13 #ifndef MLPACK_METHODS_DET_DT_UTILS_IMPL_HPP 14 #define MLPACK_METHODS_DET_DT_UTILS_IMPL_HPP 22 template <
typename MatType>
25 const arma::Mat<size_t>& labels,
26 const size_t numClasses,
27 const std::string& leafClassMembershipFile)
30 int numLeaves = dtree->TagTree();
32 arma::Mat<size_t> table(numLeaves, (numClasses + 1));
35 for (
size_t i = 0; i < data.n_cols; ++i)
37 const typename MatType::vec_type testPoint = data.unsafe_col(i);
38 const int leafTag = dtree->FindBucket(testPoint);
39 const size_t label = labels[i];
40 table(leafTag, label) += 1;
43 if (leafClassMembershipFile ==
"")
45 Log::Info <<
"Leaf membership; row represents leaf id, column represents " 46 <<
"class id; value represents number of points in leaf in class." 47 << std::endl << table;
52 std::ofstream outfile(leafClassMembershipFile.c_str());
56 Log::Info <<
"Leaf membership printed to '" << leafClassMembershipFile
61 Log::Warn <<
"Can't open '" << leafClassMembershipFile <<
"' to write " 62 <<
"leaf membership to." << std::endl;
70 template <
typename MatType,
typename TagType>
72 const std::string viFile)
78 for (
size_t i = 0; i < imps.n_elem; ++i)
82 Log::Info <<
"Maximum variable importance: " << max <<
"." << std::endl;
86 Log::Info <<
"Variable importance: " << std::endl << imps.t() << std::endl;
90 std::ofstream outfile(viFile.c_str());
94 Log::Info <<
"Variable importance printed to '" << viFile <<
"'." 99 Log::Warn <<
"Can't open '" << viFile <<
"' to write variable importance " 100 <<
"to." << std::endl;
109 template <
typename MatType,
typename TagType>
112 const bool useVolumeReg,
113 const size_t maxLeafSize,
114 const size_t minLeafSize,
115 const bool skipPruning)
122 arma::Col<size_t> oldFromNew(dataset.n_cols);
123 for (
size_t i = 0; i < oldFromNew.n_elem; ++i)
127 MatType newDataset(dataset);
130 double oldAlpha = 0.0;
131 double alpha = dtree->
Grow(newDataset, oldFromNew, useVolumeReg, maxLeafSize,
136 <<
"dataset; minimum alpha: " << alpha <<
"." << std::endl;
141 if (folds == dataset.n_cols)
142 Log::Info <<
"Performing leave-one-out cross validation." << std::endl;
144 Log::Info <<
"Performing " << folds <<
"-fold cross validation." <<
150 std::vector<std::pair<double, double> > prunedSequence;
153 std::pair<double, double> treeSeq(oldAlpha,
155 prunedSequence.push_back(treeSeq);
157 alpha = dtree->
PruneAndUpdate(oldAlpha, dataset.n_cols, useVolumeReg);
162 Log::Assert((alpha < std::numeric_limits<double>::max())
168 std::pair<double, double> treeSeq(oldAlpha,
170 prunedSequence.push_back(treeSeq);
173 Log::Info << prunedSequence.size() <<
" trees in the sequence; maximum alpha:" 174 <<
" " << oldAlpha <<
"." << std::endl;
176 const MatType cvData(dataset);
177 const size_t testSize = dataset.n_cols / folds;
179 arma::vec regularizationConstants(prunedSequence.size());
180 regularizationConstants.fill(0.0);
187 #pragma omp parallel for shared(prunedSequence, regularizationConstants) 188 for (omp_size_t fold = 0; fold < (omp_size_t) folds; fold++)
191 const size_t start = fold * testSize;
192 const size_t end = std::min((
size_t) (fold + 1)
193 * testSize, (
size_t) cvData.n_cols);
195 MatType test = cvData.cols(start, end - 1);
196 MatType train(cvData.n_rows, cvData.n_cols - test.n_cols);
198 if (start == 0 && end < cvData.n_cols)
200 train.cols(0, train.n_cols - 1) = cvData.cols(end, cvData.n_cols - 1);
202 else if (start > 0 && end == cvData.n_cols)
204 train.cols(0, train.n_cols - 1) = cvData.cols(0, start - 1);
208 train.cols(0, start - 1) = cvData.cols(0, start - 1);
209 train.cols(start, train.n_cols - 1) = cvData.cols(end, cvData.n_cols - 1);
216 arma::Col<size_t> cvOldFromNew(train.n_cols);
217 for (
size_t i = 0; i < cvOldFromNew.n_elem; ++i)
221 cvDTree.
Grow(train, cvOldFromNew, useVolumeReg, maxLeafSize,
227 arma::vec cvRegularizationConstants(prunedSequence.size());
228 cvRegularizationConstants.fill(0.0);
230 i < ((prunedSequence.size() < 2) ? 0 : prunedSequence.size() - 2); ++i)
234 for (
size_t j = 0; j < test.n_cols; ++j)
236 arma::vec testPoint = test.unsafe_col(j);
241 cvRegularizationConstants[i] += 2.0 * cvVal / (double) cvData.n_cols;
244 double cvOldAlpha = 0.5 * (prunedSequence[i + 1].first
245 + prunedSequence[i + 2].first);
251 for (
size_t i = 0; i < test.n_cols; ++i)
253 typename MatType::vec_type testPoint = test.unsafe_col(i);
257 if (prunedSequence.size() > 2)
258 cvRegularizationConstants[prunedSequence.size() - 2] += 2.0 * cvVal
259 / (double) cvData.n_cols;
261 #pragma omp critical(DTreeCVUpdate)
262 regularizationConstants += cvRegularizationConstants;
266 double optimalAlpha = -1.0;
267 long double cvBestError = -std::numeric_limits<long double>::max();
269 for (
size_t i = 0; i < prunedSequence.size() - 1; ++i)
273 long double thisError = -std::exp((
long double) prunedSequence[i].second) +
274 (
long double) regularizationConstants[i];
276 if (thisError > cvBestError)
278 cvBestError = thisError;
279 optimalAlpha = prunedSequence[i].first;
283 Log::Info <<
"Optimal alpha: " << optimalAlpha <<
"." << std::endl;
290 for (
size_t i = 0; i < oldFromNew.n_elem; ++i)
294 newDataset = dataset;
298 alpha = dtree->
Grow(newDataset,
305 while ((oldAlpha < optimalAlpha) && (dtree->
SubtreeLeaves() > 1))
308 alpha = dtree->
PruneAndUpdate(oldAlpha, newDataset.n_cols, useVolumeReg);
311 Log::Assert((alpha < std::numeric_limits<double>::max()) ||
317 <<
"pruned tree; optimal alpha: " << oldAlpha <<
"." << std::endl;
322 template<
typename MatType>
328 pathCache.resize(dtree->
TagTree(0,
true));
329 pathCache[0] = PathCacheType::value_type(-1,
"");
333 template<
typename MatType>
337 if (parent ==
nullptr)
342 path.push_back(PathType::value_type(parent->
Left() == node, tag));
343 pathCache[tag] = PathCacheType::value_type(parent->
BucketTag(),
348 template<
typename MatType>
352 if (parent !=
nullptr)
356 std::string PathCacher::BuildString()
359 for (PathType::iterator it = path.begin(); it != path.end(); it++)
364 str += it->first ?
"L" :
"R";
367 str += (it->first ?
"L" :
"R") + std::to_string(it->second);
370 str += std::to_string(it->second) + (it->first ?
"L" :
"R");
380 return pathCache[tag].first;
385 return pathCache[tag].second;
391 #endif // MLPACK_METHODS_DET_DT_UTILS_IMPL_HPP static void Start(const std::string &name)
Start the given timer.
Definition: timers.cpp:28
double Grow(MatType &data, arma::Col< size_t > &oldFromNew, const bool useVolReg=false, const size_t maxLeafSize=10, const size_t minLeafSize=5)
Greedily expand the tree.
Definition: dtree_impl.hpp:590
void Enter(const DTree< MatType, int > *node, const DTree< MatType, int > *parent)
Enter a given node.
Definition: dt_utils_impl.hpp:334
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
DTree< MatType, TagType > * Trainer(MatType &dataset, const size_t folds, const bool useVolumeReg=false, const size_t maxLeafSize=10, const size_t minLeafSize=5, const std::string unprunedTreeOutput="", const bool skipPruning=false)
Train the optimal decision tree using cross-validation with the given number of folds.
Print the direction, then the tag of the node.
Definition: dt_utils.hpp:91
PathCacher(PathFormat fmt, DTree< MatType, int > *tree)
Construct a PathCacher object on the given tree with the given format.
Definition: dt_utils_impl.hpp:323
void EnumerateTree(TreeType *tree, Walker &walker)
Traverses all nodes of the tree, including the inner ones.
Definition: enumerate_tree.hpp:56
void PrintLeafMembership(DTree< MatType, TagType > *dtree, const MatType &data, const arma::Mat< size_t > &labels, const size_t numClasses, const std::string &leafClassMembershipFile="")
Print the membership of leaves of a density estimation tree given the labels and number of classes...
double PruneAndUpdate(const double oldAlpha, const size_t points, const bool useVolReg=false)
Perform alpha pruning on a tree.
Definition: dtree_impl.hpp:739
static MLPACK_EXPORT util::PrefixedOutStream Warn
Prints warning messages prefixed with [WARN ].
Definition: log.hpp:87
const std::string & PathFor(int tag) const
Return the constructed path for a given tag.
Definition: dt_utils_impl.hpp:383
double SubtreeLeavesLogNegError() const
Return the log negative error of all descendants of this node.
Definition: dtree.hpp:292
size_t SubtreeLeaves() const
Return the number of leaves which are descendants of this node.
Definition: dtree.hpp:294
static void Stop(const std::string &name)
Stop the given timer.
Definition: timers.cpp:36
double ComputeValue(const VecType &query) const
Compute the logarithm of the density estimate of a given query point.
Definition: dtree_impl.hpp:865
TagType BucketTag() const
Return the current bucket's ID, if leaf, or -1 otherwise.
Definition: dtree.hpp:309
A density estimation tree is similar to both a decision tree and a space partitioning tree (like a kd...
Definition: dtree.hpp:46
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84
void Leave(const DTree< MatType, int > *node, const DTree< MatType, int > *parent)
Leave the given node.
Definition: dt_utils_impl.hpp:349
format
Define the formats we can read through cereal.
Definition: format.hpp:20
Print only whether we went left or right.
Definition: dt_utils.hpp:89
Print the tag of the node, then the direction.
Definition: dt_utils.hpp:93
DTree * Left() const
Return the left child.
Definition: dtree.hpp:301
PathFormat
Possible formats to use for output.
Definition: dt_utils.hpp:86
void ComputeVariableImportance(arma::vec &importances) const
Compute the variable importance of each dimension in the learned tree.
Definition: dtree_impl.hpp:936
TagType TagTree(const TagType &tag=0, bool everyNode=false)
Index the buckets for possible usage later; this results in every leaf in the tree having a specific ...
Definition: dtree_impl.hpp:888
void PrintVariableImportance(const DTree< MatType, TagType > *dtree, const std::string viFile="")
Print the variable importance of each dimension of a density estimation tree.
Definition: dt_utils_impl.hpp:71
int ParentOf(int tag) const
Get the parent tag of a given tag.
Definition: dt_utils_impl.hpp:378
static void Assert(bool condition, const std::string &message="Assert Failed.")
Checks if the specified condition is true.
Definition: log.cpp:38