mlpack
cellbound_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_CORE_TREE_CELLBOUND_IMPL_HPP
14 #define MLPACK_CORE_TREE_CELLBOUND_IMPL_HPP
15 
16 #include <math.h>
17 
18 // In case it has not been included yet.
19 #include "cellbound.hpp"
20 
21 namespace mlpack {
22 namespace bound {
23 
27 template<typename MetricType, typename ElemType>
28 inline CellBound<MetricType, ElemType>::CellBound() :
29  dim(0),
30  bounds(NULL),
31  loBound(arma::Mat<ElemType>()),
32  hiBound(arma::Mat<ElemType>()),
33  numBounds(0),
34  loAddress(arma::Col<AddressElemType>()),
35  hiAddress(arma::Col<AddressElemType>()),
36  minWidth(0)
37 { /* Nothing to do. */ }
38 
43 template<typename MetricType, typename ElemType>
44 inline CellBound<MetricType, ElemType>::CellBound(const size_t dimension) :
45  dim(dimension),
46  bounds(new math::RangeType<ElemType>[dim]),
47  loBound(arma::Mat<ElemType>(dim, maxNumBounds)),
48  hiBound(arma::Mat<ElemType>(dim, maxNumBounds)),
49  numBounds(0),
50  loAddress(dim),
51  hiAddress(dim),
52  minWidth(0)
53 {
54  for (size_t k = 0; k < dim ; ++k)
55  {
56  loAddress[k] = std::numeric_limits<AddressElemType>::max();
57  hiAddress[k] = 0;
58  }
59 }
60 
64 template<typename MetricType, typename ElemType>
65 inline CellBound<MetricType, ElemType>::CellBound(
66  const CellBound<MetricType, ElemType>& other) :
67  dim(other.Dim()),
68  bounds(new math::RangeType<ElemType>[dim]),
69  loBound(other.loBound),
70  hiBound(other.hiBound),
71  numBounds(other.numBounds),
72  loAddress(other.loAddress),
73  hiAddress(other.hiAddress),
74  minWidth(other.MinWidth())
75 {
76  // Copy other bounds over.
77  for (size_t i = 0; i < dim; ++i)
78  bounds[i] = other.bounds[i];
79 }
80 
84 template<typename MetricType, typename ElemType>
85 inline CellBound<
86  MetricType,
87  ElemType>& CellBound<MetricType, ElemType>::operator=(
88  const CellBound<MetricType, ElemType>& other)
89 {
90  if (this == &other)
91  return *this;
92 
93  if (dim != other.Dim())
94  {
95  // Reallocation is necessary.
96  delete[] bounds;
97 
98  dim = other.Dim();
99  bounds = new math::RangeType<ElemType>[dim];
100  }
101 
102  loBound = other.loBound;
103  hiBound = other.hiBound;
104  numBounds = other.numBounds;
105  loAddress = other.loAddress;
106  hiAddress = other.hiAddress;
107 
108  // Now copy each of the bound values.
109  for (size_t i = 0; i < dim; ++i)
110  bounds[i] = other.bounds[i];
111 
112  minWidth = other.MinWidth();
113 
114  return *this;
115 }
116 
120 template<typename MetricType, typename ElemType>
121 inline CellBound<MetricType, ElemType>::CellBound(
122  CellBound<MetricType, ElemType>&& other) :
123  dim(other.dim),
124  bounds(other.bounds),
125  loBound(std::move(other.loBound)),
126  hiBound(std::move(other.hiBound)),
127  numBounds(std::move(other.numBounds)),
128  loAddress(std::move(other.loAddress)),
129  hiAddress(std::move(other.hiAddress)),
130  minWidth(other.minWidth)
131 {
132  // Fix the other bound.
133  other.dim = 0;
134  other.bounds = NULL;
135  other.minWidth = 0.0;
136 }
137 
141 template<typename MetricType, typename ElemType>
142 inline CellBound<MetricType, ElemType>::~CellBound()
143 {
144  if (bounds)
145  delete[] bounds;
146 }
147 
151 template<typename MetricType, typename ElemType>
152 inline void CellBound<MetricType, ElemType>::Clear()
153 {
154  for (size_t k = 0; k < dim; ++k)
155  {
156  bounds[k] = math::RangeType<ElemType>();
157 
158  loAddress[k] = std::numeric_limits<AddressElemType>::max();
159  hiAddress[k] = 0;
160  }
161 
162  minWidth = 0;
163 }
164 
165 /***
166  * Calculates the centroid of the range, placing it into the given vector.
167  *
168  * @param centroid Vector which the centroid will be written to.
169  */
170 template<typename MetricType, typename ElemType>
172  arma::Col<ElemType>& center) const
173 {
174  // Set size correctly if necessary.
175  if (!(center.n_elem == dim))
176  center.set_size(dim);
177 
178  for (size_t i = 0; i < dim; ++i)
179  center(i) = bounds[i].Mid();
180 }
181 
182 template<typename MetricType, typename ElemType>
183 template<typename MatType>
184 void CellBound<MetricType, ElemType>::AddBound(
185  const arma::Col<ElemType>& loCorner,
186  const arma::Col<ElemType>& hiCorner,
187  const MatType& data)
188 {
189  assert(numBounds < loBound.n_cols);
190  assert(loBound.n_rows == dim);
191  assert(loCorner.n_elem == dim);
192  assert(hiCorner.n_elem == dim);
193 
194  for (size_t k = 0; k < dim; ++k)
195  {
196  loBound(k, numBounds) = std::numeric_limits<ElemType>::max();
197  hiBound(k, numBounds) = std::numeric_limits<ElemType>::lowest();
198  }
199 
200  for (size_t i = 0; i < data.n_cols; ++i)
201  {
202  size_t k = 0;
203  // Check if the point is contained in the hyperrectangle.
204  for (k = 0; k < dim; ++k)
205  if (data(k, i) < loCorner[k] || data(k, i) > hiCorner[k])
206  break;
207 
208  if (k < dim)
209  continue; // The point is not contained in the hyperrectangle.
210 
211  // Shrink the bound.
212  for (k = 0; k < dim; ++k)
213  {
214  loBound(k, numBounds) = std::min(loBound(k, numBounds), data(k, i));
215  hiBound(k, numBounds) = std::max(hiBound(k, numBounds), data(k, i));
216  }
217  }
218 
219  for (size_t k = 0; k < dim; ++k)
220  if (loBound(k, numBounds) > hiBound(k, numBounds))
221  return; // The hyperrectangle does not contain points.
222 
223  numBounds++;
224 }
225 
226 
227 template<typename MetricType, typename ElemType>
228 template<typename MatType>
229 void CellBound<MetricType, ElemType>::InitHighBound(size_t numEqualBits,
230  const MatType& data)
231 {
232  arma::Col<AddressElemType> tmpHiAddress(hiAddress);
233  arma::Col<AddressElemType> tmpLoAddress(hiAddress);
234  arma::Col<ElemType> loCorner(tmpHiAddress.n_elem);
235  arma::Col<ElemType> hiCorner(tmpHiAddress.n_elem);
236 
237  assert(tmpHiAddress.n_elem > 0);
238 
239  // We have to calculate the number of subrectangles since the maximum number
240  // of hyperrectangles is restricted.
241  size_t numCorners = 0;
242  for (size_t pos = numEqualBits + 1; pos < order * tmpHiAddress.n_elem; pos++)
243  {
244  size_t row = pos / order;
245  size_t bit = order - 1 - pos % order;
246 
247  // This hyperrectangle is not contained entirely in the bound.
248  // So, the number of hyperrectangles should be increased.
249  if (tmpHiAddress[row] & ((AddressElemType) 1 << bit))
250  numCorners++;
251 
252  // We ran out of the limit of hyperrectangles. In that case we enlare
253  // the last hyperrectangle.
254  if (numCorners >= maxNumBounds / 2)
255  tmpHiAddress[row] |= ((AddressElemType) 1 << bit);
256  }
257 
258  size_t pos = order * tmpHiAddress.n_elem - 1;
259 
260  // Find the last hyperrectangle and add it to the bound.
261  for ( ; pos > numEqualBits; pos--)
262  {
263  size_t row = pos / order;
264  size_t bit = order - 1 - pos % order;
265 
266  // All last bits after pos of tmpHiAddress are equal to 1 and
267  // All last bits of tmpLoAddress (after pos) are equal to 0.
268  // Thus, tmpHiAddress corresponds to the high corner of the enlarged
269  // rectangle and tmpLoAddress corresponds to the lower corner.
270  if (!(tmpHiAddress[row] & ((AddressElemType) 1 << bit)))
271  {
272  addr::AddressToPoint(loCorner, tmpLoAddress);
273  addr::AddressToPoint(hiCorner, tmpHiAddress);
274 
275  AddBound(loCorner, hiCorner, data);
276  break;
277  }
278  // Nullify the bit that corresponds to this step.
279  tmpLoAddress[row] &= ~((AddressElemType) 1 << bit);
280  }
281 
282  // Add the enlarged rectangle if we have not done that.
283  if (pos == numEqualBits)
284  {
285  addr::AddressToPoint(loCorner, tmpLoAddress);
286  addr::AddressToPoint(hiCorner, tmpHiAddress);
287 
288  AddBound(loCorner, hiCorner, data);
289  }
290 
291  for ( ; pos > numEqualBits; pos--)
292  {
293  size_t row = pos / order;
294  size_t bit = order - 1 - pos % order;
295 
296  // The lower bound should correspond to this step.
297  tmpLoAddress[row] &= ~((AddressElemType) 1 << bit);
298 
299  if (tmpHiAddress[row] & ((AddressElemType) 1 << bit))
300  {
301  // This hyperrectangle is contained entirely in the bound and do not
302  // overlap with other hyperrectangles since loAddress is less than
303  // tmpLoAddress and tmpHiAddress is less that the lower addresses
304  // of hyperrectangles that we have added previously.
305  tmpHiAddress[row] ^= (AddressElemType) 1 << bit;
306  addr::AddressToPoint(loCorner, tmpLoAddress);
307  addr::AddressToPoint(hiCorner, tmpHiAddress);
308 
309  AddBound(loCorner, hiCorner, data);
310  }
311  // The high bound should correspond to this step.
312  tmpHiAddress[row] |= ((AddressElemType) 1 << bit);
313  }
314 }
315 
316 template<typename MetricType, typename ElemType>
317 template<typename MatType>
318 void CellBound<MetricType, ElemType>::InitLowerBound(size_t numEqualBits,
319  const MatType& data)
320 {
321  arma::Col<AddressElemType> tmpHiAddress(loAddress);
322  arma::Col<AddressElemType> tmpLoAddress(loAddress);
323  arma::Col<ElemType> loCorner(tmpHiAddress.n_elem);
324  arma::Col<ElemType> hiCorner(tmpHiAddress.n_elem);
325 
326  // We have to calculate the number of subrectangles since the maximum number
327  // of hyperrectangles is restricted.
328  size_t numCorners = 0;
329  for (size_t pos = numEqualBits + 1; pos < order * tmpHiAddress.n_elem; pos++)
330  {
331  size_t row = pos / order;
332  size_t bit = order - 1 - pos % order;
333 
334  // This hyperrectangle is not contained entirely in the bound.
335  // So, the number of hyperrectangles should be increased.
336  if (!(tmpLoAddress[row] & ((AddressElemType) 1 << bit)))
337  numCorners++;
338 
339  // We ran out of the limit of hyperrectangles. In that case we enlare
340  // the last hyperrectangle.
341  if (numCorners >= maxNumBounds - numBounds)
342  tmpLoAddress[row] &= ~((AddressElemType) 1 << bit);
343  }
344 
345  size_t pos = order * tmpHiAddress.n_elem - 1;
346 
347  // Find the last hyperrectangle and add it to the bound.
348  for ( ; pos > numEqualBits; pos--)
349  {
350  size_t row = pos / order;
351  size_t bit = order - 1 - pos % order;
352 
353  // All last bits after pos of tmpHiAddress are equal to 1 and
354  // All last bits of tmpLoAddress (after pos) are equal to 0.
355  // Thus, tmpHiAddress corresponds to the high corner of the enlarged
356  // rectangle and tmpLoAddress corresponds to the lower corner.
357  if (tmpLoAddress[row] & ((AddressElemType) 1 << bit))
358  {
359  addr::AddressToPoint(loCorner, tmpLoAddress);
360  addr::AddressToPoint(hiCorner, tmpHiAddress);
361 
362  AddBound(loCorner, hiCorner, data);
363  break;
364  }
365  // Enlarge the hyperrectangle at this step since it is contained
366  // entirely in the bound.
367  tmpHiAddress[row] |= ((AddressElemType) 1 << bit);
368  }
369 
370  // Add the enlarged rectangle if we have not done that.
371  if (pos == numEqualBits)
372  {
373  addr::AddressToPoint(loCorner, tmpLoAddress);
374  addr::AddressToPoint(hiCorner, tmpHiAddress);
375 
376  AddBound(loCorner, hiCorner, data);
377  }
378 
379  for ( ; pos > numEqualBits; pos--)
380  {
381  size_t row = pos / order;
382  size_t bit = order - 1 - pos % order;
383 
384  // The high bound should correspond to this step.
385  tmpHiAddress[row] |= ((AddressElemType) 1 << bit);
386 
387  if (!(tmpLoAddress[row] & ((AddressElemType) 1 << bit)))
388  {
389  // This hyperrectangle is contained entirely in the bound and do not
390  // overlap with other hyperrectangles since hiAddress is greater than
391  // tmpHiAddress and tmpLoAddress is greater that the high addresses
392  // of hyperrectangles that we have added previously.
393  tmpLoAddress[row] ^= (AddressElemType) 1 << bit;
394 
395  addr::AddressToPoint(loCorner, tmpLoAddress);
396  addr::AddressToPoint(hiCorner, tmpHiAddress);
397 
398  AddBound(loCorner, hiCorner, data);
399  }
400 
401  // The lower bound should correspond to this step.
402  tmpLoAddress[row] &= ~((AddressElemType) 1 << bit);
403  }
404 }
405 
406 template<typename MetricType, typename ElemType>
407 template<typename MatType>
408 void CellBound<MetricType, ElemType>::UpdateAddressBounds(const MatType& data)
409 {
410  numBounds = 0;
411 
412  // Calculate the number of equal leading bits of the lower address and
413  // the high address.
414  size_t row = 0;
415  for ( ; row < hiAddress.n_elem; row++)
416  if (loAddress[row] != hiAddress[row])
417  break;
418 
419  // If the high address is equal to the lower address.
420  if (row == hiAddress.n_elem)
421  {
422  for (size_t i = 0; i < dim; ++i)
423  {
424  loBound(i, 0) = bounds[i].Lo();
425  hiBound(i, 0) = bounds[i].Hi();
426  }
427  numBounds = 1;
428 
429  return;
430  }
431 
432  size_t bit = 0;
433  for ( ; bit < order; bit++)
434  if ((loAddress[row] & ((AddressElemType) 1 << (order - 1 - bit))) !=
435  (hiAddress[row] & ((AddressElemType) 1 << (order - 1 - bit))))
436  break;
437 
438  if ((row == hiAddress.n_elem - 1) && (bit == order - 1))
439  {
440  // If the addresses differ in the last bit.
441  for (size_t i = 0; i < dim; ++i)
442  {
443  loBound(i, 0) = bounds[i].Lo();
444  hiBound(i, 0) = bounds[i].Hi();
445  }
446 
447  numBounds = 1;
448 
449  return;
450  }
451 
452  size_t numEqualBits = row * order + bit;
453  InitHighBound(numEqualBits, data);
454  InitLowerBound(numEqualBits, data);
455 
456  assert(numBounds <= maxNumBounds);
457 
458  if (numBounds == 0)
459  {
460  // I think this should never happen.
461  for (size_t i = 0; i < dim; ++i)
462  {
463  loBound(i, 0) = bounds[i].Lo();
464  hiBound(i, 0) = bounds[i].Hi();
465  }
466 
467  numBounds = 1;
468  }
469 }
470 
474 template<typename MetricType, typename ElemType>
475 template<typename VecType>
476 inline ElemType CellBound<MetricType, ElemType>::MinDistance(
477  const VecType& point,
478  typename std::enable_if_t<IsVector<VecType>::value>* /* junk */) const
479 {
480  Log::Assert(point.n_elem == dim);
481 
482  ElemType minSum = std::numeric_limits<ElemType>::max();
483 
484  ElemType lower, higher;
485 
486  for (size_t i = 0; i < numBounds; ++i)
487  {
488  ElemType sum = 0;
489 
490  for (size_t d = 0; d < dim; d++)
491  {
492  lower = loBound(d, i) - point[d];
493  higher = point[d] - hiBound(d, i);
494 
495  // Since only one of 'lower' or 'higher' is negative, if we add
496  // each's absolute value to itself and then sum those two, our
497  // result is the non negative half of the equation times two;
498  // then we raise to power Power.
499  if (MetricType::Power == 1)
500  sum += lower + std::fabs(lower) + higher + std::fabs(higher);
501  else if (MetricType::Power == 2)
502  {
503  ElemType dist = lower + std::fabs(lower) + higher + std::fabs(higher);
504  sum += dist * dist;
505  }
506  else
507  {
508  sum += pow((lower + fabs(lower)) + (higher + fabs(higher)),
509  (ElemType) MetricType::Power);
510  }
511 
512  if (sum >= minSum)
513  break;
514  }
515 
516  if (sum < minSum)
517  minSum = sum;
518  }
519 
520  // Now take the Power'th root (but make sure our result is squared if it needs
521  // to be); then cancel out the constant of 2 (which may have been squared now)
522  // that was introduced earlier. The compiler should optimize out the if
523  // statement entirely.
524  if (MetricType::Power == 1)
525  return minSum * 0.5;
526  else if (MetricType::Power == 2)
527  {
528  if (MetricType::TakeRoot)
529  return (ElemType) std::sqrt(minSum) * 0.5;
530  else
531  return minSum * 0.25;
532  }
533  else
534  {
535  if (MetricType::TakeRoot)
536  return (ElemType) pow((double) minSum,
537  1.0 / (double) MetricType::Power) / 2.0;
538  else
539  return minSum / pow(2.0, MetricType::Power);
540  }
541 }
542 
546 template<typename MetricType, typename ElemType>
547 ElemType CellBound<MetricType, ElemType>::MinDistance(const CellBound& other)
548  const
549 {
550  Log::Assert(dim == other.dim);
551 
552  ElemType minSum = std::numeric_limits<ElemType>::max();
553 
554  ElemType lower, higher;
555 
556  for (size_t i = 0; i < numBounds; ++i)
557  for (size_t j = 0; j < other.numBounds; ++j)
558  {
559  ElemType sum = 0;
560  for (size_t d = 0; d < dim; d++)
561  {
562  lower = other.loBound(d, j) - hiBound(d, i);
563  higher = loBound(d, i) - other.hiBound(d, j);
564  // We invoke the following:
565  // x + fabs(x) = max(x * 2, 0)
566  // (x * 2)^2 / 4 = x^2
567 
568  // The compiler should optimize out this if statement entirely.
569  if (MetricType::Power == 1)
570  sum += (lower + std::fabs(lower)) + (higher + std::fabs(higher));
571  else if (MetricType::Power == 2)
572  {
573  ElemType dist = lower + std::fabs(lower) + higher + std::fabs(higher);
574  sum += dist * dist;
575  }
576  else
577  {
578  sum += pow((lower + fabs(lower)) + (higher + fabs(higher)),
579  (ElemType) MetricType::Power);
580  }
581 
582  if (sum >= minSum)
583  break;
584  }
585 
586  if (sum < minSum)
587  minSum = sum;
588  }
589 
590  // The compiler should optimize out this if statement entirely.
591  if (MetricType::Power == 1)
592  return minSum * 0.5;
593  else if (MetricType::Power == 2)
594  {
595  if (MetricType::TakeRoot)
596  return (ElemType) std::sqrt(minSum) * 0.5;
597  else
598  return minSum * 0.25;
599  }
600  else
601  {
602  if (MetricType::TakeRoot)
603  return (ElemType) pow((double) minSum,
604  1.0 / (double) MetricType::Power) / 2.0;
605  else
606  return minSum / pow(2.0, MetricType::Power);
607  }
608 }
609 
613 template<typename MetricType, typename ElemType>
614 template<typename VecType>
615 inline ElemType CellBound<MetricType, ElemType>::MaxDistance(
616  const VecType& point,
617  typename std::enable_if_t<IsVector<VecType>::value>* /* junk */) const
618 {
619  ElemType maxSum = std::numeric_limits<ElemType>::lowest();
620 
621  Log::Assert(point.n_elem == dim);
622 
623  for (size_t i = 0; i < numBounds; ++i)
624  {
625  ElemType sum = 0;
626  for (size_t d = 0; d < dim; d++)
627  {
628  ElemType v = std::max(fabs(point[d] - loBound(d, i)),
629  fabs(hiBound(d, i) - point[d]));
630 
631  if (MetricType::Power == 1)
632  sum += v; // v is non-negative.
633  else if (MetricType::Power == 2)
634  sum += v * v;
635  else
636  sum += std::pow(v, (ElemType) MetricType::Power);
637  }
638 
639  if (sum > maxSum)
640  maxSum = sum;
641  }
642 
643  // The compiler should optimize out this if statement entirely.
644  if (MetricType::TakeRoot)
645  {
646  if (MetricType::Power == 1)
647  return maxSum;
648  else if (MetricType::Power == 2)
649  return (ElemType) std::sqrt(maxSum);
650  else
651  return (ElemType) pow((double) maxSum, 1.0 / (double) MetricType::Power);
652  }
653 
654  return maxSum;
655 }
656 
660 template<typename MetricType, typename ElemType>
661 inline ElemType CellBound<MetricType, ElemType>::MaxDistance(
662  const CellBound& other)
663  const
664 {
665  ElemType maxSum = std::numeric_limits<ElemType>::lowest();
666 
667  Log::Assert(dim == other.dim);
668 
669  ElemType v;
670  for (size_t i = 0; i < numBounds; ++i)
671  for (size_t j = 0; j < other.numBounds; ++j)
672  {
673  ElemType sum = 0;
674  for (size_t d = 0; d < dim; d++)
675  {
676  v = std::max(fabs(other.hiBound(d, j) - loBound(d, i)),
677  fabs(hiBound(d, i) - other.loBound(d, j)));
678 
679  // The compiler should optimize out this if statement entirely.
680  if (MetricType::Power == 1)
681  sum += v; // v is non-negative.
682  else if (MetricType::Power == 2)
683  sum += v * v;
684  else
685  sum += std::pow(v, (ElemType) MetricType::Power);
686  }
687 
688  if (sum > maxSum)
689  maxSum = sum;
690  }
691 
692  // The compiler should optimize out this if statement entirely.
693  if (MetricType::TakeRoot)
694  {
695  if (MetricType::Power == 1)
696  return maxSum;
697  else if (MetricType::Power == 2)
698  return (ElemType) std::sqrt(maxSum);
699  else
700  return (ElemType) pow((double) maxSum, 1.0 / (double) MetricType::Power);
701  }
702 
703  return maxSum;
704 }
705 
709 template<typename MetricType, typename ElemType>
710 inline math::RangeType<ElemType>
711 CellBound<MetricType, ElemType>::RangeDistance(
712  const CellBound& other) const
713 {
714  ElemType minLoSum = std::numeric_limits<ElemType>::max();
715  ElemType maxHiSum = std::numeric_limits<ElemType>::lowest();
716 
717  Log::Assert(dim == other.dim);
718 
719  ElemType v1, v2, vLo, vHi;
720 
721  for (size_t i = 0; i < numBounds; ++i)
722  for (size_t j = 0; j < other.numBounds; ++j)
723  {
724  ElemType loSum = 0;
725  ElemType hiSum = 0;
726  for (size_t d = 0; d < dim; d++)
727  {
728  v1 = other.loBound(d, j) - hiBound(d, i);
729  v2 = loBound(d, i) - other.hiBound(d, j);
730  // One of v1 or v2 is negative.
731  if (v1 >= v2)
732  {
733  vHi = -v2; // Make it nonnegative.
734  vLo = (v1 > 0) ? v1 : 0; // Force to be 0 if negative.
735  }
736  else
737  {
738  vHi = -v1; // Make it nonnegative.
739  vLo = (v2 > 0) ? v2 : 0; // Force to be 0 if negative.
740  }
741 
742  // The compiler should optimize out this if statement entirely.
743  if (MetricType::Power == 1)
744  {
745  loSum += vLo; // vLo is non-negative.
746  hiSum += vHi; // vHi is non-negative.
747  }
748  else if (MetricType::Power == 2)
749  {
750  loSum += vLo * vLo;
751  hiSum += vHi * vHi;
752  }
753  else
754  {
755  loSum += std::pow(vLo, (ElemType) MetricType::Power);
756  hiSum += std::pow(vHi, (ElemType) MetricType::Power);
757  }
758  }
759 
760  if (loSum < minLoSum)
761  minLoSum = loSum;
762  if (hiSum > maxHiSum)
763  maxHiSum = hiSum;
764  }
765 
766  if (MetricType::TakeRoot)
767  {
768  if (MetricType::Power == 1)
769  return math::RangeType<ElemType>(minLoSum, maxHiSum);
770  else if (MetricType::Power == 2)
771  return math::RangeType<ElemType>((ElemType) std::sqrt(minLoSum),
772  (ElemType) std::sqrt(maxHiSum));
773  else
774  {
775  return math::RangeType<ElemType>(
776  (ElemType) pow((double) minLoSum, 1.0 / (double) MetricType::Power),
777  (ElemType) pow((double) maxHiSum, 1.0 / (double) MetricType::Power));
778  }
779  }
780 
781  return math::RangeType<ElemType>(minLoSum, maxHiSum);
782 }
783 
787 template<typename MetricType, typename ElemType>
788 template<typename VecType>
789 inline math::RangeType<ElemType>
790 CellBound<MetricType, ElemType>::RangeDistance(
791  const VecType& point,
792  typename std::enable_if_t<IsVector<VecType>::value>* /* junk */) const
793 {
794  ElemType minLoSum = std::numeric_limits<ElemType>::max();
795  ElemType maxHiSum = std::numeric_limits<ElemType>::lowest();
796 
797  Log::Assert(point.n_elem == dim);
798 
799  ElemType v1, v2, vLo, vHi;
800  for (size_t i = 0; i < numBounds; ++i)
801  {
802  ElemType loSum = 0;
803  ElemType hiSum = 0;
804  for (size_t d = 0; d < dim; d++)
805  {
806  v1 = loBound(d, i) - point[d]; // Negative if point[d] > lo.
807  v2 = point[d] - hiBound(d, i); // Negative if point[d] < hi.
808 
809  // One of v1 or v2 (or both) is negative.
810  if (v1 >= 0) // point[d] <= bounds_[d].Lo().
811  {
812  vHi = -v2; // v2 will be larger but must be negated.
813  vLo = v1;
814  }
815  else // point[d] is between lo and hi, or greater than hi.
816  {
817  if (v2 >= 0)
818  {
819  vHi = -v1; // v1 will be larger, but must be negated.
820  vLo = v2;
821  }
822  else
823  {
824  vHi = -std::min(v1, v2); // Both are negative, but we need the larger.
825  vLo = 0;
826  }
827  }
828 
829  // The compiler should optimize out this if statement entirely.
830  if (MetricType::Power == 1)
831  {
832  loSum += vLo; // vLo is non-negative.
833  hiSum += vHi; // vHi is non-negative.
834  }
835  else if (MetricType::Power == 2)
836  {
837  loSum += vLo * vLo;
838  hiSum += vHi * vHi;
839  }
840  else
841  {
842  loSum += std::pow(vLo, (ElemType) MetricType::Power);
843  hiSum += std::pow(vHi, (ElemType) MetricType::Power);
844  }
845  }
846  if (loSum < minLoSum)
847  minLoSum = loSum;
848  if (hiSum > maxHiSum)
849  maxHiSum = hiSum;
850  }
851 
852  if (MetricType::TakeRoot)
853  {
854  if (MetricType::Power == 1)
855  return math::RangeType<ElemType>(minLoSum, maxHiSum);
856  else if (MetricType::Power == 2)
857  return math::RangeType<ElemType>((ElemType) std::sqrt(minLoSum),
858  (ElemType) std::sqrt(maxHiSum));
859  else
860  {
861  return math::RangeType<ElemType>(
862  (ElemType) pow((double) minLoSum, 1.0 / (double) MetricType::Power),
863  (ElemType) pow((double) maxHiSum, 1.0 / (double) MetricType::Power));
864  }
865  }
866 
867  return math::RangeType<ElemType>(minLoSum, maxHiSum);
868 }
869 
873 template<typename MetricType, typename ElemType>
874 template<typename MatType>
875 inline CellBound<MetricType, ElemType>&
876 CellBound<MetricType, ElemType>::operator|=(const MatType& data)
877 {
878  Log::Assert(data.n_rows == dim);
879 
880  arma::Col<ElemType> mins(arma::min(data, 1));
881  arma::Col<ElemType> maxs(arma::max(data, 1));
882 
883  minWidth = std::numeric_limits<ElemType>::max();
884  for (size_t i = 0; i < dim; ++i)
885  {
886  bounds[i] |= math::RangeType<ElemType>(mins[i], maxs[i]);
887  const ElemType width = bounds[i].Width();
888  if (width < minWidth)
889  minWidth = width;
890 
891  loBound(i, 0) = bounds[i].Lo();
892  hiBound(i, 0) = bounds[i].Hi();
893  }
894 
895  numBounds = 1;
896 
897  return *this;
898 }
899 
903 template<typename MetricType, typename ElemType>
904 inline CellBound<MetricType, ElemType>&
905 CellBound<MetricType, ElemType>::operator|=(const CellBound& other)
906 {
907  assert(other.dim == dim);
908 
909  minWidth = std::numeric_limits<ElemType>::max();
910  for (size_t i = 0; i < dim; ++i)
911  {
912  bounds[i] |= other.bounds[i];
913  const ElemType width = bounds[i].Width();
914  if (width < minWidth)
915  minWidth = width;
916  }
917 
918  if (addr::CompareAddresses(other.loAddress, loAddress) < 0)
919  loAddress = other.loAddress;
920 
921  if (addr::CompareAddresses(other.hiAddress, hiAddress) > 0)
922  hiAddress = other.hiAddress;
923 
924  if (loAddress[0] > hiAddress[0])
925  {
926  for (size_t i = 0; i < dim; ++i)
927  {
928  loBound(i, 0) = bounds[i].Lo();
929  hiBound(i, 0) = bounds[i].Hi();
930  }
931 
932  numBounds = 1;
933  }
934 
935  return *this;
936 }
937 
941 template<typename MetricType, typename ElemType>
942 template<typename VecType>
943 inline bool CellBound<MetricType, ElemType>::Contains(
944  const VecType& point) const
945 {
946  for (size_t i = 0; i < point.n_elem; ++i)
947  {
948  if (!bounds[i].Contains(point(i)))
949  return false;
950  }
951 
952  if (loAddress[0] > hiAddress[0])
953  return true;
954 
955  arma::Col<AddressElemType> address(dim);
956 
957  addr::PointToAddress(address, point);
958 
959  return addr::Contains(address, loAddress, hiAddress);
960 }
961 
962 
966 template<typename MetricType, typename ElemType>
967 inline ElemType CellBound<MetricType, ElemType>::Diameter() const
968 {
969  ElemType d = 0;
970  for (size_t i = 0; i < dim; ++i)
971  d += std::pow(bounds[i].Hi() - bounds[i].Lo(),
972  (ElemType) MetricType::Power);
973 
974  if (MetricType::TakeRoot)
975  return (ElemType) std::pow((double) d, 1.0 / (double) MetricType::Power);
976 
977  return d;
978 }
979 
981 template<typename MetricType, typename ElemType>
982 template<typename Archive>
983 void CellBound<MetricType, ElemType>::serialize(
984  Archive& ar,
985  const uint32_t /* version */)
986 {
987  ar(CEREAL_POINTER_ARRAY(bounds, dim));
988  ar(CEREAL_NVP(minWidth));
989  ar(CEREAL_NVP(loBound));
990  ar(CEREAL_NVP(hiBound));
991  ar(CEREAL_NVP(numBounds));
992  ar(CEREAL_NVP(loAddress));
993  ar(CEREAL_NVP(hiAddress));
994  ar(CEREAL_NVP(metric));
995 }
996 
997 } // namespace bound
998 } // namespace mlpack
999 
1000 #endif // MLPACK_CORE_TREE_HRECTBOUND_IMPL_HPP
1001 
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Definition: pointer_wrapper.hpp:23
void Center(const arma::mat &x, arma::mat &xCentered)
Creates a centered matrix, where centering is done by subtracting the sum over the columns (a column ...
Definition: lin_alg.cpp:43
#define CEREAL_POINTER_ARRAY(T, S)
Cereal does not support the serialization of raw pointer.
Definition: array_wrapper.hpp:87
If value == true, then VecType is some sort of Armadillo vector or subview.
Definition: arma_traits.hpp:35
static void Assert(bool condition, const std::string &message="Assert Failed.")
Checks if the specified condition is true.
Definition: log.cpp:38
bool Contains(const AddressType1 &address, const AddressType2 &loBound, const AddressType3 &hiBound)
Returns true if an address is contained between two other addresses.
Definition: address.hpp:256