hipCUB
bfloat16.hpp
Go to the documentation of this file.
1 /******************************************************************************
2  * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright
7  * notice, this list of conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright
9  * notice, this list of conditions and the following disclaimer in the
10  * documentation and/or other materials provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the
12  * names of its contributors may be used to endorse or promote products
13  * derived from this software without specific prior written permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18  * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
19  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25  *
26  ******************************************************************************/
27 
28 #pragma once
29 
35 #include <stdint.h>
36 #include <hipcub/util_type.hpp>
37 
38 #include <iosfwd>
39 
40 #if defined(__HIP_PLATFORM_NVIDIA__)
41 #include <cuda_bf16.h>
42 #endif
43 
44 #ifdef __GNUC__
45 // There's a ton of type-punning going on in this file.
46 #pragma GCC diagnostic push
47 #pragma GCC diagnostic ignored "-Wstrict-aliasing"
48 #endif
49 
50 
51 /******************************************************************************
52  * bfloat16_t
53  ******************************************************************************/
54 
58 struct bfloat16_t
59 {
60  uint16_t __x;
61 
62 #ifdef __HIP_PLATFORM_AMD__
63 
65  __host__ __device__ __forceinline__
66  bfloat16_t(const hip_bfloat16 &other)
67  {
68  __x = reinterpret_cast<const uint16_t&>(other);
69  }
70 
71 #elif defined(__HIP_PLATFORM_NVIDIA__)
72 
74  __host__ __device__ __forceinline__
75  bfloat16_t(const __nv_bfloat16 &other)
76  {
77  __x = reinterpret_cast<const uint16_t&>(other);
78  }
79 
80 #endif
81 
83  __host__ __device__ __forceinline__
84  bfloat16_t(int a)
85  {
86  *this = bfloat16_t(float(a));
87  }
88 
90  bfloat16_t() = default;
91 
93  __host__ __device__ __forceinline__
94  bfloat16_t(float a)
95  {
96  // Reference:
97  // https://github.com/pytorch/pytorch/blob/44cc873fba5e5ffc4d4d4eef3bd370b653ce1ce1/c10/util/BFloat16.h#L51
98  uint16_t ir;
99  if (a != a) {
100  ir = UINT16_C(0x7FFF);
101  } else {
102  union {
103  uint32_t U32;
104  float F32;
105  };
106 
107  F32 = a;
108  uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF);
109  ir = static_cast<uint16_t>((U32 + rounding_bias) >> 16);
110  }
111  this->__x = ir;
112  }
113 
114 #ifdef __HIP_PLATFORM_AMD__
115 
117  __host__ __device__ __forceinline__
118  operator hip_bfloat16 () const
119  {
120  return reinterpret_cast<const hip_bfloat16 &>(__x);
121  }
122 
123 #elif defined(__HIP_PLATFORM_NVIDIA__)
124 
126  __host__ __device__ __forceinline__
127  operator __nv_bfloat16() const
128  {
129  return reinterpret_cast<const __nv_bfloat16&>(__x);
130  }
131 
132 #endif
133 
135  __host__ __device__ __forceinline__
136  operator float() const
137  {
138  float f = 0;
139  uint32_t *p = reinterpret_cast<uint32_t *>(&f);
140  *p = uint32_t(__x) << 16;
141  return f;
142  }
143 
145  __host__ __device__ __forceinline__
146  uint16_t raw() const
147  {
148  return this->__x;
149  }
150 
152  __host__ __device__ __forceinline__
153  friend bool operator ==(const bfloat16_t &a, const bfloat16_t &b){
154  return (a.__x == b.__x);
155  }
156 
158  __host__ __device__ __forceinline__
159  bool operator !=(const bfloat16_t &other) const
160  {
161  return (this->__x != other.__x);
162  }
163 
165  __host__ __device__ __forceinline__
167  {
168  *this = bfloat16_t(float(*this) + float(rhs));
169  return *this;
170  }
171 
173  __host__ __device__ __forceinline__
175  {
176  return bfloat16_t(float(*this) * float(other));
177  }
178 
180  __host__ __device__ __forceinline__
182  {
183  return bfloat16_t(float(*this) + float(other));
184  }
185 
187  __host__ __device__ __forceinline__
189  {
190  return bfloat16_t(float(*this) - float(other));
191  }
192 
194  __host__ __device__ __forceinline__
195  bool operator<(const bfloat16_t &other) const
196  {
197  return float(*this) < float(other);
198  }
199 
201  __host__ __device__ __forceinline__
202  bool operator<=(const bfloat16_t &other) const
203  {
204  return float(*this) <= float(other);
205  }
206 
208  __host__ __device__ __forceinline__
209  bool operator>(const bfloat16_t &other) const
210  {
211  return float(*this) > float(other);
212  }
213 
215  __host__ __device__ __forceinline__
216  bool operator>=(const bfloat16_t &other) const
217  {
218  return float(*this) >= float(other);
219  }
220 
222  __host__ __device__ __forceinline__
223  static bfloat16_t max() {
224  uint16_t max_word = 0x7F7F;
225  return reinterpret_cast<bfloat16_t&>(max_word);
226  }
227 
229  __host__ __device__ __forceinline__
230  static bfloat16_t lowest() {
231  uint16_t lowest_word = 0xFF7F;
232  return reinterpret_cast<bfloat16_t&>(lowest_word);
233  }
234 };
235 
236 
237 /******************************************************************************
238  * I/O stream overloads
239  ******************************************************************************/
240 
242 inline std::ostream& operator<<(std::ostream &out, const bfloat16_t &x)
243 {
244  out << (float)x;
245  return out;
246 }
247 
248 #if defined(__HIP_PLATFORM_NVIDIA__)
249 
251  inline std::ostream& operator<<(std::ostream &out, const __nv_bfloat16 &x)
252  {
253  return out << bfloat16_t(x);
254  }
255 
256 #endif
257 
258 
259 
260 
261 /******************************************************************************
262  * Traits overloads
263  ******************************************************************************/
264 
265 template <>
266 struct hipcub::FpLimits<bfloat16_t>
267 {
268  static __host__ __device__ __forceinline__ bfloat16_t Max() { return bfloat16_t::max(); }
269 
270  static __host__ __device__ __forceinline__ bfloat16_t Lowest() { return bfloat16_t::lowest(); }
271 };
272 
273 template <> struct hipcub::NumericTraits<bfloat16_t> : hipcub::BaseTraits<FLOATING_POINT, true, false, unsigned short, bfloat16_t> {};
274 
275 #ifdef __GNUC__
276 #pragma GCC diagnostic pop
277 #endif
__host__ __device__ __forceinline__ bfloat16_t & operator+=(const bfloat16_t &rhs)
Assignment by sum.
Definition: bfloat16.hpp:166
__host__ __device__ __forceinline__ bool operator<=(const bfloat16_t &other) const
Less-than-equal.
Definition: bfloat16.hpp:202
Host-based fp16 data type compatible and convertible with __nv_bfloat16 or hip_bfloat16.
Definition: bfloat16.hpp:58
Definition: thread_operators.hpp:105
__host__ __device__ __forceinline__ uint16_t raw() const
Get raw storage.
Definition: bfloat16.hpp:146
__host__ __device__ __forceinline__ bfloat16_t(float a)
Constructor from float.
Definition: bfloat16.hpp:94
__host__ __device__ __forceinline__ bfloat16_t operator+(const bfloat16_t &other)
Add.
Definition: bfloat16.hpp:181
__host__ __device__ __forceinline__ bool operator>=(const bfloat16_t &other) const
Greater-than-equal.
Definition: bfloat16.hpp:216
__host__ __device__ __forceinline__ bool operator>(const bfloat16_t &other) const
Greater-than.
Definition: bfloat16.hpp:209
bfloat16_t()=default
Default constructor.
__host__ __device__ __forceinline__ bfloat16_t(int a)
Constructor from integer.
Definition: bfloat16.hpp:84
__host__ __device__ __forceinline__ friend bool operator==(const bfloat16_t &a, const bfloat16_t &b)
Equality.
Definition: bfloat16.hpp:153
__host__ __device__ static __forceinline__ bfloat16_t max()
numeric_traits<bfloat16_t>::max
Definition: bfloat16.hpp:223
__host__ __device__ __forceinline__ bfloat16_t operator*(const bfloat16_t &other)
Multiply.
Definition: bfloat16.hpp:174
std::ostream & operator<<(std::ostream &out, const bfloat16_t &x)
Insert formatted bfloat16_t into the output stream.
Definition: bfloat16.hpp:242
__host__ __device__ static __forceinline__ bfloat16_t lowest()
numeric_traits<bfloat16_t>::lowest
Definition: bfloat16.hpp:230
__host__ __device__ __forceinline__ bfloat16_t operator-(const bfloat16_t &other)
Subtract.
Definition: bfloat16.hpp:188
__host__ __device__ __forceinline__ bool operator<(const bfloat16_t &other) const
Less-than.
Definition: bfloat16.hpp:195
__host__ __device__ __forceinline__ bool operator!=(const bfloat16_t &other) const
Inequality.
Definition: bfloat16.hpp:159