mlpack
async_learning_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_RL_ASYNC_LEARNING_IMPL_HPP
14 #define MLPACK_METHODS_RL_ASYNC_LEARNING_IMPL_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 #include "queue"
18 
19 namespace mlpack {
20 namespace rl {
21 
22 template <
23  typename WorkerType,
24  typename EnvironmentType,
25  typename NetworkType,
26  typename UpdaterType,
27  typename PolicyType
28 >
29 AsyncLearning<
30  WorkerType,
31  EnvironmentType,
32  NetworkType,
33  UpdaterType,
34  PolicyType
36  TrainingConfig config,
37  NetworkType network,
38  PolicyType policy,
39  UpdaterType updater,
40  EnvironmentType environment):
41  config(std::move(config)),
42  learningNetwork(std::move(network)),
43  policy(std::move(policy)),
44  updater(std::move(updater)),
45  environment(std::move(environment))
46 { /* Nothing to do here. */ };
47 
48 template <
49  typename WorkerType,
50  typename EnvironmentType,
51  typename NetworkType,
52  typename UpdaterType,
53  typename PolicyType
54 >
55 template <typename Measure>
56 void AsyncLearning<
57  WorkerType,
58  EnvironmentType,
59  NetworkType,
60  UpdaterType,
61  PolicyType
62 >::Train(Measure& measure)
63 {
68  NetworkType learningNetwork = std::move(this->learningNetwork);
69  if (learningNetwork.Parameters().is_empty())
70  learningNetwork.ResetParameters();
71  NetworkType targetNetwork = learningNetwork;
72  size_t totalSteps = 0;
73  PolicyType policy = this->policy;
74  bool stop = false;
75 
76  // Set up worker pool, worker 0 will be deterministic for evaluation.
77  std::vector<WorkerType> workers;
78  for (size_t i = 0; i <= config.NumWorkers(); ++i)
79  {
80  workers.push_back(WorkerType(updater, environment, config, !i));
81  workers.back().Initialize(learningNetwork);
82  }
83  // Set up task queue corresponding to worker pool.
84  std::queue<size_t> tasks;
85  for (size_t i = 0; i <= config.NumWorkers(); ++i)
86  tasks.push(i);
87 
93  size_t numThreads = 0;
94  #pragma omp parallel reduction(+:numThreads)
95  numThreads++;
96  Log::Debug << numThreads << " threads will be used in total." << std::endl;
97 
98  #pragma omp parallel for shared(stop, workers, tasks, learningNetwork, \
99  targetNetwork, totalSteps, policy)
100  for (omp_size_t i = 0; i < numThreads; ++i)
101  {
102  #pragma omp critical
103  {
104  #ifdef HAS_OPENMP
105  Log::Debug << "Thread " << omp_get_thread_num() <<
106  " started." << std::endl;
107  #endif
108  }
109  size_t task = std::numeric_limits<size_t>::max();
110  while (!stop)
111  {
112  // Assign task to current thread from queue.
113  #pragma omp critical
114  {
115  if (task != std::numeric_limits<size_t>::max())
116  tasks.push(task);
117 
118  if (!tasks.empty())
119  {
120  task = tasks.front();
121  tasks.pop();
122  }
123  };
124 
125  // This may happen when threads are more than workers.
126  if (task == std::numeric_limits<size_t>::max())
127  continue;
128 
129  // Get corresponding worker.
130  WorkerType& worker = workers[task];
131  double episodeReturn;
132  if (worker.Step(learningNetwork, targetNetwork, totalSteps,
133  policy, episodeReturn) && !task)
134  {
135  stop = measure(episodeReturn);
136  }
137  }
138  }
139 
140  // Write back the learning network.
141  this->learningNetwork = std::move(learningNetwork);
142 };
143 
144 } // namespace rl
145 } // namespace mlpack
146 
147 #endif
148 
static MLPACK_EXPORT util::NullOutStream Debug
MLPACK_EXPORT is required for global variables, so that they are properly exported by the Windows com...
Definition: log.hpp:79
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
size_t NumWorkers() const
Get the amount of workers.
Definition: training_config.hpp:74
The core includes that mlpack expects; standard C++ includes and Armadillo.
Definition: pointer_wrapper.hpp:23
Definition: training_config.hpp:19
void Train(Measure &measure)
Starting async training.
Definition: async_learning_impl.hpp:62
Wrapper of various asynchronous learning algorithms, e.g.
Definition: async_learning.hpp:57