Expression Templates Library (ETL)
truncated_normal.hpp
Go to the documentation of this file.
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
13 #pragma once
14 
15 #include <chrono> //for std::time
16 
17 namespace etl {
18 
22 template <typename T = double>
24  using value_type = T;
25 
27  std::normal_distribution<value_type> distribution;
28 
29  static constexpr bool gpu_computable = false;
30 
36  truncated_normal_generator_op(T mean, T stddev) : rand_engine(std::time(nullptr)), distribution(mean, stddev) {}
37 
43  auto x = distribution(rand_engine);
44 
45  while (std::abs(x - distribution.mean()) > 2.0 * distribution.stddev()) {
46  x = distribution(rand_engine);
47  }
48 
49  return x;
50  }
51 
58  friend std::ostream& operator<<(std::ostream& os, [[maybe_unused]] const truncated_normal_generator_op& s) {
59  return os << "TN(0,1)";
60  }
61 };
62 
66 template <typename G, typename T = double>
68  using value_type = T;
69 
71  std::normal_distribution<value_type> distribution;
72 
73  static constexpr bool gpu_computable = false;
74 
80  truncated_normal_generator_g_op(G& g, T mean, T stddev) : rand_engine(g), distribution(mean, stddev) {}
81 
87  auto x = distribution(rand_engine);
88 
89  while (std::abs(x - distribution.mean()) > 2.0 * distribution.stddev()) {
90  x = distribution(rand_engine);
91  }
92 
93  return x;
94  }
95 
102  friend std::ostream& operator<<(std::ostream& os, [[maybe_unused]] const truncated_normal_generator_g_op& s) {
103  return os << "TN(0,1)";
104  }
105 };
106 
107 } //end of namespace etl
value_t< E > mean(E &&values)
Returns the mean of all the values contained in the given expression.
Definition: expression_builder.hpp:650
auto s(T &&value)
Force the evaluation of the given expression.
Definition: stop.hpp:18
value_type operator()()
Generate a new value.
Definition: truncated_normal.hpp:42
T value_type
The value type.
Definition: truncated_normal.hpp:24
random_engine rand_engine
The random engine.
Definition: truncated_normal.hpp:26
Generator from a normal distribution.
Definition: truncated_normal.hpp:23
auto abs(E &&value)
Apply absolute on each value of the given expression.
Definition: expression_builder.hpp:54
friend std::ostream & operator<<(std::ostream &os, [[maybe_unused]] const truncated_normal_generator_op &s)
Outputs the given generator to the given stream.
Definition: truncated_normal.hpp:58
Root namespace for the ETL library.
Definition: adapter.hpp:15
value_t< E > stddev(E &&values)
Returns the standard deviation of all the values contained in the given expression.
Definition: expression_builder.hpp:670
std::normal_distribution< value_type > distribution
The used distribution.
Definition: truncated_normal.hpp:27
Generator from a normal distribution using a custom random engine.
Definition: truncated_normal.hpp:67
T value_type
The value type.
Definition: truncated_normal.hpp:68
G & rand_engine
The random engine.
Definition: truncated_normal.hpp:70
static constexpr bool gpu_computable
Indicates if the operator is computable on GPU.
Definition: truncated_normal.hpp:29
std::mt19937_64 random_engine
The random engine used by the library.
Definition: random.hpp:22
truncated_normal_generator_op(T mean, T stddev)
Construct a new generator with the given mean and standard deviation.
Definition: truncated_normal.hpp:36
value_type operator()()
Generate a new value.
Definition: truncated_normal.hpp:86
std::normal_distribution< value_type > distribution
The used distribution.
Definition: truncated_normal.hpp:71
truncated_normal_generator_g_op(G &g, T mean, T stddev)
Construct a new generator with the given mean and standard deviation.
Definition: truncated_normal.hpp:80
friend std::ostream & operator<<(std::ostream &os, [[maybe_unused]] const truncated_normal_generator_g_op &s)
Outputs the given generator to the given stream.
Definition: truncated_normal.hpp:102