mlpack
hrectbound_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_CORE_TREE_HRECTBOUND_IMPL_HPP
13 #define MLPACK_CORE_TREE_HRECTBOUND_IMPL_HPP
14 
15 #include <math.h>
16 
17 // In case it has not been included yet.
18 #include "hrectbound.hpp"
19 
20 namespace mlpack {
21 namespace bound {
22 
26 template<typename MetricType, typename ElemType>
28  dim(0),
29  bounds(NULL),
30  minWidth(0)
31 { /* Nothing to do. */ }
32 
37 template<typename MetricType, typename ElemType>
38 inline HRectBound<MetricType, ElemType>::HRectBound(const size_t dimension) :
39  dim(dimension),
40  bounds(new math::RangeType<ElemType>[dim]),
41  minWidth(0)
42 { /* Nothing to do. */ }
43 
47 template<typename MetricType, typename ElemType>
49  const HRectBound<MetricType, ElemType>& other) :
50  dim(other.Dim()),
51  bounds(new math::RangeType<ElemType>[dim]),
52  minWidth(other.MinWidth())
53 {
54  // Copy other bounds over.
55  for (size_t i = 0; i < dim; ++i)
56  bounds[i] = other[i];
57 }
58 
62 template<typename MetricType, typename ElemType>
63 inline HRectBound<
64  MetricType,
65  ElemType>& HRectBound<MetricType,
66  ElemType>::operator=(const HRectBound<MetricType, ElemType>& other)
67 {
68  if (this == &other)
69  return *this;
70 
71  if (dim != other.Dim())
72  {
73  // Reallocation is necessary.
74  if (bounds)
75  delete[] bounds;
76 
77  dim = other.Dim();
78  bounds = new math::RangeType<ElemType>[dim];
79  }
80 
81  // Now copy each of the bound values.
82  for (size_t i = 0; i < dim; ++i)
83  bounds[i] = other[i];
84 
85  minWidth = other.MinWidth();
86 
87  return *this;
88 }
89 
93 template<typename MetricType, typename ElemType>
96  dim(other.dim),
97  bounds(other.bounds),
98  minWidth(other.minWidth)
99 {
100  // Fix the other bound.
101  other.dim = 0;
102  other.bounds = NULL;
103  other.minWidth = 0.0;
104 }
105 
109 template<typename MetricType, typename ElemType>
113 {
114  if (this != &other)
115  {
116  bounds = other.bounds;
117  minWidth = other.minWidth;
118  dim = other.dim;
119  other.dim = 0;
120  other.bounds = nullptr;
121  other.minWidth = 0.0;
122  }
123  return *this;
124 }
125 
129 template<typename MetricType, typename ElemType>
131 {
132  if (bounds)
133  delete[] bounds;
134 }
135 
139 template<typename MetricType, typename ElemType>
141 {
142  for (size_t i = 0; i < dim; ++i)
143  bounds[i] = math::RangeType<ElemType>();
144  minWidth = 0;
145 }
146 
147 /***
148  * Calculates the centroid of the range, placing it into the given vector.
149  *
150  * @param centroid Vector which the centroid will be written to.
151  */
152 template<typename MetricType, typename ElemType>
154  arma::Col<ElemType>& center) const
155 {
156  // Set size correctly if necessary.
157  if (!(center.n_elem == dim))
158  center.set_size(dim);
159 
160  for (size_t i = 0; i < dim; ++i)
161  center(i) = bounds[i].Mid();
162 }
163 
169 template<typename MetricType, typename ElemType>
171 {
172  ElemType volume = 1.0;
173  for (size_t i = 0; i < dim; ++i)
174  {
175  if (bounds[i].Lo() >= bounds[i].Hi())
176  return 0;
177 
178  volume *= (bounds[i].Hi() - bounds[i].Lo());
179  }
180 
181  return volume;
182 }
183 
187 template<typename MetricType, typename ElemType>
188 template<typename VecType>
190  const VecType& point,
191  typename std::enable_if_t<IsVector<VecType>::value>* /* junk */) const
192 {
193  Log::Assert(point.n_elem == dim);
194 
195  ElemType sum = 0;
196 
197  ElemType lower, higher;
198  for (size_t d = 0; d < dim; d++)
199  {
200  lower = bounds[d].Lo() - point[d];
201  higher = point[d] - bounds[d].Hi();
202 
203  // Since only one of 'lower' or 'higher' is negative, if we add each's
204  // absolute value to itself and then sum those two, our result is the
205  // nonnegative half of the equation times two; then we raise to power Power.
206  if (MetricType::Power == 1)
207  sum += (lower + std::fabs(lower)) + (higher + std::fabs(higher));
208  else if (MetricType::Power == 2)
209  {
210  ElemType dist = (lower + std::fabs(lower)) + (higher + std::fabs(higher));
211  sum += dist * dist;
212  }
213  else
214  {
215  sum += pow((lower + fabs(lower)) + (higher + fabs(higher)),
216  (ElemType) MetricType::Power);
217  }
218  }
219 
220  // Now take the Power'th root (but make sure our result is squared if it needs
221  // to be); then cancel out the constant of 2 (which may have been squared now)
222  // that was introduced earlier. The compiler should optimize out the if
223  // statement entirely.
224  if (MetricType::Power == 1)
225  return sum * 0.5;
226  else if (MetricType::Power == 2)
227  {
228  if (MetricType::TakeRoot)
229  return (ElemType) std::sqrt(sum) * 0.5;
230  else
231  return sum * 0.25;
232  }
233  else
234  {
235  if (MetricType::TakeRoot)
236  return (ElemType) pow((double) sum,
237  1.0 / (double) MetricType::Power) / 2.0;
238  else
239  return sum / pow(2.0, MetricType::Power);
240  }
241 }
242 
246 template<typename MetricType, typename ElemType>
248  const
249 {
250  Log::Assert(dim == other.dim);
251 
252  ElemType sum = 0;
253  const math::RangeType<ElemType>* mbound = bounds;
254  const math::RangeType<ElemType>* obound = other.bounds;
255 
256  ElemType lower, higher;
257  for (size_t d = 0; d < dim; d++)
258  {
259  lower = obound->Lo() - mbound->Hi();
260  higher = mbound->Lo() - obound->Hi();
261  // We invoke the following:
262  // x + fabs(x) = max(x * 2, 0)
263  // (x * 2)^2 / 4 = x^2
264 
265  // The compiler should optimize out this if statement entirely.
266  if (MetricType::Power == 1)
267  sum += (lower + std::fabs(lower)) + (higher + std::fabs(higher));
268  else if (MetricType::Power == 2)
269  {
270  ElemType dist = (lower + std::fabs(lower)) + (higher + std::fabs(higher));
271  sum += dist * dist;
272  }
273  else
274  {
275  sum += pow((lower + fabs(lower)) + (higher + fabs(higher)),
276  (ElemType) MetricType::Power);
277  }
278 
279  // Move bound pointers.
280  mbound++;
281  obound++;
282  }
283 
284  // The compiler should optimize out this if statement entirely.
285  if (MetricType::Power == 1)
286  return sum * 0.5;
287  else if (MetricType::Power == 2)
288  {
289  if (MetricType::TakeRoot)
290  return (ElemType) std::sqrt(sum) * 0.5;
291  else
292  return sum * 0.25;
293  }
294  else
295  {
296  if (MetricType::TakeRoot)
297  return (ElemType) pow((double) sum,
298  1.0 / (double) MetricType::Power) / 2.0;
299  else
300  return sum / pow(2.0, MetricType::Power);
301  }
302 }
303 
307 template<typename MetricType, typename ElemType>
308 template<typename VecType>
310  const VecType& point,
311  typename std::enable_if_t<IsVector<VecType>::value>* /* junk */) const
312 {
313  ElemType sum = 0;
314 
315  Log::Assert(point.n_elem == dim);
316 
317  for (size_t d = 0; d < dim; d++)
318  {
319  ElemType v = std::max(fabs(point[d] - bounds[d].Lo()),
320  fabs(bounds[d].Hi() - point[d]));
321 
322  // The compiler should optimize out this if statement entirely.
323  if (MetricType::Power == 1)
324  sum += v; // v is non-negative.
325  else if (MetricType::Power == 2)
326  sum += v * v;
327  else
328  sum += std::pow(v, (ElemType) MetricType::Power);
329  }
330 
331  // The compiler should optimize out this if statement entirely.
332  if (MetricType::TakeRoot)
333  {
334  if (MetricType::Power == 1)
335  return sum;
336  else if (MetricType::Power == 2)
337  return (ElemType) std::sqrt(sum);
338  else
339  return (ElemType) pow((double) sum, 1.0 / (double) MetricType::Power);
340  }
341  else
342  return sum;
343 }
344 
348 template<typename MetricType, typename ElemType>
350  const HRectBound& other)
351  const
352 {
353  ElemType sum = 0;
354 
355  Log::Assert(dim == other.dim);
356 
357  ElemType v;
358  for (size_t d = 0; d < dim; d++)
359  {
360  v = std::max(fabs(other.bounds[d].Hi() - bounds[d].Lo()),
361  fabs(bounds[d].Hi() - other.bounds[d].Lo()));
362 
363  // The compiler should optimize out this if statement entirely.
364  if (MetricType::Power == 1)
365  sum += v; // v is non-negative.
366  else if (MetricType::Power == 2)
367  sum += v * v;
368  else
369  sum += std::pow(v, (ElemType) MetricType::Power);
370  }
371 
372  // The compiler should optimize out this if statement entirely.
373  if (MetricType::TakeRoot)
374  {
375  if (MetricType::Power == 1)
376  return sum;
377  else if (MetricType::Power == 2)
378  return (ElemType) std::sqrt(sum);
379  else
380  return (ElemType) pow((double) sum, 1.0 / (double) MetricType::Power);
381  }
382  else
383  return sum;
384 }
385 
389 template<typename MetricType, typename ElemType>
392  const HRectBound& other) const
393 {
394  ElemType loSum = 0;
395  ElemType hiSum = 0;
396 
397  Log::Assert(dim == other.dim);
398 
399  ElemType v1, v2, vLo, vHi;
400  for (size_t d = 0; d < dim; d++)
401  {
402  v1 = other.bounds[d].Lo() - bounds[d].Hi();
403  v2 = bounds[d].Lo() - other.bounds[d].Hi();
404  // One of v1 or v2 is negative.
405  if (v1 >= v2)
406  {
407  vHi = -v2; // Make it nonnegative.
408  vLo = (v1 > 0) ? v1 : 0; // Force to be 0 if negative.
409  }
410  else
411  {
412  vHi = -v1; // Make it nonnegative.
413  vLo = (v2 > 0) ? v2 : 0; // Force to be 0 if negative.
414  }
415 
416  // The compiler should optimize out this if statement entirely.
417  if (MetricType::Power == 1)
418  {
419  loSum += vLo; // vLo is non-negative.
420  hiSum += vHi; // vHi is non-negative.
421  }
422  else if (MetricType::Power == 2)
423  {
424  loSum += vLo * vLo;
425  hiSum += vHi * vHi;
426  }
427  else
428  {
429  loSum += std::pow(vLo, (ElemType) MetricType::Power);
430  hiSum += std::pow(vHi, (ElemType) MetricType::Power);
431  }
432  }
433 
434  if (MetricType::TakeRoot)
435  {
436  if (MetricType::Power == 1)
437  return math::RangeType<ElemType>(loSum, hiSum);
438  else if (MetricType::Power == 2)
439  return math::RangeType<ElemType>((ElemType) std::sqrt(loSum),
440  (ElemType) std::sqrt(hiSum));
441  else
442  {
444  (ElemType) pow((double) loSum, 1.0 / (double) MetricType::Power),
445  (ElemType) pow((double) hiSum, 1.0 / (double) MetricType::Power));
446  }
447  }
448  else
449  return math::RangeType<ElemType>(loSum, hiSum);
450 }
451 
455 template<typename MetricType, typename ElemType>
456 template<typename VecType>
459  const VecType& point,
460  typename std::enable_if_t<IsVector<VecType>::value>* /* junk */) const
461 {
462  ElemType loSum = 0;
463  ElemType hiSum = 0;
464 
465  Log::Assert(point.n_elem == dim);
466 
467  ElemType v1, v2, vLo, vHi;
468  for (size_t d = 0; d < dim; d++)
469  {
470  v1 = bounds[d].Lo() - point[d]; // Negative if point[d] > lo.
471  v2 = point[d] - bounds[d].Hi(); // Negative if point[d] < hi.
472  // One of v1 or v2 (or both) is negative.
473  if (v1 >= 0) // point[d] <= bounds_[d].Lo().
474  {
475  vHi = -v2; // v2 will be larger but must be negated.
476  vLo = v1;
477  }
478  else // point[d] is between lo and hi, or greater than hi.
479  {
480  if (v2 >= 0)
481  {
482  vHi = -v1; // v1 will be larger, but must be negated.
483  vLo = v2;
484  }
485  else
486  {
487  vHi = -std::min(v1, v2); // Both are negative, but we need the larger.
488  vLo = 0;
489  }
490  }
491 
492  // The compiler should optimize out this if statement entirely.
493  if (MetricType::Power == 1)
494  {
495  loSum += vLo; // vLo is non-negative.
496  hiSum += vHi; // vHi is non-negative.
497  }
498  else if (MetricType::Power == 2)
499  {
500  loSum += vLo * vLo;
501  hiSum += vHi * vHi;
502  }
503  else
504  {
505  loSum += std::pow(vLo, (ElemType) MetricType::Power);
506  hiSum += std::pow(vHi, (ElemType) MetricType::Power);
507  }
508  }
509 
510  if (MetricType::TakeRoot)
511  {
512  if (MetricType::Power == 1)
513  return math::RangeType<ElemType>(loSum, hiSum);
514  else if (MetricType::Power == 2)
515  return math::RangeType<ElemType>((ElemType) std::sqrt(loSum),
516  (ElemType) std::sqrt(hiSum));
517  else
518  {
520  (ElemType) pow((double) loSum, 1.0 / (double) MetricType::Power),
521  (ElemType) pow((double) hiSum, 1.0 / (double) MetricType::Power));
522  }
523  }
524  else
525  return math::RangeType<ElemType>(loSum, hiSum);
526 }
527 
531 template<typename MetricType, typename ElemType>
532 template<typename MatType>
535 {
536  Log::Assert(data.n_rows == dim);
537 
538  arma::Col<ElemType> mins(min(data, 1));
539  arma::Col<ElemType> maxs(max(data, 1));
540 
541  minWidth = std::numeric_limits<ElemType>::max();
542  for (size_t i = 0; i < dim; ++i)
543  {
544  bounds[i] |= math::RangeType<ElemType>(mins[i], maxs[i]);
545  const ElemType width = bounds[i].Width();
546  if (width < minWidth)
547  minWidth = width;
548  }
549 
550  return *this;
551 }
552 
556 template<typename MetricType, typename ElemType>
559 {
560  assert(other.dim == dim);
561 
562  minWidth = std::numeric_limits<ElemType>::max();
563  for (size_t i = 0; i < dim; ++i)
564  {
565  bounds[i] |= other.bounds[i];
566  const ElemType width = bounds[i].Width();
567  if (width < minWidth)
568  minWidth = width;
569  }
570 
571  return *this;
572 }
573 
577 template<typename MetricType, typename ElemType>
578 template<typename VecType>
580  const VecType& point) const
581 {
582  for (size_t i = 0; i < point.n_elem; ++i)
583  {
584  if (!bounds[i].Contains(point(i)))
585  return false;
586  }
587 
588  return true;
589 }
590 
594 template<typename MetricType, typename ElemType>
596  const HRectBound& bound) const
597 {
598  for (size_t i = 0; i < dim; ++i)
599  {
600  const math::RangeType<ElemType>& r_a = bounds[i];
601  const math::RangeType<ElemType>& r_b = bound.bounds[i];
602 
603  // If a does not overlap b at all.
604  if (r_a.Hi() <= r_b.Lo() || r_a.Lo() >= r_b.Hi())
605  return false;
606  }
607 
608  return true;
609 }
610 
614 template<typename MetricType, typename ElemType>
617 {
619 
620  for (size_t k = 0; k < dim; ++k)
621  {
622  result[k].Lo() = std::max(bounds[k].Lo(), bound.bounds[k].Lo());
623  result[k].Hi() = std::min(bounds[k].Hi(), bound.bounds[k].Hi());
624  }
625  return result;
626 }
627 
631 template<typename MetricType, typename ElemType>
634 {
635  for (size_t k = 0; k < dim; ++k)
636  {
637  bounds[k].Lo() = std::max(bounds[k].Lo(), bound.bounds[k].Lo());
638  bounds[k].Hi() = std::min(bounds[k].Hi(), bound.bounds[k].Hi());
639  }
640  return *this;
641 }
642 
646 template<typename MetricType, typename ElemType>
648  const HRectBound& bound) const
649 {
650  ElemType volume = 1.0;
651 
652  for (size_t k = 0; k < dim; ++k)
653  {
654  ElemType lo = std::max(bounds[k].Lo(), bound.bounds[k].Lo());
655  ElemType hi = std::min(bounds[k].Hi(), bound.bounds[k].Hi());
656 
657  if ( hi <= lo)
658  return 0;
659 
660  volume *= hi - lo;
661  }
662  return volume;
663 }
664 
668 template<typename MetricType, typename ElemType>
670 {
671  ElemType d = 0;
672  for (size_t i = 0; i < dim; ++i)
673  d += std::pow(bounds[i].Hi() - bounds[i].Lo(),
674  (ElemType) MetricType::Power);
675 
676  if (MetricType::TakeRoot)
677  return (ElemType) std::pow((double) d, 1.0 / (double) MetricType::Power);
678  else
679  return d;
680 }
681 
683 template<typename MetricType, typename ElemType>
684 template<typename Archive>
686  Archive& ar,
687  const uint32_t /* version */)
688 {
689  // We can't serialize a raw array directly, so wrap it.
690  ar(CEREAL_POINTER_ARRAY(bounds, dim));
691  ar(CEREAL_NVP(minWidth));
692  ar(CEREAL_NVP(metric));
693 }
694 
695 } // namespace bound
696 } // namespace mlpack
697 
698 #endif // MLPACK_CORE_TREE_HRECTBOUND_IMPL_HPP
T Lo() const
Get the lower bound.
Definition: range.hpp:61
Bounds that are useful for binary space partitioning trees.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Hyper-rectangle bound for an L-metric.
Definition: hrectbound.hpp:54
HRectBound()
Empty constructor; creates a bound of dimensionality 0.
Definition: hrectbound_impl.hpp:27
T Hi() const
Get the upper bound.
Definition: range.hpp:66
size_t Dim() const
Gets the dimensionality.
Definition: hrectbound.hpp:96
void Center(const arma::mat &x, arma::mat &xCentered)
Creates a centered matrix, where centering is done by subtracting the sum over the columns (a column ...
Definition: lin_alg.cpp:43
T Width() const
Gets the span of the range (hi - lo).
Definition: range_impl.hpp:47
#define CEREAL_POINTER_ARRAY(T, S)
Cereal does not support the serialization of raw pointer.
Definition: array_wrapper.hpp:87
If value == true, then VecType is some sort of Armadillo vector or subview.
Definition: arma_traits.hpp:35