|
105 | 105 | block_quant_dequant, |
106 | 106 | block_quant_to_tensor_quant, |
107 | 107 | channel_quant_to_tensor_quant, |
| 108 | + inverse_transform_scale_ue8m0, |
108 | 109 | normalize_e4m3fn_to_e4m3fnuz, |
109 | 110 | quant_weight_ue8m0, |
110 | 111 | requant_weight_ue8m0_inplace, |
@@ -3270,6 +3271,8 @@ def __init__( |
3270 | 3271 | } |
3271 | 3272 | ) |
3272 | 3273 | self.capture_aux_hidden_states = False |
| 3274 | + self._executed_weight_requant_ue8m0 = False |
| 3275 | + |
3273 | 3276 | self.nsa_enable_prefill_cp = is_nsa_enable_prefill_cp() |
3274 | 3277 | if self.nsa_enable_prefill_cp: |
3275 | 3278 | self.cp_rank = get_attention_tp_rank() |
@@ -3443,6 +3446,18 @@ def post_load_weights(self, is_nextn=False, weight_names=None): |
3443 | 3446 | weight = w |
3444 | 3447 | weight_scale = self_attn.kv_b_proj.weight_scale_inv |
3445 | 3448 |
|
| 3449 | + if ( |
| 3450 | + should_deepgemm_weight_requant_ue8m0( |
| 3451 | + weight_block_size=getattr( |
| 3452 | + self.quant_config, "weight_block_size", None |
| 3453 | + ) |
| 3454 | + ) |
| 3455 | + and self._executed_weight_requant_ue8m0 |
| 3456 | + ): |
| 3457 | + weight_scale = inverse_transform_scale_ue8m0( |
| 3458 | + weight_scale, mn=weight.shape[-2] |
| 3459 | + ) |
| 3460 | + |
3446 | 3461 | if ( |
3447 | 3462 | _is_cuda |
3448 | 3463 | and weight_block_size[0] == 128 |
@@ -3553,9 +3568,14 @@ def post_load_weights(self, is_nextn=False, weight_names=None): |
3553 | 3568 | self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc.contiguous()) |
3554 | 3569 | self_attn.use_deep_gemm_bmm = True |
3555 | 3570 |
|
3556 | | - if not ENABLE_FLASHINFER_FP8_GEMM and should_deepgemm_weight_requant_ue8m0( |
3557 | | - weight_block_size=getattr(self.quant_config, "weight_block_size", None) |
| 3571 | + if ( |
| 3572 | + not ENABLE_FLASHINFER_FP8_GEMM |
| 3573 | + and should_deepgemm_weight_requant_ue8m0( |
| 3574 | + weight_block_size=getattr(self.quant_config, "weight_block_size", None) |
| 3575 | + ) |
| 3576 | + and not self._executed_weight_requant_ue8m0 |
3558 | 3577 | ): |
| 3578 | + self._executed_weight_requant_ue8m0 = True |
3559 | 3579 | self._weight_requant_ue8m0(is_nextn) |
3560 | 3580 |
|
3561 | 3581 | # TODO can move weight_requant_ue8m0 and transform_scale_ue8m0 into Fp8LinearMethod.process_weights_after_loading |
|
0 commit comments