11 #ifndef MLPACK_CORE_TREE_SPILL_TREE_SPILL_TREE_IMPL_HPP 12 #define MLPACK_CORE_TREE_SPILL_TREE_SPILL_TREE_IMPL_HPP 22 template<
typename MetricType,
23 typename StatisticType,
25 template<
typename HyperplaneMetricType>
class HyperplaneType,
26 template<
typename SplitMetricType,
typename SplitMatType>
32 const size_t maxLeafSize,
39 overlappingNode(false),
46 arma::Col<size_t> points;
47 if (dataset->n_cols > 0)
49 points = arma::linspace<arma::Col<size_t>>(0, dataset->n_cols - 1,
53 SplitNode(points, maxLeafSize, tau, rho);
56 stat = StatisticType(*
this);
59 template<
typename MetricType,
60 typename StatisticType,
62 template<
typename HyperplaneMetricType>
class HyperplaneType,
63 template<
typename SplitMetricType,
typename SplitMatType>
69 const size_t maxLeafSize,
76 overlappingNode(false),
80 dataset(new MatType(
std::move(
data))),
83 arma::Col<size_t> points;
84 if (dataset->n_cols > 0)
86 points = arma::linspace<arma::Col<size_t>>(0, dataset->n_cols - 1,
90 SplitNode(points, maxLeafSize, tau, rho);
93 stat = StatisticType(*
this);
96 template<
typename MetricType,
97 typename StatisticType,
99 template<
typename HyperplaneMetricType>
class HyperplaneType,
100 template<
typename SplitMetricType,
typename SplitMatType>
105 arma::Col<size_t>& points,
107 const size_t maxLeafSize,
112 count(points.n_elem),
114 overlappingNode(false),
116 bound(parent->
Dataset().n_rows),
121 SplitNode(points, maxLeafSize, tau, rho);
124 stat = StatisticType(*
this);
131 template<
typename MetricType,
132 typename StatisticType,
134 template<
typename HyperplaneMetricType>
class HyperplaneType,
135 template<
typename SplitMetricType,
typename SplitMatType>
141 parent(other.parent),
144 overlappingNode(other.overlappingNode),
145 hyperplane(other.hyperplane),
148 parentDistance(other.parentDistance),
149 furthestDescendantDistance(other.furthestDescendantDistance),
152 dataset((other.parent == NULL && other.localDataset) ?
153 new MatType(*other.dataset) : other.dataset),
154 localDataset(other.parent == NULL && other.localDataset)
170 if (other.pointsIndex)
171 pointsIndex =
new arma::Col<size_t>(*other.pointsIndex);
174 if (parent == NULL && localDataset)
176 std::queue<SpillTree*> queue;
181 while (!queue.empty())
186 node->dataset = dataset;
188 queue.push(node->left);
190 queue.push(node->right);
198 template<
typename MetricType,
199 typename StatisticType,
201 template<
typename HyperplaneMetricType>
class HyperplaneType,
202 template<
typename SplitMetricType,
typename SplitMatType>
221 parent = other.parent;
224 overlappingNode = other.overlappingNode;
225 hyperplane = other.hyperplane;
228 parentDistance = other.parentDistance;
229 furthestDescendantDistance = other.furthestDescendantDistance;
233 dataset = (other.parent == NULL && other.localDataset) ?
234 new MatType(*other.dataset) : other.dataset;
235 localDataset = other.parent == NULL && other.localDataset;
251 if (other.pointsIndex)
252 pointsIndex =
new arma::Col<size_t>(*other.pointsIndex);
255 if (parent == NULL && localDataset)
257 std::queue<SpillTree*> queue;
262 while (!queue.empty())
267 node->dataset = dataset;
269 queue.push(node->left);
271 queue.push(node->right);
280 template<
typename MetricType,
281 typename StatisticType,
283 template<
typename HyperplaneMetricType>
class HyperplaneType,
284 template<
typename SplitMetricType,
typename SplitMatType>
290 parent(other.parent),
292 pointsIndex(other.pointsIndex),
293 overlappingNode(other.overlappingNode),
294 hyperplane(other.hyperplane),
295 bound(
std::move(other.bound)),
296 stat(
std::move(other.stat)),
297 parentDistance(other.parentDistance),
298 furthestDescendantDistance(other.furthestDescendantDistance),
299 minimumBoundDistance(other.minimumBoundDistance),
300 dataset(other.dataset),
301 localDataset(other.localDataset)
308 other.pointsIndex = NULL;
309 other.parentDistance = 0.0;
310 other.furthestDescendantDistance = 0.0;
311 other.minimumBoundDistance = 0.0;
312 other.dataset = NULL;
313 other.localDataset =
false;
319 right->parent =
this;
325 template<
typename MetricType,
326 typename StatisticType,
328 template<
typename HyperplaneMetricType>
class HyperplaneType,
329 template<
typename SplitMetricType,
typename SplitMatType>
348 parent = other.parent;
350 pointsIndex = other.pointsIndex;
351 overlappingNode = other.overlappingNode;
352 hyperplane = other.hyperplane;
353 bound = std::move(other.bound);
354 stat = std::move(other.stat);
355 parentDistance = other.parentDistance;
356 furthestDescendantDistance = other.furthestDescendantDistance;
357 minimumBoundDistance = other.minimumBoundDistance;
358 dataset = other.dataset;
359 localDataset = other.localDataset;
366 other.pointsIndex = NULL;
367 other.parentDistance = 0.0;
368 other.furthestDescendantDistance = 0.0;
369 other.minimumBoundDistance = 0.0;
370 other.dataset = NULL;
371 other.localDataset =
false;
377 right->parent =
this;
385 template<
typename MetricType,
386 typename StatisticType,
388 template<
typename HyperplaneMetricType>
class HyperplaneType,
389 template<
typename SplitMetricType,
typename SplitMatType>
391 template<
typename Archive>
395 const typename std::enable_if_t<cereal::is_loading<Archive>()>*) :
400 ar(CEREAL_NVP(*
this));
408 template<
typename MetricType,
409 typename StatisticType,
411 template<
typename HyperplaneMetricType>
class HyperplaneType,
412 template<
typename SplitMetricType,
typename SplitMatType>
422 if (!parent && localDataset)
426 template<
typename MetricType,
427 typename StatisticType,
429 template<
typename HyperplaneMetricType>
class HyperplaneType,
430 template<
typename SplitMetricType,
typename SplitMatType>
432 inline bool SpillTree<MetricType, StatisticType, MatType, HyperplaneType,
441 template<
typename MetricType,
442 typename StatisticType,
444 template<
typename HyperplaneMetricType>
class HyperplaneType,
445 template<
typename SplitMetricType,
typename SplitMatType>
447 inline size_t SpillTree<MetricType, StatisticType, MatType, HyperplaneType,
464 template<
typename MetricType,
465 typename StatisticType,
467 template<
typename HyperplaneMetricType>
class HyperplaneType,
468 template<
typename SplitMetricType,
typename SplitMatType>
470 template<
typename VecType>
471 size_t SpillTree<MetricType, StatisticType, MatType, HyperplaneType,
473 const VecType& point,
476 if (
IsLeaf() || !left || !right)
479 if (hyperplane.Left(point))
490 template<
typename MetricType,
491 typename StatisticType,
493 template<
typename HyperplaneMetricType>
class HyperplaneType,
494 template<
typename SplitMetricType,
typename SplitMatType>
496 template<
typename VecType>
497 size_t SpillTree<MetricType, StatisticType, MatType, HyperplaneType,
499 const VecType& point,
502 if (
IsLeaf() || !left || !right)
505 if (hyperplane.Left(point))
516 template<
typename MetricType,
517 typename StatisticType,
519 template<
typename HyperplaneMetricType>
class HyperplaneType,
520 template<
typename SplitMetricType,
typename SplitMatType>
522 size_t SpillTree<MetricType, StatisticType, MatType, HyperplaneType,
525 if (
IsLeaf() || !left || !right)
528 if (hyperplane.Left(queryNode.
Bound()))
530 if (hyperplane.Right(queryNode.
Bound()))
542 template<
typename MetricType,
543 typename StatisticType,
545 template<
typename HyperplaneMetricType>
class HyperplaneType,
546 template<
typename SplitMetricType,
typename SplitMatType>
548 size_t SpillTree<MetricType, StatisticType, MatType, HyperplaneType,
551 if (
IsLeaf() || !left || !right)
554 if (hyperplane.Left(queryNode.
Bound()))
556 if (hyperplane.Right(queryNode.
Bound()))
566 template<
typename MetricType,
567 typename StatisticType,
569 template<
typename HyperplaneMetricType>
class HyperplaneType,
570 template<
typename SplitMetricType,
typename SplitMatType>
572 inline typename SpillTree<MetricType, StatisticType, MatType, HyperplaneType,
581 return 0.5 * bound.Diameter();
591 template<
typename MetricType,
592 typename StatisticType,
594 template<
typename HyperplaneMetricType>
class HyperplaneType,
595 template<
typename SplitMetricType,
typename SplitMatType>
597 inline typename SpillTree<MetricType, StatisticType, MatType, HyperplaneType,
602 return furthestDescendantDistance;
606 template<
typename MetricType,
607 typename StatisticType,
609 template<
typename HyperplaneMetricType>
class HyperplaneType,
610 template<
typename SplitMetricType,
typename SplitMatType>
612 inline typename SpillTree<MetricType, StatisticType, MatType, HyperplaneType,
617 return bound.MinWidth() / 2.0;
623 template<
typename MetricType,
624 typename StatisticType,
626 template<
typename HyperplaneMetricType>
class HyperplaneType,
627 template<
typename SplitMetricType,
typename SplitMatType>
642 template<
typename MetricType,
643 typename StatisticType,
645 template<
typename HyperplaneMetricType>
class HyperplaneType,
646 template<
typename SplitMetricType,
typename SplitMatType>
648 inline size_t SpillTree<MetricType, StatisticType, MatType, HyperplaneType,
659 template<
typename MetricType,
660 typename StatisticType,
662 template<
typename HyperplaneMetricType>
class HyperplaneType,
663 template<
typename SplitMetricType,
typename SplitMatType>
665 inline size_t SpillTree<MetricType, StatisticType, MatType, HyperplaneType,
674 template<
typename MetricType,
675 typename StatisticType,
677 template<
typename HyperplaneMetricType>
class HyperplaneType,
678 template<
typename SplitMetricType,
typename SplitMatType>
680 inline size_t SpillTree<MetricType, StatisticType, MatType, HyperplaneType,
683 if (
IsLeaf() || overlappingNode)
684 return (*pointsIndex)[index];
698 template<
typename MetricType,
699 typename StatisticType,
701 template<
typename HyperplaneMetricType>
class HyperplaneType,
702 template<
typename SplitMetricType,
typename SplitMatType>
704 inline size_t SpillTree<MetricType, StatisticType, MatType, HyperplaneType,
708 return (*pointsIndex)[index];
710 return (
size_t() - 1);
713 template<
typename MetricType,
714 typename StatisticType,
716 template<
typename HyperplaneMetricType>
class HyperplaneType,
717 template<
typename SplitMetricType,
typename SplitMatType>
721 const size_t maxLeafSize,
726 for (
size_t i = 0; i < points.n_elem; ++i)
727 bound |= dataset->col(points[i]);
730 furthestDescendantDistance = 0.5 * bound.Diameter();
733 if (points.n_elem <= maxLeafSize)
735 pointsIndex =
new arma::Col<size_t>();
736 pointsIndex->swap(points);
740 const bool split = SplitType<MetricType, MatType>::SplitSpace(bound,
741 *dataset, points, hyperplane);
746 pointsIndex =
new arma::Col<size_t>();
747 pointsIndex->swap(points);
751 arma::Col<size_t> leftPoints, rightPoints;
753 overlappingNode = SplitPoints(tau, rho, points, leftPoints, rightPoints);
759 pointsIndex =
new arma::Col<size_t>();
760 pointsIndex->swap(points);
765 arma::Col<size_t>().swap(points);
770 left =
new SpillTree(
this, leftPoints, tau, maxLeafSize, rho);
771 right =
new SpillTree(
this, rightPoints, tau, maxLeafSize, rho);
774 arma::vec center, leftCenter, rightCenter;
777 right->
Center(rightCenter);
779 const ElemType leftParentDistance = MetricType::Evaluate(center, leftCenter);
780 const ElemType rightParentDistance = MetricType::Evaluate(center,
787 template<
typename MetricType,
788 typename StatisticType,
790 template<
typename HyperplaneMetricType>
class HyperplaneType,
791 template<
typename SplitMetricType,
typename SplitMatType>
796 const arma::Col<size_t>& points,
797 arma::Col<size_t>& leftPoints,
798 arma::Col<size_t>& rightPoints)
800 arma::vec projections(points.n_elem);
801 size_t left = 0, right = 0, leftFrontier = 0, rightFrontier = 0;
804 for (
size_t i = 0; i < points.n_elem; ++i)
807 projections[i] = hyperplane.Project(dataset->col(points[i]));
808 if (projections[i] <= 0)
811 if (projections[i] > -tau)
817 if (projections[i] < tau)
822 const double p1 = (double) (left + rightFrontier) / points.n_elem;
823 const double p2 = (double) (right + leftFrontier) / points.n_elem;
825 if ((p1 <= rho || rightFrontier == 0) &&
826 (p2 <= rho || leftFrontier == 0))
831 const size_t leftUnique = points.n_elem - right - leftFrontier;
832 const size_t overlap = leftFrontier + rightFrontier;
834 leftPoints.resize(left + rightFrontier);
835 rightPoints.resize(right + leftFrontier);
836 for (
size_t i = 0, rc = overlap, lc = 0, rf = 0, lf = leftUnique;
837 i < points.n_elem; ++i)
841 if (projections[i] < -tau)
842 leftPoints[lc++] = points[i];
843 else if (projections[i] < tau)
844 leftPoints[lf++] = points[i];
846 if (projections[i] > tau)
847 rightPoints[rc++] = points[i];
848 else if (projections[i] > -tau)
849 rightPoints[rf++] = points[i];
859 leftPoints.resize(left);
860 rightPoints.resize(right);
861 for (
size_t i = 0, rc = 0, lc = 0; i < points.n_elem; ++i)
863 if (projections[i] <= 0)
864 leftPoints[lc++] = points[i];
866 rightPoints[rc++] = points[i];
873 template<
typename MetricType,
874 typename StatisticType,
876 template<
typename HyperplaneMetricType>
class HyperplaneType,
877 template<
typename SplitMetricType,
typename SplitMatType>
886 overlappingNode(false),
889 furthestDescendantDistance(0),
899 template<
typename MetricType,
900 typename StatisticType,
902 template<
typename HyperplaneMetricType>
class HyperplaneType,
903 template<
typename SplitMetricType,
typename SplitMatType>
905 template<
typename Archive>
910 if (cereal::is_loading<Archive>())
916 if (!parent && localDataset)
924 if (cereal::is_loading<Archive>())
928 ar(CEREAL_NVP(count));
930 ar(CEREAL_NVP(overlappingNode));
931 ar(CEREAL_NVP(hyperplane));
932 ar(CEREAL_NVP(bound));
933 ar(CEREAL_NVP(stat));
934 ar(CEREAL_NVP(parentDistance));
935 ar(CEREAL_NVP(furthestDescendantDistance));
937 MatType*& datasetPtr =
const_cast<MatType*&
>(dataset);
940 bool hasLeft = (left != NULL);
941 bool hasRight = (right != NULL);
942 bool hasParent = (parent != NULL);
944 ar(CEREAL_NVP(hasLeft));
945 ar(CEREAL_NVP(hasRight));
946 ar(CEREAL_NVP(hasParent));
955 if (cereal::is_loading<Archive>())
960 left->localDataset =
false;
964 right->parent =
this;
965 right->localDataset =
false;
972 std::stack<SpillTree*> stack;
977 while (!stack.empty())
981 node->dataset = dataset;
983 stack.push(node->left);
985 stack.push(node->right);
SpillTree & operator=(const SpillTree &other)
Copy the given Spill Tree.
Definition: spill_tree_impl.hpp:206
MatType::elem_type ElemType
The type of element held in MatType.
Definition: spill_tree.hpp:79
ElemType FurthestDescendantDistance() const
Return the furthest possible descendant distance.
Definition: spill_tree_impl.hpp:600
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Definition: pointer_wrapper.hpp:23
size_t NumPoints() const
Return the number of points in this node (0 if not a leaf).
Definition: spill_tree_impl.hpp:649
SpillTree & Child(const size_t child) const
Return the specified child (0 will be left, 1 will be right).
Definition: spill_tree_impl.hpp:631
SpillTree * Right() const
Gets the right child of this node.
Definition: spill_tree.hpp:262
A hybrid spill tree is a variant of binary space trees in which the children of a node can "spill ove...
Definition: spill_tree.hpp:73
size_t GetNearestChild(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0)
Return the index of the nearest child node to the given query point (this is an efficient estimation ...
Definition: spill_tree_impl.hpp:472
size_t NumDescendants() const
Return the number of descendants of this node.
Definition: spill_tree_impl.hpp:666
SpillTree * Parent() const
Gets the parent of this node.
Definition: spill_tree.hpp:267
SpillTree * Left() const
Gets the left child of this node.
Definition: spill_tree.hpp:257
bool IsLeaf() const
Return whether or not this node is a leaf (true if it has no children).
Definition: spill_tree_impl.hpp:433
const BoundType & Bound() const
Return the bound object for this node.
Definition: spill_tree.hpp:244
SpillTree()
A default constructor.
Definition: spill_tree_impl.hpp:880
void serialize(Archive &ar, const uint32_t version)
Serialize the tree.
Definition: spill_tree_impl.hpp:907
ElemType MinimumBoundDistance() const
Return the minimum distance from the center of the node to any bound edge.
Definition: spill_tree_impl.hpp:615
size_t Descendant(const size_t index) const
Return the index (with reference to the dataset) of a particular descendant of this node...
Definition: spill_tree_impl.hpp:681
size_t GetFurthestChild(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0)
Return the index of the furthest child node to the given query point (this is an efficient estimation...
Definition: spill_tree_impl.hpp:498
Definition of generalized hybrid spill tree (SpillTree).
#define CEREAL_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_wrapper.hpp:96
const MatType & Dataset() const
Get the dataset which the tree is built on.
Definition: spill_tree.hpp:272
~SpillTree()
Deletes this node, deallocating the memory for the children and calling their destructors in turn...
Definition: spill_tree_impl.hpp:415
ElemType FurthestPointDistance() const
Return the furthest distance to a point held in this node.
Definition: spill_tree_impl.hpp:575
ElemType ParentDistance() const
Return the distance from the center of this node to the center of the parent node.
Definition: spill_tree.hpp:344
size_t NumChildren() const
Return the number of children in this node.
Definition: spill_tree_impl.hpp:448
If value == true, then VecType is some sort of Armadillo vector or subview.
Definition: arma_traits.hpp:35
void Center(arma::vec ¢er)
Store the center of the bounding region in the given vector.
Definition: spill_tree.hpp:438
size_t Point(const size_t index) const
Return the index (with reference to the dataset) of a particular point in this node.
Definition: spill_tree_impl.hpp:705