13 #ifndef MLPACK_METHODS_KDE_RULES_IMPL_HPP 14 #define MLPACK_METHODS_KDE_RULES_IMPL_HPP 20 #include <boost/math/distributions/normal.hpp> 25 template<
typename MetricType,
typename KernelType,
typename TreeType>
27 const arma::mat& referenceSet,
28 const arma::mat& querySet,
30 const double relError,
31 const double absError,
33 const size_t initialSampleSize,
34 const double mcAccessCoef,
35 const double mcBreakCoef,
38 const bool monteCarlo,
40 referenceSet(referenceSet),
46 initialSampleSize(initialSampleSize),
47 mcAccessCoef(mcAccessCoef),
48 mcBreakCoef(mcBreakCoef),
51 monteCarlo(monteCarlo),
53 absErrorTol(absError / referenceSet.n_cols),
54 lastQueryIndex(querySet.n_cols),
55 lastReferenceIndex(referenceSet.n_cols),
60 accumError = arma::vec(querySet.n_cols, arma::fill::zeros);
63 if (monteCarlo && kernelIsGaussian)
64 accumMCAlpha = arma::vec(querySet.n_cols, arma::fill::zeros);
68 template<
typename MetricType,
typename KernelType,
typename TreeType>
71 const size_t queryIndex,
72 const size_t referenceIndex)
76 if (sameSet && (queryIndex == referenceIndex))
80 if ((lastQueryIndex == queryIndex) && (lastReferenceIndex == referenceIndex))
84 const double distance = metric.Evaluate(querySet.col(queryIndex),
85 referenceSet.col(referenceIndex));
86 const double kernelValue = kernel.Evaluate(distance);
87 densities(queryIndex) += kernelValue;
90 accumError(queryIndex) += 2 * relError * kernelValue;
93 lastQueryIndex = queryIndex;
94 lastReferenceIndex = referenceIndex;
100 template<
typename MetricType,
typename KernelType,
typename TreeType>
102 Score(
const size_t queryIndex, TreeType& referenceNode)
105 const arma::vec& queryPoint = querySet.unsafe_col(queryIndex);
106 const size_t refNumDesc = referenceNode.NumDescendants();
107 double score, minDistance, maxDistance, depthAlpha;
109 bool alreadyDidRefPoint0 =
false;
112 if (monteCarlo && kernelIsGaussian)
113 depthAlpha = CalculateAlpha(&referenceNode);
118 lastQueryIndex == queryIndex &&
120 lastReferenceIndex == referenceNode.Point(0))
123 alreadyDidRefPoint0 =
true;
124 const double furthestDescDist = referenceNode.FurthestDescendantDistance();
125 minDistance = std::max(traversalInfo.
LastBaseCase() - furthestDescDist,
127 maxDistance = traversalInfo.
LastBaseCase() + furthestDescDist;
132 const math::Range r = referenceNode.RangeDistance(queryPoint);
133 minDistance = r.
Lo();
134 maxDistance = r.
Hi();
138 referenceNode.Parent() != NULL &&
139 referenceNode.Parent()->Point(0) == referenceNode.Point(0))
141 alreadyDidRefPoint0 =
true;
145 const double maxKernel = kernel.Evaluate(minDistance);
146 const double minKernel = kernel.Evaluate(maxDistance);
147 const double bound = maxKernel - minKernel;
151 const double relErrorTol = relError * minKernel;
152 const double errorTolerance = absErrorTol + relErrorTol;
157 double pointAccumErrorTol;
158 if (alreadyDidRefPoint0)
159 pointAccumErrorTol = accumError(queryIndex) / (refNumDesc - 1);
161 pointAccumErrorTol = accumError(queryIndex) / refNumDesc;
163 if (bound <= 2 * errorTolerance + pointAccumErrorTol)
166 const double kernelValue = (maxKernel + minKernel) / 2.0;
168 if (alreadyDidRefPoint0)
169 densities(queryIndex) += (refNumDesc - 1) * kernelValue;
171 densities(queryIndex) += refNumDesc * kernelValue;
178 if (alreadyDidRefPoint0)
179 accumError(queryIndex) -= (refNumDesc - 1) * (bound - 2 * errorTolerance);
181 accumError(queryIndex) -= refNumDesc * (bound - 2 * errorTolerance);
184 if (kernelIsGaussian && monteCarlo)
185 accumMCAlpha(queryIndex) += depthAlpha;
187 else if (monteCarlo &&
188 refNumDesc >= mcAccessCoef * initialSampleSize &&
193 const double alpha = depthAlpha + accumMCAlpha(queryIndex);
194 const boost::math::normal normalDist;
196 std::abs(boost::math::quantile(normalDist, alpha / 2));
200 size_t m = initialSampleSize;
201 double meanSample = 0;
202 bool useMonteCarloPredictions =
true;
207 const size_t oldSize = sample.size();
208 const size_t newSize = oldSize + m;
212 if (newSize >= mcBreakCoef * refNumDesc)
214 useMonteCarloPredictions =
false;
219 sample.resize(newSize);
220 for (
size_t i = 0; i < m; ++i)
224 if (alreadyDidRefPoint0)
229 sample(oldSize + i) =
230 EvaluateKernel(queryIndex, referenceNode.Descendant(randomPoint));
232 meanSample = arma::mean(sample);
233 const double stddev = arma::stddev(sample);
234 const double mThreshBase =
235 z * stddev * (1 + relError) / (relError * meanSample);
236 const size_t mThresh = std::ceil(mThreshBase * mThreshBase);
238 if (sample.size() < mThresh)
239 m = mThresh - sample.size();
244 if (useMonteCarloPredictions)
247 if (alreadyDidRefPoint0)
248 densities(queryIndex) += (refNumDesc - 1) * meanSample;
250 densities(queryIndex) += refNumDesc * meanSample;
256 accumMCAlpha(queryIndex) = 0;
263 if (referenceNode.IsLeaf())
266 accumMCAlpha(queryIndex) += depthAlpha;
276 if (referenceNode.IsLeaf())
278 if (alreadyDidRefPoint0)
279 accumError(queryIndex) += (refNumDesc - 1) * 2 * absErrorTol;
281 accumError(queryIndex) += refNumDesc * 2 * absErrorTol;
286 if (kernelIsGaussian && monteCarlo && referenceNode.IsLeaf())
287 accumMCAlpha(queryIndex) += depthAlpha;
296 template<
typename MetricType,
typename KernelType,
typename TreeType>
300 const double oldScore)
const 307 template<
typename MetricType,
typename KernelType,
typename TreeType>
309 Score(TreeType& queryNode, TreeType& referenceNode)
312 const size_t refNumDesc = referenceNode.NumDescendants();
313 double score, minDistance, maxDistance, depthAlpha;
315 bool alreadyDidRefPoint0 =
false;
318 if (monteCarlo && kernelIsGaussian)
319 depthAlpha = CalculateAlpha(&referenceNode);
325 const bool canReclaimAlpha = kernelIsGaussian &&
327 referenceNode.IsLeaf() &&
333 (traversalInfo.
LastQueryNode()->Point(0) == queryNode.Point(0)) &&
337 alreadyDidRefPoint0 =
true;
338 lastQueryIndex = queryNode.Point(0);
339 lastReferenceIndex = referenceNode.Point(0);
342 const double refFurtDescDist = referenceNode.FurthestDescendantDistance();
343 const double queryFurtDescDist = queryNode.FurthestDescendantDistance();
344 const double sumFurtDescDist = refFurtDescDist + queryFurtDescDist;
345 minDistance = std::max(traversalInfo.
LastBaseCase() - sumFurtDescDist, 0.0);
346 maxDistance = traversalInfo.
LastBaseCase() + sumFurtDescDist;
351 const math::Range r = queryNode.RangeDistance(referenceNode);
352 minDistance = r.
Lo();
353 maxDistance = r.
Hi();
356 const double maxKernel = kernel.Evaluate(minDistance);
357 const double minKernel = kernel.Evaluate(maxDistance);
358 const double bound = maxKernel - minKernel;
362 const double relErrorTol = relError * minKernel;
363 const double errorTolerance = absErrorTol + relErrorTol;
368 const double pointAccumErrorTol = queryStat.
AccumError() / refNumDesc;
371 if (bound <= 2 * errorTolerance + pointAccumErrorTol)
374 const double kernelValue = (maxKernel + minKernel) / 2.0;
377 for (
size_t i = 0; i < queryNode.NumDescendants(); ++i)
379 if (alreadyDidRefPoint0 && i == 0)
380 densities(queryNode.Descendant(i)) += (refNumDesc - 1) * kernelValue;
382 densities(queryNode.Descendant(i)) += refNumDesc * kernelValue;
390 queryStat.
AccumError() -= refNumDesc * (bound - 2 * errorTolerance);
393 if (kernelIsGaussian && monteCarlo)
396 else if (monteCarlo &&
397 refNumDesc >= mcAccessCoef * initialSampleSize &&
402 const double alpha = depthAlpha + queryStat.
AccumAlpha();
403 const boost::math::normal normalDist;
405 std::abs(boost::math::quantile(normalDist, alpha / 2));
409 arma::vec means = arma::zeros(queryNode.NumDescendants());
411 double meanSample = 0;
412 bool useMonteCarloPredictions =
true;
415 for (
size_t i = 0; i < queryNode.NumDescendants(); ++i)
417 const size_t queryIndex = queryNode.Descendant(i);
419 m = initialSampleSize;
424 const size_t oldSize = sample.size();
425 const size_t newSize = oldSize + m;
429 if (newSize >= mcBreakCoef * refNumDesc)
431 useMonteCarloPredictions =
false;
436 sample.resize(newSize);
437 for (
size_t i = 0; i < m; ++i)
441 if (alreadyDidRefPoint0)
446 sample(oldSize + i) =
447 EvaluateKernel(queryIndex, referenceNode.Descendant(randomPoint));
449 meanSample = arma::mean(sample);
450 const double stddev = arma::stddev(sample);
451 const double mThreshBase =
452 z * stddev * (1 + relError) / (relError * meanSample);
453 const size_t mThresh = std::ceil(mThreshBase * mThreshBase);
455 if (sample.size() < mThresh)
456 m = mThresh - sample.size();
462 if (useMonteCarloPredictions)
463 means(i) = meanSample;
468 if (useMonteCarloPredictions)
471 for (
size_t i = 0; i < queryNode.NumDescendants(); ++i)
473 if (alreadyDidRefPoint0 && i == 0)
474 densities(queryNode.Descendant(i)) += (refNumDesc - 1) * means(i);
476 densities(queryNode.Descendant(i)) += refNumDesc * means(i);
504 if (referenceNode.IsLeaf() && queryNode.IsLeaf())
505 queryStat.
AccumError() += refNumDesc * 2 * errorTolerance;
521 template<
typename MetricType,
typename KernelType,
typename TreeType>
525 const double oldScore)
const 531 template<
typename MetricType,
typename KernelType,
typename TreeType>
534 const size_t referenceIndex)
const 536 return EvaluateKernel(querySet.unsafe_col(queryIndex),
537 referenceSet.unsafe_col(referenceIndex));
540 template<
typename MetricType,
typename KernelType,
typename TreeType>
542 EvaluateKernel(
const arma::vec& query,
const arma::vec& reference)
const 544 return kernel.Evaluate(metric.Evaluate(query, reference));
547 template<
typename MetricType,
typename KernelType,
typename TreeType>
555 if (std::abs(stat.
MCBeta() - mcBeta) > DBL_EPSILON)
557 TreeType* parent = node->Parent();
566 stat.
MCAlpha() = parent->Stat().MCAlpha() / parent->NumChildren();
577 template<
typename TreeType>
586 template<
typename TreeType>
589 TreeType& referenceNode)
591 referenceNode.Stat().AccumAlpha() = 0;
592 referenceNode.Stat().AccumError() = 0;
597 template<
typename TreeType>
600 TreeType& referenceNode)
602 queryNode.Stat().AccumAlpha() = 0;
603 referenceNode.Stat().AccumAlpha() = 0;
605 queryNode.Stat().AccumError() = 0;
606 referenceNode.Stat().AccumError() = 0;
T Lo() const
Get the lower bound.
Definition: range.hpp:61
double AccumError() const
Get accumulated error tolerance of the node.
Definition: kde_stat.hpp:57
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Extra data for each node in the tree for the task of kernel density estimation.
Definition: kde_stat.hpp:24
double MCBeta() const
Get accumulated Monte Carlo alpha of the node.
Definition: kde_stat.hpp:45
double AccumAlpha() const
Get accumulated Monte Carlo alpha of the node.
Definition: kde_stat.hpp:51
double LastBaseCase() const
Get the base case associated with the last node combination.
Definition: traversal_info.hpp:78
TreeType * LastQueryNode() const
Get the last query node.
Definition: traversal_info.hpp:63
KDERules(const arma::mat &referenceSet, const arma::mat &querySet, arma::vec &densities, const double relError, const double absError, const double mcProb, const size_t initialSampleSize, const double mcAccessCoef, const double mcBreakCoef, MetricType &metric, KernelType &kernel, const bool monteCarlo, const bool sameSet)
Construct KDERules.
Definition: kde_rules_impl.hpp:26
double LastScore() const
Get the score associated with the last query and reference nodes.
Definition: traversal_info.hpp:73
A dual-tree traversal Rules class for kernel density estimation.
Definition: kde_rules.hpp:26
double BaseCase(const size_t queryIndex, const size_t referenceIndex)
Base Case.
Definition: kde_rules_impl.hpp:70
T Hi() const
Get the upper bound.
Definition: range.hpp:66
double BaseCase(const size_t, const size_t)
Base Case.
Definition: kde_rules_impl.hpp:579
The TreeTraits class provides compile-time information on the characteristics of a given tree type...
Definition: tree_traits.hpp:77
double Score(const size_t, TreeType &referenceNode)
SingleTree Score.
Definition: kde_rules_impl.hpp:588
int RandInt(const int hiExclusive)
Generates a uniform random integer.
Definition: random.hpp:110
double MCAlpha() const
Get Monte Carlo alpha of the node.
Definition: kde_stat.hpp:63
TreeType * LastReferenceNode() const
Get the last reference node.
Definition: traversal_info.hpp:68
double Rescore(const size_t queryIndex, TreeType &referenceNode, const double oldScore) const
SingleTree Score.
Definition: kde_rules_impl.hpp:298
double Score(const size_t queryIndex, TreeType &referenceNode)
SingleTree Rescore.
Definition: kde_rules_impl.hpp:102