mlpack
|
#include <mlpack/core.hpp>
#include <ensmallen.hpp>
#include <mlpack/methods/ann/layer/layer.hpp>
#include <mlpack/methods/ann/loss_functions/mean_squared_error.hpp>
#include <mlpack/methods/ann/rnn.hpp>
#include <mlpack/methods/ann/brnn.hpp>
#include <mlpack/core/data/binarize.hpp>
#include <mlpack/core/math/random.hpp>
#include "catch.hpp"
#include "serialization.hpp"
#include "custom_layer.hpp"
Functions | |
void | GenerateNoisySines (arma::cube &data, arma::mat &labels, const size_t points, const size_t sequences, const double noise=0.3) |
Construct a 2-class dataset out of noisy sines. More... | |
void | GenerateDistractedSequence (arma::mat &input, arma::mat &output) |
template<typename RecurrentLayerType > | |
void | DistractedSequenceRecallTestNetwork (const size_t cellSize, const size_t hiddenSize) |
Train the specified network and the construct distracted sequence recall dataset. | |
TEST_CASE ("LSTMDistractedSequenceRecallTest", "[RecurrentNetworkTest]") | |
Train the specified networks on the Derek D. More... | |
TEST_CASE ("FastLSTMDistractedSequenceRecallTest", "[RecurrentNetworkTest]") | |
Train the specified networks on the Derek D. More... | |
TEST_CASE ("GRUDistractedSequenceRecallTest", "[RecurrentNetworkTest]") | |
Train the specified networks on the Derek D. More... | |
template<typename RecurrentLayerType > | |
void | BatchSizeTest () |
Create a simple recurrent neural network for the noisy sines task, and require that it produces the exact same network for a few batch sizes. | |
TEST_CASE ("LSTMBatchSizeTest", "[RecurrentNetworkTest]") | |
Ensure LSTMs work with larger batch sizes. | |
TEST_CASE ("FastLSTMBatchSizeTest", "[RecurrentNetworkTest]") | |
Ensure fast LSTMs work with larger batch sizes. | |
TEST_CASE ("GRUBatchSizeTest", "[RecurrentNetworkTest]") | |
Ensure GRUs work with larger batch sizes. | |
TEST_CASE ("RNNSerializationTest", "[RecurrentNetworkTest]") | |
Make sure the RNN can be properly serialized. More... | |
TEST_CASE ("SequenceClassificationBRNNTest", "[RecurrentNetworkTest]") | |
Train the BRNN on a larger dataset. | |
TEST_CASE ("SequenceClassificationTest", "[RecurrentNetworkTest]") | |
Train the vanilla network on a larger dataset. More... | |
void | GenerateNoisySinRNN (arma::cube &data, arma::cube &labels, size_t rho, size_t outputSteps=1, const int dataPoints=100, const double gain=1.0, const int freq=10, const double phase=0, const int noisePercent=20, const double numCycles=6.0, const bool normalize=true) |
Generates noisy sine wave and outputs the data and the labels that can be used directly for training and testing with RNN. More... | |
double | RNNSineTest (size_t hiddenUnits, size_t rho, size_t numEpochs=100) |
RNNSineTest Test a simple RNN using noisy sine. More... | |
TEST_CASE ("MultiTimestepTest", "[RecurrentNetworkTest]") | |
Test RNN using multiple timestep input and single output. | |
TEST_CASE ("RNNTrainReturnObjective", "[RecurrentNetworkTest]") | |
Test that RNN::Train() returns finite objective value. More... | |
TEST_CASE ("BRNNTrainReturnObjective", "[RecurrentNetworkTest]") | |
Test that BRNN::Train() returns finite objective value. | |
TEST_CASE ("LargeRhoValueRnnTest", "[RecurrentNetworkTest]") | |
Test that RNN::Train() does not give an error for large rho. | |
TEST_CASE ("RNNCheckInputShapeTest", "[RecurrentNetworkTest]") | |
Test to make sure that an error is thrown when input with wrong input shape is provided to a RNN. More... | |
Tests the recurrent network.
mlpack is free software; you may redistribute it and/or modify it under the terms of the 3-clause BSD license. You should have received a copy of the 3-clause BSD license along with mlpack. If not, see http://www.opensource.org/licenses/BSD-3-Clause for more information.
void GenerateNoisySines | ( | arma::cube & | data, |
arma::mat & | labels, | ||
const size_t | points, | ||
const size_t | sequences, | ||
const double | noise = 0.3 |
||
) |
Construct a 2-class dataset out of noisy sines.
data | Input data used to store the noisy sines. |
labels | Labels used to store the target class of the noisy sines. |
points | Number of points/features in a single sequence. |
sequences | Number of sequences for each class. |
noise | The noise factor that influences the sines. |
void GenerateNoisySinRNN | ( | arma::cube & | data, |
arma::cube & | labels, | ||
size_t | rho, | ||
size_t | outputSteps = 1 , |
||
const int | dataPoints = 100 , |
||
const double | gain = 1.0 , |
||
const int | freq = 10 , |
||
const double | phase = 0 , |
||
const int | noisePercent = 20 , |
||
const double | numCycles = 6.0 , |
||
const bool | normalize = true |
||
) |
Generates noisy sine wave and outputs the data and the labels that can be used directly for training and testing with RNN.
data | The data points as output |
labels | The expected values as output |
rho | The size of the sequence of each data point |
outputSteps | How many output steps to consider for every rho inputs |
dataPoints | The number of generated data points. The actual generated data points may be more than this to adjust to the outputSteps. But at the minimum these many data points will be generated. |
gain | The gain on the amplitude |
freq | The frquency of the sine wave |
phase | The phase shift if any |
noisePercent | The percent noise to induce |
numCycles | How many full size wave cycles required. All the data points will be fit into these cycles. |
normalize | Whether to normalise the data. This may be required for some layers like LSTM. Default is true. |
double RNNSineTest | ( | size_t | hiddenUnits, |
size_t | rho, | ||
size_t | numEpochs = 100 |
||
) |
RNNSineTest Test a simple RNN using noisy sine.
Use single output for multiple inputs.
hiddenUnits | No of units in the hiddenlayer. |
rho | The input sequence length. |
numEpochs | The number of epochs to run. |
TEST_CASE | ( | "LSTMDistractedSequenceRecallTest" | , |
"" | [RecurrentNetworkTest] | ||
) |
Train the specified networks on the Derek D.
Monner's distracted sequence recall task.
TEST_CASE | ( | "FastLSTMDistractedSequenceRecallTest" | , |
"" | [RecurrentNetworkTest] | ||
) |
Train the specified networks on the Derek D.
Monner's distracted sequence recall task.
TEST_CASE | ( | "GRUDistractedSequenceRecallTest" | , |
"" | [RecurrentNetworkTest] | ||
) |
Train the specified networks on the Derek D.
Monner's distracted sequence recall task.
TEST_CASE | ( | "RNNSerializationTest" | , |
"" | [RecurrentNetworkTest] | ||
) |
Make sure the RNN can be properly serialized.
Construct a network with 1 input unit, 4 hidden units and 10 output units. The hidden layer is connected to itself. The network structure looks like:
Input Hidden Output Layer(1) Layer(4) Layer(10) +--—+ +--—+ +--—+ | | | | | | | +---—>| +---—>| | | | ..>| | | | +--—+ . +–+–+ +--—+ . . . . .......
TEST_CASE | ( | "SequenceClassificationTest" | , |
"" | [RecurrentNetworkTest] | ||
) |
Train the vanilla network on a larger dataset.
Construct a network with 1 input unit, 4 hidden units and 10 output units. The hidden layer is connected to itself. The network structure looks like:
Input Hidden Output Layer(1) Layer(4) Layer(10) +--—+ +--—+ +--—+ | | | | | | | +---—>| +---—>| | | | ..>| | | | +--—+ . +–+–+ +--—+ . . . . .......
TEST_CASE | ( | "RNNTrainReturnObjective" | , |
"" | [RecurrentNetworkTest] | ||
) |
Test that RNN::Train() returns finite objective value.
Construct a network with 1 input unit, 4 hidden units and 10 output units. The hidden layer is connected to itself. The network structure looks like:
Input Hidden Output Layer(1) Layer(4) Layer(10) +--—+ +--—+ +--—+ | | | | | | | +---—>| +---—>| | | | ..>| | | | +--—+ . +–+–+ +--—+ . . . . .......
TEST_CASE | ( | "RNNCheckInputShapeTest" | , |
"" | [RecurrentNetworkTest] | ||
) |
Test to make sure that an error is thrown when input with wrong input shape is provided to a RNN.
Construct a network with 1 input unit, 4 hidden units and 10 output units. The hidden layer is connected to itself. The network structure looks like:
Input Hidden Output Layer(1) Layer(4) Layer(10) +--—+ +--—+ +--—+ | | | | | | | +---—>| +---—>| | | | ..>| | | | +--—+ . +–+–+ +--—+ . . . . .......