mlpack
drusilla_select_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_APPROX_KFN_DRUSILLA_SELECT_IMPL_HPP
13 #define MLPACK_METHODS_APPROX_KFN_DRUSILLA_SELECT_IMPL_HPP
14 
15 // In case it hasn't been included yet.
16 #include "drusilla_select.hpp"
17 
18 #include <queue>
22 #include <algorithm>
23 
24 namespace mlpack {
25 namespace neighbor {
26 
27 // Constructor.
28 template<typename MatType>
29 DrusillaSelect<MatType>::DrusillaSelect(const MatType& referenceSet,
30  const size_t l,
31  const size_t m) :
32  candidateSet(referenceSet.n_cols, l * m),
33  candidateIndices(l * m),
34  l(l),
35  m(m)
36 {
37  if (l == 0)
38  throw std::invalid_argument("DrusillaSelect::DrusillaSelect(): invalid "
39  "value of l; must be greater than 0!");
40  else if (m == 0)
41  throw std::invalid_argument("DrusillaSelect::DrusillaSelect(): invalid "
42  "value of m; must be greater than 0!");
43 
44  Train(referenceSet, l, m);
45 }
46 
47 // Constructor with no training.
48 template<typename MatType>
49 DrusillaSelect<MatType>::DrusillaSelect(const size_t l, const size_t m) :
50  candidateSet(0, l * m),
51  candidateIndices(l * m),
52  l(l),
53  m(m)
54 {
55  if (l == 0)
56  throw std::invalid_argument("DrusillaSelect::DrusillaSelect(): invalid "
57  "value of l; must be greater than 0!");
58  else if (m == 0)
59  throw std::invalid_argument("DrusillaSelect::DrusillaSelect(): invalid "
60  "value of m; must be greater than 0!");
61 }
62 
63 // Train the model.
64 template<typename MatType>
66  const MatType& referenceSet,
67  const size_t lIn,
68  const size_t mIn)
69 {
70  // Did the user specify a new size? If so, use it.
71  if (lIn > 0)
72  l = lIn;
73  if (mIn > 0)
74  m = mIn;
75 
76  if ((l * m) > referenceSet.n_cols)
77  throw std::invalid_argument("DrusillaSelect::Train(): l and m are too "
78  "large! Choose smaller values. l*m must be smaller than the number "
79  "of points in the dataset.");
80 
81  candidateSet.set_size(referenceSet.n_rows, l * m);
82  candidateIndices.set_size(l * m);
83 
84  arma::vec dataMean(arma::mean(referenceSet, 1));
85  arma::vec norms(referenceSet.n_cols);
86 
87  MatType refCopy(referenceSet.n_rows, referenceSet.n_cols);
88  for (size_t i = 0; i < refCopy.n_cols; ++i)
89  {
90  refCopy.col(i) = referenceSet.col(i) - dataMean;
91  norms[i] = arma::norm(refCopy.col(i));
92  }
93 
94  // Find the top m points for each of the l projections...
95  for (size_t i = 0; i < l; ++i)
96  {
97  // Pick best index.
98  arma::uword maxIndex = 0;
99  norms.max(maxIndex);
100 
101  arma::vec line(refCopy.col(maxIndex) / arma::norm(refCopy.col(maxIndex)));
102 
103  // Calculate distortion and offset and make scores.
104  std::vector<bool> closeAngle(referenceSet.n_cols, false);
105  arma::vec sums(referenceSet.n_cols);
106  for (size_t j = 0; j < referenceSet.n_cols; ++j)
107  {
108  if (norms[j] > 0.0)
109  {
110  const double offset = arma::dot(refCopy.col(j), line);
111  const double distortion = arma::norm(refCopy.col(j) - offset * line);
112  sums[j] = std::abs(offset) - std::abs(distortion);
113  closeAngle[j] =
114  (std::atan(distortion / std::abs(offset)) < (M_PI / 8.0));
115  }
116  else
117  {
118  sums[j] = norms[j];
119  }
120  }
121 
122  // Find the top m elements using a priority queue.
123  typedef std::pair<double, size_t> Candidate;
124  struct CandidateCmp
125  {
126  bool operator()(const Candidate& c1, const Candidate& c2)
127  {
128  return c2.first < c1.first;
129  }
130  };
131 
132  std::vector<Candidate> clist(
133  m, std::make_pair(double(-DBL_MAX), size_t(-1)));
134  std::priority_queue<Candidate, std::vector<Candidate>, CandidateCmp>
135  pq(CandidateCmp(), std::move(clist));
136 
137  for (size_t j = 0; j < sums.n_elem; ++j)
138  {
139  Candidate c = std::make_pair(sums[j], j);
140  if (CandidateCmp()(c, pq.top()))
141  {
142  pq.pop();
143  pq.push(c);
144  }
145  }
146 
147  // Take the top m elements for this table.
148  for (size_t j = 0; j < m; ++j)
149  {
150  const size_t index = pq.top().second;
151  pq.pop();
152  candidateSet.col(i * m + j) = referenceSet.col(index);
153  candidateIndices[i * m + j] = index;
154 
155  // Mark the norm as -1 so we don't see this point again.
156  norms[index] = -1.0;
157  }
158 
159  // Calculate angles from the current projection. Anything close enough,
160  // mark the norm as 0.
161  for (size_t j = 0; j < norms.n_elem; ++j)
162  if (norms[j] > 0.0 && closeAngle[j])
163  norms[j] = 0.0;
164  }
165 }
166 
167 // Search.
168 template<typename MatType>
169 void DrusillaSelect<MatType>::Search(const MatType& querySet,
170  const size_t k,
171  arma::Mat<size_t>& neighbors,
172  arma::mat& distances)
173 {
174  if (candidateSet.n_cols == 0)
175  throw std::runtime_error("DrusillaSelect::Search(): candidate set not "
176  "initialized! Call Train() first.");
177 
178  if (k > (l * m))
179  throw std::invalid_argument("DrusillaSelect::Search(): requested k is "
180  "greater than number of points in candidate set! Increase l or m.");
181 
182  // We'll use the NeighborSearchRules class to perform our brute-force search.
183  // Note that we aren't using trees for our search, so we can use 'int' as a
184  // TreeType.
188  rules(candidateSet, querySet, k, metric, 0, false);
189 
190  for (size_t q = 0; q < querySet.n_cols; ++q)
191  for (size_t r = 0; r < candidateSet.n_cols; ++r)
192  rules.BaseCase(q, r);
193 
194  rules.GetResults(neighbors, distances);
195 
196  // Map the neighbors back to their original indices in the reference set.
197  for (size_t i = 0; i < neighbors.n_elem; ++i)
198  neighbors[i] = candidateIndices[neighbors[i]];
199 }
200 
202 template<typename MatType>
203 template<typename Archive>
205  const uint32_t /* version */)
206 {
207  ar(CEREAL_NVP(candidateSet));
208  ar(CEREAL_NVP(candidateIndices));
209  ar(CEREAL_NVP(l));
210  ar(CEREAL_NVP(m));
211 }
212 
213 } // namespace neighbor
214 } // namespace mlpack
215 
216 #endif
DrusillaSelect(const MatType &referenceSet, const size_t l, const size_t m)
Construct the DrusillaSelect object with the given reference set (this is the set that will be search...
Definition: drusilla_select_impl.hpp:29
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
A binary space partitioning tree, such as a KD-tree or a ball tree.
Definition: binary_space_tree.hpp:54
This class implements the necessary methods for the SortPolicy template parameter of the NeighborSear...
Definition: furthest_neighbor_sort.hpp:27
The NeighborSearchRules class is a template helper class used by NeighborSearch class when performing...
Definition: neighbor_search_rules.hpp:35
void serialize(Archive &ar, const uint32_t)
Serialize the model.
Definition: drusilla_select_impl.hpp:204
The L_p metric for arbitrary integer p, with an option to take the root.
Definition: lmetric.hpp:63
void Train(const MatType &referenceSet, const size_t l=0, const size_t m=0)
Build the set of candidate points on the given reference set.
Definition: drusilla_select_impl.hpp:65
void Search(const MatType &querySet, const size_t k, arma::Mat< size_t > &neighbors, arma::mat &distances)
Search for the k furthest neighbors of the given query set.
Definition: drusilla_select_impl.hpp:169