rocPRIM
block_radix_rank_match.hpp
1 // Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved.
2 //
3 // Permission is hereby granted, free of charge, to any person obtaining a copy
4 // of this software and associated documentation files (the "Software"), to deal
5 // in the Software without restriction, including without limitation the rights
6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 // copies of the Software, and to permit persons to whom the Software is
8 // furnished to do so, subject to the following conditions:
9 //
10 // The above copyright notice and this permission notice shall be included in
11 // all copies or substantial portions of the Software.
12 //
13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 // THE SOFTWARE.
20 
21 #ifndef ROCPRIM_BLOCK_DETAIL_BLOCK_RANK_MATCH_HPP_
22 #define ROCPRIM_BLOCK_DETAIL_BLOCK_RANK_MATCH_HPP_
23 
24 #include "../../config.hpp"
25 #include "../../detail/various.hpp"
26 #include "../../functional.hpp"
27 #include "../../types.hpp"
28 
29 #include "../../detail/radix_sort.hpp"
30 
31 #include "../block_scan.hpp"
32 
33 BEGIN_ROCPRIM_NAMESPACE
34 
35 namespace detail
36 {
37 
38 template<unsigned int BlockSizeX,
39  unsigned int RadixBits,
40  unsigned int BlockSizeY = 1,
41  unsigned int BlockSizeZ = 1>
43 {
44  using digit_counter_type = unsigned int;
45 
46  using block_scan_type = ::rocprim::block_scan<digit_counter_type,
47  BlockSizeX,
48  ::rocprim::block_scan_algorithm::using_warp_scan,
49  BlockSizeY,
50  BlockSizeZ>;
51 
52  static constexpr unsigned int block_size = BlockSizeX * BlockSizeY * BlockSizeZ;
53  static constexpr unsigned int radix_digits = 1 << RadixBits;
54 
55  static constexpr unsigned int warp_size = warpSize;
56  // Force the number of warps to an uneven amount to reduce the number of lds bank conflicts.
57  static constexpr unsigned int warps
58  = ::rocprim::detail::ceiling_div(block_size, warp_size) | 1u;
59  // The number of counters that are actively being used.
60  static constexpr unsigned int active_counters = warps * radix_digits;
61  // We want to use a regular block scan to scan the per-warp counters. This requires the
62  // total number of counters to be divisible by the block size. To facilitate this, just add
63  // a bunch of counters that are not otherwise used.
64  static constexpr unsigned int counters_per_thread
65  = ::rocprim::detail::ceiling_div(active_counters, block_size);
66  // The total number of counters, factoring in the unused ones for the block scan.
67  static constexpr unsigned int counters = counters_per_thread * block_size;
68 
69 public:
70  constexpr static unsigned int digits_per_thread
71  = ::rocprim::detail::ceiling_div(radix_digits, block_size);
72 
73 private:
74  struct storage_type_
75  {
76  typename block_scan_type::storage_type block_scan;
77  digit_counter_type counters[counters];
78  };
79 
80  ROCPRIM_DEVICE ROCPRIM_INLINE digit_counter_type&
81  get_digit_counter(const unsigned int digit, const unsigned int warp, storage_type_& storage)
82  {
83  return storage.counters[digit * warps + warp];
84  }
85 
86  template<typename Key, unsigned int ItemsPerThread, typename DigitExtractor>
87  ROCPRIM_DEVICE void rank_keys_impl(const Key (&keys)[ItemsPerThread],
88  unsigned int (&ranks)[ItemsPerThread],
89  storage_type_& storage,
90  DigitExtractor digit_extractor)
91  {
92  const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>();
93  const unsigned int warp_id = ::rocprim::warp_id();
94 
95  ROCPRIM_UNROLL
96  for(unsigned int i = 0; i < counters_per_thread; ++i)
97  {
98  storage.counters[flat_id * counters_per_thread + i] = 0;
99  }
100 
102 
103  digit_counter_type* digit_counters[ItemsPerThread];
104 
105  ROCPRIM_UNROLL
106  for(unsigned int i = 0; i < ItemsPerThread; ++i)
107  {
108  // Get the digit for this key.
109  const unsigned int digit = digit_extractor(keys[i]);
110 
111  // Get the digit counter for this key on the current warp.
112  digit_counters[i] = &get_digit_counter(digit, warp_id, storage);
113  const digit_counter_type warp_digit_prefix = *digit_counters[i];
114 
115  // Construct a mask of threads in this wave which have the same digit.
117 
118  ROCPRIM_UNROLL
119  for(unsigned int b = 0; b < RadixBits; ++b)
120  {
121  const unsigned int bit_set = digit & (1u << b);
122  const ::rocprim::lane_mask_type bit_set_mask = ::rocprim::ballot(bit_set);
123  peer_mask &= (bit_set ? bit_set_mask : ~bit_set_mask);
124  }
125 
127 
128  // The total number of threads in the warp which also have this digit.
129  const unsigned int digit_count = rocprim::bit_count(peer_mask);
130  // The number of threads in the warp that have the same digit AND whose lane id is lower
131  // than the current thread's.
132  const unsigned int peer_digit_prefix = rocprim::masked_bit_count(peer_mask);
133 
134  // The first thread with a particular digit gets to update the shared counter.
135  if(peer_digit_prefix == 0)
136  {
137  *digit_counters[i] = warp_digit_prefix + digit_count;
138  }
139 
141 
142  // Compute the warp-local rank.
143  ranks[i] = warp_digit_prefix + peer_digit_prefix;
144  }
145 
147 
148  // Scan the per-warp counters to get a rank-offset per warp counter.
149  digit_counter_type scan_counters[counters_per_thread];
150 
151  ROCPRIM_UNROLL
152  for(unsigned int i = 0; i < counters_per_thread; ++i)
153  {
154  scan_counters[i] = storage.counters[flat_id * counters_per_thread + i];
155  }
156 
157  block_scan_type().exclusive_scan(scan_counters, scan_counters, 0, storage.block_scan);
158 
159  ROCPRIM_UNROLL
160  for(unsigned int i = 0; i < counters_per_thread; ++i)
161  {
162  storage.counters[flat_id * counters_per_thread + i] = scan_counters[i];
163  }
164 
166 
167  // Add the per-warp rank counter to get the final rank.
168  ROCPRIM_UNROLL
169  for(unsigned int i = 0; i < ItemsPerThread; ++i)
170  {
171  ranks[i] += *digit_counters[i];
172  }
173  }
174 
175  template<bool Descending, typename Key, unsigned int ItemsPerThread>
176  ROCPRIM_DEVICE void rank_keys_impl(const Key (&keys)[ItemsPerThread],
177  unsigned int (&ranks)[ItemsPerThread],
178  storage_type_& storage,
179  const unsigned int begin_bit,
180  const unsigned int pass_bits)
181  {
182  using key_codec = ::rocprim::detail::radix_key_codec<Key, Descending>;
183  using bit_key_type = typename key_codec::bit_key_type;
184 
185  bit_key_type bit_keys[ItemsPerThread];
186  ROCPRIM_UNROLL
187  for(unsigned int i = 0; i < ItemsPerThread; ++i)
188  {
189  bit_keys[i] = key_codec::encode(keys[i]);
190  }
191 
192  rank_keys_impl(bit_keys,
193  ranks,
194  storage,
195  [begin_bit, pass_bits](const bit_key_type& key)
196  { return key_codec::extract_digit(key, begin_bit, pass_bits); });
197  }
198 
199  template<unsigned int ItemsPerThread>
200  ROCPRIM_DEVICE void digit_prefix_count(unsigned int (&prefix)[digits_per_thread],
201  unsigned int (&counts)[digits_per_thread],
202  storage_type_& storage)
203  {
204  const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>();
205 
206  ROCPRIM_UNROLL
207  for(unsigned int i = 0; i < digits_per_thread; ++i)
208  {
209  const unsigned int digit = flat_id * digits_per_thread + i;
210  if(radix_digits % block_size == 0 || digit < radix_digits)
211  {
212  // The counter for warp 0 holds the prefix of all the digits at this point.
213  prefix[i] = get_digit_counter(digit, 0, storage);
214  // To find the count, subtract the prefix of the next digit with that of the
215  // current digit.
216  const unsigned int next_prefix = digit + 1 == radix_digits
217  ? block_size * ItemsPerThread
218  : get_digit_counter(digit + 1, 0, storage);
219  counts[i] = next_prefix - prefix[i];
220  }
221  }
222  }
223 
224 public:
225  using storage_type = ::rocprim::detail::raw_storage<storage_type_>;
226 
227  template<typename Key, unsigned ItemsPerThread>
228  ROCPRIM_DEVICE void rank_keys(const Key (&keys)[ItemsPerThread],
229  unsigned int (&ranks)[ItemsPerThread],
230  storage_type& storage,
231  unsigned int begin_bit = 0,
232  unsigned int pass_bits = RadixBits)
233  {
234  rank_keys_impl<false>(keys, ranks, storage.get(), begin_bit, pass_bits);
235  }
236 
237  template<typename Key, unsigned ItemsPerThread>
238  ROCPRIM_DEVICE void rank_keys_desc(const Key (&keys)[ItemsPerThread],
239  unsigned int (&ranks)[ItemsPerThread],
240  storage_type& storage,
241  unsigned int begin_bit = 0,
242  unsigned int pass_bits = RadixBits)
243  {
244  rank_keys_impl<true>(keys, ranks, storage.get(), begin_bit, pass_bits);
245  }
246 
247  template<typename Key, unsigned ItemsPerThread, typename DigitExtractor>
248  ROCPRIM_DEVICE void rank_keys(const Key (&keys)[ItemsPerThread],
249  unsigned int (&ranks)[ItemsPerThread],
250  storage_type& storage,
251  DigitExtractor digit_extractor)
252  {
253  rank_keys_impl(keys, ranks, storage.get(), digit_extractor);
254  }
255 
256  template<typename Key, unsigned ItemsPerThread, typename DigitExtractor>
257  ROCPRIM_DEVICE void rank_keys(const Key (&keys)[ItemsPerThread],
258  unsigned int (&ranks)[ItemsPerThread],
259  storage_type& storage,
260  DigitExtractor digit_extractor,
261  unsigned int (&prefix)[digits_per_thread],
262  unsigned int (&counts)[digits_per_thread])
263  {
264  rank_keys(keys, ranks, storage, digit_extractor);
265  digit_prefix_count<ItemsPerThread>(prefix, counts, storage.get());
266  }
267 };
268 
269 } // namespace detail
270 
271 END_ROCPRIM_NAMESPACE
272 
273 #endif
ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int masked_bit_count(lane_mask_type x, unsigned int add=0)
Masked bit count.
Definition: warp.hpp:48
The block_scan class is a block level parallel primitive which provides methods for performing inclus...
Definition: block_scan.hpp:134
Definition: block_radix_rank_match.hpp:42
ROCPRIM_DEVICE ROCPRIM_INLINE void wave_barrier()
Synchronize all threads in the wavefront.
Definition: thread.hpp:235
Deprecated: Configuration of device-level scan primitives.
Definition: block_histogram.hpp:62
const unsigned int warp_id
Returns warp id in a block (tile).
Definition: benchmark_warp_exchange.cpp:153
ROCPRIM_DEVICE ROCPRIM_INLINE lane_mask_type ballot(int predicate)
Evaluate predicate for all active work-items in the warp and return an integer whose i-th bit is set ...
Definition: warp.hpp:38
ROCPRIM_DEVICE ROCPRIM_INLINE void syncthreads()
Synchronize all threads in a block (tile)
Definition: thread.hpp:216
unsigned long long int lane_mask_type
The lane_mask_type is an integer that contains one bit per thread.
Definition: types.hpp:164
ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int bit_count(unsigned int x)
Bit count.
Definition: bit.hpp:42