mlpack
q_learning_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_RL_Q_LEARNING_IMPL_HPP
13 #define MLPACK_METHODS_RL_Q_LEARNING_IMPL_HPP
14 
15 #include "q_learning.hpp"
16 
17 namespace mlpack {
18 namespace rl {
19 
20 template <
21  typename EnvironmentType,
22  typename NetworkType,
23  typename UpdaterType,
24  typename PolicyType,
25  typename ReplayType
26 >
27 QLearning<
28  EnvironmentType,
29  NetworkType,
30  UpdaterType,
31  PolicyType,
32  ReplayType
34  NetworkType& network,
35  PolicyType& policy,
36  ReplayType& replayMethod,
37  UpdaterType updater,
38  EnvironmentType environment):
39  config(config),
40  learningNetwork(network),
41  policy(policy),
42  replayMethod(replayMethod),
43  updater(std::move(updater)),
44  #if ENS_VERSION_MAJOR >= 2
45  updatePolicy(NULL),
46  #endif
47  environment(std::move(environment)),
48  totalSteps(0),
49  deterministic(false)
50 {
51  // To copy over the network structure.
52  targetNetwork = learningNetwork;
53 
54  // Set up q-learning network.
55  if (learningNetwork.Parameters().is_empty())
56  learningNetwork.ResetParameters();
57 
58  targetNetwork.ResetParameters();
59 
60  #if ENS_VERSION_MAJOR == 1
61  this->updater.Initialize(learningNetwork.Parameters().n_rows,
62  learningNetwork.Parameters().n_cols);
63  #else
64  this->updatePolicy = new typename UpdaterType::template
65  Policy<arma::mat, arma::mat>(this->updater,
66  learningNetwork.Parameters().n_rows,
67  learningNetwork.Parameters().n_cols);
68  #endif
69 
70  // Initialize the target network with the parameters of learning network.
71  targetNetwork.Parameters() = learningNetwork.Parameters();
72 }
73 
74 template <
75  typename EnvironmentType,
76  typename NetworkType,
77  typename UpdaterType,
78  typename PolicyType,
79  typename ReplayType
80 >
81 QLearning<
82  EnvironmentType,
83  NetworkType,
84  UpdaterType,
85  PolicyType,
86  ReplayType
88 {
89  #if ENS_VERSION_MAJOR >= 2
90  delete updatePolicy;
91  #endif
92 }
93 
94 template <
95  typename EnvironmentType,
96  typename NetworkType,
97  typename UpdaterType,
98  typename PolicyType,
99  typename ReplayType
100 >
101 arma::Col<size_t> QLearning<
102  EnvironmentType,
103  NetworkType,
104  UpdaterType,
105  PolicyType,
106  ReplayType
107 >::BestAction(const arma::mat& actionValues)
108 {
109  // Take best possible action at a particular instance.
110  arma::Col<size_t> bestActions(actionValues.n_cols);
111  arma::rowvec maxActionValues = arma::max(actionValues, 0);
112  for (size_t i = 0; i < actionValues.n_cols; ++i)
113  {
114  bestActions(i) = arma::as_scalar(
115  arma::find(actionValues.col(i) == maxActionValues[i], 1));
116  }
117  return bestActions;
118 };
119 
120 template <
121  typename EnvironmentType,
122  typename NetworkType,
123  typename UpdaterType,
124  typename BehaviorPolicyType,
125  typename ReplayType
126 >
127 void QLearning<
128  EnvironmentType,
129  NetworkType,
130  UpdaterType,
131  BehaviorPolicyType,
132  ReplayType
134 {
135  // Start experience replay.
136 
137  // Sample from previous experience.
138  arma::mat sampledStates;
139  std::vector<ActionType> sampledActions;
140  arma::rowvec sampledRewards;
141  arma::mat sampledNextStates;
142  arma::irowvec isTerminal;
143 
144  replayMethod.Sample(sampledStates, sampledActions, sampledRewards,
145  sampledNextStates, isTerminal);
146 
147  // Compute action value for next state with target network.
148  arma::mat nextActionValues;
149  targetNetwork.Predict(sampledNextStates, nextActionValues);
150 
151  arma::Col<size_t> bestActions;
152  if (config.DoubleQLearning())
153  {
154  // If use double Q-Learning, use learning network to select the best action.
155  arma::mat nextActionValues;
156  learningNetwork.Predict(sampledNextStates, nextActionValues);
157  bestActions = BestAction(nextActionValues);
158  }
159  else
160  {
161  bestActions = BestAction(nextActionValues);
162  }
163 
164  // Compute the update target.
165  arma::mat target;
166  learningNetwork.Forward(sampledStates, target);
167 
168  double discount = std::pow(config.Discount(), replayMethod.NSteps());
169 
175  for (size_t i = 0; i < sampledNextStates.n_cols; ++i)
176  {
177  target(sampledActions[i].action, i) = sampledRewards(i) + discount *
178  nextActionValues(bestActions(i), i) * (1 - isTerminal[i]);
179  }
180 
181  // Learn from experience.
182  arma::mat gradients;
183  learningNetwork.Backward(sampledStates, target, gradients);
184 
185  replayMethod.Update(target, sampledActions, nextActionValues, gradients);
186 
187  #if ENS_VERSION_MAJOR == 1
188  updater.Update(learningNetwork.Parameters(), config.StepSize(), gradients);
189  #else
190  updatePolicy->Update(learningNetwork.Parameters(), config.StepSize(),
191  gradients);
192  #endif
193 
194  if (config.NoisyQLearning() == true)
195  {
196  learningNetwork.ResetNoise();
197  targetNetwork.ResetNoise();
198  }
199  // Update target network.
200  if (totalSteps % config.TargetNetworkSyncInterval() == 0)
201  targetNetwork.Parameters() = learningNetwork.Parameters();
202 
203  if (totalSteps > config.ExplorationSteps())
204  policy.Anneal();
205 }
206 
207 template <
208  typename EnvironmentType,
209  typename NetworkType,
210  typename UpdaterType,
211  typename BehaviorPolicyType,
212  typename ReplayType
213 >
214 void QLearning<
215  EnvironmentType,
216  NetworkType,
217  UpdaterType,
218  BehaviorPolicyType,
219  ReplayType
221 {
222  // Start experience replay.
223 
224  // Sample from previous experience.
225  arma::mat sampledStates;
226  std::vector<ActionType> sampledActions;
227  arma::rowvec sampledRewards;
228  arma::mat sampledNextStates;
229  arma::irowvec isTerminal;
230 
231  replayMethod.Sample(sampledStates, sampledActions, sampledRewards,
232  sampledNextStates, isTerminal);
233 
234  size_t atomSize = config.AtomSize();
235  arma::colvec support = arma::linspace<arma::colvec>(config.VMin(),
236  config.VMax(), atomSize);
237 
238  size_t batchSize = sampledNextStates.n_cols;
239 
240  // Compute action value for next state with target network.
241  arma::mat nextActionValues;
242  targetNetwork.Predict(sampledNextStates, nextActionValues);
243 
244  arma::Col<size_t> nextAction;
245  if (config.DoubleQLearning())
246  {
247  // If use double Q-Learning, use learning network to select the best action.
248  arma::mat nextActionValues;
249  learningNetwork.Predict(sampledNextStates, nextActionValues);
250  nextAction = BestAction(nextActionValues);
251  }
252  else
253  {
254  nextAction = BestAction(nextActionValues);
255  }
256 
257  arma::mat nextDists, nextDist(atomSize, batchSize);
258  targetNetwork.Forward(sampledNextStates, nextDists);
259  for (size_t i = 0; i < batchSize; ++i)
260  {
261  nextDist.col(i) = nextDists(nextAction(i) * atomSize, i,
262  arma::size(atomSize, 1));
263  }
264 
265  arma::mat tZ = (arma::conv_to<arma::mat>::from(config.Discount() *
266  (support * (1 - isTerminal))).each_row() + sampledRewards);
267  tZ = arma::clamp(tZ, config.VMin(), config.VMax());
268  arma::mat b = (tZ - config.VMin()) / (config.VMax() - config.VMin()) *
269  (atomSize - 1);
270  arma::mat l = arma::floor(b);
271  arma::mat u = arma::ceil(b);
272 
273  arma::mat projDistUpper = nextDist % (u - b);
274  arma::mat projDistLower = nextDist % (b - l);
275 
276  arma::mat projDist = arma::zeros<arma::mat>(arma::size(nextDist));
277  for (size_t batchNo = 0; batchNo < batchSize; batchNo++)
278  {
279  for (size_t j = 0; j < atomSize; j++)
280  {
281  projDist(l(j, batchNo), batchNo) += projDistUpper(j, batchNo);
282  projDist(u(j, batchNo), batchNo) += projDistLower(j, batchNo);
283  }
284  }
285  arma::mat dists;
286  learningNetwork.Forward(sampledStates, dists);
287  arma::mat lossGradients = arma::zeros<arma::mat>(arma::size(dists));
288  for (size_t i = 0; i < batchSize; ++i)
289  {
290  lossGradients(sampledActions[i].action * atomSize, i,
291  arma::size(atomSize, 1)) = -(projDist.col(i) / (1e-10 + dists(
292  sampledActions[i].action * atomSize, i, arma::size(atomSize, 1))));
293  }
294  // Learn from experience.
295  arma::mat gradients;
296  learningNetwork.Backward(sampledStates, lossGradients, gradients);
297 
298  #if ENS_VERSION_MAJOR == 1
299  updater.Update(learningNetwork.Parameters(), config.StepSize(), gradients);
300  #else
301  updatePolicy->Update(learningNetwork.Parameters(), config.StepSize(),
302  gradients);
303  #endif
304 
305  if (config.NoisyQLearning() == true)
306  {
307  learningNetwork.ResetNoise();
308  targetNetwork.ResetNoise();
309  }
310  // Update target network.
311  if (totalSteps % config.TargetNetworkSyncInterval() == 0)
312  targetNetwork.Parameters() = learningNetwork.Parameters();
313 
314  if (totalSteps > config.ExplorationSteps())
315  policy.Anneal();
316 }
317 
318 template <
319  typename EnvironmentType,
320  typename NetworkType,
321  typename UpdaterType,
322  typename BehaviorPolicyType,
323  typename ReplayType
324 >
325 void QLearning<
326  EnvironmentType,
327  NetworkType,
328  UpdaterType,
329  BehaviorPolicyType,
330  ReplayType
332 {
333  // Get the action value for each action at current state.
334  arma::colvec actionValue;
335  learningNetwork.Predict(state.Encode(), actionValue);
336 
337  // Select an action according to the behavior policy.
338  action = policy.Sample(actionValue, deterministic, config.NoisyQLearning());
339 }
340 
341 template <
342  typename EnvironmentType,
343  typename NetworkType,
344  typename UpdaterType,
345  typename BehaviorPolicyType,
346  typename ReplayType
347 >
348 double QLearning<
349  EnvironmentType,
350  NetworkType,
351  UpdaterType,
352  BehaviorPolicyType,
353  ReplayType
355 {
356  // Get the initial state from environment.
357  state = environment.InitialSample();
358 
359  // Track the return of this episode.
360  double totalReturn = 0.0;
361 
362  // Running until get to the terminal state.
363  while (!environment.IsTerminal(state))
364  {
365  SelectAction();
366 
367  // Interact with the environment to advance to next state.
368  StateType nextState;
369  double reward = environment.Sample(state, action, nextState);
370 
371  totalReturn += reward;
372  totalSteps++;
373 
374  // Store the transition for replay.
375  replayMethod.Store(state, action, reward, nextState,
376  environment.IsTerminal(nextState), config.Discount());
377  // Update current state.
378  state = nextState;
379 
380  if (deterministic || totalSteps < config.ExplorationSteps())
381  continue;
382  if (config.IsCategorical())
384  else
385  TrainAgent();
386  }
387  return totalReturn;
388 }
389 
390 } // namespace rl
391 } // namespace mlpack
392 
393 #endif
bool DoubleQLearning() const
Get the indicator of double q-learning.
Definition: training_config.hpp:118
size_t TargetNetworkSyncInterval() const
Get the interval for syncing target network.
Definition: training_config.hpp:84
typename EnvironmentType::State StateType
Convenient typedef for state.
Definition: q_learning.hpp:63
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
double VMax() const
Get the maximum value for support.
Definition: training_config.hpp:143
~QLearning()
Clean memory.
Definition: q_learning_impl.hpp:87
Definition: pointer_wrapper.hpp:23
void SelectAction()
Select an action, given an agent.
Definition: q_learning_impl.hpp:331
double StepSize() const
Get the step size of the optimizer.
Definition: training_config.hpp:103
double VMin() const
Get the minimum value for support.
Definition: training_config.hpp:138
double Episode()
Execute an episode.
Definition: q_learning_impl.hpp:354
void TrainAgent()
Trains the DQN agent(non-categorical).
Definition: q_learning_impl.hpp:133
bool NoisyQLearning() const
Get the indicator of noisy q-learning.
Definition: training_config.hpp:123
Definition: training_config.hpp:19
Implementation of various Q-Learning algorithms, such as DQN, double DQN.
Definition: q_learning.hpp:59
double Discount() const
Get the discount rate for future reward.
Definition: training_config.hpp:108
size_t ExplorationSteps() const
Get the exploration steps.
Definition: training_config.hpp:98
bool IsCategorical() const
Get the indicator of categorical q-learning.
Definition: training_config.hpp:128
void TrainCategoricalAgent()
Trains the DQN agent of categorical type.
Definition: q_learning_impl.hpp:220
size_t AtomSize() const
Get the number of atoms.
Definition: training_config.hpp:133