13 #ifndef MLPACK_METHODS_RL_SAC_IMPL_HPP 14 #define MLPACK_METHODS_RL_SAC_IMPL_HPP 24 typename EnvironmentType,
25 typename QNetworkType,
26 typename PolicyNetworkType,
37 QNetworkType& learningQ1Network,
38 PolicyNetworkType& policyNetwork,
39 ReplayType& replayMethod,
40 UpdaterType qNetworkUpdater,
41 UpdaterType policyNetworkUpdater,
42 EnvironmentType environment):
44 learningQ1Network(learningQ1Network),
45 policyNetwork(policyNetwork),
46 replayMethod(replayMethod),
47 qNetworkUpdater(
std::move(qNetworkUpdater)),
48 #if ENS_VERSION_MAJOR >= 2
49 qNetworkUpdatePolicy(NULL),
51 policyNetworkUpdater(
std::move(policyNetworkUpdater)),
52 #if ENS_VERSION_MAJOR >= 2
53 policyNetworkUpdatePolicy(NULL),
55 environment(
std::move(environment)),
60 targetQ1Network = learningQ1Network;
61 learningQ2Network = learningQ1Network;
62 targetQ2Network = learningQ2Network;
68 if (learningQ1Network.Parameters().is_empty())
70 learningQ1Network.ResetParameters();
71 learningQ2Network.ResetParameters();
73 if (policyNetwork.Parameters().is_empty())
74 policyNetwork.ResetParameters();
75 targetQ1Network.ResetParameters();
76 targetQ2Network.ResetParameters();
78 #if ENS_VERSION_MAJOR == 1 79 this->qNetworkUpdater.Initialize(learningQ1Network.Parameters().n_rows,
80 learningQ1Network.Parameters().n_cols);
82 this->qNetworkUpdatePolicy =
new typename UpdaterType::template
83 Policy<arma::mat, arma::mat>(this->qNetworkUpdater,
84 learningQ1Network.Parameters().n_rows,
85 learningQ1Network.Parameters().n_cols);
88 #if ENS_VERSION_MAJOR == 1 89 this->policyNetworkUpdater.Initialize(policyNetwork.Parameters().n_rows,
90 policyNetwork.Parameters().n_cols);
92 this->policyNetworkUpdatePolicy =
new typename UpdaterType::template
93 Policy<arma::mat, arma::mat>(this->policyNetworkUpdater,
94 policyNetwork.Parameters().n_rows,
95 policyNetwork.Parameters().n_cols);
99 targetQ1Network.Parameters() = learningQ1Network.Parameters();
100 targetQ2Network.Parameters() = learningQ2Network.Parameters();
104 typename EnvironmentType,
105 typename QNetworkType,
106 typename PolicyNetworkType,
107 typename UpdaterType,
118 #if ENS_VERSION_MAJOR >= 2 119 delete qNetworkUpdatePolicy;
120 delete policyNetworkUpdatePolicy;
125 typename EnvironmentType,
126 typename QNetworkType,
127 typename PolicyNetworkType,
128 typename UpdaterType,
139 targetQ1Network.Parameters() = (1 - rho) * targetQ1Network.Parameters() +
140 rho * learningQ1Network.Parameters();
141 targetQ2Network.Parameters() = (1 - rho) * targetQ2Network.Parameters() +
142 rho * learningQ2Network.Parameters();
146 typename EnvironmentType,
147 typename QNetworkType,
148 typename PolicyNetworkType,
149 typename UpdaterType,
161 arma::mat sampledStates;
162 std::vector<ActionType> sampledActions;
163 arma::rowvec sampledRewards;
164 arma::mat sampledNextStates;
165 arma::irowvec isTerminal;
167 replayMethod.Sample(sampledStates, sampledActions, sampledRewards,
168 sampledNextStates, isTerminal);
173 arma::mat nextStateActions;
174 policyNetwork.Predict(sampledNextStates, nextStateActions);
176 arma::mat targetQInput = arma::join_vert(nextStateActions,
179 targetQ1Network.Predict(targetQInput, Q1);
180 targetQ2Network.Predict(targetQInput, Q2);
181 arma::rowvec nextQ = sampledRewards + config.
Discount() * ((1 - isTerminal)
182 % arma::min(Q1, Q2));
184 arma::mat sampledActionValues(action.size, sampledActions.size());
185 for (
size_t i = 0; i < sampledActions.size(); i++)
186 sampledActionValues.col(i) = arma::conv_to<arma::colvec>::from
187 (sampledActions[i].action);
188 arma::mat learningQInput = arma::join_vert(sampledActionValues,
190 learningQ1Network.Forward(learningQInput, Q1);
191 learningQ2Network.Forward(learningQInput, Q2);
193 arma::mat gradQ1Loss, gradQ2Loss;
194 lossFunction.
Backward(Q1, nextQ, gradQ1Loss);
195 lossFunction.
Backward(Q2, nextQ, gradQ2Loss);
198 arma::mat gradientQ1, gradientQ2;
199 learningQ1Network.Backward(learningQInput, gradQ1Loss, gradientQ1);
200 #if ENS_VERSION_MAJOR == 1 201 qNetworkUpdater.Update(learningQ1Network.Parameters(), config.
StepSize(),
204 qNetworkUpdatePolicy->Update(learningQ1Network.Parameters(),
207 learningQ2Network.Backward(learningQInput, gradQ2Loss, gradientQ2);
208 #if ENS_VERSION_MAJOR == 1 209 qNetworkUpdater.Update(learningQ2Network.Parameters(), config.
StepSize(),
212 qNetworkUpdatePolicy->Update(learningQ2Network.Parameters(),
219 policyNetwork.Predict(sampledStates, pi);
221 arma::mat qInput = arma::join_vert(pi, sampledStates);
222 learningQ1Network.Predict(qInput, Q1);
223 learningQ2Network.Predict(qInput, Q2);
226 size_t hidden1 = boost::get<mlpack::ann::Linear<> *>
227 (learningQ1Network.Model()[0])->OutputSize();
230 for (
size_t i = 0; i < sampledStates.n_cols; i++)
232 arma::mat grad, gradQ, q;
233 arma::colvec singleState = sampledStates.col(i);
234 arma::colvec singlePi;
235 policyNetwork.Forward(singleState, singlePi);
236 arma::colvec input = arma::join_vert(singlePi, singleState);
237 arma::mat weightLastLayer;
241 learningQ1Network.Forward(input, q);
242 learningQ1Network.Backward(input, -1, gradQ);
243 weightLastLayer = arma::reshape(learningQ1Network.Parameters().
244 rows(0, hidden1 * singlePi.n_rows - 1), hidden1, singlePi.n_rows);
248 learningQ2Network.Forward(input, q);
249 learningQ2Network.Backward(input, -1, gradQ);
250 weightLastLayer = arma::reshape(learningQ2Network.Parameters().
251 rows(0, hidden1 * singlePi.n_rows - 1), hidden1, singlePi.n_rows);
254 arma::colvec gradQBias = gradQ(input.n_rows * hidden1, 0,
255 arma::size(hidden1, 1));
256 arma::mat gradPolicy = weightLastLayer.t() * gradQBias;
257 policyNetwork.Backward(singleState, gradPolicy, grad);
260 gradient.copy_size(grad);
265 gradient /= sampledStates.n_cols;
267 #if ENS_VERSION_MAJOR == 1 268 policyUpdater.Update(policyNetwork.Parameters(), config.
StepSize(), gradient);
270 policyNetworkUpdatePolicy->Update(policyNetwork.Parameters(),
280 typename EnvironmentType,
281 typename QNetworkType,
282 typename PolicyNetworkType,
283 typename UpdaterType,
295 arma::colvec outputAction;
296 policyNetwork.Predict(state.Encode(), outputAction);
300 arma::colvec noise = arma::randn<arma::colvec>(outputAction.n_rows) * 0.1;
301 noise = arma::clamp(noise, -0.25, 0.25);
302 outputAction = outputAction + noise;
304 action.action = arma::conv_to<std::vector<double>>::from(outputAction);
308 typename EnvironmentType,
309 typename QNetworkType,
310 typename PolicyNetworkType,
311 typename UpdaterType,
323 state = environment.InitialSample();
329 double totalReturn = 0.0;
332 while (!environment.IsTerminal(state))
340 double reward = environment.Sample(state, action, nextState);
342 totalReturn += reward;
347 replayMethod.Store(state, action, reward, nextState,
348 environment.IsTerminal(nextState), config.
Discount());
size_t TargetNetworkSyncInterval() const
Get the interval for syncing target network.
Definition: training_config.hpp:84
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Implementation of Soft Actor-Critic, a model-free off-policy actor-critic based deep reinforcement le...
Definition: sac.hpp:64
void Backward(const PredictionType &prediction, const TargetType &target, LossType &loss)
Ordinary feed backward pass of a neural network.
Definition: mean_squared_error_impl.hpp:39
The core includes that mlpack expects; standard C++ includes and Armadillo.
Definition: pointer_wrapper.hpp:23
size_t UpdateInterval() const
Get the update interval.
Definition: training_config.hpp:79
double StepSize() const
Get the step size of the optimizer.
Definition: training_config.hpp:103
double Rho() const
Get the rho value for sac.
Definition: training_config.hpp:148
typename EnvironmentType::State StateType
Convenient typedef for state.
Definition: sac.hpp:68
void SoftUpdate(double rho)
Softly update the learning Q network parameters to the target Q network parameters.
Definition: sac_impl.hpp:137
size_t StepLimit() const
Get the maximum steps of each episode.
Definition: training_config.hpp:90
void Update()
Update the Q and policy networks.
Definition: sac_impl.hpp:158
Definition: training_config.hpp:19
~SAC()
Clean memory.
Definition: sac_impl.hpp:116
void SelectAction()
Select an action, given an agent.
Definition: sac_impl.hpp:292
double Episode()
Execute an episode.
Definition: sac_impl.hpp:320
double Discount() const
Get the discount rate for future reward.
Definition: training_config.hpp:108
size_t ExplorationSteps() const
Get the exploration steps.
Definition: training_config.hpp:98