-
Notifications
You must be signed in to change notification settings - Fork 419
Support finetuning from a pretrained model #1321
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@@ -70,6 +75,7 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: | |||
# we will need to reinitialize the cache_state_dict. | |||
self.cache_state_dict = { | |||
k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items() | |||
if k not in excluded_parameters_for_model_only |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed here because I was getting RuntimeError: Missing key in checkpoint state_dict: model.freqs_cis.
if I am trying to load a full checkpoint.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fegin can you take a look at this part?
I can verify that I couldn't load a seed checkpoint before this PR; with this PR I was able load the seed checkpoint.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that the default setting is to save model
only in the last step, freq_cis
won't exist. I guess this change makes sense as we just don't need freq_cis
being loaded.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vwxyzjn If you use --checkpoint.initial_load_path
, you don't need to change anything. A better solution is to change Checkpointer._states_to_load
to always pop out freq_cis
regardless model_only
is True or not. Since it is a bug anyway, we can change Checkpointer
to never load freq_cis
back despite we are making it a persistent buffer. Right now we only do this when model_only
is True.
@tianyu-l what do you think about this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fegin
I agree we should never load freqs_cis
.
I think what prevented me from loading a seed checkpoint is the {"MODEL": state_dict}
wrapping.
A seed checkpoint is at step-0, however it's not saved with model_only
; currently there's no option to do that.
However, according to https://github.com/pytorch/torchtitan/blob/main/torchtitan/components/checkpoint.py#L448, if loading a step-0 checkpoint, it expects model-only.
The reason this PR unblocks me is because it adds the wrapping even for model-only load.
https://github.com/pytorch/torchtitan/pull/1321/files#diff-27a108fa6d4885d9c66306785cb36029c0b4f5a1542e63ae24e84eb7e9a273d1R571
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My current solution without code change:
set initial_load_path=checkpoint-xxx
initial_load_model_weights_only=false to get the wrapper dict as before
exclude_from_loading=["dataloader", "optimizer", "lr_schedulers", "train_state"] to make the state dict to load only has {MODEL: ...}
What I didn't understand is why you need to set initial_load_path
and initial_load_model_weights_only
. IIUC if you just put it in --checkpoint.folder
it's expecting a full checkpoint, and you only need to set exclude_from_loading
which doesn't sound too bad?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tianyu-l It's the prefix issue of the keys, not what should be loaded.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand. I'm saying even if we have the prefix issue right now, @lkhphuc 's use case should have simpler workflow -- just don't use init_load_path
. But I just realized he might hit model_only=True
because he named the existing checkpoint to be step-0? https://github.com/pytorch/torchtitan/blob/main/torchtitan/components/checkpoint.py#L448
I agree we should flatten the {MODEL:}
part, making it much easier to reason about.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But I just realized he might hit model_only=True because he named the existing checkpoint to be step-0?
Yes that's what I've been doing, since I want to train on different dataset and clean optimizer, it makes more sense for the run to start at 0.
Plus, using the new initial_load_path
is much more ergonomic than having to copy and symlink old checkpoint everytime the dump folder changes too :)
Flatten model checkpoint seems good to me too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi all, sorry for the delay in response. Is there a clear consensus on what to implement?
Also,
IIUC if you just put it in --checkpoint.folder it's expecting a full checkpoint, and you only need to set exclude_from_loading which doesn't sound too bad?
I think it would work in many setups, but if you have shared file system this becomes tricky, right? You'd need to make a copy of the initial pre-trained checkpoint every time you launch training.
I think it may be that the modeling code as-is won't be able finetune sequences > 8192 because as of now it doesn't implement the Llama3 RoPE scaling, see https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json#L21 and implemented in https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L385 For sequences > 8192 it assumes standard RoPE, but AFAIK Llama 3.1+ wasn't trained like this, but rather post-scaled to 128K |
The model state_dict is unique compared to other state dictionaries (e.g., optimizer). It's the only one that will be exported outside of TorchTitan and imported from other sources. To ensure FQN consistency, we previously removed the prefix during the first checkpoint load and last checkpoint save. However, this approach has caused confusion among users, despite available options to control behavior. This PR aims to resolve the issue by always flattening the model state dictionary, eliminating the `"MODEL."` prefix from its keys. We decided not to flatten all components due to the risk of key collisions between different components. Instead, this PR only flattens the model state_dict, which is a special case. While this solution isn't perfect, as it introduces different handling for different components, it's a good compromise given the unique nature of the model state_dict. Also see the discussion in #1321 (comment) This is the pseudo code for the current state: ``` if model_only: state_dict = model.state_dict() else: state_dict = { "MODEL": model, "OPTIMIZER": optimizer, ... } } ``` This is the pseudo code after this PR is landed: ``` state_dict = model.state_dict() if not model_only: state_dict.update( {"OPTIMIZER": optimizer} ... ) ``` FSDP4 v.s. FSDP4 TP2 loss curve with seed checkpoint and --training.seed=42 ![Uploading Screenshot 2025-07-02 at 1.02.23 PM.png…]()
Continuation of #1300. Cleaner implementation.