mlpack
Public Types | Public Member Functions | List of all members
mlpack::rl::SAC< EnvironmentType, QNetworkType, PolicyNetworkType, UpdaterType, ReplayType > Class Template Reference

Implementation of Soft Actor-Critic, a model-free off-policy actor-critic based deep reinforcement learning algorithm. More...

#include <sac.hpp>

Public Types

using StateType = typename EnvironmentType::State
 Convenient typedef for state.
 
using ActionType = typename EnvironmentType::Action
 Convenient typedef for action.
 

Public Member Functions

 SAC (TrainingConfig &config, QNetworkType &learningQ1Network, PolicyNetworkType &policyNetwork, ReplayType &replayMethod, UpdaterType qNetworkUpdater=UpdaterType(), UpdaterType policyNetworkUpdater=UpdaterType(), EnvironmentType environment=EnvironmentType())
 Create the SAC object with given settings. More...
 
 ~SAC ()
 Clean memory.
 
void SoftUpdate (double rho)
 Softly update the learning Q network parameters to the target Q network parameters. More...
 
void Update ()
 Update the Q and policy networks.
 
void SelectAction ()
 Select an action, given an agent.
 
double Episode ()
 Execute an episode. More...
 
size_t & TotalSteps ()
 Modify total steps from beginning.
 
const size_t & TotalSteps () const
 Get total steps from beginning.
 
StateTypeState ()
 Modify the state of the agent.
 
const StateTypeState () const
 Get the state of the agent.
 
const ActionTypeAction () const
 Get the action of the agent.
 
bool & Deterministic ()
 Modify the training mode / test mode indicator.
 
const bool & Deterministic () const
 Get the indicator of training mode / test mode.
 

Detailed Description

template<typename EnvironmentType, typename QNetworkType, typename PolicyNetworkType, typename UpdaterType, typename ReplayType = RandomReplay<EnvironmentType>>
class mlpack::rl::SAC< EnvironmentType, QNetworkType, PolicyNetworkType, UpdaterType, ReplayType >

Implementation of Soft Actor-Critic, a model-free off-policy actor-critic based deep reinforcement learning algorithm.

For more details, see the following:

@misc{haarnoja2018soft,
author = {Tuomas Haarnoja and
Aurick Zhou and
Kristian Hartikainen and
George Tucker and
Sehoon Ha and
Jie Tan and
Vikash Kumar and
Henry Zhu and
Abhishek Gupta and
Pieter Abbeel and
Sergey Levine},
title = {Soft Actor-Critic Algorithms and Applications},
year = {2018},
url = {https://arxiv.org/abs/1812.05905}
}
Template Parameters
EnvironmentTypeThe environment of the reinforcement learning task.
NetworkTypeThe network to compute action value.
UpdaterTypeHow to apply gradients when training.
ReplayTypeExperience replay method.

Constructor & Destructor Documentation

◆ SAC()

template<typename EnvironmentType , typename QNetworkType , typename PolicyNetworkType , typename UpdaterType , typename ReplayType >
mlpack::rl::SAC< EnvironmentType, QNetworkType, PolicyNetworkType, UpdaterType, ReplayType >::SAC ( TrainingConfig config,
QNetworkType &  learningQ1Network,
PolicyNetworkType &  policyNetwork,
ReplayType &  replayMethod,
UpdaterType  qNetworkUpdater = UpdaterType(),
UpdaterType  policyNetworkUpdater = UpdaterType(),
EnvironmentType  environment = EnvironmentType() 
)

Create the SAC object with given settings.

If you want to pass in a parameter and discard the original parameter object, you can directly pass the parameter, as the constructor takes a reference. This avoids unnecessary copy.

Parameters
configHyper-parameters for training.
learningQ1NetworkThe network to compute action value.
policyNetworkThe network to produce an action given a state.
replayMethodExperience replay method.
qNetworkUpdaterHow to apply gradients to Q network when training.
policyNetworkUpdaterHow to apply gradients to policy network when training.
environmentReinforcement learning task.

Member Function Documentation

◆ Episode()

template<typename EnvironmentType , typename QNetworkType , typename PolicyNetworkType , typename UpdaterType , typename ReplayType >
double mlpack::rl::SAC< EnvironmentType, QNetworkType, PolicyNetworkType, UpdaterType, ReplayType >::Episode ( )

Execute an episode.

Returns
Return of the episode.

◆ SoftUpdate()

template<typename EnvironmentType , typename QNetworkType , typename PolicyNetworkType , typename UpdaterType , typename ReplayType >
void mlpack::rl::SAC< EnvironmentType, QNetworkType, PolicyNetworkType, UpdaterType, ReplayType >::SoftUpdate ( double  rho)

Softly update the learning Q network parameters to the target Q network parameters.

Parameters
rhoHow "softly" should the parameters be copied.

The documentation for this class was generated from the following files: