muda
kernel_node.h
1 #pragma once
2 #include <muda/graph/graph_base.h>
3 
4 namespace muda
5 {
6 class KernelNode : public GraphNode
7 {
8  public:
9  using this_type = KernelNode;
10  friend class Graph;
11 };
12 
13 template <typename U>
14 class KernelNodeParms : public NodeParms
15 {
16  std::vector<void*> m_args;
17  cudaKernelNodeParams m_parms;
18 
19  public:
20  using this_type = KernelNodeParms;
21  friend class Graph;
22  friend class std::shared_ptr<this_type>;
23  friend class std::unique_ptr<this_type>;
24  friend class std::weak_ptr<this_type>;
25 
26  template <typename... Args>
27  KernelNodeParms(Args&&... args)
28  : kernelParmData(std::forward<Args>(args)...)
29  , m_parms({})
30  {
31  }
32 
33  KernelNodeParms() {}
34  U kernelParmData;
35  auto func() { return m_parms.func; }
36  void func(void* v) { m_parms.func = v; }
37  auto grid_dim() { return m_parms.gridDim; }
38  void grid_dim(const dim3& v) { m_parms.gridDim = v; }
39  auto block_dim() { return m_parms.blockDim; }
40  void block_dim(const dim3& v) { m_parms.blockDim = v; }
41  auto shared_mem_bytes() { return m_parms.sharedMemBytes; }
42  void shared_mem_bytes(unsigned int v) { m_parms.sharedMemBytes = v; }
43  auto kernel_params() { return m_parms.kernelParams; }
44  void kernel_params(const std::vector<void*>& v)
45  {
46  m_args = v;
47  m_parms.kernelParams = m_args.data();
48  }
49  void parse(std::function<std::vector<void*>(U&)> pred)
50  {
51  m_args = pred(kernelParmData);
52  m_parms.kernelParams = m_args.data();
53  }
54  auto extra() { return m_parms.extra; }
55  void extra(void** v) { m_parms.extra = v; }
56 
57  const cudaKernelNodeParams* handle() const { return &m_parms; }
58 };
59 } // namespace muda
Definition: graph_base.h:19
Definition: assert.h:13
Definition: graph_base.h:26
Definition: graph.h:17
Definition: kernel_node.h:14
Definition: kernel_node.h:6