rocPRIM
device_merge_sort.hpp
1 // Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved.
2 //
3 // Permission is hereby granted, free of charge, to any person obtaining a copy
4 // of this software and associated documentation files (the "Software"), to deal
5 // in the Software without restriction, including without limitation the rights
6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 // copies of the Software, and to permit persons to whom the Software is
8 // furnished to do so, subject to the following conditions:
9 //
10 // The above copyright notice and this permission notice shall be included in
11 // all copies or substantial portions of the Software.
12 //
13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR next
17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR nextWISE, ARISING FROM,
18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR next DEALINGS IN
19 // THE SOFTWARE.
20 
21 #ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_MERGE_SORT_HPP_
22 #define ROCPRIM_DEVICE_DETAIL_DEVICE_MERGE_SORT_HPP_
23 
24 #include <type_traits>
25 #include <iterator>
26 
27 #include "../../config.hpp"
28 #include "../../detail/various.hpp"
29 
30 #include "../../intrinsics.hpp"
31 #include "../../functional.hpp"
32 #include "../../types.hpp"
33 
34 #include "../../block/block_load.hpp"
35 #include "../../block/block_load_func.hpp"
36 #include "../../block/block_sort.hpp"
37 #include "../../block/block_store.hpp"
38 
39 BEGIN_ROCPRIM_NAMESPACE
40 
41 namespace detail
42 {
43 
44 template<
45  bool WithValues,
46  unsigned int BlockSize,
47  unsigned int ItemsPerThread,
48  class Key,
49  class Value
50 >
52  using block_store_type
54 
55  using storage_type = typename block_store_type::storage_type;
56 
57  template<class KeysOutputIterator, class ValuesOutputIterator, class OffsetT>
58  ROCPRIM_DEVICE ROCPRIM_INLINE void store(const OffsetT block_offset,
59  const unsigned int valid_in_last_block,
60  const bool is_incomplete_block,
61  KeysOutputIterator keys_output,
62  ValuesOutputIterator /*values_output*/,
63  Key (&keys)[ItemsPerThread],
64  Value (&/*values*/)[ItemsPerThread],
65  storage_type& storage)
66  {
67  // Synchronize before reusing shared memory
69 
70  if(is_incomplete_block)
71  {
73  keys_output + block_offset,
74  keys,
75  valid_in_last_block,
76  storage
77  );
78  }
79  else
80  {
82  keys_output + block_offset,
83  keys,
84  storage
85  );
86  }
87  }
88 };
89 
90 template<
91  unsigned int BlockSize,
92  unsigned int ItemsPerThread,
93  class Key,
94  class Value
95 >
96 struct block_store_impl<true, BlockSize, ItemsPerThread, Key, Value> {
99 
100  union storage_type {
102  typename block_store_value_type::storage_type values;
103  };
104 
105  template <class KeysOutputIterator, class ValuesOutputIterator, class OffsetT>
106  ROCPRIM_DEVICE ROCPRIM_INLINE
107  void store(const OffsetT block_offset,
108  const unsigned int valid_in_last_block,
109  const bool is_incomplete_block,
110  KeysOutputIterator keys_output,
111  ValuesOutputIterator values_output,
112  Key (&keys)[ItemsPerThread],
113  Value (&values)[ItemsPerThread],
114  storage_type& storage)
115  {
116  // Synchronize before reusing shared memory
118 
119  if(is_incomplete_block)
120  {
121  block_store_key_type().store(
122  keys_output + block_offset,
123  keys,
124  valid_in_last_block,
125  storage.keys
126  );
127 
129 
131  values_output + block_offset,
132  values,
133  valid_in_last_block,
134  storage.values
135  );
136  }
137  else
138  {
139  block_store_key_type().store(
140  keys_output + block_offset,
141  keys,
142  storage.keys
143  );
144 
146 
148  values_output + block_offset,
149  values,
150  storage.values
151  );
152  }
153  }
154 };
155 
156 template<typename Value,
157  unsigned int BlockSize,
158  unsigned int ItemsPerThread,
159  typename Enable = void>
161 {
163 
164  using values_store_type
166 
168  {
169  typename values_exchange_type::storage_type exchange;
170  typename values_store_type::storage_type store;
171  };
172 
173  template<typename ValuesInputIterator, typename ValuesOutputIterator>
174  ROCPRIM_DEVICE void permute(unsigned int (&ranks)[ItemsPerThread],
175  ValuesInputIterator values_input,
176  ValuesOutputIterator values_output,
177  storage_type& storage)
178  {
179  syncthreads();
180  const auto flat_id = block_thread_id<0>();
181  Value values[ItemsPerThread];
182  block_load_direct_striped<BlockSize>(flat_id, values_input, values);
183  values_exchange_type().gather_from_striped(values, values, ranks, storage.exchange);
184  syncthreads();
185  values_store_type().store(values_output, values, storage.store);
186  }
187 
188  template<typename ValuesOutputIterator, typename ValuesInputIterator>
189  ROCPRIM_DEVICE void permute(unsigned int (&ranks)[ItemsPerThread],
190  ValuesInputIterator values_input,
191  ValuesOutputIterator values_output,
192  const unsigned int valid_in_last_block,
193  storage_type& storage)
194  {
195  syncthreads();
196  const auto flat_id = block_thread_id<0>();
197  Value values[ItemsPerThread];
198  block_load_direct_striped<BlockSize>(flat_id, values_input, values, valid_in_last_block);
199  values_exchange_type().gather_from_striped(values, values, ranks, storage.exchange);
200  syncthreads();
201  values_store_type().store(values_output, values, valid_in_last_block, storage.store);
202  }
203 };
204 
205 template<unsigned int BlockSize, unsigned int ItemsPerThread>
206 struct block_permute_values_impl<rocprim::empty_type, BlockSize, ItemsPerThread>
207 {
209 
210  template<typename ValuesInputIterator, typename ValuesOutputIterator>
211  ROCPRIM_DEVICE void permute(unsigned int (&ranks)[ItemsPerThread],
212  ValuesInputIterator values_input,
213  ValuesOutputIterator values_output,
214  storage_type& storage)
215  {
216  (void)ranks;
217  (void)values_input;
218  (void)values_output;
219  (void)storage;
220  }
221 
222  template<typename ValuesOutputIterator, typename ValuesInputIterator>
223  ROCPRIM_DEVICE void permute(unsigned int (&ranks)[ItemsPerThread],
224  ValuesInputIterator values_input,
225  ValuesOutputIterator values_output,
226  const unsigned int valid_in_last_block,
227  storage_type& storage)
228  {
229  (void)ranks;
230  (void)values_input;
231  (void)values_output;
232  (void)valid_in_last_block;
233  (void)storage;
234  }
235 };
236 
237 // The specialization below exists because the compiler creates slow code for
238 // ValueTypes with misaligned datastructures in them (e.g. custom_char_double)
239 // when storing/loading those ValueTypes to/from registers.
240 // Thus this is a temporary workaround.
241 // TODO: Check if also the case for small types like this.
242 template<typename Value, unsigned int BlockSize, unsigned int ItemsPerThread>
244  BlockSize,
245  ItemsPerThread,
246  std::enable_if_t<(std::is_trivially_copyable<Value>::value
247  && !rocprim::is_floating_point<Value>::value
248  && !std::is_integral<Value>::value)>>
249 {
250  static constexpr unsigned int items_per_block = ItemsPerThread * BlockSize;
251 
252  struct storage_type_
253  {
254  Value values[items_per_block];
255  };
256 
258 
259  template<typename ValuesInputIterator, typename ValuesOutputIterator>
260  ROCPRIM_DEVICE void permute(unsigned int (&ranks)[ItemsPerThread],
261  ValuesInputIterator values_input,
262  ValuesOutputIterator values_output,
263  storage_type& storage_)
264  {
265  syncthreads();
266  auto& values_shared = storage_.get().values;
267  const auto flat_id = block_thread_id<0>();
268 
269  ROCPRIM_UNROLL
270  for(unsigned int item = 0; item < ItemsPerThread; ++item)
271  {
272  const unsigned int idx = BlockSize * item + flat_id;
273  values_shared[idx] = values_input[idx];
274  }
275 
276  syncthreads();
277 
278  ROCPRIM_UNROLL
279  for(unsigned int item = 0; item < ItemsPerThread; ++item)
280  {
281  values_output[ItemsPerThread * flat_id + item] = values_shared[ranks[item]];
282  }
283  }
284 
285  template<typename ValuesOutputIterator, typename ValuesInputIterator>
286  ROCPRIM_DEVICE void permute(unsigned int (&ranks)[ItemsPerThread],
287  ValuesInputIterator values_input,
288  ValuesOutputIterator values_output,
289  const unsigned int valid_in_last_block,
290  storage_type& storage_)
291  {
292  syncthreads();
293  auto& values_shared = storage_.get().values;
294  const auto flat_id = block_thread_id<0>();
295 
296  ROCPRIM_UNROLL
297  for(unsigned int item = 0; item < ItemsPerThread; ++item)
298  {
299  const unsigned int idx = BlockSize * item + flat_id;
300  if(idx < valid_in_last_block)
301  {
302  values_shared[idx] = values_input[idx];
303  }
304  }
305 
306  syncthreads();
307 
308  ROCPRIM_UNROLL
309  for(unsigned int item = 0; item < ItemsPerThread; ++item)
310  {
311  if(flat_id * ItemsPerThread + item < valid_in_last_block)
312  {
313  values_output[ItemsPerThread * flat_id + item] = values_shared[ranks[item]];
314  }
315  }
316  }
317 };
318 
319 template<typename Key,
320  typename Value,
321  unsigned int BlockSize,
322  unsigned int ItemsPerThread,
324  typename Enable = void>
326 {
327  using stable_key_type = rocprim::tuple<Key, unsigned int>;
328 
329  using keys_load_type
331 
332  using sort_type
334 
335  using keys_store_type
337 
339 
341  {
342  typename keys_load_type::storage_type load_keys;
343  typename sort_type::storage_type sort;
344  typename keys_store_type::storage_type store_keys;
345  typename values_permute_type::storage_type permute_values;
346  };
347 
348  template<typename KeysInputIterator,
349  typename KeysOutputIterator,
350  typename ValuesInputIterator,
351  typename ValuesOutputIterator,
352  typename BinaryFunction>
353  ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
354  void sort(const unsigned int valid_in_last_block,
355  const bool is_incomplete_block,
356  KeysInputIterator keys_input,
357  KeysOutputIterator keys_output,
358  ValuesInputIterator values_input,
359  ValuesOutputIterator values_output,
360  BinaryFunction compare_function,
361  storage_type& storage)
362  {
363  // By default, the block sort algorithm is not stable. We can make it stable
364  // by adding an index to each key.
365 
366  Key keys[ItemsPerThread];
367 
368  if(is_incomplete_block)
369  {
370  keys_load_type().load(keys_input, keys, valid_in_last_block, storage.load_keys);
371  }
372  else
373  {
374  keys_load_type().load(keys_input, keys, storage.load_keys);
375  }
376 
377  const auto flat_id = block_thread_id<0>();
378 
379  stable_key_type stable_keys[ItemsPerThread];
380  ROCPRIM_UNROLL
381  for(unsigned int i = 0; i < ItemsPerThread; ++i)
382  {
383  stable_keys[i] = rocprim::make_tuple(keys[i], flat_id * ItemsPerThread + i);
384  }
385 
386  syncthreads();
387 
388  // Special compare function that enforces sorting is stable.
389  auto stable_compare_function
390  = [compare_function](const stable_key_type& a,
391  const stable_key_type& b) ROCPRIM_FORCE_INLINE mutable
392  {
393  const bool ab = compare_function(rocprim::get<0>(a), rocprim::get<0>(b));
394  return ab
395  || (!compare_function(rocprim::get<0>(b), rocprim::get<0>(a))
396  && (rocprim::get<1>(a) < rocprim::get<1>(b)));
397  };
398 
399  if(is_incomplete_block)
400  {
401  // Special compare function that enforces sorting is stable, and that out-of-bounds elements
402  // are not compared.
403  auto stable_oob_compare_function
404  = [stable_compare_function, valid_in_last_block](const stable_key_type& a,
405  const stable_key_type& b) mutable
406  {
407  const bool a_oob = rocprim::get<1>(a) >= valid_in_last_block;
408  const bool b_oob = rocprim::get<1>(b) >= valid_in_last_block;
409  return a_oob || b_oob ? !a_oob : stable_compare_function(a, b);
410  };
411 
412  // Note: rocprim::block_sort with an algorithm that is not stable_merge_sort does not implement sorting
413  // a misaligned amount of items.
414  sort_type().sort(stable_keys, storage.sort, stable_oob_compare_function);
415 
416  unsigned int ranks[ItemsPerThread];
417  ROCPRIM_UNROLL
418  for(unsigned int i = 0; i < ItemsPerThread; ++i)
419  {
420  keys[i] = rocprim::get<0>(stable_keys[i]);
421  ranks[i] = rocprim::get<1>(stable_keys[i]);
422  }
423 
424  syncthreads();
425  keys_store_type().store(keys_output, keys, valid_in_last_block, storage.store_keys);
426  values_permute_type().permute(ranks,
427  values_input,
428  values_output,
429  valid_in_last_block,
430  storage.permute_values);
431  }
432  else
433  {
434  sort_type().sort(stable_keys, storage.sort, stable_compare_function);
435 
436  unsigned int ranks[ItemsPerThread];
437  ROCPRIM_UNROLL
438  for(unsigned int i = 0; i < ItemsPerThread; ++i)
439  {
440  keys[i] = rocprim::get<0>(stable_keys[i]);
441  ranks[i] = rocprim::get<1>(stable_keys[i]);
442  }
443 
444  syncthreads();
445  keys_store_type().store(keys_output, keys, storage.store_keys);
446  values_permute_type().permute(ranks,
447  values_input,
448  values_output,
449  storage.permute_values);
450  }
451  }
452 };
453 
454 template<typename Key, unsigned int BlockSize, unsigned int ItemsPerThread>
455 struct block_sort_impl<Key,
457  BlockSize,
458  ItemsPerThread,
460 {
461  using keys_load_type
463 
464  using sort_type = block_sort<Key,
465  BlockSize,
466  ItemsPerThread,
467  rocprim::empty_type,
469 
470  using keys_store_type
472 
473  union storage_type
474  {
475  typename keys_load_type::storage_type load_keys;
476  typename sort_type::storage_type sort;
477  typename keys_store_type::storage_type store_keys;
478  };
479 
480  template<typename KeysInputIterator,
481  typename KeysOutputIterator,
482  typename ValuesInputIterator,
483  typename ValuesOutputIterator,
484  typename BinaryFunction>
485  ROCPRIM_DEVICE void sort(unsigned int valid_in_last_block,
486  const bool is_incomplete_block,
487  KeysInputIterator keys_input,
488  KeysOutputIterator keys_output,
489  ValuesInputIterator /*values_input*/,
490  ValuesOutputIterator /*values_output*/,
491  BinaryFunction compare_function,
492  storage_type& storage)
493  {
494  Key keys[ItemsPerThread];
495 
496  if(is_incomplete_block)
497  {
498  keys_load_type().load(keys_input, keys, valid_in_last_block, storage.load_keys);
499  syncthreads();
500  sort_type().sort(keys, storage.sort, valid_in_last_block, compare_function);
501  syncthreads();
502  keys_store_type().store(keys_output, keys, valid_in_last_block, storage.store_keys);
503  }
504  else
505  {
506  keys_load_type().load(keys_input, keys, storage.load_keys);
507  syncthreads();
508  sort_type().sort(keys, storage.sort, compare_function);
509  syncthreads();
510  keys_store_type().store(keys_output, keys, storage.store_keys);
511  }
512  }
513 };
514 
515 #ifndef DOXYGEN_SHOULD_SKIP_THIS
516 template<typename Key, typename Value, unsigned int BlockSize, unsigned int ItemsPerThread>
517 struct block_sort_impl<Key,
518  Value,
519  BlockSize,
520  ItemsPerThread,
522  std::enable_if_t<(sizeof(Value) <= sizeof(int))>>
523 {
524  using keys_load_type
526 
527  using values_load_type
529 
530  using sort_type = block_sort<Key,
531  BlockSize,
532  ItemsPerThread,
533  Value,
535 
536  using keys_store_type
538 
539  using values_store_type
541 
542  union storage_type
543  {
544  typename keys_load_type::storage_type load_keys;
545  typename values_load_type::storage_type load_values;
546  typename sort_type::storage_type sort;
547  typename keys_store_type::storage_type store_keys;
548  typename values_store_type::storage_type store_values;
549  };
550 
551  template<typename KeysInputIterator,
552  typename KeysOutputIterator,
553  typename ValuesInputIterator,
554  typename ValuesOutputIterator,
555  typename BinaryFunction>
556  ROCPRIM_DEVICE void sort(const unsigned int valid_in_last_block,
557  const bool is_incomplete_block,
558  KeysInputIterator keys_input,
559  KeysOutputIterator keys_output,
560  ValuesInputIterator values_input,
561  ValuesOutputIterator values_output,
562  BinaryFunction compare_function,
563  storage_type& storage)
564  {
565  Key keys[ItemsPerThread];
566  Value values[ItemsPerThread];
567 
568  if(is_incomplete_block)
569  {
570  keys_load_type().load(keys_input, keys, valid_in_last_block, storage.load_keys);
571  syncthreads();
572  values_load_type().load(values_input, values, valid_in_last_block, storage.load_values);
573  syncthreads();
574  sort_type().sort(keys, values, storage.sort, valid_in_last_block, compare_function);
575  syncthreads();
576  keys_store_type().store(keys_output, keys, valid_in_last_block, storage.store_keys);
577  syncthreads();
578  values_store_type().store(values_output,
579  values,
580  valid_in_last_block,
581  storage.store_values);
582  }
583  else
584  {
585  keys_load_type().load(keys_input, keys, storage.load_keys);
586  syncthreads();
587  values_load_type().load(values_input, values, storage.load_values);
588  syncthreads();
589  sort_type().sort(keys, values, storage.sort, compare_function);
590  syncthreads();
591  keys_store_type().store(keys_output, keys, storage.store_keys);
592  syncthreads();
593  values_store_type().store(values_output, values, storage.store_values);
594  }
595  }
596 };
597 template<typename Key, typename Value, unsigned int BlockSize, unsigned int ItemsPerThread>
598 struct block_sort_impl<Key,
599  Value,
600  BlockSize,
601  ItemsPerThread,
603  std::enable_if_t<(sizeof(Value) > sizeof(int))>>
604 {
605  using keys_load_type
607 
608  using sort_type = block_sort<Key,
609  BlockSize,
610  ItemsPerThread,
611  unsigned int,
613 
614  using keys_store_type
616 
618 
620  {
621  typename keys_load_type::storage_type load_keys;
622  typename sort_type::storage_type sort;
623  typename keys_store_type::storage_type store_keys;
624  typename values_permute_type::storage_type permute_values;
625  };
626 
627  template<typename KeysInputIterator,
628  typename KeysOutputIterator,
629  typename ValuesInputIterator,
630  typename ValuesOutputIterator,
631  typename BinaryFunction>
632  ROCPRIM_DEVICE void sort(const unsigned int valid_in_last_block,
633  const bool is_incomplete_block,
634  KeysInputIterator keys_input,
635  KeysOutputIterator keys_output,
636  ValuesInputIterator values_input,
637  ValuesOutputIterator values_output,
638  BinaryFunction compare_function,
639  storage_type& storage)
640  {
641  Key keys[ItemsPerThread];
642 
643  const auto flat_id = block_thread_id<0>();
644  unsigned int ranks[ItemsPerThread];
645  ROCPRIM_UNROLL
646  for(unsigned int item = 0; item < ItemsPerThread; ++item)
647  {
648  ranks[item] = flat_id * ItemsPerThread + item;
649  }
650 
651  if(is_incomplete_block)
652  {
653  keys_load_type().load(keys_input, keys, valid_in_last_block, storage.load_keys);
654  syncthreads();
655  sort_type().sort(keys, ranks, storage.sort, valid_in_last_block, compare_function);
656  syncthreads();
657  keys_store_type().store(keys_output, keys, valid_in_last_block, storage.store_keys);
658  values_permute_type().permute(ranks,
659  values_input,
660  values_output,
661  valid_in_last_block,
662  storage.permute_values);
663  }
664  else
665  {
666  keys_load_type().load(keys_input, keys, storage.load_keys);
667  syncthreads();
668  sort_type().sort(keys, ranks, storage.sort, compare_function);
669  syncthreads();
670  keys_store_type().store(keys_output, keys, storage.store_keys);
671  values_permute_type().permute(ranks,
672  values_input,
673  values_output,
674  storage.permute_values);
675  }
676  }
677 };
678 #endif // DOXYGEN_SHOULD_SKIP_THIS
679 
680 template<unsigned int BlockSize,
681  unsigned int ItemsPerThread,
683  class KeysInputIterator,
684  class KeysOutputIterator,
685  class ValuesInputIterator,
686  class ValuesOutputIterator,
687  class OffsetT,
688  class BinaryFunction,
689  class ValueType = typename std::iterator_traits<ValuesInputIterator>::value_type>
690 ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE auto block_sort_kernel_impl(KeysInputIterator keys_input,
691  KeysOutputIterator keys_output,
692  ValuesInputIterator values_input,
693  ValuesOutputIterator values_output,
694  const OffsetT input_size,
695  BinaryFunction compare_function)
696 {
697  using key_type = typename std::iterator_traits<KeysInputIterator>::value_type;
698  using value_type = typename std::iterator_traits<ValuesInputIterator>::value_type;
699 
700  const unsigned int flat_block_id = block_id<0>();
701  constexpr unsigned int items_per_block = BlockSize * ItemsPerThread;
702 
703  const OffsetT block_offset = flat_block_id * items_per_block;
704  const unsigned int valid_in_last_block = input_size - block_offset;
705  const bool is_incomplete_block = flat_block_id == (input_size / items_per_block);
706 
708 
709  ROCPRIM_SHARED_MEMORY typename sort_impl::storage_type storage;
710 
711  sort_impl().sort(valid_in_last_block,
712  is_incomplete_block,
713  keys_input + block_offset,
714  keys_output + block_offset,
715  values_input + block_offset,
716  values_output + block_offset,
717  compare_function,
718  storage);
719 }
720 
721 template<unsigned int BlockSize,
722  unsigned int ItemsPerThread,
723  class KeysInputIterator,
724  class KeysOutputIterator,
725  class ValuesInputIterator,
726  class ValuesOutputIterator,
727  class OffsetT,
728  class BinaryFunction>
729 ROCPRIM_DEVICE ROCPRIM_INLINE void block_merge_oddeven_kernel(KeysInputIterator keys_input,
730  KeysOutputIterator keys_output,
731  ValuesInputIterator values_input,
732  ValuesOutputIterator values_output,
733  const OffsetT input_size,
734  const OffsetT sorted_block_size,
735  BinaryFunction compare_function)
736 {
737  using key_type = typename std::iterator_traits<KeysInputIterator>::value_type;
738  using value_type = typename std::iterator_traits<ValuesInputIterator>::value_type;
739  constexpr bool with_values = !std::is_same<value_type, ::rocprim::empty_type>::value;
740 
741  constexpr unsigned int items_per_block = BlockSize * ItemsPerThread;
742  const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>();
743  const unsigned int flat_block_id = ::rocprim::detail::block_id<0>();
744  const bool is_incomplete_block = flat_block_id == (input_size / items_per_block);
745  // ^ bounds-checking: if input_size is not a multiple of items_per_block and
746  // this is the last block: true, false otherwise
747  const OffsetT block_offset = flat_block_id * items_per_block;
748  const OffsetT valid_in_last_block = input_size - block_offset;
749 
750  const OffsetT thread_offset = flat_id * ItemsPerThread;
751  if(thread_offset >= valid_in_last_block)
752  {
753  return;
754  }
755 
756  key_type keys[ItemsPerThread];
757  value_type values[ItemsPerThread];
758 
759  if(is_incomplete_block)
760  {
761  block_load_direct_blocked(flat_id, keys_input + block_offset, keys, valid_in_last_block);
762 
763  if ROCPRIM_IF_CONSTEXPR(with_values)
764  {
766  values_input + block_offset,
767  values,
768  valid_in_last_block);
769  }
770  }
771  else
772  {
773  block_load_direct_blocked(flat_id, keys_input + block_offset, keys);
774  if ROCPRIM_IF_CONSTEXPR(with_values)
775  {
776  block_load_direct_blocked(flat_id, values_input + block_offset, values);
777  }
778  }
779 
780  const unsigned int merged_tiles_number = sorted_block_size / items_per_block;
781  const unsigned int mask = merged_tiles_number - 1;
782  // tilegroup_id is the id of the input sorted_block
783  const unsigned int tilegroup_id = ~mask & flat_block_id;
784  const unsigned int block_is_odd = merged_tiles_number & tilegroup_id;
785  const OffsetT block_start = tilegroup_id * items_per_block;
786  const OffsetT next_block_start_
787  = block_is_odd ? block_start - sorted_block_size : block_start + sorted_block_size;
788  const OffsetT next_block_start = min(next_block_start_, input_size);
789  const OffsetT next_block_end = min(next_block_start + sorted_block_size, input_size);
790 
791  if(next_block_start == input_size)
792  {
793  // In this case, no merging needs to happen and
794  // block_is_odd will always be false here
795  if(is_incomplete_block)
796  {
797  ROCPRIM_UNROLL
798  for(unsigned int i = 0; i < ItemsPerThread; i++)
799  {
800  const unsigned int id = block_offset + thread_offset + i;
801  if(id < input_size)
802  {
803  keys_output[id] = keys[i];
804  if ROCPRIM_IF_CONSTEXPR(with_values)
805  {
806  values_output[id] = values[i];
807  }
808  }
809  }
810  }
811  else
812  {
813  ROCPRIM_UNROLL
814  for(unsigned int i = 0; i < ItemsPerThread; i++)
815  {
816  const unsigned int id = block_offset + thread_offset + i;
817  keys_output[id] = keys[i];
818  if ROCPRIM_IF_CONSTEXPR(with_values)
819  {
820  values_output[id] = values[i];
821  }
822  }
823  }
824  return;
825  }
826 
827  OffsetT left_id = next_block_start;
828 
829  const OffsetT dest_offset
830  = min(block_start, next_block_start) + block_offset + thread_offset - block_start
831  - next_block_start; // Destination offset (base+source+partial target calculation)
832 
833  const auto merge_function = [&](const unsigned int i)
834  {
835  OffsetT right_id = next_block_end;
836 
837  while(left_id < right_id)
838  {
839  OffsetT mid_id = (left_id + right_id) / 2;
840  key_type mid_key = keys_input[mid_id];
841  const bool mid_smaller = block_is_odd ? !compare_function(keys[i], mid_key)
842  : compare_function(mid_key, keys[i]);
843  left_id = mid_smaller ? mid_id + 1 : left_id;
844  right_id = mid_smaller ? right_id : mid_id;
845  }
846 
847  OffsetT offset = dest_offset + i + left_id; // Destination offset (target calculation)
848  keys_output[offset] = keys[i];
849  if ROCPRIM_IF_CONSTEXPR(with_values)
850  {
851  values_output[offset] = values[i];
852  }
853  };
854 
855  if(is_incomplete_block)
856  {
857  ROCPRIM_UNROLL
858  for(unsigned int i = 0; i < ItemsPerThread; i++)
859  {
860  if(thread_offset + i < valid_in_last_block)
861  {
862  merge_function(i);
863  }
864  }
865  }
866  else
867  {
868  ROCPRIM_UNROLL
869  for(unsigned int i = 0; i < ItemsPerThread; i++)
870  {
871  merge_function(i);
872  }
873  }
874 }
875 
876 } // end of detail namespace
877 
878 END_ROCPRIM_NAMESPACE
879 
880 #endif // ROCPRIM_DEVICE_DETAIL_DEVICE_MERGE_SORT_HPP_
Empty type used as a placeholder, usually used to flag that given template parameter should not be us...
Definition: types.hpp:135
A merged sort based algorithm which sorts stably.
ROCPRIM_DEVICE ROCPRIM_INLINE void store(OutputIterator block_output, T(&items)[ItemsPerThread])
Stores an arrangement of items from across the thread block into an arrangement on continuous memory...
Definition: block_store.hpp:168
The block_sort class is a block level parallel primitive which provides methods sorting items (keys o...
Definition: block_sort.hpp:151
The block_store class is a block level parallel primitive which provides methods for storing an arran...
Definition: block_store.hpp:134
Definition: device_merge_sort.hpp:340
The block_exchange class is a block level parallel primitive which provides methods for rearranging i...
Definition: block_exchange.hpp:81
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void sort(Key &thread_key, BinaryFunction compare_function=BinaryFunction())
Block sort for any data type.
Definition: block_sort.hpp:181
Definition: test_utils_custom_float_type.hpp:110
Definition: device_merge_sort.hpp:51
ROCPRIM_HOST_DEVICE constexpr T min(const T &a, const T &b)
Returns the minimum of its arguments.
Definition: functional.hpp:63
ROCPRIM_DEVICE ROCPRIM_INLINE void load(InputIterator block_input, T(&items)[ItemsPerThread])
Loads data from continuous memory into an arrangement of items across the thread block.
Definition: block_load.hpp:167
Definition: device_merge_sort.hpp:167
Deprecated: Configuration of device-level scan primitives.
Definition: block_histogram.hpp:62
Definition: device_merge_sort.hpp:325
Definition: device_merge_sort.hpp:160
ROCPRIM_DEVICE ROCPRIM_INLINE void syncthreads()
Synchronize all threads in a block (tile)
Definition: thread.hpp:216
typename ::rocprim::detail::empty_storage_type storage_type
Struct used to allocate a temporary memory that is required for thread communication during operation...
Definition: block_load.hpp:148
ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int flat_block_id()
Returns flat (linear, 1D) block identifier in a multidimensional grid.
Definition: thread.hpp:178
Definition: various.hpp:52
typename base_type::storage_type storage_type
Struct used to allocate a temporary memory that is required for thread communication during operation...
Definition: block_sort.hpp:166
typename ::rocprim::detail::empty_storage_type storage_type
Struct used to allocate a temporary memory that is required for thread communication during operation...
Definition: block_store.hpp:149
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void gather_from_striped(const T(&input)[ItemsPerThread], U(&output)[ItemsPerThread], const Offset(&ranks)[ItemsPerThread])
Gathers items from a striped arrangement based on their ranks across the thread block.
Definition: block_exchange.hpp:425
Definition: device_merge_sort.hpp:619
BEGIN_ROCPRIM_NAMESPACE ROCPRIM_DEVICE ROCPRIM_INLINE void block_load_direct_blocked(unsigned int flat_id, InputIterator block_input, T(&items)[ItemsPerThread])
Loads data from continuous memory into a blocked arrangement of items across the thread block...
Definition: block_load_func.hpp:58
block_sort_algorithm
Available algorithms for block_sort primitive.
Definition: block_sort.hpp:41
The block_load class is a block level parallel primitive which provides methods for loading data from...
Definition: block_load.hpp:133