Skip to content

Commit f3e9336

Browse files
authored
Support weight update for blackwell DeepGEMM (#13324)
1 parent 290fcd8 commit f3e9336

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

python/sglang/srt/models/deepseek_v2.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@
105105
block_quant_dequant,
106106
block_quant_to_tensor_quant,
107107
channel_quant_to_tensor_quant,
108+
inverse_transform_scale_ue8m0,
108109
normalize_e4m3fn_to_e4m3fnuz,
109110
quant_weight_ue8m0,
110111
requant_weight_ue8m0_inplace,
@@ -3270,6 +3271,8 @@ def __init__(
32703271
}
32713272
)
32723273
self.capture_aux_hidden_states = False
3274+
self._executed_weight_requant_ue8m0 = False
3275+
32733276
self.nsa_enable_prefill_cp = is_nsa_enable_prefill_cp()
32743277
if self.nsa_enable_prefill_cp:
32753278
self.cp_rank = get_attention_tp_rank()
@@ -3443,6 +3446,18 @@ def post_load_weights(self, is_nextn=False, weight_names=None):
34433446
weight = w
34443447
weight_scale = self_attn.kv_b_proj.weight_scale_inv
34453448

3449+
if (
3450+
should_deepgemm_weight_requant_ue8m0(
3451+
weight_block_size=getattr(
3452+
self.quant_config, "weight_block_size", None
3453+
)
3454+
)
3455+
and self._executed_weight_requant_ue8m0
3456+
):
3457+
weight_scale = inverse_transform_scale_ue8m0(
3458+
weight_scale, mn=weight.shape[-2]
3459+
)
3460+
34463461
if (
34473462
_is_cuda
34483463
and weight_block_size[0] == 128
@@ -3553,9 +3568,14 @@ def post_load_weights(self, is_nextn=False, weight_names=None):
35533568
self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc.contiguous())
35543569
self_attn.use_deep_gemm_bmm = True
35553570

3556-
if not ENABLE_FLASHINFER_FP8_GEMM and should_deepgemm_weight_requant_ue8m0(
3557-
weight_block_size=getattr(self.quant_config, "weight_block_size", None)
3571+
if (
3572+
not ENABLE_FLASHINFER_FP8_GEMM
3573+
and should_deepgemm_weight_requant_ue8m0(
3574+
weight_block_size=getattr(self.quant_config, "weight_block_size", None)
3575+
)
3576+
and not self._executed_weight_requant_ue8m0
35583577
):
3578+
self._executed_weight_requant_ue8m0 = True
35593579
self._weight_requant_ue8m0(is_nextn)
35603580

35613581
# TODO can move weight_requant_ue8m0 and transform_scale_ue8m0 into Fp8LinearMethod.process_weights_after_loading

0 commit comments

Comments
 (0)