13 #ifndef MLPACK_METHODS_ANN_LAYER_GRU_IMPL_HPP 14 #define MLPACK_METHODS_ANN_LAYER_GRU_IMPL_HPP 19 #include "../visitor/forward_visitor.hpp" 20 #include "../visitor/backward_visitor.hpp" 21 #include "../visitor/gradient_visitor.hpp" 26 template<
typename InputDataType,
typename OutputDataType>
32 template <
typename InputDataType,
typename OutputDataType>
47 input2GateModule =
new Linear<>(inSize, 3 * outSize);
55 network.push_back(input2GateModule);
56 network.push_back(output2GateModule);
57 network.push_back(outputHidden2GateModule);
63 network.push_back(inputGateModule);
64 network.push_back(hiddenStateModule);
65 network.push_back(forgetGateModule);
67 prevError = arma::zeros<arma::mat>(3 * outSize, batchSize);
69 allZeros = arma::zeros<arma::mat>(outSize, batchSize);
71 outParameter.emplace_back(allZeros.memptr(),
72 allZeros.n_rows, allZeros.n_cols,
false,
true);
74 prevOutput = outParameter.begin();
75 backIterator = outParameter.end();
76 gradIterator = outParameter.end();
79 template<
typename InputDataType,
typename OutputDataType>
82 const arma::Mat<eT>& input, arma::Mat<eT>& output)
84 if (input.n_cols != batchSize)
86 batchSize = input.n_cols;
87 prevError.resize(3 * outSize, batchSize);
88 allZeros.zeros(outSize, batchSize);
90 if (outParameter.size() > 1)
92 Log::Fatal <<
"GRU<>::Forward(): batch size cannot change during a " 93 <<
"forward pass!" << std::endl;
97 outParameter.emplace_back(allZeros.memptr(),
98 allZeros.n_rows, allZeros.n_cols,
false,
true);
100 prevOutput = outParameter.begin();
101 backIterator = outParameter.end();
102 gradIterator = outParameter.end();
107 boost::apply_visitor(outputParameterVisitor, input2GateModule)),
112 boost::apply_visitor(outputParameterVisitor, output2GateModule)),
116 output = (boost::apply_visitor(outputParameterVisitor,
117 input2GateModule).submat(0, 0, 2 * outSize - 1, batchSize - 1) +
118 boost::apply_visitor(outputParameterVisitor, output2GateModule));
122 0, 0, 1 * outSize - 1, batchSize - 1), boost::apply_visitor(
123 outputParameterVisitor, inputGateModule)), inputGateModule);
127 1 * outSize, 0, 2 * outSize - 1, batchSize - 1),
128 boost::apply_visitor(outputParameterVisitor, forgetGateModule)),
131 arma::mat modInput = (boost::apply_visitor(outputParameterVisitor,
132 forgetGateModule) % *prevOutput);
136 boost::apply_visitor(outputParameterVisitor, outputHidden2GateModule)),
137 outputHidden2GateModule);
140 arma::mat outputH = boost::apply_visitor(outputParameterVisitor,
141 input2GateModule).submat(2 * outSize, 0, 3 * outSize - 1, batchSize - 1) +
142 boost::apply_visitor(outputParameterVisitor, outputHidden2GateModule);
146 boost::apply_visitor(outputParameterVisitor, hiddenStateModule)),
152 output = (boost::apply_visitor(outputParameterVisitor, inputGateModule)
153 % (*prevOutput - boost::apply_visitor(outputParameterVisitor,
154 hiddenStateModule))) + boost::apply_visitor(outputParameterVisitor,
158 if (forwardStep == rho)
163 outParameter.emplace_back(allZeros.memptr(),
164 allZeros.n_rows, allZeros.n_cols,
false,
true);
165 prevOutput = --outParameter.end();
169 *prevOutput = arma::mat(allZeros.memptr(),
170 allZeros.n_rows, allZeros.n_cols,
false,
true);
173 else if (!deterministic)
175 outParameter.push_back(output);
176 prevOutput = --outParameter.end();
180 if (forwardStep == 1)
182 outParameter.clear();
183 outParameter.push_back(output);
185 prevOutput = outParameter.begin();
189 *prevOutput = output;
194 template<
typename InputDataType,
typename OutputDataType>
195 template<
typename eT>
197 const arma::Mat<eT>& input,
const arma::Mat<eT>& gy, arma::Mat<eT>& g)
199 if (input.n_cols != batchSize)
201 batchSize = input.n_cols;
202 prevError.resize(3 * outSize, batchSize);
203 allZeros.zeros(outSize, batchSize);
205 if (outParameter.size() > 1)
207 Log::Fatal <<
"GRU<>::Forward(): batch size cannot change during a " 208 <<
"forward pass!" << std::endl;
211 outParameter.clear();
212 outParameter.emplace_back(allZeros.memptr(),
213 allZeros.n_rows, allZeros.n_cols,
false,
true);
215 prevOutput = outParameter.begin();
216 backIterator = outParameter.end();
217 gradIterator = outParameter.end();
220 arma::Mat<eT> gyLocal;
221 if ((outParameter.size() - backwardStep - 1) % rho != 0 && backwardStep != 0)
223 gyLocal = gy + boost::apply_visitor(deltaVisitor, output2GateModule);
227 gyLocal = arma::Mat<eT>(((arma::Mat<eT>&) gy).memptr(), gy.n_rows,
228 gy.n_cols,
false,
false);
231 if (backIterator == outParameter.end())
233 backIterator = --(--outParameter.end());
237 arma::mat dZt = gyLocal % (*backIterator -
238 boost::apply_visitor(outputParameterVisitor,
242 arma::mat dOt = gyLocal % (arma::ones<arma::mat>(outSize, batchSize) -
243 boost::apply_visitor(outputParameterVisitor, inputGateModule));
247 outputParameterVisitor, inputGateModule), dZt,
248 boost::apply_visitor(deltaVisitor, inputGateModule)),
253 outputParameterVisitor, hiddenStateModule), dOt,
254 boost::apply_visitor(deltaVisitor, hiddenStateModule)),
259 outputParameterVisitor, outputHidden2GateModule),
260 boost::apply_visitor(deltaVisitor, hiddenStateModule),
261 boost::apply_visitor(deltaVisitor, outputHidden2GateModule)),
262 outputHidden2GateModule);
265 arma::mat dRt = boost::apply_visitor(deltaVisitor, outputHidden2GateModule) %
270 outputParameterVisitor, forgetGateModule), dRt,
271 boost::apply_visitor(deltaVisitor, forgetGateModule)),
275 prevError.submat(0, 0, 1 * outSize - 1, batchSize - 1) = boost::apply_visitor(
276 deltaVisitor, inputGateModule);
279 prevError.submat(1 * outSize, 0, 2 * outSize - 1, batchSize - 1) =
280 boost::apply_visitor(deltaVisitor, forgetGateModule);
283 prevError.submat(2 * outSize, 0, 3 * outSize - 1, batchSize - 1) =
284 boost::apply_visitor(deltaVisitor, hiddenStateModule);
287 arma::mat prevErrorSubview = prevError.submat(0, 0, 2 * outSize - 1,
290 outputParameterVisitor, input2GateModule),
292 boost::apply_visitor(deltaVisitor, output2GateModule)),
296 boost::apply_visitor(deltaVisitor, output2GateModule) +=
297 boost::apply_visitor(deltaVisitor, outputHidden2GateModule) %
298 boost::apply_visitor(outputParameterVisitor, forgetGateModule);
301 boost::apply_visitor(deltaVisitor, output2GateModule) += gyLocal %
302 boost::apply_visitor(outputParameterVisitor, inputGateModule);
306 outputParameterVisitor, input2GateModule), prevError,
307 boost::apply_visitor(deltaVisitor, input2GateModule)),
313 g = boost::apply_visitor(deltaVisitor, input2GateModule);
316 template<
typename InputDataType,
typename OutputDataType>
317 template<
typename eT>
319 const arma::Mat<eT>& input,
320 const arma::Mat<eT>& ,
323 if (input.n_cols != batchSize)
325 batchSize = input.n_cols;
326 prevError.resize(3 * outSize, batchSize);
327 allZeros.zeros(outSize, batchSize);
329 if (outParameter.size() > 1)
331 Log::Fatal <<
"GRU<>::Forward(): batch size cannot change during a " 332 <<
"forward pass!" << std::endl;
335 outParameter.clear();
336 outParameter.emplace_back(allZeros.memptr(),
337 allZeros.n_rows, allZeros.n_cols,
false,
true);
339 prevOutput = outParameter.begin();
340 backIterator = outParameter.end();
341 gradIterator = outParameter.end();
344 if (gradIterator == outParameter.end())
346 gradIterator = --(--outParameter.end());
349 boost::apply_visitor(
GradientVisitor(input, prevError), input2GateModule);
353 prevError.submat(0, 0, 2 * outSize - 1, batchSize - 1)),
357 *gradIterator % boost::apply_visitor(outputParameterVisitor,
359 prevError.submat(2 * outSize, 0, 3 * outSize - 1, batchSize - 1)),
360 outputHidden2GateModule);
365 template<
typename InputDataType,
typename OutputDataType>
368 outParameter.clear();
369 outParameter.emplace_back(allZeros.memptr(),
370 allZeros.n_rows, allZeros.n_cols,
false,
true);
372 prevOutput = outParameter.begin();
373 backIterator = outParameter.end();
374 gradIterator = outParameter.end();
380 template<
typename InputDataType,
typename OutputDataType>
381 template<
typename Archive>
383 Archive& ar,
const uint32_t )
386 if (cereal::is_loading<Archive>())
388 boost::apply_visitor(deleteVisitor, input2GateModule);
389 boost::apply_visitor(deleteVisitor, output2GateModule);
390 boost::apply_visitor(deleteVisitor, outputHidden2GateModule);
391 boost::apply_visitor(deleteVisitor, inputGateModule);
392 boost::apply_visitor(deleteVisitor, forgetGateModule);
393 boost::apply_visitor(deleteVisitor, hiddenStateModule);
396 ar(CEREAL_NVP(inSize));
397 ar(CEREAL_NVP(outSize));
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: gru_impl.hpp:81
BackwardVisitor executes the Backward() function given the input, error and delta parameter...
Definition: backward_visitor.hpp:28
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
GRU()
Create the GRU object.
Definition: gru_impl.hpp:27
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: gru_impl.hpp:382
Implementation of the Linear layer class.
Definition: layer_types.hpp:93
Implementation of the base layer.
Definition: base_layer.hpp:71
ForwardVisitor executes the Forward() function given the input and output parameter.
Definition: forward_visitor.hpp:28
Implementation of the LinearNoBias class.
Definition: layer_types.hpp:103
#define CEREAL_VARIANT_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_variant_wrapper.hpp:155
An implementation of a gru network layer.
Definition: gru.hpp:58
SearchModeVisitor executes the Gradient() method of the given module using the input and delta parame...
Definition: gradient_visitor.hpp:28
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: gru_impl.hpp:196
OutputDataType const & Gradient() const
Get the gradient.
Definition: gru.hpp:145