Skip to content

Commit aae5d36

Browse files
author
Swati Allabadi
committed
Minor change about total_train_steps update
Signed-off-by: Swati Allabadi <[email protected]>
1 parent c4c5c11 commit aae5d36

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

QEfficient/finetune/utils/train_utils.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,6 @@ def train(
126126
intermediate_epoch = int(train_config.from_peft_checkpoint.split("/")[-2].split("_")[-1]) - 1
127127
if epoch < intermediate_epoch:
128128
logger.log_rank_zero(f"Skipping epoch {epoch + 1} since fine tuning has already completed for it.")
129-
# to bring the count of train_step in sync with where it left off
130-
total_train_steps += len(train_dataloader)
131129
continue
132130

133131
logger.log_rank_zero(f"Starting epoch {epoch + 1}/{train_config.num_epochs}")
@@ -149,21 +147,22 @@ def train(
149147

150148
num_dummy_samples = 0
151149
for step, batch in enumerate(train_dataloader):
150+
# total_train_steps indicates the cumulative number of training steps completed across all epochs.
151+
# When resuming fine-tuning from previously saved checkpoints, total_train_steps indicates the total number of steps trained across the earlier session and the ongoing one.
152+
total_train_steps = (epoch) * len(train_dataloader) + step
152153
# resume training from a particular checkpoint, assuming the dataset is not shuffled
153154
if train_config.use_peft and train_config.from_peft_checkpoint:
154155
intermediate_step = int(train_config.from_peft_checkpoint.split("/")[-1].split("_")[-1])
155156
intermediate_epoch = int(train_config.from_peft_checkpoint.split("/")[-2].split("_")[-1]) - 1
156157
# to bring the count of train_step in sync with where it left off
157158
if epoch == intermediate_epoch and step == 0:
158-
total_train_steps += intermediate_step
159159
logger.log_rank_zero(
160160
f"Skipping first {intermediate_step} steps for epoch {epoch + 1}, since fine tuning has already completed for it."
161161
)
162162
if epoch == intermediate_epoch and step < intermediate_step:
163163
continue
164-
total_train_steps += 1
165164

166-
if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step:
165+
if train_config.max_train_step > 0 and total_train_steps >= train_config.max_train_step:
167166
max_steps_reached = True
168167
logger.log_rank_zero(
169168
"Maximum training steps reached "
@@ -209,7 +208,7 @@ def train(
209208
total_loss += loss.detach().float()
210209

211210
if is_rank_zero():
212-
tensorboard_updates.add_scalars("loss", {"train": loss}, total_train_steps - 1)
211+
tensorboard_updates.add_scalars("loss", {"train": loss}, total_train_steps)
213212
if loss <= train_config.convergence_loss:
214213
loss_0_counter += 1
215214
else:
@@ -264,7 +263,7 @@ def train(
264263
)
265264

266265
pbar.set_description(
267-
f"Training Epoch: {epoch + 1}/{train_config.num_epochs}, step {step + 1}/{len(train_dataloader)} completed (loss: {(loss).detach().float()})"
266+
f"Training Epoch: {epoch + 1}/{train_config.num_epochs}, step {step + 1}/{len(train_dataloader)} completed (loss: {loss.detach().float()})"
268267
)
269268
if train_config.save_metrics:
270269
save_to_json(
@@ -325,7 +324,7 @@ def train(
325324
)
326325

327326
if is_rank_zero():
328-
tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps - 1)
327+
tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps)
329328
if train_config.save_metrics:
330329
eval_step_loss.extend(step_loss)
331330
eval_step_metric.extend(step_metric)

0 commit comments

Comments
 (0)