mlpack
ra_search_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_RANN_RA_SEARCH_IMPL_HPP
14 #define MLPACK_METHODS_RANN_RA_SEARCH_IMPL_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 #include "ra_search_rules.hpp"
19 
20 namespace mlpack {
21 namespace neighbor {
22 
23 namespace aux {
24 
26 template<typename TreeType, typename MatType>
27 TreeType* BuildTree(
28  MatType&& dataset,
29  std::vector<size_t>& oldFromNew,
30  typename std::enable_if<
32 {
33  return new TreeType(std::forward<MatType>(dataset), oldFromNew);
34 }
35 
37 template<typename TreeType, typename MatType>
38 TreeType* BuildTree(
39  MatType&& dataset,
40  const std::vector<size_t>& /* oldFromNew */,
41  const typename std::enable_if<
43 {
44  return new TreeType(std::forward<MatType>(dataset));
45 }
46 
47 } // namespace aux
48 
49 // Construct the object, taking ownership of the data matrix.
50 template<typename SortPolicy,
51  typename MetricType,
52  typename MatType,
53  template<typename TreeMetricType,
54  typename TreeStatType,
55  typename TreeMatType> class TreeType>
56 RASearch<SortPolicy, MetricType, MatType, TreeType>::
57 RASearch(MatType referenceSetIn,
58  const bool naive,
59  const bool singleMode,
60  const double tau,
61  const double alpha,
62  const bool sampleAtLeaves,
63  const bool firstLeafExact,
64  const size_t singleSampleLimit,
65  const MetricType metric) :
66  referenceTree(naive ? NULL : aux::BuildTree<Tree>(
67  std::move(referenceSetIn), oldFromNewReferences)),
68  referenceSet(naive ? new MatType(std::move(referenceSetIn)) :
69  &referenceTree->Dataset()),
70  treeOwner(!naive),
71  setOwner(naive),
72  naive(naive),
73  singleMode(!naive && singleMode), // No single mode if naive.
74  tau(tau),
75  alpha(alpha),
76  sampleAtLeaves(sampleAtLeaves),
77  firstLeafExact(firstLeafExact),
78  singleSampleLimit(singleSampleLimit),
79  metric(metric)
80 {
81  // Nothing to do.
82 }
83 
84 // Construct the object.
85 template<typename SortPolicy,
86  typename MetricType,
87  typename MatType,
88  template<typename TreeMetricType,
89  typename TreeStatType,
90  typename TreeMatType> class TreeType>
91 RASearch<SortPolicy, MetricType, MatType, TreeType>::
92 RASearch(Tree* referenceTree,
93  const bool singleMode,
94  const double tau,
95  const double alpha,
96  const bool sampleAtLeaves,
97  const bool firstLeafExact,
98  const size_t singleSampleLimit,
99  const MetricType metric) :
100  referenceTree(referenceTree),
101  referenceSet(&referenceTree->Dataset()),
102  treeOwner(false),
103  setOwner(false),
104  naive(false),
105  singleMode(singleMode),
106  tau(tau),
107  alpha(alpha),
108  sampleAtLeaves(sampleAtLeaves),
109  firstLeafExact(firstLeafExact),
110  singleSampleLimit(singleSampleLimit),
111  metric(metric)
112 // Nothing else to initialize.
113 { }
114 
115 // Empty constructor.
116 template<typename SortPolicy,
117  typename MetricType,
118  typename MatType,
119  template<typename TreeMetricType,
120  typename TreeStatType,
121  typename TreeMatType> class TreeType>
122 RASearch<SortPolicy, MetricType, MatType, TreeType>::
123 RASearch(const bool naive,
124  const bool singleMode,
125  const double tau,
126  const double alpha,
127  const bool sampleAtLeaves,
128  const bool firstLeafExact,
129  const size_t singleSampleLimit,
130  const MetricType metric) :
131  referenceTree(NULL),
132  referenceSet(new MatType()),
133  treeOwner(false),
134  setOwner(true),
135  naive(naive),
136  singleMode(singleMode),
137  tau(tau),
138  alpha(alpha),
139  sampleAtLeaves(sampleAtLeaves),
140  firstLeafExact(firstLeafExact),
141  singleSampleLimit(singleSampleLimit),
142  metric(metric)
143 {
144  // Build the tree on the empty dataset, if necessary.
145  if (!naive)
146  {
147  referenceTree = aux::BuildTree<Tree>(*referenceSet, oldFromNewReferences);
148  treeOwner = true;
149  }
150 }
151 
156 template<typename SortPolicy,
157  typename MetricType,
158  typename MatType,
159  template<typename TreeMetricType,
160  typename TreeStatType,
161  typename TreeMatType> class TreeType>
162 RASearch<SortPolicy, MetricType, MatType, TreeType>::
163 ~RASearch()
164 {
165  if (treeOwner && referenceTree)
166  delete referenceTree;
167  if (setOwner)
168  delete referenceSet;
169 }
170 
171 // Train on a new reference set.
172 template<typename SortPolicy,
173  typename MetricType,
174  typename MatType,
175  template<typename TreeMetricType,
176  typename TreeStatType,
177  typename TreeMatType> class TreeType>
178 void RASearch<SortPolicy, MetricType, MatType, TreeType>::Train(
179  MatType referenceSet)
180 {
181  // Clean up the old tree, if we built one.
182  if (treeOwner && referenceTree)
183  delete referenceTree;
184 
185  // We may need to rebuild the tree.
186  if (!naive)
187  {
188  referenceTree = aux::BuildTree<Tree>(std::move(referenceSet),
189  oldFromNewReferences);
190  treeOwner = true;
191  }
192  else
193  {
194  treeOwner = false;
195  }
196 
197  // Delete the old reference set, if we owned it.
198  if (setOwner && this->referenceSet)
199  delete this->referenceSet;
200 
201  if (!naive)
202  {
203  this->referenceSet = &referenceTree->Dataset();
204  setOwner = false;
205  }
206  else
207  {
208  this->referenceSet = new MatType(std::move(referenceSet));
209  setOwner = true;
210  }
211 }
212 
214 template<typename SortPolicy,
215  typename MetricType,
216  typename MatType,
217  template<typename TreeMetricType,
218  typename TreeStatType,
219  typename TreeMatType> class TreeType>
220 void RASearch<SortPolicy, MetricType, MatType, TreeType>::Train(
221  Tree* referenceTree)
222 {
223  if (naive)
224  throw std::invalid_argument("cannot train on given reference tree when "
225  "naive search (without trees) is desired");
226 
227  if (treeOwner && referenceTree)
228  delete this->referenceTree;
229  if (setOwner && referenceSet)
230  delete this->referenceSet;
231 
232  this->referenceTree = referenceTree;
233  this->referenceSet = &referenceTree->Dataset();
234  treeOwner = false;
235  setOwner = false;
236 }
237 
242 template<typename SortPolicy,
243  typename MetricType,
244  typename MatType,
245  template<typename TreeMetricType,
246  typename TreeStatType,
247  typename TreeMatType> class TreeType>
248 void RASearch<SortPolicy, MetricType, MatType, TreeType>::
249 Search(const MatType& querySet,
250  const size_t k,
251  arma::Mat<size_t>& neighbors,
252  arma::mat& distances)
253 {
254  if (k > referenceSet->n_cols)
255  {
256  std::stringstream ss;
257  ss << "requested value of k (" << k << ") is greater than the number of "
258  << "points in the reference set (" << referenceSet->n_cols << ")";
259  throw std::invalid_argument(ss.str());
260  }
261 
262  Timer::Start("computing_neighbors");
263 
264  // This will hold mappings for query points, if necessary.
265  std::vector<size_t> oldFromNewQueries;
266 
267  // If we have built the trees ourselves, then we will have to map all the
268  // indices back to their original indices when this computation is finished.
269  // To avoid an extra copy, we will store the neighbors and distances in a
270  // separate matrix.
271  arma::Mat<size_t>* neighborPtr = &neighbors;
272  arma::mat* distancePtr = &distances;
273 
274  // Mapping is only required if this tree type rearranges points and we are not
275  // in naive mode.
277  {
278  if (!singleMode && !naive)
279  {
280  distancePtr = new arma::mat; // Query indices need to be mapped.
281  neighborPtr = new arma::Mat<size_t>;
282  }
283 
284  else if (treeOwner)
285  neighborPtr = new arma::Mat<size_t>; // All indices need mapping.
286  }
287 
288  // Set the size of the neighbor and distance matrices.
289  neighborPtr->set_size(k, querySet.n_cols);
290  distancePtr->set_size(k, querySet.n_cols);
291 
293 
294  if (naive)
295  {
296  RuleType rules(*referenceSet, querySet, k, metric, tau, alpha, naive,
297  sampleAtLeaves, firstLeafExact, singleSampleLimit, false);
298 
299  // Find how many samples from the reference set we need and sample uniformly
300  // from the reference set without replacement.
301  const size_t numSamples = RAUtil::MinimumSamplesReqd(referenceSet->n_cols,
302  k, tau, alpha);
303  arma::uvec distinctSamples;
304  math::ObtainDistinctSamples(0, referenceSet->n_cols, numSamples,
305  distinctSamples);
306 
307  // Run the base case on each combination of query point and sampled
308  // reference point.
309  for (size_t i = 0; i < querySet.n_cols; ++i)
310  for (size_t j = 0; j < distinctSamples.n_elem; ++j)
311  rules.BaseCase(i, (size_t) distinctSamples[j]);
312 
313  rules.GetResults(*neighborPtr, *distancePtr);
314  }
315  else if (singleMode)
316  {
317  RuleType rules(*referenceSet, querySet, k, metric, tau, alpha, naive,
318  sampleAtLeaves, firstLeafExact, singleSampleLimit, false);
319 
320  // If the reference root node is a leaf, then the sampling has already been
321  // done in the RASearchRules constructor. This happens when naive = true.
322  if (!referenceTree->IsLeaf())
323  {
324  Log::Info << "Performing single-tree traversal..." << std::endl;
325 
326  // Create the traverser.
327  typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
328 
329  // Now have it traverse for each point.
330  for (size_t i = 0; i < querySet.n_cols; ++i)
331  traverser.Traverse(i, *referenceTree);
332 
333  Log::Info << "Single-tree traversal complete." << std::endl;
334  Log::Info << "Average number of distance calculations per query point: "
335  << (rules.NumDistComputations() / querySet.n_cols) << "."
336  << std::endl;
337  }
338 
339  rules.GetResults(*neighborPtr, *distancePtr);
340  }
341  else // Dual-tree recursion.
342  {
343  Log::Info << "Performing dual-tree traversal..." << std::endl;
344 
345  // Build the query tree.
346  Timer::Stop("computing_neighbors");
347  Timer::Start("tree_building");
348  Tree* queryTree = aux::BuildTree<Tree>(const_cast<MatType&>(querySet),
349  oldFromNewQueries);
350  Timer::Stop("tree_building");
351  Timer::Start("computing_neighbors");
352 
353  RuleType rules(*referenceSet, queryTree->Dataset(), k, metric, tau, alpha,
354  naive, sampleAtLeaves, firstLeafExact, singleSampleLimit, false);
355  typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
356 
357  Log::Info << "Query statistic pre-search: "
358  << queryTree->Stat().NumSamplesMade() << std::endl;
359 
360  traverser.Traverse(*queryTree, *referenceTree);
361 
362  Log::Info << "Dual-tree traversal complete." << std::endl;
363  Log::Info << "Average number of distance calculations per query point: "
364  << (rules.NumDistComputations() / querySet.n_cols) << "." << std::endl;
365 
366  rules.GetResults(*neighborPtr, *distancePtr);
367 
368  delete queryTree;
369  }
370 
371  Timer::Stop("computing_neighbors");
372 
373  // Map points back to original indices, if necessary.
375  {
376  if (!singleMode && !naive && treeOwner)
377  {
378  // We must map both query and reference indices.
379  neighbors.set_size(k, querySet.n_cols);
380  distances.set_size(k, querySet.n_cols);
381 
382  for (size_t i = 0; i < distances.n_cols; ++i)
383  {
384  // Map distances (copy a column).
385  distances.col(oldFromNewQueries[i]) = distancePtr->col(i);
386 
387  // Map indices of neighbors.
388  for (size_t j = 0; j < distances.n_rows; ++j)
389  {
390  neighbors(j, oldFromNewQueries[i]) =
391  oldFromNewReferences[(*neighborPtr)(j, i)];
392  }
393  }
394 
395  // Finished with temporary matrices.
396  delete neighborPtr;
397  delete distancePtr;
398  }
399  else if (!singleMode && !naive)
400  {
401  // We must map query indices only.
402  neighbors.set_size(k, querySet.n_cols);
403  distances.set_size(k, querySet.n_cols);
404 
405  for (size_t i = 0; i < distances.n_cols; ++i)
406  {
407  // Map distances (copy a column).
408  const size_t queryMapping = oldFromNewQueries[i];
409  distances.col(queryMapping) = distancePtr->col(i);
410  neighbors.col(queryMapping) = neighborPtr->col(i);
411  }
412 
413  // Finished with temporary matrices.
414  delete neighborPtr;
415  delete distancePtr;
416  }
417  else if (treeOwner)
418  {
419  // We must map reference indices only.
420  neighbors.set_size(k, querySet.n_cols);
421 
422  // Map indices of neighbors.
423  for (size_t i = 0; i < neighbors.n_cols; ++i)
424  for (size_t j = 0; j < neighbors.n_rows; ++j)
425  neighbors(j, i) = oldFromNewReferences[(*neighborPtr)(j, i)];
426 
427  // Finished with temporary matrix.
428  delete neighborPtr;
429  }
430  }
431 }
432 
433 template<typename SortPolicy,
434  typename MetricType,
435  typename MatType,
436  template<typename TreeMetricType,
437  typename TreeStatType,
438  typename TreeMatType> class TreeType>
439 void RASearch<SortPolicy, MetricType, MatType, TreeType>::Search(
440  Tree* queryTree,
441  const size_t k,
442  arma::Mat<size_t>& neighbors,
443  arma::mat& distances)
444 {
445  Timer::Start("computing_neighbors");
446 
447  // Get a reference to the query set.
448  const MatType& querySet = queryTree->Dataset();
449 
450  // Make sure we are in dual-tree mode.
451  if (singleMode || naive)
452  throw std::invalid_argument("cannot call NeighborSearch::Search() with a "
453  "query tree when naive or singleMode are set to true");
454 
455  // We won't need to map query indices, but will we need to map distances?
456  arma::Mat<size_t>* neighborPtr = &neighbors;
457 
459  neighborPtr = new arma::Mat<size_t>;
460 
461  neighborPtr->set_size(k, querySet.n_cols);
462  distances.set_size(k, querySet.n_cols);
463 
464  // Create the helper object for the tree traversal.
466  RuleType rules(*referenceSet, queryTree->Dataset(), k, metric, tau, alpha,
467  naive, sampleAtLeaves, firstLeafExact, singleSampleLimit, false);
468 
469  // Create the traverser.
470  typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
471  traverser.Traverse(*queryTree, *referenceTree);
472 
473  rules.GetResults(*neighborPtr, distances);
474 
475  Timer::Stop("computing_neighbors");
476 
477  // Do we need to map indices?
479  {
480  // We must map reference indices only.
481  neighbors.set_size(k, querySet.n_cols);
482 
483  // Map indices of neighbors.
484  for (size_t i = 0; i < neighbors.n_cols; ++i)
485  for (size_t j = 0; j < neighbors.n_rows; ++j)
486  neighbors(j, i) = oldFromNewReferences[(*neighborPtr)(j, i)];
487 
488  // Finished with temporary matrix.
489  delete neighborPtr;
490  }
491 }
492 
493 template<typename SortPolicy,
494  typename MetricType,
495  typename MatType,
496  template<typename TreeMetricType,
497  typename TreeStatType,
498  typename TreeMatType> class TreeType>
499 void RASearch<SortPolicy, MetricType, MatType, TreeType>::Search(
500  const size_t k,
501  arma::Mat<size_t>& neighbors,
502  arma::mat& distances)
503 {
504  Timer::Start("computing_neighbors");
505 
506  arma::Mat<size_t>* neighborPtr = &neighbors;
507  arma::mat* distancePtr = &distances;
508 
510  {
511  // We will always need to rearrange in this case.
512  distancePtr = new arma::mat;
513  neighborPtr = new arma::Mat<size_t>;
514  }
515 
516  // Initialize results.
517  neighborPtr->set_size(k, referenceSet->n_cols);
518  distancePtr->set_size(k, referenceSet->n_cols);
519 
520  // Create the helper object for the tree traversal.
522  RuleType rules(*referenceSet, *referenceSet, k, metric, tau, alpha, naive,
523  sampleAtLeaves, firstLeafExact, singleSampleLimit, true /* same sets */);
524 
525  if (naive)
526  {
527  // Find how many samples from the reference set we need and sample uniformly
528  // from the reference set without replacement.
529  const size_t numSamples = RAUtil::MinimumSamplesReqd(referenceSet->n_cols,
530  k, tau, alpha);
531  arma::uvec distinctSamples;
532  math::ObtainDistinctSamples(0, referenceSet->n_cols, numSamples,
533  distinctSamples);
534 
535  // The naive brute-force solution.
536  for (size_t i = 0; i < referenceSet->n_cols; ++i)
537  for (size_t j = 0; j < referenceSet->n_cols; ++j)
538  rules.BaseCase(i, j);
539  }
540  else if (singleMode)
541  {
542  // Create the traverser.
543  typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
544 
545  // Now have it traverse for each point.
546  for (size_t i = 0; i < referenceSet->n_cols; ++i)
547  traverser.Traverse(i, *referenceTree);
548  }
549  else
550  {
551  // Create the traverser.
552  typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
553 
554  traverser.Traverse(*referenceTree, *referenceTree);
555  }
556 
557  rules.GetResults(*neighborPtr, *distancePtr);
558 
559  Timer::Stop("computing_neighbors");
560 
561  // Do we need to map the reference indices?
563  {
564  neighbors.set_size(k, referenceSet->n_cols);
565  distances.set_size(k, referenceSet->n_cols);
566 
567  for (size_t i = 0; i < distances.n_cols; ++i)
568  {
569  // Map distances (copy a column).
570  const size_t refMapping = oldFromNewReferences[i];
571  distances.col(refMapping) = distancePtr->col(i);
572 
573  // Map each neighbor's index.
574  for (size_t j = 0; j < distances.n_rows; ++j)
575  neighbors(j, refMapping) = oldFromNewReferences[(*neighborPtr)(j, i)];
576  }
577 
578  // Finished with temporary matrices.
579  delete neighborPtr;
580  delete distancePtr;
581  }
582 }
583 
584 template<typename SortPolicy,
585  typename MetricType,
586  typename MatType,
587  template<typename TreeMetricType,
588  typename TreeStatType,
589  typename TreeMatType> class TreeType>
590 void RASearch<SortPolicy, MetricType, MatType, TreeType>::ResetQueryTree(
591  Tree* queryNode) const
592 {
593  queryNode->Stat().Bound() = SortPolicy::WorstDistance();
594  queryNode->Stat().NumSamplesMade() = 0;
595 
596  for (size_t i = 0; i < queryNode->NumChildren(); ++i)
597  ResetQueryTree(&queryNode->Child(i));
598 }
599 
600 template<typename SortPolicy,
601  typename MetricType,
602  typename MatType,
603  template<typename TreeMetricType,
604  typename TreeStatType,
605  typename TreeMatType> class TreeType>
606 template<typename Archive>
607 void RASearch<SortPolicy, MetricType, MatType, TreeType>::serialize(
608  Archive& ar, const uint32_t /* version */)
609 {
610  // Serialize preferences for search.
611  ar(CEREAL_NVP(naive));
612  ar(CEREAL_NVP(singleMode));
613 
614  ar(CEREAL_NVP(tau));
615  ar(CEREAL_NVP(alpha));
616  ar(CEREAL_NVP(sampleAtLeaves));
617  ar(CEREAL_NVP(firstLeafExact));
618  ar(CEREAL_NVP(singleSampleLimit));
619 
620  // If we are doing naive search, we serialize the dataset. Otherwise we
621  // serialize the tree.
622  if (naive)
623  {
624  if (cereal::is_loading<Archive>())
625  {
626  if (setOwner && referenceSet)
627  delete referenceSet;
628 
629  setOwner = true;
630  }
631  ar(CEREAL_POINTER(const_cast<MatType*&>(referenceSet)));
632  ar(CEREAL_NVP(metric));
633 
634  // If we are loading, set the tree to NULL and clean up memory if necessary.
635  if (cereal::is_loading<Archive>())
636  {
637  if (treeOwner && referenceTree)
638  delete referenceTree;
639 
640  referenceTree = NULL;
641  oldFromNewReferences.clear();
642  treeOwner = false;
643  }
644  }
645  else
646  {
647  // Delete the current reference tree, if necessary and if we are loading.
648  if (cereal::is_loading<Archive>())
649  {
650  if (treeOwner && referenceTree)
651  delete referenceTree;
652 
653  // After we load the tree, we will own it.
654  treeOwner = true;
655  }
656 
657  ar(CEREAL_POINTER(referenceTree));
658  ar(CEREAL_NVP(oldFromNewReferences));
659 
660  // If we are loading, set the dataset accordingly and clean up memory if
661  // necessary.
662  if (cereal::is_loading<Archive>())
663  {
664  if (setOwner && referenceSet)
665  delete referenceSet;
666 
667  referenceSet = &referenceTree->Dataset();
668  metric = referenceTree->Metric();
669  setOwner = false;
670  }
671  }
672 }
673 
674 } // namespace neighbor
675 } // namespace mlpack
676 
677 #endif
void ObtainDistinctSamples(const size_t loInclusive, const size_t hiExclusive, const size_t maxNumSamples, arma::uvec &distinctSamples)
Obtains no more than maxNumSamples distinct samples.
Definition: random.hpp:153
static void Start(const std::string &name)
Start the given timer.
Definition: timers.cpp:28
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.
static size_t MinimumSamplesReqd(const size_t n, const size_t k, const double tau, const double alpha)
Compute the minimum number of samples required to guarantee the given rank-approximation and success ...
Definition: ra_util.cpp:18
Definition: hmm_train_main.cpp:300
static void Stop(const std::string &name)
Stop the given timer.
Definition: timers.cpp:36
The TreeTraits class provides compile-time information on the characteristics of a given tree type...
Definition: tree_traits.hpp:77
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84
#define CEREAL_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_wrapper.hpp:96
The RASearchRules class is a template helper class used by RASearch class when performing rank-approx...
Definition: ra_search_rules.hpp:33