Fleet  0.0.9
Inference in the LOT
Proposers.h
Go to the documentation of this file.
1 #pragma once
2 
3 //#define DEBUG_PROPOSE 1
4 
5 #include <optional>
6 #include <utility>
7 #include <tuple>
8 
9 #include "Node.h"
10 
11 namespace Proposals {
12 
18  double can_resample(const Node& n) {
19  return n.can_resample*1.0;
20  }
21 
22  template<typename GrammarType>
23  std::optional<std::pair<Node,double>> prior_proposal(GrammarType* grammar, const Node& from) {
24  auto g = grammar->generate(from.nt());
25  return {g, grammar->log_probability(g) - grammar->log_probability(from)};
26  }
27 
35  template<typename GrammarType>
36  double p_regeneration_propose_to(GrammarType* grammar, const Node& a, const Node& b) {
37 
38  // TODO: Currently does not take into account can_resample
39  // TODO: FIX THAT PLEASE
40 
41  // what's the probability of replacing the root of a and generating b?
42  double alp = -log(a.count()); // probability of choosing any given node o in a
43  double wholetree = alp + grammar->log_probability(b);
44 
45  if(a.rule != b.rule) {
46  // I must regenerate this whole tree
47  return wholetree;
48  }
49  else {
50  assert(a.nchildren() == b.nchildren()); // not handling missing kids
51 
52  size_t ndiff = 0; // how many children are different?
53  int who = 0; // if there is exactly one difference, who is it?
54  for(size_t i = 0;i<a.nchildren();i++) {
55  if(a.child(i) != b.child(i)) {
56  who = i;
57  ndiff++;
58  }
59  }
60 
61  if(ndiff == 0) {
62  // if all the kids are the same, we could propose to any of them...
63  double lp = wholetree; // could propose to whole tree
64  for(size_t i=0;i<a.nchildren();i++) {
65  lp = logplusexp(lp, alp + p_regeneration_propose_to(grammar, a.child(i), b.child(i)));
66  }
67  return lp;
68  }
69  else if(ndiff == 1) {
70  // we have to propose to who or root
71  return logplusexp(wholetree,
72  alp + p_regeneration_propose_to(grammar, a.child(who), b.child(who)));
73  }
74  else {
75  // more than one difference means we have to propose to the root
76  return wholetree;
77  }
78 
79  }
80  }
81 
82 // /**
83 // * @brief A little helper function that resamples everything below when we can. If we can't, then we'll recurse
84 // * @param grammar
85 // * @param from
86 // */
87 // template<typename GrammarType>
88 // void __regenerate_when_can_resample(const GrammarType* grammar, Node& from) {
89 // if(from.can_resample) {
90 // from.assign(grammar->generate(from.nt()));
91 // }
92 // else { // we can't regenerate so we have to recurse
93 // for(auto& c : from.get_children()) {
94 // __regenerate_when_can_resample(grammar, c);
95 // }
96 // }
97 //
98 // }
99 
106  template<typename GrammarType>
107  std::optional<std::pair<Node,double>> regenerate(GrammarType* grammar, const Node& from) {
108 
109  // copy, regenerate a random node, and return that and forward-backward prob
110 
111  Node ret = from; // copy
112 
113  if(from.sum<double>(can_resample) == 0.0)
114  return {};
115 
116  // 9/2023 -- No ide what this note below was, or what that edit was.
117  // We are changing this as of Feb 2022, we now can *choose* a can_resample Node, but
118  // we won't actually change any unless can_resample is true
119  // BUT if we DO change a can_resample node, we change everything below it.
120  // I changed to sample only those we can resample:
121  auto [s, slp] = sample<Node,Node>(ret, +[](const Node& n) { return 1.0*n.can_resample;});
122 
123 // #ifdef DEBUG_PROPOSE
124 // DEBUG("REGENERATE", from, *s, s->can_resample);
125 // #endif
126 
127  double oldgp = grammar->log_probability(*s); // reverse probability generating
128 
129  s->assign(grammar->generate(s->nt()));
130  // also removed 9/2023:
131 // __regenerate_when_can_resample(grammar,*s);
132 
133  //double fb = slp + grammar->log_probability(*s)
134  // - (log(can_resample(*s)) - log(ret.sum(can_resample)) + oldgp);
135  double fb = (-log(from.count()) + grammar->log_probability(*s)) -
136  (-log(ret.count()) + oldgp);
137 
138  #ifdef DEBUG_PROPOSE
139  CERR "FORWARD" TAB slp + grammar->log_probability(*s) ENDL;
140  CERR "BACKWARD" TAB log(can_resample(*s)) - log(ret.sum(can_resample)) + oldgp ENDL;
141  CERR "PROPOSING" TAB ret ENDL;
142  CERR "----------------" ENDL;
143 
144  #endif
145 
146  return std::make_pair(ret, fb);
147  }
148 
149 
150 
151 
161  template<typename GrammarType, int D>
162  std::optional<std::pair<Node,double>> regenerate_shallow(GrammarType* grammar, const Node& from) {
163 
164  auto my_can_resample = +[](const Node& n) {
165  return (n.can_resample and n.depth() <= D )*1.0;
166  };
167 
168 // #ifdef DEBUG_PROPOSE
169 // CERR "REGENERATE_SHALLOW" TAB from.string() ENDL;
170 // #endif
171 
172  Node ret = from; // copy
173 
174  if(from.sum<double>(my_can_resample) == 0.0)
175  return {};
176 
177  auto [s, slp] = sample<Node,Node>(ret, my_can_resample);
178 
179  double oldgp = grammar->log_probability(*s); // reverse probability generating
180 
181  s->assign(grammar->generate(s.first->nt()));
182 
183  double fb = slp + grammar->log_probability(*s.first)
184  - (log(my_can_resample(*s.first)) - log(ret.sum(my_can_resample)) + oldgp);
185 
186  return std::make_pair(ret, fb);
187  }
188 
189 
190 
191  template<typename GrammarType>
192  std::optional<std::pair<Node, double>> insert_tree(GrammarType* grammar, const Node& from) {
193  // This proposal selects a node, regenerates, and then copies what was there before somewhere below
194  // in the replaced tree. NOTE: it must regenerate something with the right nonterminal
195  // since that's what's being replaced!
196 
197  Node ret = from; // copy
198 
199  if(ret.sum<double>(can_resample) == 0.0)
200  return {};
201 
202  // So:
203  // we pick node s to be replaced.
204  // we create t to replace it with
205  // then somewhere below t, we choose something of type s.nt(), called q, to put s
206 
207  auto [s, slp] = sample<Node,Node>(ret, can_resample); // s is a ptr into ret
208 // print("Choosing s=", s->string());
209 
210  #ifdef DEBUG_PROPOSE
211  DEBUG("INSERT-TREE", from, *s);
212  #endif
213 
214 
215  Node old_s = *s; // the old value of s, copied -- needed for fb and for replacement
216 
217  Node* captured_s = s; // clang doesn't like taking s for some reason
218  std::function can_resample_matches_s_nt = [captured_s](const Node& n) -> double {
219  return can_resample(n)*(n.nt() == captured_s->nt());
220  };
221 
222  // make something of the same type as s that we can
223  // put s into as a subtree below
224  Node t = grammar->generate(s->nt());
225 // print("Generated t=", t.string());
226  s->assign(t);// copy, not move, since we need it below
227 // s->fullprint();
228 // checkNode(grammar, *s);
229 
230  // q is the thing in t that s will replace
231  // this is sampeld from t, but we do it from s after assignment
232  // since that was a bug before ught
233  auto [q, qlp] = sample<Node,Node>(*s, can_resample_matches_s_nt);
234 // print("Choosing q=", q->string());
235 
236  // and then we assign the subtree, q, to be the original s
237  q->assign(old_s);
238 // print("Ret after Q assignment=", ret.string());
239 // s->fullprint();
240 
241 // checkNode(grammar, *s);
242 
243  // now if we replace something below t with s, there are multiples ones we could have done...
244  auto lpq = lp_sample_eq(*q, *s, can_resample_matches_s_nt); //
245 
246 // CERR "----INSERT-----------" ENDL;
247 // CERR from ENDL;
248 // CERR *s ENDL;
249 // CERR t ENDL;
250 // CERR s.second TAB grammar->log_probability(t) TAB grammar->log_probability(old_s) TAB lpq ENDL;
251 
252  // forward is choosing s, generating everything *except* what replaced s, and then replacing
253  double forward = slp + // must get exactly this s
254  (grammar->log_probability(t)-grammar->log_probability(old_s)) + // generate the rest of the tree
255  lpq; // probability of getting any s
256 
258  double backward = lp_sample_one<Node,Node>(t, ret, can_resample) +
259  lp_sample_eq<Node,Node>(old_s, *s, can_resample_matches_s_nt);
260 // print("RETURNINGI", ret);
261 
262 // if(std::isinf(forward)) {
263 // print(slp, lpq, grammar->log_probability(t), grammar->log_probability(old_s) );
264 // print(s->string(), t.string());
265 //
266 // }
267 
268  assert(not std::isinf(forward));
269  assert(not std::isinf(backward));
270 
271  return std::make_pair(ret, forward-backward);
272  }
273 
274  template<typename GrammarType>
275  std::optional<std::pair<Node, double>> delete_tree(GrammarType* grammar, const Node& from) {
276  // This proposal selects a node, regenerates, and then copies what was there before somewhere below
277  // in the replaced tree. NOTE: it must regenerate something with the right nonterminal
278  // since that's what's being replaced
279 
280  Node ret = from; // copy
281 
282  if(ret.sum(can_resample) == 0.0)
283  return {};
284 
285  // s is who we edit at
286  auto [s, slp] = sample<Node,Node>(ret, can_resample); // s is a ptr to ret
287  Node old_s = *s; // the old value of s, copied -- needed for fb
288 
289  #ifdef DEBUG_PROPOSE
290  DEBUG("DELETE-TREE", from, *s);
291  #endif
292 
293  Node* captured_s = s; // clang doesn't like capturing s for some reason
294  std::function can_resample_matches_s_nt = [&](const Node& n) -> double {
295  return can_resample(n)*(n.nt() == captured_s->nt());
296  };
297 
298  // q is who we promote here
299  auto [q, qlp] = sample(*s, can_resample_matches_s_nt);
300  Node newq = *q; // must make a copy since q gets deleted in assignment... TODO: Can clean up with std::move?
301 
302  // forward is choosing s, and then anything equal to q within
303  double forward = slp + lp_sample_eq<Node,Node>(*q, *s, can_resample_matches_s_nt);
304 
305  // probability of generating everything in s except q
306  double tlp = grammar->log_probability(old_s) - grammar->log_probability(*q);
307 
308  // promote q here
309  s->assign(newq);
310 
312  double backward = lp_sample_one<Node,Node>(*s,ret,can_resample) +
313  tlp +
314  lp_sample_eq<Node,Node>(newq,old_s,can_resample_matches_s_nt);
315 
316  assert(not std::isinf(forward));
317  assert(not std::isinf(backward));
318 
319  return std::make_pair(ret, forward-backward);
320  }
321 
322 
330  template<typename GrammarType>
331  std::optional<std::pair<Node,double>> sample_function_leaving_args(GrammarType* grammar, const Node& from) {
332 
333  // We add a restriction here that it must be a function (not a leaf)
334  // since the leaves get well-done by regenration
335  std::function allowed = +[](const Node& n) {
336  return can_resample(n) and n.nchildren() > 0;
337  };
338 
339  Node ret = from; // copy
340 
341  auto z = sample_z<Node,Node>(ret, allowed);
342  if(z == 0.0) return {};
343 
344  auto [s, slp] = sample<Node,Node>(ret, z, allowed);
345 
346  #ifdef DEBUG_PROPOSE
347  print("SAMPLE-LEAVING-ARGS", from, *s);
348  #endif
349 
350  // find everything in the grammar that matches s's type
351  std::vector<Rule*> matching_rules;
352  for(auto& r: grammar->rules[s->rule->nt]) {
353  if(r.child_types == s->rule->child_types) {
354  matching_rules.push_back(&r);
355 // print("Matching ", r, *s->rule);
356  }
357  }
358  assert(matching_rules.size() >= 1); // we had to have matchd one...
359  if(matching_rules.size() == 1) { // don't do this proposal if there is only one rule
360  return {};
361  }
362 
363  // now sample from one
364  std::function sampler = +[](const Rule* r) -> double { return r->p; };
365  auto [newr, __rp] = sample<Rule*>(matching_rules, sampler); // sample according to p
366  assert(newr != nullptr and *newr != nullptr);
367  assert( (*newr)->nt == s->rule->nt);
368  // NOTE: confusingly, newr is a Rule** so it must be deferenced to get a poitner
369  const Rule* oldRule = s->rule;
370 
371  // set the rule and its probability -- save us the copying
372  s->rule = *newr;
373  s->lp = (*newr)->p / grammar->Z[s->rule->nt];
374 
375  // here we compute fb while ignoring the normalizing constants which is why we don't use rp
376  double fb = log(sampler(*newr))-log(sampler(oldRule));
377 
378  return std::make_pair(ret,fb); // forward-backward is just probability of sampling rp since s cancels
379  }
380 
381 
388  template<typename GrammarType>
389  std::optional<std::pair<Node,double>> swap_args(GrammarType* grammar, const Node& from) {
390 
391  Node ret = from; // copy
392 
393  auto z = sample_z<Node,Node>(ret, can_resample);
394  if(z == 0.0) return {};
395 
396  auto [s, slp] = sample<Node,Node>(ret, z, can_resample);
397  if(s->nchildren() <= 1) {
398  return {};
399  }
400 
401  auto N = s->nchildren();
402 
403  // go through and find children of the same type
404  // this is a bit inefficient but the ns here are very small typically
405  std::vector<int> possible_indices; // for swapping
406  for(size_t i=0;i<N;i++) {
407  for(size_t j=i+1;j<N;j++) {
408  if(s->child(j).rule->nt == s->child(i).rule->nt) { // if there is something else with the same type, we can swap
409  possible_indices.push_back(i);
410  break;
411  }
412  }
413  }
414  if(possible_indices.size() == 0) return {};
415 
416 
417 // #ifdef DEBUG_PROPOSE
418 // DEBUG("SWAP-ARGS", from, *s);
419 // #endif
420 
421  // Sample one we can swap
422  auto x = possible_indices.at(sample_int(possible_indices.size()).first);
423  auto y = sample_int(N, [&](const int v) { return 1*(s->child(v).rule->nt == s->child(x).rule->nt and v != x); }).first;
424 
425  Node tmp = s->child(x); // copy (unfortunately, though maybe it won't be necessary if we use std::swap)
426  s->set_child(x, std::move(s->child(y)));
427  s->set_child(y, std::move(tmp));
428 // print("SWAPPED:", ret);
429 
430  return std::make_pair(ret,0.0);
431  }
432 
433 
434 
435 }
MyGrammar grammar
std::optional< std::pair< Node, double > > sample_function_leaving_args(GrammarType *grammar, const Node &from)
This samples functions f(a,b) -> g(a,b) (e.g. without destroying what&#39;s below). This uses a little tr...
Definition: Proposers.h:331
Definition: Node.h:22
double p_regeneration_propose_to(GrammarType *grammar, const Node &a, const Node &b)
Probability of proposing from a to b under regeneration.
Definition: Proposers.h:36
#define TAB
Definition: IO.h:19
std::pair< t *, double > sample(const T &s, double z, const std::function< double(const t &)> &f=[](const t &v){return 1.0;})
Definition: Random.h:258
Definition: Rule.h:21
std::pair< int, double > sample_int(unsigned int max, const std::function< double(const int)> &f=[](const int v){return 1.0;})
Definition: Random.h:330
T sum(std::function< T(const this_t &)> &f) const
Definition: BaseNode.h:425
std::optional< std::pair< Node, double > > prior_proposal(GrammarType *grammar, const Node &from)
Definition: Proposers.h:23
void print(FIRST f, ARGS... args)
Lock output_lock and print to std:cout.
Definition: IO.h:53
std::optional< std::pair< Node, double > > delete_tree(GrammarType *grammar, const Node &from)
Definition: Proposers.h:275
T logplusexp(const T a, const T b)
Definition: Numerics.h:131
double lp_sample_eq(const t &x, const T &s, std::function< double(const t &)> &f=[](const t &v){return 1.0;})
Definition: Random.h:431
Definition: Proposers.h:11
A Node is the primary internal representation for a program – it recursively stores a rule and the a...
std::optional< std::pair< Node, double > > regenerate_shallow(GrammarType *grammar, const Node &from)
Regenerate with rational-rules style proposals, but only allow proposals to trees with a max depth of...
Definition: Proposers.h:162
void assign(Node &n)
Assign will set everything to n BUT it will not copy the parent pointer etc since we&#39;re assuming this...
Definition: Node.h:59
#define CERR
Definition: IO.h:23
nonterminal_t nt() const
Definition: Node.h:156
this_t & child(const size_t i)
Definition: BaseNode.h:175
const Rule * rule
Definition: Node.h:32
std::optional< std::pair< Node, double > > regenerate(GrammarType *grammar, const Node &from)
A little helper function that resamples everything below when we can. If we can&#39;t, then we&#39;ll recurse.
Definition: Proposers.h:107
void set_child(const size_t i, Node &n)
Definition: Node.h:88
double p
Definition: Rule.h:30
#define ENDL
Definition: IO.h:21
std::optional< std::pair< Node, double > > insert_tree(GrammarType *grammar, const Node &from)
Definition: Proposers.h:192
void DEBUG(FIRST f, ARGS... args)
Print to std:ccout with debugging info.
Definition: IO.h:73
bool can_resample
Definition: Node.h:34
std::optional< std::pair< Node, double > > swap_args(GrammarType *grammar, const Node &from)
This propose swaps around arguments of the same type.
Definition: Proposers.h:389
virtual size_t count() const
How many nodes total are below me?
Definition: BaseNode.h:358
double can_resample(const Node &n)
Helper function for whether we can resample from a node (just accesses n.can_resample) ...
Definition: Proposers.h:18
size_t nchildren() const
Definition: BaseNode.h:208
double D
Definition: Main.cpp:15