rocPRIM
block_sort_bitonic.hpp
1 // Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved.
2 //
3 // Permission is hereby granted, free of charge, to any person obtaining a copy
4 // of this software and associated documentation files (the "Software"), to deal
5 // in the Software without restriction, including without limitation the rights
6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 // copies of the Software, and to permit persons to whom the Software is
8 // furnished to do so, subject to the following conditions:
9 //
10 // The above copyright notice and this permission notice shall be included in
11 // all copies or substantial portions of the Software.
12 //
13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 // THE SOFTWARE.
20 
21 #ifndef ROCPRIM_BLOCK_DETAIL_BLOCK_SORT_SHARED_HPP_
22 #define ROCPRIM_BLOCK_DETAIL_BLOCK_SORT_SHARED_HPP_
23 
24 #include <type_traits>
25 
26 #include "../../config.hpp"
27 #include "../../detail/various.hpp"
28 
29 #include "../../intrinsics.hpp"
30 #include "../../functional.hpp"
31 
32 #include "../../warp/warp_sort.hpp"
33 
34 BEGIN_ROCPRIM_NAMESPACE
35 
36 namespace detail
37 {
38 
39 template<
40  class Key,
41  unsigned int BlockSizeX,
42  unsigned int BlockSizeY,
43  unsigned int BlockSizeZ,
44  unsigned int ItemsPerThread,
45  class Value
46 >
48 {
49  static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ;
50  static constexpr unsigned int ItemsPerBlock = BlockSize * ItemsPerThread;
51 
52  template<class KeyType, class ValueType>
53  struct storage_type_
54  {
55  KeyType key[BlockSize * ItemsPerThread];
56  ValueType value[BlockSize * ItemsPerThread];
57  };
58 
59  template<class KeyType>
60  struct storage_type_<KeyType, empty_type>
61  {
62  KeyType key[BlockSize * ItemsPerThread];
63  };
64 
65 public:
67 
68  template <class BinaryFunction>
69  ROCPRIM_DEVICE ROCPRIM_INLINE
70  void sort(Key& thread_key,
71  storage_type& storage,
72  BinaryFunction compare_function)
73  {
74  this->sort_impl<BlockSize, ItemsPerThread>(
75  ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
76  storage,
77  compare_function,
78  thread_key);
79  }
80 
81  template<class BinaryFunction>
82  ROCPRIM_DEVICE ROCPRIM_INLINE
83  void sort(Key (&thread_keys)[ItemsPerThread],
84  storage_type& storage,
85  BinaryFunction compare_function)
86  {
87  this->sort_impl<BlockSize, ItemsPerThread>(
88  ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
89  storage,
90  compare_function,
91  thread_keys);
92  }
93 
94  template<class BinaryFunction>
95  ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
96  void sort(Key& thread_key,
97  BinaryFunction compare_function)
98  {
99  ROCPRIM_SHARED_MEMORY storage_type storage;
100  this->sort(thread_key, storage, compare_function);
101  }
102 
103  template<class BinaryFunction>
104  ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
105  void sort(Key (&thread_keys)[ItemsPerThread],
106  BinaryFunction compare_function)
107  {
108  ROCPRIM_SHARED_MEMORY storage_type storage;
109  this->sort(thread_keys, storage, compare_function);
110  }
111 
112  template<class BinaryFunction>
113  ROCPRIM_DEVICE ROCPRIM_INLINE
114  void sort(Key& thread_key,
115  Value& thread_value,
116  storage_type& storage,
117  BinaryFunction compare_function)
118  {
119  this->sort_impl<BlockSize, ItemsPerThread>(
120  ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
121  storage,
122  compare_function,
123  thread_key,
124  thread_value);
125  }
126 
127  template<class BinaryFunction>
128  ROCPRIM_DEVICE ROCPRIM_INLINE void sort(Key (&thread_keys)[ItemsPerThread],
129  Value (&thread_values)[ItemsPerThread],
130  storage_type& storage,
131  BinaryFunction compare_function)
132  {
133  this->sort_impl<BlockSize, ItemsPerThread>(
134  ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
135  storage,
136  compare_function,
137  thread_keys,
138  thread_values);
139  }
140 
141  template<class BinaryFunction>
142  ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
143  void sort(Key& thread_key,
144  Value& thread_value,
145  BinaryFunction compare_function)
146  {
147  ROCPRIM_SHARED_MEMORY storage_type storage;
148  this->sort(thread_key, thread_value, storage, compare_function);
149  }
150 
151  template<class BinaryFunction>
152  ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
153  void sort(Key (&thread_keys)[ItemsPerThread],
154  Value (&thread_values)[ItemsPerThread],
155  BinaryFunction compare_function)
156  {
157  ROCPRIM_SHARED_MEMORY storage_type storage;
158  this->sort(thread_keys, thread_values, storage, compare_function);
159  }
160 
161  template<class BinaryFunction>
162  ROCPRIM_DEVICE ROCPRIM_INLINE void sort(Key& thread_key,
163  storage_type& storage,
164  const unsigned int size,
165  BinaryFunction compare_function)
166  {
167  this->sort_impl(::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
168  size,
169  storage,
170  compare_function,
171  thread_key);
172  }
173 
174  template<class BinaryFunction>
175  ROCPRIM_DEVICE ROCPRIM_INLINE void sort(Key (&thread_keys)[ItemsPerThread],
176  storage_type& storage,
177  const unsigned int size,
178  BinaryFunction compare_function)
179  {
180  this->sort_impl(::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
181  size,
182  storage,
183  compare_function,
184  thread_keys);
185  }
186 
187  template<class BinaryFunction>
188  ROCPRIM_DEVICE ROCPRIM_INLINE void sort(Key& thread_key,
189  Value& thread_value,
190  storage_type& storage,
191  const unsigned int size,
192  BinaryFunction compare_function)
193  {
194  this->sort_impl(::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
195  size,
196  storage,
197  compare_function,
198  thread_key,
199  thread_value);
200  }
201 
202  template<class BinaryFunction>
203  ROCPRIM_DEVICE ROCPRIM_INLINE void sort(Key (&thread_keys)[ItemsPerThread],
204  Value (&thread_values)[ItemsPerThread],
205  storage_type& storage,
206  const unsigned int size,
207  BinaryFunction compare_function)
208  {
209  this->sort_impl(::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
210  size,
211  storage,
212  compare_function,
213  thread_keys,
214  thread_values);
215  }
216 
217 private:
218  ROCPRIM_DEVICE ROCPRIM_INLINE
219  void copy_to_shared(Key& k, const unsigned int flat_tid, storage_type& storage)
220  {
221  storage_type_<Key, Value>& storage_ = storage.get();
222  storage_.key[flat_tid] = k;
224  }
225 
226  ROCPRIM_DEVICE ROCPRIM_INLINE
227  void copy_to_shared(Key (&k)[ItemsPerThread], const unsigned int flat_tid, storage_type& storage) {
228  storage_type_<Key, Value>& storage_ = storage.get();
229  ROCPRIM_UNROLL
230  for(unsigned int item = 0; item < ItemsPerThread; ++item) {
231  storage_.key[item * BlockSize + flat_tid] = k[item];
232  }
234  }
235 
236  ROCPRIM_DEVICE ROCPRIM_INLINE
237  void copy_to_shared(Key& k, Value& v, const unsigned int flat_tid, storage_type& storage)
238  {
239  storage_type_<Key, Value>& storage_ = storage.get();
240  storage_.key[flat_tid] = k;
241  storage_.value[flat_tid] = v;
243  }
244 
245  ROCPRIM_DEVICE ROCPRIM_INLINE
246  void copy_to_shared(Key (&k)[ItemsPerThread],
247  Value (&v)[ItemsPerThread],
248  const unsigned int flat_tid,
249  storage_type& storage)
250  {
251  storage_type_<Key, Value>& storage_ = storage.get();
252  ROCPRIM_UNROLL
253  for(unsigned int item = 0; item < ItemsPerThread; ++item) {
254  storage_.key[item * BlockSize + flat_tid] = k[item];
255  storage_.value[item * BlockSize + flat_tid] = v[item];
256  }
258  }
259 
260  template<class BinaryFunction>
261  ROCPRIM_DEVICE ROCPRIM_INLINE
262  void swap(Key& key,
263  const unsigned int flat_tid,
264  const unsigned int next_id,
265  const bool dir,
266  storage_type& storage,
267  BinaryFunction compare_function)
268  {
269  storage_type_<Key, Value>& storage_ = storage.get();
270  Key next_key = storage_.key[next_id];
271  bool compare = (next_id < flat_tid) ? compare_function(key, next_key) : compare_function(next_key, key);
272  bool swap = compare ^ dir;
273  if(swap)
274  {
275  key = next_key;
276  }
277  }
278 
279  template<class BinaryFunction>
280  ROCPRIM_DEVICE ROCPRIM_INLINE
281  void swap(Key (&key)[ItemsPerThread],
282  const unsigned int flat_tid,
283  const unsigned int next_id,
284  const bool dir,
285  storage_type& storage,
286  BinaryFunction compare_function)
287  {
288  storage_type_<Key, Value>& storage_ = storage.get();
289  ROCPRIM_UNROLL
290  for(unsigned int item = 0; item < ItemsPerThread; ++item) {
291  Key next_key = storage_.key[item * BlockSize + next_id];
292  bool compare = (next_id < flat_tid) ? compare_function(key[item], next_key) : compare_function(next_key, key[item]);
293  bool swap = compare ^ dir;
294  if(swap)
295  {
296  key[item] = next_key;
297  }
298  }
299  }
300 
301  template<class BinaryFunction>
302  ROCPRIM_DEVICE ROCPRIM_INLINE
303  void swap(Key& key,
304  Value& value,
305  const unsigned int flat_tid,
306  const unsigned int next_id,
307  const bool dir,
308  storage_type& storage,
309  BinaryFunction compare_function)
310  {
311  storage_type_<Key, Value>& storage_ = storage.get();
312  Key next_key = storage_.key[next_id];
313  bool b = next_id < flat_tid;
314  bool compare = compare_function(b ? key : next_key, b ? next_key : key);
315  bool swap = compare ^ dir;
316  if(swap)
317  {
318  key = next_key;
319  value = storage_.value[next_id];
320  }
321  }
322 
323  template<class BinaryFunction>
324  ROCPRIM_DEVICE ROCPRIM_INLINE
325  void swap(Key (&key)[ItemsPerThread],
326  Value (&value)[ItemsPerThread],
327  const unsigned int flat_tid,
328  const unsigned int next_id,
329  const bool dir,
330  storage_type& storage,
331  BinaryFunction compare_function)
332  {
333  storage_type_<Key, Value>& storage_ = storage.get();
334  ROCPRIM_UNROLL
335  for(unsigned int item = 0; item < ItemsPerThread; ++item) {
336  Key next_key = storage_.key[item * BlockSize + next_id];
337  bool b = next_id < flat_tid;
338  bool compare = compare_function(b ? key[item] : next_key, b ? next_key : key[item]);
339  bool swap = compare ^ dir;
340  if(swap)
341  {
342  key[item] = next_key;
343  value[item] = storage_.value[item * BlockSize + next_id];
344  }
345  }
346  }
347 
348  template<class BinaryFunction>
349  ROCPRIM_DEVICE ROCPRIM_INLINE void swap_oddeven(Key& key,
350  const unsigned int next_id,
351  const unsigned int /* item */,
352  const unsigned int next_item_id,
353  bool dir,
354  storage_type& storage,
355  BinaryFunction compare_function)
356  {
357  storage_type_<Key, Value>& storage_ = storage.get();
358  Key next_key = storage_.key[next_item_id * BlockSize + next_id];
359  // swap items instead of branching for compare function, to achieve superior perf (ROCm 5.3)
360  if(dir)
361  {
362  rocprim::swap(next_key, key);
363  }
364  bool swap = compare_function(next_key, key);
365  if(dir)
366  {
367  rocprim::swap(next_key, key);
368  }
369  if(swap)
370  {
371  key = next_key;
372  }
373  }
374 
375  template<class BinaryFunction>
376  ROCPRIM_DEVICE ROCPRIM_INLINE void swap_oddeven(Key (&keys)[ItemsPerThread],
377  const unsigned int next_id,
378  const unsigned int item,
379  const unsigned int next_item_id,
380  bool dir,
381  storage_type& storage,
382  BinaryFunction compare_function)
383  {
384  storage_type_<Key, Value>& storage_ = storage.get();
385  Key next_key = storage_.key[next_item_id * BlockSize + next_id];
386  // swap items instead of branching for compare function, to achieve superior perf (ROCm 5.3)
387  if(dir)
388  {
389  rocprim::swap(next_key, keys[item]);
390  }
391  bool swap = compare_function(next_key, keys[item]);
392  if(dir)
393  {
394  rocprim::swap(next_key, keys[item]);
395  }
396  if(swap)
397  {
398  keys[item] = next_key;
399  }
400  }
401 
402  template<class BinaryFunction>
403  ROCPRIM_DEVICE ROCPRIM_INLINE void swap_oddeven(Key& key,
404  Value& value,
405  const unsigned int next_id,
406  const unsigned int /* item */,
407  const unsigned int next_item_id,
408  bool dir,
409  storage_type& storage,
410  BinaryFunction compare_function)
411  {
412  storage_type_<Key, Value>& storage_ = storage.get();
413  Key next_key = storage_.key[next_item_id * BlockSize + next_id];
414  // swap items instead of branching for compare function, to achieve superior perf (ROCm 5.3)
415  if(dir)
416  {
417  rocprim::swap(next_key, key);
418  }
419  bool swap = compare_function(next_key, key);
420  if(dir)
421  {
422  rocprim::swap(next_key, key);
423  }
424  if(swap)
425  {
426  key = next_key;
427  value = storage_.value[next_item_id * BlockSize + next_id];
428  }
429  }
430 
431  template<class BinaryFunction>
432  ROCPRIM_DEVICE ROCPRIM_INLINE void swap_oddeven(Key (&keys)[ItemsPerThread],
433  Value (&values)[ItemsPerThread],
434  const unsigned int next_id,
435  const unsigned int item,
436  const unsigned int next_item_id,
437  bool dir,
438  storage_type& storage,
439  BinaryFunction compare_function)
440  {
441  storage_type_<Key, Value>& storage_ = storage.get();
442  Key next_key = storage_.key[next_item_id * BlockSize + next_id];
443  // swap items instead of branching for compare function, to achieve superior perf (ROCm 5.3)
444  if(dir)
445  {
446  rocprim::swap(next_key, keys[item]);
447  }
448  bool swap = compare_function(next_key, keys[item]);
449  if(dir)
450  {
451  rocprim::swap(next_key, keys[item]);
452  }
453  if(swap)
454  {
455  keys[item] = next_key;
456  values[item] = storage_.value[next_item_id * BlockSize + next_id];
457  }
458  }
459 
460  template<
461  unsigned int Size,
462  class BinaryFunction,
463  class... KeyValue
464  >
465  ROCPRIM_DEVICE ROCPRIM_INLINE
466  typename std::enable_if<(Size <= ::rocprim::device_warp_size())>::type
467  sort_power_two(const unsigned int flat_tid,
468  storage_type& storage,
469  BinaryFunction compare_function,
470  KeyValue&... kv)
471  {
472  (void) flat_tid;
473  (void) storage;
474 
475  ::rocprim::warp_sort<Key, Size, Value> wsort;
476  wsort.sort(kv..., compare_function);
477  }
478 
479  template<class BinaryFunction>
480  ROCPRIM_DEVICE ROCPRIM_INLINE
481  void warp_swap(Key& k, Value& v, int mask, bool dir, BinaryFunction compare_function)
482  {
483  Key k1 = warp_shuffle_xor(k, mask);
484  bool swap = compare_function(dir ? k : k1, dir ? k1 : k);
485  if (swap)
486  {
487  k = k1;
488  v = warp_shuffle_xor(v, mask);
489  }
490  }
491 
492  template <class BinaryFunction>
493  ROCPRIM_DEVICE ROCPRIM_INLINE
494  void warp_swap(Key (&k)[ItemsPerThread],
495  Value (&v)[ItemsPerThread],
496  int mask,
497  bool dir,
498  BinaryFunction compare_function)
499  {
500  ROCPRIM_UNROLL
501  for(unsigned int item = 0; item < ItemsPerThread; ++item) {
502  Key k1 = warp_shuffle_xor(k[item], mask);
503  bool swap = compare_function(dir ? k[item] : k1, dir ? k1 : k[item]);
504  if (swap)
505  {
506  k[item] = k1;
507  v[item] = warp_shuffle_xor(v[item], mask);
508  }
509  }
510  }
511 
512  template<class BinaryFunction>
513  ROCPRIM_DEVICE ROCPRIM_INLINE
514  void warp_swap(Key& k, int mask, bool dir, BinaryFunction compare_function)
515  {
516  Key k1 = warp_shuffle_xor(k, mask);
517  bool swap = compare_function(dir ? k : k1, dir ? k1 : k);
518  if (swap)
519  {
520  k = k1;
521  }
522  }
523 
524  template <class BinaryFunction>
525  ROCPRIM_DEVICE ROCPRIM_INLINE
526  void warp_swap(Key (&k)[ItemsPerThread], int mask, bool dir, BinaryFunction compare_function)
527  {
528  ROCPRIM_UNROLL
529  for(unsigned int item = 0; item < ItemsPerThread; ++item) {
530  Key k1 = warp_shuffle_xor(k[item], mask);
531  bool swap = compare_function(dir ? k[item] : k1, dir ? k1 : k[item]);
532  if (swap)
533  {
534  k[item] = k1;
535  }
536  }
537  }
538 
539  template <class BinaryFunction, unsigned int Items = ItemsPerThread, class... KeyValue>
540  ROCPRIM_DEVICE ROCPRIM_INLINE
541  typename std::enable_if<(Items < 2)>::type
542  thread_merge(bool /*dir*/, BinaryFunction /*compare_function*/, KeyValue&... /*kv*/)
543  {
544  }
545 
546  template <class BinaryFunction>
547  ROCPRIM_DEVICE ROCPRIM_INLINE
548  void thread_swap(Key (&k)[ItemsPerThread],
549  Value (&v)[ItemsPerThread],
550  bool dir,
551  unsigned int i,
552  unsigned int j,
553  BinaryFunction compare_function)
554  {
555  if(compare_function(k[i], k[j]) == dir)
556  {
557  Key k_temp = k[i];
558  k[i] = k[j];
559  k[j] = k_temp;
560  Value v_temp = v[i];
561  v[i] = v[j];
562  v[j] = v_temp;
563  }
564  }
565  template <class BinaryFunction>
566  ROCPRIM_DEVICE ROCPRIM_INLINE
567  void thread_swap(Key (&k)[ItemsPerThread],
568  bool dir,
569  unsigned int i,
570  unsigned int j,
571  BinaryFunction compare_function)
572  {
573  if(compare_function(k[i], k[j]) == dir)
574  {
575  Key k_temp = k[i];
576  k[i] = k[j];
577  k[j] = k_temp;
578  }
579  }
580 
581  template <class BinaryFunction, class... KeyValue>
582  ROCPRIM_DEVICE ROCPRIM_INLINE
583  void thread_shuffle(unsigned int offset, bool dir, BinaryFunction compare_function, KeyValue&... kv)
584  {
585  ROCPRIM_UNROLL
586  for(unsigned base = 0; base < ItemsPerThread; base += 2 * offset)
587  {
588  ROCPRIM_UNROLL
589 // Workaround to prevent the compiler thinking this is a 'Parallel Loop' on clang 15
590 // because it leads to invalid code generation with `T` = `char` and `ItemsPerthread` = 4
591 #if defined(__clang_major__) && __clang_major__ >= 15
592  #pragma clang loop vectorize(disable)
593 #endif
594  for(unsigned i = 0; i < offset; ++i)
595  {
596  thread_swap(kv..., dir, base + i, base + i + offset, compare_function);
597  }
598  }
599  }
600 
601  template <class BinaryFunction, unsigned int Items = ItemsPerThread, class... KeyValue>
602  ROCPRIM_DEVICE ROCPRIM_INLINE
603  typename std::enable_if<!(Items < 2)>::type
604  thread_merge(bool dir, BinaryFunction compare_function, KeyValue&... kv)
605  {
606  ROCPRIM_UNROLL
607  for(unsigned int k = ItemsPerThread / 2; k > 0; k /= 2)
608  {
609  thread_shuffle(k, dir, compare_function, kv...);
610  }
611  }
612 
615  template<unsigned int BS, class BinaryFunction, class... KeyValue>
616  ROCPRIM_DEVICE ROCPRIM_INLINE
617  typename std::enable_if<(BS > ::rocprim::device_warp_size())>::type
618  sort_power_two(const unsigned int flat_tid,
619  storage_type& storage,
620  BinaryFunction compare_function,
621  KeyValue&... kv)
622  {
623  const auto warp_id_is_even = ((flat_tid / ::rocprim::device_warp_size()) % 2) == 0;
624  ::rocprim::warp_sort<Key, ::rocprim::device_warp_size(), Value> wsort;
625  auto compare_function2 =
626  [compare_function, warp_id_is_even](const Key& a, const Key& b) mutable -> bool
627  {
628  auto r = compare_function(a, b);
629  if(warp_id_is_even)
630  return r;
631  return !r;
632  };
633  wsort.sort(kv..., compare_function2);
634 
635  ROCPRIM_UNROLL
636  for(unsigned int length = ::rocprim::device_warp_size(); length < BS; length *= 2)
637  {
638  const bool dir = (flat_tid & (length * 2)) != 0;
639  ROCPRIM_UNROLL
640  for(unsigned int k = length; k > ::rocprim::device_warp_size() / 2; k /= 2)
641  {
642  copy_to_shared(kv..., flat_tid, storage);
643  swap(kv..., flat_tid, flat_tid ^ k, dir, storage, compare_function);
645  }
646 
647  ROCPRIM_UNROLL
648  for(unsigned int k = ::rocprim::device_warp_size() / 2; k > 0; k /= 2)
649  {
650  const bool length_even = ((detail::logical_lane_id<::rocprim::device_warp_size()>() / k ) % 2 ) == 0;
651  const bool local_dir = length_even ? dir : !dir;
652  warp_swap(kv..., k, local_dir, compare_function);
653  }
654  thread_merge(dir, compare_function, kv...);
655  }
656  }
657 
658  template<unsigned int BS, unsigned int IPT, class BinaryFunction, class... KeyValue>
659  ROCPRIM_DEVICE ROCPRIM_INLINE
660  typename std::enable_if<is_power_of_two(BS) && is_power_of_two(IPT)>::type
661  sort_impl(const unsigned int flat_tid,
662  storage_type& storage,
663  BinaryFunction compare_function,
664  KeyValue&... kv)
665  {
666  static constexpr unsigned int PairSize = sizeof...(KeyValue);
667  static_assert(PairSize < 3,
668  "KeyValue parameter pack can be 1 or 2 elements (key, or key and value)");
669 
670  sort_power_two<BS>(flat_tid, storage, compare_function, kv...);
671  }
672 
675  template<bool SizeCheck, class BinaryFunction, class... KeyValue>
676  ROCPRIM_DEVICE ROCPRIM_INLINE void odd_even_sort(const unsigned int flat_tid,
677  const unsigned int size,
678  storage_type& storage,
679  BinaryFunction compare_function,
680  KeyValue&... kv)
681  {
682  static constexpr unsigned int PairSize = sizeof...(KeyValue);
683  static_assert(PairSize < 3,
684  "KeyValue parameter pack can be 1 or 2 elements (key, or key and value)");
685 
686  if(SizeCheck && size > ItemsPerBlock)
687  {
688  return;
689  }
690 
691  copy_to_shared(kv..., flat_tid, storage);
692 
693  for(unsigned int i = 0; i < size; i++)
694  {
695  bool is_even_iter = i % 2 == 0;
696  for(unsigned int item = 0; item < ItemsPerThread; ++item)
697  {
698  // the element in the original array that key[item] represents
699  unsigned int linear_id = flat_tid * ItemsPerThread + item;
700  bool is_even_lid = linear_id % 2 == 0;
701 
702  // one up/down from the linear_id
703  unsigned int odd_lid = is_even_lid ? ::rocprim::max(linear_id, 1u) - 1
704  : ::rocprim::min(linear_id + 1, size - 1);
705  unsigned int even_lid = is_even_lid ? ::rocprim::min(linear_id + 1, size - 1)
706  : ::rocprim::max(linear_id, 1u) - 1;
707 
708  // determine if the odd or even index must be used
709  unsigned int next_lid = is_even_iter ? even_lid : odd_lid;
710 
711  // map the linear_id back to item and thread id for indexing shared memory
712  unsigned int next_id = next_lid / ItemsPerThread;
713  unsigned int next_item_id = next_lid % ItemsPerThread;
714 
715  // prevent calling the compare function with out-of-bounds items
716  if(!SizeCheck || (linear_id < size && next_lid < size))
717  {
718  swap_oddeven(kv...,
719  next_id,
720  item,
721  next_item_id,
722  next_lid < linear_id,
723  storage,
724  compare_function);
725  }
726  }
728  copy_to_shared(kv..., flat_tid, storage);
729  }
730  }
731 
732  template<unsigned int BS, unsigned int IPT, class BinaryFunction, class... KeyValue>
733  ROCPRIM_DEVICE ROCPRIM_INLINE
734  typename std::enable_if<!is_power_of_two(BS) || !is_power_of_two(IPT)>::type
735  sort_impl(const unsigned int flat_tid,
736  storage_type& storage,
737  BinaryFunction compare_function,
738  KeyValue&... kv)
739  {
740  odd_even_sort<false>(flat_tid, ItemsPerBlock, storage, compare_function, kv...);
741  }
742 
743  template<class BinaryFunction, class... KeyValue>
744  ROCPRIM_DEVICE ROCPRIM_INLINE void sort_impl(const unsigned int flat_tid,
745  const unsigned int size,
746  storage_type& storage,
747  BinaryFunction compare_function,
748  KeyValue&... kv)
749  {
750  odd_even_sort<true>(flat_tid, size, storage, compare_function, kv...);
751  }
752 };
753 
754 } // end namespace detail
755 
756 END_ROCPRIM_NAMESPACE
757 
758 #endif // ROCPRIM_BLOCK_DETAIL_BLOCK_SORT_SHARED_HPP_
ROCPRIM_HOST_DEVICE constexpr T max(const T &a, const T &b)
Returns the maximum of its arguments.
Definition: functional.hpp:55
ROCPRIM_DEVICE ROCPRIM_INLINE constexpr unsigned int device_warp_size()
Returns a number of threads in a hardware warp for the actual target.
Definition: thread.hpp:70
Definition: block_sort_bitonic.hpp:47
ROCPRIM_HOST_DEVICE constexpr T min(const T &a, const T &b)
Returns the minimum of its arguments.
Definition: functional.hpp:63
Deprecated: Configuration of device-level scan primitives.
Definition: block_histogram.hpp:62
ROCPRIM_DEVICE ROCPRIM_INLINE void syncthreads()
Synchronize all threads in a block (tile)
Definition: thread.hpp:216
ROCPRIM_HOST_DEVICE void swap(T &a, T &b)
Swaps two values.
Definition: functional.hpp:71
Definition: various.hpp:180
ROCPRIM_DEVICE ROCPRIM_INLINE T warp_shuffle_xor(const T &input, const int lane_mask, const int width=device_warp_size())
Shuffle XOR for any data type.
Definition: warp_shuffle.hpp:246