Expression Templates Library (ETL)
minus.hpp
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
8 #pragma once
9 
10 namespace etl {
11 
15 template <typename T>
17  static constexpr bool linear = true;
18  static constexpr bool thread_safe = true;
19  static constexpr bool desc_func = false;
20 
26  template <vector_mode_t V>
27  static constexpr bool vectorizable = true;
28 
32  template <typename L, typename R>
33  static constexpr bool gpu_computable =
34  ((!is_scalar<L> && !is_scalar<R>)&&((is_single_precision_t<T> && impl::egblas::has_saxpy_3 && impl::egblas::has_saxpby_3)
35  || (is_double_precision_t<T> && impl::egblas::has_daxpy_3 && impl::egblas::has_daxpby_3)
36  || (is_complex_single_t<T> && impl::egblas::has_caxpy_3 && impl::egblas::has_caxpby_3)
37  || (is_complex_double_t<T> && impl::egblas::has_zaxpy_3 && impl::egblas::has_zaxpby_3)))
38  || ((is_scalar<L> != is_scalar<R>)&&((is_single_precision_t<T> && impl::egblas::has_scalar_sadd && impl::egblas::has_scalar_smul)
39  || (is_double_precision_t<T> && impl::egblas::has_scalar_dadd && impl::egblas::has_scalar_dmul)
40  || (is_complex_single_t<T> && impl::egblas::has_scalar_cadd && impl::egblas::has_scalar_cmul)
41  || (is_complex_double_t<T> && impl::egblas::has_scalar_zadd && impl::egblas::has_scalar_zmul)));
42 
47  static constexpr int complexity() {
48  return 1;
49  }
50 
54  template <typename V = default_vec>
55  using vec_type = typename V::template vec_type<T>;
56 
63  static constexpr T apply(const T& lhs, const T& rhs) noexcept {
64  return lhs - rhs;
65  }
66 
74  template <typename V = default_vec>
75  static vec_type<V> load(const vec_type<V>& lhs, const vec_type<V>& rhs) noexcept {
76  return V::sub(lhs, rhs);
77  }
78 
87  template <typename L, typename R, typename Y>
88  static auto gpu_compute_hint(const L& lhs, const R& rhs, Y& y) noexcept {
89  auto t3 = force_temporary_gpu_dim_only(y);
90  gpu_compute(lhs, rhs, t3);
91  return t3;
92  }
93 
102  template <typename L, typename R, typename Y>
103  static Y& gpu_compute(const L& lhs, const R& rhs, Y& yy) noexcept {
104  if constexpr (is_axpy_right_left<L, R>) {
105  auto& rhs_lhs = rhs.get_lhs();
106  auto& rhs_rhs = rhs.get_rhs();
107 
108  decltype(auto) x = smart_gpu_compute_hint(rhs_rhs, yy);
109  decltype(auto) y = smart_gpu_compute_hint(lhs, yy);
110 
111  constexpr auto incx = gpu_inc<decltype(rhs_rhs)>;
112  constexpr auto incy = gpu_inc<decltype(lhs)>;
113 
114  impl::egblas::axpy_3(etl::size(yy), -rhs_lhs.value, x.gpu_memory(), incx, y.gpu_memory(), incy, yy.gpu_memory(), 1);
115  } else if constexpr (is_axpy_right_right<L, R>) {
116  auto& rhs_lhs = rhs.get_lhs();
117  auto& rhs_rhs = rhs.get_rhs();
118 
119  decltype(auto) x = smart_gpu_compute_hint(rhs_lhs, yy);
120  decltype(auto) y = smart_gpu_compute_hint(lhs, yy);
121 
122  constexpr auto incx = gpu_inc<decltype(rhs_lhs)>;
123  constexpr auto incy = gpu_inc<decltype(lhs)>;
124 
125  impl::egblas::axpy_3(etl::size(yy), -rhs_rhs.value, x.gpu_memory(), incx, y.gpu_memory(), incy, yy.gpu_memory(), 1);
126  } else if constexpr (is_axpy_left_left<L, R>) {
127  auto& lhs_lhs = lhs.get_lhs();
128  auto& lhs_rhs = lhs.get_rhs();
129 
130  decltype(auto) x = smart_gpu_compute_hint(lhs_rhs, yy);
131  decltype(auto) y = smart_gpu_compute_hint(rhs, yy);
132 
133  constexpr auto incx = gpu_inc<decltype(lhs_rhs)>;
134  constexpr auto incy = gpu_inc<decltype(rhs)>;
135 
136  impl::egblas::axpby_3(etl::size(yy), lhs_lhs.value, x.gpu_memory(), incx, T(-1), y.gpu_memory(), incy, yy.gpu_memory(), 1);
137  } else if constexpr (is_axpy_left_right<L, R>) {
138  auto& lhs_lhs = lhs.get_lhs();
139  auto& lhs_rhs = lhs.get_rhs();
140 
141  decltype(auto) x = smart_gpu_compute_hint(lhs_lhs, yy);
142  decltype(auto) y = smart_gpu_compute_hint(rhs, yy);
143 
144  constexpr auto incx = gpu_inc<decltype(lhs_lhs)>;
145  constexpr auto incy = gpu_inc<decltype(rhs)>;
146 
147  impl::egblas::axpby_3(etl::size(yy), lhs_rhs.value, x.gpu_memory(), incx, T(-1), y.gpu_memory(), incy, yy.gpu_memory(), 1);
148  } else if constexpr (is_axpby_left_left<L, R>) {
149  auto& lhs_lhs = lhs.get_lhs();
150  auto& lhs_rhs = lhs.get_rhs();
151 
152  auto& rhs_lhs = rhs.get_lhs();
153  auto& rhs_rhs = rhs.get_rhs();
154 
155  decltype(auto) x = smart_gpu_compute_hint(lhs_rhs, yy);
156  decltype(auto) y = smart_gpu_compute_hint(rhs_rhs, yy);
157 
158  constexpr auto incx = gpu_inc<decltype(lhs_rhs)>;
159  constexpr auto incy = gpu_inc<decltype(rhs_rhs)>;
160 
161  impl::egblas::axpby_3(etl::size(yy), lhs_lhs.value, x.gpu_memory(), incx, T(-1) * rhs_lhs.value, y.gpu_memory(), incy, yy.gpu_memory(), 1);
162  } else if constexpr (is_axpby_left_right<L, R>) {
163  auto& lhs_lhs = lhs.get_lhs();
164  auto& lhs_rhs = lhs.get_rhs();
165 
166  auto& rhs_lhs = rhs.get_lhs();
167  auto& rhs_rhs = rhs.get_rhs();
168 
169  decltype(auto) x = smart_gpu_compute_hint(lhs_rhs, yy);
170  decltype(auto) y = smart_gpu_compute_hint(rhs_lhs, yy);
171 
172  constexpr auto incx = gpu_inc<decltype(lhs_rhs)>;
173  constexpr auto incy = gpu_inc<decltype(rhs_lhs)>;
174 
175  impl::egblas::axpby_3(etl::size(yy), lhs_lhs.value, x.gpu_memory(), incx, T(-1) * rhs_rhs.value, y.gpu_memory(), incy, yy.gpu_memory(), 1);
176  } else if constexpr (is_axpby_right_left<L, R>) {
177  auto& lhs_lhs = lhs.get_lhs();
178  auto& lhs_rhs = lhs.get_rhs();
179 
180  auto& rhs_lhs = rhs.get_lhs();
181  auto& rhs_rhs = rhs.get_rhs();
182 
183  decltype(auto) x = smart_gpu_compute_hint(lhs_lhs, yy);
184  decltype(auto) y = smart_gpu_compute_hint(rhs_rhs, yy);
185 
186  constexpr auto incx = gpu_inc<decltype(lhs_lhs)>;
187  constexpr auto incy = gpu_inc<decltype(rhs_rhs)>;
188 
189  impl::egblas::axpby_3(etl::size(yy), lhs_rhs.value, x.gpu_memory(), incx, T(-1) * rhs_lhs.value, y.gpu_memory(), incy, yy.gpu_memory(), 1);
190  } else if constexpr (is_axpby_right_right<L, R>) {
191  auto& lhs_lhs = lhs.get_lhs();
192  auto& lhs_rhs = lhs.get_rhs();
193 
194  auto& rhs_lhs = rhs.get_lhs();
195  auto& rhs_rhs = rhs.get_rhs();
196 
197  decltype(auto) x = smart_gpu_compute_hint(lhs_lhs, yy);
198  decltype(auto) y = smart_gpu_compute_hint(rhs_lhs, yy);
199 
200  constexpr auto incx = gpu_inc<decltype(lhs_lhs)>;
201  constexpr auto incy = gpu_inc<decltype(rhs_lhs)>;
202 
203  impl::egblas::axpby_3(etl::size(yy), lhs_rhs.value, x.gpu_memory(), incx, T(-1) * rhs_rhs.value, y.gpu_memory(), incy, yy.gpu_memory(), 1);
204  } else if constexpr (!is_scalar<L> && !is_scalar<R> && !is_special_plus<L, R>) {
205  decltype(auto) x = smart_gpu_compute_hint(lhs, yy);
206  decltype(auto) y = smart_gpu_compute_hint(rhs, yy);
207 
208  constexpr auto incx = gpu_inc<decltype(lhs)>;
209  constexpr auto incy = gpu_inc<decltype(rhs)>;
210 
211  impl::egblas::axpy_3(etl::size(yy), value_t<L>(-1), y.gpu_memory(), incy, x.gpu_memory(), incx, yy.gpu_memory(), 1);
212  } else if constexpr (!is_scalar<L> && is_scalar<R>) {
213  auto s = -rhs.value;
214 
215  smart_gpu_compute(lhs, yy);
216 
217  impl::egblas::scalar_add(yy.gpu_memory(), etl::size(yy), 1, s);
218  } else if constexpr (is_scalar<L> && !is_scalar<R>) {
219  auto s = lhs.value;
220 
221  smart_gpu_compute(rhs, yy);
222 
223  impl::egblas::scalar_mul(yy.gpu_memory(), etl::size(yy), 1, value_t<L>(-1));
224  impl::egblas::scalar_add(yy.gpu_memory(), etl::size(yy), 1, s);
225  }
226 
227  yy.validate_gpu();
228  yy.invalidate_cpu();
229 
230  return yy;
231  }
232 
237  static std::string desc() noexcept {
238  return "-";
239  }
240 };
241 
242 } //end of namespace etl
static std::string desc() noexcept
Returns a textual representation of the operator.
Definition: minus.hpp:237
auto s(T &&value)
Force the evaluation of the given expression.
Definition: stop.hpp:18
static constexpr int complexity()
Estimate the complexity of operator.
Definition: minus.hpp:47
typename V::template vec_type< T > vec_type
Definition: minus.hpp:55
Binary operator for scalar subtraction.
Definition: minus.hpp:16
static constexpr bool gpu_computable
Indicates if the operator can be computed on GPU.
Definition: minus.hpp:33
static vec_type< V > load(const vec_type< V > &lhs, const vec_type< V > &rhs) noexcept
Compute several applications of the operator at a time.
Definition: minus.hpp:75
static constexpr T apply(const T &lhs, const T &rhs) noexcept
Apply the unary operator on lhs and rhs.
Definition: minus.hpp:63
static constexpr bool desc_func
Indicates if the description must be printed as function.
Definition: minus.hpp:19
static constexpr bool linear
Indicates if the operator is linear or not.
Definition: minus.hpp:17
Root namespace for the ETL library.
Definition: adapter.hpp:15
decltype(auto) force_temporary_gpu_dim_only(E &&expr)
Force a temporary out of the expression, without copying its content.
Definition: temporary.hpp:223
static constexpr bool vectorizable
Indicates if the expression is vectorizable using the given vector mode.
Definition: minus.hpp:27
static auto gpu_compute_hint(const L &lhs, const R &rhs, Y &y) noexcept
Compute the result of the operation using the GPU.
Definition: minus.hpp:88
constexpr size_t size(const E &expr) noexcept
Returns the size of the given ETL expression.
Definition: helpers.hpp:108
decltype(auto) smart_gpu_compute_hint(E &expr, Y &y)
Compute the expression into a representation that is GPU up to date.
Definition: helpers.hpp:368
typename decay_traits< E >::value_type value_t
Traits to extract the value type out of an ETL type.
Definition: tmp.hpp:81
static constexpr bool thread_safe
Indicates if the operator is thread safe or not.
Definition: minus.hpp:18
static Y & gpu_compute(const L &lhs, const R &rhs, Y &yy) noexcept
Compute the result of the operation using the GPU.
Definition: minus.hpp:103
decltype(auto) smart_gpu_compute(X &x, Y &y)
Compute the expression into a representation that is GPU up to date and store this representation in ...
Definition: helpers.hpp:397