cuda-api-wrappers
Thin C++-flavored wrappers for the CUDA Runtime API
current_context.hpp
Go to the documentation of this file.
1 
4 #pragma once
5 #ifndef CUDA_API_WRAPPERS_CURRENT_CONTEXT_HPP_
6 #define CUDA_API_WRAPPERS_CURRENT_CONTEXT_HPP_
7 
8 #include "error.hpp"
9 #include "constants.hpp"
10 #include "types.hpp"
11 
12 namespace cuda {
13 
15 class device_t;
16 class context_t;
17 namespace device {
18 class primary_context_t;
19 } // namespace device
21 
22 namespace context {
23 
24 namespace current {
25 
30 inline bool exists()
31 {
32  context::handle_t handle;
33  auto status = cuCtxGetCurrent(&handle);
34  if (status == cuda::status::not_yet_initialized) {
35  return false;
36  }
37  throw_if_error_lazy(status, "Failed obtaining the current context's handle");
38  return (handle != context::detail_::none);
39 }
40 
41 
42 namespace detail_ {
49 inline bool is_(handle_t handle)
50 {
51  handle_t current_context_handle;
52  auto status = cuCtxGetCurrent(&current_context_handle);
53  switch(status) {
54  case CUDA_ERROR_NOT_INITIALIZED:
55  case CUDA_ERROR_INVALID_CONTEXT:
56  return false;
57  case CUDA_SUCCESS:
58  return (handle == current_context_handle);
59  default:
60  throw cuda::runtime_error(status,
61  "Failed determining whether there's a current context, or what it is");
62  }
63 }
64 
65 struct status_and_handle_pair {
66  status_t status;
67  handle_t handle;
68 };
69 
78 inline status_and_handle_pair get_with_status()
79 {
80  handle_t handle;
81  auto status = cuCtxGetCurrent(&handle);
82  if (status == status::not_yet_initialized) {
83  handle = context::detail_::none;
84  }
85  return { status, handle };
86 }
87 
94 inline handle_t get_handle()
95 {
96  auto p = get_with_status();
97  throw_if_error_lazy(p.status, "Failed obtaining the current context's handle");
98  return p.handle;
99 }
100 
101 // Note: not calling this get_ since flags are read-only anyway
102 inline context::flags_t get_flags()
103 {
104  context::flags_t result;
105  auto status = cuCtxGetFlags(&result);
106  throw_if_error_lazy(status, "Failed obtaining the current context's flags");
107  // Note: Not sanitizing the flags from having CU_CTX_MAP_HOST set
108  return result;
109 }
110 
111 inline device::id_t get_device_id()
112 {
113  device::id_t device_id;
114  auto result = cuCtxGetDevice(&device_id);
115  throw_if_error_lazy(result, "Failed obtaining the current context's device");
116  return device_id;
117 }
118 
127 inline void push(handle_t context_handle)
128 {
129  auto status = cuCtxPushCurrent(context_handle);
130  throw_if_error_lazy(status, "Failed pushing to the top of the context stack: "
131  + context::detail_::identify(context_handle));
132 }
133 
146 inline bool push_if_not_on_top(handle_t context_handle)
147 {
148  if (get_handle() == context_handle) { return false; }
149  push(context_handle);
150  return true;
151 }
152 
153 inline status_t pop_and_discard_nothrow()
154 {
155  handle_t popped_context_handle;
156  auto status = cuCtxPopCurrent(&popped_context_handle);
157  return status;
158 }
159 
160 inline context::handle_t pop()
161 {
162  handle_t popped_context_handle;
163  auto status = cuCtxPopCurrent(&popped_context_handle);
164  throw_if_error_lazy(status, "Failed popping the current CUDA context");
165  return popped_context_handle;
166 }
167 
168 inline void set(handle_t context_handle)
169 {
170  // Thought about doing this:
171  // if (detail_::get_handle() == context_handle_) { return; }
172  // ... but decided against it.
173  auto status = cuCtxSetCurrent(context_handle);
174  throw_if_error_lazy(status,
175  "Failed setting the current context to " + context::detail_::identify(context_handle));
176 }
177 
178 } // namespace detail_
179 
180 namespace detail_ {
184 class scoped_override_t {
185 public:
186  bool hold_primary_context_ref_unit_;
187  device::id_t device_id_or_0_;
188 
189  explicit scoped_override_t(handle_t context_handle) : scoped_override_t(false, 0, context_handle) {}
190  scoped_override_t(device::id_t device_for_which_context_is_primary, handle_t context_handle)
191  : scoped_override_t(true, device_for_which_context_is_primary, context_handle) {}
192  explicit scoped_override_t(bool hold_primary_context_ref_unit, device::id_t device_id, handle_t context_handle);
193  scoped_override_t(const scoped_override_t&) = delete;
195  scoped_override_t& operator=(const scoped_override_t&) = delete;
196  scoped_override_t& operator=(scoped_override_t&&) = delete;
197  ~scoped_override_t() DESTRUCTOR_EXCEPTION_SPEC;
198 };
199 
201 
202 /*
203  * This macro is intended for use inside the cuda-api-wrappers implementation, to
204  * save us some typing; it's quite usable on the outside, but you probably want to
205  * use the context_t objects, and for safety (e.g. w.r.t. primary device contexts),
206  * prefer @ref SET_CUDA_CONTEXT_FOR_THIS_SCOPE instead.
207  */
208 #define CAW_SET_SCOPE_CONTEXT(context_handle_expr_) \
209 const ::cuda::context::current::detail_::scoped_override_t caw_context_for_this_scope_(context_handle_expr_)
210 
217 class scoped_ensurer_t {
218 public:
219  bool context_was_pushed_on_construction;
220 
221  explicit scoped_ensurer_t(bool force_push, handle_t fallback_context_handle)
222  : context_was_pushed_on_construction(force_push)
223  {
224  if (force_push) { push(fallback_context_handle); }
225  }
226 
227  explicit scoped_ensurer_t(handle_t fallback_context_handle)
228  : scoped_ensurer_t(not exists(), fallback_context_handle)
229  {}
230 
231  scoped_ensurer_t(const scoped_ensurer_t&) = delete;
232  scoped_ensurer_t(scoped_ensurer_t&&) = delete;
233 
234  scoped_ensurer_t& operator=(scoped_ensurer_t&&) = delete;
235  scoped_ensurer_t& operator=(const scoped_ensurer_t&) = delete;
236 
237  ~scoped_ensurer_t() { if (context_was_pushed_on_construction) { pop(); } }
238 };
239 
240 } // namespace detail_
241 
255 class scoped_override_t : private detail_::scoped_override_t {
256 protected:
257  using parent = detail_::scoped_override_t;
258 public:
259 
260  explicit scoped_override_t(device::primary_context_t&& primary_context);
261  explicit scoped_override_t(const context_t& context);
262  explicit scoped_override_t(context_t&& context);
263  ~scoped_override_t() = default;
264 };
265 
266 
267 
274 #define CUDA_CONTEXT_FOR_THIS_SCOPE(_cuda_context) \
275 ::cuda::context::current::scoped_override_t set_context_for_this_scope{ _cuda_context }
276 
283 inline void synchronize()
284 {
285  auto status = cuCtxSynchronize();
286  if (not is_success(status)) {
287  throw cuda::runtime_error(status, "Failed synchronizing current context");
288  }
289 }
290 
291 namespace detail_ {
292 
293 // Just like context::current::synchronize(), but with an argument
294 // allowing for throwing a more informative exception on failure
295 inline void synchronize(context::handle_t current_context_handle)
296 {
297  auto status = cuCtxSynchronize();
298  if (not is_success(status)) {
299  throw cuda::runtime_error(status,"Failed synchronizing "
300  + context::detail_::identify(current_context_handle));
301  }
302 }
303 
304 // Just like context::current::synchronize(), but with arguments
305 // allowing for throwing a more informative exception on failure
306 inline void synchronize(
307  device::id_t current_context_device_id,
308  context::handle_t current_context_handle)
309 {
310  auto status = cuCtxSynchronize();
311  if (not is_success(status)) {
312  throw cuda::runtime_error(status, "Failed synchronizing "
313  + context::detail_::identify(current_context_handle, current_context_device_id));
314  }
315 }
316 
317 } // namespace detail
318 
319 } // namespace current
320 
321 } // namespace context
322 
323 } // namespace cuda
324 
325 #endif // CUDA_API_WRAPPERS_CURRENT_CONTEXT_HPP_
Wrapper class for a CUDA context.
Definition: context.hpp:249
Definitions and functionality wrapping CUDA APIs.
Definition: array.hpp:22
CUcontext handle_t
Raw CUDA driver handle for a context; see {context_t}.
Definition: types.hpp:880
A class for holding the primary context of a CUDA device.
Definition: primary_context.hpp:122
CUdevice id_t
Numeric ID of a CUDA device used by the CUDA Runtime API.
Definition: types.hpp:852
bool push_if_not_on_top(const context_t &context)
Push a (reference to a) context onto the top of the context stack - unless that context is already at...
Definition: context.hpp:899
context_t pop()
Pop the top off of the context stack.
Definition: context.hpp:922
A (base?) class for exceptions raised by CUDA code; these errors are thrown by essentially all CUDA R...
Definition: error.hpp:282
void synchronize(const context_t &context)
Waits for all previously-scheduled tasks on all streams (= queues) in a CUDA context to conclude...
Definition: context.hpp:980
A RAII-based mechanism for pushing a context onto the context stack for what remains of the current (...
Definition: current_context.hpp:255
void push(const context_t &context)
Push a (reference to a) context onto the top of the context stack.
Definition: context.hpp:911
#define throw_if_error_lazy(status__,...)
A macro for only throwing an error if we've failed - which also ensures no string is constructed unle...
Definition: error.hpp:327
Facilities for exception-based handling of Runtime and Driver API errors, including a basic exception...
Fundamental CUDA-related constants and enumerations, not dependent on any more complex abstractions...
bool exists()
Determine whether any CUDA context is current, or whether the context stack is empty/uninitialized.
Definition: current_context.hpp:30
Fundamental CUDA-related type definitions.
constexpr bool is_success(status_t status)
Determine whether the API call returning the specified status had succeeded.
Definition: error.hpp:214
CUresult status_t
Indicates either the result (success or error index) of a CUDA Runtime or Driver API call...
Definition: types.hpp:74