12 #ifndef MLPACK_CORE_CV_METRICS_F1_IMPL_HPP 13 #define MLPACK_CORE_CV_METRICS_F1_IMPL_HPP 21 template<
typename MLAlgorithm,
typename DataType>
24 const arma::Row<size_t>& labels)
26 return Evaluate<AS>(model, data, labels);
30 template<AverageStrategy _AS,
typename MLAlgorithm,
typename DataType,
typename>
33 const arma::Row<size_t>& labels)
35 util::CheckSameSizes(data, labels,
"F1<Binary>::Evaluate()");
37 arma::Row<size_t> predictedLabels;
38 model.Classify(data, predictedLabels);
40 size_t tp = arma::sum((labels == PC) % (predictedLabels == PC));
41 size_t numberOfPositivePredictions = arma::sum(predictedLabels == PC);
42 size_t numberOfPositiveClassInstances = arma::sum(labels == PC);
44 double precision = double(tp) / numberOfPositivePredictions;
45 double recall = double(tp) / numberOfPositiveClassInstances;
47 return (precision + recall == 0.0) ? 0.0 :
48 2.0 * precision * recall / (precision + recall);
52 template<
AverageStrategy _AS,
typename MLAlgorithm,
typename DataType,
typename,
56 const arma::Row<size_t>& labels)
58 util::CheckSameSizes(data, labels,
"F1<Micro>::Evaluate()");
66 template<
AverageStrategy _AS,
typename MLAlgorithm,
typename DataType,
typename,
70 const arma::Row<size_t>& labels)
72 util::CheckSameSizes(data, labels,
"F1<Macro>::Evaluate()");
74 arma::Row<size_t> predictedLabels;
75 model.Classify(data, predictedLabels);
77 size_t numClasses = arma::max(labels) + 1;
79 arma::vec f1s = arma::vec(numClasses);
80 for (
size_t c = 0; c < numClasses; ++c)
82 size_t tp = arma::sum((labels == c) % (predictedLabels == c));
83 size_t positivePredictions = arma::sum(predictedLabels == c);
84 size_t positiveLabels = arma::sum(labels == c);
86 double precision = double(tp) / positivePredictions;
87 double recall = double(tp) / positiveLabels;
88 f1s(c) = (precision + recall == 0.0) ? 0.0 :
89 2.0 * precision * recall / (precision + recall);
92 return arma::mean(f1s);
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
static double Evaluate(MLAlgorithm &model, const DataType &data, const arma::Row< size_t > &labels)
Run classification and calculate F1.
Definition: f1_impl.hpp:22
static double Evaluate(MLAlgorithm &model, const DataType &data, const arma::Row< size_t > &labels)
Run classification and calculate accuracy.
Definition: accuracy_impl.hpp:19
AverageStrategy
This enum declares possible strategies for averaging that can be used in some metrics like precision...
Definition: average_strategy.hpp:25