13 #ifndef MLPACK_METHODS_RANN_RA_SEARCH_IMPL_HPP 14 #define MLPACK_METHODS_RANN_RA_SEARCH_IMPL_HPP 26 template<
typename TreeType,
typename MatType>
29 std::vector<size_t>& oldFromNew,
30 typename std::enable_if<
33 return new TreeType(std::forward<MatType>(dataset), oldFromNew);
37 template<
typename TreeType,
typename MatType>
40 const std::vector<size_t>& ,
41 const typename std::enable_if<
44 return new TreeType(std::forward<MatType>(dataset));
50 template<
typename SortPolicy,
53 template<
typename TreeMetricType,
54 typename TreeStatType,
55 typename TreeMatType>
class TreeType>
56 RASearch<SortPolicy, MetricType, MatType, TreeType>::
57 RASearch(MatType referenceSetIn,
59 const bool singleMode,
62 const bool sampleAtLeaves,
63 const bool firstLeafExact,
64 const size_t singleSampleLimit,
65 const MetricType metric) :
66 referenceTree(naive ? NULL : aux::BuildTree<Tree>(
67 std::move(referenceSetIn), oldFromNewReferences)),
68 referenceSet(naive ? new MatType(std::move(referenceSetIn)) :
69 &referenceTree->Dataset()),
73 singleMode(!naive && singleMode),
76 sampleAtLeaves(sampleAtLeaves),
77 firstLeafExact(firstLeafExact),
78 singleSampleLimit(singleSampleLimit),
85 template<
typename SortPolicy,
88 template<
typename TreeMetricType,
89 typename TreeStatType,
90 typename TreeMatType>
class TreeType>
91 RASearch<SortPolicy, MetricType, MatType, TreeType>::
92 RASearch(Tree* referenceTree,
93 const bool singleMode,
96 const bool sampleAtLeaves,
97 const bool firstLeafExact,
98 const size_t singleSampleLimit,
99 const MetricType metric) :
100 referenceTree(referenceTree),
101 referenceSet(&referenceTree->Dataset()),
105 singleMode(singleMode),
108 sampleAtLeaves(sampleAtLeaves),
109 firstLeafExact(firstLeafExact),
110 singleSampleLimit(singleSampleLimit),
116 template<
typename SortPolicy,
119 template<
typename TreeMetricType,
120 typename TreeStatType,
121 typename TreeMatType>
class TreeType>
122 RASearch<SortPolicy, MetricType, MatType, TreeType>::
123 RASearch(const bool naive,
124 const bool singleMode,
127 const bool sampleAtLeaves,
128 const bool firstLeafExact,
129 const size_t singleSampleLimit,
130 const MetricType metric) :
132 referenceSet(new MatType()),
136 singleMode(singleMode),
139 sampleAtLeaves(sampleAtLeaves),
140 firstLeafExact(firstLeafExact),
141 singleSampleLimit(singleSampleLimit),
147 referenceTree = aux::BuildTree<Tree>(*referenceSet, oldFromNewReferences);
156 template<
typename SortPolicy,
159 template<
typename TreeMetricType,
160 typename TreeStatType,
161 typename TreeMatType>
class TreeType>
162 RASearch<SortPolicy, MetricType, MatType, TreeType>::
165 if (treeOwner && referenceTree)
166 delete referenceTree;
172 template<
typename SortPolicy,
175 template<
typename TreeMetricType,
176 typename TreeStatType,
177 typename TreeMatType>
class TreeType>
178 void RASearch<SortPolicy, MetricType, MatType, TreeType>::
Train(
179 MatType referenceSet)
182 if (treeOwner && referenceTree)
183 delete referenceTree;
188 referenceTree = aux::BuildTree<Tree>(std::move(referenceSet),
189 oldFromNewReferences);
198 if (setOwner && this->referenceSet)
199 delete this->referenceSet;
203 this->referenceSet = &referenceTree->Dataset();
208 this->referenceSet =
new MatType(std::move(referenceSet));
214 template<
typename SortPolicy,
217 template<
typename TreeMetricType,
218 typename TreeStatType,
219 typename TreeMatType>
class TreeType>
220 void RASearch<SortPolicy, MetricType, MatType, TreeType>::
Train(
224 throw std::invalid_argument(
"cannot train on given reference tree when " 225 "naive search (without trees) is desired");
227 if (treeOwner && referenceTree)
228 delete this->referenceTree;
229 if (setOwner && referenceSet)
230 delete this->referenceSet;
232 this->referenceTree = referenceTree;
233 this->referenceSet = &referenceTree->Dataset();
242 template<
typename SortPolicy,
245 template<
typename TreeMetricType,
246 typename TreeStatType,
247 typename TreeMatType>
class TreeType>
248 void RASearch<SortPolicy, MetricType, MatType, TreeType>::
249 Search(const MatType& querySet,
251 arma::Mat<size_t>& neighbors,
252 arma::mat& distances)
254 if (k > referenceSet->n_cols)
256 std::stringstream ss;
257 ss <<
"requested value of k (" << k <<
") is greater than the number of " 258 <<
"points in the reference set (" << referenceSet->n_cols <<
")";
259 throw std::invalid_argument(ss.str());
265 std::vector<size_t> oldFromNewQueries;
271 arma::Mat<size_t>* neighborPtr = &neighbors;
272 arma::mat* distancePtr = &distances;
278 if (!singleMode && !naive)
280 distancePtr =
new arma::mat;
281 neighborPtr =
new arma::Mat<size_t>;
285 neighborPtr =
new arma::Mat<size_t>;
289 neighborPtr->set_size(k, querySet.n_cols);
290 distancePtr->set_size(k, querySet.n_cols);
296 RuleType rules(*referenceSet, querySet, k, metric, tau, alpha, naive,
297 sampleAtLeaves, firstLeafExact, singleSampleLimit,
false);
303 arma::uvec distinctSamples;
309 for (
size_t i = 0; i < querySet.n_cols; ++i)
310 for (
size_t j = 0; j < distinctSamples.n_elem; ++j)
311 rules.BaseCase(i, (
size_t) distinctSamples[j]);
313 rules.GetResults(*neighborPtr, *distancePtr);
317 RuleType rules(*referenceSet, querySet, k, metric, tau, alpha, naive,
318 sampleAtLeaves, firstLeafExact, singleSampleLimit,
false);
322 if (!referenceTree->IsLeaf())
324 Log::Info <<
"Performing single-tree traversal..." << std::endl;
327 typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
330 for (
size_t i = 0; i < querySet.n_cols; ++i)
331 traverser.Traverse(i, *referenceTree);
333 Log::Info <<
"Single-tree traversal complete." << std::endl;
334 Log::Info <<
"Average number of distance calculations per query point: " 335 << (rules.NumDistComputations() / querySet.n_cols) <<
"." 339 rules.GetResults(*neighborPtr, *distancePtr);
343 Log::Info <<
"Performing dual-tree traversal..." << std::endl;
348 Tree* queryTree = aux::BuildTree<Tree>(
const_cast<MatType&
>(querySet),
353 RuleType rules(*referenceSet, queryTree->Dataset(), k, metric, tau, alpha,
354 naive, sampleAtLeaves, firstLeafExact, singleSampleLimit,
false);
355 typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
357 Log::Info <<
"Query statistic pre-search: " 358 << queryTree->Stat().NumSamplesMade() << std::endl;
360 traverser.Traverse(*queryTree, *referenceTree);
362 Log::Info <<
"Dual-tree traversal complete." << std::endl;
363 Log::Info <<
"Average number of distance calculations per query point: " 364 << (rules.NumDistComputations() / querySet.n_cols) <<
"." << std::endl;
366 rules.GetResults(*neighborPtr, *distancePtr);
376 if (!singleMode && !naive && treeOwner)
379 neighbors.set_size(k, querySet.n_cols);
380 distances.set_size(k, querySet.n_cols);
382 for (
size_t i = 0; i < distances.n_cols; ++i)
385 distances.col(oldFromNewQueries[i]) = distancePtr->col(i);
388 for (
size_t j = 0; j < distances.n_rows; ++j)
390 neighbors(j, oldFromNewQueries[i]) =
391 oldFromNewReferences[(*neighborPtr)(j, i)];
399 else if (!singleMode && !naive)
402 neighbors.set_size(k, querySet.n_cols);
403 distances.set_size(k, querySet.n_cols);
405 for (
size_t i = 0; i < distances.n_cols; ++i)
408 const size_t queryMapping = oldFromNewQueries[i];
409 distances.col(queryMapping) = distancePtr->col(i);
410 neighbors.col(queryMapping) = neighborPtr->col(i);
420 neighbors.set_size(k, querySet.n_cols);
423 for (
size_t i = 0; i < neighbors.n_cols; ++i)
424 for (
size_t j = 0; j < neighbors.n_rows; ++j)
425 neighbors(j, i) = oldFromNewReferences[(*neighborPtr)(j, i)];
433 template<
typename SortPolicy,
436 template<
typename TreeMetricType,
437 typename TreeStatType,
438 typename TreeMatType>
class TreeType>
439 void RASearch<SortPolicy, MetricType, MatType, TreeType>::Search(
442 arma::Mat<size_t>& neighbors,
443 arma::mat& distances)
448 const MatType& querySet = queryTree->Dataset();
451 if (singleMode || naive)
452 throw std::invalid_argument(
"cannot call NeighborSearch::Search() with a " 453 "query tree when naive or singleMode are set to true");
456 arma::Mat<size_t>* neighborPtr = &neighbors;
459 neighborPtr =
new arma::Mat<size_t>;
461 neighborPtr->set_size(k, querySet.n_cols);
462 distances.set_size(k, querySet.n_cols);
466 RuleType rules(*referenceSet, queryTree->Dataset(), k, metric, tau, alpha,
467 naive, sampleAtLeaves, firstLeafExact, singleSampleLimit,
false);
470 typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
471 traverser.Traverse(*queryTree, *referenceTree);
473 rules.GetResults(*neighborPtr, distances);
481 neighbors.set_size(k, querySet.n_cols);
484 for (
size_t i = 0; i < neighbors.n_cols; ++i)
485 for (
size_t j = 0; j < neighbors.n_rows; ++j)
486 neighbors(j, i) = oldFromNewReferences[(*neighborPtr)(j, i)];
493 template<
typename SortPolicy,
496 template<
typename TreeMetricType,
497 typename TreeStatType,
498 typename TreeMatType>
class TreeType>
499 void RASearch<SortPolicy, MetricType, MatType, TreeType>::Search(
501 arma::Mat<size_t>& neighbors,
502 arma::mat& distances)
506 arma::Mat<size_t>* neighborPtr = &neighbors;
507 arma::mat* distancePtr = &distances;
512 distancePtr =
new arma::mat;
513 neighborPtr =
new arma::Mat<size_t>;
517 neighborPtr->set_size(k, referenceSet->n_cols);
518 distancePtr->set_size(k, referenceSet->n_cols);
522 RuleType rules(*referenceSet, *referenceSet, k, metric, tau, alpha, naive,
523 sampleAtLeaves, firstLeafExact, singleSampleLimit,
true );
531 arma::uvec distinctSamples;
536 for (
size_t i = 0; i < referenceSet->n_cols; ++i)
537 for (
size_t j = 0; j < referenceSet->n_cols; ++j)
538 rules.BaseCase(i, j);
543 typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
546 for (
size_t i = 0; i < referenceSet->n_cols; ++i)
547 traverser.Traverse(i, *referenceTree);
552 typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
554 traverser.Traverse(*referenceTree, *referenceTree);
557 rules.GetResults(*neighborPtr, *distancePtr);
564 neighbors.set_size(k, referenceSet->n_cols);
565 distances.set_size(k, referenceSet->n_cols);
567 for (
size_t i = 0; i < distances.n_cols; ++i)
570 const size_t refMapping = oldFromNewReferences[i];
571 distances.col(refMapping) = distancePtr->col(i);
574 for (
size_t j = 0; j < distances.n_rows; ++j)
575 neighbors(j, refMapping) = oldFromNewReferences[(*neighborPtr)(j, i)];
584 template<
typename SortPolicy,
587 template<
typename TreeMetricType,
588 typename TreeStatType,
589 typename TreeMatType>
class TreeType>
590 void RASearch<SortPolicy, MetricType, MatType, TreeType>::ResetQueryTree(
591 Tree* queryNode) const
593 queryNode->Stat().Bound() = SortPolicy::WorstDistance();
594 queryNode->Stat().NumSamplesMade() = 0;
596 for (
size_t i = 0; i < queryNode->NumChildren(); ++i)
597 ResetQueryTree(&queryNode->Child(i));
600 template<
typename SortPolicy,
603 template<
typename TreeMetricType,
604 typename TreeStatType,
605 typename TreeMatType>
class TreeType>
606 template<typename Archive>
607 void RASearch<SortPolicy, MetricType, MatType, TreeType>::serialize(
608 Archive& ar, const uint32_t )
611 ar(CEREAL_NVP(naive));
612 ar(CEREAL_NVP(singleMode));
615 ar(CEREAL_NVP(alpha));
616 ar(CEREAL_NVP(sampleAtLeaves));
617 ar(CEREAL_NVP(firstLeafExact));
618 ar(CEREAL_NVP(singleSampleLimit));
624 if (cereal::is_loading<Archive>())
626 if (setOwner && referenceSet)
632 ar(CEREAL_NVP(metric));
635 if (cereal::is_loading<Archive>())
637 if (treeOwner && referenceTree)
638 delete referenceTree;
640 referenceTree = NULL;
641 oldFromNewReferences.clear();
648 if (cereal::is_loading<Archive>())
650 if (treeOwner && referenceTree)
651 delete referenceTree;
658 ar(CEREAL_NVP(oldFromNewReferences));
662 if (cereal::is_loading<Archive>())
664 if (setOwner && referenceSet)
667 referenceSet = &referenceTree->Dataset();
668 metric = referenceTree->Metric();
void ObtainDistinctSamples(const size_t loInclusive, const size_t hiExclusive, const size_t maxNumSamples, arma::uvec &distinctSamples)
Obtains no more than maxNumSamples distinct samples.
Definition: random.hpp:153
static void Start(const std::string &name)
Start the given timer.
Definition: timers.cpp:28
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.
static size_t MinimumSamplesReqd(const size_t n, const size_t k, const double tau, const double alpha)
Compute the minimum number of samples required to guarantee the given rank-approximation and success ...
Definition: ra_util.cpp:18
Definition: hmm_train_main.cpp:300
static void Stop(const std::string &name)
Stop the given timer.
Definition: timers.cpp:36
The TreeTraits class provides compile-time information on the characteristics of a given tree type...
Definition: tree_traits.hpp:77
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84
#define CEREAL_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_wrapper.hpp:96
The RASearchRules class is a template helper class used by RASearch class when performing rank-approx...
Definition: ra_search_rules.hpp:33