12 #ifndef MLPACK_METHODS_ANN_LAYER_RECURRENT_ATTENTION_HPP 13 #define MLPACK_METHODS_ANN_LAYER_RECURRENT_ATTENTION_HPP 17 #include "../visitor/delta_visitor.hpp" 18 #include "../visitor/output_parameter_visitor.hpp" 19 #include "../visitor/reset_visitor.hpp" 20 #include "../visitor/weight_size_visitor.hpp" 52 typename InputDataType = arma::mat,
53 typename OutputDataType = arma::mat
55 class RecurrentAttention
72 template<
typename RNNModuleType,
typename ActionModuleType>
74 const RNNModuleType& rnn,
75 const ActionModuleType& action,
86 void Forward(
const arma::Mat<eT>& input, arma::Mat<eT>& output);
99 const arma::Mat<eT>& gy,
109 template<
typename eT>
110 void Gradient(
const arma::Mat<eT>& ,
111 const arma::Mat<eT>& ,
115 std::vector<LayerTypes<>>&
Model() {
return network; }
123 OutputDataType
const&
Parameters()
const {
return parameters; }
133 OutputDataType
const&
Delta()
const {
return delta; }
135 OutputDataType&
Delta() {
return delta; }
138 OutputDataType
const&
Gradient()
const {
return gradient; }
146 size_t const&
Rho()
const {
return rho; }
151 template<
typename Archive>
152 void serialize(Archive& ar,
const uint32_t );
156 void IntermediateGradient()
158 intermediateGradient.zeros();
161 if (backwardStep == (rho - 1))
169 outputParameterVisitor, actionModule), actionError),
175 outputParameterVisitor, rnnModule), recurrentError),
178 attentionGradient += intermediateGradient;
185 LayerTypes<> rnnModule;
188 LayerTypes<> actionModule;
203 OutputDataType parameters;
206 std::vector<LayerTypes<>> network;
218 std::vector<arma::mat> feedbackOutputParameter;
221 std::vector<arma::mat> moduleOutputParameter;
224 OutputDataType delta;
227 OutputDataType gradient;
230 OutputDataType outputParameter;
233 arma::mat recurrentError;
236 arma::mat actionError;
239 arma::mat actionDelta;
245 arma::mat initialInput;
251 arma::mat attentionGradient;
254 arma::mat intermediateGradient;
RecurrentAttention()
Default constructor: this will not give a usable RecurrentAttention object, so be sure to set all the...
Definition: recurrent_attention_impl.hpp:30
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: recurrent_attention.hpp:128
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: recurrent_attention_impl.hpp:211
OutputDataType const & Parameters() const
Get the parameters.
Definition: recurrent_attention.hpp:123
OutputDataType const & Delta() const
Get the delta.
Definition: recurrent_attention.hpp:133
OutputDataType const & Gradient() const
Get the gradient.
Definition: recurrent_attention.hpp:138
OutputDataType & Parameters()
Modify the parameters.
Definition: recurrent_attention.hpp:125
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.
WeightSizeVisitor returns the number of weights of the given module.
Definition: weight_size_visitor.hpp:27
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: recurrent_attention.hpp:130
size_t const & Rho() const
Get the number of steps to backpropagate through time.
Definition: recurrent_attention.hpp:146
ResetVisitor executes the Reset() function.
Definition: reset_visitor.hpp:26
OutputParameterVisitor exposes the output parameter of the given module.
Definition: output_parameter_visitor.hpp:27
OutputDataType & Gradient()
Modify the gradient.
Definition: recurrent_attention.hpp:140
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: recurrent_attention_impl.hpp:116
bool & Deterministic()
Modify the value of the deterministic parameter.
Definition: recurrent_attention.hpp:120
std::vector< LayerTypes<> > & Model()
Get the model modules.
Definition: recurrent_attention.hpp:115
SearchModeVisitor executes the Gradient() method of the given module using the input and delta parame...
Definition: gradient_visitor.hpp:28
OutputDataType & Delta()
Modify the delta.
Definition: recurrent_attention.hpp:135
DeltaVisitor exposes the delta parameter of the given module.
Definition: delta_visitor.hpp:27
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: recurrent_attention_impl.hpp:61
bool Deterministic() const
The value of the deterministic parameter.
Definition: recurrent_attention.hpp:118
size_t OutSize() const
Get the module output size.
Definition: recurrent_attention.hpp:143