mlpack
dual_tree_traverser_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_CORE_TREE_OCTREE_DUAL_TREE_TRAVERSER_IMPL_HPP
13 #define MLPACK_CORE_TREE_OCTREE_DUAL_TREE_TRAVERSER_IMPL_HPP
14 
15 // In case it hasn't been included yet.
16 #include "dual_tree_traverser.hpp"
17 
18 namespace mlpack {
19 namespace tree {
20 
21 template<typename MetricType, typename StatisticType, typename MatType>
22 template<typename RuleType>
23 Octree<MetricType, StatisticType, MatType>::DualTreeTraverser<RuleType>::
24  DualTreeTraverser(RuleType& rule) :
25  rule(rule),
26  numPrunes(0),
27  numVisited(0),
28  numScores(0),
29  numBaseCases(0)
30 {
31  // Nothing to do.
32 }
33 
34 template<typename MetricType, typename StatisticType, typename MatType>
35 template<typename RuleType>
37  Traverse(Octree& queryNode, Octree& referenceNode)
38 {
39  // Increment the visit counter.
40  ++numVisited;
41 
42  // Store the current traversal info.
43  traversalInfo = rule.TraversalInfo();
44 
45  // If both nodes are root nodes, just score them.
46  if (queryNode.Parent() == NULL && referenceNode.Parent() == NULL)
47  {
48  const double rootScore = rule.Score(queryNode, referenceNode);
49  // If root score is DBL_MAX, don't recurse.
50  if (rootScore == DBL_MAX)
51  {
52  ++numPrunes;
53  return;
54  }
55  }
56 
57  if (queryNode.IsLeaf() && referenceNode.IsLeaf())
58  {
59  const size_t begin = queryNode.Point(0);
60  const size_t end = begin + queryNode.NumPoints();
61  for (size_t q = begin; q < end; ++q)
62  {
63  // First, see if we can prune the reference node for this query point.
64  rule.TraversalInfo() = traversalInfo;
65  const double score = rule.Score(q, referenceNode);
66  if (score == DBL_MAX)
67  {
68  ++numPrunes;
69  continue;
70  }
71 
72  const size_t rBegin = referenceNode.Point(0);
73  const size_t rEnd = rBegin + referenceNode.NumPoints();
74  for (size_t r = rBegin; r < rEnd; ++r)
75  rule.BaseCase(q, r);
76 
77  numBaseCases += referenceNode.NumPoints();
78  }
79  }
80  else if (!queryNode.IsLeaf() && referenceNode.IsLeaf())
81  {
82  // We have to recurse down the query node. Order does not matter.
83  for (size_t i = 0; i < queryNode.NumChildren(); ++i)
84  {
85  rule.TraversalInfo() = traversalInfo;
86  const double score = rule.Score(queryNode.Child(i), referenceNode);
87  if (score == DBL_MAX)
88  {
89  ++numPrunes;
90  continue;
91  }
92 
93  Traverse(queryNode.Child(i), referenceNode);
94  }
95  }
96  else if (queryNode.IsLeaf() && !referenceNode.IsLeaf())
97  {
98  // We have to recurse down the reference node, so we need to do it in an
99  // ordered manner.
100  arma::vec scores(referenceNode.NumChildren());
101  std::vector<typename RuleType::TraversalInfoType>
102  tis(referenceNode.NumChildren());
103  for (size_t i = 0; i < referenceNode.NumChildren(); ++i)
104  {
105  rule.TraversalInfo() = traversalInfo;
106  scores[i] = rule.Score(queryNode, referenceNode.Child(i));
107  tis[i] = rule.TraversalInfo();
108  }
109 
110  // Sort the scores.
111  arma::uvec scoreOrder = arma::sort_index(scores);
112  for (size_t i = 0; i < scoreOrder.n_elem; ++i)
113  {
114  if (scores[scoreOrder[i]] == DBL_MAX)
115  {
116  // We don't need to check any more---all children past here are pruned.
117  numPrunes += scoreOrder.n_elem - i;
118  break;
119  }
120 
121  rule.TraversalInfo() = tis[scoreOrder[i]];
122  Traverse(queryNode, referenceNode.Child(scoreOrder[i]));
123  }
124  }
125  else
126  {
127  // We have to recurse down both the query and reference nodes. Query order
128  // does not matter, so we will do that in sequence. However we will
129  // allocate the arrays for recursion at this level.
130  arma::vec scores(referenceNode.NumChildren());
131  std::vector<typename RuleType::TraversalInfoType>
132  tis(referenceNode.NumChildren());
133  for (size_t j = 0; j < queryNode.NumChildren(); ++j)
134  {
135  // Now we have to recurse down the reference node, which we will do in a
136  // prioritized manner.
137  for (size_t i = 0; i < referenceNode.NumChildren(); ++i)
138  {
139  rule.TraversalInfo() = traversalInfo;
140  scores[i] = rule.Score(queryNode.Child(j), referenceNode.Child(i));
141  tis[i] = rule.TraversalInfo();
142  }
143 
144  // Sort the scores.
145  arma::uvec scoreOrder = arma::sort_index(scores);
146  for (size_t i = 0; i < scoreOrder.n_elem; ++i)
147  {
148  if (scores[scoreOrder[i]] == DBL_MAX)
149  {
150  // We don't need to check any more
151  // All children past here are pruned.
152  numPrunes += scoreOrder.n_elem - i;
153  break;
154  }
155 
156  rule.TraversalInfo() = tis[scoreOrder[i]];
157  Traverse(queryNode.Child(j), referenceNode.Child(scoreOrder[i]));
158  }
159  }
160  }
161 }
162 
163 } // namespace tree
164 } // namespace mlpack
165 
166 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
size_t NumPoints() const
Return the number of points in this node (0 if not a leaf).
Definition: octree_impl.hpp:627
size_t NumChildren() const
Return the number of children in this node.
Definition: octree_impl.hpp:509
void Traverse(Octree &queryNode, Octree &referenceNode)
Traverse the two trees.
Definition: dual_tree_traverser_impl.hpp:37
size_t Point(const size_t index) const
Return the index (with reference to the dataset) of a particular point in this node.
Definition: octree_impl.hpp:647
bool IsLeaf() const
Return whether or not the node is a leaf.
Definition: octree.hpp:297
const Octree & Child(const size_t child) const
Return the specified child.
Definition: octree.hpp:340
Definition: octree.hpp:25
Octree * Parent() const
Get the pointer to the parent.
Definition: octree.hpp:256