mlpack
kde_rules_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_KDE_RULES_IMPL_HPP
14 #define MLPACK_METHODS_KDE_RULES_IMPL_HPP
15 
16 // In case it hasn't been included yet.
17 #include "kde_rules.hpp"
18 
19 // Used for Monte Carlo estimation.
20 #include <boost/math/distributions/normal.hpp>
21 
22 namespace mlpack {
23 namespace kde {
24 
25 template<typename MetricType, typename KernelType, typename TreeType>
27  const arma::mat& referenceSet,
28  const arma::mat& querySet,
29  arma::vec& densities,
30  const double relError,
31  const double absError,
32  const double mcProb,
33  const size_t initialSampleSize,
34  const double mcAccessCoef,
35  const double mcBreakCoef,
36  MetricType& metric,
37  KernelType& kernel,
38  const bool monteCarlo,
39  const bool sameSet) :
40  referenceSet(referenceSet),
41  querySet(querySet),
42  densities(densities),
43  absError(absError),
44  relError(relError),
45  mcBeta(1 - mcProb),
46  initialSampleSize(initialSampleSize),
47  mcAccessCoef(mcAccessCoef),
48  mcBreakCoef(mcBreakCoef),
49  metric(metric),
50  kernel(kernel),
51  monteCarlo(monteCarlo),
52  sameSet(sameSet),
53  absErrorTol(absError / referenceSet.n_cols),
54  lastQueryIndex(querySet.n_cols),
55  lastReferenceIndex(referenceSet.n_cols),
56  baseCases(0),
57  scores(0)
58 {
59  // Initialize accumError.
60  accumError = arma::vec(querySet.n_cols, arma::fill::zeros);
61 
62  // Initialize accumMCAlpha only if Monte Carlo estimations are available.
63  if (monteCarlo && kernelIsGaussian)
64  accumMCAlpha = arma::vec(querySet.n_cols, arma::fill::zeros);
65 }
66 
68 template<typename MetricType, typename KernelType, typename TreeType>
69 inline force_inline
71  const size_t queryIndex,
72  const size_t referenceIndex)
73 {
74  // If reference and query sets are the same we don't want to compute the
75  // estimation of a point with itself.
76  if (sameSet && (queryIndex == referenceIndex))
77  return 0.0;
78 
79  // Avoid duplicated calculations.
80  if ((lastQueryIndex == queryIndex) && (lastReferenceIndex == referenceIndex))
81  return 0.0;
82 
83  // Calculations.
84  const double distance = metric.Evaluate(querySet.col(queryIndex),
85  referenceSet.col(referenceIndex));
86  const double kernelValue = kernel.Evaluate(distance);
87  densities(queryIndex) += kernelValue;
88 
89  // Update accumulated relative error tolerance for single-tree pruning.
90  accumError(queryIndex) += 2 * relError * kernelValue;
91 
92  ++baseCases;
93  lastQueryIndex = queryIndex;
94  lastReferenceIndex = referenceIndex;
95  traversalInfo.LastBaseCase() = distance;
96  return distance;
97 }
98 
100 template<typename MetricType, typename KernelType, typename TreeType>
102 Score(const size_t queryIndex, TreeType& referenceNode)
103 {
104  // Auxiliary variables.
105  const arma::vec& queryPoint = querySet.unsafe_col(queryIndex);
106  const size_t refNumDesc = referenceNode.NumDescendants();
107  double score, minDistance, maxDistance, depthAlpha;
108  // Calculations are not duplicated.
109  bool alreadyDidRefPoint0 = false;
110 
111  // Calculate alpha if Monte Carlo is available.
112  if (monteCarlo && kernelIsGaussian)
113  depthAlpha = CalculateAlpha(&referenceNode);
114  else
115  depthAlpha = -1;
116 
118  lastQueryIndex == queryIndex &&
119  traversalInfo.LastReferenceNode() != NULL &&
120  lastReferenceIndex == referenceNode.Point(0))
121  {
122  // Don't duplicate calculations.
123  alreadyDidRefPoint0 = true;
124  const double furthestDescDist = referenceNode.FurthestDescendantDistance();
125  minDistance = std::max(traversalInfo.LastBaseCase() - furthestDescDist,
126  0.0);
127  maxDistance = traversalInfo.LastBaseCase() + furthestDescDist;
128  }
129  else
130  {
131  // All Calculations are new.
132  const math::Range r = referenceNode.RangeDistance(queryPoint);
133  minDistance = r.Lo();
134  maxDistance = r.Hi();
135 
136  // Check if we are a self-child.
138  referenceNode.Parent() != NULL &&
139  referenceNode.Parent()->Point(0) == referenceNode.Point(0))
140  {
141  alreadyDidRefPoint0 = true;
142  }
143  }
144 
145  const double maxKernel = kernel.Evaluate(minDistance);
146  const double minKernel = kernel.Evaluate(maxDistance);
147  const double bound = maxKernel - minKernel;
148 
149  // Error tolerance of the current combination of query point and reference
150  // node.
151  const double relErrorTol = relError * minKernel;
152  const double errorTolerance = absErrorTol + relErrorTol;
153 
154  // We relax the bound for pruning by accumError(queryIndex), so that if there
155  // is any leftover error tolerance from the rest of the traversal, we can use
156  // it here to prune more.
157  double pointAccumErrorTol;
158  if (alreadyDidRefPoint0)
159  pointAccumErrorTol = accumError(queryIndex) / (refNumDesc - 1);
160  else
161  pointAccumErrorTol = accumError(queryIndex) / refNumDesc;
162 
163  if (bound <= 2 * errorTolerance + pointAccumErrorTol)
164  {
165  // Estimate kernel value.
166  const double kernelValue = (maxKernel + minKernel) / 2.0;
167 
168  if (alreadyDidRefPoint0)
169  densities(queryIndex) += (refNumDesc - 1) * kernelValue;
170  else
171  densities(queryIndex) += refNumDesc * kernelValue;
172 
173  // Don't explore this tree branch.
174  score = DBL_MAX;
175 
176  // Subtract used error tolerance or add extra available tolerace from this
177  // prune.
178  if (alreadyDidRefPoint0)
179  accumError(queryIndex) -= (refNumDesc - 1) * (bound - 2 * errorTolerance);
180  else
181  accumError(queryIndex) -= refNumDesc * (bound - 2 * errorTolerance);
182 
183  // Store not used alpha for Monte Carlo.
184  if (kernelIsGaussian && monteCarlo)
185  accumMCAlpha(queryIndex) += depthAlpha;
186  }
187  else if (monteCarlo &&
188  refNumDesc >= mcAccessCoef * initialSampleSize &&
189  kernelIsGaussian)
190  {
191  // Monte Carlo probabilistic estimation.
192  // Calculate z using accumulated alpha if possible.
193  const double alpha = depthAlpha + accumMCAlpha(queryIndex);
194  const boost::math::normal normalDist;
195  const double z =
196  std::abs(boost::math::quantile(normalDist, alpha / 2));
197 
198  // Auxiliary variables.
199  arma::vec sample;
200  size_t m = initialSampleSize;
201  double meanSample = 0;
202  bool useMonteCarloPredictions = true;
203 
204  // Resample as long as confidence is not high enough.
205  while (m > 0)
206  {
207  const size_t oldSize = sample.size();
208  const size_t newSize = oldSize + m;
209 
210  // Don't use probabilistic estimation if this is going to take a similar
211  // amount of computations to the exact calculation.
212  if (newSize >= mcBreakCoef * refNumDesc)
213  {
214  useMonteCarloPredictions = false;
215  break;
216  }
217 
218  // Increase the sample size.
219  sample.resize(newSize);
220  for (size_t i = 0; i < m; ++i)
221  {
222  // Sample and evaluate random points from the reference node.
223  size_t randomPoint;
224  if (alreadyDidRefPoint0)
225  randomPoint = math::RandInt(1, refNumDesc);
226  else
227  randomPoint = math::RandInt(0, refNumDesc);
228 
229  sample(oldSize + i) =
230  EvaluateKernel(queryIndex, referenceNode.Descendant(randomPoint));
231  }
232  meanSample = arma::mean(sample);
233  const double stddev = arma::stddev(sample);
234  const double mThreshBase =
235  z * stddev * (1 + relError) / (relError * meanSample);
236  const size_t mThresh = std::ceil(mThreshBase * mThreshBase);
237 
238  if (sample.size() < mThresh)
239  m = mThresh - sample.size();
240  else
241  m = 0;
242  }
243 
244  if (useMonteCarloPredictions)
245  {
246  // Confidence is high enough so we can use Monte Carlo estimation.
247  if (alreadyDidRefPoint0)
248  densities(queryIndex) += (refNumDesc - 1) * meanSample;
249  else
250  densities(queryIndex) += refNumDesc * meanSample;
251 
252  // Prune.
253  score = DBL_MAX;
254 
255  // Accumulated alpha has been used.
256  accumMCAlpha(queryIndex) = 0;
257  }
258  else
259  {
260  // Recurse.
261  score = minDistance;
262 
263  if (referenceNode.IsLeaf())
264  {
265  // Reclaim not used alpha since the node will be exactly computed.
266  accumMCAlpha(queryIndex) += depthAlpha;
267  }
268  }
269  }
270  else
271  {
272  // Recurse.
273  score = minDistance;
274 
275  // Add accumulated unused absolute error tolerance.
276  if (referenceNode.IsLeaf())
277  {
278  if (alreadyDidRefPoint0)
279  accumError(queryIndex) += (refNumDesc - 1) * 2 * absErrorTol;
280  else
281  accumError(queryIndex) += refNumDesc * 2 * absErrorTol;
282  }
283 
284  // If node is going to be exactly computed, reclaim not used alpha for
285  // Monte Carlo estimations.
286  if (kernelIsGaussian && monteCarlo && referenceNode.IsLeaf())
287  accumMCAlpha(queryIndex) += depthAlpha;
288  }
289 
290  ++scores;
291  traversalInfo.LastReferenceNode() = &referenceNode;
292  traversalInfo.LastScore() = score;
293  return score;
294 }
295 
296 template<typename MetricType, typename KernelType, typename TreeType>
297 inline force_inline double KDERules<MetricType, KernelType, TreeType>::
298 Rescore(const size_t /* queryIndex */,
299  TreeType& /* referenceNode */,
300  const double oldScore) const
301 {
302  // If it's pruned it continues to be pruned.
303  return oldScore;
304 }
305 
307 template<typename MetricType, typename KernelType, typename TreeType>
309 Score(TreeType& queryNode, TreeType& referenceNode)
310 {
311  kde::KDEStat& queryStat = queryNode.Stat();
312  const size_t refNumDesc = referenceNode.NumDescendants();
313  double score, minDistance, maxDistance, depthAlpha;
314  // Calculations are not duplicated.
315  bool alreadyDidRefPoint0 = false;
316 
317  // Calculate alpha if Monte Carlo is available.
318  if (monteCarlo && kernelIsGaussian)
319  depthAlpha = CalculateAlpha(&referenceNode);
320  else
321  depthAlpha = -1;
322 
323  // Check if not used Monte Carlo alpha can be reclaimed for this combination
324  // of nodes.
325  const bool canReclaimAlpha = kernelIsGaussian &&
326  monteCarlo &&
327  referenceNode.IsLeaf() &&
328  queryNode.IsLeaf();
329 
331  (traversalInfo.LastQueryNode() != NULL) &&
332  (traversalInfo.LastReferenceNode() != NULL) &&
333  (traversalInfo.LastQueryNode()->Point(0) == queryNode.Point(0)) &&
334  (traversalInfo.LastReferenceNode()->Point(0) == referenceNode.Point(0)))
335  {
336  // Don't duplicate calculations.
337  alreadyDidRefPoint0 = true;
338  lastQueryIndex = queryNode.Point(0);
339  lastReferenceIndex = referenceNode.Point(0);
340 
341  // Calculate min and max distance.
342  const double refFurtDescDist = referenceNode.FurthestDescendantDistance();
343  const double queryFurtDescDist = queryNode.FurthestDescendantDistance();
344  const double sumFurtDescDist = refFurtDescDist + queryFurtDescDist;
345  minDistance = std::max(traversalInfo.LastBaseCase() - sumFurtDescDist, 0.0);
346  maxDistance = traversalInfo.LastBaseCase() + sumFurtDescDist;
347  }
348  else
349  {
350  // All calculations are new.
351  const math::Range r = queryNode.RangeDistance(referenceNode);
352  minDistance = r.Lo();
353  maxDistance = r.Hi();
354  }
355 
356  const double maxKernel = kernel.Evaluate(minDistance);
357  const double minKernel = kernel.Evaluate(maxDistance);
358  const double bound = maxKernel - minKernel;
359 
360  // Error tolerance of the current combination of query node and reference
361  // node.
362  const double relErrorTol = relError * minKernel;
363  const double errorTolerance = absErrorTol + relErrorTol;
364 
365  // We relax the bound for pruning by queryStat.AccumError(), so that if there
366  // is any leftover error tolerance from the rest of the traversal, we can use
367  // it here to prune more.
368  const double pointAccumErrorTol = queryStat.AccumError() / refNumDesc;
369 
370  // If possible, avoid some calculations because of the error tolerance.
371  if (bound <= 2 * errorTolerance + pointAccumErrorTol)
372  {
373  // Estimate kernel value.
374  const double kernelValue = (maxKernel + minKernel) / 2.0;
375 
376  // Sum up estimations.
377  for (size_t i = 0; i < queryNode.NumDescendants(); ++i)
378  {
379  if (alreadyDidRefPoint0 && i == 0)
380  densities(queryNode.Descendant(i)) += (refNumDesc - 1) * kernelValue;
381  else
382  densities(queryNode.Descendant(i)) += refNumDesc * kernelValue;
383  }
384 
385  // Prune.
386  score = DBL_MAX;
387 
388  // Subtract used error tolerance or add extra available tolerace from this
389  // prune.
390  queryStat.AccumError() -= refNumDesc * (bound - 2 * errorTolerance);
391 
392  // Store not used alpha for Monte Carlo.
393  if (kernelIsGaussian && monteCarlo)
394  queryStat.AccumAlpha() += depthAlpha;
395  }
396  else if (monteCarlo &&
397  refNumDesc >= mcAccessCoef * initialSampleSize &&
398  kernelIsGaussian)
399  {
400  // Monte Carlo probabilistic estimation.
401  // Calculate z using accumulated alpha if possible.
402  const double alpha = depthAlpha + queryStat.AccumAlpha();
403  const boost::math::normal normalDist;
404  const double z =
405  std::abs(boost::math::quantile(normalDist, alpha / 2));
406 
407  // Auxiliary variables.
408  arma::vec sample;
409  arma::vec means = arma::zeros(queryNode.NumDescendants());
410  size_t m;
411  double meanSample = 0;
412  bool useMonteCarloPredictions = true;
413 
414  // Pick a sample for every query node.
415  for (size_t i = 0; i < queryNode.NumDescendants(); ++i)
416  {
417  const size_t queryIndex = queryNode.Descendant(i);
418  sample.clear();
419  m = initialSampleSize;
420 
421  // Resample as long as confidence is not high enough.
422  while (m > 0)
423  {
424  const size_t oldSize = sample.size();
425  const size_t newSize = oldSize + m;
426 
427  // Don't use probabilistic estimation if this is going to take a similar
428  // amount of computations to the exact calculation.
429  if (newSize >= mcBreakCoef * refNumDesc)
430  {
431  useMonteCarloPredictions = false;
432  break;
433  }
434 
435  // Increase the sample size.
436  sample.resize(newSize);
437  for (size_t i = 0; i < m; ++i)
438  {
439  // Sample and evaluate random points from the reference node.
440  size_t randomPoint;
441  if (alreadyDidRefPoint0)
442  randomPoint = math::RandInt(1, refNumDesc);
443  else
444  randomPoint = math::RandInt(0, refNumDesc);
445 
446  sample(oldSize + i) =
447  EvaluateKernel(queryIndex, referenceNode.Descendant(randomPoint));
448  }
449  meanSample = arma::mean(sample);
450  const double stddev = arma::stddev(sample);
451  const double mThreshBase =
452  z * stddev * (1 + relError) / (relError * meanSample);
453  const size_t mThresh = std::ceil(mThreshBase * mThreshBase);
454 
455  if (sample.size() < mThresh)
456  m = mThresh - sample.size();
457  else
458  m = 0;
459  }
460 
461  // Store mean for the i_th query node descendant point.
462  if (useMonteCarloPredictions)
463  means(i) = meanSample;
464  else
465  break;
466  }
467 
468  if (useMonteCarloPredictions)
469  {
470  // Confidence is high enough so we can use Monte Carlo estimation.
471  for (size_t i = 0; i < queryNode.NumDescendants(); ++i)
472  {
473  if (alreadyDidRefPoint0 && i == 0)
474  densities(queryNode.Descendant(i)) += (refNumDesc - 1) * means(i);
475  else
476  densities(queryNode.Descendant(i)) += refNumDesc * means(i);
477  }
478 
479  // Prune.
480  score = DBL_MAX;
481 
482  // Accumulated alpha has been used.
483  queryStat.AccumAlpha() = 0;
484  }
485  else
486  {
487  // Recurse.
488  score = minDistance;
489 
490  if (canReclaimAlpha)
491  {
492  // Reclaim not used Monte Carlo alpha since the nodes will be
493  // exactly computed.
494  queryStat.AccumAlpha() += depthAlpha;
495  }
496  }
497  }
498  else
499  {
500  // Recurse.
501  score = minDistance;
502 
503  // Add accumulated unused error tolerance.
504  if (referenceNode.IsLeaf() && queryNode.IsLeaf())
505  queryStat.AccumError() += refNumDesc * 2 * errorTolerance;
506 
507  // If node is going to be exactly computed, reclaim not used alpha for
508  // Monte Carlo estimations.
509  if (canReclaimAlpha)
510  queryStat.AccumAlpha() += depthAlpha;
511  }
512 
513  ++scores;
514  traversalInfo.LastQueryNode() = &queryNode;
515  traversalInfo.LastReferenceNode() = &referenceNode;
516  traversalInfo.LastScore() = score;
517  return score;
518 }
519 
521 template<typename MetricType, typename KernelType, typename TreeType>
522 inline force_inline double KDERules<MetricType, KernelType, TreeType>::
523 Rescore(TreeType& /*queryNode*/,
524  TreeType& /*referenceNode*/,
525  const double oldScore) const
526 {
527  // If a branch is pruned then it continues to be pruned.
528  return oldScore;
529 }
530 
531 template<typename MetricType, typename KernelType, typename TreeType>
532 inline force_inline double KDERules<MetricType, KernelType, TreeType>::
533 EvaluateKernel(const size_t queryIndex,
534  const size_t referenceIndex) const
535 {
536  return EvaluateKernel(querySet.unsafe_col(queryIndex),
537  referenceSet.unsafe_col(referenceIndex));
538 }
539 
540 template<typename MetricType, typename KernelType, typename TreeType>
541 inline force_inline double KDERules<MetricType, KernelType, TreeType>::
542 EvaluateKernel(const arma::vec& query, const arma::vec& reference) const
543 {
544  return kernel.Evaluate(metric.Evaluate(query, reference));
545 }
546 
547 template<typename MetricType, typename KernelType, typename TreeType>
548 inline force_inline double KDERules<MetricType, KernelType, TreeType>::
549 CalculateAlpha(TreeType* node)
550 {
551  KDEStat& stat = node->Stat();
552 
553  // If new mcBeta is different from previously computed mcBeta, then alpha for
554  // the node is recomputed.
555  if (std::abs(stat.MCBeta() - mcBeta) > DBL_EPSILON)
556  {
557  TreeType* parent = node->Parent();
558  if (parent == NULL)
559  {
560  // If it's the root node then assign mcBeta.
561  stat.MCAlpha() = mcBeta;
562  }
563  else
564  {
565  // Distribute it's parent alpha between children.
566  stat.MCAlpha() = parent->Stat().MCAlpha() / parent->NumChildren();
567  }
568 
569  // Set beta value for which this alpha is valid.
570  stat.MCBeta() = mcBeta;
571  }
572 
573  return stat.MCAlpha();
574 }
575 
577 template<typename TreeType>
578 inline force_inline
579 double KDECleanRules<TreeType>::BaseCase(const size_t /* queryIndex */,
580  const size_t /* refIndex */)
581 {
582  return 0;
583 }
584 
586 template<typename TreeType>
587 inline force_inline
588 double KDECleanRules<TreeType>::Score(const size_t /* queryIndex */,
589  TreeType& referenceNode)
590 {
591  referenceNode.Stat().AccumAlpha() = 0;
592  referenceNode.Stat().AccumError() = 0;
593  return 0;
594 }
595 
597 template<typename TreeType>
598 inline force_inline
599 double KDECleanRules<TreeType>::Score(TreeType& queryNode,
600  TreeType& referenceNode)
601 {
602  queryNode.Stat().AccumAlpha() = 0;
603  referenceNode.Stat().AccumAlpha() = 0;
604 
605  queryNode.Stat().AccumError() = 0;
606  referenceNode.Stat().AccumError() = 0;
607 
608  return 0;
609 }
610 
611 } // namespace kde
612 } // namespace mlpack
613 
614 #endif
T Lo() const
Get the lower bound.
Definition: range.hpp:61
double AccumError() const
Get accumulated error tolerance of the node.
Definition: kde_stat.hpp:57
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Extra data for each node in the tree for the task of kernel density estimation.
Definition: kde_stat.hpp:24
double MCBeta() const
Get accumulated Monte Carlo alpha of the node.
Definition: kde_stat.hpp:45
double AccumAlpha() const
Get accumulated Monte Carlo alpha of the node.
Definition: kde_stat.hpp:51
double LastBaseCase() const
Get the base case associated with the last node combination.
Definition: traversal_info.hpp:78
TreeType * LastQueryNode() const
Get the last query node.
Definition: traversal_info.hpp:63
KDERules(const arma::mat &referenceSet, const arma::mat &querySet, arma::vec &densities, const double relError, const double absError, const double mcProb, const size_t initialSampleSize, const double mcAccessCoef, const double mcBreakCoef, MetricType &metric, KernelType &kernel, const bool monteCarlo, const bool sameSet)
Construct KDERules.
Definition: kde_rules_impl.hpp:26
double LastScore() const
Get the score associated with the last query and reference nodes.
Definition: traversal_info.hpp:73
A dual-tree traversal Rules class for kernel density estimation.
Definition: kde_rules.hpp:26
double BaseCase(const size_t queryIndex, const size_t referenceIndex)
Base Case.
Definition: kde_rules_impl.hpp:70
T Hi() const
Get the upper bound.
Definition: range.hpp:66
double BaseCase(const size_t, const size_t)
Base Case.
Definition: kde_rules_impl.hpp:579
The TreeTraits class provides compile-time information on the characteristics of a given tree type...
Definition: tree_traits.hpp:77
double Score(const size_t, TreeType &referenceNode)
SingleTree Score.
Definition: kde_rules_impl.hpp:588
int RandInt(const int hiExclusive)
Generates a uniform random integer.
Definition: random.hpp:110
double MCAlpha() const
Get Monte Carlo alpha of the node.
Definition: kde_stat.hpp:63
TreeType * LastReferenceNode() const
Get the last reference node.
Definition: traversal_info.hpp:68
double Rescore(const size_t queryIndex, TreeType &referenceNode, const double oldScore) const
SingleTree Score.
Definition: kde_rules_impl.hpp:298
double Score(const size_t queryIndex, TreeType &referenceNode)
SingleTree Rescore.
Definition: kde_rules_impl.hpp:102