12 #ifndef MLPACK_METHODS_ANN_DISTRIBUTIONS_BERNOULLI_DISTRIBUTION_IMPL_HPP 13 #define MLPACK_METHODS_ANN_DISTRIBUTIONS_BERNOULLI_DISTRIBUTION_IMPL_HPP 21 template<
typename DataType>
29 template<
typename DataType>
31 const DataType& param,
32 const bool applyLogistic,
35 applyLogistic(applyLogistic),
44 probability = arma::mat(logits.memptr(), logits.n_rows,
45 logits.n_cols,
false,
false);
49 template<
typename DataType>
52 DataType sample = arma::randu<DataType>
53 (probability.n_rows, probability.n_cols);
55 for (
size_t i = 0; i < sample.n_elem; ++i)
56 sample(i) = sample(i) < probability(i);
61 template<
typename DataType>
63 const DataType& observation)
const 65 return arma::accu(arma::log(probability + eps) % observation +
66 arma::log(1 - probability + eps) % (1 - observation)) /
70 template<
typename DataType>
72 const DataType& observation, DataType& output)
const 76 output = observation / (probability + eps) - (1 - observation) /
77 (1 - probability + eps);
82 output = (observation / (probability + eps) - (1 - observation) /
83 (1 - probability + eps)) % output;
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
static double Deriv(const double x)
Computes the first derivative of the logistic function.
Definition: logistic_function.hpp:70
static double Fn(const eT x)
Computes the logistic function.
Definition: logistic_function.hpp:39
DataType Sample() const
Return a matrix of randomly generated samples according to the probability distributions defined by t...
Definition: bernoulli_distribution_impl.hpp:50
void LogProbBackward(const DataType &observation, DataType &output) const
Stores the gradient of the log probabilities of the observations in the output matrix.
Definition: bernoulli_distribution_impl.hpp:71
BernoulliDistribution()
Default constructor, which creates a Bernoulli distribution with zero dimension.
Definition: bernoulli_distribution_impl.hpp:22
double LogProbability(const DataType &observation) const
Return the log probabilities of the given matrix of observations.
Definition: bernoulli_distribution_impl.hpp:62