12 #ifndef MLPACK_CORE_CV_CV_BASE_IMPL_HPP 13 #define MLPACK_CORE_CV_CV_BASE_IMPL_HPP 20 template<
typename MLAlgorithm,
22 typename PredictionsType,
28 isDatasetInfoPassed(false),
32 "The given MLAlgorithm requires the numClasses parameter; " 33 "make sure that you pass numClasses with type size_t!");
36 template<
typename MLAlgorithm,
38 typename PredictionsType,
43 WeightsType>
::CVBase(
const size_t numClasses) :
44 isDatasetInfoPassed(false),
45 numClasses(numClasses)
48 "The given MLAlgorithm does not take the numClasses parameter");
51 template<
typename MLAlgorithm,
53 typename PredictionsType,
59 const size_t numClasses) :
60 datasetInfo(datasetInfo),
61 isDatasetInfoPassed(true),
62 numClasses(numClasses)
65 "The given MLAlgorithm does not take the numClasses parameter");
67 "The given MLAlgorithm does not accept a data::DatasetInfo parameter");
70 template<
typename MLAlgorithm,
72 typename PredictionsType,
74 template<
typename... MLAlgorithmArgs>
75 MLAlgorithm
CVBase<MLAlgorithm,
79 const PredictionsType& ys,
80 const MLAlgorithmArgs&... args)
82 return TrainModel(xs, ys, args...);
85 template<
typename MLAlgorithm,
87 typename PredictionsType,
89 template<
typename... MLAlgorithmArgs>
90 MLAlgorithm
CVBase<MLAlgorithm,
94 const PredictionsType& ys,
95 const WeightsType& weights,
96 const MLAlgorithmArgs&... args)
98 return TrainModel(xs, ys, weights, args...);
101 template<
typename MLAlgorithm,
103 typename PredictionsType,
104 typename WeightsType>
109 const PredictionsType& ys)
111 util::CheckSameSizes(xs, (
size_t) ys.n_cols,
112 "CVBase::AssertDataConsistency()",
"predictions");
115 template<
typename MLAlgorithm,
117 typename PredictionsType,
118 typename WeightsType>
123 const WeightsType& weights)
126 "The given MLAlgorithm does not support weighted learning");
128 util::CheckSameSizes(xs, weights,
"CVBase::AssertWeightsConsistency()",
132 template<
typename MLAlgorithm,
134 typename PredictionsType,
135 typename WeightsType>
136 template<
typename... MLAlgorithmArgs,
bool Enabled,
typename>
137 MLAlgorithm
CVBase<MLAlgorithm,
140 WeightsType>::TrainModel(
const MatType& xs,
141 const PredictionsType& ys,
142 const MLAlgorithmArgs&... args)
145 std::is_constructible<MLAlgorithm,
const MatType&,
const PredictionsType&,
146 MLAlgorithmArgs...>::value,
147 "The given MLAlgorithm is not constructible from the passed arguments");
149 return MLAlgorithm(xs, ys, args...);
152 template<
typename MLAlgorithm,
154 typename PredictionsType,
155 typename WeightsType>
156 template<
typename... MLAlgorithmArgs,
bool Enabled,
typename,
typename>
157 MLAlgorithm
CVBase<MLAlgorithm,
160 WeightsType>::TrainModel(
const MatType& xs,
161 const PredictionsType& ys,
162 const MLAlgorithmArgs&... args)
165 std::is_constructible<MLAlgorithm,
const MatType&,
const PredictionsType&,
166 const size_t, MLAlgorithmArgs...>::value,
167 "The given MLAlgorithm is not constructible from the passed arguments");
169 return MLAlgorithm(xs, ys, numClasses, args...);
172 template<
typename MLAlgorithm,
174 typename PredictionsType,
175 typename WeightsType>
176 template<
typename... MLAlgorithmArgs,
bool Enabled,
typename,
typename,
178 MLAlgorithm
CVBase<MLAlgorithm,
181 WeightsType>::TrainModel(
const MatType& xs,
182 const PredictionsType& ys,
183 const MLAlgorithmArgs&... args)
186 std::is_constructible<MLAlgorithm,
const MatType&,
188 MLAlgorithmArgs...>::value,
189 "The given MLAlgorithm is not constructible with a data::DatasetInfo " 190 "parameter and the passed arguments");
192 static const bool constructableWithoutDatasetInfo =
193 std::is_constructible<MLAlgorithm,
const MatType&,
const PredictionsType&,
194 const size_t, MLAlgorithmArgs...>::value;
195 return TrainModel<constructableWithoutDatasetInfo>(xs, ys, args...);
198 template<
typename MLAlgorithm,
200 typename PredictionsType,
201 typename WeightsType>
202 template<
typename... MLAlgorithmArgs,
bool Enabled,
typename>
203 MLAlgorithm
CVBase<MLAlgorithm,
206 WeightsType>::TrainModel(
const MatType& xs,
207 const PredictionsType& ys,
208 const WeightsType& weights,
209 const MLAlgorithmArgs&... args)
212 std::is_constructible<MLAlgorithm,
const MatType&,
const PredictionsType&,
213 const WeightsType&, MLAlgorithmArgs...>::value,
214 "The given MLAlgorithm is not constructible from the passed arguments");
216 return MLAlgorithm(xs, ys, weights, args...);
219 template<
typename MLAlgorithm,
221 typename PredictionsType,
222 typename WeightsType>
223 template<
typename... MLAlgorithmArgs,
bool Enabled,
typename,
typename>
224 MLAlgorithm
CVBase<MLAlgorithm,
227 WeightsType>::TrainModel(
const MatType& xs,
228 const PredictionsType& ys,
229 const WeightsType& weights,
230 const MLAlgorithmArgs&... args)
233 std::is_constructible<MLAlgorithm,
const MatType&,
const PredictionsType&,
234 const size_t,
const WeightsType&, MLAlgorithmArgs...>::value,
235 "The given MLAlgorithm is not constructible from the passed arguments");
237 return MLAlgorithm(xs, ys, numClasses, weights, args...);
240 template<
typename MLAlgorithm,
242 typename PredictionsType,
243 typename WeightsType>
244 template<
typename... MLAlgorithmArgs,
bool Enabled,
typename,
typename,
246 MLAlgorithm
CVBase<MLAlgorithm,
249 WeightsType>::TrainModel(
const MatType& xs,
250 const PredictionsType& ys,
251 const WeightsType& weights,
252 const MLAlgorithmArgs&... args)
255 std::is_constructible<MLAlgorithm,
const MatType&,
257 const WeightsType&, MLAlgorithmArgs...>::value,
258 "The given MLAlgorithm is not constructible with a data::DatasetInfo " 259 "parameter and the passed arguments");
261 static const bool constructableWithoutDatasetInfo =
262 std::is_constructible<MLAlgorithm,
const MatType&,
const PredictionsType&,
263 const size_t,
const WeightsType&, MLAlgorithmArgs...>::value;
264 return TrainModel<constructableWithoutDatasetInfo>(xs, ys, weights, args...);
267 template<
typename MLAlgorithm,
269 typename PredictionsType,
270 typename WeightsType>
271 template<
bool ConstructableWithoutDatasetInfo,
typename... MLAlgorithmArgs,
273 MLAlgorithm
CVBase<MLAlgorithm,
276 WeightsType>::TrainModel(
const MatType& xs,
277 const PredictionsType& ys,
278 const MLAlgorithmArgs&... args)
280 if (isDatasetInfoPassed)
281 return MLAlgorithm(xs, datasetInfo, ys, numClasses, args...);
283 return MLAlgorithm(xs, ys, numClasses, args...);
286 template<
typename MLAlgorithm,
288 typename PredictionsType,
289 typename WeightsType>
290 template<
bool ConstructableWithoutDatasetInfo,
typename... MLAlgorithmArgs,
292 MLAlgorithm
CVBase<MLAlgorithm,
295 WeightsType>::TrainModel(
const MatType& xs,
296 const PredictionsType& ys,
297 const MLAlgorithmArgs&... args)
299 if (!isDatasetInfoPassed)
300 throw std::invalid_argument(
301 "The given MLAlgorithm requires a data::DatasetInfo parameter");
303 return MLAlgorithm(xs, datasetInfo, ys, numClasses, args...);
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
Definition: dataset_mapper.hpp:41
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.
static void AssertDataConsistency(const MatType &xs, const PredictionsType &ys)
Assert there is the equal number of data points and predictions.
Definition: cv_base_impl.hpp:108
CVBase()
Assert that MLAlgorithm doesn't take any additional basic parameters like numClasses.
Definition: cv_base_impl.hpp:27
MLAlgorithm Train(const MatType &xs, const PredictionsType &ys, const MLAlgorithmArgs &... args)
Train MLAlgorithm with given data points, predictions, and hyperparameters depending on what CVBase c...
Definition: cv_base_impl.hpp:78
static void AssertWeightsConsistency(const MatType &xs, const WeightsType &weights)
Assert weighted learning is supported and there is the equal number of data points and weights...
Definition: cv_base_impl.hpp:122
An auxiliary class for cross-validation.
Definition: cv_base.hpp:39