mlpack
midpoint_split_impl.hpp
Go to the documentation of this file.
1 
14 #ifndef MLPACK_CORE_TREE_BINARY_SPACE_TREE_MIDPOINT_SPLIT_IMPL_HPP
15 #define MLPACK_CORE_TREE_BINARY_SPACE_TREE_MIDPOINT_SPLIT_IMPL_HPP
16 
17 #include "midpoint_split.hpp"
19 
20 namespace mlpack {
21 namespace tree {
22 
23 template<typename BoundType, typename MatType>
25  MatType& data,
26  const size_t begin,
27  const size_t count,
28  SplitInfo& splitInfo)
29 {
30  double maxWidth = -1;
31  splitInfo.splitDimension = data.n_rows; // Indicate invalid.
32 
33  // Find the split dimension. If the bound is tight, we only need to consult
34  // the bound's width.
36  {
37  for (size_t d = 0; d < data.n_rows; d++)
38  {
39  const double width = bound[d].Width();
40 
41  if (width > maxWidth)
42  {
43  maxWidth = width;
44  splitInfo.splitDimension = d;
45 
46  // Split in the midpoint of that dimension.
47  splitInfo.splitVal = bound[d].Mid();
48  }
49  }
50  }
51  else
52  {
53  // We must individually calculate bounding boxes.
54  math::Range* ranges = new math::Range[data.n_rows];
55  for (size_t i = begin; i < begin + count; ++i)
56  {
57  // Expand each dimension as necessary.
58  for (size_t d = 0; d < data.n_rows; ++d)
59  {
60  const double val = data(d, i);
61  if (val < ranges[d].Lo())
62  ranges[d].Lo() = val;
63  if (val > ranges[d].Hi())
64  ranges[d].Hi() = val;
65  }
66  }
67 
68  // Now, which is the widest?
69  for (size_t d = 0; d < data.n_rows; d++)
70  {
71  const double width = ranges[d].Width();
72  if (width > maxWidth)
73  {
74  maxWidth = width;
75  splitInfo.splitDimension = d;
76  // Split in the midpoint of that dimension.
77  splitInfo.splitVal = ranges[d].Mid();
78  }
79  }
80 
81  delete[] ranges;
82  }
83 
84  if (maxWidth <= 0) // All these points are the same. We can't split.
85  return false;
86 
87  // Split in the midpoint of that dimension.
88  splitInfo.splitVal = bound[splitInfo.splitDimension].Mid();
89 
90  return true;
91 }
92 
93 } // namespace tree
94 } // namespace mlpack
95 
96 #endif
T Lo() const
Get the lower bound.
Definition: range.hpp:61
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
A class to obtain compile-time traits about BoundType classes.
Definition: bound_traits.hpp:26
Bounds that are useful for binary space partitioning trees.
double splitVal
The split in dimension splitDimension is based on this value.
Definition: midpoint_split.hpp:39
T Hi() const
Get the upper bound.
Definition: range.hpp:66
A struct that contains an information about the split.
Definition: midpoint_split.hpp:34
static bool SplitNode(const BoundType &bound, MatType &data, const size_t begin, const size_t count, SplitInfo &splitInfo)
Find the partition of the node.
Definition: midpoint_split_impl.hpp:24
size_t splitDimension
The dimension to split the node on.
Definition: midpoint_split.hpp:37
T Width() const
Gets the span of the range (hi - lo).
Definition: range_impl.hpp:47
T Mid() const
Gets the midpoint of this range.
Definition: range_impl.hpp:59