@@ -217,26 +217,46 @@ def filter_weights(prefix: str, weights: dict):
217217 if "self_attn.qkv_proj" in name :
218218 # The weights need to be split correctly before sharding to support tp_size >1.
219219 qkv_weight = module_weights ['weight' ][:]
220- q_weight = qkv_weight [:hidden_size , :]
221- k_weight = qkv_weight [hidden_size :hidden_size +
222- num_kv_heads * head_dim , :]
223- v_weight = qkv_weight [hidden_size +
224- num_kv_heads * head_dim :, :]
220+ qk_split_index = hidden_size
221+ kv_split_index = hidden_size + num_kv_heads * head_dim
222+
223+ q_dict = {'weight' : qkv_weight [:qk_split_index , :]}
224+ k_dict = {
225+ 'weight' :
226+ qkv_weight [qk_split_index :kv_split_index , :]
227+ }
228+ v_dict = {'weight' : qkv_weight [kv_split_index :, :]}
225229
226230 # Get the scale factor for the fused QKV projection
227231 qkv_scale = module_weights .get ('weight_scale' , None )
228232
229- q_dict = {'weight' : q_weight }
230- if qkv_scale is not None :
231- q_dict ['weight_scale' ] = qkv_scale
232-
233- k_dict = {'weight' : k_weight }
234233 if qkv_scale is not None :
235- k_dict ['weight_scale' ] = qkv_scale # Use same scale
236-
237- v_dict = {'weight' : v_weight }
238- if qkv_scale is not None :
239- v_dict ['weight_scale' ] = qkv_scale # Use same scale
234+ if qkv_scale .shape and qkv_scale .shape [
235+ 0 ] == qkv_weight .shape [0 ]:
236+ q_dict [
237+ 'weight_scale' ] = qkv_scale [:
238+ qk_split_index , :]
239+ k_dict ['weight_scale' ] = qkv_scale [
240+ qk_split_index :kv_split_index , :]
241+ v_dict ['weight_scale' ] = qkv_scale [
242+ kv_split_index :, :]
243+ else : # use same scale
244+ q_dict ['weight_scale' ] = qkv_scale
245+ k_dict ['weight_scale' ] = qkv_scale
246+ v_dict ['weight_scale' ] = qkv_scale
247+
248+ input_scale = module_weights .get ('input_scale' , None )
249+ if input_scale is not None :
250+ q_dict ['input_scale' ] = input_scale
251+ k_dict ['input_scale' ] = input_scale
252+ v_dict ['input_scale' ] = input_scale
253+
254+ weight_scale_2 = module_weights .get (
255+ 'weight_scale_2' , None )
256+ if weight_scale_2 is not None :
257+ q_dict ['weight_scale_2' ] = weight_scale_2
258+ k_dict ['weight_scale_2' ] = weight_scale_2
259+ v_dict ['weight_scale_2' ] = weight_scale_2
240260
241261 module .load_weights (weights = [q_dict , k_dict , v_dict ])
242262 elif "mlp.gate_up_proj" in name :
@@ -246,16 +266,33 @@ def filter_weights(prefix: str, weights: dict):
246266 gate_weight = gate_up_weight [:intermediate_size , :]
247267 up_weight = gate_up_weight [intermediate_size :, :]
248268
249- # Get the scale factors if they exist
250- gate_up_scale = module_weights .get ('weight_scale' , None )
251-
252269 gate_dict = {'weight' : gate_weight }
253- if gate_up_scale is not None :
254- gate_dict ['weight_scale' ] = gate_up_scale
255-
256270 up_dict = {'weight' : up_weight }
271+
272+ # Get the scale factors if they exist
273+ gate_up_scale = module_weights .get ('weight_scale' , None )
257274 if gate_up_scale is not None :
258- up_dict ['weight_scale' ] = gate_up_scale
275+ if gate_up_scale .shape and gate_up_scale .shape [
276+ 0 ] == gate_up_weight .shape [0 ]:
277+ gate_dict [
278+ 'weight_scale' ] = gate_up_scale [:
279+ intermediate_size , :]
280+ up_dict ['weight_scale' ] = gate_up_scale [
281+ intermediate_size :, :]
282+ else : # use same scale
283+ gate_dict ['weight_scale' ] = gate_up_scale
284+ up_dict ['weight_scale' ] = gate_up_scale
285+
286+ input_scale = module_weights .get ('input_scale' , None )
287+ if input_scale is not None :
288+ gate_dict ['input_scale' ] = input_scale
289+ up_dict ['input_scale' ] = input_scale
290+
291+ weight_scale_2 = module_weights .get (
292+ 'weight_scale_2' , None )
293+ if weight_scale_2 is not None :
294+ gate_dict ['weight_scale_2' ] = weight_scale_2
295+ up_dict ['weight_scale_2' ] = weight_scale_2
259296
260297 module .load_weights (weights = [gate_dict , up_dict ])
261298 else :
0 commit comments