Fleet  0.0.9
Inference in the LOT
ThreadedInferenceInterface.h
Go to the documentation of this file.
1 #pragma once
2 
3 
4 #include <atomic>
5 #include <mutex>
6 #include <thread>
7 
8 #include "Control.h"
9 #include "SampleStreams.h"
10 #include "ConcurrentQueue.h"
11 
21  template<typename X, typename... Args>
23 public:
24 
25  // Subclasses must implement run_thread, which is what each individual thread
26  // gets called on (and each thread manages its own locks etc)
27  virtual generator<X&> run_thread(Control& ctl, Args... args) = 0;
28 
29  // index here is used to index into larger parallel collections. Each thread
30  // is expected to get its next item to work on through index, though how will vary
31  std::atomic<size_t> index;
32 
33  // How many threads? Used by some subclasses as asserts
34  size_t __nthreads;
35  std::atomic<size_t> __nrunning;// everyone updates this when they are done
36 
38 
39  ThreadedInferenceInterface() : index(0), __nthreads(0), __nrunning(0), to_yield(FleetArgs::nthreads) { }
40 
45  unsigned long next_index() { return index++; }
46 
51  size_t nthreads() { return __nthreads; }
52 
60  void run_thread_generator_wrapper(size_t thr, Control& ctl, Args... args) {
61 
62  for(auto& x : run_thread(ctl, args...)) {
63 
64  if(x.born_chain_idx == 0 or not FleetArgs::yieldOnlyChainOne) {
65  to_yield.push(x, thr);
66 
67  }
68 
69  if(CTRL_C) break;
70  }
71 
72  // we always notify when we're done, after making sure we're not running or else the
73  // other thread can block
74  __nrunning--;
75  }
76 
82  generator<X&> run(Control ctl, Args... args) {
83 
84  std::vector<std::thread> threads(ctl.nthreads);
85  __nthreads = ctl.nthreads; // save this for children
86  assert(__nrunning==0);
87 
88  // Make a new control to run on each thread and then pass this to
89  // each subthread. This way multiple threads all share the same control
90  // which is required for getting an accurate total count
91  Control ctl2 = ctl;
92  ctl2.nthreads = 1;
93  ctl2.start();
94 
95  // give this just some extra space here
96  //to_yield.resize(FleetArgs::MCMC_QUEUE_MULTIPLIER*ctl.nthreads); // just some extra space here
97 
98  // start each thread
99  for(unsigned long thr=0;thr<ctl.nthreads;thr++) {
100  ++__nrunning;
101  threads[thr] = std::thread(&ThreadedInferenceInterface<X, Args...>::run_thread_generator_wrapper, this, thr, std::ref(ctl2), args...);
102  }
103 
104  // now yield as long as we have some that are running
105  while(__nrunning > 0 and !CTRL_C) { // we don't want to stop when its empty because a thread might fill it
106 // if(not to_yield.empty()) { // w/o this we might pop when its empty...
107 // //print((size_t)to_yield.push_idx, (size_t)to_yield.pop_idx, to_yield.size(), to_yield.N);
108 // co_yield to_yield.pop();
109 // }
110  if(not to_yield.empty()) {
111  auto val = to_yield.pop(); // search through until we find one
112  if(val.has_value()) co_yield val.value();
113  else break;
114  }
115 
116  }
117 
118  // now we're done filling but we still may have stuff in the queue
119  // some threads may be waiting so we can't join yet
120  while(not to_yield.empty()) {
121  auto val = to_yield.pop(); // search through until we find one
122  if(val.has_value()) co_yield val.value();
123  else break; // NOTE: we migth break here and leave some in queue -- but we only get break on CTRL_C
124  }
125 
126  // wait for all to complete
127  for(auto& t : threads)
128  t.join();
129 
130  }
131 
132  generator<X&> unthreaded_run(Control ctl, Args... args) {
133 
134  // This is a simple version where we don't use threads -- useful
135  // because it's hard to debug otherwise
136 
137  std::cerr << "*** Warning running unthreaded_run (intended for debugging)" << std::endl;
138  ctl.start();
139  for(auto& x : run_thread(ctl, args...)) {
140  if(CTRL_C) break; // must come first or else we yield something invalid
141 
142  auto val = to_yield.pop(); // search through until we find one
143  if(val.has_value()) co_yield val.value();
144  else break;
145  }
146  }
147 
148 };
149 
150 
void push(const T &item, size_t thr)
Definition: ConcurrentQueue.h:116
virtual generator< X & > run_thread(Control &ctl, Args... args)=0
bool empty()
Definition: ConcurrentQueue.h:137
size_t nthreads()
How many threads are currently run in this interface?
Definition: ThreadedInferenceInterface.h:51
A concurrent queue class that allows multiple threads to push and consume. Note that this has a fixed...
volatile sig_atomic_t CTRL_C
void start()
Definition: Control.h:54
Definition: Control.h:23
bool yieldOnlyChainOne
Definition: FleetArgs.h:56
std::atomic< size_t > __nrunning
Definition: ThreadedInferenceInterface.h:35
ThreadedInferenceInterface()
Definition: ThreadedInferenceInterface.h:39
Definition: generator.hpp:21
unsigned long next_index()
Return the next index to operate on (in a thread-safe way).
Definition: ThreadedInferenceInterface.h:45
void run_thread_generator_wrapper(size_t thr, Control &ctl, Args... args)
We have to wrap run_thread in something that manages the sync with main. This really just synchronize...
Definition: ThreadedInferenceInterface.h:60
ConcurrentQueueRing< X > to_yield
Definition: ThreadedInferenceInterface.h:37
Primitive< typename Grammar_t::input_t > X(Op::X, BUILTIN_LAMBDA { assert(!vms->xstack.empty());vms->template push< typename Grammar_t::input_t >(vms->xstack.top());})
size_t nthreads
Definition: Control.h:26
std::atomic< size_t > index
Definition: ThreadedInferenceInterface.h:31
size_t __nthreads
Definition: ThreadedInferenceInterface.h:34
Definition: ThreadedInferenceInterface.h:22
generator< X & > unthreaded_run(Control ctl, Args... args)
Definition: ThreadedInferenceInterface.h:132
generator< X & > run(Control ctl, Args... args)
Set up the multiple threads and actually run, calling run_thread_generator_wrapper.
Definition: ThreadedInferenceInterface.h:82
This class has all the information for running MCMC or MCTS in a little package. It defaultly constru...
Definition: FleetArgs.h:10
std::optional< T > pop()
Definition: ConcurrentQueue.h:121