mlpack
gru_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_GRU_IMPL_HPP
14 #define MLPACK_METHODS_ANN_LAYER_GRU_IMPL_HPP
15 
16 // In case it hasn't yet been included.
17 #include "gru.hpp"
18 
19 #include "../visitor/forward_visitor.hpp"
20 #include "../visitor/backward_visitor.hpp"
21 #include "../visitor/gradient_visitor.hpp"
22 
23 namespace mlpack {
24 namespace ann {
25 
26 template<typename InputDataType, typename OutputDataType>
28 {
29  // Nothing to do here.
30 }
31 
32 template <typename InputDataType, typename OutputDataType>
34  const size_t inSize,
35  const size_t outSize,
36  const size_t rho) :
37  inSize(inSize),
38  outSize(outSize),
39  rho(rho),
40  batchSize(1),
41  forwardStep(0),
42  backwardStep(0),
43  gradientStep(0),
44  deterministic(false)
45 {
46  // Input specific linear layers(for zt, rt, ot).
47  input2GateModule = new Linear<>(inSize, 3 * outSize);
48 
49  // Previous output gates (for zt and rt).
50  output2GateModule = new LinearNoBias<>(outSize, 2 * outSize);
51 
52  // Previous output gate for ot.
53  outputHidden2GateModule = new LinearNoBias<>(outSize, outSize);
54 
55  network.push_back(input2GateModule);
56  network.push_back(output2GateModule);
57  network.push_back(outputHidden2GateModule);
58 
59  inputGateModule = new SigmoidLayer<>();
60  forgetGateModule = new SigmoidLayer<>();
61  hiddenStateModule = new TanHLayer<>();
62 
63  network.push_back(inputGateModule);
64  network.push_back(hiddenStateModule);
65  network.push_back(forgetGateModule);
66 
67  prevError = arma::zeros<arma::mat>(3 * outSize, batchSize);
68 
69  allZeros = arma::zeros<arma::mat>(outSize, batchSize);
70 
71  outParameter.emplace_back(allZeros.memptr(),
72  allZeros.n_rows, allZeros.n_cols, false, true);
73 
74  prevOutput = outParameter.begin();
75  backIterator = outParameter.end();
76  gradIterator = outParameter.end();
77 }
78 
79 template<typename InputDataType, typename OutputDataType>
80 template<typename eT>
82  const arma::Mat<eT>& input, arma::Mat<eT>& output)
83 {
84  if (input.n_cols != batchSize)
85  {
86  batchSize = input.n_cols;
87  prevError.resize(3 * outSize, batchSize);
88  allZeros.zeros(outSize, batchSize);
89  // Batch size better not change during an iteration...
90  if (outParameter.size() > 1)
91  {
92  Log::Fatal << "GRU<>::Forward(): batch size cannot change during a "
93  << "forward pass!" << std::endl;
94  }
95 
96  outParameter.clear();
97  outParameter.emplace_back(allZeros.memptr(),
98  allZeros.n_rows, allZeros.n_cols, false, true);
99 
100  prevOutput = outParameter.begin();
101  backIterator = outParameter.end();
102  gradIterator = outParameter.end();
103  }
104 
105  // Process the input linearly(zt, rt, ot).
106  boost::apply_visitor(ForwardVisitor(input,
107  boost::apply_visitor(outputParameterVisitor, input2GateModule)),
108  input2GateModule);
109 
110  // Process the output(zt, rt) linearly.
111  boost::apply_visitor(ForwardVisitor(*prevOutput,
112  boost::apply_visitor(outputParameterVisitor, output2GateModule)),
113  output2GateModule);
114 
115  // Merge the outputs(zt and rt).
116  output = (boost::apply_visitor(outputParameterVisitor,
117  input2GateModule).submat(0, 0, 2 * outSize - 1, batchSize - 1) +
118  boost::apply_visitor(outputParameterVisitor, output2GateModule));
119 
120  // Pass the first outSize through inputGate(it).
121  boost::apply_visitor(ForwardVisitor(output.submat(
122  0, 0, 1 * outSize - 1, batchSize - 1), boost::apply_visitor(
123  outputParameterVisitor, inputGateModule)), inputGateModule);
124 
125  // Pass the second through forgetGate.
126  boost::apply_visitor(ForwardVisitor(output.submat(
127  1 * outSize, 0, 2 * outSize - 1, batchSize - 1),
128  boost::apply_visitor(outputParameterVisitor, forgetGateModule)),
129  forgetGateModule);
130 
131  arma::mat modInput = (boost::apply_visitor(outputParameterVisitor,
132  forgetGateModule) % *prevOutput);
133 
134  // Pass that through the outputHidden2GateModule.
135  boost::apply_visitor(ForwardVisitor(modInput,
136  boost::apply_visitor(outputParameterVisitor, outputHidden2GateModule)),
137  outputHidden2GateModule);
138 
139  // Merge for ot.
140  arma::mat outputH = boost::apply_visitor(outputParameterVisitor,
141  input2GateModule).submat(2 * outSize, 0, 3 * outSize - 1, batchSize - 1) +
142  boost::apply_visitor(outputParameterVisitor, outputHidden2GateModule);
143 
144  // Pass it through hiddenGate.
145  boost::apply_visitor(ForwardVisitor(outputH,
146  boost::apply_visitor(outputParameterVisitor, hiddenStateModule)),
147  hiddenStateModule);
148 
149  // Update the output (nextOutput): cmul1 + cmul2
150  // Where cmul1 is input gate * prevOutput and
151  // cmul2 is (1 - input gate) * hidden gate.
152  output = (boost::apply_visitor(outputParameterVisitor, inputGateModule)
153  % (*prevOutput - boost::apply_visitor(outputParameterVisitor,
154  hiddenStateModule))) + boost::apply_visitor(outputParameterVisitor,
155  hiddenStateModule);
156 
157  forwardStep++;
158  if (forwardStep == rho)
159  {
160  forwardStep = 0;
161  if (!deterministic)
162  {
163  outParameter.emplace_back(allZeros.memptr(),
164  allZeros.n_rows, allZeros.n_cols, false, true);
165  prevOutput = --outParameter.end();
166  }
167  else
168  {
169  *prevOutput = arma::mat(allZeros.memptr(),
170  allZeros.n_rows, allZeros.n_cols, false, true);
171  }
172  }
173  else if (!deterministic)
174  {
175  outParameter.push_back(output);
176  prevOutput = --outParameter.end();
177  }
178  else
179  {
180  if (forwardStep == 1)
181  {
182  outParameter.clear();
183  outParameter.push_back(output);
184 
185  prevOutput = outParameter.begin();
186  }
187  else
188  {
189  *prevOutput = output;
190  }
191  }
192 }
193 
194 template<typename InputDataType, typename OutputDataType>
195 template<typename eT>
197  const arma::Mat<eT>& input, const arma::Mat<eT>& gy, arma::Mat<eT>& g)
198 {
199  if (input.n_cols != batchSize)
200  {
201  batchSize = input.n_cols;
202  prevError.resize(3 * outSize, batchSize);
203  allZeros.zeros(outSize, batchSize);
204  // Batch size better not change during an iteration...
205  if (outParameter.size() > 1)
206  {
207  Log::Fatal << "GRU<>::Forward(): batch size cannot change during a "
208  << "forward pass!" << std::endl;
209  }
210 
211  outParameter.clear();
212  outParameter.emplace_back(allZeros.memptr(),
213  allZeros.n_rows, allZeros.n_cols, false, true);
214 
215  prevOutput = outParameter.begin();
216  backIterator = outParameter.end();
217  gradIterator = outParameter.end();
218  }
219 
220  arma::Mat<eT> gyLocal;
221  if ((outParameter.size() - backwardStep - 1) % rho != 0 && backwardStep != 0)
222  {
223  gyLocal = gy + boost::apply_visitor(deltaVisitor, output2GateModule);
224  }
225  else
226  {
227  gyLocal = arma::Mat<eT>(((arma::Mat<eT>&) gy).memptr(), gy.n_rows,
228  gy.n_cols, false, false);
229  }
230 
231  if (backIterator == outParameter.end())
232  {
233  backIterator = --(--outParameter.end());
234  }
235 
236  // Delta zt.
237  arma::mat dZt = gyLocal % (*backIterator -
238  boost::apply_visitor(outputParameterVisitor,
239  hiddenStateModule));
240 
241  // Delta ot.
242  arma::mat dOt = gyLocal % (arma::ones<arma::mat>(outSize, batchSize) -
243  boost::apply_visitor(outputParameterVisitor, inputGateModule));
244 
245  // Delta of input gate.
246  boost::apply_visitor(BackwardVisitor(boost::apply_visitor(
247  outputParameterVisitor, inputGateModule), dZt,
248  boost::apply_visitor(deltaVisitor, inputGateModule)),
249  inputGateModule);
250 
251  // Delta of hidden gate.
252  boost::apply_visitor(BackwardVisitor(boost::apply_visitor(
253  outputParameterVisitor, hiddenStateModule), dOt,
254  boost::apply_visitor(deltaVisitor, hiddenStateModule)),
255  hiddenStateModule);
256 
257  // Delta of outputHidden2GateModule.
258  boost::apply_visitor(BackwardVisitor(boost::apply_visitor(
259  outputParameterVisitor, outputHidden2GateModule),
260  boost::apply_visitor(deltaVisitor, hiddenStateModule),
261  boost::apply_visitor(deltaVisitor, outputHidden2GateModule)),
262  outputHidden2GateModule);
263 
264  // Delta rt.
265  arma::mat dRt = boost::apply_visitor(deltaVisitor, outputHidden2GateModule) %
266  *backIterator;
267 
268  // Delta of forget gate.
269  boost::apply_visitor(BackwardVisitor(boost::apply_visitor(
270  outputParameterVisitor, forgetGateModule), dRt,
271  boost::apply_visitor(deltaVisitor, forgetGateModule)),
272  forgetGateModule);
273 
274  // Put delta zt.
275  prevError.submat(0, 0, 1 * outSize - 1, batchSize - 1) = boost::apply_visitor(
276  deltaVisitor, inputGateModule);
277 
278  // Put delta rt.
279  prevError.submat(1 * outSize, 0, 2 * outSize - 1, batchSize - 1) =
280  boost::apply_visitor(deltaVisitor, forgetGateModule);
281 
282  // Put delta ot.
283  prevError.submat(2 * outSize, 0, 3 * outSize - 1, batchSize - 1) =
284  boost::apply_visitor(deltaVisitor, hiddenStateModule);
285 
286  // Get delta ht - 1 for input gate and forget gate.
287  arma::mat prevErrorSubview = prevError.submat(0, 0, 2 * outSize - 1,
288  batchSize - 1);
289  boost::apply_visitor(BackwardVisitor(boost::apply_visitor(
290  outputParameterVisitor, input2GateModule),
291  prevErrorSubview,
292  boost::apply_visitor(deltaVisitor, output2GateModule)),
293  output2GateModule);
294 
295  // Add delta ht - 1 from hidden state.
296  boost::apply_visitor(deltaVisitor, output2GateModule) +=
297  boost::apply_visitor(deltaVisitor, outputHidden2GateModule) %
298  boost::apply_visitor(outputParameterVisitor, forgetGateModule);
299 
300  // Add delta ht - 1 from ht.
301  boost::apply_visitor(deltaVisitor, output2GateModule) += gyLocal %
302  boost::apply_visitor(outputParameterVisitor, inputGateModule);
303 
304  // Get delta input.
305  boost::apply_visitor(BackwardVisitor(boost::apply_visitor(
306  outputParameterVisitor, input2GateModule), prevError,
307  boost::apply_visitor(deltaVisitor, input2GateModule)),
308  input2GateModule);
309 
310  backwardStep++;
311  backIterator--;
312 
313  g = boost::apply_visitor(deltaVisitor, input2GateModule);
314 }
315 
316 template<typename InputDataType, typename OutputDataType>
317 template<typename eT>
319  const arma::Mat<eT>& input,
320  const arma::Mat<eT>& /* error */,
321  arma::Mat<eT>& /* gradient */)
322 {
323  if (input.n_cols != batchSize)
324  {
325  batchSize = input.n_cols;
326  prevError.resize(3 * outSize, batchSize);
327  allZeros.zeros(outSize, batchSize);
328  // Batch size better not change during an iteration...
329  if (outParameter.size() > 1)
330  {
331  Log::Fatal << "GRU<>::Forward(): batch size cannot change during a "
332  << "forward pass!" << std::endl;
333  }
334 
335  outParameter.clear();
336  outParameter.emplace_back(allZeros.memptr(),
337  allZeros.n_rows, allZeros.n_cols, false, true);
338 
339  prevOutput = outParameter.begin();
340  backIterator = outParameter.end();
341  gradIterator = outParameter.end();
342  }
343 
344  if (gradIterator == outParameter.end())
345  {
346  gradIterator = --(--outParameter.end());
347  }
348 
349  boost::apply_visitor(GradientVisitor(input, prevError), input2GateModule);
350 
351  boost::apply_visitor(GradientVisitor(
352  *gradIterator,
353  prevError.submat(0, 0, 2 * outSize - 1, batchSize - 1)),
354  output2GateModule);
355 
356  boost::apply_visitor(GradientVisitor(
357  *gradIterator % boost::apply_visitor(outputParameterVisitor,
358  forgetGateModule),
359  prevError.submat(2 * outSize, 0, 3 * outSize - 1, batchSize - 1)),
360  outputHidden2GateModule);
361 
362  gradIterator--;
363 }
364 
365 template<typename InputDataType, typename OutputDataType>
366 void GRU<InputDataType, OutputDataType>::ResetCell(const size_t /* size */)
367 {
368  outParameter.clear();
369  outParameter.emplace_back(allZeros.memptr(),
370  allZeros.n_rows, allZeros.n_cols, false, true);
371 
372  prevOutput = outParameter.begin();
373  backIterator = outParameter.end();
374  gradIterator = outParameter.end();
375 
376  forwardStep = 0;
377  backwardStep = 0;
378 }
379 
380 template<typename InputDataType, typename OutputDataType>
381 template<typename Archive>
383  Archive& ar, const uint32_t /* version */)
384 {
385  // If necessary, clean memory from the old model.
386  if (cereal::is_loading<Archive>())
387  {
388  boost::apply_visitor(deleteVisitor, input2GateModule);
389  boost::apply_visitor(deleteVisitor, output2GateModule);
390  boost::apply_visitor(deleteVisitor, outputHidden2GateModule);
391  boost::apply_visitor(deleteVisitor, inputGateModule);
392  boost::apply_visitor(deleteVisitor, forgetGateModule);
393  boost::apply_visitor(deleteVisitor, hiddenStateModule);
394  }
395 
396  ar(CEREAL_NVP(inSize));
397  ar(CEREAL_NVP(outSize));
398  ar(CEREAL_NVP(rho));
399 
400  ar(CEREAL_VARIANT_POINTER(input2GateModule));
401  ar(CEREAL_VARIANT_POINTER(output2GateModule));
402  ar(CEREAL_VARIANT_POINTER(outputHidden2GateModule));
403  ar(CEREAL_VARIANT_POINTER(inputGateModule));
404  ar(CEREAL_VARIANT_POINTER(forgetGateModule));
405  ar(CEREAL_VARIANT_POINTER(hiddenStateModule));
406 }
407 
408 } // namespace ann
409 } // namespace mlpack
410 
411 #endif
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: gru_impl.hpp:81
BackwardVisitor executes the Backward() function given the input, error and delta parameter...
Definition: backward_visitor.hpp:28
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
GRU()
Create the GRU object.
Definition: gru_impl.hpp:27
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: gru_impl.hpp:382
Implementation of the Linear layer class.
Definition: layer_types.hpp:93
Implementation of the base layer.
Definition: base_layer.hpp:71
ForwardVisitor executes the Forward() function given the input and output parameter.
Definition: forward_visitor.hpp:28
Implementation of the LinearNoBias class.
Definition: layer_types.hpp:103
#define CEREAL_VARIANT_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_variant_wrapper.hpp:155
An implementation of a gru network layer.
Definition: gru.hpp:58
SearchModeVisitor executes the Gradient() method of the given module using the input and delta parame...
Definition: gradient_visitor.hpp:28
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: gru_impl.hpp:196
OutputDataType const & Gradient() const
Get the gradient.
Definition: gru.hpp:145