@@ -158,51 +158,28 @@ def train(
158158 # Check if parameter passed or if set within environ
159159 use_wandb = wandb_check (wandb_project , wandb_watch , wandb_log_model )
160160
161- if deepspeed_zero3 :
162- deepspeed = deepspeed if deepspeed is not None else "./deepspeed_zero3_config.json"
163-
164161 if saved_low_bit_model is not None :
165162 # Load the low bit optimized model if provide the saved path
166- if deepspeed_zero3 :
167- import deepspeed as ds
168- with ds .zero .Init (config_dict_or_path = deepspeed ):
169- model = AutoModelForCausalLM .load_low_bit (
170- saved_low_bit_model ,
171- optimize_model = False ,
172- torch_dtype = torch .bfloat16 ,
173- modules_to_not_convert = ["lm_head" ],
174- trust_remote_code = True ,
175- )
176- else :
177- model = AutoModelForCausalLM .load_low_bit (
178- saved_low_bit_model ,
179- optimize_model = False ,
180- torch_dtype = torch .bfloat16 ,
181- modules_to_not_convert = ["lm_head" ],
182- trust_remote_code = True ,
183- )
163+ model = AutoModelForCausalLM .load_low_bit (
164+ saved_low_bit_model ,
165+ optimize_model = False ,
166+ torch_dtype = torch .bfloat16 ,
167+ modules_to_not_convert = ["lm_head" ],
168+ trust_remote_code = True ,
169+ )
170+ else :
171+ model = AutoModelForCausalLM .from_pretrained (
172+ base_model ,
173+ load_in_low_bit = "bf16" ,
174+ optimize_model = False ,
175+ torch_dtype = torch .bfloat16 ,
176+ modules_to_not_convert = ["lm_head" ],
177+ trust_remote_code = True ,
178+ )
179+
180+ if deepspeed_zero3 :
181+ deepspeed = deepspeed if deepspeed is not None else "./deepspeed_zero3_config.json"
184182 else :
185- if deepspeed_zero3 :
186- import deepspeed as ds
187- with ds .zero .Init (config_dict_or_path = deepspeed ):
188- model = AutoModelForCausalLM .from_pretrained (
189- base_model ,
190- load_in_low_bit = "bf16" ,
191- optimize_model = False ,
192- torch_dtype = torch .bfloat16 ,
193- modules_to_not_convert = ["lm_head" ],
194- trust_remote_code = True ,
195- )
196- else :
197- model = AutoModelForCausalLM .from_pretrained (
198- base_model ,
199- load_in_low_bit = "bf16" ,
200- optimize_model = False ,
201- torch_dtype = torch .bfloat16 ,
202- modules_to_not_convert = ["lm_head" ],
203- trust_remote_code = True ,
204- )
205- if not deepspeed_zero3 :
206183 print (f"Model loaded on rank { os .environ .get ('LOCAL_RANK' )} " )
207184 model = model .to (f'xpu:{ os .environ .get ("LOCAL_RANK" , 0 )} ' )
208185 print (f"Model moved to rank { os .environ .get ('LOCAL_RANK' )} " )
0 commit comments