15 #ifdef ETL_CUBLAS_MODE 17 #include "etl/impl/cublas/cuda.hpp" 24 #ifdef ETL_CUBLAS_MODE 32 template <etl_dma_single_precision A, etl_dma_single_precision B>
33 float dot(
const A& a,
const B& b) {
34 decltype(
auto) handle = start_cublas();
37 b.ensure_gpu_up_to_date();
40 cublas_check(cublasSdot(handle.get(),
etl::size(a), a.
gpu_memory(), 1, b.gpu_memory(), 1, &prod));
47 template <etl_dma_double_precision A, etl_dma_double_precision B>
48 double dot(const A& a, const B& b) {
49 decltype(
auto) handle = start_cublas();
51 a.ensure_gpu_up_to_date();
52 b.ensure_gpu_up_to_date();
55 cublas_check(cublasDdot(handle.get(),
etl::size(a), a.gpu_memory(), 1, b.gpu_memory(), 1, &prod));
64 template <
typename A,
typename B>
66 cpp_unreachable(
"CUBLAS not enabled/available");
value_t< A > dot(const A &a, const B &b)
Returns the dot product of the two given expressions.
Definition: expression_builder.hpp:594
Root namespace for the ETL library.
Definition: adapter.hpp:15
Utility functions for cublas.
void ensure_gpu_up_to_date() const
Copy back from the GPU to the expression memory if necessary.
Definition: dyn_matrix_view.hpp:280
typename decay_traits< E >::value_type value_t
Traits to extract the value type out of an ETL type.
Definition: tmp.hpp:81
value_type * gpu_memory() const noexcept
Return GPU memory of this expression, if any.
Definition: sub_view.hpp:674