12 #ifndef MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_IMPL_HPP 13 #define MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_IMPL_HPP 25 template<
typename TreeType,
typename StatisticType>
26 void BuildStatistics(TreeType* node)
29 for (
size_t i = 0; i < node->NumChildren(); ++i)
30 BuildStatistics<TreeType, StatisticType>(&node->Child(i));
33 node->Stat() = StatisticType(*node);
39 typename StatisticType,
41 typename RootPointPolicy
44 const MatType& dataset,
48 point(RootPointPolicy::ChooseRoot(dataset)),
54 furthestDescendantDistance(0),
55 localMetric(metric == NULL),
62 this->metric =
new MetricType();
66 if (dataset.n_cols <= 1)
73 arma::Col<size_t> indices = arma::linspace<arma::Col<size_t> >(1,
74 dataset.n_cols - 1, dataset.n_cols - 1);
78 indices[point - 1] = 0;
80 arma::vec distances(dataset.n_cols - 1);
83 ComputeDistances(point, indices, distances, dataset.n_cols - 1);
86 size_t farSetSize = 0;
87 size_t usedSetSize = 0;
88 CreateChildren(indices, distances, dataset.n_cols - 1, farSetSize,
92 while (children.size() == 1)
98 children.erase(children.begin());
101 children.push_back(&(old->
Child(i)));
111 scale = old->
Scale();
121 if (furthestDescendantDistance == 0.0 && dataset.n_cols == 1)
123 else if (furthestDescendantDistance == 0.0)
126 scale = (int) ceil(log(furthestDescendantDistance) / log(base));
130 BuildStatistics<CoverTree, StatisticType>(
this);
132 Log::Info << distanceComps <<
" distance computations during tree " 133 <<
"construction." << std::endl;
138 typename StatisticType,
140 typename RootPointPolicy
143 const MatType& dataset,
147 point(RootPointPolicy::ChooseRoot(dataset)),
153 furthestDescendantDistance(0),
156 metric(new MetricType(metric)),
161 if (dataset.n_cols <= 1)
168 arma::Col<size_t> indices = arma::linspace<arma::Col<size_t> >(1,
169 dataset.n_cols - 1, dataset.n_cols - 1);
173 indices[point - 1] = 0;
175 arma::vec distances(dataset.n_cols - 1);
178 ComputeDistances(point, indices, distances, dataset.n_cols - 1);
181 size_t farSetSize = 0;
182 size_t usedSetSize = 0;
183 CreateChildren(indices, distances, dataset.n_cols - 1, farSetSize,
187 while (children.size() == 1)
193 children.erase(children.begin());
196 children.push_back(&(old->
Child(i)));
206 scale = old->
Scale();
216 if (furthestDescendantDistance == 0.0 && dataset.n_cols == 1)
218 else if (furthestDescendantDistance == 0.0)
221 scale = (int) ceil(log(furthestDescendantDistance) / log(base));
225 BuildStatistics<CoverTree, StatisticType>(
this);
227 Log::Info << distanceComps <<
" distance computations during tree " 228 <<
"construction." << std::endl;
233 typename StatisticType,
235 typename RootPointPolicy
240 dataset(new MatType(
std::move(
data))),
241 point(RootPointPolicy::ChooseRoot(dataset)),
247 furthestDescendantDistance(0),
253 this->metric =
new MetricType();
257 if (dataset->n_cols <= 1)
264 arma::Col<size_t> indices = arma::linspace<arma::Col<size_t> >(1,
265 dataset->n_cols - 1, dataset->n_cols - 1);
269 indices[point - 1] = 0;
271 arma::vec distances(dataset->n_cols - 1);
274 ComputeDistances(point, indices, distances, dataset->n_cols - 1);
277 size_t farSetSize = 0;
278 size_t usedSetSize = 0;
279 CreateChildren(indices, distances, dataset->n_cols - 1, farSetSize,
283 while (children.size() == 1)
289 children.erase(children.begin());
292 children.push_back(&(old->
Child(i)));
302 scale = old->
Scale();
312 if (furthestDescendantDistance == 0.0 && dataset->n_cols == 1)
314 else if (furthestDescendantDistance == 0.0)
317 scale = (int) ceil(log(furthestDescendantDistance) / log(base));
321 BuildStatistics<CoverTree, StatisticType>(
this);
323 Log::Info << distanceComps <<
" distance computations during tree " 324 <<
"construction." << std::endl;
329 typename StatisticType,
331 typename RootPointPolicy
337 dataset(new MatType(
std::move(
data))),
338 point(RootPointPolicy::ChooseRoot(dataset)),
344 furthestDescendantDistance(0),
347 metric(new MetricType(metric)),
352 if (dataset->n_cols <= 1)
359 arma::Col<size_t> indices = arma::linspace<arma::Col<size_t> >(1,
360 dataset->n_cols - 1, dataset->n_cols - 1);
364 indices[point - 1] = 0;
366 arma::vec distances(dataset->n_cols - 1);
369 ComputeDistances(point, indices, distances, dataset->n_cols - 1);
372 size_t farSetSize = 0;
373 size_t usedSetSize = 0;
374 CreateChildren(indices, distances, dataset->n_cols - 1, farSetSize,
378 while (children.size() == 1)
384 children.erase(children.begin());
387 children.push_back(&(old->
Child(i)));
397 scale = old->
Scale();
407 if (furthestDescendantDistance == 0.0 && dataset->n_cols == 1)
409 else if (furthestDescendantDistance == 0.0)
412 scale = (int) ceil(log(furthestDescendantDistance) / log(base));
416 BuildStatistics<CoverTree, StatisticType>(
this);
418 Log::Info << distanceComps <<
" distance computations during tree " 419 <<
"construction." << std::endl;
424 typename StatisticType,
426 typename RootPointPolicy
429 const MatType& dataset,
431 const size_t pointIndex,
435 arma::Col<size_t>& indices,
436 arma::vec& distances,
440 MetricType& metric) :
447 parentDistance(parentDistance),
448 furthestDescendantDistance(0),
455 if (nearSetSize == 0)
457 this->scale = INT_MIN;
463 CreateChildren(indices, distances, nearSetSize, farSetSize, usedSetSize);
469 typename StatisticType,
471 typename RootPointPolicy
474 const MatType& dataset,
476 const size_t pointIndex,
480 const ElemType furthestDescendantDistance,
481 MetricType* metric) :
488 parentDistance(parentDistance),
489 furthestDescendantDistance(furthestDescendantDistance),
490 localMetric(metric == NULL),
497 this->metric =
new MetricType();
503 typename StatisticType,
505 typename RootPointPolicy
509 dataset((other.parent == NULL && other.localDataset) ?
510 new MatType(*other.dataset) : other.dataset),
515 numDescendants(other.numDescendants),
516 parent(other.parent),
517 parentDistance(other.parentDistance),
518 furthestDescendantDistance(other.furthestDescendantDistance),
519 localMetric(other.localMetric),
520 localDataset(other.parent == NULL && other.localDataset),
521 metric((other.localMetric ? new MetricType() : other.metric)),
528 children[i]->Parent() =
this;
532 if (parent == NULL && localDataset)
534 std::queue<CoverTree*> queue;
537 queue.push(children[i]);
539 while (!queue.empty())
544 node->dataset = dataset;
546 queue.push(node->children[i]);
554 typename StatisticType,
556 typename RootPointPolicy
572 for (
size_t i = 0; i < children.size(); ++i)
576 dataset = ((other.parent == NULL && other.localDataset) ?
577 new MatType(*other.dataset) : other.dataset);
582 numDescendants = other.numDescendants;
583 parent = other.parent;
584 parentDistance = other.parentDistance;
585 furthestDescendantDistance = other.furthestDescendantDistance;
586 localMetric = other.localMetric;
587 localDataset = (other.parent == NULL && other.localDataset);
588 metric = (other.localMetric ?
new MetricType() : other.metric);
595 children[i]->Parent() =
this;
599 if (parent == NULL && localDataset)
601 std::queue<CoverTree*> queue;
604 queue.push(children[i]);
606 while (!queue.empty())
611 node->dataset = dataset;
613 queue.push(node->children[i]);
623 typename StatisticType,
625 typename RootPointPolicy
629 dataset(other.dataset),
631 children(
std::move(other.children)),
634 stat(
std::move(other.stat)),
635 numDescendants(other.numDescendants),
636 parent(other.parent),
637 parentDistance(other.parentDistance),
638 furthestDescendantDistance(other.furthestDescendantDistance),
639 localMetric(other.localMetric),
640 localDataset(other.localDataset),
641 metric(other.metric),
642 distanceComps(other.distanceComps)
645 for (
size_t i = 0; i < children.size(); ++i)
646 children[i]->
Parent() =
this;
648 other.dataset = NULL;
650 other.scale = INT_MIN;
652 other.numDescendants = 0;
654 other.parentDistance = 0;
655 other.furthestDescendantDistance = 0;
656 other.localMetric =
false;
657 other.localDataset =
false;
664 typename StatisticType,
666 typename RootPointPolicy
682 for (
size_t i = 0; i < children.size(); ++i)
685 dataset = other.dataset;
687 children = std::move(other.children);
690 stat = std::move(other.stat);
691 numDescendants = other.numDescendants;
692 parent = other.parent;
693 parentDistance = other.parentDistance;
694 furthestDescendantDistance = other.furthestDescendantDistance;
695 localMetric = other.localMetric;
696 localDataset = other.localDataset;
697 metric = other.metric;
698 distanceComps = other.distanceComps;
701 for (
size_t i = 0; i < children.size(); ++i)
702 children[i]->
Parent() =
this;
704 other.dataset = NULL;
706 other.scale = INT_MIN;
708 other.numDescendants = 0;
710 other.parentDistance = 0;
711 other.furthestDescendantDistance = 0;
712 other.localMetric =
false;
713 other.localDataset =
false;
722 typename StatisticType,
724 typename RootPointPolicy
726 template<
typename Archive>
729 const typename std::enable_if_t<cereal::is_loading<Archive>()>*) :
733 ar(cereal::make_nvp(
"this", *
this));
739 typename StatisticType,
741 typename RootPointPolicy
746 for (
size_t i = 0; i < children.size(); ++i)
761 typename StatisticType,
763 typename RootPointPolicy
769 return numDescendants;
775 typename StatisticType,
777 typename RootPointPolicy
781 const size_t index)
const 789 return children[0]->Descendant(index);
792 size_t sum = children[0]->NumDescendants();
793 for (
size_t i = 1; i < children.size(); ++i)
797 sum += children[i]->NumDescendants();
801 return (
size_t() - 1);
808 template<
typename MetricType,
809 typename StatisticType,
811 typename RootPointPolicy>
812 template<
typename VecType>
820 ElemType bestDistance = std::numeric_limits<ElemType>::max();
821 size_t bestIndex = 0;
822 for (
size_t i = 0; i < children.size(); ++i)
824 ElemType distance = children[i]->MinDistance(point);
825 if (distance <= bestDistance)
827 bestDistance = distance;
838 template<
typename MetricType,
839 typename StatisticType,
841 typename RootPointPolicy>
842 template<
typename VecType>
851 size_t bestIndex = 0;
852 for (
size_t i = 0; i < children.size(); ++i)
854 ElemType distance = children[i]->MaxDistance(point);
855 if (distance >= bestDistance)
857 bestDistance = distance;
868 template<
typename MetricType,
869 typename StatisticType,
871 typename RootPointPolicy>
878 ElemType bestDistance = std::numeric_limits<ElemType>::max();
879 size_t bestIndex = 0;
880 for (
size_t i = 0; i < children.size(); ++i)
882 ElemType distance = children[i]->MinDistance(queryNode);
883 if (distance <= bestDistance)
885 bestDistance = distance;
896 template<
typename MetricType,
897 typename StatisticType,
899 typename RootPointPolicy>
907 size_t bestIndex = 0;
908 for (
size_t i = 0; i < children.size(); ++i)
910 ElemType distance = children[i]->MaxDistance(queryNode);
911 if (distance >= bestDistance)
913 bestDistance = distance;
922 typename StatisticType,
924 typename RootPointPolicy
926 typename CoverTree<MetricType, StatisticType, MatType,
932 return std::max(metric->Evaluate(dataset->col(point),
939 typename StatisticType,
941 typename RootPointPolicy
943 typename CoverTree<MetricType, StatisticType, MatType,
949 return std::max(distance - furthestDescendantDistance -
955 typename StatisticType,
957 typename RootPointPolicy
959 typename CoverTree<MetricType, StatisticType, MatType,
964 return std::max(metric->Evaluate(dataset->col(point), other) -
965 furthestDescendantDistance, 0.0);
970 typename StatisticType,
972 typename RootPointPolicy
974 typename CoverTree<MetricType, StatisticType, MatType,
979 return std::max(distance - furthestDescendantDistance, 0.0);
984 typename StatisticType,
986 typename RootPointPolicy
988 typename CoverTree<MetricType, StatisticType, MatType,
993 return metric->Evaluate(dataset->col(point),
1000 typename StatisticType,
1002 typename RootPointPolicy
1004 typename CoverTree<MetricType, StatisticType, MatType,
1010 return distance + furthestDescendantDistance +
1015 typename MetricType,
1016 typename StatisticType,
1018 typename RootPointPolicy
1020 typename CoverTree<MetricType, StatisticType, MatType,
1025 return metric->Evaluate(dataset->col(point), other) +
1026 furthestDescendantDistance;
1030 typename MetricType,
1031 typename StatisticType,
1033 typename RootPointPolicy
1035 typename CoverTree<MetricType, StatisticType, MatType,
1040 return distance + furthestDescendantDistance;
1045 typename MetricType,
1046 typename StatisticType,
1048 typename RootPointPolicy
1055 const ElemType distance = metric->Evaluate(dataset->col(point),
1059 result.
Lo() = std::max(distance - furthestDescendantDistance -
1061 result.
Hi() = distance + furthestDescendantDistance +
1070 typename MetricType,
1071 typename StatisticType,
1073 typename RootPointPolicy
1082 result.
Lo() = std::max(distance - furthestDescendantDistance -
1084 result.
Hi() = distance + furthestDescendantDistance +
1092 typename MetricType,
1093 typename StatisticType,
1095 typename RootPointPolicy
1102 const ElemType distance = metric->Evaluate(dataset->col(point), other);
1105 std::max(distance - furthestDescendantDistance, 0.0),
1106 distance + furthestDescendantDistance);
1112 typename MetricType,
1113 typename StatisticType,
1115 typename RootPointPolicy
1124 std::max(distance - furthestDescendantDistance, 0.0),
1125 distance + furthestDescendantDistance);
1130 typename MetricType,
1131 typename StatisticType,
1133 typename RootPointPolicy
1137 arma::Col<size_t>& indices,
1138 arma::vec& distances,
1141 size_t& usedSetSize)
1151 const ElemType maxDistance = max(distances.rows(0,
1152 nearSetSize + farSetSize - 1));
1153 if (maxDistance == 0)
1157 size_t tempSize = 0;
1158 children.push_back(
new CoverTree(*dataset, base, point, INT_MIN,
this, 0,
1159 indices, distances, 0, tempSize, usedSetSize, *metric));
1160 distanceComps += children.back()->DistanceComps();
1163 for (
size_t i = 0; i < nearSetSize; ++i)
1166 children.push_back(
new CoverTree(*dataset, base, indices[i],
1167 INT_MIN,
this, distances[i], indices, distances, 0, tempSize,
1168 usedSetSize, *metric));
1169 distanceComps += children.back()->DistanceComps();
1175 numDescendants = children.size();
1181 SortPointSet(indices, distances, 0, usedSetSize, farSetSize);
1186 const int nextScale = std::min(scale,
1187 (
int) ceil(log(maxDistance) / log(base))) - 1;
1188 const ElemType bound = pow(base, nextScale);
1192 size_t childNearSetSize =
1193 SplitNearFar(indices, distances, bound, nearSetSize);
1196 size_t childFarSetSize = nearSetSize - childNearSetSize;
1197 size_t childUsedSetSize = 0;
1198 children.push_back(
new CoverTree(*dataset, base, point, nextScale,
this, 0,
1199 indices, distances, childNearSetSize, childFarSetSize, childUsedSetSize,
1202 numDescendants += children[0]->NumDescendants();
1206 furthestDescendantDistance = children[0]->FurthestDescendantDistance();
1209 RemoveNewImplicitNodes();
1211 distanceComps += children[0]->DistanceComps();
1220 SortPointSet(indices, distances, childFarSetSize, childUsedSetSize,
1224 nearSetSize -= childUsedSetSize;
1225 usedSetSize += childUsedSetSize;
1231 while (nearSetSize > 0)
1233 size_t newPointIndex = nearSetSize - 1;
1236 if (newPointIndex != 0)
1238 const size_t tempIndex = indices[newPointIndex];
1239 const ElemType tempDist = distances[newPointIndex];
1241 indices[newPointIndex] = indices[0];
1242 distances[newPointIndex] = distances[0];
1244 indices[0] = tempIndex;
1245 distances[0] = tempDist;
1249 if (distances[0] > furthestDescendantDistance)
1250 furthestDescendantDistance = distances[0];
1253 if ((nearSetSize == 1) && (farSetSize == 0))
1255 size_t childNearSetSize = 0;
1256 children.push_back(
new CoverTree(*dataset, base, indices[0], nextScale,
1257 this, distances[0], indices, distances, childNearSetSize, farSetSize,
1258 usedSetSize, *metric));
1259 distanceComps += children.back()->DistanceComps();
1260 numDescendants += children.back()->NumDescendants();
1273 arma::Col<size_t> childIndices(nearSetSize + farSetSize);
1274 childIndices.rows(0, (nearSetSize + farSetSize - 2)) = indices.rows(1,
1275 nearSetSize + farSetSize - 1);
1276 arma::vec childDistances(nearSetSize + farSetSize);
1279 ComputeDistances(indices[0], childIndices, childDistances, nearSetSize
1283 childNearSetSize = SplitNearFar(childIndices, childDistances, bound,
1284 nearSetSize + farSetSize - 1);
1285 childFarSetSize = PruneFarSet(childIndices, childDistances,
1286 base * bound, childNearSetSize,
1287 (nearSetSize + farSetSize - 1));
1293 childIndices(childNearSetSize + childFarSetSize) = indices[0];
1294 childDistances(childNearSetSize + childFarSetSize) = 0;
1297 childUsedSetSize = 1;
1298 children.push_back(
new CoverTree(*dataset, base, indices[0], nextScale,
1299 this, distances[0], childIndices, childDistances, childNearSetSize,
1300 childFarSetSize, childUsedSetSize, *metric));
1301 numDescendants += children.back()->NumDescendants();
1304 RemoveNewImplicitNodes();
1306 distanceComps += children.back()->DistanceComps();
1313 MoveToUsedSet(indices, distances, nearSetSize, farSetSize, usedSetSize,
1314 childIndices, childFarSetSize, childUsedSetSize);
1318 for (
size_t i = (nearSetSize + farSetSize); i < (nearSetSize + farSetSize +
1320 if (distances[i] > furthestDescendantDistance)
1321 furthestDescendantDistance = distances[i];
1325 typename MetricType,
1326 typename StatisticType,
1328 typename RootPointPolicy
1332 arma::vec& distances,
1334 const size_t pointSetSize)
1338 if (pointSetSize <= 1)
1343 size_t right = pointSetSize - 1;
1348 while ((distances[left] <= bound) && (left != right))
1350 while ((distances[right] > bound) && (left != right))
1353 while (left != right)
1356 const size_t tempPoint = indices[left];
1357 const ElemType tempDist = distances[left];
1359 indices[left] = indices[right];
1360 distances[left] = distances[right];
1362 indices[right] = tempPoint;
1363 distances[right] = tempDist;
1367 while ((distances[left] <= bound) && (left != right))
1373 while ((distances[right] > bound) && (left != right))
1383 typename MetricType,
1384 typename StatisticType,
1386 typename RootPointPolicy
1390 const arma::Col<size_t>& indices,
1391 arma::vec& distances,
1392 const size_t pointSetSize)
1396 distanceComps += pointSetSize;
1397 for (
size_t i = 0; i < pointSetSize; ++i)
1399 distances[i] = metric->Evaluate(dataset->col(pointIndex),
1400 dataset->col(indices[i]));
1405 typename MetricType,
1406 typename StatisticType,
1408 typename RootPointPolicy
1412 arma::vec& distances,
1413 const size_t childFarSetSize,
1414 const size_t childUsedSetSize,
1415 const size_t farSetSize)
1420 const size_t bufferSize = std::min(farSetSize, childUsedSetSize);
1421 const size_t bigCopySize = std::max(farSetSize, childUsedSetSize);
1425 if (bufferSize == 0)
1426 return (childFarSetSize + farSetSize);
1428 size_t* indicesBuffer =
new size_t[bufferSize];
1432 const size_t bufferFromLocation = ((bufferSize == farSetSize) ?
1433 (childFarSetSize + childUsedSetSize) : childFarSetSize);
1435 const size_t directFromLocation = ((bufferSize == farSetSize) ?
1436 childFarSetSize : (childFarSetSize + childUsedSetSize));
1438 const size_t bufferToLocation = ((bufferSize == farSetSize) ?
1439 childFarSetSize : (childFarSetSize + farSetSize));
1441 const size_t directToLocation = ((bufferSize == farSetSize) ?
1442 (childFarSetSize + farSetSize) : childFarSetSize);
1445 memcpy(indicesBuffer, indices.memptr() + bufferFromLocation,
1446 sizeof(size_t) * bufferSize);
1447 memcpy(distancesBuffer, distances.memptr() + bufferFromLocation,
1451 memmove(indices.memptr() + directToLocation,
1452 indices.memptr() + directFromLocation,
sizeof(size_t) * bigCopySize);
1453 memmove(distances.memptr() + directToLocation,
1454 distances.memptr() + directFromLocation,
sizeof(
ElemType) * bigCopySize);
1457 memcpy(indices.memptr() + bufferToLocation, indicesBuffer,
1458 sizeof(size_t) * bufferSize);
1459 memcpy(distances.memptr() + bufferToLocation, distancesBuffer,
1462 delete[] indicesBuffer;
1463 delete[] distancesBuffer;
1466 return (childFarSetSize + farSetSize);
1470 typename MetricType,
1471 typename StatisticType,
1473 typename RootPointPolicy
1477 arma::vec& distances,
1478 size_t& nearSetSize,
1480 size_t& usedSetSize,
1481 arma::Col<size_t>& childIndices,
1482 const size_t childFarSetSize,
1483 const size_t childUsedSetSize)
1485 const size_t originalSum = nearSetSize + farSetSize + usedSetSize;
1490 size_t startChildUsedSet = 0;
1491 for (
size_t i = 0; i < nearSetSize; ++i)
1494 for (
size_t j = startChildUsedSet; j < childUsedSetSize; ++j)
1496 if (childIndices[childFarSetSize + j] == indices[i])
1504 if ((nearSetSize - 1) != i)
1507 size_t tempIndex = indices[nearSetSize + farSetSize - 1];
1508 ElemType tempDist = distances[nearSetSize + farSetSize - 1];
1510 size_t tempNearIndex = indices[nearSetSize - 1];
1511 ElemType tempNearDist = distances[nearSetSize - 1];
1513 indices[nearSetSize + farSetSize - 1] = indices[i];
1514 distances[nearSetSize + farSetSize - 1] = distances[i];
1516 indices[nearSetSize - 1] = tempIndex;
1517 distances[nearSetSize - 1] = tempDist;
1519 indices[i] = tempNearIndex;
1520 distances[i] = tempNearDist;
1525 size_t tempIndex = indices[nearSetSize + farSetSize - 1];
1526 ElemType tempDist = distances[nearSetSize + farSetSize - 1];
1528 indices[nearSetSize + farSetSize - 1] = indices[i];
1529 distances[nearSetSize + farSetSize - 1] = distances[i];
1531 indices[i] = tempIndex;
1532 distances[i] = tempDist;
1535 else if ((nearSetSize - 1) != i)
1538 size_t tempIndex = indices[nearSetSize + farSetSize - 1];
1539 ElemType tempDist = distances[nearSetSize + farSetSize - 1];
1541 indices[nearSetSize + farSetSize - 1] = indices[i];
1542 distances[nearSetSize + farSetSize - 1] = distances[i];
1544 indices[i] = tempIndex;
1545 distances[i] = tempDist;
1555 if (j != startChildUsedSet)
1557 childIndices[childFarSetSize + j] = childIndices[childFarSetSize +
1562 ++startChildUsedSet;
1574 for (
size_t i = 0; i < farSetSize; ++i)
1577 for (
size_t j = startChildUsedSet; j < childUsedSetSize; ++j)
1579 if (childIndices[childFarSetSize + j] == indices[i + nearSetSize])
1584 size_t tempIndex = indices[nearSetSize + farSetSize - 1];
1585 ElemType tempDist = distances[nearSetSize + farSetSize - 1];
1587 indices[nearSetSize + farSetSize - 1] = indices[nearSetSize + i];
1588 distances[nearSetSize + farSetSize - 1] = distances[nearSetSize + i];
1590 indices[nearSetSize + i] = tempIndex;
1591 distances[nearSetSize + i] = tempDist;
1593 if (j != startChildUsedSet)
1595 childIndices[childFarSetSize + j] = childIndices[childFarSetSize +
1600 ++startChildUsedSet;
1610 usedSetSize += childUsedSetSize;
1612 Log::Assert(originalSum == (nearSetSize + farSetSize + usedSetSize));
1616 typename MetricType,
1617 typename StatisticType,
1619 typename RootPointPolicy
1623 arma::vec& distances,
1625 const size_t nearSetSize,
1626 const size_t pointSetSize)
1631 size_t left = nearSetSize;
1632 size_t right = pointSetSize - 1;
1633 while ((distances[left] <= bound) && (left != right))
1635 while ((distances[right] > bound) && (left != right))
1638 while (left != right)
1641 indices[left] = indices[right];
1642 distances[left] = distances[right];
1646 while ((distances[left] <= bound) && (left != right))
1648 while ((distances[right] > bound) && (left != right))
1654 return (left - nearSetSize);
1662 typename MetricType,
1663 typename StatisticType,
1665 typename RootPointPolicy
1672 while (children[children.size() - 1]->NumChildren() == 1)
1674 CoverTree* old = children[children.size() - 1];
1675 children.erase(children.begin() + children.size() - 1);
1678 children.push_back(&(old->
Child(0)));
1683 old->
Child(0).DistanceComps() = old->DistanceComps();
1697 typename MetricType,
1698 typename StatisticType,
1700 typename RootPointPolicy
1709 parentDistance(0.0),
1710 furthestDescendantDistance(0.0),
1712 localDataset(false),
1723 typename MetricType,
1724 typename StatisticType,
1726 typename RootPointPolicy
1728 template<
typename Archive>
1735 if (cereal::is_loading<Archive>())
1737 for (
size_t i = 0; i < children.size(); ++i)
1740 if (localMetric && metric)
1742 if (localDataset && dataset)
1748 bool hasParent = (parent != NULL);
1749 ar(CEREAL_NVP(hasParent));
1750 MatType*& datasetTemp =
const_cast<MatType*&
>(dataset);
1754 ar(CEREAL_NVP(point));
1755 ar(CEREAL_NVP(scale));
1756 ar(CEREAL_NVP(base));
1757 ar(CEREAL_NVP(stat));
1758 ar(CEREAL_NVP(numDescendants));
1759 ar(CEREAL_NVP(parentDistance));
1760 ar(CEREAL_NVP(furthestDescendantDistance));
1763 if (cereal::is_loading<Archive>() && !hasParent)
1766 localDataset =
true;
1772 if (cereal::is_loading<Archive>())
1775 for (
size_t i = 0; i < children.size(); ++i)
1777 children[i]->localMetric =
false;
1778 children[i]->localDataset =
false;
1779 children[i]->Parent() =
this;
1785 std::stack<CoverTree*> stack;
1786 for (
size_t i = 0; i < children.size(); ++i)
1788 stack.push(children[i]);
1790 while (!stack.empty())
1794 node->dataset = dataset;
1795 for (
size_t i = 0; i < node->children.size(); ++i)
1797 stack.push(node->children[i]);
T Lo() const
Get the lower bound.
Definition: range.hpp:61
CoverTree()
A default constructor.
Definition: cover_tree_impl.hpp:1702
ElemType ParentDistance() const
Get the distance to the parent.
Definition: cover_tree.hpp:409
int Scale() const
Get the scale of this node.
Definition: cover_tree.hpp:315
void serialize(Archive &ar, const uint32_t)
Serialize the tree.
Definition: cover_tree_impl.hpp:1729
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
CoverTree & operator=(const CoverTree &other)
Copy the given Cover Tree.
Definition: cover_tree_impl.hpp:560
Definition: pointer_wrapper.hpp:23
MatType::elem_type ElemType
The type held by the matrix type.
Definition: cover_tree.hpp:105
size_t GetNearestChild(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0)
Return the index of the nearest child node to the given query point.
Definition: cover_tree_impl.hpp:814
size_t NumDescendants() const
Get the number of descendant points.
Definition: cover_tree_impl.hpp:767
const std::vector< CoverTree * > & Children() const
Get the children.
Definition: cover_tree.hpp:304
const MatType & Dataset() const
Get a reference to the dataset.
Definition: cover_tree.hpp:283
ElemType FurthestDescendantDistance() const
Get the distance from the center of the node to the furthest descendant.
Definition: cover_tree.hpp:417
Simple real-valued range.
Definition: range.hpp:19
size_t GetFurthestChild(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0)
Return the index of the furthest child node to the given query point.
Definition: cover_tree_impl.hpp:844
#define CEREAL_VECTOR_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_vector_wrapper.hpp:93
size_t Point() const
Get the index of the point which this node represents.
Definition: cover_tree.hpp:286
const CoverTree & Child(const size_t index) const
Get a particular child node.
Definition: cover_tree.hpp:294
T Hi() const
Get the upper bound.
Definition: range.hpp:66
ElemType MaxDistance(const CoverTree &other) const
Return the maximum distance to another node.
Definition: cover_tree_impl.hpp:991
ElemType MinDistance(const CoverTree &other) const
Return the minimum distance to another node.
Definition: cover_tree_impl.hpp:929
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84
#define CEREAL_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_wrapper.hpp:96
math::RangeType< ElemType > RangeDistance(const CoverTree &other) const
Return the minimum and maximum distance to another node.
Definition: cover_tree_impl.hpp:1053
size_t NumChildren() const
Get the number of children.
Definition: cover_tree.hpp:301
A cover tree is a tree specifically designed to speed up nearest-neighbor computation in high-dimensi...
Definition: cover_tree.hpp:99
If value == true, then VecType is some sort of Armadillo vector or subview.
Definition: arma_traits.hpp:35
size_t Descendant(const size_t index) const
Get the index of a particular descendant point.
Definition: cover_tree_impl.hpp:780
static void Assert(bool condition, const std::string &message="Assert Failed.")
Checks if the specified condition is true.
Definition: log.cpp:38
~CoverTree()
Delete this cover tree node and its children.
Definition: cover_tree_impl.hpp:743
CoverTree * Parent() const
Get the parent node.
Definition: cover_tree.hpp:404