Expression Templates Library (ETL)
dropout.hpp
Go to the documentation of this file.
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
13 #pragma once
14 
15 #ifdef ETL_EGBLAS_MODE
16 
17 #include "etl/impl/cublas/cuda.hpp"
18 
19 #include <egblas.hpp>
20 
21 #endif
22 
23 namespace etl::impl::egblas {
24 
25 // dropout prepare
26 
30 #ifdef EGBLAS_HAS_DROPOUT_PREPARE
31 static constexpr bool has_dropout_prepare = true;
32 #else
33 static constexpr bool has_dropout_prepare = false;
34 #endif
35 
40 inline void* dropout_prepare() {
41 #ifdef EGBLAS_HAS_DROPOUT_PREPARE
42  inc_counter("egblas");
43  return egblas_dropout_prepare();
44 #else
45  cpp_unreachable("Invalid call to egblas::dropout_prepare");
46  return nullptr;
47 #endif
48 }
49 
53 #ifdef EGBLAS_HAS_DROPOUT_PREPARE_SEED
54 static constexpr bool has_dropout_prepare_seed = true;
55 #else
56 static constexpr bool has_dropout_prepare_seed = false;
57 #endif
58 
64 inline void* dropout_prepare_seed([[maybe_unused]] size_t seed) {
65 #ifdef EGBLAS_HAS_DROPOUT_PREPARE_SEED
66  inc_counter("egblas");
67  return egblas_dropout_prepare_seed(seed);
68 #else
69  cpp_unreachable("Invalid call to egblas::dropout_prepare_seed");
70  return nullptr;
71 #endif
72 }
73 
77 #ifdef EGBLAS_HAS_DROPOUT_RELEASE
78 static constexpr bool has_dropout_release = true;
79 #else
80 static constexpr bool has_dropout_release = false;
81 #endif
82 
88 inline void dropout_release([[maybe_unused]] void* state) {
89 #ifdef EGBLAS_HAS_DROPOUT_RELEASE
90  inc_counter("egblas");
91  egblas_dropout_release(state);
92 #else
93  cpp_unreachable("Invalid call to egblas::dropout_release");
94 #endif
95 }
96 
97 // dropout
98 
102 #ifdef EGBLAS_HAS_SDROPOUT
103 static constexpr bool has_sdropout = true;
104 #else
105 static constexpr bool has_sdropout = false;
106 #endif
107 
116 inline void dropout([[maybe_unused]] size_t n, [[maybe_unused]] float p, [[maybe_unused]] float alpha, [[maybe_unused]] float* A, [[maybe_unused]] size_t lda) {
117 #ifdef EGBLAS_HAS_SDROPOUT
118  inc_counter("egblas");
119  egblas_sdropout(n, p, alpha, A, lda);
120 #else
121  cpp_unreachable("Invalid call to egblas::dropout");
122 #endif
123 }
124 
128 #ifdef EGBLAS_HAS_DDROPOUT
129 static constexpr bool has_ddropout = true;
130 #else
131 static constexpr bool has_ddropout = false;
132 #endif
133 
142 inline void dropout(
143  [[maybe_unused]] size_t n, [[maybe_unused]] double p, [[maybe_unused]] double alpha, [[maybe_unused]] double* A, [[maybe_unused]] size_t lda) {
144 #ifdef EGBLAS_HAS_DDROPOUT
145  inc_counter("egblas");
146  egblas_ddropout(n, p, alpha, A, lda);
147 #else
148  cpp_unreachable("Invalid call to egblas::dropout");
149 #endif
150 }
151 
152 // dropout_seed
153 
157 #ifdef EGBLAS_HAS_SDROPOUT_SEED
158 static constexpr bool has_sdropout_seed = true;
159 #else
160 static constexpr bool has_sdropout_seed = false;
161 #endif
162 
171 inline void dropout_seed([[maybe_unused]] size_t n,
172  [[maybe_unused]] float p,
173  [[maybe_unused]] float alpha,
174  [[maybe_unused]] float* A,
175  [[maybe_unused]] size_t lda,
176  [[maybe_unused]] size_t seed) {
177 #ifdef EGBLAS_HAS_SDROPOUT_SEED
178  inc_counter("egblas");
179  egblas_sdropout_seed(n, p, alpha, A, lda, seed);
180 #else
181  cpp_unreachable("Invalid call to egblas::dropout");
182 #endif
183 }
184 
188 #ifdef EGBLAS_HAS_DDROPOUT_SEED
189 static constexpr bool has_ddropout_seed = true;
190 #else
191 static constexpr bool has_ddropout_seed = false;
192 #endif
193 
202 inline void dropout_seed([[maybe_unused]] size_t n,
203  [[maybe_unused]] double p,
204  [[maybe_unused]] double alpha,
205  [[maybe_unused]] double* A,
206  [[maybe_unused]] size_t lda,
207  [[maybe_unused]] size_t seed) {
208 #ifdef EGBLAS_HAS_DDROPOUT_SEED
209  inc_counter("egblas");
210  egblas_ddropout_seed(n, p, alpha, A, lda, seed);
211 #else
212  cpp_unreachable("Invalid call to egblas::dropout");
213 #endif
214 }
215 
216 // dropout_states
217 
221 #ifdef EGBLAS_HAS_SDROPOUT_STATES
222 static constexpr bool has_sdropout_states = true;
223 #else
224 static constexpr bool has_sdropout_states = false;
225 #endif
226 
235 inline void dropout_states([[maybe_unused]] size_t n,
236  [[maybe_unused]] float p,
237  [[maybe_unused]] float alpha,
238  [[maybe_unused]] float* A,
239  [[maybe_unused]] size_t lda,
240  [[maybe_unused]] void* states) {
241 #ifdef EGBLAS_HAS_SDROPOUT_STATES
242  inc_counter("egblas");
243  egblas_sdropout_states(n, p, alpha, A, lda, states);
244 #else
245  cpp_unreachable("Invalid call to egblas::dropout_states");
246 #endif
247 }
248 
252 #ifdef EGBLAS_HAS_DDROPOUT_STATES
253 static constexpr bool has_ddropout_states = true;
254 #else
255 static constexpr bool has_ddropout_states = false;
256 #endif
257 
266 inline void dropout_states([[maybe_unused]] size_t n,
267  [[maybe_unused]] double p,
268  [[maybe_unused]] double alpha,
269  [[maybe_unused]] double* A,
270  [[maybe_unused]] size_t lda,
271  [[maybe_unused]] void* states) {
272 #ifdef EGBLAS_HAS_DDROPOUT_STATES
273  inc_counter("egblas");
274  egblas_ddropout_states(n, p, alpha, A, lda, states);
275 #else
276  cpp_unreachable("Invalid call to egblas::dropout_states");
277 #endif
278 }
279 
280 // inverted dropout
281 
285 #ifdef EGBLAS_HAS_SINV_DROPOUT
286 static constexpr bool has_sinv_dropout = true;
287 #else
288 static constexpr bool has_sinv_dropout = false;
289 #endif
290 
299 inline void inv_dropout(
300  [[maybe_unused]] size_t n, [[maybe_unused]] float p, [[maybe_unused]] float alpha, [[maybe_unused]] float* A, [[maybe_unused]] size_t lda) {
301 #ifdef EGBLAS_HAS_SINV_DROPOUT
302  inc_counter("egblas");
303  egblas_sinv_dropout(n, p, alpha, A, lda);
304 #else
305 #endif
306 }
307 
311 #ifdef EGBLAS_HAS_DINV_DROPOUT
312 static constexpr bool has_dinv_dropout = true;
313 #else
314 static constexpr bool has_dinv_dropout = false;
315 #endif
316 
325 inline void inv_dropout(
326  [[maybe_unused]] size_t n, [[maybe_unused]] double p, [[maybe_unused]] double alpha, [[maybe_unused]] double* A, [[maybe_unused]] size_t lda) {
327 #ifdef EGBLAS_HAS_DINV_DROPOUT
328  inc_counter("egblas");
329  egblas_dinv_dropout(n, p, alpha, A, lda);
330 #else
331 #endif
332 }
333 
334 // inv_dropout_seed
335 
339 #ifdef EGBLAS_HAS_SINV_DROPOUT_SEED
340 static constexpr bool has_sinv_dropout_seed = true;
341 #else
342 static constexpr bool has_sinv_dropout_seed = false;
343 #endif
344 
353 inline void inv_dropout_seed([[maybe_unused]] size_t n,
354  [[maybe_unused]] float p,
355  [[maybe_unused]] float alpha,
356  [[maybe_unused]] float* A,
357  [[maybe_unused]] size_t lda,
358  [[maybe_unused]] size_t seed) {
359 #ifdef EGBLAS_HAS_SINV_DROPOUT_SEED
360  inc_counter("egblas");
361  egblas_sinv_dropout_seed(n, p, alpha, A, lda, seed);
362 #else
363  cpp_unreachable("Invalid call to egblas::inv_dropout");
364 #endif
365 }
366 
370 #ifdef EGBLAS_HAS_DINV_DROPOUT_SEED
371 static constexpr bool has_dinv_dropout_seed = true;
372 #else
373 static constexpr bool has_dinv_dropout_seed = false;
374 #endif
375 
384 inline void inv_dropout_seed([[maybe_unused]] size_t n,
385  [[maybe_unused]] double p,
386  [[maybe_unused]] double alpha,
387  [[maybe_unused]] double* A,
388  [[maybe_unused]] size_t lda,
389  [[maybe_unused]] size_t seed) {
390 #ifdef EGBLAS_HAS_DINV_DROPOUT_SEED
391  inc_counter("egblas");
392  egblas_dinv_dropout_seed(n, p, alpha, A, lda, seed);
393 #else
394  cpp_unreachable("Invalid call to egblas::inv_dropout");
395 #endif
396 }
397 
398 // inv_dropout_states
399 
403 #ifdef EGBLAS_HAS_SINV_DROPOUT_STATES
404 static constexpr bool has_sinv_dropout_states = true;
405 #else
406 static constexpr bool has_sinv_dropout_states = false;
407 #endif
408 
417 inline void inv_dropout_states([[maybe_unused]] size_t n,
418  [[maybe_unused]] float p,
419  [[maybe_unused]] float alpha,
420  [[maybe_unused]] float* A,
421  [[maybe_unused]] size_t lda,
422  [[maybe_unused]] void* states) {
423 #ifdef EGBLAS_HAS_SINV_DROPOUT_STATES
424  inc_counter("egblas");
425  egblas_sinv_dropout_states(n, p, alpha, A, lda, states);
426 #else
427  cpp_unreachable("Invalid call to egblas::inv_dropout_states");
428 #endif
429 }
430 
434 #ifdef EGBLAS_HAS_DINV_DROPOUT_STATES
435 static constexpr bool has_dinv_dropout_states = true;
436 #else
437 static constexpr bool has_dinv_dropout_states = false;
438 #endif
439 
448 inline void inv_dropout_states([[maybe_unused]] size_t n,
449  [[maybe_unused]] double p,
450  [[maybe_unused]] double alpha,
451  [[maybe_unused]] double* A,
452  [[maybe_unused]] size_t lda,
453  [[maybe_unused]] void* states) {
454 #ifdef EGBLAS_HAS_DINV_DROPOUT_STATES
455  inc_counter("egblas");
456  egblas_dinv_dropout_states(n, p, alpha, A, lda, states);
457 #else
458  cpp_unreachable("Invalid call to egblas::inv_dropout_states");
459 #endif
460 }
461 
462 } //end of namespace etl::impl::egblas
Definition: abs.hpp:23
void dropout([[maybe_unused]] size_t n, [[maybe_unused]] float p, [[maybe_unused]] float alpha, [[maybe_unused]] float *A, [[maybe_unused]] size_t lda)
Wrappers for single-precision egblas dropout.
Definition: dropout.hpp:116
void dropout_seed([[maybe_unused]] size_t n, [[maybe_unused]] float p, [[maybe_unused]] float alpha, [[maybe_unused]] float *A, [[maybe_unused]] size_t lda, [[maybe_unused]] size_t seed)
Wrappers for single-precision egblas dropout.
Definition: dropout.hpp:171
void * dropout_prepare_seed([[maybe_unused]] size_t seed)
Prepare random states for dropout with the given seed.
Definition: dropout.hpp:64
void inv_dropout([[maybe_unused]] size_t n, [[maybe_unused]] float p, [[maybe_unused]] float alpha, [[maybe_unused]] float *A, [[maybe_unused]] size_t lda)
Wrappers for single-precision egblas dropout.
Definition: dropout.hpp:299
void dropout_release([[maybe_unused]] void *state)
Prepare random states for dropout with the given seed.
Definition: dropout.hpp:88
void dropout_states([[maybe_unused]] size_t n, [[maybe_unused]] float p, [[maybe_unused]] float alpha, [[maybe_unused]] float *A, [[maybe_unused]] size_t lda, [[maybe_unused]] void *states)
Wrappers for single-precision egblas dropout.
Definition: dropout.hpp:235
void * dropout_prepare()
Prepare random states for dropout.
Definition: dropout.hpp:40
void inv_dropout_states([[maybe_unused]] size_t n, [[maybe_unused]] float p, [[maybe_unused]] float alpha, [[maybe_unused]] float *A, [[maybe_unused]] size_t lda, [[maybe_unused]] void *states)
Wrappers for single-precision egblas dropout.
Definition: dropout.hpp:417
void inc_counter([[maybe_unused]] const char *name)
Increase the given counter.
Definition: counters.hpp:25
void inv_dropout_seed([[maybe_unused]] size_t n, [[maybe_unused]] float p, [[maybe_unused]] float alpha, [[maybe_unused]] float *A, [[maybe_unused]] size_t lda, [[maybe_unused]] size_t seed)
Wrappers for single-precision egblas dropout.
Definition: dropout.hpp:353