40 template<
typename this_t,
typename HYP>
67 BaseNode<this_t>(0,par,0), open(true), which_expansion(w), nvisits(0) {
78 open(true), which_expansion(0), nvisits(0) {
100 throw YouShouldNotBeHereError(
"*** This must be defined for children.emplace_back, but should never be called");
110 void print(HYP from, std::ostream& o,
const int depth,
const bool sort) {
113 if(nvisits == 0)
return;
115 std::string idnt(depth,
'\t');
117 std::string opn = (open?
" ":
"*");
124 std::vector<std::pair<this_t*,int>> c2;
127 c2.emplace_back(&c,w);
130 std::sort(c2.begin(),
132 [](
const auto a,
const auto b) {
133 return a.first->nvisits > b.first->nvisits;
138 HYP newfrom = from; newfrom.expand_to_neighbor(c.second);
139 c.first->print(newfrom, o, depth+1, sort);
145 HYP newfrom = from; newfrom.expand_to_neighbor(w);
146 c.print(newfrom, o, depth+1, sort);
153 void print(HYP& start,
const bool sort=
true) {
154 print(start, std::cout, 0, sort);
157 void print(HYP& start,
const char* filename,
const bool sort=
true) {
158 std::ofstream out(filename);
159 print(start, out, 0, sort);
168 ptr->statistics << v;
170 }
while (ptr !=
nullptr);
197 current.compute_posterior(*data);
207 int neigh = current.neighbors();
212 while(this->
nchildren() < (
size_t)neigh) {
214 HYP kc = current; kc.expand_to_neighbor(k);
215 this->
get_children().emplace_back(kc, reinterpret_cast<this_t*>(
this), k);
225 int neigh = current.neighbors();
228 std::vector<double> children_lps(neigh, -
infinity);
231 bool all_visited =
true;
232 for(
int k=0;k<neigh;k++) {
233 if(this->
child(k).open and this->
child(k).nvisits == 0) {
235 children_lps[k] = current.neighbor_prior(k);
243 for(
int k=0;k<neigh;k++) {
244 if(this->
child(k).open){
245 if(this->
child(k).nvisits > 0) {
246 children_lps[k] = exp(this->
child(k).statistics.
max-this->statistics.max) +
278 for(
int k=0;k<neigh;k++) {
279 if(this->
child(k).open and not std::isnan(children_lps[k])) {
292 return arg_max_int(neigh, [&](
const int i) ->
double {
return children_lps[i];} ).first;
301 if(
DEBUG_MCTS)
DEBUG(
"descend_to_childless ",
this,
"\t["+current.string()+
"] ", (
unsigned long)this->nvisits);
305 if(current.is_evaluable()) {
306 return reinterpret_cast<this_t*
>(
this);
310 if(this->
children.size() < (size_t)current.neighbors()) {
312 return reinterpret_cast<this_t*
>(
this);
317 current.expand_to_neighbor(idx);
319 return this->
children[idx].descend_to_childless(current);
327 if(
DEBUG_MCTS)
DEBUG(
"descend_to_evaluable ",
this,
"\t["+current.string()+
"] ", (
unsigned long)this->nvisits);
331 if(current.is_evaluable()) {
332 return reinterpret_cast<this_t*
>(
this);
336 if(this->
children.size() < (size_t)current.neighbors()) {
343 current.expand_to_neighbor(idx);
345 return this->
children[idx].descend_to_evaluable(current);
357 template<
typename this_t,
typename HYP>
360 template<
typename this_t,
typename HYP>
void operator=(const MCTSBase &m)
Definition: MCTSBase.h:103
MCTSBase()
Definition: MCTSBase.h:63
T myrandom(T max)
Definition: Random.h:176
double max
Definition: StreamingStatistics.h:21
Definition: SpinLock.h:10
int which_expansion
Definition: MCTSBase.h:52
void reserve_children(const size_t n)
Definition: BaseNode.h:164
MCTSBase(this_t &&m)
Definition: MCTSBase.h:96
decltype(children) & get_children()
Definition: BaseNode.h:168
std::vector< this_t > children
Definition: BaseNode.h:23
void lock()
Definition: SpinLock.h:14
This is a general tree class, which we are adding because there are currently at least 3 different tr...
virtual generator< HYP & > search_one(HYP ¤t)=0
This is not implemented in MCTSBase because it is different in Partial and Full (below). So, subclasses must implement this.
double mean
Definition: StreamingStatistics.h:24
#define TAB
Definition: IO.h:19
#define DEBUG_MCTS
Definition: MCTSBase.h:3
double explore
Definition: FleetArgs.h:19
bool open
Definition: MCTSBase.h:48
Definition: StreamingStatistics.h:14
virtual this_t * descend_to_evaluable(HYP ¤t)
This goes down the tree, sampling children until it finds an evaluable (building a full tree) ...
Definition: MCTSBase.h:326
MCTSBase(HYP &start, double ex, data_t *d)
Definition: MCTSBase.h:76
virtual generator< HYP & > run_thread(Control &ctl, HYP h0) override
Definition: MCTSBase.h:174
virtual int sample_child_index(HYP ¤t)
Definition: MCTSBase.h:221
Inference by sampling from the prior – doesn't tend to work well, but might be a useful baseline...
void print(HYP &start, const char *filename, const bool sort=true)
Definition: MCTSBase.h:157
typename HYP::data_t data_t
Definition: MCTSBase.h:46
void start()
Definition: Control.h:54
This manages multiple threads for running inference. This requires a subclass to define run_thread...
void unlock()
Definition: SpinLock.h:17
Definition: BaseNode.h:20
virtual void process_evaluable(HYP ¤t)
If we can evaluate this current node (usually: compute a posterior and add_sample) ...
Definition: MCTSBase.h:193
constexpr double infinity
Definition: Numerics.h:20
Definition: generator.hpp:21
SpinLock mylock
Definition: MCTSBase.h:50
virtual this_t * descend_to_childless(HYP ¤t)
This goes down the tree to a node with no children (OR evaluable)
Definition: MCTSBase.h:300
void print(HYP from, std::ostream &o, const int depth, const bool sort)
Definition: MCTSBase.h:110
double get_sd()
Definition: StreamingStatistics.h:102
this_t & child(const size_t i)
Definition: BaseNode.h:175
MCTSBase(HYP &start, this_t *par, size_t w)
Definition: MCTSBase.h:66
double N
Definition: StreamingStatistics.h:27
virtual void add_children(HYP ¤t)
This gets called before descending the tree if we don't have all of our children. NOTE: This could ad...
Definition: MCTSBase.h:206
#define ENDL
Definition: IO.h:21
Definition: MCTSBase.h:41
This is a thread_local rng whose first object is used to see others (in other threads). This way, we can have thread_local rngs that all are seeded deterministcally in Fleet via –seed=X.
void add_sample(const float v)
Definition: MCTSBase.h:163
void DEBUG(FIRST f, ARGS... args)
Print to std:ccout with debugging info.
Definition: IO.h:73
std::pair< size_t, double > arg_max_int(unsigned int max, const std::function< double(const int)> &f)
Definition: Random.h:395
Definition: ThreadedInferenceInterface.h:22
bool running()
Definition: Control.h:63
size_t depth() const
Definition: BaseNode.h:230
static double explore
Definition: MCTSBase.h:56
MCTSBase(const this_t &m)
Definition: MCTSBase.h:92
std::atomic< unsigned int > nvisits
Definition: MCTSBase.h:59
A class to store a bunch of statistics about incoming data points, including min, max...
void print(HYP &start, const bool sort=true)
Definition: MCTSBase.h:153
static data_t * data
Definition: MCTSBase.h:57
This class has all the information for running MCMC or MCTS in a little package. It defaultly constru...
void operator=(MCTSBase &&m)
Definition: MCTSBase.h:106
StreamingStatistics statistics
Definition: MCTSBase.h:61
size_t nchildren() const
Definition: BaseNode.h:208