12 #ifndef MLPACK_METHODS_DECISION_TREE_ALL_CATEGORICAL_SPLIT_IMPL_HPP 13 #define MLPACK_METHODS_DECISION_TREE_ALL_CATEGORICAL_SPLIT_IMPL_HPP 19 template<
typename FitnessFunction>
20 template<
bool UseWeights,
typename VecType,
typename LabelsType,
21 typename WeightVecType>
23 const double bestGain,
25 const size_t numCategories,
26 const LabelsType& labels,
27 const size_t numClasses,
28 const WeightVecType& weights,
29 const size_t minimumLeafSize,
30 const double minimumGainSplit,
35 const double epsilon = 1e-7;
36 arma::Col<size_t> counts(numCategories, arma::fill::zeros);
39 arma::vec childWeightSums;
40 double sumWeight = 0.0;
42 childWeightSums.zeros(numCategories);
44 for (
size_t i = 0; i < data.n_elem; ++i)
46 counts[(size_t) data[i]]++;
50 childWeightSums[(size_t) data[i]] += weights[i];
51 sumWeight += weights[i];
57 if (arma::min(counts) < minimumLeafSize)
62 arma::uvec childPositions(numCategories, arma::fill::zeros);
63 std::vector<arma::Row<size_t>> childLabels(numCategories);
64 std::vector<arma::Row<double>> childWeights(numCategories);
66 for (
size_t i = 0; i < numCategories; ++i)
69 childLabels[i].zeros(counts[i]);
71 childWeights[i].zeros(counts[i]);
75 for (
size_t i = 0; i < data.n_elem; ++i)
77 const size_t category = (size_t) data[i];
81 childLabels[category][childPositions[category]] = labels[i];
82 childWeights[category][childPositions[category]++] = weights[i];
86 childLabels[category][childPositions[category]++] = labels[i];
90 double overallGain = 0.0;
91 for (
size_t i = 0; i < counts.n_elem; ++i)
94 const double childPct = UseWeights ?
95 double(childWeightSums[i]) / sumWeight :
96 double(counts[i]) / double(data.n_elem);
97 const double childGain = FitnessFunction::template Evaluate<UseWeights>(
98 childLabels[i], numClasses, childWeights[i]);
100 overallGain += childPct * childGain;
103 if (overallGain > bestGain + minimumGainSplit + epsilon)
106 splitInfo.set_size(1);
107 splitInfo[0] = numCategories;
116 template<
typename FitnessFunction>
117 template<
bool UseWeights,
typename VecType,
typename ResponsesType,
118 typename WeightVecType>
120 const double bestGain,
122 const size_t numCategories,
123 const ResponsesType& responses,
124 const WeightVecType& weights,
125 const size_t minimumLeafSize,
126 const double minimumGainSplit,
131 const double epsilon = 1e-7;
132 arma::Col<size_t> counts(numCategories, arma::fill::zeros);
135 arma::vec childWeightSums;
136 double sumWeight = 0.0;
138 childWeightSums.zeros(numCategories);
140 for (
size_t i = 0; i < data.n_elem; ++i)
142 counts[(size_t) data[i]]++;
146 childWeightSums[(size_t) data[i]] += weights[i];
147 sumWeight += weights[i];
153 if (arma::min(counts) < minimumLeafSize)
158 arma::uvec childPositions(numCategories, arma::fill::zeros);
159 std::vector<arma::rowvec> childResponses(numCategories);
160 std::vector<arma::rowvec> childWeights(numCategories);
162 for (
size_t i = 0; i < numCategories; ++i)
165 childResponses[i].zeros(counts[i]);
167 childWeights[i].zeros(counts[i]);
171 for (
size_t i = 0; i < data.n_elem; ++i)
173 const size_t category = (size_t) data[i];
177 childResponses[category][childPositions[category]] = responses[i];
178 childWeights[category][childPositions[category]++] = weights[i];
182 childResponses[category][childPositions[category]++] = responses[i];
186 double overallGain = 0.0;
187 for (
size_t i = 0; i < counts.n_elem; ++i)
190 const double childPct = UseWeights ?
191 double(childWeightSums[i]) / sumWeight :
192 double(counts[i]) / double(data.n_elem);
193 const double childGain = FitnessFunction::template Evaluate<UseWeights>(
194 childResponses[i], childWeights[i]);
196 overallGain += childPct * childGain;
199 if (overallGain > bestGain + minimumGainSplit + epsilon)
202 splitInfo = numCategories;
210 template<
typename FitnessFunction>
212 const double& splitInfo,
215 return (
size_t) splitInfo;
218 template<
typename FitnessFunction>
219 template<
typename ElemType>
221 const ElemType& point,
225 return (
size_t) point;
static size_t CalculateDirection(const ElemType &point, const double &splitInfo, const AuxiliarySplitInfo &)
Calculate the direction a point should percolate to.
Definition: all_categorical_split_impl.hpp:220
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
static size_t NumChildren(const double &splitInfo, const AuxiliarySplitInfo &)
Return the number of children in the split.
Definition: all_categorical_split_impl.hpp:211
Definition: all_categorical_split.hpp:34
static double SplitIfBetter(const double bestGain, const VecType &data, const size_t numCategories, const LabelsType &labels, const size_t numClasses, const WeightVecType &weights, const size_t minimumLeafSize, const double minimumGainSplit, arma::vec &splitInfo, AuxiliarySplitInfo &aux)
Check if we can split a node.
Definition: all_categorical_split_impl.hpp:22