rocPRIM
device_segmented_radix_sort.hpp
1 // Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved.
2 //
3 // Permission is hereby granted, free of charge, to any person obtaining a copy
4 // of this software and associated documentation files (the "Software"), to deal
5 // in the Software without restriction, including without limitation the rights
6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 // copies of the Software, and to permit persons to whom the Software is
8 // furnished to do so, subject to the following conditions:
9 //
10 // The above copyright notice and this permission notice shall be included in
11 // all copies or substantial portions of the Software.
12 //
13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 // THE SOFTWARE.
20 
21 #ifndef ROCPRIM_DEVICE_DEVICE_SEGMENTED_RADIX_SORT_HPP_
22 #define ROCPRIM_DEVICE_DEVICE_SEGMENTED_RADIX_SORT_HPP_
23 
24 #include <iostream>
25 #include <iterator>
26 #include <type_traits>
27 #include <utility>
28 
29 #include "../config.hpp"
30 #include "../detail/various.hpp"
31 #include "../detail/radix_sort.hpp"
32 
33 #include "../intrinsics.hpp"
34 #include "../functional.hpp"
35 #include "../types.hpp"
36 
37 #include "../block/block_load.hpp"
38 #include "../iterator/counting_iterator.hpp"
39 #include "../iterator/reverse_iterator.hpp"
40 #include "detail/device_segmented_radix_sort.hpp"
41 #include "device_partition.hpp"
42 #include "device_segmented_radix_sort_config.hpp"
43 
46 
47 BEGIN_ROCPRIM_NAMESPACE
48 
49 namespace detail
50 {
51 
52 template<
53  class Config,
54  bool Descending,
55  unsigned int BlockSize,
56  class KeysInputIterator,
57  class KeysOutputIterator,
58  class ValuesInputIterator,
59  class ValuesOutputIterator,
60  class OffsetIterator
61 >
62 ROCPRIM_KERNEL
63 __launch_bounds__(BlockSize)
64 void segmented_sort_kernel(KeysInputIterator keys_input,
65  typename std::iterator_traits<KeysInputIterator>::value_type * keys_tmp,
66  KeysOutputIterator keys_output,
67  ValuesInputIterator values_input,
68  typename std::iterator_traits<ValuesInputIterator>::value_type * values_tmp,
69  ValuesOutputIterator values_output,
70  bool to_output,
71  OffsetIterator begin_offsets,
72  OffsetIterator end_offsets,
73  unsigned int long_iterations,
74  unsigned int short_iterations,
75  unsigned int begin_bit,
76  unsigned int end_bit)
77 {
78  segmented_sort<Config, Descending>(
79  keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output,
80  to_output,
81  begin_offsets, end_offsets,
82  long_iterations, short_iterations,
83  begin_bit, end_bit
84  );
85 }
86 
87 template<
88  class Config,
89  bool Descending,
90  unsigned int BlockSize,
91  class KeysInputIterator,
92  class KeysOutputIterator,
93  class ValuesInputIterator,
94  class ValuesOutputIterator,
95  class SegmentIndexIterator,
96  class OffsetIterator
97 >
98 ROCPRIM_KERNEL
99 __launch_bounds__(BlockSize)
100 void segmented_sort_large_kernel(KeysInputIterator keys_input,
101  typename std::iterator_traits<KeysInputIterator>::value_type * keys_tmp,
102  KeysOutputIterator keys_output,
103  ValuesInputIterator values_input,
104  typename std::iterator_traits<ValuesInputIterator>::value_type * values_tmp,
105  ValuesOutputIterator values_output,
106  bool to_output,
107  SegmentIndexIterator segment_indices,
108  OffsetIterator begin_offsets,
109  OffsetIterator end_offsets,
110  unsigned int long_iterations,
111  unsigned int short_iterations,
112  unsigned int begin_bit,
113  unsigned int end_bit)
114 {
115  segmented_sort_large<Config, Descending>(
116  keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output,
117  to_output, segment_indices,
118  begin_offsets, end_offsets,
119  long_iterations, short_iterations,
120  begin_bit, end_bit
121  );
122 }
123 
124 template<class Config,
125  bool Descending,
126  unsigned int BlockSize,
127  class KeysInputIterator,
128  class KeysOutputIterator,
129  class ValuesInputIterator,
130  class ValuesOutputIterator,
131  class SegmentIndexIterator,
132  class OffsetIterator>
133 ROCPRIM_KERNEL __launch_bounds__(BlockSize) void segmented_sort_small_or_medium_kernel(
134  KeysInputIterator keys_input,
135  typename std::iterator_traits<KeysInputIterator>::value_type* keys_tmp,
136  KeysOutputIterator keys_output,
137  ValuesInputIterator values_input,
138  typename std::iterator_traits<ValuesInputIterator>::value_type* values_tmp,
139  ValuesOutputIterator values_output,
140  bool to_output,
141  unsigned int num_segments,
142  SegmentIndexIterator segment_indices,
143  OffsetIterator begin_offsets,
144  OffsetIterator end_offsets,
145  unsigned int begin_bit,
146  unsigned int end_bit)
147 {
148  segmented_sort_small<Config, Descending>(
149  keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output,
150  to_output, num_segments, segment_indices,
151  begin_offsets, end_offsets,
152  begin_bit, end_bit
153  );
154 }
155 
156 #define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \
157  { \
158  auto _error = hipGetLastError(); \
159  if(_error != hipSuccess) return _error; \
160  if(debug_synchronous) \
161  { \
162  std::cout << name << "(" << size << ")"; \
163  auto __error = hipStreamSynchronize(stream); \
164  if(__error != hipSuccess) return __error; \
165  auto _end = std::chrono::high_resolution_clock::now(); \
166  auto _d = std::chrono::duration_cast<std::chrono::duration<double>>(_end - start); \
167  std::cout << " " << _d.count() * 1000 << " ms" << '\n'; \
168  } \
169  }
170 
172 {
173  template<typename InputIterator,
174  typename FirstOutputIterator,
175  typename SecondOutputIterator,
176  typename UnselectedOutputIterator,
177  typename SelectedCountOutputIterator,
178  typename FirstUnaryPredicate,
179  typename SecondUnaryPredicate>
180  hipError_t operator()(void* temporary_storage,
181  size_t& storage_size,
182  InputIterator input,
183  FirstOutputIterator output_first_part,
184  SecondOutputIterator /*output_second_part*/,
185  UnselectedOutputIterator /*output_unselected*/,
186  SelectedCountOutputIterator selected_count_output,
187  const size_t size,
188  FirstUnaryPredicate select_first_part_op,
189  SecondUnaryPredicate /*select_second_part_op*/,
190  const hipStream_t stream,
191  const bool debug_synchronous)
192  {
193  return partition(temporary_storage,
194  storage_size,
195  input,
196  output_first_part,
197  selected_count_output,
198  size,
199  select_first_part_op,
200  stream,
201  debug_synchronous);
202  }
203 };
204 
206 {
207  template<typename InputIterator,
208  typename FirstOutputIterator,
209  typename SecondOutputIterator,
210  typename UnselectedOutputIterator,
211  typename SelectedCountOutputIterator,
212  typename FirstUnaryPredicate,
213  typename SecondUnaryPredicate>
214  hipError_t operator()(void* temporary_storage,
215  size_t& storage_size,
216  InputIterator input,
217  FirstOutputIterator output_first_part,
218  SecondOutputIterator output_second_part,
219  UnselectedOutputIterator output_unselected,
220  SelectedCountOutputIterator selected_count_output,
221  const size_t size,
222  FirstUnaryPredicate select_first_part_op,
223  SecondUnaryPredicate select_second_part_op,
224  const hipStream_t stream,
225  const bool debug_synchronous)
226  {
227  return partition_three_way(temporary_storage,
228  storage_size,
229  input,
230  output_first_part,
231  output_second_part,
232  output_unselected,
233  selected_count_output,
234  size,
235  select_first_part_op,
236  select_second_part_op,
237  stream,
238  debug_synchronous);
239  }
240 };
241 
242 template<
243  class Config,
244  bool Descending,
245  class KeysInputIterator,
246  class KeysOutputIterator,
247  class ValuesInputIterator,
248  class ValuesOutputIterator,
249  class OffsetIterator
250 >
251 inline
252 hipError_t segmented_radix_sort_impl(void * temporary_storage,
253  size_t& storage_size,
254  KeysInputIterator keys_input,
255  typename std::iterator_traits<KeysInputIterator>::value_type * keys_tmp,
256  KeysOutputIterator keys_output,
257  ValuesInputIterator values_input,
258  typename std::iterator_traits<ValuesInputIterator>::value_type * values_tmp,
259  ValuesOutputIterator values_output,
260  unsigned int size,
261  bool& is_result_in_output,
262  unsigned int segments,
263  OffsetIterator begin_offsets,
264  OffsetIterator end_offsets,
265  unsigned int begin_bit,
266  unsigned int end_bit,
267  hipStream_t stream,
268  bool debug_synchronous)
269 {
270  using key_type = typename std::iterator_traits<KeysInputIterator>::value_type;
271  using value_type = typename std::iterator_traits<ValuesInputIterator>::value_type;
272  using segment_index_type = unsigned int;
273  using segment_index_iterator = counting_iterator<segment_index_type>;
274 
275  static_assert(
276  std::is_same<key_type, typename std::iterator_traits<KeysOutputIterator>::value_type>::value,
277  "KeysInputIterator and KeysOutputIterator must have the same value_type"
278  );
279  static_assert(
280  std::is_same<value_type, typename std::iterator_traits<ValuesOutputIterator>::value_type>::value,
281  "ValuesInputIterator and ValuesOutputIterator must have the same value_type"
282  );
283 
284  using config = default_or_custom_config<
285  Config,
287  >;
288 
289  static constexpr bool with_values = !std::is_same<value_type, ::rocprim::empty_type>::value;
290  static constexpr bool partitioning_allowed =
291  !std::is_same<typename config::warp_sort_config, DisabledWarpSortConfig>::value;
292  static constexpr unsigned int max_small_segment_length
293  = config::warp_sort_config::items_per_thread_small
294  * config::warp_sort_config::logical_warp_size_small;
295  static constexpr unsigned int small_segments_per_block
296  = config::warp_sort_config::block_size_small
297  / config::warp_sort_config::logical_warp_size_small;
298  static constexpr unsigned int max_medium_segment_length
299  = config::warp_sort_config::items_per_thread_medium
300  * config::warp_sort_config::logical_warp_size_medium;
301  static constexpr unsigned int medium_segments_per_block
302  = config::warp_sort_config::block_size_medium
303  / config::warp_sort_config::logical_warp_size_medium;
304  static_assert(
305  max_small_segment_length <= max_medium_segment_length,
306  "The max length of small segments cannot be higher than the max length of medium segments");
307  // Don't waste cycles on 3-way partitioning, if the small and medium segments are equal length
308  static constexpr bool three_way_partitioning
309  = max_small_segment_length < max_medium_segment_length;
310  using partitioner_type
311  = std::conditional_t<three_way_partitioning, ThreeWayPartitioner, TwoWayPartitioner>;
312  partitioner_type partitioner;
313 
314  const auto large_segment_selector = [=](const unsigned int segment_index) mutable -> bool
315  {
316  const unsigned int segment_length
317  = end_offsets[segment_index] - begin_offsets[segment_index];
318  return segment_length > max_medium_segment_length;
319  };
320  const auto medium_segment_selector = [=](const unsigned int segment_index) mutable -> bool
321  {
322  const unsigned int segment_length = end_offsets[segment_index] - begin_offsets[segment_index];
323  return segment_length > max_small_segment_length;
324  };
325 
326  const bool with_double_buffer = keys_tmp != nullptr;
327  const unsigned int bits = end_bit - begin_bit;
328  const unsigned int iterations = ::rocprim::detail::ceiling_div(bits, config::long_radix_bits);
329  const bool to_output = with_double_buffer || (iterations - 1) % 2 == 0;
330  is_result_in_output = (iterations % 2 == 0) != to_output;
331  const unsigned int radix_bits_diff = config::long_radix_bits - config::short_radix_bits;
332  const unsigned int short_iterations = radix_bits_diff != 0
333  ? ::rocprim::min(iterations, (config::long_radix_bits * iterations - bits) / radix_bits_diff)
334  : 0;
335  const unsigned int long_iterations = iterations - short_iterations;
336  const bool do_partitioning = partitioning_allowed
337  && segments >= config::warp_sort_config::partitioning_threshold;
338 
339  const size_t medium_segment_indices_size = three_way_partitioning ? segments : 0;
340  static constexpr size_t segment_count_output_size = three_way_partitioning ? 2 : 1;
341  const size_t segment_count_output_bytes
342  = segment_count_output_size * sizeof(segment_index_type);
343 
344  segment_index_type* large_segment_indices_output{};
345  // The total number of large and small segments is not above the number of segments
346  // The same buffer is filled with the large and small indices from both directions
347  auto small_segment_indices_output
348  = make_reverse_iterator(large_segment_indices_output + segments);
349  key_type* keys_tmp_storage;
350  value_type* values_tmp_storage;
351  segment_index_type* medium_segment_indices_output{};
352  segment_index_type* segment_count_output{};
353  size_t partition_storage_size{};
354  void* partition_temporary_storage{};
355 
356  const auto partitioner_result = partitioner(nullptr,
357  partition_storage_size,
358  segment_index_iterator{},
359  large_segment_indices_output,
360  medium_segment_indices_output,
361  small_segment_indices_output,
362  segment_count_output,
363  segments,
364  large_segment_selector,
365  medium_segment_selector,
366  stream,
367  debug_synchronous);
368  if(hipSuccess != partitioner_result)
369  {
370  return partitioner_result;
371  }
372 
373  const hipError_t partition_result = detail::temp_storage::partition(
374  temporary_storage,
375  storage_size,
376  detail::temp_storage::make_linear_partition(
377  // These are required by both partitioning and by sorting.
378  detail::temp_storage::ptr_aligned_array(&large_segment_indices_output, segments),
379  detail::temp_storage::ptr_aligned_array(&medium_segment_indices_output,
380  medium_segment_indices_size),
381  detail::temp_storage::ptr_aligned_array(&segment_count_output,
382  segment_count_output_size),
383  detail::temp_storage::make_union_partition(
384  // Partition temporary storage only needed by partitioning.
385  detail::temp_storage::make_partition(&partition_temporary_storage,
386  partition_storage_size),
387  // Keys/values temporary storage only needed by sorting.
388  detail::temp_storage::make_linear_partition(
389  detail::temp_storage::ptr_aligned_array(&keys_tmp_storage,
390  !with_double_buffer ? size : 0),
391  detail::temp_storage::ptr_aligned_array(
392  &values_tmp_storage,
393  !with_double_buffer && with_values ? size : 0)))));
394  if(partition_result != hipSuccess || temporary_storage == nullptr)
395  {
396  return partition_result;
397  }
398 
399  if(segments == 0u)
400  {
401  return hipSuccess;
402  }
403  if(debug_synchronous)
404  {
405  std::cout << "begin_bit " << begin_bit << '\n';
406  std::cout << "end_bit " << end_bit << '\n';
407  std::cout << "bits " << bits << '\n';
408  std::cout << "segments " << segments << '\n';
409  std::cout << "radix_bits_diff " << radix_bits_diff << '\n';
410  std::cout << "storage_size " << storage_size << '\n';
411  std::cout << "iterations " << iterations << '\n';
412  std::cout << "long_iterations " << long_iterations << '\n';
413  std::cout << "short_iterations " << short_iterations << '\n';
414  std::cout << "do_partitioning " << do_partitioning << '\n';
415  std::cout << "config::sort::block_size: " << config::sort::block_size << '\n';
416  std::cout << "config::sort::items_per_thread: " << config::sort::items_per_thread << '\n';
417  hipError_t error = hipStreamSynchronize(stream);
418  if(error != hipSuccess) return error;
419  }
420 
421  if(!with_double_buffer)
422  {
423  keys_tmp = keys_tmp_storage;
424  values_tmp = values_tmp_storage;
425  }
426  small_segment_indices_output = make_reverse_iterator(large_segment_indices_output + segments);
427 
428  if(do_partitioning)
429  {
430  hipError_t result = partitioner(partition_temporary_storage,
431  partition_storage_size,
432  segment_index_iterator{},
433  large_segment_indices_output,
434  medium_segment_indices_output,
435  small_segment_indices_output,
436  segment_count_output,
437  segments,
438  large_segment_selector,
439  medium_segment_selector,
440  stream,
441  debug_synchronous);
442  if(hipSuccess != result)
443  {
444  return result;
445  }
446  segment_index_type segment_counts[segment_count_output_size]{};
447  result = detail::memcpy_and_sync(&segment_counts,
448  segment_count_output,
449  segment_count_output_bytes,
450  hipMemcpyDeviceToHost,
451  stream);
452  if(hipSuccess != result)
453  {
454  return result;
455  }
456  const auto large_segment_count = segment_counts[0];
457  const auto medium_segment_count = three_way_partitioning ? segment_counts[1] : 0;
458  const auto small_segment_count = segments - large_segment_count - medium_segment_count;
459  if(debug_synchronous)
460  {
461  std::cout << "large_segment_count " << large_segment_count << '\n';
462  std::cout << "medium_segment_count " << medium_segment_count << '\n';
463  std::cout << "small_segment_count " << small_segment_count << '\n';
464  }
465  if(large_segment_count > 0)
466  {
467  std::chrono::high_resolution_clock::time_point start;
468  if(debug_synchronous) start = std::chrono::high_resolution_clock::now();
469  hipLaunchKernelGGL(
470  HIP_KERNEL_NAME(segmented_sort_large_kernel<config, Descending, config::sort::block_size>),
471  dim3(large_segment_count), dim3(config::sort::block_size), 0, stream,
472  keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output,
473  to_output, large_segment_indices_output,
474  begin_offsets, end_offsets,
475  long_iterations, short_iterations,
476  begin_bit, end_bit
477  );
478  ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("segmented_sort:large_segments",
479  large_segment_count,
480  start)
481  }
482  if(three_way_partitioning && medium_segment_count > 0)
483  {
484  const auto medium_segment_grid_size
485  = ::rocprim::detail::ceiling_div(medium_segment_count, medium_segments_per_block);
486  std::chrono::high_resolution_clock::time_point start;
487  if(debug_synchronous)
488  start = std::chrono::high_resolution_clock::now();
489  hipLaunchKernelGGL(
490  HIP_KERNEL_NAME(
491  segmented_sort_small_or_medium_kernel<
492  select_warp_sort_helper_config_medium_t<typename config::warp_sort_config>,
493  Descending,
494  config::warp_sort_config::block_size_medium>),
495  dim3(medium_segment_grid_size),
496  dim3(config::warp_sort_config::block_size_medium),
497  0,
498  stream,
499  keys_input,
500  keys_tmp,
501  keys_output,
502  values_input,
503  values_tmp,
504  values_output,
505  is_result_in_output,
506  medium_segment_count,
507  medium_segment_indices_output,
508  begin_offsets,
509  end_offsets,
510  begin_bit,
511  end_bit);
512  ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("segmented_sort:medium_segments",
513  medium_segment_count,
514  start)
515  }
516  if(small_segment_count > 0)
517  {
518  const auto small_segment_grid_size = ::rocprim::detail::ceiling_div(small_segment_count,
519  small_segments_per_block);
520  std::chrono::high_resolution_clock::time_point start;
521  if(debug_synchronous) start = std::chrono::high_resolution_clock::now();
522  hipLaunchKernelGGL(
523  HIP_KERNEL_NAME(
524  segmented_sort_small_or_medium_kernel<
525  select_warp_sort_helper_config_small_t<typename config::warp_sort_config>,
526  Descending,
527  config::warp_sort_config::block_size_small>),
528  dim3(small_segment_grid_size),
529  dim3(config::warp_sort_config::block_size_small),
530  0,
531  stream,
532  keys_input,
533  keys_tmp,
534  keys_output,
535  values_input,
536  values_tmp,
537  values_output,
538  is_result_in_output,
539  small_segment_count,
540  small_segment_indices_output,
541  begin_offsets,
542  end_offsets,
543  begin_bit,
544  end_bit);
545  ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("segmented_sort:small_segments",
546  small_segment_count,
547  start)
548  }
549  }
550  else
551  {
552  std::chrono::high_resolution_clock::time_point start;
553  if(debug_synchronous) start = std::chrono::high_resolution_clock::now();
554  hipLaunchKernelGGL(
555  HIP_KERNEL_NAME(segmented_sort_kernel<config, Descending, config::sort::block_size>),
556  dim3(segments), dim3(config::sort::block_size), 0, stream,
557  keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output,
558  to_output,
559  begin_offsets, end_offsets,
560  long_iterations, short_iterations,
561  begin_bit, end_bit
562  );
563  ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("segmented_sort", segments, start)
564  }
565  return hipSuccess;
566 }
567 
568 #undef ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR
569 
570 } // end namespace detail
571 
660 template<
661  class Config = default_config,
662  class KeysInputIterator,
663  class KeysOutputIterator,
664  class OffsetIterator,
665  class Key = typename std::iterator_traits<KeysInputIterator>::value_type
666 >
667 inline
668 hipError_t segmented_radix_sort_keys(void * temporary_storage,
669  size_t& storage_size,
670  KeysInputIterator keys_input,
671  KeysOutputIterator keys_output,
672  unsigned int size,
673  unsigned int segments,
674  OffsetIterator begin_offsets,
675  OffsetIterator end_offsets,
676  unsigned int begin_bit = 0,
677  unsigned int end_bit = 8 * sizeof(Key),
678  hipStream_t stream = 0,
679  bool debug_synchronous = false)
680 {
681  empty_type * values = nullptr;
682  bool ignored;
683  return detail::segmented_radix_sort_impl<Config, false>(
684  temporary_storage, storage_size,
685  keys_input, nullptr, keys_output,
686  values, nullptr, values,
687  size, ignored,
688  segments, begin_offsets, end_offsets,
689  begin_bit, end_bit,
690  stream, debug_synchronous
691  );
692 }
693 
782 template<
783  class Config = default_config,
784  class KeysInputIterator,
785  class KeysOutputIterator,
786  class OffsetIterator,
787  class Key = typename std::iterator_traits<KeysInputIterator>::value_type
788 >
789 inline
790 hipError_t segmented_radix_sort_keys_desc(void * temporary_storage,
791  size_t& storage_size,
792  KeysInputIterator keys_input,
793  KeysOutputIterator keys_output,
794  unsigned int size,
795  unsigned int segments,
796  OffsetIterator begin_offsets,
797  OffsetIterator end_offsets,
798  unsigned int begin_bit = 0,
799  unsigned int end_bit = 8 * sizeof(Key),
800  hipStream_t stream = 0,
801  bool debug_synchronous = false)
802 {
803  empty_type * values = nullptr;
804  bool ignored;
805  return detail::segmented_radix_sort_impl<Config, true>(
806  temporary_storage, storage_size,
807  keys_input, nullptr, keys_output,
808  values, nullptr, values,
809  size, ignored,
810  segments, begin_offsets, end_offsets,
811  begin_bit, end_bit,
812  stream, debug_synchronous
813  );
814 }
815 
920 template<
921  class Config = default_config,
922  class KeysInputIterator,
923  class KeysOutputIterator,
924  class ValuesInputIterator,
925  class ValuesOutputIterator,
926  class OffsetIterator,
927  class Key = typename std::iterator_traits<KeysInputIterator>::value_type
928 >
929 inline
930 hipError_t segmented_radix_sort_pairs(void * temporary_storage,
931  size_t& storage_size,
932  KeysInputIterator keys_input,
933  KeysOutputIterator keys_output,
934  ValuesInputIterator values_input,
935  ValuesOutputIterator values_output,
936  unsigned int size,
937  unsigned int segments,
938  OffsetIterator begin_offsets,
939  OffsetIterator end_offsets,
940  unsigned int begin_bit = 0,
941  unsigned int end_bit = 8 * sizeof(Key),
942  hipStream_t stream = 0,
943  bool debug_synchronous = false)
944 {
945  bool ignored;
946  return detail::segmented_radix_sort_impl<Config, false>(
947  temporary_storage, storage_size,
948  keys_input, nullptr, keys_output,
949  values_input, nullptr, values_output,
950  size, ignored,
951  segments, begin_offsets, end_offsets,
952  begin_bit, end_bit,
953  stream, debug_synchronous
954  );
955 }
956 
1057 template<
1058  class Config = default_config,
1059  class KeysInputIterator,
1060  class KeysOutputIterator,
1061  class ValuesInputIterator,
1062  class ValuesOutputIterator,
1063  class OffsetIterator,
1064  class Key = typename std::iterator_traits<KeysInputIterator>::value_type
1065 >
1066 inline
1067 hipError_t segmented_radix_sort_pairs_desc(void * temporary_storage,
1068  size_t& storage_size,
1069  KeysInputIterator keys_input,
1070  KeysOutputIterator keys_output,
1071  ValuesInputIterator values_input,
1072  ValuesOutputIterator values_output,
1073  unsigned int size,
1074  unsigned int segments,
1075  OffsetIterator begin_offsets,
1076  OffsetIterator end_offsets,
1077  unsigned int begin_bit = 0,
1078  unsigned int end_bit = 8 * sizeof(Key),
1079  hipStream_t stream = 0,
1080  bool debug_synchronous = false)
1081 {
1082  bool ignored;
1083  return detail::segmented_radix_sort_impl<Config, true>(
1084  temporary_storage, storage_size,
1085  keys_input, nullptr, keys_output,
1086  values_input, nullptr, values_output,
1087  size, ignored,
1088  segments, begin_offsets, end_offsets,
1089  begin_bit, end_bit,
1090  stream, debug_synchronous
1091  );
1092 }
1093 
1186 template<
1187  class Config = default_config,
1188  class Key,
1189  class OffsetIterator
1190 >
1191 inline
1192 hipError_t segmented_radix_sort_keys(void * temporary_storage,
1193  size_t& storage_size,
1194  double_buffer<Key>& keys,
1195  unsigned int size,
1196  unsigned int segments,
1197  OffsetIterator begin_offsets,
1198  OffsetIterator end_offsets,
1199  unsigned int begin_bit = 0,
1200  unsigned int end_bit = 8 * sizeof(Key),
1201  hipStream_t stream = 0,
1202  bool debug_synchronous = false)
1203 {
1204  empty_type * values = nullptr;
1205  bool is_result_in_output;
1206  hipError_t error = detail::segmented_radix_sort_impl<Config, false>(
1207  temporary_storage, storage_size,
1208  keys.current(), keys.current(), keys.alternate(),
1209  values, values, values,
1210  size, is_result_in_output,
1211  segments, begin_offsets, end_offsets,
1212  begin_bit, end_bit,
1213  stream, debug_synchronous
1214  );
1215  if(temporary_storage != nullptr && is_result_in_output)
1216  {
1217  keys.swap();
1218  }
1219  return error;
1220 }
1221 
1314 template<
1315  class Config = default_config,
1316  class Key,
1317  class OffsetIterator
1318 >
1319 inline
1320 hipError_t segmented_radix_sort_keys_desc(void * temporary_storage,
1321  size_t& storage_size,
1322  double_buffer<Key>& keys,
1323  unsigned int size,
1324  unsigned int segments,
1325  OffsetIterator begin_offsets,
1326  OffsetIterator end_offsets,
1327  unsigned int begin_bit = 0,
1328  unsigned int end_bit = 8 * sizeof(Key),
1329  hipStream_t stream = 0,
1330  bool debug_synchronous = false)
1331 {
1332  empty_type * values = nullptr;
1333  bool is_result_in_output;
1334  hipError_t error = detail::segmented_radix_sort_impl<Config, true>(
1335  temporary_storage, storage_size,
1336  keys.current(), keys.current(), keys.alternate(),
1337  values, values, values,
1338  size, is_result_in_output,
1339  segments, begin_offsets, end_offsets,
1340  begin_bit, end_bit,
1341  stream, debug_synchronous
1342  );
1343  if(temporary_storage != nullptr && is_result_in_output)
1344  {
1345  keys.swap();
1346  }
1347  return error;
1348 }
1349 
1455 template<
1456  class Config = default_config,
1457  class Key,
1458  class Value,
1459  class OffsetIterator
1460 >
1461 inline
1462 hipError_t segmented_radix_sort_pairs(void * temporary_storage,
1463  size_t& storage_size,
1464  double_buffer<Key>& keys,
1465  double_buffer<Value>& values,
1466  unsigned int size,
1467  unsigned int segments,
1468  OffsetIterator begin_offsets,
1469  OffsetIterator end_offsets,
1470  unsigned int begin_bit = 0,
1471  unsigned int end_bit = 8 * sizeof(Key),
1472  hipStream_t stream = 0,
1473  bool debug_synchronous = false)
1474 {
1475  bool is_result_in_output;
1476  hipError_t error = detail::segmented_radix_sort_impl<Config, false>(
1477  temporary_storage, storage_size,
1478  keys.current(), keys.current(), keys.alternate(),
1479  values.current(), values.current(), values.alternate(),
1480  size, is_result_in_output,
1481  segments, begin_offsets, end_offsets,
1482  begin_bit, end_bit,
1483  stream, debug_synchronous
1484  );
1485  if(temporary_storage != nullptr && is_result_in_output)
1486  {
1487  keys.swap();
1488  values.swap();
1489  }
1490  return error;
1491 }
1492 
1592 template<
1593  class Config = default_config,
1594  class Key,
1595  class Value,
1596  class OffsetIterator
1597 >
1598 inline
1599 hipError_t segmented_radix_sort_pairs_desc(void * temporary_storage,
1600  size_t& storage_size,
1601  double_buffer<Key>& keys,
1602  double_buffer<Value>& values,
1603  unsigned int size,
1604  unsigned int segments,
1605  OffsetIterator begin_offsets,
1606  OffsetIterator end_offsets,
1607  unsigned int begin_bit = 0,
1608  unsigned int end_bit = 8 * sizeof(Key),
1609  hipStream_t stream = 0,
1610  bool debug_synchronous = false)
1611 {
1612  bool is_result_in_output;
1613  hipError_t error = detail::segmented_radix_sort_impl<Config, true>(
1614  temporary_storage, storage_size,
1615  keys.current(), keys.current(), keys.alternate(),
1616  values.current(), values.current(), values.alternate(),
1617  size, is_result_in_output,
1618  segments, begin_offsets, end_offsets,
1619  begin_bit, end_bit,
1620  stream, debug_synchronous
1621  );
1622  if(temporary_storage != nullptr && is_result_in_output)
1623  {
1624  keys.swap();
1625  values.swap();
1626  }
1627  return error;
1628 }
1629 
1630 END_ROCPRIM_NAMESPACE
1631 
1633 // end of group devicemodule
1634 
1635 #endif // ROCPRIM_DEVICE_DEVICE_SEGMENTED_RADIX_SORT_HPP_
Empty type used as a placeholder, usually used to flag that given template parameter should not be us...
Definition: types.hpp:135
Definition: device_segmented_radix_sort_config.hpp:344
hipError_t segmented_radix_sort_pairs_desc(void *temporary_storage, size_t &storage_size, KeysInputIterator keys_input, KeysOutputIterator keys_output, ValuesInputIterator values_input, ValuesOutputIterator values_output, unsigned int size, unsigned int segments, OffsetIterator begin_offsets, OffsetIterator end_offsets, unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key), hipStream_t stream=0, bool debug_synchronous=false)
Parallel descending radix sort-by-key primitive for device level.
Definition: device_segmented_radix_sort.hpp:1067
hipError_t partition(void *temporary_storage, size_t &storage_size, InputIterator input, FlagIterator flags, OutputIterator output, SelectedCountOutputIterator selected_count_output, const size_t size, const hipStream_t stream=0, const bool debug_synchronous=false)
Parallel select primitive for device level using range of flags.
Definition: device_partition.hpp:721
hipError_t memcpy_and_sync(void *dst, const void *src, size_t size_bytes, hipMemcpyKind kind, hipStream_t stream)
Copy data from src to dest with stream ordering and synchronization.
Definition: various.hpp:286
This class provides an convenient way to do double buffering.
Definition: double_buffer.hpp:37
hipError_t segmented_radix_sort_keys_desc(void *temporary_storage, size_t &storage_size, KeysInputIterator keys_input, KeysOutputIterator keys_output, unsigned int size, unsigned int segments, OffsetIterator begin_offsets, OffsetIterator end_offsets, unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key), hipStream_t stream=0, bool debug_synchronous=false)
Parallel descending radix sort primitive for device level.
Definition: device_segmented_radix_sort.hpp:790
ROCPRIM_HOST_DEVICE reverse_iterator< SourceIterator > make_reverse_iterator(SourceIterator source_iterator)
make_reverse_iterator creates a reverse_iterator wrapping source_iterator.
Definition: reverse_iterator.hpp:204
Special type used to show that the given device-level operation will be executed with optimal configu...
Definition: config_types.hpp:45
hipError_t segmented_radix_sort_keys(void *temporary_storage, size_t &storage_size, KeysInputIterator keys_input, KeysOutputIterator keys_output, unsigned int size, unsigned int segments, OffsetIterator begin_offsets, OffsetIterator end_offsets, unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key), hipStream_t stream=0, bool debug_synchronous=false)
Parallel ascending radix sort primitive for device level.
Definition: device_segmented_radix_sort.hpp:668
ROCPRIM_HOST_DEVICE constexpr T min(const T &a, const T &b)
Returns the minimum of its arguments.
Definition: functional.hpp:63
ROCPRIM_HOST_DEVICE T * current() const
Returns a pointer to the current buffer.
Definition: double_buffer.hpp:69
Deprecated: Configuration of device-level scan primitives.
Definition: block_histogram.hpp:62
Definition: device_segmented_radix_sort.hpp:171
A random-access input (read-only) iterator over a sequence of consecutive integer values...
Definition: counting_iterator.hpp:51
hipError_t segmented_radix_sort_pairs(void *temporary_storage, size_t &storage_size, KeysInputIterator keys_input, KeysOutputIterator keys_output, ValuesInputIterator values_input, ValuesOutputIterator values_output, unsigned int size, unsigned int segments, OffsetIterator begin_offsets, OffsetIterator end_offsets, unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key), hipStream_t stream=0, bool debug_synchronous=false)
Parallel ascending radix sort-by-key primitive for device level.
Definition: device_segmented_radix_sort.hpp:930
Definition: device_segmented_radix_sort.hpp:205
hipError_t partition_three_way(void *temporary_storage, size_t &storage_size, InputIterator input, FirstOutputIterator output_first_part, SecondOutputIterator output_second_part, UnselectedOutputIterator output_unselected, SelectedCountOutputIterator selected_count_output, const size_t size, FirstUnaryPredicate select_first_part_op, SecondUnaryPredicate select_second_part_op, const hipStream_t stream=0, const bool debug_synchronous=false)
Parallel select primitive for device level using two selection predicates.
Definition: device_partition.hpp:1029
ROCPRIM_HOST_DEVICE T * alternate() const
Returns a pointer to the alternate buffer.
Definition: double_buffer.hpp:76
ROCPRIM_HOST_DEVICE void swap()
Swaps the current and alternate buffers.
Definition: double_buffer.hpp:83