Fleet  0.0.9
Inference in the LOT
LOTHypothesis.h
Go to the documentation of this file.
1 #pragma once
2 
3 #include <string.h>
4 #include "Proposers.h"
5 #include "Program.h"
6 #include "Node.h"
7 #include "DiscreteDistribution.h"
8 #include "Errors.h"
9 
14 
31 template<typename this_t, // NOTE: IF YOU CHANGE THESE, CHANGE IN DeterministicLOTHypothesis and StochasticLOTHypothesis
32  typename _input_t,
33  typename _output_t,
34  typename _Grammar_t,
35  _Grammar_t* grammar,
36  typename _datum_t=defaultdatum_t<_input_t, _output_t>,
37  typename _data_t=std::vector<_datum_t>,
38  typename _VirtualMachineState_t=typename _Grammar_t::VirtualMachineState_t
39  >
40 class LOTHypothesis : public MCMCable<this_t,_datum_t,_data_t>, // remember, this defines data_t, datum_t
41  public Searchable<this_t,_input_t,_output_t>,
42  public Serializable<this_t>,
43  public ProgramLoader<_VirtualMachineState_t>
44 {
45 public:
46 
49  using Grammar_t = _Grammar_t;
50  using input_t = _input_t;
51  using output_t = _output_t;
52  using VirtualMachineState_t = _VirtualMachineState_t;
53 
54  // proposals return this:
55  using ProposalType = std::optional<std::pair<this_t,double>>;
56 
57  // this splits the prior,likelihood,posterior,and value
58  static const char SerializationDelimiter = '\t';
59 
60  static const size_t MAX_NODES = 64; // max number of nodes we allow; otherwise -inf prior
61 
62  // store the the total number of instructions on the last call
63  // (summed up for stochastic, else just one for deterministic)
65  unsigned long total_vms_steps;
66 
67  // A Callable stores its program
69 
70 protected:
71 
73 
74 public:
75 
76  LOTHypothesis() : MCMCable<this_t,datum_t,data_t>(), value(NullRule,0.0,true) {
77  }
78 
79  LOTHypothesis(Node& x) : MCMCable<this_t,datum_t,data_t>() {
80  set_value(x);
81  }
82 
83  LOTHypothesis(Node&& x) : MCMCable<this_t,datum_t,data_t>() {
84  set_value(x);
85  }
86 
87  // parse this from a string
88  LOTHypothesis(std::string s) : MCMCable<this_t,datum_t,data_t>() {
90  }
91 
93  this->operator=(c); // copy all this garbage -- not sure what to do here
94  }
95 
97  this->operator=(std::move(c));
98  }
99 
102 
103  total_instruction_count_last_call = c.total_instruction_count_last_call;
104  total_vms_steps = c.total_vms_steps;
105  program = c.program;
106  value = c.value; // no setting -- just copy the program
107  if(c.program.loader == &c)
108  program.loader = this; // something special here -- if c's loader was itself, make this myself; else leave it
109  return *this;
110  }
111 
114 
115  total_instruction_count_last_call = c.total_instruction_count_last_call;
116  total_vms_steps = c.total_vms_steps;
117  program = c.program; //std::move(c.program);
118  value = c.value; // std::move(c.value); // no setting -- just copy the program
119  if(c.program.loader == &c)
120  program.loader = this; // something special here -- if c's loader was itself, make this myself; else leave it
121  return *this;
122  }
123 
128  [[nodiscard]] virtual ProposalType propose() const override {
129 
130  // simplest way of doing proposals
131  auto p = Proposals::regenerate(grammar, value);
132  if(not p) return {}; // if failed
133 
134  auto x = p.value();
135 
136  // return a pair of Hypothesis and forward-backward probabilities
137  return std::make_pair(this_t(std::move(x.first)), x.second); // return this_t and fb
138  }
139 
140  [[nodiscard]] static this_t sample() {
141  return this_t(grammar->generate());
142  }
143 
148  [[nodiscard]] virtual this_t restart() const override {
149  // This is used in MCMC to restart chains
150  // this ordinarily would be a resample from the grammar, but sometimes we have can_resample=false
151  // and in that case we want to leave the non-propose nodes alone.
152 
153  if(not value.is_null()) { // if we are null
154  return this_t(grammar->copy_resample(value, [](const Node& n) { return n.can_resample; }));
155  }
156  else {
157  return this_t::sample(); //this_t(grammar->generate());
158  }
159  }
160 
161  Node& get_value() { return value; }
162  const Node& get_value() const { return value; }
163 
168  void set_value(Node& v, bool should_compile=true) {
169  value = v;
170  if(should_compile)
171  this->compile(); // compile with myself defaultly as a loader
172  }
173  void set_value(Node&& v, bool should_compile=true) {
174  value = v;
175  if(should_compile)
176  this->compile();
177  }
178 
179  Grammar_t* get_grammar() const { return grammar; }
180 
185  virtual double compute_prior() override {
186  /* This ends up being a really important check -- otherwise we spend tons of time on really long
187  * this_totheses */
188  if(this->value.count() > MAX_NODES) {
189  return this->prior = -infinity;
190  }
191 
192  return this->prior = grammar->log_probability(value);
193  }
194 
195  virtual double compute_single_likelihood(const datum_t& datum) override {
196  // compute the likelihood of a *single* data point.
197  throw NotImplementedError("*** You must define compute_single_likelihood");// for base classes to implement, but don't set = 0 since then we can't create Hypothesis classes.
198  }
199 
200  void compile() {
201  this->program.clear();
202  value.template linearize<VirtualMachineState_t, Grammar_t>(this->program);
203  this->program.loader = this; // program loader defaults to myself
204  }
205 
211  virtual void push_program(Program<VirtualMachineState_t>& s) override {
212  this->was_called = true; // by definition we should be setting this if we're a program loader
213 
214  for(auto it = this->program.begin(); it != this->program.end(); it++) {
215  s.push(*it);
216  }
217  }
218 
219  virtual std::string string(std::string prefix="") const override {
220  return this->string(prefix,true);
221  }
222  virtual std::string string(std::string prefix, bool usedot) const {
223  return prefix + std::string("\u03BBx.") + value.string(usedot);
224  }
225 
226  static this_t from_string(Grammar_t* g, std::string s) {
227  return this_t(g, g->from_parseable(s));
228  }
229 
230  virtual size_t hash() const override {
231  return value.hash();
232  }
233 
239  virtual bool operator==(const this_t& h) const override {
240  return this->value == h.value;
241  }
242 
246  virtual void complete() override {
247  if(value.is_null()) {
248  auto nt = grammar->template nt<output_t>();
249  set_value(grammar->generate(nt));
250  }
251  else {
252  grammar->complete(value);
253  set_value(value); // this will compile it
254  }
255  }
256 
257  /********************************************************
258  * Implementation of Searchable interace
259  ********************************************************/
260  // The main complication with these is that they handle nullptr
261 
262  virtual int neighbors() const override {
263  if(value.is_null()) { // if the value is null, our neighbors is the number of ways we can do nt
264  auto nt = grammar->template nt<output_t>();
265  return grammar->count_rules(nt);
266  }
267  else {
268  return grammar->neighbors(value);
269 // to rein in the mcts branching factor, we'll count neighbors as just the first unfilled gap
270 // we should not need to change make_neighbor since it fills in the first, first
271  }
272  }
273 
274  virtual void expand_to_neighbor(int k) override {
275  //assert(grammar != nullptr);
276 
277  if(value.is_null()){
278  auto nt = grammar->template nt<output_t>();
279  auto r = grammar->get_rule(nt,(size_t)k);
280  value = grammar->makeNode(r);
281  }
282  else {
283  grammar->expand_to_neighbor(value, k); // NOTE that if we do this, we have to compile still...
284  }
285 
286  if(value.is_complete()) compile(); // this seems slow/wasteful, but I'm not sure of the alternative.
287  }
288 
289 
290  virtual double neighbor_prior(int k) override {
291  // What is the prior for this neighbor?
292  if(value.is_null()){
293  auto nt = grammar->template nt<output_t>();
294  auto r = grammar->get_rule(nt,(size_t)k);
295  return log(r->p)-log(grammar->rule_normalizer(r->nt));
296  }
297  else {
298  return grammar->neighbor_prior(value, k);
299  }
300  }
301 
306  virtual bool is_evaluable() const override {
307  // This checks whether it should be allowed to call "call" on this Hypothesis.
308  // Usually this means that that the value is complete, meaning no partial subtrees
309  return value.is_complete();
310  }
311 
316  size_t recursion_count() {
317  size_t cnt = 0;
318  for(auto& n : value) {
319  cnt += n.rule->is_recursive();
320  }
321  return cnt;
322  }
323 
328  virtual std::string serialize() const override {
329  // NOTE: This doesn't preseve everything (it doesn't save can_propose, for example)
330  return str(this->prior) + SerializationDelimiter +
331  str(this->likelihood) + SerializationDelimiter +
332  str(this->posterior) + SerializationDelimiter +
333  value.parseable();
334  }
335 
341  static this_t deserialize(const std::string& s) {
342  auto [pr, li, po, v] = split<4>(s, SerializationDelimiter);
343 
344  auto h = this_t(grammar->from_parseable(v));
345 
346  // restore the bayes stats
347  h.prior = string_to<double>(pr);
348  h.likelihood = string_to<double>(li);
349  h.posterior = string_to<double>(po);
350 
351  return h;
352  }
353 
354 
355 };
LOTHypothesis(const LOTHypothesis &&c)
Definition: LOTHypothesis.h:96
virtual std::string string(std::string prefix="") const override
Definition: LOTHypothesis.h:219
Node & get_value()
Definition: LOTHypothesis.h:161
static this_t deserialize(const std::string &s)
Convert this from a string which was in a file.
Definition: LOTHypothesis.h:341
virtual this_t restart() const override
This is used to restart chains, sampling from prior but ONLY for nodes that are can_resample.
Definition: LOTHypothesis.h:148
MyGrammar grammar
virtual int neighbors() const override
Count the number of neighbors that are possible. (This should be size_t but int is more convenient...
Definition: LOTHypothesis.h:262
double neighbor_prior(const Node &node, int &which) const
Definition: Grammar.h:958
LOTHypothesis(const LOTHypothesis &c)
Definition: LOTHypothesis.h:92
Node generate(const nonterminal_t ntfrom=nt< output_t >(), unsigned long depth=0) const
A wrapper to catch DepthExcpetions and retry. This means that defaultly we try to generate GENERATE_D...
Definition: Grammar.h:693
void push(const Instruction &val)
Push val onto the stack.
Definition: Stack.h:49
virtual std::string string(std::string prefix, bool usedot) const
Definition: LOTHypothesis.h:222
virtual void complete() override
Modify this hypothesis&#39;s value by (randomly) filling in all the gaps.
Definition: LOTHypothesis.h:246
double likelihood
Definition: Bayesable.h:43
Definition: Node.h:22
void clear()
Definition: Stack.h:41
virtual size_t hash() const override
Definition: LOTHypothesis.h:230
size_t recursion_count()
Count up how many times I use recursion – we keep a list of recursion here.
Definition: LOTHypothesis.h:316
LOTHypothesis()
Definition: LOTHypothesis.h:76
void set_value(Node &v, bool should_compile=true)
Set the value to v. (NOTE: This compiles into a program)
Definition: LOTHypothesis.h:168
virtual std::string string(bool usedot=true) const override
Definition: Node.h:205
Node copy_resample(const Node &node, bool f(const Node &n)) const
Definition: Grammar.h:702
Definition: Program.h:17
Definition: MCMCable.h:14
LOTHypothesis & operator=(const LOTHypothesis &&c)
Definition: LOTHypothesis.h:112
void compile()
Definition: LOTHypothesis.h:200
std::pair< t *, double > sample(const T &s, double z, const std::function< double(const t &)> &f=[](const t &v){return 1.0;})
Definition: Random.h:258
virtual std::string parseable() const
Definition: Node.h:269
virtual std::string serialize() const override
Convert this into a string which can be written to a file.
Definition: LOTHypothesis.h:328
Node value
Definition: LOTHypothesis.h:72
virtual void push_program(Program< VirtualMachineState_t > &s) override
This puts the code from my node onto s. Used internally in e.g. recursion.
Definition: LOTHypothesis.h:211
virtual bool is_evaluable() const override
A node is "evaluable" if it is complete (meaning no null subnodes)
Definition: LOTHypothesis.h:306
A class is searchable if permits us to enumerate and make its neighbors. This class is used by MCTS a...
Definition: Serializable.h:4
Program< VirtualMachineState_t > program
Definition: LOTHypothesis.h:68
double prior
Definition: Bayesable.h:42
Node makeNode(const Rule *r) const
Definition: Grammar.h:633
auto begin()
These are for iterating through the underlying vector.
Definition: Stack.h:138
virtual bool operator==(const this_t &h) const override
Equality is checked on equality of values; note that greater-than is still on posteriors.
Definition: LOTHypothesis.h:239
Node from_parseable(std::deque< std::string > &q) const
Definition: Grammar.h:871
const Rule * NullRule
Definition: Rule.h:186
std::string str(BindingTree *t)
Definition: BindingTree.h:195
double posterior
Definition: Bayesable.h:44
virtual bool is_complete() const
Definition: Node.h:188
ProgramLoader< VirtualMachineState_t > * loader
Definition: Program.h:48
bool is_null() const
Definition: Node.h:165
Definition: Datum.h:15
unsigned long total_vms_steps
Definition: LOTHypothesis.h:65
constexpr double infinity
Definition: Numerics.h:20
A Node is the primary internal representation for a program – it recursively stores a rule and the a...
Bayesable< _datum_t, _data_t >::datum_t datum_t
Definition: LOTHypothesis.h:47
static this_t sample()
Definition: LOTHypothesis.h:140
void expand_to_neighbor(Node &node, int &which)
Definition: Grammar.h:937
size_t neighbors(const Node &node) const
Definition: Grammar.h:923
Grammar_t * get_grammar() const
Definition: LOTHypothesis.h:179
LOTHypothesis(Node &x)
Definition: LOTHypothesis.h:79
std::optional< std::pair< Node, double > > regenerate(GrammarType *grammar, const Node &from)
A little helper function that resamples everything below when we can. If we can&#39;t, then we&#39;ll recurse.
Definition: Proposers.h:107
_datum_t datum_t
Definition: Bayesable.h:38
static const char SerializationDelimiter
Definition: LOTHypothesis.h:58
double rule_normalizer(const nonterminal_t nt) const
Definition: Grammar.h:601
double log_probability(const Node &n) const
This computes the expected length of productions from this grammar, counting terminals and nontermina...
Definition: Grammar.h:849
const Node & get_value() const
Definition: LOTHypothesis.h:162
Definition: Errors.h:7
LOTHypothesis & operator=(const LOTHypothesis &c)
Definition: LOTHypothesis.h:100
void set_value(Node &&v, bool should_compile=true)
Definition: LOTHypothesis.h:173
The Bayesable class provides an interface for hypotheses that support Bayesian inference (e...
Definition: LOTHypothesis.h:40
Definition: Searchable.h:13
unsigned long total_instruction_count_last_call
Definition: LOTHypothesis.h:64
virtual void expand_to_neighbor(int k) override
Modify this hypothesis to become the k&#39;th neighbor. NOTE This does not compile since it might not be ...
Definition: LOTHypothesis.h:274
Definition: MyGrammar.h:72
virtual double compute_prior() override
Compute the prior – defaultly just the PCFG (grammar) prior.
Definition: LOTHypothesis.h:185
This stores a distribution from values of T to log probabilities. It is used as the return value from...
static this_t from_string(Grammar_t *g, std::string s)
Definition: LOTHypothesis.h:226
size_t count_rules(const nonterminal_t nt) const
Definition: Grammar.h:189
_data_t data_t
Definition: Bayesable.h:39
LOTHypothesis(std::string s)
Definition: LOTHypothesis.h:88
auto end()
Definition: Stack.h:139
static const size_t MAX_NODES
Definition: LOTHypothesis.h:60
Bayesable< _datum_t, _data_t >::data_t data_t
Definition: LOTHypothesis.h:48
Definition: Program.h:6
virtual ProposalType propose() const override
Default proposal is rational-rules style regeneration.
Definition: LOTHypothesis.h:128
void complete(Node &node)
Definition: Grammar.h:985
bool was_called
Definition: Program.h:23
A program here stores just a stack of instructions which can be executed by the VirtualMachineState_t...
virtual size_t count() const
How many nodes total are below me?
Definition: BaseNode.h:358
virtual double neighbor_prior(int k) override
What is the prior of the k&#39;th neighbor? This does not need to return the full prior, only relative (among ks)
Definition: LOTHypothesis.h:290
LOTHypothesis(Node &&x)
Definition: LOTHypothesis.h:83
virtual Rule * get_rule(const nonterminal_t nt, size_t k) const
Definition: Grammar.h:489
virtual double compute_single_likelihood(const datum_t &datum) override
Definition: LOTHypothesis.h:195
A class is MCMCable if it is Bayesable and lets us propose, restart, and check equality (which MCMC d...
virtual size_t hash(size_t depth=0) const
Definition: Node.h:391