12 #ifndef MLPACK_METHODS_DECISION_TREE_RANDOM_BINARY_NUMERIC_SPLIT_IMPL_HPP 13 #define MLPACK_METHODS_DECISION_TREE_RANDOM_BINARY_NUMERIC_SPLIT_IMPL_HPP 21 template<
typename FitnessFunction>
22 template<
bool UseWeights,
typename VecType,
typename WeightVecType>
24 const double bestGain,
26 const arma::Row<size_t>& labels,
27 const size_t numClasses,
28 const WeightVecType& weights,
29 const size_t minimumLeafSize,
30 const double minimumGainSplit,
33 const bool splitIfBetterGain)
35 double bestFoundGain = std::min(bestGain + minimumGainSplit, 0.0);
37 const size_t minimum = std::max(minimumLeafSize, (
size_t) 1);
40 if (data.n_elem < (minimum * 2))
45 typename VecType::elem_type maxValue = arma::max(data);
46 typename VecType::elem_type minValue = arma::min(data);
50 if (maxValue == minValue)
57 arma::Mat<size_t> classCounts;
58 arma::mat classWeightSums;
59 double totalWeight = 0.0;
60 double totalLeftWeight = 0.0;
61 double totalRightWeight = 0.0;
62 size_t leftLeafSize = 0;
63 size_t rightLeafSize = 0;
66 classWeightSums.zeros(numClasses, 2);
67 totalWeight = arma::accu(weights);
68 bestFoundGain *= totalWeight;
70 for (
size_t i = 0; i < data.n_elem; ++i)
72 if (data(i) < randomPivot)
75 classWeightSums(labels(i), 0) += weights(i);
76 totalLeftWeight += weights(i);
81 classWeightSums(labels(i), 1) += weights(i);
82 totalRightWeight += weights(i);
88 classCounts.zeros(numClasses, 2);
89 bestFoundGain *= data.n_elem;
91 for (
size_t i = 0; i < data.n_elem; ++i)
93 if (data(i) < randomPivot)
96 ++classCounts(labels(i), 0);
101 ++classCounts(labels(i), 1);
108 const double leftGain = UseWeights ?
109 FitnessFunction::template EvaluatePtr<true>(classWeightSums.colptr(0),
110 numClasses, totalLeftWeight) :
111 FitnessFunction::template EvaluatePtr<false>(classCounts.colptr(0),
112 numClasses, leftLeafSize);
113 const double rightGain = UseWeights ?
114 FitnessFunction::template EvaluatePtr<true>(classWeightSums.colptr(1),
115 numClasses, totalRightWeight) :
116 FitnessFunction::template EvaluatePtr<false>(classCounts.colptr(1),
117 numClasses, rightLeafSize);
121 gain = totalLeftWeight * leftGain + totalRightWeight * rightGain;
124 gain = double(leftLeafSize) * leftGain + double(rightLeafSize) * rightGain;
126 if (gain < bestFoundGain && splitIfBetterGain)
129 splitInfo.set_size(1);
130 splitInfo[0] = randomPivot;
135 gain /= labels.n_elem;
141 template<
typename FitnessFunction>
142 template<
bool UseWeights,
typename VecType,
typename WeightVecType>
144 const double bestGain,
146 const arma::rowvec& responses,
147 const WeightVecType& weights,
148 const size_t minimumLeafSize,
149 const double minimumGainSplit,
152 const bool splitIfBetterGain)
154 double bestFoundGain = std::min(bestGain + minimumGainSplit, 0.0);
156 const size_t minimum = std::max(minimumLeafSize, (
size_t) 1);
159 if (data.n_elem < (minimum * 2))
164 typename VecType::elem_type maxValue = arma::max(data);
165 typename VecType::elem_type minValue = arma::min(data);
169 if (maxValue == minValue)
172 double totalWeight = 0.0;
173 double totalLeftWeight = 0.0;
174 double totalRightWeight = 0.0;
177 totalWeight = arma::accu(weights);
178 bestFoundGain *= totalWeight;
182 bestFoundGain *= data.n_elem;
189 size_t leftLeafSize = 0;
190 size_t rightLeafSize = 0;
191 for (
size_t i = 0; i < data.n_elem; ++i)
195 if (data[i] < randomPivot)
196 totalLeftWeight += weights[i];
198 totalRightWeight += weights[i];
201 if (data[i] < randomPivot)
208 arma::rowvec leftResponses(leftLeafSize), rightResponses(rightLeafSize);
209 arma::rowvec leftWeights, rightWeights;
212 leftWeights.set_size(leftLeafSize);
213 rightWeights.set_size(rightLeafSize);
217 for (
size_t i = 0; i < data.n_elem; ++i)
221 if (data[i] < randomPivot)
222 leftWeights[l] = weights[i];
224 rightWeights[r] = weights[i];
226 if (data[i] < randomPivot)
227 leftResponses[l++] = responses[i];
229 rightResponses[r++] = responses[i];
233 const double leftGain = FitnessFunction::template
234 Evaluate<UseWeights>(leftResponses, leftWeights, 0, leftLeafSize);
235 const double rightGain = FitnessFunction::template
236 Evaluate<UseWeights>(rightResponses, rightWeights, 0, rightLeafSize);
241 gain = totalLeftWeight * leftGain + totalRightWeight * rightGain;
243 gain = double(leftLeafSize) * leftGain + double(rightLeafSize) * rightGain;
245 if (gain < bestFoundGain && splitIfBetterGain)
248 splitInfo = randomPivot;
253 gain /= responses.n_elem;
258 template<
typename FitnessFunction>
259 template<
typename ElemType>
261 const ElemType& point,
262 const double& splitInfo,
265 if (point <= splitInfo)
static double SplitIfBetter(const double bestGain, const VecType &data, const arma::Row< size_t > &labels, const size_t numClasses, const WeightVecType &weights, const size_t minimumLeafSize, const double minimumGainSplit, arma::vec &splitInfo, AuxiliarySplitInfo &aux, const bool splitIfBetterGain=false)
Check if we can split a node.
Definition: random_binary_numeric_split_impl.hpp:23
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
static size_t CalculateDirection(const ElemType &point, const double &splitInfo, const AuxiliarySplitInfo &)
Given a point, calculate which child it should go to (left or right).
Definition: random_binary_numeric_split_impl.hpp:260
Definition: random_binary_numeric_split.hpp:32
double Random()
Generates a uniform random number between 0 and 1.
Definition: random.hpp:83
Miscellaneous math random-related routines.