Expression Templates Library (ETL)
gemm_rc_to_r.hpp
Go to the documentation of this file.
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
14 #pragma once
15 
16 #include "etl/vectorization.hpp"
17 namespace etl::impl::vec {
18 
27 template <typename V, typename T>
28 void gemm_small_kernel_rc_to_r(const T* a, const T* b, T* c, size_t M, size_t N, size_t K, T alpha) {
29  using vec_type = V;
30 
31  static constexpr size_t vec_size = vec_type::template traits<T>::size;
32 
33  size_t i = 0;
34 
35  const auto k_end = prev_multiple(K, vec_size);
36 
37  for (; i + 1 < M; i += 2) {
38  size_t j = 0;
39 
40  for (; j + 3 < N; j += 4) {
41  size_t k = 0;
42 
43  auto r11 = vec_type::template zero<T>();
44  auto r21 = vec_type::template zero<T>();
45 
46  auto r12 = vec_type::template zero<T>();
47  auto r22 = vec_type::template zero<T>();
48 
49  auto r13 = vec_type::template zero<T>();
50  auto r23 = vec_type::template zero<T>();
51 
52  auto r14 = vec_type::template zero<T>();
53  auto r24 = vec_type::template zero<T>();
54 
55  for (; k < k_end; k += vec_size) {
56  auto a1 = vec_type::loadu(a + (i + 0) * K + k + vec_size * 0);
57  auto a2 = vec_type::loadu(a + (i + 1) * K + k + vec_size * 0);
58 
59  auto b1 = vec_type::loadu(b + (j + 0) * K + k + vec_size * 0);
60  auto b2 = vec_type::loadu(b + (j + 1) * K + k + vec_size * 0);
61  auto b3 = vec_type::loadu(b + (j + 2) * K + k + vec_size * 0);
62  auto b4 = vec_type::loadu(b + (j + 3) * K + k + vec_size * 0);
63 
64  r11 = vec_type::fmadd(a1, b1, r11);
65  r21 = vec_type::fmadd(a2, b1, r21);
66 
67  r12 = vec_type::fmadd(a1, b2, r12);
68  r22 = vec_type::fmadd(a2, b2, r22);
69 
70  r13 = vec_type::fmadd(a1, b3, r13);
71  r23 = vec_type::fmadd(a2, b3, r23);
72 
73  r14 = vec_type::fmadd(a1, b4, r14);
74  r24 = vec_type::fmadd(a2, b4, r24);
75  }
76 
77  auto v11 = vec_type::hadd(r11);
78  auto v21 = vec_type::hadd(r21);
79 
80  auto v12 = vec_type::hadd(r12);
81  auto v22 = vec_type::hadd(r22);
82 
83  auto v13 = vec_type::hadd(r13);
84  auto v23 = vec_type::hadd(r23);
85 
86  auto v14 = vec_type::hadd(r14);
87  auto v24 = vec_type::hadd(r24);
88 
89  for (; k < K; ++k) {
90  v11 += a[(i + 0) * K + k] * b[k + (j + 0) * K];
91  v21 += a[(i + 1) * K + k] * b[k + (j + 0) * K];
92 
93  v12 += a[(i + 0) * K + k] * b[k + (j + 1) * K];
94  v22 += a[(i + 1) * K + k] * b[k + (j + 1) * K];
95 
96  v13 += a[(i + 0) * K + k] * b[k + (j + 2) * K];
97  v23 += a[(i + 1) * K + k] * b[k + (j + 2) * K];
98 
99  v14 += a[(i + 0) * K + k] * b[k + (j + 3) * K];
100  v24 += a[(i + 1) * K + k] * b[k + (j + 3) * K];
101  }
102 
103  c[(i + 0) * N + (j + 0)] = alpha * v11;
104  c[(i + 1) * N + (j + 0)] = alpha * v21;
105 
106  c[(i + 0) * N + (j + 1)] = alpha * v12;
107  c[(i + 1) * N + (j + 1)] = alpha * v22;
108 
109  c[(i + 0) * N + (j + 2)] = alpha * v13;
110  c[(i + 1) * N + (j + 2)] = alpha * v23;
111 
112  c[(i + 0) * N + (j + 3)] = alpha * v14;
113  c[(i + 1) * N + (j + 3)] = alpha * v24;
114  }
115 
116  for (; j + 1 < N; j += 2) {
117  size_t k = 0;
118 
119  auto r11 = vec_type::template zero<T>();
120  auto r21 = vec_type::template zero<T>();
121 
122  auto r12 = vec_type::template zero<T>();
123  auto r22 = vec_type::template zero<T>();
124 
125  for (; k < k_end; k += vec_size) {
126  auto a1 = vec_type::loadu(a + (i + 0) * K + k + vec_size * 0);
127  auto a2 = vec_type::loadu(a + (i + 1) * K + k + vec_size * 0);
128 
129  auto b1 = vec_type::loadu(b + (j + 0) * K + k + vec_size * 0);
130  auto b2 = vec_type::loadu(b + (j + 1) * K + k + vec_size * 0);
131 
132  r11 = vec_type::fmadd(a1, b1, r11);
133  r21 = vec_type::fmadd(a2, b1, r21);
134 
135  r12 = vec_type::fmadd(a1, b2, r12);
136  r22 = vec_type::fmadd(a2, b2, r22);
137  }
138 
139  auto v11 = vec_type::hadd(r11);
140  auto v21 = vec_type::hadd(r21);
141 
142  auto v12 = vec_type::hadd(r12);
143  auto v22 = vec_type::hadd(r22);
144 
145  for (; k < K; ++k) {
146  v11 += a[(i + 0) * K + k] * b[k + (j + 0) * K];
147  v21 += a[(i + 1) * K + k] * b[k + (j + 0) * K];
148 
149  v12 += a[(i + 0) * K + k] * b[k + (j + 1) * K];
150  v22 += a[(i + 1) * K + k] * b[k + (j + 1) * K];
151  }
152 
153  c[(i + 0) * N + (j + 0)] = alpha * v11;
154  c[(i + 1) * N + (j + 0)] = alpha * v21;
155 
156  c[(i + 0) * N + (j + 1)] = alpha * v12;
157  c[(i + 1) * N + (j + 1)] = alpha * v22;
158  }
159 
160  for (; j < N; ++j) {
161  size_t k = 0;
162 
163  auto r11 = vec_type::template zero<T>();
164  auto r21 = vec_type::template zero<T>();
165 
166  for (; k < k_end; k += vec_size) {
167  auto a1 = vec_type::loadu(a + (i + 0) * K + k + vec_size * 0);
168  auto a2 = vec_type::loadu(a + (i + 1) * K + k + vec_size * 0);
169 
170  auto b1 = vec_type::loadu(b + j * K + k + vec_size * 0);
171 
172  r11 = vec_type::fmadd(a1, b1, r11);
173  r21 = vec_type::fmadd(a2, b1, r21);
174  }
175 
176  auto v11 = vec_type::hadd(r11);
177  auto v21 = vec_type::hadd(r21);
178 
179  for (; k < K; ++k) {
180  v11 += a[(i + 0) * K + k] * b[k + j * K];
181  v21 += a[(i + 1) * K + k] * b[k + j * K];
182  }
183 
184  c[(i + 0) * N + j] = alpha * v11;
185  c[(i + 1) * N + j] = alpha * v21;
186  }
187  }
188 
189  for (; i < M; ++i) {
190  size_t j = 0;
191 
192 #ifdef __clang__
193  for (; j + 3 < N; j += 4) {
194  size_t k = 0;
195 
196  auto r11 = vec_type::template zero<T>();
197  auto r12 = vec_type::template zero<T>();
198  auto r13 = vec_type::template zero<T>();
199  auto r14 = vec_type::template zero<T>();
200 
201  for (; k < k_end; k += vec_size) {
202  auto a1 = vec_type::loadu(a + i * K + k + vec_size * 0);
203 
204  auto b1 = vec_type::loadu(b + (j + 0) * K + k + vec_size * 0);
205  auto b2 = vec_type::loadu(b + (j + 1) * K + k + vec_size * 0);
206  auto b3 = vec_type::loadu(b + (j + 2) * K + k + vec_size * 0);
207  auto b4 = vec_type::loadu(b + (j + 3) * K + k + vec_size * 0);
208 
209  r11 = vec_type::fmadd(a1, b1, r11);
210  r12 = vec_type::fmadd(a1, b2, r12);
211  r13 = vec_type::fmadd(a1, b3, r13);
212  r14 = vec_type::fmadd(a1, b4, r14);
213  }
214 
215  auto v11 = vec_type::hadd(r11);
216  auto v12 = vec_type::hadd(r12);
217  auto v13 = vec_type::hadd(r13);
218  auto v14 = vec_type::hadd(r14);
219 
220  for (; k < K; ++k) {
221  v11 += a[i * K + k] * b[k + (j + 0) * K];
222  v12 += a[i * K + k] * b[k + (j + 1) * K];
223  v13 += a[i * K + k] * b[k + (j + 2) * K];
224  v14 += a[i * K + k] * b[k + (j + 3) * K];
225  }
226 
227  c[i * N + (j + 0)] = alpha * v11;
228  c[i * N + (j + 1)] = alpha * v12;
229  c[i * N + (j + 2)] = alpha * v13;
230  c[i * N + (j + 3)] = alpha * v14;
231  }
232 #endif
233 
234  for (; j + 1 < N; j += 2) {
235  size_t k = 0;
236 
237  auto r11 = vec_type::template zero<T>();
238  auto r12 = vec_type::template zero<T>();
239 
240  for (; k < k_end; k += vec_size) {
241  auto a1 = vec_type::loadu(a + i * K + k + vec_size * 0);
242 
243  auto b1 = vec_type::loadu(b + (j + 0) * K + k + vec_size * 0);
244  auto b2 = vec_type::loadu(b + (j + 1) * K + k + vec_size * 0);
245 
246  r11 = vec_type::fmadd(a1, b1, r11);
247  r12 = vec_type::fmadd(a1, b2, r12);
248  }
249 
250  auto v11 = vec_type::hadd(r11);
251  auto v12 = vec_type::hadd(r12);
252 
253  for (; k < K; ++k) {
254  v11 += a[i * K + k] * b[k + (j + 0) * K];
255  v12 += a[i * K + k] * b[k + (j + 1) * K];
256  }
257 
258  c[i * N + (j + 0)] = alpha * v11;
259  c[i * N + (j + 1)] = alpha * v12;
260  }
261 
262  for (; j < N; ++j) {
263  size_t k = 0;
264 
265  auto r11 = vec_type::template zero<T>();
266 
267  for (; k < k_end; k += vec_size) {
268  auto a1 = vec_type::loadu(a + i * K + k + vec_size * 0);
269 
270  auto b1 = vec_type::loadu(b + j * K + k + vec_size * 0);
271 
272  r11 = vec_type::fmadd(a1, b1, r11);
273  }
274 
275  auto v11 = vec_type::hadd(r11);
276 
277  for (; k < K; ++k) {
278  v11 += a[i * K + k] * b[k + j * K];
279  }
280 
281  c[i * N + j] = alpha * v11;
282  }
283  }
284 }
285 
294 template <typename V, typename T>
295 void gemm_large_kernel_rc_to_r(const T* a, const T* b, T* c, size_t M, size_t N, size_t K, T alpha) {
296  using vec_type = V;
297 
298  static constexpr size_t vec_size = vec_type::template traits<T>::size;
299 
300  constexpr size_t n_block_size = 256UL;
301  constexpr size_t m_block_size = 128UL;
302  constexpr size_t k_block_size = 256UL;
303 
304  for (size_t ii = 0; ii < M; ii += m_block_size) {
305  const size_t i_end = std::min(ii + m_block_size, M);
306 
307  for (size_t jj = 0; jj < N; jj += n_block_size) {
308  const size_t j_end = std::min(jj + n_block_size, N);
309 
310  for (size_t kk = 0; kk < K; kk += k_block_size) {
311  const size_t k_end_a = std::min(kk + k_block_size, K);
312  const size_t k_end = prev_multiple(k_end_a, vec_size);
313 
314  size_t i = ii;
315 
316  for (; i + 1 < i_end; i += 2) {
317  size_t j = jj;
318 
319  for (; j + 3 < j_end; j += 4) {
320  size_t k = kk;
321 
322  auto r11 = vec_type::template zero<T>();
323  auto r21 = vec_type::template zero<T>();
324 
325  auto r12 = vec_type::template zero<T>();
326  auto r22 = vec_type::template zero<T>();
327 
328  auto r13 = vec_type::template zero<T>();
329  auto r23 = vec_type::template zero<T>();
330 
331  auto r14 = vec_type::template zero<T>();
332  auto r24 = vec_type::template zero<T>();
333 
334  for (; k < k_end; k += vec_size) {
335  auto a1 = vec_type::loadu(a + (i + 0) * K + k + vec_size * 0);
336  auto a2 = vec_type::loadu(a + (i + 1) * K + k + vec_size * 0);
337 
338  auto b1 = vec_type::loadu(b + (j + 0) * K + k + vec_size * 0);
339  auto b2 = vec_type::loadu(b + (j + 1) * K + k + vec_size * 0);
340  auto b3 = vec_type::loadu(b + (j + 2) * K + k + vec_size * 0);
341  auto b4 = vec_type::loadu(b + (j + 3) * K + k + vec_size * 0);
342 
343  r11 = vec_type::fmadd(a1, b1, r11);
344  r21 = vec_type::fmadd(a2, b1, r21);
345 
346  r12 = vec_type::fmadd(a1, b2, r12);
347  r22 = vec_type::fmadd(a2, b2, r22);
348 
349  r13 = vec_type::fmadd(a1, b3, r13);
350  r23 = vec_type::fmadd(a2, b3, r23);
351 
352  r14 = vec_type::fmadd(a1, b4, r14);
353  r24 = vec_type::fmadd(a2, b4, r24);
354  }
355 
356  auto v11 = vec_type::hadd(r11);
357  auto v21 = vec_type::hadd(r21);
358 
359  auto v12 = vec_type::hadd(r12);
360  auto v22 = vec_type::hadd(r22);
361 
362  auto v13 = vec_type::hadd(r13);
363  auto v23 = vec_type::hadd(r23);
364 
365  auto v14 = vec_type::hadd(r14);
366  auto v24 = vec_type::hadd(r24);
367 
368  for (; k < k_end_a; ++k) {
369  v11 += a[(i + 0) * K + k] * b[k + (j + 0) * K];
370  v21 += a[(i + 1) * K + k] * b[k + (j + 0) * K];
371 
372  v12 += a[(i + 0) * K + k] * b[k + (j + 1) * K];
373  v22 += a[(i + 1) * K + k] * b[k + (j + 1) * K];
374 
375  v13 += a[(i + 0) * K + k] * b[k + (j + 2) * K];
376  v23 += a[(i + 1) * K + k] * b[k + (j + 2) * K];
377 
378  v14 += a[(i + 0) * K + k] * b[k + (j + 3) * K];
379  v24 += a[(i + 1) * K + k] * b[k + (j + 3) * K];
380  }
381 
382  c[(i + 0) * N + (j + 0)] += alpha * v11;
383  c[(i + 1) * N + (j + 0)] += alpha * v21;
384 
385  c[(i + 0) * N + (j + 1)] += alpha * v12;
386  c[(i + 1) * N + (j + 1)] += alpha * v22;
387 
388  c[(i + 0) * N + (j + 2)] += alpha * v13;
389  c[(i + 1) * N + (j + 2)] += alpha * v23;
390 
391  c[(i + 0) * N + (j + 3)] += alpha * v14;
392  c[(i + 1) * N + (j + 3)] += alpha * v24;
393  }
394 
395  for (; j + 1 < j_end; j += 2) {
396  size_t k = kk;
397 
398  auto r11 = vec_type::template zero<T>();
399  auto r21 = vec_type::template zero<T>();
400 
401  auto r12 = vec_type::template zero<T>();
402  auto r22 = vec_type::template zero<T>();
403 
404  for (; k < k_end; k += vec_size) {
405  auto a1 = vec_type::loadu(a + (i + 0) * K + k + vec_size * 0);
406  auto a2 = vec_type::loadu(a + (i + 1) * K + k + vec_size * 0);
407 
408  auto b1 = vec_type::loadu(b + (j + 0) * K + k + vec_size * 0);
409  auto b2 = vec_type::loadu(b + (j + 1) * K + k + vec_size * 0);
410 
411  r11 = vec_type::fmadd(a1, b1, r11);
412  r21 = vec_type::fmadd(a2, b1, r21);
413 
414  r12 = vec_type::fmadd(a1, b2, r12);
415  r22 = vec_type::fmadd(a2, b2, r22);
416  }
417 
418  auto v11 = vec_type::hadd(r11);
419  auto v21 = vec_type::hadd(r21);
420 
421  auto v12 = vec_type::hadd(r12);
422  auto v22 = vec_type::hadd(r22);
423 
424  for (; k < k_end_a; ++k) {
425  v11 += a[(i + 0) * K + k] * b[k + (j + 0) * K];
426  v21 += a[(i + 1) * K + k] * b[k + (j + 0) * K];
427 
428  v12 += a[(i + 0) * K + k] * b[k + (j + 1) * K];
429  v22 += a[(i + 1) * K + k] * b[k + (j + 1) * K];
430  }
431 
432  c[(i + 0) * N + (j + 0)] += alpha * v11;
433  c[(i + 1) * N + (j + 0)] += alpha * v21;
434 
435  c[(i + 0) * N + (j + 1)] += alpha * v12;
436  c[(i + 1) * N + (j + 1)] += alpha * v22;
437  }
438 
439  for (; j < j_end; ++j) {
440  size_t k = kk;
441 
442  auto r11 = vec_type::template zero<T>();
443  auto r21 = vec_type::template zero<T>();
444 
445  for (; k < k_end; k += vec_size) {
446  auto a1 = vec_type::loadu(a + (i + 0) * K + k + vec_size * 0);
447  auto a2 = vec_type::loadu(a + (i + 1) * K + k + vec_size * 0);
448 
449  auto b1 = vec_type::loadu(b + j * K + k + vec_size * 0);
450 
451  r11 = vec_type::fmadd(a1, b1, r11);
452  r21 = vec_type::fmadd(a2, b1, r21);
453  }
454 
455  auto v11 = vec_type::hadd(r11);
456  auto v21 = vec_type::hadd(r21);
457 
458  for (; k < k_end_a; ++k) {
459  v11 += a[(i + 0) * K + k] * b[k + j * K];
460  v21 += a[(i + 1) * K + k] * b[k + j * K];
461  }
462 
463  c[(i + 0) * N + j] += alpha * v11;
464  c[(i + 1) * N + j] += alpha * v21;
465  }
466  }
467 
468  for (; i < i_end; ++i) {
469  size_t j = jj;
470 
471  for (; j + 1 < j_end; j += 2) {
472  size_t k = kk;
473 
474  auto r11 = vec_type::template zero<T>();
475  auto r12 = vec_type::template zero<T>();
476 
477  for (; k < k_end; k += vec_size) {
478  auto a1 = vec_type::loadu(a + i * K + k + vec_size * 0);
479 
480  auto b1 = vec_type::loadu(b + (j + 0) * K + k + vec_size * 0);
481  auto b2 = vec_type::loadu(b + (j + 1) * K + k + vec_size * 0);
482 
483  r11 = vec_type::fmadd(a1, b1, r11);
484  r12 = vec_type::fmadd(a1, b2, r12);
485  }
486 
487  auto v11 = vec_type::hadd(r11);
488  auto v12 = vec_type::hadd(r12);
489 
490  for (; k < k_end_a; ++k) {
491  v11 += a[i * K + k] * b[k + (j + 0) * K];
492  v12 += a[i * K + k] * b[k + (j + 1) * K];
493  }
494 
495  c[i * N + (j + 0)] += alpha * v11;
496  c[i * N + (j + 1)] += alpha * v12;
497  }
498 
499  for (; j < j_end; ++j) {
500  size_t k = kk;
501 
502  auto r11 = vec_type::template zero<T>();
503 
504  for (; k < k_end; k += vec_size) {
505  auto a1 = vec_type::loadu(a + i * K + k + vec_size * 0);
506 
507  auto b1 = vec_type::loadu(b + j * K + k + vec_size * 0);
508 
509  r11 = vec_type::fmadd(a1, b1, r11);
510  }
511 
512  auto v11 = vec_type::hadd(r11);
513 
514  for (; k < k_end_a; ++k) {
515  v11 += a[i * K + k] * b[k + j * K];
516  }
517 
518  c[i * N + j] += alpha * v11;
519  }
520  }
521  }
522  }
523  }
524 }
525 
538 template <typename T>
539 void gemm_rc_to_r(const T* a, const T* b, T* c, size_t M, size_t N, size_t K, T alpha) {
540  cpp_assert(vec_enabled, "At least one vector mode must be enabled for impl::VEC");
541  cpp_assert(vectorize_impl, "vectorize_impl must be enabled for impl::VEC");
542 
543  if (M * N <= gemm_nt_rr_small_threshold) {
544  gemm_small_kernel_rc_to_r<default_vec>(a, b, c, M, N, K, alpha);
545  } else {
546  direct_fill_n(c, M * N, T(0));
547  gemm_large_kernel_rc_to_r<default_vec>(a, b, c, M, N, K, alpha);
548  }
549 }
550 
551 } //end of namespace etl::impl::vec
constexpr size_t gemm_nt_rr_small_threshold
The number of elements of B after which we use BLAS-like kernel (for GEMM)
Definition: threshold.hpp:57
void gemm_rc_to_r(const T *a, const T *b, T *c, size_t M, size_t N, size_t K, T alpha)
Vectorized implementation of row-major matrix - column-major matrix multiplication and assignment int...
Definition: gemm_rc_to_r.hpp:539
Definition: bias_add.hpp:15
constexpr bool vectorize_impl
Indicates if the implementations can be automatically vectorized by ETL.
Definition: config.hpp:35
constexpr bool vec_enabled
Indicates if vectorization is available in any format.
Definition: config.hpp:220
void gemm_small_kernel_rc_to_r(const T *a, const T *b, T *c, size_t M, size_t N, size_t K, T alpha)
Optimized version of GEMM for assignment of a small Row-Major Matrix - Column Major Matrix to a Row M...
Definition: gemm_rc_to_r.hpp:28
typename V::template vec_type< value_type > vec_type
The vectorization type for V.
Definition: dyn_matrix_view.hpp:43
auto loadu(size_t x) const noexcept
Load several elements of the expression at once.
Definition: dyn_matrix_view.hpp:154
auto min(L &&lhs, R &&rhs)
Create an expression with the min value of lhs or rhs.
Definition: expression_builder.hpp:77
void gemm_large_kernel_rc_to_r(const T *a, const T *b, T *c, size_t M, size_t N, size_t K, T alpha)
Optimized version of GEMM for assignment of a large Row-Major Matrix - Column Major Matrix to a Row M...
Definition: gemm_rc_to_r.hpp:295
void direct_fill_n(S *first, size_t n, T value)
Fills the given memory with the given value.
Definition: memory.hpp:57
Contains vectorization utilities for the vectorized assignments (done by the evaluator).