13 #ifndef MLPACK_METHODS_RL_Q_LEARNING_HPP 14 #define MLPACK_METHODS_RL_Q_LEARNING_HPP 17 #include <ensmallen.hpp> 53 typename EnvironmentType,
57 typename ReplayType = RandomReplay<EnvironmentType>
84 ReplayType& replayMethod,
85 UpdaterType updater = UpdaterType(),
86 EnvironmentType environment = EnvironmentType());
130 const EnvironmentType&
Environment()
const {
return environment; }
138 const NetworkType&
Network()
const {
return learningNetwork; }
140 NetworkType&
Network() {
return learningNetwork; }
148 arma::Col<size_t> BestAction(
const arma::mat& actionValues);
154 NetworkType& learningNetwork;
157 NetworkType targetNetwork;
163 ReplayType& replayMethod;
167 #if ENS_VERSION_MAJOR >= 2 168 typename UpdaterType::template Policy<arma::mat, arma::mat>* updatePolicy;
172 EnvironmentType environment;
EnvironmentType & Environment()
Modify the environment in which the agent is.
Definition: q_learning.hpp:128
typename EnvironmentType::State StateType
Convenient typedef for state.
Definition: q_learning.hpp:63
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
const ActionType & Action() const
Get the action of the agent.
Definition: q_learning.hpp:125
QLearning(TrainingConfig &config, NetworkType &network, PolicyType &policy, ReplayType &replayMethod, UpdaterType updater=UpdaterType(), EnvironmentType environment=EnvironmentType())
Create the QLearning object with given settings.
Definition: q_learning_impl.hpp:33
The core includes that mlpack expects; standard C++ includes and Armadillo.
~QLearning()
Clean memory.
Definition: q_learning_impl.hpp:87
void SelectAction()
Select an action, given an agent.
Definition: q_learning_impl.hpp:331
const bool & Deterministic() const
Get the indicator of training mode / test mode.
Definition: q_learning.hpp:135
const size_t & TotalSteps() const
Get total steps from beginning.
Definition: q_learning.hpp:117
const StateType & State() const
Get the state of the agent.
Definition: q_learning.hpp:122
double Episode()
Execute an episode.
Definition: q_learning_impl.hpp:354
void TrainAgent()
Trains the DQN agent(non-categorical).
Definition: q_learning_impl.hpp:133
bool & Deterministic()
Modify the training mode / test mode indicator.
Definition: q_learning.hpp:133
const NetworkType & Network() const
Return the learning network.
Definition: q_learning.hpp:138
size_t & TotalSteps()
Modify total steps from beginning.
Definition: q_learning.hpp:115
Definition: training_config.hpp:19
const EnvironmentType & Environment() const
Get the environment in which the agent is.
Definition: q_learning.hpp:130
Implementation of various Q-Learning algorithms, such as DQN, double DQN.
Definition: q_learning.hpp:59
NetworkType & Network()
Modify the learning network.
Definition: q_learning.hpp:140
typename EnvironmentType::Action ActionType
Convenient typedef for action.
Definition: q_learning.hpp:66
StateType & State()
Modify the state of the agent.
Definition: q_learning.hpp:120
void TrainCategoricalAgent()
Trains the DQN agent of categorical type.
Definition: q_learning_impl.hpp:220