mlpack
Public Types | Public Member Functions | List of all members
mlpack::rl::QLearning< EnvironmentType, NetworkType, UpdaterType, PolicyType, ReplayType > Class Template Reference

Implementation of various Q-Learning algorithms, such as DQN, double DQN. More...

#include <q_learning.hpp>

Public Types

using StateType = typename EnvironmentType::State
 Convenient typedef for state.
 
using ActionType = typename EnvironmentType::Action
 Convenient typedef for action.
 

Public Member Functions

 QLearning (TrainingConfig &config, NetworkType &network, PolicyType &policy, ReplayType &replayMethod, UpdaterType updater=UpdaterType(), EnvironmentType environment=EnvironmentType())
 Create the QLearning object with given settings. More...
 
 ~QLearning ()
 Clean memory.
 
void TrainAgent ()
 Trains the DQN agent(non-categorical). More...
 
void TrainCategoricalAgent ()
 Trains the DQN agent of categorical type.
 
void SelectAction ()
 Select an action, given an agent.
 
double Episode ()
 Execute an episode. More...
 
size_t & TotalSteps ()
 Modify total steps from beginning.
 
const size_t & TotalSteps () const
 Get total steps from beginning.
 
StateTypeState ()
 Modify the state of the agent.
 
const StateTypeState () const
 Get the state of the agent.
 
const ActionTypeAction () const
 Get the action of the agent.
 
EnvironmentType & Environment ()
 Modify the environment in which the agent is.
 
const EnvironmentType & Environment () const
 Get the environment in which the agent is.
 
bool & Deterministic ()
 Modify the training mode / test mode indicator.
 
const bool & Deterministic () const
 Get the indicator of training mode / test mode.
 
const NetworkType & Network () const
 Return the learning network.
 
NetworkType & Network ()
 Modify the learning network.
 

Detailed Description

template<typename EnvironmentType, typename NetworkType, typename UpdaterType, typename PolicyType, typename ReplayType = RandomReplay<EnvironmentType>>
class mlpack::rl::QLearning< EnvironmentType, NetworkType, UpdaterType, PolicyType, ReplayType >

Implementation of various Q-Learning algorithms, such as DQN, double DQN.

For more details, see the following:

@article{Mnih2013,
author = {Volodymyr Mnih and
Koray Kavukcuoglu and
David Silver and
Alex Graves and
Ioannis Antonoglou and
Daan Wierstra and
Martin A. Riedmiller},
title = {Playing Atari with Deep Reinforcement Learning},
journal = {CoRR},
year = {2013},
url = {http://arxiv.org/abs/1312.5602}
}
Template Parameters
EnvironmentTypeThe environment of the reinforcement learning task.
NetworkTypeThe network to compute action value.
UpdaterTypeHow to apply gradients when training.
PolicyTypeBehavior policy of the agent.
ReplayTypeExperience replay method.

Constructor & Destructor Documentation

◆ QLearning()

template<typename EnvironmentType , typename NetworkType , typename UpdaterType , typename PolicyType , typename ReplayType >
mlpack::rl::QLearning< EnvironmentType, NetworkType, UpdaterType, PolicyType, ReplayType >::QLearning ( TrainingConfig config,
NetworkType &  network,
PolicyType &  policy,
ReplayType &  replayMethod,
UpdaterType  updater = UpdaterType(),
EnvironmentType  environment = EnvironmentType() 
)

Create the QLearning object with given settings.

If you want to pass in a parameter and discard the original parameter object, be sure to use std::move to avoid unnecessary copy.

Parameters
configHyper-parameters for training.
networkThe network to compute action value.
policyBehavior policy of the agent.
replayMethodExperience replay method.
updaterHow to apply gradients when training.
environmentReinforcement learning task.

Member Function Documentation

◆ Episode()

template<typename EnvironmentType , typename NetworkType , typename UpdaterType , typename BehaviorPolicyType , typename ReplayType >
double mlpack::rl::QLearning< EnvironmentType, NetworkType, UpdaterType, BehaviorPolicyType, ReplayType >::Episode ( )

Execute an episode.

Returns
Return of the episode.

◆ TrainAgent()

template<typename EnvironmentType , typename NetworkType , typename UpdaterType , typename BehaviorPolicyType , typename ReplayType >
void mlpack::rl::QLearning< EnvironmentType, NetworkType, UpdaterType, BehaviorPolicyType, ReplayType >::TrainAgent ( )

Trains the DQN agent(non-categorical).

If the agent is at a terminal state, then we don't need to add the discounted reward. At terminal state, the agent wont perform any action.


The documentation for this class was generated from the following files: