From 6bfa48347a073193a4aee42fe8cd7d0ec4c1d1fa Mon Sep 17 00:00:00 2001 From: Trong Tan <67775223+jushg@users.noreply.github.com> Date: Tue, 8 Jul 2025 11:05:39 +0100 Subject: [PATCH 1/2] Update reduce_kernel.h Change the casting template for FP8 types to be compatible with c++20 --- src/device/reduce_kernel.h | 48 +++++++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/src/device/reduce_kernel.h b/src/device/reduce_kernel.h index d36dfe5a7..2941faa43 100644 --- a/src/device/reduce_kernel.h +++ b/src/device/reduce_kernel.h @@ -256,28 +256,38 @@ struct Apply_Cast { }; #endif -#define EASY_CAST(A, B, EltPerPack, VecA, VecB) \ - template<> \ - struct Apply_Cast { \ - __device__ __forceinline__ static BytePack cast(BytePack a) { \ - return toPack(VecB(fromPack(a))); \ - } \ - }; \ - template<> \ - struct Apply_Cast { \ - __device__ __forceinline__ static BytePack cast(BytePack b) { \ - return toPack(VecA(fromPack(b))); \ - } \ - }; - #if defined(__CUDA_FP8_TYPES_EXIST__) -EASY_CAST(__nv_fp8_e5m2, float, 2, __nv_fp8x2_e5m2, float2) -EASY_CAST(__nv_fp8_e5m2, float, 4, __nv_fp8x4_e5m2, float4) +#define FP8_CAST(FP8_T, VEC2_T, VEC4_T) \ +template<> struct Apply_Cast { \ + __device__ __forceinline__ static BytePack<8> cast(BytePack<2> a) { \ + VEC2_T va = fromPack(a); FP8_T* p = reinterpret_cast(&va); \ + return toPack(make_float2(float(__half(p[0])), float(__half(p[1])))); \ + } \ +}; \ +template<> struct Apply_Cast { \ + __device__ __forceinline__ static BytePack<2> cast(BytePack<8> b) { \ + float2 vb = fromPack(b); VEC2_T va; FP8_T* p = reinterpret_cast(&va); \ + p[0] = FP8_T(__half(vb.x)); p[1] = FP8_T(__half(vb.y)); return toPack(va); \ + } \ +}; \ +template<> struct Apply_Cast { \ + __device__ __forceinline__ static BytePack<16> cast(BytePack<4> a) { \ + VEC4_T va = fromPack(a); FP8_T* p = reinterpret_cast(&va); \ + return toPack(make_float4(float(__half(p[0])), float(__half(p[1])), float(__half(p[2])), float(__half(p[3])))); \ + } \ +}; \ +template<> struct Apply_Cast { \ + __device__ __forceinline__ static BytePack<4> cast(BytePack<16> b) { \ + float4 vb = fromPack(b); VEC4_T va; FP8_T* p = reinterpret_cast(&va); \ + p[0] = FP8_T(__half(vb.x)); p[1] = FP8_T(__half(vb.y)); p[2] = FP8_T(__half(vb.z)); p[3] = FP8_T(__half(vb.w)); \ + return toPack(va); \ + } \ +}; -EASY_CAST(__nv_fp8_e4m3, float, 2, __nv_fp8x2_e4m3, float2) -EASY_CAST(__nv_fp8_e4m3, float, 4, __nv_fp8x4_e4m3, float4) +FP8_CAST(__nv_fp8_e5m2, __nv_fp8x2_e5m2, __nv_fp8x4_e5m2) +FP8_CAST(__nv_fp8_e4m3, __nv_fp8x2_e4m3, __nv_fp8x4_e4m3) +#undef FP8_CAST #endif -#undef EASY_CAST //////////////////////////////////////////////////////////////////////////////// // Apply_Reduce From 51f51741d59c89fa0ac3338faaeec95f68755703 Mon Sep 17 00:00:00 2001 From: Trong Tan <67775223+jushg@users.noreply.github.com> Date: Fri, 11 Jul 2025 11:40:59 +0100 Subject: [PATCH 2/2] Update reduce_kernel.h (simpler) Turn out need much less change --- src/device/reduce_kernel.h | 48 +++++++++++++++----------------------- 1 file changed, 19 insertions(+), 29 deletions(-) diff --git a/src/device/reduce_kernel.h b/src/device/reduce_kernel.h index 2941faa43..af927a924 100644 --- a/src/device/reduce_kernel.h +++ b/src/device/reduce_kernel.h @@ -256,38 +256,28 @@ struct Apply_Cast { }; #endif +#define EASY_CAST(A, B, EltPerPack, VecA, VecB) \ + template<> \ + struct Apply_Cast { \ + __device__ __forceinline__ static BytePack cast(BytePack a) { \ + return toPack(static_cast(fromPack(a))); \ + } \ + }; \ + template<> \ + struct Apply_Cast { \ + __device__ __forceinline__ static BytePack cast(BytePack b) { \ + return toPack(VecA(fromPack(b))); \ + } \ + }; + #if defined(__CUDA_FP8_TYPES_EXIST__) -#define FP8_CAST(FP8_T, VEC2_T, VEC4_T) \ -template<> struct Apply_Cast { \ - __device__ __forceinline__ static BytePack<8> cast(BytePack<2> a) { \ - VEC2_T va = fromPack(a); FP8_T* p = reinterpret_cast(&va); \ - return toPack(make_float2(float(__half(p[0])), float(__half(p[1])))); \ - } \ -}; \ -template<> struct Apply_Cast { \ - __device__ __forceinline__ static BytePack<2> cast(BytePack<8> b) { \ - float2 vb = fromPack(b); VEC2_T va; FP8_T* p = reinterpret_cast(&va); \ - p[0] = FP8_T(__half(vb.x)); p[1] = FP8_T(__half(vb.y)); return toPack(va); \ - } \ -}; \ -template<> struct Apply_Cast { \ - __device__ __forceinline__ static BytePack<16> cast(BytePack<4> a) { \ - VEC4_T va = fromPack(a); FP8_T* p = reinterpret_cast(&va); \ - return toPack(make_float4(float(__half(p[0])), float(__half(p[1])), float(__half(p[2])), float(__half(p[3])))); \ - } \ -}; \ -template<> struct Apply_Cast { \ - __device__ __forceinline__ static BytePack<4> cast(BytePack<16> b) { \ - float4 vb = fromPack(b); VEC4_T va; FP8_T* p = reinterpret_cast(&va); \ - p[0] = FP8_T(__half(vb.x)); p[1] = FP8_T(__half(vb.y)); p[2] = FP8_T(__half(vb.z)); p[3] = FP8_T(__half(vb.w)); \ - return toPack(va); \ - } \ -}; +EASY_CAST(__nv_fp8_e5m2, float, 2, __nv_fp8x2_e5m2, float2) +EASY_CAST(__nv_fp8_e5m2, float, 4, __nv_fp8x4_e5m2, float4) -FP8_CAST(__nv_fp8_e5m2, __nv_fp8x2_e5m2, __nv_fp8x4_e5m2) -FP8_CAST(__nv_fp8_e4m3, __nv_fp8x2_e4m3, __nv_fp8x4_e4m3) -#undef FP8_CAST +EASY_CAST(__nv_fp8_e4m3, float, 2, __nv_fp8x2_e4m3, float2) +EASY_CAST(__nv_fp8_e4m3, float, 4, __nv_fp8x4_e4m3, float4) #endif +#undef EASY_CAST //////////////////////////////////////////////////////////////////////////////// // Apply_Reduce