Skip to content

Support finetuning from a pretrained model #1321

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions docs/finetune.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
## Fine-tuning from an existing checkpoint

You first need to download the Llama checkpoint. Here are the commands:

```bash
# Configure these paths as needed
export HF_TOKEN=... # get your HF token from https://huggingface.co/settings/tokens
export ORIGINAL_MODEL_DIR="tmp"
export TOKENIZER_DIR="assets/tokenizer"
export DCP_MODEL_DIR="assets/models/dcp/llama3.1-8B"

# Download the tokenizer and model weights
rm -rf $ORIGINAL_MODEL_DIR
huggingface-cli download meta-llama/Llama-3.1-8B original/tokenizer.model --local-dir $ORIGINAL_MODEL_DIR
huggingface-cli download meta-llama/Llama-3.1-8B original/consolidated.00.pth --local-dir $ORIGINAL_MODEL_DIR
huggingface-cli download meta-llama/Llama-3.1-8B original/params.json --local-dir $ORIGINAL_MODEL_DIR
# Convert the model weights to the DCP format and move it and the tokenizer to the target directories
mkdir -p $TOKENIZER_DIR && cp $ORIGINAL_MODEL_DIR/original/tokenizer.model $TOKENIZER_DIR/Meta-Llama-3.1-8B-tokenizer.model
python -m scripts.convert_llama_to_dcp $ORIGINAL_MODEL_DIR/original/ $DCP_MODEL_DIR
```

Then you can fine-tune from the checkpoint:

```bash
export TOKENIZER_DIR="assets/tokenizer"
export DCP_MODEL_DIR="assets/models/dcp/llama3.1-8B"
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" uv run ./run_train.sh \
--model.tokenizer_path $TOKENIZER_DIR/Meta-Llama-3.1-8B-tokenizer.model \
--checkpoint.initial_load_path $DCP_MODEL_DIR \
--checkpoint.enable_checkpoint
```

You should see something like this:

```bash
...
l batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 1000 (warmup 200).
[rank0]:[titan] 2025-06-20 19:13:25,465 - root - INFO - Loading the checkpoint from assets/models/dcp/llama3.1-8B.
[rank0]:[titan] 2025-06-20 19:13:39,662 - root - INFO - [GC] GC collection for checkpoint loading. 0.01 seconds.
[rank0]:[titan] 2025-06-20 19:13:39,663 - root - INFO - Finished loading the checkpoint in 14.20 seconds.
[rank0]:[titan] 2025-06-20 19:13:39,663 - root - INFO - Training starts at step 1.
```
8 changes: 0 additions & 8 deletions scripts/convert_llama_to_dcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import torch
import torch.distributed.checkpoint as DCP
from torchtitan.models.llama.model import precompute_freqs_cis
from torchtitan.tools.logging import init_logger, logger


Expand Down Expand Up @@ -123,13 +122,6 @@ def convert_llama_weights(input_dir, output_dir, max_seq_len: int):
for i in range(len(shards)):
del shards[i]["output.weight"]

# NOTE: precompute freqs_cis because must be persisted by default in torchtitan
state_dict["freqs_cis"] = precompute_freqs_cis(
dims_per_head,
max_seq_len,
params.get("rope_theta", 500000),
)

logger.info(f"Writing to DCP at '{output_dir}'")
output_dir.mkdir(parents=True, exist_ok=True)
storage_writer = DCP.filesystem.FileSystemWriter(output_dir, thread_count=8)
Expand Down
17 changes: 7 additions & 10 deletions torchtitan/components/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@
LR_SCHEDULER = "lr_scheduler"
DATALOADER = "dataloader"
TRAIN_STATE = "train_state"
# For now, we will manually pop the freqs_cis buffer, as we made this permanent
# temporarily and we don't want to include it in the exported state_dict.
# Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404
excluded_parameters_for_model_only = {"freqs_cis"}


class AsyncMode(str, enum.Enum):
Expand All @@ -54,6 +58,7 @@ def __init__(self, model: nn.Module | list[nn.Module]) -> None:
self.model = [model] if isinstance(model, nn.Module) else model
self.cache_state_dict = {
k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()
if k not in excluded_parameters_for_model_only
}

