13 #ifndef MLPACK_METHODS_ANN_LAYER_FAST_LSTM_IMPL_HPP 14 #define MLPACK_METHODS_ANN_LAYER_FAST_LSTM_IMPL_HPP 22 template<
typename InputDataType,
typename OutputDataType>
28 template <
typename InputDataType,
typename OutputDataType>
30 const size_t inSize,
const size_t outSize,
const size_t rho) :
48 template<
typename InputDataType,
typename OutputDataType>
51 outSize(layer.outSize),
53 forwardStep(layer.forwardStep),
54 backwardStep(layer.backwardStep),
55 gradientStep(layer.gradientStep),
56 weights(layer.weights),
57 batchSize(layer.batchSize),
58 batchStep(layer.batchStep),
59 gradientStepIdx(layer.gradientStepIdx),
62 bpttSteps(layer.bpttSteps)
67 template<
typename InputDataType,
typename OutputDataType>
69 inSize(
std::move(layer.inSize)),
70 outSize(
std::move(layer.outSize)),
71 rho(
std::move(layer.rho)),
72 forwardStep(
std::move(layer.forwardStep)),
73 backwardStep(
std::move(layer.backwardStep)),
74 gradientStep(
std::move(layer.gradientStep)),
75 weights(
std::move(layer.weights)),
76 batchSize(
std::move(layer.batchSize)),
77 batchStep(
std::move(layer.batchStep)),
78 gradientStepIdx(
std::move(layer.gradientStepIdx)),
79 grad(
std::move(layer.grad)),
80 rhoSize(
std::move(layer.rho)),
81 bpttSteps(
std::move(layer.bpttSteps))
86 template<
typename InputDataType,
typename OutputDataType>
92 inSize = layer.inSize;
93 outSize = layer.outSize;
95 forwardStep = layer.forwardStep;
96 backwardStep = layer.backwardStep;
97 gradientStep = layer.gradientStep;
98 weights = layer.weights;
99 batchSize = layer.batchSize;
100 batchStep = layer.batchStep;
101 gradientStepIdx = layer.gradientStepIdx;
104 bpttSteps = layer.bpttSteps;
109 template<
typename InputDataType,
typename OutputDataType>
115 inSize = std::move(layer.inSize);
116 outSize = std::move(layer.outSize);
117 rho = std::move(layer.rho);
118 forwardStep = std::move(layer.forwardStep);
119 backwardStep = std::move(layer.backwardStep);
120 gradientStep = std::move(layer.gradientStep);
121 weights = std::move(layer.weights);
122 batchSize = std::move(layer.batchSize);
123 batchStep = std::move(layer.batchStep);
124 gradientStepIdx = std::move(layer.gradientStepIdx);
125 grad = std::move(layer.grad);
126 rhoSize = std::move(layer.rho);
127 bpttSteps = std::move(layer.bpttSteps);
132 template<
typename InputDataType,
typename OutputDataType>
137 input2GateWeight = OutputDataType(weights.memptr(),
138 4 * outSize, inSize,
false,
false);
139 input2GateBias = OutputDataType(weights.memptr() + input2GateWeight.n_elem,
140 4 * outSize, 1,
false,
false);
144 output2GateWeight = OutputDataType(weights.memptr() + input2GateWeight.n_elem
145 + input2GateBias.n_elem, 4 * outSize, outSize,
false,
false);
148 template<
typename InputDataType,
typename OutputDataType>
151 if (size == std::numeric_limits<size_t>::max())
159 bpttSteps = std::min(rho, rhoSize);
162 backwardStep = batchSize * size - 1;
163 gradientStep = batchSize * size - 1;
165 const size_t rhoBatchSize = size * batchSize;
169 gate.set_size(4 * outSize, rhoBatchSize);
170 gateActivation.set_size(outSize * 3, rhoBatchSize);
171 stateActivation.set_size(outSize, rhoBatchSize);
172 cellActivation.set_size(outSize, rhoBatchSize);
173 prevError.set_size(4 * outSize, batchSize);
176 prevOutput.zeros(outSize, batchSize);
177 cell.zeros(outSize, size * batchSize);
178 cellActivationError.zeros(outSize, batchSize);
179 outParameter.zeros(outSize, (size + 1) * batchSize);
182 template<
typename InputDataType,
typename OutputDataType>
183 template<
typename InputType,
typename OutputType>
185 const InputType& input, OutputType& output)
189 if (input.n_cols != batchSize)
191 batchSize = input.n_cols;
192 batchStep = batchSize - 1;
196 gate.cols(forwardStep, forwardStep + batchStep) = input2GateWeight * input +
197 output2GateWeight * outParameter.cols(
198 forwardStep, forwardStep + batchStep);
199 gate.cols(forwardStep, forwardStep + batchStep).each_col() += input2GateBias;
201 arma::subview<double> sigmoidOut = gateActivation.cols(forwardStep,
202 forwardStep + batchStep);
204 gate.submat(0, forwardStep, 3 * outSize - 1, forwardStep + batchStep),
207 stateActivation.cols(forwardStep, forwardStep + batchStep) = arma::tanh(
208 gate.submat(3 * outSize, forwardStep, 4 * outSize - 1,
209 forwardStep + batchStep));
214 if (forwardStep == 0)
216 cell.cols(forwardStep, forwardStep + batchStep) =
217 gateActivation.submat(0, forwardStep, outSize - 1,
218 forwardStep + batchStep) %
219 stateActivation.cols(forwardStep, forwardStep + batchStep);
223 cell.cols(forwardStep, forwardStep + batchStep) =
224 gateActivation.submat(0, forwardStep, outSize - 1,
225 forwardStep + batchStep) %
226 stateActivation.cols(forwardStep, forwardStep + batchStep) +
227 gateActivation.submat(2 * outSize, forwardStep, 3 * outSize - 1,
228 forwardStep + batchStep) %
229 cell.cols(forwardStep - batchSize, forwardStep - batchSize + batchStep);
232 cellActivation.cols(forwardStep, forwardStep + batchStep) =
233 arma::tanh(cell.cols(forwardStep, forwardStep + batchStep));
235 outParameter.cols(forwardStep + batchSize,
236 forwardStep + batchSize + batchStep) = cellActivation.cols(
237 forwardStep, forwardStep + batchStep) % gateActivation.submat(
238 outSize, forwardStep, 2 * outSize - 1, forwardStep + batchStep);
240 output = OutputType(outParameter.memptr() +
241 (forwardStep + batchSize) * outSize, outSize, batchSize,
false,
false);
243 forwardStep += batchSize;
244 if ((forwardStep / batchSize) == bpttSteps)
250 template<
typename InputDataType,
typename OutputDataType>
251 template<
typename InputType,
typename ErrorType,
typename GradientType>
253 const InputType& ,
const ErrorType& gy, GradientType& g)
256 if (gradientStepIdx > 0)
258 gyLocal = gy + output2GateWeight.t() * prevError;
262 gyLocal = ErrorType(((ErrorType&) gy).memptr(), gy.n_rows, gy.n_cols,
false,
266 cellActivationError = gyLocal % gateActivation.submat(outSize,
267 backwardStep - batchStep, 2 * outSize - 1, backwardStep) %
268 (1 - arma::pow(cellActivation.cols(backwardStep - batchStep,
271 if (gradientStepIdx > 0)
272 cellActivationError += forgetGateError;
274 forgetGateError = gateActivation.submat(2 * outSize,
275 backwardStep - batchStep, 3 * outSize - 1, backwardStep) %
278 if (backwardStep > batchStep)
280 prevError.submat(2 * outSize, 0, 3 * outSize - 1, batchStep) =
281 cell.cols((backwardStep - batchSize) - batchStep,
282 (backwardStep - batchSize)) % cellActivationError %
283 gateActivation.submat(2 * outSize, backwardStep - batchStep,
284 3 * outSize - 1, backwardStep) % (1.0 - gateActivation.submat(
285 2 * outSize, backwardStep - batchStep, 3 * outSize - 1, backwardStep));
289 prevError.submat(2 * outSize, 0, 3 * outSize - 1, batchStep).zeros();
292 prevError.submat(0, 0, outSize - 1, batchStep) =
293 stateActivation.cols(backwardStep - batchStep,
294 backwardStep) % cellActivationError % gateActivation.submat(
295 0, backwardStep - batchStep, outSize - 1, backwardStep) %
296 (1.0 - gateActivation.submat(
297 0, backwardStep - batchStep, outSize - 1, backwardStep));
299 prevError.submat(3 * outSize, 0, 4 * outSize - 1, batchStep) =
300 gateActivation.submat(0, backwardStep - batchStep,
301 outSize - 1, backwardStep) % cellActivationError % (1 - arma::pow(
302 stateActivation.cols(backwardStep - batchStep, backwardStep), 2));
304 prevError.submat(outSize, 0, 2 * outSize - 1, batchStep) =
305 cellActivation.cols(backwardStep - batchStep,
306 backwardStep) % gyLocal % gateActivation.submat(
307 outSize, backwardStep - batchStep, 2 * outSize - 1, backwardStep) %
308 (1.0 - gateActivation.submat(
309 outSize, backwardStep - batchStep, 2 * outSize - 1, backwardStep));
311 g = input2GateWeight.t() * prevError;
313 backwardStep -= batchSize;
315 if (gradientStepIdx == bpttSteps)
317 backwardStep = bpttSteps - 1;
322 template<
typename InputDataType,
typename OutputDataType>
323 template<
typename InputType,
typename ErrorType,
typename GradientType>
325 const InputType& input,
327 GradientType& gradient)
330 gradient.submat(0, 0, input2GateWeight.n_elem - 1, 0) =
331 arma::vectorise(prevError * input.t());
333 gradient.submat(input2GateWeight.n_elem, 0, input2GateWeight.n_elem +
334 input2GateBias.n_elem - 1, 0) = arma::sum(prevError, 1);
337 gradient.submat(input2GateWeight.n_elem + input2GateBias.n_elem, 0,
338 gradient.n_elem - 1, 0) = arma::vectorise(prevError *
339 outParameter.cols(gradientStep - batchStep, gradientStep).t());
341 if (gradientStep > batchStep)
343 gradientStep -= batchSize;
347 gradientStep = batchSize * bpttSteps - 1;
351 template<
typename InputDataType,
typename OutputDataType>
352 template<
typename Archive>
354 Archive& ar,
const uint32_t )
356 ar(CEREAL_NVP(weights));
357 ar(CEREAL_NVP(inSize));
358 ar(CEREAL_NVP(outSize));
360 ar(CEREAL_NVP(bpttSteps));
361 ar(CEREAL_NVP(batchSize));
362 ar(CEREAL_NVP(batchStep));
363 ar(CEREAL_NVP(forwardStep));
364 ar(CEREAL_NVP(backwardStep));
365 ar(CEREAL_NVP(gradientStep));
366 ar(CEREAL_NVP(gradientStepIdx));
367 ar(CEREAL_NVP(cell));
368 ar(CEREAL_NVP(stateActivation));
369 ar(CEREAL_NVP(gateActivation));
370 ar(CEREAL_NVP(gate));
371 ar(CEREAL_NVP(cellActivation));
372 ar(CEREAL_NVP(forgetGateError));
373 ar(CEREAL_NVP(prevError));
374 ar(CEREAL_NVP(outParameter));
void Forward(const InputType &input, OutputType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: fast_lstm_impl.hpp:184
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: fast_lstm_impl.hpp:353
void Backward(const InputType &input, const ErrorType &gy, GradientType &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: fast_lstm_impl.hpp:252
OutputDataType const & Gradient() const
Get the gradient.
Definition: fast_lstm.hpp:169
Definition: pointer_wrapper.hpp:23
FastLSTM()
Create the Fast LSTM object.
Definition: fast_lstm_impl.hpp:23
size_t WeightSize() const
Get the size of the weight matrix.
Definition: fast_lstm.hpp:180
FastLSTM & operator=(const FastLSTM &layer)
Copy assignment operator.
Definition: fast_lstm_impl.hpp:88
An implementation of a faster version of the Fast LSTM network layer.
Definition: fast_lstm.hpp:66