16 #ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_NS_MODEL_HPP 17 #define MLPACK_METHODS_NEIGHBOR_SEARCH_NS_MODEL_HPP 50 virtual const arma::mat&
Dataset()
const = 0;
58 virtual double Epsilon()
const = 0;
63 virtual void Train(arma::mat&& referenceSet,
64 const size_t leafSize,
66 const double rho) = 0;
70 virtual void Search(arma::mat&& querySet,
72 arma::Mat<size_t>& neighbors,
74 const size_t leafSize,
75 const double rho) = 0;
79 virtual void Search(
const size_t k,
80 arma::Mat<size_t>& neighbors,
81 arma::mat& distances) = 0;
87 template<
typename SortPolicy,
88 template<
typename TreeMetricType,
89 typename TreeStatType,
90 typename TreeMatType>
class TreeType,
91 template<typename RuleType> class DualTreeTraversalType =
94 arma::mat>::template DualTreeTraverser,
95 template<
typename RuleType>
class SingleTreeTraversalType =
96 TreeType<metric::EuclideanDistance,
97 NeighborSearchStat<SortPolicy>,
98 arma::mat>::template SingleTreeTraverser>
105 const double epsilon) :
106 ns(searchMode, epsilon)
119 const arma::mat&
Dataset()
const {
return ns.ReferenceSet(); }
127 double Epsilon()
const {
return ns.Epsilon(); }
133 virtual void Train(arma::mat&& referenceSet,
140 virtual void Search(arma::mat&& querySet,
142 arma::Mat<size_t>& neighbors,
143 arma::mat& distances,
149 virtual void Search(
const size_t k,
150 arma::Mat<size_t>& neighbors,
151 arma::mat& distances);
154 template<
typename Archive>
163 metric::EuclideanDistance,
166 DualTreeTraversalType,
167 SingleTreeTraversalType>
NSType;
178 template<
typename SortPolicy,
179 template<
typename TreeMetricType,
180 typename TreeStatType,
181 typename TreeMatType>
class TreeType,
182 template<typename RuleType> class DualTreeTraversalType =
183 TreeType<metric::EuclideanDistance,
184 NeighborSearchStat<SortPolicy>,
185 arma::mat>::template DualTreeTraverser,
186 template<
typename RuleType>
class SingleTreeTraversalType =
187 TreeType<metric::EuclideanDistance,
188 NeighborSearchStat<SortPolicy>,
189 arma::mat>::template SingleTreeTraverser>
193 DualTreeTraversalType,
194 SingleTreeTraversalType>
200 const double epsilon) :
203 DualTreeTraversalType,
204 SingleTreeTraversalType>(searchMode, epsilon)
220 virtual void Train(arma::mat&& referenceSet,
221 const size_t leafSize,
227 virtual void Search(arma::mat&& querySet,
229 arma::Mat<size_t>& neighbors,
230 arma::mat& distances,
231 const size_t leafSize,
235 template<
typename Archive>
244 DualTreeTraversalType,
245 SingleTreeTraversalType>::ns;
252 template<
typename SortPolicy>
257 tree::SPTree<metric::EuclideanDistance,
258 NeighborSearchStat<SortPolicy>,
259 arma::mat>::template DefeatistDualTreeTraverser,
260 tree::SPTree<metric::EuclideanDistance,
261 NeighborSearchStat<SortPolicy>,
262 arma::mat>::template DefeatistSingleTreeTraverser>
267 const double epsilon) :
271 tree::
SPTree<metric::EuclideanDistance,
273 arma::mat>::template DefeatistDualTreeTraverser,
274 tree::
SPTree<metric::EuclideanDistance,
276 arma::mat>::template DefeatistSingleTreeTraverser>(
289 virtual void Train(arma::mat&& referenceSet,
290 const size_t leafSize,
296 virtual void Search(arma::mat&& querySet,
298 arma::Mat<size_t>& neighbors,
299 arma::mat& distances,
300 const size_t leafSize,
304 template<
typename Archive>
314 tree::SPTree<metric::EuclideanDistance,
315 NeighborSearchStat<SortPolicy>,
316 arma::mat>::template DefeatistDualTreeTraverser,
317 tree::SPTree<metric::EuclideanDistance,
318 NeighborSearchStat<SortPolicy>,
319 arma::mat>::template DefeatistSingleTreeTraverser>::ns;
332 template<
typename SortPolicy>
418 template<
typename Archive>
419 void serialize(Archive& ar,
const uint32_t );
422 const arma::mat&
Dataset()
const;
430 size_t& LeafSize() {
return leafSize; }
433 double Tau()
const {
return tau; }
434 double& Tau() {
return tau; }
437 double Rho()
const {
return rho; }
438 double& Rho() {
return rho; }
446 TreeTypes& TreeType() {
return treeType; }
450 bool& RandomBasis() {
return randomBasis; }
454 const double epsilon);
457 void BuildModel(arma::mat&& referenceSet,
459 const double epsilon = 0);
462 void Search(arma::mat&& querySet,
464 arma::Mat<size_t>& neighbors,
465 arma::mat& distances);
468 void Search(
const size_t k,
469 arma::Mat<size_t>& neighbors,
470 arma::mat& distances);
473 std::string TreeName()
const;
NSWrapper(const NeighborSearchMode searchMode, const double epsilon)
Construct the NSWrapper object, initializing the internally-held NeighborSearch object.
Definition: ns_model.hpp:104
virtual ~SpillNSWrapper()
Destruct the SpillNSWrapper.
Definition: ns_model.hpp:283
void serialize(Archive &ar, const uint32_t)
Serialize the NeighborSearch model.
Definition: ns_model.hpp:236
double Rho() const
Expose Rho.
Definition: ns_model.hpp:437
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
bool RandomBasis() const
Expose randomBasis.
Definition: ns_model.hpp:449
NSWrapper is a wrapper class for most NeighborSearch types.
Definition: ns_model.hpp:99
const arma::mat & Dataset() const
Get a reference to the reference set.
Definition: ns_model.hpp:119
virtual void Search(arma::mat &&querySet, const size_t k, arma::Mat< size_t > &neighbors, arma::mat &distances, const size_t leafSize, const double rho)=0
Perform bichromatic neighbor search (i.e.
TreeTypes
Enum type to identify each accepted tree type.
Definition: ns_model.hpp:337
Extra data for each node in the tree.
Definition: neighbor_search_stat.hpp:26
SpillNSWrapper(const NeighborSearchMode searchMode, const double epsilon)
Construct the SpillNSWrapper.
Definition: ns_model.hpp:266
LeafSizeNSWrapper wraps any NeighborSearch types that take a leaf size for tree construction.
Definition: neighbor_search.hpp:40
The NeighborSearch class is a template class for performing distance-based neighbor searches...
Definition: neighbor_search.hpp:88
LeafSizeNSWrapper(const NeighborSearchMode searchMode, const double epsilon)
Construct the LeafSizeNSWrapper by delegating to the NSWrapper constructor.
Definition: ns_model.hpp:199
virtual const arma::mat & Dataset() const =0
Return a reference to the dataset.
NeighborSearchMode & SearchMode()
Modify the search mode.
Definition: ns_model.hpp:124
virtual NSWrapper * Clone() const
Create a copy of this NSWrapper object.
Definition: ns_model.hpp:116
double Tau() const
Expose Tau.
Definition: ns_model.hpp:433
TreeTypes TreeType() const
Expose treeType.
Definition: ns_model.hpp:445
NSWrapperBase is a base wrapper class for holding all NeighborSearch types supported by NSModel...
Definition: ns_model.hpp:35
NSType ns
The instantiated NeighborSearch object that we are wrapping.
Definition: ns_model.hpp:170
NeighborSearchMode SearchMode() const
Get the search mode.
Definition: ns_model.hpp:122
virtual void Train(arma::mat &&referenceSet, const size_t leafSize, const double tau, const double rho)=0
Train the NeighborSearch model with the given parameters.
The NSModel class provides an easy way to serialize a model, abstracts away the different types of tr...
Definition: ns_model.hpp:333
virtual SpillNSWrapper * Clone() const
Return a copy of the SpillNSWrapper.
Definition: ns_model.hpp:286
double & Epsilon()
Modify epsilon, the approximation parameter.
Definition: ns_model.hpp:129
virtual ~NSWrapper()
Delete the NSWrapper object.
Definition: ns_model.hpp:112
SpillTree< MetricType, StatisticType, MatType, AxisOrthogonalHyperplane, MidpointSpaceSplit > SPTree
The hybrid spill tree.
Definition: typedef.hpp:62
virtual LeafSizeNSWrapper * Clone() const
Return a copy of the LeafSizeNSWrapper.
Definition: ns_model.hpp:213
The L_p metric for arbitrary integer p, with an option to take the root.
Definition: lmetric.hpp:63
NeighborSearchMode
NeighborSearchMode represents the different neighbor search modes available.
Definition: neighbor_search.hpp:43
virtual double Epsilon() const =0
Get the approximation parameter epsilon.
virtual ~LeafSizeNSWrapper()
Delete the LeafSizeNSWrapper.
Definition: ns_model.hpp:210
NSWrapperBase()
Create the NSWrapperBase object.
Definition: ns_model.hpp:40
size_t LeafSize() const
Expose LeafSize.
Definition: ns_model.hpp:429
double Epsilon() const
Get epsilon, the approximation parameter.
Definition: ns_model.hpp:127
void serialize(Archive &ar, const uint32_t)
Serialize the NeighborSearch model.
Definition: ns_model.hpp:155
The SpillNSWrapper class wraps the NeighborSearch class when the spill tree is used.
Definition: ns_model.hpp:253
virtual NeighborSearchMode SearchMode() const =0
Get the search mode.
void serialize(Archive &ar, const uint32_t)
Serialize the NeighborSearch model.
Definition: ns_model.hpp:305
virtual NSWrapperBase * Clone() const =0
Create a new NSWrapperBase that is the same as this one.
virtual ~NSWrapperBase()
Destruct the NSWrapperBase (nothing to do).
Definition: ns_model.hpp:47