mlpack
minimal_splits_number_sweep_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_SPLITS_NUMBER_SWEEP_IMPL_HPP
14 #define MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_SPLITS_NUMBER_SWEEP_IMPL_HPP
15 
17 
18 namespace mlpack {
19 namespace tree {
20 
21 template<typename SplitPolicy>
22 template<typename TreeType>
24  const size_t axis,
25  const TreeType* node,
26  typename TreeType::ElemType& axisCut)
27 {
28  typedef typename TreeType::ElemType ElemType;
29 
30  std::vector<std::pair<ElemType, size_t>> sorted(node->NumChildren());
31 
32  for (size_t i = 0; i < node->NumChildren(); ++i)
33  {
34  sorted[i].first = SplitPolicy::Bound(node->Child(i))[axis].Hi();
35  sorted[i].second = i;
36  }
37 
38  // Sort candidates in order to check balancing.
39  std::sort(sorted.begin(), sorted.end(),
40  [] (const std::pair<ElemType, size_t>& s1,
41  const std::pair<ElemType, size_t>& s2)
42  {
43  return s1.first < s2.first;
44  });
45 
46  size_t minCost = SIZE_MAX;
47 
48  // Find a split with the minimal cost.
49  for (size_t i = 0; i < sorted.size(); ++i)
50  {
51  size_t numTreeOneChildren = 0;
52  size_t numTreeTwoChildren = 0;
53  size_t numSplits = 0;
54 
55  // Calculate the number of splits.
56  for (size_t j = 0; j < node->NumChildren(); ++j)
57  {
58  const TreeType& child = node->Child(j);
59  int policy = SplitPolicy::GetSplitPolicy(child, axis, sorted[i].first);
60  if (policy == SplitPolicy::AssignToFirstTree)
61  numTreeOneChildren++;
62  else if (policy == SplitPolicy::AssignToSecondTree)
63  numTreeTwoChildren++;
64  else
65  {
66  numTreeOneChildren++;
67  numTreeTwoChildren++;
68  numSplits++;
69  }
70  }
71 
72  // Check if the split is possible.
73  if (numTreeOneChildren <= node->MaxNumChildren() &&
74  numTreeOneChildren > 0 && numTreeTwoChildren <= node->MaxNumChildren()
75  && numTreeTwoChildren > 0)
76  {
77  // Evaluate the cost using the number of splits and balancing.
78  size_t balance;
79 
80  if (sorted.size() / 2 > i )
81  balance = sorted.size() / 2 - i;
82  else
83  balance = i - sorted.size() / 2;
84 
85  size_t cost = numSplits * balance;
86  if (cost < minCost)
87  {
88  minCost = cost;
89  axisCut = sorted[i].first;
90  }
91  }
92  }
93  return minCost;
94 }
95 
96 template<typename SplitPolicy>
97 template<typename TreeType>
99  const size_t axis,
100  const TreeType* node,
101  typename TreeType::ElemType& axisCut)
102 {
103  // Split along the median.
104  axisCut = (node->Bound()[axis].Lo() + node->Bound()[axis].Hi()) * 0.5;
105 
106  if (node->Bound()[axis].Lo() == axisCut)
107  return SIZE_MAX;
108 
109  return 0;
110 }
111 
112 
113 } // namespace tree
114 } // namespace mlpack
115 
116 #endif // MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_SPLITS_NUMBER_SWEEP_IMPL_HPP
117 
118 
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
static size_t SweepLeafNode(const size_t axis, const TreeType *node, typename TreeType::ElemType &axisCut)
Find a suitable partition of a leaf node along the provided axis.
Definition: minimal_splits_number_sweep_impl.hpp:98
static size_t SweepNonLeafNode(const size_t axis, const TreeType *node, typename TreeType::ElemType &axisCut)
Find a suitable partition of a non-leaf node along the provided axis.
Definition: minimal_splits_number_sweep_impl.hpp:23