mlpack
Functions
split_data_test.cpp File Reference
#include <mlpack/core.hpp>
#include <mlpack/core/data/split_data.hpp>
#include "test_catch_tools.hpp"
#include "catch.hpp"
Include dependency graph for split_data_test.cpp:

Functions

void CompareData (const mat &inputData, const mat &compareData, const Row< size_t > &inputLabel)
 Compare the data after train test split. More...
 
void CheckMatEqual (const mat &inputData, const mat &compareData)
 
void CheckDuplication (const Row< size_t > &trainLabels, const Row< size_t > &testLabels)
 Check that no labels have been duplicated.
 
 TEST_CASE ("SplitShuffleDataResultMat", "[SplitDataTest]")
 
 TEST_CASE ("SplitDataResultMat", "[SplitDataTest]")
 
 TEST_CASE ("ZeroRatioSplitData", "[SplitDataTest]")
 
 TEST_CASE ("TotalRatioSplitData", "[SplitDataTest]")
 
 TEST_CASE ("SplitLabeledDataResultMat", "[SplitDataTest]")
 
 TEST_CASE ("SplitDataLargerTest", "[SplitDataTest]")
 The same test as above, but on a larger dataset.
 
 TEST_CASE ("SplitLabeledDataLargerTest", "[SplitDataTest]")
 
 TEST_CASE ("ZeroRatioStratifiedSplitData", "[SplitDataTest]")
 Check that test ratio of 0 results in a full train set for stratified split.
 
 TEST_CASE ("TotalRatioStratifiedSplitData", "[SplitDataTest]")
 Check that test ratio of 1 results in a full test set for stratified split.
 
 TEST_CASE ("StratifiedSplitDataResultTest", "[SplitDataTest]")
 Check if data is stratified according to labels.
 
 TEST_CASE ("StratifiedSplitLargerDataResultTest", "[SplitDataTest]")
 Check if data is stratified according to labels on a larger data set. More...
 
 TEST_CASE ("StratifiedSplitRunTimeErrorTest", "[SplitDataTest]")
 Check that Split() with stratifyData true throws a runtime error if labels are not of type arma::Row<>.
 
 TEST_CASE ("SplitDataResultField", "[SplitDataTest]")
 
 TEST_CASE ("SplitMatrixLabeledData", "[SplitDataTest]")
 Test for Split() with labels of type arma::Mat with shuffleData = False.
 
 TEST_CASE ("SplitLabeledDataResultField", "[SplitDataTest]")
 Split with input of type field<mat> and label of type field<vec>.
 

Detailed Description

Author
Tham Ngap Wei

Test the SplitData method.

mlpack is free software; you may redistribute it and/or modify it under the terms of the 3-clause BSD license. You should have received a copy of the 3-clause BSD license along with mlpack. If not, see http://www.opensource.org/licenses/BSD-3-Clause for more information.

Function Documentation

◆ CompareData()

void CompareData ( const mat &  inputData,
const mat &  compareData,
const Row< size_t > &  inputLabel 
)

Compare the data after train test split.

This assumes that the labels correspond to each column, so that we can easily check each point against its original.

Parameters
inputDataThe original data set before split.
compareDataThe data want to compare with the inputData; it could be train data or test data.
inputLabelThe labels of each point in compareData.

◆ TEST_CASE()

TEST_CASE ( "StratifiedSplitLargerDataResultTest"  ,
""  [SplitDataTest] 
)

Check if data is stratified according to labels on a larger data set.

Example calculation to find resultant number of samples in the train and test set:

Since there are 256 0s and the test ratio is 0.3, Number of 0s in the test set = 76 ( floor(256 * 0.3) = floor(76.8) ). Number of 0s in the train set = 180 ( 256 - 76 ).