12 #ifndef MLPACK_CORE_TREE_COVER_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP 13 #define MLPACK_CORE_TREE_COVER_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP 23 typename StatisticType,
25 typename RootPointPolicy
27 template<
typename RuleType>
36 typename StatisticType,
38 typename RootPointPolicy
40 template<
typename RuleType>
46 std::map<int, std::vector<DualCoverTreeMapEntry>, std::greater<int>> refMap;
48 DualCoverTreeMapEntry rootRefEntry;
50 rootRefEntry.referenceNode = &referenceNode;
53 rootRefEntry.score = rule.Score(queryNode, referenceNode);
54 rootRefEntry.baseCase = rule.BaseCase(queryNode.
Point(),
55 referenceNode.
Point());
56 rootRefEntry.traversalInfo = rule.TraversalInfo();
58 refMap[referenceNode.
Scale()].push_back(rootRefEntry);
65 typename StatisticType,
67 typename RootPointPolicy
69 template<
typename RuleType>
73 std::map<
int, std::vector<DualCoverTreeMapEntry>, std::greater<int>>&
76 if (referenceMap.size() == 0)
80 ReferenceRecursion(queryNode, referenceMap);
83 if (referenceMap.size() == 0)
88 if ((queryNode.
Scale() != INT_MIN) &&
89 (queryNode.
Scale() >= (*referenceMap.begin()).first))
96 for (
size_t i = 1; i < queryNode.
NumChildren(); ++i)
99 std::map<int, std::vector<DualCoverTreeMapEntry>, std::greater<int>>
102 PruneMap(queryNode.
Child(i), referenceMap, childMap);
105 std::map<int, std::vector<DualCoverTreeMapEntry>, std::greater<int>>
108 PruneMap(queryNode.
Child(0), referenceMap, selfChildMap);
112 if (queryNode.
Scale() != INT_MIN)
117 Log::Assert((*referenceMap.begin()).first == INT_MIN);
119 std::vector<DualCoverTreeMapEntry>& pointVector = referenceMap[INT_MIN];
121 for (
size_t i = 0; i < pointVector.size(); ++i)
124 const DualCoverTreeMapEntry& frame = pointVector[i];
126 CoverTree* refNode = frame.referenceNode;
139 rule.TraversalInfo() = frame.traversalInfo;
140 double score = rule.Score(queryNode, *refNode);
142 if (score == DBL_MAX)
149 rule.BaseCase(queryNode.
Point(), pointVector[i].referenceNode->Point());
155 typename StatisticType,
157 typename RootPointPolicy
159 template<
typename RuleType>
163 std::map<
int, std::vector<DualCoverTreeMapEntry>, std::greater<int>>&
165 std::map<
int, std::vector<DualCoverTreeMapEntry>, std::greater<int>>&
168 if (referenceMap.empty())
172 if (referenceMap.count(INT_MIN) == 1)
175 std::vector<DualCoverTreeMapEntry>& scaleVector = referenceMap[INT_MIN];
178 std::sort(scaleVector.begin(), scaleVector.end());
180 childMap[INT_MIN].reserve(scaleVector.size());
181 std::vector<DualCoverTreeMapEntry>& newScaleVector = childMap[INT_MIN];
184 for (
size_t j = 0; j < scaleVector.size(); ++j)
186 const DualCoverTreeMapEntry& frame = scaleVector[j];
189 CoverTree* refNode = frame.referenceNode;
192 rule.TraversalInfo() = frame.traversalInfo;
193 double score = rule.Score(queryNode, *refNode);
195 if (score == DBL_MAX)
203 const double baseCase = rule.BaseCase(queryNode.
Point(),
207 newScaleVector.push_back(frame);
208 newScaleVector.back().score = score;
209 newScaleVector.back().baseCase = baseCase;
210 newScaleVector.back().traversalInfo = rule.TraversalInfo();
214 if (newScaleVector.size() == 0)
215 childMap.erase(INT_MIN);
218 typename std::map<int, std::vector<DualCoverTreeMapEntry>,
219 std::greater<int>>::iterator it = referenceMap.begin();
221 while ((it != referenceMap.end()))
223 const int thisScale = (*it).first;
224 if (thisScale == INT_MIN)
228 std::vector<DualCoverTreeMapEntry>& scaleVector = (*it).second;
231 std::sort(scaleVector.begin(), scaleVector.end());
233 childMap[thisScale].reserve(scaleVector.size());
234 std::vector<DualCoverTreeMapEntry>& newScaleVector = childMap[thisScale];
237 for (
size_t j = 0; j < scaleVector.size(); ++j)
239 const DualCoverTreeMapEntry& frame = scaleVector[j];
242 CoverTree* refNode = frame.referenceNode;
245 rule.TraversalInfo() = frame.traversalInfo;
246 double score = rule.Score(queryNode, *refNode);
248 if (score == DBL_MAX)
256 const double baseCase = rule.BaseCase(queryNode.
Point(),
260 newScaleVector.push_back(frame);
261 newScaleVector.back().score = score;
262 newScaleVector.back().baseCase = baseCase;
263 newScaleVector.back().traversalInfo = rule.TraversalInfo();
267 if (newScaleVector.size() == 0)
268 childMap.erase((*it).first);
276 typename StatisticType,
278 typename RootPointPolicy
280 template<
typename RuleType>
284 std::map<
int, std::vector<DualCoverTreeMapEntry>, std::greater<int>>&
289 while (!referenceMap.empty())
291 const int maxScale = ((*referenceMap.begin()).first);
293 if (queryNode.
Parent() == NULL && maxScale < queryNode.
Scale())
295 if (queryNode.
Parent() != NULL && maxScale <= queryNode.
Scale())
299 if (queryNode.
Scale() == INT_MIN && maxScale == INT_MIN)
303 std::vector<DualCoverTreeMapEntry>& scaleVector = referenceMap[maxScale];
306 std::sort(scaleVector.begin(), scaleVector.end());
309 for (
size_t i = 0; i < scaleVector.size(); ++i)
312 const DualCoverTreeMapEntry& frame = scaleVector.at(i);
313 CoverTree* refNode = frame.referenceNode;
316 double score = rule.Rescore(queryNode, *refNode, frame.score);
320 if (score == DBL_MAX)
329 for (
size_t j = 0; j < refNode->
NumChildren(); ++j)
331 rule.TraversalInfo() = frame.traversalInfo;
332 double childScore = rule.Score(queryNode, refNode->
Child(j));
333 if (childScore == DBL_MAX)
340 const double baseCase = rule.BaseCase(queryNode.
Point(),
343 DualCoverTreeMapEntry newFrame;
344 newFrame.referenceNode = &refNode->
Child(j);
345 newFrame.score = childScore;
346 newFrame.baseCase = baseCase;
347 newFrame.traversalInfo = rule.TraversalInfo();
348 referenceMap[newFrame.referenceNode->Scale()].push_back(newFrame);
353 referenceMap.erase(maxScale);
int Scale() const
Get the scale of this node.
Definition: cover_tree.hpp:315
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.
void Traverse(CoverTree &queryNode, CoverTree &referenceNode)
Traverse the two specified trees.
Definition: dual_tree_traverser_impl.hpp:42
DualTreeTraverser(RuleType &rule)
Initialize the dual tree traverser with the given rule type.
Definition: dual_tree_traverser_impl.hpp:29
size_t Point() const
Get the index of the point which this node represents.
Definition: cover_tree.hpp:286
const CoverTree & Child(const size_t index) const
Get a particular child node.
Definition: cover_tree.hpp:294
size_t NumChildren() const
Get the number of children.
Definition: cover_tree.hpp:301
A cover tree is a tree specifically designed to speed up nearest-neighbor computation in high-dimensi...
Definition: cover_tree.hpp:99
static void Assert(bool condition, const std::string &message="Assert Failed.")
Checks if the specified condition is true.
Definition: log.cpp:38
CoverTree * Parent() const
Get the parent node.
Definition: cover_tree.hpp:404