Expression Templates Library (ETL)
dot.hpp
Go to the documentation of this file.
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
13 #pragma once
14 
15 #ifdef ETL_CUBLAS_MODE
16 
17 #include "etl/impl/cublas/cuda.hpp"
19 
20 #endif
21 
22 namespace etl::impl::cublas {
23 
24 #ifdef ETL_CUBLAS_MODE
25 
32 template <etl_dma_single_precision A, etl_dma_single_precision B>
33 float dot(const A& a, const B& b) {
34  decltype(auto) handle = start_cublas();
35 
37  b.ensure_gpu_up_to_date();
38 
39  float prod = 0.0;
40  cublas_check(cublasSdot(handle.get(), etl::size(a), a.gpu_memory(), 1, b.gpu_memory(), 1, &prod));
41  return prod;
42 }
43 
47 template <etl_dma_double_precision A, etl_dma_double_precision B>
48 double dot(const A& a, const B& b) {
49  decltype(auto) handle = start_cublas();
50 
51  a.ensure_gpu_up_to_date();
52  b.ensure_gpu_up_to_date();
53 
54  double prod = 0.0;
55  cublas_check(cublasDdot(handle.get(), etl::size(a), a.gpu_memory(), 1, b.gpu_memory(), 1, &prod));
56  return prod;
57 }
58 
59 #else
60 
64 template <typename A, typename B>
65 value_t<A> dot(const A& /*a*/, const B& /*b*/) {
66  cpp_unreachable("CUBLAS not enabled/available");
67  return 0.0;
68 }
69 
70 #endif
71 
72 } //end of namespace etl::impl::cublas
value_t< A > dot(const A &a, const B &b)
Returns the dot product of the two given expressions.
Definition: expression_builder.hpp:594
Definition: axpy.hpp:22
Root namespace for the ETL library.
Definition: adapter.hpp:15
Utility functions for cublas.
void ensure_gpu_up_to_date() const
Copy back from the GPU to the expression memory if necessary.
Definition: dyn_matrix_view.hpp:280
typename decay_traits< E >::value_type value_t
Traits to extract the value type out of an ETL type.
Definition: tmp.hpp:81
value_type * gpu_memory() const noexcept
Return GPU memory of this expression, if any.
Definition: sub_view.hpp:674