mlpack
decision_tree_regressor_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_DECISION_TREE_DECISION_TREE_REGRESSOR_IMPL_HPP
13 #define MLPACK_METHODS_DECISION_TREE_DECISION_TREE_REGRESSOR_IMPL_HPP
14 
16 #include "utils.hpp"
17 
18 namespace mlpack {
19 namespace tree {
20 
22 template<typename FitnessFunction,
23  template<typename> class NumericSplitType,
24  template<typename> class CategoricalSplitType,
25  typename DimensionSelectionType,
26  bool NoRecursion>
27 DecisionTreeRegressor<FitnessFunction,
28  NumericSplitType,
29  CategoricalSplitType,
30  DimensionSelectionType,
31  NoRecursion>::DecisionTreeRegressor() :
32  splitDimension(0),
33  dimensionType(0),
34  splitPointOrPrediction(0.0)
35 {
36  // Nothing to do here.
37 }
38 
40 template<typename FitnessFunction,
41  template<typename> class NumericSplitType,
42  template<typename> class CategoricalSplitType,
43  typename DimensionSelectionType,
44  bool NoRecursion>
45 template<typename MatType, typename ResponsesType>
46 DecisionTreeRegressor<FitnessFunction,
47  NumericSplitType,
48  CategoricalSplitType,
49  DimensionSelectionType,
51  MatType data,
52  const data::DatasetInfo& datasetInfo,
53  ResponsesType responses,
54  const size_t minimumLeafSize,
55  const double minimumGainSplit,
56  const size_t maximumDepth,
57  DimensionSelectionType dimensionSelector)
58 {
59  using TrueMatType = typename std::decay<MatType>::type;
60  using TrueResponsesType = typename std::decay<ResponsesType>::type;
61 
62  // Copy or move data.
63  TrueMatType tmpData(std::move(data));
64  TrueResponsesType tmpResponses(std::move(responses));
65 
66  // Set the correct dimensionality for the dimension selector.
67  dimensionSelector.Dimensions() = tmpData.n_rows;
68 
69  // Pass off work to the Train() method.
70  arma::rowvec weights; // Fake weights, not used.
71  Train<false>(tmpData, 0, tmpData.n_cols, datasetInfo, tmpResponses,
72  weights, minimumLeafSize, minimumGainSplit, maximumDepth,
73  dimensionSelector);
74 }
75 
77 template<typename FitnessFunction,
78  template<typename> class NumericSplitType,
79  template<typename> class CategoricalSplitType,
80  typename DimensionSelectionType,
81  bool NoRecursion>
82 template<typename MatType, typename ResponsesType>
83 DecisionTreeRegressor<FitnessFunction,
84  NumericSplitType,
85  CategoricalSplitType,
86  DimensionSelectionType,
88  MatType data,
89  ResponsesType responses,
90  const size_t minimumLeafSize,
91  const double minimumGainSplit,
92  const size_t maximumDepth,
93  DimensionSelectionType dimensionSelector)
94 {
95  using TrueMatType = typename std::decay<MatType>::type;
96  using TrueResponsesType = typename std::decay<ResponsesType>::type;
97 
98  // Copy or move data.
99  TrueMatType tmpData(std::move(data));
100  TrueResponsesType tmpResponses(std::move(responses));
101 
102  // Set the correct dimensionality for the dimension selector.
103  dimensionSelector.Dimensions() = tmpData.n_rows;
104 
105  // Pass off work to the Train() method.
106  arma::rowvec weights; // Fake weights, not used.
107  Train<false>(tmpData, 0, tmpData.n_cols, tmpResponses, weights,
108  minimumLeafSize, minimumGainSplit, maximumDepth, dimensionSelector);
109 }
110 
112 template<typename FitnessFunction,
113  template<typename> class NumericSplitType,
114  template<typename> class CategoricalSplitType,
115  typename DimensionSelectionType,
116  bool NoRecursion>
117 template<typename MatType, typename ResponsesType, typename WeightsType>
118 DecisionTreeRegressor<FitnessFunction,
119  NumericSplitType,
120  CategoricalSplitType,
121  DimensionSelectionType,
123  MatType data,
124  const data::DatasetInfo& datasetInfo,
125  ResponsesType responses,
126  WeightsType weights,
127  const size_t minimumLeafSize,
128  const double minimumGainSplit,
129  const size_t maximumDepth,
130  DimensionSelectionType dimensionSelector,
131  const std::enable_if_t<arma::is_arma_type<
132  typename std::remove_reference<WeightsType>::type>::value>*)
133 {
134  using TrueMatType = typename std::decay<MatType>::type;
135  using TrueResponsesType = typename std::decay<ResponsesType>::type;
136  using TrueWeightsType = typename std::decay<WeightsType>::type;
137 
138  TrueMatType tmpData(std::move(data));
139  TrueResponsesType tmpResponses(std::move(responses));
140  TrueWeightsType tmpWeights(std::move(weights));
141 
142  // Set the correct dimensionality for the dimension selector.
143  dimensionSelector.Dimensions() = tmpData.n_rows;
144 
145  // Pass off work to the weighted Train() method.
146  Train<true>(tmpData, 0, tmpData.n_cols, datasetInfo, tmpResponses,
147  tmpWeights, minimumLeafSize, minimumGainSplit, maximumDepth,
148  dimensionSelector);
149 }
150 
152 template<typename FitnessFunction,
153  template<typename> class NumericSplitType,
154  template<typename> class CategoricalSplitType,
155  typename DimensionSelectionType,
156  bool NoRecursion>
157 template<typename MatType, typename ResponsesType, typename WeightsType>
158 DecisionTreeRegressor<FitnessFunction,
159  NumericSplitType,
160  CategoricalSplitType,
161  DimensionSelectionType,
163  MatType data,
164  ResponsesType responses,
165  WeightsType weights,
166  const size_t minimumLeafSize,
167  const double minimumGainSplit,
168  const size_t maximumDepth,
169  DimensionSelectionType dimensionSelector,
170  const std::enable_if_t<
171  arma::is_arma_type<
172  typename std::remove_reference<
173  WeightsType>::type>::value>*)
174 {
175  using TrueMatType = typename std::decay<MatType>::type;
176  using TrueResponsesType = typename std::decay<ResponsesType>::type;
177  using TrueWeightsType = typename std::decay<WeightsType>::type;
178 
179  // Copy or move data.
180  TrueMatType tmpData(std::move(data));
181  TrueResponsesType tmpResponses(std::move(responses));
182  TrueWeightsType tmpWeights(std::move(weights));
183 
184  // Set the correct dimensionality for the dimension selector.
185  dimensionSelector.Dimensions() = tmpData.n_rows;
186 
187  // Pass off work to the weighted Train() method.
188  Train<true>(tmpData, 0, tmpData.n_cols, tmpResponses, tmpWeights,
189  minimumLeafSize, minimumGainSplit, maximumDepth, dimensionSelector);
190 }
191 
193 template<typename FitnessFunction,
194  template<typename> class NumericSplitType,
195  template<typename> class CategoricalSplitType,
196  typename DimensionSelectionType,
197  bool NoRecursion>
198 template<typename MatType, typename ResponsesType, typename WeightsType>
199 DecisionTreeRegressor<FitnessFunction,
200  NumericSplitType,
201  CategoricalSplitType,
202  DimensionSelectionType,
204  const DecisionTreeRegressor& other,
205  MatType data,
206  const data::DatasetInfo& datasetInfo,
207  ResponsesType responses,
208  WeightsType weights,
209  const size_t minimumLeafSize,
210  const double minimumGainSplit,
211  const std::enable_if_t<arma::is_arma_type<
212  typename std::remove_reference<WeightsType>::type>::value>*):
213  NumericAuxiliarySplitInfo(other),
214  CategoricalAuxiliarySplitInfo(other)
215 {
216  using TrueMatType = typename std::decay<MatType>::type;
217  using TrueResponsesType = typename std::decay<ResponsesType>::type;
218  using TrueWeightsType = typename std::decay<WeightsType>::type;
219 
220  // Copy or move data.
221  TrueMatType tmpData(std::move(data));
222  TrueResponsesType tmpResponses(std::move(responses));
223  TrueWeightsType tmpWeights(std::move(weights));
224 
225  // Pass off work to the weighted Train() method.
226  Train<true>(tmpData, 0, tmpData.n_cols, datasetInfo, tmpResponses,
227  tmpWeights, minimumLeafSize, minimumGainSplit);
228 }
229 
231 template<typename FitnessFunction,
232  template<typename> class NumericSplitType,
233  template<typename> class CategoricalSplitType,
234  typename DimensionSelectionType,
235  bool NoRecursion>
236 template<typename MatType, typename ResponsesType, typename WeightsType>
237 DecisionTreeRegressor<FitnessFunction,
238  NumericSplitType,
239  CategoricalSplitType,
240  DimensionSelectionType,
242  const DecisionTreeRegressor& other,
243  MatType data,
244  ResponsesType responses,
245  WeightsType weights,
246  const size_t minimumLeafSize,
247  const double minimumGainSplit,
248  const size_t maximumDepth,
249  DimensionSelectionType dimensionSelector,
250  const std::enable_if_t<arma::is_arma_type<
251  typename std::remove_reference<
252  WeightsType>::type>::value>*):
253  NumericAuxiliarySplitInfo(other),
254  CategoricalAuxiliarySplitInfo(other) // other info does need to copy
255 {
256  using TrueMatType = typename std::decay<MatType>::type;
257  using TrueResponsesType = typename std::decay<ResponsesType>::type;
258  using TrueWeightsType = typename std::decay<WeightsType>::type;
259 
260  // Copy or move data.
261  TrueMatType tmpData(std::move(data));
262  TrueResponsesType tmpResponses(std::move(responses));
263  TrueWeightsType tmpWeights(std::move(weights));
264 
265  // Set the correct dimensionality for the dimension selector.
266  dimensionSelector.Dimensions() = tmpData.n_rows;
267 
268  // Pass off work to the weighted Train() method.
269  Train<true>(tmpData, 0, tmpData.n_cols, tmpResponses, tmpWeights,
270  minimumLeafSize, minimumGainSplit, maximumDepth, dimensionSelector);
271 }
272 
274 template<typename FitnessFunction,
275  template<typename> class NumericSplitType,
276  template<typename> class CategoricalSplitType,
277  typename DimensionSelectionType,
278  bool NoRecursion>
279 DecisionTreeRegressor<FitnessFunction,
280  NumericSplitType,
281  CategoricalSplitType,
282  DimensionSelectionType,
283  NoRecursion
285  const DecisionTreeRegressor& other) :
286  NumericAuxiliarySplitInfo(other),
287  CategoricalAuxiliarySplitInfo(other),
288  splitDimension(other.splitDimension),
289  dimensionType(other.dimensionType),
290  splitPointOrPrediction(other.splitPointOrPrediction)
291 {
292  // Copy each child.
293  for (size_t i = 0; i < other.children.size(); ++i)
294  children.push_back(new DecisionTreeRegressor(*other.children[i]));
295 }
296 
298 template<typename FitnessFunction,
299  template<typename> class NumericSplitType,
300  template<typename> class CategoricalSplitType,
301  typename DimensionSelectionType,
302  bool NoRecursion>
303 DecisionTreeRegressor<FitnessFunction,
304  NumericSplitType,
305  CategoricalSplitType,
306  DimensionSelectionType,
307  NoRecursion
309  DecisionTreeRegressor&& other) :
310  NumericAuxiliarySplitInfo(std::move(other)),
311  CategoricalAuxiliarySplitInfo(std::move(other)),
312  children(std::move(other.children)),
313  splitDimension(other.splitDimension),
314  dimensionType(other.dimensionType),
315  splitPointOrPrediction(other.splitPointOrPrediction)
316 {
317  // Nothing to do here.
318 }
319 
321 template<typename FitnessFunction,
322  template<typename> class NumericSplitType,
323  template<typename> class CategoricalSplitType,
324  typename DimensionSelectionType,
325  bool NoRecursion>
326 DecisionTreeRegressor<FitnessFunction,
327  NumericSplitType,
328  CategoricalSplitType,
329  DimensionSelectionType,
330  NoRecursion>&
331 DecisionTreeRegressor<FitnessFunction,
332  NumericSplitType,
333  CategoricalSplitType,
334  DimensionSelectionType,
335  NoRecursion
336 >::operator=(const DecisionTreeRegressor& other)
337 {
338  if (this == &other)
339  return *this; // Nothing to copy.
340 
341  // Clean memory if needed.
342  for (size_t i = 0; i < children.size(); ++i)
343  delete children[i];
344  children.clear();
345 
346  // Copy everything from the other tree.
347  splitDimension = other.splitDimension;
348  dimensionType = other.dimensionType;
349  splitPointOrPrediction = other.splitPointOrPrediction;
350 
351  // Copy the children.
352  for (size_t i = 0; i < other.children.size(); ++i)
353  children.push_back(new DecisionTreeRegressor(*other.children[i]));
354 
355  // Copy the auxiliary info.
356  NumericAuxiliarySplitInfo::operator=(other);
357  CategoricalAuxiliarySplitInfo::operator=(other);
358 
359  return *this;
360 }
361 
363 template<typename FitnessFunction,
364  template<typename> class NumericSplitType,
365  template<typename> class CategoricalSplitType,
366  typename DimensionSelectionType,
367  bool NoRecursion>
368 DecisionTreeRegressor<FitnessFunction,
369  NumericSplitType,
370  CategoricalSplitType,
371  DimensionSelectionType,
372  NoRecursion>&
373 DecisionTreeRegressor<FitnessFunction,
374  NumericSplitType,
375  CategoricalSplitType,
376  DimensionSelectionType,
377  NoRecursion
378 >::operator=(DecisionTreeRegressor&& other)
379 {
380  if (this == &other)
381  return *this; // Nothing to move.
382 
383  // Clean memory if needed.
384  for (size_t i = 0; i < children.size(); ++i)
385  delete children[i];
386  children.clear();
387 
388  // Take ownership of the other tree's components.
389  children = std::move(other.children);
390  splitDimension = other.splitDimension;
391  dimensionType = other.dimensionType;
392  splitPointOrPrediction = other.splitPointOrPrediction;
393 
394  // Take ownership of the auxiliary info.
395  NumericAuxiliarySplitInfo::operator=(std::move(other));
396  CategoricalAuxiliarySplitInfo::operator=(std::move(other));
397 
398  return *this;
399 }
400 
402 template<typename FitnessFunction,
403  template<typename> class NumericSplitType,
404  template<typename> class CategoricalSplitType,
405  typename DimensionSelectionType,
406  bool NoRecursion>
407 DecisionTreeRegressor<FitnessFunction,
408  NumericSplitType,
409  CategoricalSplitType,
410  DimensionSelectionType,
412 {
413  for (size_t i = 0; i < children.size(); ++i)
414  delete children[i];
415 }
416 
418 template<typename FitnessFunction,
419  template<typename> class NumericSplitType,
420  template<typename> class CategoricalSplitType,
421  typename DimensionSelectionType,
422  bool NoRecursion>
423 template<typename MatType, typename ResponsesType>
424 double DecisionTreeRegressor<FitnessFunction,
425  NumericSplitType,
426  CategoricalSplitType,
427  DimensionSelectionType,
428  NoRecursion>::Train(
429  MatType data,
430  const data::DatasetInfo& datasetInfo,
431  ResponsesType responses,
432  const size_t minimumLeafSize,
433  const double minimumGainSplit,
434  const size_t maximumDepth,
435  DimensionSelectionType dimensionSelector)
436 {
437  // Sanity check on data.
438  util::CheckSameSizes(data, responses, "DecisionTreeRegressor::Train()");
439 
440  using TrueMatType = typename std::decay<MatType>::type;
441  using TrueResponsesType = typename std::decay<ResponsesType>::type;
442 
443  // Copy or move data.
444  TrueMatType tmpData(std::move(data));
445  TrueResponsesType tmpResponses(std::move(responses));
446 
447  // Set the correct dimensionality for the dimension selector.
448  dimensionSelector.Dimensions() = tmpData.n_rows;
449 
450  // Pass off work to the Train() method.
451  arma::rowvec weights; // Fake weights, not used.
452  return Train<false>(tmpData, 0, tmpData.n_cols, datasetInfo, tmpResponses,
453  weights, minimumLeafSize, minimumGainSplit, maximumDepth,
454  dimensionSelector);
455 }
456 
458 template<typename FitnessFunction,
459  template<typename> class NumericSplitType,
460  template<typename> class CategoricalSplitType,
461  typename DimensionSelectionType,
462  bool NoRecursion>
463 template<typename MatType, typename ResponsesType>
464 double DecisionTreeRegressor<FitnessFunction,
465  NumericSplitType,
466  CategoricalSplitType,
467  DimensionSelectionType,
468  NoRecursion>::Train(
469  MatType data,
470  ResponsesType responses,
471  const size_t minimumLeafSize,
472  const double minimumGainSplit,
473  const size_t maximumDepth,
474  DimensionSelectionType dimensionSelector)
475 {
476  // Sanity check on data.
477  util::CheckSameSizes(data, responses, "DecisionTreeRegressor::Train()");
478 
479  using TrueMatType = typename std::decay<MatType>::type;
480  using TrueResponsesType = typename std::decay<ResponsesType>::type;
481 
482  // Copy or move data.
483  TrueMatType tmpData(std::move(data));
484  TrueResponsesType tmpResponses(std::move(responses));
485 
486  // Set the correct dimensionality for the dimension selector.
487  dimensionSelector.Dimensions() = tmpData.n_rows;
488 
489  // Pass off work to the Train() method.
490  arma::rowvec weights; // Fake weights, not used.
491  return Train<false>(tmpData, 0, tmpData.n_cols, tmpResponses,
492  weights, minimumLeafSize, minimumGainSplit, maximumDepth,
493  dimensionSelector);
494 }
495 
497 template<typename FitnessFunction,
498  template<typename> class NumericSplitType,
499  template<typename> class CategoricalSplitType,
500  typename DimensionSelectionType,
501  bool NoRecursion>
502 template<typename MatType, typename ResponsesType, typename WeightsType>
503 double DecisionTreeRegressor<FitnessFunction,
504  NumericSplitType,
505  CategoricalSplitType,
506  DimensionSelectionType,
507  NoRecursion>::Train(
508  MatType data,
509  const data::DatasetInfo& datasetInfo,
510  ResponsesType responses,
511  WeightsType weights,
512  const size_t minimumLeafSize,
513  const double minimumGainSplit,
514  const size_t maximumDepth,
515  DimensionSelectionType dimensionSelector,
516  const std::enable_if_t<
517  arma::is_arma_type<
518  typename std::remove_reference<
519  WeightsType>::type>::value>*)
520 {
521  // Sanity check on data.
522  util::CheckSameSizes(data, responses, "DecisionTreeRegressor::Train()");
523 
524  using TrueMatType = typename std::decay<MatType>::type;
525  using TrueResponsesType = typename std::decay<ResponsesType>::type;
526  using TrueWeightsType = typename std::decay<WeightsType>::type;
527 
528  // Copy or move data.
529  TrueMatType tmpData(std::move(data));
530  TrueResponsesType tmpResponses(std::move(responses));
531  TrueWeightsType tmpWeights(std::move(weights));
532 
533  // Set the correct dimensionality for the dimension selector.
534  dimensionSelector.Dimensions() = tmpData.n_rows;
535 
536  // Pass off work to the Train() method.
537  return Train<true>(tmpData, 0, tmpData.n_cols, datasetInfo, tmpResponses,
538  tmpWeights, minimumLeafSize, minimumGainSplit, maximumDepth,
539  dimensionSelector);
540 }
541 
543 template<typename FitnessFunction,
544  template<typename> class NumericSplitType,
545  template<typename> class CategoricalSplitType,
546  typename DimensionSelectionType,
547  bool NoRecursion>
548 template<typename MatType, typename ResponsesType, typename WeightsType>
549 double DecisionTreeRegressor<FitnessFunction,
550  NumericSplitType,
551  CategoricalSplitType,
552  DimensionSelectionType,
553  NoRecursion>::Train(
554  MatType data,
555  ResponsesType responses,
556  WeightsType weights,
557  const size_t minimumLeafSize,
558  const double minimumGainSplit,
559  const size_t maximumDepth,
560  DimensionSelectionType dimensionSelector,
561  const std::enable_if_t<
562  arma::is_arma_type<
563  typename std::remove_reference<
564  WeightsType>::type>::value>*)
565 {
566  // Sanity check on data.
567  util::CheckSameSizes(data, responses, "DecisionTreeRegressor::Train()");
568 
569  using TrueMatType = typename std::decay<MatType>::type;
570  using TrueResponsesType = typename std::decay<ResponsesType>::type;
571  using TrueWeightsType = typename std::decay<WeightsType>::type;
572 
573  // Copy or move data.
574  TrueMatType tmpData(std::move(data));
575  TrueResponsesType tmpResponses(std::move(responses));
576  TrueWeightsType tmpWeights(std::move(weights));
577 
578  // Set the correct dimensionality for the dimension selector.
579  dimensionSelector.Dimensions() = tmpData.n_rows;
580 
581  // Pass off work to the Train() method.
582  return Train<true>(tmpData, 0, tmpData.n_cols, tmpResponses,
583  tmpWeights, minimumLeafSize, minimumGainSplit, maximumDepth,
584  dimensionSelector);
585 }
586 
588 template<typename FitnessFunction,
589  template<typename> class NumericSplitType,
590  template<typename> class CategoricalSplitType,
591  typename DimensionSelectionType,
592  bool NoRecursion>
593 template<bool UseWeights, typename MatType, typename ResponsesType>
594 double DecisionTreeRegressor<FitnessFunction,
595  NumericSplitType,
596  CategoricalSplitType,
597  DimensionSelectionType,
598  NoRecursion>::Train(
599  MatType& data,
600  const size_t begin,
601  const size_t count,
602  const data::DatasetInfo& datasetInfo,
603  ResponsesType& responses,
604  arma::rowvec& weights,
605  const size_t minimumLeafSize,
606  const double minimumGainSplit,
607  const size_t maximumDepth,
608  DimensionSelectionType& dimensionSelector)
609 {
610  // Clear children if needed.
611  for (size_t i = 0; i < children.size(); ++i)
612  delete children[i];
613  children.clear();
614 
615  // Look through the list of dimensions and obtain the gain of the best split.
616  // We'll cache the best numeric and categorical split auxiliary information
617  // in numericAux and categoricalAux (and clear them later if we make no
618  // split). The split point is stored in splitPointOrPrediction for all
619  // internal nodes of the tree.
620  double bestGain = FitnessFunction::template Evaluate<UseWeights>(
621  responses.subvec(begin, begin + count - 1),
622  UseWeights ? weights.subvec(begin, begin + count - 1) : weights);
623  size_t bestDim = datasetInfo.Dimensionality(); // This means "no split".
624  const size_t end = dimensionSelector.End();
625 
626  if (maximumDepth != 1)
627  {
628  for (size_t i = dimensionSelector.Begin(); i != end;
629  i = dimensionSelector.Next())
630  {
631  double dimGain = -DBL_MAX;
632  if (datasetInfo.Type(i) == data::Datatype::categorical)
633  {
634  dimGain = CategoricalSplit::template SplitIfBetter<UseWeights>(bestGain,
635  data.cols(begin, begin + count - 1).row(i),
636  datasetInfo.NumMappings(i),
637  responses.subvec(begin, begin + count - 1),
638  UseWeights ? weights.subvec(begin, begin + count - 1) : weights,
639  minimumLeafSize,
640  minimumGainSplit,
641  splitPointOrPrediction,
642  *this);
643  }
644  else if (datasetInfo.Type(i) == data::Datatype::numeric)
645  {
646  dimGain = NumericSplit::template SplitIfBetter<UseWeights>(bestGain,
647  data.cols(begin, begin + count - 1).row(i),
648  responses.subvec(begin, begin + count - 1),
649  UseWeights ? weights.subvec(begin, begin + count - 1) : weights,
650  minimumLeafSize,
651  minimumGainSplit,
652  splitPointOrPrediction,
653  *this);
654  }
655 
656  // If the splitter reported that it did not split, move to the next
657  // dimension.
658  if (dimGain == DBL_MAX)
659  continue;
660 
661  // Was there an improvement? If so mark that it's the new best dimension.
662  bestDim = i;
663  bestGain = dimGain;
664 
665  // If the gain is the best possible, no need to keep looking.
666  if (bestGain >= 0.0)
667  break;
668  }
669  }
670 
671  // Did we split or not? If so, then split the data and create the children.
672  if (bestDim != datasetInfo.Dimensionality())
673  {
674  dimensionType = (size_t) datasetInfo.Type(bestDim);
675  splitDimension = bestDim;
676 
677  // Get the number of children we will have.
678  size_t numChildren = 0;
679  if (datasetInfo.Type(bestDim) == data::Datatype::categorical)
680  numChildren = CategoricalSplit::NumChildren(splitPointOrPrediction,
681  *this);
682  else
683  numChildren = NumericSplit::NumChildren(splitPointOrPrediction, *this);
684 
685  // Calculate all child assignments.
686  arma::Row<size_t> childAssignments(count);
687  if (datasetInfo.Type(bestDim) == data::Datatype::categorical)
688  {
689  for (size_t j = begin; j < begin + count; ++j)
690  childAssignments[j - begin] = CategoricalSplit::CalculateDirection(
691  data(bestDim, j), splitPointOrPrediction, *this);
692  }
693  else
694  {
695  for (size_t j = begin; j < begin + count; ++j)
696  {
697  childAssignments[j - begin] = NumericSplit::CalculateDirection(
698  data(bestDim, j), splitPointOrPrediction, *this);
699  }
700  }
701 
702  // Figure out counts of children.
703  arma::Row<size_t> childCounts(numChildren, arma::fill::zeros);
704  for (size_t i = begin; i < begin + count; ++i)
705  childCounts[childAssignments[i - begin]]++;
706 
707  // Initialize bestGain if recursive split is allowed.
708  if (!NoRecursion)
709  {
710  bestGain = 0.0;
711  }
712 
713  // Split into children.
714  size_t currentCol = begin;
715  for (size_t i = 0; i < numChildren; ++i)
716  {
717  size_t currentChildBegin = currentCol;
718  for (size_t j = currentChildBegin; j < begin + count; ++j)
719  {
720  if (childAssignments[j - begin] == i)
721  {
722  childAssignments.swap_cols(currentCol - begin, j - begin);
723  data.swap_cols(currentCol, j);
724  responses.swap_cols(currentCol, j);
725  if (UseWeights)
726  weights.swap_cols(currentCol, j);
727  ++currentCol;
728  }
729  }
730 
731  // Now build the child recursively.
733  if (NoRecursion)
734  {
735  child->Train<UseWeights>(data, currentChildBegin,
736  currentCol - currentChildBegin, datasetInfo, responses,
737  weights, currentCol - currentChildBegin, minimumGainSplit,
738  maximumDepth - 1, dimensionSelector);
739  }
740  else
741  {
742  // During recursion entropy of child node may change.
743  double childGain = child->Train<UseWeights>(data, currentChildBegin,
744  currentCol - currentChildBegin, datasetInfo, responses,
745  weights, minimumLeafSize, minimumGainSplit, maximumDepth - 1,
746  dimensionSelector);
747  bestGain += double(childCounts[i]) / double(count) * (-childGain);
748  }
749  children.push_back(child);
750  }
751  }
752  else
753  {
754  // Clear auxiliary info objects.
755  NumericAuxiliarySplitInfo::operator=(NumericAuxiliarySplitInfo());
756  CategoricalAuxiliarySplitInfo::operator=(CategoricalAuxiliarySplitInfo());
757 
758  // Calculate prediction label because we are a leaf.
759  CalculatePrediction<UseWeights>(
760  responses.subvec(begin, begin + count - 1),
761  UseWeights ? weights.subvec(begin, begin + count - 1) : weights);
762  }
763 
764  return -bestGain;
765 }
766 
768 template<typename FitnessFunction,
769  template<typename> class NumericSplitType,
770  template<typename> class CategoricalSplitType,
771  typename DimensionSelectionType,
772  bool NoRecursion>
773 template<bool UseWeights, typename MatType, typename ResponsesType>
774 double DecisionTreeRegressor<FitnessFunction,
775  NumericSplitType,
776  CategoricalSplitType,
777  DimensionSelectionType,
778  NoRecursion>::Train(
779  MatType& data,
780  const size_t begin,
781  const size_t count,
782  ResponsesType& responses,
783  arma::rowvec& weights,
784  const size_t minimumLeafSize,
785  const double minimumGainSplit,
786  const size_t maximumDepth,
787  DimensionSelectionType& dimensionSelector)
788 {
789  // Clear children if needed.
790  for (size_t i = 0; i < children.size(); ++i)
791  delete children[i];
792  children.clear();
793 
794  // We won't be using these members, so reset them.
795  CategoricalAuxiliarySplitInfo::operator=(CategoricalAuxiliarySplitInfo());
796 
797  // Look through the list of dimensions and obtain the best split. We'll cache
798  // the best numeric split auxiliary information in numericAux (and clear it
799  // later if we don't make a split). The split point is stored in
800  // splitPointOrPrediction for all internal nodes of the tree.
801  double bestGain = FitnessFunction::template Evaluate<UseWeights>(
802  responses.subvec(begin, begin + count - 1),
803  UseWeights ? weights.subvec(begin, begin + count - 1) : weights);
804  size_t bestDim = data.n_rows; // This means "no split".
805 
806  if (maximumDepth != 1)
807  {
808  for (size_t i = dimensionSelector.Begin(); i != dimensionSelector.End();
809  i = dimensionSelector.Next())
810  {
811  const double dimGain = NumericSplitType<FitnessFunction>::template
812  SplitIfBetter<UseWeights>(bestGain,
813  data.cols(begin, begin + count - 1).row(i),
814  responses.cols(begin, begin + count - 1),
815  UseWeights ?
816  weights.cols(begin, begin + count - 1) :
817  weights,
818  minimumLeafSize,
819  minimumGainSplit,
820  splitPointOrPrediction,
821  *this);
822 
823  // If the splitter did not report that it improved, then move to the next
824  // dimension.
825  if (dimGain == DBL_MAX)
826  continue;
827 
828  bestDim = i;
829  bestGain = dimGain;
830 
831  // If the gain is the best possible, no need to keep looking.
832  if (bestGain >= 0.0)
833  break;
834  }
835  }
836 
837  // Did we split or not? If so, then split the data and create the children.
838  if (bestDim != data.n_rows)
839  {
840  // We know that the split is numeric.
841  size_t numChildren = NumericSplit::NumChildren(splitPointOrPrediction,
842  *this);
843  splitDimension = bestDim;
844  dimensionType = (size_t) data::Datatype::numeric;
845 
846  // Calculate all child assignments.
847  arma::Row<size_t> childAssignments(count);
848 
849  for (size_t j = begin; j < begin + count; ++j)
850  {
851  childAssignments[j - begin] = NumericSplit::CalculateDirection(
852  data(bestDim, j), splitPointOrPrediction, *this);
853  }
854 
855  // Calculate counts of children in each node.
856  arma::Row<size_t> childCounts(numChildren);
857  childCounts.zeros();
858  for (size_t j = begin; j < begin + count; ++j)
859  childCounts[childAssignments[j - begin]]++;
860 
861  // Initialize bestGain if recursive split is allowed.
862  if (!NoRecursion)
863  {
864  bestGain = 0.0;
865  }
866 
867  size_t currentCol = begin;
868  for (size_t i = 0; i < numChildren; ++i)
869  {
870  size_t currentChildBegin = currentCol;
871  for (size_t j = currentChildBegin; j < begin + count; ++j)
872  {
873  if (childAssignments[j - begin] == i)
874  {
875  childAssignments.swap_cols(currentCol - begin, j - begin);
876  data.swap_cols(currentCol, j);
877  responses.swap_cols(currentCol, j);
878  if (UseWeights)
879  weights.swap_cols(currentCol, j);
880  ++currentCol;
881  }
882  }
883 
884  // Now build the child recursively.
886  if (NoRecursion)
887  {
888  child->Train<UseWeights>(data, currentChildBegin,
889  currentCol - currentChildBegin, responses, weights,
890  currentCol - currentChildBegin, minimumGainSplit, maximumDepth - 1,
891  dimensionSelector);
892  }
893  else
894  {
895  // During recursion entropy of child node may change.
896  double childGain = child->Train<UseWeights>(data, currentChildBegin,
897  currentCol - currentChildBegin, responses, weights,
898  minimumLeafSize, minimumGainSplit, maximumDepth - 1,
899  dimensionSelector);
900  bestGain += double(childCounts[i]) / double(count) * (-childGain);
901  }
902  children.push_back(child);
903  }
904  }
905  else
906  {
907  // We won't be needing these members, so reset them.
908  NumericAuxiliarySplitInfo::operator=(NumericAuxiliarySplitInfo());
909 
910  // Calculate prediction label because we are a leaf.
911  CalculatePrediction<UseWeights>(
912  responses.subvec(begin, begin + count - 1),
913  UseWeights ? weights.subvec(begin, begin + count - 1) : weights);
914  }
915 
916  return -bestGain;
917 }
918 
920 template<typename FitnessFunction,
921  template<typename> class NumericSplitType,
922  template<typename> class CategoricalSplitType,
923  typename DimensionSelectionType,
924  bool NoRecursion>
925 template<typename VecType>
926 double DecisionTreeRegressor<FitnessFunction,
927  NumericSplitType,
928  CategoricalSplitType,
929  DimensionSelectionType,
930  NoRecursion>::Predict(const VecType& point) const
931 {
932  if (children.size() == 0)
933  {
934  // Return cached prediction.
935  return splitPointOrPrediction;
936  }
937 
938  return children[CalculateDirection(point)]->Predict(point);
939 }
940 
942 template<typename FitnessFunction,
943  template<typename> class NumericSplitType,
944  template<typename> class CategoricalSplitType,
945  typename DimensionSelectionType,
946  bool NoRecursion>
947 template<typename MatType>
948 void DecisionTreeRegressor<FitnessFunction,
949  NumericSplitType,
950  CategoricalSplitType,
951  DimensionSelectionType,
952  NoRecursion
953 >::Predict(const MatType& data, arma::Row<double>& predictions) const
954 {
955  predictions.set_size(data.n_cols);
956  // If the tree's root is leaf.
957  if (children.size() == 0)
958  {
959  predictions.fill(splitPointOrPrediction);
960  return;
961  }
962 
963  // Loop over each point.
964  for (size_t i = 0; i < data.n_cols; ++i)
965  predictions[i] = Predict(data.col(i));
966 }
967 
968 template<typename FitnessFunction,
969  template<typename> class NumericSplitType,
970  template<typename> class CategoricalSplitType,
971  typename DimensionSelectionType,
972  bool NoRecursion>
973 template<bool UseWeights, typename ResponsesType, typename WeightsType>
974 void DecisionTreeRegressor<FitnessFunction,
975  NumericSplitType,
976  CategoricalSplitType,
977  DimensionSelectionType,
978  NoRecursion
979 >::CalculatePrediction(const ResponsesType& responses,
980  const WeightsType& weights)
981 {
982  if (UseWeights)
983  {
984  double accWeights, weightedSum;
985  WeightedSum(responses, weights, 0, responses.n_elem, accWeights,
986  weightedSum);
987  splitPointOrPrediction = weightedSum / accWeights;
988  }
989  else
990  {
991  double sum;
992  Sum(responses, 0, responses.n_elem, sum);
993  splitPointOrPrediction = sum / responses.n_elem;
994  }
995 }
996 
997 template<typename FitnessFunction,
998  template<typename> class NumericSplitType,
999  template<typename> class CategoricalSplitType,
1000  typename DimensionSelectionType,
1001  bool NoRecursion>
1002 template<typename VecType>
1003 size_t DecisionTreeRegressor<FitnessFunction,
1004  NumericSplitType,
1005  CategoricalSplitType,
1006  DimensionSelectionType,
1007  NoRecursion
1008 >::CalculateDirection(const VecType& point) const
1009 {
1010  if ((data::Datatype) dimensionType == data::Datatype::categorical)
1011  return CategoricalSplit::CalculateDirection(point[splitDimension],
1012  splitPointOrPrediction, *this);
1013  else
1014  return NumericSplit::CalculateDirection(point[splitDimension],
1015  splitPointOrPrediction, *this);
1016 }
1017 
1019 template<typename FitnessFunction,
1020  template<typename> class NumericSplitType,
1021  template<typename> class CategoricalSplitType,
1022  typename DimensionSelectionType,
1023  bool NoRecursion>
1024 template<typename Archive>
1025 void DecisionTreeRegressor<FitnessFunction,
1026  NumericSplitType,
1027  CategoricalSplitType,
1028  DimensionSelectionType,
1029  NoRecursion
1030 >::serialize(Archive& ar, const uint32_t /* version */)
1031 {
1032  // Clean memory if needed.
1033  if (cereal::is_loading<Archive>())
1034  {
1035  for (size_t i = 0; i < children.size(); ++i)
1036  delete children[i];
1037  children.clear();
1038  }
1039  // Serialize the children first.
1040  ar(CEREAL_VECTOR_POINTER(children));
1041 
1042  // Now serialize the rest of the object.
1043  ar(CEREAL_NVP(splitDimension));
1044  ar(CEREAL_NVP(dimensionType));
1045  ar(CEREAL_NVP(splitPointOrPrediction));
1046 }
1047 
1049 template<typename FitnessFunction,
1050  template<typename> class NumericSplitType,
1051  template<typename> class CategoricalSplitType,
1052  typename DimensionSelectionType,
1053  bool NoRecursion>
1054 size_t DecisionTreeRegressor<FitnessFunction,
1055  NumericSplitType,
1056  CategoricalSplitType,
1057  DimensionSelectionType,
1058  NoRecursion>::NumLeaves() const
1059 {
1060  if (this->NumChildren() == 0)
1061  return 1;
1062 
1063  size_t numLeaves = 0;
1064  for (size_t i = 0; i < this->NumChildren(); ++i)
1065  numLeaves += children[i]->NumLeaves();
1066 
1067  return numLeaves;
1068 }
1069 
1070 } // namespace tree
1071 } // namespace mlpack
1072 
1073 #endif
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
Definition: dataset_mapper.hpp:41
size_t NumChildren() const
Get the number of children.
Definition: decision_tree_regressor.hpp:406
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Datatype
The Datatype enum specifies the types of data mlpack algorithms can use.
Definition: datatype.hpp:24
size_t NumLeaves() const
Get the number of leaves in the tree.
Definition: decision_tree_regressor_impl.hpp:1058
double Train(MatType data, const data::DatasetInfo &datasetInfo, ResponsesType responses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Train the decision tree on the given data.
Definition: decision_tree_regressor_impl.hpp:428
DecisionTreeRegressor()
Construct a decision tree without training it.
Definition: decision_tree_regressor_impl.hpp:31
Definition: pointer_wrapper.hpp:23
size_t Dimensionality() const
Get the dimensionality of the DatasetMapper object (that is, how many dimensions it has information f...
Definition: dataset_mapper_impl.hpp:228
void WeightedSum(const VecType &values, const WeightVecType &weights, const size_t begin, const size_t end, double &accWeights, double &weightedMean)
Calculates the weighted sum and total weight of labels.
Definition: utils.hpp:19
Datatype Type(const size_t dimension) const
Return the type of a given dimension (numeric or categorical).
Definition: dataset_mapper_impl.hpp:196
Definition: hmm_train_main.cpp:300
This class implements a generic decision tree learner.
Definition: decision_tree_regressor.hpp:41
#define CEREAL_VECTOR_POINTER(T)
Cereal does not support the serialization of raw pointer.
Definition: pointer_vector_wrapper.hpp:93
double Predict(const VecType &point) const
Make prediction for the given point, using the entire tree.
Definition: decision_tree_regressor_impl.hpp:930
size_t CalculateDirection(const VecType &point) const
Given a point and that this node is not a leaf, calculate the index of the child node this point woul...
Definition: decision_tree_regressor_impl.hpp:1008
size_t NumMappings(const size_t dimension) const
Get the number of mappings for a particular dimension.
Definition: dataset_mapper_impl.hpp:222
void serialize(Archive &ar, const uint32_t)
Serialize the tree.
Definition: decision_tree_regressor_impl.hpp:1030
void Sum(const VecType &values, const size_t begin, const size_t end, double &mean)
Sums up the labels vector.
Definition: utils.hpp:96
~DecisionTreeRegressor()
Clean up memory.
Definition: decision_tree_regressor_impl.hpp:411