Fleet  0.0.9
Inference in the LOT
Classes | Typedefs | Functions | Variables
Main.cpp File Reference
#include <string>
#include <vector>
#include <assert.h>
#include <iostream>
#include <exception>
#include <map>
#include "BindingTree.h"
#include "Grammar.h"
#include "Singleton.h"
#include "LOTHypothesis.h"
#include "Timing.h"
#include "Lexicon.h"
#include "CachedCallHypothesis.h"
#include "Fleet.h"
#include "TopN.h"
#include "MCMCChain.h"
#include "ParallelTempering.h"
#include "SExpression.h"
Include dependency graph for Main.cpp:

Classes

class  TreeException
 

Typedefs

using S = std::string
 

Functions

int main (int argc, char **argv)
 

Variables

std::vector< std::string > words = {"REXP", "him", "his", "he", "himself"}
 
int NDATA = 10
 
MyHypothesis::data_t target_precisionrecall_data
 
MyHypothesis target
 

Typedef Documentation

◆ S

using S = std::string

Function Documentation

◆ main()

int main ( int  argc,
char **  argv 
)

Variable Documentation

◆ NDATA

int NDATA = 10

◆ target

MyHypothesis target

◆ target_precisionrecall_data

MyHypothesis::data_t target_precisionrecall_data
Declare a grammar
This requires a template to specify what types they are (and what order they are stored in)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ */
// declare a grammar with our primitives
// Note that this ordering of primitives defines the order in Grammar
using Super::Super;
public:
add("null(%s)", +[](BindingTree* x) -> bool {
return x == nullptr;
});
add("parent(%s)", +[](BindingTree* x) -> BindingTree* {
if(x == nullptr) throw TreeException();
return x->parent;
});
add("root(%s)", +[](BindingTree* x) -> BindingTree* {
if(x==nullptr) throw TreeException();
return x->root();
});
// linear order predicates
add("linear(%s)", +[](BindingTree* x) -> int {
if(x==nullptr) throw TreeException();
return x->linear_order;
});
add("gt(%s,%s)", +[](int a, int b) -> bool { return (a>b); });
// equality
add("eq_bool(%s,%s)", +[](bool a, bool b) -> bool { return (a==b); });
add("eq_int(%s,%s)", +[](int a, int b) -> bool { return (a==b); });
add("eq_str(%s,%s)", +[](S a, S b) -> bool { return (a==b); });
add("eq_pos(%s,%s)", +[](POS a, POS b) -> bool { return (a==b); });
add("eq_bt(%s,%s)", +[](BindingTree* x, BindingTree* y) -> bool { return x == y;});
// pos predicates
add("pos(%s)", +[](BindingTree* x) -> POS {
if(x==nullptr) throw TreeException();
return x->pos;
});
// tree predicates
add("coreferent(%s)", +[](BindingTree* x) -> BindingTree* {
if(x==nullptr) throw TreeException();
return x->coreferent();
});
add("corefers(%s)", +[](BindingTree* x) -> bool {
if(x==nullptr) throw TreeException();
return x->coreferent() != nullptr;
});
add("leaf(%s)", +[](BindingTree* x) -> bool {
if(x==nullptr) throw TreeException();
return x->nchildren() == 0;
});
add("word(%s)", +[](BindingTree* x) -> S {
if(x==nullptr) throw TreeException();
if(x->target) throw TreeException(); // well it's very clever -- we can't allow label on the target or the problem is trivial
return x->word;
});
add("dominates(%s,%s)", +[](BindingTree* x, BindingTree* y) -> bool {
// NOTE: a node will dominate itself
if(y == nullptr or x == nullptr)
throw TreeException();
while(true) {
if(y == nullptr) return false; // by definition, null doesn't dominate anything
else if(y == x) return true;
y = y->parent;
}
});
add("first-dominating(%s,%s)", +[](POS s, BindingTree* x) -> BindingTree* {
if(x == nullptr) throw TreeException();
while(true) {
if(x == nullptr) return nullptr; // by definition, null doesn't dominate anything
else if(x->pos == s) return x;
x = x->parent;
}
});
add("true", +[]() -> bool { return true; }, 5);
add("false", +[]() -> bool { return false; }, 5);
add("and(%s,%s)", Builtins::And<MyGrammar>);
add("or(%s,%s)", Builtins::Or<MyGrammar>);
add("not(%s)", Builtins::Not<MyGrammar>);
// add("if(%s,%s,%s)", Builtins::If<MyGrammar,int>);
// add("if(%s,%s,%s)", Builtins::If<MyGrammar,S>);
// add("if(%s,%s,%s)", Builtins::If<MyGrammar,POS>);
// add("if(%s,%s,%s)", Builtins::If<MyGrammar,BindingTree*>);
//
add("x", Builtins::X<MyGrammar>, 10);
for(auto& w : words) {
add_terminal<S>(Q(w), w, 5.0/words.size());
}
for(auto [l,p] : posmap) {
add_terminal<POS>(Q(l), p, 5.0/posmap.size());
}
// for absolute positions (up top 24)
// for(int p=0;p<24;p++) {
// add_terminal<int>(str(p), p, 5.0/24.);
// }
// NOTE THIS DOES NOT WORK WITH THE CACHED VERSION
// add("F(%s,%s)" , Builtins::LexiconRecurse<MyGrammar,S>); // note the number of arguments here
}
public:
using Super::Super; // inherit the constructors
using output_t = Super::output_t;
using data_t = Super::data_t;
InnerHypothesis(const InnerHypothesis& c) : Super(c), CCH(c) {}
InnerHypothesis(const InnerHypothesis&& c) : Super(c), CCH(c) { }
InnerHypothesis& operator=(const InnerHypothesis& c) {
Super::operator=(c);
CachedCallHypothesis::operator=(c);
return *this;
}
InnerHypothesis& operator=(const InnerHypothesis&& c) {
Super::operator=(c);
CachedCallHypothesis::operator=(c);
return *this;
}
void set_value(Node& v) {
Super::set_value(v);
}
void set_value(Node&& v) {
Super::set_value(v);
}
virtual double compute_prior() override {
/* This ends up being a really important check -- otherwise we spend tons of time on really long
this_totheses */
if(this->value.count() > 16) return this->prior = -infinity;
else return this->prior = this->get_grammar()->log_probability(value);
}
[[nodiscard]] virtual std::pair<InnerHypothesis,double> propose() const override {
std::pair<Node,double> x;
if(flip(0.5)) x = Proposals::regenerate(&grammar, value);
else if(flip(0.1)) x = Proposals::sample_function_leaving_args(&grammar, value);
else if(flip(0.1)) x = Proposals::swap_args(&grammar, value);
else if(flip()) x = Proposals::insert_tree(&grammar, value);
else x = Proposals::delete_tree(&grammar, value);
return std::make_pair(InnerHypothesis(std::move(x.first)), x.second);
}
};
class MyHypothesis : public Lexicon<MyHypothesis, std::string, InnerHypothesis, BindingTree*, std::string> {
// Takes a node (in a bigger tree) and a word
using Super::Super; // inherit the constructors
public:
double compute_likelihood(const data_t& data, const double breakout=-infinity) override {
// TODO: Can set to null so that we get an error on recurse
for(auto& [k, f] : factors) f.program.loader = this;
// make sure everyone's cache is right on this data
for(auto& [k, f] : factors) {
//f.clear_cache(); // if we always want to recompute (e.g. if using recursion)
f.cached_callOne(data);
// now if anything threw an error, break out, we don't have to compute
if(f.got_error)
return likelihood = -infinity;
}
// The likelihood here samples from all words that are true
likelihood = 0.0;
for(size_t di=0;di<data.size();di++) {
auto& d = data[di];
// see how many factors are true:
bool wtrue = false; // was w true?
int ntrue = 0; // how many are true?
// see which words are permitted
for(const auto& w : words) {
auto b = factors[w].cache.at(di);
ntrue += 1*b;
if(d.output == w) wtrue = b;
}
// Noisy size-principle likelihood
likelihood += NDATA * log( (wtrue ? d.reliability/ntrue : 0.0) +
(1.0-d.reliability)/words.size());
if(likelihood < breakout) return likelihood = -infinity;
}
return likelihood;
}
virtual void print(std::string prefix="") {
std::lock_guard guard(output_lock);
COUT std::setprecision(5) << prefix << this->posterior TAB this->prior TAB this->likelihood TAB "";
// when we print, we are going to compute overlap with each target item
for(auto& w : words) {
int nagree = 0;
int ntot = 0;
for(size_t di=0;di<target_precisionrecall_data.size();di++) {
auto& d = target_precisionrecall_data[di];
if(factors[w].cache.at(di) == target.factors[w].cache.at(di)) {
++nagree;
}
++ntot;
}
COUT float(nagree)/float(ntot) TAB "";
}
COUT QQ(this->string()) ENDL;
}
};
/*

Main code ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

◆ words

std::vector<std::string> words = {"REXP", "him", "his", "he", "himself"}