12 #ifndef MLPACK_CORE_TREE_OCTREE_OCTREE_IMPL_HPP 13 #define MLPACK_CORE_TREE_OCTREE_OCTREE_IMPL_HPP 23 template<
typename MetricType,
typename StatisticType,
typename MatType>
25 const size_t maxLeafSize) :
27 count(dataset.n_cols),
28 bound(dataset.n_rows),
29 dataset(new MatType(dataset)),
36 bound |= *this->dataset;
40 double maxWidth = 0.0;
41 for (
size_t i = 0; i < bound.
Dim(); ++i)
42 if (bound[i].Hi() - bound[i].Lo() > maxWidth)
43 maxWidth = bound[i].Hi() - bound[i].Lo();
45 SplitNode(center, maxWidth, maxLeafSize);
47 furthestDescendantDistance = 0.5 * bound.
Diameter();
51 furthestDescendantDistance = 0.0;
55 stat = StatisticType(*
this);
59 template<
typename MetricType,
typename StatisticType,
typename MatType>
61 const MatType& dataset,
62 std::vector<size_t>& oldFromNew,
63 const size_t maxLeafSize) :
65 count(dataset.n_cols),
66 bound(dataset.n_rows),
67 dataset(new MatType(dataset)),
71 oldFromNew.resize(this->dataset->n_cols);
72 for (
size_t i = 0; i < this->dataset->n_cols; ++i)
78 bound |= *this->dataset;
82 double maxWidth = 0.0;
83 for (
size_t i = 0; i < bound.
Dim(); ++i)
84 if (bound[i].Hi() - bound[i].Lo() > maxWidth)
85 maxWidth = bound[i].Hi() - bound[i].Lo();
87 SplitNode(center, maxWidth, oldFromNew, maxLeafSize);
89 furthestDescendantDistance = 0.5 * bound.
Diameter();
93 furthestDescendantDistance = 0.0;
97 stat = StatisticType(*
this);
101 template<
typename MetricType,
typename StatisticType,
typename MatType>
103 const MatType& dataset,
104 std::vector<size_t>& oldFromNew,
105 std::vector<size_t>& newFromOld,
106 const size_t maxLeafSize) :
108 count(dataset.n_cols),
109 bound(dataset.n_rows),
110 dataset(new MatType(dataset)),
114 oldFromNew.resize(this->dataset->n_cols);
115 for (
size_t i = 0; i < this->dataset->n_cols; ++i)
121 bound |= *this->dataset;
125 double maxWidth = 0.0;
126 for (
size_t i = 0; i < bound.
Dim(); ++i)
127 if (bound[i].Hi() - bound[i].Lo() > maxWidth)
128 maxWidth = bound[i].Hi() - bound[i].Lo();
130 SplitNode(center, maxWidth, oldFromNew, maxLeafSize);
132 furthestDescendantDistance = 0.5 * bound.
Diameter();
136 furthestDescendantDistance = 0.0;
140 stat = StatisticType(*
this);
143 newFromOld.resize(this->dataset->n_cols);
144 for (
size_t i = 0; i < this->dataset->n_cols; ++i)
145 newFromOld[oldFromNew[i]] = i;
149 template<
typename MetricType,
typename StatisticType,
typename MatType>
151 const size_t maxLeafSize) :
153 count(dataset.n_cols),
154 bound(dataset.n_rows),
155 dataset(new MatType(
std::move(dataset))),
162 bound |= *this->dataset;
166 double maxWidth = 0.0;
167 for (
size_t i = 0; i < bound.
Dim(); ++i)
168 if (bound[i].Hi() - bound[i].Lo() > maxWidth)
169 maxWidth = bound[i].Hi() - bound[i].Lo();
171 SplitNode(center, maxWidth, maxLeafSize);
173 furthestDescendantDistance = 0.5 * bound.
Diameter();
177 furthestDescendantDistance = 0.0;
181 stat = StatisticType(*
this);
185 template<
typename MetricType,
typename StatisticType,
typename MatType>
188 std::vector<size_t>& oldFromNew,
189 const size_t maxLeafSize) :
191 count(dataset.n_cols),
192 bound(dataset.n_rows),
193 dataset(new MatType(
std::move(dataset))),
197 oldFromNew.resize(this->dataset->n_cols);
198 for (
size_t i = 0; i < this->dataset->n_cols; ++i)
204 bound |= *this->dataset;
208 double maxWidth = 0.0;
209 for (
size_t i = 0; i < bound.
Dim(); ++i)
210 if (bound[i].Hi() - bound[i].Lo() > maxWidth)
211 maxWidth = bound[i].Hi() - bound[i].Lo();
213 SplitNode(center, maxWidth, oldFromNew, maxLeafSize);
215 furthestDescendantDistance = 0.5 * bound.
Diameter();
219 furthestDescendantDistance = 0.0;
223 stat = StatisticType(*
this);
227 template<
typename MetricType,
typename StatisticType,
typename MatType>
230 std::vector<size_t>& oldFromNew,
231 std::vector<size_t>& newFromOld,
232 const size_t maxLeafSize) :
234 count(dataset.n_cols),
235 bound(dataset.n_rows),
236 dataset(new MatType(
std::move(dataset))),
240 oldFromNew.resize(this->dataset->n_cols);
241 for (
size_t i = 0; i < this->dataset->n_cols; ++i)
247 bound |= *this->dataset;
251 double maxWidth = 0.0;
252 for (
size_t i = 0; i < bound.
Dim(); ++i)
253 if (bound[i].Hi() - bound[i].Lo() > maxWidth)
254 maxWidth = bound[i].Hi() - bound[i].Lo();
256 SplitNode(center, maxWidth, oldFromNew, maxLeafSize);
258 furthestDescendantDistance = 0.5 * bound.
Diameter();
262 furthestDescendantDistance = 0.0;
266 stat = StatisticType(*
this);
269 newFromOld.resize(this->dataset->n_cols);
270 for (
size_t i = 0; i < this->dataset->n_cols; ++i)
271 newFromOld[oldFromNew[i]] = i;
275 template<
typename MetricType,
typename StatisticType,
typename MatType>
280 const arma::vec& center,
282 const size_t maxLeafSize) :
285 bound(parent->dataset->n_rows),
286 dataset(parent->dataset),
290 bound |= dataset->cols(begin, begin + count - 1);
293 SplitNode(center, width, maxLeafSize);
297 arma::vec trueCenter, parentCenter;
300 parentDistance = metric.Evaluate(trueCenter, parentCenter);
302 furthestDescendantDistance = 0.5 * bound.
Diameter();
305 stat = StatisticType(*
this);
309 template<
typename MetricType,
typename StatisticType,
typename MatType>
314 std::vector<size_t>& oldFromNew,
315 const arma::vec& center,
317 const size_t maxLeafSize) :
320 bound(parent->dataset->n_rows),
321 dataset(parent->dataset),
325 bound |= dataset->cols(begin, begin + count - 1);
328 SplitNode(center, width, oldFromNew, maxLeafSize);
332 arma::vec trueCenter, parentCenter;
335 parentDistance = metric.Evaluate(trueCenter, parentCenter);
337 furthestDescendantDistance = 0.5 * bound.
Diameter();
340 stat = StatisticType(*
this);
344 template<
typename MetricType,
typename StatisticType,
typename MatType>
349 dataset((other.parent == NULL) ? new MatType(*other.dataset) : NULL),
352 parentDistance(other.parentDistance),
353 furthestDescendantDistance(other.furthestDescendantDistance),
361 children[i]->parent =
this;
362 children[i]->dataset = this->dataset;
367 template<
typename MetricType,
typename StatisticType,
typename MatType>
378 for (
size_t i = 0; i < children.size(); ++i)
382 begin = other.Begin();
383 count = other.Count();
385 dataset = ((other.parent == NULL) ?
new MatType(*other.dataset) : NULL);
390 metric = other.metric;
397 children[i]->parent =
this;
398 children[i]->dataset = this->dataset;
404 template<
typename MetricType,
typename StatisticType,
typename MatType>
406 children(
std::move(other.children)),
409 bound(
std::move(other.bound)),
410 dataset(other.dataset),
411 parent(other.parent),
412 stat(
std::move(other.stat)),
413 parentDistance(other.parentDistance),
414 furthestDescendantDistance(other.furthestDescendantDistance),
415 metric(
std::move(other.metric))
418 for (
size_t i = 0; i < children.size(); ++i)
419 children[i]->parent =
this;
423 other.dataset =
new MatType();
424 other.parentDistance = 0.0;
425 other.furthestDescendantDistance = 0.0;
430 template<
typename MetricType,
typename StatisticType,
typename MatType>
441 for (
size_t i = 0; i < children.size(); ++i)
445 children = std::move(other.children);
446 begin = other.Begin();
447 count = other.Count();
448 bound = std::move(other.bound);
449 dataset = other.dataset;
451 stat = std::move(other.stat);
452 parentDistance = other.ParentDistance();
453 furthestDescendantDistance = other.furthestDescendantDistance();
454 metric = std::move(other.metric);
457 for (
size_t i = 0; i < children.size(); ++i)
458 children[i]->parent =
this;
462 other.dataset =
new MatType();
463 other.parentDistance = 0.0;
464 other.numDescendants = 0;
465 other.furthestDescendantDistance = 0.0;
471 template<
typename MetricType,
typename StatisticType,
typename MatType>
476 dataset(new MatType()),
479 furthestDescendantDistance(0.0)
484 template<
typename MetricType,
typename StatisticType,
typename MatType>
485 template<
typename Archive>
488 const typename std::enable_if_t<cereal::is_loading<Archive>()>*) :
492 ar(CEREAL_NVP(*
this));
495 template<
typename MetricType,
typename StatisticType,
typename MatType>
503 for (
size_t i = 0; i < children.size(); ++i)
508 template<
typename MetricType,
typename StatisticType,
typename MatType>
511 return children.size();
514 template<
typename MetricType,
typename StatisticType,
typename MatType>
515 template<
typename VecType>
517 const VecType& point,
526 const double dist = children[i]->MinDistance(point);
527 if (dist < bestDistance)
537 template<
typename MetricType,
typename StatisticType,
typename MatType>
538 template<
typename VecType>
540 const VecType& point,
549 const double dist = children[i]->MaxDistance(point);
550 if (dist > bestDistance)
560 template<
typename MetricType,
typename StatisticType,
typename MatType>
562 const Octree& queryNode)
const 570 const double dist = children[i]->MinDistance(queryNode);
571 if (dist < bestDistance)
581 template<
typename MetricType,
typename StatisticType,
typename MatType>
583 const Octree& queryNode)
const 591 const double dist = children[i]->MaxDistance(queryNode);
592 if (dist > bestDistance)
602 template<
typename MetricType,
typename StatisticType,
typename MatType>
609 return (children.size() > 0) ? 0.0 : furthestDescendantDistance;
612 template<
typename MetricType,
typename StatisticType,
typename MatType>
616 return furthestDescendantDistance;
619 template<
typename MetricType,
typename StatisticType,
typename MatType>
626 template<
typename MetricType,
typename StatisticType,
typename MatType>
630 return (children.size() > 0) ? 0 : count;
633 template<
typename MetricType,
typename StatisticType,
typename MatType>
639 template<
typename MetricType,
typename StatisticType,
typename MatType>
641 const size_t index)
const 643 return begin + index;
646 template<
typename MetricType,
typename StatisticType,
typename MatType>
650 return begin + index;
653 template<
typename MetricType,
typename StatisticType,
typename MatType>
661 template<
typename MetricType,
typename StatisticType,
typename MatType>
669 template<
typename MetricType,
typename StatisticType,
typename MatType>
677 template<
typename MetricType,
typename StatisticType,
typename MatType>
678 template<
typename VecType>
681 const VecType& point,
687 template<
typename MetricType,
typename StatisticType,
typename MatType>
688 template<
typename VecType>
691 const VecType& point,
698 template<
typename MetricType,
typename StatisticType,
typename MatType>
699 template<
typename VecType>
702 const VecType& point,
709 template<
typename MetricType,
typename StatisticType,
typename MatType>
710 template<
typename Archive>
716 if (cereal::is_loading<Archive>())
718 for (
size_t i = 0; i < children.size(); ++i)
728 bool hasParent = (parent != NULL);
730 ar(CEREAL_NVP(begin));
731 ar(CEREAL_NVP(count));
732 ar(CEREAL_NVP(bound));
733 ar(CEREAL_NVP(stat));
734 ar(CEREAL_NVP(parentDistance));
735 ar(CEREAL_NVP(furthestDescendantDistance));
736 ar(CEREAL_NVP(metric));
737 ar(CEREAL_NVP(hasParent));
740 MatType*& datasetTemp =
const_cast<MatType*&
>(dataset);
746 if (cereal::is_loading<Archive>())
748 for (
size_t i = 0; i < children.size(); ++i)
749 children[i]->parent =
this;
755 std::stack<Octree*> stack;
756 for (
size_t i = 0; i < children.size(); ++i)
758 stack.push(children[i]);
760 while (!stack.empty())
762 Octree* node = stack.top();
764 node->dataset = dataset;
765 for (
size_t i = 0; i < node->children.size(); ++i)
767 stack.push(node->children[i]);
774 template<
typename MetricType,
typename StatisticType,
typename MatType>
776 const arma::vec& center,
778 const size_t maxLeafSize)
782 if (count <= maxLeafSize)
786 arma::Col<size_t> childBegins(((
size_t) 1 << dataset->n_rows) + 1);
787 childBegins[0] = begin;
788 childBegins[childBegins.n_elem - 1] = begin + count;
792 std::stack<std::tuple<size_t, size_t, size_t, size_t>> stack;
793 stack.push(std::tuple<size_t, size_t, size_t, size_t>(dataset->n_rows - 1,
796 while (!stack.empty())
798 std::tuple<size_t, size_t, size_t, size_t> t = stack.top();
801 const size_t d = std::get<0>(t);
802 const size_t childBegin = std::get<1>(t);
803 const size_t childCount = std::get<2>(t);
804 const size_t leftChildIndex = std::get<3>(t);
811 const size_t firstRight = split::PerformSplit<MatType, SplitType>(*dataset,
812 childBegin, childCount, s);
816 const size_t rightChildIndex = leftChildIndex + ((size_t) 1 << d);
817 childBegins[rightChildIndex] = firstRight;
822 if (firstRight > childBegin)
824 stack.push(std::tuple<size_t, size_t, size_t, size_t>(d - 1, childBegin,
825 firstRight - childBegin, leftChildIndex));
830 for (
size_t c = leftChildIndex + 1; c < rightChildIndex; ++c)
831 childBegins[c] = childBegins[leftChildIndex];
834 if (firstRight < childBegin + childCount)
836 stack.push(std::tuple<size_t, size_t, size_t, size_t>(d - 1, firstRight,
837 childCount - (firstRight - childBegin), rightChildIndex));
842 for (
size_t c = rightChildIndex + 1;
843 c < rightChildIndex + (rightChildIndex - leftChildIndex); ++c)
844 childBegins[c] = childBegins[rightChildIndex];
850 arma::vec childCenter(center.n_elem);
851 const double childWidth = width / 2.0;
852 for (
size_t i = 0; i < childBegins.n_elem - 1; ++i)
855 if (childBegins[i + 1] - childBegins[i] == 0)
859 for (
size_t d = 0; d < center.n_elem; ++d)
862 if (((i >> d) & 1) == 0)
863 childCenter[d] = center[d] - childWidth;
865 childCenter[d] = center[d] + childWidth;
868 children.push_back(
new Octree(
this, childBegins[i],
869 childBegins[i + 1] - childBegins[i], childCenter, childWidth,
875 template<
typename MetricType,
typename StatisticType,
typename MatType>
877 const arma::vec& center,
879 std::vector<size_t>& oldFromNew,
880 const size_t maxLeafSize)
884 if (count <= maxLeafSize)
888 arma::Col<size_t> childBegins(((
size_t) 1 << dataset->n_rows) + 1);
889 childBegins[0] = begin;
890 childBegins[childBegins.n_elem - 1] = begin + count;
894 std::stack<std::tuple<size_t, size_t, size_t, size_t>> stack;
895 stack.push(std::tuple<size_t, size_t, size_t, size_t>(dataset->n_rows - 1,
898 while (!stack.empty())
900 std::tuple<size_t, size_t, size_t, size_t> t = stack.top();
903 const size_t d = std::get<0>(t);
904 const size_t childBegin = std::get<1>(t);
905 const size_t childCount = std::get<2>(t);
906 const size_t leftChildIndex = std::get<3>(t);
913 const size_t firstRight = split::PerformSplit<MatType, SplitType>(*dataset,
914 childBegin, childCount, s, oldFromNew);
918 const size_t rightChildIndex = leftChildIndex + ((size_t) 1 << d);
919 childBegins[rightChildIndex] = firstRight;
924 if (firstRight > childBegin)
926 stack.push(std::tuple<size_t, size_t, size_t, size_t>(d - 1, childBegin,
927 firstRight - childBegin, leftChildIndex));
932 for (
size_t c = leftChildIndex + 1; c < rightChildIndex; ++c)
933 childBegins[c] = childBegins[leftChildIndex];
936 if (firstRight < childBegin + childCount)
938 stack.push(std::tuple<size_t, size_t, size_t, size_t>(d - 1, firstRight,
939 childCount - (firstRight - childBegin), rightChildIndex));
944 for (
size_t c = rightChildIndex + 1;
945 c < rightChildIndex + (rightChildIndex - leftChildIndex); ++c)
946 childBegins[c] = childBegins[rightChildIndex];
952 arma::vec childCenter(center.n_elem);
953 const double childWidth = width / 2.0;
954 for (
size_t i = 0; i < childBegins.n_elem - 1; ++i)
957 if (childBegins[i + 1] - childBegins[i] == 0)
961 for (
size_t d = 0; d < center.n_elem; ++d)
964 if (((i >> d) & 1) == 0)
965 childCenter[d] = center[d] - childWidth;
967 childCenter[d] = center[d] + childWidth;
970 children.push_back(
new Octree(
this, childBegins[i],
971 childBegins[i + 1] - childBegins[i], oldFromNew, childCenter,
972 childWidth, maxLeafSize));
math::RangeType< ElemType > RangeDistance(const HRectBound &other) const
Calculates minimum and maximum bound-to-bound distance.
Definition: hrectbound_impl.hpp:391
ElemType MinWidth() const
Get the minimum width of the bound.
Definition: hrectbound.hpp:106
size_t GetNearestChild(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0) const
Return the index of the nearest child node to the given query point.
Definition: octree_impl.hpp:516
math::RangeType< ElemType > RangeDistance(const Octree &other) const
Return the minimum and maximum distance to another node.
Definition: octree_impl.hpp:671
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
~Octree()
Destroy the tree.
Definition: octree_impl.hpp:496
size_t NumPoints() const
Return the number of points in this node (0 if not a leaf).
Definition: octree_impl.hpp:627
Octree()
A default constructor.
Definition: octree_impl.hpp:472
Definition: pointer_wrapper.hpp:23
ElemType Diameter() const
Returns the diameter of the hyperrectangle (that is, the longest diagonal).
Definition: hrectbound_impl.hpp:669
size_t NumChildren() const
Return the number of children in this node.
Definition: octree_impl.hpp:509
ElemType MaxDistance(const Octree &other) const
Return the maximum distance to another node.
Definition: octree_impl.hpp:663
ElemType MinDistance(const Octree &other) const
Return the minimum distance to another node.
Definition: octree_impl.hpp:655
ElemType MaxDistance(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0) const
Calculates maximum bound-to-point squared distance.
Definition: hrectbound_impl.hpp:309
ElemType MinDistance(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0) const
Calculates minimum bound-to-point distance.
Definition: hrectbound_impl.hpp:189
Simple real-valued range.
Definition: range.hpp:19
MatType::elem_type ElemType
The type of element held in MatType.
Definition: octree.hpp:31
size_t Descendant(const size_t index) const
Return the index (with reference to the dataset) of a particular descendant.
Definition: octree_impl.hpp:640
Octree & operator=(const Octree &other)
Copy the given Octree.
Definition: octree_impl.hpp:370
ElemType ParentDistance() const
Return the distance from the center of this node to the center of the parent node.
Definition: octree.hpp:331
size_t Point(const size_t index) const
Return the index (with reference to the dataset) of a particular point in this node.
Definition: octree_impl.hpp:647
void Center(arma::Col< ElemType > ¢er) const
Calculates the center of the range, placing it into the given vector.
Definition: hrectbound_impl.hpp:153
const Octree & Child(const size_t child) const
Return the specified child.
Definition: octree.hpp:340
#define CEREAL_VECTOR_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_vector_wrapper.hpp:93
size_t GetFurthestChild(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0) const
Return the index of the furthest child node to the given query point.
Definition: octree_impl.hpp:539
Definition: octree.hpp:447
#define CEREAL_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_wrapper.hpp:96
ElemType MinimumBoundDistance() const
Return the minimum distance from the center of the node to any bound edge.
Definition: octree_impl.hpp:621
const bound::HRectBound< MetricType > & Bound() const
Return the bound object for this node.
Definition: octree.hpp:261
void serialize(Archive &ar, const uint32_t)
Serialize the tree.
Definition: octree_impl.hpp:711
size_t NumDescendants() const
Return the number of descendants of this node.
Definition: octree_impl.hpp:634
size_t Dim() const
Gets the dimensionality.
Definition: hrectbound.hpp:96
Definition: octree.hpp:25
ElemType FurthestPointDistance() const
Return the furthest distance to a point held in this node.
Definition: octree_impl.hpp:604
Octree * Parent() const
Get the pointer to the parent.
Definition: octree.hpp:256
If value == true, then VecType is some sort of Armadillo vector or subview.
Definition: arma_traits.hpp:35
ElemType FurthestDescendantDistance() const
Return the furthest possible descendant distance.
Definition: octree_impl.hpp:614