Skip to content

Commit 205d297

Browse files
committed
resolve review comments
Signed-off-by: Shijie Wang <[email protected]>
1 parent e2b8e24 commit 205d297

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,6 @@ def __init__(self,
5050
to_userbuffers: bool = False):
5151
super().__init__()
5252

53-
54-
5553
if output_dtype != torch.bfloat16:
5654
raise ValueError(
5755
f"CuteDSL NVFP4 only supports bfloat16 output, got {output_dtype}"
@@ -242,7 +240,7 @@ def forward(
242240

243241
# Allocate output tensor from UserBuffers or regular CUDA memory
244242
if self.to_userbuffers:
245-
c_tensor, _ = torch.ops.trtllm.create_userbuffers_tensor(
243+
c_tensor = torch.ops.trtllm.create_userbuffers_tensor(
246244
[m, n], self.output_dtype)
247245
else:
248246
c_tensor = torch.empty(*(m, n),

tensorrt_llm/_torch/modules/linear.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -914,9 +914,6 @@ def apply(self, module: Linear, input: torch.Tensor,
914914
act_fp4, act_sf = torch.ops.trtllm.fp4_quantize(
915915
input, module.input_scale, module.scaling_vector_size, False)
916916

917-
# Backend selection: 'auto' (default) | 'cutlass' | 'cublaslt' | 'cutedsl'
918-
backend = getattr(module, 'nvfp4_backend', 'auto')
919-
920917
# Use unified interface - supports CUTLASS, cuBLASLt, CuteDSL
921918
output = torch.ops.trtllm.nvfp4_gemm(act_fp4,
922919
module.weight,
@@ -925,7 +922,7 @@ def apply(self, module: Linear, input: torch.Tensor,
925922
module.alpha,
926923
module.dtype,
927924
to_userbuffers=False,
928-
backend=backend)
925+
backend=module.nvfp4_backend)
929926
# Take the dim of out_features if padded. Make sure the output is contiguous
930927
if output.shape[-1] > module.out_features:
931928
output = output[..., :module.out_features].contiguous()
@@ -2000,6 +1997,12 @@ def __init__(
20001997
fused_weight_shard_indices_mapping: Optional[dict] = None,
20011998
nvfp4_backend: str = "auto",
20021999
):
2000+
"""
2001+
Args:
2002+
nvfp4_backend: Backend selection for NVFP4 GEMM operations.
2003+
Supported values: "auto", "cutlass", "cublaslt", "cutedsl".
2004+
Default is "auto" which automatically selects the best backend.
2005+
"""
20032006
from ..distributed import AllReduce
20042007

20052008
super().__init__()

0 commit comments

Comments
 (0)