Skip to content

Commit 32e7deb

Browse files
committed
refactor and fix bug
Signed-off-by: Shijie Wang <[email protected]>
1 parent f2b255e commit 32e7deb

File tree

4 files changed

+301
-215
lines changed

4 files changed

+301
-215
lines changed

tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py

Lines changed: 70 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
import torch
55

6+
from tensorrt_llm.logger import logger
7+
68
from ..._utils import get_sm_version
79
from ...math_utils import pad_up
810
from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec,
@@ -80,11 +82,48 @@ def get_valid_tactics(
8082
real_k = k * 2
8183
batch_size = 1
8284
sf_vec_size = 16
83-
# m,k
85+
86+
# Fixed layout for FP4: A and B are always K-major
8487
a_major = "k"
85-
# n, k
8688
b_major = "k"
8789

90+
# Data types
91+
ab_dtype = cutlass.Float4E2M1FN
92+
c_dtype = cutlass.BFloat16
93+
94+
# Early exit: Check K dimension alignment
95+
# For K-major layout (A and B tensors), K is the major mode (contiguous dimension).
96+
# 16-byte alignment requirement: K must be divisible by 32 for FP4 (128 bits / 4 bits = 32)
97+
if real_k % 32 != 0:
98+
logger.debug(
99+
f"CuteDSL: K={real_k} does not meet 16-byte alignment requirement "
100+
f"(K%32={real_k%32}, expected 0). Skipping all tactics.")
101+
return []
102+
103+
# Optimize swap_ab candidates based on M and N alignment
104+
# swap_ab=False → C is N-major → requires N%8==0 (BF16: 128 bits / 16 bits = 8)
105+
# swap_ab=True → C is M-major → requires M%8==0
106+
m_aligned = (m % 8 == 0)
107+
n_aligned = (n % 8 == 0)
108+
109+
if not m_aligned and not n_aligned:
110+
logger.debug(
111+
f"CuteDSL: Neither M={m} nor N={n} meets 16-byte alignment "
112+
f"(M%8={m%8}, N%8={n%8}). No valid C layout. Skipping all tactics."
113+
)
114+
return []
115+
116+
# Only test swap_ab values that satisfy alignment
117+
swap_ab_candidates = []
118+
if n_aligned:
119+
swap_ab_candidates.append(False) # N-major layout
120+
if m_aligned:
121+
swap_ab_candidates.append(True) # M-major layout
122+
123+
logger.debug(
124+
f"CuteDSL: M={m}(aligned={m_aligned}), N={n}(aligned={n_aligned}), K={real_k}(aligned=True). "
125+
f"Testing swap_ab={swap_ab_candidates}")
126+
88127
# full shamoo
89128
mma_tiler_mn_candidates = [
90129
(256, 128),
@@ -105,7 +144,6 @@ def get_valid_tactics(
105144
(4, 2),
106145
(4, 4),
107146
]
108-
swap_ab_candidates = [True, False]
109147

110148
valid_tactics = []
111149
for swap_ab in swap_ab_candidates:
@@ -120,11 +158,12 @@ def get_valid_tactics(
120158
kernel_m = m
121159
kernel_n = n
122160

161+
# Use can_implement to check all constraints
123162
if Sm100BlockScaledPersistentDenseGemmKernel.can_implement(
124-
cutlass.Float4E2M1FN, # ab_dtype,
163+
ab_dtype,
125164
cutlass.Float8E4M3FN, # sf_dtype
126-
sf_vec_size, # sf_vec_size,
127-
cutlass.BFloat16, # c_dtype,
165+
sf_vec_size,
166+
c_dtype,
128167
mma_tiler_mn,
129168
cluster_shape_mn,
130169
kernel_m,
@@ -138,6 +177,9 @@ def get_valid_tactics(
138177
valid_tactics.append(
139178
(mma_tiler_mn, cluster_shape_mn, swap_ab))
140179

180+
logger.debug(
181+
f"CuteDSL: Found {len(valid_tactics)} valid tactics for M={m}, N={n}, K={real_k}"
182+
)
141183
return valid_tactics
142184

143185
def make_cute_dsl_global_pointer(self, tensor: torch.Tensor, dtype,
@@ -196,9 +238,27 @@ def forward(
196238
sf_k = pad_up(real_k // sf_vec_size, 4)
197239
sf_n = pad_up(n, 128)
198240

199-
# the scaling tensor is 1D. we need to make sure it has been padded to the correct shape
200-
assert a_sf_tensor.shape == (sf_m * sf_k, )
201-
assert b_sf_tensor.shape == (sf_n * sf_k, )
241+
# Reshape scale factors to CuteDSL's expected format
242+
# Input format (from CUTLASS/cuBLASLt): (m*k//16,) and (n*k//16,)
243+
# CuteDSL format: (sf_m*sf_k,) and (sf_n*sf_k,)
244+
# Note: This is just a view change, no memory copy
245+
expected_a_sf_size = sf_m * sf_k
246+
expected_b_sf_size = sf_n * sf_k
247+
248+
if a_sf_tensor.numel() != expected_a_sf_size:
249+
raise ValueError(
250+
f"CuteDSL: act scale factor size mismatch. "
251+
f"Expected {expected_a_sf_size} (sf_m={sf_m} * sf_k={sf_k}), "
252+
f"got {a_sf_tensor.numel()} for shape M={m}, K={real_k}")
253+
if b_sf_tensor.numel() != expected_b_sf_size:
254+
raise ValueError(
255+
f"CuteDSL: weight scale factor size mismatch. "
256+
f"Expected {expected_b_sf_size} (sf_n={sf_n} * sf_k={sf_k}), "
257+
f"got {b_sf_tensor.numel()} for shape N={n}, K={real_k}")
258+
259+
# Reshape to CuteDSL's expected format (just a view, no copy)
260+
a_sf_tensor = a_sf_tensor.reshape(sf_m * sf_k)
261+
b_sf_tensor = b_sf_tensor.reshape(sf_n * sf_k)
202262

203263
a_ptr = self.make_cute_dsl_global_pointer(a_tensor,
204264
cutlass.Float4E2M1FN, 32)
@@ -328,7 +388,7 @@ def cute_dsl_nvfp4_gemm_blackwell(
328388
"trtllm::cute_dsl_nvfp4_gemm_blackwell",
329389
[runner],
330390
CuteDSLNVFP4BlackwellLinear.tuning_config,
331-
[input, weight, input_scale, weight_scale],
391+
[input, weight, input_scale, weight_scale, alpha],
332392
)
333393

334394
return runner(

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from functools import lru_cache
2-
from typing import List, Mapping, Optional, Tuple
2+
from typing import List, Mapping, Optional, Tuple, Union
33

44
import torch
55
import triton # type: ignore[import]
@@ -707,7 +707,29 @@ def get_valid_tactics(self,
707707
# Add CuteDSL runner if available
708708
if backend in ["auto", "cutedsl"]:
709709
if IS_CUTLASS_DSL_AVAILABLE:
710-
tactics.append("cutedsl")
710+
# Check if CuteDSL actually supports the current shape
711+
from tensorrt_llm._torch.custom_ops.cute_dsl_custom_ops import \
712+
CuteDSLNVFP4BlackwellLinear
713+
cutedsl_runner = CuteDSLNVFP4BlackwellLinear(self.output_dtype)
714+
cutedsl_tactics = cutedsl_runner.get_valid_tactics(
715+
inputs, profile)
716+
717+
if cutedsl_tactics:
718+
# CuteDSL supports this shape
719+
tactics.append("cutedsl")
720+
elif backend == "cutedsl":
721+
# Explicitly requested CuteDSL but it doesn't support this shape
722+
m, n, k = inputs[0].shape[0], inputs[1].shape[
723+
0], inputs[0].shape[1] * 2
724+
raise ValueError(
725+
f"CuteDSL backend does not support the current shape:\n"
726+
f" M={m}, N={n}, K={k}\n"
727+
f"CuteDSL requires 16-byte alignment for major (contiguous) dimensions:\n"
728+
f" - K must be divisible by 32 (FP4 K-major layout): K%32={'0✓' if k % 32 == 0 else str(k%32)+'✗'}\n"
729+
f" - Or the combination of (M, N, K, tiling, cluster shape) is not supported\n"
730+
f"Please use backend='auto' to automatically select a compatible backend."
731+
)
732+
# else: backend='auto' and CuteDSL doesn't support → silently skip
711733
elif backend == "cutedsl":
712734
raise ValueError(
713735
"CuteDSL backend is not available. "
@@ -718,11 +740,40 @@ def get_valid_tactics(self,
718740
def forward(
719741
self,
720742
inputs: List[torch.Tensor],
721-
tactic: str = "cutlass",
743+
tactic: Union[
744+
str, int] = "cutlass", # str: backend name, or int: -1 for fallback
722745
**kwargs,
723746
) -> torch.Tensor:
724747
act_fp4, weight, act_sf, weight_scale, alpha = inputs
725748

749+
# Check if a specific backend was requested
750+
requested_backend = kwargs.get('backend', 'auto')
751+
752+
# If a specific backend was requested (not 'auto') and we're using fallback tactic
753+
# This can happen on cache miss, where AutoTuner uses tactic=-1 as default
754+
if requested_backend != 'auto' and requested_backend != tactic and tactic == -1:
755+
# User explicitly requested a backend, but we're falling back to default
756+
# This might happen on cache miss. We should validate the requested backend supports this shape.
757+
758+
# Get valid tactics for the requested backend
759+
from tensorrt_llm._torch.autotuner import OptimizationProfile
760+
valid_tactics = self.get_valid_tactics(inputs,
761+
OptimizationProfile(),
762+
backend=requested_backend)
763+
764+
if not valid_tactics or requested_backend not in valid_tactics:
765+
# Requested backend doesn't support this shape
766+
m, n, k = inputs[0].shape[0], inputs[1].shape[
767+
0], inputs[0].shape[1] * 2
768+
raise ValueError(
769+
f"Backend '{requested_backend}' was explicitly requested but does not support the current shape:\n"
770+
f" M={m}, N={n}, K={k}\n"
771+
f"Please use backend='auto' to automatically select a compatible backend."
772+
)
773+
774+
# Backend supports it, use the requested backend instead of fallback
775+
tactic = requested_backend
776+
726777
if tactic == "cuda_core":
727778
# Unswizzle the activation scale factors
728779
# act_sf is swizzled, need to reverse it for cuda_core_nvfp4_gemm
@@ -844,6 +895,7 @@ def nvfp4_gemm_unified(
844895
return runner(
845896
inputs=[act_fp4, weight, act_sf, weight_scale, alpha],
846897
tactic=best_tactic,
898+
backend=backend,
847899
)
848900

849901

tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_blockscaled_gemm_persistent.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def __call__(
300300
sfa_tensor: cute.Tensor,
301301
sfb_tensor: cute.Tensor,
302302
c_tensor: cute.Tensor,
303-
alpha: cute.Pointer, # Changed from cutlass.Float32 to device pointer
303+
alpha: cute.Tensor, # Single-element tensor containing alpha value
304304
max_active_clusters: cutlass.Constexpr,
305305
stream: cuda.CUstream,
306306
epilogue_op: cutlass.Constexpr = lambda x: x,
@@ -571,13 +571,12 @@ def kernel(
571571
epi_tile: cute.Tile,
572572
tile_sched_params: utils.PersistentTileSchedulerParams,
573573
epilogue_op: cutlass.Constexpr,
574-
alpha: cute.
575-
Pointer, # Changed from cutlass.Float32 to device pointer
574+
alpha: cute.Tensor, # Single-element tensor containing alpha value
576575
):
577576
"""
578577
GPU device kernel performing the Persistent batched GEMM computation.
579578
"""
580-
alpha_value = alpha.load().to(self.c_dtype)
579+
alpha_value = alpha[0].to(self.c_dtype)
581580

582581
warp_idx = cute.arch.warp_idx()
583582
warp_idx = cute.arch.make_warp_uniform(warp_idx)
@@ -1944,7 +1943,8 @@ def __call__(
19441943
a_sf_ptr: cute.Pointer,
19451944
b_sf_ptr: cute.Pointer,
19461945
c_ptr: cute.Pointer,
1947-
alpha: cute.Pointer, # Changed from cutlass.Float32 to device pointer
1946+
alpha: cute.
1947+
Pointer, # Device pointer to alpha, will be converted to Tensor
19481948
max_active_clusters: cutlass.Constexpr,
19491949
current_stream: cuda.CUstream,
19501950
swap_ab: cutlass.Constexpr = False,
@@ -1965,7 +1965,7 @@ def __call__(
19651965
a_sf_ptr (cute.Pointer): Pointer to the scale factor tensor for A.
19661966
b_sf_ptr (cute.Pointer): Pointer to the scale factor tensor for B.
19671967
c_ptr (cute.Pointer): Pointer to the C tensor.
1968-
alpha (cute.Pointer): Pointer to alpha scaling factor on device (avoids CPU-GPU sync).
1968+
alpha (cute.Pointer): Device pointer to alpha scaling factor (converted to Tensor internally).
19691969
max_active_clusters (cutlass.Constexpr): Maximum number of active
19701970
clusters.
19711971
current_stream (cuda.CUstream): CUDA stream for the operation.
@@ -2011,11 +2011,17 @@ def __call__(
20112011
order=(2, 1, 4, 0, 3, 5),
20122012
))
20132013

2014+
# Convert alpha pointer to a single-element cute.Tensor for easier kernel usage
2015+
# Create a 1D layout with a single element
2016+
alpha_tensor = cute.make_tensor(alpha,
2017+
layout=cute.make_ordered_layout(
2018+
(1, ), order=(0, )))
2019+
20142020
Sm100BlockScaledPersistentDenseGemmKernel(
20152021
self.sf_vec_size,
20162022
self.mma_tiler_mn,
20172023
self.cluster_shape_mn,
2018-
)(a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor, alpha,
2024+
)(a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor, alpha_tensor,
20192025
max_active_clusters, current_stream, epilogue_op)
20202026

20212027

0 commit comments

Comments
 (0)