mlpack
hoeffding_numeric_split_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_HOEFFDING_TREES_HOEFFDING_NUMERIC_SPLIT_IMPL_HPP
13 #define MLPACK_METHODS_HOEFFDING_TREES_HOEFFDING_NUMERIC_SPLIT_IMPL_HPP
14 
16 
17 namespace mlpack {
18 namespace tree {
19 
20 template<typename FitnessFunction, typename ObservationType>
22  const size_t numClasses,
23  const size_t bins,
24  const size_t observationsBeforeBinning) :
25  observations(observationsBeforeBinning - 1),
26  labels(observationsBeforeBinning - 1),
27  bins(bins),
28  observationsBeforeBinning(observationsBeforeBinning),
29  samplesSeen(0),
30  sufficientStatistics(arma::zeros<arma::Mat<size_t>>(numClasses, bins))
31 {
32  observations.zeros();
33  labels.zeros();
34 }
35 
36 template<typename FitnessFunction, typename ObservationType>
38  const size_t numClasses,
39  const HoeffdingNumericSplit& other) :
40  observations(other.observationsBeforeBinning - 1),
41  labels(other.observationsBeforeBinning - 1),
42  bins(other.bins),
43  observationsBeforeBinning(other.observationsBeforeBinning),
44  samplesSeen(0),
45  sufficientStatistics(arma::zeros<arma::Mat<size_t>>(numClasses, bins))
46 {
47  observations.zeros();
48  labels.zeros();
49 }
50 
51 template<typename FitnessFunction, typename ObservationType>
53  ObservationType value,
54  const size_t label)
55 {
56  if (samplesSeen < observationsBeforeBinning - 1)
57  {
58  // Add this to the samples we have seen.
59  observations[samplesSeen] = value;
60  labels[samplesSeen] = label;
61  ++samplesSeen;
62  return;
63  }
64  else if (samplesSeen == observationsBeforeBinning - 1)
65  {
66  // Now we need to make the bins.
67  ObservationType min = value;
68  ObservationType max = value;
69  for (size_t i = 0; i < observationsBeforeBinning - 1; ++i)
70  {
71  if (observations[i] < min)
72  min = observations[i];
73  else if (observations[i] > max)
74  max = observations[i];
75  }
76 
77  // Now split these. We can't use linspace, because we don't want to include
78  // the endpoints.
79  splitPoints.resize(bins - 1);
80  const ObservationType binWidth = (max - min) / bins;
81  for (size_t i = 0; i < bins - 1; ++i)
82  splitPoints[i] = min + (i + 1) * binWidth;
83  ++samplesSeen;
84 
85  // Now, add all of the points we've seen to the sufficient statistics.
86  for (size_t i = 0; i < observationsBeforeBinning - 1; ++i)
87  {
88  // What bin does the point fall into?
89  size_t bin = 0;
90  while (bin < bins - 1 && observations[i] > splitPoints[bin])
91  ++bin;
92 
93  sufficientStatistics(labels[i], bin)++;
94  }
95  }
96 
97  // If we've gotten to here, then we need to add the point to the sufficient
98  // statistics. What bin does the point fall into?
99  size_t bin = 0;
100  while (bin < bins - 1 && value > splitPoints[bin])
101  ++bin;
102 
103  sufficientStatistics(label, bin)++;
104 }
105 
106 template<typename FitnessFunction, typename ObservationType>
108  EvaluateFitnessFunction(double& bestFitness,
109  double& secondBestFitness) const
110 {
111  secondBestFitness = 0.0; // We can only split one way.
112  if (samplesSeen < observationsBeforeBinning)
113  bestFitness = 0.0;
114  else
115  bestFitness = FitnessFunction::Evaluate(sufficientStatistics);
116 }
117 
118 template<typename FitnessFunction, typename ObservationType>
120  arma::Col<size_t>& childMajorities,
121  SplitInfo& splitInfo) const
122 {
123  childMajorities.set_size(sufficientStatistics.n_cols);
124  for (size_t i = 0; i < sufficientStatistics.n_cols; ++i)
125  {
126  arma::uword maxIndex = 0;
127  sufficientStatistics.unsafe_col(i).max(maxIndex);
128  childMajorities[i] = size_t(maxIndex);
129  }
130 
131  // Create the SplitInfo object.
132  splitInfo = SplitInfo(splitPoints);
133 }
134 
135 template<typename FitnessFunction, typename ObservationType>
138 {
139  // If we haven't yet determined the bins, we must calculate this by hand.
140  if (samplesSeen < observationsBeforeBinning)
141  {
142  arma::Col<size_t> classes(sufficientStatistics.n_rows);
143  classes.zeros();
144 
145  for (size_t i = 0; i < samplesSeen; ++i)
146  classes[labels[i]]++;
147 
148  arma::uword majorityClass;
149  classes.max(majorityClass);
150  return size_t(majorityClass);
151  }
152  else
153  {
154  // We've calculated the bins, so we can just sum over the sufficient
155  // statistics.
156  arma::Col<size_t> classCounts = arma::sum(sufficientStatistics, 1);
157 
158  arma::uword maxIndex = 0;
159  classCounts.max(maxIndex);
160  return size_t(maxIndex);
161  }
162 }
163 
164 template<typename FitnessFunction, typename ObservationType>
167 {
168  // If we haven't yet determined the bins, we must calculate this by hand.
169  if (samplesSeen < observationsBeforeBinning)
170  {
171  arma::Col<size_t> classes(sufficientStatistics.n_rows);
172  classes.zeros();
173 
174  for (size_t i = 0; i < samplesSeen; ++i)
175  classes[labels[i]]++;
176 
177  return double(classes.max()) / double(arma::accu(classes));
178  }
179  else
180  {
181  // We've calculated the bins, so we can just sum over the sufficient
182  // statistics.
183  arma::Col<size_t> classCounts = arma::sum(sufficientStatistics, 1);
184 
185  return double(classCounts.max()) / double(arma::sum(classCounts));
186  }
187 }
188 
189 template<typename FitnessFunction, typename ObservationType>
190 template<typename Archive>
192  Archive& ar,
193  const uint32_t /* version */)
194 {
195  ar(CEREAL_NVP(samplesSeen));
196  ar(CEREAL_NVP(observationsBeforeBinning));
197  ar(CEREAL_NVP(bins));
198 
199  if (samplesSeen >= observationsBeforeBinning)
200  {
201  // The binning has happened, so we only need to save the resulting bins.
202  ar(CEREAL_NVP(splitPoints));
203  ar(CEREAL_NVP(sufficientStatistics));
204 
205  if (cereal::is_loading<Archive>())
206  {
207  // Clean other objects.
208  observations.clear();
209  labels.clear();
210  }
211  }
212  else
213  {
214  // The binning has not happened yet, so we only need to save the information
215  // required before binning.
216  if (cereal::is_loading<Archive>())
217  {
218  observations.zeros(observationsBeforeBinning);
219  labels.zeros(observationsBeforeBinning);
220  }
221 
222  // Save the number of classes.
223  size_t numClasses;
224  if (cereal::is_saving<Archive>())
225  numClasses = sufficientStatistics.n_rows;
226  ar(CEREAL_NVP(numClasses));
227  ar(CEREAL_NVP(observations));
228  ar(CEREAL_NVP(labels));
229 
230  if (cereal::is_loading<Archive>())
231  {
232  // Clean other objects.
233  splitPoints.clear();
234  sufficientStatistics.zeros(numClasses, bins);
235  }
236  }
237 }
238 
239 } // namespace tree
240 } // namespace mlpack
241 
242 #endif
void EvaluateFitnessFunction(double &bestFitness, double &secondBestFitness) const
Evaluate the fitness function given what has been calculated so far.
Definition: hoeffding_numeric_split_impl.hpp:108
double MajorityProbability() const
Return the probability of the majority class.
Definition: hoeffding_numeric_split_impl.hpp:166
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Split(arma::Col< size_t > &childMajorities, SplitInfo &splitInfo) const
Return the majority class of each child to be created, if a split on this dimension was performed...
Definition: hoeffding_numeric_split_impl.hpp:119
The HoeffdingNumericSplit class implements the numeric feature splitting strategy alluded to by Domin...
Definition: hoeffding_numeric_split.hpp:53
void Train(ObservationType value, const size_t label)
Train the HoeffdingNumericSplit on the given observed value (remember that this object only cares abo...
Definition: hoeffding_numeric_split_impl.hpp:52
NumericSplitInfo< ObservationType > SplitInfo
The splitting information type required by the HoeffdingNumericSplit.
Definition: hoeffding_numeric_split.hpp:57
void serialize(Archive &ar, const uint32_t)
Serialize the object.
Definition: hoeffding_numeric_split_impl.hpp:191
Definition: numeric_split_info.hpp:21
size_t MajorityClass() const
Return the majority class.
Definition: hoeffding_numeric_split_impl.hpp:137
HoeffdingNumericSplit(const size_t numClasses=0, const size_t bins=10, const size_t observationsBeforeBinning=100)
Create the HoeffdingNumericSplit class, and specify some basic parameters about how the binning shoul...
Definition: hoeffding_numeric_split_impl.hpp:21