21 #ifndef ROCPRIM_BLOCK_DETAIL_BLOCK_SCAN_REDUCE_THEN_SCAN_HPP_    22 #define ROCPRIM_BLOCK_DETAIL_BLOCK_SCAN_REDUCE_THEN_SCAN_HPP_    24 #include <type_traits>    26 #include "../../config.hpp"    27 #include "../../detail/various.hpp"    29 #include "../../intrinsics.hpp"    30 #include "../../functional.hpp"    32 #include "../../warp/warp_scan.hpp"    34 BEGIN_ROCPRIM_NAMESPACE
    41     unsigned int BlockSizeX,
    42     unsigned int BlockSizeY,
    43     unsigned int BlockSizeZ
    47     static constexpr 
unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ;
    49     static constexpr 
unsigned int thread_reduction_size_ =
    54     static constexpr 
unsigned int warp_size_ =
    56     using warp_scan_prefix_type = ::rocprim::detail::warp_scan_crosslane<T, warp_size_>;
    59     static constexpr 
unsigned int banks_no_ = ::rocprim::detail::get_lds_banks_no();
    60     static constexpr 
bool has_bank_conflicts_ =
    61         ::rocprim::detail::is_power_of_two(thread_reduction_size_) && thread_reduction_size_ > 1;
    62     static constexpr 
unsigned int bank_conflicts_padding =
    63         has_bank_conflicts_ ? (warp_size_ * thread_reduction_size_ / banks_no_) : 0;
    67         T threads[warp_size_ * thread_reduction_size_ + bank_conflicts_padding];
    73     template<
class BinaryFunction>
    74     ROCPRIM_DEVICE ROCPRIM_INLINE
    78                         BinaryFunction scan_op)
    80         const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
    81         this->inclusive_scan_impl(flat_tid, input, output, storage, scan_op);
    84     template<
class BinaryFunction>
    85     ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
    88                         BinaryFunction scan_op)
    94     template<
class BinaryFunction>
    95     ROCPRIM_DEVICE ROCPRIM_INLINE
   100                         BinaryFunction scan_op)
   102         storage_type_& storage_ = storage.get();
   104         reduction = storage_.threads[index(BlockSize - 1)];
   107     template<
class BinaryFunction>
   108     ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
   112                         BinaryFunction scan_op)
   118     template<
class PrefixCallback, 
class BinaryFunction>
   119     ROCPRIM_DEVICE ROCPRIM_INLINE
   123                         PrefixCallback& prefix_callback_op,
   124                         BinaryFunction scan_op)
   126         const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
   128         storage_type_& storage_ = storage.get();
   129         this->inclusive_scan_impl(flat_tid, input, output, storage, scan_op);
   131         T block_prefix = this->get_block_prefix(
   133             storage_.threads[index(BlockSize - 1)], 
   134             prefix_callback_op, storage
   136         output = scan_op(block_prefix, output);
   139     template<
unsigned int ItemsPerThread, 
class BinaryFunction>
   140     ROCPRIM_DEVICE ROCPRIM_INLINE
   142                         T (&output)[ItemsPerThread],
   144                         BinaryFunction scan_op)
   147         T thread_input = input[0];
   149         for(
unsigned int i = 1; i < ItemsPerThread; i++)
   151             thread_input = scan_op(thread_input, input[i]);
   155         const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
   156         this->exclusive_scan_impl(
   158             thread_input, thread_input, 
   164         output[0] = input[0];
   165         if(flat_tid != 0) output[0] = scan_op(thread_input, input[0]);
   168         for(
unsigned int i = 1; i < ItemsPerThread; i++)
   170             output[i] = scan_op(output[i-1], input[i]);
   174     template<
unsigned int ItemsPerThread, 
class BinaryFunction>
   175     ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
   177                         T (&output)[ItemsPerThread],
   178                         BinaryFunction scan_op)
   184     template<
unsigned int ItemsPerThread, 
class BinaryFunction>
   185     ROCPRIM_DEVICE ROCPRIM_INLINE
   187                         T (&output)[ItemsPerThread],
   190                         BinaryFunction scan_op)
   192         storage_type_& storage_ = storage.get();
   195         reduction = storage_.threads[index(BlockSize - 1)];
   198     template<
