mlpack
linear_svm_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_LINEAR_SVM_LINEAR_SVM_IMPL_HPP
13 #define MLPACK_METHODS_LINEAR_SVM_LINEAR_SVM_IMPL_HPP
14 
15 // In case it hasn't been included yet.
16 #include "linear_svm.hpp"
17 
18 namespace mlpack {
19 namespace svm {
20 
21 template <typename MatType>
22 template <typename OptimizerType, typename... CallbackTypes>
24  const MatType& data,
25  const arma::Row<size_t>& labels,
26  const size_t numClasses,
27  const double lambda,
28  const double delta,
29  const bool fitIntercept,
30  OptimizerType optimizer,
31  CallbackTypes&&... callbacks) :
32  numClasses(numClasses),
33  lambda(lambda),
34  delta(delta),
35  fitIntercept(fitIntercept)
36 {
37  Train(data, labels, numClasses, optimizer, callbacks...);
38 }
39 
40 template <typename MatType>
41 template <typename OptimizerType>
43  const MatType& data,
44  const arma::Row<size_t>& labels,
45  const size_t numClasses,
46  const double lambda,
47  const double delta,
48  const bool fitIntercept,
49  OptimizerType optimizer) :
50  numClasses(numClasses),
51  lambda(lambda),
52  delta(delta),
53  fitIntercept(fitIntercept)
54 {
55  Train(data, labels, numClasses, optimizer);
56 }
57 
58 template <typename MatType>
60  const size_t inputSize,
61  const size_t numClasses,
62  const double lambda,
63  const double delta,
64  const bool fitIntercept) :
65  numClasses(numClasses),
66  lambda(lambda),
67  delta(delta),
68  fitIntercept(fitIntercept)
69 {
71  numClasses, fitIntercept);
72 }
73 
74 template <typename MatType>
76  const size_t numClasses,
77  const double lambda,
78  const double delta,
79  const bool fitIntercept) :
80  numClasses(numClasses),
81  lambda(lambda),
82  delta(delta),
83  fitIntercept(fitIntercept)
84 {
85  // No training to do here.
86 }
87 
88 template <typename MatType>
89 template <typename OptimizerType, typename... CallbackTypes>
91  const MatType& data,
92  const arma::Row<size_t>& labels,
93  const size_t numClasses,
94  OptimizerType optimizer,
95  CallbackTypes&&... callbacks)
96 {
97  if (numClasses <= 1)
98  {
99  throw std::invalid_argument("LinearSVM dataset has 0 number of classes!");
100  }
101 
102  LinearSVMFunction<MatType> svm(data, labels, numClasses, lambda, delta,
103  fitIntercept);
104  if (parameters.is_empty())
105  parameters = svm.InitialPoint();
106 
107  // Train the model.
108  Timer::Start("linear_svm_optimization");
109  const double out = optimizer.Optimize(svm, parameters, callbacks...);
110  Timer::Stop("linear_svm_optimization");
111 
112  Log::Info << "LinearSVM::LinearSVM(): final objective of "
113  << "trained model is " << out << "." << std::endl;
114 
115  return out;
116 }
117 
118 template <typename MatType>
119 template <typename OptimizerType>
121  const MatType& data,
122  const arma::Row<size_t>& labels,
123  const size_t numClasses,
124  OptimizerType optimizer)
125 {
126  if (numClasses <= 1)
127  {
128  throw std::invalid_argument("LinearSVM dataset has 0 number of classes!");
129  }
130 
131  LinearSVMFunction<MatType> svm(data, labels, numClasses, lambda, delta,
132  fitIntercept);
133  if (parameters.is_empty())
134  parameters = svm.InitialPoint();
135 
136  // Train the model.
137  Timer::Start("linear_svm_optimization");
138  const double out = optimizer.Optimize(svm, parameters);
139  Timer::Stop("linear_svm_optimization");
140 
141  Log::Info << "LinearSVM::LinearSVM(): final objective of "
142  << "trained model is " << out << "." << std::endl;
143 
144  return out;
145 }
146 
147 template <typename MatType>
149  const MatType& data,
150  arma::Row<size_t>& labels) const
151 {
152  arma::mat scores;
153  Classify(data, labels, scores);
154 }
155 
156 template <typename MatType>
158  const MatType& data,
159  arma::Row<size_t>& labels,
160  arma::mat& scores) const
161 {
162  Classify(data, scores);
163 
164  // Prepare necessary data.
165  labels.zeros(data.n_cols);
166 
167  labels = arma::conv_to<arma::Row<size_t>>::from(
168  arma::index_max(scores));
169 }
170 
171 template <typename MatType>
173  const MatType& data,
174  arma::mat& scores) const
175 {
176  util::CheckSameDimensionality(data, FeatureSize(), "LinearSVM::Classify()");
177 
178  if (fitIntercept)
179  {
180  scores = parameters.rows(0, parameters.n_rows - 2).t() * data
181  + arma::repmat(parameters.row(parameters.n_rows - 1).t(), 1,
182  data.n_cols);
183  }
184  else
185  {
186  scores = parameters.t() * data;
187  }
188 }
189 
190 template <typename MatType>
191 template <typename VecType>
192 size_t LinearSVM<MatType>::Classify(const VecType& point) const
193 {
194  arma::Row<size_t> label(1);
195  Classify(point, label);
196  return size_t(label(0));
197 }
198 
199 template <typename MatType>
201  const MatType& testData,
202  const arma::Row<size_t>& testLabels) const
203 {
204  arma::Row<size_t> labels;
205 
206  // Get predictions for the provided data.
207  Classify(testData, labels);
208 
209  // Increment count for every correctly predicted label.
210  size_t count = 0;
211  for (size_t i = 0; i < labels.n_elem ; ++i)
212  if (testLabels(i) == labels(i))
213  count++;
214 
215  // Return the accuracy.
216  return (double) count / labels.n_elem;
217 }
218 
219 } // namespace svm
220 } // namespace mlpack
221 
222 #endif // MLPACK_METHODS_LINEAR_SVM_LINEAR_SVM_IMPL_HPP
The hinge loss function for the linear SVM objective function.
Definition: linear_svm_function.hpp:28
static void Start(const std::string &name)
Start the given timer.
Definition: timers.cpp:28
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
const arma::mat & InitialPoint() const
Return the initial point for the optimization.
Definition: linear_svm_function.hpp:161
LinearSVM(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const double lambda, const double delta, const bool fitIntercept, OptimizerType optimizer, CallbackTypes &&... callbacks)
Construct the LinearSVM class with the provided data and labels.
Definition: linear_svm_impl.hpp:23
double ComputeAccuracy(const MatType &testData, const arma::Row< size_t > &testLabels) const
Computes accuracy of the learned model given the feature data and the labels associated with each dat...
Definition: linear_svm_impl.hpp:200
static void Stop(const std::string &name)
Stop the given timer.
Definition: timers.cpp:36
double Train(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, OptimizerType optimizer, CallbackTypes &&... callbacks)
Train the Linear SVM with the given training data.
Definition: linear_svm_impl.hpp:90
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84
size_t FeatureSize() const
Gets the features size of the training data.
Definition: linear_svm.hpp:286
static void InitializeWeights(arma::mat &weights, const size_t featureSize, const size_t numClasses, const bool fitIntercept=false)
Initialize Linear SVM weights (trainable parameters) with the given parameters.
Definition: linear_svm_function_impl.hpp:53
void Classify(const MatType &data, arma::Row< size_t > &labels) const
Classify the given points, returning the predicted labels for each point.
Definition: linear_svm_impl.hpp:148