mlpack
hamerly_kmeans_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_KMEANS_HAMERLY_KMEANS_IMPL_HPP
13 #define MLPACK_METHODS_KMEANS_HAMERLY_KMEANS_IMPL_HPP
14 
15 // In case it hasn't been included yet.
16 #include "hamerly_kmeans.hpp"
17 
18 namespace mlpack {
19 namespace kmeans {
20 
21 template<typename MetricType, typename MatType>
23  MetricType& metric) :
24  dataset(dataset),
25  metric(metric),
26  distanceCalculations(0)
27 {
28  // Nothing to do.
29 }
30 
31 template<typename MetricType, typename MatType>
32 double HamerlyKMeans<MetricType, MatType>::Iterate(const arma::mat& centroids,
33  arma::mat& newCentroids,
34  arma::Col<size_t>& counts)
35 {
36  size_t hamerlyPruned = 0;
37 
38  // If this is the first iteration, we need to set all the bounds.
39  if (minClusterDistances.n_elem != centroids.n_cols)
40  {
41  upperBounds.set_size(dataset.n_cols);
42  upperBounds.fill(DBL_MAX);
43  lowerBounds.zeros(dataset.n_cols);
44  assignments.zeros(dataset.n_cols);
45  minClusterDistances.set_size(centroids.n_cols);
46  }
47 
48  // Reset new centroids.
49  newCentroids.zeros(centroids.n_rows, centroids.n_cols);
50  counts.zeros(centroids.n_cols);
51 
52  // Calculate minimum intra-cluster distance for each cluster.
53  minClusterDistances.fill(DBL_MAX);
54  for (size_t i = 0; i < centroids.n_cols; ++i)
55  {
56  for (size_t j = i + 1; j < centroids.n_cols; ++j)
57  {
58  const double dist = metric.Evaluate(centroids.col(i), centroids.col(j)) /
59  2.0;
60  ++distanceCalculations;
61 
62  // Update bounds, if this intra-cluster distance is smaller.
63  if (dist < minClusterDistances(i))
64  minClusterDistances(i) = dist;
65  if (dist < minClusterDistances(j))
66  minClusterDistances(j) = dist;
67  }
68  }
69 
70  for (size_t i = 0; i < dataset.n_cols; ++i)
71  {
72  const double m = std::max(minClusterDistances(assignments[i]),
73  lowerBounds(i));
74 
75  // First bound test.
76  if (upperBounds(i) <= m)
77  {
78  ++hamerlyPruned;
79  newCentroids.col(assignments[i]) += dataset.col(i);
80  ++counts(assignments[i]);
81  continue;
82  }
83 
84  // Tighten upper bound.
85  upperBounds(i) = metric.Evaluate(dataset.col(i),
86  centroids.col(assignments[i]));
87  ++distanceCalculations;
88 
89  // Second bound test.
90  if (upperBounds(i) <= m)
91  {
92  newCentroids.col(assignments[i]) += dataset.col(i);
93  ++counts(assignments[i]);
94  continue;
95  }
96 
97  // The bounds failed. So test against all other clusters.
98  // This is Hamerly's Point-All-Ctrs() function from the paper.
99  // We have to reset the lower bound first.
100  lowerBounds(i) = DBL_MAX;
101  for (size_t c = 0; c < centroids.n_cols; ++c)
102  {
103  if (c == assignments[i])
104  continue;
105 
106  const double dist = metric.Evaluate(dataset.col(i), centroids.col(c));
107 
108  // Is this a better cluster? At this point, upperBounds[i] = d(i, c(i)).
109  if (dist < upperBounds(i))
110  {
111  // lowerBounds holds the second closest cluster.
112  lowerBounds(i) = upperBounds(i);
113  upperBounds(i) = dist;
114  assignments[i] = c;
115  }
116  else if (dist < lowerBounds(i))
117  {
118  // This is a closer second-closest cluster.
119  lowerBounds(i) = dist;
120  }
121  }
122  distanceCalculations += centroids.n_cols - 1;
123 
124  // Update new centroids.
125  newCentroids.col(assignments[i]) += dataset.col(i);
126  ++counts(assignments[i]);
127  }
128 
129  // Normalize centroids and calculate cluster movement (contains parts of
130  // Move-Centers() and Update-Bounds()).
131  double furthestMovement = 0.0;
132  double secondFurthestMovement = 0.0;
133  size_t furthestMovingCluster = 0;
134  arma::vec centroidMovements(centroids.n_cols);
135  double centroidMovement = 0.0;
136  for (size_t c = 0; c < centroids.n_cols; ++c)
137  {
138  if (counts(c) > 0)
139  newCentroids.col(c) /= counts(c);
140 
141  // Calculate movement.
142  const double movement = metric.Evaluate(centroids.col(c),
143  newCentroids.col(c));
144  centroidMovements(c) = movement;
145  centroidMovement += std::pow(movement, 2.0);
146  ++distanceCalculations;
147 
148  if (movement > furthestMovement)
149  {
150  secondFurthestMovement = furthestMovement;
151  furthestMovement = movement;
152  furthestMovingCluster = c;
153  }
154  else if (movement > secondFurthestMovement)
155  {
156  secondFurthestMovement = movement;
157  }
158  }
159 
160  // Now update bounds (lines 3-8 of Update-Bounds()).
161  for (size_t i = 0; i < dataset.n_cols; ++i)
162  {
163  upperBounds(i) += centroidMovements(assignments[i]);
164  if (assignments[i] == furthestMovingCluster)
165  lowerBounds(i) -= secondFurthestMovement;
166  else
167  lowerBounds(i) -= furthestMovement;
168  }
169 
170  Log::Info << "Hamerly prunes: " << hamerlyPruned << ".\n";
171 
172  return std::sqrt(centroidMovement);
173 }
174 
175 } // namespace kmeans
176 } // namespace mlpack
177 
178 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
HamerlyKMeans(const MatType &dataset, MetricType &metric)
Construct the HamerlyKMeans object, which must store several sets of bounds.
Definition: hamerly_kmeans_impl.hpp:22
double Iterate(const arma::mat &centroids, arma::mat &newCentroids, arma::Col< size_t > &counts)
Run a single iteration of Hamerly&#39;s algorithm, updating the given centroids into the newCentroids mat...
Definition: hamerly_kmeans_impl.hpp:32
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84