mlpack
rp_tree_mean_split_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_CORE_TREE_BINARY_SPACE_TREE_RP_TREE_MEAN_SPLIT_IMPL_HPP
14 #define MLPACK_CORE_TREE_BINARY_SPACE_TREE_RP_TREE_MEAN_SPLIT_IMPL_HPP
15 
16 #include "rp_tree_max_split.hpp"
17 
18 namespace mlpack {
19 namespace tree {
20 
21 template<typename BoundType, typename MatType>
23  MatType& data,
24  const size_t begin,
25  const size_t count,
26  SplitInfo& splitInfo)
27 {
28  const size_t maxNumSamples = 100;
29  const size_t numSamples = std::min(maxNumSamples, count);
30  arma::uvec samples;
31 
32  // Get no more than numSamples distinct samples.
33  math::ObtainDistinctSamples(begin, begin + count, numSamples, samples);
34 
35  // Find the average distance between points.
36  ElemType averageDistanceSq = GetAveragePointDistance(data, samples);
37 
38  const ElemType threshold = 10;
39 
40  if (bound.Diameter() * bound.Diameter() <= threshold * averageDistanceSq)
41  {
42  // We will perform the median split.
43  splitInfo.meanSplit = false;
44 
45  splitInfo.direction.zeros(data.n_rows);
46 
47  // Get a random normal vector.
48  math::RandVector(splitInfo.direction);
49 
50  // Get the median value of the scalar products of the normal and the
51  // sampled points. The node will be split according to this value.
52  return GetDotMedian(data, samples, splitInfo.direction, splitInfo.splitVal);
53  }
54  else
55  {
56  // We will perform the mean split.
57  splitInfo.meanSplit = true;
58 
59  // Get the median of the distances between the mean point and the sampled
60  // points. The node will be split according to this value.
61  return GetMeanMedian(data, samples, splitInfo.mean, splitInfo.splitVal);
62  }
63 }
64 
65 template<typename BoundType, typename MatType>
66 typename MatType::elem_type RPTreeMeanSplit<BoundType, MatType>::
68  MatType& data,
69  const arma::uvec& samples)
70 {
71  ElemType dist = 0;
72 
73  for (size_t i = 0; i < samples.n_elem; ++i)
74  for (size_t j = i + 1; j < samples.n_elem; ++j)
75  dist += metric::SquaredEuclideanDistance::Evaluate(data.col(samples[i]),
76  data.col(samples[j]));
77 
78  dist /= (samples.n_elem * (samples.n_elem - 1) / 2);
79 
80  return dist;
81 }
82 
83 template<typename BoundType, typename MatType>
85  const MatType& data,
86  const arma::uvec& samples,
87  const arma::Col<ElemType>& direction,
88  ElemType& splitVal)
89 {
90  arma::Col<ElemType> values(samples.n_elem);
91 
92  for (size_t k = 0; k < samples.n_elem; ++k)
93  values[k] = arma::dot(data.col(samples[k]), direction);
94 
95  const ElemType maximum = arma::max(values);
96  const ElemType minimum = arma::min(values);
97  if (minimum == maximum)
98  return false;
99 
100  splitVal = arma::median(values);
101 
102  if (splitVal == maximum)
103  splitVal = minimum;
104 
105  return true;
106 }
107 
108 template<typename BoundType, typename MatType>
110  const MatType& data,
111  const arma::uvec& samples,
112  arma::Col<ElemType>& mean,
113  ElemType& splitVal)
114 {
115  arma::Col<ElemType> values(samples.n_elem);
116 
117  mean = arma::mean(data.cols(samples), 1);
118 
119  arma::Col<ElemType> tmp(data.n_rows);
120 
121  for (size_t k = 0; k < samples.n_elem; ++k)
122  {
123  tmp = data.col(samples[k]);
124  tmp -= mean;
125 
126  values[k] = arma::dot(tmp, tmp);
127  }
128 
129  const ElemType maximum = arma::max(values);
130  const ElemType minimum = arma::min(values);
131  if (minimum == maximum)
132  return false;
133 
134  splitVal = arma::median(values);
135 
136  if (splitVal == maximum)
137  splitVal = minimum;
138 
139  return true;
140 }
141 
142 } // namespace tree
143 } // namespace mlpack
144 
145 #endif // MLPACK_CORE_TREE_BINARY_SPACE_TREE_RP_TREE_MEAN_SPLIT_IMPL_HPP
void ObtainDistinctSamples(const size_t loInclusive, const size_t hiExclusive, const size_t maxNumSamples, arma::uvec &distinctSamples)
Obtains no more than maxNumSamples distinct samples.
Definition: random.hpp:153
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
An information about the partition.
Definition: rp_tree_mean_split.hpp:39
static bool SplitNode(const BoundType &, MatType &data, const size_t begin, const size_t count, SplitInfo &splitInfo)
Split the node according to the mean value in the dimension with maximum width.
Definition: rp_tree_mean_split_impl.hpp:22
ElemType splitVal
The value according to which the split will be performed.
Definition: rp_tree_mean_split.hpp:46
static VecTypeA::elem_type Evaluate(const VecTypeA &a, const VecTypeB &b)
Computes the distance between two points.
Definition: lmetric_impl.hpp:24
This class splits a binary space tree.
Definition: rp_tree_mean_split.hpp:33
void RandVector(arma::vec &v)
Overwrites a dimension-N vector to a random vector on the unit sphere in R^N.
Definition: lin_alg.cpp:79
MatType::elem_type ElemType
The element type held by the matrix type.
Definition: rp_tree_mean_split.hpp:37
arma::Col< ElemType > mean
The mean of some sampled points.
Definition: rp_tree_mean_split.hpp:44
arma::Col< ElemType > direction
The normal to the hyperplane that will split the node.
Definition: rp_tree_mean_split.hpp:42
bool meanSplit
Indicates that we should use the mean split algorithm instead of the median split.
Definition: rp_tree_mean_split.hpp:49