Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions QEfficient/finetune/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,9 @@ def train(

if train_config.use_peft and train_config.from_peft_checkpoint:
intermediate_epoch = int(train_config.from_peft_checkpoint.split("/")[-2].split("_")[-1]) - 1
intermediate_step = int(train_config.from_peft_checkpoint.split("/")[-1].split("_")[-1])
if epoch < intermediate_epoch:
logger.log_rank_zero(f"Skipping epoch {epoch + 1} since fine tuning has already completed for it.")
# to bring the count of train_step in sync with where it left off
total_train_steps += len(train_dataloader)
continue

logger.log_rank_zero(f"Starting epoch {epoch + 1}/{train_config.num_epochs}")
Expand All @@ -149,20 +148,18 @@ def train(

num_dummy_samples = 0
for step, batch in enumerate(train_dataloader):
# total_train_steps indicates the cumulative number of training steps completed across all epochs.
# When resuming fine-tuning from previously saved checkpoints, total_train_steps indicates the total number of steps trained across the earlier session and the ongoing one.
total_train_steps = (epoch) * len(train_dataloader) + step
# resume training from a particular checkpoint, assuming the dataset is not shuffled
if train_config.use_peft and train_config.from_peft_checkpoint:
intermediate_step = int(train_config.from_peft_checkpoint.split("/")[-1].split("_")[-1])
intermediate_epoch = int(train_config.from_peft_checkpoint.split("/")[-2].split("_")[-1]) - 1
# to bring the count of train_step in sync with where it left off
if epoch == intermediate_epoch and step == 0:
total_train_steps += intermediate_step
logger.log_rank_zero(
f"Skipping first {intermediate_step} steps for epoch {epoch + 1}, since fine tuning has already completed for it."
)
if epoch == intermediate_epoch and step < intermediate_step:
total_train_steps += 1
continue
total_train_steps += 1

if train_config.max_train_step > 0 and total_train_steps >= train_config.max_train_step:
max_steps_reached = True
Expand Down Expand Up @@ -235,12 +232,12 @@ def train(
else:
num_samples_in_cur_update = len(train_dataloader) % train_config.gradient_accumulation_steps

loss = loss / num_samples_in_cur_update
normalized_loss = loss / num_samples_in_cur_update

if train_config.grad_scaler:
scaler.scale(loss).backward() # backward pass
scaler.scale(normalized_loss).backward() # backward pass
else:
loss.backward() # backward pass
normalized_loss.backward() # backward pass

if is_optimizer_step:
if train_config.grad_scaler:
Expand Down
Loading