12 #ifndef MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_IMPL_HPP 13 #define MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_IMPL_HPP 24 template<
typename TreeType,
typename MatType>
27 std::vector<size_t>& oldFromNew,
28 const typename std::enable_if<
31 return new TreeType(std::forward<MatType>(dataset), oldFromNew);
35 template<
typename TreeType,
typename MatType>
38 const std::vector<size_t>& ,
39 const typename std::enable_if<
42 return new TreeType(std::forward<MatType>(dataset));
45 template<
typename MetricType,
47 template<
typename TreeMetricType,
48 typename TreeStatType,
49 typename TreeMatType>
class TreeType>
53 const bool singleMode,
54 const MetricType metric) :
55 referenceTree(naive ? NULL : BuildTree<Tree>(std::move(referenceSet),
56 oldFromNewReferences)),
57 referenceSet(naive ? new MatType(std::move(referenceSet)) :
58 &referenceTree->Dataset()),
61 singleMode(!naive && singleMode),
69 template<
typename MetricType,
71 template<
typename TreeMetricType,
72 typename TreeStatType,
73 typename TreeMatType>
class TreeType>
76 const bool singleMode,
77 const MetricType metric) :
78 referenceTree(referenceTree),
79 referenceSet(&referenceTree->Dataset()),
82 singleMode(singleMode),
90 template<
typename MetricType,
92 template<
typename TreeMetricType,
93 typename TreeStatType,
94 typename TreeMatType>
class TreeType>
97 const bool singleMode,
98 const MetricType metric) :
100 referenceSet(naive ? new MatType() : NULL),
103 singleMode(singleMode),
111 referenceTree = BuildTree<Tree>(std::move(arma::mat()),
112 oldFromNewReferences);
113 referenceSet = &referenceTree->Dataset();
118 template<
typename MetricType,
120 template<
typename TreeMetricType,
121 typename TreeStatType,
122 typename TreeMatType>
class TreeType>
124 const RangeSearch& other) :
125 oldFromNewReferences(other.oldFromNewReferences),
126 referenceTree(other.referenceTree ? new Tree(*other.referenceTree) : NULL),
127 referenceSet(other.referenceTree ? &referenceTree->Dataset() :
128 new MatType(*other.referenceSet)),
129 treeOwner(other.referenceTree),
131 singleMode(other.singleMode),
132 metric(other.metric),
133 baseCases(other.baseCases),
139 template<
typename MetricType,
141 template<
typename TreeMetricType,
142 typename TreeStatType,
143 typename TreeMatType>
class TreeType>
145 oldFromNewReferences(std::move(other.oldFromNewReferences)),
146 referenceTree(other.referenceTree),
147 referenceSet(other.referenceSet),
148 treeOwner(other.treeOwner),
150 singleMode(other.singleMode),
151 metric(std::move(other.metric)),
152 baseCases(other.baseCases),
156 other.referenceTree =
157 BuildTree<Tree>(std::move(arma::mat()), other.oldFromNewReferences);
158 other.referenceSet = &other.referenceTree->Dataset();
159 other.treeOwner =
true;
161 other.singleMode =
false;
166 template<
typename MetricType,
168 template<
typename TreeMetricType,
169 typename TreeStatType,
170 typename TreeMatType>
class TreeType>
171 RangeSearch<MetricType, MatType, TreeType>&
176 oldFromNewReferences = other.oldFromNewReferences;
177 referenceTree = other.referenceTree ?
new Tree(*other.referenceTree) :
179 referenceSet = other.referenceTree ? &referenceTree->Dataset() :
180 new MatType(*other.referenceSet);
181 treeOwner = other.referenceTree;
183 singleMode = other.singleMode;
184 metric = other.metric;
185 baseCases = other.baseCases;
186 scores = other.scores;
191 template<
typename MetricType,
193 template<
typename TreeMetricType,
194 typename TreeStatType,
195 typename TreeMatType>
class TreeType>
196 RangeSearch<MetricType, MatType, TreeType>&
203 delete referenceTree;
208 oldFromNewReferences = std::move(other.oldFromNewReferences);
209 referenceTree = other.referenceTree;
210 referenceSet = other.referenceSet;
211 treeOwner = other.treeOwner;
213 singleMode = other.singleMode;
214 metric = std::move(other.metric);
215 baseCases = other.baseCases;
216 scores = other.scores;
219 other.referenceTree =
nullptr;
220 other.referenceSet =
nullptr;
221 other.treeOwner =
false;
223 other.singleMode =
false;
230 template<
typename MetricType,
232 template<
typename TreeMetricType,
233 typename TreeStatType,
234 typename TreeMatType>
class TreeType>
235 RangeSearch<MetricType, MatType, TreeType>::~RangeSearch()
237 if (treeOwner && referenceTree)
238 delete referenceTree;
239 if (naive && referenceSet)
243 template<
typename MetricType,
245 template<
typename TreeMetricType,
246 typename TreeStatType,
247 typename TreeMatType>
class TreeType>
248 void RangeSearch<MetricType, MatType, TreeType>::
Train(
249 MatType referenceSet)
252 if (treeOwner && referenceTree)
253 delete referenceTree;
258 referenceTree = BuildTree<Tree>(std::move(referenceSet),
259 oldFromNewReferences);
268 if (naive && this->referenceSet)
269 delete this->referenceSet;
273 this->referenceSet = &referenceTree->Dataset();
277 this->referenceSet =
new MatType(std::move(referenceSet));
281 template<
typename MetricType,
283 template<
typename TreeMetricType,
284 typename TreeStatType,
285 typename TreeMatType>
class TreeType>
286 void RangeSearch<MetricType, MatType, TreeType>::
Train(
290 throw std::invalid_argument(
"cannot train on given reference tree when " 291 "naive search (without trees) is desired");
294 if (treeOwner && referenceTree)
296 delete this->referenceTree;
298 this->referenceTree = referenceTree;
299 this->referenceSet = &referenceTree->Dataset();
304 template<
typename MetricType,
306 template<
typename TreeMetricType,
307 typename TreeStatType,
308 typename TreeMatType>
class TreeType>
309 void RangeSearch<MetricType, MatType, TreeType>::Search(
310 const MatType& querySet,
311 const math::Range& range,
312 std::vector<std::vector<size_t>>& neighbors,
313 std::vector<std::vector<double>>& distances)
315 util::CheckSameDimensionality(querySet, *referenceSet,
316 "RangeSearch::Search()",
"query set");
319 if (referenceSet->n_cols == 0)
325 std::vector<size_t> oldFromNewQueries;
331 std::vector<std::vector<size_t>>* neighborPtr = &neighbors;
332 std::vector<std::vector<double>>* distancePtr = &distances;
339 if (!singleMode && !naive)
341 distancePtr =
new std::vector<std::vector<double>>;
342 neighborPtr =
new std::vector<std::vector<size_t>>;
348 neighborPtr =
new std::vector<std::vector<size_t>>;
352 neighborPtr->clear();
353 neighborPtr->resize(querySet.n_cols);
354 distancePtr->clear();
355 distancePtr->resize(querySet.n_cols);
366 RuleType rules(*referenceSet, querySet, range, *neighborPtr, *distancePtr,
370 for (
size_t i = 0; i < querySet.n_cols; ++i)
371 for (
size_t j = 0; j < referenceSet->n_cols; ++j)
372 rules.BaseCase(i, j);
374 baseCases += (querySet.n_cols * referenceSet->n_cols);
379 RuleType rules(*referenceSet, querySet, range, *neighborPtr, *distancePtr,
381 typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
384 for (
size_t i = 0; i < querySet.n_cols; ++i)
385 traverser.Traverse(i, *referenceTree);
387 baseCases += rules.BaseCases();
388 scores += rules.Scores();
395 Tree* queryTree = BuildTree<Tree>(querySet, oldFromNewQueries);
400 RuleType rules(*referenceSet, queryTree->Dataset(), range, *neighborPtr,
401 *distancePtr, metric);
402 typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
404 traverser.Traverse(*queryTree, *referenceTree);
406 baseCases += rules.BaseCases();
407 scores += rules.Scores();
418 if (!singleMode && !naive && treeOwner)
422 neighbors.resize(querySet.n_cols);
424 distances.resize(querySet.n_cols);
426 for (
size_t i = 0; i < distances.size(); ++i)
429 const size_t queryMapping = oldFromNewQueries[i];
430 distances[queryMapping] = (*distancePtr)[i];
433 neighbors[queryMapping].resize(distances[queryMapping].size());
434 for (
size_t j = 0; j < distances[queryMapping].size(); ++j)
435 neighbors[queryMapping][j] =
436 oldFromNewReferences[(*neighborPtr)[i][j]];
443 else if (!singleMode && !naive)
447 neighbors.resize(querySet.n_cols);
449 distances.resize(querySet.n_cols);
451 for (
size_t i = 0; i < distances.size(); ++i)
454 const size_t queryMapping = oldFromNewQueries[i];
455 distances[queryMapping] = (*distancePtr)[i];
456 neighbors[queryMapping] = (*neighborPtr)[i];
467 neighbors.resize(querySet.n_cols);
469 for (
size_t i = 0; i < neighbors.size(); ++i)
471 neighbors[i].resize((*neighborPtr)[i].size());
472 for (
size_t j = 0; j < neighbors[i].size(); ++j)
473 neighbors[i][j] = oldFromNewReferences[(*neighborPtr)[i][j]];
482 template<
typename MetricType,
484 template<
typename TreeMetricType,
485 typename TreeStatType,
486 typename TreeMatType>
class TreeType>
487 void RangeSearch<MetricType, MatType, TreeType>::Search(
489 const math::Range& range,
490 std::vector<std::vector<size_t>>& neighbors,
491 std::vector<std::vector<double>>& distances)
494 if (referenceSet->n_cols == 0)
500 const MatType& querySet = queryTree->Dataset();
503 if (singleMode || naive)
504 throw std::invalid_argument(
"cannot call RangeSearch::Search() with a " 505 "query tree when naive or singleMode are set to true");
508 std::vector<std::vector<size_t>>* neighborPtr = &neighbors;
511 neighborPtr =
new std::vector<std::vector<size_t>>;
514 neighborPtr->clear();
515 neighborPtr->resize(querySet.n_cols);
517 distances.resize(querySet.n_cols);
521 RuleType rules(*referenceSet, queryTree->Dataset(), range, *neighborPtr,
525 typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
527 traverser.Traverse(*queryTree, *referenceTree);
531 baseCases = rules.BaseCases();
532 scores = rules.Scores();
539 neighbors.resize(querySet.n_cols);
541 for (
size_t i = 0; i < neighbors.size(); ++i)
543 neighbors[i].resize((*neighborPtr)[i].size());
544 for (
size_t j = 0; j < neighbors[i].size(); ++j)
545 neighbors[i][j] = oldFromNewReferences[(*neighborPtr)[i][j]];
553 template<
typename MetricType,
555 template<
typename TreeMetricType,
556 typename TreeStatType,
557 typename TreeMatType>
class TreeType>
558 void RangeSearch<MetricType, MatType, TreeType>::Search(
559 const math::Range& range,
560 std::vector<std::vector<size_t>>& neighbors,
561 std::vector<std::vector<double>>& distances)
564 if (referenceSet->n_cols == 0)
570 std::vector<std::vector<size_t>>* neighborPtr = &neighbors;
571 std::vector<std::vector<double>>* distancePtr = &distances;
576 distancePtr =
new std::vector<std::vector<double>>;
577 neighborPtr =
new std::vector<std::vector<size_t>>;
581 neighborPtr->clear();
582 neighborPtr->resize(referenceSet->n_cols);
583 distancePtr->clear();
584 distancePtr->resize(referenceSet->n_cols);
588 RuleType rules(*referenceSet, *referenceSet, range, *neighborPtr,
589 *distancePtr, metric,
true );
594 for (
size_t i = 0; i < referenceSet->n_cols; ++i)
595 for (
size_t j = 0; j < referenceSet->n_cols; ++j)
596 rules.BaseCase(i, j);
598 baseCases = (referenceSet->n_cols * referenceSet->n_cols);
604 typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
607 for (
size_t i = 0; i < referenceSet->n_cols; ++i)
608 traverser.Traverse(i, *referenceTree);
610 baseCases = rules.BaseCases();
611 scores = rules.Scores();
616 typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
618 traverser.Traverse(*referenceTree, *referenceTree);
620 baseCases = rules.BaseCases();
621 scores = rules.Scores();
630 neighbors.resize(referenceSet->n_cols);
632 distances.resize(referenceSet->n_cols);
634 for (
size_t i = 0; i < distances.size(); ++i)
637 const size_t refMapping = oldFromNewReferences[i];
638 distances[refMapping] = (*distancePtr)[i];
641 neighbors[refMapping].resize(distances[refMapping].size());
642 for (
size_t j = 0; j < distances[refMapping].size(); ++j)
644 neighbors[refMapping][j] = oldFromNewReferences[(*neighborPtr)[i][j]];
654 template<
typename MetricType,
656 template<
typename TreeMetricType,
657 typename TreeStatType,
658 typename TreeMatType>
class TreeType>
659 template<typename Archive>
660 void RangeSearch<MetricType, MatType, TreeType>::serialize(
661 Archive& ar, const uint32_t )
664 ar(CEREAL_NVP(naive));
665 ar(CEREAL_NVP(singleMode));
668 if (cereal::is_loading<Archive>())
678 if (cereal::is_loading<Archive>())
685 ar(CEREAL_NVP(metric));
688 if (cereal::is_loading<Archive>())
690 if (treeOwner && referenceTree)
691 delete referenceTree;
693 referenceTree = NULL;
694 oldFromNewReferences.clear();
701 if (cereal::is_loading<Archive>())
703 if (treeOwner && referenceTree)
704 delete referenceTree;
711 ar(CEREAL_NVP(oldFromNewReferences));
715 if (cereal::is_loading<Archive>())
717 referenceSet = &referenceTree->Dataset();
718 metric = referenceTree->Metric();
The RangeSearch class is a template class for performing range searches.
Definition: range_search.hpp:45
static void Start(const std::string &name)
Start the given timer.
Definition: timers.cpp:28
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
RangeSearch & operator=(const RangeSearch &other)
Deep copy the given RangeSearch model.
Definition: range_search_impl.hpp:172
static const bool RearrangesDataset
This is true if the tree rearranges points in the dataset when it is built.
Definition: tree_traits.hpp:105
Definition: hmm_train_main.cpp:300
static void Stop(const std::string &name)
Stop the given timer.
Definition: timers.cpp:36
The TreeTraits class provides compile-time information on the characteristics of a given tree type...
Definition: tree_traits.hpp:77
TreeType * BuildTree(MatType &&dataset, std::vector< size_t > &oldFromNew, const typename std::enable_if< tree::TreeTraits< TreeType >::RearrangesDataset >::type *=0)
Call the tree constructor that does mapping.
Definition: dtb_impl.hpp:22
The RangeSearchRules class is a template helper class used by RangeSearch class when performing range...
Definition: range_search_rules.hpp:28
#define CEREAL_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_wrapper.hpp:96