mlpack
Public Types | Public Member Functions | Static Public Member Functions | Friends | List of all members
mlpack::neighbor::NeighborSearch< SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType, SingleTreeTraversalType > Class Template Reference

The NeighborSearch class is a template class for performing distance-based neighbor searches. More...

#include <neighbor_search.hpp>

Public Types

typedef TreeType< MetricType, NeighborSearchStat< SortPolicy >, MatType > Tree
 Convenience typedef.
 

Public Member Functions

 NeighborSearch (MatType referenceSet, const NeighborSearchMode mode=DUAL_TREE_MODE, const double epsilon=0, const MetricType metric=MetricType())
 Initialize the NeighborSearch object, passing a reference dataset (this is the dataset which is searched). More...
 
 NeighborSearch (Tree referenceTree, const NeighborSearchMode mode=DUAL_TREE_MODE, const double epsilon=0, const MetricType metric=MetricType())
 Initialize the NeighborSearch object with a copy of the given pre-constructed reference tree (this is the tree built on the points that will be searched). More...
 
 NeighborSearch (const NeighborSearchMode mode=DUAL_TREE_MODE, const double epsilon=0, const MetricType metric=MetricType())
 Create a NeighborSearch object without any reference data. More...
 
 NeighborSearch (const NeighborSearch &other)
 Construct the NeighborSearch object by copying the given NeighborSearch object. More...
 
 NeighborSearch (NeighborSearch &&other)
 Construct the NeighborSearch object by taking ownership of the given NeighborSearch object. More...
 
NeighborSearchoperator= (const NeighborSearch &other)
 Copy the given NeighborSearch object. More...
 
NeighborSearchoperator= (NeighborSearch &&other)
 Take ownership of the given NeighborSearch object. More...
 
 ~NeighborSearch ()
 Delete the NeighborSearch object. More...
 
void Train (MatType referenceSet)
 Set the reference set to a new reference set, and build a tree if necessary. More...
 
void Train (Tree referenceTree)
 Set the reference tree to a new reference tree. More...
 
void Search (const MatType &querySet, const size_t k, arma::Mat< size_t > &neighbors, arma::mat &distances)
 For each point in the query set, compute the nearest neighbors and store the output in the given matrices. More...
 
void Search (Tree &queryTree, const size_t k, arma::Mat< size_t > &neighbors, arma::mat &distances, bool sameSet=false)
 Given a pre-built query tree, search for the nearest neighbors of each point in the query tree, storing the output in the given matrices. More...
 
void Search (const size_t k, arma::Mat< size_t > &neighbors, arma::mat &distances)
 Search for the nearest neighbors of every point in the reference set. More...
 
size_t BaseCases () const
 Return the total number of base case evaluations performed during the last search. More...
 
size_t Scores () const
 Return the number of node combination scores during the last search.
 
NeighborSearchMode SearchMode () const
 Access the search mode.
 
NeighborSearchModeSearchMode ()
 Modify the search mode.
 
double Epsilon () const
 Access the relative error to be considered in approximate search.
 
double & Epsilon ()
 Modify the relative error to be considered in approximate search.
 
const MatType & ReferenceSet () const
 Access the reference dataset.
 
const TreeReferenceTree () const
 Access the reference tree.
 
TreeReferenceTree ()
 Modify the reference tree.
 
template<typename Archive >
void serialize (Archive &ar, const uint32_t version)
 Serialize the NeighborSearch model.
 

Static Public Member Functions

static double EffectiveError (arma::mat &foundDistances, arma::mat &realDistances)
 Calculate the average relative error (effective error) between the distances calculated and the true distances provided. More...
 
static double Recall (arma::Mat< size_t > &foundNeighbors, arma::Mat< size_t > &realNeighbors)
 Calculate the recall (% of neighbors found) given the list of found neighbors and the true set of neighbors. More...
 

Friends

class LeafSizeNSWrapper< SortPolicy, TreeType, DualTreeTraversalType, SingleTreeTraversalType >
 The NSModel class should have access to internal members.
 

Detailed Description

template<typename SortPolicy = NearestNeighborSort, typename MetricType = mlpack::metric::EuclideanDistance, typename MatType = arma::mat, template< typename TreeMetricType, typename TreeStatType, typename TreeMatType > class TreeType = tree::KDTree, template< typename RuleType > class DualTreeTraversalType = TreeType<MetricType, NeighborSearchStat<SortPolicy>, MatType>::template DualTreeTraverser, template< typename RuleType > class SingleTreeTraversalType = TreeType<MetricType, NeighborSearchStat<SortPolicy>, MatType>::template SingleTreeTraverser>
class mlpack::neighbor::NeighborSearch< SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType, SingleTreeTraversalType >

The NeighborSearch class is a template class for performing distance-based neighbor searches.

