@@ -241,7 +241,15 @@ def _matmul_persistent_deepgemm(
241241 dtype = a .dtype
242242 out = torch .empty ((M , N ), device = a .device , dtype = dtype )
243243
244- deep_gemm .bf16_gemm_nn (a , b , out )
244+ try :
245+ deep_gemm .bf16_gemm_nn (a , b , out )
246+ except RuntimeError as e :
247+ raise RuntimeError (
248+ f"DeepGEMM failed for matrix shapes M={ M } , N={ N } , K={ K } . "
249+ f"This typically occurs when dimensions are too small for DeepGEMM's TMA descriptors. "
250+ f"Consider increasing MIN_DEEPGEMM_DIM in matmul_persistent() or disabling DeepGEMM "
251+ f"for small matrices. Original error: { e } "
252+ ) from e
245253
246254 # TODO can this be put in DeepGEMM's `c`?
247255 if bias is not None :
@@ -253,13 +261,19 @@ def _matmul_persistent_deepgemm(
253261def matmul_persistent (
254262 a : torch .Tensor , b : torch .Tensor , bias : torch .Tensor | None = None
255263):
264+ K , N = b .shape
265+
266+ # DeepGEMM has minimum dimension requirements for TMA descriptors
267+ MIN_DEEPGEMM_DIM = 16
268+
256269 if (
257270 _ENABLE_MM_DEEPGEMM
258271 and ENABLE_JIT_DEEPGEMM
259272 and (a .dtype == torch .bfloat16 )
260273 and (b .dtype == torch .bfloat16 )
261274 and a .is_contiguous ()
262275 and b .transpose (0 , 1 ).is_contiguous ()
276+ and N >= MIN_DEEPGEMM_DIM
263277 ):
264278 if _ENABLE_MM_COMPARISON_TEST :
265279 out_triton = _matmul_persistent_triton (a = a , b = b , bias = bias )
0 commit comments