mlpack
Public Member Functions | List of all members
mlpack::lmnn::LMNN< MetricType, OptimizerType > Class Template Reference

An implementation of Large Margin nearest neighbor metric learning technique. More...

#include <lmnn.hpp>

Public Member Functions

 LMNN (const arma::mat &dataset, const arma::Row< size_t > &labels, const size_t k, const MetricType metric=MetricType())
 Initialize the LMNN object, passing a dataset (distance metric is learned using this dataset) and labels. More...
 
template<typename... CallbackTypes>
void LearnDistance (arma::mat &outputMatrix, CallbackTypes &&... callbacks)
 Perform Large Margin Nearest Neighbors metric learning. More...
 
const arma::mat & Dataset () const
 Get the dataset reference.
 
const arma::Row< size_t > & Labels () const
 Get the labels reference.
 
const double & Regularization () const
 Access the regularization value.
 
double & Regularization ()
 Modify the regularization value.
 
const size_t & Range () const
 Access the range value.
 
size_t & Range ()
 Modify the range value.
 
const size_t & K () const
 Access the value of k.
 
size_t K ()
 Modify the value of k.
 
const OptimizerType & Optimizer () const
 Get the optimizer.
 
OptimizerType & Optimizer ()
 

Detailed Description

template<typename MetricType = metric::SquaredEuclideanDistance, typename OptimizerType = ens::AMSGrad>
class mlpack::lmnn::LMNN< MetricType, OptimizerType >

An implementation of Large Margin nearest neighbor metric learning technique.

The method seeks to improve clustering & classification algorithms on a dataset by transforming the dataset representation in a more convenient form for them. It introduces the concept of target neighbors and impostors, focusing on the idea that the distance between impostors and the perimeters established by target neighbors should be large and that between target neighbors and data point should be small. It requires the knowledge of target neighbors beforehand. Moreover, target neighbors once initialized remain same.

For more details, see the following published paper:

@ARTICLE{weinberger09distance,
author = {Weinberger, K.Q. and Saul, L.K.},
title = {{Distance metric learning for large margin nearest neighbor
classification}},
journal = {The Journal of Machine Learning Research},
year = {2009},
volume = {10},
pages = {207--244},
publisher = {MIT Press}
}
Template Parameters
MetricTypeThe type of metric to use for computation.
OptimizerTypeOptimizer to use for developing distance.

Constructor & Destructor Documentation

◆ LMNN()

template<typename MetricType , typename OptimizerType >
mlpack::lmnn::LMNN< MetricType, OptimizerType >::LMNN ( const arma::mat &  dataset,
const arma::Row< size_t > &  labels,
const size_t  k,
const MetricType  metric = MetricType() 
)

Initialize the LMNN object, passing a dataset (distance metric is learned using this dataset) and labels.

Takes in a reference to the dataset.

Initialization will copy both dataset and labels matrices to internal copies.

Parameters
datasetInput dataset.
labelsInput dataset labels.
kNumber of targets to consider.
metricType of metric used for computation.

Copies the data, initializes all of the member variables and constraint object and generate constraints.

Member Function Documentation

◆ LearnDistance()

template<typename MetricType , typename OptimizerType >
template<typename... CallbackTypes>
void mlpack::lmnn::LMNN< MetricType, OptimizerType >::LearnDistance ( arma::mat &  outputMatrix,
CallbackTypes &&...  callbacks 
)

Perform Large Margin Nearest Neighbors metric learning.

The output distance matrix is written into the passed reference. If the LearnDistance() is called with an outputMatrix with correct dimensions, then that matrix will be used as the starting point for optimization.

Template Parameters
CallbackTypesTypes of Callback functions.
Parameters
outputMatrixCovariance matrix of Mahalanobis distance.
callbacksCallback function for ensmallen optimizer OptimizerType. See https://www.ensmallen.org/docs.html#callback-documentation.

The documentation for this class was generated from the following files: