13 #ifndef MLPACK_METHODS_RL_ASYNC_LEARNING_IMPL_HPP 14 #define MLPACK_METHODS_RL_ASYNC_LEARNING_IMPL_HPP 24 typename EnvironmentType,
40 EnvironmentType environment):
41 config(
std::move(config)),
42 learningNetwork(
std::move(network)),
43 policy(
std::move(policy)),
44 updater(
std::move(updater)),
45 environment(
std::move(environment))
50 typename EnvironmentType,
55 template <
typename Measure>
68 NetworkType learningNetwork = std::move(this->learningNetwork);
69 if (learningNetwork.Parameters().is_empty())
70 learningNetwork.ResetParameters();
71 NetworkType targetNetwork = learningNetwork;
72 size_t totalSteps = 0;
73 PolicyType policy = this->policy;
77 std::vector<WorkerType> workers;
78 for (
size_t i = 0; i <= config.
NumWorkers(); ++i)
80 workers.push_back(WorkerType(updater, environment, config, !i));
81 workers.back().Initialize(learningNetwork);
84 std::queue<size_t> tasks;
85 for (
size_t i = 0; i <= config.
NumWorkers(); ++i)
93 size_t numThreads = 0;
94 #pragma omp parallel reduction(+:numThreads) 96 Log::Debug << numThreads <<
" threads will be used in total." << std::endl;
98 #pragma omp parallel for shared(stop, workers, tasks, learningNetwork, \ 99 targetNetwork, totalSteps, policy) 100 for (omp_size_t i = 0; i < numThreads; ++i)
105 Log::Debug <<
"Thread " << omp_get_thread_num() <<
106 " started." << std::endl;
109 size_t task = std::numeric_limits<size_t>::max();
115 if (task != std::numeric_limits<size_t>::max())
120 task = tasks.front();
126 if (task == std::numeric_limits<size_t>::max())
130 WorkerType& worker = workers[task];
131 double episodeReturn;
132 if (worker.Step(learningNetwork, targetNetwork, totalSteps,
133 policy, episodeReturn) && !task)
135 stop = measure(episodeReturn);
141 this->learningNetwork = std::move(learningNetwork);
static MLPACK_EXPORT util::NullOutStream Debug
MLPACK_EXPORT is required for global variables, so that they are properly exported by the Windows com...
Definition: log.hpp:79
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
size_t NumWorkers() const
Get the amount of workers.
Definition: training_config.hpp:74
The core includes that mlpack expects; standard C++ includes and Armadillo.
Definition: pointer_wrapper.hpp:23
Definition: training_config.hpp:19
void Train(Measure &measure)
Starting async training.
Definition: async_learning_impl.hpp:62
Wrapper of various asynchronous learning algorithms, e.g.
Definition: async_learning.hpp:57