mlpack
layer_norm_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_LAYERNORM_IMPL_HPP
14 #define MLPACK_METHODS_ANN_LAYER_LAYERNORM_IMPL_HPP
15 
16 // In case it is not included.
17 #include "layer_norm.hpp"
18 
19 namespace mlpack {
20 namespace ann {
23 template<typename InputDataType, typename OutputDataType>
25  size(0),
26  eps(1e-8),
27  loading(false)
28 {
29  // Nothing to do here.
30 }
31 
32 template <typename InputDataType, typename OutputDataType>
34  const size_t size, const double eps) :
35  size(size),
36  eps(eps),
37  loading(false)
38 {
39  weights.set_size(size + size, 1);
40 }
41 
42 template<typename InputDataType, typename OutputDataType>
44 {
45  gamma = arma::mat(weights.memptr(), size, 1, false, false);
46  beta = arma::mat(weights.memptr() + gamma.n_elem, size, 1, false, false);
47 
48  if (!loading)
49  {
50  gamma.fill(1.0);
51  beta.fill(0.0);
52  }
53 
54  loading = false;
55 }
56 
57 template<typename InputDataType, typename OutputDataType>
58 template<typename eT>
60  const arma::Mat<eT>& input, arma::Mat<eT>& output)
61 {
62  mean = arma::mean(input, 0);
63  variance = arma::var(input, 1, 0);
64 
65  // Normalize the input.
66  output = input.each_row() - mean;
67  inputMean = output;
68  output.each_row() /= arma::sqrt(variance + eps);
69 
70  // Reused in the backward and gradient step.
71  normalized = output;
72 
73  // Scale and shift the output.
74  output.each_col() %= gamma;
75  output.each_col() += beta;
76 }
77 
78 template<typename InputDataType, typename OutputDataType>
79 template<typename eT>
81  const arma::Mat<eT>& input, const arma::Mat<eT>& gy, arma::Mat<eT>& g)
82 {
83  const arma::mat stdInv = 1.0 / arma::sqrt(variance + eps);
84 
85  // dl / dxhat.
86  const arma::mat norm = gy.each_col() % gamma;
87 
88  // sum dl / dxhat * (x - mu) * -0.5 * stdInv^3.
89  const arma::mat var = arma::sum(norm % inputMean, 0) %
90  arma::pow(stdInv, 3.0) * -0.5;
91 
92  // dl / dxhat * 1 / stdInv + variance * 2 * (x - mu) / m +
93  // dl / dmu * 1 / m.
94  g = (norm.each_row() % stdInv) + (inputMean.each_row() %
95  var * 2 / input.n_rows);
96 
97  // sum (dl / dxhat * -1 / stdInv) + variance *
98  // (sum -2 * (x - mu)) / m.
99  g.each_row() += arma::sum(norm.each_row() % -stdInv, 0) / input.n_rows;
100 }
101 
102 template<typename InputDataType, typename OutputDataType>
103 template<typename eT>
105  const arma::Mat<eT>& /* input */,
106  const arma::Mat<eT>& error,
107  arma::Mat<eT>& gradient)
108 {
109  gradient.set_size(size + size, 1);
110 
111  // Step 5: dl / dy * xhat.
112  gradient.submat(0, 0, gamma.n_elem - 1, 0) = arma::sum(normalized % error, 1);
113 
114  // Step 6: dl / dy.
115  gradient.submat(gamma.n_elem, 0, gradient.n_elem - 1, 0) =
116  arma::sum(error, 1);
117 }
118 
119 template<typename InputDataType, typename OutputDataType>
120 template<typename Archive>
122  Archive& ar, const uint32_t /* version */)
123 {
124  ar(CEREAL_NVP(size));
125 
126  if (cereal::is_loading<Archive>())
127  {
128  weights.set_size(size + size, 1);
129  loading = true;
130  }
131 
132  ar(CEREAL_NVP(eps));
133  ar(CEREAL_NVP(gamma));
134  ar(CEREAL_NVP(beta));
135 }
136 
137 } // namespace ann
138 } // namespace mlpack
139 
140 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
OutputDataType const & Gradient() const
Get the gradient.
Definition: layer_norm.hpp:135
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Forward pass of Layer Normalization.
Definition: layer_norm_impl.hpp:59
void Backward(const arma::Mat< eT > &input, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Backward pass through the layer.
Definition: layer_norm_impl.hpp:80
void Reset()
Reset the layer parameters.
Definition: layer_norm_impl.hpp:43
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: layer_norm_impl.hpp:121
LayerNorm()
Create the LayerNorm object.
Definition: layer_norm_impl.hpp:24