Expression Templates Library (ETL)
context.hpp
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
8 #pragma once
9 
10 namespace etl {
11 
12 #ifdef ETL_MANUAL_SELECT
13 
16 template <typename T>
17 struct forced_impl {
18  T impl;
19  bool forced = false;
20 };
21 #endif
22 
26 struct context {
27  bool serial = false;
28  bool parallel = false;
29  bool cpu = false;
30 
31 #ifdef ETL_MANUAL_SELECT
32  forced_impl<sum_impl> sum_selector;
33  forced_impl<pool_impl> pool_selector;
34  forced_impl<transpose_impl> transpose_selector;
35  forced_impl<dot_impl> dot_selector;
36  forced_impl<conv_impl> conv_selector;
37  forced_impl<conv_multi_impl> conv_multi_selector;
38  forced_impl<conv4_impl> conv4_selector;
39  forced_impl<gemm_impl> gemm_selector;
40  forced_impl<outer_impl> outer_selector;
41  forced_impl<bias_add_impl> bias_add_selector;
42  forced_impl<fft_impl> fft_selector;
43 #endif
44 };
45 
51  static thread_local context local_context;
52  return local_context;
53 }
54 
60 inline bool is_something_forced() {
61 #ifdef ETL_MANUAL_SELECT
62  auto& c = local_context();
63  return c.sum_selector.forced || c.pool_selector.forced || c.transpose_selector.forced || c.dot_selector.forced || c.conv_selector.forced
64  || c.conv_multi_selector.forced || c.conv4_selector.forced || c.gemm_selector.forced || c.outer_selector.forced || c.bias_add_selector.forced
65  || c.fft_selector.forced;
66 #else
67  return false;
68 #endif
69 }
70 
71 namespace detail {
72 
73 #ifdef ETL_MANUAL_SELECT
74 
80 template <typename T>
81 forced_impl<T>& get_forced_impl();
82 
86 template <>
87 inline forced_impl<sum_impl>& get_forced_impl() {
88  return local_context().sum_selector;
89 }
90 
94 template <>
95 inline forced_impl<pool_impl>& get_forced_impl() {
96  return local_context().pool_selector;
97 }
98 
102 template <>
103 inline forced_impl<transpose_impl>& get_forced_impl() {
104  return local_context().transpose_selector;
105 }
106 
110 template <>
111 inline forced_impl<dot_impl>& get_forced_impl() {
112  return local_context().dot_selector;
113 }
114 
118 template <>
119 inline forced_impl<conv_impl>& get_forced_impl() {
120  return local_context().conv_selector;
121 }
122 
126 template <>
127 inline forced_impl<conv_multi_impl>& get_forced_impl() {
128  return local_context().conv_multi_selector;
129 }
130 
134 template <>
135 inline forced_impl<conv4_impl>& get_forced_impl() {
136  return local_context().conv4_selector;
137 }
138 
142 template <>
143 inline forced_impl<gemm_impl>& get_forced_impl() {
144  return local_context().gemm_selector;
145 }
146 
150 template <>
151 inline forced_impl<outer_impl>& get_forced_impl() {
152  return local_context().outer_selector;
153 }
154 
158 template <>
159 inline forced_impl<bias_add_impl>& get_forced_impl() {
160  return local_context().bias_add_selector;
161 }
162 
166 template <>
167 inline forced_impl<fft_impl>& get_forced_impl() {
168  return local_context().fft_selector;
169 }
170 
171 #endif
172 
177  bool old_serial;
178 
185  old_serial = etl::local_context().serial;
186  etl::local_context().serial = true;
187  }
188 
195  etl::local_context().serial = old_serial;
196  }
197 
201  operator bool() {
202  return true;
203  }
204 };
205 
211 
218  old_parallel = etl::local_context().parallel;
219  etl::local_context().parallel = true;
220  }
221 
228  etl::local_context().parallel = old_parallel;
229  }
230 
234  operator bool() {
235  return true;
236  }
237 };
238 
242 struct cpu_context {
243  bool old_cpu;
244 
251  old_cpu = etl::local_context().cpu;
252  etl::local_context().cpu = true;
253  }
254 
261  etl::local_context().cpu = old_cpu;
262  }
263 
267  operator bool() {
268  return true;
269  }
270 };
271 
272 #ifdef ETL_MANUAL_SELECT
273 
278 template <typename Selector, Selector V>
279 struct selected_context {
280  forced_impl<Selector> old_selector;
281 
288  selected_context() {
289  decltype(auto) selector = get_forced_impl<Selector>();
290 
291  old_selector = selector;
292 
293  selector.impl = V;
294  selector.forced = true;
295  }
296 
302  ~selected_context() {
303  get_forced_impl<Selector>() = old_selector;
304  }
305 
309  operator bool() {
310  return true;
311  }
312 };
313 
314 #endif
315 
316 } //end of namespace detail
317 
321 #define SERIAL_SECTION if (auto etl_serial_context__ = etl::detail::serial_context())
322 
326 #define PARALLEL_SECTION if (auto etl_parallel_context__ = etl::detail::parallel_context())
327 
331 #define CPU_SECTION if (auto etl_cpu_context__ = etl::detail::cpu_context())
332 
333 #ifdef ETL_MANUAL_SELECT
334 
338 #define SELECTED_SECTION(v) if (auto etl_selected_context__ = etl::detail::selected_context<decltype(v), v>())
339 
340 #endif
341 
342 } //end of namespace etl
~cpu_context()
Destruct a cpu context.
Definition: context.hpp:260
bool serial
Force serial execution.
Definition: context.hpp:27
parallel_context()
Default construct a parallel context.
Definition: context.hpp:217
bool old_cpu
The previous value of cpu.
Definition: context.hpp:243
The contextual configuration of ETL.
Definition: context.hpp:26
bool parallel
Force parallel execution.
Definition: context.hpp:28
serial_context()
Default construct a serial context.
Definition: context.hpp:184
bool old_serial
The previous value of serial.
Definition: context.hpp:177
Root namespace for the ETL library.
Definition: adapter.hpp:15
context & local_context()
Return the configuration context of the current thread.
Definition: context.hpp:50
RAII helper for setting the context to serial.
Definition: context.hpp:176
bool cpu
Force CPU evaluation.
Definition: context.hpp:29
cpu_context()
Default construct a cpu context.
Definition: context.hpp:250
auto parallel(Expr &&expr) -> parallel_expr< detail::build_type< Expr >>
Create a parallel expression wrapping the given expression.
Definition: wrapper_expression_builder.hpp:79
RAII helper for setting the context to cpu.
Definition: context.hpp:242
RAII helper for setting the context to parallel.
Definition: context.hpp:209
bool is_something_forced()
Indicates if some implementation is forced in the context.
Definition: context.hpp:60
auto serial(Expr &&expr) -> serial_expr< detail::build_type< Expr >>
Create a serial expression wrapping the given expression.
Definition: wrapper_expression_builder.hpp:66
~parallel_context()
Destruct a parallel context.
Definition: context.hpp:227
~serial_context()
Destruct a serial context.
Definition: context.hpp:194
bool old_parallel
The previous value of parallel.
Definition: context.hpp:210