Fleet  0.0.9
Inference in the LOT
ChainPool.h
Go to the documentation of this file.
1 #pragma once
2 
3 #include <thread>
4 
5 //#define DEBUG_CHAINPOOL
6 
7 #include <vector>
8 
9 #include "Errors.h"
10 #include "MCMCChain.h"
11 #include "Timing.h"
13 #include "OrderedLock.h"
14 
24 template<typename HYP, typename Chain_t=MCMCChain<HYP>>
25 class ChainPool : public ThreadedInferenceInterface<HYP> {
26 public:
27 
28  // the pool stores a bunch of chains
29  std::vector<Chain_t> pool;
30 
31  // these parameters define the amount of a thread spends on each chain before changing to another
32  // NOTE: these interact with ParallelTempering swap/adapt values (because if these are too small, then
33  // we won't have time to update every chain before proposing more swaps)
34  // NOTE: It seems probably better to set steps rather than time, because the hot chains run *much* faster
35  // than the cold chains, typically. This means that if you set it by time, then you are spending lots of
36  // time on the bad chains, which is the opposite of what you want. Actually, here, we probably should
37  // run for *less* samples on the hot chains because they are faster to sample from.
38  // Also note that due to multithreading, this is actually a little complex, we can't perfectly get the
39  // number of samples
40  unsigned long steps_before_change = 100;
41 
42  // Store which chains are running and which are done
43  enum class RunningState {READY, RUNNING, DONE};
44 
45  // keep track of which threads are currently running
46  std::vector<RunningState> running;
48 
49  ChainPool() {}
50 
51  ChainPool(HYP& h0, typename HYP::data_t* d, size_t n) {
52  assert(n>=1 && "*** You probably shouldn't have a chain pool with 0 elements");
53  for(size_t i=0;i<n;i++) {
54  add_chain(i==0?h0:h0.restart(), d);
55  }
56  }
57 
63  void set_data(typename HYP::data_t* d, bool recompute=true) {
64  for(auto& c : pool) {
65  c.set_data(d, recompute);
66  }
67  }
68 
74  template<typename... ARGS>
75  void add_chain(ARGS... args) {
76 
77  std::lock_guard guard(running_lock);
78 
79  pool.emplace_back(args...);
80  running.push_back(RunningState::READY);
81  }
82 
83 
84  size_t nchains() const {
85  return pool.size();
86  }
87 
88  void show(std::string prefix) const {
89  for(size_t i=0;i<nchains();i++) {
90  std::lock_guard guard(this->pool[i].current_mutex);
91  print(prefix, i, (double)this->pool[i].temperature, this->pool[i].getCurrent().posterior, this->pool[i].getCurrent());
92  }
93  }
94 
100  assert(pool.size() > 0 && "*** Cannot run on an empty ChainPool");
101  assert(this->nthreads() <= pool.size() && "*** Cannot have more threads than pool items");
102 
103  // We have to manage subthreads pretty differently depending on whether we have a time or a
104  // sample constraint. For now, we assume we can't have both
105 
106  // Here we don't care about being precise about the number of steps
107  // (below we do)
108 
109  if(ctl.steps == 0) {
110  // we end up here if they're both zero, or steps=0. If they're both 0, we are running till CTRL_C
111 
112  while( ctl.running() and (not CTRL_C) ) {
113  // find the next open thread
114  size_t idx;
115  {
116  std::lock_guard guard(running_lock);
117 
118  do {
119  idx = this->next_index() % pool.size();
120  } while( running[idx] != RunningState::READY ); // so we exit on a false running idx
121  running[idx] = RunningState::RUNNING; // say I'm running this one
122  }
123 
124  // Actually run and yield, being sure to save where everything came from
125  Control c = ctl; // make a copy of everything in control
126  c.steps = steps_before_change; c.nthreads = 1; c.runtime = 0; // but update to
127  for(auto x : pool[idx].run(c)) {
128  x.born_chain_idx = idx; // set this
129  co_yield x;
130  }
131 
132  // and free this lock space
133  {
134  std::lock_guard guard(running_lock);
135 
136  ctl.done_steps += steps_before_change-1; // -1 because calling ctl.running adds one
137 
138  running[idx] = RunningState::READY;
139  }
140 
141  }
142 
143 
144  }
145  else { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
146 
147  assert(ctl.runtime == 0 && "*** Cannot have both time and steps specified in ChainPool (not implemented yet).");
148 
149  // note here on the while loops, we don't use ctl.running() because we need all the things
150  // we start to actually finish (or else we won't run enough steps)
151  while( ctl.done_steps < ctl.steps and (not CTRL_C) ) {
152 
153  // find the next open thread
154  size_t idx;
155  unsigned long to_run_steps=0;
156  {
157  std::lock_guard guard(running_lock);
158 
159  do {
160  idx = this->next_index() % pool.size();
161  } while( running[idx] != RunningState::READY ); // so we exit on a ready idx
162  running[idx] = RunningState::RUNNING; // say I'm running this one
163 
165  to_run_steps = std::min(ctl.steps-ctl.done_steps, this->steps_before_change);
166 
167  // exit here if there is nothing else to do
168  if(to_run_steps <= 0) {
169  break;
170  }
171 
172  // now update ctl's number of steps
173  // NOTE if we do this here, we'll stop too early; if we do it later, we'll run too many...
174  // hmm.... Need a more complex solution it seems...
175  ctl.done_steps += to_run_steps;
176 
177 // print(">>", idx, ctl.steps, ctl.done_steps, to_run_steps, this->steps_before_change);
178  }
179 
180 
181  // Actually run and yield, being sure to save where everything came from
182  Control c = ctl; // make a copy of everything in control
183  c.steps = to_run_steps; c.nthreads = 1; c.runtime = 0; // but update to
184  for(auto& x : pool[idx].run(c)) {
185  x.born_chain_idx = idx; // set this
186  co_yield x;
187  }
188 
189  #ifdef DEBUG_CHAINPOOL
190  COUT "# Thread " <<std::this_thread::get_id() << " stopping chain "<< idx TAB "at " TAB chain.current.posterior TAB chain.current.string() ENDL;
191  #endif
192 
193  // and free this lock space
194  {
195  std::lock_guard guard(running_lock);
196 
197  // we are done if we ran out of steps to continue running.
198  if(to_run_steps == steps_before_change) {
199  running[idx] = RunningState::READY;
200  }
201  else {
202  running[idx] = RunningState::DONE; // say I'm running this one
203  }
204  }
205  }
206 
207 
208  }
209 
210 
211  }
212 
213 };
Definition: OrderedLock.h:16
unsigned long steps
Definition: Control.h:24
RunningState
Definition: ChainPool.h:43
#define TAB
Definition: IO.h:19
std::vector< Chain_t > pool
Definition: ChainPool.h:29
ChainPool(HYP &h0, typename HYP::data_t *d, size_t n)
Definition: ChainPool.h:51
time_ms runtime
Definition: Control.h:25
ChainPool()
Definition: ChainPool.h:49
size_t nthreads()
How many threads are currently run in this interface?
Definition: ThreadedInferenceInterface.h:51
Definition: ChainPool.h:25
void show(std::string prefix) const
Definition: ChainPool.h:88
volatile sig_atomic_t CTRL_C
This manages multiple threads for running inference. This requires a subclass to define run_thread...
void add_chain(ARGS... args)
Lock and modify the pool.
Definition: ChainPool.h:75
void print(FIRST f, ARGS... args)
Lock output_lock and print to std:cout.
Definition: IO.h:53
Definition: Control.h:23
A FIFO mutex (from stackoverflow) https://stackoverflow.com/questions/14792016/creating-a-lock-that-p...
Definition: generator.hpp:21
unsigned long next_index()
Return the next index to operate on (in a thread-safe way).
Definition: ThreadedInferenceInterface.h:45
void set_data(typename HYP::data_t *d, bool recompute=true)
Set this data.
Definition: ChainPool.h:63
std::vector< RunningState > running
Definition: ChainPool.h:46
size_t nchains() const
Definition: ChainPool.h:84
size_t nthreads
Definition: Control.h:26
#define ENDL
Definition: IO.h:21
OrderedLock running_lock
Definition: ChainPool.h:47
Definition: ThreadedInferenceInterface.h:22
std::atomic< unsigned long > done_steps
Definition: Control.h:32
bool running()
Definition: Control.h:63
unsigned long steps_before_change
Definition: ChainPool.h:40
#define COUT
Definition: IO.h:24
generator< HYP &> run(Control ctl, Args... args)
Set up the multiple threads and actually run, calling run_thread_generator_wrapper.
Definition: ThreadedInferenceInterface.h:82
generator< HYP & > run_thread(Control &ctl) override
This run helper is called internally by multiple different threads, and runs a given pool...
Definition: ChainPool.h:99
This represents an MCMC hain on a hypothesis of type HYP. It uses HYP::propose and HYP::compute_poste...