Skip to content

Commit edefab0

Browse files
lifuhuanghyhieu
andauthored
[2/2] Support MHA prefill with FlashAttention 4. (#10937)
Co-authored-by: Hieu Pham <[email protected]>
1 parent 97cd38e commit edefab0

File tree

7 files changed

+34
-23
lines changed

7 files changed

+34
-23
lines changed

python/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ dependencies = [
5353
"scipy",
5454
"sentencepiece",
5555
"setproctitle",
56-
"sgl-kernel==0.3.14.post1",
56+
"sgl-kernel==0.3.15",
5757
"soundfile==0.13.1",
5858
"tiktoken",
5959
"timm==1.0.16",

python/pyproject_other.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ tracing = [
6565

6666
srt = [
6767
"sglang[runtime_common]",
68-
"sgl-kernel==0.3.14.post1",
68+
"sgl-kernel==0.3.15",
6969
"torch==2.8.0",
7070
"torchaudio==2.8.0",
7171
"torchvision",

python/sglang/srt/entrypoints/engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -711,7 +711,7 @@ def _set_envs_and_config(server_args: ServerArgs):
711711
if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
712712
assert_pkg_version(
713713
"sgl-kernel",
714-
"0.3.14",
714+
"0.3.15",
715715
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
716716
)
717717

python/sglang/srt/layers/attention/attention_registry.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,6 @@ def create_flashattention_v3_backend(runner):
129129

130130
@register_attention_backend("fa4")
131131
def create_flashattention_v4_backend(runner):
132-
assert (
133-
runner.use_mla_backend
134-
), "FlashAttention v4 Support is at an early stage, only MLA model supported now"
135132
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
136133

137134
return FlashAttentionBackend(runner, fa_impl_ver=4)

python/sglang/srt/layers/attention/flashattention_backend.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -754,7 +754,6 @@ def forward_extend(
754754

755755
# Use Flash Attention for prefill
756756
if not self.use_mla:
757-
assert self.fa_impl_ver in [3], "Only FA3 support here"
758757
# Do multi-head attention
759758
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
760759
layer.layer_id

python/sglang/srt/model_executor/model_runner.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1746,16 +1746,10 @@ def init_attention_backend(self):
17461746

17471747
def _get_attention_backend(self):
17481748
"""Init attention kernel backend."""
1749-
self.decode_attention_backend_str = (
1750-
self.server_args.decode_attention_backend
1751-
if self.server_args.decode_attention_backend
1752-
else self.server_args.attention_backend
1753-
)
1754-
self.prefill_attention_backend_str = (
1755-
self.server_args.prefill_attention_backend
1756-
if self.server_args.prefill_attention_backend
1757-
else self.server_args.attention_backend
1749+
self.prefill_attention_backend_str, self.decode_attention_backend_str = (
1750+
self.server_args.get_attention_backends()
17581751
)
1752+
17591753
if self.decode_attention_backend_str != self.prefill_attention_backend_str:
17601754
from sglang.srt.layers.attention.hybrid_attn_backend import (
17611755
HybridAttnBackend,

python/sglang/srt/server_args.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,19 @@ class ServerArgs:
464464
enable_pdmux: bool = False
465465
sm_group_num: int = 3
466466

467+
def get_attention_backends(server_args):
468+
prefill_attention_backend_str = (
469+
server_args.prefill_attention_backend
470+
if server_args.prefill_attention_backend
471+
else server_args.attention_backend
472+
)
473+
decode_attention_backend_str = (
474+
server_args.decode_attention_backend
475+
if server_args.decode_attention_backend
476+
else server_args.attention_backend
477+
)
478+
return prefill_attention_backend_str, decode_attention_backend_str
479+
467480
def __post_init__(self):
468481
"""
469482
Orchestrates the handling of various server arguments, ensuring proper configuration and validation.
@@ -748,20 +761,28 @@ def _handle_model_specific_adjustments(self):
748761
hf_config = self.get_hf_config()
749762
model_arch = hf_config.architectures[0]
750763
if model_arch in ["GptOssForCausalLM"]:
751-
if self.attention_backend is None:
764+
if (
765+
self.attention_backend is None
766+
and self.prefill_attention_backend is None
767+
and self.decode_attention_backend is None
768+
):
752769
if is_cuda() and is_sm100_supported():
753770
self.attention_backend = "trtllm_mha"
754771
elif is_cuda() and is_sm90_supported():
755772
self.attention_backend = "fa3"
756773
else:
757774
self.attention_backend = "triton"
758-
supported_backends = ["triton", "trtllm_mha", "fa3"]
759-
logger.info(
760-
f"Use {self.attention_backend} as attention backend for GptOssForCausalLM"
761-
)
775+
776+
supported_backends = ["triton", "trtllm_mha", "fa3", "fa4"]
777+
prefill_attn_backend, decode_attn_backend = self.get_attention_backends()
762778
assert (
763-
self.attention_backend in supported_backends
764-
), f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got '{self.attention_backend}'"
779+
prefill_attn_backend in supported_backends
780+
and decode_attn_backend in supported_backends
781+
), (
782+
f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got the following backends\n"
783+
f"- Prefill: {prefill_attn_backend}\n"
784+
f"- Decode: {decode_attn_backend}\n"
785+
)
765786

766787
if is_sm100_supported():
767788
if not self.enable_dp_attention:

0 commit comments

Comments
 (0)