mlpack
best_binary_numeric_split_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_DECISION_TREE_BEST_BINARY_NUMERIC_SPLIT_IMPL_HPP
13 #define MLPACK_METHODS_DECISION_TREE_BEST_BINARY_NUMERIC_SPLIT_IMPL_HPP
14 
15 namespace mlpack {
16 namespace tree {
17 
18 // Overload used for classification.
19 template<typename FitnessFunction>
20 template<bool UseWeights, typename VecType, typename WeightVecType>
22  const double bestGain,
23  const VecType& data,
24  const arma::Row<size_t>& labels,
25  const size_t numClasses,
26  const WeightVecType& weights,
27  const size_t minimumLeafSize,
28  const double minimumGainSplit,
29  arma::vec& splitInfo,
30  AuxiliarySplitInfo& /* aux */)
31 {
32  // First sanity check: if we don't have enough points, we can't split.
33  if (data.n_elem < (minimumLeafSize * 2))
34  return DBL_MAX;
35  if (bestGain == 0.0)
36  return DBL_MAX; // It can't be outperformed.
37 
38  // Next, sort the data.
39  arma::uvec sortedIndices = arma::sort_index(data);
40  arma::Row<size_t> sortedLabels(labels.n_elem);
41  arma::rowvec sortedWeights;
42  for (size_t i = 0; i < sortedLabels.n_elem; ++i)
43  sortedLabels[i] = labels[sortedIndices[i]];
44 
45  // Sanity check: if the first element is the same as the last, we can't split
46  // in this dimension.
47  if (data[sortedIndices[0]] == data[sortedIndices[sortedIndices.n_elem - 1]])
48  return DBL_MAX;
49 
50  // Only initialize if we are using weights.
51  if (UseWeights)
52  {
53  sortedWeights.set_size(sortedLabels.n_elem);
54  // The weights must keep the same order as the labels.
55  for (size_t i = 0; i < sortedLabels.n_elem; ++i)
56  sortedWeights[i] = weights[sortedIndices[i]];
57  }
58 
59  // Loop through all possible split points, choosing the best one. Also, force
60  // a minimum leaf size of 1 (empty children don't make sense).
61  double bestFoundGain = std::min(bestGain + minimumGainSplit, 0.0);
62  bool improved = false;
63  const size_t minimum = std::max(minimumLeafSize, (size_t) 1);
64 
65  // We need to count the number of points for each class.
66  arma::Mat<size_t> classCounts;
67  arma::mat classWeightSums;
68  double totalWeight = 0.0;
69  double totalLeftWeight = 0.0;
70  double totalRightWeight = 0.0;
71  if (UseWeights)
72  {
73  classWeightSums.zeros(numClasses, 2);
74  totalWeight = arma::accu(sortedWeights);
75  bestFoundGain *= totalWeight;
76 
77  // Initialize the counts.
78  // These points have to be on the left.
79  for (size_t i = 0; i < minimum - 1; ++i)
80  {
81  classWeightSums(sortedLabels[i], 0) += sortedWeights[i];
82  totalLeftWeight += sortedWeights[i];
83  }
84 
85  // These points have to be on the right.
86  for (size_t i = minimum - 1; i < data.n_elem; ++i)
87  {
88  classWeightSums(sortedLabels[i], 1) += sortedWeights[i];
89  totalRightWeight += sortedWeights[i];
90  }
91  }
92  else
93  {
94  classCounts.zeros(numClasses, 2);
95  bestFoundGain *= data.n_elem;
96 
97  // Initialize the counts.
98  // These points have to be on the left.
99  for (size_t i = 0; i < minimum - 1; ++i)
100  ++classCounts(sortedLabels[i], 0);
101 
102  // These points have to be on the right.
103  for (size_t i = minimum - 1; i < data.n_elem; ++i)
104  ++classCounts(sortedLabels[i], 1);
105  }
106 
107  for (size_t index = minimum; index < data.n_elem - minimum; ++index)
108  {
109  // Update class weight sums or counts.
110  if (UseWeights)
111  {
112  classWeightSums(sortedLabels[index - 1], 1) -= sortedWeights[index - 1];
113  classWeightSums(sortedLabels[index - 1], 0) += sortedWeights[index - 1];
114  totalLeftWeight += sortedWeights[index - 1];
115  totalRightWeight -= sortedWeights[index - 1];
116  }
117  else
118  {
119  --classCounts(sortedLabels[index - 1], 1);
120  ++classCounts(sortedLabels[index - 1], 0);
121  }
122 
123  // Make sure that the value has changed.
124  if (data[sortedIndices[index]] == data[sortedIndices[index - 1]])
125  continue;
126 
127  // Calculate the gain for the left and right child. Only use weights if
128  // needed.
129  const double leftGain = UseWeights ?
130  FitnessFunction::template EvaluatePtr<true>(classWeightSums.colptr(0),
131  numClasses, totalLeftWeight) :
132  FitnessFunction::template EvaluatePtr<false>(classCounts.colptr(0),
133  numClasses, index);
134  const double rightGain = UseWeights ?
135  FitnessFunction::template EvaluatePtr<true>(classWeightSums.colptr(1),
136  numClasses, totalRightWeight) :
137  FitnessFunction::template EvaluatePtr<false>(classCounts.colptr(1),
138  numClasses, size_t(sortedLabels.n_elem - index));
139 
140  double gain;
141  if (UseWeights)
142  {
143  gain = totalLeftWeight * leftGain + totalRightWeight * rightGain;
144  }
145  else
146  {
147  // Calculate the gain at this split point.
148  gain = double(index) * leftGain +
149  double(sortedLabels.n_elem - index) * rightGain;
150  }
151 
152  // Corner case: is this the best possible split?
153  if (gain >= 0.0)
154  {
155  // We can take a shortcut: no split will be better than this, so just
156  // take this one. The actual split value will be halfway between the
157  // value at index - 1 and index.
158  splitInfo.set_size(1);
159  splitInfo[0] = (data[sortedIndices[index - 1]] +
160  data[sortedIndices[index]]) / 2.0;
161 
162  return gain;
163  }
164  else if (gain > bestFoundGain)
165  {
166  // We still have a better split.
167  bestFoundGain = gain;
168  splitInfo.set_size(1);
169  splitInfo[0] = (data[sortedIndices[index - 1]] +
170  data[sortedIndices[index]]) / 2.0;
171  improved = true;
172  }
173  }
174 
175  // If we didn't improve, return the original gain exactly as we got it
176  // (without introducing floating point errors).
177  if (!improved)
178  return DBL_MAX;
179 
180  if (UseWeights)
181  bestFoundGain /= totalWeight;
182  else
183  bestFoundGain /= sortedLabels.n_elem;
184 
185  return bestFoundGain;
186 }
187 
188 // Overload used for regression.
189 template<typename FitnessFunction>
190 template<bool UseWeights, typename VecType, typename ResponsesType,
191  typename WeightVecType>
192 typename std::enable_if<
194  double>::type
196  const double bestGain,
197  const VecType& data,
198  const ResponsesType& responses,
199  const WeightVecType& weights,
200  const size_t minimumLeafSize,
201  const double minimumGainSplit,
202  double& splitInfo,
203  AuxiliarySplitInfo& /* aux */)
204 {
205  typedef typename ResponsesType::elem_type RType;
206  typedef typename WeightVecType::elem_type WType;
207 
208  // First sanity check: if we don't have enough points, we can't split.
209  if (data.n_elem < (minimumLeafSize * 2))
210  return DBL_MAX;
211  if (bestGain == 0.0)
212  return DBL_MAX; // It can't be outperformed.
213 
214  // Next, sort the data.
215  arma::uvec sortedIndices = arma::sort_index(data);
216  arma::Row<RType> sortedResponses(responses.n_elem);
217  arma::Row<WType> sortedWeights;
218  for (size_t i = 0; i < sortedResponses.n_elem; ++i)
219  sortedResponses[i] = responses[sortedIndices[i]];
220 
221  // Sanity check: if the first element is the same as the last, we can't split
222  // in this dimension.
223  if (data[sortedIndices[0]] == data[sortedIndices[sortedIndices.n_elem - 1]])
224  return DBL_MAX;
225 
226  // Only initialize if we are using weights.
227  if (UseWeights)
228  {
229  sortedWeights.set_size(sortedResponses.n_elem);
230  // The weights must keep the same order as the responses.
231  for (size_t i = 0; i < sortedResponses.n_elem; ++i)
232  sortedWeights[i] = weights[sortedIndices[i]];
233  }
234 
235  double bestFoundGain = std::min(bestGain + minimumGainSplit, 0.0);
236  bool improved = false;
237  // Force a minimum leaf size of 1 (empty children don't make sense).
238  const size_t minimum = std::max(minimumLeafSize, (size_t) 1);
239 
240  WType totalWeight = 0.0;
241  WType totalLeftWeight = 0.0;
242  WType totalRightWeight = 0.0;
243 
244  if (UseWeights)
245  {
246  totalWeight = arma::accu(sortedWeights);
247  bestFoundGain *= totalWeight;
248 
249  for (size_t i = 0; i < minimum - 1; ++i)
250  totalLeftWeight += sortedWeights[i];
251 
252  for (size_t i = minimum - 1; i < data.n_elem; ++i)
253  totalRightWeight += sortedWeights[i];
254  }
255  else
256  {
257  bestFoundGain *= data.n_elem;
258  }
259 
260  // Loop through all possible split points, choosing the best one.
261  for (size_t index = minimum; index < data.n_elem - minimum + 1; ++index)
262  {
263  if (UseWeights)
264  {
265  totalLeftWeight += sortedWeights[index - 1];
266  totalRightWeight -= sortedWeights[index - 1];
267  }
268  // Make sure that the value has changed.
269  if (data[sortedIndices[index]] == data[sortedIndices[index - 1]])
270  continue;
271 
272  // Calculate the gain for the left and right child.
273  const double leftGain = FitnessFunction::template
274  Evaluate<UseWeights>(sortedResponses, sortedWeights, 0, index);
275  const double rightGain = FitnessFunction::template
276  Evaluate<UseWeights>(sortedResponses, sortedWeights, index,
277  responses.n_elem);
278 
279  double gain;
280  if (UseWeights)
281  {
282  gain = totalLeftWeight * leftGain + totalRightWeight * rightGain;
283  }
284  else
285  {
286  // Calculate the gain at this split point.
287  gain = double(index) * leftGain +
288  double(sortedResponses.n_elem - index) * rightGain;
289  }
290 
291  // Corner case: is this the best possible split?
292  if (gain >= 0.0)
293  {
294  // We can take a shortcut: no split will be better than this, so just
295  // take this one. The actual split value will be halfway between the
296  // value at index - 1 and index.
297  splitInfo = (data[sortedIndices[index - 1]] +
298  data[sortedIndices[index]]) / 2.0;
299 
300  return gain;
301  }
302  if (gain > bestFoundGain)
303  {
304  // We still have a better split.
305  bestFoundGain = gain;
306  splitInfo = (data[sortedIndices[index - 1]] +
307  data[sortedIndices[index]]) / 2.0;
308  improved = true;
309  }
310  }
311 
312  // If we didn't improve, return the original gain exactly as we got it
313  // (without introducing floating point errors).
314  if (!improved)
315  return DBL_MAX;
316 
317  if (UseWeights)
318  bestFoundGain /= totalWeight;
319  else
320  bestFoundGain /= data.n_elem;
321 
322  return bestFoundGain;
323 }
324 
325 // Optimized version for any fitness function that implements
326 // BinaryScanInitialize(), BinaryStep() and BinaryGains() functions.
327 template<typename FitnessFunction>
328 template<bool UseWeights, typename VecType, typename ResponsesType,
329  typename WeightVecType>
330 typename std::enable_if<
332  double>::type
334  const double bestGain,
335  const VecType& data,
336  const ResponsesType& responses,
337  const WeightVecType& weights,
338  const size_t minimumLeafSize,
339  const double minimumGainSplit,
340  double& splitInfo,
341  AuxiliarySplitInfo& /* aux */)
342 {
343  typedef typename ResponsesType::elem_type RType;
344  typedef typename WeightVecType::elem_type WType;
345 
346  FitnessFunction fitnessFunction;
347 
348  // First sanity check: if we don't have enough points, we can't split.
349  if (data.n_elem < (minimumLeafSize * 2))
350  return DBL_MAX;
351  if (bestGain == 0.0)
352  return DBL_MAX; // It can't be outperformed.
353 
354  // Next, sort the data.
355  arma::uvec sortedIndices = arma::sort_index(data);
356  arma::Row<RType> sortedResponses(responses.n_elem);
357  arma::Row<WType> sortedWeights;
358  for (size_t i = 0; i < sortedResponses.n_elem; ++i)
359  sortedResponses[i] = responses[sortedIndices[i]];
360 
361  // Sanity check: if the first element is the same as the last, we can't split
362  // in this dimension.
363  if (data[sortedIndices[0]] == data[sortedIndices[sortedIndices.n_elem - 1]])
364  return DBL_MAX;
365 
366  // Only initialize if we are using weights.
367  if (UseWeights)
368  {
369  sortedWeights.set_size(sortedResponses.n_elem);
370  // The weights must keep the same order as the responses.
371  for (size_t i = 0; i < sortedResponses.n_elem; ++i)
372  sortedWeights[i] = weights[sortedIndices[i]];
373  }
374 
375  double bestFoundGain = std::min(bestGain + minimumGainSplit, 0.0);
376  bool improved = false;
377  // Force a minimum leaf size of 1 (empty children don't make sense).
378  const size_t minimum = std::max(minimumLeafSize, (size_t) 1);
379 
380  WType totalWeight = 0.0;
381  WType leftChildWeight = 0.0;
382  WType rightChildWeight = 0.0;
383 
384  if (UseWeights)
385  {
386  totalWeight = arma::accu(sortedWeights);
387  bestFoundGain *= totalWeight;
388 
389  for (size_t i = 0; i < minimum - 1; ++i)
390  leftChildWeight += sortedWeights[i];
391 
392  for (size_t i = minimum - 1; i < data.n_elem; ++i)
393  rightChildWeight += sortedWeights[i];
394  }
395  else
396  {
397  bestFoundGain *= data.n_elem;
398  }
399 
400  // Initialize and precompute various statistics to efficiently compute gain
401  // values for all possible splits.
402  fitnessFunction.template BinaryScanInitialize<UseWeights>(sortedResponses,
403  sortedWeights, minimum);
404 
405  // Loop through all possible split points, choosing the best one.
406  for (size_t index = minimum; index < data.n_elem - minimum + 1; ++index)
407  {
408  if (UseWeights)
409  {
410  leftChildWeight += sortedWeights[index - 1];
411  rightChildWeight -= sortedWeights[index - 1];
412  }
413 
414  // Steps through the current index and updates the cached data.
415  fitnessFunction.template BinaryStep<UseWeights>(sortedResponses,
416  sortedWeights, index - 1);
417 
418  // Make sure that the value has changed.
419  if (data[sortedIndices[index]] == data[sortedIndices[index - 1]])
420  continue;
421 
422  // Calculate the gain for the left and right child.
423  std::tuple<double, double> binaryGains = fitnessFunction.BinaryGains();
424  const double leftGain = std::get<0>(binaryGains);
425  const double rightGain = std::get<1>(binaryGains);
426 
427  double gain;
428  if (UseWeights)
429  {
430  gain = leftChildWeight * leftGain + rightChildWeight * rightGain;
431  }
432  else
433  {
434  // Calculate the gain at this split point.
435  gain = double(index) * leftGain +
436  double(sortedResponses.n_elem - index) * rightGain;
437  }
438 
439  // Corner case: is this the best possible split?
440  if (gain >= 0.0)
441  {
442  // We can take a shortcut: no split will be better than this, so just
443  // take this one. The actual split value will be halfway between the
444  // value at index - 1 and index.
445  splitInfo = (data[sortedIndices[index - 1]] +
446  data[sortedIndices[index]]) / 2.0;
447 
448  return gain;
449  }
450  if (gain > bestFoundGain)
451  {
452  // We still have a better split.
453  bestFoundGain = gain;
454  splitInfo = (data[sortedIndices[index - 1]] +
455  data[sortedIndices[index]]) / 2.0;
456  improved = true;
457  }
458  }
459  // If we didn't improve, return the original gain exactly as we got it
460  // (without introducing floating point errors).
461  if (!improved)
462  return DBL_MAX;
463 
464  if (UseWeights)
465  bestFoundGain /= totalWeight;
466  else
467  bestFoundGain /= data.n_elem;
468 
469  return bestFoundGain;
470 }
471 
472 template<typename FitnessFunction>
473 template<typename ElemType>
475  const ElemType& point,
476  const double& splitInfo,
477  const AuxiliarySplitInfo& /* aux */)
478 {
479  if (point <= splitInfo)
480  return 0; // Go left.
481  else
482  return 1; // Go right.
483 }
484 
485 } // namespace tree
486 } // namespace mlpack
487 
488 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
static double SplitIfBetter(const double bestGain, const VecType &data, const arma::Row< size_t > &labels, const size_t numClasses, const WeightVecType &weights, const size_t minimumLeafSize, const double minimumGainSplit, arma::vec &splitInfo, AuxiliarySplitInfo &aux)
Check if we can split a node.
Definition: best_binary_numeric_split_impl.hpp:21
static size_t CalculateDirection(const ElemType &point, const double &splitInfo, const AuxiliarySplitInfo &)
Given a point, calculate which child it should go to (left or right).
Definition: best_binary_numeric_split_impl.hpp:474
Definition: best_binary_numeric_split.hpp:36
Definition: best_binary_numeric_split.hpp:53