mlpack
reinforce_normal_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_REINFORCE_NORMAL_IMPL_HPP
14 #define MLPACK_METHODS_ANN_LAYER_REINFORCE_NORMAL_IMPL_HPP
15 
16 // In case it hasn't yet been included.
17 #include "reinforce_normal.hpp"
18 
19 namespace mlpack {
20 namespace ann {
21 
22 template<typename InputDataType, typename OutputDataType>
24  const double stdev) : stdev(stdev), reward(0.0), deterministic(false)
25 {
26  // Nothing to do here.
27 }
28 
29 template<typename InputDataType, typename OutputDataType>
30 template<typename eT>
32  const arma::Mat<eT>& input, arma::Mat<eT>& output)
33 {
34  if (!deterministic)
35  {
36  // Multiply by standard deviations and re-center the means to the mean.
37  output = output.randn(input.n_rows, input.n_cols) * stdev + input;
38 
39  moduleInputParameter.push_back(input);
40  }
41  else
42  {
43  // Use maximum a posteriori.
44  output = input;
45  }
46 }
47 
48 template<typename InputDataType, typename OutputDataType>
49 template<typename DataType>
51  const DataType& input, const DataType& /* gy */, DataType& g)
52 {
53  g = (input - moduleInputParameter.back()) / std::pow(stdev, 2.0);
54 
55  // Multiply by reward and multiply by -1.
56  g *= reward;
57  g *= -1;
58 
59  moduleInputParameter.pop_back();
60 }
61 
62 template<typename InputDataType, typename OutputDataType>
63 template<typename Archive>
65  Archive& ar, const uint32_t /* version */)
66 {
67  ar(CEREAL_NVP(stdev));
68 }
69 
70 } // namespace ann
71 } // namespace mlpack
72 
73 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: reinforce_normal_impl.hpp:31
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: reinforce_normal_impl.hpp:64
void Backward(const DataType &input, const DataType &, DataType &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: reinforce_normal_impl.hpp:50
ReinforceNormal(const double stdev=1.0)
Create the ReinforceNormal object.
Definition: reinforce_normal_impl.hpp:23