Skip to content

Commit 17a7946

Browse files
authored
fix loading of fp8 models with bf16 weight_scale (#2141)
#2108 assumes that the dtype for `scale` or `scale_inv` is `float32`, while it might be `bfloat16` for some models like the fp8 Qwen3 dense models. Signed-off-by: Youlei Yang <[email protected]>
1 parent baa60de commit 17a7946

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

vllm/model_executor/model_loader/weight_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -477,8 +477,8 @@ def safetensors_weights_iterator(
477477
if param.dtype == torch.float8_e4m3fn:
478478
param = (param.float() * fp8_e4m3fnuz_max /
479479
fp8_e4m3fn_max).to(torch.float8_e4m3fnuz)
480-
elif param.dtype == torch.float32 and "scale" in name.split(
481-
".")[-1]:
480+
elif param.dtype in [torch.float32, torch.bfloat16
481+
] and "scale" in name.split(".")[-1]:
482482
param *= fp8_e4m3fn_max / fp8_e4m3fnuz_max
483483
yield name, param
484484

@@ -539,8 +539,8 @@ def fastsafetensors_weights_iterator(
539539
if t.dtype == torch.float8_e4m3fn:
540540
t = (t.float() * fp8_e4m3fnuz_max /
541541
fp8_e4m3fn_max).to(torch.float8_e4m3fnuz)
542-
elif t.dtype == torch.float32 and "scale" in k.split(
543-
".")[-1]:
542+
elif t.dtype in [torch.float32, torch.bfloat16
543+
] and "scale" in k.split(".")[-1]:
544544
t *= fp8_e4m3fn_max / fp8_e4m3fnuz_max
545545
yield k, t
546546
finally:

0 commit comments

Comments
 (0)