3 #include <cuda_runtime.h> 4 #include <cuda_runtime_api.h> 5 #include <device_launch_parameters.h> 10 #include <cooperative_groups.h> 12 #include <cuda_profiler_api.h> 13 #include <nvtx3/nvToolsExt.h> 14 #include <nvtx3/nvToolsExtCuda.h> 15 #include <muda/type_traits/type_modifier.h> 16 #include <muda/tools/launch_info_cache.h> 18 #include <muda/check/check_cuda_errors.h> 19 #include <muda/muda_def.h> 20 #include <muda/launch/event.h> 21 #include <muda/launch/kernel_tag.h> 27 inline void stream_error_callback(cudaStream_t stream, cudaError error,
void* userdata)
30 reinterpret_cast<std::function<void(cudaStream_t, cudaError)>*
>(userdata);
31 (*callback)(stream, error);
36 class ComputeGraphVarBase;
39 class ComputeGraphVar;
45 using S = std::shared_ptr<T>;
46 MUDA_GENERIC ::cudaStream_t stream()
const {
return m_stream; }
48 ::cudaStream_t m_stream;
49 MUDA_HOST
void pop_kernel_name();
52 static void kernel_name(std::string_view name);
53 static std::string_view kernel_name();
55 MUDA_GENERIC
LaunchCore(::cudaStream_t stream) MUDA_NOEXCEPT;
57 void init_stream(::cudaStream_t s) { m_stream = s; }
59 void push_range(
const std::string& name);
62 void record(cudaEvent_t e,
int flag = cudaEventRecordDefault);
64 const std::vector<ComputeGraphVarBase*>& vars);
65 template <
typename... ViewT>
67 void when(cudaEvent_t e,
int flag = cudaEventWaitDefault);
69 void wait(cudaEvent_t e,
int flag = cudaEventWaitDefault);
71 const std::vector<ComputeGraphVarBase*>& vars);
72 template <
typename... ViewT>
75 void callback(
const std::function<
void(::cudaStream_t, ::cudaError)>& callback);
77 static void wait_event(cudaEvent_t event);
78 static void wait_stream(::cudaStream_t stream);
79 static void wait_device();
87 template <
typename Others>
92 using derived_type = T;
93 MUDA_GENERIC
LaunchBase(::cudaStream_t stream) MUDA_NOEXCEPT;
102 T& push_range(
const std::string& name);
108 T& kernel_name(std::string_view name);
109 std::string_view kernel_name()
const {
return Base::kernel_name(); }
117 T& record(cudaEvent_t e,
int flag = cudaEventRecordDefault);
120 const std::vector<ComputeGraphVarBase*>& vars);
122 template <
typename... ViewT>
136 T& when(cudaEvent_t e,
int flag = cudaEventWaitDefault);
138 T& wait(cudaEvent_t e,
int flag = cudaEventWaitDefault);
140 const std::vector<ComputeGraphVarBase*>& vars);
141 template <
typename... ViewT>
150 T& callback(
const std::function<
void(::cudaStream_t, ::cudaError)>& callback);
152 template <
typename Next>
154 template <
typename Next,
typename... Args>
155 Next next(Args&&... args);
160 T& pop_kernel_name();
163 T& derived() MUDA_NOEXCEPT {
return *(T*)(
this); }
169 Empty(::cudaStream_t stream =
nullptr)
175 Empty on(::cudaStream_t stream);
180 void wait_stream(::cudaStream_t stream);
181 void wait_event(cudaEvent_t event);
184 #include "details/launch_base.inl" Definition: launch_base.h:41
Definition: launch_base.h:166
Definition: buffer_launch.h:34
Definition: launch_base.h:85