mlpack
dtree_impl.hpp
Go to the documentation of this file.
1 
14 #include "dtree.hpp"
15 #include <stack>
16 #include <vector>
17 
18 using namespace mlpack;
19 using namespace det;
20 
21 namespace details
22 {
23 
28 template<typename ElemType, typename MatType>
29 void ExtractSplits(std::vector<std::pair<ElemType, size_t>>& splitVec,
30  const MatType& data,
31  size_t dim,
32  const size_t start,
33  const size_t end,
34  const size_t minLeafSize)
35 {
36  static_assert(
37  std::is_same<typename MatType::elem_type, ElemType>::value == true,
38  "The ElemType does not correspond to the matrix's element type.");
39 
40  typedef std::pair<ElemType, size_t> SplitItem;
41  const typename MatType::row_type dimVec =
42  arma::sort(data(dim, arma::span(start, end - 1)));
43 
44  // Ensure the minimum leaf size on both sides. We need to figure out why there
45  // are spikes if this minLeafSize is enforced here...
46  for (size_t i = minLeafSize - 1; i < dimVec.n_elem - minLeafSize; ++i)
47  {
48  // This makes sense for real continuous data. This kinda corrupts the data
49  // and estimation if the data is ordinal. Potentially we can fix that by
50  // taking into account ordinality later in the min/max update, but then we
51  // can end-up with a zero-volumed dimension. No good.
52  const ElemType split = (dimVec[i] + dimVec[i + 1]) / 2.0;
53 
54  // Check if we can split here (two points are different)
55  if (split != dimVec[i])
56  splitVec.push_back(SplitItem(split, i + 1));
57  }
58 }
59 
60 // Now the custom arma::Mat implementation.
61 template<typename ElemType>
62 void ExtractSplits(std::vector<std::pair<ElemType, size_t>>& splitVec,
63  const arma::Mat<ElemType>& data,
64  size_t dim,
65  const size_t start,
66  const size_t end,
67  const size_t minLeafSize)
68 {
69  typedef std::pair<ElemType, size_t> SplitItem;
70  arma::rowvec dimVec = data(dim, arma::span(start, end - 1));
71 
72  // We sort these, in-place (it's a copy of the data, anyways).
73  std::sort(dimVec.begin(), dimVec.end());
74 
75  for (size_t i = minLeafSize - 1; i < dimVec.n_elem - minLeafSize; ++i)
76  {
77  // This makes sense for real continuous data. This kinda corrupts the data
78  // and estimation if the data is ordinal. Potentially we can fix that by
79  // taking into account ordinality later in the min/max update, but then we
80  // can end-up with a zero-volumed dimension. No good.
81  const ElemType split = (dimVec[i] + dimVec[i + 1]) / 2.0;
82 
83  if (split != dimVec[i])
84  splitVec.push_back(SplitItem(split, i + 1));
85  }
86 }
87 
88 // This the custom, sparse optimized implementation of the same routine.
89 template<typename ElemType>
90 void ExtractSplits(std::vector<std::pair<ElemType, size_t>>& splitVec,
91  const arma::SpMat<ElemType>& data,
92  size_t dim,
93  const size_t start,
94  const size_t end,
95  const size_t minLeafSize)
96 {
97  // It's common sense, but we also use it in a check later.
98  Log::Assert(minLeafSize > 0);
99 
100  typedef std::pair<ElemType, size_t> SplitItem;
101  const size_t n_elem = end - start;
102 
103  // Construct a vector of values.
104  const arma::SpRow<ElemType> row = data(dim, arma::span(start, end - 1));
105  std::vector<ElemType> valsVec(row.begin(), row.end());
106 
107  // ... and sort it!
108  std::sort(valsVec.begin(), valsVec.end());
109 
110  // Now iterate over the values, taking account for the over-the-zeroes jump
111  // and construct the splits vector.
112  const size_t zeroes = n_elem - valsVec.size();
113  ElemType lastVal = -std::numeric_limits<ElemType>::max();
114  size_t padding = 0;
115 
116  for (size_t i = 0; i < valsVec.size(); ++i)
117  {
118  const ElemType newVal = valsVec[i];
119  if (lastVal < ElemType(0) && newVal > ElemType(0) && zeroes > 0)
120  {
121  Log::Assert(padding == 0); // We should arrive here once!
122 
123  // The minLeafSize > 0 also guarantees we're not entering right at the
124  // start.
125  if (i >= minLeafSize && i <= n_elem - minLeafSize)
126  splitVec.push_back(SplitItem(lastVal / 2.0, i));
127 
128  padding = zeroes;
129  lastVal = ElemType(0);
130  }
131 
132  // This is the normal case.
133  if (i + padding >= minLeafSize && i + padding <= n_elem - minLeafSize)
134  {
135  // This makes sense for real continuous data. This kinda corrupts the
136  // data and estimation if the data is ordinal. Potentially we can fix that
137  // by taking into account ordinality later in the min/max update, but then
138  // we can end-up with a zero-volumed dimension. No good.
139  const ElemType split = (lastVal + newVal) / 2.0;
140 
141  // Check if we can split here (two points are different)
142  if (split != newVal)
143  splitVec.push_back(SplitItem(split, i + padding));
144  }
145 
146  lastVal = newVal;
147  }
148 }
149 
150 } // namespace details
151 
152 template<typename MatType, typename TagType>
154  start(0),
155  end(0),
156  splitDim(size_t(-1)),
157  splitValue(std::numeric_limits<ElemType>::max()),
158  logNegError(-DBL_MAX),
159  subtreeLeavesLogNegError(-DBL_MAX),
160  subtreeLeaves(0),
161  root(true),
162  ratio(1.0),
163  logVolume(-DBL_MAX),
164  bucketTag(-1),
165  alphaUpper(0.0),
166  left(NULL),
167  right(NULL)
168 { /* Nothing to do. */ }
169 
170 template<typename MatType, typename TagType>
172  start(obj.start),
173  end(obj.end),
174  maxVals(obj.maxVals),
175  minVals(obj.minVals),
176  splitDim(obj.splitDim),
177  splitValue(obj.splitValue),
178  logNegError(obj.logNegError),
179  subtreeLeavesLogNegError(obj.subtreeLeavesLogNegError),
180  subtreeLeaves(obj.subtreeLeaves),
181  root(obj.root),
182  ratio(obj.ratio),
183  logVolume(obj.logVolume),
184  bucketTag(obj.bucketTag),
185  alphaUpper(obj.alphaUpper),
186  left((obj.left == NULL) ? NULL : new DTree(*obj.left)),
187  right((obj.right == NULL) ? NULL : new DTree(*obj.right))
188 {
189  /* Nothing to do. */
190 }
191 
192 template<typename MatType, typename TagType>
194  const DTree<MatType, TagType>& obj)
195 {
196  if (this == &obj)
197  return *this;
198 
199  // Copy the values from the other tree.
200  start = obj.start;
201  end = obj.end;
202  maxVals = obj.maxVals;
203  minVals = obj.minVals;
204  splitDim = obj.splitDim;
205  splitValue = obj.splitValue;
206  logNegError = obj.logNegError;
207  subtreeLeavesLogNegError = obj.subtreeLeavesLogNegError;
208  subtreeLeaves = obj.subtreeLeaves;
209  root = obj.root;
210  ratio = obj.ratio;
211  logVolume = obj.logVolume;
212  bucketTag = obj.bucketTag;
213  alphaUpper = obj.alphaUpper;
214 
215  // Free the space allocated.
216  delete left;
217  delete right;
218 
219  // Copy the children.
220  left = ((obj.left == NULL) ? NULL : new DTree(*obj.left));
221  right = ((obj.right == NULL) ? NULL : new DTree(*obj.right));
222 
223  return *this;
224 }
225 
226 template<typename MatType, typename TagType>
228  start(obj.start),
229  end(obj.end),
230  maxVals(std::move(obj.maxVals)),
231  minVals(std::move(obj.minVals)),
232  splitDim(obj.splitDim),
233  splitValue(std::move(obj.splitValue)),
234  logNegError(obj.logNegError),
235  subtreeLeavesLogNegError(obj.subtreeLeavesLogNegError),
236  subtreeLeaves(obj.subtreeLeaves),
237  root(obj.root),
238  ratio(obj.ratio),
239  logVolume(obj.logVolume),
240  bucketTag(std::move(obj.bucketTag)),
241  alphaUpper(obj.alphaUpper),
242  left(obj.left),
243  right(obj.right)
244 {
245  // Set obj to default values.
246  obj.start = 0;
247  obj.end = 0;
248  obj.splitDim = size_t(-1);
249  obj.splitValue = std::numeric_limits<ElemType>::max();
250  obj.logNegError = -DBL_MAX;
251  obj.subtreeLeavesLogNegError = -DBL_MAX;
252  obj.subtreeLeaves = 0;
253  obj.root = true;
254  obj.ratio = 1.0;
255  obj.logVolume = -DBL_MAX;
256  obj.bucketTag = -1;
257  obj.alphaUpper = 0.0;
258  obj.left = NULL;
259  obj.right = NULL;
260 }
261 
262 template<typename MatType, typename TagType>
265 {
266  if (this == &obj)
267  return *this;
268 
269  // Move the values from the other tree.
270  start = obj.start;
271  end = obj.end;
272  splitDim = obj.splitDim;
273  logNegError = obj.logNegError;
274  subtreeLeavesLogNegError = obj.subtreeLeavesLogNegError;
275  subtreeLeaves = obj.subtreeLeaves;
276  root = obj.root;
277  ratio = obj.ratio;
278  logVolume = obj.logVolume;
279  alphaUpper = obj.alphaUpper;
280  maxVals = std::move(obj.maxVals);
281  minVals = std::move(obj.minVals);
282  splitValue = std::move(obj.splitValue);
283  bucketTag = std::move(obj.bucketTag);
284 
285  // Free the space allocated.
286  delete left;
287  delete right;
288 
289  // Move children.
290  left = obj.left;
291  right = obj.right;
292 
293  // Set obj to default values.
294  obj.start = 0;
295  obj.end = 0;
296  obj.splitDim = size_t(-1);
297  obj.splitValue = std::numeric_limits<ElemType>::max();
298  obj.logNegError = -DBL_MAX;
299  obj.subtreeLeavesLogNegError = -DBL_MAX;
300  obj.subtreeLeaves = 0;
301  obj.root = true;
302  obj.ratio = 1.0;
303  obj.logVolume = -DBL_MAX;
304  obj.bucketTag = -1;
305  obj.alphaUpper = 0.0;
306  obj.left = NULL;
307  obj.right = NULL;
308 
309  return *this;
310 }
311 
312 
313 // Root node initializers.
314 template<typename MatType, typename TagType>
316  const StatType& minVals,
317  const size_t totalPoints) :
318  start(0),
319  end(totalPoints),
320  maxVals(maxVals),
321  minVals(minVals),
322  splitDim(size_t(-1)),
323  splitValue(std::numeric_limits<ElemType>::max()),
324  logNegError(LogNegativeError(totalPoints)),
325  subtreeLeavesLogNegError(-DBL_MAX),
326  subtreeLeaves(0),
327  root(true),
328  ratio(1.0),
329  logVolume(-DBL_MAX),
330  bucketTag(-1),
331  alphaUpper(0.0),
332  left(NULL),
333  right(NULL)
334 { /* Nothing to do. */ }
335 
336 template<typename MatType, typename TagType>
338  start(0),
339  end(data.n_cols),
340  maxVals(arma::max(data, 1)),
341  minVals(arma::min(data, 1)),
342  splitDim(size_t(-1)),
343  splitValue(std::numeric_limits<ElemType>::max()),
344  subtreeLeavesLogNegError(-DBL_MAX),
345  subtreeLeaves(0),
346  root(true),
347  ratio(1.0),
348  logVolume(-DBL_MAX),
349  bucketTag(-1),
350  alphaUpper(0.0),
351  left(NULL),
352  right(NULL)
353 {
354  logNegError = LogNegativeError(data.n_cols);
355 }
356 
357 // Non-root node initializers.
358 template<typename MatType, typename TagType>
360  const StatType& minVals,
361  const size_t start,
362  const size_t end,
363  const double logNegError) :
364  start(start),
365  end(end),
366  maxVals(maxVals),
367  minVals(minVals),
368  splitDim(size_t(-1)),
369  splitValue(std::numeric_limits<ElemType>::max()),
370  logNegError(logNegError),
371  subtreeLeavesLogNegError(-DBL_MAX),
372  subtreeLeaves(0),
373  root(false),
374  ratio(1.0),
375  logVolume(-DBL_MAX),
376  bucketTag(-1),
377  alphaUpper(0.0),
378  left(NULL),
379  right(NULL)
380 { /* Nothing to do. */ }
381 
382 template<typename MatType, typename TagType>
384  const StatType& minVals,
385  const size_t totalPoints,
386  const size_t start,
387  const size_t end) :
388  start(start),
389  end(end),
390  maxVals(maxVals),
391  minVals(minVals),
392  splitDim(size_t(-1)),
393  splitValue(std::numeric_limits<ElemType>::max()),
394  logNegError(LogNegativeError(totalPoints)),
395  subtreeLeavesLogNegError(-DBL_MAX),
396  subtreeLeaves(0),
397  root(false),
398  ratio(1.0),
399  logVolume(-DBL_MAX),
400  bucketTag(-1),
401  alphaUpper(0.0),
402  left(NULL),
403  right(NULL)
404 { /* Nothing to do. */ }
405 
406 template<typename MatType, typename TagType>
408 {
409  delete left;
410  delete right;
411 }
412 
413 // This function computes the log-l2-negative-error of a given node from the
414 // formula R(t) = log(|t|^2 / (N^2 V_t)).
415 template<typename MatType, typename TagType>
416 double DTree<MatType, TagType>::LogNegativeError(const size_t totalPoints) const
417 {
418  // log(-|t|^2 / (N^2 V_t)) = log(-1) + 2 log(|t|) - 2 log(N) - log(V_t).
419  double err = 2 * std::log((double) (end - start)) -
420  2 * std::log((double) totalPoints);
421 
422  StatType valDiffs = maxVals - minVals;
423  for (size_t i = 0; i < valDiffs.n_elem; ++i)
424  {
425  // Ignore very small dimensions to prevent overflow.
426  if (valDiffs[i] > 1e-50)
427  err -= std::log(valDiffs[i]);
428  }
429 
430  return err;
431 }
432 
433 // This function finds the best split with respect to the L2-error, by trying
434 // all possible splits. The dataset is the full data set but the start and
435 // end are used to obtain the point in this node.
436 template<typename MatType, typename TagType>
437 bool DTree<MatType, TagType>::FindSplit(const MatType& data,
438  size_t& splitDim,
439  ElemType& splitValue,
440  double& leftError,
441  double& rightError,
442  const size_t minLeafSize) const
443 {
444  typedef std::pair<ElemType, size_t> SplitItem;
445 
446  // Ensure the dimensionality of the data is the same as the dimensionality of
447  // the bounding rectangle.
448  Log::Assert(data.n_rows == maxVals.n_elem);
449  Log::Assert(data.n_rows == minVals.n_elem);
450 
451  const size_t points = end - start;
452 
453  double minError = logNegError;
454  bool splitFound = false;
455 
456  // Loop through each dimension.
457 #ifdef _WIN32
458  #pragma omp parallel for default(shared)
459  for (intmax_t dim = 0; dim < (intmax_t) maxVals.n_elem; ++dim)
460 #else
461  #pragma omp parallel for default(shared)
462  for (size_t dim = 0; dim < maxVals.n_elem; ++dim)
463 #endif
464  {
465  const ElemType min = minVals[dim];
466  const ElemType max = maxVals[dim];
467 
468  // If there is nothing to split in this dimension, move on.
469  if (max - min == 0.0)
470  continue; // Skip to next dimension.
471 
472  // Find the log volume of all the other dimensions.
473  const double volumeWithoutDim = logVolume - std::log(max - min);
474 
475  // Initializing all other stuff for this dimension.
476  bool dimSplitFound = false;
477  // Take an error estimate for this dimension.
478  double minDimError = std::pow(points, 2.0) / (max - min);
479  double dimLeftError = 0.0; // For -Wuninitialized. These variables will
480  double dimRightError = 0.0; // always be set to something else before use.
481  ElemType dimSplitValue = 0.0;
482 
483  // Get the values for splitting. The old implementation:
484  // dimVec = data.row(dim).subvec(start, end - 1);
485  // dimVec = arma::sort(dimVec);
486  // could be quite inefficient for sparse matrices, due to
487  // copy operations (3). This one has custom implementation for dense and
488  // sparse matrices.
489 
490  std::vector<SplitItem> splitVec;
491  details::ExtractSplits<ElemType>(splitVec, data, dim, start, end,
492  minLeafSize);
493 
494  // Iterate on all the splits for this dimension
495  for (typename std::vector<SplitItem>::iterator i = splitVec.begin();
496  i != splitVec.end();
497  ++i)
498  {
499  const ElemType split = i->first;
500  const size_t position = i->second;
501 
502  // Another way of picking split is using this:
503  // split = leftsplit;
504  if ((split - min > 0.0) && (max - split > 0.0))
505  {
506  // Ensure that the right node will have at least the minimum number of
507  // points.
508  Log::Assert((points - position) >= minLeafSize);
509 
510  // Now we have to see if the error will be reduced. Simple manipulation
511  // of the error function gives us the condition we must satisfy:
512  // |t_l|^2 / V_l + |t_r|^2 / V_r >= |t|^2 / (V_l + V_r)
513  // and because the volume is only dependent on the dimension we are
514  // splitting, we can assume V_l is just the range of the left and V_r is
515  // just the range of the right.
516  double negLeftError = std::pow(position, 2.0) / (split - min);
517  double negRightError = std::pow(points - position, 2.0) / (max - split);
518 
519  // If this is better, take it.
520  if ((negLeftError + negRightError) >= minDimError)
521  {
522  minDimError = negLeftError + negRightError;
523  dimLeftError = negLeftError;
524  dimRightError = negRightError;
525  dimSplitValue = split;
526  dimSplitFound = true;
527  }
528  }
529  }
530 
531  const double actualMinDimError = std::log(minDimError)
532  - 2 * std::log((double) data.n_cols)
533  - volumeWithoutDim;
534 
535 #pragma omp critical(DTreeFindUpdate)
536  if ((actualMinDimError > minError) && dimSplitFound)
537  {
538  // Calculate actual error (in logspace) by adding terms back to our
539  // estimate.
540  minError = actualMinDimError;
541  splitDim = dim;
542  splitValue = dimSplitValue;
543  leftError = std::log(dimLeftError) - 2 * std::log((double) data.n_cols)
544  - volumeWithoutDim;
545  rightError = std::log(dimRightError) - 2 * std::log((double) data.n_cols)
546  - volumeWithoutDim;
547  splitFound = true;
548  } // end if better split found in this dimension.
549  }
550 
551  return splitFound;
552 }
553 
554 template<typename MatType, typename TagType>
556  const size_t splitDim,
557  const ElemType splitValue,
558  arma::Col<size_t>& oldFromNew) const
559 {
560  // Swap all columns such that any columns with value in dimension splitDim
561  // less than or equal to splitValue are on the left side, and all others are
562  // on the right side. A similar sort to this is also performed in
563  // BinarySpaceTree construction (its comments are more detailed).
564  size_t left = start;
565  size_t right = end - 1;
566  for (;;)
567  {
568  while (data(splitDim, left) <= splitValue)
569  ++left;
570  while (data(splitDim, right) > splitValue)
571  --right;
572 
573  if (left > right)
574  break;
575 
576  data.swap_cols(left, right);
577 
578  // Store the mapping from old to new. Do not put std::swap here...
579  const size_t tmp = oldFromNew[left];
580  oldFromNew[left] = oldFromNew[right];
581  oldFromNew[right] = tmp;
582  }
583 
584  // This now refers to the first index of the "right" side.
585  return left;
586 }
587 
588 // Greedily expand the tree.
589 template<typename MatType, typename TagType>
591  arma::Col<size_t>& oldFromNew,
592  const bool useVolReg,
593  const size_t maxLeafSize,
594  const size_t minLeafSize)
595 {
596  Log::Assert(data.n_rows == maxVals.n_elem);
597  Log::Assert(data.n_rows == minVals.n_elem);
598 
599  double leftG, rightG;
600 
601  // Compute points ratio.
602  ratio = (double) (end - start) / (double) oldFromNew.n_elem;
603 
604  // Compute the log of the volume of the node.
605  logVolume = 0;
606  for (size_t i = 0; i < maxVals.n_elem; ++i)
607  if (maxVals[i] - minVals[i] > 0.0)
608  logVolume += std::log(maxVals[i] - minVals[i]);
609 
610  // Check if node is large enough to split.
611  if ((size_t) (end - start) > maxLeafSize)
612  {
613  // Find the split.
614  size_t dim;
615  double splitValueTmp;
616  double leftError, rightError;
617  if (FindSplit(data, dim, splitValueTmp, leftError, rightError, minLeafSize))
618  {
619  // Move the data around for the children to have points in a node lie
620  // contiguously (to increase efficiency during the training).
621  const size_t splitIndex = SplitData(data, dim, splitValueTmp, oldFromNew);
622 
623  // Make max and min vals for the children.
624  StatType maxValsL(maxVals);
625  StatType maxValsR(maxVals);
626  StatType minValsL(minVals);
627  StatType minValsR(minVals);
628 
629  maxValsL[dim] = splitValueTmp;
630  minValsR[dim] = splitValueTmp;
631 
632  // Store split dim and split val in the node.
633  splitValue = splitValueTmp;
634  splitDim = dim;
635 
636  // Recursively grow the children.
637  left = new DTree(maxValsL, minValsL, start, splitIndex, leftError);
638  right = new DTree(maxValsR, minValsR, splitIndex, end, rightError);
639 
640  leftG = left->Grow(data, oldFromNew, useVolReg, maxLeafSize,
641  minLeafSize);
642  rightG = right->Grow(data, oldFromNew, useVolReg, maxLeafSize,
643  minLeafSize);
644 
645  // Store values of R(T~) and |T~|.
646  subtreeLeaves = left->SubtreeLeaves() + right->SubtreeLeaves();
647 
648  // Find the log negative error of the subtree leaves. This is kind of an
649  // odd one because we don't want to represent the error in non-log-space,
650  // but we have to calculate log(E_l + E_r). So we multiply E_l and E_r by
651  // V_t (remember E_l has an inverse relationship to the volume of the
652  // nodes) and then subtract log(V_t) at the end of the whole expression.
653  // As a result we do leave log-space, but the largest quantity we
654  // represent is on the order of (V_t / V_i) where V_i is the smallest leaf
655  // node below this node, which depends heavily on the depth of the tree.
656  subtreeLeavesLogNegError = std::log(
657  std::exp(logVolume + left->SubtreeLeavesLogNegError()) +
658  std::exp(logVolume + right->SubtreeLeavesLogNegError()))
659  - logVolume;
660  }
661  else
662  {
663  // No split found so make a leaf out of it.
664  subtreeLeaves = 1;
665  subtreeLeavesLogNegError = logNegError;
666  }
667  }
668  else
669  {
670  // We can make this a leaf node.
671  Log::Assert((size_t) (end - start) >= minLeafSize);
672  subtreeLeaves = 1;
673  subtreeLeavesLogNegError = logNegError;
674  }
675 
676  // If this is a leaf, do not compute g_k(t); otherwise compute, store, and
677  // propagate min(g_k(t_L), g_k(t_R), g_k(t)), unless t_L and/or t_R are
678  // leaves.
679  if (subtreeLeaves == 1)
680  {
681  return std::numeric_limits<double>::max();
682  }
683  else
684  {
685  const double range = maxVals[splitDim] - minVals[splitDim];
686  const double leftRatio = (splitValue - minVals[splitDim]) / range;
687  const double rightRatio = (maxVals[splitDim] - splitValue) / range;
688 
689  const size_t leftPow = std::pow((double) (left->End() - left->Start()), 2);
690  const size_t rightPow = std::pow((double) (right->End() - right->Start()),
691  2);
692  const size_t thisPow = std::pow((double) (end - start), 2);
693 
694  double tmpAlphaSum = leftPow / leftRatio + rightPow / rightRatio - thisPow;
695 
696  if (left->SubtreeLeaves() > 1)
697  {
698  const double exponent = 2 * std::log((double) data.n_cols) + logVolume +
699  left->AlphaUpper();
700 
701  // Whether or not this will overflow is highly dependent on the depth of
702  // the tree.
703  tmpAlphaSum += std::exp(exponent);
704  }
705 
706  if (right->SubtreeLeaves() > 1)
707  {
708  const double exponent = 2 * std::log((double) data.n_cols)
709  + logVolume
710  + right->AlphaUpper();
711 
712  tmpAlphaSum += std::exp(exponent);
713  }
714 
715  alphaUpper = std::log(tmpAlphaSum) - 2 * std::log((double) data.n_cols)
716  - logVolume;
717 
718  double gT;
719  if (useVolReg)
720  {
721  // This is wrong for now!
722  gT = alphaUpper; // / (subtreeLeavesVTInv - vTInv);
723  }
724  else
725  {
726  gT = alphaUpper - std::log((double) (subtreeLeaves - 1));
727  }
728 
729  return std::min(gT, std::min(leftG, rightG));
730  }
731 
732  // We need to compute (c_t^2) * r_t for all subtree leaves; this is equal to
733  // n_t ^ 2 / r_t * n ^ 2 = -error. Therefore the value we need is actually
734  // -1.0 * subtreeLeavesError.
735 }
736 
737 
738 template<typename MatType, typename TagType>
739 double DTree<MatType, TagType>::PruneAndUpdate(const double oldAlpha,
740  const size_t points,
741  const bool useVolReg)
742 {
743  // Compute gT.
744  if (subtreeLeaves == 1) // If we are a leaf...
745  {
746  return std::numeric_limits<double>::max();
747  }
748  else
749  {
750  // Compute gT value for node t.
751  volatile double gT;
752  if (useVolReg)
753  gT = alphaUpper; // - std::log(subtreeLeavesVTInv - vTInv);
754  else
755  gT = alphaUpper - std::log((double) (subtreeLeaves - 1));
756 
757  if (gT > oldAlpha)
758  {
759  // Go down the tree and update accordingly. Traverse the children.
760  double leftG = left->PruneAndUpdate(oldAlpha, points, useVolReg);
761  double rightG = right->PruneAndUpdate(oldAlpha, points, useVolReg);
762 
763  // Update values.
764  subtreeLeaves = left->SubtreeLeaves() + right->SubtreeLeaves();
765 
766  // Find the log negative error of the subtree leaves. This is kind of an
767  // odd one because we don't want to represent the error in non-log-space,
768  // but we have to calculate log(E_l + E_r). So we multiply E_l and E_r by
769  // V_t (remember E_l has an inverse relationship to the volume of the
770  // nodes) and then subtract log(V_t) at the end of the whole expression.
771  // As a result we do leave log-space, but the largest quantity we
772  // represent is on the order of (V_t / V_i) where V_i is the smallest leaf
773  // node below this node, which depends heavily on the depth of the tree.
774  subtreeLeavesLogNegError = std::log(
775  std::exp(logVolume + left->SubtreeLeavesLogNegError()) +
776  std::exp(logVolume + right->SubtreeLeavesLogNegError()))
777  - logVolume;
778 
779  // Recalculate upper alpha.
780  const double range = maxVals[splitDim] - minVals[splitDim];
781  const double leftRatio = (splitValue - minVals[splitDim]) / range;
782  const double rightRatio = (maxVals[splitDim] - splitValue) / range;
783 
784  const size_t leftPow = std::pow((double) (left->End() - left->Start()),
785  2);
786  const size_t rightPow = std::pow((double) (right->End() - right->Start()),
787  2);
788  const size_t thisPow = std::pow((double) (end - start), 2);
789 
790  double tmpAlphaSum = leftPow / leftRatio + rightPow / rightRatio -
791  thisPow;
792 
793  if (left->SubtreeLeaves() > 1)
794  {
795  const double exponent = 2 * std::log((double) points) + logVolume +
796  left->AlphaUpper();
797 
798  // Whether or not this will overflow is highly dependent on the depth of
799  // the tree.
800  tmpAlphaSum += std::exp(exponent);
801  }
802 
803  if (right->SubtreeLeaves() > 1)
804  {
805  const double exponent = 2 * std::log((double) points) + logVolume +
806  right->AlphaUpper();
807 
808  tmpAlphaSum += std::exp(exponent);
809  }
810 
811  alphaUpper = std::log(tmpAlphaSum) - 2 * std::log((double) points) -
812  logVolume;
813 
814  // Update gT value.
815  if (useVolReg)
816  {
817  // This is incorrect.
818  gT = alphaUpper; // / (subtreeLeavesVTInv - vTInv);
819  }
820  else
821  {
822  gT = alphaUpper - std::log((double) (subtreeLeaves - 1));
823  }
824 
825  Log::Assert(gT < std::numeric_limits<double>::max());
826 
827  return std::min((double) gT, std::min(leftG, rightG));
828  }
829  else
830  {
831  // Prune this subtree.
832  // First, make this node a leaf node.
833  subtreeLeaves = 1;
834  subtreeLeavesLogNegError = logNegError;
835 
836  delete left;
837  delete right;
838 
839  left = NULL;
840  right = NULL;
841 
842  // Pass information upward.
843  return std::numeric_limits<double>::max();
844  }
845  }
846 }
847 
848 // Check whether a given point is within the bounding box of this node (check
849 // generally done at the root, so its the bounding box of the data).
850 //
851 // Future improvement: Open up the range with epsilons on both sides where
852 // epsilon depends on the density near the boundary.
853 template<typename MatType, typename TagType>
855 {
856  for (size_t i = 0; i < query.n_elem; ++i)
857  if ((query[i] < minVals[i]) || (query[i] > maxVals[i]))
858  return false;
859 
860  return true;
861 }
862 
863 
864 template<typename MatType, typename TagType>
866 {
867  Log::Assert(query.n_elem == maxVals.n_elem);
868 
869  if (root == 1) // If we are the root...
870  {
871  // Check if the query is within range.
872  if (!WithinRange(query))
873  return 0.0;
874  }
875 
876  if (subtreeLeaves == 1) // If we are a leaf...
877  return std::exp(std::log(ratio) - logVolume);
878 
879  // Return either of the two children - left or right, depending on the
880  // splitValue.
881  return (query[splitDim] <= splitValue) ?
882  left->ComputeValue(query) :
883  right->ComputeValue(query);
884 }
885 
886 // Index the buckets for possible usage later.
887 template<typename MatType, typename TagType>
888 TagType DTree<MatType, TagType>::TagTree(const TagType& tag, bool every)
889 {
890  if (subtreeLeaves == 1)
891  {
892  // Only label leaves.
893  bucketTag = tag;
894  return (tag + 1);
895  }
896 
897  TagType nextTag;
898  if (every)
899  {
900  bucketTag = tag;
901  nextTag = (tag + 1);
902  }
903  else
904  nextTag = tag;
905 
906  return right->TagTree(left->TagTree(nextTag, every), every);
907 }
908 
909 template<typename MatType, typename TagType>
910 TagType DTree<MatType, TagType>::FindBucket(const VecType& query) const
911 {
912  Log::Assert(query.n_elem == maxVals.n_elem);
913 
914  if (root == 1) // If we are the root...
915  {
916  // Check if the query is within range.
917  if (!WithinRange(query))
918  return -1;
919  }
920 
921  // If we are a leaf...
922  if (subtreeLeaves == 1)
923  {
924  return bucketTag;
925  }
926  else
927  {
928  // Return the tag from either of the two children - left or right.
929  return (query[splitDim] <= splitValue) ?
930  left->FindBucket(query) :
931  right->FindBucket(query);
932  }
933 }
934 
935 template<typename MatType, typename TagType>
937  const
938 {
939  // Clear and set to right size.
940  importances.zeros(maxVals.n_elem);
941 
942  std::stack<const DTree*> nodes;
943  nodes.push(this);
944 
945  while (!nodes.empty())
946  {
947  const DTree& curNode = *nodes.top();
948  nodes.pop();
949 
950  if (curNode.subtreeLeaves == 1)
951  continue; // Do nothing for leaves.
952 
953  // The way to do this entirely in log-space is (at this time) somewhat
954  // unclear. So this risks overflow.
955  importances[curNode.SplitDim()] += (-std::exp(curNode.LogNegError()) -
956  (-std::exp(curNode.Left()->LogNegError()) +
957  -std::exp(curNode.Right()->LogNegError())));
958 
959  nodes.push(curNode.Left());
960  nodes.push(curNode.Right());
961  }
962 }
963 
964 template<typename MatType, typename TagType>
966  const StatType& maxs)
967 {
968  if (!root)
969  {
970  minVals = mins;
971  maxVals = maxs;
972  }
973 
974  if (left && right)
975  {
976  StatType maxValsL(maxs);
977  StatType maxValsR(maxs);
978  StatType minValsL(mins);
979  StatType minValsR(mins);
980 
981  maxValsL[splitDim] = minValsR[splitDim] = splitValue;
982  left->FillMinMax(minValsL, maxValsL);
983  right->FillMinMax(minValsR, maxValsR);
984  }
985 }
986 
987 template <typename MatType, typename TagType>
988 template <typename Archive>
990  const uint32_t /* version */)
991 {
992  ar(CEREAL_NVP(start));
993  ar(CEREAL_NVP(end));
994  ar(CEREAL_NVP(maxVals));
995  ar(CEREAL_NVP(minVals));
996  ar(CEREAL_NVP(splitDim));
997  ar(CEREAL_NVP(splitValue));
998  ar(CEREAL_NVP(logNegError));
999  ar(CEREAL_NVP(subtreeLeavesLogNegError));
1000  ar(CEREAL_NVP(subtreeLeaves));
1001  ar(CEREAL_NVP(root));
1002  ar(CEREAL_NVP(ratio));
1003  ar(CEREAL_NVP(logVolume));
1004  ar(CEREAL_NVP(bucketTag));
1005  ar(CEREAL_NVP(alphaUpper));
1006 
1007  if (cereal::is_loading<Archive>())
1008  {
1009  if (left)
1010  delete left;
1011  if (right)
1012  delete right;
1013 
1014  left = NULL;
1015  right = NULL;
1016  }
1017 
1018  bool hasLeft = (left != NULL);
1019  bool hasRight = (right != NULL);
1020 
1021  ar(CEREAL_NVP(hasLeft));
1022  ar(CEREAL_NVP(hasRight));
1023 
1024  if (hasLeft)
1025  ar(CEREAL_POINTER(left));
1026  if (hasRight)
1027  ar(CEREAL_POINTER(right));
1028 
1029  if (root)
1030  {
1031  ar(CEREAL_NVP(maxVals));
1032  ar(CEREAL_NVP(minVals));
1033 
1034  // This is added in order to reduce (dramatically!) the model file size.
1035  if (cereal::is_loading<Archive>() && left && right)
1036  FillMinMax(minVals, maxVals);
1037  }
1038 }
void serialize(Archive &ar, const uint32_t)
Serialize the density estimation tree.
Definition: dtree_impl.hpp:989
DTree * Right() const
Return the right child.
Definition: dtree.hpp:303
double Grow(MatType &data, arma::Col< size_t > &oldFromNew, const bool useVolReg=false, const size_t maxLeafSize=10, const size_t minLeafSize=5)
Greedily expand the tree.
Definition: dtree_impl.hpp:590
arma::Col< ElemType > StatType
The statistic type we are holding.
Definition: dtree.hpp:54
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Definition: pointer_wrapper.hpp:23
DTree()
Create an empty density estimation tree.
Definition: dtree_impl.hpp:153
double LogNegError() const
Return the log negative error of this node.
Definition: dtree.hpp:290
size_t End() const
Return the first index of a point not contained in this node.
Definition: dtree.hpp:284
bool WithinRange(const VecType &query) const
Return whether a query point is within the range of this node.
Definition: dtree_impl.hpp:854
double LogNegativeError(const size_t totalPoints) const
Compute the log-negative-error for this point, given the total number of points in the dataset...
Definition: dtree_impl.hpp:416
double PruneAndUpdate(const double oldAlpha, const size_t points, const bool useVolReg=false)
Perform alpha pruning on a tree.
Definition: dtree_impl.hpp:739
Definition: dtree_impl.hpp:21
TagType FindBucket(const VecType &query) const
Return the tag of the leaf containing the query.
Definition: dtree_impl.hpp:910
size_t Start() const
Return the starting index of points contained in this node.
Definition: dtree.hpp:282
double SubtreeLeavesLogNegError() const
Return the log negative error of all descendants of this node.
Definition: dtree.hpp:292
MatType::elem_type ElemType
The actual, underlying type we&#39;re working with.
Definition: dtree.hpp:50
size_t SubtreeLeaves() const
Return the number of leaves which are descendants of this node.
Definition: dtree.hpp:294
size_t SplitDim() const
Return the split dimension of this node.
Definition: dtree.hpp:286
double ComputeValue(const VecType &query) const
Compute the logarithm of the density estimate of a given query point.
Definition: dtree_impl.hpp:865
MatType::vec_type VecType
The type of vector we are using.
Definition: dtree.hpp:52
A density estimation tree is similar to both a decision tree and a space partitioning tree (like a kd...
Definition: dtree.hpp:46
DTree * Left() const
Return the left child.
Definition: dtree.hpp:301
DTree & operator=(const DTree &obj)
Copy the given tree.
Definition: dtree_impl.hpp:193
#define CEREAL_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_wrapper.hpp:96
double AlphaUpper() const
Return the upper part of the alpha sum.
Definition: dtree.hpp:307
void ComputeVariableImportance(arma::vec &importances) const
Compute the variable importance of each dimension in the learned tree.
Definition: dtree_impl.hpp:936
void ExtractSplits(std::vector< std::pair< ElemType, size_t >> &splitVec, const MatType &data, size_t dim, const size_t start, const size_t end, const size_t minLeafSize)
This one sorts and scand the given per-dimension extract and puts all splits in a vector...
Definition: dtree_impl.hpp:29
TagType TagTree(const TagType &tag=0, bool everyNode=false)
Index the buckets for possible usage later; this results in every leaf in the tree having a specific ...
Definition: dtree_impl.hpp:888
static void Assert(bool condition, const std::string &message="Assert Failed.")
Checks if the specified condition is true.
Definition: log.cpp:38
~DTree()
Clean up memory allocated by the tree.
Definition: dtree_impl.hpp:407