From e7f251a4d90a4219cdd3eb3db1f54c925e9c0194 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Fri, 20 Jun 2025 12:29:59 -0700 Subject: [PATCH] Support finetuning from a pretrained model --- docs/finetune.md | 42 +++++++++++++++++++++++++++++ scripts/convert_llama_to_dcp.py | 8 ------ torchtitan/components/checkpoint.py | 17 +++++------- 3 files changed, 49 insertions(+), 18 deletions(-) create mode 100644 docs/finetune.md diff --git a/docs/finetune.md b/docs/finetune.md new file mode 100644 index 000000000..bcf760971 --- /dev/null +++ b/docs/finetune.md @@ -0,0 +1,42 @@ +## Fine-tuning from an existing checkpoint + +You first need to download the Llama checkpoint. Here are the commands: + +```bash +# Configure these paths as needed +export HF_TOKEN=... # get your HF token from https://huggingface.co/settings/tokens +export ORIGINAL_MODEL_DIR="tmp" +export TOKENIZER_DIR="assets/tokenizer" +export DCP_MODEL_DIR="assets/models/dcp/llama3.1-8B" + +# Download the tokenizer and model weights +rm -rf $ORIGINAL_MODEL_DIR +huggingface-cli download meta-llama/Llama-3.1-8B original/tokenizer.model --local-dir $ORIGINAL_MODEL_DIR +huggingface-cli download meta-llama/Llama-3.1-8B original/consolidated.00.pth --local-dir $ORIGINAL_MODEL_DIR +huggingface-cli download meta-llama/Llama-3.1-8B original/params.json --local-dir $ORIGINAL_MODEL_DIR +# Convert the model weights to the DCP format and move it and the tokenizer to the target directories +mkdir -p $TOKENIZER_DIR && cp $ORIGINAL_MODEL_DIR/original/tokenizer.model $TOKENIZER_DIR/Meta-Llama-3.1-8B-tokenizer.model +python -m scripts.convert_llama_to_dcp $ORIGINAL_MODEL_DIR/original/ $DCP_MODEL_DIR +``` + +Then you can fine-tune from the checkpoint: + +```bash +export TOKENIZER_DIR="assets/tokenizer" +export DCP_MODEL_DIR="assets/models/dcp/llama3.1-8B" +CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" uv run ./run_train.sh \ + --model.tokenizer_path $TOKENIZER_DIR/Meta-Llama-3.1-8B-tokenizer.model \ + --checkpoint.initial_load_path $DCP_MODEL_DIR \ + --checkpoint.enable_checkpoint +``` + +You should see something like this: + +```bash +... +l batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 1000 (warmup 200). +[rank0]:[titan] 2025-06-20 19:13:25,465 - root - INFO - Loading the checkpoint from assets/models/dcp/llama3.1-8B. +[rank0]:[titan] 2025-06-20 19:13:39,662 - root - INFO - [GC] GC collection for checkpoint loading. 0.01 seconds. +[rank0]:[titan] 2025-06-20 19:13:39,663 - root - INFO - Finished loading the checkpoint in 14.20 seconds. +[rank0]:[titan] 2025-06-20 19:13:39,663 - root - INFO - Training starts at step 1. +``` \ No newline at end of file diff --git a/scripts/convert_llama_to_dcp.py b/scripts/convert_llama_to_dcp.py index fa415efad..cac1a908e 100644 --- a/scripts/convert_llama_to_dcp.py +++ b/scripts/convert_llama_to_dcp.py @@ -10,7 +10,6 @@ import torch import torch.distributed.checkpoint as DCP -from torchtitan.models.llama.model import precompute_freqs_cis from torchtitan.tools.logging import init_logger, logger @@ -123,13 +122,6 @@ def convert_llama_weights(input_dir, output_dir, max_seq_len: int): for i in range(len(shards)): del shards[i]["output.weight"] - # NOTE: precompute freqs_cis because must be persisted by default in torchtitan - state_dict["freqs_cis"] = precompute_freqs_cis( - dims_per_head, - max_seq_len, - params.get("rope_theta", 500000), - ) - logger.info(f"Writing to DCP at '{output_dir}'") output_dir.mkdir(parents=True, exist_ok=True) storage_writer = DCP.filesystem.FileSystemWriter(output_dir, thread_count=8) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index b396418a9..443144066 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -41,6 +41,10 @@ LR_SCHEDULER = "lr_scheduler" DATALOADER = "dataloader" TRAIN_STATE = "train_state" +# For now, we will manually pop the freqs_cis buffer, as we made this permanent +# temporarily and we don't want to include it in the exported state_dict. +# Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404 +excluded_parameters_for_model_only = {"freqs_cis"} class AsyncMode(str, enum.Enum): @@ -54,6 +58,7 @@ def __init__(self, model: nn.Module | list[nn.Module]) -> None: self.model = [model] if isinstance(model, nn.Module) else model self.cache_state_dict = { k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items() + if k not in excluded_parameters_for_model_only } def state_dict(self) -> dict[str, Any]: @@ -70,6 +75,7 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: # we will need to reinitialize the cache_state_dict. self.cache_state_dict = { k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items() + if k not in excluded_parameters_for_model_only } @@ -81,12 +87,6 @@ class SaveDone: pass -# For now, we will manually pop the freqs_cis buffer, as we made this permanent -# temporarily and we don't want to include it in the exported state_dict. -# Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404 -excluded_parameters_for_model_only = {"freqs_cis"} - - @torch.no_grad() def save_with_gc(state, checkpoint_id): dcp.save(state, checkpoint_id=checkpoint_id) @@ -568,10 +568,7 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]: """ # For the first step, we will only load the model weights. if model_only: - sd = self.states[MODEL].state_dict() - for k in excluded_parameters_for_model_only: - sd.pop(k, None) - return sd + return {MODEL: self.states[MODEL]} for exclude_key in self.exclude_from_loading: if exclude_key not in self.states: