21 #ifndef ROCPRIM_WARP_DETAIL_WARP_SORT_SHUFFLE_HPP_ 22 #define ROCPRIM_WARP_DETAIL_WARP_SORT_SHUFFLE_HPP_ 24 #include <type_traits> 26 #include "../../config.hpp" 27 #include "../../detail/various.hpp" 29 #include "../../intrinsics.hpp" 30 #include "../../functional.hpp" 32 BEGIN_ROCPRIM_NAMESPACE
39 unsigned int WarpSize,
45 template<
int warp,
class V,
class BinaryFunction>
46 ROCPRIM_DEVICE ROCPRIM_INLINE
47 typename std::enable_if<!(WarpSize > warp)>::type
48 swap(Key& k, V& v,
int mask,
bool dir, BinaryFunction compare_function)
54 (void) compare_function;
57 template<
int warp,
class V,
class BinaryFunction>
58 ROCPRIM_DEVICE ROCPRIM_INLINE
59 typename std::enable_if<(WarpSize > warp)>::type
60 swap(Key& k, V& v,
int mask,
bool dir, BinaryFunction compare_function)
64 bool swap = compare_function(dir ? k : k1, dir ? k1 : k);
76 unsigned int ItemsPerThread
78 ROCPRIM_DEVICE ROCPRIM_INLINE
79 typename std::enable_if<!(WarpSize > warp)>::type
80 swap(Key (&k)[ItemsPerThread],
81 V (&v)[ItemsPerThread],
84 BinaryFunction compare_function)
90 (void) compare_function;
97 unsigned int ItemsPerThread
99 ROCPRIM_DEVICE ROCPRIM_INLINE
100 typename std::enable_if<(WarpSize > warp)>::type
101 swap(Key (&k)[ItemsPerThread],
102 V (&v)[ItemsPerThread],
105 BinaryFunction compare_function)
107 Key k1[ItemsPerThread];
109 for (
unsigned int item = 0; item < ItemsPerThread; item++)
113 bool swap = compare_function(dir ? k[item] : k1[item], dir ? k1[item] : k[item]);
122 template<
int warp,
class BinaryFunction>
123 ROCPRIM_DEVICE ROCPRIM_INLINE
124 typename std::enable_if<!(WarpSize > warp)>::type
125 swap(Key& k,
int mask,
bool dir, BinaryFunction compare_function)
130 (void) compare_function;
133 template<
int warp,
class BinaryFunction>
134 ROCPRIM_DEVICE ROCPRIM_INLINE
135 typename std::enable_if<(WarpSize > warp)>::type
136 swap(Key& k,
int mask,
bool dir, BinaryFunction compare_function)
139 bool swap = compare_function(dir ? k : k1, dir ? k1 : k);
148 class BinaryFunction,
149 unsigned int ItemsPerThread
151 ROCPRIM_DEVICE ROCPRIM_INLINE
152 typename std::enable_if<!(WarpSize > warp)>::type
153 swap(Key (&k)[ItemsPerThread],
int mask,
bool dir, BinaryFunction compare_function)
158 (void) compare_function;
163 class BinaryFunction,
164 unsigned int ItemsPerThread
166 ROCPRIM_DEVICE ROCPRIM_INLINE
167 typename std::enable_if<(WarpSize > warp)>::type
168 swap(Key (&k)[ItemsPerThread],
int mask,
bool dir, BinaryFunction compare_function)
170 Key k1[ItemsPerThread];
172 for (
unsigned int item = 0; item < ItemsPerThread; item++)
175 bool swap = compare_function(dir ? k[item] : k1[item], dir ? k1[item] : k[item]);
183 template <
unsigned int ItemsPerThread,
class BinaryFunction>
184 ROCPRIM_DEVICE ROCPRIM_INLINE
185 void thread_swap(Key (&k)[ItemsPerThread],
189 BinaryFunction compare_function)
191 if(compare_function(k[i], k[j]) == dir)
199 template <
unsigned int ItemsPerThread,
class V,
class BinaryFunction>
200 ROCPRIM_DEVICE ROCPRIM_INLINE
201 void thread_swap(Key (&k)[ItemsPerThread],
202 V (&v)[ItemsPerThread],
206 BinaryFunction compare_function)
208 if(compare_function(k[i], k[j]) == dir)
219 template <
unsigned int ItemsPerThread,
class BinaryFunction,
class... KeyValue>
220 ROCPRIM_DEVICE ROCPRIM_INLINE
221 void thread_shuffle(
unsigned int group_size,
224 BinaryFunction compare_function,
228 for(
unsigned int base = 0; base < ItemsPerThread; base += 2 * offset) {
231 const bool local_dir = ((base & group_size) > 0) != dir;
236 #if defined(__clang_major__) && __clang_major__ >= 15 237 #pragma clang loop vectorize(disable) 239 for(
unsigned i = 0; i < offset; ++i) {
240 thread_swap(kv..., base + i, base + i + offset, local_dir, compare_function);
245 template <
unsigned int ItemsPerThread,
class BinaryFunction,
class... KeyValue>
246 ROCPRIM_DEVICE ROCPRIM_INLINE
247 void thread_sort(
bool dir, BinaryFunction compare_function, KeyValue&... kv)
250 for(
unsigned int k = 2; k <= ItemsPerThread; k *= 2)
253 for(
unsigned int j = k / 2; j > 0; j /= 2)
255 thread_shuffle<ItemsPerThread>(k, j, dir, compare_function, kv...);
260 template <
int warp,
unsigned int ItemsPerThread,
class BinaryFunction,
class... KeyValue>
261 ROCPRIM_DEVICE ROCPRIM_INLINE
262 typename std::enable_if<(WarpSize > warp)>::type
263 thread_merge(
bool dir, BinaryFunction compare_function, KeyValue&... kv)
266 for(
unsigned int j = ItemsPerThread / 2; j > 0; j /= 2)
268 thread_shuffle<ItemsPerThread>(ItemsPerThread, j, dir, compare_function, kv...);
272 template <
int warp,
unsigned int ItemsPerThread,
class BinaryFunction,
class... KeyValue>
273 ROCPRIM_DEVICE ROCPRIM_INLINE
274 typename std::enable_if<!(WarpSize > warp)>::type
275 thread_merge(
bool , BinaryFunction , KeyValue&... )
279 template<
class BinaryFunction,
class... KeyValue>
280 ROCPRIM_DEVICE ROCPRIM_INLINE
281 void bitonic_sort(BinaryFunction compare_function, KeyValue&... kv)
284 sizeof...(KeyValue) < 3,
285 "KeyValue parameter pack can 1 or 2 elements (key, or key and value)" 288 unsigned int id = detail::logical_lane_id<WarpSize>();
289 swap< 2>(kv..., 1,
get_bit(
id, 1) !=
get_bit(
id, 0), compare_function);
291 swap< 4>(kv..., 2,
get_bit(
id, 2) !=
get_bit(
id, 1), compare_function);
292 swap< 4>(kv..., 1,
get_bit(
id, 2) !=
get_bit(
id, 0), compare_function);
294 swap< 8>(kv..., 4,
get_bit(
id, 3) !=
get_bit(
id, 2), compare_function);
295 swap< 8>(kv..., 2,
get_bit(
id, 3) !=
get_bit(
id, 1), compare_function);
296 swap< 8>(kv..., 1,
get_bit(
id, 3) !=
get_bit(
id, 0), compare_function);
298 swap<16>(kv..., 8,
get_bit(
id, 4) !=
get_bit(
id, 3), compare_function);
299 swap<16>(kv..., 4,
get_bit(
id, 4) !=
get_bit(
id, 2), compare_function);
300 swap<16>(kv..., 2,
get_bit(
id, 4) !=
get_bit(
id, 1), compare_function);
301 swap<16>(kv..., 1,
get_bit(
id, 4) !=
get_bit(
id, 0), compare_function);
303 swap<32>(kv..., 16,
get_bit(
id, 5) !=
get_bit(
id, 4), compare_function);
304 swap<32>(kv..., 8,
get_bit(
id, 5) !=
get_bit(
id, 3), compare_function);
305 swap<32>(kv..., 4,
get_bit(
id, 5) !=
get_bit(
id, 2), compare_function);
306 swap<32>(kv..., 2,
get_bit(
id, 5) !=
get_bit(
id, 1), compare_function);
307 swap<32>(kv..., 1,
get_bit(
id, 5) !=
get_bit(
id, 0), compare_function);
309 swap<32>(kv..., 32,
get_bit(
id, 5) != 0, compare_function);
310 swap<16>(kv..., 16,
get_bit(
id, 4) != 0, compare_function);
311 swap< 8>(kv..., 8,
get_bit(
id, 3) != 0, compare_function);
312 swap< 4>(kv..., 4,
get_bit(
id, 2) != 0, compare_function);
313 swap< 2>(kv..., 2,
get_bit(
id, 1) != 0, compare_function);
314 swap< 0>(kv..., 1,
get_bit(
id, 0) != 0, compare_function);
318 unsigned int ItemsPerThread,
319 class BinaryFunction,
322 ROCPRIM_DEVICE ROCPRIM_INLINE
323 void bitonic_sort(BinaryFunction compare_function, KeyValue&... kv)
326 sizeof...(KeyValue) < 3,
327 "KeyValue parameter pack can 1 or 2 elements (key, or key and value)" 330 static_assert(detail::is_power_of_two(ItemsPerThread),
"ItemsPerThread must be power of 2");
332 unsigned int id = detail::logical_lane_id<WarpSize>();
333 thread_sort<ItemsPerThread>(
get_bit(
id, 0) != 0, compare_function, kv...);
335 swap< 2>(kv..., 1,
get_bit(
id, 1) !=
get_bit(
id, 0), compare_function);
336 thread_merge<2, ItemsPerThread>(
get_bit(
id, 1) != 0, compare_function, kv...);
338 swap< 4>(kv..., 2,
get_bit(
id, 2) !=
get_bit(
id, 1), compare_function);
339 swap< 4>(kv..., 1,
get_bit(
id, 2) !=
get_bit(
id, 0), compare_function);
340 thread_merge<4, ItemsPerThread>(
get_bit(
id, 2) != 0, compare_function, kv...);
342 swap< 8>(kv..., 4,
get_bit(
id, 3) !=
get_bit(
id, 2), compare_function);
343 swap< 8>(kv..., 2,
get_bit(
id, 3) !=
get_bit(
id, 1), compare_function);
344 swap< 8>(kv..., 1,
get_bit(
id, 3) !=
get_bit(
id, 0), compare_function);
345 thread_merge<8, ItemsPerThread>(
get_bit(
id, 3) != 0, compare_function, kv...);
347 swap<16>(kv..., 8,
get_bit(
id, 4) !=
get_bit(
id, 3), compare_function);
348 swap<16>(kv..., 4,
get_bit(
id, 4) !=
get_bit(
id, 2), compare_function);
349 swap<16>(kv..., 2,
get_bit(
id, 4) !=
get_bit(
id, 1), compare_function);
350 swap<16>(kv..., 1,
get_bit(
id, 4) !=
get_bit(
id, 0), compare_function);
351 thread_merge<16, ItemsPerThread>(
get_bit(
id, 4) != 0, compare_function, kv...);
353 swap<32>(kv..., 16,
get_bit(
id, 5) !=
get_bit(
id, 4), compare_function);
354 swap<32>(kv..., 8,
get_bit(
id, 5) !=
get_bit(
id, 3), compare_function);
355 swap<32>(kv..., 4,
get_bit(
id, 5) !=
get_bit(
id, 2), compare_function);
356 swap<32>(kv..., 2,
get_bit(
id, 5) !=
get_bit(
id, 1), compare_function);
357 swap<32>(kv..., 1,
get_bit(
id, 5) !=
get_bit(
id, 0), compare_function);
358 thread_merge<32, ItemsPerThread>(
get_bit(
id, 5) != 0, compare_function, kv...);
360 swap<32>(kv..., 32,
get_bit(
id, 5) != 0, compare_function);
361 swap<16>(kv..., 16,
get_bit(
id, 4) != 0, compare_function);
362 swap< 8>(kv..., 8,
get_bit(
id, 3) != 0, compare_function);
363 swap< 4>(kv..., 4,
get_bit(
id, 2) != 0, compare_function);
364 swap< 2>(kv..., 2,
get_bit(
id, 1) != 0, compare_function);
365 swap< 0>(kv..., 1,
get_bit(
id, 0) != 0, compare_function);
366 thread_merge<1, ItemsPerThread>(
false, compare_function, kv...);
370 static_assert(detail::is_power_of_two(WarpSize),
"WarpSize must be power of 2");
372 using storage_type = ::rocprim::detail::empty_storage_type;
374 template<
class BinaryFunction>
375 ROCPRIM_DEVICE ROCPRIM_INLINE
376 void sort(Key& thread_value, BinaryFunction compare_function)
379 bitonic_sort(compare_function, thread_value);
382 template<
class BinaryFunction>
383 ROCPRIM_DEVICE ROCPRIM_INLINE
384 void sort(Key& thread_value, storage_type& storage,
385 BinaryFunction compare_function)
388 sort(thread_value, compare_function);
392 unsigned int ItemsPerThread,
395 ROCPRIM_DEVICE ROCPRIM_INLINE
396 void sort(Key (&thread_values)[ItemsPerThread],
397 BinaryFunction compare_function)
400 bitonic_sort<ItemsPerThread>(compare_function, thread_values);
404 unsigned int ItemsPerThread,
407 ROCPRIM_DEVICE ROCPRIM_INLINE
408 void sort(Key (&thread_values)[ItemsPerThread],
409 storage_type& storage,
410 BinaryFunction compare_function)
413 sort(thread_values, compare_function);
416 template<
class BinaryFunction,
class V = Value>
417 ROCPRIM_DEVICE ROCPRIM_INLINE
418 typename std::enable_if<(sizeof(V) <= sizeof(int))>::type
419 sort(Key& thread_key, Value& thread_value,
420 BinaryFunction compare_function)
422 bitonic_sort(compare_function, thread_key, thread_value);
425 template<
class BinaryFunction,
class V = Value>
426 ROCPRIM_DEVICE ROCPRIM_INLINE
427 typename std::enable_if<!(sizeof(V) <= sizeof(int))>::type
428 sort(Key& thread_key, Value& thread_value,
429 BinaryFunction compare_function)
432 unsigned int v = detail::logical_lane_id<WarpSize>();
433 bitonic_sort(compare_function, thread_key, v);
437 template<
class BinaryFunction>
438 ROCPRIM_DEVICE ROCPRIM_INLINE
439 void sort(Key& thread_key, Value& thread_value,
440 storage_type& storage, BinaryFunction compare_function)
443 sort(compare_function, thread_key, thread_value);
447 unsigned int ItemsPerThread,
448 class BinaryFunction,
451 ROCPRIM_DEVICE ROCPRIM_INLINE
452 typename std::enable_if<(sizeof(V) <= sizeof(int))>::type
453 sort(Key (&thread_keys)[ItemsPerThread],
454 Value (&thread_values)[ItemsPerThread],
455 BinaryFunction compare_function)
457 bitonic_sort<ItemsPerThread>(compare_function, thread_keys, thread_values);
461 unsigned int ItemsPerThread,
462 class BinaryFunction,
465 ROCPRIM_DEVICE ROCPRIM_INLINE
466 typename std::enable_if<!(sizeof(V) <= sizeof(int))>::type
467 sort(Key (&thread_keys)[ItemsPerThread],
468 Value (&thread_values)[ItemsPerThread],
469 BinaryFunction compare_function)
472 unsigned int v[ItemsPerThread];
474 for (
unsigned int item = 0; item < ItemsPerThread; item++)
476 v[item] = ItemsPerThread * detail::logical_lane_id<WarpSize>() + item;
479 bitonic_sort<ItemsPerThread>(compare_function, thread_keys, v);
481 V copy[ItemsPerThread];
483 for(
unsigned item = 0; item < ItemsPerThread; ++item) {
484 copy[item] = thread_values[item];
488 for(
unsigned int dst_item = 0; dst_item < ItemsPerThread; ++dst_item) {
490 for(
unsigned src_item = 0; src_item < ItemsPerThread; ++src_item) {
491 V temp =
warp_shuffle(copy[src_item], v[dst_item] / ItemsPerThread, WarpSize);
492 if(v[dst_item] % ItemsPerThread == src_item)
493 thread_values[dst_item] = temp;
499 unsigned int ItemsPerThread,
502 ROCPRIM_DEVICE ROCPRIM_INLINE
503 void sort(Key (&thread_keys)[ItemsPerThread],
504 Value (&thread_values)[ItemsPerThread],
505 storage_type& storage, BinaryFunction compare_function)
508 sort(thread_keys, thread_values, compare_function);
514 END_ROCPRIM_NAMESPACE
516 #endif // ROCPRIM_WARP_DETAIL_WARP_SORT_SHUFFLE_HPP_ Definition: warp_sort_shuffle.hpp:42
ROCPRIM_DEVICE ROCPRIM_INLINE T warp_shuffle(const T &input, const int src_lane, const int width=device_warp_size())
Shuffle for any data type.
Definition: warp_shuffle.hpp:172
Deprecated: Configuration of device-level scan primitives.
Definition: block_histogram.hpp:62
ROCPRIM_DEVICE ROCPRIM_INLINE T warp_shuffle_xor(const T &input, const int lane_mask, const int width=device_warp_size())
Shuffle XOR for any data type.
Definition: warp_shuffle.hpp:246
ROCPRIM_DEVICE ROCPRIM_INLINE int get_bit(int x, int i)
Returns a single bit at 'i' from 'x'.
Definition: bit.hpp:33