11 #ifndef MLPACK_CORE_TREE_BINARY_SPACE_TREE_BINARY_SPACE_TREE_IMPL_HPP 12 #define MLPACK_CORE_TREE_BINARY_SPACE_TREE_BINARY_SPACE_TREE_IMPL_HPP 25 template<
typename MetricType,
26 typename StatisticType,
28 template<
typename BoundMetricType,
typename...>
class BoundType,
29 template<
typename SplitBoundType,
typename SplitMatType>
34 const size_t maxLeafSize) :
42 dataset(new MatType(data))
45 SplitType<BoundType<MetricType>, MatType> splitter;
46 SplitNode(maxLeafSize, splitter);
49 stat = StatisticType(*
this);
52 template<
typename MetricType,
53 typename StatisticType,
55 template<
typename BoundMetricType,
typename...>
class BoundType,
56 template<
typename SplitBoundType,
typename SplitMatType>
61 std::vector<size_t>& oldFromNew,
62 const size_t maxLeafSize) :
70 dataset(new MatType(data))
73 oldFromNew.resize(data.n_cols);
74 for (
size_t i = 0; i < data.n_cols; ++i)
78 SplitType<BoundType<MetricType>, MatType> splitter;
79 SplitNode(oldFromNew, maxLeafSize, splitter);
82 stat = StatisticType(*
this);
85 template<
typename MetricType,
86 typename StatisticType,
88 template<
typename BoundMetricType,
typename...>
class BoundType,
89 template<
typename SplitBoundType,
typename SplitMatType>
94 std::vector<size_t>& oldFromNew,
95 std::vector<size_t>& newFromOld,
96 const size_t maxLeafSize) :
104 dataset(new MatType(data))
107 oldFromNew.resize(data.n_cols);
108 for (
size_t i = 0; i < data.n_cols; ++i)
112 SplitType<BoundType<MetricType>, MatType> splitter;
113 SplitNode(oldFromNew, maxLeafSize, splitter);
116 stat = StatisticType(*
this);
119 newFromOld.resize(data.n_cols);
120 for (
size_t i = 0; i < data.n_cols; ++i)
121 newFromOld[oldFromNew[i]] = i;
124 template<
typename MetricType,
125 typename StatisticType,
127 template<
typename BoundMetricType,
typename...>
class BoundType,
128 template<
typename SplitBoundType,
typename SplitMatType>
139 dataset(new MatType(
std::move(
data)))
142 SplitType<BoundType<MetricType>, MatType> splitter;
143 SplitNode(maxLeafSize, splitter);
146 stat = StatisticType(*
this);
149 template<
typename MetricType,
150 typename StatisticType,
152 template<
typename BoundMetricType,
typename...>
class BoundType,
153 template<
typename SplitBoundType,
typename SplitMatType>
158 std::vector<size_t>& oldFromNew,
159 const size_t maxLeafSize) :
167 dataset(new MatType(
std::move(
data)))
170 oldFromNew.resize(dataset->n_cols);
171 for (
size_t i = 0; i < dataset->n_cols; ++i)
175 SplitType<BoundType<MetricType>, MatType> splitter;
176 SplitNode(oldFromNew, maxLeafSize, splitter);
179 stat = StatisticType(*
this);
182 template<
typename MetricType,
183 typename StatisticType,
185 template<
typename BoundMetricType,
typename...>
class BoundType,
186 template<
typename SplitBoundType,
typename SplitMatType>
191 std::vector<size_t>& oldFromNew,
192 std::vector<size_t>& newFromOld,
193 const size_t maxLeafSize) :
201 dataset(new MatType(
std::move(
data)))
204 oldFromNew.resize(dataset->n_cols);
205 for (
size_t i = 0; i < dataset->n_cols; ++i)
209 SplitType<BoundType<MetricType>, MatType> splitter;
210 SplitNode(oldFromNew, maxLeafSize, splitter);
213 stat = StatisticType(*
this);
216 newFromOld.resize(dataset->n_cols);
217 for (
size_t i = 0; i < dataset->n_cols; ++i)
218 newFromOld[oldFromNew[i]] = i;
221 template<
typename MetricType,
222 typename StatisticType,
224 template<
typename BoundMetricType,
typename...>
class BoundType,
225 template<
typename SplitBoundType,
typename SplitMatType>
232 SplitType<BoundType<MetricType>, MatType>& splitter,
233 const size_t maxLeafSize) :
239 bound(parent->
Dataset().n_rows),
243 SplitNode(maxLeafSize, splitter);
246 stat = StatisticType(*
this);
249 template<
typename MetricType,
250 typename StatisticType,
252 template<
typename BoundMetricType,
typename...>
class BoundType,
253 template<
typename SplitBoundType,
typename SplitMatType>
260 std::vector<size_t>& oldFromNew,
261 SplitType<BoundType<MetricType>, MatType>& splitter,
262 const size_t maxLeafSize) :
268 bound(parent->
Dataset().n_rows),
273 assert(oldFromNew.size() == dataset->n_cols);
276 SplitNode(oldFromNew, maxLeafSize, splitter);
279 stat = StatisticType(*
this);
282 template<
typename MetricType,
283 typename StatisticType,
285 template<
typename BoundMetricType,
typename...>
class BoundType,
286 template<
typename SplitBoundType,
typename SplitMatType>
293 std::vector<size_t>& oldFromNew,
294 std::vector<size_t>& newFromOld,
295 SplitType<BoundType<MetricType>, MatType>& splitter,
296 const size_t maxLeafSize) :
302 bound(parent->
Dataset()->n_rows),
310 SplitNode(oldFromNew, maxLeafSize, splitter);
313 stat = StatisticType(*
this);
316 newFromOld.resize(dataset->n_cols);
317 for (
size_t i = 0; i < dataset->n_cols; ++i)
318 newFromOld[oldFromNew[i]] = i;
325 template<
typename MetricType,
326 typename StatisticType,
328 template<
typename BoundMetricType,
typename...>
class BoundType,
329 template<
typename SplitBoundType,
typename SplitMatType>
336 parent(other.parent),
341 parentDistance(other.parentDistance),
342 furthestDescendantDistance(other.furthestDescendantDistance),
343 minimumBoundDistance(other.minimumBoundDistance),
345 dataset((other.parent == NULL) ? new MatType(*other.dataset) : NULL)
363 std::queue<BinarySpaceTree*> queue;
368 while (!queue.empty())
373 node->dataset = dataset;
375 queue.push(node->left);
377 queue.push(node->right);
385 template<
typename MetricType,
386 typename StatisticType,
388 template<
typename BoundMetricType,
typename...>
class BoundType,
389 template<
typename SplitBoundType,
typename SplitMatType>
407 begin = other.
Begin();
408 count = other.
Count();
415 dataset = ((other.parent == NULL) ?
new MatType(*other.dataset) : NULL);
433 std::queue<BinarySpaceTree*> queue;
438 while (!queue.empty())
443 node->dataset = dataset;
445 queue.push(node->left);
447 queue.push(node->right);
457 template<
typename MetricType,
458 typename StatisticType,
460 template<
typename BoundMetricType,
typename...>
class BoundType,
461 template<
typename SplitBoundType,
typename SplitMatType>
478 right = other.
Right();
479 begin = other.Begin();
480 count = other.Count();
481 bound = std::move(other.bound);
482 stat = std::move(other.stat);
483 parentDistance = other.ParentDistance();
484 furthestDescendantDistance = other.FurthestDescendantDistance();
485 minimumBoundDistance = other.MinimumBoundDistance();
486 dataset = other.dataset;
493 other.parentDistance = 0.0;
494 other.furthestDescendantDistance = 0.0;
495 other.minimumBoundDistance = 0.0;
496 other.dataset = NULL;
505 template<
typename MetricType,
506 typename StatisticType,
508 template<
typename BoundMetricType,
typename...>
class BoundType,
509 template<
typename SplitBoundType,
typename SplitMatType>
515 parent(other.parent),
518 bound(
std::move(other.bound)),
519 stat(
std::move(other.stat)),
520 parentDistance(other.parentDistance),
521 furthestDescendantDistance(other.furthestDescendantDistance),
522 minimumBoundDistance(other.minimumBoundDistance),
523 dataset(other.dataset)
532 other.parentDistance = 0.0;
533 other.furthestDescendantDistance = 0.0;
534 other.minimumBoundDistance = 0.0;
535 other.dataset = NULL;
541 right->parent =
this;
547 template<
typename MetricType,
548 typename StatisticType,
550 template<
typename BoundMetricType,
typename...>
class BoundType,
551 template<
typename SplitBoundType,
typename SplitMatType>
553 template<
typename Archive>
557 const typename std::enable_if_t<cereal::is_loading<Archive>()>*) :
562 ar(CEREAL_NVP(*
this));
570 template<
typename MetricType,
571 typename StatisticType,
573 template<
typename BoundMetricType,
typename...>
class BoundType,
574 template<
typename SplitBoundType,
typename SplitMatType>
587 template<
typename MetricType,
588 typename StatisticType,
590 template<
typename BoundMetricType,
typename...>
class BoundType,
591 template<
typename SplitBoundType,
typename SplitMatType>
593 inline bool BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
602 template<
typename MetricType,
603 typename StatisticType,
605 template<
typename BoundMetricType,
typename...>
class BoundType,
606 template<
typename SplitBoundType,
typename SplitMatType>
608 inline size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
623 template<
typename MetricType,
624 typename StatisticType,
626 template<
typename BoundMetricType,
typename...>
class BoundType,
627 template<
typename SplitBoundType,
typename SplitMatType>
629 template<
typename VecType>
632 const VecType& point,
635 if (
IsLeaf() || !left || !right)
647 template<
typename MetricType,
648 typename StatisticType,
650 template<
typename BoundMetricType,
typename...>
class BoundType,
651 template<
typename SplitBoundType,
typename SplitMatType>
653 template<
typename VecType>
656 const VecType& point,
659 if (
IsLeaf() || !left || !right)
671 template<
typename MetricType,
672 typename StatisticType,
674 template<
typename BoundMetricType,
typename...>
class BoundType,
675 template<
typename SplitBoundType,
typename SplitMatType>
680 if (
IsLeaf() || !left || !right)
685 if (leftDist < rightDist)
687 if (rightDist < leftDist)
696 template<
typename MetricType,
697 typename StatisticType,
699 template<
typename BoundMetricType,
typename...>
class BoundType,
700 template<
typename SplitBoundType,
typename SplitMatType>
705 if (
IsLeaf() || !left || !right)
710 if (leftDist > rightDist)
712 if (rightDist > leftDist)
721 template<
typename MetricType,
722 typename StatisticType,
724 template<
typename BoundMetricType,
typename...>
class BoundType,
725 template<
typename SplitBoundType,
typename SplitMatType>
728 typename BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
737 return 0.5 * bound.Diameter();
747 template<
typename MetricType,
748 typename StatisticType,
750 template<
typename BoundMetricType,
typename...>
class BoundType,
751 template<
typename SplitBoundType,
typename SplitMatType>
754 typename BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
759 return furthestDescendantDistance;
763 template<
typename MetricType,
764 typename StatisticType,
766 template<
typename BoundMetricType,
typename...>
class BoundType,
767 template<
typename SplitBoundType,
typename SplitMatType>
770 typename BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
775 return bound.MinWidth() / 2.0;
781 template<
typename MetricType,
782 typename StatisticType,
784 template<
typename BoundMetricType,
typename...>
class BoundType,
785 template<
typename SplitBoundType,
typename SplitMatType>
801 template<
typename MetricType,
802 typename StatisticType,
804 template<
typename BoundMetricType,
typename...>
class BoundType,
805 template<
typename SplitBoundType,
typename SplitMatType>
807 inline size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
819 template<
typename MetricType,
820 typename StatisticType,
822 template<
typename BoundMetricType,
typename...>
class BoundType,
823 template<
typename SplitBoundType,
typename SplitMatType>
825 inline size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
834 template<
typename MetricType,
835 typename StatisticType,
837 template<
typename BoundMetricType,
typename...>
class BoundType,
838 template<
typename SplitBoundType,
typename SplitMatType>
840 inline size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
843 return (begin + index);
849 template<
typename MetricType,
850 typename StatisticType,
852 template<
typename BoundMetricType,
typename...>
class BoundType,
853 template<
typename SplitBoundType,
typename SplitMatType>
855 inline size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
858 return (begin + index);
861 template<
typename MetricType,
862 typename StatisticType,
864 template<
typename BoundMetricType,
typename...>
class BoundType,
865 template<
typename SplitBoundType,
typename SplitMatType>
869 SplitType<BoundType<MetricType>, MatType>& splitter)
875 furthestDescendantDistance = 0.5 * bound.Diameter();
878 if (count <= maxLeafSize)
887 typename Split::SplitInfo splitInfo;
889 const bool split = splitter.SplitNode(bound, *dataset, begin, count,
900 splitCol = splitter.PerformSplit(*dataset, begin, count, splitInfo);
902 assert(splitCol > begin);
903 assert(splitCol < begin + count);
910 splitter, maxLeafSize);
913 arma::vec center, leftCenter, rightCenter;
916 right->
Center(rightCenter);
918 const ElemType leftParentDistance = bound.Metric().Evaluate(center,
920 const ElemType rightParentDistance = bound.Metric().Evaluate(center,
927 template<
typename MetricType,
928 typename StatisticType,
930 template<
typename BoundMetricType,
typename...>
class BoundType,
931 template<
typename SplitBoundType,
typename SplitMatType>
934 SplitNode(std::vector<size_t>& oldFromNew,
935 const size_t maxLeafSize,
936 SplitType<BoundType<MetricType>, MatType>& splitter)
942 furthestDescendantDistance = 0.5 * bound.Diameter();
945 if (count <= maxLeafSize)
954 typename Split::SplitInfo splitInfo;
956 const bool split = splitter.SplitNode(bound, *dataset, begin, count,
967 splitCol = splitter.PerformSplit(*dataset, begin, count, splitInfo,
970 assert(splitCol > begin);
971 assert(splitCol < begin + count);
976 splitter, maxLeafSize);
978 oldFromNew, splitter, maxLeafSize);
981 arma::vec center, leftCenter, rightCenter;
984 right->
Center(rightCenter);
986 const ElemType leftParentDistance = bound.Metric().Evaluate(center,
988 const ElemType rightParentDistance = bound.Metric().Evaluate(center,
995 template<
typename MetricType,
996 typename StatisticType,
998 template<
typename BoundMetricType,
typename...>
class BoundType,
999 template<
typename SplitBoundType,
typename SplitMatType>
1001 template<
typename BoundType2>
1006 boundToUpdate |= dataset->cols(begin, begin + count - 1);
1009 template<
typename MetricType,
1010 typename StatisticType,
1012 template<
typename BoundMetricType,
typename...>
class BoundType,
1013 template<
typename SplitBoundType,
typename SplitMatType>
1021 boundToUpdate |= dataset->cols(begin, begin + count - 1);
1025 if (parent->left != NULL && parent->left !=
this)
1027 boundToUpdate.
HollowCenter() = parent->left->bound.Center();
1028 boundToUpdate.
InnerRadius() = std::numeric_limits<ElemType>::max();
1032 boundToUpdate |= dataset->cols(begin, begin + count - 1);
1036 template<
typename MetricType,
1037 typename StatisticType,
1039 template<
typename BoundMetricType,
typename...>
class BoundType,
1040 template<
typename SplitBoundType,
typename SplitMatType>
1051 furthestDescendantDistance(0),
1060 template<
typename MetricType,
1061 typename StatisticType,
1063 template<
typename BoundMetricType,
typename...>
class BoundType,
1064 template<
typename SplitBoundType,
typename SplitMatType>
1066 template<
typename Archive>
1071 if (cereal::is_loading<Archive>())
1085 ar(CEREAL_NVP(begin));
1086 ar(CEREAL_NVP(count));
1087 ar(CEREAL_NVP(bound));
1088 ar(CEREAL_NVP(stat));
1090 ar(CEREAL_NVP(parentDistance));
1091 ar(CEREAL_NVP(furthestDescendantDistance));
1094 bool hasLeft = (left != NULL);
1095 bool hasRight = (right != NULL);
1096 bool hasParent = (parent != NULL);
1098 ar(CEREAL_NVP(hasLeft));
1099 ar(CEREAL_NVP(hasRight));
1100 ar(CEREAL_NVP(hasParent));
1107 MatType*& datasetTemp =
const_cast<MatType*&
>(dataset);
1111 if (cereal::is_loading<Archive>())
1114 left->parent =
this;
1116 right->parent =
this;
1121 std::stack<BinarySpaceTree*> stack;
1126 while (!stack.empty())
1130 node->dataset = dataset;
1132 stack.push(node->left);
1134 stack.push(node->right);
BinarySpaceTree * Parent() const
Gets the parent of this node.
Definition: binary_space_tree.hpp:342
size_t NumDescendants() const
Return the number of descendants of this node.
Definition: binary_space_tree_impl.hpp:826
void serialize(Archive &ar, const uint32_t version)
Serialize the tree.
Definition: binary_space_tree_impl.hpp:1068
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
BinarySpaceTree & operator=(const BinarySpaceTree &other)
Copy the given BinarySaceTree.
Definition: binary_space_tree_impl.hpp:393
MatType::elem_type ElemType
The type of element held in MatType.
Definition: binary_space_tree.hpp:60
size_t Descendant(const size_t index) const
Return the index (with reference to the dataset) of a particular descendant of this node...
Definition: binary_space_tree_impl.hpp:841
size_t Count() const
Return the number of points in this subset.
Definition: binary_space_tree.hpp:503
Definition: pointer_wrapper.hpp:23
const arma::Col< ElemType > & HollowCenter() const
Get the center point of the hollow.
Definition: hollow_ball_bound.hpp:111
BinarySpaceTree * Right() const
Gets the right child of this node.
Definition: binary_space_tree.hpp:337
ElemType MinimumBoundDistance() const
Return the minimum distance from the center of the node to any bound edge.
Definition: binary_space_tree_impl.hpp:773
A binary space partitioning tree, such as a KD-tree or a ball tree.
Definition: binary_space_tree.hpp:54
~BinarySpaceTree()
Deletes this node, deallocating the memory for the children and calling their destructors in turn...
Definition: binary_space_tree_impl.hpp:577
ElemType MinDistance(const BinarySpaceTree &other) const
Return the minimum distance to another node.
Definition: binary_space_tree.hpp:453
size_t Point(const size_t index) const
Return the index (with reference to the dataset) of a particular point in this node.
Definition: binary_space_tree_impl.hpp:856
ElemType ParentDistance() const
Return the distance from the center of this node to the center of the parent node.
Definition: binary_space_tree.hpp:407
ElemType MaxDistance(const BinarySpaceTree &other) const
Return the maximum distance to another node.
Definition: binary_space_tree.hpp:459
size_t NumPoints() const
Return the number of points in this node (0 if not a leaf).
Definition: binary_space_tree_impl.hpp:808
size_t GetFurthestChild(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0)
Return the index of the furthest child node to the given query point.
Definition: binary_space_tree_impl.hpp:655
BinarySpaceTree * Left() const
Gets the left child of this node.
Definition: binary_space_tree.hpp:332
const MatType & Dataset() const
Get the dataset which the tree is built on.
Definition: binary_space_tree.hpp:347
BinarySpaceTree & Child(const size_t child) const
Return the specified child (0 will be left, 1 will be right).
Definition: binary_space_tree_impl.hpp:790
bool IsLeaf() const
Return whether or not this node is a leaf (true if it has no children).
Definition: binary_space_tree_impl.hpp:594
BinarySpaceTree()
A default constructor.
Definition: binary_space_tree_impl.hpp:1043
Definition of generalized binary space partitioning tree (BinarySpaceTree).
size_t GetNearestChild(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0)
Return the index of the nearest child node to the given query point.
Definition: binary_space_tree_impl.hpp:631
void Center(arma::vec ¢er) const
Store the center of the bounding region in the given vector.
Definition: binary_space_tree.hpp:508
size_t Begin() const
Return the index of the beginning point of this subset.
Definition: binary_space_tree.hpp:498
#define CEREAL_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_wrapper.hpp:96
ElemType FurthestPointDistance() const
Return the furthest distance to a point held in this node.
Definition: binary_space_tree_impl.hpp:731
ElemType InnerRadius() const
Get the innner radius of the ball.
Definition: hollow_ball_bound.hpp:101
Hollow ball bound encloses a set of points at a specific distance (radius) from a specific point (cen...
Definition: hollow_ball_bound.hpp:33
ElemType FurthestDescendantDistance() const
Return the furthest possible descendant distance.
Definition: binary_space_tree_impl.hpp:757
If value == true, then VecType is some sort of Armadillo vector or subview.
Definition: arma_traits.hpp:35
static void Assert(bool condition, const std::string &message="Assert Failed.")
Checks if the specified condition is true.
Definition: log.cpp:38
size_t NumChildren() const
Return the number of children in this node.
Definition: binary_space_tree_impl.hpp:609