Expression Templates Library (ETL)
outer.hpp
Go to the documentation of this file.
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
13 #pragma once
14 
15 namespace etl::impl::vec {
16 
17 #ifndef __clang__
18 #pragma GCC push_options
19 #pragma GCC optimize ("-fno-aggressive-loop-optimizations")
20 #endif
21 
28 template <typename V, typename L, typename R, typename C>
29 void batch_outer_impl(const L& lhs, const R& rhs, C&& result) {
30  using vec_type = V;
31  using T = value_t<L>;
32 
33  static constexpr size_t vec_size = vec_type::template traits<T>::size;
34 
35  const auto B = etl::dim<0>(lhs);
36  const auto M = etl::dim<0>(result);
37  const auto N = etl::dim<1>(result);
38 
39  lhs.ensure_cpu_up_to_date();
40  rhs.ensure_cpu_up_to_date();
41 
42  auto L2 = force_temporary_opp(lhs);
43  auto R2 = force_temporary_opp(rhs);
44 
45  auto batch_fun_m = [&](const size_t first, const size_t last) {
46  size_t i = first;
47 
48  for (; i + 1 < last; i += 2) {
49  size_t j = 0;
50 
51  for (; j + 3 < N; j += 4) {
52  size_t b = 0;
53 
54  auto xmm1 = vec_type::template zero<T>();
55  auto xmm2 = vec_type::template zero<T>();
56  auto xmm3 = vec_type::template zero<T>();
57  auto xmm4 = vec_type::template zero<T>();
58  auto xmm5 = vec_type::template zero<T>();
59  auto xmm6 = vec_type::template zero<T>();
60  auto xmm7 = vec_type::template zero<T>();
61  auto xmm8 = vec_type::template zero<T>();
62 
63  for (; b + vec_size - 1 < B; b += vec_size) {
64  auto l1 = L2.template loadu<vec_type>((i + 0) * B + b);
65  auto l2 = L2.template loadu<vec_type>((i + 1) * B + b);
66 
67  auto r1 = R2.template loadu<vec_type>((j + 0) * B + b);
68  auto r2 = R2.template loadu<vec_type>((j + 1) * B + b);
69  auto r3 = R2.template loadu<vec_type>((j + 2) * B + b);
70  auto r4 = R2.template loadu<vec_type>((j + 3) * B + b);
71 
72  xmm1 = vec_type::fmadd(l1, r1, xmm1);
73  xmm2 = vec_type::fmadd(l1, r2, xmm2);
74  xmm3 = vec_type::fmadd(l1, r3, xmm3);
75  xmm4 = vec_type::fmadd(l1, r4, xmm4);
76 
77  xmm5 = vec_type::fmadd(l2, r1, xmm5);
78  xmm6 = vec_type::fmadd(l2, r2, xmm6);
79  xmm7 = vec_type::fmadd(l2, r3, xmm7);
80  xmm8 = vec_type::fmadd(l2, r4, xmm8);
81  }
82 
83  T r1 = vec_type::hadd(xmm1);
84  T r2 = vec_type::hadd(xmm2);;
85  T r3 = vec_type::hadd(xmm3);;
86  T r4 = vec_type::hadd(xmm4);;
87  T r5 = vec_type::hadd(xmm5);
88  T r6 = vec_type::hadd(xmm6);;
89  T r7 = vec_type::hadd(xmm7);;
90  T r8 = vec_type::hadd(xmm8);;
91 
92  for (; b + 1 < B; b += 2) {
93  r1 += L2(b + 0, i + 0) * R2(b + 0, j + 0);
94  r1 += L2(b + 1, i + 0) * R2(b + 1, j + 0);
95 
96  r2 += L2(b + 0, i + 0) * R2(b + 0, j + 1);
97  r2 += L2(b + 1, i + 0) * R2(b + 1, j + 1);
98 
99  r3 += L2(b + 0, i + 0) * R2(b + 0, j + 2);
100  r3 += L2(b + 1, i + 0) * R2(b + 1, j + 2);
101 
102  r4 += L2(b + 0, i + 0) * R2(b + 0, j + 3);
103  r4 += L2(b + 1, i + 0) * R2(b + 1, j + 3);
104 
105  r5 += L2(b + 0, i + 1) * R2(b + 0, j + 0);
106  r5 += L2(b + 1, i + 1) * R2(b + 1, j + 0);
107 
108  r6 += L2(b + 0, i + 1) * R2(b + 0, j + 1);
109  r6 += L2(b + 1, i + 1) * R2(b + 1, j + 1);
110 
111  r7 += L2(b + 0, i + 1) * R2(b + 0, j + 2);
112  r7 += L2(b + 1, i + 1) * R2(b + 1, j + 2);
113 
114  r8 += L2(b + 0, i + 1) * R2(b + 0, j + 3);
115  r8 += L2(b + 1, i + 1) * R2(b + 1, j + 3);
116  }
117 
118  if (b < B) {
119  r1 += L2(b, i + 0) * R2(b, j + 0);
120  r2 += L2(b, i + 0) * R2(b, j + 1);
121  r3 += L2(b, i + 0) * R2(b, j + 2);
122  r4 += L2(b, i + 0) * R2(b, j + 3);
123 
124  r5 += L2(b, i + 1) * R2(b, j + 0);
125  r6 += L2(b, i + 1) * R2(b, j + 1);
126  r7 += L2(b, i + 1) * R2(b, j + 2);
127  r8 += L2(b, i + 1) * R2(b, j + 3);
128  }
129 
130  result(i + 0, j + 0) = r1;
131  result(i + 0, j + 1) = r2;
132  result(i + 0, j + 2) = r3;
133  result(i + 0, j + 3) = r4;
134 
135  result(i + 1, j + 0) = r5;
136  result(i + 1, j + 1) = r6;
137  result(i + 1, j + 2) = r7;
138  result(i + 1, j + 3) = r8;
139  }
140 
141  for (; j + 1 < N; j += 2) {
142  size_t b = 0;
143 
144  auto xmm1 = vec_type::template zero<T>();
145  auto xmm2 = vec_type::template zero<T>();
146  auto xmm3 = vec_type::template zero<T>();
147  auto xmm4 = vec_type::template zero<T>();
148 
149  for (; b + vec_size - 1 < B; b += vec_size) {
150  auto l1 = L2.template loadu<vec_type>((i + 0) * B + b);
151  auto l2 = L2.template loadu<vec_type>((i + 1) * B + b);
152 
153  auto r1 = R2.template loadu<vec_type>((j + 0) * B + b);
154  auto r2 = R2.template loadu<vec_type>((j + 1) * B + b);
155 
156  xmm1 = vec_type::fmadd(l1, r1, xmm1);
157  xmm2 = vec_type::fmadd(l1, r2, xmm2);
158 
159  xmm3 = vec_type::fmadd(l2, r1, xmm3);
160  xmm4 = vec_type::fmadd(l2, r2, xmm4);
161  }
162 
163  T r1 = vec_type::hadd(xmm1);
164  T r2 = vec_type::hadd(xmm2);;
165  T r3 = vec_type::hadd(xmm3);;
166  T r4 = vec_type::hadd(xmm4);;
167 
168  for (; b + 1 < B; b += 2) {
169  r1 += L2(b + 0, i + 0) * R2(b + 0, j + 0);
170  r1 += L2(b + 1, i + 0) * R2(b + 1, j + 0);
171 
172  r2 += L2(b + 0, i + 0) * R2(b + 0, j + 1);
173  r2 += L2(b + 1, i + 0) * R2(b + 1, j + 1);
174 
175  r3 += L2(b + 0, i + 1) * R2(b + 0, j + 0);
176  r3 += L2(b + 1, i + 1) * R2(b + 1, j + 0);
177 
178  r4 += L2(b + 0, i + 1) * R2(b + 0, j + 1);
179  r4 += L2(b + 1, i + 1) * R2(b + 1, j + 1);
180  }
181 
182  if (b < B) {
183  r1 += L2(b, i + 0) * R2(b, j + 0);
184  r2 += L2(b, i + 0) * R2(b, j + 1);
185  r3 += L2(b, i + 1) * R2(b, j + 0);
186  r4 += L2(b, i + 1) * R2(b, j + 1);
187  }
188 
189  result(i + 0, j + 0) = r1;
190  result(i + 0, j + 1) = r2;
191 
192  result(i + 1, j + 0) = r3;
193  result(i + 1, j + 1) = r4;
194  }
195 
196  if (j < N) {
197  size_t b = 0;
198 
199  auto xmm1 = vec_type::template zero<T>();
200  auto xmm2 = vec_type::template zero<T>();
201 
202  for (; b + vec_size - 1 < B; b += vec_size) {
203  auto l1 = L2.template loadu<vec_type>((i + 0) * B + b);
204  auto l2 = L2.template loadu<vec_type>((i + 1) * B + b);
205 
206  auto r1 = R2.template loadu<vec_type>(j * B + b);
207 
208  xmm1 = vec_type::fmadd(l1, r1, xmm1);
209  xmm2 = vec_type::fmadd(l2, r1, xmm2);
210  }
211 
212  T r1 = vec_type::hadd(xmm1);
213  T r2 = vec_type::hadd(xmm2);
214 
215  for (; b + 1 < B; b += 2) {
216  r1 += L2(b + 0, i + 0) * R2(b + 0, j);
217  r1 += L2(b + 1, i + 0) * R2(b + 1, j);
218 
219  r2 += L2(b + 0, i + 1) * R2(b + 0, j);
220  r2 += L2(b + 1, i + 1) * R2(b + 1, j);
221  }
222 
223  if (b < B) {
224  r1 += L2(b, i + 0) * R2(b, j);
225  r2 += L2(b, i + 1) * R2(b, j);
226  }
227 
228  result(i + 0, j) = r1;
229  result(i + 1, j) = r2;
230  }
231  }
232 
233  if (i < last) {
234  size_t j = 0;
235 
236  for (; j + 1 < N; j += 2) {
237  size_t b = 0;
238 
239  auto xmm1 = vec_type::template zero<T>();
240  auto xmm2 = vec_type::template zero<T>();
241 
242  for (; b + vec_size - 1 < B; b += vec_size) {
243  auto l1 = L2.template loadu<vec_type>(i * B + b);
244 
245  auto r1 = R2.template loadu<vec_type>((j + 0) * B + b);
246  auto r2 = R2.template loadu<vec_type>((j + 1) * B + b);
247 
248  xmm1 = vec_type::fmadd(l1, r1, xmm1);
249  xmm2 = vec_type::fmadd(l1, r2, xmm2);
250  }
251 
252  T r1 = vec_type::hadd(xmm1);
253  T r2 = vec_type::hadd(xmm2);;
254 
255  for (; b + 1 < B; b += 2) {
256  r1 += L2(b + 0, i) * R2(b + 0, j + 0);
257  r1 += L2(b + 1, i) * R2(b + 1, j + 0);
258 
259  r2 += L2(b + 0, i) * R2(b + 0, j + 1);
260  r2 += L2(b + 1, i) * R2(b + 1, j + 1);
261  }
262 
263  if (b < B) {
264  r1 += L2(b, i) * R2(b, j + 0);
265  r2 += L2(b, i) * R2(b, j + 1);
266  }
267 
268  result(i, j + 0) = r1;
269  result(i, j + 1) = r2;
270  }
271 
272  for (; j < N; ++j) {
273  size_t b = 0;
274 
275  auto xmm1 = vec_type::template zero<T>();
276 
277  for (; b + vec_size - 1 < B; b += vec_size) {
278  auto l1 = L2.template loadu<vec_type>(i * B + b);
279  auto r1 = R2.template loadu<vec_type>(j * B + b);
280 
281  xmm1 = vec_type::fmadd(l1, r1, xmm1);
282  }
283 
284  T r1 = vec_type::hadd(xmm1);
285 
286  for (; b + 1 < B; b += 2) {
287  r1 += L2(b + 0, i) * R2(b + 0, j);
288  r1 += L2(b + 1, i) * R2(b + 1, j);
289  }
290 
291  if (b < B) {
292  r1 += L2(b, i) * R2(b, j);
293  }
294 
295  result(i, j) = r1;
296  }
297  }
298  };
299 
300  engine_dispatch_1d(batch_fun_m, 0, M, engine_select_parallel(M, 2) && N > 20);
301 
302  result.invalidate_gpu();
303 }
304 
305 #ifndef __clang__
306 #pragma GCC pop_options
307 #endif
308 
315 template <typename A, typename B, typename C>
316 void batch_outer(const A& lhs, const B& rhs, C&& c) {
317  if constexpr (all_vectorizable<vector_mode, A, B, C>) {
318  batch_outer_impl<default_vec>(lhs, rhs, c);
319  } else {
320  cpp_unreachable("Invalid call to vec::batch_outer");
321  }
322 }
323 
324 } //end of namespace etl::impl::vec
void engine_dispatch_1d(Functor &&functor, size_t first, size_t last, [[maybe_unused]] size_t threshold, [[maybe_unused]] size_t n_threads=etl::threads)
Dispatch the elements of a range to a functor in a parallel manner, using the global thread engine...
Definition: parallel_support.hpp:708
void batch_outer_impl(const L &lhs, const R &rhs, C &&result)
Compute the batch outer product of a and b and store the result in c.
Definition: outer.hpp:29
Definition: bias_add.hpp:15
typename V::template vec_type< value_type > vec_type
The vectorization type for V.
Definition: dyn_matrix_view.hpp:43
batch_outer_product_expr< detail::build_type< A >, detail::build_type< B > > batch_outer(A &&a, B &&b)
Batch Outer product multiplication of two matrices.
Definition: batch_outer_product_expr.hpp:333
decltype(auto) force_temporary_opp(E &&expr)
Force a temporary out of the expression, with opposite storage order.
Definition: temporary.hpp:135
bool engine_select_parallel([[maybe_unused]] size_t n, [[maybe_unused]] size_t threshold=parallel_threshold)
Indicates if an 1D evaluation should run in paralle.
Definition: parallel_support.hpp:679
typename decay_traits< E >::value_type value_t
Traits to extract the value type out of an ETL type.
Definition: tmp.hpp:81