mlpack
Functions
recurrent_network_test.cpp File Reference
#include <mlpack/core.hpp>
#include <ensmallen.hpp>
#include <mlpack/methods/ann/layer/layer.hpp>
#include <mlpack/methods/ann/loss_functions/mean_squared_error.hpp>
#include <mlpack/methods/ann/rnn.hpp>
#include <mlpack/methods/ann/brnn.hpp>
#include <mlpack/core/data/binarize.hpp>
#include <mlpack/core/math/random.hpp>
#include "catch.hpp"
#include "serialization.hpp"
#include "custom_layer.hpp"
Include dependency graph for recurrent_network_test.cpp:

Functions

void GenerateNoisySines (arma::cube &data, arma::mat &labels, const size_t points, const size_t sequences, const double noise=0.3)
 Construct a 2-class dataset out of noisy sines. More...
 
void GenerateDistractedSequence (arma::mat &input, arma::mat &output)
 
template<typename RecurrentLayerType >
void DistractedSequenceRecallTestNetwork (const size_t cellSize, const size_t hiddenSize)
 Train the specified network and the construct distracted sequence recall dataset.
 
 TEST_CASE ("LSTMDistractedSequenceRecallTest", "[RecurrentNetworkTest]")
 Train the specified networks on the Derek D. More...
 
 TEST_CASE ("FastLSTMDistractedSequenceRecallTest", "[RecurrentNetworkTest]")
 Train the specified networks on the Derek D. More...
 
 TEST_CASE ("GRUDistractedSequenceRecallTest", "[RecurrentNetworkTest]")
 Train the specified networks on the Derek D. More...
 
template<typename RecurrentLayerType >
void BatchSizeTest ()
 Create a simple recurrent neural network for the noisy sines task, and require that it produces the exact same network for a few batch sizes.
 
 TEST_CASE ("LSTMBatchSizeTest", "[RecurrentNetworkTest]")
 Ensure LSTMs work with larger batch sizes.
 
 TEST_CASE ("FastLSTMBatchSizeTest", "[RecurrentNetworkTest]")
 Ensure fast LSTMs work with larger batch sizes.
 
 TEST_CASE ("GRUBatchSizeTest", "[RecurrentNetworkTest]")
 Ensure GRUs work with larger batch sizes.
 
 TEST_CASE ("RNNSerializationTest", "[RecurrentNetworkTest]")
 Make sure the RNN can be properly serialized. More...
 
 TEST_CASE ("SequenceClassificationBRNNTest", "[RecurrentNetworkTest]")
 Train the BRNN on a larger dataset.
 
 TEST_CASE ("SequenceClassificationTest", "[RecurrentNetworkTest]")
 Train the vanilla network on a larger dataset. More...
 
void GenerateNoisySinRNN (arma::cube &data, arma::cube &labels, size_t rho, size_t outputSteps=1, const int dataPoints=100, const double gain=1.0, const int freq=10, const double phase=0, const int noisePercent=20, const double numCycles=6.0, const bool normalize=true)
 Generates noisy sine wave and outputs the data and the labels that can be used directly for training and testing with RNN. More...
 
double RNNSineTest (size_t hiddenUnits, size_t rho, size_t numEpochs=100)
 RNNSineTest Test a simple RNN using noisy sine. More...
 
 TEST_CASE ("MultiTimestepTest", "[RecurrentNetworkTest]")
 Test RNN using multiple timestep input and single output.
 
 TEST_CASE ("RNNTrainReturnObjective", "[RecurrentNetworkTest]")
 Test that RNN::Train() returns finite objective value. More...
 
 TEST_CASE ("BRNNTrainReturnObjective", "[RecurrentNetworkTest]")
 Test that BRNN::Train() returns finite objective value.
 
 TEST_CASE ("LargeRhoValueRnnTest", "[RecurrentNetworkTest]")
 Test that RNN::Train() does not give an error for large rho.
 
 TEST_CASE ("RNNCheckInputShapeTest", "[RecurrentNetworkTest]")
 Test to make sure that an error is thrown when input with wrong input shape is provided to a RNN. More...
 

Detailed Description

Author
Marcus Edel

Tests the recurrent network.

mlpack is free software; you may redistribute it and/or modify it under the terms of the 3-clause BSD license. You should have received a copy of the 3-clause BSD license along with mlpack. If not, see http://www.opensource.org/licenses/BSD-3-Clause for more information.

Function Documentation

◆ GenerateNoisySines()

void GenerateNoisySines ( arma::cube &  data,
arma::mat &  labels,
const size_t  points,
const size_t  sequences,
const double  noise = 0.3 
)

Construct a 2-class dataset out of noisy sines.

