21 #ifndef ROCPRIM_BLOCK_DETAIL_BLOCK_RANK_BASIC_HPP_    22 #define ROCPRIM_BLOCK_DETAIL_BLOCK_RANK_BASIC_HPP_    24 #include "../../config.hpp"    25 #include "../../detail/various.hpp"    26 #include "../../functional.hpp"    28 #include "../../detail/radix_sort.hpp"    30 #include "../block_scan.hpp"    32 BEGIN_ROCPRIM_NAMESPACE
    37 template<
unsigned int BlockSizeX,
    38          unsigned int RadixBits,
    39          bool         MemoizeOuterScan = 
false,
    40          unsigned int BlockSizeY       = 1,
    41          unsigned int BlockSizeZ       = 1>
    44     using digit_counter_type  = 
unsigned short;
    45     using packed_counter_type = 
unsigned int;
    47     using block_scan_type = ::rocprim::block_scan<packed_counter_type,
    49                                                   ::rocprim::block_scan_algorithm::using_warp_scan,
    53     static constexpr 
unsigned int block_size   = BlockSizeX * BlockSizeY * BlockSizeZ;
    54     static constexpr 
unsigned int radix_digits = 1 << RadixBits;
    55     static constexpr 
unsigned int packing_ratio
    56         = 
sizeof(packed_counter_type) / 
sizeof(digit_counter_type);
    57     static constexpr 
unsigned int column_size = radix_digits / packing_ratio;
    60     static constexpr 
unsigned int digits_per_thread
    61         = ::rocprim::detail::ceiling_div(radix_digits, block_size);
    68             digit_counter_type  digit_counters[block_size * radix_digits];
    69             packed_counter_type packed_counters[block_size * column_size];
    72         typename block_scan_type::storage_type 
block_scan;
    75     ROCPRIM_DEVICE ROCPRIM_INLINE digit_counter_type& get_digit_counter(
const unsigned int digit,
    76                                                                         const unsigned int thread,
    77                                                                         storage_type_&     storage)
    79         const unsigned int column_counter = digit % column_size;
    80         const unsigned int sub_counter    = digit / column_size;
    81         const unsigned int counter
    82             = (column_counter * block_size + thread) * packing_ratio + sub_counter;
    83         return storage.digit_counters[counter];
    86     ROCPRIM_DEVICE ROCPRIM_INLINE 
void reset_counters(
const unsigned int flat_id,
    87                                                       storage_type_&     storage)
    89         for(
unsigned int i = flat_id; i < block_size * column_size; i += block_size)
    91             storage.packed_counters[i] = 0;
    95     ROCPRIM_DEVICE ROCPRIM_INLINE 
