Fleet  0.0.9
Inference in the LOT
MCTSBase.h
Go to the documentation of this file.
1 #pragma once
2 
3 #define DEBUG_MCTS 0
4 
5 #include <atomic>
6 #include <mutex>
7 #include <set>
8 #include <algorithm>
9 #include <iostream>
10 #include <fstream>
11 #include <functional>
12 
13 #include "StreamingStatistics.h"
15 #include "Control.h"
16 #include "SpinLock.h"
17 #include "Random.h"
18 #include "FleetArgs.h"
19 #include "PriorInference.h"
20 
21 #include "BaseNode.h"
22 
40 template<typename this_t, typename HYP>
41 class MCTSBase : public ThreadedInferenceInterface<HYP, HYP>, public BaseNode<this_t> {
42  friend class BaseNode<this_t>;
43 
44 public:
45 
46  using data_t = typename HYP::data_t;
47 
48  bool open; // am I still an available node?
49 
51 
53 
54  // these are static variables that makes them not have to be stored in each node
55  // this means that if we want to change them for multiple MCTS nodes, we need to subclass
56  static double explore;
57  static data_t* data;
58 
59  std::atomic<unsigned int> nvisits; // how many times have I visited each node?
60 
62 
64  }
65 
66  MCTSBase(HYP& start, this_t* par, size_t w) :
67  BaseNode<this_t>(0,par,0), open(true), which_expansion(w), nvisits(0) {
68  // here we don't expand the children because this is the constructor called when enlarging the tree
69  mylock.lock();
70 // print("MCTSBase constructor ", this, par);
71 
72  this->reserve_children(start.neighbors());
73  mylock.unlock();
74  }
75 
76  MCTSBase(HYP& start, double ex, data_t* d) :
77  BaseNode<this_t>(),
78  open(true), which_expansion(0), nvisits(0) {
79  // This is the constructor that gets called from main, and it sets the static variables. All the other constructors
80  // should use the above one
81 
82  mylock.lock();
83 
84  this->reserve_children(start.neighbors());
85  explore = ex;
86  data = d;
87 
88  mylock.unlock();
89  }
90 
91  // should not copy or move because then the parent pointers get all messed up
92  MCTSBase(const this_t& m) { // because what happens to the child pointers?
93  throw YouShouldNotBeHereError("*** should not be copying or moving MCTS nodes");
94  }
95 
96  MCTSBase(this_t&& m) {
97  // This must be defined for us to emplace_Back, but we don't actually want to move because that messes with
98  // the multithreading. So we'll throw an exception. You can avoid moves by reserving children up above,
99  // which should make it so that we never call this
100  throw YouShouldNotBeHereError("*** This must be defined for children.emplace_back, but should never be called");
101  }
102 
103  void operator=(const MCTSBase& m) {
104  throw YouShouldNotBeHereError("*** This must be defined for but should never be called");
105  }
106  void operator=(MCTSBase&& m) {
107  throw YouShouldNotBeHereError("*** This must be defined for but should never be called");
108  }
109 
110  void print(HYP from, std::ostream& o, const int depth, const bool sort) {
111  // here from is not a reference since we want to copy when we recurse
112 
113  if(nvisits == 0) return; // do nothing
114 
115  std::string idnt(depth, '\t'); // how far should we indent?
116 
117  std::string opn = (open?" ":"*");
118 
119  o << idnt TAB opn TAB statistics.max TAB statistics.N TAB statistics.mean TAB statistics.get_sd() TAB "visits=" << nvisits TAB which_expansion TAB from ENDL;
120 
121  // optional sort
122  if(sort) {
123  // we have to make a copy of our pointer array because otherwise its not const
124  std::vector<std::pair<this_t*,int>> c2;
125  int w = 0;
126  for(auto& c : this->get_children()) {
127  c2.emplace_back(&c,w);
128  ++w;
129  }
130  std::sort(c2.begin(),
131  c2.end(),
132  [](const auto a, const auto b) {
133  return a.first->nvisits > b.first->nvisits;
134  }
135  ); // sort by how many samples
136 
137  for(auto& c : c2) {
138  HYP newfrom = from; newfrom.expand_to_neighbor(c.second);
139  c.first->print(newfrom, o, depth+1, sort);
140  }
141  }
142  else {
143  int w = 0;
144  for(auto& c : this->get_children()) {
145  HYP newfrom = from; newfrom.expand_to_neighbor(w);
146  c.print(newfrom, o, depth+1, sort);
147  ++w;;
148  }
149  }
150  }
151 
152  // wrappers for file io
153  void print(HYP& start, const bool sort=true) {
154  print(start, std::cout, 0, sort);
155  }
156 
157  void print(HYP& start, const char* filename, const bool sort=true) {
158  std::ofstream out(filename);
159  print(start, out, 0, sort);
160  out.close();
161  }
162 
163  void add_sample(const float v) {
164 
165  // walk up to the root, adding please
166  auto ptr = this;
167  do {
168  ptr->statistics << v;
169  ptr = ptr->parent;
170  } while (ptr != nullptr);
171 
172  }
173 
174  virtual generator<HYP&> run_thread(Control& ctl, HYP h0) override {
175  ctl.start();
176 
177  while(ctl.running()) {
178  if(DEBUG_MCTS) DEBUG("\tMCTS SEARCH LOOP");
179 
180  HYP current = h0; // each thread makes a real copy here
181 
182  for(auto& h : this->search_one(current)) {
183  co_yield h;
184  }
185  }
186  }
187 
188 
193  virtual void process_evaluable(HYP& current) {
194  open = false; // make sure nobody else takes this one
195 
196  // if its a terminal, compute the posterior
197  current.compute_posterior(*data);
198  add_sample(current.posterior);
199  }
200 
206  virtual void add_children(HYP& current) {
207  int neigh = current.neighbors();
208 
209  mylock.lock();
210  // TODO: This is a bit inefficient because it copies current a bunch of times...
211  // this is intentionally a while loop in case another thread has snuck in and added
212  while(this->nchildren() < (size_t)neigh) {
213  int k = this->nchildren(); // which to expand
214  HYP kc = current; kc.expand_to_neighbor(k);
215  this->get_children().emplace_back(kc, reinterpret_cast<this_t*>(this), k);
216  }
217  mylock.unlock();
218  }
219 
220 
221  virtual int sample_child_index(HYP& current) {
222 
223  assert(this->children.size() > 0) ;
224 
225  int neigh = current.neighbors();
226 
227  // probability of expanding to each child
228  std::vector<double> children_lps(neigh, -infinity);
229 
230  // first, load up children_lps with whether they have been visited
231  bool all_visited = true;
232  for(int k=0;k<neigh;k++) {
233  if(this->child(k).open and this->child(k).nvisits == 0) {
234  all_visited = false;
235  children_lps[k] = current.neighbor_prior(k);
236  }
237  }
238 
239  // if all the neighbors have been visited, we'll overwrite everything
240  if(all_visited) {
241 
242  // this is basically UCT
243  for(int k=0;k<neigh;k++) {
244  if(this->child(k).open){
245  if(this->child(k).nvisits > 0) { // technically this can happen because multithreading
246  children_lps[k] = exp(this->child(k).statistics.max-this->statistics.max) + //this->statistics.max / this->child(k).statistics.max +
247  FleetArgs::explore * sqrt(log(double(this->nvisits))/this->child(k).nvisits);
248 // children_lps[k] = exp(this->child(k).statistics.median-this->statistics.median) + //this->statistics.max / this->child(k).statistics.max +
249 // FleetArgs::explore * sqrt(log(double(this->nvisits))/this->child(k).nvisits);
250  }
251  }
252  }
253 
254  // We'll use our thing to try to compute the posterior upper bound on the max
255  // here, we essentially assume that the mean and variance of the generative model
256  // are the sample mean and variance. Then, a uniform prior on the max M
257  // will mean that it only affects the likelihod of those N samples
258  // so the probability mass saved per sample is tau = 1-norm_cdf(M)
259  // So, the posterior should look like tau
260 // const double l1ma = log(1.0-0.9); // upper bound of CI we want to compute
261 // for(int k=0;k<neigh;k++) {
262 // auto& c = this->children[k];
263 // if(c.open){
264 // // here, we'll model as an exponential above the max
265 // // and compute the CI
266 // // a = 1 - exp(-sd*x) // CDF of exponential
267 // // x = l1ma / -sd
268 // double sd = c.statistics.get_sd()+1; // the +1 here is just a quick hack to deal with zeroes...
269 // double r = c.statistics.max + l1ma / -sd;
270 // children_lps[k] = (c.nvisits == 0 or std::isnan(sd) ? 0.0 : r);
271 // }
272 // }
273 
274  }
275 
276  // sometimes we'll get all NaNs, which is bad news for sampling
277  bool allNaN = true;
278  for(int k=0;k<neigh;k++) {
279  if(this->child(k).open and not std::isnan(children_lps[k])) {
280  allNaN = false;
281  break;
282  }
283  }
284 
285  if(allNaN) {
286  return myrandom(neigh); // just pick at random if all NaN
287  }
288  else {
289  // choose an index into children
290  //idx = sample_int_lp(neigh, [&](const int i) -> double {return children_lps[i];} ).first;
291  // choose the max (either of prior or of UCT)
292  return arg_max_int(neigh, [&](const int i) -> double {return children_lps[i];} ).first;
293  }
294  }
295 
300  virtual this_t* descend_to_childless(HYP& current) {
301  if(DEBUG_MCTS) DEBUG("descend_to_childless ", this, "\t["+current.string()+"] ", (unsigned long)this->nvisits);
302 
303  this->nvisits++; // change on the way down so other threads don't follow
304 
305  if(current.is_evaluable()) {
306  return reinterpret_cast<this_t*>(this);
307  }
308 
309  // add missing kids if we need them
310  if(this->children.size() < (size_t)current.neighbors()) {
311  this->nvisits--; // take this away so that others will return
312  return reinterpret_cast<this_t*>(this);
313  }
314 
315  // expand
316  auto idx = sample_child_index(current);
317  current.expand_to_neighbor(idx);
318 
319  return this->children[idx].descend_to_childless(current);
320  }
321 
326  virtual this_t* descend_to_evaluable(HYP& current) {
327  if(DEBUG_MCTS) DEBUG("descend_to_evaluable ", this, "\t["+current.string()+"] ", (unsigned long)this->nvisits);
328 
329  this->nvisits++; // change on the way down so other threads don't follow
330 
331  if(current.is_evaluable()) {
332  return reinterpret_cast<this_t*>(this);
333  }
334 
335  // add missing kids if we need them
336  if(this->children.size() < (size_t)current.neighbors()) {
337  this->nvisits--; // take this away so that others will return
338  this->add_children(current);
339  }
340 
341  // expand
342  auto idx = this->sample_child_index(current);
343  current.expand_to_neighbor(idx);
344 
345  return this->children[idx].descend_to_evaluable(current);
346  }
347 
353  virtual generator<HYP&> search_one(HYP& current) = 0;
354 };
355 
356 // Must be defined for the linker to find them, apparently:
357 template<typename this_t, typename HYP>
358 double MCTSBase<this_t, HYP>::explore = 1.0;
359 
360 template<typename this_t, typename HYP>
361 typename HYP::data_t* MCTSBase<this_t, HYP>::data = nullptr;
362 
363 
void operator=(const MCTSBase &m)
Definition: MCTSBase.h:103
MCTSBase()
Definition: MCTSBase.h:63
T myrandom(T max)
Definition: Random.h:176
double max
Definition: StreamingStatistics.h:21
Definition: SpinLock.h:10
int which_expansion
Definition: MCTSBase.h:52
void reserve_children(const size_t n)
Definition: BaseNode.h:164
MCTSBase(this_t &&m)
Definition: MCTSBase.h:96
decltype(children) & get_children()
Definition: BaseNode.h:168
std::vector< this_t > children
Definition: BaseNode.h:23
void lock()
Definition: SpinLock.h:14
This is a general tree class, which we are adding because there are currently at least 3 different tr...
virtual generator< HYP & > search_one(HYP &current)=0
This is not implemented in MCTSBase because it is different in Partial and Full (below). So, subclasses must implement this.
double mean
Definition: StreamingStatistics.h:24
#define TAB
Definition: IO.h:19
#define DEBUG_MCTS
Definition: MCTSBase.h:3
double explore
Definition: FleetArgs.h:19
bool open
Definition: MCTSBase.h:48
Definition: StreamingStatistics.h:14
virtual this_t * descend_to_evaluable(HYP &current)
This goes down the tree, sampling children until it finds an evaluable (building a full tree) ...
Definition: MCTSBase.h:326
MCTSBase(HYP &start, double ex, data_t *d)
Definition: MCTSBase.h:76
virtual generator< HYP & > run_thread(Control &ctl, HYP h0) override
Definition: MCTSBase.h:174
Definition: Errors.h:18
virtual int sample_child_index(HYP &current)
Definition: MCTSBase.h:221
Inference by sampling from the prior – doesn&#39;t tend to work well, but might be a useful baseline...
void print(HYP &start, const char *filename, const bool sort=true)
Definition: MCTSBase.h:157
typename HYP::data_t data_t
Definition: MCTSBase.h:46
void start()
Definition: Control.h:54
This manages multiple threads for running inference. This requires a subclass to define run_thread...
Definition: Control.h:23
void unlock()
Definition: SpinLock.h:17
Definition: BaseNode.h:20
virtual void process_evaluable(HYP &current)
If we can evaluate this current node (usually: compute a posterior and add_sample) ...
Definition: MCTSBase.h:193
constexpr double infinity
Definition: Numerics.h:20
Definition: generator.hpp:21
SpinLock mylock
Definition: MCTSBase.h:50
virtual this_t * descend_to_childless(HYP &current)
This goes down the tree to a node with no children (OR evaluable)
Definition: MCTSBase.h:300
void print(HYP from, std::ostream &o, const int depth, const bool sort)
Definition: MCTSBase.h:110
double get_sd()
Definition: StreamingStatistics.h:102
this_t & child(const size_t i)
Definition: BaseNode.h:175
MCTSBase(HYP &start, this_t *par, size_t w)
Definition: MCTSBase.h:66
double N
Definition: StreamingStatistics.h:27
virtual void add_children(HYP &current)
This gets called before descending the tree if we don&#39;t have all of our children. NOTE: This could ad...
Definition: MCTSBase.h:206
#define ENDL
Definition: IO.h:21
Definition: MCTSBase.h:41
This is a thread_local rng whose first object is used to see others (in other threads). This way, we can have thread_local rngs that all are seeded deterministcally in Fleet via –seed=X.
void add_sample(const float v)
Definition: MCTSBase.h:163
void DEBUG(FIRST f, ARGS... args)
Print to std:ccout with debugging info.
Definition: IO.h:73
std::pair< size_t, double > arg_max_int(unsigned int max, const std::function< double(const int)> &f)
Definition: Random.h:395
Definition: ThreadedInferenceInterface.h:22
bool running()
Definition: Control.h:63
size_t depth() const
Definition: BaseNode.h:230
static double explore
Definition: MCTSBase.h:56
MCTSBase(const this_t &m)
Definition: MCTSBase.h:92
std::atomic< unsigned int > nvisits
Definition: MCTSBase.h:59
A class to store a bunch of statistics about incoming data points, including min, max...
void print(HYP &start, const bool sort=true)
Definition: MCTSBase.h:153
static data_t * data
Definition: MCTSBase.h:57
This class has all the information for running MCMC or MCTS in a little package. It defaultly constru...
void operator=(MCTSBase &&m)
Definition: MCTSBase.h:106
StreamingStatistics statistics
Definition: MCTSBase.h:61
size_t nchildren() const
Definition: BaseNode.h:208