mlpack
orthogonal_regularizer_impl.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_ORTHOGONAL_REGULARIZER_IMPL_HPP
13 #define MLPACK_METHODS_ANN_ORTHOGONAL_REGULARIZER_IMPL_HPP
14 
15 // In case it hasn't been included.
17 
18 namespace mlpack {
19 namespace ann {
20 
22  factor(factor)
23 {
24  // Nothing to do here.
25 }
26 
27 template<typename MatType>
28 void OrthogonalRegularizer::Evaluate(const MatType& weight, MatType& gradient)
29 {
30  arma::mat grad = arma::zeros(arma::size(weight));
31 
32  for (size_t i = 0; i < weight.n_rows; ++i)
33  {
34  for (size_t j = 0; j < weight.n_rows; ++j)
35  {
36  if (i == j)
37  {
38  double s =
39  arma::as_scalar(
40  arma::sign((weight.row(i) * weight.row(i).t()) - 1));
41  grad.row(i) += 2 * s * weight.row(i);
42  }
43  else
44  {
45  double s = arma::as_scalar(
46  arma::sign(weight.row(i) * weight.row(j).t()));
47  grad.row(i) += s * weight.row(j);
48  grad.row(j) += s * weight.row(i);
49  }
50  }
51  }
52 
53  gradient += arma::vectorise(grad) * factor;
54 }
55 
56 template<typename Archive>
58  Archive& ar, const uint32_t /* version */)
59 {
60  ar(CEREAL_NVP(factor));
61 }
62 
63 } // namespace ann
64 } // namespace mlpack
65 
66 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
double factor
The constant for the regularization.
Definition: orthogonal_regularizer.hpp:63
OrthogonalRegularizer(double factor=1.0)
Create the regularizer object.
Definition: orthogonal_regularizer_impl.hpp:21
void serialize(Archive &ar, const uint32_t)
Serialize the regularizer (nothing to do).
Definition: orthogonal_regularizer_impl.hpp:57
void Evaluate(const MatType &weight, MatType &gradient)
Calculate the gradient for regularization.
Definition: orthogonal_regularizer_impl.hpp:28