mlpack
neighbor_search_rules_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_RULES_IMPL_HPP
13 #define MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_RULES_IMPL_HPP
14 
15 // In case it hasn't been included yet.
18 
19 namespace mlpack {
20 namespace neighbor {
21 
22 template<typename SortPolicy, typename MetricType, typename TreeType>
24  const typename TreeType::Mat& referenceSet,
25  const typename TreeType::Mat& querySet,
26  const size_t k,
27  MetricType& metric,
28  const double epsilon,
29  const bool sameSet) :
30  referenceSet(referenceSet),
31  querySet(querySet),
32  k(k),
33  metric(metric),
34  sameSet(sameSet),
35  epsilon(epsilon),
36  lastQueryIndex(querySet.n_cols),
37  lastReferenceIndex(referenceSet.n_cols),
38  baseCases(0),
39  scores(0)
40 {
41  // We must set the traversal info last query and reference node pointers to
42  // something that is both invalid (i.e. not a tree node) and not NULL. We'll
43  // use the this pointer.
44  traversalInfo.LastQueryNode() = (TreeType*) this;
45  traversalInfo.LastReferenceNode() = (TreeType*) this;
46 
47  // Let's build the list of candidate neighbors for each query point.
48  // It will be initialized with k candidates: (WorstDistance, size_t() - 1)
49  // The list of candidates will be updated when visiting new points with the
50  // BaseCase() method.
51  const Candidate def = std::make_pair(SortPolicy::WorstDistance(),
52  size_t() - 1);
53 
54  std::vector<Candidate> vect(k, def);
55  CandidateList pqueue(CandidateCmp(), std::move(vect));
56 
57  candidates.reserve(querySet.n_cols);
58  for (size_t i = 0; i < querySet.n_cols; ++i)
59  candidates.push_back(pqueue);
60 }
61 
62 template<typename SortPolicy, typename MetricType, typename TreeType>
64  arma::Mat<size_t>& neighbors,
65  arma::mat& distances)
66 {
67  neighbors.set_size(k, querySet.n_cols);
68  distances.set_size(k, querySet.n_cols);
69 
70  for (size_t i = 0; i < querySet.n_cols; ++i)
71  {
72  CandidateList& pqueue = candidates[i];
73  for (size_t j = 1; j <= k; ++j)
74  {
75  neighbors(k - j, i) = pqueue.top().second;
76  distances(k - j, i) = pqueue.top().first;
77  pqueue.pop();
78  }
79  }
80 };
81 
82 template<typename SortPolicy, typename MetricType, typename TreeType>
83 inline force_inline // Absolutely MUST be inline so optimizations can happen.
85 BaseCase(const size_t queryIndex, const size_t referenceIndex)
86 {
87  // If the datasets are the same, then this search is only using one dataset
88  // and we should not return identical points.
89  if (sameSet && (queryIndex == referenceIndex))
90  return 0.0;
91 
92  // If we have already performed this base case, then do not perform it again.
93  if ((lastQueryIndex == queryIndex) && (lastReferenceIndex == referenceIndex))
94  return lastBaseCase;
95 
96  double distance = metric.Evaluate(querySet.col(queryIndex),
97  referenceSet.col(referenceIndex));
98  ++baseCases;
99 
100  InsertNeighbor(queryIndex, referenceIndex, distance);
101 
102  // Cache this information for the next time BaseCase() is called.
103  lastQueryIndex = queryIndex;
104  lastReferenceIndex = referenceIndex;
105  lastBaseCase = distance;
106 
107  return distance;
108 }
109 
110 template<typename SortPolicy, typename MetricType, typename TreeType>
112  const size_t queryIndex,
113  TreeType& referenceNode)
114 {
115  ++scores; // Count number of Score() calls.
116  double distance;
118  {
119  // The first point in the tree is the centroid. So we can then calculate
120  // the base case between that and the query point.
121  double baseCase = -1.0;
123  {
124  // If the parent node is the same, then we have already calculated the
125  // base case.
126  if ((referenceNode.Parent() != NULL) &&
127  (referenceNode.Point(0) == referenceNode.Parent()->Point(0)))
128  baseCase = referenceNode.Parent()->Stat().LastDistance();
129  else
130  baseCase = BaseCase(queryIndex, referenceNode.Point(0));
131 
132  // Save this evaluation.
133  referenceNode.Stat().LastDistance() = baseCase;
134  }
135 
136  distance = SortPolicy::CombineBest(baseCase,
137  referenceNode.FurthestDescendantDistance());
138  }
139  else
140  {
141  distance = SortPolicy::BestPointToNodeDistance(querySet.col(queryIndex),
142  &referenceNode);
143  }
144 
145  // Compare against the best k'th distance for this query point so far.
146  double bestDistance = candidates[queryIndex].top().first;
147  bestDistance = SortPolicy::Relax(bestDistance, epsilon);
148 
149  return (SortPolicy::IsBetter(distance, bestDistance)) ?
150  SortPolicy::ConvertToScore(distance) : DBL_MAX;
151 }
152 
153 template<typename SortPolicy, typename MetricType, typename TreeType>
155 GetBestChild(const size_t queryIndex, TreeType& referenceNode)
156 {
157  ++scores;
158  return SortPolicy::GetBestChild(querySet.col(queryIndex), referenceNode);
159 }
160 
161 template<typename SortPolicy, typename MetricType, typename TreeType>
163 GetBestChild(const TreeType& queryNode, TreeType& referenceNode)
164 {
165  ++scores;
166  return SortPolicy::GetBestChild(queryNode, referenceNode);
167 }
168 
169 template<typename SortPolicy, typename MetricType, typename TreeType>
171  const size_t queryIndex,
172  TreeType& /* referenceNode */,
173  const double oldScore) const
174 {
175  // If we are already pruning, still prune.
176  if (oldScore == DBL_MAX)
177  return oldScore;
178 
179  const double distance = SortPolicy::ConvertToDistance(oldScore);
180 
181  // Just check the score again against the distances.
182  double bestDistance = candidates[queryIndex].top().first;
183  bestDistance = SortPolicy::Relax(bestDistance, epsilon);
184 
185  return (SortPolicy::IsBetter(distance, bestDistance)) ? oldScore : DBL_MAX;
186 }
187 
188 template<typename SortPolicy, typename MetricType, typename TreeType>
190  TreeType& queryNode,
191  TreeType& referenceNode)
192 {
193  ++scores; // Count number of Score() calls.
194 
195  // Update our bound.
196  const double bestDistance = CalculateBound(queryNode);
197 
198  // Use the traversal info to see if a parent-child or parent-parent prune is
199  // possible. This is a looser bound than we could make, but it might be
200  // sufficient.
201  const double queryParentDist = queryNode.ParentDistance();
202  const double queryDescDist = queryNode.FurthestDescendantDistance();
203  const double refParentDist = referenceNode.ParentDistance();
204  const double refDescDist = referenceNode.FurthestDescendantDistance();
205  const double score = traversalInfo.LastScore();
206  double adjustedScore;
207 
208  // We want to set adjustedScore to be the distance between the centroid of the
209  // last query node and last reference node. We will do this by adjusting the
210  // last score. In some cases, we can just use the last base case.
212  {
213  adjustedScore = traversalInfo.LastBaseCase();
214  }
215  else if (score == 0.0) // Nothing we can do here.
216  {
217  adjustedScore = 0.0;
218  }
219  else
220  {
221  // The last score is equal to the distance between the centroids minus the
222  // radii of the query and reference bounds along the axis of the line
223  // between the two centroids. In the best case, these radii are the
224  // furthest descendant distances, but that is not always true. It would
225  // take too long to calculate the exact radii, so we are forced to use
226  // MinimumBoundDistance() as a lower-bound approximation.
227  const double lastQueryDescDist =
228  traversalInfo.LastQueryNode()->MinimumBoundDistance();
229  const double lastRefDescDist =
230  traversalInfo.LastReferenceNode()->MinimumBoundDistance();
231  adjustedScore = SortPolicy::CombineWorst(score, lastQueryDescDist);
232  adjustedScore = SortPolicy::CombineWorst(adjustedScore, lastRefDescDist);
233  }
234 
235  // Assemble an adjusted score. For nearest neighbor search, this adjusted
236  // score is a lower bound on MinDistance(queryNode, referenceNode) that is
237  // assembled without actually calculating MinDistance(). For furthest
238  // neighbor search, it is an upper bound on
239  // MaxDistance(queryNode, referenceNode). If the traversalInfo isn't usable
240  // then the node should not be pruned by this.
241  if (traversalInfo.LastQueryNode() == queryNode.Parent())
242  {
243  const double queryAdjust = queryParentDist + queryDescDist;
244  adjustedScore = SortPolicy::CombineBest(adjustedScore, queryAdjust);
245  }
246  else if (traversalInfo.LastQueryNode() == &queryNode)
247  {
248  adjustedScore = SortPolicy::CombineBest(adjustedScore, queryDescDist);
249  }
250  else
251  {
252  // The last query node wasn't this query node or its parent. So we force
253  // the adjustedScore to be such that this combination can't be pruned here,
254  // because we don't really know anything about it.
255 
256  // It would be possible to modify this section to try and make a prune based
257  // on the query descendant distance and the distance between the query node
258  // and last traversal query node, but this case doesn't actually happen for
259  // kd-trees or cover trees.
260  adjustedScore = SortPolicy::BestDistance();
261  }
262 
263  if (traversalInfo.LastReferenceNode() == referenceNode.Parent())
264  {
265  const double refAdjust = refParentDist + refDescDist;
266  adjustedScore = SortPolicy::CombineBest(adjustedScore, refAdjust);
267  }
268  else if (traversalInfo.LastReferenceNode() == &referenceNode)
269  {
270  adjustedScore = SortPolicy::CombineBest(adjustedScore, refDescDist);
271  }
272  else
273  {
274  // The last reference node wasn't this reference node or its parent. So we
275  // force the adjustedScore to be such that this combination can't be pruned
276  // here, because we don't really know anything about it.
277 
278  // It would be possible to modify this section to try and make a prune based
279  // on the reference descendant distance and the distance between the
280  // reference node and last traversal reference node, but this case doesn't
281  // actually happen for kd-trees or cover trees.
282  adjustedScore = SortPolicy::BestDistance();
283  }
284 
285  // Can we prune?
286  if (!SortPolicy::IsBetter(adjustedScore, bestDistance))
287  {
289  {
290  // There isn't any need to set the traversal information because no
291  // descendant combinations will be visited, and those are the only
292  // combinations that would depend on the traversal information.
293  return DBL_MAX;
294  }
295  }
296 
297  double distance;
299  {
300  // The first point in the node is the centroid, so we can calculate the
301  // distance between the two points using BaseCase() and then find the
302  // bounds. This is potentially loose for non-ball bounds.
303  double baseCase = -1.0;
305  (traversalInfo.LastQueryNode()->Point(0) == queryNode.Point(0)) &&
306  (traversalInfo.LastReferenceNode()->Point(0) == referenceNode.Point(0)))
307  {
308  // We already calculated it.
309  baseCase = traversalInfo.LastBaseCase();
310  }
311  else
312  {
313  baseCase = BaseCase(queryNode.Point(0), referenceNode.Point(0));
314  }
315 
316  distance = SortPolicy::CombineBest(baseCase,
317  queryNode.FurthestDescendantDistance() +
318  referenceNode.FurthestDescendantDistance());
319 
320  lastQueryIndex = queryNode.Point(0);
321  lastReferenceIndex = referenceNode.Point(0);
322  lastBaseCase = baseCase;
323 
324  traversalInfo.LastBaseCase() = baseCase;
325  }
326  else
327  {
328  distance = SortPolicy::BestNodeToNodeDistance(&queryNode, &referenceNode);
329  }
330 
331  if (SortPolicy::IsBetter(distance, bestDistance))
332  {
333  // Set traversal information.
334  traversalInfo.LastQueryNode() = &queryNode;
335  traversalInfo.LastReferenceNode() = &referenceNode;
336  traversalInfo.LastScore() = distance;
337 
338  return SortPolicy::ConvertToScore(distance);
339  }
340  else
341  {
342  // There isn't any need to set the traversal information because no
343  // descendant combinations will be visited, and those are the only
344  // combinations that would depend on the traversal information.
345  return DBL_MAX;
346  }
347 }
348 
349 template<typename SortPolicy, typename MetricType, typename TreeType>
351  TreeType& queryNode,
352  TreeType& /* referenceNode */,
353  const double oldScore) const
354 {
355  if (oldScore == DBL_MAX || oldScore == 0.0)
356  return oldScore;
357 
358  const double distance = SortPolicy::ConvertToDistance(oldScore);
359 
360  // Update our bound.
361  const double bestDistance = CalculateBound(queryNode);
362 
363  return (SortPolicy::IsBetter(distance, bestDistance)) ? oldScore : DBL_MAX;
364 }
365 
366 // Calculate the bound for a given query node in its current state and update
367 // it.
368 template<typename SortPolicy, typename MetricType, typename TreeType>
370  CalculateBound(TreeType& queryNode) const
371 {
372  // This is an adapted form of the B(N_q) function in the paper
373  // ``Tree-Independent Dual-Tree Algorithms'' by Curtin et. al.; the goal is to
374  // place a bound on the worst possible distance a point combination could have
375  // to improve any of the current neighbor estimates. If the best possible
376  // distance between two nodes is greater than this bound, then the node
377  // combination can be pruned (see Score()).
378 
379  // There are a couple ways we can assemble a bound. For simplicity, this is
380  // described for nearest neighbor search (SortPolicy = NearestNeighborSort),
381  // but the code that is written is adapted for whichever SortPolicy.
382 
383  // First, we can consider the current worst neighbor candidate distance of any
384  // descendant point. This is assembled with 'worstDistance' by looping
385  // through the points held by the query node, and then by taking the cached
386  // worst distance from any child nodes (Stat().FirstBound()). This
387  // corresponds roughly to B_1(N_q) in the paper.
388 
389  // The other way of bounding is to use the triangle inequality. To do this,
390  // we find the current best kth-neighbor candidate distance of any descendant
391  // query point, and use the triangle inequality to place a bound on the
392  // distance that candidate would have to any other descendant query point.
393  // This corresponds roughly to B_2(N_q) in the paper, and is the bounding
394  // style for cover trees.
395 
396  // Then, to assemble the final bound, since both bounds are valid, we simply
397  // take the better of the two.
398 
399  double worstDistance = SortPolicy::BestDistance();
400  double bestPointDistance = SortPolicy::WorstDistance();
401 
402  // Loop over points held in the node.
403  for (size_t i = 0; i < queryNode.NumPoints(); ++i)
404  {
405  const double distance = candidates[queryNode.Point(i)].top().first;
406  if (SortPolicy::IsBetter(worstDistance, distance))
407  worstDistance = distance;
408  if (SortPolicy::IsBetter(distance, bestPointDistance))
409  bestPointDistance = distance;
410  }
411 
412  double auxDistance = bestPointDistance;
413 
414  // Loop over children of the node, and use their cached information to
415  // assemble bounds.
416  for (size_t i = 0; i < queryNode.NumChildren(); ++i)
417  {
418  const double firstBound = queryNode.Child(i).Stat().FirstBound();
419  const double auxBound = queryNode.Child(i).Stat().AuxBound();
420 
421  if (SortPolicy::IsBetter(worstDistance, firstBound))
422  worstDistance = firstBound;
423  if (SortPolicy::IsBetter(auxBound, auxDistance))
424  auxDistance = auxBound;
425  }
426 
427  // Add triangle inequality adjustment to best distance. It is possible this
428  // could be tighter for some certain types of trees.
429  double bestDistance = SortPolicy::CombineWorst(auxDistance,
430  2 * queryNode.FurthestDescendantDistance());
431 
432  // Add triangle inequality adjustment to best distance of points in node.
433  bestPointDistance = SortPolicy::CombineWorst(bestPointDistance,
434  queryNode.FurthestPointDistance() +
435  queryNode.FurthestDescendantDistance());
436 
437  if (SortPolicy::IsBetter(bestPointDistance, bestDistance))
438  bestDistance = bestPointDistance;
439 
440  // At this point:
441  // worstDistance holds the value of B_1(N_q).
442  // bestDistance holds the value of B_2(N_q).
443  // auxDistance holds the value of B_aux(N_q).
444 
445  // Now consider the parent bounds.
446  if (queryNode.Parent() != NULL)
447  {
448  // The parent's worst distance bound implies that the bound for this node
449  // must be at least as good. Thus, if the parent worst distance bound is
450  // better, then take it.
451  if (SortPolicy::IsBetter(queryNode.Parent()->Stat().FirstBound(),
452  worstDistance))
453  worstDistance = queryNode.Parent()->Stat().FirstBound();
454 
455  // The parent's best distance bound implies that the bound for this node
456  // must be at least as good. Thus, if the parent best distance bound is
457  // better, then take it.
458  if (SortPolicy::IsBetter(queryNode.Parent()->Stat().SecondBound(),
459  bestDistance))
460  bestDistance = queryNode.Parent()->Stat().SecondBound();
461  }
462 
463  // Could the existing bounds be better?
464  if (SortPolicy::IsBetter(queryNode.Stat().FirstBound(), worstDistance))
465  worstDistance = queryNode.Stat().FirstBound();
466  if (SortPolicy::IsBetter(queryNode.Stat().SecondBound(), bestDistance))
467  bestDistance = queryNode.Stat().SecondBound();
468 
469  // Cache bounds for later.
470  queryNode.Stat().FirstBound() = worstDistance;
471  queryNode.Stat().SecondBound() = bestDistance;
472  queryNode.Stat().AuxBound() = auxDistance;
473 
474  worstDistance = SortPolicy::Relax(worstDistance, epsilon);
475 
476  // We can't consider B_2 for Spill Trees.
478  return worstDistance;
479 
480  if (SortPolicy::IsBetter(worstDistance, bestDistance))
481  return worstDistance;
482  else
483  return bestDistance;
484 }
485 
493 template<typename SortPolicy, typename MetricType, typename TreeType>
496  const size_t queryIndex,
497  const size_t neighbor,
498  const double distance)
499 {
500  CandidateList& pqueue = candidates[queryIndex];
501  Candidate c = std::make_pair(distance, neighbor);
502 
503  if (CandidateCmp()(c, pqueue.top()))
504  {
505  pqueue.pop();
506  pqueue.push(c);
507  }
508 }
509 
510 } // namespace neighbor
511 } // namespace mlpack
512 
513 #endif // MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_RULES_IMPL_HPP
std::priority_queue< Candidate, std::vector< Candidate >, CandidateCmp > CandidateList
Use a priority queue to represent the list of candidate neighbors.
Definition: neighbor_search_rules.hpp:184
std::vector< CandidateList > candidates
Set of candidate neighbors for each point.
Definition: neighbor_search_rules.hpp:187
void InsertNeighbor(const size_t queryIndex, const size_t neighbor, const double distance)
Helper function to insert a point into the list of candidate points.
Definition: neighbor_search_rules_impl.hpp:495
size_t lastQueryIndex
The last query point BaseCase() was called with.
Definition: neighbor_search_rules.hpp:202
Definition: is_spill_tree.hpp:21
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
TraversalInfoType traversalInfo
Traversal info for the parent combination; this is updated by the traversal before each call to Score...
Definition: neighbor_search_rules.hpp:215
const TreeType::Mat & referenceSet
The reference set.
Definition: neighbor_search_rules.hpp:166
std::pair< double, size_t > Candidate
Candidate represents a possible candidate neighbor (distance, index).
Definition: neighbor_search_rules.hpp:172
const size_t k
Number of neighbors to search for.
Definition: neighbor_search_rules.hpp:190
const double epsilon
Relative error to be considered in approximate search.
Definition: neighbor_search_rules.hpp:199
double BaseCase(const size_t queryIndex, const size_t referenceIndex)
Get the distance from the query point to the reference point.
Definition: neighbor_search_rules_impl.hpp:85
void GetResults(arma::Mat< size_t > &neighbors, arma::mat &distances)
Store the list of candidates for each query point in the given matrices.
Definition: neighbor_search_rules_impl.hpp:63
double lastBaseCase
The last base case result.
Definition: neighbor_search_rules.hpp:206
size_t lastReferenceIndex
The last reference point BaseCase() was called with.
Definition: neighbor_search_rules.hpp:204
double LastBaseCase() const
Get the base case associated with the last node combination.
Definition: traversal_info.hpp:78
TreeType * LastQueryNode() const
Get the last query node.
Definition: traversal_info.hpp:63
double LastScore() const
Get the score associated with the last query and reference nodes.
Definition: traversal_info.hpp:73
double Score(const size_t queryIndex, TreeType &referenceNode)
Get the score for recursion order.
Definition: neighbor_search_rules_impl.hpp:111
The TreeTraits class provides compile-time information on the characteristics of a given tree type...
Definition: tree_traits.hpp:77
double Rescore(const size_t queryIndex, TreeType &referenceNode, const double oldScore) const
Re-evaluate the score for recursion order.
Definition: neighbor_search_rules_impl.hpp:170
NeighborSearchRules(const typename TreeType::Mat &referenceSet, const typename TreeType::Mat &querySet, const size_t k, MetricType &metric, const double epsilon=0, const bool sameSet=false)
Construct the NeighborSearchRules object.
Definition: neighbor_search_rules_impl.hpp:23
double CalculateBound(TreeType &queryNode) const
Recalculate the bound for a given query node.
Definition: neighbor_search_rules_impl.hpp:370
size_t GetBestChild(const size_t queryIndex, TreeType &referenceNode)
Get the child node with the best score.
Definition: neighbor_search_rules_impl.hpp:155
MetricType & metric
The instantiated metric.
Definition: neighbor_search_rules.hpp:193
size_t scores
The number of scores that have been performed.
Definition: neighbor_search_rules.hpp:211
size_t baseCases
The number of base cases that have been performed.
Definition: neighbor_search_rules.hpp:209
Definition of IsSpillTree.
const TreeType::Mat & querySet
The query set.
Definition: neighbor_search_rules.hpp:169
TreeType * LastReferenceNode() const
Get the last reference node.
Definition: traversal_info.hpp:68
Compare two candidates based on the distance.
Definition: neighbor_search_rules.hpp:175
bool sameSet
Denotes whether or not the reference and query sets are the same.
Definition: neighbor_search_rules.hpp:196