Expression Templates Library (ETL)
axpy.hpp
Go to the documentation of this file.
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
13 #pragma once
14 
15 #ifdef ETL_CUBLAS_MODE
16 
17 #include "etl/impl/cublas/cuda.hpp"
19 
20 #endif
21 
22 namespace etl::impl::cublas {
23 
24 #ifdef ETL_CUBLAS_MODE
25 
36 inline void cublas_axpy(cublasHandle_t handle, size_t n, const float* alpha, const float* A, size_t lda, float* B, size_t ldb) {
37  cublas_check(cublasSaxpy(handle, n, alpha, A, lda, B, ldb));
38 }
39 
50 inline void cublas_axpy(cublasHandle_t handle, size_t n, const double* alpha, const double* A, size_t lda, double* B, size_t ldb) {
51  cublas_check(cublasDaxpy(handle, n, alpha, A, lda, B, ldb));
52 }
53 
64 inline void cublas_axpy(
65  cublasHandle_t handle, size_t n, const std::complex<float>* alpha, const std::complex<float>* A, size_t lda, std::complex<float>* B, size_t ldb) {
66  cublas_check(
67  cublasCaxpy(handle, n, reinterpret_cast<const cuComplex*>(alpha), reinterpret_cast<const cuComplex*>(A), lda, reinterpret_cast<cuComplex*>(B), ldb));
68 }
69 
80 inline void cublas_axpy(
81  cublasHandle_t handle, size_t n, const std::complex<double>* alpha, const std::complex<double>* A, size_t lda, std::complex<double>* B, size_t ldb) {
82  cublas_check(cublasZaxpy(handle, n, reinterpret_cast<const cuDoubleComplex*>(alpha), reinterpret_cast<const cuDoubleComplex*>(A), lda,
83  reinterpret_cast<cuDoubleComplex*>(B), ldb));
84 }
85 
96 inline void cublas_axpy(
97  cublasHandle_t handle, size_t n, const etl::complex<float>* alpha, const etl::complex<float>* A, size_t lda, etl::complex<float>* B, size_t ldb) {
98  cublas_check(
99  cublasCaxpy(handle, n, reinterpret_cast<const cuComplex*>(alpha), reinterpret_cast<const cuComplex*>(A), lda, reinterpret_cast<cuComplex*>(B), ldb));
100 }
101 
112 inline void cublas_axpy(
113  cublasHandle_t handle, size_t n, const etl::complex<double>* alpha, const etl::complex<double>* A, size_t lda, etl::complex<double>* B, size_t ldb) {
114  cublas_check(cublasZaxpy(handle, n, reinterpret_cast<const cuDoubleComplex*>(alpha), reinterpret_cast<const cuDoubleComplex*>(A), lda,
115  reinterpret_cast<cuDoubleComplex*>(B), ldb));
116 }
117 
118 #endif
119 
120 } //end of namespace etl::impl::cublas
Complex number implementation.
Definition: complex.hpp:31
Definition: axpy.hpp:22
Utility functions for cublas.