mlpack
dt_utils_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_DET_DT_UTILS_IMPL_HPP
14 #define MLPACK_METHODS_DET_DT_UTILS_IMPL_HPP
15 
16 #include "dt_utils.hpp"
18 
19 namespace mlpack {
20 namespace det {
21 
22 template <typename MatType>
23 void PrintLeafMembership(DTree<MatType, int>* dtree,
24  const MatType& data,
25  const arma::Mat<size_t>& labels,
26  const size_t numClasses,
27  const std::string& leafClassMembershipFile)
28 {
29  // Tag the leaves with numbers.
30  int numLeaves = dtree->TagTree();
31 
32  arma::Mat<size_t> table(numLeaves, (numClasses + 1));
33  table.zeros();
34 
35  for (size_t i = 0; i < data.n_cols; ++i)
36  {
37  const typename MatType::vec_type testPoint = data.unsafe_col(i);
38  const int leafTag = dtree->FindBucket(testPoint);
39  const size_t label = labels[i];
40  table(leafTag, label) += 1;
41  }
42 
43  if (leafClassMembershipFile == "")
44  {
45  Log::Info << "Leaf membership; row represents leaf id, column represents "
46  << "class id; value represents number of points in leaf in class."
47  << std::endl << table;
48  }
49  else
50  {
51  // Create a stream for the file.
52  std::ofstream outfile(leafClassMembershipFile.c_str());
53  if (outfile.good())
54  {
55  outfile << table;
56  Log::Info << "Leaf membership printed to '" << leafClassMembershipFile
57  << "'." << std::endl;
58  }
59  else
60  {
61  Log::Warn << "Can't open '" << leafClassMembershipFile << "' to write "
62  << "leaf membership to." << std::endl;
63  }
64  outfile.close();
65  }
66 
67  return;
68 }
69 
70 template <typename MatType, typename TagType>
72  const std::string viFile)
73 {
74  arma::vec imps;
75  dtree->ComputeVariableImportance(imps);
76 
77  double max = 0.0;
78  for (size_t i = 0; i < imps.n_elem; ++i)
79  if (imps[i] > max)
80  max = imps[i];
81 
82  Log::Info << "Maximum variable importance: " << max << "." << std::endl;
83 
84  if (viFile == "")
85  {
86  Log::Info << "Variable importance: " << std::endl << imps.t() << std::endl;
87  }
88  else
89  {
90  std::ofstream outfile(viFile.c_str());
91  if (outfile.good())
92  {
93  outfile << imps;
94  Log::Info << "Variable importance printed to '" << viFile << "'."
95  << std::endl;
96  }
97  else
98  {
99  Log::Warn << "Can't open '" << viFile << "' to write variable importance "
100  << "to." << std::endl;
101  }
102  outfile.close();
103  }
104 }
105 
106 
107 // This function trains the optimal decision tree using the given number of
108 // folds.
109 template <typename MatType, typename TagType>
110 DTree<MatType, TagType>* Trainer(MatType& dataset,
111  const size_t folds,
112  const bool useVolumeReg,
113  const size_t maxLeafSize,
114  const size_t minLeafSize,
115  const bool skipPruning)
116 {
117  // Initialize the tree.
118  DTree<MatType, TagType>* dtree = new DTree<MatType, TagType>(dataset);
119 
120  Timer::Start("tree_growing");
121  // Prepare to grow the tree...
122  arma::Col<size_t> oldFromNew(dataset.n_cols);
123  for (size_t i = 0; i < oldFromNew.n_elem; ++i)
124  oldFromNew[i] = i;
125 
126  // Save the dataset since it would be modified while growing the tree.
127  MatType newDataset(dataset);
128 
129  // Growing the tree
130  double oldAlpha = 0.0;
131  double alpha = dtree->Grow(newDataset, oldFromNew, useVolumeReg, maxLeafSize,
132  minLeafSize);
133 
134  Timer::Stop("tree_growing");
135  Log::Info << dtree->SubtreeLeaves() << " leaf nodes in the tree using full "
136  << "dataset; minimum alpha: " << alpha << "." << std::endl;
137 
138  if (skipPruning)
139  return dtree;
140 
141  if (folds == dataset.n_cols)
142  Log::Info << "Performing leave-one-out cross validation." << std::endl;
143  else
144  Log::Info << "Performing " << folds << "-fold cross validation." <<
145  std::endl;
146 
147  Timer::Start("pruning_sequence");
148 
149  // Sequentially prune and save the alpha values and the values of c_t^2 * r_t.
150  std::vector<std::pair<double, double> > prunedSequence;
151  while (dtree->SubtreeLeaves() > 1)
152  {
153  std::pair<double, double> treeSeq(oldAlpha,
154  dtree->SubtreeLeavesLogNegError());
155  prunedSequence.push_back(treeSeq);
156  oldAlpha = alpha;
157  alpha = dtree->PruneAndUpdate(oldAlpha, dataset.n_cols, useVolumeReg);
158 
159  // Some sanity checks. It seems that on some datasets, the error does not
160  // increase as the tree is pruned but instead stays the same---hence the
161  // "<=" in the final assert.
162  Log::Assert((alpha < std::numeric_limits<double>::max())
163  || (dtree->SubtreeLeaves() == 1));
164  Log::Assert(alpha > oldAlpha);
165  Log::Assert(dtree->SubtreeLeavesLogNegError() <= treeSeq.second);
166  }
167 
168  std::pair<double, double> treeSeq(oldAlpha,
169  dtree->SubtreeLeavesLogNegError());
170  prunedSequence.push_back(treeSeq);
171 
172  Timer::Stop("pruning_sequence");
173  Log::Info << prunedSequence.size() << " trees in the sequence; maximum alpha:"
174  << " " << oldAlpha << "." << std::endl;
175 
176  const MatType cvData(dataset);
177  const size_t testSize = dataset.n_cols / folds;
178 
179  arma::vec regularizationConstants(prunedSequence.size());
180  regularizationConstants.fill(0.0);
181 
182  Timer::Start("cross_validation");
183  // Go through each fold. On the Visual Studio compiler, we have to use
184  // intmax_t because size_t is not yet supported by their OpenMP
185  // implementation. omp_size_t is the appropriate type according to the
186  // platform.
187  #pragma omp parallel for shared(prunedSequence, regularizationConstants)
188  for (omp_size_t fold = 0; fold < (omp_size_t) folds; fold++)
189  {
190  // Break up data into train and test sets.
191  const size_t start = fold * testSize;
192  const size_t end = std::min((size_t) (fold + 1)
193  * testSize, (size_t) cvData.n_cols);
194 
195  MatType test = cvData.cols(start, end - 1);
196  MatType train(cvData.n_rows, cvData.n_cols - test.n_cols);
197 
198  if (start == 0 && end < cvData.n_cols)
199  {
200  train.cols(0, train.n_cols - 1) = cvData.cols(end, cvData.n_cols - 1);
201  }
202  else if (start > 0 && end == cvData.n_cols)
203  {
204  train.cols(0, train.n_cols - 1) = cvData.cols(0, start - 1);
205  }
206  else
207  {
208  train.cols(0, start - 1) = cvData.cols(0, start - 1);
209  train.cols(start, train.n_cols - 1) = cvData.cols(end, cvData.n_cols - 1);
210  }
211 
212  // Initialize the tree.
213  DTree<MatType, TagType> cvDTree(train);
214 
215  // Getting ready to grow the tree...
216  arma::Col<size_t> cvOldFromNew(train.n_cols);
217  for (size_t i = 0; i < cvOldFromNew.n_elem; ++i)
218  cvOldFromNew[i] = i;
219 
220  // Grow the tree.
221  cvDTree.Grow(train, cvOldFromNew, useVolumeReg, maxLeafSize,
222  minLeafSize);
223 
224  // Sequentially prune with all the values of available alphas and adding
225  // values for test values. Don't enter this loop if there are less than two
226  // trees in the pruned sequence.
227  arma::vec cvRegularizationConstants(prunedSequence.size());
228  cvRegularizationConstants.fill(0.0);
229  for (size_t i = 0;
230  i < ((prunedSequence.size() < 2) ? 0 : prunedSequence.size() - 2); ++i)
231  {
232  // Compute test values for this state of the tree.
233  double cvVal = 0.0;
234  for (size_t j = 0; j < test.n_cols; ++j)
235  {
236  arma::vec testPoint = test.unsafe_col(j);
237  cvVal += cvDTree.ComputeValue(testPoint);
238  }
239 
240  // Update the cv regularization constant.
241  cvRegularizationConstants[i] += 2.0 * cvVal / (double) cvData.n_cols;
242 
243  // Determine the new alpha value and prune accordingly.
244  double cvOldAlpha = 0.5 * (prunedSequence[i + 1].first
245  + prunedSequence[i + 2].first);
246  cvDTree.PruneAndUpdate(cvOldAlpha, train.n_cols, useVolumeReg);
247  }
248 
249  // Compute test values for this state of the tree.
250  double cvVal = 0.0;
251  for (size_t i = 0; i < test.n_cols; ++i)
252  {
253  typename MatType::vec_type testPoint = test.unsafe_col(i);
254  cvVal += cvDTree.ComputeValue(testPoint);
255  }
256 
257  if (prunedSequence.size() > 2)
258  cvRegularizationConstants[prunedSequence.size() - 2] += 2.0 * cvVal
259  / (double) cvData.n_cols;
260 
261  #pragma omp critical(DTreeCVUpdate)
262  regularizationConstants += cvRegularizationConstants;
263  }
264  Timer::Stop("cross_validation");
265 
266  double optimalAlpha = -1.0;
267  long double cvBestError = -std::numeric_limits<long double>::max();
268 
269  for (size_t i = 0; i < prunedSequence.size() - 1; ++i)
270  {
271  // We can no longer work in the log-space for this because we have no
272  // guarantee the quantity will be positive.
273  long double thisError = -std::exp((long double) prunedSequence[i].second) +
274  (long double) regularizationConstants[i];
275 
276  if (thisError > cvBestError)
277  {
278  cvBestError = thisError;
279  optimalAlpha = prunedSequence[i].first;
280  }
281  }
282 
283  Log::Info << "Optimal alpha: " << optimalAlpha << "." << std::endl;
284 
285  // Re-Initialize the tree.
286  delete dtree;
287  dtree = new DTree<MatType, TagType>(dataset);
288 
289  // Getting ready to grow the tree...
290  for (size_t i = 0; i < oldFromNew.n_elem; ++i)
291  oldFromNew[i] = i;
292 
293  // Save the dataset since it would be modified while growing the tree.
294  newDataset = dataset;
295 
296  // Grow the tree.
297  oldAlpha = -DBL_MAX;
298  alpha = dtree->Grow(newDataset,
299  oldFromNew,
300  useVolumeReg,
301  maxLeafSize,
302  minLeafSize);
303 
304  // Prune with optimal alpha.
305  while ((oldAlpha < optimalAlpha) && (dtree->SubtreeLeaves() > 1))
306  {
307  oldAlpha = alpha;
308  alpha = dtree->PruneAndUpdate(oldAlpha, newDataset.n_cols, useVolumeReg);
309 
310  // Some sanity checks.
311  Log::Assert((alpha < std::numeric_limits<double>::max()) ||
312  (dtree->SubtreeLeaves() == 1));
313  Log::Assert(alpha > oldAlpha);
314  }
315 
316  Log::Info << dtree->SubtreeLeaves() << " leaf nodes in the optimally "
317  << "pruned tree; optimal alpha: " << oldAlpha << "." << std::endl;
318 
319  return dtree;
320 }
321 
322 template<typename MatType>
324  format(fmt)
325 {
326  // Here we use TagTree()'s output to determine the
327  // number of _nodes_ in the tree.
328  pathCache.resize(dtree->TagTree(0, true));
329  pathCache[0] = PathCacheType::value_type(-1, "");
330  tree::EnumerateTree(dtree, *this);
331 }
332 
333 template<typename MatType>
335  const DTree<MatType, int>* parent)
336 {
337  if (parent == nullptr)
338  return;
339 
340  int tag = node->BucketTag();
341 
342  path.push_back(PathType::value_type(parent->Left() == node, tag));
343  pathCache[tag] = PathCacheType::value_type(parent->BucketTag(),
344  (node->SubtreeLeaves() > 1) ?
345  "" : BuildString());
346 }
347 
348 template<typename MatType>
349 void PathCacher::Leave(const DTree<MatType, int>* /* node */,
350  const DTree<MatType, int>* parent)
351 {
352  if (parent != nullptr)
353  path.pop_back();
354 }
355 
356 std::string PathCacher::BuildString()
357 {
358  std::string str("");
359  for (PathType::iterator it = path.begin(); it != path.end(); it++)
360  {
361  switch (format)
362  {
363  case FormatLR:
364  str += it->first ? "L" : "R";
365  break;
366  case FormatLR_ID:
367  str += (it->first ? "L" : "R") + std::to_string(it->second);
368  break;
369  case FormatID_LR:
370  str += std::to_string(it->second) + (it->first ? "L" : "R");
371  break;
372  }
373  }
374 
375  return str;
376 }
377 
378 int PathCacher::ParentOf(int tag) const
379 {
380  return pathCache[tag].first;
381 }
382 
383 const std::string& PathCacher::PathFor(int tag) const
384 {
385  return pathCache[tag].second;
386 }
387 
388 } // namespace det
389 } // namespace mlpack
390 
391 #endif // MLPACK_METHODS_DET_DT_UTILS_IMPL_HPP
static void Start(const std::string &name)
Start the given timer.
Definition: timers.cpp:28
double Grow(MatType &data, arma::Col< size_t > &oldFromNew, const bool useVolReg=false, const size_t maxLeafSize=10, const size_t minLeafSize=5)
Greedily expand the tree.
Definition: dtree_impl.hpp:590
void Enter(const DTree< MatType, int > *node, const DTree< MatType, int > *parent)
Enter a given node.
Definition: dt_utils_impl.hpp:334
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
DTree< MatType, TagType > * Trainer(MatType &dataset, const size_t folds, const bool useVolumeReg=false, const size_t maxLeafSize=10, const size_t minLeafSize=5, const std::string unprunedTreeOutput="", const bool skipPruning=false)
Train the optimal decision tree using cross-validation with the given number of folds.
Print the direction, then the tag of the node.
Definition: dt_utils.hpp:91
PathCacher(PathFormat fmt, DTree< MatType, int > *tree)
Construct a PathCacher object on the given tree with the given format.
Definition: dt_utils_impl.hpp:323
void EnumerateTree(TreeType *tree, Walker &walker)
Traverses all nodes of the tree, including the inner ones.
Definition: enumerate_tree.hpp:56
void PrintLeafMembership(DTree< MatType, TagType > *dtree, const MatType &data, const arma::Mat< size_t > &labels, const size_t numClasses, const std::string &leafClassMembershipFile="")
Print the membership of leaves of a density estimation tree given the labels and number of classes...
double PruneAndUpdate(const double oldAlpha, const size_t points, const bool useVolReg=false)
Perform alpha pruning on a tree.
Definition: dtree_impl.hpp:739
static MLPACK_EXPORT util::PrefixedOutStream Warn
Prints warning messages prefixed with [WARN ].
Definition: log.hpp:87
const std::string & PathFor(int tag) const
Return the constructed path for a given tag.
Definition: dt_utils_impl.hpp:383
double SubtreeLeavesLogNegError() const
Return the log negative error of all descendants of this node.
Definition: dtree.hpp:292
size_t SubtreeLeaves() const
Return the number of leaves which are descendants of this node.
Definition: dtree.hpp:294
static void Stop(const std::string &name)
Stop the given timer.
Definition: timers.cpp:36
double ComputeValue(const VecType &query) const
Compute the logarithm of the density estimate of a given query point.
Definition: dtree_impl.hpp:865
TagType BucketTag() const
Return the current bucket&#39;s ID, if leaf, or -1 otherwise.
Definition: dtree.hpp:309
A density estimation tree is similar to both a decision tree and a space partitioning tree (like a kd...
Definition: dtree.hpp:46
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84
void Leave(const DTree< MatType, int > *node, const DTree< MatType, int > *parent)
Leave the given node.
Definition: dt_utils_impl.hpp:349
format
Define the formats we can read through cereal.
Definition: format.hpp:20
Print only whether we went left or right.
Definition: dt_utils.hpp:89
Print the tag of the node, then the direction.
Definition: dt_utils.hpp:93
DTree * Left() const
Return the left child.
Definition: dtree.hpp:301
PathFormat
Possible formats to use for output.
Definition: dt_utils.hpp:86
void ComputeVariableImportance(arma::vec &importances) const
Compute the variable importance of each dimension in the learned tree.
Definition: dtree_impl.hpp:936
TagType TagTree(const TagType &tag=0, bool everyNode=false)
Index the buckets for possible usage later; this results in every leaf in the tree having a specific ...
Definition: dtree_impl.hpp:888
void PrintVariableImportance(const DTree< MatType, TagType > *dtree, const std::string viFile="")
Print the variable importance of each dimension of a density estimation tree.
Definition: dt_utils_impl.hpp:71
int ParentOf(int tag) const
Get the parent tag of a given tag.
Definition: dt_utils_impl.hpp:378
static void Assert(bool condition, const std::string &message="Assert Failed.")
Checks if the specified condition is true.
Definition: log.cpp:38