mlpack
lstm_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_LAYER_LSTM_IMPL_HPP
13 #define MLPACK_METHODS_ANN_LAYER_LSTM_IMPL_HPP
14 
15 // In case it hasn't yet been included.
16 #include "lstm.hpp"
17 
18 namespace mlpack {
19 namespace ann {
20 
21 template<typename InputDataType, typename OutputDataType>
23 {
24  // Nothing to do here.
25 }
26 
27 template<typename InputDataType, typename OutputDataType>
29  const LSTM& layer) :
30  inSize(layer.inSize),
31  outSize(layer.outSize),
32  rho(layer.rho),
33  forwardStep(layer.forwardStep),
34  backwardStep(layer.backwardStep),
35  gradientStep(layer.gradientStep),
36  weights(layer.weights),
37  batchSize(layer.batchSize),
38  batchStep(layer.batchStep),
39  gradientStepIdx(layer.gradientStepIdx),
40  rhoSize(layer.rho),
41  bpttSteps(layer.bpttSteps)
42 {
43  // Nothing to do here.
44 }
45 
46 template<typename InputDataType, typename OutputDataType>
48  LSTM&& layer) :
49  inSize(std::move(layer.inSize)),
50  outSize(std::move(layer.outSize)),
51  rho(std::move(layer.rho)),
52  forwardStep(std::move(layer.forwardStep)),
53  backwardStep(std::move(layer.backwardStep)),
54  gradientStep(std::move(layer.gradientStep)),
55  weights(std::move(layer.weights)),
56  batchSize(std::move(layer.batchSize)),
57  batchStep(std::move(layer.batchStep)),
58  gradientStepIdx(std::move(layer.gradientStepIdx)),
59  rhoSize(std::move(layer.rho)),
60  bpttSteps(std::move(layer.bpttSteps))
61 {
62  // Nothing to do here.
63 }
64 
65 template <typename InputDataType, typename OutputDataType>
68 {
69  if (this != &layer)
70  {
71  inSize = layer.inSize;
72  outSize = layer.outSize;
73  rho = layer.rho;
74  forwardStep = layer.forwardStep;
75  backwardStep = layer.backwardStep;
76  gradientStep = layer.gradientStep;
77  weights = layer.weights;
78  batchSize = layer.batchSize;
79  batchStep = layer.batchStep;
80  gradientStepIdx = layer.gradientStepIdx;
81  grad = layer.grad;
82  rhoSize = layer.rho;
83  bpttSteps = layer.bpttSteps;
84  }
85  return *this;
86 }
87 
88 template <typename InputDataType, typename OutputDataType>
91 {
92  if (this != &layer)
93  {
94  inSize = std::move(layer.inSize);
95  outSize = std::move(layer.outSize);
96  rho = std::move(layer.rho);
97  forwardStep = std::move(layer.forwardStep);
98  backwardStep = std::move(layer.backwardStep);
99  gradientStep = std::move(layer.gradientStep);
100  weights = std::move(layer.weights);
101  batchSize = std::move(layer.batchSize);
102  batchStep = std::move(layer.batchStep);
103  gradientStepIdx = std::move(layer.gradientStepIdx);
104  grad = std::move(layer.grad);
105  rhoSize = std::move(layer.rho);
106  bpttSteps = std::move(layer.bpttSteps);
107  }
108  return *this;
109 }
110 
111 template <typename InputDataType, typename OutputDataType>
113  const size_t inSize, const size_t outSize, const size_t rho) :
114  inSize(inSize),
115  outSize(outSize),
116  rho(rho),
117  forwardStep(0),
118  backwardStep(0),
119  gradientStep(0),
120  batchSize(0),
121  batchStep(0),
122  gradientStepIdx(0),
123  rhoSize(rho),
124  bpttSteps(0)
125 {
126  weights.set_size(WeightSize(), 1);
127 }
128 
129 template<typename InputDataType, typename OutputDataType>
131 {
132  if (size == std::numeric_limits<size_t>::max())
133  return;
134 
135  rhoSize = size;
136 
137  if (batchSize == 0)
138  return;
139 
140  bpttSteps = std::min(rho, rhoSize);
141  forwardStep = 0;
142  gradientStepIdx = 0;
143  backwardStep = batchSize * size - 1;
144  gradientStep = batchSize * size - 1;
145 
146  const size_t rhoBatchSize = size * batchSize;
147 
148  // Make sure all of the different matrices we will use to hold parameters are
149  // at least as large as we need.
150  inputGate.set_size(outSize, rhoBatchSize);
151  forgetGate.set_size(outSize, rhoBatchSize);
152  hiddenLayer.set_size(outSize, rhoBatchSize);
153  outputGate.set_size(outSize, rhoBatchSize);
154 
155  inputGateActivation.set_size(outSize, rhoBatchSize);
156  forgetGateActivation.set_size(outSize, rhoBatchSize);
157  outputGateActivation.set_size(outSize, rhoBatchSize);
158  hiddenLayerActivation.set_size(outSize, rhoBatchSize);
159 
160  cellActivation.set_size(outSize, rhoBatchSize);
161  prevError.set_size(4 * outSize, batchSize);
162 
163  // Now reset recurrent values to 0.
164  cell.zeros(outSize, size * batchSize);
165  outParameter.zeros(outSize, (size + 1) * batchSize);
166 }
167 
168 template<typename InputDataType, typename OutputDataType>
170 {
171  // Set the weight parameter for the output gate.
172  input2GateOutputWeight = OutputDataType(weights.memptr(), outSize, inSize,
173  false, false);
174  input2GateOutputBias = OutputDataType(weights.memptr() +
175  input2GateOutputWeight.n_elem, outSize, 1, false, false);
176  size_t offset = input2GateOutputWeight.n_elem + input2GateOutputBias.n_elem;
177 
178  // Set the weight parameter for the forget gate.
179  input2GateForgetWeight = OutputDataType(weights.memptr() + offset,
180  outSize, inSize, false, false);
181  input2GateForgetBias = OutputDataType(weights.memptr() +
182  offset + input2GateForgetWeight.n_elem, outSize, 1, false, false);
183  offset += input2GateForgetWeight.n_elem + input2GateForgetBias.n_elem;
184 
185  // Set the weight parameter for the input gate.
186  input2GateInputWeight = OutputDataType(weights.memptr() +
187  offset, outSize, inSize, false, false);
188  input2GateInputBias = OutputDataType(weights.memptr() +
189  offset + input2GateInputWeight.n_elem, outSize, 1, false, false);
190  offset += input2GateInputWeight.n_elem + input2GateInputBias.n_elem;
191 
192  // Set the weight parameter for the hidden gate.
193  input2HiddenWeight = OutputDataType(weights.memptr() +
194  offset, outSize, inSize, false, false);
195  input2HiddenBias = OutputDataType(weights.memptr() +
196  offset + input2HiddenWeight.n_elem, outSize, 1, false, false);
197  offset += input2HiddenWeight.n_elem + input2HiddenBias.n_elem;
198 
199  // Set the weight parameter for the output multiplication.
200  output2GateOutputWeight = OutputDataType(weights.memptr() +
201  offset, outSize, outSize, false, false);
202  offset += output2GateOutputWeight.n_elem;
203 
204  // Set the weight parameter for the output multiplication.
205  output2GateForgetWeight = OutputDataType(weights.memptr() +
206  offset, outSize, outSize, false, false);
207  offset += output2GateForgetWeight.n_elem;
208 
209  // Set the weight parameter for the input multiplication.
210  output2GateInputWeight = OutputDataType(weights.memptr() +
211  offset, outSize, outSize, false, false);
212  offset += output2GateInputWeight.n_elem;
213 
214  // Set the weight parameter for the hidden multiplication.
215  output2HiddenWeight = OutputDataType(weights.memptr() +
216  offset, outSize, outSize, false, false);
217  offset += output2HiddenWeight.n_elem;
218 
219  // Set the weight parameter for the cell multiplication.
220  cell2GateOutputWeight = OutputDataType(weights.memptr() +
221  offset, outSize, 1, false, false);
222  offset += cell2GateOutputWeight.n_elem;
223 
224  // Set the weight parameter for the cell - forget gate multiplication.
225  cell2GateForgetWeight = OutputDataType(weights.memptr() +
226  offset, outSize, 1, false, false);
227  offset += cell2GateOutputWeight.n_elem;
228 
229  // Set the weight parameter for the cell - input gate multiplication.
230  cell2GateInputWeight = OutputDataType(weights.memptr() +
231  offset, outSize, 1, false, false);
232 }
233 
234 // Forward when cellState is not needed.
235 template<typename InputDataType, typename OutputDataType>
236 template<typename InputType, typename OutputType>
238  const InputType& input, OutputType& output)
239 {
241  OutputType cellState;
242  Forward(input, output, cellState, false);
243 }
244 
245 // Forward when cellState is needed overloaded LSTM::Forward().
246 template<typename InputDataType, typename OutputDataType>
247 template<typename InputType, typename OutputType>
249  OutputType& output,
250  OutputType& cellState,
251  bool useCellState)
252 {
253  // Check if the batch size changed, the number of cols is defines the input
254  // batch size.
255  if (input.n_cols != batchSize)
256  {
257  batchSize = input.n_cols;
258  batchStep = batchSize - 1;
259  ResetCell(rhoSize);
260  }
261 
262  inputGate.cols(forwardStep, forwardStep + batchStep) = input2GateInputWeight *
263  input + output2GateInputWeight * outParameter.cols(forwardStep,
264  forwardStep + batchStep);
265  inputGate.cols(forwardStep, forwardStep + batchStep).each_col() +=
266  input2GateInputBias;
267 
268  forgetGate.cols(forwardStep, forwardStep + batchStep) = input2GateForgetWeight
269  * input + output2GateForgetWeight * outParameter.cols(
270  forwardStep, forwardStep + batchStep);
271  forgetGate.cols(forwardStep, forwardStep + batchStep).each_col() +=
272  input2GateForgetBias;
273 
274  if (forwardStep > 0)
275  {
276  if (useCellState)
277  {
278  if (!cellState.is_empty())
279  {
280  cell.cols(forwardStep - batchSize,
281  forwardStep - batchSize + batchStep) = cellState;
282  }
283  else
284  {
285  throw std::runtime_error("Cell parameter is empty.");
286  }
287  }
288  inputGate.cols(forwardStep, forwardStep + batchStep) +=
289  arma::repmat(cell2GateInputWeight, 1, batchSize) %
290  cell.cols(forwardStep - batchSize, forwardStep - batchSize + batchStep);
291 
292  forgetGate.cols(forwardStep, forwardStep + batchStep) +=
293  arma::repmat(cell2GateForgetWeight, 1, batchSize) %
294  cell.cols(forwardStep - batchSize, forwardStep - batchSize + batchStep);
295  }
296 
297  inputGateActivation.cols(forwardStep, forwardStep + batchStep) = 1.0 /
298  (1 + arma::exp(-inputGate.cols(forwardStep, forwardStep + batchStep)));
299 
300  forgetGateActivation.cols(forwardStep, forwardStep + batchStep) = 1.0 /
301  (1 + arma::exp(-forgetGate.cols(forwardStep, forwardStep + batchStep)));
302 
303  hiddenLayer.cols(forwardStep, forwardStep + batchStep) = input2HiddenWeight *
304  input + output2HiddenWeight * outParameter.cols(
305  forwardStep, forwardStep + batchStep);
306 
307  hiddenLayer.cols(forwardStep, forwardStep + batchStep).each_col() +=
308  input2HiddenBias;
309 
310  hiddenLayerActivation.cols(forwardStep, forwardStep + batchStep) =
311  arma::tanh(hiddenLayer.cols(forwardStep, forwardStep + batchStep));
312 
313  if (forwardStep == 0)
314  {
315  cell.cols(forwardStep, forwardStep + batchStep) =
316  inputGateActivation.cols(forwardStep, forwardStep + batchStep) %
317  hiddenLayerActivation.cols(forwardStep, forwardStep + batchStep);
318  }
319  else
320  {
321  cell.cols(forwardStep, forwardStep + batchStep) =
322  forgetGateActivation.cols(forwardStep, forwardStep + batchStep) %
323  cell.cols(forwardStep - batchSize, forwardStep - batchSize + batchStep)
324  + inputGateActivation.cols(forwardStep, forwardStep + batchStep) %
325  hiddenLayerActivation.cols(forwardStep, forwardStep + batchStep);
326  }
327 
328  outputGate.cols(forwardStep, forwardStep + batchStep) = input2GateOutputWeight
329  * input + output2GateOutputWeight * outParameter.cols(
330  forwardStep, forwardStep + batchStep) + cell.cols(forwardStep,
331  forwardStep + batchStep).each_col() % cell2GateOutputWeight;
332 
333  outputGate.cols(forwardStep, forwardStep + batchStep).each_col() +=
334  input2GateOutputBias;
335 
336  outputGateActivation.cols(forwardStep, forwardStep + batchStep) = 1.0 /
337  (1 + arma::exp(-outputGate.cols(forwardStep, forwardStep + batchStep)));
338 
339  cellActivation.cols(forwardStep, forwardStep + batchStep) =
340  arma::tanh(cell.cols(forwardStep, forwardStep + batchStep));
341 
342  outParameter.cols(forwardStep + batchSize,
343  forwardStep + batchSize + batchStep) =
344  cellActivation.cols(forwardStep, forwardStep + batchStep) %
345  outputGateActivation.cols(forwardStep, forwardStep + batchStep);
346 
347  output = OutputType(outParameter.memptr() +
348  (forwardStep + batchSize) * outSize, outSize, batchSize, false, false);
349 
350  cellState = OutputType(cell.memptr() +
351  forwardStep * outSize, outSize, batchSize, false, false);
352 
353  forwardStep += batchSize;
354  if ((forwardStep / batchSize) == bpttSteps)
355  {
356  forwardStep = 0;
357  }
358 }
359 
360 template<typename InputDataType, typename OutputDataType>
361 template<typename InputType, typename ErrorType, typename GradientType>
363  const InputType& /* input */, const ErrorType& gy, GradientType& g)
364 {
365  ErrorType gyLocal;
366  if (gradientStepIdx > 0)
367  {
368  gyLocal = gy + prevError;
369  }
370  else
371  {
372  // Make an alias.
373  gyLocal = ErrorType(((ErrorType&) gy).memptr(), gy.n_rows, gy.n_cols, false,
374  false);
375  }
376 
377  outputGateError =
378  gyLocal % cellActivation.cols(backwardStep - batchStep, backwardStep) %
379  (outputGateActivation.cols(backwardStep - batchStep, backwardStep) %
380  (1.0 - outputGateActivation.cols(backwardStep - batchStep,
381  backwardStep)));
382 
383  OutputDataType cellError = gyLocal %
384  outputGateActivation.cols(backwardStep - batchStep, backwardStep) %
385  (1 - arma::pow(cellActivation.cols(backwardStep -
386  batchStep, backwardStep), 2)) + outputGateError.each_col() %
387  cell2GateOutputWeight;
388 
389  if (gradientStepIdx > 0)
390  {
391  cellError += inputCellError;
392  }
393 
394  if (backwardStep > batchStep)
395  {
396  forgetGateError = cell.cols((backwardStep - batchSize) - batchStep,
397  (backwardStep - batchSize)) % cellError % (forgetGateActivation.cols(
398  backwardStep - batchStep, backwardStep) % (1.0 -
399  forgetGateActivation.cols(backwardStep - batchStep, backwardStep)));
400  }
401  else
402  {
403  forgetGateError.zeros();
404  }
405 
406  inputGateError = hiddenLayerActivation.cols(backwardStep - batchStep,
407  backwardStep) % cellError %
408  (inputGateActivation.cols(backwardStep - batchStep, backwardStep) %
409  (1.0 - inputGateActivation.cols(backwardStep - batchStep, backwardStep)));
410 
411  hiddenError = inputGateActivation.cols(backwardStep - batchStep,
412  backwardStep) % cellError % (1 - arma::pow(hiddenLayerActivation.cols(
413  backwardStep - batchStep, backwardStep), 2));
414 
415  inputCellError = forgetGateActivation.cols(backwardStep - batchStep,
416  backwardStep) % cellError + forgetGateError.each_col() %
417  cell2GateForgetWeight + inputGateError.each_col() % cell2GateInputWeight;
418 
419  g = input2GateInputWeight.t() * inputGateError +
420  input2HiddenWeight.t() * hiddenError +
421  input2GateForgetWeight.t() * forgetGateError +
422  input2GateOutputWeight.t() * outputGateError;
423 
424  prevError = output2GateOutputWeight.t() * outputGateError +
425  output2GateForgetWeight.t() * forgetGateError +
426  output2GateInputWeight.t() * inputGateError +
427  output2HiddenWeight.t() * hiddenError;
428 
429  backwardStep -= batchSize;
430  gradientStepIdx++;
431  if (gradientStepIdx == bpttSteps)
432  {
433  backwardStep = bpttSteps - 1;
434  gradientStepIdx = 0;
435  }
436 }
437 
438 template<typename InputDataType, typename OutputDataType>
439 template<typename InputType, typename ErrorType, typename GradientType>
441  const InputType& input,
442  const ErrorType& /* error */,
443  GradientType& gradient)
444 {
445  // Input2GateOutputWeight and input2GateOutputBias gradients.
446  gradient.submat(0, 0, input2GateOutputWeight.n_elem - 1, 0) =
447  arma::vectorise(outputGateError * input.t());
448  gradient.submat(input2GateOutputWeight.n_elem, 0,
449  input2GateOutputWeight.n_elem + input2GateOutputBias.n_elem - 1, 0) =
450  arma::sum(outputGateError, 1);
451  size_t offset = input2GateOutputWeight.n_elem + input2GateOutputBias.n_elem;
452 
453  // Input2GateForgetWeight and input2GateForgetBias gradients.
454  gradient.submat(offset, 0, offset + input2GateForgetWeight.n_elem - 1, 0) =
455  arma::vectorise(forgetGateError * input.t());
456  gradient.submat(offset + input2GateForgetWeight.n_elem, 0,
457  offset + input2GateForgetWeight.n_elem +
458  input2GateForgetBias.n_elem - 1, 0) = arma::sum(forgetGateError, 1);
459  offset += input2GateForgetWeight.n_elem + input2GateForgetBias.n_elem;
460 
461  // Input2GateInputWeight and input2GateInputBias gradients.
462  gradient.submat(offset, 0, offset + input2GateInputWeight.n_elem - 1, 0) =
463  arma::vectorise(inputGateError * input.t());
464  gradient.submat(offset + input2GateInputWeight.n_elem, 0,
465  offset + input2GateInputWeight.n_elem +
466  input2GateInputBias.n_elem - 1, 0) = arma::sum(inputGateError, 1);
467  offset += input2GateInputWeight.n_elem + input2GateInputBias.n_elem;
468 
469  // Input2HiddenWeight and input2HiddenBias gradients.
470  gradient.submat(offset, 0, offset + input2HiddenWeight.n_elem - 1, 0) =
471  arma::vectorise(hiddenError * input.t());
472  gradient.submat(offset + input2HiddenWeight.n_elem, 0,
473  offset + input2HiddenWeight.n_elem + input2HiddenBias.n_elem - 1, 0) =
474  arma::sum(hiddenError, 1);
475  offset += input2HiddenWeight.n_elem + input2HiddenBias.n_elem;
476 
477  // Output2GateOutputWeight gradients.
478  gradient.submat(offset, 0, offset + output2GateOutputWeight.n_elem - 1, 0) =
479  arma::vectorise(outputGateError *
480  outParameter.cols(gradientStep - batchStep, gradientStep).t());
481  offset += output2GateOutputWeight.n_elem;
482 
483  // Output2GateForgetWeight gradients.
484  gradient.submat(offset, 0, offset + output2GateForgetWeight.n_elem - 1, 0) =
485  arma::vectorise(forgetGateError *
486  outParameter.cols(gradientStep - batchStep, gradientStep).t());
487  offset += output2GateForgetWeight.n_elem;
488 
489  // Output2GateInputWeight gradients.
490  gradient.submat(offset, 0, offset + output2GateInputWeight.n_elem - 1, 0) =
491  arma::vectorise(inputGateError *
492  outParameter.cols(gradientStep - batchStep, gradientStep).t());
493  offset += output2GateInputWeight.n_elem;
494 
495  // Output2HiddenWeight gradients.
496  gradient.submat(offset, 0, offset + output2HiddenWeight.n_elem - 1, 0) =
497  arma::vectorise(hiddenError *
498  outParameter.cols(gradientStep - batchStep, gradientStep).t());
499  offset += output2HiddenWeight.n_elem;
500 
501  // Cell2GateOutputWeight gradients.
502  gradient.submat(offset, 0, offset + cell2GateOutputWeight.n_elem - 1, 0) =
503  arma::sum(outputGateError %
504  cell.cols(gradientStep - batchStep, gradientStep), 1);
505  offset += cell2GateOutputWeight.n_elem;
506 
507  // Cell2GateForgetWeight and cell2GateInputWeight gradients.
508  if (gradientStep > batchStep)
509  {
510  gradient.submat(offset, 0, offset + cell2GateForgetWeight.n_elem - 1, 0) =
511  arma::sum(forgetGateError %
512  cell.cols((gradientStep - batchSize) - batchStep,
513  (gradientStep - batchSize)), 1);
514  gradient.submat(offset + cell2GateForgetWeight.n_elem, 0, offset +
515  cell2GateForgetWeight.n_elem + cell2GateInputWeight.n_elem - 1, 0) =
516  arma::sum(inputGateError %
517  cell.cols((gradientStep - batchSize) - batchStep,
518  (gradientStep - batchSize)), 1);
519  }
520  else
521  {
522  gradient.submat(offset, 0, offset +
523  cell2GateForgetWeight.n_elem - 1, 0).zeros();
524  gradient.submat(offset + cell2GateForgetWeight.n_elem, 0, offset +
525  cell2GateForgetWeight.n_elem +
526  cell2GateInputWeight.n_elem - 1, 0).zeros();
527  }
528 
529  if (gradientStep == 0)
530  {
531  gradientStep = batchSize * bpttSteps - 1;
532  }
533  else
534  {
535  gradientStep -= batchSize;
536  }
537 }
538 
539 template<typename InputDataType, typename OutputDataType>
540 template<typename Archive>
542  Archive& ar, const uint32_t /* version */)
543 {
544  ar(CEREAL_NVP(weights));
545  ar(CEREAL_NVP(inSize));
546  ar(CEREAL_NVP(outSize));
547  ar(CEREAL_NVP(rho));
548  ar(CEREAL_NVP(bpttSteps));
549  ar(CEREAL_NVP(batchSize));
550  ar(CEREAL_NVP(batchStep));
551  ar(CEREAL_NVP(forwardStep));
552  ar(CEREAL_NVP(backwardStep));
553  ar(CEREAL_NVP(gradientStep));
554  ar(CEREAL_NVP(gradientStepIdx));
555  ar(CEREAL_NVP(cell));
556  ar(CEREAL_NVP(inputGateActivation));
557  ar(CEREAL_NVP(forgetGateActivation));
558  ar(CEREAL_NVP(outputGateActivation));
559  ar(CEREAL_NVP(hiddenLayerActivation));
560  ar(CEREAL_NVP(cellActivation));
561  ar(CEREAL_NVP(prevError));
562  ar(CEREAL_NVP(outParameter));
563 }
564 
565 } // namespace ann
566 } // namespace mlpack
567 
568 #endif
LSTM & operator=(const LSTM &layer)
Copy assignment operator.
Definition: lstm_impl.hpp:67
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Definition: pointer_wrapper.hpp:23
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: lstm_impl.hpp:541
OutputDataType const & Gradient() const
Get the gradient.
Definition: lstm.hpp:176
Implementation of the LSTM module class.
Definition: layer_types.hpp:82
size_t WeightSize() const
Get the size of the weights.
Definition: lstm.hpp:187
void Forward(const InputType &input, OutputType &output)
Ordinary feed-forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: lstm_impl.hpp:237
LSTM()
Create the LSTM object.
Definition: lstm_impl.hpp:22
void Backward(const InputType &input, const ErrorType &gy, GradientType &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: lstm_impl.hpp:362