diff --git a/tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py b/tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py index 6985b6574a1..e5dd0163846 100644 --- a/tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py +++ b/tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py @@ -526,7 +526,7 @@ def register_nvfp4_gemm_prologue(custom_pass: PatternMatcherPass): alpha_key = KeywordArg('alpha') output_dtype_key = KeywordArg('output_dtype') to_userbuffers_key = KeywordArg('to_userbuffers') - backend_key = KeywordArg('backend') + allowed_backends_key = KeywordArg('allowed_backends') trtllm_nvfp4_gemm_default = CallFunction( torch.ops.trtllm.nvfp4_gemm.default, act_fp4_key, @@ -536,7 +536,7 @@ def register_nvfp4_gemm_prologue(custom_pass: PatternMatcherPass): alpha_key, output_dtype_key, to_userbuffers=to_userbuffers_key, - backend=backend_key) + allowed_backends=allowed_backends_key) ub_copy = CallFunction(torch.ops.trtllm.copy_to_userbuffers, trtllm_nvfp4_gemm_default) @@ -548,7 +548,7 @@ def empty_nvfp4_gemm_prologue_pattern( alpha: torch.Tensor, output_dtype: torch.dtype, to_userbuffers: bool, - backend: str, + allowed_backends: str, ): return @@ -560,26 +560,31 @@ def target_nvfp4_gemm_prologue_pattern( alpha: torch.Tensor, output_dtype: torch.dtype, to_userbuffers: bool, - backend: str, + allowed_backends: str, ): nvfp4_gemm_output = torch.ops.trtllm.nvfp4_gemm( act_fp4, weight, act_sf, weight_scale, alpha, output_dtype, - True, backend) + True, allowed_backends) return nvfp4_gemm_output def extra_check(match: Match) -> bool: - # Validate backend value - backend_value = match.kwargs.get('backend') - if backend_value is None: - # No backend specified, use default - OK - return True - - # backend should be a string literal - if not isinstance(backend_value, str): - return False - - valid_backends = {'auto', 'cutlass', 'cublaslt', 'cutedsl'} - return backend_value in valid_backends + # Validate allowed_backends if present (now a comma-separated string) + allowed_backends_value = match.kwargs.get('allowed_backends') + if allowed_backends_value is not None: + # allowed_backends should be a comma-separated string + if not isinstance(allowed_backends_value, str): + return False + backends_list = [ + b.strip() for b in allowed_backends_value.split(',') + if b.strip() + ] + valid_individual = { + 'cutlass', 'cublaslt', 'cutedsl', 'cuda_core' + } + if not all(b in valid_individual for b in backends_list): + return False + + return True register_replacement( empty_nvfp4_gemm_prologue_pattern, diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index d40c1fd5844..003f1378837 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -695,28 +695,36 @@ def _( class NVFP4GemmUnifiedRunner(TunableRunner): runner_dict = dict() - def __init__(self, - to_userbuffers: bool, - output_dtype: torch.dtype, - backend: str = "auto"): + def __init__(self, to_userbuffers: bool, output_dtype: torch.dtype, + allowed_backends: List[str]): super().__init__() self.to_userbuffers = to_userbuffers self.output_dtype = output_dtype - self.backend = backend + self.allowed_backends = allowed_backends def unique_id(self): - """Include backend in cache key to avoid sharing cache across backends.""" - return (self.to_userbuffers, self.output_dtype, self.backend) + """Include allowed_backends in cache key to avoid sharing cache across different backend configs.""" + # Convert list to tuple for hashability + allowed_tuple = tuple(self.allowed_backends) + return (self.to_userbuffers, self.output_dtype, allowed_tuple) + + def _is_backend_allowed(self, backend_name: str) -> bool: + """Check if a backend is allowed based on allowed_backends list.""" + return backend_name in self.allowed_backends + + def _is_only_backend(self, backend_name: str) -> bool: + """Check if this is the only backend in allowed_backends (explicitly forced).""" + return self.allowed_backends == [backend_name] def get_valid_tactics(self, inputs: List[torch.Tensor], profile: OptimizationProfile, **kwargs) -> List[Tuple]: - # return valid nvfp4 gemm implementations + # return valid nvfp4 gemm implementations from allowed_backends tactics = [] act_fp4, weight, act_sf, weight_scale, alpha = inputs - backend = self.backend - if backend in ["auto", "cuda_core"]: + # Add CUDA Core backend if available + if self._is_backend_allowed("cuda_core"): is_cuda_core_supported = False m = act_fp4.shape[0] sm_version = None @@ -732,40 +740,39 @@ def get_valid_tactics(self, inputs: List[torch.Tensor], if is_cuda_core_supported: tactics.append("cuda_core") - elif backend == "cuda_core": - # Explicitly requested but conditions not met - raise error + elif self._is_only_backend("cuda_core"): + # Explicitly forced but conditions not met - raise error error_msg = f"CUDA Core backend requires SM >= {CudaCoreNVFP4Runner.MIN_SM_VERSION} and M <= {CudaCoreNVFP4Runner.MAX_M_DIMENSION}. " error_msg += f"Current: SM={sm_version if sm_version else 'N/A'}, M={m}. " - error_msg += "Please use backend='auto' or another backend." + error_msg += "Please add other backends to allowed_backends." raise ValueError(error_msg) # Add CUTLASS runner (always available) - if backend in ["auto", "cutlass"]: + if self._is_backend_allowed("cutlass"): tactics.append("cutlass") # Add cuBLASLt runner if available - if backend in ["auto", "cublaslt"]: + if self._is_backend_allowed("cublaslt"): if IS_CUBLASLT_AVAILABLE: tactics.append("cublaslt") - elif backend == "cublaslt": + elif self._is_only_backend("cublaslt"): raise ValueError( "cuBLASLt backend is not available. " - "Please check cuBLASLt installation or use backend='auto'.") + "Please check cuBLASLt installation or add other backends to allowed_backends." + ) # Add CuteDSL runner if available - if backend in ["auto", "cutedsl"]: + if self._is_backend_allowed("cutedsl"): if IS_CUTLASS_DSL_AVAILABLE: # Check SM version first - CuteDSL NVFP4 only supports SM 100 (B200) sm_version = get_sm_version() if sm_version not in [100, 103]: - if backend == "cutedsl": - # Explicitly requested CuteDSL but SM version not supported + if self._is_only_backend("cutedsl"): + # Explicitly forced CuteDSL but SM version not supported raise ValueError( f"CuteDSL NVFP4 backend requires SM 100 (B200) or SM 103 (B300), but got SM {sm_version}. " f"CuteDSL NVFP4 is not supported on this GPU architecture. " - f"Please use backend='auto' to automatically select a compatible backend." - ) - # else: backend='auto' → silently skip CuteDSL + "Please add other backends to allowed_backends.") else: # SM version OK, check if CuteDSL supports the current shape from tensorrt_llm._torch.custom_ops.cute_dsl_custom_ops import \ @@ -778,8 +785,8 @@ def get_valid_tactics(self, inputs: List[torch.Tensor], if cutedsl_tactics: # CuteDSL supports this shape tactics.append("cutedsl") - elif backend == "cutedsl": - # Explicitly requested CuteDSL but it doesn't support this shape + elif self._is_only_backend("cutedsl"): + # Explicitly forced CuteDSL but it doesn't support this shape m, n, k = inputs[0].shape[0], inputs[1].shape[ 0], inputs[0].shape[1] * 2 raise ValueError( @@ -788,13 +795,12 @@ def get_valid_tactics(self, inputs: List[torch.Tensor], f"CuteDSL requires 16-byte alignment for major (contiguous) dimensions:\n" f" - K must be divisible by 32 (FP4 K-major layout): K%32={'0✓' if k % 32 == 0 else str(k%32)+'✗'}\n" f" - Or the combination of (M, N, K, tiling, cluster shape) is not supported\n" - f"Please use backend='auto' to automatically select a compatible backend." - ) - # else: backend='auto' and CuteDSL doesn't support shape → silently skip - elif backend == "cutedsl": + f"Please add other backends to allowed_backends.") + elif self._is_only_backend("cutedsl"): raise ValueError( "CuteDSL backend is not available. " - "Please check CuteDSL installation or use backend='auto'.") + "Please check CuteDSL installation or add other backends to allowed_backends." + ) return tactics @@ -807,31 +813,23 @@ def forward( ) -> torch.Tensor: act_fp4, weight, act_sf, weight_scale, alpha = inputs - requested_backend = self.backend - - # If a specific backend was requested (not 'auto') and we're using fallback tactic - # This can happen on cache miss, where AutoTuner uses tactic=-1 as default - if requested_backend != 'auto' and requested_backend != tactic and tactic == -1: - # User explicitly requested a backend, but we're falling back to default - # This might happen on cache miss. We should validate the requested backend supports this shape. - - # Get valid tactics for the requested backend + # Handle fallback tactic (-1) on cache miss + if tactic == -1: + # Get valid tactics and use first available from tensorrt_llm._torch.autotuner import OptimizationProfile valid_tactics = self.get_valid_tactics(inputs, OptimizationProfile()) - - if not valid_tactics or requested_backend not in valid_tactics: - # Requested backend doesn't support this shape + if valid_tactics: + # Prefer cutlass as fallback if available, otherwise use first valid tactic + tactic = "cutlass" if "cutlass" in valid_tactics else valid_tactics[ + 0] + else: m, n, k = inputs[0].shape[0], inputs[1].shape[ 0], inputs[0].shape[1] * 2 raise ValueError( - f"Backend '{requested_backend}' was explicitly requested but does not support the current shape:\n" + f"No valid backends available for the current shape:\n" f" M={m}, N={n}, K={k}\n" - f"Please use backend='auto' to automatically select a compatible backend." - ) - - # Backend supports it, use the requested backend instead of fallback - tactic = requested_backend + f" Allowed backends: {self.allowed_backends}") if tactic == "cuda_core": # Unswizzle the activation scale factors @@ -882,11 +880,11 @@ def nvfp4_gemm( alpha: torch.Tensor, output_dtype: torch.dtype, to_userbuffers: bool = False, - backend: str = "auto", + allowed_backends: str = "cutlass,cublaslt,cuda_core", ) -> torch.Tensor: - """Unified NVFP4 GEMM with automatic or manual backend selection. + """Unified NVFP4 GEMM with automatic backend selection. - This function can automatically choose the best backend or force a specific backend: + This function automatically chooses the best backend from the allowed list: - CUTLASS: Predefined CUTLASS configurations with auto-tuning - cuBLASLt: Heuristic-based algorithms from cuBLASLt library - CuteDSL: Blackwell-optimized persistent kernels (when available and inputs are valid) @@ -894,8 +892,7 @@ def nvfp4_gemm( The AutoTuner profiles all available backends during the first run and caches the best choice for each input shape. Subsequent calls use the cached selection - with zero overhead. In 'auto' mode, backends are only considered if their - requirements are met (e.g., CUDA Core only participates when SM >= 100 and M <= 8). + with zero overhead. Args: act_fp4: Activation tensor [m, k] in FP4 format (packed in uint8) @@ -905,12 +902,10 @@ def nvfp4_gemm( alpha: Scaling factor (as torch.Tensor for CUTLASS/cuBLASLt compatibility) output_dtype: Output data type to_userbuffers: Whether to use user buffers (CUTLASS/cuBLASLt only) - backend: Backend selection, one of: - - 'auto': AutoTuner automatically selects best backend (default) - - 'cutlass': Force use CUTLASS (FP4GemmRunner) - - 'cublaslt': Force use cuBLASLt (CublasLtFP4GemmRunner) - - 'cutedsl': Force use CuteDSL (CuteDSLNVFP4Wrapper) - - 'cuda_core': Force use CUDA Core (CudaCoreNVFP4Runner, requires SM >= 100, M <= 8) + allowed_backends: Comma-separated list of backends to consider for auto-selection. + Default: "cutlass,cublaslt,cuda_core" (excludes cutedsl for faster build) + Add 'cutedsl' for extreme performance at the cost of longer build time. + Valid backends: 'cutlass', 'cublaslt', 'cutedsl', 'cuda_core'. Returns: Output tensor [m, n] with dtype=output_dtype @@ -919,14 +914,26 @@ def nvfp4_gemm( ValueError: If backend is invalid/unavailable """ - # Validate backend parameter - valid_backends = ['auto', 'cutlass', 'cublaslt', 'cutedsl', 'cuda_core'] - if backend not in valid_backends: + valid_individual_backends = {'cutlass', 'cublaslt', 'cutedsl', 'cuda_core'} + + # Parse comma-separated string to list + backends_list = [ + b.strip() for b in allowed_backends.split(',') if b.strip() + ] + + # Validate allowed_backends + invalid_backends = set(backends_list) - valid_individual_backends + if invalid_backends: + raise ValueError( + f"Invalid backends in allowed_backends: {invalid_backends}. " + f"Valid backends are: {sorted(valid_individual_backends)}.") + if not backends_list: raise ValueError( - f"Invalid backend '{backend}'. Must be one of {valid_backends}") + f"allowed_backends cannot be empty. " + f"Valid backends are: {sorted(valid_individual_backends)}.") - # Build list of runners based on backend parameter - runner = NVFP4GemmUnifiedRunner(to_userbuffers, output_dtype, backend) + # Build runner with allowed backends + runner = NVFP4GemmUnifiedRunner(to_userbuffers, output_dtype, backends_list) # Use AutoTuner to select best runner and tactic # - For 'auto' mode: compare across all backends, find global optimum @@ -966,7 +973,7 @@ def _( alpha: torch.Tensor, output_dtype: torch.dtype, to_userbuffers: bool = False, - backend: str = "auto", + allowed_backends: str = "cutlass,cublaslt,cuda_core", ) -> torch.Tensor: """Fake implementation for torch.compile support.""" return act_fp4.new_empty((act_fp4.size(0), weight.size(0)), diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index 232683a2765..148ec5e2e3f 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -102,6 +102,11 @@ class ModelConfig(Generic[TConfig]): # If true, use low precision combine in MoE operations (only for NVFP4 quantization) use_low_precision_moe_combine: bool = False + # NVFP4 GEMM backend configuration - list of backends to consider for auto-selection + # Default excludes 'cutedsl' for faster build time. Add 'cutedsl' for extreme perf. + nvfp4_gemm_allowed_backends: List[str] = field( + default_factory=lambda: ['cutlass', 'cublaslt', 'cuda_core']) + allreduce_strategy: AllReduceStrategy = AllReduceStrategy.AUTO # If true, enable min-latency mode. Currently only used for Llama4. diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index 613665207cc..5c90a364934 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -28,7 +28,7 @@ from ..._utils import get_sm_version, is_sm_100f from ...models.modeling_utils import QuantConfig -from ..utils import Fp4QuantizedTensor, unswizzle_sf +from ..utils import Fp4QuantizedTensor, get_model_extra_attrs, unswizzle_sf class WeightMode(str, enum.Enum): @@ -937,14 +937,17 @@ def apply(self, module: Linear, input: torch.Tensor, input, module.input_scale, module.scaling_vector_size, False) # Use unified interface - supports CUTLASS, cuBLASLt, CuteDSL - output = torch.ops.trtllm.nvfp4_gemm(act_fp4, - module.weight, - act_sf, - module.weight_scale, - module.alpha, - module.dtype, - to_userbuffers=False, - backend=module.nvfp4_backend) + # Convert list to comma-separated string for torch.compile compatibility + allowed_backends_str = ','.join(module.nvfp4_allowed_backends) + output = torch.ops.trtllm.nvfp4_gemm( + act_fp4, + module.weight, + act_sf, + module.weight_scale, + module.alpha, + module.dtype, + to_userbuffers=False, + allowed_backends=allowed_backends_str) # Take the dim of out_features if padded. Make sure the output is contiguous if output.shape[-1] > module.out_features: output = output[..., :module.out_features].contiguous() @@ -2054,14 +2057,15 @@ def __init__( use_cute_dsl_blockscaling_mm: bool = False, disable_deep_gemm: bool = False, fused_weight_shard_indices_mapping: Optional[dict] = None, - nvfp4_backend: str = "auto", + nvfp4_allowed_backends: Optional[List[str]] = None, ): """ Args: - nvfp4_backend: Backend selection for NVFP4 GEMM operations. - Supported values: "auto", "cutlass", "cublaslt", "cutedsl". - Default is "auto" which automatically selects the best backend. - Can be overridden via TRTLLM_NVFP4_GEMM_BACKEND environment variable. + nvfp4_allowed_backends: List of backends to consider for NVFP4 GEMM auto-selection. + Default (via config): ['cutlass', 'cublaslt', 'cuda_core'] - excludes cutedsl for faster build. + Add 'cutedsl' for extreme performance at the cost of longer build time. + Valid backends: 'cutlass', 'cublaslt', 'cutedsl', 'cuda_core'. + Configure via nvfp4_gemm_config.allowed_backends in extra_llm_api_options.yaml. """ from ..distributed import AllReduce @@ -2082,20 +2086,17 @@ def __init__( self.disable_deep_gemm = disable_deep_gemm self.fused_weight_shard_indices_mapping = fused_weight_shard_indices_mapping - # Support environment variable override for nvfp4_backend - nvfp4_backend_value = os.environ.get('TRTLLM_NVFP4_GEMM_BACKEND', - nvfp4_backend) - - # Validate backend selection - valid_backends = {'auto', 'cutlass', 'cublaslt', 'cutedsl'} - if nvfp4_backend_value not in valid_backends: - raise ValueError( - f"Invalid nvfp4_backend: '{nvfp4_backend_value}'. " - f"Supported values are: {', '.join(sorted(valid_backends))}. " - f"Set via constructor argument or TRTLLM_NVFP4_GEMM_BACKEND environment variable." - ) - - self.nvfp4_backend = nvfp4_backend_value + # Store NVFP4 GEMM allowed backends configuration + # Read from model_extra_attrs if not explicitly provided (allows config via llm_api_options) + if nvfp4_allowed_backends is None: + model_attrs = get_model_extra_attrs() + if model_attrs: + nvfp4_allowed_backends = model_attrs.get( + 'nvfp4_gemm_allowed_backends') + # Default: exclude cutedsl for faster build time + self.nvfp4_allowed_backends = nvfp4_allowed_backends or [ + 'cutlass', 'cublaslt', 'cuda_core' + ] local_in_features = in_features local_out_features = out_features diff --git a/tensorrt_llm/_torch/pyexecutor/model_loader.py b/tensorrt_llm/_torch/pyexecutor/model_loader.py index 679d50615e8..3d054c6b3b7 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_loader.py +++ b/tensorrt_llm/_torch/pyexecutor/model_loader.py @@ -357,7 +357,13 @@ def _load_and_validate_config( moe_disable_finalize_fusion=self.llm_args.moe_config. disable_finalize_fusion, use_low_precision_moe_combine=self.llm_args.moe_config. - use_low_precision_moe_combine) + use_low_precision_moe_combine, + nvfp4_gemm_allowed_backends=self.llm_args.nvfp4_gemm_config. + allowed_backends) + + # Store nvfp4 config in extra_attrs for Linear layer access + config.extra_attrs[ + 'nvfp4_gemm_allowed_backends'] = config.nvfp4_gemm_allowed_backends validate_and_set_kv_cache_quant(config, self.llm_args.kv_cache_config.dtype) diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 0ff0fb12972..9f154c53f63 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -412,6 +412,34 @@ def from_dict(cls, data: dict): return cls(**data) +class Nvfp4GemmConfig(StrictBaseModel): + """ + Configuration for NVFP4 GEMM backend selection. + """ + allowed_backends: List[str] = Field( + default=['cutlass', 'cublaslt', 'cuda_core'], + description="List of backends to consider for auto-selection. " + "Default excludes 'cutedsl' for faster build time. " + "Add 'cutedsl' for extreme performance at the cost of longer server launch time. " + "Valid values: 'cutlass', 'cublaslt', 'cutedsl', 'cuda_core'.") + + @model_validator(mode="after") + def validate_allowed_backends(self) -> 'Nvfp4GemmConfig': + valid_backends = {'cutlass', 'cublaslt', 'cutedsl', 'cuda_core'} + invalid = set(self.allowed_backends) - valid_backends + if invalid: + raise ValueError( + f"Invalid backends in allowed_backends: {invalid}. " + f"Valid backends are: {sorted(valid_backends)}") + if not self.allowed_backends: + raise ValueError("allowed_backends cannot be empty.") + return self + + @classmethod + def from_dict(cls, data: dict): + return cls(**data) + + class AttentionDpConfig(StrictBaseModel): """ Configuration for attention DP. @@ -2566,6 +2594,11 @@ class TorchLlmArgs(BaseLlmArgs): description="MoE config.", status="beta") + nvfp4_gemm_config: Nvfp4GemmConfig = Field( + default_factory=Nvfp4GemmConfig, + description="NVFP4 GEMM backend config.", + status="beta") + attn_backend: str = Field(default='TRTLLM', description="Attention backend to use.", status="beta") @@ -3050,6 +3083,7 @@ def update_llm_args_with_extra_dict( "speculative_config": DecodingBaseConfig, "lora_config": LoraConfig, "moe_config": MoeConfig, + "nvfp4_gemm_config": Nvfp4GemmConfig, "attention_dp_config": AttentionDpConfig, "sparse_attention_config": BaseSparseAttentionConfig, "kv_cache_config": KvCacheConfig, diff --git a/tests/unittest/_torch/thop/parallel/test_fp4_linear.py b/tests/unittest/_torch/thop/parallel/test_fp4_linear.py index a549b52fa4e..cc61e075150 100644 --- a/tests/unittest/_torch/thop/parallel/test_fp4_linear.py +++ b/tests/unittest/_torch/thop/parallel/test_fp4_linear.py @@ -35,12 +35,13 @@ def test_fp4_linear(dtype, mnk): False) qc = QuantConfig(quant_algo=QuantAlgo.NVFP4) - l_fp4 = Linear(in_features=HIDDEN_SIZE, - out_features=OUTPUT_SIZE, - bias=False, - dtype=dtype, - quant_config=qc, - nvfp4_backend='cutlass') # Force CUTLASS to match reference + l_fp4 = Linear( + in_features=HIDDEN_SIZE, + out_features=OUTPUT_SIZE, + bias=False, + dtype=dtype, + quant_config=qc, + nvfp4_allowed_backends=['cutlass']) # Force CUTLASS to match reference assert l_fp4.weight.dtype == fp4_utils.float4_e2m1x2 assert l_fp4.weight_scale.dtype == fp4_utils.float4_sf_dtype @@ -117,7 +118,7 @@ def test_fp4_linear_cute_dsl(dtype, mnk): bias=False, dtype=dtype, quant_config=qc, - nvfp4_backend='cutedsl') + nvfp4_allowed_backends=['cutedsl']) assert l_fp4.weight.dtype == fp4_utils.float4_e2m1x2 assert l_fp4.weight_scale.dtype == fp4_utils.float4_sf_dtype @@ -180,7 +181,7 @@ def fp4_linear_perf_test(dtype, SEQ_LEN, OUTPUT_SIZE, HIDDEN_SIZE): bias=False, dtype=dtype, quant_config=qc, - nvfp4_backend='cutedsl') + nvfp4_allowed_backends=['cutedsl']) assert l_fp4.weight.dtype == fp4_utils.float4_e2m1x2 assert l_fp4.weight_scale.dtype == fp4_utils.float4_sf_dtype @@ -215,7 +216,8 @@ def fp4_linear_perf_test(dtype, SEQ_LEN, OUTPUT_SIZE, HIDDEN_SIZE): bias=False, dtype=dtype, quant_config=qc, - nvfp4_backend='cutlass') # Use CUTLASS as reference + nvfp4_allowed_backends=['cutlass' + ]) # Use CUTLASS as reference assert l_fp4_ref.weight.dtype == fp4_utils.float4_e2m1x2 assert l_fp4_ref.weight_scale.dtype == fp4_utils.float4_sf_dtype @@ -466,18 +468,19 @@ def test_nvfp4_gemm_unified_all_tactics(dtype, mnk): alpha=alpha_tensor, output_dtype=dtype, to_userbuffers=False, - backend='cutlass') + allowed_backends='cutlass') # Test auto backend selection with autotuning with torch.inference_mode(), autotune(): - output_auto = torch.ops.trtllm.nvfp4_gemm(act_fp4=x_fp4, - weight=w_fp4, - act_sf=x_sf_block, - weight_scale=w_sf_block, - alpha=alpha_tensor, - output_dtype=dtype, - to_userbuffers=False, - backend='auto') + output_auto = torch.ops.trtllm.nvfp4_gemm( + act_fp4=x_fp4, + weight=w_fp4, + act_sf=x_sf_block, + weight_scale=w_sf_block, + alpha=alpha_tensor, + output_dtype=dtype, + to_userbuffers=False, + allowed_backends='cutlass,cublaslt,cuda_core,cutedsl') AutoTuner.get().print_profiling_cache() @@ -497,14 +500,15 @@ def test_nvfp4_gemm_unified_all_tactics(dtype, mnk): print(f"\n[Outer Layer] Capturing backend selection tactics...") with AutoTuner.get().capture() as outer_capture, torch.inference_mode(): - output = torch.ops.trtllm.nvfp4_gemm(act_fp4=x_fp4, - weight=w_fp4, - act_sf=x_sf_block, - weight_scale=w_sf_block, - alpha=alpha_tensor, - output_dtype=dtype, - to_userbuffers=False, - backend='auto') + output = torch.ops.trtllm.nvfp4_gemm( + act_fp4=x_fp4, + weight=w_fp4, + act_sf=x_sf_block, + weight_scale=w_sf_block, + alpha=alpha_tensor, + output_dtype=dtype, + to_userbuffers=False, + allowed_backends='cutlass,cublaslt,cuda_core,cutedsl') outer_tactics_list = list(outer_capture) print(f" Found {len(outer_tactics_list)} outer layer tactics (backends)") @@ -604,7 +608,7 @@ def test_nvfp4_gemm_unified_all_tactics(dtype, mnk): alpha=alpha_tensor, output_dtype=dtype, to_userbuffers=False, - backend='cuda_core') + allowed_backends='cuda_core') torch.testing.assert_close(output_cuda_core, output_ref, @@ -624,7 +628,7 @@ def test_nvfp4_gemm_unified_all_tactics(dtype, mnk): print(f"\n Note: cuda_core has no autotuning (single tactic)") print(f" Note: Tested all inner layer tactics for each backend") print( - f" Outer layer (backend selection) was tested separately with backend='auto'" + f" Outer layer (backend selection) was tested separately with all backends allowed" ) print(f"{'='*80}\n") @@ -719,17 +723,18 @@ def test_fp4_linear_cuda_core(dtype, mnk): alpha=alpha_tensor, output_dtype=dtype, to_userbuffers=False, - backend='cutlass') + allowed_backends='cutlass') # Test CUDA Core backend - output_cuda_core = torch.ops.trtllm.nvfp4_gemm(act_fp4=x_fp4, - weight=w_fp4, - act_sf=x_sf_block, - weight_scale=w_sf_block, - alpha=alpha_tensor, - output_dtype=dtype, - to_userbuffers=False, - backend='cuda_core') + output_cuda_core = torch.ops.trtllm.nvfp4_gemm( + act_fp4=x_fp4, + weight=w_fp4, + act_sf=x_sf_block, + weight_scale=w_sf_block, + alpha=alpha_tensor, + output_dtype=dtype, + to_userbuffers=False, + allowed_backends='cuda_core') # Compare results torch.cuda.synchronize() diff --git a/tests/unittest/api_stability/references/llm.yaml b/tests/unittest/api_stability/references/llm.yaml index 7702fe8cdef..6d02fed3975 100644 --- a/tests/unittest/api_stability/references/llm.yaml +++ b/tests/unittest/api_stability/references/llm.yaml @@ -115,6 +115,10 @@ methods: annotation: tensorrt_llm.llmapi.llm_args.MoeConfig status: beta default: null + nvfp4_gemm_config: + annotation: tensorrt_llm.llmapi.llm_args.Nvfp4GemmConfig + status: beta + default: null attn_backend: annotation: str default: TRTLLM