diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 2b1bd6e0f48b9..f08712db37bb7 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1433,8 +1433,14 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + // a [ne0, ne01, ne02, ne03] + // mask [ne0, ne11, ne12, ne13] | ne11 >= ne01, F16 or F32, optional + // + // broadcast: + // ne02 % ne12 == 0 + // ne03 % ne13 == 0 + // // fused soft_max(a*scale + mask*(ALiBi slope)) - // mask is optional // max_bias = 0.0f for no ALiBi GGML_API struct ggml_tensor * ggml_soft_max_ext( struct ggml_context * ctx, @@ -1868,11 +1874,16 @@ extern "C" { #define GGML_KQ_MASK_PAD 64 - // q: [n_embd_k, n_batch, n_head, 1] - // k: [n_embd_k, n_kv, n_head_kv, 1] - // v: [n_embd_v, n_kv, n_head_kv, 1] !! not transposed !! - // mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !! - // res: [n_embd_v, n_head, n_batch, 1] !! permuted !! + // q: [n_embd_k, n_batch, n_head, ne3] + // k: [n_embd_k, n_kv, n_head_kv, ne3] + // v: [n_embd_v, n_kv, n_head_kv, ne3] !! not transposed !! + // mask: [n_kv, n_batch_pad, ne32, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !! + // res: [n_embd_v, n_head, n_batch, ne3] !! permuted !! + // + // broadcast: + // n_head % n_head_kv == 0 + // ne3 % ne32 == 0 + // GGML_API struct ggml_tensor * ggml_flash_attn_ext( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index d1a0ad374d691..8a3d2026ed898 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2187,7 +2187,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_OP_SQRT: case GGML_OP_CLAMP: case GGML_OP_DIAG_MASK_INF: - case GGML_OP_SOFT_MAX: case GGML_OP_SUM_ROWS: case GGML_OP_ARGSORT: case GGML_OP_ACC: @@ -2205,6 +2204,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_OP_PAD_REFLECT_1D: case GGML_OP_COUNT_EQUAL: return true; + case GGML_OP_SOFT_MAX: + // TODO: support broadcast + // ref: https://github.com/ggml-org/llama.cpp/pull/14435 + return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1); case GGML_OP_FLASH_ATTN_EXT:{ // derived from [ggml-cuda.cu] if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){ @@ -2227,6 +2230,8 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, // DeepSeek MLA return false; } + // TODO: support broadcast + // ref: https://github.com/ggml-org/llama.cpp/pull/14435 if (op->src[0]->ne[3] != 1) { return false; } diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 9f17ea43c8553..67bc29dbd754c 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -4802,14 +4802,17 @@ static void ggml_compute_forward_soft_max_f32( memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); - // TODO: handle transposed/permuted matrices - const int ith = params->ith; const int nth = params->nth; GGML_TENSOR_UNARY_OP_LOCALS - //const int64_t ne11 = src1 ? src1->ne[1] : 1; + const int64_t nb11 = src1 ? src1->nb[1] : 1; + const int64_t nb12 = src1 ? src1->nb[2] : 1; + const int64_t nb13 = src1 ? src1->nb[3] : 1; + + const int64_t ne12 = src1 ? src1->ne[2] : 1; + const int64_t ne13 = src1 ? src1->ne[3] : 1; // TODO: is this supposed to be ceil instead of floor? // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370 @@ -4819,68 +4822,66 @@ static void ggml_compute_forward_soft_max_f32( const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - const int nc = src0->ne[0]; - const int nr = ggml_nrows(src0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; + float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); - for (int i1 = ir0; i1 < ir1; i1++) { - // ALiBi - const uint32_t h = (i1/ne01)%ne02; // head - const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; - - float * sp = (float *)((char *) src0->data + i1*src0->nb[1]); - float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); - - // broadcast the mask across rows - ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; - float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; - - ggml_vec_cpy_f32 (nc, wp, sp); - ggml_vec_scale_f32(nc, wp, scale); - if (mp_f32) { - if (use_f16) { - for (int i = 0; i < nc; ++i) { - wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]); - } - } else { - for (int i = 0; i < nc; ++i) { - wp[i] += slope*mp_f32[i]; + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ith; i01 < ne01; i01 += nth) { + const int64_t i11 = i01; + const int64_t i12 = i02%ne12; + const int64_t i13 = i03%ne13; + + // ALiBi + const uint32_t h = i02; // head + const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; + + float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + + // broadcast the mask across rows + ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL; + float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL; + + ggml_vec_cpy_f32 (ne00, wp, sp); + ggml_vec_scale_f32(ne00, wp, scale); + if (mp_f32) { + if (use_f16) { + for (int i = 0; i < ne00; ++i) { + wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]); + } + } else { + for (int i = 0; i < ne00; ++i) { + wp[i] += slope*mp_f32[i]; + } + } } - } - } #ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - //printf("p[%d] = %f\n", i, p[i]); - assert(!isnan(wp[i])); - } + for (int i = 0; i < ne00; ++i) { + //printf("p[%d] = %f\n", i, p[i]); + assert(!isnan(wp[i])); + } #endif - float max = -INFINITY; - ggml_vec_max_f32(nc, &max, wp); + float max = -INFINITY; + ggml_vec_max_f32(ne00, &max, wp); - ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max); - assert(sum > 0.0); + ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max); + assert(sum > 0.0); - sum = 1.0/sum; - ggml_vec_scale_f32(nc, dp, sum); + sum = 1.0/sum; + ggml_vec_scale_f32(ne00, dp, sum); #ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - assert(!isnan(dp[i])); - assert(!isinf(dp[i])); - } + for (int i = 0; i < ne00; ++i) { + assert(!isnan(dp[i])); + assert(!isinf(dp[i])); + } #endif + } + } } } @@ -7151,7 +7152,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type; + ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type; ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float; ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot; ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float; @@ -7183,7 +7184,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( memset(VKQ32, 0, DV*sizeof(float)); } - const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL; + const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq3%mask->ne[2])*mask->nb[2]) : NULL; // k indices const int ik3 = iq3 / rk3; diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index cfab2b5ebaccc..075f14a49e9ac 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -32,7 +32,9 @@ typedef void (* fattn_kernel_t)( const int ne12, const int ne13, const int ne31, + const int ne32, const int nb31, + const int nb32, const int nb01, const int nb02, const int nb03, @@ -851,7 +853,8 @@ void launch_fattn( scale, max_bias, m0, m1, n_head_log2, logit_softcap, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, + mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, Q->nb[1], Q->nb[2], Q->nb[3], nb11, nb12, nb13, nb21, nb22, nb23, diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index e230f6d494d77..709589854f0af 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -1223,7 +1223,9 @@ static __global__ void flash_attn_ext_f16( const int ne12, const int ne13, const int ne31, + const int ne32, const int nb31, + const int nb32, const int nb01, const int nb02, const int nb03, @@ -1288,7 +1290,8 @@ static __global__ void flash_attn_ext_f16( const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); - const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; + const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : + (const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1); float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); @@ -1327,7 +1330,8 @@ static __global__ void flash_attn_ext_f16( const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); - const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; + const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : + (const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1); float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); @@ -1348,8 +1352,8 @@ static __global__ void flash_attn_ext_f16( GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); - GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); - GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); + GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3); diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index 9283560d5c4ee..0c967f178e7b1 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -6,7 +6,7 @@ template // D == head size #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) -__launch_bounds__(nwarps*WARP_SIZE, 1) +__launch_bounds__(nwarps*WARP_SIZE, 2) #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) static __global__ void flash_attn_tile_ext_f16( const char * __restrict__ Q, @@ -30,7 +30,9 @@ static __global__ void flash_attn_tile_ext_f16( const int ne12, const int ne13, const int ne31, + const int ne32, const int nb31, + const int nb32, const int nb01, const int nb02, const int nb03, @@ -64,7 +66,7 @@ static __global__ void flash_attn_tile_ext_f16( const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0); const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio)); const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) mask + ne11*ic0; + const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); const int stride_KV2 = nb11 / sizeof(half2); @@ -288,8 +290,8 @@ static __global__ void flash_attn_tile_ext_f16( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); - GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index 32673adb57fc1..124d5d3e89122 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -6,7 +6,7 @@ template // D == head size #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) -__launch_bounds__(nwarps*WARP_SIZE, 1) +__launch_bounds__(nwarps*WARP_SIZE, 2) #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) static __global__ void flash_attn_tile_ext_f32( const char * __restrict__ Q, @@ -30,7 +30,9 @@ static __global__ void flash_attn_tile_ext_f32( const int ne12, const int ne13, const int ne31, + const int ne32, const int nb31, + const int nb32, const int nb01, const int nb02, const int nb03, @@ -58,8 +60,8 @@ static __global__ void flash_attn_tile_ext_f32( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); - GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); @@ -76,7 +78,7 @@ static __global__ void flash_attn_tile_ext_f32( const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0); const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio)); const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) mask + ne11*ic0; + const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); const int stride_KV2 = nb11 / sizeof(half2); diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index 35e649cb3c81b..e78fb181919fd 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -27,7 +27,9 @@ static __global__ void flash_attn_vec_ext_f16( const int ne12, const int ne13, const int ne31, + const int ne32, const int nb31, + const int nb32, const int nb01, const int nb02, const int nb03, @@ -68,7 +70,7 @@ static __global__ void flash_attn_vec_ext_f16( K += nb12*(blockIdx.z / gqa_ratio); V += nb22*(blockIdx.z / gqa_ratio); - const half * maskh = (const half *) mask + ne11*ic0; + const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); @@ -342,8 +344,8 @@ static __global__ void flash_attn_vec_ext_f16( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); - GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index 9539679177969..c22baf41764d1 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -27,7 +27,9 @@ static __global__ void flash_attn_vec_ext_f32( const int ne12, const int ne13, const int ne31, + const int ne32, const int nb31, + const int nb32, const int nb01, const int nb02, const int nb03, @@ -51,8 +53,8 @@ static __global__ void flash_attn_vec_ext_f32( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); - GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); @@ -79,7 +81,8 @@ static __global__ void flash_attn_vec_ext_f32( Q += nb02* blockIdx.z + nb01*ic0; K += nb12*(blockIdx.z / gqa_ratio); V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape - const half * maskh = (const half *) mask + ne11*ic0; + + const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index f3b794c3644c8..c95ca7b1f285f 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -46,7 +46,9 @@ static __global__ void flash_attn_ext_f16( const int ne12, const int ne13, const int ne31, + const int ne32, const int nb31, + const int nb32, const int nb01, const int nb02, const int nb03, @@ -94,11 +96,11 @@ static __global__ void flash_attn_ext_f16( constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0); - const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio)); - const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0; - const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2); + const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0); + const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio)); + const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); + const half2 * mask2 = (const half2 *) maskh; const int stride_Q = nb01 / sizeof(float); const int stride_KV = nb11 / sizeof(half); @@ -440,7 +442,7 @@ static __global__ void flash_attn_ext_f16( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); - GGML_UNUSED(ne31); GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index b30c13c62f25c..8f99889743748 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3247,6 +3247,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CONT: return op->src[0]->type != GGML_TYPE_BF16; case GGML_OP_DIAG_MASK_INF: + return true; case GGML_OP_SOFT_MAX: return true; case GGML_OP_SOFT_MAX_BACK: { @@ -3295,6 +3296,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g if (op->src[0]->ne[0] == 192) { return false; } + // TODO: support broadcast + // ref: https://github.com/ggml-org/llama.cpp/pull/14435 if (op->src[0]->ne[3] != 1) { return false; } diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu index aac6e0999880a..e0eb921d85d53 100644 --- a/ggml/src/ggml-cuda/softmax.cu +++ b/ggml/src/ggml-cuda/softmax.cu @@ -13,6 +13,29 @@ __device__ float __forceinline__ t2f32(half val) { return __half2float(val); } +struct soft_max_params { + + int64_t nheads; + uint32_t n_head_log2; + int64_t ncols; + int64_t nrows_x; + int64_t nrows_y; + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + int64_t nb11; + int64_t nb12; + int64_t nb13; + + int64_t ne12; + int64_t ne13; + float scale; + float max_bias; + float m0; + float m1; +}; + // When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled. // As we want to keep pragma unroll for all other cases we supress the clang transformation warning here. #ifdef __clang__ @@ -21,16 +44,24 @@ __device__ float __forceinline__ t2f32(half val) { #endif // __clang__ template static __global__ void soft_max_f32( - const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, - const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { - const int ncols = ncols_template == 0 ? ncols_par : ncols_template; + const float * x, const T * mask, float * dst, const soft_max_params p) { + const int ncols = ncols_template == 0 ? p.ncols : ncols_template; const int tid = threadIdx.x; - const int rowx = blockIdx.x; - const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension + + const int64_t i03 = blockIdx.z; + const int64_t i02 = blockIdx.y; + const int64_t i01 = blockIdx.x; + + //TODO: noncontigous inputs/outputs + const int rowx = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y; + + const int64_t i11 = i01; + const int64_t i12 = i02 % p.ne12; + const int64_t i13 = i03 % p.ne13; x += int64_t(rowx)*ncols; - mask += int64_t(rowy)*ncols * (mask != nullptr); + mask += (i11*p.nb11 + i12*p.nb12 + i13*p.nb13) / sizeof(T) * (mask != nullptr); dst += int64_t(rowx)*ncols; const int block_size = block_size_template == 0 ? blockDim.x : block_size_template; @@ -38,7 +69,7 @@ static __global__ void soft_max_f32( const int warp_id = threadIdx.x / WARP_SIZE; const int lane_id = threadIdx.x % WARP_SIZE; - const float slope = get_alibi_slope(max_bias, rowx/nrows_y, n_head_log2, m0, m1); + const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1); extern __shared__ float data_soft_max_f32[]; float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication @@ -55,7 +86,7 @@ static __global__ void soft_max_f32( break; } - const float val = x[col]*scale + (mask ? slope*t2f32(mask[col]) : 0.0f); + const float val = x[col]*p.scale + (mask ? slope*t2f32(mask[col]) : 0.0f); vals[col] = val; max_val = max(max_val, val); @@ -151,63 +182,60 @@ static __global__ void soft_max_back_f32( } template -static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) { +static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) { int nth = WARP_SIZE; + const int64_t ncols_x = params.ncols; + while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); - const dim3 block_nums(nrows_x, 1, 1); + const dim3 block_nums(params.ne01, params.ne02, params.ne03); const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float); static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); - const uint32_t n_head = nrows_x/nrows_y; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); - - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); // FIXME: this limit could be raised by ~2-4x on Ampere or newer if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) { switch (ncols_x) { case 32: soft_max_f32<<>> - (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + (x, mask, dst, params); break; case 64: soft_max_f32<<>> - (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + (x, mask, dst, params); break; case 128: soft_max_f32<<>> - (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + (x, mask, dst, params); break; case 256: soft_max_f32<<>> - (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + (x, mask, dst, params); break; case 512: soft_max_f32<<>> - (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + (x, mask, dst, params); break; case 1024: soft_max_f32<<>> - (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + (x, mask, dst, params); break; case 2048: soft_max_f32<<>> - (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + (x, mask, dst, params); break; case 4096: soft_max_f32<<>> - (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + (x, mask, dst, params); break; default: soft_max_f32<<>> - (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + (x, mask, dst, params); break; } } else { const size_t nbytes_shared_low = WARP_SIZE*sizeof(float); - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, params); } } @@ -235,10 +263,11 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional - const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); const int64_t nrows_y = src0->ne[1]; + const int64_t ne00 = src0->ne[0]; + float scale = 1.0f; float max_bias = 0.0f; @@ -247,10 +276,44 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + const int64_t nb11 = src1 ? src1->nb[1] : 1; + const int64_t nb12 = src1 ? src1->nb[2] : 1; + const int64_t nb13 = src1 ? src1->nb[3] : 1; + + const int64_t ne12 = src1 ? src1->ne[2] : 1; + const int64_t ne13 = src1 ? src1->ne[3] : 1; + + const uint32_t n_head = src0->ne[2]; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + + soft_max_params params = {}; + params.nheads = src0->ne[2]; + params.n_head_log2 = n_head_log2; + params.ncols = ne00; + params.nrows_x = nrows_x; + params.nrows_y = nrows_y; + params.ne00 = src0->ne[0]; + params.ne01 = src0->ne[1]; + params.ne02 = src0->ne[2]; + params.ne03 = src0->ne[3]; + params.nb11 = nb11; + params.nb12 = nb12; + params.nb13 = nb13; + params.ne12 = ne12; + params.ne13 = ne13; + params.scale = scale; + params.max_bias = max_bias; + params.m0 = m0; + params.m1 = m1; + if (use_f16) { - soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, params, stream); } else { - soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, params, stream); } } diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 260440aedde00..1b6ff02c75866 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -229,7 +229,9 @@ typedef struct { uint64_t nb21; uint64_t nb22; uint64_t nb23; + int32_t ne32; uint64_t nb31; + uint64_t nb32; int32_t ne1; int32_t ne2; float scale; @@ -450,9 +452,21 @@ typedef struct { } ggml_metal_kargs_sum_rows; typedef struct { - int64_t ne00; - int64_t ne01; - int64_t ne02; + int32_t ne00; + int32_t ne01; + int32_t ne02; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; float scale; float max_bias; float m0; diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 349f0ff998ed1..3fab6bca692c8 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -1710,7 +1710,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_OP_MEAN: case GGML_OP_SOFT_MAX: case GGML_OP_GROUP_NORM: - return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]); + return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]); case GGML_OP_RMS_NORM: case GGML_OP_L2_NORM: return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); @@ -2573,10 +2573,7 @@ static bool ggml_metal_encode_node( memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale)); memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias)); - const int64_t nrows_x = ggml_nrows(src0); - const int64_t nrows_y = src0->ne[1]; - - const uint32_t n_head = nrows_x/nrows_y; + const uint32_t n_head = src0->ne[2]; const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); @@ -2636,6 +2633,18 @@ static bool ggml_metal_encode_node( /*.ne00 =*/ ne00, /*.ne01 =*/ ne01, /*.ne02 =*/ ne02, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, /*.scale =*/ scale, /*.max_bias =*/ max_bias, /*.m0 =*/ m0, @@ -2655,7 +2664,7 @@ static bool ggml_metal_encode_node( [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; case GGML_OP_DIAG_MASK_INF: { @@ -4908,7 +4917,9 @@ static bool ggml_metal_encode_node( /*.nb21 =*/ nb21, /*.nb22 =*/ nb22, /*.nb23 =*/ nb23, + /*.ne32 =*/ ne32, /*.nb31 =*/ nb31, + /*.nb32 =*/ nb32, /*.ne1 =*/ ne1, /*.ne2 =*/ ne2, /*.scale =*/ scale, diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 984a0ab503e7d..c79192e8cf727 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1253,24 +1253,28 @@ kernel void kernel_soft_max( device char * dst, constant ggml_metal_kargs_soft_max & args, threadgroup float * buf [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], - uint ntg[[threads_per_threadgroup]]) { - const int64_t i03 = (tgpig) / (args.ne02*args.ne01); - const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01; - const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01); + uint3 tptg[[threads_per_threadgroup]]) { + const int32_t i03 = tgpig.z; + const int32_t i02 = tgpig.y; + const int32_t i01 = tgpig.x; + + const int32_t i13 = i03%args.ne13; + const int32_t i12 = i02%args.ne12; + const int32_t i11 = i01; - device const float * psrc0 = (device const float *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00); - device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00 : nullptr; - device float * pdst = (device float *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00); + device const float * psrc0 = (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03); + device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr; + device float * pdst = (device float *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3); float slope = 1.0f; // ALiBi if (args.max_bias > 0.0f) { - const int64_t h = i02; + const int32_t h = i02; const float base = h < args.n_head_log2 ? args.m0 : args.m1; const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; @@ -1281,13 +1285,13 @@ kernel void kernel_soft_max( // parallel max float lmax = -INFINITY; - for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) { lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)); } // find the max value in the block float max_val = simd_max(lmax); - if (ntg > N_SIMDWIDTH) { + if (tptg.x > N_SIMDWIDTH) { if (sgitg == 0) { buf[tiisg] = -INFINITY; } @@ -1306,7 +1310,7 @@ kernel void kernel_soft_max( // parallel sum float lsum = 0.0f; - for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) { const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val); lsum += exp_psrc0; pdst[i00] = exp_psrc0; @@ -1318,7 +1322,7 @@ kernel void kernel_soft_max( float sum = simd_sum(lsum); - if (ntg > N_SIMDWIDTH) { + if (tptg.x > N_SIMDWIDTH) { if (sgitg == 0) { buf[tiisg] = 0.0f; } @@ -1337,7 +1341,7 @@ kernel void kernel_soft_max( const float inv_sum = 1.0f/sum; - for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) { pdst[i00] *= inv_sum; } } @@ -1349,23 +1353,27 @@ kernel void kernel_soft_max_4( device char * dst, constant ggml_metal_kargs_soft_max & args, threadgroup float * buf [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], - uint ntg[[threads_per_threadgroup]]) { - const int64_t i03 = (tgpig) / (args.ne02*args.ne01); - const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01; - const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01); + uint3 tptg[[threads_per_threadgroup]]) { + const int32_t i03 = tgpig.z; + const int32_t i02 = tgpig.y; + const int32_t i01 = tgpig.x; + + const int32_t i13 = i03%args.ne13; + const int32_t i12 = i02%args.ne12; + const int32_t i11 = i01; - device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4; - device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00/4 : nullptr; - device float4 * pdst4 = (device float4 *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4; + device const float4 * psrc4 = (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03); + device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr; + device float4 * pdst4 = (device float4 *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3); float slope = 1.0f; if (args.max_bias > 0.0f) { - const int64_t h = i02; + const int32_t h = i02; const float base = h < args.n_head_log2 ? args.m0 : args.m1; const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; @@ -1376,14 +1384,14 @@ kernel void kernel_soft_max_4( // parallel max float4 lmax4 = -INFINITY; - for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) { lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))); } const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); float max_val = simd_max(lmax); - if (ntg > N_SIMDWIDTH) { + if (tptg.x > N_SIMDWIDTH) { if (sgitg == 0) { buf[tiisg] = -INFINITY; } @@ -1402,7 +1410,7 @@ kernel void kernel_soft_max_4( // parallel sum float4 lsum4 = 0.0f; - for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) { const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val); lsum4 += exp_psrc4; pdst4[i00] = exp_psrc4; @@ -1416,7 +1424,7 @@ kernel void kernel_soft_max_4( float sum = simd_sum(lsum); - if (ntg > N_SIMDWIDTH) { + if (tptg.x > N_SIMDWIDTH) { if (sgitg == 0) { buf[tiisg] = 0.0f; } @@ -1435,7 +1443,7 @@ kernel void kernel_soft_max_4( const float inv_sum = 1.0f/sum; - for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) { pdst4[i00] *= inv_sum; } } @@ -3709,7 +3717,7 @@ kernel void kernel_flash_attn_ext( // load the mask in shared memory #pragma unroll(Q) for (short j = 0; j < Q; ++j) { - device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31); + device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq3%args.ne32)*args.nb32); const float m = pm[ic + tiisg]; @@ -4195,7 +4203,7 @@ kernel void kernel_flash_attn_ext_vec( const bool has_mask = mask != q; // pointer to the mask - device const half * pm = (device const half *) (mask + iq1*args.nb31); + device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq3%args.ne32)*args.nb32); float slope = 1.0f; diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 9cb36ae99e7f5..50010aee92125 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4370,9 +4370,15 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g return true; case GGML_OP_CONT: return op->src[0]->type != GGML_TYPE_BF16; - case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: - return true; + // TODO: support batching + if (op->src[0]->ne[3] != 1) { + return false; + } + // TODO: support broadcast + // ref: https://github.com/ggml-org/llama.cpp/pull/14435 + return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1); + case GGML_OP_DIAG_MASK_INF: case GGML_OP_ROPE: case GGML_OP_IM2COL: return true; diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 99be5e45b2af7..aaa7bc7e81e06 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -627,6 +627,7 @@ struct vk_flash_attn_push_constants { uint32_t nev2; uint32_t nev3; uint32_t nem1; + uint32_t nem2; uint32_t nb01; uint32_t nb02; @@ -637,7 +638,6 @@ struct vk_flash_attn_push_constants { uint32_t nb21; uint32_t nb22; uint32_t nb23; - uint32_t nb31; float scale; float max_bias; @@ -652,6 +652,7 @@ struct vk_flash_attn_push_constants { uint32_t split_kv; uint32_t k_num; }; +static_assert(sizeof(vk_flash_attn_push_constants) <= 128, "sizeof(vk_flash_attn_push_constants) must be <= 128"); struct vk_op_push_constants { uint32_t KX; @@ -743,6 +744,14 @@ struct vk_op_rope_push_constants { struct vk_op_soft_max_push_constants { uint32_t KX; uint32_t KY; + uint32_t ne00; + uint32_t ne01; + uint32_t ne02; + uint32_t ne12; + uint32_t ne13; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; float scale; float max_bias; float m0; @@ -5977,7 +5986,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx GGML_TENSOR_LOCALS(size_t, nb, dst, nb) const uint32_t nem1 = mask ? mask->ne[1] : 0; - const uint32_t nbm1 = mask ? mask->nb[1] : 0; + const uint32_t nem2 = mask ? mask->ne[2] : 0; const uint32_t D = neq0; uint32_t N = neq1; @@ -6140,7 +6149,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // Try to use split_k when KV is large enough to be worth the overhead if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) { // Try to run two workgroups per SM. - split_k = ctx->device->shader_core_count * 2 / workgroups_y; + split_k = ctx->device->shader_core_count * 2 / (workgroups_y * workgroups_z); if (split_k > 1) { // Try to evenly split KV into split_k chunks, but it needs to be a multiple // of "align", so recompute split_k based on that. @@ -6150,9 +6159,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx } } - // Reserve space for split_k temporaries. For each split, we need to store the O matrix (D x ne1) - // and the per-row m and L values (ne1 rows). - const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k : 0; + // Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1) + // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows. + const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0; if (split_k_size > ctx->device->max_memory_allocation_size) { GGML_ABORT("Requested preallocation size is too large"); } @@ -6244,11 +6253,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx (uint32_t)neq2, (uint32_t)neq3, (uint32_t)nek2, (uint32_t)nek3, (uint32_t)nev2, (uint32_t)nev3, - nem1, + nem1, nem2, q_stride, (uint32_t)nbq2, (uint32_t)nbq3, k_stride, (uint32_t)nbk2, (uint32_t)nbk3, v_stride, (uint32_t)nbv2, (uint32_t)nbv3, - nbm1, scale, max_bias, logit_softcap, mask != nullptr, n_head_log2, m0, m1, gqa_ratio, split_kv, split_k }; @@ -6271,13 +6279,13 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z }); ggml_vk_sync_buffers(subctx); - const std::array pc2 = { D, (uint32_t)ne1, split_k }; + const std::array pc2 = { D, (uint32_t)ne1, (uint32_t)ne3, split_k }; ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce, { vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE}, vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, }, - pc2, { (uint32_t)ne1, 1, 1 }); + pc2, { (uint32_t)ne1, 1, (uint32_t)ne3 }); } else { ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { @@ -7562,7 +7570,13 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const uint32_t nrows_x = (uint32_t)ggml_nrows(src0); const uint32_t nrows_y = (uint32_t)src0->ne[1]; - const uint32_t n_head_kv = nrows_x/nrows_y; + const uint32_t ne12 = src1 ? (uint32_t)(src1->ne[2]) : 0u; + const uint32_t ne13 = src1 ? (uint32_t)(src1->ne[3]) : 0u; + const uint32_t nb11 = src1 ? (uint32_t)(src1->nb[1] / src1->nb[0]) : 0u; + const uint32_t nb12 = src1 ? (uint32_t)(src1->nb[2] / src1->nb[0]) : 0u; + const uint32_t nb13 = src1 ? (uint32_t)(src1->nb[3] / src1->nb[0]) : 0u; + + const uint32_t n_head_kv = src0->ne[2]; const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); @@ -7571,6 +7585,9 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, { ncols, src1 != nullptr ? nrows_y : (uint32_t)0, + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], + ne12, ne13, + nb11, nb12, nb13, scale, max_bias, m0, m1, n_head_log2, @@ -10224,6 +10241,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_SCALE: case GGML_OP_PAD: case GGML_OP_DIAG_MASK_INF: + return true; case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX_BACK: case GGML_OP_ARGSORT: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index ce230a8f7d910..6f80101d1c432 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -99,6 +99,10 @@ void main() { uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; #endif + uint32_t m_offset = 0; + if (p.nem2 != 1) { + m_offset = (iq3 % p.nem2) * p.nem1 * KV; + } [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { @@ -150,7 +154,7 @@ void main() { uint32_t c = (idx + tid) % Bc; uint32_t r = (idx + tid) / Bc; if (idx + tid < Bc * Br) { - masksh[c][r] = float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]); + masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); } } barrier(); @@ -277,7 +281,7 @@ void main() { // If there is split_k, then the split_k resolve shader does the final // division by L. Store the intermediate O value and per-row m and L values. if (p.k_num > 1) { - uint32_t o_offset = D * p.ne1 * split_k_index; + uint32_t o_offset = D * p.ne1 * (split_k_index + iq3 * p.k_num); [[unroll]] for (uint32_t r = 0; r < Br; ++r) { if (r < N) { @@ -289,7 +293,7 @@ void main() { } } - o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2; + o_offset = D * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; [[unroll]] for (uint32_t r = 0; r < Br; ++r) { if (r < N) { perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); @@ -311,7 +315,7 @@ void main() { } } - uint32_t o_offset = iq3*p.ne2*p.ne1; + uint32_t o_offset = iq3*p.ne2*p.ne1*D; if (p.gqa_ratio > 1) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp index 61d90e2d8ed21..67b935cf57e50 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp @@ -24,6 +24,7 @@ layout (push_constant) uniform parameter { uint32_t nev2; uint32_t nev3; uint32_t nem1; + uint32_t nem2; uint32_t nb01; uint32_t nb02; @@ -34,7 +35,6 @@ layout (push_constant) uniform parameter { uint32_t nb21; uint32_t nb22; uint32_t nb23; - uint32_t nb31; float scale; float max_bias; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index da478be24fb6e..26fe50c7a81e5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -123,6 +123,10 @@ void main() { uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; #endif + uint32_t m_offset = 0; + if (p.nem2 != 1) { + m_offset = (iq3 % p.nem2) * p.nem1 * KV; + } [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { @@ -181,7 +185,7 @@ void main() { uint32_t c = (idx + tid) % Bc; uint32_t r = (idx + tid) / Bc; if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { - sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[(i * Br + r) * m_stride + (j * Bc + c)])); + sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)])); } } barrier(); @@ -300,7 +304,7 @@ void main() { // If there is split_k, then the split_k resolve shader does the final // division by L. Store the intermediate O value and per-row m and L values. if (p.k_num > 1) { - uint32_t o_offset = D * p.ne1 * split_k_index; + uint32_t o_offset = D * p.ne1 * (split_k_index + iq3 * p.k_num); [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { if (tile_row(r) < N) { @@ -312,7 +316,7 @@ void main() { } } - o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2; + o_offset = D * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { if (tile_row(r) < N) { perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); @@ -334,7 +338,7 @@ void main() { } } - uint32_t o_offset = iq3*p.ne2*p.ne1; + uint32_t o_offset = iq3*p.ne2*p.ne1*D; if (p.gqa_ratio > 1) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 6acf67a03a463..cf47bd53e28bb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -130,6 +130,11 @@ void main() { coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2); } + uint32_t m_offset = 0; + if (p.nem2 != 1) { + m_offset = (iq3 % p.nem2) * p.nem1 * KV * 2 /*sizeof(float16_t)*/; + } + [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { @@ -155,7 +160,7 @@ void main() { coopmat mv; - coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); + coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); S += slopeMat*coopmat(mv); } @@ -229,10 +234,10 @@ void main() { if (p.k_num > 1) { coopmat O_D = coopmat(O); - uint32_t o_offset = D * p.ne1 * split_k_index; + uint32_t o_offset = D * p.ne1 * (split_k_index + iq3 * p.k_num); coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); - o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2; + o_offset = D * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N); coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N); return; @@ -250,7 +255,7 @@ void main() { O = Ldiag*O; - uint32_t o_offset = iq3*p.ne2*p.ne1; + uint32_t o_offset = iq3*p.ne2*p.ne1*D; coopmat O_D = coopmat(O); if (p.gqa_ratio > 1) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp index a7e3956854c44..599cef072e931 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp @@ -12,6 +12,7 @@ layout (binding = 1) writeonly buffer D {float data_d[];}; layout (push_constant) uniform parameter { uint D; uint N; + uint ne3; uint k_num; } p; @@ -19,13 +20,14 @@ void main() { // Each workgroup handles a row const uint n = gl_WorkGroupID.x; const uint tid = gl_LocalInvocationID.x; + const uint iq3 = gl_WorkGroupID.z; uint D = p.D; uint N = p.N; uint k_num = p.k_num; - uint l_offset = D * N * k_num + n; - uint m_offset = D * N * k_num + N + n; + uint l_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + n; + uint m_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + N + n; uint lm_stride = N * 2; // Compute the max m value for the row @@ -49,11 +51,11 @@ void main() { for (uint d = tid; d < D; d += BLOCK_SIZE) { float O = 0.0; [[unroll]] for (uint k = 0; k < k_num; ++k) { - uint o_offset = D * N * k + D * n + d; + uint o_offset = D * N * (k + iq3 * k_num) + D * n + d; float m = data_a[m_offset + k * lm_stride]; O += exp(m - m_max) * data_a[o_offset]; } O *= L; - data_d[D * n + d] = O; + data_d[iq3 * D * N + D * n + d] = O; } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp index 51fc2dc7ed406..5bcd3b1e3ddc6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp @@ -6,6 +6,14 @@ layout (push_constant) uniform parameter { uint KX; uint KY; + uint ne00; + uint ne01; + uint ne02; + uint ne12; + uint ne13; + uint nb11; + uint nb12; + uint nb13; float scale; float max_bias; float m0; @@ -31,7 +39,15 @@ shared FLOAT_TYPE vals[BLOCK_SIZE]; void soft_max(uint num_iters) { const uint tid = gl_LocalInvocationID.x; const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; - const uint rowy = (p.KY > 0) ? (rowx % p.KY) : 0; + + const uint32_t i03 = rowx / (p.ne01 * p.ne02); + const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01; + const uint32_t i01 = rowx % p.ne01; + + uint rowy_start = 0; + if (p.KY > 0) { + rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13; + } if (rowx >= p.nrows_x) { return; @@ -41,7 +57,7 @@ void soft_max(uint num_iters) { // ALiBi if (p.max_bias > 0.0f) { - const uint h = rowx/p.KY; // head index + const uint h = (rowx / p.ne01) % p.ne02; // head index const float base = h < p.n_head_log2 ? p.m0 : p.m1; const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1; @@ -67,7 +83,7 @@ void soft_max(uint num_iters) { FLOAT_TYPE b = FLOAT_TYPE(0); if (p.KY > 0 && col < p.KX) { - b = data_b[rowy * p.KX + col]; + b = data_b[rowy_start + col]; } FLOAT_TYPE v = a * p.scale + slope * b; @@ -111,7 +127,7 @@ void soft_max(uint num_iters) { if (idx < DATA_CACHE_SIZE) { val = exp(data_cache[idx] - max_val); } else { - val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val); + val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy_start + col]) : FLOAT_TYPE(0.0f)) - max_val); } sum += val; if (idx < DATA_CACHE_SIZE) { diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 3d04f80ef4f90..1c20f66a1a4f0 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3515,9 +3515,11 @@ static struct ggml_tensor * ggml_soft_max_impl( if (mask) { GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32); GGML_ASSERT(ggml_is_contiguous(mask)); - GGML_ASSERT(ggml_is_matrix(mask)); + GGML_ASSERT(ggml_is_3d(mask)); GGML_ASSERT(mask->ne[0] == a->ne[0]); GGML_ASSERT(mask->ne[1] >= a->ne[1]); + GGML_ASSERT(a->ne[2]%mask->ne[2] == 0); + GGML_ASSERT(a->ne[3]%mask->ne[3] == 0); } if (max_bias > 0.0f) { @@ -4491,13 +4493,17 @@ struct ggml_tensor * ggml_flash_attn_ext( GGML_ASSERT(ggml_can_mul_mat(k, q)); // TODO: check if vT can be multiplied by (k*qT) + GGML_ASSERT(q->ne[3] == k->ne[3]); + GGML_ASSERT(q->ne[3] == v->ne[3]); + if (mask) { GGML_ASSERT(ggml_is_contiguous(mask)); - GGML_ASSERT(mask->ne[2] == 1); - GGML_ASSERT(mask->ne[3] == 1); + GGML_ASSERT(mask->ne[2] == q->ne[3]); GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) && "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big"); //GGML_ASSERT(ggml_can_repeat_rows(mask, qk)); + + GGML_ASSERT(q->ne[3] % mask->ne[2] == 0); } if (max_bias > 0.0f) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index a233f1f2fd97a..ac38c1015c160 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2525,11 +2525,12 @@ struct test_soft_max : public test_case { const std::array ne; const bool mask; const ggml_type m_prec; + const std::array nr23; // broadcast only dims 2 and 3 const float scale; const float max_bias; std::string vars() override { - return VARS_TO_STR6(type, ne, mask, m_prec, scale, max_bias); + return VARS_TO_STR7(type, ne, mask, m_prec, nr23, scale, max_bias); } // the 1024 test with bias occasionally fails: @@ -2542,18 +2543,19 @@ struct test_soft_max : public test_case { std::array ne = {10, 5, 4, 3}, bool mask = false, ggml_type m_prec = GGML_TYPE_F32, + std::array nr23 = {1, 1}, float scale = 1.0f, float max_bias = 0.0f) - : type(type), ne(ne), mask(mask), m_prec(m_prec), scale(scale), max_bias(max_bias) {} + : type(type), ne(ne), mask(mask), m_prec(m_prec), nr23(nr23), scale(scale), max_bias(max_bias) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2]*nr23[0], ne[3]*nr23[1]); ggml_set_param(a); ggml_set_name(a, "a"); ggml_tensor * mask = nullptr; if (this->mask) { - mask = ggml_new_tensor_2d(ctx, m_prec, ne[0], ne[1]); + mask = ggml_new_tensor_4d(ctx, m_prec, ne[0], ne[1], ne[2], ne[3]); ggml_set_name(mask, "mask"); } @@ -3384,7 +3386,7 @@ struct test_flash_attn_ext : public test_case { const int64_t hsk; // K head size const int64_t hsv; // V head size const int64_t nh; // num heads - const int64_t nr; // repeat in Q, tests for grouped-query attention + const std::array nr23; // repeat in dim 2 and 3, tests for grouped-query attention const int64_t kv; // kv size const int64_t nb; // batch size @@ -3398,7 +3400,7 @@ struct test_flash_attn_ext : public test_case { std::array permute; std::string vars() override { - return VARS_TO_STR12(hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, permute); + return VARS_TO_STR12(hsk, hsv, nh, nr23, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, permute); } double max_nmse_err() override { @@ -3409,13 +3411,13 @@ struct test_flash_attn_ext : public test_case { GGML_UNUSED(t); // Just counting matmul costs: // Q*K^T is nb x hsk x kv, P*V is nb x kv x hsv, per head - return 2 * nh*nr * nb * (hsk + hsv) * kv; + return (2 * nh*nr23[0] * nb * (hsk + hsv) * kv)*nr23[1]; } - test_flash_attn_ext(int64_t hsk = 128, int64_t hsv = 128, int64_t nh = 32, int64_t nr = 1, int64_t kv = 96, int64_t nb = 8, + test_flash_attn_ext(int64_t hsk = 128, int64_t hsv = 128, int64_t nh = 32, std::array nr23 = {1, 1}, int64_t kv = 96, int64_t nb = 8, bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_prec prec = GGML_PREC_F32, ggml_type type_KV = GGML_TYPE_F16, std::array permute = {0, 1, 2, 3}) - : hsk(hsk), hsv(hsv), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {} + : hsk(hsk), hsv(hsv), nh(nh), nr23(nr23), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {} ggml_tensor * build_graph(ggml_context * ctx) override { const int64_t hsk_padded = GGML_PAD(hsk, ggml_blck_size(type_KV)); @@ -3434,18 +3436,18 @@ struct test_flash_attn_ext : public test_case { return t; }; - ggml_tensor * q = create_permuted(GGML_TYPE_F32, hsk_padded, nb, nh*nr, 1); + ggml_tensor * q = create_permuted(GGML_TYPE_F32, hsk_padded, nb, nh*nr23[0], nr23[1]); ggml_set_name(q, "q"); - ggml_tensor * k = create_permuted(type_KV, hsk_padded, kv, nh, 1); + ggml_tensor * k = create_permuted(type_KV, hsk_padded, kv, nh, nr23[1]); ggml_set_name(k, "k"); - ggml_tensor * v = create_permuted(type_KV, hsv_padded, kv, nh, 1); + ggml_tensor * v = create_permuted(type_KV, hsv_padded, kv, nh, nr23[1]); ggml_set_name(v, "v"); ggml_tensor * m = nullptr; if (mask) { - m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1); + m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), nr23[1], 1); ggml_set_name(m, "m"); } @@ -4524,26 +4526,31 @@ static std::vector> make_test_cases_eval() { for (int64_t ne1 : {16, 1024}) { if (mask) { for (ggml_type m_prec : {GGML_TYPE_F32, GGML_TYPE_F16}) { - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, m_prec, scale, max_bias)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, scale, max_bias)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, m_prec, {1, 1}, scale, max_bias)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {1, 1}, scale, max_bias)); + + if (ne0 <= 32 && ne1 <= 32) { + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, m_prec, {3, 1}, scale, max_bias)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {2, 3}, scale, max_bias)); + } } } else { /* The precision of mask here doesn't matter as boolean mask is false */ - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, GGML_TYPE_F32, scale, max_bias)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, GGML_TYPE_F32, scale, max_bias)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, GGML_TYPE_F32, {1, 1}, scale, max_bias)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, GGML_TYPE_F32, {1, 1}, scale, max_bias)); } } } } } } - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F32, 0.1f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F16, 0.1f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, GGML_TYPE_F32, 0.1f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32, 0.1f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16, 0.1f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32, 0.1f, 8.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16, 0.1f, 8.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f)); for (float max_bias : {0.0f, 8.0f}) { for (float scale : {1.0f, 0.1f}) { @@ -4641,20 +4648,23 @@ static std::vector> make_test_cases_eval() { for (float logit_softcap : {0.0f, 10.0f}) { if (hsk != 128 && logit_softcap != 0.0f) continue; for (int nh : { 4, }) { - for (int nr : { 1, 4, 16 }) { - if (nr == 16 && hsk != 128) continue; - for (int kv : { 512, 1024, }) { - if (nr != 1 && kv != 512) continue; - for (int nb : { 1, 3, 32, 35, }) { - for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) { - if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue; - for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) { - test_cases.emplace_back(new test_flash_attn_ext( - hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV)); - // run fewer test cases permuted - if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) { + for (int nr3 : { 1, 3, }) { + if (hsk > 64 && nr3 > 1) continue; // skip broadcast for large head sizes + for (int nr2 : { 1, 4, 16 }) { + if (nr2 == 16 && hsk != 128) continue; + for (int kv : { 512, 1024, }) { + if (nr2 != 1 && kv != 512) continue; + for (int nb : { 1, 3, 32, 35, }) { + for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) { + if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue; + for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) { test_cases.emplace_back(new test_flash_attn_ext( - hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3})); + hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, max_bias, logit_softcap, prec, type_KV)); + // run fewer test cases permuted + if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) { + test_cases.emplace_back(new test_flash_attn_ext( + hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3})); + } } } } @@ -4697,13 +4707,13 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3})); test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3})); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 10, 1, 1})); test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1})); @@ -4735,7 +4745,7 @@ static std::vector> make_test_cases_perf() { for (int kv : { 4096, 8192, 16384, }) { for (int hs : { 64, 128, }) { for (int nr : { 1, 4, }) { - test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, nr, kv, 1, true, 0, 0, GGML_PREC_F32, GGML_TYPE_F16)); + test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, {nr, 1}, kv, 1, true, 0, 0, GGML_PREC_F32, GGML_TYPE_F16)); } } }