rocPRIM
block_scan_warp_scan.hpp
1 // Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved.
2 //
3 // Permission is hereby granted, free of charge, to any person obtaining a copy
4 // of this software and associated documentation files (the "Software"), to deal
5 // in the Software without restriction, including without limitation the rights
6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 // copies of the Software, and to permit persons to whom the Software is
8 // furnished to do so, subject to the following conditions:
9 //
10 // The above copyright notice and this permission notice shall be included in
11 // all copies or substantial portions of the Software.
12 //
13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 // THE SOFTWARE.
20 
21 #ifndef ROCPRIM_BLOCK_DETAIL_BLOCK_SCAN_WARP_SCAN_HPP_
22 #define ROCPRIM_BLOCK_DETAIL_BLOCK_SCAN_WARP_SCAN_HPP_
23 
24 #include <type_traits>
25 
26 #include "../../config.hpp"
27 #include "../../detail/various.hpp"
28 
29 #include "../../intrinsics.hpp"
30 #include "../../functional.hpp"
31 
32 #include "../../warp/warp_scan.hpp"
33 
34 BEGIN_ROCPRIM_NAMESPACE
35 
36 namespace detail
37 {
38 
39 template<
40  class T,
41  unsigned int BlockSizeX,
42  unsigned int BlockSizeY,
43  unsigned int BlockSizeZ
44 >
46 {
47  static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ;
48  // Select warp size
49  static constexpr unsigned int warp_size_ =
50  detail::get_min_warp_size(BlockSize, ::rocprim::device_warp_size());
51  // Number of warps in block
52  static constexpr unsigned int warps_no_ = (BlockSize + warp_size_ - 1) / warp_size_;
53 
54  // typedef of warp_scan primitive that will be used to perform warp-level
55  // inclusive/exclusive scan operations on input values.
56  // warp_scan_crosslane is an implementation of warp_scan that does not need storage,
57  // but requires logical warp size to be a power of two.
58  using warp_scan_input_type = ::rocprim::detail::warp_scan_crosslane<T, warp_size_>;
59  // typedef of warp_scan primitive that will be used to get prefix values for
60  // each warp (scanned carry-outs from warps before it).
61  using warp_scan_prefix_type = ::rocprim::detail::warp_scan_crosslane<T, detail::next_power_of_two(warps_no_)>;
62 
63  struct storage_type_
64  {
65  T warp_prefixes[warps_no_];
66  // ---------- Shared memory optimisation ----------
67  // Since warp_scan_input and warp_scan_prefix are typedef of warp_scan_crosslane,
68  // we don't need to allocate any temporary memory for them.
69  // If we just use warp_scan, we would need to add following union to this struct:
70  // union
71  // {
72  // typename warp_scan_input::storage_type wscan[warps_no_];
73  // typename warp_scan_prefix::storage_type wprefix_scan;
74  // };
75  // and use storage_.wscan[warp_id] and storage.wprefix_scan when calling
76  // warp_scan_input().inclusive_scan(..) and warp_scan_prefix().inclusive_scan(..).
77  };
78 
79 public:
81 
82  template<class BinaryFunction>
83  ROCPRIM_DEVICE ROCPRIM_INLINE
84  void inclusive_scan(T input,
85  T& output,
86  storage_type& storage,
87  BinaryFunction scan_op)
88  {
89  this->inclusive_scan_impl(
90  ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
91  input, output, storage, scan_op
92  );
93  }
94 
95  template<class BinaryFunction>
96  ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
97  void inclusive_scan(T input,
98  T& output,
99  BinaryFunction scan_op)
100  {
101  ROCPRIM_SHARED_MEMORY storage_type storage;
102  this->inclusive_scan(input, output, storage, scan_op);
103  }
104 
105  template<class BinaryFunction>
106  ROCPRIM_DEVICE ROCPRIM_INLINE
107  void inclusive_scan(T input,
108  T& output,
109  T& reduction,
110  storage_type& storage,
111  BinaryFunction scan_op)
112  {
113  storage_type_& storage_ = storage.get();
114  this->inclusive_scan(input, output, storage, scan_op);
115  // Save reduction result
116  reduction = storage_.warp_prefixes[warps_no_ - 1];
117  }
118 
119  template<class BinaryFunction>
120  ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
121  void inclusive_scan(T input,
122  T& output,
123  T& reduction,
124  BinaryFunction scan_op)
125  {
126  ROCPRIM_SHARED_MEMORY storage_type storage;
127  this->inclusive_scan(input, output, reduction, storage, scan_op);
128  }
129 
130  template<class PrefixCallback, class BinaryFunction>
131  ROCPRIM_DEVICE ROCPRIM_INLINE
132  void inclusive_scan(T input,
133  T& output,
134  storage_type& storage,
135  PrefixCallback& prefix_callback_op,
136  BinaryFunction scan_op)
137  {
138  const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
139  const auto warp_id = ::rocprim::warp_id(flat_tid);
140  storage_type_& storage_ = storage.get();
141  this->inclusive_scan_impl(flat_tid, input, output, storage, scan_op);
142  // Include block prefix (this operation overwrites storage_.warp_prefixes[warps_no_ - 1])
143  T block_prefix = this->get_block_prefix(
144  flat_tid, warp_id,
145  storage_.warp_prefixes[warps_no_ - 1], // block reduction
146  prefix_callback_op, storage
147  );
148  output = scan_op(block_prefix, output);
149  }
150 
151  template<unsigned int ItemsPerThread, class BinaryFunction>
152  ROCPRIM_DEVICE ROCPRIM_INLINE
153  void inclusive_scan(T (&input)[ItemsPerThread],
154  T (&output)[ItemsPerThread],
155  storage_type& storage,
156  BinaryFunction scan_op)
157  {
158  // Reduce thread items
159  T thread_input = input[0];
160  ROCPRIM_UNROLL
161  for(unsigned int i = 1; i < ItemsPerThread; i++)
162  {
163  thread_input = scan_op(thread_input, input[i]);
164  }
165 
166  // Scan of reduced values to get prefixes
167  const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
168  this->exclusive_scan_impl(
169  flat_tid,
170  thread_input, thread_input, // input, output
171  storage,
172  scan_op
173  );
174 
175  // Include prefix (first thread does not have prefix)
176  output[0] = input[0];
177  if(flat_tid != 0)
178  {
179  output[0] = scan_op(thread_input, input[0]);
180  }
181 
182  // Final thread-local scan
183  ROCPRIM_UNROLL
184  for(unsigned int i = 1; i < ItemsPerThread; i++)
185  {
186  output[i] = scan_op(output[i-1], input[i]);
187  }
188  }
189 
190  template<unsigned int ItemsPerThread, class BinaryFunction>
191  ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
192  void inclusive_scan(T (&input)[ItemsPerThread],
193  T (&output)[ItemsPerThread],
194  BinaryFunction scan_op)
195  {
196  ROCPRIM_SHARED_MEMORY storage_type storage;
197  this->inclusive_scan(input, output, storage, scan_op);
198  }
199 
200  template<unsigned int ItemsPerThread, class BinaryFunction>
201  ROCPRIM_DEVICE ROCPRIM_INLINE
202  void inclusive_scan(T (&input)[ItemsPerThread],
203  T (&output)[ItemsPerThread],
204  T& reduction,
205  storage_type& storage,
206  BinaryFunction scan_op)
207  {
208  storage_type_& storage_ = storage.get();
209  this->inclusive_scan(input, output, storage, scan_op);
210  // Save reduction result
211  reduction = storage_.warp_prefixes[warps_no_ - 1];
212  }
213 
214  template<unsigned int ItemsPerThread, class BinaryFunction>
215  ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
216  void inclusive_scan(T (&input)[ItemsPerThread],
217  T (&output)[ItemsPerThread],
218  T& reduction,
219  BinaryFunction scan_op)
220  {
221  ROCPRIM_SHARED_MEMORY storage_type storage;
222  this->inclusive_scan(input, output, reduction, storage, scan_op);
223  }
224 
225  template<
226  class PrefixCallback,
227  unsigned int ItemsPerThread,
228  class BinaryFunction
229  >
230  ROCPRIM_DEVICE ROCPRIM_INLINE
231  void inclusive_scan(T (&input)[ItemsPerThread],
232  T (&output)[ItemsPerThread],
233  storage_type& storage,
234  PrefixCallback& prefix_callback_op,
235  BinaryFunction scan_op)
236  {
237  storage_type_& storage_ = storage.get();
238  // Reduce thread items
239  T thread_input = input[0];
240  ROCPRIM_UNROLL
241  for(unsigned int i = 1; i < ItemsPerThread; i++)
242  {
243  thread_input = scan_op(thread_input, input[i]);
244  }
245 
246  // Scan of reduced values to get prefixes
247  const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
248  this->exclusive_scan_impl(
249  flat_tid,
250  thread_input, thread_input, // input, output
251  storage,
252  scan_op
253  );
254 
255  // this operation overwrites storage_.warp_prefixes[warps_no_ - 1]
256  T block_prefix = this->get_block_prefix(
257  flat_tid, ::rocprim::warp_id(flat_tid),
258  storage_.warp_prefixes[warps_no_ - 1], // block reduction
259  prefix_callback_op, storage
260  );
261 
262  // Include prefix (first thread does not have prefix)
263  output[0] = input[0];
264  if(flat_tid != 0)
265  {
266  output[0] = scan_op(thread_input, input[0]);
267  }
268  // Include block prefix
269  output[0] = scan_op(block_prefix, output[0]);
270  // Final thread-local scan
271  ROCPRIM_UNROLL
272  for(unsigned int i = 1; i < ItemsPerThread; i++)
273  {
274  output[i] = scan_op(output[i-1], input[i]);
275  }
276  }
277 
278  template<class BinaryFunction>
279  ROCPRIM_DEVICE ROCPRIM_INLINE
280  void exclusive_scan(T input,
281  T& output,
282  T init,
283  storage_type& storage,
284  BinaryFunction scan_op)
285  {
286  this->exclusive_scan_impl(
287  ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>(),
288  input, output, init, storage, scan_op
289  );
290  }
291 
292  template<class BinaryFunction>
293  ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
294  void exclusive_scan(T input,
295  T& output,
296  T init,
297  BinaryFunction scan_op)
298  {
299  ROCPRIM_SHARED_MEMORY storage_type storage;
300  this->exclusive_scan(
301  input, output, init, storage, scan_op
302  );
303  }
304 
305  template<class BinaryFunction>
306  ROCPRIM_DEVICE ROCPRIM_INLINE
307  void exclusive_scan(T input,
308  T& output,
309  T init,
310  T& reduction,
311  storage_type& storage,
312  BinaryFunction scan_op)
313  {
314  storage_type_& storage_ = storage.get();
315  this->exclusive_scan(
316  input, output, init, storage, scan_op
317  );
318  // Save reduction result
319  reduction = storage_.warp_prefixes[warps_no_ - 1];
320  }
321 
322  template<class BinaryFunction>
323  ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
324  void exclusive_scan(T input,
325  T& output,
326  T init,
327  T& reduction,
328  BinaryFunction scan_op)
329  {
330  ROCPRIM_SHARED_MEMORY storage_type storage;
331  this->exclusive_scan(
332  input, output, init, reduction, storage, scan_op
333  );
334  }
335 
336  template<class PrefixCallback, class BinaryFunction>
337  ROCPRIM_DEVICE ROCPRIM_INLINE
338  void exclusive_scan(T input,
339  T& output,
340  storage_type& storage,
341  PrefixCallback& prefix_callback_op,
342  BinaryFunction scan_op)
343  {
344  const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
345  const auto warp_id = ::rocprim::warp_id(flat_tid);
346  storage_type_& storage_ = storage.get();
347  this->exclusive_scan_impl(
348  flat_tid, input, output, storage, scan_op
349  );
350  // Include block prefix (this operation overwrites storage_.warp_prefixes[warps_no_ - 1])
351  T block_prefix = this->get_block_prefix(
352  flat_tid, warp_id,
353  storage_.warp_prefixes[warps_no_ - 1], // block reduction
354  prefix_callback_op, storage
355  );
356  output = scan_op(block_prefix, output);
357  if(flat_tid == 0) output = block_prefix;
358  }
359 
360  template<unsigned int ItemsPerThread, class BinaryFunction>
361  ROCPRIM_DEVICE ROCPRIM_INLINE
362  void exclusive_scan(T (&input)[ItemsPerThread],
363  T (&output)[ItemsPerThread],
364  T init,
365  storage_type& storage,
366  BinaryFunction scan_op)
367  {
368  // Reduce thread items
369  T thread_input = input[0];
370  ROCPRIM_UNROLL
371  for(unsigned int i = 1; i < ItemsPerThread; i++)
372  {
373  thread_input = scan_op(thread_input, input[i]);
374  }
375 
376  // Scan of reduced values to get prefixes
377  const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
378  this->exclusive_scan_impl(
379  flat_tid,
380  thread_input, thread_input, // input, output
381  init,
382  storage,
383  scan_op
384  );
385 
386  // Include init value
387  T prev = input[0];
388  T exclusive = init;
389  if(flat_tid != 0)
390  {
391  exclusive = thread_input;
392  }
393  output[0] = exclusive;
394 
395  ROCPRIM_UNROLL
396  for(unsigned int i = 1; i < ItemsPerThread; i++)
397  {
398  exclusive = scan_op(exclusive, prev);
399  prev = input[i];
400  output[i] = exclusive;
401  }
402  }
403 
404  template<unsigned int ItemsPerThread, class BinaryFunction>
405  ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
406  void exclusive_scan(T (&input)[ItemsPerThread],
407  T (&output)[ItemsPerThread],
408  T init,
409  BinaryFunction scan_op)
410  {
411  ROCPRIM_SHARED_MEMORY storage_type storage;
412  this->exclusive_scan(input, output, init, storage, scan_op);
413  }
414 
415  template<unsigned int ItemsPerThread, class BinaryFunction>
416  ROCPRIM_DEVICE ROCPRIM_INLINE
417  void exclusive_scan(T (&input)[ItemsPerThread],
418  T (&output)[ItemsPerThread],
419  T init,
420  T& reduction,
421  storage_type& storage,
422  BinaryFunction scan_op)
423  {
424  storage_type_& storage_ = storage.get();
425  this->exclusive_scan(input, output, init, storage, scan_op);
426  // Save reduction result
427  reduction = storage_.warp_prefixes[warps_no_ - 1];
428  }
429 
430  template<unsigned int ItemsPerThread, class BinaryFunction>
431  ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
432  void exclusive_scan(T (&input)[ItemsPerThread],
433  T (&output)[ItemsPerThread],
434  T init,
435  T& reduction,
436  BinaryFunction scan_op)
437  {
438  ROCPRIM_SHARED_MEMORY storage_type storage;
439  this->exclusive_scan(input, output, init, reduction, storage, scan_op);
440  }
441 
442  template<
443  class PrefixCallback,
444  unsigned int ItemsPerThread,
445  class BinaryFunction
446  >
447  ROCPRIM_DEVICE ROCPRIM_INLINE
448  void exclusive_scan(T (&input)[ItemsPerThread],
449  T (&output)[ItemsPerThread],
450  storage_type& storage,
451  PrefixCallback& prefix_callback_op,
452  BinaryFunction scan_op)
453  {
454  storage_type_& storage_ = storage.get();
455  // Reduce thread items
456  T thread_input = input[0];
457  ROCPRIM_UNROLL
458  for(unsigned int i = 1; i < ItemsPerThread; i++)
459  {
460  thread_input = scan_op(thread_input, input[i]);
461  }
462 
463  // Scan of reduced values to get prefixes
464  const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
465  this->exclusive_scan_impl(
466  flat_tid,
467  thread_input, thread_input, // input, output
468  storage,
469  scan_op
470  );
471 
472  // this operation overwrites storage_.warp_prefixes[warps_no_ - 1]
473  T block_prefix = this->get_block_prefix(
474  flat_tid, ::rocprim::warp_id(flat_tid),
475  storage_.warp_prefixes[warps_no_ - 1], // block reduction
476  prefix_callback_op, storage
477  );
478 
479  // Include init value and block prefix
480  T prev = input[0];
481  T exclusive = block_prefix;
482  if(flat_tid != 0)
483  {
484  exclusive = scan_op(block_prefix, thread_input);
485  }
486  output[0] = exclusive;
487 
488  ROCPRIM_UNROLL
489  for(unsigned int i = 1; i < ItemsPerThread; i++)
490  {
491  exclusive = scan_op(exclusive, prev);
492  prev = input[i];
493  output[i] = exclusive;
494  }
495  }
496 
497 private:
498  template<class BinaryFunction, unsigned int BlockSize_ = BlockSize>
499  ROCPRIM_DEVICE ROCPRIM_INLINE
500  auto inclusive_scan_impl(const unsigned int flat_tid,
501  T input,
502  T& output,
503  storage_type& storage,
504  BinaryFunction scan_op)
505  -> typename std::enable_if<(BlockSize_ > ::rocprim::device_warp_size())>::type
506  {
507  storage_type_& storage_ = storage.get();
508  // Perform warp scan
509  warp_scan_input_type().inclusive_scan(
510  // not using shared mem, see note in storage_type
511  input, output, scan_op
512  );
513 
514  // i-th warp will have its prefix stored in storage_.warp_prefixes[i-1]
515  const auto warp_id = ::rocprim::warp_id(flat_tid);
516  this->calculate_warp_prefixes(flat_tid, warp_id, output, storage, scan_op);
517 
518  // Use warp prefix to calculate the final scan results for every thread
519  if(warp_id != 0)
520  {
521  auto warp_prefix = storage_.warp_prefixes[warp_id - 1];
522  output = scan_op(warp_prefix, output);
523  }
524  }
525 
526  // When BlockSize is less than warp_size we dont need the extra prefix calculations.
527  template<class BinaryFunction, unsigned int BlockSize_ = BlockSize>
528  ROCPRIM_DEVICE ROCPRIM_INLINE
529  auto inclusive_scan_impl(unsigned int flat_tid,
530  T input,
531  T& output,
532  storage_type& storage,
533  BinaryFunction scan_op)
534  -> typename std::enable_if<!(BlockSize_ > ::rocprim::device_warp_size())>::type
535  {
536  (void) storage;
537  (void) flat_tid;
538  storage_type_& storage_ = storage.get();
539  // Perform warp scan
540  warp_scan_input_type().inclusive_scan(
541  // not using shared mem, see note in storage_type
542  input, output, scan_op
543  );
544 
545  if(flat_tid == BlockSize_ - 1)
546  {
547  storage_.warp_prefixes[0] = output;
548  }
550  }
551 
552  // Exclusive scan with initial value when BlockSize is bigger than warp_size
553  template<class BinaryFunction, unsigned int BlockSize_ = BlockSize>
554  ROCPRIM_DEVICE ROCPRIM_INLINE
555  auto exclusive_scan_impl(const unsigned int flat_tid,
556  T input,
557  T& output,
558  T init,
559  storage_type& storage,
560  BinaryFunction scan_op)
561  -> typename std::enable_if<(BlockSize_ > ::rocprim::device_warp_size())>::type
562  {
563  storage_type_& storage_ = storage.get();
564  // Perform warp scan on input values
565  warp_scan_input_type().inclusive_scan(
566  // not using shared mem, see note in storage_type
567  input, output, scan_op
568  );
569 
570  // i-th warp will have its prefix stored in storage_.warp_prefixes[i-1]
571  const auto warp_id = ::rocprim::warp_id(flat_tid);
572  this->calculate_warp_prefixes(flat_tid, warp_id, output, storage, scan_op);
573 
574  // Include initial value in warp prefixes, and fix warp prefixes
575  // for exclusive scan (first warp prefix is init)
576  auto warp_prefix = init;
577  if(warp_id != 0)
578  {
579  warp_prefix = scan_op(init, storage_.warp_prefixes[warp_id-1]);
580  }
581 
582  // Use warp prefix to calculate the final scan results for every thread
583  output = scan_op(warp_prefix, output); // include warp prefix in scan results
584  output = warp_shuffle_up(output, 1, warp_size_); // shift to get exclusive results
585  if(::rocprim::lane_id() == 0)
586  {
587  output = warp_prefix;
588  }
589  }
590 
591  // Exclusive scan with initial value when BlockSize is less than warp_size.
592  // When BlockSize is less than warp_size we dont need the extra prefix calculations.
593  template<class BinaryFunction, unsigned int BlockSize_ = BlockSize>
594  ROCPRIM_DEVICE ROCPRIM_INLINE
595  auto exclusive_scan_impl(const unsigned int flat_tid,
596  T input,
597  T& output,
598  T init,
599  storage_type& storage,
600  BinaryFunction scan_op)
601  -> typename std::enable_if<!(BlockSize_ > ::rocprim::device_warp_size())>::type
602  {
603  (void) flat_tid;
604  (void) storage;
605  (void) init;
606  storage_type_& storage_ = storage.get();
607  // Perform warp scan on input values
608  warp_scan_input_type().inclusive_scan(
609  // not using shared mem, see note in storage_type
610  input, output, scan_op
611  );
612 
613  if(flat_tid == BlockSize_ - 1)
614  {
615  storage_.warp_prefixes[0] = output;
616  }
618 
619  // Use warp prefix to calculate the final scan results for every thread
620  output = scan_op(init, output); // include warp prefix in scan results
621  output = warp_shuffle_up(output, 1, warp_size_); // shift to get exclusive results
622  if(::rocprim::lane_id() == 0)
623  {
624  output = init;
625  }
626  }
627 
628  // Exclusive scan with unknown initial value
629  template<class BinaryFunction, unsigned int BlockSize_ = BlockSize>
630  ROCPRIM_DEVICE ROCPRIM_INLINE
631  auto exclusive_scan_impl(const unsigned int flat_tid,
632  T input,
633  T& output,
634  storage_type& storage,
635  BinaryFunction scan_op)
636  -> typename std::enable_if<(BlockSize_ > ::rocprim::device_warp_size())>::type
637  {
638  storage_type_& storage_ = storage.get();
639  // Perform warp scan on input values
640  warp_scan_input_type().inclusive_scan(
641  // not using shared mem, see note in storage_type
642  input, output, scan_op
643  );
644 
645  // i-th warp will have its prefix stored in storage_.warp_prefixes[i-1]
646  const auto warp_id = ::rocprim::warp_id(flat_tid);
647  this->calculate_warp_prefixes(flat_tid, warp_id, output, storage, scan_op);
648 
649  // Use warp prefix to calculate the final scan results for every thread
650  T warp_prefix;
651  if(warp_id != 0)
652  {
653  warp_prefix = storage_.warp_prefixes[warp_id - 1];
654  output = scan_op(warp_prefix, output);
655  }
656  output = warp_shuffle_up(output, 1, warp_size_); // shift to get exclusive results
657  if(::rocprim::lane_id() == 0)
658  {
659  output = warp_prefix;
660  }
661  }
662 
663  // Exclusive scan with unknown initial value, when BlockSize less than warp_size.
664  // When BlockSize is less than warp_size we dont need the extra prefix calculations.
665  template<class BinaryFunction, unsigned int BlockSize_ = BlockSize>
666  ROCPRIM_DEVICE ROCPRIM_INLINE
667  auto exclusive_scan_impl(const unsigned int flat_tid,
668  T input,
669  T& output,
670  storage_type& storage,
671  BinaryFunction scan_op)
672  -> typename std::enable_if<!(BlockSize_ > ::rocprim::device_warp_size())>::type
673  {
674  (void) flat_tid;
675  (void) storage;
676  storage_type_& storage_ = storage.get();
677  // Perform warp scan on input values
678  warp_scan_input_type().inclusive_scan(
679  // not using shared mem, see note in storage_type
680  input, output, scan_op
681  );
682 
683  if(flat_tid == BlockSize_ - 1)
684  {
685  storage_.warp_prefixes[0] = output;
686  }
688  output = warp_shuffle_up(output, 1, warp_size_); // shift to get exclusive results
689  }
690 
691  // i-th warp will have its prefix stored in storage_.warp_prefixes[i-1]
692  template<class BinaryFunction, unsigned int BlockSize_ = BlockSize>
693  ROCPRIM_DEVICE ROCPRIM_INLINE
694  void calculate_warp_prefixes(const unsigned int flat_tid,
695  const unsigned int warp_id,
696  T inclusive_input,
697  storage_type& storage,
698  BinaryFunction scan_op)
699  {
700  storage_type_& storage_ = storage.get();
701  // Save the warp reduction result, that is the scan result
702  // for last element in each warp
703  if(flat_tid == ::rocprim::min((warp_id+1) * warp_size_, BlockSize_) - 1)
704  {
705  storage_.warp_prefixes[warp_id] = inclusive_input;
706  }
708 
709  // Scan the warp reduction results and store in storage_.warp_prefixes
710  if(flat_tid < warps_no_)
711  {
712  auto warp_prefix = storage_.warp_prefixes[flat_tid];
713  warp_scan_prefix_type().inclusive_scan(
714  // not using shared mem, see note in storage_type
715  warp_prefix, warp_prefix, scan_op
716  );
717  storage_.warp_prefixes[flat_tid] = warp_prefix;
718  }
720  }
721 
722  // THIS OVERWRITES storage_.warp_prefixes[warps_no_ - 1]
723  template<class PrefixCallback>
724  ROCPRIM_DEVICE ROCPRIM_INLINE
725  T get_block_prefix(const unsigned int flat_tid,
726  const unsigned int warp_id,
727  const T reduction,
728  PrefixCallback& prefix_callback_op,
729  storage_type& storage)
730  {
731  storage_type_& storage_ = storage.get();
732  if(warp_id == 0)
733  {
734  T block_prefix = prefix_callback_op(reduction);
735  if(flat_tid == 0)
736  {
737  // Reuse storage_.warp_prefixes[warps_no_ - 1] to store block prefix
738  storage_.warp_prefixes[warps_no_ - 1] = block_prefix;
739  }
740  }
742  return storage_.warp_prefixes[warps_no_ - 1];
743  }
744 };
745 
746 } // end namespace detail
747 
748 END_ROCPRIM_NAMESPACE
749 
750 #endif // ROCPRIM_BLOCK_DETAIL_BLOCK_SCAN_WARP_SCAN_HPP_
Definition: benchmark_block_scan.cpp:63
Definition: block_scan_warp_scan.hpp:45
ROCPRIM_DEVICE ROCPRIM_INLINE constexpr unsigned int device_warp_size()
Returns a number of threads in a hardware warp for the actual target.
Definition: thread.hpp:70
ROCPRIM_DEVICE ROCPRIM_INLINE T warp_shuffle_up(const T &input, const unsigned int delta, const int width=device_warp_size())
Shuffle up for any data type.
Definition: warp_shuffle.hpp:197
ROCPRIM_HOST_DEVICE constexpr T min(const T &a, const T &b)
Returns the minimum of its arguments.
Definition: functional.hpp:63
Deprecated: Configuration of device-level scan primitives.
Definition: block_histogram.hpp:62
const unsigned int warp_id
Returns warp id in a block (tile).
Definition: benchmark_warp_exchange.cpp:153
ROCPRIM_DEVICE ROCPRIM_INLINE void syncthreads()
Synchronize all threads in a block (tile)
Definition: thread.hpp:216
Definition: benchmark_block_scan.cpp:100
ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int lane_id()
Returns thread identifier in a warp.
Definition: thread.hpp:93