mlpack
ra_search_rules_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_RANN_RA_SEARCH_RULES_IMPL_HPP
13 #define MLPACK_METHODS_RANN_RA_SEARCH_RULES_IMPL_HPP
14 
15 // In case it hasn't been included yet.
16 #include "ra_search_rules.hpp"
17 
18 namespace mlpack {
19 namespace neighbor {
20 
21 template<typename SortPolicy, typename MetricType, typename TreeType>
23 RASearchRules(const arma::mat& referenceSet,
24  const arma::mat& querySet,
25  const size_t k,
26  MetricType& metric,
27  const double tau,
28  const double alpha,
29  const bool naive,
30  const bool sampleAtLeaves,
31  const bool firstLeafExact,
32  const size_t singleSampleLimit,
33  const bool sameSet) :
34  referenceSet(referenceSet),
35  querySet(querySet),
36  k(k),
37  metric(metric),
38  sampleAtLeaves(sampleAtLeaves),
39  firstLeafExact(firstLeafExact),
40  singleSampleLimit(singleSampleLimit),
41  sameSet(sameSet)
42 {
43  // Validate tau to make sure that the rank approximation is greater than the
44  // number of neighbors requested.
45 
46  // The rank approximation.
47  const size_t n = referenceSet.n_cols;
48  const size_t t = (size_t) std::ceil(tau * (double) n / 100.0);
49  if (t < k)
50  {
51  Log::Warn << "Rank-approximation percentile " << tau << " corresponds to "
52  << t << " points, which is less than k (" << k << ").";
53  Log::Fatal << "Cannot return " << k << " approximate nearest neighbors "
54  << "from the nearest " << t << " points. Increase tau!" << std::endl;
55  }
56  else if (t == k)
57  Log::Warn << "Rank-approximation percentile " << tau << " corresponds to "
58  << t << " points; because k = " << k << ", this is exact search!"
59  << std::endl;
60 
61  Timer::Start("computing_number_of_samples_reqd");
62  numSamplesReqd = RAUtil::MinimumSamplesReqd(n, k, tau, alpha);
63  Timer::Stop("computing_number_of_samples_reqd");
64 
65  // Initialize some statistics to be collected during the search.
66  numSamplesMade = arma::zeros<arma::Col<size_t> >(querySet.n_cols);
67  numDistComputations = 0;
68  samplingRatio = (double) numSamplesReqd / (double) n;
69 
70  Log::Info << "Minimum samples required per query: " << numSamplesReqd <<
71  ", sampling ratio: " << samplingRatio << std::endl;
72 
73  // Let's build the list of candidate neighbors for each query point.
74  // It will be initialized with k candidates: (WorstDistance, size_t() - 1)
75  // The list of candidates will be updated when visiting new points with the
76  // BaseCase() method.
77  const Candidate def = std::make_pair(SortPolicy::WorstDistance(),
78  size_t() - 1);
79 
80  std::vector<Candidate> vect(k, def);
81  CandidateList pqueue(CandidateCmp(), std::move(vect));
82 
83  candidates.reserve(querySet.n_cols);
84  for (size_t i = 0; i < querySet.n_cols; ++i)
85  candidates.push_back(pqueue);
86 
87  if (naive) // No tree traversal; just do naive sampling here.
88  {
89  // Sample enough points.
90  arma::uvec distinctSamples;
91  for (size_t i = 0; i < querySet.n_cols; ++i)
92  {
93  math::ObtainDistinctSamples(0, n, numSamplesReqd, distinctSamples);
94  for (size_t j = 0; j < distinctSamples.n_elem; ++j)
95  BaseCase(i, (size_t) distinctSamples[j]);
96  }
97  }
98 }
99 
100 template<typename SortPolicy, typename MetricType, typename TreeType>
102  arma::Mat<size_t>& neighbors,
103  arma::mat& distances)
104 {
105  neighbors.set_size(k, querySet.n_cols);
106  distances.set_size(k, querySet.n_cols);
107 
108  for (size_t i = 0; i < querySet.n_cols; ++i)
109  {
110  CandidateList& pqueue = candidates[i];
111  for (size_t j = 1; j <= k; ++j)
112  {
113  neighbors(k - j, i) = pqueue.top().second;
114  distances(k - j, i) = pqueue.top().first;
115  pqueue.pop();
116  }
117  }
118 };
119 
120 template<typename SortPolicy, typename MetricType, typename TreeType>
121 inline force_inline
123  const size_t queryIndex,
124  const size_t referenceIndex)
125 {
126  // If the datasets are the same, then this search is only using one dataset
127  // and we should not return identical points.
128  if (sameSet && (queryIndex == referenceIndex))
129  return 0.0;
130 
131  double distance = metric.Evaluate(querySet.unsafe_col(queryIndex),
132  referenceSet.unsafe_col(referenceIndex));
133 
134  InsertNeighbor(queryIndex, referenceIndex, distance);
135 
136  numSamplesMade[queryIndex]++;
137 
138  numDistComputations++;
139 
140  return distance;
141 }
142 
143 template<typename SortPolicy, typename MetricType, typename TreeType>
145  const size_t queryIndex,
146  TreeType& referenceNode)
147 {
148  const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
149  const double distance = SortPolicy::BestPointToNodeDistance(queryPoint,
150  &referenceNode);
151  const double bestDistance = candidates[queryIndex].top().first;
152 
153  return Score(queryIndex, referenceNode, distance, bestDistance);
154 }
155 
156 template<typename SortPolicy, typename MetricType, typename TreeType>
158  const size_t queryIndex,
159  TreeType& referenceNode,
160  const double baseCaseResult)
161 {
162  const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
163  const double distance = SortPolicy::BestPointToNodeDistance(queryPoint,
164  &referenceNode, baseCaseResult);
165  const double bestDistance = candidates[queryIndex].top().first;
166 
167  return Score(queryIndex, referenceNode, distance, bestDistance);
168 }
169 
170 template<typename SortPolicy, typename MetricType, typename TreeType>
172  const size_t queryIndex,
173  TreeType& referenceNode,
174  const double distance,
175  const double bestDistance)
176 {
177  // If this is better than the best distance we've seen so far, maybe there
178  // will be something down this node. Also check if enough samples are already
179  // made for this query.
180  if (SortPolicy::IsBetter(distance, bestDistance)
181  && numSamplesMade[queryIndex] < numSamplesReqd)
182  {
183  // We cannot prune this node; try approximating it by sampling.
184 
185  // If we are required to visit the first leaf (to find possible duplicates),
186  // make sure we do not approximate.
187  if (numSamplesMade[queryIndex] > 0 || !firstLeafExact)
188  {
189  // Check if this node can be approximated by sampling.
190  size_t samplesReqd = (size_t) std::ceil(samplingRatio *
191  (double) referenceNode.NumDescendants());
192  samplesReqd = std::min(samplesReqd,
193  numSamplesReqd - numSamplesMade[queryIndex]);
194 
195  if (samplesReqd > singleSampleLimit && !referenceNode.IsLeaf())
196  {
197  // If too many samples required and not at a leaf, then can't prune.
198  return distance;
199  }
200  else
201  {
202  if (!referenceNode.IsLeaf())
203  {
204  // Then samplesReqd <= singleSampleLimit.
205  // Hence, approximate the node by sampling enough number of points.
206  arma::uvec distinctSamples;
207  math::ObtainDistinctSamples(0, referenceNode.NumDescendants(),
208  samplesReqd, distinctSamples);
209  for (size_t i = 0; i < distinctSamples.n_elem; ++i)
210  // The counting of the samples are done in the 'BaseCase' function
211  // so no book-keeping is required here.
212  BaseCase(queryIndex, referenceNode.Descendant(distinctSamples[i]));
213 
214  // Node approximated, so we can prune it.
215  return DBL_MAX;
216  }
217  else // We are at a leaf.
218  {
219  if (sampleAtLeaves) // If allowed to sample at leaves.
220  {
221  // Approximate node by sampling enough number of points.
222  arma::uvec distinctSamples;
223  math::ObtainDistinctSamples(0, referenceNode.NumDescendants(),
224  samplesReqd, distinctSamples);
225  for (size_t i = 0; i < distinctSamples.n_elem; ++i)
226  // The counting of the samples are done in the 'BaseCase' function
227  // so no book-keeping is required here.
228  BaseCase(queryIndex,
229  referenceNode.Descendant(distinctSamples[i]));
230 
231  // (Leaf) node approximated, so we can prune it.
232  return DBL_MAX;
233  }
234  else
235  {
236  // Not allowed to sample from leaves, so cannot prune.
237  return distance;
238  }
239  }
240  }
241  }
242  else
243  {
244  // Try first to visit the first leaf to boost your accuracy and find
245  // (near) duplicates if they exist.
246  return distance;
247  }
248  }
249  else
250  {
251  // Either there cannot be anything better in this node, or enough number of
252  // samples are already made. So prune it.
253 
254  // Add 'fake' samples from this node; they are fake because the distances to
255  // these samples need not be computed.
256 
257  // If enough samples are already made, this step does not change the result
258  // of the search.
259  numSamplesMade[queryIndex] += (size_t) std::floor(
260  samplingRatio * (double) referenceNode.NumDescendants());
261 
262  return DBL_MAX;
263  }
264 }
265 
266 template<typename SortPolicy, typename MetricType, typename TreeType>
268 Rescore(const size_t queryIndex,
269  TreeType& referenceNode,
270  const double oldScore)
271 {
272  // If we are already pruning, still prune.
273  if (oldScore == DBL_MAX)
274  return oldScore;
275 
276  // Just check the score again against the distances.
277  const double bestDistance = candidates[queryIndex].top().first;
278 
279  // If this is better than the best distance we've seen so far,
280  // maybe there will be something down this node.
281  // Also check if enough samples are already made for this query.
282  if (SortPolicy::IsBetter(oldScore, bestDistance)
283  && numSamplesMade[queryIndex] < numSamplesReqd)
284  {
285  // We cannot prune this node; thus, we try approximating this node by
286  // sampling.
287 
288  // Here, we assume that since we are re-scoring, the algorithm has already
289  // sampled some candidates, and if specified, also traversed to the first
290  // leaf. So no check regarding that is made any more.
291 
292  // Check if this node can be approximated by sampling.
293  size_t samplesReqd = (size_t) std::ceil(samplingRatio *
294  (double) referenceNode.NumDescendants());
295  samplesReqd = std::min(samplesReqd, numSamplesReqd -
296  numSamplesMade[queryIndex]);
297 
298  if (samplesReqd > singleSampleLimit && !referenceNode.IsLeaf())
299  {
300  // If too many samples are required and we are not at a leaf, then we
301  // can't prune.
302  return oldScore;
303  }
304  else
305  {
306  if (!referenceNode.IsLeaf())
307  {
308  // Then, samplesReqd <= singleSampleLimit. Hence, approximate the node
309  // by sampling enough number of points.
310  arma::uvec distinctSamples;
311  math::ObtainDistinctSamples(0, referenceNode.NumDescendants(),
312  samplesReqd, distinctSamples);
313  for (size_t i = 0; i < distinctSamples.n_elem; ++i)
314  // The counting of the samples are done in the 'BaseCase' function so
315  // no book-keeping is required here.
316  BaseCase(queryIndex, referenceNode.Descendant(distinctSamples[i]));
317 
318  // Node approximated, so we can prune it.
319  return DBL_MAX;
320  }
321  else // We are at a leaf.
322  {
323  if (sampleAtLeaves)
324  {
325  // Approximate node by sampling enough points.
326  arma::uvec distinctSamples;
327  math::ObtainDistinctSamples(0, referenceNode.NumDescendants(),
328  samplesReqd, distinctSamples);
329  for (size_t i = 0; i < distinctSamples.n_elem; ++i)
330  // The counting of the samples are done in the 'BaseCase' function
331  // so no book-keeping is required here.
332  BaseCase(queryIndex, referenceNode.Descendant(distinctSamples[i]));
333 
334  // (Leaf) node approximated, so we can prune it.
335  return DBL_MAX;
336  }
337  else
338  {
339  // We cannot sample from leaves, so we cannot prune.
340  return oldScore;
341  }
342  }
343  }
344  }
345  else
346  {
347  // Either there cannot be anything better in this node, or enough number of
348  // samples are already made, so prune it.
349 
350  // Add 'fake' samples from this node; they are fake because the distances to
351  // these samples need not be computed. If enough samples are already made,
352  // this step does not change the result of the search.
353  numSamplesMade[queryIndex] += (size_t) std::floor(samplingRatio *
354  (double) referenceNode.NumDescendants());
355 
356  return DBL_MAX;
357  }
358 } // Rescore(point, node, oldScore)
359 
360 template<typename SortPolicy, typename MetricType, typename TreeType>
362  TreeType& queryNode,
363  TreeType& referenceNode)
364 {
365  // First try to find the distance bound to check if we can prune by distance.
366 
367  // Calculate the best node-to-node distance.
368  const double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
369  &referenceNode);
370 
371  double pointBound = DBL_MAX;
372  double childBound = DBL_MAX;
373  const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
374 
375  for (size_t i = 0; i < queryNode.NumPoints(); ++i)
376  {
377  const double bound = candidates[queryNode.Point(i)].top().first
378  + maxDescendantDistance;
379  if (bound < pointBound)
380  pointBound = bound;
381  }
382 
383  for (size_t i = 0; i < queryNode.NumChildren(); ++i)
384  {
385  const double bound = queryNode.Child(i).Stat().Bound();
386  if (bound < childBound)
387  childBound = bound;
388  }
389 
390  // Update the bound.
391  queryNode.Stat().Bound() = std::min(pointBound, childBound);
392  const double bestDistance = queryNode.Stat().Bound();
393 
394  return Score(queryNode, referenceNode, distance, bestDistance);
395 }
396 
397 template<typename SortPolicy, typename MetricType, typename TreeType>
399  TreeType& queryNode,
400  TreeType& referenceNode,
401  const double baseCaseResult)
402 {
403  // First try to find the distance bound to check if we can prune
404  // by distance.
405 
406  // Find the best node-to-node distance.
407  const double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
408  &referenceNode, baseCaseResult);
409 
410  double pointBound = DBL_MAX;
411  double childBound = DBL_MAX;
412  const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
413 
414  for (size_t i = 0; i < queryNode.NumPoints(); ++i)
415  {
416  const double bound = candidates[queryNode.Point(i)].top().first
417  + maxDescendantDistance;
418  if (bound < pointBound)
419  pointBound = bound;
420  }
421 
422  for (size_t i = 0; i < queryNode.NumChildren(); ++i)
423  {
424  const double bound = queryNode.Child(i).Stat().Bound();
425  if (bound < childBound)
426  childBound = bound;
427  }
428 
429  // update the bound
430  queryNode.Stat().Bound() = std::min(pointBound, childBound);
431  const double bestDistance = queryNode.Stat().Bound();
432 
433  return Score(queryNode, referenceNode, distance, bestDistance);
434 }
435 
436 template<typename SortPolicy, typename MetricType, typename TreeType>
438  TreeType& queryNode,
439  TreeType& referenceNode,
440  const double distance,
441  const double bestDistance)
442 {
443  // Update the number of samples made for this node -- propagate up from child
444  // nodes if child nodes have made samples that the parent node is not aware
445  // of. Remember, we must propagate down samples made to the child nodes if
446  // 'queryNode' descend is deemed necessary.
447 
448  // Only update from children if a non-leaf node, obviously.
449  if (!queryNode.IsLeaf())
450  {
451  size_t numSamplesMadeInChildNodes = std::numeric_limits<size_t>::max();
452 
453  // Find the minimum number of samples made among all children.
454  for (size_t i = 0; i < queryNode.NumChildren(); ++i)
455  {
456  const size_t numSamples = queryNode.Child(i).Stat().NumSamplesMade();
457  if (numSamples < numSamplesMadeInChildNodes)
458  numSamplesMadeInChildNodes = numSamples;
459  }
460 
461  // The number of samples made for a node is propagated up from the child
462  // nodes if the child nodes have made samples that the parent (which is the
463  // current 'queryNode') is not aware of.
464  queryNode.Stat().NumSamplesMade() = std::max(
465  queryNode.Stat().NumSamplesMade(), numSamplesMadeInChildNodes);
466  }
467 
468  // Now check if the node-pair interaction can be pruned.
469 
470  // If this is better than the best distance we've seen so far, maybe there
471  // will be something down this node. Also check if enough samples are already
472  // made for this 'queryNode'.
473  if (SortPolicy::IsBetter(distance, bestDistance)
474  && queryNode.Stat().NumSamplesMade() < numSamplesReqd)
475  {
476  // We cannot prune this node; try approximating this node by sampling.
477 
478  // If we are required to visit the first leaf (to find possible duplicates),
479  // make sure we do not approximate.
480  if (queryNode.Stat().NumSamplesMade() > 0 || !firstLeafExact)
481  {
482  // Check if this node can be approximated by sampling.
483  size_t samplesReqd = (size_t) std::ceil(samplingRatio
484  * (double) referenceNode.NumDescendants());
485  samplesReqd = std::min(samplesReqd, numSamplesReqd -
486  queryNode.Stat().NumSamplesMade());
487 
488  if (samplesReqd > singleSampleLimit && !referenceNode.IsLeaf())
489  {
490  // If too many samples are required and we are not at a leaf, then we
491  // can't prune. Since query tree descent is necessary now, propagate
492  // the number of samples made down to the children.
493 
494  // Iterate through all children and propagate the number of samples made
495  // to the children. Only update if the parent node has made samples the
496  // children have not seen.
497  for (size_t i = 0; i < queryNode.NumChildren(); ++i)
498  queryNode.Child(i).Stat().NumSamplesMade() = std::max(
499  queryNode.Stat().NumSamplesMade(),
500  queryNode.Child(i).Stat().NumSamplesMade());
501 
502  return distance;
503  }
504  else
505  {
506  if (!referenceNode.IsLeaf())
507  {
508  // Then samplesReqd <= singleSampleLimit. Hence, approximate node by
509  // sampling enough number of points for every query in the query node.
510  arma::uvec distinctSamples;
511  for (size_t i = 0; i < queryNode.NumDescendants(); ++i)
512  {
513  const size_t queryIndex = queryNode.Descendant(i);
514  math::ObtainDistinctSamples(0, referenceNode.NumDescendants(),
515  samplesReqd, distinctSamples);
516  for (size_t j = 0; j < distinctSamples.n_elem; ++j)
517  // The counting of the samples are done in the 'BaseCase' function
518  // so no book-keeping is required here.
519  BaseCase(queryIndex,
520  referenceNode.Descendant(distinctSamples[j]));
521  }
522 
523  // Update the number of samples made for the queryNode and also update
524  // the number of sample made for the child nodes.
525  queryNode.Stat().NumSamplesMade() += samplesReqd;
526 
527  // Since we are not going to descend down the query tree for this
528  // reference node, there is no point updating the number of samples
529  // made for the child nodes of this query node.
530 
531  // Node is approximated, so we can prune it.
532  return DBL_MAX;
533  }
534  else
535  {
536  if (sampleAtLeaves)
537  {
538  // Approximate node by sampling enough number of points for every
539  // query in the query node.
540  arma::uvec distinctSamples;
541  for (size_t i = 0; i < queryNode.NumDescendants(); ++i)
542  {
543  const size_t queryIndex = queryNode.Descendant(i);
544  math::ObtainDistinctSamples(0, referenceNode.NumDescendants(),
545  samplesReqd, distinctSamples);
546  for (size_t j = 0; j < distinctSamples.n_elem; ++j)
547  // The counting of the samples are done in the 'BaseCase'
548  // function so no book-keeping is required here.
549  BaseCase(queryIndex,
550  referenceNode.Descendant(distinctSamples[j]));
551  }
552 
553  // Update the number of samples made for the queryNode and also
554  // update the number of sample made for the child nodes.
555  queryNode.Stat().NumSamplesMade() += samplesReqd;
556 
557  // Since we are not going to descend down the query tree for this
558  // reference node, there is no point updating the number of samples
559  // made for the child nodes of this query node.
560 
561  // (Leaf) node is approximated, so we can prune it.
562  return DBL_MAX;
563  }
564  else
565  {
566  // We cannot sample from leaves, so we cannot prune. Propagate the
567  // number of samples made down to the children.
568 
569  // Go through all children and propagate the number of
570  // samples made to the children.
571  for (size_t i = 0; i < queryNode.NumChildren(); ++i)
572  queryNode.Child(i).Stat().NumSamplesMade() = std::max(
573  queryNode.Stat().NumSamplesMade(),
574  queryNode.Child(i).Stat().NumSamplesMade());
575 
576  return distance;
577  }
578  }
579  }
580  }
581  else
582  {
583  // We must first visit the first leaf to boost accuracy.
584  // Go through all children and propagate the number of
585  // samples made to the children.
586  for (size_t i = 0; i < queryNode.NumChildren(); ++i)
587  queryNode.Child(i).Stat().NumSamplesMade() = std::max(
588  queryNode.Stat().NumSamplesMade(),
589  queryNode.Child(i).Stat().NumSamplesMade());
590 
591  return distance;
592  }
593  }
594  else
595  {
596  // Either there cannot be anything better in this node, or enough number of
597  // samples are already made, so prune it.
598 
599  // Add 'fake' samples from this node; fake because the distances to
600  // these samples need not be computed. If enough samples are already made,
601  // this step does not change the result of the search since this queryNode
602  // will never be descended anymore.
603  queryNode.Stat().NumSamplesMade() += (size_t) std::floor(samplingRatio *
604  (double) referenceNode.NumDescendants());
605 
606  // Since we are not going to descend down the query tree for this reference
607  // node, there is no point updating the number of samples made for the child
608  // nodes of this query node.
609 
610  return DBL_MAX;
611  }
612 }
613 
614 template<typename SortPolicy, typename MetricType, typename TreeType>
616 Rescore(TreeType& queryNode,
617  TreeType& referenceNode,
618  const double oldScore)
619 {
620  if (oldScore == DBL_MAX)
621  return oldScore;
622 
623  // First try to find the distance bound to check if we can prune by distance.
624  double pointBound = DBL_MAX;
625  double childBound = DBL_MAX;
626  const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
627 
628  for (size_t i = 0; i < queryNode.NumPoints(); ++i)
629  {
630  const double bound = candidates[queryNode.Point(i)].top().first
631  + maxDescendantDistance;
632  if (bound < pointBound)
633  pointBound = bound;
634  }
635 
636  for (size_t i = 0; i < queryNode.NumChildren(); ++i)
637  {
638  const double bound = queryNode.Child(i).Stat().Bound();
639  if (bound < childBound)
640  childBound = bound;
641  }
642 
643  // Update the bound.
644  queryNode.Stat().Bound() = std::min(pointBound, childBound);
645  const double bestDistance = queryNode.Stat().Bound();
646 
647  // Now check if the node-pair interaction can be pruned by sampling.
648  // Update the number of samples made for that node. Propagate up from child
649  // nodes if child nodes have made samples that the parent node is not aware
650  // of. Remember, we must propagate down samples made to the child nodes if
651  // the parent samples.
652 
653  // Only update from children if a non-leaf node, obviously.
654  if (!queryNode.IsLeaf())
655  {
656  size_t numSamplesMadeInChildNodes = std::numeric_limits<size_t>::max();
657 
658  // Find the minimum number of samples made among all children
659  for (size_t i = 0; i < queryNode.NumChildren(); ++i)
660  {
661  const size_t numSamples = queryNode.Child(i).Stat().NumSamplesMade();
662  if (numSamples < numSamplesMadeInChildNodes)
663  numSamplesMadeInChildNodes = numSamples;
664  }
665 
666  // The number of samples made for a node is propagated up from the child
667  // nodes if the child nodes have made samples that the parent (which is the
668  // current 'queryNode') is not aware of.
669  queryNode.Stat().NumSamplesMade() = std::max(
670  queryNode.Stat().NumSamplesMade(), numSamplesMadeInChildNodes);
671  }
672 
673  // Now check if the node-pair interaction can be pruned by sampling.
674 
675  // If this is better than the best distance we've seen so far, maybe there
676  // will be something down this node. Also check if enough samples are already
677  // made for this query.
678  if (SortPolicy::IsBetter(oldScore, bestDistance) &&
679  queryNode.Stat().NumSamplesMade() < numSamplesReqd)
680  {
681  // We cannot prune this node, so approximate by sampling.
682 
683  // Here we assume that since we are re-scoring, the algorithm has already
684  // sampled some candidates, and if specified, also traversed to the first
685  // leaf. So no checks regarding that are made any more.
686  size_t samplesReqd = (size_t) std::ceil(
687  samplingRatio * (double) referenceNode.NumDescendants());
688  samplesReqd = std::min(samplesReqd,
689  numSamplesReqd - queryNode.Stat().NumSamplesMade());
690 
691  if (samplesReqd > singleSampleLimit && !referenceNode.IsLeaf())
692  {
693  // If too many samples are required and we are not at a leaf, then we
694  // can't prune.
695 
696  // Since query tree descent is necessary now, propagate the number of
697  // samples made down to the children.
698 
699  // Go through all children and propagate the number of samples made to the
700  // children. Only update if the parent node has made samples the children
701  // have not seen.
702  for (size_t i = 0; i < queryNode.NumChildren(); ++i)
703  queryNode.Child(i).Stat().NumSamplesMade() = std::max(
704  queryNode.Stat().NumSamplesMade(),
705  queryNode.Child(i).Stat().NumSamplesMade());
706 
707  return oldScore;
708  }
709  else
710  {
711  if (!referenceNode.IsLeaf()) // If not a leaf,
712  {
713  // then samplesReqd <= singleSampleLimit. Hence, approximate the node
714  // by sampling enough points for every query in the query node.
715  arma::uvec distinctSamples;
716  for (size_t i = 0; i < queryNode.NumDescendants(); ++i)
717  {
718  const size_t queryIndex = queryNode.Descendant(i);
719  math::ObtainDistinctSamples(0, referenceNode.NumDescendants(),
720  samplesReqd, distinctSamples);
721  for (size_t j = 0; j < distinctSamples.n_elem; ++j)
722  // The counting of the samples are done in the 'BaseCase'
723  // function so no book-keeping is required here.
724  BaseCase(queryIndex, referenceNode.Descendant(distinctSamples[j]));
725  }
726 
727  // Update the number of samples made for the query node and also update
728  // the number of samples made for the child nodes.
729  queryNode.Stat().NumSamplesMade() += samplesReqd;
730 
731  // Since we are not going to descend down the query tree for this
732  // reference node, there is no point updating the number of samples made
733  // for the child nodes of this query node.
734 
735  // Node approximated, so we can prune it.
736  return DBL_MAX;
737  }
738  else // We are at a leaf.
739  {
740  if (sampleAtLeaves)
741  {
742  // Approximate node by sampling enough points for every query in the
743  // query node.
744  arma::uvec distinctSamples;
745  for (size_t i = 0; i < queryNode.NumDescendants(); ++i)
746  {
747  const size_t queryIndex = queryNode.Descendant(i);
748  math::ObtainDistinctSamples(0, referenceNode.NumDescendants(),
749  samplesReqd, distinctSamples);
750  for (size_t j = 0; j < distinctSamples.n_elem; ++j)
751  // The counting of the samples are done in BaseCase() so no
752  // book-keeping is required here.
753  BaseCase(queryIndex,
754  referenceNode.Descendant(distinctSamples[j]));
755  }
756 
757  // Update the number of samples made for the query node and also
758  // update the number of samples made for the child nodes.
759  queryNode.Stat().NumSamplesMade() += samplesReqd;
760 
761  // Since we are not going to descend down the query tree for this
762  // reference node, there is no point updating the number of samples
763  // made for the child nodes of this query node.
764 
765  // (Leaf) node approximated, so we can prune it.
766  return DBL_MAX;
767  }
768  else
769  {
770  // We cannot sample from leaves, so we cannot prune.
771  // Propagate the number of samples made down to the children.
772  for (size_t i = 0; i < queryNode.NumChildren(); ++i)
773  queryNode.Child(i).Stat().NumSamplesMade() = std::max(
774  queryNode.Stat().NumSamplesMade(),
775  queryNode.Child(i).Stat().NumSamplesMade());
776 
777  return oldScore;
778  }
779  }
780  }
781  }
782  else
783  {
784  // Either there cannot be anything better in this node, or enough samples
785  // are already made, so prune it.
786 
787  // Add 'fake' samples from this node; fake because the distances to
788  // these samples need not be computed. If enough samples are already made,
789  // this step does not change the result of the search since this query node
790  // will never be descended anymore.
791  queryNode.Stat().NumSamplesMade() += (size_t) std::floor(samplingRatio *
792  (double) referenceNode.NumDescendants());
793 
794  // Since we are not going to descend down the query tree for this reference
795  // node, there is no point updating the number of samples made for the child
796  // nodes of this query node.
797  return DBL_MAX;
798  }
799 } // Rescore(node, node, oldScore)
800 
808 template<typename SortPolicy, typename MetricType, typename TreeType>
811  const size_t queryIndex,
812  const size_t neighbor,
813  const double distance)
814 {
815  CandidateList& pqueue = candidates[queryIndex];
816  Candidate c = std::make_pair(distance, neighbor);
817 
818  if (CandidateCmp()(c, pqueue.top()))
819  {
820  pqueue.pop();
821  pqueue.push(c);
822  }
823 }
824 
825 } // namespace neighbor
826 } // namespace mlpack
827 
828 #endif // MLPACK_METHODS_RANN_RA_SEARCH_RULES_IMPL_HPP
void ObtainDistinctSamples(const size_t loInclusive, const size_t hiExclusive, const size_t maxNumSamples, arma::uvec &distinctSamples)
Obtains no more than maxNumSamples distinct samples.
Definition: random.hpp:153
void GetResults(arma::Mat< size_t > &neighbors, arma::mat &distances)
Store the list of candidates for each query point in the given matrices.
Definition: ra_search_rules_impl.hpp:101
static void Start(const std::string &name)
Start the given timer.
Definition: timers.cpp:28
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
double Rescore(const size_t queryIndex, TreeType &referenceNode, const double oldScore)
Re-evaluate the score for recursion order.
Definition: ra_search_rules_impl.hpp:268
double Score(const size_t queryIndex, TreeType &referenceNode)
Get the score for recursion order.
Definition: ra_search_rules_impl.hpp:144
static MLPACK_EXPORT util::PrefixedOutStream Warn
Prints warning messages prefixed with [WARN ].
Definition: log.hpp:87
double BaseCase(const size_t queryIndex, const size_t referenceIndex)
Get the distance from the query point to the reference point.
Definition: ra_search_rules_impl.hpp:122
static size_t MinimumSamplesReqd(const size_t n, const size_t k, const double tau, const double alpha)
Compute the minimum number of samples required to guarantee the given rank-approximation and success ...
Definition: ra_util.cpp:18
static void Stop(const std::string &name)
Stop the given timer.
Definition: timers.cpp:36
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84
RASearchRules(const arma::mat &referenceSet, const arma::mat &querySet, const size_t k, MetricType &metric, const double tau=5, const double alpha=0.95, const bool naive=false, const bool sampleAtLeaves=false, const bool firstLeafExact=false, const size_t singleSampleLimit=20, const bool sameSet=false)
Construct the RASearchRules object.
Definition: ra_search_rules_impl.hpp:23
The RASearchRules class is a template helper class used by RASearch class when performing rank-approx...
Definition: ra_search_rules.hpp:33