mlpack
fastmks_rules_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_FASTMKS_FASTMKS_RULES_IMPL_HPP
13 #define MLPACK_METHODS_FASTMKS_FASTMKS_RULES_IMPL_HPP
14 
15 // In case it hasn't already been included.
16 #include "fastmks_rules.hpp"
17 
18 namespace mlpack {
19 namespace fastmks {
20 
21 template<typename KernelType, typename TreeType>
23  const typename TreeType::Mat& referenceSet,
24  const typename TreeType::Mat& querySet,
25  const size_t k,
26  KernelType& kernel) :
27  referenceSet(referenceSet),
28  querySet(querySet),
29  k(k),
30  kernel(kernel),
31  lastQueryIndex(-1),
32  lastReferenceIndex(-1),
33  lastKernel(0.0),
34  baseCases(0),
35  scores(0)
36 {
37  // Precompute each self-kernel.
38  queryKernels.set_size(querySet.n_cols);
39  for (size_t i = 0; i < querySet.n_cols; ++i)
40  queryKernels[i] = sqrt(kernel.Evaluate(querySet.col(i),
41  querySet.col(i)));
42 
43  referenceKernels.set_size(referenceSet.n_cols);
44  for (size_t i = 0; i < referenceSet.n_cols; ++i)
45  referenceKernels[i] = sqrt(kernel.Evaluate(referenceSet.col(i),
46  referenceSet.col(i)));
47 
48  // Set to invalid memory, so that the first node combination does not try to
49  // dereference null pointers.
50  traversalInfo.LastQueryNode() = (TreeType*) this;
51  traversalInfo.LastReferenceNode() = (TreeType*) this;
52 
53  // Let's build the list of candidate points for each query point.
54  // It will be initialized with k candidates: (-DBL_MAX, size_t() - 1)
55  // The list of candidates will be updated when visiting new points with the
56  // BaseCase() method.
57  const Candidate def = std::make_pair(-DBL_MAX, size_t() - 1);
58 
59  CandidateList pqueue;
60  pqueue.reserve(k);
61  for (size_t i = 0; i < k; ++i)
62  pqueue.push(def);
63  std::vector<CandidateList> tmp(querySet.n_cols, pqueue);
64  candidates.swap(tmp);
65 }
66 
67 template<typename KernelType, typename TreeType>
69  arma::Mat<size_t>& indices,
70  arma::mat& products)
71 {
72  indices.set_size(k, querySet.n_cols);
73  products.set_size(k, querySet.n_cols);
74 
75  for (size_t i = 0; i < querySet.n_cols; ++i)
76  {
77  CandidateList& pqueue = candidates[i];
78  for (size_t j = 1; j <= k; ++j)
79  {
80  indices(k - j, i) = pqueue.top().second;
81  products(k - j, i) = pqueue.top().first;
82  pqueue.pop();
83  }
84  }
85 }
86 
87 template<typename KernelType, typename TreeType>
88 inline force_inline
90  const size_t queryIndex,
91  const size_t referenceIndex)
92 {
93  // Score() always happens before BaseCase() for a given node combination. For
94  // cover trees, the kernel evaluation between the two centroid points already
95  // happened. So we don't need to do it. Note that this optimizes out if the
96  // first conditional is false (its result is known at compile time).
98  {
99  if ((queryIndex == lastQueryIndex) &&
100  (referenceIndex == lastReferenceIndex))
101  return lastKernel;
102 
103  // Store new values.
104  lastQueryIndex = queryIndex;
105  lastReferenceIndex = referenceIndex;
106  }
107 
108  ++baseCases;
109  double kernelEval = kernel.Evaluate(querySet.col(queryIndex),
110  referenceSet.col(referenceIndex));
111 
112  // Update the last kernel value, if we need to.
114  lastKernel = kernelEval;
115 
116  // If the reference and query sets are identical, we still need to compute the
117  // base case (so that things can be bounded properly), but we won't add it to
118  // the results.
119  if ((&querySet == &referenceSet) && (queryIndex == referenceIndex))
120  return kernelEval;
121 
122  InsertNeighbor(queryIndex, referenceIndex, kernelEval);
123 
124  return kernelEval;
125 }
126 
127 template<typename KernelType, typename TreeType>
128 double FastMKSRules<KernelType, TreeType>::Score(const size_t queryIndex,
129  TreeType& referenceNode)
130 {
131  // Compare with the current best.
132  const double bestKernel = candidates[queryIndex].top().first;
133 
134  // See if we can perform a parent-child prune.
135  const double furthestDist = referenceNode.FurthestDescendantDistance();
136  if (referenceNode.Parent() != NULL)
137  {
138  double maxKernelBound;
139  const double parentDist = referenceNode.ParentDistance();
140  const double combinedDistBound = parentDist + furthestDist;
141  const double lastKernel = referenceNode.Parent()->Stat().LastKernel();
143  {
144  const double squaredDist = std::pow(combinedDistBound, 2.0);
145  const double delta = (1 - 0.5 * squaredDist);
146  if (lastKernel <= delta)
147  {
148  const double gamma = combinedDistBound * sqrt(1 - 0.25 * squaredDist);
149  maxKernelBound = lastKernel * delta +
150  gamma * sqrt(1 - std::pow(lastKernel, 2.0));
151  }
152  else
153  {
154  maxKernelBound = 1.0;
155  }
156  }
157  else
158  {
159  maxKernelBound = lastKernel +
160  combinedDistBound * queryKernels[queryIndex];
161  }
162 
163  if (maxKernelBound < bestKernel)
164  return DBL_MAX;
165  }
166 
167  // Calculate the maximum possible kernel value, either by calculating the
168  // centroid or, if the centroid is a point, use that.
169  ++scores;
170  double kernelEval;
172  {
173  // Could it be that this kernel evaluation has already been calculated?
175  referenceNode.Parent() != NULL &&
176  referenceNode.Point(0) == referenceNode.Parent()->Point(0))
177  {
178  kernelEval = referenceNode.Parent()->Stat().LastKernel();
179  }
180  else
181  {
182  kernelEval = BaseCase(queryIndex, referenceNode.Point(0));
183  }
184  }
185  else
186  {
187  arma::vec refCenter;
188  referenceNode.Center(refCenter);
189 
190  kernelEval = kernel.Evaluate(querySet.col(queryIndex), refCenter);
191  }
192 
193  referenceNode.Stat().LastKernel() = kernelEval;
194 
195  double maxKernel;
197  {
198  const double squaredDist = std::pow(furthestDist, 2.0);
199  const double delta = (1 - 0.5 * squaredDist);
200  if (kernelEval <= delta)
201  {
202  const double gamma = furthestDist * sqrt(1 - 0.25 * squaredDist);
203  maxKernel = kernelEval * delta +
204  gamma * sqrt(1 - std::pow(kernelEval, 2.0));
205  }
206  else
207  {
208  maxKernel = 1.0;
209  }
210  }
211  else
212  {
213  maxKernel = kernelEval + furthestDist * queryKernels[queryIndex];
214  }
215 
216  // We return the inverse of the maximum kernel so that larger kernels are
217  // recursed into first.
218  return (maxKernel >= bestKernel) ? (1.0 / maxKernel) : DBL_MAX;
219 }
220 
221 template<typename KernelType, typename TreeType>
222 double FastMKSRules<KernelType, TreeType>::Score(TreeType& queryNode,
223  TreeType& referenceNode)
224 {
225  // Update and get the query node's bound.
226  queryNode.Stat().Bound() = CalculateBound(queryNode);
227  const double bestKernel = queryNode.Stat().Bound();
228 
229  // First, see if we can make a parent-child or parent-parent prune. These
230  // four bounds on the maximum kernel value are looser than the bound normally
231  // used, but they can prevent a base case from needing to be calculated.
232 
233  // Convenience caching so lines are shorter.
234  const double queryParentDist = queryNode.ParentDistance();
235  const double queryDescDist = queryNode.FurthestDescendantDistance();
236  const double refParentDist = referenceNode.ParentDistance();
237  const double refDescDist = referenceNode.FurthestDescendantDistance();
238  double adjustedScore = traversalInfo.LastBaseCase();
239 
240  const double queryDistBound = (queryParentDist + queryDescDist);
241  const double refDistBound = (refParentDist + refDescDist);
242  double dualQueryTerm;
243  double dualRefTerm;
244 
245  // The parent-child and parent-parent prunes work by applying the same pruning
246  // condition as when the parent node was used, except they are tighter because
247  // queryDistBound < queryNode.Parent()->FurthestDescendantDistance()
248  // and
249  // refDistBound < referenceNode.Parent()->FurthestDescendantDistance()
250  // so we construct the same bounds that were used when Score() was called with
251  // the parents, except with the tighter distance bounds. Sometimes this
252  // allows us to prune nodes without evaluating the base cases between them.
253  if (traversalInfo.LastQueryNode() == queryNode.Parent())
254  {
255  // We can assume that queryNode.Parent() != NULL, because at the root node
256  // combination, the traversalInfo.LastQueryNode() pointer will _not_ be
257  // NULL. We also should be guaranteed that
258  // traversalInfo.LastReferenceNode() is either the reference node or the
259  // parent of the reference node.
260  adjustedScore += queryDistBound *
261  traversalInfo.LastReferenceNode()->Stat().SelfKernel();
262  dualQueryTerm = queryDistBound;
263  }
264  else
265  {
266  // The query parent could be NULL, which does weird things and we have to
267  // consider.
268  if (traversalInfo.LastReferenceNode() != NULL)
269  {
270  adjustedScore += queryDescDist *
271  traversalInfo.LastReferenceNode()->Stat().SelfKernel();
272  dualQueryTerm = queryDescDist;
273  }
274  else
275  {
276  // This makes it so a child-parent (or parent-parent) prune is not
277  // possible.
278  dualQueryTerm = 0.0;
279  adjustedScore = bestKernel;
280  }
281  }
282 
283  if (traversalInfo.LastReferenceNode() == referenceNode.Parent())
284  {
285  // We can assume that referenceNode.Parent() != NULL, because at the root
286  // node combination, the traversalInfo.LastReferenceNode() pointer will
287  // _not_ be NULL.
288  adjustedScore += refDistBound *
289  traversalInfo.LastQueryNode()->Stat().SelfKernel();
290  dualRefTerm = refDistBound;
291  }
292  else
293  {
294  // The reference parent could be NULL, which does weird things and we have
295  // to consider.
296  if (traversalInfo.LastQueryNode() != NULL)
297  {
298  adjustedScore += refDescDist *
299  traversalInfo.LastQueryNode()->Stat().SelfKernel();
300  dualRefTerm = refDescDist;
301  }
302  else
303  {
304  // This makes it so a child-parent (or parent-parent) prune is not
305  // possible.
306  dualRefTerm = 0.0;
307  adjustedScore = bestKernel;
308  }
309  }
310 
311  // Now add the dual term.
312  adjustedScore += (dualQueryTerm * dualRefTerm);
313 
314  if (adjustedScore < bestKernel)
315  {
316  // It is not possible that this node combination can contain a point
317  // combination with kernel value better than the minimum kernel value to
318  // improve any of the results, so we can prune it.
319  return DBL_MAX;
320  }
321 
322  // We were unable to perform a parent-child or parent-parent prune, so now we
323  // must calculate kernel evaluation, if necessary.
324  double kernelEval = 0.0;
326  {
327  // For this type of tree, we may have already calculated the base case in
328  // the parents.
329  if ((traversalInfo.LastQueryNode() != NULL) &&
330  (traversalInfo.LastReferenceNode() != NULL) &&
331  (traversalInfo.LastQueryNode()->Point(0) == queryNode.Point(0)) &&
332  (traversalInfo.LastReferenceNode()->Point(0) == referenceNode.Point(0)))
333  {
334  // Base case already done.
335  kernelEval = traversalInfo.LastBaseCase();
336 
337  // When BaseCase() is called after Score(), these must be correct so that
338  // another kernel evaluation is not performed.
339  lastQueryIndex = queryNode.Point(0);
340  lastReferenceIndex = referenceNode.Point(0);
341  }
342  else
343  {
344  // The kernel must be evaluated, but it is between points in the dataset,
345  // so we can call BaseCase(). BaseCase() will set lastQueryIndex and
346  // lastReferenceIndex correctly.
347  kernelEval = BaseCase(queryNode.Point(0), referenceNode.Point(0));
348  }
349 
350  traversalInfo.LastBaseCase() = kernelEval;
351  }
352  else
353  {
354  // Calculate the maximum possible kernel value.
355  arma::vec queryCenter;
356  arma::vec refCenter;
357  queryNode.Center(queryCenter);
358  referenceNode.Center(refCenter);
359 
360  kernelEval = kernel.Evaluate(queryCenter, refCenter);
361 
362  traversalInfo.LastBaseCase() = kernelEval;
363  }
364  ++scores;
365 
366  double maxKernel;
368  {
369  // We have a tighter bound for normalized kernels.
370  const double querySqDist = std::pow(queryDescDist, 2.0);
371  const double refSqDist = std::pow(refDescDist, 2.0);
372  const double bothSqDist = std::pow((queryDescDist + refDescDist), 2.0);
373 
374  if (kernelEval <= (1 - 0.5 * bothSqDist))
375  {
376  const double queryDelta = (1 - 0.5 * querySqDist);
377  const double queryGamma = queryDescDist * sqrt(1 - 0.25 * querySqDist);
378  const double refDelta = (1 - 0.5 * refSqDist);
379  const double refGamma = refDescDist * sqrt(1 - 0.25 * refSqDist);
380 
381  maxKernel = kernelEval * (queryDelta * refDelta - queryGamma * refGamma) +
382  sqrt(1 - std::pow(kernelEval, 2.0)) *
383  (queryGamma * refDelta + queryDelta * refGamma);
384  }
385  else
386  {
387  maxKernel = 1.0;
388  }
389  }
390  else
391  {
392  // Use standard bound; kernel is not normalized.
393  const double refKernelTerm = queryDescDist *
394  referenceNode.Stat().SelfKernel();
395  const double queryKernelTerm = refDescDist * queryNode.Stat().SelfKernel();
396 
397  maxKernel = kernelEval + refKernelTerm + queryKernelTerm +
398  (queryDescDist * refDescDist);
399  }
400 
401  // Store relevant information for parent-child pruning.
402  traversalInfo.LastQueryNode() = &queryNode;
403  traversalInfo.LastReferenceNode() = &referenceNode;
404 
405  // We return the inverse of the maximum kernel so that larger kernels are
406  // recursed into first.
407  return (maxKernel >= bestKernel) ? (1.0 / maxKernel) : DBL_MAX;
408 }
409 
410 template<typename KernelType, typename TreeType>
411 double FastMKSRules<KernelType, TreeType>::Rescore(const size_t queryIndex,
412  TreeType& /*referenceNode*/,
413  const double oldScore) const
414 {
415  const double bestKernel = candidates[queryIndex].top().first;
416 
417  return ((1.0 / oldScore) >= bestKernel) ? oldScore : DBL_MAX;
418 }
419 
420 template<typename KernelType, typename TreeType>
422  TreeType& /*referenceNode*/,
423  const double oldScore) const
424 {
425  queryNode.Stat().Bound() = CalculateBound(queryNode);
426  const double bestKernel = queryNode.Stat().Bound();
427 
428  return ((1.0 / oldScore) >= bestKernel) ? oldScore : DBL_MAX;
429 }
430 
438 template<typename KernelType, typename TreeType>
439 double FastMKSRules<KernelType, TreeType>::CalculateBound(TreeType& queryNode)
440  const
441 {
442  // We have four possible bounds -- just like NeighborSearchRules, but they are
443  // slightly different in this context.
444  //
445  // (1) min ( min_{all points p in queryNode} P_p[k],
446  // min_{all children c in queryNode} B(c) );
447  // (2) max_{all points p in queryNode} P_p[k] + (worst child distance + worst
448  // descendant distance) sqrt(K(I_p[k], I_p[k]));
449  // (3) max_{all children c in queryNode} B(c) + <-- not done yet. ignored.
450  // (4) B(parent of queryNode);
451  double worstPointKernel = DBL_MAX;
452  double bestAdjustedPointKernel = -DBL_MAX;
453 
454  const double queryDescendantDistance = queryNode.FurthestDescendantDistance();
455 
456  // Loop over all points in this node to find the worst max-kernel value and
457  // the best possible adjusted max-kernel value that could be held by any
458  // descendant.
459  for (size_t i = 0; i < queryNode.NumPoints(); ++i)
460  {
461  const size_t point = queryNode.Point(i);
462  const CandidateList& candidatesPoints = candidates[point];
463  if (candidatesPoints.top().first < worstPointKernel)
464  worstPointKernel = candidatesPoints.top().first;
465 
466  if (candidatesPoints.top().first == -DBL_MAX)
467  continue; // Avoid underflow.
468 
469  // This should be (queryDescendantDistance + centroidDistance) for any tree
470  // but it works for cover trees since centroidDistance = 0 for cover trees.
471  // The formulation here is slightly different than in Equation 43 of
472  // "Dual-tree fast exact max-kernel search". Because we could be searching
473  // for k max kernels and not just one, the bound for this point must
474  // actually be the minimum adjusted kernel of all k candidate kernels.
475  // So,
476  // B(N_q) = min_{1 \le j \le k} k_j^*(p_q) -
477  // \lambda_q \sqrt(K(p_j^*(p_q), p_j^*(p_q)))
478  // where p_j^*(p_q) is the j'th kernel candidate for query point p_q and
479  // k_j^*(p_q) is K(p_q, p_j^*(p_q)).
480  double worstPointCandidateKernel = DBL_MAX;
481  typedef typename CandidateList::const_iterator iter;
482  for (iter it = candidatesPoints.begin(); it != candidatesPoints.end(); ++it)
483  {
484  const double candidateKernel = it->first - queryDescendantDistance *
485  referenceKernels[it->second];
486  if (candidateKernel < worstPointCandidateKernel)
487  worstPointCandidateKernel = candidateKernel;
488  }
489 
490  if (worstPointCandidateKernel > bestAdjustedPointKernel)
491  bestAdjustedPointKernel = worstPointCandidateKernel;
492  }
493 
494  // Loop over all the children in the node.
495  double worstChildKernel = DBL_MAX;
496 
497  for (size_t i = 0; i < queryNode.NumChildren(); ++i)
498  {
499  if (queryNode.Child(i).Stat().Bound() < worstChildKernel)
500  worstChildKernel = queryNode.Child(i).Stat().Bound();
501  }
502 
503  // Now assemble bound (1).
504  const double firstBound = (worstPointKernel < worstChildKernel) ?
505  worstPointKernel : worstChildKernel;
506 
507  // Bound (2) is bestAdjustedPointKernel.
508  const double fourthBound = (queryNode.Parent() == NULL) ? -DBL_MAX :
509  queryNode.Parent()->Stat().Bound();
510 
511  // Pick the best of these bounds.
512  const double interA = (firstBound > bestAdjustedPointKernel) ? firstBound :
513  bestAdjustedPointKernel;
514  const double interB = fourthBound;
515 
516  return (interA > interB) ? interA : interB;
517 }
518 
526 template<typename KernelType, typename TreeType>
528  const size_t queryIndex,
529  const size_t index,
530  const double product)
531 {
532  CandidateList& pqueue = candidates[queryIndex];
533  if (product > pqueue.top().first)
534  {
535  Candidate c = std::make_pair(product, index);
536  pqueue.pop();
537  pqueue.push(c);
538  }
539 }
540 
541 } // namespace fastmks
542 } // namespace mlpack
543 
544 #endif
double Rescore(const size_t queryIndex, TreeType &referenceNode, const double oldScore) const
Re-evaluate the score for recursion order.
Definition: fastmks_rules_impl.hpp:411
This is a template class that can provide information about various kernels.
Definition: kernel_traits.hpp:27
void GetResults(arma::Mat< size_t > &indices, arma::mat &products)
Store the list of candidates for each query point in the given matrices.
Definition: fastmks_rules_impl.hpp:68
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
double BaseCase(const size_t queryIndex, const size_t referenceIndex)
Compute the base case (kernel value) between two points.
Definition: fastmks_rules_impl.hpp:89
FastMKSRules(const typename TreeType::Mat &referenceSet, const typename TreeType::Mat &querySet, const size_t k, KernelType &kernel)
Construct the FastMKSRules object.
Definition: fastmks_rules_impl.hpp:22
double Score(const size_t queryIndex, TreeType &referenceNode)
Get the score for recursion order.
Definition: fastmks_rules_impl.hpp:128
double LastBaseCase() const
Get the base case associated with the last node combination.
Definition: traversal_info.hpp:78
TreeType * LastQueryNode() const
Get the last query node.
Definition: traversal_info.hpp:63
The FastMKSRules class is a template helper class used by FastMKS class when performing exact max-ker...
Definition: fastmks_rules.hpp:34
The TreeTraits class provides compile-time information on the characteristics of a given tree type...
Definition: tree_traits.hpp:77
TreeType * LastReferenceNode() const
Get the last reference node.
Definition: traversal_info.hpp:68