Expression Templates Library (ETL)
batch_k_scale_plus_expr.hpp
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
8 #pragma once
9 
10 #include "etl/expr/base_temporary_expr.hpp"
11 
12 namespace etl {
13 
14 template <etl_1d A, etl_2d_or_4d B, etl_1d C>
15 struct batch_k_scale_plus_expr : base_temporary_expr_tern<batch_k_scale_plus_expr<A, B, C>, A, B, C> {
20 
21  static constexpr bool D4 = is_4d<B>;
22 
23  static constexpr auto storage_order = left_traits::storage_order;
24 
29  static constexpr bool gpu_computable =
30  (!D4 && impl::egblas::has_sbatch_k_scale_plus2 && all_row_major<A, B> && all_floating<A, B>)
31  || (D4 && impl::egblas::has_dbatch_k_scale_plus4 && all_row_major<A, B> && all_floating<A, B>);
32 
37  batch_k_scale_plus_expr(A a, B b, C c) : base_type(a, b, c) {
38  //Nothing else to init
39  }
40 
46  template <same_dimensions<B> L>
47  static void check([[maybe_unused]] const A& a, [[maybe_unused]] const B& b, [[maybe_unused]] const C& c, [[maybe_unused]] L& lhs) {
48  if constexpr (D4) {
49  if constexpr (all_fast<A, B, C, L>) {
50  static_assert(etl::dim<0, B>() == etl::dim<0, L>(), "Invalid dimensions for batch_k_scale_plus");
51  static_assert(etl::dim<1, B>() == etl::dim<1, L>(), "Invalid dimensions for batch_k_scale_plus");
52  static_assert(etl::dim<2, B>() == etl::dim<2, L>(), "Invalid dimensions for batch_k_scale_plus");
53  static_assert(etl::dim<3, B>() == etl::dim<3, L>(), "Invalid dimensions for batch_k_scale_plus");
54 
55  static_assert(etl::dim<0, A>() == etl::dim<1, B>(), "Invalid dimensions for batch_k_scale_plus");
56  static_assert(etl::dim<0, A>() == etl::dim<0, C>(), "Invalid dimensions for batch_k_scale_plus");
57  } else {
58  cpp_assert(etl::dim<0>(b) == etl::dim<0>(lhs), "Invalid dimensions for batch_k_scale_plus");
59  cpp_assert(etl::dim<1>(b) == etl::dim<1>(lhs), "Invalid dimensions for batch_k_scale_plus");
60  cpp_assert(etl::dim<2>(b) == etl::dim<2>(lhs), "Invalid dimensions for batch_k_scale_plus");
61  cpp_assert(etl::dim<3>(b) == etl::dim<3>(lhs), "Invalid dimensions for batch_k_scale_plus");
62 
63  cpp_assert(etl::dim<0>(a) == etl::dim<1>(b), "Invalid dimensions for batch_k_scale_plus");
64  cpp_assert(etl::dim<0>(a) == etl::dim<0>(c), "Invalid dimensions for batch_k_scale_plus");
65  }
66  } else {
67  if constexpr (all_fast<A, B, C, L>) {
68  static_assert(etl::dim<0, B>() == etl::dim<0, L>(), "Invalid dimensions for batch_k_scale_plus");
69  static_assert(etl::dim<1, B>() == etl::dim<1, L>(), "Invalid dimensions for batch_k_scale_plus");
70 
71  static_assert(etl::dim<0, A>() == etl::dim<1, B>(), "Invalid dimensions for batch_k_scale_plus");
72  static_assert(etl::dim<0, A>() == etl::dim<0, C>(), "Invalid dimensions for batch_k_scale_plus");
73  } else {
74  cpp_assert(etl::dim<0>(b) == etl::dim<0>(lhs), "Invalid dimensions for batch_k_scale_plus");
75  cpp_assert(etl::dim<1>(b) == etl::dim<1>(lhs), "Invalid dimensions for batch_k_scale_plus");
76 
77  cpp_assert(etl::dim<0>(a) == etl::dim<1>(b), "Invalid dimensions for batch_k_scale_plus");
78  cpp_assert(etl::dim<0>(a) == etl::dim<0>(c), "Invalid dimensions for batch_k_scale_plus");
79  }
80  }
81  }
82 
83  // Assignment functions
84 
89  template <etl_expr L>
90  void assign_to(L&& lhs) const {
91  inc_counter("temp:assign");
92 
93  auto& a = this->a();
94  auto& b = this->b();
95  auto& c = this->c();
96 
97  check(a, b, c, lhs);
98 
99  if constexpr (D4) {
100  const auto Batch = etl::dim<0>(lhs);
101  const auto K = etl::dim<1>(lhs);
102  const auto M = etl::dim<2>(lhs);
103  const auto N = etl::dim<3>(lhs);
104 
105  if constexpr (impl::egblas::has_sbatch_k_scale4 && all_row_major<A, B, L> && all_floating<A, B, L>) {
106  decltype(auto) t1 = smart_forward_gpu(a);
107  decltype(auto) t2 = smart_forward_gpu(b);
108  decltype(auto) t3 = smart_forward_gpu(c);
109 
110  t1.ensure_gpu_up_to_date();
111  t2.ensure_gpu_up_to_date();
112  t3.ensure_gpu_up_to_date();
113 
114  lhs.ensure_gpu_allocated();
115 
116  impl::egblas::batch_k_scale_plus(Batch, K, M, N, t2.gpu_memory(), t1.gpu_memory(), t3.gpu_memory(), lhs.gpu_memory());
117 
118  lhs.validate_gpu();
119  lhs.invalidate_cpu();
120  } else {
121  standard_evaluator::pre_assign_rhs(a);
122  standard_evaluator::pre_assign_rhs(b);
123 
124  a.ensure_cpu_up_to_date();
125  b.ensure_cpu_up_to_date();
126  c.ensure_cpu_up_to_date();
127 
128  auto batch_fun_b = [&](const size_t first, const size_t last) {
129  CPU_SECTION {
130  if constexpr (vec_enabled && all_vectorizable<vector_mode, A, L> && all_row_major<A, L>) {
131  using vec_type = default_vec;
132  using T = value_t<L>;
133 
134  static constexpr size_t vec_size = vec_type::template traits<T>::size;
135 
136  const auto MN = M * N;
137 
138  for (size_t batch = first; batch < last; ++batch) {
139  for (size_t k = 0; k < K; ++k) {
140  T ak = a(k);
141  T ck = c(k);
142 
143  auto lhs_sub = lhs(batch)(k);
144  auto b_sub = b(batch)(k);
145 
146  size_t mn = 0;
147 
148  auto a1 = vec_type::set(ak);
149  auto c1 = vec_type::set(ck);
150 
151  for (; mn + 4 * vec_size - 1 < MN; mn += 4 * vec_size) {
152  auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
153  auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
154  auto b3 = b_sub.template loadu<vec_type>(mn + 2 * vec_size);
155  auto b4 = b_sub.template loadu<vec_type>(mn + 3 * vec_size);
156 
157  auto r1 = vec_type::add(vec_type::mul(a1, b1), c1);
158  auto r2 = vec_type::add(vec_type::mul(a1, b2), c1);
159  auto r3 = vec_type::add(vec_type::mul(a1, b3), c1);
160  auto r4 = vec_type::add(vec_type::mul(a1, b4), c1);
161 
162  lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
163  lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
164  lhs_sub.template storeu<vec_type>(r3, mn + 2 * vec_size);
165  lhs_sub.template storeu<vec_type>(r4, mn + 3 * vec_size);
166  }
167 
168  for (; mn + 2 * vec_size - 1 < MN; mn += 2 * vec_size) {
169  auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
170  auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
171 
172  auto r1 = vec_type::add(vec_type::mul(a1, b1), c1);
173  auto r2 = vec_type::add(vec_type::mul(a1, b2), c1);
174 
175  lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
176  lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
177  }
178 
179  for (; mn + vec_size - 1 < MN; mn += vec_size) {
180  auto b1 = b_sub.template loadu<vec_type>(mn);
181 
182  auto r1 = vec_type::add(vec_type::mul(a1, b1), c1);
183 
184  lhs_sub.template storeu<vec_type>(r1, mn);
185  }
186 
187  for (; mn + 3 < MN; mn += 4) {
188  lhs_sub[mn + 0] = ak * b_sub[mn + 0] + ck;
189  lhs_sub[mn + 1] = ak * b_sub[mn + 1] + ck;
190  lhs_sub[mn + 2] = ak * b_sub[mn + 2] + ck;
191  lhs_sub[mn + 3] = ak * b_sub[mn + 3] + ck;
192  }
193 
194  for (; mn + 1 < MN; mn += 2) {
195  lhs_sub[mn + 0] = ak * b_sub[mn + 0] + ck;
196  lhs_sub[mn + 1] = ak * b_sub[mn + 1] + ck;
197  }
198 
199  for (; mn < MN; ++mn) {
200  lhs_sub[mn] = ak * b_sub[mn] + ck;
201  }
202  }
203  }
204  } else {
205  for (size_t batch = first; batch < last; ++batch) {
206  for (size_t k = 0; k < K; ++k) {
207  for (size_t m = 0; m < M; ++m) {
208  for (size_t n = 0; n < N; ++n) {
209  lhs(batch, k, m, n) = a(k) * b(batch, k, m, n) + c(k);
210  }
211  }
212  }
213  }
214  }
215  }
216  };
217 
218  engine_dispatch_1d_serial(batch_fun_b, 0, Batch, 2UL);
219 
220  lhs.validate_cpu();
221  lhs.invalidate_gpu();
222  }
223  } else {
224  const auto Batch = etl::dim<0>(lhs);
225  const auto K = etl::dim<1>(lhs);
226 
227  if constexpr (impl::egblas::has_sbatch_k_scale_plus2 && all_row_major<A, B, L> && all_floating<A, B, L>) {
228  decltype(auto) t1 = smart_forward_gpu(a);
229  decltype(auto) t2 = smart_forward_gpu(b);
230  decltype(auto) t3 = smart_forward_gpu(c);
231 
232  t1.ensure_gpu_up_to_date();
233  t2.ensure_gpu_up_to_date();
234  t3.ensure_gpu_up_to_date();
235 
236  lhs.ensure_gpu_allocated();
237 
238  impl::egblas::batch_k_scale_plus(Batch, K, t2.gpu_memory(), t1.gpu_memory(), t3.gpu_memory(), lhs.gpu_memory());
239 
240  lhs.validate_gpu();
241  lhs.invalidate_cpu();
242  } else {
243  standard_evaluator::pre_assign_rhs(a);
244  standard_evaluator::pre_assign_rhs(b);
245 
246  a.ensure_cpu_up_to_date();
247  b.ensure_cpu_up_to_date();
248  c.ensure_cpu_up_to_date();
249 
250  auto batch_fun_b = [&](const size_t first, const size_t last) {
251  CPU_SECTION {
252  if constexpr (vec_enabled && all_vectorizable<vector_mode, A, B, L> && all_row_major<A, B, L>) {
253  using vec_type = default_vec;
254  using T = value_t<L>;
255 
256  static constexpr size_t vec_size = vec_type::template traits<T>::size;
257 
258  for (size_t batch = first; batch < last; ++batch) {
259  size_t k = 0;
260 
261  const size_t base = batch * K;
262 
263  for (; k + 4 * vec_size - 1 < K; k += 4 * vec_size) {
264  auto a1 = a.template load<vec_type>(k + 0 * vec_size);
265  auto a2 = a.template load<vec_type>(k + 1 * vec_size);
266  auto a3 = a.template load<vec_type>(k + 2 * vec_size);
267  auto a4 = a.template load<vec_type>(k + 3 * vec_size);
268 
269  auto b1 = b.template loadu<vec_type>(base + k + 0 * vec_size);
270  auto b2 = b.template loadu<vec_type>(base + k + 1 * vec_size);
271  auto b3 = b.template loadu<vec_type>(base + k + 2 * vec_size);
272  auto b4 = b.template loadu<vec_type>(base + k + 3 * vec_size);
273 
274  auto c1 = c.template loadu<vec_type>(k + 0 * vec_size);
275  auto c2 = c.template loadu<vec_type>(k + 1 * vec_size);
276  auto c3 = c.template loadu<vec_type>(k + 2 * vec_size);
277  auto c4 = c.template loadu<vec_type>(k + 3 * vec_size);
278 
279  auto r1 = vec_type::add(vec_type::mul(a1, b1), c1);
280  auto r2 = vec_type::add(vec_type::mul(a2, b2), c2);
281  auto r3 = vec_type::add(vec_type::mul(a3, b3), c3);
282  auto r4 = vec_type::add(vec_type::mul(a4, b4), c4);
283 
284  lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
285  lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
286  lhs.template storeu<vec_type>(r3, base + k + 2 * vec_size);
287  lhs.template storeu<vec_type>(r4, base + k + 3 * vec_size);
288  }
289 
290  for (; k + 2 * vec_size - 1 < K; k += 2 * vec_size) {
291  auto a1 = a.template load<vec_type>(k + 0 * vec_size);
292  auto a2 = a.template load<vec_type>(k + 1 * vec_size);
293 
294  auto b1 = b.template loadu<vec_type>(base + k + 0 * vec_size);
295  auto b2 = b.template loadu<vec_type>(base + k + 1 * vec_size);
296 
297  auto c1 = c.template loadu<vec_type>(k + 0 * vec_size);
298  auto c2 = c.template loadu<vec_type>(k + 1 * vec_size);
299 
300  auto r1 = vec_type::add(vec_type::mul(a1, b1), c1);
301  auto r2 = vec_type::add(vec_type::mul(a2, b2), c2);
302 
303  lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
304  lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
305  }
306 
307  for (; k + vec_size - 1 < K; k += vec_size) {
308  auto a1 = a.template load<vec_type>(k);
309 
310  auto b1 = b.template loadu<vec_type>(base + k);
311 
312  auto c1 = c.template loadu<vec_type>(k);
313 
314  auto r1 = vec_type::add(vec_type::mul(a1, b1), c1);
315 
316  lhs.template storeu<vec_type>(r1, base + k);
317  }
318 
319  for (; k + 3 < K; k += 4) {
320  lhs(batch, k + 0) = a(k + 0) * b(batch, k + 0) + c(k + 0);
321  lhs(batch, k + 1) = a(k + 1) * b(batch, k + 1) + c(k + 1);
322  lhs(batch, k + 2) = a(k + 2) * b(batch, k + 2) + c(k + 2);
323  lhs(batch, k + 3) = a(k + 3) * b(batch, k + 3) + c(k + 3);
324  }
325 
326  for (; k + 1 < K; k += 2) {
327  lhs(batch, k + 0) = a(k + 0) * b(batch, k + 0) + c(k + 0);
328  lhs(batch, k + 1) = a(k + 1) * b(batch, k + 1) + c(k + 1);
329  }
330 
331  if (k < K) {
332  lhs(batch, k) = a(k) * b(batch, k) + c(k);
333  }
334  }
335  } else {
336  for (size_t batch = first; batch < last; ++batch) {
337  for (size_t k = 0; k < K; ++k) {
338  lhs(batch, k) = a(k) * b(batch, k) + c(k);
339  }
340  }
341  }
342  }
343  };
344 
345  engine_dispatch_1d_serial(batch_fun_b, 0, Batch, 2UL);
346 
347  lhs.validate_cpu();
348  lhs.invalidate_gpu();
349  }
350  }
351  }
352 
357  template <etl_expr L>
358  void assign_add_to(L&& lhs) const {
359  auto& a = this->a();
360  auto& b = this->b();
361  auto& c = this->c();
362 
363  check(a, b, c, lhs);
364 
365  if constexpr (D4) {
366  if constexpr (impl::egblas::has_sbatch_k_scale_plus4 && all_row_major<A, B, L> && all_floating<A, B, L>) {
367  std_add_evaluate(*this, lhs);
368  } else {
369  const auto Batch = etl::dim<0>(lhs);
370  const auto K = etl::dim<1>(lhs);
371  const auto M = etl::dim<2>(lhs);
372  const auto N = etl::dim<3>(lhs);
373 
374  standard_evaluator::pre_assign_rhs(a);
375  standard_evaluator::pre_assign_rhs(b);
376 
377  a.ensure_cpu_up_to_date();
378  b.ensure_cpu_up_to_date();
379  c.ensure_cpu_up_to_date();
380  lhs.ensure_cpu_up_to_date();
381 
382  auto batch_fun_b = [&](const size_t first, const size_t last) {
383  CPU_SECTION {
384  if constexpr (vec_enabled && all_vectorizable<vector_mode, A, L> && all_row_major<A, L>) {
385  using vec_type = default_vec;
386  using T = value_t<L>;
387 
388  static constexpr size_t vec_size = vec_type::template traits<T>::size;
389 
390  const auto MN = M * N;
391 
392  for (size_t batch = first; batch < last; ++batch) {
393  for (size_t k = 0; k < K; ++k) {
394  T ak = a(k);
395  T ck = c(k);
396 
397  auto lhs_sub = lhs(batch)(k);
398  auto b_sub = b(batch)(k);
399 
400  size_t mn = 0;
401 
402  auto a1 = vec_type::set(ak);
403  auto c1 = vec_type::set(ck);
404 
405  for (; mn + 4 * vec_size - 1 < MN; mn += 4 * vec_size) {
406  auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
407  auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
408  auto b3 = b_sub.template loadu<vec_type>(mn + 2 * vec_size);
409  auto b4 = b_sub.template loadu<vec_type>(mn + 3 * vec_size);
410 
411  auto l1 = lhs_sub.template loadu<vec_type>(mn + 0 * vec_size);
412  auto l2 = lhs_sub.template loadu<vec_type>(mn + 1 * vec_size);
413  auto l3 = lhs_sub.template loadu<vec_type>(mn + 2 * vec_size);
414  auto l4 = lhs_sub.template loadu<vec_type>(mn + 3 * vec_size);
415 
416  auto r1 = vec_type::add(l1, vec_type::add(vec_type::mul(a1, b1), c1));
417  auto r2 = vec_type::add(l2, vec_type::add(vec_type::mul(a1, b2), c1));
418  auto r3 = vec_type::add(l3, vec_type::add(vec_type::mul(a1, b3), c1));
419  auto r4 = vec_type::add(l4, vec_type::add(vec_type::mul(a1, b4), c1));
420 
421  lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
422  lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
423  lhs_sub.template storeu<vec_type>(r3, mn + 2 * vec_size);
424  lhs_sub.template storeu<vec_type>(r4, mn + 3 * vec_size);
425  }
426 
427  for (; mn + 2 * vec_size - 1 < MN; mn += 2 * vec_size) {
428  auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
429  auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
430 
431  auto l1 = lhs_sub.template loadu<vec_type>(mn + 0 * vec_size);
432  auto l2 = lhs_sub.template loadu<vec_type>(mn + 1 * vec_size);
433 
434  auto r1 = vec_type::add(l1, vec_type::add(vec_type::mul(a1, b1), c1));
435  auto r2 = vec_type::add(l2, vec_type::add(vec_type::mul(a1, b2), c1));
436 
437  lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
438  lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
439  }
440 
441  for (; mn + vec_size - 1 < MN; mn += vec_size) {
442  auto b1 = b_sub.template loadu<vec_type>(mn);
443 
444  auto l1 = lhs_sub.template loadu<vec_type>(mn);
445 
446  auto r1 = vec_type::add(l1, vec_type::add(vec_type::mul(a1, b1), c1));
447 
448  lhs_sub.template storeu<vec_type>(r1, mn);
449  }
450 
451  for (; mn + 3 < MN; mn += 4) {
452  lhs_sub[mn + 0] += ak * b_sub[mn + 0] + ck;
453  lhs_sub[mn + 1] += ak * b_sub[mn + 1] + ck;
454  lhs_sub[mn + 2] += ak * b_sub[mn + 2] + ck;
455  lhs_sub[mn + 3] += ak * b_sub[mn + 3] + ck;
456  }
457 
458  for (; mn + 1 < MN; mn += 2) {
459  lhs_sub[mn + 0] += ak * b_sub[mn + 0] + ck;
460  lhs_sub[mn + 1] += ak * b_sub[mn + 1] + ck;
461  }
462 
463  for (; mn < MN; ++mn) {
464  lhs_sub[mn] += ak * b_sub[mn] + ck;
465  }
466  }
467  }
468  } else {
469  for (size_t batch = first; batch < last; ++batch) {
470  for (size_t k = 0; k < K; ++k) {
471  for (size_t m = 0; m < M; ++m) {
472  for (size_t n = 0; n < N; ++n) {
473  lhs(batch, k, m, n) += a(k) * b(batch, k, m, n) + c(k);
474  }
475  }
476  }
477  }
478  }
479  }
480  };
481 
482  engine_dispatch_1d_serial(batch_fun_b, 0, Batch, 2UL);
483 
484  lhs.validate_cpu();
485  lhs.invalidate_gpu();
486  }
487  } else {
488  if constexpr (impl::egblas::has_sbatch_k_scale_plus2 && all_row_major<A, B, L> && all_floating<A, B, L>) {
489  std_add_evaluate(*this, lhs);
490  } else {
491  const auto Batch = etl::dim<0>(lhs);
492  const auto K = etl::dim<1>(lhs);
493 
494  standard_evaluator::pre_assign_rhs(a);
495  standard_evaluator::pre_assign_rhs(b);
496 
497  a.ensure_cpu_up_to_date();
498  b.ensure_cpu_up_to_date();
499  c.ensure_cpu_up_to_date();
500  lhs.ensure_cpu_up_to_date();
501 
502  auto batch_fun_b = [&](const size_t first, const size_t last) {
503  CPU_SECTION {
504  if constexpr (vec_enabled && all_vectorizable<vector_mode, A, B, L> && all_row_major<A, B, L>) {
505  using vec_type = default_vec;
506  using T = value_t<L>;
507 
508  static constexpr size_t vec_size = vec_type::template traits<T>::size;
509 
510  for (size_t batch = first; batch < last; ++batch) {
511  size_t k = 0;
512 
513  const size_t base = batch * K;
514 
515  for (; k + 4 * vec_size - 1 < K; k += 4 * vec_size) {
516  auto a1 = a.template load<vec_type>(k + 0 * vec_size);
517  auto a2 = a.template load<vec_type>(k + 1 * vec_size);
518  auto a3 = a.template load<vec_type>(k + 2 * vec_size);
519  auto a4 = a.template load<vec_type>(k + 3 * vec_size);
520 
521  auto b1 = b.template loadu<vec_type>(base + k + 0 * vec_size);
522  auto b2 = b.template loadu<vec_type>(base + k + 1 * vec_size);
523  auto b3 = b.template loadu<vec_type>(base + k + 2 * vec_size);
524  auto b4 = b.template loadu<vec_type>(base + k + 3 * vec_size);
525 
526  auto c1 = c.template loadu<vec_type>(k + 0 * vec_size);
527  auto c2 = c.template loadu<vec_type>(k + 1 * vec_size);
528  auto c3 = c.template loadu<vec_type>(k + 2 * vec_size);
529  auto c4 = c.template loadu<vec_type>(k + 3 * vec_size);
530 
531  auto l1 = lhs.template loadu<vec_type>(base + k + 0 * vec_size);
532  auto l2 = lhs.template loadu<vec_type>(base + k + 1 * vec_size);
533  auto l3 = lhs.template loadu<vec_type>(base + k + 2 * vec_size);
534  auto l4 = lhs.template loadu<vec_type>(base + k + 3 * vec_size);
535 
536  auto r1 = vec_type::add(l1, vec_type::add(vec_type::mul(a1, b1), c1));
537  auto r2 = vec_type::add(l2, vec_type::add(vec_type::mul(a2, b2), c2));
538  auto r3 = vec_type::add(l3, vec_type::add(vec_type::mul(a3, b3), c3));
539  auto r4 = vec_type::add(l4, vec_type::add(vec_type::mul(a4, b4), c4));
540 
541  lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
542  lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
543  lhs.template storeu<vec_type>(r3, base + k + 2 * vec_size);
544  lhs.template storeu<vec_type>(r4, base + k + 3 * vec_size);
545  }
546 
547  for (; k + 2 * vec_size - 1 < K; k += 2 * vec_size) {
548  auto a1 = a.template load<vec_type>(k + 0 * vec_size);
549  auto a2 = a.template load<vec_type>(k + 1 * vec_size);
550 
551  auto b1 = b.template loadu<vec_type>(base + k + 0 * vec_size);
552  auto b2 = b.template loadu<vec_type>(base + k + 1 * vec_size);
553 
554  auto c1 = c.template loadu<vec_type>(k + 0 * vec_size);
555  auto c2 = c.template loadu<vec_type>(k + 1 * vec_size);
556 
557  auto l1 = lhs.template loadu<vec_type>(base + k + 0 * vec_size);
558  auto l2 = lhs.template loadu<vec_type>(base + k + 1 * vec_size);
559 
560  auto r1 = vec_type::add(l1, vec_type::add(vec_type::mul(a1, b1), c1));
561  auto r2 = vec_type::add(l2, vec_type::add(vec_type::mul(a2, b2), c2));
562 
563  lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
564  lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
565  }
566 
567  for (; k + vec_size - 1 < K; k += vec_size) {
568  auto a1 = a.template load<vec_type>(k);
569 
570  auto b1 = b.template loadu<vec_type>(base + k);
571 
572  auto c1 = c.template loadu<vec_type>(k);
573 
574  auto l1 = lhs.template loadu<vec_type>(base + k);
575 
576  auto r1 = vec_type::add(l1, vec_type::add(vec_type::mul(a1, b1), c1));
577 
578  lhs.template storeu<vec_type>(r1, base + k);
579  }
580 
581  for (; k + 3 < K; k += 4) {
582  lhs(batch, k + 0) += a(k + 0) * b(batch, k + 0) + c(k + 0);
583  lhs(batch, k + 1) += a(k + 1) * b(batch, k + 1) + c(k + 1);
584  lhs(batch, k + 2) += a(k + 2) * b(batch, k + 2) + c(k + 2);
585  lhs(batch, k + 3) += a(k + 3) * b(batch, k + 3) + c(k + 3);
586  }
587 
588  for (; k + 1 < K; k += 2) {
589  lhs(batch, k + 0) += a(k + 0) * b(batch, k + 0) + c(k + 0);
590  lhs(batch, k + 1) += a(k + 1) * b(batch, k + 1) + c(k + 1);
591  }
592 
593  if (k < K) {
594  lhs(batch, k) += a(k) * b(batch, k) + c(k);
595  }
596  }
597  } else {
598  for (size_t batch = first; batch < last; ++batch) {
599  for (size_t k = 0; k < K; ++k) {
600  lhs(batch, k) += a(k) * b(batch, k) + c(k);
601  }
602  }
603  }
604  }
605  };
606 
607  engine_dispatch_1d_serial(batch_fun_b, 0, Batch, 2UL);
608 
609  lhs.validate_cpu();
610  lhs.invalidate_gpu();
611  }
612  }
613  }
614 
619  template <etl_expr L>
620  void assign_sub_to(L&& lhs) const {
621  auto& a = this->a();
622  auto& b = this->b();
623  auto& c = this->c();
624 
625  check(a, b, c, lhs);
626 
627  if constexpr (D4) {
628  if constexpr (impl::egblas::has_sbatch_k_scale_plus2 && all_row_major<A, B, L> && all_floating<A, B, L>) {
629  std_sub_evaluate(*this, lhs);
630  } else {
631  const auto Batch = etl::dim<0>(lhs);
632  const auto K = etl::dim<1>(lhs);
633  const auto M = etl::dim<2>(lhs);
634  const auto N = etl::dim<3>(lhs);
635 
636  standard_evaluator::pre_assign_rhs(a);
637  standard_evaluator::pre_assign_rhs(b);
638 
639  a.ensure_cpu_up_to_date();
640  b.ensure_cpu_up_to_date();
641  c.ensure_cpu_up_to_date();
642  lhs.ensure_cpu_up_to_date();
643 
644  auto batch_fun_b = [&](const size_t first, const size_t last) {
645  CPU_SECTION {
646  if constexpr (vec_enabled && all_vectorizable<vector_mode, A, L> && all_row_major<A, L>) {
647  using vec_type = default_vec;
648  using T = value_t<L>;
649 
650  static constexpr size_t vec_size = vec_type::template traits<T>::size;
651 
652  const auto MN = M * N;
653 
654  for (size_t batch = first; batch < last; ++batch) {
655  for (size_t k = 0; k < K; ++k) {
656  T ak = a(k);
657  T ck = c(k);
658 
659  auto lhs_sub = lhs(batch)(k);
660  auto b_sub = b(batch)(k);
661 
662  size_t mn = 0;
663 
664  auto a1 = vec_type::set(ak);
665  auto c1 = vec_type::set(ck);
666 
667  for (; mn + 4 * vec_size - 1 < MN; mn += 4 * vec_size) {
668  auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
669  auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
670  auto b3 = b_sub.template loadu<vec_type>(mn + 2 * vec_size);
671  auto b4 = b_sub.template loadu<vec_type>(mn + 3 * vec_size);
672 
673  auto l1 = lhs_sub.template loadu<vec_type>(mn + 0 * vec_size);
674  auto l2 = lhs_sub.template loadu<vec_type>(mn + 1 * vec_size);
675  auto l3 = lhs_sub.template loadu<vec_type>(mn + 2 * vec_size);
676  auto l4 = lhs_sub.template loadu<vec_type>(mn + 3 * vec_size);
677 
678  auto r1 = vec_type::sub(l1, vec_type::add(vec_type::mul(a1, b1), c1));
679  auto r2 = vec_type::sub(l2, vec_type::add(vec_type::mul(a1, b2), c1));
680  auto r3 = vec_type::sub(l3, vec_type::add(vec_type::mul(a1, b3), c1));
681  auto r4 = vec_type::sub(l4, vec_type::add(vec_type::mul(a1, b4), c1));
682 
683  lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
684  lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
685  lhs_sub.template storeu<vec_type>(r3, mn + 2 * vec_size);
686  lhs_sub.template storeu<vec_type>(r4, mn + 3 * vec_size);
687  }
688 
689  for (; mn + 2 * vec_size - 1 < MN; mn += 2 * vec_size) {
690  auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
691  auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
692 
693  auto l1 = lhs_sub.template loadu<vec_type>(mn + 0 * vec_size);
694  auto l2 = lhs_sub.template loadu<vec_type>(mn + 1 * vec_size);
695 
696  auto r1 = vec_type::sub(l1, vec_type::add(vec_type::mul(a1, b1), c1));
697  auto r2 = vec_type::sub(l2, vec_type::add(vec_type::mul(a1, b2), c1));
698 
699  lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
700  lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
701  }
702 
703  for (; mn + vec_size - 1 < MN; mn += vec_size) {
704  auto b1 = b_sub.template loadu<vec_type>(mn);
705 
706  auto l1 = lhs_sub.template loadu<vec_type>(mn);
707 
708  auto r1 = vec_type::sub(l1, vec_type::add(vec_type::mul(a1, b1), c1));
709 
710  lhs_sub.template storeu<vec_type>(r1, mn);
711  }
712 
713  for (; mn + 3 < MN; mn += 4) {
714  lhs_sub[mn + 0] -= ak * b_sub[mn + 0] + ck;
715  lhs_sub[mn + 1] -= ak * b_sub[mn + 1] + ck;
716  lhs_sub[mn + 2] -= ak * b_sub[mn + 2] + ck;
717  lhs_sub[mn + 3] -= ak * b_sub[mn + 3] + ck;
718  }
719 
720  for (; mn + 1 < MN; mn += 2) {
721  lhs_sub[mn + 0] -= ak * b_sub[mn + 0] + ck;
722  lhs_sub[mn + 1] -= ak * b_sub[mn + 1] + ck;
723  }
724 
725  for (; mn < MN; ++mn) {
726  lhs_sub[mn] -= ak * b_sub[mn] + ck;
727  }
728  }
729  }
730  } else {
731  for (size_t batch = first; batch < last; ++batch) {
732  for (size_t k = 0; k < K; ++k) {
733  for (size_t m = 0; m < M; ++m) {
734  for (size_t n = 0; n < N; ++n) {
735  lhs(batch, k, m, n) -= a(k) * b(batch, k, m, n) + c(k);
736  }
737  }
738  }
739  }
740  }
741  }
742  };
743 
744  engine_dispatch_1d_serial(batch_fun_b, 0, Batch, 2UL);
745 
746  lhs.validate_cpu();
747  lhs.invalidate_gpu();
748  }
749  } else {
750  if constexpr (impl::egblas::has_sbatch_k_scale_plus2 && all_row_major<A, B, L> && all_floating<A, B, L>) {
751  std_sub_evaluate(*this, lhs);
752  } else {
753  const auto Batch = etl::dim<0>(lhs);
754  const auto K = etl::dim<1>(lhs);
755 
756  standard_evaluator::pre_assign_rhs(a);
757  standard_evaluator::pre_assign_rhs(b);
758 
759  a.ensure_cpu_up_to_date();
760  b.ensure_cpu_up_to_date();
761  c.ensure_cpu_up_to_date();
762  lhs.ensure_cpu_up_to_date();
763 
764  auto batch_fun_b = [&](const size_t first, const size_t last) {
765  CPU_SECTION {
766  if constexpr (vec_enabled && all_vectorizable<vector_mode, A, B, L> && all_row_major<A, B, L>) {
767  using vec_type = default_vec;
768  using T = value_t<L>;
769 
770  static constexpr size_t vec_size = vec_type::template traits<T>::size;
771 
772  for (size_t batch = first; batch < last; ++batch) {
773  size_t k = 0;
774 
775  const size_t base = batch * K;
776 
777  for (; k + 4 * vec_size - 1 < K; k += 4 * vec_size) {
778  auto a1 = a.template load<vec_type>(k + 0 * vec_size);
779  auto a2 = a.template load<vec_type>(k + 1 * vec_size);
780  auto a3 = a.template load<vec_type>(k + 2 * vec_size);
781  auto a4 = a.template load<vec_type>(k + 3 * vec_size);
782 
783  auto b1 = b.template loadu<vec_type>(base + k + 0 * vec_size);
784  auto b2 = b.template loadu<vec_type>(base + k + 1 * vec_size);
785  auto b3 = b.template loadu<vec_type>(base + k + 2 * vec_size);
786  auto b4 = b.template loadu<vec_type>(base + k + 3 * vec_size);
787 
788  auto c1 = c.template loadu<vec_type>(k + 0 * vec_size);
789  auto c2 = c.template loadu<vec_type>(k + 1 * vec_size);
790  auto c3 = c.template loadu<vec_type>(k + 2 * vec_size);
791  auto c4 = c.template loadu<vec_type>(k + 3 * vec_size);
792 
793  auto l1 = lhs.template loadu<vec_type>(base + k + 0 * vec_size);
794  auto l2 = lhs.template loadu<vec_type>(base + k + 1 * vec_size);
795  auto l3 = lhs.template loadu<vec_type>(base + k + 2 * vec_size);
796  auto l4 = lhs.template loadu<vec_type>(base + k + 3 * vec_size);
797 
798  auto r1 = vec_type::sub(l1, vec_type::add(vec_type::mul(a1, b1), c1));
799  auto r2 = vec_type::sub(l2, vec_type::add(vec_type::mul(a2, b2), c2));
800  auto r3 = vec_type::sub(l3, vec_type::add(vec_type::mul(a3, b3), c3));
801  auto r4 = vec_type::sub(l4, vec_type::add(vec_type::mul(a4, b4), c4));
802 
803  lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
804  lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
805  lhs.template storeu<vec_type>(r3, base + k + 2 * vec_size);
806  lhs.template storeu<vec_type>(r4, base + k + 3 * vec_size);
807  }
808 
809  for (; k + 2 * vec_size - 1 < K; k += 2 * vec_size) {
810  auto a1 = a.template load<vec_type>(k + 0 * vec_size);
811  auto a2 = a.template load<vec_type>(k + 1 * vec_size);
812 
813  auto b1 = b.template loadu<vec_type>(base + k + 0 * vec_size);
814  auto b2 = b.template loadu<vec_type>(base + k + 1 * vec_size);
815 
816  auto c1 = c.template loadu<vec_type>(k + 0 * vec_size);
817  auto c2 = c.template loadu<vec_type>(k + 1 * vec_size);
818 
819  auto l1 = lhs.template loadu<vec_type>(base + k + 0 * vec_size);
820  auto l2 = lhs.template loadu<vec_type>(base + k + 1 * vec_size);
821 
822  auto r1 = vec_type::sub(l1, vec_type::add(vec_type::mul(a1, b1), c1));
823  auto r2 = vec_type::sub(l2, vec_type::add(vec_type::mul(a2, b2), c2));
824 
825  lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
826  lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
827  }
828 
829  for (; k + vec_size - 1 < K; k += vec_size) {
830  auto a1 = a.template load<vec_type>(k);
831 
832  auto b1 = b.template loadu<vec_type>(base + k);
833 
834  auto c1 = c.template loadu<vec_type>(k);
835 
836  auto l1 = lhs.template loadu<vec_type>(base + k);
837 
838  auto r1 = vec_type::sub(l1, vec_type::add(vec_type::mul(a1, b1), c1));
839 
840  lhs.template storeu<vec_type>(r1, base + k);
841  }
842 
843  for (; k + 3 < K; k += 4) {
844  lhs(batch, k + 0) -= a(k + 0) * b(batch, k + 0) + c(k + 0);
845  lhs(batch, k + 1) -= a(k + 1) * b(batch, k + 1) + c(k + 1);
846  lhs(batch, k + 2) -= a(k + 2) * b(batch, k + 2) + c(k + 2);
847  lhs(batch, k + 3) -= a(k + 3) * b(batch, k + 3) + c(k + 3);
848  }
849 
850  for (; k + 1 < K; k += 2) {
851  lhs(batch, k + 0) -= a(k + 0) * b(batch, k + 0) + c(k + 0);
852  lhs(batch, k + 1) -= a(k + 1) * b(batch, k + 1) + c(k + 1);
853  }
854 
855  if (k < K) {
856  lhs(batch, k) -= a(k) * b(batch, k) + c(k);
857  }
858  }
859  } else {
860  for (size_t batch = first; batch < last; ++batch) {
861  for (size_t k = 0; k < K; ++k) {
862  lhs(batch, k) -= a(k) * b(batch, k) + c(k);
863  }
864  }
865  }
866  }
867  };
868 
869  engine_dispatch_1d_serial(batch_fun_b, 0, Batch, 2UL);
870 
871  lhs.validate_cpu();
872  lhs.invalidate_gpu();
873  }
874  }
875  }
876 
881  template <etl_expr L>
882  void assign_mul_to(L&& lhs) const {
883  auto& a = this->a();
884  auto& b = this->b();
885  auto& c = this->c();
886 
887  check(a, b, c, lhs);
888 
889  if constexpr (D4) {
890  if constexpr (impl::egblas::has_sbatch_k_scale_plus2 && all_row_major<A, B, L> && all_floating<A, B, L>) {
891  std_mul_evaluate(*this, lhs);
892  } else {
893  const auto Batch = etl::dim<0>(lhs);
894  const auto K = etl::dim<1>(lhs);
895  const auto M = etl::dim<2>(lhs);
896  const auto N = etl::dim<3>(lhs);
897 
898  standard_evaluator::pre_assign_rhs(a);
899  standard_evaluator::pre_assign_rhs(b);
900 
901  a.ensure_cpu_up_to_date();
902  b.ensure_cpu_up_to_date();
903  c.ensure_cpu_up_to_date();
904  lhs.ensure_cpu_up_to_date();
905 
906  auto batch_fun_b = [&](const size_t first, const size_t last) {
907  CPU_SECTION {
908  if constexpr (vec_enabled && all_vectorizable<vector_mode, A, L> && all_row_major<A, L>) {
909  using vec_type = default_vec;
910  using T = value_t<L>;
911 
912  static constexpr size_t vec_size = vec_type::template traits<T>::size;
913 
914  const auto MN = M * N;
915 
916  for (size_t batch = first; batch < last; ++batch) {
917  for (size_t k = 0; k < K; ++k) {
918  T ak = a(k);
919  T ck = c(k);
920 
921  auto lhs_sub = lhs(batch)(k);
922  auto b_sub = b(batch)(k);
923 
924  size_t mn = 0;
925 
926  auto a1 = vec_type::set(ak);
927  auto c1 = vec_type::set(ck);
928 
929  for (; mn + 4 * vec_size - 1 < MN; mn += 4 * vec_size) {
930  auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
931  auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
932  auto b3 = b_sub.template loadu<vec_type>(mn + 2 * vec_size);
933  auto b4 = b_sub.template loadu<vec_type>(mn + 3 * vec_size);
934 
935  auto l1 = lhs_sub.template loadu<vec_type>(mn + 0 * vec_size);
936  auto l2 = lhs_sub.template loadu<vec_type>(mn + 1 * vec_size);
937  auto l3 = lhs_sub.template loadu<vec_type>(mn + 2 * vec_size);
938  auto l4 = lhs_sub.template loadu<vec_type>(mn + 3 * vec_size);
939 
940  auto r1 = vec_type::mul(l1, vec_type::add(vec_type::mul(a1, b1), c1));
941  auto r2 = vec_type::mul(l2, vec_type::add(vec_type::mul(a1, b2), c1));
942  auto r3 = vec_type::mul(l3, vec_type::add(vec_type::mul(a1, b3), c1));
943  auto r4 = vec_type::mul(l4, vec_type::add(vec_type::mul(a1, b4), c1));
944 
945  lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
946  lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
947  lhs_sub.template storeu<vec_type>(r3, mn + 2 * vec_size);
948  lhs_sub.template storeu<vec_type>(r4, mn + 3 * vec_size);
949  }
950 
951  for (; mn + 2 * vec_size - 1 < MN; mn += 2 * vec_size) {
952  auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
953  auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
954 
955  auto l1 = lhs_sub.template loadu<vec_type>(mn + 0 * vec_size);
956  auto l2 = lhs_sub.template loadu<vec_type>(mn + 1 * vec_size);
957 
958  auto r1 = vec_type::mul(l1, vec_type::add(vec_type::mul(a1, b1), c1));
959  auto r2 = vec_type::mul(l2, vec_type::add(vec_type::mul(a1, b2), c1));
960 
961  lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
962  lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
963  }
964 
965  for (; mn + vec_size - 1 < MN; mn += vec_size) {
966  auto b1 = b_sub.template loadu<vec_type>(mn);
967 
968  auto l1 = lhs_sub.template loadu<vec_type>(mn);
969 
970  auto r1 = vec_type::mul(l1, vec_type::add(vec_type::mul(a1, b1), c1));
971 
972  lhs_sub.template storeu<vec_type>(r1, mn);
973  }
974 
975  for (; mn + 3 < MN; mn += 4) {
976  lhs_sub[mn + 0] *= ak * b_sub[mn + 0] + ck;
977  lhs_sub[mn + 1] *= ak * b_sub[mn + 1] + ck;
978  lhs_sub[mn + 2] *= ak * b_sub[mn + 2] + ck;
979  lhs_sub[mn + 3] *= ak * b_sub[mn + 3] + ck;
980  }
981 
982  for (; mn + 1 < MN; mn += 2) {
983  lhs_sub[mn + 0] *= ak * b_sub[mn + 0] + ck;
984  lhs_sub[mn + 1] *= ak * b_sub[mn + 1] + ck;
985  }
986 
987  for (; mn < MN; ++mn) {
988  lhs_sub[mn] *= ak * b_sub[mn] + ck;
989  }
990  }
991  }
992  } else {
993  for (size_t batch = first; batch < last; ++batch) {
994  for (size_t k = 0; k < K; ++k) {
995  for (size_t m = 0; m < M; ++m) {
996  for (size_t n = 0; n < N; ++n) {
997  lhs(batch, k, m, n) *= a(k) * b(batch, k, m, n) + c(k);
998  }
999  }
1000  }
1001  }
1002  }
1003  }
1004  };
1005 
1006  engine_dispatch_1d_serial(batch_fun_b, 0, Batch, 2UL);
1007 
1008  lhs.validate_cpu();
1009  lhs.invalidate_gpu();
1010  }
1011  } else {
1012  if constexpr (impl::egblas::has_sbatch_k_scale_plus2 && all_row_major<A, B, L> && all_floating<A, B, L>) {
1013  std_mul_evaluate(*this, lhs);
1014  } else {
1015  const auto Batch = etl::dim<0>(lhs);
1016  const auto K = etl::dim<1>(lhs);
1017 
1018  standard_evaluator::pre_assign_rhs(a);
1019  standard_evaluator::pre_assign_rhs(b);
1020 
1021  a.ensure_cpu_up_to_date();
1022  b.ensure_cpu_up_to_date();
1023  c.ensure_cpu_up_to_date();
1024  lhs.ensure_cpu_up_to_date();
1025 
1026  auto batch_fun_b = [&](const size_t first, const size_t last) {
1027  CPU_SECTION {
1028  if constexpr (vec_enabled && all_vectorizable<vector_mode, A, B, L> && all_row_major<A, B, L>) {
1029  using vec_type = default_vec;
1030  using T = value_t<L>;
1031 
1032  static constexpr size_t vec_size = vec_type::template traits<T>::size;
1033 
1034  for (size_t batch = first; batch < last; ++batch) {
1035  size_t k = 0;
1036 
1037  const size_t base = batch * K;
1038 
1039  for (; k + 4 * vec_size - 1 < K; k += 4 * vec_size) {
1040  auto a1 = a.template load<vec_type>(k + 0 * vec_size);
1041  auto a2 = a.template load<vec_type>(k + 1 * vec_size);
1042  auto a3 = a.template load<vec_type>(k + 2 * vec_size);
1043  auto a4 = a.template load<vec_type>(k + 3 * vec_size);
1044 
1045  auto b1 = b.template loadu<vec_type>(base + k + 0 * vec_size);
1046  auto b2 = b.template loadu<vec_type>(base + k + 1 * vec_size);
1047  auto b3 = b.template loadu<vec_type>(base + k + 2 * vec_size);
1048  auto b4 = b.template loadu<vec_type>(base + k + 3 * vec_size);
1049 
1050  auto c1 = c.template loadu<vec_type>(k + 0 * vec_size);
1051  auto c2 = c.template loadu<vec_type>(k + 1 * vec_size);
1052  auto c3 = c.template loadu<vec_type>(k + 2 * vec_size);
1053  auto c4 = c.template loadu<vec_type>(k + 3 * vec_size);
1054 
1055  auto l1 = lhs.template loadu<vec_type>(base + k + 0 * vec_size);
1056  auto l2 = lhs.template loadu<vec_type>(base + k + 1 * vec_size);
1057  auto l3 = lhs.template loadu<vec_type>(base + k + 2 * vec_size);
1058  auto l4 = lhs.template loadu<vec_type>(base + k + 3 * vec_size);
1059 
1060  auto r1 = vec_type::mul(l1, vec_type::add(vec_type::mul(a1, b1), c1));
1061  auto r2 = vec_type::mul(l2, vec_type::add(vec_type::mul(a2, b2), c2));
1062  auto r3 = vec_type::mul(l3, vec_type::add(vec_type::mul(a3, b3), c3));
1063  auto r4 = vec_type::mul(l4, vec_type::add(vec_type::mul(a4, b4), c4));
1064 
1065  lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
1066  lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
1067  lhs.template storeu<vec_type>(r3, base + k + 2 * vec_size);
1068  lhs.template storeu<vec_type>(r4, base + k + 3 * vec_size);
1069  }
1070 
1071  for (; k + 2 * vec_size - 1 < K; k += 2 * vec_size) {
1072  auto a1 = a.template load<vec_type>(k + 0 * vec_size);
1073  auto a2 = a.template load<vec_type>(k + 1 * vec_size);
1074 
1075  auto b1 = b.template loadu<vec_type>(base + k + 0 * vec_size);
1076  auto b2 = b.template loadu<vec_type>(base + k + 1 * vec_size);
1077 
1078  auto c1 = c.template loadu<vec_type>(k + 0 * vec_size);
1079  auto c2 = c.template loadu<vec_type>(k + 1 * vec_size);
1080 
1081  auto l1 = lhs.template loadu<vec_type>(base + k + 0 * vec_size);
1082  auto l2 = lhs.template loadu<vec_type>(base + k + 1 * vec_size);
1083 
1084  auto r1 = vec_type::mul(l1, vec_type::add(vec_type::mul(a1, b1), c1));
1085  auto r2 = vec_type::mul(l2, vec_type::add(vec_type::mul(a2, b2), c2));
1086 
1087  lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
1088  lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
1089  }
1090 
1091  for (; k + vec_size - 1 < K; k += vec_size) {
1092  auto a1 = a.template load<vec_type>(k);
1093 
1094  auto b1 = b.template loadu<vec_type>(base + k);
1095 
1096  auto c1 = c.template loadu<vec_type>(k);
1097 
1098  auto l1 = lhs.template loadu<vec_type>(base + k);
1099 
1100  auto r1 = vec_type::mul(l1, vec_type::add(vec_type::mul(a1, b1), c1));
1101 
1102  lhs.template storeu<vec_type>(r1, base + k);
1103  }
1104 
1105  for (; k + 3 < K; k += 4) {
1106  lhs(batch, k + 0) *= a(k + 0) * b(batch, k + 0) + c(k + 0);
1107  lhs(batch, k + 1) *= a(k + 1) * b(batch, k + 1) + c(k + 1);
1108  lhs(batch, k + 2) *= a(k + 2) * b(batch, k + 2) + c(k + 2);
1109  lhs(batch, k + 3) *= a(k + 3) * b(batch, k + 3) + c(k + 3);
1110  }
1111 
1112  for (; k + 1 < K; k += 2) {
1113  lhs(batch, k + 0) *= a(k + 0) * b(batch, k + 0) + c(k + 0);
1114  lhs(batch, k + 1) *= a(k + 1) * b(batch, k + 1) + c(k + 1);
1115  }
1116 
1117  if (k < K) {
1118  lhs(batch, k) *= a(k) * b(batch, k) + c(k);
1119  }
1120  }
1121  } else {
1122  for (size_t batch = first; batch < last; ++batch) {
1123  for (size_t k = 0; k < K; ++k) {
1124  lhs(batch, k) *= a(k) * b(batch, k) + c(k);
1125  }
1126  }
1127  }
1128  }
1129  };
1130 
1131  engine_dispatch_1d_serial(batch_fun_b, 0, Batch, 2UL);
1132 
1133  lhs.validate_cpu();
1134  lhs.invalidate_gpu();
1135  }
1136  }
1137  }
1138 
1143  template <etl_expr L>
1144  void assign_div_to(L&& lhs) const {
1145  auto& a = this->a();
1146  auto& b = this->b();
1147  auto& c = this->c();
1148 
1149  check(a, b, c, lhs);
1150 
1151  if constexpr (D4) {
1152  if constexpr (impl::egblas::has_sbatch_k_scale_plus2 && all_row_major<A, B, L> && all_floating<A, B, L>) {
1153  std_div_evaluate(*this, lhs);
1154  } else {
1155  const auto Batch = etl::dim<0>(lhs);
1156  const auto K = etl::dim<1>(lhs);
1157  const auto M = etl::dim<2>(lhs);
1158  const auto N = etl::dim<3>(lhs);
1159 
1160  standard_evaluator::pre_assign_rhs(a);
1161  standard_evaluator::pre_assign_rhs(b);
1162 
1163  a.ensure_cpu_up_to_date();
1164  b.ensure_cpu_up_to_date();
1165  c.ensure_cpu_up_to_date();
1166  lhs.ensure_cpu_up_to_date();
1167 
1168  auto batch_fun_b = [&](const size_t first, const size_t last) {
1169  CPU_SECTION {
1170  if constexpr (vec_enabled && all_vectorizable<vector_mode, A, L> && all_row_major<A, L>) {
1171  using vec_type = default_vec;
1172  using T = value_t<L>;
1173 
1174  static constexpr size_t vec_size = vec_type::template traits<T>::size;
1175 
1176  const auto MN = M * N;
1177 
1178  for (size_t batch = first; batch < last; ++batch) {
1179  for (size_t k = 0; k < K; ++k) {
1180  T ak = a(k);
1181  T ck = c(k);
1182 
1183  auto lhs_sub = lhs(batch)(k);
1184  auto b_sub = b(batch)(k);
1185 
1186  size_t mn = 0;
1187 
1188  auto a1 = vec_type::set(ak);
1189  auto c1 = vec_type::set(ck);
1190 
1191  for (; mn + 4 * vec_size - 1 < MN; mn += 4 * vec_size) {
1192  auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
1193  auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
1194  auto b3 = b_sub.template loadu<vec_type>(mn + 2 * vec_size);
1195  auto b4 = b_sub.template loadu<vec_type>(mn + 3 * vec_size);
1196 
1197  auto l1 = lhs_sub.template loadu<vec_type>(mn + 0 * vec_size);
1198  auto l2 = lhs_sub.template loadu<vec_type>(mn + 1 * vec_size);
1199  auto l3 = lhs_sub.template loadu<vec_type>(mn + 2 * vec_size);
1200  auto l4 = lhs_sub.template loadu<vec_type>(mn + 3 * vec_size);
1201 
1202  auto r1 = vec_type::div(l1, vec_type::add(vec_type::mul(a1, b1), c1));
1203  auto r2 = vec_type::div(l2, vec_type::add(vec_type::mul(a1, b2), c1));
1204  auto r3 = vec_type::div(l3, vec_type::add(vec_type::mul(a1, b3), c1));
1205  auto r4 = vec_type::div(l4, vec_type::add(vec_type::mul(a1, b4), c1));
1206 
1207  lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
1208  lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
1209  lhs_sub.template storeu<vec_type>(r3, mn + 2 * vec_size);
1210  lhs_sub.template storeu<vec_type>(r4, mn + 3 * vec_size);
1211  }
1212 
1213  for (; mn + 2 * vec_size - 1 < MN; mn += 2 * vec_size) {
1214  auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
1215  auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
1216 
1217  auto l1 = lhs_sub.template loadu<vec_type>(mn + 0 * vec_size);
1218  auto l2 = lhs_sub.template loadu<vec_type>(mn + 1 * vec_size);
1219 
1220  auto r1 = vec_type::div(l1, vec_type::add(vec_type::mul(a1, b1), c1));
1221  auto r2 = vec_type::div(l2, vec_type::add(vec_type::mul(a1, b2), c1));
1222 
1223  lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
1224  lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
1225  }
1226 
1227  for (; mn + vec_size - 1 < MN; mn += vec_size) {
1228  auto b1 = b_sub.template loadu<vec_type>(mn);
1229 
1230  auto l1 = lhs_sub.template loadu<vec_type>(mn);
1231 
1232  auto r1 = vec_type::div(l1, vec_type::add(vec_type::mul(a1, b1), c1));
1233 
1234  lhs_sub.template storeu<vec_type>(r1, mn);
1235  }
1236 
1237  for (; mn + 3 < MN; mn += 4) {
1238  lhs_sub[mn + 0] /= ak * b_sub[mn + 0] + ck;
1239  lhs_sub[mn + 1] /= ak * b_sub[mn + 1] + ck;
1240  lhs_sub[mn + 2] /= ak * b_sub[mn + 2] + ck;
1241  lhs_sub[mn + 3] /= ak * b_sub[mn + 3] + ck;
1242  }
1243 
1244  for (; mn + 1 < MN; mn += 2) {
1245  lhs_sub[mn + 0] /= ak * b_sub[mn + 0] + ck;
1246  lhs_sub[mn + 1] /= ak * b_sub[mn + 1] + ck;
1247  }
1248 
1249  for (; mn < MN; ++mn) {
1250  lhs_sub[mn] /= ak * b_sub[mn] + ck;
1251  }
1252  }
1253  }
1254  } else {
1255  for (size_t batch = first; batch < last; ++batch) {
1256  for (size_t k = 0; k < K; ++k) {
1257  for (size_t m = 0; m < M; ++m) {
1258  for (size_t n = 0; n < N; ++n) {
1259  lhs(batch, k, m, n) /= a(k) * b(batch, k, m, n) + c(k);
1260  }
1261  }
1262  }
1263  }
1264  }
1265  }
1266  };
1267 
1268  engine_dispatch_1d_serial(batch_fun_b, 0, Batch, 2UL);
1269 
1270  lhs.validate_cpu();
1271  lhs.invalidate_gpu();
1272  }
1273  } else {
1274  if constexpr (impl::egblas::has_sbatch_k_scale_plus2 && all_row_major<A, B, L> && all_floating<A, B, L>) {
1275  std_div_evaluate(*this, lhs);
1276  } else {
1277  const auto Batch = etl::dim<0>(lhs);
1278  const auto K = etl::dim<1>(lhs);
1279 
1280  standard_evaluator::pre_assign_rhs(a);
1281  standard_evaluator::pre_assign_rhs(b);
1282 
1283  a.ensure_cpu_up_to_date();
1284  b.ensure_cpu_up_to_date();
1285  c.ensure_cpu_up_to_date();
1286  lhs.ensure_cpu_up_to_date();
1287 
1288  auto batch_fun_b = [&](const size_t first, const size_t last) {
1289  CPU_SECTION {
1290  if constexpr (vec_enabled && all_vectorizable<vector_mode, A, B, L> && all_row_major<A, B, L>) {
1291  using vec_type = default_vec;
1292  using T = value_t<L>;
1293 
1294  static constexpr size_t vec_size = vec_type::template traits<T>::size;
1295 
1296  for (size_t batch = first; batch < last; ++batch) {
1297  size_t k = 0;
1298 
1299  const size_t base = batch * K;
1300 
1301  for (; k + 4 * vec_size - 1 < K; k += 4 * vec_size) {
1302  auto a1 = a.template load<vec_type>(k + 0 * vec_size);
1303  auto a2 = a.template load<vec_type>(k + 1 * vec_size);
1304  auto a3 = a.template load<vec_type>(k + 2 * vec_size);
1305  auto a4 = a.template load<vec_type>(k + 3 * vec_size);
1306 
1307  auto b1 = b.template loadu<vec_type>(base + k + 0 * vec_size);
1308  auto b2 = b.template loadu<vec_type>(base + k + 1 * vec_size);
1309  auto b3 = b.template loadu<vec_type>(base + k + 2 * vec_size);
1310  auto b4 = b.template loadu<vec_type>(base + k + 3 * vec_size);
1311 
1312  auto c1 = c.template loadu<vec_type>(k + 0 * vec_size);
1313  auto c2 = c.template loadu<vec_type>(k + 1 * vec_size);
1314  auto c3 = c.template loadu<vec_type>(k + 2 * vec_size);
1315  auto c4 = c.template loadu<vec_type>(k + 3 * vec_size);
1316 
1317  auto l1 = lhs.template loadu<vec_type>(base + k + 0 * vec_size);
1318  auto l2 = lhs.template loadu<vec_type>(base + k + 1 * vec_size);
1319  auto l3 = lhs.template loadu<vec_type>(base + k + 2 * vec_size);
1320  auto l4 = lhs.template loadu<vec_type>(base + k + 3 * vec_size);
1321 
1322  auto r1 = vec_type::div(l1, vec_type::add(vec_type::mul(a1, b1), c1));
1323  auto r2 = vec_type::div(l2, vec_type::add(vec_type::mul(a2, b2), c2));
1324  auto r3 = vec_type::div(l3, vec_type::add(vec_type::mul(a3, b3), c3));
1325  auto r4 = vec_type::div(l4, vec_type::add(vec_type::mul(a4, b4), c4));
1326 
1327  lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
1328  lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
1329  lhs.template storeu<vec_type>(r3, base + k + 2 * vec_size);
1330  lhs.template storeu<vec_type>(r4, base + k + 3 * vec_size);
1331  }
1332 
1333  for (; k + 2 * vec_size - 1 < K; k += 2 * vec_size) {
1334  auto a1 = a.template load<vec_type>(k + 0 * vec_size);
1335  auto a2 = a.template load<vec_type>(k + 1 * vec_size);
1336 
1337  auto b1 = b.template loadu<vec_type>(base + k + 0 * vec_size);
1338  auto b2 = b.template loadu<vec_type>(base + k + 1 * vec_size);
1339 
1340  auto c1 = c.template loadu<vec_type>(k + 0 * vec_size);
1341  auto c2 = c.template loadu<vec_type>(k + 1 * vec_size);
1342 
1343  auto l1 = lhs.template loadu<vec_type>(base + k + 0 * vec_size);
1344  auto l2 = lhs.template loadu<vec_type>(base + k + 1 * vec_size);
1345 
1346  auto r1 = vec_type::div(l1, vec_type::add(vec_type::mul(a1, b1), c1));
1347  auto r2 = vec_type::div(l2, vec_type::add(vec_type::mul(a2, b2), c2));
1348 
1349  lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
1350  lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
1351  }
1352 
1353  for (; k + vec_size - 1 < K; k += vec_size) {
1354  auto a1 = a.template load<vec_type>(k);
1355 
1356  auto b1 = b.template loadu<vec_type>(base + k);
1357 
1358  auto c1 = c.template loadu<vec_type>(k);
1359 
1360  auto l1 = lhs.template loadu<vec_type>(base + k);
1361 
1362  auto r1 = vec_type::div(l1, vec_type::add(vec_type::mul(a1, b1), c1));
1363 
1364  lhs.template storeu<vec_type>(r1, base + k);
1365  }
1366 
1367  for (; k + 3 < K; k += 4) {
1368  lhs(batch, k + 0) /= a(k + 0) * b(batch, k + 0) + c(k + 0);
1369  lhs(batch, k + 1) /= a(k + 1) * b(batch, k + 1) + c(k + 1);
1370  lhs(batch, k + 2) /= a(k + 2) * b(batch, k + 2) + c(k + 2);
1371  lhs(batch, k + 3) /= a(k + 3) * b(batch, k + 3) + c(k + 3);
1372  }
1373 
1374  for (; k + 1 < K; k += 2) {
1375  lhs(batch, k + 0) /= a(k + 0) * b(batch, k + 0) + c(k + 0);
1376  lhs(batch, k + 1) /= a(k + 1) * b(batch, k + 1) + c(k + 1);
1377  }
1378 
1379  if (k < K) {
1380  lhs(batch, k) /= a(k) * b(batch, k) + c(k);
1381  }
1382  }
1383  } else {
1384  for (size_t batch = first; batch < last; ++batch) {
1385  for (size_t k = 0; k < K; ++k) {
1386  lhs(batch, k) /= a(k) * b(batch, k) + c(k);
1387  }
1388  }
1389  }
1390  }
1391  };
1392 
1393  engine_dispatch_1d_serial(batch_fun_b, 0, Batch, 2UL);
1394 
1395  lhs.validate_cpu();
1396  lhs.invalidate_gpu();
1397  }
1398  }
1399  }
1400 
1405  template <etl_expr L>
1406  void assign_mod_to(L&& lhs) const {
1407  auto& a = this->a();
1408  auto& b = this->b();
1409  auto& c = this->c();
1410 
1411  check(a, b, c, lhs);
1412 
1413  standard_evaluator::pre_assign_rhs(a);
1414  standard_evaluator::pre_assign_rhs(b);
1415 
1416  a.ensure_cpu_up_to_date();
1417  b.ensure_cpu_up_to_date();
1418  c.ensure_cpu_up_to_date();
1419  lhs.ensure_cpu_up_to_date();
1420 
1421  if constexpr (D4) {
1422  const auto Batch = etl::dim<0>(lhs);
1423  const auto K = etl::dim<1>(lhs);
1424  const auto M = etl::dim<2>(lhs);
1425  const auto N = etl::dim<3>(lhs);
1426 
1427  auto batch_fun_b = [&](const size_t first, const size_t last) {
1428  CPU_SECTION {
1429  for (size_t batch = first; batch < last; ++batch) {
1430  for (size_t k = 0; k < K; ++k) {
1431  for (size_t m = 0; m < M; ++m) {
1432  for (size_t n = 0; n < N; ++n) {
1433  lhs(batch, k, m, n) %= a(k) * b(batch, k, m, n) + c(k);
1434  }
1435  }
1436  }
1437  }
1438  }
1439  };
1440 
1441  engine_dispatch_1d_serial(batch_fun_b, 0, Batch, 2UL);
1442 
1443  lhs.validate_cpu();
1444  lhs.invalidate_gpu();
1445  } else {
1446  const auto Batch = etl::dim<0>(lhs);
1447  const auto K = etl::dim<1>(lhs);
1448 
1449  auto batch_fun_b = [&](const size_t first, const size_t last) {
1450  CPU_SECTION {
1451  for (size_t batch = first; batch < last; ++batch) {
1452  for (size_t k = 0; k < K; ++k) {
1453  lhs(batch, k) %= a(k) * b(batch, k) + c(k);
1454  }
1455  }
1456  }
1457  };
1458 
1459  engine_dispatch_1d_serial(batch_fun_b, 0, Batch, 2UL);
1460 
1461  lhs.validate_cpu();
1462  lhs.invalidate_gpu();
1463  }
1464  }
1465 
1472  friend std::ostream& operator<<(std::ostream& os, const batch_k_scale_plus_expr& expr) {
1473  return os << "batch_k_scale_plus(" << expr._a << "," << expr._b << "," << expr._c << ")";
1474  }
1475 };
1476 
1481 template <typename A, typename B, typename C>
1484  using sub_expr_t = std::decay_t<B>;
1487 
1488  static constexpr bool is_etl = true;
1489  static constexpr bool is_transformer = false;
1490  static constexpr bool is_view = false;
1491  static constexpr bool is_magic_view = false;
1492  static constexpr bool is_fast = sub_traits::is_fast;
1493  static constexpr bool is_linear = false;
1494  static constexpr bool is_thread_safe = true;
1495  static constexpr bool is_value = false;
1496  static constexpr bool is_direct = true;
1497  static constexpr bool is_generator = false;
1498  static constexpr bool is_padded = false;
1499  static constexpr bool is_aligned = true;
1500  static constexpr bool is_temporary = true;
1501  static constexpr bool gpu_computable = true;
1502  static constexpr order storage_order = sub_traits::storage_order;
1503 
1509  template <vector_mode_t V>
1510  static constexpr bool vectorizable = true;
1511 
1516  template <size_t DD>
1517  static constexpr size_t dim() {
1518  return decay_traits<B>::template dim<DD>();
1519  }
1520 
1527  static size_t dim(const expr_t& e, size_t d) {
1528  return etl::dim(e._b, d);
1529  }
1530 
1536  static size_t size(const expr_t& e) {
1537  return etl::size(e._b);
1538  }
1539 
1544  static constexpr size_t size() {
1545  return decay_traits<B>::size();
1546  }
1547 
1552  static constexpr size_t dimensions() {
1553  return decay_traits<B>::dimensions();
1554  }
1555 
1560  static constexpr int complexity() noexcept {
1561  return -1;
1562  }
1563 };
1564 
1565 // Note: This function should not be called directly
1566 // instead, batch_hint(a >> b) should be used
1567 // But this function is used as helpers from batch_hint
1568 
1574 template <etl_1d A, etl_2d_or_4d B, etl_1d C>
1576  return {a, b, c};
1577 }
1578 
1579 } //end of namespace etl
std::add_lvalue_reference_t< B > b()
Returns the sub expression.
Definition: base_temporary_expr.hpp:702
static void check([[maybe_unused]] const A &a, [[maybe_unused]] const B &b, [[maybe_unused]] const C &c, [[maybe_unused]] L &lhs)
Validate the transposition dimensions.
Definition: batch_k_scale_plus_expr.hpp:47
void assign_mod_to(L &&lhs) const
Modulo the given left-hand-side expression.
Definition: batch_k_scale_plus_expr.hpp:1406
void assign_sub_to(L &&lhs) const
Sub from the given left-hand-side expression.
Definition: batch_k_scale_plus_expr.hpp:620
batch_k_scale_plus_expr< detail::build_type< A >, detail::build_type< B >, detail::build_type< C > > batch_k_scale_plus(const A &a, const B &b, const C &c)
Returns the transpose of the given expression.
Definition: batch_k_scale_plus_expr.hpp:1575
value_t< A > value_type
The type of value of the expression.
Definition: batch_k_scale_plus_expr.hpp:16
void engine_dispatch_1d_serial(Functor &&functor, size_t first, size_t last, size_t threshold, [[maybe_unused]] size_t n_threads=etl::threads)
Dispatch the elements of a range to a functor in a parallel manner, using the global thread engine...
Definition: parallel_support.hpp:734
constexpr bool is_magic_view
Traits indicating if the given ETL type is a magic view expression.
Definition: traits.hpp:311
static constexpr auto storage_order
The sub storage order.
Definition: batch_k_scale_plus_expr.hpp:23
static constexpr size_t size()
Returns the size of the expression.
Definition: batch_k_scale_plus_expr.hpp:1544
constexpr bool vec_enabled
Indicates if vectorization is available in any format.
Definition: config.hpp:220
order
Storage order of a matrix.
Definition: order.hpp:15
static constexpr size_t dimensions()
Returns the number of dimensions of the expression.
Definition: batch_k_scale_plus_expr.hpp:1552
A _a
The first sub expression reference.
Definition: base_temporary_expr.hpp:638
static size_t size(const expr_t &e)
Returns the size of the expression.
Definition: batch_k_scale_plus_expr.hpp:1536
std::add_lvalue_reference_t< A > a()
Returns the sub expression.
Definition: base_temporary_expr.hpp:686
constexpr bool is_fast
Traits to test if the given ETL expresion type is fast (sizes known at compile-time) ...
Definition: traits.hpp:588
batch_k_scale_plus_expr(A a, B b, C c)
Construct a new expression.
Definition: batch_k_scale_plus_expr.hpp:37
typename VV::template vec_type< value_type > vec_type
The vectorization type for VV.
Definition: base_temporary_expr.hpp:107
Traits to get information about ETL types.
Definition: tmp.hpp:68
Root namespace for the ETL library.
Definition: adapter.hpp:15
void assign_div_to(L &&lhs) const
Divide the given left-hand-side expression.
Definition: batch_k_scale_plus_expr.hpp:1144
static constexpr size_t dimensions()
Return the number of dimensions of the expression.
Definition: traits_base.hpp:31
no_vec default_vec
The default vectorization scheme.
Definition: vectorization.hpp:242
auto dim(E &&value, size_t i) -> detail::identity_helper< E, dim_view< detail::build_identity_type< E >, D >>
Return a view representing the ith Dth dimension.
Definition: view_expression_builder.hpp:25
std::conditional_t< is_etl_value< T >, const std::decay_t< T > &, std::decay_t< T > > build_type
Helper to build the type for a sub expression.
Definition: expression_helpers.hpp:24
Abstract base class for temporary ternary expression.
Definition: base_temporary_expr.hpp:634
static size_t dim(const expr_t &e, size_t d)
Returns the dth dimension of the expression.
Definition: batch_k_scale_plus_expr.hpp:1527
void std_mul_evaluate(Expr &&expr, Result &&result)
Compound multiply evaluation of the expr into result.
Definition: evaluator.hpp:1233
void assign_add_to(L &&lhs) const
Add to the given left-hand-side expression.
Definition: batch_k_scale_plus_expr.hpp:358
constexpr bool is_transformer
Traits indicating if the given ETL type is a transformer expression.
Definition: traits.hpp:297
decltype(auto) smart_forward_gpu(E &expr)
Smart forwarding for a temporary expression that will be computed in GPU.
Definition: helpers.hpp:343
void assign_mul_to(L &&lhs) const
Multiply the given left-hand-side expression.
Definition: batch_k_scale_plus_expr.hpp:882
constexpr size_t size(const E &expr) noexcept
Returns the size of the given ETL expression.
Definition: helpers.hpp:108
constexpr bool is_view
Traits indicating if the given ETL type is a view expression.
Definition: traits.hpp:304
void std_sub_evaluate(Expr &&expr, Result &&result)
Compound subtract evaluation of the expr into result.
Definition: evaluator.hpp:1214
std::decay_t< B > sub_expr_t
The sub expression type.
Definition: batch_k_scale_plus_expr.hpp:1484
friend std::ostream & operator<<(std::ostream &os, const batch_k_scale_plus_expr &expr)
Print a representation of the expression on the given stream.
Definition: batch_k_scale_plus_expr.hpp:1472
static constexpr size_t dim()
Returns the DDth dimension of the expression.
Definition: batch_k_scale_plus_expr.hpp:1517
void assign_to(L &&lhs) const
Assign to a matrix of the same storage order.
Definition: batch_k_scale_plus_expr.hpp:90
constexpr bool is_thread_safe
Traits to test if the given ETL expresion type is thread safe.
Definition: traits.hpp:687
static constexpr bool D4
If the expression is 4D (instead of 2D)
Definition: batch_k_scale_plus_expr.hpp:21
B _b
The second sub expression reference.
Definition: base_temporary_expr.hpp:639
Definition: batch_k_scale_plus_expr.hpp:15
C _c
The third sub expression reference.
Definition: base_temporary_expr.hpp:640
value_t< A > value_type
The value type of the expression.
Definition: batch_k_scale_plus_expr.hpp:1486
static constexpr bool gpu_computable
Indicates if the temporary expression can be directly evaluated using only GPU.
Definition: batch_k_scale_plus_expr.hpp:29
typename decay_traits< E >::value_type value_t
Traits to extract the value type out of an ETL type.
Definition: tmp.hpp:81
void std_div_evaluate(Expr &&expr, Result &&result)
Compound divide evaluation of the expr into result.
Definition: evaluator.hpp:1252
void inc_counter([[maybe_unused]] const char *name)
Increase the given counter.
Definition: counters.hpp:25
static constexpr int complexity() noexcept
Estimate the complexity of computation.
Definition: batch_k_scale_plus_expr.hpp:1560
std::add_lvalue_reference_t< C > c()
Returns the sub expression.
Definition: base_temporary_expr.hpp:718
void std_add_evaluate(Expr &&expr, Result &&result)
Compound add evaluation of the expr into result.
Definition: evaluator.hpp:1195