12 #ifndef MLPACK_METHODS_RL_CATEGORICAL_DQN_HPP 13 #define MLPACK_METHODS_RL_CATEGORICAL_DQN_HPP 21 #include "../training_config.hpp" 57 network(), atomSize(0), vMin(0.0), vMax(0.0), isNoisy(false)
77 const bool isNoisy =
false,
78 InitType init = InitType(),
79 OutputLayerType outputLayer = OutputLayerType()):
80 network(outputLayer, init),
81 atomSize(config.AtomSize()),
86 network.Add(
new Linear<>(inputDim, h1));
90 noisyLayerIndex.push_back(network.Model().size());
93 noisyLayerIndex.push_back(network.Model().size());
100 network.Add(
new Linear<>(h2, outputDim * atomSize));
114 const bool isNoisy =
false):
115 network(
std::move(network)),
116 atomSize(config.AtomSize()),
133 void Predict(
const arma::mat state, arma::mat& actionValue)
136 network.Predict(state, q_atoms);
137 activations.copy_size(q_atoms);
138 actionValue.set_size(q_atoms.n_rows / atomSize, q_atoms.n_cols);
139 arma::rowvec support = arma::linspace<arma::rowvec>(vMin, vMax, atomSize);
140 for (
size_t i = 0; i < q_atoms.n_rows; i += atomSize)
142 arma::mat activation = activations.rows(i, i + atomSize - 1);
143 arma::mat input = q_atoms.rows(i, i + atomSize - 1);
144 softMax.Forward(input, activation);
145 activations.rows(i, i + atomSize - 1) = activation;
146 actionValue.row(i/atomSize) = support * activation;
156 void Forward(
const arma::mat state, arma::mat& dist)
159 network.Forward(state, q_atoms);
160 activations.copy_size(q_atoms);
161 for (
size_t i = 0; i < q_atoms.n_rows; i += atomSize)
163 arma::mat activation = activations.rows(i, i + atomSize - 1);
164 arma::mat input = q_atoms.rows(i, i + atomSize - 1);
165 softMax.Forward(input, activation);
166 activations.rows(i, i + atomSize - 1) = activation;
176 network.ResetParameters();
184 for (
size_t i = 0; i < noisyLayerIndex.size(); i++)
186 boost::get<NoisyLinear<>*>
187 (network.Model()[noisyLayerIndex[i]])->ResetNoise();
192 const arma::mat&
Parameters()
const {
return network.Parameters(); }
204 arma::mat& lossGradients,
207 arma::mat activationGradients(arma::size(activations));
208 for (
size_t i = 0; i < activations.n_rows; i += atomSize)
210 arma::mat activationGrad;
211 arma::mat lossGrad = lossGradients.rows(i, i + atomSize - 1);
212 arma::mat activation = activations.rows(i, i + atomSize - 1);
213 softMax.Backward(activation, lossGrad, activationGrad);
214 activationGradients.rows(i, i + atomSize - 1) = activationGrad;
216 network.Backward(state, activationGradients, gradient);
236 std::vector<size_t> noisyLayerIndex;
242 arma::mat activations;
Artificial Neural Network.
Definition: elish_function.hpp:32
void ResetNoise()
Resets noise of the network, if the network is of type noisy.
Definition: categorical_dqn.hpp:182
void Forward(const arma::mat state, arma::mat &dist)
Perform the forward pass of the states in real batch mode.
Definition: categorical_dqn.hpp:156
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
const arma::mat & Parameters() const
Return the Parameters.
Definition: categorical_dqn.hpp:192
The core includes that mlpack expects; standard C++ includes and Armadillo.
Implementation of the Linear layer class.
Definition: layer_types.hpp:93
Definition: pointer_wrapper.hpp:23
The empty loss does nothing, letting the user calculate the loss outside the model.
Definition: empty_loss.hpp:35
Implementation of the Softmax layer.
Definition: softmax.hpp:38
Implementation of the base layer.
Definition: base_layer.hpp:71
Definition: training_config.hpp:19
Implementation of the NoisyLinear layer class.
Definition: layer_types.hpp:107
void Predict(const arma::mat state, arma::mat &actionValue)
Predict the responses to a given set of predictors.
Definition: categorical_dqn.hpp:133
void ResetParameters()
Resets the parameters of the network.
Definition: categorical_dqn.hpp:174
CategoricalDQN(const int inputDim, const int h1, const int h2, const int outputDim, TrainingConfig config, const bool isNoisy=false, InitType init=InitType(), OutputLayerType outputLayer=OutputLayerType())
Construct an instance of CategoricalDQN class.
Definition: categorical_dqn.hpp:72
CategoricalDQN()
Default constructor.
Definition: categorical_dqn.hpp:56
Implementation of a standard feed forward network.
Definition: ffn.hpp:52
CategoricalDQN(NetworkType &network, TrainingConfig config, const bool isNoisy=false)
Construct an instance of CategoricalDQN class from a pre-constructed network.
Definition: categorical_dqn.hpp:112
arma::mat & Parameters()
Modify the Parameters.
Definition: categorical_dqn.hpp:194
void Backward(const arma::mat state, arma::mat &lossGradients, arma::mat &gradient)
Perform the backward pass of the state in real batch mode.
Definition: categorical_dqn.hpp:203
Implementation of the Categorical Deep Q-Learning network.
Definition: categorical_dqn.hpp:50
This class is used to initialize weigth matrix with a gaussian.
Definition: gaussian_init.hpp:28