12 #ifndef MLPACK_METHODS_ANN_ORTHOGONAL_REGULARIZER_IMPL_HPP 13 #define MLPACK_METHODS_ANN_ORTHOGONAL_REGULARIZER_IMPL_HPP 27 template<
typename MatType>
30 arma::mat grad = arma::zeros(arma::size(weight));
32 for (
size_t i = 0; i < weight.n_rows; ++i)
34 for (
size_t j = 0; j < weight.n_rows; ++j)
40 arma::sign((weight.row(i) * weight.row(i).t()) - 1));
41 grad.row(i) += 2 * s * weight.row(i);
45 double s = arma::as_scalar(
46 arma::sign(weight.row(i) * weight.row(j).t()));
47 grad.row(i) += s * weight.row(j);
48 grad.row(j) += s * weight.row(i);
53 gradient += arma::vectorise(grad) *
factor;
56 template<
typename Archive>
58 Archive& ar,
const uint32_t )
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
double factor
The constant for the regularization.
Definition: orthogonal_regularizer.hpp:63
OrthogonalRegularizer(double factor=1.0)
Create the regularizer object.
Definition: orthogonal_regularizer_impl.hpp:21
void serialize(Archive &ar, const uint32_t)
Serialize the regularizer (nothing to do).
Definition: orthogonal_regularizer_impl.hpp:57
void Evaluate(const MatType &weight, MatType &gradient)
Calculate the gradient for regularization.
Definition: orthogonal_regularizer_impl.hpp:28