Expression Templates Library (ETL)
mse.hpp
Go to the documentation of this file.
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
13 #pragma once
14 
15 #ifdef ETL_EGBLAS_MODE
16 
17 #include "etl/impl/cublas/cuda.hpp"
18 
19 #include <egblas.hpp>
20 
21 #endif
22 
23 namespace etl::impl::egblas {
24 
28 #ifdef EGBLAS_HAS_MSE_SLOSS
29 static constexpr bool has_mse_sloss = true;
30 #else
31 static constexpr bool has_mse_sloss = false;
32 #endif
33 
43 inline float mse_loss([[maybe_unused]] size_t n,
44  [[maybe_unused]] float alpha,
45  [[maybe_unused]] const float* A,
46  [[maybe_unused]] size_t lda,
47  [[maybe_unused]] const float* B,
48  [[maybe_unused]] size_t ldb) {
49 #ifdef EGBLAS_HAS_MSE_SLOSS
50  inc_counter("egblas");
51  return egblas_mse_sloss(n, alpha, A, lda, B, ldb);
52 #else
53  cpp_unreachable("Invalid call to egblas::mse_loss");
54 
55  return 0.0;
56 #endif
57 }
58 
62 #ifdef EGBLAS_HAS_MSE_DLOSS
63 static constexpr bool has_mse_dloss = true;
64 #else
65 static constexpr bool has_mse_dloss = false;
66 #endif
67 
77 inline double mse_loss([[maybe_unused]] size_t n,
78  [[maybe_unused]] double alpha,
79  [[maybe_unused]] const double* A,
80  [[maybe_unused]] size_t lda,
81  [[maybe_unused]] const double* B,
82  [[maybe_unused]] size_t ldb) {
83 #ifdef EGBLAS_HAS_MSE_DLOSS
84  inc_counter("egblas");
85  return egblas_mse_dloss(n, alpha, A, lda, B, ldb);
86 #else
87  cpp_unreachable("Invalid call to egblas::mse_loss");
88 
89  return 0.0;
90 #endif
91 }
92 
96 #ifdef EGBLAS_HAS_MSE_SERROR
97 static constexpr bool has_mse_serror = true;
98 #else
99 static constexpr bool has_mse_serror = false;
100 #endif
101 
111 inline float mse_error([[maybe_unused]] size_t n,
112  [[maybe_unused]] float alpha,
113  [[maybe_unused]] const float* A,
114  [[maybe_unused]] size_t lda,
115  [[maybe_unused]] const float* B,
116  [[maybe_unused]] size_t ldb) {
117 #ifdef EGBLAS_HAS_MSE_SERROR
118  inc_counter("egblas");
119  return egblas_mse_serror(n, alpha, A, lda, B, ldb);
120 #else
121  cpp_unreachable("Invalid call to egblas::mse_error");
122 
123  return 0.0;
124 #endif
125 }
126 
130 #ifdef EGBLAS_HAS_MSE_DERROR
131 static constexpr bool has_mse_derror = true;
132 #else
133 static constexpr bool has_mse_derror = false;
134 #endif
135 
145 inline double mse_error([[maybe_unused]] size_t n,
146  [[maybe_unused]] double alpha,
147  [[maybe_unused]] const double* A,
148  [[maybe_unused]] size_t lda,
149  [[maybe_unused]] const double* B,
150  [[maybe_unused]] size_t ldb) {
151 #ifdef EGBLAS_HAS_MSE_DERROR
152  inc_counter("egblas");
153  return egblas_mse_derror(n, alpha, A, lda, B, ldb);
154 #else
155  cpp_unreachable("Invalid call to egblas::mse_error");
156 
157  return 0.0;
158 #endif
159 }
160 
164 #ifdef EGBLAS_HAS_SMSE
165 static constexpr bool has_smse = true;
166 #else
167 static constexpr bool has_smse = false;
168 #endif
169 
179 inline std::pair<float, float> mse([[maybe_unused]] size_t n,
180  [[maybe_unused]] float alpha,
181  [[maybe_unused]] float beta,
182  [[maybe_unused]] const float* A,
183  [[maybe_unused]] size_t lda,
184  [[maybe_unused]] const float* B,
185  [[maybe_unused]] size_t ldb) {
186 #ifdef EGBLAS_HAS_SMSE
187  inc_counter("egblas");
188  return egblas_smse(n, alpha, beta, A, lda, B, ldb);
189 #else
190  cpp_unreachable("Invalid call to egblas::mse");
191 
192  return std::make_pair(0.0f, 0.0f);
193 #endif
194 }
195 
199 #ifdef EGBLAS_HAS_DMSE
200 static constexpr bool has_dmse = true;
201 #else
202 static constexpr bool has_dmse = false;
203 #endif
204 
214 inline std::pair<double, double> mse([[maybe_unused]] size_t n,
215  [[maybe_unused]] double alpha,
216  [[maybe_unused]] double beta,
217  [[maybe_unused]] const double* A,
218  [[maybe_unused]] size_t lda,
219  [[maybe_unused]] const double* B,
220  [[maybe_unused]] size_t ldb) {
221 #ifdef EGBLAS_HAS_DMSE
222  inc_counter("egblas");
223  return egblas_dmse(n, alpha, beta, A, lda, B, ldb);
224 #else
225  cpp_unreachable("Invalid call to egblas::mse");
226 
227  return std::make_pair(0.0, 0.0);
228 #endif
229 }
230 
231 } //end of namespace etl::impl::egblas
Definition: abs.hpp:23
void inc_counter([[maybe_unused]] const char *name)
Increase the given counter.
Definition: counters.hpp:25