mlpack
ra_model_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_RANN_RA_MODEL_IMPL_HPP
13 #define MLPACK_METHODS_RANN_RA_MODEL_IMPL_HPP
14 
15 // In case it hasn't been included yet.
16 #include "ra_model.hpp"
18 
19 namespace mlpack {
20 namespace neighbor {
21 
22 template<template<typename TreeMetricType,
23  typename TreeStatType,
24  typename TreeMatType> class TreeType>
25 void RAWrapper<TreeType>::Train(arma::mat&& referenceSet,
26  const size_t /* leafSize */)
27 {
28  ra.Train(std::move(referenceSet));
29 }
30 
31 template<template<typename TreeMetricType,
32  typename TreeStatType,
33  typename TreeMatType> class TreeType>
34 void RAWrapper<TreeType>::Search(arma::mat&& querySet,
35  const size_t k,
36  arma::Mat<size_t>& neighbors,
37  arma::mat& distances,
38  const size_t /* leafSize */)
39 {
40  ra.Search(querySet, k, neighbors, distances);
41 }
42 
43 template<template<typename TreeMetricType,
44  typename TreeStatType,
45  typename TreeMatType> class TreeType>
46 void RAWrapper<TreeType>::Search(const size_t k,
47  arma::Mat<size_t>& neighbors,
48  arma::mat& distances)
49 {
50  ra.Search(k, neighbors, distances);
51 }
52 
53 template<template<typename TreeMetricType,
54  typename TreeStatType,
55  typename TreeMatType> class TreeType>
56 void LeafSizeRAWrapper<TreeType>::Train(arma::mat&& referenceSet,
57  const size_t leafSize)
58 {
59  // Build tree, if necessary.
60  if (ra.Naive())
61  {
62  ra.Train(std::move(referenceSet));
63  }
64  else
65  {
66  std::vector<size_t> oldFromNewReferences;
67  typename decltype(ra)::Tree* tree =
68  new typename decltype(ra)::Tree(std::move(referenceSet),
69  oldFromNewReferences,
70  leafSize);
71  ra.Train(tree);
72 
73  // Give the model ownership of the tree and the mappings.
74  ra.treeOwner = true;
75  ra.oldFromNewReferences = std::move(oldFromNewReferences);
76  }
77 }
78 
79 template<template<typename TreeMetricType,
80  typename TreeStatType,
81  typename TreeMatType> class TreeType>
82 void LeafSizeRAWrapper<TreeType>::Search(arma::mat&& querySet,
83  const size_t k,
84  arma::Mat<size_t>& neighbors,
85  arma::mat& distances,
86  const size_t leafSize)
87 {
88  if (!ra.Naive() && !ra.SingleMode())
89  {
90  // Build a second tree and search, taking the leaf size into account.
91  Timer::Start("tree_building");
92  Log::Info << "Building query tree...."<< std::endl;
93  std::vector<size_t> oldFromNewQueries;
94  typename decltype(ra)::Tree queryTree(std::move(querySet),
95  oldFromNewQueries,
96  leafSize);
97  Log::Info << "Tree built." << std::endl;
98  Timer::Stop("tree_building");
99 
100  arma::Mat<size_t> neighborsOut;
101  arma::mat distancesOut;
102  ra.Search(&queryTree, k, neighborsOut, distancesOut);
103 
104  // Unmap the query points.
105  distances.set_size(distancesOut.n_rows, distancesOut.n_cols);
106  neighbors.set_size(neighborsOut.n_rows, neighborsOut.n_cols);
107  for (size_t i = 0; i < oldFromNewQueries.size(); ++i)
108  {
109  neighbors.col(oldFromNewQueries[i]) = neighborsOut.col(i);
110  distances.col(oldFromNewQueries[i]) = distancesOut.col(i);
111  }
112  }
113  else
114  {
115  // Search without building a second tree.
116  ra.Search(querySet, k, neighbors, distances);
117  }
118 }
119 
120 template<typename Archive>
121 void RAModel::serialize(Archive& ar, const uint32_t /* version */)
122 {
123  ar(CEREAL_NVP(treeType));
124  ar(CEREAL_NVP(randomBasis));
125  ar(CEREAL_NVP(q));
126 
127  // This should never happen, but just in case, be clean with memory.
128  if (cereal::is_loading<Archive>())
129  InitializeModel(false, false); // Values will be overwritten.
130 
131  // Avoid polymorphic serialization by explicitly serializing the correct type.
132  switch (treeType)
133  {
134  case KD_TREE:
135  {
136  LeafSizeRAWrapper<tree::KDTree>& typedSearch =
137  dynamic_cast<LeafSizeRAWrapper<tree::KDTree>&>(*raSearch);
138  ar(CEREAL_NVP(typedSearch));
139  break;
140  }
141  case COVER_TREE:
142  {
144  dynamic_cast<RAWrapper<tree::StandardCoverTree>&>(*raSearch);
145  ar(CEREAL_NVP(typedSearch));
146  break;
147  }
148  case R_TREE:
149  {
150  RAWrapper<tree::RTree>& typedSearch =
151  dynamic_cast<RAWrapper<tree::RTree>&>(*raSearch);
152  ar(CEREAL_NVP(typedSearch));
153  break;
154  }
155  case R_STAR_TREE:
156  {
157  RAWrapper<tree::RStarTree>& typedSearch =
158  dynamic_cast<RAWrapper<tree::RStarTree>&>(*raSearch);
159  ar(CEREAL_NVP(typedSearch));
160  break;
161  }
162  case X_TREE:
163  {
164  RAWrapper<tree::XTree>& typedSearch =
165  dynamic_cast<RAWrapper<tree::XTree>&>(*raSearch);
166  ar(CEREAL_NVP(typedSearch));
167  break;
168  }
169  case HILBERT_R_TREE:
170  {
171  RAWrapper<tree::HilbertRTree>& typedSearch =
172  dynamic_cast<RAWrapper<tree::HilbertRTree>&>(*raSearch);
173  ar(CEREAL_NVP(typedSearch));
174  break;
175  }
176  case R_PLUS_TREE:
177  {
178  RAWrapper<tree::RPlusTree>& typedSearch =
179  dynamic_cast<RAWrapper<tree::RPlusTree>&>(*raSearch);
180  ar(CEREAL_NVP(typedSearch));
181  break;
182  }
183  case R_PLUS_PLUS_TREE:
184  {
185  RAWrapper<tree::RPlusPlusTree>& typedSearch =
186  dynamic_cast<RAWrapper<tree::RPlusPlusTree>&>(*raSearch);
187  ar(CEREAL_NVP(typedSearch));
188  break;
189  }
190  case UB_TREE:
191  {
192  RAWrapper<tree::UBTree>& typedSearch =
193  dynamic_cast<RAWrapper<tree::UBTree>&>(*raSearch);
194  ar(CEREAL_NVP(typedSearch));
195  break;
196  }
197  case OCTREE:
198  {
199  LeafSizeRAWrapper<tree::Octree>& typedSearch =
200  dynamic_cast<LeafSizeRAWrapper<tree::Octree>&>(*raSearch);
201  ar(CEREAL_NVP(typedSearch));
202  break;
203  }
204  }
205 }
206 
207 } // namespace neighbor
208 } // namespace mlpack
209 
210 #endif
static void Start(const std::string &name)
Start the given timer.
Definition: timers.cpp:28
virtual void Search(arma::mat &&querySet, const size_t k, arma::Mat< size_t > &neighbors, arma::mat &distances, const size_t)
Perform bichromatic neighbor search (i.e.
Definition: ra_model_impl.hpp:34
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void serialize(Archive &ar, const uint32_t)
Serialize the model.
Definition: ra_model_impl.hpp:121
Definition: hmm_train_main.cpp:300
static void Stop(const std::string &name)
Stop the given timer.
Definition: timers.cpp:36
LeafSizeRAWrapper wraps any RASearch type that needs to be able to take the leaf size into account wh...
Definition: ra_model.hpp:208
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84
RAWrapper is a wrapper class for most RASearch types.
Definition: ra_model.hpp:109