rocPRIM
warp_sort_shuffle.hpp
1 // Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved.
2 //
3 // Permission is hereby granted, free of charge, to any person obtaining a copy
4 // of this software and associated documentation files (the "Software"), to deal
5 // in the Software without restriction, including without limitation the rights
6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 // copies of the Software, and to permit persons to whom the Software is
8 // furnished to do so, subject to the following conditions:
9 //
10 // The above copyright notice and this permission notice shall be included in
11 // all copies or substantial portions of the Software.
12 //
13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 // THE SOFTWARE.
20 
21 #ifndef ROCPRIM_WARP_DETAIL_WARP_SORT_SHUFFLE_HPP_
22 #define ROCPRIM_WARP_DETAIL_WARP_SORT_SHUFFLE_HPP_
23 
24 #include <type_traits>
25 
26 #include "../../config.hpp"
27 #include "../../detail/various.hpp"
28 
29 #include "../../intrinsics.hpp"
30 #include "../../functional.hpp"
31 
32 BEGIN_ROCPRIM_NAMESPACE
33 
34 namespace detail
35 {
36 
37 template<
38  class Key,
39  unsigned int WarpSize,
40  class Value
41 >
43 {
44 private:
45  template<int warp, class V, class BinaryFunction>
46  ROCPRIM_DEVICE ROCPRIM_INLINE
47  typename std::enable_if<!(WarpSize > warp)>::type
48  swap(Key& k, V& v, int mask, bool dir, BinaryFunction compare_function)
49  {
50  (void) k;
51  (void) v;
52  (void) mask;
53  (void) dir;
54  (void) compare_function;
55  }
56 
57  template<int warp, class V, class BinaryFunction>
58  ROCPRIM_DEVICE ROCPRIM_INLINE
59  typename std::enable_if<(WarpSize > warp)>::type
60  swap(Key& k, V& v, int mask, bool dir, BinaryFunction compare_function)
61  {
62  Key k1 = warp_shuffle_xor(k, mask, WarpSize);
63  //V v1 = warp_shuffle_xor(v, mask, WarpSize);
64  bool swap = compare_function(dir ? k : k1, dir ? k1 : k);
65  if (swap)
66  {
67  k = k1;
68  v = warp_shuffle_xor(v, mask, WarpSize);
69  }
70  }
71 
72  template<
73  int warp,
74  class V,
75  class BinaryFunction,
76  unsigned int ItemsPerThread
77  >
78  ROCPRIM_DEVICE ROCPRIM_INLINE
79  typename std::enable_if<!(WarpSize > warp)>::type
80  swap(Key (&k)[ItemsPerThread],
81  V (&v)[ItemsPerThread],
82  int mask,
83  bool dir,
84  BinaryFunction compare_function)
85  {
86  (void) k;
87  (void) v;
88  (void) mask;
89  (void) dir;
90  (void) compare_function;
91  }
92 
93  template<
94  int warp,
95  class V,
96  class BinaryFunction,
97  unsigned int ItemsPerThread
98  >
99  ROCPRIM_DEVICE ROCPRIM_INLINE
100  typename std::enable_if<(WarpSize > warp)>::type
101  swap(Key (&k)[ItemsPerThread],
102  V (&v)[ItemsPerThread],
103  int mask,
104  bool dir,
105  BinaryFunction compare_function)
106  {
107  Key k1[ItemsPerThread];
108  ROCPRIM_UNROLL
109  for (unsigned int item = 0; item < ItemsPerThread; item++)
110  {
111  k1[item]= warp_shuffle_xor(k[item], mask, WarpSize);
112  //V v1 = warp_shuffle_xor(v, mask, WarpSize);
113  bool swap = compare_function(dir ? k[item] : k1[item], dir ? k1[item] : k[item]);
114  if (swap)
115  {
116  k[item] = k1[item];
117  v[item] = warp_shuffle_xor(v[item], mask, WarpSize);
118  }
119  }
120  }
121 
122  template<int warp, class BinaryFunction>
123  ROCPRIM_DEVICE ROCPRIM_INLINE
124  typename std::enable_if<!(WarpSize > warp)>::type
125  swap(Key& k, int mask, bool dir, BinaryFunction compare_function)
126  {
127  (void) k;
128  (void) mask;
129  (void) dir;
130  (void) compare_function;
131  }
132 
133  template<int warp, class BinaryFunction>
134  ROCPRIM_DEVICE ROCPRIM_INLINE
135  typename std::enable_if<(WarpSize > warp)>::type
136  swap(Key& k, int mask, bool dir, BinaryFunction compare_function)
137  {
138  Key k1 = warp_shuffle_xor(k, mask, WarpSize);
139  bool swap = compare_function(dir ? k : k1, dir ? k1 : k);
140  if (swap)
141  {
142  k = k1;
143  }
144  }
145 
146  template<
147  int warp,
148  class BinaryFunction,
149  unsigned int ItemsPerThread
150  >
151  ROCPRIM_DEVICE ROCPRIM_INLINE
152  typename std::enable_if<!(WarpSize > warp)>::type
153  swap(Key (&k)[ItemsPerThread], int mask, bool dir, BinaryFunction compare_function)
154  {
155  (void) k;
156  (void) mask;
157  (void) dir;
158  (void) compare_function;
159  }
160 
161  template<
162  int warp,
163  class BinaryFunction,
164  unsigned int ItemsPerThread
165  >
166  ROCPRIM_DEVICE ROCPRIM_INLINE
167  typename std::enable_if<(WarpSize > warp)>::type
168  swap(Key (&k)[ItemsPerThread], int mask, bool dir, BinaryFunction compare_function)
169  {
170  Key k1[ItemsPerThread];
171  ROCPRIM_UNROLL
172  for (unsigned int item = 0; item < ItemsPerThread; item++)
173  {
174  k1[item]= warp_shuffle_xor(k[item], mask, WarpSize);
175  bool swap = compare_function(dir ? k[item] : k1[item], dir ? k1[item] : k[item]);
176  if (swap)
177  {
178  k[item] = k1[item];
179  }
180  }
181  }
182 
183  template <unsigned int ItemsPerThread, class BinaryFunction>
184  ROCPRIM_DEVICE ROCPRIM_INLINE
185  void thread_swap(Key (&k)[ItemsPerThread],
186  unsigned int i,
187  unsigned int j,
188  bool dir,
189  BinaryFunction compare_function)
190  {
191  if(compare_function(k[i], k[j]) == dir)
192  {
193  Key temp = k[i];
194  k[i] = k[j];
195  k[j] = temp;
196  }
197  }
198 
199  template <unsigned int ItemsPerThread, class V, class BinaryFunction>
200  ROCPRIM_DEVICE ROCPRIM_INLINE
201  void thread_swap(Key (&k)[ItemsPerThread],
202  V (&v)[ItemsPerThread],
203  unsigned int i,
204  unsigned int j,
205  bool dir,
206  BinaryFunction compare_function)
207  {
208  if(compare_function(k[i], k[j]) == dir)
209  {
210  Key k_temp = k[i];
211  k[i] = k[j];
212  k[j] = k_temp;
213  V v_temp = v[i];
214  v[i] = v[j];
215  v[j] = v_temp;
216  }
217  }
218 
219  template <unsigned int ItemsPerThread, class BinaryFunction, class... KeyValue>
220  ROCPRIM_DEVICE ROCPRIM_INLINE
221  void thread_shuffle(unsigned int group_size,
222  unsigned int offset,
223  bool dir,
224  BinaryFunction compare_function,
225  KeyValue&... kv)
226  {
227  ROCPRIM_UNROLL
228  for(unsigned int base = 0; base < ItemsPerThread; base += 2 * offset) {
229  // The local direction must change every group_size items
230  // and is flipped if dir is true
231  const bool local_dir = ((base & group_size) > 0) != dir;
232 
233  ROCPRIM_UNROLL
234 // Workaround to prevent the compiler thinking this is a 'Parallel Loop' on clang 15
235 // because it leads to invalid code generation with `T` = `char` and `ItemsPerthread` = 4
236 #if defined(__clang_major__) && __clang_major__ >= 15
237  #pragma clang loop vectorize(disable)
238 #endif
239  for(unsigned i = 0; i < offset; ++i) {
240  thread_swap(kv..., base + i, base + i + offset, local_dir, compare_function);
241  }
242  }
243  }
244 
245  template <unsigned int ItemsPerThread, class BinaryFunction, class... KeyValue>
246  ROCPRIM_DEVICE ROCPRIM_INLINE
247  void thread_sort(bool dir, BinaryFunction compare_function, KeyValue&... kv)
248  {
249  ROCPRIM_UNROLL
250  for(unsigned int k = 2; k <= ItemsPerThread; k *= 2)
251  {
252  ROCPRIM_UNROLL
253  for(unsigned int j = k / 2; j > 0; j /= 2)
254  {
255  thread_shuffle<ItemsPerThread>(k, j, dir, compare_function, kv...);
256  }
257  }
258  }
259 
260  template <int warp, unsigned int ItemsPerThread, class BinaryFunction, class... KeyValue>
261  ROCPRIM_DEVICE ROCPRIM_INLINE
262  typename std::enable_if<(WarpSize > warp)>::type
263  thread_merge(bool dir, BinaryFunction compare_function, KeyValue&... kv)
264  {
265  ROCPRIM_UNROLL
266  for(unsigned int j = ItemsPerThread / 2; j > 0; j /= 2)
267  {
268  thread_shuffle<ItemsPerThread>(ItemsPerThread, j, dir, compare_function, kv...);
269  }
270  }
271 
272  template <int warp, unsigned int ItemsPerThread, class BinaryFunction, class... KeyValue>
273  ROCPRIM_DEVICE ROCPRIM_INLINE
274  typename std::enable_if<!(WarpSize > warp)>::type
275  thread_merge(bool /*dir*/, BinaryFunction /*compare_function*/, KeyValue&... /*kv*/)
276  {
277  }
278 
279  template<class BinaryFunction, class... KeyValue>
280  ROCPRIM_DEVICE ROCPRIM_INLINE
281  void bitonic_sort(BinaryFunction compare_function, KeyValue&... kv)
282  {
283  static_assert(
284  sizeof...(KeyValue) < 3,
285  "KeyValue parameter pack can 1 or 2 elements (key, or key and value)"
286  );
287 
288  unsigned int id = detail::logical_lane_id<WarpSize>();
289  swap< 2>(kv..., 1, get_bit(id, 1) != get_bit(id, 0), compare_function);
290 
291  swap< 4>(kv..., 2, get_bit(id, 2) != get_bit(id, 1), compare_function);
292  swap< 4>(kv..., 1, get_bit(id, 2) != get_bit(id, 0), compare_function);
293 
294  swap< 8>(kv..., 4, get_bit(id, 3) != get_bit(id, 2), compare_function);
295  swap< 8>(kv..., 2, get_bit(id, 3) != get_bit(id, 1), compare_function);
296  swap< 8>(kv..., 1, get_bit(id, 3) != get_bit(id, 0), compare_function);
297 
298  swap<16>(kv..., 8, get_bit(id, 4) != get_bit(id, 3), compare_function);
299  swap<16>(kv..., 4, get_bit(id, 4) != get_bit(id, 2), compare_function);
300  swap<16>(kv..., 2, get_bit(id, 4) != get_bit(id, 1), compare_function);
301  swap<16>(kv..., 1, get_bit(id, 4) != get_bit(id, 0), compare_function);
302 
303  swap<32>(kv..., 16, get_bit(id, 5) != get_bit(id, 4), compare_function);
304  swap<32>(kv..., 8, get_bit(id, 5) != get_bit(id, 3), compare_function);
305  swap<32>(kv..., 4, get_bit(id, 5) != get_bit(id, 2), compare_function);
306  swap<32>(kv..., 2, get_bit(id, 5) != get_bit(id, 1), compare_function);
307  swap<32>(kv..., 1, get_bit(id, 5) != get_bit(id, 0), compare_function);
308 
309  swap<32>(kv..., 32, get_bit(id, 5) != 0, compare_function);
310  swap<16>(kv..., 16, get_bit(id, 4) != 0, compare_function);
311  swap< 8>(kv..., 8, get_bit(id, 3) != 0, compare_function);
312  swap< 4>(kv..., 4, get_bit(id, 2) != 0, compare_function);
313  swap< 2>(kv..., 2, get_bit(id, 1) != 0, compare_function);
314  swap< 0>(kv..., 1, get_bit(id, 0) != 0, compare_function);
315  }
316 
317  template<
318  unsigned int ItemsPerThread,
319  class BinaryFunction,
320  class... KeyValue
321  >
322  ROCPRIM_DEVICE ROCPRIM_INLINE
323  void bitonic_sort(BinaryFunction compare_function, KeyValue&... kv)
324  {
325  static_assert(
326  sizeof...(KeyValue) < 3,
327  "KeyValue parameter pack can 1 or 2 elements (key, or key and value)"
328  );
329 
330  static_assert(detail::is_power_of_two(ItemsPerThread), "ItemsPerThread must be power of 2");
331 
332  unsigned int id = detail::logical_lane_id<WarpSize>();
333  thread_sort<ItemsPerThread>(get_bit(id, 0) != 0, compare_function, kv...);
334 
335  swap< 2>(kv..., 1, get_bit(id, 1) != get_bit(id, 0), compare_function);
336  thread_merge<2, ItemsPerThread>(get_bit(id, 1) != 0, compare_function, kv...);
337 
338  swap< 4>(kv..., 2, get_bit(id, 2) != get_bit(id, 1), compare_function);
339  swap< 4>(kv..., 1, get_bit(id, 2) != get_bit(id, 0), compare_function);
340  thread_merge<4, ItemsPerThread>(get_bit(id, 2) != 0, compare_function, kv...);
341 
342  swap< 8>(kv..., 4, get_bit(id, 3) != get_bit(id, 2), compare_function);
343  swap< 8>(kv..., 2, get_bit(id, 3) != get_bit(id, 1), compare_function);
344  swap< 8>(kv..., 1, get_bit(id, 3) != get_bit(id, 0), compare_function);
345  thread_merge<8, ItemsPerThread>(get_bit(id, 3) != 0, compare_function, kv...);
346 
347  swap<16>(kv..., 8, get_bit(id, 4) != get_bit(id, 3), compare_function);
348  swap<16>(kv..., 4, get_bit(id, 4) != get_bit(id, 2), compare_function);
349  swap<16>(kv..., 2, get_bit(id, 4) != get_bit(id, 1), compare_function);
350  swap<16>(kv..., 1, get_bit(id, 4) != get_bit(id, 0), compare_function);
351  thread_merge<16, ItemsPerThread>(get_bit(id, 4) != 0, compare_function, kv...);
352 
353  swap<32>(kv..., 16, get_bit(id, 5) != get_bit(id, 4), compare_function);
354  swap<32>(kv..., 8, get_bit(id, 5) != get_bit(id, 3), compare_function);
355  swap<32>(kv..., 4, get_bit(id, 5) != get_bit(id, 2), compare_function);
356  swap<32>(kv..., 2, get_bit(id, 5) != get_bit(id, 1), compare_function);
357  swap<32>(kv..., 1, get_bit(id, 5) != get_bit(id, 0), compare_function);
358  thread_merge<32, ItemsPerThread>(get_bit(id, 5) != 0, compare_function, kv...);
359 
360  swap<32>(kv..., 32, get_bit(id, 5) != 0, compare_function);
361  swap<16>(kv..., 16, get_bit(id, 4) != 0, compare_function);
362  swap< 8>(kv..., 8, get_bit(id, 3) != 0, compare_function);
363  swap< 4>(kv..., 4, get_bit(id, 2) != 0, compare_function);
364  swap< 2>(kv..., 2, get_bit(id, 1) != 0, compare_function);
365  swap< 0>(kv..., 1, get_bit(id, 0) != 0, compare_function);
366  thread_merge<1, ItemsPerThread>(false, compare_function, kv...);
367  }
368 
369 public:
370  static_assert(detail::is_power_of_two(WarpSize), "WarpSize must be power of 2");
371 
372  using storage_type = ::rocprim::detail::empty_storage_type;
373 
374  template<class BinaryFunction>
375  ROCPRIM_DEVICE ROCPRIM_INLINE
376  void sort(Key& thread_value, BinaryFunction compare_function)
377  {
378  // sort by value only
379  bitonic_sort(compare_function, thread_value);
380  }
381 
382  template<class BinaryFunction>
383  ROCPRIM_DEVICE ROCPRIM_INLINE
384  void sort(Key& thread_value, storage_type& storage,
385  BinaryFunction compare_function)
386  {
387  (void) storage;
388  sort(thread_value, compare_function);
389  }
390 
391  template<
392  unsigned int ItemsPerThread,
393  class BinaryFunction
394  >
395  ROCPRIM_DEVICE ROCPRIM_INLINE
396  void sort(Key (&thread_values)[ItemsPerThread],
397  BinaryFunction compare_function)
398  {
399  // sort by value only
400  bitonic_sort<ItemsPerThread>(compare_function, thread_values);
401  }
402 
403  template<
404  unsigned int ItemsPerThread,
405  class BinaryFunction
406  >
407  ROCPRIM_DEVICE ROCPRIM_INLINE
408  void sort(Key (&thread_values)[ItemsPerThread],
409  storage_type& storage,
410  BinaryFunction compare_function)
411  {
412  (void) storage;
413  sort(thread_values, compare_function);
414  }
415 
416  template<class BinaryFunction, class V = Value>
417  ROCPRIM_DEVICE ROCPRIM_INLINE
418  typename std::enable_if<(sizeof(V) <= sizeof(int))>::type
419  sort(Key& thread_key, Value& thread_value,
420  BinaryFunction compare_function)
421  {
422  bitonic_sort(compare_function, thread_key, thread_value);
423  }
424 
425  template<class BinaryFunction, class V = Value>
426  ROCPRIM_DEVICE ROCPRIM_INLINE
427  typename std::enable_if<!(sizeof(V) <= sizeof(int))>::type
428  sort(Key& thread_key, Value& thread_value,
429  BinaryFunction compare_function)
430  {
431  // Instead of passing large values between lanes we pass indices and gather values after sorting.
432  unsigned int v = detail::logical_lane_id<WarpSize>();
433  bitonic_sort(compare_function, thread_key, v);
434  thread_value = warp_shuffle(thread_value, v, WarpSize);
435  }
436 
437  template<class BinaryFunction>
438  ROCPRIM_DEVICE ROCPRIM_INLINE
439  void sort(Key& thread_key, Value& thread_value,
440  storage_type& storage, BinaryFunction compare_function)
441  {
442  (void) storage;
443  sort(compare_function, thread_key, thread_value);
444  }
445 
446  template<
447  unsigned int ItemsPerThread,
448  class BinaryFunction,
449  class V = Value
450  >
451  ROCPRIM_DEVICE ROCPRIM_INLINE
452  typename std::enable_if<(sizeof(V) <= sizeof(int))>::type
453  sort(Key (&thread_keys)[ItemsPerThread],
454  Value (&thread_values)[ItemsPerThread],
455  BinaryFunction compare_function)
456  {
457  bitonic_sort<ItemsPerThread>(compare_function, thread_keys, thread_values);
458  }
459 
460  template<
461  unsigned int ItemsPerThread,
462  class BinaryFunction,
463  class V = Value
464  >
465  ROCPRIM_DEVICE ROCPRIM_INLINE
466  typename std::enable_if<!(sizeof(V) <= sizeof(int))>::type
467  sort(Key (&thread_keys)[ItemsPerThread],
468  Value (&thread_values)[ItemsPerThread],
469  BinaryFunction compare_function)
470  {
471  // Instead of passing large values between lanes we pass indices and gather values after sorting.
472  unsigned int v[ItemsPerThread];
473  ROCPRIM_UNROLL
474  for (unsigned int item = 0; item < ItemsPerThread; item++)
475  {
476  v[item] = ItemsPerThread * detail::logical_lane_id<WarpSize>() + item;
477  }
478 
479  bitonic_sort<ItemsPerThread>(compare_function, thread_keys, v);
480 
481  V copy[ItemsPerThread];
482  ROCPRIM_UNROLL
483  for(unsigned item = 0; item < ItemsPerThread; ++item) {
484  copy[item] = thread_values[item];
485  }
486 
487  ROCPRIM_UNROLL
488  for(unsigned int dst_item = 0; dst_item < ItemsPerThread; ++dst_item) {
489  ROCPRIM_UNROLL
490  for(unsigned src_item = 0; src_item < ItemsPerThread; ++src_item) {
491  V temp = warp_shuffle(copy[src_item], v[dst_item] / ItemsPerThread, WarpSize);
492  if(v[dst_item] % ItemsPerThread == src_item)
493  thread_values[dst_item] = temp;
494  }
495  }
496  }
497 
498  template<
499  unsigned int ItemsPerThread,
500  class BinaryFunction
501  >
502  ROCPRIM_DEVICE ROCPRIM_INLINE
503  void sort(Key (&thread_keys)[ItemsPerThread],
504  Value (&thread_values)[ItemsPerThread],
505  storage_type& storage, BinaryFunction compare_function)
506  {
507  (void) storage;
508  sort(thread_keys, thread_values, compare_function);
509  }
510 };
511 
512 } // end namespace detail
513 
514 END_ROCPRIM_NAMESPACE
515 
516 #endif // ROCPRIM_WARP_DETAIL_WARP_SORT_SHUFFLE_HPP_
Definition: warp_sort_shuffle.hpp:42
ROCPRIM_DEVICE ROCPRIM_INLINE T warp_shuffle(const T &input, const int src_lane, const int width=device_warp_size())
Shuffle for any data type.
Definition: warp_shuffle.hpp:172
Deprecated: Configuration of device-level scan primitives.
Definition: block_histogram.hpp:62
ROCPRIM_DEVICE ROCPRIM_INLINE T warp_shuffle_xor(const T &input, const int lane_mask, const int width=device_warp_size())
Shuffle XOR for any data type.
Definition: warp_shuffle.hpp:246
ROCPRIM_DEVICE ROCPRIM_INLINE int get_bit(int x, int i)
Returns a single bit at &#39;i&#39; from &#39;x&#39;.
Definition: bit.hpp:33