15 #ifndef MLPACK_METHODS_KMEANS_DTNN_KMEANS_IMPL_HPP 16 #define MLPACK_METHODS_KMEANS_DTNN_KMEANS_IMPL_HPP 27 template<
typename TreeType,
typename MatType>
30 std::vector<size_t>& oldFromNew,
31 const typename std::enable_if<
36 return new TreeType(std::forward<MatType>(dataset), oldFromNew, 1);
40 template<
typename TreeType,
typename MatType>
43 const std::vector<size_t>& ,
44 const typename std::enable_if<
47 return new TreeType(std::forward<MatType>(dataset));
50 template<
typename MetricType,
52 template<
typename TreeMetricType,
53 typename TreeStatType,
54 typename TreeMatType>
class TreeType>
56 const MatType& dataset,
59 tree(new Tree(const_cast<MatType&>(dataset))),
60 dataset(tree->Dataset()),
62 distanceCalculations(0),
64 upperBounds(dataset.n_cols),
65 lowerBounds(dataset.n_cols),
66 prunedPoints(dataset.n_cols, false),
67 assignments(dataset.n_cols),
68 visited(dataset.n_cols, false)
70 for (
size_t i = 0; i < dataset.n_cols; ++i)
72 prunedPoints[i] =
false;
75 assignments.fill(
size_t(-1));
76 upperBounds.fill(DBL_MAX);
77 lowerBounds.fill(DBL_MAX);
80 template<
typename MetricType,
82 template<
typename TreeMetricType,
83 typename TreeStatType,
84 typename TreeMatType>
class TreeType>
92 template<
typename MetricType,
94 template<
typename TreeMetricType,
95 typename TreeStatType,
96 typename TreeMatType>
class TreeType>
98 const arma::mat& centroids,
99 arma::mat& newCentroids,
100 arma::Col<size_t>& counts)
104 std::vector<size_t> oldFromNewCentroids;
105 Tree* centroidTree = BuildTree<Tree>(centroids, oldFromNewCentroids);
111 NNSTreeType> nns(std::move(*centroidTree));
119 arma::mat* interclusterDistancesTemp =
121 new arma::mat(1, centroids.n_elem) : &interclusterDistances;
123 arma::Mat<size_t> closestClusters;
124 nns.Search(1, closestClusters, *interclusterDistancesTemp);
125 distanceCalculations += nns.BaseCases() + nns.Scores();
130 for (
size_t i = 0; i < interclusterDistances.n_elem; ++i)
131 interclusterDistances[oldFromNewCentroids[i]] =
132 (*interclusterDistancesTemp)[i];
134 delete interclusterDistancesTemp;
139 UpdateTree(*tree, centroids);
141 for (
size_t i = 0; i < dataset.n_cols; ++i)
147 clusterDistances.set_size(centroids.n_cols + 1);
148 interclusterDistances.set_size(1, centroids.n_cols);
152 lastIterationCentroids = centroids;
154 RuleType rules(nns.ReferenceTree().Dataset(), dataset, assignments,
155 upperBounds, lowerBounds, metric, prunedPoints, oldFromNewCentroids,
158 typename Tree::template BreadthFirstDualTreeTraverser<RuleType>
166 tree->Stat().Pruned() = 0;
167 traverser.Traverse(*tree, nns.ReferenceTree());
168 distanceCalculations += rules.BaseCases() + rules.Scores();
171 DecoalesceTree(*tree);
175 newCentroids.zeros(centroids.n_rows, centroids.n_cols);
176 counts.zeros(centroids.n_cols);
177 ExtractCentroids(*tree, newCentroids, counts, centroids);
180 double residual = 0.0;
181 clusterDistances[centroids.n_cols] = 0.0;
182 for (
size_t c = 0; c < centroids.n_cols; ++c)
186 clusterDistances[c] = 0;
190 newCentroids.col(c) /= counts(c);
191 const double movement = metric.Evaluate(centroids.col(c),
192 newCentroids.col(c));
193 clusterDistances[c] = movement;
194 residual += std::pow(movement, 2.0);
196 if (movement > clusterDistances[centroids.n_cols])
197 clusterDistances[centroids.n_cols] = movement;
200 distanceCalculations += centroids.n_cols;
206 return std::sqrt(residual);
209 template<
typename MetricType,
211 template<
typename TreeMetricType,
212 typename TreeStatType,
213 typename TreeMatType>
class TreeType>
214 void DualTreeKMeans<MetricType, MatType, TreeType>::UpdateTree(
216 const arma::mat& centroids,
217 const double parentUpperBound,
218 const double adjustedParentUpperBound,
219 const double parentLowerBound,
220 const double adjustedParentLowerBound)
222 const bool prunedLastIteration = node.Stat().StaticPruned();
223 node.Stat().StaticPruned() =
false;
226 if (node.Parent() != NULL &&
227 node.Parent()->Stat().Pruned() == centroids.n_cols &&
228 node.Parent()->Stat().Owner() < centroids.n_cols)
234 node.Stat().UpperBound() = parentUpperBound;
235 node.Stat().LowerBound() = parentLowerBound;
236 node.Stat().Pruned() = node.Parent()->Stat().Pruned();
237 node.Stat().Owner() = node.Parent()->Stat().Owner();
239 const double unadjustedUpperBound = node.Stat().UpperBound();
240 double adjustedUpperBound = adjustedParentUpperBound;
241 const double unadjustedLowerBound = node.Stat().LowerBound();
242 double adjustedLowerBound = adjustedParentLowerBound;
296 if ((node.Stat().Pruned() == centroids.n_cols) &&
297 (node.Stat().Owner() < centroids.n_cols))
300 node.Stat().UpperBound() += clusterDistances[node.Stat().Owner()];
301 node.Stat().LowerBound() -= clusterDistances[centroids.n_cols];
303 if (adjustedParentUpperBound < node.Stat().UpperBound())
304 node.Stat().UpperBound() = adjustedParentUpperBound;
306 if (adjustedParentLowerBound > node.Stat().LowerBound())
307 node.Stat().LowerBound() = adjustedParentLowerBound;
311 const double interclusterBound = interclusterDistances[node.Stat().Owner()]
313 if (interclusterBound > node.Stat().LowerBound())
315 node.Stat().LowerBound() = interclusterBound;
316 adjustedLowerBound = node.Stat().LowerBound();
319 if (node.Stat().UpperBound() < node.Stat().LowerBound())
321 node.Stat().StaticPruned() =
true;
326 node.Stat().UpperBound() =
327 std::min(node.Stat().UpperBound(),
328 node.MaxDistance(centroids.col(node.Stat().Owner())));
329 adjustedUpperBound = node.Stat().UpperBound();
331 ++distanceCalculations;
332 if (node.Stat().UpperBound() < node.Stat().LowerBound())
333 node.Stat().StaticPruned() =
true;
338 node.Stat().LowerBound() -= clusterDistances[centroids.n_cols];
343 bool allChildrenPruned =
true;
344 for (
size_t i = 0; i < node.NumChildren(); ++i)
346 UpdateTree(node.Child(i), centroids, unadjustedUpperBound,
347 adjustedUpperBound, unadjustedLowerBound, adjustedLowerBound);
348 if (!node.Child(i).Stat().StaticPruned())
349 allChildrenPruned =
false;
352 bool allPointsPruned =
true;
361 allPointsPruned = node.Child(0).Stat().StaticPruned();
363 else if (!node.Stat().StaticPruned())
366 for (
size_t i = 0; i < node.NumPoints(); ++i)
368 const size_t index = node.Point(i);
369 if (!visited[index] && !prunedPoints[index])
371 upperBounds[index] = DBL_MAX;
372 lowerBounds[index] = DBL_MAX;
373 allPointsPruned =
false;
378 if (prunedLastIteration)
382 upperBounds[index] += node.Stat().StaticUpperBoundMovement();
383 lowerBounds[index] -= node.Stat().StaticLowerBoundMovement();
386 prunedPoints[index] =
false;
387 const size_t owner = assignments[index];
388 const double lowerBound = std::min(lowerBounds[index] -
389 clusterDistances[centroids.n_cols], node.Stat().LowerBound());
390 const double pruningLowerBound = std::max(lowerBound,
391 interclusterDistances[owner] / 2.0);
392 if (upperBounds[index] + clusterDistances[owner] < pruningLowerBound)
394 prunedPoints[index] =
true;
395 upperBounds[index] += clusterDistances[owner];
396 lowerBounds[index] = pruningLowerBound;
401 upperBounds[index] = metric.Evaluate(dataset.col(index),
402 centroids.col(owner));
403 ++distanceCalculations;
404 if (upperBounds[index] < pruningLowerBound)
406 prunedPoints[index] =
true;
407 lowerBounds[index] = pruningLowerBound;
416 node.NumChildren() == 0)
418 upperBounds[index] = DBL_MAX;
419 lowerBounds[index] = DBL_MAX;
421 allPointsPruned =
false;
439 if (allChildrenPruned && allPointsPruned && !node.Stat().StaticPruned())
441 node.Stat().StaticPruned() =
true;
442 node.Stat().Owner() = centroids.n_cols;
443 node.Stat().Pruned() = size_t(-1);
446 if (!node.Stat().StaticPruned())
448 node.Stat().UpperBound() = DBL_MAX;
449 node.Stat().LowerBound() = DBL_MAX;
450 node.Stat().Pruned() = size_t(-1);
451 node.Stat().Owner() = centroids.n_cols;
452 node.Stat().StaticPruned() =
false;
456 if (prunedLastIteration)
459 node.Stat().StaticUpperBoundMovement() +=
460 clusterDistances[node.Stat().Owner()];
461 node.Stat().StaticLowerBoundMovement() +=
462 clusterDistances[centroids.n_cols];
466 node.Stat().StaticUpperBoundMovement() =
467 clusterDistances[node.Stat().Owner()];
468 node.Stat().StaticLowerBoundMovement() =
469 clusterDistances[centroids.n_cols];
474 template<
typename MetricType,
476 template<
typename TreeMetricType,
477 typename TreeStatType,
478 typename TreeMatType>
class TreeType>
479 void DualTreeKMeans<MetricType, MatType, TreeType>::ExtractCentroids(
481 arma::mat& newCentroids,
482 arma::Col<size_t>& newCounts,
483 const arma::mat& centroids)
486 if ((node.Stat().Pruned() == newCentroids.n_cols) ||
487 (node.Stat().StaticPruned() && node.Stat().Owner() < newCentroids.n_cols))
489 const size_t owner = node.Stat().Owner();
490 newCentroids.col(owner) += node.Stat().Centroid() * node.NumDescendants();
491 newCounts[owner] += node.NumDescendants();
524 if (node.NumChildren() == 0)
526 for (
size_t i = 0; i < node.NumPoints(); ++i)
528 const size_t owner = assignments[node.Point(i)];
529 newCentroids.col(owner) += dataset.col(node.Point(i));
562 for (
size_t i = 0; i < node.NumChildren(); ++i)
563 ExtractCentroids(node.Child(i), newCentroids, newCounts, centroids);
567 template<
typename MetricType,
569 template<
typename TreeMetricType,
570 typename TreeStatType,
571 typename TreeMatType>
class TreeType>
572 void DualTreeKMeans<MetricType, MatType, TreeType>::CoalesceTree(
577 if (node.NumChildren() == 0)
581 if (node.Parent() != NULL)
584 for (
size_t i = node.NumChildren() - 1; i > 0; --i)
586 if (node.Child(i).Stat().StaticPruned())
589 CoalesceTree(node.Child(i), i);
592 if (node.Child(0).Stat().StaticPruned())
595 CoalesceTree(node.Child(0), 0);
601 if (node.NumChildren() == 1)
603 node.Child(0).Parent() = node.Parent();
604 node.Parent()->ChildPtr(child) = node.ChildPtr(0);
611 for (
size_t i = 0; i < node.NumChildren(); ++i)
612 CoalesceTree(node.Child(i), i);
616 template<
typename MetricType,
618 template<
typename TreeMetricType,
619 typename TreeStatType,
620 typename TreeMatType>
class TreeType>
621 void DualTreeKMeans<MetricType, MatType, TreeType>::DecoalesceTree(Tree& node)
623 node.Parent() = (
Tree*) node.Stat().TrueParent();
626 for (
size_t i = 0; i < node.NumChildren(); ++i)
627 DecoalesceTree(node.Child(i));
631 template<
typename TreeType>
634 const typename std::enable_if_t<
639 node.Children().erase(node.Children().begin() + child);
643 template<
typename TreeType>
646 const typename std::enable_if_t<
653 node.ChildPtr(0) = node.ChildPtr(1);
654 node.ChildPtr(1) = NULL;
658 node.ChildPtr(1) = NULL;
663 template<
typename TreeType>
665 const typename std::enable_if_t<
668 node.Children().clear();
669 node.Children().resize(node.Stat().NumTrueChildren());
670 for (
size_t i = 0; i < node.Stat().NumTrueChildren(); ++i)
671 node.Children()[i] = (TreeType*) node.Stat().TrueChild(i);
675 template<
typename TreeType>
677 const typename std::enable_if_t<
680 if (node.Stat().NumTrueChildren() > 0)
682 node.ChildPtr(0) = (TreeType*) node.Stat().TrueChild(0);
683 node.ChildPtr(1) = (TreeType*) node.Stat().TrueChild(1);
static void Start(const std::string &name)
Start the given timer.
Definition: timers.cpp:28
TreeType< MetricType, DualTreeKMeansStatistic, MatType > Tree
Convenience typedef.
Definition: dual_tree_kmeans.hpp:45
void HideChild(TreeType &node, const size_t child, const typename std::enable_if_t< !tree::TreeTraits< TreeType >::BinaryTree > *junk=0)
Utility function for hiding children.
Definition: dual_tree_kmeans_impl.hpp:632
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The NeighborSearch class is a template class for performing distance-based neighbor searches...
Definition: neighbor_search.hpp:88
void RestoreChildren(TreeType &node, const typename std::enable_if_t<!tree::TreeTraits< TreeType >::BinaryTree > *junk=0)
Utility function for restoring children to a non-binary tree.
An algorithm for an exact Lloyd iteration which simply uses dual-tree nearest-neighbor search to find...
Definition: dual_tree_kmeans.hpp:41
This class implements the necessary methods for the SortPolicy template parameter of the NeighborSear...
Definition: nearest_neighbor_sort.hpp:31
Definition: dual_tree_kmeans_rules.hpp:23
TreeType * BuildTree(MatType &&dataset, std::vector< size_t > &oldFromNew, const typename std::enable_if< tree::TreeTraits< TreeType >::RearrangesDataset >::type *=0)
Call the tree constructor that does mapping.
Definition: dual_tree_kmeans_impl.hpp:28
static void Stop(const std::string &name)
Stop the given timer.
Definition: timers.cpp:36
The TreeTraits class provides compile-time information on the characteristics of a given tree type...
Definition: tree_traits.hpp:77