21 #ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_SCAN_BY_KEY_HPP_ 22 #define ROCPRIM_DEVICE_DETAIL_DEVICE_SCAN_BY_KEY_HPP_ 24 #include "device_scan_common.hpp" 25 #include "lookback_scan_state.hpp" 27 #include "../../block/block_discontinuity.hpp" 28 #include "../../block/block_load.hpp" 29 #include "../../block/block_scan.hpp" 30 #include "../../block/block_store.hpp" 31 #include "../../config.hpp" 32 #include "../../detail/binary_op_wrappers.hpp" 33 #include "../../intrinsics/thread.hpp" 34 #include "../../types/tuple.hpp" 36 #include <type_traits> 38 BEGIN_ROCPRIM_NAMESPACE
42 template <
bool Exclusive,
44 unsigned int items_per_thread,
52 = ::rocprim::block_load<key_type, block_size, items_per_thread, load_keys_method>;
56 using block_load_values
57 = ::rocprim::block_load<result_type, block_size, items_per_thread, load_keys_method>;
61 typename block_load_keys::storage_type load;
64 typename block_load_values::storage_type load_values;
74 template <
typename KeyIterator,
typename ValueIterator,
typename CompareFunction>
76 load(KeyIterator keys_input,
77 ValueIterator values_input,
78 CompareFunction compare,
79 const result_type initial_value,
81 const size_t starting_block,
82 const size_t number_of_blocks,
83 const unsigned int flat_thread_id,
85 rocprim::tuple<result_type, bool> (&wrapped_values)[items_per_thread],
88 constexpr
static unsigned int items_per_block = items_per_thread *
block_size;
89 const unsigned int block_offset = flat_block_id * items_per_block;
90 KeyIterator block_keys = keys_input + block_offset;
91 ValueIterator block_values = values_input + block_offset;
93 key_type keys[items_per_thread];
94 result_type values[items_per_thread];
95 bool flags[items_per_thread];
98 = [compare](
const auto& a,
const auto& b)
mutable {
return !compare(a, b); };
100 const auto flag_segment_boundaries = [&]() {
103 const key_type tile_successor
104 = starting_block + flat_block_id < number_of_blocks - 1
105 ? block_keys[items_per_block]
108 flags, tile_successor, keys, not_equal, storage.keys.flag);
112 const key_type tile_predecessor = starting_block + flat_block_id > 0
116 flags, tile_predecessor, keys, not_equal, storage.keys.flag);
120 if(starting_block + flat_block_id < number_of_blocks - 1)
122 block_load_keys{}.load(
128 flag_segment_boundaries();
132 block_load_values{}.load(
139 for(
unsigned int i = 0; i < items_per_thread; ++i) {
140 rocprim::get<0>(wrapped_values[i])
141 = (Exclusive && flags[i]) ? initial_value : values[i];
142 rocprim::get<1>(wrapped_values[i]) = flags[i];
147 const unsigned int valid_in_last_block
148 =
static_cast<unsigned int>(size - items_per_block * (number_of_blocks - 1));
150 block_load_keys {}.load(
157 flag_segment_boundaries();
161 block_load_values{}.load(
169 for(
unsigned int i = 0; i < items_per_thread; ++i) {
170 if(flat_thread_id * items_per_thread + i >= valid_in_last_block) {
174 rocprim::get<0>(wrapped_values[i])
175 = (Exclusive && flags[i]) ? initial_value : values[i];
176 rocprim::get<1>(wrapped_values[i]) = flags[i];
183 unsigned int items_per_thread,
184 typename result_type,
188 using block_store_values
189 = ::rocprim::block_store<result_type, block_size, items_per_thread, store_method>;
191 using storage_type =
typename block_store_values::storage_type;
193 template <
typename OutputIterator>
195 store(OutputIterator output,
197 const size_t starting_block,
198 const size_t number_of_blocks,
199 const unsigned int flat_thread_id,
201 const rocprim::tuple<result_type, bool> (&wrapped_values)[items_per_thread],
202 storage_type& storage)
204 constexpr
static unsigned int items_per_block = items_per_thread *
block_size;
205 const unsigned int block_offset = flat_block_id * items_per_block;
206 OutputIterator block_output = output + block_offset;
208 result_type thread_values[items_per_thread];
210 if(starting_block + flat_block_id < number_of_blocks - 1)
213 for(
unsigned int i = 0; i < items_per_thread; ++i) {
214 thread_values[i] = rocprim::get<0>(wrapped_values[i]);
220 block_store_values {}.store(block_output, thread_values, storage);
224 const unsigned int valid_in_last_block
225 =
static_cast<unsigned int>(size - items_per_block * (number_of_blocks - 1));
228 for(
unsigned int i = 0; i < items_per_thread; ++i) {
229 if(flat_thread_id * items_per_thread + i >= valid_in_last_block) {
233 thread_values[i] = rocprim::get<0>(wrapped_values[i]);
239 block_store_values {}.store(
240 block_output, thread_values, valid_in_last_block, storage);
245 template<
bool Exclusive,
247 typename KeyInputIterator,
248 typename InputIterator,
249 typename OutputIterator,
251 typename CompareFunction,
252 typename BinaryFunction,
253 typename LookbackScanState>
254 ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
void device_scan_by_key_kernel_impl(
255 KeyInputIterator keys,
256 InputIterator values,
257 OutputIterator output,
258 ResultType initial_value,
259 const CompareFunction compare,
260 const BinaryFunction scan_op,
261 LookbackScanState scan_state,
263 const size_t starting_block,
264 const size_t number_of_blocks,
265 const rocprim::tuple<ResultType, bool>*
const previous_last_value)
267 using result_type = ResultType;
268 static_assert(std::is_same<rocprim::tuple<ResultType, bool>,
269 typename LookbackScanState::value_type>::value,
270 "value_type of LookbackScanState must be tuple of result type and flag");
273 constexpr
auto block_size = params.kernel_config.block_size;
274 constexpr
auto items_per_thread = params.kernel_config.items_per_thread;
275 constexpr
auto load_keys_method = params.block_load_method;
276 constexpr
auto load_values_method = load_keys_method;
278 using key_type =
typename std::iterator_traits<KeyInputIterator>::value_type;
288 using wrapped_type = rocprim::tuple<result_type, bool>;
290 using block_scan_type
291 = ::rocprim::block_scan<wrapped_type,
block_size, params.block_scan_method>;
293 constexpr
auto store_method = params.block_store_method;
296 ROCPRIM_SHARED_MEMORY
union 298 typename load_flagged::storage_type load;
299 typename block_scan_type::storage_type scan;
300 typename store_unwrap::storage_type store;
303 const auto flat_thread_id = ::rocprim::detail::block_thread_id<0>();
307 wrapped_type wrapped_values[items_per_thread];
308 load_flagged {}.load(keys,
324 if(flat_block_id == 0)
326 auto wrapped_initial_value = rocprim::make_tuple(initial_value,
false);
330 if(previous_last_value !=
nullptr)
333 rocprim::get<0>(wrapped_initial_value) = rocprim::get<0>(*previous_last_value);
334 }
else if (flat_thread_id == 0) {
335 wrapped_values[0] = wrapped_op(*previous_last_value, wrapped_values[0]);
339 wrapped_type reduction;
340 lookback_block_scan<Exclusive, block_scan_type>(wrapped_values,
341 wrapped_initial_value,
346 if(flat_thread_id == 0)
348 scan_state.set_complete(flat_block_id, reduction);
354 decltype(wrapped_op),
355 decltype(scan_state)> {
359 lookback_block_scan<Exclusive, block_scan_type>(
368 store_unwrap {}.store(output,
379 END_ROCPRIM_NAMESPACE
381 #endif // ROCPRIM_DEVICE_DETAIL_DEVICE_SCAN_BY_KEY_HPP_ Definition: device_scan_by_key.hpp:59
Definition: device_scan_by_key.hpp:186
The block_discontinuity class is a block level parallel primitive which provides methods for flagging...
Definition: block_discontinuity.hpp:82
block_store_method
block_store_method enumerates the methods available to store a striped arrangement of items into a bl...
Definition: block_store.hpp:41
Definition: lookback_scan_state.hpp:356
Deprecated: Configuration of device-level scan primitives.
Definition: block_histogram.hpp:62
ROCPRIM_DEVICE ROCPRIM_INLINE void flag_heads(Flag(&head_flags)[ItemsPerThread], const T(&input)[ItemsPerThread], FlagOp flag_op, storage_type &storage)
Tags head_flags that indicate discontinuities between items partitioned across the thread block...
Definition: block_discontinuity.hpp:156
ROCPRIM_DEVICE ROCPRIM_INLINE void syncthreads()
Synchronize all threads in a block (tile)
Definition: thread.hpp:216
block_load_method
block_load_method enumerates the methods available to load data from continuous memory into a blocked...
Definition: block_load.hpp:41
ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int flat_block_id()
Returns flat (linear, 1D) block identifier in a multidimensional grid.
Definition: thread.hpp:178
ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int block_size()
Returns block size in a multidimensional grid by dimension.
Definition: thread.hpp:268
Definition: binary_op_wrappers.hpp:72
Definition: test_device_binary_search.cpp:37
Provides the kernel parameters for exclusive_scan_by_key and inclusive_scan_by_key based on autotuned...
Definition: device_config_helper.hpp:393
Definition: device_scan_by_key.hpp:49
ROCPRIM_DEVICE ROCPRIM_INLINE void flag_tails(Flag(&tail_flags)[ItemsPerThread], const T(&input)[ItemsPerThread], FlagOp flag_op, storage_type &storage)
Tags tail_flags that indicate discontinuities between items partitioned across the thread block...
Definition: block_discontinuity.hpp:304