rocPRIM
device_radix_sort.hpp
1 // Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved.
2 //
3 // Permission is hereby granted, free of charge, to any person obtaining a copy
4 // of this software and associated documentation files (the "Software"), to deal
5 // in the Software without restriction, including without limitation the rights
6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 // copies of the Software, and to permit persons to whom the Software is
8 // furnished to do so, subject to the following conditions:
9 //
10 // The above copyright notice and this permission notice shall be included in
11 // all copies or substantial portions of the Software.
12 //
13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 // THE SOFTWARE.
20 
21 #ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_RADIX_SORT_HPP_
22 #define ROCPRIM_DEVICE_DETAIL_DEVICE_RADIX_SORT_HPP_
23 
24 #include <type_traits>
25 #include <iterator>
26 
27 #include "../../config.hpp"
28 #include "../../detail/various.hpp"
29 #include "../../detail/radix_sort.hpp"
30 
31 #include "../../intrinsics.hpp"
32 #include "../../functional.hpp"
33 #include "../../types.hpp"
34 
35 #include "../../block/block_discontinuity.hpp"
36 #include "../../block/block_exchange.hpp"
37 #include "../../block/block_load.hpp"
38 #include "../../block/block_load_func.hpp"
39 #include "../../block/block_radix_rank.hpp"
40 #include "../../block/block_radix_sort.hpp"
41 #include "../../block/block_scan.hpp"
42 #include "../../block/block_store_func.hpp"
43 
44 BEGIN_ROCPRIM_NAMESPACE
45 
46 namespace detail
47 {
48 
49 // Wrapping functions that allow one to call proper methods (with or without values)
50 // (a variant with values is enabled only when Value is not empty_type)
51 template<bool Descending = false, class SortType, class SortKey, class SortValue, unsigned int ItemsPerThread>
52 ROCPRIM_DEVICE ROCPRIM_INLINE
53 void sort_block(SortType sorter,
54  SortKey (&keys)[ItemsPerThread],
55  SortValue (&values)[ItemsPerThread],
56  typename SortType::storage_type& storage,
57  unsigned int begin_bit,
58  unsigned int end_bit)
59 {
60  if(Descending)
61  {
62  sorter.sort_desc(keys, values, storage, begin_bit, end_bit);
63  }
64  else
65  {
66  sorter.sort(keys, values, storage, begin_bit, end_bit);
67  }
68 }
69 
70 template<bool Descending = false, class SortType, class SortKey, unsigned int ItemsPerThread>
71 ROCPRIM_DEVICE ROCPRIM_INLINE
72 void sort_block(SortType sorter,
73  SortKey (&keys)[ItemsPerThread],
74  ::rocprim::empty_type (&values)[ItemsPerThread],
75  typename SortType::storage_type& storage,
76  unsigned int begin_bit,
77  unsigned int end_bit)
78 {
79  (void) values;
80  if(Descending)
81  {
82  sorter.sort_desc(keys, storage, begin_bit, end_bit);
83  }
84  else
85  {
86  sorter.sort(keys, storage, begin_bit, end_bit);
87  }
88 }
89 
90 template<
91  unsigned int WarpSize,
92  unsigned int BlockSize,
93  unsigned int ItemsPerThread,
94  unsigned int RadixBits,
95  bool Descending
96 >
98 {
99  static constexpr unsigned int radix_size = 1 << RadixBits;
100 
101  static constexpr unsigned int warp_size = WarpSize;
102  static constexpr unsigned int warps_no = BlockSize / warp_size;
103  static_assert(BlockSize % ::rocprim::device_warp_size() == 0, "BlockSize must be divisible by warp size");
104  static_assert(radix_size <= BlockSize, "Radix size must not exceed BlockSize");
105 
107  {
108  unsigned int digit_counts[warps_no][radix_size];
109  };
110 
111  template<
112  bool IsFull = false,
113  class KeysInputIterator,
114  class Offset
115  >
116  ROCPRIM_DEVICE ROCPRIM_INLINE
117  void count_digits(KeysInputIterator keys_input,
118  Offset begin_offset,
119  Offset end_offset,
120  unsigned int bit,
121  unsigned int current_radix_bits,
122  storage_type& storage,
123  unsigned int& digit_count) // i-th thread will get i-th digit's value
124  {
125  constexpr unsigned int items_per_block = BlockSize * ItemsPerThread;
126 
127  using key_type = typename std::iterator_traits<KeysInputIterator>::value_type;
128 
129  using key_codec = radix_key_codec<key_type, Descending>;
130  using bit_key_type = typename key_codec::bit_key_type;
131 
132  const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>();
133  const unsigned int warp_id = ::rocprim::warp_id<0, 1, 1>();
134 
135  if(flat_id < radix_size)
136  {
137  for(unsigned int w = 0; w < warps_no; w++)
138  {
139  storage.digit_counts[w][flat_id] = 0;
140  }
141  }
143 
144  for(Offset block_offset = begin_offset; block_offset < end_offset; block_offset += items_per_block)
145  {
146  key_type keys[ItemsPerThread];
147  unsigned int valid_count;
148  // Use loading into a striped arrangement because an order of items is irrelevant,
149  // only totals matter
150  if(IsFull || (block_offset + items_per_block <= end_offset))
151  {
152  valid_count = items_per_block;
153  block_load_direct_striped<BlockSize>(flat_id, keys_input + block_offset, keys);
154  }
155  else
156  {
157  valid_count = end_offset - block_offset;
158  block_load_direct_striped<BlockSize>(flat_id, keys_input + block_offset, keys, valid_count);
159  }
160 
161  for(unsigned int i = 0; i < ItemsPerThread; i++)
162  {
163  const bit_key_type bit_key = key_codec::encode(keys[i]);
164  const unsigned int digit = key_codec::extract_digit(bit_key, bit, current_radix_bits);
165  const unsigned int pos = i * BlockSize + flat_id;
166  lane_mask_type same_digit_lanes_mask = ::rocprim::ballot(IsFull || (pos < valid_count));
167  for(unsigned int b = 0; b < RadixBits; b++)
168  {
169  const unsigned int bit_set = digit & (1u << b);
170  const lane_mask_type bit_set_mask = ::rocprim::ballot(bit_set);
171  same_digit_lanes_mask &= (bit_set ? bit_set_mask : ~bit_set_mask);
172  }
173  const unsigned int same_digit_count = ::rocprim::bit_count(same_digit_lanes_mask);
174  const unsigned int prev_same_digit_count = ::rocprim::masked_bit_count(same_digit_lanes_mask);
175  if(prev_same_digit_count == 0)
176  {
177  // Write the number of lanes having this digit,
178  // if the current lane is the first (and maybe only) lane with this digit.
179  storage.digit_counts[warp_id][digit] += same_digit_count;
180  }
181  }
182  }
184 
185  digit_count = 0;
186  if(flat_id < radix_size)
187  {
188  for(unsigned int w = 0; w < warps_no; w++)
189  {
190  digit_count += storage.digit_counts[w][flat_id];
191  }
192  }
193  }
194 };
195 
196 template<
197  unsigned int BlockSize,
198  unsigned int ItemsPerThread,
199  bool Descending,
200  class Key,
201  class Value
202 >
204 {
205  static constexpr unsigned int items_per_block = BlockSize * ItemsPerThread;
206 
207  using key_type = Key;
208  using value_type = Value;
209 
211  using bit_key_type = typename key_codec::bit_key_type;
212  using sort_type = ::rocprim::block_radix_sort<key_type, BlockSize, ItemsPerThread, value_type>;
213 
214  static constexpr bool with_values = !std::is_same<value_type, ::rocprim::empty_type>::value;
215 
217  {
218  typename sort_type::storage_type sort;
219  };
220 
221  template<
222  class KeysInputIterator,
223  class KeysOutputIterator,
224  class ValuesInputIterator,
225  class ValuesOutputIterator
226  >
227  ROCPRIM_DEVICE ROCPRIM_INLINE
228  void sort_single(KeysInputIterator keys_input,
229  KeysOutputIterator keys_output,
230  ValuesInputIterator values_input,
231  ValuesOutputIterator values_output,
232  unsigned int size,
233  unsigned int bit,
234  unsigned int current_radix_bits,
235  storage_type& storage)
236  {
237  const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>();
238  const unsigned int flat_block_id = ::rocprim::detail::block_id<0>();
239  const unsigned int block_offset = flat_block_id * items_per_block;
240  const bool is_incomplete_block = flat_block_id == (size / items_per_block);
241  const unsigned int valid_in_last_block = size - block_offset;
242 
243  using key_type = typename std::iterator_traits<KeysInputIterator>::value_type;
244 
246  using bit_key_type = typename key_codec::bit_key_type;
247 
248  key_type keys[ItemsPerThread];
249  value_type values[ItemsPerThread];
250  if(!is_incomplete_block)
251  {
252  block_load_direct_blocked(flat_id, keys_input + block_offset, keys);
253  if ROCPRIM_IF_CONSTEXPR(with_values)
254  {
255  block_load_direct_blocked(flat_id, values_input + block_offset, values);
256  }
257  }
258  else
259  {
260  const key_type out_of_bounds = key_codec::decode(bit_key_type(-1));
262  keys_input + block_offset,
263  keys,
264  valid_in_last_block,
265  out_of_bounds);
266  if ROCPRIM_IF_CONSTEXPR(with_values)
267  {
269  values_input + block_offset,
270  values,
271  valid_in_last_block);
272  }
273  }
274 
275  sort_block<Descending>(sort_type(), keys, values, storage.sort, bit, bit + current_radix_bits);
276 
277  // Store keys and values
278  if(!is_incomplete_block)
279  {
280  block_store_direct_blocked(flat_id, keys_output + block_offset, keys);
281  if ROCPRIM_IF_CONSTEXPR(with_values)
282  {
283  block_store_direct_blocked(flat_id, values_output + block_offset, values);
284  }
285  }
286  else
287  {
289  keys_output + block_offset,
290  keys,
291  valid_in_last_block);
292  if ROCPRIM_IF_CONSTEXPR(with_values)
293  {
295  values_output + block_offset,
296  values,
297  valid_in_last_block);
298  }
299  }
300  }
301 };
302 
303 template<
304  unsigned int BlockSize,
305  unsigned int ItemsPerThread,
306  unsigned int RadixBits,
307  bool Descending,
308  class Key,
309  class Value,
310  class Offset
311 >
313 {
314  static constexpr unsigned int items_per_block = BlockSize * ItemsPerThread;
315  static constexpr unsigned int radix_size = 1 << RadixBits;
316 
317  using key_type = Key;
318  using value_type = Value;
319 
321  using bit_key_type = typename key_codec::bit_key_type;
322  using keys_load_type = ::rocprim::block_load<
323  key_type, BlockSize, ItemsPerThread,
324  ::rocprim::block_load_method::block_load_transpose>;
325  using values_load_type = ::rocprim::block_load<
326  value_type, BlockSize, ItemsPerThread,
327  ::rocprim::block_load_method::block_load_transpose>;
328  using sort_type = ::rocprim::block_radix_sort<key_type, BlockSize, ItemsPerThread, value_type>;
329  using discontinuity_type = ::rocprim::block_discontinuity<unsigned int, BlockSize>;
330  using bit_keys_exchange_type = ::rocprim::block_exchange<bit_key_type, BlockSize, ItemsPerThread>;
331  using values_exchange_type = ::rocprim::block_exchange<value_type, BlockSize, ItemsPerThread>;
332 
333  static constexpr bool with_values = !std::is_same<value_type, ::rocprim::empty_type>::value;
334 
336  {
337  union
338  {
339  typename keys_load_type::storage_type keys_load;
340  typename values_load_type::storage_type values_load;
341  typename sort_type::storage_type sort;
342  typename discontinuity_type::storage_type discontinuity;
343  typename bit_keys_exchange_type::storage_type bit_keys_exchange;
344  typename values_exchange_type::storage_type values_exchange;
345  };
346 
347  unsigned short starts[radix_size];
348  unsigned short ends[radix_size];
349 
350  Offset digit_starts[radix_size];
351  };
352 
353  template<
354  bool IsFull = false,
355  class KeysInputIterator,
356  class KeysOutputIterator,
357  class ValuesInputIterator,
358  class ValuesOutputIterator
359  >
360  ROCPRIM_DEVICE ROCPRIM_INLINE
361  void sort_and_scatter(KeysInputIterator keys_input,
362  KeysOutputIterator keys_output,
363  ValuesInputIterator values_input,
364  ValuesOutputIterator values_output,
365  Offset begin_offset,
366  Offset end_offset,
367  unsigned int bit,
368  unsigned int current_radix_bits,
369  Offset digit_start, // i-th thread must pass i-th digit's value
370  storage_type& storage)
371  {
372  const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>();
373 
374  if(flat_id < radix_size)
375  {
376  storage.digit_starts[flat_id] = digit_start;
377  }
378 
379  for(Offset block_offset = begin_offset; block_offset < end_offset; block_offset += items_per_block)
380  {
381  key_type keys[ItemsPerThread];
382  value_type values[ItemsPerThread];
383  unsigned int valid_count;
384  if(IsFull || (block_offset + items_per_block <= end_offset))
385  {
386  valid_count = items_per_block;
387  keys_load_type().load(keys_input + block_offset, keys, storage.keys_load);
388  if(with_values)
389  {
391  values_load_type().load(values_input + block_offset, values, storage.values_load);
392  }
393  }
394  else
395  {
396  valid_count = end_offset - block_offset;
397  // Sort will leave "invalid" (out of size) items at the end of the sorted sequence
398  const key_type out_of_bounds = key_codec::decode(bit_key_type(-1));
399  keys_load_type().load(keys_input + block_offset, keys, valid_count, out_of_bounds, storage.keys_load);
400  if(with_values)
401  {
403  values_load_type().load(values_input + block_offset, values, valid_count, storage.values_load);
404  }
405  }
406 
407  if(flat_id < radix_size)
408  {
409  storage.starts[flat_id] = valid_count;
410  storage.ends[flat_id] = valid_count;
411  }
412 
414  sort_block<Descending>(sort_type(), keys, values, storage.sort, bit, bit + current_radix_bits);
415 
416  bit_key_type bit_keys[ItemsPerThread];
417  unsigned int digits[ItemsPerThread];
418  for(unsigned int i = 0; i < ItemsPerThread; i++)
419  {
420  bit_keys[i] = key_codec::encode(keys[i]);
421  digits[i] = key_codec::extract_digit(bit_keys[i], bit, current_radix_bits);
422  }
423 
424  bool head_flags[ItemsPerThread];
425  bool tail_flags[ItemsPerThread];
426  ::rocprim::not_equal_to<unsigned int> flag_op;
427 
429  discontinuity_type().flag_heads_and_tails(head_flags, tail_flags, digits, flag_op, storage.discontinuity);
430 
431  // Fill start and end position of subsequence for every digit
432  for(unsigned int i = 0; i < ItemsPerThread; i++)
433  {
434  const unsigned int digit = digits[i];
435  const unsigned int pos = flat_id * ItemsPerThread + i;
436  if(head_flags[i])
437  {
438  storage.starts[digit] = pos;
439  }
440  if(tail_flags[i])
441  {
442  storage.ends[digit] = pos;
443  }
444  }
445 
447  // Rearrange to striped arrangement to have faster coalesced writes instead of
448  // scattering of blocked-arranged items
449  bit_keys_exchange_type().blocked_to_striped(bit_keys, bit_keys, storage.bit_keys_exchange);
450  if(with_values)
451  {
453  values_exchange_type().blocked_to_striped(values, values, storage.values_exchange);
454  }
455 
456  for(unsigned int i = 0; i < ItemsPerThread; i++)
457  {
458  const unsigned int digit = key_codec::extract_digit(bit_keys[i], bit, current_radix_bits);
459  const unsigned int pos = i * BlockSize + flat_id;
460  if(IsFull || (pos < valid_count))
461  {
462  const Offset dst = pos - storage.starts[digit] + storage.digit_starts[digit];
463  keys_output[dst] = key_codec::decode(bit_keys[i]);
464  if(with_values)
465  {
466  values_output[dst] = values[i];
467  }
468  }
469  }
470 
472 
473  // Accumulate counts of the current block
474  if(flat_id < radix_size)
475  {
476  const unsigned int digit = flat_id;
477  const unsigned int start = storage.starts[digit];
478  const unsigned int end = storage.ends[digit];
479  if(start < valid_count)
480  {
481  storage.digit_starts[digit] += (::rocprim::min(valid_count - 1, end) - start + 1);
482  }
483  }
484  }
485  }
486 };
487 
488 template<
489  unsigned int BlockSize,
490  unsigned int ItemsPerThread,
491  bool Descending,
492  class KeysInputIterator,
493  class KeysOutputIterator,
494  class ValuesInputIterator,
495  class ValuesOutputIterator
496 >
497 ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
498 void sort_single(KeysInputIterator keys_input,
499  KeysOutputIterator keys_output,
500  ValuesInputIterator values_input,
501  ValuesOutputIterator values_output,
502  unsigned int size,
503  unsigned int bit,
504  unsigned int current_radix_bits)
505 {
506  using key_type = typename std::iterator_traits<KeysInputIterator>::value_type;
507  using value_type = typename std::iterator_traits<ValuesInputIterator>::value_type;
508 
509  using sort_single_helper = radix_sort_single_helper<
510  BlockSize, ItemsPerThread, Descending,
511  key_type, value_type
512  >;
513 
514  ROCPRIM_SHARED_MEMORY typename sort_single_helper::storage_type storage;
515 
516  sort_single_helper().template sort_single(
517  keys_input, keys_output, values_input, values_output,
518  size, bit, current_radix_bits,
519  storage
520  );
521 }
522 
523 template<class T>
524 ROCPRIM_DEVICE ROCPRIM_INLINE
525 auto compare_nan_sensitive(const T& a, const T& b)
526  -> typename std::enable_if<rocprim::is_floating_point<T>::value, bool>::type
527 {
528  // Beware: the performance of this function is extremely vulnerable to refactoring.
529  // Always check benchmark_device_segmented_radix_sort and benchmark_device_radix_sort
530  // when making changes to this function.
531 
532  using bit_key_type = typename float_bit_mask<T>::bit_type;
533  static constexpr auto sign_bit = float_bit_mask<T>::sign_bit;
534 
535  auto a_bits = __builtin_bit_cast(bit_key_type, a);
536  auto b_bits = __builtin_bit_cast(bit_key_type, b);
537 
538  // convert -0.0 to +0.0
539  a_bits = a_bits == sign_bit ? 0 : a_bits;
540  b_bits = b_bits == sign_bit ? 0 : b_bits;
541  // invert negatives, put 1 into sign bit for positives
542  a_bits ^= (sign_bit & a_bits) == 0 ? sign_bit : bit_key_type(-1);
543  b_bits ^= (sign_bit & b_bits) == 0 ? sign_bit : bit_key_type(-1);
544 
545  // sort numbers and NaNs according to their bit representation
546  return a_bits > b_bits;
547 }
548 
549 template<class T>
550 ROCPRIM_DEVICE ROCPRIM_INLINE
551 auto compare_nan_sensitive(const T& a, const T& b)
552  -> typename std::enable_if<!rocprim::is_floating_point<T>::value, bool>::type
553 {
554  return a > b;
555 }
556 
557 template<
558  bool Descending,
559  bool UseRadixMask,
560  class T,
561  class Enable = void
562 >
564 
565 template<class T>
566 struct radix_merge_compare<false, false, T>
567 {
568  ROCPRIM_DEVICE ROCPRIM_INLINE
569  bool operator()(const T& a, const T& b) const
570  {
571  return compare_nan_sensitive<T>(b, a);
572  }
573 };
574 
575 template<class T>
576 struct radix_merge_compare<true, false, T>
577 {
578  ROCPRIM_DEVICE ROCPRIM_INLINE
579  bool operator()(const T& a, const T& b) const
580  {
581  return compare_nan_sensitive<T>(a, b);
582  }
583 };
584 
585 template<class T>
586 struct radix_merge_compare<false, true, T, typename std::enable_if<rocprim::is_integral<T>::value>::type>
587 {
588  T radix_mask;
589 
590  ROCPRIM_HOST_DEVICE ROCPRIM_INLINE
591  radix_merge_compare(const unsigned int start_bit, const unsigned int current_radix_bits)
592  {
593  T radix_mask_upper = (T(1) << (current_radix_bits + start_bit)) - 1;
594  T radix_mask_bottom = (T(1) << start_bit) - 1;
595  radix_mask = radix_mask_upper ^ radix_mask_bottom;
596  }
597 
598  ROCPRIM_DEVICE ROCPRIM_INLINE
599  bool operator()(const T& a, const T& b) const
600  {
601  const T masked_key_a = a & radix_mask;
602  const T masked_key_b = b & radix_mask;
603  return masked_key_b > masked_key_a;
604  }
605 };
606 
607 template<class T>
608 struct radix_merge_compare<true, true, T, typename std::enable_if<rocprim::is_integral<T>::value>::type>
609 {
610  T radix_mask;
611 
612  ROCPRIM_HOST_DEVICE ROCPRIM_INLINE
613  radix_merge_compare(const unsigned int start_bit, const unsigned int current_radix_bits)
614  {
615  T radix_mask_upper = (T(1) << (current_radix_bits + start_bit)) - 1;
616  T radix_mask_bottom = (T(1) << start_bit) - 1;
617  radix_mask = (radix_mask_upper ^ radix_mask_bottom);
618  }
619 
620  ROCPRIM_DEVICE ROCPRIM_INLINE
621  bool operator()(const T& a, const T& b) const
622  {
623  const T masked_key_a = a & radix_mask;
624  const T masked_key_b = b & radix_mask;
625  return masked_key_a > masked_key_b;
626  }
627 };
628 
629 template<bool Descending, class T>
630 struct radix_merge_compare<Descending,
631  true,
632  T,
633  typename std::enable_if<!rocprim::is_integral<T>::value>::type>
634 {
635  // radix_merge_compare supports masks only for integrals.
636  // even though masks are never used for floating point-types,
637  // it needs to be able to compile.
638  ROCPRIM_HOST_DEVICE ROCPRIM_INLINE
639  radix_merge_compare(const unsigned int, const unsigned int){}
640 
641  ROCPRIM_DEVICE ROCPRIM_INLINE
642  bool operator()(const T&, const T&) const { return false; }
643 };
644 
645 template<class KeyType,
646  unsigned int BlockSize,
647  unsigned int ItemsPerThread,
648  unsigned int RadixBits,
649  bool Descending>
651 {
652  static constexpr unsigned int radix_size = 1u << RadixBits;
653  // Upper bound, this value does not take into account the actual size of the number of bits
654  // that are to be considered in the radix sort.
655  static constexpr unsigned int max_digit_places
656  = ::rocprim::detail::ceiling_div(sizeof(KeyType) * 8, RadixBits);
657  static constexpr unsigned int items_per_block = BlockSize * ItemsPerThread;
658  static constexpr unsigned int digits_per_thread
659  = ::rocprim::detail::ceiling_div(radix_size, BlockSize);
660  static constexpr unsigned int atomic_stripes = 4;
661  static constexpr unsigned int histogram_counters
662  = radix_size * max_digit_places * atomic_stripes;
663 
664  using counter_type = uint32_t;
666  using bit_key_type = typename key_codec::bit_key_type;
667 
669  {
670  counter_type histogram[histogram_counters];
671  };
672 
673  ROCPRIM_DEVICE ROCPRIM_INLINE counter_type& get_counter(const unsigned stripe_index,
674  const unsigned int place,
675  const unsigned int digit,
676  storage_type& storage)
677  {
678  return storage.histogram[(place * radix_size + digit) * atomic_stripes + stripe_index];
679  }
680 
681  ROCPRIM_DEVICE ROCPRIM_INLINE void clear_histogram(const unsigned int flat_id,
682  storage_type& storage)
683  {
684  for(unsigned int i = flat_id; i < histogram_counters; i += BlockSize)
685  {
686  storage.histogram[i] = 0;
687  }
688  }
689 
690  template<bool IsFull>
691  ROCPRIM_DEVICE void count_digits_at_place(const unsigned int flat_id,
692  const unsigned int stripe,
693  const bit_key_type (&bit_keys)[ItemsPerThread],
694  const unsigned int place,
695  const unsigned int start_bit,
696  const unsigned int current_radix_bits,
697  const unsigned int valid_count,
698  storage_type& storage)
699  {
700  ROCPRIM_UNROLL
701  for(unsigned int i = 0; i < ItemsPerThread; ++i)
702  {
703  const unsigned int pos = i * BlockSize + flat_id;
704  if(IsFull || pos < valid_count)
705  {
706  const unsigned int digit
707  = key_codec::extract_digit(bit_keys[i], start_bit, current_radix_bits);
708  ::rocprim::detail::atomic_add(&get_counter(stripe, place, digit, storage), 1);
709  }
710  }
711  }
712 
713  template<bool IsFull, class KeysInputIterator, class Offset>
714  ROCPRIM_DEVICE void count_digits(KeysInputIterator keys_input,
715  Offset* global_digit_counts,
716  const unsigned int valid_count,
717  const unsigned int begin_bit,
718  const unsigned int end_bit,
719  storage_type& storage)
720  {
721  const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>();
722  const unsigned int stripe = flat_id % atomic_stripes;
723 
724  KeyType keys[ItemsPerThread];
725  // Load using a striped arrangement, the order doesn't matter here.
726  if ROCPRIM_IF_CONSTEXPR(IsFull)
727  {
728  block_load_direct_striped<BlockSize>(flat_id, keys_input, keys);
729  }
730  else
731  {
732  block_load_direct_striped<BlockSize>(flat_id, keys_input, keys, valid_count);
733  }
734 
735  // Initialize shared counters to zero.
736  clear_histogram(flat_id, storage);
737 
739 
740  // Compute a shared histogram for each digit and each place.
741  bit_key_type bit_keys[ItemsPerThread];
742  ROCPRIM_UNROLL
743  for(unsigned int i = 0; i < ItemsPerThread; ++i)
744  {
745  bit_keys[i] = key_codec::encode(keys[i]);
746  }
747 
748  for(unsigned int bit = begin_bit, place = 0; bit < end_bit; bit += RadixBits, ++place)
749  {
750  count_digits_at_place<IsFull>(flat_id,
751  stripe,
752  bit_keys,
753  place,
754  bit,
755  min(RadixBits, end_bit - bit),
756  valid_count,
757  storage);
758  }
759 
761 
762  // Combine the local histograms into a global histogram.
763 
764  unsigned int place = 0;
765  for(unsigned int bit = begin_bit; bit < end_bit; bit += RadixBits)
766  {
767  for(unsigned int digit = flat_id; digit < radix_size; digit += BlockSize)
768  {
769  counter_type total = 0;
770 
771  ROCPRIM_UNROLL
772  for(unsigned int stripe = 0; stripe < atomic_stripes; ++stripe)
773  {
774  total += get_counter(stripe, place, digit, storage);
775  }
776 
777  ::rocprim::detail::atomic_add(&global_digit_counts[place * radix_size + digit],
778  total);
779  }
780  ++place;
781  }
782  }
783 };
784 
785 template<unsigned int BlockSize,
786  unsigned int ItemsPerThread,
787  unsigned int RadixBits,
788  bool Descending,
789  class KeysInputIterator,
790  class Offset>
791 ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void onesweep_histograms(KeysInputIterator keys_input,
792  Offset* global_digit_counts,
793  const Offset size,
794  const Offset full_blocks,
795  const unsigned int begin_bit,
796  const unsigned int end_bit)
797 {
798  using key_type = typename std::iterator_traits<KeysInputIterator>::value_type;
799  using count_helper_type
801 
802  constexpr unsigned int items_per_block = BlockSize * ItemsPerThread;
803 
804  const Offset block_id = ::rocprim::detail::block_id<0>();
805  const Offset block_offset = block_id * ItemsPerThread * BlockSize;
806 
807  ROCPRIM_SHARED_MEMORY typename count_helper_type::storage_type storage;
808 
809  if(block_id < full_blocks)
810  {
811  count_helper_type{}.template count_digits<true>(keys_input + block_offset,
812  global_digit_counts,
813  items_per_block,
814  begin_bit,
815  end_bit,
816  storage);
817  }
818  else
819  {
820  const unsigned int valid_in_last_block = size - items_per_block * full_blocks;
821  count_helper_type{}.template count_digits<false>(keys_input + block_offset,
822  global_digit_counts,
823  valid_in_last_block,
824  begin_bit,
825  end_bit,
826  storage);
827  }
828 }
829 
830 template<unsigned int BlockSize, unsigned int RadixBits, class Offset>
831 ROCPRIM_DEVICE void onesweep_scan_histograms(Offset* global_digit_offsets)
832 {
833  using block_scan_type = block_scan<Offset, BlockSize>;
834 
835  constexpr unsigned int radix_size = 1u << RadixBits;
836  constexpr unsigned int items_per_thread = ::rocprim::detail::ceiling_div(radix_size, BlockSize);
837 
838  const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>();
839  const unsigned int digit_place = ::rocprim::detail::block_id<0>();
840  const unsigned int block_offset = digit_place * radix_size;
841 
842  Offset offsets[items_per_thread];
843  block_load_direct_blocked(flat_id, global_digit_offsets + block_offset, offsets, radix_size);
844  block_scan_type{}.exclusive_scan(offsets, offsets, 0);
845  block_store_direct_blocked(flat_id, global_digit_offsets + block_offset, offsets, radix_size);
846 }
847 
849 {
850  // The two most significant bits are used to indicate the status of the prefix - leaving the other 30 bits for the
851  // counter value.
852  using underlying_type = uint32_t;
853 
854  static constexpr unsigned int state_bits = 8u * sizeof(underlying_type);
855 
856  enum prefix_flag : underlying_type
857  {
858  EMPTY = 0,
859  PARTIAL = 1u << (state_bits - 2),
860  COMPLETE = 2u << (state_bits - 2)
861  };
862 
863  static constexpr underlying_type status_mask = 3u << (state_bits - 2);
864  static constexpr underlying_type value_mask = ~status_mask;
865 
866  underlying_type state;
867 
868  ROCPRIM_DEVICE ROCPRIM_INLINE explicit onesweep_lookback_state(underlying_type state)
869  : state(state)
870  {}
871 
872  ROCPRIM_DEVICE ROCPRIM_INLINE onesweep_lookback_state(prefix_flag status, underlying_type value)
873  : state(static_cast<underlying_type>(status) | value)
874  {}
875 
876  ROCPRIM_DEVICE ROCPRIM_INLINE underlying_type value() const
877  {
878  return this->state & value_mask;
879  }
880 
881  ROCPRIM_DEVICE ROCPRIM_INLINE prefix_flag status() const
882  {
883  return static_cast<prefix_flag>(this->state & status_mask);
884  }
885 
886  ROCPRIM_DEVICE ROCPRIM_INLINE static onesweep_lookback_state load(onesweep_lookback_state* ptr)
887  {
888  underlying_type state = ::rocprim::detail::atomic_add(&ptr->state, 0);
889  return onesweep_lookback_state(state);
890  }
891 
892  ROCPRIM_DEVICE ROCPRIM_INLINE void store(onesweep_lookback_state* ptr) const
893  {
894  ::rocprim::detail::atomic_exch(&ptr->state, this->state);
895  }
896 };
897 
898 template<class Key,
899  class Value,
900  class Offset,
901  unsigned int BlockSize,
902  unsigned int ItemsPerThread,
903  unsigned int RadixBits,
904  bool Descending,
905  block_radix_rank_algorithm RadixRankAlgorithm>
907 {
908  static constexpr unsigned int radix_size = 1u << RadixBits;
909  static constexpr unsigned int items_per_block = BlockSize * ItemsPerThread;
910  static constexpr bool with_values = !std::is_same<Value, rocprim::empty_type>::value;
911 
913  using bit_key_type = typename key_codec::bit_key_type;
914  using radix_rank_type = ::rocprim::block_radix_rank<BlockSize, RadixBits, RadixRankAlgorithm>;
915 
916  static constexpr bool load_warp_striped
917  = RadixRankAlgorithm == block_radix_rank_algorithm::match;
918 
919  static constexpr unsigned int digits_per_thread = radix_rank_type::digits_per_thread;
920 
922  {
923  typename radix_rank_type::storage_type rank;
924  struct
925  {
926  Offset global_digit_offsets[radix_size];
927  union
928  {
929  bit_key_type ordered_block_keys[items_per_block];
930  Value ordered_block_values[items_per_block];
931  };
932  };
933  };
934 
936 
937  template<bool IsFull,
938  class KeysInputIterator,
939  class KeysOutputIterator,
940  class ValuesInputIterator,
941  class ValuesOutputIterator>
942  ROCPRIM_DEVICE void onesweep(KeysInputIterator keys_input,
943  KeysOutputIterator keys_output,
944  ValuesInputIterator values_input,
945  ValuesOutputIterator values_output,
946  Offset* global_digit_offsets_in,
947  Offset* global_digit_offsets_out,
948  onesweep_lookback_state* lookback_states,
949  const unsigned int bit,
950  const unsigned int current_radix_bits,
951  const unsigned int valid_items,
952  storage_type_& storage)
953  {
954  const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>();
955  const unsigned int block_id = ::rocprim::detail::block_id<0>();
956  const unsigned int block_offset = block_id * items_per_block;
957 
958  // Load keys into private memory, and encode them to unsigned integers.
959  Key keys[ItemsPerThread];
960  if ROCPRIM_IF_CONSTEXPR(IsFull)
961  {
962  if ROCPRIM_IF_CONSTEXPR(load_warp_striped)
963  {
964  block_load_direct_warp_striped(flat_id, keys_input + block_offset, keys);
965  }
966  else
967  {
968  block_load_direct_blocked(flat_id, keys_input + block_offset, keys);
969  }
970  }
971  else
972  {
973  // Fill the out-of-bounds elements of the key array with the key value with
974  // the largest digit. This will make sure they are sorted (ranked) last, and
975  // thus will be omitted when we compare the item offset against `valid_items` later.
976  // Note that this will lead to an incorrect digit count. Since this is the very last digit,
977  // it does not matter. It does cause the final digit offset to be increased past its end,
978  // but again this does not matter since this is the last iteration in which it will be used anyway.
979  const Key out_of_bounds = key_codec::decode(bit_key_type(-1));
980  if ROCPRIM_IF_CONSTEXPR(load_warp_striped)
981  {
983  keys_input + block_offset,
984  keys,
985  valid_items,
986  out_of_bounds);
987  }
988  else
989  {
991  keys_input + block_offset,
992  keys,
993  valid_items,
994  out_of_bounds);
995  }
996  }
997 
998  bit_key_type bit_keys[ItemsPerThread];
999  ROCPRIM_UNROLL
1000  for(unsigned int i = 0; i < ItemsPerThread; ++i)
1001  {
1002  bit_keys[i] = key_codec::encode(keys[i]);
1003  }
1004 
1005  // Compute the block-based key ranks, the digit counts, and the prefix sum of the digit counts.
1006  unsigned int ranks[ItemsPerThread];
1007  // Tile-wide digit offset
1008  unsigned int exclusive_digit_prefix[digits_per_thread];
1009  // Tile-wide digit count
1010  unsigned int digit_counts[digits_per_thread];
1011  radix_rank_type{}.rank_keys(
1012  bit_keys,
1013  ranks,
1014  storage.rank,
1015  [bit, current_radix_bits](const bit_key_type& key)
1016  { return key_codec::extract_digit(key, bit, current_radix_bits); },
1017  exclusive_digit_prefix,
1018  digit_counts);
1019 
1021 
1022  // Order keys in shared memory.
1023  ROCPRIM_UNROLL
1024  for(unsigned int i = 0; i < ItemsPerThread; ++i)
1025  {
1026  storage.ordered_block_keys[ranks[i]] = bit_keys[i];
1027  }
1028 
1030 
1031  // Compute the global prefix for each histogram.
1032  // At this point `lookback_states` already hold `onesweep_lookback_state::EMPTY`.
1033  ROCPRIM_UNROLL
1034  for(unsigned int i = 0; i < digits_per_thread; ++i)
1035  {
1036  const unsigned int digit = flat_id * digits_per_thread + i;
1037  if(radix_size % BlockSize == 0 || digit < radix_size)
1038  {
1039  onesweep_lookback_state* block_state
1040  = &lookback_states[block_id * radix_size + digit];
1041  onesweep_lookback_state(onesweep_lookback_state::PARTIAL, digit_counts[i])
1042  .store(block_state);
1043 
1044  unsigned int exclusive_prefix = 0;
1045  unsigned int lookback_block_id = block_id;
1046  // The main back tracking loop.
1047  while(lookback_block_id > 0)
1048  {
1049  --lookback_block_id;
1050  onesweep_lookback_state* lookback_state_ptr
1051  = &lookback_states[lookback_block_id * radix_size + digit];
1052  onesweep_lookback_state lookback_state
1053  = onesweep_lookback_state::load(lookback_state_ptr);
1054  while(lookback_state.status() == onesweep_lookback_state::EMPTY)
1055  {
1056  lookback_state = onesweep_lookback_state::load(lookback_state_ptr);
1057  }
1058 
1059  exclusive_prefix += lookback_state.value();
1060  if(lookback_state.status() == onesweep_lookback_state::COMPLETE)
1061  {
1062  break;
1063  }
1064  }
1065 
1066  // Update the state for the current block.
1067  const unsigned int inclusive_digit_prefix = exclusive_prefix + digit_counts[i];
1068  // Note that this should not deadlock, as HSA guarantees that blocks with a lower block ID launch before
1069  // those with a higher block id.
1070  onesweep_lookback_state(onesweep_lookback_state::COMPLETE, inclusive_digit_prefix)
1071  .store(block_state);
1072 
1073  // Subtract the exclusive digit prefix from the global offset here, since we already ordered the keys in shared
1074  // memory.
1075  storage.global_digit_offsets[digit]
1076  = global_digit_offsets_in[digit] - exclusive_digit_prefix[i] + exclusive_prefix;
1077  }
1078  }
1079 
1081 
1082  // Scatter the keys to global memory in a sorted fashion.
1083  ROCPRIM_UNROLL
1084  for(unsigned int i = 0; i < ItemsPerThread; ++i)
1085  {
1086  const unsigned int rank = i * BlockSize + flat_id;
1087  if(IsFull || rank < valid_items)
1088  {
1089  const bit_key_type bit_key = storage.ordered_block_keys[rank];
1090  const unsigned int digit
1091  = key_codec::extract_digit(bit_key, bit, current_radix_bits);
1092  const Offset global_offset = storage.global_digit_offsets[digit];
1093  keys_output[rank + global_offset] = key_codec::decode(bit_key);
1094  }
1095  }
1096 
1097  // Gather and scatter values if necessary.
1098  if(with_values)
1099  {
1100  Value values[ItemsPerThread];
1101  if ROCPRIM_IF_CONSTEXPR(IsFull)
1102  {
1103  if ROCPRIM_IF_CONSTEXPR(load_warp_striped)
1104  {
1105  block_load_direct_warp_striped(flat_id, values_input + block_offset, values);
1106  }
1107  else
1108  {
1109  block_load_direct_blocked(flat_id, values_input + block_offset, values);
1110  }
1111  }
1112  else
1113  {
1114  if ROCPRIM_IF_CONSTEXPR(load_warp_striped)
1115  {
1117  values_input + block_offset,
1118  values,
1119  valid_items);
1120  }
1121  else
1122  {
1123  block_load_direct_blocked(flat_id,
1124  values_input + block_offset,
1125  values,
1126  valid_items);
1127  }
1128  }
1129 
1130  // Compute digits up-front so that we can re-use shared memory between ordered_block_keys and
1131  // ordered_block_values.
1132  unsigned int digits[ItemsPerThread];
1133  ROCPRIM_UNROLL
1134  for(unsigned int i = 0; i < ItemsPerThread; ++i)
1135  {
1136  const unsigned int rank = i * BlockSize + flat_id;
1137  if(IsFull || rank < valid_items)
1138  {
1139  const bit_key_type bit_key = storage.ordered_block_keys[rank];
1140  digits[i] = key_codec::extract_digit(bit_key, bit, current_radix_bits);
1141  }
1142  }
1143 
1145 
1146  // Order values in shared memory
1147  ROCPRIM_UNROLL
1148  for(unsigned int i = 0; i < ItemsPerThread; ++i)
1149  {
1150  storage.ordered_block_values[ranks[i]] = values[i];
1151  }
1152 
1154 
1155  // And scatter the values to global memory.
1156  ROCPRIM_UNROLL
1157  for(unsigned int i = 0; i < ItemsPerThread; ++i)
1158  {
1159  const unsigned int rank = i * BlockSize + flat_id;
1160  if(IsFull || rank < valid_items)
1161  {
1162  const Value value = storage.ordered_block_values[rank];
1163  const Offset global_offset = storage.global_digit_offsets[digits[i]];
1164  values_output[rank + global_offset] = value;
1165  }
1166  }
1167  }
1168 
1169  // Update the global digit offset if we are batching
1170  const bool is_last_block = block_id == rocprim::detail::grid_size<0>() - 1;
1171  if(is_last_block)
1172  {
1173  ROCPRIM_UNROLL
1174  for(unsigned int i = 0; i < digits_per_thread; ++i)
1175  {
1176  const unsigned int digit = flat_id * digits_per_thread + i;
1177  if(radix_size % BlockSize == 0 || digit < radix_size)
1178  {
1179  global_digit_offsets_out[digit] = storage.global_digit_offsets[digit]
1180  + exclusive_digit_prefix[i] + digit_counts[i];
1181  }
1182  }
1183  }
1184  }
1185 };
1186 
1187 template<unsigned int BlockSize,
1188  unsigned int ItemsPerThread,
1189  unsigned int RadixBits,
1190  bool Descending,
1191  block_radix_rank_algorithm RadixRankAlgorithm,
1192  class KeysInputIterator,
1193  class KeysOutputIterator,
1194  class ValuesInputIterator,
1195  class ValuesOutputIterator,
1196  class Offset>
1197 ROCPRIM_DEVICE void onesweep_iteration(KeysInputIterator keys_input,
1198  KeysOutputIterator keys_output,
1199  ValuesInputIterator values_input,
1200  ValuesOutputIterator values_output,
1201  const unsigned int size,
1202  Offset* global_digit_offsets_in,
1203  Offset* global_digit_offsets_out,
1204  onesweep_lookback_state* lookback_states,
1205  const unsigned int bit,
1206  const unsigned int current_radix_bits,
1207  const unsigned int full_blocks)
1208 {
1209  using key_type = typename std::iterator_traits<KeysInputIterator>::value_type;
1210  using value_type = typename std::iterator_traits<ValuesInputIterator>::value_type;
1211 
1212  using onesweep_iteration_helper_type = onesweep_iteration_helper<key_type,
1213  value_type,
1214  Offset,
1215  BlockSize,
1216  ItemsPerThread,
1217  RadixBits,
1218  Descending,
1219  RadixRankAlgorithm>;
1220 
1221  constexpr unsigned int items_per_block = BlockSize * ItemsPerThread;
1222  const unsigned int block_id = ::rocprim::detail::block_id<0>();
1223 
1224  ROCPRIM_SHARED_MEMORY typename onesweep_iteration_helper_type::storage_type storage;
1225 
1226  if(block_id < full_blocks)
1227  {
1228  onesweep_iteration_helper_type{}.template onesweep<true>(keys_input,
1229  keys_output,
1230  values_input,
1231  values_output,
1232  global_digit_offsets_in,
1233  global_digit_offsets_out,
1234  lookback_states,
1235  bit,
1236  current_radix_bits,
1237  items_per_block,
1238  storage.get());
1239  }
1240  else
1241  {
1242  const unsigned int valid_in_last_block = size - items_per_block * full_blocks;
1243  onesweep_iteration_helper_type{}.template onesweep<false>(keys_input,
1244  keys_output,
1245  values_input,
1246  values_output,
1247  global_digit_offsets_in,
1248  global_digit_offsets_out,
1249  lookback_states,
1250  bit,
1251  current_radix_bits,
1252  valid_in_last_block,
1253  storage.get());
1254  }
1255 }
1256 
1257 } // end namespace detail
1258 
1259 END_ROCPRIM_NAMESPACE
1260 
1261 #endif // ROCPRIM_DEVICE_DETAIL_DEVICE_RADIX_SORT_HPP_
ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int block_id()
Returns block identifier in a multidimensional grid by dimension.
Definition: thread.hpp:258
block_radix_rank_algorithm
Available algorithms for the block_radix_rank primitive.
Definition: block_radix_rank.hpp:40
ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int masked_bit_count(lane_mask_type x, unsigned int add=0)
Masked bit count.
Definition: warp.hpp:48
Definition: device_radix_sort.hpp:906
The block_scan class is a block level parallel primitive which provides methods for performing inclus...
Definition: block_scan.hpp:134
Definition: device_radix_sort.hpp:312
Definition: device_radix_sort.hpp:216
Definition: device_radix_sort.hpp:848
ROCPRIM_DEVICE ROCPRIM_INLINE constexpr unsigned int device_warp_size()
Returns a number of threads in a hardware warp for the actual target.
Definition: thread.hpp:70
Definition: radix_sort.hpp:101
ROCPRIM_HOST_DEVICE constexpr T min(const T &a, const T &b)
Returns the minimum of its arguments.
Definition: functional.hpp:63
Definition: device_radix_sort.hpp:563
Definition: device_radix_sort.hpp:335
Definition: device_radix_sort.hpp:921
Deprecated: Configuration of device-level scan primitives.
Definition: block_histogram.hpp:62
Definition: device_radix_sort.hpp:650
const unsigned int warp_id
Returns warp id in a block (tile).
Definition: benchmark_warp_exchange.cpp:153
ROCPRIM_DEVICE ROCPRIM_INLINE lane_mask_type ballot(int predicate)
Evaluate predicate for all active work-items in the warp and return an integer whose i-th bit is set ...
Definition: warp.hpp:38
ROCPRIM_DEVICE ROCPRIM_INLINE void syncthreads()
Synchronize all threads in a block (tile)
Definition: thread.hpp:216
ROCPRIM_DEVICE ROCPRIM_INLINE void block_load_direct_warp_striped(unsigned int flat_id, InputIterator block_input, T(&items)[ItemsPerThread])
Loads data from continuous memory into a warp-striped arrangement of items across the thread block...
Definition: block_load_func.hpp:378
ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int flat_block_id()
Returns flat (linear, 1D) block identifier in a multidimensional grid.
Definition: thread.hpp:178
Definition: radix_sort.hpp:241
Definition: device_radix_sort.hpp:106
BEGIN_ROCPRIM_NAMESPACE ROCPRIM_DEVICE ROCPRIM_INLINE void block_store_direct_blocked(unsigned int flat_id, OutputIterator block_output, T(&items)[ItemsPerThread])
Stores a blocked arrangement of items from across the thread block into a blocked arrangement on cont...
Definition: block_store_func.hpp:58
Definition: device_radix_sort.hpp:97
Warp-based radix ranking algorithm. Keys and ranks are assumed in warp-striped order for this algorit...
BEGIN_ROCPRIM_NAMESPACE ROCPRIM_DEVICE ROCPRIM_INLINE void block_load_direct_blocked(unsigned int flat_id, InputIterator block_input, T(&items)[ItemsPerThread])
Loads data from continuous memory into a blocked arrangement of items across the thread block...
Definition: block_load_func.hpp:58
Definition: benchmark_block_histogram.cpp:64
Definition: device_radix_sort.hpp:203
unsigned long long int lane_mask_type
The lane_mask_type is an integer that contains one bit per thread.
Definition: types.hpp:164
ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int bit_count(unsigned int x)
Bit count.
Definition: bit.hpp:42
Definition: device_radix_sort.hpp:668