15 #ifndef MLPACK_CORE_TREE_SPILL_TREE_SPILL_DUAL_TREE_TRAVERSER_IMPL_HPP 16 #define MLPACK_CORE_TREE_SPILL_TREE_SPILL_DUAL_TREE_TRAVERSER_IMPL_HPP 24 template<
typename MetricType,
25 typename StatisticType,
27 template<
typename HyperplaneMetricType>
class HyperplaneType,
28 template<
typename SplitMetricType,
typename SplitMatType>
30 template<
typename RuleType,
bool Defeatist>
41 template<
typename MetricType,
42 typename StatisticType,
44 template<
typename HyperplaneMetricType>
class HyperplaneType,
45 template<
typename SplitMetricType,
typename SplitMatType>
47 template<
typename RuleType,
bool Defeatist>
54 const bool bruteForce)
60 traversalInfo = rule.TraversalInfo();
65 if (!bruteForce && Defeatist &&
66 (referenceNode.
Parent() != NULL) &&
67 (referenceNode.
Parent()->Overlap()) &&
75 else if ((queryNode.
IsLeaf() && referenceNode.
IsLeaf()) || bruteForce)
83 for (
size_t query = 0; query < queryEnd; ++query)
85 const size_t queryIndex = queryNode.
Descendant(query);
88 rule.TraversalInfo() = traversalInfo;
89 const double childScore = rule.Score(queryIndex, referenceNode);
91 if (childScore == DBL_MAX)
94 for (
size_t ref = 0; ref < refEnd; ++ref)
95 rule.BaseCase(queryIndex, referenceNode.
Descendant(ref));
97 numBaseCases += refEnd;
100 else if (((!queryNode.
IsLeaf()) && referenceNode.
IsLeaf()) ||
106 const double leftScore = rule.Score(*queryNode.
Left(), referenceNode);
109 if (leftScore != DBL_MAX)
115 rule.TraversalInfo() = traversalInfo;
116 const double rightScore = rule.Score(*queryNode.
Right(), referenceNode);
119 if (rightScore != DBL_MAX)
124 else if (queryNode.
IsLeaf() && (!referenceNode.
IsLeaf()))
126 if (Defeatist && referenceNode.
Overlap())
129 size_t bestChild = rule.GetBestChild(queryNode, referenceNode);
141 const size_t queryEnd = queryNode.
NumPoints();
144 for (
size_t query = 0; query < queryEnd; ++query)
146 const size_t queryIndex = queryNode.
Point(query);
148 const double childScore = rule.Score(queryIndex, referenceNode);
150 if (childScore == DBL_MAX)
153 st.
Traverse(queryIndex, referenceNode);
162 double leftScore = rule.Score(queryNode, *referenceNode.
Left());
163 typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo();
164 rule.TraversalInfo() = traversalInfo;
165 double rightScore = rule.Score(queryNode, *referenceNode.
Right());
168 if (leftScore < rightScore)
172 traversalInfo = rule.TraversalInfo();
173 rule.TraversalInfo() = leftInfo;
177 rightScore = rule.Rescore(queryNode, *referenceNode.
Right(),
180 if (rightScore != DBL_MAX)
183 rule.TraversalInfo() = traversalInfo;
189 else if (rightScore < leftScore)
195 leftScore = rule.Rescore(queryNode, *referenceNode.
Left(), leftScore);
197 if (leftScore != DBL_MAX)
200 rule.TraversalInfo() = leftInfo;
208 if (leftScore == DBL_MAX)
216 traversalInfo = rule.TraversalInfo();
217 rule.TraversalInfo() = leftInfo;
220 rightScore = rule.Rescore(queryNode, *referenceNode.
Right(),
223 if (rightScore != DBL_MAX)
226 rule.TraversalInfo() = traversalInfo;
237 if (Defeatist && referenceNode.
Overlap())
240 size_t bestChild = rule.GetBestChild(*queryNode.
Left(), referenceNode);
254 bestChild = rule.GetBestChild(*queryNode.
Right(), referenceNode);
274 double leftScore = rule.Score(*queryNode.
Left(), *referenceNode.
Left());
275 typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo();
276 rule.TraversalInfo() = traversalInfo;
277 double rightScore = rule.Score(*queryNode.
Left(), *referenceNode.
Right());
278 typename RuleType::TraversalInfoType rightInfo;
281 if (leftScore < rightScore)
285 rightInfo = rule.TraversalInfo();
286 rule.TraversalInfo() = leftInfo;
290 rightScore = rule.Rescore(*queryNode.
Left(), *referenceNode.
Right(),
293 if (rightScore != DBL_MAX)
296 rule.TraversalInfo() = rightInfo;
302 else if (rightScore < leftScore)
308 leftScore = rule.Rescore(*queryNode.
Left(), *referenceNode.
Left(),
311 if (leftScore != DBL_MAX)
314 rule.TraversalInfo() = leftInfo;
322 if (leftScore == DBL_MAX)
330 rightInfo = rule.TraversalInfo();
331 rule.TraversalInfo() = leftInfo;
335 rightScore = rule.Rescore(*queryNode.
Left(), *referenceNode.
Right(),
338 if (rightScore != DBL_MAX)
341 rule.TraversalInfo() = rightInfo;
350 rule.TraversalInfo() = traversalInfo;
353 leftScore = rule.Score(*queryNode.
Right(), *referenceNode.
Left());
354 leftInfo = rule.TraversalInfo();
355 rule.TraversalInfo() = traversalInfo;
356 rightScore = rule.Score(*queryNode.
Right(), *referenceNode.
Right());
359 if (leftScore < rightScore)
363 rightInfo = rule.TraversalInfo();
364 rule.TraversalInfo() = leftInfo;
368 rightScore = rule.Rescore(*queryNode.
Right(), *referenceNode.
Right(),
371 if (rightScore != DBL_MAX)
374 rule.TraversalInfo() = rightInfo;
380 else if (rightScore < leftScore)
386 leftScore = rule.Rescore(*queryNode.
Right(), *referenceNode.
Left(),
389 if (leftScore != DBL_MAX)
392 rule.TraversalInfo() = leftInfo;
400 if (leftScore == DBL_MAX)
408 rightInfo = rule.TraversalInfo();
409 rule.TraversalInfo() = leftInfo;
413 rightScore = rule.Rescore(*queryNode.
Right(), *referenceNode.
Right(),
416 if (rightScore != DBL_MAX)
419 rule.TraversalInfo() = rightInfo;
433 #endif // MLPACK_CORE_TREE_SPILL_TREE_SPILL_DUAL_TREE_TRAVERSER_IMPL_HPP Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
A generic single-tree traverser for hybrid spill trees; see spill_single_tree_traverser.hpp for implementation.
Definition: spill_single_tree_traverser.hpp:34
size_t NumPoints() const
Return the number of points in this node (0 if not a leaf).
Definition: spill_tree_impl.hpp:649
SpillTree & Child(const size_t child) const
Return the specified child (0 will be left, 1 will be right).
Definition: spill_tree_impl.hpp:631
void Traverse(const size_t queryIndex, SpillTree &referenceNode, const bool bruteForce=false)
Traverse the tree with the given point.
Definition: spill_single_tree_traverser_impl.hpp:46
SpillTree * Right() const
Gets the right child of this node.
Definition: spill_tree.hpp:262
A hybrid spill tree is a variant of binary space trees in which the children of a node can "spill ove...
Definition: spill_tree.hpp:73
bool Overlap() const
Distinguish overlapping nodes from non-overlapping nodes.
Definition: spill_tree.hpp:275
size_t NumDescendants() const
Return the number of descendants of this node.
Definition: spill_tree_impl.hpp:666
SpillTree * Parent() const
Gets the parent of this node.
Definition: spill_tree.hpp:267
SpillTree * Left() const
Gets the left child of this node.
Definition: spill_tree.hpp:257
bool IsLeaf() const
Return whether or not this node is a leaf (true if it has no children).
Definition: spill_tree_impl.hpp:433
size_t Descendant(const size_t index) const
Return the index (with reference to the dataset) of a particular descendant of this node...
Definition: spill_tree_impl.hpp:681
void Traverse(SpillTree &queryNode, SpillTree &referenceNode, const bool bruteForce=false)
Traverse the two trees.
Definition: spill_dual_tree_traverser_impl.hpp:49
size_t NumChildren() const
Return the number of children in this node.
Definition: spill_tree_impl.hpp:448
SpillDualTreeTraverser(RuleType &rule)
Instantiate the dual-tree traverser with the given rule set.
Definition: spill_dual_tree_traverser_impl.hpp:32
size_t Point(const size_t index) const
Return the index (with reference to the dataset) of a particular point in this node.
Definition: spill_tree_impl.hpp:705