@@ -126,8 +126,6 @@ def train(
126
126
intermediate_epoch = int (train_config .from_peft_checkpoint .split ("/" )[- 2 ].split ("_" )[- 1 ]) - 1
127
127
if epoch < intermediate_epoch :
128
128
logger .log_rank_zero (f"Skipping epoch { epoch + 1 } since fine tuning has already completed for it." )
129
- # to bring the count of train_step in sync with where it left off
130
- total_train_steps += len (train_dataloader )
131
129
continue
132
130
133
131
logger .log_rank_zero (f"Starting epoch { epoch + 1 } /{ train_config .num_epochs } " )
@@ -149,21 +147,22 @@ def train(
149
147
150
148
num_dummy_samples = 0
151
149
for step , batch in enumerate (train_dataloader ):
150
+ # total_train_steps indicates the cumulative number of training steps completed across all epochs.
151
+ # When resuming fine-tuning from previously saved checkpoints, total_train_steps indicates the total number of steps trained across the earlier session and the ongoing one.
152
+ total_train_steps = (epoch ) * len (train_dataloader ) + step
152
153
# resume training from a particular checkpoint, assuming the dataset is not shuffled
153
154
if train_config .use_peft and train_config .from_peft_checkpoint :
154
155
intermediate_step = int (train_config .from_peft_checkpoint .split ("/" )[- 1 ].split ("_" )[- 1 ])
155
156
intermediate_epoch = int (train_config .from_peft_checkpoint .split ("/" )[- 2 ].split ("_" )[- 1 ]) - 1
156
157
# to bring the count of train_step in sync with where it left off
157
158
if epoch == intermediate_epoch and step == 0 :
158
- total_train_steps += intermediate_step
159
159
logger .log_rank_zero (
160
160
f"Skipping first { intermediate_step } steps for epoch { epoch + 1 } , since fine tuning has already completed for it."
161
161
)
162
162
if epoch == intermediate_epoch and step < intermediate_step :
163
163
continue
164
- total_train_steps += 1
165
164
166
- if train_config .max_train_step > 0 and total_train_steps > train_config .max_train_step :
165
+ if train_config .max_train_step > 0 and total_train_steps >= train_config .max_train_step :
167
166
max_steps_reached = True
168
167
logger .log_rank_zero (
169
168
"Maximum training steps reached "
@@ -209,7 +208,7 @@ def train(
209
208
total_loss += loss .detach ().float ()
210
209
211
210
if is_rank_zero ():
212
- tensorboard_updates .add_scalars ("loss" , {"train" : loss }, total_train_steps - 1 )
211
+ tensorboard_updates .add_scalars ("loss" , {"train" : loss }, total_train_steps )
213
212
if loss <= train_config .convergence_loss :
214
213
loss_0_counter += 1
215
214
else :
@@ -264,7 +263,7 @@ def train(
264
263
)
265
264
266
265
pbar .set_description (
267
- f"Training Epoch: { epoch + 1 } /{ train_config .num_epochs } , step { step + 1 } /{ len (train_dataloader )} completed (loss: { ( loss ) .detach ().float ()} )"
266
+ f"Training Epoch: { epoch + 1 } /{ train_config .num_epochs } , step { step + 1 } /{ len (train_dataloader )} completed (loss: { loss .detach ().float ()} )"
268
267
)
269
268
if train_config .save_metrics :
270
269
save_to_json (
@@ -325,7 +324,7 @@ def train(
325
324
)
326
325
327
326
if is_rank_zero ():
328
- tensorboard_updates .add_scalars ("loss" , {"eval" : eval_epoch_loss }, total_train_steps - 1 )
327
+ tensorboard_updates .add_scalars ("loss" , {"eval" : eval_epoch_loss }, total_train_steps )
329
328
if train_config .save_metrics :
330
329
eval_step_loss .extend (step_loss )
331
330
eval_step_metric .extend (step_metric )
0 commit comments