Skip to content

Commit 347515a

Browse files
committed
fix bug
Signed-off-by: Shijie Wang <[email protected]>
1 parent 491d2ea commit 347515a

File tree

2 files changed

+28
-17
lines changed

2 files changed

+28
-17
lines changed

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -475,10 +475,10 @@ def forward(
475475

476476
class CudaCoreNVFP4Runner(TunableRunner):
477477
"""
478-
CUDA Core-based NVFP4 GEMM runner on modern architectures.
478+
CUDA Core-based NVFP4 GEMM runner.
479479
480480
This runner is available on:
481-
- SM >= 100 (Blackwell and newer architectures)
481+
- SM >= 100 (Blackwell)
482482
- M <= 8 (small batch size limitation from kernel template)
483483
"""
484484

@@ -836,11 +836,12 @@ def nvfp4_gemm_unified(
836836
- CUTLASS: Predefined CUTLASS configurations with auto-tuning
837837
- cuBLASLt: Heuristic-based algorithms from cuBLASLt library
838838
- CuteDSL: Blackwell-optimized persistent kernels (when available and inputs are valid)
839-
- CUDA Core: CUDA Core implementation on SM >= 89 (Ada+), M <= 16 (explicit selection only)
839+
- CUDA Core: CUDA Core implementation (requires SM >= 100 and M <= 8)
840840
841841
The AutoTuner profiles all available backends during the first run and caches
842842
the best choice for each input shape. Subsequent calls use the cached selection
843-
with zero overhead.
843+
with zero overhead. In 'auto' mode, backends are only considered if their
844+
requirements are met (e.g., CUDA Core only participates when SM >= 100 and M <= 8).
844845
845846
Args:
846847
act_fp4: Activation tensor [m, k] in FP4 format (packed in uint8)
@@ -855,7 +856,7 @@ def nvfp4_gemm_unified(
855856
- 'cutlass': Force use CUTLASS (FP4GemmRunner)
856857
- 'cublaslt': Force use cuBLASLt (CublasLtFP4GemmRunner)
857858
- 'cutedsl': Force use CuteDSL (CuteDSLNVFP4Wrapper)
858-
- 'cuda_core': Force use CUDA Core (CudaCoreNVFP4Runner, requires SM >= 89, M <= 16)
859+
- 'cuda_core': Force use CUDA Core (CudaCoreNVFP4Runner, requires SM >= 100, M <= 8)
859860
860861
Returns:
861862
Output tensor [m, n] with dtype=output_dtype
@@ -873,25 +874,34 @@ def nvfp4_gemm_unified(
873874
# Build list of runners based on backend parameter
874875
runners = []
875876

876-
# CUDA Core runner can be enabled via backend='cuda_core' (not in auto mode by default)
877-
# This avoids AutoTuner cache incompatibility issues
878-
if backend == "cuda_core":
879-
# Check if architecture is supported (SM >= 89)
877+
# Add CUDA Core runner if conditions are met
878+
# Only instantiate when both SM version and M dimension requirements are satisfied
879+
if backend in ["auto", "cuda_core"]:
880880
is_cuda_core_supported = False
881+
m = act_fp4.shape[0]
882+
sm_version = None
883+
881884
if torch.cuda.is_available():
882885
capability = torch.cuda.get_device_capability(
883886
torch.device('cuda:0'))
884887
sm_version = capability[0] * 10 + capability[1]
885-
is_cuda_core_supported = sm_version >= CudaCoreNVFP4Runner.MIN_SM_VERSION
888+
# Check both SM version and M dimension constraints
889+
is_cuda_core_supported = (
890+
sm_version >= CudaCoreNVFP4Runner.MIN_SM_VERSION
891+
and m <= CudaCoreNVFP4Runner.MAX_M_DIMENSION)
886892

887893
if is_cuda_core_supported:
888894
runners.append(CudaCoreNVFP4Runner(to_userbuffers, output_dtype))
889-
logger.debug("CUDA Core runner added to nvfp4_gemm_unified")
890-
else:
891-
raise ValueError(
892-
f"CUDA Core backend requires SM >= {CudaCoreNVFP4Runner.MIN_SM_VERSION} (Ada or newer). "
893-
f"Current SM version: {sm_version if torch.cuda.is_available() else 'N/A'}. "
894-
f"Please use backend='auto' or another backend.")
895+
logger.debug(
896+
f"CUDA Core runner added to nvfp4_gemm_unified (SM={sm_version}, M={m})"
897+
)
898+
elif backend == "cuda_core":
899+
# Explicitly requested but conditions not met - raise error
900+
error_msg = f"CUDA Core backend requires SM >= {CudaCoreNVFP4Runner.MIN_SM_VERSION} and M <= {CudaCoreNVFP4Runner.MAX_M_DIMENSION}. "
901+
error_msg += f"Current: SM={sm_version if sm_version else 'N/A'}, M={m}. "
902+
error_msg += "Please use backend='auto' or another backend."
903+
raise ValueError(error_msg)
904+
# For auto mode: silently skip if conditions not met
895905

896906
# Add CUTLASS runner (always available)
897907
if backend in ["auto", "cutlass"]:

tests/unittest/_torch/thop/parallel/test_fp4_linear.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ def test_fp4_linear(dtype, mnk):
3939
out_features=OUTPUT_SIZE,
4040
bias=False,
4141
dtype=dtype,
42-
quant_config=qc)
42+
quant_config=qc,
43+
nvfp4_backend='cutlass') # Force CUTLASS to match reference
4344

4445
assert l_fp4.weight.dtype == fp4_utils.float4_e2m1x2
4546
assert l_fp4.weight_scale.dtype == fp4_utils.float4_sf_dtype

0 commit comments

Comments
 (0)