26 template <
typename V,
typename T>
30 static constexpr
size_t vec_size = vec_type::template traits<T>::size;
32 const size_t k_pos = prev_multiple(K, vec_size);
36 for (; i + 3 < M; i += 4) {
39 for (; j + 1 < N; j += 2) {
42 auto r11 = vec_type::template zero<T>();
43 auto r12 = vec_type::template zero<T>();
44 auto r13 = vec_type::template zero<T>();
45 auto r14 = vec_type::template zero<T>();
47 auto r21 = vec_type::template zero<T>();
48 auto r22 = vec_type::template zero<T>();
49 auto r23 = vec_type::template zero<T>();
50 auto r24 = vec_type::template zero<T>();
52 for (; k < k_pos; k += vec_size) {
61 r11 = vec_type::fmadd(a1, b1, r11);
62 r12 = vec_type::fmadd(a2, b1, r12);
63 r13 = vec_type::fmadd(a3, b1, r13);
64 r14 = vec_type::fmadd(a4, b1, r14);
66 r21 = vec_type::fmadd(a1, b2, r21);
67 r22 = vec_type::fmadd(a2, b2, r22);
68 r23 = vec_type::fmadd(a3, b2, r23);
69 r24 = vec_type::fmadd(a4, b2, r24);
72 auto v11 = vec_type::hadd(r11);
73 auto v12 = vec_type::hadd(r12);
74 auto v13 = vec_type::hadd(r13);
75 auto v14 = vec_type::hadd(r14);
77 auto v21 = vec_type::hadd(r21);
78 auto v22 = vec_type::hadd(r22);
79 auto v23 = vec_type::hadd(r23);
80 auto v24 = vec_type::hadd(r24);
83 v11 += a[(i + 0) * K + k] * b[k + (j + 0) * K];
84 v12 += a[(i + 1) * K + k] * b[k + (j + 0) * K];
85 v13 += a[(i + 2) * K + k] * b[k + (j + 0) * K];
86 v14 += a[(i + 3) * K + k] * b[k + (j + 0) * K];
88 v21 += a[(i + 0) * K + k] * b[k + (j + 1) * K];
89 v22 += a[(i + 1) * K + k] * b[k + (j + 1) * K];
90 v23 += a[(i + 2) * K + k] * b[k + (j + 1) * K];
91 v24 += a[(i + 3) * K + k] * b[k + (j + 1) * K];
94 c[(i + 0) + (j + 0) * M] = alpha * v11;
95 c[(i + 1) + (j + 0) * M] = alpha * v12;
96 c[(i + 2) + (j + 0) * M] = alpha * v13;
97 c[(i + 3) + (j + 0) * M] = alpha * v14;
99 c[(i + 0) + (j + 1) * M] = alpha * v21;
100 c[(i + 1) + (j + 1) * M] = alpha * v22;
101 c[(i + 2) + (j + 1) * M] = alpha * v23;
102 c[(i + 3) + (j + 1) * M] = alpha * v24;
108 auto r11 = vec_type::template zero<T>();
109 auto r12 = vec_type::template zero<T>();
110 auto r13 = vec_type::template zero<T>();
111 auto r14 = vec_type::template zero<T>();
113 for (; k < k_pos; k += vec_size) {
121 r11 = vec_type::fmadd(a1, b1, r11);
122 r12 = vec_type::fmadd(a2, b1, r12);
123 r13 = vec_type::fmadd(a3, b1, r13);
124 r14 = vec_type::fmadd(a4, b1, r14);
127 auto v11 = vec_type::hadd(r11);
128 auto v12 = vec_type::hadd(r12);
129 auto v13 = vec_type::hadd(r13);
130 auto v14 = vec_type::hadd(r14);
133 v11 += a[(i + 0) * K + k] * b[k + j * K];
134 v12 += a[(i + 1) * K + k] * b[k + j * K];
135 v13 += a[(i + 2) * K + k] * b[k + j * K];
136 v14 += a[(i + 3) * K + k] * b[k + j * K];
139 c[(i + 0) + j * M] = alpha * v11;
140 c[(i + 1) + j * M] = alpha * v12;
141 c[(i + 2) + j * M] = alpha * v13;
142 c[(i + 3) + j * M] = alpha * v14;
146 for (; i + 1 < M; i += 2) {
149 for (; j + 1 < N; j += 2) {
152 auto r11 = vec_type::template zero<T>();
153 auto r12 = vec_type::template zero<T>();
155 auto r21 = vec_type::template zero<T>();
156 auto r22 = vec_type::template zero<T>();
158 for (; k < k_pos; k += vec_size) {
165 r11 = vec_type::fmadd(a1, b1, r11);
166 r12 = vec_type::fmadd(a2, b1, r12);
168 r21 = vec_type::fmadd(a1, b2, r21);
169 r22 = vec_type::fmadd(a2, b2, r22);
172 auto v11 = vec_type::hadd(r11);
173 auto v12 = vec_type::hadd(r12);
175 auto v21 = vec_type::hadd(r21);
176 auto v22 = vec_type::hadd(r22);
179 v11 += a[(i + 0) * K + k] * b[k + (j + 0) * K];
180 v12 += a[(i + 1) * K + k] * b[k + (j + 0) * K];
182 v21 += a[(i + 0) * K + k] * b[k + (j + 1) * K];
183 v22 += a[(i + 1) * K + k] * b[k + (j + 1) * K];
186 c[(i + 0) + (j + 0) * M] = alpha * v11;
187 c[(i + 1) + (j + 0) * M] = alpha * v12;
189 c[(i + 0) + (j + 1) * M] = alpha * v21;
190 c[(i + 1) + (j + 1) * M] = alpha * v22;
196 auto r11 = vec_type::template zero<T>();
197 auto r12 = vec_type::template zero<T>();
199 for (; k < k_pos; k += vec_size) {
205 r11 = vec_type::fmadd(a1, b1, r11);
206 r12 = vec_type::fmadd(a2, b1, r12);
209 auto v11 = vec_type::hadd(r11);
210 auto v12 = vec_type::hadd(r12);
213 v11 += a[(i + 0) * K + k] * b[k + j * K];
214 v12 += a[(i + 1) * K + k] * b[k + j * K];
217 c[(i + 0) + j * M] = alpha * v11;
218 c[(i + 1) + j * M] = alpha * v12;
225 for (; j + 1 < N; j += 2) {
228 auto r11 = vec_type::template zero<T>();
229 auto r21 = vec_type::template zero<T>();
231 for (; k < k_pos; k += vec_size) {
237 r11 = vec_type::fmadd(a1, b1, r11);
238 r21 = vec_type::fmadd(a1, b2, r21);
241 auto v11 = vec_type::hadd(r11);
242 auto v21 = vec_type::hadd(r21);
245 v11 += a[i * K + k] * b[k + (j + 0) * K];
246 v21 += a[i * K + k] * b[k + (j + 1) * K];
249 c[i + (j + 0) * M] = alpha * v11;
250 c[i + (j + 1) * M] = alpha * v21;
256 auto r11 = vec_type::template zero<T>();
258 for (; k < k_pos; k += vec_size) {
263 r11 = vec_type::fmadd(a1, b1, r11);
266 auto v11 = vec_type::hadd(r11);
269 v11 += a[i * K + k] * b[k + j * K];
272 c[i + j * M] = alpha * v11;
285 template <
typename V,
typename T>
289 static constexpr
size_t vec_size = vec_type::template traits<T>::size;
291 constexpr
size_t n_block_size = 128UL;
292 constexpr
size_t m_block_size = 64UL;
293 constexpr
size_t k_block_size = 128UL;
295 for (
size_t ii = 0; ii < M; ii += m_block_size) {
296 const size_t i_end =
std::min(ii + m_block_size, M);
298 for (
size_t jj = 0; jj < N; jj += n_block_size) {
299 const size_t j_end =
std::min(jj + n_block_size, N);
301 for (
size_t kk = 0; kk < K; kk += k_block_size) {
302 const size_t k_end =
std::min(kk + k_block_size, K);
303 const size_t k_pos = prev_multiple(k_end, vec_size);
307 for (; i + 3 < i_end; i += 4) {
310 for (; j + 1 < j_end; j += 2) {
313 auto r11 = vec_type::template zero<T>();
314 auto r12 = vec_type::template zero<T>();
315 auto r13 = vec_type::template zero<T>();
316 auto r14 = vec_type::template zero<T>();
318 auto r21 = vec_type::template zero<T>();
319 auto r22 = vec_type::template zero<T>();
320 auto r23 = vec_type::template zero<T>();
321 auto r24 = vec_type::template zero<T>();
323 for (; k < k_pos; k += vec_size) {
332 r11 = vec_type::fmadd(a1, b1, r11);
333 r12 = vec_type::fmadd(a2, b1, r12);
334 r13 = vec_type::fmadd(a3, b1, r13);
335 r14 = vec_type::fmadd(a4, b1, r14);
337 r21 = vec_type::fmadd(a1, b2, r21);
338 r22 = vec_type::fmadd(a2, b2, r22);
339 r23 = vec_type::fmadd(a3, b2, r23);
340 r24 = vec_type::fmadd(a4, b2, r24);
343 auto v11 = vec_type::hadd(r11);
344 auto v12 = vec_type::hadd(r12);
345 auto v13 = vec_type::hadd(r13);
346 auto v14 = vec_type::hadd(r14);
348 auto v21 = vec_type::hadd(r21);
349 auto v22 = vec_type::hadd(r22);
350 auto v23 = vec_type::hadd(r23);
351 auto v24 = vec_type::hadd(r24);
353 for (; k < k_end; ++k) {
354 v11 += a[(i + 0) * K + k] * b[k + (j + 0) * K];
355 v12 += a[(i + 1) * K + k] * b[k + (j + 0) * K];
356 v13 += a[(i + 2) * K + k] * b[k + (j + 0) * K];
357 v14 += a[(i + 3) * K + k] * b[k + (j + 0) * K];
359 v21 += a[(i + 0) * K + k] * b[k + (j + 1) * K];
360 v22 += a[(i + 1) * K + k] * b[k + (j + 1) * K];
361 v23 += a[(i + 2) * K + k] * b[k + (j + 1) * K];
362 v24 += a[(i + 3) * K + k] * b[k + (j + 1) * K];
365 c[(i + 0) + (j + 0) * M] += alpha * v11;
366 c[(i + 1) + (j + 0) * M] += alpha * v12;
367 c[(i + 2) + (j + 0) * M] += alpha * v13;
368 c[(i + 3) + (j + 0) * M] += alpha * v14;
370 c[(i + 0) + (j + 1) * M] += alpha * v21;
371 c[(i + 1) + (j + 1) * M] += alpha * v22;
372 c[(i + 2) + (j + 1) * M] += alpha * v23;
373 c[(i + 3) + (j + 1) * M] += alpha * v24;
376 for (; j < j_end; ++j) {
379 auto r11 = vec_type::template zero<T>();
380 auto r12 = vec_type::template zero<T>();
381 auto r13 = vec_type::template zero<T>();
382 auto r14 = vec_type::template zero<T>();
384 for (; k < k_pos; k += vec_size) {
392 r11 = vec_type::fmadd(a1, b1, r11);
393 r12 = vec_type::fmadd(a2, b1, r12);
394 r13 = vec_type::fmadd(a3, b1, r13);
395 r14 = vec_type::fmadd(a4, b1, r14);
398 auto v11 = vec_type::hadd(r11);
399 auto v12 = vec_type::hadd(r12);
400 auto v13 = vec_type::hadd(r13);
401 auto v14 = vec_type::hadd(r14);
403 for (; k < k_end; ++k) {
404 v11 += a[(i + 0) * K + k] * b[k + j * K];
405 v12 += a[(i + 1) * K + k] * b[k + j * K];
406 v13 += a[(i + 2) * K + k] * b[k + j * K];
407 v14 += a[(i + 3) * K + k] * b[k + j * K];
410 c[(i + 0) + j * M] += alpha * v11;
411 c[(i + 1) + j * M] += alpha * v12;
412 c[(i + 2) + j * M] += alpha * v13;
413 c[(i + 3) + j * M] += alpha * v14;
417 for (; i + 1 < i_end; i += 2) {
420 for (; j + 1 < j_end; j += 2) {
423 auto r11 = vec_type::template zero<T>();
424 auto r12 = vec_type::template zero<T>();
426 auto r21 = vec_type::template zero<T>();
427 auto r22 = vec_type::template zero<T>();
429 for (; k < k_pos; k += vec_size) {
436 r11 = vec_type::fmadd(a1, b1, r11);
437 r12 = vec_type::fmadd(a2, b1, r12);
439 r21 = vec_type::fmadd(a1, b2, r21);
440 r22 = vec_type::fmadd(a2, b2, r22);
443 auto v11 = vec_type::hadd(r11);
444 auto v12 = vec_type::hadd(r12);
446 auto v21 = vec_type::hadd(r21);
447 auto v22 = vec_type::hadd(r22);
449 for (; k < k_end; ++k) {
450 v11 += a[(i + 0) * K + k] * b[k + (j + 0) * K];
451 v12 += a[(i + 1) * K + k] * b[k + (j + 0) * K];
453 v21 += a[(i + 0) * K + k] * b[k + (j + 1) * K];
454 v22 += a[(i + 1) * K + k] * b[k + (j + 1) * K];
457 c[(i + 0) + (j + 0) * M] += alpha * v11;
458 c[(i + 1) + (j + 0) * M] += alpha * v12;
460 c[(i + 0) + (j + 1) * M] += alpha * v21;
461 c[(i + 1) + (j + 1) * M] += alpha * v22;
464 for (; j < j_end; ++j) {
467 auto r11 = vec_type::template zero<T>();
468 auto r12 = vec_type::template zero<T>();
470 for (; k < k_pos; k += vec_size) {
476 r11 = vec_type::fmadd(a1, b1, r11);
477 r12 = vec_type::fmadd(a2, b1, r12);
480 auto v11 = vec_type::hadd(r11);
481 auto v12 = vec_type::hadd(r12);
483 for (; k < k_end; ++k) {
484 v11 += a[(i + 0) * K + k] * b[k + j * K];
485 v12 += a[(i + 1) * K + k] * b[k + j * K];
488 c[(i + 0) + j * M] += alpha * v11;
489 c[(i + 1) + j * M] += alpha * v12;
493 for (; i < i_end; ++i) {
496 for (; j + 1 < j_end; j += 2) {
499 auto r11 = vec_type::template zero<T>();
500 auto r21 = vec_type::template zero<T>();
502 for (; k < k_pos; k += vec_size) {
508 r11 = vec_type::fmadd(a1, b1, r11);
509 r21 = vec_type::fmadd(a1, b2, r21);
512 auto v11 = vec_type::hadd(r11);
513 auto v21 = vec_type::hadd(r21);
515 for (; k < k_end; ++k) {
516 v11 += a[i * K + k] * b[k + (j + 0) * K];
517 v21 += a[i * K + k] * b[k + (j + 1) * K];
520 c[i + (j + 0) * M] += alpha * v11;
521 c[i + (j + 1) * M] += alpha * v21;
524 for (; j < j_end; ++j) {
527 auto r11 = vec_type::template zero<T>();
529 for (; k < k_pos; k += vec_size) {
534 r11 = vec_type::fmadd(a1, b1, r11);
537 auto v11 = vec_type::hadd(r11);
539 for (; k < k_end; ++k) {
540 v11 += a[i * K + k] * b[k + j * K];
543 c[i + j * M] += alpha * v11;
563 template <
typename T>
564 void gemm_rc_to_c(
const T* a,
const T* b, T* c,
size_t M,
size_t N,
size_t K, T alpha) {
565 cpp_assert(
vec_enabled,
"At least one vector mode must be enabled for impl::VEC");
566 cpp_assert(
vectorize_impl,
"vectorize_impl must be enabled for impl::VEC");
569 gemm_small_kernel_rc_to_c<default_vec>(a, b, c, M, N, K, alpha);
572 gemm_small_kernel_rc_to_c<default_vec>(a, b, c, M, N, K, alpha);
Definition: bias_add.hpp:15
constexpr bool vectorize_impl
Indicates if the implementations can be automatically vectorized by ETL.
Definition: config.hpp:35
constexpr bool vec_enabled
Indicates if vectorization is available in any format.
Definition: config.hpp:220
typename V::template vec_type< value_type > vec_type
The vectorization type for V.
Definition: dyn_matrix_view.hpp:43
auto loadu(size_t x) const noexcept
Load several elements of the expression at once.
Definition: dyn_matrix_view.hpp:154
auto min(L &&lhs, R &&rhs)
Create an expression with the min value of lhs or rhs.
Definition: expression_builder.hpp:77
void gemm_rc_to_c(const T *a, const T *b, T *c, size_t M, size_t N, size_t K, T alpha)
Vectorized implementation of row-major matrix - column-major matrix multiplication and assignment int...
Definition: gemm_rc_to_c.hpp:564
void gemm_large_kernel_rc_to_c(const T *a, const T *b, T *c, size_t M, size_t N, size_t K, T alpha)
Optimized version of GEMM for assignment of a large Row-Major Matrix - Column Major Matrix to a Colum...
Definition: gemm_rc_to_c.hpp:286
void gemm_small_kernel_rc_to_c(const T *a, const T *b, T *c, size_t M, size_t N, size_t K, T alpha)
Optimized version of GEMM for assignment of a small Row-Major Matrix - Column Major Matrix to a Colum...
Definition: gemm_rc_to_c.hpp:27