From 7a3f7e94ba9d17ba801034ad83e5829ee1ac32ab Mon Sep 17 00:00:00 2001 From: Engininja2 <139037756+Engininja2@users.noreply.github.com> Date: Thu, 4 Apr 2024 15:09:03 -0600 Subject: [PATCH 1/3] cuda : use amd wave sharing intrinsics for warp_reduce functions --- ggml-cuda/common.cuh | 63 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 44e67e040e16a..dee2cac67e82b 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -315,6 +315,57 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) { #endif return c; } + +#ifdef __HIP_PLATFORM_AMD__ +#define AMD_SWIZZLE_MASK(and_mask, or_mask, xor_mask) ((and_mask) | ((or_mask)<<5) | ((xor_mask)<<10)) // 5-bit masks applied sequentially to the thread id +#define AMD_DPP_ROW_RR(x) (0x120+(x)) // 121-12F - row rotate right by 1-15 threads - a row is 16 threads +#define hip_move_dppf(src, dpp_ctrl, row_mask, bank_mask, bound_ctrl) \ + hip_move_dppf_N<(dpp_ctrl), (row_mask), (bank_mask), (bound_ctrl)>((src)) + +template +static __device__ __forceinline__ float hip_move_dppf_N(float x) { + typedef union float_b32 { + float val; + int b32; + } float_b32_t; + float_b32_t tmp; + tmp.val = x; + tmp.b32 = __builtin_amdgcn_mov_dpp(tmp.b32, dpp_ctrl, row_mask, bank_mask, bound_ctrl); + return tmp.val; +} + +static __device__ __forceinline__ float warp_reduce_sum_impl_amd(float x) { + x += __hip_ds_swizzlef(x, AMD_SWIZZLE_MASK(0x1F, 0, 0x10)); // swap neighbouring groups of 16 lanes + x += hip_move_dppf(x, AMD_DPP_ROW_RR(8), 0xF, 0xF, true); + x += hip_move_dppf(x, AMD_DPP_ROW_RR(4), 0xF, 0xF, true); + x += hip_move_dppf(x, AMD_DPP_ROW_RR(2), 0xF, 0xF, true); + x += hip_move_dppf(x, AMD_DPP_ROW_RR(1), 0xF, 0xF, true); + return x; +} + +static __device__ __forceinline__ float2 warp_reduce_sum_impl_amd(float2 a) { + a.x += __hip_ds_swizzlef(a.x, AMD_SWIZZLE_MASK(0x1F, 0, 0x10)); + a.y += __hip_ds_swizzlef(a.y, AMD_SWIZZLE_MASK(0x1F, 0, 0x10)); + a.x += hip_move_dppf(a.x, AMD_DPP_ROW_RR(8), 0xF, 0xF, true); + a.y += hip_move_dppf(a.y, AMD_DPP_ROW_RR(8), 0xF, 0xF, true); + a.x += hip_move_dppf(a.x, AMD_DPP_ROW_RR(4), 0xF, 0xF, true); + a.y += hip_move_dppf(a.y, AMD_DPP_ROW_RR(4), 0xF, 0xF, true); + a.x += hip_move_dppf(a.x, AMD_DPP_ROW_RR(2), 0xF, 0xF, true); + a.y += hip_move_dppf(a.y, AMD_DPP_ROW_RR(2), 0xF, 0xF, true); + a.x += hip_move_dppf(a.x, AMD_DPP_ROW_RR(1), 0xF, 0xF, true); + a.y += hip_move_dppf(a.y, AMD_DPP_ROW_RR(1), 0xF, 0xF, true); + return a; +} + +static __device__ __forceinline__ float warp_reduce_max_impl_amd(float x) { + x = fmaxf(x, __hip_ds_swizzlef(x, AMD_SWIZZLE_MASK(0x1F, 0, 0x10))); + x = fmaxf(x, hip_move_dppf(x, AMD_DPP_ROW_RR(8), 0xF, 0xF, false)); + x = fmaxf(x, hip_move_dppf(x, AMD_DPP_ROW_RR(4), 0xF, 0xF, false)); + x = fmaxf(x, hip_move_dppf(x, AMD_DPP_ROW_RR(2), 0xF, 0xF, false)); + x = fmaxf(x, hip_move_dppf(x, AMD_DPP_ROW_RR(1), 0xF, 0xF, false)); + return x; +} +#endif // __HIP_PLATFORM_AMD__ #endif // defined(GGML_USE_HIPBLAS) #define FP16_AVAILABLE (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL @@ -349,20 +400,28 @@ static __device__ void no_device_code( #endif // __CUDA_ARCH__ static __device__ __forceinline__ float warp_reduce_sum(float x) { +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + return warp_reduce_sum_impl_amd(x); +#else #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { x += __shfl_xor_sync(0xffffffff, x, mask, 32); } return x; +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) } static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + return warp_reduce_sum_impl_amd(a); +#else #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32); a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32); } return a; +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) } static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { @@ -391,11 +450,15 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { } static __device__ __forceinline__ float warp_reduce_max(float x) { +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + return warp_reduce_max_impl_amd(x); +#else #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); } return x; +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) } static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) { From 9e6f2e2affe26f76db0d7ddb8add8715f952ef4b Mon Sep 17 00:00:00 2001 From: Engininja2 <139037756+Engininja2@users.noreply.github.com> Date: Sat, 11 May 2024 13:38:34 -0600 Subject: [PATCH 2/3] cuda : add amd dpp version of warp_reduce_sum for half2 --- ggml-cuda/common.cuh | 44 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 37 insertions(+), 7 deletions(-) diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index dee2cac67e82b..1114e6af2ad82 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -321,6 +321,9 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) { #define AMD_DPP_ROW_RR(x) (0x120+(x)) // 121-12F - row rotate right by 1-15 threads - a row is 16 threads #define hip_move_dppf(src, dpp_ctrl, row_mask, bank_mask, bound_ctrl) \ hip_move_dppf_N<(dpp_ctrl), (row_mask), (bank_mask), (bound_ctrl)>((src)) +#define hip_move_dpph2(src, dpp_ctrl, row_mask, bank_mask, bound_ctrl) \ + hip_move_dpph2_N<(dpp_ctrl), (row_mask), (bank_mask), (bound_ctrl)>((src)) +#define hip_ds_swizzleh2(src, pattern) hip_ds_swizzleh2_N<(pattern)>((src)) template static __device__ __forceinline__ float hip_move_dppf_N(float x) { @@ -334,6 +337,30 @@ static __device__ __forceinline__ float hip_move_dppf_N(float x) { return tmp.val; } +template +static __device__ __forceinline__ half2 hip_move_dpph2_N(half2 x) { + typedef union half2_b32 { + half2 val; + int b32; + } half2_b32_t; + half2_b32_t tmp; + tmp.val = x; + tmp.b32 = __builtin_amdgcn_mov_dpp(tmp.b32, dpp_ctrl, row_mask, bank_mask, bound_ctrl); + return tmp.val; +} + +template +static __device__ __forceinline__ half2 hip_ds_swizzleh2_N(half2 src) { + typedef union half2_b32 { + half2 val; + int b32; + } half2_b32_t; + half2_b32_t tmp; + tmp.val = src; + tmp.b32 = __builtin_amdgcn_ds_swizzle(tmp.b32, pattern); + return tmp.val; +} + static __device__ __forceinline__ float warp_reduce_sum_impl_amd(float x) { x += __hip_ds_swizzlef(x, AMD_SWIZZLE_MASK(0x1F, 0, 0x10)); // swap neighbouring groups of 16 lanes x += hip_move_dppf(x, AMD_DPP_ROW_RR(8), 0xF, 0xF, true); @@ -357,6 +384,15 @@ static __device__ __forceinline__ float2 warp_reduce_sum_impl_amd(float2 a) { return a; } +static __device__ __forceinline__ half2 warp_reduce_sum_impl_amd(half2 x) { + x += hip_ds_swizzleh2(x, AMD_SWIZZLE_MASK(0x1F, 0, 0x10)); + x += hip_move_dpph2(x, AMD_DPP_ROW_RR(8), 0xF, 0xF, true); + x += hip_move_dpph2(x, AMD_DPP_ROW_RR(4), 0xF, 0xF, true); + x += hip_move_dpph2(x, AMD_DPP_ROW_RR(2), 0xF, 0xF, true); + x += hip_move_dpph2(x, AMD_DPP_ROW_RR(1), 0xF, 0xF, true); + return x; +} + static __device__ __forceinline__ float warp_reduce_max_impl_amd(float x) { x = fmaxf(x, __hip_ds_swizzlef(x, AMD_SWIZZLE_MASK(0x1F, 0, 0x10))); x = fmaxf(x, hip_move_dppf(x, AMD_DPP_ROW_RR(8), 0xF, 0xF, false)); @@ -428,13 +464,7 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { #if FP16_AVAILABLE #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - const half2 a_other = __shfl_xor_sync(0xffffffff, a, mask, 32); - reinterpret_cast(a.x) += __low2half(a_other); - reinterpret_cast(a.y) += __high2half(a_other); - } - return a; + return warp_reduce_sum_impl_amd(a); #else #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { From 6c8fdb8e5aac771ddd7e85af7b2c61a02ae46773 Mon Sep 17 00:00:00 2001 From: Engininja2 <139037756+Engininja2@users.noreply.github.com> Date: Sat, 11 May 2024 21:07:59 -0600 Subject: [PATCH 3/3] adding the components of half2 seems to be compiled faster --- ggml-cuda/common.cuh | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 1114e6af2ad82..a2f822e7788d8 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -384,13 +384,24 @@ static __device__ __forceinline__ float2 warp_reduce_sum_impl_amd(float2 a) { return a; } -static __device__ __forceinline__ half2 warp_reduce_sum_impl_amd(half2 x) { - x += hip_ds_swizzleh2(x, AMD_SWIZZLE_MASK(0x1F, 0, 0x10)); - x += hip_move_dpph2(x, AMD_DPP_ROW_RR(8), 0xF, 0xF, true); - x += hip_move_dpph2(x, AMD_DPP_ROW_RR(4), 0xF, 0xF, true); - x += hip_move_dpph2(x, AMD_DPP_ROW_RR(2), 0xF, 0xF, true); - x += hip_move_dpph2(x, AMD_DPP_ROW_RR(1), 0xF, 0xF, true); - return x; +static __device__ __forceinline__ half2 warp_reduce_sum_impl_amd(half2 a) { + half2 tmp; + tmp = hip_ds_swizzleh2(a, AMD_SWIZZLE_MASK(0x1F, 0, 0x10)); + a.data.x += tmp.data.x; + a.data.y += tmp.data.y; + tmp = hip_move_dpph2(a, AMD_DPP_ROW_RR(8), 0xF, 0xF, true); + a.data.x += tmp.data.x; + a.data.y += tmp.data.y; + tmp = hip_move_dpph2(a, AMD_DPP_ROW_RR(4), 0xF, 0xF, true); + a.data.x += tmp.data.x; + a.data.y += tmp.data.y; + tmp = hip_move_dpph2(a, AMD_DPP_ROW_RR(2), 0xF, 0xF, true); + a.data.x += tmp.data.x; + a.data.y += tmp.data.y; + tmp = hip_move_dpph2(a, AMD_DPP_ROW_RR(1), 0xF, 0xF, true); + a.data.x += tmp.data.x; + a.data.y += tmp.data.y; + return a; } static __device__ __forceinline__ float warp_reduce_max_impl_amd(float x) {