12 #ifndef MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_RULES_IMPL_HPP 13 #define MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_RULES_IMPL_HPP 21 template<
typename MetricType,
typename TreeType>
23 const arma::mat& referenceSet,
24 const arma::mat& querySet,
26 std::vector<std::vector<size_t> >& neighbors,
27 std::vector<std::vector<double> >& distances,
30 referenceSet(referenceSet),
37 lastQueryIndex(querySet.n_cols),
38 lastReferenceIndex(referenceSet.n_cols),
47 template<
typename MetricType,
typename TreeType>
50 const size_t queryIndex,
51 const size_t referenceIndex)
54 if (sameSet && (queryIndex == referenceIndex))
58 if ((lastQueryIndex == queryIndex) && (lastReferenceIndex == referenceIndex))
61 const double distance = metric.Evaluate(querySet.unsafe_col(queryIndex),
62 referenceSet.unsafe_col(referenceIndex));
66 lastQueryIndex = queryIndex;
67 lastReferenceIndex = referenceIndex;
71 neighbors[queryIndex].push_back(referenceIndex);
72 distances[queryIndex].push_back(distance);
79 template<
typename MetricType,
typename TreeType>
81 TreeType& referenceNode)
93 (referenceNode.Parent() != NULL) &&
94 (referenceNode.Point(0) == referenceNode.Parent()->Point(0)))
98 baseCase = referenceNode.Parent()->Stat().LastDistance();
99 lastQueryIndex = queryIndex;
100 lastReferenceIndex = referenceNode.Point(0);
105 baseCase =
BaseCase(queryIndex, referenceNode.Point(0));
109 distances.
Lo() = baseCase - referenceNode.FurthestDescendantDistance();
110 distances.
Hi() = baseCase + referenceNode.FurthestDescendantDistance();
113 referenceNode.Stat().LastDistance() = baseCase;
117 distances = referenceNode.RangeDistance(querySet.unsafe_col(queryIndex));
127 if ((distances.
Lo() >= range.
Lo()) && (distances.
Hi() <= range.
Hi()))
129 AddResult(queryIndex, referenceNode);
139 template<
typename MetricType,
typename TreeType>
143 const double oldScore)
const 150 template<
typename MetricType,
typename TreeType>
152 TreeType& referenceNode)
158 double baseCase = 0.0;
161 (traversalInfo.
LastQueryNode()->Point(0) == queryNode.Point(0)) &&
167 lastQueryIndex = queryNode.Point(0);
168 lastReferenceIndex = referenceNode.Point(0);
173 baseCase =
BaseCase(queryNode.Point(0), referenceNode.Point(0));
176 distances.
Lo() = baseCase - queryNode.FurthestDescendantDistance()
177 - referenceNode.FurthestDescendantDistance();
178 distances.
Hi() = baseCase + queryNode.FurthestDescendantDistance()
179 + referenceNode.FurthestDescendantDistance();
187 distances = referenceNode.RangeDistance(queryNode);
197 if ((distances.
Lo() >= range.
Lo()) && (distances.
Hi() <= range.
Hi()))
199 for (
size_t i = 0; i < queryNode.NumDescendants(); ++i)
200 AddResult(queryNode.Descendant(i), referenceNode);
212 template<
typename MetricType,
typename TreeType>
216 const double oldScore)
const 224 template<
typename MetricType,
typename TreeType>
226 TreeType& referenceNode)
231 size_t baseCaseMod = 0;
233 (queryIndex == lastQueryIndex) &&
234 (referenceNode.Point(0) == lastReferenceIndex))
242 const size_t oldSize = neighbors[queryIndex].size();
243 neighbors[queryIndex].reserve(oldSize + referenceNode.NumDescendants() -
245 distances[queryIndex].reserve(oldSize + referenceNode.NumDescendants() -
248 for (
size_t i = baseCaseMod; i < referenceNode.NumDescendants(); ++i)
250 if ((&referenceSet == &querySet) &&
251 (queryIndex == referenceNode.Descendant(i)))
254 const double distance = metric.Evaluate(querySet.unsafe_col(queryIndex),
255 referenceNode.Dataset().unsafe_col(referenceNode.Descendant(i)));
257 neighbors[queryIndex].push_back(referenceNode.Descendant(i));
258 distances[queryIndex].push_back(distance);
T Lo() const
Get the lower bound.
Definition: range.hpp:61
RangeSearchRules(const arma::mat &referenceSet, const arma::mat &querySet, const math::Range &range, std::vector< std::vector< size_t > > &neighbors, std::vector< std::vector< double > > &distances, MetricType &metric, const bool sameSet=false)
Construct the RangeSearchRules object.
Definition: range_search_rules_impl.hpp:22
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
double Score(const size_t queryIndex, TreeType &referenceNode)
Get the score for recursion order.
Definition: range_search_rules_impl.hpp:80
double BaseCase(const size_t queryIndex, const size_t referenceIndex)
Compute the base case between the given query point and reference point.
Definition: range_search_rules_impl.hpp:49
double LastBaseCase() const
Get the base case associated with the last node combination.
Definition: traversal_info.hpp:78
TreeType * LastQueryNode() const
Get the last query node.
Definition: traversal_info.hpp:63
double Rescore(const size_t queryIndex, TreeType &referenceNode, const double oldScore) const
Re-evaluate the score for recursion order.
Definition: range_search_rules_impl.hpp:140
T Hi() const
Get the upper bound.
Definition: range.hpp:66
The TreeTraits class provides compile-time information on the characteristics of a given tree type...
Definition: tree_traits.hpp:77
The RangeSearchRules class is a template helper class used by RangeSearch class when performing range...
Definition: range_search_rules.hpp:28
bool Contains(const T d) const
Determines if a point is contained within the range.
Definition: range_impl.hpp:187
TreeType * LastReferenceNode() const
Get the last reference node.
Definition: traversal_info.hpp:68