12 #ifndef MLPACK_METHODS_DECISION_TREE_BEST_BINARY_NUMERIC_SPLIT_IMPL_HPP 13 #define MLPACK_METHODS_DECISION_TREE_BEST_BINARY_NUMERIC_SPLIT_IMPL_HPP 19 template<
typename FitnessFunction>
20 template<
bool UseWeights,
typename VecType,
typename WeightVecType>
22 const double bestGain,
24 const arma::Row<size_t>& labels,
25 const size_t numClasses,
26 const WeightVecType& weights,
27 const size_t minimumLeafSize,
28 const double minimumGainSplit,
33 if (data.n_elem < (minimumLeafSize * 2))
39 arma::uvec sortedIndices = arma::sort_index(data);
40 arma::Row<size_t> sortedLabels(labels.n_elem);
41 arma::rowvec sortedWeights;
42 for (
size_t i = 0; i < sortedLabels.n_elem; ++i)
43 sortedLabels[i] = labels[sortedIndices[i]];
47 if (data[sortedIndices[0]] == data[sortedIndices[sortedIndices.n_elem - 1]])
53 sortedWeights.set_size(sortedLabels.n_elem);
55 for (
size_t i = 0; i < sortedLabels.n_elem; ++i)
56 sortedWeights[i] = weights[sortedIndices[i]];
61 double bestFoundGain = std::min(bestGain + minimumGainSplit, 0.0);
62 bool improved =
false;
63 const size_t minimum = std::max(minimumLeafSize, (
size_t) 1);
66 arma::Mat<size_t> classCounts;
67 arma::mat classWeightSums;
68 double totalWeight = 0.0;
69 double totalLeftWeight = 0.0;
70 double totalRightWeight = 0.0;
73 classWeightSums.zeros(numClasses, 2);
74 totalWeight = arma::accu(sortedWeights);
75 bestFoundGain *= totalWeight;
79 for (
size_t i = 0; i < minimum - 1; ++i)
81 classWeightSums(sortedLabels[i], 0) += sortedWeights[i];
82 totalLeftWeight += sortedWeights[i];
86 for (
size_t i = minimum - 1; i < data.n_elem; ++i)
88 classWeightSums(sortedLabels[i], 1) += sortedWeights[i];
89 totalRightWeight += sortedWeights[i];
94 classCounts.zeros(numClasses, 2);
95 bestFoundGain *= data.n_elem;
99 for (
size_t i = 0; i < minimum - 1; ++i)
100 ++classCounts(sortedLabels[i], 0);
103 for (
size_t i = minimum - 1; i < data.n_elem; ++i)
104 ++classCounts(sortedLabels[i], 1);
107 for (
size_t index = minimum; index < data.n_elem - minimum; ++index)
112 classWeightSums(sortedLabels[index - 1], 1) -= sortedWeights[index - 1];
113 classWeightSums(sortedLabels[index - 1], 0) += sortedWeights[index - 1];
114 totalLeftWeight += sortedWeights[index - 1];
115 totalRightWeight -= sortedWeights[index - 1];
119 --classCounts(sortedLabels[index - 1], 1);
120 ++classCounts(sortedLabels[index - 1], 0);
124 if (data[sortedIndices[index]] == data[sortedIndices[index - 1]])
129 const double leftGain = UseWeights ?
130 FitnessFunction::template EvaluatePtr<true>(classWeightSums.colptr(0),
131 numClasses, totalLeftWeight) :
132 FitnessFunction::template EvaluatePtr<false>(classCounts.colptr(0),
134 const double rightGain = UseWeights ?
135 FitnessFunction::template EvaluatePtr<true>(classWeightSums.colptr(1),
136 numClasses, totalRightWeight) :
137 FitnessFunction::template EvaluatePtr<false>(classCounts.colptr(1),
138 numClasses, size_t(sortedLabels.n_elem - index));
143 gain = totalLeftWeight * leftGain + totalRightWeight * rightGain;
148 gain = double(index) * leftGain +
149 double(sortedLabels.n_elem - index) * rightGain;
158 splitInfo.set_size(1);
159 splitInfo[0] = (data[sortedIndices[index - 1]] +
160 data[sortedIndices[index]]) / 2.0;
164 else if (gain > bestFoundGain)
167 bestFoundGain = gain;
168 splitInfo.set_size(1);
169 splitInfo[0] = (data[sortedIndices[index - 1]] +
170 data[sortedIndices[index]]) / 2.0;
181 bestFoundGain /= totalWeight;
183 bestFoundGain /= sortedLabels.n_elem;
185 return bestFoundGain;
189 template<
typename FitnessFunction>
190 template<
bool UseWeights,
typename VecType,
typename ResponsesType,
191 typename WeightVecType>
192 typename std::enable_if<
196 const double bestGain,
198 const ResponsesType& responses,
199 const WeightVecType& weights,
200 const size_t minimumLeafSize,
201 const double minimumGainSplit,
205 typedef typename ResponsesType::elem_type RType;
206 typedef typename WeightVecType::elem_type WType;
209 if (data.n_elem < (minimumLeafSize * 2))
215 arma::uvec sortedIndices = arma::sort_index(data);
216 arma::Row<RType> sortedResponses(responses.n_elem);
217 arma::Row<WType> sortedWeights;
218 for (
size_t i = 0; i < sortedResponses.n_elem; ++i)
219 sortedResponses[i] = responses[sortedIndices[i]];
223 if (data[sortedIndices[0]] == data[sortedIndices[sortedIndices.n_elem - 1]])
229 sortedWeights.set_size(sortedResponses.n_elem);
231 for (
size_t i = 0; i < sortedResponses.n_elem; ++i)
232 sortedWeights[i] = weights[sortedIndices[i]];
235 double bestFoundGain = std::min(bestGain + minimumGainSplit, 0.0);
236 bool improved =
false;
238 const size_t minimum = std::max(minimumLeafSize, (
size_t) 1);
240 WType totalWeight = 0.0;
241 WType totalLeftWeight = 0.0;
242 WType totalRightWeight = 0.0;
246 totalWeight = arma::accu(sortedWeights);
247 bestFoundGain *= totalWeight;
249 for (
size_t i = 0; i < minimum - 1; ++i)
250 totalLeftWeight += sortedWeights[i];
252 for (
size_t i = minimum - 1; i < data.n_elem; ++i)
253 totalRightWeight += sortedWeights[i];
257 bestFoundGain *= data.n_elem;
261 for (
size_t index = minimum; index < data.n_elem - minimum + 1; ++index)
265 totalLeftWeight += sortedWeights[index - 1];
266 totalRightWeight -= sortedWeights[index - 1];
269 if (data[sortedIndices[index]] == data[sortedIndices[index - 1]])
273 const double leftGain = FitnessFunction::template
274 Evaluate<UseWeights>(sortedResponses, sortedWeights, 0, index);
275 const double rightGain = FitnessFunction::template
276 Evaluate<UseWeights>(sortedResponses, sortedWeights, index,
282 gain = totalLeftWeight * leftGain + totalRightWeight * rightGain;
287 gain = double(index) * leftGain +
288 double(sortedResponses.n_elem - index) * rightGain;
297 splitInfo = (data[sortedIndices[index - 1]] +
298 data[sortedIndices[index]]) / 2.0;
302 if (gain > bestFoundGain)
305 bestFoundGain = gain;
306 splitInfo = (data[sortedIndices[index - 1]] +
307 data[sortedIndices[index]]) / 2.0;
318 bestFoundGain /= totalWeight;
320 bestFoundGain /= data.n_elem;
322 return bestFoundGain;
327 template<
typename FitnessFunction>
328 template<
bool UseWeights,
typename VecType,
typename ResponsesType,
329 typename WeightVecType>
330 typename std::enable_if<
334 const double bestGain,
336 const ResponsesType& responses,
337 const WeightVecType& weights,
338 const size_t minimumLeafSize,
339 const double minimumGainSplit,
343 typedef typename ResponsesType::elem_type RType;
344 typedef typename WeightVecType::elem_type WType;
346 FitnessFunction fitnessFunction;
349 if (data.n_elem < (minimumLeafSize * 2))
355 arma::uvec sortedIndices = arma::sort_index(data);
356 arma::Row<RType> sortedResponses(responses.n_elem);
357 arma::Row<WType> sortedWeights;
358 for (
size_t i = 0; i < sortedResponses.n_elem; ++i)
359 sortedResponses[i] = responses[sortedIndices[i]];
363 if (data[sortedIndices[0]] == data[sortedIndices[sortedIndices.n_elem - 1]])
369 sortedWeights.set_size(sortedResponses.n_elem);
371 for (
size_t i = 0; i < sortedResponses.n_elem; ++i)
372 sortedWeights[i] = weights[sortedIndices[i]];
375 double bestFoundGain = std::min(bestGain + minimumGainSplit, 0.0);
376 bool improved =
false;
378 const size_t minimum = std::max(minimumLeafSize, (
size_t) 1);
380 WType totalWeight = 0.0;
381 WType leftChildWeight = 0.0;
382 WType rightChildWeight = 0.0;
386 totalWeight = arma::accu(sortedWeights);
387 bestFoundGain *= totalWeight;
389 for (
size_t i = 0; i < minimum - 1; ++i)
390 leftChildWeight += sortedWeights[i];
392 for (
size_t i = minimum - 1; i < data.n_elem; ++i)
393 rightChildWeight += sortedWeights[i];
397 bestFoundGain *= data.n_elem;
402 fitnessFunction.template BinaryScanInitialize<UseWeights>(sortedResponses,
403 sortedWeights, minimum);
406 for (
size_t index = minimum; index < data.n_elem - minimum + 1; ++index)
410 leftChildWeight += sortedWeights[index - 1];
411 rightChildWeight -= sortedWeights[index - 1];
415 fitnessFunction.template BinaryStep<UseWeights>(sortedResponses,
416 sortedWeights, index - 1);
419 if (data[sortedIndices[index]] == data[sortedIndices[index - 1]])
423 std::tuple<double, double> binaryGains = fitnessFunction.BinaryGains();
424 const double leftGain = std::get<0>(binaryGains);
425 const double rightGain = std::get<1>(binaryGains);
430 gain = leftChildWeight * leftGain + rightChildWeight * rightGain;
435 gain = double(index) * leftGain +
436 double(sortedResponses.n_elem - index) * rightGain;
445 splitInfo = (data[sortedIndices[index - 1]] +
446 data[sortedIndices[index]]) / 2.0;
450 if (gain > bestFoundGain)
453 bestFoundGain = gain;
454 splitInfo = (data[sortedIndices[index - 1]] +
455 data[sortedIndices[index]]) / 2.0;
465 bestFoundGain /= totalWeight;
467 bestFoundGain /= data.n_elem;
469 return bestFoundGain;
472 template<
typename FitnessFunction>
473 template<
typename ElemType>
475 const ElemType& point,
476 const double& splitInfo,
479 if (point <= splitInfo)
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
static double SplitIfBetter(const double bestGain, const VecType &data, const arma::Row< size_t > &labels, const size_t numClasses, const WeightVecType &weights, const size_t minimumLeafSize, const double minimumGainSplit, arma::vec &splitInfo, AuxiliarySplitInfo &aux)
Check if we can split a node.
Definition: best_binary_numeric_split_impl.hpp:21
static size_t CalculateDirection(const ElemType &point, const double &splitInfo, const AuxiliarySplitInfo &)
Given a point, calculate which child it should go to (left or right).
Definition: best_binary_numeric_split_impl.hpp:474
Definition: best_binary_numeric_split.hpp:53