Skip to content

Commit 5f66d6b

Browse files
committed
minor change
Signed-off-by: Shijie Wang <[email protected]>
1 parent b6a50c3 commit 5f66d6b

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -695,19 +695,26 @@ def _(
695695
class NVFP4GemmUnifiedRunner(TunableRunner):
696696
runner_dict = dict()
697697

698-
def __init__(self, to_userbuffers: bool, output_dtype: torch.dtype):
698+
def __init__(self,
699+
to_userbuffers: bool,
700+
output_dtype: torch.dtype,
701+
backend: str = "auto"):
699702
super().__init__()
700703
self.to_userbuffers = to_userbuffers
701704
self.output_dtype = output_dtype
705+
self.backend = backend
706+
707+
def unique_id(self):
708+
"""Include backend in cache key to avoid sharing cache across backends."""
709+
return (self.to_userbuffers, self.output_dtype, self.backend)
702710

703-
def get_valid_tactics(self,
704-
inputs: List[torch.Tensor],
711+
def get_valid_tactics(self, inputs: List[torch.Tensor],
705712
profile: OptimizationProfile,
706-
backend: str = "auto",
707713
**kwargs) -> List[Tuple]:
708714
# return valid nvfp4 gemm implementations
709715
tactics = []
710716
act_fp4, weight, act_sf, weight_scale, alpha = inputs
717+
backend = self.backend
711718

712719
if backend in ["auto", "cuda_core"]:
713720
is_cuda_core_supported = False
@@ -800,8 +807,7 @@ def forward(
800807
) -> torch.Tensor:
801808
act_fp4, weight, act_sf, weight_scale, alpha = inputs
802809

803-
# Check if a specific backend was requested
804-
requested_backend = kwargs.get('backend', 'auto')
810+
requested_backend = self.backend
805811

806812
# If a specific backend was requested (not 'auto') and we're using fallback tactic
807813
# This can happen on cache miss, where AutoTuner uses tactic=-1 as default
@@ -812,8 +818,7 @@ def forward(
812818
# Get valid tactics for the requested backend
813819
from tensorrt_llm._torch.autotuner import OptimizationProfile
814820
valid_tactics = self.get_valid_tactics(inputs,
815-
OptimizationProfile(),
816-
backend=requested_backend)
821+
OptimizationProfile())
817822

818823
if not valid_tactics or requested_backend not in valid_tactics:
819824
# Requested backend doesn't support this shape
@@ -921,7 +926,7 @@ def nvfp4_gemm(
921926
f"Invalid backend '{backend}'. Must be one of {valid_backends}")
922927

923928
# Build list of runners based on backend parameter
924-
runner = NVFP4GemmUnifiedRunner(to_userbuffers, output_dtype)
929+
runner = NVFP4GemmUnifiedRunner(to_userbuffers, output_dtype, backend)
925930

926931
# Use AutoTuner to select best runner and tactic
927932
# - For 'auto' mode: compare across all backends, find global optimum
@@ -935,7 +940,6 @@ def nvfp4_gemm(
935940
FP4GemmRunner.
936941
tuning_config, # All runners use the same tuning_config
937942
[act_fp4, weight, act_sf, weight_scale, alpha],
938-
backend=backend,
939943
)
940944
except IndexError as e:
941945
# Provide more helpful error message
@@ -950,7 +954,6 @@ def nvfp4_gemm(
950954
return runner(
951955
inputs=[act_fp4, weight, act_sf, weight_scale, alpha],
952956
tactic=best_tactic,
953-
backend=backend,
954957
)
955958

956959

0 commit comments

Comments
 (0)