mlpack
lmnn_function_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_LMNN_FUNCTION_IMPL_HPP
13 #define MLPACK_METHODS_LMNN_FUNCTION_IMPL_HPP
14 
15 #include "lmnn_function.hpp"
16 
18 
19 namespace mlpack {
20 namespace lmnn {
21 
22 template<typename MetricType>
23 LMNNFunction<MetricType>::LMNNFunction(const arma::mat& dataset,
24  const arma::Row<size_t>& labels,
25  size_t k,
26  double regularization,
27  size_t range,
28  MetricType metric) :
29  dataset(math::MakeAlias(const_cast<arma::mat&>(dataset), false)),
30  labels(math::MakeAlias(const_cast<arma::Row<size_t>&>(labels), false)),
31  k(k),
32  metric(metric),
33  regularization(regularization),
34  iteration(0),
35  range(range),
36  constraint(dataset, labels, k),
37  points(dataset.n_cols),
38  impBounds(false)
39 {
40  // Initialize the initial learning point.
41  initialPoint.eye(dataset.n_rows, dataset.n_rows);
42  // Initialize transformed dataset to base dataset.
43  transformedDataset = dataset;
44 
45  // Calculate and store norm of datapoints.
46  norm.set_size(dataset.n_cols);
47  for (size_t i = 0; i < dataset.n_cols; ++i)
48  {
49  norm(i) = arma::norm(dataset.col(i));
50  }
51 
52  // Initialize cache.
53  evalOld.set_size(k, k, dataset.n_cols);
54  evalOld.zeros();
55 
56  maxImpNorm.set_size(k, dataset.n_cols);
57  maxImpNorm.zeros();
58 
59  lastTransformationIndices.set_size(dataset.n_cols);
60  lastTransformationIndices.zeros();
61 
62  // Reserve the first element of cache.
63  arma::mat emptyMat;
64  oldTransformationMatrices.push_back(emptyMat);
65  oldTransformationCounts.push_back(dataset.n_cols);
66 
67  // Check if we can impose bounds over impostors.
68  size_t minCount = arma::min(arma::histc(labels, arma::unique(labels)));
69  if (minCount <= k + 1)
70  {
71  // Initialize target neighbors & impostors.
72  targetNeighbors.set_size(k, dataset.n_cols);
73  impostors.set_size(k, dataset.n_cols);
74  distance.set_size(k, dataset.n_cols);
75  }
76  else
77  {
78  // Update parameters.
79  constraint.K() = k + 1;
80  impBounds = true;
81  // Initialize target neighbors & impostors.
82  targetNeighbors.set_size(k + 1, dataset.n_cols);
83  impostors.set_size(k + 1, dataset.n_cols);
84  distance.set_size(k + 1, dataset.n_cols);
85  }
86 
87  constraint.TargetNeighbors(targetNeighbors, dataset, labels, norm);
88  constraint.Impostors(impostors, dataset, labels, norm);
89 
90  // Precalculate and save the gradient due to target neighbors.
91  Precalculate();
92 }
93 
95 template<typename MetricType>
97 {
98  arma::mat newDataset = dataset;
99  arma::Mat<size_t> newLabels = labels;
100  arma::cube newEvalOld = evalOld;
101  arma::vec newlastTransformationIndices = lastTransformationIndices;
102  arma::mat newMaxImpNorm = maxImpNorm;
103  arma::vec newNorm = norm;
104 
105  // Generate ordering.
106  arma::uvec ordering = arma::shuffle(arma::linspace<arma::uvec>(0,
107  dataset.n_cols - 1, dataset.n_cols));
108 
109  math::ClearAlias(dataset);
110  math::ClearAlias(labels);
111 
112  dataset = newDataset.cols(ordering);
113  labels = newLabels.cols(ordering);
114  maxImpNorm = newMaxImpNorm.cols(ordering);
115  lastTransformationIndices = newlastTransformationIndices.elem(ordering);
116  norm = newNorm.elem(ordering);
117 
118  for (size_t i = 0; i < ordering.n_elem; ++i)
119  {
120  evalOld.slice(i) = newEvalOld.slice(ordering(i));
121  }
122 
123  // Re-calculate target neighbors as indices changed.
124  constraint.PreCalulated() = false;
125  constraint.TargetNeighbors(targetNeighbors, dataset, labels, norm);
126 }
127 
128 // Update cache transformation matrices.
129 template<typename MetricType>
131  const arma::mat& transformation,
132  const size_t begin,
133  const size_t batchSize)
134 {
135  // Are there any empty transformation matrices?
136  size_t index = oldTransformationMatrices.size();
137  for (size_t i = 1; i < oldTransformationCounts.size(); ++i)
138  {
139  if (oldTransformationCounts[i] == 0)
140  {
141  index = i; // Reuse this index.
142  break;
143  }
144  }
145 
146  // Did we find an unused matrix? If not, we have to allocate new space.
147  if (index == oldTransformationMatrices.size())
148  {
149  oldTransformationMatrices.push_back(transformation);
150  oldTransformationCounts.push_back(0);
151  }
152  else
153  {
154  oldTransformationMatrices[index] = transformation;
155  }
156 
157  // Update all the transformation indices.
158  for (size_t i = begin; i < begin + batchSize; ++i)
159  {
160  --oldTransformationCounts[lastTransformationIndices(i)];
161  lastTransformationIndices(i) = index;
162  }
163 
164  oldTransformationCounts[index] += batchSize;
165 
166  #ifdef DEBUG
167  size_t total = 0;
168  for (size_t i = 1; i < oldTransformationCounts.size(); ++i)
169  {
170  std::ostringstream oss;
171  oss << "transformation counts for matrix " << i
172  << " invalid (" << oldTransformationCounts[i] << ")!";
173  Log::Assert(oldTransformationCounts[i] <= dataset.n_cols, oss.str());
174  total += oldTransformationCounts[i];
175  }
176 
177  std::ostringstream oss;
178  oss << "total count for transformation matrices invalid (" << total
179  << ", " << "should be " << dataset.n_cols << "!";
180  if (begin + batchSize == dataset.n_cols)
181  Log::Assert(total == dataset.n_cols, oss.str());
182  #endif
183 }
184 
185 // Calculate norm of change in transformation.
186 template<typename MetricType>
188  std::map<size_t, double>& transformationDiffs,
189  const arma::mat& transformation,
190  const size_t begin,
191  const size_t batchSize)
192 {
193  for (size_t i = begin; i < begin + batchSize; ++i)
194  {
195  if (transformationDiffs.count(lastTransformationIndices[i]) == 0)
196  {
197  if (lastTransformationIndices[i] == 0)
198  {
199  transformationDiffs[0] = 0.0; // This won't be used anyway...
200  }
201  else
202  {
203  transformationDiffs[lastTransformationIndices[i]] =
204  arma::norm(transformation -
205  oldTransformationMatrices[lastTransformationIndices(i)]);
206  }
207  }
208  }
209 }
210 
212 template<typename MetricType>
213 double LMNNFunction<MetricType>::Evaluate(const arma::mat& transformation)
214 {
215  double cost = 0;
216 
217  // Apply metric over dataset.
218  transformedDataset = transformation * dataset;
219 
220  double transformationDiff = 0;
221  if (!transformationOld.is_empty())
222  {
223  // Calculate norm of change in transformation.
224  transformationDiff = arma::norm(transformation - transformationOld);
225  }
226 
227  if (!transformationOld.is_empty() && iteration++ % range == 0)
228  {
229  if (impBounds)
230  {
231  // Track number of data points to use for impostors calculatiom.
232  size_t numPoints = 0;
233 
234  for (size_t i = 0; i < dataset.n_cols; ++i)
235  {
236  if (transformationDiff * (2 * norm(i) + norm(impostors(k - 1, i)) +
237  norm(impostors(k, i))) > distance(k, i) - distance(k - 1, i))
238  {
239  points(numPoints++) = i;
240  }
241  }
242 
243  // Re-calculate impostors on transformed dataset.
244  constraint.Impostors(impostors, distance,
245  transformedDataset, labels, norm, points, numPoints);
246  }
247  else
248  {
249  // Re-calculate impostors on transformed dataset.
250  constraint.Impostors(impostors, distance, transformedDataset, labels,
251  norm);
252  }
253  }
254  else if (iteration++ % range == 0)
255  {
256  // Re-calculate impostors on transformed dataset.
257  constraint.Impostors(impostors, distance, transformedDataset, labels, norm);
258  }
259 
260  for (size_t i = 0; i < dataset.n_cols; ++i)
261  {
262  for (size_t j = 0; j < k ; ++j)
263  {
264  // Calculate cost due to distance between target neighbors & data point.
265  double eval = metric.Evaluate(transformedDataset.col(i),
266  transformedDataset.col(targetNeighbors(j, i)));
267  cost += (1 - regularization) * eval;
268  }
269 
270  for (int j = k - 1; j >= 0; j--)
271  {
272  // Bound constraints to avoid uneccesary computation. Here bp stands for
273  // breaking point.
274  for (size_t l = 0, bp = k; l < bp ; l++)
275  {
276  // Calculate cost due to {data point, target neighbors, impostors}
277  // triplets.
278  double eval = 0;
279 
280  // Bounds for eval.
281  if (!transformationOld.is_empty() && evalOld(l, j, i) < -1)
282  {
283  // Update cache max impostor norm.
284  maxImpNorm(l, i) = std::max(maxImpNorm(l, i), norm(impostors(l, i)));
285 
286  eval = evalOld(l, j, i) + transformationDiff *
287  (norm(targetNeighbors(j, i)) + maxImpNorm(l, i) +
288  2 * norm(i));
289  }
290 
291  // Calculate exact eval value.
292  if (eval > -1)
293  {
294  if (iteration - 1 % range == 0)
295  {
296  eval = metric.Evaluate(transformedDataset.col(i),
297  transformedDataset.col(targetNeighbors(j, i))) -
298  distance(l, i);
299  }
300  else
301  {
302  eval = metric.Evaluate(transformedDataset.col(i),
303  transformedDataset.col(targetNeighbors(j, i))) -
304  metric.Evaluate(transformedDataset.col(i),
305  transformedDataset.col(impostors(l, i)));
306  }
307  }
308 
309  // Update cache eval value.
310  evalOld(l, j, i) = eval;
311 
312  // Check bounding condition.
313  if (eval <= -1)
314  {
315  // update bound.
316  bp = l;
317  break;
318  }
319 
320  cost += regularization * (1 + eval);
321 
322  // Reset cache.
323  if (eval > -1)
324  {
325  // update bound.
326  evalOld(l, j, i) = 0;
327  maxImpNorm(l, i) = 0;
328  }
329  }
330  }
331  }
332 
333  // Update cache transformation matrix.
334  transformationOld = transformation;
335 
336  return cost;
337 }
338 
340 template<typename MetricType>
341 double LMNNFunction<MetricType>::Evaluate(const arma::mat& transformation,
342  const size_t begin,
343  const size_t batchSize)
344 {
345  double cost = 0;
346 
347  // Calculate norm of change in transformation.
348  std::map<size_t, double> transformationDiffs;
349  TransDiff(transformationDiffs, transformation, begin, batchSize);
350 
351  // Apply metric over dataset.
352  transformedDataset = transformation * dataset;
353 
354  if (impBounds && iteration++ % range == 0)
355  {
356  // Track number of data points to use for impostors calculatiom.
357  size_t numPoints = 0;
358 
359  for (size_t i = begin; i < begin + batchSize; ++i)
360  {
361  if (lastTransformationIndices(i))
362  {
363  if (transformationDiffs[lastTransformationIndices[i]] *
364  (2 * norm(i) + norm(impostors(k - 1, i)) +
365  norm(impostors(k, i))) > distance(k, i) - distance(k - 1, i))
366  {
367  points(numPoints++) = i;
368  }
369  }
370  else
371  {
372  points(numPoints++) = i;
373  }
374  }
375 
376  // Re-calculate impostors on transformed dataset.
377  constraint.Impostors(impostors, distance,
378  transformedDataset, labels, norm, points, numPoints);
379  }
380  else if (iteration++ % range == 0)
381  {
382  // Re-calculate impostors on transformed dataset.
383  constraint.Impostors(impostors, distance, transformedDataset, labels,
384  norm, begin, batchSize);
385  }
386 
387  for (size_t i = begin; i < begin + batchSize; ++i)
388  {
389  for (size_t j = 0; j < k ; ++j)
390  {
391  // Calculate cost due to distance between target neighbors & data point.
392  double eval = metric.Evaluate(transformedDataset.col(i),
393  transformedDataset.col(targetNeighbors(j, i)));
394  cost += (1 - regularization) * eval;
395  }
396 
397  for (int j = k - 1; j >= 0; j--)
398  {
399  // Bound constraints to avoid uneccesary computation. Here bp stands for
400  // breaking point.
401  for (size_t l = 0, bp = k; l < bp ; l++)
402  {
403  // Calculate cost due to {data point, target neighbors, impostors}
404  // triplets.
405  double eval = 0;
406 
407  // Bounds for eval.
408  if (lastTransformationIndices(i) && evalOld(l, j, i) < -1)
409  {
410  // Update cache max impostor norm.
411  maxImpNorm(l, i) = std::max(maxImpNorm(l, i), norm(impostors(l, i)));
412 
413  eval = evalOld(l, j, i) +
414  transformationDiffs[lastTransformationIndices[i]] *
415  (norm(targetNeighbors(j, i)) + maxImpNorm(l, i) + 2 * norm(i));
416  }
417 
418  // Calculate exact eval value.
419  if (eval > -1)
420  {
421  if (iteration - 1 % range == 0)
422  {
423  eval = metric.Evaluate(transformedDataset.col(i),
424  transformedDataset.col(targetNeighbors(j, i))) -
425  distance(l, i);
426  }
427  else
428  {
429  eval = metric.Evaluate(transformedDataset.col(i),
430  transformedDataset.col(targetNeighbors(j, i))) -
431  metric.Evaluate(transformedDataset.col(i),
432  transformedDataset.col(impostors(l, i)));
433  }
434  }
435 
436  // Update cache eval value.
437  evalOld(l, j, i) = eval;
438 
439  // Check bounding condition.
440  if (eval <= -1)
441  {
442  // update bound.
443  bp = l;
444  break;
445  }
446 
447  cost += regularization * (1 + eval);
448 
449  // Reset cache.
450  if (eval > -1 && lastTransformationIndices(i))
451  {
452  // update bound.
453  evalOld(l, j, i) = 0;
454  maxImpNorm(l, i) = 0;
455  --oldTransformationCounts[lastTransformationIndices(i)];
456  lastTransformationIndices(i) = 0;
457  }
458  }
459  }
460  }
461 
462  // Update cache.
463  UpdateCache(transformation, begin, batchSize);
464 
465  return cost;
466 }
467 
469 template<typename MetricType>
470 template<typename GradType>
471 void LMNNFunction<MetricType>::Gradient(const arma::mat& transformation,
472  GradType& gradient)
473 {
474  // Apply metric over dataset.
475  transformedDataset = transformation * dataset;
476 
477  double transformationDiff = 0;
478  if (!transformationOld.is_empty() && iteration++ % range == 0)
479  {
480  // Calculate norm of change in transformation.
481  transformationDiff = arma::norm(transformation - transformationOld);
482 
483  if (impBounds)
484  {
485  // Track number of data points to use for impostors calculatiom.
486  size_t numPoints = 0;
487 
488  for (size_t i = 0; i < dataset.n_cols; ++i)
489  {
490  if (transformationDiff * (2 * norm(i) + norm(impostors(k - 1, i)) +
491  norm(impostors(k, i))) > distance(k, i) - distance(k - 1, i))
492  {
493  points(numPoints++) = i;
494  }
495  }
496 
497  // Re-calculate impostors on transformed dataset.
498  constraint.Impostors(impostors, distance,
499  transformedDataset, labels, norm, points, numPoints);
500  }
501  else
502  {
503  // Re-calculate impostors on transformed dataset.
504  constraint.Impostors(impostors, distance, transformedDataset, labels,
505  norm);
506  }
507  }
508  else if (iteration++ % range == 0)
509  {
510  // Re-calculate impostors on transformed dataset.
511  constraint.Impostors(impostors, distance, transformedDataset, labels,
512  norm);
513  }
514 
515  gradient.zeros(transformation.n_rows, transformation.n_cols);
516 
517  // Calculate gradient due to target neighbors.
518  arma::mat cij = pCij;
519 
520  // Calculate gradient due to impostors.
521  arma::mat cil = arma::zeros(dataset.n_rows, dataset.n_rows);
522 
523  for (size_t i = 0; i < dataset.n_cols; ++i)
524  {
525  for (int j = k - 1; j >= 0; j--)
526  {
527  // Bound constraints to avoid uneccesary computation.
528  for (size_t l = 0, bp = k; l < bp ; l++)
529  {
530  // Calculate cost due to {data point, target neighbors, impostors}
531  // triplets.
532  double eval = 0;
533 
534  // Bounds for eval.
535  if (!transformationOld.is_empty() && evalOld(l, j, i) < -1)
536  {
537  // Update cache max impostor norm.
538  maxImpNorm(l, i) = std::max(maxImpNorm(l, i), norm(impostors(l, i)));
539 
540  eval = evalOld(l, j, i) + transformationDiff *
541  (norm(targetNeighbors(j, i)) + maxImpNorm(l, i) +
542  2 * norm(i));
543  }
544 
545  // Calculate exact eval value.
546  if (eval > -1)
547  {
548  if (iteration - 1 % range == 0)
549  {
550  eval = metric.Evaluate(transformedDataset.col(i),
551  transformedDataset.col(targetNeighbors(j, i))) -
552  distance(l, i);
553  }
554  else
555  {
556  eval = metric.Evaluate(transformedDataset.col(i),
557  transformedDataset.col(targetNeighbors(j, i))) -
558  metric.Evaluate(transformedDataset.col(i),
559  transformedDataset.col(impostors(l, i)));
560  }
561  }
562 
563  // Update cache eval value.
564  evalOld(l, j, i) = eval;
565 
566  // Check bounding condition.
567  if (eval <= -1)
568  {
569  // update bound.
570  bp = l;
571  break;
572  }
573 
574  // Reset cache.
575  if (eval > -1)
576  {
577  // update bound.
578  evalOld(l, j, i) = 0;
579  maxImpNorm(l, i) = 0;
580  }
581 
582  // Caculate gradient due to impostors.
583  arma::vec diff = dataset.col(i) - dataset.col(targetNeighbors(j, i));
584  cil += diff * arma::trans(diff);
585 
586  diff = dataset.col(i) - dataset.col(impostors(l, i));
587  cil -= diff * arma::trans(diff);
588  }
589  }
590  }
591 
592  gradient = 2 * transformation * ((1 - regularization) * cij +
593  regularization * cil);
594 
595  // Update cache transformation matrix.
596  transformationOld = transformation;
597 }
598 
600 template<typename MetricType>
601 template<typename GradType>
602 void LMNNFunction<MetricType>::Gradient(const arma::mat& transformation,
603  const size_t begin,
604  GradType& gradient,
605  const size_t batchSize)
606 {
607  // Apply metric over dataset.
608  transformedDataset = transformation * dataset;
609 
610  // Calculate norm of change in transformation.
611  std::map<size_t, double> transformationDiffs;
612  TransDiff(transformationDiffs, transformation, begin, batchSize);
613 
614  if (impBounds && iteration++ % range == 0)
615  {
616  // Track number of data points to use for impostors calculatiom.
617  size_t numPoints = 0;
618 
619  for (size_t i = begin; i < begin + batchSize; ++i)
620  {
621  if (lastTransformationIndices(i))
622  {
623  if (transformationDiffs[lastTransformationIndices[i]] *
624  (2 * norm(i) + norm(impostors(k - 1, i)) +
625  norm(impostors(k, i))) > distance(k, i) - distance(k - 1, i))
626  {
627  points(numPoints++) = i;
628  }
629  }
630  else
631  {
632  points(numPoints++) = i;
633  }
634  }
635 
636  // Re-calculate impostors on transformed dataset.
637  constraint.Impostors(impostors, distance,
638  transformedDataset, labels, norm, points, numPoints);
639  }
640  else if (iteration++ % range == 0)
641  {
642  // Re-calculate impostors on transformed dataset.
643  constraint.Impostors(impostors, distance, transformedDataset, labels,
644  norm, begin, batchSize);
645  }
646 
647  gradient.zeros(transformation.n_rows, transformation.n_cols);
648 
649  arma::mat cij = arma::zeros(dataset.n_rows, dataset.n_rows);
650  arma::mat cil = arma::zeros(dataset.n_rows, dataset.n_rows);
651 
652  for (size_t i = begin; i < begin + batchSize; ++i)
653  {
654  for (size_t j = 0; j < k ; ++j)
655  {
656  // Calculate gradient due to target neighbors.
657  arma::vec diff = dataset.col(i) - dataset.col(targetNeighbors(j, i));
658  cij += diff * arma::trans(diff);
659  }
660 
661  for (int j = k - 1; j >= 0; j--)
662  {
663  // Bound constraints to avoid uneccesary computation.
664  for (size_t l = 0, bp = k; l < bp ; l++)
665  {
666  // Calculate cost due to {data point, target neighbors, impostors}
667  // triplets.
668  double eval = 0;
669 
670  // Bounds for eval.
671  if (lastTransformationIndices(i) && evalOld(l, j, i) < -1)
672  {
673  // Update cache max impostor norm.
674  maxImpNorm(l, i) = std::max(maxImpNorm(l, i), norm(impostors(l, i)));
675 
676  eval = evalOld(l, j, i) +
677  transformationDiffs[lastTransformationIndices[i]] *
678  (norm(targetNeighbors(j, i)) + maxImpNorm(l, i) + 2 * norm(i));
679  }
680 
681  // Calculate exact eval value.
682  if (eval > -1)
683  {
684  if (iteration - 1 % range == 0)
685  {
686  eval = metric.Evaluate(transformedDataset.col(i),
687  transformedDataset.col(targetNeighbors(j, i))) -
688  distance(l, i);
689  }
690  else
691  {
692  eval = metric.Evaluate(transformedDataset.col(i),
693  transformedDataset.col(targetNeighbors(j, i))) -
694  metric.Evaluate(transformedDataset.col(i),
695  transformedDataset.col(impostors(l, i)));
696  }
697  }
698 
699  // Update cache eval value.
700  evalOld(l, j, i) = eval;
701 
702  // Check bounding condition.
703  if (eval <= -1)
704  {
705  // update bound.
706  bp = l;
707  break;
708  }
709 
710  // Reset cache.
711  if (eval > -1 && lastTransformationIndices(i))
712  {
713  // update bound.
714  evalOld(l, j, i) = 0;
715  maxImpNorm(l, i) = 0;
716  --oldTransformationCounts[lastTransformationIndices(i)];
717  lastTransformationIndices(i) = 0;
718  }
719 
720  // Caculate gradient due to impostors.
721  arma::vec diff = dataset.col(i) - dataset.col(targetNeighbors(j, i));
722  cil += diff * arma::trans(diff);
723 
724  diff = dataset.col(i) - dataset.col(impostors(l, i));
725  cil -= diff * arma::trans(diff);
726  }
727  }
728  }
729 
730  gradient = 2 * transformation * ((1 - regularization) * cij +
731  regularization * cil);
732 
733  // Update cache.
734  UpdateCache(transformation, begin, batchSize);
735 }
736 
738 template<typename MetricType>
739 template<typename GradType>
741  const arma::mat& transformation,
742  GradType& gradient)
743 {
744  double cost = 0;
745 
746  // Apply metric over dataset.
747  transformedDataset = transformation * dataset;
748 
749  double transformationDiff = 0;
750  if (!transformationOld.is_empty())
751  {
752  // Calculate norm of change in transformation.
753  transformationDiff = arma::norm(transformation - transformationOld);
754  }
755 
756  if (!transformationOld.is_empty() && iteration++ % range == 0)
757  {
758  if (impBounds)
759  {
760  // Track number of data points to use for impostors calculatiom.
761  size_t numPoints = 0;
762 
763  for (size_t i = 0; i < dataset.n_cols; ++i)
764  {
765  if (transformationDiff * (2 * norm(i) + norm(impostors(k - 1, i)) +
766  norm(impostors(k, i))) > distance(k, i) - distance(k - 1, i))
767  {
768  points(numPoints++) = i;
769  }
770  }
771 
772  // Re-calculate impostors on transformed dataset.
773  constraint.Impostors(impostors, distance,
774  transformedDataset, labels, norm, points, numPoints);
775  }
776  else
777  {
778  // Re-calculate impostors on transformed dataset.
779  constraint.Impostors(impostors, distance, transformedDataset, labels,
780  norm);
781  }
782  }
783  else if (iteration++ % range == 0)
784  {
785  // Re-calculate impostors on transformed dataset.
786  constraint.Impostors(impostors, distance, transformedDataset, labels,
787  norm);
788  }
789 
790  gradient.zeros(transformation.n_rows, transformation.n_cols);
791 
792  // Calculate gradient due to target neighbors.
793  arma::mat cij = pCij;
794 
795  // Calculate gradient due to impostors.
796  arma::mat cil = arma::zeros(dataset.n_rows, dataset.n_rows);
797 
798  for (size_t i = 0; i < dataset.n_cols; ++i)
799  {
800  for (size_t j = 0; j < k ; ++j)
801  {
802  // Calculate cost due to distance between target neighbors & data point.
803  double eval = metric.Evaluate(transformedDataset.col(i),
804  transformedDataset.col(targetNeighbors(j, i)));
805  cost += (1 - regularization) * eval;
806  }
807 
808  for (int j = k - 1; j >= 0; j--)
809  {
810  // Bound constraints to avoid uneccesary computation.
811  for (size_t l = 0, bp = k; l < bp ; l++)
812  {
813  // Calculate cost due to {data point, target neighbors, impostors}
814  // triplets.
815  double eval = 0;
816 
817  // Bounds for eval.
818  if (!transformationOld.is_empty() && evalOld(l, j, i) < -1)
819  {
820  // Update cache max impostor norm.
821  maxImpNorm(l, i) = std::max(maxImpNorm(l, i), norm(impostors(l, i)));
822 
823  eval = evalOld(l, j, i) + transformationDiff *
824  (norm(targetNeighbors(j, i)) + maxImpNorm(l, i) +
825  2 * norm(i));
826  }
827 
828  // Calculate exact eval value.
829  if (eval > -1)
830  {
831  if (iteration - 1 % range == 0)
832  {
833  eval = metric.Evaluate(transformedDataset.col(i),
834  transformedDataset.col(targetNeighbors(j, i))) -
835  distance(l, i);
836  }
837  else
838  {
839  eval = metric.Evaluate(transformedDataset.col(i),
840  transformedDataset.col(targetNeighbors(j, i))) -
841  metric.Evaluate(transformedDataset.col(i),
842  transformedDataset.col(impostors(l, i)));
843  }
844  }
845 
846  // Update cache eval value.
847  evalOld(l, j, i) = eval;
848 
849  // Check bounding condition.
850  if (eval <= -1)
851  {
852  // update bound.
853  bp = l;
854  break;
855  }
856 
857  cost += regularization * (1 + eval);
858 
859  // Caculate gradient due to impostors.
860  arma::vec diff = dataset.col(i) - dataset.col(targetNeighbors(j, i));
861  cil += diff * arma::trans(diff);
862 
863  diff = dataset.col(i) - dataset.col(impostors(l, i));
864  cil -= diff * arma::trans(diff);
865  }
866  }
867  }
868 
869  gradient = 2 * transformation * ((1 - regularization) * cij +
870  regularization * cil);
871 
872  // Update cache transformation matrix.
873  transformationOld = transformation;
874 
875  return cost;
876 }
877 
879 template<typename MetricType>
880 template<typename GradType>
882  const arma::mat& transformation,
883  const size_t begin,
884  GradType& gradient,
885  const size_t batchSize)
886 {
887  double cost = 0;
888 
889  // Calculate norm of change in transformation.
890  std::map<size_t, double> transformationDiffs;
891  TransDiff(transformationDiffs, transformation, begin, batchSize);
892 
893  // Apply metric over dataset.
894  transformedDataset = transformation * dataset;
895 
896  if (impBounds && iteration++ % range == 0)
897  {
898  // Track number of data points to use for impostors calculatiom.
899  size_t numPoints = 0;
900 
901  for (size_t i = begin; i < begin + batchSize; ++i)
902  {
903  if (lastTransformationIndices(i))
904  {
905  if (transformationDiffs[lastTransformationIndices[i]] *
906  (2 * norm(i) + norm(impostors(k - 1, i)) +
907  norm(impostors(k, i))) > distance(k, i) - distance(k - 1, i))
908  {
909  points(numPoints++) = i;
910  }
911  }
912  else
913  {
914  points(numPoints++) = i;
915  }
916  }
917 
918  // Re-calculate impostors on transformed dataset.
919  constraint.Impostors(impostors, distance,
920  transformedDataset, labels, norm, points, numPoints);
921  }
922  else if (iteration++ % range == 0)
923  {
924  // Re-calculate impostors on transformed dataset.
925  constraint.Impostors(impostors, distance, transformedDataset, labels,
926  norm, begin, batchSize);
927  }
928 
929  gradient.zeros(transformation.n_rows, transformation.n_cols);
930 
931  arma::mat cij = arma::zeros(dataset.n_rows, dataset.n_rows);
932  arma::mat cil = arma::zeros(dataset.n_rows, dataset.n_rows);
933 
934  for (size_t i = begin; i < begin + batchSize; ++i)
935  {
936  for (size_t j = 0; j < k ; ++j)
937  {
938  // Calculate cost due to distance between target neighbors & data point.
939  double eval = metric.Evaluate(transformedDataset.col(i),
940  transformedDataset.col(targetNeighbors(j, i)));
941  cost += (1 - regularization) * eval;
942 
943  // Calculate gradient due to target neighbors.
944  arma::vec diff = dataset.col(i) - dataset.col(targetNeighbors(j, i));
945  cij += diff * arma::trans(diff);
946  }
947 
948  for (int j = k - 1; j >= 0; j--)
949  {
950  // Bound constraints to avoid uneccesary computation.
951  for (size_t l = 0, bp = k; l < bp ; l++)
952  {
953  // Calculate cost due to {data point, target neighbors, impostors}
954  // triplets.
955  double eval = 0;
956 
957  // Bounds for eval.
958  if (lastTransformationIndices(i) && evalOld(l, j, i) < -1)
959  {
960  // Update cache max impostor norm.
961  maxImpNorm(l, i) = std::max(maxImpNorm(l, i), norm(impostors(l, i)));
962 
963  eval = evalOld(l, j, i) +
964  transformationDiffs[lastTransformationIndices[i]] *
965  (norm(targetNeighbors(j, i)) + maxImpNorm(l, i) + 2 * norm(i));
966  }
967 
968  // Calculate exact eval value.
969  if (eval > -1)
970  {
971  if (iteration - 1 % range == 0)
972  {
973  eval = metric.Evaluate(transformedDataset.col(i),
974  transformedDataset.col(targetNeighbors(j, i))) -
975  distance(l, i);
976  }
977  else
978  {
979  eval = metric.Evaluate(transformedDataset.col(i),
980  transformedDataset.col(targetNeighbors(j, i))) -
981  metric.Evaluate(transformedDataset.col(i),
982  transformedDataset.col(impostors(l, i)));
983  }
984  }
985 
986  // Update cache eval value.
987  evalOld(l, j, i) = eval;
988 
989  // Check bounding condition.
990  if (eval <= -1)
991  {
992  // update bound.
993  bp = l;
994  break;
995  }
996 
997  cost += regularization * (1 + eval);
998 
999  // Caculate gradient due to impostors.
1000  arma::vec diff = dataset.col(i) - dataset.col(targetNeighbors(j, i));
1001  cil += diff * arma::trans(diff);
1002 
1003  diff = dataset.col(i) - dataset.col(impostors(l, i));
1004  cil -= diff * arma::trans(diff);
1005  }
1006  }
1007  }
1008 
1009  gradient = 2 * transformation * ((1 - regularization) * cij +
1010  regularization * cil);
1011 
1012  // Update cache.
1013  UpdateCache(transformation, begin, batchSize);
1014 
1015  return cost;
1016 }
1017 
1018 template<typename MetricType>
1020 {
1021  pCij.zeros(dataset.n_rows, dataset.n_rows);
1022 
1023  for (size_t i = 0; i < dataset.n_cols; ++i)
1024  {
1025  for (size_t j = 0; j < k ; ++j)
1026  {
1027  // Calculate gradient due to target neighbors.
1028  arma::vec diff = dataset.col(i) - dataset.col(targetNeighbors(j, i));
1029  pCij += diff * arma::trans(diff);
1030  }
1031  }
1032 }
1033 
1034 } // namespace lmnn
1035 } // namespace mlpack
1036 
1037 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The Large Margin Nearest Neighbors function.
Definition: lmnn_function.hpp:46
double EvaluateWithGradient(const arma::mat &transformation, GradType &gradient)
Evaluate the LMNN objective function together with gradient for the given transformation matrix...
Definition: lmnn_function_impl.hpp:740
double Evaluate(const arma::mat &transformation)
Evaluate the LMNN function for the given transformation matrix.
Definition: lmnn_function_impl.hpp:213
void Gradient(const arma::mat &transformation, GradType &gradient)
Evaluate the gradient of the LMNN function for the given transformation matrix.
Definition: lmnn_function_impl.hpp:471
LMNNFunction(const arma::mat &dataset, const arma::Row< size_t > &labels, size_t k, double regularization, size_t range, MetricType metric=MetricType())
Constructor for LMNNFunction class.
Definition: lmnn_function_impl.hpp:23
void ClearAlias(arma::Mat< ElemType > &mat)
Clear an alias so that no data is overwritten.
Definition: make_alias.hpp:110
arma::Cube< ElemType > MakeAlias(arma::Cube< ElemType > &input, const bool strict=true)
Make an alias of a dense cube.
Definition: make_alias.hpp:24
void Shuffle()
Shuffle the points in the dataset.
Definition: lmnn_function_impl.hpp:96
static void Assert(bool condition, const std::string &message="Assert Failed.")
Checks if the specified condition is true.
Definition: log.cpp:38