12 #ifndef MLPACK_METHODS_MEAN_SHIFT_MEAN_SHIFT_IMPL_HPP 13 #define MLPACK_METHODS_MEAN_SHIFT_MEAN_SHIFT_IMPL_HPP 32 template<
bool UseKernel,
typename KernelType,
typename MatType>
35 const size_t maxIterations,
36 const KernelType kernel) :
38 maxIterations(maxIterations),
44 template<
bool UseKernel,
typename KernelType,
typename MatType>
47 this->radius = radius;
51 template<
bool UseKernel,
typename KernelType,
typename MatType>
62 const size_t nNeighbors = size_t(data.n_cols * ratio);
63 arma::Mat<size_t> neighbors;
65 neighborSearch.
Search(nNeighbors, neighbors, distances);
68 arma::rowvec maxDistances = max(distances);
71 return arma::sum(maxDistances) / (double) data.n_cols;
75 template <
typename VecType>
79 bool operator()(
const VecType& first,
const VecType& second)
const 81 for (
size_t i = 0; i < first.n_rows; ++i)
83 if (first[i] == second[i])
85 return first(i) < second(i);
92 template<
bool UseKernel,
typename KernelType,
typename MatType>
99 typedef arma::colvec VecType;
100 std::map<VecType, int, less<VecType> > allSeeds;
101 for (
size_t i = 0; i < data.n_cols; ++i)
103 VecType binnedPoint = arma::floor(data.unsafe_col(i) / binSize);
104 if (allSeeds.find(binnedPoint) == allSeeds.end())
105 allSeeds[binnedPoint] = 1;
107 allSeeds[binnedPoint]++;
112 std::map<VecType, int, less<VecType> >::iterator it;
114 for (it = allSeeds.begin(); it != allSeeds.end(); ++it)
115 if (it->second >= minFreq)
118 seeds.set_size(data.n_rows, count);
120 for (it = allSeeds.begin(); it != allSeeds.end(); ++it)
122 if (it->second >= minFreq)
124 seeds.col(count) = it->first;
133 template<
bool UseKernel,
typename KernelType,
typename MatType>
134 template<
bool ApplyKernel>
135 typename std::enable_if<ApplyKernel, bool>::type
138 const std::vector<size_t>& neighbors,
139 const std::vector<double>& distances,
140 arma::colvec& centroid)
142 double sumWeight = 0;
143 for (
size_t i = 0; i < neighbors.size(); ++i)
145 if (distances[i] > 0)
147 double dist = distances[i] / radius;
148 double weight = kernel.Gradient(dist) / dist;
150 centroid += weight * data.unsafe_col(neighbors[i]);
156 centroid /= sumWeight;
163 template<
bool UseKernel,
typename KernelType,
typename MatType>
164 template<
bool ApplyKernel>
165 typename std::enable_if<!ApplyKernel, bool>::type
168 const std::vector<size_t>& neighbors,
169 const std::vector<double>&,
170 arma::colvec& centroid)
172 for (
size_t i = 0; i < neighbors.size(); ++i)
173 centroid += data.unsafe_col(neighbors[i]);
175 centroid /= neighbors.size();
183 template<
bool UseKernel,
typename KernelType,
typename MatType>
186 arma::Row<size_t>& assignments,
187 arma::mat& centroids,
188 bool forceConvergence,
198 const MatType* pSeeds = &data;
201 GenSeeds(data, radius, 1, seeds);
206 arma::mat allCentroids(pSeeds->n_rows, pSeeds->n_cols);
208 assignments.set_size(data.n_cols);
212 std::vector<std::vector<size_t> > neighbors;
213 std::vector<std::vector<double> > distances;
216 for (
size_t i = 0; i < pSeeds->n_cols; ++i)
219 allCentroids.col(i) = pSeeds->unsafe_col(i);
220 for (
size_t completedIterations = 0; completedIterations < maxIterations
221 || forceConvergence; completedIterations++)
224 arma::colvec newCentroid = arma::zeros<arma::colvec>(pSeeds->n_rows);
226 rangeSearcher.
Search(allCentroids.unsafe_col(i), validRadius,
227 neighbors, distances);
228 if (neighbors[0].size() == 0)
232 if (!CalculateCentroid(data, neighbors[0], distances[0], newCentroid))
233 newCentroid = allCentroids.unsafe_col(i);
237 allCentroids.unsafe_col(i)) < 1e-3 * radius)
240 bool isDuplicated =
false;
241 for (
size_t k = 0; k < centroids.n_cols; ++k)
244 allCentroids.unsafe_col(i), centroids.unsafe_col(k));
245 if (distance < radius)
253 centroids.insert_cols(centroids.n_cols, allCentroids.unsafe_col(i));
260 allCentroids.col(i) = newCentroid;
266 if (centroids.empty())
268 Log::Warn <<
"No clusters converged; setting 1 random centroid calculated. " 269 <<
"Try increasing the maximum number of iterations or setting the " 270 <<
"option to force convergence." << std::endl;
272 if (maxIterations == 0)
274 centroids.insert_cols(centroids.n_cols, data.col(0));
278 centroids.insert_cols(centroids.n_cols, allCentroids.col(0));
282 else if (centroids.n_cols == 1)
290 arma::mat neighborDistances;
291 arma::Mat<size_t> resultingNeighbors;
292 neighborSearcher.
Search(data, 1, resultingNeighbors, neighborDistances);
293 assignments = resultingNeighbors;
The RangeSearch class is a template class for performing range searches.
Definition: range_search.hpp:45
double Radius() const
Get the radius.
Definition: mean_shift.hpp:100
double EstimateRadius(const MatType &data, const double ratio=0.2)
Give an estimation of radius based on given dataset.
Definition: mean_shift_impl.hpp:53
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The NeighborSearch class is a template class for performing distance-based neighbor searches...
Definition: neighbor_search.hpp:88
This class implements mean shift clustering.
Definition: mean_shift.hpp:50
MeanShift(const double radius=0, const size_t maxIterations=1000, const KernelType kernel=KernelType())
Create a mean shift object and set the parameters which mean shift will be run with.
Definition: mean_shift_impl.hpp:34
static VecTypeA::elem_type Evaluate(const VecTypeA &a, const VecTypeB &b)
Computes the distance between two points.
Definition: lmetric_impl.hpp:24
Definition: mean_shift_impl.hpp:76
static MLPACK_EXPORT util::PrefixedOutStream Warn
Prints warning messages prefixed with [WARN ].
Definition: log.hpp:87
void Cluster(const MatType &data, arma::Row< size_t > &assignments, arma::mat ¢roids, bool forceConvergence=true, bool useSeeds=true)
Perform mean shift clustering on the data, returning a list of cluster assignments and centroids...
Definition: mean_shift_impl.hpp:184
void Search(const MatType &querySet, const math::Range &range, std::vector< std::vector< size_t >> &neighbors, std::vector< std::vector< double >> &distances)
Search for all reference points in the given range for each point in the query set, returning the results in the neighbors and distances objects.
Definition: range_search_impl.hpp:309
void Search(const MatType &querySet, const size_t k, arma::Mat< size_t > &neighbors, arma::mat &distances)
For each point in the query set, compute the nearest neighbors and store the output in the given matr...
Definition: neighbor_search_impl.hpp:389