mlpack
ra_search_rules.hpp
Go to the documentation of this file.
1 
14 #ifndef MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP
15 #define MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP
16 
18 
19 #include <queue>
20 
21 namespace mlpack {
22 namespace neighbor {
23 
32 template<typename SortPolicy, typename MetricType, typename TreeType>
34 {
35  public:
57  RASearchRules(const arma::mat& referenceSet,
58  const arma::mat& querySet,
59  const size_t k,
60  MetricType& metric,
61  const double tau = 5,
62  const double alpha = 0.95,
63  const bool naive = false,
64  const bool sampleAtLeaves = false,
65  const bool firstLeafExact = false,
66  const size_t singleSampleLimit = 20,
67  const bool sameSet = false);
68 
76  void GetResults(arma::Mat<size_t>& neighbors, arma::mat& distances);
77 
85  double BaseCase(const size_t queryIndex, const size_t referenceIndex);
86 
109  double Score(const size_t queryIndex, TreeType& referenceNode);
110 
134  double Score(const size_t queryIndex,
135  TreeType& referenceNode,
136  const double baseCaseResult);
137 
155  double Rescore(const size_t queryIndex,
156  TreeType& referenceNode,
157  const double oldScore);
158 
177  double Score(TreeType& queryNode, TreeType& referenceNode);
178 
199  double Score(TreeType& queryNode,
200  TreeType& referenceNode,
201  const double baseCaseResult);
202 
225  double Rescore(TreeType& queryNode,
226  TreeType& referenceNode,
227  const double oldScore);
228 
229 
230  size_t NumDistComputations() { return numDistComputations; }
231  size_t NumEffectiveSamples()
232  {
233  if (numSamplesMade.n_elem == 0)
234  return 0;
235  else
236  return arma::sum(numSamplesMade);
237  }
238 
240 
241  const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
242  TraversalInfoType& TraversalInfo() { return traversalInfo; }
243 
247  size_t MinimumBaseCases() const { return k; }
248 
249  private:
251  const arma::mat& referenceSet;
252 
254  const arma::mat& querySet;
255 
257  typedef std::pair<double, size_t> Candidate;
258 
260  struct CandidateCmp {
261  bool operator()(const Candidate& c1, const Candidate& c2)
262  {
263  return !SortPolicy::IsBetter(c2.first, c1.first);
264  };
265  };
266 
268  typedef std::priority_queue<Candidate, std::vector<Candidate>, CandidateCmp>
269  CandidateList;
270 
272  std::vector<CandidateList> candidates;
273 
275  const size_t k;
276 
278  MetricType& metric;
279 
281  bool sampleAtLeaves;
282 
284  bool firstLeafExact;
285 
287  size_t singleSampleLimit;
288 
290  size_t numSamplesReqd;
291 
293  arma::Col<size_t> numSamplesMade;
294 
296  double samplingRatio;
297 
299  size_t numDistComputations;
300 
302  bool sameSet;
303 
304  TraversalInfoType traversalInfo;
305 
313  void InsertNeighbor(const size_t queryIndex,
314  const size_t neighbor,
315  const double distance);
316 
320  double Score(const size_t queryIndex,
321  TreeType& referenceNode,
322  const double distance,
323  const double bestDistance);
324 
328  double Score(TreeType& queryNode,
329  TreeType& referenceNode,
330  const double distance,
331  const double bestDistance);
332 
333  static_assert(tree::TreeTraits<TreeType>::UniqueNumDescendants, "TreeType "
334  "must provide a unique number of descendants points.");
335 }; // class RASearchRules
336 
337 } // namespace neighbor
338 } // namespace mlpack
339 
340 // Include implementation.
341 #include "ra_search_rules_impl.hpp"
342 
343 #endif // MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP
void GetResults(arma::Mat< size_t > &neighbors, arma::mat &distances)
Store the list of candidates for each query point in the given matrices.
Definition: ra_search_rules_impl.hpp:101
The TraversalInfo class holds traversal information which is used in dual-tree (and single-tree) trav...
Definition: traversal_info.hpp:50
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
double Rescore(const size_t queryIndex, TreeType &referenceNode, const double oldScore)
Re-evaluate the score for recursion order.
Definition: ra_search_rules_impl.hpp:268
double Score(const size_t queryIndex, TreeType &referenceNode)
Get the score for recursion order.
Definition: ra_search_rules_impl.hpp:144
size_t MinimumBaseCases() const
Get the minimum number of base cases that must be performed for each query point for an acceptable re...
Definition: ra_search_rules.hpp:247
double BaseCase(const size_t queryIndex, const size_t referenceIndex)
Get the distance from the query point to the reference point.
Definition: ra_search_rules_impl.hpp:122
The TreeTraits class provides compile-time information on the characteristics of a given tree type...
Definition: tree_traits.hpp:77
RASearchRules(const arma::mat &referenceSet, const arma::mat &querySet, const size_t k, MetricType &metric, const double tau=5, const double alpha=0.95, const bool naive=false, const bool sampleAtLeaves=false, const bool firstLeafExact=false, const size_t singleSampleLimit=20, const bool sameSet=false)
Construct the RASearchRules object.
Definition: ra_search_rules_impl.hpp:23
The RASearchRules class is a template helper class used by RASearch class when performing rank-approx...
Definition: ra_search_rules.hpp:33