12 #ifndef MLPACK_METHODS_APPROX_KFN_DRUSILLA_SELECT_IMPL_HPP 13 #define MLPACK_METHODS_APPROX_KFN_DRUSILLA_SELECT_IMPL_HPP 28 template<
typename MatType>
32 candidateSet(referenceSet.n_cols, l * m),
33 candidateIndices(l * m),
38 throw std::invalid_argument(
"DrusillaSelect::DrusillaSelect(): invalid " 39 "value of l; must be greater than 0!");
41 throw std::invalid_argument(
"DrusillaSelect::DrusillaSelect(): invalid " 42 "value of m; must be greater than 0!");
44 Train(referenceSet, l, m);
48 template<
typename MatType>
50 candidateSet(0, l * m),
51 candidateIndices(l * m),
56 throw std::invalid_argument(
"DrusillaSelect::DrusillaSelect(): invalid " 57 "value of l; must be greater than 0!");
59 throw std::invalid_argument(
"DrusillaSelect::DrusillaSelect(): invalid " 60 "value of m; must be greater than 0!");
64 template<
typename MatType>
66 const MatType& referenceSet,
76 if ((l * m) > referenceSet.n_cols)
77 throw std::invalid_argument(
"DrusillaSelect::Train(): l and m are too " 78 "large! Choose smaller values. l*m must be smaller than the number " 79 "of points in the dataset.");
81 candidateSet.set_size(referenceSet.n_rows, l * m);
82 candidateIndices.set_size(l * m);
84 arma::vec dataMean(arma::mean(referenceSet, 1));
85 arma::vec norms(referenceSet.n_cols);
87 MatType refCopy(referenceSet.n_rows, referenceSet.n_cols);
88 for (
size_t i = 0; i < refCopy.n_cols; ++i)
90 refCopy.col(i) = referenceSet.col(i) - dataMean;
91 norms[i] = arma::norm(refCopy.col(i));
95 for (
size_t i = 0; i < l; ++i)
98 arma::uword maxIndex = 0;
101 arma::vec line(refCopy.col(maxIndex) / arma::norm(refCopy.col(maxIndex)));
104 std::vector<bool> closeAngle(referenceSet.n_cols,
false);
105 arma::vec sums(referenceSet.n_cols);
106 for (
size_t j = 0; j < referenceSet.n_cols; ++j)
110 const double offset = arma::dot(refCopy.col(j), line);
111 const double distortion = arma::norm(refCopy.col(j) - offset * line);
112 sums[j] = std::abs(offset) - std::abs(distortion);
114 (std::atan(distortion / std::abs(offset)) < (M_PI / 8.0));
123 typedef std::pair<double, size_t> Candidate;
126 bool operator()(
const Candidate& c1,
const Candidate& c2)
128 return c2.first < c1.first;
132 std::vector<Candidate> clist(
133 m, std::make_pair(
double(-DBL_MAX),
size_t(-1)));
134 std::priority_queue<Candidate, std::vector<Candidate>, CandidateCmp>
135 pq(CandidateCmp(), std::move(clist));
137 for (
size_t j = 0; j < sums.n_elem; ++j)
139 Candidate c = std::make_pair(sums[j], j);
140 if (CandidateCmp()(c, pq.top()))
148 for (
size_t j = 0; j < m; ++j)
150 const size_t index = pq.top().second;
152 candidateSet.col(i * m + j) = referenceSet.col(index);
153 candidateIndices[i * m + j] = index;
161 for (
size_t j = 0; j < norms.n_elem; ++j)
162 if (norms[j] > 0.0 && closeAngle[j])
168 template<
typename MatType>
171 arma::Mat<size_t>& neighbors,
172 arma::mat& distances)
174 if (candidateSet.n_cols == 0)
175 throw std::runtime_error(
"DrusillaSelect::Search(): candidate set not " 176 "initialized! Call Train() first.");
179 throw std::invalid_argument(
"DrusillaSelect::Search(): requested k is " 180 "greater than number of points in candidate set! Increase l or m.");
188 rules(candidateSet, querySet, k, metric, 0,
false);
190 for (
size_t q = 0; q < querySet.n_cols; ++q)
191 for (
size_t r = 0; r < candidateSet.n_cols; ++r)
192 rules.BaseCase(q, r);
194 rules.GetResults(neighbors, distances);
197 for (
size_t i = 0; i < neighbors.n_elem; ++i)
198 neighbors[i] = candidateIndices[neighbors[i]];
202 template<
typename MatType>
203 template<
typename Archive>
207 ar(CEREAL_NVP(candidateSet));
208 ar(CEREAL_NVP(candidateIndices));
DrusillaSelect(const MatType &referenceSet, const size_t l, const size_t m)
Construct the DrusillaSelect object with the given reference set (this is the set that will be search...
Definition: drusilla_select_impl.hpp:29
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
A binary space partitioning tree, such as a KD-tree or a ball tree.
Definition: binary_space_tree.hpp:54
This class implements the necessary methods for the SortPolicy template parameter of the NeighborSear...
Definition: furthest_neighbor_sort.hpp:27
The NeighborSearchRules class is a template helper class used by NeighborSearch class when performing...
Definition: neighbor_search_rules.hpp:35
void serialize(Archive &ar, const uint32_t)
Serialize the model.
Definition: drusilla_select_impl.hpp:204
The L_p metric for arbitrary integer p, with an option to take the root.
Definition: lmetric.hpp:63
void Train(const MatType &referenceSet, const size_t l=0, const size_t m=0)
Build the set of candidate points on the given reference set.
Definition: drusilla_select_impl.hpp:65
void Search(const MatType &querySet, const size_t k, arma::Mat< size_t > &neighbors, arma::mat &distances)
Search for the k furthest neighbors of the given query set.
Definition: drusilla_select_impl.hpp:169