mlpack
Public Member Functions | List of all members
mlpack::ann::ReinforceNormal< InputDataType, OutputDataType > Class Template Reference

Implementation of the reinforce normal layer. More...

#include <reinforce_normal.hpp>

Public Member Functions

 ReinforceNormal (const double stdev=1.0)
 Create the ReinforceNormal object. More...
 
template<typename eT >
void Forward (const arma::Mat< eT > &input, arma::Mat< eT > &output)
 Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activity forward through f. More...
 
template<typename DataType >
void Backward (const DataType &input, const DataType &, DataType &g)
 Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backwards through f. More...
 
OutputDataType & OutputParameter () const
 Get the output parameter.
 
OutputDataType & OutputParameter ()
 Modify the output parameter.
 
OutputDataType & Delta () const
 Get the delta.
 
OutputDataType & Delta ()
 Modify the delta.
 
bool Deterministic () const
 Get the value of the deterministic parameter.
 
bool & Deterministic ()
 Modify the value of the deterministic parameter.
 
double Reward () const
 Get the value of the reward parameter.
 
double & Reward ()
 Modify the value of the deterministic parameter.
 
double StandardDeviation () const
 Get the standard deviation used during forward and backward pass.
 
template<typename Archive >
void serialize (Archive &ar, const uint32_t)
 Serialize the layer.
 

Detailed Description

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
class mlpack::ann::ReinforceNormal< InputDataType, OutputDataType >

Implementation of the reinforce normal layer.

The reinforce normal layer implements the REINFORCE algorithm for the normal distribution.

Template Parameters
InputDataTypeType of the input data (arma::colvec, arma::mat, arma::sp_mat or arma::cube).
OutputDataTypeType of the output data (arma::colvec, arma::mat, arma::sp_mat or arma::cube).

Constructor & Destructor Documentation

◆ ReinforceNormal()

template<typename InputDataType , typename OutputDataType >
mlpack::ann::ReinforceNormal< InputDataType, OutputDataType >::ReinforceNormal ( const double  stdev = 1.0)

Create the ReinforceNormal object.

Parameters
stdevStandard deviation used during the forward and backward pass.

Member Function Documentation

◆ Backward()

template<typename InputDataType , typename OutputDataType >
template<typename DataType >
void mlpack::ann::ReinforceNormal< InputDataType, OutputDataType >::Backward ( const DataType &  input,
const DataType &  ,
DataType &  g 
)

Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backwards through f.

Using the results from the feed forward pass.

Parameters
inputThe propagated input activation.
*(gy) The backpropagated error.
gThe calculated gradient.

◆ Forward()

template<typename InputDataType , typename OutputDataType >
template<typename eT >
void mlpack::ann::ReinforceNormal< InputDataType, OutputDataType >::Forward ( const arma::Mat< eT > &  input,
arma::Mat< eT > &  output 
)

Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activity forward through f.

Parameters
inputInput data used for evaluating the specified function.
outputResulting output activation.

The documentation for this class was generated from the following files: