muda
compute_graph_accessor.h
1 #pragma once
2 #include <cuda_runtime.h>
3 #include <muda/compute_graph/compute_graph_fwd.h>
4 #include <muda/graph/kernel_node.h>
5 #include <muda/graph/memory_node.h>
6 #include <muda/graph/event_node.h>
7 namespace muda
8 {
9 namespace details
10 {
11  // allow devlopers to access some internal function
13  {
14  friend class ComputeGraph;
15  ComputeGraph& m_cg;
16  template <typename T>
17  using S = std::shared_ptr<T>;
18 
19  public:
21 
24 
25  /************************************************************************************
26  *
27  * Graph Add/Update node API
28  *
29  * Automatically add or update graph node by parms (distincted by ComputeGraphPhase)
30  *
31  *************************************************************************************/
32  template <typename T>
33  void set_kernel_node(const S<KernelNodeParms<T>>& kernelParms);
34  void set_memcpy_node(void* dst, const void* src, size_t size_bytes, cudaMemcpyKind kind);
35  void set_memcpy_node(const cudaMemcpy3DParms& parms);
36  void set_memset_node(const cudaMemsetParams& parms);
37  void set_event_record_node(cudaEvent_t event);
38  void set_event_wait_node(cudaEvent_t event);
39  void set_capture_node(cudaGraph_t sub_graph);
40 
41  /************************************************************************************
42  *
43  * Current State Query API
44  *
45  *************************************************************************************/
46  auto current_closure() const
47  -> const std::pair<std::string, ComputeGraphClosure*>&;
48  auto current_closure() -> std::pair<std::string, ComputeGraphClosure*>&;
49  template <typename T>
50  T* current_node();
51  const ComputeGraphNodeBase* current_node() const;
52  ComputeGraphNodeBase* current_node();
53  cudaStream_t current_stream() const;
54  cudaStream_t capture_stream() const;
55 
56  bool is_topo_built() const;
57 
58  /************************************************************************************
59  *
60  * Current State Check API
61  *
62  *************************************************************************************/
63  void check_allow_var_eval() const;
64  void check_allow_node_adding() const;
65 
66  private:
67  friend class muda::ComputeGraphVarBase;
68  void set_var_usage(VarId id, ComputeGraphVarUsage usage);
69 
70  template <typename T>
71  void add_kernel_node(const S<KernelNodeParms<T>>& kernelParms);
72  template <typename T>
73  void update_kernel_node(const S<KernelNodeParms<T>>& kernelParms);
74 
75  void add_memcpy_node(void* dst, const void* src, size_t size_bytes, cudaMemcpyKind kind);
76  void update_memcpy_node(void* dst, const void* src, size_t size_bytes, cudaMemcpyKind kind);
77  void add_memcpy_node(const cudaMemcpy3DParms& parms);
78  void update_memcpy_node(const cudaMemcpy3DParms& parms);
79 
80  void add_memset_node(const cudaMemsetParams& parms);
81  void update_memset_node(const cudaMemsetParams& parms);
82 
83  void add_event_record_node(cudaEvent_t event);
84  void update_event_record_node(cudaEvent_t event);
85 
86  void add_event_wait_node(cudaEvent_t event);
87  void update_event_wait_node(cudaEvent_t event);
88 
89  void add_capture_node(cudaGraph_t sub_graph);
90  void update_capture_node(cudaGraph_t sub_graph);
91 
92  template <typename F>
93  void access_graph(F&& f);
94 
95  template <typename F>
96  void access_graph_exec(F&& f);
97 
98  //auto&& temp_var_usage()
99  //{
100  // return std::move(m_cg.m_temp_node_info.var_usage);
101  //}
102 
103  template <typename NodeType, typename F>
104  NodeType* get_or_create_node(F&& f);
105  };
106 } // namespace details
107 } // namespace muda
108 
109 #include "details/compute_graph_accessor.inl"
Definition: compute_graph_var.h:16
Definition: assert.h:13
Definition: compute_graph_node.h:12
Definition: compute_graph_accessor.h:12
Definition: kernel_node.h:14
Definition: compute_graph.h:37
Definition: compute_graph_var_id.h:5