12 #ifndef MLPACK_METHODS_DECISION_TREE_DECISION_TREE_REGRESSOR_IMPL_HPP 13 #define MLPACK_METHODS_DECISION_TREE_DECISION_TREE_REGRESSOR_IMPL_HPP 22 template<
typename FitnessFunction,
23 template<
typename>
class NumericSplitType,
24 template<
typename>
class CategoricalSplitType,
25 typename DimensionSelectionType,
27 DecisionTreeRegressor<FitnessFunction,
30 DimensionSelectionType,
34 splitPointOrPrediction(0.0)
40 template<
typename FitnessFunction,
41 template<
typename>
class NumericSplitType,
42 template<
typename>
class CategoricalSplitType,
43 typename DimensionSelectionType,
45 template<
typename MatType,
typename ResponsesType>
49 DimensionSelectionType,
53 ResponsesType responses,
54 const size_t minimumLeafSize,
55 const double minimumGainSplit,
56 const size_t maximumDepth,
57 DimensionSelectionType dimensionSelector)
59 using TrueMatType =
typename std::decay<MatType>::type;
60 using TrueResponsesType =
typename std::decay<ResponsesType>::type;
63 TrueMatType tmpData(std::move(data));
64 TrueResponsesType tmpResponses(std::move(responses));
67 dimensionSelector.Dimensions() = tmpData.n_rows;
71 Train<false>(tmpData, 0, tmpData.n_cols, datasetInfo, tmpResponses,
72 weights, minimumLeafSize, minimumGainSplit, maximumDepth,
77 template<
typename FitnessFunction,
78 template<
typename>
class NumericSplitType,
79 template<
typename>
class CategoricalSplitType,
80 typename DimensionSelectionType,
82 template<
typename MatType,
typename ResponsesType>
86 DimensionSelectionType,
89 ResponsesType responses,
90 const size_t minimumLeafSize,
91 const double minimumGainSplit,
92 const size_t maximumDepth,
93 DimensionSelectionType dimensionSelector)
95 using TrueMatType =
typename std::decay<MatType>::type;
96 using TrueResponsesType =
typename std::decay<ResponsesType>::type;
99 TrueMatType tmpData(std::move(data));
100 TrueResponsesType tmpResponses(std::move(responses));
103 dimensionSelector.Dimensions() = tmpData.n_rows;
106 arma::rowvec weights;
107 Train<false>(tmpData, 0, tmpData.n_cols, tmpResponses, weights,
108 minimumLeafSize, minimumGainSplit, maximumDepth, dimensionSelector);
112 template<
typename FitnessFunction,
113 template<
typename>
class NumericSplitType,
114 template<
typename>
class CategoricalSplitType,
115 typename DimensionSelectionType,
117 template<
typename MatType,
typename ResponsesType,
typename WeightsType>
120 CategoricalSplitType,
121 DimensionSelectionType,
125 ResponsesType responses,
127 const size_t minimumLeafSize,
128 const double minimumGainSplit,
129 const size_t maximumDepth,
130 DimensionSelectionType dimensionSelector,
131 const std::enable_if_t<arma::is_arma_type<
132 typename std::remove_reference<WeightsType>::type>::value>*)
134 using TrueMatType =
typename std::decay<MatType>::type;
135 using TrueResponsesType =
typename std::decay<ResponsesType>::type;
136 using TrueWeightsType =
typename std::decay<WeightsType>::type;
138 TrueMatType tmpData(std::move(data));
139 TrueResponsesType tmpResponses(std::move(responses));
140 TrueWeightsType tmpWeights(std::move(weights));
143 dimensionSelector.Dimensions() = tmpData.n_rows;
146 Train<true>(tmpData, 0, tmpData.n_cols, datasetInfo, tmpResponses,
147 tmpWeights, minimumLeafSize, minimumGainSplit, maximumDepth,
152 template<
typename FitnessFunction,
153 template<
typename>
class NumericSplitType,
154 template<
typename>
class CategoricalSplitType,
155 typename DimensionSelectionType,
157 template<
typename MatType,
typename ResponsesType,
typename WeightsType>
160 CategoricalSplitType,
161 DimensionSelectionType,
164 ResponsesType responses,
166 const size_t minimumLeafSize,
167 const double minimumGainSplit,
168 const size_t maximumDepth,
169 DimensionSelectionType dimensionSelector,
170 const std::enable_if_t<
172 typename std::remove_reference<
173 WeightsType>::type>::value>*)
175 using TrueMatType =
typename std::decay<MatType>::type;
176 using TrueResponsesType =
typename std::decay<ResponsesType>::type;
177 using TrueWeightsType =
typename std::decay<WeightsType>::type;
180 TrueMatType tmpData(std::move(data));
181 TrueResponsesType tmpResponses(std::move(responses));
182 TrueWeightsType tmpWeights(std::move(weights));
185 dimensionSelector.Dimensions() = tmpData.n_rows;
188 Train<true>(tmpData, 0, tmpData.n_cols, tmpResponses, tmpWeights,
189 minimumLeafSize, minimumGainSplit, maximumDepth, dimensionSelector);
193 template<
typename FitnessFunction,
194 template<
typename>
class NumericSplitType,
195 template<
typename>
class CategoricalSplitType,
196 typename DimensionSelectionType,
198 template<
typename MatType,
typename ResponsesType,
typename WeightsType>
201 CategoricalSplitType,
202 DimensionSelectionType,
207 ResponsesType responses,
209 const size_t minimumLeafSize,
210 const double minimumGainSplit,
211 const std::enable_if_t<arma::is_arma_type<
212 typename std::remove_reference<WeightsType>::type>::value>*):
213 NumericAuxiliarySplitInfo(other),
214 CategoricalAuxiliarySplitInfo(other)
216 using TrueMatType =
typename std::decay<MatType>::type;
217 using TrueResponsesType =
typename std::decay<ResponsesType>::type;
218 using TrueWeightsType =
typename std::decay<WeightsType>::type;
221 TrueMatType tmpData(std::move(data));
222 TrueResponsesType tmpResponses(std::move(responses));
223 TrueWeightsType tmpWeights(std::move(weights));
226 Train<true>(tmpData, 0, tmpData.n_cols, datasetInfo, tmpResponses,
227 tmpWeights, minimumLeafSize, minimumGainSplit);
231 template<
typename FitnessFunction,
232 template<
typename>
class NumericSplitType,
233 template<
typename>
class CategoricalSplitType,
234 typename DimensionSelectionType,
236 template<
typename MatType,
typename ResponsesType,
typename WeightsType>
239 CategoricalSplitType,
240 DimensionSelectionType,
244 ResponsesType responses,
246 const size_t minimumLeafSize,
247 const double minimumGainSplit,
248 const size_t maximumDepth,
249 DimensionSelectionType dimensionSelector,
250 const std::enable_if_t<arma::is_arma_type<
251 typename std::remove_reference<
252 WeightsType>::type>::value>*):
253 NumericAuxiliarySplitInfo(other),
254 CategoricalAuxiliarySplitInfo(other)
256 using TrueMatType =
typename std::decay<MatType>::type;
257 using TrueResponsesType =
typename std::decay<ResponsesType>::type;
258 using TrueWeightsType =
typename std::decay<WeightsType>::type;
261 TrueMatType tmpData(std::move(data));
262 TrueResponsesType tmpResponses(std::move(responses));
263 TrueWeightsType tmpWeights(std::move(weights));
266 dimensionSelector.Dimensions() = tmpData.n_rows;
269 Train<true>(tmpData, 0, tmpData.n_cols, tmpResponses, tmpWeights,
270 minimumLeafSize, minimumGainSplit, maximumDepth, dimensionSelector);
274 template<
typename FitnessFunction,
275 template<
typename>
class NumericSplitType,
276 template<
typename>
class CategoricalSplitType,
277 typename DimensionSelectionType,
281 CategoricalSplitType,
282 DimensionSelectionType,
286 NumericAuxiliarySplitInfo(other),
287 CategoricalAuxiliarySplitInfo(other),
288 splitDimension(other.splitDimension),
289 dimensionType(other.dimensionType),
290 splitPointOrPrediction(other.splitPointOrPrediction)
293 for (
size_t i = 0; i < other.children.size(); ++i)
298 template<
typename FitnessFunction,
299 template<
typename>
class NumericSplitType,
300 template<
typename>
class CategoricalSplitType,
301 typename DimensionSelectionType,
305 CategoricalSplitType,
306 DimensionSelectionType,
310 NumericAuxiliarySplitInfo(
std::move(other)),
311 CategoricalAuxiliarySplitInfo(
std::move(other)),
312 children(
std::move(other.children)),
313 splitDimension(other.splitDimension),
314 dimensionType(other.dimensionType),
315 splitPointOrPrediction(other.splitPointOrPrediction)
321 template<
typename FitnessFunction,
322 template<
typename>
class NumericSplitType,
323 template<
typename>
class CategoricalSplitType,
324 typename DimensionSelectionType,
328 CategoricalSplitType,
329 DimensionSelectionType,
333 CategoricalSplitType,
334 DimensionSelectionType,
342 for (
size_t i = 0; i < children.size(); ++i)
347 splitDimension = other.splitDimension;
348 dimensionType = other.dimensionType;
349 splitPointOrPrediction = other.splitPointOrPrediction;
352 for (
size_t i = 0; i < other.children.size(); ++i)
356 NumericAuxiliarySplitInfo::operator=(other);
357 CategoricalAuxiliarySplitInfo::operator=(other);
363 template<
typename FitnessFunction,
364 template<
typename>
class NumericSplitType,
365 template<
typename>
class CategoricalSplitType,
366 typename DimensionSelectionType,
370 CategoricalSplitType,
371 DimensionSelectionType,
375 CategoricalSplitType,
376 DimensionSelectionType,
384 for (
size_t i = 0; i < children.size(); ++i)
389 children = std::move(other.children);
390 splitDimension = other.splitDimension;
391 dimensionType = other.dimensionType;
392 splitPointOrPrediction = other.splitPointOrPrediction;
395 NumericAuxiliarySplitInfo::operator=(std::move(other));
396 CategoricalAuxiliarySplitInfo::operator=(std::move(other));
402 template<
typename FitnessFunction,
403 template<
typename>
class NumericSplitType,
404 template<
typename>
class CategoricalSplitType,
405 typename DimensionSelectionType,
409 CategoricalSplitType,
410 DimensionSelectionType,
413 for (
size_t i = 0; i < children.size(); ++i)
418 template<
typename FitnessFunction,
419 template<
typename>
class NumericSplitType,
420 template<
typename>
class CategoricalSplitType,
421 typename DimensionSelectionType,
423 template<
typename MatType,
typename ResponsesType>
426 CategoricalSplitType,
427 DimensionSelectionType,
431 ResponsesType responses,
432 const size_t minimumLeafSize,
433 const double minimumGainSplit,
434 const size_t maximumDepth,
435 DimensionSelectionType dimensionSelector)
438 util::CheckSameSizes(data, responses,
"DecisionTreeRegressor::Train()");
440 using TrueMatType =
typename std::decay<MatType>::type;
441 using TrueResponsesType =
typename std::decay<ResponsesType>::type;
444 TrueMatType tmpData(std::move(data));
445 TrueResponsesType tmpResponses(std::move(responses));
448 dimensionSelector.Dimensions() = tmpData.n_rows;
451 arma::rowvec weights;
452 return Train<false>(tmpData, 0, tmpData.n_cols, datasetInfo, tmpResponses,
453 weights, minimumLeafSize, minimumGainSplit, maximumDepth,
458 template<
typename FitnessFunction,
459 template<
typename>
class NumericSplitType,
460 template<
typename>
class CategoricalSplitType,
461 typename DimensionSelectionType,
463 template<
typename MatType,
typename ResponsesType>
466 CategoricalSplitType,
467 DimensionSelectionType,
470 ResponsesType responses,
471 const size_t minimumLeafSize,
472 const double minimumGainSplit,
473 const size_t maximumDepth,
474 DimensionSelectionType dimensionSelector)
477 util::CheckSameSizes(data, responses,
"DecisionTreeRegressor::Train()");
479 using TrueMatType =
typename std::decay<MatType>::type;
480 using TrueResponsesType =
typename std::decay<ResponsesType>::type;
483 TrueMatType tmpData(std::move(data));
484 TrueResponsesType tmpResponses(std::move(responses));
487 dimensionSelector.Dimensions() = tmpData.n_rows;
490 arma::rowvec weights;
491 return Train<false>(tmpData, 0, tmpData.n_cols, tmpResponses,
492 weights, minimumLeafSize, minimumGainSplit, maximumDepth,
497 template<
typename FitnessFunction,
498 template<
typename>
class NumericSplitType,
499 template<
typename>
class CategoricalSplitType,
500 typename DimensionSelectionType,
502 template<
typename MatType,
typename ResponsesType,
typename WeightsType>
505 CategoricalSplitType,
506 DimensionSelectionType,
510 ResponsesType responses,
512 const size_t minimumLeafSize,
513 const double minimumGainSplit,
514 const size_t maximumDepth,
515 DimensionSelectionType dimensionSelector,
516 const std::enable_if_t<
518 typename std::remove_reference<
519 WeightsType>::type>::value>*)
522 util::CheckSameSizes(data, responses,
"DecisionTreeRegressor::Train()");
524 using TrueMatType =
typename std::decay<MatType>::type;
525 using TrueResponsesType =
typename std::decay<ResponsesType>::type;
526 using TrueWeightsType =
typename std::decay<WeightsType>::type;
529 TrueMatType tmpData(std::move(data));
530 TrueResponsesType tmpResponses(std::move(responses));
531 TrueWeightsType tmpWeights(std::move(weights));
534 dimensionSelector.Dimensions() = tmpData.n_rows;
537 return Train<true>(tmpData, 0, tmpData.n_cols, datasetInfo, tmpResponses,
538 tmpWeights, minimumLeafSize, minimumGainSplit, maximumDepth,
543 template<
typename FitnessFunction,
544 template<
typename>
class NumericSplitType,
545 template<
typename>
class CategoricalSplitType,
546 typename DimensionSelectionType,
548 template<
typename MatType,
typename ResponsesType,
typename WeightsType>
551 CategoricalSplitType,
552 DimensionSelectionType,
555 ResponsesType responses,
557 const size_t minimumLeafSize,
558 const double minimumGainSplit,
559 const size_t maximumDepth,
560 DimensionSelectionType dimensionSelector,
561 const std::enable_if_t<
563 typename std::remove_reference<
564 WeightsType>::type>::value>*)
567 util::CheckSameSizes(data, responses,
"DecisionTreeRegressor::Train()");
569 using TrueMatType =
typename std::decay<MatType>::type;
570 using TrueResponsesType =
typename std::decay<ResponsesType>::type;
571 using TrueWeightsType =
typename std::decay<WeightsType>::type;
574 TrueMatType tmpData(std::move(data));
575 TrueResponsesType tmpResponses(std::move(responses));
576 TrueWeightsType tmpWeights(std::move(weights));
579 dimensionSelector.Dimensions() = tmpData.n_rows;
582 return Train<true>(tmpData, 0, tmpData.n_cols, tmpResponses,
583 tmpWeights, minimumLeafSize, minimumGainSplit, maximumDepth,
588 template<
typename FitnessFunction,
589 template<
typename>
class NumericSplitType,
590 template<
typename>
class CategoricalSplitType,
591 typename DimensionSelectionType,
593 template<
bool UseWeights,
typename MatType,
typename ResponsesType>
596 CategoricalSplitType,
597 DimensionSelectionType,
603 ResponsesType& responses,
604 arma::rowvec& weights,
605 const size_t minimumLeafSize,
606 const double minimumGainSplit,
607 const size_t maximumDepth,
608 DimensionSelectionType& dimensionSelector)
611 for (
size_t i = 0; i < children.size(); ++i)
620 double bestGain = FitnessFunction::template Evaluate<UseWeights>(
621 responses.subvec(begin, begin + count - 1),
622 UseWeights ? weights.subvec(begin, begin + count - 1) : weights);
624 const size_t end = dimensionSelector.End();
626 if (maximumDepth != 1)
628 for (
size_t i = dimensionSelector.Begin(); i != end;
629 i = dimensionSelector.Next())
631 double dimGain = -DBL_MAX;
632 if (datasetInfo.
Type(i) == data::Datatype::categorical)
634 dimGain = CategoricalSplit::template SplitIfBetter<UseWeights>(bestGain,
635 data.cols(begin, begin + count - 1).row(i),
637 responses.subvec(begin, begin + count - 1),
638 UseWeights ? weights.subvec(begin, begin + count - 1) : weights,
641 splitPointOrPrediction,
644 else if (datasetInfo.
Type(i) == data::Datatype::numeric)
646 dimGain = NumericSplit::template SplitIfBetter<UseWeights>(bestGain,
647 data.cols(begin, begin + count - 1).row(i),
648 responses.subvec(begin, begin + count - 1),
649 UseWeights ? weights.subvec(begin, begin + count - 1) : weights,
652 splitPointOrPrediction,
658 if (dimGain == DBL_MAX)
674 dimensionType = (size_t) datasetInfo.
Type(bestDim);
675 splitDimension = bestDim;
678 size_t numChildren = 0;
679 if (datasetInfo.
Type(bestDim) == data::Datatype::categorical)
680 numChildren = CategoricalSplit::NumChildren(splitPointOrPrediction,
683 numChildren = NumericSplit::NumChildren(splitPointOrPrediction, *
this);
686 arma::Row<size_t> childAssignments(count);
687 if (datasetInfo.
Type(bestDim) == data::Datatype::categorical)
689 for (
size_t j = begin; j < begin + count; ++j)
690 childAssignments[j - begin] = CategoricalSplit::CalculateDirection(
691 data(bestDim, j), splitPointOrPrediction, *
this);
695 for (
size_t j = begin; j < begin + count; ++j)
697 childAssignments[j - begin] = NumericSplit::CalculateDirection(
698 data(bestDim, j), splitPointOrPrediction, *
this);
703 arma::Row<size_t> childCounts(numChildren, arma::fill::zeros);
704 for (
size_t i = begin; i < begin + count; ++i)
705 childCounts[childAssignments[i - begin]]++;
714 size_t currentCol = begin;
715 for (
size_t i = 0; i < numChildren; ++i)
717 size_t currentChildBegin = currentCol;
718 for (
size_t j = currentChildBegin; j < begin + count; ++j)
720 if (childAssignments[j - begin] == i)
722 childAssignments.swap_cols(currentCol - begin, j - begin);
723 data.swap_cols(currentCol, j);
724 responses.swap_cols(currentCol, j);
726 weights.swap_cols(currentCol, j);
735 child->
Train<UseWeights>(data, currentChildBegin,
736 currentCol - currentChildBegin, datasetInfo, responses,
737 weights, currentCol - currentChildBegin, minimumGainSplit,
738 maximumDepth - 1, dimensionSelector);
743 double childGain = child->
Train<UseWeights>(data, currentChildBegin,
744 currentCol - currentChildBegin, datasetInfo, responses,
745 weights, minimumLeafSize, minimumGainSplit, maximumDepth - 1,
747 bestGain += double(childCounts[i]) / double(count) * (-childGain);
749 children.push_back(child);
755 NumericAuxiliarySplitInfo::operator=(NumericAuxiliarySplitInfo());
756 CategoricalAuxiliarySplitInfo::operator=(CategoricalAuxiliarySplitInfo());
759 CalculatePrediction<UseWeights>(
760 responses.subvec(begin, begin + count - 1),
761 UseWeights ? weights.subvec(begin, begin + count - 1) : weights);
768 template<
typename FitnessFunction,
769 template<
typename>
class NumericSplitType,
770 template<
typename>
class CategoricalSplitType,
771 typename DimensionSelectionType,
773 template<
bool UseWeights,
typename MatType,
typename ResponsesType>
776 CategoricalSplitType,
777 DimensionSelectionType,
782 ResponsesType& responses,
783 arma::rowvec& weights,
784 const size_t minimumLeafSize,
785 const double minimumGainSplit,
786 const size_t maximumDepth,
787 DimensionSelectionType& dimensionSelector)
790 for (
size_t i = 0; i < children.size(); ++i)
795 CategoricalAuxiliarySplitInfo::operator=(CategoricalAuxiliarySplitInfo());
801 double bestGain = FitnessFunction::template Evaluate<UseWeights>(
802 responses.subvec(begin, begin + count - 1),
803 UseWeights ? weights.subvec(begin, begin + count - 1) : weights);
804 size_t bestDim = data.n_rows;
806 if (maximumDepth != 1)
808 for (
size_t i = dimensionSelector.Begin(); i != dimensionSelector.End();
809 i = dimensionSelector.Next())
811 const double dimGain = NumericSplitType<FitnessFunction>::template
812 SplitIfBetter<UseWeights>(bestGain,
813 data.cols(begin, begin + count - 1).row(i),
814 responses.cols(begin, begin + count - 1),
816 weights.cols(begin, begin + count - 1) :
820 splitPointOrPrediction,
825 if (dimGain == DBL_MAX)
838 if (bestDim != data.n_rows)
841 size_t numChildren = NumericSplit::NumChildren(splitPointOrPrediction,
843 splitDimension = bestDim;
844 dimensionType = (size_t) data::Datatype::numeric;
847 arma::Row<size_t> childAssignments(count);
849 for (
size_t j = begin; j < begin + count; ++j)
851 childAssignments[j - begin] = NumericSplit::CalculateDirection(
852 data(bestDim, j), splitPointOrPrediction, *
this);
856 arma::Row<size_t> childCounts(numChildren);
858 for (
size_t j = begin; j < begin + count; ++j)
859 childCounts[childAssignments[j - begin]]++;
867 size_t currentCol = begin;
868 for (
size_t i = 0; i < numChildren; ++i)
870 size_t currentChildBegin = currentCol;
871 for (
size_t j = currentChildBegin; j < begin + count; ++j)
873 if (childAssignments[j - begin] == i)
875 childAssignments.swap_cols(currentCol - begin, j - begin);
876 data.swap_cols(currentCol, j);
877 responses.swap_cols(currentCol, j);
879 weights.swap_cols(currentCol, j);
888 child->
Train<UseWeights>(data, currentChildBegin,
889 currentCol - currentChildBegin, responses, weights,
890 currentCol - currentChildBegin, minimumGainSplit, maximumDepth - 1,
896 double childGain = child->
Train<UseWeights>(data, currentChildBegin,
897 currentCol - currentChildBegin, responses, weights,
898 minimumLeafSize, minimumGainSplit, maximumDepth - 1,
900 bestGain += double(childCounts[i]) / double(count) * (-childGain);
902 children.push_back(child);
908 NumericAuxiliarySplitInfo::operator=(NumericAuxiliarySplitInfo());
911 CalculatePrediction<UseWeights>(
912 responses.subvec(begin, begin + count - 1),
913 UseWeights ? weights.subvec(begin, begin + count - 1) : weights);
920 template<
typename FitnessFunction,
921 template<
typename>
class NumericSplitType,
922 template<
typename>
class CategoricalSplitType,
923 typename DimensionSelectionType,
925 template<
typename VecType>
928 CategoricalSplitType,
929 DimensionSelectionType,
932 if (children.size() == 0)
935 return splitPointOrPrediction;
942 template<
typename FitnessFunction,
943 template<
typename>
class NumericSplitType,
944 template<
typename>
class CategoricalSplitType,
945 typename DimensionSelectionType,
947 template<
typename MatType>
950 CategoricalSplitType,
951 DimensionSelectionType,
953 >
::Predict(
const MatType& data, arma::Row<double>& predictions)
const 955 predictions.set_size(data.n_cols);
957 if (children.size() == 0)
959 predictions.fill(splitPointOrPrediction);
964 for (
size_t i = 0; i < data.n_cols; ++i)
965 predictions[i] =
Predict(data.col(i));
968 template<
typename FitnessFunction,
969 template<
typename>
class NumericSplitType,
970 template<
typename>
class CategoricalSplitType,
971 typename DimensionSelectionType,
973 template<
bool UseWeights,
typename ResponsesType,
typename WeightsType>
976 CategoricalSplitType,
977 DimensionSelectionType,
979 >::CalculatePrediction(
const ResponsesType& responses,
980 const WeightsType& weights)
984 double accWeights, weightedSum;
985 WeightedSum(responses, weights, 0, responses.n_elem, accWeights,
987 splitPointOrPrediction = weightedSum / accWeights;
992 Sum(responses, 0, responses.n_elem, sum);
993 splitPointOrPrediction = sum / responses.n_elem;
997 template<
typename FitnessFunction,
998 template<
typename>
class NumericSplitType,
999 template<
typename>
class CategoricalSplitType,
1000 typename DimensionSelectionType,
1002 template<
typename VecType>
1005 CategoricalSplitType,
1006 DimensionSelectionType,
1010 if ((
data::Datatype) dimensionType == data::Datatype::categorical)
1011 return CategoricalSplit::CalculateDirection(point[splitDimension],
1012 splitPointOrPrediction, *
this);
1014 return NumericSplit::CalculateDirection(point[splitDimension],
1015 splitPointOrPrediction, *
this);
1019 template<
typename FitnessFunction,
1020 template<
typename>
class NumericSplitType,
1021 template<
typename>
class CategoricalSplitType,
1022 typename DimensionSelectionType,
1024 template<
typename Archive>
1027 CategoricalSplitType,
1028 DimensionSelectionType,
1033 if (cereal::is_loading<Archive>())
1035 for (
size_t i = 0; i < children.size(); ++i)
1043 ar(CEREAL_NVP(splitDimension));
1044 ar(CEREAL_NVP(dimensionType));
1045 ar(CEREAL_NVP(splitPointOrPrediction));
1049 template<
typename FitnessFunction,
1050 template<
typename>
class NumericSplitType,
1051 template<
typename>
class CategoricalSplitType,
1052 typename DimensionSelectionType,
1056 CategoricalSplitType,
1057 DimensionSelectionType,
1063 size_t numLeaves = 0;
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
Definition: dataset_mapper.hpp:41
size_t NumChildren() const
Get the number of children.
Definition: decision_tree_regressor.hpp:406
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Datatype
The Datatype enum specifies the types of data mlpack algorithms can use.
Definition: datatype.hpp:24
size_t NumLeaves() const
Get the number of leaves in the tree.
Definition: decision_tree_regressor_impl.hpp:1058
double Train(MatType data, const data::DatasetInfo &datasetInfo, ResponsesType responses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Train the decision tree on the given data.
Definition: decision_tree_regressor_impl.hpp:428
DecisionTreeRegressor()
Construct a decision tree without training it.
Definition: decision_tree_regressor_impl.hpp:31
Definition: pointer_wrapper.hpp:23
size_t Dimensionality() const
Get the dimensionality of the DatasetMapper object (that is, how many dimensions it has information f...
Definition: dataset_mapper_impl.hpp:228
void WeightedSum(const VecType &values, const WeightVecType &weights, const size_t begin, const size_t end, double &accWeights, double &weightedMean)
Calculates the weighted sum and total weight of labels.
Definition: utils.hpp:19
Datatype Type(const size_t dimension) const
Return the type of a given dimension (numeric or categorical).
Definition: dataset_mapper_impl.hpp:196
Definition: hmm_train_main.cpp:300
This class implements a generic decision tree learner.
Definition: decision_tree_regressor.hpp:41
#define CEREAL_VECTOR_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_vector_wrapper.hpp:93
double Predict(const VecType &point) const
Make prediction for the given point, using the entire tree.
Definition: decision_tree_regressor_impl.hpp:930
size_t CalculateDirection(const VecType &point) const
Given a point and that this node is not a leaf, calculate the index of the child node this point woul...
Definition: decision_tree_regressor_impl.hpp:1008
size_t NumMappings(const size_t dimension) const
Get the number of mappings for a particular dimension.
Definition: dataset_mapper_impl.hpp:222
void serialize(Archive &ar, const uint32_t)
Serialize the tree.
Definition: decision_tree_regressor_impl.hpp:1030
void Sum(const VecType &values, const size_t begin, const size_t end, double &mean)
Sums up the labels vector.
Definition: utils.hpp:96
~DecisionTreeRegressor()
Clean up memory.
Definition: decision_tree_regressor_impl.hpp:411