28 template<
typename ElemType,
typename MatType>
34 const size_t minLeafSize)
37 std::is_same<typename MatType::elem_type, ElemType>::value ==
true,
38 "The ElemType does not correspond to the matrix's element type.");
40 typedef std::pair<ElemType, size_t> SplitItem;
41 const typename MatType::row_type dimVec =
42 arma::sort(data(dim, arma::span(start, end - 1)));
46 for (
size_t i = minLeafSize - 1; i < dimVec.n_elem - minLeafSize; ++i)
52 const ElemType split = (dimVec[i] + dimVec[i + 1]) / 2.0;
55 if (split != dimVec[i])
56 splitVec.push_back(SplitItem(split, i + 1));
61 template<
typename ElemType>
62 void ExtractSplits(std::vector<std::pair<ElemType, size_t>>& splitVec,
63 const arma::Mat<ElemType>&
data,
67 const size_t minLeafSize)
69 typedef std::pair<ElemType, size_t> SplitItem;
70 arma::rowvec dimVec = data(dim, arma::span(start, end - 1));
73 std::sort(dimVec.begin(), dimVec.end());
75 for (
size_t i = minLeafSize - 1; i < dimVec.n_elem - minLeafSize; ++i)
81 const ElemType split = (dimVec[i] + dimVec[i + 1]) / 2.0;
83 if (split != dimVec[i])
84 splitVec.push_back(SplitItem(split, i + 1));
89 template<
typename ElemType>
90 void ExtractSplits(std::vector<std::pair<ElemType, size_t>>& splitVec,
91 const arma::SpMat<ElemType>& data,
95 const size_t minLeafSize)
100 typedef std::pair<ElemType, size_t> SplitItem;
101 const size_t n_elem = end - start;
104 const arma::SpRow<ElemType> row = data(dim, arma::span(start, end - 1));
105 std::vector<ElemType> valsVec(row.begin(), row.end());
108 std::sort(valsVec.begin(), valsVec.end());
112 const size_t zeroes = n_elem - valsVec.size();
113 ElemType lastVal = -std::numeric_limits<ElemType>::max();
116 for (
size_t i = 0; i < valsVec.size(); ++i)
118 const ElemType newVal = valsVec[i];
119 if (lastVal < ElemType(0) && newVal > ElemType(0) && zeroes > 0)
125 if (i >= minLeafSize && i <= n_elem - minLeafSize)
126 splitVec.push_back(SplitItem(lastVal / 2.0, i));
129 lastVal = ElemType(0);
133 if (i + padding >= minLeafSize && i + padding <= n_elem - minLeafSize)
139 const ElemType split = (lastVal + newVal) / 2.0;
143 splitVec.push_back(SplitItem(split, i + padding));
152 template<
typename MatType,
typename TagType>
156 splitDim(size_t(-1)),
158 logNegError(-DBL_MAX),
159 subtreeLeavesLogNegError(-DBL_MAX),
170 template<
typename MatType,
typename TagType>
174 maxVals(obj.maxVals),
175 minVals(obj.minVals),
176 splitDim(obj.splitDim),
177 splitValue(obj.splitValue),
178 logNegError(obj.logNegError),
179 subtreeLeavesLogNegError(obj.subtreeLeavesLogNegError),
180 subtreeLeaves(obj.subtreeLeaves),
183 logVolume(obj.logVolume),
184 bucketTag(obj.bucketTag),
185 alphaUpper(obj.alphaUpper),
186 left((obj.left == NULL) ? NULL : new
DTree(*obj.left)),
187 right((obj.right == NULL) ? NULL : new
DTree(*obj.right))
192 template<
typename MatType,
typename TagType>
202 maxVals = obj.maxVals;
203 minVals = obj.minVals;
204 splitDim = obj.splitDim;
205 splitValue = obj.splitValue;
206 logNegError = obj.logNegError;
207 subtreeLeavesLogNegError = obj.subtreeLeavesLogNegError;
208 subtreeLeaves = obj.subtreeLeaves;
211 logVolume = obj.logVolume;
212 bucketTag = obj.bucketTag;
213 alphaUpper = obj.alphaUpper;
220 left = ((obj.left == NULL) ? NULL :
new DTree(*obj.left));
221 right = ((obj.right == NULL) ? NULL :
new DTree(*obj.right));
226 template<
typename MatType,
typename TagType>
230 maxVals(
std::move(obj.maxVals)),
231 minVals(
std::move(obj.minVals)),
232 splitDim(obj.splitDim),
233 splitValue(
std::move(obj.splitValue)),
234 logNegError(obj.logNegError),
235 subtreeLeavesLogNegError(obj.subtreeLeavesLogNegError),
236 subtreeLeaves(obj.subtreeLeaves),
239 logVolume(obj.logVolume),
240 bucketTag(
std::move(obj.bucketTag)),
241 alphaUpper(obj.alphaUpper),
248 obj.splitDim = size_t(-1);
249 obj.splitValue = std::numeric_limits<ElemType>::max();
250 obj.logNegError = -DBL_MAX;
251 obj.subtreeLeavesLogNegError = -DBL_MAX;
252 obj.subtreeLeaves = 0;
255 obj.logVolume = -DBL_MAX;
257 obj.alphaUpper = 0.0;
262 template<
typename MatType,
typename TagType>
272 splitDim = obj.splitDim;
273 logNegError = obj.logNegError;
274 subtreeLeavesLogNegError = obj.subtreeLeavesLogNegError;
275 subtreeLeaves = obj.subtreeLeaves;
278 logVolume = obj.logVolume;
279 alphaUpper = obj.alphaUpper;
280 maxVals = std::move(obj.maxVals);
281 minVals = std::move(obj.minVals);
282 splitValue = std::move(obj.splitValue);
283 bucketTag = std::move(obj.bucketTag);
296 obj.splitDim = size_t(-1);
297 obj.splitValue = std::numeric_limits<ElemType>::max();
298 obj.logNegError = -DBL_MAX;
299 obj.subtreeLeavesLogNegError = -DBL_MAX;
300 obj.subtreeLeaves = 0;
303 obj.logVolume = -DBL_MAX;
305 obj.alphaUpper = 0.0;
314 template<
typename MatType,
typename TagType>
317 const size_t totalPoints) :
322 splitDim(size_t(-1)),
325 subtreeLeavesLogNegError(-DBL_MAX),
336 template<
typename MatType,
typename TagType>
340 maxVals(
arma::max(data, 1)),
341 minVals(
arma::min(data, 1)),
342 splitDim(size_t(-1)),
344 subtreeLeavesLogNegError(-DBL_MAX),
358 template<
typename MatType,
typename TagType>
363 const double logNegError) :
368 splitDim(size_t(-1)),
370 logNegError(logNegError),
371 subtreeLeavesLogNegError(-DBL_MAX),
382 template<
typename MatType,
typename TagType>
385 const size_t totalPoints,
392 splitDim(size_t(-1)),
395 subtreeLeavesLogNegError(-DBL_MAX),
406 template<
typename MatType,
typename TagType>
415 template<
typename MatType,
typename TagType>
419 double err = 2 * std::log((
double) (end - start)) -
420 2 * std::log((
double) totalPoints);
422 StatType valDiffs = maxVals - minVals;
423 for (
size_t i = 0; i < valDiffs.n_elem; ++i)
426 if (valDiffs[i] > 1e-50)
427 err -= std::log(valDiffs[i]);
436 template<
typename MatType,
typename TagType>
442 const size_t minLeafSize)
const 444 typedef std::pair<ElemType, size_t> SplitItem;
451 const size_t points = end - start;
453 double minError = logNegError;
454 bool splitFound =
false;
458 #pragma omp parallel for default(shared) 459 for (intmax_t dim = 0; dim < (intmax_t) maxVals.n_elem; ++dim)
461 #pragma omp parallel for default(shared) 462 for (
size_t dim = 0; dim < maxVals.n_elem; ++dim)
469 if (max - min == 0.0)
473 const double volumeWithoutDim = logVolume - std::log(max - min);
476 bool dimSplitFound =
false;
478 double minDimError = std::pow(points, 2.0) / (max - min);
479 double dimLeftError = 0.0;
480 double dimRightError = 0.0;
490 std::vector<SplitItem> splitVec;
491 details::ExtractSplits<ElemType>(splitVec,
data, dim, start, end,
495 for (
typename std::vector<SplitItem>::iterator i = splitVec.begin();
500 const size_t position = i->second;
504 if ((split - min > 0.0) && (max - split > 0.0))
516 double negLeftError = std::pow(position, 2.0) / (split - min);
517 double negRightError = std::pow(points - position, 2.0) / (max - split);
520 if ((negLeftError + negRightError) >= minDimError)
522 minDimError = negLeftError + negRightError;
523 dimLeftError = negLeftError;
524 dimRightError = negRightError;
525 dimSplitValue = split;
526 dimSplitFound =
true;
531 const double actualMinDimError = std::log(minDimError)
532 - 2 * std::log((
double) data.n_cols)
535 #pragma omp critical(DTreeFindUpdate) 536 if ((actualMinDimError > minError) && dimSplitFound)
540 minError = actualMinDimError;
542 splitValue = dimSplitValue;
543 leftError = std::log(dimLeftError) - 2 * std::log((
double) data.n_cols)
545 rightError = std::log(dimRightError) - 2 * std::log((
double) data.n_cols)
554 template<
typename MatType,
typename TagType>
556 const size_t splitDim,
558 arma::Col<size_t>& oldFromNew)
const 565 size_t right = end - 1;
568 while (
data(splitDim, left) <= splitValue)
570 while (
data(splitDim, right) > splitValue)
576 data.swap_cols(left, right);
579 const size_t tmp = oldFromNew[left];
580 oldFromNew[left] = oldFromNew[right];
581 oldFromNew[right] = tmp;
589 template<
typename MatType,
typename TagType>
591 arma::Col<size_t>& oldFromNew,
592 const bool useVolReg,
593 const size_t maxLeafSize,
594 const size_t minLeafSize)
599 double leftG, rightG;
602 ratio = (double) (end - start) / (double) oldFromNew.n_elem;
606 for (
size_t i = 0; i < maxVals.n_elem; ++i)
607 if (maxVals[i] - minVals[i] > 0.0)
608 logVolume += std::log(maxVals[i] - minVals[i]);
611 if ((
size_t) (end - start) > maxLeafSize)
615 double splitValueTmp;
616 double leftError, rightError;
617 if (FindSplit(data, dim, splitValueTmp, leftError, rightError, minLeafSize))
621 const size_t splitIndex = SplitData(data, dim, splitValueTmp, oldFromNew);
629 maxValsL[dim] = splitValueTmp;
630 minValsR[dim] = splitValueTmp;
633 splitValue = splitValueTmp;
637 left =
new DTree(maxValsL, minValsL, start, splitIndex, leftError);
638 right =
new DTree(maxValsR, minValsR, splitIndex, end, rightError);
640 leftG = left->
Grow(data, oldFromNew, useVolReg, maxLeafSize,
642 rightG = right->
Grow(data, oldFromNew, useVolReg, maxLeafSize,
656 subtreeLeavesLogNegError = std::log(
665 subtreeLeavesLogNegError = logNegError;
671 Log::Assert((
size_t) (end - start) >= minLeafSize);
673 subtreeLeavesLogNegError = logNegError;
679 if (subtreeLeaves == 1)
681 return std::numeric_limits<double>::max();
685 const double range = maxVals[splitDim] - minVals[splitDim];
686 const double leftRatio = (splitValue - minVals[splitDim]) / range;
687 const double rightRatio = (maxVals[splitDim] - splitValue) / range;
689 const size_t leftPow = std::pow((
double) (left->
End() - left->
Start()), 2);
690 const size_t rightPow = std::pow((
double) (right->
End() - right->
Start()),
692 const size_t thisPow = std::pow((
double) (end - start), 2);
694 double tmpAlphaSum = leftPow / leftRatio + rightPow / rightRatio - thisPow;
698 const double exponent = 2 * std::log((
double) data.n_cols) + logVolume +
703 tmpAlphaSum += std::exp(exponent);
708 const double exponent = 2 * std::log((
double) data.n_cols)
712 tmpAlphaSum += std::exp(exponent);
715 alphaUpper = std::log(tmpAlphaSum) - 2 * std::log((
double) data.n_cols)
726 gT = alphaUpper - std::log((
double) (subtreeLeaves - 1));
729 return std::min(gT, std::min(leftG, rightG));
738 template<
typename MatType,
typename TagType>
741 const bool useVolReg)
744 if (subtreeLeaves == 1)
746 return std::numeric_limits<double>::max();
755 gT = alphaUpper - std::log((
double) (subtreeLeaves - 1));
761 double rightG = right->
PruneAndUpdate(oldAlpha, points, useVolReg);
774 subtreeLeavesLogNegError = std::log(
780 const double range = maxVals[splitDim] - minVals[splitDim];
781 const double leftRatio = (splitValue - minVals[splitDim]) / range;
782 const double rightRatio = (maxVals[splitDim] - splitValue) / range;
784 const size_t leftPow = std::pow((
double) (left->
End() - left->
Start()),
786 const size_t rightPow = std::pow((
double) (right->
End() - right->
Start()),
788 const size_t thisPow = std::pow((
double) (end - start), 2);
790 double tmpAlphaSum = leftPow / leftRatio + rightPow / rightRatio -
795 const double exponent = 2 * std::log((
double) points) + logVolume +
800 tmpAlphaSum += std::exp(exponent);
805 const double exponent = 2 * std::log((
double) points) + logVolume +
808 tmpAlphaSum += std::exp(exponent);
811 alphaUpper = std::log(tmpAlphaSum) - 2 * std::log((
double) points) -
822 gT = alphaUpper - std::log((
double) (subtreeLeaves - 1));
825 Log::Assert(gT < std::numeric_limits<double>::max());
827 return std::min((
double) gT, std::min(leftG, rightG));
834 subtreeLeavesLogNegError = logNegError;
843 return std::numeric_limits<double>::max();
853 template<
typename MatType,
typename TagType>
856 for (
size_t i = 0; i < query.n_elem; ++i)
857 if ((query[i] < minVals[i]) || (query[i] > maxVals[i]))
864 template<
typename MatType,
typename TagType>
876 if (subtreeLeaves == 1)
877 return std::exp(std::log(ratio) - logVolume);
881 return (query[splitDim] <= splitValue) ?
887 template<
typename MatType,
typename TagType>
890 if (subtreeLeaves == 1)
909 template<
typename MatType,
typename TagType>
922 if (subtreeLeaves == 1)
929 return (query[splitDim] <= splitValue) ?
935 template<
typename MatType,
typename TagType>
940 importances.zeros(maxVals.n_elem);
942 std::stack<const DTree*> nodes;
945 while (!nodes.empty())
947 const DTree& curNode = *nodes.top();
950 if (curNode.subtreeLeaves == 1)
959 nodes.push(curNode.
Left());
960 nodes.push(curNode.
Right());
964 template<
typename MatType,
typename TagType>
981 maxValsL[splitDim] = minValsR[splitDim] = splitValue;
982 left->FillMinMax(minValsL, maxValsL);
983 right->FillMinMax(minValsR, maxValsR);
987 template <
typename MatType,
typename TagType>
988 template <
typename Archive>
992 ar(CEREAL_NVP(start));
994 ar(CEREAL_NVP(maxVals));
995 ar(CEREAL_NVP(minVals));
996 ar(CEREAL_NVP(splitDim));
997 ar(CEREAL_NVP(splitValue));
998 ar(CEREAL_NVP(logNegError));
999 ar(CEREAL_NVP(subtreeLeavesLogNegError));
1000 ar(CEREAL_NVP(subtreeLeaves));
1001 ar(CEREAL_NVP(root));
1002 ar(CEREAL_NVP(ratio));
1003 ar(CEREAL_NVP(logVolume));
1004 ar(CEREAL_NVP(bucketTag));
1005 ar(CEREAL_NVP(alphaUpper));
1007 if (cereal::is_loading<Archive>())
1018 bool hasLeft = (left != NULL);
1019 bool hasRight = (right != NULL);
1021 ar(CEREAL_NVP(hasLeft));
1022 ar(CEREAL_NVP(hasRight));
1031 ar(CEREAL_NVP(maxVals));
1032 ar(CEREAL_NVP(minVals));
1035 if (cereal::is_loading<Archive>() && left && right)
1036 FillMinMax(minVals, maxVals);
void serialize(Archive &ar, const uint32_t)
Serialize the density estimation tree.
Definition: dtree_impl.hpp:989
DTree * Right() const
Return the right child.
Definition: dtree.hpp:303
double Grow(MatType &data, arma::Col< size_t > &oldFromNew, const bool useVolReg=false, const size_t maxLeafSize=10, const size_t minLeafSize=5)
Greedily expand the tree.
Definition: dtree_impl.hpp:590
arma::Col< ElemType > StatType
The statistic type we are holding.
Definition: dtree.hpp:54
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Definition: pointer_wrapper.hpp:23
DTree()
Create an empty density estimation tree.
Definition: dtree_impl.hpp:153
double LogNegError() const
Return the log negative error of this node.
Definition: dtree.hpp:290
size_t End() const
Return the first index of a point not contained in this node.
Definition: dtree.hpp:284
bool WithinRange(const VecType &query) const
Return whether a query point is within the range of this node.
Definition: dtree_impl.hpp:854
double LogNegativeError(const size_t totalPoints) const
Compute the log-negative-error for this point, given the total number of points in the dataset...
Definition: dtree_impl.hpp:416
double PruneAndUpdate(const double oldAlpha, const size_t points, const bool useVolReg=false)
Perform alpha pruning on a tree.
Definition: dtree_impl.hpp:739
Definition: dtree_impl.hpp:21
TagType FindBucket(const VecType &query) const
Return the tag of the leaf containing the query.
Definition: dtree_impl.hpp:910
size_t Start() const
Return the starting index of points contained in this node.
Definition: dtree.hpp:282
double SubtreeLeavesLogNegError() const
Return the log negative error of all descendants of this node.
Definition: dtree.hpp:292
MatType::elem_type ElemType
The actual, underlying type we're working with.
Definition: dtree.hpp:50
size_t SubtreeLeaves() const
Return the number of leaves which are descendants of this node.
Definition: dtree.hpp:294
size_t SplitDim() const
Return the split dimension of this node.
Definition: dtree.hpp:286
double ComputeValue(const VecType &query) const
Compute the logarithm of the density estimate of a given query point.
Definition: dtree_impl.hpp:865
MatType::vec_type VecType
The type of vector we are using.
Definition: dtree.hpp:52
A density estimation tree is similar to both a decision tree and a space partitioning tree (like a kd...
Definition: dtree.hpp:46
DTree * Left() const
Return the left child.
Definition: dtree.hpp:301
DTree & operator=(const DTree &obj)
Copy the given tree.
Definition: dtree_impl.hpp:193
#define CEREAL_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_wrapper.hpp:96
double AlphaUpper() const
Return the upper part of the alpha sum.
Definition: dtree.hpp:307
void ComputeVariableImportance(arma::vec &importances) const
Compute the variable importance of each dimension in the learned tree.
Definition: dtree_impl.hpp:936
void ExtractSplits(std::vector< std::pair< ElemType, size_t >> &splitVec, const MatType &data, size_t dim, const size_t start, const size_t end, const size_t minLeafSize)
This one sorts and scand the given per-dimension extract and puts all splits in a vector...
Definition: dtree_impl.hpp:29
TagType TagTree(const TagType &tag=0, bool everyNode=false)
Index the buckets for possible usage later; this results in every leaf in the tree having a specific ...
Definition: dtree_impl.hpp:888
static void Assert(bool condition, const std::string &message="Assert Failed.")
Checks if the specified condition is true.
Definition: log.cpp:38
~DTree()
Clean up memory allocated by the tree.
Definition: dtree_impl.hpp:407