mlpack
decision_tree_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_DECISION_TREE_DECISION_TREE_IMPL_HPP
13 #define MLPACK_METHODS_DECISION_TREE_DECISION_TREE_IMPL_HPP
14 
15 #include "decision_tree.hpp"
16 
17 namespace mlpack {
18 namespace tree {
19 
21 template<typename FitnessFunction,
22  template<typename> class NumericSplitType,
23  template<typename> class CategoricalSplitType,
24  typename DimensionSelectionType,
25  bool NoRecursion>
26 template<typename MatType, typename LabelsType>
27 DecisionTree<FitnessFunction,
28  NumericSplitType,
29  CategoricalSplitType,
30  DimensionSelectionType,
31  NoRecursion>::DecisionTree(
32  MatType data,
33  const data::DatasetInfo& datasetInfo,
34  LabelsType labels,
35  const size_t numClasses,
36  const size_t minimumLeafSize,
37  const double minimumGainSplit,
38  const size_t maximumDepth,
39  DimensionSelectionType dimensionSelector)
40 {
41  using TrueMatType = typename std::decay<MatType>::type;
42  using TrueLabelsType = typename std::decay<LabelsType>::type;
43 
44  // Copy or move data.
45  TrueMatType tmpData(std::move(data));
46  TrueLabelsType tmpLabels(std::move(labels));
47 
48  // Set the correct dimensionality for the dimension selector.
49  dimensionSelector.Dimensions() = tmpData.n_rows;
50 
51  // Pass off work to the Train() method.
52  arma::rowvec weights; // Fake weights, not used.
53  Train<false>(tmpData, 0, tmpData.n_cols, datasetInfo, tmpLabels, numClasses,
54  weights, minimumLeafSize, minimumGainSplit, maximumDepth,
55  dimensionSelector);
56 }
57 
59 template<typename FitnessFunction,
60  template<typename> class NumericSplitType,
61  template<typename> class CategoricalSplitType,
62  typename DimensionSelectionType,
63  bool NoRecursion>
64 template<typename MatType, typename LabelsType>
65 DecisionTree<FitnessFunction,
66  NumericSplitType,
67  CategoricalSplitType,
68  DimensionSelectionType,
69  NoRecursion>::DecisionTree(
70  MatType data,
71  LabelsType labels,
72  const size_t numClasses,
73  const size_t minimumLeafSize,
74  const double minimumGainSplit,
75  const size_t maximumDepth,
76  DimensionSelectionType dimensionSelector)
77 {
78  using TrueMatType = typename std::decay<MatType>::type;
79  using TrueLabelsType = typename std::decay<LabelsType>::type;
80 
81  // Copy or move data.
82  TrueMatType tmpData(std::move(data));
83  TrueLabelsType tmpLabels(std::move(labels));
84 
85  // Set the correct dimensionality for the dimension selector.
86  dimensionSelector.Dimensions() = tmpData.n_rows;
87 
88  // Pass off work to the Train() method.
89  arma::rowvec weights; // Fake weights, not used.
90  Train<false>(tmpData, 0, tmpData.n_cols, tmpLabels, numClasses, weights,
91  minimumLeafSize, minimumGainSplit, maximumDepth, dimensionSelector);
92 }
93 
95 template<typename FitnessFunction,
96  template<typename> class NumericSplitType,
97  template<typename> class CategoricalSplitType,
98  typename DimensionSelectionType,
99  bool NoRecursion>
100 template<typename MatType, typename LabelsType, typename WeightsType>
101 DecisionTree<FitnessFunction,
102  NumericSplitType,
103  CategoricalSplitType,
104  DimensionSelectionType,
105  NoRecursion>::DecisionTree(
106  MatType data,
107  const data::DatasetInfo& datasetInfo,
108  LabelsType labels,
109  const size_t numClasses,
110  WeightsType weights,
111  const size_t minimumLeafSize,
112  const double minimumGainSplit,
113  const size_t maximumDepth,
114  DimensionSelectionType dimensionSelector,
115  const std::enable_if_t<arma::is_arma_type<
116  typename std::remove_reference<WeightsType>::type>::value>*)
117 {
118  using TrueMatType = typename std::decay<MatType>::type;
119  using TrueLabelsType = typename std::decay<LabelsType>::type;
120  using TrueWeightsType = typename std::decay<WeightsType>::type;
121 
122  // Copy or move data.
123  TrueMatType tmpData(std::move(data));
124  TrueLabelsType tmpLabels(std::move(labels));
125  TrueWeightsType tmpWeights(std::move(weights));
126 
127  // Set the correct dimensionality for the dimension selector.
128  dimensionSelector.Dimensions() = tmpData.n_rows;
129 
130  // Pass off work to the weighted Train() method.
131  Train<true>(tmpData, 0, tmpData.n_cols, datasetInfo, tmpLabels, numClasses,
132  tmpWeights, minimumLeafSize, minimumGainSplit, maximumDepth,
133  dimensionSelector);
134 }
135 
137 template<typename FitnessFunction,
138  template<typename> class NumericSplitType,
139  template<typename> class CategoricalSplitType,
140  typename DimensionSelectionType,
141  bool NoRecursion>
142 template<typename MatType, typename LabelsType, typename WeightsType>
143 DecisionTree<FitnessFunction,
144  NumericSplitType,
145  CategoricalSplitType,
146  DimensionSelectionType,
147  NoRecursion>::DecisionTree(
148  const DecisionTree& other,
149  MatType data,
150  const data::DatasetInfo& datasetInfo,
151  LabelsType labels,
152  const size_t numClasses,
153  WeightsType weights,
154  const size_t minimumLeafSize,
155  const double minimumGainSplit,
156  const std::enable_if_t<arma::is_arma_type<
157  typename std::remove_reference<WeightsType>::type>::value>*):
158  NumericAuxiliarySplitInfo(other),
159  CategoricalAuxiliarySplitInfo(other)
160 {
161  using TrueMatType = typename std::decay<MatType>::type;
162  using TrueLabelsType = typename std::decay<LabelsType>::type;
163  using TrueWeightsType = typename std::decay<WeightsType>::type;
164 
165  // Copy or move data.
166  TrueMatType tmpData(std::move(data));
167  TrueLabelsType tmpLabels(std::move(labels));
168  TrueWeightsType tmpWeights(std::move(weights));
169 
170  // Pass off work to the weighted Train() method.
171  Train<true>(tmpData, 0, tmpData.n_cols, datasetInfo, tmpLabels, numClasses,
172  tmpWeights, minimumLeafSize, minimumGainSplit);
173 }
174 
176 template<typename FitnessFunction,
177  template<typename> class NumericSplitType,
178  template<typename> class CategoricalSplitType,
179  typename DimensionSelectionType,
180  bool NoRecursion>
181 template<typename MatType, typename LabelsType, typename WeightsType>
182 DecisionTree<FitnessFunction,
183  NumericSplitType,
184  CategoricalSplitType,
185  DimensionSelectionType,
186  NoRecursion>::DecisionTree(
187  MatType data,
188  LabelsType labels,
189  const size_t numClasses,
190  WeightsType weights,
191  const size_t minimumLeafSize,
192  const double minimumGainSplit,
193  const size_t maximumDepth,
194  DimensionSelectionType dimensionSelector,
195  const std::enable_if_t<
196  arma::is_arma_type<
197  typename std::remove_reference<
198  WeightsType>::type>::value>*)
199 {
200  using TrueMatType = typename std::decay<MatType>::type;
201  using TrueLabelsType = typename std::decay<LabelsType>::type;
202  using TrueWeightsType = typename std::decay<WeightsType>::type;
203 
204  // Copy or move data.
205  TrueMatType tmpData(std::move(data));
206  TrueLabelsType tmpLabels(std::move(labels));
207  TrueWeightsType tmpWeights(std::move(weights));
208 
209  // Set the correct dimensionality for the dimension selector.
210  dimensionSelector.Dimensions() = tmpData.n_rows;
211 
212  // Pass off work to the weighted Train() method.
213  Train<true>(tmpData, 0, tmpData.n_cols, tmpLabels, numClasses, tmpWeights,
214  minimumLeafSize, minimumGainSplit, maximumDepth, dimensionSelector);
215 }
216 
218 template<typename FitnessFunction,
219  template<typename> class NumericSplitType,
220  template<typename> class CategoricalSplitType,
221  typename DimensionSelectionType,
222  bool NoRecursion>
223 template<typename MatType, typename LabelsType, typename WeightsType>
224 DecisionTree<FitnessFunction,
225  NumericSplitType,
226  CategoricalSplitType,
227  DimensionSelectionType,
228  NoRecursion>::DecisionTree(
229  const DecisionTree& other,
230  MatType data,
231  LabelsType labels,
232  const size_t numClasses,
233  WeightsType weights,
234  const size_t minimumLeafSize,
235  const double minimumGainSplit,
236  const size_t maximumDepth,
237  DimensionSelectionType dimensionSelector,
238  const std::enable_if_t<arma::is_arma_type<
239  typename std::remove_reference<
240  WeightsType>::type>::value>*):
241  NumericAuxiliarySplitInfo(other),
242  CategoricalAuxiliarySplitInfo(other) // other info does need to copy
243 {
244  using TrueMatType = typename std::decay<MatType>::type;
245  using TrueLabelsType = typename std::decay<LabelsType>::type;
246  using TrueWeightsType = typename std::decay<WeightsType>::type;
247 
248  // Copy or move data.
249  TrueMatType tmpData(std::move(data));
250  TrueLabelsType tmpLabels(std::move(labels));
251  TrueWeightsType tmpWeights(std::move(weights));
252 
253  // Set the correct dimensionality for the dimension selector.
254  dimensionSelector.Dimensions() = tmpData.n_rows;
255 
256  // Pass off work to the weighted Train() method.
257  Train<true>(tmpData, 0, tmpData.n_cols, tmpLabels, numClasses, tmpWeights,
258  minimumLeafSize, minimumGainSplit, maximumDepth, dimensionSelector);
259 }
260 
262 template<typename FitnessFunction,
263  template<typename> class NumericSplitType,
264  template<typename> class CategoricalSplitType,
265  typename DimensionSelectionType,
266  bool NoRecursion>
267 DecisionTree<FitnessFunction,
268  NumericSplitType,
269  CategoricalSplitType,
270  DimensionSelectionType,
271  NoRecursion>::DecisionTree(const size_t numClasses) :
272  splitDimension(0),
273  dimensionTypeOrMajorityClass(0),
274  classProbabilities(numClasses)
275 {
276  // Initialize utility vector.
277  classProbabilities.fill(1.0 / (double) numClasses);
278 }
279 
281 template<typename FitnessFunction,
282  template<typename> class NumericSplitType,
283  template<typename> class CategoricalSplitType,
284  typename DimensionSelectionType,
285  bool NoRecursion>
286 DecisionTree<FitnessFunction,
287  NumericSplitType,
288  CategoricalSplitType,
289  DimensionSelectionType,
290  NoRecursion>::DecisionTree(const DecisionTree& other) :
291  NumericAuxiliarySplitInfo(other),
292  CategoricalAuxiliarySplitInfo(other),
293  splitDimension(other.splitDimension),
294  dimensionTypeOrMajorityClass(other.dimensionTypeOrMajorityClass),
295  classProbabilities(other.classProbabilities)
296 {
297  // Copy each child.
298  for (size_t i = 0; i < other.children.size(); ++i)
299  children.push_back(new DecisionTree(*other.children[i]));
300 }
301 
303 template<typename FitnessFunction,
304  template<typename> class NumericSplitType,
305  template<typename> class CategoricalSplitType,
306  typename DimensionSelectionType,
307  bool NoRecursion>
308 DecisionTree<FitnessFunction,
309  NumericSplitType,
310  CategoricalSplitType,
311  DimensionSelectionType,
312  NoRecursion>::DecisionTree(DecisionTree&& other) :
313  NumericAuxiliarySplitInfo(std::move(other)),
314  CategoricalAuxiliarySplitInfo(std::move(other)),
315  children(std::move(other.children)),
316  splitDimension(other.splitDimension),
317  dimensionTypeOrMajorityClass(other.dimensionTypeOrMajorityClass),
318  classProbabilities(std::move(other.classProbabilities))
319 {
320  // Reset the other object.
321  other.classProbabilities.ones(1); // One class, P(1) = 1.
322 }
323 
325 template<typename FitnessFunction,
326  template<typename> class NumericSplitType,
327  template<typename> class CategoricalSplitType,
328  typename DimensionSelectionType,
329  bool NoRecursion>
330 DecisionTree<FitnessFunction,
331  NumericSplitType,
332  CategoricalSplitType,
333  DimensionSelectionType,
334  NoRecursion>&
335 DecisionTree<FitnessFunction,
336  NumericSplitType,
337  CategoricalSplitType,
338  DimensionSelectionType,
339  NoRecursion>::operator=(const DecisionTree& other)
340 {
341  if (this == &other)
342  return *this; // Nothing to copy.
343 
344  // Clean memory if needed.
345  for (size_t i = 0; i < children.size(); ++i)
346  delete children[i];
347  children.clear();
348 
349  // Copy everything from the other tree.
350  splitDimension = other.splitDimension;
351  dimensionTypeOrMajorityClass = other.dimensionTypeOrMajorityClass;
352  classProbabilities = other.classProbabilities;
353 
354  // Copy the children.
355  for (size_t i = 0; i < other.children.size(); ++i)
356  children.push_back(new DecisionTree(*other.children[i]));
357 
358  // Copy the auxiliary info.
359  NumericAuxiliarySplitInfo::operator=(other);
360  CategoricalAuxiliarySplitInfo::operator=(other);
361 
362  return *this;
363 }
364 
366 template<typename FitnessFunction,
367  template<typename> class NumericSplitType,
368  template<typename> class CategoricalSplitType,
369  typename DimensionSelectionType,
370  bool NoRecursion>
371 DecisionTree<FitnessFunction,
372  NumericSplitType,
373  CategoricalSplitType,
374  DimensionSelectionType,
375  NoRecursion>&
376 DecisionTree<FitnessFunction,
377  NumericSplitType,
378  CategoricalSplitType,
379  DimensionSelectionType,
380  NoRecursion>::operator=(DecisionTree&& other)
381 {
382  if (this == &other)
383  return *this; // Nothing to move.
384 
385  // Clean memory if needed.
386  for (size_t i = 0; i < children.size(); ++i)
387  delete children[i];
388  children.clear();
389 
390  // Take ownership of the other tree's components.
391  children = std::move(other.children);
392  splitDimension = other.splitDimension;
393  dimensionTypeOrMajorityClass = other.dimensionTypeOrMajorityClass;
394  classProbabilities = std::move(other.classProbabilities);
395 
396  // Reset the class probabilities of the other object.
397  other.classProbabilities.ones(1); // One class, P(1) = 1.
398 
399  // Take ownership of the auxiliary info.
400  NumericAuxiliarySplitInfo::operator=(std::move(other));
401  CategoricalAuxiliarySplitInfo::operator=(std::move(other));
402 
403  return *this;
404 }
405 
407 template<typename FitnessFunction,
408  template<typename> class NumericSplitType,
409  template<typename> class CategoricalSplitType,
410  typename DimensionSelectionType,
411  bool NoRecursion>
412 DecisionTree<FitnessFunction,
413  NumericSplitType,
414  CategoricalSplitType,
415  DimensionSelectionType,
416  NoRecursion>::~DecisionTree()
417 {
418  for (size_t i = 0; i < children.size(); ++i)
419  delete children[i];
420 }
421 
423 template<typename FitnessFunction,
424  template<typename> class NumericSplitType,
425  template<typename> class CategoricalSplitType,
426  typename DimensionSelectionType,
427  bool NoRecursion>
428 template<typename MatType, typename LabelsType>
429 double DecisionTree<FitnessFunction,
430  NumericSplitType,
431  CategoricalSplitType,
432  DimensionSelectionType,
433  NoRecursion>::Train(
434  MatType data,
435  const data::DatasetInfo& datasetInfo,
436  LabelsType labels,
437  const size_t numClasses,
438  const size_t minimumLeafSize,
439  const double minimumGainSplit,
440  const size_t maximumDepth,
441  DimensionSelectionType dimensionSelector)
442 {
443  // Sanity check on data.
444  util::CheckSameSizes(data, labels, "DecisionTree::Train()");
445 
446  using TrueMatType = typename std::decay<MatType>::type;
447  using TrueLabelsType = typename std::decay<LabelsType>::type;
448 
449  // Copy or move data.
450  TrueMatType tmpData(std::move(data));
451  TrueLabelsType tmpLabels(std::move(labels));
452 
453  // Set the correct dimensionality for the dimension selector.
454  dimensionSelector.Dimensions() = tmpData.n_rows;
455 
456  // Pass off work to the Train() method.
457  arma::rowvec weights; // Fake weights, not used.
458  return Train<false>(tmpData, 0, tmpData.n_cols, datasetInfo, tmpLabels,
459  numClasses, weights, minimumLeafSize, minimumGainSplit, maximumDepth,
460  dimensionSelector);
461 }
462 
464 template<typename FitnessFunction,
465  template<typename> class NumericSplitType,
466  template<typename> class CategoricalSplitType,
467  typename DimensionSelectionType,
468  bool NoRecursion>
469 template<typename MatType, typename LabelsType>
470 double DecisionTree<FitnessFunction,
471  NumericSplitType,
472  CategoricalSplitType,
473  DimensionSelectionType,
474  NoRecursion>::Train(
475  MatType data,
476  LabelsType labels,
477  const size_t numClasses,
478  const size_t minimumLeafSize,
479  const double minimumGainSplit,
480  const size_t maximumDepth,
481  DimensionSelectionType dimensionSelector)
482 {
483  // Sanity check on data.
484  util::CheckSameSizes(data, labels, "DecisionTree::Train()");
485 
486  using TrueMatType = typename std::decay<MatType>::type;
487  using TrueLabelsType = typename std::decay<LabelsType>::type;
488 
489  // Copy or move data.
490  TrueMatType tmpData(std::move(data));
491  TrueLabelsType tmpLabels(std::move(labels));
492 
493  // Set the correct dimensionality for the dimension selector.
494  dimensionSelector.Dimensions() = tmpData.n_rows;
495 
496  // Pass off work to the Train() method.
497  arma::rowvec weights; // Fake weights, not used.
498  return Train<false>(tmpData, 0, tmpData.n_cols, tmpLabels, numClasses,
499  weights, minimumLeafSize, minimumGainSplit, maximumDepth,
500  dimensionSelector);
501 }
502 
504 template<typename FitnessFunction,
505  template<typename> class NumericSplitType,
506  template<typename> class CategoricalSplitType,
507  typename DimensionSelectionType,
508  bool NoRecursion>
509 template<typename MatType, typename LabelsType, typename WeightsType>
510 double DecisionTree<FitnessFunction,
511  NumericSplitType,
512  CategoricalSplitType,
513  DimensionSelectionType,
514  NoRecursion>::Train(
515  MatType data,
516  const data::DatasetInfo& datasetInfo,
517  LabelsType labels,
518  const size_t numClasses,
519  WeightsType weights,
520  const size_t minimumLeafSize,
521  const double minimumGainSplit,
522  const size_t maximumDepth,
523  DimensionSelectionType dimensionSelector,
524  const std::enable_if_t<
525  arma::is_arma_type<
526  typename std::remove_reference<
527  WeightsType>::type>::value>*)
528 {
529  // Sanity check on data.
530  util::CheckSameSizes(data, labels, "DecisionTree::Train()");
531 
532  using TrueMatType = typename std::decay<MatType>::type;
533  using TrueLabelsType = typename std::decay<LabelsType>::type;
534  using TrueWeightsType = typename std::decay<WeightsType>::type;
535 
536  // Copy or move data.
537  TrueMatType tmpData(std::move(data));
538  TrueLabelsType tmpLabels(std::move(labels));
539  TrueWeightsType tmpWeights(std::move(weights));
540 
541  // Set the correct dimensionality for the dimension selector.
542  dimensionSelector.Dimensions() = tmpData.n_rows;
543 
544  // Pass off work to the Train() method.
545  return Train<true>(tmpData, 0, tmpData.n_cols, datasetInfo, tmpLabels,
546  numClasses, tmpWeights, minimumLeafSize, minimumGainSplit, maximumDepth,
547  dimensionSelector);
548 }
549 
551 template<typename FitnessFunction,
552  template<typename> class NumericSplitType,
553  template<typename> class CategoricalSplitType,
554  typename DimensionSelectionType,
555  bool NoRecursion>
556 template<typename MatType, typename LabelsType, typename WeightsType>
557 double DecisionTree<FitnessFunction,
558  NumericSplitType,
559  CategoricalSplitType,
560  DimensionSelectionType,
561  NoRecursion>::Train(
562  MatType data,
563  LabelsType labels,
564  const size_t numClasses,
565  WeightsType weights,
566  const size_t minimumLeafSize,
567  const double minimumGainSplit,
568  const size_t maximumDepth,
569  DimensionSelectionType dimensionSelector,
570  const std::enable_if_t<
571  arma::is_arma_type<
572  typename std::remove_reference<
573  WeightsType>::type>::value>*)
574 {
575  // Sanity check on data.
576  util::CheckSameSizes(data, labels, "DecisionTree::Train()");
577 
578  using TrueMatType = typename std::decay<MatType>::type;
579  using TrueLabelsType = typename std::decay<LabelsType>::type;
580  using TrueWeightsType = typename std::decay<WeightsType>::type;
581 
582  // Copy or move data.
583  TrueMatType tmpData(std::move(data));
584  TrueLabelsType tmpLabels(std::move(labels));
585  TrueWeightsType tmpWeights(std::move(weights));
586 
587  // Set the correct dimensionality for the dimension selector.
588  dimensionSelector.Dimensions() = tmpData.n_rows;
589 
590  // Pass off work to the Train() method.
591  return Train<true>(tmpData, 0, tmpData.n_cols, tmpLabels, numClasses,
592  tmpWeights, minimumLeafSize, minimumGainSplit, maximumDepth,
593  dimensionSelector);
594 }
595 
597 template<typename FitnessFunction,
598  template<typename> class NumericSplitType,
599  template<typename> class CategoricalSplitType,
600  typename DimensionSelectionType,
601  bool NoRecursion>
602 template<bool UseWeights, typename MatType>
603 double DecisionTree<FitnessFunction,
604  NumericSplitType,
605  CategoricalSplitType,
606  DimensionSelectionType,
607  NoRecursion>::Train(
608  MatType& data,
609  const size_t begin,
610  const size_t count,
611  const data::DatasetInfo& datasetInfo,
612  arma::Row<size_t>& labels,
613  const size_t numClasses,
614  arma::rowvec& weights,
615  const size_t minimumLeafSize,
616  const double minimumGainSplit,
617  const size_t maximumDepth,
618  DimensionSelectionType& dimensionSelector)
619 {
620  // Clear children if needed.
621  for (size_t i = 0; i < children.size(); ++i)
622  delete children[i];
623  children.clear();
624 
625  // Look through the list of dimensions and obtain the gain of the best split.
626  // We'll cache the best numeric and categorical split auxiliary information in
627  // numericAux and categoricalAux (and clear them later if we make no split),
628  // and use classProbabilities as auxiliary information. Later we'll overwrite
629  // classProbabilities to the empirical class probabilities if we do not split.
630  double bestGain = FitnessFunction::template Evaluate<UseWeights>(
631  labels.subvec(begin, begin + count - 1),
632  numClasses,
633  UseWeights ? weights.subvec(begin, begin + count - 1) : weights);
634  size_t bestDim = datasetInfo.Dimensionality(); // This means "no split".
635  const size_t end = dimensionSelector.End();
636 
637  if (maximumDepth != 1)
638  {
639  for (size_t i = dimensionSelector.Begin(); i != end;
640  i = dimensionSelector.Next())
641  {
642  double dimGain = -DBL_MAX;
643  if (datasetInfo.Type(i) == data::Datatype::categorical)
644  {
645  dimGain = CategoricalSplit::template SplitIfBetter<UseWeights>(bestGain,
646  data.cols(begin, begin + count - 1).row(i),
647  datasetInfo.NumMappings(i),
648  labels.subvec(begin, begin + count - 1),
649  numClasses,
650  UseWeights ? weights.subvec(begin, begin + count - 1) : weights,
651  minimumLeafSize,
652  minimumGainSplit,
653  classProbabilities,
654  *this);
655  }
656  else if (datasetInfo.Type(i) == data::Datatype::numeric)
657  {
658  dimGain = NumericSplit::template SplitIfBetter<UseWeights>(bestGain,
659  data.cols(begin, begin + count - 1).row(i),
660  labels.subvec(begin, begin + count - 1),
661  numClasses,
662  UseWeights ? weights.subvec(begin, begin + count - 1) : weights,
663  minimumLeafSize,
664  minimumGainSplit,
665  classProbabilities,
666  *this);
667  }
668 
669  // If the splitter reported that it did not split, move to the next
670  // dimension.
671  if (dimGain == DBL_MAX)
672  continue;
673 
674  // Was there an improvement? If so mark that it's the new best dimension.
675  bestDim = i;
676  bestGain = dimGain;
677 
678  // If the gain is the best possible, no need to keep looking.
679  if (bestGain >= 0.0)
680  break;
681  }
682  }
683 
684  // Did we split or not? If so, then split the data and create the children.
685  if (bestDim != datasetInfo.Dimensionality())
686  {
687  dimensionTypeOrMajorityClass = (size_t) datasetInfo.Type(bestDim);
688  splitDimension = bestDim;
689 
690  // Get the number of children we will have.
691  size_t numChildren = 0;
692  if (datasetInfo.Type(bestDim) == data::Datatype::categorical)
693  numChildren = CategoricalSplit::NumChildren(classProbabilities[0], *this);
694  else
695  numChildren = NumericSplit::NumChildren(classProbabilities[0], *this);
696 
697  // Calculate all child assignments.
698  arma::Row<size_t> childAssignments(count);
699  if (datasetInfo.Type(bestDim) == data::Datatype::categorical)
700  {
701  for (size_t j = begin; j < begin + count; ++j)
702  childAssignments[j - begin] = CategoricalSplit::CalculateDirection(
703  data(bestDim, j), classProbabilities[0], *this);
704  }
705  else
706  {
707  for (size_t j = begin; j < begin + count; ++j)
708  {
709  childAssignments[j - begin] = NumericSplit::CalculateDirection(
710  data(bestDim, j), classProbabilities[0], *this);
711  }
712  }
713 
714  // Figure out counts of children.
715  arma::Row<size_t> childCounts(numChildren, arma::fill::zeros);
716  for (size_t i = begin; i < begin + count; ++i)
717  childCounts[childAssignments[i - begin]]++;
718 
719  // Initialize bestGain if recursive split is allowed.
720  if (!NoRecursion)
721  {
722  bestGain = 0.0;
723  }
724 
725  // Split into children.
726  size_t currentCol = begin;
727  for (size_t i = 0; i < numChildren; ++i)
728  {
729  size_t currentChildBegin = currentCol;
730  for (size_t j = currentChildBegin; j < begin + count; ++j)
731  {
732  if (childAssignments[j - begin] == i)
733  {
734  childAssignments.swap_cols(currentCol - begin, j - begin);
735  data.swap_cols(currentCol, j);
736  labels.swap_cols(currentCol, j);
737  if (UseWeights)
738  weights.swap_cols(currentCol, j);
739  ++currentCol;
740  }
741  }
742 
743  // Now build the child recursively.
744  DecisionTree* child = new DecisionTree();
745  if (NoRecursion)
746  {
747  child->Train<UseWeights>(data, currentChildBegin,
748  currentCol - currentChildBegin, datasetInfo, labels, numClasses,
749  weights, currentCol - currentChildBegin, minimumGainSplit,
750  maximumDepth - 1, dimensionSelector);
751  }
752  else
753  {
754  // During recursion entropy of child node may change.
755  double childGain = child->Train<UseWeights>(data, currentChildBegin,
756  currentCol - currentChildBegin, datasetInfo, labels, numClasses,
757  weights, minimumLeafSize, minimumGainSplit, maximumDepth - 1,
758  dimensionSelector);
759  bestGain += double(childCounts[i]) / double(count) * (-childGain);
760  }
761  children.push_back(child);
762  }
763  }
764  else
765  {
766  // Clear auxiliary info objects.
767  NumericAuxiliarySplitInfo::operator=(NumericAuxiliarySplitInfo());
768  CategoricalAuxiliarySplitInfo::operator=(CategoricalAuxiliarySplitInfo());
769 
770  // Calculate class probabilities because we are a leaf.
771  CalculateClassProbabilities<UseWeights>(
772  labels.subvec(begin, begin + count - 1),
773  numClasses,
774  UseWeights ? weights.subvec(begin, begin + count - 1) : weights);
775  }
776 
777  return -bestGain;
778 }
779 
781 template<typename FitnessFunction,
782  template<typename> class NumericSplitType,
783  template<typename> class CategoricalSplitType,
784  typename DimensionSelectionType,
785  bool NoRecursion>
786 template<bool UseWeights, typename MatType>
787 double DecisionTree<FitnessFunction,
788  NumericSplitType,
789  CategoricalSplitType,
790  DimensionSelectionType,
791  NoRecursion>::Train(
792  MatType& data,
793  const size_t begin,
794  const size_t count,
795  arma::Row<size_t>& labels,
796  const size_t numClasses,
797  arma::rowvec& weights,
798  const size_t minimumLeafSize,
799  const double minimumGainSplit,
800  const size_t maximumDepth,
801  DimensionSelectionType& dimensionSelector)
802 {
803  // Clear children if needed.
804  for (size_t i = 0; i < children.size(); ++i)
805  delete children[i];
806  children.clear();
807 
808  // We won't be using these members, so reset them.
809  CategoricalAuxiliarySplitInfo::operator=(CategoricalAuxiliarySplitInfo());
810 
811  // Look through the list of dimensions and obtain the best split. We'll cache
812  // the best numeric split auxiliary information in numericAux (and clear it
813  // later if we don't make a split), and use classProbabilities as auxiliary
814  // information. Later we'll overwrite classProbabilities to the empirical
815  // class probabilities if we do not split.
816  double bestGain = FitnessFunction::template Evaluate<UseWeights>(
817  labels.subvec(begin, begin + count - 1),
818  numClasses,
819  UseWeights ? weights.subvec(begin, begin + count - 1) : weights);
820  size_t bestDim = data.n_rows; // This means "no split".
821 
822  if (maximumDepth != 1)
823  {
824  for (size_t i = dimensionSelector.Begin(); i != dimensionSelector.End();
825  i = dimensionSelector.Next())
826  {
827  const double dimGain = NumericSplitType<FitnessFunction>::template
828  SplitIfBetter<UseWeights>(bestGain,
829  data.cols(begin, begin + count - 1).row(i),
830  labels.cols(begin, begin + count - 1),
831  numClasses,
832  UseWeights ?
833  weights.cols(begin, begin + count - 1) :
834  weights,
835  minimumLeafSize,
836  minimumGainSplit,
837  classProbabilities,
838  *this);
839 
840  // If the splitter did not report that it improved, then move to the next
841  // dimension.
842  if (dimGain == DBL_MAX)
843  continue;
844 
845  bestDim = i;
846  bestGain = dimGain;
847 
848  // If the gain is the best possible, no need to keep looking.
849  if (bestGain >= 0.0)
850  break;
851  }
852  }
853 
854  // Did we split or not? If so, then split the data and create the children.
855  if (bestDim != data.n_rows)
856  {
857  // We know that the split is numeric.
858  size_t numChildren =
859  NumericSplit::NumChildren(classProbabilities[0], *this);
860  splitDimension = bestDim;
861  dimensionTypeOrMajorityClass = (size_t) data::Datatype::numeric;
862 
863  // Calculate all child assignments.
864  arma::Row<size_t> childAssignments(count);
865 
866  for (size_t j = begin; j < begin + count; ++j)
867  {
868  childAssignments[j - begin] = NumericSplit::CalculateDirection(
869  data(bestDim, j), classProbabilities[0], *this);
870  }
871 
872  // Calculate counts of children in each node.
873  arma::Row<size_t> childCounts(numChildren);
874  childCounts.zeros();
875  for (size_t j = begin; j < begin + count; ++j)
876  childCounts[childAssignments[j - begin]]++;
877 
878  // Initialize bestGain if recursive split is allowed.
879  if (!NoRecursion)
880  {
881  bestGain = 0.0;
882  }
883 
884  size_t currentCol = begin;
885  for (size_t i = 0; i < numChildren; ++i)
886  {
887  size_t currentChildBegin = currentCol;
888  for (size_t j = currentChildBegin; j < begin + count; ++j)
889  {
890  if (childAssignments[j - begin] == i)
891  {
892  childAssignments.swap_cols(currentCol - begin, j - begin);
893  data.swap_cols(currentCol, j);
894  labels.swap_cols(currentCol, j);
895  if (UseWeights)
896  weights.swap_cols(currentCol, j);
897  ++currentCol;
898  }
899  }
900 
901  // Now build the child recursively.
902  DecisionTree* child = new DecisionTree();
903  if (NoRecursion)
904  {
905  child->Train<UseWeights>(data, currentChildBegin,
906  currentCol - currentChildBegin, labels, numClasses, weights,
907  currentCol - currentChildBegin, minimumGainSplit, maximumDepth - 1,
908  dimensionSelector);
909  }
910  else
911  {
912  // During recursion entropy of child node may change.
913  double childGain = child->Train<UseWeights>(data, currentChildBegin,
914  currentCol - currentChildBegin, labels, numClasses, weights,
915  minimumLeafSize, minimumGainSplit, maximumDepth - 1,
916  dimensionSelector);
917  bestGain += double(childCounts[i]) / double(count) * (-childGain);
918  }
919  children.push_back(child);
920  }
921  }
922  else
923  {
924  // We won't be needing these members, so reset them.
925  NumericAuxiliarySplitInfo::operator=(NumericAuxiliarySplitInfo());
926 
927  // Calculate class probabilities because we are a leaf.
928  CalculateClassProbabilities<UseWeights>(
929  labels.subvec(begin, begin + count - 1),
930  numClasses,
931  UseWeights ? weights.subvec(begin, begin + count - 1) : weights);
932  }
933 
934  return -bestGain;
935 }
936 
938 template<typename FitnessFunction,
939  template<typename> class NumericSplitType,
940  template<typename> class CategoricalSplitType,
941  typename DimensionSelectionType,
942  bool NoRecursion>
943 template<typename VecType>
944 size_t DecisionTree<FitnessFunction,
945  NumericSplitType,
946  CategoricalSplitType,
947  DimensionSelectionType,
948  NoRecursion>::Classify(const VecType& point) const
949 {
950  if (children.size() == 0)
951  {
952  // Return cached max of probabilities.
953  return dimensionTypeOrMajorityClass;
954  }
955 
956  return children[CalculateDirection(point)]->Classify(point);
957 }
958 
960 template<typename FitnessFunction,
961  template<typename> class NumericSplitType,
962  template<typename> class CategoricalSplitType,
963  typename DimensionSelectionType,
964  bool NoRecursion>
965 template<typename VecType>
966 void DecisionTree<FitnessFunction,
967  NumericSplitType,
968  CategoricalSplitType,
969  DimensionSelectionType,
970  NoRecursion>::Classify(const VecType& point,
971  size_t& prediction,
972  arma::vec& probabilities) const
973 {
974  if (children.size() == 0)
975  {
976  prediction = dimensionTypeOrMajorityClass;
977  probabilities = classProbabilities;
978  return;
979  }
980 
981  children[CalculateDirection(point)]->Classify(point, prediction,
982  probabilities);
983 }
984 
986 template<typename FitnessFunction,
987  template<typename> class NumericSplitType,
988  template<typename> class CategoricalSplitType,
989  typename DimensionSelectionType,
990  bool NoRecursion>
991 template<typename MatType>
992 void DecisionTree<FitnessFunction,
993  NumericSplitType,
994  CategoricalSplitType,
995  DimensionSelectionType,
996  NoRecursion>::Classify(const MatType& data,
997  arma::Row<size_t>& predictions) const
998 {
999  predictions.set_size(data.n_cols);
1000  if (children.size() == 0)
1001  {
1002  predictions.fill(dimensionTypeOrMajorityClass);
1003  return;
1004  }
1005 
1006  // Loop over each point.
1007  for (size_t i = 0; i < data.n_cols; ++i)
1008  predictions[i] = Classify(data.col(i));
1009 }
1010 
1012 template<typename FitnessFunction,
1013  template<typename> class NumericSplitType,
1014  template<typename> class CategoricalSplitType,
1015  typename DimensionSelectionType,
1016  bool NoRecursion>
1017 template<typename MatType>
1018 void DecisionTree<FitnessFunction,
1019  NumericSplitType,
1020  CategoricalSplitType,
1021  DimensionSelectionType,
1022  NoRecursion>::Classify(const MatType& data,
1023  arma::Row<size_t>& predictions,
1024  arma::mat& probabilities) const
1025 {
1026  predictions.set_size(data.n_cols);
1027  if (children.size() == 0)
1028  {
1029  predictions.fill(dimensionTypeOrMajorityClass);
1030  probabilities = arma::repmat(classProbabilities, 1, data.n_cols);
1031  return;
1032  }
1033 
1034  // Otherwise we have to find the right size to set the predictions matrix to
1035  // be.
1036  DecisionTree* node = children[0];
1037  while (node->NumChildren() != 0)
1038  node = &node->Child(0);
1039  probabilities.set_size(node->classProbabilities.n_elem, data.n_cols);
1040 
1041  for (size_t i = 0; i < data.n_cols; ++i)
1042  {
1043  arma::vec v = probabilities.unsafe_col(i); // Alias of column.
1044  Classify(data.col(i), predictions[i], v);
1045  }
1046 }
1047 
1049 template<typename FitnessFunction,
1050  template<typename> class NumericSplitType,
1051  template<typename> class CategoricalSplitType,
1052  typename DimensionSelectionType,
1053  bool NoRecursion>
1054 template<typename Archive>
1055 void DecisionTree<FitnessFunction,
1056  NumericSplitType,
1057  CategoricalSplitType,
1058  DimensionSelectionType,
1059  NoRecursion>::serialize(Archive& ar,
1060  const uint32_t /* version */)
1061 {
1062  // Clean memory if needed.
1063  if (cereal::is_loading<Archive>())
1064  {
1065  for (size_t i = 0; i < children.size(); ++i)
1066  delete children[i];
1067  children.clear();
1068  }
1069  // Serialize the children first.
1070  ar(CEREAL_VECTOR_POINTER(children));
1071 
1072  // Now serialize the rest of the object.
1073  ar(CEREAL_NVP(splitDimension));
1074  ar(CEREAL_NVP(dimensionTypeOrMajorityClass));
1075  ar(CEREAL_NVP(classProbabilities));
1076 }
1077 
1078 template<typename FitnessFunction,
1079  template<typename> class NumericSplitType,
1080  template<typename> class CategoricalSplitType,
1081  typename DimensionSelectionType,
1082  bool NoRecursion>
1083 template<typename VecType>
1084 size_t DecisionTree<FitnessFunction,
1085  NumericSplitType,
1086  CategoricalSplitType,
1087  DimensionSelectionType,
1088  NoRecursion>::CalculateDirection(const VecType& point) const
1089 {
1090  if ((data::Datatype) dimensionTypeOrMajorityClass ==
1091  data::Datatype::categorical)
1092  return CategoricalSplit::CalculateDirection(point[splitDimension],
1093  classProbabilities[0], *this);
1094  else
1095  return NumericSplit::CalculateDirection(point[splitDimension],
1096  classProbabilities[0], *this);
1097 }
1098 
1099 // Get the number of classes in the tree.
1100 template<typename FitnessFunction,
1101  template<typename> class NumericSplitType,
1102  template<typename> class CategoricalSplitType,
1103  typename DimensionSelectionType,
1104  bool NoRecursion>
1105 size_t DecisionTree<FitnessFunction,
1106  NumericSplitType,
1107  CategoricalSplitType,
1108  DimensionSelectionType,
1109  NoRecursion>::NumClasses() const
1110 {
1111  // Recurse to the nearest leaf and return the number of elements in the
1112  // probability vector.
1113  if (children.size() == 0)
1114  return classProbabilities.n_elem;
1115  else
1116  return children[0]->NumClasses();
1117 }
1118 
1119 template<typename FitnessFunction,
1120  template<typename> class NumericSplitType,
1121  template<typename> class CategoricalSplitType,
1122  typename DimensionSelectionType,
1123  bool NoRecursion>
1124 template<bool UseWeights, typename RowType, typename WeightsRowType>
1125 void DecisionTree<FitnessFunction,
1126  NumericSplitType,
1127  CategoricalSplitType,
1128  DimensionSelectionType,
1129  NoRecursion>::CalculateClassProbabilities(
1130  const RowType& labels,
1131  const size_t numClasses,
1132  const WeightsRowType& weights)
1133 {
1134  classProbabilities.zeros(numClasses);
1135  double sumWeights = 0.0;
1136  for (size_t i = 0; i < labels.n_elem; ++i)
1137  {
1138  if (UseWeights)
1139  {
1140  classProbabilities[labels[i]] += weights[i];
1141  sumWeights += weights[i];
1142  }
1143  else
1144  {
1145  classProbabilities[labels[i]]++;
1146  }
1147  }
1148 
1149  // Now normalize into probabilities.
1150  classProbabilities /= UseWeights ? sumWeights : labels.n_elem;
1151  arma::uword maxIndex = 0;
1152  classProbabilities.max(maxIndex);
1153  dimensionTypeOrMajorityClass = (size_t) maxIndex;
1154 }
1155 
1156 } // namespace tree
1157 } // namespace mlpack
1158 
1159 #endif
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
Definition: dataset_mapper.hpp:41
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Datatype
The Datatype enum specifies the types of data mlpack algorithms can use.
Definition: datatype.hpp:24
This class implements a generic decision tree learner.
Definition: decision_tree.hpp:40
Definition: pointer_wrapper.hpp:23
size_t Classify(const VecType &point) const
Classify the given point, using the entire tree.
Definition: decision_tree_impl.hpp:948
size_t Dimensionality() const
Get the dimensionality of the DatasetMapper object (that is, how many dimensions it has information f...
Definition: dataset_mapper_impl.hpp:228
double Train(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Train the decision tree on the given data.
Definition: decision_tree_impl.hpp:433
void serialize(Archive &ar, const uint32_t)
Serialize the tree.
Definition: decision_tree_impl.hpp:1059
size_t CalculateDirection(const VecType &point) const
Given a point and that this node is not a leaf, calculate the index of the child node this point woul...
Definition: decision_tree_impl.hpp:1088
Datatype Type(const size_t dimension) const
Return the type of a given dimension (numeric or categorical).
Definition: dataset_mapper_impl.hpp:196
Definition: hmm_train_main.cpp:300
#define CEREAL_VECTOR_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_vector_wrapper.hpp:93
size_t NumMappings(const size_t dimension) const
Get the number of mappings for a particular dimension.
Definition: dataset_mapper_impl.hpp:222
const DecisionTree & Child(const size_t i) const
Get the child of the given index.
Definition: decision_tree.hpp:458
~DecisionTree()
Clean up memory.
Definition: decision_tree_impl.hpp:416
size_t NumClasses() const
Get the number of classes in the tree.
Definition: decision_tree_impl.hpp:1109
size_t NumChildren() const
Get the number of children.
Definition: decision_tree.hpp:455
DecisionTree(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Construct the decision tree on the given data and labels, where the data can be both numeric and cate...
Definition: decision_tree_impl.hpp:31