12 #ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_X_TREE_SPLIT_IMPL_HPP 13 #define MLPACK_CORE_TREE_RECTANGLE_TREE_X_TREE_SPLIT_IMPL_HPP 28 template<
typename TreeType>
32 typedef typename TreeType::ElemType ElemType;
34 if (tree->Count() <= tree->MaxLeafSize())
51 std::vector<std::pair<ElemType, size_t>> sorted(tree->Count());
52 for (
size_t i = 0; i < sorted.size(); ++i)
54 sorted[i].first = tree->Dataset().col(tree->Point(i))[bestAxis];
55 sorted[i].second = tree->Point(i);
58 std::sort(sorted.begin(), sorted.end(), PairComp<ElemType, size_t>);
71 TreeType* par = tree->Parent();
72 TreeType* treeOne = (par) ? tree :
new TreeType(tree);
73 TreeType* treeTwo = (par) ?
new TreeType(par) :
new TreeType(tree);
76 const size_t numPoints = tree->Count();
80 tree->numChildren = 0;
81 tree->numDescendants = 0;
86 for (
size_t i = 0; i < numPoints; ++i)
88 if (i < bestIndex + tree->MinLeafSize())
89 treeOne->InsertPoint(sorted[i].second);
91 treeTwo->InsertPoint(sorted[i].second);
97 par->children[par->NumChildren()++] = treeTwo;
101 InsertNodeIntoTree(tree, treeOne);
102 InsertNodeIntoTree(tree, treeTwo);
106 treeOne->AuxiliaryInfo().SplitHistory().history[bestAxis] =
true;
107 treeOne->AuxiliaryInfo().SplitHistory().lastDimension = bestAxis;
108 treeTwo->AuxiliaryInfo().SplitHistory().history[bestAxis] =
true;
109 treeTwo->AuxiliaryInfo().SplitHistory().lastDimension = bestAxis;
112 if (par && par->NumChildren() == par->MaxNumChildren() + 1)
123 template<
typename TreeType>
127 typedef typename TreeType::ElemType ElemType;
138 std::vector<bool> axes(tree->Bound().Dim(),
true);
139 std::vector<size_t> dimensionsLastUsed(tree->NumChildren());
140 for (
size_t i = 0; i < tree->NumChildren(); ++i)
141 dimensionsLastUsed[i] =
142 tree->Child(i).AuxiliaryInfo().SplitHistory().lastDimension;
143 std::sort(dimensionsLastUsed.begin(), dimensionsLastUsed.end());
145 size_t lastDim = dimensionsLastUsed[dimensionsLastUsed.size() / 2];
146 size_t minOverlapSplitDimension = tree->Bound().
Dim();
149 for (
size_t i = lastDim + 1; i < axes.size(); ++i)
151 for (
size_t j = 0; j < tree->NumChildren(); ++j)
153 tree->Child(j).AuxiliaryInfo().SplitHistory().history[i];
156 minOverlapSplitDimension = i;
161 if (minOverlapSplitDimension == tree->Bound().Dim())
163 for (
size_t i = 0; i < lastDim + 1; ++i)
166 for (
size_t j = 0; j < tree->NumChildren(); ++j)
168 tree->Child(j).AuxiliaryInfo().SplitHistory().history[i];
171 minOverlapSplitDimension = i;
177 bool minOverlapSplitUsesHi =
false;
178 ElemType bestScoreMinOverlapSplit = std::numeric_limits<ElemType>::max();
179 ElemType areaOfBestMinOverlapSplit = 0;
180 int bestIndexMinOverlapSplit = 0;
182 int bestOverlapIndexOnBestAxis = 0;
183 int bestAreaIndexOnBestAxis = 0;
184 bool tiedOnOverlap =
false;
185 bool lowIsBest =
true;
187 ElemType bestAxisScore = std::numeric_limits<ElemType>::max();
188 ElemType overlapBestOverlapAxis = 0;
189 ElemType areaBestOverlapAxis = 0;
190 ElemType overlapBestAreaAxis = 0;
191 ElemType areaBestAreaAxis = 0;
193 for (
size_t j = 0; j < tree->Bound().Dim(); ++j)
195 ElemType axisScore = 0.0;
198 std::vector<std::pair<ElemType, TreeType*>> sorted(tree->NumChildren());
199 for (
size_t i = 0; i < sorted.size(); ++i)
201 sorted[i].first = tree->Child(i).Bound()[j].Lo();
202 sorted[i].second = &tree->Child(i);
205 std::sort(sorted.begin(), sorted.end(), PairComp<ElemType, TreeType*>);
208 std::vector<ElemType> areas(tree->MaxNumChildren() -
209 2 * tree->MinNumChildren() + 2);
210 std::vector<ElemType> margins(tree->MaxNumChildren() -
211 2 * tree->MinNumChildren() + 2);
212 std::vector<ElemType> overlapedAreas(tree->MaxNumChildren() -
213 2 * tree->MinNumChildren() + 2);
214 for (
size_t i = 0; i < areas.size(); ++i)
218 overlapedAreas[i] = 0.0;
221 for (
size_t i = 0; i < areas.size(); ++i)
227 size_t cutOff = tree->MinNumChildren() + i;
229 BoundType bound1(tree->Bound().Dim());
230 BoundType bound2(tree->Bound().Dim());
232 for (
size_t l = 0; l < cutOff; l++)
233 bound1 |= sorted[l].second->Bound();
235 for (
size_t l = cutOff; l < tree->NumChildren(); l++)
236 bound2 |= sorted[l].second->Bound();
238 ElemType area1 = bound1.Volume();
239 ElemType area2 = bound2.Volume();
240 ElemType oArea = bound1.Overlap(bound2);
242 for (
size_t k = 0; k < bound1.Dim(); ++k)
243 margins[i] += bound1[k].Width() + bound2[k].Width();
245 areas[i] += area1 + area2;
246 overlapedAreas[i] += oArea;
247 axisScore += margins[i];
250 if (axisScore < bestAxisScore)
252 bestAxisScore = axisScore;
254 bestOverlapIndexOnBestAxis = 0;
255 bestAreaIndexOnBestAxis = 0;
256 overlapBestOverlapAxis = overlapedAreas[bestOverlapIndexOnBestAxis];
257 areaBestOverlapAxis = areas[bestAreaIndexOnBestAxis];
258 for (
size_t i = 1; i < areas.size(); ++i)
260 if (overlapedAreas[i] < overlapedAreas[bestOverlapIndexOnBestAxis])
262 tiedOnOverlap =
false;
263 bestAreaIndexOnBestAxis = i;
264 bestOverlapIndexOnBestAxis = i;
265 overlapBestOverlapAxis = overlapedAreas[i];
266 areaBestOverlapAxis = areas[i];
268 else if (overlapedAreas[i] ==
269 overlapedAreas[bestOverlapIndexOnBestAxis])
271 tiedOnOverlap =
true;
272 if (areas[i] < areas[bestAreaIndexOnBestAxis])
274 bestAreaIndexOnBestAxis = i;
275 overlapBestAreaAxis = overlapedAreas[i];
276 areaBestAreaAxis = areas[i];
283 if (minOverlapSplitDimension != tree->Bound().Dim() &&
284 j == minOverlapSplitDimension)
286 for (
size_t i = 0; i < overlapedAreas.size(); ++i)
288 if (overlapedAreas[i] < bestScoreMinOverlapSplit)
290 bestScoreMinOverlapSplit = overlapedAreas[i];
291 bestIndexMinOverlapSplit = i;
292 areaOfBestMinOverlapSplit = areas[i];
299 for (
size_t j = 0; j < tree->Bound().Dim(); ++j)
301 ElemType axisScore = 0.0;
303 std::vector<std::pair<ElemType, TreeType*>> sorted(tree->NumChildren());
304 for (
size_t i = 0; i < sorted.size(); ++i)
306 sorted[i].first = tree->Child(i).Bound()[j].Hi();
307 sorted[i].second = &tree->Child(i);
310 std::sort(sorted.begin(), sorted.end(), PairComp<ElemType, TreeType*>);
313 std::vector<ElemType> areas(tree->MaxNumChildren() -
314 2 * tree->MinNumChildren() + 2);
315 std::vector<ElemType> margins(tree->MaxNumChildren() -
316 2 * tree->MinNumChildren() + 2);
317 std::vector<ElemType> overlapedAreas(tree->MaxNumChildren() -
318 2 * tree->MinNumChildren() + 2);
319 for (
size_t i = 0; i < areas.size(); ++i)
323 overlapedAreas[i] = 0.0;
326 for (
size_t i = 0; i < areas.size(); ++i)
332 size_t cutOff = tree->MinNumChildren() + i;
334 BoundType bound1(tree->Bound().Dim());
335 BoundType bound2(tree->Bound().Dim());
337 for (
size_t l = 0; l < cutOff; l++)
338 bound1 |= sorted[l].second->Bound();
340 for (
size_t l = cutOff; l < tree->NumChildren(); l++)
341 bound2 |= sorted[l].second->Bound();
343 ElemType area1 = bound1.Volume();
344 ElemType area2 = bound2.Volume();
345 ElemType oArea = bound1.Overlap(bound2);
347 for (
size_t k = 0; k < bound1.Dim(); ++k)
348 margins[i] += bound1[k].Width() + bound2[k].Width();
351 areas[i] += area1 + area2;
352 overlapedAreas[i] += oArea;
353 axisScore += margins[i];
356 if (axisScore < bestAxisScore)
358 bestAxisScore = axisScore;
361 bestOverlapIndexOnBestAxis = 0;
362 bestAreaIndexOnBestAxis = 0;
363 overlapBestOverlapAxis = overlapedAreas[bestOverlapIndexOnBestAxis];
364 areaBestOverlapAxis = areas[bestAreaIndexOnBestAxis];
365 for (
size_t i = 1; i < areas.size(); ++i)
367 if (overlapedAreas[i] < overlapedAreas[bestOverlapIndexOnBestAxis])
369 tiedOnOverlap =
false;
370 bestAreaIndexOnBestAxis = i;
371 bestOverlapIndexOnBestAxis = i;
372 overlapBestOverlapAxis = overlapedAreas[i];
373 areaBestOverlapAxis = areas[i];
375 else if (overlapedAreas[i] ==
376 overlapedAreas[bestOverlapIndexOnBestAxis])
378 tiedOnOverlap =
true;
379 if (areas[i] < areas[bestAreaIndexOnBestAxis])
381 bestAreaIndexOnBestAxis = i;
382 overlapBestAreaAxis = overlapedAreas[i];
383 areaBestAreaAxis = areas[i];
390 if (minOverlapSplitDimension != tree->Bound().Dim() &&
391 j == minOverlapSplitDimension)
393 for (
size_t i = 0; i < overlapedAreas.size(); ++i)
395 if (overlapedAreas[i] < bestScoreMinOverlapSplit)
397 minOverlapSplitUsesHi =
true;
398 bestScoreMinOverlapSplit = overlapedAreas[i];
399 bestIndexMinOverlapSplit = i;
400 areaOfBestMinOverlapSplit = areas[i];
406 std::vector<std::pair<ElemType, TreeType*>> sorted(tree->NumChildren());
409 for (
size_t i = 0; i < sorted.size(); ++i)
411 sorted[i].first = tree->Child(i).Bound()[bestAxis].Lo();
412 sorted[i].second = &tree->Child(i);
417 for (
size_t i = 0; i < sorted.size(); ++i)
419 sorted[i].first = tree->Child(i).Bound()[bestAxis].Hi();
420 sorted[i].second = &tree->Child(i);
424 std::sort(sorted.begin(), sorted.end(), PairComp<ElemType, TreeType*>);
426 if (tree->Parent() != NULL)
429 TreeType* treeTwo =
new TreeType(tree->Parent(), tree->MaxNumChildren());
430 const size_t numChildren = tree->NumChildren();
431 tree->numChildren = 0;
435 bool useMinOverlapSplit =
false;
438 if (areaBestAreaAxis > 0 &&
439 overlapBestAreaAxis / areaBestAreaAxis <
MAX_OVERLAP)
441 tree->numDescendants = 0;
443 for (
size_t i = 0; i < numChildren; ++i)
445 if (i < bestAreaIndexOnBestAxis + tree->MinNumChildren())
446 InsertNodeIntoTree(tree, sorted[i].second);
448 InsertNodeIntoTree(treeTwo, sorted[i].second);
452 useMinOverlapSplit =
true;
456 if (overlapBestOverlapAxis / areaBestOverlapAxis <
MAX_OVERLAP)
458 tree->numDescendants = 0;
460 for (
size_t i = 0; i < numChildren; ++i)
462 if (i < bestOverlapIndexOnBestAxis + tree->MinNumChildren())
463 InsertNodeIntoTree(tree, sorted[i].second);
465 InsertNodeIntoTree(treeTwo, sorted[i].second);
469 useMinOverlapSplit =
true;
475 if (useMinOverlapSplit)
478 if ((minOverlapSplitDimension != tree->Bound().Dim()) &&
479 (bestScoreMinOverlapSplit / areaOfBestMinOverlapSplit <
MAX_OVERLAP))
481 std::vector<std::pair<ElemType, TreeType*>> sorted2(numChildren);
482 if (minOverlapSplitUsesHi)
484 for (
size_t i = 0; i < sorted2.size(); ++i)
486 sorted2[i].first = sorted[i].second->Bound()[bestAxis].Hi();
487 sorted2[i].second = sorted[i].second;
492 for (
size_t i = 0; i < sorted2.size(); ++i)
494 sorted2[i].first = sorted[i].second->Bound()[bestAxis].Lo();
495 sorted2[i].second = sorted[i].second;
498 std::sort(sorted2.begin(), sorted2.end(),
499 PairComp<ElemType, TreeType*>);
501 tree->numDescendants = 0;
503 for (
size_t i = 0; i < numChildren; ++i)
505 if (i < bestIndexMinOverlapSplit + tree->MinNumChildren())
506 InsertNodeIntoTree(tree, sorted2[i].second);
508 InsertNodeIntoTree(treeTwo, sorted2[i].second);
522 if ((tree->Parent()->Parent() == NULL) &&
523 (tree->Parent()->NumChildren() == 1))
526 tree->Parent()->MaxNumChildren() = tree->MaxNumChildren() +
527 tree->AuxiliaryInfo().NormalNodeMaxNumChildren();
528 tree->Parent()->children.resize(tree->Parent()->MaxNumChildren() + 1);
529 tree->Parent()->NumChildren() = tree->NumChildren();
530 for (
size_t i = 0; i < numChildren; ++i)
532 tree->Parent()->children[i] = sorted[i].second;
533 tree->Parent()->children[i]->Parent() = tree->Parent();
534 tree->children[i] = NULL;
544 tree->MaxNumChildren() +=
545 tree->AuxiliaryInfo().NormalNodeMaxNumChildren();
546 tree->children.resize(tree->MaxNumChildren() + 1);
547 tree->numChildren = numChildren;
548 for (
size_t i = 0; i < numChildren; ++i)
549 tree->Child(i).Parent() = tree;
557 tree->AuxiliaryInfo().SplitHistory().history[bestAxis] =
true;
558 tree->AuxiliaryInfo().SplitHistory().lastDimension = bestAxis;
559 treeTwo->AuxiliaryInfo().SplitHistory().history[bestAxis] =
true;
560 treeTwo->AuxiliaryInfo().SplitHistory().lastDimension = bestAxis;
563 TreeType* par = tree->Parent();
564 par->children[par->NumChildren()++] = treeTwo;
568 if (!(par->NumChildren() <= par->MaxNumChildren() + 1))
569 Log::Debug <<
"error " << par->NumChildren() <<
", " 570 << par->MaxNumChildren() + 1 << std::endl;
571 assert(par->NumChildren() <= par->MaxNumChildren() + 1);
573 if (par->NumChildren() == par->MaxNumChildren() + 1)
578 for (
size_t i = 0; i < treeTwo->NumChildren(); ++i)
579 treeTwo->Child(i).Parent() = treeTwo;
581 assert(tree->Parent()->NumChildren() <=
582 tree->Parent()->MaxNumChildren());
583 assert(tree->Parent()->NumChildren() >=
584 tree->Parent()->MinNumChildren());
585 assert(treeTwo->Parent()->NumChildren() <=
586 treeTwo->Parent()->MaxNumChildren());
587 assert(treeTwo->Parent()->NumChildren() >=
588 treeTwo->Parent()->MinNumChildren());
595 TreeType* treeOne =
new TreeType(tree, tree->MaxNumChildren());
596 TreeType* treeTwo =
new TreeType(tree, tree->MaxNumChildren());
597 const size_t numChildren = tree->NumChildren();
598 tree->numChildren = 0;
601 bool useMinOverlapSplit =
false;
604 if (overlapBestAreaAxis/areaBestAreaAxis <
MAX_OVERLAP)
606 for (
size_t i = 0; i < numChildren; ++i)
608 if (i < bestAreaIndexOnBestAxis + tree->MinNumChildren())
609 InsertNodeIntoTree(treeOne, sorted[i].second);
611 InsertNodeIntoTree(treeTwo, sorted[i].second);
615 useMinOverlapSplit =
true;
619 if (overlapBestOverlapAxis/areaBestOverlapAxis <
MAX_OVERLAP)
621 for (
size_t i = 0; i < numChildren; ++i)
623 if (i < bestOverlapIndexOnBestAxis + tree->MinNumChildren())
624 InsertNodeIntoTree(treeOne, sorted[i].second);
626 InsertNodeIntoTree(treeTwo, sorted[i].second);
630 useMinOverlapSplit =
true;
636 if (useMinOverlapSplit)
639 if ((minOverlapSplitDimension != tree->Bound().Dim()) &&
640 (bestScoreMinOverlapSplit / areaOfBestMinOverlapSplit <
MAX_OVERLAP))
642 std::vector<std::pair<ElemType, TreeType*>> sorted2(numChildren);
643 if (minOverlapSplitUsesHi)
645 for (
size_t i = 0; i < sorted2.size(); ++i)
647 sorted2[i].first = sorted[i].second->Bound()[bestAxis].Hi();
648 sorted2[i].second = sorted[i].second;
653 for (
size_t i = 0; i < sorted2.size(); ++i)
655 sorted2[i].first = sorted[i].second->Bound()[bestAxis].Lo();
656 sorted2[i].second = sorted[i].second;
659 std::sort(sorted2.begin(), sorted2.end(),
660 PairComp<ElemType, TreeType*>);
662 for (
size_t i = 0; i < numChildren; ++i)
664 if (i < bestIndexMinOverlapSplit + tree->MinNumChildren())
665 InsertNodeIntoTree(treeOne, sorted2[i].second);
667 InsertNodeIntoTree(treeTwo, sorted2[i].second);
673 tree->MaxNumChildren() +=
674 tree->AuxiliaryInfo().NormalNodeMaxNumChildren();
675 tree->children.resize(tree->MaxNumChildren() + 1);
676 tree->numChildren = numChildren;
677 for (
size_t i = 0; i < numChildren; ++i)
678 tree->Child(i).Parent() = tree;
687 treeOne->AuxiliaryInfo().SplitHistory().history[bestAxis] =
true;
688 treeOne->AuxiliaryInfo().SplitHistory().lastDimension = bestAxis;
689 treeTwo->AuxiliaryInfo().SplitHistory().history[bestAxis] =
true;
690 treeTwo->AuxiliaryInfo().SplitHistory().lastDimension = bestAxis;
693 tree->children[0] = treeOne;
694 tree->children[1] = treeTwo;
695 tree->numChildren = 2;
696 tree->numDescendants = treeOne->numDescendants + treeTwo->numDescendants;
700 for (
size_t i = 0; i < treeOne->NumChildren(); ++i)
701 treeOne->Child(i).Parent() = treeOne;
702 for (
size_t i = 0; i < treeTwo->NumChildren(); ++i)
703 treeTwo->Child(i).Parent() = treeTwo;
713 template<
typename TreeType>
714 void XTreeSplit::InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode)
716 destTree->Bound() |= srcNode->Bound();
717 destTree->numDescendants += srcNode->numDescendants;
718 destTree->children[destTree->NumChildren()++] = srcNode;
static void SplitLeafNode(TreeType *tree, std::vector< bool > &relevels)
Split a leaf node using the algorithm described in "The R*-tree: An Efficient and Robust Access metho...
Definition: x_tree_split_impl.hpp:29
static MLPACK_EXPORT util::NullOutStream Debug
MLPACK_EXPORT is required for global variables, so that they are properly exported by the Windows com...
Definition: log.hpp:79
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
static void PickLeafSplit(TreeType *tree, size_t &bestAxis, size_t &bestIndex)
Given a node, return the best dimension and the best index to split on.
Definition: r_star_tree_split_impl.hpp:82
static bool SplitNonLeafNode(TreeType *tree, std::vector< bool > &relevels)
Split a non-leaf node using the "default" algorithm.
Definition: x_tree_split_impl.hpp:124
static size_t ReinsertPoints(TreeType *tree, std::vector< bool > &relevels)
Reinsert any points into the tree, if needed.
Definition: r_star_tree_split_impl.hpp:28
size_t Dim() const
Gets the dimensionality.
Definition: hrectbound.hpp:96
Definition of the Range class, which represents a simple range with a lower and upper bound...
const double MAX_OVERLAP
The X-tree paper says that a maximum allowable overlap of 20% works well.
Definition: x_tree_split.hpp:29