Skip to content

Conversation

@tohtana
Copy link
Contributor

@tohtana tohtana commented Nov 14, 2025

What does this PR do?

With the latest revision, loading weights fails with an error when DeepSpeed ZeRO3 is enabled.
This PR resolves the issue.

Here is the stack trace of the error.

  File "/home/runner/work/DeepSpeed/DeepSpeed/tests/unit/runtime/zero/test_zero_nesting_init.py", line 70, in test_nested_parallel_init
    model = VisionEncoderDecoderModel.from_pretrained(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/runner/work/DeepSpeed/DeepSpeed/unit-test-venv/lib/python3.12/site-packages/transformers/modeling_utils.py", line 270, in _wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/runner/work/DeepSpeed/DeepSpeed/unit-test-venv/lib/python3.12/site-packages/transformers/modeling_utils.py", line 4122, in from_pretrained
    model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs = cls._load_pretrained_model(
                                                                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/runner/work/DeepSpeed/DeepSpeed/unit-test-venv/lib/python3.12/site-packages/transformers/modeling_utils.py", line 4236, in _load_pretrained_model
    error_msgs += _load_state_dict_into_zero3_model(model, state_dict)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/runner/work/DeepSpeed/DeepSpeed/unit-test-venv/lib/python3.12/site-packages/transformers/integrations/deepspeed.py", line 302, in _load_state_dict_into_zero3_model
    state_dict = state_dict.copy()
                 ^^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'copy'

This is the minimal repro. You can run this with torchrun.

def main() -> None:
    from transformers import VisionEncoderDecoderModel, __version__ as transformers_version
    from transformers.integrations.deepspeed import HfDeepSpeedConfig

    import deepspeed  # noqa: F401  # ensure DeepSpeed is available inside the process

    ds_config = {
        "train_batch_size": 1,
        "zero_optimization": {
            "stage": 3,
        },
    }
    # Keep the config object alive so transformers detects ZeRO-3.
    dschf = HfDeepSpeedConfig(ds_config)

    print("Attempting to load model with transformers version:", transformers_version)
    VisionEncoderDecoderModel.from_pretrained(
        "hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2")
    # Keep the config alive until the end of the function.
    assert dschf is not None


if __name__ == "__main__":
    main()

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Let me mention a few people who contributed to the related commits.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@3outeille
Copy link
Member

tried the repro and indeed there was a bug. LGTM, thanks for the PR

@3outeille 3outeille enabled auto-merge (squash) November 14, 2025 10:46
@ArthurZucker
Copy link
Collaborator

Thanks! We expected it to break and were gonna fix it but thanks a lot for doing it ahead of us @tohtana !
Actually we'd want to have deepseep use our parallelism when possible (because we shard in this func) but @3outeille and I will have a look at that once stuff are stabilized!

@ArthurZucker ArthurZucker merged commit eddd51e into huggingface:main Nov 14, 2025
20 of 23 checks passed
@tohtana tohtana deleted the fix_z3_init branch November 14, 2025 23:04
@tohtana
Copy link
Contributor Author

tohtana commented Nov 14, 2025

Thank you @3outeille and @ArthurZucker for merging this so quickly!

Actually we'd want to have deepseep use our parallelism when possible (because we shard in this func) but @3outeille and I will have a look at that once stuff are stabilized!

That sounds great! Could you elaborate a bit more on what kind of integration you have in mind?
From the DeepSpeed side, we’re definitely open to extending our APIs to enable better interoperability with HF’s parallelism.

@ArthurZucker
Copy link
Collaborator

I need to have a look at how you implement TP today to give a better answer, but mostly if it can leverage the normal from_pretrained path for parallelism it would simplify things! 🤗

@tohtana
Copy link
Contributor Author

tohtana commented Nov 21, 2025

You mean AutoTP in DeepSpeed? It's a very important feature for us. Please let me know if there is anything we could do on DeepSpeed side (e.g. adding APIs /tests).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants