mlpack
positional_encoding_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_POSITIONAL_ENCODING_IMPL_HPP
14 #define MLPACK_METHODS_ANN_LAYER_POSITIONAL_ENCODING_IMPL_HPP
15 
16 // In case it hasn't yet been included.
17 #include "positional_encoding.hpp"
18 
19 namespace mlpack {
20 namespace ann {
21 
22 template<typename InputDataType, typename OutputDataType>
24  embedDim(0),
25  maxSequenceLength(0)
26 {
27  // Nothing to do here.
28 }
29 
30 template<typename InputDataType, typename OutputDataType>
32  const size_t embedDim,
33  const size_t maxSequenceLength) :
34  embedDim(embedDim),
35  maxSequenceLength(maxSequenceLength)
36 {
37  InitPositionalEncoding();
38 }
39 
40 template<typename InputDataType, typename OutputDataType>
42 {
43  positionalEncoding.set_size(maxSequenceLength, embedDim);
44  const InputDataType position = arma::regspace(0, 1, maxSequenceLength - 1);
45  const InputDataType divTerm = arma::exp(arma::regspace(0, 2, embedDim - 1)
46  * (- std::log(10000.0) / embedDim));
47  const InputDataType theta = position * divTerm.t();
48  for (size_t i = 0; i < theta.n_cols; ++i)
49  {
50  positionalEncoding.col(2 * i) = arma::sin(theta.col(i));
51  positionalEncoding.col(2 * i + 1) = arma::cos(theta.col(i));
52  }
53  positionalEncoding = arma::vectorise(positionalEncoding.t());
54 }
55 
56 template<typename InputDataType, typename OutputDataType>
57 template<typename eT>
59  const arma::Mat<eT>& input, arma::Mat<eT>& output)
60 {
61  if (input.n_rows != embedDim * maxSequenceLength)
62  Log::Fatal << "Incorrect input dimensions!" << std::endl;
63 
64  output = input.each_col() + positionalEncoding;
65 }
66 
67 template<typename InputDataType, typename OutputDataType>
68 template<typename eT>
70  const arma::Mat<eT>& /* input */, const arma::Mat<eT>& gy, arma::Mat<eT>& g)
71 {
72  g = gy;
73 }
74 
75 template<typename InputDataType, typename OutputDataType>
76 template<typename Archive>
78  Archive& ar, const uint32_t /* version */)
79 {
80  ar(CEREAL_NVP(embedDim));
81  ar(CEREAL_NVP(maxSequenceLength));
82 
83  if (cereal::is_loading<Archive>())
84  InitPositionalEncoding();
85 }
86 
87 } // namespace ann
88 } // namespace mlpack
89 
90 #endif
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: positional_encoding_impl.hpp:58
Positional Encoding injects some information about the relative or absolute position of the tokens in...
Definition: positional_encoding.hpp:37
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: positional_encoding_impl.hpp:77
PositionalEncoding()
Create PositionalEncoding object.
Definition: positional_encoding_impl.hpp:23
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: positional_encoding_impl.hpp:69