mlpack
linear_svm_function.hpp
Go to the documentation of this file.
1 
14 #ifndef MLPACK_METHODS_LINEAR_SVM_LINEAR_SVM_FUNCTION_HPP
15 #define MLPACK_METHODS_LINEAR_SVM_LINEAR_SVM_FUNCTION_HPP
16 
17 #include <mlpack/prereqs.hpp>
18 
19 namespace mlpack {
20 namespace svm {
21 
27 template <typename MatType = arma::mat>
29 {
30  public:
41  LinearSVMFunction(const MatType& dataset,
42  const arma::Row<size_t>& labels,
43  const size_t numClasses,
44  const double lambda = 0.0001,
45  const double delta = 1.0,
46  const bool fitIntercept = false);
47 
51  void Shuffle();
52 
62  static void InitializeWeights(arma::mat& weights,
63  const size_t featureSize,
64  const size_t numClasses,
65  const bool fitIntercept = false);
66 
73  void GetGroundTruthMatrix(const arma::Row<size_t>& labels,
74  arma::sp_mat& groundTruth);
75 
82  double Evaluate(const arma::mat& parameters);
83 
93  double Evaluate(const arma::mat& parameters,
94  const size_t firstId,
95  const size_t batchSize = 1);
96 
105  template <typename GradType>
106  void Gradient(const arma::mat& parameters,
107  GradType& gradient);
108 
119  template <typename GradType>
120  void Gradient(const arma::mat& parameters,
121  const size_t firstId,
122  GradType& gradient,
123  const size_t batchSize = 1);
124 
136  template <typename GradType>
137  double EvaluateWithGradient(const arma::mat& parameters,
138  GradType& gradient) const;
139 
154  template <typename GradType>
155  double EvaluateWithGradient(const arma::mat& parameters,
156  const size_t firstId,
157  GradType& gradient,
158  const size_t batchSize = 1) const;
159 
161  const arma::mat& InitialPoint() const { return initialPoint; }
163  arma::mat& InitialPoint() { return initialPoint; }
164 
166  const arma::sp_mat& Dataset() const { return dataset; }
168  arma::sp_mat& Dataset() { return dataset; }
169 
171  double& Lambda() { return lambda; }
173  double Lambda() const { return lambda; }
174 
176  bool FitIntercept() const { return fitIntercept; }
177 
179  size_t NumFunctions() const;
180 
181  private:
183  arma::mat initialPoint;
184 
186  arma::sp_mat groundTruth;
187 
189  MatType dataset;
190 
192  size_t numClasses;
193 
195  double lambda;
196 
198  double delta;
199 
201  bool fitIntercept;
202 };
203 
204 } // namespace svm
205 } // namespace mlpack
206 
207 // Include implementation
209 
210 #endif // MLPACK_METHODS_LINEAR_SVM_LINEAR_SVM_FUNCTION_HPP
double Lambda() const
Gets the regularization parameter.
Definition: linear_svm_function.hpp:173
The hinge loss function for the linear SVM objective function.
Definition: linear_svm_function.hpp:28
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.
void GetGroundTruthMatrix(const arma::Row< size_t > &labels, arma::sp_mat &groundTruth)
Constructs the ground truth label matrix with the passed labels.
Definition: linear_svm_function_impl.hpp:74
arma::mat & InitialPoint()
Modify the initial point for the optimization.
Definition: linear_svm_function.hpp:163
double EvaluateWithGradient(const arma::mat &parameters, GradType &gradient) const
Evaluate the gradient of the hinge loss function, following the LinearFunctionType requirements on th...
Definition: linear_svm_function_impl.hpp:360
const arma::mat & InitialPoint() const
Return the initial point for the optimization.
Definition: linear_svm_function.hpp:161
void Gradient(const arma::mat &parameters, GradType &gradient)
Evaluate the gradient of the hinge loss function following the LinearFunctionType requirements on the...
Definition: linear_svm_function_impl.hpp:239
LinearSVMFunction(const MatType &dataset, const arma::Row< size_t > &labels, const size_t numClasses, const double lambda=0.0001, const double delta=1.0, const bool fitIntercept=false)
Construct the Linear SVM objective function with given parameters.
Definition: linear_svm_function_impl.hpp:27
bool FitIntercept() const
Gets the intercept flag.
Definition: linear_svm_function.hpp:176
arma::sp_mat & Dataset()
Modify the dataset.
Definition: linear_svm_function.hpp:168
double Evaluate(const arma::mat &parameters)
Evaluate the hinge loss function for all the datapoints.
Definition: linear_svm_function_impl.hpp:146
const arma::sp_mat & Dataset() const
Get the dataset.
Definition: linear_svm_function.hpp:166
size_t NumFunctions() const
Return the number of functions.
Definition: linear_svm_function_impl.hpp:494
void Shuffle()
Shuffle the dataset.
Definition: linear_svm_function_impl.hpp:111
double & Lambda()
Sets the regularization parameter.
Definition: linear_svm_function.hpp:171
static void InitializeWeights(arma::mat &weights, const size_t featureSize, const size_t numClasses, const bool fitIntercept=false)
Initialize Linear SVM weights (trainable parameters) with the given parameters.
Definition: linear_svm_function_impl.hpp:53