21 #ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_RADIX_SORT_HPP_ 22 #define ROCPRIM_DEVICE_DETAIL_DEVICE_RADIX_SORT_HPP_ 24 #include <type_traits> 27 #include "../../config.hpp" 28 #include "../../detail/various.hpp" 29 #include "../../detail/radix_sort.hpp" 31 #include "../../intrinsics.hpp" 32 #include "../../functional.hpp" 33 #include "../../types.hpp" 35 #include "../../block/block_discontinuity.hpp" 36 #include "../../block/block_exchange.hpp" 37 #include "../../block/block_load.hpp" 38 #include "../../block/block_load_func.hpp" 39 #include "../../block/block_radix_rank.hpp" 40 #include "../../block/block_radix_sort.hpp" 41 #include "../../block/block_scan.hpp" 42 #include "../../block/block_store_func.hpp" 44 BEGIN_ROCPRIM_NAMESPACE
51 template<
bool Descending = false,
class SortType,
class SortKey,
class SortValue,
unsigned int ItemsPerThread>
52 ROCPRIM_DEVICE ROCPRIM_INLINE
53 void sort_block(SortType sorter,
54 SortKey (&keys)[ItemsPerThread],
55 SortValue (&values)[ItemsPerThread],
56 typename SortType::storage_type& storage,
57 unsigned int begin_bit,
62 sorter.sort_desc(keys, values, storage, begin_bit, end_bit);
66 sorter.sort(keys, values, storage, begin_bit, end_bit);
70 template<
bool Descending = false,
class SortType,
class SortKey,
unsigned int ItemsPerThread>
71 ROCPRIM_DEVICE ROCPRIM_INLINE
72 void sort_block(SortType sorter,
73 SortKey (&keys)[ItemsPerThread],
74 ::rocprim::empty_type (&values)[ItemsPerThread],
75 typename SortType::storage_type& storage,
76 unsigned int begin_bit,
82 sorter.sort_desc(keys, storage, begin_bit, end_bit);
86 sorter.sort(keys, storage, begin_bit, end_bit);
91 unsigned int WarpSize,
92 unsigned int BlockSize,
93 unsigned int ItemsPerThread,
94 unsigned int RadixBits,
99 static constexpr
unsigned int radix_size = 1 << RadixBits;
101 static constexpr
unsigned int warp_size = WarpSize;
102 static constexpr
unsigned int warps_no = BlockSize / warp_size;
104 static_assert(radix_size <= BlockSize,
"Radix size must not exceed BlockSize");
108 unsigned int digit_counts[warps_no][radix_size];
113 class KeysInputIterator,
116 ROCPRIM_DEVICE ROCPRIM_INLINE
117 void count_digits(KeysInputIterator keys_input,
121 unsigned int current_radix_bits,
123 unsigned int& digit_count)
125 constexpr
unsigned int items_per_block = BlockSize * ItemsPerThread;
127 using key_type =
typename std::iterator_traits<KeysInputIterator>::value_type;
130 using bit_key_type =
typename key_codec::bit_key_type;
132 const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>();
133 const unsigned int warp_id = ::rocprim::warp_id<0, 1, 1>();
135 if(flat_id < radix_size)
137 for(
unsigned int w = 0; w < warps_no; w++)
139 storage.digit_counts[w][flat_id] = 0;
144 for(Offset block_offset = begin_offset; block_offset < end_offset; block_offset += items_per_block)
146 key_type keys[ItemsPerThread];
147 unsigned int valid_count;
150 if(IsFull || (block_offset + items_per_block <= end_offset))
152 valid_count = items_per_block;
153 block_load_direct_striped<BlockSize>(flat_id, keys_input + block_offset, keys);
157 valid_count = end_offset - block_offset;
158 block_load_direct_striped<BlockSize>(flat_id, keys_input + block_offset, keys, valid_count);
161 for(
unsigned int i = 0; i < ItemsPerThread; i++)
163 const bit_key_type bit_key = key_codec::encode(keys[i]);
164 const unsigned int digit = key_codec::extract_digit(bit_key, bit, current_radix_bits);
165 const unsigned int pos = i * BlockSize + flat_id;
167 for(
unsigned int b = 0; b < RadixBits; b++)
169 const unsigned int bit_set = digit & (1u << b);
171 same_digit_lanes_mask &= (bit_set ? bit_set_mask : ~bit_set_mask);
175 if(prev_same_digit_count == 0)
179 storage.digit_counts[
warp_id][digit] += same_digit_count;
186 if(flat_id < radix_size)
188 for(
unsigned int w = 0; w < warps_no; w++)
190 digit_count += storage.digit_counts[w][flat_id];
197 unsigned int BlockSize,
198 unsigned int ItemsPerThread,
205 static constexpr
unsigned int items_per_block = BlockSize * ItemsPerThread;
207 using key_type = Key;
208 using value_type = Value;
211 using bit_key_type =
typename key_codec::bit_key_type;
212 using sort_type = ::rocprim::block_radix_sort<key_type, BlockSize, ItemsPerThread, value_type>;
214 static constexpr
bool with_values = !std::is_same<value_type, ::rocprim::empty_type>::value;
218 typename sort_type::storage_type sort;
222 class KeysInputIterator,
223 class KeysOutputIterator,
224 class ValuesInputIterator,
225 class ValuesOutputIterator
227 ROCPRIM_DEVICE ROCPRIM_INLINE
228 void sort_single(KeysInputIterator keys_input,
229 KeysOutputIterator keys_output,
230 ValuesInputIterator values_input,
231 ValuesOutputIterator values_output,
234 unsigned int current_radix_bits,
237 const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>();
238 const unsigned int flat_block_id = ::rocprim::detail::block_id<0>();
239 const unsigned int block_offset = flat_block_id * items_per_block;
240 const bool is_incomplete_block = flat_block_id == (size / items_per_block);
241 const unsigned int valid_in_last_block = size - block_offset;
243 using key_type =
typename std::iterator_traits<KeysInputIterator>::value_type;
246 using bit_key_type =
typename key_codec::bit_key_type;
248 key_type keys[ItemsPerThread];
249 value_type values[ItemsPerThread];
250 if(!is_incomplete_block)
253 if ROCPRIM_IF_CONSTEXPR(with_values)
260 const key_type out_of_bounds = key_codec::decode(bit_key_type(-1));
262 keys_input + block_offset,
266 if ROCPRIM_IF_CONSTEXPR(with_values)
269 values_input + block_offset,
271 valid_in_last_block);
275 sort_block<Descending>(sort_type(), keys, values, storage.sort, bit, bit + current_radix_bits);
278 if(!is_incomplete_block)
281 if ROCPRIM_IF_CONSTEXPR(with_values)
289 keys_output + block_offset,
291 valid_in_last_block);
292 if ROCPRIM_IF_CONSTEXPR(with_values)
295 values_output + block_offset,
297 valid_in_last_block);
304 unsigned int BlockSize,
305 unsigned int ItemsPerThread,
306 unsigned int RadixBits,
314 static constexpr
unsigned int items_per_block = BlockSize * ItemsPerThread;
315 static constexpr
unsigned int radix_size = 1 << RadixBits;
317 using key_type = Key;
318 using value_type = Value;
321 using bit_key_type =
typename key_codec::bit_key_type;
322 using keys_load_type = ::rocprim::block_load<
323 key_type, BlockSize, ItemsPerThread,
324 ::rocprim::block_load_method::block_load_transpose>;
325 using values_load_type = ::rocprim::block_load<
326 value_type, BlockSize, ItemsPerThread,
327 ::rocprim::block_load_method::block_load_transpose>;
328 using sort_type = ::rocprim::block_radix_sort<key_type, BlockSize, ItemsPerThread, value_type>;
329 using discontinuity_type = ::rocprim::block_discontinuity<unsigned int, BlockSize>;
330 using bit_keys_exchange_type = ::rocprim::block_exchange<bit_key_type, BlockSize, ItemsPerThread>;
331 using values_exchange_type = ::rocprim::block_exchange<value_type, BlockSize, ItemsPerThread>;
333 static constexpr
bool with_values = !std::is_same<value_type, ::rocprim::empty_type>::value;
339 typename keys_load_type::storage_type keys_load;
340 typename values_load_type::storage_type values_load;
341 typename sort_type::storage_type sort;
342 typename discontinuity_type::storage_type discontinuity;
343 typename bit_keys_exchange_type::storage_type bit_keys_exchange;
344 typename values_exchange_type::storage_type values_exchange;
347 unsigned short starts[radix_size];
348 unsigned short ends[radix_size];
350 Offset digit_starts[radix_size];
355 class KeysInputIterator,
356 class KeysOutputIterator,
357 class ValuesInputIterator,
358 class ValuesOutputIterator
360 ROCPRIM_DEVICE ROCPRIM_INLINE
361 void sort_and_scatter(KeysInputIterator keys_input,
362 KeysOutputIterator keys_output,
363 ValuesInputIterator values_input,
364 ValuesOutputIterator values_output,
368 unsigned int current_radix_bits,
372 const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>();
374 if(flat_id < radix_size)
376 storage.digit_starts[flat_id] = digit_start;
379 for(Offset block_offset = begin_offset; block_offset < end_offset; block_offset += items_per_block)
381 key_type keys[ItemsPerThread];
382 value_type values[ItemsPerThread];
383 unsigned int valid_count;
384 if(IsFull || (block_offset + items_per_block <= end_offset))
386 valid_count = items_per_block;
387 keys_load_type().load(keys_input + block_offset, keys, storage.keys_load);
391 values_load_type().load(values_input + block_offset, values, storage.values_load);
396 valid_count = end_offset - block_offset;
398 const key_type out_of_bounds = key_codec::decode(bit_key_type(-1));
399 keys_load_type().load(keys_input + block_offset, keys, valid_count, out_of_bounds, storage.keys_load);
403 values_load_type().load(values_input + block_offset, values, valid_count, storage.values_load);
407 if(flat_id < radix_size)
409 storage.starts[flat_id] = valid_count;
410 storage.ends[flat_id] = valid_count;
414 sort_block<Descending>(sort_type(), keys, values, storage.sort, bit, bit + current_radix_bits);
416 bit_key_type bit_keys[ItemsPerThread];
417 unsigned int digits[ItemsPerThread];
418 for(
unsigned int i = 0; i < ItemsPerThread; i++)
420 bit_keys[i] = key_codec::encode(keys[i]);
421 digits[i] = key_codec::extract_digit(bit_keys[i], bit, current_radix_bits);
424 bool head_flags[ItemsPerThread];
425 bool tail_flags[ItemsPerThread];
426 ::rocprim::not_equal_to<unsigned int> flag_op;
429 discontinuity_type().flag_heads_and_tails(head_flags, tail_flags, digits, flag_op, storage.discontinuity);
432 for(
unsigned int i = 0; i < ItemsPerThread; i++)
434 const unsigned int digit = digits[i];
435 const unsigned int pos = flat_id * ItemsPerThread + i;
438 storage.starts[digit] = pos;
442 storage.ends[digit] = pos;
449 bit_keys_exchange_type().blocked_to_striped(bit_keys, bit_keys, storage.bit_keys_exchange);
453 values_exchange_type().blocked_to_striped(values, values, storage.values_exchange);
456 for(
unsigned int i = 0; i < ItemsPerThread; i++)
458 const unsigned int digit = key_codec::extract_digit(bit_keys[i], bit, current_radix_bits);
459 const unsigned int pos = i * BlockSize + flat_id;
460 if(IsFull || (pos < valid_count))
462 const Offset dst = pos - storage.starts[digit] + storage.digit_starts[digit];
463 keys_output[dst] = key_codec::decode(bit_keys[i]);
466 values_output[dst] = values[i];
474 if(flat_id < radix_size)
476 const unsigned int digit = flat_id;
477 const unsigned int start = storage.starts[digit];
478 const unsigned int end = storage.ends[digit];
479 if(start < valid_count)
481 storage.digit_starts[digit] += (
::rocprim::min(valid_count - 1, end) - start + 1);
489 unsigned int BlockSize,
490 unsigned int ItemsPerThread,
492 class KeysInputIterator,
493 class KeysOutputIterator,
494 class ValuesInputIterator,
495 class ValuesOutputIterator
497 ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
498 void sort_single(KeysInputIterator keys_input,
499 KeysOutputIterator keys_output,
500 ValuesInputIterator values_input,
501 ValuesOutputIterator values_output,
504 unsigned int current_radix_bits)
506 using key_type =
typename std::iterator_traits<KeysInputIterator>::value_type;
507 using value_type =
typename std::iterator_traits<ValuesInputIterator>::value_type;
510 BlockSize, ItemsPerThread, Descending,
514 ROCPRIM_SHARED_MEMORY
typename sort_single_helper::storage_type storage;
516 sort_single_helper().template sort_single(
517 keys_input, keys_output, values_input, values_output,
518 size, bit, current_radix_bits,
524 ROCPRIM_DEVICE ROCPRIM_INLINE
525 auto compare_nan_sensitive(
const T& a,
const T& b)
526 ->
typename std::enable_if<rocprim::is_floating_point<T>::value,
bool>::type
535 auto a_bits = __builtin_bit_cast(bit_key_type, a);
536 auto b_bits = __builtin_bit_cast(bit_key_type, b);
539 a_bits = a_bits == sign_bit ? 0 : a_bits;
540 b_bits = b_bits == sign_bit ? 0 : b_bits;
542 a_bits ^= (sign_bit & a_bits) == 0 ? sign_bit : bit_key_type(-1);
543 b_bits ^= (sign_bit & b_bits) == 0 ? sign_bit : bit_key_type(-1);
546 return a_bits > b_bits;
550 ROCPRIM_DEVICE ROCPRIM_INLINE
551 auto compare_nan_sensitive(
const T& a,
const T& b)
552 ->
typename std::enable_if<!rocprim::is_floating_point<T>::value,
bool>::type
568 ROCPRIM_DEVICE ROCPRIM_INLINE
569 bool operator()(
const T& a,
const T& b)
const 571 return compare_nan_sensitive<T>(b, a);
578 ROCPRIM_DEVICE ROCPRIM_INLINE
579 bool operator()(
const T& a,
const T& b)
const 581 return compare_nan_sensitive<T>(a, b);
590 ROCPRIM_HOST_DEVICE ROCPRIM_INLINE
593 T radix_mask_upper = (T(1) << (current_radix_bits + start_bit)) - 1;
594 T radix_mask_bottom = (T(1) << start_bit) - 1;
595 radix_mask = radix_mask_upper ^ radix_mask_bottom;
598 ROCPRIM_DEVICE ROCPRIM_INLINE
599 bool operator()(
const T& a,
const T& b)
const 601 const T masked_key_a = a & radix_mask;
602 const T masked_key_b = b & radix_mask;
603 return masked_key_b > masked_key_a;
612 ROCPRIM_HOST_DEVICE ROCPRIM_INLINE
615 T radix_mask_upper = (T(1) << (current_radix_bits + start_bit)) - 1;
616 T radix_mask_bottom = (T(1) << start_bit) - 1;
617 radix_mask = (radix_mask_upper ^ radix_mask_bottom);
620 ROCPRIM_DEVICE ROCPRIM_INLINE
621 bool operator()(
const T& a,
const T& b)
const 623 const T masked_key_a = a & radix_mask;
624 const T masked_key_b = b & radix_mask;
625 return masked_key_a > masked_key_b;
629 template<
bool Descending,
class T>
633 typename std::enable_if<!rocprim::is_integral<T>::value>::type>
638 ROCPRIM_HOST_DEVICE ROCPRIM_INLINE
641 ROCPRIM_DEVICE ROCPRIM_INLINE
642 bool operator()(
const T&,
const T&)
const {
return false; }
645 template<
class KeyType,
646 unsigned int BlockSize,
647 unsigned int ItemsPerThread,
648 unsigned int RadixBits,
652 static constexpr
unsigned int radix_size = 1u << RadixBits;
655 static constexpr
unsigned int max_digit_places
656 = ::rocprim::detail::ceiling_div(
sizeof(KeyType) * 8, RadixBits);
657 static constexpr
unsigned int items_per_block = BlockSize * ItemsPerThread;
658 static constexpr
unsigned int digits_per_thread
659 = ::rocprim::detail::ceiling_div(radix_size, BlockSize);
660 static constexpr
unsigned int atomic_stripes = 4;
661 static constexpr
unsigned int histogram_counters
662 = radix_size * max_digit_places * atomic_stripes;
664 using counter_type = uint32_t;
666 using bit_key_type =
typename key_codec::bit_key_type;
670 counter_type
histogram[histogram_counters];
673 ROCPRIM_DEVICE ROCPRIM_INLINE counter_type& get_counter(
const unsigned stripe_index,
674 const unsigned int place,
675 const unsigned int digit,
678 return storage.histogram[(place * radix_size + digit) * atomic_stripes + stripe_index];
681 ROCPRIM_DEVICE ROCPRIM_INLINE
void clear_histogram(
const unsigned int flat_id,
684 for(
unsigned int i = flat_id; i < histogram_counters; i += BlockSize)
686 storage.histogram[i] = 0;
690 template<
bool IsFull>
691 ROCPRIM_DEVICE
void count_digits_at_place(
const unsigned int flat_id,
692 const unsigned int stripe,
693 const bit_key_type (&bit_keys)[ItemsPerThread],
694 const unsigned int place,
695 const unsigned int start_bit,
696 const unsigned int current_radix_bits,
697 const unsigned int valid_count,
701 for(
unsigned int i = 0; i < ItemsPerThread; ++i)
703 const unsigned int pos = i * BlockSize + flat_id;
704 if(IsFull || pos < valid_count)
706 const unsigned int digit
707 = key_codec::extract_digit(bit_keys[i], start_bit, current_radix_bits);
708 ::rocprim::detail::atomic_add(&get_counter(stripe, place, digit, storage), 1);
713 template<
bool IsFull,
class KeysInputIterator,
class Offset>
714 ROCPRIM_DEVICE
void count_digits(KeysInputIterator keys_input,
715 Offset* global_digit_counts,
716 const unsigned int valid_count,
717 const unsigned int begin_bit,
718 const unsigned int end_bit,
721 const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>();
722 const unsigned int stripe = flat_id % atomic_stripes;
724 KeyType keys[ItemsPerThread];
726 if ROCPRIM_IF_CONSTEXPR(IsFull)
728 block_load_direct_striped<BlockSize>(flat_id, keys_input, keys);
732 block_load_direct_striped<BlockSize>(flat_id, keys_input, keys, valid_count);
736 clear_histogram(flat_id, storage);
741 bit_key_type bit_keys[ItemsPerThread];
743 for(
unsigned int i = 0; i < ItemsPerThread; ++i)
745 bit_keys[i] = key_codec::encode(keys[i]);
748 for(
unsigned int bit = begin_bit, place = 0; bit < end_bit; bit += RadixBits, ++place)
750 count_digits_at_place<IsFull>(flat_id,
755 min(RadixBits, end_bit - bit),
764 unsigned int place = 0;
765 for(
unsigned int bit = begin_bit; bit < end_bit; bit += RadixBits)
767 for(
unsigned int digit = flat_id; digit < radix_size; digit += BlockSize)
769 counter_type total = 0;
772 for(
unsigned int stripe = 0; stripe < atomic_stripes; ++stripe)
774 total += get_counter(stripe, place, digit, storage);
777 ::rocprim::detail::atomic_add(&global_digit_counts[place * radix_size + digit],
785 template<
unsigned int BlockSize,
786 unsigned int ItemsPerThread,
787 unsigned int RadixBits,
789 class KeysInputIterator,
791 ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
void onesweep_histograms(KeysInputIterator keys_input,
792 Offset* global_digit_counts,
794 const Offset full_blocks,
795 const unsigned int begin_bit,
796 const unsigned int end_bit)
798 using key_type =
typename std::iterator_traits<KeysInputIterator>::value_type;
799 using count_helper_type
802 constexpr
unsigned int items_per_block = BlockSize * ItemsPerThread;
804 const Offset
block_id = ::rocprim::detail::block_id<0>();
805 const Offset block_offset = block_id * ItemsPerThread * BlockSize;
807 ROCPRIM_SHARED_MEMORY
typename count_helper_type::storage_type storage;
809 if(block_id < full_blocks)
811 count_helper_type{}.template count_digits<true>(keys_input + block_offset,
820 const unsigned int valid_in_last_block = size - items_per_block * full_blocks;
821 count_helper_type{}.template count_digits<false>(keys_input + block_offset,
830 template<
unsigned int BlockSize,
unsigned int RadixBits,
class Offset>
831 ROCPRIM_DEVICE
void onesweep_scan_histograms(Offset* global_digit_offsets)
835 constexpr
unsigned int radix_size = 1u << RadixBits;
836 constexpr
unsigned int items_per_thread = ::rocprim::detail::ceiling_div(radix_size, BlockSize);
838 const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>();
839 const unsigned int digit_place = ::rocprim::detail::block_id<0>();
840 const unsigned int block_offset = digit_place * radix_size;
842 Offset offsets[items_per_thread];
844 block_scan_type{}.exclusive_scan(offsets, offsets, 0);
852 using underlying_type = uint32_t;
854 static constexpr
unsigned int state_bits = 8u *
sizeof(underlying_type);
856 enum prefix_flag : underlying_type
859 PARTIAL = 1u << (state_bits - 2),
860 COMPLETE = 2u << (state_bits - 2)
863 static constexpr underlying_type status_mask = 3u << (state_bits - 2);
864 static constexpr underlying_type value_mask = ~status_mask;
866 underlying_type state;
872 ROCPRIM_DEVICE ROCPRIM_INLINE onesweep_lookback_state(prefix_flag status, underlying_type value)
873 : state(static_cast<underlying_type>(status) | value)
876 ROCPRIM_DEVICE ROCPRIM_INLINE underlying_type value()
const 878 return this->state & value_mask;
881 ROCPRIM_DEVICE ROCPRIM_INLINE prefix_flag status()
const 883 return static_cast<prefix_flag
>(this->state & status_mask);
886 ROCPRIM_DEVICE ROCPRIM_INLINE
static onesweep_lookback_state load(onesweep_lookback_state* ptr)
888 underlying_type state = ::rocprim::detail::atomic_add(&ptr->state, 0);
889 return onesweep_lookback_state(state);
892 ROCPRIM_DEVICE ROCPRIM_INLINE
void store(onesweep_lookback_state* ptr)
const 894 ::rocprim::detail::atomic_exch(&ptr->state, this->state);
901 unsigned int BlockSize,
902 unsigned int ItemsPerThread,
903 unsigned int RadixBits,
908 static constexpr
unsigned int radix_size = 1u << RadixBits;
909 static constexpr
unsigned int items_per_block = BlockSize * ItemsPerThread;
910 static constexpr
bool with_values = !std::is_same<Value, rocprim::empty_type>::value;
913 using bit_key_type =
typename key_codec::bit_key_type;
914 using radix_rank_type = ::rocprim::block_radix_rank<BlockSize, RadixBits, RadixRankAlgorithm>;
916 static constexpr
bool load_warp_striped
919 static constexpr
unsigned int digits_per_thread = radix_rank_type::digits_per_thread;
923 typename radix_rank_type::storage_type rank;
926 Offset global_digit_offsets[radix_size];
929 bit_key_type ordered_block_keys[items_per_block];
930 Value ordered_block_values[items_per_block];
937 template<
bool IsFull,
938 class KeysInputIterator,
939 class KeysOutputIterator,
940 class ValuesInputIterator,
941 class ValuesOutputIterator>
942 ROCPRIM_DEVICE
void onesweep(KeysInputIterator keys_input,
943 KeysOutputIterator keys_output,
944 ValuesInputIterator values_input,
945 ValuesOutputIterator values_output,
946 Offset* global_digit_offsets_in,
947 Offset* global_digit_offsets_out,
949 const unsigned int bit,
950 const unsigned int current_radix_bits,
951 const unsigned int valid_items,
954 const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>();
955 const unsigned int block_id = ::rocprim::detail::block_id<0>();
956 const unsigned int block_offset = block_id * items_per_block;
959 Key keys[ItemsPerThread];
960 if ROCPRIM_IF_CONSTEXPR(IsFull)
962 if ROCPRIM_IF_CONSTEXPR(load_warp_striped)
979 const Key out_of_bounds = key_codec::decode(bit_key_type(-1));
980 if ROCPRIM_IF_CONSTEXPR(load_warp_striped)
983 keys_input + block_offset,
991 keys_input + block_offset,
998 bit_key_type bit_keys[ItemsPerThread];
1000 for(
unsigned int i = 0; i < ItemsPerThread; ++i)
1002 bit_keys[i] = key_codec::encode(keys[i]);
1006 unsigned int ranks[ItemsPerThread];
1008 unsigned int exclusive_digit_prefix[digits_per_thread];
1010 unsigned int digit_counts[digits_per_thread];
1011 radix_rank_type{}.rank_keys(
1015 [bit, current_radix_bits](
const bit_key_type& key)
1016 { return key_codec::extract_digit(key, bit, current_radix_bits); },
1017 exclusive_digit_prefix,
1024 for(
unsigned int i = 0; i < ItemsPerThread; ++i)
1026 storage.ordered_block_keys[ranks[i]] = bit_keys[i];
1034 for(
unsigned int i = 0; i < digits_per_thread; ++i)
1036 const unsigned int digit = flat_id * digits_per_thread + i;
1037 if(radix_size % BlockSize == 0 || digit < radix_size)
1040 = &lookback_states[block_id * radix_size + digit];
1042 .store(block_state);
1044 unsigned int exclusive_prefix = 0;
1045 unsigned int lookback_block_id =
block_id;
1047 while(lookback_block_id > 0)
1049 --lookback_block_id;
1051 = &lookback_states[lookback_block_id * radix_size + digit];
1053 = onesweep_lookback_state::load(lookback_state_ptr);
1054 while(lookback_state.status() == onesweep_lookback_state::EMPTY)
1056 lookback_state = onesweep_lookback_state::load(lookback_state_ptr);
1059 exclusive_prefix += lookback_state.value();
1060 if(lookback_state.status() == onesweep_lookback_state::COMPLETE)
1067 const unsigned int inclusive_digit_prefix = exclusive_prefix + digit_counts[i];
1071 .store(block_state);
1075 storage.global_digit_offsets[digit]
1076 = global_digit_offsets_in[digit] - exclusive_digit_prefix[i] + exclusive_prefix;
1084 for(
unsigned int i = 0; i < ItemsPerThread; ++i)
1086 const unsigned int rank = i * BlockSize + flat_id;
1087 if(IsFull || rank < valid_items)
1089 const bit_key_type bit_key = storage.ordered_block_keys[rank];
1090 const unsigned int digit
1091 = key_codec::extract_digit(bit_key, bit, current_radix_bits);
1092 const Offset global_offset = storage.global_digit_offsets[digit];
1093 keys_output[rank + global_offset] = key_codec::decode(bit_key);
1100 Value values[ItemsPerThread];
1101 if ROCPRIM_IF_CONSTEXPR(IsFull)
1103 if ROCPRIM_IF_CONSTEXPR(load_warp_striped)
1114 if ROCPRIM_IF_CONSTEXPR(load_warp_striped)
1117 values_input + block_offset,
1124 values_input + block_offset,
1132 unsigned int digits[ItemsPerThread];
1134 for(
unsigned int i = 0; i < ItemsPerThread; ++i)
1136 const unsigned int rank = i * BlockSize + flat_id;
1137 if(IsFull || rank < valid_items)
1139 const bit_key_type bit_key = storage.ordered_block_keys[rank];
1140 digits[i] = key_codec::extract_digit(bit_key, bit, current_radix_bits);
1148 for(
unsigned int i = 0; i < ItemsPerThread; ++i)
1150 storage.ordered_block_values[ranks[i]] = values[i];
1157 for(
unsigned int i = 0; i < ItemsPerThread; ++i)
1159 const unsigned int rank = i * BlockSize + flat_id;
1160 if(IsFull || rank < valid_items)
1162 const Value value = storage.ordered_block_values[rank];
1163 const Offset global_offset = storage.global_digit_offsets[digits[i]];
1164 values_output[rank + global_offset] = value;
1170 const bool is_last_block = block_id == rocprim::detail::grid_size<0>() - 1;
1174 for(
unsigned int i = 0; i < digits_per_thread; ++i)
1176 const unsigned int digit = flat_id * digits_per_thread + i;
1177 if(radix_size % BlockSize == 0 || digit < radix_size)
1179 global_digit_offsets_out[digit] = storage.global_digit_offsets[digit]
1180 + exclusive_digit_prefix[i] + digit_counts[i];
1187 template<
unsigned int BlockSize,
1188 unsigned int ItemsPerThread,
1189 unsigned int RadixBits,
1192 class KeysInputIterator,
1193 class KeysOutputIterator,
1194 class ValuesInputIterator,
1195 class ValuesOutputIterator,
1197 ROCPRIM_DEVICE
void onesweep_iteration(KeysInputIterator keys_input,
1198 KeysOutputIterator keys_output,
1199 ValuesInputIterator values_input,
1200 ValuesOutputIterator values_output,
1201 const unsigned int size,
1202 Offset* global_digit_offsets_in,
1203 Offset* global_digit_offsets_out,
1205 const unsigned int bit,
1206 const unsigned int current_radix_bits,
1207 const unsigned int full_blocks)
1209 using key_type =
typename std::iterator_traits<KeysInputIterator>::value_type;
1210 using value_type =
typename std::iterator_traits<ValuesInputIterator>::value_type;
1219 RadixRankAlgorithm>;
1221 constexpr
unsigned int items_per_block = BlockSize * ItemsPerThread;
1222 const unsigned int block_id = ::rocprim::detail::block_id<0>();
1224 ROCPRIM_SHARED_MEMORY
typename onesweep_iteration_helper_type::storage_type storage;
1226 if(block_id < full_blocks)
1228 onesweep_iteration_helper_type{}.template onesweep<true>(keys_input,
1232 global_digit_offsets_in,
1233 global_digit_offsets_out,
1242 const unsigned int valid_in_last_block = size - items_per_block * full_blocks;
1243 onesweep_iteration_helper_type{}.template onesweep<false>(keys_input,
1247 global_digit_offsets_in,
1248 global_digit_offsets_out,
1252 valid_in_last_block,
1259 END_ROCPRIM_NAMESPACE
1261 #endif // ROCPRIM_DEVICE_DETAIL_DEVICE_RADIX_SORT_HPP_ ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int block_id()
Returns block identifier in a multidimensional grid by dimension.
Definition: thread.hpp:258
block_radix_rank_algorithm
Available algorithms for the block_radix_rank primitive.
Definition: block_radix_rank.hpp:40
ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int masked_bit_count(lane_mask_type x, unsigned int add=0)
Masked bit count.
Definition: warp.hpp:48
Definition: device_radix_sort.hpp:906
The block_scan class is a block level parallel primitive which provides methods for performing inclus...
Definition: block_scan.hpp:134
Definition: device_radix_sort.hpp:312
Definition: device_radix_sort.hpp:216
Definition: device_radix_sort.hpp:848
ROCPRIM_DEVICE ROCPRIM_INLINE constexpr unsigned int device_warp_size()
Returns a number of threads in a hardware warp for the actual target.
Definition: thread.hpp:70
Definition: radix_sort.hpp:101
ROCPRIM_HOST_DEVICE constexpr T min(const T &a, const T &b)
Returns the minimum of its arguments.
Definition: functional.hpp:63
Definition: device_radix_sort.hpp:563
Definition: device_radix_sort.hpp:335
Definition: device_radix_sort.hpp:921
Deprecated: Configuration of device-level scan primitives.
Definition: block_histogram.hpp:62
Definition: device_radix_sort.hpp:650
const unsigned int warp_id
Returns warp id in a block (tile).
Definition: benchmark_warp_exchange.cpp:153
ROCPRIM_DEVICE ROCPRIM_INLINE lane_mask_type ballot(int predicate)
Evaluate predicate for all active work-items in the warp and return an integer whose i-th bit is set ...
Definition: warp.hpp:38
ROCPRIM_DEVICE ROCPRIM_INLINE void syncthreads()
Synchronize all threads in a block (tile)
Definition: thread.hpp:216
ROCPRIM_DEVICE ROCPRIM_INLINE void block_load_direct_warp_striped(unsigned int flat_id, InputIterator block_input, T(&items)[ItemsPerThread])
Loads data from continuous memory into a warp-striped arrangement of items across the thread block...
Definition: block_load_func.hpp:378
ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int flat_block_id()
Returns flat (linear, 1D) block identifier in a multidimensional grid.
Definition: thread.hpp:178
Definition: radix_sort.hpp:241
Definition: device_radix_sort.hpp:106
BEGIN_ROCPRIM_NAMESPACE ROCPRIM_DEVICE ROCPRIM_INLINE void block_store_direct_blocked(unsigned int flat_id, OutputIterator block_output, T(&items)[ItemsPerThread])
Stores a blocked arrangement of items from across the thread block into a blocked arrangement on cont...
Definition: block_store_func.hpp:58
Definition: device_radix_sort.hpp:97
Warp-based radix ranking algorithm. Keys and ranks are assumed in warp-striped order for this algorit...
BEGIN_ROCPRIM_NAMESPACE ROCPRIM_DEVICE ROCPRIM_INLINE void block_load_direct_blocked(unsigned int flat_id, InputIterator block_input, T(&items)[ItemsPerThread])
Loads data from continuous memory into a blocked arrangement of items across the thread block...
Definition: block_load_func.hpp:58
Definition: benchmark_block_histogram.cpp:64
Definition: device_radix_sort.hpp:203
unsigned long long int lane_mask_type
The lane_mask_type is an integer that contains one bit per thread.
Definition: types.hpp:164
ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int bit_count(unsigned int x)
Bit count.
Definition: bit.hpp:42
Definition: device_radix_sort.hpp:668