13 #ifndef MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP 14 #define MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP 23 #include <type_traits> 35 template<
typename FitnessFunction = GiniGain,
36 template<
typename>
class NumericSplitType = BestBinaryNumericSplit,
37 template<
typename>
class CategoricalSplitType = AllCategoricalSplit,
38 typename DimensionSelectionType = AllDimensionSelect,
39 bool NoRecursion =
false>
41 public NumericSplitType<FitnessFunction>::AuxiliarySplitInfo,
42 public CategoricalSplitType<FitnessFunction>::AuxiliarySplitInfo
69 template<
typename MatType,
typename LabelsType>
73 const size_t numClasses,
74 const size_t minimumLeafSize = 10,
75 const double minimumGainSplit = 1e-7,
76 const size_t maximumDepth = 0,
77 DimensionSelectionType dimensionSelector =
78 DimensionSelectionType());
96 template<
typename MatType,
typename LabelsType>
99 const size_t numClasses,
100 const size_t minimumLeafSize = 10,
101 const double minimumGainSplit = 1e-7,
102 const size_t maximumDepth = 0,
103 DimensionSelectionType dimensionSelector =
104 DimensionSelectionType());
125 template<
typename MatType,
typename LabelsType,
typename WeightsType>
130 const size_t numClasses,
132 const size_t minimumLeafSize = 10,
133 const double minimumGainSplit = 1e-7,
134 const size_t maximumDepth = 0,
135 DimensionSelectionType dimensionSelector = DimensionSelectionType(),
136 const std::enable_if_t<arma::is_arma_type<
137 typename std::remove_reference<WeightsType>::type>::value>* = 0);
157 template<
typename MatType,
typename LabelsType,
typename WeightsType>
163 const size_t numClasses,
165 const size_t minimumLeafSize = 10,
166 const double minimumGainSplit = 1e-7,
167 const std::enable_if_t<arma::is_arma_type<
168 typename std::remove_reference<WeightsType>::type>::value>* = 0);
187 template<
typename MatType,
typename LabelsType,
typename WeightsType>
191 const size_t numClasses,
193 const size_t minimumLeafSize = 10,
194 const double minimumGainSplit = 1e-7,
195 const size_t maximumDepth = 0,
196 DimensionSelectionType dimensionSelector = DimensionSelectionType(),
197 const std::enable_if_t<arma::is_arma_type<
198 typename std::remove_reference<WeightsType>::type>::value>* = 0);
218 template<
typename MatType,
typename LabelsType,
typename WeightsType>
223 const size_t numClasses,
225 const size_t minimumLeafSize = 10,
226 const double minimumGainSplit = 1e-7,
227 const size_t maximumDepth = 0,
228 DimensionSelectionType dimensionSelector = DimensionSelectionType(),
229 const std::enable_if_t<arma::is_arma_type<
230 typename std::remove_reference<WeightsType>::type>::value>* = 0);
294 template<
typename MatType,
typename LabelsType>
295 double Train(MatType data,
298 const size_t numClasses,
299 const size_t minimumLeafSize = 10,
300 const double minimumGainSplit = 1e-7,
301 const size_t maximumDepth = 0,
302 DimensionSelectionType dimensionSelector =
303 DimensionSelectionType());
322 template<
typename MatType,
typename LabelsType>
323 double Train(MatType data,
325 const size_t numClasses,
326 const size_t minimumLeafSize = 10,
327 const double minimumGainSplit = 1e-7,
328 const size_t maximumDepth = 0,
329 DimensionSelectionType dimensionSelector =
330 DimensionSelectionType());
353 template<
typename MatType,
typename LabelsType,
typename WeightsType>
354 double Train(MatType data,
357 const size_t numClasses,
359 const size_t minimumLeafSize = 10,
360 const double minimumGainSplit = 1e-7,
361 const size_t maximumDepth = 0,
362 DimensionSelectionType dimensionSelector =
363 DimensionSelectionType(),
364 const std::enable_if_t<arma::is_arma_type<
typename 365 std::remove_reference<WeightsType>::type>::value>* = 0);
386 template<
typename MatType,
typename LabelsType,
typename WeightsType>
387 double Train(MatType data,
389 const size_t numClasses,
391 const size_t minimumLeafSize = 10,
392 const double minimumGainSplit = 1e-7,
393 const size_t maximumDepth = 0,
394 DimensionSelectionType dimensionSelector =
395 DimensionSelectionType(),
396 const std::enable_if_t<arma::is_arma_type<
typename 397 std::remove_reference<WeightsType>::type>::value>* = 0);
405 template<
typename VecType>
406 size_t Classify(
const VecType& point)
const;
417 template<
typename VecType>
420 arma::vec& probabilities)
const;
429 template<
typename MatType>
431 arma::Row<size_t>& predictions)
const;
443 template<
typename MatType>
445 arma::Row<size_t>& predictions,
446 arma::mat& probabilities)
const;
451 template<
typename Archive>
452 void serialize(Archive& ar,
const uint32_t );
473 template<
typename VecType>
483 std::vector<DecisionTree*> children;
485 size_t splitDimension;
488 size_t dimensionTypeOrMajorityClass;
496 arma::vec classProbabilities;
501 typedef typename NumericSplit::AuxiliarySplitInfo
502 NumericAuxiliarySplitInfo;
503 typedef typename CategoricalSplit::AuxiliarySplitInfo
504 CategoricalAuxiliarySplitInfo;
509 template<
bool UseWeights,
typename RowType,
typename WeightsRowType>
510 void CalculateClassProbabilities(
const RowType& labels,
511 const size_t numClasses,
512 const WeightsRowType& weights);
531 template<
bool UseWeights,
typename MatType>
532 double Train(MatType& data,
536 arma::Row<size_t>& labels,
537 const size_t numClasses,
538 arma::rowvec& weights,
539 const size_t minimumLeafSize,
540 const double minimumGainSplit,
541 const size_t maximumDepth,
542 DimensionSelectionType& dimensionSelector);
560 template<
bool UseWeights,
typename MatType>
561 double Train(MatType& data,
564 arma::Row<size_t>& labels,
565 const size_t numClasses,
566 arma::rowvec& weights,
567 const size_t minimumLeafSize,
568 const double minimumGainSplit,
569 const size_t maximumDepth,
570 DimensionSelectionType& dimensionSelector);
576 template<
typename FitnessFunction =
GiniGain,
582 CategoricalSplitType,
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
Definition: dataset_mapper.hpp:41
The BestBinaryNumericSplit is a splitting function for decision trees that will exhaustively search a...
Definition: best_binary_numeric_split.hpp:49
size_t SplitDimension() const
Get the split dimension (only meaningful if this is a non-leaf in a trained tree).
Definition: decision_tree.hpp:464
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
DecisionTree< InformationGain, BestBinaryNumericSplit, AllCategoricalSplit, AllDimensionSelect, true > ID3DecisionStump
Convenience typedef for ID3 decision stumps (single level decision trees made with the ID3 algorithm)...
Definition: decision_tree.hpp:594
This class implements a generic decision tree learner.
Definition: decision_tree.hpp:40
The core includes that mlpack expects; standard C++ includes and Armadillo.
The AllCategoricalSplit is a splitting function that will split categorical features into many childr...
Definition: all_categorical_split.hpp:30
CategoricalSplitType< FitnessFunction > CategoricalSplit
Allow access to the categorical split type.
Definition: decision_tree.hpp:48
size_t Classify(const VecType &point) const
Classify the given point, using the entire tree.
Definition: decision_tree_impl.hpp:948
double Train(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Train the decision tree on the given data.
Definition: decision_tree_impl.hpp:433
void serialize(Archive &ar, const uint32_t)
Serialize the tree.
Definition: decision_tree_impl.hpp:1059
size_t CalculateDirection(const VecType &point) const
Given a point and that this node is not a leaf, calculate the index of the child node this point woul...
Definition: decision_tree_impl.hpp:1088
This dimension selection policy allows any dimension to be selected for splitting.
Definition: all_dimension_select.hpp:22
NumericSplitType< FitnessFunction > NumericSplit
Allow access to the numeric split type.
Definition: decision_tree.hpp:46
The Gini gain, a measure of set purity usable as a fitness function (FitnessFunction) for decision tr...
Definition: gini_gain.hpp:27
DecisionTree & operator=(const DecisionTree &other)
Copy another tree.
Definition: decision_tree_impl.hpp:339
DecisionTree & Child(const size_t i)
Modify the child of the given index (be careful!).
Definition: decision_tree.hpp:460
DimensionSelectionType DimensionSelection
Allow access to the dimension selection type.
Definition: decision_tree.hpp:50
const DecisionTree & Child(const size_t i) const
Get the child of the given index.
Definition: decision_tree.hpp:458
~DecisionTree()
Clean up memory.
Definition: decision_tree_impl.hpp:416
size_t NumClasses() const
Get the number of classes in the tree.
Definition: decision_tree_impl.hpp:1109
size_t NumChildren() const
Get the number of children.
Definition: decision_tree.hpp:455
DecisionTree(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Construct the decision tree on the given data and labels, where the data can be both numeric and cate...
Definition: decision_tree_impl.hpp:31