mlpack
cv_base_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_CORE_CV_CV_BASE_IMPL_HPP
13 #define MLPACK_CORE_CV_CV_BASE_IMPL_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
17 namespace mlpack {
18 namespace cv {
19 
20 template<typename MLAlgorithm,
21  typename MatType,
22  typename PredictionsType,
23  typename WeightsType>
24 CVBase<MLAlgorithm,
25  MatType,
26  PredictionsType,
27  WeightsType>::CVBase() :
28  isDatasetInfoPassed(false),
29  numClasses(0)
30 {
31  static_assert(!MIE::TakesNumClasses,
32  "The given MLAlgorithm requires the numClasses parameter; "
33  "make sure that you pass numClasses with type size_t!");
34 }
35 
36 template<typename MLAlgorithm,
37  typename MatType,
38  typename PredictionsType,
39  typename WeightsType>
40 CVBase<MLAlgorithm,
41  MatType,
42  PredictionsType,
43  WeightsType>::CVBase(const size_t numClasses) :
44  isDatasetInfoPassed(false),
45  numClasses(numClasses)
46 {
47  static_assert(MIE::TakesNumClasses,
48  "The given MLAlgorithm does not take the numClasses parameter");
49 }
50 
51 template<typename MLAlgorithm,
52  typename MatType,
53  typename PredictionsType,
54  typename WeightsType>
55 CVBase<MLAlgorithm,
56  MatType,
57  PredictionsType,
58  WeightsType>::CVBase(const data::DatasetInfo& datasetInfo,
59  const size_t numClasses) :
60  datasetInfo(datasetInfo),
61  isDatasetInfoPassed(true),
62  numClasses(numClasses)
63 {
64  static_assert(MIE::TakesNumClasses,
65  "The given MLAlgorithm does not take the numClasses parameter");
66  static_assert(MIE::TakesDatasetInfo,
67  "The given MLAlgorithm does not accept a data::DatasetInfo parameter");
68 }
69 
70 template<typename MLAlgorithm,
71  typename MatType,
72  typename PredictionsType,
73  typename WeightsType>
74 template<typename... MLAlgorithmArgs>
75 MLAlgorithm CVBase<MLAlgorithm,
76  MatType,
77  PredictionsType,
78  WeightsType>::Train(const MatType& xs,
79  const PredictionsType& ys,
80  const MLAlgorithmArgs&... args)
81 {
82  return TrainModel(xs, ys, args...);
83 }
84 
85 template<typename MLAlgorithm,
86  typename MatType,
87  typename PredictionsType,
88  typename WeightsType>
89 template<typename... MLAlgorithmArgs>
90 MLAlgorithm CVBase<MLAlgorithm,
91  MatType,
92  PredictionsType,
93  WeightsType>::Train(const MatType& xs,
94  const PredictionsType& ys,
95  const WeightsType& weights,
96  const MLAlgorithmArgs&... args)
97 {
98  return TrainModel(xs, ys, weights, args...);
99 }
100 
101 template<typename MLAlgorithm,
102  typename MatType,
103  typename PredictionsType,
104  typename WeightsType>
105 void CVBase<MLAlgorithm,
106  MatType,
107  PredictionsType,
108  WeightsType>::AssertDataConsistency(const MatType& xs,
109  const PredictionsType& ys)
110 {
111  util::CheckSameSizes(xs, (size_t) ys.n_cols,
112  "CVBase::AssertDataConsistency()", "predictions");
113 }
114 
115 template<typename MLAlgorithm,
116  typename MatType,
117  typename PredictionsType,
118  typename WeightsType>
119 void CVBase<MLAlgorithm,
120  MatType,
121  PredictionsType,
122  WeightsType>::AssertWeightsConsistency(const MatType& xs,
123  const WeightsType& weights)
124 {
125  static_assert(MIE::SupportsWeights,
126  "The given MLAlgorithm does not support weighted learning");
127 
128  util::CheckSameSizes(xs, weights, "CVBase::AssertWeightsConsistency()",
129  "weights");
130 }
131 
132 template<typename MLAlgorithm,
133  typename MatType,
134  typename PredictionsType,
135  typename WeightsType>
136 template<typename... MLAlgorithmArgs, bool Enabled, typename>
137 MLAlgorithm CVBase<MLAlgorithm,
138  MatType,
139  PredictionsType,
140  WeightsType>::TrainModel(const MatType& xs,
141  const PredictionsType& ys,
142  const MLAlgorithmArgs&... args)
143 {
144  static_assert(
145  std::is_constructible<MLAlgorithm, const MatType&, const PredictionsType&,
146  MLAlgorithmArgs...>::value,
147  "The given MLAlgorithm is not constructible from the passed arguments");
148 
149  return MLAlgorithm(xs, ys, args...);
150 }
151 
152 template<typename MLAlgorithm,
153  typename MatType,
154  typename PredictionsType,
155  typename WeightsType>
156 template<typename... MLAlgorithmArgs, bool Enabled, typename, typename>
157 MLAlgorithm CVBase<MLAlgorithm,
158  MatType,
159  PredictionsType,
160  WeightsType>::TrainModel(const MatType& xs,
161  const PredictionsType& ys,
162  const MLAlgorithmArgs&... args)
163 {
164  static_assert(
165  std::is_constructible<MLAlgorithm, const MatType&, const PredictionsType&,
166  const size_t, MLAlgorithmArgs...>::value,
167  "The given MLAlgorithm is not constructible from the passed arguments");
168 
169  return MLAlgorithm(xs, ys, numClasses, args...);
170 }
171 
172 template<typename MLAlgorithm,
173  typename MatType,
174  typename PredictionsType,
175  typename WeightsType>
176 template<typename... MLAlgorithmArgs, bool Enabled, typename, typename,
177  typename>
178 MLAlgorithm CVBase<MLAlgorithm,
179  MatType,
180  PredictionsType,
181  WeightsType>::TrainModel(const MatType& xs,
182  const PredictionsType& ys,
183  const MLAlgorithmArgs&... args)
184 {
185  static_assert(
186  std::is_constructible<MLAlgorithm, const MatType&,
187  const data::DatasetInfo, const PredictionsType&, const size_t,
188  MLAlgorithmArgs...>::value,
189  "The given MLAlgorithm is not constructible with a data::DatasetInfo "
190  "parameter and the passed arguments");
191 
192  static const bool constructableWithoutDatasetInfo =
193  std::is_constructible<MLAlgorithm, const MatType&, const PredictionsType&,
194  const size_t, MLAlgorithmArgs...>::value;
195  return TrainModel<constructableWithoutDatasetInfo>(xs, ys, args...);
196 }
197 
198 template<typename MLAlgorithm,
199  typename MatType,
200  typename PredictionsType,
201  typename WeightsType>
202 template<typename... MLAlgorithmArgs, bool Enabled, typename>
203 MLAlgorithm CVBase<MLAlgorithm,
204  MatType,
205  PredictionsType,
206  WeightsType>::TrainModel(const MatType& xs,
207  const PredictionsType& ys,
208  const WeightsType& weights,
209  const MLAlgorithmArgs&... args)
210 {
211  static_assert(
212  std::is_constructible<MLAlgorithm, const MatType&, const PredictionsType&,
213  const WeightsType&, MLAlgorithmArgs...>::value,
214  "The given MLAlgorithm is not constructible from the passed arguments");
215 
216  return MLAlgorithm(xs, ys, weights, args...);
217 }
218 
219 template<typename MLAlgorithm,
220  typename MatType,
221  typename PredictionsType,
222  typename WeightsType>
223 template<typename... MLAlgorithmArgs, bool Enabled, typename, typename>
224 MLAlgorithm CVBase<MLAlgorithm,
225  MatType,
226  PredictionsType,
227  WeightsType>::TrainModel(const MatType& xs,
228  const PredictionsType& ys,
229  const WeightsType& weights,
230  const MLAlgorithmArgs&... args)
231 {
232  static_assert(
233  std::is_constructible<MLAlgorithm, const MatType&, const PredictionsType&,
234  const size_t, const WeightsType&, MLAlgorithmArgs...>::value,
235  "The given MLAlgorithm is not constructible from the passed arguments");
236 
237  return MLAlgorithm(xs, ys, numClasses, weights, args...);
238 }
239 
240 template<typename MLAlgorithm,
241  typename MatType,
242  typename PredictionsType,
243  typename WeightsType>
244 template<typename... MLAlgorithmArgs, bool Enabled, typename, typename,
245  typename>
246 MLAlgorithm CVBase<MLAlgorithm,
247  MatType,
248  PredictionsType,
249  WeightsType>::TrainModel(const MatType& xs,
250  const PredictionsType& ys,
251  const WeightsType& weights,
252  const MLAlgorithmArgs&... args)
253 {
254  static_assert(
255  std::is_constructible<MLAlgorithm, const MatType&,
256  const data::DatasetInfo, const PredictionsType&, const size_t,
257  const WeightsType&, MLAlgorithmArgs...>::value,
258  "The given MLAlgorithm is not constructible with a data::DatasetInfo "
259  "parameter and the passed arguments");
260 
261  static const bool constructableWithoutDatasetInfo =
262  std::is_constructible<MLAlgorithm, const MatType&, const PredictionsType&,
263  const size_t, const WeightsType&, MLAlgorithmArgs...>::value;
264  return TrainModel<constructableWithoutDatasetInfo>(xs, ys, weights, args...);
265 }
266 
267 template<typename MLAlgorithm,
268  typename MatType,
269  typename PredictionsType,
270  typename WeightsType>
271 template<bool ConstructableWithoutDatasetInfo, typename... MLAlgorithmArgs,
272  typename>
273 MLAlgorithm CVBase<MLAlgorithm,
274  MatType,
275  PredictionsType,
276  WeightsType>::TrainModel(const MatType& xs,
277  const PredictionsType& ys,
278  const MLAlgorithmArgs&... args)
279 {
280  if (isDatasetInfoPassed)
281  return MLAlgorithm(xs, datasetInfo, ys, numClasses, args...);
282  else
283  return MLAlgorithm(xs, ys, numClasses, args...);
284 }
285 
286 template<typename MLAlgorithm,
287  typename MatType,
288  typename PredictionsType,
289  typename WeightsType>
290 template<bool ConstructableWithoutDatasetInfo, typename... MLAlgorithmArgs,
291  typename, typename>
292 MLAlgorithm CVBase<MLAlgorithm,
293  MatType,
294  PredictionsType,
295  WeightsType>::TrainModel(const MatType& xs,
296  const PredictionsType& ys,
297  const MLAlgorithmArgs&... args)
298 {
299  if (!isDatasetInfoPassed)
300  throw std::invalid_argument(
301  "The given MLAlgorithm requires a data::DatasetInfo parameter");
302 
303  return MLAlgorithm(xs, datasetInfo, ys, numClasses, args...);
304 }
305 
306 } // namespace cv
307 } // namespace mlpack
308 
309 #endif
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
Definition: dataset_mapper.hpp:41
static const bool SupportsWeights
An indication whether MLAlgorithm supports weighted learning.
Definition: meta_info_extractor.hpp:337
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.
static const bool TakesDatasetInfo
An indication whether MLAlgorithm takes a data::DatasetInfo parameter.
Definition: meta_info_extractor.hpp:342
static void AssertDataConsistency(const MatType &xs, const PredictionsType &ys)
Assert there is the equal number of data points and predictions.
Definition: cv_base_impl.hpp:108
CVBase()
Assert that MLAlgorithm doesn&#39;t take any additional basic parameters like numClasses.
Definition: cv_base_impl.hpp:27
MLAlgorithm Train(const MatType &xs, const PredictionsType &ys, const MLAlgorithmArgs &... args)
Train MLAlgorithm with given data points, predictions, and hyperparameters depending on what CVBase c...
Definition: cv_base_impl.hpp:78
static void AssertWeightsConsistency(const MatType &xs, const WeightsType &weights)
Assert weighted learning is supported and there is the equal number of data points and weights...
Definition: cv_base_impl.hpp:122
An auxiliary class for cross-validation.
Definition: cv_base.hpp:39
static const bool TakesNumClasses
An indication whether MLAlgorithm takes the numClasses (size_t) parameter.
Definition: meta_info_extractor.hpp:347