mlpack
spill_dual_tree_traverser_impl.hpp
Go to the documentation of this file.
1 
15 #ifndef MLPACK_CORE_TREE_SPILL_TREE_SPILL_DUAL_TREE_TRAVERSER_IMPL_HPP
16 #define MLPACK_CORE_TREE_SPILL_TREE_SPILL_DUAL_TREE_TRAVERSER_IMPL_HPP
17 
18 // In case it hasn't been included yet.
20 
21 namespace mlpack {
22 namespace tree {
23 
24 template<typename MetricType,
25  typename StatisticType,
26  typename MatType,
27  template<typename HyperplaneMetricType> class HyperplaneType,
28  template<typename SplitMetricType, typename SplitMatType>
29  class SplitType>
30 template<typename RuleType, bool Defeatist>
33  RuleType& rule) :
34  rule(rule),
35  numPrunes(0),
36  numVisited(0),
37  numScores(0),
38  numBaseCases(0)
39 { /* Nothing to do. */ }
40 
41 template<typename MetricType,
42  typename StatisticType,
43  typename MatType,
44  template<typename HyperplaneMetricType> class HyperplaneType,
45  template<typename SplitMetricType, typename SplitMatType>
46  class SplitType>
47 template<typename RuleType, bool Defeatist>
51  queryNode,
53  referenceNode,
54  const bool bruteForce)
55 {
56  // Increment the visit counter.
57  ++numVisited;
58 
59  // Store the current traversal info.
60  traversalInfo = rule.TraversalInfo();
61 
62  // Determine whether we need to brute-force the reference node. We have no
63  // realistic way to track how many base cases we've done for each point, so we
64  // act as though we have done zero.
65  if (!bruteForce && Defeatist &&
66  (referenceNode.Parent() != NULL) &&
67  (referenceNode.Parent()->Overlap()) &&
68  (referenceNode.NumDescendants() < rule.MinimumBaseCases()))
69  {
70  // We've actually recursed too far. Go back up one level and brute-force
71  // the computation, and then we are done.
72  Traverse(queryNode, *referenceNode.Parent(), true);
73  return;
74  }
75  else if ((queryNode.IsLeaf() && referenceNode.IsLeaf()) || bruteForce)
76  {
77  // If both are leaves or if we explicitly need to do brute-force search, we
78  // must evaluate the base cases.
79 
80  // Loop through each of the points in each node.
81  const size_t queryEnd = queryNode.NumDescendants();
82  const size_t refEnd = referenceNode.NumDescendants();
83  for (size_t query = 0; query < queryEnd; ++query)
84  {
85  const size_t queryIndex = queryNode.Descendant(query);
86  // See if we need to investigate this point. Restore the traversal
87  // information first.
88  rule.TraversalInfo() = traversalInfo;
89  const double childScore = rule.Score(queryIndex, referenceNode);
90 
91  if (childScore == DBL_MAX)
92  continue; // We can't improve this particular point.
93 
94  for (size_t ref = 0; ref < refEnd; ++ref)
95  rule.BaseCase(queryIndex, referenceNode.Descendant(ref));
96 
97  numBaseCases += refEnd;
98  }
99  }
100  else if (((!queryNode.IsLeaf()) && referenceNode.IsLeaf()) ||
101  (queryNode.NumDescendants() > 3 * referenceNode.NumDescendants() &&
102  !queryNode.IsLeaf() && !referenceNode.IsLeaf()))
103  {
104  // We have to recurse down the query node. In this case the recursion order
105  // does not matter.
106  const double leftScore = rule.Score(*queryNode.Left(), referenceNode);
107  ++numScores;
108 
109  if (leftScore != DBL_MAX)
110  Traverse(*queryNode.Left(), referenceNode);
111  else
112  ++numPrunes;
113 
114  // Before recursing, we have to set the traversal information correctly.
115  rule.TraversalInfo() = traversalInfo;
116  const double rightScore = rule.Score(*queryNode.Right(), referenceNode);
117  ++numScores;
118 
119  if (rightScore != DBL_MAX)
120  Traverse(*queryNode.Right(), referenceNode);
121  else
122  ++numPrunes;
123  }
124  else if (queryNode.IsLeaf() && (!referenceNode.IsLeaf()))
125  {
126  if (Defeatist && referenceNode.Overlap())
127  {
128  // If referenceNode is a overlapping node let's do defeatist search.
129  size_t bestChild = rule.GetBestChild(queryNode, referenceNode);
130  if (bestChild < referenceNode.NumChildren())
131  {
132  Traverse(queryNode, referenceNode.Child(bestChild));
133  ++numPrunes;
134  }
135  else
136  {
137  // If we can't decide which child node to traverse, this means that
138  // queryNode is at both sides of the splitting hyperplane. So, as
139  // queryNode is a leafNode, all we can do is single tree search for each
140  // point in the query node.
141  const size_t queryEnd = queryNode.NumPoints();
143  // Loop through each of the points in query node.
144  for (size_t query = 0; query < queryEnd; ++query)
145  {
146  const size_t queryIndex = queryNode.Point(query);
147  // See if we need to investigate this point.
148  const double childScore = rule.Score(queryIndex, referenceNode);
149 
150  if (childScore == DBL_MAX)
151  continue; // We can't improve this particular point.
152 
153  st.Traverse(queryIndex, referenceNode);
154  }
155  }
156  }
157  else
158  {
159  // We have to recurse down the reference node. In this case the recursion
160  // order does matter. Before recursing, though, we have to set the
161  // traversal information correctly.
162  double leftScore = rule.Score(queryNode, *referenceNode.Left());
163  typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo();
164  rule.TraversalInfo() = traversalInfo;
165  double rightScore = rule.Score(queryNode, *referenceNode.Right());
166  numScores += 2;
167 
168  if (leftScore < rightScore)
169  {
170  // Recurse to the left. Restore the left traversal info. Store the
171  // right traversal info.
172  traversalInfo = rule.TraversalInfo();
173  rule.TraversalInfo() = leftInfo;
174  Traverse(queryNode, *referenceNode.Left());
175 
176  // Is it still valid to recurse to the right?
177  rightScore = rule.Rescore(queryNode, *referenceNode.Right(),
178  rightScore);
179 
180  if (rightScore != DBL_MAX)
181  {
182  // Restore the right traversal info.
183  rule.TraversalInfo() = traversalInfo;
184  Traverse(queryNode, *referenceNode.Right());
185  }
186  else
187  ++numPrunes;
188  }
189  else if (rightScore < leftScore)
190  {
191  // Recurse to the right.
192  Traverse(queryNode, *referenceNode.Right());
193 
194  // Is it still valid to recurse to the left?
195  leftScore = rule.Rescore(queryNode, *referenceNode.Left(), leftScore);
196 
197  if (leftScore != DBL_MAX)
198  {
199  // Restore the left traversal info.
200  rule.TraversalInfo() = leftInfo;
201  Traverse(queryNode, *referenceNode.Left());
202  }
203  else
204  ++numPrunes;
205  }
206  else // leftScore is equal to rightScore.
207  {
208  if (leftScore == DBL_MAX)
209  {
210  numPrunes += 2;
211  }
212  else
213  {
214  // Choose the left first. Restore the left traversal info. Store the
215  // right traversal info.
216  traversalInfo = rule.TraversalInfo();
217  rule.TraversalInfo() = leftInfo;
218  Traverse(queryNode, *referenceNode.Left());
219 
220  rightScore = rule.Rescore(queryNode, *referenceNode.Right(),
221  rightScore);
222 
223  if (rightScore != DBL_MAX)
224  {
225  // Restore the right traversal info.
226  rule.TraversalInfo() = traversalInfo;
227  Traverse(queryNode, *referenceNode.Right());
228  }
229  else
230  ++numPrunes;
231  }
232  }
233  }
234  }
235  else
236  {
237  if (Defeatist && referenceNode.Overlap())
238  {
239  // If referenceNode is a overlapping node let's do defeatist search.
240  size_t bestChild = rule.GetBestChild(*queryNode.Left(), referenceNode);
241  if (bestChild < referenceNode.NumChildren())
242  {
243  Traverse(*queryNode.Left(), referenceNode.Child(bestChild));
244  ++numPrunes;
245  }
246  else
247  {
248  // If we can't decide which child node to traverse, this means that
249  // queryNode.Left() is at both sides of the splitting hyperplane. So,
250  // let's recurse down only the query node.
251  Traverse(*queryNode.Left(), referenceNode);
252  }
253 
254  bestChild = rule.GetBestChild(*queryNode.Right(), referenceNode);
255  if (bestChild < referenceNode.NumChildren())
256  {
257  Traverse(*queryNode.Right(), referenceNode.Child(bestChild));
258  ++numPrunes;
259  }
260  else
261  {
262  // If we can't decide which child node to traverse, this means that
263  // queryNode.Right() is at both sides of the splitting hyperplane. So,
264  // let's recurse down only the query node.
265  Traverse(*queryNode.Right(), referenceNode);
266  }
267  }
268  else
269  {
270  // We have to recurse down both query and reference nodes. Because the
271  // query descent order does not matter, we will go to the left query child
272  // first. Before recursing, we have to set the traversal information
273  // correctly.
274  double leftScore = rule.Score(*queryNode.Left(), *referenceNode.Left());
275  typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo();
276  rule.TraversalInfo() = traversalInfo;
277  double rightScore = rule.Score(*queryNode.Left(), *referenceNode.Right());
278  typename RuleType::TraversalInfoType rightInfo;
279  numScores += 2;
280 
281  if (leftScore < rightScore)
282  {
283  // Recurse to the left. Restore the left traversal info. Store the
284  // right traversal info.
285  rightInfo = rule.TraversalInfo();
286  rule.TraversalInfo() = leftInfo;
287  Traverse(*queryNode.Left(), *referenceNode.Left());
288 
289  // Is it still valid to recurse to the right?
290  rightScore = rule.Rescore(*queryNode.Left(), *referenceNode.Right(),
291  rightScore);
292 
293  if (rightScore != DBL_MAX)
294  {
295  // Restore the right traversal info.
296  rule.TraversalInfo() = rightInfo;
297  Traverse(*queryNode.Left(), *referenceNode.Right());
298  }
299  else
300  ++numPrunes;
301  }
302  else if (rightScore < leftScore)
303  {
304  // Recurse to the right.
305  Traverse(*queryNode.Left(), *referenceNode.Right());
306 
307  // Is it still valid to recurse to the left?
308  leftScore = rule.Rescore(*queryNode.Left(), *referenceNode.Left(),
309  leftScore);
310 
311  if (leftScore != DBL_MAX)
312  {
313  // Restore the left traversal info.
314  rule.TraversalInfo() = leftInfo;
315  Traverse(*queryNode.Left(), *referenceNode.Left());
316  }
317  else
318  ++numPrunes;
319  }
320  else
321  {
322  if (leftScore == DBL_MAX)
323  {
324  numPrunes += 2;
325  }
326  else
327  {
328  // Choose the left first. Restore the left traversal info and store
329  // the right traversal info.
330  rightInfo = rule.TraversalInfo();
331  rule.TraversalInfo() = leftInfo;
332  Traverse(*queryNode.Left(), *referenceNode.Left());
333 
334  // Is it still valid to recurse to the right?
335  rightScore = rule.Rescore(*queryNode.Left(), *referenceNode.Right(),
336  rightScore);
337 
338  if (rightScore != DBL_MAX)
339  {
340  // Restore the right traversal information.
341  rule.TraversalInfo() = rightInfo;
342  Traverse(*queryNode.Left(), *referenceNode.Right());
343  }
344  else
345  ++numPrunes;
346  }
347  }
348 
349  // Restore the main traversal information.
350  rule.TraversalInfo() = traversalInfo;
351 
352  // Now recurse down the right query node.
353  leftScore = rule.Score(*queryNode.Right(), *referenceNode.Left());
354  leftInfo = rule.TraversalInfo();
355  rule.TraversalInfo() = traversalInfo;
356  rightScore = rule.Score(*queryNode.Right(), *referenceNode.Right());
357  numScores += 2;
358 
359  if (leftScore < rightScore)
360  {
361  // Recurse to the left. Restore the left traversal info. Store the
362  // right traversal info.
363  rightInfo = rule.TraversalInfo();
364  rule.TraversalInfo() = leftInfo;
365  Traverse(*queryNode.Right(), *referenceNode.Left());
366 
367  // Is it still valid to recurse to the right?
368  rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(),
369  rightScore);
370 
371  if (rightScore != DBL_MAX)
372  {
373  // Restore the right traversal info.
374  rule.TraversalInfo() = rightInfo;
375  Traverse(*queryNode.Right(), *referenceNode.Right());
376  }
377  else
378  ++numPrunes;
379  }
380  else if (rightScore < leftScore)
381  {
382  // Recurse to the right.
383  Traverse(*queryNode.Right(), *referenceNode.Right());
384 
385  // Is it still valid to recurse to the left?
386  leftScore = rule.Rescore(*queryNode.Right(), *referenceNode.Left(),
387  leftScore);
388 
389  if (leftScore != DBL_MAX)
390  {
391  // Restore the left traversal info.
392  rule.TraversalInfo() = leftInfo;
393  Traverse(*queryNode.Right(), *referenceNode.Left());
394  }
395  else
396  ++numPrunes;
397  }
398  else
399  {
400  if (leftScore == DBL_MAX)
401  {
402  numPrunes += 2;
403  }
404  else
405  {
406  // Choose the left first. Restore the left traversal info. Store the
407  // right traversal info.
408  rightInfo = rule.TraversalInfo();
409  rule.TraversalInfo() = leftInfo;
410  Traverse(*queryNode.Right(), *referenceNode.Left());
411 
412  // Is it still valid to recurse to the right?
413  rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(),
414  rightScore);
415 
416  if (rightScore != DBL_MAX)
417  {
418  // Restore the right traversal info.
419  rule.TraversalInfo() = rightInfo;
420  Traverse(*queryNode.Right(), *referenceNode.Right());
421  }
422  else
423  ++numPrunes;
424  }
425  }
426  }
427  }
428 }
429 
430 } // namespace tree
431 } // namespace mlpack
432 
433 #endif // MLPACK_CORE_TREE_SPILL_TREE_SPILL_DUAL_TREE_TRAVERSER_IMPL_HPP
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
A generic single-tree traverser for hybrid spill trees; see spill_single_tree_traverser.hpp for implementation.
Definition: spill_single_tree_traverser.hpp:34
size_t NumPoints() const
Return the number of points in this node (0 if not a leaf).
Definition: spill_tree_impl.hpp:649
SpillTree & Child(const size_t child) const
Return the specified child (0 will be left, 1 will be right).
Definition: spill_tree_impl.hpp:631
void Traverse(const size_t queryIndex, SpillTree &referenceNode, const bool bruteForce=false)
Traverse the tree with the given point.
Definition: spill_single_tree_traverser_impl.hpp:46
SpillTree * Right() const
Gets the right child of this node.
Definition: spill_tree.hpp:262
A hybrid spill tree is a variant of binary space trees in which the children of a node can "spill ove...
Definition: spill_tree.hpp:73
bool Overlap() const
Distinguish overlapping nodes from non-overlapping nodes.
Definition: spill_tree.hpp:275
size_t NumDescendants() const
Return the number of descendants of this node.
Definition: spill_tree_impl.hpp:666
SpillTree * Parent() const
Gets the parent of this node.
Definition: spill_tree.hpp:267
SpillTree * Left() const
Gets the left child of this node.
Definition: spill_tree.hpp:257
bool IsLeaf() const
Return whether or not this node is a leaf (true if it has no children).
Definition: spill_tree_impl.hpp:433
size_t Descendant(const size_t index) const
Return the index (with reference to the dataset) of a particular descendant of this node...
Definition: spill_tree_impl.hpp:681
void Traverse(SpillTree &queryNode, SpillTree &referenceNode, const bool bruteForce=false)
Traverse the two trees.
Definition: spill_dual_tree_traverser_impl.hpp:49
size_t NumChildren() const
Return the number of children in this node.
Definition: spill_tree_impl.hpp:448
SpillDualTreeTraverser(RuleType &rule)
Instantiate the dual-tree traverser with the given rule set.
Definition: spill_dual_tree_traverser_impl.hpp:32
size_t Point(const size_t index) const
Return the index (with reference to the dataset) of a particular point in this node.
Definition: spill_tree_impl.hpp:705