muda
compute_graph_var_manager.h
1 #pragma once
2 #include <driver_types.h>
3 #include <memory>
4 #include <unordered_map>
5 #include <unordered_set>
6 #include <vector>
7 #include <memory>
8 #include <muda/mstl/span.h>
9 #include <muda/compute_graph/compute_graph_flag.h>
10 #include <muda/compute_graph/compute_graph_fwd.h>
11 #include <muda/compute_graph/graphviz_options.h>
12 namespace muda
13 {
15 {
16  template <typename T>
17  using S = std::shared_ptr<T>;
18 
19  public:
20  ComputeGraphVarManager() = default;
22 
23  S<ComputeGraph> create_graph(std::string_view name = "graph",
24  ComputeGraphFlag flags = {});
25 
26 
27  /**************************************************************
28  *
29  * GraphVar API
30  *
31  ***************************************************************/
32  template <typename T>
33  ComputeGraphVar<T>& create_var(std::string_view name);
34  template <typename T>
35  ComputeGraphVar<T>& create_var(std::string_view name, const T& init_value);
36  template <typename T>
37  ComputeGraphVar<T>* find_var(std::string_view name);
38 
39  bool is_using() const;
40  void sync() const;
41  void sync_on(cudaStream_t stream) const;
42 
43  template <typename... T>
44  bool is_using(const ComputeGraphVar<T>&... vars) const;
45  template <typename... T>
46  void sync(const ComputeGraphVar<T>&... vars) const;
47  template <typename... T>
48  void sync_on(cudaStream_t stream, const ComputeGraphVar<T>&... vars) const;
49 
50  bool is_using(const span<const ComputeGraphVarBase*> vars) const;
51  void sync(const span<const ComputeGraphVarBase*> vars) const;
52  void sync_on(cudaStream_t stream, const span<const ComputeGraphVarBase*> vars) const;
53 
54  const auto& graphs() const { return m_graphs; }
55  void graphviz(std::ostream& os, const ComputeGraphGraphvizOptions& options = {}) const;
56 
57  private:
58  friend class ComputeGraph;
59  friend class ComputeGraphNodeBase;
60  friend class ComputeGraphClosure;
61  std::vector<ComputeGraph*> unique_graphs(span<const ComputeGraphVarBase*> vars) const;
62  std::unordered_map<std::string, ComputeGraphVarBase*> m_vars_map;
63  std::vector<ComputeGraphVarBase*> m_vars;
64  std::unordered_set<ComputeGraph*> m_graphs;
65  span<const ComputeGraphVarBase*> var_span() const;
66 };
67 } // namespace muda
68 
69 #include "details/compute_graph_var_manager.inl"
Definition: compute_graph_var_manager.h:14
Definition: buffer_launch.h:34
Definition: assert.h:13
Definition: compute_graph_node.h:12
Definition: graphviz_options.h:5
Definition: compute_graph.h:37
Definition: compute_graph_closure.h:14