mlpack
ballbound_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_CORE_TREE_BALLBOUND_IMPL_HPP
13 #define MLPACK_CORE_TREE_BALLBOUND_IMPL_HPP
14 
15 // In case it hasn't been included already.
16 #include "ballbound.hpp"
18 
19 #include <string>
20 
21 namespace mlpack {
22 namespace bound {
23 
25 template<typename MetricType, typename VecType>
27  radius(std::numeric_limits<ElemType>::lowest()),
28  metric(new MetricType()),
29  ownsMetric(true)
30 { /* Nothing to do. */ }
31 
37 template<typename MetricType, typename VecType>
39  radius(std::numeric_limits<ElemType>::lowest()),
40  center(dimension),
41  metric(new MetricType()),
42  ownsMetric(true)
43 { /* Nothing to do. */ }
44 
51 template<typename MetricType, typename VecType>
53  const VecType& center) :
54  radius(radius),
55  center(center),
56  metric(new MetricType()),
57  ownsMetric(true)
58 { /* Nothing to do. */ }
59 
61 template<typename MetricType, typename VecType>
63  radius(other.radius),
64  center(other.center),
65  metric(other.metric),
66  ownsMetric(false)
67 { /* Nothing to do. */ }
68 
70 template<typename MetricType, typename VecType>
72  const BallBound& other)
73 {
74  if (this != &other)
75  {
76  radius = other.radius;
77  center = other.center;
78  metric = other.metric;
79  ownsMetric = false;
80  }
81  return *this;
82 }
83 
85 template<typename MetricType, typename VecType>
87  radius(other.radius),
88  center(other.center),
89  metric(other.metric),
90  ownsMetric(other.ownsMetric)
91 {
92  // Fix the other bound.
93  other.radius = 0.0;
94  other.center = VecType();
95  other.metric = NULL;
96  other.ownsMetric = false;
97 }
98 
100 template<typename MetricType, typename VecType>
102  BallBound&& other)
103 {
104  if (this != &other)
105  {
106  radius = other.radius;
107  center = std::move(other.center);
108  metric = other.metric;
109  ownsMetric = other.ownsMetric;
110 
111  other.radius = 0.0;
112  other.center = VecType();
113  other.metric = nullptr;
114  other.ownsMetric = false;
115  }
116  return *this;
117 }
118 
120 template<typename MetricType, typename VecType>
122 {
123  if (ownsMetric)
124  delete metric;
125 }
126 
128 template<typename MetricType, typename VecType>
131 {
132  if (radius < 0)
133  return math::Range();
134  else
135  return math::Range(center[i] - radius, center[i] + radius);
136 }
137 
141 template<typename MetricType, typename VecType>
142 bool BallBound<MetricType, VecType>::Contains(const VecType& point) const
143 {
144  if (radius < 0)
145  return false;
146  else
147  return metric->Evaluate(center, point) <= radius;
148 }
149 
153 template<typename MetricType, typename VecType>
154 template<typename OtherVecType>
157  const OtherVecType& point,
158  typename std::enable_if_t<IsVector<OtherVecType>::value>* /* junk */) const
159 {
160  if (radius < 0)
161  return std::numeric_limits<ElemType>::max();
162  else
163  return math::ClampNonNegative(metric->Evaluate(point, center) - radius);
164 }
165 
169 template<typename MetricType, typename VecType>
172  const
173 {
174  if (radius < 0)
175  return std::numeric_limits<ElemType>::max();
176  else
177  {
178  const ElemType delta = metric->Evaluate(center, other.center) - radius -
179  other.radius;
180  return math::ClampNonNegative(delta);
181  }
182 }
183 
187 template<typename MetricType, typename VecType>
188 template<typename OtherVecType>
191  const OtherVecType& point,
192  typename std::enable_if_t<IsVector<OtherVecType>::value>* /* junk */) const
193 {
194  if (radius < 0)
195  return std::numeric_limits<ElemType>::max();
196  else
197  return metric->Evaluate(point, center) + radius;
198 }
199 
203 template<typename MetricType, typename VecType>
206  const
207 {
208  if (radius < 0)
209  return std::numeric_limits<ElemType>::max();
210  else
211  return metric->Evaluate(other.center, center) + radius + other.radius;
212 }
213 
219 template<typename MetricType, typename VecType>
220 template<typename OtherVecType>
223  const OtherVecType& point,
224  typename std::enable_if_t<IsVector<OtherVecType>::value>* /* junk */) const
225 {
226  if (radius < 0)
227  return math::Range(std::numeric_limits<ElemType>::max(),
228  std::numeric_limits<ElemType>::max());
229  else
230  {
231  const ElemType dist = metric->Evaluate(center, point);
232  return math::Range(math::ClampNonNegative(dist - radius),
233  dist + radius);
234  }
235 }
236 
237 template<typename MetricType, typename VecType>
240  const BallBound& other) const
241 {
242  if (radius < 0)
243  return math::Range(std::numeric_limits<ElemType>::max(),
244  std::numeric_limits<ElemType>::max());
245  else
246  {
247  const ElemType dist = metric->Evaluate(center, other.center);
248  const ElemType sumradius = radius + other.radius;
249  return math::Range(math::ClampNonNegative(dist - sumradius),
250  dist + sumradius);
251  }
252 }
253 
277 template<typename MetricType, typename VecType>
278 template<typename MatType>
281 {
282  if (radius < 0)
283  {
284  center = data.col(0);
285  radius = 0;
286  }
287 
288  // Now iteratively add points.
289  for (size_t i = 0; i < data.n_cols; ++i)
290  {
291  const ElemType dist = metric->Evaluate(center, (VecType) data.col(i));
292 
293  // See if the new point lies outside the bound.
294  if (dist > radius)
295  {
296  // Move towards the new point and increase the radius just enough to
297  // accommodate the new point.
298  const VecType diff = data.col(i) - center;
299  center += ((dist - radius) / (2 * dist)) * diff;
300  radius = 0.5 * (dist + radius);
301  }
302  }
303 
304  return *this;
305 }
306 
308 template<typename MetricType, typename VecType>
309 template<typename Archive>
311  Archive& ar,
312  const uint32_t /* version */)
313 {
314  ar(CEREAL_NVP(radius));
315  ar(CEREAL_NVP(center));
316 
317  if (cereal::is_loading<Archive>())
318  {
319  // If we're loading, delete the local metric since we'll have a new one.
320  if (ownsMetric)
321  delete metric;
322  }
323 
324  ar(CEREAL_POINTER(metric));
325  ar(CEREAL_NVP(ownsMetric));
326 }
327 
328 } // namespace bound
329 } // namespace mlpack
330 
331 #endif // MLPACK_CORE_TREE_DBALLBOUND_IMPL_HPP
double ClampNonNegative(const double d)
Forces a number to be non-negative, turning negative numbers into zero.
Definition: clamp.hpp:28
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
BallBound()
Empty Constructor.
Definition: ballbound_impl.hpp:26
Definition: pointer_wrapper.hpp:23
RangeType< double > Range
3.0.0 TODO: break reverse-compatibility by changing RangeType to Range.
Definition: range.hpp:19
Miscellaneous math clamping routines.
Ball bound encloses a set of points at a specific distance (radius) from a specific point (center)...
Definition: ballbound.hpp:32
~BallBound()
Destructor to release allocated memory.
Definition: ballbound_impl.hpp:121
bool Contains(const VecType &point) const
Determines if a point is within this bound.
Definition: ballbound_impl.hpp:142
Simple real-valued range.
Definition: range.hpp:19
Bounds that are useful for binary space partitioning trees.
VecType::elem_type ElemType
The underlying data type.
Definition: ballbound.hpp:36
#define CEREAL_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_wrapper.hpp:96
ElemType MinDistance(const OtherVecType &point, typename std::enable_if_t< IsVector< OtherVecType >::value > *=0) const
Calculates minimum bound-to-point squared distance.
Definition: ballbound_impl.hpp:156
BallBound & operator=(const BallBound &other)
For the same reason as the copy constructor: to prevent memory leaks.
Definition: ballbound_impl.hpp:71
ElemType MaxDistance(const OtherVecType &point, typename std::enable_if_t< IsVector< OtherVecType >::value > *=0) const
Computes maximum distance.
Definition: ballbound_impl.hpp:190
math::RangeType< ElemType > operator[](const size_t i) const
Get the range in a certain dimension.
Definition: ballbound_impl.hpp:130
void serialize(Archive &ar, const uint32_t version)
Serialize the bound.
Definition: ballbound_impl.hpp:310
math::RangeType< ElemType > RangeDistance(const OtherVecType &other, typename std::enable_if_t< IsVector< OtherVecType >::value > *=0) const
Calculates minimum and maximum bound-to-point distance.
const BallBound & operator|=(const BallBound &other)
Expand the bound to include the given node.
If value == true, then VecType is some sort of Armadillo vector or subview.
Definition: arma_traits.hpp:35