mlpack
single_tree_traverser_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_CORE_TREE_OCTREE_SINGLE_TREE_TRAVERSER_IMPL_HPP
13 #define MLPACK_CORE_TREE_OCTREE_SINGLE_TREE_TRAVERSER_IMPL_HPP
14 
15 // In case it hasn't been included yet.
17 
18 namespace mlpack {
19 namespace tree {
20 
21 template<typename MetricType, typename StatisticType, typename MatType>
22 template<typename RuleType>
23 Octree<MetricType, StatisticType, MatType>::SingleTreeTraverser<RuleType>::
24  SingleTreeTraverser(RuleType& rule) :
25  rule(rule),
26  numPrunes(0)
27 {
28  // Nothing to do.
29 }
30 
31 template<typename MetricType, typename StatisticType, typename MatType>
32 template<typename RuleType>
34  Traverse(const size_t queryIndex, Octree& referenceNode)
35 {
36  // If we are a leaf, run the base cases.
37  if (referenceNode.NumChildren() == 0)
38  {
39  const size_t refBegin = referenceNode.Point(0);
40  const size_t refEnd = refBegin + referenceNode.NumPoints();
41  for (size_t r = refBegin; r < refEnd; ++r)
42  rule.BaseCase(queryIndex, r);
43  }
44  else
45  {
46  // If it's the root node, just score it.
47  if (referenceNode.Parent() == NULL)
48  {
49  const double rootScore = rule.Score(queryIndex, referenceNode);
50  // If root score is DBL_MAX, don't recurse into that node.
51  if (rootScore == DBL_MAX)
52  {
53  ++numPrunes;
54  return;
55  }
56  }
57 
58  // Do a prioritized recursion, by scoring all candidates and then sorting
59  // them.
60  arma::vec scores(referenceNode.NumChildren());
61  for (size_t i = 0; i < scores.n_elem; ++i)
62  scores[i] = rule.Score(queryIndex, referenceNode.Child(i));
63 
64  // Sort the scores.
65  arma::uvec sortedIndices = arma::sort_index(scores);
66 
67  for (size_t i = 0; i < sortedIndices.n_elem; ++i)
68  {
69  // If the node is pruned, all subsequent nodes in sorted order will also
70  // be pruned.
71  if (scores[sortedIndices[i]] == DBL_MAX)
72  {
73  numPrunes += (sortedIndices.n_elem - i);
74  break;
75  }
76 
77  Traverse(queryIndex, referenceNode.Child(sortedIndices[i]));
78  }
79  }
80 }
81 
82 } // namespace tree
83 } // namespace mlpack
84 
85 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
size_t NumPoints() const
Return the number of points in this node (0 if not a leaf).
Definition: octree_impl.hpp:627
size_t NumChildren() const
Return the number of children in this node.
Definition: octree_impl.hpp:509
void Traverse(const size_t queryIndex, Octree &referenceNode)
Traverse the reference tree with the given query point.
Definition: single_tree_traverser_impl.hpp:34
size_t Point(const size_t index) const
Return the index (with reference to the dataset) of a particular point in this node.
Definition: octree_impl.hpp:647
const Octree & Child(const size_t child) const
Return the specified child.
Definition: octree.hpp:340
Definition: octree.hpp:25
Octree * Parent() const
Get the pointer to the parent.
Definition: octree.hpp:256