12 #ifndef MLPACK_METHODS_RL_DUELING_DQN_HPP 13 #define MLPACK_METHODS_RL_DUELING_DQN_HPP 67 concat->Add(valueNetwork);
68 concat->Add(advantageNetwork);
70 completeNetwork.Add(featureNetwork);
71 completeNetwork.Add(concat);
89 const bool isNoisy =
false,
90 InitType init = InitType(),
91 OutputLayerType outputLayer = OutputLayerType()):
92 completeNetwork(outputLayer, init),
96 featureNetwork->Add(
new Linear<>(inputDim, h1));
104 noisyLayerIndex.push_back(valueNetwork->Model().size());
111 noisyLayerIndex.push_back(valueNetwork->Model().size());
117 valueNetwork->Add(
new Linear<>(h1, h2));
119 valueNetwork->Add(
new Linear<>(h2, 1));
121 advantageNetwork->Add(
new Linear<>(h1, h2));
123 advantageNetwork->Add(
new Linear<>(h2, outputDim));
127 concat->Add(valueNetwork);
128 concat->Add(advantageNetwork);
131 completeNetwork.Add(featureNetwork);
132 completeNetwork.Add(concat);
133 this->ResetParameters();
145 AdvantageNetworkType& advantageNetwork,
146 ValueNetworkType& valueNetwork,
147 const bool isNoisy =
false):
148 featureNetwork(featureNetwork),
149 advantageNetwork(advantageNetwork),
150 valueNetwork(valueNetwork),
154 concat->Add(valueNetwork);
155 concat->Add(advantageNetwork);
157 completeNetwork.Add(featureNetwork);
158 completeNetwork.Add(concat);
159 this->ResetParameters();
169 *valueNetwork = *model.valueNetwork;
170 *advantageNetwork = *model.advantageNetwork;
171 *featureNetwork = *model.featureNetwork;
172 isNoisy = model.isNoisy;
173 noisyLayerIndex = model.noisyLayerIndex;
187 void Predict(
const arma::mat state, arma::mat& actionValue)
189 arma::mat advantage, value, networkOutput;
190 completeNetwork.Predict(state, networkOutput);
191 value = networkOutput.row(0);
192 advantage = networkOutput.rows(1, networkOutput.n_rows - 1);
193 actionValue = advantage.each_row() +
194 (value - arma::mean(advantage));
203 void Forward(
const arma::mat state, arma::mat& actionValue)
205 arma::mat advantage, value, networkOutput;
206 completeNetwork.Forward(state, networkOutput);
207 value = networkOutput.row(0);
208 advantage = networkOutput.rows(1, networkOutput.n_rows - 1);
209 actionValue = advantage.each_row() +
210 (value - arma::mean(advantage));
211 this->actionValues = actionValue;
221 void Backward(
const arma::mat state, arma::mat& target, arma::mat& gradient)
224 lossFunction.Backward(this->actionValues, target, gradLoss);
226 arma::mat gradValue = arma::sum(gradLoss);
227 arma::mat gradAdvantage = gradLoss.each_row() - arma::mean(gradLoss);
229 arma::mat grad = arma::join_cols(gradValue, gradAdvantage);
230 completeNetwork.Backward(state, grad, gradient);
238 completeNetwork.ResetParameters();
246 for (
size_t i = 0; i < noisyLayerIndex.size(); i++)
248 boost::get<NoisyLinear<>*>
249 (valueNetwork->Model()[noisyLayerIndex[i]])->ResetNoise();
250 boost::get<NoisyLinear<>*>
251 (advantageNetwork->Model()[noisyLayerIndex[i]])->ResetNoise();
256 const arma::mat&
Parameters()
const {
return completeNetwork.Parameters(); }
258 arma::mat&
Parameters() {
return completeNetwork.Parameters(); }
262 CompleteNetworkType completeNetwork;
268 FeatureNetworkType* featureNetwork;
271 AdvantageNetworkType* advantageNetwork;
274 ValueNetworkType* valueNetwork;
280 std::vector<size_t> noisyLayerIndex;
283 arma::mat actionValues;
const arma::mat & Parameters() const
Return the Parameters.
Definition: dueling_dqn.hpp:256
Artificial Neural Network.
Definition: elish_function.hpp:32
DuelingDQN()
Default constructor.
Definition: dueling_dqn.hpp:60
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Predict(const arma::mat state, arma::mat &actionValue)
Predict the responses to a given set of predictors.
Definition: dueling_dqn.hpp:187
void Forward(const arma::mat state, arma::mat &actionValue)
Perform the forward pass of the states in real batch mode.
Definition: dueling_dqn.hpp:203
The core includes that mlpack expects; standard C++ includes and Armadillo.
Implementation of the Linear layer class.
Definition: layer_types.hpp:93
Implementation of the Dueling Deep Q-Learning network.
Definition: dueling_dqn.hpp:56
The empty loss does nothing, letting the user calculate the loss outside the model.
Definition: empty_loss.hpp:35
Implementation of the base layer.
Definition: base_layer.hpp:71
DuelingDQN(FeatureNetworkType &featureNetwork, AdvantageNetworkType &advantageNetwork, ValueNetworkType &valueNetwork, const bool isNoisy=false)
Construct an instance of DuelingDQN class from a pre-constructed network.
Definition: dueling_dqn.hpp:144
Implementation of the Concat class.
Definition: concat.hpp:43
void ResetNoise()
Resets noise of the network, if the network is of type noisy.
Definition: dueling_dqn.hpp:244
void ResetParameters()
Resets the parameters of the network.
Definition: dueling_dqn.hpp:236
Implementation of the NoisyLinear layer class.
Definition: layer_types.hpp:107
void Backward(const arma::mat state, arma::mat &target, arma::mat &gradient)
Perform the backward pass of the state in real batch mode.
Definition: dueling_dqn.hpp:221
arma::mat & Parameters()
Modify the Parameters.
Definition: dueling_dqn.hpp:258
DuelingDQN(const DuelingDQN &)
Copy constructor.
Definition: dueling_dqn.hpp:163
DuelingDQN(const int inputDim, const int h1, const int h2, const int outputDim, const bool isNoisy=false, InitType init=InitType(), OutputLayerType outputLayer=OutputLayerType())
Construct an instance of DuelingDQN class.
Definition: dueling_dqn.hpp:85
The mean squared error performance function measures the network's performance according to the mean ...
Definition: mean_squared_error.hpp:34
Implementation of a standard feed forward network.
Definition: ffn.hpp:52
Implementation of the Sequential class.
Definition: layer_types.hpp:145
This class is used to initialize weigth matrix with a gaussian.
Definition: gaussian_init.hpp:28