mlpack
dual_tree_traverser_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_CORE_TREE_COVER_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP
13 #define MLPACK_CORE_TREE_COVER_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 #include <queue>
17 
18 namespace mlpack {
19 namespace tree {
20 
21 template<
22  typename MetricType,
23  typename StatisticType,
24  typename MatType,
25  typename RootPointPolicy
26 >
27 template<typename RuleType>
30  rule(rule),
31  numPrunes(0)
32 { /* Nothing to do. */ }
33 
34 template<
35  typename MetricType,
36  typename StatisticType,
37  typename MatType,
38  typename RootPointPolicy
39 >
40 template<typename RuleType>
43  CoverTree& referenceNode)
44 {
45  // Start by creating a map and adding the reference root node to it.
46  std::map<int, std::vector<DualCoverTreeMapEntry>, std::greater<int>> refMap;
47 
48  DualCoverTreeMapEntry rootRefEntry;
49 
50  rootRefEntry.referenceNode = &referenceNode;
51 
52  // Perform the evaluation between the roots of either tree.
53  rootRefEntry.score = rule.Score(queryNode, referenceNode);
54  rootRefEntry.baseCase = rule.BaseCase(queryNode.Point(),
55  referenceNode.Point());
56  rootRefEntry.traversalInfo = rule.TraversalInfo();
57 
58  refMap[referenceNode.Scale()].push_back(rootRefEntry);
59 
60  Traverse(queryNode, refMap);
61 }
62 
63 template<
64  typename MetricType,
65  typename StatisticType,
66  typename MatType,
67  typename RootPointPolicy
68 >
69 template<typename RuleType>
72  CoverTree& queryNode,
73  std::map<int, std::vector<DualCoverTreeMapEntry>, std::greater<int>>&
74  referenceMap)
75 {
76  if (referenceMap.size() == 0)
77  return; // Nothing to do!
78 
79  // First recurse down the reference nodes as necessary.
80  ReferenceRecursion(queryNode, referenceMap);
81 
82  // Did the map get emptied?
83  if (referenceMap.size() == 0)
84  return; // Nothing to do!
85 
86  // Now, reduce the scale of the query node by recursing. But we can't recurse
87  // if the query node is a leaf node.
88  if ((queryNode.Scale() != INT_MIN) &&
89  (queryNode.Scale() >= (*referenceMap.begin()).first))
90  {
91  // Recurse into the non-self-children first. The recursion order cannot
92  // affect the runtime of the algorithm, because each query child recursion's
93  // results are separate and independent. I don't think this is true in
94  // every case, and we may have to modify this section to consider scores in
95  // the future.
96  for (size_t i = 1; i < queryNode.NumChildren(); ++i)
97  {
98  // We need a copy of the map for this child.
99  std::map<int, std::vector<DualCoverTreeMapEntry>, std::greater<int>>
100  childMap;
101 
102  PruneMap(queryNode.Child(i), referenceMap, childMap);
103  Traverse(queryNode.Child(i), childMap);
104  }
105  std::map<int, std::vector<DualCoverTreeMapEntry>, std::greater<int>>
106  selfChildMap;
107 
108  PruneMap(queryNode.Child(0), referenceMap, selfChildMap);
109  Traverse(queryNode.Child(0), selfChildMap);
110  }
111 
112  if (queryNode.Scale() != INT_MIN)
113  return; // No need to evaluate base cases at this level. It's all done.
114 
115  // If we have made it this far, all we have is a bunch of base case
116  // evaluations to do.
117  Log::Assert((*referenceMap.begin()).first == INT_MIN);
118  Log::Assert(queryNode.Scale() == INT_MIN);
119  std::vector<DualCoverTreeMapEntry>& pointVector = referenceMap[INT_MIN];
120 
121  for (size_t i = 0; i < pointVector.size(); ++i)
122  {
123  // Get a reference to the frame.
124  const DualCoverTreeMapEntry& frame = pointVector[i];
125 
126  CoverTree* refNode = frame.referenceNode;
127 
128  // If the point is the same as both parents, then we have already done this
129  // base case.
130  if ((refNode->Point() == refNode->Parent()->Point()) &&
131  (queryNode.Point() == queryNode.Parent()->Point()))
132  {
133  ++numPrunes;
134  continue;
135  }
136 
137  // Score the node, to see if we can prune it, after restoring the traversal
138  // info.
139  rule.TraversalInfo() = frame.traversalInfo;
140  double score = rule.Score(queryNode, *refNode);
141 
142  if (score == DBL_MAX)
143  {
144  ++numPrunes;
145  continue;
146  }
147 
148  // If not, compute the base case.
149  rule.BaseCase(queryNode.Point(), pointVector[i].referenceNode->Point());
150  }
151 }
152 
153 template<
154  typename MetricType,
155  typename StatisticType,
156  typename MatType,
157  typename RootPointPolicy
158 >
159 template<typename RuleType>
162  CoverTree& queryNode,
163  std::map<int, std::vector<DualCoverTreeMapEntry>, std::greater<int>>&
164  referenceMap,
165  std::map<int, std::vector<DualCoverTreeMapEntry>, std::greater<int>>&
166  childMap)
167 {
168  if (referenceMap.empty())
169  return; // Nothing to do.
170 
171  // Copy the zero set first.
172  if (referenceMap.count(INT_MIN) == 1)
173  {
174  // Get a reference to the vector representing the entries at this scale.
175  std::vector<DualCoverTreeMapEntry>& scaleVector = referenceMap[INT_MIN];
176 
177  // Before traversing all the points in this scale, sort by score.
178  std::sort(scaleVector.begin(), scaleVector.end());
179 
180  childMap[INT_MIN].reserve(scaleVector.size());
181  std::vector<DualCoverTreeMapEntry>& newScaleVector = childMap[INT_MIN];
182 
183  // Loop over each entry in the vector.
184  for (size_t j = 0; j < scaleVector.size(); ++j)
185  {
186  const DualCoverTreeMapEntry& frame = scaleVector[j];
187 
188  // First evaluate if we can prune without performing the base case.
189  CoverTree* refNode = frame.referenceNode;
190 
191  // Perform the actual scoring, after restoring the traversal info.
192  rule.TraversalInfo() = frame.traversalInfo;
193  double score = rule.Score(queryNode, *refNode);
194 
195  if (score == DBL_MAX)
196  {
197  // Pruned. Move on.
198  ++numPrunes;
199  continue;
200  }
201 
202  // If it isn't pruned, we must evaluate the base case.
203  const double baseCase = rule.BaseCase(queryNode.Point(),
204  refNode->Point());
205 
206  // Add to child map.
207  newScaleVector.push_back(frame);
208  newScaleVector.back().score = score;
209  newScaleVector.back().baseCase = baseCase;
210  newScaleVector.back().traversalInfo = rule.TraversalInfo();
211  }
212 
213  // If we didn't add anything, then strike this vector from the map.
214  if (newScaleVector.size() == 0)
215  childMap.erase(INT_MIN);
216  }
217 
218  typename std::map<int, std::vector<DualCoverTreeMapEntry>,
219  std::greater<int>>::iterator it = referenceMap.begin();
220 
221  while ((it != referenceMap.end()))
222  {
223  const int thisScale = (*it).first;
224  if (thisScale == INT_MIN) // We already did it.
225  break;
226 
227  // Get a reference to the vector representing the entries at this scale.
228  std::vector<DualCoverTreeMapEntry>& scaleVector = (*it).second;
229 
230  // Before traversing all the points in this scale, sort by score.
231  std::sort(scaleVector.begin(), scaleVector.end());
232 
233  childMap[thisScale].reserve(scaleVector.size());
234  std::vector<DualCoverTreeMapEntry>& newScaleVector = childMap[thisScale];
235 
236  // Loop over each entry in the vector.
237  for (size_t j = 0; j < scaleVector.size(); ++j)
238  {
239  const DualCoverTreeMapEntry& frame = scaleVector[j];
240 
241  // First evaluate if we can prune without performing the base case.
242  CoverTree* refNode = frame.referenceNode;
243 
244  // Perform the actual scoring, after restoring the traversal info.
245  rule.TraversalInfo() = frame.traversalInfo;
246  double score = rule.Score(queryNode, *refNode);
247 
248  if (score == DBL_MAX)
249  {
250  // Pruned. Move on.
251  ++numPrunes;
252  continue;
253  }
254 
255  // If it isn't pruned, we must evaluate the base case.
256  const double baseCase = rule.BaseCase(queryNode.Point(),
257  refNode->Point());
258 
259  // Add to child map.
260  newScaleVector.push_back(frame);
261  newScaleVector.back().score = score;
262  newScaleVector.back().baseCase = baseCase;
263  newScaleVector.back().traversalInfo = rule.TraversalInfo();
264  }
265 
266  // If we didn't add anything, then strike this vector from the map.
267  if (newScaleVector.size() == 0)
268  childMap.erase((*it).first);
269 
270  ++it; // Advance to next scale.
271  }
272 }
273 
274 template<
275  typename MetricType,
276  typename StatisticType,
277  typename MatType,
278  typename RootPointPolicy
279 >
280 template<typename RuleType>
283  CoverTree& queryNode,
284  std::map<int, std::vector<DualCoverTreeMapEntry>, std::greater<int>>&
285  referenceMap)
286 {
287  // First, reduce the maximum scale in the reference map down to the scale of
288  // the query node.
289  while (!referenceMap.empty())
290  {
291  const int maxScale = ((*referenceMap.begin()).first);
292  // Hacky bullshit to imitate jl cover tree.
293  if (queryNode.Parent() == NULL && maxScale < queryNode.Scale())
294  break;
295  if (queryNode.Parent() != NULL && maxScale <= queryNode.Scale())
296  break;
297  // If the query node's scale is INT_MIN and the reference map's maximum
298  // scale is INT_MIN, don't try to recurse...
299  if (queryNode.Scale() == INT_MIN && maxScale == INT_MIN)
300  break;
301 
302  // Get a reference to the current largest scale.
303  std::vector<DualCoverTreeMapEntry>& scaleVector = referenceMap[maxScale];
304 
305  // Before traversing all the points in this scale, sort by score.
306  std::sort(scaleVector.begin(), scaleVector.end());
307 
308  // Now loop over each element.
309  for (size_t i = 0; i < scaleVector.size(); ++i)
310  {
311  // Get a reference to the current element.
312  const DualCoverTreeMapEntry& frame = scaleVector.at(i);
313  CoverTree* refNode = frame.referenceNode;
314 
315  // Create the score for the children.
316  double score = rule.Rescore(queryNode, *refNode, frame.score);
317 
318  // Now if this childScore is DBL_MAX we can prune all children. In this
319  // recursion setup pruning is all or nothing for children.
320  if (score == DBL_MAX)
321  {
322  ++numPrunes;
323  continue;
324  }
325 
326  // If it is not pruned, we must evaluate the base case.
327 
328  // Add the children.
329  for (size_t j = 0; j < refNode->NumChildren(); ++j)
330  {
331  rule.TraversalInfo() = frame.traversalInfo;
332  double childScore = rule.Score(queryNode, refNode->Child(j));
333  if (childScore == DBL_MAX)
334  {
335  ++numPrunes;
336  continue;
337  }
338 
339  // It wasn't pruned; evaluate the base case.
340  const double baseCase = rule.BaseCase(queryNode.Point(),
341  refNode->Child(j).Point());
342 
343  DualCoverTreeMapEntry newFrame;
344  newFrame.referenceNode = &refNode->Child(j);
345  newFrame.score = childScore; // Use the score of the parent.
346  newFrame.baseCase = baseCase;
347  newFrame.traversalInfo = rule.TraversalInfo();
348  referenceMap[newFrame.referenceNode->Scale()].push_back(newFrame);
349  }
350  }
351 
352  // Now clear the memory for this scale; it isn't needed anymore.
353  referenceMap.erase(maxScale);
354  }
355 }
356 
357 } // namespace tree
358 } // namespace mlpack
359 
360 #endif
int Scale() const
Get the scale of this node.
Definition: cover_tree.hpp:315
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.
void Traverse(CoverTree &queryNode, CoverTree &referenceNode)
Traverse the two specified trees.
Definition: dual_tree_traverser_impl.hpp:42
DualTreeTraverser(RuleType &rule)
Initialize the dual tree traverser with the given rule type.
Definition: dual_tree_traverser_impl.hpp:29
size_t Point() const
Get the index of the point which this node represents.
Definition: cover_tree.hpp:286
const CoverTree & Child(const size_t index) const
Get a particular child node.
Definition: cover_tree.hpp:294
size_t NumChildren() const
Get the number of children.
Definition: cover_tree.hpp:301
A cover tree is a tree specifically designed to speed up nearest-neighbor computation in high-dimensi...
Definition: cover_tree.hpp:99
static void Assert(bool condition, const std::string &message="Assert Failed.")
Checks if the specified condition is true.
Definition: log.cpp:38
CoverTree * Parent() const
Get the parent node.
Definition: cover_tree.hpp:404