12 #ifndef MLPACK_METHODS_RL_SIMPLE_DQN_HPP 13 #define MLPACK_METHODS_RL_SIMPLE_DQN_HPP 60 const bool isNoisy =
false,
61 InitType init = InitType(),
62 OutputLayerType outputLayer = OutputLayerType()):
63 network(outputLayer, init),
66 network.Add(
new Linear<>(inputDim, h1));
70 noisyLayerIndex.push_back(network.Model().size());
73 noisyLayerIndex.push_back(network.Model().size());
80 network.Add(
new Linear<>(h2, outputDim));
90 SimpleDQN(NetworkType& network,
const bool isNoisy =
false):
106 void Predict(
const arma::mat state, arma::mat& actionValue)
108 network.Predict(state, actionValue);
117 void Forward(
const arma::mat state, arma::mat& target)
119 network.Forward(state, target);
127 network.ResetParameters();
135 for (
size_t i = 0; i < noisyLayerIndex.size(); i++)
137 boost::get<NoisyLinear<>*>
138 (network.Model()[noisyLayerIndex[i]])->ResetNoise();
143 const arma::mat&
Parameters()
const {
return network.Parameters(); }
154 void Backward(
const arma::mat state, arma::mat& target, arma::mat& gradient)
156 network.Backward(state, target, gradient);
167 std::vector<size_t> noisyLayerIndex;
void Backward(const arma::mat state, arma::mat &target, arma::mat &gradient)
Perform the backward pass of the state in real batch mode.
Definition: simple_dqn.hpp:154
Artificial Neural Network.
Definition: elish_function.hpp:32
arma::mat & Parameters()
Modify the Parameters.
Definition: simple_dqn.hpp:145
void ResetNoise()
Resets noise of the network, if the network is of type noisy.
Definition: simple_dqn.hpp:133
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.
void Predict(const arma::mat state, arma::mat &actionValue)
Predict the responses to a given set of predictors.
Definition: simple_dqn.hpp:106
Implementation of the Linear layer class.
Definition: layer_types.hpp:93
void ResetParameters()
Resets the parameters of the network.
Definition: simple_dqn.hpp:125
Implementation of the base layer.
Definition: base_layer.hpp:71
Implementation of the NoisyLinear layer class.
Definition: layer_types.hpp:107
The mean squared error performance function measures the network's performance according to the mean ...
Definition: mean_squared_error.hpp:34
SimpleDQN()
Default constructor.
Definition: simple_dqn.hpp:42
Implementation of a standard feed forward network.
Definition: ffn.hpp:52
SimpleDQN(const int inputDim, const int h1, const int h2, const int outputDim, const bool isNoisy=false, InitType init=InitType(), OutputLayerType outputLayer=OutputLayerType())
Construct an instance of SimpleDQN class.
Definition: simple_dqn.hpp:56
const arma::mat & Parameters() const
Return the Parameters.
Definition: simple_dqn.hpp:143
SimpleDQN(NetworkType &network, const bool isNoisy=false)
Construct an instance of SimpleDQN class from a pre-constructed network.
Definition: simple_dqn.hpp:90
Definition: simple_dqn.hpp:36
void Forward(const arma::mat state, arma::mat &target)
Perform the forward pass of the states in real batch mode.
Definition: simple_dqn.hpp:117
This class is used to initialize weigth matrix with a gaussian.
Definition: gaussian_init.hpp:28