Skip to content

Commit 8e0c15d

Browse files
authored
support qwen3 fp8 (#3505)
* support qwen3 fp8 * fix quant
1 parent 6f65b74 commit 8e0c15d

File tree

3 files changed

+127
-5
lines changed

3 files changed

+127
-5
lines changed

lmdeploy/pytorch/models/qwen3.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
403403
('.gate_up_proj', '.up_proj', 1),
404404
]
405405

406+
scale_suffix = '.weight_scale_inv'
406407
params_dict = dict(self.named_parameters())
407408
for name, loaded_weight in weights:
408409
if 'rotary_emb.inv_freq' in name:
@@ -411,6 +412,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
411412
continue
412413
if self.config.tie_word_embeddings and 'lm_head.weight' in name:
413414
continue
415+
if name.endswith(scale_suffix):
416+
name = name[:-len(scale_suffix)] + '.scale'
417+
414418
for (param_name, weight_name, shard_id) in stacked_params_mapping:
415419
if weight_name not in name:
416420
continue

lmdeploy/pytorch/models/qwen3_moe.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,8 @@ def __init__(self,
176176
dtype: torch.dtype = None,
177177
device: torch.device = None):
178178
super().__init__()
179+
# TODO: zhouxinyu, determine modules_to_not_convert from config file
180+
quantization_config = getattr(config, 'quantization_config', None)
179181
self.layer_idx = layer_idx
180182
self.hidden_dim = config.hidden_size
181183
self.ffn_dim = config.moe_intermediate_size
@@ -206,6 +208,7 @@ def __init__(self,
206208
renormalize=self.renormalize,
207209
dtype=dtype,
208210
device=device,
211+
quant_config=quantization_config,
209212
all_reduce=_all_reduce,
210213
)
211214

@@ -492,6 +495,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
492495
down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down')
493496
expert_params_mapping += [gate_param, up_param, down_param]
494497

498+
scale_suffix = '.weight_scale_inv'
495499
params_dict = dict(self.named_parameters())
496500
for name, loaded_weight in weights:
497501
if 'rotary_emb.inv_freq' in name:
@@ -500,6 +504,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
500504
continue
501505
if self.config.tie_word_embeddings and 'lm_head.weight' in name:
502506
continue
507+
if name.endswith(scale_suffix):
508+
name = name[:-len(scale_suffix)] + '.scale'
503509

504510
if '.experts' in name:
505511
self._load_weight_experts(name, loaded_weight, params_dict, expert_params_mapping=expert_params_mapping)

lmdeploy/pytorch/nn/linear.py

Lines changed: 117 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -393,12 +393,15 @@ def __init__(self,
393393
is_tp=is_tp,
394394
dp_gather=dp_gather)
395395
self.weight.weight_loader = self.weight_loader
396+
self.weight._weight_type = 'qweight'
396397
self.scale.weight_loader = self.weight_loader
398+
self.scale._weight_type = 'scales'
397399
self.weight.weight_spliter = self.weight_spliter
398400
self.scale.weight_spliter = self.weight_spliter
399401
if self.bias is not None:
400402
self.bias.weight_loader = self.weight_loader
401403
self.bias.weight_spliter = self.weight_spliter
404+
self.bias._weight_type = 'bias'
402405

