21 #ifndef ROCPRIM_DETAIL_RADIX_SORT_HPP_ 22 #define ROCPRIM_DETAIL_RADIX_SORT_HPP_ 24 #include <type_traits> 26 #include "../config.hpp" 27 #include "../type_traits.hpp" 29 BEGIN_ROCPRIM_NAMESPACE
41 template<
class Key,
class BitKey,
class Enable =
void>
44 template<
class Key,
class BitKey>
47 using bit_key_type = BitKey;
49 ROCPRIM_DEVICE ROCPRIM_INLINE
50 static bit_key_type encode(Key key)
52 return __builtin_bit_cast(bit_key_type, key);
55 ROCPRIM_DEVICE ROCPRIM_INLINE
56 static Key decode(bit_key_type bit_key)
58 return __builtin_bit_cast(Key, bit_key);
61 template<
bool Descending>
62 ROCPRIM_DEVICE
static unsigned int 63 extract_digit(bit_key_type bit_key,
unsigned int start,
unsigned int length)
65 unsigned int mask = (1u << length) - 1;
66 return static_cast<unsigned int>(bit_key >> start) & mask;
70 template<
class Key,
class BitKey>
73 using bit_key_type = BitKey;
75 static constexpr bit_key_type sign_bit = bit_key_type(1) << (
sizeof(bit_key_type) * 8 - 1);
77 ROCPRIM_DEVICE ROCPRIM_INLINE
78 static bit_key_type encode(Key key)
80 const bit_key_type bit_key = __builtin_bit_cast(bit_key_type, key);
81 return sign_bit ^ bit_key;
84 ROCPRIM_DEVICE ROCPRIM_INLINE
85 static Key decode(bit_key_type bit_key)
88 return __builtin_bit_cast(Key, bit_key);
91 template<
bool Descending>
92 ROCPRIM_DEVICE
static unsigned int 93 extract_digit(bit_key_type bit_key,
unsigned int start,
unsigned int length)
95 unsigned int mask = (1u << length) - 1;
96 return static_cast<unsigned int>(bit_key >> start) & mask;
106 static constexpr uint32_t sign_bit = 0x80000000;
107 static constexpr uint32_t exponent = 0x7F800000;
108 static constexpr uint32_t mantissa = 0x007FFFFF;
109 using bit_type = uint32_t;
115 static constexpr uint64_t sign_bit = 0x8000000000000000;
116 static constexpr uint64_t exponent = 0x7FF0000000000000;
117 static constexpr uint64_t mantissa = 0x000FFFFFFFFFFFFF;
118 using bit_type = uint64_t;
124 static constexpr uint16_t sign_bit = 0x8000;
125 static constexpr uint16_t exponent = 0x7F80;
126 static constexpr uint16_t mantissa = 0x007F;
127 using bit_type = uint16_t;
133 static constexpr uint16_t sign_bit = 0x8000;
134 static constexpr uint16_t exponent = 0x7C00;
135 static constexpr uint16_t mantissa = 0x03FF;
136 using bit_type = uint16_t;
139 template<
class Key,
class BitKey>
142 using bit_key_type = BitKey;
146 ROCPRIM_DEVICE ROCPRIM_INLINE
147 static bit_key_type encode(Key key)
149 bit_key_type bit_key = __builtin_bit_cast(bit_key_type, key);
150 bit_key ^= (sign_bit & bit_key) == 0 ? sign_bit : bit_key_type(-1);
154 ROCPRIM_DEVICE ROCPRIM_INLINE
155 static Key decode(bit_key_type bit_key)
157 bit_key ^= (sign_bit & bit_key) == 0 ? bit_key_type(-1) : sign_bit;
158 return __builtin_bit_cast(Key, bit_key);
161 template<
bool Descending>
162 ROCPRIM_DEVICE
static unsigned int 163 extract_digit(bit_key_type bit_key,
unsigned int start,
unsigned int length)
165 unsigned int mask = (1u << length) - 1;
177 if ROCPRIM_IF_CONSTEXPR(Descending)
179 bit_key = bit_key == sign_bit ?
static_cast<bit_key_type
>(~sign_bit) : bit_key;
183 bit_key = bit_key ==
static_cast<bit_key_type
>(~sign_bit) ? sign_bit : bit_key;
185 return static_cast<unsigned int>(bit_key >> start) & mask;
189 template<
class Key,
class Enable =
void>
192 static_assert(
sizeof(Key) == 0,
193 "Only integral and floating point types supported as radix sort keys");
199 typename std::enable_if<::rocprim::is_integral<Key>::value>::type
205 using bit_key_type =
unsigned char;
207 ROCPRIM_DEVICE ROCPRIM_INLINE
208 static bit_key_type encode(
bool key)
210 return static_cast<bit_key_type
>(key);
213 ROCPRIM_DEVICE ROCPRIM_INLINE
214 static bool decode(bit_key_type bit_key)
216 return static_cast<bool>(bit_key);
219 template<
bool Descending>
220 ROCPRIM_DEVICE
static unsigned int 221 extract_digit(bit_key_type bit_key,
unsigned int start,
unsigned int length)
223 unsigned int mask = (1u << length) - 1;
224 return static_cast<unsigned int>(bit_key >> start) & mask;
240 template<
class Key,
bool Descending = false>
246 using bit_key_type =
typename base_type::bit_key_type;
248 ROCPRIM_DEVICE ROCPRIM_INLINE
249 static bit_key_type encode(Key key)
251 bit_key_type bit_key = base_type::encode(key);
252 return (Descending ? ~bit_key : bit_key);
255 ROCPRIM_DEVICE ROCPRIM_INLINE
256 static Key decode(bit_key_type bit_key)
258 bit_key = (Descending ? ~bit_key : bit_key);
259 return base_type::decode(bit_key);
262 ROCPRIM_DEVICE ROCPRIM_INLINE
263 static unsigned int extract_digit(bit_key_type bit_key,
unsigned int start,
unsigned int radix_bits)
265 return base_type::template extract_digit<Descending>(bit_key, start, radix_bits);
270 END_ROCPRIM_NAMESPACE
272 #endif // ROCPRIM_DETAIL_RADIX_SORT_HPP_ Definition: radix_sort.hpp:140
Definition: radix_sort.hpp:101
Definition: test_utils_custom_float_type.hpp:110
Deprecated: Configuration of device-level scan primitives.
Definition: block_histogram.hpp:62
Definition: radix_sort.hpp:190
::hip_bfloat16 bfloat16
bfloat16 floating point type
Definition: types.hpp:148
Definition: radix_sort.hpp:241
Definition: radix_sort.hpp:42
::__half half
Half-precision floating point type.
Definition: types.hpp:146