unsigned int ItemsPerThread, 
class BinaryFunction>
   199     ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
   201                         T (&output)[ItemsPerThread],
   203                         BinaryFunction scan_op)
   210         class PrefixCallback,
   211         unsigned int ItemsPerThread,
   214     ROCPRIM_DEVICE ROCPRIM_INLINE
   216                         T (&output)[ItemsPerThread],
   218                         PrefixCallback& prefix_callback_op,
   219                         BinaryFunction scan_op)
   221         storage_type_& storage_ = storage.get();
   223         T thread_input = input[0];
   225         for(
unsigned int i = 1; i < ItemsPerThread; i++)
   227             thread_input = scan_op(thread_input, input[i]);
   231         const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
   232         this->exclusive_scan_impl(
   234             thread_input, thread_input, 
   240         T block_prefix = this->get_block_prefix(
   242             storage_.threads[index(BlockSize - 1)], 
   243             prefix_callback_op, storage
   247         output[0] = input[0];
   248         if(flat_tid != 0) output[0] = scan_op(thread_input, input[0]);
   250         output[0] = scan_op(block_prefix, output[0]);
   253         for(
unsigned int i = 1; i < ItemsPerThread; i++)
   255             output[i] = scan_op(output[i-1], input[i]);
   259     template<
class BinaryFunction>
   260     ROCPRIM_DEVICE ROCPRIM_INLINE
   265                         BinaryFunction scan_op)
   267         const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
   268         this->exclusive_scan_impl(flat_tid, input, output, init, storage, scan_op);
   271     template<
class BinaryFunction>
   272     ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
   276                         BinaryFunction scan_op)
   282     template<
class BinaryFunction>
   283     ROCPRIM_DEVICE ROCPRIM_INLINE
   289                         BinaryFunction scan_op)
   291         const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
   292         storage_type_& storage_ = storage.get();
   293         this->exclusive_scan_impl(
   294             flat_tid, input, output, init, storage, scan_op
   297         reduction = storage_.threads[index(BlockSize - 1)];
   300     template<
class BinaryFunction>
   301     ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
   306                         BinaryFunction scan_op)
   309         this->
exclusive_scan(input, output, init, reduction, storage, scan_op);
   312     template<
class PrefixCallback, 
class BinaryFunction>
   313     ROCPRIM_DEVICE ROCPRIM_INLINE
   317                         PrefixCallback& prefix_callback_op,
   318                         BinaryFunction scan_op)
   320         const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
   322         storage_type_& storage_ = storage.get();
   323         this->exclusive_scan_impl(
   324             flat_tid, input, output, storage, scan_op
   327         T reduction = storage_.threads[index(BlockSize - 1)];
   329         T block_prefix = this->get_block_prefix(
   330             flat_tid, warp_id, reduction,
   331             prefix_callback_op, storage
   333         output = scan_op(block_prefix, output);
   334         if(flat_tid == 0) output = block_prefix;
   337     template<
unsigned int ItemsPerThread, 
class BinaryFunction>
   338     ROCPRIM_DEVICE ROCPRIM_INLINE
   340                         T (&output)[ItemsPerThread],
   343                         BinaryFunction scan_op)
   346         T thread_input = input[0];
   348         for(
unsigned int i = 1; i < ItemsPerThread; i++)
   350             thread_input = scan_op(thread_input, input[i]);
   354         const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
   355         this->exclusive_scan_impl(
   357             thread_input, thread_input, 
   368             exclusive = thread_input;
   370         output[0] = exclusive;
   372         for(
unsigned int i = 1; i < ItemsPerThread; i++)
   374             exclusive = scan_op(exclusive, prev);
   376             output[i] = exclusive;
   380     template<
unsigned int ItemsPerThread, 
class BinaryFunction>
   381     ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
   383                         T (&output)[ItemsPerThread],
   385                         BinaryFunction scan_op)
   391     template<
unsigned int ItemsPerThread, 
class BinaryFunction>
   392     ROCPRIM_DEVICE ROCPRIM_INLINE
   394                         T (&output)[ItemsPerThread],
   398                         BinaryFunction scan_op)
   400         storage_type_& storage_ = storage.get();
   403         reduction = storage_.threads[index(BlockSize - 1)];
   406     template<
unsigned int ItemsPerThread, 
class BinaryFunction>
   407     ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
   409                         T (&output)[ItemsPerThread],
   412                         BinaryFunction scan_op)
   415         this->
