27 template <
typename V,
typename T>
31 static constexpr
size_t vec_size = vec_type::template traits<T>::size;
35 const auto k_end = prev_multiple(K, vec_size);
37 for (; i + 1 < M; i += 2) {
40 for (; j + 3 < N; j += 4) {
43 auto r11 = vec_type::template zero<T>();
44 auto r21 = vec_type::template zero<T>();
46 auto r12 = vec_type::template zero<T>();
47 auto r22 = vec_type::template zero<T>();
49 auto r13 = vec_type::template zero<T>();
50 auto r23 = vec_type::template zero<T>();
52 auto r14 = vec_type::template zero<T>();
53 auto r24 = vec_type::template zero<T>();
55 for (; k < k_end; k += vec_size) {
64 r11 = vec_type::fmadd(a1, b1, r11);
65 r21 = vec_type::fmadd(a2, b1, r21);
67 r12 = vec_type::fmadd(a1, b2, r12);
68 r22 = vec_type::fmadd(a2, b2, r22);
70 r13 = vec_type::fmadd(a1, b3, r13);
71 r23 = vec_type::fmadd(a2, b3, r23);
73 r14 = vec_type::fmadd(a1, b4, r14);
74 r24 = vec_type::fmadd(a2, b4, r24);
77 auto v11 = vec_type::hadd(r11);
78 auto v21 = vec_type::hadd(r21);
80 auto v12 = vec_type::hadd(r12);
81 auto v22 = vec_type::hadd(r22);
83 auto v13 = vec_type::hadd(r13);
84 auto v23 = vec_type::hadd(r23);
86 auto v14 = vec_type::hadd(r14);
87 auto v24 = vec_type::hadd(r24);
90 v11 += a[(i + 0) * K + k] * b[k + (j + 0) * K];
91 v21 += a[(i + 1) * K + k] * b[k + (j + 0) * K];
93 v12 += a[(i + 0) * K + k] * b[k + (j + 1) * K];
94 v22 += a[(i + 1) * K + k] * b[k + (j + 1) * K];
96 v13 += a[(i + 0) * K + k] * b[k + (j + 2) * K];
97 v23 += a[(i + 1) * K + k] * b[k + (j + 2) * K];
99 v14 += a[(i + 0) * K + k] * b[k + (j + 3) * K];
100 v24 += a[(i + 1) * K + k] * b[k + (j + 3) * K];
103 c[(i + 0) * N + (j + 0)] = alpha * v11;
104 c[(i + 1) * N + (j + 0)] = alpha * v21;
106 c[(i + 0) * N + (j + 1)] = alpha * v12;
107 c[(i + 1) * N + (j + 1)] = alpha * v22;
109 c[(i + 0) * N + (j + 2)] = alpha * v13;
110 c[(i + 1) * N + (j + 2)] = alpha * v23;
112 c[(i + 0) * N + (j + 3)] = alpha * v14;
113 c[(i + 1) * N + (j + 3)] = alpha * v24;
116 for (; j + 1 < N; j += 2) {
119 auto r11 = vec_type::template zero<T>();
120 auto r21 = vec_type::template zero<T>();
122 auto r12 = vec_type::template zero<T>();
123 auto r22 = vec_type::template zero<T>();
125 for (; k < k_end; k += vec_size) {
132 r11 = vec_type::fmadd(a1, b1, r11);
133 r21 = vec_type::fmadd(a2, b1, r21);
135 r12 = vec_type::fmadd(a1, b2, r12);
136 r22 = vec_type::fmadd(a2, b2, r22);
139 auto v11 = vec_type::hadd(r11);
140 auto v21 = vec_type::hadd(r21);
142 auto v12 = vec_type::hadd(r12);
143 auto v22 = vec_type::hadd(r22);
146 v11 += a[(i + 0) * K + k] * b[k + (j + 0) * K];
147 v21 += a[(i + 1) * K + k] * b[k + (j + 0) * K];
149 v12 += a[(i + 0) * K + k] * b[k + (j + 1) * K];
150 v22 += a[(i + 1) * K + k] * b[k + (j + 1) * K];
153 c[(i + 0) * N + (j + 0)] = alpha * v11;
154 c[(i + 1) * N + (j + 0)] = alpha * v21;
156 c[(i + 0) * N + (j + 1)] = alpha * v12;
157 c[(i + 1) * N + (j + 1)] = alpha * v22;
163 auto r11 = vec_type::template zero<T>();
164 auto r21 = vec_type::template zero<T>();
166 for (; k < k_end; k += vec_size) {
172 r11 = vec_type::fmadd(a1, b1, r11);
173 r21 = vec_type::fmadd(a2, b1, r21);
176 auto v11 = vec_type::hadd(r11);
177 auto v21 = vec_type::hadd(r21);
180 v11 += a[(i + 0) * K + k] * b[k + j * K];
181 v21 += a[(i + 1) * K + k] * b[k + j * K];
184 c[(i + 0) * N + j] = alpha * v11;
185 c[(i + 1) * N + j] = alpha * v21;
193 for (; j + 3 < N; j += 4) {
196 auto r11 = vec_type::template zero<T>();
197 auto r12 = vec_type::template zero<T>();
198 auto r13 = vec_type::template zero<T>();
199 auto r14 = vec_type::template zero<T>();
201 for (; k < k_end; k += vec_size) {
209 r11 = vec_type::fmadd(a1, b1, r11);
210 r12 = vec_type::fmadd(a1, b2, r12);
211 r13 = vec_type::fmadd(a1, b3, r13);
212 r14 = vec_type::fmadd(a1, b4, r14);
215 auto v11 = vec_type::hadd(r11);
216 auto v12 = vec_type::hadd(r12);
217 auto v13 = vec_type::hadd(r13);
218 auto v14 = vec_type::hadd(r14);
221 v11 += a[i * K + k] * b[k + (j + 0) * K];
222 v12 += a[i * K + k] * b[k + (j + 1) * K];
223 v13 += a[i * K + k] * b[k + (j + 2) * K];
224 v14 += a[i * K + k] * b[k + (j + 3) * K];
227 c[i * N + (j + 0)] = alpha * v11;
228 c[i * N + (j + 1)] = alpha * v12;
229 c[i * N + (j + 2)] = alpha * v13;
230 c[i * N + (j + 3)] = alpha * v14;
234 for (; j + 1 < N; j += 2) {
237 auto r11 = vec_type::template zero<T>();
238 auto r12 = vec_type::template zero<T>();
240 for (; k < k_end; k += vec_size) {
246 r11 = vec_type::fmadd(a1, b1, r11);
247 r12 = vec_type::fmadd(a1, b2, r12);
250 auto v11 = vec_type::hadd(r11);
251 auto v12 = vec_type::hadd(r12);
254 v11 += a[i * K + k] * b[k + (j + 0) * K];
255 v12 += a[i * K + k] * b[k + (j + 1) * K];
258 c[i * N + (j + 0)] = alpha * v11;
259 c[i * N + (j + 1)] = alpha * v12;
265 auto r11 = vec_type::template zero<T>();
267 for (; k < k_end; k += vec_size) {
272 r11 = vec_type::fmadd(a1, b1, r11);
275 auto v11 = vec_type::hadd(r11);
278 v11 += a[i * K + k] * b[k + j * K];
281 c[i * N + j] = alpha * v11;
294 template <
typename V,
typename T>
298 static constexpr
size_t vec_size = vec_type::template traits<T>::size;
300 constexpr
size_t n_block_size = 256UL;
301 constexpr
size_t m_block_size = 128UL;
302 constexpr
size_t k_block_size = 256UL;
304 for (
size_t ii = 0; ii < M; ii += m_block_size) {
305 const size_t i_end =
std::min(ii + m_block_size, M);
307 for (
size_t jj = 0; jj < N; jj += n_block_size) {
308 const size_t j_end =
std::min(jj + n_block_size, N);
310 for (
size_t kk = 0; kk < K; kk += k_block_size) {
311 const size_t k_end_a =
std::min(kk + k_block_size, K);
312 const size_t k_end = prev_multiple(k_end_a, vec_size);
316 for (; i + 1 < i_end; i += 2) {
319 for (; j + 3 < j_end; j += 4) {
322 auto r11 = vec_type::template zero<T>();
323 auto r21 = vec_type::template zero<T>();
325 auto r12 = vec_type::template zero<T>();
326 auto r22 = vec_type::template zero<T>();
328 auto r13 = vec_type::template zero<T>();
329 auto r23 = vec_type::template zero<T>();
331 auto r14 = vec_type::template zero<T>();
332 auto r24 = vec_type::template zero<T>();
334 for (; k < k_end; k += vec_size) {
343 r11 = vec_type::fmadd(a1, b1, r11);
344 r21 = vec_type::fmadd(a2, b1, r21);
346 r12 = vec_type::fmadd(a1, b2, r12);
347 r22 = vec_type::fmadd(a2, b2, r22);
349 r13 = vec_type::fmadd(a1, b3, r13);
350 r23 = vec_type::fmadd(a2, b3, r23);
352 r14 = vec_type::fmadd(a1, b4, r14);
353 r24 = vec_type::fmadd(a2, b4, r24);
356 auto v11 = vec_type::hadd(r11);
357 auto v21 = vec_type::hadd(r21);
359 auto v12 = vec_type::hadd(r12);
360 auto v22 = vec_type::hadd(r22);
362 auto v13 = vec_type::hadd(r13);
363 auto v23 = vec_type::hadd(r23);
365 auto v14 = vec_type::hadd(r14);
366 auto v24 = vec_type::hadd(r24);
368 for (; k < k_end_a; ++k) {
369 v11 += a[(i + 0) * K + k] * b[k + (j + 0) * K];
370 v21 += a[(i + 1) * K + k] * b[k + (j + 0) * K];
372 v12 += a[(i + 0) * K + k] * b[k + (j + 1) * K];
373 v22 += a[(i + 1) * K + k] * b[k + (j + 1) * K];
375 v13 += a[(i + 0) * K + k] * b[k + (j + 2) * K];
376 v23 += a[(i + 1) * K + k] * b[k + (j + 2) * K];
378 v14 += a[(i + 0) * K + k] * b[k + (j + 3) * K];
379 v24 += a[(i + 1) * K + k] * b[k + (j + 3) * K];
382 c[(i + 0) * N + (j + 0)] += alpha * v11;
383 c[(i + 1) * N + (j + 0)] += alpha * v21;
385 c[(i + 0) * N + (j + 1)] += alpha * v12;
386 c[(i + 1) * N + (j + 1)] += alpha * v22;
388 c[(i + 0) * N + (j + 2)] += alpha * v13;
389 c[(i + 1) * N + (j + 2)] += alpha * v23;
391 c[(i + 0) * N + (j + 3)] += alpha * v14;
392 c[(i + 1) * N + (j + 3)] += alpha * v24;
395 for (; j + 1 < j_end; j += 2) {
398 auto r11 = vec_type::template zero<T>();
399 auto r21 = vec_type::template zero<T>();
401 auto r12 = vec_type::template zero<T>();
402 auto r22 = vec_type::template zero<T>();
404 for (; k < k_end; k += vec_size) {
411 r11 = vec_type::fmadd(a1, b1, r11);
412 r21 = vec_type::fmadd(a2, b1, r21);
414 r12 = vec_type::fmadd(a1, b2, r12);
415 r22 = vec_type::fmadd(a2, b2, r22);
418 auto v11 = vec_type::hadd(r11);
419 auto v21 = vec_type::hadd(r21);
421 auto v12 = vec_type::hadd(r12);
422 auto v22 = vec_type::hadd(r22);
424 for (; k < k_end_a; ++k) {
425 v11 += a[(i + 0) * K + k] * b[k + (j + 0) * K];
426 v21 += a[(i + 1) * K + k] * b[k + (j + 0) * K];
428 v12 += a[(i + 0) * K + k] * b[k + (j + 1) * K];
429 v22 += a[(i + 1) * K + k] * b[k + (j + 1) * K];
432 c[(i + 0) * N + (j + 0)] += alpha * v11;
433 c[(i + 1) * N + (j + 0)] += alpha * v21;
435 c[(i + 0) * N + (j + 1)] += alpha * v12;
436 c[(i + 1) * N + (j + 1)] += alpha * v22;
439 for (; j < j_end; ++j) {
442 auto r11 = vec_type::template zero<T>();
443 auto r21 = vec_type::template zero<T>();
445 for (; k < k_end; k += vec_size) {
451 r11 = vec_type::fmadd(a1, b1, r11);
452 r21 = vec_type::fmadd(a2, b1, r21);
455 auto v11 = vec_type::hadd(r11);
456 auto v21 = vec_type::hadd(r21);
458 for (; k < k_end_a; ++k) {
459 v11 += a[(i + 0) * K + k] * b[k + j * K];
460 v21 += a[(i + 1) * K + k] * b[k + j * K];
463 c[(i + 0) * N + j] += alpha * v11;
464 c[(i + 1) * N + j] += alpha * v21;
468 for (; i < i_end; ++i) {
471 for (; j + 1 < j_end; j += 2) {
474 auto r11 = vec_type::template zero<T>();
475 auto r12 = vec_type::template zero<T>();
477 for (; k < k_end; k += vec_size) {
483 r11 = vec_type::fmadd(a1, b1, r11);
484 r12 = vec_type::fmadd(a1, b2, r12);
487 auto v11 = vec_type::hadd(r11);
488 auto v12 = vec_type::hadd(r12);
490 for (; k < k_end_a; ++k) {
491 v11 += a[i * K + k] * b[k + (j + 0) * K];
492 v12 += a[i * K + k] * b[k + (j + 1) * K];
495 c[i * N + (j + 0)] += alpha * v11;
496 c[i * N + (j + 1)] += alpha * v12;
499 for (; j < j_end; ++j) {
502 auto r11 = vec_type::template zero<T>();
504 for (; k < k_end; k += vec_size) {
509 r11 = vec_type::fmadd(a1, b1, r11);
512 auto v11 = vec_type::hadd(r11);
514 for (; k < k_end_a; ++k) {
515 v11 += a[i * K + k] * b[k + j * K];
518 c[i * N + j] += alpha * v11;
538 template <
typename T>
539 void gemm_rc_to_r(
const T* a,
const T* b, T* c,
size_t M,
size_t N,
size_t K, T alpha) {
540 cpp_assert(
vec_enabled,
"At least one vector mode must be enabled for impl::VEC");
541 cpp_assert(
vectorize_impl,
"vectorize_impl must be enabled for impl::VEC");
544 gemm_small_kernel_rc_to_r<default_vec>(a, b, c, M, N, K, alpha);
547 gemm_large_kernel_rc_to_r<default_vec>(a, b, c, M, N, K, alpha);
constexpr size_t gemm_nt_rr_small_threshold
The number of elements of B after which we use BLAS-like kernel (for GEMM)
Definition: threshold.hpp:57
void gemm_rc_to_r(const T *a, const T *b, T *c, size_t M, size_t N, size_t K, T alpha)
Vectorized implementation of row-major matrix - column-major matrix multiplication and assignment int...
Definition: gemm_rc_to_r.hpp:539
Definition: bias_add.hpp:15
constexpr bool vectorize_impl
Indicates if the implementations can be automatically vectorized by ETL.
Definition: config.hpp:35
constexpr bool vec_enabled
Indicates if vectorization is available in any format.
Definition: config.hpp:220
void gemm_small_kernel_rc_to_r(const T *a, const T *b, T *c, size_t M, size_t N, size_t K, T alpha)
Optimized version of GEMM for assignment of a small Row-Major Matrix - Column Major Matrix to a Row M...
Definition: gemm_rc_to_r.hpp:28
typename V::template vec_type< value_type > vec_type
The vectorization type for V.
Definition: dyn_matrix_view.hpp:43
auto loadu(size_t x) const noexcept
Load several elements of the expression at once.
Definition: dyn_matrix_view.hpp:154
auto min(L &&lhs, R &&rhs)
Create an expression with the min value of lhs or rhs.
Definition: expression_builder.hpp:77
void gemm_large_kernel_rc_to_r(const T *a, const T *b, T *c, size_t M, size_t N, size_t K, T alpha)
Optimized version of GEMM for assignment of a large Row-Major Matrix - Column Major Matrix to a Row M...
Definition: gemm_rc_to_r.hpp:295
void direct_fill_n(S *first, size_t n, T value)
Fills the given memory with the given value.
Definition: memory.hpp:57
Contains vectorization utilities for the vectorized assignments (done by the evaluator).