muda
stream.h
1 #pragma once
2 #include <cuda.h>
3 #include <cuda_runtime.h>
4 #include <cuda_runtime_api.h>
5 #include <device_launch_parameters.h>
6 #include <muda/check/check_cuda_errors.h>
7 #include <muda/tools/temp_buffer.h>
8 
9 namespace muda
10 {
11 template <typename T>
12 class DeviceBuffer;
13 
17 class Stream
18 {
19  cudaStream_t m_handle = nullptr;
20 
21  public:
22  enum class Flag : unsigned int
23  {
24  eDefault = cudaStreamDefault,
25  eNonBlocking = cudaStreamNonBlocking
26  };
27 
28  MUDA_NODISCARD Stream(Flag f = Flag::eDefault);
29  ~Stream();
30 
31  operator cudaStream_t() const { return m_handle; }
32  cudaStream_t view() const { return m_handle; }
33 
34  // delete copy constructor and copy assignment operator
35  Stream(const Stream&) = delete;
36  Stream& operator=(const Stream&) = delete;
37 
38  // allow move constructor and move assignment operator
39  Stream(Stream&& o) MUDA_NOEXCEPT;
40  Stream& operator=(Stream&& o) MUDA_NOEXCEPT;
41 
42  void wait() const;
43 
44  void begin_capture(cudaStreamCaptureMode mode = cudaStreamCaptureModeThreadLocal) const;
45  void end_capture(cudaGraph_t* graph) const;
46 
47  static Stream& Default();
48 
49  class TailLaunch
50  {
51  public:
52  MUDA_DEVICE TailLaunch(){};
53  MUDA_DEVICE operator cudaStream_t() const;
54  };
55 
57  {
58  public:
59  MUDA_DEVICE FireAndForget(){};
60  MUDA_DEVICE operator cudaStream_t() const;
61  };
62 
64  {
65  public:
66  MUDA_DEVICE GraphTailLaunch(){};
67  MUDA_DEVICE operator cudaStream_t() const;
68  };
69 
71  {
72  public:
73  MUDA_DEVICE GraphFireAndForget(){};
74  MUDA_DEVICE operator cudaStream_t() const;
75  };
76 
77  std::byte* workspace(size_t byte_size);
78 
79  private:
80  Stream(nullptr_t)
81  : m_handle(nullptr)
82  {
83  }
84  details::ByteTempBuffer m_workspace;
85 };
86 
87 
88 } // namespace muda
89 
90 #include "details/stream.inl"
Definition: temp_buffer.h:7
Definition: stream.h:56
Definition: stream.h:70
Definition: stream.h:49
Definition: assert.h:13
Definition: stream.h:63
RAII wrapper for cudaStream
Definition: stream.h:17
Definition: kernel_tag.h:9