12 #ifndef MLPACK_METHODS_ANN_LAYER_RECURRENT_ATTENTION_IMPL_HPP 13 #define MLPACK_METHODS_ANN_LAYER_RECURRENT_ATTENTION_IMPL_HPP 18 #include "../visitor/load_output_parameter_visitor.hpp" 19 #include "../visitor/save_output_parameter_visitor.hpp" 20 #include "../visitor/backward_visitor.hpp" 21 #include "../visitor/forward_visitor.hpp" 22 #include "../visitor/gradient_set_visitor.hpp" 23 #include "../visitor/gradient_update_visitor.hpp" 24 #include "../visitor/gradient_visitor.hpp" 29 template<
typename InputDataType,
typename OutputDataType>
40 template <
typename InputDataType,
typename OutputDataType>
41 template<
typename RNNModuleType,
typename ActionModuleType>
44 const RNNModuleType& rnn,
45 const ActionModuleType& action,
48 rnnModule(new RNNModuleType(rnn)),
49 actionModule(new ActionModuleType(action)),
55 network.push_back(rnnModule);
56 network.push_back(actionModule);
59 template<
typename InputDataType,
typename OutputDataType>
62 const arma::Mat<eT>& input, arma::Mat<eT>& output)
65 if (initialInput.is_empty())
67 initialInput = arma::zeros(outSize, input.n_cols);
71 for (forwardStep = 0; forwardStep < rho; ++forwardStep)
76 boost::apply_visitor(outputParameterVisitor, actionModule)),
82 outputParameterVisitor, rnnModule), boost::apply_visitor(
83 outputParameterVisitor, actionModule)), actionModule);
87 arma::mat glimpseInput = arma::zeros(input.n_elem, 2);
88 glimpseInput.col(0) = input;
89 glimpseInput.submat(0, 1, boost::apply_visitor(outputParameterVisitor,
90 actionModule).n_elem - 1, 1) = boost::apply_visitor(
91 outputParameterVisitor, actionModule);
94 boost::apply_visitor(outputParameterVisitor, rnnModule)),
100 for (
size_t l = 0; l < network.size(); ++l)
103 moduleOutputParameter), network[l]);
108 output = boost::apply_visitor(outputParameterVisitor, rnnModule);
114 template<
typename InputDataType,
typename OutputDataType>
115 template<
typename eT>
117 const arma::Mat<eT>& ,
118 const arma::Mat<eT>& gy,
121 if (intermediateGradient.is_empty() && backwardStep == 0)
124 size_t weights = boost::apply_visitor(weightSizeVisitor, rnnModule) +
125 boost::apply_visitor(weightSizeVisitor, actionModule);
127 intermediateGradient = arma::zeros(weights, 1);
128 attentionGradient = arma::zeros(weights, 1);
131 actionError = arma::zeros(
132 boost::apply_visitor(outputParameterVisitor, actionModule).n_rows,
133 boost::apply_visitor(outputParameterVisitor, actionModule).n_cols);
137 if (backwardStep == 0)
141 intermediateGradient, offset), rnnModule);
143 intermediateGradient, offset), actionModule);
145 attentionGradient.zeros();
149 for (; backwardStep < rho; backwardStep++)
151 if (backwardStep == 0)
157 recurrentError = actionDelta;
160 for (
size_t l = 0; l < network.size(); ++l)
163 moduleOutputParameter), network[network.size() - 1 - l]);
166 if (backwardStep == (rho - 1))
169 outputParameterVisitor, actionModule), actionError,
170 actionDelta), actionModule);
175 actionDelta), actionModule);
179 outputParameterVisitor, rnnModule), recurrentError, rnnDelta),
182 if (backwardStep == 0)
188 g += rnnDelta.col(1);
191 IntermediateGradient();
195 template<
typename InputDataType,
typename OutputDataType>
196 template<
typename eT>
198 const arma::Mat<eT>& ,
199 const arma::Mat<eT>& ,
204 attentionGradient, offset), rnnModule);
206 attentionGradient, offset), actionModule);
209 template<
typename InputDataType,
typename OutputDataType>
210 template<
typename Archive>
212 Archive& ar,
const uint32_t )
215 ar(CEREAL_NVP(outSize));
216 ar(CEREAL_NVP(forwardStep));
217 ar(CEREAL_NVP(backwardStep));
RecurrentAttention()
Default constructor: this will not give a usable RecurrentAttention object, so be sure to set all the...
Definition: recurrent_attention_impl.hpp:30
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: recurrent_attention_impl.hpp:211
BackwardVisitor executes the Backward() function given the input, error and delta parameter...
Definition: backward_visitor.hpp:28
OutputDataType const & Gradient() const
Get the gradient.
Definition: recurrent_attention.hpp:138
GradientUpdateVisitor update the gradient parameter given the gradient set.
Definition: gradient_update_visitor.hpp:26
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
GradientSetVisitor update the gradient parameter given the gradient set.
Definition: gradient_set_visitor.hpp:26
SaveOutputParameterVisitor saves the output parameter into the given parameter set.
Definition: save_output_parameter_visitor.hpp:27
ForwardVisitor executes the Forward() function given the input and output parameter.
Definition: forward_visitor.hpp:28
#define CEREAL_VARIANT_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_variant_wrapper.hpp:155
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: recurrent_attention_impl.hpp:116
LoadOutputParameterVisitor restores the output parameter using the given parameter set...
Definition: load_output_parameter_visitor.hpp:28
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: recurrent_attention_impl.hpp:61