Skip to content

Commit 8364678

Browse files
committed
Support to_userbuffers in CuteDSL and fix nvfp4_gemm pattern matching
Signed-off-by: Shijie Wang <[email protected]>
1 parent 22d3130 commit 8364678

File tree

3 files changed

+59
-13
lines changed

3 files changed

+59
-13
lines changed

tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -554,11 +554,24 @@ def target_scaled_mm_prologue_pattern(
554554
)
555555

556556
def register_nvfp4_gemm_prologue(custom_pass: PatternMatcherPass):
557+
act_fp4_key = KeywordArg('act_fp4')
558+
weight_key = KeywordArg('weight')
559+
act_sf_key = KeywordArg('act_sf')
560+
weight_scale_key = KeywordArg('weight_scale')
561+
alpha_key = KeywordArg('alpha')
562+
output_dtype_key = KeywordArg('output_dtype')
563+
to_userbuffers_key = KeywordArg('to_userbuffers')
564+
backend_key = KeywordArg('backend')
557565
trtllm_nvfp4_gemm_default = CallFunction(
558-
torch.ops.trtllm.nvfp4_gemm.default, KeywordArg('act_fp4'),
559-
KeywordArg('weight'), KeywordArg('act_sf'),
560-
KeywordArg('weight_scale'), KeywordArg('alpha'),
561-
KeywordArg('output_dtype'))
566+
torch.ops.trtllm.nvfp4_gemm.default,
567+
act_fp4_key,
568+
weight_key,
569+
act_sf_key,
570+
weight_scale_key,
571+
alpha_key,
572+
output_dtype_key,
573+
to_userbuffers=to_userbuffers_key,
574+
backend=backend_key)
562575
ub_copy = CallFunction(torch.ops.trtllm.copy_to_userbuffers,
563576
trtllm_nvfp4_gemm_default)
564577

@@ -569,6 +582,8 @@ def empty_nvfp4_gemm_prologue_pattern(
569582
weight_scale: torch.Tensor,
570583
alpha: torch.Tensor,
571584
output_dtype: torch.dtype,
585+
to_userbuffers: bool,
586+
backend: str,
572587
):
573588
return
574589

@@ -579,18 +594,45 @@ def target_nvfp4_gemm_prologue_pattern(
579594
weight_scale: torch.Tensor,
580595
alpha: torch.Tensor,
581596
output_dtype: torch.dtype,
597+
to_userbuffers: bool,
598+
backend: str,
582599
):
583600
nvfp4_gemm_output = torch.ops.trtllm.nvfp4_gemm(
584601
act_fp4, weight, act_sf, weight_scale, alpha, output_dtype,
585-
True)
602+
True, backend)
586603
return nvfp4_gemm_output
587604

588-
# No extra check needed as the output dtype of nvfp4_gemm has been verified when
589-
# ub_copy is inserted.
605+
def extra_check(match: Match) -> bool:
606+
# Validate backend value
607+
backend_node = match.kwargs.get('backend')
608+
if backend_node is None:
609+
# No backend specified, use default - OK
610+
return True
611+
612+
valid_backends = {'auto', 'cutlass', 'cublaslt', 'cutedsl'}
613+
614+
# Case 1: backend is a Node with metadata
615+
if hasattr(backend_node, 'meta') and 'val' in backend_node.meta:
616+
backend_value = backend_node.meta['val']
617+
if isinstance(backend_value, str):
618+
return backend_value in valid_backends
619+
return False # Invalid type
620+
621+
# Case 2: backend is a constant Node in the graph
622+
if hasattr(backend_node, 'target'):
623+
return backend_node.target in valid_backends
624+
625+
# Case 3: backend is a Python literal in kwargs
626+
if isinstance(backend_node, str):
627+
return backend_node in valid_backends
628+
629+
# Unknown format - reject to be safe
630+
return False
631+
590632
register_replacement(
591633
empty_nvfp4_gemm_prologue_pattern,
592634
target_nvfp4_gemm_prologue_pattern,
593-
[],
635+
[extra_check],
594636
fwd_only,
595637
custom_pass,
596638
search_fn_pattern=ub_copy,

tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,7 @@ def forward(
242242

243243
# Allocate output tensor from UserBuffers or regular CUDA memory
244244
if self.to_userbuffers:
245-
from tensorrt_llm.bindings import torch_ext
246-
c_tensor, _ = torch_ext.create_userbuffers_tensor(
245+
c_tensor, _ = torch.ops.trtllm.create_userbuffers_tensor(
247246
[m, n], self.output_dtype)
248247
else:
249248
c_tensor = torch.empty(*(m, n),

tensorrt_llm/_torch/modules/linear.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -918,9 +918,14 @@ def apply(self, module: Linear, input: torch.Tensor,
918918
backend = getattr(module, 'nvfp4_backend', 'auto')
919919

920920
# Use unified interface - supports CUTLASS, cuBLASLt, CuteDSL
921-
output = torch.ops.trtllm.nvfp4_gemm(act_fp4, module.weight, act_sf,
922-
module.weight_scale, module.alpha,
923-
module.dtype, False, backend)
921+
output = torch.ops.trtllm.nvfp4_gemm(act_fp4,
922+
module.weight,
923+
act_sf,
924+
module.weight_scale,
925+
module.alpha,
926+
module.dtype,
927+
to_userbuffers=False,
928+
backend=backend)
924929
# Take the dim of out_features if padded. Make sure the output is contiguous
925930
if output.shape[-1] > module.out_features:
926931
output = output[..., :module.out_features].contiguous()

0 commit comments

Comments
 (0)