Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 22 additions & 17 deletions tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ def register_nvfp4_gemm_prologue(custom_pass: PatternMatcherPass):
alpha_key = KeywordArg('alpha')
output_dtype_key = KeywordArg('output_dtype')
to_userbuffers_key = KeywordArg('to_userbuffers')
backend_key = KeywordArg('backend')
allowed_backends_key = KeywordArg('allowed_backends')
trtllm_nvfp4_gemm_default = CallFunction(
torch.ops.trtllm.nvfp4_gemm.default,
act_fp4_key,
Expand All @@ -536,7 +536,7 @@ def register_nvfp4_gemm_prologue(custom_pass: PatternMatcherPass):
alpha_key,
output_dtype_key,
to_userbuffers=to_userbuffers_key,
backend=backend_key)
allowed_backends=allowed_backends_key)
ub_copy = CallFunction(torch.ops.trtllm.copy_to_userbuffers,
trtllm_nvfp4_gemm_default)

Expand All @@ -548,7 +548,7 @@ def empty_nvfp4_gemm_prologue_pattern(
alpha: torch.Tensor,
output_dtype: torch.dtype,
to_userbuffers: bool,
backend: str,
allowed_backends: str,
):
return

Expand All @@ -560,26 +560,31 @@ def target_nvfp4_gemm_prologue_pattern(
alpha: torch.Tensor,
output_dtype: torch.dtype,
to_userbuffers: bool,
backend: str,
allowed_backends: str,
):
nvfp4_gemm_output = torch.ops.trtllm.nvfp4_gemm(
act_fp4, weight, act_sf, weight_scale, alpha, output_dtype,
True, backend)
True, allowed_backends)
return nvfp4_gemm_output

def extra_check(match: Match) -> bool:
# Validate backend value
backend_value = match.kwargs.get('backend')
if backend_value is None:
# No backend specified, use default - OK
return True

# backend should be a string literal
if not isinstance(backend_value, str):
return False

valid_backends = {'auto', 'cutlass', 'cublaslt', 'cutedsl'}
return backend_value in valid_backends
# Validate allowed_backends if present (now a comma-separated string)
allowed_backends_value = match.kwargs.get('allowed_backends')
if allowed_backends_value is not None:
# allowed_backends should be a comma-separated string
if not isinstance(allowed_backends_value, str):
return False
backends_list = [
b.strip() for b in allowed_backends_value.split(',')
if b.strip()
]
valid_individual = {
'cutlass', 'cublaslt', 'cutedsl', 'cuda_core'
}
if not all(b in valid_individual for b in backends_list):
return False

return True

