Fleet  0.0.9
Inference in the LOT
Lexicon.h
Go to the documentation of this file.
1 #pragma once
2 
3 #include <map>
4 #include <limits.h>
5 
9 
19 template<typename this_t,
20  typename _key_t,
21  typename INNER,
22  typename _input_t,
23  typename _output_t,
24  typename datum_t=defaultdatum_t<_input_t, _output_t>,
25  typename _VirtualMachineState_t=typename INNER::Grammar_t::VirtualMachineState_t
26  >
27 class Lexicon : public MCMCable<this_t,datum_t>,
28  public Searchable<this_t, _input_t, _output_t>, // TODO: Interface a little broken
29  public Serializable<this_t>,
30  public ProgramLoader<_VirtualMachineState_t>
31 {
32 public:
33  using Grammar_t = typename INNER::Grammar_t;
34  using input_t = _input_t;
35  using output_t = _output_t;
36  using key_t = _key_t;
37  using VirtualMachineState_t = _VirtualMachineState_t;
38 
39  // Store a lexicon of type INNER elements
40  const static char FactorDelimiter = '|';
41 
42  static double p_factor_propose;
43 
44  std::map<key_t,INNER> factors;
45 
46  Lexicon() : MCMCable<this_t,datum_t>() { }
47 
52  size_t nfactors() const {
53  return factors.size();
54  }
55 
56  // A lexicon's value is just that vector (this is used by GrammarHypothesis)
57  auto& get_value() { return factors; }
58  const auto& get_value() const { return factors; }
59 
60  INNER& at(const key_t& k) { return factors.at(k); }
61  const INNER& at(const key_t& k) const { return factors.at(k); }
62 
63  INNER& operator[](const key_t& k) { return factors[k]; }
64  const INNER& operator[](const key_t& k) const { return factors[k]; }
65 
66  bool contains(const key_t& key) {
67  return factors.contains(key);
68  }
69 
71  // NOTE that since LOTHypothesis has grammar as its type, all INNER Must have the
72  // same grammar type (But this may change in the future -- if it does, we need to
73  // update GrammarHypothesis::set_hypotheses_and_data)
74  return factors.begin()->second.get_grammar();
75  }
76 
77 
83  [[nodiscard]] static this_t sample(std::initializer_list<key_t> lst) {
84 
85  this_t out;
86  for(auto& k : lst){
87  out[k] = INNER::sample();
88  }
89 
90  return out;
91  }
92  [[nodiscard]] static this_t sample(const std::vector<key_t>& lst) {
93 
94  this_t out;
95  for(auto& k : lst){
96  out[k] = INNER::sample();
97  }
98 
99  return out;
100  }
101 
102  virtual std::string string(std::string prefix="") const override {
108  std::string s = prefix + "[";
109  for(auto& [k,f] : factors) {
110  s += "F("+str(k)+",x):=" + f.string() + ". ";
111  }
112  s.erase(s.size()-1); // remove last empty space
113  s.append("]");
114  return s;
115  }
116 
117  virtual size_t hash() const override {
123  std::hash<size_t> h;
124  size_t out = h(factors.size());
125  size_t i=0;
126  for(auto& [k,f] : factors){
127  hash_combine(out, f.hash(), k);
128  i++;
129  }
130  return out;
131  }
132 
138  virtual bool operator==(const this_t& l) const override {
139  return factors == l.factors;
140  }
141 
142 // /**
143 // * @brief A lexicon has valid indices if calls to op_RECURSE, op_MEM_RECURSE, op_SAFE_RECURSE, and op_SAFE_MEM_RECURSE all have arguments that are less than the size.
144 // * (So this places no restrictions on the calling earlier factors)
145 // * @return
146 // */
147 // bool has_valid_indices() const {
148 // for(auto& [k, f] : factors) {
149 // for(const auto& n : f.get_value() ) {
150 // if(n.rule->is_recursive()) {
151 // int fi = n.rule->arg; // which factor is called?
152 // if(fi >= (int)factors.size() or fi < 0)
153 // return false;
154 // }
155 // }
156 // }
157 // return true;
158 // }
159 
160 //
161 // bool check_reachable() const {
162 // /**
163 // * @brief Check if the last factor call everything else transitively (e.g. are we "wasting" factors)
164 // * We do this by making a graph of what factors call which others and then computing the transitive closure.
165 // * NOTE that this requires the key_type and assumes that a rule of that type can be gotten directly
166 // * from the first child of a recursive call (e.g. a terminal)
167 // * @return
168 // */
169 //
170 // const size_t N = factors.size();
171 // assert(N > 0);
172 //
173 // // is calls[i][j] stores whether factor i calls factor j
174 // //bool calls[N][N];
175 // std::vector<std::vector<bool> > calls(N, std::vector<bool>(N, false));
176 //
177 // // everyone calls themselves, zero the rest
178 // for(size_t i=0;i<N;i++) {
179 // for(size_t j=0;j<N;j++){
180 // calls[i][j] = (i==j);
181 // }
182 // }
183 //
184 // {
185 // int i=0;
186 // for(auto& [k, f] : factors) {
187 // for(const auto& n : f.get_value() ) {
188 // if(n.rule->is_recursive()) {
189 //
190 // // NOTE This assumes that n.child[0] is directly evaluable.
191 // const Rule* r = n.child(0).rule;
192 // assert(r->is_terminal()); // or else we can't use this
193 // auto fptr = reinterpret_cast<VirtualMachineState_t::FT*>(r->fptr);
194 // CERR string() ENDL;
195 // CERR ((*fptr)(nullptr,0)) ENDL;
196 //
197 // BLEH this doesn't work well anymore because we have to run the function
198 //
199 // calls[i][n.rule->arg] = true;
200 // }
201 // }
202 // i++;
203 // }
204 // }
205 //
206 // // now we take the transitive closure to see if calls[N-1] calls everything (eventually)
207 // // otherwise it has probability of zero
208 // // TOOD: This could probably be lazier because we really only need to check reachability
209 // for(size_t a=0;a<N;a++) {
210 // for(size_t b=0;b<N;b++) {
211 // for(size_t c=0;c<N;c++) {
212 // calls[b][c] = calls[b][c] or (calls[b][a] and calls[a][c]);
213 // }
214 // }
215 // }
216 //
217 // // don't do anything if we have uncalled functions from the root
218 // for(size_t i=0;i<N;i++) {
219 // if(not calls[N-1][i]) {
220 // return false;
221 // }
222 // }
223 // return true;
224 // }
225 
226 
227  /********************************************************
228  * Required for VMS to dispatch to the right sub
229  ********************************************************/
230 
236  virtual void push_program(Program<VirtualMachineState_t>& s, const key_t k) override {
237  this->was_called = true; // set this since we're a program loader
238  // dispath to the right factor
239  factors.at(k).push_program(s); // on a LOTHypothesis, we must call wiht j=0 (j is used in Lexicon to select the right one)
240  }
241 
242  /********************************************************
243  * Implementation of MCMCable interace
244  ********************************************************/
245 
246  virtual void complete() override {
247  for(auto& [k, f] : factors){
248  f.complete();
249  }
250  }
251 
252  virtual double compute_prior() override {
253  // this uses a proper prior which flips a coin to determine the number of factors
254 
255  this->prior = log(0.5)*(factors.size()+1); // +1 to end adding factors, as in a geometric
256 
257  for(auto& [k,f] : factors) {
258  this->prior += f.compute_prior();
259  }
260 
261  return this->prior;
262  }
263 
269  [[nodiscard]] virtual std::optional<std::pair<this_t,double>> propose() const override {
270 
271  // let's first make a vector to see which factor we propose to.
272  auto should_propose = random_nonempty_subset(factors.size(), p_factor_propose);
273 
274  // now go through and propose to those factors
275  // (NOTE fb is always zero)
276  // NOTE: This is not great because it doesn't copy like we might want...
277  this_t x; double fb = 0.0;
278  int idx = 0;
279  for(auto& [k,f] : factors) {
280  if(should_propose[idx]) {
281  auto p = f.propose();
282  if(p){
283  auto [h, _fb] = p.value();
284  x.factors[k] = h;
285  fb += _fb;
286  }
287  else {
288  x.factors[k] = f; // on failed proposal just copy
289  }
290  } else {
291  x.factors[k] = f;
292  }
293  idx++;
294  }
295  assert(x.factors.size() == factors.size());
296 
297  return std::make_pair(x,fb);
298  }
299 
300 
301  [[nodiscard]] virtual this_t restart() const override {
302  this_t x;
303  for(auto& [k,f] : factors) {
304  x.factors[k] = f.restart();
305  }
306  return x;
307  }
308 
309  template<typename... A>
310  [[nodiscard]] static this_t make(A... a) {
311  auto h = this_t(a...);
312  return h.restart();
313  }
314 
315  /********************************************************
316  * Implementation of Searchable interace
317  ********************************************************/
318  // Main complication is that we need to manage nullptrs/empty in order to start, so we piggyback
319  // on LOTHypothesis. To od this, we assume we can construct T with T tmp(grammar, nullptr) and
320  // use that temporary object to compute neighbors etc.
321 
322  // for now, we define neighbors only for *complete* (evaluable) trees -- so you can add a factor if you have a complete tree already
323  // otherwise, no adding factors
324  int neighbors() const override {
325 
326  if(is_evaluable()) {
327  INNER tmp; // make something null, and then count its neighbors
328  return tmp.neighbors();
329  }
330  else {
331  // we should have everything complete except the last
332  return factors.rbegin()->second.neighbors();
333  }
334  }
335 
336  void expand_to_neighbor(int k) override {
337  // This is currently a bit broken -- it used to know when to add a factor
338  // but now we can't add a factor without knowing the key, so for now
339  // this is just assert(false);
340  throw NotImplementedError();
341  }
342 
343 
344  virtual double neighbor_prior(int k) override {
345  assert(k < factors.rbegin()->second.neighbors());
346  return factors.rbegin()->second.neighbor_prior(k);
347  }
348 
349  bool is_evaluable() const override {
350  for(auto& [k,f]: factors) {
351  if(not f.is_evaluable()) return false;
352  }
353  return true;
354  }
355 
356 
357  /********************************************************
358  * How to call
359  ********************************************************/
360 
361  virtual DiscreteDistribution<output_t> call(const key_t k, const input_t x, const output_t& err=output_t{}) {
362  throw NotImplementedError();
363  }
364 
365  virtual std::string serialize() const override {
371  std::string out = str(this->prior) + Lexicon::FactorDelimiter +
374  for(auto& [k,f] : factors) {
375  out += str(k) + Lexicon::FactorDelimiter +
376  f.serialize() + Lexicon::FactorDelimiter;
377  }
378  out.erase(out.size()-1); // remove last delmiter
379  return out;
380  }
381 
382  static this_t deserialize(const std::string s) {
390  this_t h;
391  auto q = split(s, Lexicon::FactorDelimiter);
392 
393  h.prior = string_to<double>(q.front()); q.pop_front();
394  h.likelihood = string_to<double>(q.front()); q.pop_front();
395  h.posterior = string_to<double>(q.front()); q.pop_front();
396 
397  while(not q.empty()){
398  key_t k = string_to<key_t>(q.front()); q.pop_front();
399  auto v = INNER::deserialize(q.front()); q.pop_front();
400  h.factors[k] = v;
401  }
402  return h;
403  }
404 };
405 
406 
407 template<typename this_t,
408  typename key_t,
409  typename INNER,
410  typename _input_t,
411  typename _output_t,
412  typename datum_t,
413  typename _VirtualMachineState_t
414  >
std::vector< bool > random_nonempty_subset(const size_t n, const double p)
Returns a random nonempty subset of n elements, as a vector<bool> of length n with trues for elements...
Definition: Random.h:220
virtual size_t hash() const override
Definition: Lexicon.h:117
virtual void complete() override
Fill in all the holes in this hypothesis, at random, modifying self. NOTE for LOTHypotheses this will...
Definition: Lexicon.h:246
virtual double neighbor_prior(int k) override
What is the prior of the k&#39;th neighbor? This does not need to return the full prior, only relative (among ks)
Definition: Lexicon.h:344
double likelihood
Definition: Bayesable.h:42
bool is_evaluable() const override
Check if we can evaluate this node (meaning compute a prior and posterior). NOTE that this is not the...
Definition: Lexicon.h:349
typename InnerHypothesis ::Grammar_t::VirtualMachineState_t VirtualMachineState_t
Definition: Lexicon.h:37
const INNER & at(const key_t &k) const
Definition: Lexicon.h:61
Definition: Program.h:17
Definition: MCMCable.h:14
std::pair< t *, double > sample(const T &s, double z, const std::function< double(const t &)> &f=[](const t &v){return 1.0;})
Definition: Random.h:258
A class is searchable if permits us to enumerate and make its neighbors. This class is used by MCTS a...
Definition: Serializable.h:4
std::deque< std::string > split(const std::string &s, const char delimiter)
Split is returns a deque of s split up at the character delimiter. It handles these special cases: sp...
Definition: str.h:44
double prior
Definition: Bayesable.h:41
Definition: DiscreteDistribution.h:25
virtual DiscreteDistribution< output_t > call(const key_t k, const input_t x, const output_t &err=output_t{})
Definition: Lexicon.h:361
static double p_factor_propose
Definition: Lexicon.h:42
Grammar_t * get_grammar()
Definition: Lexicon.h:70
int neighbors() const override
Count the number of neighbors that are possible. (This should be size_t but int is more convenient...
Definition: Lexicon.h:324
const INNER & operator[](const key_t &k) const
Definition: Lexicon.h:64
bool contains(const key_t &key)
Definition: Lexicon.h:66
std::string str(BindingTree *t)
Definition: BindingTree.h:195
double posterior
Definition: Bayesable.h:43
const auto & get_value() const
Definition: Lexicon.h:58
virtual std::optional< std::pair< this_t, double > > propose() const override
This proposal guarantees that there will be at least one factor that is proposed to. Each individual factor is proposed to with p_factor_propose.
Definition: Lexicon.h:269
Definition: Datum.h:15
static this_t sample(const std::vector< key_t > &lst)
Definition: Lexicon.h:92
typename InnerHypothesis ::Grammar_t Grammar_t
Definition: Lexicon.h:33
auto & get_value()
Definition: Lexicon.h:57
virtual std::string serialize() const override
Definition: Lexicon.h:365
static const char FactorDelimiter
Definition: Lexicon.h:40
Definition: MyHypothesis.h:80
Args... datum_t
Definition: Bayesable.h:37
Definition: Lexicon.h:27
std::map< key_t, INNER > factors
Definition: Lexicon.h:44
static this_t make(A... a)
Definition: Lexicon.h:310
Definition: Errors.h:7
Lexicon()
Definition: Lexicon.h:46
The Bayesable class provides an interface for hypotheses that support Bayesian inference (e...
Definition: Searchable.h:13
virtual std::string string(std::string prefix="") const override
Definition: Lexicon.h:102
void expand_to_neighbor(int k) override
Modify this hypothesis to become the k&#39;th neighbor. NOTE This does not compile since it might not be ...
Definition: Lexicon.h:336
static this_t sample(std::initializer_list< key_t > lst)
Sample with n factors.
Definition: Lexicon.h:83
Definition: Program.h:6
size_t nfactors() const
Return the number of factors.
Definition: Lexicon.h:52
static this_t deserialize(const std::string s)
Definition: Lexicon.h:382
virtual this_t restart() const override
Definition: Lexicon.h:301
virtual bool operator==(const this_t &l) const override
Equality checks equality on each part.
Definition: Lexicon.h:138
bool was_called
Definition: Program.h:23
INNER & at(const key_t &k)
Definition: Lexicon.h:60
virtual void push_program(Program< VirtualMachineState_t > &s, const key_t k) override
A lexicon has valid indices if calls to op_RECURSE, op_MEM_RECURSE, op_SAFE_RECURSE, and op_SAFE_MEM_RECURSE all have arguments that are less than the size. (So this places no restrictions on the calling earlier factors)
Definition: Lexicon.h:236
A class is MCMCable if it is Bayesable and lets us propose, restart, and check equality (which MCMC d...
virtual double compute_prior() override
Definition: Lexicon.h:252
INNER & operator[](const key_t &k)
Definition: Lexicon.h:63