Expression Templates Library (ETL)
sqrt.hpp
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
8 #pragma once
9 
10 #include "etl/impl/egblas/sqrt.hpp"
11 
12 namespace etl {
13 
18 template <typename T>
19 struct sqrt_unary_op {
20  static constexpr bool linear = true;
21  static constexpr bool thread_safe = true;
22 
28  template <vector_mode_t V>
29  static constexpr bool vectorizable = !is_complex_t<T>;
30 
34  template <typename E>
35  static constexpr bool gpu_computable = (is_single_precision_t<T> && impl::egblas::has_ssqrt) || (is_double_precision_t<T> && impl::egblas::has_dsqrt)
36  || (is_complex_single_t<T> && impl::egblas::has_csqrt) || (is_complex_double_t<T> && impl::egblas::has_zsqrt);
37 
42  static constexpr int complexity() {
43  return 12;
44  }
45 
49  template <typename V = default_vec>
50  using vec_type = typename V::template vec_type<T>;
51 
57  static constexpr T apply(const T& x) {
58  return std::sqrt(x);
59  }
60 
67  template <typename V = default_vec>
68  static vec_type<V> load(const vec_type<V>& x) noexcept {
69  return V::sqrt(x);
70  }
71 
79  template <typename X, typename Y>
80  static auto gpu_compute_hint(const X& x, Y& y) noexcept {
81  decltype(auto) t1 = smart_gpu_compute_hint(x, y);
82 
83  auto t2 = force_temporary_gpu_dim_only(t1);
84 
85  T alpha(1.0);
86  impl::egblas::sqrt(etl::size(y), alpha, t1.gpu_memory(), 1, t2.gpu_memory(), 1);
87 
88  return t2;
89  }
96  template <typename X, typename Y>
97  static Y& gpu_compute(const X& x, Y& y) noexcept {
98  decltype(auto) t1 = select_smart_gpu_compute(x, y);
99 
100  T alpha(1.0);
101  impl::egblas::sqrt(etl::size(y), alpha, t1.gpu_memory(), 1, y.gpu_memory(), 1);
102 
103  y.validate_gpu();
104  y.invalidate_cpu();
105 
106  return y;
107  }
108 
113  static std::string desc() noexcept {
114  return "sqrt";
115  }
116 };
117 
121 template <typename TT>
122 struct sqrt_unary_op<etl::complex<TT>> {
123  using T = etl::complex<TT>;
124 
125  static constexpr bool linear = true;
126  static constexpr bool thread_safe = true;
127 
133  template <vector_mode_t V>
134  static constexpr bool vectorizable = false;
135 
139  template <typename E>
140  static constexpr bool gpu_computable = (is_single_precision_t<T> && impl::egblas::has_ssqrt) || (is_double_precision_t<T> && impl::egblas::has_dsqrt)
141  || (is_complex_single_t<T> && impl::egblas::has_csqrt) || (is_complex_double_t<T> && impl::egblas::has_zsqrt);
142 
147  static constexpr int complexity() {
148  return 12;
149  }
150 
156  static constexpr T apply(const T& x) {
157  return etl::sqrt(x);
158  }
159 
167  template <typename X, typename Y>
168  static auto gpu_compute_hint(const X& x, Y& y) noexcept {
169  decltype(auto) t1 = smart_gpu_compute_hint(x, y);
170 
171  auto t2 = force_temporary_gpu_dim_only(t1);
172 
173  T alpha(1.0);
174  impl::egblas::sqrt(etl::size(y), alpha, t1.gpu_memory(), 1, t2.gpu_memory(), 1);
175 
176  return t2;
177  }
184  template <typename X, typename Y>
185  static Y& gpu_compute(const X& x, Y& y) noexcept {
186  decltype(auto) t1 = select_smart_gpu_compute(x, y);
187 
188  T alpha(1.0);
189  impl::egblas::sqrt(etl::size(y), alpha, t1.gpu_memory(), 1, y.gpu_memory(), 1);
190 
191  y.validate_gpu();
192  y.invalidate_cpu();
193 
194  return y;
195  }
196 
201  static std::string desc() noexcept {
202  return "sqrt";
203  }
204 };
205 
206 } //end of namespace etl
Complex number implementation.
Definition: complex.hpp:31
static Y & gpu_compute(const X &x, Y &y) noexcept
Compute the result of the operation using the GPU.
Definition: sqrt.hpp:97
static Y & gpu_compute(const X &x, Y &y) noexcept
Compute the result of the operation using the GPU.
Definition: sqrt.hpp:185
static std::string desc() noexcept
Returns a textual representation of the operator.
Definition: sqrt.hpp:113
auto sqrt(E &&value) -> detail::unary_helper< E, sqrt_unary_op >
Apply square root on each value of the given expression.
Definition: function_expression_builder.hpp:24
decltype(auto) select_smart_gpu_compute(X &x, Y &y)
Compute the expression into a representation that is GPU up to date and possibly store this represent...
Definition: helpers.hpp:434
static vec_type< V > load(const vec_type< V > &x) noexcept
Compute several applications of the operator at a time.
Definition: sqrt.hpp:68
EGBLAS wrappers for the sqrt operation.
Root namespace for the ETL library.
Definition: adapter.hpp:15
typename V::template vec_type< T > vec_type
Definition: sqrt.hpp:50
static constexpr int complexity()
Estimate the complexity of operator.
Definition: sqrt.hpp:42
static constexpr bool vectorizable
Indicates if the expression is vectorizable using the given vector mode.
Definition: sqrt.hpp:29
static constexpr T apply(const T &x)
Apply the unary operator on x.
Definition: sqrt.hpp:156
static constexpr bool thread_safe
Indicates if the operator is thread safe or not.
Definition: sqrt.hpp:21
static auto gpu_compute_hint(const X &x, Y &y) noexcept
Compute the result of the operation using the GPU.
Definition: sqrt.hpp:80
static constexpr bool gpu_computable
Indicates if the operator can be computed on GPU.
Definition: sqrt.hpp:35
decltype(auto) force_temporary_gpu_dim_only(E &&expr)
Force a temporary out of the expression, without copying its content.
Definition: temporary.hpp:223
static auto gpu_compute_hint(const X &x, Y &y) noexcept
Compute the result of the operation using the GPU.
Definition: sqrt.hpp:168
Unary operation taking the square root value.
Definition: sqrt.hpp:19
static constexpr bool linear
Indicates if the operator is linear.
Definition: sqrt.hpp:20
constexpr size_t size(const E &expr) noexcept
Returns the size of the given ETL expression.
Definition: helpers.hpp:108
static constexpr T apply(const T &x)
Apply the unary operator on x.
Definition: sqrt.hpp:57
decltype(auto) smart_gpu_compute_hint(E &expr, Y &y)
Compute the expression into a representation that is GPU up to date.
Definition: helpers.hpp:368
static constexpr int complexity()
Estimate the complexity of operator.
Definition: sqrt.hpp:147
static std::string desc() noexcept
Returns a textual representation of the operator.
Definition: sqrt.hpp:201