mlpack
reparametrization_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_REPARAMETRIZATION_IMPL_HPP
14 #define MLPACK_METHODS_ANN_LAYER_REPARAMETRIZATION_IMPL_HPP
15 
16 // In case it hasn't yet been included.
17 #include "reparametrization.hpp"
18 
19 namespace mlpack {
20 namespace ann {
21 
22 template<typename InputDataType, typename OutputDataType>
24  latentSize(0),
25  stochastic(true),
26  includeKl(true),
27  beta(1)
28 {
29  // Nothing to do here.
30 }
31 
32 template <typename InputDataType, typename OutputDataType>
34  const size_t latentSize,
35  const bool stochastic,
36  const bool includeKl,
37  const double beta) :
38  latentSize(latentSize),
39  stochastic(stochastic),
40  includeKl(includeKl),
41  beta(beta)
42 {
43  if (includeKl == false && beta != 1)
44  {
45  Log::Info << "The beta parameter will be ignored as KL divergence is not "
46  << "included." << std::endl;
47  }
48 }
49 
50 template <typename InputDataType, typename OutputDataType>
52  const Reparametrization& layer) :
53  latentSize(layer.latentSize),
54  stochastic(layer.stochastic),
55  includeKl(layer.includeKl),
56  beta(layer.beta)
57 {
58  // Nothing to do here.
59 }
60 
61 template <typename InputDataType, typename OutputDataType>
63  Reparametrization&& layer) :
64  latentSize(std::move(layer.latentSize)),
65  stochastic(std::move(layer.stochastic)),
66  includeKl(std::move(layer.includeKl)),
67  beta(std::move(layer.beta))
68 {
69  // Nothing to do here.
70 }
71 
72 template <typename InputDataType, typename OutputDataType>
76 {
77  if (this != &layer)
78  {
79  latentSize = layer.latentSize;
80  stochastic = layer.stochastic;
81  includeKl = layer.includeKl;
82  beta = layer.beta;
83  }
84  return *this;
85 }
86 
87 template <typename InputDataType, typename OutputDataType>
91 {
92  if (this != &layer)
93  {
94  latentSize = std::move(layer.latentSize);
95  stochastic = std::move(layer.stochastic);
96  includeKl = std::move(layer.includeKl);
97  beta = std::move(layer.beta);
98  }
99  return *this;
100 }
101 
102 
103 template<typename InputDataType, typename OutputDataType>
104 template<typename eT>
106  const arma::Mat<eT>& input, arma::Mat<eT>& output)
107 {
108  if (input.n_rows != 2 * latentSize)
109  {
110  Log::Fatal << "The output size of layer before the Reparametrization "
111  << "layer should be 2 * latent size of the Reparametrization layer!"
112  << std::endl;
113  }
114 
115  mean = input.submat(latentSize, 0, 2 * latentSize - 1, input.n_cols - 1);
116  preStdDev = input.submat(0, 0, latentSize - 1, input.n_cols - 1);
117 
118  if (stochastic)
119  gaussianSample = arma::randn<arma::Mat<eT> >(latentSize, input.n_cols);
120  else
121  gaussianSample = arma::ones<arma::Mat<eT> >(latentSize, input.n_cols) * 0.7;
122 
123  SoftplusFunction::Fn(preStdDev, stdDev);
124  output = mean + stdDev % gaussianSample;
125 }
126 
127 template<typename InputDataType, typename OutputDataType>
128 template<typename eT>
130  const arma::Mat<eT>& /* input */, const arma::Mat<eT>& gy, arma::Mat<eT>& g)
131 {
132  SoftplusFunction::Deriv(preStdDev, g);
133 
134  if (includeKl)
135  {
136  g = join_cols(gy % std::move(gaussianSample) % g + (-1 / stdDev + stdDev)
137  % g * beta, gy + mean * beta / mean.n_cols);
138  }
139  else
140  g = join_cols(gy % std::move(gaussianSample) % g, gy);
141 }
142 
143 template<typename InputDataType, typename OutputDataType>
144 template<typename Archive>
146  Archive& ar, const uint32_t /* version */)
147 {
148  ar(CEREAL_NVP(latentSize));
149  ar(CEREAL_NVP(stochastic));
150  ar(CEREAL_NVP(includeKl));
151 }
152 
153 } // namespace ann
154 } // namespace mlpack
155 
156 #endif
static double Deriv(const double y)
Computes the first derivative of the softplus function.
Definition: softplus_function.hpp:81
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: reparametrization_impl.hpp:145
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Backward(const arma::Mat< eT > &input, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: reparametrization_impl.hpp:129
Definition: pointer_wrapper.hpp:23
Implementation of the Reparametrization layer class.
Definition: layer_types.hpp:132
Reparametrization()
Create the Reparametrization object.
Definition: reparametrization_impl.hpp:23
Reparametrization & operator=(const Reparametrization &layer)
Copy assignment operator.
Definition: reparametrization_impl.hpp:75
static double Fn(const double x)
Computes the softplus function.
Definition: softplus_function.hpp:52
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: reparametrization_impl.hpp:105
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84