mlpack
recurrent_attention_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_LAYER_RECURRENT_ATTENTION_IMPL_HPP
13 #define MLPACK_METHODS_ANN_LAYER_RECURRENT_ATTENTION_IMPL_HPP
14 
15 // In case it hasn't yet been included.
16 #include "recurrent_attention.hpp"
17 
18 #include "../visitor/load_output_parameter_visitor.hpp"
19 #include "../visitor/save_output_parameter_visitor.hpp"
20 #include "../visitor/backward_visitor.hpp"
21 #include "../visitor/forward_visitor.hpp"
22 #include "../visitor/gradient_set_visitor.hpp"
23 #include "../visitor/gradient_update_visitor.hpp"
24 #include "../visitor/gradient_visitor.hpp"
25 
26 namespace mlpack {
27 namespace ann {
28 
29 template<typename InputDataType, typename OutputDataType>
31  outSize(0),
32  rho(0),
33  forwardStep(0),
34  backwardStep(0),
35  deterministic(false)
36 {
37  // Nothing to do.
38 }
39 
40 template <typename InputDataType, typename OutputDataType>
41 template<typename RNNModuleType, typename ActionModuleType>
43  const size_t outSize,
44  const RNNModuleType& rnn,
45  const ActionModuleType& action,
46  const size_t rho) :
47  outSize(outSize),
48  rnnModule(new RNNModuleType(rnn)),
49  actionModule(new ActionModuleType(action)),
50  rho(rho),
51  forwardStep(0),
52  backwardStep(0),
53  deterministic(false)
54 {
55  network.push_back(rnnModule);
56  network.push_back(actionModule);
57 }
58 
59 template<typename InputDataType, typename OutputDataType>
60 template<typename eT>
62  const arma::Mat<eT>& input, arma::Mat<eT>& output)
63 {
64  // Initialize the action input.
65  if (initialInput.is_empty())
66  {
67  initialInput = arma::zeros(outSize, input.n_cols);
68  }
69 
70  // Propagate through the action and recurrent module.
71  for (forwardStep = 0; forwardStep < rho; ++forwardStep)
72  {
73  if (forwardStep == 0)
74  {
75  boost::apply_visitor(ForwardVisitor(initialInput,
76  boost::apply_visitor(outputParameterVisitor, actionModule)),
77  actionModule);
78  }
79  else
80  {
81  boost::apply_visitor(ForwardVisitor(boost::apply_visitor(
82  outputParameterVisitor, rnnModule), boost::apply_visitor(
83  outputParameterVisitor, actionModule)), actionModule);
84  }
85 
86  // Initialize the glimpse input.
87  arma::mat glimpseInput = arma::zeros(input.n_elem, 2);
88  glimpseInput.col(0) = input;
89  glimpseInput.submat(0, 1, boost::apply_visitor(outputParameterVisitor,
90  actionModule).n_elem - 1, 1) = boost::apply_visitor(
91  outputParameterVisitor, actionModule);
92 
93  boost::apply_visitor(ForwardVisitor(glimpseInput,
94  boost::apply_visitor(outputParameterVisitor, rnnModule)),
95  rnnModule);
96 
97  // Save the output parameter when training the module.
98  if (!deterministic)
99  {
100  for (size_t l = 0; l < network.size(); ++l)
101  {
102  boost::apply_visitor(SaveOutputParameterVisitor(
103  moduleOutputParameter), network[l]);
104  }
105  }
106  }
107 
108  output = boost::apply_visitor(outputParameterVisitor, rnnModule);
109 
110  forwardStep = 0;
111  backwardStep = 0;
112 }
113 
114 template<typename InputDataType, typename OutputDataType>
115 template<typename eT>
117  const arma::Mat<eT>& /* input */,
118  const arma::Mat<eT>& gy,
119  arma::Mat<eT>& g)
120 {
121  if (intermediateGradient.is_empty() && backwardStep == 0)
122  {
123  // Initialize the attention gradients.
124  size_t weights = boost::apply_visitor(weightSizeVisitor, rnnModule) +
125  boost::apply_visitor(weightSizeVisitor, actionModule);
126 
127  intermediateGradient = arma::zeros(weights, 1);
128  attentionGradient = arma::zeros(weights, 1);
129 
130  // Initialize the action error.
131  actionError = arma::zeros(
132  boost::apply_visitor(outputParameterVisitor, actionModule).n_rows,
133  boost::apply_visitor(outputParameterVisitor, actionModule).n_cols);
134  }
135 
136  // Propagate the attention gradients.
137  if (backwardStep == 0)
138  {
139  size_t offset = 0;
140  offset += boost::apply_visitor(GradientSetVisitor(
141  intermediateGradient, offset), rnnModule);
142  boost::apply_visitor(GradientSetVisitor(
143  intermediateGradient, offset), actionModule);
144 
145  attentionGradient.zeros();
146  }
147 
148  // Back-propagate through time.
149  for (; backwardStep < rho; backwardStep++)
150  {
151  if (backwardStep == 0)
152  {
153  recurrentError = gy;
154  }
155  else
156  {
157  recurrentError = actionDelta;
158  }
159 
160  for (size_t l = 0; l < network.size(); ++l)
161  {
162  boost::apply_visitor(LoadOutputParameterVisitor(
163  moduleOutputParameter), network[network.size() - 1 - l]);
164  }
165 
166  if (backwardStep == (rho - 1))
167  {
168  boost::apply_visitor(BackwardVisitor(boost::apply_visitor(
169  outputParameterVisitor, actionModule), actionError,
170  actionDelta), actionModule);
171  }
172  else
173  {
174  boost::apply_visitor(BackwardVisitor(initialInput, actionError,
175  actionDelta), actionModule);
176  }
177 
178  boost::apply_visitor(BackwardVisitor(boost::apply_visitor(
179  outputParameterVisitor, rnnModule), recurrentError, rnnDelta),
180  rnnModule);
181 
182  if (backwardStep == 0)
183  {
184  g = rnnDelta.col(1);
185  }
186  else
187  {
188  g += rnnDelta.col(1);
189  }
190 
191  IntermediateGradient();
192  }
193 }
194 
195 template<typename InputDataType, typename OutputDataType>
196 template<typename eT>
198  const arma::Mat<eT>& /* input */,
199  const arma::Mat<eT>& /* error */,
200  arma::Mat<eT>& /* gradient */)
201 {
202  size_t offset = 0;
203  offset += boost::apply_visitor(GradientUpdateVisitor(
204  attentionGradient, offset), rnnModule);
205  boost::apply_visitor(GradientUpdateVisitor(
206  attentionGradient, offset), actionModule);
207 }
208 
209 template<typename InputDataType, typename OutputDataType>
210 template<typename Archive>
212  Archive& ar, const uint32_t /* version */)
213 {
214  ar(CEREAL_NVP(rho));
215  ar(CEREAL_NVP(outSize));
216  ar(CEREAL_NVP(forwardStep));
217  ar(CEREAL_NVP(backwardStep));
218 
219  ar(CEREAL_VARIANT_POINTER(rnnModule));
220  ar(CEREAL_VARIANT_POINTER(actionModule));
221 }
222 
223 } // namespace ann
224 } // namespace mlpack
225 
226 #endif
RecurrentAttention()
Default constructor: this will not give a usable RecurrentAttention object, so be sure to set all the...
Definition: recurrent_attention_impl.hpp:30
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: recurrent_attention_impl.hpp:211
BackwardVisitor executes the Backward() function given the input, error and delta parameter...
Definition: backward_visitor.hpp:28
OutputDataType const & Gradient() const
Get the gradient.
Definition: recurrent_attention.hpp:138
GradientUpdateVisitor update the gradient parameter given the gradient set.
Definition: gradient_update_visitor.hpp:26
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
GradientSetVisitor update the gradient parameter given the gradient set.
Definition: gradient_set_visitor.hpp:26
SaveOutputParameterVisitor saves the output parameter into the given parameter set.
Definition: save_output_parameter_visitor.hpp:27
ForwardVisitor executes the Forward() function given the input and output parameter.
Definition: forward_visitor.hpp:28
#define CEREAL_VARIANT_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_variant_wrapper.hpp:155
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: recurrent_attention_impl.hpp:116
LoadOutputParameterVisitor restores the output parameter using the given parameter set...
Definition: load_output_parameter_visitor.hpp:28
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: recurrent_attention_impl.hpp:61