mlpack
random_forest_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_RANDOM_FOREST_RANDOM_FOREST_IMPL_HPP
13 #define MLPACK_METHODS_RANDOM_FOREST_RANDOM_FOREST_IMPL_HPP
14 
15 // In case it hasn't been included yet.
16 #include "random_forest.hpp"
17 
18 namespace mlpack {
19 namespace tree {
20 
21 template<
22  typename FitnessFunction,
23  typename DimensionSelectionType,
24  template<typename> class NumericSplitType,
25  template<typename> class CategoricalSplitType,
26  bool UseBootstrap
27 >
28 RandomForest<
29  FitnessFunction,
30  DimensionSelectionType,
31  NumericSplitType,
32  CategoricalSplitType,
33  UseBootstrap
35  avgGain(0.0)
36 {
37  // Nothing to do here.
38 }
39 
40 template<
41  typename FitnessFunction,
42  typename DimensionSelectionType,
43  template<typename> class NumericSplitType,
44  template<typename> class CategoricalSplitType,
45  bool UseBootstrap
46 >
47 template<typename MatType>
49  FitnessFunction,
50  DimensionSelectionType,
51  NumericSplitType,
52  CategoricalSplitType,
53  UseBootstrap
54 >::RandomForest(const MatType& dataset,
55  const arma::Row<size_t>& labels,
56  const size_t numClasses,
57  const size_t numTrees,
58  const size_t minimumLeafSize,
59  const double minimumGainSplit,
60  const size_t maximumDepth,
61  DimensionSelectionType dimensionSelector) :
62  avgGain(0.0)
63 {
64  // Pass off work to the Train() method.
65  data::DatasetInfo info; // Ignored.
66  arma::rowvec weights; // Fake weights, not used.
67  Train<false, false>(dataset, info, labels, numClasses, weights, numTrees,
68  minimumLeafSize, minimumGainSplit, maximumDepth, dimensionSelector,
69  false);
70 }
71 
72 template<
73  typename FitnessFunction,
74  typename DimensionSelectionType,
75  template<typename> class NumericSplitType,
76  template<typename> class CategoricalSplitType,
77  bool UseBootstrap
78 >
79 template<typename MatType>
81  FitnessFunction,
82  DimensionSelectionType,
83  NumericSplitType,
84  CategoricalSplitType,
85  UseBootstrap
86 >::RandomForest(const MatType& dataset,
87  const data::DatasetInfo& datasetInfo,
88  const arma::Row<size_t>& labels,
89  const size_t numClasses,
90  const size_t numTrees,
91  const size_t minimumLeafSize,
92  const double minimumGainSplit,
93  const size_t maximumDepth,
94  DimensionSelectionType dimensionSelector):
95  avgGain(0.0)
96 {
97  // Pass off work to the Train() method.
98  arma::rowvec weights; // Fake weights, not used.
99  Train<false, true>(dataset, datasetInfo, labels, numClasses, weights,
100  numTrees, minimumLeafSize, minimumGainSplit, maximumDepth,
101  dimensionSelector, false);
102 }
103 
104 template<
105  typename FitnessFunction,
106  typename DimensionSelectionType,
107  template<typename> class NumericSplitType,
108  template<typename> class CategoricalSplitType,
109  bool UseBootstrap
110 >
111 template<typename MatType>
113  FitnessFunction,
114  DimensionSelectionType,
115  NumericSplitType,
116  CategoricalSplitType,
117  UseBootstrap
118 >::RandomForest(const MatType& dataset,
119  const arma::Row<size_t>& labels,
120  const size_t numClasses,
121  const arma::rowvec& weights,
122  const size_t numTrees,
123  const size_t minimumLeafSize,
124  const double minimumGainSplit,
125  const size_t maximumDepth,
126  DimensionSelectionType dimensionSelector) :
127  avgGain(0.0)
128 {
129  // Pass off work to the Train() method.
130  data::DatasetInfo info; // Ignored by Train().
131  Train<true, false>(dataset, info, labels, numClasses, weights, numTrees,
132  minimumLeafSize, minimumGainSplit, maximumDepth, dimensionSelector,
133  false);
134 }
135 
136 template<
137  typename FitnessFunction,
138  typename DimensionSelectionType,
139  template<typename> class NumericSplitType,
140  template<typename> class CategoricalSplitType,
141  bool UseBootstrap
142 >
143 template<typename MatType>
145  FitnessFunction,
146  DimensionSelectionType,
147  NumericSplitType,
148  CategoricalSplitType,
149  UseBootstrap
150 >::RandomForest(const MatType& dataset,
151  const data::DatasetInfo& datasetInfo,
152  const arma::Row<size_t>& labels,
153  const size_t numClasses,
154  const arma::rowvec& weights,
155  const size_t numTrees,
156  const size_t minimumLeafSize,
157  const double minimumGainSplit,
158  const size_t maximumDepth,
159  DimensionSelectionType dimensionSelector) :
160  avgGain(0.0)
161 {
162  // Pass off work to the Train() method.
163  Train<true, true>(dataset, datasetInfo, labels, numClasses, weights,
164  numTrees, minimumLeafSize, minimumGainSplit, maximumDepth,
165  dimensionSelector, false);
166 }
167 
168 template<
169  typename FitnessFunction,
170  typename DimensionSelectionType,
171  template<typename> class NumericSplitType,
172  template<typename> class CategoricalSplitType,
173  bool UseBootstrap
174 >
175 template<typename MatType>
176 double RandomForest<
177  FitnessFunction,
178  DimensionSelectionType,
179  NumericSplitType,
180  CategoricalSplitType,
181  UseBootstrap
182 >::Train(const MatType& dataset,
183  const arma::Row<size_t>& labels,
184  const size_t numClasses,
185  const size_t numTrees,
186  const size_t minimumLeafSize,
187  const double minimumGainSplit,
188  const size_t maximumDepth,
189  const bool warmStart,
190  DimensionSelectionType dimensionSelector)
191 {
192  // Pass off to Train().
193  data::DatasetInfo datasetInfo; // Ignored by Train().
194  arma::rowvec weights; // Ignored by Train().
195  return Train<false, false>(dataset, datasetInfo, labels, numClasses, weights,
196  numTrees, minimumLeafSize, minimumGainSplit, maximumDepth,
197  dimensionSelector, warmStart);
198 }
199 
200 template<
201  typename FitnessFunction,
202  typename DimensionSelectionType,
203  template<typename> class NumericSplitType,
204  template<typename> class CategoricalSplitType,
205  bool UseBootstrap
206 >
207 template<typename MatType>
208 double RandomForest<
209  FitnessFunction,
210  DimensionSelectionType,
211  NumericSplitType,
212  CategoricalSplitType,
213  UseBootstrap
214 >::Train(const MatType& dataset,
215  const data::DatasetInfo& datasetInfo,
216  const arma::Row<size_t>& labels,
217  const size_t numClasses,
218  const size_t numTrees,
219  const size_t minimumLeafSize,
220  const double minimumGainSplit,
221  const size_t maximumDepth,
222  const bool warmStart,
223  DimensionSelectionType dimensionSelector)
224 {
225  // Pass off to Train().
226  arma::rowvec weights; // Ignored by Train().
227  return Train<false, true>(dataset, datasetInfo, labels, numClasses, weights,
228  numTrees, minimumLeafSize, minimumGainSplit, maximumDepth,
229  dimensionSelector, warmStart);
230 }
231 
232 template<
233  typename FitnessFunction,
234  typename DimensionSelectionType,
235  template<typename> class NumericSplitType,
236  template<typename> class CategoricalSplitType,
237  bool UseBootstrap
238 >
239 template<typename MatType>
240 double RandomForest<
241  FitnessFunction,
242  DimensionSelectionType,
243  NumericSplitType,
244  CategoricalSplitType,
245  UseBootstrap
246 >::Train(const MatType& dataset,
247  const arma::Row<size_t>& labels,
248  const size_t numClasses,
249  const arma::rowvec& weights,
250  const size_t numTrees,
251  const size_t minimumLeafSize,
252  const double minimumGainSplit,
253  const size_t maximumDepth,
254  const bool warmStart,
255  DimensionSelectionType dimensionSelector)
256 {
257  // Pass off to Train().
258  data::DatasetInfo datasetInfo; // Ignored by Train().
259  return Train<false, false>(dataset, datasetInfo, labels, numClasses, weights,
260  numTrees, minimumLeafSize, minimumGainSplit, maximumDepth,
261  dimensionSelector, warmStart);
262 }
263 
264 template<
265  typename FitnessFunction,
266  typename DimensionSelectionType,
267  template<typename> class NumericSplitType,
268  template<typename> class CategoricalSplitType,
269  bool UseBootstrap
270 >
271 template<typename MatType>
272 double RandomForest<
273  FitnessFunction,
274  DimensionSelectionType,
275  NumericSplitType,
276  CategoricalSplitType,
277  UseBootstrap
278 >::Train(const MatType& dataset,
279  const data::DatasetInfo& datasetInfo,
280  const arma::Row<size_t>& labels,
281  const size_t numClasses,
282  const arma::rowvec& weights,
283  const size_t numTrees,
284  const size_t minimumLeafSize,
285  const double minimumGainSplit,
286  const size_t maximumDepth,
287  const bool warmStart,
288  DimensionSelectionType dimensionSelector)
289 {
290  // Pass off to Train().
291  return Train<true, true>(dataset, datasetInfo, labels, numClasses, weights,
292  numTrees, minimumLeafSize, minimumGainSplit, maximumDepth,
293  dimensionSelector, warmStart);
294 }
295 
296 template<
297  typename FitnessFunction,
298  typename DimensionSelectionType,
299  template<typename> class NumericSplitType,
300  template<typename> class CategoricalSplitType,
301  bool UseBootstrap
302 >
303 template<typename VecType>
304 size_t RandomForest<
305  FitnessFunction,
306  DimensionSelectionType,
307  NumericSplitType,
308  CategoricalSplitType,
309  UseBootstrap
310 >::Classify(const VecType& point) const
311 {
312  // Pass off to another Classify() overload.
313  size_t predictedClass;
314  arma::vec probabilities;
315  Classify(point, predictedClass, probabilities);
316 
317  return predictedClass;
318 }
319 
320 template<
321  typename FitnessFunction,
322  typename DimensionSelectionType,
323  template<typename> class NumericSplitType,
324  template<typename> class CategoricalSplitType,
325  bool UseBootstrap
326 >
327 template<typename VecType>
328 void RandomForest<
329  FitnessFunction,
330  DimensionSelectionType,
331  NumericSplitType,
332  CategoricalSplitType,
333  UseBootstrap
334 >::Classify(const VecType& point,
335  size_t& prediction,
336  arma::vec& probabilities) const
337 {
338  // Check edge case.
339  if (trees.size() == 0)
340  {
341  probabilities.clear();
342  prediction = 0;
343 
344  throw std::invalid_argument("RandomForest::Classify(): no random forest "
345  "trained!");
346  }
347 
348  probabilities.zeros(trees[0].NumClasses());
349  for (size_t i = 0; i < trees.size(); ++i)
350  {
351  arma::vec treeProbs;
352  size_t treePrediction; // Ignored.
353  trees[i].Classify(point, treePrediction, treeProbs);
354 
355  probabilities += treeProbs;
356  }
357 
358  // Find maximum element after renormalizing probabilities.
359  probabilities /= trees.size();
360  arma::uword maxIndex = 0;
361  probabilities.max(maxIndex);
362 
363  // Set prediction.
364  prediction = (size_t) maxIndex;
365 }
366 
367 template<
368  typename FitnessFunction,
369  typename DimensionSelectionType,
370  template<typename> class NumericSplitType,
371  template<typename> class CategoricalSplitType,
372  bool UseBootstrap
373 >
374 template<typename MatType>
375 void RandomForest<
376  FitnessFunction,
377  DimensionSelectionType,
378  NumericSplitType,
379  CategoricalSplitType,
380  UseBootstrap
381 >::Classify(const MatType& data,
382  arma::Row<size_t>& predictions) const
383 {
384  // Check edge case.
385  if (trees.size() == 0)
386  {
387  predictions.clear();
388 
389  throw std::invalid_argument("RandomForest::Classify(): no random forest "
390  "trained!");
391  }
392 
393  predictions.set_size(data.n_cols);
394 
395  #pragma omp parallel for
396  for (omp_size_t i = 0; i < data.n_cols; ++i)
397  {
398  predictions[i] = Classify(data.col(i));
399  }
400 }
401 
402 template<
403  typename FitnessFunction,
404  typename DimensionSelectionType,
405  template<typename> class NumericSplitType,
406  template<typename> class CategoricalSplitType,
407  bool UseBootstrap
408 >
409 template<typename MatType>
410 void RandomForest<
411  FitnessFunction,
412  DimensionSelectionType,
413  NumericSplitType,
414  CategoricalSplitType,
415  UseBootstrap
416 >::Classify(const MatType& data,
417  arma::Row<size_t>& predictions,
418  arma::mat& probabilities) const
419 {
420  // Check edge case.
421  if (trees.size() == 0)
422  {
423  predictions.clear();
424  probabilities.clear();
425 
426  throw std::invalid_argument("RandomForest::Classify(): no random forest "
427  "trained!");
428  }
429 
430  probabilities.set_size(trees[0].NumClasses(), data.n_cols);
431  predictions.set_size(data.n_cols);
432  #pragma omp parallel for
433  for (omp_size_t i = 0; i < data.n_cols; ++i)
434  {
435  arma::vec probs = probabilities.unsafe_col(i);
436  Classify(data.col(i), predictions[i], probs);
437  }
438 }
439 
440 template<
441  typename FitnessFunction,
442  typename DimensionSelectionType,
443  template<typename> class NumericSplitType,
444  template<typename> class CategoricalSplitType,
445  bool UseBootstrap
446 >
447 template<typename Archive>
448 void RandomForest<
449  FitnessFunction,
450  DimensionSelectionType,
451  NumericSplitType,
452  CategoricalSplitType,
453  UseBootstrap
454 >::serialize(Archive& ar, const uint32_t /* version */)
455 {
456  size_t numTrees;
457  if (cereal::is_loading<Archive>())
458  trees.clear();
459  else
460  numTrees = trees.size();
461 
462  ar(CEREAL_NVP(numTrees));
463 
464  // Allocate space if needed.
465  if (cereal::is_loading<Archive>())
466  trees.resize(numTrees);
467 
468  ar(CEREAL_NVP(trees));
469  ar(CEREAL_NVP(avgGain));
470 }
471 
472 template<
473  typename FitnessFunction,
474  typename DimensionSelectionType,
475  template<typename> class NumericSplitType,
476  template<typename> class CategoricalSplitType,
477  bool UseBootstrap
478 >
479 template<bool UseWeights, bool UseDatasetInfo, typename MatType>
480 double RandomForest<
481  FitnessFunction,
482  DimensionSelectionType,
483  NumericSplitType,
484  CategoricalSplitType,
485  UseBootstrap
486 >::Train(const MatType& dataset,
487  const data::DatasetInfo& datasetInfo,
488  const arma::Row<size_t>& labels,
489  const size_t numClasses,
490  const arma::rowvec& weights,
491  const size_t numTrees,
492  const size_t minimumLeafSize,
493  const double minimumGainSplit,
494  const size_t maximumDepth,
495  DimensionSelectionType& dimensionSelector,
496  const bool warmStart)
497 {
498  // Reset the forest if we are not doing a warm-start.
499  if (!warmStart)
500  trees.clear();
501  const size_t oldNumTrees = trees.size();
502  trees.resize(trees.size() + numTrees);
503 
504  // Convert avgGain to total gain.
505  double totalGain = avgGain * oldNumTrees;
506 
507  // Train each tree individually.
508  #pragma omp parallel for reduction( + : totalGain)
509  for (omp_size_t i = 0; i < numTrees; ++i)
510  {
511  MatType bootstrapDataset;
512  arma::Row<size_t> bootstrapLabels;
513  arma::rowvec bootstrapWeights;
514  if (UseBootstrap)
515  {
516  Timer::Start("bootstrap");
517  Bootstrap<UseWeights>(dataset, labels, weights, bootstrapDataset,
518  bootstrapLabels, bootstrapWeights);
519  Timer::Stop("bootstrap");
520  }
521 
522  Timer::Start("train_tree");
523  if (UseWeights)
524  {
525  if (UseDatasetInfo)
526  {
527  totalGain += UseBootstrap ?
528  trees[oldNumTrees + i].Train(bootstrapDataset, datasetInfo,
529  bootstrapLabels, numClasses, bootstrapWeights, minimumLeafSize,
530  minimumGainSplit, maximumDepth, dimensionSelector) :
531  trees[oldNumTrees + i].Train(dataset, datasetInfo, labels,
532  numClasses, weights, minimumLeafSize, minimumGainSplit,
533  maximumDepth, dimensionSelector);
534  }
535  else
536  {
537  totalGain += UseBootstrap ?
538  trees[oldNumTrees + i].Train(bootstrapDataset, bootstrapLabels,
539  numClasses, bootstrapWeights, minimumLeafSize,
540  minimumGainSplit, maximumDepth, dimensionSelector) :
541  trees[oldNumTrees + i].Train(dataset, labels, numClasses,
542  weights, minimumLeafSize, minimumGainSplit, maximumDepth,
543  dimensionSelector);
544  }
545  }
546  else
547  {
548  if (UseDatasetInfo)
549  {
550  totalGain += UseBootstrap ?
551  trees[oldNumTrees + i].Train(bootstrapDataset, datasetInfo,
552  bootstrapLabels, numClasses, minimumLeafSize, minimumGainSplit,
553  maximumDepth, dimensionSelector) :
554  trees[oldNumTrees + i].Train(dataset, datasetInfo, labels,
555  numClasses, minimumLeafSize, minimumGainSplit, maximumDepth,
556  dimensionSelector);
557  }
558  else
559  {
560  totalGain += UseBootstrap ?
561  trees[oldNumTrees + i].Train(bootstrapDataset, bootstrapLabels,
562  numClasses, minimumLeafSize, minimumGainSplit, maximumDepth,
563  dimensionSelector) :
564  trees[oldNumTrees + i].Train(dataset, labels, numClasses,
565  minimumLeafSize, minimumGainSplit, maximumDepth,
566  dimensionSelector);
567  }
568  }
569 
570  Timer::Stop("train_tree");
571  }
572 
573  avgGain = totalGain / trees.size();
574  return avgGain;
575 }
576 
577 } // namespace tree
578 } // namespace mlpack
579 
580 #endif
The RandomForest class provides an implementation of random forests, described in Breiman&#39;s seminal p...
Definition: random_forest.hpp:44
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
Definition: dataset_mapper.hpp:41
static void Start(const std::string &name)
Start the given timer.
Definition: timers.cpp:28
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
RandomForest()
Construct the random forest without any training or specifying the number of trees.
Definition: random_forest_impl.hpp:34
double Train(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const size_t numTrees=20, const size_t minimumLeafSize=1, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, const bool warmStart=false, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Train the random forest on the given labeled training data with the given number of trees...
Definition: random_forest_impl.hpp:182
size_t Classify(const VecType &point) const
Predict the class of the given point.
Definition: random_forest_impl.hpp:310
Definition: hmm_train_main.cpp:300
static void Stop(const std::string &name)
Stop the given timer.
Definition: timers.cpp:36
void serialize(Archive &ar, const uint32_t)
Serialize the random forest.
Definition: random_forest_impl.hpp:454