21 #ifndef ROCPRIM_BLOCK_DETAIL_BLOCK_SCAN_REDUCE_THEN_SCAN_HPP_ 22 #define ROCPRIM_BLOCK_DETAIL_BLOCK_SCAN_REDUCE_THEN_SCAN_HPP_ 24 #include <type_traits> 26 #include "../../config.hpp" 27 #include "../../detail/various.hpp" 29 #include "../../intrinsics.hpp" 30 #include "../../functional.hpp" 32 #include "../../warp/warp_scan.hpp" 34 BEGIN_ROCPRIM_NAMESPACE
41 unsigned int BlockSizeX,
42 unsigned int BlockSizeY,
43 unsigned int BlockSizeZ
47 static constexpr
unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ;
49 static constexpr
unsigned int thread_reduction_size_ =
54 static constexpr
unsigned int warp_size_ =
56 using warp_scan_prefix_type = ::rocprim::detail::warp_scan_crosslane<T, warp_size_>;
59 static constexpr
unsigned int banks_no_ = ::rocprim::detail::get_lds_banks_no();
60 static constexpr
bool has_bank_conflicts_ =
61 ::rocprim::detail::is_power_of_two(thread_reduction_size_) && thread_reduction_size_ > 1;
62 static constexpr
unsigned int bank_conflicts_padding =
63 has_bank_conflicts_ ? (warp_size_ * thread_reduction_size_ / banks_no_) : 0;
67 T threads[warp_size_ * thread_reduction_size_ + bank_conflicts_padding];
73 template<
class BinaryFunction>
74 ROCPRIM_DEVICE ROCPRIM_INLINE
78 BinaryFunction scan_op)
80 const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
81 this->inclusive_scan_impl(flat_tid, input, output, storage, scan_op);
84 template<
class BinaryFunction>
85 ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
88 BinaryFunction scan_op)
94 template<
class BinaryFunction>
95 ROCPRIM_DEVICE ROCPRIM_INLINE
100 BinaryFunction scan_op)
102 storage_type_& storage_ = storage.get();
104 reduction = storage_.threads[index(BlockSize - 1)];
107 template<
class BinaryFunction>
108 ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
112 BinaryFunction scan_op)
118 template<
class PrefixCallback,
class BinaryFunction>
119 ROCPRIM_DEVICE ROCPRIM_INLINE
123 PrefixCallback& prefix_callback_op,
124 BinaryFunction scan_op)
126 const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
128 storage_type_& storage_ = storage.get();
129 this->inclusive_scan_impl(flat_tid, input, output, storage, scan_op);
131 T block_prefix = this->get_block_prefix(
133 storage_.threads[index(BlockSize - 1)],
134 prefix_callback_op, storage
136 output = scan_op(block_prefix, output);
139 template<
unsigned int ItemsPerThread,
class BinaryFunction>
140 ROCPRIM_DEVICE ROCPRIM_INLINE
142 T (&output)[ItemsPerThread],
144 BinaryFunction scan_op)
147 T thread_input = input[0];
149 for(
unsigned int i = 1; i < ItemsPerThread; i++)
151 thread_input = scan_op(thread_input, input[i]);
155 const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
156 this->exclusive_scan_impl(
158 thread_input, thread_input,
164 output[0] = input[0];
165 if(flat_tid != 0) output[0] = scan_op(thread_input, input[0]);
168 for(
unsigned int i = 1; i < ItemsPerThread; i++)
170 output[i] = scan_op(output[i-1], input[i]);
174 template<
unsigned int ItemsPerThread,
class BinaryFunction>
175 ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
177 T (&output)[ItemsPerThread],
178 BinaryFunction scan_op)
184 template<
unsigned int ItemsPerThread,
class BinaryFunction>
185 ROCPRIM_DEVICE ROCPRIM_INLINE
187 T (&output)[ItemsPerThread],
190 BinaryFunction scan_op)
192 storage_type_& storage_ = storage.get();
195 reduction = storage_.threads[index(BlockSize - 1)];
198 template<
unsigned int ItemsPerThread,
class BinaryFunction>
199 ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
201 T (&output)[ItemsPerThread],
203 BinaryFunction scan_op)
210 class PrefixCallback,
211 unsigned int ItemsPerThread,
214 ROCPRIM_DEVICE ROCPRIM_INLINE
216 T (&output)[ItemsPerThread],
218 PrefixCallback& prefix_callback_op,
219 BinaryFunction scan_op)
221 storage_type_& storage_ = storage.get();
223 T thread_input = input[0];
225 for(
unsigned int i = 1; i < ItemsPerThread; i++)
227 thread_input = scan_op(thread_input, input[i]);
231 const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
232 this->exclusive_scan_impl(
234 thread_input, thread_input,
240 T block_prefix = this->get_block_prefix(
242 storage_.threads[index(BlockSize - 1)],
243 prefix_callback_op, storage
247 output[0] = input[0];
248 if(flat_tid != 0) output[0] = scan_op(thread_input, input[0]);
250 output[0] = scan_op(block_prefix, output[0]);
253 for(
unsigned int i = 1; i < ItemsPerThread; i++)
255 output[i] = scan_op(output[i-1], input[i]);
259 template<
class BinaryFunction>
260 ROCPRIM_DEVICE ROCPRIM_INLINE
265 BinaryFunction scan_op)
267 const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
268 this->exclusive_scan_impl(flat_tid, input, output, init, storage, scan_op);
271 template<
class BinaryFunction>
272 ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
276 BinaryFunction scan_op)
282 template<
class BinaryFunction>
283 ROCPRIM_DEVICE ROCPRIM_INLINE
289 BinaryFunction scan_op)
291 const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
292 storage_type_& storage_ = storage.get();
293 this->exclusive_scan_impl(
294 flat_tid, input, output, init, storage, scan_op
297 reduction = storage_.threads[index(BlockSize - 1)];
300 template<
class BinaryFunction>
301 ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
306 BinaryFunction scan_op)
309 this->
exclusive_scan(input, output, init, reduction, storage, scan_op);
312 template<
class PrefixCallback,
class BinaryFunction>
313 ROCPRIM_DEVICE ROCPRIM_INLINE
317 PrefixCallback& prefix_callback_op,
318 BinaryFunction scan_op)
320 const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
322 storage_type_& storage_ = storage.get();
323 this->exclusive_scan_impl(
324 flat_tid, input, output, storage, scan_op
327 T reduction = storage_.threads[index(BlockSize - 1)];
329 T block_prefix = this->get_block_prefix(
330 flat_tid, warp_id, reduction,
331 prefix_callback_op, storage
333 output = scan_op(block_prefix, output);
334 if(flat_tid == 0) output = block_prefix;
337 template<
unsigned int ItemsPerThread,
class BinaryFunction>
338 ROCPRIM_DEVICE ROCPRIM_INLINE
340 T (&output)[ItemsPerThread],
343 BinaryFunction scan_op)
346 T thread_input = input[0];
348 for(
unsigned int i = 1; i < ItemsPerThread; i++)
350 thread_input = scan_op(thread_input, input[i]);
354 const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
355 this->exclusive_scan_impl(
357 thread_input, thread_input,
368 exclusive = thread_input;
370 output[0] = exclusive;
372 for(
unsigned int i = 1; i < ItemsPerThread; i++)
374 exclusive = scan_op(exclusive, prev);
376 output[i] = exclusive;
380 template<
unsigned int ItemsPerThread,
class BinaryFunction>
381 ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
383 T (&output)[ItemsPerThread],
385 BinaryFunction scan_op)
391 template<
unsigned int ItemsPerThread,
class BinaryFunction>
392 ROCPRIM_DEVICE ROCPRIM_INLINE
394 T (&output)[ItemsPerThread],
398 BinaryFunction scan_op)
400 storage_type_& storage_ = storage.get();
403 reduction = storage_.threads[index(BlockSize - 1)];
406 template<
unsigned int ItemsPerThread,
class BinaryFunction>
407 ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
409 T (&output)[ItemsPerThread],
412 BinaryFunction scan_op)
415 this->
exclusive_scan(input, output, init, reduction, storage, scan_op);
419 class PrefixCallback,
420 unsigned int ItemsPerThread,
423 ROCPRIM_DEVICE ROCPRIM_INLINE
425 T (&output)[ItemsPerThread],
427 PrefixCallback& prefix_callback_op,
428 BinaryFunction scan_op)
430 storage_type_& storage_ = storage.get();
432 T thread_input = input[0];
434 for(
unsigned int i = 1; i < ItemsPerThread; i++)
436 thread_input = scan_op(thread_input, input[i]);
440 const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
441 this->exclusive_scan_impl(
443 thread_input, thread_input,
449 T block_prefix = this->get_block_prefix(
451 storage_.threads[index(BlockSize - 1)],
452 prefix_callback_op, storage
457 T exclusive = block_prefix;
460 exclusive = scan_op(block_prefix, thread_input);
462 output[0] = exclusive;
464 for(
unsigned int i = 1; i < ItemsPerThread; i++)
466 exclusive = scan_op(exclusive, prev);
468 output[i] = exclusive;
477 template<
class BinaryFunction>
478 ROCPRIM_DEVICE ROCPRIM_INLINE
479 void inclusive_scan_impl(
const unsigned int flat_tid,
483 BinaryFunction scan_op)
485 storage_type_& storage_ = storage.get();
488 this->inclusive_scan_base(flat_tid, input, storage, scan_op);
489 output = storage_.threads[index(flat_tid)];
494 template<
class BinaryFunction>
495 ROCPRIM_DEVICE ROCPRIM_INLINE
496 void inclusive_scan_base(
const unsigned int flat_tid,
499 BinaryFunction scan_op)
501 storage_type_& storage_ = storage.get();
502 storage_.threads[index(flat_tid)] = input;
504 if(flat_tid < warp_size_)
506 const unsigned int idx_start = index(flat_tid * thread_reduction_size_);
507 const unsigned int idx_end = idx_start + thread_reduction_size_;
509 T thread_reduction = storage_.threads[idx_start];
511 for(
unsigned int i = idx_start + 1; i < idx_end; i++)
513 thread_reduction = scan_op(
514 thread_reduction, storage_.threads[i]
519 warp_scan_prefix_type().inclusive_scan(thread_reduction, thread_reduction, scan_op);
523 thread_reduction = scan_op(thread_reduction, storage_.threads[idx_start]);
526 thread_reduction = input;
529 storage_.threads[idx_start] = thread_reduction;
531 for(
unsigned int i = idx_start + 1; i < idx_end; i++)
533 thread_reduction = scan_op(
534 thread_reduction, storage_.threads[i]
536 storage_.threads[i] = thread_reduction;
542 template<
class BinaryFunction>
543 ROCPRIM_DEVICE ROCPRIM_INLINE
544 void exclusive_scan_impl(
const unsigned int flat_tid,
549 BinaryFunction scan_op)
551 storage_type_& storage_ = storage.get();
553 this->inclusive_scan_base(flat_tid, input, storage, scan_op);
555 if(flat_tid != 0) output = scan_op(init, storage_.threads[index(flat_tid-1)]);
558 template<
class BinaryFunction>
559 ROCPRIM_DEVICE ROCPRIM_INLINE
560 void exclusive_scan_impl(
const unsigned int flat_tid,
564 BinaryFunction scan_op)
566 storage_type_& storage_ = storage.get();
568 this->inclusive_scan_base(flat_tid, input, storage, scan_op);
571 output = storage_.threads[index(flat_tid-1)];
576 template<
class PrefixCallback,
class BinaryFunction>
577 ROCPRIM_DEVICE ROCPRIM_INLINE
578 void include_block_prefix(
const unsigned int flat_tid,
583 PrefixCallback& prefix_callback_op,
585 BinaryFunction scan_op)
587 T block_prefix = this->get_block_prefix(
588 flat_tid, warp_id, reduction,
589 prefix_callback_op, storage
591 output = scan_op(block_prefix, input);
595 template<
class PrefixCallback>
596 ROCPRIM_DEVICE ROCPRIM_INLINE
597 T get_block_prefix(
const unsigned int flat_tid,
598 const unsigned int warp_id,
600 PrefixCallback& prefix_callback_op,
603 storage_type_& storage_ = storage.get();
606 T block_prefix = prefix_callback_op(reduction);
611 storage_.threads[0] = block_prefix;
615 return storage_.threads[0];
619 ROCPRIM_DEVICE ROCPRIM_INLINE
620 unsigned int index(
unsigned int n)
const 623 return has_bank_conflicts_ ? (n + (n/banks_no_)) : n;
629 END_ROCPRIM_NAMESPACE
631 #endif // ROCPRIM_BLOCK_DETAIL_BLOCK_SCAN_REDUCE_THEN_SCAN_HPP_ Definition: benchmark_block_scan.cpp:63
ROCPRIM_DEVICE ROCPRIM_INLINE constexpr unsigned int device_warp_size()
Returns a number of threads in a hardware warp for the actual target.
Definition: thread.hpp:70
Definition: block_scan_reduce_then_scan.hpp:45
ROCPRIM_DEVICE ROCPRIM_INLINE T warp_shuffle_up(const T &input, const unsigned int delta, const int width=device_warp_size())
Shuffle up for any data type.
Definition: warp_shuffle.hpp:197
Deprecated: Configuration of device-level scan primitives.
Definition: block_histogram.hpp:62
const unsigned int warp_id
Returns warp id in a block (tile).
Definition: benchmark_warp_exchange.cpp:153
ROCPRIM_DEVICE ROCPRIM_INLINE void syncthreads()
Synchronize all threads in a block (tile)
Definition: thread.hpp:216
Definition: benchmark_block_scan.cpp:100