13 #ifndef MLPACK_CORE_DATA_SPLIT_DATA_HPP 14 #define MLPACK_CORE_DATA_SPLIT_DATA_HPP 26 template<
typename InputType>
30 const double testRatio,
31 const arma::uvec& order = arma::uvec())
33 const size_t testSize =
static_cast<size_t>(input.n_cols * testRatio);
34 const size_t trainSize = input.n_cols - testSize;
37 train.set_size(input.n_rows, trainSize);
38 test.set_size(input.n_rows, testSize);
41 if (!order.is_empty())
45 for (
size_t i = 0; i < trainSize; ++i)
46 train.col(i) = input.col(order(i));
48 if (trainSize < input.n_cols)
50 for (
size_t i = trainSize; i < input.n_cols; ++i)
51 test.col(i - trainSize) = input.col(order(i));
58 train = input.cols(0, trainSize - 1);
60 if (trainSize < input.n_cols)
61 test = input.cols(trainSize, input.n_cols - 1);
101 template<
typename T,
typename LabelsType,
102 typename = std::enable_if_t<arma::is_arma_type<LabelsType>::value> >
104 const LabelsType& inputLabel,
105 arma::Mat<T>& trainData,
106 arma::Mat<T>& testData,
107 LabelsType& trainLabel,
108 LabelsType& testLabel,
109 const double testRatio,
110 const bool shuffleData =
true)
145 const bool typeCheck = (arma::is_Row<LabelsType>::value)
146 || (arma::is_Col<LabelsType>::value);
148 throw std::runtime_error(
"data::Split(): when stratified sampling is done, " 149 "labels must have type `arma::Row<>`!");
152 size_t trainSize = 0;
154 arma::uvec labelCounts;
155 arma::uvec testLabelCounts;
156 typename LabelsType::elem_type maxLabel = inputLabel.max();
158 labelCounts.zeros(maxLabel+1);
159 testLabelCounts.zeros(maxLabel+1);
161 for (
typename LabelsType::elem_type label : inputLabel)
162 ++labelCounts[label];
164 for (arma::uword labelCount : labelCounts)
166 testSize += floor(labelCount * testRatio);
167 trainSize += labelCount - floor(labelCount * testRatio);
170 trainData.set_size(input.n_rows, trainSize);
171 testData.set_size(input.n_rows, testSize);
172 trainLabel.set_size(inputLabel.n_rows, trainSize);
173 testLabel.set_size(inputLabel.n_rows, testSize);
177 arma::uvec order = arma::shuffle(
178 arma::linspace<arma::uvec>(0, input.n_cols - 1, input.n_cols));
180 for (arma::uword i : order)
182 typename LabelsType::elem_type label = inputLabel[i];
183 if (testLabelCounts[label] < floor(labelCounts[label] * testRatio))
185 testLabelCounts[label] += 1;
186 testData.col(testIdx) = input.col(i);
187 testLabel[testIdx] = inputLabel[i];
192 trainData.col(trainIdx) = input.col(i);
193 trainLabel[trainIdx] = inputLabel[i];
200 for (arma::uword i = 0; i < input.n_cols; i++)
202 typename LabelsType::elem_type label = inputLabel[i];
203 if (testLabelCounts[label] < floor(labelCounts[label] * testRatio))
205 testLabelCounts[label] += 1;
206 testData.col(testIdx) = input.col(i);
207 testLabel[testIdx] = inputLabel[i];
212 trainData.col(trainIdx) = input.col(i);
213 trainLabel[trainIdx] = inputLabel[i];
253 template<
typename T,
typename LabelsType,
254 typename = std::enable_if_t<arma::is_arma_type<LabelsType>::value> >
255 void Split(
const arma::Mat<T>& input,
256 const LabelsType& inputLabel,
257 arma::Mat<T>& trainData,
258 arma::Mat<T>& testData,
259 LabelsType& trainLabel,
260 LabelsType& testLabel,
261 const double testRatio,
262 const bool shuffleData =
true)
266 arma::uvec order = arma::shuffle(arma::linspace<arma::uvec>(0,
267 input.n_cols - 1, input.n_cols));
268 SplitHelper(input, trainData, testData, testRatio, order);
269 SplitHelper(inputLabel, trainLabel, testLabel, testRatio, order);
273 SplitHelper(input, trainData, testData, testRatio);
274 SplitHelper(inputLabel, trainLabel, testLabel, testRatio);
302 void Split(
const arma::Mat<T>& input,
303 arma::Mat<T>& trainData,
304 arma::Mat<T>& testData,
305 const double testRatio,
306 const bool shuffleData =
true)
310 arma::uvec order = arma::shuffle(arma::linspace<arma::uvec>(0,
311 input.n_cols - 1, input.n_cols));
312 SplitHelper(input, trainData, testData, testRatio, order);
316 SplitHelper(input, trainData, testData, testRatio);
348 template<
typename T,
typename LabelsType,
349 typename = std::enable_if_t<arma::is_arma_type<LabelsType>::value> >
350 std::tuple<arma::Mat<T>, arma::Mat<T>, LabelsType, LabelsType>
352 const LabelsType& inputLabel,
353 const double testRatio,
354 const bool shuffleData =
true,
355 const bool stratifyData =
false)
357 arma::Mat<T> trainData;
358 arma::Mat<T> testData;
359 LabelsType trainLabel;
360 LabelsType testLabel;
365 testLabel, testRatio, shuffleData);
369 Split(input, inputLabel, trainData, testData, trainLabel, testLabel,
370 testRatio, shuffleData);
373 return std::make_tuple(std::move(trainData),
375 std::move(trainLabel),
376 std::move(testLabel));
398 std::tuple<arma::Mat<T>, arma::Mat<T>>
400 const double testRatio,
401 const bool shuffleData =
true)
403 arma::Mat<T> trainData;
404 arma::Mat<T> testData;
405 Split(input, trainData, testData, testRatio, shuffleData);
407 return std::make_tuple(std::move(trainData),
408 std::move(testData));
445 template <
typename FieldType,
typename T,
446 typename = std::enable_if_t<
447 arma::is_Col<typename FieldType::object_type>::value ||
448 arma::is_Mat_only<typename FieldType::object_type>::value>>
450 const arma::field<T>& inputLabel,
451 FieldType& trainData,
452 arma::field<T>& trainLabel,
454 arma::field<T>& testLabel,
455 const double testRatio,
456 const bool shuffleData =
true)
460 arma::uvec order = arma::shuffle(arma::linspace<arma::uvec>(0,
461 input.n_cols - 1, input.n_cols));
462 SplitHelper(input, trainData, testData, testRatio, order);
463 SplitHelper(inputLabel, trainLabel, testLabel, testRatio, order);
467 SplitHelper(input, trainData, testData, testRatio);
468 SplitHelper(inputLabel, trainLabel, testLabel, testRatio);
500 template <
class FieldType,
501 class = std::enable_if_t<
502 arma::is_Col<typename FieldType::object_type>::value ||
503 arma::is_Mat_only<typename FieldType::object_type>::value>>
505 FieldType& trainData,
507 const double testRatio,
508 const bool shuffleData =
true)
512 arma::uvec order = arma::shuffle(arma::linspace<arma::uvec>(0,
513 input.n_cols - 1, input.n_cols));
514 SplitHelper(input, trainData, testData, testRatio, order);
518 SplitHelper(input, trainData, testData, testRatio);
549 template <
class FieldType,
typename T,
550 class = std::enable_if_t<
551 arma::is_Col<typename FieldType::object_type>::value ||
552 arma::is_Mat_only<typename FieldType::object_type>::value>>
553 std::tuple<FieldType, FieldType, arma::field<T>, arma::field<T>>
555 const arma::field<T>& inputLabel,
556 const double testRatio,
557 const bool shuffleData =
true)
561 arma::field<T> trainLabel;
562 arma::field<T> testLabel;
564 Split(input, inputLabel, trainData, trainLabel, testData, testLabel,
565 testRatio, shuffleData);
567 return std::make_tuple(std::move(trainData),
569 std::move(trainLabel),
570 std::move(testLabel));
596 template <
class FieldType,
597 class = std::enable_if_t<
598 arma::is_Col<typename FieldType::object_type>::value ||
599 arma::is_Mat_only<typename FieldType::object_type>::value>>
600 std::tuple<FieldType, FieldType>
602 const double testRatio,
603 const bool shuffleData =
true)
607 Split(input, trainData, testData, testRatio, shuffleData);
609 return std::make_tuple(std::move(trainData),
610 std::move(testData));
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.
void Split(const arma::Mat< T > &input, const LabelsType &inputLabel, arma::Mat< T > &trainData, arma::Mat< T > &testData, LabelsType &trainLabel, LabelsType &testLabel, const double testRatio, const bool shuffleData=true)
Given an input dataset and labels, split into a training set and test set.
Definition: split_data.hpp:255
void StratifiedSplit(const arma::Mat< T > &input, const LabelsType &inputLabel, arma::Mat< T > &trainData, arma::Mat< T > &testData, LabelsType &trainLabel, LabelsType &testLabel, const double testRatio, const bool shuffleData=true)
Given an input dataset and labels, stratify into a training set and test set.
Definition: split_data.hpp:103
void SplitHelper(const InputType &input, InputType &train, InputType &test, const double testRatio, const arma::uvec &order=arma::uvec())
This helper function splits any input data into training and testing parts.
Definition: split_data.hpp:27