-
Notifications
You must be signed in to change notification settings - Fork 31.2k
Fix loaded data order bug when resuming from epoch >= 1 #40691
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fix loaded data order bug when resuming from epoch >= 1 #40691
Conversation
|
Hi @zach-huggingface @SunMarc : I am just following up on this bug and the proposed fix. It’s a one-line change with a reproducible example, and I’d love to help get it merged if possible. Let me know if you have any question :) Thanks a lot! |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot ! Can you try to add some tests for this ? Maybe you can have a look at this PR for the tests
src/transformers/trainer.py
Outdated
| epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch) | ||
| epoch_dataloader.iteration = epochs_trained |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indeed this is probably linked to this also
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the simplest fix would be to move the following code just below the if steps_trained_in_current_epoch > 0: condition !
if hasattr(epoch_dataloader, "set_epoch"):
epoch_dataloader.set_epoch(epoch)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the simplest fix would be to move the following code just below the
if steps_trained_in_current_epoch > 0:condition !if hasattr(epoch_dataloader, "set_epoch"): epoch_dataloader.set_epoch(epoch)
I added the test and the fix suggested and we cannot move this snippet under the if steps_trained_in_current_epoch > 0: condition as the epoch needs to be set even if no batch has to be skipped. So instead I copied the snippet there :)
|
Hi @SunMarc , Thanks for your feedback. Sure, let me check the tests and try to implement one to spot and check the correction of the bug. I will probably work on it tomorrow. |
b07715d to
750dd2a
Compare
7582e24 to
c2bf020
Compare
|
@SunMarc, I think we are close to be able to merge this PR:
One can run I tried to follow the contribution guidelines but I am not super familiar with the commands to run to format the code. Let me know if anything more is needed :) |
SunMarc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks ! Left a few comments
| if hasattr(epoch_dataloader, "set_epoch"): | ||
| epoch_dataloader.set_epoch(epoch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete the lines above then as we are duplicating the lines right now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not think the lines above should be deleted. I ran tests with and without, and my test only passes if the epoch_dataloader sets the epoch in the main loop over epochs AND after skip_first_batches is called
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm that's strange, can you try to investigate a bit why this happens ? It's a bit counter-intuitive
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am exploring this. If I commented out/delete these lines
if hasattr(epoch_dataloader, "set_epoch"):
epoch_dataloader.set_epoch(epoch)
Then, for the baseline run, without resuming, I observe that the same randomness is applied at each epoch to sample batches:
Baseline: tensor([2, 6, 1, 8, 4, 5, 0, 9, 3, 7,
2, 6, 1, 8, 4, 5, 0, 9, 3, 7,
2, 6, 1, 8, 4, 5, 0, 9, 3, 7])
For comparison here is the normal data order (with current code base) that my PR maintains:
Baseline: tensor([2, 6, 1, 8, 4, 5, 0, 9, 3, 7,
8, 4, 9, 0, 5, 1, 6, 7, 2, 3,
2, 6, 3, 8, 7, 9, 1, 4, 5, 0])
I think this is because of this line:
epoch_dataloader = train_dataloader
that resets epoch_dataloader at every epoch and thus its iteration attribute to 0.
I tried to put epoch_dataloader = train_dataloader before the for loop over epochs for epoch in range(epochs_trained, num_train_epochs): but I run into errors. I guess there was a reason to do so.
Conclusion: we should then keep the update of the iteration attribute after epoch_dataloader = train_dataloader
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the above make sense @SunMarc ?
tests/trainer/test_trainer.py
Outdated
| final_model_path = os.path.join(final_checkpoint_path, SAFE_WEIGHTS_NAME) | ||
| self.assertTrue(os.path.exists(final_model_path), "Final model checkpoint was not saved!") | ||
|
|
||
| @require_safetensors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems a bit too complicated just for testing this, can you try to make this simpler ? Feel free to check the other tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have just simplified the test :) Sorry for the delay
Please let me know if you have other ideas of simplification.
|
Hello @SunMarc , just a quick follow-up after 2 weeks :) I hope my updates are good enough now ! |
SunMarc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot ! Just some nits and we are good ot merge
| if hasattr(epoch_dataloader, "set_epoch"): | ||
| epoch_dataloader.set_epoch(epoch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm that's strange, can you try to investigate a bit why this happens ? It's a bit counter-intuitive
|
make sure to rebase correctly, the diff is huge right now |
df5523d to
7935b86
Compare
I did reset from main. Let me know if anything else is due. |
1c87d02 to
2eb6316
Compare
|
Hi @SunMarc , I believe the rebase on main is correct now |
| # Scenario 1: Run baseline training to completion | ||
| # 1.1 Run training to completion | ||
| set_seed(42) | ||
| train_dataset = DummyDataset(size=10) | ||
| model_baseline = DummyModel(size=10) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it only creates 3 checkpoints, I think the dataset you created is too small.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Checking this now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The training for 3 epochs over a dataset of 10 points, with a batch size of 2, no gradient accumulation leads correctly to the creation of 15 checkpoints. Indeed, an epoch requires 5 steps with bs=2.
I checked it :
$ pytest tests/trainer/test_trainer.py::TrainerIntegrationTest::test_resume_batch_order -v -s
========================================================================================================================== test session starts ===========================================================================================================================
platform linux -- Python 3.10.12, pytest-8.4.2, pluggy-1.6.0 -- /usr/bin/python3
cachedir: .pytest_cache
hypothesis profile 'default'
rootdir: .../transformers
configfile: pyproject.toml
plugins: rerunfailures-15.1, rich-0.2.0, asyncio-1.2.0, xdist-3.8.0, order-1.3.0, hypothesis-6.140.3, anyio-4.11.0, timeout-2.4.0, cov-6.2.1
asyncio: mode=strict, debug=False, asyncio_default_fixture_loop_scope=function, asyncio_default_test_loop_scope=function
collected 1 item
tests/trainer/test_trainer.py::TrainerIntegrationTest::test_resume_batch_order {'train_runtime': 0.2858, 'train_samples_per_second': 104.973, 'train_steps_per_second': 52.486, 'train_loss': 2.2335431416829428, 'epoch': 3.0}
~~~~ Checking saved checkpoints after full training
checkpoint-13
checkpoint-10
checkpoint-14
checkpoint-15
checkpoint-4
checkpoint-11
checkpoint-6
checkpoint-9
checkpoint-5
checkpoint-3
checkpoint-8
checkpoint-2
checkpoint-7
checkpoint-1
checkpoint-12
{'train_runtime': 0.0479, 'train_samples_per_second': 626.658, 'train_steps_per_second': 313.329, 'train_loss': 1.27561403910319, 'epoch': 3.0}
~~~~ Checking saved checkpoints after resuming from ckpt 7
checkpoint-13
checkpoint-10
checkpoint-14
checkpoint-15
checkpoint-11
checkpoint-9
checkpoint-8
checkpoint-7
checkpoint-12
PASSEDIs there something I misunderstood? Or should I add comments if something is unclear?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm i will try on my side again but I only saved 3 checkpoints. Can you try with transformers main ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I rebased my code on the latests updates I could fetch from main yesterday.
Let me do it again in a fresh env and a fresh branch directly checked out from current transformers main!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just pull again the last main at this commit 561233c
I ran (inside my container), pip uninstall transformers and then pip install -e ".[dev]" (which output Successfully installed transformers-5.0.0.dev0) and the tests gives me the correct output as i shared above.
I do not know where this issue comes from. is it the same in the CI ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @SunMarc , Any update on this? I will try again from a fresh env tomorrow
fd14b2a to
1adc7ba
Compare
1adc7ba to
136dfc4
Compare
|
@SunMarc to reproduce: git clone https://github.com/ngazagna-qc/transformers.git
cd transformers
uv venv .my-env
source .my-env/bin/activate
uv pip install -e ".[torch,sentencepiece,tokenizers,vision,integrations,timm,torch-vision,codecarbon,accelerate,video,num2words,mistral-common,chat_template,testing,quality,ja,sklearn,modelcreation]"
git checkout pr-fix-data-order-resumed-epoch
pytest tests/trainer/test_trainer.py::TrainerIntegrationTest::test_resume_batch_order -v -sLet me know if I can help. I hope we are very close to merging. |
What does this PR do?
Fixes #40690
Only 1 line added as described in the original issue.
Question: should we transform https://github.com/ngazagna-qc/transformers/blob/fix-data-order-resumed-epoch/reproduce_wrong_resumed_epoch.py into a test?
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings. -> Not required
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.