mlpack
range_search_rules_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_RULES_IMPL_HPP
13 #define MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_RULES_IMPL_HPP
14 
15 // In case it hasn't been included yet.
16 #include "range_search_rules.hpp"
17 
18 namespace mlpack {
19 namespace range {
20 
21 template<typename MetricType, typename TreeType>
23  const arma::mat& referenceSet,
24  const arma::mat& querySet,
25  const math::Range& range,
26  std::vector<std::vector<size_t> >& neighbors,
27  std::vector<std::vector<double> >& distances,
28  MetricType& metric,
29  const bool sameSet) :
30  referenceSet(referenceSet),
31  querySet(querySet),
32  range(range),
33  neighbors(neighbors),
34  distances(distances),
35  metric(metric),
36  sameSet(sameSet),
37  lastQueryIndex(querySet.n_cols),
38  lastReferenceIndex(referenceSet.n_cols),
39  baseCases(0),
40  scores(0)
41 {
42  // Nothing to do.
43 }
44 
47 template<typename MetricType, typename TreeType>
48 inline force_inline
50  const size_t queryIndex,
51  const size_t referenceIndex)
52 {
53  // If the datasets are the same, don't return the point as in its own range.
54  if (sameSet && (queryIndex == referenceIndex))
55  return 0.0;
56 
57  // If we have just performed this base case, don't do it again.
58  if ((lastQueryIndex == queryIndex) && (lastReferenceIndex == referenceIndex))
59  return 0.0; // No value to return... this shouldn't do anything bad.
60 
61  const double distance = metric.Evaluate(querySet.unsafe_col(queryIndex),
62  referenceSet.unsafe_col(referenceIndex));
63  ++baseCases;
64 
65  // Update last indices, so we don't accidentally perform a base case twice.
66  lastQueryIndex = queryIndex;
67  lastReferenceIndex = referenceIndex;
68 
69  if (range.Contains(distance))
70  {
71  neighbors[queryIndex].push_back(referenceIndex);
72  distances[queryIndex].push_back(distance);
73  }
74 
75  return distance;
76 }
77 
79 template<typename MetricType, typename TreeType>
80 double RangeSearchRules<MetricType, TreeType>::Score(const size_t queryIndex,
81  TreeType& referenceNode)
82 {
83  // We must get the minimum and maximum distances and store them in this
84  // object.
85  math::Range distances;
86 
88  {
89  // In this situation, we calculate the base case. So we should check to be
90  // sure we haven't already done that.
91  double baseCase;
93  (referenceNode.Parent() != NULL) &&
94  (referenceNode.Point(0) == referenceNode.Parent()->Point(0)))
95  {
96  // If the tree has self-children and this is a self-child, the base case
97  // was already calculated.
98  baseCase = referenceNode.Parent()->Stat().LastDistance();
99  lastQueryIndex = queryIndex;
100  lastReferenceIndex = referenceNode.Point(0);
101  }
102  else
103  {
104  // We must calculate the base case by hand.
105  baseCase = BaseCase(queryIndex, referenceNode.Point(0));
106  }
107 
108  // This may be possibly loose for non-ball bound trees.
109  distances.Lo() = baseCase - referenceNode.FurthestDescendantDistance();
110  distances.Hi() = baseCase + referenceNode.FurthestDescendantDistance();
111 
112  // Update last distance calculation.
113  referenceNode.Stat().LastDistance() = baseCase;
114  }
115  else
116  {
117  distances = referenceNode.RangeDistance(querySet.unsafe_col(queryIndex));
118  ++scores;
119  }
120 
121  // If the ranges do not overlap, prune this node.
122  if (!distances.Contains(range))
123  return DBL_MAX;
124 
125  // In this case, all of the points in the reference node will be part of the
126  // results.
127  if ((distances.Lo() >= range.Lo()) && (distances.Hi() <= range.Hi()))
128  {
129  AddResult(queryIndex, referenceNode);
130  return DBL_MAX; // We don't need to go any deeper.
131  }
132 
133  // Otherwise the score doesn't matter. Recursion order is irrelevant in
134  // range search.
135  return 0.0;
136 }
137 
139 template<typename MetricType, typename TreeType>
141  const size_t /* queryIndex */,
142  TreeType& /* referenceNode */,
143  const double oldScore) const
144 {
145  // If it wasn't pruned before, it isn't pruned now.
146  return oldScore;
147 }
148 
150 template<typename MetricType, typename TreeType>
152  TreeType& referenceNode)
153 {
154  math::Range distances;
156  {
157  // It is possible that the base case has already been calculated.
158  double baseCase = 0.0;
159  if ((traversalInfo.LastQueryNode() != NULL) &&
160  (traversalInfo.LastReferenceNode() != NULL) &&
161  (traversalInfo.LastQueryNode()->Point(0) == queryNode.Point(0)) &&
162  (traversalInfo.LastReferenceNode()->Point(0) == referenceNode.Point(0)))
163  {
164  baseCase = traversalInfo.LastBaseCase();
165 
166  // Make sure that if BaseCase() is called, we don't duplicate results.
167  lastQueryIndex = queryNode.Point(0);
168  lastReferenceIndex = referenceNode.Point(0);
169  }
170  else
171  {
172  // We must calculate the base case.
173  baseCase = BaseCase(queryNode.Point(0), referenceNode.Point(0));
174  }
175 
176  distances.Lo() = baseCase - queryNode.FurthestDescendantDistance()
177  - referenceNode.FurthestDescendantDistance();
178  distances.Hi() = baseCase + queryNode.FurthestDescendantDistance()
179  + referenceNode.FurthestDescendantDistance();
180 
181  // Update the last distances performed for the query and reference node.
182  traversalInfo.LastBaseCase() = baseCase;
183  }
184  else
185  {
186  // Just perform the calculation.
187  distances = referenceNode.RangeDistance(queryNode);
188  ++scores;
189  }
190 
191  // If the ranges do not overlap, prune this node.
192  if (!distances.Contains(range))
193  return DBL_MAX;
194 
195  // In this case, all of the points in the reference node will be part of all
196  // the results for each point in the query node.
197  if ((distances.Lo() >= range.Lo()) && (distances.Hi() <= range.Hi()))
198  {
199  for (size_t i = 0; i < queryNode.NumDescendants(); ++i)
200  AddResult(queryNode.Descendant(i), referenceNode);
201  return DBL_MAX; // We don't need to go any deeper.
202  }
203 
204  // Otherwise the score doesn't matter. Recursion order is irrelevant in range
205  // search.
206  traversalInfo.LastQueryNode() = &queryNode;
207  traversalInfo.LastReferenceNode() = &referenceNode;
208  return 0.0;
209 }
210 
212 template<typename MetricType, typename TreeType>
214  TreeType& /* queryNode */,
215  TreeType& /* referenceNode */,
216  const double oldScore) const
217 {
218  // If it wasn't pruned before, it isn't pruned now.
219  return oldScore;
220 }
221 
224 template<typename MetricType, typename TreeType>
225 void RangeSearchRules<MetricType, TreeType>::AddResult(const size_t queryIndex,
226  TreeType& referenceNode)
227 {
228  // Some types of trees calculate the base case evaluation before Score() is
229  // called, so if the base case has already been calculated, then we must avoid
230  // adding that point to the results again.
231  size_t baseCaseMod = 0;
233  (queryIndex == lastQueryIndex) &&
234  (referenceNode.Point(0) == lastReferenceIndex))
235  {
236  baseCaseMod = 1;
237  }
238 
239  // Resize distances and neighbors vectors appropriately. We have to use
240  // reserve() and not resize(), because we don't know if we will encounter the
241  // case where the datasets and points are the same (and we skip in that case).
242  const size_t oldSize = neighbors[queryIndex].size();
243  neighbors[queryIndex].reserve(oldSize + referenceNode.NumDescendants() -
244  baseCaseMod);
245  distances[queryIndex].reserve(oldSize + referenceNode.NumDescendants() -
246  baseCaseMod);
247 
248  for (size_t i = baseCaseMod; i < referenceNode.NumDescendants(); ++i)
249  {
250  if ((&referenceSet == &querySet) &&
251  (queryIndex == referenceNode.Descendant(i)))
252  continue;
253 
254  const double distance = metric.Evaluate(querySet.unsafe_col(queryIndex),
255  referenceNode.Dataset().unsafe_col(referenceNode.Descendant(i)));
256 
257  neighbors[queryIndex].push_back(referenceNode.Descendant(i));
258  distances[queryIndex].push_back(distance);
259  }
260 }
261 
262 } // namespace range
263 } // namespace mlpack
264 
265 #endif
T Lo() const
Get the lower bound.
Definition: range.hpp:61
RangeSearchRules(const arma::mat &referenceSet, const arma::mat &querySet, const math::Range &range, std::vector< std::vector< size_t > > &neighbors, std::vector< std::vector< double > > &distances, MetricType &metric, const bool sameSet=false)
Construct the RangeSearchRules object.
Definition: range_search_rules_impl.hpp:22
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
double Score(const size_t queryIndex, TreeType &referenceNode)
Get the score for recursion order.
Definition: range_search_rules_impl.hpp:80
double BaseCase(const size_t queryIndex, const size_t referenceIndex)
Compute the base case between the given query point and reference point.
Definition: range_search_rules_impl.hpp:49
double LastBaseCase() const
Get the base case associated with the last node combination.
Definition: traversal_info.hpp:78
TreeType * LastQueryNode() const
Get the last query node.
Definition: traversal_info.hpp:63
double Rescore(const size_t queryIndex, TreeType &referenceNode, const double oldScore) const
Re-evaluate the score for recursion order.
Definition: range_search_rules_impl.hpp:140
T Hi() const
Get the upper bound.
Definition: range.hpp:66
The TreeTraits class provides compile-time information on the characteristics of a given tree type...
Definition: tree_traits.hpp:77
The RangeSearchRules class is a template helper class used by RangeSearch class when performing range...
Definition: range_search_rules.hpp:28
bool Contains(const T d) const
Determines if a point is contained within the range.
Definition: range_impl.hpp:187
TreeType * LastReferenceNode() const
Get the last reference node.
Definition: traversal_info.hpp:68