Expression Templates Library (ETL)
gemm_cr_to_c.hpp
Go to the documentation of this file.
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
14 #pragma once
15 
16 namespace etl::impl::vec {
17 
26 template <typename V, typename T>
27 void gemm_small_kernel_cr_to_c(const T* a, const T* b, T* c, size_t M, size_t N, size_t K, T alpha) {
28  using vec_type = V;
29 
30  static constexpr size_t vec_size = vec_type::template traits<T>::size;
31 
32  auto alpha_vec = vec_type::set(alpha);
33 
34  const auto i_end = prev_multiple(M, vec_size);
35 
36  size_t i = 0;
37 
38  for (; i + 3 * vec_size < i_end; i += 4 * vec_size) {
39  size_t j = 0;
40 
41  for (; j + 1 < N; j += 2) {
42  auto r11 = vec_type::template zero<T>();
43  auto r21 = vec_type::template zero<T>();
44  auto r31 = vec_type::template zero<T>();
45  auto r41 = vec_type::template zero<T>();
46 
47  auto r12 = vec_type::template zero<T>();
48  auto r22 = vec_type::template zero<T>();
49  auto r32 = vec_type::template zero<T>();
50  auto r42 = vec_type::template zero<T>();
51 
52  for (size_t k = 0; k < K; ++k) {
53  auto a1 = vec_type::loadu(a + i + k * M + 0 * vec_size);
54  auto a2 = vec_type::loadu(a + i + k * M + 1 * vec_size);
55  auto a3 = vec_type::loadu(a + i + k * M + 2 * vec_size);
56  auto a4 = vec_type::loadu(a + i + k * M + 3 * vec_size);
57 
58  auto b1 = vec_type::set(b[k * N + (j + 0)]);
59  auto b2 = vec_type::set(b[k * N + (j + 1)]);
60 
61  r11 = vec_type::fmadd(a1, b1, r11);
62  r21 = vec_type::fmadd(a2, b1, r21);
63  r31 = vec_type::fmadd(a3, b1, r31);
64  r41 = vec_type::fmadd(a4, b1, r41);
65 
66  r12 = vec_type::fmadd(a1, b2, r12);
67  r22 = vec_type::fmadd(a2, b2, r22);
68  r32 = vec_type::fmadd(a3, b2, r32);
69  r42 = vec_type::fmadd(a4, b2, r42);
70  }
71 
72  vec_type::storeu(c + i + (j + 0) * M + 0 * vec_size, vec_type::mul(alpha_vec, r11));
73  vec_type::storeu(c + i + (j + 0) * M + 1 * vec_size, vec_type::mul(alpha_vec, r21));
74  vec_type::storeu(c + i + (j + 0) * M + 2 * vec_size, vec_type::mul(alpha_vec, r31));
75  vec_type::storeu(c + i + (j + 0) * M + 3 * vec_size, vec_type::mul(alpha_vec, r41));
76 
77  vec_type::storeu(c + i + (j + 1) * M + 0 * vec_size, vec_type::mul(alpha_vec, r12));
78  vec_type::storeu(c + i + (j + 1) * M + 1 * vec_size, vec_type::mul(alpha_vec, r22));
79  vec_type::storeu(c + i + (j + 1) * M + 2 * vec_size, vec_type::mul(alpha_vec, r32));
80  vec_type::storeu(c + i + (j + 1) * M + 3 * vec_size, vec_type::mul(alpha_vec, r42));
81  }
82 
83  if (j < N) {
84  auto r11 = vec_type::template zero<T>();
85  auto r21 = vec_type::template zero<T>();
86  auto r31 = vec_type::template zero<T>();
87  auto r41 = vec_type::template zero<T>();
88 
89  for (size_t k = 0; k < K; ++k) {
90  auto a1 = vec_type::loadu(a + i + k * M + 0 * vec_size);
91  auto a2 = vec_type::loadu(a + i + k * M + 1 * vec_size);
92  auto a3 = vec_type::loadu(a + i + k * M + 2 * vec_size);
93  auto a4 = vec_type::loadu(a + i + k * M + 3 * vec_size);
94 
95  auto b1 = vec_type::set(b[k * N + j]);
96 
97  r11 = vec_type::fmadd(a1, b1, r11);
98  r21 = vec_type::fmadd(a2, b1, r21);
99  r31 = vec_type::fmadd(a3, b1, r31);
100  r41 = vec_type::fmadd(a4, b1, r41);
101  }
102 
103  vec_type::storeu(c + i + j * M + 0 * vec_size, vec_type::mul(alpha_vec, r11));
104  vec_type::storeu(c + i + j * M + 1 * vec_size, vec_type::mul(alpha_vec, r21));
105  vec_type::storeu(c + i + j * M + 2 * vec_size, vec_type::mul(alpha_vec, r31));
106  vec_type::storeu(c + i + j * M + 3 * vec_size, vec_type::mul(alpha_vec, r41));
107  }
108  }
109 
110  for (; i + 1 * vec_size < i_end; i += 2 * vec_size) {
111  size_t j = 0;
112 
113  for (; j + 1 < N; j += 2) {
114  auto r11 = vec_type::template zero<T>();
115  auto r21 = vec_type::template zero<T>();
116 
117  auto r12 = vec_type::template zero<T>();
118  auto r22 = vec_type::template zero<T>();
119 
120  for (size_t k = 0; k < K; ++k) {
121  auto a1 = vec_type::loadu(a + i + k * M + 0 * vec_size);
122  auto a2 = vec_type::loadu(a + i + k * M + 1 * vec_size);
123 
124  auto b1 = vec_type::set(b[k * N + (j + 0)]);
125  auto b2 = vec_type::set(b[k * N + (j + 1)]);
126 
127  r11 = vec_type::fmadd(a1, b1, r11);
128  r21 = vec_type::fmadd(a2, b1, r21);
129 
130  r12 = vec_type::fmadd(a1, b2, r12);
131  r22 = vec_type::fmadd(a2, b2, r22);
132  }
133 
134  vec_type::storeu(c + i + (j + 0) * M + 0 * vec_size, vec_type::mul(alpha_vec, r11));
135  vec_type::storeu(c + i + (j + 0) * M + 1 * vec_size, vec_type::mul(alpha_vec, r21));
136 
137  vec_type::storeu(c + i + (j + 1) * M + 0 * vec_size, vec_type::mul(alpha_vec, r12));
138  vec_type::storeu(c + i + (j + 1) * M + 1 * vec_size, vec_type::mul(alpha_vec, r22));
139  }
140 
141  if (j < N) {
142  auto r11 = vec_type::template zero<T>();
143  auto r21 = vec_type::template zero<T>();
144 
145  for (size_t k = 0; k < K; ++k) {
146  auto a1 = vec_type::loadu(a + i + k * M + 0 * vec_size);
147  auto a2 = vec_type::loadu(a + i + k * M + 1 * vec_size);
148 
149  auto b1 = vec_type::set(b[k * N + j]);
150 
151  r11 = vec_type::fmadd(a1, b1, r11);
152  r21 = vec_type::fmadd(a2, b1, r21);
153  }
154 
155  vec_type::storeu(c + i + j * M + 0 * vec_size, vec_type::mul(alpha_vec, r11));
156  vec_type::storeu(c + i + j * M + 1 * vec_size, vec_type::mul(alpha_vec, r21));
157  }
158  }
159 
160  for (; i < i_end; i += vec_size) {
161  size_t j = 0;
162 
163  for (; j + 1 < N; j += 2) {
164  auto r11 = vec_type::template zero<T>();
165  auto r12 = vec_type::template zero<T>();
166 
167  for (size_t k = 0; k < K; ++k) {
168  auto a1 = vec_type::loadu(a + i + k * M + 0 * vec_size);
169 
170  auto b1 = vec_type::set(b[k * N + (j + 0)]);
171  auto b2 = vec_type::set(b[k * N + (j + 1)]);
172 
173  r11 = vec_type::fmadd(a1, b1, r11);
174  r12 = vec_type::fmadd(a1, b2, r12);
175  }
176 
177  vec_type::storeu(c + i + (j + 0) * M, vec_type::mul(alpha_vec, r11));
178  vec_type::storeu(c + i + (j + 1) * M, vec_type::mul(alpha_vec, r12));
179  }
180 
181  if (j < N) {
182  auto r11 = vec_type::template zero<T>();
183 
184  for (size_t k = 0; k < K; ++k) {
185  auto a1 = vec_type::loadu(a + i + k * M + 0 * vec_size);
186 
187  auto b1 = vec_type::set(b[k * N + j]);
188 
189  r11 = vec_type::fmadd(a1, b1, r11);
190  }
191 
192  vec_type::storeu(c + i + j * M, vec_type::mul(alpha_vec, r11));
193  }
194  }
195 
196  for (; i < M; ++i) {
197  size_t j = 0;
198 
199  for (; j + 1 < N; j += 2) {
200  auto r11 = T();
201  auto r12 = T();
202 
203  for (size_t k = 0; k < K; ++k) {
204  r11 += a[i + k * M] * b[k * N + (j + 0)];
205  r12 += a[i + k * M] * b[k * N + (j + 1)];
206  }
207 
208  c[i + (j + 0) * M] = alpha * r11;
209  c[i + (j + 1) * M] = alpha * r12;
210  }
211 
212  if (j < N) {
213  auto r11 = T();
214 
215  for (size_t k = 0; k < K; ++k) {
216  r11 += a[i + k * M] * b[k * N + j];
217  }
218 
219  c[i + j * M] = alpha * r11;
220  }
221  }
222 }
223 
232 template <typename V, typename T>
233 void gemm_large_kernel_cr_to_c(const T* a, const T* b, T* c, size_t M, size_t N, size_t K, T alpha) {
234  using vec_type = V;
235 
236  static constexpr size_t vec_size = vec_type::template traits<T>::size;
237 
238  constexpr size_t n_block_size = 128UL;
239  constexpr size_t m_block_size = 64UL;
240  constexpr size_t k_block_size = 128UL;
241 
242  auto alpha_vec = vec_type::set(alpha);
243 
244  for (size_t ii = 0; ii < M; ii += m_block_size) {
245  const size_t i_end = std::min(ii + m_block_size, M);
246  const size_t i_pos = prev_multiple(i_end, vec_size);
247 
248  for (size_t jj = 0; jj < N; jj += n_block_size) {
249  const size_t j_end = std::min(jj + n_block_size, N);
250 
251  for (size_t kk = 0; kk < K; kk += k_block_size) {
252  const size_t k_end = std::min(kk + k_block_size, K);
253 
254  size_t i = ii;
255 
256 #ifdef __clang__
257  for (; i + 3 * vec_size < i_pos; i += 4 * vec_size) {
258  size_t j = jj;
259 
260  for (; j + 1 < j_end; j += 2) {
261  auto r11 = vec_type::loadu(c + i + (j + 0) * M + 0 * vec_size);
262  auto r12 = vec_type::loadu(c + i + (j + 0) * M + 1 * vec_size);
263  auto r13 = vec_type::loadu(c + i + (j + 0) * M + 2 * vec_size);
264  auto r14 = vec_type::loadu(c + i + (j + 0) * M + 3 * vec_size);
265 
266  auto r21 = vec_type::loadu(c + i + (j + 1) * M + 0 * vec_size);
267  auto r22 = vec_type::loadu(c + i + (j + 1) * M + 1 * vec_size);
268  auto r23 = vec_type::loadu(c + i + (j + 1) * M + 2 * vec_size);
269  auto r24 = vec_type::loadu(c + i + (j + 1) * M + 3 * vec_size);
270 
271  for (size_t k = kk; k < k_end; ++k) {
272  auto a1 = vec_type::loadu(a + i + k * M + 0 * vec_size);
273  auto a2 = vec_type::loadu(a + i + k * M + 1 * vec_size);
274  auto a3 = vec_type::loadu(a + i + k * M + 2 * vec_size);
275  auto a4 = vec_type::loadu(a + i + k * M + 3 * vec_size);
276 
277  auto b1 = vec_type::set(b[k * N + (j + 0)]);
278  auto b2 = vec_type::set(b[k * N + (j + 1)]);
279 
280  r11 = vec_type::fmadd(a1, b1, r11);
281  r12 = vec_type::fmadd(a2, b1, r12);
282  r13 = vec_type::fmadd(a3, b1, r13);
283  r14 = vec_type::fmadd(a4, b1, r14);
284 
285  r21 = vec_type::fmadd(a1, b2, r21);
286  r22 = vec_type::fmadd(a2, b2, r22);
287  r23 = vec_type::fmadd(a3, b2, r23);
288  r24 = vec_type::fmadd(a4, b2, r24);
289  }
290 
291  vec_type::storeu(c + i + (j + 0) * M + 0 * vec_size, vec_type::mul(alpha_vec, r11));
292  vec_type::storeu(c + i + (j + 0) * M + 1 * vec_size, vec_type::mul(alpha_vec, r12));
293  vec_type::storeu(c + i + (j + 0) * M + 2 * vec_size, vec_type::mul(alpha_vec, r13));
294  vec_type::storeu(c + i + (j + 0) * M + 3 * vec_size, vec_type::mul(alpha_vec, r14));
295 
296  vec_type::storeu(c + i + (j + 1) * M + 0 * vec_size, vec_type::mul(alpha_vec, r21));
297  vec_type::storeu(c + i + (j + 1) * M + 1 * vec_size, vec_type::mul(alpha_vec, r22));
298  vec_type::storeu(c + i + (j + 1) * M + 2 * vec_size, vec_type::mul(alpha_vec, r23));
299  vec_type::storeu(c + i + (j + 1) * M + 3 * vec_size, vec_type::mul(alpha_vec, r24));
300  }
301 
302  if (j < j_end) {
303  auto r11 = vec_type::loadu(c + i + j * M + 0 * vec_size);
304  auto r12 = vec_type::loadu(c + i + j * M + 1 * vec_size);
305  auto r13 = vec_type::loadu(c + i + j * M + 2 * vec_size);
306  auto r14 = vec_type::loadu(c + i + j * M + 3 * vec_size);
307 
308  for (size_t k = kk; k < k_end; ++k) {
309  auto a1 = vec_type::loadu(a + i + k * M + 0 * vec_size);
310  auto a2 = vec_type::loadu(a + i + k * M + 1 * vec_size);
311  auto a3 = vec_type::loadu(a + i + k * M + 2 * vec_size);
312  auto a4 = vec_type::loadu(a + i + k * M + 3 * vec_size);
313 
314  auto b1 = vec_type::set(b[k * N + j]);
315 
316  r11 = vec_type::fmadd(a1, b1, r11);
317  r12 = vec_type::fmadd(a2, b1, r12);
318  r13 = vec_type::fmadd(a3, b1, r13);
319  r14 = vec_type::fmadd(a4, b1, r14);
320  }
321 
322  vec_type::storeu(c + i + j * M + 0 * vec_size, vec_type::mul(alpha_vec, r11));
323  vec_type::storeu(c + i + j * M + 1 * vec_size, vec_type::mul(alpha_vec, r12));
324  vec_type::storeu(c + i + j * M + 2 * vec_size, vec_type::mul(alpha_vec, r13));
325  vec_type::storeu(c + i + j * M + 3 * vec_size, vec_type::mul(alpha_vec, r14));
326  }
327  }
328 #endif
329 
330  for (; i + 1 * vec_size < i_pos; i += 2 * vec_size) {
331  size_t j = jj;
332 
333  for (; j + 3 < j_end; j += 4) {
334  auto r11 = vec_type::loadu(c + i + (j + 0) * M + 0 * vec_size);
335  auto r12 = vec_type::loadu(c + i + (j + 0) * M + 1 * vec_size);
336 
337  auto r21 = vec_type::loadu(c + i + (j + 1) * M + 0 * vec_size);
338  auto r22 = vec_type::loadu(c + i + (j + 1) * M + 1 * vec_size);
339 
340  auto r31 = vec_type::loadu(c + i + (j + 2) * M + 0 * vec_size);
341  auto r32 = vec_type::loadu(c + i + (j + 2) * M + 1 * vec_size);
342 
343  auto r41 = vec_type::loadu(c + i + (j + 3) * M + 0 * vec_size);
344  auto r42 = vec_type::loadu(c + i + (j + 3) * M + 1 * vec_size);
345 
346  for (size_t k = kk; k < k_end; ++k) {
347  auto a1 = vec_type::loadu(a + i + k * M + 0 * vec_size);
348  auto a2 = vec_type::loadu(a + i + k * M + 1 * vec_size);
349 
350  auto b1 = vec_type::set(b[k * N + (j + 0)]);
351  auto b2 = vec_type::set(b[k * N + (j + 1)]);
352  auto b3 = vec_type::set(b[k * N + (j + 2)]);
353  auto b4 = vec_type::set(b[k * N + (j + 3)]);
354 
355  r11 = vec_type::fmadd(a1, b1, r11);
356  r12 = vec_type::fmadd(a2, b1, r12);
357 
358  r21 = vec_type::fmadd(a1, b2, r21);
359  r22 = vec_type::fmadd(a2, b2, r22);
360 
361  r31 = vec_type::fmadd(a1, b3, r31);
362  r32 = vec_type::fmadd(a2, b3, r32);
363 
364  r41 = vec_type::fmadd(a1, b4, r41);
365  r42 = vec_type::fmadd(a2, b4, r42);
366  }
367 
368  vec_type::storeu(c + i + (j + 0) * M + 0 * vec_size, vec_type::mul(alpha_vec, r11));
369  vec_type::storeu(c + i + (j + 0) * M + 1 * vec_size, vec_type::mul(alpha_vec, r12));
370 
371  vec_type::storeu(c + i + (j + 1) * M + 0 * vec_size, vec_type::mul(alpha_vec, r21));
372  vec_type::storeu(c + i + (j + 1) * M + 1 * vec_size, vec_type::mul(alpha_vec, r22));
373 
374  vec_type::storeu(c + i + (j + 2) * M + 0 * vec_size, vec_type::mul(alpha_vec, r31));
375  vec_type::storeu(c + i + (j + 2) * M + 1 * vec_size, vec_type::mul(alpha_vec, r32));
376 
377  vec_type::storeu(c + i + (j + 3) * M + 0 * vec_size, vec_type::mul(alpha_vec, r41));
378  vec_type::storeu(c + i + (j + 3) * M + 1 * vec_size, vec_type::mul(alpha_vec, r42));
379  }
380 
381  for (; j + 1 < j_end; j += 2) {
382  auto r11 = vec_type::loadu(c + i + (j + 0) * M + 0 * vec_size);
383  auto r12 = vec_type::loadu(c + i + (j + 0) * M + 1 * vec_size);
384 
385  auto r21 = vec_type::loadu(c + i + (j + 1) * M + 0 * vec_size);
386  auto r22 = vec_type::loadu(c + i + (j + 1) * M + 1 * vec_size);
387 
388  for (size_t k = kk; k < k_end; ++k) {
389  auto a1 = vec_type::loadu(a + i + k * M + 0 * vec_size);
390  auto a2 = vec_type::loadu(a + i + k * M + 1 * vec_size);
391 
392  auto b1 = vec_type::set(b[k * N + (j + 0)]);
393  auto b2 = vec_type::set(b[k * N + (j + 1)]);
394 
395  r11 = vec_type::fmadd(a1, b1, r11);
396  r12 = vec_type::fmadd(a2, b1, r12);
397 
398  r21 = vec_type::fmadd(a1, b2, r21);
399  r22 = vec_type::fmadd(a2, b2, r22);
400  }
401 
402  vec_type::storeu(c + i + (j + 0) * M + 0 * vec_size, vec_type::mul(alpha_vec, r11));
403  vec_type::storeu(c + i + (j + 0) * M + 1 * vec_size, vec_type::mul(alpha_vec, r12));
404 
405  vec_type::storeu(c + i + (j + 1) * M + 0 * vec_size, vec_type::mul(alpha_vec, r21));
406  vec_type::storeu(c + i + (j + 1) * M + 1 * vec_size, vec_type::mul(alpha_vec, r22));
407  }
408 
409  if (j < j_end) {
410  auto r11 = vec_type::loadu(c + i + j * M + 0 * vec_size);
411  auto r12 = vec_type::loadu(c + i + j * M + 1 * vec_size);
412 
413  for (size_t k = kk; k < k_end; ++k) {
414  auto a1 = vec_type::loadu(a + i + k * M + 0 * vec_size);
415  auto a2 = vec_type::loadu(a + i + k * M + 1 * vec_size);
416 
417  auto b1 = vec_type::set(b[k * N + j]);
418 
419  r11 = vec_type::fmadd(a1, b1, r11);
420  r12 = vec_type::fmadd(a2, b1, r12);
421  }
422 
423  vec_type::storeu(c + i + j * M + 0 * vec_size, vec_type::mul(alpha_vec, r11));
424  vec_type::storeu(c + i + j * M + 1 * vec_size, vec_type::mul(alpha_vec, r12));
425  }
426  }
427 
428  for (; i < i_pos; i += vec_size) {
429  for (size_t j = jj; j < j_end; ++j) {
430  auto r11 = vec_type::loadu(c + i + j * M + 0 * vec_size);
431 
432  for (size_t k = kk; k < k_end; ++k) {
433  auto a1 = vec_type::loadu(a + i + k * M + 0 * vec_size);
434 
435  auto b1 = vec_type::set(b[k * N + j]);
436 
437  r11 = vec_type::fmadd(a1, b1, r11);
438  }
439 
440  vec_type::storeu(c + i + j * M + 0 * vec_size, vec_type::mul(alpha_vec, r11));
441  }
442  }
443 
444  for (; i < i_end; ++i) {
445  for (size_t j = jj; j < j_end; ++j) {
446  auto r11 = c[i + j * M];
447 
448  for (size_t k = kk; k < k_end; ++k) {
449  r11 += a[i + k * M] * b[k * N + j];
450  }
451 
452  c[i + j * M] = alpha * r11;
453  }
454  }
455  }
456  }
457  }
458 }
459 
472 template <typename T>
473 void gemm_cr_to_c(const T* a, const T* b, T* c, size_t M, size_t N, size_t K, T alpha) {
474  cpp_assert(vec_enabled, "At least one vector mode must be enabled for impl::VEC");
475  cpp_assert(vectorize_impl, "vectorize_impl must be enabled for impl::VEC");
476 
477  if (M * N <= gemm_rr_small_threshold) {
478  gemm_small_kernel_cr_to_c<default_vec>(a, b, c, M, N, K, alpha);
479  } else {
480  direct_fill_n(c, M * N, T(0));
481  gemm_large_kernel_cr_to_c<default_vec>(a, b, c, M, N, K, alpha);
482  }
483 }
484 
485 } //end of namespace etl::impl::vec
void gemm_large_kernel_cr_to_c(const T *a, const T *b, T *c, size_t M, size_t N, size_t K, T alpha)
Optimized version of GEMM for assignment of a large Column-Major Matrix - Row Major Matrix to a Colum...
Definition: gemm_cr_to_c.hpp:233
Definition: bias_add.hpp:15
constexpr size_t gemm_rr_small_threshold
The number of elements of B after which we use BLAS-like kernel (for GEMM)
Definition: threshold.hpp:55
constexpr bool vectorize_impl
Indicates if the implementations can be automatically vectorized by ETL.
Definition: config.hpp:35
constexpr bool vec_enabled
Indicates if vectorization is available in any format.
Definition: config.hpp:220
typename V::template vec_type< value_type > vec_type
The vectorization type for V.
Definition: dyn_matrix_view.hpp:43
void gemm_cr_to_c(const T *a, const T *b, T *c, size_t M, size_t N, size_t K, T alpha)
Vectorized implementation of column-major matrix - row-major matrix multiplication and assignment int...
Definition: gemm_cr_to_c.hpp:473
void storeu(vec_type< V > in, size_t i) noexcept
Store several elements in the matrix at once.
Definition: dyn_matrix_view.hpp:187
auto loadu(size_t x) const noexcept
Load several elements of the expression at once.
Definition: dyn_matrix_view.hpp:154
auto min(L &&lhs, R &&rhs)
Create an expression with the min value of lhs or rhs.
Definition: expression_builder.hpp:77
void gemm_small_kernel_cr_to_c(const T *a, const T *b, T *c, size_t M, size_t N, size_t K, T alpha)
Optimized version of GEMM for assignment of a small Column-Major Matrix - Row Major Matrix to a Colum...
Definition: gemm_cr_to_c.hpp:27
void direct_fill_n(S *first, size_t n, T value)
Fills the given memory with the given value.
Definition: memory.hpp:57