muda
graph_exec.h
1 #pragma once
2 #include <muda/graph/graph_base.h>
3 #include <muda/graph/kernel_node.h>
4 #include <muda/graph/memory_node.h>
5 #include <muda/graph/event_node.h>
6 #include <muda/graph/graph_viewer.h>
7 
8 namespace muda
9 {
10 class GraphExec
11 {
12  template <typename T>
13  using S = std::shared_ptr<T>;
14  template <typename T>
15  using U = std::unique_ptr<T>;
16  cudaGraphExec_t m_handle;
18 
19  public:
20  friend class Graph;
21 
22  GraphExec();
23 
24  // delete copy
25  GraphExec(const GraphExec&) = delete;
26  GraphExec& operator=(const GraphExec&) = delete;
27 
28  // move
29  GraphExec(GraphExec&& other);
30  GraphExec& operator=(GraphExec&& other);
31 
32  void upload(cudaStream_t stream = nullptr);
33 
34  void launch(cudaStream_t stream = nullptr);
35 
36  template <typename T>
37  void set_kernel_node_parms(S<KernelNode> node, const S<KernelNodeParms<T>>& new_parms);
38 
39 
40  void set_memcpy_node_parms(S<MemcpyNode> node,
41  void* dst,
42  const void* src,
43  size_t size_bytes,
44  cudaMemcpyKind kind);
45  void set_memcpy_node_parms(S<MemcpyNode> node, const cudaMemcpy3DParms& parms);
46  void set_memset_node_parms(S<MemsetNode> node, const cudaMemsetParams& parms);
47 
48 
49  void set_event_record_node_parms(S<EventRecordNode> node, cudaEvent_t event);
50  void set_event_wait_node_parms(S<EventWaitNode> node, cudaEvent_t event);
51 
52  ~GraphExec();
53 
54  cudaGraphExec_t handle() const { return m_handle; }
55 
56  GraphViewer viewer() const;
57  private:
58  // keep the ref count > 0 for those whose data should be kept alive for the graph life.
59  std::list<S<NodeParms>> m_cached;
60 };
61 } // namespace muda
62 
63 #include "details/graph_exec.inl"
Definition: graph_viewer.h:7
Definition: assert.h:13
Definition: graph.h:17
Definition: graph_exec.h:10
Definition: kernel_node.h:14