mlpack
random_binary_numeric_split_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_DECISION_TREE_RANDOM_BINARY_NUMERIC_SPLIT_IMPL_HPP
13 #define MLPACK_METHODS_DECISION_TREE_RANDOM_BINARY_NUMERIC_SPLIT_IMPL_HPP
14 
16 
17 namespace mlpack {
18 namespace tree {
19 
20 // Overload used for classification.
21 template<typename FitnessFunction>
22 template<bool UseWeights, typename VecType, typename WeightVecType>
24  const double bestGain,
25  const VecType& data,
26  const arma::Row<size_t>& labels,
27  const size_t numClasses,
28  const WeightVecType& weights,
29  const size_t minimumLeafSize,
30  const double minimumGainSplit,
31  arma::vec& splitInfo,
32  AuxiliarySplitInfo& /* aux */,
33  const bool splitIfBetterGain)
34 {
35  double bestFoundGain = std::min(bestGain + minimumGainSplit, 0.0);
36  // Forcing a minimum leaf size of 1 (empty children don't make sense).
37  const size_t minimum = std::max(minimumLeafSize, (size_t) 1);
38 
39  // First sanity check: if we don't have enough points, we can't split.
40  if (data.n_elem < (minimum * 2))
41  return DBL_MAX;
42  if (bestGain == 0.0)
43  return DBL_MAX; // It can't be outperformed.
44 
45  typename VecType::elem_type maxValue = arma::max(data);
46  typename VecType::elem_type minValue = arma::min(data);
47 
48  // Sanity check: if the maximum element is the same as the minimum, we
49  // can't split in this dimension.
50  if (maxValue == minValue)
51  return DBL_MAX;
52 
53  // Picking a random pivot to split the dimension.
54  double randomPivot = math::Random(minValue, maxValue);
55 
56  // We need to count the number of points for each class.
57  arma::Mat<size_t> classCounts;
58  arma::mat classWeightSums;
59  double totalWeight = 0.0;
60  double totalLeftWeight = 0.0;
61  double totalRightWeight = 0.0;
62  size_t leftLeafSize = 0;
63  size_t rightLeafSize = 0;
64  if (UseWeights)
65  {
66  classWeightSums.zeros(numClasses, 2);
67  totalWeight = arma::accu(weights);
68  bestFoundGain *= totalWeight;
69 
70  for (size_t i = 0; i < data.n_elem; ++i)
71  {
72  if (data(i) < randomPivot)
73  {
74  ++leftLeafSize;
75  classWeightSums(labels(i), 0) += weights(i);
76  totalLeftWeight += weights(i);
77  }
78  else
79  {
80  ++rightLeafSize;
81  classWeightSums(labels(i), 1) += weights(i);
82  totalRightWeight += weights(i);
83  }
84  }
85  }
86  else
87  {
88  classCounts.zeros(numClasses, 2);
89  bestFoundGain *= data.n_elem;
90 
91  for (size_t i = 0; i < data.n_elem; ++i)
92  {
93  if (data(i) < randomPivot)
94  {
95  ++leftLeafSize;
96  ++classCounts(labels(i), 0);
97  }
98  else
99  {
100  ++rightLeafSize;
101  ++classCounts(labels(i), 1);
102  }
103  }
104  }
105 
106  // Calculate the gain for the left and right child. Only use weights if
107  // needed.
108  const double leftGain = UseWeights ?
109  FitnessFunction::template EvaluatePtr<true>(classWeightSums.colptr(0),
110  numClasses, totalLeftWeight) :
111  FitnessFunction::template EvaluatePtr<false>(classCounts.colptr(0),
112  numClasses, leftLeafSize);
113  const double rightGain = UseWeights ?
114  FitnessFunction::template EvaluatePtr<true>(classWeightSums.colptr(1),
115  numClasses, totalRightWeight) :
116  FitnessFunction::template EvaluatePtr<false>(classCounts.colptr(1),
117  numClasses, rightLeafSize);
118 
119  double gain;
120  if (UseWeights)
121  gain = totalLeftWeight * leftGain + totalRightWeight * rightGain;
122  else
123  // Calculate the gain at this split point.
124  gain = double(leftLeafSize) * leftGain + double(rightLeafSize) * rightGain;
125 
126  if (gain < bestFoundGain && splitIfBetterGain)
127  return DBL_MAX;
128 
129  splitInfo.set_size(1);
130  splitInfo[0] = randomPivot;
131 
132  if (UseWeights)
133  gain /= totalWeight;
134  else
135  gain /= labels.n_elem;
136 
137  return gain;
138 }
139 
140 // Overload used for regression.
141 template<typename FitnessFunction>
142 template<bool UseWeights, typename VecType, typename WeightVecType>
144  const double bestGain,
145  const VecType& data,
146  const arma::rowvec& responses,
147  const WeightVecType& weights,
148  const size_t minimumLeafSize,
149  const double minimumGainSplit,
150  double& splitInfo,
151  AuxiliarySplitInfo& /* aux */,
152  const bool splitIfBetterGain)
153 {
154  double bestFoundGain = std::min(bestGain + minimumGainSplit, 0.0);
155  // Forcing a minimum leaf size of 1 (empty children don't make sense).
156  const size_t minimum = std::max(minimumLeafSize, (size_t) 1);
157 
158  // First sanity check: if we don't have enough points, we can't split.
159  if (data.n_elem < (minimum * 2))
160  return DBL_MAX;
161  if (bestGain == 0.0)
162  return DBL_MAX; // It can't be outperformed.
163 
164  typename VecType::elem_type maxValue = arma::max(data);
165  typename VecType::elem_type minValue = arma::min(data);
166 
167  // Sanity check: if the maximum element is the same as the minimum, we
168  // can't split in this dimension.
169  if (maxValue == minValue)
170  return DBL_MAX;
171 
172  double totalWeight = 0.0;
173  double totalLeftWeight = 0.0;
174  double totalRightWeight = 0.0;
175  if (UseWeights)
176  {
177  totalWeight = arma::accu(weights);
178  bestFoundGain *= totalWeight;
179  }
180  else
181  {
182  bestFoundGain *= data.n_elem;
183  }
184 
185  // Picking a random pivot to split the dimension.
186  double randomPivot = math::Random(minValue, maxValue);
187 
188  // We need to count the number of points for each leaf.
189  size_t leftLeafSize = 0;
190  size_t rightLeafSize = 0;
191  for (size_t i = 0; i < data.n_elem; ++i)
192  {
193  if (UseWeights)
194  {
195  if (data[i] < randomPivot)
196  totalLeftWeight += weights[i];
197  else
198  totalRightWeight += weights[i];
199  }
200 
201  if (data[i] < randomPivot)
202  ++leftLeafSize;
203  else
204  ++rightLeafSize;
205  }
206 
207  // Splitting data to compute gain.
208  arma::rowvec leftResponses(leftLeafSize), rightResponses(rightLeafSize);
209  arma::rowvec leftWeights, rightWeights;
210  if (UseWeights)
211  {
212  leftWeights.set_size(leftLeafSize);
213  rightWeights.set_size(rightLeafSize);
214  }
215 
216  size_t l = 0, r = 0;
217  for (size_t i = 0; i < data.n_elem; ++i)
218  {
219  if (UseWeights)
220  {
221  if (data[i] < randomPivot)
222  leftWeights[l] = weights[i];
223  else
224  rightWeights[r] = weights[i];
225  }
226  if (data[i] < randomPivot)
227  leftResponses[l++] = responses[i];
228  else
229  rightResponses[r++] = responses[i];
230  }
231 
232  // Calculate the gain for the left and right child.
233  const double leftGain = FitnessFunction::template
234  Evaluate<UseWeights>(leftResponses, leftWeights, 0, leftLeafSize);
235  const double rightGain = FitnessFunction::template
236  Evaluate<UseWeights>(rightResponses, rightWeights, 0, rightLeafSize);
237 
238  // Calculate the gain at this split point.
239  double gain;
240  if (UseWeights)
241  gain = totalLeftWeight * leftGain + totalRightWeight * rightGain;
242  else
243  gain = double(leftLeafSize) * leftGain + double(rightLeafSize) * rightGain;
244 
245  if (gain < bestFoundGain && splitIfBetterGain)
246  return DBL_MAX;
247 
248  splitInfo = randomPivot;
249 
250  if (UseWeights)
251  gain /= totalWeight;
252  else
253  gain /= responses.n_elem;
254 
255  return gain;
256 }
257 
258 template<typename FitnessFunction>
259 template<typename ElemType>
261  const ElemType& point,
262  const double& splitInfo,
263  const AuxiliarySplitInfo& /* aux */)
264 {
265  if (point <= splitInfo)
266  return 0; // Go left.
267  else
268  return 1; // Go right.
269 }
270 
271 } // namespace tree
272 } // namespace mlpack
273 
274 #endif
static double SplitIfBetter(const double bestGain, const VecType &data, const arma::Row< size_t > &labels, const size_t numClasses, const WeightVecType &weights, const size_t minimumLeafSize, const double minimumGainSplit, arma::vec &splitInfo, AuxiliarySplitInfo &aux, const bool splitIfBetterGain=false)
Check if we can split a node.
Definition: random_binary_numeric_split_impl.hpp:23
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
static size_t CalculateDirection(const ElemType &point, const double &splitInfo, const AuxiliarySplitInfo &)
Given a point, calculate which child it should go to (left or right).
Definition: random_binary_numeric_split_impl.hpp:260
Definition: random_binary_numeric_split.hpp:32
double Random()
Generates a uniform random number between 0 and 1.
Definition: random.hpp:83
Miscellaneous math random-related routines.