Skip to content

Commit ca03bfe

Browse files
committed
Modify the alpha in cutedsl to be a device pointer and refactor code
Signed-off-by: Shijie Wang <[email protected]>
1 parent 32d2ad6 commit ca03bfe

File tree

3 files changed

+73
-70
lines changed

3 files changed

+73
-70
lines changed

tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from tensorrt_llm._utils import get_sm_version
77
from tensorrt_llm.math_utils import pad_up
88

9-
from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec,
10-
OptimizationProfile, TunableRunner, TuningConfig)
9+
from ..autotuner import (ConstraintSpec, DynamicTensorSpec, OptimizationProfile,
10+
TunableRunner, TuningConfig)
1111
from ..cute_dsl_utils import IS_CUTLASS_DSL_AVAILABLE
1212
from ..utils import (fp4_scale_infer_shape,
1313
get_last_power_of_2_num_tokens_buckets,
@@ -38,16 +38,21 @@ class CuteDSLNVFP4BlackwellLinear(TunableRunner):
3838
constraint_specs=(ConstraintSpec(2, 0, fp4_scale_infer_shape), ),
3939
)
4040

41-
def __init__(self, alpha: float, output_dtype: torch.dtype):
41+
def __init__(self, output_dtype: torch.dtype):
4242
super().__init__()
43-
self.alpha = alpha
43+
44+
# Validate output dtype (use proper exception instead of assert)
45+
if output_dtype != torch.bfloat16:
46+
raise ValueError(
47+
f"CuteDSL NVFP4 only supports bfloat16 output, got {output_dtype}"
48+
)
4449
self.output_dtype = output_dtype
45-
assert output_dtype == torch.bfloat16
4650

51+
# Validate SM version at initialization
4752
if get_sm_version() != 100:
4853
raise ValueError(
49-
f"SM version {get_sm_version()} is not supported for CuteDSLNVFP4BlackwellLinear, it only supports SM 100"
50-
)
54+
f"SM version {get_sm_version()} is not supported. "
55+
f"CuteDSL NVFP4 requires SM 100 (Blackwell).")
5156

