12 #ifndef MLPACK_CORE_CV_METRICS_SILHOUETTE_SCORE_IMPL_HPP 13 #define MLPACK_CORE_CV_METRICS_SILHOUETTE_SCORE_IMPL_HPP 20 template<
typename DataType,
typename Metric>
22 const arma::Row<size_t>& labels,
25 util::CheckSameSizes(X, labels,
"SilhouetteScore::Overall()");
29 template<
typename DataType>
31 const arma::Row<size_t>& labels)
33 util::CheckSameSizes(distances, labels,
"SilhouetteScore::SamplesScore()");
36 arma::rowvec sampleScores(distances.n_rows);
38 arma::ucolvec clusterLabels = arma::find_unique(labels,
false);
40 for (
size_t i = 0; i < distances.n_rows; i++)
42 double interClusterDistance = DBL_MAX, intraClusterDistance = 0;
43 double minInterClusterDistance = DBL_MAX;
44 for (
size_t j = 0; j < clusterLabels.n_elem; j++)
46 size_t clusterLabel = labels(clusterLabels(j));
47 if (labels(i) != clusterLabel) {
49 distances.col(i), labels, clusterLabel,
false);
50 if (interClusterDistance < minInterClusterDistance) {
51 minInterClusterDistance = interClusterDistance;
55 distances.col(i), labels, clusterLabel,
true);
56 if (intraClusterDistance == 0) {
62 if (intraClusterDistance == 0) {
64 sampleScores(i) = 0.0;
66 sampleScores(i) = minInterClusterDistance - intraClusterDistance;
67 sampleScores(i) /= std::max(
68 intraClusterDistance, minInterClusterDistance);
74 template<
typename DataType,
typename Metric>
76 const arma::Row<size_t>& labels,
79 util::CheckSameSizes(X, labels,
"SilhouetteScore::SamplesScore()");
85 const arma::Row<size_t>& labels,
86 const size_t& elemLabel,
87 const bool& sameCluster)
90 arma::uvec sameClusterIndices = arma::find(labels == elemLabel);
93 size_t numSameCluster = sameClusterIndices.n_elem;
94 if ((sameCluster ==
true) && (numSameCluster == 1))
99 double distance = arma::accu(distances.elem(sameClusterIndices));
100 distance /= (numSameCluster - sameCluster);
static double MeanDistanceFromCluster(const arma::colvec &distances, const arma::Row< size_t > &labels, const size_t &label, const bool &sameCluster=false)
Find mean distance of element from a given cluster.
Definition: silhouette_score_impl.hpp:84
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
DataType PairwiseDistances(const DataType &data, const Metric &metric)
Pairwise distance of the given data.
Definition: facilities.hpp:29
static double Overall(const DataType &X, const arma::Row< size_t > &labels, const Metric &metric)
Find the overall silhouette score.
Definition: silhouette_score_impl.hpp:21
static arma::rowvec SamplesScore(const DataType &distances, const arma::Row< size_t > &labels)
Find the individual silhouette scores for precomputted dissimilarites.
Definition: silhouette_score_impl.hpp:30