26 template <
typename V,
typename T>
30 static constexpr
size_t vec_size = vec_type::template traits<T>::size;
32 auto alpha_vec = vec_type::set(alpha);
34 const auto j_end = prev_multiple(N, vec_size);
38 for (; j + 7 * vec_size < j_end; j += 8 * vec_size) {
42 auto r11 = vec_type::template zero<T>();
43 auto r12 = vec_type::template zero<T>();
44 auto r13 = vec_type::template zero<T>();
45 auto r14 = vec_type::template zero<T>();
46 auto r15 = vec_type::template zero<T>();
47 auto r16 = vec_type::template zero<T>();
48 auto r17 = vec_type::template zero<T>();
49 auto r18 = vec_type::template zero<T>();
51 for (
size_t k = 0; k < K; ++k) {
52 auto a1 = vec_type::set(a[(i + 0) + k * M]);
63 r11 = vec_type::fmadd(a1, b1, r11);
64 r12 = vec_type::fmadd(a1, b2, r12);
65 r13 = vec_type::fmadd(a1, b3, r13);
66 r14 = vec_type::fmadd(a1, b4, r14);
67 r15 = vec_type::fmadd(a1, b5, r15);
68 r16 = vec_type::fmadd(a1, b6, r16);
69 r17 = vec_type::fmadd(a1, b7, r17);
70 r18 = vec_type::fmadd(a1, b8, r18);
73 vec_type::storeu(c + (i + 0) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
74 vec_type::storeu(c + (i + 0) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r12));
75 vec_type::storeu(c + (i + 0) * N + j + 2 * vec_size, vec_type::mul(alpha_vec, r13));
76 vec_type::storeu(c + (i + 0) * N + j + 3 * vec_size, vec_type::mul(alpha_vec, r14));
77 vec_type::storeu(c + (i + 0) * N + j + 4 * vec_size, vec_type::mul(alpha_vec, r15));
78 vec_type::storeu(c + (i + 0) * N + j + 5 * vec_size, vec_type::mul(alpha_vec, r16));
79 vec_type::storeu(c + (i + 0) * N + j + 6 * vec_size, vec_type::mul(alpha_vec, r17));
80 vec_type::storeu(c + (i + 0) * N + j + 7 * vec_size, vec_type::mul(alpha_vec, r18));
84 for (; j + 3 * vec_size < j_end; j += 4 * vec_size) {
87 for (; i + 1 < M; i += 2) {
88 auto r11 = vec_type::template zero<T>();
89 auto r21 = vec_type::template zero<T>();
91 auto r12 = vec_type::template zero<T>();
92 auto r22 = vec_type::template zero<T>();
94 auto r13 = vec_type::template zero<T>();
95 auto r23 = vec_type::template zero<T>();
97 auto r14 = vec_type::template zero<T>();
98 auto r24 = vec_type::template zero<T>();
100 for (
size_t k = 0; k < K; ++k) {
101 auto a1 = vec_type::set(a[(i + 0) + k * M]);
102 auto a2 = vec_type::set(a[(i + 1) + k * M]);
109 r11 = vec_type::fmadd(a1, b1, r11);
110 r21 = vec_type::fmadd(a2, b1, r21);
112 r12 = vec_type::fmadd(a1, b2, r12);
113 r22 = vec_type::fmadd(a2, b2, r22);
115 r13 = vec_type::fmadd(a1, b3, r13);
116 r23 = vec_type::fmadd(a2, b3, r23);
118 r14 = vec_type::fmadd(a1, b4, r14);
119 r24 = vec_type::fmadd(a2, b4, r24);
122 vec_type::storeu(c + (i + 0) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
123 vec_type::storeu(c + (i + 1) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r21));
125 vec_type::storeu(c + (i + 0) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r12));
126 vec_type::storeu(c + (i + 1) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r22));
128 vec_type::storeu(c + (i + 0) * N + j + 2 * vec_size, vec_type::mul(alpha_vec, r13));
129 vec_type::storeu(c + (i + 1) * N + j + 2 * vec_size, vec_type::mul(alpha_vec, r23));
131 vec_type::storeu(c + (i + 0) * N + j + 3 * vec_size, vec_type::mul(alpha_vec, r14));
132 vec_type::storeu(c + (i + 1) * N + j + 3 * vec_size, vec_type::mul(alpha_vec, r24));
136 auto r11 = vec_type::template zero<T>();
137 auto r12 = vec_type::template zero<T>();
138 auto r13 = vec_type::template zero<T>();
139 auto r14 = vec_type::template zero<T>();
141 for (
size_t k = 0; k < K; ++k) {
142 auto a1 = vec_type::set(a[(i + 0) + k * M]);
149 r11 = vec_type::fmadd(a1, b1, r11);
150 r12 = vec_type::fmadd(a1, b2, r12);
151 r13 = vec_type::fmadd(a1, b3, r13);
152 r14 = vec_type::fmadd(a1, b4, r14);
155 vec_type::storeu(c + (i + 0) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
156 vec_type::storeu(c + (i + 0) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r12));
157 vec_type::storeu(c + (i + 0) * N + j + 2 * vec_size, vec_type::mul(alpha_vec, r13));
158 vec_type::storeu(c + (i + 0) * N + j + 3 * vec_size, vec_type::mul(alpha_vec, r14));
162 for (; j + 1 * vec_size < j_end; j += 2 * vec_size) {
165 for (; i + 1 < M; i += 2) {
166 auto r11 = vec_type::template zero<T>();
167 auto r21 = vec_type::template zero<T>();
169 auto r12 = vec_type::template zero<T>();
170 auto r22 = vec_type::template zero<T>();
172 for (
size_t k = 0; k < K; ++k) {
173 auto a1 = vec_type::set(a[(i + 0) + k * M]);
174 auto a2 = vec_type::set(a[(i + 1) + k * M]);
179 r11 = vec_type::fmadd(a1, b1, r11);
180 r21 = vec_type::fmadd(a2, b1, r21);
182 r12 = vec_type::fmadd(a1, b2, r12);
183 r22 = vec_type::fmadd(a2, b2, r22);
186 vec_type::storeu(c + (i + 0) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
187 vec_type::storeu(c + (i + 1) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r21));
189 vec_type::storeu(c + (i + 0) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r12));
190 vec_type::storeu(c + (i + 1) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r22));
194 auto r11 = vec_type::template zero<T>();
195 auto r12 = vec_type::template zero<T>();
197 for (
size_t k = 0; k < K; ++k) {
198 auto a1 = vec_type::set(a[(i + 0) + k * M]);
203 r11 = vec_type::fmadd(a1, b1, r11);
204 r12 = vec_type::fmadd(a1, b2, r12);
207 vec_type::storeu(c + (i + 0) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
208 vec_type::storeu(c + (i + 0) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r12));
212 for (; j < j_end; j += vec_size) {
215 for (; i + 1 < M; i += 2) {
216 auto r11 = vec_type::template zero<T>();
217 auto r21 = vec_type::template zero<T>();
219 for (
size_t k = 0; k < K; ++k) {
220 auto a1 = vec_type::set(a[(i + 0) + k * M]);
221 auto a2 = vec_type::set(a[(i + 1) + k * M]);
225 r11 = vec_type::fmadd(a1, b1, r11);
226 r21 = vec_type::fmadd(a2, b1, r21);
229 vec_type::storeu(c + (i + 0) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
230 vec_type::storeu(c + (i + 1) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r21));
234 auto r11 = vec_type::template zero<T>();
236 for (
size_t k = 0; k < K; ++k) {
237 auto a1 = vec_type::set(a[i + k * M]);
241 r11 = vec_type::fmadd(a1, b1, r11);
244 vec_type::storeu(c + i * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
251 for (; i + 1 < M; i += 2) {
255 for (
size_t k = 0; k < K; ++k) {
256 r1 += a[(i + 0) + k * M] * b[k * N + j];
257 r2 += a[(i + 1) + k * M] * b[k * N + j];
260 c[(i + 0) * N + j] = alpha * r1;
261 c[(i + 1) * N + j] = alpha * r2;
267 for (
size_t k = 0; k < K; ++k) {
268 r1 += a[i + k * M] * b[k * N + j];
271 c[i * N + j] = alpha * r1;
284 template <
typename V,
typename T>
288 static constexpr
size_t vec_size = vec_type::template traits<T>::size;
290 constexpr
size_t n_block_size = 128UL;
291 constexpr
size_t m_block_size = 64UL;
292 constexpr
size_t k_block_size = 128UL;
294 auto alpha_vec = vec_type::set(alpha);
296 for (
size_t jj = 0; jj < N; jj += n_block_size) {
297 const size_t j_end_a =
std::min(jj + n_block_size, N);
298 const auto j_end = prev_multiple(j_end_a, vec_size);
300 for (
size_t ii = 0; ii < M; ii += m_block_size) {
301 const size_t i_end =
std::min(ii + m_block_size, M);
303 for (
size_t kk = 0; kk < K; kk += k_block_size) {
304 const size_t k_end =
std::min(kk + k_block_size, K);
309 for (; j + 3 * vec_size < j_end; j += 4 * vec_size) {
312 for (; i + 1 < i_end; i += 2) {
323 for (
size_t k = kk; k < k_end; ++k) {
324 auto a1 = vec_type::set(a[(i + 0) + k * M]);
325 auto a2 = vec_type::set(a[(i + 1) + k * M]);
332 r11 = vec_type::fmadd(a1, b1, r11);
333 r12 = vec_type::fmadd(a1, b2, r12);
334 r13 = vec_type::fmadd(a1, b3, r13);
335 r14 = vec_type::fmadd(a1, b4, r14);
337 r21 = vec_type::fmadd(a2, b1, r21);
338 r22 = vec_type::fmadd(a2, b2, r22);
339 r23 = vec_type::fmadd(a2, b3, r23);
340 r24 = vec_type::fmadd(a2, b4, r24);
343 vec_type::storeu(c + (i + 0) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
344 vec_type::storeu(c + (i + 0) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r12));
345 vec_type::storeu(c + (i + 0) * N + j + 2 * vec_size, vec_type::mul(alpha_vec, r13));
346 vec_type::storeu(c + (i + 0) * N + j + 3 * vec_size, vec_type::mul(alpha_vec, r14));
348 vec_type::storeu(c + (i + 1) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r21));
349 vec_type::storeu(c + (i + 1) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r22));
350 vec_type::storeu(c + (i + 1) * N + j + 2 * vec_size, vec_type::mul(alpha_vec, r23));
351 vec_type::storeu(c + (i + 1) * N + j + 3 * vec_size, vec_type::mul(alpha_vec, r24));
360 for (
size_t k = kk; k < k_end; ++k) {
361 auto a1 = vec_type::set(a[(i + 0) + k * M]);
368 r11 = vec_type::fmadd(a1, b1, r11);
369 r12 = vec_type::fmadd(a1, b2, r12);
370 r13 = vec_type::fmadd(a1, b3, r13);
371 r14 = vec_type::fmadd(a1, b4, r14);
374 vec_type::storeu(c + (i + 0) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
375 vec_type::storeu(c + (i + 0) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r12));
376 vec_type::storeu(c + (i + 0) * N + j + 2 * vec_size, vec_type::mul(alpha_vec, r13));
377 vec_type::storeu(c + (i + 0) * N + j + 3 * vec_size, vec_type::mul(alpha_vec, r14));
382 for (; j + vec_size < j_end; j += 2 * vec_size) {
385 for (; i + 3 < i_end; i += 4) {
398 for (
size_t k = kk; k < k_end; ++k) {
399 auto a1 = vec_type::set(a[(i + 0) + k * M]);
400 auto a2 = vec_type::set(a[(i + 1) + k * M]);
401 auto a3 = vec_type::set(a[(i + 2) + k * M]);
402 auto a4 = vec_type::set(a[(i + 3) + k * M]);
407 r11 = vec_type::fmadd(a1, b1, r11);
408 r12 = vec_type::fmadd(a1, b2, r12);
410 r21 = vec_type::fmadd(a2, b1, r21);
411 r22 = vec_type::fmadd(a2, b2, r22);
413 r31 = vec_type::fmadd(a3, b1, r31);
414 r32 = vec_type::fmadd(a3, b2, r32);
416 r41 = vec_type::fmadd(a4, b1, r41);
417 r42 = vec_type::fmadd(a4, b2, r42);
420 vec_type::storeu(c + (i + 0) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
421 vec_type::storeu(c + (i + 0) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r12));
423 vec_type::storeu(c + (i + 1) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r21));
424 vec_type::storeu(c + (i + 1) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r22));
426 vec_type::storeu(c + (i + 2) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r31));
427 vec_type::storeu(c + (i + 2) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r32));
429 vec_type::storeu(c + (i + 3) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r41));
430 vec_type::storeu(c + (i + 3) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r42));
433 for (; i + 1 < i_end; i += 2) {
440 for (
size_t k = kk; k < k_end; ++k) {
441 auto a1 = vec_type::set(a[(i + 0) + k * M]);
442 auto a2 = vec_type::set(a[(i + 1) + k * M]);
447 r11 = vec_type::fmadd(a1, b1, r11);
448 r12 = vec_type::fmadd(a1, b2, r12);
450 r21 = vec_type::fmadd(a2, b1, r21);
451 r22 = vec_type::fmadd(a2, b2, r22);
454 vec_type::storeu(c + (i + 0) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
455 vec_type::storeu(c + (i + 0) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r12));
457 vec_type::storeu(c + (i + 1) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r21));
458 vec_type::storeu(c + (i + 1) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r22));
465 for (
size_t k = kk; k < k_end; ++k) {
466 auto a1 = vec_type::set(a[(i + 0) + k * M]);
471 r11 = vec_type::fmadd(a1, b1, r11);
472 r12 = vec_type::fmadd(a1, b2, r12);
475 vec_type::storeu(c + (i + 0) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
476 vec_type::storeu(c + (i + 0) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r12));
480 for (; j < j_end; j += vec_size) {
481 for (
size_t i = ii; i < i_end; ++i) {
484 for (
size_t k = kk; k < k_end; ++k) {
485 auto a1 = vec_type::set(a[(i + 0) + k * M]);
489 r11 = vec_type::fmadd(a1, b1, r11);
492 vec_type::storeu(c + (i + 0) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
496 for (; j < j_end_a; ++j) {
497 for (
size_t i = ii; i < i_end; ++i) {
498 auto r11 = c[(i + 0) * N + j];
500 for (
size_t k = kk; k < k_end; ++k) {
501 r11 += a[(i + 0) + k * M] * b[k * N + j];
504 c[(i + 0) * N + j] = alpha * r11;
524 template <
typename T>
525 void gemm_cr_to_r(
const T* a,
const T* b, T* c,
size_t M,
size_t N,
size_t K, T alpha) {
526 cpp_assert(
vec_enabled,
"At least one vector mode must be enabled for impl::VEC");
527 cpp_assert(
vectorize_impl,
"vectorize_impl must be enabled for impl::VEC");
530 gemm_small_kernel_cr_to_r<default_vec>(a, b, c, M, N, K, alpha);
533 gemm_large_kernel_cr_to_r<default_vec>(a, b, c, M, N, K, alpha);
Definition: bias_add.hpp:15
constexpr size_t gemm_rr_small_threshold
The number of elements of B after which we use BLAS-like kernel (for GEMM)
Definition: threshold.hpp:55
constexpr bool vectorize_impl
Indicates if the implementations can be automatically vectorized by ETL.
Definition: config.hpp:35
constexpr bool vec_enabled
Indicates if vectorization is available in any format.
Definition: config.hpp:220
typename V::template vec_type< value_type > vec_type
The vectorization type for V.
Definition: dyn_matrix_view.hpp:43
void gemm_cr_to_r(const T *a, const T *b, T *c, size_t M, size_t N, size_t K, T alpha)
Vectorized implementation of column-major matrix - row-major matrix multiplication and assignment int...
Definition: gemm_cr_to_r.hpp:525
void storeu(vec_type< V > in, size_t i) noexcept
Store several elements in the matrix at once.
Definition: dyn_matrix_view.hpp:187
void gemm_large_kernel_cr_to_r(const T *a, const T *b, T *c, size_t M, size_t N, size_t K, T alpha)
Optimized version of GEMM for assignment of a large Column-Major Matrix - Row Major Matrix to a Row M...
Definition: gemm_cr_to_r.hpp:285
auto loadu(size_t x) const noexcept
Load several elements of the expression at once.
Definition: dyn_matrix_view.hpp:154
auto min(L &&lhs, R &&rhs)
Create an expression with the min value of lhs or rhs.
Definition: expression_builder.hpp:77
void gemm_small_kernel_cr_to_r(const T *a, const T *b, T *c, size_t M, size_t N, size_t K, T alpha)
Optimized version of GEMM for assignment of a small Column-Major Matrix - Row Major Matrix to a Row M...
Definition: gemm_cr_to_r.hpp:27
void direct_fill_n(S *first, size_t n, T value)
Fills the given memory with the given value.
Definition: memory.hpp:57