13 #ifndef MLPACK_METHODS_ANN_LAYER_POSITIONAL_ENCODING_IMPL_HPP 14 #define MLPACK_METHODS_ANN_LAYER_POSITIONAL_ENCODING_IMPL_HPP 22 template<
typename InputDataType,
typename OutputDataType>
30 template<
typename InputDataType,
typename OutputDataType>
32 const size_t embedDim,
33 const size_t maxSequenceLength) :
35 maxSequenceLength(maxSequenceLength)
37 InitPositionalEncoding();
40 template<
typename InputDataType,
typename OutputDataType>
43 positionalEncoding.set_size(maxSequenceLength, embedDim);
44 const InputDataType position = arma::regspace(0, 1, maxSequenceLength - 1);
45 const InputDataType divTerm = arma::exp(arma::regspace(0, 2, embedDim - 1)
46 * (- std::log(10000.0) / embedDim));
47 const InputDataType theta = position * divTerm.t();
48 for (
size_t i = 0; i < theta.n_cols; ++i)
50 positionalEncoding.col(2 * i) = arma::sin(theta.col(i));
51 positionalEncoding.col(2 * i + 1) = arma::cos(theta.col(i));
53 positionalEncoding = arma::vectorise(positionalEncoding.t());
56 template<
typename InputDataType,
typename OutputDataType>
59 const arma::Mat<eT>& input, arma::Mat<eT>& output)
61 if (input.n_rows != embedDim * maxSequenceLength)
62 Log::Fatal <<
"Incorrect input dimensions!" << std::endl;
64 output = input.each_col() + positionalEncoding;
67 template<
typename InputDataType,
typename OutputDataType>
70 const arma::Mat<eT>& ,
const arma::Mat<eT>& gy, arma::Mat<eT>& g)
75 template<
typename InputDataType,
typename OutputDataType>
76 template<
typename Archive>
78 Archive& ar,
const uint32_t )
80 ar(CEREAL_NVP(embedDim));
81 ar(CEREAL_NVP(maxSequenceLength));
83 if (cereal::is_loading<Archive>())
84 InitPositionalEncoding();
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: positional_encoding_impl.hpp:58
Positional Encoding injects some information about the relative or absolute position of the tokens in...
Definition: positional_encoding.hpp:37
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: positional_encoding_impl.hpp:77
PositionalEncoding()
Create PositionalEncoding object.
Definition: positional_encoding_impl.hpp:23
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: positional_encoding_impl.hpp:69