13 #ifndef MLPACK_METHODS_ANN_LAYER_REINFORCE_NORMAL_IMPL_HPP 14 #define MLPACK_METHODS_ANN_LAYER_REINFORCE_NORMAL_IMPL_HPP 22 template<
typename InputDataType,
typename OutputDataType>
24 const double stdev) : stdev(stdev), reward(0.0), deterministic(false)
29 template<
typename InputDataType,
typename OutputDataType>
32 const arma::Mat<eT>& input, arma::Mat<eT>& output)
37 output = output.randn(input.n_rows, input.n_cols) * stdev + input;
39 moduleInputParameter.push_back(input);
48 template<
typename InputDataType,
typename OutputDataType>
49 template<
typename DataType>
51 const DataType& input,
const DataType& , DataType& g)
53 g = (input - moduleInputParameter.back()) / std::pow(stdev, 2.0);
59 moduleInputParameter.pop_back();
62 template<
typename InputDataType,
typename OutputDataType>
63 template<
typename Archive>
65 Archive& ar,
const uint32_t )
67 ar(CEREAL_NVP(stdev));
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: reinforce_normal_impl.hpp:31
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: reinforce_normal_impl.hpp:64
void Backward(const DataType &input, const DataType &, DataType &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: reinforce_normal_impl.hpp:50
ReinforceNormal(const double stdev=1.0)
Create the ReinforceNormal object.
Definition: reinforce_normal_impl.hpp:23