mlpack
octree_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_CORE_TREE_OCTREE_OCTREE_IMPL_HPP
13 #define MLPACK_CORE_TREE_OCTREE_OCTREE_IMPL_HPP
14 
15 #include "octree.hpp"
17 #include <stack>
18 
19 namespace mlpack {
20 namespace tree {
21 
23 template<typename MetricType, typename StatisticType, typename MatType>
25  const size_t maxLeafSize) :
26  begin(0),
27  count(dataset.n_cols),
28  bound(dataset.n_rows),
29  dataset(new MatType(dataset)),
30  parent(NULL),
31  parentDistance(0.0)
32 {
33  if (count > 0)
34  {
35  // Calculate empirical center of data.
36  bound |= *this->dataset;
37  arma::vec center;
38  bound.Center(center);
39 
40  double maxWidth = 0.0;
41  for (size_t i = 0; i < bound.Dim(); ++i)
42  if (bound[i].Hi() - bound[i].Lo() > maxWidth)
43  maxWidth = bound[i].Hi() - bound[i].Lo();
44 
45  SplitNode(center, maxWidth, maxLeafSize);
46 
47  furthestDescendantDistance = 0.5 * bound.Diameter();
48  }
49  else
50  {
51  furthestDescendantDistance = 0.0;
52  }
53 
54  // Initialize the statistic.
55  stat = StatisticType(*this);
56 }
57 
59 template<typename MetricType, typename StatisticType, typename MatType>
61  const MatType& dataset,
62  std::vector<size_t>& oldFromNew,
63  const size_t maxLeafSize) :
64  begin(0),
65  count(dataset.n_cols),
66  bound(dataset.n_rows),
67  dataset(new MatType(dataset)),
68  parent(NULL),
69  parentDistance(0.0)
70 {
71  oldFromNew.resize(this->dataset->n_cols);
72  for (size_t i = 0; i < this->dataset->n_cols; ++i)
73  oldFromNew[i] = i;
74 
75  if (count > 0)
76  {
77  // Calculate empirical center of data.
78  bound |= *this->dataset;
79  arma::vec center;
80  bound.Center(center);
81 
82  double maxWidth = 0.0;
83  for (size_t i = 0; i < bound.Dim(); ++i)
84  if (bound[i].Hi() - bound[i].Lo() > maxWidth)
85  maxWidth = bound[i].Hi() - bound[i].Lo();
86 
87  SplitNode(center, maxWidth, oldFromNew, maxLeafSize);
88 
89  furthestDescendantDistance = 0.5 * bound.Diameter();
90  }
91  else
92  {
93  furthestDescendantDistance = 0.0;
94  }
95 
96  // Initialize the statistic.
97  stat = StatisticType(*this);
98 }
99 
101 template<typename MetricType, typename StatisticType, typename MatType>
103  const MatType& dataset,
104  std::vector<size_t>& oldFromNew,
105  std::vector<size_t>& newFromOld,
106  const size_t maxLeafSize) :
107  begin(0),
108  count(dataset.n_cols),
109  bound(dataset.n_rows),
110  dataset(new MatType(dataset)),
111  parent(NULL),
112  parentDistance(0.0)
113 {
114  oldFromNew.resize(this->dataset->n_cols);
115  for (size_t i = 0; i < this->dataset->n_cols; ++i)
116  oldFromNew[i] = i;
117 
118  if (count > 0)
119  {
120  // Calculate empirical center of data.
121  bound |= *this->dataset;
122  arma::vec center;
123  bound.Center(center);
124 
125  double maxWidth = 0.0;
126  for (size_t i = 0; i < bound.Dim(); ++i)
127  if (bound[i].Hi() - bound[i].Lo() > maxWidth)
128  maxWidth = bound[i].Hi() - bound[i].Lo();
129 
130  SplitNode(center, maxWidth, oldFromNew, maxLeafSize);
131 
132  furthestDescendantDistance = 0.5 * bound.Diameter();
133  }
134  else
135  {
136  furthestDescendantDistance = 0.0;
137  }
138 
139  // Initialize the statistic.
140  stat = StatisticType(*this);
141 
142  // Map the newFromOld indices correctly.
143  newFromOld.resize(this->dataset->n_cols);
144  for (size_t i = 0; i < this->dataset->n_cols; ++i)
145  newFromOld[oldFromNew[i]] = i;
146 }
147 
149 template<typename MetricType, typename StatisticType, typename MatType>
151  const size_t maxLeafSize) :
152  begin(0),
153  count(dataset.n_cols),
154  bound(dataset.n_rows),
155  dataset(new MatType(std::move(dataset))),
156  parent(NULL),
157  parentDistance(0.0)
158 {
159  if (count > 0)
160  {
161  // Calculate empirical center of data.
162  bound |= *this->dataset;
163  arma::vec center;
164  bound.Center(center);
165 
166  double maxWidth = 0.0;
167  for (size_t i = 0; i < bound.Dim(); ++i)
168  if (bound[i].Hi() - bound[i].Lo() > maxWidth)
169  maxWidth = bound[i].Hi() - bound[i].Lo();
170 
171  SplitNode(center, maxWidth, maxLeafSize);
172 
173  furthestDescendantDistance = 0.5 * bound.Diameter();
174  }
175  else
176  {
177  furthestDescendantDistance = 0.0;
178  }
179 
180  // Initialize the statistic.
181  stat = StatisticType(*this);
182 }
183 
185 template<typename MetricType, typename StatisticType, typename MatType>
187  MatType&& dataset,
188  std::vector<size_t>& oldFromNew,
189  const size_t maxLeafSize) :
190  begin(0),
191  count(dataset.n_cols),
192  bound(dataset.n_rows),
193  dataset(new MatType(std::move(dataset))),
194  parent(NULL),
195  parentDistance(0.0)
196 {
197  oldFromNew.resize(this->dataset->n_cols);
198  for (size_t i = 0; i < this->dataset->n_cols; ++i)
199  oldFromNew[i] = i;
200 
201  if (count > 0)
202  {
203  // Calculate empirical center of data.
204  bound |= *this->dataset;
205  arma::vec center;
206  bound.Center(center);
207 
208  double maxWidth = 0.0;
209  for (size_t i = 0; i < bound.Dim(); ++i)
210  if (bound[i].Hi() - bound[i].Lo() > maxWidth)
211  maxWidth = bound[i].Hi() - bound[i].Lo();
212 
213  SplitNode(center, maxWidth, oldFromNew, maxLeafSize);
214 
215  furthestDescendantDistance = 0.5 * bound.Diameter();
216  }
217  else
218  {
219  furthestDescendantDistance = 0.0;
220  }
221 
222  // Initialize the statistic.
223  stat = StatisticType(*this);
224 }
225 
227 template<typename MetricType, typename StatisticType, typename MatType>
229  MatType&& dataset,
230  std::vector<size_t>& oldFromNew,
231  std::vector<size_t>& newFromOld,
232  const size_t maxLeafSize) :
233  begin(0),
234  count(dataset.n_cols),
235  bound(dataset.n_rows),
236  dataset(new MatType(std::move(dataset))),
237  parent(NULL),
238  parentDistance(0.0)
239 {
240  oldFromNew.resize(this->dataset->n_cols);
241  for (size_t i = 0; i < this->dataset->n_cols; ++i)
242  oldFromNew[i] = i;
243 
244  if (count > 0)
245  {
246  // Calculate empirical center of data.
247  bound |= *this->dataset;
248  arma::vec center;
249  bound.Center(center);
250 
251  double maxWidth = 0.0;
252  for (size_t i = 0; i < bound.Dim(); ++i)
253  if (bound[i].Hi() - bound[i].Lo() > maxWidth)
254  maxWidth = bound[i].Hi() - bound[i].Lo();
255 
256  SplitNode(center, maxWidth, oldFromNew, maxLeafSize);
257 
258  furthestDescendantDistance = 0.5 * bound.Diameter();
259  }
260  else
261  {
262  furthestDescendantDistance = 0.0;
263  }
264 
265  // Initialize the statistic.
266  stat = StatisticType(*this);
267 
268  // Map the newFromOld indices correctly.
269  newFromOld.resize(this->dataset->n_cols);
270  for (size_t i = 0; i < this->dataset->n_cols; ++i)
271  newFromOld[oldFromNew[i]] = i;
272 }
273 
275 template<typename MetricType, typename StatisticType, typename MatType>
277  Octree* parent,
278  const size_t begin,
279  const size_t count,
280  const arma::vec& center,
281  const double width,
282  const size_t maxLeafSize) :
283  begin(begin),
284  count(count),
285  bound(parent->dataset->n_rows),
286  dataset(parent->dataset),
287  parent(parent)
288 {
289  // Calculate empirical center of data.
290  bound |= dataset->cols(begin, begin + count - 1);
291 
292  // Now split the node.
293  SplitNode(center, width, maxLeafSize);
294 
295  // Calculate the distance from the empirical center of this node to the
296  // empirical center of the parent.
297  arma::vec trueCenter, parentCenter;
298  bound.Center(trueCenter);
299  parent->Bound().Center(parentCenter);
300  parentDistance = metric.Evaluate(trueCenter, parentCenter);
301 
302  furthestDescendantDistance = 0.5 * bound.Diameter();
303 
304  // Initialize the statistic.
305  stat = StatisticType(*this);
306 }
307 
309 template<typename MetricType, typename StatisticType, typename MatType>
311  Octree* parent,
312  const size_t begin,
313  const size_t count,
314  std::vector<size_t>& oldFromNew,
315  const arma::vec& center,
316  const double width,
317  const size_t maxLeafSize) :
318  begin(begin),
319  count(count),
320  bound(parent->dataset->n_rows),
321  dataset(parent->dataset),
322  parent(parent)
323 {
324  // Calculate empirical center of data.
325  bound |= dataset->cols(begin, begin + count - 1);
326 
327  // Now split the node.
328  SplitNode(center, width, oldFromNew, maxLeafSize);
329 
330  // Calculate the distance from the empirical center of this node to the
331  // empirical center of the parent.
332  arma::vec trueCenter, parentCenter;
333  bound.Center(trueCenter);
334  parent->Bound().Center(parentCenter);
335  parentDistance = metric.Evaluate(trueCenter, parentCenter);
336 
337  furthestDescendantDistance = 0.5 * bound.Diameter();
338 
339  // Initialize the statistic.
340  stat = StatisticType(*this);
341 }
342 
344 template<typename MetricType, typename StatisticType, typename MatType>
346  begin(other.begin),
347  count(other.count),
348  bound(other.bound),
349  dataset((other.parent == NULL) ? new MatType(*other.dataset) : NULL),
350  parent(NULL),
351  stat(other.stat),
352  parentDistance(other.parentDistance),
353  furthestDescendantDistance(other.furthestDescendantDistance),
354  metric(other.metric)
355 {
356  // If we have any children, we need to create them, and then ensure that their
357  // parent links are set right.
358  for (size_t i = 0; i < other.NumChildren(); ++i)
359  {
360  children.push_back(new Octree(other.Child(i)));
361  children[i]->parent = this;
362  children[i]->dataset = this->dataset;
363  }
364 }
365 
367 template<typename MetricType, typename StatisticType, typename MatType>
370 operator=(const Octree& other)
371 {
372  // Return if it's the same tree.
373  if (this == &other)
374  return *this;
375 
376  // Freeing memory that will not be used anymore.
377  delete dataset;
378  for (size_t i = 0; i < children.size(); ++i)
379  delete children[i];
380  children.clear();
381 
382  begin = other.Begin();
383  count = other.Count();
384  bound = other.bound;
385  dataset = ((other.parent == NULL) ? new MatType(*other.dataset) : NULL);
386  parent = NULL;
387  stat = other.stat;
388  parentDistance = other.ParentDistance();
389  furthestDescendantDistance = other.FurthestDescendantDistance();
390  metric = other.metric;
391 
392  // If we have any children, we need to create them, and then ensure that their
393  // parent links are set right.
394  for (size_t i = 0; i < other.NumChildren(); ++i)
395  {
396  children.push_back(new Octree(other.Child(i)));
397  children[i]->parent = this;
398  children[i]->dataset = this->dataset;
399  }
400  return *this;
401 }
402 
404 template<typename MetricType, typename StatisticType, typename MatType>
406  children(std::move(other.children)),
407  begin(other.begin),
408  count(other.count),
409  bound(std::move(other.bound)),
410  dataset(other.dataset),
411  parent(other.parent),
412  stat(std::move(other.stat)),
413  parentDistance(other.parentDistance),
414  furthestDescendantDistance(other.furthestDescendantDistance),
415  metric(std::move(other.metric))
416 {
417  // Update the parent pointers of the direct children.
418  for (size_t i = 0; i < children.size(); ++i)
419  children[i]->parent = this;
420 
421  other.begin = 0;
422  other.count = 0;
423  other.dataset = new MatType();
424  other.parentDistance = 0.0;
425  other.furthestDescendantDistance = 0.0;
426  other.parent = NULL;
427 }
428 
430 template<typename MetricType, typename StatisticType, typename MatType>
434 {
435  // Return if it's the same tree.
436  if (this == &other)
437  return *this;
438 
439  // Freeing memory that will not be used anymore.
440  delete dataset;
441  for (size_t i = 0; i < children.size(); ++i)
442  delete children[i];
443  children.clear();
444 
445  children = std::move(other.children);
446  begin = other.Begin();
447  count = other.Count();
448  bound = std::move(other.bound);
449  dataset = other.dataset;
450  parent = other.Parent();
451  stat = std::move(other.stat);
452  parentDistance = other.ParentDistance();
453  furthestDescendantDistance = other.furthestDescendantDistance();
454  metric = std::move(other.metric);
455 
456  // Update the parent pointers of the direct children.
457  for (size_t i = 0; i < children.size(); ++i)
458  children[i]->parent = this;
459 
460  other.begin = 0;
461  other.count = 0;
462  other.dataset = new MatType();
463  other.parentDistance = 0.0;
464  other.numDescendants = 0;
465  other.furthestDescendantDistance = 0.0;
466  other.parent = NULL;
467 
468  return *this;
469 }
470 
471 template<typename MetricType, typename StatisticType, typename MatType>
473  begin(0),
474  count(0),
475  bound(0),
476  dataset(new MatType()),
477  parent(NULL),
478  parentDistance(0.0),
479  furthestDescendantDistance(0.0)
480 {
481  // Nothing to do.
482 }
483 
484 template<typename MetricType, typename StatisticType, typename MatType>
485 template<typename Archive>
487  Archive& ar,
488  const typename std::enable_if_t<cereal::is_loading<Archive>()>*) :
489  Octree() // Create an empty tree.
490 {
491  // De-serialize the tree into this object.
492  ar(CEREAL_NVP(*this));
493 }
494 
495 template<typename MetricType, typename StatisticType, typename MatType>
497 {
498  // Delete the dataset if we aren't the parent.
499  if (!parent)
500  delete dataset;
501 
502  // Now delete each of the children.
503  for (size_t i = 0; i < children.size(); ++i)
504  delete children[i];
505  children.clear();
506 }
507 
508 template<typename MetricType, typename StatisticType, typename MatType>
510 {
511  return children.size();
512 }
513 
514 template<typename MetricType, typename StatisticType, typename MatType>
515 template<typename VecType>
517  const VecType& point,
518  typename std::enable_if_t<IsVector<VecType>::value>*) const
519 {
520  // It's possible that this could be improved by caching which children we have
521  // and which we don't, but for now this is just a brute force search.
522  ElemType bestDistance = DBL_MAX;
523  size_t bestIndex = NumChildren();
524  for (size_t i = 0; i < NumChildren(); ++i)
525  {
526  const double dist = children[i]->MinDistance(point);
527  if (dist < bestDistance)
528  {
529  bestDistance = dist;
530  bestIndex = i;
531  }
532  }
533 
534  return bestIndex;
535 }
536 
537 template<typename MetricType, typename StatisticType, typename MatType>
538 template<typename VecType>
540  const VecType& point,
541  typename std::enable_if_t<IsVector<VecType>::value>*) const
542 {
543  // It's possible that this could be improved by caching which children we have
544  // and which we don't, but for now this is just a brute force search.
545  ElemType bestDistance = -1.0; // Initialize to invalid distance.
546  size_t bestIndex = NumChildren();
547  for (size_t i = 0; i < NumChildren(); ++i)
548  {
549  const double dist = children[i]->MaxDistance(point);
550  if (dist > bestDistance)
551  {
552  bestDistance = dist;
553  bestIndex = i;
554  }
555  }
556 
557  return bestIndex;
558 }
559 
560 template<typename MetricType, typename StatisticType, typename MatType>
562  const Octree& queryNode) const
563 {
564  // It's possible that this could be improved by caching which children we have
565  // and which we don't, but for now this is just a brute force search.
566  ElemType bestDistance = DBL_MAX;
567  size_t bestIndex = NumChildren();
568  for (size_t i = 0; i < NumChildren(); ++i)
569  {
570  const double dist = children[i]->MinDistance(queryNode);
571  if (dist < bestDistance)
572  {
573  bestDistance = dist;
574  bestIndex = i;
575  }
576  }
577 
578  return bestIndex;
579 }
580 
581 template<typename MetricType, typename StatisticType, typename MatType>
583  const Octree& queryNode) const
584 {
585  // It's possible that this could be improved by caching which children we have
586  // and which we don't, but for now this is just a brute force search.
587  ElemType bestDistance = -1.0; // Initialize to invalid distance.
588  size_t bestIndex = NumChildren();
589  for (size_t i = 0; i < NumChildren(); ++i)
590  {
591  const double dist = children[i]->MaxDistance(queryNode);
592  if (dist > bestDistance)
593  {
594  bestDistance = dist;
595  bestIndex = i;
596  }
597  }
598 
599  return bestIndex;
600 }
601 
602 template<typename MetricType, typename StatisticType, typename MatType>
605  const
606 {
607  // If we are not a leaf, then this distance is 0. Otherwise, return the
608  // furthest descendant distance.
609  return (children.size() > 0) ? 0.0 : furthestDescendantDistance;
610 }
611 
612 template<typename MetricType, typename StatisticType, typename MatType>
615 {
616  return furthestDescendantDistance;
617 }
618 
619 template<typename MetricType, typename StatisticType, typename MatType>
622 {
623  return bound.MinWidth() / 2.0;
624 }
625 
626 template<typename MetricType, typename StatisticType, typename MatType>
628 {
629  // We have no points unless we are a leaf;
630  return (children.size() > 0) ? 0 : count;
631 }
632 
633 template<typename MetricType, typename StatisticType, typename MatType>
635 {
636  return count;
637 }
638 
639 template<typename MetricType, typename StatisticType, typename MatType>
641  const size_t index) const
642 {
643  return begin + index;
644 }
645 
646 template<typename MetricType, typename StatisticType, typename MatType>
648  const
649 {
650  return begin + index;
651 }
652 
653 template<typename MetricType, typename StatisticType, typename MatType>
656  const
657 {
658  return bound.MinDistance(other.Bound());
659 }
660 
661 template<typename MetricType, typename StatisticType, typename MatType>
664  const
665 {
666  return bound.MaxDistance(other.Bound());
667 }
668 
669 template<typename MetricType, typename StatisticType, typename MatType>
672  const
673 {
674  return bound.RangeDistance(other.Bound());
675 }
676 
677 template<typename MetricType, typename StatisticType, typename MatType>
678 template<typename VecType>
681  const VecType& point,
682  typename std::enable_if_t<IsVector<VecType>::value>*) const
683 {
684  return bound.MinDistance(point);
685 }
686 
687 template<typename MetricType, typename StatisticType, typename MatType>
688 template<typename VecType>
691  const VecType& point,
692  typename std::enable_if_t<IsVector<VecType>::value>*) const
693 {
694  return bound.MaxDistance(point);
695 }
696 
697 
698 template<typename MetricType, typename StatisticType, typename MatType>
699 template<typename VecType>
702  const VecType& point,
703  typename std::enable_if_t<IsVector<VecType>::value>*) const
704 {
705  return bound.RangeDistance(point);
706 }
707 
709 template<typename MetricType, typename StatisticType, typename MatType>
710 template<typename Archive>
712  Archive& ar,
713  const uint32_t /* version */)
714 {
715  // If we're loading and we have children, they need to be deleted.
716  if (cereal::is_loading<Archive>())
717  {
718  for (size_t i = 0; i < children.size(); ++i)
719  delete children[i];
720  children.clear();
721 
722  if (!parent)
723  delete dataset;
724 
725  parent = NULL;
726  }
727 
728  bool hasParent = (parent != NULL);
729 
730  ar(CEREAL_NVP(begin));
731  ar(CEREAL_NVP(count));
732  ar(CEREAL_NVP(bound));
733  ar(CEREAL_NVP(stat));
734  ar(CEREAL_NVP(parentDistance));
735  ar(CEREAL_NVP(furthestDescendantDistance));
736  ar(CEREAL_NVP(metric));
737  ar(CEREAL_NVP(hasParent));
738  if (!hasParent)
739  {
740  MatType*& datasetTemp = const_cast<MatType*&>(dataset);
741  ar(CEREAL_POINTER(datasetTemp));
742  }
743 
744  ar(CEREAL_VECTOR_POINTER(children));
745 
746  if (cereal::is_loading<Archive>())
747  {
748  for (size_t i = 0; i < children.size(); ++i)
749  children[i]->parent = this;
750  }
751 
752  // We have to correct the dataset pointers in all of the children.
753  if (!hasParent)
754  {
755  std::stack<Octree*> stack;
756  for (size_t i = 0; i < children.size(); ++i)
757  {
758  stack.push(children[i]);
759  }
760  while (!stack.empty())
761  {
762  Octree* node = stack.top();
763  stack.pop();
764  node->dataset = dataset;
765  for (size_t i = 0; i < node->children.size(); ++i)
766  {
767  stack.push(node->children[i]);
768  }
769  }
770  }
771 }
772 
774 template<typename MetricType, typename StatisticType, typename MatType>
776  const arma::vec& center,
777  const double width,
778  const size_t maxLeafSize)
779 {
780  // No need to split if we have fewer than the maximum number of points in this
781  // node.
782  if (count <= maxLeafSize)
783  return;
784 
785  // This will hold the index of the first point in each child.
786  arma::Col<size_t> childBegins(((size_t) 1 << dataset->n_rows) + 1);
787  childBegins[0] = begin;
788  childBegins[childBegins.n_elem - 1] = begin + count;
789 
790  // We will make log2(dim) passes, splitting along the last down to the first
791  // dimension. The tuple holds { dim, begin, count, leftChildIndex }.
792  std::stack<std::tuple<size_t, size_t, size_t, size_t>> stack;
793  stack.push(std::tuple<size_t, size_t, size_t, size_t>(dataset->n_rows - 1,
794  begin, count, 0));
795 
796  while (!stack.empty())
797  {
798  std::tuple<size_t, size_t, size_t, size_t> t = stack.top();
799  stack.pop();
800 
801  const size_t d = std::get<0>(t);
802  const size_t childBegin = std::get<1>(t);
803  const size_t childCount = std::get<2>(t);
804  const size_t leftChildIndex = std::get<3>(t);
805 
806  // Perform a "half-split": after this split, all points belonging to
807  // children of index 2^(d - 1) - 1 and less will be on the left side, and
808  // all points belonging to children of index 2^(d - 1) and above will be on
809  // the right side.
810  typename SplitType::SplitInfo s(d, center);
811  const size_t firstRight = split::PerformSplit<MatType, SplitType>(*dataset,
812  childBegin, childCount, s);
813 
814  // We can set the first index of the right child. The first index of the
815  // left child is already set.
816  const size_t rightChildIndex = leftChildIndex + ((size_t) 1 << d);
817  childBegins[rightChildIndex] = firstRight;
818 
819  // Now we have to recurse, if this was not the last dimension.
820  if (d != 0)
821  {
822  if (firstRight > childBegin)
823  {
824  stack.push(std::tuple<size_t, size_t, size_t, size_t>(d - 1, childBegin,
825  firstRight - childBegin, leftChildIndex));
826  }
827  else
828  {
829  // Set beginning indices correctly for all children below this level.
830  for (size_t c = leftChildIndex + 1; c < rightChildIndex; ++c)
831  childBegins[c] = childBegins[leftChildIndex];
832  }
833 
834  if (firstRight < childBegin + childCount)
835  {
836  stack.push(std::tuple<size_t, size_t, size_t, size_t>(d - 1, firstRight,
837  childCount - (firstRight - childBegin), rightChildIndex));
838  }
839  else
840  {
841  // Set beginning indices correctly for all children below this level.
842  for (size_t c = rightChildIndex + 1;
843  c < rightChildIndex + (rightChildIndex - leftChildIndex); ++c)
844  childBegins[c] = childBegins[rightChildIndex];
845  }
846  }
847  }
848 
849  // Now that the dataset is reordered, we can create the children.
850  arma::vec childCenter(center.n_elem);
851  const double childWidth = width / 2.0;
852  for (size_t i = 0; i < childBegins.n_elem - 1; ++i)
853  {
854  // If the child has no points, don't create it.
855  if (childBegins[i + 1] - childBegins[i] == 0)
856  continue;
857 
858  // Create the correct center.
859  for (size_t d = 0; d < center.n_elem; ++d)
860  {
861  // Is the dimension "right" (1) or "left" (0)?
862  if (((i >> d) & 1) == 0)
863  childCenter[d] = center[d] - childWidth;
864  else
865  childCenter[d] = center[d] + childWidth;
866  }
867 
868  children.push_back(new Octree(this, childBegins[i],
869  childBegins[i + 1] - childBegins[i], childCenter, childWidth,
870  maxLeafSize));
871  }
872 }
873 
875 template<typename MetricType, typename StatisticType, typename MatType>
877  const arma::vec& center,
878  const double width,
879  std::vector<size_t>& oldFromNew,
880  const size_t maxLeafSize)
881 {
882  // No need to split if we have fewer than the maximum number of points in this
883  // node.
884  if (count <= maxLeafSize)
885  return;
886 
887  // This will hold the index of the first point in each child.
888  arma::Col<size_t> childBegins(((size_t) 1 << dataset->n_rows) + 1);
889  childBegins[0] = begin;
890  childBegins[childBegins.n_elem - 1] = begin + count;
891 
892  // We will make log2(dim) passes, splitting along the last down to the first
893  // dimension. The tuple holds { dim, begin, count, leftChildIndex }.
894  std::stack<std::tuple<size_t, size_t, size_t, size_t>> stack;
895  stack.push(std::tuple<size_t, size_t, size_t, size_t>(dataset->n_rows - 1,
896  begin, count, 0));
897 
898  while (!stack.empty())
899  {
900  std::tuple<size_t, size_t, size_t, size_t> t = stack.top();
901  stack.pop();
902 
903  const size_t d = std::get<0>(t);
904  const size_t childBegin = std::get<1>(t);
905  const size_t childCount = std::get<2>(t);
906  const size_t leftChildIndex = std::get<3>(t);
907 
908  // Perform a "half-split": after this split, all points belonging to
909  // children of index 2^(d - 1) - 1 and less will be on the left side, and
910  // all points belonging to children of index 2^(d - 1) and above will be on
911  // the right side.
912  typename SplitType::SplitInfo s(d, center);
913  const size_t firstRight = split::PerformSplit<MatType, SplitType>(*dataset,
914  childBegin, childCount, s, oldFromNew);
915 
916  // We can set the first index of the right child. The first index of the
917  // left child is already set.
918  const size_t rightChildIndex = leftChildIndex + ((size_t) 1 << d);
919  childBegins[rightChildIndex] = firstRight;
920 
921  // Now we have to recurse, if this was not the last dimension.
922  if (d != 0)
923  {
924  if (firstRight > childBegin)
925  {
926  stack.push(std::tuple<size_t, size_t, size_t, size_t>(d - 1, childBegin,
927  firstRight - childBegin, leftChildIndex));
928  }
929  else
930  {
931  // Set beginning indices correctly for all children below this level.
932  for (size_t c = leftChildIndex + 1; c < rightChildIndex; ++c)
933  childBegins[c] = childBegins[leftChildIndex];
934  }
935 
936  if (firstRight < childBegin + childCount)
937  {
938  stack.push(std::tuple<size_t, size_t, size_t, size_t>(d - 1, firstRight,
939  childCount - (firstRight - childBegin), rightChildIndex));
940  }
941  else
942  {
943  // Set beginning indices correctly for all children below this level.
944  for (size_t c = rightChildIndex + 1;
945  c < rightChildIndex + (rightChildIndex - leftChildIndex); ++c)
946  childBegins[c] = childBegins[rightChildIndex];
947  }
948  }
949  }
950 
951  // Now that the dataset is reordered, we can create the children.
952  arma::vec childCenter(center.n_elem);
953  const double childWidth = width / 2.0;
954  for (size_t i = 0; i < childBegins.n_elem - 1; ++i)
955  {
956  // If the child has no points, don't create it.
957  if (childBegins[i + 1] - childBegins[i] == 0)
958  continue;
959 
960  // Create the correct center.
961  for (size_t d = 0; d < center.n_elem; ++d)
962  {
963  // Is the dimension "right" (1) or "left" (0)?
964  if (((i >> d) & 1) == 0)
965  childCenter[d] = center[d] - childWidth;
966  else
967  childCenter[d] = center[d] + childWidth;
968  }
969 
970  children.push_back(new Octree(this, childBegins[i],
971  childBegins[i + 1] - childBegins[i], oldFromNew, childCenter,
972  childWidth, maxLeafSize));
973  }
974 }
975 
976 } // namespace tree
977 } // namespace mlpack
978 
979 #endif
math::RangeType< ElemType > RangeDistance(const HRectBound &other) const
Calculates minimum and maximum bound-to-bound distance.
Definition: hrectbound_impl.hpp:391
ElemType MinWidth() const
Get the minimum width of the bound.
Definition: hrectbound.hpp:106
size_t GetNearestChild(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0) const
Return the index of the nearest child node to the given query point.
Definition: octree_impl.hpp:516
math::RangeType< ElemType > RangeDistance(const Octree &other) const
Return the minimum and maximum distance to another node.
Definition: octree_impl.hpp:671
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
~Octree()
Destroy the tree.
Definition: octree_impl.hpp:496
size_t NumPoints() const
Return the number of points in this node (0 if not a leaf).
Definition: octree_impl.hpp:627
Octree()
A default constructor.
Definition: octree_impl.hpp:472
Definition: pointer_wrapper.hpp:23
ElemType Diameter() const
Returns the diameter of the hyperrectangle (that is, the longest diagonal).
Definition: hrectbound_impl.hpp:669
size_t NumChildren() const
Return the number of children in this node.
Definition: octree_impl.hpp:509
ElemType MaxDistance(const Octree &other) const
Return the maximum distance to another node.
Definition: octree_impl.hpp:663
ElemType MinDistance(const Octree &other) const
Return the minimum distance to another node.
Definition: octree_impl.hpp:655
ElemType MaxDistance(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0) const
Calculates maximum bound-to-point squared distance.
Definition: hrectbound_impl.hpp:309
ElemType MinDistance(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0) const
Calculates minimum bound-to-point distance.
Definition: hrectbound_impl.hpp:189
Simple real-valued range.
Definition: range.hpp:19
MatType::elem_type ElemType
The type of element held in MatType.
Definition: octree.hpp:31
size_t Descendant(const size_t index) const
Return the index (with reference to the dataset) of a particular descendant.
Definition: octree_impl.hpp:640
Octree & operator=(const Octree &other)
Copy the given Octree.
Definition: octree_impl.hpp:370
ElemType ParentDistance() const
Return the distance from the center of this node to the center of the parent node.
Definition: octree.hpp:331
size_t Point(const size_t index) const
Return the index (with reference to the dataset) of a particular point in this node.
Definition: octree_impl.hpp:647
void Center(arma::Col< ElemType > &center) const
Calculates the center of the range, placing it into the given vector.
Definition: hrectbound_impl.hpp:153
const Octree & Child(const size_t child) const
Return the specified child.
Definition: octree.hpp:340
#define CEREAL_VECTOR_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_vector_wrapper.hpp:93
size_t GetFurthestChild(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0) const
Return the index of the furthest child node to the given query point.
Definition: octree_impl.hpp:539
#define CEREAL_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_wrapper.hpp:96
ElemType MinimumBoundDistance() const
Return the minimum distance from the center of the node to any bound edge.
Definition: octree_impl.hpp:621
const bound::HRectBound< MetricType > & Bound() const
Return the bound object for this node.
Definition: octree.hpp:261
void serialize(Archive &ar, const uint32_t)
Serialize the tree.
Definition: octree_impl.hpp:711
size_t NumDescendants() const
Return the number of descendants of this node.
Definition: octree_impl.hpp:634
size_t Dim() const
Gets the dimensionality.
Definition: hrectbound.hpp:96
Definition: octree.hpp:25
ElemType FurthestPointDistance() const
Return the furthest distance to a point held in this node.
Definition: octree_impl.hpp:604
Octree * Parent() const
Get the pointer to the parent.
Definition: octree.hpp:256
If value == true, then VecType is some sort of Armadillo vector or subview.
Definition: arma_traits.hpp:35
ElemType FurthestDescendantDistance() const
Return the furthest possible descendant distance.
Definition: octree_impl.hpp:614