muda
graph.h
1 #pragma once
2 #include <unordered_map>
3 #include <unordered_set>
4 
5 #include <muda/graph/graph_base.h>
6 #include <muda/graph/graph_exec.h>
7 
8 #include <muda/graph/kernel_node.h>
9 #include <muda/graph/memory_node.h>
10 #include <muda/graph/host_node.h>
11 #include <muda/graph/event_node.h>
12 
13 #include <muda/graph/graph_instantiate_flag.h>
14 
15 namespace muda
16 {
17 class Graph
18 {
19  template <typename T>
20  using S = std::shared_ptr<T>;
21  template <typename T>
22  using U = std::unique_ptr<T>;
23 
24  public:
25  Graph();
26  ~Graph();
27 
28  // delete copy
29  Graph(const Graph&) = delete;
30  Graph& operator=(const Graph&) = delete;
31 
32  // move
33  Graph(Graph&&);
34  Graph& operator=(Graph&&);
35 
36 
37  friend class GraphExec;
38  friend class std::shared_ptr<Graph>;
39 
40  MUDA_NODISCARD S<GraphExec> instantiate();
41  MUDA_NODISCARD S<GraphExec> instantiate(Flags<GraphInstantiateFlagBit> flags);
42 
43  template <typename T>
44  S<KernelNode> add_kernel_node(const S<KernelNodeParms<T>>& kernelParms,
45  const std::vector<S<GraphNode>>& deps);
46  template <typename T>
47  S<KernelNode> add_kernel_node(const S<KernelNodeParms<T>>& kernelParms);
48 
49 
50  template <typename T>
51  S<HostNode> add_host_node(const S<HostNodeParms<T>>& hostParms,
52  const std::vector<S<GraphNode>>& deps);
53  template <typename T>
54  S<HostNode> add_host_node(const S<HostNodeParms<T>>& hostParms);
55 
56 
57  S<MemcpyNode> add_memcpy_node(void* dst,
58  const void* src,
59  size_t size_bytes,
60  cudaMemcpyKind kind,
61  const std::vector<S<GraphNode>>& deps);
62  S<MemcpyNode> add_memcpy_node(void* dst, const void* src, size_t size_bytes, cudaMemcpyKind kind);
63  S<MemcpyNode> add_memcpy_node(const cudaMemcpy3DParms& parms);
64  S<MemcpyNode> add_memcpy_node(const cudaMemcpy3DParms& parms,
65  const std::vector<S<GraphNode>>& deps);
66 
67  S<MemsetNode> add_memset_node(const cudaMemsetParams& parms,
68  const std::vector<S<GraphNode>>& deps);
69  S<MemsetNode> add_memset_node(const cudaMemsetParams& parms);
70 
71 
72  S<EventRecordNode> add_event_record_node(cudaEvent_t e,
73  const std::vector<S<GraphNode>>& deps);
74  S<EventRecordNode> add_event_record_node(cudaEvent_t e);
75  S<EventWaitNode> add_event_wait_node(cudaEvent_t e,
76  const std::vector<S<GraphNode>>& deps);
77  S<EventWaitNode> add_event_wait_node(cudaEvent_t e);
78 
79 
80  void add_dependency(S<GraphNode> from, S<GraphNode> to);
81 
82  cudaGraph_t handle() const { return m_handle; }
83  cudaGraph_t handle() { return m_handle; }
84  static auto create() { return std::make_shared<Graph>(); }
85 
86  private:
87  cudaGraph_t m_handle;
88  // keep the ref count > 0 for those whose data should be kept alive for the graph life.
89  std::list<S<NodeParms>> m_cached;
90  static std::vector<cudaGraphNode_t> map_dependencies(const std::vector<S<GraphNode>>& deps);
91 };
92 } // namespace muda
93 
94 #include "details/graph.inl"
Definition: host_node.h:14
Definition: assert.h:13
Definition: graph.h:17
Definition: graph_exec.h:10
Definition: kernel_node.h:14