24 template <
typename V,
typename T>
28 static constexpr
size_t vec_size = vec_type::template traits<T>::size;
30 auto alpha_vec = vec_type::set(alpha);
34 for (; i + 4 * vec_size - 1 < M; i += 4 * vec_size) {
37 for (; (j + 2UL) <= N; j += 2UL) {
38 auto r11 = vec_type::template zero<T>();
39 auto r12 = vec_type::template zero<T>();
41 auto r21 = vec_type::template zero<T>();
42 auto r22 = vec_type::template zero<T>();
44 auto r31 = vec_type::template zero<T>();
45 auto r32 = vec_type::template zero<T>();
47 auto r41 = vec_type::template zero<T>();
48 auto r42 = vec_type::template zero<T>();
50 for (
size_t k = 0; k < K; ++k) {
56 auto b1 = vec_type::set(b[k + (j + 0) * K]);
57 auto b2 = vec_type::set(b[k + (j + 1) * K]);
59 r11 = vec_type::fmadd(a11, b1, r11);
60 r12 = vec_type::fmadd(a11, b2, r12);
62 r21 = vec_type::fmadd(a21, b1, r21);
63 r22 = vec_type::fmadd(a21, b2, r22);
65 r31 = vec_type::fmadd(a31, b1, r31);
66 r32 = vec_type::fmadd(a31, b2, r32);
68 r41 = vec_type::fmadd(a41, b1, r41);
69 r42 = vec_type::fmadd(a41, b2, r42);
72 vec_type::storeu(c + (i + vec_size * 0) + (j + 0) * M, vec_type::mul(alpha_vec, r11));
73 vec_type::storeu(c + (i + vec_size * 0) + (j + 1) * M, vec_type::mul(alpha_vec, r12));
75 vec_type::storeu(c + (i + vec_size * 1) + (j + 0) * M, vec_type::mul(alpha_vec, r21));
76 vec_type::storeu(c + (i + vec_size * 1) + (j + 1) * M, vec_type::mul(alpha_vec, r22));
78 vec_type::storeu(c + (i + vec_size * 2) + (j + 0) * M, vec_type::mul(alpha_vec, r31));
79 vec_type::storeu(c + (i + vec_size * 2) + (j + 1) * M, vec_type::mul(alpha_vec, r32));
81 vec_type::storeu(c + (i + vec_size * 3) + (j + 0) * M, vec_type::mul(alpha_vec, r41));
82 vec_type::storeu(c + (i + vec_size * 3) + (j + 1) * M, vec_type::mul(alpha_vec, r42));
86 auto r11 = vec_type::template zero<T>();
87 auto r21 = vec_type::template zero<T>();
88 auto r31 = vec_type::template zero<T>();
89 auto r41 = vec_type::template zero<T>();
91 for (
size_t k = 0; k < K; ++k) {
97 auto b1 = vec_type::set(b[k + j * K]);
99 r11 = vec_type::fmadd(a11, b1, r11);
100 r21 = vec_type::fmadd(a21, b1, r21);
101 r31 = vec_type::fmadd(a31, b1, r31);
102 r41 = vec_type::fmadd(a41, b1, r41);
105 vec_type::storeu(c + i + vec_size * 0 + j * M, vec_type::mul(alpha_vec, r11));
106 vec_type::storeu(c + i + vec_size * 1 + j * M, vec_type::mul(alpha_vec, r21));
107 vec_type::storeu(c + i + vec_size * 2 + j * M, vec_type::mul(alpha_vec, r31));
108 vec_type::storeu(c + i + vec_size * 3 + j * M, vec_type::mul(alpha_vec, r41));
112 for (; i + 2 * vec_size - 1 < M; i += 2 * vec_size) {
115 for (; (j + 2UL) <= N; j += 2UL) {
116 auto r11 = vec_type::template zero<T>();
117 auto r12 = vec_type::template zero<T>();
119 auto r21 = vec_type::template zero<T>();
120 auto r22 = vec_type::template zero<T>();
122 for (
size_t k = 0; k < K; ++k) {
126 auto b1 = vec_type::set(b[k + (j + 0) * K]);
127 auto b2 = vec_type::set(b[k + (j + 1) * K]);
129 r11 = vec_type::fmadd(a11, b1, r11);
130 r12 = vec_type::fmadd(a11, b2, r12);
132 r21 = vec_type::fmadd(a21, b1, r21);
133 r22 = vec_type::fmadd(a21, b2, r22);
136 vec_type::storeu(c + (i + vec_size * 0) + (j + 0) * M, vec_type::mul(alpha_vec, r11));
137 vec_type::storeu(c + (i + vec_size * 0) + (j + 1) * M, vec_type::mul(alpha_vec, r12));
139 vec_type::storeu(c + (i + vec_size * 1) + (j + 0) * M, vec_type::mul(alpha_vec, r21));
140 vec_type::storeu(c + (i + vec_size * 1) + (j + 1) * M, vec_type::mul(alpha_vec, r22));
144 auto r11 = vec_type::template zero<T>();
145 auto r21 = vec_type::template zero<T>();
147 for (
size_t k = 0; k < K; ++k) {
151 auto b1 = vec_type::set(b[k + j * K]);
153 r11 = vec_type::fmadd(a11, b1, r11);
154 r21 = vec_type::fmadd(a21, b1, r21);
157 vec_type::storeu(c + i + vec_size * 0 + j * M, vec_type::mul(alpha_vec, r11));
158 vec_type::storeu(c + i + vec_size * 1 + j * M, vec_type::mul(alpha_vec, r21));
162 for (; i + vec_size - 1 < M; i += vec_size) {
165 for (; (j + 2UL) <= N; j += 2UL) {
166 auto r1 = vec_type::template zero<T>();
167 auto r2 = vec_type::template zero<T>();
169 for (
size_t k = 0; k < K; ++k) {
172 auto b1 = vec_type::set(b[k + (j + 0) * K]);
173 auto b2 = vec_type::set(b[k + (j + 1) * K]);
175 r1 = vec_type::fmadd(a1, b1, r1);
176 r2 = vec_type::fmadd(a1, b2, r2);
184 auto r1 = vec_type::template zero<T>();
186 for (
size_t k = 0; k < K; ++k) {
189 auto b1 = vec_type::set(b[k + j * K]);
191 r1 = vec_type::fmadd(a1, b1, r1);
201 for (; j + 1 < N; j += 2) {
205 for (
size_t k = 0; k < K; ++k) {
206 value1 += a[i + k * M] * b[k + (j + 0) * K];
207 value2 += a[i + k * M] * b[k + (j + 1) * K];
210 c[i + (j + 0) * M] = alpha * value1;
211 c[i + (j + 1) * M] = alpha * value2;
217 for (
size_t k = 0; k < K; ++k) {
218 value += a[i + k * M] * b[k + j * K];
221 c[i + j * M] = alpha * value;
232 template <
typename V,
typename T>
236 static constexpr
size_t vec_size = vec_type::template traits<T>::size;
238 constexpr
size_t m_block_size = 128;
239 constexpr
size_t n_block_size = 64;
240 constexpr
size_t k_block_size = 128;
242 auto alpha_vec = vec_type::set(alpha);
244 for (
size_t block_i = 0; block_i < M; block_i += m_block_size) {
245 const size_t i_end =
std::min(block_i + m_block_size, M);
247 for (
size_t block_j = 0; block_j < N; block_j += n_block_size) {
248 const size_t j_end =
std::min(block_j + n_block_size, N);
251 for (
size_t j = block_j; j < j_end; ++j) {
252 for (
size_t i = block_i; i < i_end; ++i) {
257 for (
size_t block_k = 0; block_k < K; block_k += k_block_size) {
258 const size_t k_end =
std::min(block_k + k_block_size, K);
263 for (; i + 4 * vec_size - 1 < i_end; i += 4 * vec_size) {
266 for (; j + 1 < j_end; j += 2) {
277 for (
size_t k = block_k; k < k_end; ++k) {
283 auto b1 = vec_type::set(b[k + (j + 0) * K]);
284 auto b2 = vec_type::set(b[k + (j + 1) * K]);
286 r11 = vec_type::fmadd(a1, b1, r11);
287 r12 = vec_type::fmadd(a2, b1, r12);
288 r13 = vec_type::fmadd(a3, b1, r13);
289 r14 = vec_type::fmadd(a4, b1, r14);
291 r21 = vec_type::fmadd(a1, b2, r21);
292 r22 = vec_type::fmadd(a2, b2, r22);
293 r23 = vec_type::fmadd(a3, b2, r23);
294 r24 = vec_type::fmadd(a4, b2, r24);
297 vec_type::storeu(c + i + 0 * vec_size + (j + 0) * M, vec_type::mul(alpha_vec, r11));
298 vec_type::storeu(c + i + 1 * vec_size + (j + 0) * M, vec_type::mul(alpha_vec, r12));
299 vec_type::storeu(c + i + 2 * vec_size + (j + 0) * M, vec_type::mul(alpha_vec, r13));
300 vec_type::storeu(c + i + 3 * vec_size + (j + 0) * M, vec_type::mul(alpha_vec, r14));
302 vec_type::storeu(c + i + 0 * vec_size + (j + 1) * M, vec_type::mul(alpha_vec, r21));
303 vec_type::storeu(c + i + 1 * vec_size + (j + 1) * M, vec_type::mul(alpha_vec, r22));
304 vec_type::storeu(c + i + 2 * vec_size + (j + 1) * M, vec_type::mul(alpha_vec, r23));
305 vec_type::storeu(c + i + 3 * vec_size + (j + 1) * M, vec_type::mul(alpha_vec, r24));
308 for (; j < j_end; ++j) {
314 for (
size_t k = block_k; k < k_end; ++k) {
320 auto b1 = vec_type::set(b[k + j * K]);
322 r1 = vec_type::fmadd(a1, b1, r1);
323 r2 = vec_type::fmadd(a2, b1, r2);
324 r3 = vec_type::fmadd(a3, b1, r3);
325 r4 = vec_type::fmadd(a4, b1, r4);
328 vec_type::storeu(c + i + 0 * vec_size + j * M, vec_type::mul(alpha_vec, r1));
329 vec_type::storeu(c + i + 1 * vec_size + j * M, vec_type::mul(alpha_vec, r2));
330 vec_type::storeu(c + i + 2 * vec_size + j * M, vec_type::mul(alpha_vec, r3));
331 vec_type::storeu(c + i + 3 * vec_size + j * M, vec_type::mul(alpha_vec, r4));
336 for (; i + 2 * vec_size - 1 < i_end; i += 2 * vec_size) {
339 for (; j + 3 < j_end; j += 4) {
352 for (
size_t k = block_k; k < k_end; ++k) {
356 auto b1 = vec_type::set(b[k + (j + 0) * K]);
357 auto b2 = vec_type::set(b[k + (j + 1) * K]);
358 auto b3 = vec_type::set(b[k + (j + 2) * K]);
359 auto b4 = vec_type::set(b[k + (j + 3) * K]);
361 r11 = vec_type::fmadd(a1, b1, r11);
362 r12 = vec_type::fmadd(a2, b1, r12);
364 r21 = vec_type::fmadd(a1, b2, r21);
365 r22 = vec_type::fmadd(a2, b2, r22);
367 r31 = vec_type::fmadd(a1, b3, r31);
368 r32 = vec_type::fmadd(a2, b3, r32);
370 r41 = vec_type::fmadd(a1, b4, r41);
371 r42 = vec_type::fmadd(a2, b4, r42);
374 vec_type::storeu(c + i + 0 * vec_size + (j + 0) * M, vec_type::mul(alpha_vec, r11));
375 vec_type::storeu(c + i + 1 * vec_size + (j + 0) * M, vec_type::mul(alpha_vec, r12));
377 vec_type::storeu(c + i + 0 * vec_size + (j + 1) * M, vec_type::mul(alpha_vec, r21));
378 vec_type::storeu(c + i + 1 * vec_size + (j + 1) * M, vec_type::mul(alpha_vec, r22));
380 vec_type::storeu(c + i + 0 * vec_size + (j + 2) * M, vec_type::mul(alpha_vec, r31));
381 vec_type::storeu(c + i + 1 * vec_size + (j + 2) * M, vec_type::mul(alpha_vec, r32));
383 vec_type::storeu(c + i + 0 * vec_size + (j + 3) * M, vec_type::mul(alpha_vec, r41));
384 vec_type::storeu(c + i + 1 * vec_size + (j + 3) * M, vec_type::mul(alpha_vec, r42));
387 for (; j + 1 < j_end; j += 2) {
394 for (
size_t k = block_k; k < k_end; ++k) {
398 auto b1 = vec_type::set(b[k + (j + 0) * K]);
399 auto b2 = vec_type::set(b[k + (j + 1) * K]);
401 r11 = vec_type::fmadd(a1, b1, r11);
402 r12 = vec_type::fmadd(a2, b1, r12);
404 r21 = vec_type::fmadd(a1, b2, r21);
405 r22 = vec_type::fmadd(a2, b2, r22);
408 vec_type::storeu(c + i + 0 * vec_size + (j + 0) * M, vec_type::mul(alpha_vec, r11));
409 vec_type::storeu(c + i + 1 * vec_size + (j + 0) * M, vec_type::mul(alpha_vec, r12));
411 vec_type::storeu(c + i + 0 * vec_size + (j + 1) * M, vec_type::mul(alpha_vec, r21));
412 vec_type::storeu(c + i + 1 * vec_size + (j + 1) * M, vec_type::mul(alpha_vec, r22));
415 for (; j < j_end; ++j) {
419 for (
size_t k = block_k; k < k_end; ++k) {
423 auto b1 = vec_type::set(b[k + j * K]);
425 r1 = vec_type::fmadd(a1, b1, r1);
426 r2 = vec_type::fmadd(a2, b1, r2);
429 vec_type::storeu(c + i + 0 * vec_size + j * M, vec_type::mul(alpha_vec, r1));
430 vec_type::storeu(c + i + 1 * vec_size + j * M, vec_type::mul(alpha_vec, r2));
435 for (; i + vec_size - 1 < i_end; i += vec_size) {
436 for (
size_t j = block_j; j < j_end; ++j) {
439 for (
size_t k = block_k; k < k_end; ++k) {
441 auto b1 = vec_type::set(b[k + j * K]);
443 r1 = vec_type::fmadd(a1, b1, r1);
451 for (; i < i_end; ++i) {
452 for (
size_t j = block_j; j < j_end; ++j) {
453 auto x = c[i + j * M];
455 for (
size_t k = block_k; k < k_end; ++k) {
456 x += a[i + k * M] * b[k + j * K];
459 c[i + j * M] = alpha * x;
479 template <
typename T>
480 void gemm_cc_to_c(
const T* a,
const T* b, T* c,
size_t M,
size_t N,
size_t K, T alpha) {
481 cpp_assert(
vec_enabled,
"At least one vector mode must be enabled for impl::VEC");
482 cpp_assert(
vectorize_impl,
"vectorize_impl must be enabled for impl::VEC");
487 gemm_small_kernel_cc_to_c<default_vec>(a, b, c, M, N, K, alpha);
489 gemm_large_kernel_cc_to_c<default_vec>(a, b, c, M, N, K, alpha);
Definition: bias_add.hpp:15
constexpr bool vectorize_impl
Indicates if the implementations can be automatically vectorized by ETL.
Definition: config.hpp:35
constexpr bool vec_enabled
Indicates if vectorization is available in any format.
Definition: config.hpp:220
typename V::template vec_type< value_type > vec_type
The vectorization type for V.
Definition: dyn_matrix_view.hpp:43
void storeu(vec_type< V > in, size_t i) noexcept
Store several elements in the matrix at once.
Definition: dyn_matrix_view.hpp:187
auto loadu(size_t x) const noexcept
Load several elements of the expression at once.
Definition: dyn_matrix_view.hpp:154
auto min(L &&lhs, R &&rhs)
Create an expression with the min value of lhs or rhs.
Definition: expression_builder.hpp:77
constexpr size_t gemm_cc_small_threshold
The number of elements of B after which we use BLAS-like kernel (for GEMM)
Definition: threshold.hpp:58
void gemm_cc_to_c(const T *a, const T *b, T *c, size_t M, size_t N, size_t K, T alpha)
Vectorized implementation of column-major matrix - column-major matrix multiplication and assignment ...
Definition: gemm_cc_to_c.hpp:480
void gemm_small_kernel_cc_to_c(const T *a, const T *b, T *c, size_t M, size_t N, size_t K, T alpha)
Optimized version of small GEMM for column major version.
Definition: gemm_cc_to_c.hpp:25
void gemm_large_kernel_cc_to_c(const T *a, const T *b, T *c, size_t M, size_t N, size_t K, T alpha)
Optimized version of large GEMM for column major version.
Definition: gemm_cc_to_c.hpp:233