@@ -393,12 +393,15 @@ def __init__(self,
393393 is_tp = is_tp ,
394394 dp_gather = dp_gather )
395395 self .weight .weight_loader = self .weight_loader
396+ self .weight ._weight_type = 'qweight'
396397 self .scale .weight_loader = self .weight_loader
398+ self .scale ._weight_type = 'scales'
397399 self .weight .weight_spliter = self .weight_spliter
398400 self .scale .weight_spliter = self .weight_spliter
399401 if self .bias is not None :
400402 self .bias .weight_loader = self .weight_loader
401403 self .bias .weight_spliter = self .weight_spliter
404+ self .bias ._weight_type = 'bias'
402405
403406 def _get_io_features (self , in_features : int , out_features : int , colwise : bool ):
404407 """get io features."""
@@ -419,7 +422,8 @@ def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor,
419422 """weight loader."""
420423 world_size , rank = _get_tp_world_rank (self .is_tp )
421424 shard_idx = self .out_names_map [shard_id ]
422- if loaded_weight .dim () == 2 and loaded_weight .dtype == torch .float32 :
425+ if loaded_weight .dim () == 2 and loaded_weight .dtype != self .fp8_dtype :
426+ loaded_weight = loaded_weight .to (torch .float32 )
423427 all_out_features = [feats // self .block_size for feats in self .all_out_features ]
424428 param_w = param .data .split (all_out_features , 0 )[shard_idx ]
425429 else :
@@ -430,14 +434,93 @@ def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor,
430434
431435 def weight_spliter (self , loaded_weight : torch .Tensor ):
432436 """weight spliter."""
433- if loaded_weight .dim () == 2 and loaded_weight .dtype == torch . float32 :
437+ if loaded_weight .dim () == 2 and loaded_weight .dtype != self . fp8_dtype :
434438 return loaded_weight .split (self .scale_split_section , dim = 0 )
435439 return loaded_weight .split (self .split_section , dim = 0 )
436440
437441 def weight_spliter_lora_b (self , loaded_weight : torch .Tensor ):
438442 return loaded_weight .split (self .split_section , dim = 0 )
439443
440444
445+ class QKVBlockedF8Linear (MergedBlockedF8Linear , QKVMixin ):
446+ """qkv blockedf8 linear."""
447+
448+ def __init__ (self ,
449+ in_features : int ,
450+ num_q_heads : int ,
451+ num_kv_heads : int ,
452+ head_size : int ,
453+ head_size_v : int ,
454+ bias : bool = False ,
455+ fp8_dtype : torch .dtype = torch .float8_e4m3fn ,
456+ dtype : Optional [torch .dtype ] = None ,
457+ device : Optional [torch .device ] = None ,
458+ is_tp : bool = True ,
459+ dp_gather : bool = False ,
460+ num_replicate_kv_heads : int = 1 ):
461+ self .is_tp = is_tp
462+ self .block_size = 128
463+ self .qkv_split_section = self ._get_qkv_out_features (num_q_heads , num_kv_heads , head_size , head_size_v ,
464+ num_replicate_kv_heads )
465+
466+ num_q_heads , num_kv_heads = self ._update_num_heads (num_q_heads , num_kv_heads )
467+ all_out_features = self ._get_qkv_out_features (num_q_heads , num_kv_heads , head_size , head_size_v )
468+ out_names = ('q' , 'k' , 'v' )
469+ self .num_q_heads = num_q_heads
470+ self .num_kv_heads = num_kv_heads
471+ self .head_size = head_size
472+ self .head_size_v = head_size_v
473+ self .num_replicate_kv_heads = num_replicate_kv_heads
474+
475+ super ().__init__ (in_features ,
476+ all_out_features ,
477+ dtype = dtype ,
478+ fp8_dtype = fp8_dtype ,
479+ bias = bias ,
480+ device = device ,
481+ is_tp = is_tp ,
482+ out_names = out_names ,
483+ dp_gather = dp_gather )
484+
485+ def _update_all_out_features (self , all_out_features : List [int ], replicate : Optional [List [bool ]]):
486+ """update all out features."""
487+ return all_out_features
488+
489+ def weight_loader (self , param : torch .nn .Parameter , loaded_weight : torch .Tensor , shard_id : Any ):
490+ """weight loader."""
491+ _ , rank = _get_tp_world_rank (self .is_tp )
492+ shard_idx = self .out_names_map [shard_id ]
493+
494+ num_head = self .num_q_heads if shard_id == 'q' \
495+ else self .num_kv_heads
496+ head_dim = self .head_size if shard_id in ['q' , 'k' ] \
497+ else self .head_size_v
498+ # update to duplicate k/v for tp_size > num_kv_heads
499+ rank_idx = rank if shard_id == 'q' \
500+ else rank // self .num_replicate_kv_heads
501+ sec_len = num_head * head_dim
502+ all_out_features = self .all_out_features
503+ if param ._weight_type == 'scales' :
504+ loaded_weight = loaded_weight .to (torch .float32 )
505+ all_out_features = [sec // self .block_size for sec in all_out_features ]
506+ sec_len = sec_len // self .block_size
507+
508+ sec_start = rank_idx * sec_len
509+
510+ loaded_weight = loaded_weight .narrow (dim = 0 , start = sec_start , length = sec_len )
511+ param_w = param .data .split (all_out_features , 0 )[shard_idx ]
512+ param_w .copy_ (loaded_weight )
513+
514+ def weight_spliter (self , loaded_weight : torch .Tensor , layout : str = 'default' ):
515+ """weight spliter."""
516+ _check_qkv_split_layout (layout )
517+ assert layout == 'default'
518+ qkv_split_section = self .qkv_split_section
519+ if loaded_weight .dim () == 2 and loaded_weight .dtype != self .fp8_dtype :
520+ qkv_split_section = [sec // self .block_size for sec in qkv_split_section ]
521+ return loaded_weight .split (qkv_split_section , dim = 0 )
522+
523+
441524class AwqLinear (nn .Module ):
442525 """w4a16 linear."""
443526
@@ -1591,10 +1674,19 @@ def build_qkv_proj(in_features: int,
15911674 dtype : Optional [torch .dtype ] = None ,
15921675 device : Optional [torch .device ] = None ,
15931676 is_tp : bool = True ,
1594- num_replicate_kv_heads : int = 1 ):
1677+ num_replicate_kv_heads : int = 1 ,
1678+ dp_disable_tp : bool = False ,
1679+ all_reduce : bool = False ,
1680+ dp_gather : bool = False ):
15951681 """build qkv proj."""
1596- if is_tp :
1597- is_tp , _ = _get_dp_tp_meta ()
1682+ if dp_disable_tp and is_tp :
1683+ is_tp , _ = _get_dp_tp_meta (all_reduce )
1684+ elif is_tp :
1685+ is_tp = get_tp_world_rank ()[0 ] > 1
1686+
1687+ if dp_gather :
1688+ assert not dp_disable_tp
1689+ dp_gather = _get_dp_gather (is_tp )
15981690
15991691 if head_size_v is None :
16001692 head_size_v = head_size
@@ -1642,6 +1734,26 @@ def build_qkv_proj(in_features: int,
16421734 is_tp = is_tp ,
16431735 num_replicate_kv_heads = num_replicate_kv_heads ,
16441736 quant_dtype = quant_dtype )
1737+ if quant_method == 'fp8' :
1738+ fmt = quant_config .get ('fmt' , 'e4m3' )
1739+ if fmt == 'e4m3' :
1740+ fp8_dtype = torch .float8_e4m3fn
1741+ elif fmt == 'e5m2' :
1742+ fp8_dtype = torch .float8_e5m2
1743+ else :
1744+ raise TypeError (f'Unsupported fp8 fmt: { fmt } ' )
1745+ return QKVBlockedF8Linear (in_features = in_features ,
1746+ num_q_heads = num_q_heads ,
1747+ num_kv_heads = num_kv_heads ,
1748+ head_size = head_size ,
1749+ head_size_v = head_size_v ,
1750+ bias = bias ,
1751+ fp8_dtype = fp8_dtype ,
1752+ dtype = dtype ,
1753+ device = device ,
1754+ is_tp = is_tp ,
1755+ dp_gather = dp_gather ,
1756+ num_replicate_kv_heads = num_replicate_kv_heads )
16451757 else :
16461758 raise RuntimeError (f'Unsupported quant method: { quant_method } ' )
16471759
0 commit comments