Skip to content

Support finetuning from a pretrained model #1321

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

vwxyzjn
Copy link

@vwxyzjn vwxyzjn commented Jun 20, 2025

Continuation of #1300. Cleaner implementation.

@vwxyzjn vwxyzjn requested review from tianyu-l, fegin and wwwjn as code owners June 20, 2025 19:31
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 20, 2025
@@ -70,6 +75,7 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
# we will need to reinitialize the cache_state_dict.
self.cache_state_dict = {
k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()
if k not in excluded_parameters_for_model_only
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed here because I was getting RuntimeError: Missing key in checkpoint state_dict: model.freqs_cis. if I am trying to load a full checkpoint.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fegin can you take a look at this part?
I can verify that I couldn't load a seed checkpoint before this PR; with this PR I was able load the seed checkpoint.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that the default setting is to save model only in the last step, freq_cis won't exist. I guess this change makes sense as we just don't need freq_cis being loaded.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vwxyzjn If you use --checkpoint.initial_load_path, you don't need to change anything. A better solution is to change Checkpointer._states_to_load to always pop out freq_cis regardless model_only is True or not. Since it is a bug anyway, we can change Checkpointer to never load freq_cis back despite we are making it a persistent buffer. Right now we only do this when model_only is True.

@tianyu-l what do you think about this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fegin
I agree we should never load freqs_cis.

I think what prevented me from loading a seed checkpoint is the {"MODEL": state_dict} wrapping.
A seed checkpoint is at step-0, however it's not saved with model_only; currently there's no option to do that.
However, according to https://github.com/pytorch/torchtitan/blob/main/torchtitan/components/checkpoint.py#L448, if loading a step-0 checkpoint, it expects model-only.

The reason this PR unblocks me is because it adds the wrapping even for model-only load.
https://github.com/pytorch/torchtitan/pull/1321/files#diff-27a108fa6d4885d9c66306785cb36029c0b4f5a1542e63ae24e84eb7e9a273d1R571

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lkhphuc

My current solution without code change:
set initial_load_path=checkpoint-xxx
initial_load_model_weights_only=false to get the wrapper dict as before
exclude_from_loading=["dataloader", "optimizer", "lr_schedulers", "train_state"] to make the state dict to load only has {MODEL: ...}

What I didn't understand is why you need to set initial_load_path and initial_load_model_weights_only. IIUC if you just put it in --checkpoint.folder it's expecting a full checkpoint, and you only need to set exclude_from_loading which doesn't sound too bad?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tianyu-l It's the prefix issue of the keys, not what should be loaded.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand. I'm saying even if we have the prefix issue right now, @lkhphuc 's use case should have simpler workflow -- just don't use init_load_path. But I just realized he might hit model_only=True because he named the existing checkpoint to be step-0? https://github.com/pytorch/torchtitan/blob/main/torchtitan/components/checkpoint.py#L448

I agree we should flatten the {MODEL:} part, making it much easier to reason about.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I just realized he might hit model_only=True because he named the existing checkpoint to be step-0?

Yes that's what I've been doing, since I want to train on different dataset and clean optimizer, it makes more sense for the run to start at 0.
Plus, using the new initial_load_path is much more ergonomic than having to copy and symlink old checkpoint everytime the dump folder changes too :)

Flatten model checkpoint seems good to me too.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi all, sorry for the delay in response. Is there a clear consensus on what to implement?

Also,

IIUC if you just put it in --checkpoint.folder it's expecting a full checkpoint, and you only need to set exclude_from_loading which doesn't sound too bad?

I think it would work in many setups, but if you have shared file system this becomes tricky, right? You'd need to make a copy of the initial pre-trained checkpoint every time you launch training.

@jquesnelle
Copy link

I think it may be that the modeling code as-is won't be able finetune sequences > 8192 because as of now it doesn't implement the Llama3 RoPE scaling, see https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json#L21 and implemented in https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L385

For sequences > 8192 it assumes standard RoPE, but AFAIK Llama 3.1+ wasn't trained like this, but rather post-scaled to 128K

fegin added a commit that referenced this pull request Jul 3, 2025
The model state_dict is unique compared to other state dictionaries
(e.g., optimizer). It's the only one that will be exported outside of
TorchTitan and imported from other sources. To ensure FQN consistency,
we previously removed the prefix during the first checkpoint load and
last checkpoint save. However, this approach has caused confusion among
users, despite available options to control behavior.

This PR aims to resolve the issue by always flattening the model state
dictionary, eliminating the `"MODEL."` prefix from its keys. We decided
not to flatten all components due to the risk of key collisions between
different components. Instead, this PR only flattens the model
state_dict, which is a special case.

While this solution isn't perfect, as it introduces different handling
for different components, it's a good compromise given the unique nature
of the model state_dict.

Also see the discussion in
#1321 (comment)


This is the pseudo code for the current state:
```
if model_only:
    state_dict = model.state_dict()
else:
    state_dict = {
        "MODEL": model,
        "OPTIMIZER": optimizer,
         ...
     }
}
```

This is the pseudo code after this PR is landed:
```
state_dict = model.state_dict()
if not model_only:
    state_dict.update(
        {"OPTIMIZER": optimizer}
         ...
     )
```



FSDP4 v.s. FSDP4 TP2 loss curve with seed checkpoint and
--training.seed=42

![Uploading Screenshot 2025-07-02 at 1.02.23 PM.png…]()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants