rocPRIM
warp_sort.hpp
1 // Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved.
2 //
3 // Permission is hereby granted, free of charge, to any person obtaining a copy
4 // of this software and associated documentation files (the "Software"), to deal
5 // in the Software without restriction, including without limitation the rights
6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 // copies of the Software, and to permit persons to whom the Software is
8 // furnished to do so, subject to the following conditions:
9 //
10 // The above copyright notice and this permission notice shall be included in
11 // all copies or substantial portions of the Software.
12 //
13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 // THE SOFTWARE.
20 
21 #ifndef ROCPRIM_WARP_WARP_SORT_HPP_
22 #define ROCPRIM_WARP_WARP_SORT_HPP_
23 
24 #include <type_traits>
25 
26 #include "../config.hpp"
27 #include "../detail/various.hpp"
28 
29 #include "../intrinsics.hpp"
30 #include "../functional.hpp"
31 
32 #include "detail/warp_sort_shuffle.hpp"
33 
36 
37 BEGIN_ROCPRIM_NAMESPACE
38 
94 template<
95  class Key,
96  unsigned int WarpSize = device_warp_size(),
97  class Value = empty_type
98 >
99 class warp_sort : detail::warp_sort_shuffle<Key, WarpSize, Value>
100 {
102 
103  // Check if WarpSize is valid for the targets
104  static_assert(WarpSize <= ROCPRIM_MAX_WARP_SIZE, "WarpSize can't be greater than hardware warp size.");
105 
106 public:
115  typedef typename base_type::storage_type storage_type;
116 
127  template<class BinaryFunction = ::rocprim::less<Key>, unsigned int FunctionWarpSize = WarpSize>
128  ROCPRIM_DEVICE ROCPRIM_INLINE
129  auto sort(Key& thread_key,
130  BinaryFunction compare_function = BinaryFunction())
131  -> typename std::enable_if<(FunctionWarpSize <= __AMDGCN_WAVEFRONT_SIZE), void>::type
132  {
133  base_type::sort(thread_key, compare_function);
134  }
135 
138  template<class BinaryFunction = ::rocprim::less<Key>, unsigned int FunctionWarpSize = WarpSize>
139  ROCPRIM_DEVICE ROCPRIM_INLINE
140  auto sort(Key& ,
141  BinaryFunction compare_function = BinaryFunction())
142  -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void>::type
143  {
144  (void) compare_function; // disables unused parameter warning
145  ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size. Aborting warp sort.");
146  return;
147  }
148 
159  template<
160  unsigned int ItemsPerThread,
161  class BinaryFunction = ::rocprim::less<Key>,
162  unsigned int FunctionWarpSize = WarpSize
163  >
164  ROCPRIM_DEVICE ROCPRIM_INLINE
165  auto sort(Key (&thread_keys)[ItemsPerThread],
166  BinaryFunction compare_function = BinaryFunction())
167  -> typename std::enable_if<(FunctionWarpSize <= __AMDGCN_WAVEFRONT_SIZE), void>::type
168  {
169  base_type::sort(thread_keys, compare_function);
170  }
171 
174  template<
175  unsigned int ItemsPerThread,
176  class BinaryFunction = ::rocprim::less<Key>,
177  unsigned int FunctionWarpSize = WarpSize
178  >
179  ROCPRIM_DEVICE ROCPRIM_INLINE
180  auto sort(Key (&thread_keys)[ItemsPerThread],
181  BinaryFunction compare_function = BinaryFunction())
182  -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void>::type
183  {
184  (void) thread_keys; // disables unused parameter warning
185  (void) compare_function; // disables unused parameter warning
186  ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size. Aborting warp sort.");
187  return;
188  }
189 
218  template<class BinaryFunction = ::rocprim::less<Key>, unsigned int FunctionWarpSize = WarpSize>
219  ROCPRIM_DEVICE ROCPRIM_INLINE
220  auto sort(Key& thread_key,
221  storage_type& storage,
222  BinaryFunction compare_function = BinaryFunction())
223  -> typename std::enable_if<(FunctionWarpSize <= __AMDGCN_WAVEFRONT_SIZE), void>::type
224  {
225  base_type::sort(
226  thread_key, storage, compare_function
227  );
228  }
229 
232  template<class BinaryFunction = ::rocprim::less<Key>, unsigned int FunctionWarpSize = WarpSize>
233  ROCPRIM_DEVICE ROCPRIM_INLINE
234  auto sort(Key& ,
235  storage_type& ,
236  BinaryFunction compare_function = BinaryFunction())
237  -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void>::type
238  {
239  (void) compare_function; // disables unused parameter warning
240  ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size. Aborting warp sort.");
241  return;
242  }
243 
244 
273  template<
274  unsigned int ItemsPerThread,
275  class BinaryFunction = ::rocprim::less<Key>,
276  unsigned int FunctionWarpSize = WarpSize
277  >
278  ROCPRIM_DEVICE ROCPRIM_INLINE
279  auto sort(Key (&thread_keys)[ItemsPerThread],
280  storage_type& storage,
281  BinaryFunction compare_function = BinaryFunction())
282  -> typename std::enable_if<(FunctionWarpSize <= __AMDGCN_WAVEFRONT_SIZE), void>::type
283  {
284  base_type::sort(
285  thread_keys, storage, compare_function
286  );
287  }
288 
291  template<
292  unsigned int ItemsPerThread,
293  class BinaryFunction = ::rocprim::less<Key>,
294  unsigned int FunctionWarpSize = WarpSize
295  >
296  ROCPRIM_DEVICE ROCPRIM_INLINE
297  auto sort(Key (&thread_keys)[ItemsPerThread],
298  storage_type& ,
299  BinaryFunction compare_function = BinaryFunction())
300  -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void>::type
301  {
302  (void) thread_keys; // disables unused parameter warning
303  (void) compare_function; // disables unused parameter warning
304  ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size. Aborting warp sort.");
305  return;
306  }
307 
319  template<class BinaryFunction = ::rocprim::less<Key>, unsigned int FunctionWarpSize = WarpSize>
320  ROCPRIM_DEVICE ROCPRIM_INLINE
321  auto sort(Key& thread_key,
322  Value& thread_value,
323  BinaryFunction compare_function = BinaryFunction())
324  -> typename std::enable_if<(FunctionWarpSize <= __AMDGCN_WAVEFRONT_SIZE), void>::type
325  {
326  base_type::sort(
327  thread_key, thread_value, compare_function
328  );
329  }
330 
333  template<class BinaryFunction = ::rocprim::less<Key>, unsigned int FunctionWarpSize = WarpSize>
334  ROCPRIM_DEVICE ROCPRIM_INLINE
335  auto sort(Key& ,
336  Value& ,
337  BinaryFunction compare_function = BinaryFunction())
338  -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void>::type
339  {
340  (void) compare_function; // disables unused parameter warning
341  ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size. Aborting warp sort.");
342  return;
343  }
344 
356  template<
357  unsigned int ItemsPerThread,
358  class BinaryFunction = ::rocprim::less<Key>,
359  unsigned int FunctionWarpSize = WarpSize
360  >
361  ROCPRIM_DEVICE ROCPRIM_INLINE
362  auto sort(Key (&thread_keys)[ItemsPerThread],
363  Value (&thread_values)[ItemsPerThread],
364  BinaryFunction compare_function = BinaryFunction())
365  -> typename std::enable_if<(FunctionWarpSize <= __AMDGCN_WAVEFRONT_SIZE), void>::type
366  {
367  base_type::sort(
368  thread_keys, thread_values, compare_function
369  );
370  }
371 
374  template<
375  unsigned int ItemsPerThread,
376  class BinaryFunction = ::rocprim::less<Key>,
377  unsigned int FunctionWarpSize = WarpSize
378  >
379  ROCPRIM_DEVICE ROCPRIM_INLINE
380  auto sort(Key (&thread_keys)[ItemsPerThread],
381  Value (&thread_values)[ItemsPerThread],
382  BinaryFunction compare_function = BinaryFunction())
383  -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void>::type
384  {
385  (void) thread_keys; // disables unused parameter warning
386  (void) thread_values; // disables unused parameter warning
387  (void) compare_function; // disables unused parameter warning
388  ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size. Aborting warp sort.");
389  return;
390  }
391 
421  template<class BinaryFunction = ::rocprim::less<Key>, unsigned int FunctionWarpSize = WarpSize>
422  ROCPRIM_DEVICE ROCPRIM_INLINE
423  auto sort(Key& thread_key,
424  Value& thread_value,
425  storage_type& storage,
426  BinaryFunction compare_function = BinaryFunction())
427  -> typename std::enable_if<(FunctionWarpSize <= __AMDGCN_WAVEFRONT_SIZE), void>::type
428  {
429  base_type::sort(
430  thread_key, thread_value, storage, compare_function
431  );
432  }
433 
436  template<class BinaryFunction = ::rocprim::less<Key>, unsigned int FunctionWarpSize = WarpSize>
437  ROCPRIM_DEVICE ROCPRIM_INLINE
438  auto sort(Key& ,
439  Value& ,
440  storage_type& ,
441  BinaryFunction compare_function = BinaryFunction())
442  -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void>::type
443  {
444  (void) compare_function; // disables unused parameter warning
445  ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size. Aborting warp sort.");
446  return;
447  }
448 
449 
479  template<
480  unsigned int ItemsPerThread,
481  class BinaryFunction = ::rocprim::less<Key>,
482  unsigned int FunctionWarpSize = WarpSize
483  >
484  ROCPRIM_DEVICE ROCPRIM_INLINE
485  auto sort(Key (&thread_keys)[ItemsPerThread],
486  Value (&thread_values)[ItemsPerThread],
487  storage_type& storage,
488  BinaryFunction compare_function = BinaryFunction())
489  -> typename std::enable_if<(FunctionWarpSize <= __AMDGCN_WAVEFRONT_SIZE), void>::type
490  {
491  base_type::sort(
492  thread_keys, thread_values, storage, compare_function
493  );
494  }
495 
498  template<
499  unsigned int ItemsPerThread,
500  class BinaryFunction = ::rocprim::less<Key>,
501  unsigned int FunctionWarpSize = WarpSize
502  >
503  ROCPRIM_DEVICE ROCPRIM_INLINE
504  auto sort(Key (&thread_keys)[ItemsPerThread],
505  Value (&thread_values)[ItemsPerThread],
506  storage_type& ,
507  BinaryFunction compare_function = BinaryFunction())
508  -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void>::type
509  {
510  (void) thread_keys; // disables unused parameter warning
511  (void) thread_values; // disables unused parameter warning
512  (void) compare_function; // disables unused parameter warning
513  ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size. Aborting warp sort.");
514  return;
515  }
516 };
517 
518 END_ROCPRIM_NAMESPACE
519 
521 // end of group warpmodule
522 
523 #endif // ROCPRIM_WARP_WARP_SORT_HPP_
Empty type used as a placeholder, usually used to flag that given template parameter should not be us...
Definition: types.hpp:135
ROCPRIM_DEVICE ROCPRIM_INLINE auto sort(Key &thread_key, Value &thread_value, storage_type &storage, BinaryFunction compare_function=BinaryFunction()) -> typename std::enable_if<(FunctionWarpSize<=__AMDGCN_WAVEFRONT_SIZE), void >::type
Warp sort by key for any data type using temporary storage.
Definition: warp_sort.hpp:423
Definition: warp_sort_shuffle.hpp:42
ROCPRIM_DEVICE ROCPRIM_INLINE auto sort(Key &thread_key, Value &thread_value, BinaryFunction compare_function=BinaryFunction()) -> typename std::enable_if<(FunctionWarpSize<=__AMDGCN_WAVEFRONT_SIZE), void >::type
Warp sort by key for any data type.
Definition: warp_sort.hpp:321
ROCPRIM_DEVICE ROCPRIM_INLINE auto sort(Key(&thread_keys)[ItemsPerThread], storage_type &, BinaryFunction compare_function=BinaryFunction()) -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void >::type
Warp sort for any data type using temporary storage.
Definition: warp_sort.hpp:297
ROCPRIM_DEVICE ROCPRIM_INLINE auto sort(Key(&thread_keys)[ItemsPerThread], Value(&thread_values)[ItemsPerThread], storage_type &storage, BinaryFunction compare_function=BinaryFunction()) -> typename std::enable_if<(FunctionWarpSize<=__AMDGCN_WAVEFRONT_SIZE), void >::type
Warp sort by key for any data type using temporary storage.
Definition: warp_sort.hpp:485
ROCPRIM_DEVICE ROCPRIM_INLINE constexpr unsigned int device_warp_size()
Returns a number of threads in a hardware warp for the actual target.
Definition: thread.hpp:70
ROCPRIM_DEVICE ROCPRIM_INLINE auto sort(Key &, storage_type &, BinaryFunction compare_function=BinaryFunction()) -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void >::type
Warp sort for any data type using temporary storage.
Definition: warp_sort.hpp:234
ROCPRIM_DEVICE ROCPRIM_INLINE auto sort(Key(&thread_keys)[ItemsPerThread], Value(&thread_values)[ItemsPerThread], BinaryFunction compare_function=BinaryFunction()) -> typename std::enable_if<(FunctionWarpSize<=__AMDGCN_WAVEFRONT_SIZE), void >::type
Warp sort by key for any data type.
Definition: warp_sort.hpp:362
base_type::storage_type storage_type
Struct used to allocate a temporary memory that is required for thread communication during operation...
Definition: warp_sort.hpp:104
The warp_sort class provides warp-wide methods for computing a parallel sort of items across thread w...
Definition: warp_sort.hpp:99
ROCPRIM_DEVICE ROCPRIM_INLINE auto sort(Key(&thread_keys)[ItemsPerThread], Value(&thread_values)[ItemsPerThread], storage_type &, BinaryFunction compare_function=BinaryFunction()) -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void >::type
Warp sort by key for any data type using temporary storage.
Definition: warp_sort.hpp:504
ROCPRIM_DEVICE ROCPRIM_INLINE auto sort(Key &, Value &, storage_type &, BinaryFunction compare_function=BinaryFunction()) -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void >::type
Warp sort by key for any data type using temporary storage.
Definition: warp_sort.hpp:438
ROCPRIM_DEVICE ROCPRIM_INLINE auto sort(Key(&thread_keys)[ItemsPerThread], storage_type &storage, BinaryFunction compare_function=BinaryFunction()) -> typename std::enable_if<(FunctionWarpSize<=__AMDGCN_WAVEFRONT_SIZE), void >::type
Warp sort for any data type using temporary storage.
Definition: warp_sort.hpp:279
ROCPRIM_DEVICE ROCPRIM_INLINE auto sort(Key(&thread_keys)[ItemsPerThread], Value(&thread_values)[ItemsPerThread], BinaryFunction compare_function=BinaryFunction()) -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void >::type
Warp sort by key for any data type.
Definition: warp_sort.hpp:380
ROCPRIM_DEVICE ROCPRIM_INLINE auto sort(Key &thread_key, BinaryFunction compare_function=BinaryFunction()) -> typename std::enable_if<(FunctionWarpSize<=__AMDGCN_WAVEFRONT_SIZE), void >::type
Warp sort for any data type.
Definition: warp_sort.hpp:129
ROCPRIM_DEVICE ROCPRIM_INLINE auto sort(Key &thread_key, storage_type &storage, BinaryFunction compare_function=BinaryFunction()) -> typename std::enable_if<(FunctionWarpSize<=__AMDGCN_WAVEFRONT_SIZE), void >::type
Warp sort for any data type using temporary storage.
Definition: warp_sort.hpp:220
ROCPRIM_DEVICE ROCPRIM_INLINE auto sort(Key(&thread_keys)[ItemsPerThread], BinaryFunction compare_function=BinaryFunction()) -> typename std::enable_if<(FunctionWarpSize<=__AMDGCN_WAVEFRONT_SIZE), void >::type
Warp sort for any data type.
Definition: warp_sort.hpp:165
#define ROCPRIM_PRINT_ERROR_ONCE(message)
Prints the supplied error message only once (using only one of the active threads).
Definition: functional.hpp:42
ROCPRIM_DEVICE ROCPRIM_INLINE auto sort(Key(&thread_keys)[ItemsPerThread], BinaryFunction compare_function=BinaryFunction()) -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void >::type
Warp sort for any data type.
Definition: warp_sort.hpp:180
ROCPRIM_DEVICE ROCPRIM_INLINE auto sort(Key &, Value &, BinaryFunction compare_function=BinaryFunction()) -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void >::type
Warp sort by key for any data type.
Definition: warp_sort.hpp:335
ROCPRIM_DEVICE ROCPRIM_INLINE auto sort(Key &, BinaryFunction compare_function=BinaryFunction()) -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void >::type
Warp sort for any data type.
Definition: warp_sort.hpp:140