12 #ifndef MLPACK_METHODS_RANDOM_FOREST_RANDOM_FOREST_IMPL_HPP 13 #define MLPACK_METHODS_RANDOM_FOREST_RANDOM_FOREST_IMPL_HPP 22 typename FitnessFunction,
23 typename DimensionSelectionType,
24 template<
typename>
class NumericSplitType,
25 template<
typename>
class CategoricalSplitType,
30 DimensionSelectionType,
41 typename FitnessFunction,
42 typename DimensionSelectionType,
43 template<
typename>
class NumericSplitType,
44 template<
typename>
class CategoricalSplitType,
47 template<
typename MatType>
50 DimensionSelectionType,
55 const arma::Row<size_t>& labels,
56 const size_t numClasses,
57 const size_t numTrees,
58 const size_t minimumLeafSize,
59 const double minimumGainSplit,
60 const size_t maximumDepth,
61 DimensionSelectionType dimensionSelector) :
68 minimumLeafSize, minimumGainSplit, maximumDepth, dimensionSelector,
73 typename FitnessFunction,
74 typename DimensionSelectionType,
75 template<
typename>
class NumericSplitType,
76 template<
typename>
class CategoricalSplitType,
79 template<
typename MatType>
82 DimensionSelectionType,
88 const arma::Row<size_t>& labels,
89 const size_t numClasses,
90 const size_t numTrees,
91 const size_t minimumLeafSize,
92 const double minimumGainSplit,
93 const size_t maximumDepth,
94 DimensionSelectionType dimensionSelector):
100 numTrees, minimumLeafSize, minimumGainSplit, maximumDepth,
101 dimensionSelector,
false);
105 typename FitnessFunction,
106 typename DimensionSelectionType,
107 template<
typename>
class NumericSplitType,
108 template<
typename>
class CategoricalSplitType,
111 template<
typename MatType>
114 DimensionSelectionType,
116 CategoricalSplitType,
119 const arma::Row<size_t>& labels,
120 const size_t numClasses,
121 const arma::rowvec& weights,
122 const size_t numTrees,
123 const size_t minimumLeafSize,
124 const double minimumGainSplit,
125 const size_t maximumDepth,
126 DimensionSelectionType dimensionSelector) :
132 minimumLeafSize, minimumGainSplit, maximumDepth, dimensionSelector,
137 typename FitnessFunction,
138 typename DimensionSelectionType,
139 template<
typename>
class NumericSplitType,
140 template<
typename>
class CategoricalSplitType,
143 template<
typename MatType>
146 DimensionSelectionType,
148 CategoricalSplitType,
152 const arma::Row<size_t>& labels,
153 const size_t numClasses,
154 const arma::rowvec& weights,
155 const size_t numTrees,
156 const size_t minimumLeafSize,
157 const double minimumGainSplit,
158 const size_t maximumDepth,
159 DimensionSelectionType dimensionSelector) :
164 numTrees, minimumLeafSize, minimumGainSplit, maximumDepth,
165 dimensionSelector,
false);
169 typename FitnessFunction,
170 typename DimensionSelectionType,
171 template<
typename>
class NumericSplitType,
172 template<
typename>
class CategoricalSplitType,
175 template<
typename MatType>
178 DimensionSelectionType,
180 CategoricalSplitType,
183 const arma::Row<size_t>& labels,
184 const size_t numClasses,
185 const size_t numTrees,
186 const size_t minimumLeafSize,
187 const double minimumGainSplit,
188 const size_t maximumDepth,
189 const bool warmStart,
190 DimensionSelectionType dimensionSelector)
194 arma::rowvec weights;
196 numTrees, minimumLeafSize, minimumGainSplit, maximumDepth,
197 dimensionSelector, warmStart);
201 typename FitnessFunction,
202 typename DimensionSelectionType,
203 template<
typename>
class NumericSplitType,
204 template<
typename>
class CategoricalSplitType,
207 template<
typename MatType>
210 DimensionSelectionType,
212 CategoricalSplitType,
216 const arma::Row<size_t>& labels,
217 const size_t numClasses,
218 const size_t numTrees,
219 const size_t minimumLeafSize,
220 const double minimumGainSplit,
221 const size_t maximumDepth,
222 const bool warmStart,
223 DimensionSelectionType dimensionSelector)
226 arma::rowvec weights;
228 numTrees, minimumLeafSize, minimumGainSplit, maximumDepth,
229 dimensionSelector, warmStart);
233 typename FitnessFunction,
234 typename DimensionSelectionType,
235 template<
typename>
class NumericSplitType,
236 template<
typename>
class CategoricalSplitType,
239 template<
typename MatType>
242 DimensionSelectionType,
244 CategoricalSplitType,
247 const arma::Row<size_t>& labels,
248 const size_t numClasses,
249 const arma::rowvec& weights,
250 const size_t numTrees,
251 const size_t minimumLeafSize,
252 const double minimumGainSplit,
253 const size_t maximumDepth,
254 const bool warmStart,
255 DimensionSelectionType dimensionSelector)
260 numTrees, minimumLeafSize, minimumGainSplit, maximumDepth,
261 dimensionSelector, warmStart);
265 typename FitnessFunction,
266 typename DimensionSelectionType,
267 template<
typename>
class NumericSplitType,
268 template<
typename>
class CategoricalSplitType,
271 template<
typename MatType>
274 DimensionSelectionType,
276 CategoricalSplitType,
280 const arma::Row<size_t>& labels,
281 const size_t numClasses,
282 const arma::rowvec& weights,
283 const size_t numTrees,
284 const size_t minimumLeafSize,
285 const double minimumGainSplit,
286 const size_t maximumDepth,
287 const bool warmStart,
288 DimensionSelectionType dimensionSelector)
292 numTrees, minimumLeafSize, minimumGainSplit, maximumDepth,
293 dimensionSelector, warmStart);
297 typename FitnessFunction,
298 typename DimensionSelectionType,
299 template<
typename>
class NumericSplitType,
300 template<
typename>
class CategoricalSplitType,
303 template<
typename VecType>
306 DimensionSelectionType,
308 CategoricalSplitType,
313 size_t predictedClass;
314 arma::vec probabilities;
315 Classify(point, predictedClass, probabilities);
317 return predictedClass;
321 typename FitnessFunction,
322 typename DimensionSelectionType,
323 template<
typename>
class NumericSplitType,
324 template<
typename>
class CategoricalSplitType,
327 template<
typename VecType>
330 DimensionSelectionType,
332 CategoricalSplitType,
336 arma::vec& probabilities)
const 339 if (trees.size() == 0)
341 probabilities.clear();
344 throw std::invalid_argument(
"RandomForest::Classify(): no random forest " 348 probabilities.zeros(trees[0].NumClasses());
349 for (
size_t i = 0; i < trees.size(); ++i)
352 size_t treePrediction;
353 trees[i].Classify(point, treePrediction, treeProbs);
355 probabilities += treeProbs;
359 probabilities /= trees.size();
360 arma::uword maxIndex = 0;
361 probabilities.max(maxIndex);
364 prediction = (size_t) maxIndex;
368 typename FitnessFunction,
369 typename DimensionSelectionType,
370 template<
typename>
class NumericSplitType,
371 template<
typename>
class CategoricalSplitType,
374 template<
typename MatType>
377 DimensionSelectionType,
379 CategoricalSplitType,
382 arma::Row<size_t>& predictions)
const 385 if (trees.size() == 0)
389 throw std::invalid_argument(
"RandomForest::Classify(): no random forest " 393 predictions.set_size(data.n_cols);
395 #pragma omp parallel for 396 for (omp_size_t i = 0; i < data.n_cols; ++i)
398 predictions[i] =
Classify(data.col(i));
403 typename FitnessFunction,
404 typename DimensionSelectionType,
405 template<
typename>
class NumericSplitType,
406 template<
typename>
class CategoricalSplitType,
409 template<
typename MatType>
412 DimensionSelectionType,
414 CategoricalSplitType,
417 arma::Row<size_t>& predictions,
418 arma::mat& probabilities)
const 421 if (trees.size() == 0)
424 probabilities.clear();
426 throw std::invalid_argument(
"RandomForest::Classify(): no random forest " 430 probabilities.set_size(trees[0].NumClasses(), data.n_cols);
431 predictions.set_size(data.n_cols);
432 #pragma omp parallel for 433 for (omp_size_t i = 0; i < data.n_cols; ++i)
435 arma::vec probs = probabilities.unsafe_col(i);
436 Classify(data.col(i), predictions[i], probs);
441 typename FitnessFunction,
442 typename DimensionSelectionType,
443 template<
typename>
class NumericSplitType,
444 template<
typename>
class CategoricalSplitType,
447 template<
typename Archive>
450 DimensionSelectionType,
452 CategoricalSplitType,
457 if (cereal::is_loading<Archive>())
460 numTrees = trees.size();
462 ar(CEREAL_NVP(numTrees));
465 if (cereal::is_loading<Archive>())
466 trees.resize(numTrees);
468 ar(CEREAL_NVP(trees));
469 ar(CEREAL_NVP(avgGain));
473 typename FitnessFunction,
474 typename DimensionSelectionType,
475 template<
typename>
class NumericSplitType,
476 template<
typename>
class CategoricalSplitType,
479 template<
bool UseWeights,
bool UseDatasetInfo,
typename MatType>
482 DimensionSelectionType,
484 CategoricalSplitType,
486 >
::Train(
const MatType& dataset,
488 const arma::Row<size_t>& labels,
489 const size_t numClasses,
490 const arma::rowvec& weights,
491 const size_t numTrees,
492 const size_t minimumLeafSize,
493 const double minimumGainSplit,
494 const size_t maximumDepth,
495 DimensionSelectionType& dimensionSelector,
496 const bool warmStart)
501 const size_t oldNumTrees = trees.size();
502 trees.resize(trees.size() + numTrees);
505 double totalGain = avgGain * oldNumTrees;
508 #pragma omp parallel for reduction( + : totalGain) 509 for (omp_size_t i = 0; i < numTrees; ++i)
511 MatType bootstrapDataset;
512 arma::Row<size_t> bootstrapLabels;
513 arma::rowvec bootstrapWeights;
517 Bootstrap<UseWeights>(dataset, labels, weights, bootstrapDataset,
518 bootstrapLabels, bootstrapWeights);
527 totalGain += UseBootstrap ?
528 trees[oldNumTrees + i].Train(bootstrapDataset, datasetInfo,
529 bootstrapLabels, numClasses, bootstrapWeights, minimumLeafSize,
530 minimumGainSplit, maximumDepth, dimensionSelector) :
531 trees[oldNumTrees + i].Train(dataset, datasetInfo, labels,
532 numClasses, weights, minimumLeafSize, minimumGainSplit,
533 maximumDepth, dimensionSelector);
537 totalGain += UseBootstrap ?
538 trees[oldNumTrees + i].Train(bootstrapDataset, bootstrapLabels,
539 numClasses, bootstrapWeights, minimumLeafSize,
540 minimumGainSplit, maximumDepth, dimensionSelector) :
541 trees[oldNumTrees + i].Train(dataset, labels, numClasses,
542 weights, minimumLeafSize, minimumGainSplit, maximumDepth,
550 totalGain += UseBootstrap ?
551 trees[oldNumTrees + i].Train(bootstrapDataset, datasetInfo,
552 bootstrapLabels, numClasses, minimumLeafSize, minimumGainSplit,
553 maximumDepth, dimensionSelector) :
554 trees[oldNumTrees + i].Train(dataset, datasetInfo, labels,
555 numClasses, minimumLeafSize, minimumGainSplit, maximumDepth,
560 totalGain += UseBootstrap ?
561 trees[oldNumTrees + i].Train(bootstrapDataset, bootstrapLabels,
562 numClasses, minimumLeafSize, minimumGainSplit, maximumDepth,
564 trees[oldNumTrees + i].Train(dataset, labels, numClasses,
565 minimumLeafSize, minimumGainSplit, maximumDepth,
573 avgGain = totalGain / trees.size();
The RandomForest class provides an implementation of random forests, described in Breiman's seminal p...
Definition: random_forest.hpp:44
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
Definition: dataset_mapper.hpp:41
static void Start(const std::string &name)
Start the given timer.
Definition: timers.cpp:28
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
RandomForest()
Construct the random forest without any training or specifying the number of trees.
Definition: random_forest_impl.hpp:34
double Train(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const size_t numTrees=20, const size_t minimumLeafSize=1, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, const bool warmStart=false, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Train the random forest on the given labeled training data with the given number of trees...
Definition: random_forest_impl.hpp:182
size_t Classify(const VecType &point) const
Predict the class of the given point.
Definition: random_forest_impl.hpp:310
Definition: hmm_train_main.cpp:300
static void Stop(const std::string &name)
Stop the given timer.
Definition: timers.cpp:36
void serialize(Archive &ar, const uint32_t)
Serialize the random forest.
Definition: random_forest_impl.hpp:454