13 #ifndef MLPACK_METHODS_ANN_LAYER_LAYERNORM_IMPL_HPP 14 #define MLPACK_METHODS_ANN_LAYER_LAYERNORM_IMPL_HPP 23 template<
typename InputDataType,
typename OutputDataType>
32 template <
typename InputDataType,
typename OutputDataType>
34 const size_t size,
const double eps) :
39 weights.set_size(size + size, 1);
42 template<
typename InputDataType,
typename OutputDataType>
45 gamma = arma::mat(weights.memptr(), size, 1,
false,
false);
46 beta = arma::mat(weights.memptr() + gamma.n_elem, size, 1,
false,
false);
57 template<
typename InputDataType,
typename OutputDataType>
60 const arma::Mat<eT>& input, arma::Mat<eT>& output)
62 mean = arma::mean(input, 0);
63 variance = arma::var(input, 1, 0);
66 output = input.each_row() - mean;
68 output.each_row() /= arma::sqrt(variance + eps);
74 output.each_col() %= gamma;
75 output.each_col() += beta;
78 template<
typename InputDataType,
typename OutputDataType>
81 const arma::Mat<eT>& input,
const arma::Mat<eT>& gy, arma::Mat<eT>& g)
83 const arma::mat stdInv = 1.0 / arma::sqrt(variance + eps);
86 const arma::mat norm = gy.each_col() % gamma;
89 const arma::mat var = arma::sum(norm % inputMean, 0) %
90 arma::pow(stdInv, 3.0) * -0.5;
94 g = (norm.each_row() % stdInv) + (inputMean.each_row() %
95 var * 2 / input.n_rows);
99 g.each_row() += arma::sum(norm.each_row() % -stdInv, 0) / input.n_rows;
102 template<
typename InputDataType,
typename OutputDataType>
103 template<
typename eT>
105 const arma::Mat<eT>& ,
106 const arma::Mat<eT>& error,
107 arma::Mat<eT>& gradient)
109 gradient.set_size(size + size, 1);
112 gradient.submat(0, 0, gamma.n_elem - 1, 0) = arma::sum(normalized % error, 1);
115 gradient.submat(gamma.n_elem, 0, gradient.n_elem - 1, 0) =
119 template<
typename InputDataType,
typename OutputDataType>
120 template<
typename Archive>
122 Archive& ar,
const uint32_t )
124 ar(CEREAL_NVP(size));
126 if (cereal::is_loading<Archive>())
128 weights.set_size(size + size, 1);
133 ar(CEREAL_NVP(gamma));
134 ar(CEREAL_NVP(beta));
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
OutputDataType const & Gradient() const
Get the gradient.
Definition: layer_norm.hpp:135
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Forward pass of Layer Normalization.
Definition: layer_norm_impl.hpp:59
void Backward(const arma::Mat< eT > &input, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Backward pass through the layer.
Definition: layer_norm_impl.hpp:80
void Reset()
Reset the layer parameters.
Definition: layer_norm_impl.hpp:43
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: layer_norm_impl.hpp:121
LayerNorm()
Create the LayerNorm object.
Definition: layer_norm_impl.hpp:24