mlpack
|
Wrapper of various asynchronous learning algorithms, e.g. More...
#include <async_learning.hpp>
Public Member Functions | |
AsyncLearning (TrainingConfig config, NetworkType network, PolicyType policy, UpdaterType updater=UpdaterType(), EnvironmentType environment=EnvironmentType()) | |
Construct an instance of the given async learning algorithm. More... | |
template<typename Measure > | |
void | Train (Measure &measure) |
Starting async training. More... | |
TrainingConfig & | Config () |
Get training config. | |
const TrainingConfig & | Config () const |
Modify training config. | |
NetworkType & | Network () |
Get learning network. | |
const NetworkType & | Network () const |
Modify learning network. | |
PolicyType & | Policy () |
Get behavior policy. | |
const PolicyType & | Policy () const |
Modify behavior policy. | |
UpdaterType & | Updater () |
Get optimizer. | |
const UpdaterType & | Updater () const |
Modify optimizer. | |
EnvironmentType & | Environment () |
Get the environment. | |
const EnvironmentType & | Environment () const |
Modify the environment. | |
Wrapper of various asynchronous learning algorithms, e.g.
async one-step Q-learning, async one-step Sarsa, async n-step Q-learning and async advantage actor-critic.
For more details, see the following:
WorkerType | The type of the worker. |
EnvironmentType | The type of reinforcement learning task. |
NetworkType | The type of the network model. |
UpdaterType | The type of the optimizer. |
PolicyType | The type of the behavior policy. |
mlpack::rl::AsyncLearning< WorkerType, EnvironmentType, NetworkType, UpdaterType, PolicyType >::AsyncLearning | ( | TrainingConfig | config, |
NetworkType | network, | ||
PolicyType | policy, | ||
UpdaterType | updater = UpdaterType() , |
||
EnvironmentType | environment = EnvironmentType() |
||
) |
Construct an instance of the given async learning algorithm.
config | Hyper-parameters for training. |
network | The network model. |
policy | The behavior policy. |
updater | The optimizer. |
environment | The reinforcement learning task. |
void mlpack::rl::AsyncLearning< WorkerType, EnvironmentType, NetworkType, UpdaterType, PolicyType >::Train | ( | Measure & | measure | ) |
Starting async training.
Measure | The type of the measurement. It should be a callable object like bool foo(double reward); |
measure | The measurement instance. |
OpenMP doesn't support shared class member variables. So we need to copy them to local variables.
Compute the number of threads for the for-loop. In general, we should use OpenMP task rather than for-loop, here we do so to be compatible with some compiler. We can switch to OpenMP task once MSVC supports OpenMP 3.0.