Expression Templates Library (ETL)
gemm_cc_to_c.hpp
Go to the documentation of this file.
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
14 #pragma once
15 
16 namespace etl::impl::vec {
17 
24 template <typename V, typename T>
25 void gemm_small_kernel_cc_to_c(const T* a, const T* b, T* c, size_t M, size_t N, size_t K, T alpha) {
26  using vec_type = V;
27 
28  static constexpr size_t vec_size = vec_type::template traits<T>::size;
29 
30  auto alpha_vec = vec_type::set(alpha);
31 
32  size_t i = 0;
33 
34  for (; i + 4 * vec_size - 1 < M; i += 4 * vec_size) {
35  size_t j = 0;
36 
37  for (; (j + 2UL) <= N; j += 2UL) {
38  auto r11 = vec_type::template zero<T>();
39  auto r12 = vec_type::template zero<T>();
40 
41  auto r21 = vec_type::template zero<T>();
42  auto r22 = vec_type::template zero<T>();
43 
44  auto r31 = vec_type::template zero<T>();
45  auto r32 = vec_type::template zero<T>();
46 
47  auto r41 = vec_type::template zero<T>();
48  auto r42 = vec_type::template zero<T>();
49 
50  for (size_t k = 0; k < K; ++k) {
51  auto a11 = vec_type::loadu(a + (i + vec_size * 0) + k * M);
52  auto a21 = vec_type::loadu(a + (i + vec_size * 1) + k * M);
53  auto a31 = vec_type::loadu(a + (i + vec_size * 2) + k * M);
54  auto a41 = vec_type::loadu(a + (i + vec_size * 3) + k * M);
55 
56  auto b1 = vec_type::set(b[k + (j + 0) * K]);
57  auto b2 = vec_type::set(b[k + (j + 1) * K]);
58 
59  r11 = vec_type::fmadd(a11, b1, r11);
60  r12 = vec_type::fmadd(a11, b2, r12);
61 
62  r21 = vec_type::fmadd(a21, b1, r21);
63  r22 = vec_type::fmadd(a21, b2, r22);
64 
65  r31 = vec_type::fmadd(a31, b1, r31);
66  r32 = vec_type::fmadd(a31, b2, r32);
67 
68  r41 = vec_type::fmadd(a41, b1, r41);
69  r42 = vec_type::fmadd(a41, b2, r42);
70  }
71 
72  vec_type::storeu(c + (i + vec_size * 0) + (j + 0) * M, vec_type::mul(alpha_vec, r11));
73  vec_type::storeu(c + (i + vec_size * 0) + (j + 1) * M, vec_type::mul(alpha_vec, r12));
74 
75  vec_type::storeu(c + (i + vec_size * 1) + (j + 0) * M, vec_type::mul(alpha_vec, r21));
76  vec_type::storeu(c + (i + vec_size * 1) + (j + 1) * M, vec_type::mul(alpha_vec, r22));
77 
78  vec_type::storeu(c + (i + vec_size * 2) + (j + 0) * M, vec_type::mul(alpha_vec, r31));
79  vec_type::storeu(c + (i + vec_size * 2) + (j + 1) * M, vec_type::mul(alpha_vec, r32));
80 
81  vec_type::storeu(c + (i + vec_size * 3) + (j + 0) * M, vec_type::mul(alpha_vec, r41));
82  vec_type::storeu(c + (i + vec_size * 3) + (j + 1) * M, vec_type::mul(alpha_vec, r42));
83  }
84 
85  if (j < N) {
86  auto r11 = vec_type::template zero<T>();
87  auto r21 = vec_type::template zero<T>();
88  auto r31 = vec_type::template zero<T>();
89  auto r41 = vec_type::template zero<T>();
90 
91  for (size_t k = 0; k < K; ++k) {
92  auto a11 = vec_type::loadu(a + i + vec_size * 0 + k * M);
93  auto a21 = vec_type::loadu(a + i + vec_size * 1 + k * M);
94  auto a31 = vec_type::loadu(a + i + vec_size * 2 + k * M);
95  auto a41 = vec_type::loadu(a + i + vec_size * 3 + k * M);
96 
97  auto b1 = vec_type::set(b[k + j * K]);
98 
99  r11 = vec_type::fmadd(a11, b1, r11);
100  r21 = vec_type::fmadd(a21, b1, r21);
101  r31 = vec_type::fmadd(a31, b1, r31);
102  r41 = vec_type::fmadd(a41, b1, r41);
103  }
104 
105  vec_type::storeu(c + i + vec_size * 0 + j * M, vec_type::mul(alpha_vec, r11));
106  vec_type::storeu(c + i + vec_size * 1 + j * M, vec_type::mul(alpha_vec, r21));
107  vec_type::storeu(c + i + vec_size * 2 + j * M, vec_type::mul(alpha_vec, r31));
108  vec_type::storeu(c + i + vec_size * 3 + j * M, vec_type::mul(alpha_vec, r41));
109  }
110  }
111 
112  for (; i + 2 * vec_size - 1 < M; i += 2 * vec_size) {
113  size_t j = 0;
114 
115  for (; (j + 2UL) <= N; j += 2UL) {
116  auto r11 = vec_type::template zero<T>();
117  auto r12 = vec_type::template zero<T>();
118 
119  auto r21 = vec_type::template zero<T>();
120  auto r22 = vec_type::template zero<T>();
121 
122  for (size_t k = 0; k < K; ++k) {
123  auto a11 = vec_type::loadu(a + (i + vec_size * 0) + k * M);
124  auto a21 = vec_type::loadu(a + (i + vec_size * 1) + k * M);
125 
126  auto b1 = vec_type::set(b[k + (j + 0) * K]);
127  auto b2 = vec_type::set(b[k + (j + 1) * K]);
128 
129  r11 = vec_type::fmadd(a11, b1, r11);
130  r12 = vec_type::fmadd(a11, b2, r12);
131 
132  r21 = vec_type::fmadd(a21, b1, r21);
133  r22 = vec_type::fmadd(a21, b2, r22);
134  }
135 
136  vec_type::storeu(c + (i + vec_size * 0) + (j + 0) * M, vec_type::mul(alpha_vec, r11));
137  vec_type::storeu(c + (i + vec_size * 0) + (j + 1) * M, vec_type::mul(alpha_vec, r12));
138 
139  vec_type::storeu(c + (i + vec_size * 1) + (j + 0) * M, vec_type::mul(alpha_vec, r21));
140  vec_type::storeu(c + (i + vec_size * 1) + (j + 1) * M, vec_type::mul(alpha_vec, r22));
141  }
142 
143  if (j < N) {
144  auto r11 = vec_type::template zero<T>();
145  auto r21 = vec_type::template zero<T>();
146 
147  for (size_t k = 0; k < K; ++k) {
148  auto a11 = vec_type::loadu(a + i + vec_size * 0 + k * M);
149  auto a21 = vec_type::loadu(a + i + vec_size * 1 + k * M);
150 
151  auto b1 = vec_type::set(b[k + j * K]);
152 
153  r11 = vec_type::fmadd(a11, b1, r11);
154  r21 = vec_type::fmadd(a21, b1, r21);
155  }
156 
157  vec_type::storeu(c + i + vec_size * 0 + j * M, vec_type::mul(alpha_vec, r11));
158  vec_type::storeu(c + i + vec_size * 1 + j * M, vec_type::mul(alpha_vec, r21));
159  }
160  }
161 
162  for (; i + vec_size - 1 < M; i += vec_size) {
163  size_t j = 0;
164 
165  for (; (j + 2UL) <= N; j += 2UL) {
166  auto r1 = vec_type::template zero<T>();
167  auto r2 = vec_type::template zero<T>();
168 
169  for (size_t k = 0; k < K; ++k) {
170  auto a1 = vec_type::loadu(a + i + k * M);
171 
172  auto b1 = vec_type::set(b[k + (j + 0) * K]);
173  auto b2 = vec_type::set(b[k + (j + 1) * K]);
174 
175  r1 = vec_type::fmadd(a1, b1, r1);
176  r2 = vec_type::fmadd(a1, b2, r2);
177  }
178 
179  vec_type::storeu(c + i + (j + 0) * M, vec_type::mul(alpha_vec, r1));
180  vec_type::storeu(c + i + (j + 1) * M, vec_type::mul(alpha_vec, r2));
181  }
182 
183  if (j < N) {
184  auto r1 = vec_type::template zero<T>();
185 
186  for (size_t k = 0; k < K; ++k) {
187  auto a1 = vec_type::loadu(a + i + k * M);
188 
189  auto b1 = vec_type::set(b[k + j * K]);
190 
191  r1 = vec_type::fmadd(a1, b1, r1);
192  }
193 
194  vec_type::storeu(c + i + j * M, vec_type::mul(alpha_vec, r1));
195  }
196  }
197 
198  for (; i < M; ++i) {
199  size_t j = 0;
200 
201  for (; j + 1 < N; j += 2) {
202  T value1(0);
203  T value2(0);
204 
205  for (size_t k = 0; k < K; ++k) {
206  value1 += a[i + k * M] * b[k + (j + 0) * K];
207  value2 += a[i + k * M] * b[k + (j + 1) * K];
208  }
209 
210  c[i + (j + 0) * M] = alpha * value1;
211  c[i + (j + 1) * M] = alpha * value2;
212  }
213 
214  if (j < N) {
215  T value(0);
216 
217  for (size_t k = 0; k < K; ++k) {
218  value += a[i + k * M] * b[k + j * K];
219  }
220 
221  c[i + j * M] = alpha * value;
222  }
223  }
224 }
225 
232 template <typename V, typename T>
233 void gemm_large_kernel_cc_to_c(const T* a, const T* b, T* c, size_t M, size_t N, size_t K, T alpha) {
234  using vec_type = V;
235 
236  static constexpr size_t vec_size = vec_type::template traits<T>::size;
237 
238  constexpr size_t m_block_size = 128;
239  constexpr size_t n_block_size = 64;
240  constexpr size_t k_block_size = 128;
241 
242  auto alpha_vec = vec_type::set(alpha);
243 
244  for (size_t block_i = 0; block_i < M; block_i += m_block_size) {
245  const size_t i_end = std::min(block_i + m_block_size, M);
246 
247  for (size_t block_j = 0; block_j < N; block_j += n_block_size) {
248  const size_t j_end = std::min(block_j + n_block_size, N);
249 
250  // Clear the block
251  for (size_t j = block_j; j < j_end; ++j) {
252  for (size_t i = block_i; i < i_end; ++i) {
253  c[i + j * M] = 0;
254  }
255  }
256 
257  for (size_t block_k = 0; block_k < K; block_k += k_block_size) {
258  const size_t k_end = std::min(block_k + k_block_size, K);
259 
260  size_t i = block_i;
261 
262  // 4x unrolled vectorized inner loop
263  for (; i + 4 * vec_size - 1 < i_end; i += 4 * vec_size) {
264  size_t j = block_j;
265 
266  for (; j + 1 < j_end; j += 2) {
267  auto r11 = vec_type::loadu(c + i + 0 * vec_size + (j + 0) * M);
268  auto r12 = vec_type::loadu(c + i + 1 * vec_size + (j + 0) * M);
269  auto r13 = vec_type::loadu(c + i + 2 * vec_size + (j + 0) * M);
270  auto r14 = vec_type::loadu(c + i + 3 * vec_size + (j + 0) * M);
271 
272  auto r21 = vec_type::loadu(c + i + 0 * vec_size + (j + 1) * M);
273  auto r22 = vec_type::loadu(c + i + 1 * vec_size + (j + 1) * M);
274  auto r23 = vec_type::loadu(c + i + 2 * vec_size + (j + 1) * M);
275  auto r24 = vec_type::loadu(c + i + 3 * vec_size + (j + 1) * M);
276 
277  for (size_t k = block_k; k < k_end; ++k) {
278  auto a1 = vec_type::loadu(a + i + 0 * vec_size + k * M);
279  auto a2 = vec_type::loadu(a + i + 1 * vec_size + k * M);
280  auto a3 = vec_type::loadu(a + i + 2 * vec_size + k * M);
281  auto a4 = vec_type::loadu(a + i + 3 * vec_size + k * M);
282 
283  auto b1 = vec_type::set(b[k + (j + 0) * K]);
284  auto b2 = vec_type::set(b[k + (j + 1) * K]);
285 
286  r11 = vec_type::fmadd(a1, b1, r11);
287  r12 = vec_type::fmadd(a2, b1, r12);
288  r13 = vec_type::fmadd(a3, b1, r13);
289  r14 = vec_type::fmadd(a4, b1, r14);
290 
291  r21 = vec_type::fmadd(a1, b2, r21);
292  r22 = vec_type::fmadd(a2, b2, r22);
293  r23 = vec_type::fmadd(a3, b2, r23);
294  r24 = vec_type::fmadd(a4, b2, r24);
295  }
296 
297  vec_type::storeu(c + i + 0 * vec_size + (j + 0) * M, vec_type::mul(alpha_vec, r11));
298  vec_type::storeu(c + i + 1 * vec_size + (j + 0) * M, vec_type::mul(alpha_vec, r12));
299  vec_type::storeu(c + i + 2 * vec_size + (j + 0) * M, vec_type::mul(alpha_vec, r13));
300  vec_type::storeu(c + i + 3 * vec_size + (j + 0) * M, vec_type::mul(alpha_vec, r14));
301 
302  vec_type::storeu(c + i + 0 * vec_size + (j + 1) * M, vec_type::mul(alpha_vec, r21));
303  vec_type::storeu(c + i + 1 * vec_size + (j + 1) * M, vec_type::mul(alpha_vec, r22));
304  vec_type::storeu(c + i + 2 * vec_size + (j + 1) * M, vec_type::mul(alpha_vec, r23));
305  vec_type::storeu(c + i + 3 * vec_size + (j + 1) * M, vec_type::mul(alpha_vec, r24));
306  }
307 
308  for (; j < j_end; ++j) {
309  auto r1 = vec_type::loadu(c + i + 0 * vec_size + j * M);
310  auto r2 = vec_type::loadu(c + i + 1 * vec_size + j * M);
311  auto r3 = vec_type::loadu(c + i + 2 * vec_size + j * M);
312  auto r4 = vec_type::loadu(c + i + 3 * vec_size + j * M);
313 
314  for (size_t k = block_k; k < k_end; ++k) {
315  auto a1 = vec_type::loadu(a + i + 0 * vec_size + k * M);
316  auto a2 = vec_type::loadu(a + i + 1 * vec_size + k * M);
317  auto a3 = vec_type::loadu(a + i + 2 * vec_size + k * M);
318  auto a4 = vec_type::loadu(a + i + 3 * vec_size + k * M);
319 
320  auto b1 = vec_type::set(b[k + j * K]);
321 
322  r1 = vec_type::fmadd(a1, b1, r1);
323  r2 = vec_type::fmadd(a2, b1, r2);
324  r3 = vec_type::fmadd(a3, b1, r3);
325  r4 = vec_type::fmadd(a4, b1, r4);
326  }
327 
328  vec_type::storeu(c + i + 0 * vec_size + j * M, vec_type::mul(alpha_vec, r1));
329  vec_type::storeu(c + i + 1 * vec_size + j * M, vec_type::mul(alpha_vec, r2));
330  vec_type::storeu(c + i + 2 * vec_size + j * M, vec_type::mul(alpha_vec, r3));
331  vec_type::storeu(c + i + 3 * vec_size + j * M, vec_type::mul(alpha_vec, r4));
332  }
333  }
334 
335  // 2x unrolled vectorized inner loop
336  for (; i + 2 * vec_size - 1 < i_end; i += 2 * vec_size) {
337  size_t j = block_j;
338 
339  for (; j + 3 < j_end; j += 4) {
340  auto r11 = vec_type::loadu(c + i + 0 * vec_size + (j + 0) * M);
341  auto r12 = vec_type::loadu(c + i + 1 * vec_size + (j + 0) * M);
342 
343  auto r21 = vec_type::loadu(c + i + 0 * vec_size + (j + 1) * M);
344  auto r22 = vec_type::loadu(c + i + 1 * vec_size + (j + 1) * M);
345 
346  auto r31 = vec_type::loadu(c + i + 0 * vec_size + (j + 2) * M);
347  auto r32 = vec_type::loadu(c + i + 1 * vec_size + (j + 2) * M);
348 
349  auto r41 = vec_type::loadu(c + i + 0 * vec_size + (j + 3) * M);
350  auto r42 = vec_type::loadu(c + i + 1 * vec_size + (j + 3) * M);
351 
352  for (size_t k = block_k; k < k_end; ++k) {
353  auto a1 = vec_type::loadu(a + i + 0 * vec_size + k * M);
354  auto a2 = vec_type::loadu(a + i + 1 * vec_size + k * M);
355 
356  auto b1 = vec_type::set(b[k + (j + 0) * K]);
357  auto b2 = vec_type::set(b[k + (j + 1) * K]);
358  auto b3 = vec_type::set(b[k + (j + 2) * K]);
359  auto b4 = vec_type::set(b[k + (j + 3) * K]);
360 
361  r11 = vec_type::fmadd(a1, b1, r11);
362  r12 = vec_type::fmadd(a2, b1, r12);
363 
364  r21 = vec_type::fmadd(a1, b2, r21);
365  r22 = vec_type::fmadd(a2, b2, r22);
366 
367  r31 = vec_type::fmadd(a1, b3, r31);
368  r32 = vec_type::fmadd(a2, b3, r32);
369 
370  r41 = vec_type::fmadd(a1, b4, r41);
371  r42 = vec_type::fmadd(a2, b4, r42);
372  }
373 
374  vec_type::storeu(c + i + 0 * vec_size + (j + 0) * M, vec_type::mul(alpha_vec, r11));
375  vec_type::storeu(c + i + 1 * vec_size + (j + 0) * M, vec_type::mul(alpha_vec, r12));
376 
377  vec_type::storeu(c + i + 0 * vec_size + (j + 1) * M, vec_type::mul(alpha_vec, r21));
378  vec_type::storeu(c + i + 1 * vec_size + (j + 1) * M, vec_type::mul(alpha_vec, r22));
379 
380  vec_type::storeu(c + i + 0 * vec_size + (j + 2) * M, vec_type::mul(alpha_vec, r31));
381  vec_type::storeu(c + i + 1 * vec_size + (j + 2) * M, vec_type::mul(alpha_vec, r32));
382 
383  vec_type::storeu(c + i + 0 * vec_size + (j + 3) * M, vec_type::mul(alpha_vec, r41));
384  vec_type::storeu(c + i + 1 * vec_size + (j + 3) * M, vec_type::mul(alpha_vec, r42));
385  }
386 
387  for (; j + 1 < j_end; j += 2) {
388  auto r11 = vec_type::loadu(c + i + 0 * vec_size + (j + 0) * M);
389  auto r12 = vec_type::loadu(c + i + 1 * vec_size + (j + 0) * M);
390 
391  auto r21 = vec_type::loadu(c + i + 0 * vec_size + (j + 1) * M);
392  auto r22 = vec_type::loadu(c + i + 1 * vec_size + (j + 1) * M);
393 
394  for (size_t k = block_k; k < k_end; ++k) {
395  auto a1 = vec_type::loadu(a + i + 0 * vec_size + k * M);
396  auto a2 = vec_type::loadu(a + i + 1 * vec_size + k * M);
397 
398  auto b1 = vec_type::set(b[k + (j + 0) * K]);
399  auto b2 = vec_type::set(b[k + (j + 1) * K]);
400 
401  r11 = vec_type::fmadd(a1, b1, r11);
402  r12 = vec_type::fmadd(a2, b1, r12);
403 
404  r21 = vec_type::fmadd(a1, b2, r21);
405  r22 = vec_type::fmadd(a2, b2, r22);
406  }
407 
408  vec_type::storeu(c + i + 0 * vec_size + (j + 0) * M, vec_type::mul(alpha_vec, r11));
409  vec_type::storeu(c + i + 1 * vec_size + (j + 0) * M, vec_type::mul(alpha_vec, r12));
410 
411  vec_type::storeu(c + i + 0 * vec_size + (j + 1) * M, vec_type::mul(alpha_vec, r21));
412  vec_type::storeu(c + i + 1 * vec_size + (j + 1) * M, vec_type::mul(alpha_vec, r22));
413  }
414 
415  for (; j < j_end; ++j) {
416  auto r1 = vec_type::loadu(c + i + 0 * vec_size + j * M);
417  auto r2 = vec_type::loadu(c + i + 1 * vec_size + j * M);
418 
419  for (size_t k = block_k; k < k_end; ++k) {
420  auto a1 = vec_type::loadu(a + i + 0 * vec_size + k * M);
421  auto a2 = vec_type::loadu(a + i + 1 * vec_size + k * M);
422 
423  auto b1 = vec_type::set(b[k + j * K]);
424 
425  r1 = vec_type::fmadd(a1, b1, r1);
426  r2 = vec_type::fmadd(a2, b1, r2);
427  }
428 
429  vec_type::storeu(c + i + 0 * vec_size + j * M, vec_type::mul(alpha_vec, r1));
430  vec_type::storeu(c + i + 1 * vec_size + j * M, vec_type::mul(alpha_vec, r2));
431  }
432  }
433 
434  // Vectorized inner loop
435  for (; i + vec_size - 1 < i_end; i += vec_size) {
436  for (size_t j = block_j; j < j_end; ++j) {
437  auto r1 = vec_type::loadu(c + i + j * M);
438 
439  for (size_t k = block_k; k < k_end; ++k) {
440  auto a1 = vec_type::loadu(a + i + k * M);
441  auto b1 = vec_type::set(b[k + j * K]);
442 
443  r1 = vec_type::fmadd(a1, b1, r1);
444  }
445 
446  vec_type::storeu(c + i + j * M, vec_type::mul(alpha_vec, r1));
447  }
448  }
449 
450  // Remainder inner loop
451  for (; i < i_end; ++i) {
452  for (size_t j = block_j; j < j_end; ++j) {
453  auto x = c[i + j * M];
454 
455  for (size_t k = block_k; k < k_end; ++k) {
456  x += a[i + k * M] * b[k + j * K];
457  }
458 
459  c[i + j * M] = alpha * x;
460  }
461  }
462  }
463  }
464  }
465 }
466 
479 template <typename T>
480 void gemm_cc_to_c(const T* a, const T* b, T* c, size_t M, size_t N, size_t K, T alpha) {
481  cpp_assert(vec_enabled, "At least one vector mode must be enabled for impl::VEC");
482  cpp_assert(vectorize_impl, "vectorize_impl must be enabled for impl::VEC");
483 
484  // Dispatch to the best kernel
485 
486  if (M * N <= gemm_cc_small_threshold) {
487  gemm_small_kernel_cc_to_c<default_vec>(a, b, c, M, N, K, alpha);
488  } else {
489  gemm_large_kernel_cc_to_c<default_vec>(a, b, c, M, N, K, alpha);
490  }
491 }
492 
493 } //end of namespace etl::impl::vec
Definition: bias_add.hpp:15
constexpr bool vectorize_impl
Indicates if the implementations can be automatically vectorized by ETL.
Definition: config.hpp:35
constexpr bool vec_enabled
Indicates if vectorization is available in any format.
Definition: config.hpp:220
typename V::template vec_type< value_type > vec_type
The vectorization type for V.
Definition: dyn_matrix_view.hpp:43
void storeu(vec_type< V > in, size_t i) noexcept
Store several elements in the matrix at once.
Definition: dyn_matrix_view.hpp:187
auto loadu(size_t x) const noexcept
Load several elements of the expression at once.
Definition: dyn_matrix_view.hpp:154
auto min(L &&lhs, R &&rhs)
Create an expression with the min value of lhs or rhs.
Definition: expression_builder.hpp:77
constexpr size_t gemm_cc_small_threshold
The number of elements of B after which we use BLAS-like kernel (for GEMM)
Definition: threshold.hpp:58
void gemm_cc_to_c(const T *a, const T *b, T *c, size_t M, size_t N, size_t K, T alpha)
Vectorized implementation of column-major matrix - column-major matrix multiplication and assignment ...
Definition: gemm_cc_to_c.hpp:480
void gemm_small_kernel_cc_to_c(const T *a, const T *b, T *c, size_t M, size_t N, size_t K, T alpha)
Optimized version of small GEMM for column major version.
Definition: gemm_cc_to_c.hpp:25
void gemm_large_kernel_cc_to_c(const T *a, const T *b, T *c, size_t M, size_t N, size_t K, T alpha)
Optimized version of large GEMM for column major version.
Definition: gemm_cc_to_c.hpp:233