12 #ifndef MLPACK_METHODS_APPROX_KFN_QDAFN_IMPL_HPP 13 #define MLPACK_METHODS_APPROX_KFN_QDAFN_IMPL_HPP 25 template<
typename MatType>
29 throw std::invalid_argument(
"QDAFN::QDAFN(): l must be greater than 0!");
31 throw std::invalid_argument(
"QDAFN::QDAFN(): m must be greater than 0!");
35 template<
typename MatType>
43 throw std::invalid_argument(
"QDAFN::QDAFN(): l must be greater than 0!");
45 throw std::invalid_argument(
"QDAFN::QDAFN(): m must be greater than 0!");
51 template<
typename MatType>
65 lines.set_size(referenceSet.n_rows, l);
66 for (
size_t i = 0; i < l; ++i)
67 lines.col(i) = gd.
Random();
71 projections = referenceSet.t() * lines;
74 sIndices.set_size(m, l);
75 sValues.set_size(m, l);
76 candidateSet.resize(l);
77 for (
size_t i = 0; i < l; ++i)
79 candidateSet[i].set_size(referenceSet.n_rows, m);
80 arma::uvec sortedIndices = arma::sort_index(projections.col(i),
"descend");
83 for (
size_t j = 0; j < m; ++j)
85 sIndices(j, i) = sortedIndices[j];
86 sValues(j, i) = projections(sortedIndices[j], i);
87 candidateSet[i].col(j) = referenceSet.col(sortedIndices[j]);
93 template<
typename MatType>
96 arma::Mat<size_t>& neighbors,
100 throw std::invalid_argument(
"QDAFN::Search(): requested k is greater than " 103 neighbors.set_size(k, querySet.n_cols);
104 neighbors.fill(
size_t() - 1);
105 distances.zeros(k, querySet.n_cols);
108 for (
size_t q = 0; q < querySet.n_cols; ++q)
113 std::priority_queue<std::pair<double, size_t>> queue;
114 for (
size_t i = 0; i < l; ++i)
116 const double val = sValues(0, i) - arma::dot(querySet.col(q),
118 queue.push(std::make_pair(val, i));
123 arma::Col<size_t> tableLocations = arma::zeros<arma::Col<size_t>>(l);
126 std::vector<std::pair<double, size_t>> v(k, std::make_pair(-1.0,
128 std::priority_queue<std::pair<double, size_t>>
129 resultsQueue(std::less<std::pair<double, size_t>>(), std::move(v));
130 for (
size_t i = 0; i < m; ++i)
132 std::pair<size_t, double> p = queue.top();
136 const size_t tableIndex = tableLocations[p.second];
140 querySet.col(q), candidateSet[p.second].col(tableIndex));
142 resultsQueue.push(std::make_pair(dist, sIndices(tableIndex, p.second)));
149 tableLocations[p.second]++;
150 const double val = p.first - sValues(tableIndex, p.second) +
151 sValues(tableIndex + 1, p.second);
153 queue.push(std::make_pair(val, p.second));
158 size_t extracted = 1;
159 neighbors(0, q) = resultsQueue.top().second;
160 distances(0, q) = resultsQueue.top().first;
163 while (!resultsQueue.empty())
168 std::pair<double, size_t> result = resultsQueue.top();
172 if (neighbors(extracted - 1, q) != result.second)
174 neighbors(extracted, q) = resultsQueue.top().second;
175 distances(extracted, q) = resultsQueue.top().first;
182 template<
typename MatType>
183 template<
typename Archive>
188 ar(CEREAL_NVP(lines));
189 ar(CEREAL_NVP(projections));
190 ar(CEREAL_NVP(sIndices));
191 ar(CEREAL_NVP(sValues));
192 if (cereal::is_loading<Archive>())
193 candidateSet.clear();
194 ar(CEREAL_NVP(candidateSet));
arma::vec Random() const
Return a randomly generated observation according to the probability distribution defined by this obj...
Definition: gaussian_distribution.cpp:79
A single multivariate Gaussian distribution.
Definition: gaussian_distribution.hpp:24
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
QDAFN(const size_t l, const size_t m)
Construct the QDAFN object but do not train it.
Definition: qdafn_impl.hpp:26
void Search(const MatType &querySet, const size_t k, arma::Mat< size_t > &neighbors, arma::mat &distances)
Search for the k furthest neighbors of the given query set.
Definition: qdafn_impl.hpp:94
static VecTypeA::elem_type Evaluate(const VecTypeA &a, const VecTypeB &b)
Computes the distance between two points.
Definition: lmetric_impl.hpp:24
void serialize(Archive &ar, const uint32_t)
Serialize the model.
Definition: qdafn_impl.hpp:184
void Train(const MatType &referenceSet, const size_t l=0, const size_t m=0)
Train the QDAFN model on the given reference set, optionally setting new parameters for the number of...
Definition: qdafn_impl.hpp:52