mlpack
qdafn_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_APPROX_KFN_QDAFN_IMPL_HPP
13 #define MLPACK_METHODS_APPROX_KFN_QDAFN_IMPL_HPP
14 
15 // In case it hasn't been included yet.
16 #include "qdafn.hpp"
17 
18 #include <queue>
20 
21 namespace mlpack {
22 namespace neighbor {
23 
24 // Non-training constructor.
25 template<typename MatType>
26 QDAFN<MatType>::QDAFN(const size_t l, const size_t m) : l(l), m(m)
27 {
28  if (l == 0)
29  throw std::invalid_argument("QDAFN::QDAFN(): l must be greater than 0!");
30  if (m == 0)
31  throw std::invalid_argument("QDAFN::QDAFN(): m must be greater than 0!");
32 }
33 
34 // Constructor.
35 template<typename MatType>
36 QDAFN<MatType>::QDAFN(const MatType& referenceSet,
37  const size_t l,
38  const size_t m) :
39  l(l),
40  m(m)
41 {
42  if (l == 0)
43  throw std::invalid_argument("QDAFN::QDAFN(): l must be greater than 0!");
44  if (m == 0)
45  throw std::invalid_argument("QDAFN::QDAFN(): m must be greater than 0!");
46 
47  Train(referenceSet);
48 }
49 
50 // Train the object.
51 template<typename MatType>
52 void QDAFN<MatType>::Train(const MatType& referenceSet,
53  const size_t lIn,
54  const size_t mIn)
55 {
56  if (lIn != 0)
57  l = lIn;
58  if (mIn != 0)
59  m = mIn;
60 
61  // Build tables. This is done by drawing random points from a Gaussian
62  // distribution as the vectors we project onto. The Gaussian should have zero
63  // mean and unit variance.
64  mlpack::distribution::GaussianDistribution gd(referenceSet.n_rows);
65  lines.set_size(referenceSet.n_rows, l);
66  for (size_t i = 0; i < l; ++i)
67  lines.col(i) = gd.Random();
68 
69  // Now, project each of the reference points onto each line, and collect the
70  // top m elements.
71  projections = referenceSet.t() * lines;
72 
73  // Loop over each projection and find the top m elements.
74  sIndices.set_size(m, l);
75  sValues.set_size(m, l);
76  candidateSet.resize(l);
77  for (size_t i = 0; i < l; ++i)
78  {
79  candidateSet[i].set_size(referenceSet.n_rows, m);
80  arma::uvec sortedIndices = arma::sort_index(projections.col(i), "descend");
81 
82  // Grab the top m elements.
83  for (size_t j = 0; j < m; ++j)
84  {
85  sIndices(j, i) = sortedIndices[j];
86  sValues(j, i) = projections(sortedIndices[j], i);
87  candidateSet[i].col(j) = referenceSet.col(sortedIndices[j]);
88  }
89  }
90 }
91 
92 // Search.
93 template<typename MatType>
94 void QDAFN<MatType>::Search(const MatType& querySet,
95  const size_t k,
96  arma::Mat<size_t>& neighbors,
97  arma::mat& distances)
98 {
99  if (k > m)
100  throw std::invalid_argument("QDAFN::Search(): requested k is greater than "
101  "value of m!");
102 
103  neighbors.set_size(k, querySet.n_cols);
104  neighbors.fill(size_t() - 1);
105  distances.zeros(k, querySet.n_cols);
106 
107  // Search for each point.
108  for (size_t q = 0; q < querySet.n_cols; ++q)
109  {
110  // Initialize a priority queue.
111  // The size_t represents the index of the table, and the double represents
112  // the value of l_i * S_i - l_i * query (see line 6 of Algorithm 1).
113  std::priority_queue<std::pair<double, size_t>> queue;
114  for (size_t i = 0; i < l; ++i)
115  {
116  const double val = sValues(0, i) - arma::dot(querySet.col(q),
117  lines.col(i));
118  queue.push(std::make_pair(val, i));
119  }
120 
121  // To track where we are in each S table, we keep the next index to look at
122  // in each table (they start at 0).
123  arma::Col<size_t> tableLocations = arma::zeros<arma::Col<size_t>>(l);
124 
125  // Now that the queue is initialized, iterate over m elements.
126  std::vector<std::pair<double, size_t>> v(k, std::make_pair(-1.0,
127  size_t(-1)));
128  std::priority_queue<std::pair<double, size_t>>
129  resultsQueue(std::less<std::pair<double, size_t>>(), std::move(v));
130  for (size_t i = 0; i < m; ++i)
131  {
132  std::pair<size_t, double> p = queue.top();
133  queue.pop();
134 
135  // Get index of reference point to look at.
136  const size_t tableIndex = tableLocations[p.second];
137 
138  // Calculate distance from query point.
140  querySet.col(q), candidateSet[p.second].col(tableIndex));
141 
142  resultsQueue.push(std::make_pair(dist, sIndices(tableIndex, p.second)));
143 
144  // Now (line 14) get the next element and insert into the queue. Do this
145  // by adjusting the previous value. Don't insert anything if we are at
146  // the end of the search, though.
147  if (i < m - 1)
148  {
149  tableLocations[p.second]++;
150  const double val = p.first - sValues(tableIndex, p.second) +
151  sValues(tableIndex + 1, p.second);
152 
153  queue.push(std::make_pair(val, p.second));
154  }
155  }
156 
157  // Extract the results and deduplicate them.
158  size_t extracted = 1;
159  neighbors(0, q) = resultsQueue.top().second;
160  distances(0, q) = resultsQueue.top().first;
161  resultsQueue.pop();
162 
163  while (!resultsQueue.empty())
164  {
165  if (extracted == k)
166  break;
167 
168  std::pair<double, size_t> result = resultsQueue.top();
169  resultsQueue.pop();
170 
171  // Avoid inserting any duplicates.
172  if (neighbors(extracted - 1, q) != result.second)
173  {
174  neighbors(extracted, q) = resultsQueue.top().second;
175  distances(extracted, q) = resultsQueue.top().first;
176  ++extracted;
177  }
178  }
179  }
180 }
181 
182 template<typename MatType>
183 template<typename Archive>
184 void QDAFN<MatType>::serialize(Archive& ar, const uint32_t /* version */)
185 {
186  ar(CEREAL_NVP(l));
187  ar(CEREAL_NVP(m));
188  ar(CEREAL_NVP(lines));
189  ar(CEREAL_NVP(projections));
190  ar(CEREAL_NVP(sIndices));
191  ar(CEREAL_NVP(sValues));
192  if (cereal::is_loading<Archive>())
193  candidateSet.clear();
194  ar(CEREAL_NVP(candidateSet));
195 }
196 
197 } // namespace neighbor
198 } // namespace mlpack
199 
200 #endif
arma::vec Random() const
Return a randomly generated observation according to the probability distribution defined by this obj...
Definition: gaussian_distribution.cpp:79
A single multivariate Gaussian distribution.
Definition: gaussian_distribution.hpp:24
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
QDAFN(const size_t l, const size_t m)
Construct the QDAFN object but do not train it.
Definition: qdafn_impl.hpp:26
void Search(const MatType &querySet, const size_t k, arma::Mat< size_t > &neighbors, arma::mat &distances)
Search for the k furthest neighbors of the given query set.
Definition: qdafn_impl.hpp:94
static VecTypeA::elem_type Evaluate(const VecTypeA &a, const VecTypeB &b)
Computes the distance between two points.
Definition: lmetric_impl.hpp:24
void serialize(Archive &ar, const uint32_t)
Serialize the model.
Definition: qdafn_impl.hpp:184
void Train(const MatType &referenceSet, const size_t l=0, const size_t m=0)
Train the QDAFN model on the given reference set, optionally setting new parameters for the number of...
Definition: qdafn_impl.hpp:52