@@ -107,6 +107,8 @@ def train(
107
107
gradient_checkpointing : bool = False ,
108
108
deepspeed : str = None ,
109
109
training_mode : str = "lora" ,
110
+ deepspeed_zero3 : bool = False ,
111
+ save_checkpoint : bool = True ,
110
112
):
111
113
invalidInputError (training_mode == "lora" ,
112
114
f"This example is for lora training mode, but got training_mode={ training_mode } ." )
@@ -136,6 +138,8 @@ def train(
136
138
f"resume_from_checkpoint: { resume_from_checkpoint or False } \n "
137
139
f"prompt template: { prompt_template_name } \n "
138
140
f"training_mode: { training_mode } \n "
141
+ f"deepspeed_zero3: { deepspeed_zero3 } \n "
142
+ f"save_checkpoint: { save_checkpoint } \n "
139
143
)
140
144
assert (
141
145
base_model
@@ -154,28 +158,54 @@ def train(
154
158
# Check if parameter passed or if set within environ
155
159
use_wandb = wandb_check (wandb_project , wandb_watch , wandb_log_model )
156
160
161
+ if deepspeed_zero3 :
162
+ deepspeed = deepspeed if deepspeed is not None else "./deepspeed_zero3_config.json"
163
+
157
164
if saved_low_bit_model is not None :
158
165
# Load the low bit optimized model if provide the saved path
159
- model = AutoModelForCausalLM .load_low_bit (
160
- saved_low_bit_model ,
161
- optimize_model = False ,
162
- torch_dtype = torch .bfloat16 ,
163
- modules_to_not_convert = ["lm_head" ],
164
- trust_remote_code = True ,
165
- )
166
+ if deepspeed_zero3 :
167
+ import deepspeed as ds
168
+ with ds .zero .Init (config_dict_or_path = deepspeed ):
169
+ model = AutoModelForCausalLM .load_low_bit (
170
+ saved_low_bit_model ,
171
+ optimize_model = False ,
172
+ torch_dtype = torch .bfloat16 ,
173
+ modules_to_not_convert = ["lm_head" ],
174
+ trust_remote_code = True ,
175
+ )
176
+ else :
177
+ model = AutoModelForCausalLM .load_low_bit (
178
+ saved_low_bit_model ,
179
+ optimize_model = False ,
180
+ torch_dtype = torch .bfloat16 ,
181
+ modules_to_not_convert = ["lm_head" ],
182
+ trust_remote_code = True ,
183
+ )
166
184
else :
167
- model = AutoModelForCausalLM .from_pretrained (
168
- base_model ,
169
- load_in_low_bit = "bf16" ,
170
- optimize_model = False ,
171
- torch_dtype = torch .bfloat16 ,
172
- modules_to_not_convert = ["lm_head" ],
173
- trust_remote_code = True ,
174
- )
175
-
176
- print (f"Model loaded on rank { os .environ .get ('LOCAL_RANK' )} " )
177
- model = model .to (f'xpu:{ os .environ .get ("LOCAL_RANK" , 0 )} ' )
178
- print (f"Model moved to rank { os .environ .get ('LOCAL_RANK' )} " )
185
+ if deepspeed_zero3 :
186
+ import deepspeed as ds
187
+ with ds .zero .Init (config_dict_or_path = deepspeed ):
188
+ model = AutoModelForCausalLM .from_pretrained (
189
+ base_model ,
190
+ load_in_low_bit = "bf16" ,
191
+ optimize_model = False ,
192
+ torch_dtype = torch .bfloat16 ,
193
+ modules_to_not_convert = ["lm_head" ],
194
+ trust_remote_code = True ,
195
+ )
196
+ else :
197
+ model = AutoModelForCausalLM .from_pretrained (
198
+ base_model ,
199
+ load_in_low_bit = "bf16" ,
200
+ optimize_model = False ,
201
+ torch_dtype = torch .bfloat16 ,
202
+ modules_to_not_convert = ["lm_head" ],
203
+ trust_remote_code = True ,
204
+ )
205
+ if not deepspeed_zero3 :
206
+ print (f"Model loaded on rank { os .environ .get ('LOCAL_RANK' )} " )
207
+ model = model .to (f'xpu:{ os .environ .get ("LOCAL_RANK" , 0 )} ' )
208
+ print (f"Model moved to rank { os .environ .get ('LOCAL_RANK' )} " )
179
209
180
210
tokenizer = AutoTokenizer .from_pretrained (base_model , trust_remote_code = True )
181
211
print (f"Tokenizer loaded on rank { os .environ .get ('LOCAL_RANK' )} " )
@@ -234,12 +264,12 @@ def train(
234
264
logging_steps = 1 ,
235
265
optim = "adamw_torch" ,
236
266
evaluation_strategy = "steps" if val_set_size > 0 else "no" ,
237
- save_strategy = "steps" ,
267
+ save_strategy = "steps" if save_checkpoint else "no" ,
238
268
eval_steps = 100 if val_set_size > 0 else None ,
239
269
save_steps = 100 ,
240
270
output_dir = output_dir ,
241
271
save_total_limit = 100 ,
242
- load_best_model_at_end = True if val_set_size > 0 else False ,
272
+ load_best_model_at_end = True if val_set_size > 0 and save_checkpoint else False ,
243
273
ddp_find_unused_parameters = False if ddp else None ,
244
274
group_by_length = group_by_length ,
245
275
report_to = "wandb" if use_wandb else None ,
0 commit comments