mlpack
binary_space_tree_impl.hpp
Go to the documentation of this file.
1 
11 #ifndef MLPACK_CORE_TREE_BINARY_SPACE_TREE_BINARY_SPACE_TREE_IMPL_HPP
12 #define MLPACK_CORE_TREE_BINARY_SPACE_TREE_BINARY_SPACE_TREE_IMPL_HPP
13 
14 // In case it wasn't included already for some reason.
15 #include "binary_space_tree.hpp"
16 
17 #include <mlpack/core/util/log.hpp>
18 #include <queue>
19 
20 namespace mlpack {
21 namespace tree {
22 
23 // Each of these overloads is kept as a separate function to keep the overhead
24 // from the two std::vectors out, if possible.
25 template<typename MetricType,
26  typename StatisticType,
27  typename MatType,
28  template<typename BoundMetricType, typename...> class BoundType,
29  template<typename SplitBoundType, typename SplitMatType>
30  class SplitType>
33  const MatType& data,
34  const size_t maxLeafSize) :
35  left(NULL),
36  right(NULL),
37  parent(NULL),
38  begin(0), /* This root node starts at index 0, */
39  count(data.n_cols), /* and spans all of the dataset. */
40  bound(data.n_rows),
41  parentDistance(0), // Parent distance for the root is 0: it has no parent.
42  dataset(new MatType(data)) // Copies the dataset.
43 {
44  // Do the actual splitting of this node.
45  SplitType<BoundType<MetricType>, MatType> splitter;
46  SplitNode(maxLeafSize, splitter);
47 
48  // Create the statistic depending on if we are a leaf or not.
49  stat = StatisticType(*this);
50 }
51 
52 template<typename MetricType,
53  typename StatisticType,
54  typename MatType,
55  template<typename BoundMetricType, typename...> class BoundType,
56  template<typename SplitBoundType, typename SplitMatType>
57  class SplitType>
60  const MatType& data,
61  std::vector<size_t>& oldFromNew,
62  const size_t maxLeafSize) :
63  left(NULL),
64  right(NULL),
65  parent(NULL),
66  begin(0),
67  count(data.n_cols),
68  bound(data.n_rows),
69  parentDistance(0), // Parent distance for the root is 0: it has no parent.
70  dataset(new MatType(data)) // Copies the dataset.
71 {
72  // Initialize oldFromNew correctly.
73  oldFromNew.resize(data.n_cols);
74  for (size_t i = 0; i < data.n_cols; ++i)
75  oldFromNew[i] = i; // Fill with unharmed indices.
76 
77  // Now do the actual splitting.
78  SplitType<BoundType<MetricType>, MatType> splitter;
79  SplitNode(oldFromNew, maxLeafSize, splitter);
80 
81  // Create the statistic depending on if we are a leaf or not.
82  stat = StatisticType(*this);
83 }
84 
85 template<typename MetricType,
86  typename StatisticType,
87  typename MatType,
88  template<typename BoundMetricType, typename...> class BoundType,
89  template<typename SplitBoundType, typename SplitMatType>
90  class SplitType>
93  const MatType& data,
94  std::vector<size_t>& oldFromNew,
95  std::vector<size_t>& newFromOld,
96  const size_t maxLeafSize) :
97  left(NULL),
98  right(NULL),
99  parent(NULL),
100  begin(0),
101  count(data.n_cols),
102  bound(data.n_rows),
103  parentDistance(0), // Parent distance for the root is 0: it has no parent.
104  dataset(new MatType(data)) // Copies the dataset.
105 {
106  // Initialize the oldFromNew vector correctly.
107  oldFromNew.resize(data.n_cols);
108  for (size_t i = 0; i < data.n_cols; ++i)
109  oldFromNew[i] = i; // Fill with unharmed indices.
110 
111  // Now do the actual splitting.
112  SplitType<BoundType<MetricType>, MatType> splitter;
113  SplitNode(oldFromNew, maxLeafSize, splitter);
114 
115  // Create the statistic depending on if we are a leaf or not.
116  stat = StatisticType(*this);
117 
118  // Map the newFromOld indices correctly.
119  newFromOld.resize(data.n_cols);
120  for (size_t i = 0; i < data.n_cols; ++i)
121  newFromOld[oldFromNew[i]] = i;
122 }
123 
124 template<typename MetricType,
125  typename StatisticType,
126  typename MatType,
127  template<typename BoundMetricType, typename...> class BoundType,
128  template<typename SplitBoundType, typename SplitMatType>
129  class SplitType>
131 BinarySpaceTree(MatType&& data, const size_t maxLeafSize) :
132  left(NULL),
133  right(NULL),
134  parent(NULL),
135  begin(0),
136  count(data.n_cols),
137  bound(data.n_rows),
138  parentDistance(0), // Parent distance for the root is 0: it has no parent.
139  dataset(new MatType(std::move(data)))
140 {
141  // Do the actual splitting of this node.
142  SplitType<BoundType<MetricType>, MatType> splitter;
143  SplitNode(maxLeafSize, splitter);
144 
145  // Create the statistic depending on if we are a leaf or not.
146  stat = StatisticType(*this);
147 }
148 
149 template<typename MetricType,
150  typename StatisticType,
151  typename MatType,
152  template<typename BoundMetricType, typename...> class BoundType,
153  template<typename SplitBoundType, typename SplitMatType>
154  class SplitType>
157  MatType&& data,
158  std::vector<size_t>& oldFromNew,
159  const size_t maxLeafSize) :
160  left(NULL),
161  right(NULL),
162  parent(NULL),
163  begin(0),
164  count(data.n_cols),
165  bound(data.n_rows),
166  parentDistance(0), // Parent distance for the root is 0: it has no parent.
167  dataset(new MatType(std::move(data)))
168 {
169  // Initialize oldFromNew correctly.
170  oldFromNew.resize(dataset->n_cols);
171  for (size_t i = 0; i < dataset->n_cols; ++i)
172  oldFromNew[i] = i; // Fill with unharmed indices.
173 
174  // Now do the actual splitting.
175  SplitType<BoundType<MetricType>, MatType> splitter;
176  SplitNode(oldFromNew, maxLeafSize, splitter);
177 
178  // Create the statistic depending on if we are a leaf or not.
179  stat = StatisticType(*this);
180 }
181 
182 template<typename MetricType,
183  typename StatisticType,
184  typename MatType,
185  template<typename BoundMetricType, typename...> class BoundType,
186  template<typename SplitBoundType, typename SplitMatType>
187  class SplitType>
190  MatType&& data,
191  std::vector<size_t>& oldFromNew,
192  std::vector<size_t>& newFromOld,
193  const size_t maxLeafSize) :
194  left(NULL),
195  right(NULL),
196  parent(NULL),
197  begin(0),
198  count(data.n_cols),
199  bound(data.n_rows),
200  parentDistance(0), // Parent distance for the root is 0: it has no parent.
201  dataset(new MatType(std::move(data)))
202 {
203  // Initialize the oldFromNew vector correctly.
204  oldFromNew.resize(dataset->n_cols);
205  for (size_t i = 0; i < dataset->n_cols; ++i)
206  oldFromNew[i] = i; // Fill with unharmed indices.
207 
208  // Now do the actual splitting.
209  SplitType<BoundType<MetricType>, MatType> splitter;
210  SplitNode(oldFromNew, maxLeafSize, splitter);
211 
212  // Create the statistic depending on if we are a leaf or not.
213  stat = StatisticType(*this);
214 
215  // Map the newFromOld indices correctly.
216  newFromOld.resize(dataset->n_cols);
217  for (size_t i = 0; i < dataset->n_cols; ++i)
218  newFromOld[oldFromNew[i]] = i;
219 }
220 
221 template<typename MetricType,
222  typename StatisticType,
223  typename MatType,
224  template<typename BoundMetricType, typename...> class BoundType,
225  template<typename SplitBoundType, typename SplitMatType>
226  class SplitType>
229  BinarySpaceTree* parent,
230  const size_t begin,
231  const size_t count,
232  SplitType<BoundType<MetricType>, MatType>& splitter,
233  const size_t maxLeafSize) :
234  left(NULL),
235  right(NULL),
236  parent(parent),
237  begin(begin),
238  count(count),
239  bound(parent->Dataset().n_rows),
240  dataset(&parent->Dataset()) // Point to the parent's dataset.
241 {
242  // Perform the actual splitting.
243  SplitNode(maxLeafSize, splitter);
244 
245  // Create the statistic depending on if we are a leaf or not.
246  stat = StatisticType(*this);
247 }
248 
249 template<typename MetricType,
250  typename StatisticType,
251  typename MatType,
252  template<typename BoundMetricType, typename...> class BoundType,
253  template<typename SplitBoundType, typename SplitMatType>
254  class SplitType>
257  BinarySpaceTree* parent,
258  const size_t begin,
259  const size_t count,
260  std::vector<size_t>& oldFromNew,
261  SplitType<BoundType<MetricType>, MatType>& splitter,
262  const size_t maxLeafSize) :
263  left(NULL),
264  right(NULL),
265  parent(parent),
266  begin(begin),
267  count(count),
268  bound(parent->Dataset().n_rows),
269  dataset(&parent->Dataset())
270 {
271  // Hopefully the vector is initialized correctly! We can't check that
272  // entirely but we can do a minor sanity check.
273  assert(oldFromNew.size() == dataset->n_cols);
274 
275  // Perform the actual splitting.
276  SplitNode(oldFromNew, maxLeafSize, splitter);
277 
278  // Create the statistic depending on if we are a leaf or not.
279  stat = StatisticType(*this);
280 }
281 
282 template<typename MetricType,
283  typename StatisticType,
284  typename MatType,
285  template<typename BoundMetricType, typename...> class BoundType,
286  template<typename SplitBoundType, typename SplitMatType>
287  class SplitType>
290  BinarySpaceTree* parent,
291  const size_t begin,
292  const size_t count,
293  std::vector<size_t>& oldFromNew,
294  std::vector<size_t>& newFromOld,
295  SplitType<BoundType<MetricType>, MatType>& splitter,
296  const size_t maxLeafSize) :
297  left(NULL),
298  right(NULL),
299  parent(parent),
300  begin(begin),
301  count(count),
302  bound(parent->Dataset()->n_rows),
303  dataset(&parent->Dataset())
304 {
305  // Hopefully the vector is initialized correctly! We can't check that
306  // entirely but we can do a minor sanity check.
307  Log::Assert(oldFromNew.size() == dataset->n_cols);
308 
309  // Perform the actual splitting.
310  SplitNode(oldFromNew, maxLeafSize, splitter);
311 
312  // Create the statistic depending on if we are a leaf or not.
313  stat = StatisticType(*this);
314 
315  // Map the newFromOld indices correctly.
316  newFromOld.resize(dataset->n_cols);
317  for (size_t i = 0; i < dataset->n_cols; ++i)
318  newFromOld[oldFromNew[i]] = i;
319 }
320 
325 template<typename MetricType,
326  typename StatisticType,
327  typename MatType,
328  template<typename BoundMetricType, typename...> class BoundType,
329  template<typename SplitBoundType, typename SplitMatType>
330  class SplitType>
333  const BinarySpaceTree& other) :
334  left(NULL),
335  right(NULL),
336  parent(other.parent),
337  begin(other.begin),
338  count(other.count),
339  bound(other.bound),
340  stat(other.stat),
341  parentDistance(other.parentDistance),
342  furthestDescendantDistance(other.furthestDescendantDistance),
343  minimumBoundDistance(other.minimumBoundDistance),
344  // Copy matrix, but only if we are the root.
345  dataset((other.parent == NULL) ? new MatType(*other.dataset) : NULL)
346 {
347  // Create left and right children (if any).
348  if (other.Left())
349  {
350  left = new BinarySpaceTree(*other.Left());
351  left->Parent() = this; // Set parent to this, not other tree.
352  }
353 
354  if (other.Right())
355  {
356  right = new BinarySpaceTree(*other.Right());
357  right->Parent() = this; // Set parent to this, not other tree.
358  }
359 
360  // Propagate matrix, but only if we are the root.
361  if (parent == NULL)
362  {
363  std::queue<BinarySpaceTree*> queue;
364  if (left)
365  queue.push(left);
366  if (right)
367  queue.push(right);
368  while (!queue.empty())
369  {
370  BinarySpaceTree* node = queue.front();
371  queue.pop();
372 
373  node->dataset = dataset;
374  if (node->left)
375  queue.push(node->left);
376  if (node->right)
377  queue.push(node->right);
378  }
379  }
380 }
381 
385 template<typename MetricType,
386  typename StatisticType,
387  typename MatType,
388  template<typename BoundMetricType, typename...> class BoundType,
389  template<typename SplitBoundType, typename SplitMatType>
390  class SplitType>
394 {
395  // Return if it's the same tree.
396  if (this == &other)
397  return *this;
398 
399  // Freeing memory that will not be used anymore.
400  delete dataset;
401  delete left;
402  delete right;
403 
404  left = NULL;
405  right = NULL;
406  parent = other.Parent();
407  begin = other.Begin();
408  count = other.Count();
409  bound = other.bound;
410  stat = other.stat;
411  parentDistance = other.ParentDistance();
412  furthestDescendantDistance = other.FurthestDescendantDistance();
413  minimumBoundDistance = other.MinimumBoundDistance();
414  // Copy matrix, but only if we are the root.
415  dataset = ((other.parent == NULL) ? new MatType(*other.dataset) : NULL);
416 
417  // Create left and right children (if any).
418  if (other.Left())
419  {
420  left = new BinarySpaceTree(*other.Left());
421  left->Parent() = this; // Set parent to this, not other tree.
422  }
423 
424  if (other.Right())
425  {
426  right = new BinarySpaceTree(*other.Right());
427  right->Parent() = this; // Set parent to this, not other tree.
428  }
429 
430  // Propagate matrix, but only if we are the root.
431  if (parent == NULL)
432  {
433  std::queue<BinarySpaceTree*> queue;
434  if (left)
435  queue.push(left);
436  if (right)
437  queue.push(right);
438  while (!queue.empty())
439  {
440  BinarySpaceTree* node = queue.front();
441  queue.pop();
442 
443  node->dataset = dataset;
444  if (node->left)
445  queue.push(node->left);
446  if (node->right)
447  queue.push(node->right);
448  }
449  }
450 
451  return *this;
452 }
453 
457 template<typename MetricType,
458  typename StatisticType,
459  typename MatType,
460  template<typename BoundMetricType, typename...> class BoundType,
461  template<typename SplitBoundType, typename SplitMatType>
462  class SplitType>
466 {
467  // Return if it's the same tree.
468  if (this == &other)
469  return *this;
470 
471  // Freeing memory that will not be used anymore.
472  delete dataset;
473  delete left;
474  delete right;
475 
476  parent = other.Parent();
477  left = other.Left();
478  right = other.Right();
479  begin = other.Begin();
480  count = other.Count();
481  bound = std::move(other.bound);
482  stat = std::move(other.stat);
483  parentDistance = other.ParentDistance();
484  furthestDescendantDistance = other.FurthestDescendantDistance();
485  minimumBoundDistance = other.MinimumBoundDistance();
486  dataset = other.dataset;
487 
488  other.left = NULL;
489  other.right = NULL;
490  other.parent = NULL;
491  other.begin = 0;
492  other.count = 0;
493  other.parentDistance = 0.0;
494  other.furthestDescendantDistance = 0.0;
495  other.minimumBoundDistance = 0.0;
496  other.dataset = NULL;
497 
498  return *this;
499 }
500 
501 
505 template<typename MetricType,
506  typename StatisticType,
507  typename MatType,
508  template<typename BoundMetricType, typename...> class BoundType,
509  template<typename SplitBoundType, typename SplitMatType>
510  class SplitType>
513  left(other.left),
514  right(other.right),
515  parent(other.parent),
516  begin(other.begin),
517  count(other.count),
518  bound(std::move(other.bound)),
519  stat(std::move(other.stat)),
520  parentDistance(other.parentDistance),
521  furthestDescendantDistance(other.furthestDescendantDistance),
522  minimumBoundDistance(other.minimumBoundDistance),
523  dataset(other.dataset)
524 {
525  // Now we are a clone of the other tree. But we must also clear the other
526  // tree's contents, so it doesn't delete anything when it is destructed.
527  other.left = NULL;
528  other.right = NULL;
529  other.parent = NULL;
530  other.begin = 0;
531  other.count = 0;
532  other.parentDistance = 0.0;
533  other.furthestDescendantDistance = 0.0;
534  other.minimumBoundDistance = 0.0;
535  other.dataset = NULL;
536 
537  // Set new parent.
538  if (left)
539  left->parent = this;
540  if (right)
541  right->parent = this;
542 }
543 
547 template<typename MetricType,
548  typename StatisticType,
549  typename MatType,
550  template<typename BoundMetricType, typename...> class BoundType,
551  template<typename SplitBoundType, typename SplitMatType>
552  class SplitType>
553 template<typename Archive>
556  Archive& ar,
557  const typename std::enable_if_t<cereal::is_loading<Archive>()>*) :
558  BinarySpaceTree() // Create an empty BinarySpaceTree.
559 {
560  // We've delegated to the constructor which gives us an empty tree, and now we
561  // can serialize from it.
562  ar(CEREAL_NVP(*this));
563 }
564 
570 template<typename MetricType,
571  typename StatisticType,
572  typename MatType,
573  template<typename BoundMetricType, typename...> class BoundType,
574  template<typename SplitBoundType, typename SplitMatType>
575  class SplitType>
578 {
579  delete left;
580  delete right;
581 
582  // If we're the root, delete the matrix.
583  if (!parent)
584  delete dataset;
585 }
586 
587 template<typename MetricType,
588  typename StatisticType,
589  typename MatType,
590  template<typename BoundMetricType, typename...> class BoundType,
591  template<typename SplitBoundType, typename SplitMatType>
592  class SplitType>
593 inline bool BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
594  SplitType>::IsLeaf() const
595 {
596  return !left;
597 }
598 
602 template<typename MetricType,
603  typename StatisticType,
604  typename MatType,
605  template<typename BoundMetricType, typename...> class BoundType,
606  template<typename SplitBoundType, typename SplitMatType>
607  class SplitType>
608 inline size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
609  SplitType>::NumChildren() const
610 {
611  if (left && right)
612  return 2;
613  if (left)
614  return 1;
615 
616  return 0;
617 }
618 
623 template<typename MetricType,
624  typename StatisticType,
625  typename MatType,
626  template<typename BoundMetricType, typename...> class BoundType,
627  template<typename SplitBoundType, typename SplitMatType>
628  class SplitType>
629 template<typename VecType>
630 size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
631  SplitType>::GetNearestChild(
632  const VecType& point,
633  typename std::enable_if_t<IsVector<VecType>::value>*)
634 {
635  if (IsLeaf() || !left || !right)
636  return 0;
637 
638  if (left->MinDistance(point) <= right->MinDistance(point))
639  return 0;
640  return 1;
641 }
642 
647 template<typename MetricType,
648  typename StatisticType,
649  typename MatType,
650  template<typename BoundMetricType, typename...> class BoundType,
651  template<typename SplitBoundType, typename SplitMatType>
652  class SplitType>
653 template<typename VecType>
654 size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
656  const VecType& point,
657  typename std::enable_if_t<IsVector<VecType>::value>*)
658 {
659  if (IsLeaf() || !left || !right)
660  return 0;
661 
662  if (left->MaxDistance(point) > right->MaxDistance(point))
663  return 0;
664  return 1;
665 }
666 
671 template<typename MetricType,
672  typename StatisticType,
673  typename MatType,
674  template<typename BoundMetricType, typename...> class BoundType,
675  template<typename SplitBoundType, typename SplitMatType>
676  class SplitType>
677 size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
678  SplitType>::GetNearestChild(const BinarySpaceTree& queryNode)
679 {
680  if (IsLeaf() || !left || !right)
681  return 0;
682 
683  ElemType leftDist = left->MinDistance(queryNode);
684  ElemType rightDist = right->MinDistance(queryNode);
685  if (leftDist < rightDist)
686  return 0;
687  if (rightDist < leftDist)
688  return 1;
689  return NumChildren();
690 }
691 
696 template<typename MetricType,
697  typename StatisticType,
698  typename MatType,
699  template<typename BoundMetricType, typename...> class BoundType,
700  template<typename SplitBoundType, typename SplitMatType>
701  class SplitType>
702 size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
703  SplitType>::GetFurthestChild(const BinarySpaceTree& queryNode)
704 {
705  if (IsLeaf() || !left || !right)
706  return 0;
707 
708  ElemType leftDist = left->MaxDistance(queryNode);
709  ElemType rightDist = right->MaxDistance(queryNode);
710  if (leftDist > rightDist)
711  return 0;
712  if (rightDist > leftDist)
713  return 1;
714  return NumChildren();
715 }
716 
721 template<typename MetricType,
722  typename StatisticType,
723  typename MatType,
724  template<typename BoundMetricType, typename...> class BoundType,
725  template<typename SplitBoundType, typename SplitMatType>
726  class SplitType>
727 inline
728 typename BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
729  SplitType>::ElemType
730 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
731  SplitType>::FurthestPointDistance() const
732 {
733  if (!IsLeaf())
734  return 0.0;
735 
736  // Otherwise return the distance from the center to a corner of the bound.
737  return 0.5 * bound.Diameter();
738 }
739 
747 template<typename MetricType,
748  typename StatisticType,
749  typename MatType,
750  template<typename BoundMetricType, typename...> class BoundType,
751  template<typename SplitBoundType, typename SplitMatType>
752  class SplitType>
753 inline
754 typename BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
755  SplitType>::ElemType
756 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
758 {
759  return furthestDescendantDistance;
760 }
761 
763 template<typename MetricType,
764  typename StatisticType,
765  typename MatType,
766  template<typename BoundMetricType, typename...> class BoundType,
767  template<typename SplitBoundType, typename SplitMatType>
768  class SplitType>
769 inline
770 typename BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
771  SplitType>::ElemType
772 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
773  SplitType>::MinimumBoundDistance() const
774 {
775  return bound.MinWidth() / 2.0;
776 }
777 
781 template<typename MetricType,
782  typename StatisticType,
783  typename MatType,
784  template<typename BoundMetricType, typename...> class BoundType,
785  template<typename SplitBoundType, typename SplitMatType>
786  class SplitType>
787 inline BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
788  SplitType>&
789  BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
790  SplitType>::Child(const size_t child) const
791 {
792  if (child == 0)
793  return *left;
794  else
795  return *right;
796 }
797 
801 template<typename MetricType,
802  typename StatisticType,
803  typename MatType,
804  template<typename BoundMetricType, typename...> class BoundType,
805  template<typename SplitBoundType, typename SplitMatType>
806  class SplitType>
807 inline size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
808  SplitType>::NumPoints() const
809 {
810  if (left)
811  return 0;
812 
813  return count;
814 }
815 
819 template<typename MetricType,
820  typename StatisticType,
821  typename MatType,
822  template<typename BoundMetricType, typename...> class BoundType,
823  template<typename SplitBoundType, typename SplitMatType>
824  class SplitType>
825 inline size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
826  SplitType>::NumDescendants() const
827 {
828  return count;
829 }
830 
834 template<typename MetricType,
835  typename StatisticType,
836  typename MatType,
837  template<typename BoundMetricType, typename...> class BoundType,
838  template<typename SplitBoundType, typename SplitMatType>
839  class SplitType>
840 inline size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
841  SplitType>::Descendant(const size_t index) const
842 {
843  return (begin + index);
844 }
845 
849 template<typename MetricType,
850  typename StatisticType,
851  typename MatType,
852  template<typename BoundMetricType, typename...> class BoundType,
853  template<typename SplitBoundType, typename SplitMatType>
854  class SplitType>
855 inline size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
856  SplitType>::Point(const size_t index) const
857 {
858  return (begin + index);
859 }
860 
861 template<typename MetricType,
862  typename StatisticType,
863  typename MatType,
864  template<typename BoundMetricType, typename...> class BoundType,
865  template<typename SplitBoundType, typename SplitMatType>
866  class SplitType>
868  SplitNode(const size_t maxLeafSize,
869  SplitType<BoundType<MetricType>, MatType>& splitter)
870 {
871  // We need to expand the bounds of this node properly.
872  UpdateBound(bound);
873 
874  // Calculate the furthest descendant distance.
875  furthestDescendantDistance = 0.5 * bound.Diameter();
876 
877  // Now, check if we need to split at all.
878  if (count <= maxLeafSize)
879  return; // We can't split this.
880 
881  // splitCol denotes the two partitions of the dataset after the split. The
882  // points on its left go to the left child and the others go to the right
883  // child.
884  size_t splitCol;
885 
886  // Find the partition of the node. This method does not perform the split.
887  typename Split::SplitInfo splitInfo;
888 
889  const bool split = splitter.SplitNode(bound, *dataset, begin, count,
890  splitInfo);
891 
892  // The node may not be always split. For instance, if all the points are the
893  // same, we can't split them.
894  if (!split)
895  return;
896 
897  // Perform the actual splitting. This will order the dataset such that
898  // points that belong to the left subtree are on the left of splitCol, and
899  // points from the right subtree are on the right side of splitCol.
900  splitCol = splitter.PerformSplit(*dataset, begin, count, splitInfo);
901 
902  assert(splitCol > begin);
903  assert(splitCol < begin + count);
904 
905  // Now that we know the split column, we will recursively split the children
906  // by calling their constructors (which perform this splitting process).
907  left = new BinarySpaceTree(this, begin, splitCol - begin, splitter,
908  maxLeafSize);
909  right = new BinarySpaceTree(this, splitCol, begin + count - splitCol,
910  splitter, maxLeafSize);
911 
912  // Calculate parent distances for those two nodes.
913  arma::vec center, leftCenter, rightCenter;
914  Center(center);
915  left->Center(leftCenter);
916  right->Center(rightCenter);
917 
918  const ElemType leftParentDistance = bound.Metric().Evaluate(center,
919  leftCenter);
920  const ElemType rightParentDistance = bound.Metric().Evaluate(center,
921  rightCenter);
922 
923  left->ParentDistance() = leftParentDistance;
924  right->ParentDistance() = rightParentDistance;
925 }
926 
927 template<typename MetricType,
928  typename StatisticType,
929  typename MatType,
930  template<typename BoundMetricType, typename...> class BoundType,
931  template<typename SplitBoundType, typename SplitMatType>
932  class SplitType>
934 SplitNode(std::vector<size_t>& oldFromNew,
935  const size_t maxLeafSize,
936  SplitType<BoundType<MetricType>, MatType>& splitter)
937 {
938  // We need to expand the bounds of this node properly.
939  UpdateBound(bound);
940 
941  // Calculate the furthest descendant distance.
942  furthestDescendantDistance = 0.5 * bound.Diameter();
943 
944  // First, check if we need to split at all.
945  if (count <= maxLeafSize)
946  return; // We can't split this.
947 
948  // splitCol denotes the two partitions of the dataset after the split. The
949  // points on its left go to the left child and the others go to the right
950  // child.
951  size_t splitCol;
952 
953  // Find the partition of the node. This method does not perform the split.
954  typename Split::SplitInfo splitInfo;
955 
956  const bool split = splitter.SplitNode(bound, *dataset, begin, count,
957  splitInfo);
958 
959  // The node may not be always split. For instance, if all the points are the
960  // same, we can't split them.
961  if (!split)
962  return;
963 
964  // Perform the actual splitting. This will order the dataset such that
965  // points that belong to the left subtree are on the left of splitCol, and
966  // points from the right subtree are on the right side of splitCol.
967  splitCol = splitter.PerformSplit(*dataset, begin, count, splitInfo,
968  oldFromNew);
969 
970  assert(splitCol > begin);
971  assert(splitCol < begin + count);
972 
973  // Now that we know the split column, we will recursively split the children
974  // by calling their constructors (which perform this splitting process).
975  left = new BinarySpaceTree(this, begin, splitCol - begin, oldFromNew,
976  splitter, maxLeafSize);
977  right = new BinarySpaceTree(this, splitCol, begin + count - splitCol,
978  oldFromNew, splitter, maxLeafSize);
979 
980  // Calculate parent distances for those two nodes.
981  arma::vec center, leftCenter, rightCenter;
982  Center(center);
983  left->Center(leftCenter);
984  right->Center(rightCenter);
985 
986  const ElemType leftParentDistance = bound.Metric().Evaluate(center,
987  leftCenter);
988  const ElemType rightParentDistance = bound.Metric().Evaluate(center,
989  rightCenter);
990 
991  left->ParentDistance() = leftParentDistance;
992  right->ParentDistance() = rightParentDistance;
993 }
994 
995 template<typename MetricType,
996  typename StatisticType,
997  typename MatType,
998  template<typename BoundMetricType, typename...> class BoundType,
999  template<typename SplitBoundType, typename SplitMatType>
1000  class SplitType>
1001 template<typename BoundType2>
1003 UpdateBound(BoundType2& boundToUpdate)
1004 {
1005  if (count > 0)
1006  boundToUpdate |= dataset->cols(begin, begin + count - 1);
1007 }
1008 
1009 template<typename MetricType,
1010  typename StatisticType,
1011  typename MatType,
1012  template<typename BoundMetricType, typename...> class BoundType,
1013  template<typename SplitBoundType, typename SplitMatType>
1014  class SplitType>
1017 {
1018  if (!parent)
1019  {
1020  if (count > 0)
1021  boundToUpdate |= dataset->cols(begin, begin + count - 1);
1022  return;
1023  }
1024 
1025  if (parent->left != NULL && parent->left != this)
1026  {
1027  boundToUpdate.HollowCenter() = parent->left->bound.Center();
1028  boundToUpdate.InnerRadius() = std::numeric_limits<ElemType>::max();
1029  }
1030 
1031  if (count > 0)
1032  boundToUpdate |= dataset->cols(begin, begin + count - 1);
1033 }
1034 
1035 // Default constructor (private), for cereal.
1036 template<typename MetricType,
1037  typename StatisticType,
1038  typename MatType,
1039  template<typename BoundMetricType, typename...> class BoundType,
1040  template<typename SplitBoundType, typename SplitMatType>
1041  class SplitType>
1044  left(NULL),
1045  right(NULL),
1046  parent(NULL),
1047  begin(0),
1048  count(0),
1049  stat(*this),
1050  parentDistance(0),
1051  furthestDescendantDistance(0),
1052  dataset(NULL)
1053 {
1054  // Nothing to do.
1055 }
1056 
1060 template<typename MetricType,
1061  typename StatisticType,
1062  typename MatType,
1063  template<typename BoundMetricType, typename...> class BoundType,
1064  template<typename SplitBoundType, typename SplitMatType>
1065  class SplitType>
1066 template<typename Archive>
1068  serialize(Archive& ar, const uint32_t /* version */)
1069 {
1070  // If we're loading, and we have children, they need to be deleted.
1071  if (cereal::is_loading<Archive>())
1072  {
1073  if (left)
1074  delete left;
1075  if (right)
1076  delete right;
1077  if (!parent)
1078  delete dataset;
1079 
1080  parent = NULL;
1081  left = NULL;
1082  right = NULL;
1083  }
1084 
1085  ar(CEREAL_NVP(begin));
1086  ar(CEREAL_NVP(count));
1087  ar(CEREAL_NVP(bound));
1088  ar(CEREAL_NVP(stat));
1089 
1090  ar(CEREAL_NVP(parentDistance));
1091  ar(CEREAL_NVP(furthestDescendantDistance));
1092 
1093  // Save children last.
1094  bool hasLeft = (left != NULL);
1095  bool hasRight = (right != NULL);
1096  bool hasParent = (parent != NULL);
1097 
1098  ar(CEREAL_NVP(hasLeft));
1099  ar(CEREAL_NVP(hasRight));
1100  ar(CEREAL_NVP(hasParent));
1101  if (hasLeft)
1102  ar(CEREAL_POINTER(left));
1103  if (hasRight)
1104  ar(CEREAL_POINTER(right));
1105  if (!hasParent)
1106  {
1107  MatType*& datasetTemp = const_cast<MatType*&>(dataset);
1108  ar(CEREAL_POINTER(datasetTemp));
1109  }
1110 
1111  if (cereal::is_loading<Archive>())
1112  {
1113  if (left)
1114  left->parent = this;
1115  if (right)
1116  right->parent = this;
1117  }
1118  // If we are the root, we need to restore the dataset pointer throughout
1119  if (!hasParent)
1120  {
1121  std::stack<BinarySpaceTree*> stack;
1122  if (left)
1123  stack.push(left);
1124  if (right)
1125  stack.push(right);
1126  while (!stack.empty())
1127  {
1128  BinarySpaceTree* node = stack.top();
1129  stack.pop();
1130  node->dataset = dataset;
1131  if (node->left)
1132  stack.push(node->left);
1133  if (node->right)
1134  stack.push(node->right);
1135  }
1136  }
1137 }
1138 
1139 } // namespace tree
1140 } // namespace mlpack
1141 
1142 #endif
BinarySpaceTree * Parent() const
Gets the parent of this node.
Definition: binary_space_tree.hpp:342
size_t NumDescendants() const
Return the number of descendants of this node.
Definition: binary_space_tree_impl.hpp:826
void serialize(Archive &ar, const uint32_t version)
Serialize the tree.
Definition: binary_space_tree_impl.hpp:1068
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
BinarySpaceTree & operator=(const BinarySpaceTree &other)
Copy the given BinarySaceTree.
Definition: binary_space_tree_impl.hpp:393
MatType::elem_type ElemType
The type of element held in MatType.
Definition: binary_space_tree.hpp:60
size_t Descendant(const size_t index) const
Return the index (with reference to the dataset) of a particular descendant of this node...
Definition: binary_space_tree_impl.hpp:841
size_t Count() const
Return the number of points in this subset.
Definition: binary_space_tree.hpp:503
Definition: pointer_wrapper.hpp:23
const arma::Col< ElemType > & HollowCenter() const
Get the center point of the hollow.
Definition: hollow_ball_bound.hpp:111
BinarySpaceTree * Right() const
Gets the right child of this node.
Definition: binary_space_tree.hpp:337
ElemType MinimumBoundDistance() const
Return the minimum distance from the center of the node to any bound edge.
Definition: binary_space_tree_impl.hpp:773
A binary space partitioning tree, such as a KD-tree or a ball tree.
Definition: binary_space_tree.hpp:54
~BinarySpaceTree()
Deletes this node, deallocating the memory for the children and calling their destructors in turn...
Definition: binary_space_tree_impl.hpp:577
ElemType MinDistance(const BinarySpaceTree &other) const
Return the minimum distance to another node.
Definition: binary_space_tree.hpp:453
size_t Point(const size_t index) const
Return the index (with reference to the dataset) of a particular point in this node.
Definition: binary_space_tree_impl.hpp:856
ElemType ParentDistance() const
Return the distance from the center of this node to the center of the parent node.
Definition: binary_space_tree.hpp:407
ElemType MaxDistance(const BinarySpaceTree &other) const
Return the maximum distance to another node.
Definition: binary_space_tree.hpp:459
size_t NumPoints() const
Return the number of points in this node (0 if not a leaf).
Definition: binary_space_tree_impl.hpp:808
size_t GetFurthestChild(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0)
Return the index of the furthest child node to the given query point.
Definition: binary_space_tree_impl.hpp:655
BinarySpaceTree * Left() const
Gets the left child of this node.
Definition: binary_space_tree.hpp:332
const MatType & Dataset() const
Get the dataset which the tree is built on.
Definition: binary_space_tree.hpp:347
BinarySpaceTree & Child(const size_t child) const
Return the specified child (0 will be left, 1 will be right).
Definition: binary_space_tree_impl.hpp:790
bool IsLeaf() const
Return whether or not this node is a leaf (true if it has no children).
Definition: binary_space_tree_impl.hpp:594
BinarySpaceTree()
A default constructor.
Definition: binary_space_tree_impl.hpp:1043
Definition of generalized binary space partitioning tree (BinarySpaceTree).
size_t GetNearestChild(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0)
Return the index of the nearest child node to the given query point.
Definition: binary_space_tree_impl.hpp:631
void Center(arma::vec &center) const
Store the center of the bounding region in the given vector.
Definition: binary_space_tree.hpp:508
size_t Begin() const
Return the index of the beginning point of this subset.
Definition: binary_space_tree.hpp:498
#define CEREAL_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_wrapper.hpp:96
ElemType FurthestPointDistance() const
Return the furthest distance to a point held in this node.
Definition: binary_space_tree_impl.hpp:731
ElemType InnerRadius() const
Get the innner radius of the ball.
Definition: hollow_ball_bound.hpp:101
Hollow ball bound encloses a set of points at a specific distance (radius) from a specific point (cen...
Definition: hollow_ball_bound.hpp:33
ElemType FurthestDescendantDistance() const
Return the furthest possible descendant distance.
Definition: binary_space_tree_impl.hpp:757
If value == true, then VecType is some sort of Armadillo vector or subview.
Definition: arma_traits.hpp:35
static void Assert(bool condition, const std::string &message="Assert Failed.")
Checks if the specified condition is true.
Definition: log.cpp:38
size_t NumChildren() const
Return the number of children in this node.
Definition: binary_space_tree_impl.hpp:609