mlpack
ns_model.hpp
Go to the documentation of this file.
1 
16 #ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_NS_MODEL_HPP
17 #define MLPACK_METHODS_NEIGHBOR_SEARCH_NS_MODEL_HPP
18 
24 #include "neighbor_search.hpp"
25 
26 namespace mlpack {
27 namespace neighbor {
28 
36 {
37  public:
41 
44  virtual NSWrapperBase* Clone() const = 0;
45 
47  virtual ~NSWrapperBase() { }
48 
50  virtual const arma::mat& Dataset() const = 0;
51 
53  virtual NeighborSearchMode SearchMode() const = 0;
55  virtual NeighborSearchMode& SearchMode() = 0;
56 
58  virtual double Epsilon() const = 0;
60  virtual double& Epsilon() = 0;
61 
63  virtual void Train(arma::mat&& referenceSet,
64  const size_t leafSize,
65  const double tau,
66  const double rho) = 0;
67 
70  virtual void Search(arma::mat&& querySet,
71  const size_t k,
72  arma::Mat<size_t>& neighbors,
73  arma::mat& distances,
74  const size_t leafSize,
75  const double rho) = 0;
76 
79  virtual void Search(const size_t k,
80  arma::Mat<size_t>& neighbors,
81  arma::mat& distances) = 0;
82 };
83 
87 template<typename SortPolicy,
88  template<typename TreeMetricType,
89  typename TreeStatType,
90  typename TreeMatType> class TreeType,
91  template<typename RuleType> class DualTreeTraversalType =
94  arma::mat>::template DualTreeTraverser,
95  template<typename RuleType> class SingleTreeTraversalType =
96  TreeType<metric::EuclideanDistance,
97  NeighborSearchStat<SortPolicy>,
98  arma::mat>::template SingleTreeTraverser>
99 class NSWrapper : public NSWrapperBase
100 {
101  public:
104  NSWrapper(const NeighborSearchMode searchMode,
105  const double epsilon) :
106  ns(searchMode, epsilon)
107  {
108  // Nothing else to do.
109  }
110 
112  virtual ~NSWrapper() { }
113 
116  virtual NSWrapper* Clone() const { return new NSWrapper(*this); }
117 
119  const arma::mat& Dataset() const { return ns.ReferenceSet(); }
120 
122  NeighborSearchMode SearchMode() const { return ns.SearchMode(); }
124  NeighborSearchMode& SearchMode() { return ns.SearchMode(); }
125 
127  double Epsilon() const { return ns.Epsilon(); }
129  double& Epsilon() { return ns.Epsilon(); }
130 
133  virtual void Train(arma::mat&& referenceSet,
134  const size_t /* leafSize */,
135  const double /* tau */,
136  const double /* rho */);
137 
140  virtual void Search(arma::mat&& querySet,
141  const size_t k,
142  arma::Mat<size_t>& neighbors,
143  arma::mat& distances,
144  const size_t /* leafSize */,
145  const double /* rho */);
146 
149  virtual void Search(const size_t k,
150  arma::Mat<size_t>& neighbors,
151  arma::mat& distances);
152 
154  template<typename Archive>
155  void serialize(Archive& ar, const uint32_t /* version */)
156  {
157  ar(CEREAL_NVP(ns));
158  }
159 
160  protected:
161  // Convenience typedef for the neighbor search type held by this class.
162  typedef NeighborSearch<SortPolicy,
163  metric::EuclideanDistance,
164  arma::mat,
165  TreeType,
166  DualTreeTraversalType,
167  SingleTreeTraversalType> NSType;
168 
170  NSType ns;
171 };
172 
178 template<typename SortPolicy,
179  template<typename TreeMetricType,
180  typename TreeStatType,
181  typename TreeMatType> class TreeType,
182  template<typename RuleType> class DualTreeTraversalType =
183  TreeType<metric::EuclideanDistance,
184  NeighborSearchStat<SortPolicy>,
185  arma::mat>::template DualTreeTraverser,
186  template<typename RuleType> class SingleTreeTraversalType =
187  TreeType<metric::EuclideanDistance,
188  NeighborSearchStat<SortPolicy>,
189  arma::mat>::template SingleTreeTraverser>
190 class LeafSizeNSWrapper :
191  public NSWrapper<SortPolicy,
192  TreeType,
193  DualTreeTraversalType,
194  SingleTreeTraversalType>
195 {
196  public:
200  const double epsilon) :
201  NSWrapper<SortPolicy,
202  TreeType,
203  DualTreeTraversalType,
204  SingleTreeTraversalType>(searchMode, epsilon)
205  {
206  // Nothing to do.
207  }
208 
210  virtual ~LeafSizeNSWrapper() { }
211 
213  virtual LeafSizeNSWrapper* Clone() const
214  {
215  return new LeafSizeNSWrapper(*this);
216  }
217 
220  virtual void Train(arma::mat&& referenceSet,
221  const size_t leafSize,
222  const double /* tau */,
223  const double /* rho */);
224 
227  virtual void Search(arma::mat&& querySet,
228  const size_t k,
229  arma::Mat<size_t>& neighbors,
230  arma::mat& distances,
231  const size_t leafSize,
232  const double /* rho */);
233 
235  template<typename Archive>
236  void serialize(Archive& ar, const uint32_t /* version */)
237  {
238  ar(CEREAL_NVP(ns));
239  }
240 
241  protected:
242  using NSWrapper<SortPolicy,
243  TreeType,
244  DualTreeTraversalType,
245  SingleTreeTraversalType>::ns;
246 };
247 
252 template<typename SortPolicy>
254  public NSWrapper<
255  SortPolicy,
256  tree::SPTree,
257  tree::SPTree<metric::EuclideanDistance,
258  NeighborSearchStat<SortPolicy>,
259  arma::mat>::template DefeatistDualTreeTraverser,
260  tree::SPTree<metric::EuclideanDistance,
261  NeighborSearchStat<SortPolicy>,
262  arma::mat>::template DefeatistSingleTreeTraverser>
263 {
264  public:
267  const double epsilon) :
268  NSWrapper<
269  SortPolicy,
270  tree::SPTree,
271  tree::SPTree<metric::EuclideanDistance,
272  NeighborSearchStat<SortPolicy>,
273  arma::mat>::template DefeatistDualTreeTraverser,
274  tree::SPTree<metric::EuclideanDistance,
275  NeighborSearchStat<SortPolicy>,
276  arma::mat>::template DefeatistSingleTreeTraverser>(
277  searchMode, epsilon)
278  {
279  // Nothing to do.
280  }
281 
283  virtual ~SpillNSWrapper() { }
284 
286  virtual SpillNSWrapper* Clone() const { return new SpillNSWrapper(*this); }
287 
289  virtual void Train(arma::mat&& referenceSet,
290  const size_t leafSize,
291  const double tau,
292  const double rho);
293 
296  virtual void Search(arma::mat&& querySet,
297  const size_t k,
298  arma::Mat<size_t>& neighbors,
299  arma::mat& distances,
300  const size_t leafSize,
301  const double rho);
302 
304  template<typename Archive>
305  void serialize(Archive& ar, const uint32_t /* version */)
306  {
307  ar(CEREAL_NVP(ns));
308  }
309 
310  protected:
311  using NSWrapper<
312  SortPolicy,
313  tree::SPTree,
314  tree::SPTree<metric::EuclideanDistance,
315  NeighborSearchStat<SortPolicy>,
316  arma::mat>::template DefeatistDualTreeTraverser,
317  tree::SPTree<metric::EuclideanDistance,
318  NeighborSearchStat<SortPolicy>,
319  arma::mat>::template DefeatistSingleTreeTraverser>::ns;
320 };
321 
332 template<typename SortPolicy>
333 class NSModel
334 {
335  public:
338  {
339  KD_TREE,
340  COVER_TREE,
341  R_TREE,
342  R_STAR_TREE,
343  BALL_TREE,
344  X_TREE,
345  HILBERT_R_TREE,
346  R_PLUS_TREE,
347  R_PLUS_PLUS_TREE,
348  VP_TREE,
349  RP_TREE,
350  MAX_RP_TREE,
351  SPILL_TREE,
352  UB_TREE,
353  OCTREE
354  };
355 
356  private:
358  TreeTypes treeType;
359 
361  bool randomBasis;
363  arma::mat q;
364 
365  size_t leafSize;
366  double tau;
367  double rho;
368 
373  NSWrapperBase* nSearch;
374 
375  public:
384  NSModel(TreeTypes treeType = TreeTypes::KD_TREE, bool randomBasis = false);
385 
391  NSModel(const NSModel& other);
392 
398  NSModel(NSModel&& other);
399 
405  NSModel& operator=(const NSModel& other);
406 
412  NSModel& operator=(NSModel&& other);
413 
415  ~NSModel();
416 
418  template<typename Archive>
419  void serialize(Archive& ar, const uint32_t /* version */);
420 
422  const arma::mat& Dataset() const;
423 
427 
429  size_t LeafSize() const { return leafSize; }
430  size_t& LeafSize() { return leafSize; }
431 
433  double Tau() const { return tau; }
434  double& Tau() { return tau; }
435 
437  double Rho() const { return rho; }
438  double& Rho() { return rho; }
439 
441  double Epsilon() const;
442  double& Epsilon();
443 
445  TreeTypes TreeType() const { return treeType; }
446  TreeTypes& TreeType() { return treeType; }
447 
449  bool RandomBasis() const { return randomBasis; }
450  bool& RandomBasis() { return randomBasis; }
451 
453  void InitializeModel(const NeighborSearchMode searchMode,
454  const double epsilon);
455 
457  void BuildModel(arma::mat&& referenceSet,
458  const NeighborSearchMode searchMode,
459  const double epsilon = 0);
460 
462  void Search(arma::mat&& querySet,
463  const size_t k,
464  arma::Mat<size_t>& neighbors,
465  arma::mat& distances);
466 
468  void Search(const size_t k,
469  arma::Mat<size_t>& neighbors,
470  arma::mat& distances);
471 
473  std::string TreeName() const;
474 };
475 
476 } // namespace neighbor
477 } // namespace mlpack
478 
479 // Include implementation.
480 #include "ns_model_impl.hpp"
481 
482 #endif
NSWrapper(const NeighborSearchMode searchMode, const double epsilon)
Construct the NSWrapper object, initializing the internally-held NeighborSearch object.
Definition: ns_model.hpp:104
virtual ~SpillNSWrapper()
Destruct the SpillNSWrapper.
Definition: ns_model.hpp:283
void serialize(Archive &ar, const uint32_t)
Serialize the NeighborSearch model.
Definition: ns_model.hpp:236
double Rho() const
Expose Rho.
Definition: ns_model.hpp:437
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
bool RandomBasis() const
Expose randomBasis.
Definition: ns_model.hpp:449
NSWrapper is a wrapper class for most NeighborSearch types.
Definition: ns_model.hpp:99
const arma::mat & Dataset() const
Get a reference to the reference set.
Definition: ns_model.hpp:119
virtual void Search(arma::mat &&querySet, const size_t k, arma::Mat< size_t > &neighbors, arma::mat &distances, const size_t leafSize, const double rho)=0
Perform bichromatic neighbor search (i.e.
TreeTypes
Enum type to identify each accepted tree type.
Definition: ns_model.hpp:337
Extra data for each node in the tree.
Definition: neighbor_search_stat.hpp:26
SpillNSWrapper(const NeighborSearchMode searchMode, const double epsilon)
Construct the SpillNSWrapper.
Definition: ns_model.hpp:266
LeafSizeNSWrapper wraps any NeighborSearch types that take a leaf size for tree construction.
Definition: neighbor_search.hpp:40
The NeighborSearch class is a template class for performing distance-based neighbor searches...
Definition: neighbor_search.hpp:88
LeafSizeNSWrapper(const NeighborSearchMode searchMode, const double epsilon)
Construct the LeafSizeNSWrapper by delegating to the NSWrapper constructor.
Definition: ns_model.hpp:199
virtual const arma::mat & Dataset() const =0
Return a reference to the dataset.
NeighborSearchMode & SearchMode()
Modify the search mode.
Definition: ns_model.hpp:124
virtual NSWrapper * Clone() const
Create a copy of this NSWrapper object.
Definition: ns_model.hpp:116
double Tau() const
Expose Tau.
Definition: ns_model.hpp:433
TreeTypes TreeType() const
Expose treeType.
Definition: ns_model.hpp:445
NSWrapperBase is a base wrapper class for holding all NeighborSearch types supported by NSModel...
Definition: ns_model.hpp:35
NSType ns
The instantiated NeighborSearch object that we are wrapping.
Definition: ns_model.hpp:170
NeighborSearchMode SearchMode() const
Get the search mode.
Definition: ns_model.hpp:122
virtual void Train(arma::mat &&referenceSet, const size_t leafSize, const double tau, const double rho)=0
Train the NeighborSearch model with the given parameters.
The NSModel class provides an easy way to serialize a model, abstracts away the different types of tr...
Definition: ns_model.hpp:333
virtual SpillNSWrapper * Clone() const
Return a copy of the SpillNSWrapper.
Definition: ns_model.hpp:286
double & Epsilon()
Modify epsilon, the approximation parameter.
Definition: ns_model.hpp:129
virtual ~NSWrapper()
Delete the NSWrapper object.
Definition: ns_model.hpp:112
SpillTree< MetricType, StatisticType, MatType, AxisOrthogonalHyperplane, MidpointSpaceSplit > SPTree
The hybrid spill tree.
Definition: typedef.hpp:62
virtual LeafSizeNSWrapper * Clone() const
Return a copy of the LeafSizeNSWrapper.
Definition: ns_model.hpp:213
The L_p metric for arbitrary integer p, with an option to take the root.
Definition: lmetric.hpp:63
NeighborSearchMode
NeighborSearchMode represents the different neighbor search modes available.
Definition: neighbor_search.hpp:43
virtual double Epsilon() const =0
Get the approximation parameter epsilon.
virtual ~LeafSizeNSWrapper()
Delete the LeafSizeNSWrapper.
Definition: ns_model.hpp:210
NSWrapperBase()
Create the NSWrapperBase object.
Definition: ns_model.hpp:40
size_t LeafSize() const
Expose LeafSize.
Definition: ns_model.hpp:429
double Epsilon() const
Get epsilon, the approximation parameter.
Definition: ns_model.hpp:127
void serialize(Archive &ar, const uint32_t)
Serialize the NeighborSearch model.
Definition: ns_model.hpp:155
The SpillNSWrapper class wraps the NeighborSearch class when the spill tree is used.
Definition: ns_model.hpp:253
virtual NeighborSearchMode SearchMode() const =0
Get the search mode.
void serialize(Archive &ar, const uint32_t)
Serialize the NeighborSearch model.
Definition: ns_model.hpp:305
virtual NSWrapperBase * Clone() const =0
Create a new NSWrapperBase that is the same as this one.
virtual ~NSWrapperBase()
Destruct the NSWrapperBase (nothing to do).
Definition: ns_model.hpp:47