12 #ifndef MLPACK_METHODS_LMNN_FUNCTION_IMPL_HPP 13 #define MLPACK_METHODS_LMNN_FUNCTION_IMPL_HPP 22 template<
typename MetricType>
24 const arma::Row<size_t>& labels,
26 double regularization,
29 dataset(math::
MakeAlias(const_cast<
arma::mat&>(dataset), false)),
30 labels(math::
MakeAlias(const_cast<
arma::Row<size_t>&>(labels), false)),
33 regularization(regularization),
36 constraint(dataset, labels, k),
37 points(dataset.n_cols),
41 initialPoint.eye(dataset.n_rows, dataset.n_rows);
43 transformedDataset = dataset;
46 norm.set_size(dataset.n_cols);
47 for (
size_t i = 0; i < dataset.n_cols; ++i)
49 norm(i) = arma::norm(dataset.col(i));
53 evalOld.set_size(k, k, dataset.n_cols);
56 maxImpNorm.set_size(k, dataset.n_cols);
59 lastTransformationIndices.set_size(dataset.n_cols);
60 lastTransformationIndices.zeros();
64 oldTransformationMatrices.push_back(emptyMat);
65 oldTransformationCounts.push_back(dataset.n_cols);
68 size_t minCount = arma::min(arma::histc(labels, arma::unique(labels)));
69 if (minCount <= k + 1)
72 targetNeighbors.set_size(k, dataset.n_cols);
73 impostors.set_size(k, dataset.n_cols);
74 distance.set_size(k, dataset.n_cols);
79 constraint.K() = k + 1;
82 targetNeighbors.set_size(k + 1, dataset.n_cols);
83 impostors.set_size(k + 1, dataset.n_cols);
84 distance.set_size(k + 1, dataset.n_cols);
87 constraint.TargetNeighbors(targetNeighbors, dataset, labels, norm);
88 constraint.Impostors(impostors, dataset, labels, norm);
95 template<
typename MetricType>
98 arma::mat newDataset = dataset;
99 arma::Mat<size_t> newLabels = labels;
100 arma::cube newEvalOld = evalOld;
101 arma::vec newlastTransformationIndices = lastTransformationIndices;
102 arma::mat newMaxImpNorm = maxImpNorm;
103 arma::vec newNorm = norm;
106 arma::uvec ordering = arma::shuffle(arma::linspace<arma::uvec>(0,
107 dataset.n_cols - 1, dataset.n_cols));
112 dataset = newDataset.cols(ordering);
113 labels = newLabels.cols(ordering);
114 maxImpNorm = newMaxImpNorm.cols(ordering);
115 lastTransformationIndices = newlastTransformationIndices.elem(ordering);
116 norm = newNorm.elem(ordering);
118 for (
size_t i = 0; i < ordering.n_elem; ++i)
120 evalOld.slice(i) = newEvalOld.slice(ordering(i));
124 constraint.PreCalulated() =
false;
125 constraint.TargetNeighbors(targetNeighbors, dataset, labels, norm);
129 template<
typename MetricType>
131 const arma::mat& transformation,
133 const size_t batchSize)
136 size_t index = oldTransformationMatrices.size();
137 for (
size_t i = 1; i < oldTransformationCounts.size(); ++i)
139 if (oldTransformationCounts[i] == 0)
147 if (index == oldTransformationMatrices.size())
149 oldTransformationMatrices.push_back(transformation);
150 oldTransformationCounts.push_back(0);
154 oldTransformationMatrices[index] = transformation;
158 for (
size_t i = begin; i < begin + batchSize; ++i)
160 --oldTransformationCounts[lastTransformationIndices(i)];
161 lastTransformationIndices(i) = index;
164 oldTransformationCounts[index] += batchSize;
168 for (
size_t i = 1; i < oldTransformationCounts.size(); ++i)
170 std::ostringstream oss;
171 oss <<
"transformation counts for matrix " << i
172 <<
" invalid (" << oldTransformationCounts[i] <<
")!";
173 Log::Assert(oldTransformationCounts[i] <= dataset.n_cols, oss.str());
174 total += oldTransformationCounts[i];
177 std::ostringstream oss;
178 oss <<
"total count for transformation matrices invalid (" << total
179 <<
", " <<
"should be " << dataset.n_cols <<
"!";
180 if (begin + batchSize == dataset.n_cols)
186 template<
typename MetricType>
188 std::map<size_t, double>& transformationDiffs,
189 const arma::mat& transformation,
191 const size_t batchSize)
193 for (
size_t i = begin; i < begin + batchSize; ++i)
195 if (transformationDiffs.count(lastTransformationIndices[i]) == 0)
197 if (lastTransformationIndices[i] == 0)
199 transformationDiffs[0] = 0.0;
203 transformationDiffs[lastTransformationIndices[i]] =
204 arma::norm(transformation -
205 oldTransformationMatrices[lastTransformationIndices(i)]);
212 template<
typename MetricType>
218 transformedDataset = transformation * dataset;
220 double transformationDiff = 0;
221 if (!transformationOld.is_empty())
224 transformationDiff = arma::norm(transformation - transformationOld);
227 if (!transformationOld.is_empty() && iteration++ % range == 0)
232 size_t numPoints = 0;
234 for (
size_t i = 0; i < dataset.n_cols; ++i)
236 if (transformationDiff * (2 * norm(i) + norm(impostors(k - 1, i)) +
237 norm(impostors(k, i))) > distance(k, i) - distance(k - 1, i))
239 points(numPoints++) = i;
244 constraint.Impostors(impostors, distance,
245 transformedDataset, labels, norm, points, numPoints);
250 constraint.Impostors(impostors, distance, transformedDataset, labels,
254 else if (iteration++ % range == 0)
257 constraint.Impostors(impostors, distance, transformedDataset, labels, norm);
260 for (
size_t i = 0; i < dataset.n_cols; ++i)
262 for (
size_t j = 0; j < k ; ++j)
265 double eval = metric.Evaluate(transformedDataset.col(i),
266 transformedDataset.col(targetNeighbors(j, i)));
267 cost += (1 - regularization) * eval;
270 for (
int j = k - 1; j >= 0; j--)
274 for (
size_t l = 0, bp = k; l < bp ; l++)
281 if (!transformationOld.is_empty() && evalOld(l, j, i) < -1)
284 maxImpNorm(l, i) = std::max(maxImpNorm(l, i), norm(impostors(l, i)));
286 eval = evalOld(l, j, i) + transformationDiff *
287 (norm(targetNeighbors(j, i)) + maxImpNorm(l, i) +
294 if (iteration - 1 % range == 0)
296 eval = metric.Evaluate(transformedDataset.col(i),
297 transformedDataset.col(targetNeighbors(j, i))) -
302 eval = metric.Evaluate(transformedDataset.col(i),
303 transformedDataset.col(targetNeighbors(j, i))) -
304 metric.Evaluate(transformedDataset.col(i),
305 transformedDataset.col(impostors(l, i)));
310 evalOld(l, j, i) = eval;
320 cost += regularization * (1 + eval);
326 evalOld(l, j, i) = 0;
327 maxImpNorm(l, i) = 0;
334 transformationOld = transformation;
340 template<
typename MetricType>
343 const size_t batchSize)
348 std::map<size_t, double> transformationDiffs;
349 TransDiff(transformationDiffs, transformation, begin, batchSize);
352 transformedDataset = transformation * dataset;
354 if (impBounds && iteration++ % range == 0)
357 size_t numPoints = 0;
359 for (
size_t i = begin; i < begin + batchSize; ++i)
361 if (lastTransformationIndices(i))
363 if (transformationDiffs[lastTransformationIndices[i]] *
364 (2 * norm(i) + norm(impostors(k - 1, i)) +
365 norm(impostors(k, i))) > distance(k, i) - distance(k - 1, i))
367 points(numPoints++) = i;
372 points(numPoints++) = i;
377 constraint.Impostors(impostors, distance,
378 transformedDataset, labels, norm, points, numPoints);
380 else if (iteration++ % range == 0)
383 constraint.Impostors(impostors, distance, transformedDataset, labels,
384 norm, begin, batchSize);
387 for (
size_t i = begin; i < begin + batchSize; ++i)
389 for (
size_t j = 0; j < k ; ++j)
392 double eval = metric.Evaluate(transformedDataset.col(i),
393 transformedDataset.col(targetNeighbors(j, i)));
394 cost += (1 - regularization) * eval;
397 for (
int j = k - 1; j >= 0; j--)
401 for (
size_t l = 0, bp = k; l < bp ; l++)
408 if (lastTransformationIndices(i) && evalOld(l, j, i) < -1)
411 maxImpNorm(l, i) = std::max(maxImpNorm(l, i), norm(impostors(l, i)));
413 eval = evalOld(l, j, i) +
414 transformationDiffs[lastTransformationIndices[i]] *
415 (norm(targetNeighbors(j, i)) + maxImpNorm(l, i) + 2 * norm(i));
421 if (iteration - 1 % range == 0)
423 eval = metric.Evaluate(transformedDataset.col(i),
424 transformedDataset.col(targetNeighbors(j, i))) -
429 eval = metric.Evaluate(transformedDataset.col(i),
430 transformedDataset.col(targetNeighbors(j, i))) -
431 metric.Evaluate(transformedDataset.col(i),
432 transformedDataset.col(impostors(l, i)));
437 evalOld(l, j, i) = eval;
447 cost += regularization * (1 + eval);
450 if (eval > -1 && lastTransformationIndices(i))
453 evalOld(l, j, i) = 0;
454 maxImpNorm(l, i) = 0;
455 --oldTransformationCounts[lastTransformationIndices(i)];
456 lastTransformationIndices(i) = 0;
463 UpdateCache(transformation, begin, batchSize);
469 template<
typename MetricType>
470 template<
typename GradType>
475 transformedDataset = transformation * dataset;
477 double transformationDiff = 0;
478 if (!transformationOld.is_empty() && iteration++ % range == 0)
481 transformationDiff = arma::norm(transformation - transformationOld);
486 size_t numPoints = 0;
488 for (
size_t i = 0; i < dataset.n_cols; ++i)
490 if (transformationDiff * (2 * norm(i) + norm(impostors(k - 1, i)) +
491 norm(impostors(k, i))) > distance(k, i) - distance(k - 1, i))
493 points(numPoints++) = i;
498 constraint.Impostors(impostors, distance,
499 transformedDataset, labels, norm, points, numPoints);
504 constraint.Impostors(impostors, distance, transformedDataset, labels,
508 else if (iteration++ % range == 0)
511 constraint.Impostors(impostors, distance, transformedDataset, labels,
515 gradient.zeros(transformation.n_rows, transformation.n_cols);
518 arma::mat cij = pCij;
521 arma::mat cil = arma::zeros(dataset.n_rows, dataset.n_rows);
523 for (
size_t i = 0; i < dataset.n_cols; ++i)
525 for (
int j = k - 1; j >= 0; j--)
528 for (
size_t l = 0, bp = k; l < bp ; l++)
535 if (!transformationOld.is_empty() && evalOld(l, j, i) < -1)
538 maxImpNorm(l, i) = std::max(maxImpNorm(l, i), norm(impostors(l, i)));
540 eval = evalOld(l, j, i) + transformationDiff *
541 (norm(targetNeighbors(j, i)) + maxImpNorm(l, i) +
548 if (iteration - 1 % range == 0)
550 eval = metric.Evaluate(transformedDataset.col(i),
551 transformedDataset.col(targetNeighbors(j, i))) -
556 eval = metric.Evaluate(transformedDataset.col(i),
557 transformedDataset.col(targetNeighbors(j, i))) -
558 metric.Evaluate(transformedDataset.col(i),
559 transformedDataset.col(impostors(l, i)));
564 evalOld(l, j, i) = eval;
578 evalOld(l, j, i) = 0;
579 maxImpNorm(l, i) = 0;
583 arma::vec diff = dataset.col(i) - dataset.col(targetNeighbors(j, i));
584 cil += diff * arma::trans(diff);
586 diff = dataset.col(i) - dataset.col(impostors(l, i));
587 cil -= diff * arma::trans(diff);
592 gradient = 2 * transformation * ((1 - regularization) * cij +
593 regularization * cil);
596 transformationOld = transformation;
600 template<
typename MetricType>
601 template<
typename GradType>
605 const size_t batchSize)
608 transformedDataset = transformation * dataset;
611 std::map<size_t, double> transformationDiffs;
612 TransDiff(transformationDiffs, transformation, begin, batchSize);
614 if (impBounds && iteration++ % range == 0)
617 size_t numPoints = 0;
619 for (
size_t i = begin; i < begin + batchSize; ++i)
621 if (lastTransformationIndices(i))
623 if (transformationDiffs[lastTransformationIndices[i]] *
624 (2 * norm(i) + norm(impostors(k - 1, i)) +
625 norm(impostors(k, i))) > distance(k, i) - distance(k - 1, i))
627 points(numPoints++) = i;
632 points(numPoints++) = i;
637 constraint.Impostors(impostors, distance,
638 transformedDataset, labels, norm, points, numPoints);
640 else if (iteration++ % range == 0)
643 constraint.Impostors(impostors, distance, transformedDataset, labels,
644 norm, begin, batchSize);
647 gradient.zeros(transformation.n_rows, transformation.n_cols);
649 arma::mat cij = arma::zeros(dataset.n_rows, dataset.n_rows);
650 arma::mat cil = arma::zeros(dataset.n_rows, dataset.n_rows);
652 for (
size_t i = begin; i < begin + batchSize; ++i)
654 for (
size_t j = 0; j < k ; ++j)
657 arma::vec diff = dataset.col(i) - dataset.col(targetNeighbors(j, i));
658 cij += diff * arma::trans(diff);
661 for (
int j = k - 1; j >= 0; j--)
664 for (
size_t l = 0, bp = k; l < bp ; l++)
671 if (lastTransformationIndices(i) && evalOld(l, j, i) < -1)
674 maxImpNorm(l, i) = std::max(maxImpNorm(l, i), norm(impostors(l, i)));
676 eval = evalOld(l, j, i) +
677 transformationDiffs[lastTransformationIndices[i]] *
678 (norm(targetNeighbors(j, i)) + maxImpNorm(l, i) + 2 * norm(i));
684 if (iteration - 1 % range == 0)
686 eval = metric.Evaluate(transformedDataset.col(i),
687 transformedDataset.col(targetNeighbors(j, i))) -
692 eval = metric.Evaluate(transformedDataset.col(i),
693 transformedDataset.col(targetNeighbors(j, i))) -
694 metric.Evaluate(transformedDataset.col(i),
695 transformedDataset.col(impostors(l, i)));
700 evalOld(l, j, i) = eval;
711 if (eval > -1 && lastTransformationIndices(i))
714 evalOld(l, j, i) = 0;
715 maxImpNorm(l, i) = 0;
716 --oldTransformationCounts[lastTransformationIndices(i)];
717 lastTransformationIndices(i) = 0;
721 arma::vec diff = dataset.col(i) - dataset.col(targetNeighbors(j, i));
722 cil += diff * arma::trans(diff);
724 diff = dataset.col(i) - dataset.col(impostors(l, i));
725 cil -= diff * arma::trans(diff);
730 gradient = 2 * transformation * ((1 - regularization) * cij +
731 regularization * cil);
734 UpdateCache(transformation, begin, batchSize);
738 template<
typename MetricType>
739 template<
typename GradType>
741 const arma::mat& transformation,
747 transformedDataset = transformation * dataset;
749 double transformationDiff = 0;
750 if (!transformationOld.is_empty())
753 transformationDiff = arma::norm(transformation - transformationOld);
756 if (!transformationOld.is_empty() && iteration++ % range == 0)
761 size_t numPoints = 0;
763 for (
size_t i = 0; i < dataset.n_cols; ++i)
765 if (transformationDiff * (2 * norm(i) + norm(impostors(k - 1, i)) +
766 norm(impostors(k, i))) > distance(k, i) - distance(k - 1, i))
768 points(numPoints++) = i;
773 constraint.Impostors(impostors, distance,
774 transformedDataset, labels, norm, points, numPoints);
779 constraint.Impostors(impostors, distance, transformedDataset, labels,
783 else if (iteration++ % range == 0)
786 constraint.Impostors(impostors, distance, transformedDataset, labels,
790 gradient.zeros(transformation.n_rows, transformation.n_cols);
793 arma::mat cij = pCij;
796 arma::mat cil = arma::zeros(dataset.n_rows, dataset.n_rows);
798 for (
size_t i = 0; i < dataset.n_cols; ++i)
800 for (
size_t j = 0; j < k ; ++j)
803 double eval = metric.Evaluate(transformedDataset.col(i),
804 transformedDataset.col(targetNeighbors(j, i)));
805 cost += (1 - regularization) * eval;
808 for (
int j = k - 1; j >= 0; j--)
811 for (
size_t l = 0, bp = k; l < bp ; l++)
818 if (!transformationOld.is_empty() && evalOld(l, j, i) < -1)
821 maxImpNorm(l, i) = std::max(maxImpNorm(l, i), norm(impostors(l, i)));
823 eval = evalOld(l, j, i) + transformationDiff *
824 (norm(targetNeighbors(j, i)) + maxImpNorm(l, i) +
831 if (iteration - 1 % range == 0)
833 eval = metric.Evaluate(transformedDataset.col(i),
834 transformedDataset.col(targetNeighbors(j, i))) -
839 eval = metric.Evaluate(transformedDataset.col(i),
840 transformedDataset.col(targetNeighbors(j, i))) -
841 metric.Evaluate(transformedDataset.col(i),
842 transformedDataset.col(impostors(l, i)));
847 evalOld(l, j, i) = eval;
857 cost += regularization * (1 + eval);
860 arma::vec diff = dataset.col(i) - dataset.col(targetNeighbors(j, i));
861 cil += diff * arma::trans(diff);
863 diff = dataset.col(i) - dataset.col(impostors(l, i));
864 cil -= diff * arma::trans(diff);
869 gradient = 2 * transformation * ((1 - regularization) * cij +
870 regularization * cil);
873 transformationOld = transformation;
879 template<
typename MetricType>
880 template<
typename GradType>
882 const arma::mat& transformation,
885 const size_t batchSize)
890 std::map<size_t, double> transformationDiffs;
891 TransDiff(transformationDiffs, transformation, begin, batchSize);
894 transformedDataset = transformation * dataset;
896 if (impBounds && iteration++ % range == 0)
899 size_t numPoints = 0;
901 for (
size_t i = begin; i < begin + batchSize; ++i)
903 if (lastTransformationIndices(i))
905 if (transformationDiffs[lastTransformationIndices[i]] *
906 (2 * norm(i) + norm(impostors(k - 1, i)) +
907 norm(impostors(k, i))) > distance(k, i) - distance(k - 1, i))
909 points(numPoints++) = i;
914 points(numPoints++) = i;
919 constraint.Impostors(impostors, distance,
920 transformedDataset, labels, norm, points, numPoints);
922 else if (iteration++ % range == 0)
925 constraint.Impostors(impostors, distance, transformedDataset, labels,
926 norm, begin, batchSize);
929 gradient.zeros(transformation.n_rows, transformation.n_cols);
931 arma::mat cij = arma::zeros(dataset.n_rows, dataset.n_rows);
932 arma::mat cil = arma::zeros(dataset.n_rows, dataset.n_rows);
934 for (
size_t i = begin; i < begin + batchSize; ++i)
936 for (
size_t j = 0; j < k ; ++j)
939 double eval = metric.Evaluate(transformedDataset.col(i),
940 transformedDataset.col(targetNeighbors(j, i)));
941 cost += (1 - regularization) * eval;
944 arma::vec diff = dataset.col(i) - dataset.col(targetNeighbors(j, i));
945 cij += diff * arma::trans(diff);
948 for (
int j = k - 1; j >= 0; j--)
951 for (
size_t l = 0, bp = k; l < bp ; l++)
958 if (lastTransformationIndices(i) && evalOld(l, j, i) < -1)
961 maxImpNorm(l, i) = std::max(maxImpNorm(l, i), norm(impostors(l, i)));
963 eval = evalOld(l, j, i) +
964 transformationDiffs[lastTransformationIndices[i]] *
965 (norm(targetNeighbors(j, i)) + maxImpNorm(l, i) + 2 * norm(i));
971 if (iteration - 1 % range == 0)
973 eval = metric.Evaluate(transformedDataset.col(i),
974 transformedDataset.col(targetNeighbors(j, i))) -
979 eval = metric.Evaluate(transformedDataset.col(i),
980 transformedDataset.col(targetNeighbors(j, i))) -
981 metric.Evaluate(transformedDataset.col(i),
982 transformedDataset.col(impostors(l, i)));
987 evalOld(l, j, i) = eval;
997 cost += regularization * (1 + eval);
1000 arma::vec diff = dataset.col(i) - dataset.col(targetNeighbors(j, i));
1001 cil += diff * arma::trans(diff);
1003 diff = dataset.col(i) - dataset.col(impostors(l, i));
1004 cil -= diff * arma::trans(diff);
1009 gradient = 2 * transformation * ((1 - regularization) * cij +
1010 regularization * cil);
1013 UpdateCache(transformation, begin, batchSize);
1018 template<
typename MetricType>
1021 pCij.zeros(dataset.n_rows, dataset.n_rows);
1023 for (
size_t i = 0; i < dataset.n_cols; ++i)
1025 for (
size_t j = 0; j < k ; ++j)
1028 arma::vec diff = dataset.col(i) - dataset.col(targetNeighbors(j, i));
1029 pCij += diff * arma::trans(diff);
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
The Large Margin Nearest Neighbors function.
Definition: lmnn_function.hpp:46
double EvaluateWithGradient(const arma::mat &transformation, GradType &gradient)
Evaluate the LMNN objective function together with gradient for the given transformation matrix...
Definition: lmnn_function_impl.hpp:740
double Evaluate(const arma::mat &transformation)
Evaluate the LMNN function for the given transformation matrix.
Definition: lmnn_function_impl.hpp:213
void Gradient(const arma::mat &transformation, GradType &gradient)
Evaluate the gradient of the LMNN function for the given transformation matrix.
Definition: lmnn_function_impl.hpp:471
LMNNFunction(const arma::mat &dataset, const arma::Row< size_t > &labels, size_t k, double regularization, size_t range, MetricType metric=MetricType())
Constructor for LMNNFunction class.
Definition: lmnn_function_impl.hpp:23
void ClearAlias(arma::Mat< ElemType > &mat)
Clear an alias so that no data is overwritten.
Definition: make_alias.hpp:110
arma::Cube< ElemType > MakeAlias(arma::Cube< ElemType > &input, const bool strict=true)
Make an alias of a dense cube.
Definition: make_alias.hpp:24
void Shuffle()
Shuffle the points in the dataset.
Definition: lmnn_function_impl.hpp:96
static void Assert(bool condition, const std::string &message="Assert Failed.")
Checks if the specified condition is true.
Definition: log.cpp:38