rocPRIM
radix_sort.hpp
1 // Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved.
2 //
3 // Permission is hereby granted, free of charge, to any person obtaining a copy
4 // of this software and associated documentation files (the "Software"), to deal
5 // in the Software without restriction, including without limitation the rights
6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 // copies of the Software, and to permit persons to whom the Software is
8 // furnished to do so, subject to the following conditions:
9 //
10 // The above copyright notice and this permission notice shall be included in
11 // all copies or substantial portions of the Software.
12 //
13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 // THE SOFTWARE.
20 
21 #ifndef ROCPRIM_DETAIL_RADIX_SORT_HPP_
22 #define ROCPRIM_DETAIL_RADIX_SORT_HPP_
23 
24 #include <type_traits>
25 
26 #include "../config.hpp"
27 #include "../type_traits.hpp"
28 
29 BEGIN_ROCPRIM_NAMESPACE
30 namespace detail
31 {
32 
33 // Encode and decode integral and floating point values for radix sort in such a way that preserves
34 // correct order of negative and positive keys (i.e. negative keys go before positive ones,
35 // which is not true for a simple reinterpetation of the key's bits).
36 
37 // Digit extractor takes into account that (+0.0 == -0.0) is true for floats,
38 // so both +0.0 and -0.0 are reflected into the same bit pattern for digit extraction.
39 // Maximum digit length is 32.
40 
41 template<class Key, class BitKey, class Enable = void>
43 
44 template<class Key, class BitKey>
45 struct radix_key_codec_integral<Key, BitKey, typename std::enable_if<::rocprim::is_unsigned<Key>::value>::type>
46 {
47  using bit_key_type = BitKey;
48 
49  ROCPRIM_DEVICE ROCPRIM_INLINE
50  static bit_key_type encode(Key key)
51  {
52  return __builtin_bit_cast(bit_key_type, key);
53  }
54 
55  ROCPRIM_DEVICE ROCPRIM_INLINE
56  static Key decode(bit_key_type bit_key)
57  {
58  return __builtin_bit_cast(Key, bit_key);
59  }
60 
61  template<bool Descending>
62  ROCPRIM_DEVICE static unsigned int
63  extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length)
64  {
65  unsigned int mask = (1u << length) - 1;
66  return static_cast<unsigned int>(bit_key >> start) & mask;
67  }
68 };
69 
70 template<class Key, class BitKey>
71 struct radix_key_codec_integral<Key, BitKey, typename std::enable_if<::rocprim::is_signed<Key>::value>::type>
72 {
73  using bit_key_type = BitKey;
74 
75  static constexpr bit_key_type sign_bit = bit_key_type(1) << (sizeof(bit_key_type) * 8 - 1);
76 
77  ROCPRIM_DEVICE ROCPRIM_INLINE
78  static bit_key_type encode(Key key)
79  {
80  const bit_key_type bit_key = __builtin_bit_cast(bit_key_type, key);
81  return sign_bit ^ bit_key;
82  }
83 
84  ROCPRIM_DEVICE ROCPRIM_INLINE
85  static Key decode(bit_key_type bit_key)
86  {
87  bit_key ^= sign_bit;
88  return __builtin_bit_cast(Key, bit_key);
89  }
90 
91  template<bool Descending>
92  ROCPRIM_DEVICE static unsigned int
93  extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length)
94  {
95  unsigned int mask = (1u << length) - 1;
96  return static_cast<unsigned int>(bit_key >> start) & mask;
97  }
98 };
99 
100 template<class Key>
102 
103 template<>
104 struct float_bit_mask<float>
105 {
106  static constexpr uint32_t sign_bit = 0x80000000;
107  static constexpr uint32_t exponent = 0x7F800000;
108  static constexpr uint32_t mantissa = 0x007FFFFF;
109  using bit_type = uint32_t;
110 };
111 
112 template<>
113 struct float_bit_mask<double>
114 {
115  static constexpr uint64_t sign_bit = 0x8000000000000000;
116  static constexpr uint64_t exponent = 0x7FF0000000000000;
117  static constexpr uint64_t mantissa = 0x000FFFFFFFFFFFFF;
118  using bit_type = uint64_t;
119 };
120 
121 template<>
123 {
124  static constexpr uint16_t sign_bit = 0x8000;
125  static constexpr uint16_t exponent = 0x7F80;
126  static constexpr uint16_t mantissa = 0x007F;
127  using bit_type = uint16_t;
128 };
129 
130 template<>
132 {
133  static constexpr uint16_t sign_bit = 0x8000;
134  static constexpr uint16_t exponent = 0x7C00;
135  static constexpr uint16_t mantissa = 0x03FF;
136  using bit_type = uint16_t;
137 };
138 
139 template<class Key, class BitKey>
141 {
142  using bit_key_type = BitKey;
143 
144  static constexpr bit_key_type sign_bit = float_bit_mask<Key>::sign_bit;
145 
146  ROCPRIM_DEVICE ROCPRIM_INLINE
147  static bit_key_type encode(Key key)
148  {
149  bit_key_type bit_key = __builtin_bit_cast(bit_key_type, key);
150  bit_key ^= (sign_bit & bit_key) == 0 ? sign_bit : bit_key_type(-1);
151  return bit_key;
152  }
153 
154  ROCPRIM_DEVICE ROCPRIM_INLINE
155  static Key decode(bit_key_type bit_key)
156  {
157  bit_key ^= (sign_bit & bit_key) == 0 ? bit_key_type(-1) : sign_bit;
158  return __builtin_bit_cast(Key, bit_key);
159  }
160 
161  template<bool Descending>
162  ROCPRIM_DEVICE static unsigned int
163  extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length)
164  {
165  unsigned int mask = (1u << length) - 1;
166 
167  // radix_key_codec_floating::encode() maps 0.0 to 0x8000'0000,
168  // and -0.0 to 0x7FFF'FFFF.
169  // radix_key_codec::encode() then flips the bits if descending, yielding:
170  // value | descending | ascending |
171  // ----- | ----------- | ----------- |
172  // 0.0 | 0x7FFF'FFFF | 0x8000'0000 |
173  // -0.0 | 0x8000'0000 | 0x7FFF'FFFF |
174  //
175  // For ascending sort, both should be mapped to 0x8000'0000,
176  // and for descending sort, both should be mapped to 0x7FFF'FFFF.
177  if ROCPRIM_IF_CONSTEXPR(Descending)
178  {
179  bit_key = bit_key == sign_bit ? static_cast<bit_key_type>(~sign_bit) : bit_key;
180  }
181  else
182  {
183  bit_key = bit_key == static_cast<bit_key_type>(~sign_bit) ? sign_bit : bit_key;
184  }
185  return static_cast<unsigned int>(bit_key >> start) & mask;
186  }
187 };
188 
189 template<class Key, class Enable = void>
191 {
192  static_assert(sizeof(Key) == 0,
193  "Only integral and floating point types supported as radix sort keys");
194 };
195 
196 template<class Key>
198  Key,
199  typename std::enable_if<::rocprim::is_integral<Key>::value>::type
200 > : radix_key_codec_integral<Key, typename std::make_unsigned<Key>::type> { };
201 
202 template<>
204 {
205  using bit_key_type = unsigned char;
206 
207  ROCPRIM_DEVICE ROCPRIM_INLINE
208  static bit_key_type encode(bool key)
209  {
210  return static_cast<bit_key_type>(key);
211  }
212 
213  ROCPRIM_DEVICE ROCPRIM_INLINE
214  static bool decode(bit_key_type bit_key)
215  {
216  return static_cast<bool>(bit_key);
217  }
218 
219  template<bool Descending>
220  ROCPRIM_DEVICE static unsigned int
221  extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length)
222  {
223  unsigned int mask = (1u << length) - 1;
224  return static_cast<unsigned int>(bit_key >> start) & mask;
225  }
226 };
227 
228 template<>
229 struct radix_key_codec_base<::rocprim::half> : radix_key_codec_floating<::rocprim::half, unsigned short> { };
230 
231 template<>
232 struct radix_key_codec_base<::rocprim::bfloat16> : radix_key_codec_floating<::rocprim::bfloat16, unsigned short> { };
233 
234 template<>
235 struct radix_key_codec_base<float> : radix_key_codec_floating<float, unsigned int> { };
236 
237 template<>
238 struct radix_key_codec_base<double> : radix_key_codec_floating<double, unsigned long long> { };
239 
240 template<class Key, bool Descending = false>
241 class radix_key_codec : protected radix_key_codec_base<Key>
242 {
244 
245 public:
246  using bit_key_type = typename base_type::bit_key_type;
247 
248  ROCPRIM_DEVICE ROCPRIM_INLINE
249  static bit_key_type encode(Key key)
250  {
251  bit_key_type bit_key = base_type::encode(key);
252  return (Descending ? ~bit_key : bit_key);
253  }
254 
255  ROCPRIM_DEVICE ROCPRIM_INLINE
256  static Key decode(bit_key_type bit_key)
257  {
258  bit_key = (Descending ? ~bit_key : bit_key);
259  return base_type::decode(bit_key);
260  }
261 
262  ROCPRIM_DEVICE ROCPRIM_INLINE
263  static unsigned int extract_digit(bit_key_type bit_key, unsigned int start, unsigned int radix_bits)
264  {
265  return base_type::template extract_digit<Descending>(bit_key, start, radix_bits);
266  }
267 };
268 
269 } // end namespace detail
270 END_ROCPRIM_NAMESPACE
271 
272 #endif // ROCPRIM_DETAIL_RADIX_SORT_HPP_
Definition: radix_sort.hpp:140
Definition: radix_sort.hpp:101
Definition: test_utils_custom_float_type.hpp:110
Deprecated: Configuration of device-level scan primitives.
Definition: block_histogram.hpp:62
Definition: radix_sort.hpp:190
::hip_bfloat16 bfloat16
bfloat16 floating point type
Definition: types.hpp:148
Definition: radix_sort.hpp:241
Definition: radix_sort.hpp:42
::__half half
Half-precision floating point type.
Definition: types.hpp:146