Expression Templates Library (ETL)
cce.hpp
Go to the documentation of this file.
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
13 #pragma once
14 
15 #ifdef ETL_EGBLAS_MODE
16 
17 #include "etl/impl/cublas/cuda.hpp"
18 
19 #include <egblas.hpp>
20 
21 #endif
22 
23 namespace etl::impl::egblas {
24 
28 #ifdef EGBLAS_HAS_CCE_SLOSS
29 static constexpr bool has_cce_sloss = true;
30 #else
31 static constexpr bool has_cce_sloss = false;
32 #endif
33 
43 inline float cce_loss([[maybe_unused]] size_t n,
44  [[maybe_unused]] float alpha,
45  [[maybe_unused]] float* A,
46  [[maybe_unused]] size_t lda,
47  [[maybe_unused]] float* B,
48  [[maybe_unused]] size_t ldb) {
49 #ifdef EGBLAS_HAS_CCE_SLOSS
50  inc_counter("egblas");
51  return egblas_cce_sloss(n, alpha, A, lda, B, ldb);
52 #else
53  cpp_unreachable("Invalid call to egblas::cce_loss");
54 
55  return 0.0;
56 #endif
57 }
58 
62 #ifdef EGBLAS_HAS_CCE_DLOSS
63 static constexpr bool has_cce_dloss = true;
64 #else
65 static constexpr bool has_cce_dloss = false;
66 #endif
67 
77 inline double cce_loss([[maybe_unused]] size_t n,
78  [[maybe_unused]] double alpha,
79  [[maybe_unused]] double* A,
80  [[maybe_unused]] size_t lda,
81  [[maybe_unused]] double* B,
82  [[maybe_unused]] size_t ldb) {
83 #ifdef EGBLAS_HAS_CCE_DLOSS
84  inc_counter("egblas");
85  return egblas_cce_dloss(n, alpha, A, lda, B, ldb);
86 #else
87  cpp_unreachable("Invalid call to egblas::cce_loss");
88 
89  return 0.0;
90 #endif
91 }
92 
96 #ifdef EGBLAS_HAS_CCE_SERROR
97 static constexpr bool has_cce_serror = true;
98 #else
99 static constexpr bool has_cce_serror = false;
100 #endif
101 
111 inline float cce_error(
112  [[maybe_unused]] size_t n, [[maybe_unused]] size_t m, [[maybe_unused]] float alpha, [[maybe_unused]] float* A, [[maybe_unused]] float* B) {
113 #ifdef EGBLAS_HAS_CCE_SERROR
114  inc_counter("egblas");
115  return egblas_cce_serror(n, m, alpha, A, B);
116 #else
117  cpp_unreachable("Invalid call to egblas::cce_error");
118 
119  return 0.0;
120 #endif
121 }
122 
126 #ifdef EGBLAS_HAS_CCE_DERROR
127 static constexpr bool has_cce_derror = true;
128 #else
129 static constexpr bool has_cce_derror = false;
130 #endif
131 
141 inline double cce_error(
142  [[maybe_unused]] size_t n, [[maybe_unused]] size_t m, [[maybe_unused]] double alpha, [[maybe_unused]] double* A, [[maybe_unused]] double* B) {
143 #ifdef EGBLAS_HAS_CCE_DERROR
144  inc_counter("egblas");
145  return egblas_cce_derror(n, m, alpha, A, B);
146 #else
147  cpp_unreachable("Invalid call to egblas::cce_error");
148 
149  return 0.0;
150 #endif
151 }
152 
156 #ifdef EGBLAS_HAS_SCCE
157 static constexpr bool has_scce = true;
158 #else
159 static constexpr bool has_scce = false;
160 #endif
161 
171 inline std::pair<float, float> cce([[maybe_unused]] size_t n,
172  [[maybe_unused]] size_t m,
173  [[maybe_unused]] float alpha,
174  [[maybe_unused]] float beta,
175  [[maybe_unused]] float* A,
176  [[maybe_unused]] float* B) {
177 #ifdef EGBLAS_HAS_SCCE
178  inc_counter("egblas");
179  return egblas_scce(n, m, alpha, beta, A, B);
180 #else
181  cpp_unreachable("Invalid call to egblas::cce");
182 
183  return std::make_pair(0.0f, 0.0f);
184 #endif
185 }
186 
190 #ifdef EGBLAS_HAS_DCCE
191 static constexpr bool has_dcce = true;
192 #else
193 static constexpr bool has_dcce = false;
194 #endif
195 
205 inline std::pair<double, double> cce([[maybe_unused]] size_t n,
206  [[maybe_unused]] size_t m,
207  [[maybe_unused]] double alpha,
208  [[maybe_unused]] double beta,
209  [[maybe_unused]] double* A,
210  [[maybe_unused]] double* B) {
211 #ifdef EGBLAS_HAS_DCCE
212  inc_counter("egblas");
213  return egblas_dcce(n, m, alpha, beta, A, B);
214 #else
215  cpp_unreachable("Invalid call to egblas::cce");
216 
217  return std::make_pair(0.0, 0.0);
218 #endif
219 }
220 
221 } //end of namespace etl::impl::egblas
Definition: abs.hpp:23
void inc_counter([[maybe_unused]] const char *name)
Increase the given counter.
Definition: counters.hpp:25