12 #ifndef MLPACK_CORE_TREE_OCTREE_DUAL_TREE_TRAVERSER_IMPL_HPP 13 #define MLPACK_CORE_TREE_OCTREE_DUAL_TREE_TRAVERSER_IMPL_HPP 21 template<
typename MetricType,
typename StatisticType,
typename MatType>
22 template<
typename RuleType>
23 Octree<MetricType, StatisticType, MatType>::DualTreeTraverser<RuleType>::
24 DualTreeTraverser(RuleType& rule) :
34 template<
typename MetricType,
typename StatisticType,
typename MatType>
35 template<
typename RuleType>
43 traversalInfo = rule.TraversalInfo();
46 if (queryNode.
Parent() == NULL && referenceNode.
Parent() == NULL)
48 const double rootScore = rule.Score(queryNode, referenceNode);
50 if (rootScore == DBL_MAX)
59 const size_t begin = queryNode.
Point(0);
60 const size_t end = begin + queryNode.
NumPoints();
61 for (
size_t q = begin; q < end; ++q)
64 rule.TraversalInfo() = traversalInfo;
65 const double score = rule.Score(q, referenceNode);
72 const size_t rBegin = referenceNode.
Point(0);
73 const size_t rEnd = rBegin + referenceNode.
NumPoints();
74 for (
size_t r = rBegin; r < rEnd; ++r)
77 numBaseCases += referenceNode.
NumPoints();
83 for (
size_t i = 0; i < queryNode.
NumChildren(); ++i)
85 rule.TraversalInfo() = traversalInfo;
86 const double score = rule.Score(queryNode.
Child(i), referenceNode);
101 std::vector<typename RuleType::TraversalInfoType>
103 for (
size_t i = 0; i < referenceNode.
NumChildren(); ++i)
105 rule.TraversalInfo() = traversalInfo;
106 scores[i] = rule.Score(queryNode, referenceNode.
Child(i));
107 tis[i] = rule.TraversalInfo();
111 arma::uvec scoreOrder = arma::sort_index(scores);
112 for (
size_t i = 0; i < scoreOrder.n_elem; ++i)
114 if (scores[scoreOrder[i]] == DBL_MAX)
117 numPrunes += scoreOrder.n_elem - i;
121 rule.TraversalInfo() = tis[scoreOrder[i]];
131 std::vector<typename RuleType::TraversalInfoType>
133 for (
size_t j = 0; j < queryNode.
NumChildren(); ++j)
137 for (
size_t i = 0; i < referenceNode.
NumChildren(); ++i)
139 rule.TraversalInfo() = traversalInfo;
140 scores[i] = rule.Score(queryNode.
Child(j), referenceNode.
Child(i));
141 tis[i] = rule.TraversalInfo();
145 arma::uvec scoreOrder = arma::sort_index(scores);
146 for (
size_t i = 0; i < scoreOrder.n_elem; ++i)
148 if (scores[scoreOrder[i]] == DBL_MAX)
152 numPrunes += scoreOrder.n_elem - i;
156 rule.TraversalInfo() = tis[scoreOrder[i]];
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
size_t NumPoints() const
Return the number of points in this node (0 if not a leaf).
Definition: octree_impl.hpp:627
size_t NumChildren() const
Return the number of children in this node.
Definition: octree_impl.hpp:509
void Traverse(Octree &queryNode, Octree &referenceNode)
Traverse the two trees.
Definition: dual_tree_traverser_impl.hpp:37
size_t Point(const size_t index) const
Return the index (with reference to the dataset) of a particular point in this node.
Definition: octree_impl.hpp:647
bool IsLeaf() const
Return whether or not the node is a leaf.
Definition: octree.hpp:297
const Octree & Child(const size_t child) const
Return the specified child.
Definition: octree.hpp:340
Definition: octree.hpp:25
Octree * Parent() const
Get the pointer to the parent.
Definition: octree.hpp:256