24 template <
typename T =
double>
29 std::shared_ptr<void*> states;
37 impl::egblas::has_dropout_prepare && impl::egblas::has_dropout_release
38 && ((is_single_precision_t<T> && impl::egblas::has_sdropout_states) || (is_double_precision_t<T> && impl::egblas::has_ddropout_states));
44 if constexpr (impl::egblas::has_dropout_prepare) {
45 states = std::make_shared<void*>();
46 *states = impl::egblas::dropout_prepare();
73 impl::egblas::dropout_states(
etl::size(y), probability, T(1), t1.gpu_memory(), 1, *states);
85 impl::egblas::dropout_states(
etl::size(y), probability, T(1), y.gpu_memory(), 1, *states);
107 template <
typename G,
typename T =
double>
113 std::shared_ptr<void*> states;
120 impl::egblas::has_dropout_prepare_seed && impl::egblas::has_dropout_release
121 && ((is_single_precision_t<T> && impl::egblas::has_sdropout_states) || (is_double_precision_t<T> && impl::egblas::has_ddropout_states));
129 if constexpr (impl::egblas::has_dropout_prepare) {
130 states = std::make_shared<void*>();
132 std::uniform_int_distribution<long> seed_dist;
133 *states = impl::egblas::dropout_prepare_seed(seed_dist(rand_engine));
156 template <
typename Y>
160 impl::egblas::dropout_states(
etl::size(y), probability, T(1), t1.gpu_memory(), 1, states);
170 template <
typename Y>
172 impl::egblas::dropout_states(
etl::size(y), probability, T(1), y.gpu_memory(), 1, states);
auto s(T &&value)
Force the evaluation of the given expression.
Definition: stop.hpp:18
dropout_distribution< value_type > distribution
The used distribution.
Definition: state_dropout_mask.hpp:114
value_type operator()()
Generate a new value.
Definition: state_dropout_mask.hpp:54
EGBLAS wrappers for the dropout operation.
const T probability
The dropout probability.
Definition: state_dropout_mask.hpp:28
Generator from an uniform distribution.
Definition: state_dropout_mask.hpp:25
friend std::ostream & operator<<(std::ostream &os, const state_dropout_mask_generator_g_op &s)
Outputs the given generator to the given stream.
Definition: state_dropout_mask.hpp:186
auto gpu_compute_hint(Y &y) noexcept
Compute the result of the operation using the GPU.
Definition: state_dropout_mask.hpp:70
dropout_distribution< value_type > distribution
The used distribution.
Definition: state_dropout_mask.hpp:31
std::conditional_t< std::is_floating_point_v< T >, std::uniform_real_distribution< T >, std::uniform_int_distribution< T > > dropout_distribution
Selector helper to get an dropout_distribution based on the type (real or int)
Definition: dropout_mask.hpp:26
random_engine rand_engine
The random engine.
Definition: state_dropout_mask.hpp:30
Root namespace for the ETL library.
Definition: adapter.hpp:15
const T probability
The dropout probability.
Definition: state_dropout_mask.hpp:111
Y & gpu_compute(Y &y) noexcept
Compute the result of the operation using the GPU.
Definition: state_dropout_mask.hpp:171
Generator from an uniform distribution using a custom random engine.
Definition: state_dropout_mask.hpp:108
static constexpr bool gpu_computable
Indicates if the operator can be computed on GPU.
Definition: state_dropout_mask.hpp:36
G & rand_engine
The random engine.
Definition: state_dropout_mask.hpp:112
auto gpu_compute_hint(Y &y) noexcept
Compute the result of the operation using the GPU.
Definition: state_dropout_mask.hpp:157
state_dropout_mask_generator_op(T probability)
Construct a new generator with the given start and end of the range.
Definition: state_dropout_mask.hpp:43
decltype(auto) force_temporary_gpu_dim_only(E &&expr)
Force a temporary out of the expression, without copying its content.
Definition: temporary.hpp:223
constexpr size_t size(const E &expr) noexcept
Returns the size of the given ETL expression.
Definition: helpers.hpp:108
Y & gpu_compute(Y &y) noexcept
Compute the result of the operation using the GPU.
Definition: state_dropout_mask.hpp:84
state_dropout_mask_generator_g_op(G &g, T probability)
Construct a new generator with the given start and end of the range.
Definition: state_dropout_mask.hpp:128
value_type operator()()
Generate a new value.
Definition: state_dropout_mask.hpp:141
T value_type
The value type.
Definition: state_dropout_mask.hpp:26
std::mt19937_64 random_engine
The random engine used by the library.
Definition: random.hpp:22
T value_type
The value type.
Definition: state_dropout_mask.hpp:109
friend std::ostream & operator<<(std::ostream &os, const state_dropout_mask_generator_op &s)
Outputs the given generator to the given stream.
Definition: state_dropout_mask.hpp:99