17 template <
typename L,
typename R>
19 static constexpr
bool value =
false;
22 template <
typename T0,
typename T1,
typename T2,
typename LeftExpr,
typename RightExpr>
24 static constexpr
bool value =
true;
29 template <
typename L,
typename R>
31 static constexpr
bool value =
false;
34 template <
typename T0,
typename T1,
typename T2,
typename LeftExpr,
typename RightExpr>
36 static constexpr
bool value =
true;
41 template <
typename L,
typename R>
43 static constexpr
bool value =
false;
46 template <
typename T0,
typename T1,
typename T2,
typename RightExpr,
typename R>
48 static constexpr
bool value =
true;
53 template <
typename L,
typename R>
55 static constexpr
bool value =
false;
58 template <
typename T0,
typename T1,
typename T2,
typename LeftExpr,
typename R>
60 static constexpr
bool value = !is_scalar<R>;
65 template <
typename L,
typename R>
67 static constexpr
bool value =
false;
70 template <
typename L,
typename T0,
typename T1,
typename T2,
typename RightExpr>
72 static constexpr
bool value = !is_scalar<L>;
77 template <
typename L,
typename R>
79 static constexpr
bool value =
false;
82 template <
typename L,
typename T0,
typename T1,
typename T2,
typename RightExpr>
84 static constexpr
bool value =
true;
89 template <
typename L,
typename R>
92 template <
typename L,
typename R>
95 template <
typename L,
typename R>
98 template <
typename L,
typename R>
101 template <
typename L,
typename R>
104 template <
typename L,
typename R>
107 template <
typename L,
typename R>
108 static constexpr
bool is_axmy =
109 is_axmy_left<L, R> || is_axmy_right<L, R> || is_axmy_left_left<L, R> || is_axmy_left_right<L, R> || is_axmy_right_left<L, R> || is_axmy_right_right<L, R>;
114 template <
typename T>
116 static constexpr
bool linear =
true;
117 static constexpr
bool thread_safe =
true;
118 static constexpr
bool desc_func =
false;
125 template <vector_mode_t V>
131 template <
typename L,
typename R>
132 static constexpr
bool gpu_computable = ((!is_scalar<L> && !is_scalar<R>)&&((is_single_precision_t<T> && impl::egblas::has_saxmy_3)
133 || (is_double_precision_t<T> && impl::egblas::has_daxmy_3)
134 || (is_complex_single_t<T> && impl::egblas::has_caxmy_3)
135 || (is_complex_double_t<T> && impl::egblas::has_zaxmy_3)))
136 || ((is_scalar<L> != is_scalar<R>)&&((is_single_precision_t<T> && impl::egblas::has_scalar_smul)
137 || (is_double_precision_t<T> && impl::egblas::has_scalar_dmul)
138 || (is_complex_single_t<T> && impl::egblas::has_scalar_cmul)
139 || (is_complex_double_t<T> && impl::egblas::has_scalar_zmul)));
152 template <
typename V = default_vec>
161 static constexpr T
apply(
const T& lhs,
const T& rhs) noexcept {
172 template <
typename V = default_vec>
174 return V::mul(lhs, rhs);
185 template <
typename L,
typename R,
typename Y>
188 gpu_compute(lhs, rhs, t3);
200 template <
typename L,
typename R,
typename YY>
201 static YY&
gpu_compute(
const L& lhs,
const R& rhs, YY& yy) noexcept {
202 if constexpr (is_axmy_left<L, R>) {
203 auto& rhs_lhs = rhs.get_lhs();
204 auto& rhs_rhs = rhs.get_rhs();
209 constexpr
auto incx = gpu_inc<decltype(rhs_lhs)>;
210 constexpr
auto incy = gpu_inc<decltype(rhs_rhs)>;
212 impl::egblas::axmy_3(
etl::size(yy), lhs.value, x.gpu_memory(), incx, y.gpu_memory(), incy, yy.gpu_memory(), 1);
213 }
else if constexpr (is_axmy_right<L, R>) {
214 auto& lhs_lhs = lhs.get_lhs();
215 auto& lhs_rhs = lhs.get_rhs();
220 constexpr
auto incx = gpu_inc<decltype(lhs_lhs)>;
221 constexpr
auto incy = gpu_inc<decltype(lhs_rhs)>;
223 impl::egblas::axmy_3(
etl::size(yy), rhs.value, x.gpu_memory(), incx, y.gpu_memory(), incy, yy.gpu_memory(), 1);
224 }
else if constexpr (is_axmy_left_left<L, R>) {
225 auto& lhs_lhs = lhs.get_lhs();
226 auto& lhs_rhs = lhs.get_rhs();
231 constexpr
auto incx = gpu_inc<decltype(lhs_rhs)>;
232 constexpr
auto incy = gpu_inc<decltype(rhs)>;
234 impl::egblas::axmy_3(
etl::size(yy), lhs_lhs.value, x.gpu_memory(), incx, y.gpu_memory(), incy, yy.gpu_memory(), 1);
235 }
else if constexpr (is_axmy_left_right<L, R>) {
236 auto& lhs_lhs = lhs.get_lhs();
237 auto& lhs_rhs = lhs.get_rhs();
242 constexpr
auto incx = gpu_inc<decltype(lhs_lhs)>;
243 constexpr
auto incy = gpu_inc<decltype(rhs)>;
245 impl::egblas::axmy_3(
etl::size(yy), lhs_rhs.value, x.gpu_memory(), incx, y.gpu_memory(), incy, yy.gpu_memory(), 1);
246 }
else if constexpr (is_axmy_right_left<L, R>) {
247 auto& rhs_lhs = rhs.get_lhs();
248 auto& rhs_rhs = rhs.get_rhs();
253 constexpr
auto incx = gpu_inc<decltype(lhs)>;
254 constexpr
auto incy = gpu_inc<decltype(rhs_rhs)>;
256 impl::egblas::axmy_3(
etl::size(yy), rhs_lhs.value, x.gpu_memory(), incx, y.gpu_memory(), incy, yy.gpu_memory(), 1);
257 }
else if constexpr (is_axmy_right_right<L, R>) {
258 auto& rhs_lhs = rhs.get_lhs();
259 auto& rhs_rhs = rhs.get_rhs();
264 constexpr
auto incx = gpu_inc<decltype(lhs)>;
265 constexpr
auto incy = gpu_inc<decltype(rhs_lhs)>;
267 impl::egblas::axmy_3(
etl::size(yy), rhs_rhs.value, x.gpu_memory(), incx, y.gpu_memory(), incy, yy.gpu_memory(), 1);
268 }
else if constexpr (!is_scalar<L> && !is_scalar<R> && !is_axmy<L, R>) {
272 constexpr
auto incx = gpu_inc<decltype(lhs)>;
273 constexpr
auto incy = gpu_inc<decltype(rhs)>;
275 impl::egblas::axmy_3(
etl::size(yy), 1, x.gpu_memory(), incx, y.gpu_memory(), incy, yy.gpu_memory(), 1);
276 }
else if constexpr (is_scalar<L> && !is_scalar<R> && !is_axmy<L, R>) {
279 impl::egblas::scalar_mul(yy.gpu_memory(),
etl::size(yy), 1, lhs.value);
280 }
else if constexpr (!is_scalar<L> && is_scalar<R> && !is_axmy<L, R>) {
283 impl::egblas::scalar_mul(yy.gpu_memory(),
etl::size(yy), 1, rhs.value);
296 static std::string
desc() noexcept {
static vec_type< V > load(const vec_type< V > &lhs, const vec_type< V > &rhs) noexcept
Compute several applications of the operator at a time.
Definition: mul.hpp:173
typename V::template vec_type< T > vec_type
Definition: mul.hpp:153
static YY & gpu_compute(const L &lhs, const R &rhs, YY &yy) noexcept
Compute the result of the operation using the GPU.
Definition: mul.hpp:201
static constexpr int complexity()
Estimate the complexity of operator.
Definition: mul.hpp:145
A binary expression.
Definition: binary_expr.hpp:18
Root namespace for the ETL library.
Definition: adapter.hpp:15
static constexpr T apply(const T &lhs, const T &rhs) noexcept
Apply the unary operator on lhs and rhs.
Definition: mul.hpp:161
static auto gpu_compute_hint(const L &lhs, const R &rhs, Y &y) noexcept
Compute the result of the operation using the GPU.
Definition: mul.hpp:186
decltype(auto) force_temporary_gpu_dim_only(E &&expr)
Force a temporary out of the expression, without copying its content.
Definition: temporary.hpp:223
Represents a scalar value.
Definition: concepts_base.hpp:19
constexpr size_t size(const E &expr) noexcept
Returns the size of the given ETL expression.
Definition: helpers.hpp:108
Binary operator for scalar multiplication.
Definition: div.hpp:13
decltype(auto) smart_gpu_compute_hint(E &expr, Y &y)
Compute the expression into a representation that is GPU up to date.
Definition: helpers.hpp:368
AVX-512F is the max vectorization available.
static std::string desc() noexcept
Returns a textual representation of the operator.
Definition: mul.hpp:296
decltype(auto) smart_gpu_compute(X &x, Y &y)
Compute the expression into a representation that is GPU up to date and store this representation in ...
Definition: helpers.hpp:397