mlpack
range_search_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_IMPL_HPP
13 #define MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_IMPL_HPP
14 
15 // Just in case it hasn't been included.
16 #include "range_search.hpp"
17 
18 // The rules for traversal.
19 #include "range_search_rules.hpp"
20 
21 namespace mlpack {
22 namespace range {
23 
24 template<typename TreeType, typename MatType>
25 TreeType* BuildTree(
26  MatType&& dataset,
27  std::vector<size_t>& oldFromNew,
28  const typename std::enable_if<
30 {
31  return new TreeType(std::forward<MatType>(dataset), oldFromNew);
32 }
33 
35 template<typename TreeType, typename MatType>
36 TreeType* BuildTree(
37  MatType&& dataset,
38  const std::vector<size_t>& /* oldFromNew */,
39  const typename std::enable_if<
41 {
42  return new TreeType(std::forward<MatType>(dataset));
43 }
44 
45 template<typename MetricType,
46  typename MatType,
47  template<typename TreeMetricType,
48  typename TreeStatType,
49  typename TreeMatType> class TreeType>
50 RangeSearch<MetricType, MatType, TreeType>::RangeSearch(
51  MatType referenceSet,
52  const bool naive,
53  const bool singleMode,
54  const MetricType metric) :
55  referenceTree(naive ? NULL : BuildTree<Tree>(std::move(referenceSet),
56  oldFromNewReferences)),
57  referenceSet(naive ? new MatType(std::move(referenceSet)) :
58  &referenceTree->Dataset()),
59  treeOwner(!naive),
60  naive(naive),
61  singleMode(!naive && singleMode),
62  metric(metric),
63  baseCases(0),
64  scores(0)
65 {
66  // Nothing to do.
67 }
68 
69 template<typename MetricType,
70  typename MatType,
71  template<typename TreeMetricType,
72  typename TreeStatType,
73  typename TreeMatType> class TreeType>
74 RangeSearch<MetricType, MatType, TreeType>::RangeSearch(
75  Tree* referenceTree,
76  const bool singleMode,
77  const MetricType metric) :
78  referenceTree(referenceTree),
79  referenceSet(&referenceTree->Dataset()),
80  treeOwner(false),
81  naive(false),
82  singleMode(singleMode),
83  metric(metric),
84  baseCases(0),
85  scores(0)
86 {
87  // Nothing else to initialize.
88 }
89 
90 template<typename MetricType,
91  typename MatType,
92  template<typename TreeMetricType,
93  typename TreeStatType,
94  typename TreeMatType> class TreeType>
95 RangeSearch<MetricType, MatType, TreeType>::RangeSearch(
96  const bool naive,
97  const bool singleMode,
98  const MetricType metric) :
99  referenceTree(NULL),
100  referenceSet(naive ? new MatType() : NULL), // Empty matrix.
101  treeOwner(false),
102  naive(naive),
103  singleMode(singleMode),
104  metric(metric),
105  baseCases(0),
106  scores(0)
107 {
108  // Build the tree on the empty dataset, if necessary.
109  if (!naive)
110  {
111  referenceTree = BuildTree<Tree>(std::move(arma::mat()),
112  oldFromNewReferences);
113  referenceSet = &referenceTree->Dataset();
114  treeOwner = true;
115  }
116 }
117 
118 template<typename MetricType,
119  typename MatType,
120  template<typename TreeMetricType,
121  typename TreeStatType,
122  typename TreeMatType> class TreeType>
123 RangeSearch<MetricType, MatType, TreeType>::RangeSearch(
124  const RangeSearch& other) :
125  oldFromNewReferences(other.oldFromNewReferences),
126  referenceTree(other.referenceTree ? new Tree(*other.referenceTree) : NULL),
127  referenceSet(other.referenceTree ? &referenceTree->Dataset() :
128  new MatType(*other.referenceSet)),
129  treeOwner(other.referenceTree),
130  naive(other.naive),
131  singleMode(other.singleMode),
132  metric(other.metric),
133  baseCases(other.baseCases),
134  scores(other.scores)
135 {
136  // Nothing to do.
137 }
138 
139 template<typename MetricType,
140  typename MatType,
141  template<typename TreeMetricType,
142  typename TreeStatType,
143  typename TreeMatType> class TreeType>
144 RangeSearch<MetricType, MatType, TreeType>::RangeSearch(RangeSearch&& other) :
145  oldFromNewReferences(std::move(other.oldFromNewReferences)),
146  referenceTree(other.referenceTree),
147  referenceSet(other.referenceSet),
148  treeOwner(other.treeOwner),
149  naive(other.naive),
150  singleMode(other.singleMode),
151  metric(std::move(other.metric)),
152  baseCases(other.baseCases),
153  scores(other.scores)
154 {
155  // Clear other object.
156  other.referenceTree =
157  BuildTree<Tree>(std::move(arma::mat()), other.oldFromNewReferences);
158  other.referenceSet = &other.referenceTree->Dataset();
159  other.treeOwner = true;
160  other.naive = false;
161  other.singleMode = false;
162  other.baseCases = 0;
163  other.scores = 0;
164 }
165 
166 template<typename MetricType,
167  typename MatType,
168  template<typename TreeMetricType,
169  typename TreeStatType,
170  typename TreeMatType> class TreeType>
171 RangeSearch<MetricType, MatType, TreeType>&
173 {
174  if (this != &other)
175  {
176  oldFromNewReferences = other.oldFromNewReferences;
177  referenceTree = other.referenceTree ? new Tree(*other.referenceTree) :
178  nullptr;
179  referenceSet = other.referenceTree ? &referenceTree->Dataset() :
180  new MatType(*other.referenceSet);
181  treeOwner = other.referenceTree;
182  naive = other.naive;
183  singleMode = other.singleMode;
184  metric = other.metric;
185  baseCases = other.baseCases;
186  scores = other.scores;
187  }
188  return *this;
189 }
190 
191 template<typename MetricType,
192  typename MatType,
193  template<typename TreeMetricType,
194  typename TreeStatType,
195  typename TreeMatType> class TreeType>
196 RangeSearch<MetricType, MatType, TreeType>&
198 {
199  if (this != &other)
200  {
201  // Clean memory first.
202  if (treeOwner)
203  delete referenceTree;
204  if (naive)
205  delete referenceSet;
206 
207  // Move the other model.
208  oldFromNewReferences = std::move(other.oldFromNewReferences);
209  referenceTree = other.referenceTree;
210  referenceSet = other.referenceSet;
211  treeOwner = other.treeOwner;
212  naive = other.naive;
213  singleMode = other.singleMode;
214  metric = std::move(other.metric);
215  baseCases = other.baseCases;
216  scores = other.scores;
217 
218  // Clear other object.
219  other.referenceTree = nullptr;
220  other.referenceSet = nullptr;
221  other.treeOwner = false;
222  other.naive = false;
223  other.singleMode = false;
224  other.baseCases = 0;
225  other.scores = 0;
226  }
227  return *this;
228 }
229 
230 template<typename MetricType,
231  typename MatType,
232  template<typename TreeMetricType,
233  typename TreeStatType,
234  typename TreeMatType> class TreeType>
235 RangeSearch<MetricType, MatType, TreeType>::~RangeSearch()
236 {
237  if (treeOwner && referenceTree)
238  delete referenceTree;
239  if (naive && referenceSet)
240  delete referenceSet;
241 }
242 
243 template<typename MetricType,
244  typename MatType,
245  template<typename TreeMetricType,
246  typename TreeStatType,
247  typename TreeMatType> class TreeType>
248 void RangeSearch<MetricType, MatType, TreeType>::Train(
249  MatType referenceSet)
250 {
251  // Clean up the old tree, if we built one.
252  if (treeOwner && referenceTree)
253  delete referenceTree;
254 
255  // We may need to rebuild the tree.
256  if (!naive)
257  {
258  referenceTree = BuildTree<Tree>(std::move(referenceSet),
259  oldFromNewReferences);
260  treeOwner = true;
261  }
262  else
263  {
264  treeOwner = false;
265  }
266 
267  // Delete the old reference set, if we owned it.
268  if (naive && this->referenceSet)
269  delete this->referenceSet;
270 
271  if (!naive)
272  {
273  this->referenceSet = &referenceTree->Dataset();
274  }
275  else
276  {
277  this->referenceSet = new MatType(std::move(referenceSet));
278  }
279 }
280 
281 template<typename MetricType,
282  typename MatType,
283  template<typename TreeMetricType,
284  typename TreeStatType,
285  typename TreeMatType> class TreeType>
286 void RangeSearch<MetricType, MatType, TreeType>::Train(
287  Tree* referenceTree)
288 {
289  if (naive)
290  throw std::invalid_argument("cannot train on given reference tree when "
291  "naive search (without trees) is desired");
292 
293  // Can only train when passed argument `referenceTree` is not nullptr.
294  if (treeOwner && referenceTree)
295  {
296  delete this->referenceTree;
297 
298  this->referenceTree = referenceTree;
299  this->referenceSet = &referenceTree->Dataset();
300  treeOwner = false;
301  }
302 }
303 
304 template<typename MetricType,
305  typename MatType,
306  template<typename TreeMetricType,
307  typename TreeStatType,
308  typename TreeMatType> class TreeType>
309 void RangeSearch<MetricType, MatType, TreeType>::Search(
310  const MatType& querySet,
311  const math::Range& range,
312  std::vector<std::vector<size_t>>& neighbors,
313  std::vector<std::vector<double>>& distances)
314 {
315  util::CheckSameDimensionality(querySet, *referenceSet,
316  "RangeSearch::Search()", "query set");
317 
318  // If there are no points, there is no search to be done.
319  if (referenceSet->n_cols == 0)
320  return;
321 
322  Timer::Start("range_search/computing_neighbors");
323 
324  // This will hold mappings for query points, if necessary.
325  std::vector<size_t> oldFromNewQueries;
326 
327  // If we have built the trees ourselves, then we will have to map all the
328  // indices back to their original indices when this computation is finished.
329  // To avoid extra copies, we will store the unmapped neighbors and distances
330  // in a separate object.
331  std::vector<std::vector<size_t>>* neighborPtr = &neighbors;
332  std::vector<std::vector<double>>* distancePtr = &distances;
333 
334  // Mapping is only necessary if the tree rearranges points.
336  {
337  // Query indices only need to be mapped if we are building the query tree
338  // ourselves.
339  if (!singleMode && !naive)
340  {
341  distancePtr = new std::vector<std::vector<double>>;
342  neighborPtr = new std::vector<std::vector<size_t>>;
343  }
344 
345  // Reference indices only need to be mapped if we built the reference tree
346  // ourselves.
347  else if (treeOwner)
348  neighborPtr = new std::vector<std::vector<size_t>>;
349  }
350 
351  // Resize each vector.
352  neighborPtr->clear(); // Just in case there was anything in it.
353  neighborPtr->resize(querySet.n_cols);
354  distancePtr->clear();
355  distancePtr->resize(querySet.n_cols);
356 
357  // Create the helper object for the traversal.
358  typedef RangeSearchRules<MetricType, Tree> RuleType;
359 
360  // Reset counts.
361  baseCases = 0;
362  scores = 0;
363 
364  if (naive)
365  {
366  RuleType rules(*referenceSet, querySet, range, *neighborPtr, *distancePtr,
367  metric);
368 
369  // The naive brute-force solution.
370  for (size_t i = 0; i < querySet.n_cols; ++i)
371  for (size_t j = 0; j < referenceSet->n_cols; ++j)
372  rules.BaseCase(i, j);
373 
374  baseCases += (querySet.n_cols * referenceSet->n_cols);
375  }
376  else if (singleMode)
377  {
378  // Create the traverser.
379  RuleType rules(*referenceSet, querySet, range, *neighborPtr, *distancePtr,
380  metric);
381  typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
382 
383  // Now have it traverse for each point.
384  for (size_t i = 0; i < querySet.n_cols; ++i)
385  traverser.Traverse(i, *referenceTree);
386 
387  baseCases += rules.BaseCases();
388  scores += rules.Scores();
389  }
390  else // Dual-tree recursion.
391  {
392  // Build the query tree.
393  Timer::Stop("range_search/computing_neighbors");
394  Timer::Start("range_search/tree_building");
395  Tree* queryTree = BuildTree<Tree>(querySet, oldFromNewQueries);
396  Timer::Stop("range_search/tree_building");
397  Timer::Start("range_search/computing_neighbors");
398 
399  // Create the traverser.
400  RuleType rules(*referenceSet, queryTree->Dataset(), range, *neighborPtr,
401  *distancePtr, metric);
402  typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
403 
404  traverser.Traverse(*queryTree, *referenceTree);
405 
406  baseCases += rules.BaseCases();
407  scores += rules.Scores();
408 
409  // Clean up tree memory.
410  delete queryTree;
411  }
412 
413  Timer::Stop("range_search/computing_neighbors");
414 
415  // Map points back to original indices, if necessary.
417  {
418  if (!singleMode && !naive && treeOwner)
419  {
420  // We must map both query and reference indices.
421  neighbors.clear();
422  neighbors.resize(querySet.n_cols);
423  distances.clear();
424  distances.resize(querySet.n_cols);
425 
426  for (size_t i = 0; i < distances.size(); ++i)
427  {
428  // Map distances (copy a column).
429  const size_t queryMapping = oldFromNewQueries[i];
430  distances[queryMapping] = (*distancePtr)[i];
431 
432  // Copy each neighbor individually, because we need to map it.
433  neighbors[queryMapping].resize(distances[queryMapping].size());
434  for (size_t j = 0; j < distances[queryMapping].size(); ++j)
435  neighbors[queryMapping][j] =
436  oldFromNewReferences[(*neighborPtr)[i][j]];
437  }
438 
439  // Finished with temporary objects.
440  delete neighborPtr;
441  delete distancePtr;
442  }
443  else if (!singleMode && !naive)
444  {
445  // We must map query indices only.
446  neighbors.clear();
447  neighbors.resize(querySet.n_cols);
448  distances.clear();
449  distances.resize(querySet.n_cols);
450 
451  for (size_t i = 0; i < distances.size(); ++i)
452  {
453  // Map distances and neighbors (copy a column).
454  const size_t queryMapping = oldFromNewQueries[i];
455  distances[queryMapping] = (*distancePtr)[i];
456  neighbors[queryMapping] = (*neighborPtr)[i];
457  }
458 
459  // Finished with temporary objects.
460  delete neighborPtr;
461  delete distancePtr;
462  }
463  else if (treeOwner)
464  {
465  // We must map reference indices only.
466  neighbors.clear();
467  neighbors.resize(querySet.n_cols);
468 
469  for (size_t i = 0; i < neighbors.size(); ++i)
470  {
471  neighbors[i].resize((*neighborPtr)[i].size());
472  for (size_t j = 0; j < neighbors[i].size(); ++j)
473  neighbors[i][j] = oldFromNewReferences[(*neighborPtr)[i][j]];
474  }
475 
476  // Finished with temporary object.
477  delete neighborPtr;
478  }
479  }
480 }
481 
482 template<typename MetricType,
483  typename MatType,
484  template<typename TreeMetricType,
485  typename TreeStatType,
486  typename TreeMatType> class TreeType>
487 void RangeSearch<MetricType, MatType, TreeType>::Search(
488  Tree* queryTree,
489  const math::Range& range,
490  std::vector<std::vector<size_t>>& neighbors,
491  std::vector<std::vector<double>>& distances)
492 {
493  // If there are no points, there is no search to be done.
494  if (referenceSet->n_cols == 0)
495  return;
496 
497  Timer::Start("range_search/computing_neighbors");
498 
499  // Get a reference to the query set.
500  const MatType& querySet = queryTree->Dataset();
501 
502  // Make sure we are in dual-tree mode.
503  if (singleMode || naive)
504  throw std::invalid_argument("cannot call RangeSearch::Search() with a "
505  "query tree when naive or singleMode are set to true");
506 
507  // We won't need to map query indices, but will we need to map distances?
508  std::vector<std::vector<size_t>>* neighborPtr = &neighbors;
509 
511  neighborPtr = new std::vector<std::vector<size_t>>;
512 
513  // Resize each vector.
514  neighborPtr->clear(); // Just in case there was anything in it.
515  neighborPtr->resize(querySet.n_cols);
516  distances.clear();
517  distances.resize(querySet.n_cols);
518 
519  // Create the helper object for the traversal.
520  typedef RangeSearchRules<MetricType, Tree> RuleType;
521  RuleType rules(*referenceSet, queryTree->Dataset(), range, *neighborPtr,
522  distances, metric);
523 
524  // Create the traverser.
525  typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
526 
527  traverser.Traverse(*queryTree, *referenceTree);
528 
529  Timer::Stop("range_search/computing_neighbors");
530 
531  baseCases = rules.BaseCases();
532  scores = rules.Scores();
533 
534  // Do we need to map indices?
536  {
537  // We must map reference indices only.
538  neighbors.clear();
539  neighbors.resize(querySet.n_cols);
540 
541  for (size_t i = 0; i < neighbors.size(); ++i)
542  {
543  neighbors[i].resize((*neighborPtr)[i].size());
544  for (size_t j = 0; j < neighbors[i].size(); ++j)
545  neighbors[i][j] = oldFromNewReferences[(*neighborPtr)[i][j]];
546  }
547 
548  // Finished with temporary object.
549  delete neighborPtr;
550  }
551 }
552 
553 template<typename MetricType,
554  typename MatType,
555  template<typename TreeMetricType,
556  typename TreeStatType,
557  typename TreeMatType> class TreeType>
558 void RangeSearch<MetricType, MatType, TreeType>::Search(
559  const math::Range& range,
560  std::vector<std::vector<size_t>>& neighbors,
561  std::vector<std::vector<double>>& distances)
562 {
563  // If there are no points, there is no search to be done.
564  if (referenceSet->n_cols == 0)
565  return;
566 
567  Timer::Start("range_search/computing_neighbors");
568 
569  // Here, we will use the query set as the reference set.
570  std::vector<std::vector<size_t>>* neighborPtr = &neighbors;
571  std::vector<std::vector<double>>* distancePtr = &distances;
572 
574  {
575  // We will always need to rearrange in this case.
576  distancePtr = new std::vector<std::vector<double>>;
577  neighborPtr = new std::vector<std::vector<size_t>>;
578  }
579 
580  // Resize each vector.
581  neighborPtr->clear(); // Just in case there was anything in it.
582  neighborPtr->resize(referenceSet->n_cols);
583  distancePtr->clear();
584  distancePtr->resize(referenceSet->n_cols);
585 
586  // Create the helper object for the traversal.
587  typedef RangeSearchRules<MetricType, Tree> RuleType;
588  RuleType rules(*referenceSet, *referenceSet, range, *neighborPtr,
589  *distancePtr, metric, true /* don't return the query in the results */);
590 
591  if (naive)
592  {
593  // The naive brute-force solution.
594  for (size_t i = 0; i < referenceSet->n_cols; ++i)
595  for (size_t j = 0; j < referenceSet->n_cols; ++j)
596  rules.BaseCase(i, j);
597 
598  baseCases = (referenceSet->n_cols * referenceSet->n_cols);
599  scores = 0;
600  }
601  else if (singleMode)
602  {
603  // Create the traverser.
604  typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
605 
606  // Now have it traverse for each point.
607  for (size_t i = 0; i < referenceSet->n_cols; ++i)
608  traverser.Traverse(i, *referenceTree);
609 
610  baseCases = rules.BaseCases();
611  scores = rules.Scores();
612  }
613  else // Dual-tree recursion.
614  {
615  // Create the traverser.
616  typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
617 
618  traverser.Traverse(*referenceTree, *referenceTree);
619 
620  baseCases = rules.BaseCases();
621  scores = rules.Scores();
622  }
623 
624  Timer::Stop("range_search/computing_neighbors");
625 
626  // Do we need to map the reference indices?
628  {
629  neighbors.clear();
630  neighbors.resize(referenceSet->n_cols);
631  distances.clear();
632  distances.resize(referenceSet->n_cols);
633 
634  for (size_t i = 0; i < distances.size(); ++i)
635  {
636  // Map distances (copy a column).
637  const size_t refMapping = oldFromNewReferences[i];
638  distances[refMapping] = (*distancePtr)[i];
639 
640  // Copy each neighbor individually, because we need to map it.
641  neighbors[refMapping].resize(distances[refMapping].size());
642  for (size_t j = 0; j < distances[refMapping].size(); ++j)
643  {
644  neighbors[refMapping][j] = oldFromNewReferences[(*neighborPtr)[i][j]];
645  }
646  }
647 
648  // Finished with temporary objects.
649  delete neighborPtr;
650  delete distancePtr;
651  }
652 }
653 
654 template<typename MetricType,
655  typename MatType,
656  template<typename TreeMetricType,
657  typename TreeStatType,
658  typename TreeMatType> class TreeType>
659 template<typename Archive>
660 void RangeSearch<MetricType, MatType, TreeType>::serialize(
661  Archive& ar, const uint32_t /* version */)
662 {
663  // Serialize preferences for search.
664  ar(CEREAL_NVP(naive));
665  ar(CEREAL_NVP(singleMode));
666 
667  // Reset base cases and scores if we are loading.
668  if (cereal::is_loading<Archive>())
669  {
670  baseCases = 0;
671  scores = 0;
672  }
673 
674  // If we are doing naive search, we serialize the dataset. Otherwise we
675  // serialize the tree.
676  if (naive)
677  {
678  if (cereal::is_loading<Archive>())
679  {
680  if (referenceSet)
681  delete referenceSet;
682  }
683 
684  ar(CEREAL_POINTER(const_cast<MatType*&>(referenceSet)));
685  ar(CEREAL_NVP(metric));
686 
687  // If we are loading, set the tree to NULL and clean up memory if necessary.
688  if (cereal::is_loading<Archive>())
689  {
690  if (treeOwner && referenceTree)
691  delete referenceTree;
692 
693  referenceTree = NULL;
694  oldFromNewReferences.clear();
695  treeOwner = false;
696  }
697  }
698  else
699  {
700  // Delete the current reference tree, if necessary and if we are loading.
701  if (cereal::is_loading<Archive>())
702  {
703  if (treeOwner && referenceTree)
704  delete referenceTree;
705 
706  // After we load the tree, we will own it.
707  treeOwner = true;
708  }
709 
710  ar(CEREAL_POINTER(referenceTree));
711  ar(CEREAL_NVP(oldFromNewReferences));
712 
713  // If we are loading, set the dataset accordingly and clean up memory if
714  // necessary.
715  if (cereal::is_loading<Archive>())
716  {
717  referenceSet = &referenceTree->Dataset();
718  metric = referenceTree->Metric(); // Get the metric from the tree.
719  }
720  }
721 }
722 
723 } // namespace range
724 } // namespace mlpack
725 
726 #endif
The RangeSearch class is a template class for performing range searches.
Definition: range_search.hpp:45
static void Start(const std::string &name)
Start the given timer.
Definition: timers.cpp:28
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
RangeSearch & operator=(const RangeSearch &other)
Deep copy the given RangeSearch model.
Definition: range_search_impl.hpp:172
static const bool RearrangesDataset
This is true if the tree rearranges points in the dataset when it is built.
Definition: tree_traits.hpp:105
Definition: hmm_train_main.cpp:300
static void Stop(const std::string &name)
Stop the given timer.
Definition: timers.cpp:36
The TreeTraits class provides compile-time information on the characteristics of a given tree type...
Definition: tree_traits.hpp:77
TreeType * BuildTree(MatType &&dataset, std::vector< size_t > &oldFromNew, const typename std::enable_if< tree::TreeTraits< TreeType >::RearrangesDataset >::type *=0)
Call the tree constructor that does mapping.
Definition: dtb_impl.hpp:22
The RangeSearchRules class is a template helper class used by RangeSearch class when performing range...
Definition: range_search_rules.hpp:28
#define CEREAL_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_wrapper.hpp:96