12 #ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_RULES_IMPL_HPP 13 #define MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_RULES_IMPL_HPP 22 template<
typename SortPolicy,
typename MetricType,
typename TreeType>
24 const typename TreeType::Mat& referenceSet,
25 const typename TreeType::Mat& querySet,
30 referenceSet(referenceSet),
36 lastQueryIndex(querySet.n_cols),
37 lastReferenceIndex(referenceSet.n_cols),
51 const Candidate def = std::make_pair(SortPolicy::WorstDistance(),
54 std::vector<Candidate> vect(k, def);
58 for (
size_t i = 0; i < querySet.n_cols; ++i)
62 template<
typename SortPolicy,
typename MetricType,
typename TreeType>
64 arma::Mat<size_t>& neighbors,
70 for (
size_t i = 0; i <
querySet.n_cols; ++i)
73 for (
size_t j = 1; j <=
k; ++j)
75 neighbors(
k - j, i) = pqueue.top().second;
76 distances(
k - j, i) = pqueue.top().first;
82 template<
typename SortPolicy,
typename MetricType,
typename TreeType>
85 BaseCase(
const size_t queryIndex,
const size_t referenceIndex)
89 if (
sameSet && (queryIndex == referenceIndex))
110 template<
typename SortPolicy,
typename MetricType,
typename TreeType>
112 const size_t queryIndex,
113 TreeType& referenceNode)
121 double baseCase = -1.0;
126 if ((referenceNode.Parent() != NULL) &&
127 (referenceNode.Point(0) == referenceNode.Parent()->Point(0)))
128 baseCase = referenceNode.Parent()->Stat().LastDistance();
130 baseCase =
BaseCase(queryIndex, referenceNode.Point(0));
133 referenceNode.Stat().LastDistance() = baseCase;
136 distance = SortPolicy::CombineBest(baseCase,
137 referenceNode.FurthestDescendantDistance());
141 distance = SortPolicy::BestPointToNodeDistance(
querySet.col(queryIndex),
146 double bestDistance =
candidates[queryIndex].top().first;
147 bestDistance = SortPolicy::Relax(bestDistance,
epsilon);
149 return (SortPolicy::IsBetter(distance, bestDistance)) ?
150 SortPolicy::ConvertToScore(distance) : DBL_MAX;
153 template<
typename SortPolicy,
typename MetricType,
typename TreeType>
158 return SortPolicy::GetBestChild(
querySet.col(queryIndex), referenceNode);
161 template<
typename SortPolicy,
typename MetricType,
typename TreeType>
166 return SortPolicy::GetBestChild(queryNode, referenceNode);
169 template<
typename SortPolicy,
typename MetricType,
typename TreeType>
171 const size_t queryIndex,
173 const double oldScore)
const 176 if (oldScore == DBL_MAX)
179 const double distance = SortPolicy::ConvertToDistance(oldScore);
182 double bestDistance =
candidates[queryIndex].top().first;
183 bestDistance = SortPolicy::Relax(bestDistance,
epsilon);
185 return (SortPolicy::IsBetter(distance, bestDistance)) ? oldScore : DBL_MAX;
188 template<
typename SortPolicy,
typename MetricType,
typename TreeType>
191 TreeType& referenceNode)
201 const double queryParentDist = queryNode.ParentDistance();
202 const double queryDescDist = queryNode.FurthestDescendantDistance();
203 const double refParentDist = referenceNode.ParentDistance();
204 const double refDescDist = referenceNode.FurthestDescendantDistance();
206 double adjustedScore;
215 else if (score == 0.0)
227 const double lastQueryDescDist =
229 const double lastRefDescDist =
231 adjustedScore = SortPolicy::CombineWorst(score, lastQueryDescDist);
232 adjustedScore = SortPolicy::CombineWorst(adjustedScore, lastRefDescDist);
243 const double queryAdjust = queryParentDist + queryDescDist;
244 adjustedScore = SortPolicy::CombineBest(adjustedScore, queryAdjust);
248 adjustedScore = SortPolicy::CombineBest(adjustedScore, queryDescDist);
260 adjustedScore = SortPolicy::BestDistance();
265 const double refAdjust = refParentDist + refDescDist;
266 adjustedScore = SortPolicy::CombineBest(adjustedScore, refAdjust);
270 adjustedScore = SortPolicy::CombineBest(adjustedScore, refDescDist);
282 adjustedScore = SortPolicy::BestDistance();
286 if (!SortPolicy::IsBetter(adjustedScore, bestDistance))
303 double baseCase = -1.0;
313 baseCase =
BaseCase(queryNode.Point(0), referenceNode.Point(0));
316 distance = SortPolicy::CombineBest(baseCase,
317 queryNode.FurthestDescendantDistance() +
318 referenceNode.FurthestDescendantDistance());
328 distance = SortPolicy::BestNodeToNodeDistance(&queryNode, &referenceNode);
331 if (SortPolicy::IsBetter(distance, bestDistance))
338 return SortPolicy::ConvertToScore(distance);
349 template<
typename SortPolicy,
typename MetricType,
typename TreeType>
353 const double oldScore)
const 355 if (oldScore == DBL_MAX || oldScore == 0.0)
358 const double distance = SortPolicy::ConvertToDistance(oldScore);
363 return (SortPolicy::IsBetter(distance, bestDistance)) ? oldScore : DBL_MAX;
368 template<
typename SortPolicy,
typename MetricType,
typename TreeType>
399 double worstDistance = SortPolicy::BestDistance();
400 double bestPointDistance = SortPolicy::WorstDistance();
403 for (
size_t i = 0; i < queryNode.NumPoints(); ++i)
405 const double distance =
candidates[queryNode.Point(i)].top().first;
406 if (SortPolicy::IsBetter(worstDistance, distance))
407 worstDistance = distance;
408 if (SortPolicy::IsBetter(distance, bestPointDistance))
409 bestPointDistance = distance;
412 double auxDistance = bestPointDistance;
416 for (
size_t i = 0; i < queryNode.NumChildren(); ++i)
418 const double firstBound = queryNode.Child(i).Stat().FirstBound();
419 const double auxBound = queryNode.Child(i).Stat().AuxBound();
421 if (SortPolicy::IsBetter(worstDistance, firstBound))
422 worstDistance = firstBound;
423 if (SortPolicy::IsBetter(auxBound, auxDistance))
424 auxDistance = auxBound;
429 double bestDistance = SortPolicy::CombineWorst(auxDistance,
430 2 * queryNode.FurthestDescendantDistance());
433 bestPointDistance = SortPolicy::CombineWorst(bestPointDistance,
434 queryNode.FurthestPointDistance() +
435 queryNode.FurthestDescendantDistance());
437 if (SortPolicy::IsBetter(bestPointDistance, bestDistance))
438 bestDistance = bestPointDistance;
446 if (queryNode.Parent() != NULL)
451 if (SortPolicy::IsBetter(queryNode.Parent()->Stat().FirstBound(),
453 worstDistance = queryNode.Parent()->Stat().FirstBound();
458 if (SortPolicy::IsBetter(queryNode.Parent()->Stat().SecondBound(),
460 bestDistance = queryNode.Parent()->Stat().SecondBound();
464 if (SortPolicy::IsBetter(queryNode.Stat().FirstBound(), worstDistance))
465 worstDistance = queryNode.Stat().FirstBound();
466 if (SortPolicy::IsBetter(queryNode.Stat().SecondBound(), bestDistance))
467 bestDistance = queryNode.Stat().SecondBound();
470 queryNode.Stat().FirstBound() = worstDistance;
471 queryNode.Stat().SecondBound() = bestDistance;
472 queryNode.Stat().AuxBound() = auxDistance;
474 worstDistance = SortPolicy::Relax(worstDistance,
epsilon);
478 return worstDistance;
480 if (SortPolicy::IsBetter(worstDistance, bestDistance))
481 return worstDistance;
493 template<
typename SortPolicy,
typename MetricType,
typename TreeType>
496 const size_t queryIndex,
497 const size_t neighbor,
498 const double distance)
501 Candidate c = std::make_pair(distance, neighbor);
513 #endif // MLPACK_METHODS_NEIGHBOR_SEARCH_NEAREST_NEIGHBOR_RULES_IMPL_HPP std::priority_queue< Candidate, std::vector< Candidate >, CandidateCmp > CandidateList
Use a priority queue to represent the list of candidate neighbors.
Definition: neighbor_search_rules.hpp:184
std::vector< CandidateList > candidates
Set of candidate neighbors for each point.
Definition: neighbor_search_rules.hpp:187
void InsertNeighbor(const size_t queryIndex, const size_t neighbor, const double distance)
Helper function to insert a point into the list of candidate points.
Definition: neighbor_search_rules_impl.hpp:495
size_t lastQueryIndex
The last query point BaseCase() was called with.
Definition: neighbor_search_rules.hpp:202
Definition: is_spill_tree.hpp:21
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
TraversalInfoType traversalInfo
Traversal info for the parent combination; this is updated by the traversal before each call to Score...
Definition: neighbor_search_rules.hpp:215
const TreeType::Mat & referenceSet
The reference set.
Definition: neighbor_search_rules.hpp:166
std::pair< double, size_t > Candidate
Candidate represents a possible candidate neighbor (distance, index).
Definition: neighbor_search_rules.hpp:172
const size_t k
Number of neighbors to search for.
Definition: neighbor_search_rules.hpp:190
const double epsilon
Relative error to be considered in approximate search.
Definition: neighbor_search_rules.hpp:199
double BaseCase(const size_t queryIndex, const size_t referenceIndex)
Get the distance from the query point to the reference point.
Definition: neighbor_search_rules_impl.hpp:85
void GetResults(arma::Mat< size_t > &neighbors, arma::mat &distances)
Store the list of candidates for each query point in the given matrices.
Definition: neighbor_search_rules_impl.hpp:63
double lastBaseCase
The last base case result.
Definition: neighbor_search_rules.hpp:206
size_t lastReferenceIndex
The last reference point BaseCase() was called with.
Definition: neighbor_search_rules.hpp:204
double LastBaseCase() const
Get the base case associated with the last node combination.
Definition: traversal_info.hpp:78
TreeType * LastQueryNode() const
Get the last query node.
Definition: traversal_info.hpp:63
double LastScore() const
Get the score associated with the last query and reference nodes.
Definition: traversal_info.hpp:73
double Score(const size_t queryIndex, TreeType &referenceNode)
Get the score for recursion order.
Definition: neighbor_search_rules_impl.hpp:111
The TreeTraits class provides compile-time information on the characteristics of a given tree type...
Definition: tree_traits.hpp:77
double Rescore(const size_t queryIndex, TreeType &referenceNode, const double oldScore) const
Re-evaluate the score for recursion order.
Definition: neighbor_search_rules_impl.hpp:170
NeighborSearchRules(const typename TreeType::Mat &referenceSet, const typename TreeType::Mat &querySet, const size_t k, MetricType &metric, const double epsilon=0, const bool sameSet=false)
Construct the NeighborSearchRules object.
Definition: neighbor_search_rules_impl.hpp:23
double CalculateBound(TreeType &queryNode) const
Recalculate the bound for a given query node.
Definition: neighbor_search_rules_impl.hpp:370
size_t GetBestChild(const size_t queryIndex, TreeType &referenceNode)
Get the child node with the best score.
Definition: neighbor_search_rules_impl.hpp:155
MetricType & metric
The instantiated metric.
Definition: neighbor_search_rules.hpp:193
size_t scores
The number of scores that have been performed.
Definition: neighbor_search_rules.hpp:211
size_t baseCases
The number of base cases that have been performed.
Definition: neighbor_search_rules.hpp:209
Definition of IsSpillTree.
const TreeType::Mat & querySet
The query set.
Definition: neighbor_search_rules.hpp:169
TreeType * LastReferenceNode() const
Get the last reference node.
Definition: traversal_info.hpp:68
Compare two candidates based on the distance.
Definition: neighbor_search_rules.hpp:175
bool sameSet
Denotes whether or not the reference and query sets are the same.
Definition: neighbor_search_rules.hpp:196