mlpack
sac_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_RL_SAC_IMPL_HPP
14 #define MLPACK_METHODS_RL_SAC_IMPL_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 #include "sac.hpp"
19 
20 namespace mlpack {
21 namespace rl {
22 
23 template <
24  typename EnvironmentType,
25  typename QNetworkType,
26  typename PolicyNetworkType,
27  typename UpdaterType,
28  typename ReplayType
29 >
30 SAC<
31  EnvironmentType,
32  QNetworkType,
33  PolicyNetworkType,
34  UpdaterType,
35  ReplayType
37  QNetworkType& learningQ1Network,
38  PolicyNetworkType& policyNetwork,
39  ReplayType& replayMethod,
40  UpdaterType qNetworkUpdater,
41  UpdaterType policyNetworkUpdater,
42  EnvironmentType environment):
43  config(config),
44  learningQ1Network(learningQ1Network),
45  policyNetwork(policyNetwork),
46  replayMethod(replayMethod),
47  qNetworkUpdater(std::move(qNetworkUpdater)),
48  #if ENS_VERSION_MAJOR >= 2
49  qNetworkUpdatePolicy(NULL),
50  #endif
51  policyNetworkUpdater(std::move(policyNetworkUpdater)),
52  #if ENS_VERSION_MAJOR >= 2
53  policyNetworkUpdatePolicy(NULL),
54  #endif
55  environment(std::move(environment)),
56  totalSteps(0),
57  deterministic(false)
58 {
59  // Set up q-learning and policy networks.
60  targetQ1Network = learningQ1Network;
61  learningQ2Network = learningQ1Network;
62  targetQ2Network = learningQ2Network;
63 
64  // Reset all the networks.
65  // Note: the q and policy networks have an if condition before reset.
66  // This is because we don't want to reset a loaded(possibly pretrained) model
67  // passed using this constructor.
68  if (learningQ1Network.Parameters().is_empty())
69  {
70  learningQ1Network.ResetParameters();
71  learningQ2Network.ResetParameters();
72  }
73  if (policyNetwork.Parameters().is_empty())
74  policyNetwork.ResetParameters();
75  targetQ1Network.ResetParameters();
76  targetQ2Network.ResetParameters();
77 
78  #if ENS_VERSION_MAJOR == 1
79  this->qNetworkUpdater.Initialize(learningQ1Network.Parameters().n_rows,
80  learningQ1Network.Parameters().n_cols);
81  #else
82  this->qNetworkUpdatePolicy = new typename UpdaterType::template
83  Policy<arma::mat, arma::mat>(this->qNetworkUpdater,
84  learningQ1Network.Parameters().n_rows,
85  learningQ1Network.Parameters().n_cols);
86  #endif
87 
88  #if ENS_VERSION_MAJOR == 1
89  this->policyNetworkUpdater.Initialize(policyNetwork.Parameters().n_rows,
90  policyNetwork.Parameters().n_cols);
91  #else
92  this->policyNetworkUpdatePolicy = new typename UpdaterType::template
93  Policy<arma::mat, arma::mat>(this->policyNetworkUpdater,
94  policyNetwork.Parameters().n_rows,
95  policyNetwork.Parameters().n_cols);
96  #endif
97 
98  // Copy over the learning networks to their respective target networks.
99  targetQ1Network.Parameters() = learningQ1Network.Parameters();
100  targetQ2Network.Parameters() = learningQ2Network.Parameters();
101 }
102 
103 template <
104  typename EnvironmentType,
105  typename QNetworkType,
106  typename PolicyNetworkType,
107  typename UpdaterType,
108  typename ReplayType
109 >
110 SAC<
111  EnvironmentType,
112  QNetworkType,
113  PolicyNetworkType,
114  UpdaterType,
115  ReplayType
117 {
118  #if ENS_VERSION_MAJOR >= 2
119  delete qNetworkUpdatePolicy;
120  delete policyNetworkUpdatePolicy;
121  #endif
122 }
123 
124 template <
125  typename EnvironmentType,
126  typename QNetworkType,
127  typename PolicyNetworkType,
128  typename UpdaterType,
129  typename ReplayType
130 >
131 void SAC<
132  EnvironmentType,
133  QNetworkType,
134  PolicyNetworkType,
135  UpdaterType,
136  ReplayType
137 >::SoftUpdate(double rho)
138 {
139  targetQ1Network.Parameters() = (1 - rho) * targetQ1Network.Parameters() +
140  rho * learningQ1Network.Parameters();
141  targetQ2Network.Parameters() = (1 - rho) * targetQ2Network.Parameters() +
142  rho * learningQ2Network.Parameters();
143 }
144 
145 template <
146  typename EnvironmentType,
147  typename QNetworkType,
148  typename PolicyNetworkType,
149  typename UpdaterType,
150  typename ReplayType
151 >
152 void SAC<
153  EnvironmentType,
154  QNetworkType,
155  PolicyNetworkType,
156  UpdaterType,
157  ReplayType
159 {
160  // Sample from previous experience.
161  arma::mat sampledStates;
162  std::vector<ActionType> sampledActions;
163  arma::rowvec sampledRewards;
164  arma::mat sampledNextStates;
165  arma::irowvec isTerminal;
166 
167  replayMethod.Sample(sampledStates, sampledActions, sampledRewards,
168  sampledNextStates, isTerminal);
169 
170  // Critic network update.
171 
172  // Get the actions for sampled next states, from policy.
173  arma::mat nextStateActions;
174  policyNetwork.Predict(sampledNextStates, nextStateActions);
175 
176  arma::mat targetQInput = arma::join_vert(nextStateActions,
177  sampledNextStates);
178  arma::rowvec Q1, Q2;
179  targetQ1Network.Predict(targetQInput, Q1);
180  targetQ2Network.Predict(targetQInput, Q2);
181  arma::rowvec nextQ = sampledRewards + config.Discount() * ((1 - isTerminal)
182  % arma::min(Q1, Q2));
183 
184  arma::mat sampledActionValues(action.size, sampledActions.size());
185  for (size_t i = 0; i < sampledActions.size(); i++)
186  sampledActionValues.col(i) = arma::conv_to<arma::colvec>::from
187  (sampledActions[i].action);
188  arma::mat learningQInput = arma::join_vert(sampledActionValues,
189  sampledStates);
190  learningQ1Network.Forward(learningQInput, Q1);
191  learningQ2Network.Forward(learningQInput, Q2);
192 
193  arma::mat gradQ1Loss, gradQ2Loss;
194  lossFunction.Backward(Q1, nextQ, gradQ1Loss);
195  lossFunction.Backward(Q2, nextQ, gradQ2Loss);
196 
197  // Update the critic networks.
198  arma::mat gradientQ1, gradientQ2;
199  learningQ1Network.Backward(learningQInput, gradQ1Loss, gradientQ1);
200  #if ENS_VERSION_MAJOR == 1
201  qNetworkUpdater.Update(learningQ1Network.Parameters(), config.StepSize(),
202  gradientQ1);
203  #else
204  qNetworkUpdatePolicy->Update(learningQ1Network.Parameters(),
205  config.StepSize(), gradientQ1);
206  #endif
207  learningQ2Network.Backward(learningQInput, gradQ2Loss, gradientQ2);
208  #if ENS_VERSION_MAJOR == 1
209  qNetworkUpdater.Update(learningQ2Network.Parameters(), config.StepSize(),
210  gradientQ2);
211  #else
212  qNetworkUpdatePolicy->Update(learningQ2Network.Parameters(),
213  config.StepSize(), gradientQ2);
214  #endif
215 
216  // Actor network update.
217 
218  arma::mat pi;
219  policyNetwork.Predict(sampledStates, pi);
220 
221  arma::mat qInput = arma::join_vert(pi, sampledStates);
222  learningQ1Network.Predict(qInput, Q1);
223  learningQ2Network.Predict(qInput, Q2);
224 
225  // Get the size of the first hidden layer in the Q network.
226  size_t hidden1 = boost::get<mlpack::ann::Linear<> *>
227  (learningQ1Network.Model()[0])->OutputSize();
228 
229  arma::mat gradient;
230  for (size_t i = 0; i < sampledStates.n_cols; i++)
231  {
232  arma::mat grad, gradQ, q;
233  arma::colvec singleState = sampledStates.col(i);
234  arma::colvec singlePi;
235  policyNetwork.Forward(singleState, singlePi);
236  arma::colvec input = arma::join_vert(singlePi, singleState);
237  arma::mat weightLastLayer;
238 
239  if (Q1(i) < Q2(i))
240  {
241  learningQ1Network.Forward(input, q);
242  learningQ1Network.Backward(input, -1, gradQ);
243  weightLastLayer = arma::reshape(learningQ1Network.Parameters().
244  rows(0, hidden1 * singlePi.n_rows - 1), hidden1, singlePi.n_rows);
245  }
246  else
247  {
248  learningQ2Network.Forward(input, q);
249  learningQ2Network.Backward(input, -1, gradQ);
250  weightLastLayer = arma::reshape(learningQ2Network.Parameters().
251  rows(0, hidden1 * singlePi.n_rows - 1), hidden1, singlePi.n_rows);
252  }
253 
254  arma::colvec gradQBias = gradQ(input.n_rows * hidden1, 0,
255  arma::size(hidden1, 1));
256  arma::mat gradPolicy = weightLastLayer.t() * gradQBias;
257  policyNetwork.Backward(singleState, gradPolicy, grad);
258  if (i == 0)
259  {
260  gradient.copy_size(grad);
261  gradient.fill(0.0);
262  }
263  gradient += grad;
264  }
265  gradient /= sampledStates.n_cols;
266 
267  #if ENS_VERSION_MAJOR == 1
268  policyUpdater.Update(policyNetwork.Parameters(), config.StepSize(), gradient);
269  #else
270  policyNetworkUpdatePolicy->Update(policyNetwork.Parameters(),
271  config.StepSize(), gradient);
272  #endif
273 
274  // Update target network
275  if (totalSteps % config.TargetNetworkSyncInterval() == 0)
276  SoftUpdate(config.Rho());
277 }
278 
279 template <
280  typename EnvironmentType,
281  typename QNetworkType,
282  typename PolicyNetworkType,
283  typename UpdaterType,
284  typename ReplayType
285 >
286 void SAC<
287  EnvironmentType,
288  QNetworkType,
289  PolicyNetworkType,
290  UpdaterType,
291  ReplayType
293 {
294  // Get the action at current state, from policy.
295  arma::colvec outputAction;
296  policyNetwork.Predict(state.Encode(), outputAction);
297 
298  if (!deterministic)
299  {
300  arma::colvec noise = arma::randn<arma::colvec>(outputAction.n_rows) * 0.1;
301  noise = arma::clamp(noise, -0.25, 0.25);
302  outputAction = outputAction + noise;
303  }
304  action.action = arma::conv_to<std::vector<double>>::from(outputAction);
305 }
306 
307 template <
308  typename EnvironmentType,
309  typename QNetworkType,
310  typename PolicyNetworkType,
311  typename UpdaterType,
312  typename ReplayType
313 >
314 double SAC<
315  EnvironmentType,
316  QNetworkType,
317  PolicyNetworkType,
318  UpdaterType,
319  ReplayType
321 {
322  // Get the initial state from environment.
323  state = environment.InitialSample();
324 
325  // Track the steps in this episode.
326  size_t steps = 0;
327 
328  // Track the return of this episode.
329  double totalReturn = 0.0;
330 
331  // Running until get to the terminal state.
332  while (!environment.IsTerminal(state))
333  {
334  if (config.StepLimit() && steps >= config.StepLimit())
335  break;
336  SelectAction();
337 
338  // Interact with the environment to advance to next state.
339  StateType nextState;
340  double reward = environment.Sample(state, action, nextState);
341 
342  totalReturn += reward;
343  steps++;
344  totalSteps++;
345 
346  // Store the transition for replay.
347  replayMethod.Store(state, action, reward, nextState,
348  environment.IsTerminal(nextState), config.Discount());
349 
350  // Update current state.
351  state = nextState;
352 
353  if (deterministic || totalSteps < config.ExplorationSteps())
354  continue;
355  for (size_t i = 0; i < config.UpdateInterval(); i++)
356  Update();
357  }
358  return totalReturn;
359 }
360 
361 } // namespace rl
362 } // namespace mlpack
363 #endif
size_t TargetNetworkSyncInterval() const
Get the interval for syncing target network.
Definition: training_config.hpp:84
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Implementation of Soft Actor-Critic, a model-free off-policy actor-critic based deep reinforcement le...
Definition: sac.hpp:64
void Backward(const PredictionType &prediction, const TargetType &target, LossType &loss)
Ordinary feed backward pass of a neural network.
Definition: mean_squared_error_impl.hpp:39
The core includes that mlpack expects; standard C++ includes and Armadillo.
Definition: pointer_wrapper.hpp:23
size_t UpdateInterval() const
Get the update interval.
Definition: training_config.hpp:79
double StepSize() const
Get the step size of the optimizer.
Definition: training_config.hpp:103
double Rho() const
Get the rho value for sac.
Definition: training_config.hpp:148
typename EnvironmentType::State StateType
Convenient typedef for state.
Definition: sac.hpp:68
void SoftUpdate(double rho)
Softly update the learning Q network parameters to the target Q network parameters.
Definition: sac_impl.hpp:137
size_t StepLimit() const
Get the maximum steps of each episode.
Definition: training_config.hpp:90
void Update()
Update the Q and policy networks.
Definition: sac_impl.hpp:158
Definition: training_config.hpp:19
~SAC()
Clean memory.
Definition: sac_impl.hpp:116
void SelectAction()
Select an action, given an agent.
Definition: sac_impl.hpp:292
double Episode()
Execute an episode.
Definition: sac_impl.hpp:320
double Discount() const
Get the discount rate for future reward.
Definition: training_config.hpp:108
size_t ExplorationSteps() const
Get the exploration steps.
Definition: training_config.hpp:98