12 #ifndef MLPACK_CORE_CV_SIMPLE_CV_IMPL_HPP 13 #define MLPACK_CORE_CV_SIMPLE_CV_IMPL_HPP 18 template<
typename MLAlgorithm,
21 typename PredictionsType,
23 template<
typename MIT,
typename PIT>
28 WeightsType>::SimpleCV(
const double validationSize,
31 SimpleCV(Base(), validationSize,
std::forward<MIT>(xs),
32 std::forward<PIT>(ys))
35 template<
typename MLAlgorithm,
38 typename PredictionsType,
40 template<
typename MIT,
typename PIT>
45 WeightsType>
::SimpleCV(
const double validationSize,
48 const size_t numClasses) :
49 SimpleCV(Base(numClasses), validationSize,
std::forward<MIT>(xs),
50 std::forward<PIT>(ys))
53 template<
typename MLAlgorithm,
56 typename PredictionsType,
58 template<
typename MIT,
typename PIT>
63 WeightsType>
::SimpleCV(
const double validationSize,
65 const data::DatasetInfo& datasetInfo,
67 const size_t numClasses) :
68 SimpleCV(Base(datasetInfo, numClasses), validationSize,
69 std::forward<MIT>(xs),
std::forward<PIT>(ys))
72 template<
typename MLAlgorithm,
75 typename PredictionsType,
77 template<
typename MIT,
typename PIT,
typename WIT>
82 WeightsType>
::SimpleCV(
const double validationSize,
86 SimpleCV(Base(), validationSize,
std::forward<MIT>(xs),
87 std::forward<PIT>(ys),
std::forward<WIT>(weights))
90 template<
typename MLAlgorithm,
93 typename PredictionsType,
95 template<
typename MIT,
typename PIT,
typename WIT>
100 WeightsType>
::SimpleCV(
const double validationSize,
103 const size_t numClasses,
105 SimpleCV(Base(numClasses), validationSize,
std::forward<MIT>(xs),
106 std::forward<PIT>(ys),
std::forward<WIT>(weights))
109 template<
typename MLAlgorithm,
112 typename PredictionsType,
113 typename WeightsType>
114 template<
typename MIT,
typename PIT,
typename WIT>
119 WeightsType>
::SimpleCV(
const double validationSize,
121 const data::DatasetInfo& datasetInfo,
123 const size_t numClasses,
125 SimpleCV(Base(datasetInfo, numClasses), validationSize,
126 std::forward<MIT>(xs),
std::forward<PIT>(ys),
127 std::forward<WIT>(weights))
130 template<
typename MLAlgorithm,
133 typename PredictionsType,
134 typename WeightsType>
135 template<
typename MIT,
typename PIT>
141 const double validationSize,
144 base(
std::move(base)),
145 xs(
std::forward<MIT>(xs)),
146 ys(
std::forward<PIT>(ys))
150 size_t numberOfTrainingPoints = CalculateAndAssertNumberOfTrainingPoints(
153 trainingXs = GetSubset(this->xs, 0, numberOfTrainingPoints - 1);
154 trainingYs = GetSubset(this->ys, 0, numberOfTrainingPoints - 1);
156 validationXs = GetSubset(this->xs, numberOfTrainingPoints, xs.n_cols - 1);
157 validationYs = GetSubset(this->ys, numberOfTrainingPoints, xs.n_cols - 1);
160 template<
typename MLAlgorithm,
163 typename PredictionsType,
164 typename WeightsType>
165 template<
typename MIT,
typename PIT,
typename WIT>
171 const double validationSize,
176 std::forward<PIT>(ys))
178 this->weights = std::forward<WIT>(weights);
182 trainingWeights = GetSubset(this->weights, 0, trainingXs.n_cols - 1);
185 template<
typename MLAlgorithm,
188 typename PredictionsType,
189 typename WeightsType>
190 template<
typename... MLAlgorithmArgs>
197 return TrainAndEvaluate(args...);
200 template<
typename MLAlgorithm,
203 typename PredictionsType,
204 typename WeightsType>
211 if (modelPtr ==
nullptr)
212 throw std::logic_error(
213 "SimpleCV::Model(): attempted to access an uninitialized model");
218 template<
typename MLAlgorithm,
221 typename PredictionsType,
222 typename WeightsType>
227 WeightsType>::CalculateAndAssertNumberOfTrainingPoints(
228 const double validationSize)
230 if (validationSize < 0.0 || validationSize > 1.0)
231 throw std::invalid_argument(
"SimpleCV: the validationSize parameter should " 232 "be more than 0 and less than 1");
235 throw std::invalid_argument(
"SimpleCV: 2 or more data points are expected");
237 size_t trainingPoints = round(xs.n_cols * (1.0 - validationSize));
239 if (trainingPoints == 0 || trainingPoints == xs.n_cols)
240 throw std::invalid_argument(
"SimpleCV: the validationSize parameter is " 241 "either too small or too big");
243 return trainingPoints;
246 template<
typename MLAlgorithm,
249 typename PredictionsType,
250 typename WeightsType>
251 template<
typename ElementType>
252 arma::Mat<ElementType>
SimpleCV<MLAlgorithm,
256 WeightsType>::GetSubset(
257 arma::Mat<ElementType>& m,
258 const size_t firstCol,
259 const size_t lastCol)
261 return arma::Mat<ElementType>(m.colptr(firstCol), m.n_rows,
262 lastCol - firstCol + 1,
false,
true);
265 template<
typename MLAlgorithm,
268 typename PredictionsType,
269 typename WeightsType>
270 template<
typename ElementType>
271 arma::Row<ElementType>
SimpleCV<MLAlgorithm,
275 WeightsType>::GetSubset(
276 arma::Row<ElementType>& r,
277 const size_t firstCol,
278 const size_t lastCol)
280 return arma::Row<ElementType>(r.colptr(firstCol), lastCol - firstCol + 1,
284 template<
typename MLAlgorithm,
287 typename PredictionsType,
288 typename WeightsType>
289 template<
typename... MLAlgorithmArgs,
bool Enabled,
typename>
294 WeightsType>::TrainAndEvaluate(
const MLAlgorithmArgs&... args)
296 modelPtr.reset(
new MLAlgorithm(base.
Train(trainingXs, trainingYs, args...)));
298 return Metric::Evaluate(*modelPtr, validationXs, validationYs);
301 template<
typename MLAlgorithm,
304 typename PredictionsType,
305 typename WeightsType>
306 template<
typename... MLAlgorithmArgs,
bool Enabled,
typename,
typename>
311 WeightsType>::TrainAndEvaluate(
const MLAlgorithmArgs&... args)
313 if (trainingWeights.n_elem > 0)
314 modelPtr.reset(
new MLAlgorithm(
315 base.
Train(trainingXs, trainingYs, trainingWeights, args...)));
317 modelPtr.reset(
new MLAlgorithm(
318 base.
Train(trainingXs, trainingYs, args...)));
320 return Metric::Evaluate(*modelPtr, validationXs, validationYs);
SimpleCV splits data into two sets - training and validation sets - and then runs training on the tra...
Definition: simple_cv.hpp:68
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
SimpleCV(const double validationSize, MatInType &&xs, PredictionsInType &&ys)
This constructor can be used for regression algorithms and for binary classification algorithms...
Definition: pointer_wrapper.hpp:23
double Evaluate(const MLAlgorithmArgs &... args)
Train on the training set and assess performance on the validation set by using the class Metric...
Definition: simple_cv_impl.hpp:195
static void AssertDataConsistency(const MatType &xs, const PredictionsType &ys)
Assert there is the equal number of data points and predictions.
Definition: cv_base_impl.hpp:108
MLAlgorithm Train(const MatType &xs, const PredictionsType &ys, const MLAlgorithmArgs &... args)
Train MLAlgorithm with given data points, predictions, and hyperparameters depending on what CVBase c...
Definition: cv_base_impl.hpp:78
static void AssertWeightsConsistency(const MatType &xs, const WeightsType &weights)
Assert weighted learning is supported and there is the equal number of data points and weights...
Definition: cv_base_impl.hpp:122
MLAlgorithm & Model()
Access and modify the last trained model.
Definition: simple_cv_impl.hpp:209