muda
launch_base.h
1 #pragma once
2 #include <cuda.h>
3 #include <cuda_runtime.h>
4 #include <cuda_runtime_api.h>
5 #include <device_launch_parameters.h>
6 
7 #include <string>
8 #include <functional>
9 #include <memory>
10 #include <cooperative_groups.h>
11 
12 #include <cuda_profiler_api.h>
13 #include <nvtx3/nvToolsExt.h>
14 #include <nvtx3/nvToolsExtCuda.h>
15 #include <muda/type_traits/type_modifier.h>
16 #include <muda/tools/launch_info_cache.h>
17 
18 #include <muda/check/check_cuda_errors.h>
19 #include <muda/muda_def.h>
20 #include <muda/launch/event.h>
21 #include <muda/launch/kernel_tag.h>
22 
23 namespace muda
24 {
25 namespace details
26 {
27  inline void stream_error_callback(cudaStream_t stream, cudaError error, void* userdata)
28  {
29  auto callback =
30  reinterpret_cast<std::function<void(cudaStream_t, cudaError)>*>(userdata);
31  (*callback)(stream, error);
32  delete callback;
33  }
34 } // namespace details
35 
36 class ComputeGraphVarBase;
37 
38 template <typename T>
39 class ComputeGraphVar;
40 
42 {
43  protected:
44  template <typename T>
45  using S = std::shared_ptr<T>;
46  MUDA_GENERIC ::cudaStream_t stream() const { return m_stream; }
47 
48  ::cudaStream_t m_stream;
49  MUDA_HOST void pop_kernel_name();
50 
51  public:
52  static void kernel_name(std::string_view name);
53  static std::string_view kernel_name();
54 
55  MUDA_GENERIC LaunchCore(::cudaStream_t stream) MUDA_NOEXCEPT;
56 
57  void init_stream(::cudaStream_t s) { m_stream = s; }
58 
59  void push_range(const std::string& name);
60  void pop_range();
61 
62  void record(cudaEvent_t e, int flag = cudaEventRecordDefault);
63  void record(ComputeGraphVar<cudaEvent_t>& e,
64  const std::vector<ComputeGraphVarBase*>& vars);
65  template <typename... ViewT>
66  void record(ComputeGraphVar<cudaEvent_t>& e, ComputeGraphVar<ViewT>&... vars);
67  void when(cudaEvent_t e, int flag = cudaEventWaitDefault);
68  // let the host wait for the event
69  void wait(cudaEvent_t e, int flag = cudaEventWaitDefault);
70  void wait(const ComputeGraphVar<cudaEvent_t>& e,
71  const std::vector<ComputeGraphVarBase*>& vars);
72  template <typename... ViewT>
73  void wait(const ComputeGraphVar<cudaEvent_t>& e, ComputeGraphVar<ViewT>&... vars);
74  void wait();
75  void callback(const std::function<void(::cudaStream_t, ::cudaError)>& callback);
76 
77  static void wait_event(cudaEvent_t event);
78  static void wait_stream(::cudaStream_t stream);
79  static void wait_device();
80 
81  ~LaunchCore() MUDA_NOEXCEPT;
82 };
83 
84 template <typename T>
85 class LaunchBase : public LaunchCore
86 {
87  template <typename Others>
88  friend class LaunchBase;
89  using Base = LaunchCore;
90 
91  public:
92  using derived_type = T;
93  MUDA_GENERIC LaunchBase(::cudaStream_t stream) MUDA_NOEXCEPT;
94 
95  // create a named scope for better recognization (if you are using some profile tools)
96  // usage:
97  // on(stream)
98  // .push_range("part1")
99  // .next<launch>(1,1).apply(...)
100  // .pop_range()
101  // .wait();
102  T& push_range(const std::string& name);
103  T& pop_range();
104 
105 
106  // create a name for the following kernel launch
107  // viewers will record this name for the sake of better recognization when debugging
108  T& kernel_name(std::string_view name);
109  std::string_view kernel_name() const { return Base::kernel_name(); }
110 
111  // record an event on this point with current stream, you could use .when() to
112  // capture this event for synchronization
113  // flags:
114  // cudaEventRecordDefault : Default event creation flag.
115  // cudaEventRecordExternal : Event is captured in the graph as an external
116  // event node when performing stream capture.
117  T& record(cudaEvent_t e, int flag = cudaEventRecordDefault);
118 
119  T& record(ComputeGraphVar<cudaEvent_t>& e,
120  const std::vector<ComputeGraphVarBase*>& vars);
121 
122  template <typename... ViewT>
124 
125  // let the following kernels wait until the event is triggered
126  // (asynchronize with the host)
127  // usage:
128  // on(stream)
129  // .when(event)
130  // .next<launch>(1,1).apply(...)
131  // .wait();
132  // flags:
133  // cudaEventRecordDefault : Default event creation flag.
134  // cudaEventRecordExternal : Event is captured in the graph as an external
135  // event node when performing stream capture.
136  T& when(cudaEvent_t e, int flag = cudaEventWaitDefault);
137  // let the host wait for the event
138  T& wait(cudaEvent_t e, int flag = cudaEventWaitDefault);
139  T& wait(const ComputeGraphVar<cudaEvent_t>& e,
140  const std::vector<ComputeGraphVarBase*>& vars);
141  template <typename... ViewT>
142  T& wait(const ComputeGraphVar<cudaEvent_t>& e, ComputeGraphVar<ViewT>&... vars);
143 
144 
145  // let the host wait for the current stream
146  T& wait();
147 
148  // register a host callback function, which will be called when all the jobs before
149  // this point are done.
150  T& callback(const std::function<void(::cudaStream_t, ::cudaError)>& callback);
151 
152  template <typename Next>
153  Next next(Next n);
154  template <typename Next, typename... Args>
155  Next next(Args&&... args);
156 
157  ~LaunchBase() MUDA_NOEXCEPT;
158 
159  protected:
160  T& pop_kernel_name();
161 
162  private:
163  T& derived() MUDA_NOEXCEPT { return *(T*)(this); }
164 };
165 
166 class Empty : public LaunchBase<Empty>
167 {
168  public:
169  Empty(::cudaStream_t stream = nullptr)
170  : LaunchBase(stream)
171  {
172  }
173 };
174 
175 Empty on(::cudaStream_t stream);
176 
177 Empty on();
178 
179 void wait_device();
180 void wait_stream(::cudaStream_t stream);
181 void wait_event(cudaEvent_t event);
182 } // namespace muda
183 
184 #include "details/launch_base.inl"
Definition: launch_base.h:41
Definition: launch_base.h:166
Definition: buffer_launch.h:34
Definition: assert.h:13
Definition: launch_base.h:85