mlpack
mean_split_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_CORE_TREE_BINARY_SPACE_TREE_MEAN_SPLIT_IMPL_HPP
14 #define MLPACK_CORE_TREE_BINARY_SPACE_TREE_MEAN_SPLIT_IMPL_HPP
15 
16 #include "mean_split.hpp"
17 
18 namespace mlpack {
19 namespace tree {
20 
21 template<typename BoundType, typename MatType>
22 bool MeanSplit<BoundType, MatType>::SplitNode(const BoundType& bound,
23  MatType& data,
24  const size_t begin,
25  const size_t count,
26  SplitInfo& splitInfo)
27 {
28  double maxWidth = -1;
29 
30  splitInfo.splitDimension = data.n_rows; // Indicate invalid.
31 
32  // Find the split dimension. If the bound is tight, we only need to consult
33  // the bound's width.
35  {
36  for (size_t d = 0; d < data.n_rows; d++)
37  {
38  const double width = bound[d].Width();
39 
40  if (width > maxWidth)
41  {
42  maxWidth = width;
43  splitInfo.splitDimension = d;
44  }
45  }
46  }
47  else
48  {
49  // We must individually calculate bounding boxes.
50  math::Range* ranges = new math::Range[data.n_rows];
51  for (size_t i = begin; i < begin + count; ++i)
52  {
53  // Expand each dimension as necessary.
54  for (size_t d = 0; d < data.n_rows; ++d)
55  {
56  const double val = data(d, i);
57  if (val < ranges[d].Lo())
58  ranges[d].Lo() = val;
59  if (val > ranges[d].Hi())
60  ranges[d].Hi() = val;
61  }
62  }
63 
64  // Now, which is the widest?
65  for (size_t d = 0; d < data.n_rows; d++)
66  {
67  const double width = ranges[d].Width();
68  if (width > maxWidth)
69  {
70  maxWidth = width;
71  splitInfo.splitDimension = d;
72  }
73  }
74 
75  delete[] ranges;
76  }
77 
78  if (maxWidth == 0) // All these points are the same. We can't split.
79  return false;
80 
81  // Split in the mean of that dimension.
82  splitInfo.splitVal = 0.0;
83  for (size_t i = begin; i < begin + count; ++i)
84  splitInfo.splitVal += data(splitInfo.splitDimension, i);
85  splitInfo.splitVal /= count;
86 
87  Log::Assert(splitInfo.splitVal >= bound[splitInfo.splitDimension].Lo());
88  Log::Assert(splitInfo.splitVal <= bound[splitInfo.splitDimension].Hi());
89 
90  return true;
91 }
92 
93 } // namespace tree
94 } // namespace mlpack
95 
96 #endif
T Lo() const
Get the lower bound.
Definition: range.hpp:61
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
A class to obtain compile-time traits about BoundType classes.
Definition: bound_traits.hpp:26
static bool SplitNode(const BoundType &bound, MatType &data, const size_t begin, const size_t count, SplitInfo &splitInfo)
Find the partition of the node.
Definition: mean_split_impl.hpp:22
An information about the partition.
Definition: mean_split.hpp:33
size_t splitDimension
The dimension to split the node on.
Definition: mean_split.hpp:36
T Hi() const
Get the upper bound.
Definition: range.hpp:66
double splitVal
The split in dimension splitDimension is based on this value.
Definition: mean_split.hpp:38
T Width() const
Gets the span of the range (hi - lo).
Definition: range_impl.hpp:47
static void Assert(bool condition, const std::string &message="Assert Failed.")
Checks if the specified condition is true.
Definition: log.cpp:38