13 #ifndef MLPACK_METHODS_ANN_LAYER_REPARAMETRIZATION_HPP 14 #define MLPACK_METHODS_ANN_LAYER_REPARAMETRIZATION_HPP 19 #include "../activation_functions/softplus_function.hpp" 53 typename InputDataType = arma::mat,
54 typename OutputDataType = arma::mat
56 class Reparametrization
71 const bool stochastic =
true,
72 const bool includeKl =
true,
73 const double beta = 1);
95 void Forward(
const arma::Mat<eT>& input, arma::Mat<eT>& output);
106 template<
typename eT>
107 void Backward(
const arma::Mat<eT>& input,
108 const arma::Mat<eT>& gy,
117 OutputDataType
const&
Delta()
const {
return delta; }
119 OutputDataType&
Delta() {
return delta; }
132 return -0.5 * beta * arma::accu(2 * arma::log(stdDev) - arma::pow(stdDev, 2)
133 - arma::pow(mean, 2) + 1) / mean.n_cols;
143 double Beta()
const {
return beta; }
145 size_t InputShape()
const 147 return 2 * latentSize;
153 template<
typename Archive>
154 void serialize(Archive& ar,
const uint32_t );
170 OutputDataType delta;
173 OutputDataType gaussianSample;
180 OutputDataType preStdDev;
183 OutputDataType stdDev;
186 OutputDataType outputParameter;
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: reparametrization_impl.hpp:145
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
OutputDataType & Delta()
Modify the delta.
Definition: reparametrization.hpp:119
void Backward(const arma::Mat< eT > &input, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: reparametrization_impl.hpp:129
The core includes that mlpack expects; standard C++ includes and Armadillo.
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: reparametrization.hpp:114
bool IncludeKL() const
Get the value of the includeKl parameter.
Definition: reparametrization.hpp:140
Reparametrization()
Create the Reparametrization object.
Definition: reparametrization_impl.hpp:23
OutputDataType const & Delta() const
Get the delta.
Definition: reparametrization.hpp:117
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: reparametrization.hpp:112
double Loss()
Get the KL divergence with standard normal.
Definition: reparametrization.hpp:127
Reparametrization & operator=(const Reparametrization &layer)
Copy assignment operator.
Definition: reparametrization_impl.hpp:75
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: reparametrization_impl.hpp:105
size_t & OutputSize()
Modify the output size.
Definition: reparametrization.hpp:124
size_t const & OutputSize() const
Get the output size.
Definition: reparametrization.hpp:122
bool Stochastic() const
Get the value of the stochastic parameter.
Definition: reparametrization.hpp:137
double Beta() const
Get the value of the beta hyperparameter.
Definition: reparametrization.hpp:143