23 #ifndef TEST_UTILS_SORT_COMPARATOR_HPP_ 24 #define TEST_UTILS_SORT_COMPARATOR_HPP_ 26 #include <rocprim/type_traits.hpp> 28 #include "test_utils_half.hpp" 29 #include "test_utils_bfloat16.hpp" 35 constexpr
bool is_floating_nan_host(
const T& a)
42 unsigned int StartBit,
48 template<
class Key,
bool Descending,
unsigned int StartBit,
unsigned int EndBit>
53 typename std::enable_if<rocprim::is_integral<Key>::value>::type>
55 static constexpr Key radix_mask_upper
56 = EndBit == 8 *
sizeof(Key) ? ~Key(0) : (Key(1) << EndBit) - 1;
57 static constexpr Key radix_mask_bottom = (Key(1) << StartBit) - 1;
58 static constexpr Key radix_mask = radix_mask_upper ^ radix_mask_bottom;
60 bool operator()(
const Key& lhs,
const Key& rhs)
const 62 Key l = lhs & radix_mask;
63 Key r = rhs & radix_mask;
64 return Descending ? (r < l) : (l < r);
68 template<
class Key,
bool Descending,
unsigned int StartBit,
unsigned int EndBit>
73 typename std::enable_if<rocprim::is_floating_point<Key>::value>::type>
75 using unsigned_bits_type =
typename rocprim::get_unsigned_bits_type<Key>::unsigned_type;
77 bool operator()(
const Key& lhs,
const Key& rhs)
const 84 unsigned_bits_type to_bits(
const Key& key)
const 86 unsigned_bits_type bit_key;
87 memcpy(&bit_key, &key,
sizeof(Key));
91 constexpr unsigned_bits_type minus_zero = unsigned_bits_type{1} << (8 *
sizeof(Key) - 1);
93 if(bit_key == minus_zero)
99 if(bit_key & minus_zero)
101 bit_key ^= ~minus_zero;
104 bit_key ^= minus_zero;
109 template<
class Key,
class Value,
bool Descending,
unsigned int StartBit,
unsigned int EndBit>
112 bool operator()(
const std::pair<Key, Value>& lhs,
const std::pair<Key, Value>& rhs)
118 template <
bool Descending>
128 template <
bool Descending>
139 #endif // TEST_UTILS_SORT_COMPARATOR_HPP_ Definition: test_utils_sort_comparator.hpp:110
Definition: test_utils_custom_float_type.hpp:110
::hip_bfloat16 bfloat16
bfloat16 floating point type
Definition: types.hpp:148
Definition: test_utils_sort_comparator.hpp:45
Definition: bounds_checking_iterator.hpp:24
::__half half
Half-precision floating point type.
Definition: types.hpp:146