12 #ifndef MLPACK_CORE_TREE_HRECTBOUND_IMPL_HPP 13 #define MLPACK_CORE_TREE_HRECTBOUND_IMPL_HPP 26 template<
typename MetricType,
typename ElemType>
37 template<
typename MetricType,
typename ElemType>
40 bounds(new math::RangeType<ElemType>[dim]),
47 template<
typename MetricType,
typename ElemType>
51 bounds(new math::RangeType<ElemType>[dim]),
52 minWidth(other.MinWidth())
55 for (
size_t i = 0; i < dim; ++i)
62 template<
typename MetricType,
typename ElemType>
71 if (dim != other.Dim())
82 for (
size_t i = 0; i < dim; ++i)
85 minWidth = other.MinWidth();
93 template<
typename MetricType,
typename ElemType>
98 minWidth(other.minWidth)
103 other.minWidth = 0.0;
109 template<
typename MetricType,
typename ElemType>
116 bounds = other.bounds;
117 minWidth = other.minWidth;
120 other.bounds =
nullptr;
121 other.minWidth = 0.0;
129 template<
typename MetricType,
typename ElemType>
139 template<
typename MetricType,
typename ElemType>
142 for (
size_t i = 0; i < dim; ++i)
152 template<
typename MetricType,
typename ElemType>
154 arma::Col<ElemType>& center)
const 157 if (!(center.n_elem == dim))
158 center.set_size(dim);
160 for (
size_t i = 0; i < dim; ++i)
161 center(i) = bounds[i].Mid();
169 template<
typename MetricType,
typename ElemType>
172 ElemType volume = 1.0;
173 for (
size_t i = 0; i < dim; ++i)
175 if (bounds[i].Lo() >= bounds[i].Hi())
178 volume *= (bounds[i].Hi() - bounds[i].Lo());
187 template<
typename MetricType,
typename ElemType>
188 template<
typename VecType>
190 const VecType& point,
193 Log::Assert(point.n_elem == dim);
197 ElemType lower, higher;
198 for (
size_t d = 0; d < dim; d++)
200 lower = bounds[d].Lo() - point[d];
201 higher = point[d] - bounds[d].Hi();
206 if (MetricType::Power == 1)
207 sum += (lower + std::fabs(lower)) + (higher + std::fabs(higher));
208 else if (MetricType::Power == 2)
210 ElemType dist = (lower + std::fabs(lower)) + (higher + std::fabs(higher));
215 sum += pow((lower + fabs(lower)) + (higher + fabs(higher)),
216 (ElemType) MetricType::Power);
224 if (MetricType::Power == 1)
226 else if (MetricType::Power == 2)
228 if (MetricType::TakeRoot)
229 return (ElemType) std::sqrt(sum) * 0.5;
235 if (MetricType::TakeRoot)
236 return (ElemType) pow((
double) sum,
237 1.0 / (
double) MetricType::Power) / 2.0;
239 return sum / pow(2.0, MetricType::Power);
246 template<
typename MetricType,
typename ElemType>
250 Log::Assert(dim == other.dim);
256 ElemType lower, higher;
257 for (
size_t d = 0; d < dim; d++)
259 lower = obound->
Lo() - mbound->
Hi();
260 higher = mbound->
Lo() - obound->
Hi();
266 if (MetricType::Power == 1)
267 sum += (lower + std::fabs(lower)) + (higher + std::fabs(higher));
268 else if (MetricType::Power == 2)
270 ElemType dist = (lower + std::fabs(lower)) + (higher + std::fabs(higher));
275 sum += pow((lower + fabs(lower)) + (higher + fabs(higher)),
276 (ElemType) MetricType::Power);
285 if (MetricType::Power == 1)
287 else if (MetricType::Power == 2)
289 if (MetricType::TakeRoot)
290 return (ElemType) std::sqrt(sum) * 0.5;
296 if (MetricType::TakeRoot)
297 return (ElemType) pow((
double) sum,
298 1.0 / (
double) MetricType::Power) / 2.0;
300 return sum / pow(2.0, MetricType::Power);
307 template<
typename MetricType,
typename ElemType>
308 template<
typename VecType>
310 const VecType& point,
315 Log::Assert(point.n_elem == dim);
317 for (
size_t d = 0; d < dim; d++)
319 ElemType v = std::max(fabs(point[d] - bounds[d].Lo()),
320 fabs(bounds[d].Hi() - point[d]));
323 if (MetricType::Power == 1)
325 else if (MetricType::Power == 2)
328 sum += std::pow(v, (ElemType) MetricType::Power);
332 if (MetricType::TakeRoot)
334 if (MetricType::Power == 1)
336 else if (MetricType::Power == 2)
337 return (ElemType) std::sqrt(sum);
339 return (ElemType) pow((
double) sum, 1.0 / (
double) MetricType::Power);
348 template<
typename MetricType,
typename ElemType>
355 Log::Assert(dim == other.dim);
358 for (
size_t d = 0; d < dim; d++)
360 v = std::max(fabs(other.bounds[d].
Hi() - bounds[d].Lo()),
361 fabs(bounds[d].Hi() - other.bounds[d].
Lo()));
364 if (MetricType::Power == 1)
366 else if (MetricType::Power == 2)
369 sum += std::pow(v, (ElemType) MetricType::Power);
373 if (MetricType::TakeRoot)
375 if (MetricType::Power == 1)
377 else if (MetricType::Power == 2)
378 return (ElemType) std::sqrt(sum);
380 return (ElemType) pow((
double) sum, 1.0 / (
double) MetricType::Power);
389 template<
typename MetricType,
typename ElemType>
397 Log::Assert(dim == other.dim);
399 ElemType v1, v2, vLo, vHi;
400 for (
size_t d = 0; d < dim; d++)
402 v1 = other.bounds[d].
Lo() - bounds[d].Hi();
403 v2 = bounds[d].Lo() - other.bounds[d].
Hi();
408 vLo = (v1 > 0) ? v1 : 0;
413 vLo = (v2 > 0) ? v2 : 0;
417 if (MetricType::Power == 1)
422 else if (MetricType::Power == 2)
429 loSum += std::pow(vLo, (ElemType) MetricType::Power);
430 hiSum += std::pow(vHi, (ElemType) MetricType::Power);
434 if (MetricType::TakeRoot)
436 if (MetricType::Power == 1)
438 else if (MetricType::Power == 2)
440 (ElemType) std::sqrt(hiSum));
444 (ElemType) pow((
double) loSum, 1.0 / (double) MetricType::Power),
445 (ElemType) pow((
double) hiSum, 1.0 / (double) MetricType::Power));
455 template<
typename MetricType,
typename ElemType>
456 template<
typename VecType>
459 const VecType& point,
465 Log::Assert(point.n_elem == dim);
467 ElemType v1, v2, vLo, vHi;
468 for (
size_t d = 0; d < dim; d++)
470 v1 = bounds[d].Lo() - point[d];
471 v2 = point[d] - bounds[d].Hi();
487 vHi = -std::min(v1, v2);
493 if (MetricType::Power == 1)
498 else if (MetricType::Power == 2)
505 loSum += std::pow(vLo, (ElemType) MetricType::Power);
506 hiSum += std::pow(vHi, (ElemType) MetricType::Power);
510 if (MetricType::TakeRoot)
512 if (MetricType::Power == 1)
514 else if (MetricType::Power == 2)
516 (ElemType) std::sqrt(hiSum));
520 (ElemType) pow((
double) loSum, 1.0 / (double) MetricType::Power),
521 (ElemType) pow((
double) hiSum, 1.0 / (double) MetricType::Power));
531 template<
typename MetricType,
typename ElemType>
532 template<
typename MatType>
536 Log::Assert(data.n_rows == dim);
538 arma::Col<ElemType> mins(min(data, 1));
539 arma::Col<ElemType> maxs(max(data, 1));
541 minWidth = std::numeric_limits<ElemType>::max();
542 for (
size_t i = 0; i < dim; ++i)
545 const ElemType width = bounds[i].
Width();
546 if (width < minWidth)
556 template<
typename MetricType,
typename ElemType>
560 assert(other.dim == dim);
562 minWidth = std::numeric_limits<ElemType>::max();
563 for (
size_t i = 0; i < dim; ++i)
565 bounds[i] |= other.bounds[i];
566 const ElemType width = bounds[i].
Width();
567 if (width < minWidth)
577 template<
typename MetricType,
typename ElemType>
578 template<
typename VecType>
580 const VecType& point)
const 582 for (
size_t i = 0; i < point.n_elem; ++i)
584 if (!bounds[i].Contains(point(i)))
594 template<
typename MetricType,
typename ElemType>
598 for (
size_t i = 0; i < dim; ++i)
604 if (r_a.
Hi() <= r_b.
Lo() || r_a.
Lo() >= r_b.
Hi())
614 template<
typename MetricType,
typename ElemType>
620 for (
size_t k = 0; k < dim; ++k)
622 result[k].Lo() = std::max(bounds[k].Lo(), bound.bounds[k].
Lo());
623 result[k].Hi() = std::min(bounds[k].Hi(), bound.bounds[k].
Hi());
631 template<
typename MetricType,
typename ElemType>
635 for (
size_t k = 0; k < dim; ++k)
637 bounds[k].Lo() = std::max(bounds[k].Lo(), bound.bounds[k].
Lo());
638 bounds[k].Hi() = std::min(bounds[k].Hi(), bound.bounds[k].
Hi());
646 template<
typename MetricType,
typename ElemType>
650 ElemType volume = 1.0;
652 for (
size_t k = 0; k < dim; ++k)
654 ElemType lo = std::max(bounds[k].Lo(), bound.bounds[k].
Lo());
655 ElemType hi = std::min(bounds[k].Hi(), bound.bounds[k].
Hi());
668 template<
typename MetricType,
typename ElemType>
672 for (
size_t i = 0; i < dim; ++i)
673 d += std::pow(bounds[i].Hi() - bounds[i].Lo(),
674 (ElemType) MetricType::Power);
676 if (MetricType::TakeRoot)
677 return (ElemType) std::pow((
double) d, 1.0 / (
double) MetricType::Power);
683 template<
typename MetricType,
typename ElemType>
684 template<
typename Archive>
691 ar(CEREAL_NVP(minWidth));
692 ar(CEREAL_NVP(metric));
698 #endif // MLPACK_CORE_TREE_HRECTBOUND_IMPL_HPP T Lo() const
Get the lower bound.
Definition: range.hpp:61
Bounds that are useful for binary space partitioning trees.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Hyper-rectangle bound for an L-metric.
Definition: hrectbound.hpp:54
HRectBound()
Empty constructor; creates a bound of dimensionality 0.
Definition: hrectbound_impl.hpp:27
T Hi() const
Get the upper bound.
Definition: range.hpp:66
size_t Dim() const
Gets the dimensionality.
Definition: hrectbound.hpp:96
void Center(const arma::mat &x, arma::mat &xCentered)
Creates a centered matrix, where centering is done by subtracting the sum over the columns (a column ...
Definition: lin_alg.cpp:43
T Width() const
Gets the span of the range (hi - lo).
Definition: range_impl.hpp:47
#define CEREAL_POINTER_ARRAY(T, S)
Cereal does not support the serialization of raw pointer.
Definition: array_wrapper.hpp:87
If value == true, then VecType is some sort of Armadillo vector or subview.
Definition: arma_traits.hpp:35