rocPRIM
block_radix_sort.hpp
1 // Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved.
2 //
3 // Permission is hereby granted, free of charge, to any person obtaining a copy
4 // of this software and associated documentation files (the "Software"), to deal
5 // in the Software without restriction, including without limitation the rights
6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 // copies of the Software, and to permit persons to whom the Software is
8 // furnished to do so, subject to the following conditions:
9 //
10 // The above copyright notice and this permission notice shall be included in
11 // all copies or substantial portions of the Software.
12 //
13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 // THE SOFTWARE.
20 
21 #ifndef ROCPRIM_BLOCK_BLOCK_RADIX_SORT_HPP_
22 #define ROCPRIM_BLOCK_BLOCK_RADIX_SORT_HPP_
23 
24 #include <type_traits>
25 
26 #include "../config.hpp"
27 #include "../detail/various.hpp"
28 #include "../detail/radix_sort.hpp"
29 #include "../warp/detail/warp_scan_crosslane.hpp"
30 
31 #include "../intrinsics.hpp"
32 #include "../functional.hpp"
33 #include "../types.hpp"
34 
35 #include "block_exchange.hpp"
36 #include "block_radix_rank.hpp"
37 
40 
41 BEGIN_ROCPRIM_NAMESPACE
42 
89 template<
90  class Key,
91  unsigned int BlockSizeX,
92  unsigned int ItemsPerThread,
93  class Value = empty_type,
94  unsigned int BlockSizeY = 1,
95  unsigned int BlockSizeZ = 1
96 >
98 {
99  static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ;
100  static constexpr bool with_values = !std::is_same<Value, empty_type>::value;
101  static constexpr unsigned int radix_bits_per_pass = 4;
102 
103  using bit_key_type = typename ::rocprim::detail::radix_key_codec<Key>::bit_key_type;
104  using block_rank_type = ::rocprim::block_radix_rank<BlockSizeX,
105  radix_bits_per_pass,
107  BlockSizeY,
108  BlockSizeZ>;
109  using bit_keys_exchange_type = ::rocprim::block_exchange<bit_key_type, BlockSizeX, ItemsPerThread, BlockSizeY, BlockSizeZ>;
110  using values_exchange_type
111  = ::rocprim::block_exchange<Value, BlockSizeX, ItemsPerThread, BlockSizeY, BlockSizeZ>;
112 
113  // Struct used for creating a raw_storage object for this primitive's temporary storage.
114  union storage_type_
115  {
116  typename bit_keys_exchange_type::storage_type bit_keys_exchange;
117  typename values_exchange_type::storage_type values_exchange;
118  typename block_rank_type::storage_type rank;
119  };
120 
121 public:
122 
131  #ifndef DOXYGEN_SHOULD_SKIP_THIS // hides storage_type implementation for Doxygen
133  #else
134  using storage_type = storage_type_; // only for Doxygen
135  #endif
136 
178  ROCPRIM_DEVICE ROCPRIM_INLINE
179  void sort(Key (&keys)[ItemsPerThread],
180  storage_type& storage,
181  unsigned int begin_bit = 0,
182  unsigned int end_bit = 8 * sizeof(Key))
183  {
184  empty_type values[ItemsPerThread];
185  sort_impl<false>(keys, values, storage, begin_bit, end_bit);
186  }
187 
200  ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
201  void sort(Key (&keys)[ItemsPerThread],
202  unsigned int begin_bit = 0,
203  unsigned int end_bit = 8 * sizeof(Key))
204  {
205  ROCPRIM_SHARED_MEMORY storage_type storage;
206  sort(keys, storage, begin_bit, end_bit);
207  }
208 
250  ROCPRIM_DEVICE ROCPRIM_INLINE
251  void sort_desc(Key (&keys)[ItemsPerThread],
252  storage_type& storage,
253  unsigned int begin_bit = 0,
254  unsigned int end_bit = 8 * sizeof(Key))
255  {
256  empty_type values[ItemsPerThread];
257  sort_impl<true>(keys, values, storage, begin_bit, end_bit);
258  }
259 
272  ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
273  void sort_desc(Key (&keys)[ItemsPerThread],
274  unsigned int begin_bit = 0,
275  unsigned int end_bit = 8 * sizeof(Key))
276  {
277  ROCPRIM_SHARED_MEMORY storage_type storage;
278  sort_desc(keys, storage, begin_bit, end_bit);
279  }
280 
330  template<bool WithValues = with_values>
331  ROCPRIM_DEVICE ROCPRIM_INLINE
332  void sort(Key (&keys)[ItemsPerThread],
333  typename std::enable_if<WithValues, Value>::type (&values)[ItemsPerThread],
334  storage_type& storage,
335  unsigned int begin_bit = 0,
336  unsigned int end_bit = 8 * sizeof(Key))
337  {
338  sort_impl<false>(keys, values, storage, begin_bit, end_bit);
339  }
340 
357  template<bool WithValues = with_values>
358  ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
359  void sort(Key (&keys)[ItemsPerThread],
360  typename std::enable_if<WithValues, Value>::type (&values)[ItemsPerThread],
361  unsigned int begin_bit = 0,
362  unsigned int end_bit = 8 * sizeof(Key))
363  {
364  ROCPRIM_SHARED_MEMORY storage_type storage;
365  sort(keys, values, storage, begin_bit, end_bit);
366  }
367 
417  template<bool WithValues = with_values>
418  ROCPRIM_DEVICE ROCPRIM_INLINE
419  void sort_desc(Key (&keys)[ItemsPerThread],
420  typename std::enable_if<WithValues, Value>::type (&values)[ItemsPerThread],
421  storage_type& storage,
422  unsigned int begin_bit = 0,
423  unsigned int end_bit = 8 * sizeof(Key))
424  {
425  sort_impl<true>(keys, values, storage, begin_bit, end_bit);
426  }
427 
444  template<bool WithValues = with_values>
445  ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
446  void sort_desc(Key (&keys)[ItemsPerThread],
447  typename std::enable_if<WithValues, Value>::type (&values)[ItemsPerThread],
448  unsigned int begin_bit = 0,
449  unsigned int end_bit = 8 * sizeof(Key))
450  {
451  ROCPRIM_SHARED_MEMORY storage_type storage;
452  sort_desc(keys, values, storage, begin_bit, end_bit);
453  }
454 
497  ROCPRIM_DEVICE ROCPRIM_INLINE
498  void sort_to_striped(Key (&keys)[ItemsPerThread],
499  storage_type& storage,
500  unsigned int begin_bit = 0,
501  unsigned int end_bit = 8 * sizeof(Key))
502  {
503  empty_type values[ItemsPerThread];
504  sort_impl<false, true>(keys, values, storage, begin_bit, end_bit);
505  }
506 
520  ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
521  void sort_to_striped(Key (&keys)[ItemsPerThread],
522  unsigned int begin_bit = 0,
523  unsigned int end_bit = 8 * sizeof(Key))
524  {
525  ROCPRIM_SHARED_MEMORY storage_type storage;
526  sort_to_striped(keys, storage, begin_bit, end_bit);
527  }
528 
571  ROCPRIM_DEVICE ROCPRIM_INLINE
572  void sort_desc_to_striped(Key (&keys)[ItemsPerThread],
573  storage_type& storage,
574  unsigned int begin_bit = 0,
575  unsigned int end_bit = 8 * sizeof(Key))
576  {
577  empty_type values[ItemsPerThread];
578  sort_impl<true, true>(keys, values, storage, begin_bit, end_bit);
579  }
580 
594  ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
595  void sort_desc_to_striped(Key (&keys)[ItemsPerThread],
596  unsigned int begin_bit = 0,
597  unsigned int end_bit = 8 * sizeof(Key))
598  {
599  ROCPRIM_SHARED_MEMORY storage_type storage;
600  sort_desc_to_striped(keys, storage, begin_bit, end_bit);
601  }
602 
652  template<bool WithValues = with_values>
653  ROCPRIM_DEVICE ROCPRIM_INLINE
654  void sort_to_striped(Key (&keys)[ItemsPerThread],
655  typename std::enable_if<WithValues, Value>::type (&values)[ItemsPerThread],
656  storage_type& storage,
657  unsigned int begin_bit = 0,
658  unsigned int end_bit = 8 * sizeof(Key))
659  {
660  sort_impl<false, true>(keys, values, storage, begin_bit, end_bit);
661  }
662 
677  template<bool WithValues = with_values>
678  ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
679  void sort_to_striped(Key (&keys)[ItemsPerThread],
680  typename std::enable_if<WithValues, Value>::type (&values)[ItemsPerThread],
681  unsigned int begin_bit = 0,
682  unsigned int end_bit = 8 * sizeof(Key))
683  {
684  ROCPRIM_SHARED_MEMORY storage_type storage;
685  sort_to_striped(keys, values, storage, begin_bit, end_bit);
686  }
687 
737  template<bool WithValues = with_values>
738  ROCPRIM_DEVICE ROCPRIM_INLINE
739  void sort_desc_to_striped(Key (&keys)[ItemsPerThread],
740  typename std::enable_if<WithValues, Value>::type (&values)[ItemsPerThread],
741  storage_type& storage,
742  unsigned int begin_bit = 0,
743  unsigned int end_bit = 8 * sizeof(Key))
744  {
745  sort_impl<true, true>(keys, values, storage, begin_bit, end_bit);
746  }
747 
762  template<bool WithValues = with_values>
763  ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
764  void sort_desc_to_striped(Key (&keys)[ItemsPerThread],
765  typename std::enable_if<WithValues, Value>::type (&values)[ItemsPerThread],
766  unsigned int begin_bit = 0,
767  unsigned int end_bit = 8 * sizeof(Key))
768  {
769  ROCPRIM_SHARED_MEMORY storage_type storage;
770  sort_desc_to_striped(keys, values, storage, begin_bit, end_bit);
771  }
772 
773 private:
774 
775  template<bool Descending, bool ToStriped = false, class SortedValue>
776  ROCPRIM_DEVICE ROCPRIM_INLINE
777  void sort_impl(Key (&keys)[ItemsPerThread],
778  SortedValue (&values)[ItemsPerThread],
779  storage_type& storage,
780  unsigned int begin_bit,
781  unsigned int end_bit)
782  {
783  using key_codec = ::rocprim::detail::radix_key_codec<Key, Descending>;
784 
785  bit_key_type bit_keys[ItemsPerThread];
786  ROCPRIM_UNROLL
787  for(unsigned int i = 0; i < ItemsPerThread; i++)
788  {
789  bit_keys[i] = key_codec::encode(keys[i]);
790  }
791 
792  while(true)
793  {
794  const int pass_bits = min(radix_bits_per_pass, end_bit - begin_bit);
795 
796  unsigned int ranks[ItemsPerThread];
797  block_rank_type().rank_keys(
798  bit_keys,
799  ranks,
800  storage.get().rank,
801  [begin_bit, pass_bits](const bit_key_type& key)
802  { return key_codec::extract_digit(key, begin_bit, pass_bits); });
803  begin_bit += radix_bits_per_pass;
804 
805  exchange_keys(storage, bit_keys, ranks);
806  exchange_values(storage, values, ranks);
807 
808  if(begin_bit >= end_bit)
809  break;
810 
811  // Synchronization required to make bock_rank wait on the next iteration.
813  }
814 
815  if ROCPRIM_IF_CONSTEXPR(ToStriped)
816  {
817  to_striped_keys(storage, bit_keys);
818  to_striped_values(storage, values);
819  }
820 
821  ROCPRIM_UNROLL
822  for(unsigned int i = 0; i < ItemsPerThread; i++)
823  {
824  keys[i] = key_codec::decode(bit_keys[i]);
825  }
826  }
827 
828  ROCPRIM_DEVICE ROCPRIM_INLINE
829  void exchange_keys(storage_type& storage,
830  bit_key_type (&bit_keys)[ItemsPerThread],
831  const unsigned int (&ranks)[ItemsPerThread])
832  {
833  storage_type_& storage_ = storage.get();
834  ::rocprim::syncthreads(); // Storage will be reused (union), synchronization is needed
835  bit_keys_exchange_type().scatter_to_blocked(bit_keys, bit_keys, ranks, storage_.bit_keys_exchange);
836  }
837 
838  template<class SortedValue>
839  ROCPRIM_DEVICE ROCPRIM_INLINE
840  void exchange_values(storage_type& storage,
841  SortedValue (&values)[ItemsPerThread],
842  const unsigned int (&ranks)[ItemsPerThread])
843  {
844  storage_type_& storage_ = storage.get();
845  ::rocprim::syncthreads(); // Storage will be reused (union), synchronization is needed
846  values_exchange_type().scatter_to_blocked(values, values, ranks, storage_.values_exchange);
847  }
848 
849  ROCPRIM_DEVICE ROCPRIM_INLINE
850  void exchange_values(storage_type& storage,
851  empty_type (&values)[ItemsPerThread],
852  const unsigned int (&ranks)[ItemsPerThread])
853  {
854  (void) storage;
855  (void) values;
856  (void) ranks;
857  }
858 
859  ROCPRIM_DEVICE ROCPRIM_INLINE
860  void to_striped_keys(storage_type& storage,
861  bit_key_type (&bit_keys)[ItemsPerThread])
862  {
863  storage_type_& storage_ = storage.get();
864  ::rocprim::syncthreads(); // Storage will be reused (union), synchronization is needed
865  bit_keys_exchange_type().blocked_to_striped(bit_keys, bit_keys, storage_.bit_keys_exchange);
866  }
867 
868  template<class SortedValue>
869  ROCPRIM_DEVICE ROCPRIM_INLINE
870  void to_striped_values(storage_type& storage,
871  SortedValue (&values)[ItemsPerThread])
872  {
873  storage_type_& storage_ = storage.get();
874  ::rocprim::syncthreads(); // Storage will be reused (union), synchronization is needed
875  values_exchange_type().blocked_to_striped(values, values, storage_.values_exchange);
876  }
877 
878  ROCPRIM_DEVICE ROCPRIM_INLINE
879  void to_striped_values(storage_type& storage,
880  empty_type * values)
881  {
882  (void) storage;
883  (void) values;
884  }
885 };
886 
887 END_ROCPRIM_NAMESPACE
888 
890 // end of group blockmodule
891 
892 #endif // ROCPRIM_BLOCK_BLOCK_RADIX_SORT_HPP_
Empty type used as a placeholder, usually used to flag that given template parameter should not be us...
Definition: types.hpp:135
ROCPRIM_DEVICE ROCPRIM_INLINE void sort_desc_to_striped(Key(&keys)[ItemsPerThread], storage_type &storage, unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key))
Performs descending radix sort over keys partitioned across threads in a block, results are saved in ...
Definition: block_radix_sort.hpp:572
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void sort_to_striped(Key(&keys)[ItemsPerThread], typename std::enable_if< WithValues, Value >::type(&values)[ItemsPerThread], unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key))
This is an overloaded member function, provided for convenience. It differs from the above function o...
Definition: block_radix_sort.hpp:679
The basic block radix rank algorithm, configured to memoize intermediate values.
ROCPRIM_DEVICE ROCPRIM_INLINE void sort_to_striped(Key(&keys)[ItemsPerThread], typename std::enable_if< WithValues, Value >::type(&values)[ItemsPerThread], storage_type &storage, unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key))
Performs ascending radix sort over key-value pairs partitioned across threads in a block...
Definition: block_radix_sort.hpp:654
ROCPRIM_DEVICE ROCPRIM_INLINE void sort(Key(&keys)[ItemsPerThread], typename std::enable_if< WithValues, Value >::type(&values)[ItemsPerThread], storage_type &storage, unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key))
Performs ascending radix sort over key-value pairs partitioned across threads in a block...
Definition: block_radix_sort.hpp:332
ROCPRIM_HOST_DEVICE constexpr T min(const T &a, const T &b)
Returns the minimum of its arguments.
Definition: functional.hpp:63
The block_radix_sort class is a block level parallel primitive which provides methods for sorting of ...
Definition: block_radix_sort.hpp:97
ROCPRIM_DEVICE ROCPRIM_INLINE void sort_desc_to_striped(Key(&keys)[ItemsPerThread], typename std::enable_if< WithValues, Value >::type(&values)[ItemsPerThread], storage_type &storage, unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key))
Performs descending radix sort over key-value pairs partitioned across threads in a block...
Definition: block_radix_sort.hpp:739
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void sort_desc(Key(&keys)[ItemsPerThread], typename std::enable_if< WithValues, Value >::type(&values)[ItemsPerThread], unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key))
This is an overloaded member function, provided for convenience. It differs from the above function o...
Definition: block_radix_sort.hpp:446
ROCPRIM_DEVICE ROCPRIM_INLINE void sort_to_striped(Key(&keys)[ItemsPerThread], storage_type &storage, unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key))
Performs ascending radix sort over keys partitioned across threads in a block, results are saved in a...
Definition: block_radix_sort.hpp:498
ROCPRIM_DEVICE ROCPRIM_INLINE void syncthreads()
Synchronize all threads in a block (tile)
Definition: thread.hpp:216
ROCPRIM_DEVICE ROCPRIM_INLINE void sort(Key(&keys)[ItemsPerThread], storage_type &storage, unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key))
Performs ascending radix sort over keys partitioned across threads in a block.
Definition: block_radix_sort.hpp:179
ROCPRIM_DEVICE ROCPRIM_INLINE void sort_desc(Key(&keys)[ItemsPerThread], storage_type &storage, unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key))
Performs descending radix sort over keys partitioned across threads in a block.
Definition: block_radix_sort.hpp:251
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void sort(Key(&keys)[ItemsPerThread], unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key))
This is an overloaded member function, provided for convenience. It differs from the above function o...
Definition: block_radix_sort.hpp:201
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void sort_desc_to_striped(Key(&keys)[ItemsPerThread], unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key))
This is an overloaded member function, provided for convenience. It differs from the above function o...
Definition: block_radix_sort.hpp:595
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void sort_to_striped(Key(&keys)[ItemsPerThread], unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key))
This is an overloaded member function, provided for convenience. It differs from the above function o...
Definition: block_radix_sort.hpp:521
ROCPRIM_DEVICE ROCPRIM_INLINE void sort_desc(Key(&keys)[ItemsPerThread], typename std::enable_if< WithValues, Value >::type(&values)[ItemsPerThread], storage_type &storage, unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key))
Performs descending radix sort over key-value pairs partitioned across threads in a block...
Definition: block_radix_sort.hpp:419
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void sort(Key(&keys)[ItemsPerThread], typename std::enable_if< WithValues, Value >::type(&values)[ItemsPerThread], unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key))
This is an overloaded member function, provided for convenience. It differs from the above function o...
Definition: block_radix_sort.hpp:359
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void sort_desc(Key(&keys)[ItemsPerThread], unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key))
This is an overloaded member function, provided for convenience. It differs from the above function o...
Definition: block_radix_sort.hpp:273
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void sort_desc_to_striped(Key(&keys)[ItemsPerThread], typename std::enable_if< WithValues, Value >::type(&values)[ItemsPerThread], unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key))
This is an overloaded member function, provided for convenience. It differs from the above function o...
Definition: block_radix_sort.hpp:764