rocPRIM
types.hpp
1 // Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved.
2 //
3 // Permission is hereby granted, free of charge, to any person obtaining a copy
4 // of this software and associated documentation files (the "Software"), to deal
5 // in the Software without restriction, including without limitation the rights
6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 // copies of the Software, and to permit persons to whom the Software is
8 // furnished to do so, subject to the following conditions:
9 //
10 // The above copyright notice and this permission notice shall be included in
11 // all copies or substantial portions of the Software.
12 //
13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 // THE SOFTWARE.
20 
21 #ifndef ROCPRIM_TYPES_HPP_
22 #define ROCPRIM_TYPES_HPP_
23 
24 #include <type_traits>
25 
26 // Meta configuration for rocPRIM
27 #include "config.hpp"
28 
29 #include "types/future_value.hpp"
30 #include "types/double_buffer.hpp"
31 #include "types/integer_sequence.hpp"
32 #include "types/key_value_pair.hpp"
33 #include "types/tuple.hpp"
34 
37 
38 BEGIN_ROCPRIM_NAMESPACE
39 
40 namespace detail
41 {
42 // Define vector types that will be used by rocPRIM internally.
43 // We don't use HIP vector types because they don't generate correct
44 // load/store operations, see https://github.com/RadeonOpenCompute/ROCm/issues/341
45 #ifndef _MSC_VER
46 #define DEFINE_VECTOR_TYPE(name, base) \
47 \
48 struct alignas(sizeof(base) * 2) name##2 \
49 { \
50  typedef base vector_value_type __attribute__((ext_vector_type(2))); \
51  union { \
52  vector_value_type data; \
53  struct { base x, y; }; \
54  }; \
55 }; \
56 \
57 struct alignas(sizeof(base) * 4) name##4 \
58 { \
59  typedef base vector_value_type __attribute__((ext_vector_type(4))); \
60  union { \
61  vector_value_type data; \
62  struct { base x, y, w, z; }; \
63  }; \
64 };
65 #else
66 #define DEFINE_VECTOR_TYPE(name, base) \
67 \
68 struct alignas(sizeof(base) * 2) name##2 \
69 { \
70  typedef base vector_value_type; \
71  union { \
72  vector_value_type data; \
73  struct { base x, y; }; \
74  }; \
75 }; \
76 \
77 struct alignas(sizeof(base) * 4) name##4 \
78 { \
79  typedef base vector_value_type; \
80  union { \
81  vector_value_type data; \
82  struct { base x, y, w, z; }; \
83  }; \
84 };
85 #endif
86 
87 #ifdef _MSC_VER
88 #pragma warning( push )
89 #pragma warning( disable : 4201 ) // nonstandard extension used: nameless struct/union
90 #endif
91 DEFINE_VECTOR_TYPE(char, char);
92 DEFINE_VECTOR_TYPE(short, short);
93 DEFINE_VECTOR_TYPE(int, int);
94 DEFINE_VECTOR_TYPE(longlong, long long);
95 #ifdef _MSC_VER
96 #pragma warning( pop )
97 #endif
98 // Takes a scalar type T and matches to a vector type based on NumElements.
99 template <class T, unsigned int NumElements>
101 {
102  using type = void;
103 };
104 
105 #define DEFINE_MAKE_VECTOR_N_TYPE(name, base, suffix) \
106 template<> \
107 struct make_vector_type<base, suffix> \
108 { \
109  using type = name##suffix; \
110 };
111 
112 #define DEFINE_MAKE_VECTOR_TYPE(name, base) \
113 \
114 template <> \
115 struct make_vector_type<base, 1> \
116 { \
117  using type = base; \
118 }; \
119 DEFINE_MAKE_VECTOR_N_TYPE(name, base, 2) \
120 DEFINE_MAKE_VECTOR_N_TYPE(name, base, 4)
121 
122 DEFINE_MAKE_VECTOR_TYPE(char, char);
123 DEFINE_MAKE_VECTOR_TYPE(short, short);
124 DEFINE_MAKE_VECTOR_TYPE(int, int);
125 DEFINE_MAKE_VECTOR_TYPE(longlong, long long);
126 
127 #undef DEFINE_VECTOR_TYPE
128 #undef DEFINE_MAKE_VECTOR_TYPE
129 #undef DEFINE_MAKE_VECTOR_N_TYPE
130 
131 } // end namespace detail
132 
135 struct empty_type {};
136 
140 {
142  constexpr empty_type operator()(const empty_type&, const empty_type&) const { return empty_type{}; }
143 };
144 
146 using half = ::__half;
148 using bfloat16 = ::hip_bfloat16;
149 
150 // The lane_mask_type only exist at device side
151 #ifndef __AMDGCN_WAVEFRONT_SIZE
152 // When not compiling with hipcc, we're compiling with HIP-CPU
153 // TODO: introduce a ROCPRIM-specific macro to query this
154 #define __AMDGCN_WAVEFRONT_SIZE 64
155 #endif
156 #if __AMDGCN_WAVEFRONT_SIZE == 32
162 using lane_mask_type = unsigned int;
163 #elif __AMDGCN_WAVEFRONT_SIZE == 64
164 using lane_mask_type = unsigned long long int;
165 #endif
166 
168 #ifdef __HIP_CPU_RT__
169 using native_half = half;
170 #else
171 using native_half = _Float16;
172 #endif
173 
175 #ifdef __HIP_CPU_RT__
176 // TODO: Find a better type
177 using native_bfloat16 = bfloat16;
178 #else
180 #endif
181 
182 END_ROCPRIM_NAMESPACE
183 
185 // end of group utilsmodule
186 
187 #endif // ROCPRIM_TYPES_HPP_
Empty type used as a placeholder, usually used to flag that given template parameter should not be us...
Definition: types.hpp:135
Definition: types.hpp:100
constexpr empty_type operator()(const empty_type &, const empty_type &) const
Invocation operator.
Definition: types.hpp:142
Deprecated: Configuration of device-level scan primitives.
Definition: block_histogram.hpp:62
::hip_bfloat16 bfloat16
bfloat16 floating point type
Definition: types.hpp:148
bfloat16 native_bfloat16
native bfloat16 type
Definition: types.hpp:179
_Float16 native_half
Native half-precision floating point type.
Definition: types.hpp:171
unsigned long long int lane_mask_type
The lane_mask_type is an integer that contains one bit per thread.
Definition: types.hpp:164
Binary operator that takes two instances of empty_type, usually used as nop replacement for the HIP-C...
Definition: types.hpp:139
::__half half
Half-precision floating point type.
Definition: types.hpp:146