mlpack
hoeffding_categorical_split_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_HOEFFDING_TREES_HOEFFDING_CATEGORICAL_SPLIT_IMPL_HPP
13 #define MLPACK_METHODS_HOEFFDING_TREES_HOEFFDING_CATEGORICAL_SPLIT_IMPL_HPP
14 
15 // In case it hasn't been included yet.
17 
18 namespace mlpack {
19 namespace tree {
20 
21 template<typename FitnessFunction>
23  const size_t numCategories,
24  const size_t numClasses) :
25  sufficientStatistics(numClasses, numCategories)
26 {
27  sufficientStatistics.zeros();
28 }
29 
30 template<typename FitnessFunction>
32  const size_t numCategories,
33  const size_t numClasses,
34  const HoeffdingCategoricalSplit& /* other */) :
35  sufficientStatistics(numClasses, numCategories)
36 {
37  sufficientStatistics.zeros();
38 }
39 
40 template<typename FitnessFunction>
41 template<typename eT>
43  const size_t label)
44 {
45  // Add this to our counts.
46  // 'value' should be categorical, so we should be able to cast to size_t...
47  sufficientStatistics(label, size_t(value))++;
48 }
49 
50 template<typename FitnessFunction>
52  double& bestFitness,
53  double& secondBestFitness) const
54 {
55  bestFitness = FitnessFunction::Evaluate(sufficientStatistics);
56  secondBestFitness = 0.0; // We only split one possible way.
57 }
58 
59 template<typename FitnessFunction>
61  arma::Col<size_t>& childMajorities,
62  SplitInfo& splitInfo)
63 {
64  // We'll make one child for each category.
65  childMajorities.set_size(sufficientStatistics.n_cols);
66  for (size_t i = 0; i < sufficientStatistics.n_cols; ++i)
67  {
68  arma::uword maxIndex = 0;
69  sufficientStatistics.unsafe_col(i).max(maxIndex);
70  childMajorities[i] = size_t(maxIndex);
71  }
72 
73  // Create the according SplitInfo object.
74  splitInfo = SplitInfo(sufficientStatistics.n_cols);
75 }
76 
77 template<typename FitnessFunction>
79 {
80  // Calculate the class that we have seen the most of.
81  arma::Col<size_t> classCounts = arma::sum(sufficientStatistics, 1);
82 
83  arma::uword maxIndex = 0;
84  classCounts.max(maxIndex);
85 
86  return size_t(maxIndex);
87 }
88 
89 template<typename FitnessFunction>
91 {
92  arma::Col<size_t> classCounts = arma::sum(sufficientStatistics, 1);
93 
94  return double(classCounts.max()) / double(arma::accu(classCounts));
95 }
96 
97 } // namespace tree
98 } // namespace mlpack
99 
100 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void EvaluateFitnessFunction(double &bestFitness, double &secondBestFitness) const
Given the points seen so far, evaluate the fitness function, returning the gain for the best possible...
Definition: hoeffding_categorical_split_impl.hpp:51
size_t MajorityClass() const
Get the majority class seen so far.
Definition: hoeffding_categorical_split_impl.hpp:78
Definition: categorical_split_info.hpp:20
void Split(arma::Col< size_t > &childMajorities, SplitInfo &splitInfo)
Gather the information for a split: get the labels of the child majorities, and initialize the SplitI...
Definition: hoeffding_categorical_split_impl.hpp:60
CategoricalSplitInfo SplitInfo
The type of split information required by the HoeffdingCategoricalSplit.
Definition: hoeffding_categorical_split.hpp:48
HoeffdingCategoricalSplit(const size_t numCategories=0, const size_t numClasses=0)
Create the HoeffdingCategoricalSplit given a number of categories for this dimension and a number of ...
Definition: hoeffding_categorical_split_impl.hpp:22
double MajorityProbability() const
Get the probability of the majority class given the points seen so far.
Definition: hoeffding_categorical_split_impl.hpp:90
This is the standard Hoeffding-bound categorical feature proposed in the paper below: ...
Definition: hoeffding_categorical_split.hpp:44
void Train(eT value, const size_t label)
Train on the given value with the given label.
Definition: hoeffding_categorical_split_impl.hpp:42