mlpack
reparametrization.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_REPARAMETRIZATION_HPP
14 #define MLPACK_METHODS_ANN_LAYER_REPARAMETRIZATION_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 #include "layer_types.hpp"
19 #include "../activation_functions/softplus_function.hpp"
20 
21 namespace mlpack {
22 namespace ann {
23 
52 template <
53  typename InputDataType = arma::mat,
54  typename OutputDataType = arma::mat
55 >
56 class Reparametrization
57 {
58  public:
61 
70  Reparametrization(const size_t latentSize,
71  const bool stochastic = true,
72  const bool includeKl = true,
73  const double beta = 1);
74 
77 
80 
83 
86 
94  template<typename eT>
95  void Forward(const arma::Mat<eT>& input, arma::Mat<eT>& output);
96 
106  template<typename eT>
107  void Backward(const arma::Mat<eT>& input,
108  const arma::Mat<eT>& gy,
109  arma::Mat<eT>& g);
110 
112  OutputDataType const& OutputParameter() const { return outputParameter; }
114  OutputDataType& OutputParameter() { return outputParameter; }
115 
117  OutputDataType const& Delta() const { return delta; }
119  OutputDataType& Delta() { return delta; }
120 
122  size_t const& OutputSize() const { return latentSize; }
124  size_t& OutputSize() { return latentSize; }
125 
127  double Loss()
128  {
129  if (!includeKl)
130  return 0;
131 
132  return -0.5 * beta * arma::accu(2 * arma::log(stdDev) - arma::pow(stdDev, 2)
133  - arma::pow(mean, 2) + 1) / mean.n_cols;
134  }
135 
137  bool Stochastic() const { return stochastic; }
138 
140  bool IncludeKL() const { return includeKl; }
141 
143  double Beta() const { return beta; }
144 
145  size_t InputShape() const
146  {
147  return 2 * latentSize;
148  }
149 
153  template<typename Archive>
154  void serialize(Archive& ar, const uint32_t /* version */);
155 
156  private:
158  size_t latentSize;
159 
161  bool stochastic;
162 
164  bool includeKl;
165 
167  double beta;
168 
170  OutputDataType delta;
171 
173  OutputDataType gaussianSample;
174 
176  OutputDataType mean;
177 
180  OutputDataType preStdDev;
181 
183  OutputDataType stdDev;
184 
186  OutputDataType outputParameter;
187 }; // class Reparametrization
188 
189 } // namespace ann
190 } // namespace mlpack
191 
192 // Include implementation.
194 
195 #endif
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: reparametrization_impl.hpp:145
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
OutputDataType & Delta()
Modify the delta.
Definition: reparametrization.hpp:119
void Backward(const arma::Mat< eT > &input, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: reparametrization_impl.hpp:129
The core includes that mlpack expects; standard C++ includes and Armadillo.
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: reparametrization.hpp:114
bool IncludeKL() const
Get the value of the includeKl parameter.
Definition: reparametrization.hpp:140
Reparametrization()
Create the Reparametrization object.
Definition: reparametrization_impl.hpp:23
OutputDataType const & Delta() const
Get the delta.
Definition: reparametrization.hpp:117
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: reparametrization.hpp:112
double Loss()
Get the KL divergence with standard normal.
Definition: reparametrization.hpp:127
Reparametrization & operator=(const Reparametrization &layer)
Copy assignment operator.
Definition: reparametrization_impl.hpp:75
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: reparametrization_impl.hpp:105
size_t & OutputSize()
Modify the output size.
Definition: reparametrization.hpp:124
size_t const & OutputSize() const
Get the output size.
Definition: reparametrization.hpp:122
bool Stochastic() const
Get the value of the stochastic parameter.
Definition: reparametrization.hpp:137
double Beta() const
Get the value of the beta hyperparameter.
Definition: reparametrization.hpp:143