12 #ifndef MLPACK_METHODS_FASTMKS_FASTMKS_IMPL_HPP 13 #define MLPACK_METHODS_FASTMKS_FASTMKS_IMPL_HPP 26 template<
typename KernelType,
28 template<
typename TreeMetricType,
29 typename TreeStatType,
30 typename TreeMatType>
class TreeType>
33 referenceSet(new MatType()),
37 singleMode(singleMode),
42 referenceTree =
new Tree(*referenceSet);
47 template<
typename KernelType,
49 template<
typename TreeMetricType,
50 typename TreeStatType,
51 typename TreeMatType>
class TreeType>
53 const MatType& referenceSet,
54 const bool singleMode,
56 referenceSet(&referenceSet),
60 singleMode(singleMode),
65 referenceTree =
new Tree(referenceSet);
70 template<
typename KernelType,
72 template<
typename TreeMetricType,
73 typename TreeStatType,
74 typename TreeMatType>
class TreeType>
77 const bool singleMode,
79 referenceSet(&referenceSet),
83 singleMode(singleMode),
91 referenceTree =
new Tree(referenceSet, metric);
97 template<
typename KernelType,
99 template<
typename TreeMetricType,
100 typename TreeStatType,
101 typename TreeMatType>
class TreeType>
103 MatType&& referenceSet,
104 const bool singleMode,
106 referenceSet(naive ? new MatType(std::move(referenceSet)) : NULL),
110 singleMode(singleMode),
116 referenceTree =
new Tree(std::move(referenceSet));
117 referenceSet = &referenceTree->Dataset();
123 template<
typename KernelType,
125 template<
typename TreeMetricType,
126 typename TreeStatType,
127 typename TreeMatType>
class TreeType>
130 const bool singleMode,
132 referenceSet(naive ? new MatType(std::move(referenceSet)) : NULL),
136 singleMode(singleMode),
145 referenceTree =
new Tree(referenceSet, metric);
146 referenceSet = &referenceTree->Dataset();
153 template<
typename KernelType,
155 template<
typename TreeMetricType,
156 typename TreeStatType,
157 typename TreeMatType>
class TreeType>
159 const bool singleMode) :
160 referenceSet(&referenceTree->Dataset()),
161 referenceTree(referenceTree),
164 singleMode(singleMode),
166 metric(referenceTree->Metric())
171 template<
typename KernelType,
173 template<
typename TreeMetricType,
174 typename TreeStatType,
175 typename TreeMatType>
class TreeType>
178 referenceTree(other.referenceTree ? new Tree(*other.referenceTree) : NULL),
179 treeOwner(other.referenceTree != NULL),
180 setOwner(other.referenceTree == NULL),
181 singleMode(other.singleMode),
187 referenceSet = &referenceTree->Dataset();
189 referenceSet =
new MatType(*other.referenceSet);
192 template<
typename KernelType,
194 template<
typename TreeMetricType,
195 typename TreeStatType,
196 typename TreeMatType>
class TreeType>
198 referenceSet(other.referenceSet),
199 referenceTree(other.referenceTree),
200 treeOwner(other.treeOwner),
201 setOwner(other.setOwner),
202 singleMode(other.singleMode),
204 metric(std::move(other.metric))
207 other.referenceSet = NULL;
208 other.referenceTree = NULL;
209 other.treeOwner =
false;
210 other.setOwner =
false;
211 other.singleMode =
false;
215 template<
typename KernelType,
217 template<
typename TreeMetricType,
218 typename TreeStatType,
219 typename TreeMatType>
class TreeType>
220 FastMKS<KernelType, MatType, TreeType>&
228 delete referenceTree;
232 referenceTree = NULL;
235 if (other.referenceTree)
237 referenceTree =
new Tree(*other.referenceTree);
238 referenceSet = &referenceTree->Dataset();
244 referenceSet =
new MatType(*other.referenceSet);
249 singleMode = other.singleMode;
253 template<
typename KernelType,
255 template<
typename TreeMetricType,
256 typename TreeStatType,
257 typename TreeMatType>
class TreeType>
258 FastMKS<KernelType, MatType, TreeType>&
263 referenceSet = other.referenceSet;
264 referenceTree = other.referenceTree;
265 treeOwner = other.treeOwner;
266 setOwner = other.setOwner;
267 singleMode = other.singleMode;
269 metric = std::move(other.metric);
272 other.referenceSet =
nullptr;
273 other.referenceTree =
nullptr;
274 other.treeOwner =
false;
275 other.setOwner =
false;
276 other.singleMode =
false;
282 template<
typename KernelType,
284 template<
typename TreeMetricType,
285 typename TreeStatType,
286 typename TreeMatType>
class TreeType>
287 FastMKS<KernelType, MatType, TreeType>::~FastMKS()
290 if (treeOwner && referenceTree)
291 delete referenceTree;
296 template<
typename KernelType,
298 template<
typename TreeMetricType,
299 typename TreeStatType,
300 typename TreeMatType>
class TreeType>
301 void FastMKS<KernelType, MatType, TreeType>::
Train(const MatType& referenceSet)
304 delete this->referenceSet;
306 this->referenceSet = &referenceSet;
307 this->setOwner =
false;
311 if (treeOwner && referenceTree)
312 delete referenceTree;
313 referenceTree =
new Tree(referenceSet, metric);
318 template<
typename KernelType,
320 template<
typename TreeMetricType,
321 typename TreeStatType,
322 typename TreeMatType>
class TreeType>
323 void FastMKS<KernelType, MatType, TreeType>::
Train(const MatType& referenceSet,
327 delete this->referenceSet;
329 this->referenceSet = &referenceSet;
331 this->setOwner =
false;
335 if (treeOwner && referenceTree)
336 delete referenceTree;
337 referenceTree =
new Tree(referenceSet, metric);
342 template<
typename KernelType,
344 template<
typename TreeMetricType,
345 typename TreeStatType,
346 typename TreeMatType>
class TreeType>
347 void FastMKS<KernelType, MatType, TreeType>::
Train(MatType&& referenceSet)
350 delete this->referenceSet;
354 if (treeOwner && referenceTree)
355 delete referenceTree;
356 referenceTree =
new Tree(std::move(referenceSet), metric);
357 referenceSet = referenceTree->Dataset();
363 this->referenceSet =
new MatType(std::move(referenceSet));
364 this->setOwner =
true;
368 template<
typename KernelType,
370 template<
typename TreeMetricType,
371 typename TreeStatType,
372 typename TreeMatType>
class TreeType>
373 void FastMKS<KernelType, MatType, TreeType>::
Train(MatType&& referenceSet,
377 delete this->referenceSet;
383 if (treeOwner && referenceTree)
384 delete referenceTree;
385 referenceTree =
new Tree(std::move(referenceSet), metric);
391 this->referenceSet =
new MatType(std::move(referenceSet));
392 this->setOwner =
true;
396 template<
typename KernelType,
398 template<
typename TreeMetricType,
399 typename TreeStatType,
400 typename TreeMatType>
class TreeType>
401 void FastMKS<KernelType, MatType, TreeType>::
Train(Tree* tree)
404 throw std::invalid_argument(
"cannot call FastMKS::Train() with a tree when " 405 "in naive search mode");
408 delete this->referenceSet;
410 this->referenceSet = &tree->Dataset();
412 this->setOwner =
false;
414 if (treeOwner && referenceTree)
415 delete referenceTree;
417 this->referenceTree = tree;
418 this->treeOwner =
true;
421 template<
typename KernelType,
423 template<
typename TreeMetricType,
424 typename TreeStatType,
425 typename TreeMatType>
class TreeType>
426 void FastMKS<KernelType, MatType, TreeType>::Search(
427 const MatType& querySet,
429 arma::Mat<size_t>& indices,
432 if (k > referenceSet->n_cols)
434 std::stringstream ss;
435 ss <<
"requested value of k (" << k <<
") is greater than the number of " 436 <<
"points in the reference set (" << referenceSet->n_cols <<
")";
437 throw std::invalid_argument(ss.str());
440 if (querySet.n_rows != referenceSet->n_rows)
442 std::stringstream ss;
443 ss <<
"The number of dimensions in the query set (" << querySet.n_rows
444 <<
") must be equal to the number of dimensions in the reference set (" 445 << referenceSet->n_rows <<
")!";
446 throw std::invalid_argument(ss.str());
452 indices.set_size(k, querySet.n_cols);
453 kernels.set_size(k, querySet.n_cols);
459 for (
size_t q = 0; q < querySet.n_cols; ++q)
461 const Candidate def = std::make_pair(-DBL_MAX,
size_t() - 1);
462 std::vector<Candidate> cList(k, def);
463 CandidateList pqueue(CandidateCmp(), std::move(cList));
465 for (
size_t r = 0; r < referenceSet->n_cols; ++r)
467 const double eval = metric.Kernel().Evaluate(querySet.col(q),
468 referenceSet->col(r));
470 if (eval > pqueue.top().first)
472 Candidate c = std::make_pair(eval, r);
478 for (
size_t j = 1; j <= k; ++j)
480 indices(k - j, q) = pqueue.top().second;
481 kernels(k - j, q) = pqueue.top().first;
497 RuleType rules(*referenceSet, querySet, k, metric.Kernel());
499 typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
501 for (
size_t i = 0; i < querySet.n_cols; ++i)
502 traverser.Traverse(i, *referenceTree);
504 Log::Info << rules.BaseCases() <<
" base cases." << std::endl;
505 Log::Info << rules.Scores() <<
" scores." << std::endl;
507 rules.GetResults(indices, kernels);
517 Tree queryTree(querySet);
520 Search(&queryTree, k, indices, kernels);
523 template<
typename KernelType,
525 template<
typename TreeMetricType,
526 typename TreeStatType,
527 typename TreeMatType>
class TreeType>
528 void FastMKS<KernelType, MatType, TreeType>::Search(
531 arma::Mat<size_t>& indices,
534 if (k > referenceSet->n_cols)
536 std::stringstream ss;
537 ss <<
"requested value of k (" << k <<
") is greater than the number of " 538 <<
"points in the reference set (" << referenceSet->n_cols <<
")";
539 throw std::invalid_argument(ss.str());
541 if (queryTree->Dataset().n_rows != referenceSet->n_rows)
543 std::stringstream ss;
544 ss <<
"The number of dimensions in the query set (" 545 << queryTree->Dataset().n_rows <<
") must be equal to the number of " 546 <<
"dimensions in the reference set (" << referenceSet->n_rows <<
")!";
547 throw std::invalid_argument(ss.str());
551 if (naive || singleMode)
553 throw std::invalid_argument(
"can't call Search() with a query tree when " 554 "single mode or naive search is enabled");
558 indices.set_size(k, queryTree->Dataset().n_cols);
559 kernels.set_size(k, queryTree->Dataset().n_cols);
563 RuleType rules(*referenceSet, queryTree->Dataset(), k, metric.Kernel());
565 typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
567 traverser.Traverse(*queryTree, *referenceTree);
569 Log::Info << rules.BaseCases() <<
" base cases." << std::endl;
570 Log::Info << rules.Scores() <<
" scores." << std::endl;
572 rules.GetResults(indices, kernels);
577 template<
typename KernelType,
579 template<
typename TreeMetricType,
580 typename TreeStatType,
581 typename TreeMatType>
class TreeType>
582 void FastMKS<KernelType, MatType, TreeType>::Search(
584 arma::Mat<size_t>& indices,
589 indices.set_size(k, referenceSet->n_cols);
590 kernels.set_size(k, referenceSet->n_cols);
596 for (
size_t q = 0; q < referenceSet->n_cols; ++q)
598 const Candidate def = std::make_pair(-DBL_MAX,
size_t() - 1);
599 std::vector<Candidate> cList(k, def);
600 CandidateList pqueue(CandidateCmp(), std::move(cList));
602 for (
size_t r = 0; r < referenceSet->n_cols; ++r)
607 const double eval = metric.Kernel().Evaluate(referenceSet->col(q),
608 referenceSet->col(r));
610 if (eval > pqueue.top().first)
612 Candidate c = std::make_pair(eval, r);
618 for (
size_t j = 1; j <= k; ++j)
620 indices(k - j, q) = pqueue.top().second;
621 kernels(k - j, q) = pqueue.top().first;
637 RuleType rules(*referenceSet, *referenceSet, k, metric.Kernel());
639 typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
641 for (
size_t i = 0; i < referenceSet->n_cols; ++i)
642 traverser.Traverse(i, *referenceTree);
645 const size_t numPrunes = traverser.NumPrunes();
647 Log::Info <<
"Pruned " << numPrunes <<
" nodes." << std::endl;
649 Log::Info << rules.BaseCases() <<
" base cases." << std::endl;
650 Log::Info << rules.Scores() <<
" scores." << std::endl;
652 rules.GetResults(indices, kernels);
661 Search(referenceTree, k, indices, kernels);
665 template<
typename KernelType,
667 template<
typename TreeMetricType,
668 typename TreeStatType,
669 typename TreeMatType>
class TreeType>
670 template<typename Archive>
671 void FastMKS<KernelType, MatType, TreeType>::serialize(
672 Archive& ar, const uint32_t )
675 ar(CEREAL_NVP(naive));
676 ar(CEREAL_NVP(singleMode));
682 if (cereal::is_loading<Archive>())
684 if (setOwner && referenceSet)
691 ar(CEREAL_NVP(metric));
696 if (cereal::is_loading<Archive>())
698 if (treeOwner && referenceTree)
699 delete referenceTree;
706 if (cereal::is_loading<Archive>())
708 if (setOwner && referenceSet)
711 referenceSet = &referenceTree->Dataset();
static void Start(const std::string &name)
Start the given timer.
Definition: timers.cpp:28
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The inner product metric, IPMetric, takes a given Mercer kernel (KernelType), and when Evaluate() is ...
Definition: ip_metric.hpp:32
FastMKS & operator=(const FastMKS &other)
Assign this model to be a copy of the given model.
Definition: fastmks_impl.hpp:221
Definition: hmm_train_main.cpp:300
The FastMKSRules class is a template helper class used by FastMKS class when performing exact max-ker...
Definition: fastmks_rules.hpp:34
static void Stop(const std::string &name)
Stop the given timer.
Definition: timers.cpp:36
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84
#define CEREAL_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_wrapper.hpp:96
An implementation of fast exact max-kernel search.
Definition: fastmks.hpp:63
A cover tree is a tree specifically designed to speed up nearest-neighbor computation in high-dimensi...
Definition: cover_tree.hpp:99