14 #ifndef MLPACK_METHODS_KMEANS_PELLEG_MOORE_KMEANS_RULES_IMPL_HPP 15 #define MLPACK_METHODS_KMEANS_PELLEG_MOORE_KMEANS_RULES_IMPL_HPP 23 template<
typename MetricType,
typename TreeType>
25 const typename TreeType::Mat& dataset,
26 const arma::mat& centroids,
27 arma::mat& newCentroids,
28 arma::Col<size_t>& counts,
32 newCentroids(newCentroids),
35 distanceCalculations(0)
40 template<
typename MetricType,
typename TreeType>
49 template<
typename MetricType,
typename TreeType>
52 TreeType& referenceNode)
57 if (referenceNode.Parent() == NULL ||
58 referenceNode.Parent()->Stat().Blacklist().n_elem == 0)
59 referenceNode.Stat().Blacklist().zeros(centroids.n_cols);
61 referenceNode.Stat().Blacklist() =
62 referenceNode.Parent()->Stat().Blacklist();
67 const size_t whitelisted = centroids.n_cols -
68 arma::accu(referenceNode.Stat().Blacklist());
70 distanceCalculations += whitelisted;
73 size_t closestCluster = centroids.n_cols;
74 double minMinDistance = DBL_MAX;
75 for (
size_t i = 0; i < centroids.n_cols; ++i)
77 if (referenceNode.Stat().Blacklist()[i] == 0)
79 const double minDistance = referenceNode.MinDistance(centroids.col(i));
80 if (minDistance < minMinDistance)
82 minMinDistance = minDistance;
93 size_t newBlacklisted = 0;
94 for (
size_t c = 0; c < centroids.n_cols; ++c)
96 if (referenceNode.Stat().Blacklist()[c] == 1 || c == closestCluster)
102 arma::vec cornerPoint(centroids.n_rows);
103 for (
size_t d = 0; d < referenceNode.Bound().Dim(); ++d)
105 if (centroids(d, c) > centroids(d, closestCluster))
106 cornerPoint(d) = referenceNode.Bound()[d].Hi();
108 cornerPoint(d) = referenceNode.Bound()[d].Lo();
111 const double closestDist = metric.Evaluate(cornerPoint,
112 centroids.col(closestCluster));
113 const double otherDist = metric.Evaluate(cornerPoint, centroids.col(c));
115 distanceCalculations += 3;
117 if (closestDist < otherDist)
121 referenceNode.Stat().Blacklist()[c] = 1;
126 if (whitelisted - newBlacklisted == 1)
129 counts[closestCluster] += referenceNode.NumDescendants();
130 newCentroids.col(closestCluster) += referenceNode.NumDescendants() *
131 referenceNode.Stat().Centroid();
137 for (
size_t i = 0; i < referenceNode.NumPoints(); ++i)
139 size_t bestCluster = centroids.n_cols;
140 double bestDistance = DBL_MAX;
141 for (
size_t c = 0; c < centroids.n_cols; ++c)
143 if (referenceNode.Stat().Blacklist()[c] == 1)
146 ++distanceCalculations;
149 const double distance = metric.Evaluate(centroids.col(c),
150 dataset.col(referenceNode.Point(i)));
152 if (distance < bestDistance)
154 bestDistance = distance;
160 newCentroids.col(bestCluster) += dataset.col(referenceNode.Point(i));
161 ++counts(bestCluster);
169 template<
typename MetricType,
typename TreeType>
173 const double oldScore)
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
double Rescore(const size_t queryIndex, TreeType &referenceNode, const double oldScore)
Rescore to determine if a node can be pruned.
Definition: pelleg_moore_kmeans_rules_impl.hpp:170
double BaseCase(const size_t queryIndex, const size_t referenceIndex)
The BaseCase() function for this single-tree algorithm does nothing.
Definition: pelleg_moore_kmeans_rules_impl.hpp:42
PellegMooreKMeansRules(const typename TreeType::Mat &dataset, const arma::mat ¢roids, arma::mat &newCentroids, arma::Col< size_t > &counts, MetricType &metric)
Create the PellegMooreKMeansRules object.
Definition: pelleg_moore_kmeans_rules_impl.hpp:24
double Score(const size_t queryIndex, TreeType &referenceNode)
Determine if a cluster can be pruned, and if not, perform point-to-cluster comparisons.
Definition: pelleg_moore_kmeans_rules_impl.hpp:50