rocPRIM
warp_scan_shared_mem.hpp
1 // Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved.
2 //
3 // Permission is hereby granted, free of charge, to any person obtaining a copy
4 // of this software and associated documentation files (the "Software"), to deal
5 // in the Software without restriction, including without limitation the rights
6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 // copies of the Software, and to permit persons to whom the Software is
8 // furnished to do so, subject to the following conditions:
9 //
10 // The above copyright notice and this permission notice shall be included in
11 // all copies or substantial portions of the Software.
12 //
13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 // THE SOFTWARE.
20 
21 #ifndef ROCPRIM_WARP_DETAIL_WARP_SCAN_SHARED_MEM_HPP_
22 #define ROCPRIM_WARP_DETAIL_WARP_SCAN_SHARED_MEM_HPP_
23 
24 #include <type_traits>
25 
26 #include "../../config.hpp"
27 #include "../../detail/various.hpp"
28 
29 #include "../../intrinsics.hpp"
30 #include "../../types.hpp"
31 
32 BEGIN_ROCPRIM_NAMESPACE
33 
34 namespace detail
35 {
36 
37 template<
38  class T,
39  unsigned int WarpSize
40 >
42 {
43  struct storage_type_
44  {
45  T threads[WarpSize];
46  };
47 public:
49 
50  template<class BinaryFunction>
51  ROCPRIM_DEVICE ROCPRIM_INLINE
52  void inclusive_scan(T input, T& output,
53  storage_type& storage, BinaryFunction scan_op)
54  {
55  const unsigned int lid = detail::logical_lane_id<WarpSize>();
56  storage_type_& storage_ = storage.get();
57 
58  T me = input;
59  storage_.threads[lid] = me;
61  for(unsigned int i = 1; i < WarpSize; i *= 2)
62  {
63  const bool do_op = lid >= i;
64  if(do_op)
65  {
66  T other = storage_.threads[lid - i];
67  me = scan_op(other, me);
68  }
70  if(do_op)
71  {
72  storage_.threads[lid] = me;
73  }
75  }
76  output = me;
77  }
78 
79  template<class BinaryFunction>
80  ROCPRIM_DEVICE ROCPRIM_INLINE
81  void inclusive_scan(T input, T& output, T& reduction,
82  storage_type& storage, BinaryFunction scan_op)
83  {
84  storage_type_& storage_ = storage.get();
85  inclusive_scan(input, output, storage, scan_op);
86  reduction = storage_.threads[WarpSize - 1];
87  }
88 
89  template<class BinaryFunction>
90  ROCPRIM_DEVICE ROCPRIM_INLINE
91  void exclusive_scan(T input, T& output, T init,
92  storage_type& storage, BinaryFunction scan_op)
93  {
94  inclusive_scan(input, output, storage, scan_op);
95  to_exclusive(output, init, storage, scan_op);
96  }
97 
98  template<class BinaryFunction>
99  ROCPRIM_DEVICE ROCPRIM_INLINE
100  void exclusive_scan(T input, T& output,
101  storage_type& storage, BinaryFunction scan_op)
102  {
103  inclusive_scan(input, output, storage, scan_op);
104  to_exclusive(output, storage);
105  }
106 
107  template<class BinaryFunction>
108  ROCPRIM_DEVICE ROCPRIM_INLINE
109  void exclusive_scan(T input, T& output, T init, T& reduction,
110  storage_type& storage, BinaryFunction scan_op)
111  {
112  storage_type_& storage_ = storage.get();
113  inclusive_scan(input, output, storage, scan_op);
114  reduction = storage_.threads[WarpSize - 1];
115  to_exclusive(output, init, storage, scan_op);
116  }
117 
118  template<class BinaryFunction>
119  ROCPRIM_DEVICE ROCPRIM_INLINE
120  void scan(T input, T& inclusive_output, T& exclusive_output, T init,
121  storage_type& storage, BinaryFunction scan_op)
122  {
123  inclusive_scan(input, inclusive_output, storage, scan_op);
124  to_exclusive(exclusive_output, init, storage, scan_op);
125  }
126 
127  template<class BinaryFunction>
128  ROCPRIM_DEVICE ROCPRIM_INLINE
129  void scan(T input, T& inclusive_output, T& exclusive_output,
130  storage_type& storage, BinaryFunction scan_op)
131  {
132  inclusive_scan(input, inclusive_output, storage, scan_op);
133  to_exclusive(exclusive_output, storage);
134  }
135 
136  template<class BinaryFunction>
137  ROCPRIM_DEVICE ROCPRIM_INLINE
138  void scan(T input, T& inclusive_output, T& exclusive_output, T init, T& reduction,
139  storage_type& storage, BinaryFunction scan_op)
140  {
141  storage_type_& storage_ = storage.get();
142  inclusive_scan(input, inclusive_output, storage, scan_op);
143  reduction = storage_.threads[WarpSize - 1];
145  to_exclusive(exclusive_output, init, storage, scan_op);
146  }
147 
148  ROCPRIM_DEVICE ROCPRIM_INLINE
149  T broadcast(T input, const unsigned int src_lane, storage_type& storage)
150  {
151  storage_type_& storage_ = storage.get();
152  if(src_lane == detail::logical_lane_id<WarpSize>())
153  {
154  storage_.threads[src_lane] = input;
155  }
157  return storage_.threads[src_lane];
158  }
159 
160 protected:
161  ROCPRIM_DEVICE ROCPRIM_INLINE
162  void to_exclusive(T inclusive_input, T& exclusive_output, storage_type& storage)
163  {
164  (void) inclusive_input;
165  return to_exclusive(exclusive_output, storage);
166  }
167 
168 private:
169  // Calculate exclusive results base on inclusive scan results in storage.threads[].
170  template<class BinaryFunction>
171  ROCPRIM_DEVICE ROCPRIM_INLINE
172  void to_exclusive(T& exclusive_output, T init,
173  storage_type& storage, BinaryFunction scan_op)
174  {
175  const unsigned int lid = detail::logical_lane_id<WarpSize>();
176  storage_type_& storage_ = storage.get();
177  exclusive_output = init;
178  if(lid != 0)
179  {
180  exclusive_output = scan_op(init, storage_.threads[lid - 1]);
181  }
182  }
183 
184  ROCPRIM_DEVICE ROCPRIM_INLINE
185  void to_exclusive(T& exclusive_output, storage_type& storage)
186  {
187  const unsigned int lid = detail::logical_lane_id<WarpSize>();
188  storage_type_& storage_ = storage.get();
189  if(lid != 0)
190  {
191  exclusive_output = storage_.threads[lid - 1];
192  }
193  }
194 };
195 
196 } // end namespace detail
197 
198 END_ROCPRIM_NAMESPACE
199 
200 #endif // ROCPRIM_WARP_DETAIL_WARP_SCAN_SHARED_MEM_HPP_
Definition: benchmark_block_scan.cpp:63
ROCPRIM_DEVICE ROCPRIM_INLINE void wave_barrier()
Synchronize all threads in the wavefront.
Definition: thread.hpp:235
Deprecated: Configuration of device-level scan primitives.
Definition: block_histogram.hpp:62
Definition: warp_scan_shared_mem.hpp:41
Definition: benchmark_block_scan.cpp:100