21 #ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_PARTITION_HPP_ 22 #define ROCPRIM_DEVICE_DETAIL_DEVICE_PARTITION_HPP_ 24 #include <type_traits> 27 #include "../../detail/various.hpp" 28 #include "../../intrinsics.hpp" 29 #include "../../functional.hpp" 30 #include "../../types.hpp" 32 #include "../../block/block_load.hpp" 33 #include "../../block/block_store.hpp" 34 #include "../../block/block_scan.hpp" 35 #include "../../block/block_discontinuity.hpp" 37 #include "lookback_scan_state.hpp" 38 #include "rocprim/type_traits.hpp" 39 #include "rocprim/types/tuple.hpp" 41 BEGIN_ROCPRIM_NAMESPACE
46 #ifndef DOXYGEN_SHOULD_SKIP_THIS 47 enum class select_method
53 #endif // DOXYGEN_SHOULD_SKIP_THIS 55 template<select_method SelectMethod,
56 unsigned int BlockSize,
57 class BlockLoadFlagsType,
58 class BlockDiscontinuityType,
62 unsigned int ItemsPerThread,
66 ROCPRIM_DEVICE ROCPRIM_INLINE
auto 67 partition_block_load_flags(InputIterator ,
68 FlagIterator block_flags,
69 ValueType (&)[ItemsPerThread],
70 bool (&is_selected)[ItemsPerThread],
76 const bool is_global_last_block,
77 const unsigned int valid_in_global_last_block) ->
78 typename std::enable_if<SelectMethod == select_method::flag>::type
80 if(is_global_last_block)
82 BlockLoadFlagsType().load(block_flags,
84 valid_in_global_last_block,
100 template<select_method SelectMethod,
101 unsigned int BlockSize,
102 class BlockLoadFlagsType,
103 class BlockDiscontinuityType,
107 unsigned int ItemsPerThread,
108 class UnaryPredicate,
111 ROCPRIM_DEVICE ROCPRIM_INLINE
auto 112 partition_block_load_flags(InputIterator ,
114 ValueType (&values)[ItemsPerThread],
115 bool (&is_selected)[ItemsPerThread],
116 UnaryPredicate predicate,
121 const bool is_global_last_block,
122 const unsigned int valid_in_global_last_block) ->
123 typename std::enable_if<SelectMethod == select_method::predicate>::type
125 if(is_global_last_block)
129 for(
unsigned int i = 0; i < ItemsPerThread; i++)
131 if((offset + i) < valid_in_global_last_block)
133 is_selected[i] = predicate(values[i]);
137 is_selected[i] =
false;
144 for(
unsigned int i = 0; i < ItemsPerThread; i++)
146 is_selected[i] = predicate(values[i]);
153 template<
class InequalityOp>
156 InequalityOp inequality_op;
157 unsigned int valid_count;
159 ROCPRIM_DEVICE ROCPRIM_INLINE
161 : inequality_op(inequality_op), valid_count(valid_count)
164 template<
class T,
class U>
165 ROCPRIM_DEVICE ROCPRIM_INLINE
166 bool operator()(
const T& a,
const U& b,
unsigned int b_index)
168 return (b_index < valid_count && inequality_op(a, b));
172 template<select_method SelectMethod,
173 unsigned int BlockSize,
174 class BlockLoadFlagsType,
175 class BlockDiscontinuityType,
179 unsigned int ItemsPerThread,
180 class UnaryPredicate,
183 ROCPRIM_DEVICE ROCPRIM_INLINE
auto 184 partition_block_load_flags(InputIterator block_predecessor,
186 ValueType (&values)[ItemsPerThread],
187 bool (&is_selected)[ItemsPerThread],
189 InequalityOp inequality_op,
190 StorageType& storage,
191 const bool is_first_block,
193 const bool is_global_last_block,
194 const unsigned int valid_in_global_last_block) ->
195 typename std::enable_if<SelectMethod == select_method::unique>::type
199 if(is_global_last_block)
201 BlockDiscontinuityType().flag_heads(
205 storage.discontinuity_values);
209 BlockDiscontinuityType().flag_heads(is_selected,
212 storage.discontinuity_values);
217 const ValueType predecessor = block_predecessor[0];
218 if(is_global_last_block)
220 BlockDiscontinuityType().flag_heads(
225 storage.discontinuity_values);
229 BlockDiscontinuityType().flag_heads(is_selected,
233 storage.discontinuity_values);
239 if(is_global_last_block)
241 const auto offset = block_thread_id * ItemsPerThread;
243 for(
unsigned int i = 0; i < ItemsPerThread; i++)
245 if((offset + i) >= valid_in_global_last_block)
247 is_selected[i] =
false;
254 template<select_method SelectMethod,
255 unsigned int BlockSize,
256 class BlockLoadFlagsType,
257 class BlockDiscontinuityType,
261 unsigned int ItemsPerThread,
262 class FirstUnaryPredicate,
263 class SecondUnaryPredicate,
266 ROCPRIM_DEVICE ROCPRIM_INLINE
void 267 partition_block_load_flags(InputIterator ,
269 ValueType (&values)[ItemsPerThread],
270 bool (&is_selected)[2][ItemsPerThread],
271 FirstUnaryPredicate select_first_part_op,
272 SecondUnaryPredicate select_second_part_op,
276 const unsigned int block_thread_id,
277 const bool is_global_last_block,
278 const unsigned int valid_in_global_last_block)
280 if(is_global_last_block)
282 const auto offset = block_thread_id * ItemsPerThread;
284 for(
unsigned int i = 0; i < ItemsPerThread; i++)
286 if((offset + i) < valid_in_global_last_block)
288 is_selected[0][i] = select_first_part_op(values[i]);
289 is_selected[1][i] = !is_selected[0][i] && select_second_part_op(values[i]);
293 is_selected[0][i] =
false;
294 is_selected[1][i] =
false;
301 for(
unsigned int i = 0; i < ItemsPerThread; i++)
303 is_selected[0][i] = select_first_part_op(values[i]);
304 is_selected[1][i] = !is_selected[0][i] && select_second_part_op(values[i]);
310 template<
bool OnlySelected,
311 unsigned int BlockSize,
313 unsigned int ItemsPerThread,
316 class ScatterStorageType>
317 ROCPRIM_DEVICE ROCPRIM_INLINE
auto 318 partition_scatter(ValueType (&values)[ItemsPerThread],
319 bool (&is_selected)[ItemsPerThread],
320 OffsetType (&output_indices)[ItemsPerThread],
322 const size_t total_size,
323 const OffsetType selected_prefix,
324 const OffsetType selected_in_block,
325 ScatterStorageType& storage,
328 const bool is_global_last_block,
329 const unsigned int valid_in_global_last_block,
330 size_t (&prev_selected_count_values)[1],
331 size_t prev_processed) ->
typename std::enable_if<!OnlySelected>::type
333 constexpr
unsigned int items_per_block = BlockSize * ItemsPerThread;
336 auto scatter_storage = storage.get();
338 for(
unsigned int i = 0; i < ItemsPerThread; i++)
340 unsigned int item_index = (flat_block_thread_id * ItemsPerThread) + i;
341 unsigned int selected_item_index = output_indices[i] - selected_prefix;
342 unsigned int rejected_item_index = (item_index - selected_item_index) + selected_in_block;
344 unsigned int scatter_index = is_selected[i] ? selected_item_index : rejected_item_index;
345 scatter_storage[scatter_index] = values[i];
349 ValueType reloaded_values[ItemsPerThread];
350 for(
unsigned int i = 0; i < ItemsPerThread; i++)
353 reloaded_values[i] = scatter_storage[item_index];
356 const auto calculate_scatter_index = [=](
const unsigned int item_index) ->
size_t 358 const size_t selected_output_index = prev_selected_count_values[0] + selected_prefix;
359 const size_t rejected_output_index = total_size + selected_output_index - prev_processed
360 - flat_block_id * items_per_block + selected_in_block
362 return item_index < selected_in_block ? selected_output_index + item_index
363 : rejected_output_index - item_index;
365 if(is_global_last_block)
367 for(
unsigned int i = 0; i < ItemsPerThread; i++)
370 if(item_index < valid_in_global_last_block)
372 get<0>(output)[calculate_scatter_index(item_index)] = reloaded_values[i];
378 for(
unsigned int i = 0; i < ItemsPerThread; i++)
381 get<0>(output)[calculate_scatter_index(item_index)] = reloaded_values[i];
387 template<
bool OnlySelected,
388 unsigned int BlockSize,
390 unsigned int ItemsPerThread,
394 class ScatterStorageType>
395 ROCPRIM_DEVICE ROCPRIM_INLINE
auto partition_scatter(ValueType (&values)[ItemsPerThread],
396 bool (&is_selected)[ItemsPerThread],
397 OffsetType (&output_indices)[ItemsPerThread],
400 const OffsetType selected_prefix,
401 const OffsetType selected_in_block,
402 ScatterStorageType& storage,
403 const unsigned int flat_block_id,
404 const unsigned int flat_block_thread_id,
405 const bool is_global_last_block,
406 const unsigned int valid_in_global_last_block,
407 size_t (&prev_selected_count_values)[1],
408 size_t prev_processed) ->
409 typename std::enable_if<!OnlySelected>::type
411 constexpr
unsigned int items_per_block = BlockSize * ItemsPerThread;
414 auto scatter_storage = storage.get();
416 for(
unsigned int i = 0; i < ItemsPerThread; i++)
418 unsigned int item_index = (flat_block_thread_id * ItemsPerThread) + i;
419 unsigned int selected_item_index = output_indices[i] - selected_prefix;
420 unsigned int rejected_item_index = (item_index - selected_item_index) + selected_in_block;
422 unsigned int scatter_index = is_selected[i] ? selected_item_index : rejected_item_index;
423 scatter_storage[scatter_index] = values[i];
427 ValueType reloaded_values[ItemsPerThread];
428 for(
unsigned int i = 0; i < ItemsPerThread; i++)
431 reloaded_values[i] = scatter_storage[item_index];
434 auto save_to_output = [=](
const unsigned int item_index,
const unsigned int i)
436 const size_t selected_output_index = prev_selected_count_values[0] + selected_prefix;
437 const size_t rejected_output_index = prev_processed + flat_block_id * items_per_block
438 - selected_output_index - selected_in_block;
440 if(item_index < selected_in_block)
442 get<0>(output)[selected_output_index + item_index] = reloaded_values[i];
446 get<1>(output)[rejected_output_index + item_index] = reloaded_values[i];
450 if(is_global_last_block)
452 for(
unsigned int i = 0; i < ItemsPerThread; i++)
455 if(item_index < valid_in_global_last_block)
457 save_to_output(item_index, i);
463 for(
unsigned int i = 0; i < ItemsPerThread; i++)
466 save_to_output(item_index, i);
472 template<
bool OnlySelected,
473 unsigned int BlockSize,
475 unsigned int ItemsPerThread,
479 class ScatterStorageType>
480 ROCPRIM_DEVICE ROCPRIM_INLINE
auto 481 partition_scatter(ValueType (&values)[ItemsPerThread],
482 bool (&is_selected)[ItemsPerThread],
483 OffsetType (&output_indices)[ItemsPerThread],
486 const OffsetType selected_prefix,
487 const OffsetType selected_in_block,
488 ScatterStorageType& storage,
490 const unsigned int flat_block_thread_id,
491 const bool is_global_last_block,
493 size_t (&prev_selected_count_values)[1],
494 size_t ) ->
typename std::enable_if<OnlySelected>::type
496 if(selected_in_block > BlockSize)
499 auto scatter_storage = storage.get();
501 for(
unsigned int i = 0; i < ItemsPerThread; i++)
503 unsigned int scatter_index = output_indices[i] - selected_prefix;
506 scatter_storage[scatter_index] = values[i];
512 for(
unsigned int i = flat_block_thread_id; i < selected_in_block; i += BlockSize)
514 get<0>(output)[prev_selected_count_values[0] + selected_prefix + i]
515 = scatter_storage[i];
521 for(
unsigned int i = 0; i < ItemsPerThread; i++)
523 if(!is_global_last_block || output_indices[i] < (selected_prefix + selected_in_block))
527 get<0>(output)[prev_selected_count_values[0] + output_indices[i]] = values[i];
535 template<
bool OnlySelected,
536 unsigned int BlockSize,
538 unsigned int ItemsPerThread,
541 class ScatterStorageType>
542 ROCPRIM_DEVICE ROCPRIM_INLINE
void partition_scatter(ValueType (&values)[ItemsPerThread],
543 bool (&is_selected)[2][ItemsPerThread],
544 OffsetType (&output_indices)[ItemsPerThread],
547 const OffsetType selected_prefix,
548 const OffsetType selected_in_block,
549 ScatterStorageType& storage,
550 const unsigned int flat_block_id,
551 const unsigned int flat_block_thread_id,
552 const bool is_global_last_block,
553 const unsigned int valid_in_global_last_block,
554 size_t (&prev_selected_count_values)[2],
555 size_t prev_processed)
557 constexpr
unsigned int items_per_block = BlockSize * ItemsPerThread;
558 auto scatter_storage = storage.get();
559 const size_t first_selected_prefix = prev_selected_count_values[0] + selected_prefix.x;
560 const size_t second_selected_prefix
561 = prev_selected_count_values[1] - selected_in_block.x + selected_prefix.y;
562 const size_t unselected_prefix = prev_processed - first_selected_prefix - second_selected_prefix
563 + items_per_block * flat_block_id - 2 * selected_in_block.x
564 - selected_in_block.y;
567 for(
unsigned int i = 0; i < ItemsPerThread; i++)
569 const unsigned int first_selected_item_index = output_indices[i].x - selected_prefix.x;
570 const unsigned int second_selected_item_index = output_indices[i].y - selected_prefix.y
571 + selected_in_block.x;
572 unsigned int scatter_index{};
574 if(is_selected[0][i])
576 scatter_index = first_selected_item_index;
578 else if(is_selected[1][i])
580 scatter_index = second_selected_item_index;
584 const unsigned int item_index = (flat_block_thread_id * ItemsPerThread) + i;
585 const unsigned int unselected_item_index = (item_index - first_selected_item_index - second_selected_item_index)
586 + 2*selected_in_block.x + selected_in_block.y;
587 scatter_index = unselected_item_index;
589 scatter_storage[scatter_index] = values[i];
593 auto save_to_output = [=](
const unsigned int item_index)
mutable 595 if(item_index < selected_in_block.x)
597 const size_t first_selected_index = first_selected_prefix + item_index;
598 get<0>(output)[first_selected_index] = scatter_storage[item_index];
600 else if(item_index < selected_in_block.x + selected_in_block.y)
602 const size_t second_selected_index = second_selected_prefix + item_index;
603 get<1>(output)[second_selected_index] = scatter_storage[item_index];
607 const size_t unselected_index = unselected_prefix + item_index;
608 get<2>(output)[unselected_index] = scatter_storage[item_index];
612 if(is_global_last_block)
614 for(
unsigned int i = 0; i < ItemsPerThread; i++)
616 const unsigned int item_index = (i * BlockSize) + flat_block_thread_id;
617 if(item_index < valid_in_global_last_block)
619 save_to_output(item_index);
625 for(
unsigned int i = 0; i < ItemsPerThread; i++)
627 const unsigned int item_index = (i * BlockSize) + flat_block_thread_id;
628 save_to_output(item_index);
634 unsigned int items_per_thread,
637 ROCPRIM_DEVICE ROCPRIM_INLINE
638 void convert_selected_to_indices(offset_type (&output_indices)[items_per_thread],
639 bool (&is_selected)[items_per_thread])
642 for(
unsigned int i = 0; i < items_per_thread; i++)
644 output_indices[i] = is_selected[i] ? 1 : 0;
649 unsigned int items_per_thread
651 ROCPRIM_DEVICE ROCPRIM_INLINE
652 void convert_selected_to_indices(uint2 (&output_indices)[items_per_thread],
653 bool (&is_selected)[2][items_per_thread])
656 for(
unsigned int i = 0; i < items_per_thread; i++)
658 output_indices[i].x = is_selected[0][i] ? 1 : 0;
659 output_indices[i].y = is_selected[1][i] ? 1 : 0;
663 template<
class OffsetT>
664 ROCPRIM_DEVICE ROCPRIM_INLINE
void store_selected_count(
size_t* selected_count,
665 size_t (&prev_selected_count_values)[1],
666 const OffsetT selected_prefix,
667 const OffsetT selected_in_block)
669 selected_count[0] = prev_selected_count_values[0] + selected_prefix + selected_in_block;
672 ROCPRIM_DEVICE ROCPRIM_INLINE
void store_selected_count(
size_t* selected_count,
673 size_t (&prev_selected_count_values)[2],
674 const uint2 selected_prefix,
675 const uint2 selected_in_block)
677 selected_count[0] = prev_selected_count_values[0] + selected_prefix.x + selected_in_block.x;
678 selected_count[1] = prev_selected_count_values[1] + selected_prefix.y + selected_in_block.y;
681 template<
unsigned int Size>
682 ROCPRIM_DEVICE
void load_selected_count(
const size_t*
const prev_selected_count,
683 size_t (&loaded_values)[Size])
685 for(
unsigned int i = 0; i < Size; ++i)
687 loaded_values[i] = prev_selected_count[i];
691 template<select_method SelectMethod,
697 class OutputKeyIterator,
698 class OutputValueIterator,
700 class OffsetLookbackScanState,
701 class... UnaryPredicates>
702 ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
void 703 partition_kernel_impl(KeyIterator keys_input,
704 ValueIterator values_input,
706 OutputKeyIterator keys_output,
707 OutputValueIterator values_output,
708 size_t* selected_count,
709 size_t* prev_selected_count,
710 size_t prev_processed,
711 const size_t total_size,
712 InequalityOp inequality_op,
713 OffsetLookbackScanState offset_scan_state,
714 const unsigned int number_of_blocks,
715 UnaryPredicates... predicates)
717 constexpr
auto block_size = Config::block_size;
718 constexpr
auto items_per_thread = Config::items_per_thread;
719 constexpr
unsigned int items_per_block = block_size * items_per_thread;
721 using offset_type =
typename OffsetLookbackScanState::value_type;
722 using key_type =
typename std::iterator_traits<KeyIterator>::value_type;
723 using value_type =
typename std::iterator_traits<ValueIterator>::value_type;
726 using block_load_key_type = ::rocprim::block_load<
728 Config::key_block_load_method
730 using block_load_value_type = ::rocprim::block_load<
732 Config::value_block_load_method
734 using block_load_flag_type = ::rocprim::block_load<
736 Config::flag_block_load_method
738 using block_scan_offset_type = ::rocprim::block_scan<
740 Config::block_scan_method
742 using block_discontinuity_key_type = ::rocprim::block_discontinuity<key_type, block_size>;
746 offset_type, OffsetLookbackScanState
750 using exchange_keys_storage_type = key_type[items_per_block];
752 using exchange_values_storage_type = value_type[items_per_block];
755 using is_selected_type = std::conditional_t<
756 sizeof...(UnaryPredicates) == 1,
757 bool[items_per_thread],
758 bool[
sizeof...(UnaryPredicates)][items_per_thread]>;
760 ROCPRIM_SHARED_MEMORY
union 762 raw_exchange_keys_storage_type exchange_keys;
763 raw_exchange_values_storage_type exchange_values;
764 typename block_load_key_type::storage_type load_keys;
765 typename block_load_value_type::storage_type load_values;
766 typename block_load_flag_type::storage_type load_flags;
767 typename block_discontinuity_key_type::storage_type discontinuity_values;
768 typename block_scan_offset_type::storage_type scan_offsets;
771 size_t prev_selected_count_values[
sizeof...(UnaryPredicates)]{};
772 load_selected_count(prev_selected_count, prev_selected_count_values);
774 const auto flat_block_thread_id = ::rocprim::detail::block_thread_id<0>();
775 const auto flat_block_id = ::rocprim::detail::block_id<0>();
776 const auto block_offset = flat_block_id * items_per_block;
777 const unsigned int valid_in_global_last_block
778 = total_size - prev_processed - items_per_block * (number_of_blocks - 1);
779 const bool is_last_launch = total_size <= prev_processed + number_of_blocks * items_per_block;
780 const bool is_global_last_block = is_last_launch && flat_block_id == (number_of_blocks - 1);
782 key_type keys[items_per_thread];
783 is_selected_type is_selected;
784 offset_type output_indices[items_per_thread];
787 if(is_global_last_block)
789 block_load_key_type().load(keys_input + block_offset,
791 valid_in_global_last_block,
796 block_load_key_type()
798 keys_input + block_offset,
808 const bool is_first_block = flat_block_id == 0 && prev_processed == 0;
809 partition_block_load_flags<SelectMethod,
811 block_load_flag_type,
812 block_discontinuity_key_type>(keys_input + block_offset - 1,
813 flags + block_offset,
821 is_global_last_block,
822 valid_in_global_last_block);
825 convert_selected_to_indices(output_indices, is_selected);
828 offset_type selected_prefix{};
830 offset_type selected_in_block{};
833 if(flat_block_id == 0)
835 block_scan_offset_type()
841 storage.scan_offsets,
842 ::rocprim::plus<offset_type>()
844 if(flat_block_thread_id == 0)
846 offset_scan_state.set_complete(flat_block_id, selected_in_block);
852 ROCPRIM_SHARED_MEMORY
typename offset_scan_prefix_op_type::storage_type storage_prefix_op;
853 auto prefix_op = offset_scan_prefix_op_type(
858 block_scan_offset_type()
862 storage.scan_offsets,
864 ::rocprim::plus<offset_type>()
868 selected_in_block = prefix_op.get_reduction();
869 selected_prefix = prefix_op.get_prefix();
873 partition_scatter<OnlySelected, block_size>(keys,
880 storage.exchange_keys,
883 is_global_last_block,
884 valid_in_global_last_block,
885 prev_selected_count_values,
888 static constexpr
bool with_values = !std::is_same<value_type, ::rocprim::empty_type>::value;
890 if ROCPRIM_IF_CONSTEXPR (with_values) {
891 value_type values[items_per_thread];
894 if(is_global_last_block)
896 block_load_value_type().load(values_input + block_offset,
898 valid_in_global_last_block,
899 storage.load_values);
903 block_load_value_type()
905 values_input + block_offset,
912 partition_scatter<OnlySelected, block_size>(values,
919 storage.exchange_values,
922 is_global_last_block,
923 valid_in_global_last_block,
924 prev_selected_count_values,
929 const bool is_last_block = flat_block_id == (number_of_blocks - 1);
930 if(is_last_block && flat_block_thread_id == 0)
932 store_selected_count(selected_count,
933 prev_selected_count_values,
941 END_ROCPRIM_NAMESPACE
943 #endif // ROCPRIM_DEVICE_DETAIL_DEVICE_PARTITION_HPP_ ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int flat_block_thread_id()
Returns flat (linear, 1D) thread identifier in a multidimensional block (tile).
Definition: thread.hpp:106
ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int block_thread_id()
Returns thread identifier in a multidimensional block (tile) by dimension.
Definition: thread.hpp:248
Deprecated: Configuration of device-level scan primitives.
Definition: block_histogram.hpp:62
ROCPRIM_DEVICE ROCPRIM_INLINE void syncthreads()
Synchronize all threads in a block (tile)
Definition: thread.hpp:216
ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int flat_block_id()
Returns flat (linear, 1D) block identifier in a multidimensional grid.
Definition: thread.hpp:178
Fixed-size collection of heterogeneous values.
Definition: tuple.hpp:41
Definition: various.hpp:180
ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int block_size()
Returns block size in a multidimensional grid by dimension.
Definition: thread.hpp:268
Definition: device_partition.hpp:154
Definition: lookback_scan_state.hpp:515
hipError_t unique(void *temporary_storage, size_t &storage_size, InputIterator input, OutputIterator output, UniqueCountOutputIterator unique_count_output, const size_t size, EqualityOp equality_op=EqualityOp(), const hipStream_t stream=0, const bool debug_synchronous=false)
Device-level parallel unique primitive.
Definition: device_select.hpp:383