Fleet  0.0.9
Inference in the LOT
Grammar.h
Go to the documentation of this file.
1 #pragma once
2 
3 #include <tuple>
4 #include <array>
5 #include <exception>
6 
7 #include "IO.h"
8 #include "Errors.h"
9 #include "Node.h"
10 #include "Random.h"
11 #include "Nonterminal.h"
12 #include "VirtualMachineState.h"
13 #include "VirtualMachinePool.h"
14 #include "Builtins.h"
15 
16 // an exception for recursing too deep
17 struct DepthException : public std::exception {
20  }
21 };
22 
42 template<typename _input_t, typename _output_t, typename... GRAMMAR_TYPES>
43 class Grammar {
44 public:
45 
46  using input_t = _input_t;
47  using output_t = _output_t;
48  using this_t = Grammar<input_t, output_t, GRAMMAR_TYPES...>;
49 
50  // Keep track of what types we are using here as our types -- thesee types are
51  // stored in this tuple so they can be extracted
52  using TypeTuple = std::tuple<GRAMMAR_TYPES...>;
53 
54  // how many nonterminal types do we have?
55  static constexpr size_t N_NTs = std::tuple_size<TypeTuple>::value;
56 
57  // The input/output types must be repeated to VirtualMachineState
58  using VirtualMachineState_t = VirtualMachineState<input_t, output_t, GRAMMAR_TYPES...>;
59 
60  // This is the function type
61  using FT = typename VirtualMachineState_t::FT;
62 
63  // How many times will we silently ignore a DepthException
64  // before tossing an assert error
65  static const size_t GENERATE_DEPTH_EXCEPTION_RETRIES = 1000;
66 
67  // get the n'th type
68  //template<size_t N>
69  //using type = typename std::tuple_element<N, TypeTuple>::type;
70 
71  // rules[k] stores a SORTED vector of rules for the kth' nonterminal.
72  // our iteration order is first for k = 0 ... N_NTs then for r in rules[k]
73  std::vector<Rule> rules[N_NTs];
74  std::array<double,N_NTs> Z; // keep the normalizer handy for each nonterminal (not log space)
75 
76  size_t GRAMMAR_MAX_DEPTH = 64;
77 
78  // This function converts a type (passed as a template parameter) into a
79  // size_t index for which one it in in GRAMMAR_TYPES.
80  // This is used so that a Rule doesn't need type subclasses/templates, it can
81  // store a type as e.g. nt<double>() -> size_t
82  template <class T>
83  static constexpr nonterminal_t nt() {
84  static_assert(sizeof...(GRAMMAR_TYPES) > 0, "*** Cannot use empty grammar types here");
85  static_assert(contains_type<T, GRAMMAR_TYPES...>(), "*** The type T (decayed) must be in GRAMMAR_TYPES");
86  return (nonterminal_t)TypeIndex<T, std::tuple<GRAMMAR_TYPES...>>::value;
87  }
88 
89  Grammar() {
90  for(size_t i=0;i<N_NTs;i++) {
91  Z[i] = 0.0;
92  }
93  }
94 
95  // should not be doing these
96  Grammar(const Grammar& g) = delete;
97  Grammar(const Grammar&& g) = delete;
98 
107  class RuleIterator : public std::iterator<std::forward_iterator_tag, Rule> {
108  protected:
111  std::vector<Rule>::iterator current_rule;
112 
113  public:
114 
115  RuleIterator(this_t* g, bool is_end) : grammar(g), current_nt(0) {
116  if(not is_end) {
117  current_rule = g->rules[0].begin();
118  }
119  else {
120  // by convention we set current_rule and current_nt to the last items
121  // since this is what ++ will leave them as below
122  current_nt = N_NTs-1;
123  current_rule = grammar->rules[current_nt].end();
124  }
125  }
126  Rule& operator*() const { return *current_rule; }
127 // Rule* operator->() const { return current_rule; }
128 
129  RuleIterator& operator++(int blah) { this->operator++(); return *this; }
131 
132  current_rule++;
133 
134  // keep incrementing over rules that are empty, and if we run out of
135  // nonterminals, set us to the end and break
136  while( current_rule == grammar->rules[current_nt].end() ) {
137  if(current_nt < grammar->N_NTs-1) {
138  current_nt++; // next nonterminal
139  current_rule = grammar->rules[current_nt].begin();
140  }
141  else {
142  current_rule = grammar->rules[current_nt].end();
143  break;
144  }
145  }
146 
147  return *this;
148  }
149 
150  RuleIterator& operator+(size_t n) {
151  for(size_t i=0;i<n;i++) this->operator++();
152  return *this;
153  }
154 
155  bool operator==(const RuleIterator& rhs) const {
156  return current_nt == rhs.current_nt and current_rule == rhs.current_rule;
157  }
158  };
159 
160  // these are set up to
161  RuleIterator begin() const { return RuleIterator(const_cast<this_t*>(this), false); }
162  RuleIterator end() const { return RuleIterator(const_cast<this_t*>(this), true);; }
163 
167  constexpr nonterminal_t start() {
168  return nt<output_t>();
169  }
170 
171  constexpr size_t count_nonterminals() const {
176  return N_NTs;
177  }
178 
179  size_t count_rules(const nonterminal_t nt) const {
185  assert(nt >= 0 and nt < N_NTs);
186  return rules[nt].size();
187  }
188  size_t count_rules() const {
194  size_t n=0;
195  for(size_t i=0;i<N_NTs;i++) {
196  n += count_rules((nonterminal_t)i);
197  }
198  return n;
199  }
200 
201  void change_probability(const std::string& s, const double newp) {
202  Rule* r = get_rule(s);
203  Z[r->nt] -= r->p;
204  r->p = newp;
205  Z[r->nt] += r->p;
206  }
207 
208  size_t count_terminals(nonterminal_t nt) const {
215  size_t n=0;
216  for(auto& r : rules[nt]) {
217  if(r.is_terminal()) n++;
218  }
219  return n;
220  }
228  size_t n=0;
229  for(auto& r : rules[nt]) {
230  if(not r.is_terminal()) n++;
231  }
232  return n;
233  }
234 
239 // void finite_size(nonterminal_t nt) const {
240 // assert(nt >=0 and nt <= N_NTs);
241 //
242 // // need to create a 2d table of what each thing can expand to
243 // std::vector<std::vector<int> > e(N_NTs, std::vector<int>(N_NTs, 0));
244 //
245 // for(auto& r : rules[nt]) {
246 // for(auto& t : r.child_types)
247 // ++e[nt][t]; // how many ways can I get to this one?
248 // }
249 //
250 // bool updated = false;
251 //
252 // do {
253 // for(size_t nt=0;nt<N_NTs;nt++) {
254 // for(auto& r : rules[nt]){
255 // for(auto& t : r.child_types) {
256 //
257 // }
258 // }
259 // }
260 //
261 // } while(updated);
262 // }
263 
264  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
265  // Managing rules
266  // (this holds a lot of complexity for how we initialize from PRIMITIVES)
267  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
268 
269 
270  template<typename X>
271  static constexpr bool is_in_GRAMMAR_TYPES() {
272  // check if X is in GRAMMAR_TYPES
273  // TODO: UPDATE FOR DECAY SINCE WE DONT WANT THAT UNTIL WE HAVE REFERENCES AGAIN?
274  return contains_type<X,GRAMMAR_TYPES...>();
275  }
276 
284  template<typename T, typename... args>
285  void add_vms(std::string fmt, FT* f, double p=1.0, Op o=Op::Standard, int a=0) {
286  assert(f != nullptr && "*** If you're passing a null f to add_vms, you've really screwed up.");
287 
288  nonterminal_t Tnt = this->nt<T>();
289  Rule r(Tnt, (void*)f, fmt, {nt<args>()...}, p, o, a);
290  Z[Tnt] += r.p; // keep track of the total probability
291  auto pos = std::lower_bound( rules[Tnt].begin(), rules[Tnt].end(), r);
292  rules[Tnt].insert( pos, r ); // put this before
293  }
294 
301  template<typename T, typename... args>
302  void add(std::string fmt, Primitive<T,args...>& b, double p=1.0, int a=0) {
303  // read f and o from b
304  assert(b.f != nullptr);
305  add_vms<T,args...>(fmt, (FT*)b.f, p, b.op, a);
306  }
307 
315  template<typename T, typename... args>
316  void add(std::string fmt, std::function<T(args...)> f, double p=1.0, Op o=Op::Standard, int a=0) {
317 
318  // first check that the types are allowed
319  static_assert((not std::is_reference<T>::value) && "*** Primitives cannot return references.");
320  static_assert((not std::is_reference<args>::value && ...) && "*** Arguments cannot be references.");
321  static_assert(is_in_GRAMMAR_TYPES<T>() , "*** Return type is not in GRAMMAR_TYPES");
322  static_assert((is_in_GRAMMAR_TYPES<args>() && ...), "*** Argument type is not in GRAMMAR_TYPES");
323 
324  // create a lambda on the heap that is a function of a VMS, since
325  // this is what an instruction must be. This implements the calling order convention too.
326  //auto newf = new auto ( [=](VirtualMachineState_t* vms) -> void {
327  auto fvms = new FT([=](VirtualMachineState_t* vms, int _a=0) -> void {
328  assert(vms != nullptr);
329 
330  if constexpr (sizeof...(args) == 0){
331  vms->push( f() );
332  }
333  if constexpr (sizeof...(args) == 1) {
334  auto a0 = vms->template getpop_nth<0,args...>();
335  vms->push(f(std::move(a0)));
336  }
337  else if constexpr (sizeof...(args) == 2) {
338  auto a1 = vms->template getpop_nth<1,args...>();
339  auto a0 = vms->template getpop_nth<0,args...>();
340  vms->push(f(std::move(a0), std::move(a1)));
341  }
342  else if constexpr (sizeof...(args) == 3) {
343  auto a2 = vms->template getpop_nth<2,args...>(); ;
344  auto a1 = vms->template getpop_nth<1,args...>();
345  auto a0 = vms->template getpop_nth<0,args...>();
346  vms->push(f(std::move(a0), std::move(a1), std::move(a2)));
347  }
348  else if constexpr (sizeof...(args) == 4) {
349  auto a3 = vms->template getpop_nth<3,args...>();
350  auto a2 = vms->template getpop_nth<2,args...>();
351  auto a1 = vms->template getpop_nth<1,args...>();
352  auto a0 = vms->template getpop_nth<0,args...>();
353  vms->push(f(std::move(a0), std::move(a1), std::move(a2), std::move(a3)));
354  }
355  });
356 
357  add_vms<T,args...>(fmt, fvms, p, o, a);
358  }
359 
366  template<typename T, typename... args>
367  void add(std::string fmt, T(*_f)(args...), double p=1.0, Op o=Op::Standard, int a=0) {
368  add<T,args...>(fmt, std::function<T(args...)>(_f), p, o, a);
369  }
370 
380  template<typename T>
381  void add_terminal(std::string fmt, T x, double p=1.0, Op o=Op::Standard, int a=0) {
382  add(fmt, std::function( [=]()->T { return x; }), p, o, a);
383  }
384 
385 
392  rules[nt].clear();
393  Z[nt] = 0.0;
394  }
395 
396  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
397  // Methods for getting rules by some info
398  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
399 
400  size_t get_index_of(const Rule* r) const {
407  for(size_t i=0;i<rules[r->nt].size();i++) {
408  if(*get_rule(r->nt,i) == *r) {
409  return i;
410  }
411  }
412  throw YouShouldNotBeHereError("*** Did not find rule in get_index_of.");
413  }
414 
415  [[nodiscard]] virtual Rule* get_rule(const nonterminal_t nt, size_t k) const {
423  assert(nt < N_NTs);
424  assert(k < rules[nt].size());
425  return const_cast<Rule*>(&rules[nt][k]);
426  }
427 
428  [[nodiscard]] virtual Rule* get_rule(const nonterminal_t nt, const Op o, const int a=0) {
437  assert(nt >= 0 and nt < N_NTs);
438  for(auto& r: rules[nt]) {
439  // Need to fix this because it used is_a:
440  if(r.is_a(o) and r.arg == a)
441  return &r;
442  }
443  throw YouShouldNotBeHereError("*** Could not find rule");
444  }
445 
446  [[nodiscard]] virtual Rule* get_rule(const nonterminal_t nt, size_t i) {
447  return &rules[nt].at(i);
448  }
449 
450  [[nodiscard]] virtual Rule* get_rule(const nonterminal_t nt, const std::string s) const {
458  // we're going to allow matches to prefixes, but we have to keep track
459  // if we have matched a prefix so we don't mutliple count (e.g if one rule was "str" and one was "string"),
460  // we'd want to match "string" as "string" and not "str"
461 
462  bool was_partial_match = true;
463 
464  Rule* ret = nullptr;
465  for(auto& r: rules[nt]) {
466 
467  if(s == r.format) {
468  if(ret != nullptr and not was_partial_match) { // if we previously found a full match
469  CERR "*** Multiple rules found matching " << s TAB r.format ENDL;
470  throw YouShouldNotBeHereError();
471  }
472  else {
473  was_partial_match = false; // not a partial match
474  ret = const_cast<Rule*>(&r);
475  }
476  } // else we look at partial matches
477  else if( was_partial_match and ((s != "" and is_prefix(s, r.format)) or (s=="" and s==r.format))) {
478  if(ret != nullptr) {
479  CERR "*** Multiple rules found matching " << s TAB r.format ENDL;
480  throw YouShouldNotBeHereError();
481  }
482  else {
483  ret = const_cast<Rule*>(&r);
484  }
485  }
486  }
487 
488  if(ret != nullptr) {
489  return ret;
490  }
491  else {
492  CERR "*** No rule found to match " TAB QQ(s) ENDL;
493  throw YouShouldNotBeHereError();
494  }
495  }
496 
497  [[nodiscard]] virtual Rule* get_rule(const std::string s) const {
505  Rule* ret = nullptr;
506  for(auto& r : *this) {
507  if( (s != "" and is_prefix(s, r.format)) or (s=="" and s==r.format)) {
508  if(ret != nullptr) {
509  CERR "*** Multiple rules found matching " << s TAB r.format ENDL;
510  assert(0);
511  }
512  ret = &r;
513  }
514  }
515 
516  if(ret != nullptr) { return ret; }
517  else {
518  CERR "*** No rule found to match " TAB QQ(s) ENDL;
519  throw YouShouldNotBeHereError();
520  }
521  }
522 
523  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
524  // Sampling rules
525  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
526 
527  double rule_normalizer(const nonterminal_t nt) const {
534  assert(nt < N_NTs);
535  return Z[nt];
536  }
537 
538  virtual Rule* sample_rule(const nonterminal_t nt) const {
545  std::function<double(const Rule& r)> f = [](const Rule& r){return r.p;};
546  assert(rules[nt].size() > 0 && "*** You are trying to sample from a nonterminal with no rules!");
547  return sample<Rule,std::vector<Rule>>(rules[nt], Z[nt], f).first; // ignore the probabiltiy
548  }
549 
550 
551  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
552  // Generation
553  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
554 
555 
556  Node makeNode(const Rule* r) const {
563  return Node(r, log(r->p)-log(rule_normalizer(r->nt)));
564  }
565 
566 
567  Node __generate(const nonterminal_t ntfrom=nt<output_t>(), unsigned long depth=0) const {
577  if(depth >= GRAMMAR_MAX_DEPTH) {
578  #ifdef WARN_DEPTH_EXCEPTION
579  CERR "*** Grammar exceeded max depth, are you sure the grammar probabilities are right?" ENDL;
580  CERR "*** You might be able to figure out what's wrong with gdb and then looking at the backtrace of" ENDL;
581  CERR "*** which nonterminals are called." ENDL;
582  CERR "*** Or.... maybe this nonterminal does not rewrite to a terminal?" ENDL;
583  #endif
584  throw DepthException();
585  }
586 
587  Rule* r = sample_rule(ntfrom);
588  Node n = makeNode(r);
589 
590  // we'll wrap in a catch so we can see the sequence of nonterminals that failed us:
591  try {
592 
593  for(size_t i=0;i<r->N;i++) {
594  n.set_child(i, __generate(r->type(i), depth+1)); // recurse down
595  }
596 
597  } catch(const DepthException& e) {
598  #ifdef WARN_DEPTH_EXCEPTION
599  CERR ntfrom << " ";
600  #endif
601  throw e;
602  }
603 
604  return n;
605  }
606 
616  Node generate(const nonterminal_t ntfrom=nt<output_t>(), unsigned long depth=0) const {
617  for(size_t tries=0;tries<GENERATE_DEPTH_EXCEPTION_RETRIES;tries++) {
618  try {
619  return __generate(ntfrom, depth);
620  } catch(DepthException& e) { }
621  }
622  assert(false && "*** Generate failed due to repeated depth exceptions");
623  }
624 
625  Node copy_resample(const Node& node, bool f(const Node& n)) const {
634  if(f(node)){
635  return generate(node.rule->nt);
636  }
637  else {
638 
639  // otherwise normal copy
640  auto ret = node;
641  for(size_t i=0;i<ret.nchildren();i++) {
642  ret.set_child(i, copy_resample(ret.child(i), f));
643  }
644  return ret;
645  }
646  }
647 
648  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
649  // Computing log probabilities and priors
650  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
651 
652  std::vector<size_t> get_cumulative_indices() const {
653  // NOTE: This is an inefficiency because we build this up each time
654  // We had a version of grammar that cached this, but it was complex and ugly
655  // so I think we'll take the small performance hit and put it all in here
656  const size_t NT = count_nonterminals();
657  std::vector<size_t> rule_cumulative(NT); // how many rules are there before this (in our ordering)
658  rule_cumulative[0] = 0;
659  for(size_t nt=1;nt<NT;nt++) {
660  rule_cumulative[nt] = rule_cumulative[nt-1] + count_rules( nonterminal_t(nt-1) );
661  }
662  return rule_cumulative;
663  }
664 
665  std::vector<size_t> get_counts(const Node& node) const {
672  const size_t R = count_rules();
673 
674  std::vector<size_t> out(R,0.0);
675 
676  auto rule_cumulative = get_cumulative_indices();
677 
678  for(auto& n : node) {
679  // now increment out, accounting for the number of rules that must have come before!
680  out[rule_cumulative[n.rule->nt] + get_index_of(n.rule)] += 1;
681  }
682 
683  return out;
684  }
685 
691  template<typename T>
692  std::vector<size_t> get_counts(const std::vector<T>& v) const {
693 
694  // NOTE: When we use requires, we can require something like T is a hypothesis with a value...
695  //static_assert(std::is_base_of<LOTHypothesis,T>::value); // need to have LOTHypotheses as T to use T::get_value() below
696 
697  std::vector<size_t> out(count_rules(),0.0);
698 
699  const auto rule_cumulative = get_cumulative_indices();
700  for(auto vi : v) {
701 
702  for(auto& n : vi.get_value()) { // assuming vi has a "value"
703  // now increment out, accounting for the number of rules that must have come before!
704  out[rule_cumulative[n.rule->nt] + get_index_of(n.rule)] += 1;
705  }
706  }
707 
708  return out;
709  }
710 
711 
712  template<typename K, typename T>
713  std::vector<size_t> get_counts(const std::map<K,T>& v) const {
714 
715  std::vector<size_t> out(count_rules(),0.0);
716 
717  const auto rule_cumulative = get_cumulative_indices();
718  for(auto vi : v) {
719  for(const auto& n : vi.second.get_value()) { // assuming vi has a "value"
720  // now increment out, accounting for the number of rules that must have come before!
721  out[rule_cumulative[n.rule->nt] + get_index_of(n.rule)] += 1;
722  }
723  }
724 
725  return out;
726  }
727 
728 
729 
730 
731 // template<typename T>
732 // std::vector<size_t> get_counts(const T& v) const {
733 // if constexpr (std::is_base_of<LOTHypothesis, T>) {
734 // return get_counts()
735 // }
736 // }
737 
738 
739  // If eigen is defined we can get the transition matrix
740  #ifdef AM_I_USING_EIGEN
741  Matrix get_nonterminal_transition_matrix() {
742  const size_t NT = count_nonterminals();
743  Matrix m = Matrix::Zero(NT,NT);
744  for(size_t nt=0;nt<NT;nt++) {
745  double z = rule_normalizer(nt);
746  for(auto& r : rules[nt]) {
747  double p = r.p / z;
748  for(auto& to : r.get_child_types()) {
749  m(to,nt) += p;
750  }
751  }
752  }
753 
754  return m;
755  }
756  #endif
757 
764 // double get_expected_length(size_t max_depth=50) const {
765 //
766 // const size_t NT = count_nonterminals();
767 // nonterminal_t start = nt<output_t>();
768 // // we'll build up a NT x max_depth dynamic programming table
769 //
770 // Vector2D<double> tab(NT, max_depth);
771 // tab.fill(0.0);
772 // tab[start,0] = 1; // start with 1
773 //
774 // for(size_t d=1;d<max_depth;d++) {
775 // for(nonterminal_t nt=0;nt<NT;nt++) {
776 //
777 // double z = rule_normalizer(nt);
778 // for(auto& r : rules[nt]) {
779 // double p = r.p / z;
780 // for(auto& to : r.get_child_types()) {
781 // m(to,nt) += p;
782 // }
783 // }
784 //
785 // tab[d][nt] = 0.0;
786 // }
787 // }
788 //
789 //
790 // double l = 0.0;
791 // }
792 //
793 
794  double log_probability(const Node& n) const {
802  double lp = 0.0;
803  for(auto& x : n) {
804  if(x.rule == NullRule) continue;
805  lp += log(x.rule->p) - log(rule_normalizer(x.rule->nt));
806  }
807 
808  return lp;
809  }
810 
811  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
812  // Implementation of converting strings to nodes
813  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
814 
815 
816  Node from_parseable(std::deque<std::string>& q) const {
825  assert(!q.empty() && "*** Should not ever get to here with an empty queue -- are you missing arguments?");
826 
827  auto [nts, pfx] = divide(q.front(), Node::NTDelimiter);
828  q.pop_front();
829 
830  // null rules:
831  if(pfx == NullRule->format)
832  return makeNode(NullRule);
833 
834  // otherwise find the matching rule
835  Rule* r = this->get_rule(stoi(nts), pfx);
836 
837  Node v = makeNode(r);
838  for(size_t i=0;i<r->N;i++) {
839 
840  v.set_child(i, from_parseable(q));
841 
842  if(r->type(i) != v.child(i).rule->nt) {
843  CERR "*** Grammar expected type " << r->type(i) << " but got type " << v.child(i).rule->nt << " at " << r->format << " argument " << i ENDL;
844  assert(false && "Bad names in from_parseable."); // just check that we didn't miss this up
845  }
846 
847  }
848  return v;
849  }
850 
851  Node from_parseable(std::string s) const {
858  std::deque<std::string> stk = split(s, Node::RuleDelimiter);
859  return from_parseable(stk);
860  }
861 
862  Node from_parseable(const char* c) const {
863  std::string s = c;
864  return from_parseable(s);
865  }
866 
867 
868  size_t neighbors(const Node& node) const {
869  // How many neighbors do I have? This is the number of neighbors the first gap has
870  for(size_t i=0;i<node.rule->N;i++){
871  if(node.child(i).is_null()) {
872  return count_rules(node.rule->type(i)); // NOTE: must use rule->child_types since child[i]->rule->nt is always 0 for NullRules
873  }
874  else {
875  auto cn = neighbors(node.child(i));
876  if(cn > 0) return cn; // we return the number of neighbors for the first gap
877  }
878  }
879  return 0;
880  }
881 
882  void expand_to_neighbor(Node& node, int& which) {
883  // here we find the neighbor indicated by which and expand it into the which'th neighbor
884  // to do this, we loop through until which is less than the number of neighbors,
885  // and then it must specify which expansion we want to take. This means that when we
886  // skip a nullptr, we have to subtract from it the number of neighbors (expansions)
887  // we could have taken.
888  for(size_t i=0;i<node.rule->N;i++){
889  if(node.child(i).is_null()) {
890  int c = count_rules(node.rule->type(i));
891  if(which >= 0 and which < c) {
892  auto r = get_rule(node.rule->type(i), (size_t)which);
893  node.set_child(i, makeNode(r));
894  }
895  which -= c;
896  }
897  else { // otherwise we have to process that which
898  expand_to_neighbor(node.child(i), which);
899  }
900  }
901  }
902 
903  double neighbor_prior(const Node& node, int& which) const {
904  // here we find the neighbor indicated by which and expand it into the which'th neighbor
905  // to do this, we loop through until which is less than the number of neighbors,
906  // and then it must specify which expansion we want to take. This means that when we
907  // skip a nullptr, we have to subtract from it the number of neighbors (expansions)
908  // we could have taken.
909  for(size_t i=0;i<node.rule->N;i++){
910  if(node.child(i).is_null()) {
911  int c = count_rules(node.rule->type(i));
912  if(which >= 0 and which < c) {
913  auto r = get_rule(node.rule->type(i), (size_t)which);
914  return log(r->p)-log(rule_normalizer(r->nt));
915  }
916  which -= c;
917  }
918  else { // otherwise we have to process that which
919  auto o = neighbor_prior(node.child(i), which);
920  if(not std::isnan(o)) { // if this child returned something.
921  //assert(which <= 0);
922  return o;
923  }
924  }
925  }
926 
927  return NaN; // if no neighbors
928  }
929 
930  void complete(Node& node) {
931  // go through and fill in the tree at random
932  for(size_t i=0;i<node.rule->N;i++){
933  if(node.child(i).is_null()) {
934  node.set_child(i, generate(node.rule->type(i)));
935  }
936  else {
937  complete(node.child(i));
938  }
939  }
940  }
941 
942 
943 
944  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
945  // Simple parsing routines -- not very well debugged
946  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
947 
948  std::tuple<int,std::vector<int>,int> find_open_commas_close(const std::string s) {
949 
950  // return values
951  int openpos = -1;
952  std::vector<int> commas; // all commas with only one open
953  int closepos = -1;
954 
955  // how many are open?
956  int opencount = 0;
957 
958  for(size_t i=0;i<s.length();i++){
959  char c = s.at(i);
960  // print("C=", c, opencount, openpos, commas.size(),closepos);
961 
962  if(opencount==0 and openpos==-1 and c=='(') {
963  openpos = i;
964  }
965 
966  if(opencount==1 and closepos==-1 and c==')') {
967  closepos = i;
968  }
969 
970  // commas position are first comma when there is one open
971  if(opencount==1 and c==','){
972  assert(closepos == -1);
973  commas.push_back(i);
974  }
975 
976  opencount += (c=='(');
977  opencount -= (c==')');
978  }
979 
980  return std::make_tuple(openpos, commas, closepos);
981  }
982 
993  Node simple_parse(std::string s) {
994  //print("Parsing ", s);
995 
996  // remove the lambda x. if its there
997  if(s.substr(0,3) == LAMBDAXDOT_STRING) s.erase(0,3);
998  // remove leading whitespace
999  while(s.at(0) == ' ' or s.at(0) == '\t') s.erase(0,1);
1000  // remove trailing whitespace
1001  while(s.at(s.size()-1) == ' ' or s.at(s.size()-1) == '\t') s.erase(s.size()-1,1);
1002 
1003  // use the above function to find chunks
1004  auto [open, commas, close] = find_open_commas_close(s);
1005 
1006  // if it's a terminal
1007  if(open == -1) {
1008  assert(commas.size()==0 and close==-1);
1009  auto r = this->get_rule(s);
1010  return this->makeNode(r);
1011  }
1012  else if(close == open+1) { // special case of "f()"
1013  return this->makeNode(get_rule(s)); // the whole string is what we want and its a terminal
1014  }
1015  else {
1016  assert(close != -1);
1017 
1018  // recover the rule format -- no spaces, etc.
1019  std::string fmt = s.substr(0,open) + "(%s";
1020  for(auto& c : commas) {
1021  UNUSED(c);
1022  fmt += ",%s";
1023  }
1024  fmt += ")";
1025 
1026  // find the rule for this format
1027  auto r = this->get_rule(fmt);
1028  auto out = this->makeNode(r);
1029 
1030  int prev=open+1;
1031  int ci=0;
1032  for(auto& c : commas) {
1033  auto child_string = s.substr(prev,c-prev);
1034  out.set_child(ci,simple_parse(child_string));
1035  prev = c+1;
1036  ci++;
1037  }
1038 
1039  // and the last child
1040  auto child_string = s.substr(prev,close-prev);
1041  out.set_child(ci,simple_parse(child_string));
1042 
1043  return out;
1044  }
1045  }
1046 
1047 
1048 
1049 };
Node from_parseable(const char *c) const
Definition: Grammar.h:862
The Builtin type just stores a function pointer and an Op command. This makes it a bit handier to def...
Definition: Grammar.h:43
std::tuple< int, std::vector< int >, int > find_open_commas_close(const std::string s)
Definition: Grammar.h:948
double neighbor_prior(const Node &node, int &which) const
Definition: Grammar.h:903
Node generate(const nonterminal_t ntfrom=nt< output_t >(), unsigned long depth=0) const
A wrapper to catch DepthExcpetions and retry. This means that defaultly we try to generate GENERATE_D...
Definition: Grammar.h:616
Definition: VirtualMachineState.h:45
std::vector< Rule >::iterator current_rule
Definition: Grammar.h:111
CL output_t
Definition: Grammar.h:47
Definition: Node.h:22
constexpr nonterminal_t start()
The start nonterminal type.
Definition: Grammar.h:167
virtual Rule * get_rule(const nonterminal_t nt, const std::string s) const
Definition: Grammar.h:450
This represents the state of a partial evaluation of a program, corresponding to the value of all of ...
Definition: Grammar.h:107
Node copy_resample(const Node &node, bool f(const Node &n)) const
Definition: Grammar.h:625
RuleIterator end() const
Definition: Grammar.h:162
#define TAB
Definition: IO.h:19
std::atomic< uintmax_t > depth_exceptions(0)
std::vector< size_t > get_counts(const std::map< K, T > &v) const
Definition: Grammar.h:713
size_t count_rules() const
Definition: Grammar.h:188
nonterminal_t type(size_t i) const
Definition: Rule.h:147
std::pair< std::string, std::string > divide(const std::string &s, const char delimiter)
Definition: Strings.h:331
RuleIterator & operator++(int blah)
Definition: Grammar.h:129
Definition: Rule.h:21
RuleIterator & operator+(size_t n)
Definition: Grammar.h:150
Definition: Primitive.h:15
this_t * grammar
Definition: Grammar.h:109
Eigen::MatrixXf Matrix
Definition: EigenLib.h:18
Node from_parseable(std::string s) const
Definition: Grammar.h:851
cl_void input_t
Definition: Grammar.h:46
void add(std::string fmt, T(*_f)(args...), double p=1.0, Op o=Op::Standard, int a=0)
Definition: Grammar.h:367
Node makeNode(const Rule *r) const
Definition: Grammar.h:556
void * f
Definition: Primitive.h:18
RuleIterator begin() const
Definition: Grammar.h:161
Definition: Errors.h:18
std::deque< std::string > split(const std::string &s, const char delimiter)
Split is returns a deque of s split up at the character delimiter. It handles these special cases: sp...
Definition: Strings.h:277
virtual Rule * get_rule(const nonterminal_t nt, size_t i)
Definition: Grammar.h:446
DepthException()
Definition: Grammar.h:18
Node from_parseable(std::deque< std::string > &q) const
Definition: Grammar.h:816
Op
Definition: Ops.h:3
const Rule * NullRule
Definition: Rule.h:181
Node simple_parse(std::string s)
Very simple parsing routine that takes a string like "and(not(or(eq_pos(pos(parent(x)),&#39;NP-POSS&#39;),eq_pos(&#39;NP-S&#39;,pos(x)))),corefers(x))" (from the Binding example) and parses it into a Node.
Definition: Grammar.h:993
Definition: Grammar.h:17
static const char RuleDelimiter
Definition: Node.h:29
size_t N
Definition: Rule.h:28
nonterminal_t current_nt
Definition: Grammar.h:110
std::vector< size_t > get_cumulative_indices() const
Definition: Grammar.h:652
size_t count_nonterminals(nonterminal_t nt) const
Definition: Grammar.h:221
virtual Rule * sample_rule(const nonterminal_t nt) const
Definition: Grammar.h:538
bool is_null() const
Definition: Node.h:167
static const char NTDelimiter
Definition: Node.h:28
constexpr bool contains_type()
Check if a type is contained in parameter pack.
Definition: Miscellaneous.h:135
void push(T &x)
Definition: VirtualMachineState.h:177
std::array< double, N_NTs > Z
Definition: Grammar.h:74
static constexpr nonterminal_t nt()
Definition: Grammar.h:83
Node __generate(const nonterminal_t ntfrom=nt< output_t >(), unsigned long depth=0) const
Definition: Grammar.h:567
void remove_all(nonterminal_t nt)
Remove all the nonterminals of this type from the grammar. NOTE: This is generally a really bad idea ...
Definition: Grammar.h:391
A Node is the primary internal representation for a program – it recursively stores a rule and the a...
std::string QQ(std::string x)
Definition: Strings.h:402
std::string format
Definition: Rule.h:27
size_t count_terminals(nonterminal_t nt) const
Definition: Grammar.h:208
std::function< void(this_t *, int)> FT
Definition: VirtualMachineState.h:55
#define CERR
Definition: IO.h:23
void expand_to_neighbor(Node &node, int &which)
Definition: Grammar.h:882
typename VirtualMachineState_t::FT FT
Definition: Grammar.h:61
size_t neighbors(const Node &node) const
Definition: Grammar.h:868
this_t & child(const size_t i)
Definition: BaseNode.h:166
const Rule * rule
Definition: Node.h:32
unsigned short nonterminal_t
Definition: Nonterminal.h:4
void set_child(const size_t i, Node &n)
Definition: Node.h:88
void change_probability(const std::string &s, const double newp)
Definition: Grammar.h:201
void add_terminal(std::string fmt, T x, double p=1.0, Op o=Op::Standard, int a=0)
Add a variable that is NOT A function – simplification for adding alphabets etc. This just wraps stu...
Definition: Grammar.h:381
double rule_normalizer(const nonterminal_t nt) const
Definition: Grammar.h:527
double log_probability(const Node &n) const
This computes the expected length of productions from this grammar, counting terminals and nontermina...
Definition: Grammar.h:794
const std::string LAMBDAXDOT_STRING
Definition: Strings.h:21
bool is_prefix(const T &prefix, const T &x)
Check if prefix is a prefix of x – works with iterables, including strings and vectors.
Definition: Strings.h:179
double p
Definition: Rule.h:29
constexpr size_t count_nonterminals() const
Definition: Grammar.h:171
#define ENDL
Definition: IO.h:21
bool operator==(const RuleIterator &rhs) const
Definition: Grammar.h:155
static constexpr bool is_in_GRAMMAR_TYPES()
For a given nt, returns the number of finite trees that nt can expand to if its finite; 0 if its infi...
Definition: Grammar.h:271
size_t get_index_of(const Rule *r) const
Definition: Grammar.h:400
constexpr double NaN
Definition: Numerics.h:19
This is a thread_local rng whose first object is used to see others (in other threads). This way, we can have thread_local rngs that all are seeded deterministcally in Fleet via –seed=X.
size_t count_rules(const nonterminal_t nt) const
Definition: Grammar.h:179
Op op
Definition: Primitive.h:17
std::vector< size_t > get_counts(const Node &node) const
Definition: Grammar.h:665
RuleIterator & operator++()
Definition: Grammar.h:130
std::tuple< GRAMMAR_TYPES... > TypeTuple
Definition: Grammar.h:52
nonterminal_t nt
Definition: Rule.h:26
Grammar()
Definition: Grammar.h:89
A little class that any VirtualMachinePool AND VirtualMachines inherit to control their behavior...
void complete(Node &node)
Definition: Grammar.h:930
std::vector< size_t > get_counts(const std::vector< T > &v) const
Compute a vector of counts of how often each rule was used, in a standard order given by iterating ov...
Definition: Grammar.h:692
void add(std::string fmt, Primitive< T, args... > &b, double p=1.0, int a=0)
Definition: Grammar.h:302
RuleIterator(this_t *g, bool is_end)
Definition: Grammar.h:115
virtual Rule * get_rule(const std::string s) const
Definition: Grammar.h:497
virtual Rule * get_rule(const nonterminal_t nt, size_t k) const
Definition: Grammar.h:415
void add(std::string fmt, std::function< T(args...)> f, double p=1.0, Op o=Op::Standard, int a=0)
Definition: Grammar.h:316
Rule & operator*() const
Definition: Grammar.h:126
virtual Rule * get_rule(const nonterminal_t nt, const Op o, const int a=0)
Definition: Grammar.h:428
std::vector< Rule > rules[N_NTs]
Definition: Grammar.h:73
void add_vms(std::string fmt, FT *f, double p=1.0, Op o=Op::Standard, int a=0)
Definition: Grammar.h:285