12 #ifndef MLPACK_METHODS_HOEFFDING_TREES_BINARY_NUMERIC_SPLIT_IMPL_HPP 13 #define MLPACK_METHODS_HOEFFDING_TREES_BINARY_NUMERIC_SPLIT_IMPL_HPP 21 template<
typename FitnessFunction,
typename ObservationType>
23 const size_t numClasses) :
24 classCounts(numClasses),
25 bestSplit(
std::numeric_limits<ObservationType>::min()),
32 template<
typename FitnessFunction,
typename ObservationType>
34 const size_t numClasses,
36 classCounts(numClasses),
37 bestSplit(
std::numeric_limits<ObservationType>::min()),
44 template<
typename FitnessFunction,
typename ObservationType>
46 ObservationType value,
50 sortedElements.insert(std::pair<ObservationType, size_t>(value, label));
57 template<
typename FitnessFunction,
typename ObservationType>
60 double& secondBestFitness)
63 bestSplit = std::numeric_limits<ObservationType>::min();
66 arma::Mat<size_t> counts(classCounts.n_elem, 2);
67 counts.col(0).zeros();
68 counts.col(1) = classCounts;
70 bestFitness = FitnessFunction::Evaluate(counts);
71 secondBestFitness = 0.0;
75 ObservationType lastObservation = (*sortedElements.begin()).first;
76 size_t lastClass = classCounts.n_elem;
77 for (
typename std::multimap<ObservationType, size_t>::const_iterator it =
78 sortedElements.begin(); it != sortedElements.end(); ++it)
83 if (((*it).first != lastObservation) || ((*it).second != lastClass))
85 lastObservation = (*it).first;
86 lastClass = (*it).second;
88 const double value = FitnessFunction::Evaluate(counts);
89 if (value > bestFitness)
92 bestSplit = (*it).first;
94 else if (value > secondBestFitness)
96 secondBestFitness = value;
101 --counts((*it).second, 1);
102 ++counts((*it).second, 0);
108 template<
typename FitnessFunction,
typename ObservationType>
110 arma::Col<size_t>& childMajorities,
115 double bestGain, secondBestGain;
120 childMajorities.set_size(2);
122 arma::Mat<size_t> counts(classCounts.n_elem, 2);
123 counts.col(0).zeros();
124 counts.col(1) = classCounts;
126 double min = DBL_MAX;
127 double max = -DBL_MAX;
128 for (
typename std::multimap<ObservationType, size_t>::const_iterator it =
129 sortedElements.begin();
130 it != sortedElements.end(); ++it)
133 if ((*it).first < bestSplit)
135 --counts((*it).second, 1);
136 ++counts((*it).second, 0);
138 if ((*it).first < min)
140 if ((*it).first > max)
145 arma::uword maxIndex;
146 counts.unsafe_col(0).max(maxIndex);
147 childMajorities[0] = size_t(maxIndex);
148 counts.unsafe_col(1).max(maxIndex);
149 childMajorities[1] = size_t(maxIndex);
155 template<
typename FitnessFunction,
typename ObservationType>
159 arma::uword maxIndex;
160 classCounts.max(maxIndex);
161 return size_t(maxIndex);
164 template<
typename FitnessFunction,
typename ObservationType>
168 return double(arma::max(classCounts)) / double(arma::accu(classCounts));
171 template<
typename FitnessFunction,
typename ObservationType>
172 template<
typename Archive>
178 ar(CEREAL_NVP(sortedElements));
179 ar(CEREAL_NVP(classCounts));
size_t MajorityClass() const
The majority class of the points seen so far.
Definition: binary_numeric_split_impl.hpp:156
double MajorityProbability() const
The probability of the majority class given the points seen so far.
Definition: binary_numeric_split_impl.hpp:166
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
BinaryNumericSplit(const size_t numClasses=0)
Create the BinaryNumericSplit object with the given number of classes.
Definition: binary_numeric_split_impl.hpp:22
Definition: pointer_wrapper.hpp:23
The BinaryNumericSplit class implements the numeric feature splitting strategy devised by Gama...
Definition: binary_numeric_split.hpp:47
BinaryNumericSplitInfo< ObservationType > SplitInfo
The splitting information required by the BinaryNumericSplit.
Definition: binary_numeric_split.hpp:51
Definition: binary_numeric_split_info.hpp:22
void Train(ObservationType value, const size_t label)
Train on the given value with the given label.
Definition: binary_numeric_split_impl.hpp:45
void serialize(Archive &ar, const uint32_t)
Serialize the object.
Definition: binary_numeric_split_impl.hpp:173
void EvaluateFitnessFunction(double &bestFitness, double &secondBestFitness)
Given the points seen so far, evaluate the fitness function, returning the best possible gain of a bi...
Definition: binary_numeric_split_impl.hpp:59
void Split(arma::Col< size_t > &childMajorities, SplitInfo &splitInfo)
Given that a split should happen, return the majority classes of the (two) children and an initialize...
Definition: binary_numeric_split_impl.hpp:109