mlpack
normal_distribution_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_DISTRIBUTIONS_NORMAL_DISTRIBUTION_IMPL_HPP
14 #define MLPACK_METHODS_ANN_DISTRIBUTIONS_NORMAL_DISTRIBUTION_IMPL_HPP
15 
16 // In case it hasn't yet been included.
17 #include "normal_distribution.hpp"
18 
19 namespace mlpack {
20 namespace ann {
21 
22 template<typename DataType>
24 {
25  // Nothing to do here.
26 }
27 
28 template<typename DataType>
30  const DataType& mean,
31  const DataType& sigma) :
32  mean(mean),
33  sigma(sigma)
34 {
35  // Nothing to do here.
36 }
37 
38 template<typename DataType>
40 {
41  return sigma * arma::randn<DataType>(mean.n_elem) + mean;
42 }
43 
44 template<typename DataType>
46  const DataType& observation) const
47 {
48  const DataType v1 = arma::log(sigma) + std::log(std::sqrt(2 * M_PI));
49  const DataType v2 = arma::square(observation - mean) /
50  (2 * arma::square(sigma));
51  return (-v1 - v2);
52 }
53 
54 template<typename DataType>
56  const DataType& observation,
57  DataType& dmu,
58  DataType& dsigma) const
59 {
60  dmu = (observation - mean) / (arma::square(sigma)) % Probability(observation);
61  dsigma = (- 1.0 / sigma +
62  (arma::square(observation - mean) / arma::pow(sigma, 3)))
63  % Probability(observation);
64 }
65 
66 template<typename DataType>
67 template<typename Archive>
69  const uint32_t /* version */)
70 {
71  ar(CEREAL_NVP(mean));
72  ar(CEREAL_NVP(sigma));
73 }
74 
75 } // namespace ann
76 } // namespace mlpack
77 
78 #endif
DataType Sample() const
Return a randomly generated observation according to the probability distribution defined by this obj...
Definition: normal_distribution_impl.hpp:39
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
DataType Probability(const DataType &observation) const
Return the probabilities of the given matrix of observations.
Definition: normal_distribution.hpp:54
void serialize(Archive &ar, const uint32_t)
Serialize the distribution.
Definition: normal_distribution_impl.hpp:68
NormalDistribution()
Default constructor, which creates a Normal distribution with zero dimension.
Definition: normal_distribution_impl.hpp:23
DataType LogProbability(const DataType &observation) const
Return the log probabilities of the given matrix of observations.
Definition: normal_distribution_impl.hpp:45
void ProbBackward(const DataType &observation, DataType &dmu, DataType &dsigma) const
Stores the gradient of the probabilities of the observations with respect to mean and standard deviat...
Definition: normal_distribution_impl.hpp:55