mlpack
spill_tree_impl.hpp
Go to the documentation of this file.
1 
11 #ifndef MLPACK_CORE_TREE_SPILL_TREE_SPILL_TREE_IMPL_HPP
12 #define MLPACK_CORE_TREE_SPILL_TREE_SPILL_TREE_IMPL_HPP
13 
14 // In case it wasn't included already for some reason.
15 #include "spill_tree.hpp"
16 
17 #include <queue>
18 
19 namespace mlpack {
20 namespace tree {
21 
22 template<typename MetricType,
23  typename StatisticType,
24  typename MatType,
25  template<typename HyperplaneMetricType> class HyperplaneType,
26  template<typename SplitMetricType, typename SplitMatType>
27  class SplitType>
30  const MatType& data,
31  const double tau,
32  const size_t maxLeafSize,
33  const double rho) :
34  left(NULL),
35  right(NULL),
36  parent(NULL),
37  count(data.n_cols),
38  pointsIndex(NULL),
39  overlappingNode(false),
40  hyperplane(),
41  bound(data.n_rows),
42  parentDistance(0), // Parent distance for the root is 0: it has no parent.
43  dataset(&data),
44  localDataset(false)
45 {
46  arma::Col<size_t> points;
47  if (dataset->n_cols > 0)
48  // Fill points with all possible indexes: 0 .. (dataset->n_cols - 1).
49  points = arma::linspace<arma::Col<size_t>>(0, dataset->n_cols - 1,
50  dataset->n_cols);
51 
52  // Do the actual splitting of this node.
53  SplitNode(points, maxLeafSize, tau, rho);
54 
55  // Create the statistic depending on if we are a leaf or not.
56  stat = StatisticType(*this);
57 }
58 
59 template<typename MetricType,
60  typename StatisticType,
61  typename MatType,
62  template<typename HyperplaneMetricType> class HyperplaneType,
63  template<typename SplitMetricType, typename SplitMatType>
64  class SplitType>
67  MatType&& data,
68  const double tau,
69  const size_t maxLeafSize,
70  const double rho) :
71  left(NULL),
72  right(NULL),
73  parent(NULL),
74  count(data.n_cols),
75  pointsIndex(NULL),
76  overlappingNode(false),
77  hyperplane(),
78  bound(data.n_rows),
79  parentDistance(0), // Parent distance for the root is 0: it has no parent.
80  dataset(new MatType(std::move(data))),
81  localDataset(true)
82 {
83  arma::Col<size_t> points;
84  if (dataset->n_cols > 0)
85  // Fill points with all possible indexes: 0 .. (dataset->n_cols - 1).
86  points = arma::linspace<arma::Col<size_t>>(0, dataset->n_cols - 1,
87  dataset->n_cols);
88 
89  // Do the actual splitting of this node.
90  SplitNode(points, maxLeafSize, tau, rho);
91 
92  // Create the statistic depending on if we are a leaf or not.
93  stat = StatisticType(*this);
94 }
95 
96 template<typename MetricType,
97  typename StatisticType,
98  typename MatType,
99  template<typename HyperplaneMetricType> class HyperplaneType,
100  template<typename SplitMetricType, typename SplitMatType>
101  class SplitType>
104  SpillTree* parent,
105  arma::Col<size_t>& points,
106  const double tau,
107  const size_t maxLeafSize,
108  const double rho) :
109  left(NULL),
110  right(NULL),
111  parent(parent),
112  count(points.n_elem),
113  pointsIndex(NULL),
114  overlappingNode(false),
115  hyperplane(),
116  bound(parent->Dataset().n_rows),
117  dataset(&parent->Dataset()), // Point to the parent's dataset.
118  localDataset(false)
119 {
120  // Perform the actual splitting.
121  SplitNode(points, maxLeafSize, tau, rho);
122 
123  // Create the statistic depending on if we are a leaf or not.
124  stat = StatisticType(*this);
125 }
126 
131 template<typename MetricType,
132  typename StatisticType,
133  typename MatType,
134  template<typename HyperplaneMetricType> class HyperplaneType,
135  template<typename SplitMetricType, typename SplitMatType>
136  class SplitType>
138 SpillTree(const SpillTree& other) :
139  left(NULL),
140  right(NULL),
141  parent(other.parent),
142  count(other.count),
143  pointsIndex(NULL),
144  overlappingNode(other.overlappingNode),
145  hyperplane(other.hyperplane),
146  bound(other.bound),
147  stat(other.stat),
148  parentDistance(other.parentDistance),
149  furthestDescendantDistance(other.furthestDescendantDistance),
150  // Copy matrix, but only if we are the root and the other tree has its own
151  // copy of the dataset.
152  dataset((other.parent == NULL && other.localDataset) ?
153  new MatType(*other.dataset) : other.dataset),
154  localDataset(other.parent == NULL && other.localDataset)
155 {
156  // Create left and right children (if any).
157  if (other.Left())
158  {
159  left = new SpillTree(*other.Left());
160  left->Parent() = this; // Set parent to this, not other tree.
161  }
162 
163  if (other.Right())
164  {
165  right = new SpillTree(*other.Right());
166  right->Parent() = this; // Set parent to this, not other tree.
167  }
168 
169  // If vector of indexes, copy it.
170  if (other.pointsIndex)
171  pointsIndex = new arma::Col<size_t>(*other.pointsIndex);
172 
173  // Propagate matrix, but only if we are the root.
174  if (parent == NULL && localDataset)
175  {
176  std::queue<SpillTree*> queue;
177  if (left)
178  queue.push(left);
179  if (right)
180  queue.push(right);
181  while (!queue.empty())
182  {
183  SpillTree* node = queue.front();
184  queue.pop();
185 
186  node->dataset = dataset;
187  if (node->left)
188  queue.push(node->left);
189  if (node->right)
190  queue.push(node->right);
191  }
192  }
193 }
194 
198 template<typename MetricType,
199  typename StatisticType,
200  typename MatType,
201  template<typename HyperplaneMetricType> class HyperplaneType,
202  template<typename SplitMetricType, typename SplitMatType>
203  class SplitType>
206 operator=(const SpillTree& other)
207 {
208  if (this == &other)
209  return *this;
210 
211  // Freeing memory that will not be used anymore.
212  if (localDataset)
213  delete dataset;
214 
215  delete pointsIndex;
216  delete left;
217  delete right;
218 
219  left = NULL;
220  right = NULL;
221  parent = other.parent;
222  count = other.count;
223  pointsIndex = NULL;
224  overlappingNode = other.overlappingNode;
225  hyperplane = other.hyperplane;
226  bound = other.bound;
227  stat = other.stat;
228  parentDistance = other.parentDistance;
229  furthestDescendantDistance = other.furthestDescendantDistance;
230 
231  // Copy matrix, but only if we are the root and the other tree has its own
232  // copy of the dataset.
233  dataset = (other.parent == NULL && other.localDataset) ?
234  new MatType(*other.dataset) : other.dataset;
235  localDataset = other.parent == NULL && other.localDataset;
236 
237  // Create left and right children (if any).
238  if (other.Left())
239  {
240  left = new SpillTree(*other.Left());
241  left->Parent() = this; // Set parent to this, not other tree.
242  }
243 
244  if (other.Right())
245  {
246  right = new SpillTree(*other.Right());
247  right->Parent() = this; // Set parent to this, not other tree.
248  }
249 
250  // If vector of indexes, copy it.
251  if (other.pointsIndex)
252  pointsIndex = new arma::Col<size_t>(*other.pointsIndex);
253 
254  // Propagate matrix, but only if we are the root.
255  if (parent == NULL && localDataset)
256  {
257  std::queue<SpillTree*> queue;
258  if (left)
259  queue.push(left);
260  if (right)
261  queue.push(right);
262  while (!queue.empty())
263  {
264  SpillTree* node = queue.front();
265  queue.pop();
266 
267  node->dataset = dataset;
268  if (node->left)
269  queue.push(node->left);
270  if (node->right)
271  queue.push(node->right);
272  }
273  }
274  return *this;
275 }
276 
280 template<typename MetricType,
281  typename StatisticType,
282  typename MatType,
283  template<typename HyperplaneMetricType> class HyperplaneType,
284  template<typename SplitMetricType, typename SplitMatType>
285  class SplitType>
288  left(other.left),
289  right(other.right),
290  parent(other.parent),
291  count(other.count),
292  pointsIndex(other.pointsIndex),
293  overlappingNode(other.overlappingNode),
294  hyperplane(other.hyperplane),
295  bound(std::move(other.bound)),
296  stat(std::move(other.stat)),
297  parentDistance(other.parentDistance),
298  furthestDescendantDistance(other.furthestDescendantDistance),
299  minimumBoundDistance(other.minimumBoundDistance),
300  dataset(other.dataset),
301  localDataset(other.localDataset)
302 {
303  // Now we are a clone of the other tree. But we must also clear the other
304  // tree's contents, so it doesn't delete anything when it is destructed.
305  other.left = NULL;
306  other.right = NULL;
307  other.count = 0;
308  other.pointsIndex = NULL;
309  other.parentDistance = 0.0;
310  other.furthestDescendantDistance = 0.0;
311  other.minimumBoundDistance = 0.0;
312  other.dataset = NULL;
313  other.localDataset = false;
314 
315  // Set new parent.
316  if (left)
317  left->parent = this;
318  if (right)
319  right->parent = this;
320 }
321 
325 template<typename MetricType,
326  typename StatisticType,
327  typename MatType,
328  template<typename HyperplaneMetricType> class HyperplaneType,
329  template<typename SplitMetricType, typename SplitMatType>
330  class SplitType>
334 {
335  if (this == &other)
336  return *this;
337 
338  // Freeing memory that will not be used anymore.
339  if (localDataset)
340  delete dataset;
341 
342  delete pointsIndex;
343  delete left;
344  delete right;
345 
346  left = other.left;
347  right = other.right;
348  parent = other.parent;
349  count = other.count;
350  pointsIndex = other.pointsIndex;
351  overlappingNode = other.overlappingNode;
352  hyperplane = other.hyperplane;
353  bound = std::move(other.bound);
354  stat = std::move(other.stat);
355  parentDistance = other.parentDistance;
356  furthestDescendantDistance = other.furthestDescendantDistance;
357  minimumBoundDistance = other.minimumBoundDistance;
358  dataset = other.dataset;
359  localDataset = other.localDataset;
360 
361  // Now we are a clone of the other tree. But we must also clear the other
362  // tree's contents, so it doesn't delete anything when it is destructed.
363  other.left = NULL;
364  other.right = NULL;
365  other.count = 0;
366  other.pointsIndex = NULL;
367  other.parentDistance = 0.0;
368  other.furthestDescendantDistance = 0.0;
369  other.minimumBoundDistance = 0.0;
370  other.dataset = NULL;
371  other.localDataset = false;
372 
373  // Set new parent.
374  if (left)
375  left->parent = this;
376  if (right)
377  right->parent = this;
378 
379  return *this;
380 }
381 
385 template<typename MetricType,
386  typename StatisticType,
387  typename MatType,
388  template<typename HyperplaneMetricType> class HyperplaneType,
389  template<typename SplitMetricType, typename SplitMatType>
390  class SplitType>
391 template<typename Archive>
394  Archive& ar,
395  const typename std::enable_if_t<cereal::is_loading<Archive>()>*) :
396  SpillTree() // Create an empty SpillTree.
397 {
398  // We've delegated to the constructor which gives us an empty tree, and now we
399  // can serialize from it.
400  ar(CEREAL_NVP(*this));
401 }
402 
408 template<typename MetricType,
409  typename StatisticType,
410  typename MatType,
411  template<typename HyperplaneMetricType> class HyperplaneType,
412  template<typename SplitMetricType, typename SplitMatType>
413  class SplitType>
416 {
417  delete left;
418  delete right;
419  delete pointsIndex;
420 
421  // If we're the root and we own the dataset, delete it.
422  if (!parent && localDataset)
423  delete dataset;
424 }
425 
426 template<typename MetricType,
427  typename StatisticType,
428  typename MatType,
429  template<typename HyperplaneMetricType> class HyperplaneType,
430  template<typename SplitMetricType, typename SplitMatType>
431  class SplitType>
432 inline bool SpillTree<MetricType, StatisticType, MatType, HyperplaneType,
433  SplitType>::IsLeaf() const
434 {
435  return !left;
436 }
437 
441 template<typename MetricType,
442  typename StatisticType,
443  typename MatType,
444  template<typename HyperplaneMetricType> class HyperplaneType,
445  template<typename SplitMetricType, typename SplitMatType>
446  class SplitType>
447 inline size_t SpillTree<MetricType, StatisticType, MatType, HyperplaneType,
448  SplitType>::NumChildren() const
449 {
450  if (left && right)
451  return 2;
452  if (left)
453  return 1;
454 
455  return 0;
456 }
457 
464 template<typename MetricType,
465  typename StatisticType,
466  typename MatType,
467  template<typename HyperplaneMetricType> class HyperplaneType,
468  template<typename SplitMetricType, typename SplitMatType>
469  class SplitType>
470 template<typename VecType>
471 size_t SpillTree<MetricType, StatisticType, MatType, HyperplaneType,
472  SplitType>::GetNearestChild(
473  const VecType& point,
474  typename std::enable_if_t<IsVector<VecType>::value>*)
475 {
476  if (IsLeaf() || !left || !right)
477  return 0;
478 
479  if (hyperplane.Left(point))
480  return 0;
481  return 1;
482 }
483 
490 template<typename MetricType,
491  typename StatisticType,
492  typename MatType,
493  template<typename HyperplaneMetricType> class HyperplaneType,
494  template<typename SplitMetricType, typename SplitMatType>
495  class SplitType>
496 template<typename VecType>
497 size_t SpillTree<MetricType, StatisticType, MatType, HyperplaneType,
499  const VecType& point,
500  typename std::enable_if_t<IsVector<VecType>::value>*)
501 {
502  if (IsLeaf() || !left || !right)
503  return 0;
504 
505  if (hyperplane.Left(point))
506  return 1;
507  return 0;
508 }
509 
516 template<typename MetricType,
517  typename StatisticType,
518  typename MatType,
519  template<typename HyperplaneMetricType> class HyperplaneType,
520  template<typename SplitMetricType, typename SplitMatType>
521  class SplitType>
522 size_t SpillTree<MetricType, StatisticType, MatType, HyperplaneType,
523  SplitType>::GetNearestChild(const SpillTree& queryNode)
524 {
525  if (IsLeaf() || !left || !right)
526  return 0;
527 
528  if (hyperplane.Left(queryNode.Bound()))
529  return 0;
530  if (hyperplane.Right(queryNode.Bound()))
531  return 1;
532  // Can't decide.
533  return 2;
534 }
535 
542 template<typename MetricType,
543  typename StatisticType,
544  typename MatType,
545  template<typename HyperplaneMetricType> class HyperplaneType,
546  template<typename SplitMetricType, typename SplitMatType>
547  class SplitType>
548 size_t SpillTree<MetricType, StatisticType, MatType, HyperplaneType,
549  SplitType>::GetFurthestChild(const SpillTree& queryNode)
550 {
551  if (IsLeaf() || !left || !right)
552  return 0;
553 
554  if (hyperplane.Left(queryNode.Bound()))
555  return 1;
556  if (hyperplane.Right(queryNode.Bound()))
557  return 0;
558  // Can't decide.
559  return 2;
560 }
561 
566 template<typename MetricType,
567  typename StatisticType,
568  typename MatType,
569  template<typename HyperplaneMetricType> class HyperplaneType,
570  template<typename SplitMetricType, typename SplitMatType>
571  class SplitType>
572 inline typename SpillTree<MetricType, StatisticType, MatType, HyperplaneType,
573  SplitType>::ElemType
576 {
577  if (!IsLeaf())
578  return 0.0;
579 
580  // Otherwise return the distance from the center to a corner of the bound.
581  return 0.5 * bound.Diameter();
582 }
583 
591 template<typename MetricType,
592  typename StatisticType,
593  typename MatType,
594  template<typename HyperplaneMetricType> class HyperplaneType,
595  template<typename SplitMetricType, typename SplitMatType>
596  class SplitType>
597 inline typename SpillTree<MetricType, StatisticType, MatType, HyperplaneType,
598  SplitType>::ElemType
601 {
602  return furthestDescendantDistance;
603 }
604 
606 template<typename MetricType,
607  typename StatisticType,
608  typename MatType,
609  template<typename HyperplaneMetricType> class HyperplaneType,
610  template<typename SplitMetricType, typename SplitMatType>
611  class SplitType>
612 inline typename SpillTree<MetricType, StatisticType, MatType, HyperplaneType,
613  SplitType>::ElemType
616 {
617  return bound.MinWidth() / 2.0;
618 }
619 
623 template<typename MetricType,
624  typename StatisticType,
625  typename MatType,
626  template<typename HyperplaneMetricType> class HyperplaneType,
627  template<typename SplitMetricType, typename SplitMatType>
628  class SplitType>
631  Child(const size_t child) const
632 {
633  if (child == 0)
634  return *left;
635  else
636  return *right;
637 }
638 
642 template<typename MetricType,
643  typename StatisticType,
644  typename MatType,
645  template<typename HyperplaneMetricType> class HyperplaneType,
646  template<typename SplitMetricType, typename SplitMatType>
647  class SplitType>
648 inline size_t SpillTree<MetricType, StatisticType, MatType, HyperplaneType,
649  SplitType>::NumPoints() const
650 {
651  if (IsLeaf())
652  return count;
653  return 0;
654 }
655 
659 template<typename MetricType,
660  typename StatisticType,
661  typename MatType,
662  template<typename HyperplaneMetricType> class HyperplaneType,
663  template<typename SplitMetricType, typename SplitMatType>
664  class SplitType>
665 inline size_t SpillTree<MetricType, StatisticType, MatType, HyperplaneType,
666  SplitType>::NumDescendants() const
667 {
668  return count;
669 }
670 
674 template<typename MetricType,
675  typename StatisticType,
676  typename MatType,
677  template<typename HyperplaneMetricType> class HyperplaneType,
678  template<typename SplitMetricType, typename SplitMatType>
679  class SplitType>
680 inline size_t SpillTree<MetricType, StatisticType, MatType, HyperplaneType,
681  SplitType>::Descendant(const size_t index) const
682 {
683  if (IsLeaf() || overlappingNode)
684  return (*pointsIndex)[index];
685 
686  // If this is not a leaf and not an overlapping node, then determine whether
687  // we should get the descendant from the left or the right node.
688  const size_t num = left->NumDescendants();
689  if (index < num)
690  return left->Descendant(index);
691  else
692  return right->Descendant(index - num);
693 }
694 
698 template<typename MetricType,
699  typename StatisticType,
700  typename MatType,
701  template<typename HyperplaneMetricType> class HyperplaneType,
702  template<typename SplitMetricType, typename SplitMatType>
703  class SplitType>
704 inline size_t SpillTree<MetricType, StatisticType, MatType, HyperplaneType,
705  SplitType>::Point(const size_t index) const
706 {
707  if (IsLeaf())
708  return (*pointsIndex)[index];
709  // This should never happen.
710  return (size_t() - 1);
711 }
712 
713 template<typename MetricType,
714  typename StatisticType,
715  typename MatType,
716  template<typename HyperplaneMetricType> class HyperplaneType,
717  template<typename SplitMetricType, typename SplitMatType>
718  class SplitType>
720  SplitNode(arma::Col<size_t>& points,
721  const size_t maxLeafSize,
722  const double tau,
723  const double rho)
724 {
725  // We need to expand the bounds of this node properly.
726  for (size_t i = 0; i < points.n_elem; ++i)
727  bound |= dataset->col(points[i]);
728 
729  // Calculate the furthest descendant distance.
730  furthestDescendantDistance = 0.5 * bound.Diameter();
731 
732  // Now, check if we need to split at all.
733  if (points.n_elem <= maxLeafSize)
734  {
735  pointsIndex = new arma::Col<size_t>();
736  pointsIndex->swap(points);
737  return; // We can't split this.
738  }
739 
740  const bool split = SplitType<MetricType, MatType>::SplitSpace(bound,
741  *dataset, points, hyperplane);
742  // The node may not be always split. For instance, if all the points are the
743  // same, we can't split them.
744  if (!split)
745  {
746  pointsIndex = new arma::Col<size_t>();
747  pointsIndex->swap(points);
748  return; // We can't split this.
749  }
750 
751  arma::Col<size_t> leftPoints, rightPoints;
752  // Split the node.
753  overlappingNode = SplitPoints(tau, rho, points, leftPoints, rightPoints);
754 
755  if (overlappingNode)
756  {
757  // If the node is overlapping, we have to keep track of which points are
758  // held in the node.
759  pointsIndex = new arma::Col<size_t>();
760  pointsIndex->swap(points);
761  }
762  else
763  {
764  // Otherwise, we don't need the information in points, so let's clean it.
765  arma::Col<size_t>().swap(points);
766  }
767 
768  // Now we will recursively split the children by calling their constructors
769  // (which perform this splitting process).
770  left = new SpillTree(this, leftPoints, tau, maxLeafSize, rho);
771  right = new SpillTree(this, rightPoints, tau, maxLeafSize, rho);
772 
773  // Calculate parent distances for those two nodes.
774  arma::vec center, leftCenter, rightCenter;
775  Center(center);
776  left->Center(leftCenter);
777  right->Center(rightCenter);
778 
779  const ElemType leftParentDistance = MetricType::Evaluate(center, leftCenter);
780  const ElemType rightParentDistance = MetricType::Evaluate(center,
781  rightCenter);
782 
783  left->ParentDistance() = leftParentDistance;
784  right->ParentDistance() = rightParentDistance;
785 }
786 
787 template<typename MetricType,
788  typename StatisticType,
789  typename MatType,
790  template<typename HyperplaneMetricType> class HyperplaneType,
791  template<typename SplitMetricType, typename SplitMatType>
792  class SplitType>
794  SplitPoints(const double tau,
795  const double rho,
796  const arma::Col<size_t>& points,
797  arma::Col<size_t>& leftPoints,
798  arma::Col<size_t>& rightPoints)
799 {
800  arma::vec projections(points.n_elem);
801  size_t left = 0, right = 0, leftFrontier = 0, rightFrontier = 0;
802 
803  // Count the number of points to the left/right of the splitting hyperplane.
804  for (size_t i = 0; i < points.n_elem; ++i)
805  {
806  // Store projection value for future use.
807  projections[i] = hyperplane.Project(dataset->col(points[i]));
808  if (projections[i] <= 0)
809  {
810  left++;
811  if (projections[i] > -tau)
812  leftFrontier++;
813  }
814  else
815  {
816  right++;
817  if (projections[i] < tau)
818  rightFrontier++;
819  }
820  }
821 
822  const double p1 = (double) (left + rightFrontier) / points.n_elem;
823  const double p2 = (double) (right + leftFrontier) / points.n_elem;
824 
825  if ((p1 <= rho || rightFrontier == 0) &&
826  (p2 <= rho || leftFrontier == 0))
827  {
828  // Perform the actual splitting considering the overlapping buffer. Points
829  // with projection value in the range (-tau, tau) are included in both,
830  // leftPoints and rightPoints.
831  const size_t leftUnique = points.n_elem - right - leftFrontier;
832  const size_t overlap = leftFrontier + rightFrontier;
833 
834  leftPoints.resize(left + rightFrontier);
835  rightPoints.resize(right + leftFrontier);
836  for (size_t i = 0, rc = overlap, lc = 0, rf = 0, lf = leftUnique;
837  i < points.n_elem; ++i)
838  {
839  // We put any points in the frontier should come last in the left node,
840  // and first in the right node. (This ordering is not required.)
841  if (projections[i] < -tau)
842  leftPoints[lc++] = points[i];
843  else if (projections[i] < tau)
844  leftPoints[lf++] = points[i];
845 
846  if (projections[i] > tau)
847  rightPoints[rc++] = points[i];
848  else if (projections[i] > -tau)
849  rightPoints[rf++] = points[i];
850  }
851  // Return true, because it is a overlapping node.
852  return true;
853  }
854 
855  // Perform the actual splitting ignoring the overlapping buffer. Points
856  // with projection value less than or equal to zero are included in leftPoints
857  // and points with projection value greater than zero are included in
858  // rightPoints.
859  leftPoints.resize(left);
860  rightPoints.resize(right);
861  for (size_t i = 0, rc = 0, lc = 0; i < points.n_elem; ++i)
862  {
863  if (projections[i] <= 0)
864  leftPoints[lc++] = points[i];
865  else
866  rightPoints[rc++] = points[i];
867  }
868  // Return false, because it isn't a overlapping node.
869  return false;
870 }
871 
872 // Default constructor (private), for cereal.
873 template<typename MetricType,
874  typename StatisticType,
875  typename MatType,
876  template<typename HyperplaneMetricType> class HyperplaneType,
877  template<typename SplitMetricType, typename SplitMatType>
878  class SplitType>
881  left(NULL),
882  right(NULL),
883  parent(NULL),
884  count(0),
885  pointsIndex(NULL),
886  overlappingNode(false),
887  stat(*this),
888  parentDistance(0),
889  furthestDescendantDistance(0),
890  dataset(NULL),
891  localDataset(false)
892 {
893  // Nothing to do.
894 }
895 
899 template<typename MetricType,
900  typename StatisticType,
901  typename MatType,
902  template<typename HyperplaneMetricType> class HyperplaneType,
903  template<typename SplitMetricType, typename SplitMatType>
904  class SplitType>
905 template<typename Archive>
907  serialize(Archive& ar, const uint32_t /* version */)
908 {
909  // If we're loading, and we have children, they need to be deleted.
910  if (cereal::is_loading<Archive>())
911  {
912  if (left)
913  delete left;
914  if (right)
915  delete right;
916  if (!parent && localDataset)
917  delete dataset;
918 
919  parent = NULL;
920  left = NULL;
921  right = NULL;
922  }
923 
924  if (cereal::is_loading<Archive>())
925  {
926  localDataset = true;
927  }
928  ar(CEREAL_NVP(count));
929  ar(CEREAL_POINTER(pointsIndex));
930  ar(CEREAL_NVP(overlappingNode));
931  ar(CEREAL_NVP(hyperplane));
932  ar(CEREAL_NVP(bound));
933  ar(CEREAL_NVP(stat));
934  ar(CEREAL_NVP(parentDistance));
935  ar(CEREAL_NVP(furthestDescendantDistance));
936  // Force a non-const pointer.
937  MatType*& datasetPtr = const_cast<MatType*&>(dataset);
938 
939  // Save children last; otherwise cereal gets confused.
940  bool hasLeft = (left != NULL);
941  bool hasRight = (right != NULL);
942  bool hasParent = (parent != NULL);
943 
944  ar(CEREAL_NVP(hasLeft));
945  ar(CEREAL_NVP(hasRight));
946  ar(CEREAL_NVP(hasParent));
947 
948  if (hasLeft)
949  ar(CEREAL_POINTER(left));
950  if (hasRight)
951  ar(CEREAL_POINTER(right));
952  if (!hasParent)
953  ar(CEREAL_POINTER(datasetPtr));
954 
955  if (cereal::is_loading<Archive>())
956  {
957  if (left)
958  {
959  left->parent = this;
960  left->localDataset = false;
961  }
962  if (right)
963  {
964  right->parent = this;
965  right->localDataset = false;
966  }
967  }
968 
969  // If we are the root, we need to restore the dataset pointer throughout
970  if (!hasParent)
971  {
972  std::stack<SpillTree*> stack;
973  if (left)
974  stack.push(left);
975  if (right)
976  stack.push(right);
977  while (!stack.empty())
978  {
979  SpillTree* node = stack.top();
980  stack.pop();
981  node->dataset = dataset;
982  if (node->left)
983  stack.push(node->left);
984  if (node->right)
985  stack.push(node->right);
986  }
987  }
988 }
989 
990 } // namespace tree
991 } // namespace mlpack
992 
993 #endif
SpillTree & operator=(const SpillTree &other)
Copy the given Spill Tree.
Definition: spill_tree_impl.hpp:206
MatType::elem_type ElemType
The type of element held in MatType.
Definition: spill_tree.hpp:79
ElemType FurthestDescendantDistance() const
Return the furthest possible descendant distance.
Definition: spill_tree_impl.hpp:600
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Definition: pointer_wrapper.hpp:23
size_t NumPoints() const
Return the number of points in this node (0 if not a leaf).
Definition: spill_tree_impl.hpp:649
SpillTree & Child(const size_t child) const
Return the specified child (0 will be left, 1 will be right).
Definition: spill_tree_impl.hpp:631
SpillTree * Right() const
Gets the right child of this node.
Definition: spill_tree.hpp:262
A hybrid spill tree is a variant of binary space trees in which the children of a node can "spill ove...
Definition: spill_tree.hpp:73
size_t GetNearestChild(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0)
Return the index of the nearest child node to the given query point (this is an efficient estimation ...
Definition: spill_tree_impl.hpp:472
size_t NumDescendants() const
Return the number of descendants of this node.
Definition: spill_tree_impl.hpp:666
SpillTree * Parent() const
Gets the parent of this node.
Definition: spill_tree.hpp:267
SpillTree * Left() const
Gets the left child of this node.
Definition: spill_tree.hpp:257
bool IsLeaf() const
Return whether or not this node is a leaf (true if it has no children).
Definition: spill_tree_impl.hpp:433
const BoundType & Bound() const
Return the bound object for this node.
Definition: spill_tree.hpp:244
SpillTree()
A default constructor.
Definition: spill_tree_impl.hpp:880
void serialize(Archive &ar, const uint32_t version)
Serialize the tree.
Definition: spill_tree_impl.hpp:907
ElemType MinimumBoundDistance() const
Return the minimum distance from the center of the node to any bound edge.
Definition: spill_tree_impl.hpp:615
size_t Descendant(const size_t index) const
Return the index (with reference to the dataset) of a particular descendant of this node...
Definition: spill_tree_impl.hpp:681
size_t GetFurthestChild(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0)
Return the index of the furthest child node to the given query point (this is an efficient estimation...
Definition: spill_tree_impl.hpp:498
Definition of generalized hybrid spill tree (SpillTree).
#define CEREAL_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_wrapper.hpp:96
const MatType & Dataset() const
Get the dataset which the tree is built on.
Definition: spill_tree.hpp:272
~SpillTree()
Deletes this node, deallocating the memory for the children and calling their destructors in turn...
Definition: spill_tree_impl.hpp:415
ElemType FurthestPointDistance() const
Return the furthest distance to a point held in this node.
Definition: spill_tree_impl.hpp:575
ElemType ParentDistance() const
Return the distance from the center of this node to the center of the parent node.
Definition: spill_tree.hpp:344
size_t NumChildren() const
Return the number of children in this node.
Definition: spill_tree_impl.hpp:448
If value == true, then VecType is some sort of Armadillo vector or subview.
Definition: arma_traits.hpp:35
void Center(arma::vec &center)
Store the center of the bounding region in the given vector.
Definition: spill_tree.hpp:438
size_t Point(const size_t index) const
Return the index (with reference to the dataset) of a particular point in this node.
Definition: spill_tree_impl.hpp:705