-
Notifications
You must be signed in to change notification settings - Fork 52
[QEff. Finetune]: Correcting num_steps trained as per max_train_step and displaying non scaled loss value on console. #527
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
… scaled loss value on console Signed-off-by: Swati Allabadi <[email protected]>
@@ -265,7 +265,7 @@ def train( | |||
) | |||
|
|||
pbar.set_description( | |||
f"Training Epoch: {epoch + 1}/{train_config.num_epochs}, step {step + 1}/{len(train_dataloader)} completed (loss: {loss.detach().float()})" | |||
f"Training Epoch: {epoch + 1}/{train_config.num_epochs}, step {step + 1}/{len(train_dataloader)} completed (loss: {(loss * num_samples_in_cur_update).detach().float()})" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For batch_size=4, with 3rd sample being a padded sample, the losses are L1, L2, L3 and L4. The loss for L3 is zeroed because it is padded sample. Now, here the loss variable is the avg of all 4 losses. To make it an average of 3 values, we should be multiplying the loss with 4 and divide it with 3.
Let me know is this the problem you are trying to solve? If so, then how your solution helps?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not the issue. The use case which you have explained is already taken care at line #200.
In case of gradient_accumulation, since the per step loss is scaled down by the value of gradient_accumulation_steps, and loss is printed on the console only after the step finishes, the scaled down value was being displayed on the console. Corrected it with the above change.
Can you prefix the title with "[QEff. Finetune]:" so that we can filter out whenever needed in future. |
Signed-off-by: Swati Allabadi <[email protected]>
@@ -326,7 +325,7 @@ def train( | |||
) | |||
|
|||
if is_rank_zero(): | |||
tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps) | |||
tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps - 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the better design is to properly use this total_train_steps variable. I suggest below changes to make the design cleaner.
- Remove the updation of total_train_steps at L130.
- At the L152, at the start of each iteration over dataloader update the total_train_steps variable: total_train_steps = len(train_dataloader) * (epoch + 1) + step. This will be helpful throughout each step of dataloader.
- Replace the condition at L157 with the condition of L162. Both are doing the same thing.
- Remove L158 updation of total_train_steps.
- Remove L164 updation of total_train_steps.
- The condition at L166 should include >= instead of >.
- At L212, the tensorboard should take total_train_steps not 'total_train_steps - 1'.
- At L328, the tensorboard should take total_train_steps not 'total_train_steps - 1'.
With this there will not be a manual +1 or -1 into the total_train_steps. This will make it more maintainable and understandable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Skipped 'Replace the condition at L157 with the condition of L162. Both are doing the same thing.' as both are required. Made total_train_steps = len(train_dataloader) * (epoch) + step , because epoch +1 will give incorrect number. Accomodated rest.
e74306e
to
1ff863b
Compare
Signed-off-by: Swati Allabadi <[email protected]>
1ff863b
to
aae5d36
Compare
Signed-off-by: Swati Allabadi <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for quick fix. :)
…and displaying non scaled loss value on console. (quic#527) Signed-off-by: Swati Allabadi <[email protected]> Co-authored-by: Swati Allabadi <[email protected]>
No description provided.