37 #include <hipcub/util_type.hpp> 39 #if defined(__HIP_PLATFORM_NVIDIA__) 40 #include <cuda_fp16.h> 47 #pragma GCC diagnostic push 48 #pragma GCC diagnostic ignored "-Wstrict-aliasing" 64 __host__ __device__ __forceinline__
67 __x =
reinterpret_cast<const uint16_t&
>(other);
71 __host__ __device__ __forceinline__
81 __host__ __device__ __forceinline__
85 uint32_t ia = *
reinterpret_cast<uint32_t*
>(&a);
88 ir = (ia >> 16) & 0x8000;
90 if ((ia & 0x7f800000) == 0x7f800000)
92 if ((ia & 0x7fffffff) == 0x7f800000)
101 else if ((ia & 0x7f800000) >= 0x33000000)
103 int32_t shift = (int32_t) ((ia >> 23) & 0xff) - 127;
110 ia = (ia & 0x007fffff) | 0x00800000;
113 ir |= ia >> (-1 - shift);
114 ia = ia << (32 - (-1 - shift));
118 ir |= ia >> (24 - 11);
119 ia = ia << (32 - (24 - 11));
120 ir =
static_cast<uint16_t
>(ir + ((14 + shift) << 10));
123 if ((ia > 0x80000000) || ((ia == 0x80000000) && (ir & 1)))
134 __host__ __device__ __forceinline__
135 operator float()
const 139 int sign = ((this->__x >> 15) & 1);
140 int exp = ((this->__x >> 10) & 0x1f);
141 int mantissa = (this->__x & 0x3ff);
144 if (exp > 0 && exp < 31)
148 f = (sign << 31) | (exp << 23) | (mantissa << 13);
156 while ((mantissa & (1 << 10)) == 0)
162 f = (sign << 31) | (exp << 23) | (mantissa << 13);
177 f = 0x7fffffff | (sign << 31);
181 f = (0xff << 23) | (sign << 31);
184 return *
reinterpret_cast<float const *
>(&f);
189 __host__ __device__ __forceinline__
196 __host__ __device__ __forceinline__
199 return (this->__x == other.__x);
203 __host__ __device__ __forceinline__
206 return (this->__x != other.__x);
210 __host__ __device__ __forceinline__
213 *
this =
half_t(
float(*
this) +
float(rhs));
218 __host__ __device__ __forceinline__
221 return half_t(
float(*
this) *
float(other));
227 return half_t(
float(*
this) /
float(other));
231 __host__ __device__ __forceinline__
234 return half_t(
float(*
this) +
float(other));
238 __host__ __device__ __forceinline__
241 return half_t(
float(*
this) -
float(other));
245 __host__ __device__ __forceinline__
248 return float(*
this) < float(other);
252 __host__ __device__ __forceinline__
255 return float(*
this) <= float(other);
259 __host__ __device__ __forceinline__
262 return float(*
this) > float(other);
266 __host__ __device__ __forceinline__
269 return float(*
this) >= float(other);
273 __host__ __device__ __forceinline__
275 uint16_t max_word = 0x7BFF;
276 return reinterpret_cast<half_t&
>(max_word);
280 __host__ __device__ __forceinline__
282 uint16_t lowest_word = 0xFBFF;
283 return reinterpret_cast<half_t&
>(lowest_word);
311 template <>
struct hipcub::NumericTraits<
half_t> : hipcub::BaseTraits<FLOATING_POINT, true, false, unsigned short, half_t> {};
315 #pragma GCC diagnostic pop __host__ __device__ __forceinline__ bool operator>(const half_t &other) const
Greater-than.
Definition: half.hpp:260
__host__ __device__ __forceinline__ bool operator>=(const half_t &other) const
Greater-than-equal.
Definition: half.hpp:267
__host__ __device__ __forceinline__ half_t(int a)
Constructor from integer.
Definition: half.hpp:72
__host__ __device__ __forceinline__ half_t operator+(const half_t &other)
Add.
Definition: half.hpp:232
__host__ __device__ __forceinline__ uint16_t raw() const
Get raw storage.
Definition: half.hpp:190
__host__ __device__ __forceinline__ bool operator<=(const half_t &other) const
Less-than-equal.
Definition: half.hpp:253
half_t()=default
Default constructor.
Definition: thread_operators.hpp:105
__host__ __device__ __forceinline__ half_t & operator+=(const half_t &rhs)
Assignment by sum.
Definition: half.hpp:211
std::ostream & operator<<(std::ostream &out, const half_t &x)
Insert formatted half_t into the output stream.
Definition: half.hpp:293
__host__ __device__ __forceinline__ half_t(float a)
Constructor from float.
Definition: half.hpp:82
Host-based fp16 data type compatible and convertible with __half.
Definition: half.hpp:59
__host__ __device__ __forceinline__ half_t(const __half &other)
Constructor from __half.
Definition: half.hpp:65
__host__ __device__ static __forceinline__ half_t lowest()
numeric_traits<half_t>::lowest
Definition: half.hpp:281
__host__ __device__ __forceinline__ bool operator==(const half_t &other) const
Equality.
Definition: half.hpp:197
__host__ __device__ __forceinline__ bool operator!=(const half_t &other) const
Inequality.
Definition: half.hpp:204
__host__ __device__ __forceinline__ half_t operator/(const half_t &other) const
Divide.
Definition: half.hpp:225
__host__ __device__ __forceinline__ half_t operator-(const half_t &other)
Subtract.
Definition: half.hpp:239
__host__ __device__ static __forceinline__ half_t max()
numeric_traits<half_t>::max
Definition: half.hpp:274
__host__ __device__ __forceinline__ bool operator<(const half_t &other) const
Less-than.
Definition: half.hpp:246
__host__ __device__ __forceinline__ half_t operator*(const half_t &other)
Multiply.
Definition: half.hpp:219