mlpack
lmnn_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_LMNN_LMNN_IMPL_HPP
13 #define MLPACK_METHODS_LMNN_LMNN_IMPL_HPP
14 
15 // In case it was not already included.
16 #include "lmnn.hpp"
17 
18 namespace mlpack {
19 namespace lmnn {
20 
25 template<typename MetricType, typename OptimizerType>
26 LMNN<MetricType, OptimizerType>::LMNN(const arma::mat& dataset,
27  const arma::Row<size_t>& labels,
28  const size_t k,
29  const MetricType metric) :
30  dataset(dataset),
31  labels(labels),
32  k(k),
33  regularization(0.5),
34  range(1),
35  metric(metric)
36 { /* nothing to do */ }
37 
38 template<typename MetricType, typename OptimizerType>
39 template<typename... CallbackTypes>
41  CallbackTypes&&... callbacks)
42 {
43  // LMNN objective function.
44  LMNNFunction<MetricType> objFunction(dataset, labels, k,
45  regularization, range);
46 
47  // See if we were passed an initialized matrix. outputMatrix (L) must be
48  // having r x d dimensionality.
49  if ((outputMatrix.n_cols != dataset.n_rows) ||
50  (outputMatrix.n_rows > dataset.n_rows) ||
51  !(arma::is_finite(outputMatrix)))
52  {
53  Log::Info << "Initial learning point have invalid dimensionality. "
54  "Identity matrix will be used as initial learning point for "
55  "optimization." << std::endl;
56  outputMatrix.eye(dataset.n_rows, dataset.n_rows);
57  }
58 
59  Timer::Start("lmnn_optimization");
60 
61  optimizer.Optimize(objFunction, outputMatrix, callbacks...);
62 
63  Timer::Stop("lmnn_optimization");
64 }
65 
66 
67 } // namespace lmnn
68 } // namespace mlpack
69 
70 #endif
void LearnDistance(arma::mat &outputMatrix, CallbackTypes &&... callbacks)
Perform Large Margin Nearest Neighbors metric learning.
Definition: lmnn_impl.hpp:40
static void Start(const std::string &name)
Start the given timer.
Definition: timers.cpp:28
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The Large Margin Nearest Neighbors function.
Definition: lmnn_function.hpp:46
LMNN(const arma::mat &dataset, const arma::Row< size_t > &labels, const size_t k, const MetricType metric=MetricType())
Initialize the LMNN object, passing a dataset (distance metric is learned using this dataset) and lab...
Definition: lmnn_impl.hpp:26
static void Stop(const std::string &name)
Stop the given timer.
Definition: timers.cpp:36
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84