exclusive_scan(input, output, init, reduction, storage, scan_op);
   419         class PrefixCallback,
   420         unsigned int ItemsPerThread,
   423     ROCPRIM_DEVICE ROCPRIM_INLINE
   425                         T (&output)[ItemsPerThread],
   427                         PrefixCallback& prefix_callback_op,
   428                         BinaryFunction scan_op)
   430         storage_type_& storage_ = storage.get();
   432         T thread_input = input[0];
   434         for(
unsigned int i = 1; i < ItemsPerThread; i++)
   436             thread_input = scan_op(thread_input, input[i]);
   440         const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
   441         this->exclusive_scan_impl(
   443             thread_input, thread_input, 
   449         T block_prefix = this->get_block_prefix(
   451             storage_.threads[index(BlockSize - 1)], 
   452             prefix_callback_op, storage
   457         T exclusive = block_prefix;
   460             exclusive = scan_op(block_prefix, thread_input);
   462         output[0] = exclusive;
   464         for(
unsigned int i = 1; i < ItemsPerThread; i++)
   466             exclusive = scan_op(exclusive, prev);
   468             output[i] = exclusive;
   477     template<
class BinaryFunction>
   478     ROCPRIM_DEVICE ROCPRIM_INLINE
   479     void inclusive_scan_impl(
const unsigned int flat_tid,
   483                              BinaryFunction scan_op)
   485         storage_type_& storage_ = storage.get();
   488         this->inclusive_scan_base(flat_tid, input, storage, scan_op);
   489         output = storage_.threads[index(flat_tid)];
   494     template<
class BinaryFunction>
   495     ROCPRIM_DEVICE ROCPRIM_INLINE
   496     void inclusive_scan_base(
const unsigned int flat_tid,
   499                              BinaryFunction scan_op)
   501         storage_type_& storage_ = storage.get();
   502         storage_.threads[index(flat_tid)] = input;
   504         if(flat_tid < warp_size_)
   506             const unsigned int idx_start = index(flat_tid * thread_reduction_size_);
   507             const unsigned int idx_end = idx_start + thread_reduction_size_;
   509             T thread_reduction = storage_.threads[idx_start];
   511             for(
unsigned int i = idx_start + 1; i < idx_end; i++)
   513                 thread_reduction = scan_op(
   514                     thread_reduction, storage_.threads[i]
   519             warp_scan_prefix_type().inclusive_scan(thread_reduction, thread_reduction, scan_op);
   523             thread_reduction = scan_op(thread_reduction, storage_.threads[idx_start]);
   526                 thread_reduction = input;
   529             storage_.threads[idx_start] = thread_reduction;
   531             for(
unsigned int i = idx_start + 1; i < idx_end; i++)
   533                 thread_reduction = scan_op(
   534                     thread_reduction, storage_.threads[i]
   536                 storage_.threads[i] = thread_reduction;
   542     template<
class BinaryFunction>
   543     ROCPRIM_DEVICE ROCPRIM_INLINE
   544     void exclusive_scan_impl(
const unsigned int flat_tid,
   549                              BinaryFunction scan_op)
   551         storage_type_& storage_ = storage.get();
   553         this->inclusive_scan_base(flat_tid, input, storage, scan_op);
   555         if(flat_tid != 0) output = scan_op(init, storage_.threads[index(flat_tid-1)]);
   558     template<
class BinaryFunction>
   559     ROCPRIM_DEVICE ROCPRIM_INLINE
   560     void exclusive_scan_impl(
const unsigned int flat_tid,
   564                              BinaryFunction scan_op)
   566         storage_type_& storage_ = storage.get();
   568         this->inclusive_scan_base(flat_tid, input, storage, scan_op);
   571             output = storage_.threads[index(flat_tid-1)];
   576     template<
class PrefixCallback, 
class BinaryFunction>
   577     ROCPRIM_DEVICE ROCPRIM_INLINE
   578     void include_block_prefix(
const unsigned int flat_tid,
   583                               PrefixCallback& prefix_callback_op,
   585                               BinaryFunction scan_op)
   587         T block_prefix = this->get_block_prefix(
   588             flat_tid, warp_id, reduction,
   589             prefix_callback_op, storage
   591         output = scan_op(block_prefix, input);
   595     template<
class PrefixCallback>
   596     ROCPRIM_DEVICE ROCPRIM_INLINE
   597     T get_block_prefix(
const unsigned int flat_tid,
   598                        const unsigned int warp_id,
   600                        PrefixCallback& prefix_callback_op,
   603         storage_type_& storage_ = storage.get();
   606             T block_prefix = prefix_callback_op(reduction);
   611                 storage_.threads[0] = block_prefix;
   615         return storage_.threads[0];
   619     ROCPRIM_DEVICE ROCPRIM_INLINE
   620     unsigned int index(
unsigned int n)
 const   623         return has_bank_conflicts_ ? (n + (n/banks_no_)) : n;
   629 END_ROCPRIM_NAMESPACE
   631 #endif // ROCPRIM_BLOCK_DETAIL_BLOCK_SCAN_REDUCE_THEN_SCAN_HPP_ Definition: benchmark_block_scan.cpp:63
ROCPRIM_DEVICE ROCPRIM_INLINE constexpr unsigned int device_warp_size()
Returns a number of threads in a hardware warp for the actual target. 
Definition: thread.hpp:70
Definition: block_scan_reduce_then_scan.hpp:45
ROCPRIM_DEVICE ROCPRIM_INLINE T warp_shuffle_up(const T &input, const unsigned int delta, const int width=device_warp_size())
Shuffle up for any data type. 
Definition: warp_shuffle.hpp:197
Deprecated: Configuration of device-level scan primitives. 
Definition: block_histogram.hpp:62
const unsigned int warp_id
Returns warp id in a block (tile). 
Definition: benchmark_warp_exchange.cpp:153
ROCPRIM_DEVICE ROCPRIM_INLINE void syncthreads()
Synchronize all threads in a block (tile) 
Definition: thread.hpp:216
Definition: benchmark_block_scan.cpp:100