mlpack
cover_tree_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_IMPL_HPP
13 #define MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_IMPL_HPP
14 
15 // In case it hasn't already been included.
16 #include "cover_tree.hpp"
17 
18 #include <queue>
19 #include <string>
20 
21 namespace mlpack {
22 namespace tree {
23 
24 // Build the statistics, bottom-up.
25 template<typename TreeType, typename StatisticType>
26 void BuildStatistics(TreeType* node)
27 {
28  // Recurse first.
29  for (size_t i = 0; i < node->NumChildren(); ++i)
30  BuildStatistics<TreeType, StatisticType>(&node->Child(i));
31 
32  // Now build the statistic.
33  node->Stat() = StatisticType(*node);
34 }
35 
36 // Create the cover tree.
37 template<
38  typename MetricType,
39  typename StatisticType,
40  typename MatType,
41  typename RootPointPolicy
42 >
44  const MatType& dataset,
45  const ElemType base,
46  MetricType* metric) :
47  dataset(&dataset),
48  point(RootPointPolicy::ChooseRoot(dataset)),
49  scale(INT_MAX),
50  base(base),
51  numDescendants(0),
52  parent(NULL),
53  parentDistance(0),
54  furthestDescendantDistance(0),
55  localMetric(metric == NULL),
56  localDataset(false),
57  metric(metric),
58  distanceComps(0)
59 {
60  // If we need to create a metric, do that. We'll just do it on the heap.
61  if (localMetric)
62  this->metric = new MetricType();
63 
64  // If there is only one point or zero points in the dataset... uh, we're done.
65  // Technically, if the dataset has zero points, our node is not correct...
66  if (dataset.n_cols <= 1)
67  {
68  scale = INT_MIN;
69  return;
70  }
71 
72  // Kick off the building. Create the indices array and the distances array.
73  arma::Col<size_t> indices = arma::linspace<arma::Col<size_t> >(1,
74  dataset.n_cols - 1, dataset.n_cols - 1);
75  // This is now [1 2 3 4 ... n]. We must be sure that our point does not
76  // occur.
77  if (point != 0)
78  indices[point - 1] = 0; // Put 0 back into the set; remove what was there.
79 
80  arma::vec distances(dataset.n_cols - 1);
81 
82  // Build the initial distances.
83  ComputeDistances(point, indices, distances, dataset.n_cols - 1);
84 
85  // Create the children.
86  size_t farSetSize = 0;
87  size_t usedSetSize = 0;
88  CreateChildren(indices, distances, dataset.n_cols - 1, farSetSize,
89  usedSetSize);
90 
91  // If we ended up creating only one child, remove the implicit node.
92  while (children.size() == 1)
93  {
94  // Prepare to delete the implicit child node.
95  CoverTree* old = children[0];
96 
97  // Now take its children and set their parent correctly.
98  children.erase(children.begin());
99  for (size_t i = 0; i < old->NumChildren(); ++i)
100  {
101  children.push_back(&(old->Child(i)));
102 
103  // Set its parent correctly.
104  old->Child(i).Parent() = this;
105  }
106 
107  // Remove all the children so they don't get erased.
108  old->Children().clear();
109 
110  // Reduce our own scale.
111  scale = old->Scale();
112 
113  // Now delete it.
114  delete old;
115  }
116 
117  // Use the furthest descendant distance to determine the scale of the root
118  // node. Note that if the root is a leaf, we can have scale INT_MIN, but if
119  // it *isn't* a leaf, we need to mark the scale as one higher than INT_MIN, so
120  // that the recursions don't fail.
121  if (furthestDescendantDistance == 0.0 && dataset.n_cols == 1)
122  scale = INT_MIN;
123  else if (furthestDescendantDistance == 0.0)
124  scale = INT_MIN + 1;
125  else
126  scale = (int) ceil(log(furthestDescendantDistance) / log(base));
127 
128  // Initialize statistics recursively after the entire tree construction is
129  // complete.
130  BuildStatistics<CoverTree, StatisticType>(this);
131 
132  Log::Info << distanceComps << " distance computations during tree "
133  << "construction." << std::endl;
134 }
135 
136 template<
137  typename MetricType,
138  typename StatisticType,
139  typename MatType,
140  typename RootPointPolicy
141 >
143  const MatType& dataset,
144  MetricType& metric,
145  const ElemType base) :
146  dataset(&dataset),
147  point(RootPointPolicy::ChooseRoot(dataset)),
148  scale(INT_MAX),
149  base(base),
150  numDescendants(0),
151  parent(NULL),
152  parentDistance(0),
153  furthestDescendantDistance(0),
154  localMetric(true),
155  localDataset(false),
156  metric(new MetricType(metric)),
157  distanceComps(0)
158 {
159  // If there is only one point or zero points in the dataset... uh, we're done.
160  // Technically, if the dataset has zero points, our node is not correct...
161  if (dataset.n_cols <= 1)
162  {
163  scale = INT_MIN;
164  return;
165  }
166 
167  // Kick off the building. Create the indices array and the distances array.
168  arma::Col<size_t> indices = arma::linspace<arma::Col<size_t> >(1,
169  dataset.n_cols - 1, dataset.n_cols - 1);
170  // This is now [1 2 3 4 ... n]. We must be sure that our point does not
171  // occur.
172  if (point != 0)
173  indices[point - 1] = 0; // Put 0 back into the set; remove what was there.
174 
175  arma::vec distances(dataset.n_cols - 1);
176 
177  // Build the initial distances.
178  ComputeDistances(point, indices, distances, dataset.n_cols - 1);
179 
180  // Create the children.
181  size_t farSetSize = 0;
182  size_t usedSetSize = 0;
183  CreateChildren(indices, distances, dataset.n_cols - 1, farSetSize,
184  usedSetSize);
185 
186  // If we ended up creating only one child, remove the implicit node.
187  while (children.size() == 1)
188  {
189  // Prepare to delete the implicit child node.
190  CoverTree* old = children[0];
191 
192  // Now take its children and set their parent correctly.
193  children.erase(children.begin());
194  for (size_t i = 0; i < old->NumChildren(); ++i)
195  {
196  children.push_back(&(old->Child(i)));
197 
198  // Set its parent correctly.
199  old->Child(i).Parent() = this;
200  }
201 
202  // Remove all the children so they don't get erased.
203  old->Children().clear();
204 
205  // Reduce our own scale.
206  scale = old->Scale();
207 
208  // Now delete it.
209  delete old;
210  }
211 
212  // Use the furthest descendant distance to determine the scale of the root
213  // node. Note that if the root is a leaf, we can have scale INT_MIN, but if
214  // it *isn't* a leaf, we need to mark the scale as one higher than INT_MIN, so
215  // that the recursions don't fail.
216  if (furthestDescendantDistance == 0.0 && dataset.n_cols == 1)
217  scale = INT_MIN;
218  else if (furthestDescendantDistance == 0.0)
219  scale = INT_MIN + 1;
220  else
221  scale = (int) ceil(log(furthestDescendantDistance) / log(base));
222 
223  // Initialize statistics recursively after the entire tree construction is
224  // complete.
225  BuildStatistics<CoverTree, StatisticType>(this);
226 
227  Log::Info << distanceComps << " distance computations during tree "
228  << "construction." << std::endl;
229 }
230 
231 template<
232  typename MetricType,
233  typename StatisticType,
234  typename MatType,
235  typename RootPointPolicy
236 >
238  MatType&& data,
239  const ElemType base) :
240  dataset(new MatType(std::move(data))),
241  point(RootPointPolicy::ChooseRoot(dataset)),
242  scale(INT_MAX),
243  base(base),
244  numDescendants(0),
245  parent(NULL),
246  parentDistance(0),
247  furthestDescendantDistance(0),
248  localMetric(true),
249  localDataset(true),
250  distanceComps(0)
251 {
252  // We need to create a metric. We'll just do it on the heap.
253  this->metric = new MetricType();
254 
255  // If there is only one point or zero points in the dataset... uh, we're done.
256  // Technically, if the dataset has zero points, our node is not correct...
257  if (dataset->n_cols <= 1)
258  {
259  scale = INT_MIN;
260  return;
261  }
262 
263  // Kick off the building. Create the indices array and the distances array.
264  arma::Col<size_t> indices = arma::linspace<arma::Col<size_t> >(1,
265  dataset->n_cols - 1, dataset->n_cols - 1);
266  // This is now [1 2 3 4 ... n]. We must be sure that our point does not
267  // occur.
268  if (point != 0)
269  indices[point - 1] = 0; // Put 0 back into the set; remove what was there.
270 
271  arma::vec distances(dataset->n_cols - 1);
272 
273  // Build the initial distances.
274  ComputeDistances(point, indices, distances, dataset->n_cols - 1);
275 
276  // Create the children.
277  size_t farSetSize = 0;
278  size_t usedSetSize = 0;
279  CreateChildren(indices, distances, dataset->n_cols - 1, farSetSize,
280  usedSetSize);
281 
282  // If we ended up creating only one child, remove the implicit node.
283  while (children.size() == 1)
284  {
285  // Prepare to delete the implicit child node.
286  CoverTree* old = children[0];
287 
288  // Now take its children and set their parent correctly.
289  children.erase(children.begin());
290  for (size_t i = 0; i < old->NumChildren(); ++i)
291  {
292  children.push_back(&(old->Child(i)));
293 
294  // Set its parent correctly.
295  old->Child(i).Parent() = this;
296  }
297 
298  // Remove all the children so they don't get erased.
299  old->Children().clear();
300 
301  // Reduce our own scale.
302  scale = old->Scale();
303 
304  // Now delete it.
305  delete old;
306  }
307 
308  // Use the furthest descendant distance to determine the scale of the root
309  // node. Note that if the root is a leaf, we can have scale INT_MIN, but if
310  // it *isn't* a leaf, we need to mark the scale as one higher than INT_MIN, so
311  // that the recursions don't fail.
312  if (furthestDescendantDistance == 0.0 && dataset->n_cols == 1)
313  scale = INT_MIN;
314  else if (furthestDescendantDistance == 0.0)
315  scale = INT_MIN + 1;
316  else
317  scale = (int) ceil(log(furthestDescendantDistance) / log(base));
318 
319  // Initialize statistics recursively after the entire tree construction is
320  // complete.
321  BuildStatistics<CoverTree, StatisticType>(this);
322 
323  Log::Info << distanceComps << " distance computations during tree "
324  << "construction." << std::endl;
325 }
326 
327 template<
328  typename MetricType,
329  typename StatisticType,
330  typename MatType,
331  typename RootPointPolicy
332 >
334  MatType&& data,
335  MetricType& metric,
336  const ElemType base) :
337  dataset(new MatType(std::move(data))),
338  point(RootPointPolicy::ChooseRoot(dataset)),
339  scale(INT_MAX),
340  base(base),
341  numDescendants(0),
342  parent(NULL),
343  parentDistance(0),
344  furthestDescendantDistance(0),
345  localMetric(true),
346  localDataset(true),
347  metric(new MetricType(metric)),
348  distanceComps(0)
349 {
350  // If there is only one point or zero points in the dataset... uh, we're done.
351  // Technically, if the dataset has zero points, our node is not correct...
352  if (dataset->n_cols <= 1)
353  {
354  scale = INT_MIN;
355  return;
356  }
357 
358  // Kick off the building. Create the indices array and the distances array.
359  arma::Col<size_t> indices = arma::linspace<arma::Col<size_t> >(1,
360  dataset->n_cols - 1, dataset->n_cols - 1);
361  // This is now [1 2 3 4 ... n]. We must be sure that our point does not
362  // occur.
363  if (point != 0)
364  indices[point - 1] = 0; // Put 0 back into the set; remove what was there.
365 
366  arma::vec distances(dataset->n_cols - 1);
367 
368  // Build the initial distances.
369  ComputeDistances(point, indices, distances, dataset->n_cols - 1);
370 
371  // Create the children.
372  size_t farSetSize = 0;
373  size_t usedSetSize = 0;
374  CreateChildren(indices, distances, dataset->n_cols - 1, farSetSize,
375  usedSetSize);
376 
377  // If we ended up creating only one child, remove the implicit node.
378  while (children.size() == 1)
379  {
380  // Prepare to delete the implicit child node.
381  CoverTree* old = children[0];
382 
383  // Now take its children and set their parent correctly.
384  children.erase(children.begin());
385  for (size_t i = 0; i < old->NumChildren(); ++i)
386  {
387  children.push_back(&(old->Child(i)));
388 
389  // Set its parent correctly.
390  old->Child(i).Parent() = this;
391  }
392 
393  // Remove all the children so they don't get erased.
394  old->Children().clear();
395 
396  // Reduce our own scale.
397  scale = old->Scale();
398 
399  // Now delete it.
400  delete old;
401  }
402 
403  // Use the furthest descendant distance to determine the scale of the root
404  // node. Note that if the root is a leaf, we can have scale INT_MIN, but if
405  // it *isn't* a leaf, we need to mark the scale as one higher than INT_MIN, so
406  // that the recursions don't fail.
407  if (furthestDescendantDistance == 0.0 && dataset->n_cols == 1)
408  scale = INT_MIN;
409  else if (furthestDescendantDistance == 0.0)
410  scale = INT_MIN + 1;
411  else
412  scale = (int) ceil(log(furthestDescendantDistance) / log(base));
413 
414  // Initialize statistics recursively after the entire tree construction is
415  // complete.
416  BuildStatistics<CoverTree, StatisticType>(this);
417 
418  Log::Info << distanceComps << " distance computations during tree "
419  << "construction." << std::endl;
420 }
421 
422 template<
423  typename MetricType,
424  typename StatisticType,
425  typename MatType,
426  typename RootPointPolicy
427 >
429  const MatType& dataset,
430  const ElemType base,
431  const size_t pointIndex,
432  const int scale,
433  CoverTree* parent,
434  const ElemType parentDistance,
435  arma::Col<size_t>& indices,
436  arma::vec& distances,
437  size_t nearSetSize,
438  size_t& farSetSize,
439  size_t& usedSetSize,
440  MetricType& metric) :
441  dataset(&dataset),
442  point(pointIndex),
443  scale(scale),
444  base(base),
445  numDescendants(0),
446  parent(parent),
447  parentDistance(parentDistance),
448  furthestDescendantDistance(0),
449  localMetric(false),
450  localDataset(false),
451  metric(&metric),
452  distanceComps(0)
453 {
454  // If the size of the near set is 0, this is a leaf.
455  if (nearSetSize == 0)
456  {
457  this->scale = INT_MIN;
458  numDescendants = 1;
459  return;
460  }
461 
462  // Otherwise, create the children.
463  CreateChildren(indices, distances, nearSetSize, farSetSize, usedSetSize);
464 }
465 
466 // Manually create a cover tree node.
467 template<
468  typename MetricType,
469  typename StatisticType,
470  typename MatType,
471  typename RootPointPolicy
472 >
474  const MatType& dataset,
475  const ElemType base,
476  const size_t pointIndex,
477  const int scale,
478  CoverTree* parent,
479  const ElemType parentDistance,
480  const ElemType furthestDescendantDistance,
481  MetricType* metric) :
482  dataset(&dataset),
483  point(pointIndex),
484  scale(scale),
485  base(base),
486  numDescendants(0),
487  parent(parent),
488  parentDistance(parentDistance),
489  furthestDescendantDistance(furthestDescendantDistance),
490  localMetric(metric == NULL),
491  localDataset(false),
492  metric(metric),
493  distanceComps(0)
494 {
495  // If necessary, create a local metric.
496  if (localMetric)
497  this->metric = new MetricType();
498 }
499 
500 // Copy Constructor.
501 template<
502  typename MetricType,
503  typename StatisticType,
504  typename MatType,
505  typename RootPointPolicy
506 >
508  const CoverTree& other) :
509  dataset((other.parent == NULL && other.localDataset) ?
510  new MatType(*other.dataset) : other.dataset),
511  point(other.point),
512  scale(other.scale),
513  base(other.base),
514  stat(other.stat),
515  numDescendants(other.numDescendants),
516  parent(other.parent),
517  parentDistance(other.parentDistance),
518  furthestDescendantDistance(other.furthestDescendantDistance),
519  localMetric(other.localMetric),
520  localDataset(other.parent == NULL && other.localDataset),
521  metric((other.localMetric ? new MetricType() : other.metric)),
522  distanceComps(0)
523 {
524  // Copy each child by hand.
525  for (size_t i = 0; i < other.NumChildren(); ++i)
526  {
527  children.push_back(new CoverTree(other.Child(i)));
528  children[i]->Parent() = this;
529  }
530 
531  // Propagate matrix, but only if we are the root.
532  if (parent == NULL && localDataset)
533  {
534  std::queue<CoverTree*> queue;
535 
536  for (size_t i = 0; i < NumChildren(); ++i)
537  queue.push(children[i]);
538 
539  while (!queue.empty())
540  {
541  CoverTree* node = queue.front();
542  queue.pop();
543 
544  node->dataset = dataset;
545  for (size_t i = 0; i < node->NumChildren(); ++i)
546  queue.push(node->children[i]);
547  }
548  }
549 }
550 
551 // Copy assignment operator: copy the given other tree.
552 template<
553  typename MetricType,
554  typename StatisticType,
555  typename MatType,
556  typename RootPointPolicy
557 >
560 operator=(const CoverTree& other)
561 {
562  if (this == &other)
563  return *this;
564 
565  // Freeing memory that will not be used anymore.
566  if (localDataset)
567  delete dataset;
568 
569  if (localMetric)
570  delete metric;
571 
572  for (size_t i = 0; i < children.size(); ++i)
573  delete children[i];
574  children.clear();
575 
576  dataset = ((other.parent == NULL && other.localDataset) ?
577  new MatType(*other.dataset) : other.dataset);
578  point = other.point;
579  scale = other.scale;
580  base = other.base;
581  stat = other.stat;
582  numDescendants = other.numDescendants;
583  parent = other.parent;
584  parentDistance = other.parentDistance;
585  furthestDescendantDistance = other.furthestDescendantDistance;
586  localMetric = other.localMetric;
587  localDataset = (other.parent == NULL && other.localDataset);
588  metric = (other.localMetric ? new MetricType() : other.metric);
589  distanceComps = 0;
590 
591  // Copy each child by hand.
592  for (size_t i = 0; i < other.NumChildren(); ++i)
593  {
594  children.push_back(new CoverTree(other.Child(i)));
595  children[i]->Parent() = this;
596  }
597 
598  // Propagate matrix, but only if we are the root.
599  if (parent == NULL && localDataset)
600  {
601  std::queue<CoverTree*> queue;
602 
603  for (size_t i = 0; i < NumChildren(); ++i)
604  queue.push(children[i]);
605 
606  while (!queue.empty())
607  {
608  CoverTree* node = queue.front();
609  queue.pop();
610 
611  node->dataset = dataset;
612  for (size_t i = 0; i < node->NumChildren(); ++i)
613  queue.push(node->children[i]);
614  }
615  }
616 
617  return *this;
618 }
619 
620 // Move Constructor.
621 template<
622  typename MetricType,
623  typename StatisticType,
624  typename MatType,
625  typename RootPointPolicy
626 >
628  CoverTree&& other) :
629  dataset(other.dataset),
630  point(other.point),
631  children(std::move(other.children)),
632  scale(other.scale),
633  base(other.base),
634  stat(std::move(other.stat)),
635  numDescendants(other.numDescendants),
636  parent(other.parent),
637  parentDistance(other.parentDistance),
638  furthestDescendantDistance(other.furthestDescendantDistance),
639  localMetric(other.localMetric),
640  localDataset(other.localDataset),
641  metric(other.metric),
642  distanceComps(other.distanceComps)
643 {
644  // Set proper parent pointer.
645  for (size_t i = 0; i < children.size(); ++i)
646  children[i]->Parent() = this;
647 
648  other.dataset = NULL;
649  other.point = 0;
650  other.scale = INT_MIN;
651  other.base = 0;
652  other.numDescendants = 0;
653  other.parent = NULL;
654  other.parentDistance = 0;
655  other.furthestDescendantDistance = 0;
656  other.localMetric = false;
657  other.localDataset = false;
658  other.metric = NULL;
659 }
660 
661 // Move assignment operator: take ownership of the given tree.
662 template<
663  typename MetricType,
664  typename StatisticType,
665  typename MatType,
666  typename RootPointPolicy
667 >
671 {
672  if (this == &other)
673  return *this;
674 
675  // Freeing memory that will not be used anymore.
676  if (localDataset)
677  delete dataset;
678 
679  if (localMetric)
680  delete metric;
681 
682  for (size_t i = 0; i < children.size(); ++i)
683  delete children[i];
684 
685  dataset = other.dataset;
686  point = other.point;
687  children = std::move(other.children);
688  scale = other.scale;
689  base = other.base;
690  stat = std::move(other.stat);
691  numDescendants = other.numDescendants;
692  parent = other.parent;
693  parentDistance = other.parentDistance;
694  furthestDescendantDistance = other.furthestDescendantDistance;
695  localMetric = other.localMetric;
696  localDataset = other.localDataset;
697  metric = other.metric;
698  distanceComps = other.distanceComps;
699 
700  // Set proper parent pointer.
701  for (size_t i = 0; i < children.size(); ++i)
702  children[i]->Parent() = this;
703 
704  other.dataset = NULL;
705  other.point = 0;
706  other.scale = INT_MIN;
707  other.base = 0;
708  other.numDescendants = 0;
709  other.parent = NULL;
710  other.parentDistance = 0;
711  other.furthestDescendantDistance = 0;
712  other.localMetric = false;
713  other.localDataset = false;
714  other.metric = NULL;
715 
716  return *this;
717 }
718 
719 // Construct from a cereal archive.
720 template<
721  typename MetricType,
722  typename StatisticType,
723  typename MatType,
724  typename RootPointPolicy
725 >
726 template<typename Archive>
728  Archive& ar,
729  const typename std::enable_if_t<cereal::is_loading<Archive>()>*) :
730  CoverTree() // Create an empty CoverTree.
731 {
732  // Now, serialize to our empty tree.
733  ar(cereal::make_nvp("this", *this));
734 }
735 
736 
737 template<
738  typename MetricType,
739  typename StatisticType,
740  typename MatType,
741  typename RootPointPolicy
742 >
744 {
745  // Delete each child.
746  for (size_t i = 0; i < children.size(); ++i)
747  delete children[i];
748 
749  // Delete the local metric, if necessary.
750  if (localMetric)
751  delete metric;
752 
753  // Delete the local dataset, if necessary.
754  if (localDataset)
755  delete dataset;
756 }
757 
759 template<
760  typename MetricType,
761  typename StatisticType,
762  typename MatType,
763  typename RootPointPolicy
764 >
765 inline size_t
768 {
769  return numDescendants;
770 }
771 
773 template<
774  typename MetricType,
775  typename StatisticType,
776  typename MatType,
777  typename RootPointPolicy
778 >
779 inline size_t
781  const size_t index) const
782 {
783  // The first descendant is the point contained within this node.
784  if (index == 0)
785  return point;
786 
787  // Is it in the self-child?
788  if (index < children[0]->NumDescendants())
789  return children[0]->Descendant(index);
790 
791  // Now check the other children.
792  size_t sum = children[0]->NumDescendants();
793  for (size_t i = 1; i < children.size(); ++i)
794  {
795  if (index - sum < children[i]->NumDescendants())
796  return children[i]->Descendant(index - sum);
797  sum += children[i]->NumDescendants();
798  }
799 
800  // This should never happen.
801  return (size_t() - 1);
802 }
803 
808 template<typename MetricType,
809  typename StatisticType,
810  typename MatType,
811  typename RootPointPolicy>
812 template<typename VecType>
814  GetNearestChild(const VecType& point,
815  typename std::enable_if_t<IsVector<VecType>::value>*)
816 {
817  if (IsLeaf())
818  return 0;
819 
820  ElemType bestDistance = std::numeric_limits<ElemType>::max();
821  size_t bestIndex = 0;
822  for (size_t i = 0; i < children.size(); ++i)
823  {
824  ElemType distance = children[i]->MinDistance(point);
825  if (distance <= bestDistance)
826  {
827  bestDistance = distance;
828  bestIndex = i;
829  }
830  }
831  return bestIndex;
832 }
833 
838 template<typename MetricType,
839  typename StatisticType,
840  typename MatType,
841  typename RootPointPolicy>
842 template<typename VecType>
844  GetFurthestChild(const VecType& point,
845  typename std::enable_if_t<IsVector<VecType>::value>*)
846 {
847  if (IsLeaf())
848  return 0;
849 
850  ElemType bestDistance = 0;
851  size_t bestIndex = 0;
852  for (size_t i = 0; i < children.size(); ++i)
853  {
854  ElemType distance = children[i]->MaxDistance(point);
855  if (distance >= bestDistance)
856  {
857  bestDistance = distance;
858  bestIndex = i;
859  }
860  }
861  return bestIndex;
862 }
863 
868 template<typename MetricType,
869  typename StatisticType,
870  typename MatType,
871  typename RootPointPolicy>
873  GetNearestChild(const CoverTree& queryNode)
874 {
875  if (IsLeaf())
876  return 0;
877 
878  ElemType bestDistance = std::numeric_limits<ElemType>::max();
879  size_t bestIndex = 0;
880  for (size_t i = 0; i < children.size(); ++i)
881  {
882  ElemType distance = children[i]->MinDistance(queryNode);
883  if (distance <= bestDistance)
884  {
885  bestDistance = distance;
886  bestIndex = i;
887  }
888  }
889  return bestIndex;
890 }
891 
896 template<typename MetricType,
897  typename StatisticType,
898  typename MatType,
899  typename RootPointPolicy>
901  GetFurthestChild(const CoverTree& queryNode)
902 {
903  if (IsLeaf())
904  return 0;
905 
906  ElemType bestDistance = 0;
907  size_t bestIndex = 0;
908  for (size_t i = 0; i < children.size(); ++i)
909  {
910  ElemType distance = children[i]->MaxDistance(queryNode);
911  if (distance >= bestDistance)
912  {
913  bestDistance = distance;
914  bestIndex = i;
915  }
916  }
917  return bestIndex;
918 }
919 
920 template<
921  typename MetricType,
922  typename StatisticType,
923  typename MatType,
924  typename RootPointPolicy
925 >
926 typename CoverTree<MetricType, StatisticType, MatType,
927  RootPointPolicy>::ElemType
929  MinDistance(const CoverTree& other) const
930 {
931  // Every cover tree node will contain points up to base^(scale + 1) away.
932  return std::max(metric->Evaluate(dataset->col(point),
933  other.Dataset().col(other.Point())) -
934  furthestDescendantDistance - other.FurthestDescendantDistance(), 0.0);
935 }
936 
937 template<
938  typename MetricType,
939  typename StatisticType,
940  typename MatType,
941  typename RootPointPolicy
942 >
943 typename CoverTree<MetricType, StatisticType, MatType,
944  RootPointPolicy>::ElemType
946  MinDistance(const CoverTree& other, const ElemType distance) const
947 {
948  // We already have the distance as evaluated by the metric.
949  return std::max(distance - furthestDescendantDistance -
950  other.FurthestDescendantDistance(), 0.0);
951 }
952 
953 template<
954  typename MetricType,
955  typename StatisticType,
956  typename MatType,
957  typename RootPointPolicy
958 >
959 typename CoverTree<MetricType, StatisticType, MatType,
960  RootPointPolicy>::ElemType
962  MinDistance(const arma::vec& other) const
963 {
964  return std::max(metric->Evaluate(dataset->col(point), other) -
965  furthestDescendantDistance, 0.0);
966 }
967 
968 template<
969  typename MetricType,
970  typename StatisticType,
971  typename MatType,
972  typename RootPointPolicy
973 >
974 typename CoverTree<MetricType, StatisticType, MatType,
975  RootPointPolicy>::ElemType
977  MinDistance(const arma::vec& /* other */, const ElemType distance) const
978 {
979  return std::max(distance - furthestDescendantDistance, 0.0);
980 }
981 
982 template<
983  typename MetricType,
984  typename StatisticType,
985  typename MatType,
986  typename RootPointPolicy
987 >
988 typename CoverTree<MetricType, StatisticType, MatType,
989  RootPointPolicy>::ElemType
991  MaxDistance(const CoverTree& other) const
992 {
993  return metric->Evaluate(dataset->col(point),
994  other.Dataset().col(other.Point())) +
995  furthestDescendantDistance + other.FurthestDescendantDistance();
996 }
997 
998 template<
999  typename MetricType,
1000  typename StatisticType,
1001  typename MatType,
1002  typename RootPointPolicy
1003 >
1004 typename CoverTree<MetricType, StatisticType, MatType,
1005  RootPointPolicy>::ElemType
1007  MaxDistance(const CoverTree& other, const ElemType distance) const
1008 {
1009  // We already have the distance as evaluated by the metric.
1010  return distance + furthestDescendantDistance +
1012 }
1013 
1014 template<
1015  typename MetricType,
1016  typename StatisticType,
1017  typename MatType,
1018  typename RootPointPolicy
1019 >
1020 typename CoverTree<MetricType, StatisticType, MatType,
1021  RootPointPolicy>::ElemType
1023  MaxDistance(const arma::vec& other) const
1024 {
1025  return metric->Evaluate(dataset->col(point), other) +
1026  furthestDescendantDistance;
1027 }
1028 
1029 template<
1030  typename MetricType,
1031  typename StatisticType,
1032  typename MatType,
1033  typename RootPointPolicy
1034 >
1035 typename CoverTree<MetricType, StatisticType, MatType,
1036  RootPointPolicy>::ElemType
1038  MaxDistance(const arma::vec& /* other */, const ElemType distance) const
1039 {
1040  return distance + furthestDescendantDistance;
1041 }
1042 
1044 template<
1045  typename MetricType,
1046  typename StatisticType,
1047  typename MatType,
1048  typename RootPointPolicy
1049 >
1050 math::RangeType<typename
1053  RangeDistance(const CoverTree& other) const
1054 {
1055  const ElemType distance = metric->Evaluate(dataset->col(point),
1056  other.Dataset().col(other.Point()));
1057 
1059  result.Lo() = std::max(distance - furthestDescendantDistance -
1060  other.FurthestDescendantDistance(), 0.0);
1061  result.Hi() = distance + furthestDescendantDistance +
1063 
1064  return result;
1065 }
1066 
1069 template<
1070  typename MetricType,
1071  typename StatisticType,
1072  typename MatType,
1073  typename RootPointPolicy
1074 >
1075 math::RangeType<typename
1079  const ElemType distance) const
1080 {
1082  result.Lo() = std::max(distance - furthestDescendantDistance -
1083  other.FurthestDescendantDistance(), 0.0);
1084  result.Hi() = distance + furthestDescendantDistance +
1086 
1087  return result;
1088 }
1089 
1091 template<
1092  typename MetricType,
1093  typename StatisticType,
1094  typename MatType,
1095  typename RootPointPolicy
1096 >
1097 math::RangeType<typename
1100  RangeDistance(const arma::vec& other) const
1101 {
1102  const ElemType distance = metric->Evaluate(dataset->col(point), other);
1103 
1105  std::max(distance - furthestDescendantDistance, 0.0),
1106  distance + furthestDescendantDistance);
1107 }
1108 
1111 template<
1112  typename MetricType,
1113  typename StatisticType,
1114  typename MatType,
1115  typename RootPointPolicy
1116 >
1117 math::RangeType<typename
1120  RangeDistance(const arma::vec& /* other */,
1121  const ElemType distance) const
1122 {
1124  std::max(distance - furthestDescendantDistance, 0.0),
1125  distance + furthestDescendantDistance);
1126 }
1127 
1129 template<
1130  typename MetricType,
1131  typename StatisticType,
1132  typename MatType,
1133  typename RootPointPolicy
1134 >
1135 inline void
1137  arma::Col<size_t>& indices,
1138  arma::vec& distances,
1139  size_t nearSetSize,
1140  size_t& farSetSize,
1141  size_t& usedSetSize)
1142 {
1143  // Determine the next scale level. This should be the first level where there
1144  // are any points in the far set. So, if we know the maximum distance in the
1145  // distances array, this will be the largest i such that
1146  // maxDistance > pow(base, i)
1147  // and using this for the scale factor should guarantee we are not creating an
1148  // implicit node. If the maximum distance is 0, every point in the near set
1149  // will be created as a leaf, and a child to this node. We also do not need
1150  // to change the furthestChildDistance or furthestDescendantDistance.
1151  const ElemType maxDistance = max(distances.rows(0,
1152  nearSetSize + farSetSize - 1));
1153  if (maxDistance == 0)
1154  {
1155  // Make the self child at the lowest possible level.
1156  // This should not modify farSetSize or usedSetSize.
1157  size_t tempSize = 0;
1158  children.push_back(new CoverTree(*dataset, base, point, INT_MIN, this, 0,
1159  indices, distances, 0, tempSize, usedSetSize, *metric));
1160  distanceComps += children.back()->DistanceComps();
1161 
1162  // Every point in the near set should be a leaf.
1163  for (size_t i = 0; i < nearSetSize; ++i)
1164  {
1165  // farSetSize and usedSetSize will not be modified.
1166  children.push_back(new CoverTree(*dataset, base, indices[i],
1167  INT_MIN, this, distances[i], indices, distances, 0, tempSize,
1168  usedSetSize, *metric));
1169  distanceComps += children.back()->DistanceComps();
1170  usedSetSize++;
1171  }
1172 
1173  // The number of descendants is just the number of children, because each of
1174  // them are leaves and contain one point.
1175  numDescendants = children.size();
1176 
1177  // Re-sort the dataset. We have
1178  // [ used | far | other used ]
1179  // and we want
1180  // [ far | all used ].
1181  SortPointSet(indices, distances, 0, usedSetSize, farSetSize);
1182 
1183  return;
1184  }
1185 
1186  const int nextScale = std::min(scale,
1187  (int) ceil(log(maxDistance) / log(base))) - 1;
1188  const ElemType bound = pow(base, nextScale);
1189 
1190  // First, make the self child. We must split the given near set into the near
1191  // set and far set for the self child.
1192  size_t childNearSetSize =
1193  SplitNearFar(indices, distances, bound, nearSetSize);
1194 
1195  // Build the self child (recursively).
1196  size_t childFarSetSize = nearSetSize - childNearSetSize;
1197  size_t childUsedSetSize = 0;
1198  children.push_back(new CoverTree(*dataset, base, point, nextScale, this, 0,
1199  indices, distances, childNearSetSize, childFarSetSize, childUsedSetSize,
1200  *metric));
1201  // Don't double-count the self-child (so, subtract one).
1202  numDescendants += children[0]->NumDescendants();
1203 
1204  // The self-child can't modify the furthestChildDistance away from 0, but it
1205  // can modify the furthestDescendantDistance.
1206  furthestDescendantDistance = children[0]->FurthestDescendantDistance();
1207 
1208  // Remove any implicit nodes we may have created.
1209  RemoveNewImplicitNodes();
1210 
1211  distanceComps += children[0]->DistanceComps();
1212 
1213  // Now the arrays, in memory, look like this:
1214  // [ childFar | childUsed | far | used ]
1215  // but we need to move the used points past our far set:
1216  // [ childFar | far | childUsed + used ]
1217  // and keeping in mind that childFar = our near set,
1218  // [ near | far | childUsed + used ]
1219  // is what we are trying to make.
1220  SortPointSet(indices, distances, childFarSetSize, childUsedSetSize,
1221  farSetSize);
1222 
1223  // Update size of near set and used set.
1224  nearSetSize -= childUsedSetSize;
1225  usedSetSize += childUsedSetSize;
1226 
1227  // Now for each point in the near set, we need to make children. To save
1228  // computation later, we'll create an array holding the points in the near
1229  // set, and then after each run we'll check which of those (if any) were used
1230  // and we will remove them. ...if that's faster. I think it is.
1231  while (nearSetSize > 0)
1232  {
1233  size_t newPointIndex = nearSetSize - 1;
1234 
1235  // Swap to front if necessary.
1236  if (newPointIndex != 0)
1237  {
1238  const size_t tempIndex = indices[newPointIndex];
1239  const ElemType tempDist = distances[newPointIndex];
1240 
1241  indices[newPointIndex] = indices[0];
1242  distances[newPointIndex] = distances[0];
1243 
1244  indices[0] = tempIndex;
1245  distances[0] = tempDist;
1246  }
1247 
1248  // Will this be a new furthest child?
1249  if (distances[0] > furthestDescendantDistance)
1250  furthestDescendantDistance = distances[0];
1251 
1252  // If there's only one point left, we don't need this crap.
1253  if ((nearSetSize == 1) && (farSetSize == 0))
1254  {
1255  size_t childNearSetSize = 0;
1256  children.push_back(new CoverTree(*dataset, base, indices[0], nextScale,
1257  this, distances[0], indices, distances, childNearSetSize, farSetSize,
1258  usedSetSize, *metric));
1259  distanceComps += children.back()->DistanceComps();
1260  numDescendants += children.back()->NumDescendants();
1261 
1262  // Because the far set size is 0, we don't have to do any swapping to
1263  // move the point into the used set.
1264  ++usedSetSize;
1265  --nearSetSize;
1266 
1267  // And we're done.
1268  break;
1269  }
1270 
1271  // Create the near and far set indices and distance vectors. We don't fill
1272  // in the self-point, yet.
1273  arma::Col<size_t> childIndices(nearSetSize + farSetSize);
1274  childIndices.rows(0, (nearSetSize + farSetSize - 2)) = indices.rows(1,
1275  nearSetSize + farSetSize - 1);
1276  arma::vec childDistances(nearSetSize + farSetSize);
1277 
1278  // Build distances for the child.
1279  ComputeDistances(indices[0], childIndices, childDistances, nearSetSize
1280  + farSetSize - 1);
1281 
1282  // Split into near and far sets for this point.
1283  childNearSetSize = SplitNearFar(childIndices, childDistances, bound,
1284  nearSetSize + farSetSize - 1);
1285  childFarSetSize = PruneFarSet(childIndices, childDistances,
1286  base * bound, childNearSetSize,
1287  (nearSetSize + farSetSize - 1));
1288 
1289  // Now that we know the near and far set sizes, we can put the used point
1290  // (the self point) in the correct place; now, when we call
1291  // MoveToUsedSet(), it will move the self-point correctly. The distance
1292  // does not matter.
1293  childIndices(childNearSetSize + childFarSetSize) = indices[0];
1294  childDistances(childNearSetSize + childFarSetSize) = 0;
1295 
1296  // Build this child (recursively).
1297  childUsedSetSize = 1; // Mark self point as used.
1298  children.push_back(new CoverTree(*dataset, base, indices[0], nextScale,
1299  this, distances[0], childIndices, childDistances, childNearSetSize,
1300  childFarSetSize, childUsedSetSize, *metric));
1301  numDescendants += children.back()->NumDescendants();
1302 
1303  // Remove any implicit nodes.
1304  RemoveNewImplicitNodes();
1305 
1306  distanceComps += children.back()->DistanceComps();
1307 
1308  // Now with the child created, it returns the childIndices and
1309  // childDistances vectors in this form:
1310  // [ childFar | childUsed ]
1311  // For each point in the childUsed set, we must move that point to the used
1312  // set in our own vector.
1313  MoveToUsedSet(indices, distances, nearSetSize, farSetSize, usedSetSize,
1314  childIndices, childFarSetSize, childUsedSetSize);
1315  }
1316 
1317  // Calculate furthest descendant.
1318  for (size_t i = (nearSetSize + farSetSize); i < (nearSetSize + farSetSize +
1319  usedSetSize); ++i)
1320  if (distances[i] > furthestDescendantDistance)
1321  furthestDescendantDistance = distances[i];
1322 }
1323 
1324 template<
1325  typename MetricType,
1326  typename StatisticType,
1327  typename MatType,
1328  typename RootPointPolicy
1329 >
1331  SplitNearFar(arma::Col<size_t>& indices,
1332  arma::vec& distances,
1333  const ElemType bound,
1334  const size_t pointSetSize)
1335 {
1336  // Sanity check; there is no guarantee that this condition will not be true.
1337  // ...or is there?
1338  if (pointSetSize <= 1)
1339  return 0;
1340 
1341  // We'll traverse from both left and right.
1342  size_t left = 0;
1343  size_t right = pointSetSize - 1;
1344 
1345  // A modification of quicksort, with the pivot value set to the bound.
1346  // Everything on the left of the pivot will be less than or equal to the
1347  // bound; everything on the right will be greater than the bound.
1348  while ((distances[left] <= bound) && (left != right))
1349  ++left;
1350  while ((distances[right] > bound) && (left != right))
1351  --right;
1352 
1353  while (left != right)
1354  {
1355  // Now swap the values and indices.
1356  const size_t tempPoint = indices[left];
1357  const ElemType tempDist = distances[left];
1358 
1359  indices[left] = indices[right];
1360  distances[left] = distances[right];
1361 
1362  indices[right] = tempPoint;
1363  distances[right] = tempDist;
1364 
1365  // Traverse the left, seeing how many points are correctly on that side.
1366  // When we encounter an incorrect point, stop. We will switch it later.
1367  while ((distances[left] <= bound) && (left != right))
1368  ++left;
1369 
1370  // Traverse the right, seeing how many points are correctly on that side.
1371  // When we encounter an incorrect point, stop. We will switch it with the
1372  // wrong point from the left side.
1373  while ((distances[right] > bound) && (left != right))
1374  --right;
1375  }
1376 
1377  // The final left value is the index of the first far value.
1378  return left;
1379 }
1380 
1381 // Returns the maximum distance between points.
1382 template<
1383  typename MetricType,
1384  typename StatisticType,
1385  typename MatType,
1386  typename RootPointPolicy
1387 >
1389  ComputeDistances(const size_t pointIndex,
1390  const arma::Col<size_t>& indices,
1391  arma::vec& distances,
1392  const size_t pointSetSize)
1393 {
1394  // For each point, rebuild the distances. The indices do not need to be
1395  // modified.
1396  distanceComps += pointSetSize;
1397  for (size_t i = 0; i < pointSetSize; ++i)
1398  {
1399  distances[i] = metric->Evaluate(dataset->col(pointIndex),
1400  dataset->col(indices[i]));
1401  }
1402 }
1403 
1404 template<
1405  typename MetricType,
1406  typename StatisticType,
1407  typename MatType,
1408  typename RootPointPolicy
1409 >
1411  SortPointSet(arma::Col<size_t>& indices,
1412  arma::vec& distances,
1413  const size_t childFarSetSize,
1414  const size_t childUsedSetSize,
1415  const size_t farSetSize)
1416 {
1417  // We'll use low-level memcpy calls ourselves, just to ensure it's done
1418  // quickly and the way we want it to be. Unfortunately this takes up more
1419  // memory than one-element swaps, but there's not a great way around that.
1420  const size_t bufferSize = std::min(farSetSize, childUsedSetSize);
1421  const size_t bigCopySize = std::max(farSetSize, childUsedSetSize);
1422 
1423  // Sanity check: there is no need to sort if the buffer size is going to be
1424  // zero.
1425  if (bufferSize == 0)
1426  return (childFarSetSize + farSetSize);
1427 
1428  size_t* indicesBuffer = new size_t[bufferSize];
1429  ElemType* distancesBuffer = new ElemType[bufferSize];
1430 
1431  // The start of the memory region to copy to the buffer.
1432  const size_t bufferFromLocation = ((bufferSize == farSetSize) ?
1433  (childFarSetSize + childUsedSetSize) : childFarSetSize);
1434  // The start of the memory region to move directly to the new place.
1435  const size_t directFromLocation = ((bufferSize == farSetSize) ?
1436  childFarSetSize : (childFarSetSize + childUsedSetSize));
1437  // The destination to copy the buffer back to.
1438  const size_t bufferToLocation = ((bufferSize == farSetSize) ?
1439  childFarSetSize : (childFarSetSize + farSetSize));
1440  // The destination of the directly moved memory region.
1441  const size_t directToLocation = ((bufferSize == farSetSize) ?
1442  (childFarSetSize + farSetSize) : childFarSetSize);
1443 
1444  // Copy the smaller piece to the buffer.
1445  memcpy(indicesBuffer, indices.memptr() + bufferFromLocation,
1446  sizeof(size_t) * bufferSize);
1447  memcpy(distancesBuffer, distances.memptr() + bufferFromLocation,
1448  sizeof(ElemType) * bufferSize);
1449 
1450  // Now move the other memory.
1451  memmove(indices.memptr() + directToLocation,
1452  indices.memptr() + directFromLocation, sizeof(size_t) * bigCopySize);
1453  memmove(distances.memptr() + directToLocation,
1454  distances.memptr() + directFromLocation, sizeof(ElemType) * bigCopySize);
1455 
1456  // Now copy the temporary memory to the right place.
1457  memcpy(indices.memptr() + bufferToLocation, indicesBuffer,
1458  sizeof(size_t) * bufferSize);
1459  memcpy(distances.memptr() + bufferToLocation, distancesBuffer,
1460  sizeof(ElemType) * bufferSize);
1461 
1462  delete[] indicesBuffer;
1463  delete[] distancesBuffer;
1464 
1465  // This returns the complete size of the far set.
1466  return (childFarSetSize + farSetSize);
1467 }
1468 
1469 template<
1470  typename MetricType,
1471  typename StatisticType,
1472  typename MatType,
1473  typename RootPointPolicy
1474 >
1476  MoveToUsedSet(arma::Col<size_t>& indices,
1477  arma::vec& distances,
1478  size_t& nearSetSize,
1479  size_t& farSetSize,
1480  size_t& usedSetSize,
1481  arma::Col<size_t>& childIndices,
1482  const size_t childFarSetSize, // childNearSetSize is 0 here.
1483  const size_t childUsedSetSize)
1484 {
1485  const size_t originalSum = nearSetSize + farSetSize + usedSetSize;
1486 
1487  // Loop across the set. We will swap points as we need. It should be noted
1488  // that farSetSize and nearSetSize may change with each iteration of this loop
1489  // (depending on if we make a swap or not).
1490  size_t startChildUsedSet = 0; // Where to start in the child set.
1491  for (size_t i = 0; i < nearSetSize; ++i)
1492  {
1493  // Discover if this point was in the child's used set.
1494  for (size_t j = startChildUsedSet; j < childUsedSetSize; ++j)
1495  {
1496  if (childIndices[childFarSetSize + j] == indices[i])
1497  {
1498  // We have found a point; a swap is necessary.
1499 
1500  // Since this point is from the near set, to preserve the near set, we
1501  // must do a swap.
1502  if (farSetSize > 0)
1503  {
1504  if ((nearSetSize - 1) != i)
1505  {
1506  // In this case it must be a three-way swap.
1507  size_t tempIndex = indices[nearSetSize + farSetSize - 1];
1508  ElemType tempDist = distances[nearSetSize + farSetSize - 1];
1509 
1510  size_t tempNearIndex = indices[nearSetSize - 1];
1511  ElemType tempNearDist = distances[nearSetSize - 1];
1512 
1513  indices[nearSetSize + farSetSize - 1] = indices[i];
1514  distances[nearSetSize + farSetSize - 1] = distances[i];
1515 
1516  indices[nearSetSize - 1] = tempIndex;
1517  distances[nearSetSize - 1] = tempDist;
1518 
1519  indices[i] = tempNearIndex;
1520  distances[i] = tempNearDist;
1521  }
1522  else
1523  {
1524  // We can do a two-way swap.
1525  size_t tempIndex = indices[nearSetSize + farSetSize - 1];
1526  ElemType tempDist = distances[nearSetSize + farSetSize - 1];
1527 
1528  indices[nearSetSize + farSetSize - 1] = indices[i];
1529  distances[nearSetSize + farSetSize - 1] = distances[i];
1530 
1531  indices[i] = tempIndex;
1532  distances[i] = tempDist;
1533  }
1534  }
1535  else if ((nearSetSize - 1) != i)
1536  {
1537  // A two-way swap is possible.
1538  size_t tempIndex = indices[nearSetSize + farSetSize - 1];
1539  ElemType tempDist = distances[nearSetSize + farSetSize - 1];
1540 
1541  indices[nearSetSize + farSetSize - 1] = indices[i];
1542  distances[nearSetSize + farSetSize - 1] = distances[i];
1543 
1544  indices[i] = tempIndex;
1545  distances[i] = tempDist;
1546  }
1547  else
1548  {
1549  // No swap is necessary.
1550  }
1551 
1552  // We don't need to do a complete preservation of the child index set,
1553  // but we want to make sure we only loop over points we haven't seen.
1554  // So increment the child counter by 1 and move a point if we need.
1555  if (j != startChildUsedSet)
1556  {
1557  childIndices[childFarSetSize + j] = childIndices[childFarSetSize +
1558  startChildUsedSet];
1559  }
1560 
1561  // Update all counters from the swaps we have done.
1562  ++startChildUsedSet;
1563  --nearSetSize;
1564  --i; // Since we moved a point out of the near set we must step back.
1565 
1566  break; // Break out of this for loop; back to the first one.
1567  }
1568  }
1569  }
1570 
1571  // Now loop over the far set. This loop is different because we only require
1572  // a normal two-way swap instead of the three-way swap to preserve the near
1573  // set / far set ordering.
1574  for (size_t i = 0; i < farSetSize; ++i)
1575  {
1576  // Discover if this point was in the child's used set.
1577  for (size_t j = startChildUsedSet; j < childUsedSetSize; ++j)
1578  {
1579  if (childIndices[childFarSetSize + j] == indices[i + nearSetSize])
1580  {
1581  // We have found a point to swap.
1582 
1583  // Perform the swap.
1584  size_t tempIndex = indices[nearSetSize + farSetSize - 1];
1585  ElemType tempDist = distances[nearSetSize + farSetSize - 1];
1586 
1587  indices[nearSetSize + farSetSize - 1] = indices[nearSetSize + i];
1588  distances[nearSetSize + farSetSize - 1] = distances[nearSetSize + i];
1589 
1590  indices[nearSetSize + i] = tempIndex;
1591  distances[nearSetSize + i] = tempDist;
1592 
1593  if (j != startChildUsedSet)
1594  {
1595  childIndices[childFarSetSize + j] = childIndices[childFarSetSize +
1596  startChildUsedSet];
1597  }
1598 
1599  // Update all counters from the swaps we have done.
1600  ++startChildUsedSet;
1601  --farSetSize;
1602  --i;
1603 
1604  break; // Break out of this for loop; back to the first one.
1605  }
1606  }
1607  }
1608 
1609  // Update used set size.
1610  usedSetSize += childUsedSetSize;
1611 
1612  Log::Assert(originalSum == (nearSetSize + farSetSize + usedSetSize));
1613 }
1614 
1615 template<
1616  typename MetricType,
1617  typename StatisticType,
1618  typename MatType,
1619  typename RootPointPolicy
1620 >
1622  PruneFarSet(arma::Col<size_t>& indices,
1623  arma::vec& distances,
1624  const ElemType bound,
1625  const size_t nearSetSize,
1626  const size_t pointSetSize)
1627 {
1628  // What we are trying to do is remove any points greater than the bound from
1629  // the far set. We don't care what happens to those indices and distances...
1630  // so, we don't need to properly swap points -- just drop new ones in place.
1631  size_t left = nearSetSize;
1632  size_t right = pointSetSize - 1;
1633  while ((distances[left] <= bound) && (left != right))
1634  ++left;
1635  while ((distances[right] > bound) && (left != right))
1636  --right;
1637 
1638  while (left != right)
1639  {
1640  // We don't care what happens to the point which should be on the right.
1641  indices[left] = indices[right];
1642  distances[left] = distances[right];
1643  --right; // Since we aren't changing the right.
1644 
1645  // Advance to next location which needs to switch.
1646  while ((distances[left] <= bound) && (left != right))
1647  ++left;
1648  while ((distances[right] > bound) && (left != right))
1649  --right;
1650  }
1651 
1652  // The far set size is the left pointer, with the near set size subtracted
1653  // from it.
1654  return (left - nearSetSize);
1655 }
1656 
1661 template<
1662  typename MetricType,
1663  typename StatisticType,
1664  typename MatType,
1665  typename RootPointPolicy
1666 >
1669 {
1670  // If we created an implicit node, take its self-child instead (this could
1671  // happen multiple times).
1672  while (children[children.size() - 1]->NumChildren() == 1)
1673  {
1674  CoverTree* old = children[children.size() - 1];
1675  children.erase(children.begin() + children.size() - 1);
1676 
1677  // Now take its child.
1678  children.push_back(&(old->Child(0)));
1679 
1680  // Set its parent and parameters correctly.
1681  old->Child(0).Parent() = this;
1682  old->Child(0).ParentDistance() = old->ParentDistance();
1683  old->Child(0).DistanceComps() = old->DistanceComps();
1684 
1685  // Remove its child (so it doesn't delete it).
1686  old->Children().erase(old->Children().begin() + old->Children().size() - 1);
1687 
1688  // Now delete it.
1689  delete old;
1690  }
1691 }
1692 
1696 template<
1697  typename MetricType,
1698  typename StatisticType,
1699  typename MatType,
1700  typename RootPointPolicy
1701 >
1703  dataset(NULL),
1704  point(0),
1705  scale(INT_MIN),
1706  base(0.0),
1707  numDescendants(0),
1708  parent(NULL),
1709  parentDistance(0.0),
1710  furthestDescendantDistance(0.0),
1711  localMetric(false),
1712  localDataset(false),
1713  metric(NULL),
1714  distanceComps(0)
1715 {
1716  // Nothing to do.
1717 }
1718 
1722 template<
1723  typename MetricType,
1724  typename StatisticType,
1725  typename MatType,
1726  typename RootPointPolicy
1727 >
1728 template<typename Archive>
1730  Archive& ar,
1731  const uint32_t /* version */)
1732 {
1733  // If we're loading, and we have children, they need to be deleted. We may
1734  // also need to delete the local metric and dataset.
1735  if (cereal::is_loading<Archive>())
1736  {
1737  for (size_t i = 0; i < children.size(); ++i)
1738  delete children[i];
1739 
1740  if (localMetric && metric)
1741  delete metric;
1742  if (localDataset && dataset)
1743  delete dataset;
1744 
1745  parent = NULL;
1746  }
1747 
1748  bool hasParent = (parent != NULL);
1749  ar(CEREAL_NVP(hasParent));
1750  MatType*& datasetTemp = const_cast<MatType*&>(dataset);
1751  if (!hasParent)
1752  ar(CEREAL_POINTER(datasetTemp));
1753 
1754  ar(CEREAL_NVP(point));
1755  ar(CEREAL_NVP(scale));
1756  ar(CEREAL_NVP(base));
1757  ar(CEREAL_NVP(stat));
1758  ar(CEREAL_NVP(numDescendants));
1759  ar(CEREAL_NVP(parentDistance));
1760  ar(CEREAL_NVP(furthestDescendantDistance));
1761  ar(CEREAL_POINTER(metric));
1762 
1763  if (cereal::is_loading<Archive>() && !hasParent)
1764  {
1765  localMetric = true;
1766  localDataset = true;
1767  }
1768 
1769  // Lastly, serialize the children.
1770  ar(CEREAL_VECTOR_POINTER(children));
1771 
1772  if (cereal::is_loading<Archive>())
1773  {
1774  // Look through each child individually.
1775  for (size_t i = 0; i < children.size(); ++i)
1776  {
1777  children[i]->localMetric = false;
1778  children[i]->localDataset = false;
1779  children[i]->Parent() = this;
1780  }
1781  }
1782 
1783  if (!hasParent)
1784  {
1785  std::stack<CoverTree*> stack;
1786  for (size_t i = 0; i < children.size(); ++i)
1787  {
1788  stack.push(children[i]);
1789  }
1790  while (!stack.empty())
1791  {
1792  CoverTree* node = stack.top();
1793  stack.pop();
1794  node->dataset = dataset;
1795  for (size_t i = 0; i < node->children.size(); ++i)
1796  {
1797  stack.push(node->children[i]);
1798  }
1799  }
1800  }
1801 }
1802 
1803 } // namespace tree
1804 } // namespace mlpack
1805 
1806 #endif
T Lo() const
Get the lower bound.
Definition: range.hpp:61
CoverTree()
A default constructor.
Definition: cover_tree_impl.hpp:1702
ElemType ParentDistance() const
Get the distance to the parent.
Definition: cover_tree.hpp:409
int Scale() const
Get the scale of this node.
Definition: cover_tree.hpp:315
void serialize(Archive &ar, const uint32_t)
Serialize the tree.
Definition: cover_tree_impl.hpp:1729
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
CoverTree & operator=(const CoverTree &other)
Copy the given Cover Tree.
Definition: cover_tree_impl.hpp:560
Definition: pointer_wrapper.hpp:23
MatType::elem_type ElemType
The type held by the matrix type.
Definition: cover_tree.hpp:105
size_t GetNearestChild(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0)
Return the index of the nearest child node to the given query point.
Definition: cover_tree_impl.hpp:814
size_t NumDescendants() const
Get the number of descendant points.
Definition: cover_tree_impl.hpp:767
const std::vector< CoverTree * > & Children() const
Get the children.
Definition: cover_tree.hpp:304
const MatType & Dataset() const
Get a reference to the dataset.
Definition: cover_tree.hpp:283
ElemType FurthestDescendantDistance() const
Get the distance from the center of the node to the furthest descendant.
Definition: cover_tree.hpp:417
Simple real-valued range.
Definition: range.hpp:19
size_t GetFurthestChild(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0)
Return the index of the furthest child node to the given query point.
Definition: cover_tree_impl.hpp:844
#define CEREAL_VECTOR_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_vector_wrapper.hpp:93
size_t Point() const
Get the index of the point which this node represents.
Definition: cover_tree.hpp:286
const CoverTree & Child(const size_t index) const
Get a particular child node.
Definition: cover_tree.hpp:294
T Hi() const
Get the upper bound.
Definition: range.hpp:66
ElemType MaxDistance(const CoverTree &other) const
Return the maximum distance to another node.
Definition: cover_tree_impl.hpp:991
ElemType MinDistance(const CoverTree &other) const
Return the minimum distance to another node.
Definition: cover_tree_impl.hpp:929
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84
#define CEREAL_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_wrapper.hpp:96
math::RangeType< ElemType > RangeDistance(const CoverTree &other) const
Return the minimum and maximum distance to another node.
Definition: cover_tree_impl.hpp:1053
size_t NumChildren() const
Get the number of children.
Definition: cover_tree.hpp:301
A cover tree is a tree specifically designed to speed up nearest-neighbor computation in high-dimensi...
Definition: cover_tree.hpp:99
If value == true, then VecType is some sort of Armadillo vector or subview.
Definition: arma_traits.hpp:35
size_t Descendant(const size_t index) const
Get the index of a particular descendant point.
Definition: cover_tree_impl.hpp:780
static void Assert(bool condition, const std::string &message="Assert Failed.")
Checks if the specified condition is true.
Definition: log.cpp:38
~CoverTree()
Delete this cover tree node and its children.
Definition: cover_tree_impl.hpp:743
CoverTree * Parent() const
Get the parent node.
Definition: cover_tree.hpp:404