mlpack
poisson_nll_loss_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_LOSS_FUNCTIONS_POISSON_NLL_LOSS_IMPL_HPP
13 #define MLPACK_METHODS_ANN_LOSS_FUNCTIONS_POISSON_NLL_LOSS_IMPL_HPP
14 
15 
16 // In case it hasn't yet been included.
17 #include "poisson_nll_loss.hpp"
18 
19 namespace mlpack {
20 namespace ann {
21 
22 template<typename InputDataType, typename OutputDataType>
24  const bool logInput,
25  const bool full,
26  const typename InputDataType::elem_type eps,
27  const bool mean):
28  logInput(logInput),
29  full(full),
30  eps(eps),
31  mean(mean)
32 {
33  Log::Assert(eps >= 0, "Epsilon (eps) must be greater than or equal to zero.");
34 }
35 
36 template<typename InputDataType, typename OutputDataType>
37 template<typename PredictionType, typename TargetType>
38 typename InputDataType::elem_type
40  const PredictionType& prediction,
41  const TargetType& target)
42 {
43  PredictionType loss(arma::size(prediction));
44 
45  if (logInput)
46  loss = arma::exp(prediction) - target % prediction;
47  else
48  {
49  CheckProbs(prediction);
50  loss = prediction - target % arma::log(prediction + eps);
51  }
52 
53  if (full)
54  {
55  const auto mask = target > 1.0;
56  const PredictionType approx = target % arma::log(target) - target
57  + 0.5 * arma::log(2 * M_PI * target);
58  loss.elem(arma::find(mask)) += approx.elem(arma::find(mask));
59  }
60 
61  return mean ? arma::accu(loss) / loss.n_elem : arma::accu(loss);
62 }
63 
64 template<typename InputDataType, typename OutputDataType>
65 template<typename PredictionType, typename TargetType, typename LossType>
67  const PredictionType& prediction,
68  const TargetType& target,
69  LossType& loss)
70 {
71  loss.set_size(size(prediction));
72 
73  if (logInput)
74  loss = (arma::exp(prediction) - target);
75  else
76  loss = (1 - target / (prediction + eps));
77 
78  if (mean)
79  loss = loss / loss.n_elem;
80 }
81 
82 template<typename InputDataType, typename OutputDataType>
83 template<typename Archive>
85  Archive& ar,
86  const uint32_t /* version */)
87 {
88  ar(CEREAL_NVP(logInput));
89  ar(CEREAL_NVP(full));
90  ar(CEREAL_NVP(eps));
91  ar(CEREAL_NVP(mean));
92 }
93 
94 } // namespace ann
95 } // namespace mlpack
96 
97 #endif
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: poisson_nll_loss_impl.hpp:84
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
InputDataType::elem_type Forward(const PredictionType &prediction, const TargetType &target)
Computes the Poisson negative log likelihood Loss.
Definition: poisson_nll_loss_impl.hpp:39
PoissonNLLLoss(const bool logInput=true, const bool full=false, const typename InputDataType::elem_type eps=1e-08, const bool mean=true)
Create the PoissonNLLLoss object.
Definition: poisson_nll_loss_impl.hpp:23
void Backward(const PredictionType &prediction, const TargetType &target, LossType &loss)
Ordinary feed backward pass of a neural network.
Definition: poisson_nll_loss_impl.hpp:66
static void Assert(bool condition, const std::string &message="Assert Failed.")
Checks if the specified condition is true.
Definition: log.cpp:38