hipCUB
half.hpp
Go to the documentation of this file.
1 /******************************************************************************
2  * Copyright (c) 2011, Duane Merrill. All rights reserved.
3  * Copyright (c) 2011-2023, NVIDIA CORPORATION. All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  * * Redistributions of source code must retain the above copyright
8  * notice, this list of conditions and the following disclaimer.
9  * * Redistributions in binary form must reproduce the above copyright
10  * notice, this list of conditions and the following disclaimer in the
11  * documentation and/or other materials provided with the distribution.
12  * * Neither the name of the NVIDIA CORPORATION nor the
13  * names of its contributors may be used to endorse or promote products
14  * derived from this software without specific prior written permission.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
20  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  *
27  ******************************************************************************/
28 
29 #pragma once
30 
36  #include <stdint.h>
37  #include <hipcub/util_type.hpp>
38 
39  #if defined(__HIP_PLATFORM_NVIDIA__)
40  #include <cuda_fp16.h>
41  #endif
42 
43  #include <iosfwd>
44 
45 #ifdef __GNUC__
46 // There's a ton of type-punning going on in this file.
47 #pragma GCC diagnostic push
48 #pragma GCC diagnostic ignored "-Wstrict-aliasing"
49 #endif
50 
51 
52 /******************************************************************************
53  * half_t
54  ******************************************************************************/
55 
59 struct half_t
60 {
61  uint16_t __x;
62 
64  __host__ __device__ __forceinline__
65  half_t(const __half &other)
66  {
67  __x = reinterpret_cast<const uint16_t&>(other);
68  }
69 
71  __host__ __device__ __forceinline__
72  half_t(int a)
73  {
74  *this = half_t(float(a));
75  }
76 
78  half_t() = default;
79 
81  __host__ __device__ __forceinline__
82  half_t(float a)
83  {
84  // Stolen from Norbert Juffa
85  uint32_t ia = *reinterpret_cast<uint32_t*>(&a);
86  uint16_t ir;
87 
88  ir = (ia >> 16) & 0x8000;
89 
90  if ((ia & 0x7f800000) == 0x7f800000)
91  {
92  if ((ia & 0x7fffffff) == 0x7f800000)
93  {
94  ir |= 0x7c00; /* infinity */
95  }
96  else
97  {
98  ir = 0x7fff; /* canonical NaN */
99  }
100  }
101  else if ((ia & 0x7f800000) >= 0x33000000)
102  {
103  int32_t shift = (int32_t) ((ia >> 23) & 0xff) - 127;
104  if (shift > 15)
105  {
106  ir |= 0x7c00; /* infinity */
107  }
108  else
109  {
110  ia = (ia & 0x007fffff) | 0x00800000; /* extract mantissa */
111  if (shift < -14)
112  { /* denormal */
113  ir |= ia >> (-1 - shift);
114  ia = ia << (32 - (-1 - shift));
115  }
116  else
117  { /* normal */
118  ir |= ia >> (24 - 11);
119  ia = ia << (32 - (24 - 11));
120  ir = static_cast<uint16_t>(ir + ((14 + shift) << 10));
121  }
122  /* IEEE-754 round to nearest of even */
123  if ((ia > 0x80000000) || ((ia == 0x80000000) && (ir & 1)))
124  {
125  ir++;
126  }
127  }
128  }
129 
130  this->__x = ir;
131  }
132 
134  __host__ __device__ __forceinline__
135  operator float() const
136  {
137  // Stolen from Andrew Kerr
138 
139  int sign = ((this->__x >> 15) & 1);
140  int exp = ((this->__x >> 10) & 0x1f);
141  int mantissa = (this->__x & 0x3ff);
142  uint32_t f = 0;
143 
144  if (exp > 0 && exp < 31)
145  {
146  // normal
147  exp += 112;
148  f = (sign << 31) | (exp << 23) | (mantissa << 13);
149  }
150  else if (exp == 0)
151  {
152  if (mantissa)
153  {
154  // subnormal
155  exp += 113;
156  while ((mantissa & (1 << 10)) == 0)
157  {
158  mantissa <<= 1;
159  exp--;
160  }
161  mantissa &= 0x3ff;
162  f = (sign << 31) | (exp << 23) | (mantissa << 13);
163  }
164  else if (sign)
165  {
166  f = 0x80000000; // negative zero
167  }
168  else
169  {
170  f = 0x0; // zero
171  }
172  }
173  else if (exp == 31)
174  {
175  if (mantissa)
176  {
177  f = 0x7fffffff | (sign << 31); // not a number
178  }
179  else
180  {
181  f = (0xff << 23) | (sign << 31); // inf
182  }
183  }
184  return *reinterpret_cast<float const *>(&f);
185  }
186 
187 
189  __host__ __device__ __forceinline__
190  uint16_t raw() const
191  {
192  return this->__x;
193  }
194 
196  __host__ __device__ __forceinline__
197  bool operator ==(const half_t &other) const
198  {
199  return (this->__x == other.__x);
200  }
201 
203  __host__ __device__ __forceinline__
204  bool operator !=(const half_t &other) const
205  {
206  return (this->__x != other.__x);
207  }
208 
210  __host__ __device__ __forceinline__
211  half_t& operator +=(const half_t &rhs)
212  {
213  *this = half_t(float(*this) + float(rhs));
214  return *this;
215  }
216 
218  __host__ __device__ __forceinline__
219  half_t operator*(const half_t &other)
220  {
221  return half_t(float(*this) * float(other));
222  }
223 
225  __host__ __device__ __forceinline__ half_t operator/(const half_t& other) const
226  {
227  return half_t(float(*this) / float(other));
228  }
229 
231  __host__ __device__ __forceinline__
232  half_t operator+(const half_t &other)
233  {
234  return half_t(float(*this) + float(other));
235  }
236 
238  __host__ __device__ __forceinline__
239  half_t operator-(const half_t &other)
240  {
241  return half_t(float(*this) - float(other));
242  }
243 
245  __host__ __device__ __forceinline__
246  bool operator<(const half_t &other) const
247  {
248  return float(*this) < float(other);
249  }
250 
252  __host__ __device__ __forceinline__
253  bool operator<=(const half_t &other) const
254  {
255  return float(*this) <= float(other);
256  }
257 
259  __host__ __device__ __forceinline__
260  bool operator>(const half_t &other) const
261  {
262  return float(*this) > float(other);
263  }
264 
266  __host__ __device__ __forceinline__
267  bool operator>=(const half_t &other) const
268  {
269  return float(*this) >= float(other);
270  }
271 
273  __host__ __device__ __forceinline__
274  static half_t max() {
275  uint16_t max_word = 0x7BFF;
276  return reinterpret_cast<half_t&>(max_word);
277  }
278 
280  __host__ __device__ __forceinline__
281  static half_t lowest() {
282  uint16_t lowest_word = 0xFBFF;
283  return reinterpret_cast<half_t&>(lowest_word);
284  }
285 };
286 
287 
288 /******************************************************************************
289  * I/O stream overloads
290  ******************************************************************************/
291 
293 inline std::ostream& operator<<(std::ostream &out, const half_t &x)
294 {
295  out << (float)x;
296  return out;
297 }
298 
299 /******************************************************************************
300  * Traits overloads
301  ******************************************************************************/
302 
303 template <>
304 struct hipcub::FpLimits<half_t>
305 {
306  static __host__ __device__ __forceinline__ half_t Max() { return half_t::max(); }
307 
308  static __host__ __device__ __forceinline__ half_t Lowest() { return half_t::lowest(); }
309 };
310 
311 template <> struct hipcub::NumericTraits<half_t> : hipcub::BaseTraits<FLOATING_POINT, true, false, unsigned short, half_t> {};
312 
313 
314 #ifdef __GNUC__
315 #pragma GCC diagnostic pop
316 #endif
__host__ __device__ __forceinline__ bool operator>(const half_t &other) const
Greater-than.
Definition: half.hpp:260
__host__ __device__ __forceinline__ bool operator>=(const half_t &other) const
Greater-than-equal.
Definition: half.hpp:267
__host__ __device__ __forceinline__ half_t(int a)
Constructor from integer.
Definition: half.hpp:72
__host__ __device__ __forceinline__ half_t operator+(const half_t &other)
Add.
Definition: half.hpp:232
__host__ __device__ __forceinline__ uint16_t raw() const
Get raw storage.
Definition: half.hpp:190
__host__ __device__ __forceinline__ bool operator<=(const half_t &other) const
Less-than-equal.
Definition: half.hpp:253
half_t()=default
Default constructor.
Definition: thread_operators.hpp:105
__host__ __device__ __forceinline__ half_t & operator+=(const half_t &rhs)
Assignment by sum.
Definition: half.hpp:211
std::ostream & operator<<(std::ostream &out, const half_t &x)
Insert formatted half_t into the output stream.
Definition: half.hpp:293
__host__ __device__ __forceinline__ half_t(float a)
Constructor from float.
Definition: half.hpp:82
Host-based fp16 data type compatible and convertible with __half.
Definition: half.hpp:59
__host__ __device__ __forceinline__ half_t(const __half &other)
Constructor from __half.
Definition: half.hpp:65
__host__ __device__ static __forceinline__ half_t lowest()
numeric_traits<half_t>::lowest
Definition: half.hpp:281
__host__ __device__ __forceinline__ bool operator==(const half_t &other) const
Equality.
Definition: half.hpp:197
__host__ __device__ __forceinline__ bool operator!=(const half_t &other) const
Inequality.
Definition: half.hpp:204
__host__ __device__ __forceinline__ half_t operator/(const half_t &other) const
Divide.
Definition: half.hpp:225
__host__ __device__ __forceinline__ half_t operator-(const half_t &other)
Subtract.
Definition: half.hpp:239
__host__ __device__ static __forceinline__ half_t max()
numeric_traits<half_t>::max
Definition: half.hpp:274
__host__ __device__ __forceinline__ bool operator<(const half_t &other) const
Less-than.
Definition: half.hpp:246
__host__ __device__ __forceinline__ half_t operator*(const half_t &other)
Multiply.
Definition: half.hpp:219