muda
parallel_for.h
Go to the documentation of this file.
1 /*****************************************************************/
11 #pragma once
12 #include <muda/launch/launch_base.h>
13 #include <muda/launch/kernel_tag.h>
14 #include <stdexcept>
15 #include <exception>
16 
17 namespace muda
18 {
19 namespace details
20 {
21  template <typename F>
23  {
24  public:
25  F callable;
26  int count;
27  template <typename U>
28  MUDA_GENERIC ParallelForCallable(U&& callable, int count) MUDA_NOEXCEPT
29  : callable(std::forward<U>(callable)),
30  count(count)
31  {
32  }
33  // MUDA_GENERIC ~ParallelForCallable() = default;
34  };
35 
36  template <typename F, typename UserTag>
37  MUDA_GLOBAL void parallel_for_kernel(ParallelForCallable<F> f);
38 
39  template <typename F, typename UserTag>
40  MUDA_GLOBAL void grid_stride_loop_kernel(ParallelForCallable<F> f);
41 } // namespace details
42 
43 enum class ParallelForType : uint32_t
44 {
45  DynamicBlocks,
46  GridStrideLoop
47 };
48 
50 {
51  public:
52  MUDA_NODISCARD MUDA_DEVICE int active_num_in_block() const MUDA_NOEXCEPT;
53  MUDA_NODISCARD MUDA_DEVICE bool is_final_block() const MUDA_NOEXCEPT;
54  MUDA_NODISCARD MUDA_DEVICE ParallelForType parallel_for_type() const MUDA_NOEXCEPT
55  {
56  return m_type;
57  }
58 
59  MUDA_NODISCARD MUDA_DEVICE int total_num() const MUDA_NOEXCEPT
60  {
61  return m_total_num;
62  }
63  MUDA_NODISCARD MUDA_DEVICE operator int() const MUDA_NOEXCEPT
64  {
65  return m_current_i;
66  }
67 
68  MUDA_NODISCARD MUDA_DEVICE int i() const MUDA_NOEXCEPT
69  {
70  return m_current_i;
71  }
72 
73  MUDA_NODISCARD MUDA_DEVICE int batch_i() const MUDA_NOEXCEPT
74  {
75  return m_batch_i;
76  }
77 
78  MUDA_NODISCARD MUDA_DEVICE int total_batch() const MUDA_NOEXCEPT
79  {
80  return m_total_batch;
81  }
82 
83  private:
84  template <typename F, typename UserTag>
85  friend MUDA_GLOBAL void details::parallel_for_kernel(ParallelForCallable<F> f);
86 
87  template <typename F, typename UserTag>
88  friend MUDA_GLOBAL void details::grid_stride_loop_kernel(ParallelForCallable<F> f);
89 
90  MUDA_DEVICE ParallelForDetails(ParallelForType type, int i, int total_num) MUDA_NOEXCEPT
91  : m_type(type),
92  m_total_num(total_num),
93  m_current_i(i)
94  {
95  }
96 
97  ParallelForType m_type;
98  int m_total_num;
99  int m_total_batch = 1;
100  int m_batch_i = 0;
101  int m_active_num_in_block = 0;
102  int m_current_i = 0;
103 };
104 
105 using details::grid_stride_loop_kernel;
106 using details::parallel_for_kernel;
107 
108 
115 class ParallelFor : public LaunchBase<ParallelFor>
116 {
117  int m_grid_dim;
118  int m_block_dim;
119  size_t m_shared_mem_size;
120 
121  public:
122  template <typename F>
124 
142  MUDA_HOST ParallelFor(size_t shared_mem_size = 0, cudaStream_t stream = nullptr) MUDA_NOEXCEPT
143  : LaunchBase(stream),
144  m_grid_dim(0),
145  m_block_dim(-1),
146  m_shared_mem_size(shared_mem_size)
147  {
148  }
149 
166  MUDA_HOST ParallelFor(int blockDim, size_t shared_mem_size = 0, cudaStream_t stream = nullptr) MUDA_NOEXCEPT
167  : LaunchBase(stream),
168  m_grid_dim(0),
169  m_block_dim(blockDim),
170  m_shared_mem_size(shared_mem_size)
171  {
172  }
173 
174 
192  MUDA_HOST ParallelFor(int gridDim,
193  int blockDim,
194  size_t shared_mem_size = 0,
195  cudaStream_t stream = nullptr) MUDA_NOEXCEPT
196  : LaunchBase(stream),
197  m_grid_dim(gridDim),
198  m_block_dim(blockDim),
199  m_shared_mem_size(shared_mem_size)
200  {
201  }
202 
203  template <typename F, typename UserTag = Default>
204  MUDA_HOST ParallelFor& apply(int count, F&& f);
205 
206  template <typename F, typename UserTag = Default>
207  MUDA_HOST ParallelFor& apply(int count, F&& f, Tag<UserTag>);
208 
209 
210  template <typename F, typename UserTag = Default>
211  MUDA_HOST MUDA_NODISCARD auto as_node_parms(int count, F&& f) -> S<NodeParms<F>>;
212 
213  template <typename F, typename UserTag = Default>
214  MUDA_HOST MUDA_NODISCARD auto as_node_parms(int count, F&& f, Tag<UserTag>)
215  -> S<NodeParms<F>>;
216 
217  MUDA_GENERIC MUDA_NODISCARD static int round_up_blocks(int count, int block_dim) MUDA_NOEXCEPT
218  {
219  return (count + block_dim - 1) / block_dim;
220  }
221 
222  public:
223  template <typename F, typename UserTag>
224  MUDA_HOST void invoke(int count, F&& f);
225 
226  template <typename F, typename UserTag>
227  MUDA_GENERIC int calculate_block_dim(int count) const MUDA_NOEXCEPT;
228 
229  MUDA_GENERIC int calculate_grid_dim(int count) const MUDA_NOEXCEPT;
230 
231  static MUDA_GENERIC int calculate_grid_dim(int count, int block_dim) MUDA_NOEXCEPT;
232 
233  MUDA_GENERIC void check_input(int count) const MUDA_NOEXCEPT;
234 };
235 } // namespace muda
236 
237 #include "details/parallel_for.inl"
Definition: kernel_tag.h:5
a frequently used parallel for loop, DynamicBlockDim and GridStrideLoop strategy are provided...
Definition: parallel_for.h:115
MUDA_HOST ParallelFor(size_t shared_mem_size=0, cudaStream_t stream=nullptr) MUDA_NOEXCEPT
Calculate grid dim automatically to cover the range, automatially choose the block size to achieve ma...
Definition: parallel_for.h:142
Definition: assert.h:13
Definition: parallel_for.h:22
MUDA_HOST ParallelFor(int blockDim, size_t shared_mem_size=0, cudaStream_t stream=nullptr) MUDA_NOEXCEPT
Calculate grid dim automatically to cover the range, but you need mannally set the block size...
Definition: parallel_for.h:166
Definition: launch_base.h:85
Definition: parallel_for.h:49
Definition: kernel_node.h:14
MUDA_HOST ParallelFor(int gridDim, int blockDim, size_t shared_mem_size=0, cudaStream_t stream=nullptr) MUDA_NOEXCEPT
Use Gride Stride Loop to cover the range, you need mannally set the grid size and block size...
Definition: parallel_for.h:192