Expression Templates Library (ETL)
sqrt.hpp
Go to the documentation of this file.
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
13 #pragma once
14 
15 #ifdef ETL_EGBLAS_MODE
16 
17 #include "etl/impl/cublas/cuda.hpp"
18 
19 #include <egblas.hpp>
20 
21 #endif
22 
23 namespace etl::impl::egblas {
24 
28 #ifdef EGBLAS_HAS_SSQRT
29 static constexpr bool has_ssqrt = true;
30 #else
31 static constexpr bool has_ssqrt = false;
32 #endif
33 
43 inline void sqrt([[maybe_unused]] size_t n,
44  [[maybe_unused]] float alpha,
45  [[maybe_unused]] float* A,
46  [[maybe_unused]] size_t lda,
47  [[maybe_unused]] float* B,
48  [[maybe_unused]] size_t ldb) {
49 #ifdef EGBLAS_HAS_SSQRT
50  inc_counter("egblas");
51  egblas_ssqrt(n, alpha, A, lda, B, ldb);
52 #else
53  cpp_unreachable("Invalid call to egblas::sqrt");
54 #endif
55 }
56 
60 #ifdef EGBLAS_HAS_DSQRT
61 static constexpr bool has_dsqrt = true;
62 #else
63 static constexpr bool has_dsqrt = false;
64 #endif
65 
75 inline void sqrt([[maybe_unused]] size_t n,
76  [[maybe_unused]] double alpha,
77  [[maybe_unused]] double* A,
78  [[maybe_unused]] size_t lda,
79  [[maybe_unused]] double* B,
80  [[maybe_unused]] size_t ldb) {
81 #ifdef EGBLAS_HAS_DSQRT
82  inc_counter("egblas");
83  egblas_dsqrt(n, alpha, A, lda, B, ldb);
84 #else
85  cpp_unreachable("Invalid call to egblas::sqrt");
86 #endif
87 }
88 
93 #ifdef EGBLAS_HAS_CSQRT
94 static constexpr bool has_csqrt = true;
95 #else
96 static constexpr bool has_csqrt = false;
97 #endif
98 
108 inline void sqrt([[maybe_unused]] size_t n,
109  [[maybe_unused]] std::complex<float> alpha,
110  [[maybe_unused]] std::complex<float>* A,
111  [[maybe_unused]] size_t lda,
112  [[maybe_unused]] std::complex<float>* B,
113  [[maybe_unused]] size_t ldb) {
114 #ifdef EGBLAS_HAS_CSQRT
115  inc_counter("egblas");
116  egblas_csqrt(n, complex_cast(alpha), reinterpret_cast<cuComplex*>(A), lda, reinterpret_cast<cuComplex*>(B), ldb);
117 #else
118  cpp_unreachable("Invalid call to egblas::sqrt");
119 #endif
120 }
121 
131 inline void sqrt([[maybe_unused]] size_t n,
132  [[maybe_unused]] etl::complex<float> alpha,
133  [[maybe_unused]] etl::complex<float>* A,
134  [[maybe_unused]] size_t lda,
135  [[maybe_unused]] etl::complex<float>* B,
136  [[maybe_unused]] size_t ldb) {
137 #ifdef EGBLAS_HAS_CSQRT
138  inc_counter("egblas");
139  egblas_csqrt(n, complex_cast(alpha), reinterpret_cast<cuComplex*>(A), lda, reinterpret_cast<cuComplex*>(B), ldb);
140 #else
141  cpp_unreachable("Invalid call to egblas::sqrt");
142 #endif
143 }
144 
149 #ifdef EGBLAS_HAS_ZSQRT
150 static constexpr bool has_zsqrt = true;
151 #else
152 static constexpr bool has_zsqrt = false;
153 #endif
154 
164 inline void sqrt([[maybe_unused]] size_t n,
165  [[maybe_unused]] std::complex<double> alpha,
166  [[maybe_unused]] std::complex<double>* A,
167  [[maybe_unused]] size_t lda,
168  [[maybe_unused]] std::complex<double>* B,
169  [[maybe_unused]] size_t ldb) {
170 #ifdef EGBLAS_HAS_ZSQRT
171  inc_counter("egblas");
172  egblas_zsqrt(n, complex_cast(alpha), reinterpret_cast<cuDoubleComplex*>(A), lda, reinterpret_cast<cuDoubleComplex*>(B), ldb);
173 #else
174  cpp_unreachable("Invalid call to egblas::sqrt");
175 #endif
176 }
177 
187 inline void sqrt([[maybe_unused]] size_t n,
188  [[maybe_unused]] etl::complex<double> alpha,
189  [[maybe_unused]] etl::complex<double>* A,
190  [[maybe_unused]] size_t lda,
191  [[maybe_unused]] etl::complex<double>* B,
192  [[maybe_unused]] size_t ldb) {
193 #ifdef EGBLAS_HAS_ZSQRT
194  inc_counter("egblas");
195  egblas_zsqrt(n, complex_cast(alpha), reinterpret_cast<cuDoubleComplex*>(A), lda, reinterpret_cast<cuDoubleComplex*>(B), ldb);
196 #else
197  cpp_unreachable("Invalid call to egblas::sqrt");
198 #endif
199 }
200 
201 } //end of namespace etl::impl::egblas
Complex number implementation.
Definition: complex.hpp:31
auto sqrt(E &&value) -> detail::unary_helper< E, sqrt_unary_op >
Apply square root on each value of the given expression.
Definition: function_expression_builder.hpp:24
Definition: abs.hpp:23
void inc_counter([[maybe_unused]] const char *name)
Increase the given counter.
Definition: counters.hpp:25