mlpack
kl_divergence_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LOSS_FUNCTION_KL_DIVERGENCE_IMPL_HPP
14 #define MLPACK_METHODS_ANN_LOSS_FUNCTION_KL_DIVERGENCE_IMPL_HPP
15 
16 // In case it hasn't yet been included.
17 #include "kl_divergence.hpp"
18 
19 namespace mlpack {
20 namespace ann {
21 
22 template<typename InputDataType, typename OutputDataType>
24  takeMean(takeMean)
25 {
26  // Nothing to do here.
27 }
28 
29 template<typename InputDataType, typename OutputDataType>
30 template<typename PredictionType, typename TargetType>
31 typename PredictionType::elem_type
33  const PredictionType& prediction,
34  const TargetType& target)
35 {
36  if (takeMean)
37  {
38  return arma::as_scalar(arma::mean(
39  arma::mean(prediction % (arma::log(prediction) - arma::log(target)))));
40  }
41  else
42  {
43  return arma::accu(prediction % (arma::log(prediction) - arma::log(target)));
44  }
45 }
46 
47 template<typename InputDataType, typename OutputDataType>
48 template<typename PredictionType, typename TargetType, typename LossType>
50  const PredictionType& prediction,
51  const TargetType& target,
52  LossType& loss)
53 {
54  if (takeMean)
55  {
56  loss = arma::mean(arma::mean(
57  arma::log(prediction) - arma::log(target) + 1));
58  }
59  else
60  {
61  loss = arma::accu(arma::log(prediction) - arma::log(target) + 1);
62  }
63 }
64 
65 template<typename InputDataType, typename OutputDataType>
66 template<typename Archive>
68  Archive& ar,
69  const uint32_t /* version */)
70 {
71  ar(CEREAL_NVP(takeMean));
72 }
73 
74 } // namespace ann
75 } // namespace mlpack
76 
77 #endif
void serialize(Archive &ar, const uint32_t)
Serialize the loss function.
Definition: kl_divergence_impl.hpp:67
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
KLDivergence(const bool takeMean=false)
Create the Kullback–Leibler Divergence object with the specified parameters.
Definition: kl_divergence_impl.hpp:23
PredictionType::elem_type Forward(const PredictionType &prediction, const TargetType &target)
Computes the Kullback–Leibler divergence error function.
Definition: kl_divergence_impl.hpp:32
void Backward(const PredictionType &prediction, const TargetType &target, LossType &loss)
Ordinary feed backward pass of a neural network.
Definition: kl_divergence_impl.hpp:49