diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index af330c78e300f..8f99889743748 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3249,13 +3249,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_DIAG_MASK_INF: return true; case GGML_OP_SOFT_MAX: - // TODO: support batching - if (op->src[0]->ne[3] != 1) { - return false; - } - // TODO: support broadcast - // ref: https://github.com/ggml-org/llama.cpp/pull/14435 - return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1); + return true; case GGML_OP_SOFT_MAX_BACK: { float max_bias = 0.0f; memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float)); diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu index aac6e0999880a..e0eb921d85d53 100644 --- a/ggml/src/ggml-cuda/softmax.cu +++ b/ggml/src/ggml-cuda/softmax.cu @@ -13,6 +13,29 @@ __device__ float __forceinline__ t2f32(half val) { return __half2float(val); } +struct soft_max_params { + + int64_t nheads; + uint32_t n_head_log2; + int64_t ncols; + int64_t nrows_x; + int64_t nrows_y; + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + int64_t nb11; + int64_t nb12; + int64_t nb13; + + int64_t ne12; + int64_t ne13; + float scale; + float max_bias; + float m0; + float m1; +}; + // When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled. // As we want to keep pragma unroll for all other cases we supress the clang transformation warning here. #ifdef __clang__ @@ -21,16 +44,24 @@ __device__ float __forceinline__ t2f32(half val) { #endif // __clang__ template static __global__ void soft_max_f32( - const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, - const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { - const int ncols = ncols_template == 0 ? ncols_par : ncols_template; + const float * x, const T * mask, float * dst, const soft_max_params p) { + const int ncols = ncols_template == 0 ? p.ncols : ncols_template; const int tid = threadIdx.x; - const int rowx = blockIdx.x; - const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension + + const int64_t i03 = blockIdx.z; + const int64_t i02 = blockIdx.y; + const int64_t i01 = blockIdx.x; + + //TODO: noncontigous inputs/outputs + const int rowx = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y; + + const int64_t i11 = i01; + const int64_t i12 = i02 % p.ne12; + const int64_t i13 = i03 % p.ne13; x += int64_t(rowx)*ncols; - mask += int64_t(rowy)*ncols * (mask != nullptr); + mask += (i11*p.nb11 + i12*p.nb12 + i13*p.nb13) / sizeof(T) * (mask != nullptr); dst += int64_t(rowx)*ncols; const int block_size = block_size_template == 0 ? blockDim.x : block_size_template; @@ -38,7 +69,7 @@ static __global__ void soft_max_f32( const int warp_id = threadIdx.x / WARP_SIZE; const int lane_id = threadIdx.x % WARP_SIZE; - const float slope = get_alibi_slope(max_bias, rowx/nrows_y, n_head_log2, m0, m1); + const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1); extern __shared__ float data_soft_max_f32[]; float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication @@ -55,7 +86,7 @@ static __global__ void soft_max_f32( break; } - const float val = x[col]*scale + (mask ? slope*t2f32(mask[col]) : 0.0f); + const float val = x[col]*p.scale + (mask ? slope*t2f32(mask[col]) : 0.0f); vals[col] = val; max_val = max(max_val, val); @@ -151,63 +182,60 @@ static __global__ void soft_max_back_f32( } template -static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) { +static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) { int nth = WARP_SIZE; + const int64_t ncols_x = params.ncols; + while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); - const dim3 block_nums(nrows_x, 1, 1); + const dim3 block_nums(params.ne01, params.ne02, params.ne03); const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float); static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); - const uint32_t n_head = nrows_x/nrows_y; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); - - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); // FIXME: this limit could be raised by ~2-4x on Ampere or newer if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) { switch (ncols_x) { case 32: soft_max_f32<<>> - (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + (x, mask, dst, params); break; case 64: soft_max_f32<<>> - (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + (x, mask, dst, params); break; case 128: soft_max_f32<<>> - (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + (x, mask, dst, params); break; case 256: soft_max_f32<<>> - (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + (x, mask, dst, params); break; case 512: soft_max_f32<<>> - (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + (x, mask, dst, params); break; case 1024: soft_max_f32<<>> - (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + (x, mask, dst, params); break; case 2048: soft_max_f32<<>> - (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + (x, mask, dst, params); break; case 4096: soft_max_f32<<>> - (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + (x, mask, dst, params); break; default: soft_max_f32<<>> - (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + (x, mask, dst, params); break; } } else { const size_t nbytes_shared_low = WARP_SIZE*sizeof(float); - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, params); } } @@ -235,10 +263,11 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional - const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); const int64_t nrows_y = src0->ne[1]; + const int64_t ne00 = src0->ne[0]; + float scale = 1.0f; float max_bias = 0.0f; @@ -247,10 +276,44 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + const int64_t nb11 = src1 ? src1->nb[1] : 1; + const int64_t nb12 = src1 ? src1->nb[2] : 1; + const int64_t nb13 = src1 ? src1->nb[3] : 1; + + const int64_t ne12 = src1 ? src1->ne[2] : 1; + const int64_t ne13 = src1 ? src1->ne[3] : 1; + + const uint32_t n_head = src0->ne[2]; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + + soft_max_params params = {}; + params.nheads = src0->ne[2]; + params.n_head_log2 = n_head_log2; + params.ncols = ne00; + params.nrows_x = nrows_x; + params.nrows_y = nrows_y; + params.ne00 = src0->ne[0]; + params.ne01 = src0->ne[1]; + params.ne02 = src0->ne[2]; + params.ne03 = src0->ne[3]; + params.nb11 = nb11; + params.nb12 = nb12; + params.nb13 = nb13; + params.ne12 = ne12; + params.ne13 = ne13; + params.scale = scale; + params.max_bias = max_bias; + params.m0 = m0; + params.m1 = m1; + if (use_f16) { - soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, params, stream); } else { - soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, params, stream); } }