7 #ifndef CUDA_API_WRAPPERS_TYPED_NODE_HPP 8 #define CUDA_API_WRAPPERS_TYPED_NODE_HPP 10 #if CUDA_VERSION >= 10000 13 #include "../detail/for_each_argument.hpp" 14 #include "../error.hpp" 15 #include "../device.hpp" 16 #include "../event.hpp" 17 #include "../kernel.hpp" 18 #include "../launch_configuration.hpp" 19 #include "../memory_pool.hpp" 24 #include <type_traits> 38 template<
typename... Ts>
39 ::std::vector<void*> make_kernel_argument_pointers(
const Ts&... kernel_arguments)
44 return { {
const_cast<void *
>(
reinterpret_cast<const void*
>(&kernel_arguments)) ... } };
49 enum class kind_t : ::std::underlying_type<CUgraphNodeType>::type {
50 kernel_launch = CU_GRAPH_NODE_TYPE_KERNEL,
51 memory_copy = CU_GRAPH_NODE_TYPE_MEMCPY, memcpy = memory_copy,
52 memory_set = CU_GRAPH_NODE_TYPE_MEMSET, memset = memory_set,
53 host_function_call = CU_GRAPH_NODE_TYPE_HOST,
54 child_graph = CU_GRAPH_NODE_TYPE_GRAPH,
55 empty = CU_GRAPH_NODE_TYPE_EMPTY,
56 #if CUDA_VERSION >= 11010 57 wait = CU_GRAPH_NODE_TYPE_WAIT_EVENT, wait_on_event =
wait,
58 event = CU_GRAPH_NODE_TYPE_EVENT_RECORD, record_event = event,
59 #endif // CUDA_VERSION >= 11010 65 #if CUDA_VERSION >= 11040 66 memory_allocation = CU_GRAPH_NODE_TYPE_MEM_ALLOC, malloc = memory_allocation,
67 memory_free = CU_GRAPH_NODE_TYPE_MEM_FREE, memfree = memory_free,
68 #endif // CUDA_VERSION >= 11040 69 #if CUDA_VERSION >= 11070 71 memory_barrier = CU_GRAPH_NODE_TYPE_BATCH_MEM_OP,
73 #if CUDA_VERSION >= 12030 74 conditional = CU_GRAPH_NODE_TYPE_CONDITIONAL
75 #endif // CUDA_VERSION >= 12030 78 #if CUDA_VERSION >= 12030 79 namespace conditional {
81 using handle_t = CUgraphConditionalHandle;
82 using kind_t = CUgraphConditionalNodeType;
83 using default_value_t = unsigned;
87 using flags_t = unsigned;
92 default_value_t default_value)
95 auto status = cuGraphConditionalHandleCreate(
96 &result, graph_template, context_handle, default_value, CU_GRAPH_COND_ASSIGN_DEFAULT);
103 struct parameters_t {
105 optional<handle_t> handle;
106 optional<context::handle_t> context_handle;
107 optional<template_::handle_t> graph_template_handle;
108 optional<default_value_t> default_value;
112 #endif // CUDA_VERSION >= 12030 120 template<kind_t Kind>
124 struct kind_traits<kind_t::empty> {
125 static constexpr
const auto name =
"empty";
126 #if CUDA_VERSION >= 12030 127 using raw_parameters_type = CUgraphNodeParams;
128 static constexpr
const bool inserter_takes_context =
false;
129 static constexpr
const bool inserter_takes_params_by_ptr =
true;
130 using parameters_type = cuda::nullopt_t;
132 static constexpr
const auto inserter = cuGraphAddNode;
133 static constexpr
const auto setter = cuGraphNodeSetParams;
134 static constexpr
const auto getter =
nullptr;
136 static raw_parameters_type marshal(
const parameters_type&) {
137 raw_parameters_type raw_params;
138 ::std::memset(&raw_params, 0,
sizeof(raw_parameters_type));
139 raw_params.type =
static_cast<CUgraphNodeType
>(kind_t::empty);
142 #endif // CUDA_VERSION >= 12030 146 struct kind_traits<kind_t::child_graph> {
147 static constexpr
const auto name =
"child graph";
149 static constexpr
const bool inserter_takes_params_by_ptr =
false;
150 static constexpr
const bool inserter_takes_context =
false;
151 using parameters_type = template_t;
152 static constexpr
const auto inserter = cuGraphAddChildGraphNode;
154 static constexpr
const auto getter = cuGraphChildGraphNodeGetGraph;
155 #if CUDA_VERSION >= 11010 156 static constexpr
const auto instance_setter = cuGraphExecChildGraphNodeSetParams;
159 static raw_parameters_type marshal(
const parameters_type& params);
162 #if CUDA_VERSION >= 11010 164 struct kind_traits<kind_t::record_event> {
165 static constexpr
const auto name =
"record event";
167 static constexpr
const bool inserter_takes_context =
false;
168 static constexpr
const bool inserter_takes_params_by_ptr =
false;
169 using parameters_type = event_t;
170 static constexpr
const auto inserter = cuGraphAddEventRecordNode;
171 static constexpr
const auto setter = cuGraphEventRecordNodeSetEvent;
172 static constexpr
const auto getter = cuGraphEventRecordNodeGetEvent;
173 static constexpr
const auto instance_setter = cuGraphExecEventRecordNodeSetEvent;
175 static raw_parameters_type marshal(
const parameters_type& params)
177 return params.handle();
182 struct kind_traits<kind_t::wait_on_event> {
183 static constexpr
const auto name =
"wait on event";
185 static constexpr
const bool inserter_takes_context =
false;
186 static constexpr
const bool inserter_takes_params_by_ptr =
false;
187 using parameters_type = event_t;
188 static constexpr
const auto inserter = cuGraphAddEventWaitNode;
189 static constexpr
const auto setter = cuGraphEventWaitNodeSetEvent;
190 static constexpr
const auto getter = cuGraphEventWaitNodeGetEvent;
191 static constexpr
const auto instance_setter = cuGraphExecEventWaitNodeSetEvent;
193 static raw_parameters_type marshal(
const parameters_type& params)
195 return params.handle();
199 #endif // CUDA_VERSION >= 11010 202 struct kind_traits<kind_t::host_function_call> {
203 static constexpr
const auto name =
"host function call";
204 using raw_parameters_type = CUDA_HOST_NODE_PARAMS;
205 static constexpr
const bool inserter_takes_context =
false;
206 static constexpr
const bool inserter_takes_params_by_ptr =
true;
207 struct parameters_type {
211 static constexpr
const auto inserter = cuGraphAddHostNode;
212 static constexpr
const auto setter = cuGraphHostNodeSetParams;
213 static constexpr
const auto getter = cuGraphHostNodeGetParams;
214 static constexpr
const auto instance_setter = cuGraphExecHostNodeSetParams;
216 static raw_parameters_type marshal(
const parameters_type& params)
218 return { params.function_ptr, params.user_data };
220 static raw_parameters_type marshal(::std::pair<stream::callback_t, void*> param_pair)
222 return { param_pair.first, param_pair.second };
227 struct kind_traits<kind_t::kernel_launch> {
228 static constexpr
const auto name =
"kernel launch";
229 using raw_parameters_type = CUDA_KERNEL_NODE_PARAMS;
230 static constexpr
const bool inserter_takes_context =
false;
231 static constexpr
const bool inserter_takes_params_by_ptr =
true;
232 struct parameters_type {
234 launch_configuration_t launch_config;
235 ::std::vector<void*> marshalled_arguments;
237 static constexpr
const auto inserter = cuGraphAddKernelNode;
238 static constexpr
const auto setter = cuGraphKernelNodeSetParams;
239 static constexpr
const auto getter = cuGraphKernelNodeGetParams;
240 static constexpr
const auto instance_setter = cuGraphExecKernelNodeSetParams;
242 static raw_parameters_type marshal(
const parameters_type& params)
245 raw_parameters_type raw_params;
251 raw_params.func = params.kernel.handle();
253 raw_params.gridDimX = params.launch_config.dimensions.grid.x;
254 raw_params.gridDimY = params.launch_config.dimensions.grid.y;
255 raw_params.gridDimZ = params.launch_config.dimensions.grid.z;
256 raw_params.blockDimX = params.launch_config.dimensions.block.x;
257 raw_params.blockDimY = params.launch_config.dimensions.block.y;
258 raw_params.blockDimZ = params.launch_config.dimensions.block.z;
259 raw_params.sharedMemBytes = params.launch_config.dynamic_shared_memory_size;
260 raw_params.kernelParams =
const_cast<decltype(raw_params.kernelParams)
>(params.marshalled_arguments.data());
261 raw_params.extra =
nullptr;
266 #if CUDA_VERSION >= 11040 268 struct kind_traits<kind_t::memory_allocation> {
269 static constexpr
const auto name =
"memory allocation";
270 using raw_parameters_type = CUDA_MEM_ALLOC_NODE_PARAMS;
271 static constexpr
const bool inserter_takes_context =
false;
272 static constexpr
const bool inserter_takes_params_by_ptr =
true;
273 using parameters_type = ::std::pair<device_t, size_t>;
275 static constexpr
const auto inserter = cuGraphAddMemAllocNode;
277 static constexpr
const auto getter = cuGraphMemAllocNodeGetParams;
279 static raw_parameters_type marshal(
const parameters_type& params)
281 static constexpr
const auto no_export_handle_kind = memory::pool::shared_handle_kind_t::no_export;
282 raw_parameters_type raw_params;
283 raw_params.poolProps = memory::pool::detail_::create_raw_properties<no_export_handle_kind>(params.first.id());
285 raw_params.accessDescs =
nullptr;
286 raw_params.accessDescCount = 0;
287 raw_params.bytesize = params.second;
294 #endif // CUDA_VERSION >= 11040 297 struct kind_traits<kind_t::memory_set> {
299 static constexpr
const auto name =
"memory set";
300 using raw_parameters_type = CUDA_MEMSET_NODE_PARAMS;
301 static constexpr
const bool inserter_takes_context =
true;
302 static constexpr
const bool inserter_takes_params_by_ptr =
true;
303 struct parameters_type {
305 size_t width_in_bytes;
308 static constexpr
const auto inserter = cuGraphAddMemsetNode;
310 static constexpr
const auto getter = cuGraphMemsetNodeGetParams;
312 static raw_parameters_type marshal(
const parameters_type& params)
314 static constexpr
const size_t max_width =
sizeof(parameters_type::value);
315 if (params.width_in_bytes > max_width) {
316 throw ::std::invalid_argument(
"Unsupported memset value width (maximum is " + ::std::to_string(max_width));
318 const unsigned long min_overwide_value = 1lu << (params.width_in_bytes * CHAR_BIT);
319 if (static_cast<unsigned long>(params.value) >= min_overwide_value) {
320 throw ::std::invalid_argument(
"Memset value exceeds specified width");
322 CUDA_MEMSET_NODE_PARAMS raw_params;
325 raw_params.height = 1u;
326 raw_params.value = params.value;
327 raw_params.elementSize =
static_cast<decltype(raw_params.elementSize)
>(params.width_in_bytes);
328 raw_params.width = params.region.size() / params.width_in_bytes;
333 #if CUDA_VERSION >= 11040 335 struct kind_traits<kind_t::memory_free> {
336 static constexpr
const auto name =
"memory free";
337 using raw_parameters_type = CUdeviceptr;
338 static constexpr
const bool inserter_takes_context =
false;
339 static constexpr
const bool inserter_takes_params_by_ptr =
false;
340 using parameters_type =
void*;
341 static constexpr
const auto inserter = cuGraphAddMemFreeNode;
344 static constexpr
const auto getter = cuGraphMemFreeNodeGetParams;
346 static raw_parameters_type marshal(
const parameters_type& params)
351 #endif // CUDA_VERSION >= 11040 353 #if CUDA_VERSION >= 11070 355 struct kind_traits<kind_t::memory_barrier> {
356 static constexpr
const auto name =
"memory barrier";
357 using raw_parameters_type = CUDA_BATCH_MEM_OP_NODE_PARAMS;
358 static constexpr
const bool inserter_takes_context =
false;
359 static constexpr
const bool inserter_takes_params_by_ptr =
true;
360 using parameters_type = ::std::pair<context_t, cuda::memory::barrier_scope_t>;
361 static constexpr
const auto inserter = cuGraphAddBatchMemOpNode;
362 static constexpr
const auto setter = cuGraphBatchMemOpNodeSetParams;
363 static constexpr
const auto getter = cuGraphBatchMemOpNodeGetParams;
365 static raw_parameters_type marshal(
const parameters_type& params)
367 auto const & context = params.first;
368 raw_parameters_type raw_params;
369 raw_params.count = 1;
370 raw_params.ctx = context.handle();
371 raw_params.flags = 0;
372 CUstreamBatchMemOpParams memory_barrier_op;
373 memory_barrier_op.operation = CU_STREAM_MEM_OP_BARRIER;
374 memory_barrier_op.memoryBarrier.operation = CU_STREAM_MEM_OP_BARRIER;
375 auto const & scope = params.second;
376 memory_barrier_op.memoryBarrier.flags =
static_cast<unsigned>(scope);
377 raw_params.paramArray = &memory_barrier_op;
381 #endif // CUDA_VERSION >= 11070 384 struct kind_traits<kind_t::memory_copy> {
385 static constexpr
const auto name =
"memory copy";
386 using raw_parameters_type = CUDA_MEMCPY3D;
387 static constexpr
const bool inserter_takes_context =
true;
388 static constexpr
const bool inserter_takes_params_by_ptr =
true;
390 using parameters_type = memory::copy_parameters_t<3>;
392 static constexpr
const auto inserter = cuGraphAddMemcpyNode;
393 static constexpr
const auto setter = cuGraphMemcpyNodeSetParams;
394 static constexpr
const auto getter = cuGraphMemcpyNodeGetParams;
396 static raw_parameters_type marshal(
const parameters_type& params) {
397 auto& params_ptr =
const_cast<parameters_type&
>(params);
403 return reinterpret_cast<CUDA_MEMCPY3D&
>(params_ptr);
407 #if CUDA_VERSION >= 12040 409 struct kind_traits<kind_t::conditional> {
410 static constexpr
const auto name =
"conditional";
411 using raw_parameters_type = CUgraphNodeParams;
412 static constexpr
const bool inserter_takes_context =
false;
413 static constexpr
const bool inserter_takes_params_by_ptr =
true;
414 using parameters_type = conditional::parameters_t;
416 static constexpr
const auto inserter = cuGraphAddNode;
417 static constexpr
const auto setter = cuGraphNodeSetParams;
418 static constexpr
const auto getter =
nullptr;
420 static raw_parameters_type marshal(
const parameters_type& params) {
421 raw_parameters_type raw_params;
422 ::std::memset(&raw_params, 0,
sizeof(raw_parameters_type));
423 raw_params.type =
static_cast<CUgraphNodeType
>(kind_t::conditional);
424 raw_params.conditional.type = params.kind;
425 raw_params.conditional.ctx = params.context_handle ?
426 params.context_handle.value() : context::current::detail_::get_handle();
427 if (not params.handle and
428 not (params.graph_template_handle and params.context_handle and params.default_value)) {
429 throw ::std::invalid_argument(
430 "Conditional node creation parameters specify neither a pre-existing conditional handle, " 431 "nor the arguments required for its creation");
433 raw_params.conditional.handle = params.handle ? params.handle.value() :
434 conditional::detail_::create(
435 params.graph_template_handle.value(), params.context_handle.value(), params.default_value.value());
436 raw_params.conditional.size = 1;
440 #endif // CUDA_VERSION >= 12300 442 template <kind_t Kind, typename = typename ::std::enable_if<kind_traits<Kind>::inserter_takes_params_by_ptr>::type>
443 typename kind_traits<Kind>::raw_parameters_type *
444 maybe_add_ptr(
const typename kind_traits<Kind>::raw_parameters_type& raw_params)
446 return const_cast<typename kind_traits<Kind>::raw_parameters_type *
>(&raw_params);
449 template <kind_t Kind, typename = typename ::std::enable_if<not kind_traits<Kind>::inserter_takes_params_by_ptr>::type>
450 const typename kind_traits<Kind>::raw_parameters_type&
451 maybe_add_ptr(
const typename kind_traits<Kind>::raw_parameters_type& raw_params) {
return raw_params; }
455 template <kind_t Kind>
using parameters_t =
typename detail_::kind_traits<Kind>::parameters_type;
457 template <kind_t Kind>
460 template <kind_t Kind>
463 template <kind_t Kind>
464 class typed_node_t :
public node_t {
465 using parameters_type = parameters_t<Kind>;
467 using traits = detail_::kind_traits<Kind>;
468 using raw_parameters_type =
typename traits::raw_parameters_type;
469 static constexpr
char const *
const name = traits::name;
475 const parameters_type& parameters() const noexcept
477 static_assert(Kind != kind_t::empty,
"Empty CUDA graph nodes don't have parameters");
481 parameters_type requery_parameters()
const 483 typename traits::raw_parameters_type raw_params;
484 if (traits::param_getter ==
nullptr) {
486 "Querying parameters is not supported for this kind of node: " + node::detail_::identify(*
this));
488 auto status = traits::param_getter(handle(), &raw_params);
490 params_ = traits::unmarshal(raw_params);
494 void set_parameters(parameters_t<Kind> parameters)
496 static_assert(Kind != kind_t::empty,
"Empty CUDA graph nodes don't have parameters");
497 auto marshalled_params = traits::marshal(parameters);
498 auto status = traits::param_setter(handle(), &marshalled_params);
506 typed_node_t(
template_::handle_t graph_template_handle, handle_type handle, parameters_type parameters) noexcept
507 : node_t(graph_template_handle, handle), params_(::std::move(parameters)) { }
510 typed_node_t(
const typed_node_t<Kind>&) =
default;
511 typed_node_t(typed_node_t<Kind>&&) noexcept = default;
513 typed_node_t<Kind>& operator=(typed_node_t<Kind> other) noexcept
515 node_t::operator=(other);
516 params_ = other.params_;
520 mutable parameters_t<Kind> params_;
523 template <kind_t Kind>
526 return typed_node_t<Kind>{ graph_handle, handle, ::std::move(parameters) };
531 inline node::parameters_t<node::kind_t::kernel_launch>
532 make_launch_primed_kernel(
534 launch_configuration_t launch_config,
535 const ::std::vector<void*>& argument_pointers)
537 return { ::std::move(kernel), ::std::move(launch_config), ::std::move(argument_pointers) };
540 template <
typename... KernelParameters>
541 node::parameters_t<node::kind_t::kernel_launch>
542 make_launch_primed_kernel(
544 launch_configuration_t launch_config,
545 const KernelParameters&... kernel_arguments)
549 ::std::move(launch_config),
550 make_kernel_arg_ptrs(kernel_arguments...)
558 #endif // CUDA_VERSION >= 10000 560 #endif //CUDA_API_WRAPPERS_TYPED_NODE_HPP Definitions and functionality wrapping CUDA APIs.
Definition: array.hpp:22
detail_::region_helper< memory::region_t > region_t
A child class of the generic region_t with some managed-memory-specific functionality.
Definition: memory.hpp:1960
void wait(const event_t &event)
Have the calling thread wait - either busy-waiting or blocking - and return only after this event has...
Definition: event.hpp:467
CUevent handle_t
The CUDA driver's raw handle for events.
Definition: types.hpp:217
size_t dimensionality_t
The index or number of dimensions of an entity (as opposed to the extent in any dimension) - typicall...
Definition: types.hpp:85
A (base?) class for exceptions raised by CUDA code; these errors are thrown by essentially all CUDA R...
Definition: error.hpp:271
CUstreamCallback callback_t
The CUDA driver's raw handle for a host-side callback function.
Definition: types.hpp:257
#define throw_if_error_lazy(status__,...)
A macro for only throwing an error if we've failed - which also ensures no string is constructed unle...
Definition: error.hpp:316
CUarray handle_t
Raw CUDA driver handle for arrays (of any dimension)
Definition: array.hpp:34
array_t< T, NumDimensions > wrap(device::id_t device_id, context::handle_t context_handle, handle_t handle, dimensions_t< NumDimensions > dimensions) noexcept
Wrap an existing CUDA array in an array_t instance.
Definition: array.hpp:264
address_t address(const void *device_ptr) noexcept
Definition: types.hpp:682
Graph template node proxy (base-)class base-class node_t and supporting code.