diff --git a/common/arg.cpp b/common/arg.cpp index a1fd4c9651b9c..bbd85cd2370bf 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1760,6 +1760,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.sampling.xtc_threshold = std::stof(value); } ).set_sparam()); + add_opt(common_arg( + {"--smoothing-factor"}, "N", + string_format("smoothing factor (default: %.1f, 0.0 = disabled)", (double)params.sampling.smoothing_factor), + [](common_params & params, const std::string & value) { + params.sampling.smoothing_factor = std::stof(value); + } + ).set_sparam()); + add_opt(common_arg( + {"--smoothing-curve"}, "N", + string_format("smoothing curve (default: %.1f, 1.0 = disabled)", (double)params.sampling.smoothing_curve), + [](common_params & params, const std::string & value) { + params.sampling.smoothing_curve = std::stof(value); + } + ).set_sparam()); add_opt(common_arg( {"--typical"}, "N", string_format("locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)params.sampling.typ_p), diff --git a/common/common.h b/common/common.h index dea34267c8de8..491775091a7e3 100644 --- a/common/common.h +++ b/common/common.h @@ -96,6 +96,7 @@ enum common_sampler_type { COMMON_SAMPLER_TYPE_INFILL = 9, COMMON_SAMPLER_TYPE_PENALTIES = 10, COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11, + COMMON_SAMPLER_TYPE_SMOOTHING = 12, }; // dimensionality reduction methods, used by cvector-generator @@ -135,6 +136,8 @@ struct common_params_sampling { float min_p = 0.05f; // 0.0 = disabled float xtc_probability = 0.00f; // 0.0 = disabled float xtc_threshold = 0.10f; // > 0.5 disables XTC + float smoothing_factor = 0.0f; // controls the quadratic adjustment in smooth / quadratic sampling (0.0 = disabled) + float smoothing_curve = 1.0f; // controls the cubic transformation curve for smoothing / quadratic sampling. float typ_p = 1.00f; // typical_p, 1.0 = disabled float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities float dynatemp_range = 0.00f; // 0.0 = disabled @@ -167,6 +170,7 @@ struct common_params_sampling { COMMON_SAMPLER_TYPE_TOP_P, COMMON_SAMPLER_TYPE_MIN_P, COMMON_SAMPLER_TYPE_XTC, + COMMON_SAMPLER_TYPE_SMOOTHING, COMMON_SAMPLER_TYPE_TEMPERATURE, }; diff --git a/common/sampling.cpp b/common/sampling.cpp index 28705e24c0b71..55ce23b1a3e51 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -136,11 +136,11 @@ std::string common_params_sampling::print() const { snprintf(result, sizeof(result), "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" "\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n" - "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n" + "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, smoothing_factor = %.3f, smoothing_curve = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n" "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", penalty_last_n, penalty_repeat, penalty_freq, penalty_present, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, - top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp, + top_k, top_p, min_p, xtc_probability, xtc_threshold, smoothing_factor, smoothing_curve, typ_p, top_n_sigma, temp, mirostat, mirostat_eta, mirostat_tau); return std::string(result); @@ -258,6 +258,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co case COMMON_SAMPLER_TYPE_XTC: llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); break; + case COMMON_SAMPLER_TYPE_SMOOTHING: + llama_sampler_chain_add(result->chain, llama_sampler_init_smoothing (params.smoothing_factor, params.smoothing_curve)); + break; case COMMON_SAMPLER_TYPE_TYPICAL_P: llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep)); break; @@ -479,6 +482,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) { case COMMON_SAMPLER_TYPE_XTC: return 'x'; case COMMON_SAMPLER_TYPE_INFILL: return 'i'; case COMMON_SAMPLER_TYPE_PENALTIES: return 'e'; + case COMMON_SAMPLER_TYPE_SMOOTHING: return 'q'; default : return '?'; } } @@ -495,6 +499,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) { case COMMON_SAMPLER_TYPE_XTC: return "xtc"; case COMMON_SAMPLER_TYPE_INFILL: return "infill"; case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties"; + case COMMON_SAMPLER_TYPE_SMOOTHING: return "smoothing"; default : return ""; } } @@ -509,6 +514,7 @@ std::vector common_sampler_types_from_names(const std::vect { "min_p", COMMON_SAMPLER_TYPE_MIN_P }, { "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE }, { "xtc", COMMON_SAMPLER_TYPE_XTC }, + { "smoothing", COMMON_SAMPLER_TYPE_SMOOTHING}, { "infill", COMMON_SAMPLER_TYPE_INFILL }, { "penalties", COMMON_SAMPLER_TYPE_PENALTIES }, }; @@ -525,6 +531,7 @@ std::vector common_sampler_types_from_names(const std::vect { "typ-p", COMMON_SAMPLER_TYPE_TYPICAL_P }, { "typ", COMMON_SAMPLER_TYPE_TYPICAL_P }, { "min-p", COMMON_SAMPLER_TYPE_MIN_P }, + { "quadratic", COMMON_SAMPLER_TYPE_SMOOTHING}, { "temp", COMMON_SAMPLER_TYPE_TEMPERATURE }, }; @@ -560,6 +567,7 @@ std::vector common_sampler_types_from_chars(const std::stri { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC }, + { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_SMOOTHING), COMMON_SAMPLER_TYPE_SMOOTHING}, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES }, }; diff --git a/docs/multimodal.md b/docs/multimodal.md index 6a5d2b342b850..80014ba1cef6d 100644 --- a/docs/multimodal.md +++ b/docs/multimodal.md @@ -31,7 +31,7 @@ llama-server -hf ggml-org/gemma-3-4b-it-GGUF --no-mmproj-offload ## Pre-quantized models -These are ready-to-use models, most of them come with `Q4_K_M` quantization by default. +These are ready-to-use models, most of them come with `Q4_K_M` quantization by default. They can be found at the Hugging Face page of the ggml-org: https://huggingface.co/ggml-org Replaces the `(tool_name)` with the name of binary you want to use. For example, `llama-mtmd-cli` or `llama-server` diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index b7180d5955c29..a4fbd823638fa 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -678,10 +678,14 @@ void launch_fattn( ) { constexpr int ncols = ncols1 * ncols2; + const bool is_mla = DV == 512; // TODO better parameterization + const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; + GGML_ASSERT(V || is_mla); + const ggml_tensor * mask = dst->src[3]; ggml_tensor * KQV = dst; @@ -689,6 +693,10 @@ void launch_fattn( GGML_ASSERT(Q->type == GGML_TYPE_F32); GGML_ASSERT(KQV->type == GGML_TYPE_F32); + GGML_ASSERT( Q->nb[0] == ggml_element_size(Q)); + GGML_ASSERT( K->nb[0] == ggml_element_size(K)); + GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V)); + GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); @@ -713,10 +721,10 @@ void launch_fattn( size_t nb12 = K->nb[2]; size_t nb13 = K->nb[3]; - const char * V_data = (const char *) V->data; - size_t nb21 = V->nb[1]; - size_t nb22 = V->nb[2]; - size_t nb23 = V->nb[3]; + const char * V_data = V ? (const char *) V->data : nullptr; + size_t nb21 = V ? V->nb[1] : nb11; + size_t nb22 = V ? V->nb[2] : nb12; + size_t nb23 = V ? V->nb[3] : nb13; if (need_f16_K && K->type != GGML_TYPE_F16) { GGML_ASSERT(ggml_is_contiguously_allocated(K)); @@ -733,7 +741,7 @@ void launch_fattn( nb13 = nb13*bs*sizeof(half)/ts; } - if (need_f16_V && V->type != GGML_TYPE_F16) { + if (V && need_f16_V && V->type != GGML_TYPE_F16) { GGML_ASSERT(ggml_is_contiguously_allocated(V)); V_f16.alloc(ggml_nelements(V)); to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 491780abd4062..be0329d0e0c09 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -33,9 +33,30 @@ struct fattn_mma_f16_config< 64, 64> { static constexpr int nwarps_max = 4; static constexpr bool Q_in_reg = true; static constexpr int nstages_target = 2; - static constexpr int nbatch_K2 = 32; - static constexpr int nbatch_V2 = 32; - static constexpr int nbatch_combine = 32; + + static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { + return 32; + } + + static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { + return 32; + } + + static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { + return 32; + } + + static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { + return 32; + } + + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { + return 32; + } + + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { + return 32; + } }; template <> @@ -44,9 +65,30 @@ struct fattn_mma_f16_config< 80, 80> { static constexpr int nwarps_max = 4; static constexpr bool Q_in_reg = true; static constexpr int nstages_target = 2; - static constexpr int nbatch_K2 = 40; - static constexpr int nbatch_V2 = 40; - static constexpr int nbatch_combine = 40; + + static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { + return 40; + } + + static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { + return 40; + } + + static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { + return 40; + } + + static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { + return 40; + } + + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { + return 40; + } + + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { + return 40; + } }; template <> @@ -55,9 +97,30 @@ struct fattn_mma_f16_config< 96, 96> { static constexpr int nwarps_max = 4; static constexpr bool Q_in_reg = true; static constexpr int nstages_target = 2; - static constexpr int nbatch_K2 = 48; - static constexpr int nbatch_V2 = 48; - static constexpr int nbatch_combine = 48; + + static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { + return 48; + } + + static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { + return 48; + } + + static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { + return 48; + } + + static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { + return 48; + } + + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { + return 48; + } + + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { + return 48; + } }; template <> @@ -66,9 +129,30 @@ struct fattn_mma_f16_config<112, 112> { static constexpr int nwarps_max = 4; static constexpr bool Q_in_reg = true; static constexpr int nstages_target = 2; - static constexpr int nbatch_K2 = 56; - static constexpr int nbatch_V2 = 56; - static constexpr int nbatch_combine = 56; + + static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { + return 56; + } + + static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { + return 56; + } + + static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { + return 56; + } + + static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { + return 56; + } + + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { + return 56; + } + + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { + return 56; + } }; template <> @@ -77,9 +161,30 @@ struct fattn_mma_f16_config<128, 128> { static constexpr int nwarps_max = 4; static constexpr bool Q_in_reg = true; static constexpr int nstages_target = 2; - static constexpr int nbatch_K2 = 64; - static constexpr int nbatch_V2 = 64; - static constexpr int nbatch_combine = 64; + + static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { + return 64; + } + + static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { + return 64; + } + + static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { + return 64; + } + + static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { + return 64; + } + + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { + return 64; + } + + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { + return 64; + } }; template <> @@ -88,9 +193,38 @@ struct fattn_mma_f16_config<256, 256> { static constexpr int nwarps_max = 4; static constexpr bool Q_in_reg = true; static constexpr int nstages_target = 2; - static constexpr int nbatch_K2 = 128; - static constexpr int nbatch_V2 = 128; - static constexpr int nbatch_combine = 128; + + static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { + return 128; + } + + static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { + return 128; + } + + static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { + return 128; + } + + static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { + return 128; + } + + static int get_nbatch_combine_host(const int cc, const int ncols) { + if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) { + return ncols <= 16 ? 128 : 64; + } + return 64; + } + + static constexpr __device__ int get_nbatch_combine_device(int ncols) { +#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING + return ncols <= 16 ? 128 : 64; +#else + GGML_UNUSED(ncols); + return 128; +#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING + } }; template <> @@ -99,9 +233,44 @@ struct fattn_mma_f16_config<576, 512> { static constexpr int nwarps_max = 8; static constexpr bool Q_in_reg = false; static constexpr int nstages_target = 1; - static constexpr int nbatch_K2 = 160; - static constexpr int nbatch_V2 = 128; - static constexpr int nbatch_combine = 128; + + static int get_nbatch_K2_host(const int cc, const int ncols) { + if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) { + return ncols <= 16 ? 96 : 160; + } + return ncols <= 16 ? 288 : 160; + } + + static constexpr __device__ int get_nbatch_K2_device(int ncols) { +#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING + return ncols <= 16 ? 96 : 160; +#else + return ncols <= 16 ? 288 : 160; +#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING + } + + static int get_nbatch_V2_host(const int cc, const int ncols) { + if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) { + return ncols <= 16 ? 64 : 128; + } + return ncols <= 16 ? 256 : 128; + } + + static constexpr __device__ int get_nbatch_V2_device(int ncols) { +#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING + return ncols <= 16 ? 64 : 128; +#else + return ncols <= 16 ? 256 : 128; +#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING + } + + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { + return 128; + } + + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { + return 128; + } }; // ------------------------------------------------------------------------------------------------------------------ @@ -120,7 +289,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV); - auto load = [&] __device__ (const int n) { + auto load = [&] __device__ (auto n) { const int stride_k = WARP_SIZE >> n; const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k); const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k); @@ -223,7 +392,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( } } -template +template static __device__ __forceinline__ void flash_attn_ext_f16_iter( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, @@ -261,10 +430,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( constexpr int cols_per_warp = ntiles * tile_B::I; constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + constexpr int ncols = ncols1 * ncols2; + constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols); + constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols); - constexpr int stride_tile_Q = DKQ/2 + 4; - constexpr int stride_tile_K = c::nbatch_K2 + 4; - constexpr int stride_tile_V = c::nbatch_V2 + 4; + constexpr int stride_tile_Q = DKQ/2 + 4; + constexpr int stride_tile_K = nbatch_K2 + 4; + + static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA"); + constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4; const int k_VKQ_0 = kb0 * c::nbatch_fa; tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles]; @@ -275,12 +449,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C; if constexpr (nstages > 1) { - static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading"); + static_assert(!mla, "multi-stage loading not implemented for MLA"); + static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading"); constexpr bool use_cp_async = true; cp_async_wait_all(); __syncthreads(); flash_attn_ext_f16_load_tile - (V_h2 + k_VKQ_0*stride_V, tile_V, c::nbatch_V2, stride_V); + (V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V); } else { constexpr bool use_cp_async = nstages == 1; if (ncols2 > 1 || mask_h2) { @@ -289,8 +464,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } #pragma unroll - for (int k0_start = 0; k0_start < DKQ/2; k0_start += c::nbatch_K2) { - const int k0_stop = k0_start + c::nbatch_K2 < DKQ/2 ? k0_start + c::nbatch_K2 : DKQ/2; + for (int k0_start = 0; k0_start < DKQ/2; k0_start += nbatch_K2) { + const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2; const int k0_diff = k0_stop - k0_start; if (nstages <= 1) { @@ -537,16 +712,21 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask); } flash_attn_ext_f16_load_tile - (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, c::nbatch_K2, stride_K); + (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K); } } + + // For MLA K and V have the same data. + // Therefore, iterate over V in reverse and re-use the data if possible. + static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented"); + constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV; #pragma unroll - for (int i0_start = 0; i0_start < DV; i0_start += 2*c::nbatch_V2) { - const int i0_stop = i0_start + 2*c::nbatch_V2 < DV ? i0_start + 2*c::nbatch_V2 : DV; - const int i0_diff = i0_stop - i0_start; + for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) { + const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0; + const int i0_diff = i0_stop - i0_start; - if (nstages <= 1) { + if (nstages <= 1 && i0_start < reusable_cutoff) { constexpr bool use_cp_async = nstages == 1; flash_attn_ext_f16_load_tile (V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V); @@ -555,6 +735,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } __syncthreads(); } + const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2; // Calculate VKQ tile: #pragma unroll @@ -565,7 +746,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int k0 = k00 + (threadIdx.y % np)*tile_A::J; tile_A A; - load_ldmatrix_trans(A, tile_V + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); + load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); if (ntiles == 1) { mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]); } else { @@ -596,7 +777,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #endif // NEW_MMA_AVAILABLE } -template +template static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, @@ -632,13 +813,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr int cols_per_warp = ntiles * tile_B::I; constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols); + constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols); static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps"); - constexpr int stride_tile_Q = DKQ/2 + 4; - constexpr int stride_tile_K = c::nbatch_K2 + 4; - constexpr int stride_tile_V = c::nbatch_V2 + 4; + constexpr int stride_tile_Q = DKQ/2 + 4; + constexpr int stride_tile_K = nbatch_K2 + 4; + static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA"); + constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4; constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V; extern __shared__ half2 tile_Q[]; @@ -726,26 +910,26 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( // Preload mask and K data for first iteration when using cp_async with multiple stages: if constexpr (nstages > 1) { - static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline"); + static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline"); constexpr bool use_cp_async = true; if (ncols2 > 1 || mask_h2) { flash_attn_ext_f16_load_mask (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask); } flash_attn_ext_f16_load_tile - (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, c::nbatch_K2, stride_K); + (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K); } // Iterate over ne11 == previous tokens: for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) { constexpr bool last_iter = false; - flash_attn_ext_f16_iter + flash_attn_ext_f16_iter (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); } { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. constexpr bool last_iter = true; - flash_attn_ext_f16_iter + flash_attn_ext_f16_iter (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); } @@ -774,7 +958,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. // So also write VKQ accumulators to shared memory in column-major format if np == 1. - constexpr int nbatch_combine = c::Q_in_reg ? DV/2 : DV/4; + constexpr int nbatch_combine = c::get_nbatch_combine_device(ncols); constexpr int tile_stride = nbatch_combine + 4; static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine"); @@ -1012,7 +1196,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #endif // NEW_MMA_AVAILABLE } -template +template __launch_bounds__(nwarps*WARP_SIZE, 1) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, @@ -1057,6 +1241,14 @@ static __global__ void flash_attn_ext_f16( NO_DEVICE_CODE; return; } +#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING + if (ncols1*ncols2 > 32) { + NO_DEVICE_CODE; + return; + } +#endif __CUDA_ARCH__ == GGML_CUDA_CC_TURING + + static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV"); typedef fattn_mma_f16_config c; @@ -1067,9 +1259,10 @@ static __global__ void flash_attn_ext_f16( const int stride_Q1 = nb01 / sizeof(float2); const int stride_Q2 = nb02 / sizeof(float2); const int stride_K = nb11 / sizeof(half2); - const int stride_V = nb21 / sizeof(half2); const int stride_mask = nb31 / sizeof(half2); + const int stride_V = mla ? stride_K : nb21 / sizeof(half2); + const int iter_k = ne11 / FATTN_KQ_STRIDE; const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; @@ -1092,10 +1285,11 @@ static __global__ void flash_attn_ext_f16( const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); + const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; const int kb0_start_kernel = kb0_start * kb_niter; @@ -1104,12 +1298,12 @@ static __global__ void flash_attn_ext_f16( constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. if (kb0_start == 0) { constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. - flash_attn_ext_f16_process_tile + flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } else { constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. - flash_attn_ext_f16_process_tile + flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } @@ -1130,10 +1324,11 @@ static __global__ void flash_attn_ext_f16( const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); // K and V have same shape const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); + const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; const int kb0_start_kernel = kb0_start * kb_niter; @@ -1141,7 +1336,7 @@ static __global__ void flash_attn_ext_f16( constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. constexpr bool needs_fixup = false; - flash_attn_ext_f16_process_tile + flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); #else @@ -1167,10 +1362,6 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml typedef fattn_mma_f16_config c; - constexpr int nbatch_K2 = c::nbatch_K2 < 1 ? DKQ/2 : c::nbatch_K2; - constexpr int nbatch_V2 = c::nbatch_V2 < 1 ? DV /2 : c::nbatch_V2; - constexpr int nbatch_combine = c::nbatch_combine < 1 ? DV /2 : c::nbatch_combine; - const int nstages = cp_async_available(cc) ? c::nstages_target : 0; constexpr int ncols = ncols1 * ncols2; @@ -1180,15 +1371,21 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I; constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max; + constexpr bool mla = DKQ == 576; + + const int nbatch_K2 = c::get_nbatch_K2_host (cc, ncols); + const int nbatch_V2 = c::get_nbatch_K2_host (cc, ncols); + const int nbatch_combine = c::get_nbatch_combine_host(cc, ncols); + static_assert(DKQ % tile_B::J == 0, "bad DKQ"); static_assert(DV % tile_A::J == 0, "bad DV"); static_assert(ncols % cols_per_warp == 0, "bad ncols"); - const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(c::nbatch_K2 + 4, c::nbatch_V2 + 4) * sizeof(half2); - const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (c::nbatch_K2 + 4 + c::nbatch_V2 + 4) * sizeof(half2); - const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2); - const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2); - const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2); + const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2); + const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2); + const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2); + const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2); + const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2); const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage; @@ -1202,7 +1399,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml fattn_kernel_t fattn_kernel; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; @@ -1213,7 +1410,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) } else { constexpr bool use_logit_softcap = true; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 9c5c803d02bc7..6bc0096cc65e6 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -10,6 +10,7 @@ template static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const ggml_tensor * Q = dst->src[0]; if constexpr (ncols2 <= 8) { @@ -24,7 +25,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con return; } - if (Q->ne[1] <= 32/ncols2) { + if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || Q->ne[1] <= 32/ncols2) { ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); return; } diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index b4b85abcda9e3..02dc8c12dbd8c 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3222,7 +3222,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g #endif // FLASH_ATTN_AVAILABLE if (op->src[1]->ne[0] != op->src[2]->ne[0]) { const int cc = ggml_cuda_info().devices[dev_ctx->device].cc; - if (!new_mma_available(cc) || cc < GGML_CUDA_CC_AMPERE) { + if (!new_mma_available(cc)) { return false; } const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2]; diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index e1cf843de1a65..2db5b4ab0f09c 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -122,6 +122,7 @@ void ggml_cuda_mul_mat_q( const int64_t s13 = src1->nb[3] / ts_src1; quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream); + CUDA_CHECK(cudaGetLastError()); } const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int)); @@ -205,6 +206,7 @@ void ggml_cuda_mul_mat_q( const int64_t s13 = src1->nb[2] / ts_src1; quantize_mmq_q8_1_cuda(src1_d, ids_src1_dev, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream); + CUDA_CHECK(cudaGetLastError()); } const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int)); diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index cb93181455d47..a0b03a740d74c 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -56,13 +56,13 @@ static __global__ void quantize_mmq_q8_1( constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32; constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32; - const int64_t i0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4; + const int64_t i0 = ((int64_t)blockDim.x*blockIdx.y + threadIdx.x)*4; if (i0 >= ne0) { return; } - const int64_t i1 = blockIdx.y; + const int64_t i1 = blockIdx.x; const int64_t i2 = blockIdx.z % ne2; const int64_t i3 = blockIdx.z / ne2; @@ -75,8 +75,8 @@ static __global__ void quantize_mmq_q8_1( block_q8_1_mmq * y = (block_q8_1_mmq *) vy; - const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel - const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.y; // block index in channel + const int64_t ib0 = blockIdx.z*((int64_t)gridDim.x*gridDim.y*blockDim.x/QK8_1); // first block of channel + const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.x; // block index in channel const int64_t iqs = i0 % (4*QK8_1); // quant index in block // Load 4 floats per thread and calculate max. abs. value between them: @@ -166,8 +166,9 @@ void quantize_mmq_q8_1_cuda( GGML_ASSERT(ne00 % 4 == 0); GGML_ASSERT(ne0 % (4*QK8_1) == 0); - const int64_t block_num_x = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ); - const dim3 num_blocks(block_num_x, ne1, ne2*ne3); + // ne1 tends to assume the highest values, therefore use it as the "x" dimension of the CUDA grid: + const int64_t block_num_y = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ); + const dim3 num_blocks(ne1, block_num_y, ne2*ne3); const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1); switch (mmq_get_q8_1_ds_layout(type_src0)) { case MMQ_Q8_1_DS_LAYOUT_D4: diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt index 31816219c06fd..16e10a9f399d5 100644 --- a/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -15,6 +15,32 @@ function(detect_host_compiler) set(HOST_CXX_COMPILER "${HOST_CXX_COMPILER}" PARENT_SCOPE) endfunction() +# Function to test shader extension support +# Parameters: +# EXTENSION_NAME - Name of the extension to test (e.g., "GL_EXT_integer_dot_product") +# TEST_SHADER_FILE - Path to the test shader file +# RESULT_VARIABLE - Name of the variable to set (ON/OFF) based on test result +function(test_shader_extension_support EXTENSION_NAME TEST_SHADER_FILE RESULT_VARIABLE) + execute_process( + COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${TEST_SHADER_FILE}" + OUTPUT_VARIABLE glslc_output + ERROR_VARIABLE glslc_error + ) + + if (${glslc_error} MATCHES ".*extension not supported: ${EXTENSION_NAME}.*") + message(STATUS "${EXTENSION_NAME} not supported by glslc") + set(${RESULT_VARIABLE} OFF PARENT_SCOPE) + else() + message(STATUS "${EXTENSION_NAME} supported by glslc") + set(${RESULT_VARIABLE} ON PARENT_SCOPE) + add_compile_definitions(${RESULT_VARIABLE}) + + # Ensure the extension support is forwarded to vulkan-shaders-gen + list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -D${RESULT_VARIABLE}=ON) + set(VULKAN_SHADER_GEN_CMAKE_ARGS "${VULKAN_SHADER_GEN_CMAKE_ARGS}" PARENT_SCOPE) + endif() +endfunction() + if (Vulkan_FOUND) message(STATUS "Vulkan found") @@ -23,69 +49,35 @@ if (Vulkan_FOUND) ../../include/ggml-vulkan.h ) - # Compile a test shader to determine whether GL_KHR_cooperative_matrix is supported. - # If it's not, there will be an error to stderr. - # If it's supported, set a define to indicate that we should compile those shaders - execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat_support.comp" - OUTPUT_VARIABLE glslc_output - ERROR_VARIABLE glslc_error) - - if (${glslc_error} MATCHES ".*extension not supported: GL_KHR_cooperative_matrix.*") - message(STATUS "GL_KHR_cooperative_matrix not supported by glslc") - set(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT OFF) - else() - message(STATUS "GL_KHR_cooperative_matrix supported by glslc") - set(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT ON) - add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) - endif() - - # Compile a test shader to determine whether GL_NV_cooperative_matrix2 is supported. - # If it's not, there will be an error to stderr. - # If it's supported, set a define to indicate that we should compile those shaders - execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp" - OUTPUT_VARIABLE glslc_output - ERROR_VARIABLE glslc_error) - - if (${glslc_error} MATCHES ".*extension not supported: GL_NV_cooperative_matrix2.*") - message(STATUS "GL_NV_cooperative_matrix2 not supported by glslc") - set(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT OFF) - else() - message(STATUS "GL_NV_cooperative_matrix2 supported by glslc") - set(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT ON) - add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) - endif() + set(VULKAN_SHADER_GEN_CMAKE_ARGS + -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR} + -DCMAKE_RUNTIME_OUTPUT_DIRECTORY=${CMAKE_RUNTIME_OUTPUT_DIRECTORY} + ) - # Compile a test shader to determine whether GL_EXT_integer_dot_product is supported. - # If it's not, there will be an error to stderr. - # If it's supported, set a define to indicate that we should compile those shaders - execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_integer_dot_support.comp" - OUTPUT_VARIABLE glslc_output - ERROR_VARIABLE glslc_error) + # Test all shader extensions + test_shader_extension_support( + "GL_KHR_cooperative_matrix" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat_support.comp" + "GGML_VULKAN_COOPMAT_GLSLC_SUPPORT" + ) - if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_integer_dot_product.*") - message(STATUS "GL_EXT_integer_dot_product not supported by glslc") - set(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT OFF) - else() - message(STATUS "GL_EXT_integer_dot_product supported by glslc") - set(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT ON) - add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) - endif() + test_shader_extension_support( + "GL_NV_cooperative_matrix2" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp" + "GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT" + ) - # Compile a test shader to determine whether GL_EXT_bfloat16 is supported. - # If it's not, there will be an error to stderr. - # If it's supported, set a define to indicate that we should compile those shaders - execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_bfloat16_support.comp" - OUTPUT_VARIABLE glslc_output - ERROR_VARIABLE glslc_error) + test_shader_extension_support( + "GL_EXT_integer_dot_product" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_integer_dot_support.comp" + "GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT" + ) - if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_bfloat16.*") - message(STATUS "GL_EXT_bfloat16 not supported by glslc") - set(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT OFF) - else() - message(STATUS "GL_EXT_bfloat16 supported by glslc") - set(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT ON) - add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) - endif() + test_shader_extension_support( + "GL_EXT_bfloat16" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_bfloat16_support.comp" + "GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT" + ) target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan) target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) @@ -124,16 +116,8 @@ if (Vulkan_FOUND) add_compile_definitions(GGML_VULKAN_RUN_TESTS) endif() - if (NOT CMAKE_CROSSCOMPILING) - add_subdirectory(vulkan-shaders) - if (MSVC) - foreach(CONFIG ${CMAKE_CONFIGURATION_TYPES}) - string(TOUPPER ${CONFIG} CONFIG) - set_target_properties(vulkan-shaders-gen PROPERTIES - RUNTIME_OUTPUT_DIRECTORY_${CONFIG} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) - endforeach() - endif() - else() + # Set up toolchain for host compilation whether cross-compiling or not + if (CMAKE_CROSSCOMPILING) if (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN) set(HOST_CMAKE_TOOLCHAIN_FILE ${GGML_VULKAN_SHADERS_GEN_TOOLCHAIN}) else() @@ -146,25 +130,31 @@ if (Vulkan_FOUND) configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cmake/host-toolchain.cmake.in ${CMAKE_BINARY_DIR}/host-toolchain.cmake @ONLY) set(HOST_CMAKE_TOOLCHAIN_FILE ${CMAKE_BINARY_DIR}/host-toolchain.cmake) endif() - message(STATUS "vulkan-shaders-gen toolchain file: ${HOST_CMAKE_TOOLCHAIN_FILE}") + else() + # For non-cross-compiling, use empty toolchain (use host compiler) + set(HOST_CMAKE_TOOLCHAIN_FILE "") + endif() + + # Always use ExternalProject_Add approach + include(ExternalProject) - include(ExternalProject) - # Native build through ExternalProject_Add - ExternalProject_Add( - vulkan-shaders-gen - SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders - CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE} - -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR} - -DGGML_VULKAN_COOPMAT_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT_GLSLC_SUPPORT} - -DGGML_VULKAN_COOPMAT2_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT} - -DGGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT=${GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT} - -DGGML_VULKAN_BFLOAT16_GLSLC_SUPPORT=${GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT} - BUILD_COMMAND ${CMAKE_COMMAND} --build . - INSTALL_COMMAND ${CMAKE_COMMAND} --install . - INSTALL_DIR ${CMAKE_BINARY_DIR} - ) - ExternalProject_Add_StepTargets(vulkan-shaders-gen build install) + # Add toolchain file if cross-compiling + if (CMAKE_CROSSCOMPILING) + list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE}) + message(STATUS "vulkan-shaders-gen toolchain file: ${HOST_CMAKE_TOOLCHAIN_FILE}") endif() + + # Native build through ExternalProject_Add + ExternalProject_Add( + vulkan-shaders-gen + SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders + CMAKE_ARGS ${VULKAN_SHADER_GEN_CMAKE_ARGS} + BUILD_COMMAND ${CMAKE_COMMAND} --build . + INSTALL_COMMAND ${CMAKE_COMMAND} --install . + INSTALL_DIR ${CMAKE_BINARY_DIR} + ) + ExternalProject_Add_StepTargets(vulkan-shaders-gen build install) + set (_ggml_vk_host_suffix $,.exe,>) set (_ggml_vk_genshaders_cmd ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/vulkan-shaders-gen${_ggml_vk_host_suffix}) set (_ggml_vk_header ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp) @@ -175,9 +165,8 @@ if (Vulkan_FOUND) file(GLOB _ggml_vk_shader_deps "${_ggml_vk_input_dir}/*.comp") set (_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen) - if (CMAKE_CROSSCOMPILING) - set(_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen-build vulkan-shaders-gen-install) - endif() + # Add build and install dependencies for all builds + set(_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen-build vulkan-shaders-gen-install) add_custom_command( OUTPUT ${_ggml_vk_header} diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index e2b357fdc1531..0856a1122832d 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -288,6 +288,9 @@ struct vk_device_struct { bool coopmat_acc_f32_support {}; bool coopmat_acc_f16_support {}; bool coopmat_bf16_support {}; + bool coopmat_support_16x16x16_f16acc {}; + bool coopmat_support_16x16x16_f32acc {}; + bool coopmat1_fa_support {}; uint32_t coopmat_m; uint32_t coopmat_n; uint32_t coopmat_k; @@ -410,6 +413,13 @@ struct vk_device_struct { vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D64_cm1[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D80_cm1[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D96_cm1[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D112_cm1[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D128_cm1[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D256_cm1[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2]; @@ -1588,19 +1598,36 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events ); } +enum FaCodePath { + FA_SCALAR, + FA_COOPMAT1, + FA_COOPMAT2, +}; + // number of rows/cols for flash attention shader static constexpr uint32_t flash_attention_num_small_rows = 32; static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; static constexpr uint32_t scalar_flash_attention_num_large_rows = 8; -static uint32_t get_fa_num_small_rows(bool scalar) { - return scalar ? scalar_flash_attention_num_small_rows : flash_attention_num_small_rows; +// The FA coopmat1 shader assumes 16x16x16 matrix multiply support. +// 128 threads split into four subgroups, each subgroup does 1/4 +// of the Bc dimension. +static constexpr uint32_t coopmat1_flash_attention_num_large_rows = 16; +static constexpr uint32_t scalar_flash_attention_Bc = 64; +static constexpr uint32_t scalar_flash_attention_workgroup_size = 128; + +static uint32_t get_fa_num_small_rows(FaCodePath path) { + if (path == FA_COOPMAT2) { + return flash_attention_num_small_rows; + } else { + return scalar_flash_attention_num_small_rows; + } } -static std::array fa_rows_cols(bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) { +static std::array fa_rows_cols(FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) { GGML_UNUSED(clamp); - if (scalar) { + if (path == FA_SCALAR) { if (small_rows) { return {scalar_flash_attention_num_small_rows, 64}; } else { @@ -1608,9 +1635,17 @@ static std::array fa_rows_cols(bool scalar, uint32_t D, uint32_t cl } } + if (path == FA_COOPMAT1) { + if (small_rows) { + return {scalar_flash_attention_num_small_rows, scalar_flash_attention_Bc}; + } else { + return {coopmat1_flash_attention_num_large_rows, scalar_flash_attention_Bc}; + } + } + // small rows, large cols if (small_rows) { - return {get_fa_num_small_rows(scalar), 32}; + return {get_fa_num_small_rows(FA_COOPMAT2), 32}; } // small cols to reduce register count @@ -1907,17 +1942,19 @@ static void ggml_vk_load_shaders(vk_device& device) { parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size)); }; - auto const &fa_wg_denoms = [&](bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array { - return {fa_rows_cols(scalar, D, clamp, type, small_rows)[0], 1, 1}; + auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array { + return {fa_rows_cols(path, D, clamp, type, small_rows)[0], 1, 1}; }; - auto const &fa_spec_constants = [&](bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector { + auto const &fa_spec_constants = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector { // For large number of rows, 128 invocations seems to work best. // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we // can't use 256 for D==80. // For scalar, use 128 (arbitrary) - uint32_t wg_size = scalar ? 128 : ((small_rows && (D % 32) == 0) ? 256 : 128); - auto rows_cols = fa_rows_cols(scalar, D, clamp, type, small_rows); + uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1) + ? scalar_flash_attention_workgroup_size + : ((small_rows && (D % 32) == 0) ? 256 : 128); + auto rows_cols = fa_rows_cols(path, D, clamp, type, small_rows); // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it. // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader. @@ -1929,36 +1966,43 @@ static void ggml_vk_load_shaders(vk_device& device) { return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split}; }; -#define CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, D) \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true); \ - -#define CREATE_FA(TYPE, NAMELC, SCALAR, SUFFIX) \ - CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 64) \ - CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 80) \ - CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 96) \ - CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 112) \ - CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 128) \ - CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 256) - - CREATE_FA(GGML_TYPE_F16, f16, true, ) - CREATE_FA(GGML_TYPE_Q4_0, q4_0, true, ) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, true, ) +#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, D) \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + +#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256) + + CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, ) +#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (device->coopmat1_fa_support) { + CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1) + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1) + } +#endif #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) if (device->coopmat2) { - CREATE_FA(GGML_TYPE_F16, f16, false, _cm2) - CREATE_FA(GGML_TYPE_Q4_0, q4_0, false, _cm2) - CREATE_FA(GGML_TYPE_Q4_1, q4_1, false, _cm2) - CREATE_FA(GGML_TYPE_Q5_0, q5_0, false, _cm2) - CREATE_FA(GGML_TYPE_Q5_1, q5_1, false, _cm2) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, false, _cm2) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, false, _cm2) + CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2) } #endif #undef CREATE_FA2 @@ -2041,17 +2085,17 @@ static void ggml_vk_load_shaders(vk_device& device) { // Create 6 variants, {s,m,l}x{unaligned,aligned} #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \ // Create 2 variants, {f16,f32} accumulator #define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ @@ -3009,6 +3053,11 @@ static vk_device ggml_vk_get_device(size_t idx) { #if defined(VK_KHR_cooperative_matrix) device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix; + + // coopmat1 fa shader currently assumes 32 invocations per subgroup + device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support && + device->subgroup_size_control && device->subgroup_min_size <= 32 && + device->subgroup_max_size >= 32; #endif if (coopmat2_support) { @@ -3143,6 +3192,9 @@ static vk_device ggml_vk_get_device(size_t idx) { // Only enable if shape is identical device->coopmat_acc_f32_support = true; } + if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) { + device->coopmat_support_16x16x16_f32acc = true; + } } else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 && (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) { // coopmat sizes not set yet @@ -3155,6 +3207,9 @@ static vk_device ggml_vk_get_device(size_t idx) { // Only enable if shape is identical device->coopmat_acc_f16_support = true; } + if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) { + device->coopmat_support_16x16x16_f16acc = true; + } } } else if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eSint8 && (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eSint8 && @@ -5688,6 +5743,36 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx } } +static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t D, bool f32acc) { + // Needs to be kept up to date on shader changes + const uint32_t wg_size = scalar_flash_attention_workgroup_size; + const uint32_t Br = scalar_flash_attention_num_large_rows; + const uint32_t Bc = scalar_flash_attention_Bc; + + const uint32_t acctype = f32acc ? 4 : 2; + const uint32_t f16vec4 = 8; + + const uint32_t tmpsh = wg_size * sizeof(float); + const uint32_t tmpshv4 = wg_size * 4 * acctype; + + const uint32_t Qf = Br * (D / 4 + 2) * f16vec4; + + const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br; + const uint32_t sfsh = Bc * sfshstride * acctype; + + const uint32_t kshstride = D / 4 + 2; + const uint32_t ksh = Bc * kshstride * f16vec4; + + const uint32_t slope = Br * sizeof(float); + + const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope; + const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; + + VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(D=" << D << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported); + + return supported; +} + static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) { VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3]; std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3]; @@ -5738,7 +5823,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx assert(q->type == GGML_TYPE_F32); assert(k->type == v->type); - bool scalar = !ctx->device->coopmat2; + FaCodePath path = ctx->device->coopmat2 ? FA_COOPMAT2 : + ctx->device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR; + + if (path == FA_COOPMAT1) { + const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) || + (dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc); + + const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, D, dst->op_params[3] == GGML_PREC_F32); + + if (!coopmat_shape_supported || !coopmat_shmem_supported) { + path = FA_SCALAR; + } + } uint32_t gqa_ratio = 1; uint32_t qk_ratio = neq2 / nek2; @@ -5746,9 +5843,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx uint32_t workgroups_y = (uint32_t)neq2; uint32_t workgroups_z = (uint32_t)neq3; - // For scalar FA, we can use the "large" size to accommodate qga. - // For coopmat FA, we always use the small size (which is still pretty large for gqa). - const uint32_t max_gqa = scalar ? scalar_flash_attention_num_large_rows : get_fa_num_small_rows(false); + // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga. + // For coopmat2 FA, we always use the small size (which is still pretty large for gqa). + uint32_t max_gqa; + switch (path) { + case FA_SCALAR: + case FA_COOPMAT1: + // We may switch from coopmat1 to scalar, so use the scalar limit for both + max_gqa = scalar_flash_attention_num_large_rows; + break; + case FA_COOPMAT2: + max_gqa = get_fa_num_small_rows(FA_COOPMAT2); + break; + default: + GGML_ASSERT(0); + } if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa && qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) { @@ -5761,11 +5870,16 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx } vk_pipeline *pipelines; - // XXX TODO other backends may be changing accumulator precision to default to f32 soon - bool f32acc = scalar || dst->op_params[3] == GGML_PREC_F32; - bool small_rows = N <= get_fa_num_small_rows(scalar); + bool small_rows = N <= get_fa_num_small_rows(path); - if (scalar) { + if (small_rows && path == FA_COOPMAT1) { + path = FA_SCALAR; + } + + bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32; + + switch (path) { + case FA_SCALAR: switch (D) { case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break; case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break; @@ -5777,7 +5891,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx GGML_ASSERT(!"unsupported D value"); return; } - } else { + break; + case FA_COOPMAT1: + switch (D) { + case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm1[k->type][f32acc][small_rows][0]; break; + case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm1[k->type][f32acc][small_rows][0]; break; + case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm1[k->type][f32acc][small_rows][0]; break; + case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm1[k->type][f32acc][small_rows][0]; break; + case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm1[k->type][f32acc][small_rows][0]; break; + case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm1[k->type][f32acc][small_rows][0]; break; + default: + GGML_ASSERT(!"unsupported D value"); + return; + } + break; + case FA_COOPMAT2: switch (D) { case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break; case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break; @@ -5789,6 +5917,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx GGML_ASSERT(!"unsupported D value"); return; } + break; + default: + GGML_ASSERT(0); } assert(pipelines); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt b/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt index ad13f69b3fb1a..e60e9d1e5b5c5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt @@ -5,18 +5,35 @@ find_package (Threads REQUIRED) if (GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + message(STATUS "Enabling coopmat glslc support") endif() if (GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + message(STATUS "Enabling coopmat2 glslc support") endif() if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + message(STATUS "Enabling dot glslc support") endif() if (GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + message(STATUS "Enabling bfloat16 glslc support") endif() + set(TARGET vulkan-shaders-gen) add_executable(${TARGET} vulkan-shaders-gen.cpp) install(TARGETS ${TARGET} RUNTIME) target_compile_features(${TARGET} PRIVATE cxx_std_17) target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads) + +# Configure output directories for MSVC builds +if(MSVC) + # Get the main project's runtime output directory if possible + if(DEFINED CMAKE_RUNTIME_OUTPUT_DIRECTORY) + foreach(CONFIG ${CMAKE_CONFIGURATION_TYPES}) + string(TOUPPER ${CONFIG} CONFIG) + set_target_properties(${TARGET} PROPERTIES + RUNTIME_OUTPUT_DIRECTORY_${CONFIG} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) + endforeach() + endif() +endif() diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index e6545160d5394..1683557681439 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -12,6 +12,7 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +layout (constant_id = 0) const uint32_t WorkGroupSize = 128; layout (constant_id = 1) const uint32_t Br = 1; layout (constant_id = 2) const uint32_t Bc = 32; layout (constant_id = 3) const uint32_t D = 32; @@ -19,7 +20,7 @@ layout (constant_id = 3) const uint32_t D = 32; layout (constant_id = 5) const uint32_t D_split = 16; const uint32_t D_per_thread = D / D_split; -const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split; +const uint32_t cols_per_iter = WorkGroupSize / D_split; const uint32_t cols_per_thread = Bc / cols_per_iter; layout (push_constant) uniform parameter { @@ -134,8 +135,8 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i return ACC_TYPE(pow(base, ACC_TYPE(exph))); } -shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x]; -shared vec4 tmpshv4[gl_WorkGroupSize.x]; +shared FLOAT_TYPE tmpsh[WorkGroupSize]; +shared vec4 tmpshv4[WorkGroupSize]; shared float masksh[Bc][Br]; shared vec4 Qf[Br][D / 4]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp new file mode 100644 index 0000000000000..8b86b623bd94d --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -0,0 +1,506 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#extension GL_KHR_shader_subgroup_basic : enable +#extension GL_KHR_memory_scope_semantics : enable +#extension GL_KHR_cooperative_matrix : enable + +#include "types.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (constant_id = 1) const uint32_t Br = 1; +layout (constant_id = 2) const uint32_t Bc = 32; +layout (constant_id = 3) const uint32_t D = 32; + +layout (constant_id = 5) const uint32_t D_split = 16; + +const uint32_t D_per_thread = D / D_split; +const uint32_t row_split = 4; +const uint32_t rows_per_thread = Br / row_split; +const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split; +const uint32_t cols_per_thread = Bc / cols_per_iter; + +layout (push_constant) uniform parameter { + uint32_t N; + uint32_t KV; + + uint32_t ne1; + uint32_t ne2; + uint32_t ne3; + + uint32_t neq2; + uint32_t neq3; + uint32_t nek2; + uint32_t nek3; + uint32_t nev2; + uint32_t nev3; + uint32_t nem1; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + uint32_t nb21; + uint32_t nb22; + uint32_t nb23; + uint32_t nb31; + + float scale; + float max_bias; + float logit_softcap; + + uint32_t mask; + uint32_t n_head_log2; + float m0; + float m1; + + uint32_t gqa_ratio; + uint32_t split_kv; + uint32_t k_num; +} p; + +layout (binding = 0) readonly buffer Q {float data_q[];}; +layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];}; +layout (binding = 1) readonly buffer K {float16_t data_k[];}; +layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];}; +layout (binding = 2) readonly buffer V {float16_t data_v[];}; +layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; +layout (binding = 3) readonly buffer M {float16_t data_m[];}; +layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; + +#if defined(A_TYPE_PACKED16) +#define BINDING_IDX_K 0 +#define BINDING_IDX_V 1 +layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2]; +#endif + +#if defined(DATA_A_Q4_0) +#define BLOCK_BYTE_SIZE 18 + +vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; + + return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); +} +#endif + +#if defined(DATA_A_Q8_0) +#define BLOCK_BYTE_SIZE 34 +vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; + + return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); +} +#endif + +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) + +// Store the output when doing grouped query attention. +// Rows index by Q's dimension 2, and the first N rows are valid. +D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + uint32_t offset = (iq2 + r) * D + c; + data_o[o_offset + offset] = D_TYPE(elem); + return elem; +} + +// Store column zero. This is used to save per-row m and L values for split_k. +ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + if (r < N && c == 0) { + uint32_t offset = iq2 + r; + data_o[o_offset + offset] = D_TYPE(elem); + } + return elem; +} + +// Load the slope matrix, indexed by Q's dimension 2. +ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) +{ + const uint32_t h = iq2 + (r % p.gqa_ratio); + + const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); + const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); + + return ACC_TYPE(pow(base, ACC_TYPE(exph))); +} + +// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd +const uint32_t MatBr = 16; +const uint32_t MatBc = 16; + +shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x]; +shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x]; + +const uint32_t qstride = D / 4 + 2; // in units of f16vec4 +shared f16vec4 Qf[Br * qstride]; + +// Avoid padding for D==256 to make it fit in 48KB shmem. +const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br; +shared ACC_TYPE sfsh[Bc * sfshstride]; + +const uint32_t kshstride = D / 4 + 2; // in units of f16vec4 +shared f16vec4 ksh[Bc * kshstride]; + +shared float slope[Br]; + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + + const uint32_t tid = gl_LocalInvocationIndex; + const uint32_t N = p.N; + const uint32_t KV = p.KV; + + const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split; + const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup; + const uint32_t d_tid = gl_LocalInvocationIndex % D_split; + const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split; + +#define tile_row(r) (row_tid * rows_per_thread + (r)) + + uint32_t i = gl_WorkGroupID.x; + uint32_t split_k_index = 0; + + if (p.k_num > 1) { + i = 0; + split_k_index = gl_WorkGroupID.x; + } + + const uint32_t Tr = CEIL_DIV(N, Br); + + const uint32_t start_j = split_k_index * p.split_kv / Bc; + const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc); + + // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y. + // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2. + const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio; + const uint32_t iq3 = gl_WorkGroupID.z; + + // broadcast factors + const uint32_t rk2 = p.neq2/p.nek2; + const uint32_t rk3 = p.neq3/p.nek3; + + const uint32_t rv2 = p.neq2/p.nev2; + const uint32_t rv3 = p.neq3/p.nev3; + + // k indices + const uint32_t ik3 = iq3 / rk3; + const uint32_t ik2 = iq2 / rk2; + + // v indices + const uint32_t iv3 = iq3 / rv3; + const uint32_t iv2 = iq2 / rv2; + + // nb?1 are already divided by the type size and are in units of elements. + // When using grouped query attention, Q is indexed by iq2, so the stride + // should be nb02 (which is in bytes). + uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; + uint32_t k_stride = p.nb11; + uint32_t v_stride = p.nb21; + // When using grouped query attention, all rows use the same mask (stride 0). + // "p.gqa_ratio >> 16" is just a roundabout way of writing zero + // that prevents the compiler from folding the "&" through the select + // and breaking the alignment detection. + uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; + + uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; + + [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (D / 4); + uint32_t r = (idx + tid) / (D / 4); + if (r < Br && d < D / 4 && + i * Br + r < N) { + Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale); + } + } + barrier(); + + ACC_TYPEV4 Of[rows_per_thread][D_per_thread / 4]; + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] = ACC_TYPEV4(0.0); + } + } + + float Lf[rows_per_thread], Mf[rows_per_thread]; + + // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M. + const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF); + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Lf[r] = 0; + Mf[r] = NEG_FLT_MAX_OVER_2; + } + + // ALiBi + if (p.max_bias > 0.0f) { + if (tid < Br) { + uint r = tid; + slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2); + } + barrier(); + } else { + if (tid < Br) { + uint r = tid; + slope[r] = 1.0; + } + barrier(); + } + +#if BLOCK_SIZE > 1 + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE; +#else + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; +#endif + + [[dont_unroll]] + for (uint32_t j = start_j; j < end_j; ++j) { + + [[unroll]] for (uint32_t idx = 0; idx < Bc * D / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (D / 4); + uint32_t c = (idx + tid) / (D / 4); + if (c < Bc && d < D / 4) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + f16vec4 K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); +#else + f16vec4 K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); +#endif + + ksh[c * kshstride + d] = K_Tf; + } + } + barrier(); + + // K * Q^T -> S^T: Bc x D * D x Br -> Bc x Br + // Bc split across workgroup (four subgroups), loop over D in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16 + // This is written transposed in order to allow for N being 8 if implementations need it + coopmat SfMat = coopmat(0); + coopmat KMat; + coopmat QMat; + + for (uint32_t d = 0; d < D / 16; ++d) { + coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor); + + uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4; + coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor); + + SfMat = coopMatMulAdd(KMat, QMat, SfMat); + } + + uint coord = gl_SubgroupID * MatBc * sfshstride; + coopMatStore(SfMat, sfsh, coord, sfshstride, gl_CooperativeMatrixLayoutRowMajor); + barrier(); + + if (p.logit_softcap != 0.0f) { + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) / Br; + uint32_t r = (idx + tid) % Br; + if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { + sfsh[c * sfshstride + r] = ACC_TYPE(p.logit_softcap * tanh(sfsh[c * sfshstride + r])); + } + } + barrier(); + } + + if (p.mask != 0) { + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) % Bc; + uint32_t r = (idx + tid) / Bc; + if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { + sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[(i * Br + r) * m_stride + (j * Bc + c)])); + } + } + barrier(); + } + + float eMf[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + float rowmaxf = sfsh[tile_row(r) + (0 * cols_per_iter + col_tid) * sfshstride]; + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride])); + } + float Moldf = Mf[r]; + + // M = max(rowmax, Mold) + // P = e^(S - M) + // eM = e^(Mold - M) + Mf[r] = max(rowmaxf, Moldf); + eMf[r] = exp(Moldf - Mf[r]); + } + + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] = float16_t(eMf[r]) * Of[r][d]; + } + } + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Lf[r] = eMf[r]*Lf[r]; + } + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + float Pf[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]); + Lf[r] += Pf[r]; + } + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); +#else + vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); +#endif + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] += float16_t(Pf[r]) * ACC_TYPEV4(Vf); + } + } + } + + barrier(); + } + + // reduce across threads + + float rowmaxf[rows_per_thread], eMf[rows_per_thread], Moldf[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + FLOAT_TYPE M = Mf[r]; + tmpsh[tid] = M; + // Compute max across the row + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) { + M = max(M, tmpsh[tid ^ s]); + barrier(); + tmpsh[tid] = M; + barrier(); + } + rowmaxf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup]; + barrier(); + } + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Moldf[r] = Mf[r]; + + // M = max(rowmax, Mold) + // eM = e^(Mold - M) + Mf[r] = max(rowmaxf[r], Moldf[r]); + eMf[r] = exp(Moldf[r] - Mf[r]); + + Lf[r] = eMf[r]*Lf[r]; + } + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + FLOAT_TYPE L = Lf[r]; + tmpsh[tid] = L; + // Compute sum across the row + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) { + L += tmpsh[tid ^ s]; + barrier(); + tmpsh[tid] = L; + barrier(); + } + Lf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup]; + barrier(); + } + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + + Of[r][d] = float16_t(eMf[r]) * Of[r][d]; + tmpshv4[tid] = Of[r][d]; + + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) { + Of[r][d] += tmpshv4[tid ^ s]; + barrier(); + tmpshv4[tid] = Of[r][d]; + barrier(); + } + Of[r][d] = tmpshv4[d_tid + row_tid * threads_per_rowgroup]; + barrier(); + } + } + + // If there is split_k, then the split_k resolve shader does the final + // division by L. Store the intermediate O value and per-row m and L values. + if (p.k_num > 1) { + uint32_t o_offset = D * p.ne1 * split_k_index; + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (tile_row(r) < N) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N); + } + } + } + } + + o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (tile_row(r) < N) { + perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); + perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); + } + } + + return; + } + + float Lfrcp[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Lfrcp[r] = 1.0 / Lf[r]; + } + + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] *= float16_t(Lfrcp[r]); + } + } + + uint32_t o_offset = iq3*p.ne2*p.ne1; + + if (p.gqa_ratio > 1) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (tile_row(r) < N) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N); + } + } + } + } + } else { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (i * Br + tile_row(r) < N) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + data_o[o_offset + iq2 * D + (i * Br + tile_row(r)) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); + } + } + } + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index d196137eb292e..9361e2ac83b0f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -215,7 +215,7 @@ static std::mutex compile_count_mutex; static std::condition_variable compile_count_cond; void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) { - std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_coopmat" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")); + std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")); std::string out_fname = join_paths(output_dir, name + ".spv"); std::string in_path = join_paths(input_dir, in_fname); @@ -424,6 +424,7 @@ void process_shaders() { // flash attention for (const auto& f16acc : {false, true}) { std::string acctype = f16acc ? "float16_t" : "float"; + std::string acctypev4 = f16acc ? "f16vec4" : "vec4"; for (const auto& tname : type_names) { if (tname == "f32") { @@ -440,6 +441,16 @@ void process_shaders() { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc); } +#endif +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (tname == "f16") { + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", + merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"COOPMAT", "1"}}), true, true, false, f16acc); + } else if (tname == "q4_0" || tname == "q8_0") { + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", + merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc); + } #endif if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", diff --git a/include/llama.h b/include/llama.h index 99e5fba244fcc..0cb048d2bcab9 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1263,6 +1263,9 @@ extern "C" { /// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335 LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed); + /// @details Smoothing sampling as described in https://github.com/ggml-org/llama.cpp/pull/6445 + LLAMA_API struct llama_sampler * llama_sampler_init_smoothing (float factor, float curve); + /// @details Top n sigma sampling as described in academic paper "Top-nσ: Not All Logits Are You Need" https://arxiv.org/pdf/2411.07641 LLAMA_API struct llama_sampler * llama_sampler_init_top_n_sigma(float n); diff --git a/scripts/compare-llama-bench.py b/scripts/compare-llama-bench.py index fc93bf62a67fa..a1013c3b7a66d 100755 --- a/scripts/compare-llama-bench.py +++ b/scripts/compare-llama-bench.py @@ -113,7 +113,7 @@ help_s = ( "Columns to add to the table. " "Accepts a comma-separated list of values. " - f"Legal values: {', '.join(KEY_PROPERTIES[:-2])}. " + f"Legal values: {', '.join(KEY_PROPERTIES[:-3])}. " "Defaults to model name (model_type) and CPU and/or GPU name (cpu_info, gpu_info) " "plus any column where not all data points are the same. " "If the columns are manually specified, then the results for each unique combination of the " @@ -505,7 +505,7 @@ def valid_format(data_files: list[str]) -> bool: show = known_args.show.split(",") unknown_cols = [] for prop in show: - if prop not in KEY_PROPERTIES[:-2]: # Last two values are n_prompt, n_gen. + if prop not in KEY_PROPERTIES[:-3]: # Last three values are n_prompt, n_gen, n_depth. unknown_cols.append(prop) if unknown_cols: logger.error(f"Unknown values for --show: {', '.join(unknown_cols)}") diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 62246c10dab08..a3b84a6a82e74 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1704,10 +1704,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) { } } - LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__); llama_kv_cache * kv_self = static_cast(memory.get()); - kv_self->state_write(io); + if (kv_self != nullptr) { + LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__); + kv_self->state_write(io); + } return io.n_bytes(); } diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index 4cce51668b42d..ee5208334d3b0 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -822,13 +822,18 @@ void llama_model_loader::init_mappings(bool prefetch, llama_mlocks * mlock_mmaps mappings.reserve(files.size()); mmaps_used.reserve(files.size()); for (const auto & file : files) { - auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU)); - if (!reg) { - throw std::runtime_error(format("%s: no CPU backend found", __func__)); + bool is_numa = false; + + auto * dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (dev) { + auto * reg = ggml_backend_dev_backend_reg(dev); + auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa"); + if (is_numa_fn) { + is_numa = is_numa_fn(); + } } - auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa"); - std::unique_ptr mapping = std::make_unique(file.get(), prefetch ? -1 : 0, is_numa_fn()); + std::unique_ptr mapping = std::make_unique(file.get(), prefetch ? -1 : 0, is_numa); mmaps_used.emplace_back(mapping->size(), 0); if (mlock_mmaps) { std::unique_ptr mlock_mmap(new llama_mlock()); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index f652f4b861d1f..7fd094b63f269 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -12218,6 +12218,9 @@ struct llm_build_granite : public llm_graph_context { // inp_pos - built only if rope enabled ggml_tensor * inp_pos = nullptr; + if (use_rope) { + inp_pos = build_inp_pos(); + } auto * inp_attn = build_attn_inp_kv_unified(); @@ -12260,10 +12263,6 @@ struct llm_build_granite : public llm_graph_context { Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); if (use_rope) { - - if (!inp_pos) { - inp_pos = build_inp_pos(); - } ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 804b11e0a943e..aef7443cee0f0 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1206,6 +1206,68 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, ); } +// smoothing + +struct llama_sampler_smoothing { + const float factor; + const float curve; +}; + +static const char * llama_sampler_smoothing_name(const struct llama_sampler * /*smpl*/) { + return "smoothing"; +} + +static void llama_sampler_smoothing_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_smoothing *) smpl->ctx; + + // no need to do anything if there is only one (or zero) candidates + if (cur_p->size <= 1) { + return; + } + + if (ctx->factor > 0.0f) { + llama_sampler_softmax_impl(cur_p); + float h = cur_p->data[0].logit; // Find the maximum logit for h to be added after the transformation + + // Apply the modified quadratic transformation using the smoothing_factor and smoothing_curve + for (size_t i = 0; i < cur_p->size; ++i) { + float logit_shifted = cur_p->data[i].logit - h; + float k = (3 - ctx->curve) / 2; + float s = (ctx->curve - 1) / 2; + cur_p->data[i].logit = -(k * ctx->factor * logit_shifted * logit_shifted) + (s * ctx->factor * logit_shifted * logit_shifted * logit_shifted) + h; + } + llama_sampler_softmax_impl(cur_p); + } +} + +static struct llama_sampler * llama_sampler_smoothing_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_smoothing *) smpl->ctx; + return llama_sampler_init_smoothing(ctx->factor, ctx->curve); +} + +static void llama_sampler_smoothing_free(struct llama_sampler * smpl) { + delete (llama_sampler_smoothing *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_smoothing_i = { + /* .name = */ llama_sampler_smoothing_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_smoothing_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_smoothing_clone, + /* .free = */ llama_sampler_smoothing_free, +}; + +struct llama_sampler * llama_sampler_init_smoothing(float factor, float curve) { + return llama_sampler_init( + /* .iface = */ &llama_sampler_smoothing_i, + /* .ctx = */ new llama_sampler_smoothing { + /* .smoothing_factor = */ factor, + /* .smoothing_curve = */ curve + } + ); +} + // mirostat struct llama_sampler_mirostat { diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index 60ac62b385f35..90f68718fad53 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -126,6 +126,17 @@ static void test_xtc(const std::vector & probs, const std::vector tester.check(); } +static void test_smoothing(const std::vector & probs, const std::vector & probs_expected, float factor, float curve) { + sampler_tester tester(probs, probs_expected); + + DUMP(&tester.cur_p); + tester.apply(llama_sampler_init_smoothing(factor, curve)); + tester.apply(llama_sampler_init_dist (0)); + DUMP(&tester.cur_p); + + tester.check(); +} + static void test_typical(const std::vector & probs, const std::vector & probs_expected, float p) { sampler_tester tester(probs, probs_expected); @@ -341,6 +352,9 @@ int main(void) { printf("XTC should not:\n"); test_xtc({0.4f, 0.3f, 0.2f, 0.1f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0.99f, 0.39f); + test_smoothing({0.1f, 0.2f, 0.3f, 0.4f}, {0.372382f, 0.342804f, 0.230319f, 0.054495f}, 1.0f, 1.0f); + test_smoothing({0.1f, 0.2f, 0.3f, 0.4f}, {0.368339f, 0.349226f, 0.245247f, 0.037188f}, 1.0f, 2.0f); + test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f); test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f); diff --git a/tools/main/README.md b/tools/main/README.md index 4f16ad6b2b10e..5e028adee2df8 100644 --- a/tools/main/README.md +++ b/tools/main/README.md @@ -296,6 +296,21 @@ Being experimental and unique, XTC is disabled by default. The recommended combi Example usage: `--xtc-probability 0.5 --xtc-threshold 0.1` +### Smoothing / Quadratic Sampling + +- `--smoothing-factor N`: Set the smoothing factor for smoothing / quadratic sampling (default: 0.0). +- `--smoothing-curve N`: Set the cubic transformation curve for smoothing / quadratic sampling (default: 1.0). + +(Source: https://github.com/ggml-org/llama.cpp/pull/6445) + +Smoothing / Quadratic Sampling, as described in the [original PR](https://github.com/ggml-org/llama.cpp/pull/6445), is a sampler that changes the probability distribution of tokens in a non-linear fashion. This sampler does not remote any tokens; instead, it tweaks the original logit scores of each token based on the distance from the topmost logit using quadratic transformation. This can be viewed as an alternative to Temperature that scales differently while still punishing extreme outlier tokens. + +By performing a non-linear transformation on token logits, we can effectively avoid biasing towards the topmost token if there is a group of similar probability tokens at the top, thus creating more variance. Higher values of `smoothing factor` would result in more deterministic output, while lower values would boost the creativity of the model. "Smoothing Factor" values of 0.2-0.3 are generally thought to be good for creative writing. It is worth noting that a smoothing factor value of `0.0` disables the sampler completely. + +`smoothing curve` is a second hyperparameter that adds a cubic transformation on top of the original quadratic one, and can "help make lower `smoothing factor` values work if the curve is set higher. A smoothing curve value of `1.0` is equivalant of using just quadratic transformation. + +This sampler is not mutually exclusive with Temperature, they can be used together. + ### Top-nσ Sampling - `--top-nsigma N`: Limit the next token selection to a subset of tokens with pre-softmax logits that are within n * σ less than the max logit (default: -1, -1 = disabled). diff --git a/tools/server/README.md b/tools/server/README.md index 7b944c35ba479..9a9ffe3d482aa 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -105,7 +105,7 @@ The project is under active development, and we are [looking for feedback and co | Argument | Explanation | | -------- | ----------- | -| `--samplers SAMPLERS` | samplers that will be used for generation in the order, separated by ';'
(default: penalties;dry;top_n_sigma;top_k;typ_p;top_p;min_p;xtc;temperature) | +| `--samplers SAMPLERS` | samplers that will be used for generation in the order, separated by ';'
(default: penalties;dry;top_n_sigma;top_k;typ_p;top_p;min_p;xtc;smoothing;temperature) | | `-s, --seed SEED` | RNG seed (default: -1, use random seed for -1) | | `--sampling-seq, --sampler-seq SEQUENCE` | simplified sequence for samplers that will be used (default: edskypmxt) | | `--ignore-eos` | ignore end of stream token and continue generating (implies --logit-bias EOS-inf) | @@ -115,6 +115,8 @@ The project is under active development, and we are [looking for feedback and co | `--min-p N` | min-p sampling (default: 0.1, 0.0 = disabled) | | `--xtc-probability N` | xtc probability (default: 0.0, 0.0 = disabled) | | `--xtc-threshold N` | xtc threshold (default: 0.1, 1.0 = disabled) | +| `--smoothing-factor N` | smoothing factor (default: 0.0, 0.0 = disabled) | +| `--smoothing-curve N` | smoothing curve (default: 1.0, 1.0 = disabled) | | `--typical N` | locally typical sampling, parameter p (default: 1.0, 1.0 = disabled) | | `--repeat-last-n N` | last n tokens to consider for penalize (default: 64, 0 = disabled, -1 = ctx_size) | | `--repeat-penalty N` | penalize repeat sequence of tokens (default: 1.0, 1.0 = disabled) | @@ -447,6 +449,10 @@ These words will not be included in the completion, so make sure to add them to `xtc_threshold`: Set a minimum probability threshold for tokens to be removed via XTC sampler. Default: `0.1` (> `0.5` disables XTC) +`smoothing_factor`: Set the smoothing factor for smoothing / quadratic sampling. Default: `0.0`, which is disabled. + +`smoothing_curve`: Set the cubic transformation curve for smoothing / quadratic sampling. Default: `1.0`, which makes it behaves like the original quadratic sampler. + `mirostat`: Enable Mirostat sampling, controlling perplexity during text generation. Default: `0`, where `0` is disabled, `1` is Mirostat, and `2` is Mirostat 2.0. `mirostat_tau`: Set the Mirostat target entropy, parameter tau. Default: `5.0` @@ -477,7 +483,7 @@ These words will not be included in the completion, so make sure to add them to `return_tokens`: Return the raw generated token ids in the `tokens` field. Otherwise `tokens` remains empty. Default: `false` -`samplers`: The order the samplers should be applied in. An array of strings representing sampler type names. If a sampler is not set, it will not be used. If a sampler is specified more than once, it will be applied multiple times. Default: `["dry", "top_k", "typ_p", "top_p", "min_p", "xtc", "temperature"]` - these are all the available values. +`samplers`: The order the samplers should be applied in. An array of strings representing sampler type names. If a sampler is not set, it will not be used. If a sampler is specified more than once, it will be applied multiple times. Default: `["dry", "top_k", "typ_p", "top_p", "min_p", "xtc", "smoothing", "temperature"]` - these are all the available values. `timings_per_token`: Include prompt processing and text generation speed information in each response. Default: `false` @@ -726,6 +732,8 @@ This endpoint is public (no API key check). By default, it is read-only. To make "min_p": 0.05000000074505806, "xtc_probability": 0.0, "xtc_threshold": 0.10000000149011612, + "smoothing_factor": 0.0, + "smoothing_curve": 1.0, "typical_p": 1.0, "repeat_last_n": 64, "repeat_penalty": 1.0, @@ -760,6 +768,7 @@ This endpoint is public (no API key check). By default, it is read-only. To make "top_p", "min_p", "xtc", + "smoothing", "temperature" ], "speculative.n_max": 16, @@ -872,6 +881,8 @@ Example: "min_p": 0.05000000074505806, "xtc_probability": 0.0, "xtc_threshold": 0.10000000149011612, + "smoothing_factor": 0.0, + "smoothing_curve": 1.0, "typical_p": 1.0, "repeat_last_n": 64, "repeat_penalty": 1.0, @@ -906,6 +917,7 @@ Example: "top_p", "min_p", "xtc", + "smoothing", "temperature" ], "speculative.n_max": 16, @@ -1040,7 +1052,7 @@ To know the `id` of the adapter, use GET `/lora-adapters` Returns information about the loaded model. See [OpenAI Models API documentation](https://platform.openai.com/docs/api-reference/models). -The returned list always has one single element. +The returned list always has one single element. The `meta` field can be `null` (for example, while the model is still loading). By default, model `id` field is the path to model file, specified via `-m`. You can set a custom value for model `id` field via `--alias` argument. For example, `--alias gpt-4o-mini`. diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz deleted file mode 100644 index d7363e13eb25a..0000000000000 Binary files a/tools/server/public/index.html.gz and /dev/null differ diff --git a/tools/server/public_legacy/index-new.html b/tools/server/public_legacy/index-new.html index cbfbbdf2806fa..95741ab87d375 100644 --- a/tools/server/public_legacy/index-new.html +++ b/tools/server/public_legacy/index-new.html @@ -48,6 +48,8 @@ min_p: 0.05, // 0 = disabled; recommended for non-english: ~ 0.4 xtc_probability: 0.0, // 0 = disabled; xtc_threshold: 0.1, // > 0.5 disables XTC; + smoothing_factor: 0.0, // 0 = disabled; + smoothing_curve: 1.0, // > 0.5 = disabled typical_p: 1.0, // 1.0 = disabled presence_penalty: 0.0, // 0.0 = disabled frequency_penalty: 0.0, // 0.0 = disabled @@ -841,6 +843,8 @@ ${FloatField({ label: "Typical-P", title: "Activates local typical sampling, a method used to limit the prediction of tokens that are atypical in the current context. The parameter p controls the strength of this limitation. A value of 1.0 means that this function is deactivated.", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })} ${FloatField({ label: "XTC probability", title: "Sets the chance for token removal (checked once on sampler start)", max: 1.0, min: 0.0, name: "xtc_probability", step: 0.01, value: params.value.xtc_probability })} ${FloatField({ label: "XTC threshold", title: "Sets a minimum probability threshold for tokens to be removed", max: 0.5, min: 0.0, name: "xtc_threshold", step: 0.01, value: params.value.xtc_threshold })} + ${FloatField({ label: "Smoothing Factor", max: 10.0, min: 0.0, name: "smoothing_factor", step: 0.01, value: params.value.smoothing_factor })} + ${FloatField({ label: "Smoothing Curve", max: 10.0, min: 0.0, name: "smoothing_curve", step: 0.01, value: params.value.smoothing_curve })} ${FloatField({ label: "DRY Penalty Multiplier", title: "Set the DRY repetition penalty multiplier. Default is 0.0, which disables DRY.", max: 5.0, min: 0.0, name: "dry_multiplier", step: 0.01, value: params.value.dry_multiplier })} ${FloatField({ label: "DRY Base", title: "Set the DRY repetition penalty base value. Default is 1.75", max: 3.0, min: 1.0, name: "dry_base", step: 0.01, value: params.value.dry_base })} ${IntField({ label: "DRY Allowed Length", title: "Tokens that extend repetition beyond this receive exponentially increasing penalty. Default is 2", max: 10, min: 1, step: 1, name: "dry_allowed_length", value: params.value.dry_allowed_length })} @@ -1143,6 +1147,8 @@

llama.cpp

min_p: { snapValue: 0.05, snapRangeMultiplier: 2 }, xtc_probability: { snapValue: 0.0, snapRangeMultiplier: 4 }, xtc_threshold: { snapValue: 0.5, snapRangeMultiplier: 4 }, + smoothing_factor: { snapValue: 0.0, snapRangeMultiplier: 4 }, + smoothing_curve: { snapValue: 0.0, snapRangeMultiplier: 4 }, top_p: { snapValue: 1.0, snapRangeMultiplier: 4 }, typical_p: { snapValue: 1.0, snapRangeMultiplier: 4 }, repeat_penalty: { snapValue: 1.0, snapRangeMultiplier: 4 }, diff --git a/tools/server/public_legacy/index.html b/tools/server/public_legacy/index.html index 75f39330a789d..f6a94527d1c81 100644 --- a/tools/server/public_legacy/index.html +++ b/tools/server/public_legacy/index.html @@ -312,6 +312,8 @@ min_p: 0.05, // 0 = disabled xtc_probability: 0.0, // 0 = disabled; xtc_threshold: 0.1, // > 0.5 disables XTC; + smoothing_factor: 0.0, // 0 = disabled; + smoothing_curve: 1.0, // > 0.5 = disabled typical_p: 1.0, // 1.0 = disabled presence_penalty: 0.0, // 0.0 = disabled frequency_penalty: 0.0, // 0.0 = disabled @@ -1021,6 +1023,8 @@ ${IntField({ label: "DRY Penalty Last N", max: 2048, min: -1, step: 16, name: "dry_penalty_last_n", value: params.value.dry_penalty_last_n })} ${FloatField({ label: "XTC probability", max: 1.0, min: 0.0, name: "xtc_probability", step: 0.01, value: params.value.xtc_probability })} ${FloatField({ label: "XTC threshold", max: 0.5, min: 0.0, name: "xtc_threshold", step: 0.01, value: params.value.xtc_threshold })} + ${FloatField({ label: "Smoothing Factor", max: 10.0, min: 0.0, name: "smoothing_factor", step: 0.01, value: params.value.smoothing_factor })} + ${FloatField({ label: "Smoothing Curve", max: 10.0, min: 0.0, name: "smoothing_curve", step: 0.01, value: params.value.smoothing_curve })}
diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 7169ffdceebf9..72fccb8d74079 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -150,6 +150,8 @@ struct slot_params { {"top_n_sigma", sampling.top_n_sigma}, {"xtc_probability", sampling.xtc_probability}, {"xtc_threshold", sampling.xtc_threshold}, + {"smoothing_factor", sampling.smoothing_factor}, + {"smoothing_curve", sampling.smoothing_curve}, {"typical_p", sampling.typ_p}, {"repeat_last_n", sampling.penalty_last_n}, {"repeat_penalty", sampling.penalty_repeat}, @@ -253,6 +255,8 @@ struct server_task { params.sampling.top_n_sigma = json_value(data, "top_n_sigma", defaults.sampling.top_n_sigma); params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability); params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold); + params.sampling.smoothing_factor = json_value(data, "smoothing_factor", defaults.sampling.smoothing_factor); + params.sampling.smoothing_curve = json_value(data, "smoothing_curve", defaults.sampling.smoothing_curve); params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p); params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp); params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range); @@ -2951,7 +2955,8 @@ struct server_context { llama_kv_self_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard); llama_kv_self_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard); - if (slot.params.cache_prompt) { + // add generated tokens to cache + { llama_tokens new_tokens = slot.cache_tokens.get_text_tokens(); // copy for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) { new_tokens[i - n_discard] = new_tokens[i]; @@ -2996,10 +3001,7 @@ struct server_context { common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true); slot.n_past += 1; - - if (slot.params.cache_prompt) { - slot.cache_tokens.push_back(slot.sampled); - } + slot.cache_tokens.push_back(slot.sampled); SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n", slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated); @@ -3171,6 +3173,11 @@ struct server_context { SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past); } + } else { + // if we don't cache the prompt, we have to remove the entire KV cache + llama_kv_self_seq_rm(ctx, slot.id, 0, -1); + slot.n_past = 0; + slot.cache_tokens.clear(); } } @@ -3204,7 +3211,7 @@ struct server_context { SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past); // remove the non-common part from the cache - slot.cache_tokens.resize(slot.n_past); + slot.cache_tokens.keep_first(slot.n_past); // check if we should process the image if (slot.n_past < slot.n_prompt_tokens @@ -3221,7 +3228,8 @@ struct server_context { continue; } - if (slot.params.cache_prompt) { + // add the image chunk to cache + { const auto & chunk = slot.prompt_tokens.find_chunk(slot.n_past); slot.cache_tokens.push_back(chunk.get()); // copy } @@ -3242,9 +3250,7 @@ struct server_context { const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd); - if (slot.params.cache_prompt) { - slot.cache_tokens.push_back(cur_tok); - } + slot.cache_tokens.push_back(cur_tok); slot.n_prompt_tokens_processed++; slot.n_past++; @@ -3705,6 +3711,9 @@ int main(int argc, char ** argv) { if (req.path == "/" || tmp.back() == "html") { res.set_content(reinterpret_cast(loading_html), loading_html_len, "text/html; charset=utf-8"); res.status = 503; + } else if (req.path == "/models" || req.path == "/v1/models") { + // allow the models endpoint to be accessed during loading + return true; } else { res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE)); } @@ -4363,7 +4372,13 @@ int main(int argc, char ** argv) { res_ok(res, {{ "prompt", std::move(data.at("prompt")) }}); }; - const auto handle_models = [¶ms, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { + const auto handle_models = [¶ms, &ctx_server, &state, &res_ok](const httplib::Request &, httplib::Response & res) { + server_state current_state = state.load(); + json model_meta = nullptr; + if (current_state == SERVER_STATE_READY) { + model_meta = ctx_server.model_meta(); + } + json models = { {"object", "list"}, {"data", { @@ -4372,7 +4387,7 @@ int main(int argc, char ** argv) { {"object", "model"}, {"created", std::time(0)}, {"owned_by", "llamacpp"}, - {"meta", ctx_server.model_meta()} + {"meta", model_meta}, }, }} }; diff --git a/tools/server/tests/unit/test_completion.py b/tools/server/tests/unit/test_completion.py index 0ed5b99bef4e4..4099c4e25cd6e 100644 --- a/tools/server/tests/unit/test_completion.py +++ b/tools/server/tests/unit/test_completion.py @@ -196,6 +196,18 @@ def test_cache_vs_nocache_prompt(): assert res_cache.body["content"] == res_no_cache.body["content"] +def test_nocache_long_input_prompt(): + global server + server.start() + res = server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is"*32, + "seed": 42, + "temperature": 1.0, + "cache_prompt": False, + }) + assert res.status_code == 200 + + def test_completion_with_tokens_input(): global server server.temperature = 0.0 diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index b8d140e3f051c..45193c17cfd98 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -1153,7 +1153,7 @@ struct server_tokens { tokens.clear(); } - void resize(size_t n) { + void keep_first(size_t n) { GGML_ASSERT(n <= tokens.size()); if (has_mtmd) { // we throw an error if we try to remove a token in the middle of an image diff --git a/tools/server/webui/package-lock.json b/tools/server/webui/package-lock.json index 2c23a7580b38c..a4a9380c64512 100644 --- a/tools/server/webui/package-lock.json +++ b/tools/server/webui/package-lock.json @@ -44,6 +44,7 @@ "eslint": "^9.17.0", "eslint-plugin-react-hooks": "^5.0.0", "eslint-plugin-react-refresh": "^0.4.16", + "fflate": "^0.8.2", "globals": "^15.14.0", "prettier": "^3.4.2", "sass-embedded": "^1.83.4", @@ -2802,6 +2803,13 @@ "reusify": "^1.0.4" } }, + "node_modules/fflate": { + "version": "0.8.2", + "resolved": "https://registry.npmjs.org/fflate/-/fflate-0.8.2.tgz", + "integrity": "sha512-cPJU47OaAoCbg0pBvzsgpTPhmhqI5eJjh/JIu8tPj5q+T7iLvW/JAYUqmE7KOB4R1ZyEhzBaIQpQpardBF5z8A==", + "dev": true, + "license": "MIT" + }, "node_modules/file-entry-cache": { "version": "8.0.0", "resolved": "https://registry.npmjs.org/file-entry-cache/-/file-entry-cache-8.0.0.tgz", diff --git a/tools/server/webui/package.json b/tools/server/webui/package.json index ab1b920bdc5d6..1927f696fbbb9 100644 --- a/tools/server/webui/package.json +++ b/tools/server/webui/package.json @@ -5,7 +5,7 @@ "type": "module", "scripts": { "dev": "vite", - "build": "tsc -b && vite build", + "build": "npm run format && tsc -b && vite build", "format": "eslint . && prettier --write .", "lint": "eslint .", "preview": "vite preview" @@ -47,6 +47,7 @@ "eslint": "^9.17.0", "eslint-plugin-react-hooks": "^5.0.0", "eslint-plugin-react-refresh": "^0.4.16", + "fflate": "^0.8.2", "globals": "^15.14.0", "prettier": "^3.4.2", "sass-embedded": "^1.83.4", diff --git a/tools/server/webui/src/Config.ts b/tools/server/webui/src/Config.ts index 5eef608cb96cf..a591639d79e9e 100644 --- a/tools/server/webui/src/Config.ts +++ b/tools/server/webui/src/Config.ts @@ -17,7 +17,7 @@ export const CONFIG_DEFAULT = { showThoughtInProgress: false, excludeThoughtOnReq: true, // make sure these default values are in sync with `common.h` - samplers: 'edkypmxt', + samplers: 'edkypmxqt', temperature: 0.8, dynatemp_range: 0.0, dynatemp_exponent: 1.0, @@ -26,6 +26,8 @@ export const CONFIG_DEFAULT = { min_p: 0.05, xtc_probability: 0.0, xtc_threshold: 0.1, + smoothing_factor: 0.0, + smoothing_curve: 1.0, typical_p: 1.0, repeat_last_n: 64, repeat_penalty: 1.0, @@ -44,7 +46,7 @@ export const CONFIG_INFO: Record = { apiKey: 'Set the API Key if you are using --api-key option for the server.', systemMessage: 'The starting message that defines how model should behave.', samplers: - 'The order at which samplers are applied, in simplified way. Default is "dkypmxt": dry->top_k->typ_p->top_p->min_p->xtc->temperature', + 'The order at which samplers are applied, in simplified way. Default is "dkypmxqt": dry->top_k->typ_p->top_p->min_p->xtc->smoothing->temperature', temperature: 'Controls the randomness of the generated text by affecting the probability distribution of the output tokens. Higher = more random, lower = more focused.', dynatemp_range: @@ -60,6 +62,10 @@ export const CONFIG_INFO: Record = { 'XTC sampler cuts out top tokens; this parameter controls the chance of cutting tokens at all. 0 disables XTC.', xtc_threshold: 'XTC sampler cuts out top tokens; this parameter controls the token probability that is required to cut that token.', + smoothing_factor: + 'Controls the quadratic adjustment in smooth / quadratic sampler', + smoothing_curve: + 'Controls the cubic transformation curve for smoothing / quadratic sampler.', typical_p: 'Sorts and limits tokens based on the difference between log-probability and entropy.', repeat_last_n: 'Last n tokens to consider for penalizing repetition', diff --git a/tools/server/webui/src/components/ChatScreen.tsx b/tools/server/webui/src/components/ChatScreen.tsx index b645a494d6853..7d53fe8ac27ee 100644 --- a/tools/server/webui/src/components/ChatScreen.tsx +++ b/tools/server/webui/src/components/ChatScreen.tsx @@ -1,4 +1,4 @@ -import { useEffect, useMemo, useRef, useState } from 'react'; +import { ClipboardEvent, useEffect, useMemo, useRef, useState } from 'react'; import { CallbackGeneratedChunk, useAppContext } from '../utils/app.context'; import ChatMessage from './ChatMessage'; import { CanvasType, Message, PendingMessage } from '../utils/types'; @@ -328,6 +328,17 @@ function ChatInput({ {({ getRootProps, getInputProps }) => (
) => { + const files = Array.from(e.clipboardData.items) + .filter((item) => item.kind === 'file') + .map((item) => item.getAsFile()) + .filter((file) => file !== null); + + if (files.length > 0) { + e.preventDefault(); + extraContext.onFileAdded(files); + } + }} {...getRootProps()} > {!isGenerating && ( diff --git a/tools/server/webui/src/components/SettingDialog.tsx b/tools/server/webui/src/components/SettingDialog.tsx index b0044d25403b5..afcbd15307f01 100644 --- a/tools/server/webui/src/components/SettingDialog.tsx +++ b/tools/server/webui/src/components/SettingDialog.tsx @@ -29,6 +29,8 @@ const SAMPLER_KEYS: SettKey[] = [ 'typical_p', 'xtc_probability', 'xtc_threshold', + 'smoothing_factor', + 'smoothing_curve', ]; const PENALTY_KEYS: SettKey[] = [ 'repeat_last_n', diff --git a/tools/server/webui/src/utils/app.context.tsx b/tools/server/webui/src/utils/app.context.tsx index 96cffd95aba7c..f7aa35c0e2d75 100644 --- a/tools/server/webui/src/utils/app.context.tsx +++ b/tools/server/webui/src/utils/app.context.tsx @@ -219,6 +219,8 @@ export const AppContextProvider = ({ typical_p: config.typical_p, xtc_probability: config.xtc_probability, xtc_threshold: config.xtc_threshold, + smoothing_factor: config.smoothing_factor, + smoothing_curve: config.smoothing_curve, repeat_last_n: config.repeat_last_n, repeat_penalty: config.repeat_penalty, presence_penalty: config.presence_penalty, diff --git a/tools/server/webui/vite.config.ts b/tools/server/webui/vite.config.ts index 366df3b751c58..fba4da645fa09 100644 --- a/tools/server/webui/vite.config.ts +++ b/tools/server/webui/vite.config.ts @@ -3,7 +3,7 @@ import react from '@vitejs/plugin-react'; import { viteSingleFile } from 'vite-plugin-singlefile'; import path from 'node:path'; import fs from 'node:fs'; -import zlib from 'node:zlib'; +import * as fflate from 'fflate'; /* eslint-disable */ @@ -33,9 +33,10 @@ const BUILD_PLUGINS = [ }, writeBundle() { const outputIndexHtml = path.join(config.build.outDir, 'index.html'); - const content = + let content = GUIDE_FOR_FRONTEND + '\n' + fs.readFileSync(outputIndexHtml, 'utf-8'); - const compressed = zlib.gzipSync(Buffer.from(content, 'utf-8'), { + content = content.replace(/\r/g, ''); // remove windows-style line endings + const compressed = fflate.gzipSync(Buffer.from(content, 'utf-8'), { level: 9, });