rocPRIM
warp_scan_dpp.hpp
1 // Copyright (c) 2018-2022 Advanced Micro Devices, Inc. All rights reserved.
2 //
3 // Permission is hereby granted, free of charge, to any person obtaining a copy
4 // of this software and associated documentation files (the "Software"), to deal
5 // in the Software without restriction, including without limitation the rights
6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 // copies of the Software, and to permit persons to whom the Software is
8 // furnished to do so, subject to the following conditions:
9 //
10 // The above copyright notice and this permission notice shall be included in
11 // all copies or substantial portions of the Software.
12 //
13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 // THE SOFTWARE.
20 
21 #ifndef ROCPRIM_WARP_DETAIL_WARP_SCAN_DPP_HPP_
22 #define ROCPRIM_WARP_DETAIL_WARP_SCAN_DPP_HPP_
23 
24 #include <type_traits>
25 
26 #include "../../config.hpp"
27 #include "../../detail/various.hpp"
28 
29 #include "../../intrinsics.hpp"
30 #include "../../types.hpp"
31 
32 BEGIN_ROCPRIM_NAMESPACE
33 
34 namespace detail
35 {
36 
37 template<
38  class T,
39  unsigned int WarpSize
40 >
42 {
43 public:
44  static_assert(detail::is_power_of_two(WarpSize), "WarpSize must be power of 2");
45 
47 
48  template<class BinaryFunction>
49  ROCPRIM_DEVICE ROCPRIM_INLINE
50  void inclusive_scan(T input, T& output, BinaryFunction scan_op)
51  {
52  const unsigned int lane_id = ::rocprim::lane_id();
53  const unsigned int row_lane_id = lane_id % ::rocprim::min(16u, WarpSize);
54 
55  output = input;
56 
57  if(WarpSize > 1)
58  {
59  T t = scan_op(warp_move_dpp<T, 0x111>(output), output); // row_shr:1
60  if(row_lane_id >= 1) output = t;
61  }
62  if(WarpSize > 2)
63  {
64  T t = scan_op(warp_move_dpp<T, 0x112>(output), output); // row_shr:2
65  if(row_lane_id >= 2) output = t;
66  }
67  if(WarpSize > 4)
68  {
69  T t = scan_op(warp_move_dpp<T, 0x114>(output), output); // row_shr:4
70  if(row_lane_id >= 4) output = t;
71  }
72  if(WarpSize > 8)
73  {
74  T t = scan_op(warp_move_dpp<T, 0x118>(output), output); // row_shr:8
75  if(row_lane_id >= 8) output = t;
76  }
77 #if ROCPRIM_NAVI
78  if(WarpSize > 16)
79  {
80  T t = scan_op(warp_swizzle<T, 0x1e0>(output), output); // row_bcast:15
81  if(lane_id % 32 >= 16) output = t;
82  }
83 #else
84  if(WarpSize > 16)
85  {
86  T t = scan_op(warp_move_dpp<T, 0x142>(output), output); // row_bcast:15
87  if(lane_id % 32 >= 16) output = t;
88  }
89  if(WarpSize > 32)
90  {
91  T t = scan_op(warp_move_dpp<T, 0x143>(output), output); // row_bcast:31
92  if(lane_id >= 32) output = t;
93  }
94 #endif
95  }
96 
97  template<class BinaryFunction>
98  ROCPRIM_DEVICE ROCPRIM_INLINE
99  void inclusive_scan(T input, T& output,
100  storage_type& storage, BinaryFunction scan_op)
101  {
102  (void) storage; // disables unused parameter warning
103  inclusive_scan(input, output, scan_op);
104  }
105 
106  template<class BinaryFunction>
107  ROCPRIM_DEVICE ROCPRIM_INLINE
108  void inclusive_scan(T input, T& output, T& reduction,
109  BinaryFunction scan_op)
110  {
111  inclusive_scan(input, output, scan_op);
112  // Broadcast value from the last thread in warp
113  reduction = warp_shuffle(output, WarpSize-1, WarpSize);
114  }
115 
116  template<class BinaryFunction>
117  ROCPRIM_DEVICE ROCPRIM_INLINE
118  void inclusive_scan(T input, T& output, T& reduction,
119  storage_type& storage, BinaryFunction scan_op)
120  {
121  (void) storage;
122  inclusive_scan(input, output, reduction, scan_op);
123  }
124 
125  template<class BinaryFunction>
126  ROCPRIM_DEVICE ROCPRIM_INLINE
127  void exclusive_scan(T input, T& output, T init, BinaryFunction scan_op)
128  {
129  inclusive_scan(input, output, scan_op);
130  // Convert inclusive scan result to exclusive
131  to_exclusive(output, output, init, scan_op);
132  }
133 
134  template<class BinaryFunction>
135  ROCPRIM_DEVICE ROCPRIM_INLINE
136  void exclusive_scan(T input, T& output, T init,
137  storage_type& storage, BinaryFunction scan_op)
138  {
139  (void) storage; // disables unused parameter warning
140  exclusive_scan(input, output, init, scan_op);
141  }
142 
143  template<class BinaryFunction>
144  ROCPRIM_DEVICE ROCPRIM_INLINE
145  void exclusive_scan(T input, T& output,
146  storage_type& storage, BinaryFunction scan_op)
147  {
148  (void) storage; // disables unused parameter warning
149  inclusive_scan(input, output, scan_op);
150  // Convert inclusive scan result to exclusive
151  to_exclusive(output, output);
152  }
153 
154  template<class BinaryFunction>
155  ROCPRIM_DEVICE ROCPRIM_INLINE
156  void exclusive_scan(T input, T& output, T init, T& reduction,
157  BinaryFunction scan_op)
158  {
159  inclusive_scan(input, output, scan_op);
160  // Broadcast value from the last thread in warp
161  reduction = warp_shuffle(output, WarpSize-1, WarpSize);
162  // Convert inclusive scan result to exclusive
163  to_exclusive(output, output, init, scan_op);
164  }
165 
166  template<class BinaryFunction>
167  ROCPRIM_DEVICE ROCPRIM_INLINE
168  void exclusive_scan(T input, T& output, T init, T& reduction,
169  storage_type& storage, BinaryFunction scan_op)
170  {
171  (void) storage;
172  exclusive_scan(input, output, init, reduction, scan_op);
173  }
174 
175  template<class BinaryFunction>
176  ROCPRIM_DEVICE ROCPRIM_INLINE
177  void scan(T input, T& inclusive_output, T& exclusive_output, T init,
178  BinaryFunction scan_op)
179  {
180  inclusive_scan(input, inclusive_output, scan_op);
181  // Convert inclusive scan result to exclusive
182  to_exclusive(inclusive_output, exclusive_output, init, scan_op);
183  }
184 
185  template<class BinaryFunction>
186  ROCPRIM_DEVICE ROCPRIM_INLINE
187  void scan(T input, T& inclusive_output, T& exclusive_output, T init,
188  storage_type& storage, BinaryFunction scan_op)
189  {
190  (void) storage; // disables unused parameter warning
191  scan(input, inclusive_output, exclusive_output, init, scan_op);
192  }
193 
194  template<class BinaryFunction>
195  ROCPRIM_DEVICE ROCPRIM_INLINE
196  void scan(T input, T& inclusive_output, T& exclusive_output,
197  storage_type& storage, BinaryFunction scan_op)
198  {
199  (void) storage; // disables unused parameter warning
200  inclusive_scan(input, inclusive_output, scan_op);
201  // Convert inclusive scan result to exclusive
202  to_exclusive(inclusive_output, exclusive_output);
203  }
204 
205  template<class BinaryFunction>
206  ROCPRIM_DEVICE ROCPRIM_INLINE
207  void scan(T input, T& inclusive_output, T& exclusive_output, T init, T& reduction,
208  BinaryFunction scan_op)
209  {
210  inclusive_scan(input, inclusive_output, scan_op);
211  // Broadcast value from the last thread in warp
212  reduction = warp_shuffle(inclusive_output, WarpSize-1, WarpSize);
213  // Convert inclusive scan result to exclusive
214  to_exclusive(inclusive_output, exclusive_output, init, scan_op);
215  }
216 
217  template<class BinaryFunction>
218  ROCPRIM_DEVICE ROCPRIM_INLINE
219  void scan(T input, T& inclusive_output, T& exclusive_output, T init, T& reduction,
220  storage_type& storage, BinaryFunction scan_op)
221  {
222  (void) storage;
223  scan(input, inclusive_output, exclusive_output, init, reduction, scan_op);
224  }
225 
226  ROCPRIM_DEVICE ROCPRIM_INLINE
227  T broadcast(T input, const unsigned int src_lane, storage_type& storage)
228  {
229  (void) storage;
230  return warp_shuffle(input, src_lane, WarpSize);
231  }
232 
233 protected:
234  ROCPRIM_DEVICE ROCPRIM_INLINE
235  void to_exclusive(T inclusive_input, T& exclusive_output, storage_type& storage)
236  {
237  (void) storage;
238  return to_exclusive(inclusive_input, exclusive_output);
239  }
240 
241 private:
242  // Changes inclusive scan results to exclusive scan results
243  template<class BinaryFunction>
244  ROCPRIM_DEVICE ROCPRIM_INLINE
245  void to_exclusive(T inclusive_input, T& exclusive_output, T init,
246  BinaryFunction scan_op)
247  {
248  // include init value in scan results
249  exclusive_output = scan_op(init, inclusive_input);
250  // get exclusive results
251  exclusive_output = warp_shuffle_up(exclusive_output, 1, WarpSize);
252  if(detail::logical_lane_id<WarpSize>() == 0)
253  {
254  exclusive_output = init;
255  }
256  }
257 
258  ROCPRIM_DEVICE ROCPRIM_INLINE
259  void to_exclusive(T inclusive_input, T& exclusive_output)
260  {
261  // shift to get exclusive results
262  exclusive_output = warp_shuffle_up(inclusive_input, 1, WarpSize);
263  }
264 };
265 
266 } // end namespace detail
267 
268 END_ROCPRIM_NAMESPACE
269 
270 #endif // ROCPRIM_WARP_DETAIL_WARP_SCAN_DPP_HPP_
Definition: benchmark_block_scan.cpp:63
ROCPRIM_DEVICE ROCPRIM_INLINE T warp_shuffle(const T &input, const int src_lane, const int width=device_warp_size())
Shuffle for any data type.
Definition: warp_shuffle.hpp:172
ROCPRIM_DEVICE ROCPRIM_INLINE T warp_shuffle_up(const T &input, const unsigned int delta, const int width=device_warp_size())
Shuffle up for any data type.
Definition: warp_shuffle.hpp:197
ROCPRIM_HOST_DEVICE constexpr T min(const T &a, const T &b)
Returns the minimum of its arguments.
Definition: functional.hpp:63
Deprecated: Configuration of device-level scan primitives.
Definition: block_histogram.hpp:62
Definition: various.hpp:52
Definition: benchmark_block_scan.cpp:100
Definition: warp_scan_dpp.hpp:41
ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int lane_id()
Returns thread identifier in a warp.
Definition: thread.hpp:93