mlpack
minimal_coverage_sweep_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_COVERAGE_SWEEP_IMPL_HPP
14 #define MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_COVERAGE_SWEEP_IMPL_HPP
15 
17 
18 namespace mlpack {
19 namespace tree {
20 
21 template<typename SplitPolicy>
22 template<typename TreeType>
23 typename TreeType::ElemType MinimalCoverageSweep<SplitPolicy>::
24 SweepNonLeafNode(const size_t axis,
25  const TreeType* node,
26  typename TreeType::ElemType& axisCut)
27 {
28  typedef typename TreeType::ElemType ElemType;
30 
31  std::vector<std::pair<ElemType, size_t>> sorted(node->NumChildren());
32 
33  for (size_t i = 0; i < node->NumChildren(); ++i)
34  {
35  sorted[i].first = SplitPolicy::Bound(node->Child(i))[axis].Hi();
36  sorted[i].second = i;
37  }
38  // Sort high bounds of children.
39  std::sort(sorted.begin(), sorted.end(),
40  [] (const std::pair<ElemType, size_t>& s1,
41  const std::pair<ElemType, size_t>& s2)
42  {
43  return s1.first < s2.first;
44  });
45 
46  size_t splitPointer = node->NumChildren() / 2;
47 
48  axisCut = sorted[splitPointer - 1].first;
49 
50  // Check if the midpoint split is suitable.
51  if (!CheckNonLeafSweep(node, axis, axisCut))
52  {
53  // Find any suitable partition if the default partition is not acceptable.
54  for (splitPointer = 1; splitPointer < sorted.size(); splitPointer++)
55  {
56  axisCut = sorted[splitPointer - 1].first;
57  if (CheckNonLeafSweep(node, axis, axisCut))
58  break;
59  }
60 
61  if (splitPointer == node->NumChildren())
62  return std::numeric_limits<ElemType>::max();
63  }
64 
65  BoundType bound1(node->Bound().Dim());
66  BoundType bound2(node->Bound().Dim());
67 
68  // Find bounds of two resulting nodes.
69  for (size_t i = 0; i < splitPointer; ++i)
70  bound1 |= node->Child(sorted[i].second).Bound();
71 
72  for (size_t i = splitPointer; i < node->NumChildren(); ++i)
73  bound2 |= node->Child(sorted[i].second).Bound();
74 
75 
76  // Evaluate the cost of the split i.e. calculate the total coverage
77  // of two resulting nodes.
78 
79  ElemType area1 = bound1.Volume();
80  ElemType area2 = bound2.Volume();
81 
82  return area1 + area2;
83 }
84 
85 template<typename SplitPolicy>
86 template<typename TreeType>
87 typename TreeType::ElemType MinimalCoverageSweep<SplitPolicy>::
88 SweepLeafNode(const size_t axis,
89  const TreeType* node,
90  typename TreeType::ElemType& axisCut)
91 {
92  typedef typename TreeType::ElemType ElemType;
94 
95  std::vector<std::pair<ElemType, size_t>> sorted(node->Count());
96 
97  sorted.resize(node->Count());
98 
99  for (size_t i = 0; i < node->NumPoints(); ++i)
100  {
101  sorted[i].first = node->Dataset().col(node->Point(i))[axis];
102  sorted[i].second = i;
103  }
104 
105  // Sort high bounds of children.
106  std::sort(sorted.begin(), sorted.end(),
107  [] (const std::pair<ElemType, size_t>& s1,
108  const std::pair<ElemType, size_t>& s2)
109  {
110  return s1.first < s2.first;
111  });
112 
113  size_t splitPointer = node->Count() / 2;
114 
115  axisCut = sorted[splitPointer - 1].first;
116 
117  // Check if the partition is suitable.
118  if (!CheckLeafSweep(node, axis, axisCut))
119  return std::numeric_limits<ElemType>::max();
120 
121  BoundType bound1(node->Bound().Dim());
122  BoundType bound2(node->Bound().Dim());
123 
124  // Find bounds of two resulting nodes.
125  for (size_t i = 0; i < splitPointer; ++i)
126  bound1 |= node->Dataset().col(node->Point(sorted[i].second));
127 
128  for (size_t i = splitPointer; i < node->NumChildren(); ++i)
129  bound2 |= node->Dataset().col(node->Point(sorted[i].second));
130 
131  // Evaluate the cost of the split i.e. calculate the total coverage
132  // of two resulting nodes.
133 
134  return bound1.Volume() + bound2.Volume();
135 }
136 
137 template<typename SplitPolicy>
138 template<typename TreeType, typename ElemType>
140 CheckNonLeafSweep(const TreeType* node,
141  const size_t cutAxis,
142  const ElemType cut)
143 {
144  size_t numTreeOneChildren = 0;
145  size_t numTreeTwoChildren = 0;
146 
147  // Calculate the number of children in the resulting nodes.
148  for (size_t i = 0; i < node->NumChildren(); ++i)
149  {
150  const TreeType& child = node->Child(i);
151  int policy = SplitPolicy::GetSplitPolicy(child, cutAxis, cut);
152  if (policy == SplitPolicy::AssignToFirstTree)
153  numTreeOneChildren++;
154  else if (policy == SplitPolicy::AssignToSecondTree)
155  numTreeTwoChildren++;
156  else
157  {
158  // The split is required.
159  numTreeOneChildren++;
160  numTreeTwoChildren++;
161  }
162  }
163 
164  if (numTreeOneChildren <= node->MaxNumChildren() && numTreeOneChildren > 0 &&
165  numTreeTwoChildren <= node->MaxNumChildren() && numTreeTwoChildren > 0)
166  return true;
167  return false;
168 }
169 
170 template<typename SplitPolicy>
171 template<typename TreeType, typename ElemType>
173 CheckLeafSweep(const TreeType* node,
174  const size_t cutAxis,
175  const ElemType cut)
176 {
177  size_t numTreeOnePoints = 0;
178  size_t numTreeTwoPoints = 0;
179 
180  // Calculate the number of points in the resulting nodes.
181  for (size_t i = 0; i < node->NumPoints(); ++i)
182  {
183  if (node->Dataset().col(node->Point(i))[cutAxis] <= cut)
184  numTreeOnePoints++;
185  else
186  numTreeTwoPoints++;
187  }
188 
189  if (numTreeOnePoints <= node->MaxLeafSize() && numTreeOnePoints > 0 &&
190  numTreeTwoPoints <= node->MaxLeafSize() && numTreeTwoPoints > 0)
191  return true;
192  return false;
193 }
194 
195 } // namespace tree
196 } // namespace mlpack
197 
198 #endif // MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_COVERAGE_SWEEP_IMPL_HPP
199 
static TreeType::ElemType SweepLeafNode(const size_t axis, const TreeType *node, typename TreeType::ElemType &axisCut)
Find a suitable partition of a leaf node along the provided axis.
Definition: minimal_coverage_sweep_impl.hpp:88
static bool CheckNonLeafSweep(const TreeType *node, const size_t cutAxis, const ElemType cut)
Check if an intermediate node can be split along the axis at the provided coordinate.
Definition: minimal_coverage_sweep_impl.hpp:140
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
static TreeType::ElemType SweepNonLeafNode(const size_t axis, const TreeType *node, typename TreeType::ElemType &axisCut)
Find a suitable partition of a non-leaf node along the provided axis.
Definition: minimal_coverage_sweep_impl.hpp:24
static bool CheckLeafSweep(const TreeType *node, const size_t cutAxis, const ElemType cut)
Check if a leaf node can be split along the axis at the provided coordinate.
Definition: minimal_coverage_sweep_impl.hpp:173