12 #ifndef MLPACK_METHODS_DECISION_TREE_DECISION_TREE_IMPL_HPP 13 #define MLPACK_METHODS_DECISION_TREE_DECISION_TREE_IMPL_HPP 21 template<
typename FitnessFunction,
22 template<
typename>
class NumericSplitType,
23 template<
typename>
class CategoricalSplitType,
24 typename DimensionSelectionType,
26 template<
typename MatType,
typename LabelsType>
27 DecisionTree<FitnessFunction,
30 DimensionSelectionType,
35 const size_t numClasses,
36 const size_t minimumLeafSize,
37 const double minimumGainSplit,
38 const size_t maximumDepth,
39 DimensionSelectionType dimensionSelector)
41 using TrueMatType =
typename std::decay<MatType>::type;
42 using TrueLabelsType =
typename std::decay<LabelsType>::type;
45 TrueMatType tmpData(std::move(data));
46 TrueLabelsType tmpLabels(std::move(labels));
49 dimensionSelector.Dimensions() = tmpData.n_rows;
53 Train<false>(tmpData, 0, tmpData.n_cols, datasetInfo, tmpLabels, numClasses,
54 weights, minimumLeafSize, minimumGainSplit, maximumDepth,
59 template<
typename FitnessFunction,
60 template<
typename>
class NumericSplitType,
61 template<
typename>
class CategoricalSplitType,
62 typename DimensionSelectionType,
64 template<
typename MatType,
typename LabelsType>
68 DimensionSelectionType,
72 const size_t numClasses,
73 const size_t minimumLeafSize,
74 const double minimumGainSplit,
75 const size_t maximumDepth,
76 DimensionSelectionType dimensionSelector)
78 using TrueMatType =
typename std::decay<MatType>::type;
79 using TrueLabelsType =
typename std::decay<LabelsType>::type;
82 TrueMatType tmpData(std::move(data));
83 TrueLabelsType tmpLabels(std::move(labels));
86 dimensionSelector.Dimensions() = tmpData.n_rows;
90 Train<false>(tmpData, 0, tmpData.n_cols, tmpLabels, numClasses, weights,
91 minimumLeafSize, minimumGainSplit, maximumDepth, dimensionSelector);
95 template<
typename FitnessFunction,
96 template<
typename>
class NumericSplitType,
97 template<
typename>
class CategoricalSplitType,
98 typename DimensionSelectionType,
100 template<
typename MatType,
typename LabelsType,
typename WeightsType>
103 CategoricalSplitType,
104 DimensionSelectionType,
109 const size_t numClasses,
111 const size_t minimumLeafSize,
112 const double minimumGainSplit,
113 const size_t maximumDepth,
114 DimensionSelectionType dimensionSelector,
115 const std::enable_if_t<arma::is_arma_type<
116 typename std::remove_reference<WeightsType>::type>::value>*)
118 using TrueMatType =
typename std::decay<MatType>::type;
119 using TrueLabelsType =
typename std::decay<LabelsType>::type;
120 using TrueWeightsType =
typename std::decay<WeightsType>::type;
123 TrueMatType tmpData(std::move(data));
124 TrueLabelsType tmpLabels(std::move(labels));
125 TrueWeightsType tmpWeights(std::move(weights));
128 dimensionSelector.Dimensions() = tmpData.n_rows;
131 Train<true>(tmpData, 0, tmpData.n_cols, datasetInfo, tmpLabels, numClasses,
132 tmpWeights, minimumLeafSize, minimumGainSplit, maximumDepth,
137 template<
typename FitnessFunction,
138 template<
typename>
class NumericSplitType,
139 template<
typename>
class CategoricalSplitType,
140 typename DimensionSelectionType,
142 template<
typename MatType,
typename LabelsType,
typename WeightsType>
145 CategoricalSplitType,
146 DimensionSelectionType,
152 const size_t numClasses,
154 const size_t minimumLeafSize,
155 const double minimumGainSplit,
156 const std::enable_if_t<arma::is_arma_type<
157 typename std::remove_reference<WeightsType>::type>::value>*):
158 NumericAuxiliarySplitInfo(other),
159 CategoricalAuxiliarySplitInfo(other)
161 using TrueMatType =
typename std::decay<MatType>::type;
162 using TrueLabelsType =
typename std::decay<LabelsType>::type;
163 using TrueWeightsType =
typename std::decay<WeightsType>::type;
166 TrueMatType tmpData(std::move(data));
167 TrueLabelsType tmpLabels(std::move(labels));
168 TrueWeightsType tmpWeights(std::move(weights));
171 Train<true>(tmpData, 0, tmpData.n_cols, datasetInfo, tmpLabels, numClasses,
172 tmpWeights, minimumLeafSize, minimumGainSplit);
176 template<
typename FitnessFunction,
177 template<
typename>
class NumericSplitType,
178 template<
typename>
class CategoricalSplitType,
179 typename DimensionSelectionType,
181 template<
typename MatType,
typename LabelsType,
typename WeightsType>
184 CategoricalSplitType,
185 DimensionSelectionType,
189 const size_t numClasses,
191 const size_t minimumLeafSize,
192 const double minimumGainSplit,
193 const size_t maximumDepth,
194 DimensionSelectionType dimensionSelector,
195 const std::enable_if_t<
197 typename std::remove_reference<
198 WeightsType>::type>::value>*)
200 using TrueMatType =
typename std::decay<MatType>::type;
201 using TrueLabelsType =
typename std::decay<LabelsType>::type;
202 using TrueWeightsType =
typename std::decay<WeightsType>::type;
205 TrueMatType tmpData(std::move(data));
206 TrueLabelsType tmpLabels(std::move(labels));
207 TrueWeightsType tmpWeights(std::move(weights));
210 dimensionSelector.Dimensions() = tmpData.n_rows;
213 Train<true>(tmpData, 0, tmpData.n_cols, tmpLabels, numClasses, tmpWeights,
214 minimumLeafSize, minimumGainSplit, maximumDepth, dimensionSelector);
218 template<
typename FitnessFunction,
219 template<
typename>
class NumericSplitType,
220 template<
typename>
class CategoricalSplitType,
221 typename DimensionSelectionType,
223 template<
typename MatType,
typename LabelsType,
typename WeightsType>
226 CategoricalSplitType,
227 DimensionSelectionType,
232 const size_t numClasses,
234 const size_t minimumLeafSize,
235 const double minimumGainSplit,
236 const size_t maximumDepth,
237 DimensionSelectionType dimensionSelector,
238 const std::enable_if_t<arma::is_arma_type<
239 typename std::remove_reference<
240 WeightsType>::type>::value>*):
241 NumericAuxiliarySplitInfo(other),
242 CategoricalAuxiliarySplitInfo(other)
244 using TrueMatType =
typename std::decay<MatType>::type;
245 using TrueLabelsType =
typename std::decay<LabelsType>::type;
246 using TrueWeightsType =
typename std::decay<WeightsType>::type;
249 TrueMatType tmpData(std::move(data));
250 TrueLabelsType tmpLabels(std::move(labels));
251 TrueWeightsType tmpWeights(std::move(weights));
254 dimensionSelector.Dimensions() = tmpData.n_rows;
257 Train<true>(tmpData, 0, tmpData.n_cols, tmpLabels, numClasses, tmpWeights,
258 minimumLeafSize, minimumGainSplit, maximumDepth, dimensionSelector);
262 template<
typename FitnessFunction,
263 template<
typename>
class NumericSplitType,
264 template<
typename>
class CategoricalSplitType,
265 typename DimensionSelectionType,
269 CategoricalSplitType,
270 DimensionSelectionType,
273 dimensionTypeOrMajorityClass(0),
274 classProbabilities(numClasses)
277 classProbabilities.fill(1.0 / (
double) numClasses);
281 template<
typename FitnessFunction,
282 template<
typename>
class NumericSplitType,
283 template<
typename>
class CategoricalSplitType,
284 typename DimensionSelectionType,
288 CategoricalSplitType,
289 DimensionSelectionType,
291 NumericAuxiliarySplitInfo(other),
292 CategoricalAuxiliarySplitInfo(other),
293 splitDimension(other.splitDimension),
294 dimensionTypeOrMajorityClass(other.dimensionTypeOrMajorityClass),
295 classProbabilities(other.classProbabilities)
298 for (
size_t i = 0; i < other.children.size(); ++i)
299 children.push_back(
new DecisionTree(*other.children[i]));
303 template<
typename FitnessFunction,
304 template<
typename>
class NumericSplitType,
305 template<
typename>
class CategoricalSplitType,
306 typename DimensionSelectionType,
310 CategoricalSplitType,
311 DimensionSelectionType,
313 NumericAuxiliarySplitInfo(
std::move(other)),
314 CategoricalAuxiliarySplitInfo(
std::move(other)),
315 children(
std::move(other.children)),
316 splitDimension(other.splitDimension),
317 dimensionTypeOrMajorityClass(other.dimensionTypeOrMajorityClass),
318 classProbabilities(
std::move(other.classProbabilities))
321 other.classProbabilities.ones(1);
325 template<
typename FitnessFunction,
326 template<
typename>
class NumericSplitType,
327 template<
typename>
class CategoricalSplitType,
328 typename DimensionSelectionType,
332 CategoricalSplitType,
333 DimensionSelectionType,
337 CategoricalSplitType,
338 DimensionSelectionType,
345 for (
size_t i = 0; i < children.size(); ++i)
350 splitDimension = other.splitDimension;
351 dimensionTypeOrMajorityClass = other.dimensionTypeOrMajorityClass;
352 classProbabilities = other.classProbabilities;
355 for (
size_t i = 0; i < other.children.size(); ++i)
356 children.push_back(
new DecisionTree(*other.children[i]));
359 NumericAuxiliarySplitInfo::operator=(other);
360 CategoricalAuxiliarySplitInfo::operator=(other);
366 template<
typename FitnessFunction,
367 template<
typename>
class NumericSplitType,
368 template<
typename>
class CategoricalSplitType,
369 typename DimensionSelectionType,
373 CategoricalSplitType,
374 DimensionSelectionType,
378 CategoricalSplitType,
379 DimensionSelectionType,
386 for (
size_t i = 0; i < children.size(); ++i)
391 children = std::move(other.children);
392 splitDimension = other.splitDimension;
393 dimensionTypeOrMajorityClass = other.dimensionTypeOrMajorityClass;
394 classProbabilities = std::move(other.classProbabilities);
397 other.classProbabilities.ones(1);
400 NumericAuxiliarySplitInfo::operator=(std::move(other));
401 CategoricalAuxiliarySplitInfo::operator=(std::move(other));
407 template<
typename FitnessFunction,
408 template<
typename>
class NumericSplitType,
409 template<
typename>
class CategoricalSplitType,
410 typename DimensionSelectionType,
414 CategoricalSplitType,
415 DimensionSelectionType,
418 for (
size_t i = 0; i < children.size(); ++i)
423 template<
typename FitnessFunction,
424 template<
typename>
class NumericSplitType,
425 template<
typename>
class CategoricalSplitType,
426 typename DimensionSelectionType,
428 template<
typename MatType,
typename LabelsType>
431 CategoricalSplitType,
432 DimensionSelectionType,
437 const size_t numClasses,
438 const size_t minimumLeafSize,
439 const double minimumGainSplit,
440 const size_t maximumDepth,
441 DimensionSelectionType dimensionSelector)
444 util::CheckSameSizes(data, labels,
"DecisionTree::Train()");
446 using TrueMatType =
typename std::decay<MatType>::type;
447 using TrueLabelsType =
typename std::decay<LabelsType>::type;
450 TrueMatType tmpData(std::move(data));
451 TrueLabelsType tmpLabels(std::move(labels));
454 dimensionSelector.Dimensions() = tmpData.n_rows;
457 arma::rowvec weights;
458 return Train<false>(tmpData, 0, tmpData.n_cols, datasetInfo, tmpLabels,
459 numClasses, weights, minimumLeafSize, minimumGainSplit, maximumDepth,
464 template<
typename FitnessFunction,
465 template<
typename>
class NumericSplitType,
466 template<
typename>
class CategoricalSplitType,
467 typename DimensionSelectionType,
469 template<
typename MatType,
typename LabelsType>
472 CategoricalSplitType,
473 DimensionSelectionType,
477 const size_t numClasses,
478 const size_t minimumLeafSize,
479 const double minimumGainSplit,
480 const size_t maximumDepth,
481 DimensionSelectionType dimensionSelector)
484 util::CheckSameSizes(data, labels,
"DecisionTree::Train()");
486 using TrueMatType =
typename std::decay<MatType>::type;
487 using TrueLabelsType =
typename std::decay<LabelsType>::type;
490 TrueMatType tmpData(std::move(data));
491 TrueLabelsType tmpLabels(std::move(labels));
494 dimensionSelector.Dimensions() = tmpData.n_rows;
497 arma::rowvec weights;
498 return Train<false>(tmpData, 0, tmpData.n_cols, tmpLabels, numClasses,
499 weights, minimumLeafSize, minimumGainSplit, maximumDepth,
504 template<
typename FitnessFunction,
505 template<
typename>
class NumericSplitType,
506 template<
typename>
class CategoricalSplitType,
507 typename DimensionSelectionType,
509 template<
typename MatType,
typename LabelsType,
typename WeightsType>
512 CategoricalSplitType,
513 DimensionSelectionType,
518 const size_t numClasses,
520 const size_t minimumLeafSize,
521 const double minimumGainSplit,
522 const size_t maximumDepth,
523 DimensionSelectionType dimensionSelector,
524 const std::enable_if_t<
526 typename std::remove_reference<
527 WeightsType>::type>::value>*)
530 util::CheckSameSizes(data, labels,
"DecisionTree::Train()");
532 using TrueMatType =
typename std::decay<MatType>::type;
533 using TrueLabelsType =
typename std::decay<LabelsType>::type;
534 using TrueWeightsType =
typename std::decay<WeightsType>::type;
537 TrueMatType tmpData(std::move(data));
538 TrueLabelsType tmpLabels(std::move(labels));
539 TrueWeightsType tmpWeights(std::move(weights));
542 dimensionSelector.Dimensions() = tmpData.n_rows;
545 return Train<true>(tmpData, 0, tmpData.n_cols, datasetInfo, tmpLabels,
546 numClasses, tmpWeights, minimumLeafSize, minimumGainSplit, maximumDepth,
551 template<
typename FitnessFunction,
552 template<
typename>
class NumericSplitType,
553 template<
typename>
class CategoricalSplitType,
554 typename DimensionSelectionType,
556 template<
typename MatType,
typename LabelsType,
typename WeightsType>
559 CategoricalSplitType,
560 DimensionSelectionType,
564 const size_t numClasses,
566 const size_t minimumLeafSize,
567 const double minimumGainSplit,
568 const size_t maximumDepth,
569 DimensionSelectionType dimensionSelector,
570 const std::enable_if_t<
572 typename std::remove_reference<
573 WeightsType>::type>::value>*)
576 util::CheckSameSizes(data, labels,
"DecisionTree::Train()");
578 using TrueMatType =
typename std::decay<MatType>::type;
579 using TrueLabelsType =
typename std::decay<LabelsType>::type;
580 using TrueWeightsType =
typename std::decay<WeightsType>::type;
583 TrueMatType tmpData(std::move(data));
584 TrueLabelsType tmpLabels(std::move(labels));
585 TrueWeightsType tmpWeights(std::move(weights));
588 dimensionSelector.Dimensions() = tmpData.n_rows;
591 return Train<true>(tmpData, 0, tmpData.n_cols, tmpLabels, numClasses,
592 tmpWeights, minimumLeafSize, minimumGainSplit, maximumDepth,
597 template<
typename FitnessFunction,
598 template<
typename>
class NumericSplitType,
599 template<
typename>
class CategoricalSplitType,
600 typename DimensionSelectionType,
602 template<
bool UseWeights,
typename MatType>
605 CategoricalSplitType,
606 DimensionSelectionType,
612 arma::Row<size_t>& labels,
613 const size_t numClasses,
614 arma::rowvec& weights,
615 const size_t minimumLeafSize,
616 const double minimumGainSplit,
617 const size_t maximumDepth,
618 DimensionSelectionType& dimensionSelector)
621 for (
size_t i = 0; i < children.size(); ++i)
630 double bestGain = FitnessFunction::template Evaluate<UseWeights>(
631 labels.subvec(begin, begin + count - 1),
633 UseWeights ? weights.subvec(begin, begin + count - 1) : weights);
635 const size_t end = dimensionSelector.End();
637 if (maximumDepth != 1)
639 for (
size_t i = dimensionSelector.Begin(); i != end;
640 i = dimensionSelector.Next())
642 double dimGain = -DBL_MAX;
643 if (datasetInfo.
Type(i) == data::Datatype::categorical)
645 dimGain = CategoricalSplit::template SplitIfBetter<UseWeights>(bestGain,
646 data.cols(begin, begin + count - 1).row(i),
648 labels.subvec(begin, begin + count - 1),
650 UseWeights ? weights.subvec(begin, begin + count - 1) : weights,
656 else if (datasetInfo.
Type(i) == data::Datatype::numeric)
658 dimGain = NumericSplit::template SplitIfBetter<UseWeights>(bestGain,
659 data.cols(begin, begin + count - 1).row(i),
660 labels.subvec(begin, begin + count - 1),
662 UseWeights ? weights.subvec(begin, begin + count - 1) : weights,
671 if (dimGain == DBL_MAX)
687 dimensionTypeOrMajorityClass = (size_t) datasetInfo.
Type(bestDim);
688 splitDimension = bestDim;
691 size_t numChildren = 0;
692 if (datasetInfo.
Type(bestDim) == data::Datatype::categorical)
693 numChildren = CategoricalSplit::NumChildren(classProbabilities[0], *
this);
695 numChildren = NumericSplit::NumChildren(classProbabilities[0], *
this);
698 arma::Row<size_t> childAssignments(count);
699 if (datasetInfo.
Type(bestDim) == data::Datatype::categorical)
701 for (
size_t j = begin; j < begin + count; ++j)
702 childAssignments[j - begin] = CategoricalSplit::CalculateDirection(
703 data(bestDim, j), classProbabilities[0], *
this);
707 for (
size_t j = begin; j < begin + count; ++j)
709 childAssignments[j - begin] = NumericSplit::CalculateDirection(
710 data(bestDim, j), classProbabilities[0], *
this);
715 arma::Row<size_t> childCounts(numChildren, arma::fill::zeros);
716 for (
size_t i = begin; i < begin + count; ++i)
717 childCounts[childAssignments[i - begin]]++;
726 size_t currentCol = begin;
727 for (
size_t i = 0; i < numChildren; ++i)
729 size_t currentChildBegin = currentCol;
730 for (
size_t j = currentChildBegin; j < begin + count; ++j)
732 if (childAssignments[j - begin] == i)
734 childAssignments.swap_cols(currentCol - begin, j - begin);
735 data.swap_cols(currentCol, j);
736 labels.swap_cols(currentCol, j);
738 weights.swap_cols(currentCol, j);
747 child->
Train<UseWeights>(data, currentChildBegin,
748 currentCol - currentChildBegin, datasetInfo, labels, numClasses,
749 weights, currentCol - currentChildBegin, minimumGainSplit,
750 maximumDepth - 1, dimensionSelector);
755 double childGain = child->
Train<UseWeights>(data, currentChildBegin,
756 currentCol - currentChildBegin, datasetInfo, labels, numClasses,
757 weights, minimumLeafSize, minimumGainSplit, maximumDepth - 1,
759 bestGain += double(childCounts[i]) / double(count) * (-childGain);
761 children.push_back(child);
767 NumericAuxiliarySplitInfo::operator=(NumericAuxiliarySplitInfo());
768 CategoricalAuxiliarySplitInfo::operator=(CategoricalAuxiliarySplitInfo());
771 CalculateClassProbabilities<UseWeights>(
772 labels.subvec(begin, begin + count - 1),
774 UseWeights ? weights.subvec(begin, begin + count - 1) : weights);
781 template<
typename FitnessFunction,
782 template<
typename>
class NumericSplitType,
783 template<
typename>
class CategoricalSplitType,
784 typename DimensionSelectionType,
786 template<
bool UseWeights,
typename MatType>
789 CategoricalSplitType,
790 DimensionSelectionType,
795 arma::Row<size_t>& labels,
796 const size_t numClasses,
797 arma::rowvec& weights,
798 const size_t minimumLeafSize,
799 const double minimumGainSplit,
800 const size_t maximumDepth,
801 DimensionSelectionType& dimensionSelector)
804 for (
size_t i = 0; i < children.size(); ++i)
809 CategoricalAuxiliarySplitInfo::operator=(CategoricalAuxiliarySplitInfo());
816 double bestGain = FitnessFunction::template Evaluate<UseWeights>(
817 labels.subvec(begin, begin + count - 1),
819 UseWeights ? weights.subvec(begin, begin + count - 1) : weights);
820 size_t bestDim = data.n_rows;
822 if (maximumDepth != 1)
824 for (
size_t i = dimensionSelector.Begin(); i != dimensionSelector.End();
825 i = dimensionSelector.Next())
827 const double dimGain = NumericSplitType<FitnessFunction>::template
828 SplitIfBetter<UseWeights>(bestGain,
829 data.cols(begin, begin + count - 1).row(i),
830 labels.cols(begin, begin + count - 1),
833 weights.cols(begin, begin + count - 1) :
842 if (dimGain == DBL_MAX)
855 if (bestDim != data.n_rows)
859 NumericSplit::NumChildren(classProbabilities[0], *
this);
860 splitDimension = bestDim;
861 dimensionTypeOrMajorityClass = (size_t) data::Datatype::numeric;
864 arma::Row<size_t> childAssignments(count);
866 for (
size_t j = begin; j < begin + count; ++j)
868 childAssignments[j - begin] = NumericSplit::CalculateDirection(
869 data(bestDim, j), classProbabilities[0], *
this);
873 arma::Row<size_t> childCounts(numChildren);
875 for (
size_t j = begin; j < begin + count; ++j)
876 childCounts[childAssignments[j - begin]]++;
884 size_t currentCol = begin;
885 for (
size_t i = 0; i < numChildren; ++i)
887 size_t currentChildBegin = currentCol;
888 for (
size_t j = currentChildBegin; j < begin + count; ++j)
890 if (childAssignments[j - begin] == i)
892 childAssignments.swap_cols(currentCol - begin, j - begin);
893 data.swap_cols(currentCol, j);
894 labels.swap_cols(currentCol, j);
896 weights.swap_cols(currentCol, j);
905 child->
Train<UseWeights>(data, currentChildBegin,
906 currentCol - currentChildBegin, labels, numClasses, weights,
907 currentCol - currentChildBegin, minimumGainSplit, maximumDepth - 1,
913 double childGain = child->
Train<UseWeights>(data, currentChildBegin,
914 currentCol - currentChildBegin, labels, numClasses, weights,
915 minimumLeafSize, minimumGainSplit, maximumDepth - 1,
917 bestGain += double(childCounts[i]) / double(count) * (-childGain);
919 children.push_back(child);
925 NumericAuxiliarySplitInfo::operator=(NumericAuxiliarySplitInfo());
928 CalculateClassProbabilities<UseWeights>(
929 labels.subvec(begin, begin + count - 1),
931 UseWeights ? weights.subvec(begin, begin + count - 1) : weights);
938 template<
typename FitnessFunction,
939 template<
typename>
class NumericSplitType,
940 template<
typename>
class CategoricalSplitType,
941 typename DimensionSelectionType,
943 template<
typename VecType>
946 CategoricalSplitType,
947 DimensionSelectionType,
950 if (children.size() == 0)
953 return dimensionTypeOrMajorityClass;
960 template<
typename FitnessFunction,
961 template<
typename>
class NumericSplitType,
962 template<
typename>
class CategoricalSplitType,
963 typename DimensionSelectionType,
965 template<
typename VecType>
968 CategoricalSplitType,
969 DimensionSelectionType,
972 arma::vec& probabilities)
const 974 if (children.size() == 0)
976 prediction = dimensionTypeOrMajorityClass;
977 probabilities = classProbabilities;
986 template<
typename FitnessFunction,
987 template<
typename>
class NumericSplitType,
988 template<
typename>
class CategoricalSplitType,
989 typename DimensionSelectionType,
991 template<
typename MatType>
994 CategoricalSplitType,
995 DimensionSelectionType,
997 arma::Row<size_t>& predictions)
const 999 predictions.set_size(data.n_cols);
1000 if (children.size() == 0)
1002 predictions.fill(dimensionTypeOrMajorityClass);
1007 for (
size_t i = 0; i < data.n_cols; ++i)
1008 predictions[i] =
Classify(data.col(i));
1012 template<
typename FitnessFunction,
1013 template<
typename>
class NumericSplitType,
1014 template<
typename>
class CategoricalSplitType,
1015 typename DimensionSelectionType,
1017 template<
typename MatType>
1020 CategoricalSplitType,
1021 DimensionSelectionType,
1023 arma::Row<size_t>& predictions,
1024 arma::mat& probabilities)
const 1026 predictions.set_size(data.n_cols);
1027 if (children.size() == 0)
1029 predictions.fill(dimensionTypeOrMajorityClass);
1030 probabilities = arma::repmat(classProbabilities, 1, data.n_cols);
1038 node = &node->
Child(0);
1039 probabilities.set_size(node->classProbabilities.n_elem, data.n_cols);
1041 for (
size_t i = 0; i < data.n_cols; ++i)
1043 arma::vec v = probabilities.unsafe_col(i);
1044 Classify(data.col(i), predictions[i], v);
1049 template<
typename FitnessFunction,
1050 template<
typename>
class NumericSplitType,
1051 template<
typename>
class CategoricalSplitType,
1052 typename DimensionSelectionType,
1054 template<
typename Archive>
1057 CategoricalSplitType,
1058 DimensionSelectionType,
1063 if (cereal::is_loading<Archive>())
1065 for (
size_t i = 0; i < children.size(); ++i)
1073 ar(CEREAL_NVP(splitDimension));
1074 ar(CEREAL_NVP(dimensionTypeOrMajorityClass));
1075 ar(CEREAL_NVP(classProbabilities));
1078 template<
typename FitnessFunction,
1079 template<
typename>
class NumericSplitType,
1080 template<
typename>
class CategoricalSplitType,
1081 typename DimensionSelectionType,
1083 template<
typename VecType>
1086 CategoricalSplitType,
1087 DimensionSelectionType,
1091 data::Datatype::categorical)
1092 return CategoricalSplit::CalculateDirection(point[splitDimension],
1093 classProbabilities[0], *
this);
1095 return NumericSplit::CalculateDirection(point[splitDimension],
1096 classProbabilities[0], *
this);
1100 template<
typename FitnessFunction,
1101 template<
typename>
class NumericSplitType,
1102 template<
typename>
class CategoricalSplitType,
1103 typename DimensionSelectionType,
1107 CategoricalSplitType,
1108 DimensionSelectionType,
1113 if (children.size() == 0)
1114 return classProbabilities.n_elem;
1116 return children[0]->NumClasses();
1119 template<
typename FitnessFunction,
1120 template<
typename>
class NumericSplitType,
1121 template<
typename>
class CategoricalSplitType,
1122 typename DimensionSelectionType,
1124 template<
bool UseWeights,
typename RowType,
typename WeightsRowType>
1127 CategoricalSplitType,
1128 DimensionSelectionType,
1129 NoRecursion>::CalculateClassProbabilities(
1130 const RowType& labels,
1131 const size_t numClasses,
1132 const WeightsRowType& weights)
1134 classProbabilities.zeros(numClasses);
1135 double sumWeights = 0.0;
1136 for (
size_t i = 0; i < labels.n_elem; ++i)
1140 classProbabilities[labels[i]] += weights[i];
1141 sumWeights += weights[i];
1145 classProbabilities[labels[i]]++;
1150 classProbabilities /= UseWeights ? sumWeights : labels.n_elem;
1151 arma::uword maxIndex = 0;
1152 classProbabilities.max(maxIndex);
1153 dimensionTypeOrMajorityClass = (size_t) maxIndex;
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
Definition: dataset_mapper.hpp:41
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Datatype
The Datatype enum specifies the types of data mlpack algorithms can use.
Definition: datatype.hpp:24
This class implements a generic decision tree learner.
Definition: decision_tree.hpp:40
Definition: pointer_wrapper.hpp:23
size_t Classify(const VecType &point) const
Classify the given point, using the entire tree.
Definition: decision_tree_impl.hpp:948
size_t Dimensionality() const
Get the dimensionality of the DatasetMapper object (that is, how many dimensions it has information f...
Definition: dataset_mapper_impl.hpp:228
double Train(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Train the decision tree on the given data.
Definition: decision_tree_impl.hpp:433
void serialize(Archive &ar, const uint32_t)
Serialize the tree.
Definition: decision_tree_impl.hpp:1059
size_t CalculateDirection(const VecType &point) const
Given a point and that this node is not a leaf, calculate the index of the child node this point woul...
Definition: decision_tree_impl.hpp:1088
Datatype Type(const size_t dimension) const
Return the type of a given dimension (numeric or categorical).
Definition: dataset_mapper_impl.hpp:196
Definition: hmm_train_main.cpp:300
#define CEREAL_VECTOR_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_vector_wrapper.hpp:93
size_t NumMappings(const size_t dimension) const
Get the number of mappings for a particular dimension.
Definition: dataset_mapper_impl.hpp:222
const DecisionTree & Child(const size_t i) const
Get the child of the given index.
Definition: decision_tree.hpp:458
~DecisionTree()
Clean up memory.
Definition: decision_tree_impl.hpp:416
size_t NumClasses() const
Get the number of classes in the tree.
Definition: decision_tree_impl.hpp:1109
size_t NumChildren() const
Get the number of children.
Definition: decision_tree.hpp:455
DecisionTree(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Construct the decision tree on the given data and labels, where the data can be both numeric and cate...
Definition: decision_tree_impl.hpp:31