12 #ifndef MLPACK_CORE_TREE_RECTANGLE_TREE_R_TREE_SPLIT_IMPL_HPP 13 #define MLPACK_CORE_TREE_RECTANGLE_TREE_R_TREE_SPLIT_IMPL_HPP 28 template<
typename TreeType>
31 if (tree->Count() <= tree->MaxLeafSize())
36 if (tree->Parent() == NULL)
39 TreeType* copy =
new TreeType(*tree,
false);
40 copy->Parent() = tree;
44 tree->children[(tree->NumChildren())++] = copy;
49 assert(tree->Parent()->NumChildren() <= tree->Parent()->MaxNumChildren());
56 RTreeSplit::GetPointSeeds(tree, i, j);
58 TreeType* treeOne =
new TreeType(tree->Parent());
59 TreeType* treeTwo =
new TreeType(tree->Parent());
62 AssignPointDestNode(tree, treeOne, treeTwo, i, j);
65 TreeType* par = tree->Parent();
67 while (par->children[index] != tree) { ++index; }
69 par->children[index] = treeOne;
70 par->children[par->NumChildren()++] = treeTwo;
74 assert(par->NumChildren() <= par->MaxNumChildren() + 1);
75 if (par->NumChildren() == par->MaxNumChildren() + 1)
78 assert(treeOne->Parent()->NumChildren() <= treeOne->MaxNumChildren());
79 assert(treeOne->Parent()->NumChildren() >= treeOne->MinNumChildren());
80 assert(treeTwo->Parent()->NumChildren() <= treeTwo->MaxNumChildren());
81 assert(treeTwo->Parent()->NumChildren() >= treeTwo->MinNumChildren());
94 template<
typename TreeType>
100 if (tree->Parent() == NULL)
103 TreeType* copy =
new TreeType(*tree,
false);
104 copy->Parent() = tree;
105 tree->NumChildren() = 0;
107 tree->children[(tree->NumChildren())++] = copy;
114 RTreeSplit::GetBoundSeeds(tree, i, j);
118 TreeType* treeOne =
new TreeType(tree->Parent());
119 TreeType* treeTwo =
new TreeType(tree->Parent());
122 AssignNodeDestNode(tree, treeOne, treeTwo, i, j);
125 TreeType* par = tree->Parent();
127 while (par->children[index] != tree) { ++index; }
129 assert(index != par->NumChildren());
130 par->children[index] = treeOne;
131 par->children[par->NumChildren()++] = treeTwo;
133 for (
size_t i = 0; i < par->NumChildren(); ++i)
134 assert(par->children[i] != tree);
138 assert(par->NumChildren() <= par->MaxNumChildren() + 1);
140 if (par->NumChildren() == par->MaxNumChildren() + 1)
145 for (
size_t i = 0; i < treeOne->NumChildren(); ++i)
146 treeOne->children[i]->Parent() = treeOne;
148 for (
size_t i = 0; i < treeTwo->NumChildren(); ++i)
149 treeTwo->children[i]->Parent() = treeTwo;
151 assert(treeOne->NumChildren() <= treeOne->MaxNumChildren());
152 assert(treeTwo->NumChildren() <= treeTwo->MaxNumChildren());
153 assert(treeOne->Parent()->NumChildren() <= treeOne->MaxNumChildren());
166 template<
typename TreeType>
167 void RTreeSplit::GetPointSeeds(
const TreeType *tree,
int& iRet,
int& jRet)
172 typename TreeType::ElemType worstPairScore = -1.0;
173 for (
size_t i = 0; i < tree->Count(); ++i)
175 for (
size_t j = i + 1; j < tree->Count(); ++j)
177 const typename TreeType::ElemType score = arma::prod(arma::abs(
178 tree->Dataset().col(tree->Point(i)) -
179 tree->Dataset().col(tree->Point(j))));
181 if (score > worstPairScore)
183 worstPairScore = score;
195 template<
typename TreeType>
196 void RTreeSplit::GetBoundSeeds(
const TreeType *tree,
int& iRet,
int& jRet)
199 typedef typename TreeType::ElemType ElemType;
201 ElemType worstPairScore = -1.0;
202 for (
size_t i = 0; i < tree->NumChildren(); ++i)
204 for (
size_t j = i + 1; j < tree->NumChildren(); ++j)
206 ElemType score = 1.0;
207 for (
size_t k = 0; k < tree->Bound().Dim(); ++k)
209 const ElemType hiMax = std::max(tree->Child(i).Bound()[k].Hi(),
210 tree->Child(j).Bound()[k].Hi());
211 const ElemType loMin = std::min(tree->Child(i).Bound()[k].Lo(),
212 tree->Child(j).Bound()[k].Lo());
213 score *= (hiMax - loMin);
216 if (score > worstPairScore)
218 worstPairScore = score;
226 template<
typename TreeType>
227 void RTreeSplit::AssignPointDestNode(TreeType* oldTree,
234 typedef typename TreeType::ElemType ElemType;
236 size_t end = oldTree->Count();
241 oldTree->Count() = 0;
242 treeOne->Count() = 0;
243 treeTwo->Count() = 0;
245 treeOne->InsertPoint(oldTree->Point(intI));
246 treeTwo->InsertPoint(oldTree->Point(intJ));
252 oldTree->Point(intI) = oldTree->Point(--end);
253 oldTree->Point(intJ) = oldTree->Point(--end);
257 oldTree->Point(intJ) = oldTree->Point(--end);
258 oldTree->Point(intI) = oldTree->Point(--end);
261 size_t numAssignedOne = 1;
262 size_t numAssignedTwo = 1;
273 while ((end > 0) && (end > oldTree->MinLeafSize() -
274 std::min(numAssignedOne, numAssignedTwo)))
277 ElemType bestScore = std::numeric_limits<ElemType>::max();
284 ElemType volOne = 1.0;
285 ElemType volTwo = 1.0;
286 for (
size_t i = 0; i < oldTree->Bound().Dim(); ++i)
288 volOne *= treeOne->Bound()[i].Width();
289 volTwo *= treeTwo->Bound()[i].Width();
294 for (
size_t index = 0; index < end; index++)
296 ElemType newVolOne = 1.0;
297 ElemType newVolTwo = 1.0;
298 for (
size_t i = 0; i < oldTree->Bound().Dim(); ++i)
300 ElemType c = oldTree->Dataset().col(oldTree->Point(index))[i];
301 newVolOne *= treeOne->Bound()[i].Contains(c) ?
302 treeOne->Bound()[i].Width() : (c < treeOne->Bound()[i].Lo() ?
303 (treeOne->Bound()[i].Hi() - c) : (c - treeOne->Bound()[i].Lo()));
304 newVolTwo *= treeTwo->Bound()[i].Contains(c) ?
305 treeTwo->Bound()[i].Width() : (c < treeTwo->Bound()[i].Lo() ?
306 (treeTwo->Bound()[i].Hi() - c) : (c - treeTwo->Bound()[i].Lo()));
310 if ((newVolOne - volOne) < (newVolTwo - volTwo))
312 if (newVolOne - volOne < bestScore)
314 bestScore = newVolOne - volOne;
321 if (newVolTwo - volTwo < bestScore)
323 bestScore = newVolTwo - volTwo;
334 treeOne->InsertPoint(oldTree->Point(bestIndex));
339 treeTwo->InsertPoint(oldTree->Point(bestIndex));
343 oldTree->Point(bestIndex) = oldTree->Point(--end);
349 if (numAssignedOne < numAssignedTwo)
351 for (
size_t i = 0; i < end; ++i)
352 treeOne->InsertPoint(oldTree->Point(i));
356 for (
size_t i = 0; i < end; ++i)
357 treeTwo->InsertPoint(oldTree->Point(i));
362 template<
typename TreeType>
363 void RTreeSplit::AssignNodeDestNode(TreeType* oldTree,
370 typedef typename TreeType::ElemType ElemType;
372 size_t end = oldTree->NumChildren();
375 assert(intI != intJ);
377 for (
size_t i = 0; i < oldTree->NumChildren(); ++i)
378 for (
size_t j = i + 1; j < oldTree->NumChildren(); ++j)
379 assert(oldTree->children[i] != oldTree->children[j]);
381 InsertNodeIntoTree(treeOne, oldTree->children[intI]);
382 InsertNodeIntoTree(treeTwo, oldTree->children[intJ]);
388 oldTree->children[intI] = oldTree->children[--end];
389 oldTree->children[intJ] = oldTree->children[--end];
393 oldTree->children[intJ] = oldTree->children[--end];
394 oldTree->children[intI] = oldTree->children[--end];
397 assert(treeOne->NumChildren() == 1);
398 assert(treeTwo->NumChildren() == 1);
400 for (
size_t i = 0; i < end; ++i)
401 for (
size_t j = i + 1; j < end; ++j)
402 assert(oldTree->children[i] != oldTree->children[j]);
404 for (
size_t i = 0; i < end; ++i)
405 assert(oldTree->children[i] != treeOne->children[0]);
407 for (
size_t i = 0; i < end; ++i)
408 assert(oldTree->children[i] != treeTwo->children[0]);
410 size_t numAssignTreeOne = 1;
411 size_t numAssignTreeTwo = 1;
416 while ((end > 0) && (end > oldTree->MinNumChildren() -
417 std::min(numAssignTreeOne, numAssignTreeTwo)))
420 ElemType bestScore = std::numeric_limits<ElemType>::max();
425 ElemType volOne = 1.0;
426 ElemType volTwo = 1.0;
427 for (
size_t i = 0; i < oldTree->Bound().Dim(); ++i)
429 volOne *= treeOne->Bound()[i].Width();
430 volTwo *= treeTwo->Bound()[i].Width();
433 for (
size_t index = 0; index < end; index++)
435 ElemType newVolOne = 1.0;
436 ElemType newVolTwo = 1.0;
437 for (
size_t i = 0; i < oldTree->Bound().Dim(); ++i)
442 oldTree->Child(index).Bound()[i];
443 newVolOne *= treeOne->Bound()[i].Contains(range) ?
444 treeOne->Bound()[i].Width() : (range.
Contains(treeOne->Bound()[i]) ?
445 range.
Width() : (range.
Lo() < treeOne->Bound()[i].Lo() ?
446 (treeOne->Bound()[i].Hi() - range.
Lo()) : (range.
Hi() -
447 treeOne->Bound()[i].Lo())));
449 newVolTwo *= treeTwo->Bound()[i].Contains(range) ?
450 treeTwo->Bound()[i].Width() : (range.
Contains(treeTwo->Bound()[i]) ?
451 range.
Width() : (range.
Lo() < treeTwo->Bound()[i].Lo() ?
452 (treeTwo->Bound()[i].Hi() - range.
Lo()) : (range.
Hi() -
453 treeTwo->Bound()[i].Lo())));
457 if ((newVolOne - volOne) < (newVolTwo - volTwo))
459 if (newVolOne - volOne < bestScore)
461 bestScore = newVolOne - volOne;
468 if (newVolTwo - volTwo < bestScore)
470 bestScore = newVolTwo - volTwo;
481 InsertNodeIntoTree(treeOne, oldTree->children[bestIndex]);
486 InsertNodeIntoTree(treeTwo, oldTree->children[bestIndex]);
490 oldTree->children[bestIndex] = oldTree->children[--end];
496 if (numAssignTreeOne < numAssignTreeTwo)
498 for (
size_t i = 0; i < end; ++i)
500 InsertNodeIntoTree(treeOne, oldTree->children[i]);
506 for (
size_t i = 0; i < end; ++i)
508 InsertNodeIntoTree(treeTwo, oldTree->children[i]);
514 for (
size_t i = 0; i < treeOne->NumChildren(); ++i)
515 for (
size_t j = i + 1; j < treeOne->NumChildren(); ++j)
516 assert(treeOne->children[i] != treeOne->children[j]);
518 for (
size_t i = 0; i < treeTwo->NumChildren(); ++i)
519 for (
size_t j = i + 1; j < treeTwo->NumChildren(); ++j)
520 assert(treeTwo->children[i] != treeTwo->children[j]);
527 template<
typename TreeType>
528 void RTreeSplit::InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode)
530 destTree->Bound() |= srcNode->Bound();
531 destTree->numDescendants += srcNode->numDescendants;
532 destTree->children[destTree->NumChildren()++] = srcNode;
T Lo() const
Get the lower bound.
Definition: range.hpp:61
static bool SplitNonLeafNode(TreeType *tree, std::vector< bool > &relevels)
Split a non-leaf node using the "default" algorithm.
Definition: r_tree_split_impl.hpp:95
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
static void SplitLeafNode(TreeType *tree, std::vector< bool > &relevels)
Split a leaf node using the "default" algorithm.
Definition: r_tree_split_impl.hpp:29
T Hi() const
Get the upper bound.
Definition: range.hpp:66
bool Contains(const T d) const
Determines if a point is contained within the range.
Definition: range_impl.hpp:187
Definition of the Range class, which represents a simple range with a lower and upper bound...
T Width() const
Gets the span of the range (hi - lo).
Definition: range_impl.hpp:47