14 #ifndef MLPACK_CORE_TREE_BINARY_SPACE_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP 15 #define MLPACK_CORE_TREE_BINARY_SPACE_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP 23 template<
typename MetricType,
24 typename StatisticType,
26 template<
typename BoundMetricType,
typename...>
class BoundType,
27 template<
typename SplitBoundType,
typename SplitMatType>
29 template<
typename RuleType>
39 template<
typename MetricType,
40 typename StatisticType,
42 template<
typename BoundMetricType,
typename...>
class BoundType,
43 template<
typename SplitBoundType,
typename SplitMatType>
45 template<
typename RuleType>
57 traversalInfo = rule.TraversalInfo();
60 if (queryNode.
Parent() == NULL && referenceNode.
Parent() == NULL)
62 const double rootScore = rule.Score(queryNode, referenceNode);
64 if (rootScore == DBL_MAX)
75 const size_t queryEnd = queryNode.
Begin() + queryNode.
Count();
76 const size_t refEnd = referenceNode.
Begin() + referenceNode.
Count();
77 for (
size_t query = queryNode.
Begin(); query < queryEnd; ++query)
82 rule.TraversalInfo() = traversalInfo;
83 const double childScore = rule.Score(query, referenceNode);
85 if (childScore == DBL_MAX)
88 for (
size_t ref = referenceNode.
Begin(); ref < refEnd; ++ref)
89 rule.BaseCase(query, ref);
91 numBaseCases += referenceNode.
Count();
94 else if (((!queryNode.
IsLeaf()) && referenceNode.
IsLeaf()) ||
100 const double leftScore = rule.Score(*queryNode.
Left(), referenceNode);
103 if (leftScore != DBL_MAX)
109 rule.TraversalInfo() = traversalInfo;
110 const double rightScore = rule.Score(*queryNode.
Right(), referenceNode);
113 if (rightScore != DBL_MAX)
118 else if (queryNode.
IsLeaf() && (!referenceNode.
IsLeaf()))
123 double leftScore = rule.Score(queryNode, *referenceNode.
Left());
124 typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo();
125 rule.TraversalInfo() = traversalInfo;
126 double rightScore = rule.Score(queryNode, *referenceNode.
Right());
129 if (leftScore < rightScore)
133 traversalInfo = rule.TraversalInfo();
134 rule.TraversalInfo() = leftInfo;
138 rightScore = rule.Rescore(queryNode, *referenceNode.
Right(), rightScore);
140 if (rightScore != DBL_MAX)
143 rule.TraversalInfo() = traversalInfo;
149 else if (rightScore < leftScore)
155 leftScore = rule.Rescore(queryNode, *referenceNode.
Left(), leftScore);
157 if (leftScore != DBL_MAX)
160 rule.TraversalInfo() = leftInfo;
168 if (leftScore == DBL_MAX)
176 traversalInfo = rule.TraversalInfo();
177 rule.TraversalInfo() = leftInfo;
180 rightScore = rule.Rescore(queryNode, *referenceNode.
Right(),
183 if (rightScore != DBL_MAX)
186 rule.TraversalInfo() = traversalInfo;
200 double leftScore = rule.Score(*queryNode.
Left(), *referenceNode.
Left());
201 typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo();
202 rule.TraversalInfo() = traversalInfo;
203 double rightScore = rule.Score(*queryNode.
Left(), *referenceNode.
Right());
204 typename RuleType::TraversalInfoType rightInfo;
207 if (leftScore < rightScore)
211 rightInfo = rule.TraversalInfo();
212 rule.TraversalInfo() = leftInfo;
216 rightScore = rule.Rescore(*queryNode.
Left(), *referenceNode.
Right(),
219 if (rightScore != DBL_MAX)
222 rule.TraversalInfo() = rightInfo;
228 else if (rightScore < leftScore)
234 leftScore = rule.Rescore(*queryNode.
Left(), *referenceNode.
Left(),
237 if (leftScore != DBL_MAX)
240 rule.TraversalInfo() = leftInfo;
248 if (leftScore == DBL_MAX)
256 rightInfo = rule.TraversalInfo();
257 rule.TraversalInfo() = leftInfo;
261 rightScore = rule.Rescore(*queryNode.
Left(), *referenceNode.
Right(),
264 if (rightScore != DBL_MAX)
267 rule.TraversalInfo() = rightInfo;
276 rule.TraversalInfo() = traversalInfo;
279 leftScore = rule.Score(*queryNode.
Right(), *referenceNode.
Left());
280 leftInfo = rule.TraversalInfo();
281 rule.TraversalInfo() = traversalInfo;
282 rightScore = rule.Score(*queryNode.
Right(), *referenceNode.
Right());
285 if (leftScore < rightScore)
289 rightInfo = rule.TraversalInfo();
290 rule.TraversalInfo() = leftInfo;
294 rightScore = rule.Rescore(*queryNode.
Right(), *referenceNode.
Right(),
297 if (rightScore != DBL_MAX)
300 rule.TraversalInfo() = rightInfo;
306 else if (rightScore < leftScore)
312 leftScore = rule.Rescore(*queryNode.
Right(), *referenceNode.
Left(),
315 if (leftScore != DBL_MAX)
318 rule.TraversalInfo() = leftInfo;
326 if (leftScore == DBL_MAX)
334 rightInfo = rule.TraversalInfo();
335 rule.TraversalInfo() = leftInfo;
339 rightScore = rule.Rescore(*queryNode.
Right(), *referenceNode.
Right(),
342 if (rightScore != DBL_MAX)
345 rule.TraversalInfo() = rightInfo;
358 #endif // MLPACK_CORE_TREE_BINARY_SPACE_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP BinarySpaceTree * Parent() const
Gets the parent of this node.
Definition: binary_space_tree.hpp:342
size_t NumDescendants() const
Return the number of descendants of this node.
Definition: binary_space_tree_impl.hpp:826
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
size_t Count() const
Return the number of points in this subset.
Definition: binary_space_tree.hpp:503
void Traverse(BinarySpaceTree &queryNode, BinarySpaceTree &referenceNode)
Traverse the two trees.
Definition: dual_tree_traverser_impl.hpp:47
BinarySpaceTree * Right() const
Gets the right child of this node.
Definition: binary_space_tree.hpp:337
A binary space partitioning tree, such as a KD-tree or a ball tree.
Definition: binary_space_tree.hpp:54
DualTreeTraverser(RuleType &rule)
Instantiate the dual-tree traverser with the given rule set.
Definition: dual_tree_traverser_impl.hpp:31
BinarySpaceTree * Left() const
Gets the left child of this node.
Definition: binary_space_tree.hpp:332
bool IsLeaf() const
Return whether or not this node is a leaf (true if it has no children).
Definition: binary_space_tree_impl.hpp:594
size_t Begin() const
Return the index of the beginning point of this subset.
Definition: binary_space_tree.hpp:498