@@ -25,7 +25,6 @@ def __init__(
25
25
26
26
self .qmax = torch .finfo (torch .float8_e4m3fn ).max
27
27
self .qmin = torch .finfo (torch .float8_e4m3fn ).min
28
- self .layer_num = layer_num
29
28
self .total_head_num = head_num * dist .get_world_size () if dist .is_initialized () else head_num
30
29
self .count = 0
31
30
self .scales = None
@@ -45,7 +44,13 @@ def __init__(
45
44
self .scales_list = cfg ["scales" ]
46
45
self .scales = torch .tensor (self .scales_list , dtype = torch .float32 , device = "cuda" ).view (cfg ["scales_shape" ])
47
46
if not get_env_start_args ().enable_fa3 :
48
- self .scales = torch .repeat_interleave (self .scales , self .head_num , dim = - 1 )
47
+ self .scales = torch .repeat_interleave (self .scales , head_num , dim = - 1 )
48
+ elif cfg ["num_head" ] > self .total_head_num :
49
+ factor = cfg ["num_head" ] // self .total_head_num
50
+ self .scales = self .scales [..., ::factor ].contiguous ()
51
+ elif cfg ["num_head" ] < self .total_head_num :
52
+ factor = self .total_head_num // cfg ["num_head" ]
53
+ self .scales = torch .repeat_interleave (self .scales , factor , dim = - 1 ).contiguous ()
49
54
if get_env_start_args ().enable_fa3 and dist .is_initialized () and dist .get_world_size () > 1 :
50
55
half_head = self .total_head_num // 2
51
56
start_head = dist .get_rank () * head_num
@@ -77,7 +82,7 @@ def _load_and_check_config(self):
77
82
raise ValueError (
78
83
f"num_layers { cfg ['num_layers' ]} in config " f"not match current layer_num { self .layer_num } "
79
84
)
80
- if cfg ["num_head" ] != self .total_head_num :
85
+ if cfg ["num_head" ] % self . total_head_num != 0 and self .total_head_num % cfg [ "num_head" ] != 0 :
81
86
raise ValueError (
82
87
f"num_head { cfg ['num_head' ]} in config " f"not match current model head num { self .total_head_num } "
83
88
)
0 commit comments