cuda-api-wrappers
Thin C++-flavored wrappers for the CUDA Runtime API
library.hpp
Go to the documentation of this file.
1 
7 #pragma once
8 #ifndef CUDA_API_WRAPPERS_LIBRARY_HPP_
9 #define CUDA_API_WRAPPERS_LIBRARY_HPP_
10 
11 #if CUDA_VERSION >= 12000
12 
13 #include "module.hpp"
14 #include "error.hpp"
15 
16 #if __cplusplus >= 201703L
17 #include <filesystem>
18 #endif
19 
20 namespace cuda {
21 
23 class context_t;
24 class module_t;
25 class library_t;
26 class kernel_t;
28 
29 namespace library {
30 
31 using handle_t = CUlibrary;
32 
33 namespace kernel {
34 
35 using handle_t = CUkernel; // Don't be confused; a context-associated kernel is a CUfunction :-(
36 
37 } // namespace kernel
38 
39 namespace detail_ {
40 
41 using option_t = CUlibraryOption;
42 
43 } // namespace detail_
44 
45 class kernel_t; // A kernel stored within a library; strangely, a context-associated kernel is a CUfunction.
46 
47 namespace detail_ {
48 
49 inline library_t wrap(
50  handle_t handle,
51  bool take_ownership = false) noexcept;
52 
53 inline ::std::string identify(const library::handle_t &handle)
54 {
55  return ::std::string("library ") + cuda::detail_::ptr_as_hex(handle);
56 }
57 
58 ::std::string identify(const library_t &library);
59 
60 } // namespace detail_
61 
68 template <typename ContiguousContainer,
70  cuda::detail_::enable_if_t<cuda::detail_::is_kinda_like_contiguous_container<ContiguousContainer>::value, bool> = true >
71 library_t create(
72  ContiguousContainer library_data,
73  optional<link::options_t> link_options,
74  bool code_is_preserved);
76 
77 
78 namespace detail_ {
79 
80 inline kernel::handle_t get_kernel_in_current_context(handle_t library_handle, const char* name)
81 {
82  library::kernel::handle_t kernel_handle;
83  auto status = cuLibraryGetKernel(&kernel_handle, library_handle, name);
84  throw_if_error_lazy(status, ::std::string{"Failed obtaining kernel "}
85  + name + "' from " + library::detail_::identify(library_handle));
86  return kernel_handle;
87 }
88 
89 inline kernel::handle_t get_kernel(context::handle_t context_handle, handle_t library_handle, const char* name)
90 {
91  CAW_SET_SCOPE_CONTEXT(context_handle);
92  return get_kernel_in_current_context(library_handle, name);
93 }
94 
95 } // namespace detail_
96 
97 inline kernel_t get_kernel(const library_t& library, const char* name);
98 inline kernel_t get_kernel(context_t& context, const library_t& library, const char* name);
99 
100 } // namespace library
101 
102 memory::region_t get_global(const context_t& context, const library_t& library, const char* name);
103 memory::region_t get_managed_region(const library_t& library, const char* name);
104 
105 namespace module {
106 
107 module_t create(const context_t& context, const library_t& library);
108 module_t create(const library_t& library);
109 
110 } // namespace module
111 
112 void* get_unified_function(const context_t& context, const library_t& library, const char* symbol);
113 
118 class library_t {
119 
120 public: // getters
121 
122  library::handle_t handle() const { return handle_; }
123 
135  library::kernel_t get_kernel(const context_t& context, const char* name) const;
136  library::kernel_t get_kernel(const context_t& context, const ::std::string& name) const;
137  library::kernel_t get_kernel(const char* name) const;
138  library::kernel_t get_kernel(const ::std::string& name) const;
139 
140  memory::region_t get_global(const char* name) const
141  {
142  return cuda::get_global(context::current::get(), *this, name);
143  }
144 
145  memory::region_t get_global(const ::std::string& name) const
146  {
147  return get_global(name.c_str());
148  }
149 
150  memory::region_t get_managed(const char* name) const
151  {
152  return cuda::get_managed_region(*this, name);
153  }
154 
155  memory::region_t get_managed(const ::std::string& name) const
156  {
157  return get_managed(name.c_str());
158  }
159 
160 protected: // constructors
161 
162  library_t(library::handle_t handle, bool owning) noexcept
163  : handle_(handle), owning_(owning)
164  { }
165 
166 public: // friendship
167 
168  friend library_t library::detail_::wrap(library::handle_t, bool) noexcept;
169 
170 public: // constructors and destructor
171 
172  library_t(const library_t&) = delete;
173 
174  library_t(library_t&& other) noexcept : library_t(other.handle_, other.owning_)
175  {
176  other.owning_ = false;
177  };
178 
179  ~library_t() noexcept(false)
180  {
181  if (owning_) {
182  auto status = cuLibraryUnload(handle_);
183  throw_if_error_lazy(status, "Failed unloading " + library::detail_::identify(handle_));
184  }
185  }
186 
187 public: // operators
188 
189  library_t& operator=(const library_t&) = delete;
190  library_t& operator=(library_t&& other) noexcept
191  {
192  ::std::swap(handle_, other.handle_);
193  ::std::swap(owning_, other.owning_);
194  return *this;
195  }
196 
197 protected: // data members
198  library::handle_t handle_;
199  bool owning_;
200  // this field is mutable only for enabling move construction; other
201  // than in that case it must not be altered
202 };
203 
204 inline memory::region_t get_global(const context_t& context, const library_t& library, const char* name)
205 {
206  CUdeviceptr dptr;
207  size_t size;
208  auto result = cuLibraryGetGlobal(&dptr, &size, library.handle(), name);
209  throw_if_error_lazy(result,
210  ::std::string("Obtaining the memory address and size for the global object '") + name + "' from "
211  + library::detail_::identify(library) + " in context " + context::detail_::identify(context));
212  return { memory::as_pointer(dptr), size };
213  // Note: Nothing is holding a PC refcount unit here!
214 }
215 
216 // More library item getters
217 namespace library {
218 
219 } // namespace library
220 
221 inline memory::region_t get_managed_region(const library_t& library, const char* name)
222 {
223  memory::device::address_t region_start;
224  size_t region_size;
225  auto status = cuLibraryGetManaged(&region_start, &region_size, library.handle(), name);
226  throw_if_error_lazy(status, ::std::string("Failed obtaining the managed memory region '") + name
227  + "' from " + library::detail_::identify(library));
228  return { memory::as_pointer(region_start), region_size };
229 }
230 
231 namespace module {
232 
236 inline module_t create(const context_t& context, const library_t& library)
237 {
238  CAW_SET_SCOPE_CONTEXT(context.handle());
239  module::handle_t new_handle;
240  auto status = cuLibraryGetModule(&new_handle, library.handle());
241  throw_if_error_lazy(status, ::std::string("Failed creating a module '") +
242  + "' from " + library::detail_::identify(library) + " in " + context::detail_::identify(context));
243  constexpr const bool is_owning { true };
244  return module::detail_::wrap(context.device_id(), context.handle(), new_handle,
245  is_owning, do_hold_primary_context_refcount_unit);
246  // TODO: We could consider adding a variant of this function taking a context&&, and using that
247  // to decide whether or not to hold a PC refcount unit
248 }
249 
250 } // namespace module
251 
252 // I really have no idea what this does!
253 inline void* get_unified_function(const context_t& context, const library_t& library, const char* symbol)
254 {
255  CAW_SET_SCOPE_CONTEXT(context.handle());
256  void* function_ptr;
257  auto status = cuLibraryGetUnifiedFunction(&function_ptr, library.handle(), symbol);
258  throw_if_error_lazy(status, ::std::string("Failed obtaining a pointer for function '") + symbol
259  + "' from " + library::detail_::identify(library) + " in " + context::detail_::identify(context));
260  return function_ptr;
261 }
262 
263 namespace library {
264 
265 namespace detail_ {
266 
267 template <typename Creator, typename DataSource, typename ErrorStringGenerator>
268 library_t create(
269  Creator creator,
270  DataSource data_source,
271  ErrorStringGenerator error_string_generator,
272  const link::options_t& link_options = {},
273  bool code_is_preserved = false)
274 {
275  handle_t new_lib_handle;
276  auto raw_link_opts = link::detail_::marshal(link_options);
277  struct {
278  detail_::option_t options[1];
279  void* values[1];
280  unsigned count;
281  } raw_opts = { { CU_LIBRARY_BINARY_IS_PRESERVED }, { &code_is_preserved }, 1 };
282  auto status = creator(
283  &new_lib_handle, data_source,
284  const_cast<link::detail_::option_t*>(raw_link_opts.options()),
285  const_cast<void**>(raw_link_opts.values()), raw_link_opts.count(),
286  raw_opts.options, raw_opts.values, raw_opts.count
287  );
288  throw_if_error_lazy(status,
289  ::std::string("Failed loading a compiled CUDA code library from ") + error_string_generator());
290  bool do_take_ownership{true};
291  return detail_::wrap(new_lib_handle, do_take_ownership);
292 }
293 
294 } // namespace detail_
295 
307 inline library_t load_from_file(
309  const char* path,
310  const link::options_t& link_options = {},
311  bool code_is_preserved = false)
312 {
313  return detail_::create(
314  cuLibraryLoadFromFile, path,
315  [path]() { return ::std::string("file ") + path; },
316  link_options, code_is_preserved);
317 }
318 
319 inline library_t load_from_file(
320  const ::std::string& path,
321  const link::options_t& link_options = {},
322  bool code_is_preserved = false)
323 {
324  return load_from_file(path.c_str(), link_options, code_is_preserved);
325 }
326 
327 #if __cplusplus >= 201703L
328 
329 inline library_t load_from_file(
330  const ::std::filesystem::path& path,
331  const link::options_t& link_options = {},
332  bool code_is_preserved = false)
333 {
334  return load_from_file(path.c_str(), link_options, code_is_preserved);
335 }
336 
337 #endif
338 
340 namespace detail_ {
341 
342 inline library_t wrap(handle_t handle, bool take_ownership) noexcept
343 {
344  return library_t{handle, take_ownership};
345 }
346 
347 } // namespace detail_
348 
356 inline library_t create(
357  const void* module_data,
358  const link::options_t& link_options = {},
359  bool code_is_preserved = false)
360 {
361  return detail_::create(
362  cuLibraryLoadData, module_data,
363  [module_data]() { return ::std::string("data at ") + cuda::detail_::ptr_as_hex(module_data); },
364  link_options, code_is_preserved);
365 }
366 
367 
368 // TODO: Use an optional to reduce the number of functions here... when the
369 // library starts requiring C++14.
370 
371 namespace detail_ {
372 
373 inline ::std::string identify(const library_t& library)
374 {
375  return identify(library.handle());
376 }
377 
378 } // namespace detail_
379 
380 template <typename ContiguousContainer,
381  cuda::detail_::enable_if_t<cuda::detail_::is_kinda_like_contiguous_container<ContiguousContainer>::value, bool> >
382 library_t create(
383  ContiguousContainer library_data,
384  optional<link::options_t> link_options,
385  bool code_is_preserved)
386 {
387  return create(library_data.data(), link_options, code_is_preserved);
388 }
389 
390 } // namespace library
391 
392 } // namespace cuda
393 
394 #endif // CUDA_VERSION >= 12000
395 
396 #endif // CUDA_API_WRAPPERS_LIBRARY_HPP_
Definitions and functionality wrapping CUDA APIs.
Definition: array.hpp:22
device::id_t count()
Get the number of CUDA devices usable on the system (with the current CUDA library and kernel driver)...
Definition: miscellany.hpp:63
detail_::region_helper< memory::region_t > region_t
A child class of the generic region_t with some managed-memory-specific functionality.
Definition: memory.hpp:1960
STL namespace.
module_t load_from_file(const context_t &context, const char *path)
Load a module from an appropriate compiled or semi-compiled file, allocating all relevant resources f...
Definition: module.hpp:317
#define throw_if_error_lazy(status__,...)
A macro for only throwing an error if we&#39;ve failed - which also ensures no string is constructed unle...
Definition: error.hpp:316
CUarray handle_t
Raw CUDA driver handle for arrays (of any dimension)
Definition: array.hpp:34
array_t< T, NumDimensions > wrap(device::id_t device_id, context::handle_t context_handle, handle_t handle, dimensions_t< NumDimensions > dimensions) noexcept
Wrap an existing CUDA array in an array_t instance.
Definition: array.hpp:264
void * as_pointer(device::address_t address) noexcept
Definition: types.hpp:700
CUdeviceptr address_t
The numeric type which can represent the range of memory addresses on a CUDA device.
Definition: types.hpp:672