12 #ifndef MLPACK_CORE_TREE_HOLLOW_BALL_BOUND_IMPL_HPP 13 #define MLPACK_CORE_TREE_HOLLOW_BALL_BOUND_IMPL_HPP 22 template<
typename TMetricType,
typename ElemType>
24 radii(
std::numeric_limits<ElemType>::lowest(),
25 std::numeric_limits<ElemType>::lowest()),
35 template<
typename TMetricType,
typename ElemType>
38 radii(
std::numeric_limits<ElemType>::lowest(),
39 std::numeric_limits<ElemType>::lowest()),
41 hollowCenter(dimension),
53 template<
typename TMetricType,
typename ElemType>
54 template<
typename VecType>
57 const ElemType outerRadius,
58 const VecType& center) :
68 template<
typename TMetricType,
typename ElemType>
73 hollowCenter(other.hollowCenter),
79 template<
typename TMetricType,
typename ElemType>
89 center = other.center;
90 hollowCenter = other.hollowCenter;
91 metric = other.metric;
98 template<
typename TMetricType,
typename ElemType>
102 center(
std::move(other.center)),
103 hollowCenter(
std::move(other.hollowCenter)),
104 metric(other.metric),
105 ownsMetric(other.ownsMetric)
108 other.radii.Hi() = 0.0;
109 other.radii.Lo() = 0.0;
110 other.center = arma::Col<ElemType>();
111 other.hollowCenter = arma::Col<ElemType>();
113 other.ownsMetric =
false;
117 template<
typename TMetricType,
typename ElemType>
124 center = std::move(other.center);
125 hollowCenter = std::move(other.hollowCenter);
126 metric = other.metric;
127 ownsMetric = other.ownsMetric;
129 other.radii.Hi() = 0.0;
130 other.radii.Lo() = 0.0;
131 other.center = arma::Col<ElemType>();
132 other.hollowCenter = arma::Col<ElemType>();
133 other.metric =
nullptr;
134 other.ownsMetric =
false;
140 template<
typename TMetricType,
typename ElemType>
148 template<
typename TMetricType,
typename ElemType>
150 const size_t i)
const 161 template<
typename TMetricType,
typename ElemType>
162 template<
typename VecType>
164 const VecType& point)
const 170 ElemType dist = metric->Evaluate(center, point);
171 if (dist > radii.
Hi())
175 dist = metric->Evaluate(hollowCenter, point);
177 return (dist >= radii.
Lo());
184 template<
typename TMetricType,
typename ElemType>
192 const ElemType dist = metric->Evaluate(center, other.center);
193 const ElemType hollowCenterDist = metric->Evaluate(hollowCenter,
195 const ElemType hollowHollowDist = metric->Evaluate(hollowCenter,
200 bool containOnOneSide = (hollowCenterDist - other.radii.
Hi() >= radii.
Lo())
201 && (dist + other.radii.
Hi() <= radii.
Hi());
204 bool containOnEverySide = (hollowHollowDist + radii.
Lo() <=
205 other.radii.
Lo()) && (dist + other.radii.
Hi() <= radii.
Hi());
208 bool containAsBall = (radii.
Lo() == 0) &&
209 (dist + other.radii.
Hi() <= radii.
Hi());
211 return (containOnOneSide || containOnEverySide || containAsBall);
219 template<
typename TMetricType,
typename ElemType>
220 template<
typename VecType>
222 const VecType& point,
226 return std::numeric_limits<ElemType>::max();
229 const ElemType outerDistance = metric->Evaluate(point, center) - radii.
Hi();
231 if (outerDistance >= 0)
232 return outerDistance;
236 metric->Evaluate(point, hollowCenter));
238 return innerDistance;
245 template<
typename TMetricType,
typename ElemType>
250 if (radii.
Hi() < 0 || other.radii.
Hi() < 0)
251 return std::numeric_limits<ElemType>::max();
254 const ElemType outerDistance = metric->Evaluate(center, other.center) -
255 radii.
Hi() - other.radii.
Hi();
256 if (outerDistance >= 0)
257 return outerDistance;
261 const ElemType innerDistance1 = other.radii.
Lo() -
262 metric->Evaluate(center, other.hollowCenter) - radii.
Hi();
263 if (innerDistance1 >= 0)
264 return innerDistance1;
269 metric->Evaluate(hollowCenter, other.center) - other.radii.
Hi());
271 return innerDistance2;
278 template<
typename TMetricType,
typename ElemType>
279 template<
typename VecType>
281 const VecType& point,
285 return std::numeric_limits<ElemType>::max();
287 return metric->Evaluate(point, center) + radii.
Hi();
293 template<
typename TMetricType,
typename ElemType>
299 return std::numeric_limits<ElemType>::max();
301 return metric->Evaluate(other.center, center) + radii.
Hi() +
310 template<
typename TMetricType,
typename ElemType>
311 template<
typename VecType>
313 const VecType& point,
317 return math::Range(std::numeric_limits<ElemType>::max(),
318 std::numeric_limits<ElemType>::max());
322 const ElemType dist = metric->Evaluate(point, center);
324 if (dist >= radii.
Hi())
325 range.
Lo() = dist - radii.
Hi();
330 metric->Evaluate(point, hollowCenter));
332 range.
Hi() = dist + radii.
Hi();
338 template<
typename TMetricType,
typename ElemType>
343 return math::Range(std::numeric_limits<ElemType>::max(),
344 std::numeric_limits<ElemType>::max());
349 const ElemType dist = metric->Evaluate(center, other.center);
351 const ElemType outerDistance = dist - radii.
Hi() - other.radii.
Hi();
352 if (outerDistance >= 0)
353 range.
Lo() = outerDistance;
356 const ElemType innerDistance1 = other.radii.
Lo() -
357 metric->Evaluate(center, other.hollowCenter) - radii.
Hi();
360 if (innerDistance1 >= 0)
361 range.
Lo() = innerDistance1;
367 metric->Evaluate(hollowCenter, other.center) - other.radii.
Hi());
370 range.
Hi() = dist + radii.
Hi() + other.radii.
Hi();
381 template<
typename TMetricType,
typename ElemType>
382 template<
typename MatType>
388 center = data.col(0);
393 hollowCenter = data.col(0);
397 for (
size_t i = 0; i < data.n_cols; ++i)
399 const ElemType dist = metric->Evaluate(center, data.col(i));
400 const ElemType hollowDist = metric->Evaluate(hollowCenter, data.col(i));
403 if (dist > radii.
Hi())
407 const arma::Col<ElemType> diff = data.col(i) - center;
408 center += ((dist - radii.
Hi()) / (2 * dist)) * diff;
409 radii.
Hi() = 0.5 * (dist + radii.
Hi());
411 if (hollowDist < radii.
Lo())
412 radii.
Lo() = hollowDist;
421 template<
typename TMetricType,
typename ElemType>
427 center = other.center;
428 hollowCenter = other.hollowCenter;
429 radii.
Hi() = other.radii.
Hi();
430 radii.
Lo() = other.radii.
Lo();
434 const ElemType dist = metric->Evaluate(center, other.center);
436 if (radii.
Hi() < dist + other.radii.
Hi())
437 radii.
Hi() = dist + other.radii.
Hi();
440 metric->Evaluate(hollowCenter, other.hollowCenter));
443 if (radii.
Lo() > innerDist)
444 radii.
Lo() = innerDist;
451 template<
typename TMetricType,
typename ElemType>
452 template<
typename Archive>
457 ar(CEREAL_NVP(radii));
458 ar(CEREAL_NVP(center));
459 ar(CEREAL_NVP(hollowCenter));
461 if (cereal::is_loading<Archive>())
474 #endif // MLPACK_CORE_TREE_HOLLOW_BALL_BOUND_IMPL_HPP T Lo() const
Get the lower bound.
Definition: range.hpp:61
double ClampNonNegative(const double d)
Forces a number to be non-negative, turning negative numbers into zero.
Definition: clamp.hpp:28
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
bool Contains(const VecType &point) const
Determines if a point is within this bound.
Definition: hollow_ball_bound_impl.hpp:163
Definition: pointer_wrapper.hpp:23
TMetricType MetricType
A public version of the metric type.
Definition: hollow_ball_bound.hpp:37
RangeType< double > Range
3.0.0 TODO: break reverse-compatibility by changing RangeType to Range.
Definition: range.hpp:19
HollowBallBound()
Empty Constructor.
Definition: hollow_ball_bound_impl.hpp:23
const HollowBallBound & operator|=(const MatType &data)
Expand the bound to include the given point.
~HollowBallBound()
Destructor to release allocated memory.
Definition: hollow_ball_bound_impl.hpp:141
math::RangeType< ElemType > operator[](const size_t i) const
Get the range in a certain dimension.
Definition: hollow_ball_bound_impl.hpp:149
void serialize(Archive &ar, const uint32_t version)
Serialize the bound.
Definition: hollow_ball_bound_impl.hpp:453
HollowBallBound & operator=(const HollowBallBound &other)
For the same reason as the copy constructor: to prevent memory leaks.
Definition: hollow_ball_bound_impl.hpp:81
T Hi() const
Get the upper bound.
Definition: range.hpp:66
Bounds that are useful for binary space partitioning trees.
#define CEREAL_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_wrapper.hpp:96
ElemType MaxDistance(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0) const
Computes maximum distance.
Definition: hollow_ball_bound_impl.hpp:280
ElemType MinDistance(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0) const
Calculates minimum bound-to-point squared distance.
Definition: hollow_ball_bound_impl.hpp:221
Hollow ball bound encloses a set of points at a specific distance (radius) from a specific point (cen...
Definition: hollow_ball_bound.hpp:33
math::RangeType< ElemType > RangeDistance(const VecType &other, typename std::enable_if_t< IsVector< VecType >::value > *=0) const
Calculates minimum and maximum bound-to-point distance.
Definition: hollow_ball_bound_impl.hpp:312
If value == true, then VecType is some sort of Armadillo vector or subview.
Definition: arma_traits.hpp:35