26 template <
typename V,
typename T>
30 static constexpr
size_t vec_size = vec_type::template traits<T>::size;
32 auto alpha_vec = vec_type::set(alpha);
34 const auto i_end = prev_multiple(M, vec_size);
38 for (; i + 3 * vec_size < i_end; i += 4 * vec_size) {
41 for (; j + 1 < N; j += 2) {
42 auto r11 = vec_type::template zero<T>();
43 auto r21 = vec_type::template zero<T>();
44 auto r31 = vec_type::template zero<T>();
45 auto r41 = vec_type::template zero<T>();
47 auto r12 = vec_type::template zero<T>();
48 auto r22 = vec_type::template zero<T>();
49 auto r32 = vec_type::template zero<T>();
50 auto r42 = vec_type::template zero<T>();
52 for (
size_t k = 0; k < K; ++k) {
58 auto b1 = vec_type::set(b[k * N + (j + 0)]);
59 auto b2 = vec_type::set(b[k * N + (j + 1)]);
61 r11 = vec_type::fmadd(a1, b1, r11);
62 r21 = vec_type::fmadd(a2, b1, r21);
63 r31 = vec_type::fmadd(a3, b1, r31);
64 r41 = vec_type::fmadd(a4, b1, r41);
66 r12 = vec_type::fmadd(a1, b2, r12);
67 r22 = vec_type::fmadd(a2, b2, r22);
68 r32 = vec_type::fmadd(a3, b2, r32);
69 r42 = vec_type::fmadd(a4, b2, r42);
72 vec_type::storeu(c + i + (j + 0) * M + 0 * vec_size, vec_type::mul(alpha_vec, r11));
73 vec_type::storeu(c + i + (j + 0) * M + 1 * vec_size, vec_type::mul(alpha_vec, r21));
74 vec_type::storeu(c + i + (j + 0) * M + 2 * vec_size, vec_type::mul(alpha_vec, r31));
75 vec_type::storeu(c + i + (j + 0) * M + 3 * vec_size, vec_type::mul(alpha_vec, r41));
77 vec_type::storeu(c + i + (j + 1) * M + 0 * vec_size, vec_type::mul(alpha_vec, r12));
78 vec_type::storeu(c + i + (j + 1) * M + 1 * vec_size, vec_type::mul(alpha_vec, r22));
79 vec_type::storeu(c + i + (j + 1) * M + 2 * vec_size, vec_type::mul(alpha_vec, r32));
80 vec_type::storeu(c + i + (j + 1) * M + 3 * vec_size, vec_type::mul(alpha_vec, r42));
84 auto r11 = vec_type::template zero<T>();
85 auto r21 = vec_type::template zero<T>();
86 auto r31 = vec_type::template zero<T>();
87 auto r41 = vec_type::template zero<T>();
89 for (
size_t k = 0; k < K; ++k) {
95 auto b1 = vec_type::set(b[k * N + j]);
97 r11 = vec_type::fmadd(a1, b1, r11);
98 r21 = vec_type::fmadd(a2, b1, r21);
99 r31 = vec_type::fmadd(a3, b1, r31);
100 r41 = vec_type::fmadd(a4, b1, r41);
103 vec_type::storeu(c + i + j * M + 0 * vec_size, vec_type::mul(alpha_vec, r11));
104 vec_type::storeu(c + i + j * M + 1 * vec_size, vec_type::mul(alpha_vec, r21));
105 vec_type::storeu(c + i + j * M + 2 * vec_size, vec_type::mul(alpha_vec, r31));
106 vec_type::storeu(c + i + j * M + 3 * vec_size, vec_type::mul(alpha_vec, r41));
110 for (; i + 1 * vec_size < i_end; i += 2 * vec_size) {
113 for (; j + 1 < N; j += 2) {
114 auto r11 = vec_type::template zero<T>();
115 auto r21 = vec_type::template zero<T>();
117 auto r12 = vec_type::template zero<T>();
118 auto r22 = vec_type::template zero<T>();
120 for (
size_t k = 0; k < K; ++k) {
124 auto b1 = vec_type::set(b[k * N + (j + 0)]);
125 auto b2 = vec_type::set(b[k * N + (j + 1)]);
127 r11 = vec_type::fmadd(a1, b1, r11);
128 r21 = vec_type::fmadd(a2, b1, r21);
130 r12 = vec_type::fmadd(a1, b2, r12);
131 r22 = vec_type::fmadd(a2, b2, r22);
134 vec_type::storeu(c + i + (j + 0) * M + 0 * vec_size, vec_type::mul(alpha_vec, r11));
135 vec_type::storeu(c + i + (j + 0) * M + 1 * vec_size, vec_type::mul(alpha_vec, r21));
137 vec_type::storeu(c + i + (j + 1) * M + 0 * vec_size, vec_type::mul(alpha_vec, r12));
138 vec_type::storeu(c + i + (j + 1) * M + 1 * vec_size, vec_type::mul(alpha_vec, r22));
142 auto r11 = vec_type::template zero<T>();
143 auto r21 = vec_type::template zero<T>();
145 for (
size_t k = 0; k < K; ++k) {
149 auto b1 = vec_type::set(b[k * N + j]);
151 r11 = vec_type::fmadd(a1, b1, r11);
152 r21 = vec_type::fmadd(a2, b1, r21);
155 vec_type::storeu(c + i + j * M + 0 * vec_size, vec_type::mul(alpha_vec, r11));
156 vec_type::storeu(c + i + j * M + 1 * vec_size, vec_type::mul(alpha_vec, r21));
160 for (; i < i_end; i += vec_size) {
163 for (; j + 1 < N; j += 2) {
164 auto r11 = vec_type::template zero<T>();
165 auto r12 = vec_type::template zero<T>();
167 for (
size_t k = 0; k < K; ++k) {
170 auto b1 = vec_type::set(b[k * N + (j + 0)]);
171 auto b2 = vec_type::set(b[k * N + (j + 1)]);
173 r11 = vec_type::fmadd(a1, b1, r11);
174 r12 = vec_type::fmadd(a1, b2, r12);
182 auto r11 = vec_type::template zero<T>();
184 for (
size_t k = 0; k < K; ++k) {
187 auto b1 = vec_type::set(b[k * N + j]);
189 r11 = vec_type::fmadd(a1, b1, r11);
199 for (; j + 1 < N; j += 2) {
203 for (
size_t k = 0; k < K; ++k) {
204 r11 += a[i + k * M] * b[k * N + (j + 0)];
205 r12 += a[i + k * M] * b[k * N + (j + 1)];
208 c[i + (j + 0) * M] = alpha * r11;
209 c[i + (j + 1) * M] = alpha * r12;
215 for (
size_t k = 0; k < K; ++k) {
216 r11 += a[i + k * M] * b[k * N + j];
219 c[i + j * M] = alpha * r11;
232 template <
typename V,
typename T>
236 static constexpr
size_t vec_size = vec_type::template traits<T>::size;
238 constexpr
size_t n_block_size = 128UL;
239 constexpr
size_t m_block_size = 64UL;
240 constexpr
size_t k_block_size = 128UL;
242 auto alpha_vec = vec_type::set(alpha);
244 for (
size_t ii = 0; ii < M; ii += m_block_size) {
245 const size_t i_end =
std::min(ii + m_block_size, M);
246 const size_t i_pos = prev_multiple(i_end, vec_size);
248 for (
size_t jj = 0; jj < N; jj += n_block_size) {
249 const size_t j_end =
std::min(jj + n_block_size, N);
251 for (
size_t kk = 0; kk < K; kk += k_block_size) {
252 const size_t k_end =
std::min(kk + k_block_size, K);
257 for (; i + 3 * vec_size < i_pos; i += 4 * vec_size) {
260 for (; j + 1 < j_end; j += 2) {
271 for (
size_t k = kk; k < k_end; ++k) {
277 auto b1 = vec_type::set(b[k * N + (j + 0)]);
278 auto b2 = vec_type::set(b[k * N + (j + 1)]);
280 r11 = vec_type::fmadd(a1, b1, r11);
281 r12 = vec_type::fmadd(a2, b1, r12);
282 r13 = vec_type::fmadd(a3, b1, r13);
283 r14 = vec_type::fmadd(a4, b1, r14);
285 r21 = vec_type::fmadd(a1, b2, r21);
286 r22 = vec_type::fmadd(a2, b2, r22);
287 r23 = vec_type::fmadd(a3, b2, r23);
288 r24 = vec_type::fmadd(a4, b2, r24);
291 vec_type::storeu(c + i + (j + 0) * M + 0 * vec_size, vec_type::mul(alpha_vec, r11));
292 vec_type::storeu(c + i + (j + 0) * M + 1 * vec_size, vec_type::mul(alpha_vec, r12));
293 vec_type::storeu(c + i + (j + 0) * M + 2 * vec_size, vec_type::mul(alpha_vec, r13));
294 vec_type::storeu(c + i + (j + 0) * M + 3 * vec_size, vec_type::mul(alpha_vec, r14));
296 vec_type::storeu(c + i + (j + 1) * M + 0 * vec_size, vec_type::mul(alpha_vec, r21));
297 vec_type::storeu(c + i + (j + 1) * M + 1 * vec_size, vec_type::mul(alpha_vec, r22));
298 vec_type::storeu(c + i + (j + 1) * M + 2 * vec_size, vec_type::mul(alpha_vec, r23));
299 vec_type::storeu(c + i + (j + 1) * M + 3 * vec_size, vec_type::mul(alpha_vec, r24));
308 for (
size_t k = kk; k < k_end; ++k) {
314 auto b1 = vec_type::set(b[k * N + j]);
316 r11 = vec_type::fmadd(a1, b1, r11);
317 r12 = vec_type::fmadd(a2, b1, r12);
318 r13 = vec_type::fmadd(a3, b1, r13);
319 r14 = vec_type::fmadd(a4, b1, r14);
322 vec_type::storeu(c + i + j * M + 0 * vec_size, vec_type::mul(alpha_vec, r11));
323 vec_type::storeu(c + i + j * M + 1 * vec_size, vec_type::mul(alpha_vec, r12));
324 vec_type::storeu(c + i + j * M + 2 * vec_size, vec_type::mul(alpha_vec, r13));
325 vec_type::storeu(c + i + j * M + 3 * vec_size, vec_type::mul(alpha_vec, r14));
330 for (; i + 1 * vec_size < i_pos; i += 2 * vec_size) {
333 for (; j + 3 < j_end; j += 4) {
346 for (
size_t k = kk; k < k_end; ++k) {
350 auto b1 = vec_type::set(b[k * N + (j + 0)]);
351 auto b2 = vec_type::set(b[k * N + (j + 1)]);
352 auto b3 = vec_type::set(b[k * N + (j + 2)]);
353 auto b4 = vec_type::set(b[k * N + (j + 3)]);
355 r11 = vec_type::fmadd(a1, b1, r11);
356 r12 = vec_type::fmadd(a2, b1, r12);
358 r21 = vec_type::fmadd(a1, b2, r21);
359 r22 = vec_type::fmadd(a2, b2, r22);
361 r31 = vec_type::fmadd(a1, b3, r31);
362 r32 = vec_type::fmadd(a2, b3, r32);
364 r41 = vec_type::fmadd(a1, b4, r41);
365 r42 = vec_type::fmadd(a2, b4, r42);
368 vec_type::storeu(c + i + (j + 0) * M + 0 * vec_size, vec_type::mul(alpha_vec, r11));
369 vec_type::storeu(c + i + (j + 0) * M + 1 * vec_size, vec_type::mul(alpha_vec, r12));
371 vec_type::storeu(c + i + (j + 1) * M + 0 * vec_size, vec_type::mul(alpha_vec, r21));
372 vec_type::storeu(c + i + (j + 1) * M + 1 * vec_size, vec_type::mul(alpha_vec, r22));
374 vec_type::storeu(c + i + (j + 2) * M + 0 * vec_size, vec_type::mul(alpha_vec, r31));
375 vec_type::storeu(c + i + (j + 2) * M + 1 * vec_size, vec_type::mul(alpha_vec, r32));
377 vec_type::storeu(c + i + (j + 3) * M + 0 * vec_size, vec_type::mul(alpha_vec, r41));
378 vec_type::storeu(c + i + (j + 3) * M + 1 * vec_size, vec_type::mul(alpha_vec, r42));
381 for (; j + 1 < j_end; j += 2) {
388 for (
size_t k = kk; k < k_end; ++k) {
392 auto b1 = vec_type::set(b[k * N + (j + 0)]);
393 auto b2 = vec_type::set(b[k * N + (j + 1)]);
395 r11 = vec_type::fmadd(a1, b1, r11);
396 r12 = vec_type::fmadd(a2, b1, r12);
398 r21 = vec_type::fmadd(a1, b2, r21);
399 r22 = vec_type::fmadd(a2, b2, r22);
402 vec_type::storeu(c + i + (j + 0) * M + 0 * vec_size, vec_type::mul(alpha_vec, r11));
403 vec_type::storeu(c + i + (j + 0) * M + 1 * vec_size, vec_type::mul(alpha_vec, r12));
405 vec_type::storeu(c + i + (j + 1) * M + 0 * vec_size, vec_type::mul(alpha_vec, r21));
406 vec_type::storeu(c + i + (j + 1) * M + 1 * vec_size, vec_type::mul(alpha_vec, r22));
413 for (
size_t k = kk; k < k_end; ++k) {
417 auto b1 = vec_type::set(b[k * N + j]);
419 r11 = vec_type::fmadd(a1, b1, r11);
420 r12 = vec_type::fmadd(a2, b1, r12);
423 vec_type::storeu(c + i + j * M + 0 * vec_size, vec_type::mul(alpha_vec, r11));
424 vec_type::storeu(c + i + j * M + 1 * vec_size, vec_type::mul(alpha_vec, r12));
428 for (; i < i_pos; i += vec_size) {
429 for (
size_t j = jj; j < j_end; ++j) {
432 for (
size_t k = kk; k < k_end; ++k) {
435 auto b1 = vec_type::set(b[k * N + j]);
437 r11 = vec_type::fmadd(a1, b1, r11);
440 vec_type::storeu(c + i + j * M + 0 * vec_size, vec_type::mul(alpha_vec, r11));
444 for (; i < i_end; ++i) {
445 for (
size_t j = jj; j < j_end; ++j) {
446 auto r11 = c[i + j * M];
448 for (
size_t k = kk; k < k_end; ++k) {
449 r11 += a[i + k * M] * b[k * N + j];
452 c[i + j * M] = alpha * r11;
472 template <
typename T>
473 void gemm_cr_to_c(
const T* a,
const T* b, T* c,
size_t M,
size_t N,
size_t K, T alpha) {
474 cpp_assert(
vec_enabled,
"At least one vector mode must be enabled for impl::VEC");
475 cpp_assert(
vectorize_impl,
"vectorize_impl must be enabled for impl::VEC");
478 gemm_small_kernel_cr_to_c<default_vec>(a, b, c, M, N, K, alpha);
481 gemm_large_kernel_cr_to_c<default_vec>(a, b, c, M, N, K, alpha);
void gemm_large_kernel_cr_to_c(const T *a, const T *b, T *c, size_t M, size_t N, size_t K, T alpha)
Optimized version of GEMM for assignment of a large Column-Major Matrix - Row Major Matrix to a Colum...
Definition: gemm_cr_to_c.hpp:233
Definition: bias_add.hpp:15
constexpr size_t gemm_rr_small_threshold
The number of elements of B after which we use BLAS-like kernel (for GEMM)
Definition: threshold.hpp:55
constexpr bool vectorize_impl
Indicates if the implementations can be automatically vectorized by ETL.
Definition: config.hpp:35
constexpr bool vec_enabled
Indicates if vectorization is available in any format.
Definition: config.hpp:220
typename V::template vec_type< value_type > vec_type
The vectorization type for V.
Definition: dyn_matrix_view.hpp:43
void gemm_cr_to_c(const T *a, const T *b, T *c, size_t M, size_t N, size_t K, T alpha)
Vectorized implementation of column-major matrix - row-major matrix multiplication and assignment int...
Definition: gemm_cr_to_c.hpp:473
void storeu(vec_type< V > in, size_t i) noexcept
Store several elements in the matrix at once.
Definition: dyn_matrix_view.hpp:187
auto loadu(size_t x) const noexcept
Load several elements of the expression at once.
Definition: dyn_matrix_view.hpp:154
auto min(L &&lhs, R &&rhs)
Create an expression with the min value of lhs or rhs.
Definition: expression_builder.hpp:77
void gemm_small_kernel_cr_to_c(const T *a, const T *b, T *c, size_t M, size_t N, size_t K, T alpha)
Optimized version of GEMM for assignment of a small Column-Major Matrix - Row Major Matrix to a Colum...
Definition: gemm_cr_to_c.hpp:27
void direct_fill_n(S *first, size_t n, T value)
Fills the given memory with the given value.
Definition: memory.hpp:57