10 #include "etl/expr/base_temporary_expr.hpp" 20 template <etl_1d A, etl_2d_or_4d B>
27 static constexpr
bool D4 = is_4d<B>;
36 (!D4 && impl::egblas::has_sbatch_k_scale2 && all_row_major<A, B> && all_floating<A, B>)
37 || (D4 && impl::egblas::has_dbatch_k_scale4 && all_row_major<A, B> && all_floating<A, B>);
51 template <same_dimensions<B> C>
52 static void check([[maybe_unused]]
const A&
a, [[maybe_unused]]
const B&
b, [[maybe_unused]]
const C& c) {
54 if constexpr (all_fast<A, C>) {
55 static_assert(etl::dim<0, B>() == etl::dim<0, C>(),
"Invalid dimensions for batch_k_scale");
56 static_assert(etl::dim<1, B>() == etl::dim<1, C>(),
"Invalid dimensions for batch_k_scale");
57 static_assert(etl::dim<2, B>() == etl::dim<2, C>(),
"Invalid dimensions for batch_k_scale");
58 static_assert(etl::dim<3, B>() == etl::dim<3, C>(),
"Invalid dimensions for batch_k_scale");
60 static_assert(etl::dim<0, A>() == etl::dim<1, B>(),
"Invalid dimensions for batch_k_scale");
62 cpp_assert(etl::dim<0>(b) == etl::dim<0>(c),
"Invalid dimensions for batch_k_scale");
63 cpp_assert(etl::dim<1>(b) == etl::dim<1>(c),
"Invalid dimensions for batch_k_scale");
64 cpp_assert(etl::dim<2>(b) == etl::dim<2>(c),
"Invalid dimensions for batch_k_scale");
65 cpp_assert(etl::dim<3>(b) == etl::dim<3>(c),
"Invalid dimensions for batch_k_scale");
67 cpp_assert(etl::dim<0>(a) == etl::dim<1>(b),
"Invalid dimensions for batch_k_scale");
70 if constexpr (all_fast<A, C>) {
71 static_assert(etl::dim<0, B>() == etl::dim<0, C>(),
"Invalid dimensions for batch_k_scale");
72 static_assert(etl::dim<1, B>() == etl::dim<1, C>(),
"Invalid dimensions for batch_k_scale");
74 static_assert(etl::dim<0, A>() == etl::dim<1, B>(),
"Invalid dimensions for batch_k_scale");
76 cpp_assert(etl::dim<0>(b) == etl::dim<0>(c),
"Invalid dimensions for batch_k_scale");
77 cpp_assert(etl::dim<1>(b) == etl::dim<1>(c),
"Invalid dimensions for batch_k_scale");
79 cpp_assert(etl::dim<0>(a) == etl::dim<1>(b),
"Invalid dimensions for batch_k_scale");
100 const auto Batch = etl::dim<0>(lhs);
101 const auto K = etl::dim<1>(lhs);
102 const auto M = etl::dim<2>(lhs);
103 const auto N = etl::dim<3>(lhs);
105 if constexpr (impl::egblas::has_sbatch_k_scale4 && all_row_major<A, B, L> && all_floating<A, B, L>) {
109 t1.ensure_gpu_up_to_date();
110 t2.ensure_gpu_up_to_date();
112 lhs.ensure_gpu_allocated();
114 impl::egblas::batch_k_scale(Batch, K, M, N, t2.gpu_memory(), t1.gpu_memory(), lhs.gpu_memory());
117 lhs.invalidate_cpu();
119 standard_evaluator::pre_assign_rhs(
a);
120 standard_evaluator::pre_assign_rhs(
b);
122 a.ensure_cpu_up_to_date();
123 b.ensure_cpu_up_to_date();
125 auto batch_fun_b = [&](
const size_t first,
const size_t last) {
127 if constexpr (
vec_enabled && all_vectorizable<vector_mode, A, L> && all_row_major<A, L>) {
131 static constexpr
size_t vec_size = vec_type::template traits<T>::size;
133 const auto MN = M * N;
135 for (
size_t batch = first; batch < last; ++batch) {
136 for (
size_t k = 0; k < K; ++k) {
139 auto lhs_sub = lhs(batch)(k);
140 auto b_sub =
b(batch)(k);
144 auto a1 = vec_type::set(ak);
146 for (; mn + 4 * vec_size - 1 < MN; mn += 4 * vec_size) {
147 auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
148 auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
149 auto b3 = b_sub.template loadu<vec_type>(mn + 2 * vec_size);
150 auto b4 = b_sub.template loadu<vec_type>(mn + 3 * vec_size);
152 auto r1 = vec_type::mul(a1, b1);
153 auto r2 = vec_type::mul(a1, b2);
154 auto r3 = vec_type::mul(a1, b3);
155 auto r4 = vec_type::mul(a1, b4);
157 lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
158 lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
159 lhs_sub.template storeu<vec_type>(r3, mn + 2 * vec_size);
160 lhs_sub.template storeu<vec_type>(r4, mn + 3 * vec_size);
163 for (; mn + 2 * vec_size - 1 < MN; mn += 2 * vec_size) {
164 auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
165 auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
167 auto r1 = vec_type::mul(a1, b1);
168 auto r2 = vec_type::mul(a1, b2);
170 lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
171 lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
174 for (; mn + vec_size - 1 < MN; mn += vec_size) {
175 auto b1 = b_sub.template loadu<vec_type>(mn);
177 auto r1 = vec_type::mul(a1, b1);
179 lhs_sub.template storeu<vec_type>(r1, mn);
182 for (; mn + 3 < MN; mn += 4) {
183 lhs_sub[mn + 0] = ak * b_sub[mn + 0];
184 lhs_sub[mn + 1] = ak * b_sub[mn + 1];
185 lhs_sub[mn + 2] = ak * b_sub[mn + 2];
186 lhs_sub[mn + 3] = ak * b_sub[mn + 3];
189 for (; mn + 1 < MN; mn += 2) {
190 lhs_sub[mn + 0] = ak * b_sub[mn + 0];
191 lhs_sub[mn + 1] = ak * b_sub[mn + 1];
194 for (; mn < MN; ++mn) {
195 lhs_sub[mn] = ak * b_sub[mn];
200 for (
size_t batch = first; batch < last; ++batch) {
201 for (
size_t k = 0; k < K; ++k) {
202 for (
size_t m = 0; m < M; ++m) {
203 for (
size_t n = 0; n < N; ++n) {
204 lhs(batch, k, m, n) =
a(k) *
b(batch, k, m, n);
216 lhs.invalidate_gpu();
219 const auto Batch = etl::dim<0>(lhs);
220 const auto K = etl::dim<1>(lhs);
222 if constexpr (impl::egblas::has_sbatch_k_scale2 && all_row_major<A, B, L> && all_floating<A, B, L>) {
226 t1.ensure_gpu_up_to_date();
227 t2.ensure_gpu_up_to_date();
229 lhs.ensure_gpu_allocated();
231 impl::egblas::batch_k_scale(Batch, K, t2.gpu_memory(), t1.gpu_memory(), lhs.gpu_memory());
234 lhs.invalidate_cpu();
236 standard_evaluator::pre_assign_rhs(
a);
237 standard_evaluator::pre_assign_rhs(
b);
239 a.ensure_cpu_up_to_date();
240 b.ensure_cpu_up_to_date();
242 auto batch_fun_b = [&](
const size_t first,
const size_t last) {
244 if constexpr (
vec_enabled && all_vectorizable<vector_mode, A, L> && all_row_major<A, L>) {
248 static constexpr
size_t vec_size = vec_type::template traits<T>::size;
250 for (
size_t batch = first; batch < last; ++batch) {
253 size_t base = batch * K;
255 for (; k + 4 * vec_size - 1 < K; k += 4 * vec_size) {
256 auto a1 =
a.template load<vec_type>(k + 0 * vec_size);
257 auto a2 =
a.template load<vec_type>(k + 1 * vec_size);
258 auto a3 =
a.template load<vec_type>(k + 2 * vec_size);
259 auto a4 =
a.template load<vec_type>(k + 3 * vec_size);
261 auto b1 =
b.template loadu<vec_type>(base + k + 0 * vec_size);
262 auto b2 =
b.template loadu<vec_type>(base + k + 1 * vec_size);
263 auto b3 =
b.template loadu<vec_type>(base + k + 2 * vec_size);
264 auto b4 =
b.template loadu<vec_type>(base + k + 3 * vec_size);
266 auto r1 = vec_type::mul(a1, b1);
267 auto r2 = vec_type::mul(a2, b2);
268 auto r3 = vec_type::mul(a3, b3);
269 auto r4 = vec_type::mul(a4, b4);
271 lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
272 lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
273 lhs.template storeu<vec_type>(r3, base + k + 2 * vec_size);
274 lhs.template storeu<vec_type>(r4, base + k + 3 * vec_size);
277 for (; k + 2 * vec_size - 1 < K; k += 2 * vec_size) {
278 auto a1 =
a.template load<vec_type>(k + 0 * vec_size);
279 auto a2 =
a.template load<vec_type>(k + 1 * vec_size);
281 auto b1 =
b.template loadu<vec_type>(base + k + 0 * vec_size);
282 auto b2 =
b.template loadu<vec_type>(base + k + 1 * vec_size);
284 auto r1 = vec_type::mul(a1, b1);
285 auto r2 = vec_type::mul(a2, b2);
287 lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
288 lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
291 for (; k + vec_size - 1 < K; k += vec_size) {
292 auto a1 =
a.template load<vec_type>(k);
294 auto b1 =
b.template loadu<vec_type>(base + k);
296 auto r1 = vec_type::mul(a1, b1);
298 lhs.template storeu<vec_type>(r1, base + k);
301 for (; k + 3 < K; k += 4) {
302 lhs(batch, k + 0) =
a(k + 0) *
b(batch, k + 0);
303 lhs(batch, k + 1) =
a(k + 1) *
b(batch, k + 1);
304 lhs(batch, k + 2) =
a(k + 2) *
b(batch, k + 2);
305 lhs(batch, k + 3) =
a(k + 3) *
b(batch, k + 3);
308 for (; k + 1 < K; k += 2) {
309 lhs(batch, k + 0) =
a(k + 0) *
b(batch, k + 0);
310 lhs(batch, k + 1) =
a(k + 1) *
b(batch, k + 1);
314 lhs(batch, k) =
a(k) *
b(batch, k);
318 for (
size_t batch = first; batch < last; ++batch) {
319 for (
size_t k = 0; k < K; ++k) {
320 lhs(batch, k) =
a(k) *
b(batch, k);
330 lhs.invalidate_gpu();
339 template <etl_expr L>
347 if constexpr (impl::egblas::has_sbatch_k_scale4 && all_row_major<A, B, L> && all_floating<A, B, L>) {
350 const auto Batch = etl::dim<0>(lhs);
351 const auto K = etl::dim<1>(lhs);
352 const auto M = etl::dim<2>(lhs);
353 const auto N = etl::dim<3>(lhs);
355 standard_evaluator::pre_assign_rhs(
a);
356 standard_evaluator::pre_assign_rhs(
b);
358 a.ensure_cpu_up_to_date();
359 b.ensure_cpu_up_to_date();
360 lhs.ensure_cpu_up_to_date();
362 auto batch_fun_b = [&](
const size_t first,
const size_t last) {
364 if constexpr (
vec_enabled && all_vectorizable<vector_mode, A, L> && all_row_major<A, L>) {
368 static constexpr
size_t vec_size = vec_type::template traits<T>::size;
370 const auto MN = M * N;
372 for (
size_t batch = first; batch < last; ++batch) {
373 for (
size_t k = 0; k < K; ++k) {
376 auto lhs_sub = lhs(batch)(k);
377 auto b_sub =
b(batch)(k);
381 auto a1 = vec_type::set(ak);
383 for (; mn + 4 * vec_size - 1 < MN; mn += 4 * vec_size) {
384 auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
385 auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
386 auto b3 = b_sub.template loadu<vec_type>(mn + 2 * vec_size);
387 auto b4 = b_sub.template loadu<vec_type>(mn + 3 * vec_size);
389 auto l1 = lhs_sub.template loadu<vec_type>(mn + 0 * vec_size);
390 auto l2 = lhs_sub.template loadu<vec_type>(mn + 1 * vec_size);
391 auto l3 = lhs_sub.template loadu<vec_type>(mn + 2 * vec_size);
392 auto l4 = lhs_sub.template loadu<vec_type>(mn + 3 * vec_size);
394 auto r1 = vec_type::add(l1, vec_type::mul(a1, b1));
395 auto r2 = vec_type::add(l2, vec_type::mul(a1, b2));
396 auto r3 = vec_type::add(l3, vec_type::mul(a1, b3));
397 auto r4 = vec_type::add(l4, vec_type::mul(a1, b4));
399 lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
400 lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
401 lhs_sub.template storeu<vec_type>(r3, mn + 2 * vec_size);
402 lhs_sub.template storeu<vec_type>(r4, mn + 3 * vec_size);
405 for (; mn + 2 * vec_size - 1 < MN; mn += 2 * vec_size) {
406 auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
407 auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
409 auto l1 = lhs_sub.template loadu<vec_type>(mn + 0 * vec_size);
410 auto l2 = lhs_sub.template loadu<vec_type>(mn + 1 * vec_size);
412 auto r1 = vec_type::add(l1, vec_type::mul(a1, b1));
413 auto r2 = vec_type::add(l2, vec_type::mul(a1, b2));
415 lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
416 lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
419 for (; mn + vec_size - 1 < MN; mn += vec_size) {
420 auto b1 = b_sub.template loadu<vec_type>(mn);
422 auto l1 = lhs_sub.template loadu<vec_type>(mn);
424 auto r1 = vec_type::add(l1, vec_type::mul(a1, b1));
426 lhs_sub.template storeu<vec_type>(r1, mn);
429 for (; mn + 3 < MN; mn += 4) {
430 lhs_sub[mn + 0] += ak * b_sub[mn + 0];
431 lhs_sub[mn + 1] += ak * b_sub[mn + 1];
432 lhs_sub[mn + 2] += ak * b_sub[mn + 2];
433 lhs_sub[mn + 3] += ak * b_sub[mn + 3];
436 for (; mn + 1 < MN; mn += 2) {
437 lhs_sub[mn + 0] += ak * b_sub[mn + 0];
438 lhs_sub[mn + 1] += ak * b_sub[mn + 1];
441 for (; mn < MN; ++mn) {
442 lhs_sub[mn] += ak * b_sub[mn];
447 for (
size_t batch = first; batch < last; ++batch) {
448 for (
size_t k = 0; k < K; ++k) {
449 for (
size_t m = 0; m < M; ++m) {
450 for (
size_t n = 0; n < N; ++n) {
451 lhs(batch, k, m, n) +=
a(k) *
b(batch, k, m, n);
463 lhs.invalidate_gpu();
466 if constexpr (impl::egblas::has_sbatch_k_scale2 && all_row_major<A, B, L> && all_floating<A, B, L>) {
469 const auto Batch = etl::dim<0>(lhs);
470 const auto K = etl::dim<1>(lhs);
472 standard_evaluator::pre_assign_rhs(
a);
473 standard_evaluator::pre_assign_rhs(
b);
475 a.ensure_cpu_up_to_date();
476 b.ensure_cpu_up_to_date();
477 lhs.ensure_cpu_up_to_date();
479 auto batch_fun_b = [&](
const size_t first,
const size_t last) {
481 if constexpr (
vec_enabled && all_vectorizable<vector_mode, A, L> && all_row_major<A, L>) {
485 static constexpr
size_t vec_size = vec_type::template traits<T>::size;
487 for (
size_t batch = first; batch < last; ++batch) {
490 size_t base = batch * K;
492 for (; k + 4 * vec_size - 1 < K; k += 4 * vec_size) {
493 auto a1 =
a.template load<vec_type>(k + 0 * vec_size);
494 auto a2 =
a.template load<vec_type>(k + 1 * vec_size);
495 auto a3 =
a.template load<vec_type>(k + 2 * vec_size);
496 auto a4 =
a.template load<vec_type>(k + 3 * vec_size);
498 auto b1 =
b.template loadu<vec_type>(base + k + 0 * vec_size);
499 auto b2 =
b.template loadu<vec_type>(base + k + 1 * vec_size);
500 auto b3 =
b.template loadu<vec_type>(base + k + 2 * vec_size);
501 auto b4 =
b.template loadu<vec_type>(base + k + 3 * vec_size);
503 auto l1 = lhs.template loadu<vec_type>(base + k + 0 * vec_size);
504 auto l2 = lhs.template loadu<vec_type>(base + k + 1 * vec_size);
505 auto l3 = lhs.template loadu<vec_type>(base + k + 2 * vec_size);
506 auto l4 = lhs.template loadu<vec_type>(base + k + 3 * vec_size);
508 auto r1 = vec_type::add(l1, vec_type::mul(a1, b1));
509 auto r2 = vec_type::add(l2, vec_type::mul(a2, b2));
510 auto r3 = vec_type::add(l3, vec_type::mul(a3, b3));
511 auto r4 = vec_type::add(l4, vec_type::mul(a4, b4));
513 lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
514 lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
515 lhs.template storeu<vec_type>(r3, base + k + 2 * vec_size);
516 lhs.template storeu<vec_type>(r4, base + k + 3 * vec_size);
519 for (; k + 2 * vec_size - 1 < K; k += 2 * vec_size) {
520 auto a1 =
a.template load<vec_type>(k + 0 * vec_size);
521 auto a2 =
a.template load<vec_type>(k + 1 * vec_size);
523 auto b1 =
b.template loadu<vec_type>(base + k + 0 * vec_size);
524 auto b2 =
b.template loadu<vec_type>(base + k + 1 * vec_size);
526 auto l1 = lhs.template loadu<vec_type>(base + k + 0 * vec_size);
527 auto l2 = lhs.template loadu<vec_type>(base + k + 1 * vec_size);
529 auto r1 = vec_type::add(l1, vec_type::mul(a1, b1));
530 auto r2 = vec_type::add(l2, vec_type::mul(a2, b2));
532 lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
533 lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
536 for (; k + vec_size - 1 < K; k += vec_size) {
537 auto a1 =
a.template load<vec_type>(k);
539 auto b1 =
b.template loadu<vec_type>(base + k);
541 auto l1 = lhs.template loadu<vec_type>(base + k);
543 auto r1 = vec_type::add(l1, vec_type::mul(a1, b1));
545 lhs.template storeu<vec_type>(r1, base + k);
548 for (; k + 3 < K; k += 4) {
549 lhs(batch, k + 0) +=
a(k + 0) *
b(batch, k + 0);
550 lhs(batch, k + 1) +=
a(k + 1) *
b(batch, k + 1);
551 lhs(batch, k + 2) +=
a(k + 2) *
b(batch, k + 2);
552 lhs(batch, k + 3) +=
a(k + 3) *
b(batch, k + 3);
555 for (; k + 1 < K; k += 2) {
556 lhs(batch, k + 0) +=
a(k + 0) *
b(batch, k + 0);
557 lhs(batch, k + 1) +=
a(k + 1) *
b(batch, k + 1);
561 lhs(batch, k) +=
a(k) *
b(batch, k);
565 for (
size_t batch = first; batch < last; ++batch) {
566 for (
size_t k = 0; k < K; ++k) {
567 lhs(batch, k) +=
a(k) *
b(batch, k);
577 lhs.invalidate_gpu();
586 template <etl_expr L>
594 if constexpr (impl::egblas::has_sbatch_k_scale4 && all_row_major<A, B, L> && all_floating<A, B, L>) {
597 const auto Batch = etl::dim<0>(lhs);
598 const auto K = etl::dim<1>(lhs);
599 const auto M = etl::dim<2>(lhs);
600 const auto N = etl::dim<3>(lhs);
602 standard_evaluator::pre_assign_rhs(
a);
603 standard_evaluator::pre_assign_rhs(
b);
605 a.ensure_cpu_up_to_date();
606 b.ensure_cpu_up_to_date();
607 lhs.ensure_cpu_up_to_date();
609 auto batch_fun_b = [&](
const size_t first,
const size_t last) {
611 if constexpr (
vec_enabled && all_vectorizable<vector_mode, A, L> && all_row_major<A, L>) {
615 static constexpr
size_t vec_size = vec_type::template traits<T>::size;
617 const auto MN = M * N;
619 for (
size_t batch = first; batch < last; ++batch) {
620 for (
size_t k = 0; k < K; ++k) {
623 auto lhs_sub = lhs(batch)(k);
624 auto b_sub =
b(batch)(k);
628 auto a1 = vec_type::set(ak);
630 for (; mn + 4 * vec_size - 1 < MN; mn += 4 * vec_size) {
631 auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
632 auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
633 auto b3 = b_sub.template loadu<vec_type>(mn + 2 * vec_size);
634 auto b4 = b_sub.template loadu<vec_type>(mn + 3 * vec_size);
636 auto l1 = lhs_sub.template loadu<vec_type>(mn + 0 * vec_size);
637 auto l2 = lhs_sub.template loadu<vec_type>(mn + 1 * vec_size);
638 auto l3 = lhs_sub.template loadu<vec_type>(mn + 2 * vec_size);
639 auto l4 = lhs_sub.template loadu<vec_type>(mn + 3 * vec_size);
641 auto r1 = vec_type::sub(l1, vec_type::mul(a1, b1));
642 auto r2 = vec_type::sub(l2, vec_type::mul(a1, b2));
643 auto r3 = vec_type::sub(l3, vec_type::mul(a1, b3));
644 auto r4 = vec_type::sub(l4, vec_type::mul(a1, b4));
646 lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
647 lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
648 lhs_sub.template storeu<vec_type>(r3, mn + 2 * vec_size);
649 lhs_sub.template storeu<vec_type>(r4, mn + 3 * vec_size);
652 for (; mn + 2 * vec_size - 1 < MN; mn += 2 * vec_size) {
653 auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
654 auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
656 auto l1 = lhs_sub.template loadu<vec_type>(mn + 0 * vec_size);
657 auto l2 = lhs_sub.template loadu<vec_type>(mn + 1 * vec_size);
659 auto r1 = vec_type::sub(l1, vec_type::mul(a1, b1));
660 auto r2 = vec_type::sub(l2, vec_type::mul(a1, b2));
662 lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
663 lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
666 for (; mn + vec_size - 1 < MN; mn += vec_size) {
667 auto b1 = b_sub.template loadu<vec_type>(mn);
669 auto l1 = lhs_sub.template loadu<vec_type>(mn);
671 auto r1 = vec_type::sub(l1, vec_type::mul(a1, b1));
673 lhs_sub.template storeu<vec_type>(r1, mn);
676 for (; mn + 3 < MN; mn += 4) {
677 lhs_sub[mn + 0] -= ak * b_sub[mn + 0];
678 lhs_sub[mn + 1] -= ak * b_sub[mn + 1];
679 lhs_sub[mn + 2] -= ak * b_sub[mn + 2];
680 lhs_sub[mn + 3] -= ak * b_sub[mn + 3];
683 for (; mn + 1 < MN; mn += 2) {
684 lhs_sub[mn + 0] -= ak * b_sub[mn + 0];
685 lhs_sub[mn + 1] -= ak * b_sub[mn + 1];
688 for (; mn < MN; ++mn) {
689 lhs_sub[mn] -= ak * b_sub[mn];
694 for (
size_t batch = first; batch < last; ++batch) {
695 for (
size_t k = 0; k < K; ++k) {
696 for (
size_t m = 0; m < M; ++m) {
697 for (
size_t n = 0; n < N; ++n) {
698 lhs(batch, k, m, n) -=
a(k) *
b(batch, k, m, n);
710 lhs.invalidate_gpu();
713 if constexpr (impl::egblas::has_sbatch_k_scale2 && all_row_major<A, B, L> && all_floating<A, B, L>) {
716 const auto Batch = etl::dim<0>(lhs);
717 const auto K = etl::dim<1>(lhs);
719 standard_evaluator::pre_assign_rhs(
a);
720 standard_evaluator::pre_assign_rhs(
b);
722 a.ensure_cpu_up_to_date();
723 b.ensure_cpu_up_to_date();
724 lhs.ensure_cpu_up_to_date();
726 auto batch_fun_b = [&](
const size_t first,
const size_t last) {
728 if constexpr (
vec_enabled && all_vectorizable<vector_mode, A, L> && all_row_major<A, L>) {
732 static constexpr
size_t vec_size = vec_type::template traits<T>::size;
734 for (
size_t batch = first; batch < last; ++batch) {
737 size_t base = batch * K;
739 for (; k + 4 * vec_size - 1 < K; k += 4 * vec_size) {
740 auto a1 =
a.template load<vec_type>(k + 0 * vec_size);
741 auto a2 =
a.template load<vec_type>(k + 1 * vec_size);
742 auto a3 =
a.template load<vec_type>(k + 2 * vec_size);
743 auto a4 =
a.template load<vec_type>(k + 3 * vec_size);
745 auto b1 =
b.template loadu<vec_type>(base + k + 0 * vec_size);
746 auto b2 =
b.template loadu<vec_type>(base + k + 1 * vec_size);
747 auto b3 =
b.template loadu<vec_type>(base + k + 2 * vec_size);
748 auto b4 =
b.template loadu<vec_type>(base + k + 3 * vec_size);
750 auto l1 = lhs.template loadu<vec_type>(base + k + 0 * vec_size);
751 auto l2 = lhs.template loadu<vec_type>(base + k + 1 * vec_size);
752 auto l3 = lhs.template loadu<vec_type>(base + k + 2 * vec_size);
753 auto l4 = lhs.template loadu<vec_type>(base + k + 3 * vec_size);
755 auto r1 = vec_type::sub(l1, vec_type::mul(a1, b1));
756 auto r2 = vec_type::sub(l2, vec_type::mul(a2, b2));
757 auto r3 = vec_type::sub(l3, vec_type::mul(a3, b3));
758 auto r4 = vec_type::sub(l4, vec_type::mul(a4, b4));
760 lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
761 lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
762 lhs.template storeu<vec_type>(r3, base + k + 2 * vec_size);
763 lhs.template storeu<vec_type>(r4, base + k + 3 * vec_size);
766 for (; k + 2 * vec_size - 1 < K; k += 2 * vec_size) {
767 auto a1 =
a.template load<vec_type>(k + 0 * vec_size);
768 auto a2 =
a.template load<vec_type>(k + 1 * vec_size);
770 auto b1 =
b.template loadu<vec_type>(base + k + 0 * vec_size);
771 auto b2 =
b.template loadu<vec_type>(base + k + 1 * vec_size);
773 auto l1 = lhs.template loadu<vec_type>(base + k + 0 * vec_size);
774 auto l2 = lhs.template loadu<vec_type>(base + k + 1 * vec_size);
776 auto r1 = vec_type::sub(l1, vec_type::mul(a1, b1));
777 auto r2 = vec_type::sub(l2, vec_type::mul(a2, b2));
779 lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
780 lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
783 for (; k + vec_size - 1 < K; k += vec_size) {
784 auto a1 =
a.template load<vec_type>(k);
786 auto b1 =
b.template loadu<vec_type>(base + k);
788 auto l1 = lhs.template loadu<vec_type>(base + k);
790 auto r1 = vec_type::sub(l1, vec_type::mul(a1, b1));
792 lhs.template storeu<vec_type>(r1, base + k);
795 for (; k + 3 < K; k += 4) {
796 lhs(batch, k + 0) -=
a(k + 0) *
b(batch, k + 0);
797 lhs(batch, k + 1) -=
a(k + 1) *
b(batch, k + 1);
798 lhs(batch, k + 2) -=
a(k + 2) *
b(batch, k + 2);
799 lhs(batch, k + 3) -=
a(k + 3) *
b(batch, k + 3);
802 for (; k + 1 < K; k += 2) {
803 lhs(batch, k + 0) -=
a(k + 0) *
b(batch, k + 0);
804 lhs(batch, k + 1) -=
a(k + 1) *
b(batch, k + 1);
808 lhs(batch, k) -=
a(k) *
b(batch, k);
812 for (
size_t batch = first; batch < last; ++batch) {
813 for (
size_t k = 0; k < K; ++k) {
814 lhs(batch, k) -=
a(k) *
b(batch, k);
824 lhs.invalidate_gpu();
833 template <etl_expr L>
841 if constexpr (impl::egblas::has_sbatch_k_scale4 && all_row_major<A, B, L> && all_floating<A, B, L>) {
844 const auto Batch = etl::dim<0>(lhs);
845 const auto K = etl::dim<1>(lhs);
846 const auto M = etl::dim<2>(lhs);
847 const auto N = etl::dim<3>(lhs);
849 standard_evaluator::pre_assign_rhs(
a);
850 standard_evaluator::pre_assign_rhs(
b);
852 a.ensure_cpu_up_to_date();
853 b.ensure_cpu_up_to_date();
854 lhs.ensure_cpu_up_to_date();
856 auto batch_fun_b = [&](
const size_t first,
const size_t last) {
858 if constexpr (
vec_enabled && all_vectorizable<vector_mode, A, L> && all_row_major<A, L>) {
862 static constexpr
size_t vec_size = vec_type::template traits<T>::size;
864 const auto MN = M * N;
866 for (
size_t batch = first; batch < last; ++batch) {
867 for (
size_t k = 0; k < K; ++k) {
870 auto lhs_sub = lhs(batch)(k);
871 auto b_sub =
b(batch)(k);
875 auto a1 = vec_type::set(ak);
877 for (; mn + 4 * vec_size - 1 < MN; mn += 4 * vec_size) {
878 auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
879 auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
880 auto b3 = b_sub.template loadu<vec_type>(mn + 2 * vec_size);
881 auto b4 = b_sub.template loadu<vec_type>(mn + 3 * vec_size);
883 auto l1 = lhs_sub.template loadu<vec_type>(mn + 0 * vec_size);
884 auto l2 = lhs_sub.template loadu<vec_type>(mn + 1 * vec_size);
885 auto l3 = lhs_sub.template loadu<vec_type>(mn + 2 * vec_size);
886 auto l4 = lhs_sub.template loadu<vec_type>(mn + 3 * vec_size);
888 auto r1 = vec_type::mul(l1, vec_type::mul(a1, b1));
889 auto r2 = vec_type::mul(l2, vec_type::mul(a1, b2));
890 auto r3 = vec_type::mul(l3, vec_type::mul(a1, b3));
891 auto r4 = vec_type::mul(l4, vec_type::mul(a1, b4));
893 lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
894 lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
895 lhs_sub.template storeu<vec_type>(r3, mn + 2 * vec_size);
896 lhs_sub.template storeu<vec_type>(r4, mn + 3 * vec_size);
899 for (; mn + 2 * vec_size - 1 < MN; mn += 2 * vec_size) {
900 auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
901 auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
903 auto l1 = lhs_sub.template loadu<vec_type>(mn + 0 * vec_size);
904 auto l2 = lhs_sub.template loadu<vec_type>(mn + 1 * vec_size);
906 auto r1 = vec_type::mul(l1, vec_type::mul(a1, b1));
907 auto r2 = vec_type::mul(l2, vec_type::mul(a1, b2));
909 lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
910 lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
913 for (; mn + vec_size - 1 < MN; mn += vec_size) {
914 auto b1 = b_sub.template loadu<vec_type>(mn);
916 auto l1 = lhs_sub.template loadu<vec_type>(mn);
918 auto r1 = vec_type::mul(l1, vec_type::mul(a1, b1));
920 lhs_sub.template storeu<vec_type>(r1, mn);
923 for (; mn + 3 < MN; mn += 4) {
924 lhs_sub[mn + 0] *= ak * b_sub[mn + 0];
925 lhs_sub[mn + 1] *= ak * b_sub[mn + 1];
926 lhs_sub[mn + 2] *= ak * b_sub[mn + 2];
927 lhs_sub[mn + 3] *= ak * b_sub[mn + 3];
930 for (; mn + 1 < MN; mn += 2) {
931 lhs_sub[mn + 0] *= ak * b_sub[mn + 0];
932 lhs_sub[mn + 1] *= ak * b_sub[mn + 1];
935 for (; mn < MN; ++mn) {
936 lhs_sub[mn] *= ak * b_sub[mn];
941 for (
size_t batch = first; batch < last; ++batch) {
942 for (
size_t k = 0; k < K; ++k) {
943 for (
size_t m = 0; m < M; ++m) {
944 for (
size_t n = 0; n < N; ++n) {
945 lhs(batch, k, m, n) *=
a(k) *
b(batch, k, m, n);
957 lhs.invalidate_gpu();
960 if constexpr (impl::egblas::has_sbatch_k_scale2 && all_row_major<A, B, L> && all_floating<A, B, L>) {
963 const auto Batch = etl::dim<0>(lhs);
964 const auto K = etl::dim<1>(lhs);
966 standard_evaluator::pre_assign_rhs(
a);
967 standard_evaluator::pre_assign_rhs(
b);
969 a.ensure_cpu_up_to_date();
970 b.ensure_cpu_up_to_date();
971 lhs.ensure_cpu_up_to_date();
973 auto batch_fun_b = [&](
const size_t first,
const size_t last) {
975 if constexpr (
vec_enabled && all_vectorizable<vector_mode, A, L> && all_row_major<A, L>) {
979 static constexpr
size_t vec_size = vec_type::template traits<T>::size;
981 for (
size_t batch = first; batch < last; ++batch) {
984 size_t base = batch * K;
986 for (; k + 4 * vec_size - 1 < K; k += 4 * vec_size) {
987 auto a1 =
a.template load<vec_type>(k + 0 * vec_size);
988 auto a2 =
a.template load<vec_type>(k + 1 * vec_size);
989 auto a3 =
a.template load<vec_type>(k + 2 * vec_size);
990 auto a4 =
a.template load<vec_type>(k + 3 * vec_size);
992 auto b1 =
b.template loadu<vec_type>(base + k + 0 * vec_size);
993 auto b2 =
b.template loadu<vec_type>(base + k + 1 * vec_size);
994 auto b3 =
b.template loadu<vec_type>(base + k + 2 * vec_size);
995 auto b4 =
b.template loadu<vec_type>(base + k + 3 * vec_size);
997 auto l1 = lhs.template loadu<vec_type>(base + k + 0 * vec_size);
998 auto l2 = lhs.template loadu<vec_type>(base + k + 1 * vec_size);
999 auto l3 = lhs.template loadu<vec_type>(base + k + 2 * vec_size);
1000 auto l4 = lhs.template loadu<vec_type>(base + k + 3 * vec_size);
1002 auto r1 = vec_type::mul(l1, vec_type::mul(a1, b1));
1003 auto r2 = vec_type::mul(l2, vec_type::mul(a2, b2));
1004 auto r3 = vec_type::mul(l3, vec_type::mul(a3, b3));
1005 auto r4 = vec_type::mul(l4, vec_type::mul(a4, b4));
1007 lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
1008 lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
1009 lhs.template storeu<vec_type>(r3, base + k + 2 * vec_size);
1010 lhs.template storeu<vec_type>(r4, base + k + 3 * vec_size);
1013 for (; k + 2 * vec_size - 1 < K; k += 2 * vec_size) {
1014 auto a1 =
a.template load<vec_type>(k + 0 * vec_size);
1015 auto a2 =
a.template load<vec_type>(k + 1 * vec_size);
1017 auto b1 =
b.template loadu<vec_type>(base + k + 0 * vec_size);
1018 auto b2 =
b.template loadu<vec_type>(base + k + 1 * vec_size);
1020 auto l1 = lhs.template loadu<vec_type>(base + k + 0 * vec_size);
1021 auto l2 = lhs.template loadu<vec_type>(base + k + 1 * vec_size);
1023 auto r1 = vec_type::mul(l1, vec_type::mul(a1, b1));
1024 auto r2 = vec_type::mul(l2, vec_type::mul(a2, b2));
1026 lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
1027 lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
1030 for (; k + vec_size - 1 < K; k += vec_size) {
1031 auto a1 =
a.template load<vec_type>(k);
1033 auto b1 =
b.template loadu<vec_type>(base + k);
1035 auto l1 = lhs.template loadu<vec_type>(base + k);
1037 auto r1 = vec_type::mul(l1, vec_type::mul(a1, b1));
1039 lhs.template storeu<vec_type>(r1, base + k);
1042 for (; k + 3 < K; k += 4) {
1043 lhs(batch, k + 0) *=
a(k + 0) *
b(batch, k + 0);
1044 lhs(batch, k + 1) *=
a(k + 1) *
b(batch, k + 1);
1045 lhs(batch, k + 2) *=
a(k + 2) *
b(batch, k + 2);
1046 lhs(batch, k + 3) *=
a(k + 3) *
b(batch, k + 3);
1049 for (; k + 1 < K; k += 2) {
1050 lhs(batch, k + 0) *=
a(k + 0) *
b(batch, k + 0);
1051 lhs(batch, k + 1) *=
a(k + 1) *
b(batch, k + 1);
1055 lhs(batch, k) *=
a(k) *
b(batch, k);
1059 for (
size_t batch = first; batch < last; ++batch) {
1060 for (
size_t k = 0; k < K; ++k) {
1061 lhs(batch, k) *=
a(k) *
b(batch, k);
1071 lhs.invalidate_gpu();
1080 template <etl_expr L>
1082 auto&
a = this->
a();
1083 auto&
b = this->
b();
1088 if constexpr (impl::egblas::has_sbatch_k_scale4 && all_row_major<A, B, L> && all_floating<A, B, L>) {
1091 const auto Batch = etl::dim<0>(lhs);
1092 const auto K = etl::dim<1>(lhs);
1093 const auto M = etl::dim<2>(lhs);
1094 const auto N = etl::dim<3>(lhs);
1096 standard_evaluator::pre_assign_rhs(
a);
1097 standard_evaluator::pre_assign_rhs(
b);
1099 a.ensure_cpu_up_to_date();
1100 b.ensure_cpu_up_to_date();
1101 lhs.ensure_cpu_up_to_date();
1103 auto batch_fun_b = [&](
const size_t first,
const size_t last) {
1105 if constexpr (
vec_enabled && all_vectorizable<vector_mode, A, L> && all_row_major<A, L>) {
1109 static constexpr
size_t vec_size = vec_type::template traits<T>::size;
1111 const auto MN = M * N;
1113 for (
size_t batch = first; batch < last; ++batch) {
1114 for (
size_t k = 0; k < K; ++k) {
1117 auto lhs_sub = lhs(batch)(k);
1118 auto b_sub =
b(batch)(k);
1122 auto a1 = vec_type::set(ak);
1124 for (; mn + 4 * vec_size - 1 < MN; mn += 4 * vec_size) {
1125 auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
1126 auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
1127 auto b3 = b_sub.template loadu<vec_type>(mn + 2 * vec_size);
1128 auto b4 = b_sub.template loadu<vec_type>(mn + 3 * vec_size);
1130 auto l1 = lhs_sub.template loadu<vec_type>(mn + 0 * vec_size);
1131 auto l2 = lhs_sub.template loadu<vec_type>(mn + 1 * vec_size);
1132 auto l3 = lhs_sub.template loadu<vec_type>(mn + 2 * vec_size);
1133 auto l4 = lhs_sub.template loadu<vec_type>(mn + 3 * vec_size);
1135 auto r1 = vec_type::div(l1, vec_type::mul(a1, b1));
1136 auto r2 = vec_type::div(l2, vec_type::mul(a1, b2));
1137 auto r3 = vec_type::div(l3, vec_type::mul(a1, b3));
1138 auto r4 = vec_type::div(l4, vec_type::mul(a1, b4));
1140 lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
1141 lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
1142 lhs_sub.template storeu<vec_type>(r3, mn + 2 * vec_size);
1143 lhs_sub.template storeu<vec_type>(r4, mn + 3 * vec_size);
1146 for (; mn + 2 * vec_size - 1 < MN; mn += 2 * vec_size) {
1147 auto b1 = b_sub.template loadu<vec_type>(mn + 0 * vec_size);
1148 auto b2 = b_sub.template loadu<vec_type>(mn + 1 * vec_size);
1150 auto l1 = lhs_sub.template loadu<vec_type>(mn + 0 * vec_size);
1151 auto l2 = lhs_sub.template loadu<vec_type>(mn + 1 * vec_size);
1153 auto r1 = vec_type::div(l1, vec_type::mul(a1, b1));
1154 auto r2 = vec_type::div(l2, vec_type::mul(a1, b2));
1156 lhs_sub.template storeu<vec_type>(r1, mn + 0 * vec_size);
1157 lhs_sub.template storeu<vec_type>(r2, mn + 1 * vec_size);
1160 for (; mn + vec_size - 1 < MN; mn += vec_size) {
1161 auto b1 = b_sub.template loadu<vec_type>(mn);
1163 auto l1 = lhs_sub.template loadu<vec_type>(mn);
1165 auto r1 = vec_type::div(l1, vec_type::mul(a1, b1));
1167 lhs_sub.template storeu<vec_type>(r1, mn);
1170 for (; mn + 3 < MN; mn += 4) {
1171 lhs_sub[mn + 0] /= ak * b_sub[mn + 0];
1172 lhs_sub[mn + 1] /= ak * b_sub[mn + 1];
1173 lhs_sub[mn + 2] /= ak * b_sub[mn + 2];
1174 lhs_sub[mn + 3] /= ak * b_sub[mn + 3];
1177 for (; mn + 1 < MN; mn += 2) {
1178 lhs_sub[mn + 0] /= ak * b_sub[mn + 0];
1179 lhs_sub[mn + 1] /= ak * b_sub[mn + 1];
1182 for (; mn < MN; ++mn) {
1183 lhs_sub[mn] /= ak * b_sub[mn];
1188 for (
size_t batch = first; batch < last; ++batch) {
1189 for (
size_t k = 0; k < K; ++k) {
1190 for (
size_t m = 0; m < M; ++m) {
1191 for (
size_t n = 0; n < N; ++n) {
1192 lhs(batch, k, m, n) /=
a(k) *
b(batch, k, m, n);
1204 lhs.invalidate_gpu();
1207 if constexpr (impl::egblas::has_sbatch_k_scale2 && all_row_major<A, B, L> && all_floating<A, B, L>) {
1210 const auto Batch = etl::dim<0>(lhs);
1211 const auto K = etl::dim<1>(lhs);
1213 standard_evaluator::pre_assign_rhs(
a);
1214 standard_evaluator::pre_assign_rhs(
b);
1216 a.ensure_cpu_up_to_date();
1217 b.ensure_cpu_up_to_date();
1218 lhs.ensure_cpu_up_to_date();
1220 auto batch_fun_b = [&](
const size_t first,
const size_t last) {
1222 if constexpr (
vec_enabled && all_vectorizable<vector_mode, A, L> && all_row_major<A, L>) {
1226 static constexpr
size_t vec_size = vec_type::template traits<T>::size;
1228 for (
size_t batch = first; batch < last; ++batch) {
1231 size_t base = batch * K;
1233 for (; k + 4 * vec_size - 1 < K; k += 4 * vec_size) {
1234 auto a1 =
a.template load<vec_type>(k + 0 * vec_size);
1235 auto a2 =
a.template load<vec_type>(k + 1 * vec_size);
1236 auto a3 =
a.template load<vec_type>(k + 2 * vec_size);
1237 auto a4 =
a.template load<vec_type>(k + 3 * vec_size);
1239 auto b1 =
b.template loadu<vec_type>(base + k + 0 * vec_size);
1240 auto b2 =
b.template loadu<vec_type>(base + k + 1 * vec_size);
1241 auto b3 =
b.template loadu<vec_type>(base + k + 2 * vec_size);
1242 auto b4 =
b.template loadu<vec_type>(base + k + 3 * vec_size);
1244 auto l1 = lhs.template loadu<vec_type>(base + k + 0 * vec_size);
1245 auto l2 = lhs.template loadu<vec_type>(base + k + 1 * vec_size);
1246 auto l3 = lhs.template loadu<vec_type>(base + k + 2 * vec_size);
1247 auto l4 = lhs.template loadu<vec_type>(base + k + 3 * vec_size);
1249 auto r1 = vec_type::div(l1, vec_type::mul(a1, b1));
1250 auto r2 = vec_type::div(l2, vec_type::mul(a2, b2));
1251 auto r3 = vec_type::div(l3, vec_type::mul(a3, b3));
1252 auto r4 = vec_type::div(l4, vec_type::mul(a4, b4));
1254 lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
1255 lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
1256 lhs.template storeu<vec_type>(r3, base + k + 2 * vec_size);
1257 lhs.template storeu<vec_type>(r4, base + k + 3 * vec_size);
1260 for (; k + 2 * vec_size - 1 < K; k += 2 * vec_size) {
1261 auto a1 =
a.template load<vec_type>(k + 0 * vec_size);
1262 auto a2 =
a.template load<vec_type>(k + 1 * vec_size);
1264 auto b1 =
b.template loadu<vec_type>(base + k + 0 * vec_size);
1265 auto b2 =
b.template loadu<vec_type>(base + k + 1 * vec_size);
1267 auto l1 = lhs.template loadu<vec_type>(base + k + 0 * vec_size);
1268 auto l2 = lhs.template loadu<vec_type>(base + k + 1 * vec_size);
1270 auto r1 = vec_type::div(l1, vec_type::mul(a1, b1));
1271 auto r2 = vec_type::div(l2, vec_type::mul(a2, b2));
1273 lhs.template storeu<vec_type>(r1, base + k + 0 * vec_size);
1274 lhs.template storeu<vec_type>(r2, base + k + 1 * vec_size);
1277 for (; k + vec_size - 1 < K; k += vec_size) {
1278 auto a1 =
a.template load<vec_type>(k);
1280 auto b1 =
b.template loadu<vec_type>(base + k);
1282 auto l1 = lhs.template loadu<vec_type>(base + k);
1284 auto r1 = vec_type::div(l1, vec_type::mul(a1, b1));
1286 lhs.template storeu<vec_type>(r1, base + k);
1289 for (; k + 3 < K; k += 4) {
1290 lhs(batch, k + 0) /=
a(k + 0) *
b(batch, k + 0);
1291 lhs(batch, k + 1) /=
a(k + 1) *
b(batch, k + 1);
1292 lhs(batch, k + 2) /=
a(k + 2) *
b(batch, k + 2);
1293 lhs(batch, k + 3) /=
a(k + 3) *
b(batch, k + 3);
1296 for (; k + 1 < K; k += 2) {
1297 lhs(batch, k + 0) /=
a(k + 0) *
b(batch, k + 0);
1298 lhs(batch, k + 1) /=
a(k + 1) *
b(batch, k + 1);
1302 lhs(batch, k) /=
a(k) *
b(batch, k);
1306 for (
size_t batch = first; batch < last; ++batch) {
1307 for (
size_t k = 0; k < K; ++k) {
1308 lhs(batch, k) /=
a(k) *
b(batch, k);
1318 lhs.invalidate_gpu();
1327 template <etl_expr L>
1329 auto&
a = this->
a();
1330 auto&
b = this->
b();
1334 standard_evaluator::pre_assign_rhs(
a);
1335 standard_evaluator::pre_assign_rhs(
b);
1337 a.ensure_cpu_up_to_date();
1338 b.ensure_cpu_up_to_date();
1339 lhs.ensure_cpu_up_to_date();
1342 const auto Batch = etl::dim<0>(lhs);
1343 const auto K = etl::dim<1>(lhs);
1344 const auto M = etl::dim<2>(lhs);
1345 const auto N = etl::dim<3>(lhs);
1347 auto batch_fun_b = [&](
const size_t first,
const size_t last) {
1349 for (
size_t batch = first; batch < last; ++batch) {
1350 for (
size_t k = 0; k < K; ++k) {
1351 for (
size_t m = 0; m < M; ++m) {
1352 for (
size_t n = 0; n < N; ++n) {
1353 lhs(batch, k, m, n) %=
a(k) *
b(batch, k, m, n);
1364 lhs.invalidate_gpu();
1366 const auto Batch = etl::dim<0>(lhs);
1367 const auto K = etl::dim<1>(lhs);
1369 auto batch_fun_b = [&](
const size_t first,
const size_t last) {
1371 for (
size_t batch = first; batch < last; ++batch) {
1372 for (
size_t k = 0; k < K; ++k) {
1373 lhs(batch, k) %=
a(k) *
b(batch, k);
1382 lhs.invalidate_gpu();
1393 return os <<
"batch_k_scale(" << expr.
_a <<
"," << expr.
_b <<
")";
1401 template <
typename A,
typename B>
1408 static constexpr
bool is_etl =
true;
1412 static constexpr
bool is_fast = sub_traits::is_fast;
1413 static constexpr
bool is_linear =
false;
1415 static constexpr
bool is_value =
false;
1416 static constexpr
bool is_direct =
true;
1417 static constexpr
bool is_generator =
false;
1418 static constexpr
bool is_padded =
false;
1419 static constexpr
bool is_aligned =
true;
1420 static constexpr
bool is_temporary =
true;
1421 static constexpr
bool gpu_computable =
true;
1422 static constexpr
order storage_order = sub_traits::storage_order;
1429 template <vector_mode_t V>
1430 static constexpr
bool vectorizable =
true;
1436 template <
size_t DD>
1437 static constexpr
size_t dim() {
1494 template <etl_1d A, etl_2d_or_4d B>
static constexpr size_t size()
Returns the size of the expression.
Definition: batch_k_scale_expr.hpp:1464
void assign_mul_to(L &&lhs) const
Multiply the given left-hand-side expression.
Definition: batch_k_scale_expr.hpp:834
static constexpr auto storage_order
The sub storage order.
Definition: batch_k_scale_expr.hpp:29
value_t< A > value_type
The value type of the expression.
Definition: batch_k_scale_expr.hpp:1406
friend std::ostream & operator<<(std::ostream &os, const batch_k_scale_expr &expr)
Print a representation of the expression on the given stream.
Definition: batch_k_scale_expr.hpp:1392
batch_k_scale_expr< detail::build_type< A >, detail::build_type< B > > batch_k_scale(const A &a, const B &b)
Returns the transpose of the given expression.
Definition: batch_k_scale_expr.hpp:1495
value_t< A > value_type
The type of value of the expression.
Definition: batch_k_scale_expr.hpp:22
A transposition expression.
Definition: batch_k_scale_expr.hpp:21
B _b
The sub expression reference.
Definition: base_temporary_expr.hpp:534
void engine_dispatch_1d_serial(Functor &&functor, size_t first, size_t last, size_t threshold, [[maybe_unused]] size_t n_threads=etl::threads)
Dispatch the elements of a range to a functor in a parallel manner, using the global thread engine...
Definition: parallel_support.hpp:734
constexpr bool is_magic_view
Traits indicating if the given ETL type is a magic view expression.
Definition: traits.hpp:311
A _a
The sub expression reference.
Definition: base_temporary_expr.hpp:533
static constexpr int complexity() noexcept
Estimate the complexity of computation.
Definition: batch_k_scale_expr.hpp:1480
static constexpr size_t dim()
Returns the DDth dimension of the expression.
Definition: batch_k_scale_expr.hpp:1437
constexpr bool vec_enabled
Indicates if vectorization is available in any format.
Definition: config.hpp:220
order
Storage order of a matrix.
Definition: order.hpp:15
Abstract base class for temporary binary expression.
Definition: base_temporary_expr.hpp:529
std::add_lvalue_reference_t< B > b()
Returns the sub expression.
Definition: base_temporary_expr.hpp:593
static constexpr bool gpu_computable
Indicates if the temporary expression can be directly evaluated using only GPU.
Definition: batch_k_scale_expr.hpp:35
constexpr bool is_fast
Traits to test if the given ETL expresion type is fast (sizes known at compile-time) ...
Definition: traits.hpp:588
typename VV::template vec_type< value_type > vec_type
The vectorization type for VV.
Definition: base_temporary_expr.hpp:107
Traits to get information about ETL types.
Definition: tmp.hpp:68
Root namespace for the ETL library.
Definition: adapter.hpp:15
void assign_div_to(L &&lhs) const
Divide the given left-hand-side expression.
Definition: batch_k_scale_expr.hpp:1081
static constexpr size_t dimensions()
Return the number of dimensions of the expression.
Definition: traits_base.hpp:31
EGBLAS wrappers for the batch_k_scale operation.
no_vec default_vec
The default vectorization scheme.
Definition: vectorization.hpp:242
auto dim(E &&value, size_t i) -> detail::identity_helper< E, dim_view< detail::build_identity_type< E >, D >>
Return a view representing the ith Dth dimension.
Definition: view_expression_builder.hpp:25
std::conditional_t< is_etl_value< T >, const std::decay_t< T > &, std::decay_t< T > > build_type
Helper to build the type for a sub expression.
Definition: expression_helpers.hpp:24
batch_k_scale_expr(A a, B b)
Construct a new expression.
Definition: batch_k_scale_expr.hpp:42
std::decay_t< B > sub_expr_t
The sub expression type.
Definition: batch_k_scale_expr.hpp:1404
void std_mul_evaluate(Expr &&expr, Result &&result)
Compound multiply evaluation of the expr into result.
Definition: evaluator.hpp:1233
constexpr bool is_transformer
Traits indicating if the given ETL type is a transformer expression.
Definition: traits.hpp:297
decltype(auto) smart_forward_gpu(E &expr)
Smart forwarding for a temporary expression that will be computed in GPU.
Definition: helpers.hpp:343
constexpr size_t size(const E &expr) noexcept
Returns the size of the given ETL expression.
Definition: helpers.hpp:108
constexpr bool is_view
Traits indicating if the given ETL type is a view expression.
Definition: traits.hpp:304
void assign_sub_to(L &&lhs) const
Sub from the given left-hand-side expression.
Definition: batch_k_scale_expr.hpp:587
static size_t size(const expr_t &e)
Returns the size of the expression.
Definition: batch_k_scale_expr.hpp:1456
void assign_add_to(L &&lhs) const
Add to the given left-hand-side expression.
Definition: batch_k_scale_expr.hpp:340
static constexpr bool D4
If the expression is 4D (instead of 2D)
Definition: batch_k_scale_expr.hpp:27
void std_sub_evaluate(Expr &&expr, Result &&result)
Compound subtract evaluation of the expr into result.
Definition: evaluator.hpp:1214
static size_t dim(const expr_t &e, size_t d)
Returns the dth dimension of the expression.
Definition: batch_k_scale_expr.hpp:1447
constexpr bool is_thread_safe
Traits to test if the given ETL expresion type is thread safe.
Definition: traits.hpp:687
static void check([[maybe_unused]] const A &a, [[maybe_unused]] const B &b, [[maybe_unused]] const C &c)
Validate the transposition dimensions.
Definition: batch_k_scale_expr.hpp:52
typename decay_traits< E >::value_type value_t
Traits to extract the value type out of an ETL type.
Definition: tmp.hpp:81
void std_div_evaluate(Expr &&expr, Result &&result)
Compound divide evaluation of the expr into result.
Definition: evaluator.hpp:1252
void assign_mod_to(L &&lhs) const
Modulo the given left-hand-side expression.
Definition: batch_k_scale_expr.hpp:1328
void inc_counter([[maybe_unused]] const char *name)
Increase the given counter.
Definition: counters.hpp:25
std::add_lvalue_reference_t< A > a()
Returns the sub expression.
Definition: base_temporary_expr.hpp:577
static constexpr size_t dimensions()
Returns the number of dimensions of the expression.
Definition: batch_k_scale_expr.hpp:1472
void std_add_evaluate(Expr &&expr, Result &&result)
Compound add evaluation of the expr into result.
Definition: evaluator.hpp:1195
void assign_to(L &&lhs) const
Assign to a matrix of the same storage order.
Definition: batch_k_scale_expr.hpp:91