def state_dict(self) -> dict[str, Any]:
Expand All @@ -70,6 +75,7 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
# we will need to reinitialize the cache_state_dict.
self.cache_state_dict = {
k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()
if k not in excluded_parameters_for_model_only
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed here because I was getting RuntimeError: Missing key in checkpoint state_dict: model.freqs_cis. if I am trying to load a full checkpoint.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fegin can you take a look at this part?
I can verify that I couldn't load a seed checkpoint before this PR; with this PR I was able load the seed checkpoint.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that the default setting is to save model only in the last step, freq_cis won't exist. I guess this change makes sense as we just don't need freq_cis being loaded.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vwxyzjn If you use --checkpoint.initial_load_path, you don't need to change anything. A better solution is to change Checkpointer._states_to_load to always pop out freq_cis regardless model_only is True or not. Since it is a bug anyway, we can change Checkpointer to never load freq_cis back despite we are making it a persistent buffer. Right now we only do this when model_only is True.

@tianyu-l what do you think about this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fegin
I agree we should never load freqs_cis.

I think what prevented me from loading a seed checkpoint is the {"MODEL": state_dict} wrapping.
A seed checkpoint is at step-0, however it's not saved with model_only; currently there's no option to do that.
However, according to https://github.com/pytorch/torchtitan/blob/main/torchtitan/components/checkpoint.py#L448, if loading a step-0 checkpoint, it expects model-only.

The reason this PR unblocks me is because it adds the wrapping even for model-only load.
https://github.com/pytorch/torchtitan/pull/1321/files#diff-27a108fa6d4885d9c66306785cb36029c0b4f5a1542e63ae24e84eb7e9a273d1R571

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lkhphuc

My current solution without code change:
set initial_load_path=checkpoint-xxx
initial_load_model_weights_only=false to get the wrapper dict as before
exclude_from_loading=["dataloader", "optimizer", "lr_schedulers", "train_state"] to make the state dict to load only has {MODEL: ...}

What I didn't understand is why you need to set initial_load_path and initial_load_model_weights_only. IIUC if you just put it in --checkpoint.folder it's expecting a full checkpoint, and you only need to set exclude_from_loading which doesn't sound too bad?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tianyu-l It's the prefix issue of the keys, not what should be loaded.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand. I'm saying even if we have the prefix issue right now, @lkhphuc 's use case should have simpler workflow -- just don't use init_load_path. But I just realized he might hit model_only=True because he named the existing checkpoint to be step-0? https://github.com/pytorch/torchtitan/blob/main/torchtitan/components/checkpoint.py#L448

I agree we should flatten the {MODEL:} part, making it much easier to reason about.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I just realized he might hit model_only=True because he named the existing checkpoint to be step-0?

Yes that's what I've been doing, since I want to train on different dataset and clean optimizer, it makes more sense for the run to start at 0.
Plus, using the new initial_load_path is much more ergonomic than having to copy and symlink old checkpoint everytime the dump folder changes too :)

Flatten model checkpoint seems good to me too.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi all, sorry for the delay in response. Is there a clear consensus on what to implement?

Also,

IIUC if you just put it in --checkpoint.folder it's expecting a full checkpoint, and you only need to set exclude_from_loading which doesn't sound too bad?

I think it would work in many setups, but if you have shared file system this becomes tricky, right? You'd need to make a copy of the initial pre-trained checkpoint every time you launch training.

}


Expand All @@ -81,12 +87,6 @@ class SaveDone:
pass


# For now, we will manually pop the freqs_cis buffer, as we made this permanent
# temporarily and we don't want to include it in the exported state_dict.
# Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404
excluded_parameters_for_model_only = {"freqs_cis"}


@torch.no_grad()
def save_with_gc(state, checkpoint_id):
dcp.save(state, checkpoint_id=checkpoint_id)
Expand Down Expand Up @@ -568,10 +568,7 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]:
"""
# For the first step, we will only load the model weights.
if model_only:
sd = self.states[MODEL].state_dict()
for k in excluded_parameters_for_model_only:
sd.pop(k, None)
return sd
return {MODEL: self.states[MODEL]}

for exclude_key in self.exclude_from_loading:
if exclude_key not in self.states:
Expand Down