12 #ifndef MLPACK_METHODS_RL_Q_LEARNING_IMPL_HPP 13 #define MLPACK_METHODS_RL_Q_LEARNING_IMPL_HPP 21 typename EnvironmentType,
36 ReplayType& replayMethod,
38 EnvironmentType environment):
40 learningNetwork(network),
42 replayMethod(replayMethod),
43 updater(
std::move(updater)),
44 #if ENS_VERSION_MAJOR >= 2
47 environment(
std::move(environment)),
52 targetNetwork = learningNetwork;
55 if (learningNetwork.Parameters().is_empty())
56 learningNetwork.ResetParameters();
58 targetNetwork.ResetParameters();
60 #if ENS_VERSION_MAJOR == 1 61 this->updater.Initialize(learningNetwork.Parameters().n_rows,
62 learningNetwork.Parameters().n_cols);
64 this->updatePolicy =
new typename UpdaterType::template
65 Policy<arma::mat, arma::mat>(this->updater,
66 learningNetwork.Parameters().n_rows,
67 learningNetwork.Parameters().n_cols);
71 targetNetwork.Parameters() = learningNetwork.Parameters();
75 typename EnvironmentType,
89 #if ENS_VERSION_MAJOR >= 2 95 typename EnvironmentType,
107 >::BestAction(
const arma::mat& actionValues)
110 arma::Col<size_t> bestActions(actionValues.n_cols);
111 arma::rowvec maxActionValues = arma::max(actionValues, 0);
112 for (
size_t i = 0; i < actionValues.n_cols; ++i)
114 bestActions(i) = arma::as_scalar(
115 arma::find(actionValues.col(i) == maxActionValues[i], 1));
121 typename EnvironmentType,
122 typename NetworkType,
123 typename UpdaterType,
124 typename BehaviorPolicyType,
138 arma::mat sampledStates;
139 std::vector<ActionType> sampledActions;
140 arma::rowvec sampledRewards;
141 arma::mat sampledNextStates;
142 arma::irowvec isTerminal;
144 replayMethod.Sample(sampledStates, sampledActions, sampledRewards,
145 sampledNextStates, isTerminal);
148 arma::mat nextActionValues;
149 targetNetwork.Predict(sampledNextStates, nextActionValues);
151 arma::Col<size_t> bestActions;
155 arma::mat nextActionValues;
156 learningNetwork.Predict(sampledNextStates, nextActionValues);
157 bestActions = BestAction(nextActionValues);
161 bestActions = BestAction(nextActionValues);
166 learningNetwork.Forward(sampledStates, target);
168 double discount = std::pow(config.
Discount(), replayMethod.NSteps());
175 for (
size_t i = 0; i < sampledNextStates.n_cols; ++i)
177 target(sampledActions[i].action, i) = sampledRewards(i) + discount *
178 nextActionValues(bestActions(i), i) * (1 - isTerminal[i]);
183 learningNetwork.Backward(sampledStates, target, gradients);
185 replayMethod.Update(target, sampledActions, nextActionValues, gradients);
187 #if ENS_VERSION_MAJOR == 1 188 updater.Update(learningNetwork.Parameters(), config.
StepSize(), gradients);
190 updatePolicy->Update(learningNetwork.Parameters(), config.
StepSize(),
196 learningNetwork.ResetNoise();
197 targetNetwork.ResetNoise();
201 targetNetwork.Parameters() = learningNetwork.Parameters();
208 typename EnvironmentType,
209 typename NetworkType,
210 typename UpdaterType,
211 typename BehaviorPolicyType,
225 arma::mat sampledStates;
226 std::vector<ActionType> sampledActions;
227 arma::rowvec sampledRewards;
228 arma::mat sampledNextStates;
229 arma::irowvec isTerminal;
231 replayMethod.Sample(sampledStates, sampledActions, sampledRewards,
232 sampledNextStates, isTerminal);
234 size_t atomSize = config.
AtomSize();
235 arma::colvec support = arma::linspace<arma::colvec>(config.
VMin(),
236 config.
VMax(), atomSize);
238 size_t batchSize = sampledNextStates.n_cols;
241 arma::mat nextActionValues;
242 targetNetwork.Predict(sampledNextStates, nextActionValues);
244 arma::Col<size_t> nextAction;
248 arma::mat nextActionValues;
249 learningNetwork.Predict(sampledNextStates, nextActionValues);
250 nextAction = BestAction(nextActionValues);
254 nextAction = BestAction(nextActionValues);
257 arma::mat nextDists, nextDist(atomSize, batchSize);
258 targetNetwork.Forward(sampledNextStates, nextDists);
259 for (
size_t i = 0; i < batchSize; ++i)
261 nextDist.col(i) = nextDists(nextAction(i) * atomSize, i,
262 arma::size(atomSize, 1));
265 arma::mat tZ = (arma::conv_to<arma::mat>::from(config.
Discount() *
266 (support * (1 - isTerminal))).each_row() + sampledRewards);
267 tZ = arma::clamp(tZ, config.
VMin(), config.
VMax());
268 arma::mat b = (tZ - config.
VMin()) / (config.
VMax() - config.
VMin()) *
270 arma::mat l = arma::floor(b);
271 arma::mat u = arma::ceil(b);
273 arma::mat projDistUpper = nextDist % (u - b);
274 arma::mat projDistLower = nextDist % (b - l);
276 arma::mat projDist = arma::zeros<arma::mat>(arma::size(nextDist));
277 for (
size_t batchNo = 0; batchNo < batchSize; batchNo++)
279 for (
size_t j = 0; j < atomSize; j++)
281 projDist(l(j, batchNo), batchNo) += projDistUpper(j, batchNo);
282 projDist(u(j, batchNo), batchNo) += projDistLower(j, batchNo);
286 learningNetwork.Forward(sampledStates, dists);
287 arma::mat lossGradients = arma::zeros<arma::mat>(arma::size(dists));
288 for (
size_t i = 0; i < batchSize; ++i)
290 lossGradients(sampledActions[i].action * atomSize, i,
291 arma::size(atomSize, 1)) = -(projDist.col(i) / (1e-10 + dists(
292 sampledActions[i].action * atomSize, i, arma::size(atomSize, 1))));
296 learningNetwork.Backward(sampledStates, lossGradients, gradients);
298 #if ENS_VERSION_MAJOR == 1 299 updater.Update(learningNetwork.Parameters(), config.
StepSize(), gradients);
301 updatePolicy->Update(learningNetwork.Parameters(), config.
StepSize(),
307 learningNetwork.ResetNoise();
308 targetNetwork.ResetNoise();
312 targetNetwork.Parameters() = learningNetwork.Parameters();
319 typename EnvironmentType,
320 typename NetworkType,
321 typename UpdaterType,
322 typename BehaviorPolicyType,
334 arma::colvec actionValue;
335 learningNetwork.Predict(state.Encode(), actionValue);
338 action = policy.Sample(actionValue, deterministic, config.
NoisyQLearning());
342 typename EnvironmentType,
343 typename NetworkType,
344 typename UpdaterType,
345 typename BehaviorPolicyType,
357 state = environment.InitialSample();
360 double totalReturn = 0.0;
363 while (!environment.IsTerminal(state))
369 double reward = environment.Sample(state, action, nextState);
371 totalReturn += reward;
375 replayMethod.Store(state, action, reward, nextState,
376 environment.IsTerminal(nextState), config.
Discount());
bool DoubleQLearning() const
Get the indicator of double q-learning.
Definition: training_config.hpp:118
size_t TargetNetworkSyncInterval() const
Get the interval for syncing target network.
Definition: training_config.hpp:84
typename EnvironmentType::State StateType
Convenient typedef for state.
Definition: q_learning.hpp:63
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
double VMax() const
Get the maximum value for support.
Definition: training_config.hpp:143
~QLearning()
Clean memory.
Definition: q_learning_impl.hpp:87
Definition: pointer_wrapper.hpp:23
void SelectAction()
Select an action, given an agent.
Definition: q_learning_impl.hpp:331
double StepSize() const
Get the step size of the optimizer.
Definition: training_config.hpp:103
double VMin() const
Get the minimum value for support.
Definition: training_config.hpp:138
double Episode()
Execute an episode.
Definition: q_learning_impl.hpp:354
void TrainAgent()
Trains the DQN agent(non-categorical).
Definition: q_learning_impl.hpp:133
bool NoisyQLearning() const
Get the indicator of noisy q-learning.
Definition: training_config.hpp:123
Definition: training_config.hpp:19
Implementation of various Q-Learning algorithms, such as DQN, double DQN.
Definition: q_learning.hpp:59
double Discount() const
Get the discount rate for future reward.
Definition: training_config.hpp:108
size_t ExplorationSteps() const
Get the exploration steps.
Definition: training_config.hpp:98
bool IsCategorical() const
Get the indicator of categorical q-learning.
Definition: training_config.hpp:128
void TrainCategoricalAgent()
Trains the DQN agent of categorical type.
Definition: q_learning_impl.hpp:220
size_t AtomSize() const
Get the number of atoms.
Definition: training_config.hpp:133