36 #include <hipcub/util_type.hpp> 40 #if defined(__HIP_PLATFORM_NVIDIA__) 41 #include <cuda_bf16.h> 46 #pragma GCC diagnostic push 47 #pragma GCC diagnostic ignored "-Wstrict-aliasing" 62 #ifdef __HIP_PLATFORM_AMD__ 65 __host__ __device__ __forceinline__
68 __x =
reinterpret_cast<const uint16_t&
>(other);
71 #elif defined(__HIP_PLATFORM_NVIDIA__) 74 __host__ __device__ __forceinline__
77 __x =
reinterpret_cast<const uint16_t&
>(other);
83 __host__ __device__ __forceinline__
93 __host__ __device__ __forceinline__
100 ir = UINT16_C(0x7FFF);
108 uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF);
109 ir =
static_cast<uint16_t
>((U32 + rounding_bias) >> 16);
114 #ifdef __HIP_PLATFORM_AMD__ 117 __host__ __device__ __forceinline__
118 operator hip_bfloat16 ()
const 120 return reinterpret_cast<const hip_bfloat16 &
>(__x);
123 #elif defined(__HIP_PLATFORM_NVIDIA__) 126 __host__ __device__ __forceinline__
127 operator __nv_bfloat16()
const 129 return reinterpret_cast<const __nv_bfloat16&
>(__x);
135 __host__ __device__ __forceinline__
136 operator float()
const 139 uint32_t *p =
reinterpret_cast<uint32_t *
>(&f);
140 *p = uint32_t(__x) << 16;
145 __host__ __device__ __forceinline__
152 __host__ __device__ __forceinline__
154 return (a.__x == b.__x);
158 __host__ __device__ __forceinline__
161 return (this->__x != other.__x);
165 __host__ __device__ __forceinline__
168 *
this =
bfloat16_t(
float(*
this) +
float(rhs));
173 __host__ __device__ __forceinline__
176 return bfloat16_t(
float(*
this) *
float(other));
180 __host__ __device__ __forceinline__
183 return bfloat16_t(
float(*
this) +
float(other));
187 __host__ __device__ __forceinline__
190 return bfloat16_t(
float(*
this) -
float(other));
194 __host__ __device__ __forceinline__
197 return float(*
this) < float(other);
201 __host__ __device__ __forceinline__
204 return float(*
this) <= float(other);
208 __host__ __device__ __forceinline__
211 return float(*
this) > float(other);
215 __host__ __device__ __forceinline__
218 return float(*
this) >= float(other);
222 __host__ __device__ __forceinline__
224 uint16_t max_word = 0x7F7F;
225 return reinterpret_cast<bfloat16_t&
>(max_word);
229 __host__ __device__ __forceinline__
231 uint16_t lowest_word = 0xFF7F;
232 return reinterpret_cast<bfloat16_t&
>(lowest_word);
248 #if defined(__HIP_PLATFORM_NVIDIA__) 251 inline std::ostream&
operator<<(std::ostream &out,
const __nv_bfloat16 &x)
273 template <>
struct hipcub::NumericTraits<
bfloat16_t> : hipcub::BaseTraits<FLOATING_POINT, true, false, unsigned short, bfloat16_t> {};
276 #pragma GCC diagnostic pop __host__ __device__ __forceinline__ bfloat16_t & operator+=(const bfloat16_t &rhs)
Assignment by sum.
Definition: bfloat16.hpp:166
__host__ __device__ __forceinline__ bool operator<=(const bfloat16_t &other) const
Less-than-equal.
Definition: bfloat16.hpp:202
Host-based fp16 data type compatible and convertible with __nv_bfloat16 or hip_bfloat16.
Definition: bfloat16.hpp:58
Definition: thread_operators.hpp:105
__host__ __device__ __forceinline__ uint16_t raw() const
Get raw storage.
Definition: bfloat16.hpp:146
__host__ __device__ __forceinline__ bfloat16_t(float a)
Constructor from float.
Definition: bfloat16.hpp:94
__host__ __device__ __forceinline__ bfloat16_t operator+(const bfloat16_t &other)
Add.
Definition: bfloat16.hpp:181
__host__ __device__ __forceinline__ bool operator>=(const bfloat16_t &other) const
Greater-than-equal.
Definition: bfloat16.hpp:216
__host__ __device__ __forceinline__ bool operator>(const bfloat16_t &other) const
Greater-than.
Definition: bfloat16.hpp:209
bfloat16_t()=default
Default constructor.
__host__ __device__ __forceinline__ bfloat16_t(int a)
Constructor from integer.
Definition: bfloat16.hpp:84
__host__ __device__ __forceinline__ friend bool operator==(const bfloat16_t &a, const bfloat16_t &b)
Equality.
Definition: bfloat16.hpp:153
__host__ __device__ static __forceinline__ bfloat16_t max()
numeric_traits<bfloat16_t>::max
Definition: bfloat16.hpp:223
__host__ __device__ __forceinline__ bfloat16_t operator*(const bfloat16_t &other)
Multiply.
Definition: bfloat16.hpp:174
std::ostream & operator<<(std::ostream &out, const bfloat16_t &x)
Insert formatted bfloat16_t into the output stream.
Definition: bfloat16.hpp:242
__host__ __device__ static __forceinline__ bfloat16_t lowest()
numeric_traits<bfloat16_t>::lowest
Definition: bfloat16.hpp:230
__host__ __device__ __forceinline__ bfloat16_t operator-(const bfloat16_t &other)
Subtract.
Definition: bfloat16.hpp:188
__host__ __device__ __forceinline__ bool operator<(const bfloat16_t &other) const
Less-than.
Definition: bfloat16.hpp:195
__host__ __device__ __forceinline__ bool operator!=(const bfloat16_t &other) const
Inequality.
Definition: bfloat16.hpp:159