13 #ifndef MLPACK_METHODS_ANN_BRNN_IMPL_HPP 14 #define MLPACK_METHODS_ANN_BRNN_IMPL_HPP 32 template<
typename OutputLayerType,
typename MergeLayerType,
33 typename MergeOutputType,
typename InitializationRuleType,
34 typename... CustomLayers>
35 BRNN<OutputLayerType, MergeLayerType, MergeOutputType,
36 InitializationRuleType, CustomLayers...>
::BRNN(
39 OutputLayerType outputLayer,
40 MergeLayerType* mergeLayer,
41 MergeOutputType* mergeOutput,
42 InitializationRuleType initializeRule) :
44 outputLayer(
std::move(outputLayer)),
45 mergeLayer(mergeLayer),
46 mergeOutput(mergeOutput),
47 initializeRule(
std::move(initializeRule)),
55 forwardRNN(rho, single, outputLayer, initializeRule),
56 backwardRNN(rho, single, outputLayer, initializeRule)
61 template<
typename OutputLayerType,
typename MergeLayerType,
62 typename MergeOutputType,
typename InitializationRuleType,
63 typename... CustomLayers>
64 BRNN<OutputLayerType, MergeLayerType, MergeOutputType,
65 InitializationRuleType, CustomLayers...>::~BRNN()
70 forwardRNN.network.pop_back();
71 backwardRNN.network.pop_back();
78 template<
typename OutputLayerType,
typename MergeLayerType,
79 typename MergeOutputType,
typename InitializationRuleType,
80 typename... CustomLayers>
81 template<
typename OptimizerType>
82 typename std::enable_if<
83 HasMaxIterations<OptimizerType, size_t&(OptimizerType::*)()>
85 BRNN<OutputLayerType, MergeLayerType, MergeOutputType,
87 (OptimizerType& optimizer,
size_t samples)
const 89 if (optimizer.MaxIterations() < samples &&
90 optimizer.MaxIterations() != 0)
92 Log::Warn <<
"The optimizer's maximum number of iterations " 93 <<
"is less than the size of the dataset; the " 94 <<
"optimizer will not pass over the entire " 95 <<
"dataset. To fix this, modify the maximum " 96 <<
"number of iterations to be at least equal " 97 <<
"to the number of points of your dataset " 98 <<
"(" << samples <<
")." << std::endl;
102 template<
typename OutputLayerType,
typename MergeLayerType,
103 typename MergeOutputType,
typename InitializationRuleType,
104 typename... CustomLayers>
105 template<
typename OptimizerType>
106 typename std::enable_if<
107 !HasMaxIterations<OptimizerType, size_t&(OptimizerType::*)()>
109 BRNN<OutputLayerType, MergeLayerType, MergeOutputType,
111 (OptimizerType& ,
size_t )
const 116 template<
typename OutputLayerType,
typename MergeLayerType,
117 typename MergeOutputType,
typename InitializationRuleType,
118 typename... CustomLayers>
119 template<
typename OptimizerType>
120 double BRNN<OutputLayerType, MergeLayerType, MergeOutputType,
121 InitializationRuleType, CustomLayers...>
::Train(
122 arma::cube predictors,
123 arma::cube responses,
124 OptimizerType& optimizer)
126 numFunctions = responses.n_cols;
128 this->predictors = std::move(predictors);
129 this->responses = std::move(responses);
131 this->deterministic =
true;
132 ResetDeterministic();
139 WarnMessageMaxIterations<OptimizerType>(optimizer, this->predictors.n_cols);
143 const double out = optimizer.Optimize(*
this, parameter);
146 Log::Info <<
"BRNN::BRNN(): final objective of trained model is " << out
151 template<
typename OutputLayerType,
typename MergeLayerType,
152 typename MergeOutputType,
typename InitializationRuleType,
153 typename... CustomLayers>
154 template<
typename OptimizerType>
155 double BRNN<OutputLayerType, MergeLayerType, MergeOutputType,
156 InitializationRuleType, CustomLayers...>
::Train(
157 arma::cube predictors,
158 arma::cube responses)
160 numFunctions = responses.n_cols;
162 this->predictors = std::move(predictors);
163 this->responses = std::move(responses);
165 this->deterministic =
true;
166 ResetDeterministic();
173 OptimizerType optimizer;
175 WarnMessageMaxIterations<OptimizerType>(optimizer, this->predictors.n_cols);
178 const double out = optimizer.Optimize(*
this, parameter);
180 Log::Info <<
"BRNN::BRNN(): final objective of trained model is " << out
185 template<
typename OutputLayerType,
typename MergeLayerType,
186 typename MergeOutputType,
typename InitializationRuleType,
187 typename... CustomLayers>
188 void BRNN<OutputLayerType, MergeLayerType, MergeOutputType,
190 arma::cube predictors, arma::cube& results,
const size_t batchSize)
192 forwardRNN.rho = backwardRNN.rho = rho;
194 forwardRNN.ResetCells();
195 backwardRNN.ResetCells();
199 deterministic =
true;
200 ResetDeterministic();
202 if (parameter.is_empty())
207 if (std::is_same<MergeLayerType,
Concat<>>::value)
209 results = arma::zeros<arma::cube>(outputSize * 2, predictors.n_cols, rho);
213 results = arma::zeros<arma::cube>(outputSize, predictors.n_cols, rho);
216 std::vector<arma::mat> results1, results2;
220 for (
size_t begin = 0; begin < predictors.n_cols; begin += batchSize)
222 const size_t effectiveBatchSize = std::min(batchSize,
223 size_t(predictors.n_cols - begin));
224 for (
size_t seqNum = 0; seqNum < rho; ++seqNum)
226 forwardRNN.Forward(arma::mat(
227 predictors.slice(seqNum).colptr(begin),
228 predictors.n_rows, effectiveBatchSize,
false,
true));
229 backwardRNN.Forward(std::move(arma::mat(
230 predictors.slice(rho - seqNum - 1).colptr(begin),
231 predictors.n_rows, effectiveBatchSize,
false,
true)));
234 forwardRNN.network.back());
236 backwardRNN.network.back());
238 reverse(results1.begin(), results1.end());
241 for (
size_t seqNum = 0; seqNum < rho; ++seqNum)
244 forwardRNN.network.back());
246 backwardRNN.network.back());
249 boost::apply_visitor(outputParameterVisitor, mergeLayer)),
252 boost::apply_visitor(outputParameterVisitor, mergeLayer),
253 boost::apply_visitor(outputParameterVisitor, mergeOutput)),
255 results.slice(seqNum).submat(0, begin, results.n_rows - 1, begin +
256 effectiveBatchSize - 1) =
257 boost::apply_visitor(outputParameterVisitor, mergeOutput);
262 template<
typename OutputLayerType,
typename MergeLayerType,
263 typename MergeOutputType,
typename InitializationRuleType,
264 typename... CustomLayers>
265 double BRNN<OutputLayerType, MergeLayerType, MergeOutputType,
269 const size_t batchSize,
270 const bool deterministic)
272 forwardRNN.rho = backwardRNN.rho = rho;
273 if (parameter.is_empty())
278 if (deterministic != this->deterministic)
280 this->deterministic = deterministic;
281 ResetDeterministic();
286 inputSize = predictors.n_rows;
287 targetSize = responses.n_rows;
289 else if (targetSize == 0)
291 targetSize = responses.n_rows;
294 forwardRNN.ResetCells();
295 backwardRNN.ResetCells();
297 double performance = 0;
298 size_t responseSeq = 0;
300 std::vector<arma::mat> results1, results2;
301 for (
size_t seqNum = 0; seqNum < rho; ++seqNum)
303 forwardRNN.Forward(arma::mat(
304 predictors.slice(seqNum).colptr(begin),
305 predictors.n_rows, batchSize,
false,
true));
306 backwardRNN.Forward(arma::mat(
307 predictors.slice(rho - seqNum - 1).colptr(begin),
308 predictors.n_rows, batchSize,
false,
true));
311 forwardRNN.network.back());
313 backwardRNN.network.back());
317 outputSize = boost::apply_visitor(outputParameterVisitor,
318 forwardRNN.network.back()).n_elem / batchSize;
319 forwardRNN.outputSize = backwardRNN.outputSize = outputSize;
321 reverse(results1.begin(), results1.end());
325 for (
size_t seqNum = 0; seqNum < rho; ++seqNum)
329 responseSeq = seqNum;
332 forwardRNN.network.back());
334 backwardRNN.network.back());
337 boost::apply_visitor(outputParameterVisitor, mergeLayer)),
340 boost::apply_visitor(outputParameterVisitor, mergeLayer),
341 boost::apply_visitor(outputParameterVisitor, mergeOutput)),
343 performance += outputLayer.Forward(
344 boost::apply_visitor(outputParameterVisitor, mergeOutput),
345 arma::mat(responses.slice(responseSeq).colptr(begin),
346 responses.n_rows, batchSize,
false,
true));
351 template<
typename OutputLayerType,
typename MergeLayerType,
352 typename MergeOutputType,
typename InitializationRuleType,
353 typename... CustomLayers>
354 double BRNN<OutputLayerType, MergeLayerType, MergeOutputType,
356 const arma::mat& parameters,
358 const size_t batchSize)
360 return Evaluate(parameters, begin, batchSize,
true);
363 template<
typename OutputLayerType,
typename MergeLayerType,
364 typename MergeOutputType,
typename InitializationRuleType,
365 typename... CustomLayers>
366 template<
typename GradType>
367 double BRNN<OutputLayerType, MergeLayerType, MergeOutputType,
368 InitializationRuleType, CustomLayers...>
:: 372 const size_t batchSize)
374 forwardRNN.rho = backwardRNN.rho = rho;
375 if (gradient.is_empty())
377 if (parameter.is_empty())
381 gradient = arma::zeros<arma::mat>(parameter.n_rows, parameter.n_cols);
388 if (backwardGradient.is_empty())
390 backwardGradient = arma::zeros<arma::mat>(
393 forwardGradient = arma::zeros<arma::mat>(
397 if (this->deterministic)
399 this->deterministic =
false;
400 ResetDeterministic();
405 inputSize = predictors.n_rows;
406 targetSize = responses.n_rows;
408 else if (targetSize == 0)
410 targetSize = responses.n_rows;
413 forwardRNN.ResetCells();
414 backwardRNN.ResetCells();
415 size_t networkSize = backwardRNN.network.size();
418 std::vector<arma::mat> results1, results2;
419 for (
size_t seqNum = 0; seqNum < rho; ++seqNum)
421 forwardRNN.Forward(arma::mat(
422 predictors.slice(seqNum).colptr(begin),
423 predictors.n_rows, batchSize,
false,
true));
424 backwardRNN.Forward(arma::mat(
425 predictors.slice(rho - seqNum - 1).colptr(begin),
426 predictors.n_rows, batchSize,
false,
true));
428 for (
size_t l = 0; l < networkSize; ++l)
431 forwardRNNOutputParameter), forwardRNN.network[l]);
433 backwardRNNOutputParameter), backwardRNN.network[l]);
436 forwardRNN.network.back());
438 backwardRNN.network.back());
442 outputSize = boost::apply_visitor(outputParameterVisitor,
443 forwardRNN.network.back()).n_elem / batchSize;
444 forwardRNN.outputSize = backwardRNN.outputSize = outputSize;
448 if (std::is_same<MergeLayerType,
Concat<>>::value)
450 results = arma::zeros<arma::cube>(outputSize * 2, batchSize, rho);
454 results = arma::zeros<arma::cube>(outputSize, batchSize, rho);
457 double performance = 0;
458 size_t responseSeq = 0;
461 reverse(results1.begin(), results1.end());
463 for (
size_t seqNum = 0; seqNum < rho; ++seqNum)
467 responseSeq = seqNum;
470 results1), forwardRNN.network.back());
472 results2), backwardRNN.network.back());
474 boost::apply_visitor(outputParameterVisitor, mergeLayer)),
477 boost::apply_visitor(outputParameterVisitor, mergeLayer),
478 results.slice(seqNum)), mergeOutput);
479 performance += outputLayer.Forward(results.slice(seqNum),
480 arma::mat(responses.slice(responseSeq).colptr(begin),
481 responses.n_rows, batchSize,
false,
true));
486 std::vector<arma::mat> allDelta;
488 for (
size_t seqNum = 0; seqNum < rho; ++seqNum)
490 if (single && seqNum > 0)
494 else if (single && seqNum == 0)
496 outputLayer.Backward(results.slice(seqNum),
497 arma::mat(responses.slice(0).colptr(begin),
498 responses.n_rows, batchSize,
false,
true), error);
502 outputLayer.Backward(results.slice(seqNum),
503 arma::mat(responses.slice(seqNum).colptr(begin),
504 responses.n_rows, batchSize,
false,
true), error);
507 boost::apply_visitor(
BackwardVisitor(results.slice(seqNum), error, delta),
509 allDelta.push_back(arma::mat(delta));
513 totalGradient = arma::mat(gradient.memptr(),
514 parameter.n_elem / 2, 1,
false,
false);
516 forwardGradient.zeros();
517 forwardRNN.ResetGradients(forwardGradient);
518 backwardGradient.zeros();
519 backwardRNN.ResetGradients(backwardGradient);
521 for (
size_t seqNum = 0; seqNum < rho; ++seqNum)
523 forwardGradient.zeros();
524 for (
size_t l = 0; l < networkSize; ++l)
527 forwardRNNOutputParameter),
528 forwardRNN.network[networkSize - 1 - l]);
531 outputParameterVisitor, forwardRNN.network.back()),
532 allDelta[rho - seqNum - 1], delta, 0),
535 for (
size_t i = 2; i < networkSize; ++i)
538 boost::apply_visitor(outputParameterVisitor,
539 forwardRNN.network[networkSize - i]),
540 boost::apply_visitor(deltaVisitor,
541 forwardRNN.network[networkSize - i + 1]),
542 boost::apply_visitor(deltaVisitor,
543 forwardRNN.network[networkSize - i])),
544 forwardRNN.network[networkSize - i]);
547 arma::mat(predictors.slice(rho - seqNum - 1).colptr(begin),
548 predictors.n_rows, batchSize,
false,
true));
550 boost::apply_visitor(outputParameterVisitor,
551 forwardRNN.network[networkSize - 2]),
552 allDelta[rho - seqNum - 1], 0), mergeLayer);
553 totalGradient += forwardGradient;
557 totalGradient = arma::mat(gradient.memptr() + parameter.n_elem/2,
558 parameter.n_elem/2, 1,
false,
false);
560 for (
size_t seqNum = 0; seqNum < rho; ++seqNum)
562 backwardGradient.zeros();
563 for (
size_t l = 0; l < networkSize; ++l)
566 backwardRNNOutputParameter),
567 backwardRNN.network[networkSize - 1 - l]);
570 boost::apply_visitor(outputParameterVisitor,
571 backwardRNN.network.back()),
572 allDelta[seqNum], delta, 1), mergeLayer);
573 for (
size_t i = 2; i < networkSize; ++i)
576 boost::apply_visitor(outputParameterVisitor,
577 backwardRNN.network[networkSize - i]), boost::apply_visitor(
578 deltaVisitor, backwardRNN.network[networkSize - i + 1]),
579 boost::apply_visitor(deltaVisitor,
580 backwardRNN.network[networkSize - i])),
581 backwardRNN.network[networkSize - i]);
585 arma::mat(predictors.slice(seqNum).colptr(begin),
586 predictors.n_rows, batchSize,
false,
true));
588 std::move(boost::apply_visitor(outputParameterVisitor,
589 backwardRNN.network[networkSize - 2])),
590 allDelta[seqNum], 1), mergeLayer);
591 totalGradient += backwardGradient;
596 template<
typename OutputLayerType,
typename MergeLayerType,
597 typename MergeOutputType,
typename InitializationRuleType,
598 typename... CustomLayers>
599 void BRNN<OutputLayerType, MergeLayerType, MergeOutputType,
601 const arma::mat& parameters,
604 const size_t batchSize)
609 template<
typename OutputLayerType,
typename MergeLayerType,
610 typename MergeOutputType,
typename InitializationRuleType,
611 typename... CustomLayers>
612 void BRNN<OutputLayerType, MergeLayerType, MergeOutputType,
615 arma::cube newPredictors, newResponses;
618 predictors = std::move(newPredictors);
619 responses = std::move(newResponses);
622 template<
typename OutputLayerType,
typename MergeLayerType,
623 typename MergeOutputType,
typename InitializationRuleType,
624 typename... CustomLayers>
625 template <
class LayerType,
class... Args>
626 void BRNN<OutputLayerType, MergeLayerType, MergeOutputType,
627 InitializationRuleType, CustomLayers...>
::Add(Args... args)
629 forwardRNN.network.push_back(
new LayerType(args...));
630 backwardRNN.network.push_back(
new LayerType(args...));
633 template<
typename OutputLayerType,
typename MergeLayerType,
634 typename MergeOutputType,
typename InitializationRuleType,
635 typename... CustomLayers>
636 void BRNN<OutputLayerType, MergeLayerType, MergeOutputType,
637 InitializationRuleType, CustomLayers...>
:: 638 Add(LayerTypes<CustomLayers...> layer)
640 forwardRNN.network.push_back(layer);
641 backwardRNN.network.push_back(boost::apply_visitor(copyVisitor, layer));
644 template<
typename OutputLayerType,
typename MergeLayerType,
645 typename MergeOutputType,
typename InitializationRuleType,
646 typename... CustomLayers>
647 void BRNN<OutputLayerType, MergeLayerType, MergeOutputType,
655 forwardRNN.network.back()), mergeLayer);
657 backwardRNN.network.back()), mergeLayer);
661 ResetDeterministic();
665 CustomLayers...> networkInit(initializeRule);
666 size_t rnnWeights = 0;
667 for (
size_t i = 0; i < forwardRNN.network.size(); ++i)
669 rnnWeights += boost::apply_visitor(weightSizeVisitor,
670 forwardRNN.network[i]);
673 parameter.set_size(2 * rnnWeights, 1);
675 forwardRNN.
Parameters() = arma::mat(parameter.memptr(),
676 rnnWeights, 1,
false,
false);
677 backwardRNN.
Parameters() = arma::mat(parameter.memptr() + rnnWeights,
678 rnnWeights, 1,
false,
false);
681 networkInit.Initialize(forwardRNN.network, parameter);
684 networkInit.Initialize(backwardRNN.network, parameter, rnnWeights);
686 reset = forwardRNN.reset = backwardRNN.reset =
true;
689 template<
typename OutputLayerType,
typename MergeLayerType,
690 typename MergeOutputType,
typename InitializationRuleType,
691 typename... CustomLayers>
692 void BRNN<OutputLayerType, MergeLayerType, MergeOutputType,
693 InitializationRuleType, CustomLayers...>
::Reset()
696 forwardRNN.ResetCells();
697 backwardRNN.ResetCells();
698 forwardGradient.zeros();
699 forwardRNN.ResetGradients(forwardGradient);
700 backwardGradient.zeros();
701 backwardRNN.ResetGradients(backwardGradient);
704 template<
typename OutputLayerType,
typename MergeLayerType,
705 typename MergeOutputType,
typename InitializationRuleType,
706 typename... CustomLayers>
707 void BRNN<OutputLayerType, MergeLayerType, MergeOutputType,
708 InitializationRuleType, CustomLayers...>::ResetDeterministic()
710 forwardRNN.deterministic = this->deterministic;
711 backwardRNN.deterministic = this->deterministic;
712 forwardRNN.ResetDeterministic();
713 backwardRNN.ResetDeterministic();
716 template<
typename OutputLayerType,
typename MergeLayerType,
717 typename MergeOutputType,
typename InitializationRuleType,
718 typename... CustomLayers>
719 template<
typename Archive>
720 void BRNN<OutputLayerType, MergeLayerType, MergeOutputType,
722 Archive& ar,
const uint32_t version)
724 ar(CEREAL_NVP(parameter));
725 ar(CEREAL_NVP(backwardRNN));
726 ar(CEREAL_NVP(forwardRNN));
void serialize(Archive &ar, const uint32_t)
Serialize the model.
Definition: brnn_impl.hpp:721
DeleteVisitor executes the destructor of the instantiated object.
Definition: delete_visitor.hpp:27
static void Start(const std::string &name)
Start the given timer.
Definition: timers.cpp:28
Implementation of the Add module class.
Definition: add.hpp:34
std::enable_if< HasMaxIterations< OptimizerType, size_t &(OptimizerType::*)()>::value, void >::type WarnMessageMaxIterations(OptimizerType &optimizer, size_t samples) const
Check if the optimizer has MaxIterations() parameter, if it does then check if it's value is less tha...
Definition: brnn_impl.hpp:87
RunSetVisitor set the run parameter given the run value.
Definition: run_set_visitor.hpp:28
BackwardVisitor executes the Backward() function given the input, error and delta parameter...
Definition: backward_visitor.hpp:28
double Train(arma::cube predictors, arma::cube responses, OptimizerType &optimizer)
Train the bidirectional recurrent neural network on the given input data using the given optimizer...
Definition: brnn_impl.hpp:121
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Gradient(const arma::mat ¶meters, const size_t begin, arma::mat &gradient, const size_t batchSize)
Evaluate the gradient of the recurrent neural network with the given parameters, and with respect to ...
Definition: rnn_impl.hpp:474
Definition: pointer_wrapper.hpp:23
SaveOutputParameterVisitor saves the output parameter into the given parameter set.
Definition: save_output_parameter_visitor.hpp:27
void Gradient(const arma::mat ¶meters, const size_t begin, arma::mat &gradient, const size_t batchSize)
Evaluate the gradient of the bidirectional recurrent neural network with the given parameters...
Definition: brnn_impl.hpp:600
Implementation of the Concat class.
Definition: concat.hpp:43
void Reset()
Reset the state of the network.
Definition: brnn_impl.hpp:693
static MLPACK_EXPORT util::PrefixedOutStream Warn
Prints warning messages prefixed with [WARN ].
Definition: log.hpp:87
ForwardVisitor executes the Forward() function given the input and output parameter.
Definition: forward_visitor.hpp:28
void Shuffle()
Shuffle the order of function visitation.
Definition: brnn_impl.hpp:613
void ResetParameters()
Reset the module information (weights/parameters).
Definition: brnn_impl.hpp:648
AddVisitor exposes the Add() method of the given module.
Definition: add_visitor.hpp:28
Implementation of a standard bidirectional recurrent neural network container.
Definition: brnn.hpp:48
double Evaluate(const arma::mat ¶meters, const size_t begin, const size_t batchSize, const bool deterministic)
Evaluate the bidirectional recurrent neural network with the given parameters.
Definition: brnn_impl.hpp:266
static void Stop(const std::string &name)
Stop the given timer.
Definition: timers.cpp:36
SearchModeVisitor executes the Gradient() method of the given module using the input and delta parame...
Definition: gradient_visitor.hpp:28
double EvaluateWithGradient(const arma::mat ¶meters, const size_t begin, GradType &gradient, const size_t batchSize)
Evaluate the bidirectional recurrent neural network with the given parameters.
Definition: brnn_impl.hpp:369
void ShuffleData(const MatType &inputPoints, const LabelsType &inputLabels, MatType &outputPoints, LabelsType &outputLabels, const std::enable_if_t<!arma::is_SpMat< MatType >::value > *=0, const std::enable_if_t<!arma::is_Cube< MatType >::value > *=0)
Shuffle a dataset and associated labels (or responses).
Definition: shuffle_data.hpp:28
void Predict(arma::cube predictors, arma::cube &results, const size_t batchSize=256)
Predict the responses to a given set of predictors.
Definition: brnn_impl.hpp:189
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84
LoadOutputParameterVisitor restores the output parameter using the given parameter set...
Definition: load_output_parameter_visitor.hpp:28
This class is used to initialize the network with the given initialization rule.
Definition: network_init.hpp:33
const arma::mat & Parameters() const
Return the initial point for the optimization.
Definition: rnn.hpp:297