mlpack
binary_numeric_split_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_HOEFFDING_TREES_BINARY_NUMERIC_SPLIT_IMPL_HPP
13 #define MLPACK_METHODS_HOEFFDING_TREES_BINARY_NUMERIC_SPLIT_IMPL_HPP
14 
15 // In case it hasn't been included yet.
16 #include "binary_numeric_split.hpp"
17 
18 namespace mlpack {
19 namespace tree {
20 
21 template<typename FitnessFunction, typename ObservationType>
23  const size_t numClasses) :
24  classCounts(numClasses),
25  bestSplit(std::numeric_limits<ObservationType>::min()),
26  isAccurate(true)
27 {
28  // Zero out class counts.
29  classCounts.zeros();
30 }
31 
32 template<typename FitnessFunction, typename ObservationType>
34  const size_t numClasses,
35  const BinaryNumericSplit& /* other */) :
36  classCounts(numClasses),
37  bestSplit(std::numeric_limits<ObservationType>::min()),
38  isAccurate(true)
39 {
40  // Zero out class counts.
41  classCounts.zeros();
42 }
43 
44 template<typename FitnessFunction, typename ObservationType>
46  ObservationType value,
47  const size_t label)
48 {
49  // Push it into the multimap, and update the class counts.
50  sortedElements.insert(std::pair<ObservationType, size_t>(value, label));
51  ++classCounts[label];
52 
53  // Whatever we have cached is no longer valid.
54  isAccurate = false;
55 }
56 
57 template<typename FitnessFunction, typename ObservationType>
59  EvaluateFitnessFunction(double& bestFitness,
60  double& secondBestFitness)
61 {
62  // Unfortunately, we have to iterate over the map.
63  bestSplit = std::numeric_limits<ObservationType>::min();
64 
65  // Initialize the sufficient statistics.
66  arma::Mat<size_t> counts(classCounts.n_elem, 2);
67  counts.col(0).zeros();
68  counts.col(1) = classCounts;
69 
70  bestFitness = FitnessFunction::Evaluate(counts);
71  secondBestFitness = 0.0;
72 
73  // Initialize to the first observation, so we don't calculate gain on the
74  // first iteration (it will be 0).
75  ObservationType lastObservation = (*sortedElements.begin()).first;
76  size_t lastClass = classCounts.n_elem;
77  for (typename std::multimap<ObservationType, size_t>::const_iterator it =
78  sortedElements.begin(); it != sortedElements.end(); ++it)
79  {
80  // If this value is the same as the last, or if this is the first value, or
81  // we have the same class as the previous observation, don't calculate the
82  // gain---it can't be any better. (See Fayyad and Irani, 1991.)
83  if (((*it).first != lastObservation) || ((*it).second != lastClass))
84  {
85  lastObservation = (*it).first;
86  lastClass = (*it).second;
87 
88  const double value = FitnessFunction::Evaluate(counts);
89  if (value > bestFitness)
90  {
91  bestFitness = value;
92  bestSplit = (*it).first;
93  }
94  else if (value > secondBestFitness)
95  {
96  secondBestFitness = value;
97  }
98  }
99 
100  // Move the point to the right side of the split.
101  --counts((*it).second, 1);
102  ++counts((*it).second, 0);
103  }
104 
105  isAccurate = true;
106 }
107 
108 template<typename FitnessFunction, typename ObservationType>
110  arma::Col<size_t>& childMajorities,
111  SplitInfo& splitInfo)
112 {
113  if (!isAccurate)
114  {
115  double bestGain, secondBestGain;
116  EvaluateFitnessFunction(bestGain, secondBestGain);
117  }
118 
119  // Make one child for each side of the split.
120  childMajorities.set_size(2);
121 
122  arma::Mat<size_t> counts(classCounts.n_elem, 2);
123  counts.col(0).zeros();
124  counts.col(1) = classCounts;
125 
126  double min = DBL_MAX;
127  double max = -DBL_MAX;
128  for (typename std::multimap<ObservationType, size_t>::const_iterator it =
129  sortedElements.begin(); // (*it).first < bestSplit; ++it)
130  it != sortedElements.end(); ++it)
131  {
132  // Move the point to the correct side of the split.
133  if ((*it).first < bestSplit)
134  {
135  --counts((*it).second, 1);
136  ++counts((*it).second, 0);
137  }
138  if ((*it).first < min)
139  min = (*it).first;
140  if ((*it).first > max)
141  max = (*it).first;
142  }
143 
144  // Calculate the majority classes of the children.
145  arma::uword maxIndex;
146  counts.unsafe_col(0).max(maxIndex);
147  childMajorities[0] = size_t(maxIndex);
148  counts.unsafe_col(1).max(maxIndex);
149  childMajorities[1] = size_t(maxIndex);
150 
151  // Create the according SplitInfo object.
152  splitInfo = SplitInfo(bestSplit);
153 }
154 
155 template<typename FitnessFunction, typename ObservationType>
157  const
158 {
159  arma::uword maxIndex;
160  classCounts.max(maxIndex);
161  return size_t(maxIndex);
162 }
163 
164 template<typename FitnessFunction, typename ObservationType>
167 {
168  return double(arma::max(classCounts)) / double(arma::accu(classCounts));
169 }
170 
171 template<typename FitnessFunction, typename ObservationType>
172 template<typename Archive>
174  Archive& ar,
175  const uint32_t /* version */)
176 {
177  // Serialize.
178  ar(CEREAL_NVP(sortedElements));
179  ar(CEREAL_NVP(classCounts));
180 }
181 
182 
183 } // namespace tree
184 } // namespace mlpack
185 
186 #endif
size_t MajorityClass() const
The majority class of the points seen so far.
Definition: binary_numeric_split_impl.hpp:156
double MajorityProbability() const
The probability of the majority class given the points seen so far.
Definition: binary_numeric_split_impl.hpp:166
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
BinaryNumericSplit(const size_t numClasses=0)
Create the BinaryNumericSplit object with the given number of classes.
Definition: binary_numeric_split_impl.hpp:22
Definition: pointer_wrapper.hpp:23
The BinaryNumericSplit class implements the numeric feature splitting strategy devised by Gama...
Definition: binary_numeric_split.hpp:47
BinaryNumericSplitInfo< ObservationType > SplitInfo
The splitting information required by the BinaryNumericSplit.
Definition: binary_numeric_split.hpp:51
Definition: binary_numeric_split_info.hpp:22
void Train(ObservationType value, const size_t label)
Train on the given value with the given label.
Definition: binary_numeric_split_impl.hpp:45
void serialize(Archive &ar, const uint32_t)
Serialize the object.
Definition: binary_numeric_split_impl.hpp:173
void EvaluateFitnessFunction(double &bestFitness, double &secondBestFitness)
Given the points seen so far, evaluate the fitness function, returning the best possible gain of a bi...
Definition: binary_numeric_split_impl.hpp:59
void Split(arma::Col< size_t > &childMajorities, SplitInfo &splitInfo)
Given that a split should happen, return the majority classes of the (two) children and an initialize...
Definition: binary_numeric_split_impl.hpp:109