Skip to content

Commit dcf5c86

Browse files
Wong4jhyukn
andauthored
[None][feat] Unify nvfp4 gemm backend (#8963)
Signed-off-by: Shijie Wang <[email protected]> Signed-off-by: Yukun He <[email protected]> Signed-off-by: Shijie <[email protected]> Co-authored-by: Yukun He <[email protected]>
1 parent d11acee commit dcf5c86

File tree

7 files changed

+903
-90
lines changed

7 files changed

+903
-90
lines changed

tensorrt_llm/_torch/autotuner.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,11 +365,14 @@ def search_cache(
365365
Returns:
366366
A tuple containing:
367367
[is_cache_hit, runner_id, tactic, stored_profile]
368+
runner_id is the index in the current runners list
368369
"""
369-
for r in runners:
370+
for idx, r in enumerate(runners):
370371
if (cache_key := self.get_cache_key(custom_op, r, input_shapes,
371372
tuning_config)) in self.cache:
372-
return True, *self.cache[cache_key]
373+
# Return the current index in runners list, not the cached runner_id
374+
cached_runner_id, tactic, min_time = self.cache[cache_key]
375+
return True, idx, tactic, min_time
373376

374377
return False, *self.fallback_entry()
375378

tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -554,11 +554,24 @@ def target_scaled_mm_prologue_pattern(
554554
)
555555

556556
def register_nvfp4_gemm_prologue(custom_pass: PatternMatcherPass):
557+
act_fp4_key = KeywordArg('act_fp4')
558+
weight_key = KeywordArg('weight')
559+
act_sf_key = KeywordArg('act_sf')
560+
weight_scale_key = KeywordArg('weight_scale')
561+
alpha_key = KeywordArg('alpha')
562+
output_dtype_key = KeywordArg('output_dtype')
563+
to_userbuffers_key = KeywordArg('to_userbuffers')
564+
backend_key = KeywordArg('backend')
557565
trtllm_nvfp4_gemm_default = CallFunction(
558-
torch.ops.trtllm.nvfp4_gemm.default, KeywordArg('act_fp4'),
559-
KeywordArg('weight'), KeywordArg('act_sf'),
560-
KeywordArg('weight_scale'), KeywordArg('alpha'),
561-
KeywordArg('output_dtype'))
566+
torch.ops.trtllm.nvfp4_gemm.default,
567+
act_fp4_key,
568+
weight_key,
569+
act_sf_key,
570+
weight_scale_key,
571+
alpha_key,
572+
output_dtype_key,
573+
to_userbuffers=to_userbuffers_key,
574+
backend=backend_key)
562575
ub_copy = CallFunction(torch.ops.trtllm.copy_to_userbuffers,
563576
trtllm_nvfp4_gemm_default)
564577

@@ -569,6 +582,8 @@ def empty_nvfp4_gemm_prologue_pattern(
569582
weight_scale: torch.Tensor,
570583
alpha: torch.Tensor,
571584
output_dtype: torch.dtype,
585+
to_userbuffers: bool,
586+
backend: str,
572587
):
573588
return
574589

@@ -579,21 +594,36 @@ def target_nvfp4_gemm_prologue_pattern(
579594
weight_scale: torch.Tensor,
580595
alpha: torch.Tensor,
581596
output_dtype: torch.dtype,
597+
to_userbuffers: bool,
598+
backend: str,
582599
):
583600
nvfp4_gemm_output = torch.ops.trtllm.nvfp4_gemm(
584601
act_fp4, weight, act_sf, weight_scale, alpha, output_dtype,
585-
True)
602+
True, backend)
586603
return nvfp4_gemm_output
587604

588-
# No extra check needed as the output dtype of nvfp4_gemm has been verified when
589-
# ub_copy is inserted.
605+
def extra_check(match: Match) -> bool:
606+
# Validate backend value
607+
backend_value = match.kwargs.get('backend')
608+
if backend_value is None:
609+
# No backend specified, use default - OK
610+
return True
611+
612+
# backend should be a string literal
613+
if not isinstance(backend_value, str):
614+
return False
615+
616+
valid_backends = {'auto', 'cutlass', 'cublaslt', 'cutedsl'}
617+
return backend_value in valid_backends
618+
590619
register_replacement(
591620
empty_nvfp4_gemm_prologue_pattern,
592621
target_nvfp4_gemm_prologue_pattern,
593622
[],
594623
fwd_only,
595624
custom_pass,
596625
search_fn_pattern=ub_copy,
626+
extra_check=extra_check,
597627
)
598628

599629
def register_mm_prologue(custom_pass: PatternMatcherPass):

tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py

Lines changed: 134 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
import torch
55

6+
from tensorrt_llm.logger import logger
7+
68
from ..._utils import get_sm_version
79
from ...math_utils import pad_up
810
from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec,
@@ -32,7 +34,7 @@
3234
Sm100BlockScaledPersistentDenseGemmKernel
3335
from ..cute_dsl_kernels.blackwell.utils import make_ptr
3436

35-
class CuteDSLNVFP4BlackwellRunner(TunableRunner):
37+
class CuteDSLNVFP4BlackwellLinear(TunableRunner):
3638
kernel_class = Sm100BlockScaledPersistentDenseGemmKernel
3739
kernel_cache = dict()
3840
tuning_config = TuningConfig(
@@ -43,26 +45,44 @@ class CuteDSLNVFP4BlackwellRunner(TunableRunner):
4345
use_cold_l2_cache=True,
4446
)
4547

46-
def __init__(self, alpha: float, output_dtype: torch.dtype):
48+
def __init__(self,
49+
output_dtype: torch.dtype,
50+
to_userbuffers: bool = False):
4751
super().__init__()
48-
self.alpha = alpha
49-
self.output_dtype = output_dtype
50-
assert output_dtype == torch.bfloat16
5152

52-
if get_sm_version() not in [100, 103]:
53+
if output_dtype != torch.bfloat16:
5354
raise ValueError(
54-
f"SM version {get_sm_version()} is not supported for {self.__class__.__name__}, it only supports SM 100 and SM 103"
55+
f"CuteDSL NVFP4 only supports bfloat16 output, got {output_dtype}"
5556
)
57+
self.output_dtype = output_dtype
58+
self.to_userbuffers = to_userbuffers
5659

5760
def unique_id(self):
58-
return (self.output_dtype, )
61+
return (self.output_dtype, self.to_userbuffers)
62+
63+
def __hash__(self):
64+
return hash((self.output_dtype, self.to_userbuffers))
65+
66+
def __eq__(self, other):
67+
if not isinstance(other, self.__class__):
68+
return False
69+
return self.output_dtype == other.output_dtype and self.to_userbuffers == other.to_userbuffers
5970

6071
def get_valid_tactics(
6172
self,
6273
inputs: List[torch.Tensor],
6374
profile: OptimizationProfile,
6475
**kwargs,
6576
) -> List[Tuple[int, int]]:
77+
# Early exit: Check SM version - CuteDSL NVFP4 only supports SM 100 and SM 103
78+
sm_version = get_sm_version()
79+
if sm_version not in [100, 103]:
80+
logger.debug(
81+
f"CuteDSL: SM version {sm_version} is not supported. "
82+
f"CuteDSL NVFP4 only supports SM 100 (B200) and SM 103 (B300). Skipping all tactics."
83+
)
84+
return []
85+
6686
assert inputs[0].dim() == 2
6787
assert inputs[1].dim() == 2
6888

@@ -73,11 +93,44 @@ def get_valid_tactics(
7393
real_k = k * 2
7494
batch_size = 1
7595
sf_vec_size = 16
76-
# m,k
96+
97+
# Fixed layout for FP4: A and B are always K-major
7798
a_major = "k"
78-
# n, k
7999
b_major = "k"
80100

101+
# Early exit: Check K dimension alignment
102+
# For K-major layout (A and B tensors), K is the major mode (contiguous dimension).
103+
# 16-byte alignment requirement: K must be divisible by 32 for FP4 (128 bits / 4 bits = 32)
104+
if real_k % 32 != 0:
105+
logger.debug(
106+
f"CuteDSL: K={real_k} does not meet 16-byte alignment requirement "
107+
f"(K%32={real_k%32}, expected 0). Skipping all tactics.")
108+
return []
109+
110+
# Optimize swap_ab candidates based on M and N alignment
111+
# swap_ab=False → C is N-major → requires N%8==0 (BF16: 128 bits / 16 bits = 8)
112+
# swap_ab=True → C is M-major → requires M%8==0
113+
m_aligned = (m % 8 == 0)
114+
n_aligned = (n % 8 == 0)
115+
116+
if not m_aligned and not n_aligned:
117+
logger.debug(
118+
f"CuteDSL: Neither M={m} nor N={n} meets 16-byte alignment "
119+
f"(M%8={m%8}, N%8={n%8}). No valid C layout. Skipping all tactics."
120+
)
121+
return []
122+
123+
# Only test swap_ab values that satisfy alignment
124+
swap_ab_candidates = []
125+
if n_aligned:
126+
swap_ab_candidates.append(False) # N-major layout
127+
if m_aligned:
128+
swap_ab_candidates.append(True) # M-major layout
129+
130+
logger.debug(
131+
f"CuteDSL: M={m}(aligned={m_aligned}), N={n}(aligned={n_aligned}), K={real_k}(aligned=True). "
132+
f"Testing swap_ab={swap_ab_candidates}")
133+
81134
# full shamoo
82135
mma_tiler_mn_candidates = [
83136
(128, 64),
@@ -134,6 +187,9 @@ def get_valid_tactics(
134187
valid_tactics.append(
135188
(mma_tiler_mn, cluster_shape_mn, swap_ab, use_prefetch))
136189

190+
logger.debug(
191+
f"CuteDSL: Found {len(valid_tactics)} valid tactics for M={m}, N={n}, K={real_k}"
192+
)
137193
return valid_tactics
138194

139195
def make_cute_dsl_global_pointer(self, tensor: torch.Tensor, dtype,
@@ -149,6 +205,7 @@ def forward(
149205
self,
150206
inputs: List[torch.Tensor],
151207
tactic,
208+
**kwargs,
152209
) -> torch.Tensor:
153210
"""
154211
Performs fp8 blockwise gemm operation using CuTe DSL.
@@ -160,8 +217,7 @@ def forward(
160217
inputs[2]: Input scale tensor of shape (k//16, m), dtype: fp8.
161218
inputs[3]: Weight scale tensor of shape (n, k//16), dtype: fp8.
162219
inputs[4]: Alpha scaling factor. dtype: float32.
163-
inputs[5]: Output dtype, expected to be torch.bfloat16.
164-
tactic: Tiling and cluster strategy, typically a tuple (mma_tiler_mn, cluster_shape_mn, swap_ab, use_prefetch).
220+
tactic: Tiling and cluster strategy, typically a tuple (mma_tiler_mn, cluster_shape_mn).
165221
166222
Returns:
167223
torch.Tensor: Output tensor of shape (m, n), dtype: bf16.
@@ -179,11 +235,17 @@ def forward(
179235
False,
180236
]
181237

182-
a_tensor, b_tensor, a_sf_tensor, b_sf_tensor = inputs
238+
a_tensor, b_tensor, a_sf_tensor, b_sf_tensor, alpha_tensor = inputs
183239
m, k, n = a_tensor.shape[0], a_tensor.shape[1], b_tensor.shape[0]
184-
c_tensor = torch.empty(*(m, n),
185-
dtype=self.output_dtype,
186-
device="cuda")
240+
241+
# Allocate output tensor from UserBuffers or regular CUDA memory
242+
if self.to_userbuffers:
243+
c_tensor = torch.ops.trtllm.create_userbuffers_tensor(
244+
[m, n], self.output_dtype)
245+
else:
246+
c_tensor = torch.empty(*(m, n),
247+
dtype=self.output_dtype,
248+
device="cuda")
187249

188250
if swap_ab:
189251
c_tensor = c_tensor.permute(1, 0)
@@ -193,9 +255,27 @@ def forward(
193255
sf_k = pad_up(real_k // sf_vec_size, 4)
194256
sf_n = pad_up(n, 128)
195257

196-
# the scaling tensor is 1D. we need to make sure it has been padded to the correct shape
197-
assert a_sf_tensor.shape == (sf_m * sf_k, )
198-
assert b_sf_tensor.shape == (sf_n * sf_k, )
258+
# Reshape scale factors to CuteDSL's expected format
259+
# Input format (from CUTLASS/cuBLASLt): (m*k//16,) and (n*k//16,)
260+
# CuteDSL format: (sf_m*sf_k,) and (sf_n*sf_k,)
261+
# Note: This is just a view change, no memory copy
262+
expected_a_sf_size = sf_m * sf_k
263+
expected_b_sf_size = sf_n * sf_k
264+
265+
if a_sf_tensor.numel() != expected_a_sf_size:
266+
raise ValueError(
267+
f"CuteDSL: act scale factor size mismatch. "
268+
f"Expected {expected_a_sf_size} (sf_m={sf_m} * sf_k={sf_k}), "
269+
f"got {a_sf_tensor.numel()} for shape M={m}, K={real_k}")
270+
if b_sf_tensor.numel() != expected_b_sf_size:
271+
raise ValueError(
272+
f"CuteDSL: weight scale factor size mismatch. "
273+
f"Expected {expected_b_sf_size} (sf_n={sf_n} * sf_k={sf_k}), "
274+
f"got {b_sf_tensor.numel()} for shape N={n}, K={real_k}")
275+
276+
# Reshape to CuteDSL's expected format (just a view, no copy)
277+
a_sf_tensor = a_sf_tensor.reshape(sf_m * sf_k)
278+
b_sf_tensor = b_sf_tensor.reshape(sf_n * sf_k)
199279

200280
a_ptr = self.make_cute_dsl_global_pointer(a_tensor,
201281
cutlass.Float4E2M1FN, 32)
@@ -207,6 +287,9 @@ def forward(
207287
b_sf_tensor, cutlass.Float8E4M3FN, 16)
208288
c_ptr = self.make_cute_dsl_global_pointer(c_tensor,
209289
cutlass.BFloat16, 16)
290+
# Create pointer to alpha on device
291+
alpha_ptr = self.make_cute_dsl_global_pointer(
292+
alpha_tensor, cutlass.Float32, 4)
210293

211294
# get stream
212295
torch_stream = torch.cuda.current_stream()
@@ -259,7 +342,7 @@ def forward(
259342
kernel_a_sf_ptr,
260343
kernel_b_sf_ptr,
261344
c_ptr,
262-
self.alpha,
345+
alpha_ptr, # Pass alpha as device pointer
263346
max_active_clusters,
264347
stream,
265348
swap_ab,
@@ -283,7 +366,7 @@ def forward(
283366
kernel_a_sf_ptr,
284367
kernel_b_sf_ptr,
285368
c_ptr,
286-
self.alpha,
369+
alpha_ptr, # Pass alpha as device pointer
287370
stream,
288371
)
289372

@@ -300,20 +383,45 @@ def cute_dsl_nvfp4_gemm_blackwell(
300383
weight: torch.Tensor,
301384
input_scale: torch.Tensor,
302385
weight_scale: torch.Tensor,
303-
alpha: float,
386+
alpha: torch.Tensor,
304387
output_dtype: torch.dtype,
388+
to_userbuffers: bool = False,
305389
) -> torch.Tensor:
390+
"""CuteDSL-based NVFP4 GEMM optimized for Blackwell.
391+
392+
Args:
393+
input: Activation tensor [m, k] in FP4 format (packed in uint8)
394+
weight: Weight tensor [n, k] in FP4 format (packed in uint8)
395+
input_scale: Activation scale factors
396+
weight_scale: Weight scale factors
397+
alpha: Scaling factor
398+
output_dtype: Output data type (must be bfloat16)
399+
to_userbuffers: Whether to allocate output from UserBuffers pool
400+
401+
Note:
402+
This function is primarily used internally by nvfp4_gemm.
403+
Direct usage is discouraged. Consider using nvfp4_gemm instead
404+
for automatic backend selection with better performance.
405+
"""
406+
# Validate SM version before attempting to use CuteDSL
407+
sm_version = get_sm_version()
408+
if sm_version not in [100, 103]:
409+
raise ValueError(
410+
f"CuteDSL NVFP4 backend requires SM 100 (B200) or SM 103 (B300), but got SM {sm_version}. "
411+
f"Please use nvfp4_gemm with backend='auto' for automatic backend selection."
412+
)
306413

307414
tuner = AutoTuner.get()
308415

309-
runner = CuteDSLNVFP4BlackwellRunner(alpha, output_dtype)
310-
inputs = [input, weight, input_scale, weight_scale]
416+
runner = CuteDSLNVFP4BlackwellLinear(output_dtype, to_userbuffers)
417+
inputs = [input, weight, input_scale, weight_scale, alpha]
311418
_, best_tactic = tuner.choose_one(
312419
"trtllm::cute_dsl_nvfp4_gemm_blackwell",
313420
[runner],
314421
runner.__class__.tuning_config,
315422
inputs,
316423
)
424+
317425
output = runner(inputs, tactic=best_tactic)
318426
return output
319427

@@ -323,8 +431,9 @@ def _(
323431
mat_b: torch.Tensor,
324432
input_scale: torch.Tensor,
325433
weight_scale: torch.Tensor,
326-
alpha: float,
434+
alpha: torch.Tensor, # Match custom op signature
327435
output_dtype: torch.dtype,
436+
to_userbuffers: bool = False,
328437
):
329438
# [m, k]
330439
shape = list(mat_a.shape)

0 commit comments

Comments
 (0)