rocPRIM
device_segmented_radix_sort.hpp
1 // Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved.
2 //
3 // Permission is hereby granted, free of charge, to any person obtaining a copy
4 // of this software and associated documentation files (the "Software"), to deal
5 // in the Software without restriction, including without limitation the rights
6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 // copies of the Software, and to permit persons to whom the Software is
8 // furnished to do so, subject to the following conditions:
9 //
10 // The above copyright notice and this permission notice shall be included in
11 // all copies or substantial portions of the Software.
12 //
13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 // THE SOFTWARE.
20 
21 #ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_SEGMENTED_RADIX_SORT_HPP_
22 #define ROCPRIM_DEVICE_DETAIL_DEVICE_SEGMENTED_RADIX_SORT_HPP_
23 
24 #include <type_traits>
25 #include <iterator>
26 
27 #include "../../config.hpp"
28 #include "../../detail/various.hpp"
29 
30 #include "../../intrinsics.hpp"
31 #include "../../functional.hpp"
32 #include "../../types.hpp"
33 
34 #include "../../block/block_load.hpp"
35 #include "../../block/block_store.hpp"
36 #include "../../block/block_scan.hpp"
37 
38 #include "../../warp/warp_load.hpp"
39 #include "../../warp/warp_sort.hpp"
40 #include "../../warp/warp_store.hpp"
41 
42 #include "../device_segmented_radix_sort_config.hpp"
43 #include "device_radix_sort.hpp"
44 
45 BEGIN_ROCPRIM_NAMESPACE
46 
47 namespace detail
48 {
49 
50 template<
51  class Key,
52  class Value,
53  unsigned int WarpSize,
54  unsigned int BlockSize,
55  unsigned int ItemsPerThread,
56  unsigned int RadixBits,
57  bool Descending
58 >
60 {
61  static constexpr unsigned int radix_size = 1 << RadixBits;
62 
63  using key_type = Key;
64  using value_type = Value;
65 
67  using scan_type = typename ::rocprim::block_scan<unsigned int, radix_size>;
69  BlockSize, ItemsPerThread, RadixBits, Descending,
70  key_type, value_type, unsigned int>;
71 
72 public:
73 
75  {
78  };
79 
80  template<
81  class KeysInputIterator,
82  class KeysOutputIterator,
83  class ValuesInputIterator,
84  class ValuesOutputIterator
85  >
86  ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
87  void sort(KeysInputIterator keys_input,
88  key_type * keys_tmp,
89  KeysOutputIterator keys_output,
90  ValuesInputIterator values_input,
91  value_type * values_tmp,
92  ValuesOutputIterator values_output,
93  bool to_output,
94  unsigned int begin_offset,
95  unsigned int end_offset,
96  unsigned int bit,
97  unsigned int begin_bit,
98  unsigned int end_bit,
99  storage_type& storage)
100  {
101  // Handle cases when (end_bit - bit) is not divisible by radix_bits, i.e. the last
102  // iteration has a shorter mask.
103  const unsigned int current_radix_bits = ::rocprim::min(RadixBits, end_bit - bit);
104 
105  const bool is_first_iteration = (bit == begin_bit);
106 
107  if(is_first_iteration)
108  {
109  if(to_output)
110  {
111  sort(
112  keys_input, keys_output, values_input, values_output,
113  begin_offset, end_offset,
114  bit, current_radix_bits,
115  storage
116  );
117  }
118  else
119  {
120  sort(
121  keys_input, keys_tmp, values_input, values_tmp,
122  begin_offset, end_offset,
123  bit, current_radix_bits,
124  storage
125  );
126  }
127  }
128  else
129  {
130  if(to_output)
131  {
132  sort(
133  keys_tmp, keys_output, values_tmp, values_output,
134  begin_offset, end_offset,
135  bit, current_radix_bits,
136  storage
137  );
138  }
139  else
140  {
141  sort(
142  keys_output, keys_tmp, values_output, values_tmp,
143  begin_offset, end_offset,
144  bit, current_radix_bits,
145  storage
146  );
147  }
148  }
149  }
150 
151  // When all iterators are raw pointers, this overload is used to minimize code duplication in the kernel
152  ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
153  void sort(key_type * keys_input,
154  key_type * keys_tmp,
155  key_type * keys_output,
156  value_type * values_input,
157  value_type * values_tmp,
158  value_type * values_output,
159  bool to_output,
160  unsigned int begin_offset,
161  unsigned int end_offset,
162  unsigned int bit,
163  unsigned int begin_bit,
164  unsigned int end_bit,
165  storage_type& storage)
166  {
167  // Handle cases when (end_bit - bit) is not divisible by radix_bits, i.e. the last
168  // iteration has a shorter mask.
169  const unsigned int current_radix_bits = ::rocprim::min(RadixBits, end_bit - bit);
170 
171  const bool is_first_iteration = (bit == begin_bit);
172 
173  key_type * current_keys_input;
174  key_type * current_keys_output;
175  value_type * current_values_input;
176  value_type * current_values_output;
177  if(is_first_iteration)
178  {
179  if(to_output)
180  {
181  current_keys_input = keys_input;
182  current_keys_output = keys_output;
183  current_values_input = values_input;
184  current_values_output = values_output;
185  }
186  else
187  {
188  current_keys_input = keys_input;
189  current_keys_output = keys_tmp;
190  current_values_input = values_input;
191  current_values_output = values_tmp;
192  }
193  }
194  else
195  {
196  if(to_output)
197  {
198  current_keys_input = keys_tmp;
199  current_keys_output = keys_output;
200  current_values_input = values_tmp;
201  current_values_output = values_output;
202  }
203  else
204  {
205  current_keys_input = keys_output;
206  current_keys_output = keys_tmp;
207  current_values_input = values_output;
208  current_values_output = values_tmp;
209  }
210  }
211  sort(
212  current_keys_input, current_keys_output, current_values_input, current_values_output,
213  begin_offset, end_offset,
214  bit, current_radix_bits,
215  storage
216  );
217  }
218 
219 private:
220 
221  template<
222  class KeysInputIterator,
223  class KeysOutputIterator,
224  class ValuesInputIterator,
225  class ValuesOutputIterator
226  >
227  ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
228  void sort(KeysInputIterator keys_input,
229  KeysOutputIterator keys_output,
230  ValuesInputIterator values_input,
231  ValuesOutputIterator values_output,
232  unsigned int begin_offset,
233  unsigned int end_offset,
234  unsigned int bit,
235  unsigned int current_radix_bits,
236  storage_type& storage)
237  {
238  unsigned int digit_count;
239  count_helper_type().count_digits(
240  keys_input,
241  begin_offset, end_offset,
242  bit, current_radix_bits,
243  storage.count_helper,
244  digit_count
245  );
246 
247  unsigned int digit_start;
248  scan_type().exclusive_scan(digit_count, digit_start, 0);
249  digit_start += begin_offset;
250 
252 
253  sort_and_scatter_helper().sort_and_scatter(
254  keys_input, keys_output, values_input, values_output,
255  begin_offset, end_offset,
256  bit, current_radix_bits,
257  digit_start,
258  storage.sort_and_scatter_helper
259  );
260 
262  }
263 };
264 
265 template<
266  class Key,
267  class Value,
268  unsigned int BlockSize,
269  unsigned int ItemsPerThread,
270  bool Descending
271 >
273 {
274  using key_type = Key;
275  using value_type = Value;
276 
278  using bit_key_type = typename key_codec::bit_key_type;
279  using keys_load_type = ::rocprim::block_load<
280  key_type, BlockSize, ItemsPerThread,
281  ::rocprim::block_load_method::block_load_transpose>;
282  using values_load_type = ::rocprim::block_load<
283  value_type, BlockSize, ItemsPerThread,
284  ::rocprim::block_load_method::block_load_transpose>;
285  using sort_type = ::rocprim::block_radix_sort<key_type, BlockSize, ItemsPerThread, value_type>;
286  using keys_store_type = ::rocprim::block_store<
287  key_type, BlockSize, ItemsPerThread,
288  ::rocprim::block_store_method::block_store_transpose>;
289  using values_store_type = ::rocprim::block_store<
290  value_type, BlockSize, ItemsPerThread,
291  ::rocprim::block_store_method::block_store_transpose>;
292 
293  static constexpr bool with_values = !std::is_same<value_type, ::rocprim::empty_type>::value;
294 
295 public:
296 
298  {
299  typename keys_load_type::storage_type keys_load;
300  typename values_load_type::storage_type values_load;
301  typename sort_type::storage_type sort;
302  typename keys_store_type::storage_type keys_store;
303  typename values_store_type::storage_type values_store;
304  };
305 
306  template<
307  class KeysInputIterator,
308  class KeysOutputIterator,
309  class ValuesInputIterator,
310  class ValuesOutputIterator
311  >
312  ROCPRIM_DEVICE ROCPRIM_INLINE
313  void sort(KeysInputIterator keys_input,
314  key_type * keys_tmp,
315  KeysOutputIterator keys_output,
316  ValuesInputIterator values_input,
317  value_type * values_tmp,
318  ValuesOutputIterator values_output,
319  bool to_output,
320  unsigned int begin_offset,
321  unsigned int end_offset,
322  unsigned int begin_bit,
323  unsigned int end_bit,
324  storage_type& storage)
325  {
326  if(to_output)
327  {
328  sort(
329  keys_input, keys_output, values_input, values_output,
330  begin_offset, end_offset,
331  begin_bit, end_bit,
332  storage
333  );
334  }
335  else
336  {
337  sort(
338  keys_input, keys_tmp, values_input, values_tmp,
339  begin_offset, end_offset,
340  begin_bit, end_bit,
341  storage
342  );
343  }
344  }
345 
346  // When all iterators are raw pointers, this overload is used to minimize code duplication in the kernel
347  ROCPRIM_DEVICE ROCPRIM_INLINE
348  void sort(key_type * keys_input,
349  key_type * keys_tmp,
350  key_type * keys_output,
351  value_type * values_input,
352  value_type * values_tmp,
353  value_type * values_output,
354  bool to_output,
355  unsigned int begin_offset,
356  unsigned int end_offset,
357  unsigned int begin_bit,
358  unsigned int end_bit,
359  storage_type& storage)
360  {
361  sort(
362  keys_input, (to_output ? keys_output : keys_tmp), values_input, (to_output ? values_output : values_tmp),
363  begin_offset, end_offset,
364  begin_bit, end_bit,
365  storage
366  );
367  }
368 
369  template<
370  class KeysInputIterator,
371  class KeysOutputIterator,
372  class ValuesInputIterator,
373  class ValuesOutputIterator
374  >
375  ROCPRIM_DEVICE ROCPRIM_INLINE
376  bool sort(KeysInputIterator keys_input,
377  KeysOutputIterator keys_output,
378  ValuesInputIterator values_input,
379  ValuesOutputIterator values_output,
380  unsigned int begin_offset,
381  unsigned int end_offset,
382  unsigned int begin_bit,
383  unsigned int end_bit,
384  storage_type& storage)
385  {
386  constexpr unsigned int items_per_block = BlockSize * ItemsPerThread;
387 
388  using shorter_single_block_helper = segmented_radix_sort_single_block_helper<
389  key_type, value_type,
390  BlockSize, ItemsPerThread / 2, Descending
391  >;
392 
393  // Segment is longer than supported by this function
394  if(end_offset - begin_offset > items_per_block)
395  {
396  return false;
397  }
398 
399  // Recursively chech if it is possible to sort the segment using fewer items per thread
400  const bool processed_by_shorter =
401  shorter_single_block_helper().sort(
402  keys_input, keys_output, values_input, values_output,
403  begin_offset, end_offset,
404  begin_bit, end_bit,
405  reinterpret_cast<typename shorter_single_block_helper::storage_type&>(storage)
406  );
407  if(processed_by_shorter)
408  {
409  return true;
410  }
411 
412  key_type keys[ItemsPerThread];
413  value_type values[ItemsPerThread];
414  const unsigned int valid_count = end_offset - begin_offset;
415  // Sort will leave "invalid" (out of size) items at the end of the sorted sequence
416  const key_type out_of_bounds = key_codec::decode(bit_key_type(-1));
417  keys_load_type().load(keys_input + begin_offset, keys, valid_count, out_of_bounds, storage.keys_load);
418  if(with_values)
419  {
421  values_load_type().load(values_input + begin_offset, values, valid_count, storage.values_load);
422  }
423 
425  sort_block<Descending>(sort_type(), keys, values, storage.sort, begin_bit, end_bit);
426 
428  keys_store_type().store(keys_output + begin_offset, keys, valid_count, storage.keys_store);
429  if(with_values)
430  {
432  values_store_type().store(values_output + begin_offset, values, valid_count, storage.values_store);
433  }
434 
435  return true;
436  }
437 };
438 
439 template<
440  class Key,
441  class Value,
442  unsigned int BlockSize,
443  bool Descending
444 >
445 class segmented_radix_sort_single_block_helper<Key, Value, BlockSize, 0, Descending>
446 {
447 public:
448 
449  struct storage_type { };
450 
451  template<
452  class KeysInputIterator,
453  class KeysOutputIterator,
454  class ValuesInputIterator,
455  class ValuesOutputIterator
456  >
457  ROCPRIM_DEVICE ROCPRIM_INLINE
458  bool sort(KeysInputIterator,
459  KeysOutputIterator,
460  ValuesInputIterator,
461  ValuesOutputIterator,
462  unsigned int,
463  unsigned int,
464  unsigned int,
465  unsigned int,
466  storage_type&)
467  {
468  // It can't sort anything because ItemsPerThread is 0.
469  // The segment will be sorted by the calles (i.e. using ItemsPerThread = 1)
470  return false;
471  }
472 };
473 
474 template<unsigned int LogicalWarpSize, unsigned int ItemsPerThread, unsigned int BlockSize>
476 {
477  static constexpr unsigned int logical_warp_size = LogicalWarpSize;
478  static constexpr unsigned int items_per_thread = ItemsPerThread;
479  static constexpr unsigned int block_size = BlockSize;
480 };
481 
483 {
484  static constexpr unsigned int logical_warp_size = 1;
485  static constexpr unsigned int items_per_thread = 1;
486  static constexpr unsigned int block_size = 1;
487 };
488 
489 template<class Config>
490 using select_warp_sort_helper_config_small_t
491  = std::conditional_t<std::is_same<DisabledWarpSortConfig, Config>::value,
493  WarpSortHelperConfig<Config::logical_warp_size_small,
494  Config::items_per_thread_small,
495  Config::block_size_small>>;
496 
497 template<class Config>
498 using select_warp_sort_helper_config_medium_t
499  = std::conditional_t<std::is_same<DisabledWarpSortConfig, Config>::value,
500  DisabledWarpSortHelperConfig,
501  WarpSortHelperConfig<Config::logical_warp_size_medium,
502  Config::items_per_thread_medium,
503  Config::block_size_medium>>;
504 
505 template<
506  class Config,
507  class Key,
508  class Value,
509  bool Descending,
510  class Enable = void
511 >
513 {
514  static constexpr unsigned int items_per_warp = 0;
515  using storage_type = ::rocprim::empty_type;
516 
517  template<class... Args>
518  ROCPRIM_DEVICE ROCPRIM_INLINE
519  void sort(Args&&...)
520  {
521  }
522 };
523 
524 template<class Config, class Key, class Value, bool Descending>
526  Config,
527  Key,
528  Value,
529  Descending,
530  std::enable_if_t<!std::is_same<DisabledWarpSortHelperConfig, Config>::value>>
531 {
532  static constexpr unsigned int logical_warp_size = Config::logical_warp_size;
533  static constexpr unsigned int items_per_thread = Config::items_per_thread;
534 
535  using key_type = Key;
536  using value_type = Value;
537  using key_codec = ::rocprim::detail::radix_key_codec<key_type, Descending>;
538  using bit_key_type = typename key_codec::bit_key_type;
539 
540  using keys_load_type = ::rocprim::warp_load<key_type, items_per_thread, logical_warp_size, ::rocprim::warp_load_method::warp_load_striped>;
541  using values_load_type = ::rocprim::warp_load<value_type, items_per_thread, logical_warp_size, ::rocprim::warp_load_method::warp_load_striped>;
542  using keys_store_type = ::rocprim::warp_store<key_type, items_per_thread, logical_warp_size>;
543  using values_store_type = ::rocprim::warp_store<value_type, items_per_thread, logical_warp_size>;
544  template<bool UseRadixMask>
545  using radix_comparator_type = ::rocprim::detail::radix_merge_compare<Descending, UseRadixMask, key_type>;
546  using stable_key_type = ::rocprim::tuple<key_type, unsigned int>;
547  using sort_type = ::rocprim::warp_sort<stable_key_type, logical_warp_size, value_type>;
548 
549  static constexpr bool with_values = !std::is_same<value_type, ::rocprim::empty_type>::value;
550 
551  template<class ComparatorT>
552  ROCPRIM_DEVICE ROCPRIM_INLINE
553  decltype(auto) make_stable_comparator(ComparatorT comparator)
554  {
555  return [comparator](const stable_key_type& a, const stable_key_type& b) -> bool
556  {
557  const bool ab = comparator(rocprim::get<0>(a), rocprim::get<0>(b));
558  const bool ba = comparator(rocprim::get<0>(b), rocprim::get<0>(a));
559  return ab || (!ba && (rocprim::get<1>(a) < rocprim::get<1>(b)));
560  };
561  }
562 
563 public:
564  static constexpr unsigned int items_per_warp = items_per_thread * logical_warp_size;
565 
566  union storage_type
567  {
568  typename keys_load_type::storage_type keys_load;
569  typename values_load_type::storage_type values_load;
570  typename keys_store_type::storage_type keys_store;
571  typename values_store_type::storage_type values_store;
572  typename sort_type::storage_type sort;
573  };
574 
575  template<
576  class KeysInputIterator,
577  class KeysOutputIterator,
578  class ValuesInputIterator,
579  class ValuesOutputIterator
580  >
581  ROCPRIM_DEVICE ROCPRIM_INLINE
582  void sort(KeysInputIterator keys_input,
583  KeysOutputIterator keys_output,
584  ValuesInputIterator values_input,
585  ValuesOutputIterator values_output,
586  unsigned int begin_offset,
587  unsigned int end_offset,
588  unsigned int begin_bit,
589  unsigned int end_bit,
590  storage_type& storage)
591  {
592  const unsigned int num_items = end_offset - begin_offset;
593  const key_type out_of_bounds = key_codec::decode(bit_key_type(-1));
594 
595  key_type keys[items_per_thread];
596  stable_key_type stable_keys[items_per_thread];
597  value_type values[items_per_thread];
598  keys_load_type().load(keys_input + begin_offset, keys, num_items, out_of_bounds, storage.keys_load);
599 
600  ROCPRIM_UNROLL
601  for(unsigned int i = 0; i < items_per_thread; i++)
602  {
603  ::rocprim::get<0>(stable_keys[i]) = keys[i];
604  ::rocprim::get<1>(stable_keys[i]) =
605  ::rocprim::detail::logical_lane_id<logical_warp_size>() + logical_warp_size * i;
606  }
607 
608  if(with_values)
609  {
611  values_load_type().load(values_input + begin_offset, values, num_items, storage.values_load);
612  }
613 
615  if(begin_bit == 0 && end_bit == 8 * sizeof(key_type))
616  {
617  sort_type().sort(stable_keys,
618  values,
619  storage.sort,
620  make_stable_comparator(radix_comparator_type<false>{}));
621  }
622  else
623  {
624  radix_comparator_type<true> comparator(begin_bit, end_bit - begin_bit);
625  sort_type().sort(stable_keys, values, storage.sort, make_stable_comparator(comparator));
626  }
627 
628  ROCPRIM_UNROLL
629  for(unsigned int i = 0; i < items_per_thread; i++)
630  {
631  keys[i] = ::rocprim::get<0>(stable_keys[i]);
632  }
634  keys_store_type().store(keys_output + begin_offset, keys, num_items, storage.keys_store);
635 
636  if(with_values)
637  {
639  values_store_type().store(values_output + begin_offset, values, num_items, storage.values_store);
640  }
641  }
642 
643  template<
644  class KeysInputIterator,
645  class KeysOutputIterator,
646  class ValuesInputIterator,
647  class ValuesOutputIterator
648  >
649  ROCPRIM_DEVICE ROCPRIM_INLINE
650  void sort(KeysInputIterator keys_input,
651  key_type * keys_tmp,
652  KeysOutputIterator keys_output,
653  ValuesInputIterator values_input,
654  value_type * values_tmp,
655  ValuesOutputIterator values_output,
656  bool to_output,
657  unsigned int begin_offset,
658  unsigned int end_offset,
659  unsigned int begin_bit,
660  unsigned int end_bit,
661  storage_type& storage)
662  {
663  if(to_output)
664  {
665  sort(
666  keys_input, keys_output, values_input, values_output,
667  begin_offset, end_offset,
668  begin_bit, end_bit,
669  storage
670  );
671  }
672  else
673  {
674  sort(
675  keys_input, keys_tmp, values_input, values_tmp,
676  begin_offset, end_offset,
677  begin_bit, end_bit,
678  storage
679  );
680  }
681  }
682 };
683 
684 template<
685  class Config,
686  bool Descending,
687  class KeysInputIterator,
688  class KeysOutputIterator,
689  class ValuesInputIterator,
690  class ValuesOutputIterator,
691  class OffsetIterator
692 >
693 ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
694 void segmented_sort(KeysInputIterator keys_input,
695  typename std::iterator_traits<KeysInputIterator>::value_type * keys_tmp,
696  KeysOutputIterator keys_output,
697  ValuesInputIterator values_input,
698  typename std::iterator_traits<ValuesInputIterator>::value_type * values_tmp,
699  ValuesOutputIterator values_output,
700  bool to_output,
701  OffsetIterator begin_offsets,
702  OffsetIterator end_offsets,
703  unsigned int long_iterations,
704  unsigned int short_iterations,
705  unsigned int begin_bit,
706  unsigned int end_bit)
707 {
708  constexpr unsigned int long_radix_bits = Config::long_radix_bits;
709  constexpr unsigned int short_radix_bits = Config::short_radix_bits;
710  constexpr unsigned int block_size = Config::sort::block_size;
711  constexpr unsigned int items_per_thread = Config::sort::items_per_thread;
712  constexpr unsigned int items_per_block = block_size * items_per_thread;
713  constexpr bool warp_sort_enabled = Config::warp_sort_config::enable_unpartitioned_warp_sort;
714 
715  using key_type = typename std::iterator_traits<KeysInputIterator>::value_type;
716  using value_type = typename std::iterator_traits<ValuesInputIterator>::value_type;
717 
718  using single_block_helper_type = segmented_radix_sort_single_block_helper<
719  key_type, value_type,
720  block_size, items_per_thread,
721  Descending
722  >;
723  using long_radix_helper_type = segmented_radix_sort_helper<
724  key_type, value_type,
725  ::rocprim::device_warp_size(), block_size, items_per_thread,
726  long_radix_bits, Descending
727  >;
728  using short_radix_helper_type = segmented_radix_sort_helper<
729  key_type, value_type,
730  ::rocprim::device_warp_size(), block_size, items_per_thread,
731  short_radix_bits, Descending
732  >;
733  using warp_sort_helper_type = segmented_warp_sort_helper<
734  select_warp_sort_helper_config_small_t<typename Config::warp_sort_config>,
735  key_type,
736  value_type,
737  Descending>;
738  static constexpr unsigned int items_per_warp = warp_sort_helper_type::items_per_warp;
739 
740  ROCPRIM_SHARED_MEMORY union
741  {
742  typename single_block_helper_type::storage_type single_block_helper;
743  typename long_radix_helper_type::storage_type long_radix_helper;
744  typename short_radix_helper_type::storage_type short_radix_helper;
745  typename warp_sort_helper_type::storage_type warp_sort_helper;
746  } storage;
747 
748  const unsigned int segment_id = ::rocprim::detail::block_id<0>();
749 
750  const unsigned int begin_offset = begin_offsets[segment_id];
751  const unsigned int end_offset = end_offsets[segment_id];
752 
753  // Empty segment
754  if(end_offset <= begin_offset)
755  {
756  return;
757  }
758 
759  if(end_offset - begin_offset > items_per_block)
760  {
761  // Large segment
762  unsigned int bit = begin_bit;
763  for(unsigned int i = 0; i < long_iterations; i++)
764  {
765  long_radix_helper_type().sort(
766  keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output,
767  to_output,
768  begin_offset, end_offset,
769  bit, begin_bit, end_bit,
770  storage.long_radix_helper
771  );
772 
773  to_output = !to_output;
774  bit += long_radix_bits;
775  }
776  for(unsigned int i = 0; i < short_iterations; i++)
777  {
778  short_radix_helper_type().sort(
779  keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output,
780  to_output,
781  begin_offset, end_offset,
782  bit, begin_bit, end_bit,
783  storage.short_radix_helper
784  );
785 
786  to_output = !to_output;
787  bit += short_radix_bits;
788  }
789  }
790  else if(!warp_sort_enabled || end_offset - begin_offset > items_per_warp)
791  {
792  // Small segment
793  single_block_helper_type().sort(
794  keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output,
795  ((long_iterations + short_iterations) % 2 == 0) != to_output,
796  begin_offset, end_offset,
797  begin_bit, end_bit,
798  storage.single_block_helper
799  );
800  }
801  else if(::rocprim::flat_block_thread_id() < Config::warp_sort_config::logical_warp_size_small)
802  {
803  // Single warp segment
804  warp_sort_helper_type().sort(
805  keys_input, keys_tmp, keys_output,
806  values_input, values_tmp, values_output,
807  ((long_iterations + short_iterations) % 2 == 0) != to_output,
808  begin_offset, end_offset,
809  begin_bit, end_bit, storage.warp_sort_helper
810  );
811  }
812 }
813 
814 template<
815  class Config,
816  bool Descending,
817  class KeysInputIterator,
818  class KeysOutputIterator,
819  class ValuesInputIterator,
820  class ValuesOutputIterator,
821  class SegmentIndexIterator,
822  class OffsetIterator
823 >
824 ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
825 void segmented_sort_large(KeysInputIterator keys_input,
826  typename std::iterator_traits<KeysInputIterator>::value_type * keys_tmp,
827  KeysOutputIterator keys_output,
828  ValuesInputIterator values_input,
829  typename std::iterator_traits<ValuesInputIterator>::value_type * values_tmp,
830  ValuesOutputIterator values_output,
831  bool to_output,
832  SegmentIndexIterator segment_indices,
833  OffsetIterator begin_offsets,
834  OffsetIterator end_offsets,
835  unsigned int long_iterations,
836  unsigned int short_iterations,
837  unsigned int begin_bit,
838  unsigned int end_bit)
839 {
840  constexpr unsigned int long_radix_bits = Config::long_radix_bits;
841  constexpr unsigned int short_radix_bits = Config::short_radix_bits;
842  constexpr unsigned int block_size = Config::sort::block_size;
843  constexpr unsigned int items_per_thread = Config::sort::items_per_thread;
844  constexpr unsigned int items_per_block = block_size * items_per_thread;
845 
846  using key_type = typename std::iterator_traits<KeysInputIterator>::value_type;
847  using value_type = typename std::iterator_traits<ValuesInputIterator>::value_type;
848 
849  using single_block_helper_type = segmented_radix_sort_single_block_helper<
850  key_type, value_type,
851  block_size, items_per_thread,
852  Descending
853  >;
854  using long_radix_helper_type = segmented_radix_sort_helper<
855  key_type, value_type,
856  ::rocprim::device_warp_size(), block_size, items_per_thread,
857  long_radix_bits, Descending
858  >;
859  using short_radix_helper_type = segmented_radix_sort_helper<
860  key_type, value_type,
861  ::rocprim::device_warp_size(), block_size, items_per_thread,
862  short_radix_bits, Descending
863  >;
864 
865  ROCPRIM_SHARED_MEMORY union
866  {
867  typename single_block_helper_type::storage_type single_block_helper;
868  typename long_radix_helper_type::storage_type long_radix_helper;
869  typename short_radix_helper_type::storage_type short_radix_helper;
870  } storage;
871 
872  const unsigned int block_id = ::rocprim::detail::block_id<0>();
873  const unsigned int segment_id = segment_indices[block_id];
874  const unsigned int begin_offset = begin_offsets[segment_id];
875  const unsigned int end_offset = end_offsets[segment_id];
876 
877  if(end_offset <= begin_offset)
878  {
879  return;
880  }
881 
882  if(end_offset - begin_offset > items_per_block)
883  {
884  unsigned int bit = begin_bit;
885  for(unsigned int i = 0; i < long_iterations; i++)
886  {
887  long_radix_helper_type().sort(
888  keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output,
889  to_output,
890  begin_offset, end_offset,
891  bit, begin_bit, end_bit,
892  storage.long_radix_helper
893  );
894 
895  to_output = !to_output;
896  bit += long_radix_bits;
897  }
898  for(unsigned int i = 0; i < short_iterations; i++)
899  {
900  short_radix_helper_type().sort(
901  keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output,
902  to_output,
903  begin_offset, end_offset,
904  bit, begin_bit, end_bit,
905  storage.short_radix_helper
906  );
907 
908  to_output = !to_output;
909  bit += short_radix_bits;
910  }
911  }
912  else
913  {
914  single_block_helper_type().sort(
915  keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output,
916  ((long_iterations + short_iterations) % 2 == 0) != to_output,
917  begin_offset, end_offset,
918  begin_bit, end_bit,
919  storage.single_block_helper
920  );
921  }
922 }
923 
924 template<
925  class Config,
926  bool Descending,
927  class KeysInputIterator,
928  class KeysOutputIterator,
929  class ValuesInputIterator,
930  class ValuesOutputIterator,
931  class SegmentIndexIterator,
932  class OffsetIterator
933 >
934 ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
935 void segmented_sort_small(KeysInputIterator keys_input,
936  typename std::iterator_traits<KeysInputIterator>::value_type * keys_tmp,
937  KeysOutputIterator keys_output,
938  ValuesInputIterator values_input,
939  typename std::iterator_traits<ValuesInputIterator>::value_type * values_tmp,
940  ValuesOutputIterator values_output,
941  bool to_output,
942  unsigned int num_segments,
943  SegmentIndexIterator segment_indices,
944  OffsetIterator begin_offsets,
945  OffsetIterator end_offsets,
946  unsigned int begin_bit,
947  unsigned int end_bit)
948 {
949  static constexpr unsigned int block_size = Config::block_size;
950  static constexpr unsigned int logical_warp_size = Config::logical_warp_size;
951  static_assert(block_size % logical_warp_size == 0, "logical_warp_size must be a divisor of block_size");
952  static constexpr unsigned int warps_per_block = block_size / logical_warp_size;
953 
954  using key_type = typename std::iterator_traits<KeysInputIterator>::value_type;
955  using value_type = typename std::iterator_traits<ValuesInputIterator>::value_type;
956 
957  using warp_sort_helper_type = segmented_warp_sort_helper<
958  Config, key_type, value_type, Descending
959  >;
960 
961  ROCPRIM_SHARED_MEMORY typename warp_sort_helper_type::storage_type storage;
962 
963  const unsigned int block_id = ::rocprim::detail::block_id<0>();
964  const unsigned int logical_warp_id = ::rocprim::detail::logical_warp_id<logical_warp_size>();
965  const unsigned int segment_index = block_id * warps_per_block + logical_warp_id;
966  if(segment_index >= num_segments)
967  {
968  return;
969  }
970 
971  const unsigned int segment_id = segment_indices[segment_index];
972  const unsigned int begin_offset = begin_offsets[segment_id];
973  const unsigned int end_offset = end_offsets[segment_id];
974  if(end_offset <= begin_offset)
975  {
976  return;
977  }
978  warp_sort_helper_type().sort(
979  keys_input, keys_tmp, keys_output,
980  values_input, values_tmp, values_output,
981  to_output, begin_offset, end_offset,
982  begin_bit, end_bit, storage
983  );
984 }
985 
986 } // end namespace detail
987 
988 END_ROCPRIM_NAMESPACE
989 
990 #endif // ROCPRIM_DEVICE_DETAIL_DEVICE_SEGMENTED_RADIX_SORT_HPP_
ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int block_id()
Returns block identifier in a multidimensional grid by dimension.
Definition: thread.hpp:258
Definition: device_segmented_radix_sort.hpp:512
ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int flat_block_thread_id()
Returns flat (linear, 1D) thread identifier in a multidimensional block (tile).
Definition: thread.hpp:106
Definition: device_radix_sort.hpp:312
Definition: device_segmented_radix_sort.hpp:482
ROCPRIM_DEVICE ROCPRIM_INLINE constexpr unsigned int device_warp_size()
Returns a number of threads in a hardware warp for the actual target.
Definition: thread.hpp:70
ROCPRIM_HOST_DEVICE constexpr T min(const T &a, const T &b)
Returns the minimum of its arguments.
Definition: functional.hpp:63
ROCPRIM_DEVICE ROCPRIM_INLINE void wave_barrier()
Synchronize all threads in the wavefront.
Definition: thread.hpp:235
Definition: device_radix_sort.hpp:335
Deprecated: Configuration of device-level scan primitives.
Definition: block_histogram.hpp:62
ROCPRIM_DEVICE ROCPRIM_INLINE void syncthreads()
Synchronize all threads in a block (tile)
Definition: thread.hpp:216
Definition: device_segmented_radix_sort.hpp:59
Definition: device_segmented_radix_sort.hpp:74
Definition: radix_sort.hpp:241
ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int block_size()
Returns block size in a multidimensional grid by dimension.
Definition: thread.hpp:268
Definition: device_radix_sort.hpp:106
Definition: device_segmented_radix_sort.hpp:272
Definition: device_segmented_radix_sort.hpp:475
Definition: device_radix_sort.hpp:97
Definition: device_segmented_radix_sort.hpp:297