register_replacement(
empty_nvfp4_gemm_prologue_pattern,
Expand Down
139 changes: 73 additions & 66 deletions tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,28 +695,36 @@ def _(
class NVFP4GemmUnifiedRunner(TunableRunner):
runner_dict = dict()

def __init__(self,
to_userbuffers: bool,
output_dtype: torch.dtype,
backend: str = "auto"):
def __init__(self, to_userbuffers: bool, output_dtype: torch.dtype,
allowed_backends: List[str]):
super().__init__()
self.to_userbuffers = to_userbuffers
self.output_dtype = output_dtype
self.backend = backend
self.allowed_backends = allowed_backends

def unique_id(self):
"""Include backend in cache key to avoid sharing cache across backends."""
return (self.to_userbuffers, self.output_dtype, self.backend)
"""Include allowed_backends in cache key to avoid sharing cache across different backend configs."""
# Convert list to tuple for hashability
allowed_tuple = tuple(self.allowed_backends)
return (self.to_userbuffers, self.output_dtype, allowed_tuple)

def _is_backend_allowed(self, backend_name: str) -> bool:
"""Check if a backend is allowed based on allowed_backends list."""
return backend_name in self.allowed_backends

def _is_only_backend(self, backend_name: str) -> bool:
"""Check if this is the only backend in allowed_backends (explicitly forced)."""
return self.allowed_backends == [backend_name]

def get_valid_tactics(self, inputs: List[torch.Tensor],
profile: OptimizationProfile,
**kwargs) -> List[Tuple]:
# return valid nvfp4 gemm implementations
# return valid nvfp4 gemm implementations from allowed_backends
tactics = []
act_fp4, weight, act_sf, weight_scale, alpha = inputs
backend = self.backend

if backend in ["auto", "cuda_core"]:
# Add CUDA Core backend if available
if self._is_backend_allowed("cuda_core"):
is_cuda_core_supported = False
m = act_fp4.shape[0]
sm_version = None
Expand All @@ -732,40 +740,39 @@ def get_valid_tactics(self, inputs: List[torch.Tensor],

if is_cuda_core_supported:
tactics.append("cuda_core")
elif backend == "cuda_core":
# Explicitly requested but conditions not met - raise error
elif self._is_only_backend("cuda_core"):
# Explicitly forced but conditions not met - raise error
error_msg = f"CUDA Core backend requires SM >= {CudaCoreNVFP4Runner.MIN_SM_VERSION} and M <= {CudaCoreNVFP4Runner.MAX_M_DIMENSION}. "
error_msg += f"Current: SM={sm_version if sm_version else 'N/A'}, M={m}. "
error_msg += "Please use backend='auto' or another backend."
error_msg += "Please add other backends to allowed_backends."
raise ValueError(error_msg)

# Add CUTLASS runner (always available)
if backend in ["auto", "cutlass"]:
if self._is_backend_allowed("cutlass"):
tactics.append("cutlass")

# Add cuBLASLt runner if available
if backend in ["auto", "cublaslt"]:
if self._is_backend_allowed("cublaslt"):
if IS_CUBLASLT_AVAILABLE:
tactics.append("cublaslt")
elif backend == "cublaslt":
elif self._is_only_backend("cublaslt"):
raise ValueError(
"cuBLASLt backend is not available. "
"Please check cuBLASLt installation or use backend='auto'.")
"Please check cuBLASLt installation or add other backends to allowed_backends."
)

# Add CuteDSL runner if available
if backend in ["auto", "cutedsl"]:
if self._is_backend_allowed("cutedsl"):
if IS_CUTLASS_DSL_AVAILABLE:
# Check SM version first - CuteDSL NVFP4 only supports SM 100 (B200)
sm_version = get_sm_version()
if sm_version not in [100, 103]:
if backend == "cutedsl":
# Explicitly requested CuteDSL but SM version not supported
if self._is_only_backend("cutedsl"):
# Explicitly forced CuteDSL but SM version not supported
raise ValueError(
f"CuteDSL NVFP4 backend requires SM 100 (B200) or SM 103 (B300), but got SM {sm_version}. "
f"CuteDSL NVFP4 is not supported on this GPU architecture. "
f"Please use backend='auto' to automatically select a compatible backend."
)
# else: backend='auto' → silently skip CuteDSL
"Please add other backends to allowed_backends.")
else:
# SM version OK, check if CuteDSL supports the current shape
from tensorrt_llm._torch.custom_ops.cute_dsl_custom_ops import \
Expand All @@ -778,8 +785,8 @@ def get_valid_tactics(self, inputs: List[torch.Tensor],
if cutedsl_tactics:
# CuteDSL supports this shape
tactics.append("cutedsl")
elif backend == "cutedsl":
# Explicitly requested CuteDSL but it doesn't support this shape
elif self._is_only_backend("cutedsl"):
# Explicitly forced CuteDSL but it doesn't support this shape
m, n, k = inputs[0].shape[0], inputs[1].shape[
0], inputs[0].shape[1] * 2
raise ValueError(
Expand All @@ -788,13 +795,12 @@ def get_valid_tactics(self, inputs: List[torch.Tensor],
f"CuteDSL requires 16-byte alignment for major (contiguous) dimensions:\n"
f" - K must be divisible by 32 (FP4 K-major layout): K%32={'0✓' if k % 32 == 0 else str(k%32)+'✗'}\n"
f" - Or the combination of (M, N, K, tiling, cluster shape) is not supported\n"
f"Please use backend='auto' to automatically select a compatible backend."
)
# else: backend='auto' and CuteDSL doesn't support shape → silently skip
elif backend == "cutedsl":
f"Please add other backends to allowed_backends.")
elif self._is_only_backend("cutedsl"):
raise ValueError(
"CuteDSL backend is not available. "
"Please check CuteDSL installation or use backend='auto'.")
"Please check CuteDSL installation or add other backends to allowed_backends."
)

return tactics

Expand All @@ -807,31 +813,23 @@ def forward(
) -> torch.Tensor:
act_fp4, weight, act_sf, weight_scale, alpha = inputs

requested_backend = self.backend

# If a specific backend was requested (not 'auto') and we're using fallback tactic
# This can happen on cache miss, where AutoTuner uses tactic=-1 as default
if requested_backend != 'auto' and requested_backend != tactic and tactic == -1:
# User explicitly requested a backend, but we're falling back to default
# This might happen on cache miss. We should validate the requested backend supports this shape.

# Get valid tactics for the requested backend
# Handle fallback tactic (-1) on cache miss
if tactic == -1:
# Get valid tactics and use first available
from tensorrt_llm._torch.autotuner import OptimizationProfile
valid_tactics = self.get_valid_tactics(inputs,
OptimizationProfile())

if not valid_tactics or requested_backend not in valid_tactics:
# Requested backend doesn't support this shape
if valid_tactics:
# Prefer cutlass as fallback if available, otherwise use first valid tactic
tactic = "cutlass" if "cutlass" in valid_tactics else valid_tactics[
0]
else:
m, n, k = inputs[0].shape[0], inputs[1].shape[
0], inputs[0].shape[1] * 2
raise ValueError(
f"Backend '{requested_backend}' was explicitly requested but does not support the current shape:\n"
f"No valid backends available for the current shape:\n"
f" M={m}, N={n}, K={k}\n"
f"Please use backend='auto' to automatically select a compatible backend."
)

# Backend supports it, use the requested backend instead of fallback
tactic = requested_backend
f" Allowed backends: {self.allowed_backends}")

