12 #ifndef MLPACK_METHODS_ANN_LAYER_LSTM_IMPL_HPP 13 #define MLPACK_METHODS_ANN_LAYER_LSTM_IMPL_HPP 21 template<
typename InputDataType,
typename OutputDataType>
27 template<
typename InputDataType,
typename OutputDataType>
31 outSize(layer.outSize),
33 forwardStep(layer.forwardStep),
34 backwardStep(layer.backwardStep),
35 gradientStep(layer.gradientStep),
36 weights(layer.weights),
37 batchSize(layer.batchSize),
38 batchStep(layer.batchStep),
39 gradientStepIdx(layer.gradientStepIdx),
41 bpttSteps(layer.bpttSteps)
46 template<
typename InputDataType,
typename OutputDataType>
49 inSize(
std::move(layer.inSize)),
50 outSize(
std::move(layer.outSize)),
51 rho(
std::move(layer.rho)),
52 forwardStep(
std::move(layer.forwardStep)),
53 backwardStep(
std::move(layer.backwardStep)),
54 gradientStep(
std::move(layer.gradientStep)),
55 weights(
std::move(layer.weights)),
56 batchSize(
std::move(layer.batchSize)),
57 batchStep(
std::move(layer.batchStep)),
58 gradientStepIdx(
std::move(layer.gradientStepIdx)),
59 rhoSize(
std::move(layer.rho)),
60 bpttSteps(
std::move(layer.bpttSteps))
65 template <
typename InputDataType,
typename OutputDataType>
71 inSize = layer.inSize;
72 outSize = layer.outSize;
74 forwardStep = layer.forwardStep;
75 backwardStep = layer.backwardStep;
76 gradientStep = layer.gradientStep;
77 weights = layer.weights;
78 batchSize = layer.batchSize;
79 batchStep = layer.batchStep;
80 gradientStepIdx = layer.gradientStepIdx;
83 bpttSteps = layer.bpttSteps;
88 template <
typename InputDataType,
typename OutputDataType>
94 inSize = std::move(layer.inSize);
95 outSize = std::move(layer.outSize);
96 rho = std::move(layer.rho);
97 forwardStep = std::move(layer.forwardStep);
98 backwardStep = std::move(layer.backwardStep);
99 gradientStep = std::move(layer.gradientStep);
100 weights = std::move(layer.weights);
101 batchSize = std::move(layer.batchSize);
102 batchStep = std::move(layer.batchStep);
103 gradientStepIdx = std::move(layer.gradientStepIdx);
104 grad = std::move(layer.grad);
105 rhoSize = std::move(layer.rho);
106 bpttSteps = std::move(layer.bpttSteps);
111 template <
typename InputDataType,
typename OutputDataType>
113 const size_t inSize,
const size_t outSize,
const size_t rho) :
129 template<
typename InputDataType,
typename OutputDataType>
132 if (size == std::numeric_limits<size_t>::max())
140 bpttSteps = std::min(rho, rhoSize);
143 backwardStep = batchSize * size - 1;
144 gradientStep = batchSize * size - 1;
146 const size_t rhoBatchSize = size * batchSize;
150 inputGate.set_size(outSize, rhoBatchSize);
151 forgetGate.set_size(outSize, rhoBatchSize);
152 hiddenLayer.set_size(outSize, rhoBatchSize);
153 outputGate.set_size(outSize, rhoBatchSize);
155 inputGateActivation.set_size(outSize, rhoBatchSize);
156 forgetGateActivation.set_size(outSize, rhoBatchSize);
157 outputGateActivation.set_size(outSize, rhoBatchSize);
158 hiddenLayerActivation.set_size(outSize, rhoBatchSize);
160 cellActivation.set_size(outSize, rhoBatchSize);
161 prevError.set_size(4 * outSize, batchSize);
164 cell.zeros(outSize, size * batchSize);
165 outParameter.zeros(outSize, (size + 1) * batchSize);
168 template<
typename InputDataType,
typename OutputDataType>
172 input2GateOutputWeight = OutputDataType(weights.memptr(), outSize, inSize,
174 input2GateOutputBias = OutputDataType(weights.memptr() +
175 input2GateOutputWeight.n_elem, outSize, 1,
false,
false);
176 size_t offset = input2GateOutputWeight.n_elem + input2GateOutputBias.n_elem;
179 input2GateForgetWeight = OutputDataType(weights.memptr() + offset,
180 outSize, inSize,
false,
false);
181 input2GateForgetBias = OutputDataType(weights.memptr() +
182 offset + input2GateForgetWeight.n_elem, outSize, 1,
false,
false);
183 offset += input2GateForgetWeight.n_elem + input2GateForgetBias.n_elem;
186 input2GateInputWeight = OutputDataType(weights.memptr() +
187 offset, outSize, inSize,
false,
false);
188 input2GateInputBias = OutputDataType(weights.memptr() +
189 offset + input2GateInputWeight.n_elem, outSize, 1,
false,
false);
190 offset += input2GateInputWeight.n_elem + input2GateInputBias.n_elem;
193 input2HiddenWeight = OutputDataType(weights.memptr() +
194 offset, outSize, inSize,
false,
false);
195 input2HiddenBias = OutputDataType(weights.memptr() +
196 offset + input2HiddenWeight.n_elem, outSize, 1,
false,
false);
197 offset += input2HiddenWeight.n_elem + input2HiddenBias.n_elem;
200 output2GateOutputWeight = OutputDataType(weights.memptr() +
201 offset, outSize, outSize,
false,
false);
202 offset += output2GateOutputWeight.n_elem;
205 output2GateForgetWeight = OutputDataType(weights.memptr() +
206 offset, outSize, outSize,
false,
false);
207 offset += output2GateForgetWeight.n_elem;
210 output2GateInputWeight = OutputDataType(weights.memptr() +
211 offset, outSize, outSize,
false,
false);
212 offset += output2GateInputWeight.n_elem;
215 output2HiddenWeight = OutputDataType(weights.memptr() +
216 offset, outSize, outSize,
false,
false);
217 offset += output2HiddenWeight.n_elem;
220 cell2GateOutputWeight = OutputDataType(weights.memptr() +
221 offset, outSize, 1,
false,
false);
222 offset += cell2GateOutputWeight.n_elem;
225 cell2GateForgetWeight = OutputDataType(weights.memptr() +
226 offset, outSize, 1,
false,
false);
227 offset += cell2GateOutputWeight.n_elem;
230 cell2GateInputWeight = OutputDataType(weights.memptr() +
231 offset, outSize, 1,
false,
false);
235 template<
typename InputDataType,
typename OutputDataType>
236 template<
typename InputType,
typename OutputType>
238 const InputType& input, OutputType& output)
241 OutputType cellState;
242 Forward(input, output, cellState,
false);
246 template<
typename InputDataType,
typename OutputDataType>
247 template<
typename InputType,
typename OutputType>
250 OutputType& cellState,
255 if (input.n_cols != batchSize)
257 batchSize = input.n_cols;
258 batchStep = batchSize - 1;
262 inputGate.cols(forwardStep, forwardStep + batchStep) = input2GateInputWeight *
263 input + output2GateInputWeight * outParameter.cols(forwardStep,
264 forwardStep + batchStep);
265 inputGate.cols(forwardStep, forwardStep + batchStep).each_col() +=
268 forgetGate.cols(forwardStep, forwardStep + batchStep) = input2GateForgetWeight
269 * input + output2GateForgetWeight * outParameter.cols(
270 forwardStep, forwardStep + batchStep);
271 forgetGate.cols(forwardStep, forwardStep + batchStep).each_col() +=
272 input2GateForgetBias;
278 if (!cellState.is_empty())
280 cell.cols(forwardStep - batchSize,
281 forwardStep - batchSize + batchStep) = cellState;
285 throw std::runtime_error(
"Cell parameter is empty.");
288 inputGate.cols(forwardStep, forwardStep + batchStep) +=
289 arma::repmat(cell2GateInputWeight, 1, batchSize) %
290 cell.cols(forwardStep - batchSize, forwardStep - batchSize + batchStep);
292 forgetGate.cols(forwardStep, forwardStep + batchStep) +=
293 arma::repmat(cell2GateForgetWeight, 1, batchSize) %
294 cell.cols(forwardStep - batchSize, forwardStep - batchSize + batchStep);
297 inputGateActivation.cols(forwardStep, forwardStep + batchStep) = 1.0 /
298 (1 + arma::exp(-inputGate.cols(forwardStep, forwardStep + batchStep)));
300 forgetGateActivation.cols(forwardStep, forwardStep + batchStep) = 1.0 /
301 (1 + arma::exp(-forgetGate.cols(forwardStep, forwardStep + batchStep)));
303 hiddenLayer.cols(forwardStep, forwardStep + batchStep) = input2HiddenWeight *
304 input + output2HiddenWeight * outParameter.cols(
305 forwardStep, forwardStep + batchStep);
307 hiddenLayer.cols(forwardStep, forwardStep + batchStep).each_col() +=
310 hiddenLayerActivation.cols(forwardStep, forwardStep + batchStep) =
311 arma::tanh(hiddenLayer.cols(forwardStep, forwardStep + batchStep));
313 if (forwardStep == 0)
315 cell.cols(forwardStep, forwardStep + batchStep) =
316 inputGateActivation.cols(forwardStep, forwardStep + batchStep) %
317 hiddenLayerActivation.cols(forwardStep, forwardStep + batchStep);
321 cell.cols(forwardStep, forwardStep + batchStep) =
322 forgetGateActivation.cols(forwardStep, forwardStep + batchStep) %
323 cell.cols(forwardStep - batchSize, forwardStep - batchSize + batchStep)
324 + inputGateActivation.cols(forwardStep, forwardStep + batchStep) %
325 hiddenLayerActivation.cols(forwardStep, forwardStep + batchStep);
328 outputGate.cols(forwardStep, forwardStep + batchStep) = input2GateOutputWeight
329 * input + output2GateOutputWeight * outParameter.cols(
330 forwardStep, forwardStep + batchStep) + cell.cols(forwardStep,
331 forwardStep + batchStep).each_col() % cell2GateOutputWeight;
333 outputGate.cols(forwardStep, forwardStep + batchStep).each_col() +=
334 input2GateOutputBias;
336 outputGateActivation.cols(forwardStep, forwardStep + batchStep) = 1.0 /
337 (1 + arma::exp(-outputGate.cols(forwardStep, forwardStep + batchStep)));
339 cellActivation.cols(forwardStep, forwardStep + batchStep) =
340 arma::tanh(cell.cols(forwardStep, forwardStep + batchStep));
342 outParameter.cols(forwardStep + batchSize,
343 forwardStep + batchSize + batchStep) =
344 cellActivation.cols(forwardStep, forwardStep + batchStep) %
345 outputGateActivation.cols(forwardStep, forwardStep + batchStep);
347 output = OutputType(outParameter.memptr() +
348 (forwardStep + batchSize) * outSize, outSize, batchSize,
false,
false);
350 cellState = OutputType(cell.memptr() +
351 forwardStep * outSize, outSize, batchSize,
false,
false);
353 forwardStep += batchSize;
354 if ((forwardStep / batchSize) == bpttSteps)
360 template<
typename InputDataType,
typename OutputDataType>
361 template<
typename InputType,
typename ErrorType,
typename GradientType>
363 const InputType& ,
const ErrorType& gy, GradientType& g)
366 if (gradientStepIdx > 0)
368 gyLocal = gy + prevError;
373 gyLocal = ErrorType(((ErrorType&) gy).memptr(), gy.n_rows, gy.n_cols,
false,
378 gyLocal % cellActivation.cols(backwardStep - batchStep, backwardStep) %
379 (outputGateActivation.cols(backwardStep - batchStep, backwardStep) %
380 (1.0 - outputGateActivation.cols(backwardStep - batchStep,
383 OutputDataType cellError = gyLocal %
384 outputGateActivation.cols(backwardStep - batchStep, backwardStep) %
385 (1 - arma::pow(cellActivation.cols(backwardStep -
386 batchStep, backwardStep), 2)) + outputGateError.each_col() %
387 cell2GateOutputWeight;
389 if (gradientStepIdx > 0)
391 cellError += inputCellError;
394 if (backwardStep > batchStep)
396 forgetGateError = cell.cols((backwardStep - batchSize) - batchStep,
397 (backwardStep - batchSize)) % cellError % (forgetGateActivation.cols(
398 backwardStep - batchStep, backwardStep) % (1.0 -
399 forgetGateActivation.cols(backwardStep - batchStep, backwardStep)));
403 forgetGateError.zeros();
406 inputGateError = hiddenLayerActivation.cols(backwardStep - batchStep,
407 backwardStep) % cellError %
408 (inputGateActivation.cols(backwardStep - batchStep, backwardStep) %
409 (1.0 - inputGateActivation.cols(backwardStep - batchStep, backwardStep)));
411 hiddenError = inputGateActivation.cols(backwardStep - batchStep,
412 backwardStep) % cellError % (1 - arma::pow(hiddenLayerActivation.cols(
413 backwardStep - batchStep, backwardStep), 2));
415 inputCellError = forgetGateActivation.cols(backwardStep - batchStep,
416 backwardStep) % cellError + forgetGateError.each_col() %
417 cell2GateForgetWeight + inputGateError.each_col() % cell2GateInputWeight;
419 g = input2GateInputWeight.t() * inputGateError +
420 input2HiddenWeight.t() * hiddenError +
421 input2GateForgetWeight.t() * forgetGateError +
422 input2GateOutputWeight.t() * outputGateError;
424 prevError = output2GateOutputWeight.t() * outputGateError +
425 output2GateForgetWeight.t() * forgetGateError +
426 output2GateInputWeight.t() * inputGateError +
427 output2HiddenWeight.t() * hiddenError;
429 backwardStep -= batchSize;
431 if (gradientStepIdx == bpttSteps)
433 backwardStep = bpttSteps - 1;
438 template<
typename InputDataType,
typename OutputDataType>
439 template<
typename InputType,
typename ErrorType,
typename GradientType>
441 const InputType& input,
443 GradientType& gradient)
446 gradient.submat(0, 0, input2GateOutputWeight.n_elem - 1, 0) =
447 arma::vectorise(outputGateError * input.t());
448 gradient.submat(input2GateOutputWeight.n_elem, 0,
449 input2GateOutputWeight.n_elem + input2GateOutputBias.n_elem - 1, 0) =
450 arma::sum(outputGateError, 1);
451 size_t offset = input2GateOutputWeight.n_elem + input2GateOutputBias.n_elem;
454 gradient.submat(offset, 0, offset + input2GateForgetWeight.n_elem - 1, 0) =
455 arma::vectorise(forgetGateError * input.t());
456 gradient.submat(offset + input2GateForgetWeight.n_elem, 0,
457 offset + input2GateForgetWeight.n_elem +
458 input2GateForgetBias.n_elem - 1, 0) = arma::sum(forgetGateError, 1);
459 offset += input2GateForgetWeight.n_elem + input2GateForgetBias.n_elem;
462 gradient.submat(offset, 0, offset + input2GateInputWeight.n_elem - 1, 0) =
463 arma::vectorise(inputGateError * input.t());
464 gradient.submat(offset + input2GateInputWeight.n_elem, 0,
465 offset + input2GateInputWeight.n_elem +
466 input2GateInputBias.n_elem - 1, 0) = arma::sum(inputGateError, 1);
467 offset += input2GateInputWeight.n_elem + input2GateInputBias.n_elem;
470 gradient.submat(offset, 0, offset + input2HiddenWeight.n_elem - 1, 0) =
471 arma::vectorise(hiddenError * input.t());
472 gradient.submat(offset + input2HiddenWeight.n_elem, 0,
473 offset + input2HiddenWeight.n_elem + input2HiddenBias.n_elem - 1, 0) =
474 arma::sum(hiddenError, 1);
475 offset += input2HiddenWeight.n_elem + input2HiddenBias.n_elem;
478 gradient.submat(offset, 0, offset + output2GateOutputWeight.n_elem - 1, 0) =
479 arma::vectorise(outputGateError *
480 outParameter.cols(gradientStep - batchStep, gradientStep).t());
481 offset += output2GateOutputWeight.n_elem;
484 gradient.submat(offset, 0, offset + output2GateForgetWeight.n_elem - 1, 0) =
485 arma::vectorise(forgetGateError *
486 outParameter.cols(gradientStep - batchStep, gradientStep).t());
487 offset += output2GateForgetWeight.n_elem;
490 gradient.submat(offset, 0, offset + output2GateInputWeight.n_elem - 1, 0) =
491 arma::vectorise(inputGateError *
492 outParameter.cols(gradientStep - batchStep, gradientStep).t());
493 offset += output2GateInputWeight.n_elem;
496 gradient.submat(offset, 0, offset + output2HiddenWeight.n_elem - 1, 0) =
497 arma::vectorise(hiddenError *
498 outParameter.cols(gradientStep - batchStep, gradientStep).t());
499 offset += output2HiddenWeight.n_elem;
502 gradient.submat(offset, 0, offset + cell2GateOutputWeight.n_elem - 1, 0) =
503 arma::sum(outputGateError %
504 cell.cols(gradientStep - batchStep, gradientStep), 1);
505 offset += cell2GateOutputWeight.n_elem;
508 if (gradientStep > batchStep)
510 gradient.submat(offset, 0, offset + cell2GateForgetWeight.n_elem - 1, 0) =
511 arma::sum(forgetGateError %
512 cell.cols((gradientStep - batchSize) - batchStep,
513 (gradientStep - batchSize)), 1);
514 gradient.submat(offset + cell2GateForgetWeight.n_elem, 0, offset +
515 cell2GateForgetWeight.n_elem + cell2GateInputWeight.n_elem - 1, 0) =
516 arma::sum(inputGateError %
517 cell.cols((gradientStep - batchSize) - batchStep,
518 (gradientStep - batchSize)), 1);
522 gradient.submat(offset, 0, offset +
523 cell2GateForgetWeight.n_elem - 1, 0).zeros();
524 gradient.submat(offset + cell2GateForgetWeight.n_elem, 0, offset +
525 cell2GateForgetWeight.n_elem +
526 cell2GateInputWeight.n_elem - 1, 0).zeros();
529 if (gradientStep == 0)
531 gradientStep = batchSize * bpttSteps - 1;
535 gradientStep -= batchSize;
539 template<
typename InputDataType,
typename OutputDataType>
540 template<
typename Archive>
542 Archive& ar,
const uint32_t )
544 ar(CEREAL_NVP(weights));
545 ar(CEREAL_NVP(inSize));
546 ar(CEREAL_NVP(outSize));
548 ar(CEREAL_NVP(bpttSteps));
549 ar(CEREAL_NVP(batchSize));
550 ar(CEREAL_NVP(batchStep));
551 ar(CEREAL_NVP(forwardStep));
552 ar(CEREAL_NVP(backwardStep));
553 ar(CEREAL_NVP(gradientStep));
554 ar(CEREAL_NVP(gradientStepIdx));
555 ar(CEREAL_NVP(cell));
556 ar(CEREAL_NVP(inputGateActivation));
557 ar(CEREAL_NVP(forgetGateActivation));
558 ar(CEREAL_NVP(outputGateActivation));
559 ar(CEREAL_NVP(hiddenLayerActivation));
560 ar(CEREAL_NVP(cellActivation));
561 ar(CEREAL_NVP(prevError));
562 ar(CEREAL_NVP(outParameter));
LSTM & operator=(const LSTM &layer)
Copy assignment operator.
Definition: lstm_impl.hpp:67
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Definition: pointer_wrapper.hpp:23
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: lstm_impl.hpp:541
OutputDataType const & Gradient() const
Get the gradient.
Definition: lstm.hpp:176
Implementation of the LSTM module class.
Definition: layer_types.hpp:82
size_t WeightSize() const
Get the size of the weights.
Definition: lstm.hpp:187
void Forward(const InputType &input, OutputType &output)
Ordinary feed-forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: lstm_impl.hpp:237
LSTM()
Create the LSTM object.
Definition: lstm_impl.hpp:22
void Backward(const InputType &input, const ErrorType &gy, GradientType &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: lstm_impl.hpp:362