Expression Templates Library (ETL)
gemm_cr_to_r.hpp
Go to the documentation of this file.
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
14 #pragma once
15 
16 namespace etl::impl::vec {
17 
26 template <typename V, typename T>
27 void gemm_small_kernel_cr_to_r(const T* a, const T* b, T* c, size_t M, size_t N, size_t K, T alpha) {
28  using vec_type = V;
29 
30  static constexpr size_t vec_size = vec_type::template traits<T>::size;
31 
32  auto alpha_vec = vec_type::set(alpha);
33 
34  const auto j_end = prev_multiple(N, vec_size);
35 
36  size_t j = 0;
37 
38  for (; j + 7 * vec_size < j_end; j += 8 * vec_size) {
39  size_t i = 0;
40 
41  for (; i < M; i++) {
42  auto r11 = vec_type::template zero<T>();
43  auto r12 = vec_type::template zero<T>();
44  auto r13 = vec_type::template zero<T>();
45  auto r14 = vec_type::template zero<T>();
46  auto r15 = vec_type::template zero<T>();
47  auto r16 = vec_type::template zero<T>();
48  auto r17 = vec_type::template zero<T>();
49  auto r18 = vec_type::template zero<T>();
50 
51  for (size_t k = 0; k < K; ++k) {
52  auto a1 = vec_type::set(a[(i + 0) + k * M]);
53 
54  auto b1 = vec_type::loadu(b + k * N + j + 0 * vec_size);
55  auto b2 = vec_type::loadu(b + k * N + j + 1 * vec_size);
56  auto b3 = vec_type::loadu(b + k * N + j + 2 * vec_size);
57  auto b4 = vec_type::loadu(b + k * N + j + 3 * vec_size);
58  auto b5 = vec_type::loadu(b + k * N + j + 4 * vec_size);
59  auto b6 = vec_type::loadu(b + k * N + j + 5 * vec_size);
60  auto b7 = vec_type::loadu(b + k * N + j + 6 * vec_size);
61  auto b8 = vec_type::loadu(b + k * N + j + 7 * vec_size);
62 
63  r11 = vec_type::fmadd(a1, b1, r11);
64  r12 = vec_type::fmadd(a1, b2, r12);
65  r13 = vec_type::fmadd(a1, b3, r13);
66  r14 = vec_type::fmadd(a1, b4, r14);
67  r15 = vec_type::fmadd(a1, b5, r15);
68  r16 = vec_type::fmadd(a1, b6, r16);
69  r17 = vec_type::fmadd(a1, b7, r17);
70  r18 = vec_type::fmadd(a1, b8, r18);
71  }
72 
73  vec_type::storeu(c + (i + 0) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
74  vec_type::storeu(c + (i + 0) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r12));
75  vec_type::storeu(c + (i + 0) * N + j + 2 * vec_size, vec_type::mul(alpha_vec, r13));
76  vec_type::storeu(c + (i + 0) * N + j + 3 * vec_size, vec_type::mul(alpha_vec, r14));
77  vec_type::storeu(c + (i + 0) * N + j + 4 * vec_size, vec_type::mul(alpha_vec, r15));
78  vec_type::storeu(c + (i + 0) * N + j + 5 * vec_size, vec_type::mul(alpha_vec, r16));
79  vec_type::storeu(c + (i + 0) * N + j + 6 * vec_size, vec_type::mul(alpha_vec, r17));
80  vec_type::storeu(c + (i + 0) * N + j + 7 * vec_size, vec_type::mul(alpha_vec, r18));
81  }
82  }
83 
84  for (; j + 3 * vec_size < j_end; j += 4 * vec_size) {
85  size_t i = 0;
86 
87  for (; i + 1 < M; i += 2) {
88  auto r11 = vec_type::template zero<T>();
89  auto r21 = vec_type::template zero<T>();
90 
91  auto r12 = vec_type::template zero<T>();
92  auto r22 = vec_type::template zero<T>();
93 
94  auto r13 = vec_type::template zero<T>();
95  auto r23 = vec_type::template zero<T>();
96 
97  auto r14 = vec_type::template zero<T>();
98  auto r24 = vec_type::template zero<T>();
99 
100  for (size_t k = 0; k < K; ++k) {
101  auto a1 = vec_type::set(a[(i + 0) + k * M]);
102  auto a2 = vec_type::set(a[(i + 1) + k * M]);
103 
104  auto b1 = vec_type::loadu(b + k * N + j + 0 * vec_size);
105  auto b2 = vec_type::loadu(b + k * N + j + 1 * vec_size);
106  auto b3 = vec_type::loadu(b + k * N + j + 2 * vec_size);
107  auto b4 = vec_type::loadu(b + k * N + j + 3 * vec_size);
108 
109  r11 = vec_type::fmadd(a1, b1, r11);
110  r21 = vec_type::fmadd(a2, b1, r21);
111 
112  r12 = vec_type::fmadd(a1, b2, r12);
113  r22 = vec_type::fmadd(a2, b2, r22);
114 
115  r13 = vec_type::fmadd(a1, b3, r13);
116  r23 = vec_type::fmadd(a2, b3, r23);
117 
118  r14 = vec_type::fmadd(a1, b4, r14);
119  r24 = vec_type::fmadd(a2, b4, r24);
120  }
121 
122  vec_type::storeu(c + (i + 0) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
123  vec_type::storeu(c + (i + 1) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r21));
124 
125  vec_type::storeu(c + (i + 0) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r12));
126  vec_type::storeu(c + (i + 1) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r22));
127 
128  vec_type::storeu(c + (i + 0) * N + j + 2 * vec_size, vec_type::mul(alpha_vec, r13));
129  vec_type::storeu(c + (i + 1) * N + j + 2 * vec_size, vec_type::mul(alpha_vec, r23));
130 
131  vec_type::storeu(c + (i + 0) * N + j + 3 * vec_size, vec_type::mul(alpha_vec, r14));
132  vec_type::storeu(c + (i + 1) * N + j + 3 * vec_size, vec_type::mul(alpha_vec, r24));
133  }
134 
135  for (; i < M; i++) {
136  auto r11 = vec_type::template zero<T>();
137  auto r12 = vec_type::template zero<T>();
138  auto r13 = vec_type::template zero<T>();
139  auto r14 = vec_type::template zero<T>();
140 
141  for (size_t k = 0; k < K; ++k) {
142  auto a1 = vec_type::set(a[(i + 0) + k * M]);
143 
144  auto b1 = vec_type::loadu(b + k * N + j + 0 * vec_size);
145  auto b2 = vec_type::loadu(b + k * N + j + 1 * vec_size);
146  auto b3 = vec_type::loadu(b + k * N + j + 2 * vec_size);
147  auto b4 = vec_type::loadu(b + k * N + j + 3 * vec_size);
148 
149  r11 = vec_type::fmadd(a1, b1, r11);
150  r12 = vec_type::fmadd(a1, b2, r12);
151  r13 = vec_type::fmadd(a1, b3, r13);
152  r14 = vec_type::fmadd(a1, b4, r14);
153  }
154 
155  vec_type::storeu(c + (i + 0) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
156  vec_type::storeu(c + (i + 0) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r12));
157  vec_type::storeu(c + (i + 0) * N + j + 2 * vec_size, vec_type::mul(alpha_vec, r13));
158  vec_type::storeu(c + (i + 0) * N + j + 3 * vec_size, vec_type::mul(alpha_vec, r14));
159  }
160  }
161 
162  for (; j + 1 * vec_size < j_end; j += 2 * vec_size) {
163  size_t i = 0;
164 
165  for (; i + 1 < M; i += 2) {
166  auto r11 = vec_type::template zero<T>();
167  auto r21 = vec_type::template zero<T>();
168 
169  auto r12 = vec_type::template zero<T>();
170  auto r22 = vec_type::template zero<T>();
171 
172  for (size_t k = 0; k < K; ++k) {
173  auto a1 = vec_type::set(a[(i + 0) + k * M]);
174  auto a2 = vec_type::set(a[(i + 1) + k * M]);
175 
176  auto b1 = vec_type::loadu(b + k * N + j + 0 * vec_size);
177  auto b2 = vec_type::loadu(b + k * N + j + 1 * vec_size);
178 
179  r11 = vec_type::fmadd(a1, b1, r11);
180  r21 = vec_type::fmadd(a2, b1, r21);
181 
182  r12 = vec_type::fmadd(a1, b2, r12);
183  r22 = vec_type::fmadd(a2, b2, r22);
184  }
185 
186  vec_type::storeu(c + (i + 0) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
187  vec_type::storeu(c + (i + 1) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r21));
188 
189  vec_type::storeu(c + (i + 0) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r12));
190  vec_type::storeu(c + (i + 1) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r22));
191  }
192 
193  for (; i < M; i++) {
194  auto r11 = vec_type::template zero<T>();
195  auto r12 = vec_type::template zero<T>();
196 
197  for (size_t k = 0; k < K; ++k) {
198  auto a1 = vec_type::set(a[(i + 0) + k * M]);
199 
200  auto b1 = vec_type::loadu(b + k * N + j + 0 * vec_size);
201  auto b2 = vec_type::loadu(b + k * N + j + 1 * vec_size);
202 
203  r11 = vec_type::fmadd(a1, b1, r11);
204  r12 = vec_type::fmadd(a1, b2, r12);
205  }
206 
207  vec_type::storeu(c + (i + 0) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
208  vec_type::storeu(c + (i + 0) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r12));
209  }
210  }
211 
212  for (; j < j_end; j += vec_size) {
213  size_t i = 0;
214 
215  for (; i + 1 < M; i += 2) {
216  auto r11 = vec_type::template zero<T>();
217  auto r21 = vec_type::template zero<T>();
218 
219  for (size_t k = 0; k < K; ++k) {
220  auto a1 = vec_type::set(a[(i + 0) + k * M]);
221  auto a2 = vec_type::set(a[(i + 1) + k * M]);
222 
223  auto b1 = vec_type::loadu(b + k * N + j + 0 * vec_size);
224 
225  r11 = vec_type::fmadd(a1, b1, r11);
226  r21 = vec_type::fmadd(a2, b1, r21);
227  }
228 
229  vec_type::storeu(c + (i + 0) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
230  vec_type::storeu(c + (i + 1) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r21));
231  }
232 
233  if (i < M) {
234  auto r11 = vec_type::template zero<T>();
235 
236  for (size_t k = 0; k < K; ++k) {
237  auto a1 = vec_type::set(a[i + k * M]);
238 
239  auto b1 = vec_type::loadu(b + k * N + j + 0 * vec_size);
240 
241  r11 = vec_type::fmadd(a1, b1, r11);
242  }
243 
244  vec_type::storeu(c + i * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
245  }
246  }
247 
248  for (; j < N; ++j) {
249  size_t i = 0;
250 
251  for (; i + 1 < M; i += 2) {
252  auto r1 = T();
253  auto r2 = T();
254 
255  for (size_t k = 0; k < K; ++k) {
256  r1 += a[(i + 0) + k * M] * b[k * N + j];
257  r2 += a[(i + 1) + k * M] * b[k * N + j];
258  }
259 
260  c[(i + 0) * N + j] = alpha * r1;
261  c[(i + 1) * N + j] = alpha * r2;
262  }
263 
264  if (i < M) {
265  auto r1 = T();
266 
267  for (size_t k = 0; k < K; ++k) {
268  r1 += a[i + k * M] * b[k * N + j];
269  }
270 
271  c[i * N + j] = alpha * r1;
272  }
273  }
274 }
275 
284 template <typename V, typename T>
285 void gemm_large_kernel_cr_to_r(const T* a, const T* b, T* c, size_t M, size_t N, size_t K, T alpha) {
286  using vec_type = V;
287 
288  static constexpr size_t vec_size = vec_type::template traits<T>::size;
289 
290  constexpr size_t n_block_size = 128UL;
291  constexpr size_t m_block_size = 64UL;
292  constexpr size_t k_block_size = 128UL;
293 
294  auto alpha_vec = vec_type::set(alpha);
295 
296  for (size_t jj = 0; jj < N; jj += n_block_size) {
297  const size_t j_end_a = std::min(jj + n_block_size, N);
298  const auto j_end = prev_multiple(j_end_a, vec_size);
299 
300  for (size_t ii = 0; ii < M; ii += m_block_size) {
301  const size_t i_end = std::min(ii + m_block_size, M);
302 
303  for (size_t kk = 0; kk < K; kk += k_block_size) {
304  const size_t k_end = std::min(kk + k_block_size, K);
305 
306  size_t j = jj;
307 
308 #ifdef __clang__
309  for (; j + 3 * vec_size < j_end; j += 4 * vec_size) {
310  size_t i = ii;
311 
312  for (; i + 1 < i_end; i += 2) {
313  auto r11 = vec_type::loadu(c + (i + 0) * N + j + 0 * vec_size);
314  auto r12 = vec_type::loadu(c + (i + 0) * N + j + 1 * vec_size);
315  auto r13 = vec_type::loadu(c + (i + 0) * N + j + 2 * vec_size);
316  auto r14 = vec_type::loadu(c + (i + 0) * N + j + 3 * vec_size);
317 
318  auto r21 = vec_type::loadu(c + (i + 1) * N + j + 0 * vec_size);
319  auto r22 = vec_type::loadu(c + (i + 1) * N + j + 1 * vec_size);
320  auto r23 = vec_type::loadu(c + (i + 1) * N + j + 2 * vec_size);
321  auto r24 = vec_type::loadu(c + (i + 1) * N + j + 3 * vec_size);
322 
323  for (size_t k = kk; k < k_end; ++k) {
324  auto a1 = vec_type::set(a[(i + 0) + k * M]);
325  auto a2 = vec_type::set(a[(i + 1) + k * M]);
326 
327  auto b1 = vec_type::loadu(b + k * N + j + 0 * vec_size);
328  auto b2 = vec_type::loadu(b + k * N + j + 1 * vec_size);
329  auto b3 = vec_type::loadu(b + k * N + j + 2 * vec_size);
330  auto b4 = vec_type::loadu(b + k * N + j + 3 * vec_size);
331 
332  r11 = vec_type::fmadd(a1, b1, r11);
333  r12 = vec_type::fmadd(a1, b2, r12);
334  r13 = vec_type::fmadd(a1, b3, r13);
335  r14 = vec_type::fmadd(a1, b4, r14);
336 
337  r21 = vec_type::fmadd(a2, b1, r21);
338  r22 = vec_type::fmadd(a2, b2, r22);
339  r23 = vec_type::fmadd(a2, b3, r23);
340  r24 = vec_type::fmadd(a2, b4, r24);
341  }
342 
343  vec_type::storeu(c + (i + 0) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
344  vec_type::storeu(c + (i + 0) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r12));
345  vec_type::storeu(c + (i + 0) * N + j + 2 * vec_size, vec_type::mul(alpha_vec, r13));
346  vec_type::storeu(c + (i + 0) * N + j + 3 * vec_size, vec_type::mul(alpha_vec, r14));
347 
348  vec_type::storeu(c + (i + 1) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r21));
349  vec_type::storeu(c + (i + 1) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r22));
350  vec_type::storeu(c + (i + 1) * N + j + 2 * vec_size, vec_type::mul(alpha_vec, r23));
351  vec_type::storeu(c + (i + 1) * N + j + 3 * vec_size, vec_type::mul(alpha_vec, r24));
352  }
353 
354  if (i < i_end) {
355  auto r11 = vec_type::loadu(c + (i + 0) * N + j + 0 * vec_size);
356  auto r12 = vec_type::loadu(c + (i + 0) * N + j + 1 * vec_size);
357  auto r13 = vec_type::loadu(c + (i + 0) * N + j + 2 * vec_size);
358  auto r14 = vec_type::loadu(c + (i + 0) * N + j + 3 * vec_size);
359 
360  for (size_t k = kk; k < k_end; ++k) {
361  auto a1 = vec_type::set(a[(i + 0) + k * M]);
362 
363  auto b1 = vec_type::loadu(b + k * N + j + 0 * vec_size);
364  auto b2 = vec_type::loadu(b + k * N + j + 1 * vec_size);
365  auto b3 = vec_type::loadu(b + k * N + j + 2 * vec_size);
366  auto b4 = vec_type::loadu(b + k * N + j + 3 * vec_size);
367 
368  r11 = vec_type::fmadd(a1, b1, r11);
369  r12 = vec_type::fmadd(a1, b2, r12);
370  r13 = vec_type::fmadd(a1, b3, r13);
371  r14 = vec_type::fmadd(a1, b4, r14);
372  }
373 
374  vec_type::storeu(c + (i + 0) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
375  vec_type::storeu(c + (i + 0) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r12));
376  vec_type::storeu(c + (i + 0) * N + j + 2 * vec_size, vec_type::mul(alpha_vec, r13));
377  vec_type::storeu(c + (i + 0) * N + j + 3 * vec_size, vec_type::mul(alpha_vec, r14));
378  }
379  }
380 #endif
381 
382  for (; j + vec_size < j_end; j += 2 * vec_size) {
383  size_t i = ii;
384 
385  for (; i + 3 < i_end; i += 4) {
386  auto r11 = vec_type::loadu(c + (i + 0) * N + j + 0 * vec_size);
387  auto r12 = vec_type::loadu(c + (i + 0) * N + j + 1 * vec_size);
388 
389  auto r21 = vec_type::loadu(c + (i + 1) * N + j + 0 * vec_size);
390  auto r22 = vec_type::loadu(c + (i + 1) * N + j + 1 * vec_size);
391 
392  auto r31 = vec_type::loadu(c + (i + 2) * N + j + 0 * vec_size);
393  auto r32 = vec_type::loadu(c + (i + 2) * N + j + 1 * vec_size);
394 
395  auto r41 = vec_type::loadu(c + (i + 3) * N + j + 0 * vec_size);
396  auto r42 = vec_type::loadu(c + (i + 3) * N + j + 1 * vec_size);
397 
398  for (size_t k = kk; k < k_end; ++k) {
399  auto a1 = vec_type::set(a[(i + 0) + k * M]);
400  auto a2 = vec_type::set(a[(i + 1) + k * M]);
401  auto a3 = vec_type::set(a[(i + 2) + k * M]);
402  auto a4 = vec_type::set(a[(i + 3) + k * M]);
403 
404  auto b1 = vec_type::loadu(b + k * N + j + 0 * vec_size);
405  auto b2 = vec_type::loadu(b + k * N + j + 1 * vec_size);
406 
407  r11 = vec_type::fmadd(a1, b1, r11);
408  r12 = vec_type::fmadd(a1, b2, r12);
409 
410  r21 = vec_type::fmadd(a2, b1, r21);
411  r22 = vec_type::fmadd(a2, b2, r22);
412 
413  r31 = vec_type::fmadd(a3, b1, r31);
414  r32 = vec_type::fmadd(a3, b2, r32);
415 
416  r41 = vec_type::fmadd(a4, b1, r41);
417  r42 = vec_type::fmadd(a4, b2, r42);
418  }
419 
420  vec_type::storeu(c + (i + 0) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
421  vec_type::storeu(c + (i + 0) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r12));
422 
423  vec_type::storeu(c + (i + 1) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r21));
424  vec_type::storeu(c + (i + 1) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r22));
425 
426  vec_type::storeu(c + (i + 2) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r31));
427  vec_type::storeu(c + (i + 2) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r32));
428 
429  vec_type::storeu(c + (i + 3) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r41));
430  vec_type::storeu(c + (i + 3) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r42));
431  }
432 
433  for (; i + 1 < i_end; i += 2) {
434  auto r11 = vec_type::loadu(c + (i + 0) * N + j + 0 * vec_size);
435  auto r12 = vec_type::loadu(c + (i + 0) * N + j + 1 * vec_size);
436 
437  auto r21 = vec_type::loadu(c + (i + 1) * N + j + 0 * vec_size);
438  auto r22 = vec_type::loadu(c + (i + 1) * N + j + 1 * vec_size);
439 
440  for (size_t k = kk; k < k_end; ++k) {
441  auto a1 = vec_type::set(a[(i + 0) + k * M]);
442  auto a2 = vec_type::set(a[(i + 1) + k * M]);
443 
444  auto b1 = vec_type::loadu(b + k * N + j + 0 * vec_size);
445  auto b2 = vec_type::loadu(b + k * N + j + 1 * vec_size);
446 
447  r11 = vec_type::fmadd(a1, b1, r11);
448  r12 = vec_type::fmadd(a1, b2, r12);
449 
450  r21 = vec_type::fmadd(a2, b1, r21);
451  r22 = vec_type::fmadd(a2, b2, r22);
452  }
453 
454  vec_type::storeu(c + (i + 0) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
455  vec_type::storeu(c + (i + 0) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r12));
456 
457  vec_type::storeu(c + (i + 1) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r21));
458  vec_type::storeu(c + (i + 1) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r22));
459  }
460 
461  if (i < i_end) {
462  auto r11 = vec_type::loadu(c + (i + 0) * N + j + 0 * vec_size);
463  auto r12 = vec_type::loadu(c + (i + 0) * N + j + 1 * vec_size);
464 
465  for (size_t k = kk; k < k_end; ++k) {
466  auto a1 = vec_type::set(a[(i + 0) + k * M]);
467 
468  auto b1 = vec_type::loadu(b + k * N + j + 0 * vec_size);
469  auto b2 = vec_type::loadu(b + k * N + j + 1 * vec_size);
470 
471  r11 = vec_type::fmadd(a1, b1, r11);
472  r12 = vec_type::fmadd(a1, b2, r12);
473  }
474 
475  vec_type::storeu(c + (i + 0) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
476  vec_type::storeu(c + (i + 0) * N + j + 1 * vec_size, vec_type::mul(alpha_vec, r12));
477  }
478  }
479 
480  for (; j < j_end; j += vec_size) {
481  for (size_t i = ii; i < i_end; ++i) {
482  auto r11 = vec_type::loadu(c + (i + 0) * N + j + 0 * vec_size);
483 
484  for (size_t k = kk; k < k_end; ++k) {
485  auto a1 = vec_type::set(a[(i + 0) + k * M]);
486 
487  auto b1 = vec_type::loadu(b + k * N + j + 0 * vec_size);
488 
489  r11 = vec_type::fmadd(a1, b1, r11);
490  }
491 
492  vec_type::storeu(c + (i + 0) * N + j + 0 * vec_size, vec_type::mul(alpha_vec, r11));
493  }
494  }
495 
496  for (; j < j_end_a; ++j) {
497  for (size_t i = ii; i < i_end; ++i) {
498  auto r11 = c[(i + 0) * N + j];
499 
500  for (size_t k = kk; k < k_end; ++k) {
501  r11 += a[(i + 0) + k * M] * b[k * N + j];
502  }
503 
504  c[(i + 0) * N + j] = alpha * r11;
505  }
506  }
507  }
508  }
509  }
510 }
511 
524 template <typename T>
525 void gemm_cr_to_r(const T* a, const T* b, T* c, size_t M, size_t N, size_t K, T alpha) {
526  cpp_assert(vec_enabled, "At least one vector mode must be enabled for impl::VEC");
527  cpp_assert(vectorize_impl, "vectorize_impl must be enabled for impl::VEC");
528 
529  if (M * N <= gemm_rr_small_threshold) {
530  gemm_small_kernel_cr_to_r<default_vec>(a, b, c, M, N, K, alpha);
531  } else {
532  direct_fill_n(c, M * N, T(0));
533  gemm_large_kernel_cr_to_r<default_vec>(a, b, c, M, N, K, alpha);
534  }
535 }
536 
537 } //end of namespace etl::impl::vec
Definition: bias_add.hpp:15
constexpr size_t gemm_rr_small_threshold
The number of elements of B after which we use BLAS-like kernel (for GEMM)
Definition: threshold.hpp:55
constexpr bool vectorize_impl
Indicates if the implementations can be automatically vectorized by ETL.
Definition: config.hpp:35
constexpr bool vec_enabled
Indicates if vectorization is available in any format.
Definition: config.hpp:220
typename V::template vec_type< value_type > vec_type
The vectorization type for V.
Definition: dyn_matrix_view.hpp:43
void gemm_cr_to_r(const T *a, const T *b, T *c, size_t M, size_t N, size_t K, T alpha)
Vectorized implementation of column-major matrix - row-major matrix multiplication and assignment int...
Definition: gemm_cr_to_r.hpp:525
void storeu(vec_type< V > in, size_t i) noexcept
Store several elements in the matrix at once.
Definition: dyn_matrix_view.hpp:187
void gemm_large_kernel_cr_to_r(const T *a, const T *b, T *c, size_t M, size_t N, size_t K, T alpha)
Optimized version of GEMM for assignment of a large Column-Major Matrix - Row Major Matrix to a Row M...
Definition: gemm_cr_to_r.hpp:285
auto loadu(size_t x) const noexcept
Load several elements of the expression at once.
Definition: dyn_matrix_view.hpp:154
auto min(L &&lhs, R &&rhs)
Create an expression with the min value of lhs or rhs.
Definition: expression_builder.hpp:77
void gemm_small_kernel_cr_to_r(const T *a, const T *b, T *c, size_t M, size_t N, size_t K, T alpha)
Optimized version of GEMM for assignment of a small Column-Major Matrix - Row Major Matrix to a Row M...
Definition: gemm_cr_to_r.hpp:27
void direct_fill_n(S *first, size_t n, T value)
Fills the given memory with the given value.
Definition: memory.hpp:57