-
Notifications
You must be signed in to change notification settings - Fork 420
Support finetuning from a pretrained model #1321
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
vwxyzjn
wants to merge
1
commit into
pytorch:main
Choose a base branch
from
vwxyzjn:new-finetune
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
## Fine-tuning from an existing checkpoint | ||
|
||
You first need to download the Llama checkpoint. Here are the commands: | ||
|
||
```bash | ||
# Configure these paths as needed | ||
export HF_TOKEN=... # get your HF token from https://huggingface.co/settings/tokens | ||
export ORIGINAL_MODEL_DIR="tmp" | ||
export TOKENIZER_DIR="assets/tokenizer" | ||
export DCP_MODEL_DIR="assets/models/dcp/llama3.1-8B" | ||
|
||
# Download the tokenizer and model weights | ||
rm -rf $ORIGINAL_MODEL_DIR | ||
huggingface-cli download meta-llama/Llama-3.1-8B original/tokenizer.model --local-dir $ORIGINAL_MODEL_DIR | ||
huggingface-cli download meta-llama/Llama-3.1-8B original/consolidated.00.pth --local-dir $ORIGINAL_MODEL_DIR | ||
huggingface-cli download meta-llama/Llama-3.1-8B original/params.json --local-dir $ORIGINAL_MODEL_DIR | ||
# Convert the model weights to the DCP format and move it and the tokenizer to the target directories | ||
mkdir -p $TOKENIZER_DIR && cp $ORIGINAL_MODEL_DIR/original/tokenizer.model $TOKENIZER_DIR/Meta-Llama-3.1-8B-tokenizer.model | ||
python -m scripts.convert_llama_to_dcp $ORIGINAL_MODEL_DIR/original/ $DCP_MODEL_DIR | ||
``` | ||
|
||
Then you can fine-tune from the checkpoint: | ||
|
||
```bash | ||
export TOKENIZER_DIR="assets/tokenizer" | ||
export DCP_MODEL_DIR="assets/models/dcp/llama3.1-8B" | ||
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" uv run ./run_train.sh \ | ||
--model.tokenizer_path $TOKENIZER_DIR/Meta-Llama-3.1-8B-tokenizer.model \ | ||
--checkpoint.initial_load_path $DCP_MODEL_DIR \ | ||
--checkpoint.enable_checkpoint | ||
``` | ||
|
||
You should see something like this: | ||
|
||
```bash | ||
... | ||
l batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 1000 (warmup 200). | ||
[rank0]:[titan] 2025-06-20 19:13:25,465 - root - INFO - Loading the checkpoint from assets/models/dcp/llama3.1-8B. | ||
[rank0]:[titan] 2025-06-20 19:13:39,662 - root - INFO - [GC] GC collection for checkpoint loading. 0.01 seconds. | ||
[rank0]:[titan] 2025-06-20 19:13:39,663 - root - INFO - Finished loading the checkpoint in 14.20 seconds. | ||
[rank0]:[titan] 2025-06-20 19:13:39,663 - root - INFO - Training starts at step 1. | ||
``` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed here because I was getting
RuntimeError: Missing key in checkpoint state_dict: model.freqs_cis.
if I am trying to load a full checkpoint.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fegin can you take a look at this part?
I can verify that I couldn't load a seed checkpoint before this PR; with this PR I was able load the seed checkpoint.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that the default setting is to save
model
only in the last step,freq_cis
won't exist. I guess this change makes sense as we just don't needfreq_cis
being loaded.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vwxyzjn If you use
--checkpoint.initial_load_path
, you don't need to change anything. A better solution is to changeCheckpointer._states_to_load
to always pop outfreq_cis
regardlessmodel_only
is True or not. Since it is a bug anyway, we can changeCheckpointer
to never loadfreq_cis
back despite we are making it a persistent buffer. Right now we only do this whenmodel_only
is True.@tianyu-l what do you think about this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fegin
I agree we should never load
freqs_cis
.I think what prevented me from loading a seed checkpoint is the
{"MODEL": state_dict}
wrapping.A seed checkpoint is at step-0, however it's not saved with
model_only
; currently there's no option to do that.However, according to https://github.com/pytorch/torchtitan/blob/main/torchtitan/components/checkpoint.py#L448, if loading a step-0 checkpoint, it expects model-only.
The reason this PR unblocks me is because it adds the wrapping even for model-only load.
https://github.com/pytorch/torchtitan/pull/1321/files#diff-27a108fa6d4885d9c66306785cb36029c0b4f5a1542e63ae24e84eb7e9a273d1R571
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lkhphuc
What I didn't understand is why you need to set
initial_load_path
andinitial_load_model_weights_only
. IIUC if you just put it in--checkpoint.folder
it's expecting a full checkpoint, and you only need to setexclude_from_loading
which doesn't sound too bad?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tianyu-l It's the prefix issue of the keys, not what should be loaded.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand. I'm saying even if we have the prefix issue right now, @lkhphuc 's use case should have simpler workflow -- just don't use
init_load_path
. But I just realized he might hitmodel_only=True
because he named the existing checkpoint to be step-0? https://github.com/pytorch/torchtitan/blob/main/torchtitan/components/checkpoint.py#L448I agree we should flatten the
{MODEL:}
part, making it much easier to reason about.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that's what I've been doing, since I want to train on different dataset and clean optimizer, it makes more sense for the run to start at 0.
Plus, using the new
initial_load_path
is much more ergonomic than having to copy and symlink old checkpoint everytime the dump folder changes too :)Flatten model checkpoint seems good to me too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi all, sorry for the delay in response. Is there a clear consensus on what to implement?
Also,
I think it would work in many setups, but if you have shared file system this becomes tricky, right? You'd need to make a copy of the initial pre-trained checkpoint every time you launch training.