12 #ifndef MLPACK_METHODS_ANN_LAYER_RECURRENT_IMPL_HPP 13 #define MLPACK_METHODS_ANN_LAYER_RECURRENT_IMPL_HPP 18 #include "../visitor/add_visitor.hpp" 19 #include "../visitor/backward_visitor.hpp" 20 #include "../visitor/gradient_visitor.hpp" 21 #include "../visitor/gradient_zero_visitor.hpp" 22 #include "../visitor/input_shape_visitor.hpp" 27 template<
typename InputDataType,
typename OutputDataType,
28 typename... CustomLayers>
40 template <
typename InputDataType,
typename OutputDataType,
41 typename... CustomLayers>
43 typename StartModuleType,
44 typename InputModuleType,
45 typename FeedbackModuleType,
46 typename TransferModuleType
49 const StartModuleType& start,
50 const InputModuleType& input,
51 const FeedbackModuleType& feedback,
52 const TransferModuleType& transfer,
54 startModule(new StartModuleType(start)),
55 inputModule(new InputModuleType(input)),
56 feedbackModule(new FeedbackModuleType(feedback)),
57 transferModule(new TransferModuleType(transfer)),
66 mergeModule =
new AddMerge<>(
false,
false,
false);
84 network.push_back(initialModule);
85 network.push_back(mergeModule);
86 network.push_back(feedbackModule);
87 network.push_back(recurrentModule);
90 template<
typename InputDataType,
typename OutputDataType,
91 typename... CustomLayers>
95 forwardStep(network.forwardStep),
96 backwardStep(network.backwardStep),
97 gradientStep(network.gradientStep),
98 deterministic(network.deterministic),
99 ownsLayer(network.ownsLayer)
101 startModule = boost::apply_visitor(copyVisitor, network.startModule);
102 inputModule = boost::apply_visitor(copyVisitor, network.inputModule);
103 feedbackModule = boost::apply_visitor(copyVisitor, network.feedbackModule);
104 transferModule = boost::apply_visitor(copyVisitor, network.transferModule);
106 mergeModule =
new AddMerge<>(
false,
false,
false);
123 this->network.push_back(initialModule);
124 this->network.push_back(mergeModule);
125 this->network.push_back(feedbackModule);
126 this->network.push_back(recurrentModule);
129 template<
typename InputDataType,
typename OutputDataType,
130 typename... CustomLayers>
134 const size_t inputShapeStartModule = boost::apply_visitor(
InShapeVisitor(),
138 if (inputShapeStartModule != 0)
140 return inputShapeStartModule;
146 const size_t inputShapeInputModule = boost::apply_visitor(
InShapeVisitor(),
148 if (inputShapeInputModule != 0)
150 return inputShapeInputModule;
155 const size_t inputShapeFeedbackModule = boost::apply_visitor(
157 if (inputShapeFeedbackModule != 0)
159 return inputShapeFeedbackModule;
164 const size_t inputShapeTransferModule = boost::apply_visitor(
166 if (inputShapeTransferModule != 0)
168 return inputShapeTransferModule;
179 template<
typename InputDataType,
typename OutputDataType,
180 typename... CustomLayers>
181 template<
typename eT>
183 const arma::Mat<eT>& input, arma::Mat<eT>& output)
185 if (forwardStep == 0)
187 boost::apply_visitor(
ForwardVisitor(input, output), initialModule);
192 boost::apply_visitor(outputParameterVisitor, inputModule)),
196 outputParameterVisitor, transferModule),
197 boost::apply_visitor(outputParameterVisitor, feedbackModule)),
200 boost::apply_visitor(
ForwardVisitor(input, output), recurrentModule);
203 output = boost::apply_visitor(outputParameterVisitor, transferModule);
208 feedbackOutputParameter.push_back(output);
212 if (forwardStep == rho)
217 if (!recurrentError.is_empty())
219 recurrentError.zeros();
224 template<
typename InputDataType,
typename OutputDataType,
225 typename... CustomLayers>
226 template<
typename eT>
228 const arma::Mat<eT>& ,
const arma::Mat<eT>& gy, arma::Mat<eT>& g)
230 if (!recurrentError.is_empty())
232 recurrentError += gy;
239 if (backwardStep < (rho - 1))
242 outputParameterVisitor, recurrentModule), recurrentError,
243 boost::apply_visitor(deltaVisitor, recurrentModule)),
247 outputParameterVisitor, inputModule),
248 boost::apply_visitor(deltaVisitor, recurrentModule), g),
252 outputParameterVisitor, feedbackModule),
253 boost::apply_visitor(deltaVisitor, recurrentModule),
254 boost::apply_visitor(deltaVisitor, feedbackModule)), feedbackModule);
259 outputParameterVisitor, initialModule), recurrentError, g),
263 recurrentError = boost::apply_visitor(deltaVisitor, feedbackModule);
267 template<
typename InputDataType,
typename OutputDataType,
268 typename... CustomLayers>
269 template<
typename eT>
271 const arma::Mat<eT>& input,
272 const arma::Mat<eT>& error,
275 if (gradientStep < (rho - 1))
280 boost::apply_visitor(deltaVisitor, mergeModule)), inputModule);
283 feedbackOutputParameter[feedbackOutputParameter.size() - 2 -
284 gradientStep], boost::apply_visitor(deltaVisitor,
285 mergeModule)), feedbackModule);
294 boost::apply_visitor(deltaVisitor, startModule)), initialModule);
298 if (gradientStep == rho)
301 feedbackOutputParameter.clear();
305 template<
typename InputDataType,
typename OutputDataType,
306 typename... CustomLayers>
307 template<
typename Archive>
309 Archive& ar,
const uint32_t )
312 if (cereal::is_loading<Archive>())
326 ar(CEREAL_NVP(ownsLayer));
329 if (cereal::is_loading<Archive>())
332 mergeModule =
new AddMerge<>(
false,
false,
false);
351 network.push_back(initialModule);
352 network.push_back(mergeModule);
353 network.push_back(feedbackModule);
354 network.push_back(recurrentModule);
DeleteVisitor executes the destructor of the instantiated object.
Definition: delete_visitor.hpp:27
OutputDataType const & Gradient() const
Get the gradient.
Definition: recurrent.hpp:135
BackwardVisitor executes the Backward() function given the input, error and delta parameter...
Definition: backward_visitor.hpp:28
size_t InputShape() const
Get the shape of the input.
Definition: recurrent_impl.hpp:132
Implementation of the AddMerge module class.
Definition: add_merge.hpp:42
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: recurrent_impl.hpp:227
Recurrent()
Default constructor—this will create a Recurrent object that can't be used, so be careful! Make sure...
Definition: recurrent_impl.hpp:29
Definition: gradient_zero_visitor.hpp:27
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: recurrent_impl.hpp:308
ForwardVisitor executes the Forward() function given the input and output parameter.
Definition: forward_visitor.hpp:28
InShapeVisitor returns the input shape a Layer expects.
Definition: input_shape_visitor.hpp:29
#define CEREAL_VARIANT_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_variant_wrapper.hpp:155
AddVisitor exposes the Add() method of the given module.
Definition: add_visitor.hpp:28
SearchModeVisitor executes the Gradient() method of the given module using the input and delta parame...
Definition: gradient_visitor.hpp:28
Implementation of the RecurrentLayer class.
Definition: layer_types.hpp:157
Implementation of the Sequential class.
Definition: layer_types.hpp:145
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: recurrent_impl.hpp:182