@@ -464,6 +464,19 @@ class ServerArgs:
464464 enable_pdmux : bool = False
465465 sm_group_num : int = 3
466466
467+ def get_attention_backends (server_args ):
468+ prefill_attention_backend_str = (
469+ server_args .prefill_attention_backend
470+ if server_args .prefill_attention_backend
471+ else server_args .attention_backend
472+ )
473+ decode_attention_backend_str = (
474+ server_args .decode_attention_backend
475+ if server_args .decode_attention_backend
476+ else server_args .attention_backend
477+ )
478+ return prefill_attention_backend_str , decode_attention_backend_str
479+
467480 def __post_init__ (self ):
468481 """
469482 Orchestrates the handling of various server arguments, ensuring proper configuration and validation.
@@ -748,20 +761,28 @@ def _handle_model_specific_adjustments(self):
748761 hf_config = self .get_hf_config ()
749762 model_arch = hf_config .architectures [0 ]
750763 if model_arch in ["GptOssForCausalLM" ]:
751- if self .attention_backend is None :
764+ if (
765+ self .attention_backend is None
766+ and self .prefill_attention_backend is None
767+ and self .decode_attention_backend is None
768+ ):
752769 if is_cuda () and is_sm100_supported ():
753770 self .attention_backend = "trtllm_mha"
754771 elif is_cuda () and is_sm90_supported ():
755772 self .attention_backend = "fa3"
756773 else :
757774 self .attention_backend = "triton"
758- supported_backends = ["triton" , "trtllm_mha" , "fa3" ]
759- logger .info (
760- f"Use { self .attention_backend } as attention backend for GptOssForCausalLM"
761- )
775+
776+ supported_backends = ["triton" , "trtllm_mha" , "fa3" , "fa4" ]
777+ prefill_attn_backend , decode_attn_backend = self .get_attention_backends ()
762778 assert (
763- self .attention_backend in supported_backends
764- ), f"GptOssForCausalLM requires one of { supported_backends } attention backend, but got '{ self .attention_backend } '"
779+ prefill_attn_backend in supported_backends
780+ and decode_attn_backend in supported_backends
781+ ), (
782+ f"GptOssForCausalLM requires one of { supported_backends } attention backend, but got the following backends\n "
783+ f"- Prefill: { prefill_attn_backend } \n "
784+ f"- Decode: { decode_attn_backend } \n "
785+ )
765786
766787 if is_sm100_supported ():
767788 if not self .enable_dp_attention :
0 commit comments