mlpack
simple_cv_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_CORE_CV_SIMPLE_CV_IMPL_HPP
13 #define MLPACK_CORE_CV_SIMPLE_CV_IMPL_HPP
14 
15 namespace mlpack {
16 namespace cv {
17 
18 template<typename MLAlgorithm,
19  typename Metric,
20  typename MatType,
21  typename PredictionsType,
22  typename WeightsType>
23 template<typename MIT, typename PIT>
24 SimpleCV<MLAlgorithm,
25  Metric,
26  MatType,
27  PredictionsType,
28  WeightsType>::SimpleCV(const double validationSize,
29  MIT&& xs,
30  PIT&& ys) :
31  SimpleCV(Base(), validationSize, std::forward<MIT>(xs),
32  std::forward<PIT>(ys))
33 { /* Nothing left to do. */ }
34 
35 template<typename MLAlgorithm,
36  typename Metric,
37  typename MatType,
38  typename PredictionsType,
39  typename WeightsType>
40 template<typename MIT, typename PIT>
41 SimpleCV<MLAlgorithm,
42  Metric,
43  MatType,
44  PredictionsType,
45  WeightsType>::SimpleCV(const double validationSize,
46  MIT&& xs,
47  PIT&& ys,
48  const size_t numClasses) :
49  SimpleCV(Base(numClasses), validationSize, std::forward<MIT>(xs),
50  std::forward<PIT>(ys))
51 { /* Nothing left to do. */ }
52 
53 template<typename MLAlgorithm,
54  typename Metric,
55  typename MatType,
56  typename PredictionsType,
57  typename WeightsType>
58 template<typename MIT, typename PIT>
59 SimpleCV<MLAlgorithm,
60  Metric,
61  MatType,
62  PredictionsType,
63  WeightsType>::SimpleCV(const double validationSize,
64  MIT&& xs,
65  const data::DatasetInfo& datasetInfo,
66  PIT&& ys,
67  const size_t numClasses) :
68  SimpleCV(Base(datasetInfo, numClasses), validationSize,
69  std::forward<MIT>(xs), std::forward<PIT>(ys))
70 { /* Nothing left to do. */ }
71 
72 template<typename MLAlgorithm,
73  typename Metric,
74  typename MatType,
75  typename PredictionsType,
76  typename WeightsType>
77 template<typename MIT, typename PIT, typename WIT>
78 SimpleCV<MLAlgorithm,
79  Metric,
80  MatType,
81  PredictionsType,
82  WeightsType>::SimpleCV(const double validationSize,
83  MIT&& xs,
84  PIT&& ys,
85  WIT&& weights) :
86  SimpleCV(Base(), validationSize, std::forward<MIT>(xs),
87  std::forward<PIT>(ys), std::forward<WIT>(weights))
88 { /* Nothing left to do. */ }
89 
90 template<typename MLAlgorithm,
91  typename Metric,
92  typename MatType,
93  typename PredictionsType,
94  typename WeightsType>
95 template<typename MIT, typename PIT, typename WIT>
96 SimpleCV<MLAlgorithm,
97  Metric,
98  MatType,
99  PredictionsType,
100  WeightsType>::SimpleCV(const double validationSize,
101  MIT&& xs,
102  PIT&& ys,
103  const size_t numClasses,
104  WIT&& weights) :
105  SimpleCV(Base(numClasses), validationSize, std::forward<MIT>(xs),
106  std::forward<PIT>(ys), std::forward<WIT>(weights))
107 { /* Nothing left to do. */ }
108 
109 template<typename MLAlgorithm,
110  typename Metric,
111  typename MatType,
112  typename PredictionsType,
113  typename WeightsType>
114 template<typename MIT, typename PIT, typename WIT>
115 SimpleCV<MLAlgorithm,
116  Metric,
117  MatType,
118  PredictionsType,
119  WeightsType>::SimpleCV(const double validationSize,
120  MIT&& xs,
121  const data::DatasetInfo& datasetInfo,
122  PIT&& ys,
123  const size_t numClasses,
124  WIT&& weights) :
125  SimpleCV(Base(datasetInfo, numClasses), validationSize,
126  std::forward<MIT>(xs), std::forward<PIT>(ys),
127  std::forward<WIT>(weights))
128 { /* Nothing left to do. */ }
129 
130 template<typename MLAlgorithm,
131  typename Metric,
132  typename MatType,
133  typename PredictionsType,
134  typename WeightsType>
135 template<typename MIT, typename PIT>
136 SimpleCV<MLAlgorithm,
137  Metric,
138  MatType,
139  PredictionsType,
140  WeightsType>::SimpleCV(Base&& base,
141  const double validationSize,
142  MIT&& xs,
143  PIT&& ys) :
144  base(std::move(base)),
145  xs(std::forward<MIT>(xs)),
146  ys(std::forward<PIT>(ys))
147 {
148  Base::AssertDataConsistency(this->xs, this->ys);
149 
150  size_t numberOfTrainingPoints = CalculateAndAssertNumberOfTrainingPoints(
151  validationSize);
152 
153  trainingXs = GetSubset(this->xs, 0, numberOfTrainingPoints - 1);
154  trainingYs = GetSubset(this->ys, 0, numberOfTrainingPoints - 1);
155 
156  validationXs = GetSubset(this->xs, numberOfTrainingPoints, xs.n_cols - 1);
157  validationYs = GetSubset(this->ys, numberOfTrainingPoints, xs.n_cols - 1);
158 }
159 
160 template<typename MLAlgorithm,
161  typename Metric,
162  typename MatType,
163  typename PredictionsType,
164  typename WeightsType>
165 template<typename MIT, typename PIT, typename WIT>
166 SimpleCV<MLAlgorithm,
167  Metric,
168  MatType,
169  PredictionsType,
170  WeightsType>::SimpleCV(Base&& base,
171  const double validationSize,
172  MIT&& xs,
173  PIT&& ys,
174  WIT&& weights) :
175  SimpleCV(std::move(base), validationSize, std::forward<MIT>(xs),
176  std::forward<PIT>(ys))
177 {
178  this->weights = std::forward<WIT>(weights);
179 
180  Base::AssertWeightsConsistency(this->xs, this->weights);
181 
182  trainingWeights = GetSubset(this->weights, 0, trainingXs.n_cols - 1);
183 }
184 
185 template<typename MLAlgorithm,
186  typename Metric,
187  typename MatType,
188  typename PredictionsType,
189  typename WeightsType>
190 template<typename... MLAlgorithmArgs>
191 double SimpleCV<MLAlgorithm,
192  Metric,
193  MatType,
194  PredictionsType,
195  WeightsType>::Evaluate(const MLAlgorithmArgs&... args)
196 {
197  return TrainAndEvaluate(args...);
198 }
199 
200 template<typename MLAlgorithm,
201  typename Metric,
202  typename MatType,
203  typename PredictionsType,
204  typename WeightsType>
205 MLAlgorithm& SimpleCV<MLAlgorithm,
206  Metric,
207  MatType,
208  PredictionsType,
209  WeightsType>::Model()
210 {
211  if (modelPtr == nullptr)
212  throw std::logic_error(
213  "SimpleCV::Model(): attempted to access an uninitialized model");
214 
215  return *modelPtr;
216 }
217 
218 template<typename MLAlgorithm,
219  typename Metric,
220  typename MatType,
221  typename PredictionsType,
222  typename WeightsType>
223 size_t SimpleCV<MLAlgorithm,
224  Metric,
225  MatType,
226  PredictionsType,
227  WeightsType>::CalculateAndAssertNumberOfTrainingPoints(
228  const double validationSize)
229 {
230  if (validationSize < 0.0 || validationSize > 1.0)
231  throw std::invalid_argument("SimpleCV: the validationSize parameter should "
232  "be more than 0 and less than 1");
233 
234  if (xs.n_cols < 2)
235  throw std::invalid_argument("SimpleCV: 2 or more data points are expected");
236 
237  size_t trainingPoints = round(xs.n_cols * (1.0 - validationSize));
238 
239  if (trainingPoints == 0 || trainingPoints == xs.n_cols)
240  throw std::invalid_argument("SimpleCV: the validationSize parameter is "
241  "either too small or too big");
242 
243  return trainingPoints;
244 }
245 
246 template<typename MLAlgorithm,
247  typename Metric,
248  typename MatType,
249  typename PredictionsType,
250  typename WeightsType>
251 template<typename ElementType>
252 arma::Mat<ElementType> SimpleCV<MLAlgorithm,
253  Metric,
254  MatType,
255  PredictionsType,
256  WeightsType>::GetSubset(
257  arma::Mat<ElementType>& m,
258  const size_t firstCol,
259  const size_t lastCol)
260 {
261  return arma::Mat<ElementType>(m.colptr(firstCol), m.n_rows,
262  lastCol - firstCol + 1, false, true);
263 }
264 
265 template<typename MLAlgorithm,
266  typename Metric,
267  typename MatType,
268  typename PredictionsType,
269  typename WeightsType>
270 template<typename ElementType>
271 arma::Row<ElementType> SimpleCV<MLAlgorithm,
272  Metric,
273  MatType,
274  PredictionsType,
275  WeightsType>::GetSubset(
276  arma::Row<ElementType>& r,
277  const size_t firstCol,
278  const size_t lastCol)
279 {
280  return arma::Row<ElementType>(r.colptr(firstCol), lastCol - firstCol + 1,
281  false, true);
282 }
283 
284 template<typename MLAlgorithm,
285  typename Metric,
286  typename MatType,
287  typename PredictionsType,
288  typename WeightsType>
289 template<typename... MLAlgorithmArgs, bool Enabled, typename>
290 double SimpleCV<MLAlgorithm,
291  Metric,
292  MatType,
293  PredictionsType,
294  WeightsType>::TrainAndEvaluate(const MLAlgorithmArgs&... args)
295 {
296  modelPtr.reset(new MLAlgorithm(base.Train(trainingXs, trainingYs, args...)));
297 
298  return Metric::Evaluate(*modelPtr, validationXs, validationYs);
299 }
300 
301 template<typename MLAlgorithm,
302  typename Metric,
303  typename MatType,
304  typename PredictionsType,
305  typename WeightsType>
306 template<typename... MLAlgorithmArgs, bool Enabled, typename, typename>
307 double SimpleCV<MLAlgorithm,
308  Metric,
309  MatType,
310  PredictionsType,
311  WeightsType>::TrainAndEvaluate(const MLAlgorithmArgs&... args)
312 {
313  if (trainingWeights.n_elem > 0)
314  modelPtr.reset(new MLAlgorithm(
315  base.Train(trainingXs, trainingYs, trainingWeights, args...)));
316  else
317  modelPtr.reset(new MLAlgorithm(
318  base.Train(trainingXs, trainingYs, args...)));
319 
320  return Metric::Evaluate(*modelPtr, validationXs, validationYs);
321 }
322 
323 } // namespace cv
324 } // namespace mlpack
325 
326 #endif
SimpleCV splits data into two sets - training and validation sets - and then runs training on the tra...
Definition: simple_cv.hpp:68
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
SimpleCV(const double validationSize, MatInType &&xs, PredictionsInType &&ys)
This constructor can be used for regression algorithms and for binary classification algorithms...
Definition: pointer_wrapper.hpp:23
double Evaluate(const MLAlgorithmArgs &... args)
Train on the training set and assess performance on the validation set by using the class Metric...
Definition: simple_cv_impl.hpp:195
static void AssertDataConsistency(const MatType &xs, const PredictionsType &ys)
Assert there is the equal number of data points and predictions.
Definition: cv_base_impl.hpp:108
MLAlgorithm Train(const MatType &xs, const PredictionsType &ys, const MLAlgorithmArgs &... args)
Train MLAlgorithm with given data points, predictions, and hyperparameters depending on what CVBase c...
Definition: cv_base_impl.hpp:78
static void AssertWeightsConsistency(const MatType &xs, const WeightsType &weights)
Assert weighted learning is supported and there is the equal number of data points and weights...
Definition: cv_base_impl.hpp:122
MLAlgorithm & Model()
Access and modify the last trained model.
Definition: simple_cv_impl.hpp:209