mlpack
neighbor_search_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_IMPL_HPP
14 #define MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_IMPL_HPP
15 
16 #include <mlpack/prereqs.hpp>
20 
21 namespace mlpack {
22 namespace neighbor {
23 
25 template<typename TreeType, typename MatType>
26 TreeType* BuildTree(
27  MatType&& dataset,
28  std::vector<size_t>& oldFromNew,
29  typename std::enable_if_t<
31  >* = 0)
32 {
33  return new TreeType(std::forward<MatType>(dataset), oldFromNew);
34 }
35 
37 template<typename TreeType, typename MatType>
38 TreeType* BuildTree(
39  MatType&& dataset,
40  const std::vector<size_t>& /* oldFromNew */,
41  const typename std::enable_if_t<
43  >* = 0)
44 {
45  return new TreeType(std::forward<MatType>(dataset));
46 }
47 
48 // Construct the object.
49 template<typename SortPolicy,
50  typename MetricType,
51  typename MatType,
52  template<typename TreeMetricType,
53  typename TreeStatType,
54  typename TreeMatType> class TreeType,
55  template<typename> class DualTreeTraversalType,
56  template<typename> class SingleTreeTraversalType>
57 NeighborSearch<SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType,
58 SingleTreeTraversalType>::NeighborSearch(MatType referenceSetIn,
59  const NeighborSearchMode mode,
60  const double epsilon,
61  const MetricType metric) :
62  referenceTree(mode == NAIVE_MODE ? NULL :
63  BuildTree<Tree>(std::move(referenceSetIn), oldFromNewReferences)),
64  referenceSet(mode == NAIVE_MODE ? new MatType(std::move(referenceSetIn)) :
65  &referenceTree->Dataset()),
66  searchMode(mode),
67  epsilon(epsilon),
68  metric(metric),
69  baseCases(0),
70  scores(0),
71  treeNeedsReset(false)
72 {
73  if (epsilon < 0)
74  throw std::invalid_argument("epsilon must be non-negative");
75 }
76 
77 // Construct the object.
78 template<typename SortPolicy,
79  typename MetricType,
80  typename MatType,
81  template<typename TreeMetricType,
82  typename TreeStatType,
83  typename TreeMatType> class TreeType,
84  template<typename> class DualTreeTraversalType,
85  template<typename> class SingleTreeTraversalType>
86 NeighborSearch<SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType,
87 SingleTreeTraversalType>::NeighborSearch(Tree referenceTree,
88  const NeighborSearchMode mode,
89  const double epsilon,
90  const MetricType metric) :
91  referenceTree(new Tree(std::move(referenceTree))),
92  referenceSet(&this->referenceTree->Dataset()),
93  searchMode(mode),
94  epsilon(epsilon),
95  metric(metric),
96  baseCases(0),
97  scores(0),
98  treeNeedsReset(false)
99 {
100  if (epsilon < 0)
101  throw std::invalid_argument("epsilon must be non-negative");
102 }
103 
104 // Construct the object without a reference dataset.
105 template<typename SortPolicy,
106  typename MetricType,
107  typename MatType,
108  template<typename TreeMetricType,
109  typename TreeStatType,
110  typename TreeMatType> class TreeType,
111  template<typename> class DualTreeTraversalType,
112  template<typename> class SingleTreeTraversalType>
113 NeighborSearch<SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType,
114 SingleTreeTraversalType>::NeighborSearch(const NeighborSearchMode mode,
115  const double epsilon,
116  const MetricType metric) :
117  referenceTree(NULL),
118  referenceSet(mode == NAIVE_MODE ? new MatType() : NULL), // Empty matrix.
119  searchMode(mode),
120  epsilon(epsilon),
121  metric(metric),
122  baseCases(0),
123  scores(0),
124  treeNeedsReset(false)
125 {
126  if (epsilon < 0)
127  throw std::invalid_argument("epsilon must be non-negative");
128 
129  // Build the tree on the empty dataset, if necessary.
130  if (mode != NAIVE_MODE)
131  {
132  referenceTree = BuildTree<Tree>(std::move(arma::mat()),
133  oldFromNewReferences);
134  referenceSet = &referenceTree->Dataset();
135  }
136 }
137 
138 // Copy constructor.
139 template<typename SortPolicy,
140  typename MetricType,
141  typename MatType,
142  template<typename TreeMetricType,
143  typename TreeStatType,
144  typename TreeMatType> class TreeType,
145  template<typename> class DualTreeTraversalType,
146  template<typename> class SingleTreeTraversalType>
147 NeighborSearch<SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType,
148 SingleTreeTraversalType>::NeighborSearch(const NeighborSearch& other) :
149  oldFromNewReferences(other.oldFromNewReferences),
150  referenceTree(other.referenceTree ? new Tree(*other.referenceTree) : NULL),
151  referenceSet(other.referenceTree ? &referenceTree->Dataset() :
152  new MatType(*other.referenceSet)),
153  searchMode(other.searchMode),
154  epsilon(other.epsilon),
155  metric(other.metric),
156  baseCases(other.baseCases),
157  scores(other.scores),
158  treeNeedsReset(false)
159 {
160  // Nothing else to do.
161 }
162 
163 // Move constructor.
164 template<typename SortPolicy,
165  typename MetricType,
166  typename MatType,
167  template<typename TreeMetricType,
168  typename TreeStatType,
169  typename TreeMatType> class TreeType,
170  template<typename> class DualTreeTraversalType,
171  template<typename> class SingleTreeTraversalType>
172 NeighborSearch<SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType,
173 SingleTreeTraversalType>::NeighborSearch(NeighborSearch&& other) :
174  oldFromNewReferences(std::move(other.oldFromNewReferences)),
175  referenceTree(other.referenceTree),
176  referenceSet(other.referenceSet),
177  searchMode(other.searchMode),
178  epsilon(other.epsilon),
179  metric(std::move(other.metric)),
180  baseCases(other.baseCases),
181  scores(other.scores),
182  treeNeedsReset(other.treeNeedsReset)
183 {
184  // Clear the other model.
185  other.referenceTree = BuildTree<Tree>(std::move(MatType()),
186  other.oldFromNewReferences);
187  other.referenceSet = &other.referenceTree->Dataset();
188  other.searchMode = DUAL_TREE_MODE,
189  other.epsilon = 0.0;
190  other.baseCases = 0;
191  other.scores = 0;
192  other.treeNeedsReset = false;
193 }
194 
195 // Copy operator.
196 template<typename SortPolicy,
197  typename MetricType,
198  typename MatType,
199  template<typename TreeMetricType,
200  typename TreeStatType,
201  typename TreeMatType> class TreeType,
202  template<typename> class DualTreeTraversalType,
203  template<typename> class SingleTreeTraversalType>
204 NeighborSearch<SortPolicy,
205  MetricType,
206  MatType,
207  TreeType,
208  DualTreeTraversalType,
209  SingleTreeTraversalType>&
210 NeighborSearch<SortPolicy,
211  MetricType,
212  MatType,
213  TreeType,
214  DualTreeTraversalType,
215  SingleTreeTraversalType>::operator=(const NeighborSearch& other)
216 {
217  if (&other == this)
218  return *this; // Nothing to do.
219 
220  // Clean memory first.
221  if (referenceTree)
222  delete referenceTree;
223  else
224  delete referenceSet;
225 
226  oldFromNewReferences = other.oldFromNewReferences;
227  referenceTree = other.referenceTree ? new Tree(*other.referenceTree) : NULL;
228  referenceSet = other.referenceTree ? &referenceTree->Dataset() :
229  new MatType(*other.referenceSet);
230  searchMode = other.searchMode;
231  epsilon = other.epsilon;
232  metric = other.metric;
233  baseCases = other.baseCases;
234  scores = other.scores;
235  treeNeedsReset = false;
236 }
237 
238 // Move operator.
239 template<typename SortPolicy,
240  typename MetricType,
241  typename MatType,
242  template<typename TreeMetricType,
243  typename TreeStatType,
244  typename TreeMatType> class TreeType,
245  template<typename> class DualTreeTraversalType,
246  template<typename> class SingleTreeTraversalType>
247 NeighborSearch<SortPolicy,
248  MetricType,
249  MatType,
250  TreeType,
251  DualTreeTraversalType,
252  SingleTreeTraversalType>&
253 NeighborSearch<SortPolicy,
254  MetricType,
255  MatType,
256  TreeType,
257  DualTreeTraversalType,
258  SingleTreeTraversalType>::operator=(NeighborSearch&& other)
259 {
260  if (&other == this)
261  return *this; // Nothing to do.
262 
263  // Clean memory first.
264  if (referenceTree)
265  delete referenceTree;
266  else
267  delete referenceSet;
268 
269  oldFromNewReferences = std::move(other.oldFromNewReferences);
270  referenceTree = other.referenceTree;
271  referenceSet = other.referenceSet;
272  searchMode = other.searchMode;
273  epsilon = other.epsilon;
274  metric = other.metric;
275  baseCases = other.baseCases;
276  scores = other.scores;
277  treeNeedsReset = other.treeNeedsReset;
278 
279  // Reset the other object. Clean memory if needed.
280  if (!other.referenceTree)
281  delete other.referenceSet;
282 
283  other.referenceTree = BuildTree<Tree>(std::move(arma::mat()),
284  other.oldFromNewReferences);
285  other.referenceSet = &other.referenceTree->Dataset();
286  other.searchMode = DUAL_TREE_MODE,
287  other.epsilon = 0.0;
288  other.baseCases = 0;
289  other.scores = 0;
290  other.treeNeedsReset = false;
291 }
292 
293 // Clean memory.
294 template<typename SortPolicy,
295  typename MetricType,
296  typename MatType,
297  template<typename TreeMetricType,
298  typename TreeStatType,
299  typename TreeMatType> class TreeType,
300  template<typename> class DualTreeTraversalType,
301  template<typename> class SingleTreeTraversalType>
302 NeighborSearch<SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType,
303 SingleTreeTraversalType>::~NeighborSearch()
304 {
305  if (referenceTree)
306  delete referenceTree;
307  else
308  delete referenceSet;
309 }
310 
311 template<typename SortPolicy,
312  typename MetricType,
313  typename MatType,
314  template<typename TreeMetricType,
315  typename TreeStatType,
316  typename TreeMatType> class TreeType,
317  template<typename> class DualTreeTraversalType,
318  template<typename> class SingleTreeTraversalType>
319 void NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
320 DualTreeTraversalType, SingleTreeTraversalType>::Train(MatType referenceSetIn)
321 {
322  // Clean up the old tree, if we built one.
323  if (referenceTree)
324  {
325  oldFromNewReferences.clear();
326  delete referenceTree;
327  referenceTree = NULL;
328  }
329  else
330  {
331  delete referenceSet;
332  }
333 
334  // We may need to rebuild the tree.
335  if (searchMode != NAIVE_MODE)
336  {
337  referenceTree = BuildTree<Tree>(std::move(referenceSetIn),
338  oldFromNewReferences);
339  referenceSet = &referenceTree->Dataset();
340  }
341  else
342  {
343  referenceSet = new MatType(std::move(referenceSetIn));
344  }
345 }
346 
347 template<typename SortPolicy,
348  typename MetricType,
349  typename MatType,
350  template<typename TreeMetricType,
351  typename TreeStatType,
352  typename TreeMatType> class TreeType,
353  template<typename> class DualTreeTraversalType,
354  template<typename> class SingleTreeTraversalType>
355 void NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
356 DualTreeTraversalType, SingleTreeTraversalType>::Train(Tree referenceTree)
357 {
358  if (searchMode == NAIVE_MODE)
359  throw std::invalid_argument("cannot train on given reference tree when "
360  "naive search (without trees) is desired");
361 
362  if (this->referenceTree)
363  {
364  oldFromNewReferences.clear();
365  delete this->referenceTree;
366  }
367  else
368  {
369  delete this->referenceSet;
370  }
371 
372  this->referenceTree = new Tree(std::move(referenceTree));
373  this->referenceSet = &this->referenceTree->Dataset();
374 }
375 
380 template<typename SortPolicy,
381  typename MetricType,
382  typename MatType,
383  template<typename TreeMetricType,
384  typename TreeStatType,
385  typename TreeMatType> class TreeType,
386  template<typename> class DualTreeTraversalType,
387  template<typename> class SingleTreeTraversalType>
388 void NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
389 DualTreeTraversalType, SingleTreeTraversalType>::Search(
390  const MatType& querySet,
391  const size_t k,
392  arma::Mat<size_t>& neighbors,
393  arma::mat& distances)
394 {
395  if (k > referenceSet->n_cols)
396  {
397  std::stringstream ss;
398  ss << "Requested value of k (" << k << ") is greater than the number of "
399  << "points in the reference set (" << referenceSet->n_cols << ")";
400  throw std::invalid_argument(ss.str());
401  }
402 
403  Timer::Start("computing_neighbors");
404 
405  baseCases = 0;
406  scores = 0;
407 
408  // This will hold mappings for query points, if necessary.
409  std::vector<size_t> oldFromNewQueries;
410 
411  // If we have built the trees ourselves, then we will have to map all the
412  // indices back to their original indices when this computation is finished.
413  // To avoid an extra copy, we will store the neighbors and distances in a
414  // separate matrix.
415  arma::Mat<size_t>* neighborPtr = &neighbors;
416  arma::mat* distancePtr = &distances;
417 
418  // Mapping is only necessary if the tree rearranges points.
420  {
421  if (searchMode == DUAL_TREE_MODE)
422  {
423  distancePtr = new arma::mat; // Query indices need to be mapped.
424  neighborPtr = new arma::Mat<size_t>;
425  }
426  else if (!oldFromNewReferences.empty())
427  neighborPtr = new arma::Mat<size_t>; // Reference indices need mapping.
428  }
429 
430  // Set the size of the neighbor and distance matrices.
431  neighborPtr->set_size(k, querySet.n_cols);
432  distancePtr->set_size(k, querySet.n_cols);
433 
435 
436  switch (searchMode)
437  {
438  case NAIVE_MODE:
439  {
440  // Create the helper object for the tree traversal.
441  RuleType rules(*referenceSet, querySet, k, metric, epsilon);
442 
443  // The naive brute-force traversal.
444  for (size_t i = 0; i < querySet.n_cols; ++i)
445  for (size_t j = 0; j < referenceSet->n_cols; ++j)
446  rules.BaseCase(i, j);
447 
448  baseCases += querySet.n_cols * referenceSet->n_cols;
449 
450  rules.GetResults(*neighborPtr, *distancePtr);
451  break;
452  }
453  case SINGLE_TREE_MODE:
454  {
455  // Create the helper object for the tree traversal.
456  RuleType rules(*referenceSet, querySet, k, metric, epsilon);
457 
458  // Create the traverser.
459  SingleTreeTraversalType<RuleType> traverser(rules);
460 
461  // Now have it traverse for each point.
462  for (size_t i = 0; i < querySet.n_cols; ++i)
463  traverser.Traverse(i, *referenceTree);
464 
465  scores += rules.Scores();
466  baseCases += rules.BaseCases();
467 
468  Log::Info << rules.Scores() << " node combinations were scored."
469  << std::endl;
470  Log::Info << rules.BaseCases() << " base cases were calculated."
471  << std::endl;
472 
473  rules.GetResults(*neighborPtr, *distancePtr);
474  break;
475  }
476  case DUAL_TREE_MODE:
477  {
478  // Build the query tree.
479  Timer::Stop("computing_neighbors");
480  Timer::Start("tree_building");
481  Tree* queryTree = BuildTree<Tree>(querySet, oldFromNewQueries);
482  Timer::Stop("tree_building");
483  Timer::Start("computing_neighbors");
484 
485  // Create the helper object for the tree traversal.
486  RuleType rules(*referenceSet, queryTree->Dataset(), k, metric, epsilon);
487 
488  // Create the traverser.
489  DualTreeTraversalType<RuleType> traverser(rules);
490 
491  traverser.Traverse(*queryTree, *referenceTree);
492 
493  scores += rules.Scores();
494  baseCases += rules.BaseCases();
495 
496  Log::Info << rules.Scores() << " node combinations were scored."
497  << std::endl;
498  Log::Info << rules.BaseCases() << " base cases were calculated."
499  << std::endl;
500 
501  rules.GetResults(*neighborPtr, *distancePtr);
502 
503  delete queryTree;
504  break;
505  }
506  case GREEDY_SINGLE_TREE_MODE:
507  {
508  // Create the helper object for the tree traversal.
509  RuleType rules(*referenceSet, querySet, k, metric);
510 
511  // Create the traverser.
513 
514  // Now have it traverse for each point.
515  for (size_t i = 0; i < querySet.n_cols; ++i)
516  traverser.Traverse(i, *referenceTree);
517 
518  scores += rules.Scores();
519  baseCases += rules.BaseCases();
520 
521  Log::Info << rules.Scores() << " node combinations were scored."
522  << std::endl;
523  Log::Info << rules.BaseCases() << " base cases were calculated."
524  << std::endl;
525 
526  rules.GetResults(*neighborPtr, *distancePtr);
527  break;
528  }
529  }
530 
531  Timer::Stop("computing_neighbors");
532 
533  // Map points back to original indices, if necessary.
535  {
536  if (searchMode == DUAL_TREE_MODE && !oldFromNewReferences.empty())
537  {
538  // We must map both query and reference indices.
539  neighbors.set_size(k, querySet.n_cols);
540  distances.set_size(k, querySet.n_cols);
541 
542  for (size_t i = 0; i < distances.n_cols; ++i)
543  {
544  // Map distances (copy a column).
545  distances.col(oldFromNewQueries[i]) = distancePtr->col(i);
546 
547  // Map indices of neighbors.
548  for (size_t j = 0; j < distances.n_rows; ++j)
549  {
550  neighbors(j, oldFromNewQueries[i]) =
551  oldFromNewReferences[(*neighborPtr)(j, i)];
552  }
553  }
554 
555  // Finished with temporary matrices.
556  delete neighborPtr;
557  delete distancePtr;
558  }
559  else if (searchMode == DUAL_TREE_MODE)
560  {
561  // We must map query indices only.
562  neighbors.set_size(k, querySet.n_cols);
563  distances.set_size(k, querySet.n_cols);
564 
565  for (size_t i = 0; i < distances.n_cols; ++i)
566  {
567  // Map distances (copy a column).
568  const size_t queryMapping = oldFromNewQueries[i];
569  distances.col(queryMapping) = distancePtr->col(i);
570  neighbors.col(queryMapping) = neighborPtr->col(i);
571  }
572 
573  // Finished with temporary matrices.
574  delete neighborPtr;
575  delete distancePtr;
576  }
577  else if (!oldFromNewReferences.empty())
578  {
579  // We must map reference indices only.
580  neighbors.set_size(k, querySet.n_cols);
581 
582  // Map indices of neighbors.
583  for (size_t i = 0; i < neighbors.n_cols; ++i)
584  for (size_t j = 0; j < neighbors.n_rows; ++j)
585  neighbors(j, i) = oldFromNewReferences[(*neighborPtr)(j, i)];
586 
587  // Finished with temporary matrix.
588  delete neighborPtr;
589  }
590  }
591 } // Search()
592 
593 template<typename SortPolicy,
594  typename MetricType,
595  typename MatType,
596  template<typename TreeMetricType,
597  typename TreeStatType,
598  typename TreeMatType> class TreeType,
599  template<typename> class DualTreeTraversalType,
600  template<typename> class SingleTreeTraversalType>
601 void NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
602 DualTreeTraversalType, SingleTreeTraversalType>::Search(
603  Tree& queryTree,
604  const size_t k,
605  arma::Mat<size_t>& neighbors,
606  arma::mat& distances,
607  bool sameSet)
608 {
609  if (k > referenceSet->n_cols)
610  {
611  std::stringstream ss;
612  ss << "Requested value of k (" << k << ") is greater than the number of "
613  << "points in the reference set (" << referenceSet->n_cols << ")";
614  throw std::invalid_argument(ss.str());
615  }
616 
617  // Make sure we are in dual-tree mode.
618  if (searchMode != DUAL_TREE_MODE)
619  throw std::invalid_argument("cannot call NeighborSearch::Search() with a "
620  "query tree when naive or singleMode are set to true");
621 
622  Timer::Start("computing_neighbors");
623 
624  baseCases = 0;
625  scores = 0;
626 
627  // Get a reference to the query set.
628  const MatType& querySet = queryTree.Dataset();
629 
630  // We won't need to map query indices, but will we need to map distances?
631  arma::Mat<size_t>* neighborPtr = &neighbors;
632 
633  if (!oldFromNewReferences.empty() &&
635  neighborPtr = new arma::Mat<size_t>;
636 
637  neighborPtr->set_size(k, querySet.n_cols);
638  distances.set_size(k, querySet.n_cols);
639 
640  // Create the helper object for the traversal.
642  RuleType rules(*referenceSet, querySet, k, metric, epsilon, sameSet);
643 
644  // Create the traverser.
645  DualTreeTraversalType<RuleType> traverser(rules);
646  traverser.Traverse(queryTree, *referenceTree);
647 
648  scores += rules.Scores();
649  baseCases += rules.BaseCases();
650 
651  Log::Info << rules.Scores() << " node combinations were scored." << std::endl;
652  Log::Info << rules.BaseCases() << " base cases were calculated." << std::endl;
653 
654  rules.GetResults(*neighborPtr, distances);
655 
656  Log::Info << rules.Scores() << " node combinations were scored.\n";
657  Log::Info << rules.BaseCases() << " base cases were calculated.\n";
658 
659  Timer::Stop("computing_neighbors");
660 
661  // Do we need to map indices?
662  if (!oldFromNewReferences.empty() &&
664  {
665  // We must map reference indices only.
666  neighbors.set_size(k, querySet.n_cols);
667 
668  // Map indices of neighbors.
669  for (size_t i = 0; i < neighbors.n_cols; ++i)
670  for (size_t j = 0; j < neighbors.n_rows; ++j)
671  neighbors(j, i) = oldFromNewReferences[(*neighborPtr)(j, i)];
672 
673  // Finished with temporary matrix.
674  delete neighborPtr;
675  }
676 }
677 
678 template<typename SortPolicy,
679  typename MetricType,
680  typename MatType,
681  template<typename TreeMetricType,
682  typename TreeStatType,
683  typename TreeMatType> class TreeType,
684  template<typename> class DualTreeTraversalType,
685  template<typename> class SingleTreeTraversalType>
686 void NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
687 DualTreeTraversalType, SingleTreeTraversalType>::Search(
688  const size_t k,
689  arma::Mat<size_t>& neighbors,
690  arma::mat& distances)
691 {
692  if (k > referenceSet->n_cols)
693  {
694  std::stringstream ss;
695  ss << "Requested value of k (" << k << ") is greater than the number of "
696  << "points in the reference set (" << referenceSet->n_cols << ")";
697  throw std::invalid_argument(ss.str());
698  }
699  if (k == referenceSet->n_cols)
700  {
701  std::stringstream ss;
702  ss << "Requested value of k (" << k << ") is equal to the number of "
703  << "points in the reference set (" << referenceSet->n_cols << ") and "
704  << "no query set has been provided.";
705  throw std::invalid_argument(ss.str());
706  }
707 
708  Timer::Start("computing_neighbors");
709 
710  baseCases = 0;
711  scores = 0;
712 
713  arma::Mat<size_t>* neighborPtr = &neighbors;
714  arma::mat* distancePtr = &distances;
715 
716  if (!oldFromNewReferences.empty() &&
718  {
719  // We will always need to rearrange in this case.
720  distancePtr = new arma::mat;
721  neighborPtr = new arma::Mat<size_t>;
722  }
723 
724  // Initialize results.
725  neighborPtr->set_size(k, referenceSet->n_cols);
726  distancePtr->set_size(k, referenceSet->n_cols);
727 
728  // Create the helper object for the traversal.
730  RuleType rules(*referenceSet, *referenceSet, k, metric, epsilon,
731  true /* don't return the same point as nearest neighbor */);
732 
733  switch (searchMode)
734  {
735  case NAIVE_MODE:
736  {
737  // The naive brute-force solution.
738  for (size_t i = 0; i < referenceSet->n_cols; ++i)
739  for (size_t j = 0; j < referenceSet->n_cols; ++j)
740  rules.BaseCase(i, j);
741 
742  baseCases += referenceSet->n_cols * referenceSet->n_cols;
743  break;
744  }
745  case SINGLE_TREE_MODE:
746  {
747  // Create the traverser.
748  SingleTreeTraversalType<RuleType> traverser(rules);
749 
750  // Now have it traverse for each point.
751  for (size_t i = 0; i < referenceSet->n_cols; ++i)
752  traverser.Traverse(i, *referenceTree);
753 
754  scores += rules.Scores();
755  baseCases += rules.BaseCases();
756 
757  Log::Info << rules.Scores() << " node combinations were scored."
758  << std::endl;
759  Log::Info << rules.BaseCases() << " base cases were calculated."
760  << std::endl;
761  break;
762  }
763  case DUAL_TREE_MODE:
764  {
765  // The dual-tree monochromatic search case may require resetting the
766  // bounds in the tree.
767  if (treeNeedsReset)
768  {
769  std::stack<Tree*> nodes;
770  nodes.push(referenceTree);
771  while (!nodes.empty())
772  {
773  Tree* node = nodes.top();
774  nodes.pop();
775 
776  // Reset bounds of this node.
777  node->Stat().Reset();
778 
779  // Then add the children.
780  for (size_t i = 0; i < node->NumChildren(); ++i)
781  nodes.push(&node->Child(i));
782  }
783  }
784 
785  // Create the traverser.
786  DualTreeTraversalType<RuleType> traverser(rules);
787 
789  {
790  // For Dual Tree Search on SpillTree, the queryTree must be built with
791  // non overlapping (tau = 0).
792  Tree queryTree(*referenceSet);
793  traverser.Traverse(queryTree, *referenceTree);
794  }
795  else
796  {
797  traverser.Traverse(*referenceTree, *referenceTree);
798  // Next time we perform this search, we'll need to reset the tree.
799  treeNeedsReset = true;
800  }
801 
802  scores += rules.Scores();
803  baseCases += rules.BaseCases();
804 
805  Log::Info << rules.Scores() << " node combinations were scored."
806  << std::endl;
807  Log::Info << rules.BaseCases() << " base cases were calculated."
808  << std::endl;
809 
810  // Next time we perform this search, we'll need to reset the tree.
811  treeNeedsReset = true;
812  break;
813  }
814  case GREEDY_SINGLE_TREE_MODE:
815  {
816  // Create the traverser.
818 
819  // Now have it traverse for each point.
820  for (size_t i = 0; i < referenceSet->n_cols; ++i)
821  traverser.Traverse(i, *referenceTree);
822 
823  scores += rules.Scores();
824  baseCases += rules.BaseCases();
825 
826  Log::Info << rules.Scores() << " node combinations were scored."
827  << std::endl;
828  Log::Info << rules.BaseCases() << " base cases were calculated."
829  << std::endl;
830  break;
831  }
832  }
833 
834  rules.GetResults(*neighborPtr, *distancePtr);
835 
836  Timer::Stop("computing_neighbors");
837 
838  // Do we need to map the reference indices?
839  if (!oldFromNewReferences.empty() &&
841  {
842  neighbors.set_size(k, referenceSet->n_cols);
843  distances.set_size(k, referenceSet->n_cols);
844 
845  for (size_t i = 0; i < distances.n_cols; ++i)
846  {
847  // Map distances (copy a column).
848  const size_t refMapping = oldFromNewReferences[i];
849  distances.col(refMapping) = distancePtr->col(i);
850 
851  // Map each neighbor's index.
852  for (size_t j = 0; j < distances.n_rows; ++j)
853  neighbors(j, refMapping) = oldFromNewReferences[(*neighborPtr)(j, i)];
854  }
855 
856  // Finished with temporary matrices.
857  delete neighborPtr;
858  delete distancePtr;
859  }
860 }
861 
863 template<typename SortPolicy,
864  typename MetricType,
865  typename MatType,
866  template<typename TreeMetricType,
867  typename TreeStatType,
868  typename TreeMatType> class TreeType,
869  template<typename> class DualTreeTraversalType,
870  template<typename> class SingleTreeTraversalType>
871 double NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
872 DualTreeTraversalType, SingleTreeTraversalType>::EffectiveError(
873  arma::mat& foundDistances,
874  arma::mat& realDistances)
875 {
876  if (foundDistances.n_rows != realDistances.n_rows ||
877  foundDistances.n_cols != realDistances.n_cols)
878  throw std::invalid_argument("matrices provided must have equal size");
879 
880  double effectiveError = 0;
881  size_t numCases = 0;
882 
883  for (size_t i = 0; i < foundDistances.n_elem; ++i)
884  {
885  if (realDistances(i) != 0 &&
886  foundDistances(i) != SortPolicy::WorstDistance())
887  {
888  effectiveError += fabs(foundDistances(i) - realDistances(i)) /
889  realDistances(i);
890  numCases++;
891  }
892  }
893 
894  if (numCases)
895  effectiveError /= numCases;
896 
897  return effectiveError;
898 }
899 
901 template<typename SortPolicy,
902  typename MetricType,
903  typename MatType,
904  template<typename TreeMetricType,
905  typename TreeStatType,
906  typename TreeMatType> class TreeType,
907  template<typename> class DualTreeTraversalType,
908  template<typename> class SingleTreeTraversalType>
909 double NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
910 DualTreeTraversalType, SingleTreeTraversalType>::Recall(
911  arma::Mat<size_t>& foundNeighbors,
912  arma::Mat<size_t>& realNeighbors)
913 {
914  if (foundNeighbors.n_rows != realNeighbors.n_rows ||
915  foundNeighbors.n_cols != realNeighbors.n_cols)
916  throw std::invalid_argument("matrices provided must have equal size");
917 
918  size_t found = 0;
919  for (size_t col = 0; col < foundNeighbors.n_cols; ++col)
920  for (size_t row = 0; row < foundNeighbors.n_rows; ++row)
921  for (size_t nei = 0; nei < realNeighbors.n_rows; ++nei)
922  if (foundNeighbors(row, col) == realNeighbors(nei, col))
923  {
924  found++;
925  break;
926  }
927 
928  return ((double) found) / realNeighbors.n_elem;
929 }
930 
932 template<typename SortPolicy,
933  typename MetricType,
934  typename MatType,
935  template<typename TreeMetricType,
936  typename TreeStatType,
937  typename TreeMatType> class TreeType,
938  template<typename> class DualTreeTraversalType,
939  template<typename> class SingleTreeTraversalType>
940 template<typename Archive>
941 void NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
942 DualTreeTraversalType, SingleTreeTraversalType>::serialize(
943  Archive& ar, const uint32_t /* version */)
944 {
945  // Serialize preferences for search.
946  ar(CEREAL_NVP(searchMode));
947  ar(CEREAL_NVP(treeNeedsReset));
948 
949  // If we are doing naive search, we serialize the dataset. Otherwise we
950  // serialize the tree.
951  if (searchMode == NAIVE_MODE)
952  {
953  // Delete the current reference set, if necessary and if we are loading.
954  if (cereal::is_loading<Archive>() && referenceSet)
955  {
956  delete referenceSet;
957  }
958 
959  ar(CEREAL_POINTER(const_cast<MatType*&>(referenceSet)));
960  ar(CEREAL_NVP(metric));
961 
962  // If we are loading, set the tree to NULL and clean up memory if necessary.
963  if (cereal::is_loading<Archive>())
964  {
965  if (referenceTree)
966  delete referenceTree;
967 
968  referenceTree = NULL;
969  oldFromNewReferences.clear();
970  }
971  }
972  else
973  {
974  // Delete the current reference tree, if necessary and if we are loading.
975  if (cereal::is_loading<Archive>() && referenceTree)
976  {
977  delete referenceTree;
978  }
979 
980  ar(CEREAL_POINTER(referenceTree));
981  ar(CEREAL_NVP(oldFromNewReferences));
982 
983  // If we are loading, set the dataset accordingly and clean up memory if
984  // necessary.
985  if (cereal::is_loading<Archive>())
986  {
987  referenceSet = &referenceTree->Dataset();
988  metric = referenceTree->Metric(); // Get the metric from the tree.
989  }
990  }
991 
992  // Reset base cases and scores.
993  if (cereal::is_loading<Archive>())
994  {
995  baseCases = 0;
996  scores = 0;
997  }
998 }
999 
1000 } // namespace neighbor
1001 } // namespace mlpack
1002 
1003 #endif
static void Start(const std::string &name)
Start the given timer.
Definition: timers.cpp:28
Definition: is_spill_tree.hpp:21
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Traverse(const size_t queryIndex, TreeType &referenceNode)
Traverse the tree with the given point.
Definition: greedy_single_tree_traverser_impl.hpp:31
The core includes that mlpack expects; standard C++ includes and Armadillo.
Definition: greedy_single_tree_traverser.hpp:23
The NeighborSearchRules class is a template helper class used by NeighborSearch class when performing...
Definition: neighbor_search_rules.hpp:35
Definition: hmm_train_main.cpp:300
static void Stop(const std::string &name)
Stop the given timer.
Definition: timers.cpp:36
The TreeTraits class provides compile-time information on the characteristics of a given tree type...
Definition: tree_traits.hpp:77
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84
#define CEREAL_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_wrapper.hpp:96
Definition of IsSpillTree.