Expression Templates Library (ETL)
outer.hpp
Go to the documentation of this file.
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
13 #pragma once
14 
15 #ifdef ETL_CUBLAS_MODE
16 
17 #include "etl/impl/cublas/cuda.hpp"
19 
20 #endif
21 
22 namespace etl::impl::cublas {
23 
24 #ifdef ETL_CUBLAS_MODE
25 
32 template <etl_single_precision A, etl_single_precision B, etl_single_precision C>
33 void batch_outer(const A& a, const B& b, C&& c) {
34  decltype(auto) handle = start_cublas();
35 
36  float alpha = 1.0;
37  float beta = 0.0;
38 
39  // This is brain-killing :s
40  // CUBLAS need matrices in column-major order. By switching both
41  // matrices, this is achieved. However, since one of the matrix
42  // needs to be transposed, it must be changed again
43 
45  b.ensure_gpu_up_to_date();
47 
48  cublas_check(cublasSgemm(handle.get(), CUBLAS_OP_N, CUBLAS_OP_T, etl::columns(c), etl::rows(c), etl::rows(b), &alpha, b.gpu_memory(), etl::columns(b),
49  a.gpu_memory(), etl::columns(a), &beta, c.gpu_memory(), etl::columns(b)));
50 
51  c.validate_gpu();
52  c.invalidate_cpu();
53 }
54 
58 template <etl_double_precision A, etl_double_precision B, etl_double_precision C>
59 void batch_outer(const A& a, const B& b, C&& c) {
60  decltype(auto) handle = start_cublas();
61 
62  double alpha = 1.0;
63  double beta = 0.0;
64 
65  a.ensure_gpu_up_to_date();
66  b.ensure_gpu_up_to_date();
67  c.ensure_gpu_allocated();
68 
69  cublas_check(cublasDgemm(handle.get(), CUBLAS_OP_N, CUBLAS_OP_T, etl::columns(c), etl::rows(c), etl::rows(b), &alpha, b.gpu_memory(), etl::columns(b),
70  a.gpu_memory(), etl::columns(a), &beta, c.gpu_memory(), etl::columns(b)));
71 
72  c.validate_gpu();
73  c.invalidate_cpu();
74 }
75 
76 #else
77 
81 template <typename A, typename B, typename C>
82 void batch_outer(const A& /*a*/, const B& /*b*/, C&& /*c*/) {
83  cpp_unreachable("CUBLAS not enabled/available");
84 }
85 
86 #endif
87 
88 } //end of namespace etl::impl::cublas
void ensure_gpu_allocated() const
Ensures that the GPU memory is allocated and that the GPU memory is up to date (to undefined value)...
Definition: sub_view.hpp:717
Definition: axpy.hpp:22
batch_outer_product_expr< detail::build_type< A >, detail::build_type< B > > batch_outer(A &&a, B &&b)
Batch Outer product multiplication of two matrices.
Definition: batch_outer_product_expr.hpp:333
Root namespace for the ETL library.
Definition: adapter.hpp:15
Utility functions for cublas.
size_t columns(const E &expr)
Returns the number of columns of the given ETL expression.
Definition: helpers.hpp:78
void invalidate_cpu() const noexcept
Invalidates the CPU memory.
Definition: sub_view.hpp:688
void ensure_gpu_up_to_date() const
Copy back from the GPU to the expression memory if necessary.
Definition: dyn_matrix_view.hpp:280
void validate_gpu() const noexcept
Validates the GPU memory.
Definition: sub_view.hpp:709
size_t rows(const E &expr)
Returns the number of rows of the given ETL expression.
Definition: helpers.hpp:58
value_type * gpu_memory() const noexcept
Return GPU memory of this expression, if any.
Definition: sub_view.hpp:674