HIP: Heterogenous-computing Interface for Portability
Loading...
Searching...
No Matches
amd_hip_fp8.h
Go to the documentation of this file.
1
30#ifndef _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP8_H_
31#define _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP8_H_
32
33#if (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) && __HIP_DEVICE_COMPILE__
34#define HIP_FP8_CVT_FAST_PATH 1
35#else
36#define HIP_FP8_CVT_FAST_PATH 0
37#endif
38
39#if !defined(__HIPCC_RTC__)
40#include <hip/amd_detail/amd_hip_common.h>
41#include <climits>
42
43#include "host_defines.h" // __hip_internal::
44#include "amd_hip_vector_types.h" // float2 etc
45#include "amd_hip_fp16.h" // __half_raw
46#include "amd_hip_bf16.h" // bf16
47#include "math_fwd.h" // ocml device functions
48#endif // !defined(__HIPCC_RTC__)
49
50#if defined(__HIPCC_RTC__)
51#define __FP8_HOST_DEVICE__ __device__
52#define __FP8_HOST_DEVICE_STATIC__ __FP8_HOST_DEVICE__ static
53#else
54#define __FP8_HOST_DEVICE__ __host__ __device__
55#define __FP8_HOST_DEVICE_STATIC__ __FP8_HOST_DEVICE__ static inline
56#endif // __HIPCC_RTC__
57
58#if !defined(__HIPCC_RTC__)
59static_assert(CHAR_BIT == 8, "byte size should be of 8 bits");
60#endif
61static_assert(sizeof(unsigned char) == 1);
62static_assert(sizeof(unsigned short int) == 2);
63static_assert(sizeof(unsigned int) == 4);
64
72
80
85typedef unsigned char __hip_fp8_storage_t;
86
87
92typedef unsigned short int __hip_fp8x2_storage_t;
93
94
99typedef unsigned int __hip_fp8x4_storage_t;
100
101namespace internal {
102// The conversion function is from rocblas
103// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L39
104// This has been modified to add double types conversion as well
105template <typename T, bool negative_zero_nan>
106__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t cast_to_f8(T _x, int wm, int we, bool clip = false,
107 bool stoch = false,
108 unsigned int rng = 0) {
109 constexpr bool is_half = __hip_internal::is_same<T, _Float16>::value;
110 constexpr bool is_float = __hip_internal::is_same<T, float>::value;
111 constexpr bool is_double = __hip_internal::is_same<T, double>::value;
112 static_assert(is_half || is_float || is_double, "Only half, float and double can be cast to f8");
113
114 const int mfmt = (sizeof(T) == 8) ? 52 : ((sizeof(T) == 4) ? 23 : 10);
115 unsigned long long x;
116
117 if (sizeof(T) == 8)
118 x = reinterpret_cast<unsigned long long&>(_x);
119 else if (sizeof(T) == 4)
120 x = reinterpret_cast<unsigned int&>(_x);
121 else
122 x = reinterpret_cast<unsigned short int&>(_x);
123
124
125 unsigned long long head, mantissa;
126 int exponent, bias;
127 unsigned int sign;
128
129 if (sizeof(T) == 8) {
130 head = x & 0xFFF0000000000000ull;
131 mantissa = x & 0xFFFFFFFFFFFFFull;
132 exponent = (head >> 52) & 0x7FF;
133 sign = head >> 63;
134 bias = 1023;
135 } else if (sizeof(T) == 4) {
136 head = x & 0xFF800000;
137 mantissa = x & 0x7FFFFF;
138 exponent = (head >> 23) & 0xFF;
139 sign = head >> 31;
140 bias = 127;
141 } else {
142 head = x & 0xFC00;
143 mantissa = x & 0x3FF;
144 exponent = (head >> 10) & 0x1F;
145 sign = head >> 15;
146 bias = 15;
147 }
148
149 unsigned int signed_inf = (sign << 7) + (((1 << we) - 1) << wm);
150
151 // Deal with inf and NaNs
152 if (negative_zero_nan) {
153 if (sizeof(T) == 8) {
154 if ((x & 0x7FF0000000000000ull) == 0x7FF0000000000000ull) return 0x80;
155 } else if (sizeof(T) == 4) {
156 if ((x & 0x7F800000) == 0x7F800000) return 0x80;
157 } else {
158 if ((x & 0x7C00) == 0x7C00) return 0x80;
159 }
160 } else {
161 if (sizeof(T) == 8) {
162 if ((x & 0x7FF0000000000000ull) == 0x7FF0000000000000ull)
163 return signed_inf + (mantissa != 0 ? 1 : 0);
164 } else if (sizeof(T) == 4) {
165 if ((x & 0x7F800000) == 0x7F800000) return signed_inf + (mantissa != 0 ? 1 : 0);
166 } else {
167 if ((x & 0x7C00) == 0x7C00) return signed_inf + (mantissa != 0 ? 1 : 0);
168 }
169 }
170
171 if (x == 0) {
172 return 0;
173 }
174
175 // First need to check if it is normal or denorm as there is a difference of implict 1
176 // Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
177 // The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for
178 // RNE, no need to add rng. Then probably need to check whether there is carry and adjust
179 // exponent and mantissa again
180
181 // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent bits
182 const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0);
183 const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
184 // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
185 // f8_exponent is the converted f8 exponent with bias encoding
186 // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
187 // the difference needs to be adjusted and mantissa shifted
188 int act_exponent, f8_exponent, exponent_diff;
189
190 if (exponent == 0) { // fp32/fp16 is in denormal.
191 /* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16
192here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has
193exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in
194fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
195where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal. In
196this case, the fp16 mantissa should be shift left by 1 */
197 act_exponent = exponent - bias + 1;
198 exponent_diff = f8_denormal_act_exponent -
199 act_exponent; // actual exponent is exponent-bias+1 as it is denormal
200 } else { // fp32/fp16 is normal with implicit 1
201 act_exponent = exponent - bias;
202 if (act_exponent <= f8_denormal_act_exponent) {
203 /* This is the case where fp32/fp16 is normal but it is in f8 denormal range.
204For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
205actual exponent is -7, it is actually larger due to the implict 1,
206Therefore it needs to be adjust to -6 and mantissa shift right by 1.
207So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
208 exponent_diff = f8_denormal_act_exponent - act_exponent;
209 } else { // both fp32/fp16 and f8 are in normal range
210 exponent_diff = 0; // exponent_diff=0 does not mean there is no difference for this case,
211 // act_exponent could be larger. Just that it does not need shift mantissa
212 }
213 mantissa += (1ull << mfmt); // Add the implicit 1 into mantissa
214 }
215
216 bool midpoint = (mantissa & ((1ull << (mfmt - wm + exponent_diff)) - 1)) ==
217 (1ull << (mfmt - wm + exponent_diff - 1));
218 /* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we shift
219right as shift right could rip off some residual part and make something not midpoint look like
220midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than midpoint, but
221after shift right by 4 bits, it would look like midpoint.
222*/
223
224 if (exponent_diff > 0)
225 mantissa >>= exponent_diff;
226 else if (exponent_diff == -1)
227 mantissa <<= -exponent_diff;
228 bool implicit_one = mantissa & (1ull << mfmt);
229 // if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
230 f8_exponent =
231 (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
232
233 // Now we have the exponent and mantissa adjusted
234 unsigned long long drop_mask = (1ull << (mfmt - wm)) - 1;
235 bool odd =
236 mantissa & (1ull << (mfmt - wm)); // if the least significant bit that is not truncated is 1
237 mantissa +=
238 (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask;
239
240 // Now we deal with overflow
241 if (f8_exponent == 0) {
242 if ((1ull << mfmt) & mantissa) {
243 f8_exponent = 1; // denormal overflow to become normal, promote exponent
244 }
245 } else {
246 if ((1ull << (mfmt + 1)) & mantissa) {
247 mantissa >>= 1;
248 f8_exponent++;
249 }
250 }
251
252 mantissa >>= (mfmt - wm);
253
254 // above range: quantize to maximum possible float of the same sign
255 const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2);
256 if (f8_exponent > max_exp) {
257 if (clip) {
258 mantissa = (1 << wm) - 1;
259 f8_exponent = max_exp;
260 } else {
261 return signed_inf;
262 }
263 }
264
265 if (f8_exponent == 0 && mantissa == 0) return negative_zero_nan ? 0 : (sign << 7);
266 mantissa &= (1 << wm) - 1;
267 return (sign << 7) | (f8_exponent << wm) | mantissa;
268}
269
270// The conversion function is from rocblas
271// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L220
272// This has been modified to handle double types as well
273template <typename T, bool negative_zero_nan>
274__FP8_HOST_DEVICE_STATIC__ T cast_from_f8(__hip_fp8_storage_t x, int wm, int we) {
275 constexpr bool is_half = __hip_internal::is_same<T, _Float16>::value;
276 constexpr bool is_float = __hip_internal::is_same<T, float>::value;
277 constexpr bool is_double = __hip_internal::is_same<T, double>::value;
278 static_assert(is_half || is_float || is_double, "only half, float and double are supported");
279
280 constexpr int weo = is_half ? 5 : (is_float ? 8 : 11);
281 constexpr int wmo = is_half ? 10 : (is_float ? 23 : 52);
282
283 T fInf, fNegInf, fNaN, fNeg0;
284 if (is_half) {
285 const unsigned short int ihInf = 0x7C00;
286 const unsigned short int ihNegInf = 0xFC00;
287 const unsigned short int ihNaN = 0x7C01;
288 const unsigned short int ihNeg0 = 0x8000;
289 fInf = reinterpret_cast<const _Float16&>(ihInf);
290 fNegInf = reinterpret_cast<const _Float16&>(ihNegInf);
291 fNaN = reinterpret_cast<const _Float16&>(ihNaN);
292 fNeg0 = reinterpret_cast<const _Float16&>(ihNeg0);
293 } else if (is_float) {
294 const unsigned int ifInf = 0x7F800000;
295 const unsigned int ifNegInf = 0xFF800000;
296 const unsigned int ifNaN = 0x7F800001;
297 const unsigned int ifNeg0 = 0x80000000;
298 fInf = reinterpret_cast<const float&>(ifInf);
299 fNegInf = reinterpret_cast<const float&>(ifNegInf);
300 fNaN = reinterpret_cast<const float&>(ifNaN);
301 fNeg0 = reinterpret_cast<const float&>(ifNeg0);
302 } else if (is_double) {
303 const unsigned long long ifInf = 0x7FF0000000000000ull;
304 const unsigned long long ifNegInf = 0xFFF0000000000000ull;
305 const unsigned long long ifNaN = 0x7FF0000000000001ull;
306 const unsigned long long ifNeg0 = 0x8000000000000000ull;
307 fInf = reinterpret_cast<const double&>(ifInf);
308 fNegInf = reinterpret_cast<const double&>(ifNegInf);
309 fNaN = reinterpret_cast<const double&>(ifNaN);
310 fNeg0 = reinterpret_cast<const double&>(ifNeg0);
311 }
312
313 if (x == 0) {
314 return 0;
315 }
316
317 unsigned long long sign = x >> 7;
318 unsigned long long mantissa = x & ((1 << wm) - 1);
319 int exponent = (x & 0x7F) >> wm;
320 if (negative_zero_nan) {
321 if (x == 0x80) return fNaN;
322 } else {
323 if (x == 0x80) return fNeg0;
324 if (exponent == ((1 << we) - 1)) return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
325 }
326
327 typename __hip_internal::conditional<
328 sizeof(T) == 2, unsigned short int,
329 typename __hip_internal::conditional<sizeof(T) == 4, unsigned int,
330 unsigned long long>::type>::type retval;
331
332 if (we == 5 && is_half && !negative_zero_nan) {
333 retval = x << 8;
334 return reinterpret_cast<const T&>(retval);
335 }
336
337 const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
338
339 // subnormal input
340 if (exponent == 0) {
341#if __HIP_DEVICE_COMPILE__
342 // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
343 int sh = 1 + __clz(mantissa) - (32 - wm);
344#else
345 int sh = 1 + __builtin_clz(mantissa) - (32 - wm);
346#endif
347 mantissa <<= sh;
348 exponent += 1 - sh;
349 mantissa &= ((1ull << wm) - 1);
350 }
351 exponent += exp_low_cutoff - 1;
352 mantissa <<= wmo - wm;
353
354 // subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
355 if (exponent <= 0) {
356 mantissa |= 1 << wmo;
357 mantissa >>= 1 - exponent;
358 exponent = 0;
359 }
360
361 if (sizeof(T) == 2)
362 retval = (sign << 15) | (exponent << 10) | mantissa;
363 else if (sizeof(T) == 4)
364 retval = (sign << 31) | (exponent << 23) | mantissa;
365 else
366 retval = (sign << 63) | (static_cast<unsigned long long>(exponent) << 52) | mantissa;
367 return reinterpret_cast<const T&>(retval);
368}
369
370#if HIP_FP8_CVT_FAST_PATH
371// The conversion function is from rocblas
372// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_float8.h#L79
373template <bool stochastic_rounding = false>
374static __device__ __hip_fp8_storage_t cast_to_f8_from_f32(float v, bool saturate,
376 unsigned int rng = 0) {
377 __hip_fp8_storage_t i8data;
378 union {
379 float fval;
380 unsigned int i32val;
381 unsigned char i8val[4]; // NOTE: not endian independent
382 } val;
383
384 unsigned int ival = 0;
385 val.fval = v;
386
387 if (saturate) {
388 if (interpret == __HIP_E4M3_FNUZ) {
389 if ((val.i32val & 0x7F800000) != 0x7F800000) {
390 val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
391 }
392 } else {
393 if ((val.i32val & 0x7F800000) != 0x7F800000) {
394 val.fval = __builtin_amdgcn_fmed3f(val.fval, 57344.0, -57344.0);
395 }
396 }
397 }
398
399 if (stochastic_rounding) {
400 ival = interpret == __HIP_E4M3_FNUZ
401 ? __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0)
402 : __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
403 val.i32val = ival;
404 i8data = val.i8val[0]; // little endian
405 } else { // RNE CVT
406 ival = interpret == __HIP_E4M3_FNUZ
407 ? __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false)
408 : __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0
409 val.i32val = ival;
410 i8data = val.i8val[0];
411 }
412 return i8data;
413}
414
415static __device__ __hip_fp8x2_storage_t
416cast_to_f8x2_from_f32x2(float2 v, bool saturate, __hip_fp8_interpretation_t interpret) {
417 union {
418 static_assert(sizeof(float2) == sizeof(unsigned int[2]));
419 static_assert(sizeof(float2) == sizeof(unsigned short[4]));
420 float2 fval;
421 unsigned int i32val[2];
422 unsigned short i16val[4];
423 } f2val;
424
425 f2val.fval = v;
426
427 if (saturate) {
428 if ((f2val.i32val[0] & 0x7F800000) != 0x7F800000) {
429 f2val.fval.x = __builtin_amdgcn_fmed3f(f2val.fval.x, 240.0, -240.0);
430 }
431 if ((f2val.i32val[1] & 0x7F800000) != 0x7F800000) {
432 f2val.fval.y = __builtin_amdgcn_fmed3f(f2val.fval.x, 240.0, -240.0);
433 }
434 }
435
436 f2val.i32val[0] = interpret == __HIP_E4M3_FNUZ
437 ? __builtin_amdgcn_cvt_pk_fp8_f32(v.x, v.y, 0, false)
438 : __builtin_amdgcn_cvt_pk_bf8_f32(v.x, v.y, 0, false);
439
440 return static_cast<__hip_fp8x2_storage_t>(f2val.i16val[0]);
441}
442
443static __device__ float cast_to_f32_from_f8(__hip_fp8_storage_t v,
444 __hip_fp8_interpretation_t interpret) {
445 union {
446 unsigned int i32val;
447 unsigned char i8val[4];
448 } val;
449 val.i8val[0] = v;
450
451 float fval = interpret == __HIP_E4M3_FNUZ ? __builtin_amdgcn_cvt_f32_fp8(val.i32val, 0)
452 : __builtin_amdgcn_cvt_f32_bf8(val.i32val, 0);
453 return fval;
454}
455
456static __device__ float2 cast_to_f32x2_from_f8x2(__hip_fp8x2_storage_t v,
457 __hip_fp8_interpretation_t interpret) {
458 union {
459 unsigned int i32val;
460 unsigned short i16val[2];
461 } val;
462 val.i16val[0] = v;
463
464 auto f2 = interpret == __HIP_E4M3_FNUZ ? __builtin_amdgcn_cvt_pk_f32_fp8(val.i32val, false)
465 : __builtin_amdgcn_cvt_pk_f32_bf8(val.i32val, false);
466 return float2{f2[0], f2[1]};
467}
468#endif // HIP_FP8_CVT_FAST_PATH
469
470/* For fp8 fnuz types, finite and NaN values are supported. Zero is unsigned.
471Inf are not supported. This gives us one additional number to represent.
472NaN are represented by 1-0000-000 or 1-00000-00 */
473__FP8_HOST_DEVICE_STATIC__ bool hip_fp8_fnuz_is_nan(__hip_fp8_storage_t a) {
474 return static_cast<unsigned char>(a) == 0x80;
475}
476} // namespace internal
477
487 const float f, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) {
488#if HIP_FP8_CVT_FAST_PATH
489 return internal::cast_to_f8_from_f32<false>(f, sat == __HIP_SATFINITE, type);
490#else // HIP_FP8_CVT_FAST_PATH
491 int we = type == __HIP_E4M3_FNUZ ? 4 : 5;
492 int wm = type == __HIP_E4M3_FNUZ ? 3 : 2;
493 return internal::cast_to_f8<float, true>(f, wm, we, sat == __HIP_SATFINITE);
494#endif // HIP_FP8_CVT_FAST_PATH
495}
496
506 const float2 f2, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) {
507#if HIP_FP8_CVT_FAST_PATH
508 return internal::cast_to_f8x2_from_f32x2(f2, sat == __HIP_SATFINITE, type);
509#else
510 return static_cast<__hip_fp8x2_storage_t>(
511 static_cast<unsigned short int>(__hip_cvt_float_to_fp8(f2.y, sat, type)) << 8 |
512 static_cast<unsigned short int>(__hip_cvt_float_to_fp8(f2.x, sat, type)));
513#endif
514}
515
525 const double d, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) {
526 int we = type == __HIP_E4M3_FNUZ ? 4 : 5;
527 int wm = type == __HIP_E4M3_FNUZ ? 3 : 2;
528 return internal::cast_to_f8<double, true>(d, wm, we, sat == __HIP_SATFINITE);
529}
530
540 const double2 d2, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) {
541 return static_cast<__hip_fp8x2_storage_t>(
542 static_cast<unsigned short int>(__hip_cvt_double_to_fp8(d2.y, sat, type)) << 8 |
543 static_cast<unsigned short int>(__hip_cvt_double_to_fp8(d2.x, sat, type)));
544}
545
554__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t
555__hip_cvt_bfloat16raw_to_fp8(const __hip_bfloat16_raw hr, const __hip_saturation_t sat,
556 const __hip_fp8_interpretation_t type) {
557 float fval = __hip_bfloat16(hr);
558 return __hip_cvt_float_to_fp8(fval, sat, type);
559}
560
569__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t
570__hip_cvt_bfloat16raw2_to_fp8x2(const __hip_bfloat162_raw hr, const __hip_saturation_t sat,
571 const __hip_fp8_interpretation_t type) {
572 float2 f2 = __hip_bfloat162(hr);
573 return __hip_cvt_float2_to_fp8x2(f2, sat, type);
574}
575
583__FP8_HOST_DEVICE_STATIC__ __half_raw
585 unsigned int we = type == __HIP_E4M3_FNUZ ? 4 : 5;
586 unsigned int wm = type == __HIP_E4M3_FNUZ ? 3 : 2;
587 return __half_raw{internal::cast_from_f8<_Float16, true>(x, wm, we)};
588}
589
597__FP8_HOST_DEVICE_STATIC__ __half2_raw
599 __half2 ret(static_cast<__half>(
600 __hip_cvt_fp8_to_halfraw(static_cast<__hip_fp8_storage_t>(x & 0xFF), type)),
601 static_cast<__half>(
602 __hip_cvt_fp8_to_halfraw(static_cast<__hip_fp8_storage_t>(x >> 8), type)));
603 return static_cast<__half2_raw>(ret);
604}
605
615 const __half_raw x, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) {
616 return __hip_cvt_float_to_fp8(__half2float(__half(x)), sat, type);
617}
618
628 const __half2_raw x, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) {
629 return __hip_cvt_float2_to_fp8x2(__half22float2(__half2(x)), sat, type);
630}
631
639 constexpr static __hip_fp8_interpretation_t __default_interpret = __HIP_E4M3_FNUZ;
640 constexpr static unsigned int __we = 4;
641 constexpr static unsigned int __wm = 3;
642
643 // TODO: SWDEV-452411
644 // Add cast from unsigned long long, long long to fp8
645
647 __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const long int val)
648 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
649 __default_interpret)) {}
650
652 __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const int val)
653 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
654 __default_interpret)) {}
655
657 __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const short int val)
658 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
659 __default_interpret)) {}
660
662 __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const unsigned long int val)
663 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
664 __default_interpret)) {}
665
667 __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const unsigned int val)
668 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
669 __default_interpret)) {}
670
672 __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const unsigned short int val)
673 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
674 __default_interpret)) {}
675
677 __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const double f)
678 : __x(__hip_cvt_double_to_fp8(f, __default_saturation, __default_interpret)) {}
679
681 __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const float f)
682 : __x(__hip_cvt_float_to_fp8(f, __default_saturation, __default_interpret)) {}
683
685 __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const __hip_bfloat16 f)
686 : __x(__hip_cvt_float_to_fp8(static_cast<float>(f), __default_saturation,
687 __default_interpret)) {}
688
690 __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const __half f)
692 __default_interpret)) {}
693
695 __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz() = default;
696
698 __FP8_HOST_DEVICE__ operator __half() const {
699 return __half(__hip_cvt_fp8_to_halfraw(__x, __default_interpret));
700 }
701
703 __FP8_HOST_DEVICE__ operator __hip_bfloat16() const {
704 float f = *this;
705 return __hip_bfloat16(f);
706 }
707
709 __FP8_HOST_DEVICE__ operator bool() const {
710 // it can be 0x00 (+0.0) since 0x80 will be nan
711 return !(static_cast<unsigned short>(__x) == 0);
712 }
713
715 __FP8_HOST_DEVICE__ operator char() const {
716 if (internal::hip_fp8_fnuz_is_nan(__x)) {
717 return 0;
718 }
719
720 auto fval = internal::cast_from_f8<float, true>(__x, __wm, __we);
721 auto llval = static_cast<long long>(fval);
722 if (llval <= CHAR_MIN) {
723 return CHAR_MIN;
724 } else if (llval >= CHAR_MAX) {
725 return CHAR_MAX;
726 }
727 return static_cast<char>(fval);
728 }
729
731 __FP8_HOST_DEVICE__ operator double() const {
732 return internal::cast_from_f8<double, true>(__x, __wm, __we);
733 }
734
736 __FP8_HOST_DEVICE__ operator float() const {
737#if HIP_FP8_CVT_FAST_PATH
738 return internal::cast_to_f32_from_f8(__x, __default_interpret);
739#else
740 return internal::cast_from_f8<float, true>(__x, __wm, __we);
741#endif
742 }
743
745 __FP8_HOST_DEVICE__ operator int() const {
746 if (internal::hip_fp8_fnuz_is_nan(__x)) {
747 return 0;
748 }
749
750 float fval = *this;
751 return static_cast<int>(fval);
752 }
753
755 __FP8_HOST_DEVICE__ operator long int() const {
756 if (internal::hip_fp8_fnuz_is_nan(__x)) {
757 return 0;
758 }
759
760 float fval = *this;
761 return static_cast<long>(fval);
762 }
763
765 __FP8_HOST_DEVICE__ operator long long int() const {
766 if (internal::hip_fp8_fnuz_is_nan(__x)) {
767 return 0;
768 }
769
770 float fval = *this;
771 return static_cast<long long>(fval);
772 }
773
775 __FP8_HOST_DEVICE__ operator short int() const {
776 if (internal::hip_fp8_fnuz_is_nan(__x)) {
777 return 0;
778 }
779
780 float fval = *this;
781 auto llval = static_cast<long long>(fval);
782 if (llval <= SHRT_MIN) {
783 return SHRT_MIN;
784 } else if (llval >= SHRT_MAX) {
785 return SHRT_MAX;
786 }
787 return static_cast<short>(fval);
788 }
789
791 __FP8_HOST_DEVICE__ operator signed char() const {
792 if (internal::hip_fp8_fnuz_is_nan(__x)) {
793 return 0;
794 }
795
796 float fval = *this;
797 auto llval = static_cast<long long>(fval);
798 if (llval <= SCHAR_MIN) {
799 return SCHAR_MIN;
800 } else if (llval >= SCHAR_MAX) {
801 return SCHAR_MAX;
802 }
803 return static_cast<signed char>(fval);
804 }
805
807 __FP8_HOST_DEVICE__ operator unsigned char() const {
808 if (internal::hip_fp8_fnuz_is_nan(__x)) {
809 return 0;
810 }
811
812 float fval = *this;
813 auto llval = static_cast<long long>(fval);
814 if (llval <= 0) {
815 return 0;
816 } else if (llval >= UCHAR_MAX) {
817 return UCHAR_MAX;
818 }
819 return static_cast<unsigned char>(fval);
820 }
821
823 __FP8_HOST_DEVICE__ operator unsigned int() const {
824 if (internal::hip_fp8_fnuz_is_nan(__x)) {
825 return 0;
826 }
827
828 float fval = *this;
829 auto llval = static_cast<long long>(fval);
830 if (llval <= 0) {
831 return 0;
832 }
833 return static_cast<unsigned int>(fval);
834 }
835
837 __FP8_HOST_DEVICE__ operator unsigned long int() const {
838 if (internal::hip_fp8_fnuz_is_nan(__x)) {
839 return 0;
840 }
841
842 float fval = *this;
843 auto llval = static_cast<long long>(fval);
844 if (llval <= 0) {
845 return 0;
846 }
847 return static_cast<unsigned long>(fval);
848 }
849
851 __FP8_HOST_DEVICE__ operator unsigned long long int() const {
852 if (internal::hip_fp8_fnuz_is_nan(__x)) {
853 return 0;
854 }
855
856 float fval = *this;
857 auto llval = static_cast<long long>(fval);
858 if (llval <= 0) {
859 return 0;
860 }
861 return static_cast<unsigned long long>(fval);
862 }
863
865 __FP8_HOST_DEVICE__ operator unsigned short int() const {
866 if (internal::hip_fp8_fnuz_is_nan(__x)) {
867 return 0;
868 }
869
870 float fval = *this;
871 auto llval = static_cast<long long>(fval);
872 if (llval <= 0) {
873 return 0;
874 }
875 return static_cast<unsigned short>(fval);
876 }
877};
878
886 static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E4M3_FNUZ;
887 static constexpr unsigned int __we = 4;
888 static constexpr unsigned int __wm = 3;
889
891 __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const double2 val)
892 : __x(__hip_cvt_double2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
893
895 __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const float2 val)
896 : __x(__hip_cvt_float2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
897
899 __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const __hip_bfloat162 val)
900 : __x(__hip_cvt_bfloat16raw2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
901
903 __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const __half2 val)
904 : __x(__hip_cvt_halfraw2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
905
907 __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz() = default;
908
910 __FP8_HOST_DEVICE__ operator __half2() const {
911 return __half2(__hip_cvt_fp8x2_to_halfraw2(__x, __default_interpret));
912 }
913
915 __FP8_HOST_DEVICE__ operator float2() const {
916#if HIP_FP8_CVT_FAST_PATH
917 return internal::cast_to_f32x2_from_f8x2(__x, __default_interpret);
918#else
919 return float2(internal::cast_from_f8<float, true>(static_cast<__hip_fp8_storage_t>(__x & 0xFF),
920 __wm, __we),
921 internal::cast_from_f8<float, true>(static_cast<__hip_fp8_storage_t>(__x >> 8),
922 __wm, __we));
923#endif
924 }
925};
926
934 static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E4M3_FNUZ;
935 static constexpr unsigned int __we = 4;
936 static constexpr unsigned int __wm = 3;
937
939 __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const double4 val)
940 : __x{reinterpret_cast<__hip_fp8x4_storage_t>(
941 static_cast<unsigned int>(reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
942 val.x, __default_saturation, __default_interpret)) |
943 reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
944 val.y, __default_saturation, __default_interpret))
945 << 8 |
946 reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
947 val.z, __default_saturation, __default_interpret))
948 << 16 |
949 reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
950 val.w, __default_saturation, __default_interpret))
951 << 24))} {}
952
954 __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const float4 val)
955 : __x{reinterpret_cast<__hip_fp8x4_storage_t>(
956 static_cast<unsigned int>(reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
957 val.x, __default_saturation, __default_interpret)) |
958 reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
959 val.y, __default_saturation, __default_interpret))
960 << 8 |
961 reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
962 val.z, __default_saturation, __default_interpret))
963 << 16 |
964 reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
965 val.w, __default_saturation, __default_interpret))
966 << 24))} {}
967
969 __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const __hip_bfloat162 low, const __hip_bfloat162 high)
970 : __x(reinterpret_cast<__hip_fp8x4_storage_t>(static_cast<unsigned int>(
971 reinterpret_cast<unsigned short>(
972 __hip_cvt_bfloat16raw2_to_fp8x2(high, __default_saturation, __default_interpret)) |
973 reinterpret_cast<unsigned short>(
974 __hip_cvt_bfloat16raw2_to_fp8x2(low, __default_saturation, __default_interpret))
975 << 16))) {}
976
978 __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const __half2 low, const __half2 high)
979 : __x(reinterpret_cast<__hip_fp8x4_storage_t>(
980 static_cast<unsigned int>(reinterpret_cast<unsigned short>(__hip_cvt_halfraw2_to_fp8x2(
981 high, __default_saturation, __default_interpret)) |
982 reinterpret_cast<unsigned short>(__hip_cvt_halfraw2_to_fp8x2(
983 low, __default_saturation, __default_interpret))
984 << 16))) {}
985
987 __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz() = default;
988
990 __FP8_HOST_DEVICE__ operator float4() const {
991 auto x = __x; // bypass const
992 auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t*>(&x); // Little E
993 auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t*>(&x) + 1);
994#if HIP_FP8_CVT_FAST_PATH
995 float2 high = internal::cast_to_f32x2_from_f8x2(fp8x2_high, __default_interpret);
996 float2 low = internal::cast_to_f32x2_from_f8x2(fp8x2_low, __default_interpret);
997#else
998 float2 high = float2(internal::cast_from_f8<float, true>(
999 static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8), __wm, __we),
1000 internal::cast_from_f8<float, true>(
1001 static_cast<__hip_fp8_storage_t>(fp8x2_high >> 8), __wm, __we));
1002 float2 low = float2(internal::cast_from_f8<float, true>(
1003 static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8), __wm, __we),
1004 internal::cast_from_f8<float, true>(
1005 static_cast<__hip_fp8_storage_t>(fp8x2_low >> 8), __wm, __we));
1006#endif
1007 return float4(low.x, low.y, high.x, high.y);
1008 }
1009};
1010
1018 static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E5M2_FNUZ;
1019 static constexpr unsigned int __we = 5;
1020 static constexpr unsigned int __wm = 2;
1021
1022
1023 // TODO: SWDEV-452411
1024 // Add cast from unsigned long long, long long to fp8
1025
1027 __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const long int val)
1028 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
1029 __default_interpret)) {}
1030
1032 __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const int val)
1033 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
1034 __default_interpret)) {}
1035
1037 __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const short int val)
1038 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
1039 __default_interpret)) {}
1040
1042 __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const unsigned long int val)
1043 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
1044 __default_interpret)) {}
1045
1047 __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const unsigned int val)
1048 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
1049 __default_interpret)) {}
1050
1052 __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const unsigned short int val)
1053 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
1054 __default_interpret)) {}
1055
1057 __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const double f)
1058 : __x(__hip_cvt_double_to_fp8(f, __default_saturation, __default_interpret)) {}
1059
1061 __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const float f)
1062 : __x(__hip_cvt_float_to_fp8(f, __default_saturation, __default_interpret)) {}
1063
1065 __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const __hip_bfloat16 f)
1066 : __x(__hip_cvt_float_to_fp8(static_cast<float>(f), __default_saturation,
1067 __default_interpret)) {}
1068
1070 __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const __half f)
1072 __default_interpret)) {}
1073
1075 __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz() = default;
1076
1078 __FP8_HOST_DEVICE__ operator float() const {
1079#if HIP_FP8_CVT_FAST_PATH
1080 return internal::cast_to_f32_from_f8(__x, __default_interpret);
1081#else
1082 return internal::cast_from_f8<float, true>(__x, __wm, __we);
1083#endif
1084 }
1085
1087 __FP8_HOST_DEVICE__ operator __half() const {
1088 return __half(__hip_cvt_fp8_to_halfraw(__x, __default_interpret));
1089 }
1090
1092 __FP8_HOST_DEVICE__ operator __hip_bfloat16() const {
1093 float f = *this;
1094 return __hip_bfloat16(f);
1095 }
1096
1098 __FP8_HOST_DEVICE__ operator bool() const {
1099 // it can be 0x00 (+0.0) since 0x80 will be nan
1100 return !(static_cast<unsigned short>(__x) == 0);
1101 }
1102
1104 __FP8_HOST_DEVICE__ operator char() const {
1105 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1106 return 0;
1107 }
1108
1109 float fval = *this;
1110 auto llval = static_cast<long long>(fval);
1111 if (llval <= CHAR_MIN) {
1112 return CHAR_MIN;
1113 } else if (llval >= CHAR_MAX) {
1114 return CHAR_MAX;
1115 }
1116 return static_cast<char>(fval);
1117 }
1118
1120 __FP8_HOST_DEVICE__ operator double() const {
1121 return internal::cast_from_f8<double, true>(__x, __wm, __we);
1122 }
1123
1125 __FP8_HOST_DEVICE__ operator int() const {
1126 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1127 return 0;
1128 }
1129
1130 float fval = *this;
1131 return static_cast<int>(fval);
1132 }
1133
1135 __FP8_HOST_DEVICE__ operator long int() const {
1136 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1137 return 0;
1138 }
1139
1140 float fval = *this;
1141 return static_cast<long>(fval);
1142 }
1143
1145 __FP8_HOST_DEVICE__ operator long long int() const {
1146 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1147 return 0;
1148 }
1149
1150 float fval = *this;
1151 return static_cast<long long>(fval);
1152 }
1153
1155 __FP8_HOST_DEVICE__ operator short int() const {
1156 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1157 return 0;
1158 }
1159
1160 float fval = *this;
1161 auto llval = static_cast<long long>(fval);
1162 if (llval <= SHRT_MIN) {
1163 return SHRT_MIN;
1164 } else if (llval >= SHRT_MAX) {
1165 return SHRT_MAX;
1166 }
1167 return static_cast<short>(fval);
1168 }
1169
1171 __FP8_HOST_DEVICE__ operator signed char() const {
1172 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1173 return 0;
1174 }
1175
1176 float fval = *this;
1177 auto llval = static_cast<long long>(fval);
1178 if (llval <= SCHAR_MIN) {
1179 return SCHAR_MIN;
1180 } else if (llval >= SCHAR_MAX) {
1181 return SCHAR_MAX;
1182 }
1183 return static_cast<signed char>(fval);
1184 }
1185
1187 __FP8_HOST_DEVICE__ operator unsigned char() const {
1188 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1189 return 0;
1190 }
1191
1192 float fval = *this;
1193 auto llval = static_cast<long long>(fval);
1194 if (llval <= 0) {
1195 return 0;
1196 } else if (llval >= UCHAR_MAX) {
1197 return UCHAR_MAX;
1198 }
1199 return static_cast<unsigned char>(fval);
1200 }
1201
1203 __FP8_HOST_DEVICE__ operator unsigned int() const {
1204 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1205 return 0;
1206 }
1207
1208 float fval = *this;
1209 auto llval = static_cast<long long>(fval);
1210 if (llval <= 0) {
1211 return 0;
1212 }
1213 return static_cast<unsigned int>(fval);
1214 }
1215
1217 __FP8_HOST_DEVICE__ operator unsigned long int() const {
1218 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1219 return 0;
1220 }
1221
1222 float fval = *this;
1223 auto llval = static_cast<long long>(fval);
1224 if (llval <= 0) {
1225 return 0;
1226 }
1227 return static_cast<unsigned long>(fval);
1228 }
1229
1231 __FP8_HOST_DEVICE__ operator unsigned long long int() const {
1232 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1233 return 0;
1234 }
1235
1236 float fval = *this;
1237 auto llval = static_cast<long long>(fval);
1238 if (llval <= 0) {
1239 return 0;
1240 }
1241 return static_cast<unsigned long long>(fval);
1242 }
1243
1245 __FP8_HOST_DEVICE__ operator unsigned short int() const {
1246 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1247 return 0;
1248 }
1249
1250 float fval = *this;
1251 auto llval = static_cast<long long>(fval);
1252 if (llval <= 0) {
1253 return 0;
1254 }
1255 return static_cast<unsigned short>(fval);
1256 }
1257};
1258
1266 static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E5M2_FNUZ;
1267 static constexpr unsigned int __we = 5;
1268 static constexpr unsigned int __wm = 2;
1269
1271 __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const double2 val)
1272 : __x(__hip_cvt_double2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
1273
1275 __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const float2 val)
1276 : __x(__hip_cvt_float2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
1277
1279 __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const __hip_bfloat162 val)
1280 : __x(__hip_cvt_bfloat16raw2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
1281
1283 __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const __half2 val)
1284 : __x(__hip_cvt_halfraw2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
1285
1287 __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz() = default;
1288
1290 __FP8_HOST_DEVICE__ operator __half2() const {
1291 return __half2(__hip_cvt_fp8x2_to_halfraw2(__x, __default_interpret));
1292 }
1293
1295 __FP8_HOST_DEVICE__ operator float2() const {
1296#if HIP_FP8_CVT_FAST_PATH
1297 return internal::cast_to_f32x2_from_f8x2(__x, __default_interpret);
1298#else
1299 return float2(internal::cast_from_f8<float, true>(static_cast<__hip_fp8_storage_t>(__x & 0xFF),
1300 __wm, __we),
1301 internal::cast_from_f8<float, true>(static_cast<__hip_fp8_storage_t>(__x >> 8),
1302 __wm, __we));
1303#endif
1304 }
1305};
1306
1314 static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E5M2_FNUZ;
1315 static constexpr unsigned int __we = 5;
1316 static constexpr unsigned int __wm = 2;
1317
1319 __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const double4 val)
1320 : __x(reinterpret_cast<__hip_fp8x4_storage_t>(
1321 static_cast<unsigned int>(reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
1322 val.x, __default_saturation, __default_interpret)) |
1323 reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
1324 val.y, __default_saturation, __default_interpret))
1325 << 8 |
1326 reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
1327 val.z, __default_saturation, __default_interpret))
1328 << 16 |
1329 reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
1330 val.w, __default_saturation, __default_interpret))
1331 << 24))) {}
1332
1334 __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const float4 val)
1335 : __x(reinterpret_cast<__hip_fp8x4_storage_t>(
1336 static_cast<unsigned int>(reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
1337 val.x, __default_saturation, __default_interpret)) |
1338 reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
1339 val.y, __default_saturation, __default_interpret))
1340 << 8 |
1341 reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
1342 val.z, __default_saturation, __default_interpret))
1343 << 16 |
1344 reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
1345 val.w, __default_saturation, __default_interpret))
1346 << 24))) {}
1347
1349 __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const __hip_bfloat162 low, const __hip_bfloat162 high)
1350 : __x(reinterpret_cast<__hip_fp8x4_storage_t>(static_cast<unsigned int>(
1351 reinterpret_cast<unsigned short>(
1352 __hip_cvt_bfloat16raw2_to_fp8x2(high, __default_saturation, __default_interpret)) |
1353 reinterpret_cast<unsigned short>(
1354 __hip_cvt_bfloat16raw2_to_fp8x2(low, __default_saturation, __default_interpret))
1355 << 16))) {}
1356
1358 __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const __half2 low, const __half2 high)
1359 : __x(reinterpret_cast<__hip_fp8x4_storage_t>(
1360 static_cast<unsigned int>(reinterpret_cast<unsigned short>(__hip_cvt_halfraw2_to_fp8x2(
1361 high, __default_saturation, __default_interpret)) |
1362 reinterpret_cast<unsigned short>(__hip_cvt_halfraw2_to_fp8x2(
1363 low, __default_saturation, __default_interpret))
1364 << 16))) {}
1365
1366 /* default construct fp8x4 e5m2 */
1367 __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz() = default;
1368
1370 __FP8_HOST_DEVICE__ operator float4() const {
1371 auto x = __x; // bypass const
1372 auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t*>(&x); // Little E
1373 auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t*>(&x) + 1);
1374#if HIP_FP8_CVT_FAST_PATH
1375 float2 high = internal::cast_to_f32x2_from_f8x2(fp8x2_high, __default_interpret);
1376 float2 low = internal::cast_to_f32x2_from_f8x2(fp8x2_low, __default_interpret);
1377#else
1378 float2 high = float2(internal::cast_from_f8<float, true>(
1379 static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8), __wm, __we),
1380 internal::cast_from_f8<float, true>(
1381 static_cast<__hip_fp8_storage_t>(fp8x2_high >> 8), __wm, __we));
1382 float2 low = float2(internal::cast_from_f8<float, true>(
1383 static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8), __wm, __we),
1384 internal::cast_from_f8<float, true>(
1385 static_cast<__hip_fp8_storage_t>(fp8x2_low >> 8), __wm, __we));
1386#endif
1387 return float4(low.x, low.y, high.x, high.y);
1388 }
1389};
1390
1391#endif // _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP8_H_
hip_bf16.h provides struct for __hip_bfloat16 types
__hip_saturation_t
Describes saturation behavior.
Definition amd_hip_fp8.h:76
@ __HIP_SATFINITE
Definition amd_hip_fp8.h:78
@ __HIP_NOSAT
Definition amd_hip_fp8.h:77
__FP8_HOST_DEVICE_STATIC__ __half2_raw __hip_cvt_fp8x2_to_halfraw2(const __hip_fp8x2_storage_t x, const __hip_fp8_interpretation_t type)
convert __hip_fp8x2_storage_t to __half2_raw
Definition amd_hip_fp8.h:598
__hip_fp8_interpretation_t
Describes FP8 interpretation.
Definition amd_hip_fp8.h:68
@ __HIP_E4M3_FNUZ
Definition amd_hip_fp8.h:69
@ __HIP_E5M2_FNUZ
Definition amd_hip_fp8.h:70
__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_double2_to_fp8x2(const double2 d2, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type)
convert double2 to __hip_fp8x2_storage_t
Definition amd_hip_fp8.h:539
__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_double_to_fp8(const double d, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type)
convert double to __hip_fp8_storage_t
Definition amd_hip_fp8.h:524
__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_halfraw2_to_fp8x2(const __half2_raw x, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type)
convert __half2_raw to __hip_fp8x2_storage_t
Definition amd_hip_fp8.h:627
unsigned short int __hip_fp8x2_storage_t
type to store two fp8 numbers
Definition amd_hip_fp8.h:92
__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_halfraw_to_fp8(const __half_raw x, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type)
convert __half_raw to __hip_fp8_storage_t
Definition amd_hip_fp8.h:614
__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_float_to_fp8(const float f, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type)
convert float to __hip_fp8_storage_t
Definition amd_hip_fp8.h:486
__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_bfloat16raw_to_fp8(const __hip_bfloat16_raw hr, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type)
convert __hip_bfloat16_raw to __hip_fp8_storage_t
Definition amd_hip_fp8.h:555
unsigned int __hip_fp8x4_storage_t
type to store four fp8 numbers
Definition amd_hip_fp8.h:99
__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_bfloat16raw2_to_fp8x2(const __hip_bfloat162_raw hr, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type)
convert double2 to __hip_fp8x2_storage_t
Definition amd_hip_fp8.h:570
unsigned char __hip_fp8_storage_t
type to store single fp8 number
Definition amd_hip_fp8.h:85
__FP8_HOST_DEVICE_STATIC__ __half_raw __hip_cvt_fp8_to_halfraw(const __hip_fp8_storage_t x, const __hip_fp8_interpretation_t type)
convert __hip_fp8_storage_t to __half_raw
Definition amd_hip_fp8.h:584
__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_float2_to_fp8x2(const float2 f2, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type)
convert float2 to __hip_fp8x2_storage_t
Definition amd_hip_fp8.h:505
struct representing single fp8 number with e4m3 interpretation
Definition amd_hip_fp8.h:636
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const unsigned int val)
Definition amd_hip_fp8.h:667
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz()=default
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const double f)
Definition amd_hip_fp8.h:677
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const __half f)
Definition amd_hip_fp8.h:690
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const short int val)
Definition amd_hip_fp8.h:657
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const float f)
Definition amd_hip_fp8.h:681
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const __hip_bfloat16 f)
Definition amd_hip_fp8.h:685
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const long int val)
Definition amd_hip_fp8.h:647
static constexpr __hip_saturation_t __default_saturation
raw storage of fp8 number
Definition amd_hip_fp8.h:638
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const unsigned long int val)
Definition amd_hip_fp8.h:662
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const unsigned short int val)
Definition amd_hip_fp8.h:672
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const int val)
Definition amd_hip_fp8.h:652
struct representing two fp8 numbers with e4m3 interpretation
Definition amd_hip_fp8.h:883
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const float2 val)
Definition amd_hip_fp8.h:895
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const double2 val)
Definition amd_hip_fp8.h:891
static constexpr __hip_saturation_t __default_saturation
raw storage of two fp8 numbers
Definition amd_hip_fp8.h:885
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const __half2 val)
Definition amd_hip_fp8.h:903
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz()=default
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const __hip_bfloat162 val)
Definition amd_hip_fp8.h:899
struct representing four fp8 numbers with e4m3 interpretation
Definition amd_hip_fp8.h:931
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const __hip_bfloat162 low, const __hip_bfloat162 high)
Definition amd_hip_fp8.h:969
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz()=default
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const double4 val)
Definition amd_hip_fp8.h:939
static constexpr __hip_saturation_t __default_saturation
raw storage of four fp8 numbers
Definition amd_hip_fp8.h:933
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const __half2 low, const __half2 high)
Definition amd_hip_fp8.h:978
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const float4 val)
Definition amd_hip_fp8.h:954
struct representing one fp8 number with e5m2 interpretation
Definition amd_hip_fp8.h:1015
static constexpr __hip_saturation_t __default_saturation
raw storage of one fp8 numbers
Definition amd_hip_fp8.h:1017
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const unsigned short int val)
Definition amd_hip_fp8.h:1052
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz()=default
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const __hip_bfloat16 f)
Definition amd_hip_fp8.h:1065
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const long int val)
Definition amd_hip_fp8.h:1027
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const unsigned int val)
Definition amd_hip_fp8.h:1047
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const __half f)
Definition amd_hip_fp8.h:1070
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const int val)
Definition amd_hip_fp8.h:1032
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const double f)
Definition amd_hip_fp8.h:1057
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const short int val)
Definition amd_hip_fp8.h:1037
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const float f)
Definition amd_hip_fp8.h:1061
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const unsigned long int val)
Definition amd_hip_fp8.h:1042
struct representing two fp8 numbers with e5m2 interpretation
Definition amd_hip_fp8.h:1263
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const float2 val)
Definition amd_hip_fp8.h:1275
static constexpr __hip_saturation_t __default_saturation
raw storage of two fp8 numbers
Definition amd_hip_fp8.h:1265
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const __half2 val)
Definition amd_hip_fp8.h:1283
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const __hip_bfloat162 val)
Definition amd_hip_fp8.h:1279
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const double2 val)
Definition amd_hip_fp8.h:1271
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz()=default
struct representing four fp8 numbers with e5m2 interpretation
Definition amd_hip_fp8.h:1311
__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const __hip_bfloat162 low, const __hip_bfloat162 high)
Definition amd_hip_fp8.h:1349
__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const float4 val)
Definition amd_hip_fp8.h:1334
static constexpr __hip_saturation_t __default_saturation
raw storage of four fp8 numbers
Definition amd_hip_fp8.h:1313
__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const __half2 low, const __half2 high)
Definition amd_hip_fp8.h:1358
__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const double4 val)
Definition amd_hip_fp8.h:1319
Definition amd_hip_vector_types.h:2035
Definition amd_hip_vector_types.h:2042
Definition amd_hip_vector_types.h:2072
Definition amd_hip_vector_types.h:2079
Definition hip_fp16_gcc.h:7
Definition hip_fp16_gcc.h:11