mlpack
|
Unit tests for the cross-validation module. More...
#include <type_traits>
#include <mlpack/core/cv/meta_info_extractor.hpp>
#include <mlpack/core/cv/metrics/accuracy.hpp>
#include <mlpack/core/cv/metrics/f1.hpp>
#include <mlpack/core/cv/metrics/mse.hpp>
#include <mlpack/core/cv/metrics/precision.hpp>
#include <mlpack/core/cv/metrics/recall.hpp>
#include <mlpack/core/cv/metrics/r2_score.hpp>
#include <mlpack/core/cv/metrics/silhouette_score.hpp>
#include <mlpack/core/cv/simple_cv.hpp>
#include <mlpack/core/cv/k_fold_cv.hpp>
#include <mlpack/methods/ann/ffn.hpp>
#include <mlpack/methods/ann/init_rules/const_init.hpp>
#include <mlpack/methods/ann/layer/layer.hpp>
#include <mlpack/methods/ann/loss_functions/mean_squared_error.hpp>
#include <mlpack/methods/decision_tree/decision_tree.hpp>
#include <mlpack/methods/decision_tree/information_gain.hpp>
#include <mlpack/methods/hoeffding_trees/hoeffding_tree.hpp>
#include <mlpack/methods/lars/lars.hpp>
#include <mlpack/methods/linear_regression/linear_regression.hpp>
#include <mlpack/methods/logistic_regression/logistic_regression.hpp>
#include <mlpack/methods/naive_bayes/naive_bayes_classifier.hpp>
#include <mlpack/methods/softmax_regression/softmax_regression.hpp>
#include <mlpack/core/data/confusion_matrix.hpp>
#include <ensmallen.hpp>
#include "catch.hpp"
#include "mock_categorical_data.hpp"
Functions | |
TEST_CASE ("BinaryClassificationMetricsTest", "[CVTest]") | |
Test metrics for binary classification. | |
TEST_CASE ("ConfusionMatrixTest", "[CVTest]") | |
Test for confusion matrix. | |
TEST_CASE ("MulticlassClassificationMetricsTest", "[CVTest]") | |
Test metrics for multiclass classification. | |
TEST_CASE ("MSETest", "[CVTest]") | |
Test the mean squared error. | |
TEST_CASE ("R2ScoreTest", "[CVTest]") | |
Test the R squared metric (R2 Score). | |
TEST_CASE ("AdjR2ScoreTest", "[CVTest]") | |
Test the Adjusted R squared metric. | |
TEST_CASE ("MSEMatResponsesTest", "[CVTest]") | |
Test the mean squared error with matrix responses. | |
template<typename Class , typename ExpectedPT , typename PassedMT = arma::mat, typename PassedPT = arma::Row<size_t>> | |
void | CheckPredictionsType () |
TEST_CASE ("PredictionsTypeTest", "[CVTest]") | |
Test MetaInfoExtractor correctly recognizes the type of predictions for a given machine learning algorithm. | |
TEST_CASE ("SupportsWeightsTest", "[CVTest]") | |
Test MetaInfoExtractor correctly identifies whether a given machine learning algorithm supports weighted learning. | |
template<typename Class , typename ExpectedWT , typename PassedMT = arma::mat, typename PassedPT = arma::Row<size_t>, typename PassedWT = arma::rowvec> | |
void | CheckWeightsType () |
TEST_CASE ("WeightsTypeTest", "[CVTest]") | |
Test MetaInfoExtractor correctly recognizes the type of weights for a given machine learning algorithm. | |
TEST_CASE ("TakesDatasetInfoTest", "[CVTest]") | |
Test MetaInfoExtractor correctly identifies whether a given machine learning algorithm takes a data:DatasetInfo parameter. | |
TEST_CASE ("TakesNumClassesTest", "[CVTest]") | |
Test MetaInfoExtractor correctly identifies whether a given machine learning algorithm takes the numClasses parameter. | |
TEST_CASE ("SimpleCVAccuracyTest", "[CVTest]") | |
Test the simple cross-validation strategy implementation with the Accuracy metric. | |
TEST_CASE ("SimpleCVMSETest", "[CVTest]") | |
Test the simple cross-validation strategy implementation with the MSE metric. | |
TEST_CASE ("FilterNANCVTest", "[CVTest]") | |
Test that scores of -nan are filtered out. | |
template<typename... DTArgs> | |
arma::Row< size_t > | PredictLabelsWithDT (const arma::mat &data, const DTArgs &... args) |
TEST_CASE ("SimpleCVWithDTTest", "[CVTest]") | |
Test the simple cross-validation strategy implementation with decision trees constructed in multiple ways. | |
TEST_CASE ("KFoldCVMSETest", "[CVTest]") | |
Test k-fold cross-validation with the MSE metric. | |
TEST_CASE ("KFoldCVAccuracyTest", "[CVTest]") | |
Test k-fold cross-validation with the Accuracy metric. | |
TEST_CASE ("KFoldCVWithWeightedLRTest", "[CVTest]") | |
Test k-fold cross-validation with weighted linear regression. | |
TEST_CASE ("KFoldCVWithDTTest", "[CVTest]") | |
Test k-fold cross-validation with decision trees constructed in multiple ways. | |
TEST_CASE ("KFoldCVWithDTTestLargeKNoShuffle", "[CVTest]") | |
Test k-fold cross-validation with decision trees constructed in multiple ways, but with larger k and no shuffling. | |
TEST_CASE ("KFoldCVWithDTTestUnevenBinsNoShuffle", "[CVTest]") | |
Test k-fold cross-validation with decision trees constructed in multiple ways, but with larger k such that the number of points in each cross-validation bin is not even (the last is smaller), and also with no shuffling. | |
TEST_CASE ("KFoldCVWithDTTestLargeK", "[CVTest]") | |
Test k-fold cross-validation with decision trees constructed in multiple ways, but with larger k. | |
TEST_CASE ("KFoldCVWithDTTestUnevenBins", "[CVTest]") | |
Test k-fold cross-validation with decision trees constructed in multiple ways, but with larger k such that the number of points in each cross-validation bin is not even (the last is smaller). | |
TEST_CASE ("KFoldCVWithDTTestLargeKWeighted", "[CVTest]") | |
Test k-fold cross-validation with decision trees constructed in multiple ways, but with larger k and weights. | |
TEST_CASE ("KFoldCVWithDTTestUnevenBinsWeighted", "[CVTest]") | |
Test k-fold cross-validation with decision trees constructed in multiple ways, but with larger k such that the number of points in each cross-validation bin is not even (the last is smaller) and weights. | |
TEST_CASE ("SilhouetteScoreTest", "[CVTest]") | |
Test Silhouette Score. | |
Unit tests for the cross-validation module.
mlpack is free software; you may redistribute it and/or modify it under the terms of the 3-clause BSD license. You should have received a copy of the 3-clause BSD license along with mlpack. If not, see http://www.opensource.org/licenses/BSD-3-Clause for more information.