mlpack
bernoulli_distribution_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_DISTRIBUTIONS_BERNOULLI_DISTRIBUTION_IMPL_HPP
13 #define MLPACK_METHODS_ANN_DISTRIBUTIONS_BERNOULLI_DISTRIBUTION_IMPL_HPP
14 
15 // In case it hasn't yet been included.
17 
18 namespace mlpack {
19 namespace ann {
20 
21 template<typename DataType>
23  applyLogistic(true),
24  eps(1e-10)
25 {
26  // Nothing to do here.
27 }
28 
29 template<typename DataType>
31  const DataType& param,
32  const bool applyLogistic,
33  const double eps) :
34  logits(param),
35  applyLogistic(applyLogistic),
36  eps(eps)
37 {
38  if (applyLogistic)
39  {
40  LogisticFunction::Fn(logits, probability);
41  }
42  else
43  {
44  probability = arma::mat(logits.memptr(), logits.n_rows,
45  logits.n_cols, false, false);
46  }
47 }
48 
49 template<typename DataType>
51 {
52  DataType sample = arma::randu<DataType>
53  (probability.n_rows, probability.n_cols);
54 
55  for (size_t i = 0; i < sample.n_elem; ++i)
56  sample(i) = sample(i) < probability(i);
57 
58  return sample;
59 }
60 
61 template<typename DataType>
63  const DataType& observation) const
64 {
65  return arma::accu(arma::log(probability + eps) % observation +
66  arma::log(1 - probability + eps) % (1 - observation)) /
67  observation.n_cols;
68 }
69 
70 template<typename DataType>
72  const DataType& observation, DataType& output) const
73 {
74  if (!applyLogistic)
75  {
76  output = observation / (probability + eps) - (1 - observation) /
77  (1 - probability + eps);
78  }
79  else
80  {
81  LogisticFunction::Deriv(probability, output);
82  output = (observation / (probability + eps) - (1 - observation) /
83  (1 - probability + eps)) % output;
84  }
85 }
86 
87 } // namespace ann
88 } // namespace mlpack
89 
90 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
static double Deriv(const double x)
Computes the first derivative of the logistic function.
Definition: logistic_function.hpp:70
static double Fn(const eT x)
Computes the logistic function.
Definition: logistic_function.hpp:39
DataType Sample() const
Return a matrix of randomly generated samples according to the probability distributions defined by t...
Definition: bernoulli_distribution_impl.hpp:50
void LogProbBackward(const DataType &observation, DataType &output) const
Stores the gradient of the log probabilities of the observations in the output matrix.
Definition: bernoulli_distribution_impl.hpp:71
BernoulliDistribution()
Default constructor, which creates a Bernoulli distribution with zero dimension.
Definition: bernoulli_distribution_impl.hpp:22
double LogProbability(const DataType &observation) const
Return the log probabilities of the given matrix of observations.
Definition: bernoulli_distribution_impl.hpp:62