rocPRIM
block_radix_rank_basic.hpp
1 // Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved.
2 //
3 // Permission is hereby granted, free of charge, to any person obtaining a copy
4 // of this software and associated documentation files (the "Software"), to deal
5 // in the Software without restriction, including without limitation the rights
6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 // copies of the Software, and to permit persons to whom the Software is
8 // furnished to do so, subject to the following conditions:
9 //
10 // The above copyright notice and this permission notice shall be included in
11 // all copies or substantial portions of the Software.
12 //
13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 // THE SOFTWARE.
20 
21 #ifndef ROCPRIM_BLOCK_DETAIL_BLOCK_RANK_BASIC_HPP_
22 #define ROCPRIM_BLOCK_DETAIL_BLOCK_RANK_BASIC_HPP_
23 
24 #include "../../config.hpp"
25 #include "../../detail/various.hpp"
26 #include "../../functional.hpp"
27 
28 #include "../../detail/radix_sort.hpp"
29 
30 #include "../block_scan.hpp"
31 
32 BEGIN_ROCPRIM_NAMESPACE
33 
34 namespace detail
35 {
36 
37 template<unsigned int BlockSizeX,
38  unsigned int RadixBits,
39  bool MemoizeOuterScan = false,
40  unsigned int BlockSizeY = 1,
41  unsigned int BlockSizeZ = 1>
43 {
44  using digit_counter_type = unsigned short;
45  using packed_counter_type = unsigned int;
46 
47  using block_scan_type = ::rocprim::block_scan<packed_counter_type,
48  BlockSizeX,
49  ::rocprim::block_scan_algorithm::using_warp_scan,
50  BlockSizeY,
51  BlockSizeZ>;
52 
53  static constexpr unsigned int block_size = BlockSizeX * BlockSizeY * BlockSizeZ;
54  static constexpr unsigned int radix_digits = 1 << RadixBits;
55  static constexpr unsigned int packing_ratio
56  = sizeof(packed_counter_type) / sizeof(digit_counter_type);
57  static constexpr unsigned int column_size = radix_digits / packing_ratio;
58 
59 public:
60  static constexpr unsigned int digits_per_thread
61  = ::rocprim::detail::ceiling_div(radix_digits, block_size);
62 
63 private:
64  struct storage_type_
65  {
66  union
67  {
68  digit_counter_type digit_counters[block_size * radix_digits];
69  packed_counter_type packed_counters[block_size * column_size];
70  };
71 
72  typename block_scan_type::storage_type block_scan;
73  };
74 
75  ROCPRIM_DEVICE ROCPRIM_INLINE digit_counter_type& get_digit_counter(const unsigned int digit,
76  const unsigned int thread,
77  storage_type_& storage)
78  {
79  const unsigned int column_counter = digit % column_size;
80  const unsigned int sub_counter = digit / column_size;
81  const unsigned int counter
82  = (column_counter * block_size + thread) * packing_ratio + sub_counter;
83  return storage.digit_counters[counter];
84  };
85 
86  ROCPRIM_DEVICE ROCPRIM_INLINE void reset_counters(const unsigned int flat_id,
87  storage_type_& storage)
88  {
89  for(unsigned int i = flat_id; i < block_size * column_size; i += block_size)
90  {
91  storage.packed_counters[i] = 0;
92  }
93  }
94 
95  ROCPRIM_DEVICE ROCPRIM_INLINE void
96  scan_block_counters(storage_type_& storage, packed_counter_type* const packed_counters)
97  {
98  packed_counter_type block_reduction = 0;
99  ROCPRIM_UNROLL
100  for(unsigned int i = 0; i < column_size; ++i)
101  {
102  block_reduction += packed_counters[i];
103  }
104 
105  packed_counter_type exclusive_prefix = 0;
106  packed_counter_type reduction;
107  block_scan_type().exclusive_scan(block_reduction,
108  exclusive_prefix,
109  0,
110  reduction,
111  storage.block_scan);
112 
113  ROCPRIM_UNROLL
114  for(unsigned int i = 1; i < packing_ratio; i <<= 1)
115  {
116  exclusive_prefix += reduction << (sizeof(digit_counter_type) * 8 * i);
117  }
118 
119  ROCPRIM_UNROLL
120  for(unsigned int i = 0; i < column_size; ++i)
121  {
122  packed_counter_type counter = packed_counters[i];
123  packed_counters[i] = exclusive_prefix;
124  exclusive_prefix += counter;
125  }
126  }
127 
128  ROCPRIM_DEVICE ROCPRIM_INLINE void scan_counters(const unsigned int flat_id,
129  storage_type_& storage)
130  {
131  packed_counter_type* const shared_counters
132  = &storage.packed_counters[flat_id * column_size];
133 
134  if ROCPRIM_IF_CONSTEXPR(MemoizeOuterScan)
135  {
136  packed_counter_type local_counters[column_size];
137  ROCPRIM_UNROLL
138  for(unsigned int i = 0; i < column_size; ++i)
139  {
140  local_counters[i] = shared_counters[i];
141  }
142 
143  scan_block_counters(storage, local_counters);
144 
145  ROCPRIM_UNROLL
146  for(unsigned int i = 0; i < column_size; ++i)
147  {
148  shared_counters[i] = local_counters[i];
149  }
150  }
151  else
152  {
153  scan_block_counters(storage, shared_counters);
154  }
155  }
156 
157  template<typename Key, unsigned int ItemsPerThread, typename DigitExtractor>
158  ROCPRIM_DEVICE void rank_keys_impl(const Key (&keys)[ItemsPerThread],
159  unsigned int (&ranks)[ItemsPerThread],
160  storage_type_& storage,
161  DigitExtractor digit_extractor)
162  {
163  static_assert(block_size * ItemsPerThread < 1u << 16,
164  "The maximum amout of items that block_radix_rank can rank is 2**16.");
165  const unsigned int flat_id
166  = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
167 
168  reset_counters(flat_id, storage);
169 
170  digit_counter_type thread_prefixes[ItemsPerThread];
171  digit_counter_type* digit_counters[ItemsPerThread];
172 
173  ROCPRIM_UNROLL
174  for(unsigned int i = 0; i < ItemsPerThread; ++i)
175  {
176  const unsigned int digit = digit_extractor(keys[i]);
177  digit_counters[i] = &get_digit_counter(digit, flat_id, storage);
178  thread_prefixes[i] = (*digit_counters[i])++;
179  }
180 
182 
183  scan_counters(flat_id, storage);
184 
186 
187  ROCPRIM_UNROLL
188  for(unsigned int i = 0; i < ItemsPerThread; ++i)
189  {
190  ranks[i] = thread_prefixes[i] + *digit_counters[i];
191  }
192  }
193 
194  template<bool Descending, typename Key, unsigned int ItemsPerThread>
195  ROCPRIM_DEVICE void rank_keys_impl(const Key (&keys)[ItemsPerThread],
196  unsigned int (&ranks)[ItemsPerThread],
197  storage_type_& storage,
198  const unsigned int begin_bit,
199  const unsigned int pass_bits)
200  {
201  using key_codec = ::rocprim::detail::radix_key_codec<Key, Descending>;
202  using bit_key_type = typename key_codec::bit_key_type;
203 
204  bit_key_type bit_keys[ItemsPerThread];
205  ROCPRIM_UNROLL
206  for(unsigned int i = 0; i < ItemsPerThread; ++i)
207  {
208  bit_keys[i] = key_codec::encode(keys[i]);
209  }
210 
211  rank_keys_impl(bit_keys,
212  ranks,
213  storage,
214  [begin_bit, pass_bits](const bit_key_type& key)
215  { return key_codec::extract_digit(key, begin_bit, pass_bits); });
216  }
217 
218  template<unsigned int ItemsPerThread>
219  ROCPRIM_DEVICE void digit_prefix_count(unsigned int (&prefix)[digits_per_thread],
220  unsigned int (&counts)[digits_per_thread],
221  storage_type_& storage)
222  {
223  const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>();
224 
225  ROCPRIM_UNROLL
226  for(unsigned int i = 0; i < digits_per_thread; ++i)
227  {
228  const unsigned int digit = flat_id * digits_per_thread + i;
229  if(radix_digits % block_size == 0 || digit < radix_digits)
230  {
231  // The counter for thread 0 holds the prefix of all the digits at this point.
232  prefix[i] = get_digit_counter(digit, 0, storage);
233  // To find the count, subtract the prefix of the next digit with that of the
234  // current digit.
235  const unsigned int next_prefix = digit + 1 == radix_digits
236  ? block_size * ItemsPerThread
237  : get_digit_counter(digit + 1, 0, storage);
238  counts[i] = next_prefix - prefix[i];
239  }
240  }
241  }
242 
243 public:
244  using storage_type = ::rocprim::detail::raw_storage<storage_type_>;
245 
246  template<typename Key, unsigned ItemsPerThread>
247  ROCPRIM_DEVICE void rank_keys(const Key (&keys)[ItemsPerThread],
248  unsigned int (&ranks)[ItemsPerThread],
249  storage_type& storage,
250  unsigned int begin_bit = 0,
251  unsigned int pass_bits = RadixBits)
252  {
253  rank_keys_impl<false>(keys, ranks, storage.get(), begin_bit, pass_bits);
254  }
255 
256  template<typename Key, unsigned ItemsPerThread>
257  ROCPRIM_DEVICE void rank_keys_desc(const Key (&keys)[ItemsPerThread],
258  unsigned int (&ranks)[ItemsPerThread],
259  storage_type& storage,
260  unsigned int begin_bit = 0,
261  unsigned int pass_bits = RadixBits)
262  {
263  rank_keys_impl<true>(keys, ranks, storage.get(), begin_bit, pass_bits);
264  }
265 
266  template<typename Key, unsigned ItemsPerThread, typename DigitExtractor>
267  ROCPRIM_DEVICE void rank_keys(const Key (&keys)[ItemsPerThread],
268  unsigned int (&ranks)[ItemsPerThread],
269  storage_type& storage,
270  DigitExtractor digit_extractor)
271  {
272  rank_keys_impl(keys, ranks, storage.get(), digit_extractor);
273  }
274 
275  template<typename Key, unsigned ItemsPerThread, typename DigitExtractor>
276  ROCPRIM_DEVICE void rank_keys_desc(const Key (&keys)[ItemsPerThread],
277  unsigned int (&ranks)[ItemsPerThread],
278  storage_type& storage,
279  DigitExtractor digit_extractor)
280  {
281  rank_keys_impl(keys,
282  ranks,
283  storage.get(),
284  [&digit_extractor](const Key& key)
285  {
286  const unsigned int digit = digit_extractor(key);
287  return radix_digits - 1 - digit;
288  });
289  }
290 
291  template<typename Key, unsigned ItemsPerThread, typename DigitExtractor>
292  ROCPRIM_DEVICE void rank_keys(const Key (&keys)[ItemsPerThread],
293  unsigned int (&ranks)[ItemsPerThread],
294  storage_type& storage,
295  DigitExtractor digit_extractor,
296  unsigned int (&prefix)[digits_per_thread],
297  unsigned int (&counts)[digits_per_thread])
298  {
299  rank_keys(keys, ranks, storage, digit_extractor);
300  digit_prefix_count<ItemsPerThread>(prefix, counts, storage.get());
301  }
302 };
303 
304 } // namespace detail
305 
306 END_ROCPRIM_NAMESPACE
307 
308 #endif
The block_scan class is a block level parallel primitive which provides methods for performing inclus...
Definition: block_scan.hpp:134
Definition: block_radix_rank_basic.hpp:42
Deprecated: Configuration of device-level scan primitives.
Definition: block_histogram.hpp:62
ROCPRIM_DEVICE ROCPRIM_INLINE void syncthreads()
Synchronize all threads in a block (tile)
Definition: thread.hpp:216