21 #ifndef ROCPRIM_BLOCK_DETAIL_BLOCK_SORT_SHARED_HPP_    22 #define ROCPRIM_BLOCK_DETAIL_BLOCK_SORT_SHARED_HPP_    24 #include <type_traits>    26 #include "../../config.hpp"    27 #include "../../detail/various.hpp"    29 #include "../../intrinsics.hpp"    30 #include "../../functional.hpp"    32 #include "../../warp/warp_sort.hpp"    34 BEGIN_ROCPRIM_NAMESPACE
    41     unsigned int BlockSizeX,
    42     unsigned int BlockSizeY,
    43     unsigned int BlockSizeZ,
    44     unsigned int ItemsPerThread,
    49     static constexpr 
unsigned int BlockSize     = BlockSizeX * BlockSizeY * BlockSizeZ;
    50     static constexpr 
unsigned int ItemsPerBlock = BlockSize * ItemsPerThread;
    52     template<
class KeyType, 
class ValueType>
    55         KeyType   key[BlockSize * ItemsPerThread];
    56         ValueType value[BlockSize * ItemsPerThread];
    59     template<
class KeyType>
    60     struct storage_type_<KeyType, empty_type>
    62         KeyType key[BlockSize * ItemsPerThread];
    68     template <
class BinaryFunction>
    69     ROCPRIM_DEVICE ROCPRIM_INLINE
    70     void sort(Key& thread_key,
    72               BinaryFunction compare_function)
    74         this->sort_impl<BlockSize, ItemsPerThread>(
    75             ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
    81     template<
class BinaryFunction>
    82     ROCPRIM_DEVICE ROCPRIM_INLINE
    83     void sort(Key (&thread_keys)[ItemsPerThread],
    85               BinaryFunction compare_function)
    87         this->sort_impl<BlockSize, ItemsPerThread>(
    88             ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
    94     template<
class BinaryFunction>
    95     ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
    96     void sort(Key& thread_key,
    97               BinaryFunction compare_function)
   100         this->sort(thread_key, storage, compare_function);
   103     template<
class BinaryFunction>
   104     ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
   105     void sort(Key (&thread_keys)[ItemsPerThread],
   106               BinaryFunction compare_function)
   109         this->sort(thread_keys, storage, compare_function);
   112     template<
class BinaryFunction>
   113     ROCPRIM_DEVICE ROCPRIM_INLINE
   114     void sort(Key& thread_key,
   117               BinaryFunction compare_function)
   119         this->sort_impl<BlockSize, ItemsPerThread>(
   120             ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
   127     template<
class BinaryFunction>
   128     ROCPRIM_DEVICE ROCPRIM_INLINE 
void sort(Key (&thread_keys)[ItemsPerThread],
   129                                             Value (&thread_values)[ItemsPerThread],
   131                                             BinaryFunction compare_function)
   133         this->sort_impl<BlockSize, ItemsPerThread>(
   134             ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
   141     template<
class BinaryFunction>
   142     ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
   143     void sort(Key& thread_key,
   145               BinaryFunction compare_function)
   148         this->sort(thread_key, thread_value, storage, compare_function);
   151     template<
class BinaryFunction>
   152     ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
   153     void sort(Key (&thread_keys)[ItemsPerThread],
   154               Value (&thread_values)[ItemsPerThread],
   155               BinaryFunction compare_function)
   158         this->sort(thread_keys, thread_values, storage, compare_function);
   161     template<
class BinaryFunction>
   162     ROCPRIM_DEVICE ROCPRIM_INLINE 
void sort(Key&               thread_key,
   164                                             const unsigned int size,
   165                                             BinaryFunction     compare_function)
   167         this->sort_impl(::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
   174     template<
class BinaryFunction>
   175     ROCPRIM_DEVICE ROCPRIM_INLINE 
void sort(Key (&thread_keys)[ItemsPerThread],
   177                                             const unsigned int size,
   178                                             BinaryFunction     compare_function)
   180         this->sort_impl(::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
   187     template<
class BinaryFunction>
   188     ROCPRIM_DEVICE ROCPRIM_INLINE 
void sort(Key&               thread_key,
   191                                             const unsigned int size,
   192                                             BinaryFunction     compare_function)
   194         this->sort_impl(::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
   202     template<
class BinaryFunction>
   203     ROCPRIM_DEVICE ROCPRIM_INLINE 
void sort(Key (&thread_keys)[ItemsPerThread],
   204                                             Value (&thread_values)[ItemsPerThread],
   206                                             const unsigned int size,
   207                                             BinaryFunction     compare_function)
   209         this->sort_impl(::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
   218     ROCPRIM_DEVICE ROCPRIM_INLINE
   219     void copy_to_shared(Key& k, 
const unsigned int flat_tid, 
storage_type& storage)
   221         storage_type_<Key, Value>& storage_ = storage.get();
   222         storage_.key[flat_tid] = k;
   226     ROCPRIM_DEVICE ROCPRIM_INLINE
   227     void copy_to_shared(Key (&k)[ItemsPerThread], 
const unsigned int flat_tid, 
storage_type& storage) {
   228         storage_type_<Key, Value>& storage_ = storage.get();
   230         for(
unsigned int item = 0; item < ItemsPerThread; ++item) {
   231             storage_.key[item * BlockSize + flat_tid] = k[item];
   236     ROCPRIM_DEVICE ROCPRIM_INLINE
   237     void copy_to_shared(Key& k, Value& v, 
const unsigned int flat_tid, 
storage_type& storage)
   239         storage_type_<Key, Value>& storage_ = storage.get();
   240         storage_.key[flat_tid] = k;
   241         storage_.value[flat_tid] = v;
   245     ROCPRIM_DEVICE ROCPRIM_INLINE
   246     void copy_to_shared(Key (&k)[ItemsPerThread],
   247                         Value (&v)[ItemsPerThread],
   248                         const unsigned int flat_tid,
   251         storage_type_<Key, Value>& storage_ = storage.get();
   253         for(
unsigned int item = 0; item < ItemsPerThread; ++item) {
   254             storage_.key[item * BlockSize + flat_tid]   = k[item];
   255             storage_.value[item * BlockSize + flat_tid] = v[item];
   260     template<
class BinaryFunction>
   261     ROCPRIM_DEVICE ROCPRIM_INLINE
   263               const unsigned int flat_tid,
   264               const unsigned int next_id,
   267               BinaryFunction compare_function)
   269         storage_type_<Key, Value>& storage_ = storage.get();
   270         Key next_key = storage_.key[next_id];
   271         bool compare = (next_id < flat_tid) ? compare_function(key, next_key) : compare_function(next_key, key);
   272         bool swap = compare ^ dir;
   279     template<
class BinaryFunction>
   280     ROCPRIM_DEVICE ROCPRIM_INLINE
   281     void swap(Key (&key)[ItemsPerThread],
   282               const unsigned int flat_tid,
   283               const unsigned int next_id,
   286               BinaryFunction compare_function)
   288         storage_type_<Key, Value>& storage_ = storage.get();
   290         for(
unsigned int item = 0; item < ItemsPerThread; ++item) {
   291             Key next_key = storage_.key[item * BlockSize + next_id];
   292             bool compare = (next_id < flat_tid) ? compare_function(key[item], next_key) : compare_function(next_key, key[item]);
   293             bool swap = compare ^ dir;
   296                 key[item] = next_key;
   301     template<
class BinaryFunction>
   302     ROCPRIM_DEVICE ROCPRIM_INLINE
   305               const unsigned int flat_tid,
   306               const unsigned int next_id,
   309               BinaryFunction compare_function)
   311         storage_type_<Key, Value>& storage_ = storage.get();
   312         Key next_key = storage_.key[next_id];
   313         bool b = next_id < flat_tid;
   314         bool compare = compare_function(b ? key : next_key, b ? next_key : key);
   315         bool swap = compare ^ dir;
   319             value = storage_.value[next_id];
   323     template<
class BinaryFunction>
   324     ROCPRIM_DEVICE ROCPRIM_INLINE
   325     void swap(Key (&key)[ItemsPerThread],
   326               Value (&value)[ItemsPerThread],
   327               const unsigned int flat_tid,
   328               const unsigned int next_id,
   331               BinaryFunction compare_function)
   333         storage_type_<Key, Value>& storage_ = storage.get();
   335         for(
unsigned int item = 0; item < ItemsPerThread; ++item) {
   336             Key next_key = storage_.key[item * BlockSize + next_id];
   337             bool b = next_id < flat_tid;
   338             bool compare = compare_function(b ? key[item] : next_key, b ? next_key : key[item]);
   339             bool swap = compare ^ dir;
   342                 key[item]   = next_key;
   343                 value[item] = storage_.value[item * BlockSize + next_id];
   348     template<
class BinaryFunction>
   349     ROCPRIM_DEVICE ROCPRIM_INLINE 
void swap_oddeven(Key&               key,
   350                                                     const unsigned int next_id,
   352                                                     const unsigned int next_item_id,
   355                                                     BinaryFunction     compare_function)
   357         storage_type_<Key, Value>& storage_ = storage.get();
   358         Key                        next_key = storage_.key[next_item_id * BlockSize + next_id];
   364         bool swap = compare_function(next_key, key);
   375     template<
class BinaryFunction>
   376     ROCPRIM_DEVICE ROCPRIM_INLINE 
void swap_oddeven(Key (&keys)[ItemsPerThread],
   377                                                     const unsigned int next_id,
   378                                                     const unsigned int item,
   379                                                     const unsigned int next_item_id,
   382                                                     BinaryFunction     compare_function)
   384         storage_type_<Key, Value>& storage_ = storage.get();
   385         Key                        next_key = storage_.key[next_item_id * BlockSize + next_id];
   391         bool swap = compare_function(next_key, keys[item]);
   398             keys[item] = next_key;
   402     template<
class BinaryFunction>
   403     ROCPRIM_DEVICE ROCPRIM_INLINE 
void swap_oddeven(Key&               key,
   405                                                     const unsigned int next_id,
   407                                                     const unsigned int next_item_id,
   410                                                     BinaryFunction     compare_function)
   412         storage_type_<Key, Value>& storage_ = storage.get();
   413         Key                        next_key = storage_.key[next_item_id * BlockSize + next_id];
   419         bool swap = compare_function(next_key, key);
   427             value = storage_.value[next_item_id * BlockSize + next_id];
   431     template<
class BinaryFunction>
   432     ROCPRIM_DEVICE ROCPRIM_INLINE 
void swap_oddeven(Key (&keys)[ItemsPerThread],
   433                                                     Value (&values)[ItemsPerThread],
   434                                                     const unsigned int next_id,
   435                                                     const unsigned int item,
   436                                                     const unsigned int next_item_id,
   439                                                     BinaryFunction     compare_function)
   441         storage_type_<Key, Value>& storage_ = storage.get();
   442         Key                        next_key = storage_.key[next_item_id * BlockSize + next_id];
   448         bool swap = compare_function(next_key, keys[item]);
   455             keys[item]   = next_key;
   456             values[item] = storage_.value[next_item_id * BlockSize + next_id];
   462         class BinaryFunction,
   465     ROCPRIM_DEVICE ROCPRIM_INLINE
   466     typename std::enable_if<(Size <= ::rocprim::device_warp_size())>::type
   467     sort_power_two(
const unsigned int flat_tid,
   469                    BinaryFunction compare_function,
   475         ::rocprim::warp_sort<Key, Size, Value> wsort;
   476         wsort.sort(kv..., compare_function);
   479     template<
class BinaryFunction>
   480     ROCPRIM_DEVICE ROCPRIM_INLINE
   481     void warp_swap(Key& k, Value& v, 
int mask, 
bool dir, BinaryFunction compare_function)
   484         bool swap = compare_function(dir ? k : k1, dir ? k1 : k);
   492     template <
class BinaryFunction>
   493     ROCPRIM_DEVICE ROCPRIM_INLINE
   494     void warp_swap(Key (&k)[ItemsPerThread],
   495                    Value (&v)[ItemsPerThread],
   498                    BinaryFunction compare_function)
   501         for(
unsigned int item = 0; item < ItemsPerThread; ++item) {
   503             bool swap = compare_function(dir ? k[item] : k1, dir ? k1 : k[item]);
   512     template<
class BinaryFunction>
   513     ROCPRIM_DEVICE ROCPRIM_INLINE
   514     void warp_swap(Key& k, 
int mask, 
bool dir, BinaryFunction compare_function)
   517         bool swap = compare_function(dir ? k : k1, dir ? k1 : k);
   524     template <
class BinaryFunction>
   525     ROCPRIM_DEVICE ROCPRIM_INLINE
   526     void warp_swap(Key (&k)[ItemsPerThread], 
int mask, 
bool dir, BinaryFunction compare_function)
   529         for(
unsigned int item = 0; item < ItemsPerThread; ++item) {
   531             bool swap = compare_function(dir ? k[item] : k1, dir ? k1 : k[item]);
   539     template <
class BinaryFunction, 
unsigned int Items = ItemsPerThread, 
class... KeyValue>
   540     ROCPRIM_DEVICE ROCPRIM_INLINE
   541     typename std::enable_if<(Items < 2)>::type
   542     thread_merge(
bool , BinaryFunction , KeyValue&... )
   546     template <
class BinaryFunction>
   547     ROCPRIM_DEVICE ROCPRIM_INLINE
   548     void thread_swap(Key (&k)[ItemsPerThread],
   549                      Value (&v)[ItemsPerThread],
   553                      BinaryFunction compare_function)
   555         if(compare_function(k[i], k[j]) == dir)
   565     template <
class BinaryFunction>
   566     ROCPRIM_DEVICE ROCPRIM_INLINE
   567     void thread_swap(Key (&k)[ItemsPerThread],
   571                      BinaryFunction compare_function)
   573         if(compare_function(k[i], k[j]) == dir)
   581     template <
class BinaryFunction, 
class... KeyValue>
   582     ROCPRIM_DEVICE ROCPRIM_INLINE
   583     void thread_shuffle(
unsigned int offset, 
bool dir, BinaryFunction compare_function, KeyValue&... kv)
   586         for(
unsigned base = 0; base < ItemsPerThread; base += 2 * offset)
   591 #if defined(__clang_major__) && __clang_major__ >= 15   592     #pragma clang loop vectorize(disable)   594             for(
unsigned i = 0; i < offset; ++i)
   596                 thread_swap(kv..., dir, base + i, base + i + offset, compare_function);
   601     template <
class BinaryFunction, 
unsigned int Items = ItemsPerThread, 
class... KeyValue>
   602     ROCPRIM_DEVICE ROCPRIM_INLINE
   603     typename std::enable_if<!(Items < 2)>::type
   604     thread_merge(
bool dir, BinaryFunction compare_function, KeyValue&... kv)
   607         for(
unsigned int k = ItemsPerThread / 2; k > 0; k /= 2)
   609             thread_shuffle(k, dir, compare_function, kv...);
   615     template<
unsigned int BS, 
class BinaryFunction, 
class... KeyValue>
   616     ROCPRIM_DEVICE ROCPRIM_INLINE
   618         sort_power_two(
const unsigned int flat_tid,
   620                        BinaryFunction     compare_function,
   624         ::rocprim::warp_sort<Key, ::rocprim::device_warp_size(), Value> wsort;
   625         auto compare_function2 =
   626             [compare_function, warp_id_is_even](
const Key& a, 
const Key& b) 
mutable -> 
bool   628                 auto r = compare_function(a, b);
   633         wsort.sort(kv..., compare_function2);
   638             const bool dir = (flat_tid & (length * 2)) != 0;
   642                 copy_to_shared(kv..., flat_tid, storage);
   643                 swap(kv..., flat_tid, flat_tid ^ k, dir, storage, compare_function);
   650                 const bool length_even = ((detail::logical_lane_id<::rocprim::device_warp_size()>() / k ) % 2 ) == 0;
   651                 const bool local_dir = length_even ? dir : !dir;
   652                 warp_swap(kv..., k, local_dir, compare_function);
   654             thread_merge(dir, compare_function, kv...);
   658     template<
unsigned int BS, 
unsigned int IPT, 
class BinaryFunction, 
class... KeyValue>
   659     ROCPRIM_DEVICE ROCPRIM_INLINE
   660         typename std::enable_if<is_power_of_two(BS) && is_power_of_two(IPT)>::type
   661         sort_impl(
const unsigned int flat_tid,
   663                   BinaryFunction     compare_function,
   666         static constexpr 
unsigned int PairSize = 
sizeof...(KeyValue);
   667         static_assert(PairSize < 3,
   668                       "KeyValue parameter pack can be 1 or 2 elements (key, or key and value)");
   670         sort_power_two<BS>(flat_tid, storage, compare_function, kv...);
   675     template<
bool SizeCheck, 
class BinaryFunction, 
class... KeyValue>
   676     ROCPRIM_DEVICE ROCPRIM_INLINE 
void odd_even_sort(
const unsigned int flat_tid,
   677                                                      const unsigned int size,
   679                                                      BinaryFunction     compare_function,
   682         static constexpr 
unsigned int PairSize = 
sizeof...(KeyValue);
   683         static_assert(PairSize < 3,
   684                       "KeyValue parameter pack can be 1 or 2 elements (key, or key and value)");
   686         if(SizeCheck && size > ItemsPerBlock)
   691         copy_to_shared(kv..., flat_tid, storage);
   693         for(
unsigned int i = 0; i < size; i++)
   695             bool is_even_iter = i % 2 == 0;
   696             for(
unsigned int item = 0; item < ItemsPerThread; ++item)
   699                 unsigned int linear_id   = flat_tid * ItemsPerThread + item;
   700                 bool         is_even_lid = linear_id % 2 == 0;
   703                 unsigned int odd_lid  = is_even_lid ? 
::rocprim::max(linear_id, 1u) - 1
   705                 unsigned int even_lid = is_even_lid ? 
::rocprim::min(linear_id + 1, size - 1)
   709                 unsigned int next_lid = is_even_iter ? even_lid : odd_lid;
   712                 unsigned int next_id      = next_lid / ItemsPerThread;
   713                 unsigned int next_item_id = next_lid % ItemsPerThread;
   716                 if(!SizeCheck || (linear_id < size && next_lid < size))
   722                                  next_lid < linear_id,
   728             copy_to_shared(kv..., flat_tid, storage);
   732     template<
unsigned int BS, 
unsigned int IPT, 
class BinaryFunction, 
class... KeyValue>
   733     ROCPRIM_DEVICE ROCPRIM_INLINE
   734         typename std::enable_if<!is_power_of_two(BS) || !is_power_of_two(IPT)>::type
   735         sort_impl(
const unsigned int flat_tid,
   737                   BinaryFunction     compare_function,
   740         odd_even_sort<false>(flat_tid, ItemsPerBlock, storage, compare_function, kv...);
   743     template<
class BinaryFunction, 
class... KeyValue>
   744     ROCPRIM_DEVICE ROCPRIM_INLINE 
void sort_impl(
const unsigned int flat_tid,
   745                                                  const unsigned int size,
   747                                                  BinaryFunction     compare_function,
   750         odd_even_sort<true>(flat_tid, size, storage, compare_function, kv...);
   756 END_ROCPRIM_NAMESPACE
   758 #endif // ROCPRIM_BLOCK_DETAIL_BLOCK_SORT_SHARED_HPP_ ROCPRIM_HOST_DEVICE constexpr T max(const T &a, const T &b)
Returns the maximum of its arguments. 
Definition: functional.hpp:55
ROCPRIM_DEVICE ROCPRIM_INLINE constexpr unsigned int device_warp_size()
Returns a number of threads in a hardware warp for the actual target. 
Definition: thread.hpp:70
Definition: block_sort_bitonic.hpp:47
ROCPRIM_HOST_DEVICE constexpr T min(const T &a, const T &b)
Returns the minimum of its arguments. 
Definition: functional.hpp:63
Deprecated: Configuration of device-level scan primitives. 
Definition: block_histogram.hpp:62
ROCPRIM_DEVICE ROCPRIM_INLINE void syncthreads()
Synchronize all threads in a block (tile) 
Definition: thread.hpp:216
ROCPRIM_HOST_DEVICE void swap(T &a, T &b)
Swaps two values. 
Definition: functional.hpp:71
Definition: various.hpp:180
ROCPRIM_DEVICE ROCPRIM_INLINE T warp_shuffle_xor(const T &input, const int lane_mask, const int width=device_warp_size())
Shuffle XOR for any data type. 
Definition: warp_shuffle.hpp:246