21 #ifndef ROCPRIM_WARP_DETAIL_WARP_SORT_SHUFFLE_HPP_    22 #define ROCPRIM_WARP_DETAIL_WARP_SORT_SHUFFLE_HPP_    24 #include <type_traits>    26 #include "../../config.hpp"    27 #include "../../detail/various.hpp"    29 #include "../../intrinsics.hpp"    30 #include "../../functional.hpp"    32 BEGIN_ROCPRIM_NAMESPACE
    39     unsigned int WarpSize,
    45     template<
int warp, 
class V, 
class BinaryFunction>
    46     ROCPRIM_DEVICE ROCPRIM_INLINE
    47     typename std::enable_if<!(WarpSize > warp)>::type
    48     swap(Key& k, V& v, 
int mask, 
bool dir, BinaryFunction compare_function)
    54         (void) compare_function;
    57     template<
int warp, 
class V, 
class BinaryFunction>
    58     ROCPRIM_DEVICE ROCPRIM_INLINE
    59     typename std::enable_if<(WarpSize > warp)>::type
    60     swap(Key& k, V& v, 
int mask, 
bool dir, BinaryFunction compare_function)
    64         bool swap = compare_function(dir ? k : k1, dir ? k1 : k);
    76         unsigned int ItemsPerThread
    78     ROCPRIM_DEVICE ROCPRIM_INLINE
    79     typename std::enable_if<!(WarpSize > warp)>::type
    80     swap(Key (&k)[ItemsPerThread],
    81          V (&v)[ItemsPerThread],
    84          BinaryFunction compare_function)
    90         (void) compare_function;
    97         unsigned int ItemsPerThread
    99     ROCPRIM_DEVICE ROCPRIM_INLINE
   100     typename std::enable_if<(WarpSize > warp)>::type
   101     swap(Key (&k)[ItemsPerThread],
   102          V (&v)[ItemsPerThread],
   105          BinaryFunction compare_function)
   107         Key k1[ItemsPerThread];
   109         for (
unsigned int item = 0; item < ItemsPerThread; item++)
   113            bool swap = compare_function(dir ? k[item] : k1[item], dir ? k1[item] : k[item]);
   122     template<
int warp, 
class BinaryFunction>
   123     ROCPRIM_DEVICE ROCPRIM_INLINE
   124     typename std::enable_if<!(WarpSize > warp)>::type
   125     swap(Key& k, 
int mask, 
bool dir, BinaryFunction compare_function)
   130         (void) compare_function;
   133     template<
int warp, 
class BinaryFunction>
   134     ROCPRIM_DEVICE ROCPRIM_INLINE
   135     typename std::enable_if<(WarpSize > warp)>::type
   136     swap(Key& k, 
int mask, 
bool dir, BinaryFunction compare_function)
   139         bool swap = compare_function(dir ? k : k1, dir ? k1 : k);
   148         class BinaryFunction,
   149         unsigned int ItemsPerThread
   151     ROCPRIM_DEVICE ROCPRIM_INLINE
   152     typename std::enable_if<!(WarpSize > warp)>::type
   153     swap(Key (&k)[ItemsPerThread], 
int mask, 
bool dir, BinaryFunction compare_function)
   158         (void) compare_function;
   163         class BinaryFunction,
   164         unsigned int ItemsPerThread
   166     ROCPRIM_DEVICE ROCPRIM_INLINE
   167     typename std::enable_if<(WarpSize > warp)>::type
   168     swap(Key (&k)[ItemsPerThread], 
int mask, 
bool dir, BinaryFunction compare_function)
   170         Key k1[ItemsPerThread];
   172         for (
unsigned int item = 0; item < ItemsPerThread; item++)
   175             bool swap = compare_function(dir ? k[item] : k1[item], dir ? k1[item] : k[item]);
   183     template <
unsigned int ItemsPerThread, 
class BinaryFunction>
   184     ROCPRIM_DEVICE ROCPRIM_INLINE
   185     void thread_swap(Key (&k)[ItemsPerThread],
   189                      BinaryFunction compare_function)
   191         if(compare_function(k[i], k[j]) == dir)
   199     template <
unsigned int ItemsPerThread, 
class V, 
class BinaryFunction>
   200     ROCPRIM_DEVICE ROCPRIM_INLINE
   201     void thread_swap(Key (&k)[ItemsPerThread],
   202                      V   (&v)[ItemsPerThread],
   206                      BinaryFunction compare_function)
   208         if(compare_function(k[i], k[j]) == dir)
   219     template <
