Skip to content

Conversation

@ngazagna-qc
Copy link

@ngazagna-qc ngazagna-qc commented Sep 4, 2025

What does this PR do?

Fixes #40690

Only 1 line added as described in the original issue.

Question: should we transform https://github.com/ngazagna-qc/transformers/blob/fix-data-order-resumed-epoch/reproduce_wrong_resumed_epoch.py into a test?

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings. -> Not required
  • Did you write any new necessary tests? -> To discuss

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ngazagna-qc
Copy link
Author

Hi @zach-huggingface @SunMarc : I am just following up on this bug and the proposed fix. It’s a one-line change with a reproducible example, and I’d love to help get it merged if possible. Let me know if you have any question :)

Thanks a lot!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot ! Can you try to add some tests for this ? Maybe you can have a look at this PR for the tests

Comment on lines 2605 to 2606
epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch)
epoch_dataloader.iteration = epochs_trained
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed this is probably linked to this also

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the simplest fix would be to move the following code just below the if steps_trained_in_current_epoch > 0: condition !

            if hasattr(epoch_dataloader, "set_epoch"):
                epoch_dataloader.set_epoch(epoch)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the simplest fix would be to move the following code just below the if steps_trained_in_current_epoch > 0: condition !

            if hasattr(epoch_dataloader, "set_epoch"):
                epoch_dataloader.set_epoch(epoch)

I added the test and the fix suggested and we cannot move this snippet under the if steps_trained_in_current_epoch > 0: condition as the epoch needs to be set even if no batch has to be skipped. So instead I copied the snippet there :)

@ngazagna-qc
Copy link
Author

Hi @SunMarc , Thanks for your feedback. Sure, let me check the tests and try to implement one to spot and check the correction of the bug. I will probably work on it tomorrow.

@ngazagna-qc ngazagna-qc force-pushed the pr-fix-data-order-resumed-epoch branch from b07715d to 750dd2a Compare September 25, 2025 11:17
@ngazagna-qc ngazagna-qc reopened this Sep 25, 2025
@ngazagna-qc ngazagna-qc force-pushed the pr-fix-data-order-resumed-epoch branch 2 times, most recently from 7582e24 to c2bf020 Compare September 29, 2025 10:59
@ngazagna-qc
Copy link
Author

ngazagna-qc commented Sep 29, 2025

@SunMarc, I think we are close to be able to merge this PR:

  • I updated the trainer under the if steps_trained_in_current_epoch > 0: condition as we discussed
  • I wrote the test in tests/trainer/test_trainer.py::TrainerIntegrationTest::test_resume_batch_order
  • I confirmed that without the fix, the test fails, with the fix, the test passes.

One can run pytest -v -s tests/trainer/test_trainer.py::TrainerIntegrationTest::test_resume_batch_order to see the output and understand better the issue.

I tried to follow the contribution guidelines but I am not super familiar with the commands to run to format the code. Let me know if anything more is needed :)

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks ! Left a few comments

Comment on lines +2601 to +2418
if hasattr(epoch_dataloader, "set_epoch"):
epoch_dataloader.set_epoch(epoch)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete the lines above then as we are duplicating the lines right now

Copy link
Author

@ngazagna-qc ngazagna-qc Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think the lines above should be deleted. I ran tests with and without, and my test only passes if the epoch_dataloader sets the epoch in the main loop over epochs AND after skip_first_batches is called

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm that's strange, can you try to investigate a bit why this happens ? It's a bit counter-intuitive

Copy link
Author

@ngazagna-qc ngazagna-qc Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am exploring this. If I commented out/delete these lines

            if hasattr(epoch_dataloader, "set_epoch"):
                epoch_dataloader.set_epoch(epoch)

Then, for the baseline run, without resuming, I observe that the same randomness is applied at each epoch to sample batches:

Baseline: tensor([2, 6, 1, 8, 4, 5, 0, 9, 3, 7, 
                  2, 6, 1, 8, 4, 5, 0, 9, 3, 7, 
                  2, 6, 1, 8, 4, 5, 0, 9, 3, 7])

For comparison here is the normal data order (with current code base) that my PR maintains:

Baseline: tensor([2, 6, 1, 8, 4, 5, 0, 9, 3, 7, 
                  8, 4, 9, 0, 5, 1, 6, 7, 2, 3, 
                  2, 6, 3, 8, 7, 9, 1, 4, 5, 0])

I think this is because of this line:

            epoch_dataloader = train_dataloader

that resets epoch_dataloader at every epoch and thus its iteration attribute to 0.

I tried to put epoch_dataloader = train_dataloader before the for loop over epochs for epoch in range(epochs_trained, num_train_epochs): but I run into errors. I guess there was a reason to do so.

Conclusion: we should then keep the update of the iteration attribute after epoch_dataloader = train_dataloader

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the above make sense @SunMarc ?

final_model_path = os.path.join(final_checkpoint_path, SAFE_WEIGHTS_NAME)
self.assertTrue(os.path.exists(final_model_path), "Final model checkpoint was not saved!")

@require_safetensors
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems a bit too complicated just for testing this, can you try to make this simpler ? Feel free to check the other tests

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have just simplified the test :) Sorry for the delay

Please let me know if you have other ideas of simplification.

@ngazagna-qc
Copy link
Author

Hello @SunMarc , just a quick follow-up after 2 weeks :) I hope my updates are good enough now !

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot ! Just some nits and we are good ot merge

