13 #ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_SPLITS_NUMBER_SWEEP_IMPL_HPP 14 #define MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_SPLITS_NUMBER_SWEEP_IMPL_HPP 21 template<
typename SplitPolicy>
22 template<
typename TreeType>
26 typename TreeType::ElemType& axisCut)
28 typedef typename TreeType::ElemType ElemType;
30 std::vector<std::pair<ElemType, size_t>> sorted(node->NumChildren());
32 for (
size_t i = 0; i < node->NumChildren(); ++i)
34 sorted[i].first = SplitPolicy::Bound(node->Child(i))[axis].Hi();
39 std::sort(sorted.begin(), sorted.end(),
40 [] (
const std::pair<ElemType, size_t>& s1,
41 const std::pair<ElemType, size_t>& s2)
43 return s1.first < s2.first;
46 size_t minCost = SIZE_MAX;
49 for (
size_t i = 0; i < sorted.size(); ++i)
51 size_t numTreeOneChildren = 0;
52 size_t numTreeTwoChildren = 0;
56 for (
size_t j = 0; j < node->NumChildren(); ++j)
58 const TreeType& child = node->Child(j);
59 int policy = SplitPolicy::GetSplitPolicy(child, axis, sorted[i].first);
60 if (policy == SplitPolicy::AssignToFirstTree)
62 else if (policy == SplitPolicy::AssignToSecondTree)
73 if (numTreeOneChildren <= node->MaxNumChildren() &&
74 numTreeOneChildren > 0 && numTreeTwoChildren <= node->MaxNumChildren()
75 && numTreeTwoChildren > 0)
80 if (sorted.size() / 2 > i )
81 balance = sorted.size() / 2 - i;
83 balance = i - sorted.size() / 2;
85 size_t cost = numSplits * balance;
89 axisCut = sorted[i].first;
96 template<
typename SplitPolicy>
97 template<
typename TreeType>
100 const TreeType* node,
101 typename TreeType::ElemType& axisCut)
104 axisCut = (node->Bound()[axis].Lo() + node->Bound()[axis].Hi()) * 0.5;
106 if (node->Bound()[axis].Lo() == axisCut)
116 #endif // MLPACK_CORE_TREE_RECTANGLE_TREE_MINIMAL_SPLITS_NUMBER_SWEEP_IMPL_HPP Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
static size_t SweepLeafNode(const size_t axis, const TreeType *node, typename TreeType::ElemType &axisCut)
Find a suitable partition of a leaf node along the provided axis.
Definition: minimal_splits_number_sweep_impl.hpp:98
static size_t SweepNonLeafNode(const size_t axis, const TreeType *node, typename TreeType::ElemType &axisCut)
Find a suitable partition of a non-leaf node along the provided axis.
Definition: minimal_splits_number_sweep_impl.hpp:23