13 #ifndef MLPACK_METHODS_ANN_LAYER_REPARAMETRIZATION_IMPL_HPP 14 #define MLPACK_METHODS_ANN_LAYER_REPARAMETRIZATION_IMPL_HPP 22 template<
typename InputDataType,
typename OutputDataType>
32 template <
typename InputDataType,
typename OutputDataType>
34 const size_t latentSize,
35 const bool stochastic,
38 latentSize(latentSize),
39 stochastic(stochastic),
43 if (includeKl ==
false && beta != 1)
45 Log::Info <<
"The beta parameter will be ignored as KL divergence is not " 46 <<
"included." << std::endl;
50 template <
typename InputDataType,
typename OutputDataType>
53 latentSize(layer.latentSize),
54 stochastic(layer.stochastic),
55 includeKl(layer.includeKl),
61 template <
typename InputDataType,
typename OutputDataType>
64 latentSize(
std::move(layer.latentSize)),
65 stochastic(
std::move(layer.stochastic)),
66 includeKl(
std::move(layer.includeKl)),
67 beta(
std::move(layer.beta))
72 template <
typename InputDataType,
typename OutputDataType>
79 latentSize = layer.latentSize;
80 stochastic = layer.stochastic;
81 includeKl = layer.includeKl;
87 template <
typename InputDataType,
typename OutputDataType>
94 latentSize = std::move(layer.latentSize);
95 stochastic = std::move(layer.stochastic);
96 includeKl = std::move(layer.includeKl);
97 beta = std::move(layer.beta);
103 template<
typename InputDataType,
typename OutputDataType>
104 template<
typename eT>
106 const arma::Mat<eT>& input, arma::Mat<eT>& output)
108 if (input.n_rows != 2 * latentSize)
110 Log::Fatal <<
"The output size of layer before the Reparametrization " 111 <<
"layer should be 2 * latent size of the Reparametrization layer!" 115 mean = input.submat(latentSize, 0, 2 * latentSize - 1, input.n_cols - 1);
116 preStdDev = input.submat(0, 0, latentSize - 1, input.n_cols - 1);
119 gaussianSample = arma::randn<arma::Mat<eT> >(latentSize, input.n_cols);
121 gaussianSample = arma::ones<arma::Mat<eT> >(latentSize, input.n_cols) * 0.7;
124 output = mean + stdDev % gaussianSample;
127 template<
typename InputDataType,
typename OutputDataType>
128 template<
typename eT>
130 const arma::Mat<eT>& ,
const arma::Mat<eT>& gy, arma::Mat<eT>& g)
136 g = join_cols(gy % std::move(gaussianSample) % g + (-1 / stdDev + stdDev)
137 % g * beta, gy + mean * beta / mean.n_cols);
140 g = join_cols(gy % std::move(gaussianSample) % g, gy);
143 template<
typename InputDataType,
typename OutputDataType>
144 template<
typename Archive>
146 Archive& ar,
const uint32_t )
148 ar(CEREAL_NVP(latentSize));
149 ar(CEREAL_NVP(stochastic));
150 ar(CEREAL_NVP(includeKl));
static double Deriv(const double y)
Computes the first derivative of the softplus function.
Definition: softplus_function.hpp:81
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: reparametrization_impl.hpp:145
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Backward(const arma::Mat< eT > &input, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: reparametrization_impl.hpp:129
Definition: pointer_wrapper.hpp:23
Implementation of the Reparametrization layer class.
Definition: layer_types.hpp:132
Reparametrization()
Create the Reparametrization object.
Definition: reparametrization_impl.hpp:23
Reparametrization & operator=(const Reparametrization &layer)
Copy assignment operator.
Definition: reparametrization_impl.hpp:75
static double Fn(const double x)
Computes the softplus function.
Definition: softplus_function.hpp:52
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: reparametrization_impl.hpp:105
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84