12 #ifndef MLPACK_METHODS_LINEAR_SVM_LINEAR_SVM_IMPL_HPP 13 #define MLPACK_METHODS_LINEAR_SVM_LINEAR_SVM_IMPL_HPP 21 template <
typename MatType>
22 template <
typename OptimizerType,
typename... CallbackTypes>
25 const arma::Row<size_t>& labels,
26 const size_t numClasses,
29 const bool fitIntercept,
30 OptimizerType optimizer,
31 CallbackTypes&&... callbacks) :
32 numClasses(numClasses),
35 fitIntercept(fitIntercept)
37 Train(data, labels, numClasses, optimizer, callbacks...);
40 template <
typename MatType>
41 template <
typename OptimizerType>
44 const arma::Row<size_t>& labels,
45 const size_t numClasses,
48 const bool fitIntercept,
49 OptimizerType optimizer) :
50 numClasses(numClasses),
53 fitIntercept(fitIntercept)
55 Train(data, labels, numClasses, optimizer);
58 template <
typename MatType>
60 const size_t inputSize,
61 const size_t numClasses,
64 const bool fitIntercept) :
65 numClasses(numClasses),
68 fitIntercept(fitIntercept)
71 numClasses, fitIntercept);
74 template <
typename MatType>
76 const size_t numClasses,
79 const bool fitIntercept) :
80 numClasses(numClasses),
83 fitIntercept(fitIntercept)
88 template <
typename MatType>
89 template <
typename OptimizerType,
typename... CallbackTypes>
92 const arma::Row<size_t>& labels,
93 const size_t numClasses,
94 OptimizerType optimizer,
95 CallbackTypes&&... callbacks)
99 throw std::invalid_argument(
"LinearSVM dataset has 0 number of classes!");
104 if (parameters.is_empty())
109 const double out = optimizer.Optimize(svm, parameters, callbacks...);
112 Log::Info <<
"LinearSVM::LinearSVM(): final objective of " 113 <<
"trained model is " << out <<
"." << std::endl;
118 template <
typename MatType>
119 template <
typename OptimizerType>
122 const arma::Row<size_t>& labels,
123 const size_t numClasses,
124 OptimizerType optimizer)
128 throw std::invalid_argument(
"LinearSVM dataset has 0 number of classes!");
133 if (parameters.is_empty())
138 const double out = optimizer.Optimize(svm, parameters);
141 Log::Info <<
"LinearSVM::LinearSVM(): final objective of " 142 <<
"trained model is " << out <<
"." << std::endl;
147 template <
typename MatType>
150 arma::Row<size_t>& labels)
const 156 template <
typename MatType>
159 arma::Row<size_t>& labels,
160 arma::mat& scores)
const 165 labels.zeros(data.n_cols);
167 labels = arma::conv_to<arma::Row<size_t>>::from(
168 arma::index_max(scores));
171 template <
typename MatType>
174 arma::mat& scores)
const 176 util::CheckSameDimensionality(data,
FeatureSize(),
"LinearSVM::Classify()");
180 scores = parameters.rows(0, parameters.n_rows - 2).t() * data
181 + arma::repmat(parameters.row(parameters.n_rows - 1).t(), 1,
186 scores = parameters.t() * data;
190 template <
typename MatType>
191 template <
typename VecType>
194 arma::Row<size_t> label(1);
196 return size_t(label(0));
199 template <
typename MatType>
201 const MatType& testData,
202 const arma::Row<size_t>& testLabels)
const 204 arma::Row<size_t> labels;
211 for (
size_t i = 0; i < labels.n_elem ; ++i)
212 if (testLabels(i) == labels(i))
216 return (
double) count / labels.n_elem;
222 #endif // MLPACK_METHODS_LINEAR_SVM_LINEAR_SVM_IMPL_HPP The hinge loss function for the linear SVM objective function.
Definition: linear_svm_function.hpp:28
static void Start(const std::string &name)
Start the given timer.
Definition: timers.cpp:28
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
const arma::mat & InitialPoint() const
Return the initial point for the optimization.
Definition: linear_svm_function.hpp:161
LinearSVM(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const double lambda, const double delta, const bool fitIntercept, OptimizerType optimizer, CallbackTypes &&... callbacks)
Construct the LinearSVM class with the provided data and labels.
Definition: linear_svm_impl.hpp:23
double ComputeAccuracy(const MatType &testData, const arma::Row< size_t > &testLabels) const
Computes accuracy of the learned model given the feature data and the labels associated with each dat...
Definition: linear_svm_impl.hpp:200
static void Stop(const std::string &name)
Stop the given timer.
Definition: timers.cpp:36
double Train(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, OptimizerType optimizer, CallbackTypes &&... callbacks)
Train the Linear SVM with the given training data.
Definition: linear_svm_impl.hpp:90
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84
size_t FeatureSize() const
Gets the features size of the training data.
Definition: linear_svm.hpp:286
static void InitializeWeights(arma::mat &weights, const size_t featureSize, const size_t numClasses, const bool fitIntercept=false)
Initialize Linear SVM weights (trainable parameters) with the given parameters.
Definition: linear_svm_function_impl.hpp:53
void Classify(const MatType &data, arma::Row< size_t > &labels) const
Classify the given points, returning the predicted labels for each point.
Definition: linear_svm_impl.hpp:148