5 #include <muda/launch/stream.h> 6 #include <muda/launch/event.h> 7 #include <muda/mstl/span.h> 8 #include <muda/graph/graph.h> 9 #include <muda/graph/graph_viewer.h> 10 #include <muda/compute_graph/compute_graph_flag.h> 11 #include <muda/compute_graph/compute_graph_phase.h> 12 #include <muda/compute_graph/compute_graph_node_type.h> 13 #include <muda/compute_graph/compute_graph_node_id.h> 14 #include <muda/compute_graph/compute_graph_closure_id.h> 15 #include <muda/compute_graph/compute_graph_var_id.h> 16 #include <muda/compute_graph/compute_graph_var_usage.h> 17 #include <muda/compute_graph/compute_graph_dependency.h> 18 #include <muda/compute_graph/graphviz_options.h> 19 #include <muda/compute_graph/compute_graph_fwd.h> 27 using U64IdWithType::U64IdWithType;
43 std::string m_node_name;
76 using U = std::unique_ptr<T>;
78 using S = std::shared_ptr<T>;
83 S<GraphExec> m_graph_exec{
nullptr};
85 std::unordered_map<NodeId::value_type, cudaGraph_t> m_sub_graphs;
87 std::vector<std::pair<std::string, ComputeGraphClosure*>> m_closures;
89 std::map<VarId, details::LocalVarId> m_global_to_local_var_id;
90 std::vector<details::LocalVarInfo> m_related_vars;
94 std::vector<ComputeGraphNodeBase*> m_nodes;
95 std::vector<std::vector<ComputeGraphNodeBase*>> m_graph_nodes;
96 std::vector<Dependency> m_deps;
98 std::vector<int> m_closure_need_update;
109 std::string_view name =
"graph",
110 ComputeGraphFlag flag = ComputeGraphFlag::HostLaunch);
120 std::string_view name()
const {
return m_name; }
141 void launch(
bool single_stream, cudaStream_t s =
nullptr);
143 void launch(cudaStream_t s =
nullptr) {
return launch(
false, s); }
159 void capture(std::function<
void(cudaStream_t)>&& f);
160 void capture(std::string_view name, std::function<
void(cudaStream_t)>&& f);
183 void cuda_graph_add_deps();
187 void serial_launch();
191 void check_vars_valid();
194 ComputeGraph& add_node(std::string&& name,
const std::function<
void()>& f);
198 span<const Dependency> dep_span(
size_t begin,
size_t count)
const;
200 void set_current_graph_as_this();
202 static void clear_current_graph();
204 static Stream& shared_capture_stream();
207 ClosureId current_closure_id()
const {
return m_current_closure_id; };
209 NodeId current_node_id()
const {
return m_current_node_id; };
211 size_t current_access_index()
const {
return m_access_graph_index; }
213 ComputeGraphPhase current_graph_phase()
const;
218 bool m_need_update =
false;
221 ComputeGraphPhase m_current_graph_phase = ComputeGraphPhase::None;
222 bool m_allow_access_graph =
false;
223 size_t m_access_graph_index = 0;
224 bool m_allow_node_adding =
true;
226 cudaStream_t m_current_single_stream =
nullptr;
227 bool m_is_capturing =
false;
229 bool m_is_in_capture_func =
false;
231 bool m_is_topo_built =
false;
235 #include "details/compute_graph.inl" Definition: compute_graph_node_id.h:5
Definition: graph_viewer.h:7
Definition: id_with_type.h:9
Definition: compute_graph.h:29
Definition: compute_graph_var.h:16
Definition: compute_graph_dependency.h:5
Definition: compute_graph_var_manager.h:14
Definition: compute_graph.h:40
Definition: compute_graph_closure_id.h:5
Definition: compute_graph_node.h:12
Definition: graphviz_options.h:5
RAII wrapper for cudaEvent
Definition: event.h:14
Definition: compute_graph_accessor.h:12
RAII wrapper for cudaStream
Definition: stream.h:17
Definition: compute_graph.h:37
Definition: compute_graph_closure.h:14
The event has been recorded.
QueryResult
Definition: event.h:27
Definition: compute_graph_builder.h:9
Definition: compute_graph.h:25
Definition: compute_graph.h:52