Expression Templates Library (ETL)
one_if_max_sub.hpp
Go to the documentation of this file.
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
13 #pragma once
14 
15 #ifdef ETL_EGBLAS_MODE
16 
17 #include "etl/impl/cublas/cuda.hpp"
18 
19 #include <egblas.hpp>
20 
21 #endif
22 
23 namespace etl::impl::egblas {
24 
25 #ifdef EGBLAS_HAS_SONE_IF_MAX_SUB
26 static constexpr bool has_sone_if_max_sub = true;
27 #else
28 static constexpr bool has_sone_if_max_sub = false;
29 #endif
30 
40 inline void one_if_max_sub([[maybe_unused]] size_t b,
41  [[maybe_unused]] size_t n,
42  [[maybe_unused]] float alpha,
43  [[maybe_unused]] float* A,
44  [[maybe_unused]] size_t lda,
45  [[maybe_unused]] float* B,
46  [[maybe_unused]] size_t ldb) {
47 #ifdef EGBLAS_HAS_SONE_IF_MAX_SUB
48  inc_counter("egblas");
49  egblas_sone_if_max_sub(b, n, alpha, A, lda, B, ldb);
50 #else
51  cpp_unreachable("Invalid call to egblas::one_if_max_sub");
52 #endif
53 }
54 
55 #ifdef EGBLAS_HAS_DONE_IF_MAX_SUB
56 static constexpr bool has_done_if_max_sub = true;
57 #else
58 static constexpr bool has_done_if_max_sub = false;
59 #endif
60 
70 inline void one_if_max_sub([[maybe_unused]] size_t b,
71  [[maybe_unused]] size_t n,
72  [[maybe_unused]] double alpha,
73  [[maybe_unused]] double* A,
74  [[maybe_unused]] size_t lda,
75  [[maybe_unused]] double* B,
76  [[maybe_unused]] size_t ldb) {
77 #ifdef EGBLAS_HAS_DONE_IF_MAX_SUB
78  inc_counter("egblas");
79  egblas_done_if_max_sub(b, n, alpha, A, lda, B, ldb);
80 #else
81  cpp_unreachable("Invalid call to egblas::one_if_max_sub");
82 #endif
83 }
84 
85 } //end of namespace etl::impl::egblas
Definition: abs.hpp:23
auto one_if_max_sub(const E &value)
Return, for each original position, 1.0 if the value is the max of the sub matrix, 0.0 otherwise.
Definition: expression_builder.hpp:488
void inc_counter([[maybe_unused]] const char *name)
Increase the given counter.
Definition: counters.hpp:25