Fleet  0.0.9
Inference in the LOT
Grammar.h
Go to the documentation of this file.
1 #pragma once
2 
3 #include <tuple>
4 #include <array>
5 #include <exception>
6 
7 #include "IO.h"
8 #include "Errors.h"
9 #include "Node.h"
10 #include "Random.h"
11 #include "Nonterminal.h"
12 #include "VirtualMachineState.h"
13 #include "VirtualMachinePool.h"
14 #include "Builtins.h"
15 #include "Functional.h"
16 
17 // an exception for recursing too deep
18 struct DepthException : public std::exception {
21  }
22 };
23 
43 template<typename _input_t, typename _output_t, typename... GRAMMAR_TYPES>
44 class Grammar {
45 public:
46 
47  using input_t = _input_t;
48  using output_t = _output_t;
49  using this_t = Grammar<input_t, output_t, GRAMMAR_TYPES...>;
50 
51  // Keep track of what types we are using here as our types -- thesee types are
52  // stored in this tuple so they can be extracted
53  using TypeTuple = std::tuple<GRAMMAR_TYPES...>;
54 
55  // how many nonterminal types do we have?
56  static constexpr size_t N_NTs = std::tuple_size<TypeTuple>::value;
57 
58  // The input/output types must be repeated to VirtualMachineState
59  using VirtualMachineState_t = VirtualMachineState<input_t, output_t, GRAMMAR_TYPES...>;
60 
61  // This is the function type
62  using FT = typename VirtualMachineState_t::FT;
63 
64  // How many times will we silently ignore a DepthException
65  // before tossing an assert error
66  static const size_t GENERATE_DEPTH_EXCEPTION_RETRIES = 1000;
67 
68  // get the n'th type
69  //template<size_t N>
70  //using type = typename std::tuple_element<N, TypeTuple>::type;
71 
72  // rules[k] stores a SORTED vector of rules for the kth' nonterminal.
73  // our iteration order is first for k = 0 ... N_NTs then for r in rules[k]
74  std::vector<Rule> rules[N_NTs];
75  std::array<double,N_NTs> Z; // keep the normalizer handy for each nonterminal (not log space)
76 
77  size_t GRAMMAR_MAX_DEPTH = 64;
78 
79  // This function converts a type (passed as a template parameter) into a
80  // size_t index for which one it in in GRAMMAR_TYPES.
81  // This is used so that a Rule doesn't need type subclasses/templates, it can
82  // store a type as e.g. nt<double>() -> size_t
83  template <class T>
84  static constexpr nonterminal_t nt() {
85  static_assert(sizeof...(GRAMMAR_TYPES) > 0, "*** Cannot use empty grammar types here");
86  static_assert(contains_type<T, GRAMMAR_TYPES...>(), "*** The type T (decayed) must be in GRAMMAR_TYPES");
87  return (nonterminal_t)TypeIndex<T, std::tuple<GRAMMAR_TYPES...>>::value;
88  }
89 
90  Grammar() {
91  for(size_t i=0;i<N_NTs;i++) {
92  Z[i] = 0.0;
93  }
94  }
95 
96  // should not be doing these
97  Grammar(const Grammar& g) = delete;
98  Grammar(const Grammar&& g) = delete;
99 
108  class RuleIterator {
109 
110  // these are require din here for this to be an iterator
111  using iterator_category = std::forward_iterator_tag;
112  using value_type = Rule;
113  using difference_type = int;
114  using pointer = Rule;
115  using reference = Rule&;
116 
117 
118  protected:
121  std::vector<Rule>::iterator current_rule;
122 
123  public:
124 
125  RuleIterator(this_t* g, bool is_end) : grammar(g), current_nt(0) {
126  if(not is_end) {
127  current_rule = g->rules[0].begin();
128  }
129  else {
130  // by convention we set current_rule and current_nt to the last items
131  // since this is what ++ will leave them as below
132  current_nt = N_NTs-1;
133  current_rule = grammar->rules[current_nt].end();
134  }
135  }
136  Rule& operator*() const { return *current_rule; }
137 // Rule* operator->() const { return current_rule; }
138 
139  RuleIterator& operator++(int blah) { this->operator++(); return *this; }
141 
142  current_rule++;
143 
144  // keep incrementing over rules that are empty, and if we run out of
145  // nonterminals, set us to the end and break
146  while( current_rule == grammar->rules[current_nt].end() ) {
147  if(current_nt < grammar->N_NTs-1) {
148  current_nt++; // next nonterminal
149  current_rule = grammar->rules[current_nt].begin();
150  }
151  else {
152  current_rule = grammar->rules[current_nt].end();
153  break;
154  }
155  }
156 
157  return *this;
158  }
159 
160  RuleIterator& operator+(size_t n) {
161  for(size_t i=0;i<n;i++) this->operator++();
162  return *this;
163  }
164 
165  bool operator==(const RuleIterator& rhs) const {
166  return current_nt == rhs.current_nt and current_rule == rhs.current_rule;
167  }
168  };
169 
170  // these are set up to
171  RuleIterator begin() const { return RuleIterator(const_cast<this_t*>(this), false); }
172  RuleIterator end() const { return RuleIterator(const_cast<this_t*>(this), true);; }
173 
177  constexpr nonterminal_t start() {
178  return nt<output_t>();
179  }
180 
181  constexpr size_t count_nonterminals() const {
186  return N_NTs;
187  }
188 
189  size_t count_rules(const nonterminal_t nt) const {
195  assert(nt >= 0 and nt < N_NTs);
196  return rules[nt].size();
197  }
198  size_t count_rules() const {
204  size_t n=0;
205  for(size_t i=0;i<N_NTs;i++) {
206  n += count_rules((nonterminal_t)i);
207  }
208  return n;
209  }
210 
211  void change_probability(const std::string& s, const double newp) {
212  Rule* r = get_rule(s);
213  Z[r->nt] -= r->p;
214  r->p = newp;
215  Z[r->nt] += r->p;
216  }
217 
218  size_t count_terminals(nonterminal_t nt) const {
225  size_t n=0;
226  for(auto& r : rules[nt]) {
227  if(r.is_terminal()) n++;
228  }
229  return n;
230  }
238  size_t n=0;
239  for(auto& r : rules[nt]) {
240  if(not r.is_terminal()) n++;
241  }
242  return n;
243  }
244 
249 // void finite_size(nonterminal_t nt) const {
250 // assert(nt >=0 and nt <= N_NTs);
251 //
252 // // need to create a 2d table of what each thing can expand to
253 // std::vector<std::vector<int> > e(N_NTs, std::vector<int>(N_NTs, 0));
254 //
255 // for(auto& r : rules[nt]) {
256 // for(auto& t : r.child_types)
257 // ++e[nt][t]; // how many ways can I get to this one?
258 // }
259 //
260 // bool updated = false;
261 //
262 // do {
263 // for(size_t nt=0;nt<N_NTs;nt++) {
264 // for(auto& r : rules[nt]){
265 // for(auto& t : r.child_types) {
266 //
267 // }
268 // }
269 // }
270 //
271 // } while(updated);
272 // }
273 
274  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
275  // Managing rules
276  // (this holds a lot of complexity for how we initialize from PRIMITIVES)
277  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
278 
279 
280  template<typename X>
281  static constexpr bool is_in_GRAMMAR_TYPES() {
282  // check if X is in GRAMMAR_TYPES
283  // TODO: UPDATE FOR DECAY SINCE WE DONT WANT THAT UNTIL WE HAVE REFERENCES AGAIN?
284  return contains_type<X,GRAMMAR_TYPES...>();
285  }
286 
294  template<typename T, typename... args>
295  void add_vms(std::string fmt, FT* f, double p=1.0, Op o=Op::Standard, int a=0) {
296  assert(f != nullptr && "*** If you're passing a null f to add_vms, you've really screwed up.");
297 
298  nonterminal_t Tnt = this->nt<T>();
299  Rule r(Tnt, (void*)f, fmt, {nt<args>()...}, p, o, a);
300  Z[Tnt] += r.p; // keep track of the total probability
301  auto pos = std::lower_bound( rules[Tnt].begin(), rules[Tnt].end(), r);
302  rules[Tnt].insert( pos, r ); // put this before
303  }
304 
311  template<typename T, typename... args>
312  void add(std::string fmt, Primitive<T,args...>& b, double p=1.0, int a=0) {
313  // read f and o from b
314  assert(b.f != nullptr);
315  add_vms<T,args...>(fmt, (FT*)b.f, p, b.op, a);
316  }
317 
325  template<typename T, typename... args>
326  void add(std::string fmt, std::function<T(args...)> f, double p=1.0, Op o=Op::Standard, int a=0) {
327 
328  // first check that the types are allowed
329  static_assert((not std::is_reference<T>::value) && "*** Primitives cannot return references.");
330  static_assert((not std::is_reference<args>::value && ...) && "*** Arguments cannot be references.");
331  static_assert(is_in_GRAMMAR_TYPES<T>() , "*** Return type is not in GRAMMAR_TYPES");
332  static_assert((is_in_GRAMMAR_TYPES<args>() && ...), "*** Argument type is not in GRAMMAR_TYPES");
333 
334  // NOTE: We want something with friendly error messages instead of the above,
335  // but apparently this is not supported:
336  // first check that the types are allowed
337 // if constexpr(std::is_reference<T>::value){
338 // print("*** Primitives cannot return references, in ", fmt);
339 // static_assert(false);
340 // }
341 // if constexpr((std::is_reference<args>::value || ...)){
342 // print("*** Arguments cannot be references, in ", fmt);
343 // static_assert(false);
344 // }
345 // if constexpr(not is_in_GRAMMAR_TYPES<T>()){
346 // print("*** Return type T not in grammar types, in ", fmt);
347 // static_assert(false);
348 // }
349 // if constexpr(not (is_in_GRAMMAR_TYPES<args>() && ...)){
350 // print("*** Argument type not in grammar types, in ", fmt);
351 // static_assert(false);
352 // }
353 //
354  // create a lambda on the heap that is a function of a VMS, since
355  // this is what an instruction must be. This implements the calling order convention too.
356  //auto newf = new auto ( [=](VirtualMachineState_t*, int) -> void {
357  auto fvms = new FT([=](VirtualMachineState_t* vms, int _a=0) -> void {
358  assert(vms != nullptr);
359 
360  if constexpr (sizeof...(args) == 0){
361  vms->push( f() );
362  }
363  else if constexpr (sizeof...(args) == 1) {
364  auto a0 = vms->template getpop_nth<0,args...>();
365  vms->push(f(std::move(a0)));
366  }
367  else if constexpr (sizeof...(args) == 2) {
368  auto a1 = vms->template getpop_nth<1,args...>();
369  auto a0 = vms->template getpop_nth<0,args...>();
370  vms->push(f(std::move(a0), std::move(a1)));
371  }
372  else if constexpr (sizeof...(args) == 3) {
373  auto a2 = vms->template getpop_nth<2,args...>();
374  auto a1 = vms->template getpop_nth<1,args...>();
375  auto a0 = vms->template getpop_nth<0,args...>();
376  vms->push(f(std::move(a0), std::move(a1), std::move(a2)));
377  }
378  else if constexpr (sizeof...(args) == 4) {
379  auto a3 = vms->template getpop_nth<3,args...>();
380  auto a2 = vms->template getpop_nth<2,args...>();
381  auto a1 = vms->template getpop_nth<1,args...>();
382  auto a0 = vms->template getpop_nth<0,args...>();
383  vms->push(f(std::move(a0), std::move(a1), std::move(a2), std::move(a3)));
384  }
385  else if constexpr (sizeof...(args) == 5) {
386  auto a4 = vms->template getpop_nth<4,args...>();
387  auto a3 = vms->template getpop_nth<3,args...>();
388  auto a2 = vms->template getpop_nth<2,args...>();
389  auto a1 = vms->template getpop_nth<1,args...>();
390  auto a0 = vms->template getpop_nth<0,args...>();
391  vms->push(f(std::move(a0), std::move(a1), std::move(a2), std::move(a3), std::move(a4)));
392  }
393  else if constexpr (sizeof...(args) == 6) {
394  auto a5 = vms->template getpop_nth<5,args...>();
395  auto a4 = vms->template getpop_nth<4,args...>();
396  auto a3 = vms->template getpop_nth<3,args...>();
397  auto a2 = vms->template getpop_nth<2,args...>();
398  auto a1 = vms->template getpop_nth<1,args...>();
399  auto a0 = vms->template getpop_nth<0,args...>();
400  vms->push(f(std::move(a0), std::move(a1), std::move(a2), std::move(a3), std::move(a4), std::move(a5)));
401  }
402  else {
403  print("*** Error -- too many arguments for a function. Must be updated in Grammar.h ", sizeof...(args) );
404 
405  assert(false);
406  }
407  });
408 
409  add_vms<T,args...>(fmt, fvms, p, o, a);
410  }
411 
418  template<typename T, typename... args>
419  void add(std::string fmt, T(*_f)(args...), double p=1.0, Op o=Op::Standard, int a=0) {
420  add<T,args...>(fmt, std::function<T(args...)>(_f), p, o, a);
421  }
422 
423 
424 
434  template<typename T>
435  void add_terminal(std::string fmt, T x, double p=1.0, Op o=Op::Standard, int a=0) {
436  add(fmt, std::function( [=]()->T { return x; }), p, o, a);
437  }
438 
439 
450  template<typename T, typename... args>
451  void add_ft(std::string fmt, T(*_f)(args...), double p=1.0, Op o=Op::Standard, int a=0) {
452  std::function f = _f; // convert to std::function
453 
454  assert(not contains(fmt, "%s")); // should not contain %s since its not a function application
455 
456  add_terminal<ft<T,args...>>(fmt, f, p, o, a);
457  }
458 
465  rules[nt].clear();
466  Z[nt] = 0.0;
467  }
468 
469  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
470  // Methods for getting rules by some info
471  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
472 
473  size_t get_index_of(const Rule* r) const {
481  for(size_t i=0;i<rules[r->nt].size();i++) {
482  if(*get_rule(r->nt,i) == *r) {
483  return i;
484  }
485  }
486  throw YouShouldNotBeHereError("*** Did not find rule in get_index_of.");
487  }
488 
489  [[nodiscard]] virtual Rule* get_rule(const nonterminal_t nt, size_t k) const {
497  assert(nt < N_NTs);
498  assert(k < rules[nt].size());
499  return const_cast<Rule*>(&rules[nt][k]);
500  }
501 
502  [[nodiscard]] virtual Rule* get_rule(const nonterminal_t nt, const Op o, const int a=0) {
511  assert(nt >= 0 and nt < N_NTs);
512  for(auto& r: rules[nt]) {
513  // Need to fix this because it used is_a:
514  if(r.is_a(o) and r.arg == a)
515  return &r;
516  }
517  throw YouShouldNotBeHereError("*** Could not find rule");
518  }
519 
520  [[nodiscard]] virtual Rule* get_rule(const nonterminal_t nt, size_t i) {
521  return &rules[nt].at(i);
522  }
523 
524  [[nodiscard]] virtual Rule* get_rule(const nonterminal_t nt, const std::string s) const {
532  // we're going to allow matches to prefixes, but we have to keep track
533  // if we have matched a prefix so we don't mutliple count (e.g if one rule was "str" and one was "string"),
534  // we'd want to match "string" as "string" and not "str"
535 
536  bool was_partial_match = true;
537 
538  Rule* ret = nullptr;
539  for(auto& r: rules[nt]) {
540 
541  if(s == r.format) {
542  if(ret != nullptr and not was_partial_match) { // if we previously found a full match
543  CERR "*** Multiple rules found matching " << s TAB r.format ENDL;
544  throw YouShouldNotBeHereError();
545  }
546  else {
547  was_partial_match = false; // not a partial match
548  ret = const_cast<Rule*>(&r);
549  }
550  } // else we look at partial matches
551  else if( was_partial_match and ((s != "" and is_prefix(s, r.format)) or (s=="" and s==r.format))) {
552  if(ret != nullptr) {
553  CERR "*** Multiple rules found matching " << s TAB r.format ENDL;
554  throw YouShouldNotBeHereError();
555  }
556  else {
557  ret = const_cast<Rule*>(&r);
558  }
559  }
560  }
561 
562  if(ret != nullptr) {
563  return ret;
564  }
565  else {
566  CERR "*** No rule found to match " TAB QQ(s) ENDL;
567  throw YouShouldNotBeHereError();
568  }
569  }
570 
571  [[nodiscard]] virtual Rule* get_rule(const std::string s) const {
579  Rule* ret = nullptr;
580  for(auto& r : *this) {
581  if( (s != "" and is_prefix(s, r.format)) or (s=="" and s==r.format)) {
582  if(ret != nullptr) {
583  CERR "*** Multiple rules found matching " << s TAB r.format ENDL;
584  assert(0);
585  }
586  ret = &r;
587  }
588  }
589 
590  if(ret != nullptr) { return ret; }
591  else {
592  CERR "*** No rule found to match " TAB QQ(s) ENDL;
593  throw YouShouldNotBeHereError();
594  }
595  }
596 
597  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
598  // Sampling rules
599  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
600 
601  double rule_normalizer(const nonterminal_t nt) const {
608  assert(nt < N_NTs);
609  return Z[nt];
610  }
611 
612  virtual Rule* sample_rule(const nonterminal_t nt) const {
619  std::function<double(const Rule& r)> f = [](const Rule& r){return r.p;};
620  if(rules[nt].size() == 0) {
621  print("Failed nonterminal, not in grammar:", nt);
622  assert(false && "*** You are trying to sample from a nonterminal with no rules!");
623  }
624  return sample<Rule,std::vector<Rule>>(rules[nt], Z[nt], f).first; // ignore the probabiltiy
625  }
626 
627 
628  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
629  // Generation
630  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
631 
632 
633  Node makeNode(const Rule* r) const {
640  return Node(r, log(r->p)-log(rule_normalizer(r->nt)));
641  }
642 
643 
644  Node __generate(const nonterminal_t ntfrom=nt<output_t>(), unsigned long depth=0) const {
654  if(depth >= GRAMMAR_MAX_DEPTH) {
655  #ifdef WARN_DEPTH_EXCEPTION
656  CERR "*** Grammar exceeded max depth, are you sure the grammar probabilities are right?" ENDL;
657  CERR "*** You might be able to figure out what's wrong with gdb and then looking at the backtrace of" ENDL;
658  CERR "*** which nonterminals are called." ENDL;
659  CERR "*** Or.... maybe this nonterminal does not rewrite to a terminal?" ENDL;
660  #endif
661  throw DepthException();
662  }
663 
664  Rule* r = sample_rule(ntfrom);
665  Node n = makeNode(r);
666 
667  // we'll wrap in a catch so we can see the sequence of nonterminals that failed us:
668  try {
669 
670  for(size_t i=0;i<r->N;i++) {
671  n.set_child(i, __generate(r->type(i), depth+1)); // recurse down
672  }
673 
674  } catch(const DepthException& e) {
675  #ifdef WARN_DEPTH_EXCEPTION
676  CERR ntfrom << " ";
677  #endif
678  throw e;
679  }
680 
681  return n;
682  }
683 
693  Node generate(const nonterminal_t ntfrom=nt<output_t>(), unsigned long depth=0) const {
694  for(size_t tries=0;tries<GENERATE_DEPTH_EXCEPTION_RETRIES;tries++) {
695  try {
696  return __generate(ntfrom, depth);
697  } catch(DepthException& e) { }
698  }
699  assert(false && "*** Generate failed due to repeated depth exceptions");
700  }
701 
702  Node copy_resample(const Node& node, bool f(const Node& n)) const {
711  if(f(node)){
712  return generate(node.rule->nt);
713  }
714  else {
715 
716  // otherwise normal copy
717  auto ret = node;
718  for(size_t i=0;i<ret.nchildren();i++) {
719  ret.set_child(i, copy_resample(ret.child(i), f));
720  }
721  return ret;
722  }
723  }
724 
725  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
726  // Computing log probabilities and priors
727  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
728 
734  const std::map<const Rule*, size_t> get_rule_indexer() const {
735  std::map<const Rule*, size_t> out;
736  size_t idx = 0;
737  const size_t NT = count_nonterminals();
738  for(size_t nt=0;nt<NT;nt++) {
739  for(const auto& r : rules[nt]) {
740  out[&r] = idx;
741  idx++;
742  }
743  }
744  return out;
745  }
746 
752  std::vector<size_t> get_counts(const Node& node) const {
753  auto idx = get_rule_indexer();
754  return get_counts(node,idx);
755  }
756 
762  std::vector<size_t> get_counts(const Node& node, const std::map<const Rule*,size_t>& indexer) const {
763 
764  std::vector<size_t> out(count_rules(),0);
765  for(auto& n : node) {
766  // now increment out, accounting for the number of rules that must have come before!
767  out[indexer.at(n.rule)] += 1;
768  }
769 
770  return out;
771  }
772 
779  template<typename K, typename V>
780  std::vector<size_t> get_counts(const std::map<K,V>& m, const std::map<const Rule*,size_t>& indexer) const {
781 
782  std::vector<size_t> out(count_rules(),0);
783 
784  for(const auto& [key,fac] : m) {
785  auto c = get_counts(fac.get_value(), indexer); // extract counts using indexer
786  for(size_t r=0;r<c.size();r++) // update cv
787  out[r] += c[r];
788  }
789 
790  return out;
791  }
792 
793 
794  // If eigen is defined we can get the transition matrix
795  #ifdef AM_I_USING_EIGEN
796  Matrix get_nonterminal_transition_matrix() {
797  const size_t NT = count_nonterminals();
798  Matrix m = Matrix::Zero(NT,NT);
799  for(size_t nt=0;nt<NT;nt++) {
800  double z = rule_normalizer(nt);
801  for(auto& r : rules[nt]) {
802  double p = r.p / z;
803  for(auto& to : r.get_child_types()) {
804  m(to,nt) += p;
805  }
806  }
807  }
808 
809  return m;
810  }
811  #endif
812 
819 // double get_expected_length(size_t max_depth=50) const {
820 //
821 // const size_t NT = count_nonterminals();
822 // nonterminal_t start = nt<output_t>();
823 // // we'll build up a NT x max_depth dynamic programming table
824 //
825 // Vector2D<double> tab(NT, max_depth);
826 // tab.fill(0.0);
827 // tab[start,0] = 1; // start with 1
828 //
829 // for(size_t d=1;d<max_depth;d++) {
830 // for(nonterminal_t nt=0;nt<NT;nt++) {
831 //
832 // double z = rule_normalizer(nt);
833 // for(auto& r : rules[nt]) {
834 // double p = r.p / z;
835 // for(auto& to : r.get_child_types()) {
836 // m(to,nt) += p;
837 // }
838 // }
839 //
840 // tab[d][nt] = 0.0;
841 // }
842 // }
843 //
844 //
845 // double l = 0.0;
846 // }
847 //
848 
849  double log_probability(const Node& n) const {
857  double lp = 0.0;
858  for(auto& x : n) {
859  if(x.rule == NullRule) continue;
860  lp += log(x.rule->p) - log(rule_normalizer(x.rule->nt));
861  }
862 
863  return lp;
864  }
865 
866  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
867  // Implementation of converting strings to nodes
868  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
869 
870 
871  Node from_parseable(std::deque<std::string>& q) const {
880  assert(!q.empty() && "*** Should not ever get to here with an empty queue -- are you missing arguments?");
881 
882  auto [nts, pfx] = divide(q.front(), Node::NTDelimiter);
883  q.pop_front();
884 
885  // null rules:
886  if(pfx == NullRule->format)
887  return makeNode(NullRule);
888 
889  // otherwise find the matching rule
890  Rule* r = this->get_rule(stoi(nts), pfx);
891 
892  Node v = makeNode(r);
893  for(size_t i=0;i<r->N;i++) {
894 
895  v.set_child(i, from_parseable(q));
896 
897  if(r->type(i) != v.child(i).rule->nt) {
898  CERR "*** Grammar expected type " << r->type(i) << " but got type " << v.child(i).rule->nt << " at " << r->format << " argument " << i ENDL;
899  assert(false && "Bad names in from_parseable."); // just check that we didn't miss this up
900  }
901 
902  }
903  return v;
904  }
905 
906  Node from_parseable(std::string s) const {
913  std::deque<std::string> stk = split(s, Node::RuleDelimiter);
914  return from_parseable(stk);
915  }
916 
917  Node from_parseable(const char* c) const {
918  std::string s = c;
919  return from_parseable(s);
920  }
921 
922 
923  size_t neighbors(const Node& node) const {
924  // How many neighbors do I have? This is the number of neighbors the first gap has
925  for(size_t i=0;i<node.rule->N;i++){
926  if(node.child(i).is_null()) {
927  return count_rules(node.rule->type(i)); // NOTE: must use rule->child_types since child[i]->rule->nt is always 0 for NullRules
928  }
929  else {
930  auto cn = neighbors(node.child(i));
931  if(cn > 0) return cn; // we return the number of neighbors for the first gap
932  }
933  }
934  return 0;
935  }
936 
937  void expand_to_neighbor(Node& node, int& which) {
938  // here we find the neighbor indicated by which and expand it into the which'th neighbor
939  // to do this, we loop through until which is less than the number of neighbors,
940  // and then it must specify which expansion we want to take. This means that when we
941  // skip a nullptr, we have to subtract from it the number of neighbors (expansions)
942  // we could have taken.
943  for(size_t i=0;i<node.rule->N;i++){
944  if(node.child(i).is_null()) {
945  int c = count_rules(node.rule->type(i));
946  if(which >= 0 and which < c) {
947  auto r = get_rule(node.rule->type(i), (size_t)which);
948  node.set_child(i, makeNode(r));
949  }
950  which -= c;
951  }
952  else { // otherwise we have to process that which
953  expand_to_neighbor(node.child(i), which);
954  }
955  }
956  }
957 
958  double neighbor_prior(const Node& node, int& which) const {
959  // here we find the neighbor indicated by which and expand it into the which'th neighbor
960  // to do this, we loop through until which is less than the number of neighbors,
961  // and then it must specify which expansion we want to take. This means that when we
962  // skip a nullptr, we have to subtract from it the number of neighbors (expansions)
963  // we could have taken.
964  for(size_t i=0;i<node.rule->N;i++){
965  if(node.child(i).is_null()) {
966  int c = count_rules(node.rule->type(i));
967  if(which >= 0 and which < c) {
968  auto r = get_rule(node.rule->type(i), (size_t)which);
969  return log(r->p)-log(rule_normalizer(r->nt));
970  }
971  which -= c;
972  }
973  else { // otherwise we have to process that which
974  auto o = neighbor_prior(node.child(i), which);
975  if(not std::isnan(o)) { // if this child returned something.
976  //assert(which <= 0);
977  return o;
978  }
979  }
980  }
981 
982  return NaN; // if no neighbors
983  }
984 
985  void complete(Node& node) {
986  // go through and fill in the tree at random
987  for(size_t i=0;i<node.rule->N;i++){
988  if(node.child(i).is_null()) {
989  node.set_child(i, generate(node.rule->type(i)));
990  }
991  else {
992  complete(node.child(i));
993  }
994  }
995  }
996 
997 
998 
999  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1000  // Simple parsing routines -- not very well debugged
1001  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1002 
1003  std::tuple<int,std::vector<int>,int> find_open_commas_close(const std::string s) {
1004 
1005  // return values
1006  int openpos = -1;
1007  std::vector<int> commas; // all commas with only one open
1008  int closepos = -1;
1009 
1010  // how many are open?
1011  int opencount = 0;
1012 
1013  for(size_t i=0;i<s.length();i++){
1014  char c = s.at(i);
1015  // print("C=", c, opencount, openpos, commas.size(),closepos);
1016 
1017  if(opencount==0 and openpos==-1 and c=='(') {
1018  openpos = i;
1019  }
1020 
1021  if(opencount==1 and closepos==-1 and c==')') {
1022  closepos = i;
1023  }
1024 
1025  // commas position are first comma when there is one open
1026  if(opencount==1 and c==','){
1027  assert(closepos == -1);
1028  commas.push_back(i);
1029  }
1030 
1031  opencount += (c=='(');
1032  opencount -= (c==')');
1033  }
1034 
1035  return std::make_tuple(openpos, commas, closepos);
1036  }
1037 
1048  Node simple_parse(std::string s) {
1049  //print("Parsing ", s);
1050 
1051  // remove the lambda x. if its there
1052  if(s.substr(0,3) == LAMBDAXDOT_STRING) s.erase(0,3);
1053  // remove leading whitespace
1054  while(s.at(0) == ' ' or s.at(0) == '\t') s.erase(0,1);
1055  // remove trailing whitespace
1056  while(s.at(s.size()-1) == ' ' or s.at(s.size()-1) == '\t') s.erase(s.size()-1,1);
1057 
1058  // use the above function to find chunks
1059  auto [open, commas, close] = find_open_commas_close(s);
1060 
1061  // if it's a terminal
1062  if(open == -1) {
1063  assert(commas.size()==0 and close==-1);
1064  auto r = this->get_rule(s);
1065  return this->makeNode(r);
1066  }
1067  else if(close == open+1) { // special case of "f()"
1068  return this->makeNode(get_rule(s)); // the whole string is what we want and its a terminal
1069  }
1070  else {
1071  assert(close != -1);
1072 
1073  // recover the rule format -- no spaces, etc.
1074  std::string fmt = s.substr(0,open) + "(%s";
1075  for(auto& c : commas) {
1076  UNUSED(c);
1077  fmt += ",%s";
1078  }
1079  fmt += ")";
1080 
1081  // find the rule for this format
1082  auto r = this->get_rule(fmt);
1083  auto out = this->makeNode(r);
1084 
1085  int prev=open+1;
1086  int ci=0;
1087  for(auto& c : commas) {
1088  auto child_string = s.substr(prev,c-prev);
1089  out.set_child(ci,simple_parse(child_string));
1090  prev = c+1;
1091  ci++;
1092  }
1093 
1094  // and the last child
1095  auto child_string = s.substr(prev,close-prev);
1096  out.set_child(ci,simple_parse(child_string));
1097 
1098  return out;
1099  }
1100  }
1101 
1102 
1103 
1104 };
Node from_parseable(const char *c) const
Definition: Grammar.h:917
The Primitive type just stores a function pointer and an Op command.
Definition: Grammar.h:44
std::tuple< int, std::vector< int >, int > find_open_commas_close(const std::string s)
Definition: Grammar.h:1003
double neighbor_prior(const Node &node, int &which) const
Definition: Grammar.h:958
void UNUSED(const T &x)
Definition: Miscellaneous.h:38
Node generate(const nonterminal_t ntfrom=nt< output_t >(), unsigned long depth=0) const
A wrapper to catch DepthExcpetions and retry. This means that defaultly we try to generate GENERATE_D...
Definition: Grammar.h:693
std::string QQ(const std::string &x)
Definition: Strings.h:190
Definition: VirtualMachineState.h:46
std::vector< Rule >::iterator current_rule
Definition: Grammar.h:121
Definition: Node.h:22
constexpr nonterminal_t start()
The start nonterminal type.
Definition: Grammar.h:177
virtual Rule * get_rule(const nonterminal_t nt, const std::string s) const
Definition: Grammar.h:524
This represents the state of a partial evaluation of a program, corresponding to the value of all of ...
Definition: Grammar.h:108
Node copy_resample(const Node &node, bool f(const Node &n)) const
Definition: Grammar.h:702
RuleIterator end() const
Definition: Grammar.h:172
#define TAB
Definition: IO.h:19
std::atomic< uintmax_t > depth_exceptions(0)
size_t count_rules() const
Definition: Grammar.h:198
nonterminal_t type(size_t i) const
Definition: Rule.h:152
std::pair< std::string, std::string > divide(const std::string &s, const char delimiter)
Definition: Strings.h:144
RuleIterator & operator++(int blah)
Definition: Grammar.h:139
Definition: Rule.h:21
RuleIterator & operator+(size_t n)
Definition: Grammar.h:160
Definition: Primitive.h:13
this_t * grammar
Definition: Grammar.h:119
Eigen::MatrixXf Matrix
Definition: EigenLib.h:18
Node from_parseable(std::string s) const
Definition: Grammar.h:906
std::deque< std::string > split(const std::string &s, const char delimiter)
Split is returns a deque of s split up at the character delimiter. It handles these special cases: sp...
Definition: str.h:50
void add(std::string fmt, T(*_f)(args...), double p=1.0, Op o=Op::Standard, int a=0)
Wrapper for add to use function pointers.
Definition: Grammar.h:419
Node makeNode(const Rule *r) const
Definition: Grammar.h:633
void * f
Definition: Primitive.h:16
RuleIterator begin() const
Definition: Grammar.h:171
Definition: Errors.h:18
virtual Rule * get_rule(const nonterminal_t nt, size_t i)
Definition: Grammar.h:520
DepthException()
Definition: Grammar.h:19
Node from_parseable(std::deque< std::string > &q) const
Definition: Grammar.h:871
Op
Definition: Ops.h:3
const Rule * NullRule
Definition: Rule.h:186
Node simple_parse(std::string s)
Very simple parsing routine that takes a string like "and(not(or(eq_pos(pos(parent(x)),&#39;NP-POSS&#39;),eq_pos(&#39;NP-S&#39;,pos(x)))),corefers(x))" (from the Binding example) and parses it into a Node.
Definition: Grammar.h:1048
Definition: Grammar.h:18
static const char RuleDelimiter
Definition: Node.h:29
size_t N
Definition: Rule.h:29
nonterminal_t current_nt
Definition: Grammar.h:120
size_t count_nonterminals(nonterminal_t nt) const
Definition: Grammar.h:231
virtual Rule * sample_rule(const nonterminal_t nt) const
Definition: Grammar.h:612
void print(FIRST f, ARGS... args)
Lock output_lock and print to std:cout.
Definition: IO.h:53
bool is_null() const
Definition: Node.h:165
static const char NTDelimiter
Definition: Node.h:28
void push(T &x)
Definition: VirtualMachineState.h:184
std::vector< size_t > get_counts(const std::map< K, V > &m, const std::map< const Rule *, size_t > &indexer) const
Support for map so we can call on Lexicon::get_value.
Definition: Grammar.h:780
std::array< double, N_NTs > Z
Definition: Grammar.h:75
static constexpr nonterminal_t nt()
Definition: Grammar.h:84
Node __generate(const nonterminal_t ntfrom=nt< output_t >(), unsigned long depth=0) const
Definition: Grammar.h:644
void remove_all(nonterminal_t nt)
Remove all the nonterminals of this type from the grammar. NOTE: This is generally a really bad idea ...
Definition: Grammar.h:464
A Node is the primary internal representation for a program – it recursively stores a rule and the a...
std::string format
Definition: Rule.h:28
size_t count_terminals(nonterminal_t nt) const
Definition: Grammar.h:218
std::function< void(this_t *, int)> FT
Definition: VirtualMachineState.h:56
#define CERR
Definition: IO.h:23
void expand_to_neighbor(Node &node, int &which)
Definition: Grammar.h:937
size_t neighbors(const Node &node) const
Definition: Grammar.h:923
this_t & child(const size_t i)
Definition: BaseNode.h:175
const Rule * rule
Definition: Node.h:32
unsigned short nonterminal_t
Definition: Nonterminal.h:4
void set_child(const size_t i, Node &n)
Definition: Node.h:88
void change_probability(const std::string &s, const double newp)
Definition: Grammar.h:211
void add_terminal(std::string fmt, T x, double p=1.0, Op o=Op::Standard, int a=0)
Add a variable that is NOT A function – simplification for adding alphabets etc. This just wraps stu...
Definition: Grammar.h:435
double rule_normalizer(const nonterminal_t nt) const
Definition: Grammar.h:601
double log_probability(const Node &n) const
This computes the expected length of productions from this grammar, counting terminals and nontermina...
Definition: Grammar.h:849
const std::string LAMBDAXDOT_STRING
Definition: Strings.h:20
bool is_prefix(const T &prefix, const T &x)
Check if prefix is a prefix of x – works with iterables, including strings and vectors.
Definition: Strings.h:39
double p
Definition: Rule.h:30
constexpr size_t count_nonterminals() const
Definition: Grammar.h:181
#define ENDL
Definition: IO.h:21
void add_ft(std::string fmt, T(*_f)(args...), double p=1.0, Op o=Op::Standard, int a=0)
Adds this as a function type (see Function.h) rather than as a function itself. For example...
Definition: Grammar.h:451
bool operator==(const RuleIterator &rhs) const
Definition: Grammar.h:165
static constexpr bool is_in_GRAMMAR_TYPES()
For a given nt, returns the number of finite trees that nt can expand to if its finite; 0 if its infi...
Definition: Grammar.h:281
size_t get_index_of(const Rule *r) const
Definition: Grammar.h:473
constexpr double NaN
Definition: Numerics.h:21
This is a thread_local rng whose first object is used to see others (in other threads). This way, we can have thread_local rngs that all are seeded deterministcally in Fleet via –seed=X.
size_t count_rules(const nonterminal_t nt) const
Definition: Grammar.h:189
Op op
Definition: Primitive.h:15
std::vector< size_t > get_counts(const Node &node) const
Compute a vector of counts of how often each rule was used, in a standard order given by iterating ov...
Definition: Grammar.h:752
RuleIterator & operator++()
Definition: Grammar.h:140
nonterminal_t nt
Definition: Rule.h:27
Grammar()
Definition: Grammar.h:90
std::function< out(args...)> ft
Definition: Functional.h:14
A little class that any VirtualMachinePool AND VirtualMachines inherit to control their behavior...
void complete(Node &node)
Definition: Grammar.h:985
const std::map< const Rule *, size_t > get_rule_indexer() const
Returns a map from rule pointers to indices in e.g. a vector, so that every rule has a unique index a...
Definition: Grammar.h:734
Helpers to Find the numerical index (as a nonterminal_t) in a tuple of a given type.
Definition: Miscellaneous.h:105
void add(std::string fmt, Primitive< T, args... > &b, double p=1.0, int a=0)
Definition: Grammar.h:312
RuleIterator(this_t *g, bool is_end)
Definition: Grammar.h:125
virtual Rule * get_rule(const std::string s) const
Definition: Grammar.h:571
virtual Rule * get_rule(const nonterminal_t nt, size_t k) const
Definition: Grammar.h:489
void add(std::string fmt, std::function< T(args...)> f, double p=1.0, Op o=Op::Standard, int a=0)
Definition: Grammar.h:326
bool contains(const std::string &s, const std::string &x)
Definition: Strings.h:53
std::vector< size_t > get_counts(const Node &node, const std::map< const Rule *, size_t > &indexer) const
Compute a vector of counts of how often each rule was used, using indexer to map each rule to an inde...
Definition: Grammar.h:762
Rule & operator*() const
Definition: Grammar.h:136
virtual Rule * get_rule(const nonterminal_t nt, const Op o, const int a=0)
Definition: Grammar.h:502
std::vector< Rule > rules[N_NTs]
Definition: Grammar.h:74
void add_vms(std::string fmt, FT *f, double p=1.0, Op o=Op::Standard, int a=0)
Definition: Grammar.h:295