12 #ifndef MLPACK_METHODS_ANN_LOSS_FUNCTIONS_POISSON_NLL_LOSS_IMPL_HPP 13 #define MLPACK_METHODS_ANN_LOSS_FUNCTIONS_POISSON_NLL_LOSS_IMPL_HPP 22 template<
typename InputDataType,
typename OutputDataType>
26 const typename InputDataType::elem_type eps,
33 Log::Assert(eps >= 0,
"Epsilon (eps) must be greater than or equal to zero.");
36 template<
typename InputDataType,
typename OutputDataType>
37 template<
typename PredictionType,
typename TargetType>
38 typename InputDataType::elem_type
40 const PredictionType& prediction,
41 const TargetType& target)
43 PredictionType loss(arma::size(prediction));
46 loss = arma::exp(prediction) - target % prediction;
49 CheckProbs(prediction);
50 loss = prediction - target % arma::log(prediction + eps);
55 const auto mask = target > 1.0;
56 const PredictionType approx = target % arma::log(target) - target
57 + 0.5 * arma::log(2 * M_PI * target);
58 loss.elem(arma::find(mask)) += approx.elem(arma::find(mask));
61 return mean ? arma::accu(loss) / loss.n_elem : arma::accu(loss);
64 template<
typename InputDataType,
typename OutputDataType>
65 template<
typename PredictionType,
typename TargetType,
typename LossType>
67 const PredictionType& prediction,
68 const TargetType& target,
71 loss.set_size(size(prediction));
74 loss = (arma::exp(prediction) - target);
76 loss = (1 - target / (prediction + eps));
79 loss = loss / loss.n_elem;
82 template<
typename InputDataType,
typename OutputDataType>
83 template<
typename Archive>
88 ar(CEREAL_NVP(logInput));
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: poisson_nll_loss_impl.hpp:84
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
InputDataType::elem_type Forward(const PredictionType &prediction, const TargetType &target)
Computes the Poisson negative log likelihood Loss.
Definition: poisson_nll_loss_impl.hpp:39
PoissonNLLLoss(const bool logInput=true, const bool full=false, const typename InputDataType::elem_type eps=1e-08, const bool mean=true)
Create the PoissonNLLLoss object.
Definition: poisson_nll_loss_impl.hpp:23
void Backward(const PredictionType &prediction, const TargetType &target, LossType &loss)
Ordinary feed backward pass of a neural network.
Definition: poisson_nll_loss_impl.hpp:66
static void Assert(bool condition, const std::string &message="Assert Failed.")
Checks if the specified condition is true.
Definition: log.cpp:38