12 #ifndef MLPACK_METHODS_FASTMKS_FASTMKS_RULES_IMPL_HPP 13 #define MLPACK_METHODS_FASTMKS_FASTMKS_RULES_IMPL_HPP 21 template<
typename KernelType,
typename TreeType>
23 const typename TreeType::Mat& referenceSet,
24 const typename TreeType::Mat& querySet,
27 referenceSet(referenceSet),
32 lastReferenceIndex(-1),
38 queryKernels.set_size(querySet.n_cols);
39 for (
size_t i = 0; i < querySet.n_cols; ++i)
40 queryKernels[i] = sqrt(kernel.Evaluate(querySet.col(i),
43 referenceKernels.set_size(referenceSet.n_cols);
44 for (
size_t i = 0; i < referenceSet.n_cols; ++i)
45 referenceKernels[i] = sqrt(kernel.Evaluate(referenceSet.col(i),
46 referenceSet.col(i)));
57 const Candidate def = std::make_pair(-DBL_MAX,
size_t() - 1);
61 for (
size_t i = 0; i < k; ++i)
63 std::vector<CandidateList> tmp(querySet.n_cols, pqueue);
67 template<
typename KernelType,
typename TreeType>
69 arma::Mat<size_t>& indices,
72 indices.set_size(k, querySet.n_cols);
73 products.set_size(k, querySet.n_cols);
75 for (
size_t i = 0; i < querySet.n_cols; ++i)
77 CandidateList& pqueue = candidates[i];
78 for (
size_t j = 1; j <= k; ++j)
80 indices(k - j, i) = pqueue.top().second;
81 products(k - j, i) = pqueue.top().first;
87 template<
typename KernelType,
typename TreeType>
90 const size_t queryIndex,
91 const size_t referenceIndex)
99 if ((queryIndex == lastQueryIndex) &&
100 (referenceIndex == lastReferenceIndex))
104 lastQueryIndex = queryIndex;
105 lastReferenceIndex = referenceIndex;
109 double kernelEval = kernel.Evaluate(querySet.col(queryIndex),
110 referenceSet.col(referenceIndex));
114 lastKernel = kernelEval;
119 if ((&querySet == &referenceSet) && (queryIndex == referenceIndex))
122 InsertNeighbor(queryIndex, referenceIndex, kernelEval);
127 template<
typename KernelType,
typename TreeType>
129 TreeType& referenceNode)
132 const double bestKernel = candidates[queryIndex].top().first;
135 const double furthestDist = referenceNode.FurthestDescendantDistance();
136 if (referenceNode.Parent() != NULL)
138 double maxKernelBound;
139 const double parentDist = referenceNode.ParentDistance();
140 const double combinedDistBound = parentDist + furthestDist;
141 const double lastKernel = referenceNode.Parent()->Stat().LastKernel();
144 const double squaredDist = std::pow(combinedDistBound, 2.0);
145 const double delta = (1 - 0.5 * squaredDist);
146 if (lastKernel <= delta)
148 const double gamma = combinedDistBound * sqrt(1 - 0.25 * squaredDist);
149 maxKernelBound = lastKernel * delta +
150 gamma * sqrt(1 - std::pow(lastKernel, 2.0));
154 maxKernelBound = 1.0;
159 maxKernelBound = lastKernel +
160 combinedDistBound * queryKernels[queryIndex];
163 if (maxKernelBound < bestKernel)
175 referenceNode.Parent() != NULL &&
176 referenceNode.Point(0) == referenceNode.Parent()->Point(0))
178 kernelEval = referenceNode.Parent()->Stat().LastKernel();
182 kernelEval =
BaseCase(queryIndex, referenceNode.Point(0));
188 referenceNode.Center(refCenter);
190 kernelEval = kernel.Evaluate(querySet.col(queryIndex), refCenter);
193 referenceNode.Stat().LastKernel() = kernelEval;
198 const double squaredDist = std::pow(furthestDist, 2.0);
199 const double delta = (1 - 0.5 * squaredDist);
200 if (kernelEval <= delta)
202 const double gamma = furthestDist * sqrt(1 - 0.25 * squaredDist);
203 maxKernel = kernelEval * delta +
204 gamma * sqrt(1 - std::pow(kernelEval, 2.0));
213 maxKernel = kernelEval + furthestDist * queryKernels[queryIndex];
218 return (maxKernel >= bestKernel) ? (1.0 / maxKernel) : DBL_MAX;
221 template<
typename KernelType,
typename TreeType>
223 TreeType& referenceNode)
226 queryNode.Stat().Bound() = CalculateBound(queryNode);
227 const double bestKernel = queryNode.Stat().Bound();
234 const double queryParentDist = queryNode.ParentDistance();
235 const double queryDescDist = queryNode.FurthestDescendantDistance();
236 const double refParentDist = referenceNode.ParentDistance();
237 const double refDescDist = referenceNode.FurthestDescendantDistance();
240 const double queryDistBound = (queryParentDist + queryDescDist);
241 const double refDistBound = (refParentDist + refDescDist);
242 double dualQueryTerm;
260 adjustedScore += queryDistBound *
262 dualQueryTerm = queryDistBound;
270 adjustedScore += queryDescDist *
272 dualQueryTerm = queryDescDist;
279 adjustedScore = bestKernel;
288 adjustedScore += refDistBound *
290 dualRefTerm = refDistBound;
298 adjustedScore += refDescDist *
300 dualRefTerm = refDescDist;
307 adjustedScore = bestKernel;
312 adjustedScore += (dualQueryTerm * dualRefTerm);
314 if (adjustedScore < bestKernel)
324 double kernelEval = 0.0;
331 (traversalInfo.
LastQueryNode()->Point(0) == queryNode.Point(0)) &&
339 lastQueryIndex = queryNode.Point(0);
340 lastReferenceIndex = referenceNode.Point(0);
347 kernelEval =
BaseCase(queryNode.Point(0), referenceNode.Point(0));
355 arma::vec queryCenter;
357 queryNode.Center(queryCenter);
358 referenceNode.Center(refCenter);
360 kernelEval = kernel.Evaluate(queryCenter, refCenter);
370 const double querySqDist = std::pow(queryDescDist, 2.0);
371 const double refSqDist = std::pow(refDescDist, 2.0);
372 const double bothSqDist = std::pow((queryDescDist + refDescDist), 2.0);
374 if (kernelEval <= (1 - 0.5 * bothSqDist))
376 const double queryDelta = (1 - 0.5 * querySqDist);
377 const double queryGamma = queryDescDist * sqrt(1 - 0.25 * querySqDist);
378 const double refDelta = (1 - 0.5 * refSqDist);
379 const double refGamma = refDescDist * sqrt(1 - 0.25 * refSqDist);
381 maxKernel = kernelEval * (queryDelta * refDelta - queryGamma * refGamma) +
382 sqrt(1 - std::pow(kernelEval, 2.0)) *
383 (queryGamma * refDelta + queryDelta * refGamma);
393 const double refKernelTerm = queryDescDist *
394 referenceNode.Stat().SelfKernel();
395 const double queryKernelTerm = refDescDist * queryNode.Stat().SelfKernel();
397 maxKernel = kernelEval + refKernelTerm + queryKernelTerm +
398 (queryDescDist * refDescDist);
407 return (maxKernel >= bestKernel) ? (1.0 / maxKernel) : DBL_MAX;
410 template<
typename KernelType,
typename TreeType>
413 const double oldScore)
const 415 const double bestKernel = candidates[queryIndex].top().first;
417 return ((1.0 / oldScore) >= bestKernel) ? oldScore : DBL_MAX;
420 template<
typename KernelType,
typename TreeType>
423 const double oldScore)
const 425 queryNode.Stat().Bound() = CalculateBound(queryNode);
426 const double bestKernel = queryNode.Stat().Bound();
428 return ((1.0 / oldScore) >= bestKernel) ? oldScore : DBL_MAX;
438 template<
typename KernelType,
typename TreeType>
451 double worstPointKernel = DBL_MAX;
452 double bestAdjustedPointKernel = -DBL_MAX;
454 const double queryDescendantDistance = queryNode.FurthestDescendantDistance();
459 for (
size_t i = 0; i < queryNode.NumPoints(); ++i)
461 const size_t point = queryNode.Point(i);
462 const CandidateList& candidatesPoints = candidates[point];
463 if (candidatesPoints.top().first < worstPointKernel)
464 worstPointKernel = candidatesPoints.top().first;
466 if (candidatesPoints.top().first == -DBL_MAX)
480 double worstPointCandidateKernel = DBL_MAX;
481 typedef typename CandidateList::const_iterator iter;
482 for (iter it = candidatesPoints.begin(); it != candidatesPoints.end(); ++it)
484 const double candidateKernel = it->first - queryDescendantDistance *
485 referenceKernels[it->second];
486 if (candidateKernel < worstPointCandidateKernel)
487 worstPointCandidateKernel = candidateKernel;
490 if (worstPointCandidateKernel > bestAdjustedPointKernel)
491 bestAdjustedPointKernel = worstPointCandidateKernel;
495 double worstChildKernel = DBL_MAX;
497 for (
size_t i = 0; i < queryNode.NumChildren(); ++i)
499 if (queryNode.Child(i).Stat().Bound() < worstChildKernel)
500 worstChildKernel = queryNode.Child(i).Stat().Bound();
504 const double firstBound = (worstPointKernel < worstChildKernel) ?
505 worstPointKernel : worstChildKernel;
508 const double fourthBound = (queryNode.Parent() == NULL) ? -DBL_MAX :
509 queryNode.Parent()->Stat().Bound();
512 const double interA = (firstBound > bestAdjustedPointKernel) ? firstBound :
513 bestAdjustedPointKernel;
514 const double interB = fourthBound;
516 return (interA > interB) ? interA : interB;
526 template<
typename KernelType,
typename TreeType>
528 const size_t queryIndex,
530 const double product)
532 CandidateList& pqueue = candidates[queryIndex];
533 if (product > pqueue.top().first)
535 Candidate c = std::make_pair(product, index);
double Rescore(const size_t queryIndex, TreeType &referenceNode, const double oldScore) const
Re-evaluate the score for recursion order.
Definition: fastmks_rules_impl.hpp:411
This is a template class that can provide information about various kernels.
Definition: kernel_traits.hpp:27
void GetResults(arma::Mat< size_t > &indices, arma::mat &products)
Store the list of candidates for each query point in the given matrices.
Definition: fastmks_rules_impl.hpp:68
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
double BaseCase(const size_t queryIndex, const size_t referenceIndex)
Compute the base case (kernel value) between two points.
Definition: fastmks_rules_impl.hpp:89
FastMKSRules(const typename TreeType::Mat &referenceSet, const typename TreeType::Mat &querySet, const size_t k, KernelType &kernel)
Construct the FastMKSRules object.
Definition: fastmks_rules_impl.hpp:22
double Score(const size_t queryIndex, TreeType &referenceNode)
Get the score for recursion order.
Definition: fastmks_rules_impl.hpp:128
double LastBaseCase() const
Get the base case associated with the last node combination.
Definition: traversal_info.hpp:78
TreeType * LastQueryNode() const
Get the last query node.
Definition: traversal_info.hpp:63
The FastMKSRules class is a template helper class used by FastMKS class when performing exact max-ker...
Definition: fastmks_rules.hpp:34
The TreeTraits class provides compile-time information on the characteristics of a given tree type...
Definition: tree_traits.hpp:77
TreeType * LastReferenceNode() const
Get the last reference node.
Definition: traversal_info.hpp:68