mlpack
breadth_first_dual_tree_traverser_impl.hpp
Go to the documentation of this file.
1 
14 #ifndef MLPACK_CORE_TREE_BINARY_SPACE_TREE_BF_DUAL_TREE_TRAVERSER_IMPL_HPP
15 #define MLPACK_CORE_TREE_BINARY_SPACE_TREE_BF_DUAL_TREE_TRAVERSER_IMPL_HPP
16 
17 // In case it hasn't been included yet.
19 
20 namespace mlpack {
21 namespace tree {
22 
23 template<typename MetricType,
24  typename StatisticType,
25  typename MatType,
26  template<typename BoundMetricType, typename...> class BoundType,
27  template<typename SplitBoundType, typename SplitMatType>
28  class SplitType>
29 template<typename RuleType>
32  RuleType& rule) :
33  rule(rule),
34  numPrunes(0),
35  numVisited(0),
36  numScores(0),
37  numBaseCases(0)
38 { /* Nothing to do. */ }
39 
40 template<typename TreeType, typename TraversalInfoType>
41 bool operator<(const QueueFrame<TreeType, TraversalInfoType>& a,
43 {
44  if (a.queryDepth > b.queryDepth)
45  return true;
46  else if ((a.queryDepth == b.queryDepth) && (a.score > b.score))
47  return true;
48  return false;
49 }
50 
51 template<typename MetricType,
52  typename StatisticType,
53  typename MatType,
54  template<typename BoundMetricType, typename...> class BoundType,
55  template<typename SplitBoundType, typename SplitMatType>
56  class SplitType>
57 template<typename RuleType>
61  queryRoot,
63  referenceRoot)
64 {
65  // Increment the visit counter.
66  ++numVisited;
67 
68  // Store the current traversal info.
69  traversalInfo = rule.TraversalInfo();
70 
71  // Must score the root combination.
72  const double rootScore = rule.Score(queryRoot, referenceRoot);
73  if (rootScore == DBL_MAX)
74  return; // This probably means something is wrong.
75 
76  std::priority_queue<QueueFrameType> queue;
77 
78  QueueFrameType rootFrame;
79  rootFrame.queryNode = &queryRoot;
80  rootFrame.referenceNode = &referenceRoot;
81  rootFrame.queryDepth = 0;
82  rootFrame.score = 0.0;
83  rootFrame.traversalInfo = rule.TraversalInfo();
84 
85  queue.push(rootFrame);
86 
87  // Start the traversal.
88  Traverse(queryRoot, queue);
89 }
90 
91 template<typename MetricType,
92  typename StatisticType,
93  typename MatType,
94  template<typename BoundMetricType, typename...> class BoundType,
95  template<typename SplitBoundType, typename SplitMatType>
96  class SplitType>
97 template<typename RuleType>
101  queryNode,
102  std::priority_queue<QueueFrameType>& referenceQueue)
103 {
104  // Store queues for the children. We will recurse into the children once our
105  // queue is empty.
106  std::priority_queue<QueueFrameType> leftChildQueue;
107  std::priority_queue<QueueFrameType> rightChildQueue;
108 
109  while (!referenceQueue.empty())
110  {
111  QueueFrameType currentFrame = referenceQueue.top();
112  referenceQueue.pop();
113 
114  BinarySpaceTree& queryNode = *currentFrame.queryNode;
115  BinarySpaceTree& referenceNode = *currentFrame.referenceNode;
116  typename RuleType::TraversalInfoType ti = currentFrame.traversalInfo;
117  rule.TraversalInfo() = ti;
118  const size_t queryDepth = currentFrame.queryDepth;
119 
120  double score = rule.Score(queryNode, referenceNode);
121  ++numScores;
122 
123  if (score == DBL_MAX)
124  {
125  ++numPrunes;
126  continue;
127  }
128 
129  // If both are leaves, we must evaluate the base case.
130  if (queryNode.IsLeaf() && referenceNode.IsLeaf())
131  {
132  // Loop through each of the points in each node.
133  const size_t queryEnd = queryNode.Begin() + queryNode.Count();
134  const size_t refEnd = referenceNode.Begin() + referenceNode.Count();
135  for (size_t query = queryNode.Begin(); query < queryEnd; ++query)
136  {
137  // See if we need to investigate this point (this function should be
138  // implemented for the single-tree recursion too). Restore the
139  // traversal information first.
140 // const double childScore = rule.Score(query, referenceNode);
141 
142 // if (childScore == DBL_MAX)
143 // continue; // We can't improve this particular point.
144 
145  for (size_t ref = referenceNode.Begin(); ref < refEnd; ++ref)
146  rule.BaseCase(query, ref);
147 
148  numBaseCases += referenceNode.Count();
149  }
150  }
151  else if ((!queryNode.IsLeaf()) && referenceNode.IsLeaf())
152  {
153  // We have to recurse down the query node.
154  QueueFrameType fl = { queryNode.Left(), &referenceNode, queryDepth + 1,
155  score, rule.TraversalInfo() };
156  leftChildQueue.push(fl);
157 
158  QueueFrameType fr = { queryNode.Right(), &referenceNode, queryDepth + 1,
159  score, ti };
160  rightChildQueue.push(fr);
161  }
162  else if (queryNode.IsLeaf() && (!referenceNode.IsLeaf()))
163  {
164  // We have to recurse down the reference node. In this case the recursion
165  // order does matter. Before recursing, though, we have to set the
166  // traversal information correctly.
167  QueueFrameType fl = { &queryNode, referenceNode.Left(), queryDepth,
168  score, rule.TraversalInfo() };
169  referenceQueue.push(fl);
170 
171  QueueFrameType fr = { &queryNode, referenceNode.Right(), queryDepth,
172  score, ti };
173  referenceQueue.push(fr);
174  }
175  else
176  {
177  // We have to recurse down both query and reference nodes. Because the
178  // query descent order does not matter, we will go to the left query child
179  // first. Before recursing, we have to set the traversal information
180  // correctly.
181  QueueFrameType fll = { queryNode.Left(), referenceNode.Left(),
182  queryDepth + 1, score, rule.TraversalInfo() };
183  leftChildQueue.push(fll);
184 
185  QueueFrameType flr = { queryNode.Left(), referenceNode.Right(),
186  queryDepth + 1, score, rule.TraversalInfo() };
187  leftChildQueue.push(flr);
188 
189  QueueFrameType frl = { queryNode.Right(), referenceNode.Left(),
190  queryDepth + 1, score, rule.TraversalInfo() };
191  rightChildQueue.push(frl);
192 
193  QueueFrameType frr = { queryNode.Right(), referenceNode.Right(),
194  queryDepth + 1, score, rule.TraversalInfo() };
195  rightChildQueue.push(frr);
196  }
197  }
198 
199  // Now, recurse into the left and right children queues. The order doesn't
200  // matter.
201  if (leftChildQueue.size() > 0)
202  Traverse(*queryNode.Left(), leftChildQueue);
203  if (rightChildQueue.size() > 0)
204  Traverse(*queryNode.Right(), rightChildQueue);
205 }
206 
207 } // namespace tree
208 } // namespace mlpack
209 
210 #endif // MLPACK_CORE_TREE_BINARY_SPACE_TREE_BF_DUAL_TREE_TRAVERSER_IMPL_HPP
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
size_t Count() const
Return the number of points in this subset.
Definition: binary_space_tree.hpp:503
BreadthFirstDualTreeTraverser(RuleType &rule)
Instantiate the dual-tree traverser with the given rule set.
Definition: breadth_first_dual_tree_traverser_impl.hpp:31
BinarySpaceTree * Right() const
Gets the right child of this node.
Definition: binary_space_tree.hpp:337
A binary space partitioning tree, such as a KD-tree or a ball tree.
Definition: binary_space_tree.hpp:54
Definition: breadth_first_dual_tree_traverser.hpp:27
BinarySpaceTree * Left() const
Gets the left child of this node.
Definition: binary_space_tree.hpp:332
bool IsLeaf() const
Return whether or not this node is a leaf (true if it has no children).
Definition: binary_space_tree_impl.hpp:594
size_t Begin() const
Return the index of the beginning point of this subset.
Definition: binary_space_tree.hpp:498
void Traverse(BinarySpaceTree &queryNode, BinarySpaceTree &referenceNode)
Traverse the two trees.
Definition: breadth_first_dual_tree_traverser_impl.hpp:59