mlpack
Classes | Public Types | Public Member Functions | List of all members
mlpack::rl::PrioritizedReplay< EnvironmentType > Class Template Reference

Implementation of prioritized experience replay. More...

#include <prioritized_replay.hpp>

Classes

struct  Transition
 

Public Types

using ActionType = typename EnvironmentType::Action
 Convenient typedef for action.
 
using StateType = typename EnvironmentType::State
 Convenient typedef for state.
 

Public Member Functions

 PrioritizedReplay ()
 Default constructor.
 
 PrioritizedReplay (const size_t batchSize, const size_t capacity, const double alpha, const size_t nSteps=1, const size_t dimension=StateType::dimension)
 Construct an instance of prioritized experience replay class. More...
 
void Store (StateType state, ActionType action, double reward, StateType nextState, bool isEnd, const double &discount)
 Store the given experience and set the priorities for the given experience. More...
 
void GetNStepInfo (double &reward, StateType &nextState, bool &isEnd, const double &discount)
 Get the reward, next state and terminal boolean for nth step. More...
 
arma::ucolvec SampleProportional ()
 Sample some experience according to their priorities. More...
 
void Sample (arma::mat &sampledStates, std::vector< ActionType > &sampledActions, arma::rowvec &sampledRewards, arma::mat &sampledNextStates, arma::irowvec &isTerminal)
 Sample some experience according to their priorities. More...
 
void UpdatePriorities (arma::ucolvec &indices, arma::colvec &priorities)
 Update priorities of sampled transitions. More...
 
const size_t & Size ()
 Get the number of transitions in the memory. More...
 
void BetaAnneal ()
 Annealing the beta.
 
void Update (arma::mat target, std::vector< ActionType > sampledActions, arma::mat nextActionValues, arma::mat &gradients)
 Update the priorities of transitions and Update the gradients. More...
 
const size_t & NSteps () const
 Get the number of steps for n-step agent.
 

Detailed Description

template<typename EnvironmentType>
class mlpack::rl::PrioritizedReplay< EnvironmentType >

Implementation of prioritized experience replay.

Prioritized experience replay can replay important transitions more frequently by prioritizing transitions, and make agent learn more efficiently.

@article{schaul2015prioritized,
title = {Prioritized experience replay},
author = {Schaul, Tom and Quan, John and Antonoglou,
Ioannis and Silver, David},
journal = {arXiv preprint arXiv:1511.05952},
year = {2015}
}
Template Parameters
EnvironmentTypeDesired task.

Constructor & Destructor Documentation

◆ PrioritizedReplay()

template<typename EnvironmentType >
mlpack::rl::PrioritizedReplay< EnvironmentType >::PrioritizedReplay ( const size_t  batchSize,
const size_t  capacity,
const double  alpha,
const size_t  nSteps = 1,
const size_t  dimension = StateType::dimension 
)
inline

Construct an instance of prioritized experience replay class.

Parameters
batchSizeNumber of examples returned at each sample.
capacityTotal memory size in terms of number of examples.
alphaHow much prioritization is used.
nStepsNumber of steps to look in the future.
dimensionThe dimension of an encoded state.

Member Function Documentation

◆ GetNStepInfo()

template<typename EnvironmentType >
void mlpack::rl::PrioritizedReplay< EnvironmentType >::GetNStepInfo ( double &  reward,
StateType nextState,
bool &  isEnd,
const double &  discount 
)
inline

Get the reward, next state and terminal boolean for nth step.

Parameters
rewardGiven reward.
nextStateGiven next state.
isEndWhether next state is terminal state.
discountThe discount parameter.

◆ Sample()

template<typename EnvironmentType >
void mlpack::rl::PrioritizedReplay< EnvironmentType >::Sample ( arma::mat &  sampledStates,
std::vector< ActionType > &  sampledActions,
arma::rowvec &  sampledRewards,
arma::mat &  sampledNextStates,
arma::irowvec &  isTerminal 
)
inline

Sample some experience according to their priorities.

Parameters
sampledStatesSampled encoded states.
sampledActionsSampled actions.
sampledRewardsSampled rewards.
sampledNextStatesSampled encoded next states.
isTerminalIndicate whether corresponding next state is terminal state.

◆ SampleProportional()

template<typename EnvironmentType >
arma::ucolvec mlpack::rl::PrioritizedReplay< EnvironmentType >::SampleProportional ( )
inline

Sample some experience according to their priorities.

Returns
The indices to be chosen.

◆ Size()

template<typename EnvironmentType >
const size_t& mlpack::rl::PrioritizedReplay< EnvironmentType >::Size ( )
inline

Get the number of transitions in the memory.

Returns
Actual used memory size.

◆ Store()

template<typename EnvironmentType >
void mlpack::rl::PrioritizedReplay< EnvironmentType >::Store ( StateType  state,
ActionType  action,
double  reward,
StateType  nextState,
bool  isEnd,
const double &  discount 
)
inline

Store the given experience and set the priorities for the given experience.

Parameters
stateGiven state.
actionGiven action.
rewardGiven reward.
nextStateGiven next state.
isEndWhether next state is terminal state.
discountThe discount parameter.

◆ Update()

template<typename EnvironmentType >
void mlpack::rl::PrioritizedReplay< EnvironmentType >::Update ( arma::mat  target,
std::vector< ActionType sampledActions,
arma::mat  nextActionValues,
arma::mat &  gradients 
)
inline

Update the priorities of transitions and Update the gradients.

Parameters
targetThe learned value.
sampledActionsAgent's sampled action.
nextActionValuesAgent's next action.
gradientsThe model's gradients.

◆ UpdatePriorities()

template<typename EnvironmentType >
void mlpack::rl::PrioritizedReplay< EnvironmentType >::UpdatePriorities ( arma::ucolvec &  indices,
arma::colvec &  priorities 
)
inline

Update priorities of sampled transitions.

Parameters
indicesThe indices of sample to be updated.
prioritiesTheir corresponding priorities.

The documentation for this class was generated from the following file: