33
44import torch
55
6+ from tensorrt_llm .logger import logger
7+
68from ..._utils import get_sm_version
79from ...math_utils import pad_up
810from ..autotuner import (AutoTuner , ConstraintSpec , DynamicTensorSpec ,
3234 Sm100BlockScaledPersistentDenseGemmKernel
3335 from ..cute_dsl_kernels .blackwell .utils import make_ptr
3436
35- class CuteDSLNVFP4BlackwellRunner (TunableRunner ):
37+ class CuteDSLNVFP4BlackwellLinear (TunableRunner ):
3638 kernel_class = Sm100BlockScaledPersistentDenseGemmKernel
3739 kernel_cache = dict ()
3840 tuning_config = TuningConfig (
@@ -43,26 +45,44 @@ class CuteDSLNVFP4BlackwellRunner(TunableRunner):
4345 use_cold_l2_cache = True ,
4446 )
4547
46- def __init__ (self , alpha : float , output_dtype : torch .dtype ):
48+ def __init__ (self ,
49+ output_dtype : torch .dtype ,
50+ to_userbuffers : bool = False ):
4751 super ().__init__ ()
48- self .alpha = alpha
49- self .output_dtype = output_dtype
50- assert output_dtype == torch .bfloat16
5152
52- if get_sm_version () not in [ 100 , 103 ] :
53+ if output_dtype != torch . bfloat16 :
5354 raise ValueError (
54- f"SM version { get_sm_version () } is not supported for { self . __class__ . __name__ } , it only supports SM 100 and SM 103 "
55+ f"CuteDSL NVFP4 only supports bfloat16 output, got { output_dtype } "
5556 )
57+ self .output_dtype = output_dtype
58+ self .to_userbuffers = to_userbuffers
5659
5760 def unique_id (self ):
58- return (self .output_dtype , )
61+ return (self .output_dtype , self .to_userbuffers )
62+
63+ def __hash__ (self ):
64+ return hash ((self .output_dtype , self .to_userbuffers ))
65+
66+ def __eq__ (self , other ):
67+ if not isinstance (other , self .__class__ ):
68+ return False
69+ return self .output_dtype == other .output_dtype and self .to_userbuffers == other .to_userbuffers
5970
6071 def get_valid_tactics (
6172 self ,
6273 inputs : List [torch .Tensor ],
6374 profile : OptimizationProfile ,
6475 ** kwargs ,
6576 ) -> List [Tuple [int , int ]]:
77+ # Early exit: Check SM version - CuteDSL NVFP4 only supports SM 100 and SM 103
78+ sm_version = get_sm_version ()
79+ if sm_version not in [100 , 103 ]:
80+ logger .debug (
81+ f"CuteDSL: SM version { sm_version } is not supported. "
82+ f"CuteDSL NVFP4 only supports SM 100 (B200) and SM 103 (B300). Skipping all tactics."
83+ )
84+ return []
85+
6686 assert inputs [0 ].dim () == 2
6787 assert inputs [1 ].dim () == 2
6888
@@ -73,11 +93,44 @@ def get_valid_tactics(
7393 real_k = k * 2
7494 batch_size = 1
7595 sf_vec_size = 16
76- # m,k
96+
97+ # Fixed layout for FP4: A and B are always K-major
7798 a_major = "k"
78- # n, k
7999 b_major = "k"
80100
101+ # Early exit: Check K dimension alignment
102+ # For K-major layout (A and B tensors), K is the major mode (contiguous dimension).
103+ # 16-byte alignment requirement: K must be divisible by 32 for FP4 (128 bits / 4 bits = 32)
104+ if real_k % 32 != 0 :
105+ logger .debug (
106+ f"CuteDSL: K={ real_k } does not meet 16-byte alignment requirement "
107+ f"(K%32={ real_k % 32 } , expected 0). Skipping all tactics." )
108+ return []
109+
110+ # Optimize swap_ab candidates based on M and N alignment
111+ # swap_ab=False → C is N-major → requires N%8==0 (BF16: 128 bits / 16 bits = 8)
112+ # swap_ab=True → C is M-major → requires M%8==0
113+ m_aligned = (m % 8 == 0 )
114+ n_aligned = (n % 8 == 0 )
115+
116+ if not m_aligned and not n_aligned :
117+ logger .debug (
118+ f"CuteDSL: Neither M={ m } nor N={ n } meets 16-byte alignment "
119+ f"(M%8={ m % 8 } , N%8={ n % 8 } ). No valid C layout. Skipping all tactics."
120+ )
121+ return []
122+
123+ # Only test swap_ab values that satisfy alignment
124+ swap_ab_candidates = []
125+ if n_aligned :
126+ swap_ab_candidates .append (False ) # N-major layout
127+ if m_aligned :
128+ swap_ab_candidates .append (True ) # M-major layout
129+
130+ logger .debug (
131+ f"CuteDSL: M={ m } (aligned={ m_aligned } ), N={ n } (aligned={ n_aligned } ), K={ real_k } (aligned=True). "
132+ f"Testing swap_ab={ swap_ab_candidates } " )
133+
81134 # full shamoo
82135 mma_tiler_mn_candidates = [
83136 (128 , 64 ),
@@ -134,6 +187,9 @@ def get_valid_tactics(
134187 valid_tactics .append (
135188 (mma_tiler_mn , cluster_shape_mn , swap_ab , use_prefetch ))
136189
190+ logger .debug (
191+ f"CuteDSL: Found { len (valid_tactics )} valid tactics for M={ m } , N={ n } , K={ real_k } "
192+ )
137193 return valid_tactics
138194
139195 def make_cute_dsl_global_pointer (self , tensor : torch .Tensor , dtype ,
@@ -149,6 +205,7 @@ def forward(
149205 self ,
150206 inputs : List [torch .Tensor ],
151207 tactic ,
208+ ** kwargs ,
152209 ) -> torch .Tensor :
153210 """
154211 Performs fp8 blockwise gemm operation using CuTe DSL.
@@ -160,8 +217,7 @@ def forward(
160217 inputs[2]: Input scale tensor of shape (k//16, m), dtype: fp8.
161218 inputs[3]: Weight scale tensor of shape (n, k//16), dtype: fp8.
162219 inputs[4]: Alpha scaling factor. dtype: float32.
163- inputs[5]: Output dtype, expected to be torch.bfloat16.
164- tactic: Tiling and cluster strategy, typically a tuple (mma_tiler_mn, cluster_shape_mn, swap_ab, use_prefetch).
220+ tactic: Tiling and cluster strategy, typically a tuple (mma_tiler_mn, cluster_shape_mn).
165221
166222 Returns:
167223 torch.Tensor: Output tensor of shape (m, n), dtype: bf16.
@@ -179,11 +235,17 @@ def forward(
179235 False ,
180236 ]
181237
182- a_tensor , b_tensor , a_sf_tensor , b_sf_tensor = inputs
238+ a_tensor , b_tensor , a_sf_tensor , b_sf_tensor , alpha_tensor = inputs
183239 m , k , n = a_tensor .shape [0 ], a_tensor .shape [1 ], b_tensor .shape [0 ]
184- c_tensor = torch .empty (* (m , n ),
185- dtype = self .output_dtype ,
186- device = "cuda" )
240+
241+ # Allocate output tensor from UserBuffers or regular CUDA memory
242+ if self .to_userbuffers :
243+ c_tensor = torch .ops .trtllm .create_userbuffers_tensor (
244+ [m , n ], self .output_dtype )
245+ else :
246+ c_tensor = torch .empty (* (m , n ),
247+ dtype = self .output_dtype ,
248+ device = "cuda" )
187249
188250 if swap_ab :
189251 c_tensor = c_tensor .permute (1 , 0 )
@@ -193,9 +255,27 @@ def forward(
193255 sf_k = pad_up (real_k // sf_vec_size , 4 )
194256 sf_n = pad_up (n , 128 )
195257
196- # the scaling tensor is 1D. we need to make sure it has been padded to the correct shape
197- assert a_sf_tensor .shape == (sf_m * sf_k , )
198- assert b_sf_tensor .shape == (sf_n * sf_k , )
258+ # Reshape scale factors to CuteDSL's expected format
259+ # Input format (from CUTLASS/cuBLASLt): (m*k//16,) and (n*k//16,)
260+ # CuteDSL format: (sf_m*sf_k,) and (sf_n*sf_k,)
261+ # Note: This is just a view change, no memory copy
262+ expected_a_sf_size = sf_m * sf_k
263+ expected_b_sf_size = sf_n * sf_k
264+
265+ if a_sf_tensor .numel () != expected_a_sf_size :
266+ raise ValueError (
267+ f"CuteDSL: act scale factor size mismatch. "
268+ f"Expected { expected_a_sf_size } (sf_m={ sf_m } * sf_k={ sf_k } ), "
269+ f"got { a_sf_tensor .numel ()} for shape M={ m } , K={ real_k } " )
270+ if b_sf_tensor .numel () != expected_b_sf_size :
271+ raise ValueError (
272+ f"CuteDSL: weight scale factor size mismatch. "
273+ f"Expected { expected_b_sf_size } (sf_n={ sf_n } * sf_k={ sf_k } ), "
274+ f"got { b_sf_tensor .numel ()} for shape N={ n } , K={ real_k } " )
275+
276+ # Reshape to CuteDSL's expected format (just a view, no copy)
277+ a_sf_tensor = a_sf_tensor .reshape (sf_m * sf_k )
278+ b_sf_tensor = b_sf_tensor .reshape (sf_n * sf_k )
199279
200280 a_ptr = self .make_cute_dsl_global_pointer (a_tensor ,
201281 cutlass .Float4E2M1FN , 32 )
@@ -207,6 +287,9 @@ def forward(
207287 b_sf_tensor , cutlass .Float8E4M3FN , 16 )
208288 c_ptr = self .make_cute_dsl_global_pointer (c_tensor ,
209289 cutlass .BFloat16 , 16 )
290+ # Create pointer to alpha on device
291+ alpha_ptr = self .make_cute_dsl_global_pointer (
292+ alpha_tensor , cutlass .Float32 , 4 )
210293
211294 # get stream
212295 torch_stream = torch .cuda .current_stream ()
@@ -259,7 +342,7 @@ def forward(
259342 kernel_a_sf_ptr ,
260343 kernel_b_sf_ptr ,
261344 c_ptr ,
262- self . alpha ,
345+ alpha_ptr , # Pass alpha as device pointer
263346 max_active_clusters ,
264347 stream ,
265348 swap_ab ,
@@ -283,7 +366,7 @@ def forward(
283366 kernel_a_sf_ptr ,
284367 kernel_b_sf_ptr ,
285368 c_ptr ,
286- self . alpha ,
369+ alpha_ptr , # Pass alpha as device pointer
287370 stream ,
288371 )
289372
@@ -300,20 +383,45 @@ def cute_dsl_nvfp4_gemm_blackwell(
300383 weight : torch .Tensor ,
301384 input_scale : torch .Tensor ,
302385 weight_scale : torch .Tensor ,
303- alpha : float ,
386+ alpha : torch . Tensor ,
304387 output_dtype : torch .dtype ,
388+ to_userbuffers : bool = False ,
305389 ) -> torch .Tensor :
390+ """CuteDSL-based NVFP4 GEMM optimized for Blackwell.
391+
392+ Args:
393+ input: Activation tensor [m, k] in FP4 format (packed in uint8)
394+ weight: Weight tensor [n, k] in FP4 format (packed in uint8)
395+ input_scale: Activation scale factors
396+ weight_scale: Weight scale factors
397+ alpha: Scaling factor
398+ output_dtype: Output data type (must be bfloat16)
399+ to_userbuffers: Whether to allocate output from UserBuffers pool
400+
401+ Note:
402+ This function is primarily used internally by nvfp4_gemm.
403+ Direct usage is discouraged. Consider using nvfp4_gemm instead
404+ for automatic backend selection with better performance.
405+ """
406+ # Validate SM version before attempting to use CuteDSL
407+ sm_version = get_sm_version ()
408+ if sm_version not in [100 , 103 ]:
409+ raise ValueError (
410+ f"CuteDSL NVFP4 backend requires SM 100 (B200) or SM 103 (B300), but got SM { sm_version } . "
411+ f"Please use nvfp4_gemm with backend='auto' for automatic backend selection."
412+ )
306413
307414 tuner = AutoTuner .get ()
308415
309- runner = CuteDSLNVFP4BlackwellRunner ( alpha , output_dtype )
310- inputs = [input , weight , input_scale , weight_scale ]
416+ runner = CuteDSLNVFP4BlackwellLinear ( output_dtype , to_userbuffers )
417+ inputs = [input , weight , input_scale , weight_scale , alpha ]
311418 _ , best_tactic = tuner .choose_one (
312419 "trtllm::cute_dsl_nvfp4_gemm_blackwell" ,
313420 [runner ],
314421 runner .__class__ .tuning_config ,
315422 inputs ,
316423 )
424+
317425 output = runner (inputs , tactic = best_tactic )
318426 return output
319427
@@ -323,8 +431,9 @@ def _(
323431 mat_b : torch .Tensor ,
324432 input_scale : torch .Tensor ,
325433 weight_scale : torch .Tensor ,
326- alpha : float ,
434+ alpha : torch . Tensor , # Match custom op signature
327435 output_dtype : torch .dtype ,
436+ to_userbuffers : bool = False ,
328437 ):
329438 # [m, k]
330439 shape = list (mat_a .shape )
0 commit comments