22 #ifndef ETL_GEMM_SMALL_RR_R_UNROLL_8 24 #define ETL_GEMM_SMALL_RR_R_UNROLL_8 34 template <
typename V,
typename T>
38 static constexpr
size_t vec_size = vec_type::template traits<T>::size;
40 const auto j_end = prev_multiple(N, vec_size);
44 auto alpha_vec = vec_type::set(alpha);
49 #ifdef ETL_GEMM_SMALL_RR_R_UNROLL_8 51 for (; j + vec_size * 7 < j_end; j += vec_size * 8) {
52 for (
size_t i = 0; i < M; ++i) {
55 auto a1 = vec_type::set(a[i * K + k]);
57 auto r1 = vec_type::mul(a1,
vec_type::loadu(b + k * N + j + vec_size * 0));
58 auto r2 = vec_type::mul(a1,
vec_type::loadu(b + k * N + j + vec_size * 1));
59 auto r3 = vec_type::mul(a1,
vec_type::loadu(b + k * N + j + vec_size * 2));
60 auto r4 = vec_type::mul(a1,
vec_type::loadu(b + k * N + j + vec_size * 3));
61 auto r5 = vec_type::mul(a1,
vec_type::loadu(b + k * N + j + vec_size * 4));
62 auto r6 = vec_type::mul(a1,
vec_type::loadu(b + k * N + j + vec_size * 5));
63 auto r7 = vec_type::mul(a1,
vec_type::loadu(b + k * N + j + vec_size * 6));
64 auto r8 = vec_type::mul(a1,
vec_type::loadu(b + k * N + j + vec_size * 7));
66 for (++k; k < K; ++k) {
67 a1 = vec_type::set(a[i * K + k]);
69 r1 = vec_type::fmadd(a1,
vec_type::loadu(b + k * N + j + vec_size * 0), r1);
70 r2 = vec_type::fmadd(a1,
vec_type::loadu(b + k * N + j + vec_size * 1), r2);
71 r3 = vec_type::fmadd(a1,
vec_type::loadu(b + k * N + j + vec_size * 2), r3);
72 r4 = vec_type::fmadd(a1,
vec_type::loadu(b + k * N + j + vec_size * 3), r4);
73 r5 = vec_type::fmadd(a1,
vec_type::loadu(b + k * N + j + vec_size * 4), r5);
74 r6 = vec_type::fmadd(a1,
vec_type::loadu(b + k * N + j + vec_size * 5), r6);
75 r7 = vec_type::fmadd(a1,
vec_type::loadu(b + k * N + j + vec_size * 6), r7);
76 r8 = vec_type::fmadd(a1,
vec_type::loadu(b + k * N + j + vec_size * 7), r8);
79 vec_type::storeu(c + i * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r1));
80 vec_type::storeu(c + i * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r2));
81 vec_type::storeu(c + i * N + j + 2 * vec_size, vec_type::mul(alpha_vec, r3));
82 vec_type::storeu(c + i * N + j + 3 * vec_size, vec_type::mul(alpha_vec, r4));
83 vec_type::storeu(c + i * N + j + 4 * vec_size, vec_type::mul(alpha_vec, r5));
84 vec_type::storeu(c + i * N + j + 5 * vec_size, vec_type::mul(alpha_vec, r6));
85 vec_type::storeu(c + i * N + j + 6 * vec_size, vec_type::mul(alpha_vec, r7));
86 vec_type::storeu(c + i * N + j + 7 * vec_size, vec_type::mul(alpha_vec, r8));
93 for (; j + vec_size * 4 < j_end; j += 5 * vec_size) {
96 for (; i + 1 < M; i += 2) {
99 auto a1 = vec_type::set(a[(i + 0) * K + k]);
100 auto a2 = vec_type::set(a[(i + 1) * K + k]);
108 auto r11 = vec_type::mul(a1, b1);
109 auto r12 = vec_type::mul(a2, b1);
111 auto r21 = vec_type::mul(a1, b2);
112 auto r22 = vec_type::mul(a2, b2);
114 auto r31 = vec_type::mul(a1, b3);
115 auto r32 = vec_type::mul(a2, b3);
117 auto r41 = vec_type::mul(a1, b4);
118 auto r42 = vec_type::mul(a2, b4);
120 auto r51 = vec_type::mul(a1, b5);
121 auto r52 = vec_type::mul(a2, b5);
123 for (++k; k < K; ++k) {
124 a1 = vec_type::set(a[(i + 0) * K + k]);
125 a2 = vec_type::set(a[(i + 1) * K + k]);
133 r11 = vec_type::fmadd(a1, b1, r11);
134 r12 = vec_type::fmadd(a2, b1, r12);
136 r21 = vec_type::fmadd(a1, b2, r21);
137 r22 = vec_type::fmadd(a2, b2, r22);
139 r31 = vec_type::fmadd(a1, b3, r31);
140 r32 = vec_type::fmadd(a2, b3, r32);
142 r41 = vec_type::fmadd(a1, b4, r41);
143 r42 = vec_type::fmadd(a2, b4, r42);
145 r51 = vec_type::fmadd(a1, b5, r51);
146 r52 = vec_type::fmadd(a2, b5, r52);
149 vec_type::storeu(c + (i + 0) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
150 vec_type::storeu(c + (i + 0) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r21));
151 vec_type::storeu(c + (i + 0) * N + j + 2 * vec_size, vec_type::mul(alpha_vec, r31));
152 vec_type::storeu(c + (i + 0) * N + j + 3 * vec_size, vec_type::mul(alpha_vec, r41));
153 vec_type::storeu(c + (i + 0) * N + j + 4 * vec_size, vec_type::mul(alpha_vec, r51));
155 vec_type::storeu(c + (i + 1) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r12));
156 vec_type::storeu(c + (i + 1) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r22));
157 vec_type::storeu(c + (i + 1) * N + j + 2 * vec_size, vec_type::mul(alpha_vec, r32));
158 vec_type::storeu(c + (i + 1) * N + j + 3 * vec_size, vec_type::mul(alpha_vec, r42));
159 vec_type::storeu(c + (i + 1) * N + j + 4 * vec_size, vec_type::mul(alpha_vec, r52));
165 auto a1 = vec_type::set(a[(i + 0) * K + k]);
167 auto r11 = vec_type::mul(a1,
vec_type::loadu(b + k * N + j + vec_size * 0));
168 auto r21 = vec_type::mul(a1,
vec_type::loadu(b + k * N + j + vec_size * 1));
169 auto r31 = vec_type::mul(a1,
vec_type::loadu(b + k * N + j + vec_size * 2));
170 auto r41 = vec_type::mul(a1,
vec_type::loadu(b + k * N + j + vec_size * 3));
171 auto r51 = vec_type::mul(a1,
vec_type::loadu(b + k * N + j + vec_size * 4));
173 for (++k; k < K; ++k) {
174 a1 = vec_type::set(a[(i + 0) * K + k]);
176 r11 = vec_type::fmadd(a1,
vec_type::loadu(b + k * N + j + vec_size * 0), r11);
177 r21 = vec_type::fmadd(a1,
vec_type::loadu(b + k * N + j + vec_size * 1), r21);
178 r31 = vec_type::fmadd(a1,
vec_type::loadu(b + k * N + j + vec_size * 2), r31);
179 r41 = vec_type::fmadd(a1,
vec_type::loadu(b + k * N + j + vec_size * 3), r41);
180 r51 = vec_type::fmadd(a1,
vec_type::loadu(b + k * N + j + vec_size * 4), r51);
183 vec_type::storeu(c + i * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
184 vec_type::storeu(c + i * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r21));
185 vec_type::storeu(c + i * N + j + 2 * vec_size, vec_type::mul(alpha_vec, r31));
186 vec_type::storeu(c + i * N + j + 3 * vec_size, vec_type::mul(alpha_vec, r41));
187 vec_type::storeu(c + i * N + j + 4 * vec_size, vec_type::mul(alpha_vec, r51));
192 for (; j + vec_size * 3 < j_end; j += 4 * vec_size) {
195 for (; i + 1 < M; i += 2) {
198 auto a1 = vec_type::set(a[(i + 0) * K + k]);
199 auto a2 = vec_type::set(a[(i + 1) * K + k]);
206 auto r11 = vec_type::mul(a1, b1);
207 auto r12 = vec_type::mul(a2, b1);
209 auto r21 = vec_type::mul(a1, b2);
210 auto r22 = vec_type::mul(a2, b2);
212 auto r31 = vec_type::mul(a1, b3);
213 auto r32 = vec_type::mul(a2, b3);
215 auto r41 = vec_type::mul(a1, b4);
216 auto r42 = vec_type::mul(a2, b4);
218 for (++k; k < K; ++k) {
219 a1 = vec_type::set(a[(i + 0) * K + k]);
220 a2 = vec_type::set(a[(i + 1) * K + k]);
227 r11 = vec_type::fmadd(a1, b1, r11);
228 r12 = vec_type::fmadd(a2, b1, r12);
230 r21 = vec_type::fmadd(a1, b2, r21);
231 r22 = vec_type::fmadd(a2, b2, r22);
233 r31 = vec_type::fmadd(a1, b3, r31);
234 r32 = vec_type::fmadd(a2, b3, r32);
236 r41 = vec_type::fmadd(a1, b4, r41);
237 r42 = vec_type::fmadd(a2, b4, r42);
240 vec_type::storeu(c + (i + 0) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
241 vec_type::storeu(c + (i + 0) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r21));
242 vec_type::storeu(c + (i + 0) * N + j + 2 * vec_size, vec_type::mul(alpha_vec, r31));
243 vec_type::storeu(c + (i + 0) * N + j + 3 * vec_size, vec_type::mul(alpha_vec, r41));
245 vec_type::storeu(c + (i + 1) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r12));
246 vec_type::storeu(c + (i + 1) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r22));
247 vec_type::storeu(c + (i + 1) * N + j + 2 * vec_size, vec_type::mul(alpha_vec, r32));
248 vec_type::storeu(c + (i + 1) * N + j + 3 * vec_size, vec_type::mul(alpha_vec, r42));
254 auto a1 = vec_type::set(a[(i + 0) * K + k]);
256 auto r11 = vec_type::mul(a1,
vec_type::loadu(b + k * N + j + vec_size * 0));
257 auto r21 = vec_type::mul(a1,
vec_type::loadu(b + k * N + j + vec_size * 1));
258 auto r31 = vec_type::mul(a1,
vec_type::loadu(b + k * N + j + vec_size * 2));
259 auto r41 = vec_type::mul(a1,
vec_type::loadu(b + k * N + j + vec_size * 3));
261 for (++k; k < K; ++k) {
262 a1 = vec_type::set(a[(i + 0) * K + k]);
264 r11 = vec_type::fmadd(a1,
vec_type::loadu(b + k * N + j + vec_size * 0), r11);
265 r21 = vec_type::fmadd(a1,
vec_type::loadu(b + k * N + j + vec_size * 1), r21);
266 r31 = vec_type::fmadd(a1,
vec_type::loadu(b + k * N + j + vec_size * 2), r31);
267 r41 = vec_type::fmadd(a1,
vec_type::loadu(b + k * N + j + vec_size * 3), r41);
270 vec_type::storeu(c + i * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
271 vec_type::storeu(c + i * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r21));
272 vec_type::storeu(c + i * N + j + 2 * vec_size, vec_type::mul(alpha_vec, r31));
273 vec_type::storeu(c + i * N + j + 3 * vec_size, vec_type::mul(alpha_vec, r41));
278 for (; j + vec_size * 2 < j_end; j += 3 * vec_size) {
281 for (; i + 1 < M; i += 2) {
284 auto a1 = vec_type::set(a[(i + 0) * K + k]);
285 auto a2 = vec_type::set(a[(i + 1) * K + k]);
291 auto r11 = vec_type::mul(a1, b1);
292 auto r12 = vec_type::mul(a2, b1);
294 auto r21 = vec_type::mul(a1, b2);
295 auto r22 = vec_type::mul(a2, b2);
297 auto r31 = vec_type::mul(a1, b3);
298 auto r32 = vec_type::mul(a2, b3);
300 for (++k; k < K; ++k) {
301 a1 = vec_type::set(a[(i + 0) * K + k]);
302 a2 = vec_type::set(a[(i + 1) * K + k]);
308 r11 = vec_type::fmadd(a1, b1, r11);
309 r12 = vec_type::fmadd(a2, b1, r12);
311 r21 = vec_type::fmadd(a1, b2, r21);
312 r22 = vec_type::fmadd(a2, b2, r22);
314 r31 = vec_type::fmadd(a1, b3, r31);
315 r32 = vec_type::fmadd(a2, b3, r32);
318 vec_type::storeu(c + (i + 0) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
319 vec_type::storeu(c + (i + 0) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r21));
320 vec_type::storeu(c + (i + 0) * N + j + 2 * vec_size, vec_type::mul(alpha_vec, r31));
322 vec_type::storeu(c + (i + 1) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r12));
323 vec_type::storeu(c + (i + 1) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r22));
324 vec_type::storeu(c + (i + 1) * N + j + 2 * vec_size, vec_type::mul(alpha_vec, r32));
330 auto a1 = vec_type::set(a[(i + 0) * K + k]);
332 auto r11 = vec_type::mul(a1,
vec_type::loadu(b + k * N + j + vec_size * 0));
333 auto r21 = vec_type::mul(a1,
vec_type::loadu(b + k * N + j + vec_size * 1));
334 auto r31 = vec_type::mul(a1,
vec_type::loadu(b + k * N + j + vec_size * 2));
336 for (++k; k < K; ++k) {
337 a1 = vec_type::set(a[(i + 0) * K + k]);
339 r11 = vec_type::fmadd(a1,
vec_type::loadu(b + k * N + j + vec_size * 0), r11);
340 r21 = vec_type::fmadd(a1,
vec_type::loadu(b + k * N + j + vec_size * 1), r21);
341 r31 = vec_type::fmadd(a1,
vec_type::loadu(b + k * N + j + vec_size * 2), r31);
344 vec_type::storeu(c + i * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
345 vec_type::storeu(c + i * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r21));
346 vec_type::storeu(c + i * N + j + 2 * vec_size, vec_type::mul(alpha_vec, r31));
351 for (; j + vec_size < j_end; j += 2 * vec_size) {
354 for (; i + 3 < M; i += 4) {
357 auto a1 = vec_type::set(a[(i + 0) * K + k]);
358 auto a2 = vec_type::set(a[(i + 1) * K + k]);
359 auto a3 = vec_type::set(a[(i + 2) * K + k]);
360 auto a4 = vec_type::set(a[(i + 3) * K + k]);
365 auto r11 = vec_type::mul(a1, b1);
366 auto r12 = vec_type::mul(a2, b1);
367 auto r13 = vec_type::mul(a3, b1);
368 auto r14 = vec_type::mul(a4, b1);
370 auto r21 = vec_type::mul(a1, b2);
371 auto r22 = vec_type::mul(a2, b2);
372 auto r23 = vec_type::mul(a3, b2);
373 auto r24 = vec_type::mul(a4, b2);
375 for (++k; k < K; ++k) {
376 a1 = vec_type::set(a[(i + 0) * K + k]);
377 a2 = vec_type::set(a[(i + 1) * K + k]);
378 a3 = vec_type::set(a[(i + 2) * K + k]);
379 a4 = vec_type::set(a[(i + 3) * K + k]);
384 r11 = vec_type::fmadd(a1, b1, r11);
385 r12 = vec_type::fmadd(a2, b1, r12);
386 r13 = vec_type::fmadd(a3, b1, r13);
387 r14 = vec_type::fmadd(a4, b1, r14);
389 r21 = vec_type::fmadd(a1, b2, r21);
390 r22 = vec_type::fmadd(a2, b2, r22);
391 r23 = vec_type::fmadd(a3, b2, r23);
392 r24 = vec_type::fmadd(a4, b2, r24);
395 vec_type::storeu(c + (i + 0) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
396 vec_type::storeu(c + (i + 0) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r21));
398 vec_type::storeu(c + (i + 1) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r12));
399 vec_type::storeu(c + (i + 1) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r22));
401 vec_type::storeu(c + (i + 2) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r13));
402 vec_type::storeu(c + (i + 2) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r23));
404 vec_type::storeu(c + (i + 3) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r14));
405 vec_type::storeu(c + (i + 3) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r24));
408 for (; i + 1 < M; i += 2) {
411 auto a1 = vec_type::set(a[(i + 0) * K + k]);
412 auto a2 = vec_type::set(a[(i + 1) * K + k]);
417 auto r11 = vec_type::mul(a1, b1);
418 auto r12 = vec_type::mul(a2, b1);
420 auto r21 = vec_type::mul(a1, b2);
421 auto r22 = vec_type::mul(a2, b2);
423 for (++k; k < K; ++k) {
424 a1 = vec_type::set(a[(i + 0) * K + k]);
425 a2 = vec_type::set(a[(i + 1) * K + k]);
430 r11 = vec_type::fmadd(a1, b1, r11);
431 r12 = vec_type::fmadd(a2, b1, r12);
433 r21 = vec_type::fmadd(a1, b2, r21);
434 r22 = vec_type::fmadd(a2, b2, r22);
437 vec_type::storeu(c + (i + 0) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
438 vec_type::storeu(c + (i + 0) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r21));
440 vec_type::storeu(c + (i + 1) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r12));
441 vec_type::storeu(c + (i + 1) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r22));
450 auto a1 = vec_type::set(a[(i + 0) * K + k]);
452 auto r11 = vec_type::mul(a1, b1);
453 auto r21 = vec_type::mul(a1, b2);
455 for (++k; k < K; ++k) {
459 a1 = vec_type::set(a[(i + 0) * K + k]);
461 r11 = vec_type::fmadd(a1, b1, r11);
462 r21 = vec_type::fmadd(a1, b2, r21);
465 vec_type::storeu(c + (i + 0) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
466 vec_type::storeu(c + (i + 0) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r21));
471 for (; j < j_end; j += vec_size) {
474 for (; i + 1 < M; i += 2) {
477 auto a1 = vec_type::set(a[(i + 0) * K + k]);
478 auto a2 = vec_type::set(a[(i + 1) * K + k]);
482 auto r1 = vec_type::mul(a1, b1);
483 auto r2 = vec_type::mul(a2, b1);
485 for (++k; k < K; ++k) {
488 a1 = vec_type::set(a[(i + 0) * K + k]);
489 a2 = vec_type::set(a[(i + 1) * K + k]);
491 r1 = vec_type::fmadd(a1, b1, r1);
492 r2 = vec_type::fmadd(a2, b1, r2);
502 auto a1 = vec_type::set(a[(i + 0) * K + k]);
504 auto r1 = vec_type::mul(a1,
vec_type::loadu(b + k * N + j + vec_size * 0));
506 for (++k; k < K; ++k) {
507 a1 = vec_type::set(a[(i + 0) * K + k]);
509 r1 = vec_type::fmadd(a1,
vec_type::loadu(b + k * N + j + vec_size * 0), r1);
517 for (; j + 1 < N; j += 2) {
518 const size_t j1 = j + 0;
519 const size_t j2 = j + 1;
523 for (; i + 1 < M; i += 2) {
526 auto r11 = a[(i + 0) * K + k] * b[k * N + j1];
527 auto r21 = a[(i + 0) * K + k] * b[k * N + j2];
528 auto r12 = a[(i + 1) * K + k] * b[k * N + j1];
529 auto r22 = a[(i + 1) * K + k] * b[k * N + j2];
531 for (++k; k < K; ++k) {
532 r11 += a[(i + 0) * K + k] * b[k * N + j1];
533 r21 += a[(i + 0) * K + k] * b[k * N + j2];
534 r12 += a[(i + 1) * K + k] * b[k * N + j1];
535 r22 += a[(i + 1) * K + k] * b[k * N + j2];
538 c[(i + 0) * N + j1] = alpha * r11;
539 c[(i + 0) * N + j2] = alpha * r21;
540 c[(i + 1) * N + j1] = alpha * r12;
541 c[(i + 1) * N + j2] = alpha * r22;
547 auto r1 = a[i * K + k] * b[k * N + j1];
548 auto r2 = a[i * K + k] * b[k * N + j2];
550 for (++k; k < K; ++k) {
551 r1 += a[i * K + k] * b[k * N + j1];
552 r2 += a[i * K + k] * b[k * N + j2];
555 c[i * N + j1] = alpha * r1;
556 c[i * N + j2] = alpha * r2;
564 for (; i + 1 < M; i += 2) {
567 auto r1 = a[(i + 0) * K + k] * b[k * N + j];
568 auto r2 = a[(i + 1) * K + k] * b[k * N + j];
570 for (++k; k < K; ++k) {
571 r1 += a[(i + 0) * K + k] * b[k * N + j];
572 r2 += a[(i + 1) * K + k] * b[k * N + j];
575 c[(i + 0) * N + j] = alpha * r1;
576 c[(i + 1) * N + j] = alpha * r2;
582 auto r1 = a[i * K + k] * b[k * N + j];
584 for (++k; k < K; ++k) {
585 r1 += a[i * K + k] * b[k * N + j];
588 c[i * N + j] = alpha * r1;
600 template <
typename V,
typename T>
604 static constexpr
size_t vec_size = vec_type::template traits<T>::size;
606 const size_t n_block_size = 128;
607 const size_t m_block_size = 64;
608 const size_t k_block_size = 128;
610 auto alpha_vec = vec_type::set(alpha);
616 for (
size_t block_j = 0; block_j < N; block_j += n_block_size) {
617 const size_t j_end =
std::min(block_j + n_block_size, N);
619 for (
size_t block_i = 0; block_i < M; block_i += m_block_size) {
620 const size_t i_end =
std::min(block_i + m_block_size, M);
622 if (beta == T(0.0)) {
623 for (
size_t i = block_i; i < i_end; ++i) {
624 for (
size_t j = block_j; j < j_end; ++j) {
629 for (
size_t i = block_i; i < i_end; ++i) {
630 for (
size_t j = block_j; j < j_end; ++j) {
631 c[i * N + j] = beta * c[i * N + j];
636 for (
size_t block_k = 0; block_k < K; block_k += k_block_size) {
637 const size_t k_end =
std::min(block_k + k_block_size, K);
641 for (; j + vec_size * 4 - 1 < j_end; j += vec_size * 4) {
642 const size_t j1 = j + vec_size * 1;
643 const size_t j2 = j + vec_size * 2;
644 const size_t j3 = j + vec_size * 3;
648 for (; i + 1 < i_end; i += 2) {
659 for (
size_t k = block_k; k < k_end; ++k) {
660 auto a1 = vec_type::set(a[(i + 0) * K + k]);
661 auto a2 = vec_type::set(a[(i + 1) * K + k]);
668 r11 = vec_type::fmadd(a1, b1, r11);
669 r12 = vec_type::fmadd(a1, b2, r12);
670 r13 = vec_type::fmadd(a1, b3, r13);
671 r14 = vec_type::fmadd(a1, b4, r14);
673 r21 = vec_type::fmadd(a2, b1, r21);
674 r22 = vec_type::fmadd(a2, b2, r22);
675 r23 = vec_type::fmadd(a2, b3, r23);
676 r24 = vec_type::fmadd(a2, b4, r24);
695 for (
size_t k = block_k; k < k_end; ++k) {
696 auto a1 = vec_type::set(a[(i + 0) * K + k]);
703 r1 = vec_type::fmadd(a1, b1, r1);
704 r2 = vec_type::fmadd(a1, b2, r2);
705 r3 = vec_type::fmadd(a1, b3, r3);
706 r4 = vec_type::fmadd(a1, b4, r4);
716 for (; j + vec_size * 2 - 1 < j_end; j += vec_size * 2) {
717 const size_t j1(j + vec_size);
721 for (; i + 3 < i_end; i += 4) {
734 for (
size_t k = block_k; k < k_end; ++k) {
735 auto a1 = vec_type::set(a[(i + 0) * K + k]);
736 auto a2 = vec_type::set(a[(i + 1) * K + k]);
737 auto a3 = vec_type::set(a[(i + 2) * K + k]);
738 auto a4 = vec_type::set(a[(i + 3) * K + k]);
743 r11 = vec_type::fmadd(a1, b1, r11);
744 r12 = vec_type::fmadd(a1, b2, r12);
746 r21 = vec_type::fmadd(a2, b1, r21);
747 r22 = vec_type::fmadd(a2, b2, r22);
749 r31 = vec_type::fmadd(a3, b1, r31);
750 r32 = vec_type::fmadd(a3, b2, r32);
752 r41 = vec_type::fmadd(a4, b1, r41);
753 r42 = vec_type::fmadd(a4, b2, r42);
766 for (; i + 2 - 1 < i_end; i += 2) {
773 for (
size_t k = block_k; k < k_end; ++k) {
774 auto a1 = vec_type::set(a[(i + 0) * K + k]);
775 auto a2 = vec_type::set(a[(i + 1) * K + k]);
780 r11 = vec_type::fmadd(a1, b1, r11);
781 r12 = vec_type::fmadd(a1, b2, r12);
783 r21 = vec_type::fmadd(a2, b1, r21);
784 r22 = vec_type::fmadd(a2, b2, r22);
797 for (
size_t k = block_k; k < k_end; ++k) {
798 auto a1 = vec_type::set(a[(i + 0) * K + k]);
803 r1 = vec_type::fmadd(a1, b1, r1);
804 r2 = vec_type::fmadd(a1, b2, r2);
812 for (; j + vec_size - 1 < j_end; j += vec_size) {
813 for (
size_t i = block_i; i < i_end; ++i) {
816 for (
size_t k = block_k; k < k_end; ++k) {
817 auto a1 = vec_type::set(a[(i + 0) * K + k]);
819 r1 = vec_type::fmadd(a1, b1, r1);
826 for (; j < j_end; ++j) {
827 for (
size_t i = block_i; i < i_end; ++i) {
828 auto value = c[i * N + j];
830 for (
size_t k = block_k; k < k_end; ++k) {
831 value += a[i * K + k] * b[k * N + j];
834 c[i * N + j] = alpha * value;
842 template <
size_t vec_size>
843 inline constexpr
size_t prev_vec_block(
size_t value) noexcept {
844 return value - (value % vec_size);
854 template <
typename V,
typename T>
858 constexpr
size_t vec_size = vec_type::template traits<T>::size;
860 constexpr
size_t K_BLOCK = 112 * (16 /
sizeof(T));
861 constexpr
size_t J_BLOCK = 96;
873 auto batch_fun_j = [&](
const size_t jfirst,
const size_t jlast) {
877 auto * A2M = A2.memory_start();
878 auto * B2M = B2.memory_start();
884 for (; kk + vec_size - 1 < K; kk += kblock) {
885 kblock = kk + K_BLOCK <= K ? K_BLOCK : prev_vec_block<vec_size>(K - kk);
892 for (
size_t iii = 0; iii < M; ++iii) {
893 for (
size_t kkk = 0; kkk < kblock; ++kkk) {
894 A2(iii, kkk) = A(iii, kkk + kk);
901 for (; jj < jlast; jj += jblock) {
902 jblock = jj + J_BLOCK <= jlast ? J_BLOCK : jlast - jj;
905 for (
size_t kkk = 0; kkk < kblock; ++kkk) {
906 for (
size_t jjj = 0; jjj < jblock; ++jjj) {
907 B2(kkk, jjj) = B(kkk + kk, jjj + jj);
913 for (; i + 4 < M; i += 5) {
916 for (; j + 1 < jblock; j += 2) {
928 auto xmm1 = vec_type::mul(a1, b1);
929 auto xmm2 = vec_type::mul(a1, b2);
930 auto xmm3 = vec_type::mul(a2, b1);
931 auto xmm4 = vec_type::mul(a2, b2);
932 auto xmm5 = vec_type::mul(a3, b1);
933 auto xmm6 = vec_type::mul(a3, b2);
934 auto xmm7 = vec_type::mul(a4, b1);
935 auto xmm8 = vec_type::mul(a4, b2);
936 auto xmm9 = vec_type::mul(a5, b1);
937 auto xmm10 = vec_type::mul(a5, b2);
939 for (k += vec_size; k < kblock; k += vec_size) {
949 xmm1 = vec_type::fmadd(a1, b1, xmm1);
950 xmm2 = vec_type::fmadd(a1, b2, xmm2);
951 xmm3 = vec_type::fmadd(a2, b1, xmm3);
952 xmm4 = vec_type::fmadd(a2, b2, xmm4);
953 xmm5 = vec_type::fmadd(a3, b1, xmm5);
954 xmm6 = vec_type::fmadd(a3, b2, xmm6);
955 xmm7 = vec_type::fmadd(a4, b1, xmm7);
956 xmm8 = vec_type::fmadd(a4, b2, xmm8);
957 xmm9 = vec_type::fmadd(a5, b1, xmm9);
958 xmm10 = vec_type::fmadd(a5, b2, xmm10);
961 C(i + 0, jj + j + 0) += alpha * vec_type::hadd(xmm1);
962 C(i + 0, jj + j + 1) += alpha * vec_type::hadd(xmm2);
963 C(i + 1, jj + j + 0) += alpha * vec_type::hadd(xmm3);
964 C(i + 1, jj + j + 1) += alpha * vec_type::hadd(xmm4);
965 C(i + 2, jj + j + 0) += alpha * vec_type::hadd(xmm5);
966 C(i + 2, jj + j + 1) += alpha * vec_type::hadd(xmm6);
967 C(i + 3, jj + j + 0) += alpha * vec_type::hadd(xmm7);
968 C(i + 3, jj + j + 1) += alpha * vec_type::hadd(xmm8);
969 C(i + 4, jj + j + 0) += alpha * vec_type::hadd(xmm9);
970 C(i + 4, jj + j + 1) += alpha * vec_type::hadd(xmm10);
984 auto xmm1 = vec_type::mul(a1, b1);
985 auto xmm2 = vec_type::mul(a2, b1);
986 auto xmm3 = vec_type::mul(a3, b1);
987 auto xmm4 = vec_type::mul(a4, b1);
988 auto xmm5 = vec_type::mul(a5, b1);
990 for (k += vec_size; k < kblock; k += vec_size) {
999 xmm1 = vec_type::fmadd(a1, b1, xmm1);
1000 xmm2 = vec_type::fmadd(a2, b1, xmm2);
1001 xmm3 = vec_type::fmadd(a3, b1, xmm3);
1002 xmm4 = vec_type::fmadd(a4, b1, xmm4);
1003 xmm5 = vec_type::fmadd(a5, b1, xmm5);
1006 C(i + 0, jj + j) += alpha * vec_type::hadd(xmm1);
1007 C(i + 1, jj + j) += alpha * vec_type::hadd(xmm2);
1008 C(i + 2, jj + j) += alpha * vec_type::hadd(xmm3);
1009 C(i + 3, jj + j) += alpha * vec_type::hadd(xmm4);
1010 C(i + 4, jj + j) += alpha * vec_type::hadd(xmm5);
1014 for (; i + 1 < M; i += 2) {
1017 for (; j + 3 < jblock; j += 4) {
1028 auto xmm1 = vec_type::mul(a1, b1);
1029 auto xmm2 = vec_type::mul(a1, b2);
1030 auto xmm3 = vec_type::mul(a1, b3);
1031 auto xmm4 = vec_type::mul(a1, b4);
1032 auto xmm5 = vec_type::mul(a2, b1);
1033 auto xmm6 = vec_type::mul(a2, b2);
1034 auto xmm7 = vec_type::mul(a2, b3);
1035 auto xmm8 = vec_type::mul(a2, b4);
1037 for (k += vec_size; k < kblock; k += vec_size) {
1046 xmm1 = vec_type::fmadd(a1, b1, xmm1);
1047 xmm2 = vec_type::fmadd(a1, b2, xmm2);
1048 xmm3 = vec_type::fmadd(a1, b3, xmm3);
1049 xmm4 = vec_type::fmadd(a1, b4, xmm4);
1051 xmm5 = vec_type::fmadd(a2, b1, xmm5);
1052 xmm6 = vec_type::fmadd(a2, b2, xmm6);
1053 xmm7 = vec_type::fmadd(a2, b3, xmm7);
1054 xmm8 = vec_type::fmadd(a2, b2, xmm8);
1057 C(i + 0, jj + j + 0) += alpha * vec_type::hadd(xmm1);
1058 C(i + 0, jj + j + 1) += alpha * vec_type::hadd(xmm2);
1059 C(i + 0, jj + j + 2) += alpha * vec_type::hadd(xmm3);
1060 C(i + 0, jj + j + 3) += alpha * vec_type::hadd(xmm4);
1062 C(i + 1, jj + j + 0) += alpha * vec_type::hadd(xmm5);
1063 C(i + 1, jj + j + 1) += alpha * vec_type::hadd(xmm6);
1064 C(i + 1, jj + j + 2) += alpha * vec_type::hadd(xmm7);
1065 C(i + 1, jj + j + 3) += alpha * vec_type::hadd(xmm8);
1068 for (; j + 1 < jblock; j += 2) {
1077 auto xmm1 = vec_type::mul(a1, b1);
1078 auto xmm2 = vec_type::mul(a1, b2);
1079 auto xmm3 = vec_type::mul(a2, b1);
1080 auto xmm4 = vec_type::mul(a2, b2);
1082 for (k += vec_size; k < kblock; k += vec_size) {
1089 xmm1 = vec_type::fmadd(a1, b1, xmm1);
1090 xmm2 = vec_type::fmadd(a1, b2, xmm2);
1092 xmm3 = vec_type::fmadd(a2, b1, xmm3);
1093 xmm4 = vec_type::fmadd(a2, b2, xmm4);
1096 C(i + 0, jj + j + 0) += alpha * vec_type::hadd(xmm1);
1097 C(i + 0, jj + j + 1) += alpha * vec_type::hadd(xmm2);
1099 C(i + 1, jj + j + 0) += alpha * vec_type::hadd(xmm3);
1100 C(i + 1, jj + j + 1) += alpha * vec_type::hadd(xmm4);
1111 auto xmm1 = vec_type::mul(a1, b1);
1112 auto xmm2 = vec_type::mul(a2, b1);
1114 for (k += vec_size; k < kblock; k += vec_size) {
1120 xmm1 = vec_type::fmadd(a1, b1, xmm1);
1121 xmm2 = vec_type::fmadd(a2, b1, xmm2);
1124 C(i + 0, jj + j) += alpha * vec_type::hadd(xmm1);
1125 C(i + 1, jj + j) += alpha * vec_type::hadd(xmm2);
1132 for (; j + 1 < jblock; j += 2) {
1140 auto xmm1 = vec_type::mul(a1, b1);
1141 auto xmm2 = vec_type::mul(a1, b2);
1143 for (k += vec_size; k < kblock; k += vec_size) {
1149 xmm1 = vec_type::fmadd(a1, b1, xmm1);
1150 xmm2 = vec_type::fmadd(a1, b2, xmm2);
1153 C(i, jj + j + 0) += alpha * vec_type::hadd(xmm1);
1154 C(i, jj + j + 1) += alpha * vec_type::hadd(xmm2);
1164 auto xmm1 = vec_type::mul(a1, b1);
1166 for (k += vec_size; k < kblock; k += vec_size) {
1171 xmm1 = vec_type::fmadd(a1, b1, xmm1);
1174 C(i, jj + j) += alpha * vec_type::hadd(xmm1);
1182 const size_t kend = K - kk;
1185 for (
size_t iii = 0; iii < M; ++iii) {
1186 for (
size_t kkk = 0; kkk < kend; ++kkk) {
1187 A2(iii, kkk) = A(iii, kkk + kk);
1194 for (; jj < jlast; jj += jblock) {
1195 jblock = jj + J_BLOCK <= jlast ? J_BLOCK : jlast - jj;
1198 for (
size_t kkk = 0; kkk < kend; ++kkk) {
1199 for (
size_t jjj = 0; jjj < jblock; ++jjj) {
1200 B2(kkk, jjj) = B(kkk + kk, jjj + jj);
1206 for (; i + 4 < M; i += 5) {
1209 for (; j + 1 < jblock; j += 2) {
1210 for (
size_t k = 0; k < kend; ++k) {
1211 C(i + 0, jj + j + 0) += alpha * A2(i + 0, k) * B2(k, j + 0);
1212 C(i + 0, jj + j + 1) += alpha * A2(i + 0, k) * B2(k, j + 1);
1213 C(i + 1, jj + j + 0) += alpha * A2(i + 1, k) * B2(k, j + 0);
1214 C(i + 1, jj + j + 1) += alpha * A2(i + 1, k) * B2(k, j + 1);
1215 C(i + 2, jj + j + 0) += alpha * A2(i + 2, k) * B2(k, j + 0);
1216 C(i + 2, jj + j + 1) += alpha * A2(i + 2, k) * B2(k, j + 1);
1217 C(i + 3, jj + j + 0) += alpha * A2(i + 3, k) * B2(k, j + 0);
1218 C(i + 3, jj + j + 1) += alpha * A2(i + 3, k) * B2(k, j + 1);
1219 C(i + 4, jj + j + 0) += alpha * A2(i + 4, k) * B2(k, j + 0);
1220 C(i + 4, jj + j + 1) += alpha * A2(i + 4, k) * B2(k, j + 1);
1225 for (
size_t k = 0; k < kend; ++k) {
1226 C(i + 0, jj + j) += alpha * A2(i + 0, k) * B2(k, j);
1227 C(i + 1, jj + j) += alpha * A2(i + 1, k) * B2(k, j);
1228 C(i + 2, jj + j) += alpha * A2(i + 2, k) * B2(k, j);
1229 C(i + 3, jj + j) += alpha * A2(i + 3, k) * B2(k, j);
1230 C(i + 4, jj + j) += alpha * A2(i + 4, k) * B2(k, j);
1235 for (; i + 1 < M; i += 2) {
1238 for (; j + 1 < jblock; j += 2) {
1239 for (
size_t k = 0; k < kend; ++k) {
1240 C(i + 0, jj + j + 0) += alpha * A2(i + 0, k) * B2(k, j + 0);
1241 C(i + 0, jj + j + 1) += alpha * A2(i + 0, k) * B2(k, j + 1);
1242 C(i + 1, jj + j + 0) += alpha * A2(i + 1, k) * B2(k, j + 0);
1243 C(i + 1, jj + j + 1) += alpha * A2(i + 1, k) * B2(k, j + 1);
1248 for (
size_t k = 0; k < kend; ++k) {
1249 C(i + 0, jj + j) += alpha * A2(i + 0, k) * B2(k, j);
1250 C(i + 1, jj + j) += alpha * A2(i + 1, k) * B2(k, j);
1258 for (; j + 1 < jblock; j += 2) {
1259 for (
size_t k = 0; k < kend; ++k) {
1260 C(i, jj + j + 0) += alpha * A2(i, k) * B2(k, j + 0);
1261 C(i, jj + j + 1) += alpha * A2(i, k) * B2(k, j + 1);
1266 for (
size_t k = 0; k < kend; ++k) {
1267 C(i, jj + j) += alpha * A2(i, k) * B2(k, j);
1290 template <
typename T>
1291 void gemm_rr_to_r(
const T* a,
const T* b, T* c,
size_t M,
size_t N,
size_t K, T alpha) {
1292 cpp_assert(
vec_enabled,
"At least one vector mode must be enabled for impl::VEC");
1293 cpp_assert(
vectorize_impl,
"vectorize_impl must be enabled for impl::VEC");
1298 gemm_small_kernel_rr_to_r<default_vec>(a, b, c, M, N, K, alpha);
1300 gemm_large_kernel_rr_to_r<default_vec>(a, b, c, M, N, K, alpha, T(0));
1302 gemm_large_kernel_rr_to_r_temp<default_vec>(a, b, c, M, N, K, alpha, T(0));
void engine_dispatch_1d(Functor &&functor, size_t first, size_t last, [[maybe_unused]] size_t threshold, [[maybe_unused]] size_t n_threads=etl::threads)
Dispatch the elements of a range to a functor in a parallel manner, using the global thread engine...
Definition: parallel_support.hpp:708
constexpr size_t gemm_rr_medium_threshold
The number of elements of B after which we use BLAS-like kernel (for GEMM)
Definition: threshold.hpp:56
Definition: bias_add.hpp:15
constexpr size_t gemm_rr_small_threshold
The number of elements of B after which we use BLAS-like kernel (for GEMM)
Definition: threshold.hpp:55
constexpr bool vectorize_impl
Indicates if the implementations can be automatically vectorized by ETL.
Definition: config.hpp:35
constexpr bool vec_enabled
Indicates if vectorization is available in any format.
Definition: config.hpp:220
typename V::template vec_type< value_type > vec_type
The vectorization type for V.
Definition: dyn_matrix_view.hpp:43
void gemm_large_kernel_rr_to_r(const T *a, const T *b, T *ETL_RESTRICT c, size_t M, size_t N, size_t K, T alpha, T beta)
Optimized version of large GEMM for row major version.
Definition: gemm_rr_to_r.hpp:601
auto load(size_t x) const noexcept
Load several elements of the expression at once.
Definition: dyn_matrix_view.hpp:143
void gemm_large_kernel_rr_to_r_temp(const T *a, const T *b, T *ETL_RESTRICT c, size_t M, size_t N, size_t K, T alpha, T beta)
Optimized version of large GEMM for row major version.
Definition: gemm_rr_to_r.hpp:855
void storeu(vec_type< V > in, size_t i) noexcept
Store several elements in the matrix at once.
Definition: dyn_matrix_view.hpp:187
Matrix with run-time fixed dimensions.
Definition: dyn.hpp:26
auto loadu(size_t x) const noexcept
Load several elements of the expression at once.
Definition: dyn_matrix_view.hpp:154
auto min(L &&lhs, R &&rhs)
Create an expression with the min value of lhs or rhs.
Definition: expression_builder.hpp:77
void gemm_small_kernel_rr_to_r(const T *a, const T *b, T *ETL_RESTRICT c, size_t M, size_t N, size_t K, T alpha)
Optimized version of small GEMM for row major version.
Definition: gemm_rr_to_r.hpp:35
Matrix with run-time fixed dimensions.
Definition: custom_dyn.hpp:27
void gemm_rr_to_r(const T *a, const T *b, T *c, size_t M, size_t N, size_t K, T alpha)
Vectorized implementation of row-major matrix - row-major matrix multiplication and assignment into a...
Definition: gemm_rr_to_r.hpp:1291