Skip to content

Commit a37d885

Browse files
committed
cuda : use amd wave sharing intrinsics for warp_reduce functions
1 parent 54ea069 commit a37d885

File tree

1 file changed

+64
-0
lines changed

1 file changed

+64
-0
lines changed

ggml-cuda/common.cuh

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,27 +248,87 @@ static __device__ void no_device_code(
248248
GGML_UNUSED(no_device_code); // suppress unused function warning
249249
}
250250

251+
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
252+
#define AMD_SWIZZLE_MASK(and_mask, or_mask, xor_mask) ((and_mask) | ((or_mask)<<5) | ((xor_mask)<<10)) // 5-bit masks applied sequentially to the thread id
253+
#define AMD_DPP_ROW_RR(x) (0x120+(x)) // 121-12F - row rotate right by 1-15 threads - a row is 16 threads
254+
#define hip_move_dppf(src, dpp_ctrl, row_mask, bank_mask, bound_ctrl) \
255+
hip_move_dppf_N<(dpp_ctrl), (row_mask), (bank_mask), (bound_ctrl)>((src))
256+
257+
template <int dpp_ctrl, int row_mask, int bank_mask, bool bound_ctrl>
258+
static __device__ __forceinline__ float hip_move_dppf_N(float x) {
259+
typedef union float_b32 {
260+
float val;
261+
int b32;
262+
} float_b32_t;
263+
float_b32_t tmp;
264+
tmp.val = x;
265+
tmp.b32 = __builtin_amdgcn_mov_dpp(tmp.b32, dpp_ctrl, row_mask, bank_mask, bound_ctrl);
266+
return tmp.val;
267+
}
268+
269+
static __device__ __forceinline__ float warp_reduce_sum_impl_amd(float x) {
270+
x += __hip_ds_swizzlef(x, AMD_SWIZZLE_MASK(0x1F, 0, 0x10)); // swap neighbouring groups of 16 lanes
271+
x += hip_move_dppf(x, AMD_DPP_ROW_RR(8), 0xF, 0xF, true);
272+
x += hip_move_dppf(x, AMD_DPP_ROW_RR(4), 0xF, 0xF, true);
273+
x += hip_move_dppf(x, AMD_DPP_ROW_RR(2), 0xF, 0xF, true);
274+
x += hip_move_dppf(x, AMD_DPP_ROW_RR(1), 0xF, 0xF, true);
275+
return x;
276+
}
277+
278+
static __device__ __forceinline__ float2 warp_reduce_sum_impl_amd(float2 a) {
279+
a.x += __hip_ds_swizzlef(a.x, AMD_SWIZZLE_MASK(0x1F, 0, 0x10));
280+
a.y += __hip_ds_swizzlef(a.y, AMD_SWIZZLE_MASK(0x1F, 0, 0x10));
281+
a.x += hip_move_dppf(a.x, AMD_DPP_ROW_RR(8), 0xF, 0xF, true);
282+
a.y += hip_move_dppf(a.y, AMD_DPP_ROW_RR(8), 0xF, 0xF, true);
283+
a.x += hip_move_dppf(a.x, AMD_DPP_ROW_RR(4), 0xF, 0xF, true);
284+
a.y += hip_move_dppf(a.y, AMD_DPP_ROW_RR(4), 0xF, 0xF, true);
285+
a.x += hip_move_dppf(a.x, AMD_DPP_ROW_RR(2), 0xF, 0xF, true);
286+
a.y += hip_move_dppf(a.y, AMD_DPP_ROW_RR(2), 0xF, 0xF, true);
287+
a.x += hip_move_dppf(a.x, AMD_DPP_ROW_RR(1), 0xF, 0xF, true);
288+
a.y += hip_move_dppf(a.y, AMD_DPP_ROW_RR(1), 0xF, 0xF, true);
289+
return a;
290+
}
291+
292+
static __device__ __forceinline__ float warp_reduce_max_impl_amd(float x) {
293+
x = fmaxf(x, __hip_ds_swizzlef(x, AMD_SWIZZLE_MASK(0x1F, 0, 0x10)));
294+
x = fmaxf(x, hip_move_dppf(x, AMD_DPP_ROW_RR(8), 0xF, 0xF, false));
295+
x = fmaxf(x, hip_move_dppf(x, AMD_DPP_ROW_RR(4), 0xF, 0xF, false));
296+
x = fmaxf(x, hip_move_dppf(x, AMD_DPP_ROW_RR(2), 0xF, 0xF, false));
297+
x = fmaxf(x, hip_move_dppf(x, AMD_DPP_ROW_RR(1), 0xF, 0xF, false));
298+
return x;
299+
}
300+
301+
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
302+
251303
#ifdef __CUDA_ARCH__
252304
#define NO_DEVICE_CODE no_device_code(__FILE__, __LINE__, __FUNCTION__, __CUDA_ARCH__, STRINGIZE(__CUDA_ARCH_LIST__))
253305
#else
254306
#define NO_DEVICE_CODE //GGML_ASSERT(false && "NO_DEVICE_CODE not valid in host code.")
255307
#endif // __CUDA_ARCH__
256308

257309
static __device__ __forceinline__ float warp_reduce_sum(float x) {
310+
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
311+
return warp_reduce_sum_impl_amd(x);
312+
#else
258313
#pragma unroll
259314
for (int mask = 16; mask > 0; mask >>= 1) {
260315
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
261316
}
262317
return x;
318+
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
263319
}
264320

265321
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
322+
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
323+
return warp_reduce_sum_impl_amd(a);
324+
#else
266325
#pragma unroll
267326
for (int mask = 16; mask > 0; mask >>= 1) {
268327
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
269328
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
270329
}
271330
return a;
331+
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
272332
}
273333

274334
#ifdef GGML_CUDA_F16
@@ -287,11 +347,15 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
287347
#endif // GGML_CUDA_F16
288348

289349
static __device__ __forceinline__ float warp_reduce_max(float x) {
350+
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
351+
return warp_reduce_max_impl_amd(x);
352+
#else
290353
#pragma unroll
291354
for (int mask = 16; mask > 0; mask >>= 1) {
292355
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
293356
}
294357
return x;
358+
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
295359
}
296360

297361
//static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {

0 commit comments

Comments
 (0)