21 #ifndef ROCPRIM_BLOCK_DETAIL_BLOCK_RANK_BASIC_HPP_ 22 #define ROCPRIM_BLOCK_DETAIL_BLOCK_RANK_BASIC_HPP_ 24 #include "../../config.hpp" 25 #include "../../detail/various.hpp" 26 #include "../../functional.hpp" 28 #include "../../detail/radix_sort.hpp" 30 #include "../block_scan.hpp" 32 BEGIN_ROCPRIM_NAMESPACE
37 template<
unsigned int BlockSizeX,
38 unsigned int RadixBits,
39 bool MemoizeOuterScan =
false,
40 unsigned int BlockSizeY = 1,
41 unsigned int BlockSizeZ = 1>
44 using digit_counter_type =
unsigned short;
45 using packed_counter_type =
unsigned int;
47 using block_scan_type = ::rocprim::block_scan<packed_counter_type,
49 ::rocprim::block_scan_algorithm::using_warp_scan,
53 static constexpr
unsigned int block_size = BlockSizeX * BlockSizeY * BlockSizeZ;
54 static constexpr
unsigned int radix_digits = 1 << RadixBits;
55 static constexpr
unsigned int packing_ratio
56 =
sizeof(packed_counter_type) /
sizeof(digit_counter_type);
57 static constexpr
unsigned int column_size = radix_digits / packing_ratio;
60 static constexpr
unsigned int digits_per_thread
61 = ::rocprim::detail::ceiling_div(radix_digits, block_size);
68 digit_counter_type digit_counters[block_size * radix_digits];
69 packed_counter_type packed_counters[block_size * column_size];
72 typename block_scan_type::storage_type
block_scan;
75 ROCPRIM_DEVICE ROCPRIM_INLINE digit_counter_type& get_digit_counter(
const unsigned int digit,
76 const unsigned int thread,
77 storage_type_& storage)
79 const unsigned int column_counter = digit % column_size;
80 const unsigned int sub_counter = digit / column_size;
81 const unsigned int counter
82 = (column_counter * block_size + thread) * packing_ratio + sub_counter;
83 return storage.digit_counters[counter];
86 ROCPRIM_DEVICE ROCPRIM_INLINE
void reset_counters(
const unsigned int flat_id,
87 storage_type_& storage)
89 for(
unsigned int i = flat_id; i < block_size * column_size; i += block_size)
91 storage.packed_counters[i] = 0;
95 ROCPRIM_DEVICE ROCPRIM_INLINE
void 96 scan_block_counters(storage_type_& storage, packed_counter_type*
const packed_counters)
98 packed_counter_type block_reduction = 0;
100 for(
unsigned int i = 0; i < column_size; ++i)
102 block_reduction += packed_counters[i];
105 packed_counter_type exclusive_prefix = 0;
106 packed_counter_type reduction;
107 block_scan_type().exclusive_scan(block_reduction,
114 for(
unsigned int i = 1; i < packing_ratio; i <<= 1)
116 exclusive_prefix += reduction << (
sizeof(digit_counter_type) * 8 * i);
120 for(
unsigned int i = 0; i < column_size; ++i)
122 packed_counter_type counter = packed_counters[i];
123 packed_counters[i] = exclusive_prefix;
124 exclusive_prefix += counter;
128 ROCPRIM_DEVICE ROCPRIM_INLINE
void scan_counters(
const unsigned int flat_id,
129 storage_type_& storage)
131 packed_counter_type*
const shared_counters
132 = &storage.packed_counters[flat_id * column_size];
134 if ROCPRIM_IF_CONSTEXPR(MemoizeOuterScan)
136 packed_counter_type local_counters[column_size];
138 for(
unsigned int i = 0; i < column_size; ++i)
140 local_counters[i] = shared_counters[i];
143 scan_block_counters(storage, local_counters);
146 for(
unsigned int i = 0; i < column_size; ++i)
148 shared_counters[i] = local_counters[i];
153 scan_block_counters(storage, shared_counters);
157 template<
typename Key,
unsigned int ItemsPerThread,
typename DigitExtractor>
158 ROCPRIM_DEVICE
void rank_keys_impl(
const Key (&keys)[ItemsPerThread],
159 unsigned int (&ranks)[ItemsPerThread],
160 storage_type_& storage,
161 DigitExtractor digit_extractor)
163 static_assert(block_size * ItemsPerThread < 1u << 16,
164 "The maximum amout of items that block_radix_rank can rank is 2**16.");
165 const unsigned int flat_id
166 = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
168 reset_counters(flat_id, storage);
170 digit_counter_type thread_prefixes[ItemsPerThread];
171 digit_counter_type* digit_counters[ItemsPerThread];
174 for(
unsigned int i = 0; i < ItemsPerThread; ++i)
176 const unsigned int digit = digit_extractor(keys[i]);
177 digit_counters[i] = &get_digit_counter(digit, flat_id, storage);
178 thread_prefixes[i] = (*digit_counters[i])++;
183 scan_counters(flat_id, storage);
188 for(
unsigned int i = 0; i < ItemsPerThread; ++i)
190 ranks[i] = thread_prefixes[i] + *digit_counters[i];
194 template<
bool Descending,
typename Key,
unsigned int ItemsPerThread>
195 ROCPRIM_DEVICE
void rank_keys_impl(
const Key (&keys)[ItemsPerThread],
196 unsigned int (&ranks)[ItemsPerThread],
197 storage_type_& storage,
198 const unsigned int begin_bit,
199 const unsigned int pass_bits)
201 using key_codec = ::rocprim::detail::radix_key_codec<Key, Descending>;
202 using bit_key_type =
typename key_codec::bit_key_type;
204 bit_key_type bit_keys[ItemsPerThread];
206 for(
unsigned int i = 0; i < ItemsPerThread; ++i)
208 bit_keys[i] = key_codec::encode(keys[i]);
211 rank_keys_impl(bit_keys,
214 [begin_bit, pass_bits](
const bit_key_type& key)
215 {
return key_codec::extract_digit(key, begin_bit, pass_bits); });
218 template<
unsigned int ItemsPerThread>
219 ROCPRIM_DEVICE
void digit_prefix_count(
unsigned int (&prefix)[digits_per_thread],
220 unsigned int (&counts)[digits_per_thread],
221 storage_type_& storage)
223 const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>();
226 for(
unsigned int i = 0; i < digits_per_thread; ++i)
228 const unsigned int digit = flat_id * digits_per_thread + i;
229 if(radix_digits % block_size == 0 || digit < radix_digits)
232 prefix[i] = get_digit_counter(digit, 0, storage);
235 const unsigned int next_prefix = digit + 1 == radix_digits
236 ? block_size * ItemsPerThread
237 : get_digit_counter(digit + 1, 0, storage);
238 counts[i] = next_prefix - prefix[i];
244 using storage_type = ::rocprim::detail::raw_storage<storage_type_>;
246 template<
typename Key,
unsigned ItemsPerThread>
247 ROCPRIM_DEVICE
void rank_keys(
const Key (&keys)[ItemsPerThread],
248 unsigned int (&ranks)[ItemsPerThread],
249 storage_type& storage,
250 unsigned int begin_bit = 0,
251 unsigned int pass_bits = RadixBits)
253 rank_keys_impl<false>(keys, ranks, storage.get(), begin_bit, pass_bits);
256 template<
typename Key,
unsigned ItemsPerThread>
257 ROCPRIM_DEVICE
void rank_keys_desc(
const Key (&keys)[ItemsPerThread],
258 unsigned int (&ranks)[ItemsPerThread],
259 storage_type& storage,
260 unsigned int begin_bit = 0,
261 unsigned int pass_bits = RadixBits)
263 rank_keys_impl<true>(keys, ranks, storage.get(), begin_bit, pass_bits);
266 template<
typename Key,
unsigned ItemsPerThread,
typename DigitExtractor>
267 ROCPRIM_DEVICE
void rank_keys(
const Key (&keys)[ItemsPerThread],
268 unsigned int (&ranks)[ItemsPerThread],
269 storage_type& storage,
270 DigitExtractor digit_extractor)
272 rank_keys_impl(keys, ranks, storage.get(), digit_extractor);
275 template<
typename Key,
unsigned ItemsPerThread,
typename DigitExtractor>
276 ROCPRIM_DEVICE
void rank_keys_desc(
const Key (&keys)[ItemsPerThread],
277 unsigned int (&ranks)[ItemsPerThread],
278 storage_type& storage,
279 DigitExtractor digit_extractor)
284 [&digit_extractor](
const Key& key)
286 const unsigned int digit = digit_extractor(key);
287 return radix_digits - 1 - digit;
291 template<
typename Key,
unsigned ItemsPerThread,
typename DigitExtractor>
292 ROCPRIM_DEVICE
void rank_keys(
const Key (&keys)[ItemsPerThread],
293 unsigned int (&ranks)[ItemsPerThread],
294 storage_type& storage,
295 DigitExtractor digit_extractor,
296 unsigned int (&prefix)[digits_per_thread],
297 unsigned int (&counts)[digits_per_thread])
299 rank_keys(keys, ranks, storage, digit_extractor);
300 digit_prefix_count<ItemsPerThread>(prefix, counts, storage.get());
306 END_ROCPRIM_NAMESPACE
The block_scan class is a block level parallel primitive which provides methods for performing inclus...
Definition: block_scan.hpp:134
Definition: block_radix_rank_basic.hpp:42
Deprecated: Configuration of device-level scan primitives.
Definition: block_histogram.hpp:62
ROCPRIM_DEVICE ROCPRIM_INLINE void syncthreads()
Synchronize all threads in a block (tile)
Definition: thread.hpp:216