Expression Templates Library (ETL)
mul.hpp
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
8 #pragma once
9 
10 namespace etl {
11 
12 template <typename T>
13 struct mul_binary_op;
14 
15 // detect 1.0 * (x * y)
16 
17 template <typename L, typename R>
19  static constexpr bool value = false;
20 };
21 
22 template <typename T0, typename T1, typename T2, typename LeftExpr, typename RightExpr>
23 struct is_axmy_left_impl<etl::scalar<T0>, binary_expr<T1, LeftExpr, etl::mul_binary_op<T2>, RightExpr>> {
24  static constexpr bool value = true;
25 };
26 
27 // detect (x * y) * 1.0
28 
29 template <typename L, typename R>
31  static constexpr bool value = false;
32 };
33 
34 template <typename T0, typename T1, typename T2, typename LeftExpr, typename RightExpr>
35 struct is_axmy_right_impl<binary_expr<T1, LeftExpr, etl::mul_binary_op<T2>, RightExpr>, etl::scalar<T0>> {
36  static constexpr bool value = true;
37 };
38 
39 // detect (1.0 * x) * y
40 
41 template <typename L, typename R>
43  static constexpr bool value = false;
44 };
45 
46 template <typename T0, typename T1, typename T2, typename RightExpr, typename R>
47 struct is_axmy_left_left_impl<binary_expr<T1, etl::scalar<T0>, etl::mul_binary_op<T2>, RightExpr>, R> {
48  static constexpr bool value = true;
49 };
50 
51 // detect (x * 1.0) * y
52 
53 template <typename L, typename R>
55  static constexpr bool value = false;
56 };
57 
58 template <typename T0, typename T1, typename T2, typename LeftExpr, typename R>
59 struct is_axmy_left_right_impl<binary_expr<T1, LeftExpr, etl::mul_binary_op<T2>, etl::scalar<T0>>, R> {
60  static constexpr bool value = !is_scalar<R>;
61 };
62 
63 // detect x * (1.0 * y)
64 
65 template <typename L, typename R>
67  static constexpr bool value = false;
68 };
69 
70 template <typename L, typename T0, typename T1, typename T2, typename RightExpr>
71 struct is_axmy_right_left_impl<L, binary_expr<T0, etl::scalar<T1>, etl::mul_binary_op<T2>, RightExpr>> {
72  static constexpr bool value = !is_scalar<L>;
73 };
74 
75 // detect x * (1.0 * y)
76 
77 template <typename L, typename R>
79  static constexpr bool value = false;
80 };
81 
82 template <typename L, typename T0, typename T1, typename T2, typename RightExpr>
83 struct is_axmy_right_right_impl<L, binary_expr<T0, RightExpr, etl::mul_binary_op<T2>, etl::scalar<T1>>> {
84  static constexpr bool value = true;
85 };
86 
87 // Variable templates helper
88 
89 template <typename L, typename R>
90 static constexpr bool is_axmy_left = is_axmy_left_impl<L, R>::value;
91 
92 template <typename L, typename R>
93 static constexpr bool is_axmy_right = is_axmy_right_impl<L, R>::value;
94 
95 template <typename L, typename R>
96 static constexpr bool is_axmy_left_left = is_axmy_left_left_impl<L, R>::value;
97 
98 template <typename L, typename R>
99 static constexpr bool is_axmy_left_right = is_axmy_left_right_impl<L, R>::value;
100 
101 template <typename L, typename R>
102 static constexpr bool is_axmy_right_left = is_axmy_right_left_impl<L, R>::value;
103 
104 template <typename L, typename R>
105 static constexpr bool is_axmy_right_right = is_axmy_right_right_impl<L, R>::value;
106 
107 template <typename L, typename R>
108 static constexpr bool is_axmy =
109  is_axmy_left<L, R> || is_axmy_right<L, R> || is_axmy_left_left<L, R> || is_axmy_left_right<L, R> || is_axmy_right_left<L, R> || is_axmy_right_right<L, R>;
110 
114 template <typename T>
115 struct mul_binary_op {
116  static constexpr bool linear = true;
117  static constexpr bool thread_safe = true;
118  static constexpr bool desc_func = false;
119 
125  template <vector_mode_t V>
126  static constexpr bool vectorizable = V == vector_mode_t::AVX512 ? !is_complex_t<T> : true;
127 
131  template <typename L, typename R>
132  static constexpr bool gpu_computable = ((!is_scalar<L> && !is_scalar<R>)&&((is_single_precision_t<T> && impl::egblas::has_saxmy_3)
133  || (is_double_precision_t<T> && impl::egblas::has_daxmy_3)
134  || (is_complex_single_t<T> && impl::egblas::has_caxmy_3)
135  || (is_complex_double_t<T> && impl::egblas::has_zaxmy_3)))
136  || ((is_scalar<L> != is_scalar<R>)&&((is_single_precision_t<T> && impl::egblas::has_scalar_smul)
137  || (is_double_precision_t<T> && impl::egblas::has_scalar_dmul)
138  || (is_complex_single_t<T> && impl::egblas::has_scalar_cmul)
139  || (is_complex_double_t<T> && impl::egblas::has_scalar_zmul)));
140 
145  static constexpr int complexity() {
146  return 2;
147  }
148 
152  template <typename V = default_vec>
153  using vec_type = typename V::template vec_type<T>;
154 
161  static constexpr T apply(const T& lhs, const T& rhs) noexcept {
162  return lhs * rhs;
163  }
164 
172  template <typename V = default_vec>
173  static vec_type<V> load(const vec_type<V>& lhs, const vec_type<V>& rhs) noexcept {
174  return V::mul(lhs, rhs);
175  }
176 
185  template <typename L, typename R, typename Y>
186  static auto gpu_compute_hint(const L& lhs, const R& rhs, Y& y) noexcept {
187  auto t3 = force_temporary_gpu_dim_only(y);
188  gpu_compute(lhs, rhs, t3);
189  return t3;
190  }
191 
200  template <typename L, typename R, typename YY>
201  static YY& gpu_compute(const L& lhs, const R& rhs, YY& yy) noexcept {
202  if constexpr (is_axmy_left<L, R>) {
203  auto& rhs_lhs = rhs.get_lhs();
204  auto& rhs_rhs = rhs.get_rhs();
205 
206  decltype(auto) x = smart_gpu_compute_hint(rhs_lhs, yy);
207  decltype(auto) y = smart_gpu_compute_hint(rhs_rhs, yy);
208 
209  constexpr auto incx = gpu_inc<decltype(rhs_lhs)>;
210  constexpr auto incy = gpu_inc<decltype(rhs_rhs)>;
211 
212  impl::egblas::axmy_3(etl::size(yy), lhs.value, x.gpu_memory(), incx, y.gpu_memory(), incy, yy.gpu_memory(), 1);
213  } else if constexpr (is_axmy_right<L, R>) {
214  auto& lhs_lhs = lhs.get_lhs();
215  auto& lhs_rhs = lhs.get_rhs();
216 
217  decltype(auto) x = smart_gpu_compute_hint(lhs_lhs, yy);
218  decltype(auto) y = smart_gpu_compute_hint(lhs_rhs, yy);
219 
220  constexpr auto incx = gpu_inc<decltype(lhs_lhs)>;
221  constexpr auto incy = gpu_inc<decltype(lhs_rhs)>;
222 
223  impl::egblas::axmy_3(etl::size(yy), rhs.value, x.gpu_memory(), incx, y.gpu_memory(), incy, yy.gpu_memory(), 1);
224  } else if constexpr (is_axmy_left_left<L, R>) {
225  auto& lhs_lhs = lhs.get_lhs();
226  auto& lhs_rhs = lhs.get_rhs();
227 
228  decltype(auto) x = smart_gpu_compute_hint(lhs_rhs, yy);
229  decltype(auto) y = smart_gpu_compute_hint(rhs, yy);
230 
231  constexpr auto incx = gpu_inc<decltype(lhs_rhs)>;
232  constexpr auto incy = gpu_inc<decltype(rhs)>;
233 
234  impl::egblas::axmy_3(etl::size(yy), lhs_lhs.value, x.gpu_memory(), incx, y.gpu_memory(), incy, yy.gpu_memory(), 1);
235  } else if constexpr (is_axmy_left_right<L, R>) {
236  auto& lhs_lhs = lhs.get_lhs();
237  auto& lhs_rhs = lhs.get_rhs();
238 
239  decltype(auto) x = smart_gpu_compute_hint(lhs_lhs, yy);
240  decltype(auto) y = smart_gpu_compute_hint(rhs, yy);
241 
242  constexpr auto incx = gpu_inc<decltype(lhs_lhs)>;
243  constexpr auto incy = gpu_inc<decltype(rhs)>;
244 
245  impl::egblas::axmy_3(etl::size(yy), lhs_rhs.value, x.gpu_memory(), incx, y.gpu_memory(), incy, yy.gpu_memory(), 1);
246  } else if constexpr (is_axmy_right_left<L, R>) {
247  auto& rhs_lhs = rhs.get_lhs();
248  auto& rhs_rhs = rhs.get_rhs();
249 
250  decltype(auto) x = smart_gpu_compute_hint(lhs, yy);
251  decltype(auto) y = smart_gpu_compute_hint(rhs_rhs, yy);
252 
253  constexpr auto incx = gpu_inc<decltype(lhs)>;
254  constexpr auto incy = gpu_inc<decltype(rhs_rhs)>;
255 
256  impl::egblas::axmy_3(etl::size(yy), rhs_lhs.value, x.gpu_memory(), incx, y.gpu_memory(), incy, yy.gpu_memory(), 1);
257  } else if constexpr (is_axmy_right_right<L, R>) {
258  auto& rhs_lhs = rhs.get_lhs();
259  auto& rhs_rhs = rhs.get_rhs();
260 
261  decltype(auto) x = smart_gpu_compute_hint(lhs, yy);
262  decltype(auto) y = smart_gpu_compute_hint(rhs_lhs, yy);
263 
264  constexpr auto incx = gpu_inc<decltype(lhs)>;
265  constexpr auto incy = gpu_inc<decltype(rhs_lhs)>;
266 
267  impl::egblas::axmy_3(etl::size(yy), rhs_rhs.value, x.gpu_memory(), incx, y.gpu_memory(), incy, yy.gpu_memory(), 1);
268  } else if constexpr (!is_scalar<L> && !is_scalar<R> && !is_axmy<L, R>) {
269  decltype(auto) x = smart_gpu_compute_hint(lhs, yy);
270  decltype(auto) y = smart_gpu_compute_hint(rhs, yy);
271 
272  constexpr auto incx = gpu_inc<decltype(lhs)>;
273  constexpr auto incy = gpu_inc<decltype(rhs)>;
274 
275  impl::egblas::axmy_3(etl::size(yy), 1, x.gpu_memory(), incx, y.gpu_memory(), incy, yy.gpu_memory(), 1);
276  } else if constexpr (is_scalar<L> && !is_scalar<R> && !is_axmy<L, R>) {
277  smart_gpu_compute(rhs, yy);
278 
279  impl::egblas::scalar_mul(yy.gpu_memory(), etl::size(yy), 1, lhs.value);
280  } else if constexpr (!is_scalar<L> && is_scalar<R> && !is_axmy<L, R>) {
281  smart_gpu_compute(lhs, yy);
282 
283  impl::egblas::scalar_mul(yy.gpu_memory(), etl::size(yy), 1, rhs.value);
284  }
285 
286  yy.validate_gpu();
287  yy.invalidate_cpu();
288 
289  return yy;
290  }
291 
296  static std::string desc() noexcept {
297  return "*";
298  }
299 };
300 
301 } //end of namespace etl
Definition: mul.hpp:54
static vec_type< V > load(const vec_type< V > &lhs, const vec_type< V > &rhs) noexcept
Compute several applications of the operator at a time.
Definition: mul.hpp:173
typename V::template vec_type< T > vec_type
Definition: mul.hpp:153
static YY & gpu_compute(const L &lhs, const R &rhs, YY &yy) noexcept
Compute the result of the operation using the GPU.
Definition: mul.hpp:201
Definition: mul.hpp:42
Definition: mul.hpp:18
static constexpr int complexity()
Estimate the complexity of operator.
Definition: mul.hpp:145
A binary expression.
Definition: binary_expr.hpp:18
Root namespace for the ETL library.
Definition: adapter.hpp:15
Definition: mul.hpp:66
static constexpr T apply(const T &lhs, const T &rhs) noexcept
Apply the unary operator on lhs and rhs.
Definition: mul.hpp:161
static auto gpu_compute_hint(const L &lhs, const R &rhs, Y &y) noexcept
Compute the result of the operation using the GPU.
Definition: mul.hpp:186
Definition: mul.hpp:78
decltype(auto) force_temporary_gpu_dim_only(E &&expr)
Force a temporary out of the expression, without copying its content.
Definition: temporary.hpp:223
Represents a scalar value.
Definition: concepts_base.hpp:19
constexpr size_t size(const E &expr) noexcept
Returns the size of the given ETL expression.
Definition: helpers.hpp:108
Binary operator for scalar multiplication.
Definition: div.hpp:13
decltype(auto) smart_gpu_compute_hint(E &expr, Y &y)
Compute the expression into a representation that is GPU up to date.
Definition: helpers.hpp:368
AVX-512F is the max vectorization available.
static std::string desc() noexcept
Returns a textual representation of the operator.
Definition: mul.hpp:296
decltype(auto) smart_gpu_compute(X &x, Y &y)
Compute the expression into a representation that is GPU up to date and store this representation in ...
Definition: helpers.hpp:397
Definition: mul.hpp:30