13 #ifndef MLPACK_METHODS_ANN_BRNN_HPP 14 #define MLPACK_METHODS_ANN_BRNN_HPP 30 #include <ensmallen.hpp> 42 typename OutputLayerType = NegativeLogLikelihood<>,
43 typename MergeLayerType = Concat<>,
44 typename MergeOutputType = LogSoftMax<>,
45 typename InitializationRuleType = RandomInitialization,
46 typename... CustomLayers
55 InitializationRuleType,
75 BRNN(
const size_t rho,
76 const bool single =
false,
77 OutputLayerType outputLayer = OutputLayerType(),
78 MergeLayerType* mergeLayer =
new MergeLayerType(),
79 MergeOutputType* mergeOutput =
new MergeOutputType(),
80 InitializationRuleType initializeRule = InitializationRuleType());
93 template<
typename OptimizerType>
94 typename std::enable_if<
95 HasMaxIterations<OptimizerType, size_t&(OptimizerType::*)()>
107 template<
typename OptimizerType>
108 typename std::enable_if<
109 !HasMaxIterations<OptimizerType, size_t&(OptimizerType::*)()>
136 template<
typename OptimizerType>
137 double Train(arma::cube predictors,
138 arma::cube responses,
139 OptimizerType& optimizer);
164 template<
typename OptimizerType = ens::StandardSGD>
165 double Train(arma::cube predictors, arma::cube responses);
186 void Predict(arma::cube predictors,
188 const size_t batchSize = 256);
203 double Evaluate(
const arma::mat& parameters,
205 const size_t batchSize,
206 const bool deterministic);
220 double Evaluate(
const arma::mat& parameters,
222 const size_t batchSize);
237 template<
typename GradType>
241 const size_t batchSize);
256 void Gradient(
const arma::mat& parameters,
259 const size_t batchSize);
272 template <
class LayerType,
class... Args>
273 void Add(Args... args);
280 void Add(LayerTypes<CustomLayers...> layer);
291 const size_t&
Rho()
const {
return rho; }
293 size_t&
Rho() {
return rho; }
296 const arma::cube&
Responses()
const {
return responses; }
318 template<
typename Archive>
319 void serialize(Archive& ar,
const uint32_t );
327 void ResetDeterministic();
333 OutputLayerType outputLayer;
336 LayerTypes<CustomLayers...> mergeLayer;
339 LayerTypes<CustomLayers...> mergeOutput;
343 InitializationRuleType initializeRule;
361 arma::cube predictors;
364 arma::cube responses;
382 std::vector<arma::mat> forwardRNNOutputParameter;
385 std::vector<arma::mat> backwardRNNOutputParameter;
403 arma::mat forwardGradient;
406 arma::mat backwardGradient;
409 arma::mat totalGradient;
412 RNN<OutputLayerType, InitializationRuleType, CustomLayers...> forwardRNN;
415 RNN<OutputLayerType, InitializationRuleType, CustomLayers...> backwardRNN;
void serialize(Archive &ar, const uint32_t)
Serialize the model.
Definition: brnn_impl.hpp:721
DeleteVisitor executes the destructor of the instantiated object.
Definition: delete_visitor.hpp:27
Implementation of the Add module class.
Definition: add.hpp:34
arma::mat & Parameters()
Modify the initial point for the optimization.
Definition: brnn.hpp:288
std::enable_if< HasMaxIterations< OptimizerType, size_t &(OptimizerType::*)()>::value, void >::type WarnMessageMaxIterations(OptimizerType &optimizer, size_t samples) const
Check if the optimizer has MaxIterations() parameter, if it does then check if it's value is less tha...
Definition: brnn_impl.hpp:87
double Train(arma::cube predictors, arma::cube responses, OptimizerType &optimizer)
Train the bidirectional recurrent neural network on the given input data using the given optimizer...
Definition: brnn_impl.hpp:121
size_t NumFunctions() const
Return the number of separable functions. (number of predictor points).
Definition: brnn.hpp:283
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
This visitor is to support copy constructor for neural network module.
Definition: copy_visitor.hpp:26
The core includes that mlpack expects; standard C++ includes and Armadillo.
WeightSizeVisitor returns the number of weights of the given module.
Definition: weight_size_visitor.hpp:27
size_t & Rho()
Modify the maximum length of backpropagation through time.
Definition: brnn.hpp:293
const arma::cube & Responses() const
Get the matrix of responses to the input data points.
Definition: brnn.hpp:296
const size_t & Rho() const
Return the maximum length of backpropagation through time.
Definition: brnn.hpp:291
Implementation of a standard recurrent neural network container.
Definition: rnn.hpp:45
void Gradient(const arma::mat ¶meters, const size_t begin, arma::mat &gradient, const size_t batchSize)
Evaluate the gradient of the bidirectional recurrent neural network with the given parameters...
Definition: brnn_impl.hpp:600
void Reset()
Reset the state of the network.
Definition: brnn_impl.hpp:693
BRNN(const size_t rho, const bool single=false, OutputLayerType outputLayer=OutputLayerType(), MergeLayerType *mergeLayer=new MergeLayerType(), MergeOutputType *mergeOutput=new MergeOutputType(), InitializationRuleType initializeRule=InitializationRuleType())
Create the BRNN object.
Definition: brnn_impl.hpp:36
arma::cube & Predictors()
Modify the matrix of data points (predictors).
Definition: brnn.hpp:303
arma::cube & Responses()
Modify the matrix of responses to the input data points.
Definition: brnn.hpp:298
ResetVisitor executes the Reset() function.
Definition: reset_visitor.hpp:26
OutputParameterVisitor exposes the output parameter of the given module.
Definition: output_parameter_visitor.hpp:27
const arma::cube & Predictors() const
Get the matrix of data points (predictors).
Definition: brnn.hpp:301
void Shuffle()
Shuffle the order of function visitation.
Definition: brnn_impl.hpp:613
void ResetParameters()
Reset the module information (weights/parameters).
Definition: brnn_impl.hpp:648
Implementation of a standard bidirectional recurrent neural network container.
Definition: brnn.hpp:48
double Evaluate(const arma::mat ¶meters, const size_t begin, const size_t batchSize, const bool deterministic)
Evaluate the bidirectional recurrent neural network with the given parameters.
Definition: brnn_impl.hpp:266
double EvaluateWithGradient(const arma::mat ¶meters, const size_t begin, GradType &gradient, const size_t batchSize)
Evaluate the bidirectional recurrent neural network with the given parameters.
Definition: brnn_impl.hpp:369
void Predict(arma::cube predictors, arma::cube &results, const size_t batchSize=256)
Predict the responses to a given set of predictors.
Definition: brnn_impl.hpp:189
DeltaVisitor exposes the delta parameter of the given module.
Definition: delta_visitor.hpp:27
const arma::mat & Parameters() const
Return the initial point for the optimization.
Definition: brnn.hpp:286