mlpack
kde_impl.hpp
Go to the documentation of this file.
1 
13 #include "kde.hpp"
14 #include "kde_rules.hpp"
15 
16 namespace mlpack {
17 namespace kde {
18 
20 template<typename TreeType, typename MatType>
21 TreeType* BuildTree(
22  MatType&& dataset,
23  std::vector<size_t>& oldFromNew,
24  const typename std::enable_if<
26 {
27  return new TreeType(std::forward<MatType>(dataset), oldFromNew);
28 }
29 
31 template<typename TreeType, typename MatType>
32 TreeType* BuildTree(
33  MatType&& dataset,
34  const std::vector<size_t>& /* oldFromNew */,
35  const typename std::enable_if<
37 {
38  return new TreeType(std::forward<MatType>(dataset));
39 }
40 
41 template<typename KernelType,
42  typename MetricType,
43  typename MatType,
44  template<typename TreeMetricType,
45  typename TreeStatType,
46  typename TreeMatType> class TreeType,
47  template<typename> class DualTreeTraversalType,
48  template<typename> class SingleTreeTraversalType>
49 KDE<KernelType,
50  MetricType,
51  MatType,
52  TreeType,
53  DualTreeTraversalType,
54  SingleTreeTraversalType>::
55 KDE(const double relError,
56  const double absError,
57  KernelType kernel,
58  const KDEMode mode,
59  MetricType metric,
60  const bool monteCarlo,
61  const double mcProb,
62  const size_t initialSampleSize,
63  const double mcEntryCoef,
64  const double mcBreakCoef) :
65  kernel(kernel),
66  metric(metric),
67  referenceTree(nullptr),
68  oldFromNewReferences(nullptr),
69  relError(relError),
70  absError(absError),
71  ownsReferenceTree(false),
72  trained(false),
73  mode(mode),
74  monteCarlo(monteCarlo),
75  initialSampleSize(initialSampleSize)
76 {
77  CheckErrorValues(relError, absError);
78  MCProb(mcProb);
79  MCEntryCoef(mcEntryCoef);
80  MCBreakCoef(mcBreakCoef);
81 }
82 
83 template<typename KernelType,
84  typename MetricType,
85  typename MatType,
86  template<typename TreeMetricType,
87  typename TreeStatType,
88  typename TreeMatType> class TreeType,
89  template<typename> class DualTreeTraversalType,
90  template<typename> class SingleTreeTraversalType>
91 KDE<KernelType,
92  MetricType,
93  MatType,
94  TreeType,
95  DualTreeTraversalType,
96  SingleTreeTraversalType>::
97 KDE(const KDE& other) :
98  kernel(KernelType(other.kernel)),
99  metric(MetricType(other.metric)),
100  relError(other.relError),
101  absError(other.absError),
102  ownsReferenceTree(other.ownsReferenceTree),
103  trained(other.trained),
104  mode(other.mode),
105  monteCarlo(other.monteCarlo),
106  mcProb(other.mcProb),
107  initialSampleSize(other.initialSampleSize),
108  mcEntryCoef(other.mcEntryCoef),
109  mcBreakCoef(other.mcBreakCoef)
110 {
111  if (trained)
112  {
113  if (ownsReferenceTree)
114  {
115  oldFromNewReferences =
116  new std::vector<size_t>(*other.oldFromNewReferences);
117  referenceTree = new Tree(*other.referenceTree);
118  }
119  else
120  {
121  oldFromNewReferences = other.oldFromNewReferences;
122  referenceTree = other.referenceTree;
123  }
124  }
125 }
126 
127 template<typename KernelType,
128  typename MetricType,
129  typename MatType,
130  template<typename TreeMetricType,
131  typename TreeStatType,
132  typename TreeMatType> class TreeType,
133  template<typename> class DualTreeTraversalType,
134  template<typename> class SingleTreeTraversalType>
135 KDE<KernelType,
136  MetricType,
137  MatType,
138  TreeType,
139  DualTreeTraversalType,
140  SingleTreeTraversalType>::
141 KDE(KDE&& other) :
142  kernel(std::move(other.kernel)),
143  metric(std::move(other.metric)),
144  referenceTree(other.referenceTree),
145  oldFromNewReferences(other.oldFromNewReferences),
146  relError(other.relError),
147  absError(other.absError),
148  ownsReferenceTree(other.ownsReferenceTree),
149  trained(other.trained),
150  mode(other.mode),
151  monteCarlo(other.monteCarlo),
152  mcProb(other.mcProb),
153  initialSampleSize(other.initialSampleSize),
154  mcEntryCoef(other.mcEntryCoef),
155  mcBreakCoef(other.mcBreakCoef)
156 {
157  other.kernel = std::move(KernelType());
158  other.metric = std::move(MetricType());
159  other.referenceTree = nullptr;
160  other.oldFromNewReferences = nullptr;
161  other.relError = KDEDefaultParams::relError;
162  other.absError = KDEDefaultParams::absError;
163  other.ownsReferenceTree = false;
164  other.trained = false;
165  other.mode = KDEDefaultParams::mode;
166  other.monteCarlo = KDEDefaultParams::monteCarlo;
167  other.mcProb = KDEDefaultParams::mcProb;
168  other.initialSampleSize = KDEDefaultParams::initialSampleSize;
169  other.mcEntryCoef = KDEDefaultParams::mcEntryCoef;
170  other.mcBreakCoef = KDEDefaultParams::mcBreakCoef;
171 }
172 
173 template<typename KernelType,
174  typename MetricType,
175  typename MatType,
176  template<typename TreeMetricType,
177  typename TreeStatType,
178  typename TreeMatType> class TreeType,
179  template<typename> class DualTreeTraversalType,
180  template<typename> class SingleTreeTraversalType>
181 KDE<KernelType,
182  MetricType,
183  MatType,
184  TreeType,
185  DualTreeTraversalType,
186  SingleTreeTraversalType>&
187 KDE<KernelType,
188  MetricType,
189  MatType,
190  TreeType,
191  DualTreeTraversalType,
192  SingleTreeTraversalType>::
193 operator=(const KDE& other)
194 {
195  if (this != &other)
196  {
197  // Clean memory.
198  if (ownsReferenceTree)
199  {
200  delete referenceTree;
201  delete oldFromNewReferences;
202  }
203  kernel = KernelType(other.kernel);
204  metric = MetricType(other.metric);
205  relError = other.relError;
206  absError = other.absError;
207  ownsReferenceTree = other.ownsReferenceTree;
208  trained = other.trained;
209  mode = other.mode;
210  monteCarlo = other.monteCarlo;
211  mcProb = other.mcProb;
212  initialSampleSize = other.initialSampleSize;
213  mcEntryCoef = other.mcEntryCoef;
214  mcBreakCoef = other.mcBreakCoef;
215  if (trained)
216  {
217  if (ownsReferenceTree)
218  {
219  oldFromNewReferences =
220  new std::vector<size_t>(*other.oldFromNewReferences);
221  referenceTree = new Tree(*other.referenceTree);
222  }
223  else
224  {
225  oldFromNewReferences = other.oldFromNewReferences;
226  referenceTree = other.referenceTree;
227  }
228  }
229  }
230  return *this;
231 }
232 
233 template<typename KernelType,
234  typename MetricType,
235  typename MatType,
236  template<typename TreeMetricType,
237  typename TreeStatType,
238  typename TreeMatType> class TreeType,
239  template<typename> class DualTreeTraversalType,
240  template<typename> class SingleTreeTraversalType>
241 KDE<KernelType,
242  MetricType,
243  MatType,
244  TreeType,
245  DualTreeTraversalType,
246  SingleTreeTraversalType>&
247 KDE<KernelType,
248  MetricType,
249  MatType,
250  TreeType,
251  DualTreeTraversalType,
252  SingleTreeTraversalType>::
253 operator=(KDE&& other)
254 {
255  if (this != &other)
256  {
257  // Clean memory.
258  if (ownsReferenceTree)
259  {
260  delete referenceTree;
261  delete oldFromNewReferences;
262  }
263 
264  // Move the other object.
265  this->kernel = std::move(other.kernel);
266  this->metric = std::move(other.metric);
267  this->referenceTree = std::move(other.referenceTree);
268  this->oldFromNewReferences = std::move(other.oldFromNewReferences);
269  this->relError = other.relError;
270  this->absError = other.absError;
271  this->ownsReferenceTree = other.ownsReferenceTree;
272  this->trained = other.trained;
273  this->mode = other.mode;
274  this->monteCarlo = other.monteCarlo;
275  this->mcProb = other.mcProb;
276  this->initialSampleSize = other.initialSampleSize;
277  this->mcEntryCoef = other.mcEntryCoef;
278  this->mcBreakCoef = other.mcBreakCoef;
279  }
280  return *this;
281 }
282 
283 template<typename KernelType,
284  typename MetricType,
285  typename MatType,
286  template<typename TreeMetricType,
287  typename TreeStatType,
288  typename TreeMatType> class TreeType,
289  template<typename> class DualTreeTraversalType,
290  template<typename> class SingleTreeTraversalType>
291 KDE<KernelType,
292  MetricType,
293  MatType,
294  TreeType,
295  DualTreeTraversalType,
296  SingleTreeTraversalType>::
297 ~KDE()
298 {
299  if (ownsReferenceTree)
300  {
301  delete referenceTree;
302  delete oldFromNewReferences;
303  }
304 }
305 
306 template<typename KernelType,
307  typename MetricType,
308  typename MatType,
309  template<typename TreeMetricType,
310  typename TreeStatType,
311  typename TreeMatType> class TreeType,
312  template<typename> class DualTreeTraversalType,
313  template<typename> class SingleTreeTraversalType>
314 void KDE<KernelType,
315  MetricType,
316  MatType,
317  TreeType,
318  DualTreeTraversalType,
319  SingleTreeTraversalType>::
320 Train(MatType referenceSet)
321 {
322  // Check if referenceSet is not an empty set.
323  if (referenceSet.n_cols == 0)
324  {
325  throw std::invalid_argument("cannot train KDE model with an empty "
326  "reference set");
327  }
328 
329  if (ownsReferenceTree)
330  {
331  delete referenceTree;
332  delete oldFromNewReferences;
333  }
334 
335  this->ownsReferenceTree = true;
336  Timer::Start("building_reference_tree");
337  this->oldFromNewReferences = new std::vector<size_t>;
338  this->referenceTree = BuildTree<Tree>(std::move(referenceSet),
339  *oldFromNewReferences);
340  Timer::Stop("building_reference_tree");
341  this->trained = true;
342 }
343 
344 template<typename KernelType,
345  typename MetricType,
346  typename MatType,
347  template<typename TreeMetricType,
348  typename TreeStatType,
349  typename TreeMatType> class TreeType,
350  template<typename> class DualTreeTraversalType,
351  template<typename> class SingleTreeTraversalType>
352 void KDE<KernelType,
353  MetricType,
354  MatType,
355  TreeType,
356  DualTreeTraversalType,
357  SingleTreeTraversalType>::
358 Train(Tree* referenceTree, std::vector<size_t>* oldFromNewReferences)
359 {
360  // Check if referenceTree dataset is not an empty set.
361  if (referenceTree->Dataset().n_cols == 0)
362  {
363  throw std::invalid_argument("cannot train KDE model with an empty "
364  "reference set");
365  }
366 
367  if (ownsReferenceTree == true)
368  {
369  delete this->referenceTree;
370  delete this->oldFromNewReferences;
371  }
372 
373  this->ownsReferenceTree = false;
374  this->referenceTree = referenceTree;
375  this->oldFromNewReferences = oldFromNewReferences;
376  this->trained = true;
377 }
378 
379 template<typename KernelType,
380  typename MetricType,
381  typename MatType,
382  template<typename TreeMetricType,
383  typename TreeStatType,
384  typename TreeMatType> class TreeType,
385  template<typename> class DualTreeTraversalType,
386  template<typename> class SingleTreeTraversalType>
387 void KDE<KernelType,
388  MetricType,
389  MatType,
390  TreeType,
391  DualTreeTraversalType,
392  SingleTreeTraversalType>::
393 Evaluate(MatType querySet, arma::vec& estimations)
394 {
395  if (mode == DUAL_TREE_MODE)
396  {
397  Timer::Start("building_query_tree");
398  std::vector<size_t> oldFromNewQueries;
399  Tree* queryTree = BuildTree<Tree>(std::move(querySet), oldFromNewQueries);
400  Timer::Stop("building_query_tree");
401  try
402  {
403  this->Evaluate(queryTree, oldFromNewQueries, estimations);
404  }
405  catch (std::exception& e)
406  {
407  // Make sure we delete the query tree.
408  delete queryTree;
409  throw;
410  }
411  delete queryTree;
412  }
413  else if (mode == SINGLE_TREE_MODE)
414  {
415  // Get estimations vector ready.
416  estimations.clear();
417  estimations.set_size(querySet.n_cols);
418  estimations.fill(arma::fill::zeros);
419 
420  // Check whether has already been trained.
421  if (!trained)
422  {
423  throw std::runtime_error("cannot evaluate KDE model: model needs to be "
424  "trained before evaluation");
425  }
426 
427  // Check querySet has at least 1 element to evaluate.
428  if (querySet.n_cols == 0)
429  {
430  Log::Warn << "KDE::Evaluate(): querySet is empty, no predictions will "
431  << "be returned" << std::endl;
432  return;
433  }
434 
435  // Check whether dimensions match.
436  if (querySet.n_rows != referenceTree->Dataset().n_rows)
437  {
438  throw std::invalid_argument("cannot evaluate KDE model: querySet and "
439  "referenceSet dimensions don't match");
440  }
441 
442  Timer::Start("computing_kde");
443 
444  // Evaluate.
446  RuleType rules = RuleType(referenceTree->Dataset(),
447  querySet,
448  estimations,
449  relError,
450  absError,
451  mcProb,
452  initialSampleSize,
453  mcEntryCoef,
454  mcBreakCoef,
455  metric,
456  kernel,
457  monteCarlo,
458  false);
459 
460  // Create traverser.
461  SingleTreeTraversalType<RuleType> traverser(rules);
462 
463  // Traverse for each point.
464  for (size_t i = 0; i < querySet.n_cols; ++i)
465  traverser.Traverse(i, *referenceTree);
466 
467  estimations /= referenceTree->Dataset().n_cols;
468  Timer::Stop("computing_kde");
469 
470  Log::Info << rules.Scores() << " node combinations were scored."
471  << std::endl;
472  Log::Info << rules.BaseCases() << " base cases were calculated."
473  << std::endl;
474  }
475 }
476 
477 template<typename KernelType,
478  typename MetricType,
479  typename MatType,
480  template<typename TreeMetricType,
481  typename TreeStatType,
482  typename TreeMatType> class TreeType,
483  template<typename> class DualTreeTraversalType,
484  template<typename> class SingleTreeTraversalType>
485 void KDE<KernelType,
486  MetricType,
487  MatType,
488  TreeType,
489  DualTreeTraversalType,
490  SingleTreeTraversalType>::
491 Evaluate(Tree* queryTree,
492  const std::vector<size_t>& oldFromNewQueries,
493  arma::vec& estimations)
494 {
495  // Get estimations vector ready.
496  estimations.clear();
497  estimations.set_size(queryTree->Dataset().n_cols);
498  estimations.fill(arma::fill::zeros);
499 
500  // Check whether has already been trained.
501  if (!trained)
502  {
503  throw std::runtime_error("cannot evaluate KDE model: model needs to be "
504  "trained before evaluation");
505  }
506 
507  // Check querySet has at least 1 element to evaluate.
508  if (queryTree->Dataset().n_cols == 0)
509  {
510  Log::Warn << "KDE::Evaluate(): querySet is empty, no predictions will "
511  << "be returned" << std::endl;
512  return;
513  }
514 
515  // Check whether dimensions match.
516  if (queryTree->Dataset().n_rows != referenceTree->Dataset().n_rows)
517  {
518  throw std::invalid_argument("cannot evaluate KDE model: querySet and "
519  "referenceSet dimensions don't match");
520  }
521 
522  // Check the mode is correct.
523  if (mode != DUAL_TREE_MODE)
524  {
525  throw std::invalid_argument("cannot evaluate KDE model: cannot use "
526  "a query tree when mode is different from "
527  "dual-tree");
528  }
529 
530  // Clean accumulated alpha if Monte Carlo estimations are available.
531  if (monteCarlo && std::is_same<KernelType, kernel::GaussianKernel>::value)
532  {
533  Timer::Start("cleaning_query_tree");
534  KDECleanRules<Tree> cleanRules;
535  SingleTreeTraversalType<KDECleanRules<Tree>> cleanTraverser(cleanRules);
536  cleanTraverser.Traverse(0, *queryTree);
537  Timer::Stop("cleaning_query_tree");
538  }
539 
540  Timer::Start("computing_kde");
541 
542  // Evaluate.
544  RuleType rules = RuleType(referenceTree->Dataset(),
545  queryTree->Dataset(),
546  estimations,
547  relError,
548  absError,
549  mcProb,
550  initialSampleSize,
551  mcEntryCoef,
552  mcBreakCoef,
553  metric,
554  kernel,
555  monteCarlo,
556  false);
557 
558  // Create traverser.
559  DualTreeTraversalType<RuleType> traverser(rules);
560  traverser.Traverse(*queryTree, *referenceTree);
561  estimations /= referenceTree->Dataset().n_cols;
562  Timer::Stop("computing_kde");
563 
564  // Rearrange if necessary.
565  RearrangeEstimations(oldFromNewQueries, estimations);
566 
567  Log::Info << rules.Scores() << " node combinations were scored." << std::endl;
568  Log::Info << rules.BaseCases() << " base cases were calculated." << std::endl;
569 }
570 
571 template<typename KernelType,
572  typename MetricType,
573  typename MatType,
574  template<typename TreeMetricType,
575  typename TreeStatType,
576  typename TreeMatType> class TreeType,
577  template<typename> class DualTreeTraversalType,
578  template<typename> class SingleTreeTraversalType>
579 void KDE<KernelType,
580  MetricType,
581  MatType,
582  TreeType,
583  DualTreeTraversalType,
584  SingleTreeTraversalType>::
585 Evaluate(arma::vec& estimations)
586 {
587  // Check whether has already been trained.
588  if (!trained)
589  {
590  throw std::runtime_error("cannot evaluate KDE model: model needs to be "
591  "trained before evaluation");
592  }
593 
594  // Get estimations vector ready.
595  estimations.clear();
596  estimations.set_size(referenceTree->Dataset().n_cols);
597  estimations.fill(arma::fill::zeros);
598 
599  // Clean accumulated alpha if Monte Carlo estimations are available.
600  if (monteCarlo && std::is_same<KernelType, kernel::GaussianKernel>::value)
601  {
602  Timer::Start("cleaning_query_tree");
603  KDECleanRules<Tree> cleanRules;
604  SingleTreeTraversalType<KDECleanRules<Tree>> cleanTraverser(cleanRules);
605  cleanTraverser.Traverse(0, *referenceTree);
606  Timer::Stop("cleaning_query_tree");
607  }
608 
609  Timer::Start("computing_kde");
610 
611  // Evaluate.
613  RuleType rules = RuleType(referenceTree->Dataset(),
614  referenceTree->Dataset(),
615  estimations,
616  relError,
617  absError,
618  mcProb,
619  initialSampleSize,
620  mcEntryCoef,
621  mcBreakCoef,
622  metric,
623  kernel,
624  monteCarlo,
625  true);
626 
627  if (mode == DUAL_TREE_MODE)
628  {
629  // Create traverser.
630  DualTreeTraversalType<RuleType> traverser(rules);
631  traverser.Traverse(*referenceTree, *referenceTree);
632  }
633  else if (mode == SINGLE_TREE_MODE)
634  {
635  SingleTreeTraversalType<RuleType> traverser(rules);
636  for (size_t i = 0; i < referenceTree->Dataset().n_cols; ++i)
637  traverser.Traverse(i, *referenceTree);
638  }
639 
640  estimations /= referenceTree->Dataset().n_cols;
641  // Rearrange if necessary.
642  RearrangeEstimations(*oldFromNewReferences, estimations);
643  Timer::Stop("computing_kde");
644 
645  Log::Info << rules.Scores() << " node combinations were scored." << std::endl;
646  Log::Info << rules.BaseCases() << " base cases were calculated." << std::endl;
647 }
648 
649 template<typename KernelType,
650  typename MetricType,
651  typename MatType,
652  template<typename TreeMetricType,
653  typename TreeStatType,
654  typename TreeMatType> class TreeType,
655  template<typename> class DualTreeTraversalType,
656  template<typename> class SingleTreeTraversalType>
657 void KDE<KernelType,
658  MetricType,
659  MatType,
660  TreeType,
661  DualTreeTraversalType,
662  SingleTreeTraversalType>::
663 RelativeError(const double newError)
664 {
665  CheckErrorValues(newError, absError);
666  relError = newError;
667 }
668 
669 template<typename KernelType,
670  typename MetricType,
671  typename MatType,
672  template<typename TreeMetricType,
673  typename TreeStatType,
674  typename TreeMatType> class TreeType,
675  template<typename> class DualTreeTraversalType,
676  template<typename> class SingleTreeTraversalType>
677 void KDE<KernelType,
678  MetricType,
679  MatType,
680  TreeType,
681  DualTreeTraversalType,
682  SingleTreeTraversalType>::
683 AbsoluteError(const double newError)
684 {
685  CheckErrorValues(relError, newError);
686  absError = newError;
687 }
688 
689 template<typename KernelType,
690  typename MetricType,
691  typename MatType,
692  template<typename TreeMetricType,
693  typename TreeStatType,
694  typename TreeMatType> class TreeType,
695  template<typename> class DualTreeTraversalType,
696  template<typename> class SingleTreeTraversalType>
697 void KDE<KernelType,
698  MetricType,
699  MatType,
700  TreeType,
701  DualTreeTraversalType,
702  SingleTreeTraversalType>::
703 MCProb(const double newProb)
704 {
705  if (newProb < 0 || newProb >= 1)
706  {
707  throw std::invalid_argument("Monte Carlo probability must be a value "
708  "greater than or equal to 0 and smaller than"
709  "1");
710  }
711  mcProb = newProb;
712 }
713 
714 template<typename KernelType,
715  typename MetricType,
716  typename MatType,
717  template<typename TreeMetricType,
718  typename TreeStatType,
719  typename TreeMatType> class TreeType,
720  template<typename> class DualTreeTraversalType,
721  template<typename> class SingleTreeTraversalType>
722 void KDE<KernelType,
723  MetricType,
724  MatType,
725  TreeType,
726  DualTreeTraversalType,
727  SingleTreeTraversalType>::
728 MCEntryCoef(const double newCoef)
729 {
730  if (newCoef < 1)
731  {
732  throw std::invalid_argument("Monte Carlo entry coefficient must be a value "
733  "greater than or equal to 1");
734  }
735  mcEntryCoef = newCoef;
736 }
737 
738 template<typename KernelType,
739  typename MetricType,
740  typename MatType,
741  template<typename TreeMetricType,
742  typename TreeStatType,
743  typename TreeMatType> class TreeType,
744  template<typename> class DualTreeTraversalType,
745  template<typename> class SingleTreeTraversalType>
746 void KDE<KernelType,
747  MetricType,
748  MatType,
749  TreeType,
750  DualTreeTraversalType,
751  SingleTreeTraversalType>::
752 MCBreakCoef(const double newCoef)
753 {
754  if (newCoef <= 0 || newCoef > 1)
755  {
756  throw std::invalid_argument("Monte Carlo break coefficient must be a value "
757  "greater than 0 and less than or equal to 1");
758  }
759  mcBreakCoef = newCoef;
760 }
761 
762 template<typename KernelType,
763  typename MetricType,
764  typename MatType,
765  template<typename TreeMetricType,
766  typename TreeStatType,
767  typename TreeMatType> class TreeType,
768  template<typename> class DualTreeTraversalType,
769  template<typename> class SingleTreeTraversalType>
770 template<typename Archive>
771 void KDE<KernelType,
772  MetricType,
773  MatType,
774  TreeType,
775  DualTreeTraversalType,
776  SingleTreeTraversalType>::
777 serialize(Archive& ar, const uint32_t /* version */)
778 {
779  // Serialize preferences.
780  ar(CEREAL_NVP(relError));
781  ar(CEREAL_NVP(absError));
782  ar(CEREAL_NVP(trained));
783  ar(CEREAL_NVP(mode));
784  ar(CEREAL_NVP(monteCarlo));
785  ar(CEREAL_NVP(mcProb));
786  ar(CEREAL_NVP(initialSampleSize));
787  ar(CEREAL_NVP(mcEntryCoef));
788  ar(CEREAL_NVP(mcBreakCoef));
789 
790  // If we are loading, clean up memory if necessary.
791  if (cereal::is_loading<Archive>())
792  {
793  if (ownsReferenceTree && referenceTree)
794  {
795  delete referenceTree;
796  delete oldFromNewReferences;
797  }
798  // After loading tree, we own it.
799  ownsReferenceTree = true;
800  }
801 
802  // Serialize the rest of values.
803  ar(CEREAL_NVP(kernel));
804  ar(CEREAL_NVP(metric));
805  ar(CEREAL_POINTER(referenceTree));
806  ar(CEREAL_POINTER(oldFromNewReferences));
807 }
808 
809 template<typename KernelType,
810  typename MetricType,
811  typename MatType,
812  template<typename TreeMetricType,
813  typename TreeStatType,
814  typename TreeMatType> class TreeType,
815  template<typename> class DualTreeTraversalType,
816  template<typename> class SingleTreeTraversalType>
817 void KDE<KernelType,
818  MetricType,
819  MatType,
820  TreeType,
821  DualTreeTraversalType,
822  SingleTreeTraversalType>::
823 CheckErrorValues(const double relError, const double absError)
824 {
825  if (relError < 0 || relError > 1)
826  {
827  throw std::invalid_argument("Relative error tolerance must be a value "
828  "between 0 and 1");
829  }
830  if (absError < 0)
831  {
832  throw std::invalid_argument("Absolute error tolerance must be a value "
833  "greater than or equal to 0");
834  }
835 }
836 
837 template<typename KernelType,
838  typename MetricType,
839  typename MatType,
840  template<typename TreeMetricType,
841  typename TreeStatType,
842  typename TreeMatType> class TreeType,
843  template<typename> class DualTreeTraversalType,
844  template<typename> class SingleTreeTraversalType>
845 void KDE<KernelType,
846  MetricType,
847  MatType,
848  TreeType,
849  DualTreeTraversalType,
850  SingleTreeTraversalType>::
851 RearrangeEstimations(const std::vector<size_t>& oldFromNew,
852  arma::vec& estimations)
853 {
855  {
856  const size_t nQueries = oldFromNew.size();
857  arma::vec rearrangedEstimations(nQueries);
858 
859  // Remap vector.
860  for (size_t i = 0; i < nQueries; ++i)
861  rearrangedEstimations(oldFromNew.at(i)) = estimations(i);
862 
863  estimations = std::move(rearrangedEstimations);
864  }
865 }
866 
867 } // namespace kde
868 } // namespace mlpack
static constexpr bool monteCarlo
Whether to use Monte Carlo estimations when possible.
Definition: kde.hpp:44
static constexpr double mcProb
Probability of a Monte Carlo estimation to be bounded by the relative error tolerance.
Definition: kde.hpp:48
TreeType * BuildTree(MatType &&dataset, std::vector< size_t > &oldFromNew, const typename std::enable_if< tree::TreeTraits< TreeType >::RearrangesDataset >::type *=0)
Construct tree that rearranges the dataset.
Definition: kde_impl.hpp:21
static void Start(const std::string &name)
Start the given timer.
Definition: timers.cpp:28
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
A dual-tree traversal Rules class for cleaning used trees before performing kernel density estimation...
Definition: kde_rules.hpp:189
static constexpr KDEMode mode
KDE algorithm mode.
Definition: kde.hpp:41
static constexpr double relError
Relative error tolerance.
Definition: kde.hpp:35
static MLPACK_EXPORT util::PrefixedOutStream Warn
Prints warning messages prefixed with [WARN ].
Definition: log.hpp:87
static constexpr double mcEntryCoef
Monte Carlo entry coefficient.
Definition: kde.hpp:54
static constexpr double absError
Absolute error tolerance.
Definition: kde.hpp:38
Definition: hmm_train_main.cpp:300
A dual-tree traversal Rules class for kernel density estimation.
Definition: kde_rules.hpp:26
static void Stop(const std::string &name)
Stop the given timer.
Definition: timers.cpp:36
The TreeTraits class provides compile-time information on the characteristics of a given tree type...
Definition: tree_traits.hpp:77
static constexpr double mcBreakCoef
Monte Carlo break coefficient.
Definition: kde.hpp:57
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84
#define CEREAL_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_wrapper.hpp:96
static constexpr size_t initialSampleSize
Initial sample size for Monte Carlo estimations.
Definition: kde.hpp:51