mlpack
all_categorical_split_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_DECISION_TREE_ALL_CATEGORICAL_SPLIT_IMPL_HPP
13 #define MLPACK_METHODS_DECISION_TREE_ALL_CATEGORICAL_SPLIT_IMPL_HPP
14 
15 namespace mlpack {
16 namespace tree {
17 
18 // Overload used in classification.
19 template<typename FitnessFunction>
20 template<bool UseWeights, typename VecType, typename LabelsType,
21  typename WeightVecType>
23  const double bestGain,
24  const VecType& data,
25  const size_t numCategories,
26  const LabelsType& labels,
27  const size_t numClasses,
28  const WeightVecType& weights,
29  const size_t minimumLeafSize,
30  const double minimumGainSplit,
31  arma::vec& splitInfo,
32  AuxiliarySplitInfo& /* aux */)
33 {
34  // Count the number of elements in each potential child.
35  const double epsilon = 1e-7; // Tolerance for floating-point errors.
36  arma::Col<size_t> counts(numCategories, arma::fill::zeros);
37 
38  // If we are using weighted training, split the weights for each child too.
39  arma::vec childWeightSums;
40  double sumWeight = 0.0;
41  if (UseWeights)
42  childWeightSums.zeros(numCategories);
43 
44  for (size_t i = 0; i < data.n_elem; ++i)
45  {
46  counts[(size_t) data[i]]++;
47 
48  if (UseWeights)
49  {
50  childWeightSums[(size_t) data[i]] += weights[i];
51  sumWeight += weights[i];
52  }
53  }
54 
55  // If each child will have the minimum number of points in it, we can split.
56  // Otherwise we can't.
57  if (arma::min(counts) < minimumLeafSize)
58  return DBL_MAX;
59 
60  // Calculate the gain of the split. First we have to calculate the labels
61  // that would be assigned to each child.
62  arma::uvec childPositions(numCategories, arma::fill::zeros);
63  std::vector<arma::Row<size_t>> childLabels(numCategories);
64  std::vector<arma::Row<double>> childWeights(numCategories);
65 
66  for (size_t i = 0; i < numCategories; ++i)
67  {
68  // Labels and weights should have same length.
69  childLabels[i].zeros(counts[i]);
70  if (UseWeights)
71  childWeights[i].zeros(counts[i]);
72  }
73 
74  // Extract labels for each child.
75  for (size_t i = 0; i < data.n_elem; ++i)
76  {
77  const size_t category = (size_t) data[i];
78 
79  if (UseWeights)
80  {
81  childLabels[category][childPositions[category]] = labels[i];
82  childWeights[category][childPositions[category]++] = weights[i];
83  }
84  else
85  {
86  childLabels[category][childPositions[category]++] = labels[i];
87  }
88  }
89 
90  double overallGain = 0.0;
91  for (size_t i = 0; i < counts.n_elem; ++i)
92  {
93  // Calculate the gain of this child.
94  const double childPct = UseWeights ?
95  double(childWeightSums[i]) / sumWeight :
96  double(counts[i]) / double(data.n_elem);
97  const double childGain = FitnessFunction::template Evaluate<UseWeights>(
98  childLabels[i], numClasses, childWeights[i]);
99 
100  overallGain += childPct * childGain;
101  }
102 
103  if (overallGain > bestGain + minimumGainSplit + epsilon)
104  {
105  // This is better, so store it in splitInfo and return.
106  splitInfo.set_size(1);
107  splitInfo[0] = numCategories;
108  return overallGain;
109  }
110 
111  // Otherwise there was no improvement.
112  return DBL_MAX;
113 }
114 
115 // Overload used in regression.
116 template<typename FitnessFunction>
117 template<bool UseWeights, typename VecType, typename ResponsesType,
118  typename WeightVecType>
120  const double bestGain,
121  const VecType& data,
122  const size_t numCategories,
123  const ResponsesType& responses,
124  const WeightVecType& weights,
125  const size_t minimumLeafSize,
126  const double minimumGainSplit,
127  double& splitInfo,
128  AuxiliarySplitInfo& /* aux */)
129 {
130  // Count the number of elements in each potential child.
131  const double epsilon = 1e-7; // Tolerance for floating-point errors.
132  arma::Col<size_t> counts(numCategories, arma::fill::zeros);
133 
134  // If we are using weighted training, split the weights for each child too.
135  arma::vec childWeightSums;
136  double sumWeight = 0.0;
137  if (UseWeights)
138  childWeightSums.zeros(numCategories);
139 
140  for (size_t i = 0; i < data.n_elem; ++i)
141  {
142  counts[(size_t) data[i]]++;
143 
144  if (UseWeights)
145  {
146  childWeightSums[(size_t) data[i]] += weights[i];
147  sumWeight += weights[i];
148  }
149  }
150 
151  // If each child will have the minimum number of points in it, we can split.
152  // Otherwise we can't.
153  if (arma::min(counts) < minimumLeafSize)
154  return DBL_MAX;
155 
156  // Calculate the gain of the split. First we have to calculate the labels
157  // that would be assigned to each child.
158  arma::uvec childPositions(numCategories, arma::fill::zeros);
159  std::vector<arma::rowvec> childResponses(numCategories);
160  std::vector<arma::rowvec> childWeights(numCategories);
161 
162  for (size_t i = 0; i < numCategories; ++i)
163  {
164  // Responses and weights should have same length.
165  childResponses[i].zeros(counts[i]);
166  if (UseWeights)
167  childWeights[i].zeros(counts[i]);
168  }
169 
170  // Extract labels for each child.
171  for (size_t i = 0; i < data.n_elem; ++i)
172  {
173  const size_t category = (size_t) data[i];
174 
175  if (UseWeights)
176  {
177  childResponses[category][childPositions[category]] = responses[i];
178  childWeights[category][childPositions[category]++] = weights[i];
179  }
180  else
181  {
182  childResponses[category][childPositions[category]++] = responses[i];
183  }
184  }
185 
186  double overallGain = 0.0;
187  for (size_t i = 0; i < counts.n_elem; ++i)
188  {
189  // Calculate the gain of this child.
190  const double childPct = UseWeights ?
191  double(childWeightSums[i]) / sumWeight :
192  double(counts[i]) / double(data.n_elem);
193  const double childGain = FitnessFunction::template Evaluate<UseWeights>(
194  childResponses[i], childWeights[i]);
195 
196  overallGain += childPct * childGain;
197  }
198 
199  if (overallGain > bestGain + minimumGainSplit + epsilon)
200  {
201  // This is better, so store it in splitInfo and return.
202  splitInfo = numCategories;
203  return overallGain;
204  }
205 
206  // Otherwise there was no improvement.
207  return DBL_MAX;
208 }
209 
210 template<typename FitnessFunction>
212  const double& splitInfo,
213  const AuxiliarySplitInfo& /* aux */)
214 {
215  return (size_t) splitInfo;
216 }
217 
218 template<typename FitnessFunction>
219 template<typename ElemType>
221  const ElemType& point,
222  const double& /* splitInfo */,
223  const AuxiliarySplitInfo& /* aux */)
224 {
225  return (size_t) point;
226 }
227 
228 } // namespace tree
229 } // namespace mlpack
230 
231 #endif
static size_t CalculateDirection(const ElemType &point, const double &splitInfo, const AuxiliarySplitInfo &)
Calculate the direction a point should percolate to.
Definition: all_categorical_split_impl.hpp:220
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
static size_t NumChildren(const double &splitInfo, const AuxiliarySplitInfo &)
Return the number of children in the split.
Definition: all_categorical_split_impl.hpp:211
Definition: all_categorical_split.hpp:34
static double SplitIfBetter(const double bestGain, const VecType &data, const size_t numCategories, const LabelsType &labels, const size_t numClasses, const WeightVecType &weights, const size_t minimumLeafSize, const double minimumGainSplit, arma::vec &splitInfo, AuxiliarySplitInfo &aux)
Check if we can split a node.
Definition: all_categorical_split_impl.hpp:22