Fleet  0.0.9
Inference in the LOT
BaseNode.h
Go to the documentation of this file.
1 #pragma once
2 
3 #include<vector>
4 #include<string>
5 
6 #include "Errors.h"
7 #include "IO.h"
8 
19 template<typename this_t>
20 class BaseNode {
21 
22 protected:
23  std::vector<this_t> children;
24 
25 public:
26  this_t* parent;
27  size_t pi; // what index am I in the parent?
28 
35  BaseNode(size_t n=0, this_t* p=nullptr, size_t i=0) : children(n), parent(p), pi(i) {
36  }
37 
38  BaseNode(const this_t& n) {
39  children = n.children;
40  parent = n.parent;
41  pi = n.pi;
43  }
44  BaseNode(this_t&& n) {
45  parent = n.parent;
46  pi = n.pi;
47  children = std::move(n.children);
49  }
50 
51  void operator=(const this_t& t) {
52  parent = t.parent;
53  pi = t.pi;
54  children = t.children;
56  }
57  void operator=(const this_t&& t) {
58  parent = t.parent;
59  pi = t.pi;
60  children = std::move(t.children);
62  }
63 
64 
68  void make_root() {
69  parent = nullptr;
70  pi = 0;
71  }
72 
73  virtual ~BaseNode() {}
74 
75  // Functions to be defined by subclasses
76  virtual std::string string(bool usedot=true) const {
77  throw YouShouldNotBeHereError("*** BaseNode subclass has no defined string()");
78  }
79  virtual std::string my_string() const {
80  throw YouShouldNotBeHereError("*** BaseNode subclass has no defined my_string()");
81  }
82  virtual bool operator==(const this_t& n) const {
83  throw YouShouldNotBeHereError("*** BaseNode subclass has no defined operator==");
84  }
85 
87 
95  class NodeIterator : public std::iterator<std::forward_iterator_tag, this_t> {
96  // Define an iterator class to make managing trees easier.
97  // This iterates in postfix order, which is standard in the library
98  // because it is the order of linearization
99  protected:
100  this_t* current;
101 
102  // we need to store the start because when we start at a subtree, we want to iteratre through its subnodes
103  // and so we need to know when to stop
104  const this_t* start;
105 
106  public:
107 
108  NodeIterator(const this_t* n) : current( n!= nullptr ? n->left_descend() : nullptr), start(n) { }
109  this_t& operator*() const {
110  assert(current != nullptr);
111  return *current;
112  }
113  this_t* operator->() const {
114  assert(current != nullptr);
115  return current;
116  }
117 
118  NodeIterator& operator++(int blah) { this->operator++(); return *this; }
120  assert(current != nullptr);
121  assert(not (*this == EndNodeIterator) && "Can't iterate past the end!");
122  if(current == nullptr or current == start or current->is_root()) {
123  current = nullptr;
124  return EndNodeIterator;
125  }
126 
127  // go through the children if we can
128  if(current->pi+1 < current->parent->children.size()) {
129  current = current->parent->children[current->pi+1].left_descend();
130  }
131  else { // now we call the parent (if we're out of children)
132  current = current->parent;
133  }
134  return *this;
135  }
136 
137  NodeIterator& operator+(size_t n) {
138  for(size_t i=0;i<n;i++) this->operator++();
139  return *this;
140  }
141 
142  bool operator==(const NodeIterator& rhs) const { return current == rhs.current; }
143  bool operator!=(const NodeIterator& rhs) const { return not(current == rhs.current); }
144  };
145  static NodeIterator EndNodeIterator; // defined below
147 
148  NodeIterator begin() const { return BaseNode<this_t>::NodeIterator(static_cast<const this_t*>(this)); }
150 
151  virtual bool operator!=(const this_t& n) const{
152  return not (*this == n);
153  }
154 
155  void reserve_children(const size_t n) {
156  children.reserve(n);
157  }
158 
159  decltype(children)& get_children() {
160  return children;
161  }
162  const decltype(children)& get_children() const {
163  return children;
164  }
165 
166  this_t& child(const size_t i) {
173  return children.at(i);
174  }
175 
176  const this_t& child(const size_t i) const {
183  return children.at(i);
184  }
185 
186  template<typename... Args>
187  void fill(size_t n, Args... args) {
192  // ensure that all of my children are empty nodes
193  for(size_t i=0;i<n;i++) {
194  set_child(i, this_t(args...));
195  }
196  }
197 
198 
199  size_t nchildren() const {
205  return children.size();
206  }
207 
208  this_t* left_descend() const {
214  this_t* k = const_cast<this_t*>(static_cast<const this_t*>(this));
215  while(k != nullptr and k->nchildren() > 0) {
216  k = &(k->child(0));
217  }
218  return k;
219  }
220 
221  size_t depth() const {
222  size_t d = 0;
223  for(const auto& c: children) {
224  d = std::max(d, c.depth()+1);
225  }
226  return d;
227  }
228 
229  void fix_child_info() {
234  // go through children and assign their parents to me
235  // and fix their pi's
236  size_t i = 0;
237  for(auto& c : children) {
238  c.pi = i;
239  c.parent = static_cast<this_t*>(this);
240  i++;
241  }
242  }
243 
244  void check_child_info() const {
249  size_t i = 0;
250  for(const auto& c : children) {
251 
252  // check that the kids point to the right things
253  assert(c.pi == i);
254  assert(c.parent == this);
255  i++;
256  }
257  }
258 
259 
260  this_t& operator[](const size_t i) {
267  return children.at(i); // at does bounds checking
268  }
269 
270  const this_t& operator[](const size_t i) const {
271  return children.at(i);
272  }
273 
274  void set_child(const size_t i, this_t& n) {
281  while(children.size() <= i) // make it big enough for i
282  children.push_back(this_t());
283 
284  children[i] = n;
285  children[i].pi = i;
286  children[i].parent = static_cast<this_t*>(this);
287  }
288  void set_child(const size_t i, this_t&& n) {
294  // NOTE: if you add anything fancy to this, be sure to update the copy and move constructors
295 
296  while(children.size() <= i) // make it big enough for i
297  children.push_back(this_t());
298 
299  children[i] = n;
300  children[i].pi = i;
301  children[i].parent = static_cast<this_t*>(this);
302  }
303 
304  void push_back(this_t& n) {
305  set_child(children.size(), n);
306  }
307  void push_back(this_t&& n) {
308  set_child(children.size(), n);
309  }
310 
315  virtual bool is_root() const {
316  return parent == nullptr;
317  }
318 
319 
324  this_t* root() {
325  this_t* x = static_cast<this_t*>(this);
326  while(x->parent != nullptr) {
327  x = x->parent;
328  }
329  return x;
330  }
331 
337  this_t* get_via(std::function<bool(this_t&)>& f ) {
338  for(auto& n : *this) {
339  if(f(n)) return &n;
340  }
341  return nullptr;
342  }
343 
349  virtual size_t count() const {
350 
351  size_t n=1; // me
352  for(auto& c : children) {
353  n += c.count();
354  }
355  return n;
356  }
357 
363  virtual size_t count(const this_t& n) const {
364  size_t cnt = (n == *static_cast<const this_t*>(this));
365  for(auto& c : children) {
366  cnt += c.count(n);
367  }
368  return cnt;
369  }
370 
375  virtual bool is_terminal() const {
376  return children.size() == 0;
377  }
378 
379  virtual size_t count_terminals() const {
385  size_t cnt = 0;
386  for(const auto& n : *this) {
387  if(n.is_terminal()) ++cnt;
388  }
389  return cnt;
390  }
391 
392  virtual this_t* get_nth(int n, std::function<int(const this_t&)>& f) {
400  for(auto& x : *this) {
401  if(f(x)) {
402  if(n == 0) return &x;
403  else --n;
404  }
405  }
406 
407  return nullptr; // not here, losers
408  }
409  virtual this_t* get_nth(int n) { // default true on every node
410  std::function<int(const this_t&)> f = [](const this_t& x) { return 1;};
411  return get_nth(n, f);
412  }
413 
414 
415  template<typename T>
416  T sum(std::function<T(const this_t&)>& f ) const {
423  T s = f(* dynamic_cast<const this_t*>(this)); // a little ugly here because it needs to be the subclass type
424  for(auto& c: this->children) {
425  s += c.sum(f);
426  }
427  return s;
428  }
429 
430  template<typename T>
431  T sum(T(*f)(const this_t&) ) const {
432  std::function ff = f;
433  return sum(ff);
434  }
435 
436 
437  bool all(std::function<bool(const this_t&)>& f ) const {
444  if(not f(* dynamic_cast<const this_t*>(this)))
445  return false;
446 
447  for(auto& c: this->children) {
448  if(not c.all(f))
449  return false;
450  }
451  return true;
452  }
453 
454  void map( const std::function<void(this_t&)>& f) {
460  f(* dynamic_cast<const this_t*>(this));
461  for(auto& c: this->children) {
462  c.map(f);
463  }
464  }
465 
466 
467  void print(size_t t=0) const {
468 
469  std::string tabs(t,'\t');
470 
471  COUT tabs << this->my_string() ENDL;
472  for(auto& c : children) {
473  c.print(t+1);
474  }
475  }
476 
477 };
478 
479 
480 template<typename this_t>
482 
483 
484 
485 template<typename this_t>
486 std::ostream& operator<<(std::ostream& o, const BaseNode<this_t>& t) {
487  o << t.string();
488  return o;
489 }
virtual bool operator==(const this_t &n) const
Definition: BaseNode.h:82
virtual this_t * get_nth(int n, std::function< int(const this_t &)> &f)
Definition: BaseNode.h:392
size_t pi
Definition: BaseNode.h:27
void reserve_children(const size_t n)
Definition: BaseNode.h:155
decltype(children) & get_children()
Definition: BaseNode.h:159
std::vector< this_t > children
Definition: BaseNode.h:23
virtual bool is_terminal() const
Am I a terminal? I am if I have no children.
Definition: BaseNode.h:375
NodeIterator & operator++(int blah)
Definition: BaseNode.h:118
this_t * root()
Find the root of this node by walking up the tree.
Definition: BaseNode.h:324
void fix_child_info()
Definition: BaseNode.h:229
bool operator==(const NodeIterator &rhs) const
Definition: BaseNode.h:142
void push_back(this_t &n)
Definition: BaseNode.h:304
virtual std::string my_string() const
Definition: BaseNode.h:79
T sum(T(*f)(const this_t &)) const
Definition: BaseNode.h:431
void set_child(const size_t i, this_t &n)
Definition: BaseNode.h:274
void print(size_t t=0) const
Definition: BaseNode.h:467
this_t & operator*() const
Definition: BaseNode.h:109
this_t * get_via(std::function< bool(this_t &)> &f)
Return a pointer to the first node satisfying predicate f, in standard traversal; nullptr otherwise...
Definition: BaseNode.h:337
Definition: Errors.h:18
NodeIterator end() const
Definition: BaseNode.h:149
this_t & operator[](const size_t i)
Definition: BaseNode.h:260
decltype(children) const & get_children() const
Definition: BaseNode.h:162
T sum(std::function< T(const this_t &)> &f) const
Definition: BaseNode.h:416
NodeIterator & operator++()
Definition: BaseNode.h:119
virtual size_t count_terminals() const
Definition: BaseNode.h:379
BaseNode(this_t &&n)
Definition: BaseNode.h:44
void check_child_info() const
Definition: BaseNode.h:244
void map(const std::function< void(this_t &)> &f)
Definition: BaseNode.h:454
static NodeIterator EndNodeIterator
Definition: BaseNode.h:145
virtual bool operator!=(const this_t &n) const
Definition: BaseNode.h:151
BaseNode(size_t n=0, this_t *p=nullptr, size_t i=0)
Constructor of basenode – sizes children to n.
Definition: BaseNode.h:35
Definition: BaseNode.h:20
const this_t & child(const size_t i) const
Definition: BaseNode.h:176
void fill(size_t n, Args... args)
Definition: BaseNode.h:187
this_t * left_descend() const
Definition: BaseNode.h:208
NodeIterator(const this_t *n)
Definition: BaseNode.h:108
this_t & child(const size_t i)
Definition: BaseNode.h:166
void make_root()
Make a node root – just nulls the parent.
Definition: BaseNode.h:68
virtual std::string string(bool usedot=true) const
Definition: BaseNode.h:76
this_t * parent
Definition: BaseNode.h:26
Definition: BaseNode.h:95
#define ENDL
Definition: IO.h:21
virtual size_t count(const this_t &n) const
How many nodes below me are equal to n?
Definition: BaseNode.h:363
bool all(std::function< bool(const this_t &)> &f) const
Definition: BaseNode.h:437
BaseNode(const this_t &n)
Definition: BaseNode.h:38
bool operator!=(const NodeIterator &rhs) const
Definition: BaseNode.h:143
virtual this_t * get_nth(int n)
Definition: BaseNode.h:409
this_t * current
Definition: BaseNode.h:100
void operator=(const this_t &t)
Definition: BaseNode.h:51
size_t depth() const
Definition: BaseNode.h:221
const this_t * start
Definition: BaseNode.h:104
const this_t & operator[](const size_t i) const
Definition: BaseNode.h:270
void operator=(const this_t &&t)
Definition: BaseNode.h:57
void set_child(const size_t i, this_t &&n)
Definition: BaseNode.h:288
this_t * operator->() const
Definition: BaseNode.h:113
virtual ~BaseNode()
Definition: BaseNode.h:73
virtual size_t count() const
How many nodes total are below me?
Definition: BaseNode.h:349
NodeIterator & operator+(size_t n)
Definition: BaseNode.h:137
#define COUT
Definition: IO.h:24
NodeIterator begin() const
Definition: BaseNode.h:148
virtual bool is_root() const
Am I a root node? I am if my parent is nullptr.
Definition: BaseNode.h:315
size_t nchildren() const
Definition: BaseNode.h:199
void push_back(this_t &&n)
Definition: BaseNode.h:307