15 #ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_NS_MODEL_IMPL_HPP 16 #define MLPACK_METHODS_NEIGHBOR_SEARCH_NS_MODEL_IMPL_HPP 26 template<
typename SortPolicy,
27 template<
typename TreeMetricType,
28 typename TreeStatType,
29 typename TreeMatType>
class TreeType,
30 template<typename RuleType> class DualTreeTraversalType,
31 template<typename RuleType> class SingleTreeTraversalType>
33 SortPolicy, TreeType, DualTreeTraversalType, SingleTreeTraversalType
34 >::
Train(arma::mat&& referenceSet,
39 ns.Train(std::move(referenceSet));
44 template<
typename SortPolicy,
45 template<
typename TreeMetricType,
46 typename TreeStatType,
47 typename TreeMatType>
class TreeType,
48 template<typename RuleType> class DualTreeTraversalType,
49 template<typename RuleType> class SingleTreeTraversalType>
51 SortPolicy, TreeType, DualTreeTraversalType, SingleTreeTraversalType
52 >::Search(arma::mat&& querySet,
54 arma::Mat<size_t>& neighbors,
59 ns.Search(std::move(querySet), k, neighbors, distances);
64 template<
typename SortPolicy,
65 template<
typename TreeMetricType,
66 typename TreeStatType,
67 typename TreeMatType>
class TreeType,
68 template<typename RuleType> class DualTreeTraversalType,
69 template<typename RuleType> class SingleTreeTraversalType>
71 SortPolicy, TreeType, DualTreeTraversalType, SingleTreeTraversalType
72 >::Search(const size_t k,
73 arma::Mat<size_t>& neighbors,
76 ns.Search(k, neighbors, distances);
81 template<
typename SortPolicy,
82 template<
typename TreeMetricType,
83 typename TreeStatType,
84 typename TreeMatType>
class TreeType,
85 template<typename RuleType> class DualTreeTraversalType,
86 template<typename RuleType> class SingleTreeTraversalType>
87 void LeafSizeNSWrapper<
88 SortPolicy, TreeType, DualTreeTraversalType, SingleTreeTraversalType
89 >::
Train(arma::mat&& referenceSet,
90 const size_t leafSize,
94 if (ns.SearchMode() == NAIVE_MODE)
96 ns.Train(std::move(referenceSet));
101 std::vector<size_t> oldFromNewReferences;
102 typename decltype(ns)::Tree referenceTree(std::move(referenceSet),
103 oldFromNewReferences, leafSize);
104 ns.Train(std::move(referenceTree));
105 ns.oldFromNewReferences = std::move(oldFromNewReferences);
111 template<
typename SortPolicy,
112 template<
typename TreeMetricType,
113 typename TreeStatType,
114 typename TreeMatType>
class TreeType,
115 template<typename RuleType> class DualTreeTraversalType,
116 template<typename RuleType> class SingleTreeTraversalType>
117 void LeafSizeNSWrapper<
118 SortPolicy, TreeType, DualTreeTraversalType, SingleTreeTraversalType
119 >::Search(arma::mat&& querySet,
121 arma::Mat<size_t>& neighbors,
122 arma::mat& distances,
123 const size_t leafSize,
126 if (ns.SearchMode() == DUAL_TREE_MODE)
132 std::vector<size_t> oldFromNewQueries;
133 typename decltype(ns)::Tree queryTree(std::move(querySet),
134 oldFromNewQueries, leafSize);
136 arma::Mat<size_t> neighborsOut;
137 arma::mat distancesOut;
138 ns.Search(queryTree, k, neighborsOut, distancesOut);
141 distances.set_size(distancesOut.n_rows, distancesOut.n_cols);
142 neighbors.set_size(neighborsOut.n_rows, neighborsOut.n_cols);
143 for (
size_t i = 0; i < neighborsOut.n_cols; ++i)
145 neighbors.col(oldFromNewQueries[i]) = neighborsOut.col(i);
146 distances.col(oldFromNewQueries[i]) = distancesOut.col(i);
151 ns.Search(querySet, k, neighbors, distances);
156 template<
typename SortPolicy>
158 const size_t leafSize,
162 typename decltype(ns)::Tree tree(std::move(referenceSet), tau, leafSize,
164 ns.Train(std::move(tree));
169 template<
typename SortPolicy>
172 arma::Mat<size_t>& neighbors,
173 arma::mat& distances,
174 const size_t leafSize,
177 if (ns.SearchMode() == DUAL_TREE_MODE)
181 typename decltype(ns)::Tree queryTree(std::move(querySet), 0 ,
183 ns.Search(queryTree, k, neighbors, distances);
187 ns.Search(querySet, k, neighbors, distances);
195 template<
typename SortPolicy>
198 randomBasis(randomBasis),
207 template<
typename SortPolicy>
209 treeType(other.treeType),
210 randomBasis(other.randomBasis),
212 leafSize(other.leafSize),
215 nSearch(other.nSearch->Clone())
220 template<
typename SortPolicy>
222 treeType(other.treeType),
223 randomBasis(other.randomBasis),
224 q(
std::move(other.q)),
225 leafSize(other.leafSize),
228 nSearch(other.nSearch)
231 other.treeType = TreeTypes::KD_TREE;
232 other.randomBasis =
false;
236 other.nSearch = NULL;
239 template<
typename SortPolicy>
246 treeType = other.treeType;
247 randomBasis = other.randomBasis;
249 leafSize = other.leafSize;
252 nSearch = other.nSearch->
Clone();
258 template<
typename SortPolicy>
265 treeType = other.treeType;
266 randomBasis = other.randomBasis;
267 q = std::move(other.q);
268 leafSize = other.leafSize;
271 nSearch = other.nSearch;
274 other.treeType = TreeTypes::KD_TREE;
275 other.randomBasis =
false;
279 other.nSearch = NULL;
286 template<
typename SortPolicy>
293 template<
typename SortPolicy>
294 template<
typename Archive>
297 ar(CEREAL_NVP(treeType));
298 ar(CEREAL_NVP(randomBasis));
300 ar(CEREAL_NVP(leafSize));
305 if (cereal::is_loading<Archive>())
314 dynamic_cast<LeafSizeNSWrapper<SortPolicy,
316 ar(CEREAL_NVP(typedSearch));
322 dynamic_cast<NSWrapper<SortPolicy,
324 ar(CEREAL_NVP(typedSearch));
331 ar(CEREAL_NVP(typedSearch));
338 ar(CEREAL_NVP(typedSearch));
344 dynamic_cast<LeafSizeNSWrapper<SortPolicy,
346 ar(CEREAL_NVP(typedSearch));
353 ar(CEREAL_NVP(typedSearch));
360 ar(CEREAL_NVP(typedSearch));
367 ar(CEREAL_NVP(typedSearch));
370 case R_PLUS_PLUS_TREE:
374 ar(CEREAL_NVP(typedSearch));
381 ar(CEREAL_NVP(typedSearch));
388 ar(CEREAL_NVP(typedSearch));
395 ar(CEREAL_NVP(typedSearch));
402 ar(CEREAL_NVP(typedSearch));
409 ar(CEREAL_NVP(typedSearch));
415 dynamic_cast<LeafSizeNSWrapper<SortPolicy,
417 ar(CEREAL_NVP(typedSearch));
424 template<
typename SortPolicy>
431 template<
typename SortPolicy>
438 template<
typename SortPolicy>
444 template<
typename SortPolicy>
450 template<
typename SortPolicy>
457 template<
typename SortPolicy>
459 const double epsilon)
495 case R_PLUS_PLUS_TREE:
522 template<
typename SortPolicy>
525 const double epsilon)
530 Log::Info <<
"Creating random basis..." << std::endl;
536 if (arma::qr(q, r, arma::randn<arma::mat>(referenceSet.n_rows,
537 referenceSet.n_rows)))
539 arma::vec rDiag(r.n_rows);
540 for (
size_t i = 0; i < rDiag.n_elem; ++i)
544 else if (r(i, i) > 0)
550 q *= arma::diagmat(rDiag);
553 if (arma::det(q) >= 0)
561 referenceSet = q * referenceSet;
563 if (searchMode != NAIVE_MODE)
566 Log::Info <<
"Building reference tree..." << std::endl;
570 nSearch->
Train(std::move(referenceSet), leafSize, tau, rho);
572 if (searchMode != NAIVE_MODE)
580 template<
typename SortPolicy>
583 arma::Mat<size_t>& neighbors,
584 arma::mat& distances)
588 querySet = q * querySet;
590 Log::Info <<
"Searching for " << k <<
" neighbors with ";
595 Log::Info <<
"brute-force (naive) search..." << std::endl;
597 case SINGLE_TREE_MODE:
603 case GREEDY_SINGLE_TREE_MODE:
609 nSearch->
Search(std::move(querySet), k, neighbors, distances, leafSize, rho);
613 template<
typename SortPolicy>
615 arma::Mat<size_t>& neighbors,
616 arma::mat& distances)
618 Log::Info <<
"Searching for " << k <<
" neighbors with ";
623 Log::Info <<
"brute-force (naive) search..." << std::endl;
625 case SINGLE_TREE_MODE:
631 case GREEDY_SINGLE_TREE_MODE:
641 nSearch->
Search(k, neighbors, distances);
645 template<
typename SortPolicy>
663 return "Hilbert R tree";
666 case R_PLUS_PLUS_TREE:
671 return "vantage point tree";
673 return "random projection tree (mean split)";
675 return "random projection tree (max split)";
681 return "unknown tree";
static void Start(const std::string &name)
Start the given timer.
Definition: timers.cpp:28
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
NeighborSearchMode SearchMode() const
Expose SearchMode.
Definition: ns_model_impl.hpp:432
NSModel(TreeTypes treeType=TreeTypes::KD_TREE, bool randomBasis=false)
Initialize the NSModel with the given type and whether or not a random basis should be used...
Definition: ns_model_impl.hpp:196
NSWrapper is a wrapper class for most NeighborSearch types.
Definition: ns_model.hpp:99
virtual void Search(arma::mat &&querySet, const size_t k, arma::Mat< size_t > &neighbors, arma::mat &distances, const size_t leafSize, const double rho)=0
Perform bichromatic neighbor search (i.e.
TreeTypes
Enum type to identify each accepted tree type.
Definition: ns_model.hpp:337
LeafSizeNSWrapper wraps any NeighborSearch types that take a leaf size for tree construction.
Definition: neighbor_search.hpp:40
Definition: pointer_wrapper.hpp:23
virtual const arma::mat & Dataset() const =0
Return a reference to the dataset.
void BuildModel(arma::mat &&referenceSet, const NeighborSearchMode searchMode, const double epsilon=0)
Build the reference tree.
Definition: ns_model_impl.hpp:523
A binary space partitioning tree, such as a KD-tree or a ball tree.
Definition: binary_space_tree.hpp:54
double Epsilon() const
Expose Epsilon.
Definition: ns_model_impl.hpp:445
const arma::mat & Dataset() const
Expose the dataset.
Definition: ns_model_impl.hpp:425
virtual void Train(arma::mat &&referenceSet, const size_t leafSize, const double tau, const double rho)=0
Train the NeighborSearch model with the given parameters.
The NSModel class provides an easy way to serialize a model, abstracts away the different types of tr...
Definition: ns_model.hpp:333
void InitializeModel(const NeighborSearchMode searchMode, const double epsilon)
Initialize the model type. (This does not perform any training.)
Definition: ns_model_impl.hpp:458
Definition: hmm_train_main.cpp:300
virtual void Train(arma::mat &&referenceSet, const size_t leafSize, const double tau, const double rho)
Train the model using the given parameters.
Definition: ns_model_impl.hpp:157
void serialize(Archive &ar, const uint32_t)
Serialize the neighbor search model.
Definition: ns_model_impl.hpp:295
static void Stop(const std::string &name)
Stop the given timer.
Definition: timers.cpp:36
NeighborSearchMode
NeighborSearchMode represents the different neighbor search modes available.
Definition: neighbor_search.hpp:43
virtual double Epsilon() const =0
Get the approximation parameter epsilon.
virtual void Search(arma::mat &&querySet, const size_t k, arma::Mat< size_t > &neighbors, arma::mat &distances, const size_t leafSize, const double rho)
Perform bichromatic search (i.e.
Definition: ns_model_impl.hpp:170
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84
Definition: octree.hpp:25
The SpillNSWrapper class wraps the NeighborSearch class when the spill tree is used.
Definition: ns_model.hpp:253
virtual NeighborSearchMode SearchMode() const =0
Get the search mode.
NSModel & operator=(const NSModel &other)
Copy the given NSModel.
Definition: ns_model_impl.hpp:240
std::string TreeName() const
Return a string representation of the current tree type.
Definition: ns_model_impl.hpp:646
void Search(arma::mat &&querySet, const size_t k, arma::Mat< size_t > &neighbors, arma::mat &distances)
Perform neighbor search. The query set will be reordered.
Definition: ns_model_impl.hpp:581
A cover tree is a tree specifically designed to speed up nearest-neighbor computation in high-dimensi...
Definition: cover_tree.hpp:99
virtual NSWrapperBase * Clone() const =0
Create a new NSWrapperBase that is the same as this one.
~NSModel()
Clean memory, if necessary.
Definition: ns_model_impl.hpp:287