13 #ifndef MLPACK_CORE_TREE_BINARY_SPACE_TREE_RP_TREE_MEAN_SPLIT_IMPL_HPP 14 #define MLPACK_CORE_TREE_BINARY_SPACE_TREE_RP_TREE_MEAN_SPLIT_IMPL_HPP 21 template<
typename BoundType,
typename MatType>
28 const size_t maxNumSamples = 100;
29 const size_t numSamples = std::min(maxNumSamples, count);
36 ElemType averageDistanceSq = GetAveragePointDistance(data, samples);
40 if (bound.Diameter() * bound.Diameter() <= threshold * averageDistanceSq)
61 return GetMeanMedian(data, samples, splitInfo.
mean, splitInfo.
splitVal);
65 template<
typename BoundType,
typename MatType>
69 const arma::uvec& samples)
73 for (
size_t i = 0; i < samples.n_elem; ++i)
74 for (
size_t j = i + 1; j < samples.n_elem; ++j)
76 data.col(samples[j]));
78 dist /= (samples.n_elem * (samples.n_elem - 1) / 2);
83 template<
typename BoundType,
typename MatType>
86 const arma::uvec& samples,
87 const arma::Col<ElemType>& direction,
90 arma::Col<ElemType> values(samples.n_elem);
92 for (
size_t k = 0; k < samples.n_elem; ++k)
93 values[k] = arma::dot(data.col(samples[k]), direction);
95 const ElemType maximum = arma::max(values);
96 const ElemType minimum = arma::min(values);
97 if (minimum == maximum)
100 splitVal = arma::median(values);
102 if (splitVal == maximum)
108 template<
typename BoundType,
typename MatType>
111 const arma::uvec& samples,
112 arma::Col<ElemType>& mean,
115 arma::Col<ElemType> values(samples.n_elem);
117 mean = arma::mean(data.cols(samples), 1);
119 arma::Col<ElemType> tmp(data.n_rows);
121 for (
size_t k = 0; k < samples.n_elem; ++k)
123 tmp = data.col(samples[k]);
126 values[k] = arma::dot(tmp, tmp);
129 const ElemType maximum = arma::max(values);
130 const ElemType minimum = arma::min(values);
131 if (minimum == maximum)
134 splitVal = arma::median(values);
136 if (splitVal == maximum)
145 #endif // MLPACK_CORE_TREE_BINARY_SPACE_TREE_RP_TREE_MEAN_SPLIT_IMPL_HPP void ObtainDistinctSamples(const size_t loInclusive, const size_t hiExclusive, const size_t maxNumSamples, arma::uvec &distinctSamples)
Obtains no more than maxNumSamples distinct samples.
Definition: random.hpp:153
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
An information about the partition.
Definition: rp_tree_mean_split.hpp:39
static bool SplitNode(const BoundType &, MatType &data, const size_t begin, const size_t count, SplitInfo &splitInfo)
Split the node according to the mean value in the dimension with maximum width.
Definition: rp_tree_mean_split_impl.hpp:22
ElemType splitVal
The value according to which the split will be performed.
Definition: rp_tree_mean_split.hpp:46
static VecTypeA::elem_type Evaluate(const VecTypeA &a, const VecTypeB &b)
Computes the distance between two points.
Definition: lmetric_impl.hpp:24
This class splits a binary space tree.
Definition: rp_tree_mean_split.hpp:33
void RandVector(arma::vec &v)
Overwrites a dimension-N vector to a random vector on the unit sphere in R^N.
Definition: lin_alg.cpp:79
MatType::elem_type ElemType
The element type held by the matrix type.
Definition: rp_tree_mean_split.hpp:37
arma::Col< ElemType > mean
The mean of some sampled points.
Definition: rp_tree_mean_split.hpp:44
arma::Col< ElemType > direction
The normal to the hyperplane that will split the node.
Definition: rp_tree_mean_split.hpp:42
bool meanSplit
Indicates that we should use the mean split algorithm instead of the median split.
Definition: rp_tree_mean_split.hpp:49