mlpack
k_fold_cv_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_CORE_CV_K_FOLD_CV_IMPL_HPP
13 #define MLPACK_CORE_CV_K_FOLD_CV_IMPL_HPP
14 
15 namespace mlpack {
16 namespace cv {
17 
18 template<typename MLAlgorithm,
19  typename Metric,
20  typename MatType,
21  typename PredictionsType,
22  typename WeightsType>
23 KFoldCV<MLAlgorithm,
24  Metric,
25  MatType,
26  PredictionsType,
27  WeightsType>::KFoldCV(const size_t k,
28  const MatType& xs,
29  const PredictionsType& ys,
30  const bool shuffle) :
31  KFoldCV(Base(), k, xs, ys, shuffle)
32 { /* Nothing left to do. */ }
33 
34 template<typename MLAlgorithm,
35  typename Metric,
36  typename MatType,
37  typename PredictionsType,
38  typename WeightsType>
39 KFoldCV<MLAlgorithm,
40  Metric,
41  MatType,
42  PredictionsType,
43  WeightsType>::KFoldCV(const size_t k,
44  const MatType& xs,
45  const PredictionsType& ys,
46  const size_t numClasses,
47  const bool shuffle) :
48  KFoldCV(Base(numClasses), k, xs, ys, shuffle)
49 { /* Nothing left to do. */ }
50 
51 template<typename MLAlgorithm,
52  typename Metric,
53  typename MatType,
54  typename PredictionsType,
55  typename WeightsType>
56 KFoldCV<MLAlgorithm,
57  Metric,
58  MatType,
59  PredictionsType,
60  WeightsType>::KFoldCV(const size_t k,
61  const MatType& xs,
62  const data::DatasetInfo& datasetInfo,
63  const PredictionsType& ys,
64  const size_t numClasses,
65  const bool shuffle) :
66  KFoldCV(Base(datasetInfo, numClasses), k, xs, ys, shuffle)
67 { /* Nothing left to do. */ }
68 
69 template<typename MLAlgorithm,
70  typename Metric,
71  typename MatType,
72  typename PredictionsType,
73  typename WeightsType>
74 KFoldCV<MLAlgorithm,
75  Metric,
76  MatType,
77  PredictionsType,
78  WeightsType>::KFoldCV(const size_t k,
79  const MatType& xs,
80  const PredictionsType& ys,
81  const WeightsType& weights,
82  const bool shuffle) :
83  KFoldCV(Base(), k, xs, ys, weights, shuffle)
84 { /* Nothing left to do. */ }
85 
86 template<typename MLAlgorithm,
87  typename Metric,
88  typename MatType,
89  typename PredictionsType,
90  typename WeightsType>
91 KFoldCV<MLAlgorithm,
92  Metric,
93  MatType,
94  PredictionsType,
95  WeightsType>::KFoldCV(const size_t k,
96  const MatType& xs,
97  const PredictionsType& ys,
98  const size_t numClasses,
99  const WeightsType& weights,
100  const bool shuffle) :
101  KFoldCV(Base(numClasses), k, xs, ys, weights, shuffle)
102 { /* Nothing left to do. */ }
103 
104 template<typename MLAlgorithm,
105  typename Metric,
106  typename MatType,
107  typename PredictionsType,
108  typename WeightsType>
109 KFoldCV<MLAlgorithm,
110  Metric,
111  MatType,
112  PredictionsType,
113  WeightsType>::KFoldCV(const size_t k,
114  const MatType& xs,
115  const data::DatasetInfo& datasetInfo,
116  const PredictionsType& ys,
117  const size_t numClasses,
118  const WeightsType& weights,
119  const bool shuffle) :
120  KFoldCV(Base(datasetInfo, numClasses), k, xs, ys, weights, shuffle)
121 { /* Nothing left to do. */ }
122 
123 template<typename MLAlgorithm,
124  typename Metric,
125  typename MatType,
126  typename PredictionsType,
127  typename WeightsType>
128 KFoldCV<MLAlgorithm,
129  Metric,
130  MatType,
131  PredictionsType,
132  WeightsType>::KFoldCV(Base&& base,
133  const size_t k,
134  const MatType& xs,
135  const PredictionsType& ys,
136  const bool shuffle) :
137  base(std::move(base)),
138  k(k)
139 {
140  if (k < 2)
141  throw std::invalid_argument("KFoldCV: k should not be less than 2");
142 
144 
145  InitKFoldCVMat(xs, this->xs);
146  InitKFoldCVMat(ys, this->ys);
147 
148  // Do we need to shuffle the dataset?
149  if (shuffle)
150  Shuffle();
151 }
152 
153 template<typename MLAlgorithm,
154  typename Metric,
155  typename MatType,
156  typename PredictionsType,
157  typename WeightsType>
158 KFoldCV<MLAlgorithm,
159  Metric,
160  MatType,
161  PredictionsType,
162  WeightsType>::KFoldCV(Base&& base,
163  const size_t k,
164  const MatType& xs,
165  const PredictionsType& ys,
166  const WeightsType& weights,
167  const bool shuffle) :
168  base(std::move(base)),
169  k(k)
170 {
171  Base::AssertWeightsConsistency(xs, weights);
172 
173  InitKFoldCVMat(xs, this->xs);
174  InitKFoldCVMat(ys, this->ys);
175  InitKFoldCVMat(weights, this->weights);
176 
177  // Do we need to shuffle the dataset?
178  if (shuffle)
179  Shuffle();
180 }
181 
182 template<typename MLAlgorithm,
183  typename Metric,
184  typename MatType,
185  typename PredictionsType,
186  typename WeightsType>
187 template<typename... MLAlgorithmArgs>
188 double KFoldCV<MLAlgorithm,
189  Metric,
190  MatType,
191  PredictionsType,
192  WeightsType>::Evaluate(const MLAlgorithmArgs&... args)
193 {
194  return TrainAndEvaluate(args...);
195 }
196 
197 template<typename MLAlgorithm,
198  typename Metric,
199  typename MatType,
200  typename PredictionsType,
201  typename WeightsType>
202 MLAlgorithm& KFoldCV<MLAlgorithm,
203  Metric,
204  MatType,
205  PredictionsType,
206  WeightsType>::Model()
207 {
208  if (modelPtr == nullptr)
209  throw std::logic_error(
210  "KFoldCV::Model(): attempted to access an uninitialized model");
211 
212  return *modelPtr;
213 }
214 
215 template<typename MLAlgorithm,
216  typename Metric,
217  typename MatType,
218  typename PredictionsType,
219  typename WeightsType>
220 template<typename DataType>
221 void KFoldCV<MLAlgorithm,
222  Metric,
223  MatType,
224  PredictionsType,
225  WeightsType>::InitKFoldCVMat(const DataType& source,
226  DataType& destination)
227 {
228  binSize = source.n_cols / k;
229  lastBinSize = source.n_cols - ((k - 1) * binSize);
230 
231  destination = (k == 2) ? source : arma::join_rows(source,
232  source.cols(0, source.n_cols - lastBinSize - 1));
233 }
234 
235 template<typename MLAlgorithm,
236  typename Metric,
237  typename MatType,
238  typename PredictionsType,
239  typename WeightsType>
240 template<typename... MLAlgorithmArgs, bool Enabled, typename>
241 double KFoldCV<MLAlgorithm,
242  Metric,
243  MatType,
244  PredictionsType,
245  WeightsType>::TrainAndEvaluate(const MLAlgorithmArgs&... args)
246 {
247  arma::vec evaluations(k);
248 
249  size_t numInvalidScores = 0;
250  for (size_t i = 0; i < k; ++i)
251  {
252  MLAlgorithm&& model = base.Train(GetTrainingSubset(xs, i),
253  GetTrainingSubset(ys, i), args...);
254  evaluations(i) = Metric::Evaluate(model, GetValidationSubset(xs, i),
255  GetValidationSubset(ys, i));
256  if (std::isnan(evaluations(i)) || std::isinf(evaluations(i)))
257  {
258  ++numInvalidScores;
259  Log::Warn << "KFoldCV::TrainAndEvaluate(): fold " << i << " returned "
260  << "a score of " << evaluations(i) << "; ignoring when computing "
261  << "the average score." << std::endl;
262  }
263  if (i == k - 1)
264  modelPtr.reset(new MLAlgorithm(std::move(model)));
265  }
266 
267  if (numInvalidScores == k)
268  {
269  Log::Warn << "KFoldCV::TrainAndEvaluate(): all folds returned invalid "
270  << "scores! Returning 0.0 as overall score." << std::endl;
271  return 0.0;
272  }
273 
274  return arma::mean(evaluations.elem(arma::find_finite(evaluations)));
275 }
276 
277 template<typename MLAlgorithm,
278  typename Metric,
279  typename MatType,
280  typename PredictionsType,
281  typename WeightsType>
282 template<typename... MLAlgorithmArgs, bool Enabled, typename, typename>
283 double KFoldCV<MLAlgorithm,
284  Metric,
285  MatType,
286  PredictionsType,
287  WeightsType>::TrainAndEvaluate(const MLAlgorithmArgs&... args)
288 {
289  arma::vec evaluations(k);
290 
291  for (size_t i = 0; i < k; ++i)
292  {
293  MLAlgorithm&& model = (weights.n_elem > 0) ?
294  base.Train(GetTrainingSubset(xs, i), GetTrainingSubset(ys, i),
295  GetTrainingSubset(weights, i), args...) :
296  base.Train(GetTrainingSubset(xs, i), GetTrainingSubset(ys, i),
297  args...);
298  evaluations(i) = Metric::Evaluate(model, GetValidationSubset(xs, i),
299  GetValidationSubset(ys, i));
300  if (i == k - 1)
301  modelPtr.reset(new MLAlgorithm(std::move(model)));
302  }
303 
304  return arma::mean(evaluations);
305 }
306 
307 template<typename MLAlgorithm,
308  typename Metric,
309  typename MatType,
310  typename PredictionsType,
311  typename WeightsType>
312 template<bool Enabled, typename>
313 void KFoldCV<MLAlgorithm,
314  Metric,
315  MatType,
316  PredictionsType,
317  WeightsType>::Shuffle()
318 {
319  MatType xsOrig = xs.cols(0, (k - 1) * binSize + lastBinSize - 1);
320  PredictionsType ysOrig = ys.cols(0, (k - 1) * binSize + lastBinSize - 1);
321 
322  // Now shuffle the data.
323  math::ShuffleData(xsOrig, ysOrig, xsOrig, ysOrig);
324 
325  InitKFoldCVMat(xsOrig, xs);
326  InitKFoldCVMat(ysOrig, ys);
327 }
328 
329 template<typename MLAlgorithm,
330  typename Metric,
331  typename MatType,
332  typename PredictionsType,
333  typename WeightsType>
334 template<bool Enabled, typename, typename>
335 void KFoldCV<MLAlgorithm,
336  Metric,
337  MatType,
338  PredictionsType,
339  WeightsType>::Shuffle()
340 {
341  MatType xsOrig = xs.cols(0, (k - 1) * binSize + lastBinSize - 1);
342  PredictionsType ysOrig = ys.cols(0, (k - 1) * binSize + lastBinSize - 1);
343  WeightsType weightsOrig;
344  if (weights.n_elem > 0)
345  weightsOrig = weights.cols(0, (k - 1) * binSize + lastBinSize - 1);
346 
347  // Now shuffle the data.
348  if (weights.n_elem > 0)
349  math::ShuffleData(xsOrig, ysOrig, weightsOrig, xsOrig, ysOrig, weightsOrig);
350  else
351  math::ShuffleData(xsOrig, ysOrig, xsOrig, ysOrig);
352 
353  InitKFoldCVMat(xsOrig, xs);
354  InitKFoldCVMat(ysOrig, ys);
355  if (weights.n_elem > 0)
356  InitKFoldCVMat(weightsOrig, weights);
357 }
358 
359 template<typename MLAlgorithm,
360  typename Metric,
361  typename MatType,
362  typename PredictionsType,
363  typename WeightsType>
364 size_t KFoldCV<MLAlgorithm,
365  Metric,
366  MatType,
367  PredictionsType,
368  WeightsType>::ValidationSubsetFirstCol(const size_t i)
369 {
370  // Use as close to the beginning of the dataset as we can.
371  return (i == 0) ? binSize * (k - 1) : binSize * (i - 1);
372 }
373 
374 template<typename MLAlgorithm,
375  typename Metric,
376  typename MatType,
377  typename PredictionsType,
378  typename WeightsType>
379 template<typename ElementType>
380 arma::Mat<ElementType> KFoldCV<MLAlgorithm,
381  Metric,
382  MatType,
383  PredictionsType,
384  WeightsType>::GetTrainingSubset(
385  arma::Mat<ElementType>& m,
386  const size_t i)
387 {
388  // If this is not the first fold, we have to handle it a little bit
389  // differently, since the last fold may contain slightly more than 'binSize'
390  // points.
391  const size_t subsetSize = (i != 0) ? lastBinSize + (k - 2) * binSize :
392  (k - 1) * binSize;
393 
394  return arma::Mat<ElementType>(m.colptr(binSize * i), m.n_rows, subsetSize,
395  false, true);
396 }
397 
398 template<typename MLAlgorithm,
399  typename Metric,
400  typename MatType,
401  typename PredictionsType,
402  typename WeightsType>
403 template<typename ElementType>
404 arma::Row<ElementType> KFoldCV<MLAlgorithm,
405  Metric,
406  MatType,
407  PredictionsType,
408  WeightsType>::GetTrainingSubset(
409  arma::Row<ElementType>& r,
410  const size_t i)
411 {
412  // If this is not the first fold, we have to handle it a little bit
413  // differently, since the last fold may contain slightly more than 'binSize'
414  // points.
415  const size_t subsetSize = (i != 0) ? lastBinSize + (k - 2) * binSize :
416  (k - 1) * binSize;
417 
418  return arma::Row<ElementType>(r.colptr(binSize * i), subsetSize, false, true);
419 }
420 
421 template<typename MLAlgorithm,
422  typename Metric,
423  typename MatType,
424  typename PredictionsType,
425  typename WeightsType>
426 template<typename ElementType>
427 arma::Mat<ElementType> KFoldCV<MLAlgorithm,
428  Metric,
429  MatType,
430  PredictionsType,
431  WeightsType>::GetValidationSubset(
432  arma::Mat<ElementType>& m,
433  const size_t i)
434 {
435  const size_t subsetSize = (i == 0) ? lastBinSize : binSize;
436  return arma::Mat<ElementType>(m.colptr(ValidationSubsetFirstCol(i)), m.n_rows,
437  subsetSize, false, true);
438 }
439 
440 template<typename MLAlgorithm,
441  typename Metric,
442  typename MatType,
443  typename PredictionsType,
444  typename WeightsType>
445 template<typename ElementType>
446 arma::Row<ElementType> KFoldCV<MLAlgorithm,
447  Metric,
448  MatType,
449  PredictionsType,
450  WeightsType>::GetValidationSubset(
451  arma::Row<ElementType>& r,
452  const size_t i)
453 {
454  const size_t subsetSize = (i == 0) ? lastBinSize : binSize;
455  return arma::Row<ElementType>(r.colptr(ValidationSubsetFirstCol(i)),
456  subsetSize, false, true);
457 }
458 
459 } // namespace cv
460 } // namespace mlpack
461 
462 #endif
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
Definition: dataset_mapper.hpp:41
void Shuffle()
Shuffle the data.
Definition: k_fold_cv_impl.hpp:317
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
static MLPACK_EXPORT util::PrefixedOutStream Warn
Prints warning messages prefixed with [WARN ].
Definition: log.hpp:87
static void AssertDataConsistency(const MatType &xs, const PredictionsType &ys)
Assert there is the equal number of data points and predictions.
Definition: cv_base_impl.hpp:108
void ShuffleData(const MatType &inputPoints, const LabelsType &inputLabels, MatType &outputPoints, LabelsType &outputLabels, const std::enable_if_t<!arma::is_SpMat< MatType >::value > *=0, const std::enable_if_t<!arma::is_Cube< MatType >::value > *=0)
Shuffle a dataset and associated labels (or responses).
Definition: shuffle_data.hpp:28
MLAlgorithm Train(const MatType &xs, const PredictionsType &ys, const MLAlgorithmArgs &... args)
Train MLAlgorithm with given data points, predictions, and hyperparameters depending on what CVBase c...
Definition: cv_base_impl.hpp:78
KFoldCV(const size_t k, const MatType &xs, const PredictionsType &ys, const bool shuffle=true)
This constructor can be used for regression algorithms and for binary classification algorithms...
Definition: k_fold_cv_impl.hpp:27
The class KFoldCV implements k-fold cross-validation for regression and classification algorithms...
Definition: k_fold_cv.hpp:65
double Evaluate(const MLAlgorithmArgs &...args)
Run k-fold cross-validation.
Definition: k_fold_cv_impl.hpp:192
static void AssertWeightsConsistency(const MatType &xs, const WeightsType &weights)
Assert weighted learning is supported and there is the equal number of data points and weights...
Definition: cv_base_impl.hpp:122
An auxiliary class for cross-validation.
Definition: cv_base.hpp:39
MLAlgorithm & Model()
Access and modify a model from the last run of k-fold cross-validation.
Definition: k_fold_cv_impl.hpp:206