mlpack
naive_kmeans_impl.hpp
Go to the documentation of this file.
1 
16 #ifndef MLPACK_METHODS_KMEANS_NAIVE_KMEANS_IMPL_HPP
17 #define MLPACK_METHODS_KMEANS_NAIVE_KMEANS_IMPL_HPP
18 
19 // In case it hasn't been included yet.
20 #include "naive_kmeans.hpp"
21 
22 namespace mlpack {
23 namespace kmeans {
24 
25 template<typename MetricType, typename MatType>
27  MetricType& metric) :
28  dataset(dataset),
29  metric(metric),
30  distanceCalculations(0)
31 { /* Nothing to do. */ }
32 
33 // Run a single iteration.
34 template<typename MetricType, typename MatType>
35 double NaiveKMeans<MetricType, MatType>::Iterate(const arma::mat& centroids,
36  arma::mat& newCentroids,
37  arma::Col<size_t>& counts)
38 {
39  newCentroids.zeros(centroids.n_rows, centroids.n_cols);
40  counts.zeros(centroids.n_cols);
41 
42  // Find the closest centroid to each point and update the new centroids.
43  // Computed in parallel over the complete dataset
44  #pragma omp parallel
45  {
46  // The current state of the K-means is private for each thread
47  arma::mat localCentroids(centroids.n_rows, centroids.n_cols,
48  arma::fill::zeros);
49  arma::Col<size_t> localCounts(centroids.n_cols, arma::fill::zeros);
50 
51  #pragma omp for
52  for (omp_size_t i = 0; i < (omp_size_t) dataset.n_cols; ++i)
53  {
54  // Find the closest centroid to this point.
55  double minDistance = std::numeric_limits<double>::infinity();
56  size_t closestCluster = centroids.n_cols; // Invalid value.
57 
58  for (size_t j = 0; j < centroids.n_cols; ++j)
59  {
60  const double distance = metric.Evaluate(dataset.col(i),
61  centroids.unsafe_col(j));
62  if (distance < minDistance)
63  {
64  minDistance = distance;
65  closestCluster = j;
66  }
67  }
68 
69  Log::Assert(closestCluster != centroids.n_cols);
70 
71  // We now have the minimum distance centroid index. Update that centroid.
72  localCentroids.unsafe_col(closestCluster) += dataset.col(i);
73  localCounts(closestCluster)++;
74  }
75  // Combine calculated state from each thread
76  #pragma omp critical
77  {
78  newCentroids += localCentroids;
79  counts += localCounts;
80  }
81  }
82 
83  // Now normalize the centroid.
84  for (size_t i = 0; i < centroids.n_cols; ++i)
85  if (counts(i) != 0)
86  newCentroids.col(i) /= counts(i);
87 
88  distanceCalculations += centroids.n_cols * dataset.n_cols;
89 
90  // Calculate cluster distortion for this iteration.
91  double cNorm = 0.0;
92  for (size_t i = 0; i < centroids.n_cols; ++i)
93  {
94  cNorm += std::pow(metric.Evaluate(centroids.col(i), newCentroids.col(i)),
95  2.0);
96  }
97  distanceCalculations += centroids.n_cols;
98 
99  return std::sqrt(cNorm);
100 }
101 
102 } // namespace kmeans
103 } // namespace mlpack
104 
105 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
NaiveKMeans(const MatType &dataset, MetricType &metric)
Construct the NaiveKMeans object with the given dataset and metric.
Definition: naive_kmeans_impl.hpp:26
double Iterate(const arma::mat &centroids, arma::mat &newCentroids, arma::Col< size_t > &counts)
Run a single iteration of the Lloyd algorithm, updating the given centroids into the newCentroids mat...
Definition: naive_kmeans_impl.hpp:35
static void Assert(bool condition, const std::string &message="Assert Failed.")
Checks if the specified condition is true.
Definition: log.cpp:38