mlpack
rs_model_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_RANGE_SEARCH_RS_MODEL_IMPL_HPP
13 #define MLPACK_METHODS_RANGE_SEARCH_RS_MODEL_IMPL_HPP
14 
15 // In case it hasn't been included yet.
16 #include "rs_model.hpp"
17 
19 
20 namespace mlpack {
21 namespace range {
22 
23 template<template<typename TreeMetricType,
24  typename TreeStatType,
25  typename TreeMatType> class TreeType>
26 void RSWrapper<TreeType>::Train(arma::mat&& referenceSet,
27  const size_t /* leafSize */)
28 {
29  rs.Train(std::move(referenceSet));
30 }
31 
32 template<template<typename TreeMetricType,
33  typename TreeStatType,
34  typename TreeMatType> class TreeType>
35 void RSWrapper<TreeType>::Search(arma::mat&& querySet,
36  const math::Range& range,
37  std::vector<std::vector<size_t>>& neighbors,
38  std::vector<std::vector<double>>& distances,
39  const size_t /* leafSize */)
40 {
41  rs.Search(std::move(querySet), range, neighbors, distances);
42 }
43 
44 template<template<typename TreeMetricType,
45  typename TreeStatType,
46  typename TreeMatType> class TreeType>
47 void RSWrapper<TreeType>::Search(const math::Range& range,
48  std::vector<std::vector<size_t>>& neighbors,
49  std::vector<std::vector<double>>& distances)
50 {
51  rs.Search(range, neighbors, distances);
52 }
53 
54 template<template<typename TreeMetricType,
55  typename TreeStatType,
56  typename TreeMatType> class TreeType>
57 void LeafSizeRSWrapper<TreeType>::Train(arma::mat&& referenceSet,
58  const size_t leafSize)
59 {
60  if (rs.Naive())
61  {
62  rs.Train(std::move(referenceSet));
63  }
64  else
65  {
66  std::vector<size_t> oldFromNewReferences;
67  typename decltype(rs)::Tree* tree =
68  new typename decltype(rs)::Tree(std::move(referenceSet),
69  oldFromNewReferences,
70  leafSize);
71  rs.Train(tree);
72 
73  // Give the model ownership of the tree and the mappings.
74  rs.treeOwner = true;
75  rs.oldFromNewReferences = std::move(oldFromNewReferences);
76  }
77 }
78 
79 template<template<typename TreeMetricType,
80  typename TreeStatType,
81  typename TreeMatType> class TreeType>
82 void LeafSizeRSWrapper<TreeType>::Search(
83  arma::mat&& querySet,
84  const math::Range& range,
85  std::vector<std::vector<size_t>>& neighbors,
86  std::vector<std::vector<double>>& distances,
87  const size_t leafSize)
88 {
89  if (!rs.Naive() && !rs.SingleMode())
90  {
91  // Build a second tree and search.
92  Timer::Start("tree_building");
93  Log::Info << "Building query tree..." << std::endl;
94  std::vector<size_t> oldFromNewQueries;
95  typename decltype(rs)::Tree queryTree(std::move(querySet),
96  oldFromNewQueries,
97  leafSize);
98  Log::Info << "Tree built." << std::endl;
99  Timer::Stop("tree_building");
100 
101  std::vector<std::vector<size_t>> neighborsOut;
102  std::vector<std::vector<double>> distancesOut;
103  rs.Search(&queryTree, range, neighborsOut, distancesOut);
104 
105  // Remap the query points.
106  neighbors.resize(queryTree.Dataset().n_cols);
107  distances.resize(queryTree.Dataset().n_cols);
108  for (size_t i = 0; i < queryTree.Dataset().n_cols; ++i)
109  {
110  neighbors[oldFromNewQueries[i]] = neighborsOut[i];
111  distances[oldFromNewQueries[i]] = distancesOut[i];
112  }
113  }
114  else
115  {
116  rs.Search(std::move(querySet), range, neighbors, distances);
117  }
118 }
119 
120 // Serialize the model.
121 template<typename Archive>
122 void RSModel::serialize(Archive& ar, const uint32_t /* version */)
123 {
124  ar(CEREAL_NVP(treeType));
125  ar(CEREAL_NVP(randomBasis));
126  ar(CEREAL_NVP(q));
127 
128  // This should never happen, but just in case...
129  if (cereal::is_loading<Archive>())
130  InitializeModel(false, false); // Values will be overwritten.
131 
132  // Avoid polymorphic serialization by explicitly serializing the correct type.
133  switch (treeType)
134  {
135  case KD_TREE:
136  {
137  LeafSizeRSWrapper<tree::KDTree>& typedSearch =
138  dynamic_cast<LeafSizeRSWrapper<tree::KDTree>&>(*rSearch);
139  ar(CEREAL_NVP(typedSearch));
140  break;
141  }
142  case COVER_TREE:
143  {
145  dynamic_cast<RSWrapper<tree::StandardCoverTree>&>(*rSearch);
146  ar(CEREAL_NVP(typedSearch));
147  break;
148  }
149 
150  case R_TREE:
151  {
152  RSWrapper<tree::RTree>& typedSearch =
153  dynamic_cast<RSWrapper<tree::RTree>&>(*rSearch);
154  ar(CEREAL_NVP(typedSearch));
155  break;
156  }
157 
158  case R_STAR_TREE:
159  {
160  RSWrapper<tree::RStarTree>& typedSearch =
161  dynamic_cast<RSWrapper<tree::RStarTree>&>(*rSearch);
162  ar(CEREAL_NVP(typedSearch));
163  break;
164  }
165 
166  case BALL_TREE:
167  {
168  LeafSizeRSWrapper<tree::BallTree>& typedSearch =
169  dynamic_cast<LeafSizeRSWrapper<tree::BallTree>&>(*rSearch);
170  ar(CEREAL_NVP(typedSearch));
171  break;
172  }
173  case X_TREE:
174  {
175  RSWrapper<tree::XTree>& typedSearch =
176  dynamic_cast<RSWrapper<tree::XTree>&>(*rSearch);
177  ar(CEREAL_NVP(typedSearch));
178  break;
179  }
180 
181  case HILBERT_R_TREE:
182  {
183  RSWrapper<tree::HilbertRTree>& typedSearch =
184  dynamic_cast<RSWrapper<tree::HilbertRTree>&>(*rSearch);
185  ar(CEREAL_NVP(typedSearch));
186  break;
187  }
188 
189  case R_PLUS_TREE:
190  {
191  RSWrapper<tree::RPlusTree>& typedSearch =
192  dynamic_cast<RSWrapper<tree::RPlusTree>&>(*rSearch);
193  ar(CEREAL_NVP(typedSearch));
194  break;
195  }
196 
197  case R_PLUS_PLUS_TREE:
198  {
199  RSWrapper<tree::RPlusPlusTree>& typedSearch =
200  dynamic_cast<RSWrapper<tree::RPlusPlusTree>&>(*rSearch);
201  ar(CEREAL_NVP(typedSearch));
202  break;
203  }
204 
205  case VP_TREE:
206  {
207  RSWrapper<tree::VPTree>& typedSearch =
208  dynamic_cast<RSWrapper<tree::VPTree>&>(*rSearch);
209  ar(CEREAL_NVP(typedSearch));
210  break;
211  }
212 
213  case RP_TREE:
214  {
215  RSWrapper<tree::RPTree>& typedSearch =
216  dynamic_cast<RSWrapper<tree::RPTree>&>(*rSearch);
217  ar(CEREAL_NVP(typedSearch));
218  break;
219  }
220 
221  case MAX_RP_TREE:
222  {
223  RSWrapper<tree::MaxRPTree>& typedSearch =
224  dynamic_cast<RSWrapper<tree::MaxRPTree>&>(*rSearch);
225  ar(CEREAL_NVP(typedSearch));
226  break;
227  }
228  case UB_TREE:
229  {
230  RSWrapper<tree::UBTree>& typedSearch =
231  dynamic_cast<RSWrapper<tree::UBTree>&>(*rSearch);
232  ar(CEREAL_NVP(typedSearch));
233  break;
234  }
235  case OCTREE:
236  {
237  LeafSizeRSWrapper<tree::Octree>& typedSearch =
238  dynamic_cast<LeafSizeRSWrapper<tree::Octree>&>(*rSearch);
239  ar(CEREAL_NVP(typedSearch));
240  break;
241  }
242  }
243 }
244 
245 } // namespace range
246 } // namespace mlpack
247 
248 #endif
RSWrapper is a wrapper class for most RangeSearch types.
Definition: rs_model.hpp:86
static void Start(const std::string &name)
Start the given timer.
Definition: timers.cpp:28
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Forward declaration.
Definition: range_search.hpp:28
void serialize(Archive &ar, const uint32_t)
Serialize the range search model.
Definition: rs_model_impl.hpp:122
virtual void Search(arma::mat &&querySet, const math::Range &range, std::vector< std::vector< size_t >> &neighbors, std::vector< std::vector< double >> &distances, const size_t)
Perform bichromatic range search (i.e.
Definition: rs_model_impl.hpp:35
Definition: hmm_train_main.cpp:300
static void Stop(const std::string &name)
Stop the given timer.
Definition: timers.cpp:36
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84