@@ -695,19 +695,26 @@ def _(
695695class NVFP4GemmUnifiedRunner (TunableRunner ):
696696 runner_dict = dict ()
697697
698- def __init__ (self , to_userbuffers : bool , output_dtype : torch .dtype ):
698+ def __init__ (self ,
699+ to_userbuffers : bool ,
700+ output_dtype : torch .dtype ,
701+ backend : str = "auto" ):
699702 super ().__init__ ()
700703 self .to_userbuffers = to_userbuffers
701704 self .output_dtype = output_dtype
705+ self .backend = backend
706+
707+ def unique_id (self ):
708+ """Include backend in cache key to avoid sharing cache across backends."""
709+ return (self .to_userbuffers , self .output_dtype , self .backend )
702710
703- def get_valid_tactics (self ,
704- inputs : List [torch .Tensor ],
711+ def get_valid_tactics (self , inputs : List [torch .Tensor ],
705712 profile : OptimizationProfile ,
706- backend : str = "auto" ,
707713 ** kwargs ) -> List [Tuple ]:
708714 # return valid nvfp4 gemm implementations
709715 tactics = []
710716 act_fp4 , weight , act_sf , weight_scale , alpha = inputs
717+ backend = self .backend
711718
712719 if backend in ["auto" , "cuda_core" ]:
713720 is_cuda_core_supported = False
@@ -800,8 +807,7 @@ def forward(
800807 ) -> torch .Tensor :
801808 act_fp4 , weight , act_sf , weight_scale , alpha = inputs
802809
803- # Check if a specific backend was requested
804- requested_backend = kwargs .get ('backend' , 'auto' )
810+ requested_backend = self .backend
805811
806812 # If a specific backend was requested (not 'auto') and we're using fallback tactic
807813 # This can happen on cache miss, where AutoTuner uses tactic=-1 as default
@@ -812,8 +818,7 @@ def forward(
812818 # Get valid tactics for the requested backend
813819 from tensorrt_llm ._torch .autotuner import OptimizationProfile
814820 valid_tactics = self .get_valid_tactics (inputs ,
815- OptimizationProfile (),
816- backend = requested_backend )
821+ OptimizationProfile ())
817822
818823 if not valid_tactics or requested_backend not in valid_tactics :
819824 # Requested backend doesn't support this shape
@@ -921,7 +926,7 @@ def nvfp4_gemm(
921926 f"Invalid backend '{ backend } '. Must be one of { valid_backends } " )
922927
923928 # Build list of runners based on backend parameter
924- runner = NVFP4GemmUnifiedRunner (to_userbuffers , output_dtype )
929+ runner = NVFP4GemmUnifiedRunner (to_userbuffers , output_dtype , backend )
925930
926931 # Use AutoTuner to select best runner and tactic
927932 # - For 'auto' mode: compare across all backends, find global optimum
@@ -935,7 +940,6 @@ def nvfp4_gemm(
935940 FP4GemmRunner .
936941 tuning_config , # All runners use the same tuning_config
937942 [act_fp4 , weight , act_sf , weight_scale , alpha ],
938- backend = backend ,
939943 )
940944 except IndexError as e :
941945 # Provide more helpful error message
@@ -950,7 +954,6 @@ def nvfp4_gemm(
950954 return runner (
951955 inputs = [act_fp4 , weight , act_sf , weight_scale , alpha ],
952956 tactic = best_tactic ,
953- backend = backend ,
954957 )
955958
956959
0 commit comments