if tactic == "cuda_core":
# Unswizzle the activation scale factors
Expand Down Expand Up @@ -882,20 +880,19 @@ def nvfp4_gemm(
alpha: torch.Tensor,
output_dtype: torch.dtype,
to_userbuffers: bool = False,
backend: str = "auto",
allowed_backends: str = "cutlass,cublaslt,cuda_core",
) -> torch.Tensor:
"""Unified NVFP4 GEMM with automatic or manual backend selection.
"""Unified NVFP4 GEMM with automatic backend selection.

This function can automatically choose the best backend or force a specific backend:
This function automatically chooses the best backend from the allowed list:
- CUTLASS: Predefined CUTLASS configurations with auto-tuning
- cuBLASLt: Heuristic-based algorithms from cuBLASLt library
- CuteDSL: Blackwell-optimized persistent kernels (when available and inputs are valid)
- CUDA Core: CUDA Core implementation (requires SM >= 100 and M <= 8)

The AutoTuner profiles all available backends during the first run and caches
the best choice for each input shape. Subsequent calls use the cached selection
with zero overhead. In 'auto' mode, backends are only considered if their
requirements are met (e.g., CUDA Core only participates when SM >= 100 and M <= 8).
with zero overhead.

Args:
act_fp4: Activation tensor [m, k] in FP4 format (packed in uint8)
Expand All @@ -905,12 +902,10 @@ def nvfp4_gemm(
alpha: Scaling factor (as torch.Tensor for CUTLASS/cuBLASLt compatibility)
output_dtype: Output data type
to_userbuffers: Whether to use user buffers (CUTLASS/cuBLASLt only)
backend: Backend selection, one of:
- 'auto': AutoTuner automatically selects best backend (default)
- 'cutlass': Force use CUTLASS (FP4GemmRunner)
- 'cublaslt': Force use cuBLASLt (CublasLtFP4GemmRunner)
- 'cutedsl': Force use CuteDSL (CuteDSLNVFP4Wrapper)
- 'cuda_core': Force use CUDA Core (CudaCoreNVFP4Runner, requires SM >= 100, M <= 8)
allowed_backends: Comma-separated list of backends to consider for auto-selection.
Default: "cutlass,cublaslt,cuda_core" (excludes cutedsl for faster build)
Add 'cutedsl' for extreme performance at the cost of longer build time.
Valid backends: 'cutlass', 'cublaslt', 'cutedsl', 'cuda_core'.

Returns:
Output tensor [m, n] with dtype=output_dtype
Expand All @@ -919,14 +914,26 @@ def nvfp4_gemm(
ValueError: If backend is invalid/unavailable
"""

# Validate backend parameter
valid_backends = ['auto', 'cutlass', 'cublaslt', 'cutedsl', 'cuda_core']
if backend not in valid_backends:
valid_individual_backends = {'cutlass', 'cublaslt', 'cutedsl', 'cuda_core'}

# Parse comma-separated string to list
backends_list = [
b.strip() for b in allowed_backends.split(',') if b.strip()
]

# Validate allowed_backends
invalid_backends = set(backends_list) - valid_individual_backends
if invalid_backends:
raise ValueError(
f"Invalid backends in allowed_backends: {invalid_backends}. "
f"Valid backends are: {sorted(valid_individual_backends)}.")
if not backends_list:
raise ValueError(
f"Invalid backend '{backend}'. Must be one of {valid_backends}")
f"allowed_backends cannot be empty. "
f"Valid backends are: {sorted(valid_individual_backends)}.")

# Build list of runners based on backend parameter
runner = NVFP4GemmUnifiedRunner(to_userbuffers, output_dtype, backend)
# Build runner with allowed backends
runner = NVFP4GemmUnifiedRunner(to_userbuffers, output_dtype, backends_list)

# Use AutoTuner to select best runner and tactic
# - For 'auto' mode: compare across all backends, find global optimum
Expand Down Expand Up @@ -966,7 +973,7 @@ def _(
alpha: torch.Tensor,
output_dtype: torch.dtype,
to_userbuffers: bool = False,
backend: str = "auto",
allowed_backends: str = "cutlass,cublaslt,cuda_core",
) -> torch.Tensor:
"""Fake implementation for torch.compile support."""
return act_fp4.new_empty((act_fp4.size(0), weight.size(0)),
Expand Down
5 changes: 5 additions & 0 deletions tensorrt_llm/_torch/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ class ModelConfig(Generic[TConfig]):
# If true, use low precision combine in MoE operations (only for NVFP4 quantization)
use_low_precision_moe_combine: bool = False

# NVFP4 GEMM backend configuration - list of backends to consider for auto-selection
# Default excludes 'cutedsl' for faster build time. Add 'cutedsl' for extreme perf.
nvfp4_gemm_allowed_backends: List[str] = field(
default_factory=lambda: ['cutlass', 'cublaslt', 'cuda_core'])

allreduce_strategy: AllReduceStrategy = AllReduceStrategy.AUTO

# If true, enable min-latency mode. Currently only used for Llama4.
Expand Down
Loading