mlpack
split_data.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_CORE_DATA_SPLIT_DATA_HPP
14 #define MLPACK_CORE_DATA_SPLIT_DATA_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace data {
20 
26 template<typename InputType>
27 void SplitHelper(const InputType& input,
28  InputType& train,
29  InputType& test,
30  const double testRatio,
31  const arma::uvec& order = arma::uvec())
32 {
33  const size_t testSize = static_cast<size_t>(input.n_cols * testRatio);
34  const size_t trainSize = input.n_cols - testSize;
35 
36  // Initialising the sizes of outputs if not already initialized.
37  train.set_size(input.n_rows, trainSize);
38  test.set_size(input.n_rows, testSize);
39 
40  // Shuffling and spliting simultaneously.
41  if (!order.is_empty())
42  {
43  if (trainSize > 0)
44  {
45  for (size_t i = 0; i < trainSize; ++i)
46  train.col(i) = input.col(order(i));
47  }
48  if (trainSize < input.n_cols)
49  {
50  for (size_t i = trainSize; i < input.n_cols; ++i)
51  test.col(i - trainSize) = input.col(order(i));
52  }
53  }
54  // Spliting only.
55  else
56  {
57  if (trainSize > 0)
58  train = input.cols(0, trainSize - 1);
59 
60  if (trainSize < input.n_cols)
61  test = input.cols(trainSize, input.n_cols - 1);
62  }
63 }
64 
101 template<typename T, typename LabelsType,
102  typename = std::enable_if_t<arma::is_arma_type<LabelsType>::value> >
103 void StratifiedSplit(const arma::Mat<T>& input,
104  const LabelsType& inputLabel,
105  arma::Mat<T>& trainData,
106  arma::Mat<T>& testData,
107  LabelsType& trainLabel,
108  LabelsType& testLabel,
109  const double testRatio,
110  const bool shuffleData = true)
111 {
145  const bool typeCheck = (arma::is_Row<LabelsType>::value)
146  || (arma::is_Col<LabelsType>::value);
147  if (!typeCheck)
148  throw std::runtime_error("data::Split(): when stratified sampling is done, "
149  "labels must have type `arma::Row<>`!");
150  size_t trainIdx = 0;
151  size_t testIdx = 0;
152  size_t trainSize = 0;
153  size_t testSize = 0;
154  arma::uvec labelCounts;
155  arma::uvec testLabelCounts;
156  typename LabelsType::elem_type maxLabel = inputLabel.max();
157 
158  labelCounts.zeros(maxLabel+1);
159  testLabelCounts.zeros(maxLabel+1);
160 
161  for (typename LabelsType::elem_type label : inputLabel)
162  ++labelCounts[label];
163 
164  for (arma::uword labelCount : labelCounts)
165  {
166  testSize += floor(labelCount * testRatio);
167  trainSize += labelCount - floor(labelCount * testRatio);
168  }
169 
170  trainData.set_size(input.n_rows, trainSize);
171  testData.set_size(input.n_rows, testSize);
172  trainLabel.set_size(inputLabel.n_rows, trainSize);
173  testLabel.set_size(inputLabel.n_rows, testSize);
174 
175  if (shuffleData)
176  {
177  arma::uvec order = arma::shuffle(
178  arma::linspace<arma::uvec>(0, input.n_cols - 1, input.n_cols));
179 
180  for (arma::uword i : order)
181  {
182  typename LabelsType::elem_type label = inputLabel[i];
183  if (testLabelCounts[label] < floor(labelCounts[label] * testRatio))
184  {
185  testLabelCounts[label] += 1;
186  testData.col(testIdx) = input.col(i);
187  testLabel[testIdx] = inputLabel[i];
188  testIdx += 1;
189  }
190  else
191  {
192  trainData.col(trainIdx) = input.col(i);
193  trainLabel[trainIdx] = inputLabel[i];
194  trainIdx += 1;
195  }
196  }
197  }
198  else
199  {
200  for (arma::uword i = 0; i < input.n_cols; i++)
201  {
202  typename LabelsType::elem_type label = inputLabel[i];
203  if (testLabelCounts[label] < floor(labelCounts[label] * testRatio))
204  {
205  testLabelCounts[label] += 1;
206  testData.col(testIdx) = input.col(i);
207  testLabel[testIdx] = inputLabel[i];
208  testIdx += 1;
209  }
210  else
211  {
212  trainData.col(trainIdx) = input.col(i);
213  trainLabel[trainIdx] = inputLabel[i];
214  trainIdx += 1;
215  }
216  }
217  }
218 }
219 
253 template<typename T, typename LabelsType,
254  typename = std::enable_if_t<arma::is_arma_type<LabelsType>::value> >
255 void Split(const arma::Mat<T>& input,
256  const LabelsType& inputLabel,
257  arma::Mat<T>& trainData,
258  arma::Mat<T>& testData,
259  LabelsType& trainLabel,
260  LabelsType& testLabel,
261  const double testRatio,
262  const bool shuffleData = true)
263 {
264  if (shuffleData)
265  {
266  arma::uvec order = arma::shuffle(arma::linspace<arma::uvec>(0,
267  input.n_cols - 1, input.n_cols));
268  SplitHelper(input, trainData, testData, testRatio, order);
269  SplitHelper(inputLabel, trainLabel, testLabel, testRatio, order);
270  }
271  else
272  {
273  SplitHelper(input, trainData, testData, testRatio);
274  SplitHelper(inputLabel, trainLabel, testLabel, testRatio);
275  }
276 }
277 
301 template<typename T>
302 void Split(const arma::Mat<T>& input,
303  arma::Mat<T>& trainData,
304  arma::Mat<T>& testData,
305  const double testRatio,
306  const bool shuffleData = true)
307 {
308  if (shuffleData)
309  {
310  arma::uvec order = arma::shuffle(arma::linspace<arma::uvec>(0,
311  input.n_cols - 1, input.n_cols));
312  SplitHelper(input, trainData, testData, testRatio, order);
313  }
314  else
315  {
316  SplitHelper(input, trainData, testData, testRatio);
317  }
318 }
319 
348 template<typename T, typename LabelsType,
349  typename = std::enable_if_t<arma::is_arma_type<LabelsType>::value> >
350 std::tuple<arma::Mat<T>, arma::Mat<T>, LabelsType, LabelsType>
351 Split(const arma::Mat<T>& input,
352  const LabelsType& inputLabel,
353  const double testRatio,
354  const bool shuffleData = true,
355  const bool stratifyData = false)
356 {
357  arma::Mat<T> trainData;
358  arma::Mat<T> testData;
359  LabelsType trainLabel;
360  LabelsType testLabel;
361 
362  if (stratifyData)
363  {
364  StratifiedSplit(input, inputLabel, trainData, testData, trainLabel,
365  testLabel, testRatio, shuffleData);
366  }
367  else
368  {
369  Split(input, inputLabel, trainData, testData, trainLabel, testLabel,
370  testRatio, shuffleData);
371  }
372 
373  return std::make_tuple(std::move(trainData),
374  std::move(testData),
375  std::move(trainLabel),
376  std::move(testLabel));
377 }
378 
397 template<typename T>
398 std::tuple<arma::Mat<T>, arma::Mat<T>>
399 Split(const arma::Mat<T>& input,
400  const double testRatio,
401  const bool shuffleData = true)
402 {
403  arma::Mat<T> trainData;
404  arma::Mat<T> testData;
405  Split(input, trainData, testData, testRatio, shuffleData);
406 
407  return std::make_tuple(std::move(trainData),
408  std::move(testData));
409 }
410 
445 template <typename FieldType, typename T,
446  typename = std::enable_if_t<
447  arma::is_Col<typename FieldType::object_type>::value ||
448  arma::is_Mat_only<typename FieldType::object_type>::value>>
449 void Split(const FieldType& input,
450  const arma::field<T>& inputLabel,
451  FieldType& trainData,
452  arma::field<T>& trainLabel,
453  FieldType& testData,
454  arma::field<T>& testLabel,
455  const double testRatio,
456  const bool shuffleData = true)
457 {
458  if (shuffleData)
459  {
460  arma::uvec order = arma::shuffle(arma::linspace<arma::uvec>(0,
461  input.n_cols - 1, input.n_cols));
462  SplitHelper(input, trainData, testData, testRatio, order);
463  SplitHelper(inputLabel, trainLabel, testLabel, testRatio, order);
464  }
465  else
466  {
467  SplitHelper(input, trainData, testData, testRatio);
468  SplitHelper(inputLabel, trainLabel, testLabel, testRatio);
469  }
470 }
471 
500 template <class FieldType,
501  class = std::enable_if_t<
502  arma::is_Col<typename FieldType::object_type>::value ||
503  arma::is_Mat_only<typename FieldType::object_type>::value>>
504 void Split(const FieldType& input,
505  FieldType& trainData,
506  FieldType& testData,
507  const double testRatio,
508  const bool shuffleData = true)
509 {
510  if (shuffleData)
511  {
512  arma::uvec order = arma::shuffle(arma::linspace<arma::uvec>(0,
513  input.n_cols - 1, input.n_cols));
514  SplitHelper(input, trainData, testData, testRatio, order);
515  }
516  else
517  {
518  SplitHelper(input, trainData, testData, testRatio);
519  }
520 }
521 
549 template <class FieldType, typename T,
550  class = std::enable_if_t<
551  arma::is_Col<typename FieldType::object_type>::value ||
552  arma::is_Mat_only<typename FieldType::object_type>::value>>
553 std::tuple<FieldType, FieldType, arma::field<T>, arma::field<T>>
554 Split(const FieldType& input,
555  const arma::field<T>& inputLabel,
556  const double testRatio,
557  const bool shuffleData = true)
558 {
559  FieldType trainData;
560  FieldType testData;
561  arma::field<T> trainLabel;
562  arma::field<T> testLabel;
563 
564  Split(input, inputLabel, trainData, trainLabel, testData, testLabel,
565  testRatio, shuffleData);
566 
567  return std::make_tuple(std::move(trainData),
568  std::move(testData),
569  std::move(trainLabel),
570  std::move(testLabel));
571 }
572 
596 template <class FieldType,
597  class = std::enable_if_t<
598  arma::is_Col<typename FieldType::object_type>::value ||
599  arma::is_Mat_only<typename FieldType::object_type>::value>>
600 std::tuple<FieldType, FieldType>
601 Split(const FieldType& input,
602  const double testRatio,
603  const bool shuffleData = true)
604 {
605  FieldType trainData;
606  FieldType testData;
607  Split(input, trainData, testData, testRatio, shuffleData);
608 
609  return std::make_tuple(std::move(trainData),
610  std::move(testData));
611 }
612 
613 } // namespace data
614 } // namespace mlpack
615 
616 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.
void Split(const arma::Mat< T > &input, const LabelsType &inputLabel, arma::Mat< T > &trainData, arma::Mat< T > &testData, LabelsType &trainLabel, LabelsType &testLabel, const double testRatio, const bool shuffleData=true)
Given an input dataset and labels, split into a training set and test set.
Definition: split_data.hpp:255
void StratifiedSplit(const arma::Mat< T > &input, const LabelsType &inputLabel, arma::Mat< T > &trainData, arma::Mat< T > &testData, LabelsType &trainLabel, LabelsType &testLabel, const double testRatio, const bool shuffleData=true)
Given an input dataset and labels, stratify into a training set and test set.
Definition: split_data.hpp:103
void SplitHelper(const InputType &input, InputType &train, InputType &test, const double testRatio, const arma::uvec &order=arma::uvec())
This helper function splits any input data into training and testing parts.
Definition: split_data.hpp:27