13 #ifndef MLPACK_CORE_TREE_CELLBOUND_IMPL_HPP 14 #define MLPACK_CORE_TREE_CELLBOUND_IMPL_HPP 27 template<
typename MetricType,
typename ElemType>
28 inline CellBound<MetricType, ElemType>::CellBound() :
31 loBound(
arma::Mat<ElemType>()),
32 hiBound(
arma::Mat<ElemType>()),
34 loAddress(
arma::Col<AddressElemType>()),
35 hiAddress(
arma::Col<AddressElemType>()),
43 template<
typename MetricType,
typename ElemType>
44 inline CellBound<MetricType, ElemType>::CellBound(
const size_t dimension) :
46 bounds(new math::RangeType<ElemType>[dim]),
47 loBound(
arma::Mat<ElemType>(dim, maxNumBounds)),
48 hiBound(
arma::Mat<ElemType>(dim, maxNumBounds)),
54 for (
size_t k = 0; k < dim ; ++k)
56 loAddress[k] = std::numeric_limits<AddressElemType>::max();
64 template<
typename MetricType,
typename ElemType>
65 inline CellBound<MetricType, ElemType>::CellBound(
66 const CellBound<MetricType, ElemType>& other) :
68 bounds(new math::RangeType<ElemType>[dim]),
69 loBound(other.loBound),
70 hiBound(other.hiBound),
71 numBounds(other.numBounds),
72 loAddress(other.loAddress),
73 hiAddress(other.hiAddress),
74 minWidth(other.MinWidth())
77 for (
size_t i = 0; i < dim; ++i)
78 bounds[i] = other.bounds[i];
84 template<
typename MetricType,
typename ElemType>
87 ElemType>& CellBound<MetricType, ElemType>::operator=(
88 const CellBound<MetricType, ElemType>& other)
93 if (dim != other.Dim())
99 bounds =
new math::RangeType<ElemType>[dim];
102 loBound = other.loBound;
103 hiBound = other.hiBound;
104 numBounds = other.numBounds;
105 loAddress = other.loAddress;
106 hiAddress = other.hiAddress;
109 for (
size_t i = 0; i < dim; ++i)
110 bounds[i] = other.bounds[i];
112 minWidth = other.MinWidth();
120 template<
typename MetricType,
typename ElemType>
121 inline CellBound<MetricType, ElemType>::CellBound(
122 CellBound<MetricType, ElemType>&& other) :
124 bounds(other.bounds),
125 loBound(
std::move(other.loBound)),
126 hiBound(
std::move(other.hiBound)),
127 numBounds(
std::move(other.numBounds)),
128 loAddress(
std::move(other.loAddress)),
129 hiAddress(
std::move(other.hiAddress)),
130 minWidth(other.minWidth)
135 other.minWidth = 0.0;
141 template<
typename MetricType,
typename ElemType>
142 inline CellBound<MetricType, ElemType>::~CellBound()
151 template<
typename MetricType,
typename ElemType>
152 inline void CellBound<MetricType, ElemType>::Clear()
154 for (
size_t k = 0; k < dim; ++k)
156 bounds[k] = math::RangeType<ElemType>();
158 loAddress[k] = std::numeric_limits<AddressElemType>::max();
170 template<
typename MetricType,
typename ElemType>
172 arma::Col<ElemType>& center)
const 175 if (!(center.n_elem == dim))
176 center.set_size(dim);
178 for (
size_t i = 0; i < dim; ++i)
179 center(i) = bounds[i].Mid();
182 template<
typename MetricType,
typename ElemType>
183 template<
typename MatType>
184 void CellBound<MetricType, ElemType>::AddBound(
185 const arma::Col<ElemType>& loCorner,
186 const arma::Col<ElemType>& hiCorner,
189 assert(numBounds < loBound.n_cols);
190 assert(loBound.n_rows == dim);
191 assert(loCorner.n_elem == dim);
192 assert(hiCorner.n_elem == dim);
194 for (
size_t k = 0; k < dim; ++k)
196 loBound(k, numBounds) = std::numeric_limits<ElemType>::max();
197 hiBound(k, numBounds) = std::numeric_limits<ElemType>::lowest();
200 for (
size_t i = 0; i < data.n_cols; ++i)
204 for (k = 0; k < dim; ++k)
205 if (data(k, i) < loCorner[k] || data(k, i) > hiCorner[k])
212 for (k = 0; k < dim; ++k)
214 loBound(k, numBounds) = std::min(loBound(k, numBounds), data(k, i));
215 hiBound(k, numBounds) = std::max(hiBound(k, numBounds), data(k, i));
219 for (
size_t k = 0; k < dim; ++k)
220 if (loBound(k, numBounds) > hiBound(k, numBounds))
227 template<
typename MetricType,
typename ElemType>
228 template<
typename MatType>
229 void CellBound<MetricType, ElemType>::InitHighBound(
size_t numEqualBits,
232 arma::Col<AddressElemType> tmpHiAddress(hiAddress);
233 arma::Col<AddressElemType> tmpLoAddress(hiAddress);
234 arma::Col<ElemType> loCorner(tmpHiAddress.n_elem);
235 arma::Col<ElemType> hiCorner(tmpHiAddress.n_elem);
237 assert(tmpHiAddress.n_elem > 0);
241 size_t numCorners = 0;
242 for (
size_t pos = numEqualBits + 1; pos < order * tmpHiAddress.n_elem; pos++)
244 size_t row = pos / order;
245 size_t bit = order - 1 - pos % order;
249 if (tmpHiAddress[row] & ((AddressElemType) 1 << bit))
254 if (numCorners >= maxNumBounds / 2)
255 tmpHiAddress[row] |= ((AddressElemType) 1 << bit);
258 size_t pos = order * tmpHiAddress.n_elem - 1;
261 for ( ; pos > numEqualBits; pos--)
263 size_t row = pos / order;
264 size_t bit = order - 1 - pos % order;
270 if (!(tmpHiAddress[row] & ((AddressElemType) 1 << bit)))
272 addr::AddressToPoint(loCorner, tmpLoAddress);
273 addr::AddressToPoint(hiCorner, tmpHiAddress);
275 AddBound(loCorner, hiCorner, data);
279 tmpLoAddress[row] &= ~((AddressElemType) 1 << bit);
283 if (pos == numEqualBits)
285 addr::AddressToPoint(loCorner, tmpLoAddress);
286 addr::AddressToPoint(hiCorner, tmpHiAddress);
288 AddBound(loCorner, hiCorner, data);
291 for ( ; pos > numEqualBits; pos--)
293 size_t row = pos / order;
294 size_t bit = order - 1 - pos % order;
297 tmpLoAddress[row] &= ~((AddressElemType) 1 << bit);
299 if (tmpHiAddress[row] & ((AddressElemType) 1 << bit))
305 tmpHiAddress[row] ^= (AddressElemType) 1 << bit;
306 addr::AddressToPoint(loCorner, tmpLoAddress);
307 addr::AddressToPoint(hiCorner, tmpHiAddress);
309 AddBound(loCorner, hiCorner, data);
312 tmpHiAddress[row] |= ((AddressElemType) 1 << bit);
316 template<
typename MetricType,
typename ElemType>
317 template<
typename MatType>
318 void CellBound<MetricType, ElemType>::InitLowerBound(
size_t numEqualBits,
321 arma::Col<AddressElemType> tmpHiAddress(loAddress);
322 arma::Col<AddressElemType> tmpLoAddress(loAddress);
323 arma::Col<ElemType> loCorner(tmpHiAddress.n_elem);
324 arma::Col<ElemType> hiCorner(tmpHiAddress.n_elem);
328 size_t numCorners = 0;
329 for (
size_t pos = numEqualBits + 1; pos < order * tmpHiAddress.n_elem; pos++)
331 size_t row = pos / order;
332 size_t bit = order - 1 - pos % order;
336 if (!(tmpLoAddress[row] & ((AddressElemType) 1 << bit)))
341 if (numCorners >= maxNumBounds - numBounds)
342 tmpLoAddress[row] &= ~((AddressElemType) 1 << bit);
345 size_t pos = order * tmpHiAddress.n_elem - 1;
348 for ( ; pos > numEqualBits; pos--)
350 size_t row = pos / order;
351 size_t bit = order - 1 - pos % order;
357 if (tmpLoAddress[row] & ((AddressElemType) 1 << bit))
359 addr::AddressToPoint(loCorner, tmpLoAddress);
360 addr::AddressToPoint(hiCorner, tmpHiAddress);
362 AddBound(loCorner, hiCorner, data);
367 tmpHiAddress[row] |= ((AddressElemType) 1 << bit);
371 if (pos == numEqualBits)
373 addr::AddressToPoint(loCorner, tmpLoAddress);
374 addr::AddressToPoint(hiCorner, tmpHiAddress);
376 AddBound(loCorner, hiCorner, data);
379 for ( ; pos > numEqualBits; pos--)
381 size_t row = pos / order;
382 size_t bit = order - 1 - pos % order;
385 tmpHiAddress[row] |= ((AddressElemType) 1 << bit);
387 if (!(tmpLoAddress[row] & ((AddressElemType) 1 << bit)))
393 tmpLoAddress[row] ^= (AddressElemType) 1 << bit;
395 addr::AddressToPoint(loCorner, tmpLoAddress);
396 addr::AddressToPoint(hiCorner, tmpHiAddress);
398 AddBound(loCorner, hiCorner, data);
402 tmpLoAddress[row] &= ~((AddressElemType) 1 << bit);
406 template<
typename MetricType,
typename ElemType>
407 template<
typename MatType>
408 void CellBound<MetricType, ElemType>::UpdateAddressBounds(
const MatType& data)
415 for ( ; row < hiAddress.n_elem; row++)
416 if (loAddress[row] != hiAddress[row])
420 if (row == hiAddress.n_elem)
422 for (
size_t i = 0; i < dim; ++i)
424 loBound(i, 0) = bounds[i].Lo();
425 hiBound(i, 0) = bounds[i].Hi();
433 for ( ; bit < order; bit++)
434 if ((loAddress[row] & ((AddressElemType) 1 << (order - 1 - bit))) !=
435 (hiAddress[row] & ((AddressElemType) 1 << (order - 1 - bit))))
438 if ((row == hiAddress.n_elem - 1) && (bit == order - 1))
441 for (
size_t i = 0; i < dim; ++i)
443 loBound(i, 0) = bounds[i].Lo();
444 hiBound(i, 0) = bounds[i].Hi();
452 size_t numEqualBits = row * order + bit;
453 InitHighBound(numEqualBits, data);
454 InitLowerBound(numEqualBits, data);
456 assert(numBounds <= maxNumBounds);
461 for (
size_t i = 0; i < dim; ++i)
463 loBound(i, 0) = bounds[i].Lo();
464 hiBound(i, 0) = bounds[i].Hi();
474 template<
typename MetricType,
typename ElemType>
475 template<
typename VecType>
476 inline ElemType CellBound<MetricType, ElemType>::MinDistance(
477 const VecType& point,
482 ElemType minSum = std::numeric_limits<ElemType>::max();
484 ElemType lower, higher;
486 for (
size_t i = 0; i < numBounds; ++i)
490 for (
size_t d = 0; d < dim; d++)
492 lower = loBound(d, i) - point[d];
493 higher = point[d] - hiBound(d, i);
499 if (MetricType::Power == 1)
500 sum += lower + std::fabs(lower) + higher + std::fabs(higher);
501 else if (MetricType::Power == 2)
503 ElemType dist = lower + std::fabs(lower) + higher + std::fabs(higher);
508 sum += pow((lower + fabs(lower)) + (higher + fabs(higher)),
509 (ElemType) MetricType::Power);
524 if (MetricType::Power == 1)
526 else if (MetricType::Power == 2)
528 if (MetricType::TakeRoot)
529 return (ElemType) std::sqrt(minSum) * 0.5;
531 return minSum * 0.25;
535 if (MetricType::TakeRoot)
536 return (ElemType) pow((
double) minSum,
537 1.0 / (
double) MetricType::Power) / 2.0;
539 return minSum / pow(2.0, MetricType::Power);
546 template<
typename MetricType,
typename ElemType>
547 ElemType CellBound<MetricType, ElemType>::MinDistance(
const CellBound& other)
552 ElemType minSum = std::numeric_limits<ElemType>::max();
554 ElemType lower, higher;
556 for (
size_t i = 0; i < numBounds; ++i)
557 for (
size_t j = 0; j < other.numBounds; ++j)
560 for (
size_t d = 0; d < dim; d++)
562 lower = other.loBound(d, j) - hiBound(d, i);
563 higher = loBound(d, i) - other.hiBound(d, j);
569 if (MetricType::Power == 1)
570 sum += (lower + std::fabs(lower)) + (higher + std::fabs(higher));
571 else if (MetricType::Power == 2)
573 ElemType dist = lower + std::fabs(lower) + higher + std::fabs(higher);
578 sum += pow((lower + fabs(lower)) + (higher + fabs(higher)),
579 (ElemType) MetricType::Power);
591 if (MetricType::Power == 1)
593 else if (MetricType::Power == 2)
595 if (MetricType::TakeRoot)
596 return (ElemType) std::sqrt(minSum) * 0.5;
598 return minSum * 0.25;
602 if (MetricType::TakeRoot)
603 return (ElemType) pow((
double) minSum,
604 1.0 / (
double) MetricType::Power) / 2.0;
606 return minSum / pow(2.0, MetricType::Power);
613 template<
typename MetricType,
typename ElemType>
614 template<
typename VecType>
615 inline ElemType CellBound<MetricType, ElemType>::MaxDistance(
616 const VecType& point,
619 ElemType maxSum = std::numeric_limits<ElemType>::lowest();
623 for (
size_t i = 0; i < numBounds; ++i)
626 for (
size_t d = 0; d < dim; d++)
628 ElemType v = std::max(fabs(point[d] - loBound(d, i)),
629 fabs(hiBound(d, i) - point[d]));
631 if (MetricType::Power == 1)
633 else if (MetricType::Power == 2)
636 sum += std::pow(v, (ElemType) MetricType::Power);
644 if (MetricType::TakeRoot)
646 if (MetricType::Power == 1)
648 else if (MetricType::Power == 2)
649 return (ElemType) std::sqrt(maxSum);
651 return (ElemType) pow((
double) maxSum, 1.0 / (
double) MetricType::Power);
660 template<
typename MetricType,
typename ElemType>
661 inline ElemType CellBound<MetricType, ElemType>::MaxDistance(
662 const CellBound& other)
665 ElemType maxSum = std::numeric_limits<ElemType>::lowest();
670 for (
size_t i = 0; i < numBounds; ++i)
671 for (
size_t j = 0; j < other.numBounds; ++j)
674 for (
size_t d = 0; d < dim; d++)
676 v = std::max(fabs(other.hiBound(d, j) - loBound(d, i)),
677 fabs(hiBound(d, i) - other.loBound(d, j)));
680 if (MetricType::Power == 1)
682 else if (MetricType::Power == 2)
685 sum += std::pow(v, (ElemType) MetricType::Power);
693 if (MetricType::TakeRoot)
695 if (MetricType::Power == 1)
697 else if (MetricType::Power == 2)
698 return (ElemType) std::sqrt(maxSum);
700 return (ElemType) pow((
double) maxSum, 1.0 / (
double) MetricType::Power);
709 template<
typename MetricType,
typename ElemType>
710 inline math::RangeType<ElemType>
711 CellBound<MetricType, ElemType>::RangeDistance(
712 const CellBound& other)
const 714 ElemType minLoSum = std::numeric_limits<ElemType>::max();
715 ElemType maxHiSum = std::numeric_limits<ElemType>::lowest();
719 ElemType v1, v2, vLo, vHi;
721 for (
size_t i = 0; i < numBounds; ++i)
722 for (
size_t j = 0; j < other.numBounds; ++j)
726 for (
size_t d = 0; d < dim; d++)
728 v1 = other.loBound(d, j) - hiBound(d, i);
729 v2 = loBound(d, i) - other.hiBound(d, j);
734 vLo = (v1 > 0) ? v1 : 0;
739 vLo = (v2 > 0) ? v2 : 0;
743 if (MetricType::Power == 1)
748 else if (MetricType::Power == 2)
755 loSum += std::pow(vLo, (ElemType) MetricType::Power);
756 hiSum += std::pow(vHi, (ElemType) MetricType::Power);
760 if (loSum < minLoSum)
762 if (hiSum > maxHiSum)
766 if (MetricType::TakeRoot)
768 if (MetricType::Power == 1)
769 return math::RangeType<ElemType>(minLoSum, maxHiSum);
770 else if (MetricType::Power == 2)
771 return math::RangeType<ElemType>((ElemType) std::sqrt(minLoSum),
772 (ElemType) std::sqrt(maxHiSum));
775 return math::RangeType<ElemType>(
776 (ElemType) pow((
double) minLoSum, 1.0 / (double) MetricType::Power),
777 (ElemType) pow((
double) maxHiSum, 1.0 / (double) MetricType::Power));
781 return math::RangeType<ElemType>(minLoSum, maxHiSum);
787 template<
typename MetricType,
typename ElemType>
788 template<
typename VecType>
789 inline math::RangeType<ElemType>
790 CellBound<MetricType, ElemType>::RangeDistance(
791 const VecType& point,
794 ElemType minLoSum = std::numeric_limits<ElemType>::max();
795 ElemType maxHiSum = std::numeric_limits<ElemType>::lowest();
799 ElemType v1, v2, vLo, vHi;
800 for (
size_t i = 0; i < numBounds; ++i)
804 for (
size_t d = 0; d < dim; d++)
806 v1 = loBound(d, i) - point[d];
807 v2 = point[d] - hiBound(d, i);
824 vHi = -std::min(v1, v2);
830 if (MetricType::Power == 1)
835 else if (MetricType::Power == 2)
842 loSum += std::pow(vLo, (ElemType) MetricType::Power);
843 hiSum += std::pow(vHi, (ElemType) MetricType::Power);
846 if (loSum < minLoSum)
848 if (hiSum > maxHiSum)
852 if (MetricType::TakeRoot)
854 if (MetricType::Power == 1)
855 return math::RangeType<ElemType>(minLoSum, maxHiSum);
856 else if (MetricType::Power == 2)
857 return math::RangeType<ElemType>((ElemType) std::sqrt(minLoSum),
858 (ElemType) std::sqrt(maxHiSum));
861 return math::RangeType<ElemType>(
862 (ElemType) pow((
double) minLoSum, 1.0 / (double) MetricType::Power),
863 (ElemType) pow((
double) maxHiSum, 1.0 / (double) MetricType::Power));
867 return math::RangeType<ElemType>(minLoSum, maxHiSum);
873 template<
typename MetricType,
typename ElemType>
874 template<
typename MatType>
875 inline CellBound<MetricType, ElemType>&
876 CellBound<MetricType, ElemType>::operator|=(
const MatType& data)
880 arma::Col<ElemType> mins(arma::min(data, 1));
881 arma::Col<ElemType> maxs(arma::max(data, 1));
883 minWidth = std::numeric_limits<ElemType>::max();
884 for (
size_t i = 0; i < dim; ++i)
886 bounds[i] |= math::RangeType<ElemType>(mins[i], maxs[i]);
887 const ElemType width = bounds[i].Width();
888 if (width < minWidth)
891 loBound(i, 0) = bounds[i].Lo();
892 hiBound(i, 0) = bounds[i].Hi();
903 template<
typename MetricType,
typename ElemType>
904 inline CellBound<MetricType, ElemType>&
905 CellBound<MetricType, ElemType>::operator|=(
const CellBound& other)
907 assert(other.dim == dim);
909 minWidth = std::numeric_limits<ElemType>::max();
910 for (
size_t i = 0; i < dim; ++i)
912 bounds[i] |= other.bounds[i];
913 const ElemType width = bounds[i].Width();
914 if (width < minWidth)
918 if (addr::CompareAddresses(other.loAddress, loAddress) < 0)
919 loAddress = other.loAddress;
921 if (addr::CompareAddresses(other.hiAddress, hiAddress) > 0)
922 hiAddress = other.hiAddress;
924 if (loAddress[0] > hiAddress[0])
926 for (
size_t i = 0; i < dim; ++i)
928 loBound(i, 0) = bounds[i].Lo();
929 hiBound(i, 0) = bounds[i].Hi();
941 template<
typename MetricType,
typename ElemType>
942 template<
typename VecType>
943 inline bool CellBound<MetricType, ElemType>::Contains(
944 const VecType& point)
const 946 for (
size_t i = 0; i < point.n_elem; ++i)
952 if (loAddress[0] > hiAddress[0])
955 arma::Col<AddressElemType> address(dim);
957 addr::PointToAddress(address, point);
959 return addr::Contains(address, loAddress, hiAddress);
966 template<
typename MetricType,
typename ElemType>
967 inline ElemType CellBound<MetricType, ElemType>::Diameter()
const 970 for (
size_t i = 0; i < dim; ++i)
971 d += std::pow(bounds[i].Hi() - bounds[i].Lo(),
972 (ElemType) MetricType::Power);
974 if (MetricType::TakeRoot)
975 return (ElemType) std::pow((
double) d, 1.0 / (
double) MetricType::Power);
981 template<
typename MetricType,
typename ElemType>
982 template<
typename Archive>
983 void CellBound<MetricType, ElemType>::serialize(
988 ar(CEREAL_NVP(minWidth));
989 ar(CEREAL_NVP(loBound));
990 ar(CEREAL_NVP(hiBound));
991 ar(CEREAL_NVP(numBounds));
992 ar(CEREAL_NVP(loAddress));
993 ar(CEREAL_NVP(hiAddress));
994 ar(CEREAL_NVP(metric));
1000 #endif // MLPACK_CORE_TREE_HRECTBOUND_IMPL_HPP
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Definition: pointer_wrapper.hpp:23
void Center(const arma::mat &x, arma::mat &xCentered)
Creates a centered matrix, where centering is done by subtracting the sum over the columns (a column ...
Definition: lin_alg.cpp:43
#define CEREAL_POINTER_ARRAY(T, S)
Cereal does not support the serialization of raw pointer.
Definition: array_wrapper.hpp:87
If value == true, then VecType is some sort of Armadillo vector or subview.
Definition: arma_traits.hpp:35
static void Assert(bool condition, const std::string &message="Assert Failed.")
Checks if the specified condition is true.
Definition: log.cpp:38
bool Contains(const AddressType1 &address, const AddressType2 &loBound, const AddressType3 &hiBound)
Returns true if an address is contained between two other addresses.
Definition: address.hpp:256