Skip to content

Commit 8a43734

Browse files
authored
re-submit 12911 but relax the requirement for deepgemm (#13226)
1 parent f0021c0 commit 8a43734

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,15 @@ def _matmul_persistent_deepgemm(
241241
dtype = a.dtype
242242
out = torch.empty((M, N), device=a.device, dtype=dtype)
243243

244-
deep_gemm.bf16_gemm_nn(a, b, out)
244+
try:
245+
deep_gemm.bf16_gemm_nn(a, b, out)
246+
except RuntimeError as e:
247+
raise RuntimeError(
248+
f"DeepGEMM failed for matrix shapes M={M}, N={N}, K={K}. "
249+
f"This typically occurs when dimensions are too small for DeepGEMM's TMA descriptors. "
250+
f"Consider increasing MIN_DEEPGEMM_DIM in matmul_persistent() or disabling DeepGEMM "
251+
f"for small matrices. Original error: {e}"
252+
) from e
245253

246254
# TODO can this be put in DeepGEMM's `c`?
247255
if bias is not None:
@@ -253,13 +261,19 @@ def _matmul_persistent_deepgemm(
253261
def matmul_persistent(
254262
a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None
255263
):
264+
K, N = b.shape
265+
266+
# DeepGEMM has minimum dimension requirements for TMA descriptors
267+
MIN_DEEPGEMM_DIM = 16
268+
256269
if (
257270
_ENABLE_MM_DEEPGEMM
258271
and ENABLE_JIT_DEEPGEMM
259272
and (a.dtype == torch.bfloat16)
260273
and (b.dtype == torch.bfloat16)
261274
and a.is_contiguous()
262275
and b.transpose(0, 1).is_contiguous()
276+
and N >= MIN_DEEPGEMM_DIM
263277
):
264278
if _ENABLE_MM_COMPARISON_TEST:
265279
out_triton = _matmul_persistent_triton(a=a, b=b, bias=bias)

0 commit comments

Comments
 (0)