30#ifndef _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP8_H_
31#define _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP8_H_
33#if (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx1200__) || \
34 defined(__gfx1201__)) && \
35 __HIP_DEVICE_COMPILE__
36#define HIP_FP8_CVT_FAST_PATH 1
38#define HIP_FP8_CVT_FAST_PATH 0
41#if (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) && __HIP_DEVICE_COMPILE__
42#define HIP_FP8_TYPE_OCP 0
43#define HIP_FP8_TYPE_FNUZ 1
44#elif (defined(__gfx1200__) || defined(__gfx1201__)) && __HIP_DEVICE_COMPILE__
45#define HIP_FP8_TYPE_OCP 1
46#define HIP_FP8_TYPE_FNUZ 0
48#define HIP_FP8_TYPE_FNUZ 1
49#define HIP_FP8_TYPE_OCP 1
52#if defined(__HIPCC_RTC__)
54 #define ENABLE_FNUZ_HIPRTC 1
56 #define ENABLE_FNUZ_HIPRTC 0
59 #define ENABLE_OCP_HIPRTC 1
61 #define ENABLE_OCP_HIPRTC 0
68#if !defined(__HIPCC_RTC__)
69#include <hip/amd_detail/amd_hip_common.h>
73#include "amd_hip_vector_types.h"
74#include "amd_hip_fp16.h"
76#include "hip_assert.h"
77#define __HIP_SCHAR_MAX SCHAR_MAX
78#define __HIP_SCHAR_MIN SCHAR_MIN
79#define __HIP_UCHAR_MAX UCHAR_MAX
80#define __HIP_SHRT_MIN SHRT_MIN
81#define __HIP_SHRT_MAX SHRT_MAX
82#define __HIP_CHAR_MIN CHAR_MIN
83#define __HIP_CHAR_MAX CHAR_MAX
86#define __HIP_SCHAR_MAX __SCHAR_MAX__
87#define __HIP_SCHAR_MIN (-__SCHAR_MAX__ - 1)
88#define __HIP_UCHAR_MAX (__SCHAR_MAX__ * 2 + 1)
89#define __HIP_SHRT_MIN (-__SHRT_MAX__ - 1)
90#define __HIP_SHRT_MAX __SHRT_MAX__
91#ifdef __CHAR_UNSIGNED__
92#define __HIP_CHAR_MIN 0
93#define __HIP_CHAR_MAX __HIP_UCHAR_MAX
95#define __HIP_CHAR_MIN __HIP_SCHAR_MIN
96#define __HIP_CHAR_MAX __SCHAR_MAX__
100#if defined(__HIPCC_RTC__)
101#define __FP8_HOST_DEVICE__ __device__
102#define __FP8_HOST_DEVICE_STATIC__ __FP8_HOST_DEVICE__ static
104#define __FP8_HOST_DEVICE__ __host__ __device__
105#define __FP8_HOST_DEVICE_STATIC__ __FP8_HOST_DEVICE__ static inline
108#define __FP8_HOST__ __host__
109#define __FP8_HOST_STATIC__ __FP8_HOST__ static inline
112#if !defined(__HIPCC_RTC__)
113static_assert(CHAR_BIT == 8,
"byte size should be of 8 bits");
115static_assert(
sizeof(
unsigned char) == 1);
116static_assert(
sizeof(
unsigned short int) == 2);
117static_assert(
sizeof(
unsigned int) == 4);
161#define __assert_ocp_support(interp) \
163 if (interp != __HIP_E4M3 && interp != __HIP_E5M2) { \
164 __hip_assert(false && "type is unsupported by current target device"); \
167#define __assert_fnuz_support(interp) \
169 if (interp != __HIP_E4M3_FNUZ && interp != __HIP_E5M2_FNUZ) { \
170 __hip_assert(false && "type is unsupported by current target device"); \
175#if __HIP_DEVICE_COMPILE__
177 __assert_ocp_support(interp);
180 __assert_fnuz_support(interp);
188template <
typename T,
bool is_fnuz>
189__FP8_HOST_DEVICE_STATIC__
__hip_fp8_storage_t cast_to_f8(T _x,
int wm,
int we,
bool clip =
false,
191 unsigned int rng = 0) {
192 constexpr bool is_half = __hip_internal::is_same<T, _Float16>::value;
193 constexpr bool is_float = __hip_internal::is_same<T, float>::value;
194 constexpr bool is_double = __hip_internal::is_same<T, double>::value;
195 static_assert(is_half || is_float || is_double,
"Only half, float and double can be cast to f8");
197 const int mfmt = (
sizeof(T) == 8) ? 52 : ((
sizeof(T) == 4) ? 23 : 10);
198 unsigned long long x;
201 x =
reinterpret_cast<unsigned long long&
>(_x);
202 else if (
sizeof(T) == 4)
203 x =
reinterpret_cast<unsigned int&
>(_x);
205 x =
reinterpret_cast<unsigned short int&
>(_x);
208 unsigned long long head, mantissa;
211 unsigned long long fInf, mask;
213 if (
sizeof(T) == 8) {
214 head = x & 0xFFF0000000000000ull;
215 mantissa = x & 0xFFFFFFFFFFFFFull;
216 exponent = (head >> 52) & 0x7FF;
219 fInf = 0x7FF0000000000000ull;
220 mask = 0x7FFFFFFFFFFFFFFFull;
221 }
else if (
sizeof(T) == 4) {
222 head = x & 0xFF800000;
223 mantissa = x & 0x7FFFFF;
224 exponent = (head >> 23) & 0xFF;
231 mantissa = x & 0x3FF;
232 exponent = (head >> 10) & 0x1F;
238 unsigned int signed_inf = 0;
239 unsigned int nan = 0;
241 signed_inf = clip ? ((sign << 7) + 0x7f) : 0x80;
245 signed_inf = (sign << 7) + (clip ? 0x7e : 0x7f);
247 signed_inf = (sign << 7) + (clip ? 0x7b : 0x7c);
249 nan = (sign << 7) + 0x7f;
252 unsigned long long ifmax = 0;
253 if (
sizeof(T) == 8) {
255 ifmax = 0x40EC000000000000ull;
258 ifmax = 0x406E000000000000ull;
260 ifmax = 0x407C000000000000ull;
263 }
else if (
sizeof(T) == 4) {
285 if ((x & fInf) == fInf) {
286 if (is_fnuz)
return signed_inf;
287 return mantissa != 0 ? nan : signed_inf;
290 if ((x & mask) > ifmax) {
305 const int f8_bias = (1 << (we - 1)) - 1 + (is_fnuz ? 1 : 0);
306 const int f8_denormal_act_exponent = 1 - f8_bias;
311 int act_exponent, f8_exponent, exponent_diff;
320 act_exponent = exponent - bias + 1;
321 exponent_diff = f8_denormal_act_exponent -
324 act_exponent = exponent - bias;
325 if (act_exponent <= f8_denormal_act_exponent) {
331 exponent_diff = f8_denormal_act_exponent - act_exponent;
336 mantissa += (1ull << mfmt);
339 bool midpoint = (mantissa & ((1ull << (mfmt - wm + exponent_diff)) - 1)) ==
340 (1ull << (mfmt - wm + exponent_diff - 1));
347 if (exponent_diff > 0)
348 mantissa >>= exponent_diff;
349 else if (exponent_diff == -1)
350 mantissa <<= -exponent_diff;
351 bool implicit_one = mantissa & (1ull << mfmt);
354 (act_exponent + exponent_diff) + f8_bias - (implicit_one ? 0 : 1);
357 unsigned long long drop_mask = (1ull << (mfmt - wm)) - 1;
359 mantissa & (1ull << (mfmt - wm));
361 (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask;
364 if (f8_exponent == 0) {
365 if ((1ull << mfmt) & mantissa) {
369 if ((1ull << (mfmt + 1)) & mantissa) {
375 mantissa >>= (mfmt - wm);
378 const int max_exp = (1 << we) - 1;
379 if (f8_exponent > max_exp) {
381 mantissa = (1 << wm) - 1;
382 f8_exponent = max_exp;
388 if (f8_exponent == 0 && mantissa == 0)
return is_fnuz ? 0 : (sign << 7);
389 mantissa &= (1 << wm) - 1;
390 return (sign << 7) | (f8_exponent << wm) | mantissa;
395template <
typename T,
bool is_fnuz>
398 constexpr bool is_half = __hip_internal::is_same<T, _Float16>::value;
399 constexpr bool is_float = __hip_internal::is_same<T, float>::value;
400 constexpr bool is_double = __hip_internal::is_same<T, double>::value;
401 static_assert(is_half || is_float || is_double,
"only half, float and double are supported");
403 constexpr int weo = is_half ? 5 : (is_float ? 8 : 11);
404 constexpr int wmo = is_half ? 10 : (is_float ? 23 : 52);
406 T fInf, fNegInf, fNaN, fNeg0, fmax, fmin;
408 const unsigned short int ihInf = 0x7C00;
409 const unsigned short int ihNegInf = 0xFC00;
410 const unsigned short int ihNaN = 0x7C01;
411 const unsigned short int ihNeg0 = 0x8000;
413 const unsigned short int ifmax = 0x7B00;
414 const unsigned short int ifmin = 0xFB00;
415 fInf =
reinterpret_cast<const _Float16&
>(ihInf);
416 fNegInf =
reinterpret_cast<const _Float16&
>(ihNegInf);
417 fNaN =
reinterpret_cast<const _Float16&
>(ihNaN);
418 fNeg0 =
reinterpret_cast<const _Float16&
>(ihNeg0);
419 fmax =
reinterpret_cast<const _Float16&
>(ifmax);
420 fmin =
reinterpret_cast<const _Float16&
>(ifmin);
421 }
else if (is_float) {
422 const unsigned int ifInf = 0x7F800000;
423 const unsigned int ifNegInf = 0xFF800000;
424 const unsigned int ifNaN = 0x7F800001;
425 const unsigned int ifNeg0 = 0x80000000;
427 const unsigned int ifmax = 0x47600000;
428 const unsigned int ifmin = 0xC7600000;
429 fInf =
reinterpret_cast<const float&
>(ifInf);
430 fNegInf =
reinterpret_cast<const float&
>(ifNegInf);
431 fNaN =
reinterpret_cast<const float&
>(ifNaN);
432 fNeg0 =
reinterpret_cast<const float&
>(ifNeg0);
433 fmax =
reinterpret_cast<const float&
>(ifmax);
434 fmin =
reinterpret_cast<const float&
>(ifmin);
435 }
else if (is_double) {
436 const unsigned long long ifInf = 0x7FF0000000000000ull;
437 const unsigned long long ifNegInf = 0xFFF0000000000000ull;
438 const unsigned long long ifNaN = 0x7FF0000000000001ull;
439 const unsigned long long ifNeg0 = 0x8000000000000000ull;
441 const unsigned long long ifmax = 0x40EC000000000000ull;
442 const unsigned long long ifmin = 0xC0EC000000000000ull;
443 fInf =
reinterpret_cast<const double&
>(ifInf);
444 fNegInf =
reinterpret_cast<const double&
>(ifNegInf);
445 fNaN =
reinterpret_cast<const double&
>(ifNaN);
446 fNeg0 =
reinterpret_cast<const double&
>(ifNeg0);
447 fmax =
reinterpret_cast<const double&
>(ifmax);
448 fmin =
reinterpret_cast<const double&
>(ifmin);
455 unsigned long long sign = x >> 7;
456 unsigned long long mantissa = x & ((1 << wm) - 1);
457 int exponent = (x & 0x7F) >> wm;
467 if ((x & 0x7F) == 0x7F) {
470 }
else if ((x & 0x7C) == 0x7C) {
471 if ((x & 0x3) == 0) {
473 return sign ? fmin : fmax;
475 return sign ? fNegInf : fInf;
481 typename __hip_internal::conditional<
482 sizeof(T) == 2,
unsigned short int,
483 typename __hip_internal::conditional<
sizeof(T) == 4,
unsigned int,
484 unsigned long long>::type>::type retval;
486 if (we == 5 && is_half && !is_fnuz) {
488 return reinterpret_cast<const T&
>(retval);
491 const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (is_fnuz ? 1 : 0);
495#if __HIP_DEVICE_COMPILE__
497 int sh = 1 + __clz(mantissa) - (32 - wm);
499 int sh = 1 + __builtin_clz(mantissa) - (32 - wm);
503 mantissa &= ((1ull << wm) - 1);
505 exponent += exp_low_cutoff - 1;
506 mantissa <<= wmo - wm;
510 mantissa |= 1ull << wmo;
511 mantissa >>= 1 - exponent;
516 retval = (sign << 15) | (exponent << 10) | mantissa;
517 else if (
sizeof(T) == 4)
518 retval = (sign << 31) | (exponent << 23) | mantissa;
520 retval = (sign << 63) | (static_cast<unsigned long long>(exponent) << 52) | mantissa;
521 return reinterpret_cast<const T&
>(retval);
524#if HIP_FP8_CVT_FAST_PATH
527template <
bool stochastic_rounding = false>
530 unsigned int rng = 0) {
535 unsigned char i8val[4];
538 unsigned int ival = 0;
543 if ((val.i32val & 0x7F800000) != 0x7F800000) {
544 val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
547 if ((val.i32val & 0x7F800000) != 0x7F800000) {
548 val.fval = __builtin_amdgcn_fmed3f(val.fval, 448.0, -448.0);
551 if ((val.i32val & 0x7F800000) != 0x7F800000) {
552 val.fval = __builtin_amdgcn_fmed3f(val.fval, 57344.0, -57344.0);
557 if (stochastic_rounding) {
559 ? __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0)
560 : __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0);
562 i8data = val.i8val[0];
565 ? __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival,
false)
566 : __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival,
false);
568 i8data = val.i8val[0];
576 static_assert(
sizeof(
float2) ==
sizeof(
unsigned int[2]),
"size mismatch");
577 static_assert(
sizeof(
float2) ==
sizeof(
unsigned short[4]),
"size mismatch");
579 unsigned int i32val[2];
580 unsigned short i16val[4];
587 if ((f2val.i32val[0] & 0x7F800000) != 0x7F800000) {
588 f2val.fval.x = __builtin_amdgcn_fmed3f(f2val.fval.x, 240.0, -240.0);
590 if ((f2val.i32val[1] & 0x7F800000) != 0x7F800000) {
591 f2val.fval.y = __builtin_amdgcn_fmed3f(f2val.fval.x, 240.0, -240.0);
594 if ((f2val.i32val[0] & 0x7F800000) != 0x7F800000) {
595 f2val.fval.x = __builtin_amdgcn_fmed3f(f2val.fval.x, 448.0, -448.0);
597 if ((f2val.i32val[1] & 0x7F800000) != 0x7F800000) {
598 f2val.fval.y = __builtin_amdgcn_fmed3f(f2val.fval.x, 448.0, -448.0);
601 if ((f2val.i32val[0] & 0x7F800000) != 0x7F800000) {
602 f2val.fval.x = __builtin_amdgcn_fmed3f(f2val.fval.x, 57344.0, -57344.0);
604 if ((f2val.i32val[1] & 0x7F800000) != 0x7F800000) {
605 f2val.fval.y = __builtin_amdgcn_fmed3f(f2val.fval.x, 57344.0, -57344.0);
611 ? __builtin_amdgcn_cvt_pk_fp8_f32(v.x, v.y, 0,
false)
612 : __builtin_amdgcn_cvt_pk_bf8_f32(v.x, v.y, 0,
false);
621 unsigned char i8val[4];
626 ? __builtin_amdgcn_cvt_f32_fp8(val.i32val, 0)
627 : __builtin_amdgcn_cvt_f32_bf8(val.i32val, 0);
635 unsigned short i16val[2];
640 ? __builtin_amdgcn_cvt_pk_f32_fp8(val.i32val,
false)
641 : __builtin_amdgcn_cvt_pk_f32_bf8(val.i32val,
false);
642 return float2{f2[0], f2[1]};
650 return static_cast<unsigned char>(a) == 0x80;
655 return (type ==
__HIP_E4M3) ? ((a & 0x7f) == 0x7f)
662 return (type ==
__HIP_E5M2) ? (a & 0x7f) == 0x7c :
false;
675#if HIP_FP8_CVT_FAST_PATH
678 internal::__is_interpret_supported(interp);
679 return internal::cast_to_f8_from_f32<false>(f, sat ==
__HIP_SATFINITE, interp);
681#if HIP_FP8_TYPE_OCP && HIP_FP8_TYPE_FNUZ
691 return internal::cast_to_f8<float, true>(f, wm, we, sat ==
__HIP_SATFINITE);
695 return internal::cast_to_f8<float, false>(f, wm, we, sat ==
__HIP_SATFINITE);
709#if HIP_FP8_CVT_FAST_PATH
712 internal::__is_interpret_supported(interp);
713 return internal::cast_to_f8x2_from_f32x2(f2, sat ==
__HIP_SATFINITE, interp);
715#if HIP_FP8_TYPE_OCP && HIP_FP8_TYPE_FNUZ
736#if HIP_FP8_CVT_FAST_PATH
739 internal::__is_interpret_supported(interp);
740#elif HIP_FP8_TYPE_OCP && HIP_FP8_TYPE_FNUZ
750 return internal::cast_to_f8<double, true>(d, wm, we, sat ==
__HIP_SATFINITE);
754 return internal::cast_to_f8<double, false>(d, wm, we, sat ==
__HIP_SATFINITE);
766#if HIP_FP8_CVT_FAST_PATH
769 internal::__is_interpret_supported(interp);
770#elif HIP_FP8_TYPE_OCP && HIP_FP8_TYPE_FNUZ
790#if HIP_FP8_CVT_FAST_PATH
794 internal::__is_interpret_supported(interp);
795#elif HIP_FP8_TYPE_OCP && HIP_FP8_TYPE_FNUZ
804 float fval = __hip_bfloat16(hr);
816#if HIP_FP8_CVT_FAST_PATH
820 internal::__is_interpret_supported(interp);
821#elif HIP_FP8_TYPE_OCP && HIP_FP8_TYPE_FNUZ
830 float2 f2 = __hip_bfloat162(hr);
841#if HIP_FP8_CVT_FAST_PATH
844 internal::__is_interpret_supported(interp);
845#elif HIP_FP8_TYPE_OCP && HIP_FP8_TYPE_FNUZ
855 return __half_raw{internal::cast_from_f8<_Float16, true>(x, wm, we)};
857 unsigned int we = interp ==
__HIP_E4M3 ? 4 : 5;
858 unsigned int wm = interp ==
__HIP_E4M3 ? 3 : 2;
859 return __half_raw{internal::cast_from_f8<_Float16, false>(x, wm, we)};
870#if HIP_FP8_CVT_FAST_PATH
873 internal::__is_interpret_supported(interp);
874#elif HIP_FP8_TYPE_OCP && HIP_FP8_TYPE_FNUZ
881 __half2 ret(
static_cast<__half
>(
896#if HIP_FP8_CVT_FAST_PATH
899 internal::__is_interpret_supported(interp);
900#elif HIP_FP8_TYPE_OCP && HIP_FP8_TYPE_FNUZ
918#if HIP_FP8_CVT_FAST_PATH
921 internal::__is_interpret_supported(interp);
922#elif HIP_FP8_TYPE_OCP && HIP_FP8_TYPE_FNUZ
937#if !defined(ENABLE_FNUZ_HIPRTC) || ENABLE_FNUZ_HIPRTC
942 constexpr static unsigned int __we = 4;
943 constexpr static unsigned int __wm = 3;
955 __default_interpret)) {
965 __default_interpret)) {
975 __default_interpret)) {
985 __default_interpret)) {
995 __default_interpret)) {
1005 __default_interpret)) {
1009#if HIP_FP8_TYPE_FNUZ
1018#if HIP_FP8_TYPE_FNUZ
1027#if HIP_FP8_TYPE_FNUZ
1033 __default_interpret)) {
1037#if HIP_FP8_TYPE_FNUZ
1043 __default_interpret)) {
1047#if HIP_FP8_TYPE_FNUZ
1054#if HIP_FP8_TYPE_FNUZ
1055 __FP8_HOST_DEVICE__
operator __half()
const {
1057 __FP8_HOST__
operator __half()
const {
1063#if HIP_FP8_TYPE_FNUZ
1064 __FP8_HOST_DEVICE__
operator __hip_bfloat16()
const {
1066 __FP8_HOST__
operator __hip_bfloat16()
const {
1069 return __hip_bfloat16(f);
1073#if HIP_FP8_TYPE_FNUZ
1074 __FP8_HOST_DEVICE__
operator bool()
const {
1076 __FP8_HOST__
operator bool()
const {
1079 return !(
static_cast<unsigned short>(__x) == 0);
1083#if HIP_FP8_TYPE_FNUZ
1084 __FP8_HOST_DEVICE__
operator char()
const {
1086 __FP8_HOST__
operator char()
const {
1088 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1092 auto fval = internal::cast_from_f8<float, true>(__x, __wm, __we);
1093 auto llval =
static_cast<long long>(fval);
1094 if (llval <= __HIP_CHAR_MIN) {
1095 return __HIP_CHAR_MIN;
1096 }
else if (llval >= __HIP_CHAR_MAX) {
1097 return __HIP_CHAR_MAX;
1099 return static_cast<char>(fval);
1103#if HIP_FP8_TYPE_FNUZ
1104 __FP8_HOST_DEVICE__
operator double()
const {
1106 __FP8_HOST__
operator double()
const {
1108 return internal::cast_from_f8<double, true>(__x, __wm, __we);
1112#if HIP_FP8_TYPE_FNUZ
1113 __FP8_HOST_DEVICE__
operator float()
const {
1115 __FP8_HOST__
operator float()
const {
1117#if HIP_FP8_CVT_FAST_PATH
1118 return internal::cast_to_f32_from_f8(__x, __default_interpret);
1120 return internal::cast_from_f8<float, true>(__x, __wm, __we);
1125#if HIP_FP8_TYPE_FNUZ
1126 __FP8_HOST_DEVICE__
operator int()
const {
1128 __FP8_HOST__
operator int()
const {
1130 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1135 return static_cast<int>(fval);
1139#if HIP_FP8_TYPE_FNUZ
1140 __FP8_HOST_DEVICE__
operator long int()
const {
1142 __FP8_HOST__
operator long int()
const {
1144 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1149 return static_cast<long>(fval);
1153#if HIP_FP8_TYPE_FNUZ
1154 __FP8_HOST_DEVICE__
operator long long int()
const {
1156 __FP8_HOST__
operator long long int()
const {
1158 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1163 return static_cast<long long>(fval);
1167#if HIP_FP8_TYPE_FNUZ
1168 __FP8_HOST_DEVICE__
operator short int()
const {
1170 __FP8_HOST__
operator short int()
const {
1172 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1177 auto llval =
static_cast<long long>(fval);
1178 if (llval <= __HIP_SHRT_MIN) {
1179 return __HIP_SHRT_MIN;
1180 }
else if (llval >= __HIP_SHRT_MAX) {
1181 return __HIP_SHRT_MAX;
1183 return static_cast<short>(fval);
1187#if HIP_FP8_TYPE_FNUZ
1188 __FP8_HOST_DEVICE__
operator signed char()
const {
1190 __FP8_HOST__
operator signed char()
const {
1192 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1197 auto llval =
static_cast<long long>(fval);
1198 if (llval <= __HIP_SCHAR_MIN) {
1199 return __HIP_SCHAR_MIN;
1200 }
else if (llval >= __HIP_SCHAR_MAX) {
1201 return __HIP_SCHAR_MAX;
1203 return static_cast<signed char>(fval);
1207#if HIP_FP8_TYPE_FNUZ
1208 __FP8_HOST_DEVICE__
operator unsigned char()
const {
1210 __FP8_HOST__
operator unsigned char()
const {
1212 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1217 auto llval =
static_cast<long long>(fval);
1220 }
else if (llval >= __HIP_UCHAR_MAX) {
1221 return __HIP_UCHAR_MAX;
1223 return static_cast<unsigned char>(fval);
1227#if HIP_FP8_TYPE_FNUZ
1228 __FP8_HOST_DEVICE__
operator unsigned int()
const {
1230 __FP8_HOST__
operator unsigned int()
const {
1232 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1237 auto llval =
static_cast<long long>(fval);
1241 return static_cast<unsigned int>(fval);
1245#if HIP_FP8_TYPE_FNUZ
1246 __FP8_HOST_DEVICE__
operator unsigned long int()
const {
1248 __FP8_HOST__
operator unsigned long int()
const {
1250 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1255 auto llval =
static_cast<long long>(fval);
1259 return static_cast<unsigned long>(fval);
1263#if HIP_FP8_TYPE_FNUZ
1264 __FP8_HOST_DEVICE__
operator unsigned long long int()
const {
1266 __FP8_HOST__
operator unsigned long long int()
const {
1268 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1273 auto llval =
static_cast<long long>(fval);
1277 return static_cast<unsigned long long>(fval);
1281#if HIP_FP8_TYPE_FNUZ
1282 __FP8_HOST_DEVICE__
operator unsigned short int()
const {
1284 __FP8_HOST__
operator unsigned short int()
const {
1286 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1291 auto llval =
static_cast<long long>(fval);
1295 return static_cast<unsigned short>(fval);
1307 static constexpr unsigned int __we = 4;
1308 static constexpr unsigned int __wm = 3;
1311#if HIP_FP8_TYPE_FNUZ
1320#if HIP_FP8_TYPE_FNUZ
1329#if HIP_FP8_TYPE_FNUZ
1338#if HIP_FP8_TYPE_FNUZ
1347#if HIP_FP8_TYPE_FNUZ
1354#if HIP_FP8_TYPE_FNUZ
1355 __FP8_HOST_DEVICE__
operator __half2()
const {
1357 __FP8_HOST__
operator __half2()
const {
1363#if HIP_FP8_TYPE_FNUZ
1364 __FP8_HOST_DEVICE__
operator float2()
const {
1366 __FP8_HOST__
operator float2()
const {
1368#if HIP_FP8_CVT_FAST_PATH
1369 return internal::cast_to_f32x2_from_f8x2(__x, __default_interpret);
1387 static constexpr unsigned int __we = 4;
1388 static constexpr unsigned int __wm = 3;
1391#if HIP_FP8_TYPE_FNUZ
1398 val.x, __default_saturation, __default_interpret)) |
1400 val.y, __default_saturation, __default_interpret))
1403 val.z, __default_saturation, __default_interpret))
1406 val.w, __default_saturation, __default_interpret))
1411#if HIP_FP8_TYPE_FNUZ
1418 val.x, __default_saturation, __default_interpret)) |
1420 val.y, __default_saturation, __default_interpret))
1423 val.z, __default_saturation, __default_interpret))
1426 val.w, __default_saturation, __default_interpret))
1431#if HIP_FP8_TYPE_FNUZ
1437 reinterpret_cast<unsigned short>(
1439 reinterpret_cast<unsigned short>(
1445#if HIP_FP8_TYPE_FNUZ
1452 high, __default_saturation, __default_interpret)) |
1454 low, __default_saturation, __default_interpret))
1459#if HIP_FP8_TYPE_FNUZ
1466#if HIP_FP8_TYPE_FNUZ
1467 __FP8_HOST_DEVICE__
operator float4()
const {
1469 __FP8_HOST__
operator float4()
const {
1474#if HIP_FP8_CVT_FAST_PATH
1475 float2 high = internal::cast_to_f32x2_from_f8x2(fp8x2_high, __default_interpret);
1476 float2 low = internal::cast_to_f32x2_from_f8x2(fp8x2_low, __default_interpret);
1478 float2 high =
float2(internal::cast_from_f8<float, true>(
1480 internal::cast_from_f8<float, true>(
1482 float2 low =
float2(internal::cast_from_f8<float, true>(
1484 internal::cast_from_f8<float, true>(
1487 return float4(low.x, low.y, high.x, high.y);
1499 static constexpr unsigned int __we = 5;
1500 static constexpr unsigned int __wm = 2;
1507#if HIP_FP8_TYPE_FNUZ
1513 __default_interpret)) {
1517#if HIP_FP8_TYPE_FNUZ
1523 __default_interpret)) {
1527#if HIP_FP8_TYPE_FNUZ
1533 __default_interpret)) {
1537#if HIP_FP8_TYPE_FNUZ
1543 __default_interpret)) {
1547#if HIP_FP8_TYPE_FNUZ
1553 __default_interpret)) {
1557#if HIP_FP8_TYPE_FNUZ
1563 __default_interpret)) {
1567#if HIP_FP8_TYPE_FNUZ
1576#if HIP_FP8_TYPE_FNUZ
1585#if HIP_FP8_TYPE_FNUZ
1591 __default_interpret)) {
1595#if HIP_FP8_TYPE_FNUZ
1601 __default_interpret)) {
1605#if HIP_FP8_TYPE_FNUZ
1612#if HIP_FP8_TYPE_FNUZ
1613 __FP8_HOST_DEVICE__
operator float()
const {
1615 __FP8_HOST__
operator float()
const {
1617#if HIP_FP8_CVT_FAST_PATH
1618 return internal::cast_to_f32_from_f8(__x, __default_interpret);
1620 return internal::cast_from_f8<float, true>(__x, __wm, __we);
1625#if HIP_FP8_TYPE_FNUZ
1626 __FP8_HOST_DEVICE__
operator __half()
const {
1628 __FP8_HOST__
operator __half()
const {
1634#if HIP_FP8_TYPE_FNUZ
1635 __FP8_HOST_DEVICE__
operator __hip_bfloat16()
const {
1637 __FP8_HOST__
operator __hip_bfloat16()
const {
1640 return __hip_bfloat16(f);
1644#if HIP_FP8_TYPE_FNUZ
1645 __FP8_HOST_DEVICE__
operator bool()
const {
1647 __FP8_HOST__
operator bool()
const {
1650 return !(
static_cast<unsigned short>(__x) == 0);
1654#if HIP_FP8_TYPE_FNUZ
1655 __FP8_HOST_DEVICE__
operator char()
const {
1657 __FP8_HOST__
operator char()
const {
1659 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1664 auto llval =
static_cast<long long>(fval);
1665 if (llval <= __HIP_CHAR_MIN) {
1666 return __HIP_CHAR_MIN;
1667 }
else if (llval >= __HIP_CHAR_MAX) {
1668 return __HIP_CHAR_MAX;
1670 return static_cast<char>(fval);
1674#if HIP_FP8_TYPE_FNUZ
1675 __FP8_HOST_DEVICE__
operator double()
const {
1677 __FP8_HOST__
operator double()
const {
1679 return internal::cast_from_f8<double, true>(__x, __wm, __we);
1683#if HIP_FP8_TYPE_FNUZ
1684 __FP8_HOST_DEVICE__
operator int()
const {
1686 __FP8_HOST__
operator int()
const {
1688 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1693 return static_cast<int>(fval);
1697#if HIP_FP8_TYPE_FNUZ
1698 __FP8_HOST_DEVICE__
operator long int()
const {
1700 __FP8_HOST__
operator long int()
const {
1702 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1707 return static_cast<long>(fval);
1711#if HIP_FP8_TYPE_FNUZ
1712 __FP8_HOST_DEVICE__
operator long long int()
const {
1714 __FP8_HOST__
operator long long int()
const {
1716 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1721 return static_cast<long long>(fval);
1725#if HIP_FP8_TYPE_FNUZ
1726 __FP8_HOST_DEVICE__
operator short int()
const {
1728 __FP8_HOST__
operator short int()
const {
1730 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1735 auto llval =
static_cast<long long>(fval);
1736 if (llval <= __HIP_SHRT_MIN) {
1737 return __HIP_SHRT_MIN;
1738 }
else if (llval >= __HIP_SHRT_MAX) {
1739 return __HIP_SHRT_MAX;
1741 return static_cast<short>(fval);
1745#if HIP_FP8_TYPE_FNUZ
1746 __FP8_HOST_DEVICE__
operator signed char()
const {
1748 __FP8_HOST__
operator signed char()
const {
1750 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1755 auto llval =
static_cast<long long>(fval);
1756 if (llval <= __HIP_SCHAR_MIN) {
1757 return __HIP_SCHAR_MIN;
1758 }
else if (llval >= __HIP_SCHAR_MAX) {
1759 return __HIP_SCHAR_MAX;
1761 return static_cast<signed char>(fval);
1765#if HIP_FP8_TYPE_FNUZ
1766 __FP8_HOST_DEVICE__
operator unsigned char()
const {
1768 __FP8_HOST__
operator unsigned char()
const {
1770 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1775 auto llval =
static_cast<long long>(fval);
1778 }
else if (llval >= __HIP_UCHAR_MAX) {
1779 return __HIP_UCHAR_MAX;
1781 return static_cast<unsigned char>(fval);
1785#if HIP_FP8_TYPE_FNUZ
1786 __FP8_HOST_DEVICE__
operator unsigned int()
const {
1788 __FP8_HOST__
operator unsigned int()
const {
1790 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1795 auto llval =
static_cast<long long>(fval);
1799 return static_cast<unsigned int>(fval);
1803#if HIP_FP8_TYPE_FNUZ
1804 __FP8_HOST_DEVICE__
operator unsigned long int()
const {
1806 __FP8_HOST__
operator unsigned long int()
const {
1808 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1813 auto llval =
static_cast<long long>(fval);
1817 return static_cast<unsigned long>(fval);
1821#if HIP_FP8_TYPE_FNUZ
1822 __FP8_HOST_DEVICE__
operator unsigned long long int()
const {
1824 __FP8_HOST__
operator unsigned long long int()
const {
1826 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1831 auto llval =
static_cast<long long>(fval);
1835 return static_cast<unsigned long long>(fval);
1839#if HIP_FP8_TYPE_FNUZ
1840 __FP8_HOST_DEVICE__
operator unsigned short int()
const {
1842 __FP8_HOST__
operator unsigned short int()
const {
1844 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1849 auto llval =
static_cast<long long>(fval);
1853 return static_cast<unsigned short>(fval);
1865 static constexpr unsigned int __we = 5;
1866 static constexpr unsigned int __wm = 2;
1869#if HIP_FP8_TYPE_FNUZ
1878#if HIP_FP8_TYPE_FNUZ
1887#if HIP_FP8_TYPE_FNUZ
1896#if HIP_FP8_TYPE_FNUZ
1905#if HIP_FP8_TYPE_FNUZ
1912#if HIP_FP8_TYPE_FNUZ
1913 __FP8_HOST_DEVICE__
operator __half2()
const {
1915 __FP8_HOST__
operator __half2()
const {
1921#if HIP_FP8_TYPE_FNUZ
1922 __FP8_HOST_DEVICE__
operator float2()
const {
1924 __FP8_HOST__
operator float2()
const {
1926#if HIP_FP8_CVT_FAST_PATH
1927 return internal::cast_to_f32x2_from_f8x2(__x, __default_interpret);
1945 static constexpr unsigned int __we = 5;
1946 static constexpr unsigned int __wm = 2;
1949#if HIP_FP8_TYPE_FNUZ
1956 val.x, __default_saturation, __default_interpret)) |
1958 val.y, __default_saturation, __default_interpret))
1961 val.z, __default_saturation, __default_interpret))
1964 val.w, __default_saturation, __default_interpret))
1969#if HIP_FP8_TYPE_FNUZ
1976 val.x, __default_saturation, __default_interpret)) |
1978 val.y, __default_saturation, __default_interpret))
1981 val.z, __default_saturation, __default_interpret))
1984 val.w, __default_saturation, __default_interpret))
1989#if HIP_FP8_TYPE_FNUZ
1995 reinterpret_cast<unsigned short>(
1997 reinterpret_cast<unsigned short>(
2003#if HIP_FP8_TYPE_FNUZ
2010 high, __default_saturation, __default_interpret)) |
2012 low, __default_saturation, __default_interpret))
2017#if HIP_FP8_TYPE_FNUZ
2024#if HIP_FP8_TYPE_FNUZ
2025 __FP8_HOST_DEVICE__
operator float4()
const {
2027 __FP8_HOST__
operator float4()
const {
2032#if HIP_FP8_CVT_FAST_PATH
2033 float2 high = internal::cast_to_f32x2_from_f8x2(fp8x2_high, __default_interpret);
2034 float2 low = internal::cast_to_f32x2_from_f8x2(fp8x2_low, __default_interpret);
2036 float2 high =
float2(internal::cast_from_f8<float, true>(
2038 internal::cast_from_f8<float, true>(
2040 float2 low =
float2(internal::cast_from_f8<float, true>(
2042 internal::cast_from_f8<float, true>(
2045 return float4(low.x, low.y, high.x, high.y);
2056#if !defined(ENABLE_OCP_HIPRTC) || ENABLE_OCP_HIPRTC
2062 constexpr static unsigned int __we = 4;
2063 constexpr static unsigned int __wm = 3;
2075 __default_interpret)) {
2085 __default_interpret)) {
2091 __default_interpret)) {}
2100 __default_interpret)) {
2110 __default_interpret)) {
2120 __default_interpret)) {
2148 __default_interpret)) {
2158 __default_interpret)) {
2171 __FP8_HOST_DEVICE__
operator __half()
const {
2173 __FP8_HOST__
operator __half()
const {
2180 __FP8_HOST_DEVICE__
operator __hip_bfloat16()
const {
2182 __FP8_HOST__
operator __hip_bfloat16()
const {
2185 return __hip_bfloat16(f);
2190 __FP8_HOST_DEVICE__
operator bool()
const {
2192 __FP8_HOST__
operator bool()
const {
2195 return !(
static_cast<unsigned short>(__x) == 0 ||
static_cast<unsigned short>(__x) == 0x80);
2200 __FP8_HOST_DEVICE__
operator char()
const {
2202 __FP8_HOST__
operator char()
const {
2204 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2208 auto fval = internal::cast_from_f8<float, false>(__x, __wm, __we);
2209 auto llval =
static_cast<long long>(fval);
2210 if (llval <= __HIP_CHAR_MIN) {
2211 return __HIP_CHAR_MIN;
2212 }
else if (llval >= __HIP_CHAR_MAX) {
2213 return __HIP_CHAR_MAX;
2215 return static_cast<char>(fval);
2220 __FP8_HOST_DEVICE__
operator double()
const {
2222 __FP8_HOST__
operator double()
const {
2224 return internal::cast_from_f8<double, false>(__x, __wm, __we);
2229 __FP8_HOST_DEVICE__
operator float()
const {
2231 __FP8_HOST__
operator float()
const {
2233#if HIP_FP8_CVT_FAST_PATH
2234 return internal::cast_to_f32_from_f8(__x, __default_interpret);
2236 return internal::cast_from_f8<float, false>(__x, __wm, __we);
2242 __FP8_HOST_DEVICE__
operator int()
const {
2244 __FP8_HOST__
operator int()
const {
2246 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2251 return static_cast<int>(fval);
2256 __FP8_HOST_DEVICE__
operator long int()
const {
2258 __FP8_HOST__
operator long int()
const {
2260 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2265 return static_cast<long>(fval);
2270 __FP8_HOST_DEVICE__
operator long long int()
const {
2272 __FP8_HOST__
operator long long int()
const {
2274 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2279 return static_cast<long long>(fval);
2284 __FP8_HOST_DEVICE__
operator short int()
const {
2286 __FP8_HOST__
operator short int()
const {
2288 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2293 auto llval =
static_cast<long long>(fval);
2294 if (llval <= __HIP_SHRT_MIN) {
2295 return __HIP_SHRT_MIN;
2296 }
else if (llval >= __HIP_SHRT_MAX) {
2297 return __HIP_SHRT_MAX;
2299 return static_cast<short>(fval);
2304 __FP8_HOST_DEVICE__
operator signed char()
const {
2306 __FP8_HOST__
operator signed char()
const {
2308 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2313 auto llval =
static_cast<long long>(fval);
2314 if (llval <= __HIP_SCHAR_MIN) {
2315 return __HIP_SCHAR_MIN;
2316 }
else if (llval >= __HIP_SCHAR_MAX) {
2317 return __HIP_SCHAR_MAX;
2319 return static_cast<signed char>(fval);
2324 __FP8_HOST_DEVICE__
operator unsigned char()
const {
2326 __FP8_HOST__
operator unsigned char()
const {
2328 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2333 auto llval =
static_cast<long long>(fval);
2336 }
else if (llval >= __HIP_UCHAR_MAX) {
2337 return __HIP_UCHAR_MAX;
2339 return static_cast<unsigned char>(fval);
2344 __FP8_HOST_DEVICE__
operator unsigned int()
const {
2346 __FP8_HOST__
operator unsigned int()
const {
2348 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2353 auto llval =
static_cast<long long>(fval);
2357 return static_cast<unsigned int>(fval);
2362 __FP8_HOST_DEVICE__
operator unsigned long int()
const {
2364 __FP8_HOST__
operator unsigned long int()
const {
2366 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2371 auto llval =
static_cast<long long>(fval);
2375 return static_cast<unsigned long>(fval);
2380 __FP8_HOST_DEVICE__
operator unsigned long long int()
const {
2382 __FP8_HOST__
operator unsigned long long int()
const {
2384 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2389 auto llval =
static_cast<long long>(fval);
2393 return static_cast<unsigned long long>(fval);
2398 __FP8_HOST_DEVICE__
operator unsigned short int()
const {
2400 __FP8_HOST__
operator unsigned short int()
const {
2402 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2407 auto llval =
static_cast<long long>(fval);
2411 return static_cast<unsigned short>(fval);
2423 static constexpr unsigned int __we = 4;
2424 static constexpr unsigned int __wm = 3;
2472 __FP8_HOST_DEVICE__
operator __half2()
const {
2474 __FP8_HOST__
operator __half2()
const {
2481 __FP8_HOST_DEVICE__
operator float2()
const {
2483 __FP8_HOST__
operator float2()
const {
2485#if HIP_FP8_CVT_FAST_PATH
2486 return internal::cast_to_f32x2_from_f8x2(__x, __default_interpret);
2504 static constexpr unsigned int __we = 4;
2505 static constexpr unsigned int __wm = 3;
2516 val.x, __default_saturation, __default_interpret)) |
2518 val.y, __default_saturation, __default_interpret))
2521 val.z, __default_saturation, __default_interpret))
2524 val.w, __default_saturation, __default_interpret))
2536 val.x, __default_saturation, __default_interpret)) |
2538 val.y, __default_saturation, __default_interpret))
2541 val.z, __default_saturation, __default_interpret))
2544 val.w, __default_saturation, __default_interpret))
2552 __FP8_HOST__
__hip_fp8x4_e4m3(
const __hip_bfloat162 low,
const __hip_bfloat162 high)
2555 reinterpret_cast<unsigned short>(
2557 reinterpret_cast<unsigned short>(
2570 high, __default_saturation, __default_interpret)) |
2572 low, __default_saturation, __default_interpret))
2585 __FP8_HOST_DEVICE__
operator float4()
const {
2587 __FP8_HOST__
operator float4()
const {
2592#if HIP_FP8_CVT_FAST_PATH
2593 float2 high = internal::cast_to_f32x2_from_f8x2(fp8x2_high, __default_interpret);
2594 float2 low = internal::cast_to_f32x2_from_f8x2(fp8x2_low, __default_interpret);
2596 float2 high =
float2(internal::cast_from_f8<float, false>(
2598 internal::cast_from_f8<float, false>(
2600 float2 low =
float2(internal::cast_from_f8<float, false>(
2602 internal::cast_from_f8<float, false>(
2605 return float4(low.x, low.y, high.x, high.y);
2617 static constexpr unsigned int __we = 5;
2618 static constexpr unsigned int __wm = 2;
2632 __default_interpret)) {
2642 __default_interpret)) {
2652 __default_interpret)) {
2662 __default_interpret)) {
2672 __default_interpret)) {
2682 __default_interpret)) {
2710 __default_interpret)) {
2720 __default_interpret)) {
2732 __FP8_HOST_DEVICE__
operator float()
const {
2734 __FP8_HOST__
operator float()
const {
2736#if HIP_FP8_CVT_FAST_PATH
2737 return internal::cast_to_f32_from_f8(__x, __default_interpret);
2739 return internal::cast_from_f8<float, false>(__x, __wm, __we,
2746 __FP8_HOST_DEVICE__
operator __half()
const {
2748 __FP8_HOST__
operator __half()
const {
2755 __FP8_HOST_DEVICE__
operator __hip_bfloat16()
const {
2757 __FP8_HOST__
operator __hip_bfloat16()
const {
2760 return __hip_bfloat16(f);
2765 __FP8_HOST_DEVICE__
operator bool()
const {
2767 __FP8_HOST__
operator bool()
const {
2770 return !(
static_cast<unsigned short>(__x) == 0 ||
static_cast<unsigned short>(__x) == 0x80);
2775 __FP8_HOST_DEVICE__
operator char()
const {
2777 __FP8_HOST__
operator char()
const {
2779 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2784 auto llval =
static_cast<long long>(fval);
2785 if (llval <= __HIP_CHAR_MIN) {
2786 return __HIP_CHAR_MIN;
2787 }
else if (llval >= __HIP_CHAR_MAX) {
2788 return __HIP_CHAR_MAX;
2790 return static_cast<char>(fval);
2795 __FP8_HOST_DEVICE__
operator double()
const {
2797 __FP8_HOST__
operator double()
const {
2799 return internal::cast_from_f8<double, false>(__x, __wm, __we,
2805 __FP8_HOST_DEVICE__
operator int()
const {
2807 __FP8_HOST__
operator int()
const {
2809 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2814 return static_cast<int>(fval);
2819 __FP8_HOST_DEVICE__
operator long int()
const {
2821 __FP8_HOST__
operator long int()
const {
2823 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2828 return static_cast<long>(fval);
2833 __FP8_HOST_DEVICE__
operator long long int()
const {
2835 __FP8_HOST__
operator long long int()
const {
2837 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2842 return static_cast<long long>(fval);
2847 __FP8_HOST_DEVICE__
operator short int()
const {
2849 __FP8_HOST__
operator short int()
const {
2851 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2856 auto llval =
static_cast<long long>(fval);
2857 if (llval <= __HIP_SHRT_MIN) {
2858 return __HIP_SHRT_MIN;
2859 }
else if (llval >= __HIP_SHRT_MAX) {
2860 return __HIP_SHRT_MAX;
2862 return static_cast<short>(fval);
2867 __FP8_HOST_DEVICE__
operator signed char()
const {
2869 __FP8_HOST__
operator signed char()
const {
2871 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2876 auto llval =
static_cast<long long>(fval);
2877 if (llval <= __HIP_SCHAR_MIN) {
2878 return __HIP_SCHAR_MIN;
2879 }
else if (llval >= __HIP_SCHAR_MAX) {
2880 return __HIP_SCHAR_MAX;
2882 return static_cast<signed char>(fval);
2887 __FP8_HOST_DEVICE__
operator unsigned char()
const {
2889 __FP8_HOST__
operator unsigned char()
const {
2891 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2896 auto llval =
static_cast<long long>(fval);
2899 }
else if (llval >= __HIP_UCHAR_MAX) {
2900 return __HIP_UCHAR_MAX;
2902 return static_cast<unsigned char>(fval);
2907 __FP8_HOST_DEVICE__
operator unsigned int()
const {
2909 __FP8_HOST__
operator unsigned int()
const {
2911 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2916 auto llval =
static_cast<long long>(fval);
2920 return static_cast<unsigned int>(fval);
2925 __FP8_HOST_DEVICE__
operator unsigned long int()
const {
2927 __FP8_HOST__
operator unsigned long int()
const {
2929 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2934 auto llval =
static_cast<long long>(fval);
2938 return static_cast<unsigned long>(fval);
2943 __FP8_HOST_DEVICE__
operator unsigned long long int()
const {
2945 __FP8_HOST__
operator unsigned long long int()
const {
2947 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2952 auto llval =
static_cast<long long>(fval);
2956 return static_cast<unsigned long long>(fval);
2961 __FP8_HOST_DEVICE__
operator unsigned short int()
const {
2963 __FP8_HOST__
operator unsigned short int()
const {
2965 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2970 auto llval =
static_cast<long long>(fval);
2974 return static_cast<unsigned short>(fval);
2986 static constexpr unsigned int __we = 5;
2987 static constexpr unsigned int __wm = 2;
3035 __FP8_HOST_DEVICE__
operator __half2()
const {
3037 __FP8_HOST__
operator __half2()
const {
3044 __FP8_HOST_DEVICE__
operator float2()
const {
3046 __FP8_HOST__
operator float2()
const {
3048#if HIP_FP8_CVT_FAST_PATH
3049 return internal::cast_to_f32x2_from_f8x2(__x, __default_interpret);
3054 internal::cast_from_f8<float, false>(
static_cast<__hip_fp8_storage_t>(__x >> 8), __wm, __we,
3068 static constexpr unsigned int __we = 5;
3069 static constexpr unsigned int __wm = 2;
3079 val.x, __default_saturation, __default_interpret)) |
3081 val.y, __default_saturation, __default_interpret))
3084 val.z, __default_saturation, __default_interpret))
3087 val.w, __default_saturation, __default_interpret))
3099 val.x, __default_saturation, __default_interpret)) |
3101 val.y, __default_saturation, __default_interpret))
3104 val.z, __default_saturation, __default_interpret))
3107 val.w, __default_saturation, __default_interpret))
3115 __FP8_HOST__
__hip_fp8x4_e5m2(
const __hip_bfloat162 low,
const __hip_bfloat162 high)
3118 reinterpret_cast<unsigned short>(
3120 reinterpret_cast<unsigned short>(
3133 high, __default_saturation, __default_interpret)) |
3135 low, __default_saturation, __default_interpret))
3148 __FP8_HOST_DEVICE__
operator float4()
const {
3150 __FP8_HOST__
operator float4()
const {
3155#if HIP_FP8_CVT_FAST_PATH
3156 float2 high = internal::cast_to_f32x2_from_f8x2(fp8x2_high, __default_interpret);
3157 float2 low = internal::cast_to_f32x2_from_f8x2(fp8x2_low, __default_interpret);
3160 internal::cast_from_f8<float, false>(
3166 internal::cast_from_f8<float, false>(
3169 internal::cast_from_f8<float, false>(
static_cast<__hip_fp8_storage_t>(fp8x2_low >> 8), __wm,
3172 return float4(low.x, low.y, high.x, high.y);
hip_bf16.h provides struct for __hip_bfloat16 types
__hip_saturation_t
Describes saturation behavior.
Definition amd_hip_fp8.h:132
@ __HIP_SATFINITE
Definition amd_hip_fp8.h:134
@ __HIP_NOSAT
Definition amd_hip_fp8.h:133
__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_double_to_fp8(const double d, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp)
convert double to __hip_fp8_storage_t
Definition amd_hip_fp8.h:741
__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_bfloat16raw2_to_fp8x2(const __hip_bfloat162_raw hr, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp)
convert double2 to __hip_fp8x2_storage_t
Definition amd_hip_fp8.h:823
__hip_fp8_interpretation_t
Describes FP8 interpretation.
Definition amd_hip_fp8.h:122
@ __HIP_E4M3_FNUZ
Definition amd_hip_fp8.h:125
@ __HIP_E5M2
Definition amd_hip_fp8.h:124
@ __HIP_E4M3
Definition amd_hip_fp8.h:123
@ __HIP_E5M2_FNUZ
Definition amd_hip_fp8.h:126
__FP8_HOST_DEVICE_STATIC__ __half_raw __hip_cvt_fp8_to_halfraw(const __hip_fp8_storage_t x, const __hip_fp8_interpretation_t interp)
convert __hip_fp8_storage_t to __half_raw
Definition amd_hip_fp8.h:847
__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_double2_to_fp8x2(const double2 d2, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp)
convert double2 to __hip_fp8x2_storage_t
Definition amd_hip_fp8.h:771
__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_float2_to_fp8x2(const float2 f2, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp)
convert float2 to __hip_fp8x2_storage_t
Definition amd_hip_fp8.h:716
unsigned short int __hip_fp8x2_storage_t
type to store two fp8 numbers
Definition amd_hip_fp8.h:148
__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_bfloat16raw_to_fp8(const __hip_bfloat16_raw hr, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp)
convert __hip_bfloat16_raw to __hip_fp8_storage_t
Definition amd_hip_fp8.h:797
__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_halfraw2_to_fp8x2(const __half2_raw x, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp)
convert __half2_raw to __hip_fp8x2_storage_t
Definition amd_hip_fp8.h:923
unsigned int __hip_fp8x4_storage_t
type to store four fp8 numbers
Definition amd_hip_fp8.h:155
__FP8_HOST_DEVICE_STATIC__ __half2_raw __hip_cvt_fp8x2_to_halfraw2(const __hip_fp8x2_storage_t x, const __hip_fp8_interpretation_t interp)
convert __hip_fp8x2_storage_t to __half2_raw
Definition amd_hip_fp8.h:875
__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_halfraw_to_fp8(const __half_raw x, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp)
convert __half_raw to __hip_fp8_storage_t
Definition amd_hip_fp8.h:901
unsigned char __hip_fp8_storage_t
type to store single fp8 number
Definition amd_hip_fp8.h:141
__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_float_to_fp8(const float f, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp)
convert float to __hip_fp8_storage_t
Definition amd_hip_fp8.h:682
struct representing single fp8 number with e4m3 interpretation
Definition amd_hip_fp8.h:938
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const unsigned int val)
Definition amd_hip_fp8.h:990
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz()=default
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const double f)
Definition amd_hip_fp8.h:1010
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const __half f)
Definition amd_hip_fp8.h:1038
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const short int val)
Definition amd_hip_fp8.h:970
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const float f)
Definition amd_hip_fp8.h:1019
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const __hip_bfloat16 f)
Definition amd_hip_fp8.h:1028
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const long int val)
Definition amd_hip_fp8.h:950
static constexpr __hip_saturation_t __default_saturation
raw storage of fp8 number
Definition amd_hip_fp8.h:940
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const unsigned long int val)
Definition amd_hip_fp8.h:980
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const unsigned short int val)
Definition amd_hip_fp8.h:1000
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const int val)
Definition amd_hip_fp8.h:960
struct representing two fp8 numbers with e4m3 interpretation
Definition amd_hip_fp8.h:1303
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const float2 val)
Definition amd_hip_fp8.h:1321
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const double2 val)
Definition amd_hip_fp8.h:1312
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const __half2 val)
Definition amd_hip_fp8.h:1339
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz()=default
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const __hip_bfloat162 val)
Definition amd_hip_fp8.h:1330
struct representing four fp8 numbers with e4m3 interpretation
Definition amd_hip_fp8.h:1383
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const __hip_bfloat162 low, const __hip_bfloat162 high)
Definition amd_hip_fp8.h:1432
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz()=default
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const double4 val)
Definition amd_hip_fp8.h:1392
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const __half2 low, const __half2 high)
Definition amd_hip_fp8.h:1446
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const float4 val)
Definition amd_hip_fp8.h:1412
struct representing one fp8 number with e5m2 interpretation
Definition amd_hip_fp8.h:1495
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const unsigned short int val)
Definition amd_hip_fp8.h:1558
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz()=default
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const __hip_bfloat16 f)
Definition amd_hip_fp8.h:1586
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const long int val)
Definition amd_hip_fp8.h:1508
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const unsigned int val)
Definition amd_hip_fp8.h:1548
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const __half f)
Definition amd_hip_fp8.h:1596
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const int val)
Definition amd_hip_fp8.h:1518
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const double f)
Definition amd_hip_fp8.h:1568
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const short int val)
Definition amd_hip_fp8.h:1528
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const float f)
Definition amd_hip_fp8.h:1577
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const unsigned long int val)
Definition amd_hip_fp8.h:1538
struct representing two fp8 numbers with e5m2 interpretation
Definition amd_hip_fp8.h:1861
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const float2 val)
Definition amd_hip_fp8.h:1879
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const __half2 val)
Definition amd_hip_fp8.h:1897
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const __hip_bfloat162 val)
Definition amd_hip_fp8.h:1888
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const double2 val)
Definition amd_hip_fp8.h:1870
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz()=default
struct representing four fp8 numbers with e5m2 interpretation
Definition amd_hip_fp8.h:1941
__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const __hip_bfloat162 low, const __hip_bfloat162 high)
Definition amd_hip_fp8.h:1990
__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const float4 val)
Definition amd_hip_fp8.h:1970
__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const __half2 low, const __half2 high)
Definition amd_hip_fp8.h:2004
__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const double4 val)
Definition amd_hip_fp8.h:1950
struct representing ocp fp8 numbers with e4m3 interpretation
Definition amd_hip_fp8.h:2058
__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const __hip_bfloat16 f)
Definition amd_hip_fp8.h:2143
__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const long int val)
Definition amd_hip_fp8.h:2070
__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const unsigned short int val)
Definition amd_hip_fp8.h:2115
__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const float f)
Definition amd_hip_fp8.h:2134
__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const short int val)
Definition amd_hip_fp8.h:2089
__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const __half f)
Definition amd_hip_fp8.h:2153
__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const unsigned long int val)
Definition amd_hip_fp8.h:2095
__FP8_HOST_DEVICE__ __hip_fp8_e4m3()=default
__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const int val)
Definition amd_hip_fp8.h:2080
__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const unsigned int val)
Definition amd_hip_fp8.h:2105
__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const double f)
Definition amd_hip_fp8.h:2125
struct representing two ocp fp8 numbers with e4m3 interpretation
Definition amd_hip_fp8.h:2419
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3(const float2 val)
Definition amd_hip_fp8.h:2438
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3()=default
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3(const __half2 val)
Definition amd_hip_fp8.h:2456
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3(const __hip_bfloat162 val)
Definition amd_hip_fp8.h:2447
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3(const double2 val)
Definition amd_hip_fp8.h:2429
struct representing four ocp fp8 numbers with e4m3 interpretation
Definition amd_hip_fp8.h:2500
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3(const double4 val)
Definition amd_hip_fp8.h:2510
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3(const float4 val)
Definition amd_hip_fp8.h:2530
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3()=default
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3(const __half2 low, const __half2 high)
Definition amd_hip_fp8.h:2564
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3(const __hip_bfloat162 low, const __hip_bfloat162 high)
Definition amd_hip_fp8.h:2550
struct representing ocp fp8 numbers with e5m2 interpretation
Definition amd_hip_fp8.h:2613
__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const int val)
Definition amd_hip_fp8.h:2637
__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const short int val)
Definition amd_hip_fp8.h:2647
__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const unsigned int val)
Definition amd_hip_fp8.h:2667
__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const float f)
Definition amd_hip_fp8.h:2696
__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const unsigned short int val)
Definition amd_hip_fp8.h:2677
__FP8_HOST_DEVICE__ __hip_fp8_e5m2()=default
__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const unsigned long int val)
Definition amd_hip_fp8.h:2657
__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const long int val)
Definition amd_hip_fp8.h:2627
__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const double f)
Definition amd_hip_fp8.h:2687
__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const __hip_bfloat16 f)
Definition amd_hip_fp8.h:2705
__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const __half f)
Definition amd_hip_fp8.h:2715
struct representing two ocp fp8 numbers with e5m2 interpretation
Definition amd_hip_fp8.h:2982
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2(const __half2 val)
Definition amd_hip_fp8.h:3019
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2(const double2 val)
Definition amd_hip_fp8.h:2992
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2()=default
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2(const float2 val)
Definition amd_hip_fp8.h:3001
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2(const __hip_bfloat162 val)
Definition amd_hip_fp8.h:3010
struct representing four ocp fp8 numbers with e5m2 interpretation
Definition amd_hip_fp8.h:3064
__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2(const __hip_bfloat162 low, const __hip_bfloat162 high)
Definition amd_hip_fp8.h:3113
__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2(const double4 val)
Definition amd_hip_fp8.h:3073
__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2(const __half2 low, const __half2 high)
Definition amd_hip_fp8.h:3127
__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2(const float4 val)
Definition amd_hip_fp8.h:3093
Definition amd_hip_vector_types.h:2035
Definition amd_hip_vector_types.h:2042
Definition amd_hip_vector_types.h:2072
Definition amd_hip_vector_types.h:2079
Definition hip_fp16_gcc.h:7
Definition hip_fp16_gcc.h:11