21 #ifndef ROCPRIM_BLOCK_DETAIL_BLOCK_SORT_SHARED_HPP_ 22 #define ROCPRIM_BLOCK_DETAIL_BLOCK_SORT_SHARED_HPP_ 24 #include <type_traits> 26 #include "../../config.hpp" 27 #include "../../detail/various.hpp" 29 #include "../../intrinsics.hpp" 30 #include "../../functional.hpp" 32 #include "../../warp/warp_sort.hpp" 34 BEGIN_ROCPRIM_NAMESPACE
41 unsigned int BlockSizeX,
42 unsigned int BlockSizeY,
43 unsigned int BlockSizeZ,
44 unsigned int ItemsPerThread,
49 static constexpr
unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ;
50 static constexpr
unsigned int ItemsPerBlock = BlockSize * ItemsPerThread;
52 template<
class KeyType,
class ValueType>
55 KeyType key[BlockSize * ItemsPerThread];
56 ValueType value[BlockSize * ItemsPerThread];
59 template<
class KeyType>
60 struct storage_type_<KeyType, empty_type>
62 KeyType key[BlockSize * ItemsPerThread];
68 template <
class BinaryFunction>
69 ROCPRIM_DEVICE ROCPRIM_INLINE
70 void sort(Key& thread_key,
72 BinaryFunction compare_function)
74 this->sort_impl<BlockSize, ItemsPerThread>(
75 ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
81 template<
class BinaryFunction>
82 ROCPRIM_DEVICE ROCPRIM_INLINE
83 void sort(Key (&thread_keys)[ItemsPerThread],
85 BinaryFunction compare_function)
87 this->sort_impl<BlockSize, ItemsPerThread>(
88 ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
94 template<
class BinaryFunction>
95 ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
96 void sort(Key& thread_key,
97 BinaryFunction compare_function)
100 this->sort(thread_key, storage, compare_function);
103 template<
class BinaryFunction>
104 ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
105 void sort(Key (&thread_keys)[ItemsPerThread],
106 BinaryFunction compare_function)
109 this->sort(thread_keys, storage, compare_function);
112 template<
class BinaryFunction>
113 ROCPRIM_DEVICE ROCPRIM_INLINE
114 void sort(Key& thread_key,
117 BinaryFunction compare_function)
119 this->sort_impl<BlockSize, ItemsPerThread>(
120 ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
127 template<
class BinaryFunction>
128 ROCPRIM_DEVICE ROCPRIM_INLINE
void sort(Key (&thread_keys)[ItemsPerThread],
129 Value (&thread_values)[ItemsPerThread],
131 BinaryFunction compare_function)
133 this->sort_impl<BlockSize, ItemsPerThread>(
134 ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
141 template<
class BinaryFunction>
142 ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
143 void sort(Key& thread_key,
145 BinaryFunction compare_function)
148 this->sort(thread_key, thread_value, storage, compare_function);
151 template<
class BinaryFunction>
152 ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
153 void sort(Key (&thread_keys)[ItemsPerThread],
154 Value (&thread_values)[ItemsPerThread],
155 BinaryFunction compare_function)
158 this->sort(thread_keys, thread_values, storage, compare_function);
161 template<
class BinaryFunction>
162 ROCPRIM_DEVICE ROCPRIM_INLINE
void sort(Key& thread_key,
164 const unsigned int size,
165 BinaryFunction compare_function)
167 this->sort_impl(::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
174 template<
class BinaryFunction>
175 ROCPRIM_DEVICE ROCPRIM_INLINE
void sort(Key (&thread_keys)[ItemsPerThread],
177 const unsigned int size,
178 BinaryFunction compare_function)
180 this->sort_impl(::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
187 template<
class BinaryFunction>
188 ROCPRIM_DEVICE ROCPRIM_INLINE
void sort(Key& thread_key,
191 const unsigned int size,
192 BinaryFunction compare_function)
194 this->sort_impl(::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
202 template<
class BinaryFunction>
203 ROCPRIM_DEVICE ROCPRIM_INLINE
void sort(Key (&thread_keys)[ItemsPerThread],
204 Value (&thread_values)[ItemsPerThread],
206 const unsigned int size,
207 BinaryFunction compare_function)
209 this->sort_impl(::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
218 ROCPRIM_DEVICE ROCPRIM_INLINE
219 void copy_to_shared(Key& k,
const unsigned int flat_tid,
storage_type& storage)
221 storage_type_<Key, Value>& storage_ = storage.get();
222 storage_.key[flat_tid] = k;
226 ROCPRIM_DEVICE ROCPRIM_INLINE
227 void copy_to_shared(Key (&k)[ItemsPerThread],
const unsigned int flat_tid,
storage_type& storage) {
228 storage_type_<Key, Value>& storage_ = storage.get();
230 for(
unsigned int item = 0; item < ItemsPerThread; ++item) {
231 storage_.key[item * BlockSize + flat_tid] = k[item];
236 ROCPRIM_DEVICE ROCPRIM_INLINE
237 void copy_to_shared(Key& k, Value& v,
const unsigned int flat_tid,
storage_type& storage)
239 storage_type_<Key, Value>& storage_ = storage.get();
240 storage_.key[flat_tid] = k;
241 storage_.value[flat_tid] = v;
245 ROCPRIM_DEVICE ROCPRIM_INLINE
246 void copy_to_shared(Key (&k)[ItemsPerThread],
247 Value (&v)[ItemsPerThread],
248 const unsigned int flat_tid,
251 storage_type_<Key, Value>& storage_ = storage.get();
253 for(
unsigned int item = 0; item < ItemsPerThread; ++item) {
254 storage_.key[item * BlockSize + flat_tid] = k[item];
255 storage_.value[item * BlockSize + flat_tid] = v[item];
260 template<
class BinaryFunction>
261 ROCPRIM_DEVICE ROCPRIM_INLINE
263 const unsigned int flat_tid,
264 const unsigned int next_id,
267 BinaryFunction compare_function)
269 storage_type_<Key, Value>& storage_ = storage.get();
270 Key next_key = storage_.key[next_id];
271 bool compare = (next_id < flat_tid) ? compare_function(key, next_key) : compare_function(next_key, key);
272 bool swap = compare ^ dir;
279 template<
class BinaryFunction>
280 ROCPRIM_DEVICE ROCPRIM_INLINE
281 void swap(Key (&key)[ItemsPerThread],
282 const unsigned int flat_tid,
283 const unsigned int next_id,
286 BinaryFunction compare_function)
288 storage_type_<Key, Value>& storage_ = storage.get();
290 for(
unsigned int item = 0; item < ItemsPerThread; ++item) {
291 Key next_key = storage_.key[item * BlockSize + next_id];
292 bool compare = (next_id < flat_tid) ? compare_function(key[item], next_key) : compare_function(next_key, key[item]);
293 bool swap = compare ^ dir;
296 key[item] = next_key;
301 template<
class BinaryFunction>
302 ROCPRIM_DEVICE ROCPRIM_INLINE
305 const unsigned int flat_tid,
306 const unsigned int next_id,
309 BinaryFunction compare_function)
311 storage_type_<Key, Value>& storage_ = storage.get();
312 Key next_key = storage_.key[next_id];
313 bool b = next_id < flat_tid;
314 bool compare = compare_function(b ? key : next_key, b ? next_key : key);
315 bool swap = compare ^ dir;
319 value = storage_.value[next_id];
323 template<
class BinaryFunction>
324 ROCPRIM_DEVICE ROCPRIM_INLINE
325 void swap(Key (&key)[ItemsPerThread],
326 Value (&value)[ItemsPerThread],
327 const unsigned int flat_tid,
328 const unsigned int next_id,
331 BinaryFunction compare_function)
333 storage_type_<Key, Value>& storage_ = storage.get();
335 for(
unsigned int item = 0; item < ItemsPerThread; ++item) {
336 Key next_key = storage_.key[item * BlockSize + next_id];
337 bool b = next_id < flat_tid;
338 bool compare = compare_function(b ? key[item] : next_key, b ? next_key : key[item]);
339 bool swap = compare ^ dir;
342 key[item] = next_key;
343 value[item] = storage_.value[item * BlockSize + next_id];
348 template<
class BinaryFunction>
349 ROCPRIM_DEVICE ROCPRIM_INLINE
void swap_oddeven(Key& key,
350 const unsigned int next_id,
352 const unsigned int next_item_id,
355 BinaryFunction compare_function)
357 storage_type_<Key, Value>& storage_ = storage.get();
358 Key next_key = storage_.key[next_item_id * BlockSize + next_id];
364 bool swap = compare_function(next_key, key);
375 template<
class BinaryFunction>
376 ROCPRIM_DEVICE ROCPRIM_INLINE
void swap_oddeven(Key (&keys)[ItemsPerThread],
377 const unsigned int next_id,
378 const unsigned int item,
379 const unsigned int next_item_id,
382 BinaryFunction compare_function)
384 storage_type_<Key, Value>& storage_ = storage.get();
385 Key next_key = storage_.key[next_item_id * BlockSize + next_id];
391 bool swap = compare_function(next_key, keys[item]);
398 keys[item] = next_key;
402 template<
class BinaryFunction>
403 ROCPRIM_DEVICE ROCPRIM_INLINE
void swap_oddeven(Key& key,
405 const unsigned int next_id,
407 const unsigned int next_item_id,
410 BinaryFunction compare_function)
412 storage_type_<Key, Value>& storage_ = storage.get();
413 Key next_key = storage_.key[next_item_id * BlockSize + next_id];
419 bool swap = compare_function(next_key, key);
427 value = storage_.value[next_item_id * BlockSize + next_id];
431 template<
class BinaryFunction>
432 ROCPRIM_DEVICE ROCPRIM_INLINE
void swap_oddeven(Key (&keys)[ItemsPerThread],
433 Value (&values)[ItemsPerThread],
434 const unsigned int next_id,
435 const unsigned int item,
436 const unsigned int next_item_id,
439 BinaryFunction compare_function)
441 storage_type_<Key, Value>& storage_ = storage.get();
442 Key next_key = storage_.key[next_item_id * BlockSize + next_id];
448 bool swap = compare_function(next_key, keys[item]);
455 keys[item] = next_key;
456 values[item] = storage_.value[next_item_id * BlockSize + next_id];
462 class BinaryFunction,
465 ROCPRIM_DEVICE ROCPRIM_INLINE
466 typename std::enable_if<(Size <= ::rocprim::device_warp_size())>::type
467 sort_power_two(
const unsigned int flat_tid,
469 BinaryFunction compare_function,
475 ::rocprim::warp_sort<Key, Size, Value> wsort;
476 wsort.sort(kv..., compare_function);
479 template<
class BinaryFunction>
480 ROCPRIM_DEVICE ROCPRIM_INLINE
481 void warp_swap(Key& k, Value& v,
int mask,
bool dir, BinaryFunction compare_function)
484 bool swap = compare_function(dir ? k : k1, dir ? k1 : k);
492 template <
class BinaryFunction>
493 ROCPRIM_DEVICE ROCPRIM_INLINE
494 void warp_swap(Key (&k)[ItemsPerThread],
495 Value (&v)[ItemsPerThread],
498 BinaryFunction compare_function)
501 for(
unsigned int item = 0; item < ItemsPerThread; ++item) {
503 bool swap = compare_function(dir ? k[item] : k1, dir ? k1 : k[item]);
512 template<
class BinaryFunction>
513 ROCPRIM_DEVICE ROCPRIM_INLINE
514 void warp_swap(Key& k,
int mask,
bool dir, BinaryFunction compare_function)
517 bool swap = compare_function(dir ? k : k1, dir ? k1 : k);
524 template <
class BinaryFunction>
525 ROCPRIM_DEVICE ROCPRIM_INLINE
526 void warp_swap(Key (&k)[ItemsPerThread],
int mask,
bool dir, BinaryFunction compare_function)
529 for(
unsigned int item = 0; item < ItemsPerThread; ++item) {
531 bool swap = compare_function(dir ? k[item] : k1, dir ? k1 : k[item]);
539 template <
class BinaryFunction,
unsigned int Items = ItemsPerThread,
class... KeyValue>
540 ROCPRIM_DEVICE ROCPRIM_INLINE
541 typename std::enable_if<(Items < 2)>::type
542 thread_merge(
bool , BinaryFunction , KeyValue&... )
546 template <
class BinaryFunction>
547 ROCPRIM_DEVICE ROCPRIM_INLINE
548 void thread_swap(Key (&k)[ItemsPerThread],
549 Value (&v)[ItemsPerThread],
553 BinaryFunction compare_function)
555 if(compare_function(k[i], k[j]) == dir)
565 template <
class BinaryFunction>
566 ROCPRIM_DEVICE ROCPRIM_INLINE
567 void thread_swap(Key (&k)[ItemsPerThread],
571 BinaryFunction compare_function)
573 if(compare_function(k[i], k[j]) == dir)
581 template <
class BinaryFunction,
class... KeyValue>
582 ROCPRIM_DEVICE ROCPRIM_INLINE
583 void thread_shuffle(
unsigned int offset,
bool dir, BinaryFunction compare_function, KeyValue&... kv)
586 for(
unsigned base = 0; base < ItemsPerThread; base += 2 * offset)
591 #if defined(__clang_major__) && __clang_major__ >= 15 592 #pragma clang loop vectorize(disable) 594 for(
unsigned i = 0; i < offset; ++i)
596 thread_swap(kv..., dir, base + i, base + i + offset, compare_function);
601 template <
class BinaryFunction,
unsigned int Items = ItemsPerThread,
class... KeyValue>
602 ROCPRIM_DEVICE ROCPRIM_INLINE
603 typename std::enable_if<!(Items < 2)>::type
604 thread_merge(
bool dir, BinaryFunction compare_function, KeyValue&... kv)
607 for(
unsigned int k = ItemsPerThread / 2; k > 0; k /= 2)
609 thread_shuffle(k, dir, compare_function, kv...);
615 template<
unsigned int BS,
class BinaryFunction,
class... KeyValue>
616 ROCPRIM_DEVICE ROCPRIM_INLINE
618 sort_power_two(
const unsigned int flat_tid,
620 BinaryFunction compare_function,
624 ::rocprim::warp_sort<Key, ::rocprim::device_warp_size(), Value> wsort;
625 auto compare_function2 =
626 [compare_function, warp_id_is_even](
const Key& a,
const Key& b)
mutable ->
bool 628 auto r = compare_function(a, b);
633 wsort.sort(kv..., compare_function2);
638 const bool dir = (flat_tid & (length * 2)) != 0;
642 copy_to_shared(kv..., flat_tid, storage);
643 swap(kv..., flat_tid, flat_tid ^ k, dir, storage, compare_function);
650 const bool length_even = ((detail::logical_lane_id<::rocprim::device_warp_size()>() / k ) % 2 ) == 0;
651 const bool local_dir = length_even ? dir : !dir;
652 warp_swap(kv..., k, local_dir, compare_function);
654 thread_merge(dir, compare_function, kv...);
658 template<
unsigned int BS,
unsigned int IPT,
class BinaryFunction,
class... KeyValue>
659 ROCPRIM_DEVICE ROCPRIM_INLINE
660 typename std::enable_if<is_power_of_two(BS) && is_power_of_two(IPT)>::type
661 sort_impl(
const unsigned int flat_tid,
663 BinaryFunction compare_function,
666 static constexpr
unsigned int PairSize =
sizeof...(KeyValue);
667 static_assert(PairSize < 3,
668 "KeyValue parameter pack can be 1 or 2 elements (key, or key and value)");
670 sort_power_two<BS>(flat_tid, storage, compare_function, kv...);
675 template<
bool SizeCheck,
class BinaryFunction,
class... KeyValue>
676 ROCPRIM_DEVICE ROCPRIM_INLINE
void odd_even_sort(
const unsigned int flat_tid,
677 const unsigned int size,
679 BinaryFunction compare_function,
682 static constexpr
unsigned int PairSize =
sizeof...(KeyValue);
683 static_assert(PairSize < 3,
684 "KeyValue parameter pack can be 1 or 2 elements (key, or key and value)");
686 if(SizeCheck && size > ItemsPerBlock)
691 copy_to_shared(kv..., flat_tid, storage);
693 for(
unsigned int i = 0; i < size; i++)
695 bool is_even_iter = i % 2 == 0;
696 for(
unsigned int item = 0; item < ItemsPerThread; ++item)
699 unsigned int linear_id = flat_tid * ItemsPerThread + item;
700 bool is_even_lid = linear_id % 2 == 0;
703 unsigned int odd_lid = is_even_lid ?
::rocprim::max(linear_id, 1u) - 1
705 unsigned int even_lid = is_even_lid ?
::rocprim::min(linear_id + 1, size - 1)
709 unsigned int next_lid = is_even_iter ? even_lid : odd_lid;
712 unsigned int next_id = next_lid / ItemsPerThread;
713 unsigned int next_item_id = next_lid % ItemsPerThread;
716 if(!SizeCheck || (linear_id < size && next_lid < size))
722 next_lid < linear_id,
728 copy_to_shared(kv..., flat_tid, storage);
732 template<
unsigned int BS,
unsigned int IPT,
class BinaryFunction,
class... KeyValue>
733 ROCPRIM_DEVICE ROCPRIM_INLINE
734 typename std::enable_if<!is_power_of_two(BS) || !is_power_of_two(IPT)>::type
735 sort_impl(
const unsigned int flat_tid,
737 BinaryFunction compare_function,
740 odd_even_sort<false>(flat_tid, ItemsPerBlock, storage, compare_function, kv...);
743 template<
class BinaryFunction,
class... KeyValue>
744 ROCPRIM_DEVICE ROCPRIM_INLINE
void sort_impl(
const unsigned int flat_tid,
745 const unsigned int size,
747 BinaryFunction compare_function,
750 odd_even_sort<true>(flat_tid, size, storage, compare_function, kv...);
756 END_ROCPRIM_NAMESPACE
758 #endif // ROCPRIM_BLOCK_DETAIL_BLOCK_SORT_SHARED_HPP_ ROCPRIM_HOST_DEVICE constexpr T max(const T &a, const T &b)
Returns the maximum of its arguments.
Definition: functional.hpp:55
ROCPRIM_DEVICE ROCPRIM_INLINE constexpr unsigned int device_warp_size()
Returns a number of threads in a hardware warp for the actual target.
Definition: thread.hpp:70
Definition: block_sort_bitonic.hpp:47
ROCPRIM_HOST_DEVICE constexpr T min(const T &a, const T &b)
Returns the minimum of its arguments.
Definition: functional.hpp:63
Deprecated: Configuration of device-level scan primitives.
Definition: block_histogram.hpp:62
ROCPRIM_DEVICE ROCPRIM_INLINE void syncthreads()
Synchronize all threads in a block (tile)
Definition: thread.hpp:216
ROCPRIM_HOST_DEVICE void swap(T &a, T &b)
Swaps two values.
Definition: functional.hpp:71
Definition: various.hpp:180
ROCPRIM_DEVICE ROCPRIM_INLINE T warp_shuffle_xor(const T &input, const int lane_mask, const int width=device_warp_size())
Shuffle XOR for any data type.
Definition: warp_shuffle.hpp:246