mlpack
nca_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_NCA_NCA_IMPL_HPP
13 #define MLPACK_METHODS_NCA_NCA_IMPL_HPP
14 
15 // In case it was not already included.
16 #include "nca.hpp"
17 
18 namespace mlpack {
19 namespace nca {
20 
21 // Just set the internal matrix reference.
22 template<typename MetricType, typename OptimizerType>
23 NCA<MetricType, OptimizerType>::NCA(const arma::mat& dataset,
24  const arma::Row<size_t>& labels,
25  MetricType metric) :
26  dataset(dataset),
27  labels(labels),
28  metric(metric),
29  errorFunction(dataset, labels, metric)
30 { /* Nothing to do. */ }
31 
32 template<typename MetricType, typename OptimizerType>
33 template<typename... CallbackTypes>
34 void NCA<MetricType, OptimizerType>::LearnDistance(arma::mat& outputMatrix,
35  CallbackTypes&&... callbacks)
36 {
37  // See if we were passed an initialized matrix.
38  if ((outputMatrix.n_rows != dataset.n_rows) ||
39  (outputMatrix.n_cols != dataset.n_rows))
40  outputMatrix.eye(dataset.n_rows, dataset.n_rows);
41 
42  Timer::Start("nca_sgd_optimization");
43 
44  optimizer.Optimize(errorFunction, outputMatrix, callbacks...);
45 
46  Timer::Stop("nca_sgd_optimization");
47 }
48 
49 } // namespace nca
50 } // namespace mlpack
51 
52 #endif
static void Start(const std::string &name)
Start the given timer.
Definition: timers.cpp:28
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void LearnDistance(arma::mat &outputMatrix, CallbackTypes &&... callbacks)
Perform Neighborhood Components Analysis.
Definition: nca_impl.hpp:34
NCA(const arma::mat &dataset, const arma::Row< size_t > &labels, MetricType metric=MetricType())
Construct the Neighborhood Components Analysis object.
Definition: nca_impl.hpp:23
static void Stop(const std::string &name)
Stop the given timer.
Definition: timers.cpp:36