13 #ifndef MLPACK_METHODS_ANN_LAYER_MULTIHEAD_ATTENTION_IMPL_HPP 14 #define MLPACK_METHODS_ANN_LAYER_MULTIHEAD_ATTENTION_IMPL_HPP 24 template <
typename InputDataType,
typename OutputDataType,
25 typename RegularizerType>
37 template <
typename InputDataType,
typename OutputDataType,
38 typename RegularizerType>
41 const size_t tgtSeqLen,
42 const size_t srcSeqLen,
43 const size_t embedDim,
44 const size_t numHeads) :
50 if (embedDim % numHeads != 0)
52 Log::Fatal <<
"Embedding dimension must be divisible by number of \ 53 attention heads." << std::endl;
56 headDim = embedDim / numHeads;
60 template <
typename InputDataType,
typename OutputDataType,
61 typename RegularizerType>
65 typedef typename arma::Mat<typename OutputDataType::elem_type> MatType;
67 queryWt = MatType(weights.memptr(), embedDim, embedDim,
false,
false);
68 keyWt = MatType(weights.memptr() + embedDim * embedDim,
69 embedDim, embedDim,
false,
false);
70 valueWt = MatType(weights.memptr() + 2 * embedDim * embedDim,
71 embedDim, embedDim,
false,
false);
72 outWt = MatType(weights.memptr() + 3 * embedDim * embedDim,
73 embedDim, embedDim,
false,
false);
75 qBias = MatType(weights.memptr()
76 + 4 * embedDim * embedDim, embedDim, 1,
false,
false);
77 kBias = MatType(weights.memptr()
78 + (4 * embedDim + 1) * embedDim, embedDim, 1,
false,
false);
79 vBias = MatType(weights.memptr()
80 + (4 * embedDim + 2) * embedDim, embedDim, 1,
false,
false);
81 outBias = MatType(weights.memptr()
82 + (4 * embedDim + 3) * embedDim, 1, embedDim,
false,
false);
85 template <
typename InputDataType,
typename OutputDataType,
86 typename RegularizerType>
87 template <
typename eT>
89 Forward(
const arma::Mat<eT>& input, arma::Mat<eT>& output)
91 typedef typename arma::Cube<eT> CubeType;
93 if (input.n_rows != embedDim * (tgtSeqLen + 2 * srcSeqLen))
95 Log::Fatal <<
"Incorrect input dimensions!" << std::endl;
98 const size_t batchSize = input.n_cols;
101 output.set_size(embedDim * tgtSeqLen, batchSize);
107 const CubeType q(
const_cast<arma::Mat<eT>&
>(input).memptr(),
108 embedDim, tgtSeqLen, batchSize,
false,
false);
109 const CubeType k(
const_cast<arma::Mat<eT>&
>(input).memptr() +
110 embedDim * tgtSeqLen * batchSize,
111 embedDim, srcSeqLen, batchSize,
false,
false);
112 const CubeType v(
const_cast<arma::Mat<eT>&
>(input).memptr() +
113 embedDim * (tgtSeqLen + srcSeqLen) * batchSize,
114 embedDim, srcSeqLen, batchSize,
false,
false);
118 qProj.set_size(tgtSeqLen, embedDim, batchSize);
119 kProj.set_size(srcSeqLen, embedDim, batchSize);
120 vProj.set_size(srcSeqLen, embedDim, batchSize);
122 for (
size_t i = 0; i < batchSize; ++i)
124 qProj.slice(i) = arma::trans(
125 queryWt * q.slice(i) + arma::repmat(qBias, 1, tgtSeqLen));
126 kProj.slice(i) = arma::trans(
127 keyWt * k.slice(i) + arma::repmat(kBias, 1, srcSeqLen));
128 vProj.slice(i) = arma::trans(
129 valueWt * v.slice(i) + arma::repmat(vBias, 1, srcSeqLen));
134 qProj /= std::sqrt(headDim);
138 qProj.reshape(tgtSeqLen, headDim, numHeads * batchSize);
139 kProj.reshape(srcSeqLen, headDim, numHeads * batchSize);
140 vProj.reshape(srcSeqLen, headDim, numHeads * batchSize);
150 if (!attnMask.is_empty())
152 if (attnMask.n_rows != tgtSeqLen || attnMask.n_cols != srcSeqLen)
153 Log::Fatal <<
"The size of the 'attn_mask' is not correct.\n";
154 scores.each_slice() += attnMask;
161 if (!keyPaddingMask.is_empty())
163 if (keyPaddingMask.n_rows != 1 || keyPaddingMask.n_cols != srcSeqLen)
164 Log::Fatal <<
"The size of the 'keyPaddingMask' is not correct.\n";
165 scores.each_slice() += arma::repmat(keyPaddingMask, tgtSeqLen, 1);
168 for (
size_t i = 0; i < numHeads * batchSize; ++i)
170 softmax.Forward(scores.slice(i), softmax.OutputParameter());
171 scores.slice(i) = softmax.OutputParameter();
181 attnOut.reshape(tgtSeqLen, embedDim, batchSize);
184 for (
size_t i = 0; i < batchSize; ++i)
186 output.col(i) = arma::vectorise(arma::trans(attnOut.slice(i) * outWt
187 + arma::repmat(outBias, tgtSeqLen, 1)));
191 template <
typename InputDataType,
typename OutputDataType,
192 typename RegularizerType>
193 template <
typename eT>
196 const arma::Mat<eT>& gy,
199 typedef typename arma::Cube<eT> CubeType;
201 if (gy.n_rows != tgtSeqLen * embedDim)
203 Log::Fatal <<
"Backpropagated error has incorrect dimensions!" << std::endl;
206 const size_t batchSize = gy.n_cols;
207 g.set_size(embedDim * (tgtSeqLen + 2 * srcSeqLen), batchSize);
213 CubeType gyTemp(
const_cast<arma::Mat<eT>&
>(gy).memptr(), embedDim,
214 tgtSeqLen, batchSize,
true,
false);
224 gyTemp.reshape(tgtSeqLen, headDim, numHeads * batchSize);
233 tmp.reshape(srcSeqLen, embedDim, batchSize);
235 for (
size_t i = 0; i < batchSize; ++i)
237 g.submat((tgtSeqLen + srcSeqLen) * embedDim, i, g.n_rows - 1, i)
238 = arma::vectorise(arma::trans(tmp.slice(i) * valueWt));
246 for (
size_t i = 0; i < numHeads * batchSize; ++i)
249 softmax.Backward(scores.slice(i), gyTemp.slice(i), gyTemp.slice(i));
259 tmp.reshape(srcSeqLen, embedDim, batchSize);
261 for (
size_t i = 0; i < batchSize; ++i)
263 g.submat(tgtSeqLen * embedDim, i, (tgtSeqLen + srcSeqLen) * embedDim - 1, i)
264 = arma::vectorise(arma::trans(tmp.slice(i) * keyWt));
274 tmp.reshape(tgtSeqLen, embedDim, batchSize);
276 for (
size_t i = 0; i < batchSize; ++i)
278 g.submat(0, i, tgtSeqLen * embedDim - 1, i)
279 = arma::vectorise(arma::trans(tmp.slice(i) * queryWt));
283 template <
typename InputDataType,
typename OutputDataType,
284 typename RegularizerType>
285 template <
typename eT>
288 const arma::Mat<eT>& error,
289 arma::Mat<eT>& gradient)
291 typedef typename arma::Cube<eT> CubeType;
292 typedef typename arma::Mat<eT> MatType;
294 if (input.n_rows != embedDim * (tgtSeqLen + 2 * srcSeqLen))
296 Log::Fatal <<
"Incorrect input dimensions!" << std::endl;
299 if (error.n_rows != tgtSeqLen * embedDim)
301 Log::Fatal <<
"Backpropagated error has incorrect dimensions." << std::endl;
304 const size_t batchSize = input.n_cols;
305 const size_t wtSize = embedDim * embedDim;
308 gradient.set_size(arma::size(weights));
310 const CubeType q(const_cast<MatType&>(input).memptr(),
311 embedDim, tgtSeqLen, batchSize,
false,
false);
312 const CubeType k(const_cast<MatType&>(input).memptr() + q.n_elem,
313 embedDim, srcSeqLen, batchSize,
false,
false);
314 const CubeType v(const_cast<MatType&>(input).memptr() + q.n_elem + k.n_elem,
315 embedDim, srcSeqLen, batchSize,
false,
false);
319 CubeType errorTemp(
const_cast<arma::Mat<eT>&
>(error).memptr(), embedDim,
320 tgtSeqLen, batchSize,
true,
false);
323 gradient.rows(4 * wtSize + 3 * embedDim, 4 * wtSize + 4 * embedDim - 1)
324 = arma::vectorise(arma::sum(arma::sum(errorTemp, 2), 1));
333 gradient.rows(3 * wtSize, 4 * wtSize - 1)
334 = arma::vectorise(arma::sum(gyTemp, 2));
344 gyTemp.reshape(tgtSeqLen, headDim, numHeads * batchSize);
353 errorTemp.reshape(srcSeqLen, embedDim, batchSize);
357 gradient.rows(4 * wtSize + 2 * embedDim, 4 * wtSize + 3 * embedDim - 1)
358 = arma::vectorise(arma::sum(arma::sum(errorTemp, 2), 0));
367 gradient.rows(2 * wtSize, 3 * wtSize - 1)
368 = arma::vectorise(arma::sum(errorTemp, 2));
375 for (
size_t i = 0; i < numHeads * batchSize; ++i)
380 softmax.Backward(scores.slice(i), errorTemp.slice(i), errorTemp.slice(i));
390 gyTemp.reshape(srcSeqLen, embedDim, batchSize);
394 gradient.rows(4 * wtSize + embedDim, 4 * wtSize + 2 * embedDim - 1)
395 = arma::vectorise(arma::sum(arma::sum(gyTemp, 2), 0));
404 gradient.rows(wtSize, 2 * wtSize - 1) = arma::vectorise(arma::sum(gyTemp, 2));
412 gyTemp.reshape(tgtSeqLen, embedDim, batchSize);
413 gyTemp /= std::sqrt(headDim);
417 gradient.rows(4 * wtSize, 4 * wtSize + embedDim - 1)
418 = arma::vectorise(arma::sum(arma::sum(gyTemp, 2), 0));
427 gradient.rows(0, wtSize - 1) = arma::vectorise(arma::sum(gyTemp, 2));
430 regularizer.Evaluate(weights, gradient);
433 template <
typename InputDataType,
typename OutputDataType,
434 typename RegularizerType>
435 template <
typename Archive>
439 ar(CEREAL_NVP(tgtSeqLen));
440 ar(CEREAL_NVP(srcSeqLen));
441 ar(CEREAL_NVP(embedDim));
442 ar(CEREAL_NVP(numHeads));
443 ar(CEREAL_NVP(headDim));
447 if (cereal::is_loading<Archive>())
448 weights.set_size(4 * embedDim * (embedDim + 1), 1);
OutputDataType const & Gradient() const
Get the gradient.
Definition: multihead_attention.hpp:173
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: multihead_attention_impl.hpp:437
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: multihead_attention_impl.hpp:195
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: multihead_attention_impl.hpp:89
size_t WeightSize() const
Get the size of the weights.
Definition: multihead_attention.hpp:124
CubeType MultiplyCube2Cube(const CubeType &cubeA, const CubeType &cubeB, const bool aTranspose=false, const bool bTranspose=false)
Matrix multiplication of slices of two cubes.
Definition: multiply_slices_impl.hpp:22
MultiheadAttention()
Default constructor.
Definition: multihead_attention_impl.hpp:27
void Reset()
Reset the layer parameters.
Definition: multihead_attention_impl.hpp:63
CubeType MultiplyCube2Mat(const CubeType &cubeA, const MatType &matB, const bool aTranspose=false, const bool bTranspose=false)
Matrix multiplication of all slices of a cube with a matrix.
Definition: multiply_slices_impl.hpp:141