Expression Templates Library (ETL)
gemm_rr_to_r.hpp
Go to the documentation of this file.
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
14 #pragma once
15 
16 namespace etl::impl::vec {
17 
18 // Optimizations opportunities
19 // For alpha=1, specific kernels could be done to avoid the multiplication
20 
21 // The 8-times unrolled loop is poorly handled by clang (3.9, 4.0)
22 #ifndef ETL_GEMM_SMALL_RR_R_UNROLL_8
23 #ifndef __clang__
24 #define ETL_GEMM_SMALL_RR_R_UNROLL_8
25 #endif
26 #endif
27 
34 template <typename V, typename T>
35 void gemm_small_kernel_rr_to_r(const T* a, const T* b, T* ETL_RESTRICT c, size_t M, size_t N, size_t K, T alpha) {
36  using vec_type = V;
37 
38  static constexpr size_t vec_size = vec_type::template traits<T>::size;
39 
40  const auto j_end = prev_multiple(N, vec_size);
41 
42  size_t j = 0;
43 
44  auto alpha_vec = vec_type::set(alpha);
45 
46  // As an optimization, we directly do the first iteration of the K-loop
47  // since K cannot be zero. This avoids having to preset the vector with zero
48 
49 #ifdef ETL_GEMM_SMALL_RR_R_UNROLL_8
50  // Vectorized loop unrolled eight times
51  for (; j + vec_size * 7 < j_end; j += vec_size * 8) {
52  for (size_t i = 0; i < M; ++i) {
53  size_t k = 0;
54 
55  auto a1 = vec_type::set(a[i * K + k]);
56 
57  auto r1 = vec_type::mul(a1, vec_type::loadu(b + k * N + j + vec_size * 0));
58  auto r2 = vec_type::mul(a1, vec_type::loadu(b + k * N + j + vec_size * 1));
59  auto r3 = vec_type::mul(a1, vec_type::loadu(b + k * N + j + vec_size * 2));
60  auto r4 = vec_type::mul(a1, vec_type::loadu(b + k * N + j + vec_size * 3));
61  auto r5 = vec_type::mul(a1, vec_type::loadu(b + k * N + j + vec_size * 4));
62  auto r6 = vec_type::mul(a1, vec_type::loadu(b + k * N + j + vec_size * 5));
63  auto r7 = vec_type::mul(a1, vec_type::loadu(b + k * N + j + vec_size * 6));
64  auto r8 = vec_type::mul(a1, vec_type::loadu(b + k * N + j + vec_size * 7));
65 
66  for (++k; k < K; ++k) {
67  a1 = vec_type::set(a[i * K + k]);
68 
69  r1 = vec_type::fmadd(a1, vec_type::loadu(b + k * N + j + vec_size * 0), r1);
70  r2 = vec_type::fmadd(a1, vec_type::loadu(b + k * N + j + vec_size * 1), r2);
71  r3 = vec_type::fmadd(a1, vec_type::loadu(b + k * N + j + vec_size * 2), r3);
72  r4 = vec_type::fmadd(a1, vec_type::loadu(b + k * N + j + vec_size * 3), r4);
73  r5 = vec_type::fmadd(a1, vec_type::loadu(b + k * N + j + vec_size * 4), r5);
74  r6 = vec_type::fmadd(a1, vec_type::loadu(b + k * N + j + vec_size * 5), r6);
75  r7 = vec_type::fmadd(a1, vec_type::loadu(b + k * N + j + vec_size * 6), r7);
76  r8 = vec_type::fmadd(a1, vec_type::loadu(b + k * N + j + vec_size * 7), r8);
77  }
78 
79  vec_type::storeu(c + i * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r1));
80  vec_type::storeu(c + i * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r2));
81  vec_type::storeu(c + i * N + j + 2 * vec_size, vec_type::mul(alpha_vec, r3));
82  vec_type::storeu(c + i * N + j + 3 * vec_size, vec_type::mul(alpha_vec, r4));
83  vec_type::storeu(c + i * N + j + 4 * vec_size, vec_type::mul(alpha_vec, r5));
84  vec_type::storeu(c + i * N + j + 5 * vec_size, vec_type::mul(alpha_vec, r6));
85  vec_type::storeu(c + i * N + j + 6 * vec_size, vec_type::mul(alpha_vec, r7));
86  vec_type::storeu(c + i * N + j + 7 * vec_size, vec_type::mul(alpha_vec, r8));
87  }
88  }
89 #endif
90 
91  // Vectorized loop unrolled five times
92  // This should max out the number of registers better than four
93  for (; j + vec_size * 4 < j_end; j += 5 * vec_size) {
94  size_t i = 0;
95 
96  for (; i + 1 < M; i += 2) {
97  size_t k = 0;
98 
99  auto a1 = vec_type::set(a[(i + 0) * K + k]);
100  auto a2 = vec_type::set(a[(i + 1) * K + k]);
101 
102  auto b1 = vec_type::loadu(b + k * N + j + vec_size * 0);
103  auto b2 = vec_type::loadu(b + k * N + j + vec_size * 1);
104  auto b3 = vec_type::loadu(b + k * N + j + vec_size * 2);
105  auto b4 = vec_type::loadu(b + k * N + j + vec_size * 3);
106  auto b5 = vec_type::loadu(b + k * N + j + vec_size * 4);
107 
108  auto r11 = vec_type::mul(a1, b1);
109  auto r12 = vec_type::mul(a2, b1);
110 
111  auto r21 = vec_type::mul(a1, b2);
112  auto r22 = vec_type::mul(a2, b2);
113 
114  auto r31 = vec_type::mul(a1, b3);
115  auto r32 = vec_type::mul(a2, b3);
116 
117  auto r41 = vec_type::mul(a1, b4);
118  auto r42 = vec_type::mul(a2, b4);
119 
120  auto r51 = vec_type::mul(a1, b5);
121  auto r52 = vec_type::mul(a2, b5);
122 
123  for (++k; k < K; ++k) {
124  a1 = vec_type::set(a[(i + 0) * K + k]);
125  a2 = vec_type::set(a[(i + 1) * K + k]);
126 
127  b1 = vec_type::loadu(b + k * N + j + vec_size * 0);
128  b2 = vec_type::loadu(b + k * N + j + vec_size * 1);
129  b3 = vec_type::loadu(b + k * N + j + vec_size * 2);
130  b4 = vec_type::loadu(b + k * N + j + vec_size * 3);
131  b5 = vec_type::loadu(b + k * N + j + vec_size * 4);
132 
133  r11 = vec_type::fmadd(a1, b1, r11);
134  r12 = vec_type::fmadd(a2, b1, r12);
135 
136  r21 = vec_type::fmadd(a1, b2, r21);
137  r22 = vec_type::fmadd(a2, b2, r22);
138 
139  r31 = vec_type::fmadd(a1, b3, r31);
140  r32 = vec_type::fmadd(a2, b3, r32);
141 
142  r41 = vec_type::fmadd(a1, b4, r41);
143  r42 = vec_type::fmadd(a2, b4, r42);
144 
145  r51 = vec_type::fmadd(a1, b5, r51);
146  r52 = vec_type::fmadd(a2, b5, r52);
147  }
148 
149  vec_type::storeu(c + (i + 0) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
150  vec_type::storeu(c + (i + 0) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r21));
151  vec_type::storeu(c + (i + 0) * N + j + 2 * vec_size, vec_type::mul(alpha_vec, r31));
152  vec_type::storeu(c + (i + 0) * N + j + 3 * vec_size, vec_type::mul(alpha_vec, r41));
153  vec_type::storeu(c + (i + 0) * N + j + 4 * vec_size, vec_type::mul(alpha_vec, r51));
154 
155  vec_type::storeu(c + (i + 1) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r12));
156  vec_type::storeu(c + (i + 1) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r22));
157  vec_type::storeu(c + (i + 1) * N + j + 2 * vec_size, vec_type::mul(alpha_vec, r32));
158  vec_type::storeu(c + (i + 1) * N + j + 3 * vec_size, vec_type::mul(alpha_vec, r42));
159  vec_type::storeu(c + (i + 1) * N + j + 4 * vec_size, vec_type::mul(alpha_vec, r52));
160  }
161 
162  if (i < M) {
163  size_t k = 0;
164 
165  auto a1 = vec_type::set(a[(i + 0) * K + k]);
166 
167  auto r11 = vec_type::mul(a1, vec_type::loadu(b + k * N + j + vec_size * 0));
168  auto r21 = vec_type::mul(a1, vec_type::loadu(b + k * N + j + vec_size * 1));
169  auto r31 = vec_type::mul(a1, vec_type::loadu(b + k * N + j + vec_size * 2));
170  auto r41 = vec_type::mul(a1, vec_type::loadu(b + k * N + j + vec_size * 3));
171  auto r51 = vec_type::mul(a1, vec_type::loadu(b + k * N + j + vec_size * 4));
172 
173  for (++k; k < K; ++k) {
174  a1 = vec_type::set(a[(i + 0) * K + k]);
175 
176  r11 = vec_type::fmadd(a1, vec_type::loadu(b + k * N + j + vec_size * 0), r11);
177  r21 = vec_type::fmadd(a1, vec_type::loadu(b + k * N + j + vec_size * 1), r21);
178  r31 = vec_type::fmadd(a1, vec_type::loadu(b + k * N + j + vec_size * 2), r31);
179  r41 = vec_type::fmadd(a1, vec_type::loadu(b + k * N + j + vec_size * 3), r41);
180  r51 = vec_type::fmadd(a1, vec_type::loadu(b + k * N + j + vec_size * 4), r51);
181  }
182 
183  vec_type::storeu(c + i * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
184  vec_type::storeu(c + i * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r21));
185  vec_type::storeu(c + i * N + j + 2 * vec_size, vec_type::mul(alpha_vec, r31));
186  vec_type::storeu(c + i * N + j + 3 * vec_size, vec_type::mul(alpha_vec, r41));
187  vec_type::storeu(c + i * N + j + 4 * vec_size, vec_type::mul(alpha_vec, r51));
188  }
189  }
190 
191  // Vectorized loop unrolled four times
192  for (; j + vec_size * 3 < j_end; j += 4 * vec_size) {
193  size_t i = 0;
194 
195  for (; i + 1 < M; i += 2) {
196  size_t k = 0;
197 
198  auto a1 = vec_type::set(a[(i + 0) * K + k]);
199  auto a2 = vec_type::set(a[(i + 1) * K + k]);
200 
201  auto b1 = vec_type::loadu(b + k * N + j + vec_size * 0);
202  auto b2 = vec_type::loadu(b + k * N + j + vec_size * 1);
203  auto b3 = vec_type::loadu(b + k * N + j + vec_size * 2);
204  auto b4 = vec_type::loadu(b + k * N + j + vec_size * 3);
205 
206  auto r11 = vec_type::mul(a1, b1);
207  auto r12 = vec_type::mul(a2, b1);
208 
209  auto r21 = vec_type::mul(a1, b2);
210  auto r22 = vec_type::mul(a2, b2);
211 
212  auto r31 = vec_type::mul(a1, b3);
213  auto r32 = vec_type::mul(a2, b3);
214 
215  auto r41 = vec_type::mul(a1, b4);
216  auto r42 = vec_type::mul(a2, b4);
217 
218  for (++k; k < K; ++k) {
219  a1 = vec_type::set(a[(i + 0) * K + k]);
220  a2 = vec_type::set(a[(i + 1) * K + k]);
221 
222  b1 = vec_type::loadu(b + k * N + j + vec_size * 0);
223  b2 = vec_type::loadu(b + k * N + j + vec_size * 1);
224  b3 = vec_type::loadu(b + k * N + j + vec_size * 2);
225  b4 = vec_type::loadu(b + k * N + j + vec_size * 3);
226 
227  r11 = vec_type::fmadd(a1, b1, r11);
228  r12 = vec_type::fmadd(a2, b1, r12);
229 
230  r21 = vec_type::fmadd(a1, b2, r21);
231  r22 = vec_type::fmadd(a2, b2, r22);
232 
233  r31 = vec_type::fmadd(a1, b3, r31);
234  r32 = vec_type::fmadd(a2, b3, r32);
235 
236  r41 = vec_type::fmadd(a1, b4, r41);
237  r42 = vec_type::fmadd(a2, b4, r42);
238  }
239 
240  vec_type::storeu(c + (i + 0) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
241  vec_type::storeu(c + (i + 0) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r21));
242  vec_type::storeu(c + (i + 0) * N + j + 2 * vec_size, vec_type::mul(alpha_vec, r31));
243  vec_type::storeu(c + (i + 0) * N + j + 3 * vec_size, vec_type::mul(alpha_vec, r41));
244 
245  vec_type::storeu(c + (i + 1) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r12));
246  vec_type::storeu(c + (i + 1) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r22));
247  vec_type::storeu(c + (i + 1) * N + j + 2 * vec_size, vec_type::mul(alpha_vec, r32));
248  vec_type::storeu(c + (i + 1) * N + j + 3 * vec_size, vec_type::mul(alpha_vec, r42));
249  }
250 
251  if (i < M) {
252  size_t k = 0;
253 
254  auto a1 = vec_type::set(a[(i + 0) * K + k]);
255 
256  auto r11 = vec_type::mul(a1, vec_type::loadu(b + k * N + j + vec_size * 0));
257  auto r21 = vec_type::mul(a1, vec_type::loadu(b + k * N + j + vec_size * 1));
258  auto r31 = vec_type::mul(a1, vec_type::loadu(b + k * N + j + vec_size * 2));
259  auto r41 = vec_type::mul(a1, vec_type::loadu(b + k * N + j + vec_size * 3));
260 
261  for (++k; k < K; ++k) {
262  a1 = vec_type::set(a[(i + 0) * K + k]);
263 
264  r11 = vec_type::fmadd(a1, vec_type::loadu(b + k * N + j + vec_size * 0), r11);
265  r21 = vec_type::fmadd(a1, vec_type::loadu(b + k * N + j + vec_size * 1), r21);
266  r31 = vec_type::fmadd(a1, vec_type::loadu(b + k * N + j + vec_size * 2), r31);
267  r41 = vec_type::fmadd(a1, vec_type::loadu(b + k * N + j + vec_size * 3), r41);
268  }
269 
270  vec_type::storeu(c + i * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
271  vec_type::storeu(c + i * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r21));
272  vec_type::storeu(c + i * N + j + 2 * vec_size, vec_type::mul(alpha_vec, r31));
273  vec_type::storeu(c + i * N + j + 3 * vec_size, vec_type::mul(alpha_vec, r41));
274  }
275  }
276 
277  // Vectorized loop unrolled three times
278  for (; j + vec_size * 2 < j_end; j += 3 * vec_size) {
279  size_t i = 0;
280 
281  for (; i + 1 < M; i += 2) {
282  size_t k = 0;
283 
284  auto a1 = vec_type::set(a[(i + 0) * K + k]);
285  auto a2 = vec_type::set(a[(i + 1) * K + k]);
286 
287  auto b1 = vec_type::loadu(b + k * N + j + vec_size * 0);
288  auto b2 = vec_type::loadu(b + k * N + j + vec_size * 1);
289  auto b3 = vec_type::loadu(b + k * N + j + vec_size * 2);
290 
291  auto r11 = vec_type::mul(a1, b1);
292  auto r12 = vec_type::mul(a2, b1);
293 
294  auto r21 = vec_type::mul(a1, b2);
295  auto r22 = vec_type::mul(a2, b2);
296 
297  auto r31 = vec_type::mul(a1, b3);
298  auto r32 = vec_type::mul(a2, b3);
299 
300  for (++k; k < K; ++k) {
301  a1 = vec_type::set(a[(i + 0) * K + k]);
302  a2 = vec_type::set(a[(i + 1) * K + k]);
303 
304  b1 = vec_type::loadu(b + k * N + j + vec_size * 0);
305  b2 = vec_type::loadu(b + k * N + j + vec_size * 1);
306  b3 = vec_type::loadu(b + k * N + j + vec_size * 2);
307 
308  r11 = vec_type::fmadd(a1, b1, r11);
309  r12 = vec_type::fmadd(a2, b1, r12);
310 
311  r21 = vec_type::fmadd(a1, b2, r21);
312  r22 = vec_type::fmadd(a2, b2, r22);
313 
314  r31 = vec_type::fmadd(a1, b3, r31);
315  r32 = vec_type::fmadd(a2, b3, r32);
316  }
317 
318  vec_type::storeu(c + (i + 0) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
319  vec_type::storeu(c + (i + 0) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r21));
320  vec_type::storeu(c + (i + 0) * N + j + 2 * vec_size, vec_type::mul(alpha_vec, r31));
321 
322  vec_type::storeu(c + (i + 1) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r12));
323  vec_type::storeu(c + (i + 1) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r22));
324  vec_type::storeu(c + (i + 1) * N + j + 2 * vec_size, vec_type::mul(alpha_vec, r32));
325  }
326 
327  if (i < M) {
328  size_t k = 0;
329 
330  auto a1 = vec_type::set(a[(i + 0) * K + k]);
331 
332  auto r11 = vec_type::mul(a1, vec_type::loadu(b + k * N + j + vec_size * 0));
333  auto r21 = vec_type::mul(a1, vec_type::loadu(b + k * N + j + vec_size * 1));
334  auto r31 = vec_type::mul(a1, vec_type::loadu(b + k * N + j + vec_size * 2));
335 
336  for (++k; k < K; ++k) {
337  a1 = vec_type::set(a[(i + 0) * K + k]);
338 
339  r11 = vec_type::fmadd(a1, vec_type::loadu(b + k * N + j + vec_size * 0), r11);
340  r21 = vec_type::fmadd(a1, vec_type::loadu(b + k * N + j + vec_size * 1), r21);
341  r31 = vec_type::fmadd(a1, vec_type::loadu(b + k * N + j + vec_size * 2), r31);
342  }
343 
344  vec_type::storeu(c + i * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
345  vec_type::storeu(c + i * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r21));
346  vec_type::storeu(c + i * N + j + 2 * vec_size, vec_type::mul(alpha_vec, r31));
347  }
348  }
349 
350  // Vectorized loop unrolled twice
351  for (; j + vec_size < j_end; j += 2 * vec_size) {
352  size_t i = 0;
353 
354  for (; i + 3 < M; i += 4) {
355  size_t k = 0;
356 
357  auto a1 = vec_type::set(a[(i + 0) * K + k]);
358  auto a2 = vec_type::set(a[(i + 1) * K + k]);
359  auto a3 = vec_type::set(a[(i + 2) * K + k]);
360  auto a4 = vec_type::set(a[(i + 3) * K + k]);
361 
362  auto b1 = vec_type::loadu(b + k * N + j + vec_size * 0);
363  auto b2 = vec_type::loadu(b + k * N + j + vec_size * 1);
364 
365  auto r11 = vec_type::mul(a1, b1);
366  auto r12 = vec_type::mul(a2, b1);
367  auto r13 = vec_type::mul(a3, b1);
368  auto r14 = vec_type::mul(a4, b1);
369 
370  auto r21 = vec_type::mul(a1, b2);
371  auto r22 = vec_type::mul(a2, b2);
372  auto r23 = vec_type::mul(a3, b2);
373  auto r24 = vec_type::mul(a4, b2);
374 
375  for (++k; k < K; ++k) {
376  a1 = vec_type::set(a[(i + 0) * K + k]);
377  a2 = vec_type::set(a[(i + 1) * K + k]);
378  a3 = vec_type::set(a[(i + 2) * K + k]);
379  a4 = vec_type::set(a[(i + 3) * K + k]);
380 
381  b1 = vec_type::loadu(b + k * N + j + vec_size * 0);
382  b2 = vec_type::loadu(b + k * N + j + vec_size * 1);
383 
384  r11 = vec_type::fmadd(a1, b1, r11);
385  r12 = vec_type::fmadd(a2, b1, r12);
386  r13 = vec_type::fmadd(a3, b1, r13);
387  r14 = vec_type::fmadd(a4, b1, r14);
388 
389  r21 = vec_type::fmadd(a1, b2, r21);
390  r22 = vec_type::fmadd(a2, b2, r22);
391  r23 = vec_type::fmadd(a3, b2, r23);
392  r24 = vec_type::fmadd(a4, b2, r24);
393  }
394 
395  vec_type::storeu(c + (i + 0) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
396  vec_type::storeu(c + (i + 0) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r21));
397 
398  vec_type::storeu(c + (i + 1) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r12));
399  vec_type::storeu(c + (i + 1) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r22));
400 
401  vec_type::storeu(c + (i + 2) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r13));
402  vec_type::storeu(c + (i + 2) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r23));
403 
404  vec_type::storeu(c + (i + 3) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r14));
405  vec_type::storeu(c + (i + 3) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r24));
406  }
407 
408  for (; i + 1 < M; i += 2) {
409  size_t k = 0;
410 
411  auto a1 = vec_type::set(a[(i + 0) * K + k]);
412  auto a2 = vec_type::set(a[(i + 1) * K + k]);
413 
414  auto b1 = vec_type::loadu(b + k * N + j + vec_size * 0);
415  auto b2 = vec_type::loadu(b + k * N + j + vec_size * 1);
416 
417  auto r11 = vec_type::mul(a1, b1);
418  auto r12 = vec_type::mul(a2, b1);
419 
420  auto r21 = vec_type::mul(a1, b2);
421  auto r22 = vec_type::mul(a2, b2);
422 
423  for (++k; k < K; ++k) {
424  a1 = vec_type::set(a[(i + 0) * K + k]);
425  a2 = vec_type::set(a[(i + 1) * K + k]);
426 
427  b1 = vec_type::loadu(b + k * N + j + vec_size * 0);
428  b2 = vec_type::loadu(b + k * N + j + vec_size * 1);
429 
430  r11 = vec_type::fmadd(a1, b1, r11);
431  r12 = vec_type::fmadd(a2, b1, r12);
432 
433  r21 = vec_type::fmadd(a1, b2, r21);
434  r22 = vec_type::fmadd(a2, b2, r22);
435  }
436 
437  vec_type::storeu(c + (i + 0) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
438  vec_type::storeu(c + (i + 0) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r21));
439 
440  vec_type::storeu(c + (i + 1) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r12));
441  vec_type::storeu(c + (i + 1) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r22));
442  }
443 
444  if (i < M) {
445  size_t k = 0;
446 
447  auto b1 = vec_type::loadu(b + k * N + j + vec_size * 0);
448  auto b2 = vec_type::loadu(b + k * N + j + vec_size * 1);
449 
450  auto a1 = vec_type::set(a[(i + 0) * K + k]);
451 
452  auto r11 = vec_type::mul(a1, b1);
453  auto r21 = vec_type::mul(a1, b2);
454 
455  for (++k; k < K; ++k) {
456  b1 = vec_type::loadu(b + k * N + j + vec_size * 0);
457  b2 = vec_type::loadu(b + k * N + j + vec_size * 1);
458 
459  a1 = vec_type::set(a[(i + 0) * K + k]);
460 
461  r11 = vec_type::fmadd(a1, b1, r11);
462  r21 = vec_type::fmadd(a1, b2, r21);
463  }
464 
465  vec_type::storeu(c + (i + 0) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
466  vec_type::storeu(c + (i + 0) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r21));
467  }
468  }
469 
470  // Vectorized loop
471  for (; j < j_end; j += vec_size) {
472  size_t i = 0;
473 
474  for (; i + 1 < M; i += 2) {
475  size_t k = 0;
476 
477  auto a1 = vec_type::set(a[(i + 0) * K + k]);
478  auto a2 = vec_type::set(a[(i + 1) * K + k]);
479 
480  auto b1 = vec_type::loadu(b + k * N + j + vec_size * 0);
481 
482  auto r1 = vec_type::mul(a1, b1);
483  auto r2 = vec_type::mul(a2, b1);
484 
485  for (++k; k < K; ++k) {
486  b1 = vec_type::loadu(b + k * N + j + vec_size * 0);
487 
488  a1 = vec_type::set(a[(i + 0) * K + k]);
489  a2 = vec_type::set(a[(i + 1) * K + k]);
490 
491  r1 = vec_type::fmadd(a1, b1, r1);
492  r2 = vec_type::fmadd(a2, b1, r2);
493  }
494 
495  vec_type::storeu(c + (i + 0) * N + j, vec_type::mul(alpha_vec, r1));
496  vec_type::storeu(c + (i + 1) * N + j, vec_type::mul(alpha_vec, r2));
497  }
498 
499  if (i < M) {
500  size_t k = 0;
501 
502  auto a1 = vec_type::set(a[(i + 0) * K + k]);
503 
504  auto r1 = vec_type::mul(a1, vec_type::loadu(b + k * N + j + vec_size * 0));
505 
506  for (++k; k < K; ++k) {
507  a1 = vec_type::set(a[(i + 0) * K + k]);
508 
509  r1 = vec_type::fmadd(a1, vec_type::loadu(b + k * N + j + vec_size * 0), r1);
510  }
511 
512  vec_type::storeu(c + (i + 0) * N + j, vec_type::mul(alpha_vec, r1));
513  }
514  }
515 
516  // Remainder Loop unrolled by 2
517  for (; j + 1 < N; j += 2) {
518  const size_t j1 = j + 0;
519  const size_t j2 = j + 1;
520 
521  size_t i = 0;
522 
523  for (; i + 1 < M; i += 2) {
524  size_t k = 0;
525 
526  auto r11 = a[(i + 0) * K + k] * b[k * N + j1];
527  auto r21 = a[(i + 0) * K + k] * b[k * N + j2];
528  auto r12 = a[(i + 1) * K + k] * b[k * N + j1];
529  auto r22 = a[(i + 1) * K + k] * b[k * N + j2];
530 
531  for (++k; k < K; ++k) {
532  r11 += a[(i + 0) * K + k] * b[k * N + j1];
533  r21 += a[(i + 0) * K + k] * b[k * N + j2];
534  r12 += a[(i + 1) * K + k] * b[k * N + j1];
535  r22 += a[(i + 1) * K + k] * b[k * N + j2];
536  }
537 
538  c[(i + 0) * N + j1] = alpha * r11;
539  c[(i + 0) * N + j2] = alpha * r21;
540  c[(i + 1) * N + j1] = alpha * r12;
541  c[(i + 1) * N + j2] = alpha * r22;
542  }
543 
544  if (i < M) {
545  size_t k = 0;
546 
547  auto r1 = a[i * K + k] * b[k * N + j1];
548  auto r2 = a[i * K + k] * b[k * N + j2];
549 
550  for (++k; k < K; ++k) {
551  r1 += a[i * K + k] * b[k * N + j1];
552  r2 += a[i * K + k] * b[k * N + j2];
553  }
554 
555  c[i * N + j1] = alpha * r1;
556  c[i * N + j2] = alpha * r2;
557  }
558  }
559 
560  // Final remainder loop iteration
561  if (j < N) {
562  size_t i = 0;
563 
564  for (; i + 1 < M; i += 2) {
565  size_t k = 0;
566 
567  auto r1 = a[(i + 0) * K + k] * b[k * N + j];
568  auto r2 = a[(i + 1) * K + k] * b[k * N + j];
569 
570  for (++k; k < K; ++k) {
571  r1 += a[(i + 0) * K + k] * b[k * N + j];
572  r2 += a[(i + 1) * K + k] * b[k * N + j];
573  }
574 
575  c[(i + 0) * N + j] = alpha * r1;
576  c[(i + 1) * N + j] = alpha * r2;
577  }
578 
579  if (i < M) {
580  size_t k = 0;
581 
582  auto r1 = a[i * K + k] * b[k * N + j];
583 
584  for (++k; k < K; ++k) {
585  r1 += a[i * K + k] * b[k * N + j];
586  }
587 
588  c[i * N + j] = alpha * r1;
589  }
590  }
591 }
592 
600 template <typename V, typename T>
601 void gemm_large_kernel_rr_to_r(const T* a, const T* b, T* ETL_RESTRICT c, size_t M, size_t N, size_t K, T alpha, T beta) {
602  using vec_type = V;
603 
604  static constexpr size_t vec_size = vec_type::template traits<T>::size;
605 
606  const size_t n_block_size = 128;
607  const size_t m_block_size = 64;
608  const size_t k_block_size = 128;
609 
610  auto alpha_vec = vec_type::set(alpha);
611 
612  // Note: There is a small benefit to parallelize this
613  // However, most of the parallel benefit is in larger matrices and the
614  // larger algorithm
615 
616  for (size_t block_j = 0; block_j < N; block_j += n_block_size) {
617  const size_t j_end = std::min(block_j + n_block_size, N);
618 
619  for (size_t block_i = 0; block_i < M; block_i += m_block_size) {
620  const size_t i_end = std::min(block_i + m_block_size, M);
621 
622  if (beta == T(0.0)) {
623  for (size_t i = block_i; i < i_end; ++i) {
624  for (size_t j = block_j; j < j_end; ++j) {
625  c[i * N + j] = 0;
626  }
627  }
628  } else {
629  for (size_t i = block_i; i < i_end; ++i) {
630  for (size_t j = block_j; j < j_end; ++j) {
631  c[i * N + j] = beta * c[i * N + j];
632  }
633  }
634  }
635 
636  for (size_t block_k = 0; block_k < K; block_k += k_block_size) {
637  const size_t k_end = std::min(block_k + k_block_size, K);
638 
639  size_t j = block_j;
640 
641  for (; j + vec_size * 4 - 1 < j_end; j += vec_size * 4) {
642  const size_t j1 = j + vec_size * 1;
643  const size_t j2 = j + vec_size * 2;
644  const size_t j3 = j + vec_size * 3;
645 
646  size_t i = block_i;
647 
648  for (; i + 1 < i_end; i += 2) {
649  auto r11 = vec_type::loadu(c + (i + 0) * N + j);
650  auto r12 = vec_type::loadu(c + (i + 0) * N + j1);
651  auto r13 = vec_type::loadu(c + (i + 0) * N + j2);
652  auto r14 = vec_type::loadu(c + (i + 0) * N + j3);
653 
654  auto r21 = vec_type::loadu(c + (i + 1) * N + j);
655  auto r22 = vec_type::loadu(c + (i + 1) * N + j1);
656  auto r23 = vec_type::loadu(c + (i + 1) * N + j2);
657  auto r24 = vec_type::loadu(c + (i + 1) * N + j3);
658 
659  for (size_t k = block_k; k < k_end; ++k) {
660  auto a1 = vec_type::set(a[(i + 0) * K + k]);
661  auto a2 = vec_type::set(a[(i + 1) * K + k]);
662 
663  auto b1 = vec_type::loadu(b + k * N + j);
664  auto b2 = vec_type::loadu(b + k * N + j1);
665  auto b3 = vec_type::loadu(b + k * N + j2);
666  auto b4 = vec_type::loadu(b + k * N + j3);
667 
668  r11 = vec_type::fmadd(a1, b1, r11);
669  r12 = vec_type::fmadd(a1, b2, r12);
670  r13 = vec_type::fmadd(a1, b3, r13);
671  r14 = vec_type::fmadd(a1, b4, r14);
672 
673  r21 = vec_type::fmadd(a2, b1, r21);
674  r22 = vec_type::fmadd(a2, b2, r22);
675  r23 = vec_type::fmadd(a2, b3, r23);
676  r24 = vec_type::fmadd(a2, b4, r24);
677  }
678 
679  vec_type::storeu(c + (i + 0) * N + j, vec_type::mul(alpha_vec, r11));
680  vec_type::storeu(c + (i + 0) * N + j1, vec_type::mul(alpha_vec, r12));
681  vec_type::storeu(c + (i + 0) * N + j2, vec_type::mul(alpha_vec, r13));
682  vec_type::storeu(c + (i + 0) * N + j3, vec_type::mul(alpha_vec, r14));
683  vec_type::storeu(c + (i + 1) * N + j, vec_type::mul(alpha_vec, r21));
684  vec_type::storeu(c + (i + 1) * N + j1, vec_type::mul(alpha_vec, r22));
685  vec_type::storeu(c + (i + 1) * N + j2, vec_type::mul(alpha_vec, r23));
686  vec_type::storeu(c + (i + 1) * N + j3, vec_type::mul(alpha_vec, r24));
687  }
688 
689  if (i < i_end) {
690  auto r1 = vec_type::loadu(c + (i + 0) * N + j);
691  auto r2 = vec_type::loadu(c + (i + 0) * N + j1);
692  auto r3 = vec_type::loadu(c + (i + 0) * N + j2);
693  auto r4 = vec_type::loadu(c + (i + 0) * N + j3);
694 
695  for (size_t k = block_k; k < k_end; ++k) {
696  auto a1 = vec_type::set(a[(i + 0) * K + k]);
697 
698  auto b1 = vec_type::loadu(b + k * N + j);
699  auto b2 = vec_type::loadu(b + k * N + j1);
700  auto b3 = vec_type::loadu(b + k * N + j2);
701  auto b4 = vec_type::loadu(b + k * N + j3);
702 
703  r1 = vec_type::fmadd(a1, b1, r1);
704  r2 = vec_type::fmadd(a1, b2, r2);
705  r3 = vec_type::fmadd(a1, b3, r3);
706  r4 = vec_type::fmadd(a1, b4, r4);
707  }
708 
709  vec_type::storeu(c + (i + 0) * N + j, vec_type::mul(alpha_vec, r1));
710  vec_type::storeu(c + (i + 0) * N + j1, vec_type::mul(alpha_vec, r2));
711  vec_type::storeu(c + (i + 0) * N + j2, vec_type::mul(alpha_vec, r3));
712  vec_type::storeu(c + (i + 0) * N + j3, vec_type::mul(alpha_vec, r4));
713  }
714  }
715 
716  for (; j + vec_size * 2 - 1 < j_end; j += vec_size * 2) {
717  const size_t j1(j + vec_size);
718 
719  size_t i = block_i;
720 
721  for (; i + 3 < i_end; i += 4) {
722  auto r11 = vec_type::loadu(c + (i + 0) * N + j);
723  auto r12 = vec_type::loadu(c + (i + 0) * N + j1);
724 
725  auto r21 = vec_type::loadu(c + (i + 1) * N + j);
726  auto r22 = vec_type::loadu(c + (i + 1) * N + j1);
727 
728  auto r31 = vec_type::loadu(c + (i + 2) * N + j);
729  auto r32 = vec_type::loadu(c + (i + 2) * N + j1);
730 
731  auto r41 = vec_type::loadu(c + (i + 3) * N + j);
732  auto r42 = vec_type::loadu(c + (i + 3) * N + j1);
733 
734  for (size_t k = block_k; k < k_end; ++k) {
735  auto a1 = vec_type::set(a[(i + 0) * K + k]);
736  auto a2 = vec_type::set(a[(i + 1) * K + k]);
737  auto a3 = vec_type::set(a[(i + 2) * K + k]);
738  auto a4 = vec_type::set(a[(i + 3) * K + k]);
739 
740  auto b1 = vec_type::loadu(b + k * N + j);
741  auto b2 = vec_type::loadu(b + k * N + j1);
742 
743  r11 = vec_type::fmadd(a1, b1, r11);
744  r12 = vec_type::fmadd(a1, b2, r12);
745 
746  r21 = vec_type::fmadd(a2, b1, r21);
747  r22 = vec_type::fmadd(a2, b2, r22);
748 
749  r31 = vec_type::fmadd(a3, b1, r31);
750  r32 = vec_type::fmadd(a3, b2, r32);
751 
752  r41 = vec_type::fmadd(a4, b1, r41);
753  r42 = vec_type::fmadd(a4, b2, r42);
754  }
755 
756  vec_type::storeu(c + (i + 0) * N + j, vec_type::mul(alpha_vec, r11));
757  vec_type::storeu(c + (i + 0) * N + j1, vec_type::mul(alpha_vec, r12));
758  vec_type::storeu(c + (i + 1) * N + j, vec_type::mul(alpha_vec, r21));
759  vec_type::storeu(c + (i + 1) * N + j1, vec_type::mul(alpha_vec, r22));
760  vec_type::storeu(c + (i + 2) * N + j, vec_type::mul(alpha_vec, r31));
761  vec_type::storeu(c + (i + 2) * N + j1, vec_type::mul(alpha_vec, r32));
762  vec_type::storeu(c + (i + 3) * N + j, vec_type::mul(alpha_vec, r41));
763  vec_type::storeu(c + (i + 3) * N + j1, vec_type::mul(alpha_vec, r42));
764  }
765 
766  for (; i + 2 - 1 < i_end; i += 2) {
767  auto r11 = vec_type::loadu(c + (i + 0) * N + j);
768  auto r12 = vec_type::loadu(c + (i + 0) * N + j1);
769 
770  auto r21 = vec_type::loadu(c + (i + 1) * N + j);
771  auto r22 = vec_type::loadu(c + (i + 1) * N + j1);
772 
773  for (size_t k = block_k; k < k_end; ++k) {
774  auto a1 = vec_type::set(a[(i + 0) * K + k]);
775  auto a2 = vec_type::set(a[(i + 1) * K + k]);
776 
777  auto b1 = vec_type::loadu(b + k * N + j);
778  auto b2 = vec_type::loadu(b + k * N + j1);
779 
780  r11 = vec_type::fmadd(a1, b1, r11);
781  r12 = vec_type::fmadd(a1, b2, r12);
782 
783  r21 = vec_type::fmadd(a2, b1, r21);
784  r22 = vec_type::fmadd(a2, b2, r22);
785  }
786 
787  vec_type::storeu(c + (i + 0) * N + j, vec_type::mul(alpha_vec, r11));
788  vec_type::storeu(c + (i + 0) * N + j1, vec_type::mul(alpha_vec, r12));
789  vec_type::storeu(c + (i + 1) * N + j, vec_type::mul(alpha_vec, r21));
790  vec_type::storeu(c + (i + 1) * N + j1, vec_type::mul(alpha_vec, r22));
791  }
792 
793  if (i < i_end) {
794  auto r1 = vec_type::loadu(c + (i + 0) * N + j);
795  auto r2 = vec_type::loadu(c + (i + 0) * N + j1);
796 
797  for (size_t k = block_k; k < k_end; ++k) {
798  auto a1 = vec_type::set(a[(i + 0) * K + k]);
799 
800  auto b1 = vec_type::loadu(b + k * N + j);
801  auto b2 = vec_type::loadu(b + k * N + j1);
802 
803  r1 = vec_type::fmadd(a1, b1, r1);
804  r2 = vec_type::fmadd(a1, b2, r2);
805  }
806 
807  vec_type::storeu(c + (i + 0) * N + j, vec_type::mul(alpha_vec, r1));
808  vec_type::storeu(c + (i + 0) * N + j1, vec_type::mul(alpha_vec, r2));
809  }
810  }
811 
812  for (; j + vec_size - 1 < j_end; j += vec_size) {
813  for (size_t i = block_i; i < i_end; ++i) {
814  auto r1 = vec_type::loadu(c + (i + 0) * N + j);
815 
816  for (size_t k = block_k; k < k_end; ++k) {
817  auto a1 = vec_type::set(a[(i + 0) * K + k]);
818  auto b1 = vec_type::loadu(b + k * N + j);
819  r1 = vec_type::fmadd(a1, b1, r1);
820  }
821 
822  vec_type::storeu(c + (i + 0) * N + j, vec_type::mul(alpha_vec, r1));
823  }
824  }
825 
826  for (; j < j_end; ++j) {
827  for (size_t i = block_i; i < i_end; ++i) {
828  auto value = c[i * N + j];
829 
830  for (size_t k = block_k; k < k_end; ++k) {
831  value += a[i * K + k] * b[k * N + j];
832  }
833 
834  c[i * N + j] = alpha * value;
835  }
836  }
837  }
838  }
839  }
840 }
841 
842 template <size_t vec_size>
843 inline constexpr size_t prev_vec_block(size_t value) noexcept {
844  return value - (value % vec_size);
845 }
846 
854 template <typename V, typename T>
855 void gemm_large_kernel_rr_to_r_temp(const T* a, const T* b, T* ETL_RESTRICT c, size_t M, size_t N, size_t K, T alpha, T beta) {
856  using vec_type = V;
857 
858  constexpr size_t vec_size = vec_type::template traits<T>::size;
859 
860  constexpr size_t K_BLOCK = 112 * (16 / sizeof(T));
861  constexpr size_t J_BLOCK = 96;
862 
863  etl::custom_dyn_matrix<T> A(const_cast<T*>(a), M, K);
864  etl::custom_dyn_matrix<T> B(const_cast<T*>(b), K, N);
865  etl::custom_dyn_matrix<T> C(c, M, N);
866 
867  if (beta == T(0)) {
868  C = 0;
869  } else {
870  C = beta * C;
871  }
872 
873  auto batch_fun_j = [&](const size_t jfirst, const size_t jlast) {
875  etl::dyn_matrix_impl<T, order::ColumnMajor> B2(K_BLOCK, J_BLOCK);
876 
877  auto * A2M = A2.memory_start();
878  auto * B2M = B2.memory_start();
879 
880  size_t kblock = 0;
881  size_t kk = 0;
882 
883  // Main loop
884  for (; kk + vec_size - 1 < K; kk += kblock) {
885  kblock = kk + K_BLOCK <= K ? K_BLOCK : prev_vec_block<vec_size>(K - kk);
886 
887  if (!kblock) {
888  continue;
889  }
890 
891  // Copy A into A2
892  for (size_t iii = 0; iii < M; ++iii) {
893  for (size_t kkk = 0; kkk < kblock; ++kkk) {
894  A2(iii, kkk) = A(iii, kkk + kk);
895  }
896  }
897 
898  size_t jj = jfirst;
899  size_t jblock = 0;
900 
901  for (; jj < jlast; jj += jblock) {
902  jblock = jj + J_BLOCK <= jlast ? J_BLOCK : jlast - jj;
903 
904  // Copy B into B2
905  for (size_t kkk = 0; kkk < kblock; ++kkk) {
906  for (size_t jjj = 0; jjj < jblock; ++jjj) {
907  B2(kkk, jjj) = B(kkk + kk, jjj + jj);
908  }
909  }
910 
911  size_t i = 0;
912 
913  for (; i + 4 < M; i += 5) {
914  size_t j = 0;
915 
916  for (; j + 1 < jblock; j += 2) {
917  size_t k = 0;
918 
919  auto a1 = vec_type::load(A2M + (i + 0) * K_BLOCK + k);
920  auto a2 = vec_type::load(A2M + (i + 1) * K_BLOCK + k);
921  auto a3 = vec_type::load(A2M + (i + 2) * K_BLOCK + k);
922  auto a4 = vec_type::load(A2M + (i + 3) * K_BLOCK + k);
923  auto a5 = vec_type::load(A2M + (i + 4) * K_BLOCK + k);
924 
925  auto b1 = vec_type::load(B2M + (j + 0) * K_BLOCK + k);
926  auto b2 = vec_type::load(B2M + (j + 1) * K_BLOCK + k);
927 
928  auto xmm1 = vec_type::mul(a1, b1);
929  auto xmm2 = vec_type::mul(a1, b2);
930  auto xmm3 = vec_type::mul(a2, b1);
931  auto xmm4 = vec_type::mul(a2, b2);
932  auto xmm5 = vec_type::mul(a3, b1);
933  auto xmm6 = vec_type::mul(a3, b2);
934  auto xmm7 = vec_type::mul(a4, b1);
935  auto xmm8 = vec_type::mul(a4, b2);
936  auto xmm9 = vec_type::mul(a5, b1);
937  auto xmm10 = vec_type::mul(a5, b2);
938 
939  for (k += vec_size; k < kblock; k += vec_size) {
940  a1 = vec_type::load(A2M + (i + 0) * K_BLOCK + k);
941  a2 = vec_type::load(A2M + (i + 1) * K_BLOCK + k);
942  a3 = vec_type::load(A2M + (i + 2) * K_BLOCK + k);
943  a4 = vec_type::load(A2M + (i + 3) * K_BLOCK + k);
944  a5 = vec_type::load(A2M + (i + 4) * K_BLOCK + k);
945 
946  b1 = vec_type::load(B2M + (j + 0) * K_BLOCK + k);
947  b2 = vec_type::load(B2M + (j + 1) * K_BLOCK + k);
948 
949  xmm1 = vec_type::fmadd(a1, b1, xmm1);
950  xmm2 = vec_type::fmadd(a1, b2, xmm2);
951  xmm3 = vec_type::fmadd(a2, b1, xmm3);
952  xmm4 = vec_type::fmadd(a2, b2, xmm4);
953  xmm5 = vec_type::fmadd(a3, b1, xmm5);
954  xmm6 = vec_type::fmadd(a3, b2, xmm6);
955  xmm7 = vec_type::fmadd(a4, b1, xmm7);
956  xmm8 = vec_type::fmadd(a4, b2, xmm8);
957  xmm9 = vec_type::fmadd(a5, b1, xmm9);
958  xmm10 = vec_type::fmadd(a5, b2, xmm10);
959  }
960 
961  C(i + 0, jj + j + 0) += alpha * vec_type::hadd(xmm1);
962  C(i + 0, jj + j + 1) += alpha * vec_type::hadd(xmm2);
963  C(i + 1, jj + j + 0) += alpha * vec_type::hadd(xmm3);
964  C(i + 1, jj + j + 1) += alpha * vec_type::hadd(xmm4);
965  C(i + 2, jj + j + 0) += alpha * vec_type::hadd(xmm5);
966  C(i + 2, jj + j + 1) += alpha * vec_type::hadd(xmm6);
967  C(i + 3, jj + j + 0) += alpha * vec_type::hadd(xmm7);
968  C(i + 3, jj + j + 1) += alpha * vec_type::hadd(xmm8);
969  C(i + 4, jj + j + 0) += alpha * vec_type::hadd(xmm9);
970  C(i + 4, jj + j + 1) += alpha * vec_type::hadd(xmm10);
971  }
972 
973  if (j < jblock) {
974  size_t k = 0;
975 
976  auto a1 = vec_type::load(A2M + (i + 0) * K_BLOCK + k);
977  auto a2 = vec_type::load(A2M + (i + 1) * K_BLOCK + k);
978  auto a3 = vec_type::load(A2M + (i + 2) * K_BLOCK + k);
979  auto a4 = vec_type::load(A2M + (i + 3) * K_BLOCK + k);
980  auto a5 = vec_type::load(A2M + (i + 4) * K_BLOCK + k);
981 
982  auto b1 = vec_type::load(B2M + j * K_BLOCK + k);
983 
984  auto xmm1 = vec_type::mul(a1, b1);
985  auto xmm2 = vec_type::mul(a2, b1);
986  auto xmm3 = vec_type::mul(a3, b1);
987  auto xmm4 = vec_type::mul(a4, b1);
988  auto xmm5 = vec_type::mul(a5, b1);
989 
990  for (k += vec_size; k < kblock; k += vec_size) {
991  a1 = vec_type::load(A2M + (i + 0) * K_BLOCK + k);
992  a2 = vec_type::load(A2M + (i + 1) * K_BLOCK + k);
993  a3 = vec_type::load(A2M + (i + 2) * K_BLOCK + k);
994  a4 = vec_type::load(A2M + (i + 3) * K_BLOCK + k);
995  a5 = vec_type::load(A2M + (i + 4) * K_BLOCK + k);
996 
997  b1 = vec_type::load(B2M + j * K_BLOCK + k);
998 
999  xmm1 = vec_type::fmadd(a1, b1, xmm1);
1000  xmm2 = vec_type::fmadd(a2, b1, xmm2);
1001  xmm3 = vec_type::fmadd(a3, b1, xmm3);
1002  xmm4 = vec_type::fmadd(a4, b1, xmm4);
1003  xmm5 = vec_type::fmadd(a5, b1, xmm5);
1004  }
1005 
1006  C(i + 0, jj + j) += alpha * vec_type::hadd(xmm1);
1007  C(i + 1, jj + j) += alpha * vec_type::hadd(xmm2);
1008  C(i + 2, jj + j) += alpha * vec_type::hadd(xmm3);
1009  C(i + 3, jj + j) += alpha * vec_type::hadd(xmm4);
1010  C(i + 4, jj + j) += alpha * vec_type::hadd(xmm5);
1011  }
1012  }
1013 
1014  for (; i + 1 < M; i += 2) {
1015  size_t j = 0;
1016 
1017  for (; j + 3 < jblock; j += 4) {
1018  size_t k = 0;
1019 
1020  auto a1 = vec_type::load(A2M + (i + 0) * K_BLOCK + k);
1021  auto a2 = vec_type::load(A2M + (i + 1) * K_BLOCK + k);
1022 
1023  auto b1 = vec_type::load(B2M + (j + 0) * K_BLOCK + k);
1024  auto b2 = vec_type::load(B2M + (j + 1) * K_BLOCK + k);
1025  auto b3 = vec_type::load(B2M + (j + 2) * K_BLOCK + k);
1026  auto b4 = vec_type::load(B2M + (j + 3) * K_BLOCK + k);
1027 
1028  auto xmm1 = vec_type::mul(a1, b1);
1029  auto xmm2 = vec_type::mul(a1, b2);
1030  auto xmm3 = vec_type::mul(a1, b3);
1031  auto xmm4 = vec_type::mul(a1, b4);
1032  auto xmm5 = vec_type::mul(a2, b1);
1033  auto xmm6 = vec_type::mul(a2, b2);
1034  auto xmm7 = vec_type::mul(a2, b3);
1035  auto xmm8 = vec_type::mul(a2, b4);
1036 
1037  for (k += vec_size; k < kblock; k += vec_size) {
1038  a1 = vec_type::load(A2M + (i + 0) * K_BLOCK + k);
1039  a2 = vec_type::load(A2M + (i + 1) * K_BLOCK + k);
1040 
1041  b1 = vec_type::load(B2M + (j + 0) * K_BLOCK + k);
1042  b2 = vec_type::load(B2M + (j + 1) * K_BLOCK + k);
1043  b3 = vec_type::load(B2M + (j + 2) * K_BLOCK + k);
1044  b4 = vec_type::load(B2M + (j + 3) * K_BLOCK + k);
1045 
1046  xmm1 = vec_type::fmadd(a1, b1, xmm1);
1047  xmm2 = vec_type::fmadd(a1, b2, xmm2);
1048  xmm3 = vec_type::fmadd(a1, b3, xmm3);
1049  xmm4 = vec_type::fmadd(a1, b4, xmm4);
1050 
1051  xmm5 = vec_type::fmadd(a2, b1, xmm5);
1052  xmm6 = vec_type::fmadd(a2, b2, xmm6);
1053  xmm7 = vec_type::fmadd(a2, b3, xmm7);
1054  xmm8 = vec_type::fmadd(a2, b2, xmm8);
1055  }
1056 
1057  C(i + 0, jj + j + 0) += alpha * vec_type::hadd(xmm1);
1058  C(i + 0, jj + j + 1) += alpha * vec_type::hadd(xmm2);
1059  C(i + 0, jj + j + 2) += alpha * vec_type::hadd(xmm3);
1060  C(i + 0, jj + j + 3) += alpha * vec_type::hadd(xmm4);
1061 
1062  C(i + 1, jj + j + 0) += alpha * vec_type::hadd(xmm5);
1063  C(i + 1, jj + j + 1) += alpha * vec_type::hadd(xmm6);
1064  C(i + 1, jj + j + 2) += alpha * vec_type::hadd(xmm7);
1065  C(i + 1, jj + j + 3) += alpha * vec_type::hadd(xmm8);
1066  }
1067 
1068  for (; j + 1 < jblock; j += 2) {
1069  size_t k = 0;
1070 
1071  auto a1 = vec_type::load(A2M + (i + 0) * K_BLOCK + k);
1072  auto a2 = vec_type::load(A2M + (i + 1) * K_BLOCK + k);
1073 
1074  auto b1 = vec_type::load(B2M + (j + 0) * K_BLOCK + k);
1075  auto b2 = vec_type::load(B2M + (j + 1) * K_BLOCK + k);
1076 
1077  auto xmm1 = vec_type::mul(a1, b1);
1078  auto xmm2 = vec_type::mul(a1, b2);
1079  auto xmm3 = vec_type::mul(a2, b1);
1080  auto xmm4 = vec_type::mul(a2, b2);
1081 
1082  for (k += vec_size; k < kblock; k += vec_size) {
1083  a1 = vec_type::load(A2M + (i + 0) * K_BLOCK + k);
1084  a2 = vec_type::load(A2M + (i + 1) * K_BLOCK + k);
1085 
1086  b1 = vec_type::load(B2M + (j + 0) * K_BLOCK + k);
1087  b2 = vec_type::load(B2M + (j + 1) * K_BLOCK + k);
1088 
1089  xmm1 = vec_type::fmadd(a1, b1, xmm1);
1090  xmm2 = vec_type::fmadd(a1, b2, xmm2);
1091 
1092  xmm3 = vec_type::fmadd(a2, b1, xmm3);
1093  xmm4 = vec_type::fmadd(a2, b2, xmm4);
1094  }
1095 
1096  C(i + 0, jj + j + 0) += alpha * vec_type::hadd(xmm1);
1097  C(i + 0, jj + j + 1) += alpha * vec_type::hadd(xmm2);
1098 
1099  C(i + 1, jj + j + 0) += alpha * vec_type::hadd(xmm3);
1100  C(i + 1, jj + j + 1) += alpha * vec_type::hadd(xmm4);
1101  }
1102 
1103  if (j < jblock) {
1104  size_t k = 0;
1105 
1106  auto a1 = vec_type::load(A2M + (i + 0) * K_BLOCK + k);
1107  auto a2 = vec_type::load(A2M + (i + 1) * K_BLOCK + k);
1108 
1109  auto b1 = vec_type::load(B2M + j * K_BLOCK + k);
1110 
1111  auto xmm1 = vec_type::mul(a1, b1);
1112  auto xmm2 = vec_type::mul(a2, b1);
1113 
1114  for (k += vec_size; k < kblock; k += vec_size) {
1115  a1 = vec_type::load(A2M + (i + 0) * K_BLOCK + k);
1116  a2 = vec_type::load(A2M + (i + 1) * K_BLOCK + k);
1117 
1118  b1 = vec_type::load(B2M + j * K_BLOCK + k);
1119 
1120  xmm1 = vec_type::fmadd(a1, b1, xmm1);
1121  xmm2 = vec_type::fmadd(a2, b1, xmm2);
1122  }
1123 
1124  C(i + 0, jj + j) += alpha * vec_type::hadd(xmm1);
1125  C(i + 1, jj + j) += alpha * vec_type::hadd(xmm2);
1126  }
1127  }
1128 
1129  if (i < M) {
1130  size_t j = 0;
1131 
1132  for (; j + 1 < jblock; j += 2) {
1133  size_t k = 0;
1134 
1135  auto a1 = vec_type::load(A2M + i * K_BLOCK + k);
1136 
1137  auto b1 = vec_type::load(B2M + (j + 0) * K_BLOCK + k);
1138  auto b2 = vec_type::load(B2M + (j + 1) * K_BLOCK + k);
1139 
1140  auto xmm1 = vec_type::mul(a1, b1);
1141  auto xmm2 = vec_type::mul(a1, b2);
1142 
1143  for (k += vec_size; k < kblock; k += vec_size) {
1144  a1 = vec_type::load(A2M + i * K_BLOCK + k);
1145 
1146  b1 = vec_type::load(B2M + (j + 0) * K_BLOCK + k);
1147  b2 = vec_type::load(B2M + (j + 1) * K_BLOCK + k);
1148 
1149  xmm1 = vec_type::fmadd(a1, b1, xmm1);
1150  xmm2 = vec_type::fmadd(a1, b2, xmm2);
1151  }
1152 
1153  C(i, jj + j + 0) += alpha * vec_type::hadd(xmm1);
1154  C(i, jj + j + 1) += alpha * vec_type::hadd(xmm2);
1155  }
1156 
1157  if (j < jblock) {
1158  size_t k = 0;
1159 
1160  auto a1 = vec_type::load(A2M + i * K_BLOCK + k);
1161 
1162  auto b1 = vec_type::load(B2M + j * K_BLOCK + k);
1163 
1164  auto xmm1 = vec_type::mul(a1, b1);
1165 
1166  for (k += vec_size; k < kblock; k += vec_size) {
1167  a1 = vec_type::load(A2M + i * K_BLOCK + k);
1168 
1169  b1 = vec_type::load(B2M + j * K_BLOCK + k);
1170 
1171  xmm1 = vec_type::fmadd(a1, b1, xmm1);
1172  }
1173 
1174  C(i, jj + j) += alpha * vec_type::hadd(xmm1);
1175  }
1176  }
1177  }
1178  }
1179 
1180  // Remainder loop
1181  if (kk < K) {
1182  const size_t kend = K - kk;
1183 
1184  // Copy A into A2
1185  for (size_t iii = 0; iii < M; ++iii) {
1186  for (size_t kkk = 0; kkk < kend; ++kkk) {
1187  A2(iii, kkk) = A(iii, kkk + kk);
1188  }
1189  }
1190 
1191  size_t jj = 0;
1192  size_t jblock = 0;
1193 
1194  for (; jj < jlast; jj += jblock) {
1195  jblock = jj + J_BLOCK <= jlast ? J_BLOCK : jlast - jj;
1196 
1197  // Copy B into B2
1198  for (size_t kkk = 0; kkk < kend; ++kkk) {
1199  for (size_t jjj = 0; jjj < jblock; ++jjj) {
1200  B2(kkk, jjj) = B(kkk + kk, jjj + jj);
1201  }
1202  }
1203 
1204  size_t i = 0;
1205 
1206  for (; i + 4 < M; i += 5) {
1207  size_t j = 0;
1208 
1209  for (; j + 1 < jblock; j += 2) {
1210  for (size_t k = 0; k < kend; ++k) {
1211  C(i + 0, jj + j + 0) += alpha * A2(i + 0, k) * B2(k, j + 0);
1212  C(i + 0, jj + j + 1) += alpha * A2(i + 0, k) * B2(k, j + 1);
1213  C(i + 1, jj + j + 0) += alpha * A2(i + 1, k) * B2(k, j + 0);
1214  C(i + 1, jj + j + 1) += alpha * A2(i + 1, k) * B2(k, j + 1);
1215  C(i + 2, jj + j + 0) += alpha * A2(i + 2, k) * B2(k, j + 0);
1216  C(i + 2, jj + j + 1) += alpha * A2(i + 2, k) * B2(k, j + 1);
1217  C(i + 3, jj + j + 0) += alpha * A2(i + 3, k) * B2(k, j + 0);
1218  C(i + 3, jj + j + 1) += alpha * A2(i + 3, k) * B2(k, j + 1);
1219  C(i + 4, jj + j + 0) += alpha * A2(i + 4, k) * B2(k, j + 0);
1220  C(i + 4, jj + j + 1) += alpha * A2(i + 4, k) * B2(k, j + 1);
1221  }
1222  }
1223 
1224  if (j < jblock) {
1225  for (size_t k = 0; k < kend; ++k) {
1226  C(i + 0, jj + j) += alpha * A2(i + 0, k) * B2(k, j);
1227  C(i + 1, jj + j) += alpha * A2(i + 1, k) * B2(k, j);
1228  C(i + 2, jj + j) += alpha * A2(i + 2, k) * B2(k, j);
1229  C(i + 3, jj + j) += alpha * A2(i + 3, k) * B2(k, j);
1230  C(i + 4, jj + j) += alpha * A2(i + 4, k) * B2(k, j);
1231  }
1232  }
1233  }
1234 
1235  for (; i + 1 < M; i += 2) {
1236  size_t j = 0;
1237 
1238  for (; j + 1 < jblock; j += 2) {
1239  for (size_t k = 0; k < kend; ++k) {
1240  C(i + 0, jj + j + 0) += alpha * A2(i + 0, k) * B2(k, j + 0);
1241  C(i + 0, jj + j + 1) += alpha * A2(i + 0, k) * B2(k, j + 1);
1242  C(i + 1, jj + j + 0) += alpha * A2(i + 1, k) * B2(k, j + 0);
1243  C(i + 1, jj + j + 1) += alpha * A2(i + 1, k) * B2(k, j + 1);
1244  }
1245  }
1246 
1247  if (j < jblock) {
1248  for (size_t k = 0; k < kend; ++k) {
1249  C(i + 0, jj + j) += alpha * A2(i + 0, k) * B2(k, j);
1250  C(i + 1, jj + j) += alpha * A2(i + 1, k) * B2(k, j);
1251  }
1252  }
1253  }
1254 
1255  if (i < M) {
1256  size_t j = 0;
1257 
1258  for (; j + 1 < jblock; j += 2) {
1259  for (size_t k = 0; k < kend; ++k) {
1260  C(i, jj + j + 0) += alpha * A2(i, k) * B2(k, j + 0);
1261  C(i, jj + j + 1) += alpha * A2(i, k) * B2(k, j + 1);
1262  }
1263  }
1264 
1265  if (j < jblock) {
1266  for (size_t k = 0; k < kend; ++k) {
1267  C(i, jj + j) += alpha * A2(i, k) * B2(k, j);
1268  }
1269  }
1270  }
1271  }
1272  }
1273  };
1274 
1275  engine_dispatch_1d(batch_fun_j, 0, N, J_BLOCK);
1276 }
1277 
1290 template <typename T>
1291 void gemm_rr_to_r(const T* a, const T* b, T* c, size_t M, size_t N, size_t K, T alpha) {
1292  cpp_assert(vec_enabled, "At least one vector mode must be enabled for impl::VEC");
1293  cpp_assert(vectorize_impl, "vectorize_impl must be enabled for impl::VEC");
1294 
1295  // Dispatch to the best kernel
1296 
1297  if (K * N <= gemm_rr_small_threshold) {
1298  gemm_small_kernel_rr_to_r<default_vec>(a, b, c, M, N, K, alpha);
1299  } else if (K * N <= gemm_rr_medium_threshold) {
1300  gemm_large_kernel_rr_to_r<default_vec>(a, b, c, M, N, K, alpha, T(0));
1301  } else {
1302  gemm_large_kernel_rr_to_r_temp<default_vec>(a, b, c, M, N, K, alpha, T(0));
1303  }
1304 }
1305 
1306 } //end of namespace etl::impl::vec
void engine_dispatch_1d(Functor &&functor, size_t first, size_t last, [[maybe_unused]] size_t threshold, [[maybe_unused]] size_t n_threads=etl::threads)
Dispatch the elements of a range to a functor in a parallel manner, using the global thread engine...
Definition: parallel_support.hpp:708
constexpr size_t gemm_rr_medium_threshold
The number of elements of B after which we use BLAS-like kernel (for GEMM)
Definition: threshold.hpp:56
Definition: bias_add.hpp:15
constexpr size_t gemm_rr_small_threshold
The number of elements of B after which we use BLAS-like kernel (for GEMM)
Definition: threshold.hpp:55
constexpr bool vectorize_impl
Indicates if the implementations can be automatically vectorized by ETL.
Definition: config.hpp:35
constexpr bool vec_enabled
Indicates if vectorization is available in any format.
Definition: config.hpp:220
typename V::template vec_type< value_type > vec_type
The vectorization type for V.
Definition: dyn_matrix_view.hpp:43
void gemm_large_kernel_rr_to_r(const T *a, const T *b, T *ETL_RESTRICT c, size_t M, size_t N, size_t K, T alpha, T beta)
Optimized version of large GEMM for row major version.
Definition: gemm_rr_to_r.hpp:601
auto load(size_t x) const noexcept
Load several elements of the expression at once.
Definition: dyn_matrix_view.hpp:143
void gemm_large_kernel_rr_to_r_temp(const T *a, const T *b, T *ETL_RESTRICT c, size_t M, size_t N, size_t K, T alpha, T beta)
Optimized version of large GEMM for row major version.
Definition: gemm_rr_to_r.hpp:855
void storeu(vec_type< V > in, size_t i) noexcept
Store several elements in the matrix at once.
Definition: dyn_matrix_view.hpp:187
Matrix with run-time fixed dimensions.
Definition: dyn.hpp:26
auto loadu(size_t x) const noexcept
Load several elements of the expression at once.
Definition: dyn_matrix_view.hpp:154
auto min(L &&lhs, R &&rhs)
Create an expression with the min value of lhs or rhs.
Definition: expression_builder.hpp:77
void gemm_small_kernel_rr_to_r(const T *a, const T *b, T *ETL_RESTRICT c, size_t M, size_t N, size_t K, T alpha)
Optimized version of small GEMM for row major version.
Definition: gemm_rr_to_r.hpp:35
Matrix with run-time fixed dimensions.
Definition: custom_dyn.hpp:27
void gemm_rr_to_r(const T *a, const T *b, T *c, size_t M, size_t N, size_t K, T alpha)
Vectorized implementation of row-major matrix - row-major matrix multiplication and assignment into a...
Definition: gemm_rr_to_r.hpp:1291