unsigned int ItemsPerThread, 
class BinaryFunction, 
class... KeyValue>
   220     ROCPRIM_DEVICE ROCPRIM_INLINE
   221     void thread_shuffle(
unsigned int   group_size,
   224                         BinaryFunction compare_function,
   228         for(
unsigned int base = 0; base < ItemsPerThread; base += 2 * offset) {
   231             const bool local_dir = ((base & group_size) > 0) != dir;
   236 #if defined(__clang_major__) && __clang_major__ >= 15   237     #pragma clang loop vectorize(disable)   239             for(
unsigned i = 0; i < offset; ++i) {
   240                 thread_swap(kv..., base + i, base + i + offset, local_dir, compare_function);
   245     template <
unsigned int ItemsPerThread, 
class BinaryFunction, 
class... KeyValue>
   246     ROCPRIM_DEVICE ROCPRIM_INLINE 
   247     void thread_sort(
bool dir, BinaryFunction compare_function, KeyValue&... kv)
   250         for(
unsigned int k = 2; k <= ItemsPerThread; k *= 2)
   253             for(
unsigned int j = k / 2; j > 0; j /= 2)
   255                 thread_shuffle<ItemsPerThread>(k, j, dir, compare_function, kv...);
   260     template <
int warp, 
unsigned int ItemsPerThread, 
class BinaryFunction, 
class... KeyValue>
   261     ROCPRIM_DEVICE ROCPRIM_INLINE
   262     typename std::enable_if<(WarpSize > warp)>::type
   263     thread_merge(
bool dir, BinaryFunction compare_function, KeyValue&... kv)
   266         for(
unsigned int j = ItemsPerThread / 2; j > 0; j /= 2)
   268             thread_shuffle<ItemsPerThread>(ItemsPerThread, j, dir, compare_function, kv...);
   272     template <
int warp, 
unsigned int ItemsPerThread, 
class BinaryFunction, 
class... KeyValue>
   273     ROCPRIM_DEVICE ROCPRIM_INLINE
   274     typename std::enable_if<!(WarpSize > warp)>::type
   275     thread_merge(
bool , BinaryFunction , KeyValue&... )
   279     template<
class BinaryFunction, 
class... KeyValue>
   280     ROCPRIM_DEVICE ROCPRIM_INLINE
   281     void bitonic_sort(BinaryFunction compare_function, KeyValue&... kv)
   284             sizeof...(KeyValue) < 3,
   285             "KeyValue parameter pack can 1 or 2 elements (key, or key and value)"   288         unsigned int id = detail::logical_lane_id<WarpSize>();
   289         swap< 2>(kv..., 1, 
get_bit(
id, 1) != 
get_bit(
id, 0), compare_function);
   291         swap< 4>(kv..., 2, 
get_bit(
id, 2) != 
get_bit(
id, 1), compare_function);
   292         swap< 4>(kv..., 1, 
get_bit(
id, 2) != 
get_bit(
id, 0), compare_function);
   294         swap< 8>(kv..., 4, 
get_bit(
id, 3) != 
get_bit(
id, 2), compare_function);
   295         swap< 8>(kv..., 2, 
get_bit(
id, 3) != 
get_bit(
id, 1), compare_function);
   296         swap< 8>(kv..., 1, 
get_bit(
id, 3) != 
get_bit(
id, 0), compare_function);
   298         swap<16>(kv..., 8, 
get_bit(
id, 4) != 
get_bit(
id, 3), compare_function);
   299         swap<16>(kv..., 4, 
get_bit(
id, 4) != 
get_bit(
id, 2), compare_function);
   300         swap<16>(kv..., 2, 
get_bit(
id, 4) != 
get_bit(
id, 1), compare_function);
   301         swap<16>(kv..., 1, 
get_bit(
id, 4) != 
get_bit(
id, 0), compare_function);
   303         swap<32>(kv..., 16, 
get_bit(
id, 5) != 
get_bit(
id, 4), compare_function);
   304         swap<32>(kv..., 8,  
get_bit(
id, 5) != 
get_bit(
id, 3), compare_function);
   305         swap<32>(kv..., 4,  
get_bit(
id, 5) != 
get_bit(
id, 2), compare_function);
   306         swap<32>(kv..., 2,  
get_bit(
id, 5) != 
get_bit(
id, 1), compare_function);
   307         swap<32>(kv..., 1,  
get_bit(
id, 5) != 
get_bit(
id, 0), compare_function);
   309         swap<32>(kv..., 32, 
get_bit(
id, 5) != 0, compare_function);
   310         swap<16>(kv..., 16, 
get_bit(
id, 4) != 0, compare_function);
   311         swap< 8>(kv..., 8,  
get_bit(
id, 3) != 0, compare_function);
   312         swap< 4>(kv..., 4,  
get_bit(
id, 2) != 0, compare_function);
   313         swap< 2>(kv..., 2,  
get_bit(
id, 1) != 0, compare_function);
   314         swap< 0>(kv..., 1,  
get_bit(
id, 0) != 0, compare_function);
   318         unsigned int ItemsPerThread,
   319         class BinaryFunction,
   322     ROCPRIM_DEVICE ROCPRIM_INLINE
   323     void bitonic_sort(BinaryFunction compare_function, KeyValue&... kv)
   326             sizeof...(KeyValue) < 3,
   327             "KeyValue parameter pack can 1 or 2 elements (key, or key and value)"   330         static_assert(detail::is_power_of_two(ItemsPerThread), 
"ItemsPerThread must be power of 2");
   332         unsigned int id = detail::logical_lane_id<WarpSize>();
   333         thread_sort<ItemsPerThread>(
get_bit(
id, 0) != 0, compare_function, kv...);
   335         swap< 2>(kv..., 1, 
get_bit(
id, 1) != 
get_bit(
id, 0), compare_function);
   336         thread_merge<2, ItemsPerThread>(
get_bit(
id, 1) != 0, compare_function, kv...);
   338         swap< 4>(kv..., 2, 
get_bit(
id, 2) != 
get_bit(
id, 1), compare_function);
   339         swap< 4>(kv..., 1, 
get_bit(
id, 2) != 
get_bit(
id, 0), compare_function);
   340         thread_merge<4, ItemsPerThread>(
get_bit(
id, 2) != 0, compare_function, kv...);
   342         swap< 8>(kv..., 4, 
get_bit(
id, 3) != 
get_bit(
id, 2), compare_function);
   343         swap< 8>(kv..., 2, 
get_bit(
id, 3) != 
get_bit(
id, 1), compare_function);
   344         swap< 8>(kv..., 1, 
get_bit(
id, 3) != 
get_bit(
id, 0), compare_function);
   345         thread_merge<8, ItemsPerThread>(
get_bit(
id, 3) != 0, compare_function, kv...);
   347         swap<16>(kv..., 8, 
get_bit(
id, 4) != 
get_bit(
id, 3), compare_function);
   348         swap<16>(kv..., 4, 
get_bit(
id, 4) != 
get_bit(
id, 2), compare_function);
   349         swap<16>(kv..., 2, 
get_bit(
id, 4) != 
get_bit(
id, 1), compare_function);
   350         swap<16>(kv..., 1, 
get_bit(
id, 4) != 
get_bit(
id, 0), compare_function);
   351         thread_merge<16, ItemsPerThread>(
get_bit(
id, 4) != 0, compare_function, kv...);
   353         swap<32>(kv..., 16, 
get_bit(
id, 5) != 
get_bit(
id, 4), compare_function);
   354         swap<32>(kv..., 8,  
get_bit(
id, 5) != 
get_bit(
id, 3), compare_function);
   355         swap<32>(kv..., 4,  
get_bit(
id, 5) != 
get_bit(
id, 2), compare_function);
   356         swap<32>(kv..., 2,  
get_bit(
id, 5) != 
get_bit(
id, 1), compare_function);
   357         swap<32>(kv..., 1,  
get_bit(
id, 5) != 
get_bit(
id, 0), compare_function);
   358         thread_merge<32, ItemsPerThread>(
get_bit(
id, 5) != 0, compare_function, kv...);
   360         swap<32>(kv..., 32, 
get_bit(
id, 5) != 0, compare_function);
   361         swap<16>(kv..., 16, 
get_bit(
id, 4) != 0, compare_function);
   362         swap< 8>(kv..., 8,  
get_bit(
id, 3) != 0, compare_function);
   363         swap< 4>(kv..., 4,  
get_bit(
id, 2) != 0, compare_function);
   364         swap< 2>(kv..., 2,  
get_bit(
id, 1) != 0, compare_function);
   365         swap< 0>(kv..., 1,  
get_bit(
id, 0) != 0, compare_function);
   366         thread_merge<1, ItemsPerThread>(
false, compare_function, kv...);
   370     static_assert(detail::is_power_of_two(WarpSize), 
"WarpSize must be power of 2");
   372     using storage_type = ::rocprim::detail::empty_storage_type;
   374     template<
class BinaryFunction>
   375     ROCPRIM_DEVICE ROCPRIM_INLINE
   376     void sort(Key& thread_value, BinaryFunction compare_function)
   379         bitonic_sort(compare_function, thread_value);
   382     template<
class BinaryFunction>
   383     ROCPRIM_DEVICE ROCPRIM_INLINE
   384     void sort(Key& thread_value, storage_type& storage,
   385               BinaryFunction compare_function)
   388         sort(thread_value, compare_function);
   392         unsigned int ItemsPerThread,
   395     ROCPRIM_DEVICE ROCPRIM_INLINE
   396     void sort(Key (&thread_values)[ItemsPerThread],
   397               BinaryFunction compare_function)
   400         bitonic_sort<ItemsPerThread>(compare_function, thread_values);
   404         unsigned int ItemsPerThread,
   407     ROCPRIM_DEVICE ROCPRIM_INLINE
   408     void sort(Key (&thread_values)[ItemsPerThread],
   409               storage_type& storage,
   410               BinaryFunction compare_function)
   413         sort(thread_values, compare_function);
   416     template<
class BinaryFunction, 
class V = Value>
   417     ROCPRIM_DEVICE ROCPRIM_INLINE
   418     typename std::enable_if<(sizeof(V) <= sizeof(int))>::type
   419     sort(Key& thread_key, Value& thread_value,
   420          BinaryFunction compare_function)
   422         bitonic_sort(compare_function, thread_key, thread_value);
   425     template<
class BinaryFunction, 
class V = Value>
   426     ROCPRIM_DEVICE ROCPRIM_INLINE
   427     typename std::enable_if<!(sizeof(V) <= sizeof(int))>::type
   428     sort(Key& thread_key, Value& thread_value,
   429          BinaryFunction compare_function)
   432         unsigned int v = detail::logical_lane_id<WarpSize>();
   433         bitonic_sort(compare_function, thread_key, v);
   437     template<
class BinaryFunction>
   438     ROCPRIM_DEVICE ROCPRIM_INLINE
   439     void sort(Key& thread_key, Value& thread_value,
   440               storage_type& storage, BinaryFunction compare_function)
   443         sort(compare_function, thread_key, thread_value);
   447         unsigned int ItemsPerThread,
   448         class BinaryFunction,
   451     ROCPRIM_DEVICE ROCPRIM_INLINE
   452     typename std::enable_if<(sizeof(V) <= sizeof(int))>::type
   453     sort(Key (&thread_keys)[ItemsPerThread],
   454          Value (&thread_values)[ItemsPerThread],
   455          BinaryFunction compare_function)
   457         bitonic_sort<ItemsPerThread>(compare_function, thread_keys, thread_values);
   461         unsigned int ItemsPerThread,
   462         class BinaryFunction,
   465     ROCPRIM_DEVICE ROCPRIM_INLINE
   466     typename std::enable_if<!(sizeof(V) <= sizeof(int))>::type
   467     sort(Key (&thread_keys)[ItemsPerThread],
   468          Value (&thread_values)[ItemsPerThread],
   469          BinaryFunction compare_function)
   472         unsigned int v[ItemsPerThread];
   474         for (
unsigned int item = 0; item < ItemsPerThread; item++)
   476             v[item] = ItemsPerThread * detail::logical_lane_id<WarpSize>() + item;
   479         bitonic_sort<ItemsPerThread>(compare_function, thread_keys, v);
   481         V copy[ItemsPerThread];
   483         for(
unsigned item = 0; item < ItemsPerThread; ++item) {
   484             copy[item] = thread_values[item];
   488         for(
unsigned int dst_item = 0; dst_item < ItemsPerThread; ++dst_item) {
   490             for(
unsigned src_item = 0; src_item < ItemsPerThread; ++src_item) {
   491                 V temp = 
warp_shuffle(copy[src_item], v[dst_item] / ItemsPerThread, WarpSize);
   492                 if(v[dst_item] % ItemsPerThread == src_item)
   493                     thread_values[dst_item] = temp;
   499         unsigned int ItemsPerThread,
   502     ROCPRIM_DEVICE ROCPRIM_INLINE
   503     void sort(Key (&thread_keys)[ItemsPerThread],
   504               Value (&thread_values)[ItemsPerThread],
   505               storage_type& storage, BinaryFunction compare_function)
   508         sort(thread_keys, thread_values, compare_function);
   514 END_ROCPRIM_NAMESPACE
   516 #endif // ROCPRIM_WARP_DETAIL_WARP_SORT_SHUFFLE_HPP_ Definition: warp_sort_shuffle.hpp:42
ROCPRIM_DEVICE ROCPRIM_INLINE T warp_shuffle(const T &input, const int src_lane, const int width=device_warp_size())
Shuffle for any data type. 
Definition: warp_shuffle.hpp:172
Deprecated: Configuration of device-level scan primitives. 
Definition: block_histogram.hpp:62
ROCPRIM_DEVICE ROCPRIM_INLINE T warp_shuffle_xor(const T &input, const int lane_mask, const int width=device_warp_size())
Shuffle XOR for any data type. 
Definition: warp_shuffle.hpp:246
ROCPRIM_DEVICE ROCPRIM_INLINE int get_bit(int x, int i)
Returns a single bit at 'i' from 'x'. 
Definition: bit.hpp:33