21 #ifndef ROCPRIM_BLOCK_DETAIL_BLOCK_RANK_MATCH_HPP_ 22 #define ROCPRIM_BLOCK_DETAIL_BLOCK_RANK_MATCH_HPP_ 24 #include "../../config.hpp" 25 #include "../../detail/various.hpp" 26 #include "../../functional.hpp" 27 #include "../../types.hpp" 29 #include "../../detail/radix_sort.hpp" 31 #include "../block_scan.hpp" 33 BEGIN_ROCPRIM_NAMESPACE
38 template<
unsigned int BlockSizeX,
39 unsigned int RadixBits,
40 unsigned int BlockSizeY = 1,
41 unsigned int BlockSizeZ = 1>
44 using digit_counter_type =
unsigned int;
46 using block_scan_type = ::rocprim::block_scan<digit_counter_type,
48 ::rocprim::block_scan_algorithm::using_warp_scan,
52 static constexpr
unsigned int block_size = BlockSizeX * BlockSizeY * BlockSizeZ;
53 static constexpr
unsigned int radix_digits = 1 << RadixBits;
55 static constexpr
unsigned int warp_size = warpSize;
57 static constexpr
unsigned int warps
58 = ::rocprim::detail::ceiling_div(block_size, warp_size) | 1u;
60 static constexpr
unsigned int active_counters = warps * radix_digits;
64 static constexpr
unsigned int counters_per_thread
65 = ::rocprim::detail::ceiling_div(active_counters, block_size);
67 static constexpr
unsigned int counters = counters_per_thread * block_size;
70 constexpr
static unsigned int digits_per_thread
71 = ::rocprim::detail::ceiling_div(radix_digits, block_size);
76 typename block_scan_type::storage_type
block_scan;
77 digit_counter_type counters[counters];
80 ROCPRIM_DEVICE ROCPRIM_INLINE digit_counter_type&
81 get_digit_counter(
const unsigned int digit,
const unsigned int warp, storage_type_& storage)
83 return storage.counters[digit * warps + warp];
86 template<
typename Key,
unsigned int ItemsPerThread,
typename DigitExtractor>
87 ROCPRIM_DEVICE
void rank_keys_impl(
const Key (&keys)[ItemsPerThread],
88 unsigned int (&ranks)[ItemsPerThread],
89 storage_type_& storage,
90 DigitExtractor digit_extractor)
92 const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>();
96 for(
unsigned int i = 0; i < counters_per_thread; ++i)
98 storage.counters[flat_id * counters_per_thread + i] = 0;
103 digit_counter_type* digit_counters[ItemsPerThread];
106 for(
unsigned int i = 0; i < ItemsPerThread; ++i)
109 const unsigned int digit = digit_extractor(keys[i]);
112 digit_counters[i] = &get_digit_counter(digit, warp_id, storage);
113 const digit_counter_type warp_digit_prefix = *digit_counters[i];
119 for(
unsigned int b = 0; b < RadixBits; ++b)
121 const unsigned int bit_set = digit & (1u << b);
123 peer_mask &= (bit_set ? bit_set_mask : ~bit_set_mask);
135 if(peer_digit_prefix == 0)
137 *digit_counters[i] = warp_digit_prefix + digit_count;
143 ranks[i] = warp_digit_prefix + peer_digit_prefix;
149 digit_counter_type scan_counters[counters_per_thread];
152 for(
unsigned int i = 0; i < counters_per_thread; ++i)
154 scan_counters[i] = storage.counters[flat_id * counters_per_thread + i];
157 block_scan_type().exclusive_scan(scan_counters, scan_counters, 0, storage.block_scan);
160 for(
unsigned int i = 0; i < counters_per_thread; ++i)
162 storage.counters[flat_id * counters_per_thread + i] = scan_counters[i];
169 for(
unsigned int i = 0; i < ItemsPerThread; ++i)
171 ranks[i] += *digit_counters[i];
175 template<
bool Descending,
typename Key,
unsigned int ItemsPerThread>
176 ROCPRIM_DEVICE
void rank_keys_impl(
const Key (&keys)[ItemsPerThread],
177 unsigned int (&ranks)[ItemsPerThread],
178 storage_type_& storage,
179 const unsigned int begin_bit,
180 const unsigned int pass_bits)
182 using key_codec = ::rocprim::detail::radix_key_codec<Key, Descending>;
183 using bit_key_type =
typename key_codec::bit_key_type;
185 bit_key_type bit_keys[ItemsPerThread];
187 for(
unsigned int i = 0; i < ItemsPerThread; ++i)
189 bit_keys[i] = key_codec::encode(keys[i]);
192 rank_keys_impl(bit_keys,
195 [begin_bit, pass_bits](
const bit_key_type& key)
196 {
return key_codec::extract_digit(key, begin_bit, pass_bits); });
199 template<
unsigned int ItemsPerThread>
200 ROCPRIM_DEVICE
void digit_prefix_count(
unsigned int (&prefix)[digits_per_thread],
201 unsigned int (&counts)[digits_per_thread],
202 storage_type_& storage)
204 const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>();
207 for(
unsigned int i = 0; i < digits_per_thread; ++i)
209 const unsigned int digit = flat_id * digits_per_thread + i;
210 if(radix_digits % block_size == 0 || digit < radix_digits)
213 prefix[i] = get_digit_counter(digit, 0, storage);
216 const unsigned int next_prefix = digit + 1 == radix_digits
217 ? block_size * ItemsPerThread
218 : get_digit_counter(digit + 1, 0, storage);
219 counts[i] = next_prefix - prefix[i];
225 using storage_type = ::rocprim::detail::raw_storage<storage_type_>;
227 template<
typename Key,
unsigned ItemsPerThread>
228 ROCPRIM_DEVICE
void rank_keys(
const Key (&keys)[ItemsPerThread],
229 unsigned int (&ranks)[ItemsPerThread],
230 storage_type& storage,
231 unsigned int begin_bit = 0,
232 unsigned int pass_bits = RadixBits)
234 rank_keys_impl<false>(keys, ranks, storage.get(), begin_bit, pass_bits);
237 template<
typename Key,
unsigned ItemsPerThread>
238 ROCPRIM_DEVICE
void rank_keys_desc(
const Key (&keys)[ItemsPerThread],
239 unsigned int (&ranks)[ItemsPerThread],
240 storage_type& storage,
241 unsigned int begin_bit = 0,
242 unsigned int pass_bits = RadixBits)
244 rank_keys_impl<true>(keys, ranks, storage.get(), begin_bit, pass_bits);
247 template<
typename Key,
unsigned ItemsPerThread,
typename DigitExtractor>
248 ROCPRIM_DEVICE
void rank_keys(
const Key (&keys)[ItemsPerThread],
249 unsigned int (&ranks)[ItemsPerThread],
250 storage_type& storage,
251 DigitExtractor digit_extractor)
253 rank_keys_impl(keys, ranks, storage.get(), digit_extractor);
256 template<
typename Key,
unsigned ItemsPerThread,
typename DigitExtractor>
257 ROCPRIM_DEVICE
void rank_keys(
const Key (&keys)[ItemsPerThread],
258 unsigned int (&ranks)[ItemsPerThread],
259 storage_type& storage,
260 DigitExtractor digit_extractor,
261 unsigned int (&prefix)[digits_per_thread],
262 unsigned int (&counts)[digits_per_thread])
264 rank_keys(keys, ranks, storage, digit_extractor);
265 digit_prefix_count<ItemsPerThread>(prefix, counts, storage.get());
271 END_ROCPRIM_NAMESPACE
ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int masked_bit_count(lane_mask_type x, unsigned int add=0)
Masked bit count.
Definition: warp.hpp:48
The block_scan class is a block level parallel primitive which provides methods for performing inclus...
Definition: block_scan.hpp:134
Definition: block_radix_rank_match.hpp:42
ROCPRIM_DEVICE ROCPRIM_INLINE void wave_barrier()
Synchronize all threads in the wavefront.
Definition: thread.hpp:235
Deprecated: Configuration of device-level scan primitives.
Definition: block_histogram.hpp:62
const unsigned int warp_id
Returns warp id in a block (tile).
Definition: benchmark_warp_exchange.cpp:153
ROCPRIM_DEVICE ROCPRIM_INLINE lane_mask_type ballot(int predicate)
Evaluate predicate for all active work-items in the warp and return an integer whose i-th bit is set ...
Definition: warp.hpp:38
ROCPRIM_DEVICE ROCPRIM_INLINE void syncthreads()
Synchronize all threads in a block (tile)
Definition: thread.hpp:216
unsigned long long int lane_mask_type
The lane_mask_type is an integer that contains one bit per thread.
Definition: types.hpp:164
ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int bit_count(unsigned int x)
Bit count.
Definition: bit.hpp:42