mlpack
Public Member Functions | List of all members
mlpack::meanshift::MeanShift< UseKernel, KernelType, MatType > Class Template Reference

This class implements mean shift clustering. More...

#include <mean_shift.hpp>

Public Member Functions

 MeanShift (const double radius=0, const size_t maxIterations=1000, const KernelType kernel=KernelType())
 Create a mean shift object and set the parameters which mean shift will be run with. More...
 
double EstimateRadius (const MatType &data, const double ratio=0.2)
 Give an estimation of radius based on given dataset. More...
 
void Cluster (const MatType &data, arma::Row< size_t > &assignments, arma::mat &centroids, bool forceConvergence=true, bool useSeeds=true)
 Perform mean shift clustering on the data, returning a list of cluster assignments and centroids. More...
 
size_t MaxIterations () const
 Get the maximum number of iterations.
 
size_t & MaxIterations ()
 Set the maximum number of iterations.
 
double Radius () const
 Get the radius.
 
void Radius (double radius)
 Set the radius.
 
const KernelType & Kernel () const
 Get the kernel.
 
KernelType & Kernel ()
 Modify the kernel.
 

Detailed Description

template<bool UseKernel = false, typename KernelType = kernel::GaussianKernel, typename MatType = arma::mat>
class mlpack::meanshift::MeanShift< UseKernel, KernelType, MatType >

This class implements mean shift clustering.

For each point in dataset, apply mean shift algorithm until maximum iterations or convergence. Then remove duplicate centroids.

A simple example of how to run mean shift clustering is shown below.

extern arma::mat data; // Dataset we want to run mean shift on.
arma::Row<size_t> assignments; // Cluster assignments.
arma::mat centroids; // Cluster centroids.
bool forceConvergence = true; // Flag whether to force each centroid seed
to converge regardless of maxIterations.
MeanShift<> meanShift();
meanShift.Cluster(dataset, assignments, centroids, forceConvergence);
Template Parameters
UseKernelUse kernel or mean to calculate new centroid. If false, KernelType will be ignored.
KernelTypeThe kernel to use.
MatTypeThe type of matrix the data is stored in.

Constructor & Destructor Documentation

◆ MeanShift()

template<bool UseKernel, typename KernelType , typename MatType >
mlpack::meanshift::MeanShift< UseKernel, KernelType, MatType >::MeanShift ( const double  radius = 0,
const size_t  maxIterations = 1000,
const KernelType  kernel = KernelType() 
)

Create a mean shift object and set the parameters which mean shift will be run with.

Construct the Mean Shift object.

Parameters
radiusIf distance of two centroids is less than it, one will be removed. If this value isn't positive, an estimation will be given when clustering.
maxIterationsMaximum number of iterations allowed before giving up iterations will terminate.
kernelOptional KernelType object.

Member Function Documentation

◆ Cluster()

template<bool UseKernel, typename KernelType , typename MatType >
void mlpack::meanshift::MeanShift< UseKernel, KernelType, MatType >::Cluster ( const MatType &  data,
arma::Row< size_t > &  assignments,
arma::mat &  centroids,
bool  forceConvergence = true,
bool  useSeeds = true 
)
inline

Perform mean shift clustering on the data, returning a list of cluster assignments and centroids.

Perform Mean Shift clustering on the data set, returning a list of cluster assignments and centroids.

Template Parameters
MatTypeType of matrix.
Parameters
dataDataset to cluster.
assignmentsVector to store cluster assignments in.
centroidsMatrix in which centroids are stored.
forceConvergenceFlag whether to force each centroid seed to converge regardless of maxIterations.
useSeedsSet true to use seeds.

◆ EstimateRadius()

template<bool UseKernel, typename KernelType , typename MatType >
double mlpack::meanshift::MeanShift< UseKernel, KernelType, MatType >::EstimateRadius ( const MatType &  data,
const double  ratio = 0.2 
)

Give an estimation of radius based on given dataset.

Parameters
dataDataset for estimation.
ratioPercentage of dataset to use for nearest neighbor search.

For each point in dataset, select nNeighbors nearest points and get nNeighbors distances. Use the maximum distance to estimate the duplicate threshhold.


The documentation for this class was generated from the following files: