mlpack
silhouette_score_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_CORE_CV_METRICS_SILHOUETTE_SCORE_IMPL_HPP
13 #define MLPACK_CORE_CV_METRICS_SILHOUETTE_SCORE_IMPL_HPP
14 
16 
17 namespace mlpack {
18 namespace cv {
19 
20 template<typename DataType, typename Metric>
21 double SilhouetteScore::Overall(const DataType& X,
22  const arma::Row<size_t>& labels,
23  const Metric& metric)
24 {
25  util::CheckSameSizes(X, labels, "SilhouetteScore::Overall()");
26  return arma::mean(SamplesScore(X, labels, metric));
27 }
28 
29 template<typename DataType>
30 arma::rowvec SilhouetteScore::SamplesScore(const DataType& distances,
31  const arma::Row<size_t>& labels)
32 {
33  util::CheckSameSizes(distances, labels, "SilhouetteScore::SamplesScore()");
34 
35  // Stores the silhouette scores of individual samples.
36  arma::rowvec sampleScores(distances.n_rows);
37  // Finds one index per cluster.
38  arma::ucolvec clusterLabels = arma::find_unique(labels, false);
39 
40  for (size_t i = 0; i < distances.n_rows; i++)
41  {
42  double interClusterDistance = DBL_MAX, intraClusterDistance = 0;
43  double minInterClusterDistance = DBL_MAX;
44  for (size_t j = 0; j < clusterLabels.n_elem; j++)
45  {
46  size_t clusterLabel = labels(clusterLabels(j));
47  if (labels(i) != clusterLabel) {
48  interClusterDistance = MeanDistanceFromCluster(
49  distances.col(i), labels, clusterLabel, false);
50  if (interClusterDistance < minInterClusterDistance) {
51  minInterClusterDistance = interClusterDistance;
52  }
53  } else {
54  intraClusterDistance = MeanDistanceFromCluster(
55  distances.col(i), labels, clusterLabel, true);
56  if (intraClusterDistance == 0) {
57  // s(i) = 0, no more calculation needed.
58  break;
59  }
60  }
61  }
62  if (intraClusterDistance == 0) {
63  // i is the only element in the cluster.
64  sampleScores(i) = 0.0;
65  } else {
66  sampleScores(i) = minInterClusterDistance - intraClusterDistance;
67  sampleScores(i) /= std::max(
68  intraClusterDistance, minInterClusterDistance);
69  }
70  }
71  return sampleScores;
72 }
73 
74 template<typename DataType, typename Metric>
75 arma::rowvec SilhouetteScore::SamplesScore(const DataType& X,
76  const arma::Row<size_t>& labels,
77  const Metric& metric)
78 {
79  util::CheckSameSizes(X, labels, "SilhouetteScore::SamplesScore()");
80  DataType distances = PairwiseDistances(X, metric);
81  return SamplesScore(distances, labels);
82 }
83 
84 double SilhouetteScore::MeanDistanceFromCluster(const arma::colvec& distances,
85  const arma::Row<size_t>& labels,
86  const size_t& elemLabel,
87  const bool& sameCluster)
88 {
89  // Find indices of elements with same label as elemLabel.
90  arma::uvec sameClusterIndices = arma::find(labels == elemLabel);
91 
92  // Numver of elements in the given cluster.
93  size_t numSameCluster = sameClusterIndices.n_elem;
94  if ((sameCluster == true) && (numSameCluster == 1))
95  {
96  // Return 0 if subject element is the only element in cluster.
97  return 0.0;
98  } else {
99  double distance = arma::accu(distances.elem(sameClusterIndices));
100  distance /= (numSameCluster - sameCluster);
101  return distance;
102  }
103 }
104 
105 } // namespace cv
106 } // namespace mlpack
107 
108 #endif
static double MeanDistanceFromCluster(const arma::colvec &distances, const arma::Row< size_t > &labels, const size_t &label, const bool &sameCluster=false)
Find mean distance of element from a given cluster.
Definition: silhouette_score_impl.hpp:84
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
DataType PairwiseDistances(const DataType &data, const Metric &metric)
Pairwise distance of the given data.
Definition: facilities.hpp:29
static double Overall(const DataType &X, const arma::Row< size_t > &labels, const Metric &metric)
Find the overall silhouette score.
Definition: silhouette_score_impl.hpp:21
static arma::rowvec SamplesScore(const DataType &distances, const arma::Row< size_t > &labels)
Find the individual silhouette scores for precomputted dissimilarites.
Definition: silhouette_score_impl.hpp:30