muda
compute_graph.h
1 #pragma once
2 #include <map>
3 #include <functional>
4 #include <set>
5 #include <muda/launch/stream.h>
6 #include <muda/launch/event.h>
7 #include <muda/mstl/span.h>
8 #include <muda/graph/graph.h>
9 #include <muda/graph/graph_viewer.h>
10 #include <muda/compute_graph/compute_graph_flag.h>
11 #include <muda/compute_graph/compute_graph_phase.h>
12 #include <muda/compute_graph/compute_graph_node_type.h>
13 #include <muda/compute_graph/compute_graph_node_id.h>
14 #include <muda/compute_graph/compute_graph_closure_id.h>
15 #include <muda/compute_graph/compute_graph_var_id.h>
16 #include <muda/compute_graph/compute_graph_var_usage.h>
17 #include <muda/compute_graph/compute_graph_dependency.h>
18 #include <muda/compute_graph/graphviz_options.h>
19 #include <muda/compute_graph/compute_graph_fwd.h>
20 
21 namespace muda
22 {
23 namespace details
24 {
25  class LocalVarId : public U64IdWithType
26  {
27  using U64IdWithType::U64IdWithType;
28  };
30  {
31  public:
32  LocalVarId id{};
33  ComputeGraphVarBase* var = nullptr;
34  };
35 } // namespace details
36 
38 {
39  public:
41  {
42  ComputeGraph& m_cg;
43  std::string m_node_name;
44 
45  public:
46  AddNodeProxy(ComputeGraph& cg, std::string_view node_name);
47  ComputeGraph& operator<<(std::function<void()>&& f) &&;
48  };
49  // A depends on B : from B to A
51 
53  {
54  ComputeGraph& m_cg;
55 
56  public:
57  GraphPhaseGuard(ComputeGraph& cg, ComputeGraphPhase phase);
58  ~GraphPhaseGuard();
59  };
60 
61  // delete copy
62  ComputeGraph(const ComputeGraph&) = delete;
63  ComputeGraph& operator=(const ComputeGraph&) = delete;
64 
65  // delete move
66  ComputeGraph(ComputeGraph&&) = delete;
67  ComputeGraph& operator=(ComputeGraph&&) = delete;
68 
69  private:
70  //class TempNodeInfo
71  //{
72  // public:
73  // std::map<VarId, ComputeGraphVarUsage> var_usage;
74  //};
75  template <typename T>
76  using U = std::unique_ptr<T>;
77  template <typename T>
78  using S = std::shared_ptr<T>;
79 
80  friend class ComputeGraphVarBase;
81 
82  Graph m_graph;
83  S<GraphExec> m_graph_exec{nullptr};
84 
85  std::unordered_map<NodeId::value_type, cudaGraph_t> m_sub_graphs;
86 
87  std::vector<std::pair<std::string, ComputeGraphClosure*>> m_closures;
88 
89  std::map<VarId, details::LocalVarId> m_global_to_local_var_id;
90  std::vector<details::LocalVarInfo> m_related_vars;
91  void emplace_related_var(ComputeGraphVarBase* var);
92 
93 
94  std::vector<ComputeGraphNodeBase*> m_nodes;
95  std::vector<std::vector<ComputeGraphNodeBase*>> m_graph_nodes;
96  std::vector<Dependency> m_deps;
97 
98  std::vector<int> m_closure_need_update;
99  ComputeGraphVarManager* m_var_manager = nullptr;
100 
101  friend class ComputeGraphVarManager;
102 
103  Event m_event;
104  mutable Event::QueryResult m_event_result = Event::QueryResult::eFinished;
106 
107  public:
109  std::string_view name = "graph",
110  ComputeGraphFlag flag = ComputeGraphFlag::HostLaunch);
111 
112  ~ComputeGraph();
113 
114  /**************************************************************
115  *
116  * Info API
117  *
118  ***************************************************************/
119 
120  std::string_view name() const { return m_name; }
121 
122  /**************************************************************
123  *
124  * GraphNode API
125  *
126  ***************************************************************/
127 
128  AddNodeProxy create_node(std::string_view node_name);
129 
130 
131  /**************************************************************
132  *
133  * Graph Launch API
134  *
135  ***************************************************************/
136 
137  void update();
138 
139  void build();
140 
141  void launch(bool single_stream, cudaStream_t s = nullptr);
142 
143  void launch(cudaStream_t s = nullptr) { return launch(false, s); }
144 
145  /**************************************************************
146  *
147  * Graph Event Query API
148  *
149  ***************************************************************/
150 
151  Event::QueryResult query() const;
152 
153  /**************************************************************
154  *
155  * Graph Closure Capture Node API
156  *
157  ***************************************************************/
158 
159  void capture(std::function<void(cudaStream_t)>&& f);
160  void capture(std::string_view name, std::function<void(cudaStream_t)>&& f);
161 
162  /**************************************************************
163  *
164  * Graph Visualization API
165  *
166  ***************************************************************/
167 
168  void graphviz(std::ostream& o, const ComputeGraphGraphvizOptions& options = {});
169 
170  /**************************************************************
171  *
172  * Graph Viewer API
173  *
174  ***************************************************************/
175 
176  GraphViewer viewer();
177 
178  operator GraphViewer() { return viewer(); }
179 
180  private: // internal method
181  void topo_build();
182 
183  void cuda_graph_add_deps();
184 
185  void build_deps();
186 
187  void serial_launch();
188 
189  void _update();
190 
191  void check_vars_valid();
192 
193  friend class AddNodeProxy;
194  ComputeGraph& add_node(std::string&& name, const std::function<void()>& f);
195 
196  friend class ComputeGraphNodeBase;
197  friend class ComputeGraphClosure;
198  span<const Dependency> dep_span(size_t begin, size_t count) const;
199 
200  void set_current_graph_as_this();
201 
202  static void clear_current_graph();
203 
204  static Stream& shared_capture_stream();
205 
206  friend class ComputeGraphBuilder;
207  ClosureId current_closure_id() const { return m_current_closure_id; };
208 
209  NodeId current_node_id() const { return m_current_node_id; };
210 
211  size_t current_access_index() const { return m_access_graph_index; }
212 
213  ComputeGraphPhase current_graph_phase() const;
214 
215  private: // internal data
217  std::string m_name;
218  bool m_need_update = false;
219  ClosureId m_current_closure_id;
220  NodeId m_current_node_id;
221  ComputeGraphPhase m_current_graph_phase = ComputeGraphPhase::None;
222  bool m_allow_access_graph = false;
223  size_t m_access_graph_index = 0;
224  bool m_allow_node_adding = true;
225  // TempNodeInfo m_temp_node_info;
226  cudaStream_t m_current_single_stream = nullptr;
227  bool m_is_capturing = false;
228  // in capture func, we don't allow any var eval()
229  bool m_is_in_capture_func = false;
230  // if we have already built the topo, we don't do that again
231  bool m_is_topo_built = false;
232 };
233 } // namespace muda
234 
235 #include "details/compute_graph.inl"
Definition: compute_graph_node_id.h:5
Definition: graph_viewer.h:7
Definition: id_with_type.h:9
Definition: compute_graph.h:29
Definition: compute_graph_var.h:16
Definition: compute_graph_dependency.h:5
Definition: compute_graph_var_manager.h:14
Definition: compute_graph.h:40
Definition: compute_graph_closure_id.h:5
Definition: assert.h:13
Definition: compute_graph_node.h:12
Definition: graphviz_options.h:5
RAII wrapper for cudaEvent
Definition: event.h:14
Definition: graph.h:17
Definition: compute_graph_accessor.h:12
RAII wrapper for cudaStream
Definition: stream.h:17
Definition: compute_graph.h:37
Definition: compute_graph_closure.h:14
The event has been recorded.
QueryResult
Definition: event.h:27
Definition: compute_graph_builder.h:9
Definition: compute_graph.h:25
Definition: compute_graph.h:52