mlpack
dual_tree_traverser_impl.hpp
Go to the documentation of this file.
1 
14 #ifndef MLPAC_CORE_TREE_RECTANGLE_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP
15 #define MLPAC_CORE_TREE_RECTANGLE_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP
16 
17 #include "dual_tree_traverser.hpp"
18 
19 #include <algorithm>
20 #include <stack>
21 
22 namespace mlpack {
23 namespace tree {
24 
25 template<typename MetricType,
26  typename StatisticType,
27  typename MatType,
28  typename SplitType,
29  typename DescentType,
30  template<typename> class AuxiliaryInformationType>
31 template<typename RuleType>
32 RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
33  AuxiliaryInformationType>::
34 DualTreeTraverser<RuleType>::DualTreeTraverser(RuleType& rule) :
35  rule(rule),
36  numPrunes(0),
37  numVisited(0),
38  numScores(0),
39  numBaseCases(0)
40 { /* Nothing to do */ }
41 
42 template<typename MetricType,
43  typename StatisticType,
44  typename MatType,
45  typename SplitType,
46  typename DescentType,
47  template<typename> class AuxiliaryInformationType>
48 template<typename RuleType>
49 void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
50  AuxiliaryInformationType>::
52  RectangleTree& referenceNode)
53 {
54  // Increment the visit counter.
55  ++numVisited;
56 
57  // Store the current traversal info.
58  traversalInfo = rule.TraversalInfo();
59 
60  // We now have four options.
61  // 1) Both nodes are leaf nodes.
62  // 2) Only the reference node is a leaf node.
63  // 3) Only the query node is a leaf node.
64  // 4) Niether node is a leaf node.
65  // We go through those options in that order.
66 
67  if (queryNode.IsLeaf() && referenceNode.IsLeaf())
68  {
69  // Evaluate the base case. Do the query points on the outside so we can
70  // possibly prune the reference node for that particular point.
71  for (size_t query = 0; query < queryNode.Count(); ++query)
72  {
73  // Restore the traversal information.
74  rule.TraversalInfo() = traversalInfo;
75  const double childScore = rule.Score(queryNode.Point(query),
76  referenceNode);
77 
78  if (childScore == DBL_MAX)
79  continue; // We don't require a search in this reference node.
80 
81  for (size_t ref = 0; ref < referenceNode.Count(); ++ref)
82  rule.BaseCase(queryNode.Point(query), referenceNode.Point(ref));
83 
84  numBaseCases += referenceNode.Count();
85  }
86  }
87  else if (!queryNode.IsLeaf() && referenceNode.IsLeaf())
88  {
89  // We only need to traverse down the query node. Order doesn't matter here.
90  for (size_t i = 0; i < queryNode.NumChildren(); ++i)
91  {
92  // Before recursing, we have to set the traversal information correctly.
93  rule.TraversalInfo() = traversalInfo;
94  ++numScores;
95  if (rule.Score(queryNode.Child(i), referenceNode) < DBL_MAX)
96  Traverse(queryNode.Child(i), referenceNode);
97  else
98  numPrunes++;
99  }
100  }
101  else if (queryNode.IsLeaf() && !referenceNode.IsLeaf())
102  {
103  // We only need to traverse down the reference node. Order does matter
104  // here.
105 
106  // We sort the children of the reference node by their scores.
107  std::vector<NodeAndScore> nodesAndScores(referenceNode.NumChildren());
108  for (size_t i = 0; i < referenceNode.NumChildren(); ++i)
109  {
110  rule.TraversalInfo() = traversalInfo;
111  nodesAndScores[i].node = &(referenceNode.Child(i));
112  nodesAndScores[i].score = rule.Score(queryNode,
113  *(nodesAndScores[i].node));
114  nodesAndScores[i].travInfo = rule.TraversalInfo();
115  }
116  std::sort(nodesAndScores.begin(), nodesAndScores.end(), nodeComparator);
117  numScores += nodesAndScores.size();
118 
119  for (size_t i = 0; i < nodesAndScores.size(); ++i)
120  {
121  rule.TraversalInfo() = nodesAndScores[i].travInfo;
122  if (rule.Rescore(queryNode, *(nodesAndScores[i].node),
123  nodesAndScores[i].score) < DBL_MAX)
124  {
125  Traverse(queryNode, *(nodesAndScores[i].node));
126  }
127  else
128  {
129  numPrunes += nodesAndScores.size() - i;
130  break;
131  }
132  }
133  }
134  else
135  {
136  // We need to traverse down both the reference and the query trees.
137  // We loop through all of the query nodes, and for each of them, we
138  // loop through the reference nodes to see where we need to descend.
139  for (size_t j = 0; j < queryNode.NumChildren(); ++j)
140  {
141  // We sort the children of the reference node by their scores.
142  std::vector<NodeAndScore> nodesAndScores(referenceNode.NumChildren());
143  for (size_t i = 0; i < referenceNode.NumChildren(); ++i)
144  {
145  rule.TraversalInfo() = traversalInfo;
146  nodesAndScores[i].node = &(referenceNode.Child(i));
147  nodesAndScores[i].score = rule.Score(queryNode.Child(j),
148  *nodesAndScores[i].node);
149  nodesAndScores[i].travInfo = rule.TraversalInfo();
150  }
151  std::sort(nodesAndScores.begin(), nodesAndScores.end(), nodeComparator);
152  numScores += nodesAndScores.size();
153 
154  for (size_t i = 0; i < nodesAndScores.size(); ++i)
155  {
156  rule.TraversalInfo() = nodesAndScores[i].travInfo;
157  if (rule.Rescore(queryNode.Child(j), *(nodesAndScores[i].node),
158  nodesAndScores[i].score) < DBL_MAX)
159  {
160  Traverse(queryNode.Child(j), *(nodesAndScores[i].node));
161  }
162  else
163  {
164  numPrunes += nodesAndScores.size() - i;
165  break;
166  }
167  }
168  }
169  }
170 }
171 
172 } // namespace tree
173 } // namespace mlpack
174 
175 #endif
size_t Point(const size_t index) const
Return the index (with reference to the dataset) of a particular point in this node.
Definition: rectangle_tree.hpp:480
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
friend DescentType
Give friend access for DescentType.
Definition: rectangle_tree.hpp:580
size_t NumChildren() const
Return the number of child nodes. (One level beneath this one only.)
Definition: rectangle_tree.hpp:371
friend SplitType
Give friend access for SplitType.
Definition: rectangle_tree.hpp:583
void Traverse(RectangleTree &queryNode, RectangleTree &referenceNode)
Traverse the two trees.
Definition: dual_tree_traverser_impl.hpp:51
A rectangle type tree tree, such as an R-tree or X-tree.
Definition: rectangle_tree.hpp:54
A dual tree traverser for rectangle type trees.
Definition: dual_tree_traverser.hpp:31
bool IsLeaf() const
Return whether or not this node is a leaf (true if it has no children).
Definition: rectangle_tree_impl.hpp:760
RectangleTree & Child(const size_t child) const
Get the specified child.
Definition: rectangle_tree.hpp:437
size_t Count() const
Return the number of points in this subset.
Definition: rectangle_tree.hpp:548