mlpack
decision_tree.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP
14 #define MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 #include "gini_gain.hpp"
18 #include "information_gain.hpp"
22 #include "all_dimension_select.hpp"
23 #include <type_traits>
24 
25 namespace mlpack {
26 namespace tree {
27 
35 template<typename FitnessFunction = GiniGain,
36  template<typename> class NumericSplitType = BestBinaryNumericSplit,
37  template<typename> class CategoricalSplitType = AllCategoricalSplit,
38  typename DimensionSelectionType = AllDimensionSelect,
39  bool NoRecursion = false>
40 class DecisionTree :
41  public NumericSplitType<FitnessFunction>::AuxiliarySplitInfo,
42  public CategoricalSplitType<FitnessFunction>::AuxiliarySplitInfo
43 {
44  public:
46  typedef NumericSplitType<FitnessFunction> NumericSplit;
48  typedef CategoricalSplitType<FitnessFunction> CategoricalSplit;
50  typedef DimensionSelectionType DimensionSelection;
51 
69  template<typename MatType, typename LabelsType>
70  DecisionTree(MatType data,
71  const data::DatasetInfo& datasetInfo,
72  LabelsType labels,
73  const size_t numClasses,
74  const size_t minimumLeafSize = 10,
75  const double minimumGainSplit = 1e-7,
76  const size_t maximumDepth = 0,
77  DimensionSelectionType dimensionSelector =
78  DimensionSelectionType());
79 
96  template<typename MatType, typename LabelsType>
97  DecisionTree(MatType data,
98  LabelsType labels,
99  const size_t numClasses,
100  const size_t minimumLeafSize = 10,
101  const double minimumGainSplit = 1e-7,
102  const size_t maximumDepth = 0,
103  DimensionSelectionType dimensionSelector =
104  DimensionSelectionType());
105 
125  template<typename MatType, typename LabelsType, typename WeightsType>
126  DecisionTree(
127  MatType data,
128  const data::DatasetInfo& datasetInfo,
129  LabelsType labels,
130  const size_t numClasses,
131  WeightsType weights,
132  const size_t minimumLeafSize = 10,
133  const double minimumGainSplit = 1e-7,
134  const size_t maximumDepth = 0,
135  DimensionSelectionType dimensionSelector = DimensionSelectionType(),
136  const std::enable_if_t<arma::is_arma_type<
137  typename std::remove_reference<WeightsType>::type>::value>* = 0);
138 
157  template<typename MatType, typename LabelsType, typename WeightsType>
158  DecisionTree(
159  const DecisionTree& other,
160  MatType data,
161  const data::DatasetInfo& datasetInfo,
162  LabelsType labels,
163  const size_t numClasses,
164  WeightsType weights,
165  const size_t minimumLeafSize = 10,
166  const double minimumGainSplit = 1e-7,
167  const std::enable_if_t<arma::is_arma_type<
168  typename std::remove_reference<WeightsType>::type>::value>* = 0);
187  template<typename MatType, typename LabelsType, typename WeightsType>
188  DecisionTree(
189  MatType data,
190  LabelsType labels,
191  const size_t numClasses,
192  WeightsType weights,
193  const size_t minimumLeafSize = 10,
194  const double minimumGainSplit = 1e-7,
195  const size_t maximumDepth = 0,
196  DimensionSelectionType dimensionSelector = DimensionSelectionType(),
197  const std::enable_if_t<arma::is_arma_type<
198  typename std::remove_reference<WeightsType>::type>::value>* = 0);
199 
218  template<typename MatType, typename LabelsType, typename WeightsType>
219  DecisionTree(
220  const DecisionTree& other,
221  MatType data,
222  LabelsType labels,
223  const size_t numClasses,
224  WeightsType weights,
225  const size_t minimumLeafSize = 10,
226  const double minimumGainSplit = 1e-7,
227  const size_t maximumDepth = 0,
228  DimensionSelectionType dimensionSelector = DimensionSelectionType(),
229  const std::enable_if_t<arma::is_arma_type<
230  typename std::remove_reference<WeightsType>::type>::value>* = 0);
231 
238  DecisionTree(const size_t numClasses = 1);
239 
246  DecisionTree(const DecisionTree& other);
247 
253  DecisionTree(DecisionTree&& other);
254 
261  DecisionTree& operator=(const DecisionTree& other);
262 
269 
273  ~DecisionTree();
274 
294  template<typename MatType, typename LabelsType>
295  double Train(MatType data,
296  const data::DatasetInfo& datasetInfo,
297  LabelsType labels,
298  const size_t numClasses,
299  const size_t minimumLeafSize = 10,
300  const double minimumGainSplit = 1e-7,
301  const size_t maximumDepth = 0,
302  DimensionSelectionType dimensionSelector =
303  DimensionSelectionType());
304 
322  template<typename MatType, typename LabelsType>
323  double Train(MatType data,
324  LabelsType labels,
325  const size_t numClasses,
326  const size_t minimumLeafSize = 10,
327  const double minimumGainSplit = 1e-7,
328  const size_t maximumDepth = 0,
329  DimensionSelectionType dimensionSelector =
330  DimensionSelectionType());
331 
353  template<typename MatType, typename LabelsType, typename WeightsType>
354  double Train(MatType data,
355  const data::DatasetInfo& datasetInfo,
356  LabelsType labels,
357  const size_t numClasses,
358  WeightsType weights,
359  const size_t minimumLeafSize = 10,
360  const double minimumGainSplit = 1e-7,
361  const size_t maximumDepth = 0,
362  DimensionSelectionType dimensionSelector =
363  DimensionSelectionType(),
364  const std::enable_if_t<arma::is_arma_type<typename
365  std::remove_reference<WeightsType>::type>::value>* = 0);
366 
386  template<typename MatType, typename LabelsType, typename WeightsType>
387  double Train(MatType data,
388  LabelsType labels,
389  const size_t numClasses,
390  WeightsType weights,
391  const size_t minimumLeafSize = 10,
392  const double minimumGainSplit = 1e-7,
393  const size_t maximumDepth = 0,
394  DimensionSelectionType dimensionSelector =
395  DimensionSelectionType(),
396  const std::enable_if_t<arma::is_arma_type<typename
397  std::remove_reference<WeightsType>::type>::value>* = 0);
398 
405  template<typename VecType>
406  size_t Classify(const VecType& point) const;
407 
417  template<typename VecType>
418  void Classify(const VecType& point,
419  size_t& prediction,
420  arma::vec& probabilities) const;
421 
429  template<typename MatType>
430  void Classify(const MatType& data,
431  arma::Row<size_t>& predictions) const;
432 
443  template<typename MatType>
444  void Classify(const MatType& data,
445  arma::Row<size_t>& predictions,
446  arma::mat& probabilities) const;
447 
451  template<typename Archive>
452  void serialize(Archive& ar, const uint32_t /* version */);
453 
455  size_t NumChildren() const { return children.size(); }
456 
458  const DecisionTree& Child(const size_t i) const { return *children[i]; }
460  DecisionTree& Child(const size_t i) { return *children[i]; }
461 
464  size_t SplitDimension() const { return splitDimension; }
465 
473  template<typename VecType>
474  size_t CalculateDirection(const VecType& point) const;
475 
479  size_t NumClasses() const;
480 
481  private:
483  std::vector<DecisionTree*> children;
485  size_t splitDimension;
488  size_t dimensionTypeOrMajorityClass;
496  arma::vec classProbabilities;
497 
501  typedef typename NumericSplit::AuxiliarySplitInfo
502  NumericAuxiliarySplitInfo;
503  typedef typename CategoricalSplit::AuxiliarySplitInfo
504  CategoricalAuxiliarySplitInfo;
505 
509  template<bool UseWeights, typename RowType, typename WeightsRowType>
510  void CalculateClassProbabilities(const RowType& labels,
511  const size_t numClasses,
512  const WeightsRowType& weights);
513 
531  template<bool UseWeights, typename MatType>
532  double Train(MatType& data,
533  const size_t begin,
534  const size_t count,
535  const data::DatasetInfo& datasetInfo,
536  arma::Row<size_t>& labels,
537  const size_t numClasses,
538  arma::rowvec& weights,
539  const size_t minimumLeafSize,
540  const double minimumGainSplit,
541  const size_t maximumDepth,
542  DimensionSelectionType& dimensionSelector);
543 
560  template<bool UseWeights, typename MatType>
561  double Train(MatType& data,
562  const size_t begin,
563  const size_t count,
564  arma::Row<size_t>& labels,
565  const size_t numClasses,
566  arma::rowvec& weights,
567  const size_t minimumLeafSize,
568  const double minimumGainSplit,
569  const size_t maximumDepth,
570  DimensionSelectionType& dimensionSelector);
571 };
572 
576 template<typename FitnessFunction = GiniGain,
577  template<typename> class NumericSplitType = BestBinaryNumericSplit,
578  template<typename> class CategoricalSplitType = AllCategoricalSplit,
579  typename DimensionSelectType = AllDimensionSelect>
580 using DecisionStump = DecisionTree<FitnessFunction,
581  NumericSplitType,
582  CategoricalSplitType,
583  DimensionSelectType,
584  false>;
585 
595 } // namespace tree
596 } // namespace mlpack
597 
598 // Include implementation.
599 #include "decision_tree_impl.hpp"
600 
601 #endif
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
Definition: dataset_mapper.hpp:41
The BestBinaryNumericSplit is a splitting function for decision trees that will exhaustively search a...
Definition: best_binary_numeric_split.hpp:49
size_t SplitDimension() const
Get the split dimension (only meaningful if this is a non-leaf in a trained tree).
Definition: decision_tree.hpp:464
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
DecisionTree< InformationGain, BestBinaryNumericSplit, AllCategoricalSplit, AllDimensionSelect, true > ID3DecisionStump
Convenience typedef for ID3 decision stumps (single level decision trees made with the ID3 algorithm)...
Definition: decision_tree.hpp:594
This class implements a generic decision tree learner.
Definition: decision_tree.hpp:40
The core includes that mlpack expects; standard C++ includes and Armadillo.
The AllCategoricalSplit is a splitting function that will split categorical features into many childr...
Definition: all_categorical_split.hpp:30
CategoricalSplitType< FitnessFunction > CategoricalSplit
Allow access to the categorical split type.
Definition: decision_tree.hpp:48
size_t Classify(const VecType &point) const
Classify the given point, using the entire tree.
Definition: decision_tree_impl.hpp:948
The standard information gain criterion, used for calculating gain in decision trees.
Definition: information_gain.hpp:25
double Train(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Train the decision tree on the given data.
Definition: decision_tree_impl.hpp:433
void serialize(Archive &ar, const uint32_t)
Serialize the tree.
Definition: decision_tree_impl.hpp:1059
size_t CalculateDirection(const VecType &point) const
Given a point and that this node is not a leaf, calculate the index of the child node this point woul...
Definition: decision_tree_impl.hpp:1088
This dimension selection policy allows any dimension to be selected for splitting.
Definition: all_dimension_select.hpp:22
NumericSplitType< FitnessFunction > NumericSplit
Allow access to the numeric split type.
Definition: decision_tree.hpp:46
The Gini gain, a measure of set purity usable as a fitness function (FitnessFunction) for decision tr...
Definition: gini_gain.hpp:27
DecisionTree & operator=(const DecisionTree &other)
Copy another tree.
Definition: decision_tree_impl.hpp:339
DecisionTree & Child(const size_t i)
Modify the child of the given index (be careful!).
Definition: decision_tree.hpp:460
DimensionSelectionType DimensionSelection
Allow access to the dimension selection type.
Definition: decision_tree.hpp:50
const DecisionTree & Child(const size_t i) const
Get the child of the given index.
Definition: decision_tree.hpp:458
~DecisionTree()
Clean up memory.
Definition: decision_tree_impl.hpp:416
size_t NumClasses() const
Get the number of classes in the tree.
Definition: decision_tree_impl.hpp:1109
size_t NumChildren() const
Get the number of children.
Definition: decision_tree.hpp:455
DecisionTree(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Construct the decision tree on the given data and labels, where the data can be both numeric and cate...
Definition: decision_tree_impl.hpp:31