20 template<
typename TreeType,
typename MatType>
23 std::vector<size_t>& oldFromNew,
24 const typename std::enable_if<
27 return new TreeType(std::forward<MatType>(dataset), oldFromNew);
31 template<
typename TreeType,
typename MatType>
34 const std::vector<size_t>& ,
35 const typename std::enable_if<
38 return new TreeType(std::forward<MatType>(dataset));
41 template<
typename KernelType,
44 template<
typename TreeMetricType,
45 typename TreeStatType,
46 typename TreeMatType>
class TreeType,
47 template<typename> class DualTreeTraversalType,
48 template<typename> class SingleTreeTraversalType>
53 DualTreeTraversalType,
54 SingleTreeTraversalType>::
55 KDE(const double relError,
56 const double absError,
60 const bool monteCarlo,
62 const size_t initialSampleSize,
63 const double mcEntryCoef,
64 const double mcBreakCoef) :
67 referenceTree(nullptr),
68 oldFromNewReferences(nullptr),
71 ownsReferenceTree(false),
74 monteCarlo(monteCarlo),
75 initialSampleSize(initialSampleSize)
77 CheckErrorValues(relError, absError);
79 MCEntryCoef(mcEntryCoef);
80 MCBreakCoef(mcBreakCoef);
83 template<
typename KernelType,
86 template<
typename TreeMetricType,
87 typename TreeStatType,
88 typename TreeMatType>
class TreeType,
89 template<typename> class DualTreeTraversalType,
90 template<typename> class SingleTreeTraversalType>
95 DualTreeTraversalType,
96 SingleTreeTraversalType>::
97 KDE(const KDE& other) :
98 kernel(KernelType(other.kernel)),
99 metric(MetricType(other.metric)),
100 relError(other.relError),
101 absError(other.absError),
102 ownsReferenceTree(other.ownsReferenceTree),
103 trained(other.trained),
105 monteCarlo(other.monteCarlo),
106 mcProb(other.mcProb),
107 initialSampleSize(other.initialSampleSize),
108 mcEntryCoef(other.mcEntryCoef),
109 mcBreakCoef(other.mcBreakCoef)
113 if (ownsReferenceTree)
115 oldFromNewReferences =
116 new std::vector<size_t>(*other.oldFromNewReferences);
117 referenceTree =
new Tree(*other.referenceTree);
121 oldFromNewReferences = other.oldFromNewReferences;
122 referenceTree = other.referenceTree;
127 template<
typename KernelType,
130 template<
typename TreeMetricType,
131 typename TreeStatType,
132 typename TreeMatType>
class TreeType,
133 template<typename> class DualTreeTraversalType,
134 template<typename> class SingleTreeTraversalType>
139 DualTreeTraversalType,
140 SingleTreeTraversalType>::
142 kernel(std::move(other.kernel)),
143 metric(std::move(other.metric)),
144 referenceTree(other.referenceTree),
145 oldFromNewReferences(other.oldFromNewReferences),
146 relError(other.relError),
147 absError(other.absError),
148 ownsReferenceTree(other.ownsReferenceTree),
149 trained(other.trained),
151 monteCarlo(other.monteCarlo),
152 mcProb(other.mcProb),
153 initialSampleSize(other.initialSampleSize),
154 mcEntryCoef(other.mcEntryCoef),
155 mcBreakCoef(other.mcBreakCoef)
157 other.kernel = std::move(KernelType());
158 other.metric = std::move(MetricType());
159 other.referenceTree =
nullptr;
160 other.oldFromNewReferences =
nullptr;
163 other.ownsReferenceTree =
false;
164 other.trained =
false;
173 template<
typename KernelType,
176 template<
typename TreeMetricType,
177 typename TreeStatType,
178 typename TreeMatType>
class TreeType,
179 template<typename> class DualTreeTraversalType,
180 template<typename> class SingleTreeTraversalType>
185 DualTreeTraversalType,
186 SingleTreeTraversalType>&
191 DualTreeTraversalType,
192 SingleTreeTraversalType>::
193 operator=(
const KDE& other)
198 if (ownsReferenceTree)
200 delete referenceTree;
201 delete oldFromNewReferences;
203 kernel = KernelType(other.kernel);
204 metric = MetricType(other.metric);
205 relError = other.relError;
206 absError = other.absError;
207 ownsReferenceTree = other.ownsReferenceTree;
208 trained = other.trained;
210 monteCarlo = other.monteCarlo;
211 mcProb = other.mcProb;
212 initialSampleSize = other.initialSampleSize;
213 mcEntryCoef = other.mcEntryCoef;
214 mcBreakCoef = other.mcBreakCoef;
217 if (ownsReferenceTree)
219 oldFromNewReferences =
220 new std::vector<size_t>(*other.oldFromNewReferences);
221 referenceTree =
new Tree(*other.referenceTree);
225 oldFromNewReferences = other.oldFromNewReferences;
226 referenceTree = other.referenceTree;
233 template<
typename KernelType,
236 template<
typename TreeMetricType,
237 typename TreeStatType,
238 typename TreeMatType>
class TreeType,
239 template<typename> class DualTreeTraversalType,
240 template<typename> class SingleTreeTraversalType>
245 DualTreeTraversalType,
246 SingleTreeTraversalType>&
251 DualTreeTraversalType,
252 SingleTreeTraversalType>::
253 operator=(KDE&& other)
258 if (ownsReferenceTree)
260 delete referenceTree;
261 delete oldFromNewReferences;
265 this->kernel = std::move(other.kernel);
266 this->metric = std::move(other.metric);
267 this->referenceTree = std::move(other.referenceTree);
268 this->oldFromNewReferences = std::move(other.oldFromNewReferences);
269 this->relError = other.relError;
270 this->absError = other.absError;
271 this->ownsReferenceTree = other.ownsReferenceTree;
272 this->trained = other.trained;
273 this->mode = other.mode;
274 this->monteCarlo = other.monteCarlo;
275 this->mcProb = other.mcProb;
276 this->initialSampleSize = other.initialSampleSize;
277 this->mcEntryCoef = other.mcEntryCoef;
278 this->mcBreakCoef = other.mcBreakCoef;
283 template<
typename KernelType,
286 template<
typename TreeMetricType,
287 typename TreeStatType,
288 typename TreeMatType>
class TreeType,
289 template<typename> class DualTreeTraversalType,
290 template<typename> class SingleTreeTraversalType>
295 DualTreeTraversalType,
296 SingleTreeTraversalType>::
299 if (ownsReferenceTree)
301 delete referenceTree;
302 delete oldFromNewReferences;
306 template<
typename KernelType,
309 template<
typename TreeMetricType,
310 typename TreeStatType,
311 typename TreeMatType>
class TreeType,
312 template<typename> class DualTreeTraversalType,
313 template<typename> class SingleTreeTraversalType>
318 DualTreeTraversalType,
319 SingleTreeTraversalType>::
323 if (referenceSet.n_cols == 0)
325 throw std::invalid_argument(
"cannot train KDE model with an empty " 329 if (ownsReferenceTree)
331 delete referenceTree;
332 delete oldFromNewReferences;
335 this->ownsReferenceTree =
true;
337 this->oldFromNewReferences =
new std::vector<size_t>;
338 this->referenceTree = BuildTree<Tree>(std::move(referenceSet),
339 *oldFromNewReferences);
341 this->trained =
true;
344 template<
typename KernelType,
347 template<
typename TreeMetricType,
348 typename TreeStatType,
349 typename TreeMatType>
class TreeType,
350 template<typename> class DualTreeTraversalType,
351 template<typename> class SingleTreeTraversalType>
356 DualTreeTraversalType,
357 SingleTreeTraversalType>::
358 Train(Tree* referenceTree, std::vector<size_t>* oldFromNewReferences)
361 if (referenceTree->Dataset().n_cols == 0)
363 throw std::invalid_argument(
"cannot train KDE model with an empty " 367 if (ownsReferenceTree ==
true)
369 delete this->referenceTree;
370 delete this->oldFromNewReferences;
373 this->ownsReferenceTree =
false;
374 this->referenceTree = referenceTree;
375 this->oldFromNewReferences = oldFromNewReferences;
376 this->trained =
true;
379 template<
typename KernelType,
382 template<
typename TreeMetricType,
383 typename TreeStatType,
384 typename TreeMatType>
class TreeType,
385 template<typename> class DualTreeTraversalType,
386 template<typename> class SingleTreeTraversalType>
391 DualTreeTraversalType,
392 SingleTreeTraversalType>::
393 Evaluate(MatType querySet, arma::vec& estimations)
395 if (mode == DUAL_TREE_MODE)
398 std::vector<size_t> oldFromNewQueries;
399 Tree* queryTree = BuildTree<Tree>(std::move(querySet), oldFromNewQueries);
403 this->Evaluate(queryTree, oldFromNewQueries, estimations);
405 catch (std::exception& e)
413 else if (mode == SINGLE_TREE_MODE)
417 estimations.set_size(querySet.n_cols);
418 estimations.fill(arma::fill::zeros);
423 throw std::runtime_error(
"cannot evaluate KDE model: model needs to be " 424 "trained before evaluation");
428 if (querySet.n_cols == 0)
430 Log::Warn <<
"KDE::Evaluate(): querySet is empty, no predictions will " 431 <<
"be returned" << std::endl;
436 if (querySet.n_rows != referenceTree->Dataset().n_rows)
438 throw std::invalid_argument(
"cannot evaluate KDE model: querySet and " 439 "referenceSet dimensions don't match");
446 RuleType rules = RuleType(referenceTree->Dataset(),
461 SingleTreeTraversalType<RuleType> traverser(rules);
464 for (
size_t i = 0; i < querySet.n_cols; ++i)
465 traverser.Traverse(i, *referenceTree);
467 estimations /= referenceTree->Dataset().n_cols;
470 Log::Info << rules.Scores() <<
" node combinations were scored." 472 Log::Info << rules.BaseCases() <<
" base cases were calculated." 477 template<
typename KernelType,
480 template<
typename TreeMetricType,
481 typename TreeStatType,
482 typename TreeMatType>
class TreeType,
483 template<typename> class DualTreeTraversalType,
484 template<typename> class SingleTreeTraversalType>
489 DualTreeTraversalType,
490 SingleTreeTraversalType>::
491 Evaluate(Tree* queryTree,
492 const std::vector<size_t>& oldFromNewQueries,
493 arma::vec& estimations)
497 estimations.set_size(queryTree->Dataset().n_cols);
498 estimations.fill(arma::fill::zeros);
503 throw std::runtime_error(
"cannot evaluate KDE model: model needs to be " 504 "trained before evaluation");
508 if (queryTree->Dataset().n_cols == 0)
510 Log::Warn <<
"KDE::Evaluate(): querySet is empty, no predictions will " 511 <<
"be returned" << std::endl;
516 if (queryTree->Dataset().n_rows != referenceTree->Dataset().n_rows)
518 throw std::invalid_argument(
"cannot evaluate KDE model: querySet and " 519 "referenceSet dimensions don't match");
523 if (mode != DUAL_TREE_MODE)
525 throw std::invalid_argument(
"cannot evaluate KDE model: cannot use " 526 "a query tree when mode is different from " 531 if (monteCarlo && std::is_same<KernelType, kernel::GaussianKernel>::value)
535 SingleTreeTraversalType<KDECleanRules<Tree>> cleanTraverser(cleanRules);
536 cleanTraverser.Traverse(0, *queryTree);
544 RuleType rules = RuleType(referenceTree->Dataset(),
545 queryTree->Dataset(),
559 DualTreeTraversalType<RuleType> traverser(rules);
560 traverser.Traverse(*queryTree, *referenceTree);
561 estimations /= referenceTree->Dataset().n_cols;
565 RearrangeEstimations(oldFromNewQueries, estimations);
567 Log::Info << rules.Scores() <<
" node combinations were scored." << std::endl;
568 Log::Info << rules.BaseCases() <<
" base cases were calculated." << std::endl;
571 template<
typename KernelType,
574 template<
typename TreeMetricType,
575 typename TreeStatType,
576 typename TreeMatType>
class TreeType,
577 template<typename> class DualTreeTraversalType,
578 template<typename> class SingleTreeTraversalType>
583 DualTreeTraversalType,
584 SingleTreeTraversalType>::
585 Evaluate(arma::vec& estimations)
590 throw std::runtime_error(
"cannot evaluate KDE model: model needs to be " 591 "trained before evaluation");
596 estimations.set_size(referenceTree->Dataset().n_cols);
597 estimations.fill(arma::fill::zeros);
600 if (monteCarlo && std::is_same<KernelType, kernel::GaussianKernel>::value)
604 SingleTreeTraversalType<KDECleanRules<Tree>> cleanTraverser(cleanRules);
605 cleanTraverser.Traverse(0, *referenceTree);
613 RuleType rules = RuleType(referenceTree->Dataset(),
614 referenceTree->Dataset(),
627 if (mode == DUAL_TREE_MODE)
630 DualTreeTraversalType<RuleType> traverser(rules);
631 traverser.Traverse(*referenceTree, *referenceTree);
633 else if (mode == SINGLE_TREE_MODE)
635 SingleTreeTraversalType<RuleType> traverser(rules);
636 for (
size_t i = 0; i < referenceTree->Dataset().n_cols; ++i)
637 traverser.Traverse(i, *referenceTree);
640 estimations /= referenceTree->Dataset().n_cols;
642 RearrangeEstimations(*oldFromNewReferences, estimations);
645 Log::Info << rules.Scores() <<
" node combinations were scored." << std::endl;
646 Log::Info << rules.BaseCases() <<
" base cases were calculated." << std::endl;
649 template<
typename KernelType,
652 template<
typename TreeMetricType,
653 typename TreeStatType,
654 typename TreeMatType>
class TreeType,
655 template<typename> class DualTreeTraversalType,
656 template<typename> class SingleTreeTraversalType>
661 DualTreeTraversalType,
662 SingleTreeTraversalType>::
663 RelativeError(const double newError)
665 CheckErrorValues(newError, absError);
669 template<
typename KernelType,
672 template<
typename TreeMetricType,
673 typename TreeStatType,
674 typename TreeMatType>
class TreeType,
675 template<typename> class DualTreeTraversalType,
676 template<typename> class SingleTreeTraversalType>
681 DualTreeTraversalType,
682 SingleTreeTraversalType>::
683 AbsoluteError(const double newError)
685 CheckErrorValues(relError, newError);
689 template<
typename KernelType,
692 template<
typename TreeMetricType,
693 typename TreeStatType,
694 typename TreeMatType>
class TreeType,
695 template<typename> class DualTreeTraversalType,
696 template<typename> class SingleTreeTraversalType>
701 DualTreeTraversalType,
702 SingleTreeTraversalType>::
703 MCProb(const double newProb)
705 if (newProb < 0 || newProb >= 1)
707 throw std::invalid_argument(
"Monte Carlo probability must be a value " 708 "greater than or equal to 0 and smaller than" 714 template<
typename KernelType,
717 template<
typename TreeMetricType,
718 typename TreeStatType,
719 typename TreeMatType>
class TreeType,
720 template<typename> class DualTreeTraversalType,
721 template<typename> class SingleTreeTraversalType>
726 DualTreeTraversalType,
727 SingleTreeTraversalType>::
728 MCEntryCoef(const double newCoef)
732 throw std::invalid_argument(
"Monte Carlo entry coefficient must be a value " 733 "greater than or equal to 1");
735 mcEntryCoef = newCoef;
738 template<
typename KernelType,
741 template<
typename TreeMetricType,
742 typename TreeStatType,
743 typename TreeMatType>
class TreeType,
744 template<typename> class DualTreeTraversalType,
745 template<typename> class SingleTreeTraversalType>
750 DualTreeTraversalType,
751 SingleTreeTraversalType>::
752 MCBreakCoef(const double newCoef)
754 if (newCoef <= 0 || newCoef > 1)
756 throw std::invalid_argument(
"Monte Carlo break coefficient must be a value " 757 "greater than 0 and less than or equal to 1");
759 mcBreakCoef = newCoef;
762 template<
typename KernelType,
765 template<
typename TreeMetricType,
766 typename TreeStatType,
767 typename TreeMatType>
class TreeType,
768 template<typename> class DualTreeTraversalType,
769 template<typename> class SingleTreeTraversalType>
770 template<typename Archive>
775 DualTreeTraversalType,
776 SingleTreeTraversalType>::
777 serialize(Archive& ar, const uint32_t )
780 ar(CEREAL_NVP(relError));
781 ar(CEREAL_NVP(absError));
782 ar(CEREAL_NVP(trained));
783 ar(CEREAL_NVP(mode));
784 ar(CEREAL_NVP(monteCarlo));
785 ar(CEREAL_NVP(mcProb));
786 ar(CEREAL_NVP(initialSampleSize));
787 ar(CEREAL_NVP(mcEntryCoef));
788 ar(CEREAL_NVP(mcBreakCoef));
791 if (cereal::is_loading<Archive>())
793 if (ownsReferenceTree && referenceTree)
795 delete referenceTree;
796 delete oldFromNewReferences;
799 ownsReferenceTree =
true;
803 ar(CEREAL_NVP(kernel));
804 ar(CEREAL_NVP(metric));
809 template<
typename KernelType,
812 template<
typename TreeMetricType,
813 typename TreeStatType,
814 typename TreeMatType>
class TreeType,
815 template<typename> class DualTreeTraversalType,
816 template<typename> class SingleTreeTraversalType>
821 DualTreeTraversalType,
822 SingleTreeTraversalType>::
823 CheckErrorValues(const double relError, const double absError)
825 if (relError < 0 || relError > 1)
827 throw std::invalid_argument(
"Relative error tolerance must be a value " 832 throw std::invalid_argument(
"Absolute error tolerance must be a value " 833 "greater than or equal to 0");
837 template<
typename KernelType,
840 template<
typename TreeMetricType,
841 typename TreeStatType,
842 typename TreeMatType>
class TreeType,
843 template<typename> class DualTreeTraversalType,
844 template<typename> class SingleTreeTraversalType>
849 DualTreeTraversalType,
850 SingleTreeTraversalType>::
851 RearrangeEstimations(const std::vector<size_t>& oldFromNew,
852 arma::vec& estimations)
856 const size_t nQueries = oldFromNew.size();
857 arma::vec rearrangedEstimations(nQueries);
860 for (
size_t i = 0; i < nQueries; ++i)
861 rearrangedEstimations(oldFromNew.at(i)) = estimations(i);
863 estimations = std::move(rearrangedEstimations);
static constexpr bool monteCarlo
Whether to use Monte Carlo estimations when possible.
Definition: kde.hpp:44
static constexpr double mcProb
Probability of a Monte Carlo estimation to be bounded by the relative error tolerance.
Definition: kde.hpp:48
TreeType * BuildTree(MatType &&dataset, std::vector< size_t > &oldFromNew, const typename std::enable_if< tree::TreeTraits< TreeType >::RearrangesDataset >::type *=0)
Construct tree that rearranges the dataset.
Definition: kde_impl.hpp:21
static void Start(const std::string &name)
Start the given timer.
Definition: timers.cpp:28
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
A dual-tree traversal Rules class for cleaning used trees before performing kernel density estimation...
Definition: kde_rules.hpp:189
static constexpr KDEMode mode
KDE algorithm mode.
Definition: kde.hpp:41
static constexpr double relError
Relative error tolerance.
Definition: kde.hpp:35
static MLPACK_EXPORT util::PrefixedOutStream Warn
Prints warning messages prefixed with [WARN ].
Definition: log.hpp:87
static constexpr double mcEntryCoef
Monte Carlo entry coefficient.
Definition: kde.hpp:54
static constexpr double absError
Absolute error tolerance.
Definition: kde.hpp:38
Definition: hmm_train_main.cpp:300
A dual-tree traversal Rules class for kernel density estimation.
Definition: kde_rules.hpp:26
static void Stop(const std::string &name)
Stop the given timer.
Definition: timers.cpp:36
The TreeTraits class provides compile-time information on the characteristics of a given tree type...
Definition: tree_traits.hpp:77
static constexpr double mcBreakCoef
Monte Carlo break coefficient.
Definition: kde.hpp:57
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84
#define CEREAL_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_wrapper.hpp:96
static constexpr size_t initialSampleSize
Initial sample size for Monte Carlo estimations.
Definition: kde.hpp:51