mlpack
mean_shift_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_MEAN_SHIFT_MEAN_SHIFT_IMPL_HPP
13 #define MLPACK_METHODS_MEAN_SHIFT_MEAN_SHIFT_IMPL_HPP
14 
20 
21 #include "map"
22 
23 // In case it hasn't been included yet.
24 #include "mean_shift.hpp"
25 
26 namespace mlpack {
27 namespace meanshift {
28 
32 template<bool UseKernel, typename KernelType, typename MatType>
34 MeanShift(const double radius,
35  const size_t maxIterations,
36  const KernelType kernel) :
37  radius(radius),
38  maxIterations(maxIterations),
39  kernel(kernel)
40 {
41  // Nothing to do.
42 }
43 
44 template<bool UseKernel, typename KernelType, typename MatType>
46 {
47  this->radius = radius;
48 }
49 
50 // Estimate radius based on given dataset.
51 template<bool UseKernel, typename KernelType, typename MatType>
53 EstimateRadius(const MatType& data, double ratio)
54 {
55  neighbor::KNN neighborSearch(data);
56 
62  const size_t nNeighbors = size_t(data.n_cols * ratio);
63  arma::Mat<size_t> neighbors;
64  arma::mat distances;
65  neighborSearch.Search(nNeighbors, neighbors, distances);
66 
67  // Get max distance for each point.
68  arma::rowvec maxDistances = max(distances);
69 
70  // Calculate and return the radius.
71  return arma::sum(maxDistances) / (double) data.n_cols;
72 }
73 
74 // Class to compare two vectors.
75 template <typename VecType>
76 class less
77 {
78  public:
79  bool operator()(const VecType& first, const VecType& second) const
80  {
81  for (size_t i = 0; i < first.n_rows; ++i)
82  {
83  if (first[i] == second[i])
84  continue;
85  return first(i) < second(i);
86  }
87  return false;
88  }
89 };
90 
91 // Generate seeds from given data set.
92 template<bool UseKernel, typename KernelType, typename MatType>
94  const MatType& data,
95  const double binSize,
96  const int minFreq,
97  MatType& seeds)
98 {
99  typedef arma::colvec VecType;
100  std::map<VecType, int, less<VecType> > allSeeds;
101  for (size_t i = 0; i < data.n_cols; ++i)
102  {
103  VecType binnedPoint = arma::floor(data.unsafe_col(i) / binSize);
104  if (allSeeds.find(binnedPoint) == allSeeds.end())
105  allSeeds[binnedPoint] = 1;
106  else
107  allSeeds[binnedPoint]++;
108  }
109 
110  // Remove seeds with too few points. First we count the number of seeds we
111  // end up with, then we add them.
112  std::map<VecType, int, less<VecType> >::iterator it;
113  size_t count = 0;
114  for (it = allSeeds.begin(); it != allSeeds.end(); ++it)
115  if (it->second >= minFreq)
116  ++count;
117 
118  seeds.set_size(data.n_rows, count);
119  count = 0;
120  for (it = allSeeds.begin(); it != allSeeds.end(); ++it)
121  {
122  if (it->second >= minFreq)
123  {
124  seeds.col(count) = it->first;
125  ++count;
126  }
127  }
128 
129  seeds *= binSize;
130 }
131 
132 // Calculate new centroid with given kernel.
133 template<bool UseKernel, typename KernelType, typename MatType>
134 template<bool ApplyKernel>
135 typename std::enable_if<ApplyKernel, bool>::type
137 CalculateCentroid(const MatType& data,
138  const std::vector<size_t>& neighbors,
139  const std::vector<double>& distances,
140  arma::colvec& centroid)
141 {
142  double sumWeight = 0;
143  for (size_t i = 0; i < neighbors.size(); ++i)
144  {
145  if (distances[i] > 0)
146  {
147  double dist = distances[i] / radius;
148  double weight = kernel.Gradient(dist) / dist;
149  sumWeight += weight;
150  centroid += weight * data.unsafe_col(neighbors[i]);
151  }
152  }
153 
154  if (sumWeight != 0)
155  {
156  centroid /= sumWeight;
157  return true;
158  }
159  return false;
160 }
161 
162 // Calculate new centroid by mean.
163 template<bool UseKernel, typename KernelType, typename MatType>
164 template<bool ApplyKernel>
165 typename std::enable_if<!ApplyKernel, bool>::type
167 CalculateCentroid(const MatType& data,
168  const std::vector<size_t>& neighbors,
169  const std::vector<double>&, /*unused*/
170  arma::colvec& centroid)
171 {
172  for (size_t i = 0; i < neighbors.size(); ++i)
173  centroid += data.unsafe_col(neighbors[i]);
174 
175  centroid /= neighbors.size();
176  return true;
177 }
178 
183 template<bool UseKernel, typename KernelType, typename MatType>
185  const MatType& data,
186  arma::Row<size_t>& assignments,
187  arma::mat& centroids,
188  bool forceConvergence,
189  bool useSeeds)
190 {
191  if (radius <= 0)
192  {
193  // An invalid radius is given; an estimation is needed.
194  Radius(EstimateRadius(data));
195  }
196 
197  MatType seeds;
198  const MatType* pSeeds = &data;
199  if (useSeeds)
200  {
201  GenSeeds(data, radius, 1, seeds);
202  pSeeds = &seeds;
203  }
204 
205  // Holds all centroids before removing duplicate ones.
206  arma::mat allCentroids(pSeeds->n_rows, pSeeds->n_cols);
207 
208  assignments.set_size(data.n_cols);
209 
210  range::RangeSearch<> rangeSearcher(data);
211  math::Range validRadius(0, radius);
212  std::vector<std::vector<size_t> > neighbors;
213  std::vector<std::vector<double> > distances;
214 
215  // For each seed, perform mean shift algorithm.
216  for (size_t i = 0; i < pSeeds->n_cols; ++i)
217  {
218  // Initial centroid is the seed itself.
219  allCentroids.col(i) = pSeeds->unsafe_col(i);
220  for (size_t completedIterations = 0; completedIterations < maxIterations
221  || forceConvergence; completedIterations++)
222  {
223  // Store new centroid in this.
224  arma::colvec newCentroid = arma::zeros<arma::colvec>(pSeeds->n_rows);
225 
226  rangeSearcher.Search(allCentroids.unsafe_col(i), validRadius,
227  neighbors, distances);
228  if (neighbors[0].size() == 0) // There are no points in the cluster.
229  break;
230 
231  // Calculate new centroid.
232  if (!CalculateCentroid(data, neighbors[0], distances[0], newCentroid))
233  newCentroid = allCentroids.unsafe_col(i);
234 
235  // If the mean shift vector is small enough, it has converged.
236  if (metric::EuclideanDistance::Evaluate(newCentroid,
237  allCentroids.unsafe_col(i)) < 1e-3 * radius)
238  {
239  // Determine if the new centroid is duplicate with old ones.
240  bool isDuplicated = false;
241  for (size_t k = 0; k < centroids.n_cols; ++k)
242  {
243  const double distance = metric::EuclideanDistance::Evaluate(
244  allCentroids.unsafe_col(i), centroids.unsafe_col(k));
245  if (distance < radius)
246  {
247  isDuplicated = true;
248  break;
249  }
250  }
251 
252  if (!isDuplicated)
253  centroids.insert_cols(centroids.n_cols, allCentroids.unsafe_col(i));
254 
255  // Get out of the loop.
256  break;
257  }
258 
259  // Update the centroid.
260  allCentroids.col(i) = newCentroid;
261  }
262  }
263 
264  // If no centroid has converged due to too little iterations and without
265  // forcing convergence, take 1 random centroid calculated.
266  if (centroids.empty())
267  {
268  Log::Warn << "No clusters converged; setting 1 random centroid calculated. "
269  << "Try increasing the maximum number of iterations or setting the "
270  << "option to force convergence." << std::endl;
271 
272  if (maxIterations == 0)
273  {
274  centroids.insert_cols(centroids.n_cols, data.col(0));
275  }
276  else
277  {
278  centroids.insert_cols(centroids.n_cols, allCentroids.col(0));
279  }
280  assignments.zeros();
281  }
282  else if (centroids.n_cols == 1)
283  {
284  assignments.zeros();
285  }
286  else
287  {
288  // Assign centroids to each point.
289  neighbor::KNN neighborSearcher(centroids);
290  arma::mat neighborDistances;
291  arma::Mat<size_t> resultingNeighbors;
292  neighborSearcher.Search(data, 1, resultingNeighbors, neighborDistances);
293  assignments = resultingNeighbors;
294  }
295 }
296 
297 } // namespace meanshift
298 } // namespace mlpack
299 
300 #endif
The RangeSearch class is a template class for performing range searches.
Definition: range_search.hpp:45
double Radius() const
Get the radius.
Definition: mean_shift.hpp:100
double EstimateRadius(const MatType &data, const double ratio=0.2)
Give an estimation of radius based on given dataset.
Definition: mean_shift_impl.hpp:53
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The NeighborSearch class is a template class for performing distance-based neighbor searches...
Definition: neighbor_search.hpp:88
This class implements mean shift clustering.
Definition: mean_shift.hpp:50
MeanShift(const double radius=0, const size_t maxIterations=1000, const KernelType kernel=KernelType())
Create a mean shift object and set the parameters which mean shift will be run with.
Definition: mean_shift_impl.hpp:34
static VecTypeA::elem_type Evaluate(const VecTypeA &a, const VecTypeB &b)
Computes the distance between two points.
Definition: lmetric_impl.hpp:24
Definition: mean_shift_impl.hpp:76
static MLPACK_EXPORT util::PrefixedOutStream Warn
Prints warning messages prefixed with [WARN ].
Definition: log.hpp:87
void Cluster(const MatType &data, arma::Row< size_t > &assignments, arma::mat &centroids, bool forceConvergence=true, bool useSeeds=true)
Perform mean shift clustering on the data, returning a list of cluster assignments and centroids...
Definition: mean_shift_impl.hpp:184
void Search(const MatType &querySet, const math::Range &range, std::vector< std::vector< size_t >> &neighbors, std::vector< std::vector< double >> &distances)
Search for all reference points in the given range for each point in the query set, returning the results in the neighbors and distances objects.
Definition: range_search_impl.hpp:309
void Search(const MatType &querySet, const size_t k, arma::Mat< size_t > &neighbors, arma::mat &distances)
For each point in the query set, compute the nearest neighbors and store the output in the given matr...
Definition: neighbor_search_impl.hpp:389