mlpack
|
#include <mlpack/core.hpp>
#include <mlpack/core/data/split_data.hpp>
#include "test_catch_tools.hpp"
#include "catch.hpp"
Functions | |
void | CompareData (const mat &inputData, const mat &compareData, const Row< size_t > &inputLabel) |
Compare the data after train test split. More... | |
void | CheckMatEqual (const mat &inputData, const mat &compareData) |
void | CheckDuplication (const Row< size_t > &trainLabels, const Row< size_t > &testLabels) |
Check that no labels have been duplicated. | |
TEST_CASE ("SplitShuffleDataResultMat", "[SplitDataTest]") | |
TEST_CASE ("SplitDataResultMat", "[SplitDataTest]") | |
TEST_CASE ("ZeroRatioSplitData", "[SplitDataTest]") | |
TEST_CASE ("TotalRatioSplitData", "[SplitDataTest]") | |
TEST_CASE ("SplitLabeledDataResultMat", "[SplitDataTest]") | |
TEST_CASE ("SplitDataLargerTest", "[SplitDataTest]") | |
The same test as above, but on a larger dataset. | |
TEST_CASE ("SplitLabeledDataLargerTest", "[SplitDataTest]") | |
TEST_CASE ("ZeroRatioStratifiedSplitData", "[SplitDataTest]") | |
Check that test ratio of 0 results in a full train set for stratified split. | |
TEST_CASE ("TotalRatioStratifiedSplitData", "[SplitDataTest]") | |
Check that test ratio of 1 results in a full test set for stratified split. | |
TEST_CASE ("StratifiedSplitDataResultTest", "[SplitDataTest]") | |
Check if data is stratified according to labels. | |
TEST_CASE ("StratifiedSplitLargerDataResultTest", "[SplitDataTest]") | |
Check if data is stratified according to labels on a larger data set. More... | |
TEST_CASE ("StratifiedSplitRunTimeErrorTest", "[SplitDataTest]") | |
Check that Split() with stratifyData true throws a runtime error if labels are not of type arma::Row<>. | |
TEST_CASE ("SplitDataResultField", "[SplitDataTest]") | |
TEST_CASE ("SplitMatrixLabeledData", "[SplitDataTest]") | |
Test for Split() with labels of type arma::Mat with shuffleData = False. | |
TEST_CASE ("SplitLabeledDataResultField", "[SplitDataTest]") | |
Split with input of type field<mat> and label of type field<vec>. | |
Test the SplitData method.
mlpack is free software; you may redistribute it and/or modify it under the terms of the 3-clause BSD license. You should have received a copy of the 3-clause BSD license along with mlpack. If not, see http://www.opensource.org/licenses/BSD-3-Clause for more information.
void CompareData | ( | const mat & | inputData, |
const mat & | compareData, | ||
const Row< size_t > & | inputLabel | ||
) |
Compare the data after train test split.
This assumes that the labels correspond to each column, so that we can easily check each point against its original.
inputData | The original data set before split. |
compareData | The data want to compare with the inputData; it could be train data or test data. |
inputLabel | The labels of each point in compareData. |
TEST_CASE | ( | "StratifiedSplitLargerDataResultTest" | , |
"" | [SplitDataTest] | ||
) |
Check if data is stratified according to labels on a larger data set.
Example calculation to find resultant number of samples in the train and test set:
Since there are 256 0s and the test ratio is 0.3, Number of 0s in the test set = 76 ( floor(256 * 0.3) = floor(76.8) ). Number of 0s in the train set = 180 ( 256 - 76 ).