5257
# rewrite the hash function because the value of self.alpha doesn't affect the tactic.
5358
def __hash__(self):
@@ -147,6 +152,7 @@ def forward(
147152
self,
148153
inputs: List[torch.Tensor],
149154
tactic,
155+
**kwargs,
150156
) -> torch.Tensor:
151157
"""
152158
Performs fp8 blockwise gemm operation using CuTe DSL.
@@ -158,7 +164,6 @@ def forward(
158164
inputs[2]: Input scale tensor of shape (k//16, m), dtype: fp8.
159165
inputs[3]: Weight scale tensor of shape (n, k//16), dtype: fp8.
160166
inputs[4]: Alpha scaling factor. dtype: float32.
161-
inputs[5]: Output dtype, expected to be torch.bfloat16.
162167
tactic: Tiling and cluster strategy, typically a tuple (mma_tiler_mn, cluster_shape_mn).
163168
164169
Returns:
@@ -176,7 +181,7 @@ def forward(
176181
False,
177182
]
178183

179-
a_tensor, b_tensor, a_sf_tensor, b_sf_tensor = inputs
184+
a_tensor, b_tensor, a_sf_tensor, b_sf_tensor, alpha_tensor = inputs
180185
m, k, n = a_tensor.shape[0], a_tensor.shape[1], b_tensor.shape[0]
181186
c_tensor = torch.empty(*(m, n),
182187
dtype=self.output_dtype,
@@ -204,6 +209,9 @@ def forward(
204209
b_sf_tensor, cutlass.Float8E4M3FN, 16)
205210
c_ptr = self.make_cute_dsl_global_pointer(c_tensor,
206211
cutlass.BFloat16, 16)
212+
# Create pointer to alpha on device
213+
alpha_ptr = self.make_cute_dsl_global_pointer(
214+
alpha_tensor, cutlass.Float32, 4)
207215

208216
# get stream
209217
torch_stream = torch.cuda.current_stream()
@@ -260,7 +268,7 @@ def forward(
260268
kernel_a_sf_ptr,
261269
kernel_b_sf_ptr,
262270
c_ptr,
263-
self.alpha,
271+
alpha_ptr, # Pass alpha as device pointer
264272
max_active_clusters,
265273
stream,
266274
swap_ab,
@@ -285,7 +293,7 @@ def forward(
285293
kernel_a_sf_ptr,
286294
kernel_b_sf_ptr,
287295
c_ptr,
288-
self.alpha,
296+
alpha_ptr, # Pass alpha as device pointer
289297
stream,
290298
)
291299

@@ -302,34 +310,31 @@ def cute_dsl_nvfp4_gemm_blackwell(
302310
weight: torch.Tensor,
303311
input_scale: torch.Tensor,
304312
weight_scale: torch.Tensor,
305-
alpha: float,
313+
alpha: torch.Tensor,
306314
output_dtype: torch.dtype,
307315
) -> torch.Tensor:
308316
"""CuteDSL-based NVFP4 GEMM optimized for Blackwell.
309317
310-
.. deprecated::
311-
Use :func:`torch.ops.trtllm.nvfp4_gemm_unified` instead for automatic
312-
backend selection among CUTLASS, cuBLASLt, and CuteDSL based on
313-
performance profiling.
318+
Note:
319+
This function is primarily used internally by nvfp4_gemm_unified.
320+
Direct usage is discouraged. Consider using nvfp4_gemm_unified instead
321+
for automatic backend selection with better performance.
314322
"""
315-
from tensorrt_llm.logger import logger
316-
logger.warning_once(
317-
"cute_dsl_nvfp4_gemm_blackwell is deprecated. Use nvfp4_gemm_unified instead "
318-
"for automatic backend selection with better performance.",
319-
key="cute_dsl_nvfp4_gemm_blackwell_deprecated")
323+
from tensorrt_llm._torch.autotuner import AutoTuner
320324

321325
tuner = AutoTuner.get()
322326

323-
cute_dsl_nvfp4_gemm_blackwell_runner = CuteDSLNVFP4BlackwellLinear(
324-
alpha, output_dtype)
327+
runner = CuteDSLNVFP4BlackwellLinear(output_dtype)
328+
325329
_, best_tactic = tuner.choose_one(
326330
"trtllm::cute_dsl_nvfp4_gemm_blackwell",
327-
[cute_dsl_nvfp4_gemm_blackwell_runner],
331+
[runner],
328332
CuteDSLNVFP4BlackwellLinear.tuning_config,
329333
[input, weight, input_scale, weight_scale],
330334
)
331-
return cute_dsl_nvfp4_gemm_blackwell_runner(
332-
inputs=[input, weight, input_scale, weight_scale],
335+
336+
return runner(
337+
inputs=[input, weight, input_scale, weight_scale, alpha],
333338
tactic=best_tactic,
334339
)
335340

@@ -339,7 +344,7 @@ def _(
339344
mat_b: torch.Tensor,
340345
input_scale: torch.Tensor,
341346
weight_scale: torch.Tensor,
342-
alpha: float,
347+
alpha: torch.Tensor, # Match custom op signature
343348
output_dtype: torch.dtype,
344349
):
345350
# [m, k]

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -560,14 +560,11 @@ def nvfp4_gemm_cublaslt(
560560
) -> torch.Tensor:
561561
"""cuBLASLt-based NVFP4 GEMM with heuristic-based auto-tuning.
562562
563-
.. deprecated::
564-
Use :func:`nvfp4_gemm_unified` instead for automatic backend selection
565-
among CUTLASS, cuBLASLt, and CuteDSL based on performance profiling.
563+
Note:
564+
This function is primarily used internally by nvfp4_gemm_unified.
565+
Direct usage is discouraged. Consider using nvfp4_gemm_unified instead
566+
for automatic backend selection with better performance.
566567
"""
567-
logger.warning_once(
568-
"nvfp4_gemm_cublaslt is deprecated. Use nvfp4_gemm_unified instead "
569-
"for automatic backend selection with better performance.",
570-
key="nvfp4_gemm_cublaslt_deprecated")
571568
tuner = AutoTuner.get()
572569

573570
# Use CublasLt runner with heuristic-based tuning
@@ -616,14 +613,11 @@ def nvfp4_gemm(
616613
) -> torch.Tensor:
617614
"""CUTLASS-based NVFP4 GEMM with auto-tuning.
618615
619-
.. deprecated::
620-
Use :func:`nvfp4_gemm_unified` instead for automatic backend selection
621-
among CUTLASS, cuBLASLt, and CuteDSL based on performance profiling.
616+
Note:
617+
This function is primarily used internally by nvfp4_gemm_unified.
618+
Direct usage is discouraged. Consider using nvfp4_gemm_unified instead
619+
for automatic backend selection with better performance.
622620
"""
623-
logger.warning_once(
624-
"nvfp4_gemm is deprecated. Use nvfp4_gemm_unified instead "
625-
"for automatic backend selection with better performance.",
626-
key="nvfp4_gemm_deprecated")
627621
tuner = AutoTuner.get()
628622

629623
# Use Cutlass runner with predefined configs

tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_blockscaled_gemm_persistent.py

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def __call__(
300300
sfa_tensor: cute.Tensor,
301301
sfb_tensor: cute.Tensor,
302302
c_tensor: cute.Tensor,
303-
alpha: cutlass.Float32,
303+
alpha: cute.Pointer, # Changed from cutlass.Float32 to device pointer
304304
max_active_clusters: cutlass.Constexpr,
305305
stream: cuda.CUstream,
306306
epilogue_op: cutlass.Constexpr = lambda x: x,
@@ -548,34 +548,37 @@ class SharedStorage:
548548
# GPU device kernel
549549
@cute.kernel
550550
def kernel(
551-
self,
552-
tiled_mma: cute.TiledMma,
553-
tiled_mma_sfb: cute.TiledMma,
554-
tma_atom_a: cute.CopyAtom,
555-
mA_mkl: cute.Tensor,
556-
tma_atom_b: cute.CopyAtom,
557-
mB_nkl: cute.Tensor,
558-
tma_atom_sfa: cute.CopyAtom,
559-
mSFA_mkl: cute.Tensor,
560-
tma_atom_sfb: cute.CopyAtom,
561-
mSFB_nkl: cute.Tensor,
562-
tma_atom_c: Optional[cute.CopyAtom],
563-
mC_mnl: cute.Tensor,
564-
cluster_layout_vmnk: cute.Layout,
565-
cluster_layout_sfb_vmnk: cute.Layout,
566-
a_smem_layout_staged: cute.ComposedLayout,
567-
b_smem_layout_staged: cute.ComposedLayout,
568-
sfa_smem_layout_staged: cute.Layout,
569-
sfb_smem_layout_staged: cute.Layout,
570-
c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None],
571-
epi_tile: cute.Tile,
572-
tile_sched_params: utils.PersistentTileSchedulerParams,
573-
epilogue_op: cutlass.Constexpr,
574-
alpha: cutlass.Float32,
551+
self,
552+
tiled_mma: cute.TiledMma,
553+
tiled_mma_sfb: cute.TiledMma,
554+
tma_atom_a: cute.CopyAtom,
555+
mA_mkl: cute.Tensor,
556+
tma_atom_b: cute.CopyAtom,
557+
mB_nkl: cute.Tensor,
558+
tma_atom_sfa: cute.CopyAtom,
559+
mSFA_mkl: cute.Tensor,
560+
tma_atom_sfb: cute.CopyAtom,
561+
mSFB_nkl: cute.Tensor,
562+
tma_atom_c: Optional[cute.CopyAtom],
563+
mC_mnl: cute.Tensor,
564+
cluster_layout_vmnk: cute.Layout,
565+
cluster_layout_sfb_vmnk: cute.Layout,
566+
a_smem_layout_staged: cute.ComposedLayout,
567+
b_smem_layout_staged: cute.ComposedLayout,
568+
sfa_smem_layout_staged: cute.Layout,
569+
sfb_smem_layout_staged: cute.Layout,
570+
c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None],
571+
epi_tile: cute.Tile,
572+
tile_sched_params: utils.PersistentTileSchedulerParams,
573+
epilogue_op: cutlass.Constexpr,
574+
alpha: cute.
575+
Pointer, # Changed from cutlass.Float32 to device pointer
575576
):
576577
"""
577578
GPU device kernel performing the Persistent batched GEMM computation.
578579
"""
580+
alpha_value = alpha.load().to(self.c_dtype)
581+
579582
warp_idx = cute.arch.warp_idx()
580583
warp_idx = cute.arch.make_warp_uniform(warp_idx)
581584

@@ -1248,6 +1251,7 @@ def kernel(
12481251
#
12491252
subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3])
12501253
num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt
1254+
12511255
for subtile_idx in cutlass.range(subtile_cnt):
12521256
#
12531257
# Load accumulator from tensor memory buffer to register
@@ -1259,8 +1263,8 @@ def kernel(
12591263
# Convert to C type
12601264
#
12611265
acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load()
1262-
acc_vec = epilogue_op(
1263-
alpha.to(self.c_dtype) * acc_vec.to(self.c_dtype))
1266+
acc_vec = epilogue_op(alpha_value *
1267+
acc_vec.to(self.c_dtype))
12641268
tRS_rC.store(acc_vec)
12651269

12661270
#
@@ -1940,7 +1944,7 @@ def __call__(
19401944
a_sf_ptr: cute.Pointer,
19411945
b_sf_ptr: cute.Pointer,
19421946
c_ptr: cute.Pointer,
1943-
alpha: cutlass.Float32,
1947+
alpha: cute.Pointer, # Changed from cutlass.Float32 to device pointer
19441948
max_active_clusters: cutlass.Constexpr,
19451949
current_stream: cuda.CUstream,
19461950
swap_ab: cutlass.Constexpr = False,
@@ -1961,7 +1965,7 @@ def __call__(
19611965
a_sf_ptr (cute.Pointer): Pointer to the scale factor tensor for A.
19621966
b_sf_ptr (cute.Pointer): Pointer to the scale factor tensor for B.
19631967
c_ptr (cute.Pointer): Pointer to the C tensor.
1964-
alpha (cutlass.Float32): Scaling factor for the GEMM output.
1968+
alpha (cute.Pointer): Pointer to alpha scaling factor on device (avoids CPU-GPU sync).
19651969
max_active_clusters (cutlass.Constexpr): Maximum number of active
19661970
clusters.
19671971
current_stream (cuda.CUstream): CUDA stream for the operation.

0 commit comments

Comments
 (0)