12 #ifndef MLPACK_METHODS_ANN_DISTRIBUTIONS_BERNOULLI_DISTRIBUTION_HPP 13 #define MLPACK_METHODS_ANN_DISTRIBUTIONS_BERNOULLI_DISTRIBUTION_HPP 16 #include "../activation_functions/logistic_function.hpp" 33 template <
typename DataType = arma::mat>
64 const bool applyLogistic =
true,
65 const double eps = 1e-10);
91 void LogProbBackward(
const DataType& observation, DataType& output)
const;
108 const DataType&
Logits()
const {
return logits; }
116 template<
typename Archive>
120 ar(CEREAL_NVP(probability));
121 ar(CEREAL_NVP(logits));
122 ar(CEREAL_NVP(applyLogistic));
128 DataType probability;
const DataType & Logits() const
Return the logits matrix.
Definition: bernoulli_distribution.hpp:108
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.
Multiple independent Bernoulli distributions.
Definition: bernoulli_distribution.hpp:34
double Probability(const DataType &observation) const
Return the probabilities of the given matrix of observations.
Definition: bernoulli_distribution.hpp:72
DataType & Logits()
Return a modifiable copy of the pre probability matrix.
Definition: bernoulli_distribution.hpp:111
const DataType & Probability() const
Return the probability matrix.
Definition: bernoulli_distribution.hpp:102
DataType & Probability()
Return a modifiable copy of the probability matrix.
Definition: bernoulli_distribution.hpp:105
DataType Sample() const
Return a matrix of randomly generated samples according to the probability distributions defined by t...
Definition: bernoulli_distribution_impl.hpp:50
void LogProbBackward(const DataType &observation, DataType &output) const
Stores the gradient of the log probabilities of the observations in the output matrix.
Definition: bernoulli_distribution_impl.hpp:71
BernoulliDistribution()
Default constructor, which creates a Bernoulli distribution with zero dimension.
Definition: bernoulli_distribution_impl.hpp:22
double LogProbability(const DataType &observation) const
Return the log probabilities of the given matrix of observations.
Definition: bernoulli_distribution_impl.hpp:62
void serialize(Archive &ar, const uint32_t)
Serialize the distribution.
Definition: bernoulli_distribution.hpp:117