Comment on lines +2601 to +2418
if hasattr(epoch_dataloader, "set_epoch"):
epoch_dataloader.set_epoch(epoch)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm that's strange, can you try to investigate a bit why this happens ? It's a bit counter-intuitive

@SunMarc
Copy link
Member

SunMarc commented Oct 23, 2025

make sure to rebase correctly, the diff is huge right now

@ngazagna-qc ngazagna-qc force-pushed the pr-fix-data-order-resumed-epoch branch from df5523d to 7935b86 Compare October 24, 2025 11:52
@ngazagna-qc ngazagna-qc reopened this Oct 24, 2025
@ngazagna-qc
Copy link
Author

ngazagna-qc commented Oct 24, 2025

make sure to rebase correctly, the diff is huge right now

I did reset from main. Let me know if anything else is due.

@ngazagna-qc ngazagna-qc force-pushed the pr-fix-data-order-resumed-epoch branch 2 times, most recently from 1c87d02 to 2eb6316 Compare October 28, 2025 10:11
@ngazagna-qc
Copy link
Author

Hi @SunMarc , I believe the rebase on main is correct now

Comment on lines +5162 to +5166
# Scenario 1: Run baseline training to completion
# 1.1 Run training to completion
set_seed(42)
train_dataset = DummyDataset(size=10)
model_baseline = DummyModel(size=10)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it only creates 3 checkpoints, I think the dataset you created is too small.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checking this now

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The training for 3 epochs over a dataset of 10 points, with a batch size of 2, no gradient accumulation leads correctly to the creation of 15 checkpoints. Indeed, an epoch requires 5 steps with bs=2.

I checked it :

$ pytest tests/trainer/test_trainer.py::TrainerIntegrationTest::test_resume_batch_order -v -s   
========================================================================================================================== test session starts ===========================================================================================================================
platform linux -- Python 3.10.12, pytest-8.4.2, pluggy-1.6.0 -- /usr/bin/python3
cachedir: .pytest_cache
hypothesis profile 'default'
rootdir: .../transformers
configfile: pyproject.toml
plugins: rerunfailures-15.1, rich-0.2.0, asyncio-1.2.0, xdist-3.8.0, order-1.3.0, hypothesis-6.140.3, anyio-4.11.0, timeout-2.4.0, cov-6.2.1
asyncio: mode=strict, debug=False, asyncio_default_fixture_loop_scope=function, asyncio_default_test_loop_scope=function
collected 1 item                                                                                                                                                                                                                                                         

tests/trainer/test_trainer.py::TrainerIntegrationTest::test_resume_batch_order {'train_runtime': 0.2858, 'train_samples_per_second': 104.973, 'train_steps_per_second': 52.486, 'train_loss': 2.2335431416829428, 'epoch': 3.0}
~~~~ Checking saved checkpoints after full training
checkpoint-13
checkpoint-10
checkpoint-14
checkpoint-15
checkpoint-4
checkpoint-11
checkpoint-6
checkpoint-9
checkpoint-5
checkpoint-3
checkpoint-8
checkpoint-2
checkpoint-7
checkpoint-1
checkpoint-12
{'train_runtime': 0.0479, 'train_samples_per_second': 626.658, 'train_steps_per_second': 313.329, 'train_loss': 1.27561403910319, 'epoch': 3.0}
~~~~ Checking saved checkpoints after resuming from ckpt 7
checkpoint-13
checkpoint-10
checkpoint-14
checkpoint-15
checkpoint-11
checkpoint-9
checkpoint-8
checkpoint-7
checkpoint-12
PASSED

Is there something I misunderstood? Or should I add comments if something is unclear?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm i will try on my side again but I only saved 3 checkpoints. Can you try with transformers main ?

Copy link
Author

@ngazagna-qc ngazagna-qc Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rebased my code on the latests updates I could fetch from main yesterday.

Let me do it again in a fresh env and a fresh branch directly checked out from current transformers main!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just pull again the last main at this commit 561233c
I ran (inside my container), pip uninstall transformers and then pip install -e ".[dev]" (which output Successfully installed transformers-5.0.0.dev0) and the tests gives me the correct output as i shared above.

I do not know where this issue comes from. is it the same in the CI ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @SunMarc , Any update on this? I will try again from a fresh env tomorrow

@ngazagna-qc ngazagna-qc force-pushed the pr-fix-data-order-resumed-epoch branch 2 times, most recently from fd14b2a to 1adc7ba Compare November 5, 2025 15:34
@ngazagna-qc ngazagna-qc force-pushed the pr-fix-data-order-resumed-epoch branch from 1adc7ba to 136dfc4 Compare November 21, 2025 12:24
@ngazagna-qc
Copy link
Author

ngazagna-qc commented Nov 21, 2025

@SunMarc to reproduce:

git clone https://github.com/ngazagna-qc/transformers.git
cd transformers
uv venv .my-env
source .my-env/bin/activate
uv pip install -e ".[torch,sentencepiece,tokenizers,vision,integrations,timm,torch-vision,codecarbon,accelerate,video,num2words,mistral-common,chat_template,testing,quality,ja,sklearn,modelcreation]"
git checkout pr-fix-data-order-resumed-epoch
pytest tests/trainer/test_trainer.py::TrainerIntegrationTest::test_resume_batch_order -v -s

Let me know if I can help. I hope we are very close to merging.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Batches loaded from wrong epoch when resuming from second epoch

3 participants