Parameters
dataInput data used to store the noisy sines.
labelsLabels used to store the target class of the noisy sines.
pointsNumber of points/features in a single sequence.
sequencesNumber of sequences for each class.
noiseThe noise factor that influences the sines.

◆ GenerateNoisySinRNN()

void GenerateNoisySinRNN ( arma::cube &  data,
arma::cube &  labels,
size_t  rho,
size_t  outputSteps = 1,
const int  dataPoints = 100,
const double  gain = 1.0,
const int  freq = 10,
const double  phase = 0,
const int  noisePercent = 20,
const double  numCycles = 6.0,
const bool  normalize = true 
)

Generates noisy sine wave and outputs the data and the labels that can be used directly for training and testing with RNN.

Parameters
dataThe data points as output
labelsThe expected values as output
rhoThe size of the sequence of each data point
outputStepsHow many output steps to consider for every rho inputs
dataPointsThe number of generated data points. The actual generated data points may be more than this to adjust to the outputSteps. But at the minimum these many data points will be generated.
gainThe gain on the amplitude
freqThe frquency of the sine wave
phaseThe phase shift if any
noisePercentThe percent noise to induce
numCyclesHow many full size wave cycles required. All the data points will be fit into these cycles.
normalizeWhether to normalise the data. This may be required for some layers like LSTM. Default is true.

◆ RNNSineTest()

double RNNSineTest ( size_t  hiddenUnits,
size_t  rho,
size_t  numEpochs = 100 
)

RNNSineTest Test a simple RNN using noisy sine.

Use single output for multiple inputs.

Parameters
hiddenUnitsNo of units in the hiddenlayer.
rhoThe input sequence length.
numEpochsThe number of epochs to run.
Returns
The mean squared error of the prediction.

◆ TEST_CASE() [1/7]

TEST_CASE ( "LSTMDistractedSequenceRecallTest"  ,
""  [RecurrentNetworkTest] 
)

Train the specified networks on the Derek D.

Monner's distracted sequence recall task.

◆ TEST_CASE() [2/7]

TEST_CASE ( "FastLSTMDistractedSequenceRecallTest"  ,
""  [RecurrentNetworkTest] 
)

Train the specified networks on the Derek D.

Monner's distracted sequence recall task.

◆ TEST_CASE() [3/7]

TEST_CASE ( "GRUDistractedSequenceRecallTest"  ,
""  [RecurrentNetworkTest] 
)

Train the specified networks on the Derek D.

Monner's distracted sequence recall task.

◆ TEST_CASE() [4/7]

TEST_CASE ( "RNNSerializationTest"  ,
""  [RecurrentNetworkTest] 
)

Make sure the RNN can be properly serialized.

Construct a network with 1 input unit, 4 hidden units and 10 output units. The hidden layer is connected to itself. The network structure looks like:

Input Hidden Output Layer(1) Layer(4) Layer(10) +--—+ +--—+ +--—+ | | | | | | | +---—>| +---—>| | | | ..>| | | | +--—+ . +–+–+ +--—+ . . . . .......

◆ TEST_CASE() [5/7]

TEST_CASE ( "SequenceClassificationTest"  ,
""  [RecurrentNetworkTest] 
)

Train the vanilla network on a larger dataset.

Construct a network with 1 input unit, 4 hidden units and 10 output units. The hidden layer is connected to itself. The network structure looks like:

Input Hidden Output Layer(1) Layer(4) Layer(10) +--—+ +--—+ +--—+ | | | | | | | +---—>| +---—>| | | | ..>| | | | +--—+ . +–+–+ +--—+ . . . . .......

◆ TEST_CASE() [6/7]

TEST_CASE ( "RNNTrainReturnObjective"  ,
""  [RecurrentNetworkTest] 
)

Test that RNN::Train() returns finite objective value.

Construct a network with 1 input unit, 4 hidden units and 10 output units. The hidden layer is connected to itself. The network structure looks like:

Input Hidden Output Layer(1) Layer(4) Layer(10) +--—+ +--—+ +--—+ | | | | | | | +---—>| +---—>| | | | ..>| | | | +--—+ . +–+–+ +--—+ . . . . .......

◆ TEST_CASE() [7/7]

TEST_CASE ( "RNNCheckInputShapeTest"  ,
""  [RecurrentNetworkTest] 
)

Test to make sure that an error is thrown when input with wrong input shape is provided to a RNN.

Construct a network with 1 input unit, 4 hidden units and 10 output units. The hidden layer is connected to itself. The network structure looks like:

Input Hidden Output Layer(1) Layer(4) Layer(10) +--—+ +--—+ +--—+ | | | | | | | +---—>| +---—>| | | | ..>| | | | +--—+ . +–+–+ +--—+ . . . . .......