@@ -111,6 +111,8 @@ def __init__(
111111 self .bias_fused_dim = 0
112112 self .weight_scale_and_zero_point_fused_dim = 0
113113
114+ self .load_finished : bool = False
115+
114116 def mm (
115117 self , input_tensor : torch .Tensor , out : Optional [torch .Tensor ] = None , use_custom_tensor_mananger : bool = True
116118 ) -> torch .Tensor :
@@ -201,6 +203,7 @@ def load_hf_weights(self, weights):
201203 self .quant_method is not None
202204 and self .mm_param .weight is not None
203205 and self .quant_method .weight_need_quanted (self .mm_param .weight )
206+ and self .load_finished is False
204207 ):
205208 logger .info (f"online quant weight names: { self .weight_names } " )
206209 quantized_weight , weight_scale , weight_zero_point = self .quant_method .quantize (
@@ -211,7 +214,12 @@ def load_hf_weights(self, weights):
211214 self .mm_param .weight_zero_point = weight_zero_point
212215
213216 # repack 操作
214- if self .quant_method is not None and self .mm_param .is_ready () and self .quant_method .params_need_repack ():
217+ if (
218+ self .quant_method is not None
219+ and self .mm_param .is_ready ()
220+ and self .quant_method .params_need_repack ()
221+ and self .load_finished is False
222+ ):
215223 (
216224 self .mm_param .weight ,
217225 self .mm_param .weight_scale ,
@@ -223,8 +231,9 @@ def load_hf_weights(self, weights):
223231 dtype_type = self .data_type_ ,
224232 )
225233
226- if self .mm_param .is_ready ():
234+ if self .mm_param .is_ready () and self . load_finished is False :
227235 self ._to_gpu_device ()
236+ self .load_finished = True
228237
229238 def verify_load (self ) -> bool :
230239 return self .mm_param .is_ready ()
0 commit comments