403406
def _get_io_features(self, in_features: int, out_features: int, colwise: bool):
404407
"""get io features."""
@@ -419,7 +422,8 @@ def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor,
419422
"""weight loader."""
420423
world_size, rank = _get_tp_world_rank(self.is_tp)
421424
shard_idx = self.out_names_map[shard_id]
422-
if loaded_weight.dim() == 2 and loaded_weight.dtype == torch.float32:
425+
if loaded_weight.dim() == 2 and loaded_weight.dtype != self.fp8_dtype:
426+
loaded_weight = loaded_weight.to(torch.float32)
423427
all_out_features = [feats // self.block_size for feats in self.all_out_features]
424428
param_w = param.data.split(all_out_features, 0)[shard_idx]
425429
else:
@@ -430,14 +434,93 @@ def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor,
430434

431435
def weight_spliter(self, loaded_weight: torch.Tensor):
432436
"""weight spliter."""
433-
if loaded_weight.dim() == 2 and loaded_weight.dtype == torch.float32:
437+
if loaded_weight.dim() == 2 and loaded_weight.dtype != self.fp8_dtype:
434438
return loaded_weight.split(self.scale_split_section, dim=0)
435439
return loaded_weight.split(self.split_section, dim=0)
436440

437441
def weight_spliter_lora_b(self, loaded_weight: torch.Tensor):
438442
return loaded_weight.split(self.split_section, dim=0)
439443

440444

445+
class QKVBlockedF8Linear(MergedBlockedF8Linear, QKVMixin):
446+
"""qkv blockedf8 linear."""
447+
448+
def __init__(self,
449+
in_features: int,
450+
num_q_heads: int,
451+
num_kv_heads: int,
452+
head_size: int,
453+
head_size_v: int,
454+
bias: bool = False,
455+
fp8_dtype: torch.dtype = torch.float8_e4m3fn,
456+
dtype: Optional[torch.dtype] = None,
457+
device: Optional[torch.device] = None,
458+
is_tp: bool = True,
459+
dp_gather: bool = False,
460+
num_replicate_kv_heads: int = 1):
461+
self.is_tp = is_tp
462+
self.block_size = 128
463+
self.qkv_split_section = self._get_qkv_out_features(num_q_heads, num_kv_heads, head_size, head_size_v,
464+
num_replicate_kv_heads)
465+
466+
num_q_heads, num_kv_heads = self._update_num_heads(num_q_heads, num_kv_heads)
467+
all_out_features = self._get_qkv_out_features(num_q_heads, num_kv_heads, head_size, head_size_v)
468+
out_names = ('q', 'k', 'v')
469+
self.num_q_heads = num_q_heads
470+
self.num_kv_heads = num_kv_heads
471+
self.head_size = head_size
472+
self.head_size_v = head_size_v
473+
self.num_replicate_kv_heads = num_replicate_kv_heads
474+
475+
super().__init__(in_features,
476+
all_out_features,
477+
dtype=dtype,
478+
fp8_dtype=fp8_dtype,
479+
bias=bias,
480+
device=device,
481+
is_tp=is_tp,
482+
out_names=out_names,
483+
dp_gather=dp_gather)
484+
485+
def _update_all_out_features(self, all_out_features: List[int], replicate: Optional[List[bool]]):
486+
"""update all out features."""
487+
return all_out_features
488+
489+
def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any):
490+
"""weight loader."""
491+
_, rank = _get_tp_world_rank(self.is_tp)
492+
shard_idx = self.out_names_map[shard_id]
493+
494+
num_head = self.num_q_heads if shard_id == 'q' \
495+
else self.num_kv_heads
496+
head_dim = self.head_size if shard_id in ['q', 'k'] \
497+
else self.head_size_v
498+
# update to duplicate k/v for tp_size > num_kv_heads
499+
rank_idx = rank if shard_id == 'q' \
500+
else rank // self.num_replicate_kv_heads
501+
sec_len = num_head * head_dim
502+
all_out_features = self.all_out_features
503+
if param._weight_type == 'scales':
504+
loaded_weight = loaded_weight.to(torch.float32)
505+
all_out_features = [sec // self.block_size for sec in all_out_features]
506+
sec_len = sec_len // self.block_size
507+
508+
sec_start = rank_idx * sec_len
509+
510+
loaded_weight = loaded_weight.narrow(dim=0, start=sec_start, length=sec_len)
511+
param_w = param.data.split(all_out_features, 0)[shard_idx]
512+
param_w.copy_(loaded_weight)
513+
514+
def weight_spliter(self, loaded_weight: torch.Tensor, layout: str = 'default'):
515+
"""weight spliter."""
516+
_check_qkv_split_layout(layout)
517+
assert layout == 'default'
518+
qkv_split_section = self.qkv_split_section
519+
if loaded_weight.dim() == 2 and loaded_weight.dtype != self.fp8_dtype:
520+
qkv_split_section = [sec // self.block_size for sec in qkv_split_section]
521+
return loaded_weight.split(qkv_split_section, dim=0)
522+
523+
441524
class AwqLinear(nn.Module):
442525
"""w4a16 linear."""
443526

@@ -1591,10 +1674,19 @@ def build_qkv_proj(in_features: int,
15911674
dtype: Optional[torch.dtype] = None,
15921675
device: Optional[torch.device] = None,
15931676
is_tp: bool = True,
1594-
num_replicate_kv_heads: int = 1):
1677+
num_replicate_kv_heads: int = 1,
1678+
dp_disable_tp: bool = False,
1679+
all_reduce: bool = False,
1680+
dp_gather: bool = False):
15951681
"""build qkv proj."""
1596-
if is_tp:
1597-
is_tp, _ = _get_dp_tp_meta()
1682+
if dp_disable_tp and is_tp:
1683+
is_tp, _ = _get_dp_tp_meta(all_reduce)
1684+
elif is_tp:
1685+
is_tp = get_tp_world_rank()[0] > 1
1686+
1687+
if dp_gather:
1688+
assert not dp_disable_tp
1689+
dp_gather = _get_dp_gather(is_tp)
15981690

15991691
if head_size_v is None:
16001692
head_size_v = head_size
@@ -1642,6 +1734,26 @@ def build_qkv_proj(in_features: int,
16421734
is_tp=is_tp,
16431735
num_replicate_kv_heads=num_replicate_kv_heads,
16441736
quant_dtype=quant_dtype)
1737+
if quant_method == 'fp8':
1738+
fmt = quant_config.get('fmt', 'e4m3')
1739+
if fmt == 'e4m3':
1740+
fp8_dtype = torch.float8_e4m3fn
1741+
elif fmt == 'e5m2':
1742+
fp8_dtype = torch.float8_e5m2
1743+
else:
1744+
raise TypeError(f'Unsupported fp8 fmt: {fmt}')
1745+
return QKVBlockedF8Linear(in_features=in_features,
1746+
num_q_heads=num_q_heads,
1747+
num_kv_heads=num_kv_heads,
1748+
head_size=head_size,
1749+
head_size_v=head_size_v,
1750+
bias=bias,
1751+
fp8_dtype=fp8_dtype,
1752+
dtype=dtype,
1753+
device=device,
1754+
is_tp=is_tp,
1755+
dp_gather=dp_gather,
1756+
num_replicate_kv_heads=num_replicate_kv_heads)
16451757
else:
16461758
raise RuntimeError(f'Unsupported quant method: {quant_method}')
16471759

0 commit comments

Comments
 (0)