15 #ifdef ETL_CURAND_MODE 31 using uniform_distribution = std::conditional_t<std::is_floating_point_v<T>, std::uniform_real_distribution<T>, std::uniform_int_distribution<T>>;
36 template <
typename T =
double>
49 && ((is_single_precision_t<T> && impl::egblas::has_scalar_sadd && impl::egblas::has_scalar_smul)
50 || (is_double_precision_t<T> && impl::egblas::has_scalar_dadd && impl::egblas::has_scalar_dmul));
57 uniform_generator_op(T start, T end) : start(start), end(end), rand_engine(std::time(nullptr)), distribution(start, end) {}
67 #ifdef ETL_CURAND_MODE 77 auto t1 = etl::force_temporary_gpu_dim_only_t<T>(y);
79 curandGenerator_t gen;
82 curand_call(curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT));
85 std::uniform_int_distribution<long> seed_dist;
86 curand_call(curandSetPseudoRandomGeneratorSeed(gen, seed_dist(rand_engine)));
89 impl::curand::generate_uniform(gen, t1.gpu_memory(),
etl::size(y));
92 auto s1 = T(end) - T(start);
93 impl::egblas::scalar_mul(t1.gpu_memory(),
etl::size(y), 1, s1);
97 impl::egblas::scalar_add(t1.gpu_memory(),
etl::size(y), 1, s2);
108 template <
typename Y>
109 Y& gpu_compute(Y& y) noexcept {
110 y.ensure_gpu_allocated();
112 curandGenerator_t gen;
115 curand_call(curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT));
118 std::uniform_int_distribution<long> seed_dist;
119 curand_call(curandSetPseudoRandomGeneratorSeed(gen, seed_dist(rand_engine)));
122 impl::curand::generate_uniform(gen, y.gpu_memory(),
etl::size(y));
125 auto s1 = T(end) - T(start);
126 impl::egblas::scalar_mul(y.gpu_memory(),
etl::size(y), 1, s1);
130 impl::egblas::scalar_add(y.gpu_memory(),
etl::size(y), 1, s2);
147 return os <<
"U(0,1)";
154 template <
typename G,
typename T =
double>
167 && ((is_single_precision_t<T> && impl::egblas::has_scalar_sadd && impl::egblas::has_scalar_smul)
168 || (is_double_precision_t<T> && impl::egblas::has_scalar_dadd && impl::egblas::has_scalar_dmul));
185 #ifdef ETL_CURAND_MODE 193 template <
typename Y>
195 auto t1 = etl::force_temporary_gpu_dim_only_t<T>(y);
197 curandGenerator_t gen;
200 curand_call(curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT));
203 std::uniform_int_distribution<long> seed_dist;
204 curand_call(curandSetPseudoRandomGeneratorSeed(gen, seed_dist(rand_engine)));
207 impl::curand::generate_uniform(gen, t1.gpu_memory(),
etl::size(y));
210 auto s1 = T(end) - T(start);
211 impl::egblas::scalar_mul(t1.gpu_memory(),
etl::size(y), 1, s1);
215 impl::egblas::scalar_add(t1.gpu_memory(),
etl::size(y), 1, s2);
226 template <
typename Y>
227 Y& gpu_compute(Y& y) noexcept {
228 y.ensure_gpu_allocated();
230 curandGenerator_t gen;
233 curand_call(curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT));
236 std::uniform_int_distribution<long> seed_dist;
237 curand_call(curandSetPseudoRandomGeneratorSeed(gen, seed_dist(rand_engine)));
240 impl::curand::generate_uniform(gen, y.gpu_memory(),
etl::size(y));
243 auto s1 = T(end) - T(start);
244 impl::egblas::scalar_mul(y.gpu_memory(),
etl::size(y), 1, s1);
248 impl::egblas::scalar_add(y.gpu_memory(),
etl::size(y), 1, s2);
265 return os <<
"U(0,1)";
EGBLAS wrappers for the scalar_mul operation.
auto s(T &&value)
Force the evaluation of the given expression.
Definition: stop.hpp:18
std::conditional_t< std::is_floating_point_v< T >, std::uniform_real_distribution< T >, std::uniform_int_distribution< T > > uniform_distribution
Selector helper to get an uniform_distribution based on the type (real or int)
Definition: uniform.hpp:31
EGBLAS wrappers for the scalar_add operation.
constexpr bool curand_enabled
Indicates if the NVIDIA CURAND library is available for ETL.
Definition: config.hpp:109
Root namespace for the ETL library.
Definition: adapter.hpp:15
constexpr size_t size(const E &expr) noexcept
Returns the size of the given ETL expression.
Definition: helpers.hpp:108
std::mt19937_64 random_engine
The random engine used by the library.
Definition: random.hpp:22
Utility functions for curand.
const auto & gpu_compute_hint([[maybe_unused]] Y &y) const
Return a GPU computed version of this expression.
Definition: sub_view.hpp:653