21 #ifndef ROCPRIM_WARP_DETAIL_WARP_SORT_STABLE_HPP_ 22 #define ROCPRIM_WARP_DETAIL_WARP_SORT_STABLE_HPP_ 24 #include <type_traits> 26 #include "../../config.hpp" 27 #include "../../detail/various.hpp" 29 #include "../../functional.hpp" 30 #include "../../intrinsics.hpp" 32 BEGIN_ROCPRIM_NAMESPACE
37 template<
typename Key,
38 unsigned int BlockSize,
39 unsigned int WarpSize,
40 unsigned int ItemsPerThread,
45 constexpr
static unsigned int items_per_block = BlockSize * ItemsPerThread;
46 constexpr
static bool with_values = !std::is_same<Value, rocprim::empty_type>::value;
48 struct storage_type_keys
50 Key keys[items_per_block];
53 struct storage_type_keys_values
55 Key keys[items_per_block];
56 Value values[items_per_block];
60 = std::conditional_t<with_values, storage_type_keys_values, storage_type_keys>;
63 template<
bool is_incomplete,
typename CompareFunction>
64 ROCPRIM_DEVICE ROCPRIM_INLINE
void thread_sort(Key (&thread_keys)[ItemsPerThread],
65 CompareFunction compare_function,
66 const unsigned int input_size = items_per_block)
69 const auto thread_input_size = thread_offset > input_size ? 0 : input_size - thread_offset;
72 for(
auto i = 0u; i < ItemsPerThread; ++i)
75 for(
auto j = i & 1u; j < ItemsPerThread - 1u; j += 2u)
77 if(j + 1 < thread_input_size
78 && compare_function(thread_keys[j + 1], thread_keys[j]))
87 template<
bool is_incomplete,
typename CompareFunction>
88 ROCPRIM_DEVICE ROCPRIM_INLINE
void thread_sort(Key (&thread_keys)[ItemsPerThread],
89 Value (&thread_values)[ItemsPerThread],
90 CompareFunction compare_function,
91 const unsigned int input_size = items_per_block)
94 const auto thread_input_size = thread_offset > input_size ? 0 : input_size - thread_offset;
97 for(
auto i = 0u; i < ItemsPerThread; ++i)
100 for(
auto j = i & 1u; j < ItemsPerThread - 1u; j += 2u)
102 if(j + 1 < thread_input_size
103 && compare_function(thread_keys[j + 1], thread_keys[j]))
112 template<
bool is_incomplete,
class BinaryFunction>
113 ROCPRIM_DEVICE ROCPRIM_INLINE
void merge_path_merge(Key (&thread_keys)[ItemsPerThread],
114 storage_type_& storage,
115 BinaryFunction compare_function,
116 const unsigned int input_size
123 const auto warp_input_size = warp_offset > input_size ? 0 : input_size - warp_offset;
124 const auto shared_keys = &storage.keys[warp_offset];
127 for(
auto partition_size = 1u; partition_size < WarpSize; partition_size <<= 1u)
130 for(
auto i = 0u; i < ItemsPerThread; ++i)
132 shared_keys[ItemsPerThread * lane + i] = thread_keys[i];
137 const auto size = partition_size * ItemsPerThread;
138 const auto mask = (partition_size * 2) - 1;
140 const auto start = lane & ~mask;
141 const auto keys1_begin = start * ItemsPerThread;
142 const auto keys1_end =
std::min(keys1_begin + size, warp_input_size);
143 const auto keys2_begin = keys1_end;
144 const auto keys2_end =
std::min(keys2_begin + size, warp_input_size);
146 const auto diag =
std::min(ItemsPerThread * (mask & lane), warp_input_size);
147 const auto partition = merge_path(&shared_keys[keys1_begin],
148 &shared_keys[keys2_begin],
149 keys1_end - keys1_begin,
150 keys2_end - keys2_begin,
154 const auto keys1_merge_begin = keys1_begin +
partition;
155 const auto keys2_merge_begin = keys2_begin + diag -
partition;
164 serial_merge(shared_keys, thread_keys, range, compare_function);
170 template<
bool is_incomplete,
class BinaryFunction>
171 ROCPRIM_DEVICE ROCPRIM_INLINE
void merge_path_merge(Key (&thread_keys)[ItemsPerThread],
172 Value (&thread_values)[ItemsPerThread],
173 storage_type_& storage,
174 BinaryFunction compare_function,
175 const unsigned int input_size
182 const auto warp_input_size = warp_offset > input_size ? 0 : input_size - warp_offset;
183 const auto shared_keys = &storage.keys[warp_offset];
184 const auto shared_values = &storage.values[warp_offset];
187 for(
auto partition_size = 1u; partition_size < WarpSize; partition_size <<= 1u)
190 for(
auto i = 0u; i < ItemsPerThread; ++i)
192 shared_keys[ItemsPerThread * lane + i] = thread_keys[i];
193 shared_values[ItemsPerThread * lane + i] = thread_values[i];
198 const auto size = partition_size * ItemsPerThread;
199 const auto mask = (partition_size * 2) - 1;
201 const auto start = lane & ~mask;
202 const auto keys1_begin = start * ItemsPerThread;
203 const auto keys1_end =
std::min(keys1_begin + size, warp_input_size);
204 const auto keys2_begin = keys1_end;
205 const auto keys2_end =
std::min(keys2_begin + size, warp_input_size);
207 const auto diag =
std::min(ItemsPerThread * (mask & lane), warp_input_size);
208 const auto partition = merge_path(&shared_keys[keys1_begin],
209 &shared_keys[keys2_begin],
210 keys1_end - keys1_begin,
211 keys2_end - keys2_begin,
215 const auto keys1_merge_begin = keys1_begin +
partition;
216 const auto keys2_merge_begin = keys2_begin + diag -
partition;
225 serial_merge(shared_keys,
237 static_assert(detail::is_power_of_two(WarpSize),
"WarpSize must be power of 2");
241 template<
class BinaryFunction>
242 ROCPRIM_DEVICE ROCPRIM_INLINE
void sort(Key& thread_key, BinaryFunction compare_function)
245 sort(thread_key, storage, compare_function);
248 template<
class BinaryFunction>
249 ROCPRIM_DEVICE ROCPRIM_INLINE
void 250 sort(Key& thread_key,
storage_type& storage, BinaryFunction compare_function)
252 Key thread_keys[] = {thread_key};
253 sort(thread_keys, storage, compare_function);
256 template<
class BinaryFunction>
257 ROCPRIM_DEVICE ROCPRIM_INLINE
void sort(Key (&thread_keys)[ItemsPerThread],
258 BinaryFunction compare_function)
261 sort(thread_keys, storage, compare_function);
264 template<
class BinaryFunction>
265 ROCPRIM_DEVICE ROCPRIM_INLINE
void sort(Key (&thread_keys)[ItemsPerThread],
267 BinaryFunction compare_function)
269 thread_sort<false>(thread_keys, compare_function);
271 merge_path_merge<false>(thread_keys, storage.get(), compare_function);
275 template<
class BinaryFunction,
class V = Value>
276 ROCPRIM_DEVICE ROCPRIM_INLINE
void 277 sort(Key& thread_key, Value& thread_value, BinaryFunction compare_function)
279 Key thread_keys[] = {thread_key};
280 Value thread_values[] = {thread_value};
281 sort(thread_keys, thread_values, compare_function);
284 template<
class BinaryFunction>
285 ROCPRIM_DEVICE ROCPRIM_INLINE
void sort(Key& thread_key,
288 BinaryFunction compare_function)
290 Key thread_keys[] = {thread_key};
291 Value thread_values[] = {thread_value};
292 sort(thread_keys, thread_values, storage, compare_function);
295 template<
class BinaryFunction>
296 ROCPRIM_DEVICE ROCPRIM_INLINE
void sort(Key (&thread_keys)[ItemsPerThread],
297 Value (&thread_values)[ItemsPerThread],
298 BinaryFunction compare_function)
301 sort(thread_keys, thread_values, storage, compare_function);
304 template<
class BinaryFunction>
305 ROCPRIM_DEVICE ROCPRIM_INLINE
void sort(Key (&thread_keys)[ItemsPerThread],
307 const unsigned int input_size,
308 BinaryFunction compare_function)
310 thread_sort<true>(thread_keys, compare_function, input_size);
312 merge_path_merge<true>(thread_keys, storage.get(), compare_function, input_size);
317 template<
class BinaryFunction>
318 ROCPRIM_DEVICE ROCPRIM_INLINE
void sort(Key (&thread_keys)[ItemsPerThread],
319 Value (&thread_values)[ItemsPerThread],
321 BinaryFunction compare_function)
323 thread_sort<false>(thread_keys, thread_values, compare_function);
325 merge_path_merge<false>(thread_keys, thread_values, storage.get(), compare_function);
329 template<
class BinaryFunction>
330 ROCPRIM_DEVICE ROCPRIM_INLINE
void sort(Key (&thread_keys)[ItemsPerThread],
331 Value (&thread_values)[ItemsPerThread],
333 const unsigned int input_size,
334 BinaryFunction compare_function)
336 thread_sort<true>(thread_keys, thread_values, compare_function, input_size);
338 merge_path_merge<true>(thread_keys,
348 template<
typename Key,
unsigned int BlockSize,
unsigned int WarpSize,
typename Value>
352 constexpr
static unsigned items_per_thread = 1;
364 template<
bool is_incomplete,
typename BinaryFunction>
365 ROCPRIM_DEVICE ROCPRIM_INLINE
int merge_rank(
const unsigned int m,
367 BinaryFunction compare_function,
368 const unsigned int valid_items = BlockSize)
373 const auto n = m * 2;
375 const auto index = lane % n;
377 const auto is_lower = index < m;
379 const auto base = lane - index;
383 auto begin = base + (is_lower ? m : 0);
385 auto end = begin + m;
390 for(
auto i = 1u; i <= m; i <<= 1u)
392 const auto mid = (begin + end) / 2;
396 auto key_a = thread_key;
401 const auto mid_smaller = ((!is_incomplete || (lane < valid_items && mid < valid_items))
402 && compare_function(key_a, key_b))
405 if(mid_smaller && begin != end)
416 return index + begin - m;
420 static_assert(detail::is_power_of_two(WarpSize),
"WarpSize must be power of 2");
424 template<
class BinaryFunction>
425 ROCPRIM_DEVICE ROCPRIM_INLINE
void sort(Key& thread_key, BinaryFunction compare_function)
428 for(
auto i = 1u; i < WarpSize; i <<= 1u)
430 const auto thread_rank = merge_rank<false>(i, thread_key, compare_function);
435 template<
class BinaryFunction>
436 ROCPRIM_DEVICE ROCPRIM_INLINE
void 437 sort(Key& thread_key,
storage_type& storage, BinaryFunction compare_function)
440 sort(thread_key, compare_function);
443 template<
class BinaryFunction>
444 ROCPRIM_DEVICE ROCPRIM_INLINE
void sort(Key (&thread_keys)[items_per_thread],
445 BinaryFunction compare_function)
447 sort(thread_keys[0], compare_function);
450 template<
class BinaryFunction>
451 ROCPRIM_DEVICE ROCPRIM_INLINE
void sort(Key (&thread_keys)[items_per_thread],
453 BinaryFunction compare_function)
455 sort(thread_keys[0], storage, compare_function);
458 template<
class BinaryFunction>
459 ROCPRIM_DEVICE ROCPRIM_INLINE
void sort(Key (&thread_keys)[items_per_thread],
461 const unsigned int input_size,
462 BinaryFunction compare_function)
464 sort(thread_keys[0], storage, input_size, compare_function);
467 template<
class BinaryFunction>
468 ROCPRIM_DEVICE ROCPRIM_INLINE
void sort(Key& thread_key,
470 const unsigned int input_size,
471 BinaryFunction compare_function)
476 const auto warp_input_size = warp_offset > input_size ? 0 : input_size - warp_offset;
479 for(
auto i = 1u; i < WarpSize; i <<= 1u)
481 const auto thread_rank
482 = merge_rank<true>(i, thread_key, compare_function, warp_input_size);
487 template<
class BinaryFunction,
class V = Value>
488 ROCPRIM_DEVICE ROCPRIM_INLINE
typename std::enable_if<(sizeof(V) <= sizeof(int))>::type
489 sort(Key& thread_key, V& thread_value, BinaryFunction compare_function)
492 for(
auto i = 1u; i < WarpSize; i <<= 1u)
494 const auto thread_rank = merge_rank<false>(i, thread_key, compare_function);
500 template<
class BinaryFunction,
class V = Value>
501 ROCPRIM_DEVICE ROCPRIM_INLINE
typename std::enable_if<!(sizeof(V) <= sizeof(int))>::type
502 sort(Key& thread_key, V& thread_value, BinaryFunction compare_function)
506 sort(thread_key, value_index, compare_function);
511 template<
class BinaryFunction>
512 ROCPRIM_DEVICE ROCPRIM_INLINE
void sort(Key& thread_key,
515 BinaryFunction compare_function)
518 sort(compare_function, thread_key, thread_value);
521 template<
class BinaryFunction>
522 ROCPRIM_DEVICE ROCPRIM_INLINE
void sort(Key (&thread_keys)[items_per_thread],
523 Value (&thread_values)[items_per_thread],
524 BinaryFunction compare_function)
526 sort(thread_keys[0], thread_values[0], compare_function);
529 template<
class BinaryFunction>
530 ROCPRIM_DEVICE ROCPRIM_INLINE
void sort(Key (&thread_keys)[items_per_thread],
531 Value (&thread_values)[items_per_thread],
533 BinaryFunction compare_function)
536 sort(thread_keys[0], thread_values[0], compare_function);
539 template<
class BinaryFunction>
540 ROCPRIM_DEVICE ROCPRIM_INLINE
void sort(Key (&thread_keys)[items_per_thread],
541 Value (&thread_values)[items_per_thread],
543 unsigned int input_size,
544 BinaryFunction compare_function)
547 sort(thread_keys[0], thread_values[0], storage, input_size, compare_function);
550 template<
class BinaryFunction,
typename V = Value>
551 ROCPRIM_DEVICE ROCPRIM_INLINE
typename std::enable_if<(sizeof(V) <= sizeof(int))>::type
552 sort(Key& thread_key,
555 unsigned int input_size,
556 BinaryFunction compare_function)
561 const auto warp_input_size = warp_offset > input_size ? 0 : input_size - warp_offset;
564 for(
auto i = 1u; i < WarpSize; i <<= 1u)
566 const auto thread_rank
567 = merge_rank<true>(i, thread_key, compare_function, warp_input_size);
573 template<
class BinaryFunction,
typename V = Value>
574 ROCPRIM_DEVICE ROCPRIM_INLINE
typename std::enable_if<!(sizeof(V) <= sizeof(int))>::type
575 sort(Key& thread_key,
578 unsigned int input_size,
579 BinaryFunction compare_function)
583 sort(thread_key, value_index, storage, input_size, compare_function);
591 END_ROCPRIM_NAMESPACE
593 #endif // ROCPRIM_WARP_DETAIL_WARP_SORT_SHUFFLE_HPP_ ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int flat_block_thread_id()
Returns flat (linear, 1D) thread identifier in a multidimensional block (tile).
Definition: thread.hpp:106
Definition: warp_sort_stable.hpp:42
ROCPRIM_DEVICE ROCPRIM_INLINE T warp_permute(const T &input, const int dst_lane, const int width=device_warp_size())
Permute items across the threads in a warp.
Definition: warp_shuffle.hpp:273
hipError_t partition(void *temporary_storage, size_t &storage_size, InputIterator input, FlagIterator flags, OutputIterator output, SelectedCountOutputIterator selected_count_output, const size_t size, const hipStream_t stream=0, const bool debug_synchronous=false)
Parallel select primitive for device level using range of flags.
Definition: device_partition.hpp:721
ROCPRIM_DEVICE ROCPRIM_INLINE T warp_shuffle(const T &input, const int src_lane, const int width=device_warp_size())
Shuffle for any data type.
Definition: warp_shuffle.hpp:172
ROCPRIM_DEVICE ROCPRIM_INLINE constexpr unsigned int device_warp_size()
Returns a number of threads in a hardware warp for the actual target.
Definition: thread.hpp:70
ROCPRIM_HOST_DEVICE constexpr T min(const T &a, const T &b)
Returns the minimum of its arguments.
Definition: functional.hpp:63
ROCPRIM_DEVICE ROCPRIM_INLINE void wave_barrier()
Synchronize all threads in the wavefront.
Definition: thread.hpp:235
Deprecated: Configuration of device-level scan primitives.
Definition: block_histogram.hpp:62
const unsigned int warp_id
Returns warp id in a block (tile).
Definition: benchmark_warp_exchange.cpp:153
ROCPRIM_DEVICE ROCPRIM_INLINE void syncthreads()
Synchronize all threads in a block (tile)
Definition: thread.hpp:216
Definition: merge_path.hpp:33
ROCPRIM_HOST_DEVICE void swap(T &a, T &b)
Swaps two values.
Definition: functional.hpp:71
Definition: various.hpp:52
ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int lane_id()
Returns thread identifier in a warp.
Definition: thread.hpp:93