mlpack
fastmks_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_FASTMKS_FASTMKS_IMPL_HPP
13 #define MLPACK_METHODS_FASTMKS_FASTMKS_IMPL_HPP
14 
15 // In case it hasn't yet been included.
16 #include "fastmks.hpp"
17 
18 #include "fastmks_rules.hpp"
19 
21 
22 namespace mlpack {
23 namespace fastmks {
24 
25 // No data; create a model on an empty dataset.
26 template<typename KernelType,
27  typename MatType,
28  template<typename TreeMetricType,
29  typename TreeStatType,
30  typename TreeMatType> class TreeType>
31 FastMKS<KernelType, MatType, TreeType>::FastMKS(const bool singleMode,
32  const bool naive) :
33  referenceSet(new MatType()),
34  referenceTree(NULL),
35  treeOwner(true),
36  setOwner(true),
37  singleMode(singleMode),
38  naive(naive)
39 {
40  Timer::Start("tree_building");
41  if (!naive)
42  referenceTree = new Tree(*referenceSet);
43  Timer::Stop("tree_building");
44 }
45 
46 // No instantiated kernel.
47 template<typename KernelType,
48  typename MatType,
49  template<typename TreeMetricType,
50  typename TreeStatType,
51  typename TreeMatType> class TreeType>
52 FastMKS<KernelType, MatType, TreeType>::FastMKS(
53  const MatType& referenceSet,
54  const bool singleMode,
55  const bool naive) :
56  referenceSet(&referenceSet),
57  referenceTree(NULL),
58  treeOwner(true),
59  setOwner(false),
60  singleMode(singleMode),
61  naive(naive)
62 {
63  Timer::Start("tree_building");
64  if (!naive)
65  referenceTree = new Tree(referenceSet);
66  Timer::Stop("tree_building");
67 }
68 
69 // Instantiated kernel.
70 template<typename KernelType,
71  typename MatType,
72  template<typename TreeMetricType,
73  typename TreeStatType,
74  typename TreeMatType> class TreeType>
75 FastMKS<KernelType, MatType, TreeType>::FastMKS(const MatType& referenceSet,
76  KernelType& kernel,
77  const bool singleMode,
78  const bool naive) :
79  referenceSet(&referenceSet),
80  referenceTree(NULL),
81  treeOwner(true),
82  setOwner(false),
83  singleMode(singleMode),
84  naive(naive),
85  metric(kernel)
86 {
87  Timer::Start("tree_building");
88 
89  // If necessary, the reference tree should be built. There is no query tree.
90  if (!naive)
91  referenceTree = new Tree(referenceSet, metric);
92 
93  Timer::Stop("tree_building");
94 }
95 
96 // No instantiated kernel.
97 template<typename KernelType,
98  typename MatType,
99  template<typename TreeMetricType,
100  typename TreeStatType,
101  typename TreeMatType> class TreeType>
102 FastMKS<KernelType, MatType, TreeType>::FastMKS(
103  MatType&& referenceSet,
104  const bool singleMode,
105  const bool naive) :
106  referenceSet(naive ? new MatType(std::move(referenceSet)) : NULL),
107  referenceTree(NULL),
108  treeOwner(true),
109  setOwner(naive),
110  singleMode(singleMode),
111  naive(naive)
112 {
113  Timer::Start("tree_building");
114  if (!naive)
115  {
116  referenceTree = new Tree(std::move(referenceSet));
117  referenceSet = &referenceTree->Dataset();
118  }
119  Timer::Stop("tree_building");
120 }
121 
122 // Instantiated kernel.
123 template<typename KernelType,
124  typename MatType,
125  template<typename TreeMetricType,
126  typename TreeStatType,
127  typename TreeMatType> class TreeType>
128 FastMKS<KernelType, MatType, TreeType>::FastMKS(MatType&& referenceSet,
129  KernelType& kernel,
130  const bool singleMode,
131  const bool naive) :
132  referenceSet(naive ? new MatType(std::move(referenceSet)) : NULL),
133  referenceTree(NULL),
134  treeOwner(true),
135  setOwner(naive),
136  singleMode(singleMode),
137  naive(naive),
138  metric(kernel)
139 {
140  Timer::Start("tree_building");
141 
142  // If necessary, the reference tree should be built. There is no query tree.
143  if (!naive)
144  {
145  referenceTree = new Tree(referenceSet, metric);
146  referenceSet = &referenceTree->Dataset();
147  }
148 
149  Timer::Stop("tree_building");
150 }
151 
152 // One dataset, pre-built tree.
153 template<typename KernelType,
154  typename MatType,
155  template<typename TreeMetricType,
156  typename TreeStatType,
157  typename TreeMatType> class TreeType>
158 FastMKS<KernelType, MatType, TreeType>::FastMKS(Tree* referenceTree,
159  const bool singleMode) :
160  referenceSet(&referenceTree->Dataset()),
161  referenceTree(referenceTree),
162  treeOwner(false),
163  setOwner(false),
164  singleMode(singleMode),
165  naive(false),
166  metric(referenceTree->Metric())
167 {
168  // Nothing to do.
169 }
170 
171 template<typename KernelType,
172  typename MatType,
173  template<typename TreeMetricType,
174  typename TreeStatType,
175  typename TreeMatType> class TreeType>
176 FastMKS<KernelType, MatType, TreeType>::FastMKS(const FastMKS& other) :
177  referenceSet(NULL),
178  referenceTree(other.referenceTree ? new Tree(*other.referenceTree) : NULL),
179  treeOwner(other.referenceTree != NULL),
180  setOwner(other.referenceTree == NULL),
181  singleMode(other.singleMode),
182  naive(other.naive),
183  metric(other.metric)
184 {
185  // Set reference set correctly.
186  if (referenceTree)
187  referenceSet = &referenceTree->Dataset();
188  else
189  referenceSet = new MatType(*other.referenceSet);
190 }
191 
192 template<typename KernelType,
193  typename MatType,
194  template<typename TreeMetricType,
195  typename TreeStatType,
196  typename TreeMatType> class TreeType>
197 FastMKS<KernelType, MatType, TreeType>::FastMKS(FastMKS&& other) :
198  referenceSet(other.referenceSet),
199  referenceTree(other.referenceTree),
200  treeOwner(other.treeOwner),
201  setOwner(other.setOwner),
202  singleMode(other.singleMode),
203  naive(other.naive),
204  metric(std::move(other.metric))
205 {
206  // Clear information from the other.
207  other.referenceSet = NULL;
208  other.referenceTree = NULL;
209  other.treeOwner = false;
210  other.setOwner = false;
211  other.singleMode = false;
212  other.naive = false;
213 }
214 
215 template<typename KernelType,
216  typename MatType,
217  template<typename TreeMetricType,
218  typename TreeStatType,
219  typename TreeMatType> class TreeType>
220 FastMKS<KernelType, MatType, TreeType>&
222 {
223  if (this == &other)
224  return *this;
225 
226  // Clear anything we currently have.
227  if (treeOwner)
228  delete referenceTree;
229  if (setOwner)
230  delete referenceSet;
231 
232  referenceTree = NULL;
233  referenceSet = NULL;
234 
235  if (other.referenceTree)
236  {
237  referenceTree = new Tree(*other.referenceTree);
238  referenceSet = &referenceTree->Dataset();
239  treeOwner = true;
240  setOwner = false;
241  }
242  else
243  {
244  referenceSet = new MatType(*other.referenceSet);
245  treeOwner = false;
246  setOwner = true;
247  }
248 
249  singleMode = other.singleMode;
250  naive = other.naive;
251 }
252 
253 template<typename KernelType,
254  typename MatType,
255  template<typename TreeMetricType,
256  typename TreeStatType,
257  typename TreeMatType> class TreeType>
258 FastMKS<KernelType, MatType, TreeType>&
260 {
261  if (this != &other)
262  {
263  referenceSet = other.referenceSet;
264  referenceTree = other.referenceTree;
265  treeOwner = other.treeOwner;
266  setOwner = other.setOwner;
267  singleMode = other.singleMode;
268  naive = other.naive;
269  metric = std::move(other.metric);
270 
271  // Clear information from the other.
272  other.referenceSet = nullptr;
273  other.referenceTree = nullptr;
274  other.treeOwner = false;
275  other.setOwner = false;
276  other.singleMode = false;
277  other.naive = false;
278  }
279  return *this;
280 }
281 
282 template<typename KernelType,
283  typename MatType,
284  template<typename TreeMetricType,
285  typename TreeStatType,
286  typename TreeMatType> class TreeType>
287 FastMKS<KernelType, MatType, TreeType>::~FastMKS()
288 {
289  // If we created the trees, we must delete them.
290  if (treeOwner && referenceTree)
291  delete referenceTree;
292  if (setOwner)
293  delete referenceSet;
294 }
295 
296 template<typename KernelType,
297  typename MatType,
298  template<typename TreeMetricType,
299  typename TreeStatType,
300  typename TreeMatType> class TreeType>
301 void FastMKS<KernelType, MatType, TreeType>::Train(const MatType& referenceSet)
302 {
303  if (setOwner)
304  delete this->referenceSet;
305 
306  this->referenceSet = &referenceSet;
307  this->setOwner = false;
308 
309  if (!naive)
310  {
311  if (treeOwner && referenceTree)
312  delete referenceTree;
313  referenceTree = new Tree(referenceSet, metric);
314  treeOwner = true;
315  }
316 }
317 
318 template<typename KernelType,
319  typename MatType,
320  template<typename TreeMetricType,
321  typename TreeStatType,
322  typename TreeMatType> class TreeType>
323 void FastMKS<KernelType, MatType, TreeType>::Train(const MatType& referenceSet,
324  KernelType& kernel)
325 {
326  if (setOwner)
327  delete this->referenceSet;
328 
329  this->referenceSet = &referenceSet;
330  this->metric = metric::IPMetric<KernelType>(kernel);
331  this->setOwner = false;
332 
333  if (!naive)
334  {
335  if (treeOwner && referenceTree)
336  delete referenceTree;
337  referenceTree = new Tree(referenceSet, metric);
338  treeOwner = true;
339  }
340 }
341 
342 template<typename KernelType,
343  typename MatType,
344  template<typename TreeMetricType,
345  typename TreeStatType,
346  typename TreeMatType> class TreeType>
347 void FastMKS<KernelType, MatType, TreeType>::Train(MatType&& referenceSet)
348 {
349  if (setOwner)
350  delete this->referenceSet;
351 
352  if (!naive)
353  {
354  if (treeOwner && referenceTree)
355  delete referenceTree;
356  referenceTree = new Tree(std::move(referenceSet), metric);
357  referenceSet = referenceTree->Dataset();
358  treeOwner = true;
359  setOwner = false;
360  }
361  else
362  {
363  this->referenceSet = new MatType(std::move(referenceSet));
364  this->setOwner = true;
365  }
366 }
367 
368 template<typename KernelType,
369  typename MatType,
370  template<typename TreeMetricType,
371  typename TreeStatType,
372  typename TreeMatType> class TreeType>
373 void FastMKS<KernelType, MatType, TreeType>::Train(MatType&& referenceSet,
374  KernelType& kernel)
375 {
376  if (setOwner)
377  delete this->referenceSet;
378 
379  this->metric = metric::IPMetric<KernelType>(kernel);
380 
381  if (!naive)
382  {
383  if (treeOwner && referenceTree)
384  delete referenceTree;
385  referenceTree = new Tree(std::move(referenceSet), metric);
386  treeOwner = true;
387  setOwner = false;
388  }
389  else
390  {
391  this->referenceSet = new MatType(std::move(referenceSet));
392  this->setOwner = true;
393  }
394 }
395 
396 template<typename KernelType,
397  typename MatType,
398  template<typename TreeMetricType,
399  typename TreeStatType,
400  typename TreeMatType> class TreeType>
401 void FastMKS<KernelType, MatType, TreeType>::Train(Tree* tree)
402 {
403  if (naive)
404  throw std::invalid_argument("cannot call FastMKS::Train() with a tree when "
405  "in naive search mode");
406 
407  if (setOwner)
408  delete this->referenceSet;
409 
410  this->referenceSet = &tree->Dataset();
411  this->metric = metric::IPMetric<KernelType>(tree->Metric().Kernel());
412  this->setOwner = false;
413 
414  if (treeOwner && referenceTree)
415  delete referenceTree;
416 
417  this->referenceTree = tree;
418  this->treeOwner = true;
419 }
420 
421 template<typename KernelType,
422  typename MatType,
423  template<typename TreeMetricType,
424  typename TreeStatType,
425  typename TreeMatType> class TreeType>
426 void FastMKS<KernelType, MatType, TreeType>::Search(
427  const MatType& querySet,
428  const size_t k,
429  arma::Mat<size_t>& indices,
430  arma::mat& kernels)
431 {
432  if (k > referenceSet->n_cols)
433  {
434  std::stringstream ss;
435  ss << "requested value of k (" << k << ") is greater than the number of "
436  << "points in the reference set (" << referenceSet->n_cols << ")";
437  throw std::invalid_argument(ss.str());
438  }
439 
440  if (querySet.n_rows != referenceSet->n_rows)
441  {
442  std::stringstream ss;
443  ss << "The number of dimensions in the query set (" << querySet.n_rows
444  << ") must be equal to the number of dimensions in the reference set ("
445  << referenceSet->n_rows << ")!";
446  throw std::invalid_argument(ss.str());
447  }
448 
449  Timer::Start("computing_products");
450 
451  // No remapping will be necessary because we are using the cover tree.
452  indices.set_size(k, querySet.n_cols);
453  kernels.set_size(k, querySet.n_cols);
454 
455  // Naive implementation.
456  if (naive)
457  {
458  // Simple double loop. Stupid, slow, but a good benchmark.
459  for (size_t q = 0; q < querySet.n_cols; ++q)
460  {
461  const Candidate def = std::make_pair(-DBL_MAX, size_t() - 1);
462  std::vector<Candidate> cList(k, def);
463  CandidateList pqueue(CandidateCmp(), std::move(cList));
464 
465  for (size_t r = 0; r < referenceSet->n_cols; ++r)
466  {
467  const double eval = metric.Kernel().Evaluate(querySet.col(q),
468  referenceSet->col(r));
469 
470  if (eval > pqueue.top().first)
471  {
472  Candidate c = std::make_pair(eval, r);
473  pqueue.pop();
474  pqueue.push(c);
475  }
476  }
477 
478  for (size_t j = 1; j <= k; ++j)
479  {
480  indices(k - j, q) = pqueue.top().second;
481  kernels(k - j, q) = pqueue.top().first;
482  pqueue.pop();
483  }
484  }
485 
486  Timer::Stop("computing_products");
487 
488  return;
489  }
490 
491  // Single-tree implementation.
492  if (singleMode)
493  {
494  // Create rules object (this will store the results). This constructor
495  // precalculates each self-kernel value.
496  typedef FastMKSRules<KernelType, Tree> RuleType;
497  RuleType rules(*referenceSet, querySet, k, metric.Kernel());
498 
499  typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
500 
501  for (size_t i = 0; i < querySet.n_cols; ++i)
502  traverser.Traverse(i, *referenceTree);
503 
504  Log::Info << rules.BaseCases() << " base cases." << std::endl;
505  Log::Info << rules.Scores() << " scores." << std::endl;
506 
507  rules.GetResults(indices, kernels);
508 
509  Timer::Stop("computing_products");
510  return;
511  }
512 
513  // Dual-tree implementation. First, we need to build the query tree. We are
514  // assuming it doesn't map anything...
515  Timer::Stop("computing_products");
516  Timer::Start("tree_building");
517  Tree queryTree(querySet);
518  Timer::Stop("tree_building");
519 
520  Search(&queryTree, k, indices, kernels);
521 }
522 
523 template<typename KernelType,
524  typename MatType,
525  template<typename TreeMetricType,
526  typename TreeStatType,
527  typename TreeMatType> class TreeType>
528 void FastMKS<KernelType, MatType, TreeType>::Search(
529  Tree* queryTree,
530  const size_t k,
531  arma::Mat<size_t>& indices,
532  arma::mat& kernels)
533 {
534  if (k > referenceSet->n_cols)
535  {
536  std::stringstream ss;
537  ss << "requested value of k (" << k << ") is greater than the number of "
538  << "points in the reference set (" << referenceSet->n_cols << ")";
539  throw std::invalid_argument(ss.str());
540  }
541  if (queryTree->Dataset().n_rows != referenceSet->n_rows)
542  {
543  std::stringstream ss;
544  ss << "The number of dimensions in the query set ("
545  << queryTree->Dataset().n_rows << ") must be equal to the number of "
546  << "dimensions in the reference set (" << referenceSet->n_rows << ")!";
547  throw std::invalid_argument(ss.str());
548  }
549 
550  // If either naive mode or single mode is specified, this must fail.
551  if (naive || singleMode)
552  {
553  throw std::invalid_argument("can't call Search() with a query tree when "
554  "single mode or naive search is enabled");
555  }
556 
557  // No remapping will be necessary because we are using the cover tree.
558  indices.set_size(k, queryTree->Dataset().n_cols);
559  kernels.set_size(k, queryTree->Dataset().n_cols);
560 
561  Timer::Start("computing_products");
562  typedef FastMKSRules<KernelType, Tree> RuleType;
563  RuleType rules(*referenceSet, queryTree->Dataset(), k, metric.Kernel());
564 
565  typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
566 
567  traverser.Traverse(*queryTree, *referenceTree);
568 
569  Log::Info << rules.BaseCases() << " base cases." << std::endl;
570  Log::Info << rules.Scores() << " scores." << std::endl;
571 
572  rules.GetResults(indices, kernels);
573 
574  Timer::Stop("computing_products");
575 }
576 
577 template<typename KernelType,
578  typename MatType,
579  template<typename TreeMetricType,
580  typename TreeStatType,
581  typename TreeMatType> class TreeType>
582 void FastMKS<KernelType, MatType, TreeType>::Search(
583  const size_t k,
584  arma::Mat<size_t>& indices,
585  arma::mat& kernels)
586 {
587  // No remapping will be necessary because we are using the cover tree.
588  Timer::Start("computing_products");
589  indices.set_size(k, referenceSet->n_cols);
590  kernels.set_size(k, referenceSet->n_cols);
591 
592  // Naive implementation.
593  if (naive)
594  {
595  // Simple double loop. Stupid, slow, but a good benchmark.
596  for (size_t q = 0; q < referenceSet->n_cols; ++q)
597  {
598  const Candidate def = std::make_pair(-DBL_MAX, size_t() - 1);
599  std::vector<Candidate> cList(k, def);
600  CandidateList pqueue(CandidateCmp(), std::move(cList));
601 
602  for (size_t r = 0; r < referenceSet->n_cols; ++r)
603  {
604  if (q == r)
605  continue; // Don't return the point as its own candidate.
606 
607  const double eval = metric.Kernel().Evaluate(referenceSet->col(q),
608  referenceSet->col(r));
609 
610  if (eval > pqueue.top().first)
611  {
612  Candidate c = std::make_pair(eval, r);
613  pqueue.pop();
614  pqueue.push(c);
615  }
616  }
617 
618  for (size_t j = 1; j <= k; ++j)
619  {
620  indices(k - j, q) = pqueue.top().second;
621  kernels(k - j, q) = pqueue.top().first;
622  pqueue.pop();
623  }
624  }
625 
626  Timer::Stop("computing_products");
627 
628  return;
629  }
630 
631  // Single-tree implementation.
632  if (singleMode)
633  {
634  // Create rules object (this will store the results). This constructor
635  // precalculates each self-kernel value.
636  typedef FastMKSRules<KernelType, Tree> RuleType;
637  RuleType rules(*referenceSet, *referenceSet, k, metric.Kernel());
638 
639  typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
640 
641  for (size_t i = 0; i < referenceSet->n_cols; ++i)
642  traverser.Traverse(i, *referenceTree);
643 
644  // Save the number of pruned nodes.
645  const size_t numPrunes = traverser.NumPrunes();
646 
647  Log::Info << "Pruned " << numPrunes << " nodes." << std::endl;
648 
649  Log::Info << rules.BaseCases() << " base cases." << std::endl;
650  Log::Info << rules.Scores() << " scores." << std::endl;
651 
652  rules.GetResults(indices, kernels);
653 
654  Timer::Stop("computing_products");
655  return;
656  }
657 
658  // Dual-tree implementation.
659  Timer::Stop("computing_products");
660 
661  Search(referenceTree, k, indices, kernels);
662 }
663 
665 template<typename KernelType,
666  typename MatType,
667  template<typename TreeMetricType,
668  typename TreeStatType,
669  typename TreeMatType> class TreeType>
670 template<typename Archive>
671 void FastMKS<KernelType, MatType, TreeType>::serialize(
672  Archive& ar, const uint32_t /* version */)
673 {
674  // Serialize preferences for search.
675  ar(CEREAL_NVP(naive));
676  ar(CEREAL_NVP(singleMode));
677 
678  // If we are doing naive search, serialize the dataset. Otherwise we
679  // serialize the tree.
680  if (naive)
681  {
682  if (cereal::is_loading<Archive>())
683  {
684  if (setOwner && referenceSet)
685  delete referenceSet;
686 
687  setOwner = true;
688  }
689 
690  ar(CEREAL_POINTER(const_cast<MatType*&>(referenceSet)));
691  ar(CEREAL_NVP(metric));
692  }
693  else
694  {
695  // Delete the current reference tree, if necessary.
696  if (cereal::is_loading<Archive>())
697  {
698  if (treeOwner && referenceTree)
699  delete referenceTree;
700 
701  treeOwner = true;
702  }
703 
704  ar(CEREAL_POINTER(referenceTree));
705 
706  if (cereal::is_loading<Archive>())
707  {
708  if (setOwner && referenceSet)
709  delete referenceSet;
710 
711  referenceSet = &referenceTree->Dataset();
712  metric = metric::IPMetric<KernelType>(referenceTree->Metric().Kernel());
713  setOwner = false;
714  }
715  }
716 }
717 
718 } // namespace fastmks
719 } // namespace mlpack
720 
721 #endif
static void Start(const std::string &name)
Start the given timer.
Definition: timers.cpp:28
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The inner product metric, IPMetric, takes a given Mercer kernel (KernelType), and when Evaluate() is ...
Definition: ip_metric.hpp:32
FastMKS & operator=(const FastMKS &other)
Assign this model to be a copy of the given model.
Definition: fastmks_impl.hpp:221
Definition: hmm_train_main.cpp:300
The FastMKSRules class is a template helper class used by FastMKS class when performing exact max-ker...
Definition: fastmks_rules.hpp:34
static void Stop(const std::string &name)
Stop the given timer.
Definition: timers.cpp:36
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84
#define CEREAL_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_wrapper.hpp:96
An implementation of fast exact max-kernel search.
Definition: fastmks.hpp:63
A cover tree is a tree specifically designed to speed up nearest-neighbor computation in high-dimensi...
Definition: cover_tree.hpp:99