It takes a query dataset and a reference dataset (or just a reference dataset) and, for each point in the query dataset, finds the k neighbors in the reference dataset which have the 'best' distance according to a given sorting policy. A constructor is given which takes only a reference dataset, and if that constructor is used, the given reference dataset is also used as the query dataset.

The template parameters SortPolicy and Metric define the sort function used and the metric (distance function) used. More information on those classes can be found in the NearestNeighborSort class and the kernel::ExampleKernel class.

Template Parameters
SortPolicyThe sort policy for distances; see NearestNeighborSort.
MetricTypeThe metric to use for computation.
MatTypeThe type of data matrix.
TreeTypeThe tree type to use; must adhere to the TreeType API.
DualTreeTraversalTypeThe type of dual tree traversal to use (defaults to the tree's default traverser).
SingleTreeTraversalTypeThe type of single tree traversal to use (defaults to the tree's default traverser).

Constructor & Destructor Documentation

◆ NeighborSearch() [1/5]

template<typename SortPolicy , typename MetricType, typename MatType, template< typename TreeMetricType, typename TreeStatType, typename TreeMatType > class TreeType, template< typename > class DualTreeTraversalType, template< typename > class SingleTreeTraversalType>
mlpack::neighbor::NeighborSearch< SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType, SingleTreeTraversalType >::NeighborSearch ( MatType  referenceSet,
const NeighborSearchMode  mode = DUAL_TREE_MODE,
const double  epsilon = 0,
const MetricType  metric = MetricType() 
)

Initialize the NeighborSearch object, passing a reference dataset (this is the dataset which is searched).

Optionally, perform the computation in a different mode. An initialized distance metric can be given, for cases where the metric has internal data (i.e. the distance::MahalanobisDistance class).

This method will move the matrices to internal copies, which are rearranged during tree-building. You can avoid creating an extra copy by pre-constructing the trees, passing std::move(yourReferenceSet).

Parameters
referenceSetSet of reference points.
modeNeighbor search mode.
epsilonRelative approximate error (non-negative).
metricAn optional instance of the MetricType class.

◆ NeighborSearch() [2/5]

template<typename SortPolicy , typename MetricType, typename MatType, template< typename TreeMetricType, typename TreeStatType, typename TreeMatType > class TreeType, template< typename > class DualTreeTraversalType, template< typename > class SingleTreeTraversalType>
mlpack::neighbor::NeighborSearch< SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType, SingleTreeTraversalType >::NeighborSearch ( Tree  referenceTree,
const NeighborSearchMode  mode = DUAL_TREE_MODE,
const double  epsilon = 0,
const MetricType  metric = MetricType() 
)

Initialize the NeighborSearch object with a copy of the given pre-constructed reference tree (this is the tree built on the points that will be searched).

Optionally, choose to use single-tree mode. Naive mode is not available as an option for this constructor. Additionally, an instantiated distance metric can be given, for cases where the distance metric holds data.

This method will copy the given tree. When copies must absolutely be avoided, you can avoid this copy, while taking ownership of the given tree, by passing std::move(yourReferenceTree)

Note
Mapping the points of the matrix back to their original indices is not done when this constructor is used, so if the tree type you are using maps points (like BinarySpaceTree), then you will have to perform the re-mapping manually.
Parameters
referenceTreePre-built tree for reference points.
modeNeighbor search mode.
epsilonRelative approximate error (non-negative).
metricInstantiated distance metric.

◆ NeighborSearch() [3/5]

template<typename SortPolicy , typename MetricType, typename MatType, template< typename TreeMetricType, typename TreeStatType, typename TreeMatType > class TreeType, template< typename > class DualTreeTraversalType, template< typename > class SingleTreeTraversalType>
mlpack::neighbor::NeighborSearch< SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType, SingleTreeTraversalType >::NeighborSearch ( const NeighborSearchMode  mode = DUAL_TREE_MODE,
const double  epsilon = 0,
const MetricType  metric = MetricType() 
)

Create a NeighborSearch object without any reference data.

If Search() is called before a reference set is set with Train(), an exception will be thrown.

Parameters
modeNeighbor search mode.
epsilonRelative approximate error (non-negative).
metricInstantiated metric.

◆ NeighborSearch() [4/5]

template<typename SortPolicy , typename MetricType, typename MatType, template< typename TreeMetricType, typename TreeStatType, typename TreeMatType > class TreeType, template< typename > class DualTreeTraversalType, template< typename > class SingleTreeTraversalType>
mlpack::neighbor::NeighborSearch< SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType, SingleTreeTraversalType >::NeighborSearch ( const NeighborSearch< SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType, SingleTreeTraversalType > &  other)

Construct the NeighborSearch object by copying the given NeighborSearch object.

Parameters
otherNeighborSearch object to copy.

◆ NeighborSearch() [5/5]

template<typename SortPolicy , typename MetricType, typename MatType, template< typename TreeMetricType, typename TreeStatType, typename TreeMatType > class TreeType, template< typename > class DualTreeTraversalType, template< typename > class SingleTreeTraversalType>
mlpack::neighbor::NeighborSearch< SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType, SingleTreeTraversalType >::NeighborSearch ( NeighborSearch< SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType, SingleTreeTraversalType > &&  other)

Construct the NeighborSearch object by taking ownership of the given NeighborSearch object.

Parameters
otherNeighborSearch object to take ownership of.

◆ ~NeighborSearch()

template<typename SortPolicy , typename MetricType , typename MatType , template< typename TreeMetricType, typename TreeStatType, typename TreeMatType > class TreeType, template< typename > class DualTreeTraversalType, template< typename > class SingleTreeTraversalType>
mlpack::neighbor::NeighborSearch< SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType, SingleTreeTraversalType >::~NeighborSearch ( )

Delete the NeighborSearch object.

The tree is the only member we are responsible for deleting. The others will take care of themselves.

Member Function Documentation

◆ BaseCases()

template<typename SortPolicy = NearestNeighborSort, typename MetricType = mlpack::metric::EuclideanDistance, typename MatType = arma::mat, template< typename TreeMetricType, typename TreeStatType, typename TreeMatType > class TreeType = tree::KDTree, template< typename RuleType > class DualTreeTraversalType = TreeType<MetricType, NeighborSearchStat<SortPolicy>, MatType>::template DualTreeTraverser, template< typename RuleType > class SingleTreeTraversalType = TreeType<MetricType, NeighborSearchStat<SortPolicy>, MatType>::template SingleTreeTraverser>
size_t mlpack::neighbor::NeighborSearch< SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType, SingleTreeTraversalType >::BaseCases ( ) const
inline

Return the total number of base case evaluations performed during the last search.

◆ EffectiveError()

template<typename SortPolicy , typename MetricType , typename MatType , template< typename TreeMetricType, typename TreeStatType, typename TreeMatType > class TreeType, template< typename > class DualTreeTraversalType, template< typename > class SingleTreeTraversalType>
double mlpack::neighbor::NeighborSearch< SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType, SingleTreeTraversalType >::EffectiveError ( arma::mat &  foundDistances,
arma::mat &  realDistances 
)
static

Calculate the average relative error (effective error) between the distances calculated and the true distances provided.

Calculate the average relative error.

The input matrices must have the same size.

Cases where the true distance is zero (the same point) or the calculated distance is SortPolicy::WorstDistance() (didn't find enough points) will be ignored.

Parameters
foundDistancesMatrix storing lists of calculated distances for each query point.
realDistancesMatrix storing lists of true best distances for each query point.
Returns
Average relative error.

◆ operator=() [1/2]

template<typename SortPolicy , typename MetricType , typename MatType , template< typename TreeMetricType, typename TreeStatType, typename TreeMatType > class TreeType, template< typename > class DualTreeTraversalType, template< typename > class SingleTreeTraversalType>
NeighborSearch< SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType, SingleTreeTraversalType > & mlpack::neighbor::NeighborSearch< SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType, SingleTreeTraversalType >::operator= ( const NeighborSearch< SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType, SingleTreeTraversalType > &  other)

Copy the given NeighborSearch object.

Parameters
otherNeighborSearch object to copy.

◆ operator=() [2/2]

template<typename SortPolicy , typename MetricType , typename MatType , template< typename TreeMetricType, typename TreeStatType, typename TreeMatType > class TreeType, template< typename > class DualTreeTraversalType, template< typename > class SingleTreeTraversalType>
NeighborSearch< SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType, SingleTreeTraversalType > & mlpack::neighbor::NeighborSearch< SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType, SingleTreeTraversalType >::operator= ( NeighborSearch< SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType, SingleTreeTraversalType > &&  other)

Take ownership of the given NeighborSearch object.

Parameters
otherNeighborSearch object to take ownership of.

◆ Recall()

template<typename SortPolicy , typename MetricType , typename MatType , template< typename TreeMetricType, typename TreeStatType, typename TreeMatType > class TreeType, template< typename > class DualTreeTraversalType, template< typename > class SingleTreeTraversalType>
double mlpack::neighbor::NeighborSearch< SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType, SingleTreeTraversalType >::Recall ( arma::Mat< size_t > &  foundNeighbors,
arma::Mat< size_t > &  realNeighbors 
)
static

Calculate the recall (% of neighbors found) given the list of found neighbors and the true set of neighbors.

Calculate the recall.

The recall returned will be in the range [0, 1].

Parameters
foundNeighborsMatrix storing lists of calculated neighbors for each query point.
realNeighborsMatrix storing lists of true best neighbors for each query point.
Returns
Recall.

◆ Search() [1/3]

template<typename SortPolicy , typename MetricType , typename MatType, template< typename TreeMetricType, typename TreeStatType, typename TreeMatType > class TreeType, template< typename > class DualTreeTraversalType, template< typename > class SingleTreeTraversalType>
void mlpack::neighbor::NeighborSearch< SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType, SingleTreeTraversalType >::Search ( const MatType &  querySet,
const size_t  k,
arma::Mat< size_t > &  neighbors,
arma::mat &  distances 
)

For each point in the query set, compute the nearest neighbors and store the output in the given matrices.

Computes the best neighbors and stores them in resultingNeighbors and distances.

The matrices will be set to the size of n columns by k rows, where n is the number of points in the query dataset and k is the number of neighbors being searched for.

If querySet contains only a few query points, the extra cost of building a tree on the points for dual-tree search may not be warranted, and it may be worthwhile to set singleMode = false (either in the constructor or with SingleMode()).

Parameters
querySetSet of query points (can be just one point).
kNumber of neighbors to search for.
neighborsMatrix storing lists of neighbors for each query point.
distancesMatrix storing distances of neighbors for each query point.

◆ Search() [2/3]

template<typename SortPolicy , typename MetricType , typename MatType, template< typename TreeMetricType, typename TreeStatType, typename TreeMatType > class TreeType, template< typename > class DualTreeTraversalType, template< typename > class SingleTreeTraversalType>
void mlpack::neighbor::NeighborSearch< SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType, SingleTreeTraversalType >::Search ( Tree queryTree,
const size_t  k,
arma::Mat< size_t > &  neighbors,
arma::mat &  distances,
bool  sameSet = false 
)

Given a pre-built query tree, search for the nearest neighbors of each point in the query tree, storing the output in the given matrices.

The matrices will be set to the size of n columns by k rows, where n is the number of points in the query dataset and k is the number of neighbors being searched for.

Note that if you are calling Search() multiple times with a single query tree, you need to reset the bounds in the statistic of each query node, otherwise the result may be wrong! You can do this by calling TreeType::Stat().Reset() on each node in the query tree.

Parameters
queryTreeTree built on query points.
kNumber of neighbors to search for.
neighborsMatrix storing lists of neighbors for each query point.
distancesMatrix storing distances of neighbors for each query point.
sameSetDenotes whether or not the reference and query sets are the same.

◆ Search() [3/3]

template<typename SortPolicy , typename MetricType , typename MatType, template< typename TreeMetricType, typename TreeStatType, typename TreeMatType > class TreeType, template< typename > class DualTreeTraversalType, template< typename > class SingleTreeTraversalType>
void mlpack::neighbor::NeighborSearch< SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType, SingleTreeTraversalType >::Search ( const size_t  k,
arma::Mat< size_t > &  neighbors,
arma::mat &  distances 
)

Search for the nearest neighbors of every point in the reference set.

This is basically equivalent to calling any other overload of Search() with the reference set as the query set; so, this lets you do all-k-nearest-neighbors search. The results are stored in the given matrices. The matrices will be set to the size of n columns by k rows, where n is the number of points in the query dataset and k is the number of neighbors being searched for.

Parameters
kNumber of neighbors to search for.
neighborsMatrix storing lists of neighbors for each query point.
distancesMatrix storing distances of neighbors for each query point.

◆ Train() [1/2]

template<typename SortPolicy , typename MetricType , typename MatType, template< typename TreeMetricType, typename TreeStatType, typename TreeMatType > class TreeType, template< typename > class DualTreeTraversalType, template< typename > class SingleTreeTraversalType>
void mlpack::neighbor::NeighborSearch< SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType, SingleTreeTraversalType >::Train ( MatType  referenceSet)

Set the reference set to a new reference set, and build a tree if necessary.

The dataset is copied by default, but the copy can be avoided by transferring the ownership of the dataset using std::move(). This method is called 'Train()' in order to match the rest of the mlpack abstractions, even though calling this "training" is maybe a bit of a stretch.

Parameters
referenceSetNew set of reference data.

◆ Train() [2/2]

template<typename SortPolicy , typename MetricType , typename MatType, template< typename TreeMetricType, typename TreeStatType, typename TreeMatType > class TreeType, template< typename > class DualTreeTraversalType, template< typename > class SingleTreeTraversalType>
void mlpack::neighbor::NeighborSearch< SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType, SingleTreeTraversalType >::Train ( Tree  referenceTree)

Set the reference tree to a new reference tree.

The tree is copied by default, but the copy can be avoided by using std::move() to transfer the ownership of the tree. This method is called 'Train()' in order to match the rest of the mlpack abstractions, even though calling this "training" is maybe a bit of a stretch.

Parameters
referenceTreePre-built tree for reference points.

The documentation for this class was generated from the following files: