13 #ifndef MLPACK_METHODS_DECISION_TREE_DECISION_TREE_REGRESSOR_HPP 14 #define MLPACK_METHODS_DECISION_TREE_DECISION_TREE_REGRESSOR_HPP 23 #include <type_traits> 36 template<
typename FitnessFunction = MSEGain,
37 template<
typename>
class NumericSplitType = BestBinaryNumericSplit,
38 template<
typename>
class CategoricalSplitType = AllCategoricalSplit,
39 typename DimensionSelectionType = AllDimensionSelect,
40 bool NoRecursion =
false>
42 public NumericSplitType<FitnessFunction>::AuxiliarySplitInfo,
43 public CategoricalSplitType<FitnessFunction>::AuxiliarySplitInfo
74 template<
typename MatType,
typename ResponsesType>
77 ResponsesType responses,
78 const size_t minimumLeafSize = 10,
79 const double minimumGainSplit = 1e-7,
80 const size_t maximumDepth = 0,
81 DimensionSelectionType dimensionSelector =
82 DimensionSelectionType());
99 template<
typename MatType,
typename ResponsesType>
101 ResponsesType responses,
102 const size_t minimumLeafSize = 10,
103 const double minimumGainSplit = 1e-7,
104 const size_t maximumDepth = 0,
105 DimensionSelectionType dimensionSelector =
106 DimensionSelectionType());
126 template<
typename MatType,
typename ResponsesType,
typename WeightsType>
130 ResponsesType responses,
132 const size_t minimumLeafSize = 10,
133 const double minimumGainSplit = 1e-7,
134 const size_t maximumDepth = 0,
135 DimensionSelectionType dimensionSelector = DimensionSelectionType(),
136 const std::enable_if_t<arma::is_arma_type<
137 typename std::remove_reference<WeightsType>::type>::value>* = 0);
156 template<
typename MatType,
typename ResponsesType,
typename WeightsType>
159 ResponsesType responses,
161 const size_t minimumLeafSize = 10,
162 const double minimumGainSplit = 1e-7,
163 const size_t maximumDepth = 0,
164 DimensionSelectionType dimensionSelector = DimensionSelectionType(),
165 const std::enable_if_t<arma::is_arma_type<
166 typename std::remove_reference<WeightsType>::type>::value>* = 0);
186 template<
typename MatType,
typename ResponsesType,
typename WeightsType>
191 ResponsesType responses,
193 const size_t minimumLeafSize = 10,
194 const double minimumGainSplit = 1e-7,
195 const std::enable_if_t<arma::is_arma_type<
196 typename std::remove_reference<WeightsType>::type>::value>* = 0);
215 template<
typename MatType,
typename ResponsesType,
typename WeightsType>
219 ResponsesType responses,
221 const size_t minimumLeafSize = 10,
222 const double minimumGainSplit = 1e-7,
223 const size_t maximumDepth = 0,
224 DimensionSelectionType dimensionSelector = DimensionSelectionType(),
225 const std::enable_if_t<arma::is_arma_type<
226 typename std::remove_reference<WeightsType>::type>::value>* = 0);
281 template<
typename MatType,
typename ResponsesType>
282 double Train(MatType data,
284 ResponsesType responses,
285 const size_t minimumLeafSize = 10,
286 const double minimumGainSplit = 1e-7,
287 const size_t maximumDepth = 0,
288 DimensionSelectionType dimensionSelector =
289 DimensionSelectionType());
307 template<
typename MatType,
typename ResponsesType>
308 double Train(MatType data,
309 ResponsesType responses,
310 const size_t minimumLeafSize = 10,
311 const double minimumGainSplit = 1e-7,
312 const size_t maximumDepth = 0,
313 DimensionSelectionType dimensionSelector =
314 DimensionSelectionType());
336 template<
typename MatType,
typename ResponsesType,
typename WeightsType>
337 double Train(MatType data,
339 ResponsesType responses,
341 const size_t minimumLeafSize = 10,
342 const double minimumGainSplit = 1e-7,
343 const size_t maximumDepth = 0,
344 DimensionSelectionType dimensionSelector =
345 DimensionSelectionType(),
346 const std::enable_if_t<arma::is_arma_type<
typename 347 std::remove_reference<WeightsType>::type>::value>* = 0);
367 template<
typename MatType,
typename ResponsesType,
typename WeightsType>
368 double Train(MatType data,
369 ResponsesType responses,
371 const size_t minimumLeafSize = 10,
372 const double minimumGainSplit = 1e-7,
373 const size_t maximumDepth = 0,
374 DimensionSelectionType dimensionSelector =
375 DimensionSelectionType(),
376 const std::enable_if_t<arma::is_arma_type<
typename 377 std::remove_reference<WeightsType>::type>::value>* = 0);
385 template<
typename VecType>
386 double Predict(
const VecType& point)
const;
395 template<
typename MatType>
396 void Predict(
const MatType& data,
397 arma::Row<double>& predictions)
const;
402 template<
typename Archive>
403 void serialize(Archive& ar,
const uint32_t );
430 template<
typename VecType>
435 std::vector<DecisionTreeRegressor*> children;
437 size_t splitDimension;
440 size_t dimensionType;
448 double splitPointOrPrediction;
453 typedef typename NumericSplit::AuxiliarySplitInfo
454 NumericAuxiliarySplitInfo;
455 typedef typename CategoricalSplit::AuxiliarySplitInfo
456 CategoricalAuxiliarySplitInfo;
461 template<
bool UseWeights,
typename ResponsesType,
typename WeightsType>
462 void CalculatePrediction(
const ResponsesType& responses,
463 const WeightsType& weights);
481 template<
bool UseWeights,
typename MatType,
typename ResponsesType>
482 double Train(MatType& data,
486 ResponsesType& responses,
487 arma::rowvec& weights,
488 const size_t minimumLeafSize,
489 const double minimumGainSplit,
490 const size_t maximumDepth,
491 DimensionSelectionType& dimensionSelector);
508 template<
bool UseWeights,
typename MatType,
typename ResponsesType>
509 double Train(MatType& data,
512 ResponsesType& responses,
513 arma::rowvec& weights,
514 const size_t minimumLeafSize,
515 const double minimumGainSplit,
516 const size_t maximumDepth,
517 DimensionSelectionType& dimensionSelector);
NumericSplitType< FitnessFunction > NumericSplit
Allow access to the numeric split type.
Definition: decision_tree_regressor.hpp:47
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
Definition: dataset_mapper.hpp:41
size_t NumChildren() const
Get the number of children.
Definition: decision_tree_regressor.hpp:406
size_t SplitDimension() const
Get the split dimension (only meaningful if this is a non-leaf in a trained tree).
Definition: decision_tree_regressor.hpp:421
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
size_t NumLeaves() const
Get the number of leaves in the tree.
Definition: decision_tree_regressor_impl.hpp:1058
double Train(MatType data, const data::DatasetInfo &datasetInfo, ResponsesType responses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Train the decision tree on the given data.
Definition: decision_tree_regressor_impl.hpp:428
DecisionTreeRegressor & operator=(const DecisionTreeRegressor &other)
Copy another tree.
Definition: decision_tree_regressor_impl.hpp:336
DecisionTreeRegressor()
Construct a decision tree without training it.
Definition: decision_tree_regressor_impl.hpp:31
The core includes that mlpack expects; standard C++ includes and Armadillo.
DecisionTreeRegressor & Child(const size_t i)
Modify the child of the given index (be careful!).
Definition: decision_tree_regressor.hpp:417
DimensionSelectionType DimensionSelection
Allow access to the dimension selection type.
Definition: decision_tree_regressor.hpp:51
This class implements a generic decision tree learner.
Definition: decision_tree_regressor.hpp:41
double Predict(const VecType &point) const
Make prediction for the given point, using the entire tree.
Definition: decision_tree_regressor_impl.hpp:930
size_t CalculateDirection(const VecType &point) const
Given a point and that this node is not a leaf, calculate the index of the child node this point woul...
Definition: decision_tree_regressor_impl.hpp:1008
CategoricalSplitType< FitnessFunction > CategoricalSplit
Allow access to the categorical split type.
Definition: decision_tree_regressor.hpp:49
void serialize(Archive &ar, const uint32_t)
Serialize the tree.
Definition: decision_tree_regressor_impl.hpp:1030
const DecisionTreeRegressor & Child(const size_t i) const
Get the child of the given index.
Definition: decision_tree_regressor.hpp:412
~DecisionTreeRegressor()
Clean up memory.
Definition: decision_tree_regressor_impl.hpp:411