muda
compute_graph_builder.h
1 #pragma once
2 #include <functional>
3 #include <muda/compute_graph/compute_graph_phase.h>
4 #include <muda/compute_graph/compute_graph_fwd.h>
5 #include <functional>
6 
7 namespace muda
8 {
10 {
11  static ComputeGraphBuilder& instance();
12  using Phase = ComputeGraphPhase;
13  using PhaseAction = std::function<void()>;
14  using CaptureAction = std::function<void(cudaStream_t)>;
15 
16  public:
17  static Phase current_phase();
18  static void capture(CaptureAction&& cap);
19  static void capture(std::string_view name, CaptureAction&& cap);
20  static bool is_phase_none();
21  static bool is_phase_serial_launching();
22  static bool is_topo_building();
23  static bool is_building();
24  // return true when no graph is building or the graph is in serial launching mode
25  static bool is_direct_launching();
26  static bool is_caturing();
27 
28 
29  // do_when_direct_launch
30  // do_when_set_node => do_when_add_node & do_when_update_node
31  // if do_when_topo_building_set_node == nullptr, do_when_set_node will be called
32  // if do_when_topo_building_set_node != nullptr, do_when_topo_building_set_node will be called
33  // copy this code to use:
34  /*
35  ComputeGraphBuilder::invoke_phase_actions(
36  [&] // do_when_direct_launch
37  {
38 
39  },
40  [&] // do_when_set_node
41  {
42 
43  },
44  [&] // do_when_topo_building_set_node
45  {
46 
47  });
48  */
49  static void invoke_phase_actions(PhaseAction&& do_when_direct_launch,
50  PhaseAction&& do_when_set_node,
51  PhaseAction&& do_when_topo_building_set_node);
52 
53  // copy this code to use:
54  /*
55  ComputeGraphBuilder::invoke_phase_actions(
56  [&] // do_when_direct_launch
57  {
58 
59  },
60  [&] // do_when_set_node and do_when_topo_building_set_node
61  {
62 
63  });
64  */
65  static void invoke_phase_actions(PhaseAction&& do_when_direct_launch,
66  PhaseAction&& do_when_set_node);
67 
68  // copy this code to use:
69  /*
70  ComputeGraphBuilder::invoke_phase_actions(
71  [&] // do_in_every_phase
72  {
73 
74  });
75  */
76  static void invoke_phase_actions(PhaseAction&& do_in_every_phase);
77 
78  private:
79  friend class ComputeGraph;
80  friend class ComputeGraphVarBase;
81 
82  static void current_graph(ComputeGraph* graph);
83  friend class details::ComputeGraphAccessor;
84  static auto current_graph() { return instance().m_current_graph; }
85 
86  ComputeGraphBuilder() = default;
87  ~ComputeGraphBuilder() = default;
88 
89  ComputeGraph* m_current_graph = nullptr;
90 };
91 } // namespace muda
92 
93 #include "details/compute_graph_builder.inl"
Definition: compute_graph_var.h:16
Definition: assert.h:13
Definition: compute_graph_accessor.h:12
Definition: compute_graph.h:37
Definition: compute_graph_builder.h:9