mlpack
multihead_attention_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_MULTIHEAD_ATTENTION_IMPL_HPP
14 #define MLPACK_METHODS_ANN_LAYER_MULTIHEAD_ATTENTION_IMPL_HPP
15 
16 // In case it hasn't yet been included.
17 #include "multihead_attention.hpp"
18 
20 
21 namespace mlpack {
22 namespace ann {
23 
24 template <typename InputDataType, typename OutputDataType,
25  typename RegularizerType>
28  tgtSeqLen(0),
29  srcSeqLen(0),
30  embedDim(0),
31  numHeads(0),
32  headDim(0)
33 {
34  // Nothing to do here.
35 }
36 
37 template <typename InputDataType, typename OutputDataType,
38  typename RegularizerType>
41  const size_t tgtSeqLen,
42  const size_t srcSeqLen,
43  const size_t embedDim,
44  const size_t numHeads) :
45  tgtSeqLen(tgtSeqLen),
46  srcSeqLen(srcSeqLen),
47  embedDim(embedDim),
48  numHeads(numHeads)
49 {
50  if (embedDim % numHeads != 0)
51  {
52  Log::Fatal << "Embedding dimension must be divisible by number of \
53  attention heads." << std::endl;
54  }
55 
56  headDim = embedDim / numHeads;
57  weights.set_size(WeightSize(), 1);
58 }
59 
60 template <typename InputDataType, typename OutputDataType,
61  typename RegularizerType>
64 {
65  typedef typename arma::Mat<typename OutputDataType::elem_type> MatType;
66 
67  queryWt = MatType(weights.memptr(), embedDim, embedDim, false, false);
68  keyWt = MatType(weights.memptr() + embedDim * embedDim,
69  embedDim, embedDim, false, false);
70  valueWt = MatType(weights.memptr() + 2 * embedDim * embedDim,
71  embedDim, embedDim, false, false);
72  outWt = MatType(weights.memptr() + 3 * embedDim * embedDim,
73  embedDim, embedDim, false, false);
74 
75  qBias = MatType(weights.memptr()
76  + 4 * embedDim * embedDim, embedDim, 1, false, false);
77  kBias = MatType(weights.memptr()
78  + (4 * embedDim + 1) * embedDim, embedDim, 1, false, false);
79  vBias = MatType(weights.memptr()
80  + (4 * embedDim + 2) * embedDim, embedDim, 1, false, false);
81  outBias = MatType(weights.memptr()
82  + (4 * embedDim + 3) * embedDim, 1, embedDim, false, false);
83 }
84 
85 template <typename InputDataType, typename OutputDataType,
86  typename RegularizerType>
87 template <typename eT>
89 Forward(const arma::Mat<eT>& input, arma::Mat<eT>& output)
90 {
91  typedef typename arma::Cube<eT> CubeType;
92 
93  if (input.n_rows != embedDim * (tgtSeqLen + 2 * srcSeqLen))
94  {
95  Log::Fatal << "Incorrect input dimensions!" << std::endl;
96  }
97 
98  const size_t batchSize = input.n_cols;
99 
100  // shape of output : (embedDim * tgtSeqLen, batchSize).
101  output.set_size(embedDim * tgtSeqLen, batchSize);
102 
103  // Reshape the input, the query, and the key into a cube from a matrix.
104  // The shape of q : (embedDim, tgtSeqLen, batchSize).
105  // The shape of k : (embedDim, srcSeqLen, batchSize).
106  // The shape of v : (embedDim, srcSeqLen, batchSize).
107  const CubeType q(const_cast<arma::Mat<eT>&>(input).memptr(),
108  embedDim, tgtSeqLen, batchSize, false, false);
109  const CubeType k(const_cast<arma::Mat<eT>&>(input).memptr() +
110  embedDim * tgtSeqLen * batchSize,
111  embedDim, srcSeqLen, batchSize, false, false);
112  const CubeType v(const_cast<arma::Mat<eT>&>(input).memptr() +
113  embedDim * (tgtSeqLen + srcSeqLen) * batchSize,
114  embedDim, srcSeqLen, batchSize, false, false);
115 
116  // qProj, kProj, and vProj are the linearly projected query, key and value
117  // respectively.
118  qProj.set_size(tgtSeqLen, embedDim, batchSize);
119  kProj.set_size(srcSeqLen, embedDim, batchSize);
120  vProj.set_size(srcSeqLen, embedDim, batchSize);
121 
122  for (size_t i = 0; i < batchSize; ++i)
123  {
124  qProj.slice(i) = arma::trans(
125  queryWt * q.slice(i) + arma::repmat(qBias, 1, tgtSeqLen));
126  kProj.slice(i) = arma::trans(
127  keyWt * k.slice(i) + arma::repmat(kBias, 1, srcSeqLen));
128  vProj.slice(i) = arma::trans(
129  valueWt * v.slice(i) + arma::repmat(vBias, 1, srcSeqLen));
130  }
131 
132  // The scaling factor sqrt(headDim) is used to prevent exploding values
133  // after dot product i.e. when qProj is multiplied with kProj.
134  qProj /= std::sqrt(headDim);
135 
136  // Split the qProj, kProj and vProj into n heads. That's what Multihead
137  // Attention is.
138  qProj.reshape(tgtSeqLen, headDim, numHeads * batchSize);
139  kProj.reshape(srcSeqLen, headDim, numHeads * batchSize);
140  vProj.reshape(srcSeqLen, headDim, numHeads * batchSize);
141 
142  // Calculate the scores i.e. perform the matrix multiplication operation
143  // on qProj and kProj. Here score = qProj . kProj'
144  scores = math::MultiplyCube2Cube(qProj, kProj, false, true);
145 
146  // Apply the attention mask if provided. The attention mask is used to black-
147  // out future sequences and generally used in Encoder-Decoder attention.
148  // The attention mask has elements 0 or -infinity.
149  // The shape of the attention mask : (tgtSeqLen, srcSeqLen).
150  if (!attnMask.is_empty())
151  {
152  if (attnMask.n_rows != tgtSeqLen || attnMask.n_cols != srcSeqLen)
153  Log::Fatal << "The size of the 'attn_mask' is not correct.\n";
154  scores.each_slice() += attnMask;
155  }
156 
157  // Apply the key padding mask when provided. It blacks-out any particular
158  // word in the sequence.
159  // The key padding mask has elements 0 or -infinity.
160  // The shape of keyPaddingMask : (1, srcSeqLen).
161  if (!keyPaddingMask.is_empty())
162  {
163  if (keyPaddingMask.n_rows != 1 || keyPaddingMask.n_cols != srcSeqLen)
164  Log::Fatal << "The size of the 'keyPaddingMask' is not correct.\n";
165  scores.each_slice() += arma::repmat(keyPaddingMask, tgtSeqLen, 1);
166  }
167 
168  for (size_t i = 0; i < numHeads * batchSize; ++i)
169  {
170  softmax.Forward(scores.slice(i), softmax.OutputParameter());
171  scores.slice(i) = softmax.OutputParameter();
172  }
173 
174  // Calculate the attention output i.e. matrix multiplication of softmax
175  // output and vProj.
176  // The shape of attnOutput : (tgtSeqLen, headDim, numHeads * batchSize).
177  attnOut = math::MultiplyCube2Cube(scores, vProj, false, false);
178 
179  // Now we will concatenate output of all the heads i.e. we will reshape
180  // attnOut to (tgtSeqLen, embedDim, batchSize).
181  attnOut.reshape(tgtSeqLen, embedDim, batchSize);
182 
183  // The final output is the linear projection of attention output.
184  for (size_t i = 0; i < batchSize; ++i)
185  {
186  output.col(i) = arma::vectorise(arma::trans(attnOut.slice(i) * outWt
187  + arma::repmat(outBias, tgtSeqLen, 1)));
188  }
189 }
190 
191 template <typename InputDataType, typename OutputDataType,
192  typename RegularizerType>
193 template <typename eT>
195 Backward(const arma::Mat<eT>& /* input */,
196  const arma::Mat<eT>& gy,
197  arma::Mat<eT>& g)
198 {
199  typedef typename arma::Cube<eT> CubeType;
200 
201  if (gy.n_rows != tgtSeqLen * embedDim)
202  {
203  Log::Fatal << "Backpropagated error has incorrect dimensions!" << std::endl;
204  }
205 
206  const size_t batchSize = gy.n_cols;
207  g.set_size(embedDim * (tgtSeqLen + 2 * srcSeqLen), batchSize);
208 
209  // Reshape the propagated gradient into a cube.
210  // The shape of gyTemp : (tgtSeqLen, embedDim, batchSize).
211  // We need not split it into n heads now because this is the part when
212  // output were concatenated from n heads.
213  CubeType gyTemp(const_cast<arma::Mat<eT>&>(gy).memptr(), embedDim,
214  tgtSeqLen, batchSize, true, false);
215 
216  // The shape of gyTemp : (embedDim, tgtSeqLen, batchSize).
217  // The shape of outWt : (embedDim, embedDim).
218  // The shape of the result : (tgtSeqLen, embedDim, batchSize).
219  gyTemp = math::MultiplyCube2Mat(gyTemp, outWt, true, true);
220 
221  // Now since the shape of gyTemp is (tgtSeqLen, embedDim, batchSize). We will
222  // split it into n heads.
223  // The shape of gyTemp : (tgtSeqLen, headDim, numHeads * batchSize).
224  gyTemp.reshape(tgtSeqLen, headDim, numHeads * batchSize);
225 
226  // Obtain backpropagted error of value.
227  // Shape of gyTemp : (tgtSeqLen, headDim, numHeads * batchSize).
228  // Shape of scores : (tgtSeqLen, srcSeqLen, numHeads * batchSize).
229  // The shape of tmp : (srcSeqLen, headDim, numHeads * batchSize).
230  CubeType tmp = math::MultiplyCube2Cube(scores, gyTemp, true, false);
231 
232  // Concatenate results of all the attention heads.
233  tmp.reshape(srcSeqLen, embedDim, batchSize);
234 
235  for (size_t i = 0; i < batchSize; ++i)
236  {
237  g.submat((tgtSeqLen + srcSeqLen) * embedDim, i, g.n_rows - 1, i)
238  = arma::vectorise(arma::trans(tmp.slice(i) * valueWt));
239  }
240 
241  // The shape of gyTemp : (tgtSeqLen, headDim, numHeads * batchSize).
242  // The shape of vProj : (srcSeqLen, headDim, numHeads * batchSize).
243  // So the new shape of gyTemp : (tgtSeqLen, srcSeqLen, numHeads * batchSize).
244  gyTemp = math::MultiplyCube2Cube(gyTemp, vProj, false, true);
245 
246  for (size_t i = 0; i < numHeads * batchSize; ++i)
247  {
248  // We will perform backpropagation of softmax over each slice of gyTemp.
249  softmax.Backward(scores.slice(i), gyTemp.slice(i), gyTemp.slice(i));
250  }
251 
252  // Obtain backpropagated error of key.
253  // The shape of qProj : (tgtSeqLen, headDim, numHeads * batchSize).
254  // The shape of gyTemp : (tgtSeqLen, srcSeqLen, numHeads * batchSize).
255  // The new shape of tmp : (srcSeqLen, headDim, numHeads * batchSize).
256  tmp = math::MultiplyCube2Cube(gyTemp, qProj, true, false);
257 
258  // Concatenate results of all the attention heads.
259  tmp.reshape(srcSeqLen, embedDim, batchSize);
260 
261  for (size_t i = 0; i < batchSize; ++i)
262  {
263  g.submat(tgtSeqLen * embedDim, i, (tgtSeqLen + srcSeqLen) * embedDim - 1, i)
264  = arma::vectorise(arma::trans(tmp.slice(i) * keyWt));
265  }
266 
267  // Obtain backpropagated error of the query.
268  // The shape of kProj : (srcSeqLen, headDim, numHeads * batchSize).
269  // The shape of gyTemp : (tgtSeqLen, srcSeqLen, numHeads * batchSize).
270  // The new shape of tmp : (tgtSeqLen, headDim, numHeads * batchSize).
271  tmp = math::MultiplyCube2Cube(gyTemp, kProj) / std::sqrt(headDim);
272 
273  // Concatenate results of all the attention heads.
274  tmp.reshape(tgtSeqLen, embedDim, batchSize);
275 
276  for (size_t i = 0; i < batchSize; ++i)
277  {
278  g.submat(0, i, tgtSeqLen * embedDim - 1, i)
279  = arma::vectorise(arma::trans(tmp.slice(i) * queryWt));
280  }
281 }
282 
283 template <typename InputDataType, typename OutputDataType,
284  typename RegularizerType>
285 template <typename eT>
287 Gradient(const arma::Mat<eT>& input,
288  const arma::Mat<eT>& error,
289  arma::Mat<eT>& gradient)
290 {
291  typedef typename arma::Cube<eT> CubeType;
292  typedef typename arma::Mat<eT> MatType;
293 
294  if (input.n_rows != embedDim * (tgtSeqLen + 2 * srcSeqLen))
295  {
296  Log::Fatal << "Incorrect input dimensions!" << std::endl;
297  }
298 
299  if (error.n_rows != tgtSeqLen * embedDim)
300  {
301  Log::Fatal << "Backpropagated error has incorrect dimensions." << std::endl;
302  }
303 
304  const size_t batchSize = input.n_cols;
305  const size_t wtSize = embedDim * embedDim;
306 
307  // The shape of gradient : (4 * embedDim * embedDim + 4 * embedDim, 1).
308  gradient.set_size(arma::size(weights));
309 
310  const CubeType q(const_cast<MatType&>(input).memptr(),
311  embedDim, tgtSeqLen, batchSize, false, false);
312  const CubeType k(const_cast<MatType&>(input).memptr() + q.n_elem,
313  embedDim, srcSeqLen, batchSize, false, false);
314  const CubeType v(const_cast<MatType&>(input).memptr() + q.n_elem + k.n_elem,
315  embedDim, srcSeqLen, batchSize, false, false);
316 
317  // Reshape the propagated error into a cube.
318  // The shape of errorTemp : (embedDim, tgtSeqLen, batchSize).
319  CubeType errorTemp(const_cast<arma::Mat<eT>&>(error).memptr(), embedDim,
320  tgtSeqLen, batchSize, true, false);
321 
322  // Gradient wrt. outBias, i.e. dL/d(outBias).
323  gradient.rows(4 * wtSize + 3 * embedDim, 4 * wtSize + 4 * embedDim - 1)
324  = arma::vectorise(arma::sum(arma::sum(errorTemp, 2), 1));
325 
326  // The shape of attnOut : (tgtSeqLen, embedDim, batchSize).
327  // The shape of errorTemp : (embedDim, tgtSeqLen, batchSize).
328  // The shape of gyTemp : (embedDim, embedDim, batchSize).
329  CubeType gyTemp = math::MultiplyCube2Cube(attnOut, errorTemp, true, true);
330 
331  // Gradient wrt. outWt, i.e. dL/d(outWt). We will take sum of gyTemp along
332  // the slices and vectorise the output.
333  gradient.rows(3 * wtSize, 4 * wtSize - 1)
334  = arma::vectorise(arma::sum(gyTemp, 2));
335 
336  // Partial derivative wrt. attnOut.
337  // The shape of outWt : (embedDim, embedDim).
338  // The shape of errorTemp : (embedDim, tgtSeqLen, batchSize).
339  // The shape of gyTemp : (tgtSeqLen, embedDim, batchSize).
340  gyTemp = math::MultiplyCube2Mat(errorTemp, outWt, true, true);
341 
342  // Now we will split it into n heads i.e. reshape it into a cube of shape
343  // (tgtSeqLen, headDim, numHeads * batchSize).
344  gyTemp.reshape(tgtSeqLen, headDim, numHeads * batchSize);
345 
346  // Shape of gyTemp : (tgtSeqLen, headDim, numHeads * batchSize).
347  // Shape of scores : (tgtSeqLen, srcSeqLen, numHeads * batchSize).
348  // The new shape of errorTemp : (srcSeqLen, headDim, numHeads * batchSize).
349  errorTemp = math::MultiplyCube2Cube(scores, gyTemp, true, false);
350 
351  // Now we will concatenate the propagated errors from all heads i.e. we
352  // will reshape errorTemp to (srcSeqLen, embedDim, batchSize).
353  errorTemp.reshape(srcSeqLen, embedDim, batchSize);
354 
355  // Gradient wrt. vBias, i.e. dL/d(vBias). We will take summation of errorTemp
356  // over all the batches and over all the sequences.
357  gradient.rows(4 * wtSize + 2 * embedDim, 4 * wtSize + 3 * embedDim - 1)
358  = arma::vectorise(arma::sum(arma::sum(errorTemp, 2), 0));
359 
360  // Shape of v : (srcSeqLen, embedDim, batchSize).
361  // Shape of errorTemp : (srcSeqLen, embedDim, bathSize).
362  // The new shape of errorTemp : (embedDim, embedDim, batchSize).
363  errorTemp = math::MultiplyCube2Cube(errorTemp, v, true, true);
364 
365  // Gradient wrt. valueWt, i.e. dL/d(valueWt). We will take summation over all
366  // batches of errorTemp.
367  gradient.rows(2 * wtSize, 3 * wtSize - 1)
368  = arma::vectorise(arma::sum(errorTemp, 2));
369 
370  // Now, the shape of gyTemp : (tgtSeqLen, headDim, numHeads * batchSize).
371  // The shape of vProj : (srcSeqLen, headDim, numHeads * batchSize).
372  // The new shape of errorTemp : (tgtSeqLen, srcSeqLen, numHeads * batchSize).
373  errorTemp = math::MultiplyCube2Cube(gyTemp, vProj, false, true);
374 
375  for (size_t i = 0; i < numHeads * batchSize; ++i)
376  {
377  // The shape of scores : (tgtSeqLen, srcSeqLen, numHeads * batchSize).
378  // The shape of errorTemp : (tgtSeqLen, srcSeqLen, numHeads * batchSize).
379  // The new shape of errorTemp remain same.
380  softmax.Backward(scores.slice(i), errorTemp.slice(i), errorTemp.slice(i));
381  }
382 
383  // The shape of qProj : (tgtSeqLen, headDim, numHeads * batchSize).
384  // The shape of errorTemp : (tgtSeqLen, srcSeqLen, numHeads * batchSize).
385  // The shape of gyTemp : (srcSeqLen, headDim, numHeads * batchSize).
386  gyTemp = math::MultiplyCube2Cube(errorTemp, qProj, true, false);
387 
388  // We will now conctenate the propagated errors from all heads.
389  // The new shape of gyTemp : (srcSeqLen, embedDim, batchSize).
390  gyTemp.reshape(srcSeqLen, embedDim, batchSize);
391 
392  // Gradient wrt. kBias, i.e. dL/d(kBias). We will take summation over all the
393  // batches of gyTemp and then over all the sequences.
394  gradient.rows(4 * wtSize + embedDim, 4 * wtSize + 2 * embedDim - 1)
395  = arma::vectorise(arma::sum(arma::sum(gyTemp, 2), 0));
396 
397  // The shape of k : (embedDim, srcSeqLen, batchSize).
398  // The shape of gyTemp : (srcSeqLen, embedDim, batchSize).
399  // The shape of dkeyWt : (embedDim, embedDim, batchSize).
400  gyTemp = math::MultiplyCube2Cube(gyTemp, k, true, true);
401 
402  // Gradient wrt. keyWt, i.e. dL/d(keyWt). We will take summation over all the
403  // batches of dkeyWt.
404  gradient.rows(wtSize, 2 * wtSize - 1) = arma::vectorise(arma::sum(gyTemp, 2));
405 
406  // The shape of kProj : (srcSeqLen, headDim, numHeads * batchSize).
407  // The shape of errorTemp : (tgtSeqLen, srcSeqLen, numHeads * batchSize).
408  // The shape of gyTemp : (tgtSeqLen, headDim, numHeads * batchSize).
409  gyTemp = math::MultiplyCube2Cube(errorTemp, kProj, false, false);
410 
411  // Now, we will concatenate propagated error of all heads.
412  gyTemp.reshape(tgtSeqLen, embedDim, batchSize);
413  gyTemp /= std::sqrt(headDim);
414 
415  // Gradient wrt. qBias, i.e. dL/d(qBias). We will take summation over all the
416  // batches of gyTemp and over all the sequences.
417  gradient.rows(4 * wtSize, 4 * wtSize + embedDim - 1)
418  = arma::vectorise(arma::sum(arma::sum(gyTemp, 2), 0));
419 
420  // The shape of gyTemp : (tgtSeqLen, embedDim, batchSize).
421  // The shape of q : (embedDim, tgtSeqLen, batchSize).
422  // The shape of gyTemp : (embedDim, embedDim, batchSize).
423  gyTemp = math::MultiplyCube2Cube(gyTemp, q, true, true);
424 
425  // Gradient wrt. queryWt, i.e. dL/d(queryBias). We will take summation over
426  // all the batches of gyTemp.
427  gradient.rows(0, wtSize - 1) = arma::vectorise(arma::sum(gyTemp, 2));
428 
429  // Regularize according to the given regularization rule.
430  regularizer.Evaluate(weights, gradient);
431 }
432 
433 template <typename InputDataType, typename OutputDataType,
434  typename RegularizerType>
435 template <typename Archive>
437 serialize(Archive& ar, const uint32_t /* version */)
438 {
439  ar(CEREAL_NVP(tgtSeqLen));
440  ar(CEREAL_NVP(srcSeqLen));
441  ar(CEREAL_NVP(embedDim));
442  ar(CEREAL_NVP(numHeads));
443  ar(CEREAL_NVP(headDim));
444 
445  // This is inefficient, but we have to allocate this memory so that
446  // WeightSetVisitor gets the right size.
447  if (cereal::is_loading<Archive>())
448  weights.set_size(4 * embedDim * (embedDim + 1), 1);
449 }
450 
451 } // namespace ann
452 } // namespace mlpack
453 
454 #endif
OutputDataType const & Gradient() const
Get the gradient.
Definition: multihead_attention.hpp:173
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: multihead_attention_impl.hpp:437
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: multihead_attention_impl.hpp:195
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: multihead_attention_impl.hpp:89
size_t WeightSize() const
Get the size of the weights.
Definition: multihead_attention.hpp:124
CubeType MultiplyCube2Cube(const CubeType &cubeA, const CubeType &cubeB, const bool aTranspose=false, const bool bTranspose=false)
Matrix multiplication of slices of two cubes.
Definition: multiply_slices_impl.hpp:22
MultiheadAttention()
Default constructor.
Definition: multihead_attention_impl.hpp:27
void Reset()
Reset the layer parameters.
Definition: multihead_attention_impl.hpp:63
CubeType MultiplyCube2Mat(const CubeType &cubeA, const MatType &matB, const bool aTranspose=false, const bool bTranspose=false)
Matrix multiplication of all slices of a cube with a matrix.
Definition: multiply_slices_impl.hpp:141