Expression Templates Library (ETL)
plus.hpp
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
8 #pragma once
9 
10 namespace etl {
11 
12 template <typename T>
13 struct mul_binary_op;
14 
15 // detect 1.0 * x + y
16 
17 template <typename L, typename R>
19  static constexpr bool value = false;
20 };
21 
22 template <typename T0, typename T1, typename T2, typename RightExpr, typename R>
23 struct is_axpy_left_left_impl<binary_expr<T0, etl::scalar<T1>, etl::mul_binary_op<T2>, RightExpr>, R> {
24  static constexpr bool value = !is_scalar<R>;
25 };
26 
27 // detect x * 1.0 + y
28 
29 template <typename L, typename R>
31  static constexpr bool value = false;
32 };
33 
34 template <typename T0, typename T1, typename T2, typename LeftExpr, typename R>
35 struct is_axpy_left_right_impl<binary_expr<T0, LeftExpr, etl::mul_binary_op<T2>, etl::scalar<T1>>, R> {
36  static constexpr bool value = !is_scalar<R>;
37 };
38 
39 // detect x + 1.0 * y
40 
41 template <typename L, typename R>
43  static constexpr bool value = false;
44 };
45 
46 template <typename T0, typename T1, typename T2, typename RightExpr, typename L>
47 struct is_axpy_right_left_impl<L, binary_expr<T0, etl::scalar<T1>, etl::mul_binary_op<T2>, RightExpr>> {
48  static constexpr bool value = !is_scalar<L> && !is_scalar<RightExpr>;
49 };
50 
51 // detect x + y * 1.0
52 
53 template <typename L, typename R>
55  static constexpr bool value = false;
56 };
57 
58 template <typename T0, typename T1, typename T2, typename LeftExpr, typename L>
59 struct is_axpy_right_right_impl<L, binary_expr<T0, LeftExpr, etl::mul_binary_op<T2>, etl::scalar<T1>>> {
60  static constexpr bool value = !is_scalar<LeftExpr>;
61 };
62 
63 // detect 1.0 * x + 1.0 * y
64 
65 template <typename L, typename R>
67  static constexpr bool value = false;
68 };
69 
70 template <typename LT1, typename LT2, typename LT3, typename LRightExpr, typename RT1, typename RT2, typename RT3, typename RRightExpr>
71 struct is_axpby_left_left_impl<binary_expr<LT1, etl::scalar<LT2>, etl::mul_binary_op<LT3>, LRightExpr>,
72  binary_expr<RT1, etl::scalar<RT2>, etl::mul_binary_op<RT3>, RRightExpr>> {
73  static constexpr bool value = true;
74 };
75 
76 // detect 1.0 * x + y * 1.0
77 
78 template <typename L, typename R>
80  static constexpr bool value = false;
81 };
82 
83 template <typename LT1, typename LT2, typename LT3, typename LRightExpr, typename RT1, typename RT2, typename RT3, typename RLeftExpr>
84 struct is_axpby_left_right_impl<binary_expr<LT1, etl::scalar<LT2>, etl::mul_binary_op<LT3>, LRightExpr>,
85  binary_expr<RT1, RLeftExpr, etl::mul_binary_op<RT3>, etl::scalar<RT2>>> {
86  static constexpr bool value = true;
87 };
88 
89 // detect x * 1.0 + 1.0 * y
90 
91 template <typename L, typename R>
93  static constexpr bool value = false;
94 };
95 
96 template <typename LT1, typename LT2, typename LT3, typename LLeftExpr, typename RT1, typename RT2, typename RT3, typename RRightExpr>
98  binary_expr<RT1, etl::scalar<RT2>, etl::mul_binary_op<RT3>, RRightExpr>> {
99  static constexpr bool value = true;
100 };
101 
102 // detect x * 1.0 + y * 1.0
103 
104 template <typename L, typename R>
106  static constexpr bool value = false;
107 };
108 
109 template <typename LT1, typename LT2, typename LT3, typename LLeftExpr, typename RT1, typename RT2, typename RT3, typename RLeftExpr>
111  binary_expr<RT1, RLeftExpr, etl::mul_binary_op<RT3>, etl::scalar<RT2>>> {
112  static constexpr bool value = true;
113 };
114 
115 // Variable templates helper
116 
117 template <typename L, typename R>
118 static constexpr bool is_axpby_left_left = is_axpby_left_left_impl<L, R>::value;
119 
120 template <typename L, typename R>
121 static constexpr bool is_axpby_left_right = is_axpby_left_right_impl<L, R>::value;
122 
123 template <typename L, typename R>
124 static constexpr bool is_axpby_right_left = is_axpby_right_left_impl<L, R>::value;
125 
126 template <typename L, typename R>
127 static constexpr bool is_axpby_right_right = is_axpby_right_right_impl<L, R>::value;
128 
129 template <typename L, typename R>
130 static constexpr bool is_axpby = is_axpby_left_left<L, R> || is_axpby_right_right<L, R> || is_axpby_left_right<L, R> || is_axpby_right_left<L, R>;
131 
132 template <typename L, typename R>
133 static constexpr bool is_axpy_left_left = is_axpy_left_left_impl<L, R>::value && !is_axpby<L, R>;
134 
135 template <typename L, typename R>
136 static constexpr bool is_axpy_left_right = is_axpy_left_right_impl<L, R>::value && !is_axpby<L, R>;
137 
138 template <typename L, typename R>
139 static constexpr bool is_axpy_right_left = is_axpy_right_left_impl<L, R>::value && !is_axpby<L, R>;
140 
141 template <typename L, typename R>
142 static constexpr bool is_axpy_right_right = is_axpy_right_right_impl<L, R>::value && !is_axpby<L, R>;
143 
144 template <typename L, typename R>
145 static constexpr bool is_axpy = is_axpy_left_left<L, R> || is_axpy_left_right<L, R> || is_axpy_right_left<L, R> || is_axpy_right_right<L, R>;
146 
147 template <typename L, typename R>
148 static constexpr bool is_special_plus = is_axpy<L, R> || is_axpby<L, R>;
149 
153 template <typename T>
155  static constexpr bool linear = true;
156  static constexpr bool thread_safe = true;
157  static constexpr bool desc_func = false;
158 
164  template <vector_mode_t V>
165  static constexpr bool vectorizable = true;
166 
170  template <typename L, typename R>
171  static constexpr bool gpu_computable =
172  ((!is_scalar<L> && !is_scalar<R>)&&((is_single_precision_t<T> && impl::egblas::has_saxpy_3 && impl::egblas::has_saxpby_3)
173  || (is_double_precision_t<T> && impl::egblas::has_daxpy_3 && impl::egblas::has_daxpby_3)
174  || (is_complex_single_t<T> && impl::egblas::has_caxpy_3 && impl::egblas::has_caxpby_3)
175  || (is_complex_double_t<T> && impl::egblas::has_zaxpy_3 && impl::egblas::has_zaxpby_3)))
176  || ((is_scalar<L> != is_scalar<R>)&&((is_single_precision_t<T> && impl::egblas::has_scalar_sadd)
177  || (is_double_precision_t<T> && impl::egblas::has_scalar_dadd)
178  || (is_complex_single_t<T> && impl::egblas::has_scalar_cadd)
179  || (is_complex_double_t<T> && impl::egblas::has_scalar_zadd)));
180 
185  static constexpr int complexity() {
186  return 1;
187  }
188 
192  template <typename V = default_vec>
193  using vec_type = typename V::template vec_type<T>;
194 
201  static constexpr T apply(const T& lhs, const T& rhs) noexcept {
202  return lhs + rhs;
203  }
204 
212  template <typename V = default_vec>
213  static ETL_STRONG_INLINE(vec_type<V>) load(const vec_type<V>& lhs, const vec_type<V>& rhs) noexcept {
214  return V::add(lhs, rhs);
215  }
216 
225  template <typename L, typename R, typename Y>
226  static auto gpu_compute_hint(const L& lhs, const R& rhs, Y& y) noexcept {
227  auto t3 = force_temporary_gpu_dim_only(y);
228  gpu_compute(lhs, rhs, t3);
229  return t3;
230  }
231 
240  template <typename L, typename R, typename Y>
241  static Y& gpu_compute(const L& lhs, const R& rhs, Y& yy) noexcept {
242  if constexpr (!is_scalar<L> && !is_scalar<R> && is_axpy_left_left<L, R>) {
243  auto& lhs_lhs = lhs.get_lhs();
244  auto& lhs_rhs = lhs.get_rhs();
245 
246  decltype(auto) x = smart_gpu_compute_hint(lhs_rhs, yy);
247  decltype(auto) y = smart_gpu_compute_hint(rhs, yy);
248 
249  constexpr auto incx = gpu_inc<decltype(lhs_rhs)>;
250  constexpr auto incy = gpu_inc<decltype(rhs)>;
251 
252  impl::egblas::axpy_3(etl::size(yy), lhs_lhs.value, x.gpu_memory(), incx, y.gpu_memory(), incy, yy.gpu_memory(), 1);
253  } else if constexpr (!is_scalar<L> && !is_scalar<R> && is_axpy_left_right<L, R>) {
254  auto& lhs_lhs = lhs.get_lhs();
255  auto& lhs_rhs = lhs.get_rhs();
256 
257  decltype(auto) x = smart_gpu_compute_hint(lhs_lhs, yy);
258  decltype(auto) y = smart_gpu_compute_hint(rhs, yy);
259 
260  constexpr auto incx = gpu_inc<decltype(lhs_lhs)>;
261  constexpr auto incy = gpu_inc<decltype(rhs)>;
262 
263  impl::egblas::axpy_3(etl::size(yy), lhs_rhs.value, x.gpu_memory(), incx, y.gpu_memory(), incy, yy.gpu_memory(), 1);
264  } else if constexpr (is_axpy_right_left<L, R>) {
265  auto& rhs_lhs = rhs.get_lhs();
266  auto& rhs_rhs = rhs.get_rhs();
267 
268  decltype(auto) x = smart_gpu_compute_hint(rhs_rhs, yy);
269  decltype(auto) y = smart_gpu_compute_hint(lhs, yy);
270 
271  constexpr auto incx = gpu_inc<decltype(rhs_rhs)>;
272  constexpr auto incy = gpu_inc<decltype(lhs)>;
273 
274  impl::egblas::axpy_3(etl::size(yy), rhs_lhs.value, x.gpu_memory(), incx, y.gpu_memory(), incy, yy.gpu_memory(), 1);
275  } else if constexpr (is_axpy_right_right<L, R>) {
276  auto& rhs_lhs = rhs.get_lhs();
277  auto& rhs_rhs = rhs.get_rhs();
278 
279  decltype(auto) x = smart_gpu_compute_hint(rhs_lhs, yy);
280  decltype(auto) y = smart_gpu_compute_hint(lhs, yy);
281 
282  constexpr auto incx = gpu_inc<decltype(rhs_lhs)>;
283  constexpr auto incy = gpu_inc<decltype(lhs)>;
284 
285  impl::egblas::axpy_3(etl::size(yy), rhs_rhs.value, x.gpu_memory(), incx, y.gpu_memory(), incy, yy.gpu_memory(), 1);
286  } else if constexpr (is_axpby_left_left<L, R>) {
287  auto& lhs_lhs = lhs.get_lhs();
288  auto& lhs_rhs = lhs.get_rhs();
289 
290  auto& rhs_lhs = rhs.get_lhs();
291  auto& rhs_rhs = rhs.get_rhs();
292 
293  decltype(auto) x = smart_gpu_compute_hint(lhs_rhs, yy);
294  decltype(auto) y = smart_gpu_compute_hint(rhs_rhs, yy);
295 
296  constexpr auto incx = gpu_inc<decltype(lhs_rhs)>;
297  constexpr auto incy = gpu_inc<decltype(rhs_rhs)>;
298 
299  impl::egblas::axpby_3(etl::size(yy), lhs_lhs.value, x.gpu_memory(), incx, rhs_lhs.value, y.gpu_memory(), incy, yy.gpu_memory(), 1);
300  } else if constexpr (is_axpby_left_right<L, R>) {
301  auto& lhs_lhs = lhs.get_lhs();
302  auto& lhs_rhs = lhs.get_rhs();
303 
304  auto& rhs_lhs = rhs.get_lhs();
305  auto& rhs_rhs = rhs.get_rhs();
306 
307  decltype(auto) x = smart_gpu_compute_hint(lhs_rhs, yy);
308  decltype(auto) y = smart_gpu_compute_hint(rhs_lhs, yy);
309 
310  constexpr auto incx = gpu_inc<decltype(lhs_rhs)>;
311  constexpr auto incy = gpu_inc<decltype(rhs_lhs)>;
312 
313  impl::egblas::axpby_3(etl::size(yy), lhs_lhs.value, x.gpu_memory(), incx, rhs_rhs.value, y.gpu_memory(), incy, yy.gpu_memory(), 1);
314  } else if constexpr (is_axpby_right_left<L, R>) {
315  auto& lhs_lhs = lhs.get_lhs();
316  auto& lhs_rhs = lhs.get_rhs();
317 
318  auto& rhs_lhs = rhs.get_lhs();
319  auto& rhs_rhs = rhs.get_rhs();
320 
321  decltype(auto) x = smart_gpu_compute_hint(lhs_lhs, yy);
322  decltype(auto) y = smart_gpu_compute_hint(rhs_rhs, yy);
323 
324  constexpr auto incx = gpu_inc<decltype(lhs_lhs)>;
325  constexpr auto incy = gpu_inc<decltype(rhs_rhs)>;
326 
327  impl::egblas::axpby_3(etl::size(yy), lhs_rhs.value, x.gpu_memory(), incx, rhs_lhs.value, y.gpu_memory(), incy, yy.gpu_memory(), 1);
328  } else if constexpr (is_axpby_right_right<L, R>) {
329  auto& lhs_lhs = lhs.get_lhs();
330  auto& lhs_rhs = lhs.get_rhs();
331 
332  auto& rhs_lhs = rhs.get_lhs();
333  auto& rhs_rhs = rhs.get_rhs();
334 
335  decltype(auto) x = smart_gpu_compute_hint(lhs_lhs, yy);
336  decltype(auto) y = smart_gpu_compute_hint(rhs_lhs, yy);
337 
338  constexpr auto incx = gpu_inc<decltype(lhs_lhs)>;
339  constexpr auto incy = gpu_inc<decltype(rhs_lhs)>;
340 
341  impl::egblas::axpby_3(etl::size(yy), lhs_rhs.value, x.gpu_memory(), incx, rhs_rhs.value, y.gpu_memory(), incy, yy.gpu_memory(), 1);
342  } else if constexpr (!is_scalar<L> && !is_scalar<R> && !is_special_plus<L, R>) {
343  decltype(auto) x = smart_gpu_compute_hint(lhs, yy);
344  decltype(auto) y = smart_gpu_compute_hint(rhs, yy);
345 
346  constexpr auto incx = gpu_inc<decltype(lhs)>;
347  constexpr auto incy = gpu_inc<decltype(rhs)>;
348 
349  value_t<L> alpha(1);
350  impl::egblas::axpy_3(etl::size(yy), alpha, x.gpu_memory(), incx, y.gpu_memory(), incy, yy.gpu_memory(), 1);
351  } else if constexpr (is_scalar<L> && !is_scalar<R>) {
352  auto s = lhs.value;
353 
354  smart_gpu_compute(rhs, yy);
355 
356  impl::egblas::scalar_add(yy.gpu_memory(), etl::size(yy), 1, s);
357  } else if constexpr (!is_scalar<L> && is_scalar<R>) {
358  auto s = rhs.value;
359 
360  smart_gpu_compute(lhs, yy);
361 
362  impl::egblas::scalar_add(yy.gpu_memory(), etl::size(yy), 1, s);
363  }
364 
365  yy.validate_gpu();
366  yy.invalidate_cpu();
367 
368  return yy;
369  }
370 
375  static std::string desc() noexcept {
376  return "+";
377  }
378 };
379 
380 } //end of namespace etl
Definition: plus.hpp:30
auto s(T &&value)
Force the evaluation of the given expression.
Definition: stop.hpp:18
typename V::template vec_type< T > vec_type
Definition: plus.hpp:193
static Y & gpu_compute(const L &lhs, const R &rhs, Y &yy) noexcept
Compute the result of the operation using the GPU.
Definition: plus.hpp:241
Definition: plus.hpp:79
static constexpr T apply(const T &lhs, const T &rhs) noexcept
Apply the unary operator on lhs and rhs.
Definition: plus.hpp:201
A binary expression.
Definition: binary_expr.hpp:18
Definition: plus.hpp:105
auto load(size_t x) const noexcept
Load several elements of the expression at once.
Definition: dyn_matrix_view.hpp:143
Root namespace for the ETL library.
Definition: adapter.hpp:15
Definition: plus.hpp:42
Definition: plus.hpp:66
Definition: plus.hpp:54
static std::string desc() noexcept
Returns a textual representation of the operator.
Definition: plus.hpp:375
Definition: plus.hpp:92
decltype(auto) force_temporary_gpu_dim_only(E &&expr)
Force a temporary out of the expression, without copying its content.
Definition: temporary.hpp:223
Represents a scalar value.
Definition: concepts_base.hpp:19
Definition: plus.hpp:18
static constexpr int complexity()
Estimate the complexity of operator.
Definition: plus.hpp:185
constexpr size_t size(const E &expr) noexcept
Returns the size of the given ETL expression.
Definition: helpers.hpp:108
Binary operator for scalar multiplication.
Definition: div.hpp:13
static auto gpu_compute_hint(const L &lhs, const R &rhs, Y &y) noexcept
Compute the result of the operation using the GPU.
Definition: plus.hpp:226
decltype(auto) smart_gpu_compute_hint(E &expr, Y &y)
Compute the expression into a representation that is GPU up to date.
Definition: helpers.hpp:368
typename decay_traits< E >::value_type value_t
Traits to extract the value type out of an ETL type.
Definition: tmp.hpp:81
decltype(auto) smart_gpu_compute(X &x, Y &y)
Compute the expression into a representation that is GPU up to date and store this representation in ...
Definition: helpers.hpp:397
Binary operator for scalar addition.
Definition: plus.hpp:154