mlpack
mse_gain.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_DECISION_TREE_MSE_GAIN_HPP
14 #define MLPACK_METHODS_DECISION_TREE_MSE_GAIN_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 #include "utils.hpp"
18 
19 namespace mlpack {
20 namespace tree {
21 
28 class MSEGain
29 {
30  public:
43  template<bool UseWeights, typename VecType, typename WeightVecType>
44  static double Evaluate(const VecType& values,
45  const WeightVecType& weights,
46  const size_t begin,
47  const size_t end)
48  {
49  double mse = 0.0;
50 
51  if (UseWeights)
52  {
53  double accWeights = 0.0;
54  double weightedMean = 0.0;
55  WeightedSum(values, weights, begin, end, accWeights, weightedMean);
56 
57  // Catch edge case: if there are no weights, the impurity is zero.
58  if (accWeights == 0.0)
59  return 0.0;
60 
61  weightedMean /= accWeights;
62 
63  for (size_t i = begin; i < end; ++i)
64  mse += weights[i] * std::pow(values[i] - weightedMean, 2);
65 
66  mse /= accWeights;
67  }
68  else
69  {
70  double mean = 0.0;
71  Sum(values, begin, end, mean);
72  mean /= (double) (end - begin);
73 
74  mse = arma::accu(arma::square(values.subvec(begin, end - 1) - mean));
75  mse /= (double) (end - begin);
76  }
77 
78  return -mse;
79  }
80 
87  template<bool UseWeights, typename VecType, typename WeightVecType>
88  static double Evaluate(const VecType& values,
89  const WeightVecType& weights)
90  {
91  // Corner case: if there are no elements, the impurity is zero.
92  if (values.n_elem == 0)
93  return 0.0;
94 
95  return Evaluate<UseWeights>(values, weights, 0, values.n_elem);
96  }
97 
109  std::tuple<double, double> BinaryGains()
110  {
111  double mseLeft = leftSumSquares / leftSize - leftMean * leftMean;
112  double mseRight = (totalSumSquares - leftSumSquares) / rightSize
113  - rightMean * rightMean;
114 
115  return std::make_tuple(-mseLeft, -mseRight);
116  }
117 
126  template<bool UseWeights, typename ResponsesType, typename WeightVecType>
127  void BinaryScanInitialize(const ResponsesType& responses,
128  const WeightVecType& weights,
129  const size_t minimum)
130  {
131  typedef typename ResponsesType::elem_type RType;
132  typedef typename WeightVecType::elem_type WType;
133 
134  // Initializing data members to cache statistics.
135  leftMean = 0.0;
136  rightMean = 0.0;
137  leftSize = 0.0;
138  rightSize = 0.0;
139  leftSumSquares = 0.0;
140  totalSumSquares = 0.0;
141 
142  if (UseWeights)
143  {
144  totalSumSquares = arma::accu(weights % arma::square(responses));
145  for (size_t i = 0; i < minimum - 1; ++i)
146  {
147  const WType w = weights[i];
148  const RType x = responses[i];
149 
150  // Calculating initial weighted mean of responses for the left child.
151  leftSize += w;
152  leftMean += w * x;
153  leftSumSquares += w * x * x;
154  }
155  if (leftSize > 1e-9)
156  leftMean /= leftSize;
157 
158  for (size_t i = minimum - 1; i < responses.n_elem; ++i)
159  {
160  const WType w = weights[i];
161  const RType x = responses[i];
162 
163  // Calculating initial weighted mean of responses for the right child.
164  rightSize += w;
165  rightMean += w * x;
166  }
167  if (rightSize > 1e-9)
168  rightMean /= rightSize;
169  }
170  else
171  {
172  totalSumSquares = arma::accu(arma::square(responses));
173  for (size_t i = 0; i < minimum - 1; ++i)
174  {
175  const RType x = responses[i];
176 
177  // Calculating the initial mean of responses for the left child.
178  ++leftSize;
179  leftMean += x;
180  leftSumSquares += x * x;
181  }
182  if (leftSize > 1e-9)
183  leftMean /= leftSize;
184 
185  for (size_t i = minimum - 1; i < responses.n_elem; ++i)
186  {
187  const RType x = responses[i];
188 
189  // Calculating the initial mean of responses for the right child.
190  ++rightSize;
191  rightMean += x;
192  }
193  if (rightSize > 1e-9)
194  rightMean /= rightSize;
195  }
196  }
197 
205  template<bool UseWeights, typename ResponsesType, typename WeightVecType>
206  void BinaryStep(const ResponsesType& responses,
207  const WeightVecType& weights,
208  const size_t index)
209  {
210  typedef typename ResponsesType::elem_type RType;
211  typedef typename WeightVecType::elem_type WType;
212 
213  if (UseWeights)
214  {
215  const WType w = weights[index];
216  const RType x = responses[index];
217 
218  // Update weighted sum of squares for left child.
219  leftSumSquares += w * x * x;
220 
221  // Update weighted mean for both childs.
222  leftMean = (leftMean * leftSize + w * x) / (leftSize + w);
223  leftSize += w;
224 
225  rightMean = (rightMean * rightSize - w * x) / (rightSize - w);
226  rightSize -= w;
227  }
228  else
229  {
230  const RType x = responses[index];
231 
232  // Update sum of squares for left child.
233  leftSumSquares += x * x;
234 
235  // Update mean for both childs.
236  leftMean = (leftMean * leftSize + x) / (leftSize + 1);
237  ++leftSize;
238 
239  rightMean = (rightMean * rightSize - x) / (rightSize - 1);
240  --rightSize;
241  }
242  }
243 
244  private:
249  // Stores the sum of squares / weighted sum of squares for the left child.
250  double leftSumSquares;
251  // For unweighted data, stores the number of elements in each child.
252  // For weighted data, stores the sum of weights of elements in each
253  // child.
254  double leftSize;
255  double rightSize;
256  // Stores the mean / weighted mean.
257  double leftMean;
258  double rightMean;
259  // Stores the total sum of squares / total weighted sum of squares.
260  double totalSumSquares;
261 };
262 
263 } // namespace tree
264 } // namespace mlpack
265 
266 #endif
void BinaryStep(const ResponsesType &responses, const WeightVecType &weights, const size_t index)
Updates the statistics for the given index.
Definition: mse_gain.hpp:206
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
static double Evaluate(const VecType &values, const WeightVecType &weights, const size_t begin, const size_t end)
Evaluate the mean squared error gain of values from begin to end index.
Definition: mse_gain.hpp:44
The core includes that mlpack expects; standard C++ includes and Armadillo.
void WeightedSum(const VecType &values, const WeightVecType &weights, const size_t begin, const size_t end, double &accWeights, double &weightedMean)
Calculates the weighted sum and total weight of labels.
Definition: utils.hpp:19
The MSE (Mean squared error) gain, is a measure of set purity based on the variance of response value...
Definition: mse_gain.hpp:28
void BinaryScanInitialize(const ResponsesType &responses, const WeightVecType &weights, const size_t minimum)
Caches the prefix sum of squares to efficiently compute gain value for each split.
Definition: mse_gain.hpp:127
static double Evaluate(const VecType &values, const WeightVecType &weights)
Evaluate the MSE gain on the complete vector.
Definition: mse_gain.hpp:88
std::tuple< double, double > BinaryGains()
Calculates the mean squared error gain for the left and right children for the current index...
Definition: mse_gain.hpp:109
void Sum(const VecType &values, const size_t begin, const size_t end, double &mean)
Sums up the labels vector.
Definition: utils.hpp:96