12 #ifndef MLPACK_METHODS_ANN_RNN_HPP 13 #define MLPACK_METHODS_ANN_RNN_HPP 29 #include <ensmallen.hpp> 41 typename OutputLayerType = NegativeLogLikelihood<>,
42 typename InitializationRuleType = RandomInitialization,
43 typename... CustomLayers
50 InitializationRuleType,
69 const bool single =
false,
70 OutputLayerType outputLayer = OutputLayerType(),
71 InitializationRuleType initializeRule = InitializationRuleType());
97 template<
typename OptimizerType>
98 typename std::enable_if<
99 HasMaxIterations<OptimizerType, size_t&(OptimizerType::*)()>
111 template<
typename OptimizerType>
112 typename std::enable_if<
113 !HasMaxIterations<OptimizerType, size_t&(OptimizerType::*)()>
144 template<
typename OptimizerType,
typename... CallbackTypes>
145 double Train(arma::cube predictors,
146 arma::cube responses,
147 OptimizerType& optimizer,
148 CallbackTypes&&... callbacks);
177 template<
typename OptimizerType = ens::StandardSGD,
typename... CallbackTypes>
178 double Train(arma::cube predictors,
179 arma::cube responses,
180 CallbackTypes&&... callbacks);
201 void Predict(arma::cube predictors,
203 const size_t batchSize = 256);
217 double Evaluate(
const arma::mat& parameters,
219 const size_t batchSize,
220 const bool deterministic);
233 double Evaluate(
const arma::mat& parameters,
235 const size_t batchSize);
248 template<
typename GradType>
252 const size_t batchSize);
267 void Gradient(
const arma::mat& parameters,
270 const size_t batchSize);
283 template <
class LayerType,
class... Args>
284 void Add(Args... args) { network.push_back(
new LayerType(args...)); }
291 void Add(LayerTypes<CustomLayers...> layer) { network.push_back(layer); }
302 const size_t&
Rho()
const {
return rho; }
304 size_t&
Rho() {
return rho; }
307 const arma::cube&
Responses()
const {
return responses; }
329 template<
typename Archive>
330 void serialize(Archive& ar,
const uint32_t );
340 template<
typename InputType>
341 void Forward(
const InputType& input);
358 template<
typename InputType>
359 void Gradient(
const InputType& input);
365 void ResetDeterministic();
370 void ResetGradients(arma::mat& gradient);
376 OutputLayerType outputLayer;
380 InitializationRuleType initializeRule;
398 std::vector<LayerTypes<CustomLayers...> > network;
401 arma::cube predictors;
404 arma::cube responses;
422 std::vector<arma::mat> moduleOutputParameter;
440 arma::mat currentGradient;
444 typename OutputLayerType1,
445 typename MergeLayerType1,
446 typename MergeOutputType1,
447 typename InitializationRuleType1,
448 typename... CustomLayers1
DeleteVisitor executes the destructor of the instantiated object.
Definition: delete_visitor.hpp:27
const arma::cube & Predictors() const
Get the matrix of data points (predictors).
Definition: rnn.hpp:312
Implementation of the Add module class.
Definition: add.hpp:34
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
RNN & operator=(const RNN &)
Copy assignment operator.
void Gradient(const arma::mat ¶meters, const size_t begin, arma::mat &gradient, const size_t batchSize)
Evaluate the gradient of the recurrent neural network with the given parameters, and with respect to ...
Definition: rnn_impl.hpp:474
This visitor is to support copy constructor for neural network module.
Definition: copy_visitor.hpp:26
The core includes that mlpack expects; standard C++ includes and Armadillo.
~RNN()
Destructor to release allocated memory.
Definition: rnn_impl.hpp:100
WeightSizeVisitor returns the number of weights of the given module.
Definition: weight_size_visitor.hpp:27
size_t & Rho()
Modify the maximum length of backpropagation through time.
Definition: rnn.hpp:304
Implementation of a standard recurrent neural network container.
Definition: rnn.hpp:45
RNN(const size_t rho, const bool single=false, OutputLayerType outputLayer=OutputLayerType(), InitializationRuleType initializeRule=InitializationRuleType())
Create the RNN object.
Definition: rnn_impl.hpp:35
arma::cube & Predictors()
Modify the matrix of data points (predictors).
Definition: rnn.hpp:314
void Shuffle()
Shuffle the order of function visitation.
Definition: rnn_impl.hpp:485
ResetVisitor executes the Reset() function.
Definition: reset_visitor.hpp:26
OutputParameterVisitor exposes the output parameter of the given module.
Definition: output_parameter_visitor.hpp:27
void Reset()
Reset the state of the network.
Definition: rnn_impl.hpp:511
void Predict(arma::cube predictors, arma::cube &results, const size_t batchSize=256)
Predict the responses to a given set of predictors.
Definition: rnn_impl.hpp:231
const size_t & Rho() const
Return the maximum length of backpropagation through time.
Definition: rnn.hpp:302
const arma::cube & Responses() const
Get the matrix of responses to the input data points.
Definition: rnn.hpp:307
Implementation of a standard bidirectional recurrent neural network container.
Definition: brnn.hpp:48
size_t NumFunctions() const
Return the number of separable functions (the number of predictor points).
Definition: rnn.hpp:294
double EvaluateWithGradient(const arma::mat ¶meters, const size_t begin, GradType &gradient, const size_t batchSize)
Evaluate the recurrent neural network with the given parameters.
Definition: rnn_impl.hpp:354
arma::mat & Parameters()
Modify the initial point for the optimization.
Definition: rnn.hpp:299
DeltaVisitor exposes the delta parameter of the given module.
Definition: delta_visitor.hpp:27
arma::cube & Responses()
Modify the matrix of responses to the input data points.
Definition: rnn.hpp:309
void serialize(Archive &ar, const uint32_t)
Serialize the model.
Definition: rnn_impl.hpp:602
double Train(arma::cube predictors, arma::cube responses, OptimizerType &optimizer, CallbackTypes &&... callbacks)
Train the recurrent neural network on the given input data using the given optimizer.
Definition: rnn_impl.hpp:146
double Evaluate(const arma::mat ¶meters, const size_t begin, const size_t batchSize, const bool deterministic)
Evaluate the recurrent neural network with the given parameters.
Definition: rnn_impl.hpp:282
const arma::mat & Parameters() const
Return the initial point for the optimization.
Definition: rnn.hpp:297
void ResetParameters()
Reset the module information (weights/parameters).
Definition: rnn_impl.hpp:497
std::enable_if< HasMaxIterations< OptimizerType, size_t &(OptimizerType::*)()>::value, void >::type WarnMessageMaxIterations(OptimizerType &optimizer, size_t samples) const
Check if the optimizer has MaxIterations() parameter, if it does then check if it's value is less tha...
Definition: rnn_impl.hpp:115