21 #ifndef ROCPRIM_BLOCK_DETAIL_BLOCK_SCAN_WARP_SCAN_HPP_ 22 #define ROCPRIM_BLOCK_DETAIL_BLOCK_SCAN_WARP_SCAN_HPP_ 24 #include <type_traits> 26 #include "../../config.hpp" 27 #include "../../detail/various.hpp" 29 #include "../../intrinsics.hpp" 30 #include "../../functional.hpp" 32 #include "../../warp/warp_scan.hpp" 34 BEGIN_ROCPRIM_NAMESPACE
41 unsigned int BlockSizeX,
42 unsigned int BlockSizeY,
43 unsigned int BlockSizeZ
47 static constexpr
unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ;
49 static constexpr
unsigned int warp_size_ =
52 static constexpr
unsigned int warps_no_ = (BlockSize + warp_size_ - 1) / warp_size_;
58 using warp_scan_input_type = ::rocprim::detail::warp_scan_crosslane<T, warp_size_>;
61 using warp_scan_prefix_type = ::rocprim::detail::warp_scan_crosslane<T, detail::next_power_of_two(warps_no_)>;
65 T warp_prefixes[warps_no_];
82 template<
class BinaryFunction>
83 ROCPRIM_DEVICE ROCPRIM_INLINE
87 BinaryFunction scan_op)
89 this->inclusive_scan_impl(
90 ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
91 input, output, storage, scan_op
95 template<
class BinaryFunction>
96 ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
99 BinaryFunction scan_op)
105 template<
class BinaryFunction>
106 ROCPRIM_DEVICE ROCPRIM_INLINE
111 BinaryFunction scan_op)
113 storage_type_& storage_ = storage.get();
116 reduction = storage_.warp_prefixes[warps_no_ - 1];
119 template<
class BinaryFunction>
120 ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
124 BinaryFunction scan_op)
130 template<
class PrefixCallback,
class BinaryFunction>
131 ROCPRIM_DEVICE ROCPRIM_INLINE
135 PrefixCallback& prefix_callback_op,
136 BinaryFunction scan_op)
138 const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
140 storage_type_& storage_ = storage.get();
141 this->inclusive_scan_impl(flat_tid, input, output, storage, scan_op);
143 T block_prefix = this->get_block_prefix(
145 storage_.warp_prefixes[warps_no_ - 1],
146 prefix_callback_op, storage
148 output = scan_op(block_prefix, output);
151 template<
unsigned int ItemsPerThread,
class BinaryFunction>
152 ROCPRIM_DEVICE ROCPRIM_INLINE
154 T (&output)[ItemsPerThread],
156 BinaryFunction scan_op)
159 T thread_input = input[0];
161 for(
unsigned int i = 1; i < ItemsPerThread; i++)
163 thread_input = scan_op(thread_input, input[i]);
167 const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
168 this->exclusive_scan_impl(
170 thread_input, thread_input,
176 output[0] = input[0];
179 output[0] = scan_op(thread_input, input[0]);
184 for(
unsigned int i = 1; i < ItemsPerThread; i++)
186 output[i] = scan_op(output[i-1], input[i]);
190 template<
unsigned int ItemsPerThread,
class BinaryFunction>
191 ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
193 T (&output)[ItemsPerThread],
194 BinaryFunction scan_op)
200 template<
unsigned int ItemsPerThread,
class BinaryFunction>
201 ROCPRIM_DEVICE ROCPRIM_INLINE
203 T (&output)[ItemsPerThread],
206 BinaryFunction scan_op)
208 storage_type_& storage_ = storage.get();
211 reduction = storage_.warp_prefixes[warps_no_ - 1];
214 template<
unsigned int ItemsPerThread,
class BinaryFunction>
215 ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
217 T (&output)[ItemsPerThread],
219 BinaryFunction scan_op)
226 class PrefixCallback,
227 unsigned int ItemsPerThread,
230 ROCPRIM_DEVICE ROCPRIM_INLINE
232 T (&output)[ItemsPerThread],
234 PrefixCallback& prefix_callback_op,
235 BinaryFunction scan_op)
237 storage_type_& storage_ = storage.get();
239 T thread_input = input[0];
241 for(
unsigned int i = 1; i < ItemsPerThread; i++)
243 thread_input = scan_op(thread_input, input[i]);
247 const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
248 this->exclusive_scan_impl(
250 thread_input, thread_input,
256 T block_prefix = this->get_block_prefix(
258 storage_.warp_prefixes[warps_no_ - 1],
259 prefix_callback_op, storage
263 output[0] = input[0];
266 output[0] = scan_op(thread_input, input[0]);
269 output[0] = scan_op(block_prefix, output[0]);
272 for(
unsigned int i = 1; i < ItemsPerThread; i++)
274 output[i] = scan_op(output[i-1], input[i]);
278 template<
class BinaryFunction>
279 ROCPRIM_DEVICE ROCPRIM_INLINE
284 BinaryFunction scan_op)
286 this->exclusive_scan_impl(
287 ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
288 input, output, init, storage, scan_op
292 template<
class BinaryFunction>
293 ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
297 BinaryFunction scan_op)
301 input, output, init, storage, scan_op
305 template<
class BinaryFunction>
306 ROCPRIM_DEVICE ROCPRIM_INLINE
312 BinaryFunction scan_op)
314 storage_type_& storage_ = storage.get();
316 input, output, init, storage, scan_op
319 reduction = storage_.warp_prefixes[warps_no_ - 1];
322 template<
class BinaryFunction>
323 ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
328 BinaryFunction scan_op)
332 input, output, init, reduction, storage, scan_op
336 template<
class PrefixCallback,
class BinaryFunction>
337 ROCPRIM_DEVICE ROCPRIM_INLINE
341 PrefixCallback& prefix_callback_op,
342 BinaryFunction scan_op)
344 const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
346 storage_type_& storage_ = storage.get();
347 this->exclusive_scan_impl(
348 flat_tid, input, output, storage, scan_op
351 T block_prefix = this->get_block_prefix(
353 storage_.warp_prefixes[warps_no_ - 1],
354 prefix_callback_op, storage
356 output = scan_op(block_prefix, output);
357 if(flat_tid == 0) output = block_prefix;
360 template<
unsigned int ItemsPerThread,
class BinaryFunction>
361 ROCPRIM_DEVICE ROCPRIM_INLINE
363 T (&output)[ItemsPerThread],
366 BinaryFunction scan_op)
369 T thread_input = input[0];
371 for(
unsigned int i = 1; i < ItemsPerThread; i++)
373 thread_input = scan_op(thread_input, input[i]);
377 const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
378 this->exclusive_scan_impl(
380 thread_input, thread_input,
391 exclusive = thread_input;
393 output[0] = exclusive;
396 for(
unsigned int i = 1; i < ItemsPerThread; i++)
398 exclusive = scan_op(exclusive, prev);
400 output[i] = exclusive;
404 template<
unsigned int ItemsPerThread,
class BinaryFunction>
405 ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
407 T (&output)[ItemsPerThread],
409 BinaryFunction scan_op)
415 template<
unsigned int ItemsPerThread,
class BinaryFunction>
416 ROCPRIM_DEVICE ROCPRIM_INLINE
418 T (&output)[ItemsPerThread],
422 BinaryFunction scan_op)
424 storage_type_& storage_ = storage.get();
427 reduction = storage_.warp_prefixes[warps_no_ - 1];
430 template<
unsigned int ItemsPerThread,
class BinaryFunction>
431 ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
433 T (&output)[ItemsPerThread],
436 BinaryFunction scan_op)
439 this->
exclusive_scan(input, output, init, reduction, storage, scan_op);
443 class PrefixCallback,
444 unsigned int ItemsPerThread,
447 ROCPRIM_DEVICE ROCPRIM_INLINE
449 T (&output)[ItemsPerThread],
451 PrefixCallback& prefix_callback_op,
452 BinaryFunction scan_op)
454 storage_type_& storage_ = storage.get();
456 T thread_input = input[0];
458 for(
unsigned int i = 1; i < ItemsPerThread; i++)
460 thread_input = scan_op(thread_input, input[i]);
464 const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
465 this->exclusive_scan_impl(
467 thread_input, thread_input,
473 T block_prefix = this->get_block_prefix(
475 storage_.warp_prefixes[warps_no_ - 1],
476 prefix_callback_op, storage
481 T exclusive = block_prefix;
484 exclusive = scan_op(block_prefix, thread_input);
486 output[0] = exclusive;
489 for(
unsigned int i = 1; i < ItemsPerThread; i++)
491 exclusive = scan_op(exclusive, prev);
493 output[i] = exclusive;
498 template<
class BinaryFunction,
unsigned int BlockSize_ = BlockSize>
499 ROCPRIM_DEVICE ROCPRIM_INLINE
500 auto inclusive_scan_impl(
const unsigned int flat_tid,
504 BinaryFunction scan_op)
507 storage_type_& storage_ = storage.get();
509 warp_scan_input_type().inclusive_scan(
511 input, output, scan_op
516 this->calculate_warp_prefixes(flat_tid,
warp_id, output, storage, scan_op);
521 auto warp_prefix = storage_.warp_prefixes[
warp_id - 1];
522 output = scan_op(warp_prefix, output);
527 template<
class BinaryFunction,
unsigned int BlockSize_ = BlockSize>
528 ROCPRIM_DEVICE ROCPRIM_INLINE
529 auto inclusive_scan_impl(
unsigned int flat_tid,
533 BinaryFunction scan_op)
538 storage_type_& storage_ = storage.get();
540 warp_scan_input_type().inclusive_scan(
542 input, output, scan_op
545 if(flat_tid == BlockSize_ - 1)
547 storage_.warp_prefixes[0] = output;
553 template<
class BinaryFunction,
unsigned int BlockSize_ = BlockSize>
554 ROCPRIM_DEVICE ROCPRIM_INLINE
555 auto exclusive_scan_impl(
const unsigned int flat_tid,
560 BinaryFunction scan_op)
563 storage_type_& storage_ = storage.get();
565 warp_scan_input_type().inclusive_scan(
567 input, output, scan_op
572 this->calculate_warp_prefixes(flat_tid,
warp_id, output, storage, scan_op);
576 auto warp_prefix = init;
579 warp_prefix = scan_op(init, storage_.warp_prefixes[
warp_id-1]);
583 output = scan_op(warp_prefix, output);
587 output = warp_prefix;
593 template<
class BinaryFunction,
unsigned int BlockSize_ = BlockSize>
594 ROCPRIM_DEVICE ROCPRIM_INLINE
595 auto exclusive_scan_impl(
const unsigned int flat_tid,
600 BinaryFunction scan_op)
606 storage_type_& storage_ = storage.get();
608 warp_scan_input_type().inclusive_scan(
610 input, output, scan_op
613 if(flat_tid == BlockSize_ - 1)
615 storage_.warp_prefixes[0] = output;
620 output = scan_op(init, output);
629 template<
class BinaryFunction,
unsigned int BlockSize_ = BlockSize>
630 ROCPRIM_DEVICE ROCPRIM_INLINE
631 auto exclusive_scan_impl(
const unsigned int flat_tid,
635 BinaryFunction scan_op)
638 storage_type_& storage_ = storage.get();
640 warp_scan_input_type().inclusive_scan(
642 input, output, scan_op
647 this->calculate_warp_prefixes(flat_tid,
warp_id, output, storage, scan_op);
653 warp_prefix = storage_.warp_prefixes[
warp_id - 1];
654 output = scan_op(warp_prefix, output);
659 output = warp_prefix;
665 template<
class BinaryFunction,
unsigned int BlockSize_ = BlockSize>
666 ROCPRIM_DEVICE ROCPRIM_INLINE
667 auto exclusive_scan_impl(
const unsigned int flat_tid,
671 BinaryFunction scan_op)
676 storage_type_& storage_ = storage.get();
678 warp_scan_input_type().inclusive_scan(
680 input, output, scan_op
683 if(flat_tid == BlockSize_ - 1)
685 storage_.warp_prefixes[0] = output;
692 template<
class BinaryFunction,
unsigned int BlockSize_ = BlockSize>
693 ROCPRIM_DEVICE ROCPRIM_INLINE
694 void calculate_warp_prefixes(
const unsigned int flat_tid,
698 BinaryFunction scan_op)
700 storage_type_& storage_ = storage.get();
703 if(flat_tid == ::
rocprim::min((warp_id+1) * warp_size_, BlockSize_) - 1)
705 storage_.warp_prefixes[
warp_id] = inclusive_input;
710 if(flat_tid < warps_no_)
712 auto warp_prefix = storage_.warp_prefixes[flat_tid];
713 warp_scan_prefix_type().inclusive_scan(
715 warp_prefix, warp_prefix, scan_op
717 storage_.warp_prefixes[flat_tid] = warp_prefix;
723 template<
class PrefixCallback>
724 ROCPRIM_DEVICE ROCPRIM_INLINE
725 T get_block_prefix(
const unsigned int flat_tid,
726 const unsigned int warp_id,
728 PrefixCallback& prefix_callback_op,
731 storage_type_& storage_ = storage.get();
734 T block_prefix = prefix_callback_op(reduction);
738 storage_.warp_prefixes[warps_no_ - 1] = block_prefix;
742 return storage_.warp_prefixes[warps_no_ - 1];
748 END_ROCPRIM_NAMESPACE
750 #endif // ROCPRIM_BLOCK_DETAIL_BLOCK_SCAN_WARP_SCAN_HPP_ Definition: benchmark_block_scan.cpp:63
Definition: block_scan_warp_scan.hpp:45
ROCPRIM_DEVICE ROCPRIM_INLINE constexpr unsigned int device_warp_size()
Returns a number of threads in a hardware warp for the actual target.
Definition: thread.hpp:70
ROCPRIM_DEVICE ROCPRIM_INLINE T warp_shuffle_up(const T &input, const unsigned int delta, const int width=device_warp_size())
Shuffle up for any data type.
Definition: warp_shuffle.hpp:197
ROCPRIM_HOST_DEVICE constexpr T min(const T &a, const T &b)
Returns the minimum of its arguments.
Definition: functional.hpp:63
Deprecated: Configuration of device-level scan primitives.
Definition: block_histogram.hpp:62
const unsigned int warp_id
Returns warp id in a block (tile).
Definition: benchmark_warp_exchange.cpp:153
ROCPRIM_DEVICE ROCPRIM_INLINE void syncthreads()
Synchronize all threads in a block (tile)
Definition: thread.hpp:216
Definition: benchmark_block_scan.cpp:100
ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int lane_id()
Returns thread identifier in a warp.
Definition: thread.hpp:93