2 #include <cuda_runtime.h> 3 #include <muda/compute_graph/compute_graph_fwd.h> 4 #include <muda/graph/kernel_node.h> 5 #include <muda/graph/memory_node.h> 6 #include <muda/graph/event_node.h> 17 using S = std::shared_ptr<T>;
34 void set_memcpy_node(
void* dst,
const void* src,
size_t size_bytes, cudaMemcpyKind kind);
35 void set_memcpy_node(
const cudaMemcpy3DParms& parms);
36 void set_memset_node(
const cudaMemsetParams& parms);
37 void set_event_record_node(cudaEvent_t event);
38 void set_event_wait_node(cudaEvent_t event);
39 void set_capture_node(cudaGraph_t sub_graph);
46 auto current_closure()
const 47 ->
const std::pair<std::string, ComputeGraphClosure*>&;
48 auto current_closure() -> std::pair<std::string, ComputeGraphClosure*>&;
53 cudaStream_t current_stream()
const;
54 cudaStream_t capture_stream()
const;
56 bool is_topo_built()
const;
63 void check_allow_var_eval()
const;
64 void check_allow_node_adding()
const;
68 void set_var_usage(
VarId id, ComputeGraphVarUsage usage);
75 void add_memcpy_node(
void* dst,
const void* src,
size_t size_bytes, cudaMemcpyKind kind);
76 void update_memcpy_node(
void* dst,
const void* src,
size_t size_bytes, cudaMemcpyKind kind);
77 void add_memcpy_node(
const cudaMemcpy3DParms& parms);
78 void update_memcpy_node(
const cudaMemcpy3DParms& parms);
80 void add_memset_node(
const cudaMemsetParams& parms);
81 void update_memset_node(
const cudaMemsetParams& parms);
83 void add_event_record_node(cudaEvent_t event);
84 void update_event_record_node(cudaEvent_t event);
86 void add_event_wait_node(cudaEvent_t event);
87 void update_event_wait_node(cudaEvent_t event);
89 void add_capture_node(cudaGraph_t sub_graph);
90 void update_capture_node(cudaGraph_t sub_graph);
93 void access_graph(F&& f);
96 void access_graph_exec(F&& f);
103 template <
typename NodeType,
typename F>
104 NodeType* get_or_create_node(F&& f);
109 #include "details/compute_graph_accessor.inl" Definition: compute_graph_var.h:16
Definition: compute_graph_node.h:12
Definition: compute_graph_accessor.h:12
Definition: kernel_node.h:14
Definition: compute_graph.h:37
Definition: compute_graph_var_id.h:5