mlpack
simple_cv.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_CORE_CV_SIMPLE_CV_HPP
13 #define MLPACK_CORE_CV_SIMPLE_CV_HPP
14 
17 
18 namespace mlpack {
19 namespace cv {
20 
60 template<typename MLAlgorithm,
61  typename Metric,
62  typename MatType = arma::mat,
63  typename PredictionsType =
65  typename WeightsType =
66  typename MetaInfoExtractor<MLAlgorithm, MatType,
67  PredictionsType>::WeightsType>
68 class SimpleCV
69 {
70  public:
84  template<typename MatInType, typename PredictionsInType>
85  SimpleCV(const double validationSize,
86  MatInType&& xs,
87  PredictionsInType&& ys);
88 
101  template<typename MatInType, typename PredictionsInType>
102  SimpleCV(const double validationSize,
103  MatInType&& xs,
104  PredictionsInType&& ys,
105  const size_t numClasses);
106 
121  template<typename MatInType, typename PredictionsInType>
122  SimpleCV(const double validationSize,
123  MatInType&& xs,
124  const data::DatasetInfo& datasetInfo,
125  PredictionsInType&& ys,
126  const size_t numClasses);
127 
143  template<typename MatInType,
144  typename PredictionsInType,
145  typename WeightsInType>
146  SimpleCV(const double validationSize,
147  MatInType&& xs,
148  PredictionsInType&& ys,
149  WeightsInType&& weights);
150 
166  template<typename MatInType,
167  typename PredictionsInType,
168  typename WeightsInType>
169  SimpleCV(const double validationSize,
170  MatInType&& xs,
171  PredictionsInType&& ys,
172  const size_t numClasses,
173  WeightsInType&& weights);
174 
191  template<typename MatInType,
192  typename PredictionsInType,
193  typename WeightsInType>
194  SimpleCV(const double validationSize,
195  MatInType&& xs,
196  const data::DatasetInfo& datasetInfo,
197  PredictionsInType&& ys,
198  const size_t numClasses,
199  WeightsInType&& weights);
200 
208  template<typename... MLAlgorithmArgs>
209  double Evaluate(const MLAlgorithmArgs&... args);
210 
212  MLAlgorithm& Model();
213 
214  private:
217 
219  Base base;
220 
222  MatType xs;
224  PredictionsType ys;
226  WeightsType weights;
227 
229  MatType trainingXs;
231  PredictionsType trainingYs;
233  WeightsType trainingWeights;
234 
236  MatType validationXs;
238  PredictionsType validationYs;
239 
241  std::unique_ptr<MLAlgorithm> modelPtr;
242 
247  template<typename MatInType,
248  typename PredictionsInType>
249  SimpleCV(Base&& base,
250  const double validationSize,
251  MatInType&& xs,
252  PredictionsInType&& ys);
253 
258  template<typename MatInType,
259  typename PredictionsInType,
260  typename WeightsInType>
261  SimpleCV(Base&& base,
262  const double validationSize,
263  MatInType&& xs,
264  PredictionsInType&& ys,
265  WeightsInType&& weights);
266 
270  size_t CalculateAndAssertNumberOfTrainingPoints(const double validationSize);
271 
275  template<typename ElementType>
276  arma::Mat<ElementType> GetSubset(arma::Mat<ElementType>& m,
277  const size_t firstCol,
278  const size_t lastCol);
279 
283  template<typename ElementType>
284  arma::Row<ElementType> GetSubset(arma::Row<ElementType>& r,
285  const size_t firstCol,
286  const size_t lastCol);
287 
291  template<typename... MLAlgorithmArgs,
292  bool Enabled = !Base::MIE::SupportsWeights,
293  typename = typename std::enable_if<Enabled>::type>
294  double TrainAndEvaluate(const MLAlgorithmArgs&... args);
295 
299  template<typename... MLAlgorithmArgs,
300  bool Enabled = Base::MIE::SupportsWeights,
301  typename = typename std::enable_if<Enabled>::type,
302  typename = void>
303  double TrainAndEvaluate(const MLAlgorithmArgs&... args);
304 };
305 
306 } // namespace cv
307 } // namespace mlpack
308 
309 // Include implementation
310 #include "simple_cv_impl.hpp"
311 
312 #endif
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
Definition: dataset_mapper.hpp:41
SimpleCV splits data into two sets - training and validation sets - and then runs training on the tra...
Definition: simple_cv.hpp:68
static const bool SupportsWeights
An indication whether MLAlgorithm supports weighted learning.
Definition: meta_info_extractor.hpp:337
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
SimpleCV(const double validationSize, MatInType &&xs, PredictionsInType &&ys)
This constructor can be used for regression algorithms and for binary classification algorithms...
typename Select< TF1, TF2, TF3, TF4, TF5 >::Type::PredictionsType PredictionsType
The type of predictions used in MLAlgorithm.
Definition: meta_info_extractor.hpp:319
double Evaluate(const MLAlgorithmArgs &... args)
Train on the training set and assess performance on the validation set by using the class Metric...
Definition: simple_cv_impl.hpp:195
MLAlgorithm & Model()
Access and modify the last trained model.
Definition: simple_cv_impl.hpp:209
An auxiliary class for cross-validation.
Definition: cv_base.hpp:39