18 #pragma GCC push_options 19 #pragma GCC optimize ("-fno-aggressive-loop-optimizations") 28 template <
typename V,
typename L,
typename R,
typename C>
33 static constexpr
size_t vec_size = vec_type::template traits<T>::size;
35 const auto B = etl::dim<0>(lhs);
36 const auto M = etl::dim<0>(result);
37 const auto N = etl::dim<1>(result);
39 lhs.ensure_cpu_up_to_date();
40 rhs.ensure_cpu_up_to_date();
45 auto batch_fun_m = [&](
const size_t first,
const size_t last) {
48 for (; i + 1 < last; i += 2) {
51 for (; j + 3 < N; j += 4) {
54 auto xmm1 = vec_type::template zero<T>();
55 auto xmm2 = vec_type::template zero<T>();
56 auto xmm3 = vec_type::template zero<T>();
57 auto xmm4 = vec_type::template zero<T>();
58 auto xmm5 = vec_type::template zero<T>();
59 auto xmm6 = vec_type::template zero<T>();
60 auto xmm7 = vec_type::template zero<T>();
61 auto xmm8 = vec_type::template zero<T>();
63 for (; b + vec_size - 1 < B; b += vec_size) {
64 auto l1 = L2.template loadu<vec_type>((i + 0) * B + b);
65 auto l2 = L2.template loadu<vec_type>((i + 1) * B + b);
67 auto r1 = R2.template loadu<vec_type>((j + 0) * B + b);
68 auto r2 = R2.template loadu<vec_type>((j + 1) * B + b);
69 auto r3 = R2.template loadu<vec_type>((j + 2) * B + b);
70 auto r4 = R2.template loadu<vec_type>((j + 3) * B + b);
72 xmm1 = vec_type::fmadd(l1, r1, xmm1);
73 xmm2 = vec_type::fmadd(l1, r2, xmm2);
74 xmm3 = vec_type::fmadd(l1, r3, xmm3);
75 xmm4 = vec_type::fmadd(l1, r4, xmm4);
77 xmm5 = vec_type::fmadd(l2, r1, xmm5);
78 xmm6 = vec_type::fmadd(l2, r2, xmm6);
79 xmm7 = vec_type::fmadd(l2, r3, xmm7);
80 xmm8 = vec_type::fmadd(l2, r4, xmm8);
83 T r1 = vec_type::hadd(xmm1);
84 T r2 = vec_type::hadd(xmm2);;
85 T r3 = vec_type::hadd(xmm3);;
86 T r4 = vec_type::hadd(xmm4);;
87 T r5 = vec_type::hadd(xmm5);
88 T r6 = vec_type::hadd(xmm6);;
89 T r7 = vec_type::hadd(xmm7);;
90 T r8 = vec_type::hadd(xmm8);;
92 for (; b + 1 < B; b += 2) {
93 r1 += L2(b + 0, i + 0) * R2(b + 0, j + 0);
94 r1 += L2(b + 1, i + 0) * R2(b + 1, j + 0);
96 r2 += L2(b + 0, i + 0) * R2(b + 0, j + 1);
97 r2 += L2(b + 1, i + 0) * R2(b + 1, j + 1);
99 r3 += L2(b + 0, i + 0) * R2(b + 0, j + 2);
100 r3 += L2(b + 1, i + 0) * R2(b + 1, j + 2);
102 r4 += L2(b + 0, i + 0) * R2(b + 0, j + 3);
103 r4 += L2(b + 1, i + 0) * R2(b + 1, j + 3);
105 r5 += L2(b + 0, i + 1) * R2(b + 0, j + 0);
106 r5 += L2(b + 1, i + 1) * R2(b + 1, j + 0);
108 r6 += L2(b + 0, i + 1) * R2(b + 0, j + 1);
109 r6 += L2(b + 1, i + 1) * R2(b + 1, j + 1);
111 r7 += L2(b + 0, i + 1) * R2(b + 0, j + 2);
112 r7 += L2(b + 1, i + 1) * R2(b + 1, j + 2);
114 r8 += L2(b + 0, i + 1) * R2(b + 0, j + 3);
115 r8 += L2(b + 1, i + 1) * R2(b + 1, j + 3);
119 r1 += L2(b, i + 0) * R2(b, j + 0);
120 r2 += L2(b, i + 0) * R2(b, j + 1);
121 r3 += L2(b, i + 0) * R2(b, j + 2);
122 r4 += L2(b, i + 0) * R2(b, j + 3);
124 r5 += L2(b, i + 1) * R2(b, j + 0);
125 r6 += L2(b, i + 1) * R2(b, j + 1);
126 r7 += L2(b, i + 1) * R2(b, j + 2);
127 r8 += L2(b, i + 1) * R2(b, j + 3);
130 result(i + 0, j + 0) = r1;
131 result(i + 0, j + 1) = r2;
132 result(i + 0, j + 2) = r3;
133 result(i + 0, j + 3) = r4;
135 result(i + 1, j + 0) = r5;
136 result(i + 1, j + 1) = r6;
137 result(i + 1, j + 2) = r7;
138 result(i + 1, j + 3) = r8;
141 for (; j + 1 < N; j += 2) {
144 auto xmm1 = vec_type::template zero<T>();
145 auto xmm2 = vec_type::template zero<T>();
146 auto xmm3 = vec_type::template zero<T>();
147 auto xmm4 = vec_type::template zero<T>();
149 for (; b + vec_size - 1 < B; b += vec_size) {
150 auto l1 = L2.template loadu<vec_type>((i + 0) * B + b);
151 auto l2 = L2.template loadu<vec_type>((i + 1) * B + b);
153 auto r1 = R2.template loadu<vec_type>((j + 0) * B + b);
154 auto r2 = R2.template loadu<vec_type>((j + 1) * B + b);
156 xmm1 = vec_type::fmadd(l1, r1, xmm1);
157 xmm2 = vec_type::fmadd(l1, r2, xmm2);
159 xmm3 = vec_type::fmadd(l2, r1, xmm3);
160 xmm4 = vec_type::fmadd(l2, r2, xmm4);
163 T r1 = vec_type::hadd(xmm1);
164 T r2 = vec_type::hadd(xmm2);;
165 T r3 = vec_type::hadd(xmm3);;
166 T r4 = vec_type::hadd(xmm4);;
168 for (; b + 1 < B; b += 2) {
169 r1 += L2(b + 0, i + 0) * R2(b + 0, j + 0);
170 r1 += L2(b + 1, i + 0) * R2(b + 1, j + 0);
172 r2 += L2(b + 0, i + 0) * R2(b + 0, j + 1);
173 r2 += L2(b + 1, i + 0) * R2(b + 1, j + 1);
175 r3 += L2(b + 0, i + 1) * R2(b + 0, j + 0);
176 r3 += L2(b + 1, i + 1) * R2(b + 1, j + 0);
178 r4 += L2(b + 0, i + 1) * R2(b + 0, j + 1);
179 r4 += L2(b + 1, i + 1) * R2(b + 1, j + 1);
183 r1 += L2(b, i + 0) * R2(b, j + 0);
184 r2 += L2(b, i + 0) * R2(b, j + 1);
185 r3 += L2(b, i + 1) * R2(b, j + 0);
186 r4 += L2(b, i + 1) * R2(b, j + 1);
189 result(i + 0, j + 0) = r1;
190 result(i + 0, j + 1) = r2;
192 result(i + 1, j + 0) = r3;
193 result(i + 1, j + 1) = r4;
199 auto xmm1 = vec_type::template zero<T>();
200 auto xmm2 = vec_type::template zero<T>();
202 for (; b + vec_size - 1 < B; b += vec_size) {
203 auto l1 = L2.template loadu<vec_type>((i + 0) * B + b);
204 auto l2 = L2.template loadu<vec_type>((i + 1) * B + b);
206 auto r1 = R2.template loadu<vec_type>(j * B + b);
208 xmm1 = vec_type::fmadd(l1, r1, xmm1);
209 xmm2 = vec_type::fmadd(l2, r1, xmm2);
212 T r1 = vec_type::hadd(xmm1);
213 T r2 = vec_type::hadd(xmm2);
215 for (; b + 1 < B; b += 2) {
216 r1 += L2(b + 0, i + 0) * R2(b + 0, j);
217 r1 += L2(b + 1, i + 0) * R2(b + 1, j);
219 r2 += L2(b + 0, i + 1) * R2(b + 0, j);
220 r2 += L2(b + 1, i + 1) * R2(b + 1, j);
224 r1 += L2(b, i + 0) * R2(b, j);
225 r2 += L2(b, i + 1) * R2(b, j);
228 result(i + 0, j) = r1;
229 result(i + 1, j) = r2;
236 for (; j + 1 < N; j += 2) {
239 auto xmm1 = vec_type::template zero<T>();
240 auto xmm2 = vec_type::template zero<T>();
242 for (; b + vec_size - 1 < B; b += vec_size) {
243 auto l1 = L2.template loadu<vec_type>(i * B + b);
245 auto r1 = R2.template loadu<vec_type>((j + 0) * B + b);
246 auto r2 = R2.template loadu<vec_type>((j + 1) * B + b);
248 xmm1 = vec_type::fmadd(l1, r1, xmm1);
249 xmm2 = vec_type::fmadd(l1, r2, xmm2);
252 T r1 = vec_type::hadd(xmm1);
253 T r2 = vec_type::hadd(xmm2);;
255 for (; b + 1 < B; b += 2) {
256 r1 += L2(b + 0, i) * R2(b + 0, j + 0);
257 r1 += L2(b + 1, i) * R2(b + 1, j + 0);
259 r2 += L2(b + 0, i) * R2(b + 0, j + 1);
260 r2 += L2(b + 1, i) * R2(b + 1, j + 1);
264 r1 += L2(b, i) * R2(b, j + 0);
265 r2 += L2(b, i) * R2(b, j + 1);
268 result(i, j + 0) = r1;
269 result(i, j + 1) = r2;
275 auto xmm1 = vec_type::template zero<T>();
277 for (; b + vec_size - 1 < B; b += vec_size) {
278 auto l1 = L2.template loadu<vec_type>(i * B + b);
279 auto r1 = R2.template loadu<vec_type>(j * B + b);
281 xmm1 = vec_type::fmadd(l1, r1, xmm1);
284 T r1 = vec_type::hadd(xmm1);
286 for (; b + 1 < B; b += 2) {
287 r1 += L2(b + 0, i) * R2(b + 0, j);
288 r1 += L2(b + 1, i) * R2(b + 1, j);
292 r1 += L2(b, i) * R2(b, j);
302 result.invalidate_gpu();
306 #pragma GCC pop_options 315 template <
typename A,
typename B,
typename C>
317 if constexpr (all_vectorizable<vector_mode, A, B, C>) {
318 batch_outer_impl<default_vec>(lhs, rhs, c);
320 cpp_unreachable(
"Invalid call to vec::batch_outer");
void engine_dispatch_1d(Functor &&functor, size_t first, size_t last, [[maybe_unused]] size_t threshold, [[maybe_unused]] size_t n_threads=etl::threads)
Dispatch the elements of a range to a functor in a parallel manner, using the global thread engine...
Definition: parallel_support.hpp:708
void batch_outer_impl(const L &lhs, const R &rhs, C &&result)
Compute the batch outer product of a and b and store the result in c.
Definition: outer.hpp:29
Definition: bias_add.hpp:15
typename V::template vec_type< value_type > vec_type
The vectorization type for V.
Definition: dyn_matrix_view.hpp:43
batch_outer_product_expr< detail::build_type< A >, detail::build_type< B > > batch_outer(A &&a, B &&b)
Batch Outer product multiplication of two matrices.
Definition: batch_outer_product_expr.hpp:333
decltype(auto) force_temporary_opp(E &&expr)
Force a temporary out of the expression, with opposite storage order.
Definition: temporary.hpp:135
bool engine_select_parallel([[maybe_unused]] size_t n, [[maybe_unused]] size_t threshold=parallel_threshold)
Indicates if an 1D evaluation should run in paralle.
Definition: parallel_support.hpp:679
typename decay_traits< E >::value_type value_t
Traits to extract the value type out of an ETL type.
Definition: tmp.hpp:81