muda
compute_graph_var.h
1 #pragma once
2 #include <string>
3 #include <set>
4 #include <map>
5 #include <muda/launch/event.h>
6 #include <muda/mstl/span.h>
7 #include <muda/type_traits/type_modifier.h>
8 #include <muda/compute_graph/compute_graph_closure_id.h>
9 #include <muda/compute_graph/compute_graph_var_usage.h>
10 #include <muda/compute_graph/compute_graph_var_id.h>
11 #include <muda/compute_graph/graphviz_options.h>
12 #include <muda/compute_graph/compute_graph_fwd.h>
13 
14 namespace muda
15 {
17 {
18  std::string_view m_name;
19  ComputeGraphVarManager* m_var_manager = nullptr;
20  VarId m_var_id;
21  bool m_is_valid;
22 
23  public:
24  std::string_view name() const MUDA_NOEXCEPT { return m_name; }
25  VarId var_id() const MUDA_NOEXCEPT { return m_var_id; }
26  bool is_valid() const MUDA_NOEXCEPT { return m_is_valid; }
27  void update();
28  Event::QueryResult query();
29  bool is_using();
30  void sync();
31  virtual void graphviz_def(std::ostream& os,
32  const ComputeGraphGraphvizOptions& options) const;
33  virtual void graphviz_id(std::ostream& os, const ComputeGraphGraphvizOptions& options) const;
34 
35  protected:
36  template <typename RWView>
37  RWView _eval(const RWView& view);
38  template <typename ROView>
39  ROView _ceval(ROView& view) const;
40 
41  friend class ComputeGraph;
42  friend class ComputeGraphVarManager;
43 
45  std::string_view name,
46  VarId var_id) MUDA_NOEXCEPT : m_var_manager(var_manager),
47  m_name(name),
48  m_var_id(var_id),
49  m_is_valid(false)
50  {
51  }
52 
54  std::string_view name,
55  VarId var_id,
56  bool is_valid) MUDA_NOEXCEPT : m_var_manager(var_manager),
57  m_name(name),
58  m_var_id(var_id),
59  m_is_valid(is_valid)
60  {
61  }
62 
63  virtual ~ComputeGraphVarBase() = default;
64 
65 
66  void base_update();
67 
68  friend class LaunchCore;
69 
70  mutable std::set<ClosureId> m_closure_ids;
71 
72  private:
73  void _building_eval(ComputeGraphVarUsage usage) const;
74  void base_building_eval();
75  void base_building_ceval() const;
76  void remove_related_closure_infos(ComputeGraph* graph);
77 
78  class RelatedClosureInfo
79  {
80  public:
81  ComputeGraph* graph;
82  std::set<ClosureId> closure_ids;
83  };
84 
85  mutable std::map<ComputeGraph*, RelatedClosureInfo> m_related_closure_infos;
86 };
87 
88 template <typename T>
90 {
91  public:
92  static_assert(!std::is_const_v<T>, "T must not be const");
93  using ROViewer = read_only_viewer_t<T>;
94  using RWViewer = T;
95  static_assert(std::is_convertible_v<RWViewer, ROViewer>,
96  "RWViewer must be convertible to ROView");
97 
98  protected:
99  friend class ComputeGraph;
100  friend class ComputeGraphVarManager;
101 
102  using ComputeGraphVarBase::ComputeGraphVarBase;
103 
104  ComputeGraphVar(ComputeGraphVarManager* var_manager, std::string_view name, VarId var_id) MUDA_NOEXCEPT
105  : ComputeGraphVarBase(var_manager, name, var_id)
106  {
107  }
108 
110  std::string_view name,
111  VarId var_id,
112  const T& init_value) MUDA_NOEXCEPT
113  : ComputeGraphVarBase(var_manager, name, var_id, true),
114  m_value(init_value)
115  {
116  }
117 
118  virtual ~ComputeGraphVar() = default;
119 
120  public:
121  RWViewer eval() { return _eval(m_value); }
122  ROViewer ceval() const { return _ceval(m_value); }
123 
124  operator ROViewer() const { return ceval(); }
125  operator RWViewer() { return eval(); }
126 
127  void update(const RWViewer& view);
128  ComputeGraphVar<T>& operator=(const RWViewer& view);
129  virtual void graphviz_def(std::ostream& os,
130  const ComputeGraphGraphvizOptions& options) const override;
131 
132  private:
133  RWViewer m_value;
134 };
135 
136 // for host memory
137 template <typename T>
139 {
140  using type = const T*;
141 };
142 template <typename T>
143 struct read_write_viewer<const T*>
144 {
145  using type = T*;
146 };
147 
148 // for cuda event
149 template <>
150 struct read_only_viewer<cudaEvent_t>
151 {
152  using type = cudaEvent_t;
153 };
154 template <>
155 struct read_write_viewer<cudaEvent_t>
156 {
157  using type = cudaEvent_t;
158 };
159 
160 } // namespace muda
161 
162 
163 #include "details/compute_graph_var.inl"
Definition: type_modifier.h:21
Definition: type_modifier.h:27
Definition: launch_base.h:41
Definition: compute_graph_var.h:16
Definition: compute_graph_var_manager.h:14
Definition: buffer_launch.h:34
Definition: assert.h:13
Definition: graphviz_options.h:5
Definition: compute_graph.h:37
QueryResult
Definition: event.h:27
Definition: compute_graph_var_id.h:5