mlpack
fast_lstm_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_FAST_LSTM_IMPL_HPP
14 #define MLPACK_METHODS_ANN_LAYER_FAST_LSTM_IMPL_HPP
15 
16 // In case it hasn't yet been included.
17 #include "fast_lstm.hpp"
18 
19 namespace mlpack {
20 namespace ann {
21 
22 template<typename InputDataType, typename OutputDataType>
24 {
25  // Nothing to do here.
26 }
27 
28 template <typename InputDataType, typename OutputDataType>
30  const size_t inSize, const size_t outSize, const size_t rho) :
31  inSize(inSize),
32  outSize(outSize),
33  rho(rho),
34  forwardStep(0),
35  backwardStep(0),
36  gradientStep(0),
37  batchSize(0),
38  batchStep(0),
39  gradientStepIdx(0),
40  rhoSize(rho),
41  bpttSteps(0)
42 {
43  // Weights for: input to gate layer (4 * outsize * inSize + 4 * outsize)
44  // and output to gate (4 * outSize).
45  weights.set_size(WeightSize(), 1);
46 }
47 
48 template<typename InputDataType, typename OutputDataType>
50  inSize(layer.inSize),
51  outSize(layer.outSize),
52  rho(layer.rho),
53  forwardStep(layer.forwardStep),
54  backwardStep(layer.backwardStep),
55  gradientStep(layer.gradientStep),
56  weights(layer.weights),
57  batchSize(layer.batchSize),
58  batchStep(layer.batchStep),
59  gradientStepIdx(layer.gradientStepIdx),
60  grad(layer.grad),
61  rhoSize(layer.rho),
62  bpttSteps(layer.bpttSteps)
63 {
64  // Nothing to do here.
65 }
66 
67 template<typename InputDataType, typename OutputDataType>
69  inSize(std::move(layer.inSize)),
70  outSize(std::move(layer.outSize)),
71  rho(std::move(layer.rho)),
72  forwardStep(std::move(layer.forwardStep)),
73  backwardStep(std::move(layer.backwardStep)),
74  gradientStep(std::move(layer.gradientStep)),
75  weights(std::move(layer.weights)),
76  batchSize(std::move(layer.batchSize)),
77  batchStep(std::move(layer.batchStep)),
78  gradientStepIdx(std::move(layer.gradientStepIdx)),
79  grad(std::move(layer.grad)),
80  rhoSize(std::move(layer.rho)),
81  bpttSteps(std::move(layer.bpttSteps))
82 {
83  // Nothing to do here.
84 }
85 
86 template<typename InputDataType, typename OutputDataType>
89 {
90  if (this != &layer)
91  {
92  inSize = layer.inSize;
93  outSize = layer.outSize;
94  rho = layer.rho;
95  forwardStep = layer.forwardStep;
96  backwardStep = layer.backwardStep;
97  gradientStep = layer.gradientStep;
98  weights = layer.weights;
99  batchSize = layer.batchSize;
100  batchStep = layer.batchStep;
101  gradientStepIdx = layer.gradientStepIdx;
102  grad = layer.grad;
103  rhoSize = layer.rho;
104  bpttSteps = layer.bpttSteps;
105  }
106  return *this;
107 }
108 
109 template<typename InputDataType, typename OutputDataType>
112 {
113  if (this != &layer)
114  {
115  inSize = std::move(layer.inSize);
116  outSize = std::move(layer.outSize);
117  rho = std::move(layer.rho);
118  forwardStep = std::move(layer.forwardStep);
119  backwardStep = std::move(layer.backwardStep);
120  gradientStep = std::move(layer.gradientStep);
121  weights = std::move(layer.weights);
122  batchSize = std::move(layer.batchSize);
123  batchStep = std::move(layer.batchStep);
124  gradientStepIdx = std::move(layer.gradientStepIdx);
125  grad = std::move(layer.grad);
126  rhoSize = std::move(layer.rho);
127  bpttSteps = std::move(layer.bpttSteps);
128  }
129  return *this;
130 }
131 
132 template<typename InputDataType, typename OutputDataType>
134 {
135  // Set the weight parameter for the input to gate layer (linear layer) using
136  // the overall layer parameter matrix.
137  input2GateWeight = OutputDataType(weights.memptr(),
138  4 * outSize, inSize, false, false);
139  input2GateBias = OutputDataType(weights.memptr() + input2GateWeight.n_elem,
140  4 * outSize, 1, false, false);
141 
142  // Set the weight parameter for the output to gate layer
143  // (linear no bias layer) using the overall layer parameter matrix.
144  output2GateWeight = OutputDataType(weights.memptr() + input2GateWeight.n_elem
145  + input2GateBias.n_elem, 4 * outSize, outSize, false, false);
146 }
147 
148 template<typename InputDataType, typename OutputDataType>
150 {
151  if (size == std::numeric_limits<size_t>::max())
152  return;
153 
154  rhoSize = size;
155 
156  if (batchSize == 0)
157  return;
158 
159  bpttSteps = std::min(rho, rhoSize);
160  forwardStep = 0;
161  gradientStepIdx = 0;
162  backwardStep = batchSize * size - 1;
163  gradientStep = batchSize * size - 1;
164 
165  const size_t rhoBatchSize = size * batchSize;
166 
167  // Make sure all of the matrices we use to store state are at least as large
168  // as we need.
169  gate.set_size(4 * outSize, rhoBatchSize);
170  gateActivation.set_size(outSize * 3, rhoBatchSize);
171  stateActivation.set_size(outSize, rhoBatchSize);
172  cellActivation.set_size(outSize, rhoBatchSize);
173  prevError.set_size(4 * outSize, batchSize);
174 
175  // Reset stored state to zeros.
176  prevOutput.zeros(outSize, batchSize);
177  cell.zeros(outSize, size * batchSize);
178  cellActivationError.zeros(outSize, batchSize);
179  outParameter.zeros(outSize, (size + 1) * batchSize);
180 }
181 
182 template<typename InputDataType, typename OutputDataType>
183 template<typename InputType, typename OutputType>
185  const InputType& input, OutputType& output)
186 {
187  // Check if the batch size changed, the number of cols is defines the input
188  // batch size.
189  if (input.n_cols != batchSize)
190  {
191  batchSize = input.n_cols;
192  batchStep = batchSize - 1;
193  ResetCell(rhoSize);
194  }
195 
196  gate.cols(forwardStep, forwardStep + batchStep) = input2GateWeight * input +
197  output2GateWeight * outParameter.cols(
198  forwardStep, forwardStep + batchStep);
199  gate.cols(forwardStep, forwardStep + batchStep).each_col() += input2GateBias;
200 
201  arma::subview<double> sigmoidOut = gateActivation.cols(forwardStep,
202  forwardStep + batchStep);
203  FastSigmoid(
204  gate.submat(0, forwardStep, 3 * outSize - 1, forwardStep + batchStep),
205  sigmoidOut);
206 
207  stateActivation.cols(forwardStep, forwardStep + batchStep) = arma::tanh(
208  gate.submat(3 * outSize, forwardStep, 4 * outSize - 1,
209  forwardStep + batchStep));
210 
211  // Update the cell: cmul1 + cmul2
212  // where cmul1 is input gate * hidden state and
213  // cmul2 is forget gate * cell (prevCell).
214  if (forwardStep == 0)
215  {
216  cell.cols(forwardStep, forwardStep + batchStep) =
217  gateActivation.submat(0, forwardStep, outSize - 1,
218  forwardStep + batchStep) %
219  stateActivation.cols(forwardStep, forwardStep + batchStep);
220  }
221  else
222  {
223  cell.cols(forwardStep, forwardStep + batchStep) =
224  gateActivation.submat(0, forwardStep, outSize - 1,
225  forwardStep + batchStep) %
226  stateActivation.cols(forwardStep, forwardStep + batchStep) +
227  gateActivation.submat(2 * outSize, forwardStep, 3 * outSize - 1,
228  forwardStep + batchStep) %
229  cell.cols(forwardStep - batchSize, forwardStep - batchSize + batchStep);
230  }
231 
232  cellActivation.cols(forwardStep, forwardStep + batchStep) =
233  arma::tanh(cell.cols(forwardStep, forwardStep + batchStep));
234 
235  outParameter.cols(forwardStep + batchSize,
236  forwardStep + batchSize + batchStep) = cellActivation.cols(
237  forwardStep, forwardStep + batchStep) % gateActivation.submat(
238  outSize, forwardStep, 2 * outSize - 1, forwardStep + batchStep);
239 
240  output = OutputType(outParameter.memptr() +
241  (forwardStep + batchSize) * outSize, outSize, batchSize, false, false);
242 
243  forwardStep += batchSize;
244  if ((forwardStep / batchSize) == bpttSteps)
245  {
246  forwardStep = 0;
247  }
248 }
249 
250 template<typename InputDataType, typename OutputDataType>
251 template<typename InputType, typename ErrorType, typename GradientType>
253  const InputType& /* input */, const ErrorType& gy, GradientType& g)
254 {
255  ErrorType gyLocal;
256  if (gradientStepIdx > 0)
257  {
258  gyLocal = gy + output2GateWeight.t() * prevError;
259  }
260  else
261  {
262  gyLocal = ErrorType(((ErrorType&) gy).memptr(), gy.n_rows, gy.n_cols, false,
263  false);
264  }
265 
266  cellActivationError = gyLocal % gateActivation.submat(outSize,
267  backwardStep - batchStep, 2 * outSize - 1, backwardStep) %
268  (1 - arma::pow(cellActivation.cols(backwardStep - batchStep,
269  backwardStep), 2));
270 
271  if (gradientStepIdx > 0)
272  cellActivationError += forgetGateError;
273 
274  forgetGateError = gateActivation.submat(2 * outSize,
275  backwardStep - batchStep, 3 * outSize - 1, backwardStep) %
276  cellActivationError;
277 
278  if (backwardStep > batchStep)
279  {
280  prevError.submat(2 * outSize, 0, 3 * outSize - 1, batchStep) =
281  cell.cols((backwardStep - batchSize) - batchStep,
282  (backwardStep - batchSize)) % cellActivationError %
283  gateActivation.submat(2 * outSize, backwardStep - batchStep,
284  3 * outSize - 1, backwardStep) % (1.0 - gateActivation.submat(
285  2 * outSize, backwardStep - batchStep, 3 * outSize - 1, backwardStep));
286  }
287  else
288  {
289  prevError.submat(2 * outSize, 0, 3 * outSize - 1, batchStep).zeros();
290  }
291 
292  prevError.submat(0, 0, outSize - 1, batchStep) =
293  stateActivation.cols(backwardStep - batchStep,
294  backwardStep) % cellActivationError % gateActivation.submat(
295  0, backwardStep - batchStep, outSize - 1, backwardStep) %
296  (1.0 - gateActivation.submat(
297  0, backwardStep - batchStep, outSize - 1, backwardStep));
298 
299  prevError.submat(3 * outSize, 0, 4 * outSize - 1, batchStep) =
300  gateActivation.submat(0, backwardStep - batchStep,
301  outSize - 1, backwardStep) % cellActivationError % (1 - arma::pow(
302  stateActivation.cols(backwardStep - batchStep, backwardStep), 2));
303 
304  prevError.submat(outSize, 0, 2 * outSize - 1, batchStep) =
305  cellActivation.cols(backwardStep - batchStep,
306  backwardStep) % gyLocal % gateActivation.submat(
307  outSize, backwardStep - batchStep, 2 * outSize - 1, backwardStep) %
308  (1.0 - gateActivation.submat(
309  outSize, backwardStep - batchStep, 2 * outSize - 1, backwardStep));
310 
311  g = input2GateWeight.t() * prevError;
312 
313  backwardStep -= batchSize;
314  gradientStepIdx++;
315  if (gradientStepIdx == bpttSteps)
316  {
317  backwardStep = bpttSteps - 1;
318  gradientStepIdx = 0;
319  }
320 }
321 
322 template<typename InputDataType, typename OutputDataType>
323 template<typename InputType, typename ErrorType, typename GradientType>
325  const InputType& input,
326  const ErrorType& /* error */,
327  GradientType& gradient)
328 {
329  // Gradient of the input to gate layer.
330  gradient.submat(0, 0, input2GateWeight.n_elem - 1, 0) =
331  arma::vectorise(prevError * input.t());
332 
333  gradient.submat(input2GateWeight.n_elem, 0, input2GateWeight.n_elem +
334  input2GateBias.n_elem - 1, 0) = arma::sum(prevError, 1);
335 
336  // Gradient of the output to gate layer.
337  gradient.submat(input2GateWeight.n_elem + input2GateBias.n_elem, 0,
338  gradient.n_elem - 1, 0) = arma::vectorise(prevError *
339  outParameter.cols(gradientStep - batchStep, gradientStep).t());
340 
341  if (gradientStep > batchStep)
342  {
343  gradientStep -= batchSize;
344  }
345  else
346  {
347  gradientStep = batchSize * bpttSteps - 1;
348  }
349 }
350 
351 template<typename InputDataType, typename OutputDataType>
352 template<typename Archive>
354  Archive& ar, const uint32_t /* version */)
355 {
356  ar(CEREAL_NVP(weights));
357  ar(CEREAL_NVP(inSize));
358  ar(CEREAL_NVP(outSize));
359  ar(CEREAL_NVP(rho));
360  ar(CEREAL_NVP(bpttSteps));
361  ar(CEREAL_NVP(batchSize));
362  ar(CEREAL_NVP(batchStep));
363  ar(CEREAL_NVP(forwardStep));
364  ar(CEREAL_NVP(backwardStep));
365  ar(CEREAL_NVP(gradientStep));
366  ar(CEREAL_NVP(gradientStepIdx));
367  ar(CEREAL_NVP(cell));
368  ar(CEREAL_NVP(stateActivation));
369  ar(CEREAL_NVP(gateActivation));
370  ar(CEREAL_NVP(gate));
371  ar(CEREAL_NVP(cellActivation));
372  ar(CEREAL_NVP(forgetGateError));
373  ar(CEREAL_NVP(prevError));
374  ar(CEREAL_NVP(outParameter));
375 }
376 
377 } // namespace ann
378 } // namespace mlpack
379 
380 #endif
void Forward(const InputType &input, OutputType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: fast_lstm_impl.hpp:184
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: fast_lstm_impl.hpp:353
void Backward(const InputType &input, const ErrorType &gy, GradientType &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: fast_lstm_impl.hpp:252
OutputDataType const & Gradient() const
Get the gradient.
Definition: fast_lstm.hpp:169
Definition: pointer_wrapper.hpp:23
FastLSTM()
Create the Fast LSTM object.
Definition: fast_lstm_impl.hpp:23
size_t WeightSize() const
Get the size of the weight matrix.
Definition: fast_lstm.hpp:180
FastLSTM & operator=(const FastLSTM &layer)
Copy assignment operator.
Definition: fast_lstm_impl.hpp:88
An implementation of a faster version of the Fast LSTM network layer.
Definition: fast_lstm.hpp:66