mlpack
lsh_search_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_IMPL_HPP
13 #define MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_IMPL_HPP
14 
15 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace neighbor {
20 
21 // Construct the object with random tables
22 template<typename SortPolicy, typename MatType>
24 LSHSearch(MatType referenceSet,
25  const size_t numProj,
26  const size_t numTables,
27  const double hashWidthIn,
28  const size_t secondHashSize,
29  const size_t bucketSize) :
30  numProj(numProj),
31  numTables(numTables),
32  hashWidth(hashWidthIn),
33  secondHashSize(secondHashSize),
34  bucketSize(bucketSize),
35  distanceEvaluations(0)
36 {
37  // Pass work to training function.
38  Train(std::move(referenceSet), numProj, numTables, hashWidthIn,
39  secondHashSize, bucketSize);
40 }
41 
42 // Construct the object with given tables
43 template<typename SortPolicy, typename MatType>
45 LSHSearch(MatType referenceSet,
46  const arma::cube& projections,
47  const double hashWidthIn,
48  const size_t secondHashSize,
49  const size_t bucketSize) :
50  numProj(projections.n_cols),
51  numTables(projections.n_slices),
52  hashWidth(hashWidthIn),
53  secondHashSize(secondHashSize),
54  bucketSize(bucketSize),
55  distanceEvaluations(0)
56 {
57  // Pass work to training function.
58  Train(std::move(referenceSet), numProj, numTables, hashWidthIn,
59  secondHashSize, bucketSize, projections);
60 }
61 
62 // Empty constructor.
63 template<typename SortPolicy, typename MatType>
65  numProj(0),
66  numTables(0),
67  hashWidth(0),
68  secondHashSize(99901),
69  bucketSize(500),
70  distanceEvaluations(0)
71 {
72 }
73 
74 // Copy constructor.
75 template<typename SortPolicy, typename MatType>
77  referenceSet(other.referenceSet), // Copy the other set.
78  numProj(other.numProj),
79  numTables(other.numTables),
80  projections(other.projections),
81  offsets(other.offsets),
82  hashWidth(other.hashWidth),
83  secondHashSize(other.secondHashSize),
84  secondHashWeights(other.secondHashWeights),
85  bucketSize(other.bucketSize),
86  secondHashTable(other.secondHashTable),
87  bucketContentSize(other.bucketContentSize),
88  bucketRowInHashTable(other.bucketRowInHashTable),
89  distanceEvaluations(other.distanceEvaluations)
90 {
91  // Nothing to do.
92 }
93 
94 // Move constructor.
95 template<typename SortPolicy, typename MatType>
97  referenceSet(std::move(other.referenceSet)),
98  numProj(other.numProj),
99  numTables(other.numTables),
100  projections(std::move(other.projections)),
101  offsets(std::move(other.offsets)),
102  hashWidth(other.hashWidth),
103  secondHashSize(other.secondHashSize),
104  secondHashWeights(std::move(other.secondHashWeights)),
105  bucketSize(other.bucketSize),
106  secondHashTable(std::move(other.secondHashTable)),
107  bucketContentSize(std::move(other.bucketContentSize)),
108  bucketRowInHashTable(std::move(other.bucketRowInHashTable)),
109  distanceEvaluations(other.distanceEvaluations)
110 {
111  // Reset other model to defaults.
112  other.numProj = 0;
113  other.numTables = 0;
114  other.hashWidth = 0;
115  other.secondHashSize = 99901;
116  other.bucketSize = 500;
117  other.distanceEvaluations = 0;
118 }
119 
120 // Copy operator.
121 template<typename SortPolicy, typename MatType>
123  const LSHSearch& other)
124 {
125  referenceSet = other.referenceSet;
126  numProj = other.numProj;
127  numTables = other.numTables;
128  projections = other.projections;
129  offsets = other.offsets;
130  hashWidth = other.hashWidth;
131  secondHashSize = other.secondHashSize;
132  secondHashWeights = other.secondHashWeights;
133  bucketSize = other.bucketSize;
134  secondHashTable = other.secondHashTable;
135  bucketContentSize = other.bucketContentSize;
136  bucketRowInHashTable = other.bucketRowInHashTable;
137  distanceEvaluations = other.distanceEvaluations;
138 
139  return *this;
140 }
141 
142 // Move operator.
143 template<typename SortPolicy, typename MatType>
145  LSHSearch&& other)
146 {
147  referenceSet = std::move(other.referenceSet);
148  numProj = other.numProj;
149  numTables = other.numTables;
150  projections = std::move(other.projections);
151  offsets = std::move(other.offsets);
152  hashWidth = other.hashWidth;
153  secondHashSize = other.secondHashSize;
154  secondHashWeights = std::move(other.secondHashWeights);
155  bucketSize = other.bucketSize;
156  secondHashTable = std::move(other.secondHashTable);
157  bucketContentSize = std::move(other.bucketContentSize);
158  bucketRowInHashTable = std::move(other.bucketRowInHashTable);
159  distanceEvaluations = other.distanceEvaluations;
160 
161  // Reset other model to defaults.
162  other.numProj = 0;
163  other.numTables = 0;
164  other.hashWidth = 0;
165  other.secondHashSize = 99901;
166  other.bucketSize = 500;
167  other.distanceEvaluations = 0;
168 
169  return *this;
170 }
171 
172 // Train on a new reference set.
173 template<typename SortPolicy, typename MatType>
174 void LSHSearch<SortPolicy, MatType>::Train(MatType referenceSet,
175  const size_t numProj,
176  const size_t numTables,
177  const double hashWidthIn,
178  const size_t secondHashSize,
179  const size_t bucketSize,
180  const arma::cube& projection)
181 {
182  // Set new reference set.
183  this->referenceSet = std::move(referenceSet);
184 
185  // Set new parameters.
186  this->numProj = numProj;
187  this->numTables = numTables;
188  this->hashWidth = hashWidthIn;
189  this->secondHashSize = secondHashSize;
190  this->bucketSize = bucketSize;
191 
192  if (hashWidth == 0.0) // The user has not provided any value.
193  {
194  const size_t numSamples = 25;
195  // Compute a heuristic hash width from the data.
196  for (size_t i = 0; i < numSamples; ++i)
197  {
198  size_t p1 = (size_t) math::RandInt(this->referenceSet.n_cols);
199  size_t p2 = (size_t) math::RandInt(this->referenceSet.n_cols);
200 
201  hashWidth += std::sqrt(metric::EuclideanDistance::Evaluate(
202  this->referenceSet.col(p1),
203  this->referenceSet.col(p2)));
204  }
205 
206  hashWidth /= numSamples;
207  }
208 
209  Log::Info << "Hash width chosen as: " << hashWidth << std::endl;
210 
211  // Hash building procedure:
212  // The first level hash for a single table outputs a 'numProj'-dimensional
213  // integer key for each point in the set -- (key, pointID). The key creation
214  // details are presented below.
215 
216  // Step I: Prepare the second level hash.
217 
218  // Obtain the weights for the second hash.
219  secondHashWeights = arma::floor(arma::randu(numProj) *
220  (double) secondHashSize);
221 
222  // Instead of putting the points in the row corresponding to the bucket, we
223  // chose the next empty row and keep track of the row in which the bucket
224  // lies. This allows us to stack together and slice out the empty buckets at
225  // the end of the hashing.
226  bucketRowInHashTable.set_size(secondHashSize);
227  bucketRowInHashTable.fill(secondHashSize);
228 
229  // Step II: The offsets for all projections in all tables.
230  // Since the 'offsets' are in [0, hashWidth], we obtain the 'offsets'
231  // as randu(numProj, numTables) * hashWidth.
232  offsets.randu(numProj, numTables);
233  offsets *= hashWidth;
234 
235  // Step III: Obtain the 'numProj' projections for each table.
236  projections.clear(); // Reset projections vector.
237 
238  if (projection.n_slices == 0) // Randomly generate the tables.
239  {
240  // For L2 metric, 2-stable distributions are used, and the normal Z ~ N(0,
241  // 1) is a 2-stable distribution.
242 
243  // Build numTables random tables arranged in a cube.
244  projections.randn(this->referenceSet.n_rows, numProj, numTables);
245  }
246  else if (projection.n_slices == numTables) // Take user-defined tables.
247  {
248  projections = projection;
249  }
250  else // The user gave something wrong.
251  {
252  throw std::invalid_argument("LSHSearch::Train(): number of projection "
253  "tables provided must be equal to numProj");
254  }
255 
256  // We will store the second hash vectors in this matrix; the second hash
257  // vector for table i will be held in row i. We have to use int and not
258  // size_t, otherwise negative numbers are cast to 0.
259  arma::Mat<size_t> secondHashVectors(numTables, this->referenceSet.n_cols);
260 
261  for (size_t i = 0; i < numTables; ++i)
262  {
263  // Step IV: create the 'numProj'-dimensional key for each point in each
264  // table.
265 
266  // The following code performs the task of hashing each point to a
267  // 'numProj'-dimensional integer key. Hence you get a ('numProj' x
268  // 'referenceSet.n_cols') key matrix.
269  //
270  // For a single table, let the 'numProj' projections be denoted by 'proj_i'
271  // and the corresponding offset be 'offset_i'. Then the key of a single
272  // point is obtained as:
273  // key = { floor((<proj_i, point> + offset_i) / 'hashWidth') forall i }
274  arma::mat offsetMat = arma::repmat(offsets.unsafe_col(i), 1,
275  this->referenceSet.n_cols);
276  arma::mat hashMat = projections.slice(i).t() * (this->referenceSet);
277  hashMat += offsetMat;
278  hashMat /= hashWidth;
279 
280  // Step V: Putting the points in the 'secondHashTable' by hashing the key.
281  // Now we hash every key, point ID to its corresponding bucket. We must
282  // also normalize the hashes to the range [0, secondHashSize).
283  arma::rowvec unmodVector = secondHashWeights.t() * arma::floor(hashMat);
284  for (size_t j = 0; j < unmodVector.n_elem; ++j)
285  {
286  double shs = (double) secondHashSize; // Convenience cast.
287  if (unmodVector[j] >= 0.0)
288  {
289  const size_t key = size_t(fmod(unmodVector[j], shs));
290  secondHashVectors(i, j) = key;
291  }
292  else
293  {
294  const double mod = fmod(-unmodVector[j], shs);
295  const size_t key = (mod < 1.0) ? 0 : secondHashSize - size_t(mod);
296  secondHashVectors(i, j) = key;
297  }
298  }
299  }
300 
301  // Now, using the hash vectors for each table, count the number of rows we
302  // have in the second hash table.
303  arma::Row<size_t> secondHashBinCounts(secondHashSize, arma::fill::zeros);
304  for (size_t i = 0; i < secondHashVectors.n_elem; ++i)
305  secondHashBinCounts[secondHashVectors[i]]++;
306 
307  // Enforce the maximum bucket size.
308  const size_t effectiveBucketSize = (bucketSize == 0) ? SIZE_MAX : bucketSize;
309  secondHashBinCounts.transform([effectiveBucketSize](size_t val)
310  { return std::min(val, effectiveBucketSize); });
311 
312  const size_t numRowsInTable = arma::accu(secondHashBinCounts > 0);
313  bucketContentSize.zeros(numRowsInTable);
314  secondHashTable.resize(numRowsInTable);
315 
316  // Next we must assign each point in each table to the right second hash
317  // table.
318  size_t currentRow = 0;
319  for (size_t i = 0; i < numTables; ++i)
320  {
321  // Insert the point in the corresponding row to its bucket in the
322  // 'secondHashTable'.
323  for (size_t j = 0; j < secondHashVectors.n_cols; ++j)
324  {
325  // This is the bucket number.
326  size_t hashInd = (size_t) secondHashVectors(i, j);
327  // The point ID is 'j'.
328 
329  // If this is currently an empty bucket, start a new row keep track of
330  // which row corresponds to the bucket.
331  const size_t maxSize = secondHashBinCounts[hashInd];
332  if (bucketRowInHashTable[hashInd] == secondHashSize)
333  {
334  bucketRowInHashTable[hashInd] = currentRow;
335  secondHashTable[currentRow].set_size(maxSize);
336  currentRow++;
337  }
338 
339  // If this vector in the hash table is not full, add the point.
340  const size_t index = bucketRowInHashTable[hashInd];
341  if (bucketContentSize[index] < maxSize)
342  secondHashTable[index](bucketContentSize[index]++) = j;
343  } // Loop over all points in the reference set.
344  } // Loop over tables.
345 
346  Log::Info << "Final hash table size: " << numRowsInTable << " rows, with a "
347  << "maximum length of " << arma::max(secondHashBinCounts) << ", "
348  << "totaling " << arma::accu(secondHashBinCounts) << " elements."
349  << std::endl;
350 }
351 
352 // Base case where the query set is the reference set. (So, we can't return
353 // ourselves as the nearest neighbor.)
354 template<typename SortPolicy, typename MatType>
355 inline force_inline
357  const size_t queryIndex,
358  const arma::uvec& referenceIndices,
359  const size_t k,
360  arma::Mat<size_t>& neighbors,
361  arma::mat& distances) const
362 {
363  // Let's build the list of candidate neighbors for the given query point.
364  // It will be initialized with k candidates:
365  // (WorstDistance, referenceSet.n_cols)
366  const Candidate def = std::make_pair(SortPolicy::WorstDistance(),
367  referenceSet.n_cols);
368  std::vector<Candidate> vect(k, def);
369  CandidateList pqueue(CandidateCmp(), std::move(vect));
370 
371  for (size_t j = 0; j < referenceIndices.n_elem; ++j)
372  {
373  const size_t referenceIndex = referenceIndices[j];
374  // If the points are the same, skip this point.
375  if (queryIndex == referenceIndex)
376  continue;
377 
378  const double distance = metric::EuclideanDistance::Evaluate(
379  referenceSet.col(queryIndex),
380  referenceSet.col(referenceIndex));
381 
382  Candidate c = std::make_pair(distance, referenceIndex);
383  // If this distance is better than the worst candidate, let's insert it.
384  if (CandidateCmp()(c, pqueue.top()))
385  {
386  pqueue.pop();
387  pqueue.push(c);
388  }
389  }
390 
391  for (size_t j = 1; j <= k; ++j)
392  {
393  neighbors(k - j, queryIndex) = pqueue.top().second;
394  distances(k - j, queryIndex) = pqueue.top().first;
395  pqueue.pop();
396  }
397 }
398 
399 // Base case for bichromatic search.
400 template<typename SortPolicy, typename MatType>
401 inline force_inline
403  const size_t queryIndex,
404  const arma::uvec& referenceIndices,
405  const size_t k,
406  const MatType& querySet,
407  arma::Mat<size_t>& neighbors,
408  arma::mat& distances) const
409 {
410  // Let's build the list of candidate neighbors for the given query point.
411  // It will be initialized with k candidates:
412  // (WorstDistance, referenceSet.n_cols)
413  const Candidate def = std::make_pair(SortPolicy::WorstDistance(),
414  referenceSet.n_cols);
415  std::vector<Candidate> vect(k, def);
416  CandidateList pqueue(CandidateCmp(), std::move(vect));
417 
418  for (size_t j = 0; j < referenceIndices.n_elem; ++j)
419  {
420  const size_t referenceIndex = referenceIndices[j];
421  const double distance = metric::EuclideanDistance::Evaluate(
422  querySet.col(queryIndex),
423  referenceSet.col(referenceIndex));
424 
425  Candidate c = std::make_pair(distance, referenceIndex);
426  // If this distance is better than the worst candidate, let's insert it.
427  if (CandidateCmp()(c, pqueue.top()))
428  {
429  pqueue.pop();
430  pqueue.push(c);
431  }
432  }
433 
434  for (size_t j = 1; j <= k; ++j)
435  {
436  neighbors(k - j, queryIndex) = pqueue.top().second;
437  distances(k - j, queryIndex) = pqueue.top().first;
438  pqueue.pop();
439  }
440 }
441 
442 template<typename SortPolicy, typename MatType>
443 inline force_inline
445  const std::vector<bool>& A,
446  const arma::vec& scores) const
447 {
448  double score = 0.0;
449  for (size_t i = 0; i < A.size(); ++i)
450  if (A[i])
451  score += scores(i); // add scores of non-zero indices
452  return score;
453 }
454 
455 template<typename SortPolicy, typename MatType>
456 inline force_inline
458  std::vector<bool>& A) const
459 {
460  size_t maxPos = 0;
461  for (size_t i = 0; i < A.size(); ++i)
462  if (A[i] == 1) // Marked true.
463  maxPos = i;
464 
465  if (maxPos + 1 < A.size()) // Otherwise, this is an invalid vector.
466  {
467  A[maxPos] = 0;
468  A[maxPos + 1] = 1;
469  return true; // valid
470  }
471  return false; // invalid
472 }
473 
474 template<typename SortPolicy, typename MatType>
475 inline force_inline
477  std::vector<bool>& A) const
478 {
479  // Find the last '1' in A.
480  size_t maxPos = 0;
481  for (size_t i = 0; i < A.size(); ++i)
482  if (A[i]) // Marked true.
483  maxPos = i;
484 
485  if (maxPos + 1 < A.size()) // Otherwise, this is an invalid vector.
486  {
487  A[maxPos + 1] = 1;
488  return true;
489  }
490  return false;
491 }
492 
493 template<typename SortPolicy, typename MatType>
494 inline force_inline
496  const std::vector<bool>& A) const
497 {
498  // Use check to mark dimensions we have seen before in A. If a dimension is
499  // seen twice (or more), A is not a valid perturbation.
500  std::vector<bool> check(numProj);
501 
502  if (A.size() > 2 * numProj)
503  return false; // This should never happen.
504 
505  // Check that we only see each dimension once. If not, vector is not valid.
506  for (size_t i = 0; i < A.size(); ++i)
507  {
508  // Only check dimensions that were included.
509  if (!A[i])
510  continue;
511 
512  // If dimesnion is unseen thus far, mark it as seen.
513  if (check[i % numProj] == false)
514  check[i % numProj] = true;
515  else
516  return false; // If dimension was seen before, set is not valid.
517  }
518  // If we didn't fail, set is valid.
519  return true;
520 }
521 
522 // Compute additional probing bins for a query
523 template<typename SortPolicy, typename MatType>
525  const arma::vec& queryCode,
526  const arma::vec& queryCodeNotFloored,
527  const size_t T,
528  arma::mat& additionalProbingBins) const
529 {
530  // No additional bins requested. Our work is done.
531  if (T == 0)
532  return;
533 
534  // Each column of additionalProbingBins is the code of a bin.
535  additionalProbingBins.set_size(numProj, T);
536 
537  // Copy the query's code, then in the end we will add/subtract according
538  // to perturbations we calculated.
539  for (size_t c = 0; c < T; ++c)
540  additionalProbingBins.col(c) = queryCode;
541 
542 
543  // Calculate query point's projection position.
544  arma::mat projection = queryCodeNotFloored;
545 
546  // Use projection to calculate query's distance from hash limits.
547  arma::vec limLow = projection - queryCode * hashWidth;
548  arma::vec limHigh = hashWidth - limLow;
549 
550  // Calculate scores. score = distance^2.
551  arma::vec scores(2 * numProj);
552  scores.rows(0, numProj - 1) = arma::pow(limLow, 2);
553  scores.rows(numProj, (2 * numProj) - 1) = arma::pow(limHigh, 2);
554 
555  // Actions vector describes what perturbation (-1/+1) corresponds to a score.
556  arma::Col<short int> actions(2 * numProj); // will be [-1 ... 1 ...]
557  actions.rows(0, numProj - 1) = // First numProj rows.
558  -1 * arma::ones< arma::Col<short int> > (numProj); // -1s
559  actions.rows(numProj, (2 * numProj) - 1) = // Last numProj rows.
560  arma::ones< arma::Col<short int> > (numProj); // 1s
561 
562 
563  // Acting dimension vector shows which coordinate to transform according to
564  // actions (actions are described by actions vector above).
565  arma::Col<size_t> positions(2 * numProj); // Will be [0 1 2 ... 0 1 2 ...].
566  positions.rows(0, numProj - 1) =
567  arma::linspace< arma::Col<size_t> >(0, numProj - 1, numProj);
568  positions.rows(numProj, 2 * numProj - 1) =
569  arma::linspace< arma::Col<size_t> >(0, numProj - 1, numProj);
570 
571  // Special case: No need to create heap for 1 or 2 codes.
572  if (T <= 2)
573  {
574  // First, find location of minimum score, generate 1 perturbation vector,
575  // and add its code to additionalProbingBins column 0.
576 
577  // Find location and value of smallest element of scores vector.
578  double minscore = scores[0];
579  size_t minloc = 0;
580  for (size_t s = 1; s < (2 * numProj); ++s)
581  {
582  if (minscore > scores[s])
583  {
584  minscore = scores[s];
585  minloc = s;
586  }
587  }
588 
589  // Add or subtract 1 to dimension corresponding to minimum score.
590  additionalProbingBins(positions[minloc], 0) += actions[minloc];
591  if (T == 1)
592  return; // Done if asked for only 1 code.
593 
594  // Now, find location of second smallest score and generate one more vector.
595  // The second perturbation vector still can't comprise of more than one
596  // change in the bin codes, because of the way perturbation vectors
597  // are generated: First we create the one with the smallest score (Ao) and
598  // then we either add 1 extra dimension to it (Ae) or shift it by one (As).
599  // Since As contains the second smallest score, and Ae contains both the
600  // smallest and the second smallest, it's obvious that score(Ae) >
601  // score(As). Therefore the second perturbation vector is ALWAYS the vector
602  // containing only the second-lowest scoring perturbation.
603  double minscore2 = scores[0];
604  size_t minloc2 = 0;
605  for (size_t s = 0; s < (2 * numProj); ++s) // Here we can't start from 1.
606  {
607  if (minscore2 > scores[s] && s != minloc) // Second smallest.
608  {
609  minscore2 = scores[s];
610  minloc2 = s;
611  }
612  }
613 
614  // Add or subtract 1 to create second-lowest scoring vector.
615  additionalProbingBins(positions[minloc2], 1) += actions[minloc2];
616  return;
617  }
618 
619  // General case: more than 2 perturbation vectors require use of minheap.
620  // Sort everything in increasing order.
621  arma::uvec sortidx = arma::sort_index(scores);
622  scores = scores(sortidx);
623  actions = actions(sortidx);
624  positions = positions(sortidx);
625 
626  // Theory:
627  // A probing sequence is a sequence of T probing bins where a query's
628  // neighbors are most likely to be. Likelihood is dependent only on a bin's
629  // score, which is the sum of scores of all dimension-action pairs, so we
630  // need to calculate the T smallest sums of scores that are not conflicting.
631  //
632  // Method:
633  // Store each perturbation set (pair of (dimension, action)) in a
634  // std::vector. Create a minheap of scores, with each node pointing to its
635  // relevant perturbation set. Each perturbation set popped from the minheap
636  // is the next most likely perturbation set.
637  // Transform perturbation set to perturbation vector by setting the
638  // dimensions specified by the set to queryCode+action (action is {-1, 1}).
639 
640  // Perturbation sets (A) mark with 1 the (score, action, dimension) positions
641  // included in a given perturbation vector. Other spaces are 0.
642  std::vector<bool> Ao(2 * numProj);
643  Ao[0] = 1; // Smallest vector includes only smallest score.
644 
645  std::vector< std::vector<bool> > perturbationSets;
646  perturbationSets.push_back(Ao); // Storage of perturbation sets.
647 
648  std::priority_queue<
649  std::pair<double, size_t>, // contents: pairs of (score, index)
650  std::vector< // container: vector of pairs
651  std::pair<double, size_t>
652  >,
653  std::greater< std::pair<double, size_t> > // comparator of pairs
654  > minHeap; // our minheap
655 
656  // Start by adding the lowest scoring set to the minheap.
657  minHeap.push(std::make_pair(PerturbationScore(Ao, scores), 0));
658 
659  // Loop invariable: after pvec iterations, additionalProbingBins contains pvec
660  // valid codes of the lowest-scoring bins (bins most likely to contain
661  // neighbors of the query).
662  for (size_t pvec = 0; pvec < T; ++pvec)
663  {
664  std::vector<bool> Ai;
665  do
666  {
667  // Get the perturbation set corresponding to the minimum score.
668  Ai = perturbationSets[ minHeap.top().second ];
669  minHeap.pop(); // .top() returns, .pop() removes
670 
671  // Shift operation on Ai (replace max with max+1).
672  std::vector<bool> As = Ai;
673 
674  // Don't add invalid sets.
675  if (PerturbationShift(As) && PerturbationValid(As))
676  {
677  perturbationSets.push_back(As); // add shifted set to sets
678  minHeap.push(
679  std::make_pair(PerturbationScore(As, scores),
680  perturbationSets.size() - 1));
681  }
682 
683  // Expand operation on Ai (add max+1 to set).
684  std::vector<bool> Ae = Ai;
685 
686  // Don't add invalid sets.
687  if (PerturbationExpand(Ae) && PerturbationValid(Ae))
688  {
689  perturbationSets.push_back(Ae); // add expanded set to sets
690  minHeap.push(
691  std::make_pair(PerturbationScore(Ae, scores),
692  perturbationSets.size() - 1));
693  }
694  } while (!PerturbationValid(Ai)); // Discard invalid perturbations
695 
696  // Found valid perturbation set Ai. Construct perturbation vector from set.
697  for (size_t pos = 0; pos < Ai.size(); ++pos)
698  {
699  // If Ai[pos] is marked, add action to probing vector.
700  additionalProbingBins(positions(pos), pvec) += Ai[pos] ? actions(pos) : 0;
701  }
702  }
703 }
704 
705 template<typename SortPolicy, typename MatType>
706 template<typename VecType>
708  const VecType& queryPoint,
709  arma::uvec& referenceIndices,
710  size_t numTablesToSearch,
711  const size_t T) const
712 {
713  // Decide on the number of tables to look into.
714  if (numTablesToSearch == 0) // If no user input is given, search all.
715  numTablesToSearch = numTables;
716 
717  // Sanity check to make sure that the existing number of tables is not
718  // exceeded.
719  if (numTablesToSearch > numTables)
720  numTablesToSearch = numTables;
721 
722  // Hash the query in each of the 'numTablesToSearch' hash tables using the
723  // 'numProj' projections for each table. This gives us 'numTablesToSearch'
724  // keys for the query where each key is a 'numProj' dimensional integer
725  // vector.
726 
727  // Compute the projection of the query in each table.
728  arma::mat allProjInTables(numProj, numTablesToSearch);
729  arma::mat queryCodesNotFloored(numProj, numTablesToSearch);
730  for (size_t i = 0; i < numTablesToSearch; ++i)
731  queryCodesNotFloored.unsafe_col(i) = projections.slice(i).t() * queryPoint;
732 
733  queryCodesNotFloored += offsets.cols(0, numTablesToSearch - 1);
734  allProjInTables = arma::floor(queryCodesNotFloored / hashWidth);
735 
736  // Use hashMat to store the primary probing codes and any additional codes
737  // from multiprobe LSH.
738  arma::Mat<size_t> hashMat;
739  hashMat.set_size(T + 1, numTablesToSearch);
740 
741  // Compute the primary hash value of each key of the query into a bucket of
742  // the secondHashTable using the secondHashWeights.
743  hashMat.row(0) = arma::conv_to<arma::Row<size_t>> // Floor by typecasting
744  ::from(secondHashWeights.t() * allProjInTables);
745  // Mod to compute 2nd-level codes.
746  for (size_t i = 0; i < numTablesToSearch; ++i)
747  hashMat(0, i) = (hashMat(0, i) % secondHashSize);
748 
749  // Compute hash codes of additional probing bins.
750  if (T > 0)
751  {
752  for (size_t i = 0; i < numTablesToSearch; ++i)
753  {
754  // Construct this table's probing sequence of length T.
755  arma::mat additionalProbingBins;
756  GetAdditionalProbingBins(allProjInTables.unsafe_col(i),
757  queryCodesNotFloored.unsafe_col(i),
758  T,
759  additionalProbingBins);
760 
761  // Map each probing bin to a bin in secondHashTable (just like we did for
762  // the primary hash table).
763  hashMat(arma::span(1, T), i) = // Compute code of rows 1:end of column i
764  arma::conv_to< arma::Col<size_t> >:: // floor by typecasting to size_t
765  from(secondHashWeights.t() * additionalProbingBins);
766  for (size_t p = 1; p < T + 1; ++p)
767  hashMat(p, i) = (hashMat(p, i) % secondHashSize);
768  }
769  }
770 
771  // Count number of points hashed in the same bucket as the query.
772  size_t maxNumPoints = 0;
773  for (size_t i = 0; i < numTablesToSearch; ++i)
774  {
775  for (size_t p = 0; p < T + 1; ++p)
776  {
777  const size_t hashInd = hashMat(p, i); // find query's bucket
778  const size_t tableRow = bucketRowInHashTable[hashInd];
779  if (tableRow < secondHashSize)
780  maxNumPoints += bucketContentSize[tableRow]; // count bucket contents
781  }
782  }
783 
784  // There are two ways to proceed here:
785  // Either allocate a maxNumPoints-size vector, place all candidates, and run
786  // unique on the vector to discard duplicates.
787  // Or allocate a referenceSet.n_cols size vector (i.e. number of reference
788  // points) of zeros, and mark found indices as 1.
789  // Option 1 runs faster for small maxNumPoints but worse for larger values, so
790  // we choose based on a heuristic.
791  const float cutoff = 0.1;
792  const float selectivity = static_cast<float>(maxNumPoints) /
793  static_cast<float>(referenceSet.n_cols);
794 
795  if (selectivity > cutoff)
796  {
797  // Heuristic: larger maxNumPoints means we should use find() because it
798  // should be faster.
799  // Reference points hashed in the same bucket as the query are set to >0.
800  arma::Col<size_t> refPointsConsidered;
801  refPointsConsidered.zeros(referenceSet.n_cols);
802 
803  for (size_t i = 0; i < numTablesToSearch; ++i) // for all tables
804  {
805  for (size_t p = 0; p < T + 1; ++p) // For entire probing sequence.
806  {
807  // get the sequence code
808  size_t hashInd = hashMat(p, i);
809  size_t tableRow = bucketRowInHashTable[hashInd];
810 
811  if (tableRow < secondHashSize && bucketContentSize[tableRow] > 0)
812  {
813  // Pick the indices in the bucket corresponding to hashInd.
814  for (size_t j = 0; j < bucketContentSize[tableRow]; ++j)
815  refPointsConsidered[ secondHashTable[tableRow](j) ]++;
816  }
817  }
818  }
819 
820  // Only keep reference points found in at least one bucket.
821  referenceIndices = arma::find(refPointsConsidered > 0);
822  return;
823  }
824  else
825  {
826  // Heuristic: smaller maxNumPoints means we should use unique() because it
827  // should be faster.
828  // Allocate space for the query's potential neighbors.
829  arma::uvec refPointsConsideredSmall;
830  refPointsConsideredSmall.zeros(maxNumPoints);
831 
832  // Retrieve candidates.
833  size_t start = 0;
834 
835  for (size_t i = 0; i < numTablesToSearch; ++i) // For all tables
836  {
837  for (size_t p = 0; p < T + 1; ++p)
838  {
839  const size_t hashInd = hashMat(p, i); // Find the query's bucket.
840  const size_t tableRow = bucketRowInHashTable[hashInd];
841 
842  if (tableRow < secondHashSize)
843  {
844  // Store all secondHashTable points in the candidates set.
845  for (size_t j = 0; j < bucketContentSize[tableRow]; ++j)
846  refPointsConsideredSmall(start++) = secondHashTable[tableRow](j);
847  }
848  }
849  }
850 
851  // Keep only one copy of each candidate.
852  referenceIndices = arma::unique(refPointsConsideredSmall);
853  return;
854  }
855 }
856 
857 // Search for nearest neighbors in a given query set.
858 template<typename SortPolicy, typename MatType>
860  const MatType& querySet,
861  const size_t k,
862  arma::Mat<size_t>& resultingNeighbors,
863  arma::mat& distances,
864  const size_t numTablesToSearch,
865  const size_t T)
866 {
867  // Ensure the dimensionality of the query set is correct.
868  util::CheckSameDimensionality(querySet, referenceSet, "LSHSearch::Search()",
869  "query set");
870 
871  if (k > referenceSet.n_cols)
872  {
873  std::ostringstream oss;
874  oss << "LSHSearch::Search(): requested " << k << " approximate nearest "
875  << "neighbors, but reference set has " << referenceSet.n_cols
876  << " points!" << std::endl;
877  throw std::invalid_argument(oss.str());
878  }
879 
880  // Set the size of the neighbor and distance matrices.
881  resultingNeighbors.set_size(k, querySet.n_cols);
882  distances.set_size(k, querySet.n_cols);
883 
884  // If the user asked for 0 nearest neighbors... uh... we're done.
885  if (k == 0)
886  return;
887 
888  // If the user requested more than the available number of additional probing
889  // bins, set Teffective to maximum T. Maximum T is 2^numProj - 1
890  size_t Teffective = T;
891  if (T > ((size_t) ((1 << numProj) - 1)))
892  {
893  Teffective = (1 << numProj) - 1;
894  Log::Warn << "Requested " << T << " additional bins are more than "
895  << "theoretical maximum. Using " << Teffective << " instead."
896  << std::endl;
897  }
898 
899  // If the user set multiprobe, log it
900  if (Teffective > 0)
901  Log::Info << "Running multiprobe LSH with " << Teffective
902  <<" additional probing bins per table per query." << std::endl;
903 
904  size_t avgIndicesReturned = 0;
905 
906  Timer::Start("computing_neighbors");
907 
908  // Parallelization to process more than one query at a time.
909  #pragma omp parallel for \
910  shared(resultingNeighbors, distances) \
911  schedule(dynamic)\
912  reduction(+:avgIndicesReturned)
913  for (omp_size_t i = 0; i < (omp_size_t) querySet.n_cols; ++i)
914  {
915  // Go through every query point.
916  // Hash every query into every hash table and eventually into the
917  // 'secondHashTable' to obtain the neighbor candidates.
918  arma::uvec refIndices;
919  ReturnIndicesFromTable(querySet.col(i), refIndices, numTablesToSearch,
920  Teffective);
921 
922  // An informative book-keeping for the number of neighbor candidates
923  // returned on average.
924  // Make atomic to avoid race conditions when multiple threads are running
925  // #pragma omp atomic
926  avgIndicesReturned = avgIndicesReturned + refIndices.n_elem;
927 
928  // Sequentially go through all the candidates and save the best 'k'
929  // candidates.
930  BaseCase(i, refIndices, k, querySet, resultingNeighbors, distances);
931  }
932 
933  Timer::Stop("computing_neighbors");
934 
935  distanceEvaluations += avgIndicesReturned;
936  avgIndicesReturned /= querySet.n_cols;
937  Log::Info << avgIndicesReturned << " distinct indices returned on average." <<
938  std::endl;
939 }
940 
941 // Search for approximate neighbors of the reference set.
942 template<typename SortPolicy, typename MatType>
944 Search(const size_t k,
945  arma::Mat<size_t>& resultingNeighbors,
946  arma::mat& distances,
947  const size_t numTablesToSearch,
948  size_t T)
949 {
950  // This is monochromatic search; the query set is the reference set.
951  resultingNeighbors.set_size(k, referenceSet.n_cols);
952  distances.set_size(k, referenceSet.n_cols);
953 
954  // If the user requested more than the available number of additional probing
955  // bins, set Teffective to maximum T. Maximum T is 2^numProj - 1
956  size_t Teffective = T;
957  if (T > ((size_t) ((1 << numProj) - 1)))
958  {
959  Teffective = (1 << numProj) - 1;
960  Log::Warn << "Requested " << T << " additional bins are more than "
961  << "theoretical maximum. Using " << Teffective << " instead."
962  << std::endl;
963  }
964 
965  // If the user set multiprobe, log it
966  if (T > 0)
967  Log::Info << "Running multiprobe LSH with " << Teffective <<
968  " additional probing bins per table per query."<< std::endl;
969 
970  size_t avgIndicesReturned = 0;
971 
972  Timer::Start("computing_neighbors");
973 
974  // Parallelization to process more than one query at a time.
975  #pragma omp parallel for \
976  shared(resultingNeighbors, distances) \
977  schedule(dynamic)\
978  reduction(+:avgIndicesReturned)
979  for (omp_size_t i = 0; i < (omp_size_t) referenceSet.n_cols; ++i)
980  {
981  // Go through every query point.
982  // Hash every query into every hash table and eventually into the
983  // 'secondHashTable' to obtain the neighbor candidates.
984  arma::uvec refIndices;
985  ReturnIndicesFromTable(referenceSet.col(i), refIndices, numTablesToSearch,
986  Teffective);
987 
988  // An informative book-keeping for the number of neighbor candidates
989  // returned on average.
990  // Make atomic to avoid race conditions when multiple threads are running.
991  // #pragma omp atomic
992  avgIndicesReturned += refIndices.n_elem;
993 
994  // Sequentially go through all the candidates and save the best 'k'
995  // candidates.
996  BaseCase(i, refIndices, k, resultingNeighbors, distances);
997  }
998 
999  Timer::Stop("computing_neighbors");
1000 
1001  distanceEvaluations += avgIndicesReturned;
1002  avgIndicesReturned /= referenceSet.n_cols;
1003  Log::Info << avgIndicesReturned << " distinct indices returned on average." <<
1004  std::endl;
1005 }
1006 
1007 template<typename SortPolicy, typename MatType>
1009  const arma::Mat<size_t>& foundNeighbors,
1010  const arma::Mat<size_t>& realNeighbors)
1011 {
1012  if (foundNeighbors.n_rows != realNeighbors.n_rows ||
1013  foundNeighbors.n_cols != realNeighbors.n_cols)
1014  throw std::invalid_argument("LSHSearch::ComputeRecall(): matrices provided"
1015  " must have equal size");
1016 
1017  const size_t queries = foundNeighbors.n_cols;
1018  const size_t neighbors = foundNeighbors.n_rows; // Should be equal to k.
1019 
1020  // The recall is the set intersection of found and real neighbors.
1021  size_t found = 0;
1022  for (size_t col = 0; col < queries; ++col)
1023  for (size_t row = 0; row < neighbors; ++row)
1024  for (size_t nei = 0; nei < realNeighbors.n_rows; ++nei)
1025  if (realNeighbors(row, col) == foundNeighbors(nei, col))
1026  {
1027  found++;
1028  break;
1029  }
1030 
1031  return ((double) found) / realNeighbors.n_elem;
1032 }
1033 
1034 template<typename SortPolicy, typename MatType>
1035 template<typename Archive>
1037  const uint32_t /* version */)
1038 {
1039  ar(CEREAL_NVP(referenceSet));
1040  ar(CEREAL_NVP(numProj));
1041  ar(CEREAL_NVP(numTables));
1042 
1043  // Delete existing projections, if necessary.
1044  if (cereal::is_loading<Archive>())
1045  projections.reset();
1046 
1047  ar(CEREAL_NVP(projections));
1048  ar(CEREAL_NVP(offsets));
1049  ar(CEREAL_NVP(hashWidth));
1050  ar(CEREAL_NVP(secondHashSize));
1051  ar(CEREAL_NVP(secondHashWeights));
1052  ar(CEREAL_NVP(bucketSize));
1053  ar(CEREAL_NVP(secondHashTable));
1054  ar(CEREAL_NVP(bucketContentSize));
1055  ar(CEREAL_NVP(bucketRowInHashTable));
1056  ar(CEREAL_NVP(distanceEvaluations));
1057 }
1058 
1059 } // namespace neighbor
1060 } // namespace mlpack
1061 
1062 #endif
static void Start(const std::string &name)
Start the given timer.
Definition: timers.cpp:28
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The core includes that mlpack expects; standard C++ includes and Armadillo.
Definition: pointer_wrapper.hpp:23
void serialize(Archive &ar, const uint32_t version)
Serialize the LSH model.
Definition: lsh_search_impl.hpp:1036
Definition: sfinae_test.cpp:18
LSHSearch & operator=(const LSHSearch &other)
Copy the given LSH model.
Definition: lsh_search_impl.hpp:122
void Train(MatType referenceSet, const size_t numProj, const size_t numTables, const double hashWidth=0.0, const size_t secondHashSize=99901, const size_t bucketSize=500, const arma::cube &projection=arma::cube())
Train the LSH model on the given dataset.
Definition: lsh_search_impl.hpp:174
The LSHSearch class; this class builds a hash on the reference set and uses this hash to compute the ...
Definition: lsh_search.hpp:72
static VecTypeA::elem_type Evaluate(const VecTypeA &a, const VecTypeB &b)
Computes the distance between two points.
Definition: lmetric_impl.hpp:24
static MLPACK_EXPORT util::PrefixedOutStream Warn
Prints warning messages prefixed with [WARN ].
Definition: log.hpp:87
static double ComputeRecall(const arma::Mat< size_t > &foundNeighbors, const arma::Mat< size_t > &realNeighbors)
Compute the recall (% of neighbors found) given the neighbors returned by LSHSearch::Search and a "gr...
Definition: lsh_search_impl.hpp:1008
static void Stop(const std::string &name)
Stop the given timer.
Definition: timers.cpp:36
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84
Miscellaneous math random-related routines.
int RandInt(const int hiExclusive)
Generates a uniform random integer.
Definition: random.hpp:110
void Search(const MatType &querySet, const size_t k, arma::Mat< size_t > &resultingNeighbors, arma::mat &distances, const size_t numTablesToSearch=0, const size_t T=0)
Compute the nearest neighbors of the points in the given query set and store the output in the given ...
Definition: lsh_search_impl.hpp:859
LSHSearch()
Create an untrained LSH model.
Definition: lsh_search_impl.hpp:64