12 #ifndef MLPACK_METHODS_HOEFFDING_TREES_HOEFFDING_NUMERIC_SPLIT_IMPL_HPP 13 #define MLPACK_METHODS_HOEFFDING_TREES_HOEFFDING_NUMERIC_SPLIT_IMPL_HPP 20 template<
typename FitnessFunction,
typename ObservationType>
22 const size_t numClasses,
24 const size_t observationsBeforeBinning) :
25 observations(observationsBeforeBinning - 1),
26 labels(observationsBeforeBinning - 1),
28 observationsBeforeBinning(observationsBeforeBinning),
30 sufficientStatistics(
arma::zeros<
arma::Mat<size_t>>(numClasses, bins))
36 template<
typename FitnessFunction,
typename ObservationType>
38 const size_t numClasses,
40 observations(other.observationsBeforeBinning - 1),
41 labels(other.observationsBeforeBinning - 1),
43 observationsBeforeBinning(other.observationsBeforeBinning),
45 sufficientStatistics(
arma::zeros<
arma::Mat<size_t>>(numClasses, bins))
51 template<
typename FitnessFunction,
typename ObservationType>
53 ObservationType value,
56 if (samplesSeen < observationsBeforeBinning - 1)
59 observations[samplesSeen] = value;
60 labels[samplesSeen] = label;
64 else if (samplesSeen == observationsBeforeBinning - 1)
67 ObservationType min = value;
68 ObservationType max = value;
69 for (
size_t i = 0; i < observationsBeforeBinning - 1; ++i)
71 if (observations[i] < min)
72 min = observations[i];
73 else if (observations[i] > max)
74 max = observations[i];
79 splitPoints.resize(bins - 1);
80 const ObservationType binWidth = (max - min) / bins;
81 for (
size_t i = 0; i < bins - 1; ++i)
82 splitPoints[i] = min + (i + 1) * binWidth;
86 for (
size_t i = 0; i < observationsBeforeBinning - 1; ++i)
90 while (bin < bins - 1 && observations[i] > splitPoints[bin])
93 sufficientStatistics(labels[i], bin)++;
100 while (bin < bins - 1 && value > splitPoints[bin])
103 sufficientStatistics(label, bin)++;
106 template<
typename FitnessFunction,
typename ObservationType>
109 double& secondBestFitness)
const 111 secondBestFitness = 0.0;
112 if (samplesSeen < observationsBeforeBinning)
115 bestFitness = FitnessFunction::Evaluate(sufficientStatistics);
118 template<
typename FitnessFunction,
typename ObservationType>
120 arma::Col<size_t>& childMajorities,
123 childMajorities.set_size(sufficientStatistics.n_cols);
124 for (
size_t i = 0; i < sufficientStatistics.n_cols; ++i)
126 arma::uword maxIndex = 0;
127 sufficientStatistics.unsafe_col(i).max(maxIndex);
128 childMajorities[i] = size_t(maxIndex);
135 template<
typename FitnessFunction,
typename ObservationType>
140 if (samplesSeen < observationsBeforeBinning)
142 arma::Col<size_t> classes(sufficientStatistics.n_rows);
145 for (
size_t i = 0; i < samplesSeen; ++i)
146 classes[labels[i]]++;
148 arma::uword majorityClass;
149 classes.max(majorityClass);
150 return size_t(majorityClass);
156 arma::Col<size_t> classCounts = arma::sum(sufficientStatistics, 1);
158 arma::uword maxIndex = 0;
159 classCounts.max(maxIndex);
160 return size_t(maxIndex);
164 template<
typename FitnessFunction,
typename ObservationType>
169 if (samplesSeen < observationsBeforeBinning)
171 arma::Col<size_t> classes(sufficientStatistics.n_rows);
174 for (
size_t i = 0; i < samplesSeen; ++i)
175 classes[labels[i]]++;
177 return double(classes.max()) /
double(arma::accu(classes));
183 arma::Col<size_t> classCounts = arma::sum(sufficientStatistics, 1);
185 return double(classCounts.max()) /
double(arma::sum(classCounts));
189 template<
typename FitnessFunction,
typename ObservationType>
190 template<
typename Archive>
195 ar(CEREAL_NVP(samplesSeen));
196 ar(CEREAL_NVP(observationsBeforeBinning));
197 ar(CEREAL_NVP(bins));
199 if (samplesSeen >= observationsBeforeBinning)
202 ar(CEREAL_NVP(splitPoints));
203 ar(CEREAL_NVP(sufficientStatistics));
205 if (cereal::is_loading<Archive>())
208 observations.clear();
216 if (cereal::is_loading<Archive>())
218 observations.zeros(observationsBeforeBinning);
219 labels.zeros(observationsBeforeBinning);
224 if (cereal::is_saving<Archive>())
225 numClasses = sufficientStatistics.n_rows;
226 ar(CEREAL_NVP(numClasses));
227 ar(CEREAL_NVP(observations));
228 ar(CEREAL_NVP(labels));
230 if (cereal::is_loading<Archive>())
234 sufficientStatistics.zeros(numClasses, bins);
void EvaluateFitnessFunction(double &bestFitness, double &secondBestFitness) const
Evaluate the fitness function given what has been calculated so far.
Definition: hoeffding_numeric_split_impl.hpp:108
double MajorityProbability() const
Return the probability of the majority class.
Definition: hoeffding_numeric_split_impl.hpp:166
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Split(arma::Col< size_t > &childMajorities, SplitInfo &splitInfo) const
Return the majority class of each child to be created, if a split on this dimension was performed...
Definition: hoeffding_numeric_split_impl.hpp:119
The HoeffdingNumericSplit class implements the numeric feature splitting strategy alluded to by Domin...
Definition: hoeffding_numeric_split.hpp:53
void Train(ObservationType value, const size_t label)
Train the HoeffdingNumericSplit on the given observed value (remember that this object only cares abo...
Definition: hoeffding_numeric_split_impl.hpp:52
NumericSplitInfo< ObservationType > SplitInfo
The splitting information type required by the HoeffdingNumericSplit.
Definition: hoeffding_numeric_split.hpp:57
void serialize(Archive &ar, const uint32_t)
Serialize the object.
Definition: hoeffding_numeric_split_impl.hpp:191
Definition: numeric_split_info.hpp:21
size_t MajorityClass() const
Return the majority class.
Definition: hoeffding_numeric_split_impl.hpp:137
HoeffdingNumericSplit(const size_t numClasses=0, const size_t bins=10, const size_t observationsBeforeBinning=100)
Create the HoeffdingNumericSplit class, and specify some basic parameters about how the binning shoul...
Definition: hoeffding_numeric_split_impl.hpp:21