Fleet  0.0.9
Inference in the LOT
Builtins.h
Go to the documentation of this file.
1 #pragma once
2 
3 #include "Numerics.h"
4 #include "IO.h"
5 #include "Errors.h"
6 #include "Primitive.h"
7 
8 // Define this because it's an ugly expression to keep typing and we might want to change all at once
9 #define BUILTIN_LAMBDA +[](typename Grammar_t::VirtualMachineState_t* vms, int arg) -> void
10 
11 namespace Builtins {
12 
13  template<typename Grammar_t>
15  assert(!vms->xstack.empty());
16  vms->template push<typename Grammar_t::input_t>(vms->xstack.top()); // not a pop!
17  });
18 
19  template<typename Grammar_t>
21 
22  // process the short circuit
23  bool b = vms->template getpop<bool>(); // bool has already evaluted
24 
25  if(!b) {
26  vms->program.popn(arg); // pop off the other branch
27  vms->template push<bool>(false); // entire and must be false
28  }
29  else {
30  // else our value is just the other value -- true when its true and false when its false
31  }
32  });
33 
34  template<typename Grammar_t>
36 
37  // process the short circuit
38  bool b = vms->template getpop<bool>(); // bool has already evaluted
39 
40  if(b) {
41  vms->program.popn(arg); // pop off the other branch
42  vms->template push<bool>(true);
43  }
44  else {
45  // else our value is just the other value -- true when its true and false when its false
46  }
47  });
48 
49  template<typename Grammar_t>
51  vms->push(not vms->template getpop<bool>());
52  });
53 
54 
55  template<typename Grammar_t, typename T>
57  // Op::If has to short circuit and skip (pop some of the stack)
58  bool b = vms->template getpop<bool>(); // bool has already evaluted
59 
60  // now ops must skip the xbranch
61  if(!b) vms->program.popn(arg);
62  else {}; // do nothing, we pass through and then get to the jump we placed at the end of the x branch
63  });
64 
65  template<typename Grammar_t>
67  vms->program.popn(arg);
68  });
69 
70  template<typename Grammar_t>
72  vms->xstack.pop();
73  });
74 
75  template<typename Grammar_t>
77  assert(vms->pool != nullptr);
78 
79  // push both routes onto the stack
80  vms->pool->copy_increment_push(vms, true, -LOG2);
81  bool b = vms->pool->increment_push(vms, false, -LOG2);
82 
83  // TODO: This is clumsy, ugly mechanism -- need to re-do
84 
85  // since we pushed this back onto the queue (via increment_push), we need to tell the
86  // pool not to delete this, so we send back this special signal
87  if(b) { // theoutcome of increment_push decides whether I am deleted or not
89  }
90  else {
91  vms->status = vmstatus_t::RANDOM_CHOICE;
92  }
93  });
94 
95  template<typename Grammar_t>
97  assert(vms->pool != nullptr);
98 
99  // get the coin weight
100  double p = vms->template getpop<double>();
101 
102  // some checking
103  if(std::isnan(p)) { p = 0.0; } // treat nans as 0s
104  if(p > 1.0 or p < 0.0) {
105  print("*** Error, received p not in [0,1]:", p);
106  assert(false);
107  }
108 
109  // push both routes onto the stack
110  vms->pool->copy_increment_push(vms, true, log(p));
111  bool b = vms->pool->increment_push(vms, false, log(1.0-p));
112 
113  // TODO: This is clumsy, ugly mechanism -- need to re-do
114 
115  // since we pushed this back onto the queue (via increment_push), we need to tell the
116  // pool not to delete this, so we send back this special signal
117  if(b) { // theoutcome of increment_push decides whether I am deleted or not
119  }
120  else {
121  vms->status = vmstatus_t::RANDOM_CHOICE;
122  }
123  });
124 
125 
126  // This is a version of FlipP that doesn't complain if p>1 or p<0 -- it
127  // just sets them to those values
128  template<typename Grammar_t>
130  assert(vms->pool != nullptr);
131 
132  // get the coin weight
133  double p = vms->template getpop<double>();
134 
135  // some checking
136  if(std::isnan(p)) p = 0.0; // treat nans as 0s
137  else if(p > 1.0) p = 1.0;
138  else if(p < 0.0) p = 0.0;
139 
140  // push both routes onto the stack
141  vms->pool->copy_increment_push(vms, true, log(p));
142  bool b = vms->pool->increment_push(vms, false, log(1.0-p));
143 
144  // TODO: This is clumsy, ugly mechanism -- need to re-do
145 
146  // since we pushed this back onto the queue (via increment_push), we need to tell the
147  // pool not to delete this, so we send back this special signal
148  if(b) { // theoutcome of increment_push decides whether I am deleted or not
150  }
151  else {
152  vms->status = vmstatus_t::RANDOM_CHOICE;
153  }
154  });
155 
156  template<typename Grammar_t, typename t, typename T=std::set<t>>
158 
159  // implement sampling from the set.
160  // to do this, we read the set and then push all the alternatives onto the stack
161  auto s = vms->template getpop<T>();
162 
163  // One useful optimization here is that sometimes that set only has one element. So first we check that, and if so we don't need to do anything
164  // also this is especially convenient because we only have one element to pop
165  if(s.size() == 1) {
166  auto v = std::move(*s.begin());
167  vms->template push<t>(std::move(v));
168  }
169  else {
170  // else there is more than one, so we have to copy the stack and increment the lp etc for each
171  // NOTE: The function could just be this latter branch, but that's much slower because it copies vms
172  // even for single stack elements
173 
174  // now just push on each, along with their probability
175  // which is here decided to be uniform.
176  const double lp = -log(s.size());
177  for(const auto& x : s) {
178  bool b = vms->pool->copy_increment_push(vms,x,lp);
179  if(not b) break; // we can break since all of these have the same lp -- if we don't add one, we won't add any!
180  }
181 
182  vms->status = vmstatus_t::RANDOM_CHOICE; // we don't continue with this context
183  }
184  });
185 
186  template<typename Grammar_t>
188  // sample 0,1,2,3, ... <first argument>-1
189 
190  const int mx = vms->template getpop<int>();
191 
192  const double lp = -log(mx);
193  for(int i=0;i<mx;i++) {
194  bool b = vms->pool->copy_increment_push(vms,i,lp);
195  if(not b) break; // we can break since all of these have the same lp -- if we don't add one, we won't add any!
196  }
197  vms->status = vmstatus_t::RANDOM_CHOICE; // we don't continue with this context
198  });
199 
200 
201  template<typename Grammar_t, typename key_t, typename output_t=typename Grammar_t::output_t>
203  auto memindex = vms->template memstack<key_t>().top(); vms->template memstack<key_t>().pop();
204  if(vms->template mem<key_t>().count(memindex)==0) { // you might actually have already placed mem in crazy recursive situations, so don't overwrite if you have
205  vms->template mem<key_t>()[memindex] = vms->template gettop<output_t>(); // what I should memoize should be on top here, but don't remove because we also return it
206  }
207  });
208 
209 
210  template<typename Grammar_t>
212  });
213 
214  template<typename Grammar_t>
216  assert(false);
217  });
218 
219  template<typename Grammar_t,
220  typename input_t=typename Grammar_t::input_t,
221  typename output_t=typename Grammar_t::output_t>
223 
224  assert(vms->program.loader != nullptr);
225 
226  if(vms->recursion_depth++ > vms->MAX_RECURSE) { // there is one of these for each recurse
227  throw VMSRuntimeError();
228  }
229 
230  // if we get here, then we have processed our arguments and they are stored in the input_t stack.
231  // so we must move them to the x stack (where there are accessible by op_X)
232  auto mynewx = vms->template getpop<input_t>();
233  vms->xstack.push(std::move(mynewx));
234  vms->program.push(Builtins::PopX<Grammar_t>.makeInstruction()); // we have to remember to remove X once the other program evaluates, *after* everything has evaluated
235 
236  // push this program
237  // but we give i.arg so that we can pass factorized recursed
238  // in argument if we want to
239  vms->program.loader->push_program(vms->program);
240 
241  // after execution is done, the result will be pushed onto output_t
242  // which is what gets returned when we are all done
243 
244  });
245 
246 
247 
248 
249 
250  template<typename Grammar_t,
251  typename input_t=typename Grammar_t::input_t,
252  typename output_t=typename Grammar_t::output_t>
254  assert(not vms->template stack<input_t>().empty());
255 
256  // if the size of the top is zero, we return output{}
257  if(vms->template stack<input_t>().top().size() == 0) {
258  vms->template getpop<input_t>(); // this would have been the argument
259  vms->template push<output_t>(output_t{}); //push default (null) return
260  }
261  else {
262  Recurse<Grammar_t>.call(vms,arg);
263  }
264  });
265 
266 
267 
268  template<typename Grammar_t,
269  typename input_t=typename Grammar_t::input_t,
270  typename output_t=typename Grammar_t::output_t>
271  Primitive<output_t,input_t> MemRecurse(Op::MemRecurse, BUILTIN_LAMBDA { // note the order switch -- that's right!
272  assert(vms->program.loader != nullptr);
273 
274  using mykey_t = short; // this is just the default type used for non-lex recursion
275 
276  if(vms->recursion_depth++ > vms->MAX_RECURSE) { // there is one of these for each recurse
277  throw VMSRuntimeError();
278  }
279 
280  auto x = vms->template getpop<input_t>(); // get the argument
281  auto memindex = std::make_pair(arg,x);
282 
283  if(vms->template mem<mykey_t>().count(memindex)){
284  vms->push(vms->template mem<mykey_t>()[memindex]); // hmm probably should not be a move?
285  }
286  else {
287  vms->xstack.push(x);
288  vms->program.push(Builtins::PopX<Grammar_t>.makeInstruction());
289 
290  vms->template memstack<mykey_t>().push(memindex); // popped off by op_MEM
291  vms->program.push(Builtins::Mem<Grammar_t,mykey_t,output_t>.makeInstruction());
292 
293  vms->program.loader->push_program(vms->program); // this leaves the answer on top
294  }
295  });
296 
297 
298  template<typename Grammar_t,
299  typename input_t=typename Grammar_t::input_t,
300  typename output_t=typename Grammar_t::output_t>
302  assert(not vms->template stack<input_t>().empty());
303 
304  // if the size of the top is zero, we return output{}
305  if(vms->template stack<input_t>().top().size() == 0) {
306  vms->template getpop<input_t>(); // this would have been the argument
307  vms->template push<output_t>(output_t{}); //push default (null) return
308  }
309  else {
310  MemRecurse<Grammar_t>.call(vms, arg);
311  }
312  });
313 
314 
315  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
316 
317 
318 
319  template<typename Grammar_t,
320  typename key_t,
321  typename input_t=typename Grammar_t::input_t,
322  typename output_t=typename Grammar_t::output_t>
324 
325  assert(vms->program.loader != nullptr);
326 
327  if(vms->recursion_depth++ > vms->MAX_RECURSE) { // there is one of these for each recurse
328  throw VMSRuntimeError();
329  }
330 
331  // the key here is the index into the lexicon
332  // NOTE: this is popped before the argument (which is relevant when they are the same type)
333  auto key = vms->template getpop<key_t>();
334 
335  // if we get here, then we have processed our arguments and they are stored in the input_t stack.
336  // so we must move them to the x stack (where there are accessible by op_X)
337  auto mynewx = vms->template getpop<input_t>();
338  vms->xstack.push(std::move(mynewx));
339  vms->program.push(Builtins::PopX<Grammar_t>.makeInstruction()); // we have to remember to remove X once the other program evaluates, *after* everything has evaluated
340 
341  // push this program
342  // but we give i.arg so that we can pass factorized recursed
343  // in argument if we want to
344  vms->program.loader->push_program(vms->program,key);
345  });
346 
347 
348  template<typename Grammar_t,
349  typename key_t,
350  typename input_t=typename Grammar_t::input_t,
351  typename output_t=typename Grammar_t::output_t>
353  assert(not vms->template stack<input_t>().empty());
354  assert(vms->program.loader != nullptr);
355 
356  // if the size of the top is zero, we return output{}
357  if(vms->template stack<input_t>().top().size() == 0) {
358  vms->template getpop<key_t>();
359  vms->template getpop<input_t>(); // this would have been the argument
360 
361  vms->template push<output_t>(output_t{}); //push default (null) return
362  }
363  else {
364  LexiconRecurse<Grammar_t,key_t>.call(vms,arg); // NOTE: Key is popped in this call
365  }
366 
367  });
368 
369 
370  template<typename Grammar_t,
371  typename key_t,
372  typename input_t=typename Grammar_t::input_t,
373  typename output_t=typename Grammar_t::output_t>
375  assert(vms->program.loader != nullptr);
376 
377  if(vms->recursion_depth++ > vms->MAX_RECURSE) { // there is one of these for each recurse
378  throw VMSRuntimeError();
379  }
380 
381  auto key = vms->template getpop<key_t>();
382  auto x = vms->template getpop<input_t>(); // get the argument
383 
384  auto memindex = std::make_pair(key,x);
385 
386  if(vms->template mem<key_t>().count(memindex)){
387  vms->push(vms->template mem<key_t>()[memindex]); // copy over here
388  }
389  else {
390  vms->xstack.push(x);
391  vms->program.push(Builtins::PopX<Grammar_t>.makeInstruction());
392 
393  vms->template memstack<key_t>().push(memindex); // popped off by op_MEM
394  vms->program.push(Builtins::Mem<Grammar_t,key_t,output_t>.makeInstruction());
395 
396  vms->program.loader->push_program(vms->program,key); // this leaves the answer on top
397  }
398  });
399 
400 
401  template<typename Grammar_t,
402  typename key_t,
403  typename input_t=typename Grammar_t::input_t,
404  typename output_t=typename Grammar_t::output_t>
406  assert(not vms->template stack<input_t>().empty());
407  assert(vms->program.loader != nullptr);
408 
409  // if the size of the top is zero, we return output{}
410  if(vms->template stack<input_t>().top().size() == 0) {
411 
412  vms->template getpop<key_t>();
413  vms->template getpop<input_t>(); // this would have been the argument
414 
415  vms->template push<output_t>(output_t{}); //push default (null) return
416  }
417  else {
418  LexiconMemRecurse<Grammar_t,key_t>.call(vms,arg);
419  }
420 
421  });
422 }
423 
Primitive< output_t, key_t, input_t > LexiconSafeMemRecurse(Op::LexiconSafeMemRecurse, BUILTIN_LAMBDA { assert(not vms->template stack< input_t >().empty());assert(vms->program.loader !=nullptr);if(vms->template stack< input_t >().top().size()==0) { vms->template getpop< key_t >();vms->template getpop< input_t >();vms->template push< output_t >(output_t{});} else { LexiconMemRecurse< Grammar_t, key_t >.call(vms, arg);} })
Definition: VMSRuntimeError.h:13
Primitive< output_t, input_t > MemRecurse(Op::MemRecurse, BUILTIN_LAMBDA { assert(vms->program.loader !=nullptr);using mykey_t=short;if(vms->recursion_depth++> vms->MAX_RECURSE) { throw VMSRuntimeError();} auto x=vms->template getpop< input_t >();auto memindex=std::make_pair(arg, x);if(vms->template mem< mykey_t >().count(memindex)){ vms->push(vms->template mem< mykey_t >()[memindex]);} else { vms->xstack.push(x);vms->program.push(Builtins::PopX< Grammar_t >.makeInstruction());vms->template memstack< mykey_t >().push(memindex);vms->program.push(Builtins::Mem< Grammar_t, mykey_t, output_t >.makeInstruction());vms->program.loader->push_program(vms->program);} })
Primitive< int, int > Sample_int(Op::Sample, BUILTIN_LAMBDA { const int mx=vms->template getpop< int >();const double lp=-log(mx);for(int i=0;i< mx;i++) { bool b=vms->pool->copy_increment_push(vms, i, lp);if(not b) break;} vms->status=vmstatus_t::RANDOM_CHOICE;})
Primitive< output_t, key_t, input_t > LexiconMemRecurse(Op::LexiconMemRecurse, BUILTIN_LAMBDA { assert(vms->program.loader !=nullptr);if(vms->recursion_depth++> vms->MAX_RECURSE) { throw VMSRuntimeError();} auto key=vms->template getpop< key_t >();auto x=vms->template getpop< input_t >();auto memindex=std::make_pair(key, x);if(vms->template mem< key_t >().count(memindex)){ vms->push(vms->template mem< key_t >()[memindex]);} else { vms->xstack.push(x);vms->program.push(Builtins::PopX< Grammar_t >.makeInstruction());vms->template memstack< key_t >().push(memindex);vms->program.push(Builtins::Mem< Grammar_t, key_t, output_t >.makeInstruction());vms->program.loader->push_program(vms->program, key);} })
Primitive< bool, bool, bool > And(Op::And, BUILTIN_LAMBDA { bool b=vms->template getpop< bool >();if(!b) { vms->program.popn(arg);vms->template push< bool >(false);} else { } })
Primitive< bool, bool > Not(Op::Not, BUILTIN_LAMBDA { vms->push(not vms->template getpop< bool >());})
Primitive< output_t, key_t, input_t > LexiconSafeRecurse(Op::LexiconSafeRecurse, BUILTIN_LAMBDA { assert(not vms->template stack< input_t >().empty());assert(vms->program.loader !=nullptr);if(vms->template stack< input_t >().top().size()==0) { vms->template getpop< key_t >();vms->template getpop< input_t >();vms->template push< output_t >(output_t{});} else { LexiconRecurse< Grammar_t, key_t >.call(vms, arg);} })
Definition: Primitive.h:15
Primitive Mem(Op::Mem, BUILTIN_LAMBDA { auto memindex=vms->template memstack< key_t >().top();vms->template memstack< key_t >().pop();if(vms->template mem< key_t >().count(memindex)==0) { vms->template mem< key_t >()[memindex]=vms->template gettop< output_t >();} })
Primitive< T, bool, T, T > If(Op::If, BUILTIN_LAMBDA { bool b=vms->template getpop< bool >();if(!b) vms->program.popn(arg);else {};})
Primitive< t, T > Sample(Op::Sample, BUILTIN_LAMBDA { auto s=vms->template getpop< T >();if(s.size()==1) { auto v=std::move(*s.begin());vms->template push< t >(std::move(v));} else { const double lp=-log(s.size());for(const auto &x :s) { bool b=vms->pool->copy_increment_push(vms, x, lp);if(not b) break;} vms->status=vmstatus_t::RANDOM_CHOICE;} })
Primitive< bool, bool, bool > Or(Op::Or, BUILTIN_LAMBDA { bool b=vms->template getpop< bool >();if(b) { vms->program.popn(arg);vms->template push< bool >(true);} else { } })
Primitive NoOp(Op::NoOp, BUILTIN_LAMBDA { })
Primitive< bool > Flip(Op::Flip, BUILTIN_LAMBDA { assert(vms->pool !=nullptr);vms->pool->copy_increment_push(vms, true, -LOG2);bool b=vms->pool->increment_push(vms, false, -LOG2);if(b) { vms->status=vmstatus_t::RANDOM_CHOICE_NO_DELETE;} else { vms->status=vmstatus_t::RANDOM_CHOICE;} })
void print(FIRST f, ARGS... args)
Lock output_lcok and print to std:cout.
Definition: IO.h:53
Primitive< bool, double > SafeFlipP(Op::SafeFlipP, BUILTIN_LAMBDA { assert(vms->pool !=nullptr);double p=vms->template getpop< double >();if(std::isnan(p)) p=0.0;else if(p > 1.0) p=1.0;else if(p< 0.0) p=0.0;vms->pool->copy_increment_push(vms, true, log(p));bool b=vms->pool->increment_push(vms, false, log(1.0-p));if(b) { vms->status=vmstatus_t::RANDOM_CHOICE_NO_DELETE;} else { vms->status=vmstatus_t::RANDOM_CHOICE;} })
const double LOG2
Definition: Numerics.h:16
Primitive Jmp(Op::Jmp, BUILTIN_LAMBDA { vms->program.popn(arg);})
Definition: Builtins.h:11
Primitive< output_t, input_t > SafeMemRecurse(Op::SafeMemRecurse, BUILTIN_LAMBDA { assert(not vms->template stack< input_t >().empty());if(vms->template stack< input_t >().top().size()==0) { vms->template getpop< input_t >();vms->template push< output_t >(output_t{});} else { MemRecurse< Grammar_t >.call(vms, arg);} })
Primitive< typename Grammar_t::input_t > X(Op::X, BUILTIN_LAMBDA { assert(!vms->xstack.empty());vms->template push< typename Grammar_t::input_t >(vms->xstack.top());})
#define BUILTIN_LAMBDA
Definition: Builtins.h:9
Primitive UnusedNoOp(Op::NoOp, BUILTIN_LAMBDA { assert(false);})
Primitive< output_t, input_t > Recurse(Op::Recurse, BUILTIN_LAMBDA { assert(vms->program.loader !=nullptr);if(vms->recursion_depth++> vms->MAX_RECURSE) { throw VMSRuntimeError();} auto mynewx=vms->template getpop< input_t >();vms->xstack.push(std::move(mynewx));vms->program.push(Builtins::PopX< Grammar_t >.makeInstruction());vms->program.loader->push_program(vms->program);})
Primitive< output_t, input_t > SafeRecurse(Op::SafeRecurse, BUILTIN_LAMBDA { assert(not vms->template stack< input_t >().empty());if(vms->template stack< input_t >().top().size()==0) { vms->template getpop< input_t >();vms->template push< output_t >(output_t{});} else { Recurse< Grammar_t >.call(vms, arg);} })
Primitive< bool, double > FlipP(Op::FlipP, BUILTIN_LAMBDA { assert(vms->pool !=nullptr);double p=vms->template getpop< double >();if(std::isnan(p)) { p=0.0;} if(p > 1.0 or p< 0.0) { print("*** Error, received p not in [0,1]:", p);assert(false);} vms->pool->copy_increment_push(vms, true, log(p));bool b=vms->pool->increment_push(vms, false, log(1.0-p));if(b) { vms->status=vmstatus_t::RANDOM_CHOICE_NO_DELETE;} else { vms->status=vmstatus_t::RANDOM_CHOICE;} })
Primitive PopX(Op::PopX, BUILTIN_LAMBDA { vms->xstack.pop();})
Primitive< output_t, key_t, input_t > LexiconRecurse(Op::LexiconRecurse, BUILTIN_LAMBDA { assert(vms->program.loader !=nullptr);if(vms->recursion_depth++> vms->MAX_RECURSE) { throw VMSRuntimeError();} auto key=vms->template getpop< key_t >();auto mynewx=vms->template getpop< input_t >();vms->xstack.push(std::move(mynewx));vms->program.push(Builtins::PopX< Grammar_t >.makeInstruction());vms->program.loader->push_program(vms->program, key);})