mlpack
dual_tree_traverser_impl.hpp
Go to the documentation of this file.
1 
14 #ifndef MLPACK_CORE_TREE_BINARY_SPACE_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP
15 #define MLPACK_CORE_TREE_BINARY_SPACE_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP
16 
17 // In case it hasn't been included yet.
18 #include "dual_tree_traverser.hpp"
19 
20 namespace mlpack {
21 namespace tree {
22 
23 template<typename MetricType,
24  typename StatisticType,
25  typename MatType,
26  template<typename BoundMetricType, typename...> class BoundType,
27  template<typename SplitBoundType, typename SplitMatType>
28  class SplitType>
29 template<typename RuleType>
32  rule(rule),
33  numPrunes(0),
34  numVisited(0),
35  numScores(0),
36  numBaseCases(0)
37 { /* Nothing to do. */ }
38 
39 template<typename MetricType,
40  typename StatisticType,
41  typename MatType,
42  template<typename BoundMetricType, typename...> class BoundType,
43  template<typename SplitBoundType, typename SplitMatType>
44  class SplitType>
45 template<typename RuleType>
49  queryNode,
51  referenceNode)
52 {
53  // Increment the visit counter.
54  ++numVisited;
55 
56  // Store the current traversal info.
57  traversalInfo = rule.TraversalInfo();
58 
59  // If both nodes are root nodes, just score them.
60  if (queryNode.Parent() == NULL && referenceNode.Parent() == NULL)
61  {
62  const double rootScore = rule.Score(queryNode, referenceNode);
63  // If root score is DBL_MAX, don't recurse.
64  if (rootScore == DBL_MAX)
65  {
66  ++numPrunes;
67  return;
68  }
69  }
70 
71  // If both are leaves, we must evaluate the base case.
72  if (queryNode.IsLeaf() && referenceNode.IsLeaf())
73  {
74  // Loop through each of the points in each node.
75  const size_t queryEnd = queryNode.Begin() + queryNode.Count();
76  const size_t refEnd = referenceNode.Begin() + referenceNode.Count();
77  for (size_t query = queryNode.Begin(); query < queryEnd; ++query)
78  {
79  // See if we need to investigate this point (this function should be
80  // implemented for the single-tree recursion too). Restore the traversal
81  // information first.
82  rule.TraversalInfo() = traversalInfo;
83  const double childScore = rule.Score(query, referenceNode);
84 
85  if (childScore == DBL_MAX)
86  continue; // We can't improve this particular point.
87 
88  for (size_t ref = referenceNode.Begin(); ref < refEnd; ++ref)
89  rule.BaseCase(query, ref);
90 
91  numBaseCases += referenceNode.Count();
92  }
93  }
94  else if (((!queryNode.IsLeaf()) && referenceNode.IsLeaf()) ||
95  (queryNode.NumDescendants() > 3 * referenceNode.NumDescendants() &&
96  !queryNode.IsLeaf() && !referenceNode.IsLeaf()))
97  {
98  // We have to recurse down the query node. In this case the recursion order
99  // does not matter.
100  const double leftScore = rule.Score(*queryNode.Left(), referenceNode);
101  ++numScores;
102 
103  if (leftScore != DBL_MAX)
104  Traverse(*queryNode.Left(), referenceNode);
105  else
106  ++numPrunes;
107 
108  // Before recursing, we have to set the traversal information correctly.
109  rule.TraversalInfo() = traversalInfo;
110  const double rightScore = rule.Score(*queryNode.Right(), referenceNode);
111  ++numScores;
112 
113  if (rightScore != DBL_MAX)
114  Traverse(*queryNode.Right(), referenceNode);
115  else
116  ++numPrunes;
117  }
118  else if (queryNode.IsLeaf() && (!referenceNode.IsLeaf()))
119  {
120  // We have to recurse down the reference node. In this case the recursion
121  // order does matter. Before recursing, though, we have to set the
122  // traversal information correctly.
123  double leftScore = rule.Score(queryNode, *referenceNode.Left());
124  typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo();
125  rule.TraversalInfo() = traversalInfo;
126  double rightScore = rule.Score(queryNode, *referenceNode.Right());
127  numScores += 2;
128 
129  if (leftScore < rightScore)
130  {
131  // Recurse to the left. Restore the left traversal info. Store the right
132  // traversal info.
133  traversalInfo = rule.TraversalInfo();
134  rule.TraversalInfo() = leftInfo;
135  Traverse(queryNode, *referenceNode.Left());
136 
137  // Is it still valid to recurse to the right?
138  rightScore = rule.Rescore(queryNode, *referenceNode.Right(), rightScore);
139 
140  if (rightScore != DBL_MAX)
141  {
142  // Restore the right traversal info.
143  rule.TraversalInfo() = traversalInfo;
144  Traverse(queryNode, *referenceNode.Right());
145  }
146  else
147  ++numPrunes;
148  }
149  else if (rightScore < leftScore)
150  {
151  // Recurse to the right.
152  Traverse(queryNode, *referenceNode.Right());
153 
154  // Is it still valid to recurse to the left?
155  leftScore = rule.Rescore(queryNode, *referenceNode.Left(), leftScore);
156 
157  if (leftScore != DBL_MAX)
158  {
159  // Restore the left traversal info.
160  rule.TraversalInfo() = leftInfo;
161  Traverse(queryNode, *referenceNode.Left());
162  }
163  else
164  ++numPrunes;
165  }
166  else // leftScore is equal to rightScore.
167  {
168  if (leftScore == DBL_MAX)
169  {
170  numPrunes += 2;
171  }
172  else
173  {
174  // Choose the left first. Restore the left traversal info. Store the
175  // right traversal info.
176  traversalInfo = rule.TraversalInfo();
177  rule.TraversalInfo() = leftInfo;
178  Traverse(queryNode, *referenceNode.Left());
179 
180  rightScore = rule.Rescore(queryNode, *referenceNode.Right(),
181  rightScore);
182 
183  if (rightScore != DBL_MAX)
184  {
185  // Restore the right traversal info.
186  rule.TraversalInfo() = traversalInfo;
187  Traverse(queryNode, *referenceNode.Right());
188  }
189  else
190  ++numPrunes;
191  }
192  }
193  }
194  else
195  {
196  // We have to recurse down both query and reference nodes. Because the
197  // query descent order does not matter, we will go to the left query child
198  // first. Before recursing, we have to set the traversal information
199  // correctly.
200  double leftScore = rule.Score(*queryNode.Left(), *referenceNode.Left());
201  typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo();
202  rule.TraversalInfo() = traversalInfo;
203  double rightScore = rule.Score(*queryNode.Left(), *referenceNode.Right());
204  typename RuleType::TraversalInfoType rightInfo;
205  numScores += 2;
206 
207  if (leftScore < rightScore)
208  {
209  // Recurse to the left. Restore the left traversal info. Store the right
210  // traversal info.
211  rightInfo = rule.TraversalInfo();
212  rule.TraversalInfo() = leftInfo;
213  Traverse(*queryNode.Left(), *referenceNode.Left());
214 
215  // Is it still valid to recurse to the right?
216  rightScore = rule.Rescore(*queryNode.Left(), *referenceNode.Right(),
217  rightScore);
218 
219  if (rightScore != DBL_MAX)
220  {
221  // Restore the right traversal info.
222  rule.TraversalInfo() = rightInfo;
223  Traverse(*queryNode.Left(), *referenceNode.Right());
224  }
225  else
226  ++numPrunes;
227  }
228  else if (rightScore < leftScore)
229  {
230  // Recurse to the right.
231  Traverse(*queryNode.Left(), *referenceNode.Right());
232 
233  // Is it still valid to recurse to the left?
234  leftScore = rule.Rescore(*queryNode.Left(), *referenceNode.Left(),
235  leftScore);
236 
237  if (leftScore != DBL_MAX)
238  {
239  // Restore the left traversal info.
240  rule.TraversalInfo() = leftInfo;
241  Traverse(*queryNode.Left(), *referenceNode.Left());
242  }
243  else
244  ++numPrunes;
245  }
246  else
247  {
248  if (leftScore == DBL_MAX)
249  {
250  numPrunes += 2;
251  }
252  else
253  {
254  // Choose the left first. Restore the left traversal info and store the
255  // right traversal info.
256  rightInfo = rule.TraversalInfo();
257  rule.TraversalInfo() = leftInfo;
258  Traverse(*queryNode.Left(), *referenceNode.Left());
259 
260  // Is it still valid to recurse to the right?
261  rightScore = rule.Rescore(*queryNode.Left(), *referenceNode.Right(),
262  rightScore);
263 
264  if (rightScore != DBL_MAX)
265  {
266  // Restore the right traversal information.
267  rule.TraversalInfo() = rightInfo;
268  Traverse(*queryNode.Left(), *referenceNode.Right());
269  }
270  else
271  ++numPrunes;
272  }
273  }
274 
275  // Restore the main traversal information.
276  rule.TraversalInfo() = traversalInfo;
277 
278  // Now recurse down the right query node.
279  leftScore = rule.Score(*queryNode.Right(), *referenceNode.Left());
280  leftInfo = rule.TraversalInfo();
281  rule.TraversalInfo() = traversalInfo;
282  rightScore = rule.Score(*queryNode.Right(), *referenceNode.Right());
283  numScores += 2;
284 
285  if (leftScore < rightScore)
286  {
287  // Recurse to the left. Restore the left traversal info. Store the right
288  // traversal info.
289  rightInfo = rule.TraversalInfo();
290  rule.TraversalInfo() = leftInfo;
291  Traverse(*queryNode.Right(), *referenceNode.Left());
292 
293  // Is it still valid to recurse to the right?
294  rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(),
295  rightScore);
296 
297  if (rightScore != DBL_MAX)
298  {
299  // Restore the right traversal info.
300  rule.TraversalInfo() = rightInfo;
301  Traverse(*queryNode.Right(), *referenceNode.Right());
302  }
303  else
304  ++numPrunes;
305  }
306  else if (rightScore < leftScore)
307  {
308  // Recurse to the right.
309  Traverse(*queryNode.Right(), *referenceNode.Right());
310 
311  // Is it still valid to recurse to the left?
312  leftScore = rule.Rescore(*queryNode.Right(), *referenceNode.Left(),
313  leftScore);
314 
315  if (leftScore != DBL_MAX)
316  {
317  // Restore the left traversal info.
318  rule.TraversalInfo() = leftInfo;
319  Traverse(*queryNode.Right(), *referenceNode.Left());
320  }
321  else
322  ++numPrunes;
323  }
324  else
325  {
326  if (leftScore == DBL_MAX)
327  {
328  numPrunes += 2;
329  }
330  else
331  {
332  // Choose the left first. Restore the left traversal info. Store the
333  // right traversal info.
334  rightInfo = rule.TraversalInfo();
335  rule.TraversalInfo() = leftInfo;
336  Traverse(*queryNode.Right(), *referenceNode.Left());
337 
338  // Is it still valid to recurse to the right?
339  rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(),
340  rightScore);
341 
342  if (rightScore != DBL_MAX)
343  {
344  // Restore the right traversal info.
345  rule.TraversalInfo() = rightInfo;
346  Traverse(*queryNode.Right(), *referenceNode.Right());
347  }
348  else
349  ++numPrunes;
350  }
351  }
352  }
353 }
354 
355 } // namespace tree
356 } // namespace mlpack
357 
358 #endif // MLPACK_CORE_TREE_BINARY_SPACE_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP
BinarySpaceTree * Parent() const
Gets the parent of this node.
Definition: binary_space_tree.hpp:342
size_t NumDescendants() const
Return the number of descendants of this node.
Definition: binary_space_tree_impl.hpp:826
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
size_t Count() const
Return the number of points in this subset.
Definition: binary_space_tree.hpp:503
void Traverse(BinarySpaceTree &queryNode, BinarySpaceTree &referenceNode)
Traverse the two trees.
Definition: dual_tree_traverser_impl.hpp:47
BinarySpaceTree * Right() const
Gets the right child of this node.
Definition: binary_space_tree.hpp:337
A binary space partitioning tree, such as a KD-tree or a ball tree.
Definition: binary_space_tree.hpp:54
DualTreeTraverser(RuleType &rule)
Instantiate the dual-tree traverser with the given rule set.
Definition: dual_tree_traverser_impl.hpp:31
BinarySpaceTree * Left() const
Gets the left child of this node.
Definition: binary_space_tree.hpp:332
bool IsLeaf() const
Return whether or not this node is a leaf (true if it has no children).
Definition: binary_space_tree_impl.hpp:594
size_t Begin() const
Return the index of the beginning point of this subset.
Definition: binary_space_tree.hpp:498