mlpack
hollow_ball_bound_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_CORE_TREE_HOLLOW_BALL_BOUND_IMPL_HPP
13 #define MLPACK_CORE_TREE_HOLLOW_BALL_BOUND_IMPL_HPP
14 
15 // In case it hasn't been included already.
16 #include "hollow_ball_bound.hpp"
17 
18 namespace mlpack {
19 namespace bound {
20 
22 template<typename TMetricType, typename ElemType>
24  radii(std::numeric_limits<ElemType>::lowest(),
25  std::numeric_limits<ElemType>::lowest()),
26  metric(new MetricType()),
27  ownsMetric(true)
28 { /* Nothing to do. */ }
29 
35 template<typename TMetricType, typename ElemType>
37 HollowBallBound(const size_t dimension) :
38  radii(std::numeric_limits<ElemType>::lowest(),
39  std::numeric_limits<ElemType>::lowest()),
40  center(dimension),
41  hollowCenter(dimension),
42  metric(new MetricType()),
43  ownsMetric(true)
44 { /* Nothing to do. */ }
45 
53 template<typename TMetricType, typename ElemType>
54 template<typename VecType>
56 HollowBallBound(const ElemType innerRadius,
57  const ElemType outerRadius,
58  const VecType& center) :
59  radii(innerRadius,
60  outerRadius),
61  center(center),
62  hollowCenter(center),
63  metric(new MetricType()),
64  ownsMetric(true)
65 { /* Nothing to do. */ }
66 
68 template<typename TMetricType, typename ElemType>
70  const HollowBallBound& other) :
71  radii(other.radii),
72  center(other.center),
73  hollowCenter(other.hollowCenter),
74  metric(other.metric),
75  ownsMetric(false)
76 { /* Nothing to do. */ }
77 
79 template<typename TMetricType, typename ElemType>
82 {
83  if (this != &other)
84  {
85  if (ownsMetric)
86  delete metric;
87 
88  radii = other.radii;
89  center = other.center;
90  hollowCenter = other.hollowCenter;
91  metric = other.metric;
92  ownsMetric = false;
93  }
94  return *this;
95 }
96 
98 template<typename TMetricType, typename ElemType>
100  HollowBallBound&& other) :
101  radii(other.radii),
102  center(std::move(other.center)),
103  hollowCenter(std::move(other.hollowCenter)),
104  metric(other.metric),
105  ownsMetric(other.ownsMetric)
106 {
107  // Fix the other bound.
108  other.radii.Hi() = 0.0;
109  other.radii.Lo() = 0.0;
110  other.center = arma::Col<ElemType>();
111  other.hollowCenter = arma::Col<ElemType>();
112  other.metric = NULL;
113  other.ownsMetric = false;
114 }
115 
117 template<typename TMetricType, typename ElemType>
120 {
121  if (this != &other)
122  {
123  radii = other.radii;
124  center = std::move(other.center);
125  hollowCenter = std::move(other.hollowCenter);
126  metric = other.metric;
127  ownsMetric = other.ownsMetric;
128 
129  other.radii.Hi() = 0.0;
130  other.radii.Lo() = 0.0;
131  other.center = arma::Col<ElemType>();
132  other.hollowCenter = arma::Col<ElemType>();
133  other.metric = nullptr;
134  other.ownsMetric = false;
135  }
136  return *this;
137 }
138 
140 template<typename TMetricType, typename ElemType>
142 {
143  if (ownsMetric)
144  delete metric;
145 }
146 
148 template<typename TMetricType, typename ElemType>
150  const size_t i) const
151 {
152  if (radii.Hi() < 0)
153  return math::Range();
154  else
155  return math::Range(center[i] - radii.Hi(), center[i] + radii.Hi());
156 }
157 
161 template<typename TMetricType, typename ElemType>
162 template<typename VecType>
164  const VecType& point) const
165 {
166  if (radii.Hi() < 0)
167  return false;
168  else
169  {
170  ElemType dist = metric->Evaluate(center, point);
171  if (dist > radii.Hi())
172  return false; // The point is situated outside the outer ball.
173 
174  // Check if the point is situated outside the hole.
175  dist = metric->Evaluate(hollowCenter, point);
176 
177  return (dist >= radii.Lo());
178  }
179 }
180 
184 template<typename TMetricType, typename ElemType>
186  const HollowBallBound& other) const
187 {
188  if (radii.Hi() < 0)
189  return false;
190  else
191  {
192  const ElemType dist = metric->Evaluate(center, other.center);
193  const ElemType hollowCenterDist = metric->Evaluate(hollowCenter,
194  other.center);
195  const ElemType hollowHollowDist = metric->Evaluate(hollowCenter,
196  other.hollowCenter);
197 
198  // The outer ball of the second bound does not contain the hole of the first
199  // bound.
200  bool containOnOneSide = (hollowCenterDist - other.radii.Hi() >= radii.Lo())
201  && (dist + other.radii.Hi() <= radii.Hi());
202 
203  // The hole of the second bound contains the hole of the first bound.
204  bool containOnEverySide = (hollowHollowDist + radii.Lo() <=
205  other.radii.Lo()) && (dist + other.radii.Hi() <= radii.Hi());
206 
207  // The first bound has not got a hole.
208  bool containAsBall = (radii.Lo() == 0) &&
209  (dist + other.radii.Hi() <= radii.Hi());
210 
211  return (containOnOneSide || containOnEverySide || containAsBall);
212  }
213 }
214 
215 
219 template<typename TMetricType, typename ElemType>
220 template<typename VecType>
222  const VecType& point,
223  typename std::enable_if_t<IsVector<VecType>::value>* /* junk */) const
224 {
225  if (radii.Hi() < 0)
226  return std::numeric_limits<ElemType>::max();
227  else
228  {
229  const ElemType outerDistance = metric->Evaluate(point, center) - radii.Hi();
230 
231  if (outerDistance >= 0)
232  return outerDistance; // The outer ball does not contain the point.
233 
234  // Check if the point is situated in the hole.
235  const ElemType innerDistance = math::ClampNonNegative(radii.Lo() -
236  metric->Evaluate(point, hollowCenter));
237 
238  return innerDistance;
239  }
240 }
241 
245 template<typename TMetricType, typename ElemType>
247  const HollowBallBound& other)
248  const
249 {
250  if (radii.Hi() < 0 || other.radii.Hi() < 0)
251  return std::numeric_limits<ElemType>::max();
252  else
253  {
254  const ElemType outerDistance = metric->Evaluate(center, other.center) -
255  radii.Hi() - other.radii.Hi();
256  if (outerDistance >= 0)
257  return outerDistance; // The outer hollows do not overlap.
258 
259  // Check if the hole of the second bound contains the outer ball of the
260  // first bound.
261  const ElemType innerDistance1 = other.radii.Lo() -
262  metric->Evaluate(center, other.hollowCenter) - radii.Hi();
263  if (innerDistance1 >= 0)
264  return innerDistance1;
265 
266  // Check if the hole of the first bound contains the outer ball of the
267  // second bound.
268  const ElemType innerDistance2 = math::ClampNonNegative(radii.Lo() -
269  metric->Evaluate(hollowCenter, other.center) - other.radii.Hi());
270 
271  return innerDistance2;
272  }
273 }
274 
278 template<typename TMetricType, typename ElemType>
279 template<typename VecType>
281  const VecType& point,
282  typename std::enable_if_t<IsVector<VecType>::value>* /* junk */) const
283 {
284  if (radii.Hi() < 0)
285  return std::numeric_limits<ElemType>::max();
286  else
287  return metric->Evaluate(point, center) + radii.Hi();
288 }
289 
293 template<typename TMetricType, typename ElemType>
295  const HollowBallBound& other)
296  const
297 {
298  if (radii.Hi() < 0)
299  return std::numeric_limits<ElemType>::max();
300  else
301  return metric->Evaluate(other.center, center) + radii.Hi() +
302  other.radii.Hi();
303 }
304 
310 template<typename TMetricType, typename ElemType>
311 template<typename VecType>
313  const VecType& point,
314  typename std::enable_if_t<IsVector<VecType>::value>* /* junk */) const
315 {
316  if (radii.Hi() < 0)
317  return math::Range(std::numeric_limits<ElemType>::max(),
318  std::numeric_limits<ElemType>::max());
319  else
320  {
322  const ElemType dist = metric->Evaluate(point, center);
323 
324  if (dist >= radii.Hi()) // The outer ball does not contain the point.
325  range.Lo() = dist - radii.Hi();
326  else
327  {
328  // Check if the point is situated in the hole.
329  range.Lo() = math::ClampNonNegative(radii.Lo() -
330  metric->Evaluate(point, hollowCenter));
331  }
332  range.Hi() = dist + radii.Hi();
333 
334  return range;
335  }
336 }
337 
338 template<typename TMetricType, typename ElemType>
340  const HollowBallBound& other) const
341 {
342  if (radii.Hi() < 0)
343  return math::Range(std::numeric_limits<ElemType>::max(),
344  std::numeric_limits<ElemType>::max());
345  else
346  {
348 
349  const ElemType dist = metric->Evaluate(center, other.center);
350 
351  const ElemType outerDistance = dist - radii.Hi() - other.radii.Hi();
352  if (outerDistance >= 0)
353  range.Lo() = outerDistance; // The outer balls do not overlap.
354  else
355  {
356  const ElemType innerDistance1 = other.radii.Lo() -
357  metric->Evaluate(center, other.hollowCenter) - radii.Hi();
358  // Check if the outer ball of the first bound is contained in the
359  // hole of the second bound.
360  if (innerDistance1 >= 0)
361  range.Lo() = innerDistance1;
362  else
363  {
364  // Check if the outer ball of the second bound is contained in the
365  // hole of the first bound.
366  range.Lo() = math::ClampNonNegative(radii.Lo() -
367  metric->Evaluate(hollowCenter, other.center) - other.radii.Hi());
368  }
369  }
370  range.Hi() = dist + radii.Hi() + other.radii.Hi();
371  return range;
372  }
373 }
374 
381 template<typename TMetricType, typename ElemType>
382 template<typename MatType>
385 {
386  if (radii.Hi() < 0)
387  {
388  center = data.col(0);
389  radii.Hi() = 0;
390  }
391  if (radii.Lo() < 0)
392  {
393  hollowCenter = data.col(0);
394  radii.Lo() = 0;
395  }
396  // Now iteratively add points.
397  for (size_t i = 0; i < data.n_cols; ++i)
398  {
399  const ElemType dist = metric->Evaluate(center, data.col(i));
400  const ElemType hollowDist = metric->Evaluate(hollowCenter, data.col(i));
401 
402  // See if the new point lies outside the bound.
403  if (dist > radii.Hi())
404  {
405  // Move towards the new point and increase the radius just enough to
406  // accommodate the new point.
407  const arma::Col<ElemType> diff = data.col(i) - center;
408  center += ((dist - radii.Hi()) / (2 * dist)) * diff;
409  radii.Hi() = 0.5 * (dist + radii.Hi());
410  }
411  if (hollowDist < radii.Lo())
412  radii.Lo() = hollowDist;
413  }
414 
415  return *this;
416 }
417 
421 template<typename TMetricType, typename ElemType>
424 {
425  if (radii.Hi() < 0)
426  {
427  center = other.center;
428  hollowCenter = other.hollowCenter;
429  radii.Hi() = other.radii.Hi();
430  radii.Lo() = other.radii.Lo();
431  return *this;
432  }
433 
434  const ElemType dist = metric->Evaluate(center, other.center);
435  // Check if the outer balls overlap.
436  if (radii.Hi() < dist + other.radii.Hi())
437  radii.Hi() = dist + other.radii.Hi();
438 
439  const ElemType innerDist = math::ClampNonNegative(other.radii.Lo() -
440  metric->Evaluate(hollowCenter, other.hollowCenter));
441  // Check if the hole of the first bound is not contained in the hole of the
442  // second bound.
443  if (radii.Lo() > innerDist)
444  radii.Lo() = innerDist;
445 
446  return *this;
447 }
448 
449 
451 template<typename TMetricType, typename ElemType>
452 template<typename Archive>
454  Archive& ar,
455  const uint32_t /* version */)
456 {
457  ar(CEREAL_NVP(radii));
458  ar(CEREAL_NVP(center));
459  ar(CEREAL_NVP(hollowCenter));
460  ar(CEREAL_POINTER(metric));
461  if (cereal::is_loading<Archive>())
462  {
463  // If we're loading, delete the local metric since we'll have a new one.
464  if (ownsMetric)
465  delete metric;
466 
467  ownsMetric = true;
468  }
469 }
470 
471 } // namespace bound
472 } // namespace mlpack
473 
474 #endif // MLPACK_CORE_TREE_HOLLOW_BALL_BOUND_IMPL_HPP
T Lo() const
Get the lower bound.
Definition: range.hpp:61
double ClampNonNegative(const double d)
Forces a number to be non-negative, turning negative numbers into zero.
Definition: clamp.hpp:28
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
bool Contains(const VecType &point) const
Determines if a point is within this bound.
Definition: hollow_ball_bound_impl.hpp:163
Definition: pointer_wrapper.hpp:23
TMetricType MetricType
A public version of the metric type.
Definition: hollow_ball_bound.hpp:37
RangeType< double > Range
3.0.0 TODO: break reverse-compatibility by changing RangeType to Range.
Definition: range.hpp:19
HollowBallBound()
Empty Constructor.
Definition: hollow_ball_bound_impl.hpp:23
const HollowBallBound & operator|=(const MatType &data)
Expand the bound to include the given point.
~HollowBallBound()
Destructor to release allocated memory.
Definition: hollow_ball_bound_impl.hpp:141
math::RangeType< ElemType > operator[](const size_t i) const
Get the range in a certain dimension.
Definition: hollow_ball_bound_impl.hpp:149
void serialize(Archive &ar, const uint32_t version)
Serialize the bound.
Definition: hollow_ball_bound_impl.hpp:453
HollowBallBound & operator=(const HollowBallBound &other)
For the same reason as the copy constructor: to prevent memory leaks.
Definition: hollow_ball_bound_impl.hpp:81
T Hi() const
Get the upper bound.
Definition: range.hpp:66
Bounds that are useful for binary space partitioning trees.
#define CEREAL_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_wrapper.hpp:96
ElemType MaxDistance(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0) const
Computes maximum distance.
Definition: hollow_ball_bound_impl.hpp:280
ElemType MinDistance(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0) const
Calculates minimum bound-to-point squared distance.
Definition: hollow_ball_bound_impl.hpp:221
Hollow ball bound encloses a set of points at a specific distance (radius) from a specific point (cen...
Definition: hollow_ball_bound.hpp:33
math::RangeType< ElemType > RangeDistance(const VecType &other, typename std::enable_if_t< IsVector< VecType >::value > *=0) const
Calculates minimum and maximum bound-to-point distance.
Definition: hollow_ball_bound_impl.hpp:312
If value == true, then VecType is some sort of Armadillo vector or subview.
Definition: arma_traits.hpp:35