15 #ifdef ETL_EGBLAS_MODE 17 #include "etl/impl/cublas/cuda.hpp" 30 #ifdef EGBLAS_HAS_DROPOUT_PREPARE 31 static constexpr
bool has_dropout_prepare =
true;
33 static constexpr
bool has_dropout_prepare =
false;
41 #ifdef EGBLAS_HAS_DROPOUT_PREPARE 43 return egblas_dropout_prepare();
45 cpp_unreachable(
"Invalid call to egblas::dropout_prepare");
53 #ifdef EGBLAS_HAS_DROPOUT_PREPARE_SEED 54 static constexpr
bool has_dropout_prepare_seed =
true;
56 static constexpr
bool has_dropout_prepare_seed =
false;
65 #ifdef EGBLAS_HAS_DROPOUT_PREPARE_SEED 67 return egblas_dropout_prepare_seed(seed);
69 cpp_unreachable(
"Invalid call to egblas::dropout_prepare_seed");
77 #ifdef EGBLAS_HAS_DROPOUT_RELEASE 78 static constexpr
bool has_dropout_release =
true;
80 static constexpr
bool has_dropout_release =
false;
89 #ifdef EGBLAS_HAS_DROPOUT_RELEASE 91 egblas_dropout_release(state);
93 cpp_unreachable(
"Invalid call to egblas::dropout_release");
102 #ifdef EGBLAS_HAS_SDROPOUT 103 static constexpr
bool has_sdropout =
true;
105 static constexpr
bool has_sdropout =
false;
116 inline void dropout([[maybe_unused]]
size_t n, [[maybe_unused]]
float p, [[maybe_unused]]
float alpha, [[maybe_unused]]
float* A, [[maybe_unused]]
size_t lda) {
117 #ifdef EGBLAS_HAS_SDROPOUT 119 egblas_sdropout(n, p, alpha, A, lda);
121 cpp_unreachable(
"Invalid call to egblas::dropout");
128 #ifdef EGBLAS_HAS_DDROPOUT 129 static constexpr
bool has_ddropout =
true;
131 static constexpr
bool has_ddropout =
false;
143 [[maybe_unused]]
size_t n, [[maybe_unused]]
double p, [[maybe_unused]]
double alpha, [[maybe_unused]]
double* A, [[maybe_unused]]
size_t lda) {
144 #ifdef EGBLAS_HAS_DDROPOUT 146 egblas_ddropout(n, p, alpha, A, lda);
148 cpp_unreachable(
"Invalid call to egblas::dropout");
157 #ifdef EGBLAS_HAS_SDROPOUT_SEED 158 static constexpr
bool has_sdropout_seed =
true;
160 static constexpr
bool has_sdropout_seed =
false;
172 [[maybe_unused]]
float p,
173 [[maybe_unused]]
float alpha,
174 [[maybe_unused]]
float* A,
175 [[maybe_unused]]
size_t lda,
176 [[maybe_unused]]
size_t seed) {
177 #ifdef EGBLAS_HAS_SDROPOUT_SEED 179 egblas_sdropout_seed(n, p, alpha, A, lda, seed);
181 cpp_unreachable(
"Invalid call to egblas::dropout");
188 #ifdef EGBLAS_HAS_DDROPOUT_SEED 189 static constexpr
bool has_ddropout_seed =
true;
191 static constexpr
bool has_ddropout_seed =
false;
203 [[maybe_unused]]
double p,
204 [[maybe_unused]]
double alpha,
205 [[maybe_unused]]
double* A,
206 [[maybe_unused]]
size_t lda,
207 [[maybe_unused]]
size_t seed) {
208 #ifdef EGBLAS_HAS_DDROPOUT_SEED 210 egblas_ddropout_seed(n, p, alpha, A, lda, seed);
212 cpp_unreachable(
"Invalid call to egblas::dropout");
221 #ifdef EGBLAS_HAS_SDROPOUT_STATES 222 static constexpr
bool has_sdropout_states =
true;
224 static constexpr
bool has_sdropout_states =
false;
236 [[maybe_unused]]
float p,
237 [[maybe_unused]]
float alpha,
238 [[maybe_unused]]
float* A,
239 [[maybe_unused]]
size_t lda,
240 [[maybe_unused]]
void* states) {
241 #ifdef EGBLAS_HAS_SDROPOUT_STATES 243 egblas_sdropout_states(n, p, alpha, A, lda, states);
245 cpp_unreachable(
"Invalid call to egblas::dropout_states");
252 #ifdef EGBLAS_HAS_DDROPOUT_STATES 253 static constexpr
bool has_ddropout_states =
true;
255 static constexpr
bool has_ddropout_states =
false;
267 [[maybe_unused]]
double p,
268 [[maybe_unused]]
double alpha,
269 [[maybe_unused]]
double* A,
270 [[maybe_unused]]
size_t lda,
271 [[maybe_unused]]
void* states) {
272 #ifdef EGBLAS_HAS_DDROPOUT_STATES 274 egblas_ddropout_states(n, p, alpha, A, lda, states);
276 cpp_unreachable(
"Invalid call to egblas::dropout_states");
285 #ifdef EGBLAS_HAS_SINV_DROPOUT 286 static constexpr
bool has_sinv_dropout =
true;
288 static constexpr
bool has_sinv_dropout =
false;
300 [[maybe_unused]]
size_t n, [[maybe_unused]]
float p, [[maybe_unused]]
float alpha, [[maybe_unused]]
float* A, [[maybe_unused]]
size_t lda) {
301 #ifdef EGBLAS_HAS_SINV_DROPOUT 303 egblas_sinv_dropout(n, p, alpha, A, lda);
311 #ifdef EGBLAS_HAS_DINV_DROPOUT 312 static constexpr
bool has_dinv_dropout =
true;
314 static constexpr
bool has_dinv_dropout =
false;
326 [[maybe_unused]]
size_t n, [[maybe_unused]]
double p, [[maybe_unused]]
double alpha, [[maybe_unused]]
double* A, [[maybe_unused]]
size_t lda) {
327 #ifdef EGBLAS_HAS_DINV_DROPOUT 329 egblas_dinv_dropout(n, p, alpha, A, lda);
339 #ifdef EGBLAS_HAS_SINV_DROPOUT_SEED 340 static constexpr
bool has_sinv_dropout_seed =
true;
342 static constexpr
bool has_sinv_dropout_seed =
false;
354 [[maybe_unused]]
float p,
355 [[maybe_unused]]
float alpha,
356 [[maybe_unused]]
float* A,
357 [[maybe_unused]]
size_t lda,
358 [[maybe_unused]]
size_t seed) {
359 #ifdef EGBLAS_HAS_SINV_DROPOUT_SEED 361 egblas_sinv_dropout_seed(n, p, alpha, A, lda, seed);
363 cpp_unreachable(
"Invalid call to egblas::inv_dropout");
370 #ifdef EGBLAS_HAS_DINV_DROPOUT_SEED 371 static constexpr
bool has_dinv_dropout_seed =
true;
373 static constexpr
bool has_dinv_dropout_seed =
false;
385 [[maybe_unused]]
double p,
386 [[maybe_unused]]
double alpha,
387 [[maybe_unused]]
double* A,
388 [[maybe_unused]]
size_t lda,
389 [[maybe_unused]]
size_t seed) {
390 #ifdef EGBLAS_HAS_DINV_DROPOUT_SEED 392 egblas_dinv_dropout_seed(n, p, alpha, A, lda, seed);
394 cpp_unreachable(
"Invalid call to egblas::inv_dropout");
403 #ifdef EGBLAS_HAS_SINV_DROPOUT_STATES 404 static constexpr
bool has_sinv_dropout_states =
true;
406 static constexpr
bool has_sinv_dropout_states =
false;
418 [[maybe_unused]]
float p,
419 [[maybe_unused]]
float alpha,
420 [[maybe_unused]]
float* A,
421 [[maybe_unused]]
size_t lda,
422 [[maybe_unused]]
void* states) {
423 #ifdef EGBLAS_HAS_SINV_DROPOUT_STATES 425 egblas_sinv_dropout_states(n, p, alpha, A, lda, states);
427 cpp_unreachable(
"Invalid call to egblas::inv_dropout_states");
434 #ifdef EGBLAS_HAS_DINV_DROPOUT_STATES 435 static constexpr
bool has_dinv_dropout_states =
true;
437 static constexpr
bool has_dinv_dropout_states =
false;
449 [[maybe_unused]]
double p,
450 [[maybe_unused]]
double alpha,
451 [[maybe_unused]]
double* A,
452 [[maybe_unused]]
size_t lda,
453 [[maybe_unused]]
void* states) {
454 #ifdef EGBLAS_HAS_DINV_DROPOUT_STATES 456 egblas_dinv_dropout_states(n, p, alpha, A, lda, states);
458 cpp_unreachable(
"Invalid call to egblas::inv_dropout_states");
void dropout([[maybe_unused]] size_t n, [[maybe_unused]] float p, [[maybe_unused]] float alpha, [[maybe_unused]] float *A, [[maybe_unused]] size_t lda)
Wrappers for single-precision egblas dropout.
Definition: dropout.hpp:116
void dropout_seed([[maybe_unused]] size_t n, [[maybe_unused]] float p, [[maybe_unused]] float alpha, [[maybe_unused]] float *A, [[maybe_unused]] size_t lda, [[maybe_unused]] size_t seed)
Wrappers for single-precision egblas dropout.
Definition: dropout.hpp:171
void * dropout_prepare_seed([[maybe_unused]] size_t seed)
Prepare random states for dropout with the given seed.
Definition: dropout.hpp:64
void inv_dropout([[maybe_unused]] size_t n, [[maybe_unused]] float p, [[maybe_unused]] float alpha, [[maybe_unused]] float *A, [[maybe_unused]] size_t lda)
Wrappers for single-precision egblas dropout.
Definition: dropout.hpp:299
void dropout_release([[maybe_unused]] void *state)
Prepare random states for dropout with the given seed.
Definition: dropout.hpp:88
void dropout_states([[maybe_unused]] size_t n, [[maybe_unused]] float p, [[maybe_unused]] float alpha, [[maybe_unused]] float *A, [[maybe_unused]] size_t lda, [[maybe_unused]] void *states)
Wrappers for single-precision egblas dropout.
Definition: dropout.hpp:235
void * dropout_prepare()
Prepare random states for dropout.
Definition: dropout.hpp:40
void inv_dropout_states([[maybe_unused]] size_t n, [[maybe_unused]] float p, [[maybe_unused]] float alpha, [[maybe_unused]] float *A, [[maybe_unused]] size_t lda, [[maybe_unused]] void *states)
Wrappers for single-precision egblas dropout.
Definition: dropout.hpp:417
void inc_counter([[maybe_unused]] const char *name)
Increase the given counter.
Definition: counters.hpp:25
void inv_dropout_seed([[maybe_unused]] size_t n, [[maybe_unused]] float p, [[maybe_unused]] float alpha, [[maybe_unused]] float *A, [[maybe_unused]] size_t lda, [[maybe_unused]] size_t seed)
Wrappers for single-precision egblas dropout.
Definition: dropout.hpp:353