mlpack
recurrent_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_LAYER_RECURRENT_IMPL_HPP
13 #define MLPACK_METHODS_ANN_LAYER_RECURRENT_IMPL_HPP
14 
15 // In case it hasn't yet been included.
16 #include "recurrent.hpp"
17 
18 #include "../visitor/add_visitor.hpp"
19 #include "../visitor/backward_visitor.hpp"
20 #include "../visitor/gradient_visitor.hpp"
21 #include "../visitor/gradient_zero_visitor.hpp"
22 #include "../visitor/input_shape_visitor.hpp"
23 
24 namespace mlpack {
25 namespace ann {
26 
27 template<typename InputDataType, typename OutputDataType,
28  typename... CustomLayers>
30  rho(0),
31  forwardStep(0),
32  backwardStep(0),
33  gradientStep(0),
34  deterministic(false),
35  ownsLayer(false)
36 {
37  // Nothing to do.
38 }
39 
40 template <typename InputDataType, typename OutputDataType,
41  typename... CustomLayers>
42 template<
43  typename StartModuleType,
44  typename InputModuleType,
45  typename FeedbackModuleType,
46  typename TransferModuleType
47 >
49  const StartModuleType& start,
50  const InputModuleType& input,
51  const FeedbackModuleType& feedback,
52  const TransferModuleType& transfer,
53  const size_t rho) :
54  startModule(new StartModuleType(start)),
55  inputModule(new InputModuleType(input)),
56  feedbackModule(new FeedbackModuleType(feedback)),
57  transferModule(new TransferModuleType(transfer)),
58  rho(rho),
59  forwardStep(0),
60  backwardStep(0),
61  gradientStep(0),
62  deterministic(false),
63  ownsLayer(true)
64 {
65  initialModule = new Sequential<>();
66  mergeModule = new AddMerge<>(false, false, false);
67  recurrentModule = new Sequential<>(false, false);
68 
69  boost::apply_visitor(AddVisitor<CustomLayers...>(inputModule),
70  initialModule);
71  boost::apply_visitor(AddVisitor<CustomLayers...>(startModule),
72  initialModule);
73  boost::apply_visitor(AddVisitor<CustomLayers...>(transferModule),
74  initialModule);
75 
76  boost::apply_visitor(AddVisitor<CustomLayers...>(inputModule), mergeModule);
77  boost::apply_visitor(AddVisitor<CustomLayers...>(feedbackModule),
78  mergeModule);
79  boost::apply_visitor(AddVisitor<CustomLayers...>(mergeModule),
80  recurrentModule);
81  boost::apply_visitor(AddVisitor<CustomLayers...>(transferModule),
82  recurrentModule);
83 
84  network.push_back(initialModule);
85  network.push_back(mergeModule);
86  network.push_back(feedbackModule);
87  network.push_back(recurrentModule);
88 }
89 
90 template<typename InputDataType, typename OutputDataType,
91  typename... CustomLayers>
93  const Recurrent& network) :
94  rho(network.rho),
95  forwardStep(network.forwardStep),
96  backwardStep(network.backwardStep),
97  gradientStep(network.gradientStep),
98  deterministic(network.deterministic),
99  ownsLayer(network.ownsLayer)
100 {
101  startModule = boost::apply_visitor(copyVisitor, network.startModule);
102  inputModule = boost::apply_visitor(copyVisitor, network.inputModule);
103  feedbackModule = boost::apply_visitor(copyVisitor, network.feedbackModule);
104  transferModule = boost::apply_visitor(copyVisitor, network.transferModule);
105  initialModule = new Sequential<>();
106  mergeModule = new AddMerge<>(false, false, false);
107  recurrentModule = new Sequential<>(false, false);
108 
109  boost::apply_visitor(AddVisitor<CustomLayers...>(inputModule),
110  initialModule);
111  boost::apply_visitor(AddVisitor<CustomLayers...>(startModule),
112  initialModule);
113  boost::apply_visitor(AddVisitor<CustomLayers...>(transferModule),
114  initialModule);
115 
116  boost::apply_visitor(AddVisitor<CustomLayers...>(inputModule), mergeModule);
117  boost::apply_visitor(AddVisitor<CustomLayers...>(feedbackModule),
118  mergeModule);
119  boost::apply_visitor(AddVisitor<CustomLayers...>(mergeModule),
120  recurrentModule);
121  boost::apply_visitor(AddVisitor<CustomLayers...>(transferModule),
122  recurrentModule);
123  this->network.push_back(initialModule);
124  this->network.push_back(mergeModule);
125  this->network.push_back(feedbackModule);
126  this->network.push_back(recurrentModule);
127 }
128 
129 template<typename InputDataType, typename OutputDataType,
130  typename... CustomLayers>
131 size_t
133 {
134  const size_t inputShapeStartModule = boost::apply_visitor(InShapeVisitor(),
135  startModule);
136 
137  // Return the input shape of the first module that we have.
138  if (inputShapeStartModule != 0)
139  {
140  return inputShapeStartModule;
141  }
142  // If input shape of first module is 0.
143  else
144  {
145  // Return input shape of the second module that we have.
146  const size_t inputShapeInputModule = boost::apply_visitor(InShapeVisitor(),
147  inputModule);
148  if (inputShapeInputModule != 0)
149  {
150  return inputShapeInputModule;
151  }
152  else // If the input shape of second module is 0.
153  {
154  // Return input shape of the third module that we have.
155  const size_t inputShapeFeedbackModule = boost::apply_visitor(
156  InShapeVisitor(), feedbackModule);
157  if (inputShapeFeedbackModule != 0)
158  {
159  return inputShapeFeedbackModule;
160  }
161  else // If the input shape of the third module is 0.
162  {
163  // Return the shape of the fourth module that we have.
164  const size_t inputShapeTransferModule = boost::apply_visitor(
165  InShapeVisitor(), transferModule);
166  if (inputShapeTransferModule != 0)
167  {
168  return inputShapeTransferModule;
169  }
170  else // If the input shape of the fourth module is 0.
171  {
172  return 0;
173  }
174  }
175  }
176  }
177 }
178 
179 template<typename InputDataType, typename OutputDataType,
180  typename... CustomLayers>
181 template<typename eT>
183  const arma::Mat<eT>& input, arma::Mat<eT>& output)
184 {
185  if (forwardStep == 0)
186  {
187  boost::apply_visitor(ForwardVisitor(input, output), initialModule);
188  }
189  else
190  {
191  boost::apply_visitor(ForwardVisitor(input,
192  boost::apply_visitor(outputParameterVisitor, inputModule)),
193  inputModule);
194 
195  boost::apply_visitor(ForwardVisitor(boost::apply_visitor(
196  outputParameterVisitor, transferModule),
197  boost::apply_visitor(outputParameterVisitor, feedbackModule)),
198  feedbackModule);
199 
200  boost::apply_visitor(ForwardVisitor(input, output), recurrentModule);
201  }
202 
203  output = boost::apply_visitor(outputParameterVisitor, transferModule);
204 
205  // Save the feedback output parameter when training the module.
206  if (!deterministic)
207  {
208  feedbackOutputParameter.push_back(output);
209  }
210 
211  forwardStep++;
212  if (forwardStep == rho)
213  {
214  forwardStep = 0;
215  backwardStep = 0;
216 
217  if (!recurrentError.is_empty())
218  {
219  recurrentError.zeros();
220  }
221  }
222 }
223 
224 template<typename InputDataType, typename OutputDataType,
225  typename... CustomLayers>
226 template<typename eT>
228  const arma::Mat<eT>& /* input */, const arma::Mat<eT>& gy, arma::Mat<eT>& g)
229 {
230  if (!recurrentError.is_empty())
231  {
232  recurrentError += gy;
233  }
234  else
235  {
236  recurrentError = gy;
237  }
238 
239  if (backwardStep < (rho - 1))
240  {
241  boost::apply_visitor(BackwardVisitor(boost::apply_visitor(
242  outputParameterVisitor, recurrentModule), recurrentError,
243  boost::apply_visitor(deltaVisitor, recurrentModule)),
244  recurrentModule);
245 
246  boost::apply_visitor(BackwardVisitor(boost::apply_visitor(
247  outputParameterVisitor, inputModule),
248  boost::apply_visitor(deltaVisitor, recurrentModule), g),
249  inputModule);
250 
251  boost::apply_visitor(BackwardVisitor(boost::apply_visitor(
252  outputParameterVisitor, feedbackModule),
253  boost::apply_visitor(deltaVisitor, recurrentModule),
254  boost::apply_visitor(deltaVisitor, feedbackModule)), feedbackModule);
255  }
256  else
257  {
258  boost::apply_visitor(BackwardVisitor(boost::apply_visitor(
259  outputParameterVisitor, initialModule), recurrentError, g),
260  initialModule);
261  }
262 
263  recurrentError = boost::apply_visitor(deltaVisitor, feedbackModule);
264  backwardStep++;
265 }
266 
267 template<typename InputDataType, typename OutputDataType,
268  typename... CustomLayers>
269 template<typename eT>
271  const arma::Mat<eT>& input,
272  const arma::Mat<eT>& error,
273  arma::Mat<eT>& /* gradient */)
274 {
275  if (gradientStep < (rho - 1))
276  {
277  boost::apply_visitor(GradientVisitor(input, error), recurrentModule);
278 
279  boost::apply_visitor(GradientVisitor(input,
280  boost::apply_visitor(deltaVisitor, mergeModule)), inputModule);
281 
282  boost::apply_visitor(GradientVisitor(
283  feedbackOutputParameter[feedbackOutputParameter.size() - 2 -
284  gradientStep], boost::apply_visitor(deltaVisitor,
285  mergeModule)), feedbackModule);
286  }
287  else
288  {
289  boost::apply_visitor(GradientZeroVisitor(), recurrentModule);
290  boost::apply_visitor(GradientZeroVisitor(), inputModule);
291  boost::apply_visitor(GradientZeroVisitor(), feedbackModule);
292 
293  boost::apply_visitor(GradientVisitor(input,
294  boost::apply_visitor(deltaVisitor, startModule)), initialModule);
295  }
296 
297  gradientStep++;
298  if (gradientStep == rho)
299  {
300  gradientStep = 0;
301  feedbackOutputParameter.clear();
302  }
303 }
304 
305 template<typename InputDataType, typename OutputDataType,
306  typename... CustomLayers>
307 template<typename Archive>
309  Archive& ar, const uint32_t /* version */)
310 {
311  // Clean up memory, if we are loading.
312  if (cereal::is_loading<Archive>())
313  {
314  // Clear old things, if needed.
315  boost::apply_visitor(DeleteVisitor(), recurrentModule);
316  boost::apply_visitor(DeleteVisitor(), initialModule);
317  boost::apply_visitor(DeleteVisitor(), startModule);
318  network.clear();
319  }
320 
321  ar(CEREAL_VARIANT_POINTER(startModule));
322  ar(CEREAL_VARIANT_POINTER(inputModule));
323  ar(CEREAL_VARIANT_POINTER(feedbackModule));
324  ar(CEREAL_VARIANT_POINTER(transferModule));
325  ar(CEREAL_NVP(rho));
326  ar(CEREAL_NVP(ownsLayer));
327 
328  // Set up the network.
329  if (cereal::is_loading<Archive>())
330  {
331  initialModule = new Sequential<>();
332  mergeModule = new AddMerge<>(false, false, false);
333  recurrentModule = new Sequential<>(false, false);
334 
335  boost::apply_visitor(AddVisitor<CustomLayers...>(inputModule),
336  initialModule);
337  boost::apply_visitor(AddVisitor<CustomLayers...>(startModule),
338  initialModule);
339  boost::apply_visitor(AddVisitor<CustomLayers...>(transferModule),
340  initialModule);
341 
342  boost::apply_visitor(AddVisitor<CustomLayers...>(inputModule),
343  mergeModule);
344  boost::apply_visitor(AddVisitor<CustomLayers...>(feedbackModule),
345  mergeModule);
346  boost::apply_visitor(AddVisitor<CustomLayers...>(mergeModule),
347  recurrentModule);
348  boost::apply_visitor(AddVisitor<CustomLayers...>(transferModule),
349  recurrentModule);
350 
351  network.push_back(initialModule);
352  network.push_back(mergeModule);
353  network.push_back(feedbackModule);
354  network.push_back(recurrentModule);
355  }
356 }
357 
358 } // namespace ann
359 } // namespace mlpack
360 
361 #endif
DeleteVisitor executes the destructor of the instantiated object.
Definition: delete_visitor.hpp:27
OutputDataType const & Gradient() const
Get the gradient.
Definition: recurrent.hpp:135
BackwardVisitor executes the Backward() function given the input, error and delta parameter...
Definition: backward_visitor.hpp:28
size_t InputShape() const
Get the shape of the input.
Definition: recurrent_impl.hpp:132
Implementation of the AddMerge module class.
Definition: add_merge.hpp:42
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: recurrent_impl.hpp:227
Recurrent()
Default constructor—this will create a Recurrent object that can&#39;t be used, so be careful! Make sure...
Definition: recurrent_impl.hpp:29
Definition: gradient_zero_visitor.hpp:27
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: recurrent_impl.hpp:308
ForwardVisitor executes the Forward() function given the input and output parameter.
Definition: forward_visitor.hpp:28
InShapeVisitor returns the input shape a Layer expects.
Definition: input_shape_visitor.hpp:29
#define CEREAL_VARIANT_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_variant_wrapper.hpp:155
AddVisitor exposes the Add() method of the given module.
Definition: add_visitor.hpp:28
SearchModeVisitor executes the Gradient() method of the given module using the input and delta parame...
Definition: gradient_visitor.hpp:28
Implementation of the RecurrentLayer class.
Definition: layer_types.hpp:157
Implementation of the Sequential class.
Definition: layer_types.hpp:145
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: recurrent_impl.hpp:182