12 #ifndef MLPACK_METHODS_RANN_RA_SEARCH_RULES_IMPL_HPP 13 #define MLPACK_METHODS_RANN_RA_SEARCH_RULES_IMPL_HPP 21 template<
typename SortPolicy,
typename MetricType,
typename TreeType>
24 const arma::mat& querySet,
30 const bool sampleAtLeaves,
31 const bool firstLeafExact,
32 const size_t singleSampleLimit,
34 referenceSet(referenceSet),
38 sampleAtLeaves(sampleAtLeaves),
39 firstLeafExact(firstLeafExact),
40 singleSampleLimit(singleSampleLimit),
47 const size_t n = referenceSet.n_cols;
48 const size_t t = (size_t) std::ceil(tau * (
double) n / 100.0);
51 Log::Warn <<
"Rank-approximation percentile " << tau <<
" corresponds to " 52 << t <<
" points, which is less than k (" << k <<
").";
53 Log::Fatal <<
"Cannot return " << k <<
" approximate nearest neighbors " 54 <<
"from the nearest " << t <<
" points. Increase tau!" << std::endl;
57 Log::Warn <<
"Rank-approximation percentile " << tau <<
" corresponds to " 58 << t <<
" points; because k = " << k <<
", this is exact search!" 66 numSamplesMade = arma::zeros<arma::Col<size_t> >(querySet.n_cols);
67 numDistComputations = 0;
68 samplingRatio = (double) numSamplesReqd / (
double) n;
70 Log::Info <<
"Minimum samples required per query: " << numSamplesReqd <<
71 ", sampling ratio: " << samplingRatio << std::endl;
77 const Candidate def = std::make_pair(SortPolicy::WorstDistance(),
80 std::vector<Candidate> vect(k, def);
81 CandidateList pqueue(CandidateCmp(), std::move(vect));
83 candidates.reserve(querySet.n_cols);
84 for (
size_t i = 0; i < querySet.n_cols; ++i)
85 candidates.push_back(pqueue);
90 arma::uvec distinctSamples;
91 for (
size_t i = 0; i < querySet.n_cols; ++i)
94 for (
size_t j = 0; j < distinctSamples.n_elem; ++j)
95 BaseCase(i, (
size_t) distinctSamples[j]);
100 template<
typename SortPolicy,
typename MetricType,
typename TreeType>
102 arma::Mat<size_t>& neighbors,
103 arma::mat& distances)
105 neighbors.set_size(k, querySet.n_cols);
106 distances.set_size(k, querySet.n_cols);
108 for (
size_t i = 0; i < querySet.n_cols; ++i)
110 CandidateList& pqueue = candidates[i];
111 for (
size_t j = 1; j <= k; ++j)
113 neighbors(k - j, i) = pqueue.top().second;
114 distances(k - j, i) = pqueue.top().first;
120 template<
typename SortPolicy,
typename MetricType,
typename TreeType>
123 const size_t queryIndex,
124 const size_t referenceIndex)
128 if (sameSet && (queryIndex == referenceIndex))
131 double distance = metric.Evaluate(querySet.unsafe_col(queryIndex),
132 referenceSet.unsafe_col(referenceIndex));
134 InsertNeighbor(queryIndex, referenceIndex, distance);
136 numSamplesMade[queryIndex]++;
138 numDistComputations++;
143 template<
typename SortPolicy,
typename MetricType,
typename TreeType>
145 const size_t queryIndex,
146 TreeType& referenceNode)
148 const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
149 const double distance = SortPolicy::BestPointToNodeDistance(queryPoint,
151 const double bestDistance = candidates[queryIndex].top().first;
153 return Score(queryIndex, referenceNode, distance, bestDistance);
156 template<
typename SortPolicy,
typename MetricType,
typename TreeType>
158 const size_t queryIndex,
159 TreeType& referenceNode,
160 const double baseCaseResult)
162 const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
163 const double distance = SortPolicy::BestPointToNodeDistance(queryPoint,
164 &referenceNode, baseCaseResult);
165 const double bestDistance = candidates[queryIndex].top().first;
167 return Score(queryIndex, referenceNode, distance, bestDistance);
170 template<
typename SortPolicy,
typename MetricType,
typename TreeType>
172 const size_t queryIndex,
173 TreeType& referenceNode,
174 const double distance,
175 const double bestDistance)
180 if (SortPolicy::IsBetter(distance, bestDistance)
181 && numSamplesMade[queryIndex] < numSamplesReqd)
187 if (numSamplesMade[queryIndex] > 0 || !firstLeafExact)
190 size_t samplesReqd = (size_t) std::ceil(samplingRatio *
191 (
double) referenceNode.NumDescendants());
192 samplesReqd = std::min(samplesReqd,
193 numSamplesReqd - numSamplesMade[queryIndex]);
195 if (samplesReqd > singleSampleLimit && !referenceNode.IsLeaf())
202 if (!referenceNode.IsLeaf())
206 arma::uvec distinctSamples;
208 samplesReqd, distinctSamples);
209 for (
size_t i = 0; i < distinctSamples.n_elem; ++i)
212 BaseCase(queryIndex, referenceNode.Descendant(distinctSamples[i]));
222 arma::uvec distinctSamples;
224 samplesReqd, distinctSamples);
225 for (
size_t i = 0; i < distinctSamples.n_elem; ++i)
229 referenceNode.Descendant(distinctSamples[i]));
259 numSamplesMade[queryIndex] += (size_t) std::floor(
260 samplingRatio * (
double) referenceNode.NumDescendants());
266 template<
typename SortPolicy,
typename MetricType,
typename TreeType>
269 TreeType& referenceNode,
270 const double oldScore)
273 if (oldScore == DBL_MAX)
277 const double bestDistance = candidates[queryIndex].top().first;
282 if (SortPolicy::IsBetter(oldScore, bestDistance)
283 && numSamplesMade[queryIndex] < numSamplesReqd)
293 size_t samplesReqd = (size_t) std::ceil(samplingRatio *
294 (
double) referenceNode.NumDescendants());
295 samplesReqd = std::min(samplesReqd, numSamplesReqd -
296 numSamplesMade[queryIndex]);
298 if (samplesReqd > singleSampleLimit && !referenceNode.IsLeaf())
306 if (!referenceNode.IsLeaf())
310 arma::uvec distinctSamples;
312 samplesReqd, distinctSamples);
313 for (
size_t i = 0; i < distinctSamples.n_elem; ++i)
316 BaseCase(queryIndex, referenceNode.Descendant(distinctSamples[i]));
326 arma::uvec distinctSamples;
328 samplesReqd, distinctSamples);
329 for (
size_t i = 0; i < distinctSamples.n_elem; ++i)
332 BaseCase(queryIndex, referenceNode.Descendant(distinctSamples[i]));
353 numSamplesMade[queryIndex] += (size_t) std::floor(samplingRatio *
354 (
double) referenceNode.NumDescendants());
360 template<
typename SortPolicy,
typename MetricType,
typename TreeType>
363 TreeType& referenceNode)
368 const double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
371 double pointBound = DBL_MAX;
372 double childBound = DBL_MAX;
373 const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
375 for (
size_t i = 0; i < queryNode.NumPoints(); ++i)
377 const double bound = candidates[queryNode.Point(i)].top().first
378 + maxDescendantDistance;
379 if (bound < pointBound)
383 for (
size_t i = 0; i < queryNode.NumChildren(); ++i)
385 const double bound = queryNode.Child(i).Stat().Bound();
386 if (bound < childBound)
391 queryNode.Stat().Bound() = std::min(pointBound, childBound);
392 const double bestDistance = queryNode.Stat().Bound();
394 return Score(queryNode, referenceNode, distance, bestDistance);
397 template<
typename SortPolicy,
typename MetricType,
typename TreeType>
400 TreeType& referenceNode,
401 const double baseCaseResult)
407 const double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
408 &referenceNode, baseCaseResult);
410 double pointBound = DBL_MAX;
411 double childBound = DBL_MAX;
412 const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
414 for (
size_t i = 0; i < queryNode.NumPoints(); ++i)
416 const double bound = candidates[queryNode.Point(i)].top().first
417 + maxDescendantDistance;
418 if (bound < pointBound)
422 for (
size_t i = 0; i < queryNode.NumChildren(); ++i)
424 const double bound = queryNode.Child(i).Stat().Bound();
425 if (bound < childBound)
430 queryNode.Stat().Bound() = std::min(pointBound, childBound);
431 const double bestDistance = queryNode.Stat().Bound();
433 return Score(queryNode, referenceNode, distance, bestDistance);
436 template<
typename SortPolicy,
typename MetricType,
typename TreeType>
439 TreeType& referenceNode,
440 const double distance,
441 const double bestDistance)
449 if (!queryNode.IsLeaf())
451 size_t numSamplesMadeInChildNodes = std::numeric_limits<size_t>::max();
454 for (
size_t i = 0; i < queryNode.NumChildren(); ++i)
456 const size_t numSamples = queryNode.Child(i).Stat().NumSamplesMade();
457 if (numSamples < numSamplesMadeInChildNodes)
458 numSamplesMadeInChildNodes = numSamples;
464 queryNode.Stat().NumSamplesMade() = std::max(
465 queryNode.Stat().NumSamplesMade(), numSamplesMadeInChildNodes);
473 if (SortPolicy::IsBetter(distance, bestDistance)
474 && queryNode.Stat().NumSamplesMade() < numSamplesReqd)
480 if (queryNode.Stat().NumSamplesMade() > 0 || !firstLeafExact)
483 size_t samplesReqd = (size_t) std::ceil(samplingRatio
484 * (
double) referenceNode.NumDescendants());
485 samplesReqd = std::min(samplesReqd, numSamplesReqd -
486 queryNode.Stat().NumSamplesMade());
488 if (samplesReqd > singleSampleLimit && !referenceNode.IsLeaf())
497 for (
size_t i = 0; i < queryNode.NumChildren(); ++i)
498 queryNode.Child(i).Stat().NumSamplesMade() = std::max(
499 queryNode.Stat().NumSamplesMade(),
500 queryNode.Child(i).Stat().NumSamplesMade());
506 if (!referenceNode.IsLeaf())
510 arma::uvec distinctSamples;
511 for (
size_t i = 0; i < queryNode.NumDescendants(); ++i)
513 const size_t queryIndex = queryNode.Descendant(i);
515 samplesReqd, distinctSamples);
516 for (
size_t j = 0; j < distinctSamples.n_elem; ++j)
520 referenceNode.Descendant(distinctSamples[j]));
525 queryNode.Stat().NumSamplesMade() += samplesReqd;
540 arma::uvec distinctSamples;
541 for (
size_t i = 0; i < queryNode.NumDescendants(); ++i)
543 const size_t queryIndex = queryNode.Descendant(i);
545 samplesReqd, distinctSamples);
546 for (
size_t j = 0; j < distinctSamples.n_elem; ++j)
550 referenceNode.Descendant(distinctSamples[j]));
555 queryNode.Stat().NumSamplesMade() += samplesReqd;
571 for (
size_t i = 0; i < queryNode.NumChildren(); ++i)
572 queryNode.Child(i).Stat().NumSamplesMade() = std::max(
573 queryNode.Stat().NumSamplesMade(),
574 queryNode.Child(i).Stat().NumSamplesMade());
586 for (
size_t i = 0; i < queryNode.NumChildren(); ++i)
587 queryNode.Child(i).Stat().NumSamplesMade() = std::max(
588 queryNode.Stat().NumSamplesMade(),
589 queryNode.Child(i).Stat().NumSamplesMade());
603 queryNode.Stat().NumSamplesMade() += (size_t) std::floor(samplingRatio *
604 (
double) referenceNode.NumDescendants());
614 template<
typename SortPolicy,
typename MetricType,
typename TreeType>
617 TreeType& referenceNode,
618 const double oldScore)
620 if (oldScore == DBL_MAX)
624 double pointBound = DBL_MAX;
625 double childBound = DBL_MAX;
626 const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
628 for (
size_t i = 0; i < queryNode.NumPoints(); ++i)
630 const double bound = candidates[queryNode.Point(i)].top().first
631 + maxDescendantDistance;
632 if (bound < pointBound)
636 for (
size_t i = 0; i < queryNode.NumChildren(); ++i)
638 const double bound = queryNode.Child(i).Stat().Bound();
639 if (bound < childBound)
644 queryNode.Stat().Bound() = std::min(pointBound, childBound);
645 const double bestDistance = queryNode.Stat().Bound();
654 if (!queryNode.IsLeaf())
656 size_t numSamplesMadeInChildNodes = std::numeric_limits<size_t>::max();
659 for (
size_t i = 0; i < queryNode.NumChildren(); ++i)
661 const size_t numSamples = queryNode.Child(i).Stat().NumSamplesMade();
662 if (numSamples < numSamplesMadeInChildNodes)
663 numSamplesMadeInChildNodes = numSamples;
669 queryNode.Stat().NumSamplesMade() = std::max(
670 queryNode.Stat().NumSamplesMade(), numSamplesMadeInChildNodes);
678 if (SortPolicy::IsBetter(oldScore, bestDistance) &&
679 queryNode.Stat().NumSamplesMade() < numSamplesReqd)
686 size_t samplesReqd = (size_t) std::ceil(
687 samplingRatio * (
double) referenceNode.NumDescendants());
688 samplesReqd = std::min(samplesReqd,
689 numSamplesReqd - queryNode.Stat().NumSamplesMade());
691 if (samplesReqd > singleSampleLimit && !referenceNode.IsLeaf())
702 for (
size_t i = 0; i < queryNode.NumChildren(); ++i)
703 queryNode.Child(i).Stat().NumSamplesMade() = std::max(
704 queryNode.Stat().NumSamplesMade(),
705 queryNode.Child(i).Stat().NumSamplesMade());
711 if (!referenceNode.IsLeaf())
715 arma::uvec distinctSamples;
716 for (
size_t i = 0; i < queryNode.NumDescendants(); ++i)
718 const size_t queryIndex = queryNode.Descendant(i);
720 samplesReqd, distinctSamples);
721 for (
size_t j = 0; j < distinctSamples.n_elem; ++j)
724 BaseCase(queryIndex, referenceNode.Descendant(distinctSamples[j]));
729 queryNode.Stat().NumSamplesMade() += samplesReqd;
744 arma::uvec distinctSamples;
745 for (
size_t i = 0; i < queryNode.NumDescendants(); ++i)
747 const size_t queryIndex = queryNode.Descendant(i);
749 samplesReqd, distinctSamples);
750 for (
size_t j = 0; j < distinctSamples.n_elem; ++j)
754 referenceNode.Descendant(distinctSamples[j]));
759 queryNode.Stat().NumSamplesMade() += samplesReqd;
772 for (
size_t i = 0; i < queryNode.NumChildren(); ++i)
773 queryNode.Child(i).Stat().NumSamplesMade() = std::max(
774 queryNode.Stat().NumSamplesMade(),
775 queryNode.Child(i).Stat().NumSamplesMade());
791 queryNode.Stat().NumSamplesMade() += (size_t) std::floor(samplingRatio *
792 (
double) referenceNode.NumDescendants());
808 template<
typename SortPolicy,
typename MetricType,
typename TreeType>
811 const size_t queryIndex,
812 const size_t neighbor,
813 const double distance)
815 CandidateList& pqueue = candidates[queryIndex];
816 Candidate c = std::make_pair(distance, neighbor);
818 if (CandidateCmp()(c, pqueue.top()))
828 #endif // MLPACK_METHODS_RANN_RA_SEARCH_RULES_IMPL_HPP void ObtainDistinctSamples(const size_t loInclusive, const size_t hiExclusive, const size_t maxNumSamples, arma::uvec &distinctSamples)
Obtains no more than maxNumSamples distinct samples.
Definition: random.hpp:153
void GetResults(arma::Mat< size_t > &neighbors, arma::mat &distances)
Store the list of candidates for each query point in the given matrices.
Definition: ra_search_rules_impl.hpp:101
static void Start(const std::string &name)
Start the given timer.
Definition: timers.cpp:28
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
double Rescore(const size_t queryIndex, TreeType &referenceNode, const double oldScore)
Re-evaluate the score for recursion order.
Definition: ra_search_rules_impl.hpp:268
double Score(const size_t queryIndex, TreeType &referenceNode)
Get the score for recursion order.
Definition: ra_search_rules_impl.hpp:144
static MLPACK_EXPORT util::PrefixedOutStream Warn
Prints warning messages prefixed with [WARN ].
Definition: log.hpp:87
double BaseCase(const size_t queryIndex, const size_t referenceIndex)
Get the distance from the query point to the reference point.
Definition: ra_search_rules_impl.hpp:122
static size_t MinimumSamplesReqd(const size_t n, const size_t k, const double tau, const double alpha)
Compute the minimum number of samples required to guarantee the given rank-approximation and success ...
Definition: ra_util.cpp:18
static void Stop(const std::string &name)
Stop the given timer.
Definition: timers.cpp:36
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84
RASearchRules(const arma::mat &referenceSet, const arma::mat &querySet, const size_t k, MetricType &metric, const double tau=5, const double alpha=0.95, const bool naive=false, const bool sampleAtLeaves=false, const bool firstLeafExact=false, const size_t singleSampleLimit=20, const bool sameSet=false)
Construct the RASearchRules object.
Definition: ra_search_rules_impl.hpp:23
The RASearchRules class is a template helper class used by RASearch class when performing rank-approx...
Definition: ra_search_rules.hpp:33