Fleet  0.0.9
Inference in the LOT
ShibbolethSampler.h
Go to the documentation of this file.
1 // TODO: This spends too much time on bad hypotheses since it divides time equally
2 // better might be to constrain the chain. The challenge is figuring out how to update
3 // the constraint/chain so that it is still a correct sampler.
4 
5 // For now, we don't try to do that -- we just sample conditional on a constraint
6 
7 #pragma once
8 
9 #include <utility>
10 #include <functional>
11 #include "SampleStreams.h"
12 #include "ChainPool.h"
13 #include "FiniteHistory.h"
14 #include "Control.h"
15 #include "OrderedLock.h"
16 #include "FleetStatistics.h"
17 #include "Random.h"
18 #include "MCMCChain.h"
19 
20 template<typename HYP>
21 class ConstrainedMCMC : public MCMCChain<HYP> {
22 public:
23  using input_t = HYP::input_t;
24 
25  // my god we need this to figure out the return type of shibboleth_call
26  using output_t = decltype(std::declval<HYP>().shibboleth_call( std::declval<input_t>()));
27 
28  size_t FLIP_EVERY = 10; // how many samples before we
29  bool store = false;
30 
31  size_t steps = 0;
32 
35 
36  ConstrainedMCMC(HYP& h0, input_t& in, typename HYP::data_t* d) :
37  MCMCChain<HYP>(h0,d), sh_in(in) {
38 
39  // we store the output on the shib on our initial sample
40  sh_out = h0.shibboleth_call(sh_in);
41  }
42 
43  // this is called inside MCMCChain to quickly reject hypotheses that don't match
44  virtual bool check(HYP& h) override {
45 
46  steps++;
47  //print(h.string());
48 
49  auto o = h.shibboleth_call(sh_in);
50 
51  if(steps % 1000 == 0) {
52  ::print("good", str(o));
53  sh_out = o;
54  store = false; // we stored it;
55  return true;
56  }
57  else {
58  if(o == sh_out) {
59  ::print("good", str(o), str(h.string()));
60  return true;
61  }
62  else {
63  ::print("bad", str(o), str(h.string()));
64  ++this->samples;
66  return false; // reject this sample otherwise but pretend it was a sample ok?
67  }
68  }
69 
70  }
71  /*
72  // this is called inside MCMCChain to quickly reject hypotheses that don't match
73  virtual bool check(HYP& h) override {
74  steps++;
75 
76  // if we are in an "even" range of FLIP_EVERY we have no constraint
77  // otherwise we constrain to sh_out
78  if(int(steps / FLIP_EVERY) % 2 == 0){
79  store = true;
80  return true;
81  }
82  else { // odd chunk of FLIP_EVERY
83 
84  auto o = h.shibboleth_call(sh_in);
85 
86  if(store) {
87  sh_out = o;
88  store = false; // we stored it;
89  return true;
90  }
91  else {
92  if(o == sh_out) {
93  return true;
94  }
95  else {
96  ++this->samples;
97  ++FleetStatistics::global_sample_count;
98  return false; // reject this sample otherwise but pretend it was a sample ok?
99  }
100  }
101  }
102  }*/
103 
104 };
105 
106 template<typename HYP>
107 class ShibbolethSampler : public ChainPool<HYP,ConstrainedMCMC<HYP>> {
108 
109 public:
111  using Super::Super;
112 
113  ShibbolethSampler(HYP& h0, typename HYP::input_t& in, typename HYP::data_t* d) {
114  this->add_chain(h0, in, d);
115  }
116 };
117 
118 
119 //#define DEBUG_MCMC 1
120 //template<typename HYP> class ShibbolethSampler;
121 //
122 //template<typename HYP>
123 //class ConstrainedMCMC : public MCMCChain<HYP> {
124 //public:
125 // using input_t = HYP::input_t;
126 //
127 // // my god we need this to figure out the return type of shibboleth_call
128 // using output_t = decltype(std::declval<HYP>().shibboleth_call( std::declval<input_t>()));
129 //
130 // input_t sh_in;
131 // output_t sh_out;
132 //
133 // // we need to store parent so that we can add to it when we get a new data type
134 // ShibbolethSampler<HYP>* parent;
135 // typename HYP::data_t* data;
136 //
137 // ConstrainedMCMC(ShibbolethSampler<HYP>* _parent, HYP& h0, input_t& in, typename HYP::data_t* d) :
138 // sh_in(in), parent(_parent), data(d), MCMCChain<HYP>(h0,d) {
139 //
140 // // we store the output on the shib on our initial sample
141 // sh_out = h0.shibboleth_call(sh_in);
142 // }
143 //
144 // // this is called inside MCMCChain to quickly reject hypotheses that don't match
145 // virtual bool check(HYP& h) override {
146 // auto o = h.shibboleth_call(sh_in);
147 // if(o == sh_out) {
148 // //print("# sample ", this->current.posterior, sh_out, h.string());
149 // return true;
150 // }
151 // else {
152 // // add if it doesn't exist already
153 // if(not parent->seen.contains(o)) {
154 // assert(parent->pool.size() + 1 < parent->MAX_POOL_SIZE);
155 // print("# Shibboleth adding", o, h.string());
156 // parent->seen.insert(o); // add this
157 // parent->add_chain(parent, h, sh_in, data); // note, locking done internal to ChainPool
158 // }
159 //
160 // // we're going to count these as samples, even though we won't return current (for speed)
161 // ++this->samples;
162 // ++FleetStatistics::global_sample_count;
163 //
164 // return false; // reject this sample otherwise
165 // }
166 // }
167 //};
168 //
169 //template<typename HYP>
170 //class ShibbolethSampler : public ChainPool<HYP,ConstrainedMCMC<HYP>> {
171 //
172 //public:
173 // using Super = ChainPool<HYP,ConstrainedMCMC<HYP>> ;
174 // using Super::Super;
175 //
176 // // we must pre-allocate the pool, so
177 // static const size_t MAX_POOL_SIZE = 1024;
178 //
179 // using output_t = ConstrainedMCMC<HYP>::output_t;
180 //
181 // std::set<output_t> seen;
182 //
183 // ShibbolethSampler(HYP& h0, typename HYP::input_t& in, typename HYP::data_t* d) {
184 // this->add_chain(this, h0, in, d);
185 // print("# starting with", in);
186 //
187 // // must reserve so they don't move/reallocate -- TODO: FIX THIS
188 // this->pool.reserve(MAX_POOL_SIZE);
189 // this->running.reserve(MAX_POOL_SIZE);
190 //
191 // }
192 //
193 //};
unsigned long samples
Definition: MCMCChain.h:41
size_t steps
Definition: ShibbolethSampler.h:31
A ChainPool stores a bunch of MCMCChains and allows you to run them serially or in parallel...
std::atomic< uintmax_t > global_sample_count(0)
input_t sh_in
Definition: ShibbolethSampler.h:33
Definition: ShibbolethSampler.h:107
HYP HYP
Definition: MCMCChain.h:26
bool store
Definition: ShibbolethSampler.h:29
ShibbolethSampler(HYP &h0, typename HYP::input_t &in, typename HYP::data_t *d)
Definition: ShibbolethSampler.h:113
decltype(std::declval< HYP >().shibboleth_call(std::declval< input_t >())) output_t
Definition: ShibbolethSampler.h:26
Definition: MCMCChain.h:23
std::string str(BindingTree *t)
Definition: BindingTree.h:195
Definition: ChainPool.h:25
size_t FLIP_EVERY
Definition: ShibbolethSampler.h:28
void print(FIRST f, ARGS... args)
Lock output_lock and print to std:cout.
Definition: IO.h:53
A FiniteHistory stores the previous N examples of something of type T. This is used e...
A FIFO mutex (from stackoverflow) https://stackoverflow.com/questions/14792016/creating-a-lock-that-p...
output_t sh_out
Definition: ShibbolethSampler.h:34
ConstrainedMCMC(HYP &h0, input_t &in, typename HYP::data_t *d)
Definition: ShibbolethSampler.h:36
virtual bool check(HYP &h) override
This allows us to overwrite/enforce stuff about proposals in subclasses of MCMCChain.
Definition: ShibbolethSampler.h:44
HYP::input_t input_t
Definition: ShibbolethSampler.h:23
This is a thread_local rng whose first object is used to see others (in other threads). This way, we can have thread_local rngs that all are seeded deterministcally in Fleet via –seed=X.
This represents an MCMC hain on a hypothesis of type HYP. It uses HYP::propose and HYP::compute_poste...
This class has all the information for running MCMC or MCTS in a little package. It defaultly constru...
Definition: ShibbolethSampler.h:21