13 #ifndef MLPACK_METHODS_RL_SAC_HPP 14 #define MLPACK_METHODS_RL_SAC_HPP 17 #include <ensmallen.hpp> 58 typename EnvironmentType,
59 typename QNetworkType,
60 typename PolicyNetworkType,
62 typename ReplayType = RandomReplay<EnvironmentType>
90 QNetworkType& learningQ1Network,
91 PolicyNetworkType& policyNetwork,
92 ReplayType& replayMethod,
93 UpdaterType qNetworkUpdater = UpdaterType(),
94 UpdaterType policyNetworkUpdater = UpdaterType(),
95 EnvironmentType environment = EnvironmentType());
150 QNetworkType& learningQ1Network;
151 QNetworkType learningQ2Network;
154 QNetworkType targetQ1Network;
155 QNetworkType targetQ2Network;
158 PolicyNetworkType& policyNetwork;
161 ReplayType& replayMethod;
164 UpdaterType qNetworkUpdater;
165 #if ENS_VERSION_MAJOR >= 2 166 typename UpdaterType::template Policy<arma::mat, arma::mat>*
167 qNetworkUpdatePolicy;
171 UpdaterType policyNetworkUpdater;
172 #if ENS_VERSION_MAJOR >= 2 173 typename UpdaterType::template Policy<arma::mat, arma::mat>*
174 policyNetworkUpdatePolicy;
178 EnvironmentType environment;
const size_t & TotalSteps() const
Get total steps from beginning.
Definition: sac.hpp:129
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Implementation of Soft Actor-Critic, a model-free off-policy actor-critic based deep reinforcement le...
Definition: sac.hpp:64
The core includes that mlpack expects; standard C++ includes and Armadillo.
typename EnvironmentType::State StateType
Convenient typedef for state.
Definition: sac.hpp:68
void SoftUpdate(double rho)
Softly update the learning Q network parameters to the target Q network parameters.
Definition: sac_impl.hpp:137
typename EnvironmentType::Action ActionType
Convenient typedef for action.
Definition: sac.hpp:71
void Update()
Update the Q and policy networks.
Definition: sac_impl.hpp:158
Definition: training_config.hpp:19
bool & Deterministic()
Modify the training mode / test mode indicator.
Definition: sac.hpp:140
const bool & Deterministic() const
Get the indicator of training mode / test mode.
Definition: sac.hpp:142
~SAC()
Clean memory.
Definition: sac_impl.hpp:116
SAC(TrainingConfig &config, QNetworkType &learningQ1Network, PolicyNetworkType &policyNetwork, ReplayType &replayMethod, UpdaterType qNetworkUpdater=UpdaterType(), UpdaterType policyNetworkUpdater=UpdaterType(), EnvironmentType environment=EnvironmentType())
Create the SAC object with given settings.
Definition: sac_impl.hpp:36
void SelectAction()
Select an action, given an agent.
Definition: sac_impl.hpp:292
double Episode()
Execute an episode.
Definition: sac_impl.hpp:320
The mean squared error performance function measures the network's performance according to the mean ...
Definition: mean_squared_error.hpp:34
const StateType & State() const
Get the state of the agent.
Definition: sac.hpp:134
size_t & TotalSteps()
Modify total steps from beginning.
Definition: sac.hpp:127
StateType & State()
Modify the state of the agent.
Definition: sac.hpp:132
const ActionType & Action() const
Get the action of the agent.
Definition: sac.hpp:137