21 #ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_MERGE_SORT_HPP_ 22 #define ROCPRIM_DEVICE_DETAIL_DEVICE_MERGE_SORT_HPP_ 24 #include <type_traits> 27 #include "../../config.hpp" 28 #include "../../detail/various.hpp" 30 #include "../../intrinsics.hpp" 31 #include "../../functional.hpp" 32 #include "../../types.hpp" 34 #include "../../block/block_load.hpp" 35 #include "../../block/block_load_func.hpp" 36 #include "../../block/block_sort.hpp" 37 #include "../../block/block_store.hpp" 39 BEGIN_ROCPRIM_NAMESPACE
46 unsigned int BlockSize,
47 unsigned int ItemsPerThread,
57 template<
class KeysOutputIterator,
class ValuesOutputIterator,
class OffsetT>
58 ROCPRIM_DEVICE ROCPRIM_INLINE
void store(
const OffsetT block_offset,
59 const unsigned int valid_in_last_block,
60 const bool is_incomplete_block,
61 KeysOutputIterator keys_output,
62 ValuesOutputIterator ,
63 Key (&keys)[ItemsPerThread],
64 Value (&)[ItemsPerThread],
65 storage_type& storage)
70 if(is_incomplete_block)
73 keys_output + block_offset,
82 keys_output + block_offset,
91 unsigned int BlockSize,
92 unsigned int ItemsPerThread,
105 template <
class KeysOutputIterator,
class ValuesOutputIterator,
class OffsetT>
106 ROCPRIM_DEVICE ROCPRIM_INLINE
107 void store(
const OffsetT block_offset,
108 const unsigned int valid_in_last_block,
109 const bool is_incomplete_block,
110 KeysOutputIterator keys_output,
111 ValuesOutputIterator values_output,
112 Key (&keys)[ItemsPerThread],
113 Value (&values)[ItemsPerThread],
114 storage_type& storage)
119 if(is_incomplete_block)
121 block_store_key_type().store(
122 keys_output + block_offset,
131 values_output + block_offset,
139 block_store_key_type().store(
140 keys_output + block_offset,
148 values_output + block_offset,
156 template<
typename Value,
157 unsigned int BlockSize,
158 unsigned int ItemsPerThread,
159 typename Enable =
void>
173 template<
typename ValuesInputIterator,
typename ValuesOutputIterator>
174 ROCPRIM_DEVICE
void permute(
unsigned int (&ranks)[ItemsPerThread],
175 ValuesInputIterator values_input,
176 ValuesOutputIterator values_output,
180 const auto flat_id = block_thread_id<0>();
181 Value values[ItemsPerThread];
182 block_load_direct_striped<BlockSize>(flat_id, values_input, values);
188 template<
typename ValuesOutputIterator,
typename ValuesInputIterator>
189 ROCPRIM_DEVICE
void permute(
unsigned int (&ranks)[ItemsPerThread],
190 ValuesInputIterator values_input,
191 ValuesOutputIterator values_output,
192 const unsigned int valid_in_last_block,
196 const auto flat_id = block_thread_id<0>();
197 Value values[ItemsPerThread];
198 block_load_direct_striped<BlockSize>(flat_id, values_input, values, valid_in_last_block);
205 template<
unsigned int BlockSize,
unsigned int ItemsPerThread>
210 template<
typename ValuesInputIterator,
typename ValuesOutputIterator>
211 ROCPRIM_DEVICE
void permute(
unsigned int (&ranks)[ItemsPerThread],
212 ValuesInputIterator values_input,
213 ValuesOutputIterator values_output,
214 storage_type& storage)
222 template<
typename ValuesOutputIterator,
typename ValuesInputIterator>
223 ROCPRIM_DEVICE
void permute(
unsigned int (&ranks)[ItemsPerThread],
224 ValuesInputIterator values_input,
225 ValuesOutputIterator values_output,
226 const unsigned int valid_in_last_block,
227 storage_type& storage)
232 (void)valid_in_last_block;
242 template<
typename Value,
unsigned int BlockSize,
unsigned int ItemsPerThread>
246 std::enable_if_t<(std::is_trivially_copyable<Value>::value
247 && !rocprim::is_floating_point<Value>::value
248 && !std::is_integral<Value>::value)>>
250 static constexpr
unsigned int items_per_block = ItemsPerThread * BlockSize;
254 Value values[items_per_block];
259 template<
typename ValuesInputIterator,
typename ValuesOutputIterator>
260 ROCPRIM_DEVICE
void permute(
unsigned int (&ranks)[ItemsPerThread],
261 ValuesInputIterator values_input,
262 ValuesOutputIterator values_output,
266 auto& values_shared = storage_.get().values;
267 const auto flat_id = block_thread_id<0>();
270 for(
unsigned int item = 0; item < ItemsPerThread; ++item)
272 const unsigned int idx = BlockSize * item + flat_id;
273 values_shared[idx] = values_input[idx];
279 for(
unsigned int item = 0; item < ItemsPerThread; ++item)
281 values_output[ItemsPerThread * flat_id + item] = values_shared[ranks[item]];
285 template<
typename ValuesOutputIterator,
typename ValuesInputIterator>
286 ROCPRIM_DEVICE
void permute(
unsigned int (&ranks)[ItemsPerThread],
287 ValuesInputIterator values_input,
288 ValuesOutputIterator values_output,
289 const unsigned int valid_in_last_block,
293 auto& values_shared = storage_.get().values;
294 const auto flat_id = block_thread_id<0>();
297 for(
unsigned int item = 0; item < ItemsPerThread; ++item)
299 const unsigned int idx = BlockSize * item + flat_id;
300 if(idx < valid_in_last_block)
302 values_shared[idx] = values_input[idx];
309 for(
unsigned int item = 0; item < ItemsPerThread; ++item)
311 if(flat_id * ItemsPerThread + item < valid_in_last_block)
313 values_output[ItemsPerThread * flat_id + item] = values_shared[ranks[item]];
319 template<
typename Key,
321 unsigned int BlockSize,
322 unsigned int ItemsPerThread,
324 typename Enable =
void>
327 using stable_key_type = rocprim::tuple<Key, unsigned int>;
348 template<
typename KeysInputIterator,
349 typename KeysOutputIterator,
350 typename ValuesInputIterator,
351 typename ValuesOutputIterator,
352 typename BinaryFunction>
353 ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
354 void sort(
const unsigned int valid_in_last_block,
355 const bool is_incomplete_block,
356 KeysInputIterator keys_input,
357 KeysOutputIterator keys_output,
358 ValuesInputIterator values_input,
359 ValuesOutputIterator values_output,
360 BinaryFunction compare_function,
366 Key keys[ItemsPerThread];
368 if(is_incomplete_block)
377 const auto flat_id = block_thread_id<0>();
379 stable_key_type stable_keys[ItemsPerThread];
381 for(
unsigned int i = 0; i < ItemsPerThread; ++i)
383 stable_keys[i] = rocprim::make_tuple(keys[i], flat_id * ItemsPerThread + i);
389 auto stable_compare_function
390 = [compare_function](
const stable_key_type& a,
391 const stable_key_type& b) ROCPRIM_FORCE_INLINE
mutable 393 const bool ab = compare_function(rocprim::get<0>(a), rocprim::get<0>(b));
395 || (!compare_function(rocprim::get<0>(b), rocprim::get<0>(a))
396 && (rocprim::get<1>(a) < rocprim::get<1>(b)));
399 if(is_incomplete_block)
403 auto stable_oob_compare_function
404 = [stable_compare_function, valid_in_last_block](
const stable_key_type& a,
405 const stable_key_type& b)
mutable 407 const bool a_oob = rocprim::get<1>(a) >= valid_in_last_block;
408 const bool b_oob = rocprim::get<1>(b) >= valid_in_last_block;
409 return a_oob || b_oob ? !a_oob : stable_compare_function(a, b);
414 sort_type().
sort(stable_keys, storage.sort, stable_oob_compare_function);
416 unsigned int ranks[ItemsPerThread];
418 for(
unsigned int i = 0; i < ItemsPerThread; ++i)
420 keys[i] = rocprim::get<0>(stable_keys[i]);
421 ranks[i] = rocprim::get<1>(stable_keys[i]);
430 storage.permute_values);
434 sort_type().
sort(stable_keys, storage.sort, stable_compare_function);
436 unsigned int ranks[ItemsPerThread];
438 for(
unsigned int i = 0; i < ItemsPerThread; ++i)
440 keys[i] = rocprim::get<0>(stable_keys[i]);
441 ranks[i] = rocprim::get<1>(stable_keys[i]);
449 storage.permute_values);
454 template<
typename Key,
unsigned int BlockSize,
unsigned int ItemsPerThread>
480 template<
typename KeysInputIterator,
481 typename KeysOutputIterator,
482 typename ValuesInputIterator,
483 typename ValuesOutputIterator,
484 typename BinaryFunction>
485 ROCPRIM_DEVICE
void sort(
unsigned int valid_in_last_block,
486 const bool is_incomplete_block,
487 KeysInputIterator keys_input,
488 KeysOutputIterator keys_output,
489 ValuesInputIterator ,
490 ValuesOutputIterator ,
491 BinaryFunction compare_function,
492 storage_type& storage)
494 Key keys[ItemsPerThread];
496 if(is_incomplete_block)
498 keys_load_type().load(keys_input, keys, valid_in_last_block, storage.load_keys);
500 sort_type().
sort(keys, storage.sort, valid_in_last_block, compare_function);
506 keys_load_type().load(keys_input, keys, storage.load_keys);
515 #ifndef DOXYGEN_SHOULD_SKIP_THIS 516 template<
typename Key,
typename Value,
unsigned int BlockSize,
unsigned int ItemsPerThread>
522 std::enable_if_t<(sizeof(Value) <= sizeof(int))>>
544 typename keys_load_type::storage_type load_keys;
545 typename values_load_type::storage_type load_values;
546 typename sort_type::storage_type sort;
547 typename keys_store_type::storage_type store_keys;
548 typename values_store_type::storage_type store_values;
551 template<
typename KeysInputIterator,
552 typename KeysOutputIterator,
553 typename ValuesInputIterator,
554 typename ValuesOutputIterator,
555 typename BinaryFunction>
556 ROCPRIM_DEVICE
void sort(
const unsigned int valid_in_last_block,
557 const bool is_incomplete_block,
558 KeysInputIterator keys_input,
559 KeysOutputIterator keys_output,
560 ValuesInputIterator values_input,
561 ValuesOutputIterator values_output,
562 BinaryFunction compare_function,
563 storage_type& storage)
565 Key keys[ItemsPerThread];
566 Value values[ItemsPerThread];
568 if(is_incomplete_block)
570 keys_load_type().load(keys_input, keys, valid_in_last_block, storage.load_keys);
572 values_load_type().load(values_input, values, valid_in_last_block, storage.load_values);
574 sort_type().sort(keys, values, storage.sort, valid_in_last_block, compare_function);
576 keys_store_type().store(keys_output, keys, valid_in_last_block, storage.store_keys);
578 values_store_type().store(values_output,
581 storage.store_values);
585 keys_load_type().load(keys_input, keys, storage.load_keys);
587 values_load_type().load(values_input, values, storage.load_values);
589 sort_type().sort(keys, values, storage.sort, compare_function);
591 keys_store_type().store(keys_output, keys, storage.store_keys);
593 values_store_type().store(values_output, values, storage.store_values);
597 template<
typename Key,
typename Value,
unsigned int BlockSize,
unsigned int ItemsPerThread>
603 std::enable_if_t<(sizeof(Value) > sizeof(int))>>
621 typename keys_load_type::storage_type load_keys;
622 typename sort_type::storage_type sort;
623 typename keys_store_type::storage_type store_keys;
624 typename values_permute_type::storage_type permute_values;
627 template<
typename KeysInputIterator,
628 typename KeysOutputIterator,
629 typename ValuesInputIterator,
630 typename ValuesOutputIterator,
631 typename BinaryFunction>
632 ROCPRIM_DEVICE
void sort(
const unsigned int valid_in_last_block,
633 const bool is_incomplete_block,
634 KeysInputIterator keys_input,
635 KeysOutputIterator keys_output,
636 ValuesInputIterator values_input,
637 ValuesOutputIterator values_output,
638 BinaryFunction compare_function,
641 Key keys[ItemsPerThread];
643 const auto flat_id = block_thread_id<0>();
644 unsigned int ranks[ItemsPerThread];
646 for(
unsigned int item = 0; item < ItemsPerThread; ++item)
648 ranks[item] = flat_id * ItemsPerThread + item;
651 if(is_incomplete_block)
653 keys_load_type().load(keys_input, keys, valid_in_last_block, storage.load_keys);
655 sort_type().sort(keys, ranks, storage.sort, valid_in_last_block, compare_function);
657 keys_store_type().store(keys_output, keys, valid_in_last_block, storage.store_keys);
658 values_permute_type().permute(ranks,
662 storage.permute_values);
666 keys_load_type().load(keys_input, keys, storage.load_keys);
668 sort_type().sort(keys, ranks, storage.sort, compare_function);
670 keys_store_type().store(keys_output, keys, storage.store_keys);
671 values_permute_type().permute(ranks,
674 storage.permute_values);
678 #endif // DOXYGEN_SHOULD_SKIP_THIS 680 template<
unsigned int BlockSize,
681 unsigned int ItemsPerThread,
683 class KeysInputIterator,
684 class KeysOutputIterator,
685 class ValuesInputIterator,
686 class ValuesOutputIterator,
688 class BinaryFunction,
689 class ValueType =
typename std::iterator_traits<ValuesInputIterator>::value_type>
690 ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
auto block_sort_kernel_impl(KeysInputIterator keys_input,
691 KeysOutputIterator keys_output,
692 ValuesInputIterator values_input,
693 ValuesOutputIterator values_output,
694 const OffsetT input_size,
695 BinaryFunction compare_function)
697 using key_type =
typename std::iterator_traits<KeysInputIterator>::value_type;
698 using value_type =
typename std::iterator_traits<ValuesInputIterator>::value_type;
701 constexpr
unsigned int items_per_block = BlockSize * ItemsPerThread;
703 const OffsetT block_offset = flat_block_id * items_per_block;
704 const unsigned int valid_in_last_block = input_size - block_offset;
705 const bool is_incomplete_block = flat_block_id == (input_size / items_per_block);
709 ROCPRIM_SHARED_MEMORY
typename sort_impl::storage_type storage;
711 sort_impl().sort(valid_in_last_block,
713 keys_input + block_offset,
714 keys_output + block_offset,
715 values_input + block_offset,
716 values_output + block_offset,
721 template<
unsigned int BlockSize,
722 unsigned int ItemsPerThread,
723 class KeysInputIterator,
724 class KeysOutputIterator,
725 class ValuesInputIterator,
726 class ValuesOutputIterator,
728 class BinaryFunction>
729 ROCPRIM_DEVICE ROCPRIM_INLINE
void block_merge_oddeven_kernel(KeysInputIterator keys_input,
730 KeysOutputIterator keys_output,
731 ValuesInputIterator values_input,
732 ValuesOutputIterator values_output,
733 const OffsetT input_size,
734 const OffsetT sorted_block_size,
735 BinaryFunction compare_function)
737 using key_type =
typename std::iterator_traits<KeysInputIterator>::value_type;
738 using value_type =
typename std::iterator_traits<ValuesInputIterator>::value_type;
739 constexpr
bool with_values = !std::is_same<value_type, ::rocprim::empty_type>::value;
741 constexpr
unsigned int items_per_block = BlockSize * ItemsPerThread;
742 const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>();
743 const unsigned int flat_block_id = ::rocprim::detail::block_id<0>();
744 const bool is_incomplete_block = flat_block_id == (input_size / items_per_block);
747 const OffsetT block_offset = flat_block_id * items_per_block;
748 const OffsetT valid_in_last_block = input_size - block_offset;
750 const OffsetT thread_offset = flat_id * ItemsPerThread;
751 if(thread_offset >= valid_in_last_block)
756 key_type keys[ItemsPerThread];
757 value_type values[ItemsPerThread];
759 if(is_incomplete_block)
763 if ROCPRIM_IF_CONSTEXPR(with_values)
766 values_input + block_offset,
768 valid_in_last_block);
774 if ROCPRIM_IF_CONSTEXPR(with_values)
780 const unsigned int merged_tiles_number = sorted_block_size / items_per_block;
781 const unsigned int mask = merged_tiles_number - 1;
784 const unsigned int block_is_odd = merged_tiles_number & tilegroup_id;
785 const OffsetT block_start = tilegroup_id * items_per_block;
786 const OffsetT next_block_start_
787 = block_is_odd ? block_start - sorted_block_size : block_start + sorted_block_size;
788 const OffsetT next_block_start =
min(next_block_start_, input_size);
789 const OffsetT next_block_end =
min(next_block_start + sorted_block_size, input_size);
791 if(next_block_start == input_size)
795 if(is_incomplete_block)
798 for(
unsigned int i = 0; i < ItemsPerThread; i++)
800 const unsigned int id = block_offset + thread_offset + i;
803 keys_output[id] = keys[i];
804 if ROCPRIM_IF_CONSTEXPR(with_values)
806 values_output[id] = values[i];
814 for(
unsigned int i = 0; i < ItemsPerThread; i++)
816 const unsigned int id = block_offset + thread_offset + i;
817 keys_output[id] = keys[i];
818 if ROCPRIM_IF_CONSTEXPR(with_values)
820 values_output[id] = values[i];
827 OffsetT left_id = next_block_start;
829 const OffsetT dest_offset
830 =
min(block_start, next_block_start) + block_offset + thread_offset - block_start
833 const auto merge_function = [&](
const unsigned int i)
835 OffsetT right_id = next_block_end;
837 while(left_id < right_id)
839 OffsetT mid_id = (left_id + right_id) / 2;
840 key_type mid_key = keys_input[mid_id];
841 const bool mid_smaller = block_is_odd ? !compare_function(keys[i], mid_key)
842 : compare_function(mid_key, keys[i]);
843 left_id = mid_smaller ? mid_id + 1 : left_id;
844 right_id = mid_smaller ? right_id : mid_id;
847 OffsetT offset = dest_offset + i + left_id;
848 keys_output[offset] = keys[i];
849 if ROCPRIM_IF_CONSTEXPR(with_values)
851 values_output[offset] = values[i];
855 if(is_incomplete_block)
858 for(
unsigned int i = 0; i < ItemsPerThread; i++)
860 if(thread_offset + i < valid_in_last_block)
869 for(
unsigned int i = 0; i < ItemsPerThread; i++)
878 END_ROCPRIM_NAMESPACE
880 #endif // ROCPRIM_DEVICE_DETAIL_DEVICE_MERGE_SORT_HPP_ Empty type used as a placeholder, usually used to flag that given template parameter should not be us...
Definition: types.hpp:135
A merged sort based algorithm which sorts stably.
ROCPRIM_DEVICE ROCPRIM_INLINE void store(OutputIterator block_output, T(&items)[ItemsPerThread])
Stores an arrangement of items from across the thread block into an arrangement on continuous memory...
Definition: block_store.hpp:168
The block_sort class is a block level parallel primitive which provides methods sorting items (keys o...
Definition: block_sort.hpp:151
The block_store class is a block level parallel primitive which provides methods for storing an arran...
Definition: block_store.hpp:134
Definition: device_merge_sort.hpp:340
The block_exchange class is a block level parallel primitive which provides methods for rearranging i...
Definition: block_exchange.hpp:81
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void sort(Key &thread_key, BinaryFunction compare_function=BinaryFunction())
Block sort for any data type.
Definition: block_sort.hpp:181
Definition: test_utils_custom_float_type.hpp:110
Definition: device_merge_sort.hpp:51
ROCPRIM_HOST_DEVICE constexpr T min(const T &a, const T &b)
Returns the minimum of its arguments.
Definition: functional.hpp:63
ROCPRIM_DEVICE ROCPRIM_INLINE void load(InputIterator block_input, T(&items)[ItemsPerThread])
Loads data from continuous memory into an arrangement of items across the thread block.
Definition: block_load.hpp:167
Definition: device_merge_sort.hpp:167
Deprecated: Configuration of device-level scan primitives.
Definition: block_histogram.hpp:62
Definition: device_merge_sort.hpp:325
Definition: device_merge_sort.hpp:160
ROCPRIM_DEVICE ROCPRIM_INLINE void syncthreads()
Synchronize all threads in a block (tile)
Definition: thread.hpp:216
typename ::rocprim::detail::empty_storage_type storage_type
Struct used to allocate a temporary memory that is required for thread communication during operation...
Definition: block_load.hpp:148
ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int flat_block_id()
Returns flat (linear, 1D) block identifier in a multidimensional grid.
Definition: thread.hpp:178
Definition: various.hpp:52
typename base_type::storage_type storage_type
Struct used to allocate a temporary memory that is required for thread communication during operation...
Definition: block_sort.hpp:166
typename ::rocprim::detail::empty_storage_type storage_type
Struct used to allocate a temporary memory that is required for thread communication during operation...
Definition: block_store.hpp:149
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void gather_from_striped(const T(&input)[ItemsPerThread], U(&output)[ItemsPerThread], const Offset(&ranks)[ItemsPerThread])
Gathers items from a striped arrangement based on their ranks across the thread block.
Definition: block_exchange.hpp:425
Definition: device_merge_sort.hpp:619
BEGIN_ROCPRIM_NAMESPACE ROCPRIM_DEVICE ROCPRIM_INLINE void block_load_direct_blocked(unsigned int flat_id, InputIterator block_input, T(&items)[ItemsPerThread])
Loads data from continuous memory into a blocked arrangement of items across the thread block...
Definition: block_load_func.hpp:58
block_sort_algorithm
Available algorithms for block_sort primitive.
Definition: block_sort.hpp:41
The block_load class is a block level parallel primitive which provides methods for loading data from...
Definition: block_load.hpp:133