Expression Templates Library (ETL)
div.hpp
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
8 #pragma once
9 
10 namespace etl {
11 
12 template <typename T>
14 
15 // detect x / (1.0 * y)
16 
17 template <typename L, typename R>
19  static constexpr bool value = false;
20 };
21 
22 template <typename T0, typename T1, typename T2, typename RightExpr, typename L>
23 struct is_axdy_right_left_impl<L, binary_expr<T0, etl::scalar<T1>, etl::mul_binary_op<T2>, RightExpr>> {
24  static constexpr bool value = true;
25 };
26 
27 // detect x / (y * 1.0)
28 
29 template <typename L, typename R>
31  static constexpr bool value = false;
32 };
33 
34 template <typename T0, typename T1, typename T2, typename RightExpr, typename L>
35 struct is_axdy_right_right_impl<L, binary_expr<T0, RightExpr, etl::mul_binary_op<T2>, etl::scalar<T1>>> {
36  static constexpr bool value = true;
37 };
38 
39 // detect (1.0 * x) / y
40 
41 template <typename L, typename R>
43  static constexpr bool value = false;
44 };
45 
46 template <typename T0, typename T1, typename T2, typename RightExpr, typename R>
47 struct is_axdy_left_left_impl<binary_expr<T0, etl::scalar<T1>, etl::mul_binary_op<T2>, RightExpr>, R> {
48  static constexpr bool value = true;
49 };
50 
51 // detect (1.0 * x) / y
52 
53 template <typename L, typename R>
55  static constexpr bool value = false;
56 };
57 
58 template <typename T0, typename T1, typename T2, typename RightExpr, typename R>
59 struct is_axdy_left_right_impl<binary_expr<T0, RightExpr, etl::mul_binary_op<T2>, etl::scalar<T1>>, R> {
60  static constexpr bool value = true;
61 };
62 
63 // detect x / (1.0 + y)
64 
65 template <typename L, typename R>
67  static constexpr bool value = false;
68 };
69 
70 template <typename T0, typename T1, typename T2, typename RightExpr, typename L>
71 struct is_axdbpy_left_impl<L, binary_expr<T0, etl::scalar<T1>, etl::plus_binary_op<T2>, RightExpr>> {
72  static constexpr bool value = true;
73 };
74 
75 // detect x / (y + 1.0)
76 
77 template <typename L, typename R>
79  static constexpr bool value = false;
80 };
81 
82 template <typename T0, typename T1, typename T2, typename RightExpr, typename L>
83 struct is_axdbpy_right_impl<L, binary_expr<T0, RightExpr, etl::plus_binary_op<T2>, etl::scalar<T1>>> {
84  static constexpr bool value = true;
85 };
86 
87 // detect (1.0 * x) / (1.0 + y)
88 
89 template <typename L, typename R>
91  static constexpr bool value = false;
92 };
93 
94 template <typename T0, typename T1, typename T2, typename T3, typename T4, typename T5, typename R1, typename R2>
95 struct is_axdbpy_left_left_impl<binary_expr<T0, etl::scalar<T1>, etl::mul_binary_op<T2>, R1>, binary_expr<T3, etl::scalar<T4>, etl::plus_binary_op<T5>, R2>> {
96  static constexpr bool value = true;
97 };
98 
99 // detect (1.0 * x) / (y + 1.0)
100 
101 template <typename L, typename R>
103  static constexpr bool value = false;
104 };
105 
106 template <typename T0, typename T1, typename T2, typename T3, typename T4, typename T5, typename R1, typename R2>
107 struct is_axdbpy_left_right_impl<binary_expr<T0, etl::scalar<T1>, etl::mul_binary_op<T2>, R1>, binary_expr<T3, R2, etl::plus_binary_op<T5>, etl::scalar<T4>>> {
108  static constexpr bool value = true;
109 };
110 
111 // detect (x * 1.0) / (1.0 + y)
112 
113 template <typename L, typename R>
115  static constexpr bool value = false;
116 };
117 
118 template <typename T0, typename T1, typename T2, typename T3, typename T4, typename T5, typename R1, typename R2>
119 struct is_axdbpy_right_left_impl<binary_expr<T0, R1, etl::mul_binary_op<T2>, etl::scalar<T1>>, binary_expr<T3, etl::scalar<T4>, etl::plus_binary_op<T5>, R2>> {
120  static constexpr bool value = true;
121 };
122 
123 // detect (x * 1.0) / (y + 1.0)
124 
125 template <typename L, typename R>
127  static constexpr bool value = false;
128 };
129 
130 template <typename T0, typename T1, typename T2, typename T3, typename T4, typename T5, typename R1, typename R2>
131 struct is_axdbpy_right_right_impl<binary_expr<T0, R1, etl::mul_binary_op<T2>, etl::scalar<T1>>, binary_expr<T3, R2, etl::plus_binary_op<T5>, etl::scalar<T4>>> {
132  static constexpr bool value = true;
133 };
134 
135 // detect (1.0 + x) / (1.0 + y)
136 
137 template <typename L, typename R>
139  static constexpr bool value = false;
140 };
141 
142 template <typename T0, typename T1, typename T2, typename T3, typename T4, typename T5, typename R1, typename R2>
143 struct is_apxdbpy_left_left_impl<binary_expr<T0, etl::scalar<T1>, etl::plus_binary_op<T2>, R1>, binary_expr<T3, etl::scalar<T4>, etl::plus_binary_op<T5>, R2>> {
144  static constexpr bool value = true;
145 };
146 
147 // detect (1.0 + x) / (y + 1.0)
148 
149 template <typename L, typename R>
151  static constexpr bool value = false;
152 };
153 
154 template <typename T0, typename T1, typename T2, typename T3, typename T4, typename T5, typename R1, typename R2>
156  binary_expr<T3, R2, etl::plus_binary_op<T5>, etl::scalar<T4>>> {
157  static constexpr bool value = true;
158 };
159 
160 // detect (x + 1.0) / (1.0 + y)
161 
162 template <typename L, typename R>
164  static constexpr bool value = false;
165 };
166 
167 template <typename T0, typename T1, typename T2, typename T3, typename T4, typename T5, typename R1, typename R2>
169  binary_expr<T3, etl::scalar<T4>, etl::plus_binary_op<T5>, R2>> {
170  static constexpr bool value = true;
171 };
172 
173 // detect (x + 1.0) / (y + 1.0)
174 
175 template <typename L, typename R>
177  static constexpr bool value = false;
178 };
179 
180 template <typename T0, typename T1, typename T2, typename T3, typename T4, typename T5, typename R1, typename R2>
182  binary_expr<T3, R2, etl::plus_binary_op<T5>, etl::scalar<T4>>> {
183  static constexpr bool value = true;
184 };
185 
186 // detect (1.0 + x) / y
187 
188 template <typename L, typename R>
190  static constexpr bool value = false;
191 };
192 
193 template <typename T3, typename T4, typename T5, typename R1, typename R2>
195  static constexpr bool value = true;
196 };
197 
198 // detect (x + 1.0) / y
199 
200 template <typename L, typename R>
202  static constexpr bool value = false;
203 };
204 
205 template <typename T3, typename T4, typename T5, typename R1, typename R2>
207  static constexpr bool value = true;
208 };
209 
210 // detect (1.0 + x) / (1.0 * y)
211 
212 template <typename L, typename R>
214  static constexpr bool value = false;
215 };
216 
217 template <typename T0, typename T1, typename T2, typename T3, typename T4, typename T5, typename R1, typename R2>
218 struct is_apxdby_left_left_impl<binary_expr<T0, etl::scalar<T1>, etl::plus_binary_op<T2>, R1>, binary_expr<T3, etl::scalar<T4>, etl::mul_binary_op<T5>, R2>> {
219  static constexpr bool value = true;
220 };
221 
222 // detect (1.0 + x) / (y * 1.0)
223 
224 template <typename L, typename R>
226  static constexpr bool value = false;
227 };
228 
229 template <typename T0, typename T1, typename T2, typename T3, typename T4, typename T5, typename R1, typename R2>
230 struct is_apxdby_left_right_impl<binary_expr<T0, etl::scalar<T1>, etl::plus_binary_op<T2>, R1>, binary_expr<T3, R2, etl::mul_binary_op<T5>, etl::scalar<T4>>> {
231  static constexpr bool value = true;
232 };
233 
234 // detect (x + 1.0) / (1.0 * y)
235 
236 template <typename L, typename R>
238  static constexpr bool value = false;
239 };
240 
241 template <typename T0, typename T1, typename T2, typename T3, typename T4, typename T5, typename R1, typename R2>
242 struct is_apxdby_right_left_impl<binary_expr<T0, R1, etl::plus_binary_op<T2>, etl::scalar<T1>>, binary_expr<T3, etl::scalar<T4>, etl::mul_binary_op<T5>, R2>> {
243  static constexpr bool value = true;
244 };
245 
246 // detect (x + 1.0) / (y * 1.0)
247 
248 template <typename L, typename R>
250  static constexpr bool value = false;
251 };
252 
253 template <typename T0, typename T1, typename T2, typename T3, typename T4, typename T5, typename R1, typename R2>
254 struct is_apxdby_right_right_impl<binary_expr<T0, R1, etl::plus_binary_op<T2>, etl::scalar<T1>>, binary_expr<T3, R2, etl::mul_binary_op<T5>, etl::scalar<T4>>> {
255  static constexpr bool value = true;
256 };
257 
258 // Variable templates helper
259 
260 template <typename L, typename R>
261 static constexpr bool is_apxdbpy_left_left = is_apxdbpy_left_left_impl<L, R>::value;
262 
263 template <typename L, typename R>
264 static constexpr bool is_apxdbpy_left_right = is_apxdbpy_left_right_impl<L, R>::value;
265 
266 template <typename L, typename R>
267 static constexpr bool is_apxdbpy_right_left = is_apxdbpy_right_left_impl<L, R>::value;
268 
269 template <typename L, typename R>
270 static constexpr bool is_apxdbpy_right_right = is_apxdbpy_right_right_impl<L, R>::value;
271 
272 template <typename L, typename R>
273 static constexpr bool is_apxdbpy = is_apxdbpy_left_left<L, R> || is_apxdbpy_left_right<L, R> || is_apxdbpy_right_left<L, R> || is_apxdbpy_right_right<L, R>;
274 
275 template <typename L, typename R>
276 static constexpr bool is_apxdby_left_left = is_apxdby_left_left_impl<L, R>::value;
277 
278 template <typename L, typename R>
279 static constexpr bool is_apxdby_left_right = is_apxdby_left_right_impl<L, R>::value;
280 
281 template <typename L, typename R>
282 static constexpr bool is_apxdby_right_left = is_apxdby_right_left_impl<L, R>::value;
283 
284 template <typename L, typename R>
285 static constexpr bool is_apxdby_right_right = is_apxdby_right_right_impl<L, R>::value;
286 
287 template <typename L, typename R>
288 static constexpr bool is_apxdby_left =
290  && !is_apxdbpy<L, R> && !is_apxdby_left_left<L, R> && !is_apxdby_left_right<L, R> && !is_apxdby_right_left<L, R> && !is_apxdby_right_right<L, R>;
291 
292 template <typename L, typename R>
293 static constexpr bool is_apxdby_right =
295  && !is_apxdbpy<L, R> && !is_apxdby_left_left<L, R> && !is_apxdby_left_right<L, R> && !is_apxdby_right_left<L, R> && !is_apxdby_right_right<L, R>;
296 
297 template <typename L, typename R>
298 static constexpr bool is_apxdby =
299  is_apxdby_left<
300  L,
301  R> || is_apxdby_right<L, R> || is_apxdby_left_left<L, R> || is_apxdby_left_right<L, R> || is_apxdby_right_left<L, R> || is_apxdby_right_right<L, R>;
302 
303 template <typename L, typename R>
304 static constexpr bool is_axdbpy_left_left = is_axdbpy_left_left_impl<L, R>::value;
305 
306 template <typename L, typename R>
307 static constexpr bool is_axdbpy_left_right = is_axdbpy_left_right_impl<L, R>::value;
308 
309 template <typename L, typename R>
310 static constexpr bool is_axdbpy_right_left = is_axdbpy_right_left_impl<L, R>::value;
311 
312 template <typename L, typename R>
313 static constexpr bool is_axdbpy_right_right = is_axdbpy_right_right_impl<L, R>::value;
314 
315 template <typename L, typename R>
316 static constexpr bool is_axdbpy_left = is_axdbpy_left_impl<L, R>::value && !is_axdbpy_left_left<L, R> && !is_axdbpy_right_left<L, R> && !is_apxdbpy<L, R>;
317 
318 template <typename L, typename R>
319 static constexpr bool is_axdbpy_right = is_axdbpy_right_impl<L, R>::value && !is_axdbpy_left_right<L, R> && !is_axdbpy_right_right<L, R> && !is_apxdbpy<L, R>;
320 
321 template <typename L, typename R>
322 static constexpr bool is_axdbpy =
323  is_axdbpy_left<
324  L,
325  R> || is_axdbpy_right<L, R> || is_axdbpy_left_left<L, R> || is_axdbpy_left_right<L, R> || is_axdbpy_right_left<L, R> || is_axdbpy_right_right<L, R>;
326 
327 template <typename L, typename R>
328 static constexpr bool is_axdy_right_left = is_axdy_right_left_impl<L, R>::value && !is_axdbpy<L, R> && !is_apxdby<L, R>;
329 
330 template <typename L, typename R>
331 static constexpr bool is_axdy_right_right = is_axdy_right_right_impl<L, R>::value && !is_axdbpy<L, R> && !is_apxdby<L, R>;
332 
333 template <typename L, typename R>
334 static constexpr bool is_axdy_left_left = is_axdy_left_left_impl<L, R>::value && !is_axdbpy<L, R> && !is_apxdby<L, R>;
335 
336 template <typename L, typename R>
337 static constexpr bool is_axdy_left_right = is_axdy_left_right_impl<L, R>::value && !is_axdbpy<L, R> && !is_apxdby<L, R>;
338 
339 template <typename L, typename R>
340 static constexpr bool is_axdy = is_axdy_right_left<L, R> || is_axdy_right_right<L, R> || is_axdy_left_left<L, R> || is_axdy_left_right<L, R>;
341 
342 template <typename L, typename R>
343 static constexpr bool is_special_div = is_axdy<L, R> || is_axdbpy<L, R> || is_apxdbpy<L, R> || is_apxdby<L, R>;
344 
348 template <typename T>
350  static constexpr bool linear = true;
351  static constexpr bool thread_safe = true;
352  static constexpr bool desc_func = false;
353 
360  template <vector_mode_t V>
361  static constexpr bool vectorizable = is_floating_t<T> || (is_complex_t<T> && V != vector_mode_t::AVX512);
362 
366  template <typename L, typename R>
367  static constexpr bool gpu_computable =
368  ((!is_scalar<L> && !is_scalar<R>)&&((is_single_precision_t<T> && impl::egblas::has_saxdy_3 && impl::egblas::has_saxdbpy_3
369  && impl::egblas::has_sapxdbpy_3 && impl::egblas::has_sapxdby_3)
370  || (is_double_precision_t<T> && impl::egblas::has_daxdy_3 && impl::egblas::has_daxdbpy_3
371  && impl::egblas::has_dapxdbpy_3 && impl::egblas::has_dapxdby_3)
372  || (is_complex_single_t<T> && impl::egblas::has_caxdy_3 && impl::egblas::has_caxdbpy_3
373  && impl::egblas::has_capxdbpy_3 && impl::egblas::has_capxdby_3)
374  || (is_complex_double_t<T> && impl::egblas::has_zaxdy_3 && impl::egblas::has_zaxdbpy_3
375  && impl::egblas::has_zapxdbpy_3 && impl::egblas::has_zapxdby_3)))
376  || ((is_scalar<L> != is_scalar<R>)&&((is_single_precision_t<T> && impl::egblas::has_scalar_smul && impl::egblas::has_scalar_sdiv)
377  || (is_double_precision_t<T> && impl::egblas::has_scalar_dmul && impl::egblas::has_scalar_ddiv)
378  || (is_complex_single_t<T> && impl::egblas::has_scalar_cmul && impl::egblas::has_scalar_cdiv)
379  || (is_complex_double_t<T> && impl::egblas::has_scalar_zmul && impl::egblas::has_scalar_zdiv)));
380 
385  static constexpr int complexity() {
386  return 1;
387  }
388 
392  template <typename V = default_vec>
393  using vec_type = typename V::template vec_type<T>;
394 
401  static constexpr T apply(const T& lhs, const T& rhs) noexcept {
402  return lhs / rhs;
403  }
404 
412  template <typename V = default_vec>
413  static vec_type<V> load(const vec_type<V>& lhs, const vec_type<V>& rhs) noexcept {
414  return V::div(lhs, rhs);
415  }
416 
425  template <typename L, typename R, typename Y>
426  static auto gpu_compute_hint(const L& lhs, const R& rhs, Y& y) noexcept {
427  auto t3 = force_temporary_gpu_dim_only(y);
428  gpu_compute(lhs, rhs, t3);
429  return t3;
430  }
431 
440  template <typename L, typename R, typename Y>
441  static Y& gpu_compute(const L& lhs, const R& rhs, Y& yy) noexcept {
442  if constexpr (is_axdy_right_left<L, R>) {
443  auto& rhs_lhs = rhs.get_lhs();
444  auto& rhs_rhs = rhs.get_rhs();
445 
446  decltype(auto) x = smart_gpu_compute_hint(lhs, yy);
447  decltype(auto) y = smart_gpu_compute_hint(rhs_rhs, yy);
448 
449  constexpr auto incx = gpu_inc<decltype(lhs)>;
450  constexpr auto incy = gpu_inc<decltype(rhs_rhs)>;
451 
452  impl::egblas::axdy_3(etl::size(yy), rhs_lhs.value, x.gpu_memory(), incx, y.gpu_memory(), incy, yy.gpu_memory(), 1);
453  } else if constexpr (is_axdy_right_right<L, R>) {
454  auto& rhs_lhs = rhs.get_lhs();
455  auto& rhs_rhs = rhs.get_rhs();
456 
457  decltype(auto) x = smart_gpu_compute_hint(lhs, yy);
458  decltype(auto) y = smart_gpu_compute_hint(rhs_lhs, yy);
459 
460  constexpr auto incx = gpu_inc<decltype(lhs)>;
461  constexpr auto incy = gpu_inc<decltype(rhs_lhs)>;
462 
463  impl::egblas::axdy_3(etl::size(yy), rhs_rhs.value, x.gpu_memory(), incx, y.gpu_memory(), incy, yy.gpu_memory(), 1);
464  } else if constexpr (is_axdy_left_left<L, R>) {
465  auto& lhs_lhs = lhs.get_lhs();
466  auto& lhs_rhs = lhs.get_rhs();
467 
468  decltype(auto) x = smart_gpu_compute_hint(lhs_rhs, yy);
469  decltype(auto) y = smart_gpu_compute_hint(rhs, yy);
470 
471  constexpr auto incx = gpu_inc<decltype(lhs_rhs)>;
472  constexpr auto incy = gpu_inc<decltype(rhs)>;
473 
474  impl::egblas::axdy_3(etl::size(yy), T(1) / lhs_lhs.value, x.gpu_memory(), incx, y.gpu_memory(), incy, yy.gpu_memory(), 1);
475  } else if constexpr (is_axdy_left_right<L, R>) {
476  auto& lhs_lhs = lhs.get_lhs();
477  auto& lhs_rhs = lhs.get_rhs();
478 
479  decltype(auto) x = smart_gpu_compute_hint(lhs_lhs, yy);
480  decltype(auto) y = smart_gpu_compute_hint(rhs, yy);
481 
482  constexpr auto incx = gpu_inc<decltype(lhs_lhs)>;
483  constexpr auto incy = gpu_inc<decltype(rhs)>;
484 
485  impl::egblas::axdy_3(etl::size(yy), T(1) / lhs_rhs.value, x.gpu_memory(), incx, y.gpu_memory(), incy, yy.gpu_memory(), 1);
486  } else if constexpr (is_axdbpy_left<L, R>) {
487  auto& rhs_lhs = rhs.get_lhs();
488  auto& rhs_rhs = rhs.get_rhs();
489 
490  decltype(auto) x = smart_gpu_compute_hint(lhs, yy);
491  decltype(auto) y = smart_gpu_compute_hint(rhs_rhs, yy);
492 
493  constexpr auto incx = gpu_inc<decltype(lhs)>;
494  constexpr auto incy = gpu_inc<decltype(rhs_rhs)>;
495 
496  impl::egblas::axdbpy_3(etl::size(yy), T(1), x.gpu_memory(), incx, rhs_lhs.value, y.gpu_memory(), incy, yy.gpu_memory(), 1);
497  } else if constexpr (is_axdbpy_right<L, R>) {
498  auto& rhs_lhs = rhs.get_lhs();
499  auto& rhs_rhs = rhs.get_rhs();
500 
501  decltype(auto) x = smart_gpu_compute_hint(lhs, yy);
502  decltype(auto) y = smart_gpu_compute_hint(rhs_lhs, yy);
503 
504  constexpr auto incx = gpu_inc<decltype(lhs)>;
505  constexpr auto incy = gpu_inc<decltype(rhs_lhs)>;
506 
507  impl::egblas::axdbpy_3(etl::size(yy), T(1), x.gpu_memory(), incx, rhs_rhs.value, y.gpu_memory(), incy, yy.gpu_memory(), 1);
508  } else if constexpr (is_axdbpy_left_left<L, R>) {
509  auto& lhs_lhs = lhs.get_lhs();
510  auto& lhs_rhs = lhs.get_rhs();
511 
512  auto& rhs_lhs = rhs.get_lhs();
513  auto& rhs_rhs = rhs.get_rhs();
514 
515  decltype(auto) x = smart_gpu_compute_hint(lhs_rhs, yy);
516  decltype(auto) y = smart_gpu_compute_hint(rhs_rhs, yy);
517 
518  constexpr auto incx = gpu_inc<decltype(lhs_rhs)>;
519  constexpr auto incy = gpu_inc<decltype(rhs_rhs)>;
520 
521  impl::egblas::axdbpy_3(etl::size(yy), lhs_lhs.value, x.gpu_memory(), incx, rhs_lhs.value, y.gpu_memory(), incy, yy.gpu_memory(), 1);
522  } else if constexpr (is_axdbpy_left_right<L, R>) {
523  auto& lhs_lhs = lhs.get_lhs();
524  auto& lhs_rhs = lhs.get_rhs();
525 
526  auto& rhs_lhs = rhs.get_lhs();
527  auto& rhs_rhs = rhs.get_rhs();
528 
529  decltype(auto) x = smart_gpu_compute_hint(lhs_rhs, yy);
530  decltype(auto) y = smart_gpu_compute_hint(rhs_lhs, yy);
531 
532  constexpr auto incx = gpu_inc<decltype(lhs_rhs)>;
533  constexpr auto incy = gpu_inc<decltype(rhs_lhs)>;
534 
535  impl::egblas::axdbpy_3(etl::size(yy), lhs_lhs.value, x.gpu_memory(), incx, rhs_rhs.value, y.gpu_memory(), incy, yy.gpu_memory(), 1);
536  } else if constexpr (is_axdbpy_right_left<L, R>) {
537  auto& lhs_lhs = lhs.get_lhs();
538  auto& lhs_rhs = lhs.get_rhs();
539 
540  auto& rhs_lhs = rhs.get_lhs();
541  auto& rhs_rhs = rhs.get_rhs();
542 
543  decltype(auto) x = smart_gpu_compute_hint(lhs_lhs, yy);
544  decltype(auto) y = smart_gpu_compute_hint(rhs_rhs, yy);
545 
546  constexpr auto incx = gpu_inc<decltype(lhs_lhs)>;
547  constexpr auto incy = gpu_inc<decltype(rhs_rhs)>;
548 
549  impl::egblas::axdbpy_3(etl::size(yy), lhs_rhs.value, x.gpu_memory(), incx, rhs_lhs.value, y.gpu_memory(), incy, yy.gpu_memory(), 1);
550  } else if constexpr (is_axdbpy_right_right<L, R>) {
551  auto& lhs_lhs = lhs.get_lhs();
552  auto& lhs_rhs = lhs.get_rhs();
553 
554  auto& rhs_lhs = rhs.get_lhs();
555  auto& rhs_rhs = rhs.get_rhs();
556 
557  decltype(auto) x = smart_gpu_compute_hint(lhs_lhs, yy);
558  decltype(auto) y = smart_gpu_compute_hint(rhs_lhs, yy);
559 
560  constexpr auto incx = gpu_inc<decltype(lhs_lhs)>;
561  constexpr auto incy = gpu_inc<decltype(rhs_lhs)>;
562 
563  impl::egblas::axdbpy_3(etl::size(yy), lhs_rhs.value, x.gpu_memory(), incx, rhs_rhs.value, y.gpu_memory(), incy, yy.gpu_memory(), 1);
564  } else if constexpr (is_apxdbpy_left_left<L, R>) {
565  auto& lhs_lhs = lhs.get_lhs();
566  auto& lhs_rhs = lhs.get_rhs();
567 
568  auto& rhs_lhs = rhs.get_lhs();
569  auto& rhs_rhs = rhs.get_rhs();
570 
571  decltype(auto) x = smart_gpu_compute_hint(lhs_rhs, yy);
572  decltype(auto) y = smart_gpu_compute_hint(rhs_rhs, yy);
573 
574  constexpr auto incx = gpu_inc<decltype(lhs_rhs)>;
575  constexpr auto incy = gpu_inc<decltype(rhs_rhs)>;
576 
577  impl::egblas::apxdbpy_3(etl::size(yy), lhs_lhs.value, x.gpu_memory(), incx, rhs_lhs.value, y.gpu_memory(), incy, yy.gpu_memory(), 1);
578  } else if constexpr (is_apxdbpy_left_right<L, R>) {
579  auto& lhs_lhs = lhs.get_lhs();
580  auto& lhs_rhs = lhs.get_rhs();
581 
582  auto& rhs_lhs = rhs.get_lhs();
583  auto& rhs_rhs = rhs.get_rhs();
584 
585  decltype(auto) x = smart_gpu_compute_hint(lhs_rhs, yy);
586  decltype(auto) y = smart_gpu_compute_hint(rhs_lhs, yy);
587 
588  constexpr auto incx = gpu_inc<decltype(lhs_rhs)>;
589  constexpr auto incy = gpu_inc<decltype(rhs_lhs)>;
590 
591  impl::egblas::apxdbpy_3(etl::size(yy), lhs_lhs.value, x.gpu_memory(), incx, rhs_rhs.value, y.gpu_memory(), incy, yy.gpu_memory(), 1);
592  } else if constexpr (is_apxdbpy_right_left<L, R>) {
593  auto& lhs_lhs = lhs.get_lhs();
594  auto& lhs_rhs = lhs.get_rhs();
595 
596  auto& rhs_lhs = rhs.get_lhs();
597  auto& rhs_rhs = rhs.get_rhs();
598 
599  decltype(auto) x = smart_gpu_compute_hint(lhs_lhs, yy);
600  decltype(auto) y = smart_gpu_compute_hint(rhs_rhs, yy);
601 
602  constexpr auto incx = gpu_inc<decltype(lhs_lhs)>;
603  constexpr auto incy = gpu_inc<decltype(rhs_rhs)>;
604 
605  impl::egblas::apxdbpy_3(etl::size(yy), lhs_rhs.value, x.gpu_memory(), incx, rhs_lhs.value, y.gpu_memory(), incy, yy.gpu_memory(), 1);
606  } else if constexpr (is_apxdbpy_right_right<L, R>) {
607  auto& lhs_lhs = lhs.get_lhs();
608  auto& lhs_rhs = lhs.get_rhs();
609 
610  auto& rhs_lhs = rhs.get_lhs();
611  auto& rhs_rhs = rhs.get_rhs();
612 
613  decltype(auto) x = smart_gpu_compute_hint(lhs_lhs, yy);
614  decltype(auto) y = smart_gpu_compute_hint(rhs_lhs, yy);
615 
616  constexpr auto incx = gpu_inc<decltype(lhs_lhs)>;
617  constexpr auto incy = gpu_inc<decltype(rhs_lhs)>;
618 
619  impl::egblas::apxdbpy_3(etl::size(yy), lhs_rhs.value, x.gpu_memory(), incx, rhs_rhs.value, y.gpu_memory(), incy, yy.gpu_memory(), 1);
620  } else if constexpr (is_apxdby_left<L, R>) {
621  auto& lhs_lhs = lhs.get_lhs();
622  auto& lhs_rhs = lhs.get_rhs();
623 
624  decltype(auto) x = smart_gpu_compute_hint(lhs_rhs, yy);
625  decltype(auto) y = smart_gpu_compute_hint(rhs, yy);
626 
627  constexpr auto incx = gpu_inc<decltype(lhs_rhs)>;
628  constexpr auto incy = gpu_inc<decltype(rhs)>;
629 
630  impl::egblas::apxdby_3(etl::size(yy), lhs_lhs.value, x.gpu_memory(), incx, T(1), y.gpu_memory(), incy, yy.gpu_memory(), 1);
631  } else if constexpr (is_apxdby_right<L, R>) {
632  auto& lhs_lhs = lhs.get_lhs();
633  auto& lhs_rhs = lhs.get_rhs();
634 
635  decltype(auto) x = smart_gpu_compute_hint(lhs_lhs, yy);
636  decltype(auto) y = smart_gpu_compute_hint(rhs, yy);
637 
638  constexpr auto incx = gpu_inc<decltype(lhs_lhs)>;
639  constexpr auto incy = gpu_inc<decltype(rhs)>;
640 
641  impl::egblas::apxdby_3(etl::size(yy), lhs_rhs.value, x.gpu_memory(), incx, T(1), y.gpu_memory(), incy, yy.gpu_memory(), 1);
642  } else if constexpr (is_apxdby_left_left<L, R>) {
643  auto& lhs_lhs = lhs.get_lhs();
644  auto& lhs_rhs = lhs.get_rhs();
645 
646  auto& rhs_lhs = rhs.get_lhs();
647  auto& rhs_rhs = rhs.get_rhs();
648 
649  decltype(auto) x = smart_gpu_compute_hint(lhs_rhs, yy);
650  decltype(auto) y = smart_gpu_compute_hint(rhs_rhs, yy);
651 
652  constexpr auto incx = gpu_inc<decltype(lhs_rhs)>;
653  constexpr auto incy = gpu_inc<decltype(rhs_rhs)>;
654 
655  impl::egblas::apxdby_3(etl::size(yy), lhs_lhs.value, x.gpu_memory(), incx, rhs_lhs.value, y.gpu_memory(), incy, yy.gpu_memory(), 1);
656  } else if constexpr (is_apxdby_left_right<L, R>) {
657  auto& lhs_lhs = lhs.get_lhs();
658  auto& lhs_rhs = lhs.get_rhs();
659 
660  auto& rhs_lhs = rhs.get_lhs();
661  auto& rhs_rhs = rhs.get_rhs();
662 
663  decltype(auto) x = smart_gpu_compute_hint(lhs_rhs, yy);
664  decltype(auto) y = smart_gpu_compute_hint(rhs_lhs, yy);
665 
666  constexpr auto incx = gpu_inc<decltype(lhs_rhs)>;
667  constexpr auto incy = gpu_inc<decltype(rhs_lhs)>;
668 
669  impl::egblas::apxdby_3(etl::size(yy), lhs_lhs.value, x.gpu_memory(), incx, rhs_rhs.value, y.gpu_memory(), incy, yy.gpu_memory(), 1);
670  } else if constexpr (is_apxdby_right_left<L, R>) {
671  auto& lhs_lhs = lhs.get_lhs();
672  auto& lhs_rhs = lhs.get_rhs();
673 
674  auto& rhs_lhs = rhs.get_lhs();
675  auto& rhs_rhs = rhs.get_rhs();
676 
677  decltype(auto) x = smart_gpu_compute_hint(lhs_lhs, yy);
678  decltype(auto) y = smart_gpu_compute_hint(rhs_rhs, yy);
679 
680  constexpr auto incx = gpu_inc<decltype(lhs_lhs)>;
681  constexpr auto incy = gpu_inc<decltype(rhs_rhs)>;
682 
683  impl::egblas::apxdby_3(etl::size(yy), lhs_rhs.value, x.gpu_memory(), incx, rhs_lhs.value, y.gpu_memory(), incy, yy.gpu_memory(), 1);
684  } else if constexpr (is_apxdby_right_right<L, R>) {
685  auto& lhs_lhs = lhs.get_lhs();
686  auto& lhs_rhs = lhs.get_rhs();
687 
688  auto& rhs_lhs = rhs.get_lhs();
689  auto& rhs_rhs = rhs.get_rhs();
690 
691  decltype(auto) x = smart_gpu_compute_hint(lhs_lhs, yy);
692  decltype(auto) y = smart_gpu_compute_hint(rhs_lhs, yy);
693 
694  constexpr auto incx = gpu_inc<decltype(lhs_lhs)>;
695  constexpr auto incy = gpu_inc<decltype(rhs_lhs)>;
696 
697  impl::egblas::apxdby_3(etl::size(yy), lhs_rhs.value, x.gpu_memory(), incx, rhs_rhs.value, y.gpu_memory(), incy, yy.gpu_memory(), 1);
698  } else if constexpr (!is_scalar<L> && !is_scalar<R> && !is_special_div<L, R>) {
699  decltype(auto) x = smart_gpu_compute_hint(lhs, yy);
700  decltype(auto) y = smart_gpu_compute_hint(rhs, yy);
701 
702  constexpr auto incx = gpu_inc<decltype(lhs)>;
703  constexpr auto incy = gpu_inc<decltype(rhs)>;
704 
705  value_t<L> alpha(1);
706  impl::egblas::axdy_3(etl::size(yy), alpha, x.gpu_memory(), incx, y.gpu_memory(), incy, yy.gpu_memory(), 1);
707  } else if constexpr (!is_scalar<L> && is_scalar<R> && !is_special_div<L, R>) {
708  auto s = T(1) / rhs.value;
709 
710  smart_gpu_compute(lhs, yy);
711 
712  impl::egblas::scalar_mul(yy.gpu_memory(), etl::size(yy), 1, s);
713  } else if constexpr (is_scalar<L> && !is_scalar<R> && !is_special_div<L, R>) {
714  auto s = lhs.value;
715 
716  smart_gpu_compute(rhs, yy);
717 
718  impl::egblas::scalar_div(s, yy.gpu_memory(), etl::size(yy), 1);
719  }
720 
721  yy.validate_gpu();
722  yy.invalidate_cpu();
723 
724  return yy;
725  }
726 
731  static std::string desc() noexcept {
732  return "/";
733  }
734 };
735 
736 } //end of namespace etl
Definition: div.hpp:126
auto s(T &&value)
Force the evaluation of the given expression.
Definition: stop.hpp:18
Definition: div.hpp:102
typename V::template vec_type< T > vec_type
Definition: div.hpp:393
static auto gpu_compute_hint(const L &lhs, const R &rhs, Y &y) noexcept
Compute the result of the operation using the GPU.
Definition: div.hpp:426
static std::string desc() noexcept
Returns a textual representation of the operator.
Definition: div.hpp:731
Definition: div.hpp:30
static Y & gpu_compute(const L &lhs, const R &rhs, Y &yy) noexcept
Compute the result of the operation using the GPU.
Definition: div.hpp:441
Definition: div.hpp:42
Binary operator for scalar division.
Definition: div.hpp:349
A binary expression.
Definition: binary_expr.hpp:18
Root namespace for the ETL library.
Definition: adapter.hpp:15
Definition: div.hpp:176
Definition: div.hpp:150
Definition: div.hpp:138
decltype(auto) force_temporary_gpu_dim_only(E &&expr)
Force a temporary out of the expression, without copying its content.
Definition: temporary.hpp:223
Represents a scalar value.
Definition: concepts_base.hpp:19
constexpr size_t size(const E &expr) noexcept
Returns the size of the given ETL expression.
Definition: helpers.hpp:108
Definition: div.hpp:163
Definition: div.hpp:249
Binary operator for scalar multiplication.
Definition: div.hpp:13
Definition: div.hpp:225
static vec_type< V > load(const vec_type< V > &lhs, const vec_type< V > &rhs) noexcept
Compute several applications of the operator at a time.
Definition: div.hpp:413
Definition: div.hpp:237
Definition: div.hpp:78
Definition: div.hpp:54
Definition: div.hpp:213
Definition: div.hpp:18
Definition: div.hpp:90
Definition: div.hpp:189
decltype(auto) smart_gpu_compute_hint(E &expr, Y &y)
Compute the expression into a representation that is GPU up to date.
Definition: helpers.hpp:368
static constexpr T apply(const T &lhs, const T &rhs) noexcept
Apply the unary operator on lhs and rhs.
Definition: div.hpp:401
AVX-512F is the max vectorization available.
typename decay_traits< E >::value_type value_t
Traits to extract the value type out of an ETL type.
Definition: tmp.hpp:81
decltype(auto) smart_gpu_compute(X &x, Y &y)
Compute the expression into a representation that is GPU up to date and store this representation in ...
Definition: helpers.hpp:397
Definition: div.hpp:114
static constexpr int complexity()
Estimate the complexity of operator.
Definition: div.hpp:385
Definition: div.hpp:66
Binary operator for scalar addition.
Definition: plus.hpp:154
Definition: div.hpp:201