66from tensorrt_llm ._utils import get_sm_version
77from tensorrt_llm .math_utils import pad_up
88
9- from ..autotuner import (AutoTuner , ConstraintSpec , DynamicTensorSpec ,
10- OptimizationProfile , TunableRunner , TuningConfig )
9+ from ..autotuner import (ConstraintSpec , DynamicTensorSpec , OptimizationProfile ,
10+ TunableRunner , TuningConfig )
1111from ..cute_dsl_utils import IS_CUTLASS_DSL_AVAILABLE
1212from ..utils import (fp4_scale_infer_shape ,
1313 get_last_power_of_2_num_tokens_buckets ,
@@ -38,16 +38,21 @@ class CuteDSLNVFP4BlackwellLinear(TunableRunner):
3838 constraint_specs = (ConstraintSpec (2 , 0 , fp4_scale_infer_shape ), ),
3939 )
4040
41- def __init__ (self , alpha : float , output_dtype : torch .dtype ):
41+ def __init__ (self , output_dtype : torch .dtype ):
4242 super ().__init__ ()
43- self .alpha = alpha
43+
44+ # Validate output dtype (use proper exception instead of assert)
45+ if output_dtype != torch .bfloat16 :
46+ raise ValueError (
47+ f"CuteDSL NVFP4 only supports bfloat16 output, got { output_dtype } "
48+ )
4449 self .output_dtype = output_dtype
45- assert output_dtype == torch .bfloat16
4650
51+ # Validate SM version at initialization
4752 if get_sm_version () != 100 :
4853 raise ValueError (
49- f"SM version { get_sm_version ()} is not supported for CuteDSLNVFP4BlackwellLinear, it only supports SM 100 "
50- )
54+ f"SM version { get_sm_version ()} is not supported. "
55+ f"CuteDSL NVFP4 requires SM 100 (Blackwell)." )
5156
5257 # rewrite the hash function because the value of self.alpha doesn't affect the tactic.
5358 def __hash__ (self ):
@@ -147,6 +152,7 @@ def forward(
147152 self ,
148153 inputs : List [torch .Tensor ],
149154 tactic ,
155+ ** kwargs ,
150156 ) -> torch .Tensor :
151157 """
152158 Performs fp8 blockwise gemm operation using CuTe DSL.
@@ -158,7 +164,6 @@ def forward(
158164 inputs[2]: Input scale tensor of shape (k//16, m), dtype: fp8.
159165 inputs[3]: Weight scale tensor of shape (n, k//16), dtype: fp8.
160166 inputs[4]: Alpha scaling factor. dtype: float32.
161- inputs[5]: Output dtype, expected to be torch.bfloat16.
162167 tactic: Tiling and cluster strategy, typically a tuple (mma_tiler_mn, cluster_shape_mn).
163168
164169 Returns:
@@ -176,7 +181,7 @@ def forward(
176181 False ,
177182 ]
178183
179- a_tensor , b_tensor , a_sf_tensor , b_sf_tensor = inputs
184+ a_tensor , b_tensor , a_sf_tensor , b_sf_tensor , alpha_tensor = inputs
180185 m , k , n = a_tensor .shape [0 ], a_tensor .shape [1 ], b_tensor .shape [0 ]
181186 c_tensor = torch .empty (* (m , n ),
182187 dtype = self .output_dtype ,
@@ -204,6 +209,9 @@ def forward(
204209 b_sf_tensor , cutlass .Float8E4M3FN , 16 )
205210 c_ptr = self .make_cute_dsl_global_pointer (c_tensor ,
206211 cutlass .BFloat16 , 16 )
212+ # Create pointer to alpha on device
213+ alpha_ptr = self .make_cute_dsl_global_pointer (
214+ alpha_tensor , cutlass .Float32 , 4 )
207215
208216 # get stream
209217 torch_stream = torch .cuda .current_stream ()
@@ -260,7 +268,7 @@ def forward(
260268 kernel_a_sf_ptr ,
261269 kernel_b_sf_ptr ,
262270 c_ptr ,
263- self . alpha ,
271+ alpha_ptr , # Pass alpha as device pointer
264272 max_active_clusters ,
265273 stream ,
266274 swap_ab ,
@@ -285,7 +293,7 @@ def forward(
285293 kernel_a_sf_ptr ,
286294 kernel_b_sf_ptr ,
287295 c_ptr ,
288- self . alpha ,
296+ alpha_ptr , # Pass alpha as device pointer
289297 stream ,
290298 )
291299
@@ -302,34 +310,31 @@ def cute_dsl_nvfp4_gemm_blackwell(
302310 weight : torch .Tensor ,
303311 input_scale : torch .Tensor ,
304312 weight_scale : torch .Tensor ,
305- alpha : float ,
313+ alpha : torch . Tensor ,
306314 output_dtype : torch .dtype ,
307315 ) -> torch .Tensor :
308316 """CuteDSL-based NVFP4 GEMM optimized for Blackwell.
309317
310- .. deprecated: :
311- Use :func:`torch.ops.trtllm.nvfp4_gemm_unified` instead for automatic
312- backend selection among CUTLASS, cuBLASLt, and CuteDSL based on
313- performance profiling .
318+ Note :
319+ This function is primarily used internally by nvfp4_gemm_unified.
320+ Direct usage is discouraged. Consider using nvfp4_gemm_unified instead
321+ for automatic backend selection with better performance .
314322 """
315- from tensorrt_llm .logger import logger
316- logger .warning_once (
317- "cute_dsl_nvfp4_gemm_blackwell is deprecated. Use nvfp4_gemm_unified instead "
318- "for automatic backend selection with better performance." ,
319- key = "cute_dsl_nvfp4_gemm_blackwell_deprecated" )
323+ from tensorrt_llm ._torch .autotuner import AutoTuner
320324
321325 tuner = AutoTuner .get ()
322326
323- cute_dsl_nvfp4_gemm_blackwell_runner = CuteDSLNVFP4BlackwellLinear (
324- alpha , output_dtype )
327+ runner = CuteDSLNVFP4BlackwellLinear (output_dtype )
328+
325329 _ , best_tactic = tuner .choose_one (
326330 "trtllm::cute_dsl_nvfp4_gemm_blackwell" ,
327- [cute_dsl_nvfp4_gemm_blackwell_runner ],
331+ [runner ],
328332 CuteDSLNVFP4BlackwellLinear .tuning_config ,
329333 [input , weight , input_scale , weight_scale ],
330334 )
331- return cute_dsl_nvfp4_gemm_blackwell_runner (
332- inputs = [input , weight , input_scale , weight_scale ],
335+
336+ return runner (
337+ inputs = [input , weight , input_scale , weight_scale , alpha ],
333338 tactic = best_tactic ,
334339 )
335340
@@ -339,7 +344,7 @@ def _(
339344 mat_b : torch .Tensor ,
340345 input_scale : torch .Tensor ,
341346 weight_scale : torch .Tensor ,
342- alpha : float ,
347+ alpha : torch . Tensor , # Match custom op signature
343348 output_dtype : torch .dtype ,
344349 ):
345350 # [m, k]
0 commit comments