10 #include "etl/expr/base_temporary_expr.hpp" 14 template <etl_1d A, etl_2d_or_4d B, etl_1d C>
21 static constexpr
bool D4 = is_4d<B>;
30 (!D4 && impl::egblas::has_sbatch_k_scale_plus2 && all_row_major<A, B> && all_floating<A, B>)
31 || (D4 && impl::egblas::has_dbatch_k_scale_plus4 && all_row_major<A, B> && all_floating<A, B>);
46 template <same_dimensions<B> L>
47 static void check([[maybe_unused]]
const A&
a, [[maybe_unused]]
const B&
b, [[maybe_unused]]
const C&
c, [[maybe_unused]] L& lhs) {
49 if constexpr (all_fast<A, B, C, L>) {
50 static_assert(etl::dim<0, B>() == etl::dim<0, L>(),
"Invalid dimensions for batch_k_scale_plus");
51 static_assert(etl::dim<1, B>() == etl::dim<1, L>(),
"Invalid dimensions for batch_k_scale_plus");
52 static_assert(etl::dim<2, B>() == etl::dim<2, L>(),
"Invalid dimensions for batch_k_scale_plus");
53 static_assert(etl::dim<3, B>() == etl::dim<3, L>(),
"Invalid dimensions for batch_k_scale_plus");
55 static_assert(etl::dim<0, A>() == etl::dim<1, B>(),
"Invalid dimensions for batch_k_scale_plus");
56 static_assert(etl::dim<0, A>() == etl::dim<0, C>(),
"Invalid dimensions for batch_k_scale_plus");
58 cpp_assert(etl::dim<0>(b) == etl::dim<0>(lhs),
"Invalid dimensions for batch_k_scale_plus");
59 cpp_assert(etl::dim<1>(b) == etl::dim<1>(lhs),
"Invalid dimensions for batch_k_scale_plus");
60 cpp_assert(etl::dim<2>(b) == etl::dim<2>(lhs),
"Invalid dimensions for batch_k_scale_plus");
61 cpp_assert(etl::dim<3>(b) == etl::dim<3>(lhs),
"Invalid dimensions for batch_k_scale_plus");
63 cpp_assert(etl::dim<0>(a) == etl::dim<1>(b),
"Invalid dimensions for batch_k_scale_plus");
64 cpp_assert(etl::dim<0>(a) == etl::dim<0>(c),
"Invalid dimensions for batch_k_scale_plus");
67 if constexpr (all_fast<A, B, C, L>) {
68 static_assert(etl::dim<0, B>() == etl::dim<0, L>(),
"Invalid dimensions for batch_k_scale_plus");
69 static_assert(etl::dim<1, B>() == etl::dim<1, L>(),
"Invalid dimensions for batch_k_scale_plus");
71 static_assert(etl::dim<0, A>() == etl::dim<1, B>(),
"Invalid dimensions for batch_k_scale_plus");
72 static_assert(etl::dim<0, A>() == etl::dim<0, C>(),
"Invalid dimensions for batch_k_scale_plus");
74 cpp_assert(etl::dim<0>(b) == etl::dim<0>(lhs),
"Invalid dimensions for batch_k_scale_plus");
75 cpp_assert(etl::dim<1>(b) == etl::dim<1>(lhs),
"Invalid dimensions for batch_k_scale_plus");
77 cpp_assert(etl::dim<0>(a) == etl::dim<1>(b),
"Invalid dimensions for batch_k_scale_plus");
78 cpp_assert(etl::dim<0>(a) == etl::dim<0>(c),
"Invalid dimensions for batch_k_scale_plus");
100 const auto Batch = etl::dim<0>(lhs);
101 const auto K = etl::dim<1>(lhs);
102 const auto M = etl::dim<2>(lhs);
103 const auto N = etl::dim<3>(lhs);
105 if constexpr (impl::egblas::has_sbatch_k_scale4 && all_row_major<A, B, L> && all_floating<A, B, L>) {
110 t1.ensure_gpu_up_to_date();
111 t2.ensure_gpu_up_to_date();
112 t3.ensure_gpu_up_to_date();
114 lhs.ensure_gpu_allocated();
116 impl::egblas::batch_k_scale_plus(Batch, K, M, N, t2.gpu_memory(), t1.gpu_memory(), t3.gpu_memory(), lhs.gpu_memory());
119 lhs.invalidate_cpu();
121 standard_evaluator::pre_assign_rhs(
a);
122 standard_evaluator::pre_assign_rhs(
b);
124 a.ensure_cpu_up_to_date();
125 b.ensure_cpu_up_to_date();
126 c.ensure_cpu_up_to_date();
128 auto batch_fun_b = [&](
const size_t first,
const size_t last) {
130 if constexpr (
vec_enabled && all_vectorizable<vector_mode, A, L> && all_row_major<A, L>) {
134 static constexpr
size_t vec_size = vec_type::template traits<T>::size;
136 const auto MN = M * N;
138 for (
size_t batch = first; batch < last; ++batch) {
139 for (
size_t k = 0; k < K; ++k) {
143 auto lhs_sub = lhs(batch)(k);
144 auto b_sub =
b(batch)(k);
148 auto a1 = vec_type::set(ak);
149 auto c1 = vec_type::set(ck);
151 for (; mn + 4 * vec_size - 1 < MN; mn += 4 * vec_size) {
152 auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
153 auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
154 auto b3 = b_sub.template loadu<vec_type>(mn + 2 * vec_size);
155 auto b4 = b_sub.template loadu<vec_type>(mn + 3 * vec_size);
157 auto r1 = vec_type::add(vec_type::mul(a1, b1), c1);
158 auto r2 = vec_type::add(vec_type::mul(a1, b2), c1);
159 auto r3 = vec_type::add(vec_type::mul(a1, b3), c1);
160 auto r4 = vec_type::add(vec_type::mul(a1, b4), c1);
162 lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
163 lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
164 lhs_sub.template storeu<vec_type>(r3, mn + 2 * vec_size);
165 lhs_sub.template storeu<vec_type>(r4, mn + 3 * vec_size);
168 for (; mn + 2 * vec_size - 1 < MN; mn += 2 * vec_size) {
169 auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
170 auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
172 auto r1 = vec_type::add(vec_type::mul(a1, b1), c1);
173 auto r2 = vec_type::add(vec_type::mul(a1, b2), c1);
175 lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
176 lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
179 for (; mn + vec_size - 1 < MN; mn += vec_size) {
180 auto b1 = b_sub.template loadu<vec_type>(mn);
182 auto r1 = vec_type::add(vec_type::mul(a1, b1), c1);
184 lhs_sub.template storeu<vec_type>(r1, mn);
187 for (; mn + 3 < MN; mn += 4) {
188 lhs_sub[mn + 0] = ak * b_sub[mn + 0] + ck;
189 lhs_sub[mn + 1] = ak * b_sub[mn + 1] + ck;
190 lhs_sub[mn + 2] = ak * b_sub[mn + 2] + ck;
191 lhs_sub[mn + 3] = ak * b_sub[mn + 3] + ck;
194 for (; mn + 1 < MN; mn += 2) {
195 lhs_sub[mn + 0] = ak * b_sub[mn + 0] + ck;
196 lhs_sub[mn + 1] = ak * b_sub[mn + 1] + ck;
199 for (; mn < MN; ++mn) {
200 lhs_sub[mn] = ak * b_sub[mn] + ck;
205 for (
size_t batch = first; batch < last; ++batch) {
206 for (
size_t k = 0; k < K; ++k) {
207 for (
size_t m = 0; m < M; ++m) {
208 for (
size_t n = 0; n < N; ++n) {
209 lhs(batch, k, m, n) =
a(k) *
b(batch, k, m, n) +
c(k);
221 lhs.invalidate_gpu();
224 const auto Batch = etl::dim<0>(lhs);
225 const auto K = etl::dim<1>(lhs);
227 if constexpr (impl::egblas::has_sbatch_k_scale_plus2 && all_row_major<A, B, L> && all_floating<A, B, L>) {
232 t1.ensure_gpu_up_to_date();
233 t2.ensure_gpu_up_to_date();
234 t3.ensure_gpu_up_to_date();
236 lhs.ensure_gpu_allocated();
238 impl::egblas::batch_k_scale_plus(Batch, K, t2.gpu_memory(), t1.gpu_memory(), t3.gpu_memory(), lhs.gpu_memory());
241 lhs.invalidate_cpu();
243 standard_evaluator::pre_assign_rhs(
a);
244 standard_evaluator::pre_assign_rhs(
b);
246 a.ensure_cpu_up_to_date();
247 b.ensure_cpu_up_to_date();
248 c.ensure_cpu_up_to_date();
250 auto batch_fun_b = [&](
const size_t first,
const size_t last) {
252 if constexpr (
vec_enabled && all_vectorizable<vector_mode, A, B, L> && all_row_major<A, B, L>) {
256 static constexpr
size_t vec_size = vec_type::template traits<T>::size;
258 for (
size_t batch = first; batch < last; ++batch) {
261 const size_t base = batch * K;
263 for (; k + 4 * vec_size - 1 < K; k += 4 * vec_size) {
264 auto a1 =
a.template load<vec_type>(k + 0 * vec_size);
265 auto a2 =
a.template load<vec_type>(k + 1 * vec_size);
266 auto a3 =
a.template load<vec_type>(k + 2 * vec_size);
267 auto a4 =
a.template load<vec_type>(k + 3 * vec_size);
269 auto b1 =
b.template loadu<vec_type>(base + k + 0 * vec_size);
270 auto b2 =
b.template loadu<vec_type>(base + k + 1 * vec_size);
271 auto b3 =
b.template loadu<vec_type>(base + k + 2 * vec_size);
272 auto b4 =
b.template loadu<vec_type>(base + k + 3 * vec_size);
274 auto c1 =
c.template loadu<vec_type>(k + 0 * vec_size);
275 auto c2 =
c.template loadu<vec_type>(k + 1 * vec_size);
276 auto c3 =
c.template loadu<vec_type>(k + 2 * vec_size);
277 auto c4 =
c.template loadu<vec_type>(k + 3 * vec_size);
279 auto r1 = vec_type::add(vec_type::mul(a1, b1), c1);
280 auto r2 = vec_type::add(vec_type::mul(a2, b2), c2);
281 auto r3 = vec_type::add(vec_type::mul(a3, b3), c3);
282 auto r4 = vec_type::add(vec_type::mul(a4, b4), c4);
284 lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
285 lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
286 lhs.template storeu<vec_type>(r3, base + k + 2 * vec_size);
287 lhs.template storeu<vec_type>(r4, base + k + 3 * vec_size);
290 for (; k + 2 * vec_size - 1 < K; k += 2 * vec_size) {
291 auto a1 =
a.template load<vec_type>(k + 0 * vec_size);
292 auto a2 =
a.template load<vec_type>(k + 1 * vec_size);
294 auto b1 =
b.template loadu<vec_type>(base + k + 0 * vec_size);
295 auto b2 =
b.template loadu<vec_type>(base + k + 1 * vec_size);
297 auto c1 =
c.template loadu<vec_type>(k + 0 * vec_size);
298 auto c2 =
c.template loadu<vec_type>(k + 1 * vec_size);
300 auto r1 = vec_type::add(vec_type::mul(a1, b1), c1);
301 auto r2 = vec_type::add(vec_type::mul(a2, b2), c2);
303 lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
304 lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
307 for (; k + vec_size - 1 < K; k += vec_size) {
308 auto a1 =
a.template load<vec_type>(k);
310 auto b1 =
b.template loadu<vec_type>(base + k);
312 auto c1 =
c.template loadu<vec_type>(k);
314 auto r1 = vec_type::add(vec_type::mul(a1, b1), c1);
316 lhs.template storeu<vec_type>(r1, base + k);
319 for (; k + 3 < K; k += 4) {
320 lhs(batch, k + 0) =
a(k + 0) *
b(batch, k + 0) +
c(k + 0);
321 lhs(batch, k + 1) =
a(k + 1) *
b(batch, k + 1) +
c(k + 1);
322 lhs(batch, k + 2) =
a(k + 2) *
b(batch, k + 2) +
c(k + 2);
323 lhs(batch, k + 3) =
a(k + 3) *
b(batch, k + 3) +
c(k + 3);
326 for (; k + 1 < K; k += 2) {
327 lhs(batch, k + 0) =
a(k + 0) *
b(batch, k + 0) +
c(k + 0);
328 lhs(batch, k + 1) =
a(k + 1) *
b(batch, k + 1) +
c(k + 1);
332 lhs(batch, k) =
a(k) *
b(batch, k) +
c(k);
336 for (
size_t batch = first; batch < last; ++batch) {
337 for (
size_t k = 0; k < K; ++k) {
338 lhs(batch, k) =
a(k) *
b(batch, k) +
c(k);
348 lhs.invalidate_gpu();
357 template <etl_expr L>
366 if constexpr (impl::egblas::has_sbatch_k_scale_plus4 && all_row_major<A, B, L> && all_floating<A, B, L>) {
369 const auto Batch = etl::dim<0>(lhs);
370 const auto K = etl::dim<1>(lhs);
371 const auto M = etl::dim<2>(lhs);
372 const auto N = etl::dim<3>(lhs);
374 standard_evaluator::pre_assign_rhs(
a);
375 standard_evaluator::pre_assign_rhs(
b);
377 a.ensure_cpu_up_to_date();
378 b.ensure_cpu_up_to_date();
379 c.ensure_cpu_up_to_date();
380 lhs.ensure_cpu_up_to_date();
382 auto batch_fun_b = [&](
const size_t first,
const size_t last) {
384 if constexpr (
vec_enabled && all_vectorizable<vector_mode, A, L> && all_row_major<A, L>) {
388 static constexpr
size_t vec_size = vec_type::template traits<T>::size;
390 const auto MN = M * N;
392 for (
size_t batch = first; batch < last; ++batch) {
393 for (
size_t k = 0; k < K; ++k) {
397 auto lhs_sub = lhs(batch)(k);
398 auto b_sub =
b(batch)(k);
402 auto a1 = vec_type::set(ak);
403 auto c1 = vec_type::set(ck);
405 for (; mn + 4 * vec_size - 1 < MN; mn += 4 * vec_size) {
406 auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
407 auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
408 auto b3 = b_sub.template loadu<vec_type>(mn + 2 * vec_size);
409 auto b4 = b_sub.template loadu<vec_type>(mn + 3 * vec_size);
411 auto l1 = lhs_sub.template loadu<vec_type>(mn + 0 * vec_size);
412 auto l2 = lhs_sub.template loadu<vec_type>(mn + 1 * vec_size);
413 auto l3 = lhs_sub.template loadu<vec_type>(mn + 2 * vec_size);
414 auto l4 = lhs_sub.template loadu<vec_type>(mn + 3 * vec_size);
416 auto r1 = vec_type::add(l1, vec_type::add(vec_type::mul(a1, b1), c1));
417 auto r2 = vec_type::add(l2, vec_type::add(vec_type::mul(a1, b2), c1));
418 auto r3 = vec_type::add(l3, vec_type::add(vec_type::mul(a1, b3), c1));
419 auto r4 = vec_type::add(l4, vec_type::add(vec_type::mul(a1, b4), c1));
421 lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
422 lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
423 lhs_sub.template storeu<vec_type>(r3, mn + 2 * vec_size);
424 lhs_sub.template storeu<vec_type>(r4, mn + 3 * vec_size);
427 for (; mn + 2 * vec_size - 1 < MN; mn += 2 * vec_size) {
428 auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
429 auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
431 auto l1 = lhs_sub.template loadu<vec_type>(mn + 0 * vec_size);
432 auto l2 = lhs_sub.template loadu<vec_type>(mn + 1 * vec_size);
434 auto r1 = vec_type::add(l1, vec_type::add(vec_type::mul(a1, b1), c1));
435 auto r2 = vec_type::add(l2, vec_type::add(vec_type::mul(a1, b2), c1));
437 lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
438 lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
441 for (; mn + vec_size - 1 < MN; mn += vec_size) {
442 auto b1 = b_sub.template loadu<vec_type>(mn);
444 auto l1 = lhs_sub.template loadu<vec_type>(mn);
446 auto r1 = vec_type::add(l1, vec_type::add(vec_type::mul(a1, b1), c1));
448 lhs_sub.template storeu<vec_type>(r1, mn);
451 for (; mn + 3 < MN; mn += 4) {
452 lhs_sub[mn + 0] += ak * b_sub[mn + 0] + ck;
453 lhs_sub[mn + 1] += ak * b_sub[mn + 1] + ck;
454 lhs_sub[mn + 2] += ak * b_sub[mn + 2] + ck;
455 lhs_sub[mn + 3] += ak * b_sub[mn + 3] + ck;
458 for (; mn + 1 < MN; mn += 2) {
459 lhs_sub[mn + 0] += ak * b_sub[mn + 0] + ck;
460 lhs_sub[mn + 1] += ak * b_sub[mn + 1] + ck;
463 for (; mn < MN; ++mn) {
464 lhs_sub[mn] += ak * b_sub[mn] + ck;
469 for (
size_t batch = first; batch < last; ++batch) {
470 for (
size_t k = 0; k < K; ++k) {
471 for (
size_t m = 0; m < M; ++m) {
472 for (
size_t n = 0; n < N; ++n) {
473 lhs(batch, k, m, n) +=
a(k) *
b(batch, k, m, n) +
c(k);
485 lhs.invalidate_gpu();
488 if constexpr (impl::egblas::has_sbatch_k_scale_plus2 && all_row_major<A, B, L> && all_floating<A, B, L>) {
491 const auto Batch = etl::dim<0>(lhs);
492 const auto K = etl::dim<1>(lhs);
494 standard_evaluator::pre_assign_rhs(
a);
495 standard_evaluator::pre_assign_rhs(
b);
497 a.ensure_cpu_up_to_date();
498 b.ensure_cpu_up_to_date();
499 c.ensure_cpu_up_to_date();
500 lhs.ensure_cpu_up_to_date();
502 auto batch_fun_b = [&](
const size_t first,
const size_t last) {
504 if constexpr (
vec_enabled && all_vectorizable<vector_mode, A, B, L> && all_row_major<A, B, L>) {
508 static constexpr
size_t vec_size = vec_type::template traits<T>::size;
510 for (
size_t batch = first; batch < last; ++batch) {
513 const size_t base = batch * K;
515 for (; k + 4 * vec_size - 1 < K; k += 4 * vec_size) {
516 auto a1 =
a.template load<vec_type>(k + 0 * vec_size);
517 auto a2 =
a.template load<vec_type>(k + 1 * vec_size);
518 auto a3 =
a.template load<vec_type>(k + 2 * vec_size);
519 auto a4 =
a.template load<vec_type>(k + 3 * vec_size);
521 auto b1 =
b.template loadu<vec_type>(base + k + 0 * vec_size);
522 auto b2 =
b.template loadu<vec_type>(base + k + 1 * vec_size);
523 auto b3 =
b.template loadu<vec_type>(base + k + 2 * vec_size);
524 auto b4 =
b.template loadu<vec_type>(base + k + 3 * vec_size);
526 auto c1 =
c.template loadu<vec_type>(k + 0 * vec_size);
527 auto c2 =
c.template loadu<vec_type>(k + 1 * vec_size);
528 auto c3 =
c.template loadu<vec_type>(k + 2 * vec_size);
529 auto c4 =
c.template loadu<vec_type>(k + 3 * vec_size);
531 auto l1 = lhs.template loadu<vec_type>(base + k + 0 * vec_size);
532 auto l2 = lhs.template loadu<vec_type>(base + k + 1 * vec_size);
533 auto l3 = lhs.template loadu<vec_type>(base + k + 2 * vec_size);
534 auto l4 = lhs.template loadu<vec_type>(base + k + 3 * vec_size);
536 auto r1 = vec_type::add(l1, vec_type::add(vec_type::mul(a1, b1), c1));
537 auto r2 = vec_type::add(l2, vec_type::add(vec_type::mul(a2, b2), c2));
538 auto r3 = vec_type::add(l3, vec_type::add(vec_type::mul(a3, b3), c3));
539 auto r4 = vec_type::add(l4, vec_type::add(vec_type::mul(a4, b4), c4));
541 lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
542 lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
543 lhs.template storeu<vec_type>(r3, base + k + 2 * vec_size);
544 lhs.template storeu<vec_type>(r4, base + k + 3 * vec_size);
547 for (; k + 2 * vec_size - 1 < K; k += 2 * vec_size) {
548 auto a1 =
a.template load<vec_type>(k + 0 * vec_size);
549 auto a2 =
a.template load<vec_type>(k + 1 * vec_size);
551 auto b1 =
b.template loadu<vec_type>(base + k + 0 * vec_size);
552 auto b2 =
b.template loadu<vec_type>(base + k + 1 * vec_size);
554 auto c1 =
c.template loadu<vec_type>(k + 0 * vec_size);
555 auto c2 =
c.template loadu<vec_type>(k + 1 * vec_size);
557 auto l1 = lhs.template loadu<vec_type>(base + k + 0 * vec_size);
558 auto l2 = lhs.template loadu<vec_type>(base + k + 1 * vec_size);
560 auto r1 = vec_type::add(l1, vec_type::add(vec_type::mul(a1, b1), c1));
561 auto r2 = vec_type::add(l2, vec_type::add(vec_type::mul(a2, b2), c2));
563 lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
564 lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
567 for (; k + vec_size - 1 < K; k += vec_size) {
568 auto a1 =
a.template load<vec_type>(k);
570 auto b1 =
b.template loadu<vec_type>(base + k);
572 auto c1 =
c.template loadu<vec_type>(k);
574 auto l1 = lhs.template loadu<vec_type>(base + k);
576 auto r1 = vec_type::add(l1, vec_type::add(vec_type::mul(a1, b1), c1));
578 lhs.template storeu<vec_type>(r1, base + k);
581 for (; k + 3 < K; k += 4) {
582 lhs(batch, k + 0) +=
a(k + 0) *
b(batch, k + 0) +
c(k + 0);
583 lhs(batch, k + 1) +=
a(k + 1) *
b(batch, k + 1) +
c(k + 1);
584 lhs(batch, k + 2) +=
a(k + 2) *
b(batch, k + 2) +
c(k + 2);
585 lhs(batch, k + 3) +=
a(k + 3) *
b(batch, k + 3) +
c(k + 3);
588 for (; k + 1 < K; k += 2) {
589 lhs(batch, k + 0) +=
a(k + 0) *
b(batch, k + 0) +
c(k + 0);
590 lhs(batch, k + 1) +=
a(k + 1) *
b(batch, k + 1) +
c(k + 1);
594 lhs(batch, k) +=
a(k) *
b(batch, k) +
c(k);
598 for (
size_t batch = first; batch < last; ++batch) {
599 for (
size_t k = 0; k < K; ++k) {
600 lhs(batch, k) +=
a(k) *
b(batch, k) +
c(k);
610 lhs.invalidate_gpu();
619 template <etl_expr L>
628 if constexpr (impl::egblas::has_sbatch_k_scale_plus2 && all_row_major<A, B, L> && all_floating<A, B, L>) {
631 const auto Batch = etl::dim<0>(lhs);
632 const auto K = etl::dim<1>(lhs);
633 const auto M = etl::dim<2>(lhs);
634 const auto N = etl::dim<3>(lhs);
636 standard_evaluator::pre_assign_rhs(
a);
637 standard_evaluator::pre_assign_rhs(
b);
639 a.ensure_cpu_up_to_date();
640 b.ensure_cpu_up_to_date();
641 c.ensure_cpu_up_to_date();
642 lhs.ensure_cpu_up_to_date();
644 auto batch_fun_b = [&](
const size_t first,
const size_t last) {
646 if constexpr (
vec_enabled && all_vectorizable<vector_mode, A, L> && all_row_major<A, L>) {
650 static constexpr
size_t vec_size = vec_type::template traits<T>::size;
652 const auto MN = M * N;
654 for (
size_t batch = first; batch < last; ++batch) {
655 for (
size_t k = 0; k < K; ++k) {
659 auto lhs_sub = lhs(batch)(k);
660 auto b_sub =
b(batch)(k);
664 auto a1 = vec_type::set(ak);
665 auto c1 = vec_type::set(ck);
667 for (; mn + 4 * vec_size - 1 < MN; mn += 4 * vec_size) {
668 auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
669 auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
670 auto b3 = b_sub.template loadu<vec_type>(mn + 2 * vec_size);
671 auto b4 = b_sub.template loadu<vec_type>(mn + 3 * vec_size);
673 auto l1 = lhs_sub.template loadu<vec_type>(mn + 0 * vec_size);
674 auto l2 = lhs_sub.template loadu<vec_type>(mn + 1 * vec_size);
675 auto l3 = lhs_sub.template loadu<vec_type>(mn + 2 * vec_size);
676 auto l4 = lhs_sub.template loadu<vec_type>(mn + 3 * vec_size);
678 auto r1 = vec_type::sub(l1, vec_type::add(vec_type::mul(a1, b1), c1));
679 auto r2 = vec_type::sub(l2, vec_type::add(vec_type::mul(a1, b2), c1));
680 auto r3 = vec_type::sub(l3, vec_type::add(vec_type::mul(a1, b3), c1));
681 auto r4 = vec_type::sub(l4, vec_type::add(vec_type::mul(a1, b4), c1));
683 lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
684 lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
685 lhs_sub.template storeu<vec_type>(r3, mn + 2 * vec_size);
686 lhs_sub.template storeu<vec_type>(r4, mn + 3 * vec_size);
689 for (; mn + 2 * vec_size - 1 < MN; mn += 2 * vec_size) {
690 auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
691 auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
693 auto l1 = lhs_sub.template loadu<vec_type>(mn + 0 * vec_size);
694 auto l2 = lhs_sub.template loadu<vec_type>(mn + 1 * vec_size);
696 auto r1 = vec_type::sub(l1, vec_type::add(vec_type::mul(a1, b1), c1));
697 auto r2 = vec_type::sub(l2, vec_type::add(vec_type::mul(a1, b2), c1));
699 lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
700 lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
703 for (; mn + vec_size - 1 < MN; mn += vec_size) {
704 auto b1 = b_sub.template loadu<vec_type>(mn);
706 auto l1 = lhs_sub.template loadu<vec_type>(mn);
708 auto r1 = vec_type::sub(l1, vec_type::add(vec_type::mul(a1, b1), c1));
710 lhs_sub.template storeu<vec_type>(r1, mn);
713 for (; mn + 3 < MN; mn += 4) {
714 lhs_sub[mn + 0] -= ak * b_sub[mn + 0] + ck;
715 lhs_sub[mn + 1] -= ak * b_sub[mn + 1] + ck;
716 lhs_sub[mn + 2] -= ak * b_sub[mn + 2] + ck;
717 lhs_sub[mn + 3] -= ak * b_sub[mn + 3] + ck;
720 for (; mn + 1 < MN; mn += 2) {
721 lhs_sub[mn + 0] -= ak * b_sub[mn + 0] + ck;
722 lhs_sub[mn + 1] -= ak * b_sub[mn + 1] + ck;
725 for (; mn < MN; ++mn) {
726 lhs_sub[mn] -= ak * b_sub[mn] + ck;
731 for (
size_t batch = first; batch < last; ++batch) {
732 for (
size_t k = 0; k < K; ++k) {
733 for (
size_t m = 0; m < M; ++m) {
734 for (
size_t n = 0; n < N; ++n) {
735 lhs(batch, k, m, n) -=
a(k) *
b(batch, k, m, n) +
c(k);
747 lhs.invalidate_gpu();
750 if constexpr (impl::egblas::has_sbatch_k_scale_plus2 && all_row_major<A, B, L> && all_floating<A, B, L>) {
753 const auto Batch = etl::dim<0>(lhs);
754 const auto K = etl::dim<1>(lhs);
756 standard_evaluator::pre_assign_rhs(
a);
757 standard_evaluator::pre_assign_rhs(
b);
759 a.ensure_cpu_up_to_date();
760 b.ensure_cpu_up_to_date();
761 c.ensure_cpu_up_to_date();
762 lhs.ensure_cpu_up_to_date();
764 auto batch_fun_b = [&](
const size_t first,
const size_t last) {
766 if constexpr (
vec_enabled && all_vectorizable<vector_mode, A, B, L> && all_row_major<A, B, L>) {
770 static constexpr
size_t vec_size = vec_type::template traits<T>::size;
772 for (
size_t batch = first; batch < last; ++batch) {
775 const size_t base = batch * K;
777 for (; k + 4 * vec_size - 1 < K; k += 4 * vec_size) {
778 auto a1 =
a.template load<vec_type>(k + 0 * vec_size);
779 auto a2 =
a.template load<vec_type>(k + 1 * vec_size);
780 auto a3 =
a.template load<vec_type>(k + 2 * vec_size);
781 auto a4 =
a.template load<vec_type>(k + 3 * vec_size);
783 auto b1 =
b.template loadu<vec_type>(base + k + 0 * vec_size);
784 auto b2 =
b.template loadu<vec_type>(base + k + 1 * vec_size);
785 auto b3 =
b.template loadu<vec_type>(base + k + 2 * vec_size);
786 auto b4 =
b.template loadu<vec_type>(base + k + 3 * vec_size);
788 auto c1 =
c.template loadu<vec_type>(k + 0 * vec_size);
789 auto c2 =
c.template loadu<vec_type>(k + 1 * vec_size);
790 auto c3 =
c.template loadu<vec_type>(k + 2 * vec_size);
791 auto c4 =
c.template loadu<vec_type>(k + 3 * vec_size);
793 auto l1 = lhs.template loadu<vec_type>(base + k + 0 * vec_size);
794 auto l2 = lhs.template loadu<vec_type>(base + k + 1 * vec_size);
795 auto l3 = lhs.template loadu<vec_type>(base + k + 2 * vec_size);
796 auto l4 = lhs.template loadu<vec_type>(base + k + 3 * vec_size);
798 auto r1 = vec_type::sub(l1, vec_type::add(vec_type::mul(a1, b1), c1));
799 auto r2 = vec_type::sub(l2, vec_type::add(vec_type::mul(a2, b2), c2));
800 auto r3 = vec_type::sub(l3, vec_type::add(vec_type::mul(a3, b3), c3));
801 auto r4 = vec_type::sub(l4, vec_type::add(vec_type::mul(a4, b4), c4));
803 lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
804 lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
805 lhs.template storeu<vec_type>(r3, base + k + 2 * vec_size);
806 lhs.template storeu<vec_type>(r4, base + k + 3 * vec_size);
809 for (; k + 2 * vec_size - 1 < K; k += 2 * vec_size) {
810 auto a1 =
a.template load<vec_type>(k + 0 * vec_size);
811 auto a2 =
a.template load<vec_type>(k + 1 * vec_size);
813 auto b1 =
b.template loadu<vec_type>(base + k + 0 * vec_size);
814 auto b2 =
b.template loadu<vec_type>(base + k + 1 * vec_size);
816 auto c1 =
c.template loadu<vec_type>(k + 0 * vec_size);
817 auto c2 =
c.template loadu<vec_type>(k + 1 * vec_size);
819 auto l1 = lhs.template loadu<vec_type>(base + k + 0 * vec_size);
820 auto l2 = lhs.template loadu<vec_type>(base + k + 1 * vec_size);
822 auto r1 = vec_type::sub(l1, vec_type::add(vec_type::mul(a1, b1), c1));
823 auto r2 = vec_type::sub(l2, vec_type::add(vec_type::mul(a2, b2), c2));
825 lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
826 lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
829 for (; k + vec_size - 1 < K; k += vec_size) {
830 auto a1 =
a.template load<vec_type>(k);
832 auto b1 =
b.template loadu<vec_type>(base + k);
834 auto c1 =
c.template loadu<vec_type>(k);
836 auto l1 = lhs.template loadu<vec_type>(base + k);
838 auto r1 = vec_type::sub(l1, vec_type::add(vec_type::mul(a1, b1), c1));
840 lhs.template storeu<vec_type>(r1, base + k);
843 for (; k + 3 < K; k += 4) {
844 lhs(batch, k + 0) -=
a(k + 0) *
b(batch, k + 0) +
c(k + 0);
845 lhs(batch, k + 1) -=
a(k + 1) *
b(batch, k + 1) +
c(k + 1);
846 lhs(batch, k + 2) -=
a(k + 2) *
b(batch, k + 2) +
c(k + 2);
847 lhs(batch, k + 3) -=
a(k + 3) *
b(batch, k + 3) +
c(k + 3);
850 for (; k + 1 < K; k += 2) {
851 lhs(batch, k + 0) -=
a(k + 0) *
b(batch, k + 0) +
c(k + 0);
852 lhs(batch, k + 1) -=
a(k + 1) *
b(batch, k + 1) +
c(k + 1);
856 lhs(batch, k) -=
a(k) *
b(batch, k) +
c(k);
860 for (
size_t batch = first; batch < last; ++batch) {
861 for (
size_t k = 0; k < K; ++k) {
862 lhs(batch, k) -=
a(k) *
b(batch, k) +
c(k);
872 lhs.invalidate_gpu();
881 template <etl_expr L>
890 if constexpr (impl::egblas::has_sbatch_k_scale_plus2 && all_row_major<A, B, L> && all_floating<A, B, L>) {
893 const auto Batch = etl::dim<0>(lhs);
894 const auto K = etl::dim<1>(lhs);
895 const auto M = etl::dim<2>(lhs);
896 const auto N = etl::dim<3>(lhs);
898 standard_evaluator::pre_assign_rhs(
a);
899 standard_evaluator::pre_assign_rhs(
b);
901 a.ensure_cpu_up_to_date();
902 b.ensure_cpu_up_to_date();
903 c.ensure_cpu_up_to_date();
904 lhs.ensure_cpu_up_to_date();
906 auto batch_fun_b = [&](
const size_t first,
const size_t last) {
908 if constexpr (
vec_enabled && all_vectorizable<vector_mode, A, L> && all_row_major<A, L>) {
912 static constexpr
size_t vec_size = vec_type::template traits<T>::size;
914 const auto MN = M * N;
916 for (
size_t batch = first; batch < last; ++batch) {
917 for (
size_t k = 0; k < K; ++k) {
921 auto lhs_sub = lhs(batch)(k);
922 auto b_sub =
b(batch)(k);
926 auto a1 = vec_type::set(ak);
927 auto c1 = vec_type::set(ck);
929 for (; mn + 4 * vec_size - 1 < MN; mn += 4 * vec_size) {
930 auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
931 auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
932 auto b3 = b_sub.template loadu<vec_type>(mn + 2 * vec_size);
933 auto b4 = b_sub.template loadu<vec_type>(mn + 3 * vec_size);
935 auto l1 = lhs_sub.template loadu<vec_type>(mn + 0 * vec_size);
936 auto l2 = lhs_sub.template loadu<vec_type>(mn + 1 * vec_size);
937 auto l3 = lhs_sub.template loadu<vec_type>(mn + 2 * vec_size);
938 auto l4 = lhs_sub.template loadu<vec_type>(mn + 3 * vec_size);
940 auto r1 = vec_type::mul(l1, vec_type::add(vec_type::mul(a1, b1), c1));
941 auto r2 = vec_type::mul(l2, vec_type::add(vec_type::mul(a1, b2), c1));
942 auto r3 = vec_type::mul(l3, vec_type::add(vec_type::mul(a1, b3), c1));
943 auto r4 = vec_type::mul(l4, vec_type::add(vec_type::mul(a1, b4), c1));
945 lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
946 lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
947 lhs_sub.template storeu<vec_type>(r3, mn + 2 * vec_size);
948 lhs_sub.template storeu<vec_type>(r4, mn + 3 * vec_size);
951 for (; mn + 2 * vec_size - 1 < MN; mn += 2 * vec_size) {
952 auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
953 auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
955 auto l1 = lhs_sub.template loadu<vec_type>(mn + 0 * vec_size);
956 auto l2 = lhs_sub.template loadu<vec_type>(mn + 1 * vec_size);
958 auto r1 = vec_type::mul(l1, vec_type::add(vec_type::mul(a1, b1), c1));
959 auto r2 = vec_type::mul(l2, vec_type::add(vec_type::mul(a1, b2), c1));
961 lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
962 lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
965 for (; mn + vec_size - 1 < MN; mn += vec_size) {
966 auto b1 = b_sub.template loadu<vec_type>(mn);
968 auto l1 = lhs_sub.template loadu<vec_type>(mn);
970 auto r1 = vec_type::mul(l1, vec_type::add(vec_type::mul(a1, b1), c1));
972 lhs_sub.template storeu<vec_type>(r1, mn);
975 for (; mn + 3 < MN; mn += 4) {
976 lhs_sub[mn + 0] *= ak * b_sub[mn + 0] + ck;
977 lhs_sub[mn + 1] *= ak * b_sub[mn + 1] + ck;
978 lhs_sub[mn + 2] *= ak * b_sub[mn + 2] + ck;
979 lhs_sub[mn + 3] *= ak * b_sub[mn + 3] + ck;
982 for (; mn + 1 < MN; mn += 2) {
983 lhs_sub[mn + 0] *= ak * b_sub[mn + 0] + ck;
984 lhs_sub[mn + 1] *= ak * b_sub[mn + 1] + ck;
987 for (; mn < MN; ++mn) {
988 lhs_sub[mn] *= ak * b_sub[mn] + ck;
993 for (
size_t batch = first; batch < last; ++batch) {
994 for (
size_t k = 0; k < K; ++k) {
995 for (
size_t m = 0; m < M; ++m) {
996 for (
size_t n = 0; n < N; ++n) {
997 lhs(batch, k, m, n) *=
a(k) *
b(batch, k, m, n) +
c(k);
1009 lhs.invalidate_gpu();
1012 if constexpr (impl::egblas::has_sbatch_k_scale_plus2 && all_row_major<A, B, L> && all_floating<A, B, L>) {
1015 const auto Batch = etl::dim<0>(lhs);
1016 const auto K = etl::dim<1>(lhs);
1018 standard_evaluator::pre_assign_rhs(
a);
1019 standard_evaluator::pre_assign_rhs(
b);
1021 a.ensure_cpu_up_to_date();
1022 b.ensure_cpu_up_to_date();
1023 c.ensure_cpu_up_to_date();
1024 lhs.ensure_cpu_up_to_date();
1026 auto batch_fun_b = [&](
const size_t first,
const size_t last) {
1028 if constexpr (
vec_enabled && all_vectorizable<vector_mode, A, B, L> && all_row_major<A, B, L>) {
1032 static constexpr
size_t vec_size = vec_type::template traits<T>::size;
1034 for (
size_t batch = first; batch < last; ++batch) {
1037 const size_t base = batch * K;
1039 for (; k + 4 * vec_size - 1 < K; k += 4 * vec_size) {
1040 auto a1 =
a.template load<vec_type>(k + 0 * vec_size);
1041 auto a2 =
a.template load<vec_type>(k + 1 * vec_size);
1042 auto a3 =
a.template load<vec_type>(k + 2 * vec_size);
1043 auto a4 =
a.template load<vec_type>(k + 3 * vec_size);
1045 auto b1 =
b.template loadu<vec_type>(base + k + 0 * vec_size);
1046 auto b2 =
b.template loadu<vec_type>(base + k + 1 * vec_size);
1047 auto b3 =
b.template loadu<vec_type>(base + k + 2 * vec_size);
1048 auto b4 =
b.template loadu<vec_type>(base + k + 3 * vec_size);
1050 auto c1 =
c.template loadu<vec_type>(k + 0 * vec_size);
1051 auto c2 =
c.template loadu<vec_type>(k + 1 * vec_size);
1052 auto c3 =
c.template loadu<vec_type>(k + 2 * vec_size);
1053 auto c4 =
c.template loadu<vec_type>(k + 3 * vec_size);
1055 auto l1 = lhs.template loadu<vec_type>(base + k + 0 * vec_size);
1056 auto l2 = lhs.template loadu<vec_type>(base + k + 1 * vec_size);
1057 auto l3 = lhs.template loadu<vec_type>(base + k + 2 * vec_size);
1058 auto l4 = lhs.template loadu<vec_type>(base + k + 3 * vec_size);
1060 auto r1 = vec_type::mul(l1, vec_type::add(vec_type::mul(a1, b1), c1));
1061 auto r2 = vec_type::mul(l2, vec_type::add(vec_type::mul(a2, b2), c2));
1062 auto r3 = vec_type::mul(l3, vec_type::add(vec_type::mul(a3, b3), c3));
1063 auto r4 = vec_type::mul(l4, vec_type::add(vec_type::mul(a4, b4), c4));
1065 lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
1066 lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
1067 lhs.template storeu<vec_type>(r3, base + k + 2 * vec_size);
1068 lhs.template storeu<vec_type>(r4, base + k + 3 * vec_size);
1071 for (; k + 2 * vec_size - 1 < K; k += 2 * vec_size) {
1072 auto a1 =
a.template load<vec_type>(k + 0 * vec_size);
1073 auto a2 =
a.template load<vec_type>(k + 1 * vec_size);
1075 auto b1 =
b.template loadu<vec_type>(base + k + 0 * vec_size);
1076 auto b2 =
b.template loadu<vec_type>(base + k + 1 * vec_size);
1078 auto c1 =
c.template loadu<vec_type>(k + 0 * vec_size);
1079 auto c2 =
c.template loadu<vec_type>(k + 1 * vec_size);
1081 auto l1 = lhs.template loadu<vec_type>(base + k + 0 * vec_size);
1082 auto l2 = lhs.template loadu<vec_type>(base + k + 1 * vec_size);
1084 auto r1 = vec_type::mul(l1, vec_type::add(vec_type::mul(a1, b1), c1));
1085 auto r2 = vec_type::mul(l2, vec_type::add(vec_type::mul(a2, b2), c2));
1087 lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
1088 lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
1091 for (; k + vec_size - 1 < K; k += vec_size) {
1092 auto a1 =
a.template load<vec_type>(k);
1094 auto b1 =
b.template loadu<vec_type>(base + k);
1096 auto c1 =
c.template loadu<vec_type>(k);
1098 auto l1 = lhs.template loadu<vec_type>(base + k);
1100 auto r1 = vec_type::mul(l1, vec_type::add(vec_type::mul(a1, b1), c1));
1102 lhs.template storeu<vec_type>(r1, base + k);
1105 for (; k + 3 < K; k += 4) {
1106 lhs(batch, k + 0) *=
a(k + 0) *
b(batch, k + 0) +
c(k + 0);
1107 lhs(batch, k + 1) *=
a(k + 1) *
b(batch, k + 1) +
c(k + 1);
1108 lhs(batch, k + 2) *=
a(k + 2) *
b(batch, k + 2) +
c(k + 2);
1109 lhs(batch, k + 3) *=
a(k + 3) *
b(batch, k + 3) +
c(k + 3);
1112 for (; k + 1 < K; k += 2) {
1113 lhs(batch, k + 0) *=
a(k + 0) *
b(batch, k + 0) +
c(k + 0);
1114 lhs(batch, k + 1) *=
a(k + 1) *
b(batch, k + 1) +
c(k + 1);
1118 lhs(batch, k) *=
a(k) *
b(batch, k) +
c(k);
1122 for (
size_t batch = first; batch < last; ++batch) {
1123 for (
size_t k = 0; k < K; ++k) {
1124 lhs(batch, k) *=
a(k) *
b(batch, k) +
c(k);
1134 lhs.invalidate_gpu();
1143 template <etl_expr L>
1145 auto&
a = this->
a();
1146 auto&
b = this->
b();
1147 auto&
c = this->
c();
1152 if constexpr (impl::egblas::has_sbatch_k_scale_plus2 && all_row_major<A, B, L> && all_floating<A, B, L>) {
1155 const auto Batch = etl::dim<0>(lhs);
1156 const auto K = etl::dim<1>(lhs);
1157 const auto M = etl::dim<2>(lhs);
1158 const auto N = etl::dim<3>(lhs);
1160 standard_evaluator::pre_assign_rhs(
a);
1161 standard_evaluator::pre_assign_rhs(
b);
1163 a.ensure_cpu_up_to_date();
1164 b.ensure_cpu_up_to_date();
1165 c.ensure_cpu_up_to_date();
1166 lhs.ensure_cpu_up_to_date();
1168 auto batch_fun_b = [&](
const size_t first,
const size_t last) {
1170 if constexpr (
vec_enabled && all_vectorizable<vector_mode, A, L> && all_row_major<A, L>) {
1174 static constexpr
size_t vec_size = vec_type::template traits<T>::size;
1176 const auto MN = M * N;
1178 for (
size_t batch = first; batch < last; ++batch) {
1179 for (
size_t k = 0; k < K; ++k) {
1183 auto lhs_sub = lhs(batch)(k);
1184 auto b_sub =
b(batch)(k);
1188 auto a1 = vec_type::set(ak);
1189 auto c1 = vec_type::set(ck);
1191 for (; mn + 4 * vec_size - 1 < MN; mn += 4 * vec_size) {
1192 auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
1193 auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
1194 auto b3 = b_sub.template loadu<vec_type>(mn + 2 * vec_size);
1195 auto b4 = b_sub.template loadu<vec_type>(mn + 3 * vec_size);
1197 auto l1 = lhs_sub.template loadu<vec_type>(mn + 0 * vec_size);
1198 auto l2 = lhs_sub.template loadu<vec_type>(mn + 1 * vec_size);
1199 auto l3 = lhs_sub.template loadu<vec_type>(mn + 2 * vec_size);
1200 auto l4 = lhs_sub.template loadu<vec_type>(mn + 3 * vec_size);
1202 auto r1 = vec_type::div(l1, vec_type::add(vec_type::mul(a1, b1), c1));
1203 auto r2 = vec_type::div(l2, vec_type::add(vec_type::mul(a1, b2), c1));
1204 auto r3 = vec_type::div(l3, vec_type::add(vec_type::mul(a1, b3), c1));
1205 auto r4 = vec_type::div(l4, vec_type::add(vec_type::mul(a1, b4), c1));
1207 lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
1208 lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
1209 lhs_sub.template storeu<vec_type>(r3, mn + 2 * vec_size);
1210 lhs_sub.template storeu<vec_type>(r4, mn + 3 * vec_size);
1213 for (; mn + 2 * vec_size - 1 < MN; mn += 2 * vec_size) {
1214 auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
1215 auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
1217 auto l1 = lhs_sub.template loadu<vec_type>(mn + 0 * vec_size);
1218 auto l2 = lhs_sub.template loadu<vec_type>(mn + 1 * vec_size);
1220 auto r1 = vec_type::div(l1, vec_type::add(vec_type::mul(a1, b1), c1));
1221 auto r2 = vec_type::div(l2, vec_type::add(vec_type::mul(a1, b2), c1));
1223 lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
1224 lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
1227 for (; mn + vec_size - 1 < MN; mn += vec_size) {
1228 auto b1 = b_sub.template loadu<vec_type>(mn);
1230 auto l1 = lhs_sub.template loadu<vec_type>(mn);
1232 auto r1 = vec_type::div(l1, vec_type::add(vec_type::mul(a1, b1), c1));
1234 lhs_sub.template storeu<vec_type>(r1, mn);
1237 for (; mn + 3 < MN; mn += 4) {
1238 lhs_sub[mn + 0] /= ak * b_sub[mn + 0] + ck;
1239 lhs_sub[mn + 1] /= ak * b_sub[mn + 1] + ck;
1240 lhs_sub[mn + 2] /= ak * b_sub[mn + 2] + ck;
1241 lhs_sub[mn + 3] /= ak * b_sub[mn + 3] + ck;
1244 for (; mn + 1 < MN; mn += 2) {
1245 lhs_sub[mn + 0] /= ak * b_sub[mn + 0] + ck;
1246 lhs_sub[mn + 1] /= ak * b_sub[mn + 1] + ck;
1249 for (; mn < MN; ++mn) {
1250 lhs_sub[mn] /= ak * b_sub[mn] + ck;
1255 for (
size_t batch = first; batch < last; ++batch) {
1256 for (
size_t k = 0; k < K; ++k) {
1257 for (
size_t m = 0; m < M; ++m) {
1258 for (
size_t n = 0; n < N; ++n) {
1259 lhs(batch, k, m, n) /=
a(k) *
b(batch, k, m, n) +
c(k);
1271 lhs.invalidate_gpu();
1274 if constexpr (impl::egblas::has_sbatch_k_scale_plus2 && all_row_major<A, B, L> && all_floating<A, B, L>) {
1277 const auto Batch = etl::dim<0>(lhs);
1278 const auto K = etl::dim<1>(lhs);
1280 standard_evaluator::pre_assign_rhs(
a);
1281 standard_evaluator::pre_assign_rhs(
b);
1283 a.ensure_cpu_up_to_date();
1284 b.ensure_cpu_up_to_date();
1285 c.ensure_cpu_up_to_date();
1286 lhs.ensure_cpu_up_to_date();
1288 auto batch_fun_b = [&](
const size_t first,
const size_t last) {
1290 if constexpr (
vec_enabled && all_vectorizable<vector_mode, A, B, L> && all_row_major<A, B, L>) {
1294 static constexpr
size_t vec_size = vec_type::template traits<T>::size;
1296 for (
size_t batch = first; batch < last; ++batch) {
1299 const size_t base = batch * K;
1301 for (; k + 4 * vec_size - 1 < K; k += 4 * vec_size) {
1302 auto a1 =
a.template load<vec_type>(k + 0 * vec_size);
1303 auto a2 =
a.template load<vec_type>(k + 1 * vec_size);
1304 auto a3 =
a.template load<vec_type>(k + 2 * vec_size);
1305 auto a4 =
a.template load<vec_type>(k + 3 * vec_size);
1307 auto b1 =
b.template loadu<vec_type>(base + k + 0 * vec_size);
1308 auto b2 =
b.template loadu<vec_type>(base + k + 1 * vec_size);
1309 auto b3 =
b.template loadu<vec_type>(base + k + 2 * vec_size);
1310 auto b4 =
b.template loadu<vec_type>(base + k + 3 * vec_size);
1312 auto c1 =
c.template loadu<vec_type>(k + 0 * vec_size);
1313 auto c2 =
c.template loadu<vec_type>(k + 1 * vec_size);
1314 auto c3 =
c.template loadu<vec_type>(k + 2 * vec_size);
1315 auto c4 =
c.template loadu<vec_type>(k + 3 * vec_size);
1317 auto l1 = lhs.template loadu<vec_type>(base + k + 0 * vec_size);
1318 auto l2 = lhs.template loadu<vec_type>(base + k + 1 * vec_size);
1319 auto l3 = lhs.template loadu<vec_type>(base + k + 2 * vec_size);
1320 auto l4 = lhs.template loadu<vec_type>(base + k + 3 * vec_size);
1322 auto r1 = vec_type::div(l1, vec_type::add(vec_type::mul(a1, b1), c1));
1323 auto r2 = vec_type::div(l2, vec_type::add(vec_type::mul(a2, b2), c2));
1324 auto r3 = vec_type::div(l3, vec_type::add(vec_type::mul(a3, b3), c3));
1325 auto r4 = vec_type::div(l4, vec_type::add(vec_type::mul(a4, b4), c4));
1327 lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
1328 lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
1329 lhs.template storeu<vec_type>(r3, base + k + 2 * vec_size);
1330 lhs.template storeu<vec_type>(r4, base + k + 3 * vec_size);
1333 for (; k + 2 * vec_size - 1 < K; k += 2 * vec_size) {
1334 auto a1 =
a.template load<vec_type>(k + 0 * vec_size);
1335 auto a2 =
a.template load<vec_type>(k + 1 * vec_size);
1337 auto b1 =
b.template loadu<vec_type>(base + k + 0 * vec_size);
1338 auto b2 =
b.template loadu<vec_type>(base + k + 1 * vec_size);
1340 auto c1 =
c.template loadu<vec_type>(k + 0 * vec_size);
1341 auto c2 =
c.template loadu<vec_type>(k + 1 * vec_size);
1343 auto l1 = lhs.template loadu<vec_type>(base + k + 0 * vec_size);
1344 auto l2 = lhs.template loadu<vec_type>(base + k + 1 * vec_size);
1346 auto r1 = vec_type::div(l1, vec_type::add(vec_type::mul(a1, b1), c1));
1347 auto r2 = vec_type::div(l2, vec_type::add(vec_type::mul(a2, b2), c2));
1349 lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
1350 lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
1353 for (; k + vec_size - 1 < K; k += vec_size) {
1354 auto a1 =
a.template load<vec_type>(k);
1356 auto b1 =
b.template loadu<vec_type>(base + k);
1358 auto c1 =
c.template loadu<vec_type>(k);
1360 auto l1 = lhs.template loadu<vec_type>(base + k);
1362 auto r1 = vec_type::div(l1, vec_type::add(vec_type::mul(a1, b1), c1));
1364 lhs.template storeu<vec_type>(r1, base + k);
1367 for (; k + 3 < K; k += 4) {
1368 lhs(batch, k + 0) /=
a(k + 0) *
b(batch, k + 0) +
c(k + 0);
1369 lhs(batch, k + 1) /=
a(k + 1) *
b(batch, k + 1) +
c(k + 1);
1370 lhs(batch, k + 2) /=
a(k + 2) *
b(batch, k + 2) +
c(k + 2);
1371 lhs(batch, k + 3) /=
a(k + 3) *
b(batch, k + 3) +
c(k + 3);
1374 for (; k + 1 < K; k += 2) {
1375 lhs(batch, k + 0) /=
a(k + 0) *
b(batch, k + 0) +
c(k + 0);
1376 lhs(batch, k + 1) /=
a(k + 1) *
b(batch, k + 1) +
c(k + 1);
1380 lhs(batch, k) /=
a(k) *
b(batch, k) +
c(k);
1384 for (
size_t batch = first; batch < last; ++batch) {
1385 for (
size_t k = 0; k < K; ++k) {
1386 lhs(batch, k) /=
a(k) *
b(batch, k) +
c(k);
1396 lhs.invalidate_gpu();
1405 template <etl_expr L>
1407 auto&
a = this->
a();
1408 auto&
b = this->
b();
1409 auto&
c = this->
c();
1413 standard_evaluator::pre_assign_rhs(
a);
1414 standard_evaluator::pre_assign_rhs(
b);
1416 a.ensure_cpu_up_to_date();
1417 b.ensure_cpu_up_to_date();
1418 c.ensure_cpu_up_to_date();
1419 lhs.ensure_cpu_up_to_date();
1422 const auto Batch = etl::dim<0>(lhs);
1423 const auto K = etl::dim<1>(lhs);
1424 const auto M = etl::dim<2>(lhs);
1425 const auto N = etl::dim<3>(lhs);
1427 auto batch_fun_b = [&](
const size_t first,
const size_t last) {
1429 for (
size_t batch = first; batch < last; ++batch) {
1430 for (
size_t k = 0; k < K; ++k) {
1431 for (
size_t m = 0; m < M; ++m) {
1432 for (
size_t n = 0; n < N; ++n) {
1433 lhs(batch, k, m, n) %=
a(k) *
b(batch, k, m, n) +
c(k);
1444 lhs.invalidate_gpu();
1446 const auto Batch = etl::dim<0>(lhs);
1447 const auto K = etl::dim<1>(lhs);
1449 auto batch_fun_b = [&](
const size_t first,
const size_t last) {
1451 for (
size_t batch = first; batch < last; ++batch) {
1452 for (
size_t k = 0; k < K; ++k) {
1453 lhs(batch, k) %=
a(k) *
b(batch, k) +
c(k);
1462 lhs.invalidate_gpu();
1473 return os <<
"batch_k_scale_plus(" << expr.
_a <<
"," << expr.
_b <<
"," << expr.
_c <<
")";
1481 template <
typename A,
typename B,
typename C>
1488 static constexpr
bool is_etl =
true;
1492 static constexpr
bool is_fast = sub_traits::is_fast;
1493 static constexpr
bool is_linear =
false;
1495 static constexpr
bool is_value =
false;
1496 static constexpr
bool is_direct =
true;
1497 static constexpr
bool is_generator =
false;
1498 static constexpr
bool is_padded =
false;
1499 static constexpr
bool is_aligned =
true;
1500 static constexpr
bool is_temporary =
true;
1501 static constexpr
bool gpu_computable =
true;
1502 static constexpr
order storage_order = sub_traits::storage_order;
1509 template <vector_mode_t V>
1510 static constexpr
bool vectorizable =
true;
1516 template <
size_t DD>
1517 static constexpr
size_t dim() {
1574 template <etl_1d A, etl_2d_or_4d B, etl_1d C>
std::add_lvalue_reference_t< B > b()
Returns the sub expression.
Definition: base_temporary_expr.hpp:702
static void check([[maybe_unused]] const A &a, [[maybe_unused]] const B &b, [[maybe_unused]] const C &c, [[maybe_unused]] L &lhs)
Validate the transposition dimensions.
Definition: batch_k_scale_plus_expr.hpp:47
void assign_mod_to(L &&lhs) const
Modulo the given left-hand-side expression.
Definition: batch_k_scale_plus_expr.hpp:1406
void assign_sub_to(L &&lhs) const
Sub from the given left-hand-side expression.
Definition: batch_k_scale_plus_expr.hpp:620
batch_k_scale_plus_expr< detail::build_type< A >, detail::build_type< B >, detail::build_type< C > > batch_k_scale_plus(const A &a, const B &b, const C &c)
Returns the transpose of the given expression.
Definition: batch_k_scale_plus_expr.hpp:1575
value_t< A > value_type
The type of value of the expression.
Definition: batch_k_scale_plus_expr.hpp:16
void engine_dispatch_1d_serial(Functor &&functor, size_t first, size_t last, size_t threshold, [[maybe_unused]] size_t n_threads=etl::threads)
Dispatch the elements of a range to a functor in a parallel manner, using the global thread engine...
Definition: parallel_support.hpp:734
constexpr bool is_magic_view
Traits indicating if the given ETL type is a magic view expression.
Definition: traits.hpp:311
static constexpr auto storage_order
The sub storage order.
Definition: batch_k_scale_plus_expr.hpp:23
static constexpr size_t size()
Returns the size of the expression.
Definition: batch_k_scale_plus_expr.hpp:1544
constexpr bool vec_enabled
Indicates if vectorization is available in any format.
Definition: config.hpp:220
order
Storage order of a matrix.
Definition: order.hpp:15
static constexpr size_t dimensions()
Returns the number of dimensions of the expression.
Definition: batch_k_scale_plus_expr.hpp:1552
A _a
The first sub expression reference.
Definition: base_temporary_expr.hpp:638
static size_t size(const expr_t &e)
Returns the size of the expression.
Definition: batch_k_scale_plus_expr.hpp:1536
std::add_lvalue_reference_t< A > a()
Returns the sub expression.
Definition: base_temporary_expr.hpp:686
constexpr bool is_fast
Traits to test if the given ETL expresion type is fast (sizes known at compile-time) ...
Definition: traits.hpp:588
batch_k_scale_plus_expr(A a, B b, C c)
Construct a new expression.
Definition: batch_k_scale_plus_expr.hpp:37
typename VV::template vec_type< value_type > vec_type
The vectorization type for VV.
Definition: base_temporary_expr.hpp:107
Traits to get information about ETL types.
Definition: tmp.hpp:68
Root namespace for the ETL library.
Definition: adapter.hpp:15
void assign_div_to(L &&lhs) const
Divide the given left-hand-side expression.
Definition: batch_k_scale_plus_expr.hpp:1144
static constexpr size_t dimensions()
Return the number of dimensions of the expression.
Definition: traits_base.hpp:31
no_vec default_vec
The default vectorization scheme.
Definition: vectorization.hpp:242
auto dim(E &&value, size_t i) -> detail::identity_helper< E, dim_view< detail::build_identity_type< E >, D >>
Return a view representing the ith Dth dimension.
Definition: view_expression_builder.hpp:25
std::conditional_t< is_etl_value< T >, const std::decay_t< T > &, std::decay_t< T > > build_type
Helper to build the type for a sub expression.
Definition: expression_helpers.hpp:24
Abstract base class for temporary ternary expression.
Definition: base_temporary_expr.hpp:634
static size_t dim(const expr_t &e, size_t d)
Returns the dth dimension of the expression.
Definition: batch_k_scale_plus_expr.hpp:1527
void std_mul_evaluate(Expr &&expr, Result &&result)
Compound multiply evaluation of the expr into result.
Definition: evaluator.hpp:1233
void assign_add_to(L &&lhs) const
Add to the given left-hand-side expression.
Definition: batch_k_scale_plus_expr.hpp:358
constexpr bool is_transformer
Traits indicating if the given ETL type is a transformer expression.
Definition: traits.hpp:297
decltype(auto) smart_forward_gpu(E &expr)
Smart forwarding for a temporary expression that will be computed in GPU.
Definition: helpers.hpp:343
void assign_mul_to(L &&lhs) const
Multiply the given left-hand-side expression.
Definition: batch_k_scale_plus_expr.hpp:882
constexpr size_t size(const E &expr) noexcept
Returns the size of the given ETL expression.
Definition: helpers.hpp:108
constexpr bool is_view
Traits indicating if the given ETL type is a view expression.
Definition: traits.hpp:304
void std_sub_evaluate(Expr &&expr, Result &&result)
Compound subtract evaluation of the expr into result.
Definition: evaluator.hpp:1214
std::decay_t< B > sub_expr_t
The sub expression type.
Definition: batch_k_scale_plus_expr.hpp:1484
friend std::ostream & operator<<(std::ostream &os, const batch_k_scale_plus_expr &expr)
Print a representation of the expression on the given stream.
Definition: batch_k_scale_plus_expr.hpp:1472
static constexpr size_t dim()
Returns the DDth dimension of the expression.
Definition: batch_k_scale_plus_expr.hpp:1517
void assign_to(L &&lhs) const
Assign to a matrix of the same storage order.
Definition: batch_k_scale_plus_expr.hpp:90
constexpr bool is_thread_safe
Traits to test if the given ETL expresion type is thread safe.
Definition: traits.hpp:687
static constexpr bool D4
If the expression is 4D (instead of 2D)
Definition: batch_k_scale_plus_expr.hpp:21
B _b
The second sub expression reference.
Definition: base_temporary_expr.hpp:639
Definition: batch_k_scale_plus_expr.hpp:15
C _c
The third sub expression reference.
Definition: base_temporary_expr.hpp:640
value_t< A > value_type
The value type of the expression.
Definition: batch_k_scale_plus_expr.hpp:1486
static constexpr bool gpu_computable
Indicates if the temporary expression can be directly evaluated using only GPU.
Definition: batch_k_scale_plus_expr.hpp:29
typename decay_traits< E >::value_type value_t
Traits to extract the value type out of an ETL type.
Definition: tmp.hpp:81
void std_div_evaluate(Expr &&expr, Result &&result)
Compound divide evaluation of the expr into result.
Definition: evaluator.hpp:1252
void inc_counter([[maybe_unused]] const char *name)
Increase the given counter.
Definition: counters.hpp:25
static constexpr int complexity() noexcept
Estimate the complexity of computation.
Definition: batch_k_scale_plus_expr.hpp:1560
std::add_lvalue_reference_t< C > c()
Returns the sub expression.
Definition: base_temporary_expr.hpp:718
void std_add_evaluate(Expr &&expr, Result &&result)
Compound add evaluation of the expr into result.
Definition: evaluator.hpp:1195