mlpack
space_split_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_CORE_TREE_SPILL_TREE_SPACE_SPLIT_IMPL_HPP
14 #define MLPACK_CORE_TREE_SPILL_TREE_SPACE_SPLIT_IMPL_HPP
15 
16 #include "space_split.hpp"
17 
18 namespace mlpack {
19 namespace tree {
20 
21 template<typename MetricType, typename MatType>
23  const bound::HRectBound<MetricType>& bound,
24  const MatType& data,
25  const arma::Col<size_t>& /* points */,
26  AxisParallelProjVector& projVector,
27  double& midValue)
28 {
29  // Get the dimension that has the maximum width.
30  size_t splitDim = data.n_rows; // Indicate invalid.
31  double maxWidth = -1;
32 
33  for (size_t d = 0; d < data.n_rows; d++)
34  {
35  const double width = bound[d].Width();
36 
37  if (width > maxWidth)
38  {
39  maxWidth = width;
40  splitDim = d;
41  }
42  }
43 
44  if (maxWidth <= 0) // All these points are the same.
45  return false;
46 
47  projVector = AxisParallelProjVector(splitDim);
48 
49  midValue = bound[splitDim].Mid();
50 
51  return true;
52 }
53 
54 template<typename MetricType, typename MatType>
55 template<typename BoundType>
57  const BoundType& /* bound */,
58  const MatType& data,
59  const arma::Col<size_t>& points,
60  ProjVector& projVector,
61  double& midValue)
62 {
63  MetricType metric;
64 
65  // Efficiently estimate the farthest pair of points in the given set.
66  size_t fst = points[rand() % points.n_elem];
67  size_t snd = points[0];
68  double max = metric.Evaluate(data.col(fst), data.col(snd));
69 
70  for (size_t i = 1; i < points.n_elem; ++i)
71  {
72  double dist = metric.Evaluate(data.col(fst), data.col(points[i]));
73  if (dist > max)
74  {
75  max = dist;
76  snd = points[i];
77  }
78  }
79 
80  std::swap(fst, snd);
81 
82  for (size_t i = 0; i < points.n_elem; ++i)
83  {
84  double dist = metric.Evaluate(data.col(fst), data.col(points[i]));
85  if (dist > max)
86  {
87  max = dist;
88  snd = points[i];
89  }
90  }
91 
92  if (max == 0) // All these points are the same.
93  return false;
94 
95  // Calculate the normalized projection vector.
96  projVector = ProjVector(data.col(snd) - data.col(fst));
97 
98  arma::vec midPoint = (data.col(snd) + data.col(fst)) / 2;
99 
100  midValue = projVector.Project(midPoint);
101 
102  return true;
103 }
104 
105 } // namespace tree
106 } // namespace mlpack
107 
108 #endif
double Project(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0) const
Project the given point on the projection vector.
Definition: projection_vector.hpp:119
AxisParallelProjVector defines an axis-parallel projection vector.
Definition: projection_vector.hpp:24
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
static bool GetProjVector(const bound::HRectBound< MetricType > &bound, const MatType &data, const arma::Col< size_t > &points, AxisParallelProjVector &projVector, double &midValue)
Create a projection vector based on the given set of point.
Definition: space_split_impl.hpp:22
ProjVector defines a general projection vector (not necessarily axis-parallel).
Definition: projection_vector.hpp:91