12 #ifndef MLPACK_CORE_CV_K_FOLD_CV_IMPL_HPP 13 #define MLPACK_CORE_CV_K_FOLD_CV_IMPL_HPP 18 template<
typename MLAlgorithm,
21 typename PredictionsType,
29 const PredictionsType& ys,
34 template<
typename MLAlgorithm,
37 typename PredictionsType,
45 const PredictionsType& ys,
46 const size_t numClasses,
51 template<
typename MLAlgorithm,
54 typename PredictionsType,
63 const PredictionsType& ys,
64 const size_t numClasses,
66 KFoldCV(
Base(datasetInfo, numClasses), k, xs, ys, shuffle)
69 template<
typename MLAlgorithm,
72 typename PredictionsType,
80 const PredictionsType& ys,
81 const WeightsType& weights,
86 template<
typename MLAlgorithm,
89 typename PredictionsType,
97 const PredictionsType& ys,
98 const size_t numClasses,
99 const WeightsType& weights,
100 const bool shuffle) :
101 KFoldCV(
Base(numClasses), k, xs, ys, weights, shuffle)
104 template<
typename MLAlgorithm,
107 typename PredictionsType,
108 typename WeightsType>
116 const PredictionsType& ys,
117 const size_t numClasses,
118 const WeightsType& weights,
119 const bool shuffle) :
120 KFoldCV(
Base(datasetInfo, numClasses), k, xs, ys, weights, shuffle)
123 template<
typename MLAlgorithm,
126 typename PredictionsType,
127 typename WeightsType>
135 const PredictionsType& ys,
136 const bool shuffle) :
137 base(std::move(base)),
141 throw std::invalid_argument(
"KFoldCV: k should not be less than 2");
145 InitKFoldCVMat(xs, this->xs);
146 InitKFoldCVMat(ys, this->ys);
153 template<
typename MLAlgorithm,
156 typename PredictionsType,
157 typename WeightsType>
165 const PredictionsType& ys,
166 const WeightsType& weights,
167 const bool shuffle) :
168 base(std::move(base)),
173 InitKFoldCVMat(xs, this->xs);
174 InitKFoldCVMat(ys, this->ys);
175 InitKFoldCVMat(weights, this->weights);
182 template<
typename MLAlgorithm,
185 typename PredictionsType,
186 typename WeightsType>
187 template<
typename... MLAlgorithmArgs>
194 return TrainAndEvaluate(args...);
197 template<
typename MLAlgorithm,
200 typename PredictionsType,
201 typename WeightsType>
202 MLAlgorithm&
KFoldCV<MLAlgorithm,
208 if (modelPtr ==
nullptr)
209 throw std::logic_error(
210 "KFoldCV::Model(): attempted to access an uninitialized model");
215 template<
typename MLAlgorithm,
218 typename PredictionsType,
219 typename WeightsType>
220 template<
typename DataType>
225 WeightsType>::InitKFoldCVMat(
const DataType& source,
226 DataType& destination)
228 binSize = source.n_cols / k;
229 lastBinSize = source.n_cols - ((k - 1) * binSize);
231 destination = (k == 2) ? source : arma::join_rows(source,
232 source.cols(0, source.n_cols - lastBinSize - 1));
235 template<
typename MLAlgorithm,
238 typename PredictionsType,
239 typename WeightsType>
240 template<
typename... MLAlgorithmArgs,
bool Enabled,
typename>
245 WeightsType>::TrainAndEvaluate(
const MLAlgorithmArgs&... args)
247 arma::vec evaluations(k);
249 size_t numInvalidScores = 0;
250 for (
size_t i = 0; i < k; ++i)
252 MLAlgorithm&& model = base.
Train(GetTrainingSubset(xs, i),
253 GetTrainingSubset(ys, i), args...);
254 evaluations(i) = Metric::Evaluate(model, GetValidationSubset(xs, i),
255 GetValidationSubset(ys, i));
256 if (std::isnan(evaluations(i)) || std::isinf(evaluations(i)))
259 Log::Warn <<
"KFoldCV::TrainAndEvaluate(): fold " << i <<
" returned " 260 <<
"a score of " << evaluations(i) <<
"; ignoring when computing " 261 <<
"the average score." << std::endl;
264 modelPtr.reset(
new MLAlgorithm(std::move(model)));
267 if (numInvalidScores == k)
269 Log::Warn <<
"KFoldCV::TrainAndEvaluate(): all folds returned invalid " 270 <<
"scores! Returning 0.0 as overall score." << std::endl;
274 return arma::mean(evaluations.elem(arma::find_finite(evaluations)));
277 template<
typename MLAlgorithm,
280 typename PredictionsType,
281 typename WeightsType>
282 template<
typename... MLAlgorithmArgs,
bool Enabled,
typename,
typename>
287 WeightsType>::TrainAndEvaluate(
const MLAlgorithmArgs&... args)
289 arma::vec evaluations(k);
291 for (
size_t i = 0; i < k; ++i)
293 MLAlgorithm&& model = (weights.n_elem > 0) ?
294 base.
Train(GetTrainingSubset(xs, i), GetTrainingSubset(ys, i),
295 GetTrainingSubset(weights, i), args...) :
296 base.
Train(GetTrainingSubset(xs, i), GetTrainingSubset(ys, i),
298 evaluations(i) = Metric::Evaluate(model, GetValidationSubset(xs, i),
299 GetValidationSubset(ys, i));
301 modelPtr.reset(
new MLAlgorithm(std::move(model)));
304 return arma::mean(evaluations);
307 template<
typename MLAlgorithm,
310 typename PredictionsType,
311 typename WeightsType>
312 template<
bool Enabled,
typename>
319 MatType xsOrig = xs.cols(0, (k - 1) * binSize + lastBinSize - 1);
320 PredictionsType ysOrig = ys.cols(0, (k - 1) * binSize + lastBinSize - 1);
325 InitKFoldCVMat(xsOrig, xs);
326 InitKFoldCVMat(ysOrig, ys);
329 template<
typename MLAlgorithm,
332 typename PredictionsType,
333 typename WeightsType>
334 template<
bool Enabled,
typename,
typename>
341 MatType xsOrig = xs.cols(0, (k - 1) * binSize + lastBinSize - 1);
342 PredictionsType ysOrig = ys.cols(0, (k - 1) * binSize + lastBinSize - 1);
343 WeightsType weightsOrig;
344 if (weights.n_elem > 0)
345 weightsOrig = weights.cols(0, (k - 1) * binSize + lastBinSize - 1);
348 if (weights.n_elem > 0)
353 InitKFoldCVMat(xsOrig, xs);
354 InitKFoldCVMat(ysOrig, ys);
355 if (weights.n_elem > 0)
356 InitKFoldCVMat(weightsOrig, weights);
359 template<
typename MLAlgorithm,
362 typename PredictionsType,
363 typename WeightsType>
368 WeightsType>::ValidationSubsetFirstCol(
const size_t i)
371 return (i == 0) ? binSize * (k - 1) : binSize * (i - 1);
374 template<
typename MLAlgorithm,
377 typename PredictionsType,
378 typename WeightsType>
379 template<
typename ElementType>
380 arma::Mat<ElementType>
KFoldCV<MLAlgorithm,
384 WeightsType>::GetTrainingSubset(
385 arma::Mat<ElementType>& m,
391 const size_t subsetSize = (i != 0) ? lastBinSize + (k - 2) * binSize :
394 return arma::Mat<ElementType>(m.colptr(binSize * i), m.n_rows, subsetSize,
398 template<
typename MLAlgorithm,
401 typename PredictionsType,
402 typename WeightsType>
403 template<
typename ElementType>
404 arma::Row<ElementType>
KFoldCV<MLAlgorithm,
408 WeightsType>::GetTrainingSubset(
409 arma::Row<ElementType>& r,
415 const size_t subsetSize = (i != 0) ? lastBinSize + (k - 2) * binSize :
418 return arma::Row<ElementType>(r.colptr(binSize * i), subsetSize,
false,
true);
421 template<
typename MLAlgorithm,
424 typename PredictionsType,
425 typename WeightsType>
426 template<
typename ElementType>
427 arma::Mat<ElementType>
KFoldCV<MLAlgorithm,
431 WeightsType>::GetValidationSubset(
432 arma::Mat<ElementType>& m,
435 const size_t subsetSize = (i == 0) ? lastBinSize : binSize;
436 return arma::Mat<ElementType>(m.colptr(ValidationSubsetFirstCol(i)), m.n_rows,
437 subsetSize,
false,
true);
440 template<
typename MLAlgorithm,
443 typename PredictionsType,
444 typename WeightsType>
445 template<
typename ElementType>
446 arma::Row<ElementType>
KFoldCV<MLAlgorithm,
450 WeightsType>::GetValidationSubset(
451 arma::Row<ElementType>& r,
454 const size_t subsetSize = (i == 0) ? lastBinSize : binSize;
455 return arma::Row<ElementType>(r.colptr(ValidationSubsetFirstCol(i)),
456 subsetSize,
false,
true);
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
Definition: dataset_mapper.hpp:41
void Shuffle()
Shuffle the data.
Definition: k_fold_cv_impl.hpp:317
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
static MLPACK_EXPORT util::PrefixedOutStream Warn
Prints warning messages prefixed with [WARN ].
Definition: log.hpp:87
static void AssertDataConsistency(const MatType &xs, const PredictionsType &ys)
Assert there is the equal number of data points and predictions.
Definition: cv_base_impl.hpp:108
void ShuffleData(const MatType &inputPoints, const LabelsType &inputLabels, MatType &outputPoints, LabelsType &outputLabels, const std::enable_if_t<!arma::is_SpMat< MatType >::value > *=0, const std::enable_if_t<!arma::is_Cube< MatType >::value > *=0)
Shuffle a dataset and associated labels (or responses).
Definition: shuffle_data.hpp:28
MLAlgorithm Train(const MatType &xs, const PredictionsType &ys, const MLAlgorithmArgs &... args)
Train MLAlgorithm with given data points, predictions, and hyperparameters depending on what CVBase c...
Definition: cv_base_impl.hpp:78
KFoldCV(const size_t k, const MatType &xs, const PredictionsType &ys, const bool shuffle=true)
This constructor can be used for regression algorithms and for binary classification algorithms...
Definition: k_fold_cv_impl.hpp:27
The class KFoldCV implements k-fold cross-validation for regression and classification algorithms...
Definition: k_fold_cv.hpp:65
double Evaluate(const MLAlgorithmArgs &...args)
Run k-fold cross-validation.
Definition: k_fold_cv_impl.hpp:192
static void AssertWeightsConsistency(const MatType &xs, const WeightsType &weights)
Assert weighted learning is supported and there is the equal number of data points and weights...
Definition: cv_base_impl.hpp:122
An auxiliary class for cross-validation.
Definition: cv_base.hpp:39
MLAlgorithm & Model()
Access and modify a model from the last run of k-fold cross-validation.
Definition: k_fold_cv_impl.hpp:206