Expression Templates Library (ETL)
batch_k_scale_expr.hpp
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
8 #pragma once
9 
10 #include "etl/expr/base_temporary_expr.hpp"
11 
13 
14 namespace etl {
15 
20 template <etl_1d A, etl_2d_or_4d B>
21 struct batch_k_scale_expr : base_temporary_expr_bin<batch_k_scale_expr<A, B>, A, B> {
26 
27  static constexpr bool D4 = is_4d<B>;
28 
29  static constexpr auto storage_order = left_traits::storage_order;
30 
35  static constexpr bool gpu_computable =
36  (!D4 && impl::egblas::has_sbatch_k_scale2 && all_row_major<A, B> && all_floating<A, B>)
37  || (D4 && impl::egblas::has_dbatch_k_scale4 && all_row_major<A, B> && all_floating<A, B>);
42  batch_k_scale_expr(A a, B b) : base_type(a, b) {
43  //Nothing else to init
44  }
45 
51  template <same_dimensions<B> C>
52  static void check([[maybe_unused]] const A& a, [[maybe_unused]] const B& b, [[maybe_unused]] const C& c) {
53  if constexpr (D4) {
54  if constexpr (all_fast<A, C>) {
55  static_assert(etl::dim<0, B>() == etl::dim<0, C>(), "Invalid dimensions for batch_k_scale");
56  static_assert(etl::dim<1, B>() == etl::dim<1, C>(), "Invalid dimensions for batch_k_scale");
57  static_assert(etl::dim<2, B>() == etl::dim<2, C>(), "Invalid dimensions for batch_k_scale");
58  static_assert(etl::dim<3, B>() == etl::dim<3, C>(), "Invalid dimensions for batch_k_scale");
59 
60  static_assert(etl::dim<0, A>() == etl::dim<1, B>(), "Invalid dimensions for batch_k_scale");
61  } else {
62  cpp_assert(etl::dim<0>(b) == etl::dim<0>(c), "Invalid dimensions for batch_k_scale");
63  cpp_assert(etl::dim<1>(b) == etl::dim<1>(c), "Invalid dimensions for batch_k_scale");
64  cpp_assert(etl::dim<2>(b) == etl::dim<2>(c), "Invalid dimensions for batch_k_scale");
65  cpp_assert(etl::dim<3>(b) == etl::dim<3>(c), "Invalid dimensions for batch_k_scale");
66 
67  cpp_assert(etl::dim<0>(a) == etl::dim<1>(b), "Invalid dimensions for batch_k_scale");
68  }
69  } else {
70  if constexpr (all_fast<A, C>) {
71  static_assert(etl::dim<0, B>() == etl::dim<0, C>(), "Invalid dimensions for batch_k_scale");
72  static_assert(etl::dim<1, B>() == etl::dim<1, C>(), "Invalid dimensions for batch_k_scale");
73 
74  static_assert(etl::dim<0, A>() == etl::dim<1, B>(), "Invalid dimensions for batch_k_scale");
75  } else {
76  cpp_assert(etl::dim<0>(b) == etl::dim<0>(c), "Invalid dimensions for batch_k_scale");
77  cpp_assert(etl::dim<1>(b) == etl::dim<1>(c), "Invalid dimensions for batch_k_scale");
78 
79  cpp_assert(etl::dim<0>(a) == etl::dim<1>(b), "Invalid dimensions for batch_k_scale");
80  }
81  }
82  }
83 
84  // Assignment functions
85 
90  template <etl_expr L>
91  void assign_to(L&& lhs) const {
92  inc_counter("temp:assign");
93 
94  auto& a = this->a();
95  auto& b = this->b();
96 
97  check(a, b, lhs);
98 
99  if constexpr (D4) {
100  const auto Batch = etl::dim<0>(lhs);
101  const auto K = etl::dim<1>(lhs);
102  const auto M = etl::dim<2>(lhs);
103  const auto N = etl::dim<3>(lhs);
104 
105  if constexpr (impl::egblas::has_sbatch_k_scale4 && all_row_major<A, B, L> && all_floating<A, B, L>) {
106  decltype(auto) t1 = smart_forward_gpu(a);
107  decltype(auto) t2 = smart_forward_gpu(b);
108 
109  t1.ensure_gpu_up_to_date();
110  t2.ensure_gpu_up_to_date();
111 
112  lhs.ensure_gpu_allocated();
113 
114  impl::egblas::batch_k_scale(Batch, K, M, N, t2.gpu_memory(), t1.gpu_memory(), lhs.gpu_memory());
115 
116  lhs.validate_gpu();
117  lhs.invalidate_cpu();
118  } else {
119  standard_evaluator::pre_assign_rhs(a);
120  standard_evaluator::pre_assign_rhs(b);
121 
122  a.ensure_cpu_up_to_date();
123  b.ensure_cpu_up_to_date();
124 
125  auto batch_fun_b = [&](const size_t first, const size_t last) {
126  CPU_SECTION {
127  if constexpr (vec_enabled && all_vectorizable<vector_mode, A, L> && all_row_major<A, L>) {
128  using vec_type = default_vec;
129  using T = value_t<L>;
130 
131  static constexpr size_t vec_size = vec_type::template traits<T>::size;
132 
133  const auto MN = M * N;
134 
135  for (size_t batch = first; batch < last; ++batch) {
136  for (size_t k = 0; k < K; ++k) {
137  T ak = a(k);
138 
139  auto lhs_sub = lhs(batch)(k);
140  auto b_sub = b(batch)(k);
141 
142  size_t mn = 0;
143 
144  auto a1 = vec_type::set(ak);
145 
146  for (; mn + 4 * vec_size - 1 < MN; mn += 4 * vec_size) {
147  auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
148  auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
149  auto b3 = b_sub.template loadu<vec_type>(mn + 2 * vec_size);
150  auto b4 = b_sub.template loadu<vec_type>(mn + 3 * vec_size);
151 
152  auto r1 = vec_type::mul(a1, b1);
153  auto r2 = vec_type::mul(a1, b2);
154  auto r3 = vec_type::mul(a1, b3);
155  auto r4 = vec_type::mul(a1, b4);
156 
157  lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
158  lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
159  lhs_sub.template storeu<vec_type>(r3, mn + 2 * vec_size);
160  lhs_sub.template storeu<vec_type>(r4, mn + 3 * vec_size);
161  }
162 
163  for (; mn + 2 * vec_size - 1 < MN; mn += 2 * vec_size) {
164  auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
165  auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
166 
167  auto r1 = vec_type::mul(a1, b1);
168  auto r2 = vec_type::mul(a1, b2);
169 
170  lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
171  lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
172  }
173 
174  for (; mn + vec_size - 1 < MN; mn += vec_size) {
175  auto b1 = b_sub.template loadu<vec_type>(mn);
176 
177  auto r1 = vec_type::mul(a1, b1);
178 
179  lhs_sub.template storeu<vec_type>(r1, mn);
180  }
181 
182  for (; mn + 3 < MN; mn += 4) {
183  lhs_sub[mn + 0] = ak * b_sub[mn + 0];
184  lhs_sub[mn + 1] = ak * b_sub[mn + 1];
185  lhs_sub[mn + 2] = ak * b_sub[mn + 2];
186  lhs_sub[mn + 3] = ak * b_sub[mn + 3];
187  }
188 
189  for (; mn + 1 < MN; mn += 2) {
190  lhs_sub[mn + 0] = ak * b_sub[mn + 0];
191  lhs_sub[mn + 1] = ak * b_sub[mn + 1];
192  }
193 
194  for (; mn < MN; ++mn) {
195  lhs_sub[mn] = ak * b_sub[mn];
196  }
197  }
198  }
199  } else {
200  for (size_t batch = first; batch < last; ++batch) {
201  for (size_t k = 0; k < K; ++k) {
202  for (size_t m = 0; m < M; ++m) {
203  for (size_t n = 0; n < N; ++n) {
204  lhs(batch, k, m, n) = a(k) * b(batch, k, m, n);
205  }
206  }
207  }
208  }
209  }
210  }
211  };
212 
213  engine_dispatch_1d_serial(batch_fun_b, 0, Batch, 2UL);
214 
215  lhs.validate_cpu();
216  lhs.invalidate_gpu();
217  }
218  } else {
219  const auto Batch = etl::dim<0>(lhs);
220  const auto K = etl::dim<1>(lhs);
221 
222  if constexpr (impl::egblas::has_sbatch_k_scale2 && all_row_major<A, B, L> && all_floating<A, B, L>) {
223  decltype(auto) t1 = smart_forward_gpu(a);
224  decltype(auto) t2 = smart_forward_gpu(b);
225 
226  t1.ensure_gpu_up_to_date();
227  t2.ensure_gpu_up_to_date();
228 
229  lhs.ensure_gpu_allocated();
230 
231  impl::egblas::batch_k_scale(Batch, K, t2.gpu_memory(), t1.gpu_memory(), lhs.gpu_memory());
232 
233  lhs.validate_gpu();
234  lhs.invalidate_cpu();
235  } else {
236  standard_evaluator::pre_assign_rhs(a);
237  standard_evaluator::pre_assign_rhs(b);
238 
239  a.ensure_cpu_up_to_date();
240  b.ensure_cpu_up_to_date();
241 
242  auto batch_fun_b = [&](const size_t first, const size_t last) {
243  CPU_SECTION {
244  if constexpr (vec_enabled && all_vectorizable<vector_mode, A, L> && all_row_major<A, L>) {
245  using vec_type = default_vec;
246  using T = value_t<L>;
247 
248  static constexpr size_t vec_size = vec_type::template traits<T>::size;
249 
250  for (size_t batch = first; batch < last; ++batch) {
251  size_t k = 0;
252 
253  size_t base = batch * K;
254 
255  for (; k + 4 * vec_size - 1 < K; k += 4 * vec_size) {
256  auto a1 = a.template load<vec_type>(k + 0 * vec_size);
257  auto a2 = a.template load<vec_type>(k + 1 * vec_size);
258  auto a3 = a.template load<vec_type>(k + 2 * vec_size);
259  auto a4 = a.template load<vec_type>(k + 3 * vec_size);
260 
261  auto b1 = b.template loadu<vec_type>(base + k + 0 * vec_size);
262  auto b2 = b.template loadu<vec_type>(base + k + 1 * vec_size);
263  auto b3 = b.template loadu<vec_type>(base + k + 2 * vec_size);
264  auto b4 = b.template loadu<vec_type>(base + k + 3 * vec_size);
265 
266  auto r1 = vec_type::mul(a1, b1);
267  auto r2 = vec_type::mul(a2, b2);
268  auto r3 = vec_type::mul(a3, b3);
269  auto r4 = vec_type::mul(a4, b4);
270 
271  lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
272  lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
273  lhs.template storeu<vec_type>(r3, base + k + 2 * vec_size);
274  lhs.template storeu<vec_type>(r4, base + k + 3 * vec_size);
275  }
276 
277  for (; k + 2 * vec_size - 1 < K; k += 2 * vec_size) {
278  auto a1 = a.template load<vec_type>(k + 0 * vec_size);
279  auto a2 = a.template load<vec_type>(k + 1 * vec_size);
280 
281  auto b1 = b.template loadu<vec_type>(base + k + 0 * vec_size);
282  auto b2 = b.template loadu<vec_type>(base + k + 1 * vec_size);
283 
284  auto r1 = vec_type::mul(a1, b1);
285  auto r2 = vec_type::mul(a2, b2);
286 
287  lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
288  lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
289  }
290 
291  for (; k + vec_size - 1 < K; k += vec_size) {
292  auto a1 = a.template load<vec_type>(k);
293 
294  auto b1 = b.template loadu<vec_type>(base + k);
295 
296  auto r1 = vec_type::mul(a1, b1);
297 
298  lhs.template storeu<vec_type>(r1, base + k);
299  }
300 
301  for (; k + 3 < K; k += 4) {
302  lhs(batch, k + 0) = a(k + 0) * b(batch, k + 0);
303  lhs(batch, k + 1) = a(k + 1) * b(batch, k + 1);
304  lhs(batch, k + 2) = a(k + 2) * b(batch, k + 2);
305  lhs(batch, k + 3) = a(k + 3) * b(batch, k + 3);
306  }
307 
308  for (; k + 1 < K; k += 2) {
309  lhs(batch, k + 0) = a(k + 0) * b(batch, k + 0);
310  lhs(batch, k + 1) = a(k + 1) * b(batch, k + 1);
311  }
312 
313  if (k < K) {
314  lhs(batch, k) = a(k) * b(batch, k);
315  }
316  }
317  } else {
318  for (size_t batch = first; batch < last; ++batch) {
319  for (size_t k = 0; k < K; ++k) {
320  lhs(batch, k) = a(k) * b(batch, k);
321  }
322  }
323  }
324  }
325  };
326 
327  engine_dispatch_1d_serial(batch_fun_b, 0, Batch, 2UL);
328 
329  lhs.validate_cpu();
330  lhs.invalidate_gpu();
331  }
332  }
333  }
334 
339  template <etl_expr L>
340  void assign_add_to(L&& lhs) const {
341  auto& a = this->a();
342  auto& b = this->b();
343 
344  check(a, b, lhs);
345 
346  if constexpr (D4) {
347  if constexpr (impl::egblas::has_sbatch_k_scale4 && all_row_major<A, B, L> && all_floating<A, B, L>) {
348  std_add_evaluate(*this, lhs);
349  } else {
350  const auto Batch = etl::dim<0>(lhs);
351  const auto K = etl::dim<1>(lhs);
352  const auto M = etl::dim<2>(lhs);
353  const auto N = etl::dim<3>(lhs);
354 
355  standard_evaluator::pre_assign_rhs(a);
356  standard_evaluator::pre_assign_rhs(b);
357 
358  a.ensure_cpu_up_to_date();
359  b.ensure_cpu_up_to_date();
360  lhs.ensure_cpu_up_to_date();
361 
362  auto batch_fun_b = [&](const size_t first, const size_t last) {
363  CPU_SECTION {
364  if constexpr (vec_enabled && all_vectorizable<vector_mode, A, L> && all_row_major<A, L>) {
365  using vec_type = default_vec;
366  using T = value_t<L>;
367 
368  static constexpr size_t vec_size = vec_type::template traits<T>::size;
369 
370  const auto MN = M * N;
371 
372  for (size_t batch = first; batch < last; ++batch) {
373  for (size_t k = 0; k < K; ++k) {
374  T ak = a(k);
375 
376  auto lhs_sub = lhs(batch)(k);
377  auto b_sub = b(batch)(k);
378 
379  size_t mn = 0;
380 
381  auto a1 = vec_type::set(ak);
382 
383  for (; mn + 4 * vec_size - 1 < MN; mn += 4 * vec_size) {
384  auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
385  auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
386  auto b3 = b_sub.template loadu<vec_type>(mn + 2 * vec_size);
387  auto b4 = b_sub.template loadu<vec_type>(mn + 3 * vec_size);
388 
389  auto l1 = lhs_sub.template loadu<vec_type>(mn + 0 * vec_size);
390  auto l2 = lhs_sub.template loadu<vec_type>(mn + 1 * vec_size);
391  auto l3 = lhs_sub.template loadu<vec_type>(mn + 2 * vec_size);
392  auto l4 = lhs_sub.template loadu<vec_type>(mn + 3 * vec_size);
393 
394  auto r1 = vec_type::add(l1, vec_type::mul(a1, b1));
395  auto r2 = vec_type::add(l2, vec_type::mul(a1, b2));
396  auto r3 = vec_type::add(l3, vec_type::mul(a1, b3));
397  auto r4 = vec_type::add(l4, vec_type::mul(a1, b4));
398 
399  lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
400  lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
401  lhs_sub.template storeu<vec_type>(r3, mn + 2 * vec_size);
402  lhs_sub.template storeu<vec_type>(r4, mn + 3 * vec_size);
403  }
404 
405  for (; mn + 2 * vec_size - 1 < MN; mn += 2 * vec_size) {
406  auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
407  auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
408 
409  auto l1 = lhs_sub.template loadu<vec_type>(mn + 0 * vec_size);
410  auto l2 = lhs_sub.template loadu<vec_type>(mn + 1 * vec_size);
411 
412  auto r1 = vec_type::add(l1, vec_type::mul(a1, b1));
413  auto r2 = vec_type::add(l2, vec_type::mul(a1, b2));
414 
415  lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
416  lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
417  }
418 
419  for (; mn + vec_size - 1 < MN; mn += vec_size) {
420  auto b1 = b_sub.template loadu<vec_type>(mn);
421 
422  auto l1 = lhs_sub.template loadu<vec_type>(mn);
423 
424  auto r1 = vec_type::add(l1, vec_type::mul(a1, b1));
425 
426  lhs_sub.template storeu<vec_type>(r1, mn);
427  }
428 
429  for (; mn + 3 < MN; mn += 4) {
430  lhs_sub[mn + 0] += ak * b_sub[mn + 0];
431  lhs_sub[mn + 1] += ak * b_sub[mn + 1];
432  lhs_sub[mn + 2] += ak * b_sub[mn + 2];
433  lhs_sub[mn + 3] += ak * b_sub[mn + 3];
434  }
435 
436  for (; mn + 1 < MN; mn += 2) {
437  lhs_sub[mn + 0] += ak * b_sub[mn + 0];
438  lhs_sub[mn + 1] += ak * b_sub[mn + 1];
439  }
440 
441  for (; mn < MN; ++mn) {
442  lhs_sub[mn] += ak * b_sub[mn];
443  }
444  }
445  }
446  } else {
447  for (size_t batch = first; batch < last; ++batch) {
448  for (size_t k = 0; k < K; ++k) {
449  for (size_t m = 0; m < M; ++m) {
450  for (size_t n = 0; n < N; ++n) {
451  lhs(batch, k, m, n) += a(k) * b(batch, k, m, n);
452  }
453  }
454  }
455  }
456  }
457  }
458  };
459 
460  engine_dispatch_1d_serial(batch_fun_b, 0, Batch, 2UL);
461 
462  lhs.validate_cpu();
463  lhs.invalidate_gpu();
464  }
465  } else {
466  if constexpr (impl::egblas::has_sbatch_k_scale2 && all_row_major<A, B, L> && all_floating<A, B, L>) {
467  std_add_evaluate(*this, lhs);
468  } else {
469  const auto Batch = etl::dim<0>(lhs);
470  const auto K = etl::dim<1>(lhs);
471 
472  standard_evaluator::pre_assign_rhs(a);
473  standard_evaluator::pre_assign_rhs(b);
474 
475  a.ensure_cpu_up_to_date();
476  b.ensure_cpu_up_to_date();
477  lhs.ensure_cpu_up_to_date();
478 
479  auto batch_fun_b = [&](const size_t first, const size_t last) {
480  CPU_SECTION {
481  if constexpr (vec_enabled && all_vectorizable<vector_mode, A, L> && all_row_major<A, L>) {
482  using vec_type = default_vec;
483  using T = value_t<L>;
484 
485  static constexpr size_t vec_size = vec_type::template traits<T>::size;
486 
487  for (size_t batch = first; batch < last; ++batch) {
488  size_t k = 0;
489 
490  size_t base = batch * K;
491 
492  for (; k + 4 * vec_size - 1 < K; k += 4 * vec_size) {
493  auto a1 = a.template load<vec_type>(k + 0 * vec_size);
494  auto a2 = a.template load<vec_type>(k + 1 * vec_size);
495  auto a3 = a.template load<vec_type>(k + 2 * vec_size);
496  auto a4 = a.template load<vec_type>(k + 3 * vec_size);
497 
498  auto b1 = b.template loadu<vec_type>(base + k + 0 * vec_size);
499  auto b2 = b.template loadu<vec_type>(base + k + 1 * vec_size);
500  auto b3 = b.template loadu<vec_type>(base + k + 2 * vec_size);
501  auto b4 = b.template loadu<vec_type>(base + k + 3 * vec_size);
502 
503  auto l1 = lhs.template loadu<vec_type>(base + k + 0 * vec_size);
504  auto l2 = lhs.template loadu<vec_type>(base + k + 1 * vec_size);
505  auto l3 = lhs.template loadu<vec_type>(base + k + 2 * vec_size);
506  auto l4 = lhs.template loadu<vec_type>(base + k + 3 * vec_size);
507 
508  auto r1 = vec_type::add(l1, vec_type::mul(a1, b1));
509  auto r2 = vec_type::add(l2, vec_type::mul(a2, b2));
510  auto r3 = vec_type::add(l3, vec_type::mul(a3, b3));
511  auto r4 = vec_type::add(l4, vec_type::mul(a4, b4));
512 
513  lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
514  lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
515  lhs.template storeu<vec_type>(r3, base + k + 2 * vec_size);
516  lhs.template storeu<vec_type>(r4, base + k + 3 * vec_size);
517  }
518 
519  for (; k + 2 * vec_size - 1 < K; k += 2 * vec_size) {
520  auto a1 = a.template load<vec_type>(k + 0 * vec_size);
521  auto a2 = a.template load<vec_type>(k + 1 * vec_size);
522 
523  auto b1 = b.template loadu<vec_type>(base + k + 0 * vec_size);
524  auto b2 = b.template loadu<vec_type>(base + k + 1 * vec_size);
525 
526  auto l1 = lhs.template loadu<vec_type>(base + k + 0 * vec_size);
527  auto l2 = lhs.template loadu<vec_type>(base + k + 1 * vec_size);
528 
529  auto r1 = vec_type::add(l1, vec_type::mul(a1, b1));
530  auto r2 = vec_type::add(l2, vec_type::mul(a2, b2));
531 
532  lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
533  lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
534  }
535 
536  for (; k + vec_size - 1 < K; k += vec_size) {
537  auto a1 = a.template load<vec_type>(k);
538 
539  auto b1 = b.template loadu<vec_type>(base + k);
540 
541  auto l1 = lhs.template loadu<vec_type>(base + k);
542 
543  auto r1 = vec_type::add(l1, vec_type::mul(a1, b1));
544 
545  lhs.template storeu<vec_type>(r1, base + k);
546  }
547 
548  for (; k + 3 < K; k += 4) {
549  lhs(batch, k + 0) += a(k + 0) * b(batch, k + 0);
550  lhs(batch, k + 1) += a(k + 1) * b(batch, k + 1);
551  lhs(batch, k + 2) += a(k + 2) * b(batch, k + 2);
552  lhs(batch, k + 3) += a(k + 3) * b(batch, k + 3);
553  }
554 
555  for (; k + 1 < K; k += 2) {
556  lhs(batch, k + 0) += a(k + 0) * b(batch, k + 0);
557  lhs(batch, k + 1) += a(k + 1) * b(batch, k + 1);
558  }
559 
560  if (k < K) {
561  lhs(batch, k) += a(k) * b(batch, k);
562  }
563  }
564  } else {
565  for (size_t batch = first; batch < last; ++batch) {
566  for (size_t k = 0; k < K; ++k) {
567  lhs(batch, k) += a(k) * b(batch, k);
568  }
569  }
570  }
571  }
572  };
573 
574  engine_dispatch_1d_serial(batch_fun_b, 0, Batch, 2UL);
575 
576  lhs.validate_cpu();
577  lhs.invalidate_gpu();
578  }
579  }
580  }
581 
586  template <etl_expr L>
587  void assign_sub_to(L&& lhs) const {
588  auto& a = this->a();
589  auto& b = this->b();
590 
591  check(a, b, lhs);
592 
593  if constexpr (D4) {
594  if constexpr (impl::egblas::has_sbatch_k_scale4 && all_row_major<A, B, L> && all_floating<A, B, L>) {
595  std_sub_evaluate(*this, lhs);
596  } else {
597  const auto Batch = etl::dim<0>(lhs);
598  const auto K = etl::dim<1>(lhs);
599  const auto M = etl::dim<2>(lhs);
600  const auto N = etl::dim<3>(lhs);
601 
602  standard_evaluator::pre_assign_rhs(a);
603  standard_evaluator::pre_assign_rhs(b);
604 
605  a.ensure_cpu_up_to_date();
606  b.ensure_cpu_up_to_date();
607  lhs.ensure_cpu_up_to_date();
608 
609  auto batch_fun_b = [&](const size_t first, const size_t last) {
610  CPU_SECTION {
611  if constexpr (vec_enabled && all_vectorizable<vector_mode, A, L> && all_row_major<A, L>) {
612  using vec_type = default_vec;
613  using T = value_t<L>;
614 
615  static constexpr size_t vec_size = vec_type::template traits<T>::size;
616 
617  const auto MN = M * N;
618 
619  for (size_t batch = first; batch < last; ++batch) {
620  for (size_t k = 0; k < K; ++k) {
621  T ak = a(k);
622 
623  auto lhs_sub = lhs(batch)(k);
624  auto b_sub = b(batch)(k);
625 
626  size_t mn = 0;
627 
628  auto a1 = vec_type::set(ak);
629 
630  for (; mn + 4 * vec_size - 1 < MN; mn += 4 * vec_size) {
631  auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
632  auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
633  auto b3 = b_sub.template loadu<vec_type>(mn + 2 * vec_size);
634  auto b4 = b_sub.template loadu<vec_type>(mn + 3 * vec_size);
635 
636  auto l1 = lhs_sub.template loadu<vec_type>(mn + 0 * vec_size);
637  auto l2 = lhs_sub.template loadu<vec_type>(mn + 1 * vec_size);
638  auto l3 = lhs_sub.template loadu<vec_type>(mn + 2 * vec_size);
639  auto l4 = lhs_sub.template loadu<vec_type>(mn + 3 * vec_size);
640 
641  auto r1 = vec_type::sub(l1, vec_type::mul(a1, b1));
642  auto r2 = vec_type::sub(l2, vec_type::mul(a1, b2));
643  auto r3 = vec_type::sub(l3, vec_type::mul(a1, b3));
644  auto r4 = vec_type::sub(l4, vec_type::mul(a1, b4));
645 
646  lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
647  lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
648  lhs_sub.template storeu<vec_type>(r3, mn + 2 * vec_size);
649  lhs_sub.template storeu<vec_type>(r4, mn + 3 * vec_size);
650  }
651 
652  for (; mn + 2 * vec_size - 1 < MN; mn += 2 * vec_size) {
653  auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
654  auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
655 
656  auto l1 = lhs_sub.template loadu<vec_type>(mn + 0 * vec_size);
657  auto l2 = lhs_sub.template loadu<vec_type>(mn + 1 * vec_size);
658 
659  auto r1 = vec_type::sub(l1, vec_type::mul(a1, b1));
660  auto r2 = vec_type::sub(l2, vec_type::mul(a1, b2));
661 
662  lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
663  lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
664  }
665 
666  for (; mn + vec_size - 1 < MN; mn += vec_size) {
667  auto b1 = b_sub.template loadu<vec_type>(mn);
668 
669  auto l1 = lhs_sub.template loadu<vec_type>(mn);
670 
671  auto r1 = vec_type::sub(l1, vec_type::mul(a1, b1));
672 
673  lhs_sub.template storeu<vec_type>(r1, mn);
674  }
675 
676  for (; mn + 3 < MN; mn += 4) {
677  lhs_sub[mn + 0] -= ak * b_sub[mn + 0];
678  lhs_sub[mn + 1] -= ak * b_sub[mn + 1];
679  lhs_sub[mn + 2] -= ak * b_sub[mn + 2];
680  lhs_sub[mn + 3] -= ak * b_sub[mn + 3];
681  }
682 
683  for (; mn + 1 < MN; mn += 2) {
684  lhs_sub[mn + 0] -= ak * b_sub[mn + 0];
685  lhs_sub[mn + 1] -= ak * b_sub[mn + 1];
686  }
687 
688  for (; mn < MN; ++mn) {
689  lhs_sub[mn] -= ak * b_sub[mn];
690  }
691  }
692  }
693  } else {
694  for (size_t batch = first; batch < last; ++batch) {
695  for (size_t k = 0; k < K; ++k) {
696  for (size_t m = 0; m < M; ++m) {
697  for (size_t n = 0; n < N; ++n) {
698  lhs(batch, k, m, n) -= a(k) * b(batch, k, m, n);
699  }
700  }
701  }
702  }
703  }
704  }
705  };
706 
707  engine_dispatch_1d_serial(batch_fun_b, 0, Batch, 2UL);
708 
709  lhs.validate_cpu();
710  lhs.invalidate_gpu();
711  }
712  } else {
713  if constexpr (impl::egblas::has_sbatch_k_scale2 && all_row_major<A, B, L> && all_floating<A, B, L>) {
714  std_sub_evaluate(*this, lhs);
715  } else {
716  const auto Batch = etl::dim<0>(lhs);
717  const auto K = etl::dim<1>(lhs);
718 
719  standard_evaluator::pre_assign_rhs(a);
720  standard_evaluator::pre_assign_rhs(b);
721 
722  a.ensure_cpu_up_to_date();
723  b.ensure_cpu_up_to_date();
724  lhs.ensure_cpu_up_to_date();
725 
726  auto batch_fun_b = [&](const size_t first, const size_t last) {
727  CPU_SECTION {
728  if constexpr (vec_enabled && all_vectorizable<vector_mode, A, L> && all_row_major<A, L>) {
729  using vec_type = default_vec;
730  using T = value_t<L>;
731 
732  static constexpr size_t vec_size = vec_type::template traits<T>::size;
733 
734  for (size_t batch = first; batch < last; ++batch) {
735  size_t k = 0;
736 
737  size_t base = batch * K;
738 
739  for (; k + 4 * vec_size - 1 < K; k += 4 * vec_size) {
740  auto a1 = a.template load<vec_type>(k + 0 * vec_size);
741  auto a2 = a.template load<vec_type>(k + 1 * vec_size);
742  auto a3 = a.template load<vec_type>(k + 2 * vec_size);
743  auto a4 = a.template load<vec_type>(k + 3 * vec_size);
744 
745  auto b1 = b.template loadu<vec_type>(base + k + 0 * vec_size);
746  auto b2 = b.template loadu<vec_type>(base + k + 1 * vec_size);
747  auto b3 = b.template loadu<vec_type>(base + k + 2 * vec_size);
748  auto b4 = b.template loadu<vec_type>(base + k + 3 * vec_size);
749 
750  auto l1 = lhs.template loadu<vec_type>(base + k + 0 * vec_size);
751  auto l2 = lhs.template loadu<vec_type>(base + k + 1 * vec_size);
752  auto l3 = lhs.template loadu<vec_type>(base + k + 2 * vec_size);
753  auto l4 = lhs.template loadu<vec_type>(base + k + 3 * vec_size);
754 
755  auto r1 = vec_type::sub(l1, vec_type::mul(a1, b1));
756  auto r2 = vec_type::sub(l2, vec_type::mul(a2, b2));
757  auto r3 = vec_type::sub(l3, vec_type::mul(a3, b3));
758  auto r4 = vec_type::sub(l4, vec_type::mul(a4, b4));
759 
760  lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
761  lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
762  lhs.template storeu<vec_type>(r3, base + k + 2 * vec_size);
763  lhs.template storeu<vec_type>(r4, base + k + 3 * vec_size);
764  }
765 
766  for (; k + 2 * vec_size - 1 < K; k += 2 * vec_size) {
767  auto a1 = a.template load<vec_type>(k + 0 * vec_size);
768  auto a2 = a.template load<vec_type>(k + 1 * vec_size);
769 
770  auto b1 = b.template loadu<vec_type>(base + k + 0 * vec_size);
771  auto b2 = b.template loadu<vec_type>(base + k + 1 * vec_size);
772 
773  auto l1 = lhs.template loadu<vec_type>(base + k + 0 * vec_size);
774  auto l2 = lhs.template loadu<vec_type>(base + k + 1 * vec_size);
775 
776  auto r1 = vec_type::sub(l1, vec_type::mul(a1, b1));
777  auto r2 = vec_type::sub(l2, vec_type::mul(a2, b2));
778 
779  lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
780  lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
781  }
782 
783  for (; k + vec_size - 1 < K; k += vec_size) {
784  auto a1 = a.template load<vec_type>(k);
785 
786  auto b1 = b.template loadu<vec_type>(base + k);
787 
788  auto l1 = lhs.template loadu<vec_type>(base + k);
789 
790  auto r1 = vec_type::sub(l1, vec_type::mul(a1, b1));
791 
792  lhs.template storeu<vec_type>(r1, base + k);
793  }
794 
795  for (; k + 3 < K; k += 4) {
796  lhs(batch, k + 0) -= a(k + 0) * b(batch, k + 0);
797  lhs(batch, k + 1) -= a(k + 1) * b(batch, k + 1);
798  lhs(batch, k + 2) -= a(k + 2) * b(batch, k + 2);
799  lhs(batch, k + 3) -= a(k + 3) * b(batch, k + 3);
800  }
801 
802  for (; k + 1 < K; k += 2) {
803  lhs(batch, k + 0) -= a(k + 0) * b(batch, k + 0);
804  lhs(batch, k + 1) -= a(k + 1) * b(batch, k + 1);
805  }
806 
807  if (k < K) {
808  lhs(batch, k) -= a(k) * b(batch, k);
809  }
810  }
811  } else {
812  for (size_t batch = first; batch < last; ++batch) {
813  for (size_t k = 0; k < K; ++k) {
814  lhs(batch, k) -= a(k) * b(batch, k);
815  }
816  }
817  }
818  }
819  };
820 
821  engine_dispatch_1d_serial(batch_fun_b, 0, Batch, 2UL);
822 
823  lhs.validate_cpu();
824  lhs.invalidate_gpu();
825  }
826  }
827  }
828 
833  template <etl_expr L>
834  void assign_mul_to(L&& lhs) const {
835  auto& a = this->a();
836  auto& b = this->b();
837 
838  check(a, b, lhs);
839 
840  if constexpr (D4) {
841  if constexpr (impl::egblas::has_sbatch_k_scale4 && all_row_major<A, B, L> && all_floating<A, B, L>) {
842  std_mul_evaluate(*this, lhs);
843  } else {
844  const auto Batch = etl::dim<0>(lhs);
845  const auto K = etl::dim<1>(lhs);
846  const auto M = etl::dim<2>(lhs);
847  const auto N = etl::dim<3>(lhs);
848 
849  standard_evaluator::pre_assign_rhs(a);
850  standard_evaluator::pre_assign_rhs(b);
851 
852  a.ensure_cpu_up_to_date();
853  b.ensure_cpu_up_to_date();
854  lhs.ensure_cpu_up_to_date();
855 
856  auto batch_fun_b = [&](const size_t first, const size_t last) {
857  CPU_SECTION {
858  if constexpr (vec_enabled && all_vectorizable<vector_mode, A, L> && all_row_major<A, L>) {
859  using vec_type = default_vec;
860  using T = value_t<L>;
861 
862  static constexpr size_t vec_size = vec_type::template traits<T>::size;
863 
864  const auto MN = M * N;
865 
866  for (size_t batch = first; batch < last; ++batch) {
867  for (size_t k = 0; k < K; ++k) {
868  T ak = a(k);
869 
870  auto lhs_sub = lhs(batch)(k);
871  auto b_sub = b(batch)(k);
872 
873  size_t mn = 0;
874 
875  auto a1 = vec_type::set(ak);
876 
877  for (; mn + 4 * vec_size - 1 < MN; mn += 4 * vec_size) {
878  auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
879  auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
880  auto b3 = b_sub.template loadu<vec_type>(mn + 2 * vec_size);
881  auto b4 = b_sub.template loadu<vec_type>(mn + 3 * vec_size);
882 
883  auto l1 = lhs_sub.template loadu<vec_type>(mn + 0 * vec_size);
884  auto l2 = lhs_sub.template loadu<vec_type>(mn + 1 * vec_size);
885  auto l3 = lhs_sub.template loadu<vec_type>(mn + 2 * vec_size);
886  auto l4 = lhs_sub.template loadu<vec_type>(mn + 3 * vec_size);
887 
888  auto r1 = vec_type::mul(l1, vec_type::mul(a1, b1));
889  auto r2 = vec_type::mul(l2, vec_type::mul(a1, b2));
890  auto r3 = vec_type::mul(l3, vec_type::mul(a1, b3));
891  auto r4 = vec_type::mul(l4, vec_type::mul(a1, b4));
892 
893  lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
894  lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
895  lhs_sub.template storeu<vec_type>(r3, mn + 2 * vec_size);
896  lhs_sub.template storeu<vec_type>(r4, mn + 3 * vec_size);
897  }
898 
899  for (; mn + 2 * vec_size - 1 < MN; mn += 2 * vec_size) {
900  auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
901  auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
902 
903  auto l1 = lhs_sub.template loadu<vec_type>(mn + 0 * vec_size);
904  auto l2 = lhs_sub.template loadu<vec_type>(mn + 1 * vec_size);
905 
906  auto r1 = vec_type::mul(l1, vec_type::mul(a1, b1));
907  auto r2 = vec_type::mul(l2, vec_type::mul(a1, b2));
908 
909  lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
910  lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
911  }
912 
913  for (; mn + vec_size - 1 < MN; mn += vec_size) {
914  auto b1 = b_sub.template loadu<vec_type>(mn);
915 
916  auto l1 = lhs_sub.template loadu<vec_type>(mn);
917 
918  auto r1 = vec_type::mul(l1, vec_type::mul(a1, b1));
919 
920  lhs_sub.template storeu<vec_type>(r1, mn);
921  }
922 
923  for (; mn + 3 < MN; mn += 4) {
924  lhs_sub[mn + 0] *= ak * b_sub[mn + 0];
925  lhs_sub[mn + 1] *= ak * b_sub[mn + 1];
926  lhs_sub[mn + 2] *= ak * b_sub[mn + 2];
927  lhs_sub[mn + 3] *= ak * b_sub[mn + 3];
928  }
929 
930  for (; mn + 1 < MN; mn += 2) {
931  lhs_sub[mn + 0] *= ak * b_sub[mn + 0];
932  lhs_sub[mn + 1] *= ak * b_sub[mn + 1];
933  }
934 
935  for (; mn < MN; ++mn) {
936  lhs_sub[mn] *= ak * b_sub[mn];
937  }
938  }
939  }
940  } else {
941  for (size_t batch = first; batch < last; ++batch) {
942  for (size_t k = 0; k < K; ++k) {
943  for (size_t m = 0; m < M; ++m) {
944  for (size_t n = 0; n < N; ++n) {
945  lhs(batch, k, m, n) *= a(k) * b(batch, k, m, n);
946  }
947  }
948  }
949  }
950  }
951  }
952  };
953 
954  engine_dispatch_1d_serial(batch_fun_b, 0, Batch, 2UL);
955 
956  lhs.validate_cpu();
957  lhs.invalidate_gpu();
958  }
959  } else {
960  if constexpr (impl::egblas::has_sbatch_k_scale2 && all_row_major<A, B, L> && all_floating<A, B, L>) {
961  std_mul_evaluate(*this, lhs);
962  } else {
963  const auto Batch = etl::dim<0>(lhs);
964  const auto K = etl::dim<1>(lhs);
965 
966  standard_evaluator::pre_assign_rhs(a);
967  standard_evaluator::pre_assign_rhs(b);
968 
969  a.ensure_cpu_up_to_date();
970  b.ensure_cpu_up_to_date();
971  lhs.ensure_cpu_up_to_date();
972 
973  auto batch_fun_b = [&](const size_t first, const size_t last) {
974  CPU_SECTION {
975  if constexpr (vec_enabled && all_vectorizable<vector_mode, A, L> && all_row_major<A, L>) {
976  using vec_type = default_vec;
977  using T = value_t<L>;
978 
979  static constexpr size_t vec_size = vec_type::template traits<T>::size;
980 
981  for (size_t batch = first; batch < last; ++batch) {
982  size_t k = 0;
983 
984  size_t base = batch * K;
985 
986  for (; k + 4 * vec_size - 1 < K; k += 4 * vec_size) {
987  auto a1 = a.template load<vec_type>(k + 0 * vec_size);
988  auto a2 = a.template load<vec_type>(k + 1 * vec_size);
989  auto a3 = a.template load<vec_type>(k + 2 * vec_size);
990  auto a4 = a.template load<vec_type>(k + 3 * vec_size);
991 
992  auto b1 = b.template loadu<vec_type>(base + k + 0 * vec_size);
993  auto b2 = b.template loadu<vec_type>(base + k + 1 * vec_size);
994  auto b3 = b.template loadu<vec_type>(base + k + 2 * vec_size);
995  auto b4 = b.template loadu<vec_type>(base + k + 3 * vec_size);
996 
997  auto l1 = lhs.template loadu<vec_type>(base + k + 0 * vec_size);
998  auto l2 = lhs.template loadu<vec_type>(base + k + 1 * vec_size);
999  auto l3 = lhs.template loadu<vec_type>(base + k + 2 * vec_size);
1000  auto l4 = lhs.template loadu<vec_type>(base + k + 3 * vec_size);
1001 
1002  auto r1 = vec_type::mul(l1, vec_type::mul(a1, b1));
1003  auto r2 = vec_type::mul(l2, vec_type::mul(a2, b2));
1004  auto r3 = vec_type::mul(l3, vec_type::mul(a3, b3));
1005  auto r4 = vec_type::mul(l4, vec_type::mul(a4, b4));
1006 
1007  lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
1008  lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
1009  lhs.template storeu<vec_type>(r3, base + k + 2 * vec_size);
1010  lhs.template storeu<vec_type>(r4, base + k + 3 * vec_size);
1011  }
1012 
1013  for (; k + 2 * vec_size - 1 < K; k += 2 * vec_size) {
1014  auto a1 = a.template load<vec_type>(k + 0 * vec_size);
1015  auto a2 = a.template load<vec_type>(k + 1 * vec_size);
1016 
1017  auto b1 = b.template loadu<vec_type>(base + k + 0 * vec_size);
1018  auto b2 = b.template loadu<vec_type>(base + k + 1 * vec_size);
1019 
1020  auto l1 = lhs.template loadu<vec_type>(base + k + 0 * vec_size);
1021  auto l2 = lhs.template loadu<vec_type>(base + k + 1 * vec_size);
1022 
1023  auto r1 = vec_type::mul(l1, vec_type::mul(a1, b1));
1024  auto r2 = vec_type::mul(l2, vec_type::mul(a2, b2));
1025 
1026  lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
1027  lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
1028  }
1029 
1030  for (; k + vec_size - 1 < K; k += vec_size) {
1031  auto a1 = a.template load<vec_type>(k);
1032 
1033  auto b1 = b.template loadu<vec_type>(base + k);
1034 
1035  auto l1 = lhs.template loadu<vec_type>(base + k);
1036 
1037  auto r1 = vec_type::mul(l1, vec_type::mul(a1, b1));
1038 
1039  lhs.template storeu<vec_type>(r1, base + k);
1040  }
1041 
1042  for (; k + 3 < K; k += 4) {
1043  lhs(batch, k + 0) *= a(k + 0) * b(batch, k + 0);
1044  lhs(batch, k + 1) *= a(k + 1) * b(batch, k + 1);
1045  lhs(batch, k + 2) *= a(k + 2) * b(batch, k + 2);
1046  lhs(batch, k + 3) *= a(k + 3) * b(batch, k + 3);
1047  }
1048 
1049  for (; k + 1 < K; k += 2) {
1050  lhs(batch, k + 0) *= a(k + 0) * b(batch, k + 0);
1051  lhs(batch, k + 1) *= a(k + 1) * b(batch, k + 1);
1052  }
1053 
1054  if (k < K) {
1055  lhs(batch, k) *= a(k) * b(batch, k);
1056  }
1057  }
1058  } else {
1059  for (size_t batch = first; batch < last; ++batch) {
1060  for (size_t k = 0; k < K; ++k) {
1061  lhs(batch, k) *= a(k) * b(batch, k);
1062  }
1063  }
1064  }
1065  }
1066  };
1067 
1068  engine_dispatch_1d_serial(batch_fun_b, 0, Batch, 2UL);
1069 
1070  lhs.validate_cpu();
1071  lhs.invalidate_gpu();
1072  }
1073  }
1074  }
1075 
1080  template <etl_expr L>
1081  void assign_div_to(L&& lhs) const {
1082  auto& a = this->a();
1083  auto& b = this->b();
1084 
1085  check(a, b, lhs);
1086 
1087  if constexpr (D4) {
1088  if constexpr (impl::egblas::has_sbatch_k_scale4 && all_row_major<A, B, L> && all_floating<A, B, L>) {
1089  std_div_evaluate(*this, lhs);
1090  } else {
1091  const auto Batch = etl::dim<0>(lhs);
1092  const auto K = etl::dim<1>(lhs);
1093  const auto M = etl::dim<2>(lhs);
1094  const auto N = etl::dim<3>(lhs);
1095 
1096  standard_evaluator::pre_assign_rhs(a);
1097  standard_evaluator::pre_assign_rhs(b);
1098 
1099  a.ensure_cpu_up_to_date();
1100  b.ensure_cpu_up_to_date();
1101  lhs.ensure_cpu_up_to_date();
1102 
1103  auto batch_fun_b = [&](const size_t first, const size_t last) {
1104  CPU_SECTION {
1105  if constexpr (vec_enabled && all_vectorizable<vector_mode, A, L> && all_row_major<A, L>) {
1106  using vec_type = default_vec;
1107  using T = value_t<L>;
1108 
1109  static constexpr size_t vec_size = vec_type::template traits<T>::size;
1110 
1111  const auto MN = M * N;
1112 
1113  for (size_t batch = first; batch < last; ++batch) {
1114  for (size_t k = 0; k < K; ++k) {
1115  T ak = a(k);
1116 
1117  auto lhs_sub = lhs(batch)(k);
1118  auto b_sub = b(batch)(k);
1119 
1120  size_t mn = 0;
1121 
1122  auto a1 = vec_type::set(ak);
1123 
1124  for (; mn + 4 * vec_size - 1 < MN; mn += 4 * vec_size) {
1125  auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
1126  auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
1127  auto b3 = b_sub.template loadu<vec_type>(mn + 2 * vec_size);
1128  auto b4 = b_sub.template loadu<vec_type>(mn + 3 * vec_size);
1129 
1130  auto l1 = lhs_sub.template loadu<vec_type>(mn + 0 * vec_size);
1131  auto l2 = lhs_sub.template loadu<vec_type>(mn + 1 * vec_size);
1132  auto l3 = lhs_sub.template loadu<vec_type>(mn + 2 * vec_size);
1133  auto l4 = lhs_sub.template loadu<vec_type>(mn + 3 * vec_size);
1134 
1135  auto r1 = vec_type::div(l1, vec_type::mul(a1, b1));
1136  auto r2 = vec_type::div(l2, vec_type::mul(a1, b2));
1137  auto r3 = vec_type::div(l3, vec_type::mul(a1, b3));
1138  auto r4 = vec_type::div(l4, vec_type::mul(a1, b4));
1139 
1140  lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
1141  lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
1142  lhs_sub.template storeu<vec_type>(r3, mn + 2 * vec_size);
1143  lhs_sub.template storeu<vec_type>(r4, mn + 3 * vec_size);
1144  }
1145 
1146  for (; mn + 2 * vec_size - 1 < MN; mn += 2 * vec_size) {
1147  auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
1148  auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
1149 
1150  auto l1 = lhs_sub.template loadu<vec_type>(mn + 0 * vec_size);
1151  auto l2 = lhs_sub.template loadu<vec_type>(mn + 1 * vec_size);
1152 
1153  auto r1 = vec_type::div(l1, vec_type::mul(a1, b1));
1154  auto r2 = vec_type::div(l2, vec_type::mul(a1, b2));
1155 
1156  lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
1157  lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
1158  }
1159 
1160  for (; mn + vec_size - 1 < MN; mn += vec_size) {
1161  auto b1 = b_sub.template loadu<vec_type>(mn);
1162 
1163  auto l1 = lhs_sub.template loadu<vec_type>(mn);
1164 
1165  auto r1 = vec_type::div(l1, vec_type::mul(a1, b1));
1166 
1167  lhs_sub.template storeu<vec_type>(r1, mn);
1168  }
1169 
1170  for (; mn + 3 < MN; mn += 4) {
1171  lhs_sub[mn + 0] /= ak * b_sub[mn + 0];
1172  lhs_sub[mn + 1] /= ak * b_sub[mn + 1];
1173  lhs_sub[mn + 2] /= ak * b_sub[mn + 2];
1174  lhs_sub[mn + 3] /= ak * b_sub[mn + 3];
1175  }
1176 
1177  for (; mn + 1 < MN; mn += 2) {
1178  lhs_sub[mn + 0] /= ak * b_sub[mn + 0];
1179  lhs_sub[mn + 1] /= ak * b_sub[mn + 1];
1180  }
1181 
1182  for (; mn < MN; ++mn) {
1183  lhs_sub[mn] /= ak * b_sub[mn];
1184  }
1185  }
1186  }
1187  } else {
1188  for (size_t batch = first; batch < last; ++batch) {
1189  for (size_t k = 0; k < K; ++k) {
1190  for (size_t m = 0; m < M; ++m) {
1191  for (size_t n = 0; n < N; ++n) {
1192  lhs(batch, k, m, n) /= a(k) * b(batch, k, m, n);
1193  }
1194  }
1195  }
1196  }
1197  }
1198  }
1199  };
1200 
1201  engine_dispatch_1d_serial(batch_fun_b, 0, Batch, 2UL);
1202 
1203  lhs.validate_cpu();
1204  lhs.invalidate_gpu();
1205  }
1206  } else {
1207  if constexpr (impl::egblas::has_sbatch_k_scale2 && all_row_major<A, B, L> && all_floating<A, B, L>) {
1208  std_div_evaluate(*this, lhs);
1209  } else {
1210  const auto Batch = etl::dim<0>(lhs);
1211  const auto K = etl::dim<1>(lhs);
1212 
1213  standard_evaluator::pre_assign_rhs(a);
1214  standard_evaluator::pre_assign_rhs(b);
1215 
1216  a.ensure_cpu_up_to_date();
1217  b.ensure_cpu_up_to_date();
1218  lhs.ensure_cpu_up_to_date();
1219 
1220  auto batch_fun_b = [&](const size_t first, const size_t last) {
1221  CPU_SECTION {
1222  if constexpr (vec_enabled && all_vectorizable<vector_mode, A, L> && all_row_major<A, L>) {
1223  using vec_type = default_vec;
1224  using T = value_t<L>;
1225 
1226  static constexpr size_t vec_size = vec_type::template traits<T>::size;
1227 
1228  for (size_t batch = first; batch < last; ++batch) {
1229  size_t k = 0;
1230 
1231  size_t base = batch * K;
1232 
1233  for (; k + 4 * vec_size - 1 < K; k += 4 * vec_size) {
1234  auto a1 = a.template load<vec_type>(k + 0 * vec_size);
1235  auto a2 = a.template load<vec_type>(k + 1 * vec_size);
1236  auto a3 = a.template load<vec_type>(k + 2 * vec_size);
1237  auto a4 = a.template load<vec_type>(k + 3 * vec_size);
1238 
1239  auto b1 = b.template loadu<vec_type>(base + k + 0 * vec_size);
1240  auto b2 = b.template loadu<vec_type>(base + k + 1 * vec_size);
1241  auto b3 = b.template loadu<vec_type>(base + k + 2 * vec_size);
1242  auto b4 = b.template loadu<vec_type>(base + k + 3 * vec_size);
1243 
1244  auto l1 = lhs.template loadu<vec_type>(base + k + 0 * vec_size);
1245  auto l2 = lhs.template loadu<vec_type>(base + k + 1 * vec_size);
1246  auto l3 = lhs.template loadu<vec_type>(base + k + 2 * vec_size);
1247  auto l4 = lhs.template loadu<vec_type>(base + k + 3 * vec_size);
1248 
1249  auto r1 = vec_type::div(l1, vec_type::mul(a1, b1));
1250  auto r2 = vec_type::div(l2, vec_type::mul(a2, b2));
1251  auto r3 = vec_type::div(l3, vec_type::mul(a3, b3));
1252  auto r4 = vec_type::div(l4, vec_type::mul(a4, b4));
1253 
1254  lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
1255  lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
1256  lhs.template storeu<vec_type>(r3, base + k + 2 * vec_size);
1257  lhs.template storeu<vec_type>(r4, base + k + 3 * vec_size);
1258  }
1259 
1260  for (; k + 2 * vec_size - 1 < K; k += 2 * vec_size) {
1261  auto a1 = a.template load<vec_type>(k + 0 * vec_size);
1262  auto a2 = a.template load<vec_type>(k + 1 * vec_size);
1263 
1264  auto b1 = b.template loadu<vec_type>(base + k + 0 * vec_size);
1265  auto b2 = b.template loadu<vec_type>(base + k + 1 * vec_size);
1266 
1267  auto l1 = lhs.template loadu<vec_type>(base + k + 0 * vec_size);
1268  auto l2 = lhs.template loadu<vec_type>(base + k + 1 * vec_size);
1269 
1270  auto r1 = vec_type::div(l1, vec_type::mul(a1, b1));
1271  auto r2 = vec_type::div(l2, vec_type::mul(a2, b2));
1272 
1273  lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
1274  lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
1275  }
1276 
1277  for (; k + vec_size - 1 < K; k += vec_size) {
1278  auto a1 = a.template load<vec_type>(k);
1279 
1280  auto b1 = b.template loadu<vec_type>(base + k);
1281 
1282  auto l1 = lhs.template loadu<vec_type>(base + k);
1283 
1284  auto r1 = vec_type::div(l1, vec_type::mul(a1, b1));
1285 
1286  lhs.template storeu<vec_type>(r1, base + k);
1287  }
1288 
1289  for (; k + 3 < K; k += 4) {
1290  lhs(batch, k + 0) /= a(k + 0) * b(batch, k + 0);
1291  lhs(batch, k + 1) /= a(k + 1) * b(batch, k + 1);
1292  lhs(batch, k + 2) /= a(k + 2) * b(batch, k + 2);
1293  lhs(batch, k + 3) /= a(k + 3) * b(batch, k + 3);
1294  }
1295 
1296  for (; k + 1 < K; k += 2) {
1297  lhs(batch, k + 0) /= a(k + 0) * b(batch, k + 0);
1298  lhs(batch, k + 1) /= a(k + 1) * b(batch, k + 1);
1299  }
1300 
1301  if (k < K) {
1302  lhs(batch, k) /= a(k) * b(batch, k);
1303  }
1304  }
1305  } else {
1306  for (size_t batch = first; batch < last; ++batch) {
1307  for (size_t k = 0; k < K; ++k) {
1308  lhs(batch, k) /= a(k) * b(batch, k);
1309  }
1310  }
1311  }
1312  }
1313  };
1314 
1315  engine_dispatch_1d_serial(batch_fun_b, 0, Batch, 2UL);
1316 
1317  lhs.validate_cpu();
1318  lhs.invalidate_gpu();
1319  }
1320  }
1321  }
1322 
1327  template <etl_expr L>
1328  void assign_mod_to(L&& lhs) const {
1329  auto& a = this->a();
1330  auto& b = this->b();
1331 
1332  check(a, b, lhs);
1333 
1334  standard_evaluator::pre_assign_rhs(a);
1335  standard_evaluator::pre_assign_rhs(b);
1336 
1337  a.ensure_cpu_up_to_date();
1338  b.ensure_cpu_up_to_date();
1339  lhs.ensure_cpu_up_to_date();
1340 
1341  if constexpr (D4) {
1342  const auto Batch = etl::dim<0>(lhs);
1343  const auto K = etl::dim<1>(lhs);
1344  const auto M = etl::dim<2>(lhs);
1345  const auto N = etl::dim<3>(lhs);
1346 
1347  auto batch_fun_b = [&](const size_t first, const size_t last) {
1348  CPU_SECTION {
1349  for (size_t batch = first; batch < last; ++batch) {
1350  for (size_t k = 0; k < K; ++k) {
1351  for (size_t m = 0; m < M; ++m) {
1352  for (size_t n = 0; n < N; ++n) {
1353  lhs(batch, k, m, n) %= a(k) * b(batch, k, m, n);
1354  }
1355  }
1356  }
1357  }
1358  }
1359  };
1360 
1361  engine_dispatch_1d_serial(batch_fun_b, 0, Batch, 2UL);
1362 
1363  lhs.validate_cpu();
1364  lhs.invalidate_gpu();
1365  } else {
1366  const auto Batch = etl::dim<0>(lhs);
1367  const auto K = etl::dim<1>(lhs);
1368 
1369  auto batch_fun_b = [&](const size_t first, const size_t last) {
1370  CPU_SECTION {
1371  for (size_t batch = first; batch < last; ++batch) {
1372  for (size_t k = 0; k < K; ++k) {
1373  lhs(batch, k) %= a(k) * b(batch, k);
1374  }
1375  }
1376  }
1377  };
1378 
1379  engine_dispatch_1d_serial(batch_fun_b, 0, Batch, 2UL);
1380 
1381  lhs.validate_cpu();
1382  lhs.invalidate_gpu();
1383  }
1384  }
1385 
1392  friend std::ostream& operator<<(std::ostream& os, const batch_k_scale_expr& expr) {
1393  return os << "batch_k_scale(" << expr._a << "," << expr._b << ")";
1394  }
1395 };
1396 
1401 template <typename A, typename B>
1404  using sub_expr_t = std::decay_t<B>;
1407 
1408  static constexpr bool is_etl = true;
1409  static constexpr bool is_transformer = false;
1410  static constexpr bool is_view = false;
1411  static constexpr bool is_magic_view = false;
1412  static constexpr bool is_fast = sub_traits::is_fast;
1413  static constexpr bool is_linear = false;
1414  static constexpr bool is_thread_safe = true;
1415  static constexpr bool is_value = false;
1416  static constexpr bool is_direct = true;
1417  static constexpr bool is_generator = false;
1418  static constexpr bool is_padded = false;
1419  static constexpr bool is_aligned = true;
1420  static constexpr bool is_temporary = true;
1421  static constexpr bool gpu_computable = true;
1422  static constexpr order storage_order = sub_traits::storage_order;
1423 
1429  template <vector_mode_t V>
1430  static constexpr bool vectorizable = true;
1431 
1436  template <size_t DD>
1437  static constexpr size_t dim() {
1438  return decay_traits<B>::template dim<DD>();
1439  }
1440 
1447  static size_t dim(const expr_t& e, size_t d) {
1448  return etl::dim(e._b, d);
1449  }
1450 
1456  static size_t size(const expr_t& e) {
1457  return etl::size(e._b);
1458  }
1459 
1464  static constexpr size_t size() {
1465  return decay_traits<B>::size();
1466  }
1467 
1472  static constexpr size_t dimensions() {
1473  return decay_traits<B>::dimensions();
1474  }
1475 
1480  static constexpr int complexity() noexcept {
1481  return -1;
1482  }
1483 };
1484 
1485 // Note: This function should not be called directly
1486 // instead, batch_hint(a >> b) should be used
1487 // But this function is used as helpers from batch_hint
1488 
1494 template <etl_1d A, etl_2d_or_4d B>
1496  return {a, b};
1497 }
1498 
1499 } //end of namespace etl
static constexpr size_t size()
Returns the size of the expression.
Definition: batch_k_scale_expr.hpp:1464
void assign_mul_to(L &&lhs) const
Multiply the given left-hand-side expression.
Definition: batch_k_scale_expr.hpp:834
static constexpr auto storage_order
The sub storage order.
Definition: batch_k_scale_expr.hpp:29
value_t< A > value_type
The value type of the expression.
Definition: batch_k_scale_expr.hpp:1406
friend std::ostream & operator<<(std::ostream &os, const batch_k_scale_expr &expr)
Print a representation of the expression on the given stream.
Definition: batch_k_scale_expr.hpp:1392
batch_k_scale_expr< detail::build_type< A >, detail::build_type< B > > batch_k_scale(const A &a, const B &b)
Returns the transpose of the given expression.
Definition: batch_k_scale_expr.hpp:1495
value_t< A > value_type
The type of value of the expression.
Definition: batch_k_scale_expr.hpp:22
A transposition expression.
Definition: batch_k_scale_expr.hpp:21
B _b
The sub expression reference.
Definition: base_temporary_expr.hpp:534
void engine_dispatch_1d_serial(Functor &&functor, size_t first, size_t last, size_t threshold, [[maybe_unused]] size_t n_threads=etl::threads)
Dispatch the elements of a range to a functor in a parallel manner, using the global thread engine...
Definition: parallel_support.hpp:734
constexpr bool is_magic_view
Traits indicating if the given ETL type is a magic view expression.
Definition: traits.hpp:311
A _a
The sub expression reference.
Definition: base_temporary_expr.hpp:533
static constexpr int complexity() noexcept
Estimate the complexity of computation.
Definition: batch_k_scale_expr.hpp:1480
static constexpr size_t dim()
Returns the DDth dimension of the expression.
Definition: batch_k_scale_expr.hpp:1437
constexpr bool vec_enabled
Indicates if vectorization is available in any format.
Definition: config.hpp:220
order
Storage order of a matrix.
Definition: order.hpp:15
Abstract base class for temporary binary expression.
Definition: base_temporary_expr.hpp:529
std::add_lvalue_reference_t< B > b()
Returns the sub expression.
Definition: base_temporary_expr.hpp:593
static constexpr bool gpu_computable
Indicates if the temporary expression can be directly evaluated using only GPU.
Definition: batch_k_scale_expr.hpp:35
constexpr bool is_fast
Traits to test if the given ETL expresion type is fast (sizes known at compile-time) ...
Definition: traits.hpp:588
typename VV::template vec_type< value_type > vec_type
The vectorization type for VV.
Definition: base_temporary_expr.hpp:107
Traits to get information about ETL types.
Definition: tmp.hpp:68
Root namespace for the ETL library.
Definition: adapter.hpp:15
void assign_div_to(L &&lhs) const
Divide the given left-hand-side expression.
Definition: batch_k_scale_expr.hpp:1081
static constexpr size_t dimensions()
Return the number of dimensions of the expression.
Definition: traits_base.hpp:31
EGBLAS wrappers for the batch_k_scale operation.
no_vec default_vec
The default vectorization scheme.
Definition: vectorization.hpp:242
auto dim(E &&value, size_t i) -> detail::identity_helper< E, dim_view< detail::build_identity_type< E >, D >>
Return a view representing the ith Dth dimension.
Definition: view_expression_builder.hpp:25
std::conditional_t< is_etl_value< T >, const std::decay_t< T > &, std::decay_t< T > > build_type
Helper to build the type for a sub expression.
Definition: expression_helpers.hpp:24
batch_k_scale_expr(A a, B b)
Construct a new expression.
Definition: batch_k_scale_expr.hpp:42
std::decay_t< B > sub_expr_t
The sub expression type.
Definition: batch_k_scale_expr.hpp:1404
void std_mul_evaluate(Expr &&expr, Result &&result)
Compound multiply evaluation of the expr into result.
Definition: evaluator.hpp:1233
constexpr bool is_transformer
Traits indicating if the given ETL type is a transformer expression.
Definition: traits.hpp:297
decltype(auto) smart_forward_gpu(E &expr)
Smart forwarding for a temporary expression that will be computed in GPU.
Definition: helpers.hpp:343
constexpr size_t size(const E &expr) noexcept
Returns the size of the given ETL expression.
Definition: helpers.hpp:108
constexpr bool is_view
Traits indicating if the given ETL type is a view expression.
Definition: traits.hpp:304
void assign_sub_to(L &&lhs) const
Sub from the given left-hand-side expression.
Definition: batch_k_scale_expr.hpp:587
static size_t size(const expr_t &e)
Returns the size of the expression.
Definition: batch_k_scale_expr.hpp:1456
void assign_add_to(L &&lhs) const
Add to the given left-hand-side expression.
Definition: batch_k_scale_expr.hpp:340
static constexpr bool D4
If the expression is 4D (instead of 2D)
Definition: batch_k_scale_expr.hpp:27
void std_sub_evaluate(Expr &&expr, Result &&result)
Compound subtract evaluation of the expr into result.
Definition: evaluator.hpp:1214
static size_t dim(const expr_t &e, size_t d)
Returns the dth dimension of the expression.
Definition: batch_k_scale_expr.hpp:1447
constexpr bool is_thread_safe
Traits to test if the given ETL expresion type is thread safe.
Definition: traits.hpp:687
static void check([[maybe_unused]] const A &a, [[maybe_unused]] const B &b, [[maybe_unused]] const C &c)
Validate the transposition dimensions.
Definition: batch_k_scale_expr.hpp:52
typename decay_traits< E >::value_type value_t
Traits to extract the value type out of an ETL type.
Definition: tmp.hpp:81
void std_div_evaluate(Expr &&expr, Result &&result)
Compound divide evaluation of the expr into result.
Definition: evaluator.hpp:1252
void assign_mod_to(L &&lhs) const
Modulo the given left-hand-side expression.
Definition: batch_k_scale_expr.hpp:1328
void inc_counter([[maybe_unused]] const char *name)
Increase the given counter.
Definition: counters.hpp:25
std::add_lvalue_reference_t< A > a()
Returns the sub expression.
Definition: base_temporary_expr.hpp:577
static constexpr size_t dimensions()
Returns the number of dimensions of the expression.
Definition: batch_k_scale_expr.hpp:1472
void std_add_evaluate(Expr &&expr, Result &&result)
Compound add evaluation of the expr into result.
Definition: evaluator.hpp:1195
void assign_to(L &&lhs) const
Assign to a matrix of the same storage order.
Definition: batch_k_scale_expr.hpp:91