rocPRIM
device_partition.hpp
1 // Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved.
2 //
3 // Permission is hereby granted, free of charge, to any person obtaining a copy
4 // of this software and associated documentation files (the "Software"), to deal
5 // in the Software without restriction, including without limitation the rights
6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 // copies of the Software, and to permit persons to whom the Software is
8 // furnished to do so, subject to the following conditions:
9 //
10 // The above copyright notice and this permission notice shall be included in
11 // all copies or substantial portions of the Software.
12 //
13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 // THE SOFTWARE.
20 
21 #ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_PARTITION_HPP_
22 #define ROCPRIM_DEVICE_DETAIL_DEVICE_PARTITION_HPP_
23 
24 #include <type_traits>
25 #include <iterator>
26 
27 #include "../../detail/various.hpp"
28 #include "../../intrinsics.hpp"
29 #include "../../functional.hpp"
30 #include "../../types.hpp"
31 
32 #include "../../block/block_load.hpp"
33 #include "../../block/block_store.hpp"
34 #include "../../block/block_scan.hpp"
35 #include "../../block/block_discontinuity.hpp"
36 
37 #include "lookback_scan_state.hpp"
38 #include "rocprim/type_traits.hpp"
39 #include "rocprim/types/tuple.hpp"
40 
41 BEGIN_ROCPRIM_NAMESPACE
42 
43 namespace detail
44 {
45 
46 #ifndef DOXYGEN_SHOULD_SKIP_THIS
47 enum class select_method
48 {
49  flag = 0,
50  predicate = 1,
51  unique = 2
52 };
53 #endif // DOXYGEN_SHOULD_SKIP_THIS
54 
55 template<select_method SelectMethod,
56  unsigned int BlockSize,
57  class BlockLoadFlagsType,
58  class BlockDiscontinuityType,
59  class InputIterator,
60  class FlagIterator,
61  class ValueType,
62  unsigned int ItemsPerThread,
63  class UnaryPredicate,
64  class InequalityOp,
65  class StorageType>
66 ROCPRIM_DEVICE ROCPRIM_INLINE auto
67  partition_block_load_flags(InputIterator /* block_predecessor */,
68  FlagIterator block_flags,
69  ValueType (&/* values */)[ItemsPerThread],
70  bool (&is_selected)[ItemsPerThread],
71  UnaryPredicate /* predicate */,
72  InequalityOp /* inequality_op */,
73  StorageType& storage,
74  const bool /* is_first_block */,
75  const unsigned int /* block_thread_id */,
76  const bool is_global_last_block,
77  const unsigned int valid_in_global_last_block) ->
78  typename std::enable_if<SelectMethod == select_method::flag>::type
79 {
80  if(is_global_last_block) // last block
81  {
82  BlockLoadFlagsType().load(block_flags,
83  is_selected,
84  valid_in_global_last_block,
85  false,
86  storage.load_flags);
87  }
88  else
89  {
90  BlockLoadFlagsType()
91  .load(
92  block_flags,
93  is_selected,
94  storage.load_flags
95  );
96  }
97  ::rocprim::syncthreads(); // sync threads to reuse shared memory
98 }
99 
100 template<select_method SelectMethod,
101  unsigned int BlockSize,
102  class BlockLoadFlagsType,
103  class BlockDiscontinuityType,
104  class InputIterator,
105  class FlagIterator,
106  class ValueType,
107  unsigned int ItemsPerThread,
108  class UnaryPredicate,
109  class InequalityOp,
110  class StorageType>
111 ROCPRIM_DEVICE ROCPRIM_INLINE auto
112  partition_block_load_flags(InputIterator /* block_predecessor */,
113  FlagIterator /* block_flags */,
114  ValueType (&values)[ItemsPerThread],
115  bool (&is_selected)[ItemsPerThread],
116  UnaryPredicate predicate,
117  InequalityOp /* inequality_op */,
118  StorageType& /* storage */,
119  const bool /* is_first_block */,
120  const unsigned int block_thread_id,
121  const bool is_global_last_block,
122  const unsigned int valid_in_global_last_block) ->
123  typename std::enable_if<SelectMethod == select_method::predicate>::type
124 {
125  if(is_global_last_block) // last block
126  {
127  const auto offset = block_thread_id * ItemsPerThread;
128  ROCPRIM_UNROLL
129  for(unsigned int i = 0; i < ItemsPerThread; i++)
130  {
131  if((offset + i) < valid_in_global_last_block)
132  {
133  is_selected[i] = predicate(values[i]);
134  }
135  else
136  {
137  is_selected[i] = false;
138  }
139  }
140  }
141  else
142  {
143  ROCPRIM_UNROLL
144  for(unsigned int i = 0; i < ItemsPerThread; i++)
145  {
146  is_selected[i] = predicate(values[i]);
147  }
148  }
149 }
150 
151 // This wrapper processes only part of items and flags (valid_count - 1)th item (for tails)
152 // and (valid_count)th item (for heads), all items after valid_count are unflagged.
153 template<class InequalityOp>
155 {
156  InequalityOp inequality_op;
157  unsigned int valid_count;
158 
159  ROCPRIM_DEVICE ROCPRIM_INLINE
160  guarded_inequality_op(InequalityOp inequality_op, unsigned int valid_count)
161  : inequality_op(inequality_op), valid_count(valid_count)
162  {}
163 
164  template<class T, class U>
165  ROCPRIM_DEVICE ROCPRIM_INLINE
166  bool operator()(const T& a, const U& b, unsigned int b_index)
167  {
168  return (b_index < valid_count && inequality_op(a, b));
169  }
170 };
171 
172 template<select_method SelectMethod,
173  unsigned int BlockSize,
174  class BlockLoadFlagsType,
175  class BlockDiscontinuityType,
176  class InputIterator,
177  class FlagIterator,
178  class ValueType,
179  unsigned int ItemsPerThread,
180  class UnaryPredicate,
181  class InequalityOp,
182  class StorageType>
183 ROCPRIM_DEVICE ROCPRIM_INLINE auto
184  partition_block_load_flags(InputIterator block_predecessor,
185  FlagIterator /* block_flags */,
186  ValueType (&values)[ItemsPerThread],
187  bool (&is_selected)[ItemsPerThread],
188  UnaryPredicate /* predicate */,
189  InequalityOp inequality_op,
190  StorageType& storage,
191  const bool is_first_block,
192  const unsigned int block_thread_id,
193  const bool is_global_last_block,
194  const unsigned int valid_in_global_last_block) ->
195  typename std::enable_if<SelectMethod == select_method::unique>::type
196 {
197  if(is_first_block)
198  {
199  if(is_global_last_block)
200  {
201  BlockDiscontinuityType().flag_heads(
202  is_selected,
203  values,
204  guarded_inequality_op<InequalityOp>(inequality_op, valid_in_global_last_block),
205  storage.discontinuity_values);
206  }
207  else
208  {
209  BlockDiscontinuityType().flag_heads(is_selected,
210  values,
211  inequality_op,
212  storage.discontinuity_values);
213  }
214  }
215  else
216  {
217  const ValueType predecessor = block_predecessor[0];
218  if(is_global_last_block)
219  {
220  BlockDiscontinuityType().flag_heads(
221  is_selected,
222  predecessor,
223  values,
224  guarded_inequality_op<InequalityOp>(inequality_op, valid_in_global_last_block),
225  storage.discontinuity_values);
226  }
227  else
228  {
229  BlockDiscontinuityType().flag_heads(is_selected,
230  predecessor,
231  values,
232  inequality_op,
233  storage.discontinuity_values);
234  }
235  }
236 
237 
238  // Set is_selected for invalid items to false
239  if(is_global_last_block)
240  {
241  const auto offset = block_thread_id * ItemsPerThread;
242  ROCPRIM_UNROLL
243  for(unsigned int i = 0; i < ItemsPerThread; i++)
244  {
245  if((offset + i) >= valid_in_global_last_block)
246  {
247  is_selected[i] = false;
248  }
249  }
250  }
251  ::rocprim::syncthreads(); // sync threads to reuse shared memory
252 }
253 
254 template<select_method SelectMethod,
255  unsigned int BlockSize,
256  class BlockLoadFlagsType,
257  class BlockDiscontinuityType,
258  class InputIterator,
259  class FlagIterator,
260  class ValueType,
261  unsigned int ItemsPerThread,
262  class FirstUnaryPredicate,
263  class SecondUnaryPredicate,
264  class InequalityOp,
265  class StorageType>
266 ROCPRIM_DEVICE ROCPRIM_INLINE void
267  partition_block_load_flags(InputIterator /*block_predecessor*/,
268  FlagIterator /* block_flags */,
269  ValueType (&values)[ItemsPerThread],
270  bool (&is_selected)[2][ItemsPerThread],
271  FirstUnaryPredicate select_first_part_op,
272  SecondUnaryPredicate select_second_part_op,
273  InequalityOp /*inequality_op*/,
274  StorageType& /*storage*/,
275  const unsigned int /*block_id*/,
276  const unsigned int block_thread_id,
277  const bool is_global_last_block,
278  const unsigned int valid_in_global_last_block)
279 {
280  if(is_global_last_block)
281  {
282  const auto offset = block_thread_id * ItemsPerThread;
283  ROCPRIM_UNROLL
284  for(unsigned int i = 0; i < ItemsPerThread; i++)
285  {
286  if((offset + i) < valid_in_global_last_block)
287  {
288  is_selected[0][i] = select_first_part_op(values[i]);
289  is_selected[1][i] = !is_selected[0][i] && select_second_part_op(values[i]);
290  }
291  else
292  {
293  is_selected[0][i] = false;
294  is_selected[1][i] = false;
295  }
296  }
297  }
298  else
299  {
300  ROCPRIM_UNROLL
301  for(unsigned int i = 0; i < ItemsPerThread; i++)
302  {
303  is_selected[0][i] = select_first_part_op(values[i]);
304  is_selected[1][i] = !is_selected[0][i] && select_second_part_op(values[i]);
305  }
306  }
307 }
308 
309 // two-way partition into one iterator
310 template<bool OnlySelected,
311  unsigned int BlockSize,
312  class ValueType,
313  unsigned int ItemsPerThread,
314  class OffsetType,
315  class SelectType,
316  class ScatterStorageType>
317 ROCPRIM_DEVICE ROCPRIM_INLINE auto
318  partition_scatter(ValueType (&values)[ItemsPerThread],
319  bool (&is_selected)[ItemsPerThread],
320  OffsetType (&output_indices)[ItemsPerThread],
322  const size_t total_size,
323  const OffsetType selected_prefix,
324  const OffsetType selected_in_block,
325  ScatterStorageType& storage,
326  const unsigned int flat_block_id,
327  const unsigned int flat_block_thread_id,
328  const bool is_global_last_block,
329  const unsigned int valid_in_global_last_block,
330  size_t (&prev_selected_count_values)[1],
331  size_t prev_processed) -> typename std::enable_if<!OnlySelected>::type
332 {
333  constexpr unsigned int items_per_block = BlockSize * ItemsPerThread;
334 
335  // Scatter selected/rejected values to shared memory
336  auto scatter_storage = storage.get();
337  ROCPRIM_UNROLL
338  for(unsigned int i = 0; i < ItemsPerThread; i++)
339  {
340  unsigned int item_index = (flat_block_thread_id * ItemsPerThread) + i;
341  unsigned int selected_item_index = output_indices[i] - selected_prefix;
342  unsigned int rejected_item_index = (item_index - selected_item_index) + selected_in_block;
343  // index of item in scatter_storage
344  unsigned int scatter_index = is_selected[i] ? selected_item_index : rejected_item_index;
345  scatter_storage[scatter_index] = values[i];
346  }
347  ::rocprim::syncthreads(); // sync threads to reuse shared memory
348 
349  ValueType reloaded_values[ItemsPerThread];
350  for(unsigned int i = 0; i < ItemsPerThread; i++)
351  {
352  const unsigned int item_index = i * BlockSize + flat_block_thread_id;
353  reloaded_values[i] = scatter_storage[item_index];
354  }
355 
356  const auto calculate_scatter_index = [=](const unsigned int item_index) -> size_t
357  {
358  const size_t selected_output_index = prev_selected_count_values[0] + selected_prefix;
359  const size_t rejected_output_index = total_size + selected_output_index - prev_processed
360  - flat_block_id * items_per_block + selected_in_block
361  - 1;
362  return item_index < selected_in_block ? selected_output_index + item_index
363  : rejected_output_index - item_index;
364  };
365  if(is_global_last_block)
366  {
367  for(unsigned int i = 0; i < ItemsPerThread; i++)
368  {
369  const unsigned int item_index = i * BlockSize + flat_block_thread_id;
370  if(item_index < valid_in_global_last_block)
371  {
372  get<0>(output)[calculate_scatter_index(item_index)] = reloaded_values[i];
373  }
374  }
375  }
376  else
377  {
378  for(unsigned int i = 0; i < ItemsPerThread; i++)
379  {
380  const unsigned int item_index = i * BlockSize + flat_block_thread_id;
381  get<0>(output)[calculate_scatter_index(item_index)] = reloaded_values[i];
382  }
383  }
384 }
385 
386 // two-way partition into two iterators
387 template<bool OnlySelected,
388  unsigned int BlockSize,
389  class ValueType,
390  unsigned int ItemsPerThread,
391  class OffsetType,
392  class SelectType,
393  class RejectType,
394  class ScatterStorageType>
395 ROCPRIM_DEVICE ROCPRIM_INLINE auto partition_scatter(ValueType (&values)[ItemsPerThread],
396  bool (&is_selected)[ItemsPerThread],
397  OffsetType (&output_indices)[ItemsPerThread],
399  const size_t /*total_size*/,
400  const OffsetType selected_prefix,
401  const OffsetType selected_in_block,
402  ScatterStorageType& storage,
403  const unsigned int flat_block_id,
404  const unsigned int flat_block_thread_id,
405  const bool is_global_last_block,
406  const unsigned int valid_in_global_last_block,
407  size_t (&prev_selected_count_values)[1],
408  size_t prev_processed) ->
409  typename std::enable_if<!OnlySelected>::type
410 {
411  constexpr unsigned int items_per_block = BlockSize * ItemsPerThread;
412 
413  // Scatter selected/rejected values to shared memory
414  auto scatter_storage = storage.get();
415  ROCPRIM_UNROLL
416  for(unsigned int i = 0; i < ItemsPerThread; i++)
417  {
418  unsigned int item_index = (flat_block_thread_id * ItemsPerThread) + i;
419  unsigned int selected_item_index = output_indices[i] - selected_prefix;
420  unsigned int rejected_item_index = (item_index - selected_item_index) + selected_in_block;
421  // index of item in scatter_storage
422  unsigned int scatter_index = is_selected[i] ? selected_item_index : rejected_item_index;
423  scatter_storage[scatter_index] = values[i];
424  }
425  ::rocprim::syncthreads(); // sync threads to reuse shared memory
426 
427  ValueType reloaded_values[ItemsPerThread];
428  for(unsigned int i = 0; i < ItemsPerThread; i++)
429  {
430  const unsigned int item_index = i * BlockSize + flat_block_thread_id;
431  reloaded_values[i] = scatter_storage[item_index];
432  }
433 
434  auto save_to_output = [=](const unsigned int item_index, const unsigned int i)
435  {
436  const size_t selected_output_index = prev_selected_count_values[0] + selected_prefix;
437  const size_t rejected_output_index = prev_processed + flat_block_id * items_per_block
438  - selected_output_index - selected_in_block;
439 
440  if(item_index < selected_in_block)
441  {
442  get<0>(output)[selected_output_index + item_index] = reloaded_values[i];
443  }
444  else
445  {
446  get<1>(output)[rejected_output_index + item_index] = reloaded_values[i];
447  }
448  };
449 
450  if(is_global_last_block)
451  {
452  for(unsigned int i = 0; i < ItemsPerThread; i++)
453  {
454  const unsigned int item_index = i * BlockSize + flat_block_thread_id;
455  if(item_index < valid_in_global_last_block)
456  {
457  save_to_output(item_index, i);
458  }
459  }
460  }
461  else
462  {
463  for(unsigned int i = 0; i < ItemsPerThread; i++)
464  {
465  const unsigned int item_index = i * BlockSize + flat_block_thread_id;
466  save_to_output(item_index, i);
467  }
468  }
469 }
470 
471 // two-way partition, selection only
472 template<bool OnlySelected,
473  unsigned int BlockSize,
474  class ValueType,
475  unsigned int ItemsPerThread,
476  class OffsetType,
477  class SelectType,
478  class RejectType,
479  class ScatterStorageType>
480 ROCPRIM_DEVICE ROCPRIM_INLINE auto
481  partition_scatter(ValueType (&values)[ItemsPerThread],
482  bool (&is_selected)[ItemsPerThread],
483  OffsetType (&output_indices)[ItemsPerThread],
485  const size_t /*total_size*/,
486  const OffsetType selected_prefix,
487  const OffsetType selected_in_block,
488  ScatterStorageType& storage,
489  const unsigned int /*flat_block_id*/,
490  const unsigned int flat_block_thread_id,
491  const bool is_global_last_block,
492  const unsigned int /*valid_in_global_last_block*/,
493  size_t (&prev_selected_count_values)[1],
494  size_t /*prev_processed*/) -> typename std::enable_if<OnlySelected>::type
495 {
496  if(selected_in_block > BlockSize)
497  {
498  // Scatter selected values to shared memory
499  auto scatter_storage = storage.get();
500  ROCPRIM_UNROLL
501  for(unsigned int i = 0; i < ItemsPerThread; i++)
502  {
503  unsigned int scatter_index = output_indices[i] - selected_prefix;
504  if(is_selected[i])
505  {
506  scatter_storage[scatter_index] = values[i];
507  }
508  }
509  ::rocprim::syncthreads(); // sync threads to reuse shared memory
510 
511  // Coalesced write from shared memory to global memory
512  for(unsigned int i = flat_block_thread_id; i < selected_in_block; i += BlockSize)
513  {
514  get<0>(output)[prev_selected_count_values[0] + selected_prefix + i]
515  = scatter_storage[i];
516  }
517  }
518  else
519  {
520  ROCPRIM_UNROLL
521  for(unsigned int i = 0; i < ItemsPerThread; i++)
522  {
523  if(!is_global_last_block || output_indices[i] < (selected_prefix + selected_in_block))
524  {
525  if(is_selected[i])
526  {
527  get<0>(output)[prev_selected_count_values[0] + output_indices[i]] = values[i];
528  }
529  }
530  }
531  }
532 }
533 
534 // three-way partition
535 template<bool OnlySelected,
536  unsigned int BlockSize,
537  class ValueType,
538  unsigned int ItemsPerThread,
539  class OffsetType,
540  class OutputType,
541  class ScatterStorageType>
542 ROCPRIM_DEVICE ROCPRIM_INLINE void partition_scatter(ValueType (&values)[ItemsPerThread],
543  bool (&is_selected)[2][ItemsPerThread],
544  OffsetType (&output_indices)[ItemsPerThread],
545  OutputType output,
546  const size_t /*total_size*/,
547  const OffsetType selected_prefix,
548  const OffsetType selected_in_block,
549  ScatterStorageType& storage,
550  const unsigned int flat_block_id,
551  const unsigned int flat_block_thread_id,
552  const bool is_global_last_block,
553  const unsigned int valid_in_global_last_block,
554  size_t (&prev_selected_count_values)[2],
555  size_t prev_processed)
556 {
557  constexpr unsigned int items_per_block = BlockSize * ItemsPerThread;
558  auto scatter_storage = storage.get();
559  const size_t first_selected_prefix = prev_selected_count_values[0] + selected_prefix.x;
560  const size_t second_selected_prefix
561  = prev_selected_count_values[1] - selected_in_block.x + selected_prefix.y;
562  const size_t unselected_prefix = prev_processed - first_selected_prefix - second_selected_prefix
563  + items_per_block * flat_block_id - 2 * selected_in_block.x
564  - selected_in_block.y;
565 
566  ROCPRIM_UNROLL
567  for(unsigned int i = 0; i < ItemsPerThread; i++)
568  {
569  const unsigned int first_selected_item_index = output_indices[i].x - selected_prefix.x;
570  const unsigned int second_selected_item_index = output_indices[i].y - selected_prefix.y
571  + selected_in_block.x;
572  unsigned int scatter_index{};
573 
574  if(is_selected[0][i])
575  {
576  scatter_index = first_selected_item_index;
577  }
578  else if(is_selected[1][i])
579  {
580  scatter_index = second_selected_item_index;
581  }
582  else
583  {
584  const unsigned int item_index = (flat_block_thread_id * ItemsPerThread) + i;
585  const unsigned int unselected_item_index = (item_index - first_selected_item_index - second_selected_item_index)
586  + 2*selected_in_block.x + selected_in_block.y;
587  scatter_index = unselected_item_index;
588  }
589  scatter_storage[scatter_index] = values[i];
590  }
592 
593  auto save_to_output = [=](const unsigned int item_index) mutable
594  {
595  if(item_index < selected_in_block.x)
596  {
597  const size_t first_selected_index = first_selected_prefix + item_index;
598  get<0>(output)[first_selected_index] = scatter_storage[item_index];
599  }
600  else if(item_index < selected_in_block.x + selected_in_block.y)
601  {
602  const size_t second_selected_index = second_selected_prefix + item_index;
603  get<1>(output)[second_selected_index] = scatter_storage[item_index];
604  }
605  else
606  {
607  const size_t unselected_index = unselected_prefix + item_index;
608  get<2>(output)[unselected_index] = scatter_storage[item_index];
609  }
610  };
611 
612  if(is_global_last_block)
613  {
614  for(unsigned int i = 0; i < ItemsPerThread; i++)
615  {
616  const unsigned int item_index = (i * BlockSize) + flat_block_thread_id;
617  if(item_index < valid_in_global_last_block)
618  {
619  save_to_output(item_index);
620  }
621  }
622  }
623  else
624  {
625  for(unsigned int i = 0; i < ItemsPerThread; i++)
626  {
627  const unsigned int item_index = (i * BlockSize) + flat_block_thread_id;
628  save_to_output(item_index);
629  }
630  }
631 }
632 
633 template<
634  unsigned int items_per_thread,
635  class offset_type
636 >
637 ROCPRIM_DEVICE ROCPRIM_INLINE
638 void convert_selected_to_indices(offset_type (&output_indices)[items_per_thread],
639  bool (&is_selected)[items_per_thread])
640 {
641  ROCPRIM_UNROLL
642  for(unsigned int i = 0; i < items_per_thread; i++)
643  {
644  output_indices[i] = is_selected[i] ? 1 : 0;
645  }
646 }
647 
648 template<
649  unsigned int items_per_thread
650 >
651 ROCPRIM_DEVICE ROCPRIM_INLINE
652 void convert_selected_to_indices(uint2 (&output_indices)[items_per_thread],
653  bool (&is_selected)[2][items_per_thread])
654 {
655  ROCPRIM_UNROLL
656  for(unsigned int i = 0; i < items_per_thread; i++)
657  {
658  output_indices[i].x = is_selected[0][i] ? 1 : 0;
659  output_indices[i].y = is_selected[1][i] ? 1 : 0;
660  }
661 }
662 
663 template<class OffsetT>
664 ROCPRIM_DEVICE ROCPRIM_INLINE void store_selected_count(size_t* selected_count,
665  size_t (&prev_selected_count_values)[1],
666  const OffsetT selected_prefix,
667  const OffsetT selected_in_block)
668 {
669  selected_count[0] = prev_selected_count_values[0] + selected_prefix + selected_in_block;
670 }
671 
672 ROCPRIM_DEVICE ROCPRIM_INLINE void store_selected_count(size_t* selected_count,
673  size_t (&prev_selected_count_values)[2],
674  const uint2 selected_prefix,
675  const uint2 selected_in_block)
676 {
677  selected_count[0] = prev_selected_count_values[0] + selected_prefix.x + selected_in_block.x;
678  selected_count[1] = prev_selected_count_values[1] + selected_prefix.y + selected_in_block.y;
679 }
680 
681 template<unsigned int Size>
682 ROCPRIM_DEVICE void load_selected_count(const size_t* const prev_selected_count,
683  size_t (&loaded_values)[Size])
684 {
685  for(unsigned int i = 0; i < Size; ++i)
686  {
687  loaded_values[i] = prev_selected_count[i];
688  }
689 }
690 
691 template<select_method SelectMethod,
692  bool OnlySelected,
693  class Config,
694  class KeyIterator,
695  class ValueIterator, // Can be rocprim::empty_type* if key only
696  class FlagIterator,
697  class OutputKeyIterator,
698  class OutputValueIterator,
699  class InequalityOp,
700  class OffsetLookbackScanState,
701  class... UnaryPredicates>
702 ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void
703  partition_kernel_impl(KeyIterator keys_input,
704  ValueIterator values_input,
705  FlagIterator flags,
706  OutputKeyIterator keys_output,
707  OutputValueIterator values_output,
708  size_t* selected_count,
709  size_t* prev_selected_count,
710  size_t prev_processed,
711  const size_t total_size,
712  InequalityOp inequality_op,
713  OffsetLookbackScanState offset_scan_state,
714  const unsigned int number_of_blocks,
715  UnaryPredicates... predicates)
716 {
717  constexpr auto block_size = Config::block_size;
718  constexpr auto items_per_thread = Config::items_per_thread;
719  constexpr unsigned int items_per_block = block_size * items_per_thread;
720 
721  using offset_type = typename OffsetLookbackScanState::value_type;
722  using key_type = typename std::iterator_traits<KeyIterator>::value_type;
723  using value_type = typename std::iterator_traits<ValueIterator>::value_type;
724 
725  // Block primitives
726  using block_load_key_type = ::rocprim::block_load<
727  key_type, block_size, items_per_thread,
728  Config::key_block_load_method
729  >;
730  using block_load_value_type = ::rocprim::block_load<
731  value_type, block_size, items_per_thread,
732  Config::value_block_load_method
733  >;
734  using block_load_flag_type = ::rocprim::block_load<
735  bool, block_size, items_per_thread,
736  Config::flag_block_load_method
737  >;
738  using block_scan_offset_type = ::rocprim::block_scan<
739  offset_type, block_size,
740  Config::block_scan_method
741  >;
742  using block_discontinuity_key_type = ::rocprim::block_discontinuity<key_type, block_size>;
743 
744  // Offset prefix operation type
745  using offset_scan_prefix_op_type = offset_lookback_scan_prefix_op<
746  offset_type, OffsetLookbackScanState
747  >;
748 
749  // Memory required for 2-phase scatter
750  using exchange_keys_storage_type = key_type[items_per_block];
751  using raw_exchange_keys_storage_type = typename detail::raw_storage<exchange_keys_storage_type>;
752  using exchange_values_storage_type = value_type[items_per_block];
753  using raw_exchange_values_storage_type = typename detail::raw_storage<exchange_values_storage_type>;
754 
755  using is_selected_type = std::conditional_t<
756  sizeof...(UnaryPredicates) == 1,
757  bool[items_per_thread],
758  bool[sizeof...(UnaryPredicates)][items_per_thread]>;
759 
760  ROCPRIM_SHARED_MEMORY union
761  {
762  raw_exchange_keys_storage_type exchange_keys;
763  raw_exchange_values_storage_type exchange_values;
764  typename block_load_key_type::storage_type load_keys;
765  typename block_load_value_type::storage_type load_values;
766  typename block_load_flag_type::storage_type load_flags;
767  typename block_discontinuity_key_type::storage_type discontinuity_values;
768  typename block_scan_offset_type::storage_type scan_offsets;
769  } storage;
770 
771  size_t prev_selected_count_values[sizeof...(UnaryPredicates)]{};
772  load_selected_count(prev_selected_count, prev_selected_count_values);
773 
774  const auto flat_block_thread_id = ::rocprim::detail::block_thread_id<0>();
775  const auto flat_block_id = ::rocprim::detail::block_id<0>();
776  const auto block_offset = flat_block_id * items_per_block;
777  const unsigned int valid_in_global_last_block
778  = total_size - prev_processed - items_per_block * (number_of_blocks - 1);
779  const bool is_last_launch = total_size <= prev_processed + number_of_blocks * items_per_block;
780  const bool is_global_last_block = is_last_launch && flat_block_id == (number_of_blocks - 1);
781 
782  key_type keys[items_per_thread];
783  is_selected_type is_selected;
784  offset_type output_indices[items_per_thread];
785 
786  // Load input values into values
787  if(is_global_last_block)
788  {
789  block_load_key_type().load(keys_input + block_offset,
790  keys,
791  valid_in_global_last_block,
792  storage.load_keys);
793  }
794  else
795  {
796  block_load_key_type()
797  .load(
798  keys_input + block_offset,
799  keys,
800  storage.load_keys
801  );
802  }
803  ::rocprim::syncthreads(); // sync threads to reuse shared memory
804 
805  // Load selection flags into is_selected, generate them using
806  // input value and selection predicate, or generate them using
807  // block_discontinuity primitive
808  const bool is_first_block = flat_block_id == 0 && prev_processed == 0;
809  partition_block_load_flags<SelectMethod,
810  block_size,
811  block_load_flag_type,
812  block_discontinuity_key_type>(keys_input + block_offset - 1,
813  flags + block_offset,
814  keys,
815  is_selected,
816  predicates...,
817  inequality_op,
818  storage,
819  is_first_block,
821  is_global_last_block,
822  valid_in_global_last_block);
823 
824  // Convert true/false is_selected flags to 0s and 1s
825  convert_selected_to_indices(output_indices, is_selected);
826 
827  // Number of selected values in previous blocks
828  offset_type selected_prefix{};
829  // Number of selected values in this block
830  offset_type selected_in_block{};
831 
832  // Calculate number of selected values in block and their indices
833  if(flat_block_id == 0)
834  {
835  block_scan_offset_type()
836  .exclusive_scan(
837  output_indices,
838  output_indices,
839  offset_type{},
840  selected_in_block,
841  storage.scan_offsets,
842  ::rocprim::plus<offset_type>()
843  );
844  if(flat_block_thread_id == 0)
845  {
846  offset_scan_state.set_complete(flat_block_id, selected_in_block);
847  }
848  ::rocprim::syncthreads(); // sync threads to reuse shared memory
849  }
850  else
851  {
852  ROCPRIM_SHARED_MEMORY typename offset_scan_prefix_op_type::storage_type storage_prefix_op;
853  auto prefix_op = offset_scan_prefix_op_type(
854  flat_block_id,
855  offset_scan_state,
856  storage_prefix_op
857  );
858  block_scan_offset_type()
859  .exclusive_scan(
860  output_indices,
861  output_indices,
862  storage.scan_offsets,
863  prefix_op,
864  ::rocprim::plus<offset_type>()
865  );
866  ::rocprim::syncthreads(); // sync threads to reuse shared memory
867 
868  selected_in_block = prefix_op.get_reduction();
869  selected_prefix = prefix_op.get_prefix();
870  }
871 
872  // Scatter selected and rejected values
873  partition_scatter<OnlySelected, block_size>(keys,
874  is_selected,
875  output_indices,
876  keys_output,
877  total_size,
878  selected_prefix,
879  selected_in_block,
880  storage.exchange_keys,
883  is_global_last_block,
884  valid_in_global_last_block,
885  prev_selected_count_values,
886  prev_processed);
887 
888  static constexpr bool with_values = !std::is_same<value_type, ::rocprim::empty_type>::value;
889 
890  if ROCPRIM_IF_CONSTEXPR (with_values) {
891  value_type values[items_per_thread];
892 
893  ::rocprim::syncthreads(); // sync threads to reuse shared memory
894  if(is_global_last_block)
895  {
896  block_load_value_type().load(values_input + block_offset,
897  values,
898  valid_in_global_last_block,
899  storage.load_values);
900  }
901  else
902  {
903  block_load_value_type()
904  .load(
905  values_input + block_offset,
906  values,
907  storage.load_values
908  );
909  }
910  ::rocprim::syncthreads(); // sync threads to reuse shared memory
911 
912  partition_scatter<OnlySelected, block_size>(values,
913  is_selected,
914  output_indices,
915  values_output,
916  total_size,
917  selected_prefix,
918  selected_in_block,
919  storage.exchange_values,
922  is_global_last_block,
923  valid_in_global_last_block,
924  prev_selected_count_values,
925  prev_processed);
926  }
927 
928  // Last block in grid stores number of selected values
929  const bool is_last_block = flat_block_id == (number_of_blocks - 1);
930  if(is_last_block && flat_block_thread_id == 0)
931  {
932  store_selected_count(selected_count,
933  prev_selected_count_values,
934  selected_prefix,
935  selected_in_block);
936  }
937 }
938 
939 } // end of detail namespace
940 
941 END_ROCPRIM_NAMESPACE
942 
943 #endif // ROCPRIM_DEVICE_DETAIL_DEVICE_PARTITION_HPP_
ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int flat_block_thread_id()
Returns flat (linear, 1D) thread identifier in a multidimensional block (tile).
Definition: thread.hpp:106
ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int block_thread_id()
Returns thread identifier in a multidimensional block (tile) by dimension.
Definition: thread.hpp:248
Deprecated: Configuration of device-level scan primitives.
Definition: block_histogram.hpp:62
ROCPRIM_DEVICE ROCPRIM_INLINE void syncthreads()
Synchronize all threads in a block (tile)
Definition: thread.hpp:216
ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int flat_block_id()
Returns flat (linear, 1D) block identifier in a multidimensional grid.
Definition: thread.hpp:178
Fixed-size collection of heterogeneous values.
Definition: tuple.hpp:41
Definition: various.hpp:180
ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int block_size()
Returns block size in a multidimensional grid by dimension.
Definition: thread.hpp:268
Definition: device_partition.hpp:154
Definition: lookback_scan_state.hpp:515
hipError_t unique(void *temporary_storage, size_t &storage_size, InputIterator input, OutputIterator output, UniqueCountOutputIterator unique_count_output, const size_t size, EqualityOp equality_op=EqualityOp(), const hipStream_t stream=0, const bool debug_synchronous=false)
Device-level parallel unique primitive.
Definition: device_select.hpp:383