rocPRIM
test_utils_sort_comparator.hpp
1 // MIT License
2 //
3 // Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved.
4 //
5 // Permission is hereby granted, free of charge, to any person obtaining a copy
6 // of this software and associated documentation files (the "Software"), to deal
7 // in the Software without restriction, including without limitation the rights
8 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 // copies of the Software, and to permit persons to whom the Software is
10 // furnished to do so, subject to the following conditions:
11 //
12 // The above copyright notice and this permission notice shall be included in all
13 // copies or substantial portions of the Software.
14 //
15 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 // SOFTWARE.
22 
23 #ifndef TEST_UTILS_SORT_COMPARATOR_HPP_
24 #define TEST_UTILS_SORT_COMPARATOR_HPP_
25 
26 #include <rocprim/type_traits.hpp>
27 
28 #include "test_utils_half.hpp"
29 #include "test_utils_bfloat16.hpp"
30 
31 namespace test_utils
32 {
33 
34 template<class T>
35 constexpr bool is_floating_nan_host(const T& a)
36 {
37  return (a != a);
38 }
39 
40 template<class Key,
41  bool Descending,
42  unsigned int StartBit,
43  unsigned int EndBit,
44  class Enable = void>
46 {};
47 
48 template<class Key, bool Descending, unsigned int StartBit, unsigned int EndBit>
49 struct key_comparator<Key,
50  Descending,
51  StartBit,
52  EndBit,
53  typename std::enable_if<rocprim::is_integral<Key>::value>::type>
54 {
55  static constexpr Key radix_mask_upper
56  = EndBit == 8 * sizeof(Key) ? ~Key(0) : (Key(1) << EndBit) - 1;
57  static constexpr Key radix_mask_bottom = (Key(1) << StartBit) - 1;
58  static constexpr Key radix_mask = radix_mask_upper ^ radix_mask_bottom;
59 
60  bool operator()(const Key& lhs, const Key& rhs) const
61  {
62  Key l = lhs & radix_mask;
63  Key r = rhs & radix_mask;
64  return Descending ? (r < l) : (l < r);
65  }
66 };
67 
68 template<class Key, bool Descending, unsigned int StartBit, unsigned int EndBit>
69 struct key_comparator<Key,
70  Descending,
71  StartBit,
72  EndBit,
73  typename std::enable_if<rocprim::is_floating_point<Key>::value>::type>
74 {
75  using unsigned_bits_type = typename rocprim::get_unsigned_bits_type<Key>::unsigned_type;
76 
77  bool operator()(const Key& lhs, const Key& rhs) const
78  {
80  this->to_bits(lhs),
81  this->to_bits(rhs));
82  }
83 
84  unsigned_bits_type to_bits(const Key& key) const
85  {
86  unsigned_bits_type bit_key;
87  memcpy(&bit_key, &key, sizeof(Key));
88 
89  // Remove signed zero, this case is supposed to be treated the same as
90  // unsigned zero in rocprim sorting algorithms.
91  constexpr unsigned_bits_type minus_zero = unsigned_bits_type{1} << (8 * sizeof(Key) - 1);
92  // Positive and negative zero should compare the same.
93  if(bit_key == minus_zero)
94  {
95  bit_key = 0;
96  }
97  // Flip bits mantissa and exponent if the key is negative, so as to make
98  // 'more negative' values compare before 'less negative'.
99  if(bit_key & minus_zero)
100  {
101  bit_key ^= ~minus_zero;
102  }
103  // Make negatives compare before positives.
104  bit_key ^= minus_zero;
105  return bit_key;
106  }
107 };
108 
109 template<class Key, class Value, bool Descending, unsigned int StartBit, unsigned int EndBit>
111 {
112  bool operator()(const std::pair<Key, Value>& lhs, const std::pair<Key, Value>& rhs)
113  {
114  return key_comparator<Key, Descending, StartBit, EndBit>()(lhs.first, rhs.first);
115  }
116 };
117 
118 template <bool Descending>
119 struct key_comparator<rocprim::half, Descending, 0, sizeof(rocprim::half) * 8>
120 {
121  bool operator()(const rocprim::half& lhs, const rocprim::half& rhs)
122  {
123  // HIP's half doesn't have __host__ comparison operators, use floats instead
125  }
126 };
127 
128 template <bool Descending>
129 struct key_comparator<rocprim::bfloat16, Descending, 0, sizeof(rocprim::bfloat16) * 8>
130 {
131  bool operator()(const rocprim::bfloat16& lhs, const rocprim::bfloat16& rhs)
132  {
133  // HIP's bfloat16 doesn't have __host__ comparison operators, use floats instead
135  }
136 };
137 
138 }
139 #endif // TEST_UTILS_SORT_COMPARATOR_HPP_
Definition: test_utils_sort_comparator.hpp:110
Definition: test_utils_custom_float_type.hpp:110
::hip_bfloat16 bfloat16
bfloat16 floating point type
Definition: types.hpp:148
Definition: test_utils_sort_comparator.hpp:45
Definition: bounds_checking_iterator.hpp:24
::__half half
Half-precision floating point type.
Definition: types.hpp:146