43 template<
typename _input_t,
typename _output_t,
typename... GRAMMAR_TYPES>
56 static constexpr
size_t N_NTs = std::tuple_size<TypeTuple>::value;
66 static const size_t GENERATE_DEPTH_EXCEPTION_RETRIES = 1000;
74 std::vector<Rule> rules[N_NTs];
75 std::array<double,N_NTs>
Z;
77 size_t GRAMMAR_MAX_DEPTH = 64;
85 static_assert(
sizeof...(GRAMMAR_TYPES) > 0,
"*** Cannot use empty grammar types here");
86 static_assert(contains_type<T, GRAMMAR_TYPES...>(),
"*** The type T (decayed) must be in GRAMMAR_TYPES");
91 for(
size_t i=0;i<N_NTs;i++) {
111 using iterator_category = std::forward_iterator_tag;
113 using difference_type = int;
127 current_rule = g->
rules[0].begin();
132 current_nt = N_NTs-1;
133 current_rule = grammar->
rules[current_nt].end();
146 while( current_rule == grammar->
rules[current_nt].end() ) {
147 if(current_nt < grammar->N_NTs-1) {
149 current_rule = grammar->
rules[current_nt].begin();
152 current_rule = grammar->
rules[current_nt].end();
161 for(
size_t i=0;i<n;i++) this->
operator++();
171 RuleIterator
begin()
const {
return RuleIterator(const_cast<this_t*>(
this),
false); }
172 RuleIterator
end()
const {
return RuleIterator(const_cast<this_t*>(
this),
true);; }
178 return nt<output_t>();
195 assert(nt >= 0 and nt < N_NTs);
196 return rules[nt].size();
205 for(
size_t i=0;i<N_NTs;i++) {
212 Rule* r = get_rule(s);
226 for(
auto& r : rules[nt]) {
227 if(r.is_terminal()) n++;
239 for(
auto& r : rules[nt]) {
240 if(not r.is_terminal()) n++;
284 return contains_type<
X,GRAMMAR_TYPES...>();
294 template<
typename T,
typename... args>
296 assert(f !=
nullptr &&
"*** If you're passing a null f to add_vms, you've really screwed up.");
299 Rule r(Tnt, (
void*)f, fmt, {nt<args>()...}, p, o, a);
301 auto pos = std::lower_bound( rules[Tnt].begin(), rules[Tnt].end(), r);
302 rules[Tnt].insert( pos, r );
311 template<
typename T,
typename... args>
314 assert(b.
f !=
nullptr);
315 add_vms<T,args...>(fmt, (
FT*)b.
f, p, b.
op, a);
325 template<
typename T,
typename... args>
326 void add(std::string fmt, std::function<T(args...)> f,
double p=1.0,
Op o=
Op::Standard,
int a=0) {
329 static_assert((not std::is_reference<T>::value) &&
"*** Primitives cannot return references.");
330 static_assert((not std::is_reference<args>::value && ...) &&
"*** Arguments cannot be references.");
331 static_assert(is_in_GRAMMAR_TYPES<T>() ,
"*** Return type is not in GRAMMAR_TYPES");
332 static_assert((is_in_GRAMMAR_TYPES<args>() && ...),
"*** Argument type is not in GRAMMAR_TYPES");
358 assert(vms !=
nullptr);
360 if constexpr (
sizeof...(args) == 0){
363 else if constexpr (
sizeof...(args) == 1) {
364 auto a0 = vms->template getpop_nth<0,args...>();
365 vms->
push(f(std::move(a0)));
367 else if constexpr (
sizeof...(args) == 2) {
368 auto a1 = vms->template getpop_nth<1,args...>();
369 auto a0 = vms->template getpop_nth<0,args...>();
370 vms->
push(f(std::move(a0), std::move(a1)));
372 else if constexpr (
sizeof...(args) == 3) {
373 auto a2 = vms->template getpop_nth<2,args...>();
374 auto a1 = vms->template getpop_nth<1,args...>();
375 auto a0 = vms->template getpop_nth<0,args...>();
376 vms->
push(f(std::move(a0), std::move(a1), std::move(a2)));
378 else if constexpr (
sizeof...(args) == 4) {
379 auto a3 = vms->template getpop_nth<3,args...>();
380 auto a2 = vms->template getpop_nth<2,args...>();
381 auto a1 = vms->template getpop_nth<1,args...>();
382 auto a0 = vms->template getpop_nth<0,args...>();
383 vms->
push(f(std::move(a0), std::move(a1), std::move(a2), std::move(a3)));
385 else if constexpr (
sizeof...(args) == 5) {
386 auto a4 = vms->template getpop_nth<4,args...>();
387 auto a3 = vms->template getpop_nth<3,args...>();
388 auto a2 = vms->template getpop_nth<2,args...>();
389 auto a1 = vms->template getpop_nth<1,args...>();
390 auto a0 = vms->template getpop_nth<0,args...>();
391 vms->
push(f(std::move(a0), std::move(a1), std::move(a2), std::move(a3), std::move(a4)));
393 else if constexpr (
sizeof...(args) == 6) {
394 auto a5 = vms->template getpop_nth<5,args...>();
395 auto a4 = vms->template getpop_nth<4,args...>();
396 auto a3 = vms->template getpop_nth<3,args...>();
397 auto a2 = vms->template getpop_nth<2,args...>();
398 auto a1 = vms->template getpop_nth<1,args...>();
399 auto a0 = vms->template getpop_nth<0,args...>();
400 vms->
push(f(std::move(a0), std::move(a1), std::move(a2), std::move(a3), std::move(a4), std::move(a5)));
403 print(
"*** Error -- too many arguments for a function. Must be updated in Grammar.h ",
sizeof...(args) );
409 add_vms<T,args...>(fmt, fvms, p, o, a);
418 template<
typename T,
typename... args>
420 add<T,args...>(fmt, std::function<T(args...)>(_f), p, o, a);
436 add(fmt, std::function( [=]()->T {
return x; }), p, o, a);
450 template<
typename T,
typename... args>
452 std::function f = _f;
456 add_terminal<
ft<T,args...>>(fmt, f, p, o, a);
481 for(
size_t i=0;i<rules[r->
nt].size();i++) {
482 if(*get_rule(r->
nt,i) == *r) {
498 assert(k < rules[nt].size());
499 return const_cast<Rule*
>(&rules[nt][k]);
511 assert(nt >= 0 and nt < N_NTs);
512 for(
auto& r: rules[nt]) {
514 if(r.is_a(o) and r.arg == a)
521 return &rules[nt].at(i);
536 bool was_partial_match =
true;
539 for(
auto& r: rules[nt]) {
542 if(ret !=
nullptr and not was_partial_match) {
543 CERR "*** Multiple rules found matching " << s
TAB r.format
ENDL;
547 was_partial_match =
false;
548 ret =
const_cast<Rule*
>(&r);
551 else if( was_partial_match and ((s !=
"" and
is_prefix(s, r.format)) or (s==
"" and s==r.format))) {
553 CERR "*** Multiple rules found matching " << s
TAB r.format
ENDL;
557 ret =
const_cast<Rule*
>(&r);
580 for(
auto& r : *
this) {
581 if( (s !=
"" and
is_prefix(s, r.format)) or (s==
"" and s==r.format)) {
583 CERR "*** Multiple rules found matching " << s
TAB r.format
ENDL;
590 if(ret !=
nullptr) {
return ret; }
619 std::function<double(const Rule& r)> f = [](
const Rule& r){
return r.
p;};
620 if(rules[nt].size() == 0) {
621 print(
"Failed nonterminal, not in grammar:", nt);
622 assert(
false &&
"*** You are trying to sample from a nonterminal with no rules!");
624 return sample<Rule,std::vector<Rule>>(rules[nt], Z[nt], f).first;
640 return Node(r, log(r->
p)-log(rule_normalizer(r->
nt)));
654 if(depth >= GRAMMAR_MAX_DEPTH) {
655 #ifdef WARN_DEPTH_EXCEPTION 656 CERR "*** Grammar exceeded max depth, are you sure the grammar probabilities are right?" ENDL;
657 CERR "*** You might be able to figure out what's wrong with gdb and then looking at the backtrace of" ENDL;
658 CERR "*** which nonterminals are called." ENDL;
659 CERR "*** Or.... maybe this nonterminal does not rewrite to a terminal?" ENDL;
664 Rule* r = sample_rule(ntfrom);
665 Node n = makeNode(r);
670 for(
size_t i=0;i<r->
N;i++) {
675 #ifdef WARN_DEPTH_EXCEPTION 694 for(
size_t tries=0;tries<GENERATE_DEPTH_EXCEPTION_RETRIES;tries++) {
696 return __generate(ntfrom, depth);
699 assert(
false &&
"*** Generate failed due to repeated depth exceptions");
712 return generate(node.
rule->
nt);
718 for(
size_t i=0;i<ret.nchildren();i++) {
719 ret.
set_child(i, copy_resample(ret.child(i), f));
735 std::map<const Rule*, size_t> out;
737 const size_t NT = count_nonterminals();
738 for(
size_t nt=0;nt<NT;nt++) {
739 for(
const auto& r : rules[nt]) {
753 auto idx = get_rule_indexer();
754 return get_counts(node,idx);
762 std::vector<size_t>
get_counts(
const Node& node,
const std::map<const Rule*,size_t>& indexer)
const {
764 std::vector<size_t> out(count_rules(),0);
765 for(
auto& n : node) {
767 out[indexer.at(n.rule)] += 1;
779 template<
typename K,
typename V>
780 std::vector<size_t>
get_counts(
const std::map<K,V>& m,
const std::map<const Rule*,size_t>& indexer)
const {
782 std::vector<size_t> out(count_rules(),0);
784 for(
const auto& [key,fac] : m) {
785 auto c = get_counts(fac.get_value(), indexer);
786 for(
size_t r=0;r<c.size();r++)
795 #ifdef AM_I_USING_EIGEN 796 Matrix get_nonterminal_transition_matrix() {
797 const size_t NT = count_nonterminals();
798 Matrix m = Matrix::Zero(NT,NT);
799 for(
size_t nt=0;nt<NT;nt++) {
800 double z = rule_normalizer(nt);
801 for(
auto& r : rules[nt]) {
803 for(
auto& to : r.get_child_types()) {
860 lp += log(x.rule->p) - log(rule_normalizer(x.rule->nt));
880 assert(!q.empty() &&
"*** Should not ever get to here with an empty queue -- are you missing arguments?");
890 Rule* r = this->get_rule(stoi(nts), pfx);
892 Node v = makeNode(r);
893 for(
size_t i=0;i<r->
N;i++) {
899 assert(
false &&
"Bad names in from_parseable.");
914 return from_parseable(stk);
919 return from_parseable(s);
925 for(
size_t i=0;i<node.
rule->
N;i++){
927 return count_rules(node.
rule->
type(i));
930 auto cn = neighbors(node.
child(i));
931 if(cn > 0)
return cn;
943 for(
size_t i=0;i<node.
rule->
N;i++){
945 int c = count_rules(node.
rule->
type(i));
946 if(which >= 0 and which < c) {
947 auto r = get_rule(node.
rule->
type(i), (size_t)which);
953 expand_to_neighbor(node.
child(i), which);
964 for(
size_t i=0;i<node.
rule->
N;i++){
966 int c = count_rules(node.
rule->
type(i));
967 if(which >= 0 and which < c) {
968 auto r = get_rule(node.
rule->
type(i), (size_t)which);
969 return log(r->p)-log(rule_normalizer(r->nt));
974 auto o = neighbor_prior(node.
child(i), which);
975 if(not std::isnan(o)) {
987 for(
size_t i=0;i<node.
rule->
N;i++){
992 complete(node.
child(i));
1007 std::vector<int> commas;
1013 for(
size_t i=0;i<s.length();i++){
1017 if(opencount==0 and openpos==-1 and c==
'(') {
1021 if(opencount==1 and closepos==-1 and c==
')') {
1026 if(opencount==1 and c==
','){
1027 assert(closepos == -1);
1028 commas.push_back(i);
1031 opencount += (c==
'(');
1032 opencount -= (c==
')');
1035 return std::make_tuple(openpos, commas, closepos);
1054 while(s.at(0) ==
' ' or s.at(0) ==
'\t') s.erase(0,1);
1056 while(s.at(s.size()-1) ==
' ' or s.at(s.size()-1) ==
'\t') s.erase(s.size()-1,1);
1059 auto [open, commas, close] = find_open_commas_close(s);
1063 assert(commas.size()==0 and close==-1);
1064 auto r = this->get_rule(s);
1065 return this->makeNode(r);
1067 else if(close == open+1) {
1068 return this->makeNode(get_rule(s));
1071 assert(close != -1);
1074 std::string fmt = s.substr(0,open) +
"(%s";
1075 for(
auto& c : commas) {
1082 auto r = this->get_rule(fmt);
1083 auto out = this->makeNode(r);
1087 for(
auto& c : commas) {
1088 auto child_string = s.substr(prev,c-prev);
1089 out.set_child(ci,simple_parse(child_string));
1095 auto child_string = s.substr(prev,close-prev);
1096 out.set_child(ci,simple_parse(child_string));
Node from_parseable(const char *c) const
Definition: Grammar.h:917
The Primitive type just stores a function pointer and an Op command.
std::tuple< int, std::vector< int >, int > find_open_commas_close(const std::string s)
Definition: Grammar.h:1003
double neighbor_prior(const Node &node, int &which) const
Definition: Grammar.h:958
void UNUSED(const T &x)
Definition: Miscellaneous.h:38
Node generate(const nonterminal_t ntfrom=nt< output_t >(), unsigned long depth=0) const
A wrapper to catch DepthExcpetions and retry. This means that defaultly we try to generate GENERATE_D...
Definition: Grammar.h:693
std::string QQ(const std::string &x)
Definition: Strings.h:190
Definition: VirtualMachineState.h:46
std::vector< Rule >::iterator current_rule
Definition: Grammar.h:121
Grammar< MyInput, bool, bool, MyObject, MyInput, ObjectSet, ft< bool, bool >, ft< bool, bool, bool >, ft< bool, MyObject >, ft< bool, MyObject, MyObject >, ft< bool, bool, MyObject > >::output_t bool output_t
Definition: Grammar.h:48
constexpr nonterminal_t start()
The start nonterminal type.
Definition: Grammar.h:177
virtual Rule * get_rule(const nonterminal_t nt, const std::string s) const
Definition: Grammar.h:524
This represents the state of a partial evaluation of a program, corresponding to the value of all of ...
Definition: Grammar.h:108
Node copy_resample(const Node &node, bool f(const Node &n)) const
Definition: Grammar.h:702
RuleIterator end() const
Definition: Grammar.h:172
#define TAB
Definition: IO.h:19
std::atomic< uintmax_t > depth_exceptions(0)
size_t count_rules() const
Definition: Grammar.h:198
nonterminal_t type(size_t i) const
Definition: Rule.h:152
std::pair< std::string, std::string > divide(const std::string &s, const char delimiter)
Definition: Strings.h:144
RuleIterator & operator++(int blah)
Definition: Grammar.h:139
RuleIterator & operator+(size_t n)
Definition: Grammar.h:160
Definition: Primitive.h:13
this_t * grammar
Definition: Grammar.h:119
Eigen::MatrixXf Matrix
Definition: EigenLib.h:18
Node from_parseable(std::string s) const
Definition: Grammar.h:906
Grammar< MyInput, bool, bool, MyObject, MyInput, ObjectSet, ft< bool, bool >, ft< bool, bool, bool >, ft< bool, MyObject >, ft< bool, MyObject, MyObject >, ft< bool, bool, MyObject > >::input_t MyInput input_t
Definition: Grammar.h:47
std::deque< std::string > split(const std::string &s, const char delimiter)
Split is returns a deque of s split up at the character delimiter. It handles these special cases: sp...
Definition: str.h:50
void add(std::string fmt, T(*_f)(args...), double p=1.0, Op o=Op::Standard, int a=0)
Wrapper for add to use function pointers.
Definition: Grammar.h:419
Node makeNode(const Rule *r) const
Definition: Grammar.h:633
void * f
Definition: Primitive.h:16
RuleIterator begin() const
Definition: Grammar.h:171
virtual Rule * get_rule(const nonterminal_t nt, size_t i)
Definition: Grammar.h:520
DepthException()
Definition: Grammar.h:19
Node from_parseable(std::deque< std::string > &q) const
Definition: Grammar.h:871
const Rule * NullRule
Definition: Rule.h:186
Node simple_parse(std::string s)
Very simple parsing routine that takes a string like "and(not(or(eq_pos(pos(parent(x)),'NP-POSS'),eq_pos('NP-S',pos(x)))),corefers(x))" (from the Binding example) and parses it into a Node.
Definition: Grammar.h:1048
static const char RuleDelimiter
Definition: Node.h:29
size_t N
Definition: Rule.h:29
nonterminal_t current_nt
Definition: Grammar.h:120
size_t count_nonterminals(nonterminal_t nt) const
Definition: Grammar.h:231
virtual Rule * sample_rule(const nonterminal_t nt) const
Definition: Grammar.h:612
void print(FIRST f, ARGS... args)
Lock output_lock and print to std:cout.
Definition: IO.h:53
bool is_null() const
Definition: Node.h:165
static const char NTDelimiter
Definition: Node.h:28
void push(T &x)
Definition: VirtualMachineState.h:184
std::vector< size_t > get_counts(const std::map< K, V > &m, const std::map< const Rule *, size_t > &indexer) const
Support for map so we can call on Lexicon::get_value.
Definition: Grammar.h:780
std::array< double, N_NTs > Z
Definition: Grammar.h:75
static constexpr nonterminal_t nt()
Definition: Grammar.h:84
Node __generate(const nonterminal_t ntfrom=nt< output_t >(), unsigned long depth=0) const
Definition: Grammar.h:644
void remove_all(nonterminal_t nt)
Remove all the nonterminals of this type from the grammar. NOTE: This is generally a really bad idea ...
Definition: Grammar.h:464
A Node is the primary internal representation for a program – it recursively stores a rule and the a...
std::string format
Definition: Rule.h:28
size_t count_terminals(nonterminal_t nt) const
Definition: Grammar.h:218
std::function< void(this_t *, int)> FT
Definition: VirtualMachineState.h:56
#define CERR
Definition: IO.h:23
void expand_to_neighbor(Node &node, int &which)
Definition: Grammar.h:937
Grammar< MyInput, bool, bool, MyObject, MyInput, ObjectSet, ft< bool, bool >, ft< bool, bool, bool >, ft< bool, MyObject >, ft< bool, MyObject, MyObject >, ft< bool, bool, MyObject > >::FT typename VirtualMachineState_t::FT FT
Definition: Grammar.h:62
size_t neighbors(const Node &node) const
Definition: Grammar.h:923
this_t & child(const size_t i)
Definition: BaseNode.h:175
const Rule * rule
Definition: Node.h:32
unsigned short nonterminal_t
Definition: Nonterminal.h:4
void set_child(const size_t i, Node &n)
Definition: Node.h:88
void change_probability(const std::string &s, const double newp)
Definition: Grammar.h:211
void add_terminal(std::string fmt, T x, double p=1.0, Op o=Op::Standard, int a=0)
Add a variable that is NOT A function – simplification for adding alphabets etc. This just wraps stu...
Definition: Grammar.h:435
double rule_normalizer(const nonterminal_t nt) const
Definition: Grammar.h:601
double log_probability(const Node &n) const
This computes the expected length of productions from this grammar, counting terminals and nontermina...
Definition: Grammar.h:849
const std::string LAMBDAXDOT_STRING
Definition: Strings.h:20
bool is_prefix(const T &prefix, const T &x)
Check if prefix is a prefix of x – works with iterables, including strings and vectors.
Definition: Strings.h:39
double p
Definition: Rule.h:30
constexpr size_t count_nonterminals() const
Definition: Grammar.h:181
#define ENDL
Definition: IO.h:21
void add_ft(std::string fmt, T(*_f)(args...), double p=1.0, Op o=Op::Standard, int a=0)
Adds this as a function type (see Function.h) rather than as a function itself. For example...
Definition: Grammar.h:451
bool operator==(const RuleIterator &rhs) const
Definition: Grammar.h:165
static constexpr bool is_in_GRAMMAR_TYPES()
For a given nt, returns the number of finite trees that nt can expand to if its finite; 0 if its infi...
Definition: Grammar.h:281
size_t get_index_of(const Rule *r) const
Definition: Grammar.h:473
constexpr double NaN
Definition: Numerics.h:21
This is a thread_local rng whose first object is used to see others (in other threads). This way, we can have thread_local rngs that all are seeded deterministcally in Fleet via –seed=X.
size_t count_rules(const nonterminal_t nt) const
Definition: Grammar.h:189
Op op
Definition: Primitive.h:15
std::vector< size_t > get_counts(const Node &node) const
Compute a vector of counts of how often each rule was used, in a standard order given by iterating ov...
Definition: Grammar.h:752
RuleIterator & operator++()
Definition: Grammar.h:140
Grammar< MyInput, bool, bool, MyObject, MyInput, ObjectSet, ft< bool, bool >, ft< bool, bool, bool >, ft< bool, MyObject >, ft< bool, MyObject, MyObject >, ft< bool, bool, MyObject > >::TypeTuple std::tuple< GRAMMAR_TYPES... > TypeTuple
Definition: Grammar.h:53
nonterminal_t nt
Definition: Rule.h:27
Grammar()
Definition: Grammar.h:90
std::function< out(args...)> ft
Definition: Functional.h:14
A little class that any VirtualMachinePool AND VirtualMachines inherit to control their behavior...
void complete(Node &node)
Definition: Grammar.h:985
const std::map< const Rule *, size_t > get_rule_indexer() const
Returns a map from rule pointers to indices in e.g. a vector, so that every rule has a unique index a...
Definition: Grammar.h:734
Helpers to Find the numerical index (as a nonterminal_t) in a tuple of a given type.
Definition: Miscellaneous.h:105
void add(std::string fmt, Primitive< T, args... > &b, double p=1.0, int a=0)
Definition: Grammar.h:312
RuleIterator(this_t *g, bool is_end)
Definition: Grammar.h:125
virtual Rule * get_rule(const std::string s) const
Definition: Grammar.h:571
virtual Rule * get_rule(const nonterminal_t nt, size_t k) const
Definition: Grammar.h:489
void add(std::string fmt, std::function< T(args...)> f, double p=1.0, Op o=Op::Standard, int a=0)
Definition: Grammar.h:326
bool contains(const std::string &s, const std::string &x)
Definition: Strings.h:53
std::vector< size_t > get_counts(const Node &node, const std::map< const Rule *, size_t > &indexer) const
Compute a vector of counts of how often each rule was used, using indexer to map each rule to an inde...
Definition: Grammar.h:762
Rule & operator*() const
Definition: Grammar.h:136
virtual Rule * get_rule(const nonterminal_t nt, const Op o, const int a=0)
Definition: Grammar.h:502
std::vector< Rule > rules[N_NTs]
Definition: Grammar.h:74
void add_vms(std::string fmt, FT *f, double p=1.0, Op o=Op::Standard, int a=0)
Definition: Grammar.h:295