mlpack
dual_tree_kmeans_impl.hpp
Go to the documentation of this file.
1 
15 #ifndef MLPACK_METHODS_KMEANS_DTNN_KMEANS_IMPL_HPP
16 #define MLPACK_METHODS_KMEANS_DTNN_KMEANS_IMPL_HPP
17 
18 // In case it hasn't been included yet.
19 #include "dual_tree_kmeans.hpp"
20 
22 
23 namespace mlpack {
24 namespace kmeans {
25 
27 template<typename TreeType, typename MatType>
28 TreeType* BuildTree(
29  MatType&& dataset,
30  std::vector<size_t>& oldFromNew,
31  const typename std::enable_if<
33 {
34  // This is a hack. I know this will be BinarySpaceTree, so force a leaf size
35  // of two.
36  return new TreeType(std::forward<MatType>(dataset), oldFromNew, 1);
37 }
38 
40 template<typename TreeType, typename MatType>
41 TreeType* BuildTree(
42  MatType&& dataset,
43  const std::vector<size_t>& /* oldFromNew */,
44  const typename std::enable_if<
46 {
47  return new TreeType(std::forward<MatType>(dataset));
48 }
49 
50 template<typename MetricType,
51  typename MatType,
52  template<typename TreeMetricType,
53  typename TreeStatType,
54  typename TreeMatType> class TreeType>
55 DualTreeKMeans<MetricType, MatType, TreeType>::DualTreeKMeans(
56  const MatType& dataset,
57  MetricType& metric) :
58  datasetOrig(dataset),
59  tree(new Tree(const_cast<MatType&>(dataset))),
60  dataset(tree->Dataset()),
61  metric(metric),
62  distanceCalculations(0),
63  iteration(0),
64  upperBounds(dataset.n_cols),
65  lowerBounds(dataset.n_cols),
66  prunedPoints(dataset.n_cols, false), // Fill with false.
67  assignments(dataset.n_cols),
68  visited(dataset.n_cols, false) // Fill with false.
69 {
70  for (size_t i = 0; i < dataset.n_cols; ++i)
71  {
72  prunedPoints[i] = false;
73  visited[i] = false;
74  }
75  assignments.fill(size_t(-1));
76  upperBounds.fill(DBL_MAX);
77  lowerBounds.fill(DBL_MAX);
78 }
79 
80 template<typename MetricType,
81  typename MatType,
82  template<typename TreeMetricType,
83  typename TreeStatType,
84  typename TreeMatType> class TreeType>
85 DualTreeKMeans<MetricType, MatType, TreeType>::~DualTreeKMeans()
86 {
87  if (tree)
88  delete tree;
89 }
90 
91 // Run a single iteration.
92 template<typename MetricType,
93  typename MatType,
94  template<typename TreeMetricType,
95  typename TreeStatType,
96  typename TreeMatType> class TreeType>
97 double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
98  const arma::mat& centroids,
99  arma::mat& newCentroids,
100  arma::Col<size_t>& counts)
101 {
102  // Build a tree on the centroids. This will make a copy if necessary, which
103  // is unfortunate, but I don't see a reasonable way around it.
104  std::vector<size_t> oldFromNewCentroids;
105  Tree* centroidTree = BuildTree<Tree>(centroids, oldFromNewCentroids);
106 
107  // Find the nearest neighbors of each of the clusters. We have to make our
108  // own TreeType, which is a little bit abuse, but we know for sure the
109  // TreeStatType we have will work.
111  NNSTreeType> nns(std::move(*centroidTree));
112 
113  // Reset information in the tree, if we need to.
114  if (iteration > 0)
115  {
116  Timer::Start("knn");
117 
118  // If the tree maps points, we need an intermediate result matrix.
119  arma::mat* interclusterDistancesTemp =
121  new arma::mat(1, centroids.n_elem) : &interclusterDistances;
122 
123  arma::Mat<size_t> closestClusters; // We don't actually care about these.
124  nns.Search(1, closestClusters, *interclusterDistancesTemp);
125  distanceCalculations += nns.BaseCases() + nns.Scores();
126 
127  // We need to do the unmapping ourselves, if the tree does mapping.
129  {
130  for (size_t i = 0; i < interclusterDistances.n_elem; ++i)
131  interclusterDistances[oldFromNewCentroids[i]] =
132  (*interclusterDistancesTemp)[i];
133 
134  delete interclusterDistancesTemp;
135  }
136 
137  Timer::Stop("knn");
138 
139  UpdateTree(*tree, centroids);
140 
141  for (size_t i = 0; i < dataset.n_cols; ++i)
142  visited[i] = false;
143  }
144  else
145  {
146  // Not initialized yet.
147  clusterDistances.set_size(centroids.n_cols + 1);
148  interclusterDistances.set_size(1, centroids.n_cols);
149  }
150 
151  // We won't use the KNN class here because we have our own set of rules.
152  lastIterationCentroids = centroids;
153  typedef DualTreeKMeansRules<MetricType, Tree> RuleType;
154  RuleType rules(nns.ReferenceTree().Dataset(), dataset, assignments,
155  upperBounds, lowerBounds, metric, prunedPoints, oldFromNewCentroids,
156  visited);
157 
158  typename Tree::template BreadthFirstDualTreeTraverser<RuleType>
159  traverser(rules);
160 
161  Timer::Start("tree_mod");
162  CoalesceTree(*tree);
163  Timer::Stop("tree_mod");
164 
165  // Set the number of pruned centroids in the root to 0.
166  tree->Stat().Pruned() = 0;
167  traverser.Traverse(*tree, nns.ReferenceTree());
168  distanceCalculations += rules.BaseCases() + rules.Scores();
169 
170  Timer::Start("tree_mod");
171  DecoalesceTree(*tree);
172  Timer::Stop("tree_mod");
173 
174  // Now we need to extract the clusters.
175  newCentroids.zeros(centroids.n_rows, centroids.n_cols);
176  counts.zeros(centroids.n_cols);
177  ExtractCentroids(*tree, newCentroids, counts, centroids);
178 
179  // Now, calculate how far the clusters moved, after normalizing them.
180  double residual = 0.0;
181  clusterDistances[centroids.n_cols] = 0.0;
182  for (size_t c = 0; c < centroids.n_cols; ++c)
183  {
184  if (counts[c] == 0)
185  {
186  clusterDistances[c] = 0;
187  }
188  else
189  {
190  newCentroids.col(c) /= counts(c);
191  const double movement = metric.Evaluate(centroids.col(c),
192  newCentroids.col(c));
193  clusterDistances[c] = movement;
194  residual += std::pow(movement, 2.0);
195 
196  if (movement > clusterDistances[centroids.n_cols])
197  clusterDistances[centroids.n_cols] = movement;
198  }
199  }
200  distanceCalculations += centroids.n_cols;
201 
202  delete centroidTree;
203 
204  ++iteration;
205 
206  return std::sqrt(residual);
207 }
208 
209 template<typename MetricType,
210  typename MatType,
211  template<typename TreeMetricType,
212  typename TreeStatType,
213  typename TreeMatType> class TreeType>
214 void DualTreeKMeans<MetricType, MatType, TreeType>::UpdateTree(
215  Tree& node,
216  const arma::mat& centroids,
217  const double parentUpperBound,
218  const double adjustedParentUpperBound,
219  const double parentLowerBound,
220  const double adjustedParentLowerBound)
221 {
222  const bool prunedLastIteration = node.Stat().StaticPruned();
223  node.Stat().StaticPruned() = false;
224 
225  // Grab information from the parent, if we can.
226  if (node.Parent() != NULL &&
227  node.Parent()->Stat().Pruned() == centroids.n_cols &&
228  node.Parent()->Stat().Owner() < centroids.n_cols)
229  {
230  // When taking bounds from the parent, note that the parent has already
231  // adjusted the bounds according to the cluster movements, so we need to
232  // de-adjust them since we'll adjust them again. Maybe there is a smarter
233  // way to do this...
234  node.Stat().UpperBound() = parentUpperBound;
235  node.Stat().LowerBound() = parentLowerBound;
236  node.Stat().Pruned() = node.Parent()->Stat().Pruned();
237  node.Stat().Owner() = node.Parent()->Stat().Owner();
238  }
239  const double unadjustedUpperBound = node.Stat().UpperBound();
240  double adjustedUpperBound = adjustedParentUpperBound;
241  const double unadjustedLowerBound = node.Stat().LowerBound();
242  double adjustedLowerBound = adjustedParentLowerBound;
243 
244  // Exhaustive lower bound check. Sigh.
245 /*
246  if (!prunedLastIteration)
247  {
248  for (size_t i = 0; i < node.NumDescendants(); ++i)
249  {
250  double closest = DBL_MAX;
251  double secondClosest = DBL_MAX;
252  arma::vec distances(centroids.n_cols);
253  for (size_t j = 0; j < centroids.n_cols; ++j)
254  {
255  const double dist = metric.Evaluate(dataset.col(node.Descendant(i)),
256  lastIterationCentroids.col(j));
257  distances(j) = dist;
258 
259  if (dist < closest)
260  {
261  secondClosest = closest;
262  closest = dist;
263  }
264  else if (dist < secondClosest)
265  secondClosest = dist;
266  }
267  if (closest - 1e-10 > node.Stat().UpperBound())
268  {
269  Log::Warn << distances.t();
270  Log::Fatal << "Point " << node.Descendant(i) << " in " << node.Point(0) <<
271 "c" << node.NumDescendants() << " invalidates upper bound " <<
272 node.Stat().UpperBound() << " with closest cluster distance " << closest <<
273 ".\n";
274  }
275 
276  if (node.NumChildren() == 0)
277  {
278  if (secondClosest + 1e-10 < std::min(lowerBounds[node.Descendant(i)],
279  node.Stat().LowerBound()))
280  {
281  Log::Warn << distances.t();
282  Log::Warn << node;
283  Log::Fatal << "Point " << node.Descendant(i) << " in " << node.Point(0) <<
284 "c" << node.NumDescendants() << " invalidates lower bound " <<
285 std::min(lowerBounds[node.Descendant(i)], node.Stat().LowerBound()) << " (" <<
286 lowerBounds[node.Descendant(i)] << ", " << node.Stat().LowerBound() << ") with "
287  << "second closest cluster distance " << secondClosest << ". cd " <<
288 closest << "; pruned " << prunedPoints[node.Descendant(i)] << " visited " <<
289 visited[node.Descendant(i)] << ".\n";
290  }
291  }
292  }
293  }
294 */
295 
296  if ((node.Stat().Pruned() == centroids.n_cols) &&
297  (node.Stat().Owner() < centroids.n_cols))
298  {
299  // Adjust bounds.
300  node.Stat().UpperBound() += clusterDistances[node.Stat().Owner()];
301  node.Stat().LowerBound() -= clusterDistances[centroids.n_cols];
302 
303  if (adjustedParentUpperBound < node.Stat().UpperBound())
304  node.Stat().UpperBound() = adjustedParentUpperBound;
305 
306  if (adjustedParentLowerBound > node.Stat().LowerBound())
307  node.Stat().LowerBound() = adjustedParentLowerBound;
308 
309  // Try to use the inter-cluster distances to produce a better lower bound,
310  // if possible.
311  const double interclusterBound = interclusterDistances[node.Stat().Owner()]
312  / 2.0;
313  if (interclusterBound > node.Stat().LowerBound())
314  {
315  node.Stat().LowerBound() = interclusterBound;
316  adjustedLowerBound = node.Stat().LowerBound();
317  }
318 
319  if (node.Stat().UpperBound() < node.Stat().LowerBound())
320  {
321  node.Stat().StaticPruned() = true;
322  }
323  else
324  {
325  // Tighten bound.
326  node.Stat().UpperBound() =
327  std::min(node.Stat().UpperBound(),
328  node.MaxDistance(centroids.col(node.Stat().Owner())));
329  adjustedUpperBound = node.Stat().UpperBound();
330 
331  ++distanceCalculations;
332  if (node.Stat().UpperBound() < node.Stat().LowerBound())
333  node.Stat().StaticPruned() = true;
334  }
335  }
336  else
337  {
338  node.Stat().LowerBound() -= clusterDistances[centroids.n_cols];
339  }
340 
341  // Recurse into children, and if all the children (and all the points) are
342  // pruned, then we can mark this as statically pruned.
343  bool allChildrenPruned = true;
344  for (size_t i = 0; i < node.NumChildren(); ++i)
345  {
346  UpdateTree(node.Child(i), centroids, unadjustedUpperBound,
347  adjustedUpperBound, unadjustedLowerBound, adjustedLowerBound);
348  if (!node.Child(i).Stat().StaticPruned())
349  allChildrenPruned = false;
350  }
351 
352  bool allPointsPruned = true;
353  if (tree::TreeTraits<Tree>::HasSelfChildren && node.NumChildren() > 0)
354  {
355  // If this tree type has self-children, then we have already adjusted the
356  // point bounds at a lower level, and we can determine if all of our points
357  // are pruned simply by seeing if all of the children's points are pruned.
358  // This particular line below additionally assumes that each node's points
359  // are all contained in its first child. This is valid for the cover tree,
360  // but maybe not others.
361  allPointsPruned = node.Child(0).Stat().StaticPruned();
362  }
363  else if (!node.Stat().StaticPruned())
364  {
365  // Try to prune individual points.
366  for (size_t i = 0; i < node.NumPoints(); ++i)
367  {
368  const size_t index = node.Point(i);
369  if (!visited[index] && !prunedPoints[index])
370  {
371  upperBounds[index] = DBL_MAX; // Reset the bounds.
372  lowerBounds[index] = DBL_MAX;
373  allPointsPruned = false;
374  continue; // We didn't visit it and we don't have valid bounds -- so we
375  // can't prune it.
376  }
377 
378  if (prunedLastIteration)
379  {
380  // It was pruned last iteration but not this iteration.
381  // Set the bounds correctly.
382  upperBounds[index] += node.Stat().StaticUpperBoundMovement();
383  lowerBounds[index] -= node.Stat().StaticLowerBoundMovement();
384  }
385 
386  prunedPoints[index] = false;
387  const size_t owner = assignments[index];
388  const double lowerBound = std::min(lowerBounds[index] -
389  clusterDistances[centroids.n_cols], node.Stat().LowerBound());
390  const double pruningLowerBound = std::max(lowerBound,
391  interclusterDistances[owner] / 2.0);
392  if (upperBounds[index] + clusterDistances[owner] < pruningLowerBound)
393  {
394  prunedPoints[index] = true;
395  upperBounds[index] += clusterDistances[owner];
396  lowerBounds[index] = pruningLowerBound;
397  }
398  else
399  {
400  // Attempt to tighten the bound.
401  upperBounds[index] = metric.Evaluate(dataset.col(index),
402  centroids.col(owner));
403  ++distanceCalculations;
404  if (upperBounds[index] < pruningLowerBound)
405  {
406  prunedPoints[index] = true;
407  lowerBounds[index] = pruningLowerBound;
408  }
409  else
410  {
411  // Point cannot be pruned. We may have to inspect the point at a
412  // lower level, though. If that's the case, then we shouldn't
413  // invalidate the bounds we've got -- it will happen at the lower
414  // level.
416  node.NumChildren() == 0)
417  {
418  upperBounds[index] = DBL_MAX;
419  lowerBounds[index] = DBL_MAX;
420  }
421  allPointsPruned = false;
422  }
423  }
424  }
425  }
426 
427 /*
428  if (node.Stat().StaticPruned() && !allChildrenPruned)
429  {
430  Log::Warn << node;
431  for (size_t i = 0; i < node.NumChildren(); ++i)
432  Log::Warn << "child " << i << ":\n" << node.Child(i);
433  Log::Fatal << "Node is statically pruned but not all its children are!\n";
434  }
435 */
436 
437  // If all of the children and points are pruned, we may mark this node as
438  // pruned.
439  if (allChildrenPruned && allPointsPruned && !node.Stat().StaticPruned())
440  {
441  node.Stat().StaticPruned() = true;
442  node.Stat().Owner() = centroids.n_cols; // Invalid owner.
443  node.Stat().Pruned() = size_t(-1);
444  }
445 
446  if (!node.Stat().StaticPruned())
447  {
448  node.Stat().UpperBound() = DBL_MAX;
449  node.Stat().LowerBound() = DBL_MAX;
450  node.Stat().Pruned() = size_t(-1);
451  node.Stat().Owner() = centroids.n_cols;
452  node.Stat().StaticPruned() = false;
453  }
454  else // The node is now pruned.
455  {
456  if (prunedLastIteration)
457  {
458  // Track total movement while pruned.
459  node.Stat().StaticUpperBoundMovement() +=
460  clusterDistances[node.Stat().Owner()];
461  node.Stat().StaticLowerBoundMovement() +=
462  clusterDistances[centroids.n_cols];
463  }
464  else
465  {
466  node.Stat().StaticUpperBoundMovement() =
467  clusterDistances[node.Stat().Owner()];
468  node.Stat().StaticLowerBoundMovement() =
469  clusterDistances[centroids.n_cols];
470  }
471  }
472 }
473 
474 template<typename MetricType,
475  typename MatType,
476  template<typename TreeMetricType,
477  typename TreeStatType,
478  typename TreeMatType> class TreeType>
479 void DualTreeKMeans<MetricType, MatType, TreeType>::ExtractCentroids(
480  Tree& node,
481  arma::mat& newCentroids,
482  arma::Col<size_t>& newCounts,
483  const arma::mat& centroids)
484 {
485  // Does this node own points?
486  if ((node.Stat().Pruned() == newCentroids.n_cols) ||
487  (node.Stat().StaticPruned() && node.Stat().Owner() < newCentroids.n_cols))
488  {
489  const size_t owner = node.Stat().Owner();
490  newCentroids.col(owner) += node.Stat().Centroid() * node.NumDescendants();
491  newCounts[owner] += node.NumDescendants();
492 
493  // Perform the sanity check here.
494 /*
495  for (size_t i = 0; i < node.NumDescendants(); ++i)
496  {
497  const size_t index = node.Descendant(i);
498  arma::vec trueDistances(centroids.n_cols);
499  for (size_t j = 0; j < centroids.n_cols; ++j)
500  {
501  const double dist = metric.Evaluate(dataset.col(index),
502  centroids.col(j));
503  trueDistances[j] = dist;
504  }
505 
506  arma::uword minIndex;
507  const double minDist = trueDistances.min(minIndex);
508  if (size_t(minIndex) != owner)
509  {
510  Log::Warn << node;
511  Log::Warn << trueDistances.t();
512  Log::Fatal << "Point " << index << " of node " << node.Point(0) << "c"
513 << node.NumDescendants() << " has true minimum cluster " << minIndex << " with "
514  << "distance " << minDist << " but node is pruned with upper bound " <<
515 node.Stat().UpperBound() << " and owner " << node.Stat().Owner() << ".\n";
516  }
517  }
518 */
519  }
520  else
521  {
522  // Check each point held in the node.
523  // Only check at leaves.
524  if (node.NumChildren() == 0)
525  {
526  for (size_t i = 0; i < node.NumPoints(); ++i)
527  {
528  const size_t owner = assignments[node.Point(i)];
529  newCentroids.col(owner) += dataset.col(node.Point(i));
530  ++newCounts[owner];
531 
532 /*
533  const size_t index = node.Point(i);
534  arma::vec trueDistances(centroids.n_cols);
535  for (size_t j = 0; j < centroids.n_cols; ++j)
536  {
537  const double dist = metric.Evaluate(dataset.col(index),
538  centroids.col(j));
539  trueDistances[j] = dist;
540  }
541 
542  arma::uword minIndex;
543  const double minDist = trueDistances.min(minIndex);
544  if (size_t(minIndex) != owner)
545  {
546  Log::Warn << node;
547  Log::Warn << trueDistances.t();
548  Log::Fatal << "Point " << index << " of node " << node.Point(0) << "c"
549  << node.NumDescendants() << " has true minimum cluster " << minIndex << " with "
550  << "distance " << minDist << " but was assigned to cluster " <<
551 assignments[node.Point(i)] << " with ub " << upperBounds[node.Point(i)] <<
552 " and lb " << lowerBounds[node.Point(i)] << "; pp " <<
553 (prunedPoints[node.Point(i)] ? "true" : "false") << ", visited " <<
554 (visited[node.Point(i)] ? "true"
555 : "false") << ".\n";
556  }
557 */
558  }
559  }
560 
561  // The node is not entirely owned by a cluster. Recurse.
562  for (size_t i = 0; i < node.NumChildren(); ++i)
563  ExtractCentroids(node.Child(i), newCentroids, newCounts, centroids);
564  }
565 }
566 
567 template<typename MetricType,
568  typename MatType,
569  template<typename TreeMetricType,
570  typename TreeStatType,
571  typename TreeMatType> class TreeType>
572 void DualTreeKMeans<MetricType, MatType, TreeType>::CoalesceTree(
573  Tree& node,
574  const size_t child /* Which child are we? */)
575 {
576  // If all children except one are pruned, we can hide this node.
577  if (node.NumChildren() == 0)
578  return; // We can't do anything.
579 
580  // If this is the root node, we can't coalesce.
581  if (node.Parent() != NULL)
582  {
583  // First, we should coalesce those nodes that aren't statically pruned.
584  for (size_t i = node.NumChildren() - 1; i > 0; --i)
585  {
586  if (node.Child(i).Stat().StaticPruned())
587  HideChild(node, i);
588  else
589  CoalesceTree(node.Child(i), i);
590  }
591 
592  if (node.Child(0).Stat().StaticPruned())
593  HideChild(node, 0);
594  else
595  CoalesceTree(node.Child(0), 0);
596 
597  // If we've pruned all but one child, then notPrunedIndex will contain the
598  // index of that child, and we can coalesce this node entirely. Note that
599  // the case where all children are statically pruned should not happen,
600  // because then this node should itself be statically pruned.
601  if (node.NumChildren() == 1)
602  {
603  node.Child(0).Parent() = node.Parent();
604  node.Parent()->ChildPtr(child) = node.ChildPtr(0);
605  }
606  }
607  else
608  {
609  // We can't coalesce the root, so call the children individually and
610  // coalesce them.
611  for (size_t i = 0; i < node.NumChildren(); ++i)
612  CoalesceTree(node.Child(i), i);
613  }
614 }
615 
616 template<typename MetricType,
617  typename MatType,
618  template<typename TreeMetricType,
619  typename TreeStatType,
620  typename TreeMatType> class TreeType>
621 void DualTreeKMeans<MetricType, MatType, TreeType>::DecoalesceTree(Tree& node)
622 {
623  node.Parent() = (Tree*) node.Stat().TrueParent();
624  RestoreChildren(node);
625 
626  for (size_t i = 0; i < node.NumChildren(); ++i)
627  DecoalesceTree(node.Child(i));
628 }
629 
631 template<typename TreeType>
632 void HideChild(TreeType& node,
633  const size_t child,
634  const typename std::enable_if_t<
636 {
637  // We're going to assume we have a Children() function open to us. If we
638  // don't, then this won't work, I guess...
639  node.Children().erase(node.Children().begin() + child);
640 }
641 
643 template<typename TreeType>
644 void HideChild(TreeType& node,
645  const size_t child,
646  const typename std::enable_if_t<
648 {
649  // If we're hiding the left child, then take the right child as the new left
650  // child.
651  if (child == 0)
652  {
653  node.ChildPtr(0) = node.ChildPtr(1);
654  node.ChildPtr(1) = NULL;
655  }
656  else
657  {
658  node.ChildPtr(1) = NULL;
659  }
660 }
661 
663 template<typename TreeType>
664 void RestoreChildren(TreeType& node,
665  const typename std::enable_if_t<
667 {
668  node.Children().clear();
669  node.Children().resize(node.Stat().NumTrueChildren());
670  for (size_t i = 0; i < node.Stat().NumTrueChildren(); ++i)
671  node.Children()[i] = (TreeType*) node.Stat().TrueChild(i);
672 }
673 
675 template<typename TreeType>
676 void RestoreChildren(TreeType& node,
677  const typename std::enable_if_t<
679 {
680  if (node.Stat().NumTrueChildren() > 0)
681  {
682  node.ChildPtr(0) = (TreeType*) node.Stat().TrueChild(0);
683  node.ChildPtr(1) = (TreeType*) node.Stat().TrueChild(1);
684  }
685 }
686 
687 } // namespace kmeans
688 } // namespace mlpack
689 
690 #endif
static void Start(const std::string &name)
Start the given timer.
Definition: timers.cpp:28
TreeType< MetricType, DualTreeKMeansStatistic, MatType > Tree
Convenience typedef.
Definition: dual_tree_kmeans.hpp:45
void HideChild(TreeType &node, const size_t child, const typename std::enable_if_t< !tree::TreeTraits< TreeType >::BinaryTree > *junk=0)
Utility function for hiding children.
Definition: dual_tree_kmeans_impl.hpp:632
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The NeighborSearch class is a template class for performing distance-based neighbor searches...
Definition: neighbor_search.hpp:88
void RestoreChildren(TreeType &node, const typename std::enable_if_t<!tree::TreeTraits< TreeType >::BinaryTree > *junk=0)
Utility function for restoring children to a non-binary tree.
An algorithm for an exact Lloyd iteration which simply uses dual-tree nearest-neighbor search to find...
Definition: dual_tree_kmeans.hpp:41
This class implements the necessary methods for the SortPolicy template parameter of the NeighborSear...
Definition: nearest_neighbor_sort.hpp:31
Definition: dual_tree_kmeans_rules.hpp:23
TreeType * BuildTree(MatType &&dataset, std::vector< size_t > &oldFromNew, const typename std::enable_if< tree::TreeTraits< TreeType >::RearrangesDataset >::type *=0)
Call the tree constructor that does mapping.
Definition: dual_tree_kmeans_impl.hpp:28
static void Stop(const std::string &name)
Stop the given timer.
Definition: timers.cpp:36
The TreeTraits class provides compile-time information on the characteristics of a given tree type...
Definition: tree_traits.hpp:77