void    96         scan_block_counters(storage_type_& storage, packed_counter_type* 
const packed_counters)
    98         packed_counter_type block_reduction = 0;
   100         for(
unsigned int i = 0; i < column_size; ++i)
   102             block_reduction += packed_counters[i];
   105         packed_counter_type exclusive_prefix = 0;
   106         packed_counter_type reduction;
   107         block_scan_type().exclusive_scan(block_reduction,
   114         for(
unsigned int i = 1; i < packing_ratio; i <<= 1)
   116             exclusive_prefix += reduction << (
sizeof(digit_counter_type) * 8 * i);
   120         for(
unsigned int i = 0; i < column_size; ++i)
   122             packed_counter_type counter = packed_counters[i];
   123             packed_counters[i]          = exclusive_prefix;
   124             exclusive_prefix += counter;
   128     ROCPRIM_DEVICE ROCPRIM_INLINE 
void scan_counters(
const unsigned int flat_id,
   129                                                      storage_type_&     storage)
   131         packed_counter_type* 
const shared_counters
   132             = &storage.packed_counters[flat_id * column_size];
   134         if ROCPRIM_IF_CONSTEXPR(MemoizeOuterScan)
   136             packed_counter_type local_counters[column_size];
   138             for(
unsigned int i = 0; i < column_size; ++i)
   140                 local_counters[i] = shared_counters[i];
   143             scan_block_counters(storage, local_counters);
   146             for(
unsigned int i = 0; i < column_size; ++i)
   148                 shared_counters[i] = local_counters[i];
   153             scan_block_counters(storage, shared_counters);
   157     template<
typename Key, 
unsigned int ItemsPerThread, 
typename DigitExtractor>
   158     ROCPRIM_DEVICE 
void rank_keys_impl(
const Key (&keys)[ItemsPerThread],
   159                                        unsigned int (&ranks)[ItemsPerThread],
   160                                        storage_type_& storage,
   161                                        DigitExtractor digit_extractor)
   163         static_assert(block_size * ItemsPerThread < 1u << 16,
   164                       "The maximum amout of items that block_radix_rank can rank is 2**16.");
   165         const unsigned int flat_id
   166             = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
   168         reset_counters(flat_id, storage);
   170         digit_counter_type  thread_prefixes[ItemsPerThread];
   171         digit_counter_type* digit_counters[ItemsPerThread];
   174         for(
unsigned int i = 0; i < ItemsPerThread; ++i)
   176             const unsigned int digit = digit_extractor(keys[i]);
   177             digit_counters[i]        = &get_digit_counter(digit, flat_id, storage);
   178             thread_prefixes[i]       = (*digit_counters[i])++;
   183         scan_counters(flat_id, storage);
   188         for(
unsigned int i = 0; i < ItemsPerThread; ++i)
   190             ranks[i] = thread_prefixes[i] + *digit_counters[i];
   194     template<
bool Descending, 
typename Key, 
unsigned int ItemsPerThread>
   195     ROCPRIM_DEVICE 
void rank_keys_impl(
const Key (&keys)[ItemsPerThread],
   196                                        unsigned int (&ranks)[ItemsPerThread],
   197                                        storage_type_&     storage,
   198                                        const unsigned int begin_bit,
   199                                        const unsigned int pass_bits)
   201         using key_codec    = ::rocprim::detail::radix_key_codec<Key, Descending>;
   202         using bit_key_type = 
typename key_codec::bit_key_type;
   204         bit_key_type bit_keys[ItemsPerThread];
   206         for(
unsigned int i = 0; i < ItemsPerThread; ++i)
   208             bit_keys[i] = key_codec::encode(keys[i]);
   211         rank_keys_impl(bit_keys,
   214                        [begin_bit, pass_bits](
const bit_key_type& key)
   215                        { 
return key_codec::extract_digit(key, begin_bit, pass_bits); });
   218     template<
unsigned int ItemsPerThread>
   219     ROCPRIM_DEVICE 
void digit_prefix_count(
unsigned int (&prefix)[digits_per_thread],
   220                                            unsigned int (&counts)[digits_per_thread],
   221                                            storage_type_& storage)
   223         const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>();
   226         for(
unsigned int i = 0; i < digits_per_thread; ++i)
   228             const unsigned int digit = flat_id * digits_per_thread + i;
   229             if(radix_digits % block_size == 0 || digit < radix_digits)
   232                 prefix[i] = get_digit_counter(digit, 0, storage);
   235                 const unsigned int next_prefix = digit + 1 == radix_digits
   236                                                      ? block_size * ItemsPerThread
   237                                                      : get_digit_counter(digit + 1, 0, storage);
   238                 counts[i]                      = next_prefix - prefix[i];
   244     using storage_type = ::rocprim::detail::raw_storage<storage_type_>;
   246     template<
typename Key, 
unsigned ItemsPerThread>
   247     ROCPRIM_DEVICE 
void rank_keys(
const Key (&keys)[ItemsPerThread],
   248                                   unsigned int (&ranks)[ItemsPerThread],
   249                                   storage_type& storage,
   250                                   unsigned int  begin_bit = 0,
   251                                   unsigned int  pass_bits = RadixBits)
   253         rank_keys_impl<false>(keys, ranks, storage.get(), begin_bit, pass_bits);
   256     template<
typename Key, 
unsigned ItemsPerThread>
   257     ROCPRIM_DEVICE 
void rank_keys_desc(
const Key (&keys)[ItemsPerThread],
   258                                        unsigned int (&ranks)[ItemsPerThread],
   259                                        storage_type& storage,
   260                                        unsigned int  begin_bit = 0,
   261                                        unsigned int  pass_bits = RadixBits)
   263         rank_keys_impl<true>(keys, ranks, storage.get(), begin_bit, pass_bits);
   266     template<
typename Key, 
unsigned ItemsPerThread, 
typename DigitExtractor>
   267     ROCPRIM_DEVICE 
void rank_keys(
const Key (&keys)[ItemsPerThread],
   268                                   unsigned int (&ranks)[ItemsPerThread],
   269                                   storage_type&  storage,
   270                                   DigitExtractor digit_extractor)
   272         rank_keys_impl(keys, ranks, storage.get(), digit_extractor);
   275     template<
typename Key, 
unsigned ItemsPerThread, 
typename DigitExtractor>
   276     ROCPRIM_DEVICE 
void rank_keys_desc(
const Key (&keys)[ItemsPerThread],
   277                                        unsigned int (&ranks)[ItemsPerThread],
   278                                        storage_type&  storage,
   279                                        DigitExtractor digit_extractor)
   284                        [&digit_extractor](
const Key& key)
   286                            const unsigned int digit = digit_extractor(key);
   287                            return radix_digits - 1 - digit;
   291     template<
typename Key, 
unsigned ItemsPerThread, 
typename DigitExtractor>
   292     ROCPRIM_DEVICE 
void rank_keys(
const Key (&keys)[ItemsPerThread],
   293                                   unsigned int (&ranks)[ItemsPerThread],
   294                                   storage_type&  storage,
   295                                   DigitExtractor digit_extractor,
   296                                   unsigned int (&prefix)[digits_per_thread],
   297                                   unsigned int (&counts)[digits_per_thread])
   299         rank_keys(keys, ranks, storage, digit_extractor);
   300         digit_prefix_count<ItemsPerThread>(prefix, counts, storage.get());
   306 END_ROCPRIM_NAMESPACE
 The block_scan class is a block level parallel primitive which provides methods for performing inclus...
Definition: block_scan.hpp:134
Definition: block_radix_rank_basic.hpp:42
Deprecated: Configuration of device-level scan primitives. 
Definition: block_histogram.hpp:62
ROCPRIM_DEVICE ROCPRIM_INLINE void syncthreads()
Synchronize all threads in a block (tile) 
Definition: thread.hpp:216