diff --git a/examples/speculative_decoding/README.md b/examples/speculative_decoding/README.md index 2a29f644e6..b87bb1f768 100644 --- a/examples/speculative_decoding/README.md +++ b/examples/speculative_decoding/README.md @@ -73,14 +73,16 @@ This one-line command runs a minimal example workflow of training and exporting For small base models that fit in GPU memory, we can collocate them with draft models and train with the following command: ```bash -./launch_train.sh --model $BASE_MODEL \ - --output_dir $OUTPUT_DIR \ - --data input_conversations/train.jsonl \ - --num_epochs $NUM_EPOCH \ - --eagle_config eagle_config.json +./launch_train.sh \ + --config ../../modelopt_recipes/general/speculative_decoding/eagle3.yaml \ + model.model_name_or_path=meta-llama/Llama-3.2-1B \ + data.data_path=input_conversations/train.jsonl \ + training.output_dir=ckpts/llama-3.2-1b-online ``` -FSDP2 is used by default. To enable context parallelism for long-context training, specify `--cp_size n`. +All default training settings live in `eagle3.yaml`; override any field via OmegaConf dotlist arguments on the command line. + +To enable context parallelism for long-context training, add `training.cp_size=` to the overrides. The saved modelopt checkpoint is similar in architecture to HF models. It can be further optimized through **ModelOpt**, e.g., PTQ and QAT. ## Training Draft Model with Offline Base Model @@ -113,15 +115,14 @@ python collect_hidden_states/compute_hidden_states_hf.py \ ### Train Draft Model with Dumped Hidden States -Once we finish dumping hidden states, launch offline training with an extra `--offline-data` argument: +Once we finish dumping hidden states, launch offline training pointing to the hidden states directory: ```bash -./launch_train.sh --model $BASE_MODEL \ - --output_dir $OUTPUT_DIR \ - --data $DATA \ - --num_epochs $NUM_EPOCH \ - --eagle_config eagle_config.json \ - --offline-data $HIDDEN_STATES_DIR +./launch_train.sh \ + --config ../../modelopt_recipes/general/speculative_decoding/eagle3.yaml \ + model.model_name_or_path=meta-llama/Llama-3.2-1B \ + data.offline_data_path=$HIDDEN_STATES_DIR \ + training.output_dir=ckpts/llama-3.2-1b-offline ``` ## Model Validation @@ -244,13 +245,13 @@ For large scale data generation, please see [SLURM prepare data](SLURM_prepare_d ### Configuring Draft Model -For EAGLE‑1 and EAGLE‑3 we provide a [default model architecture config](https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt/torch/speculative/config.py#L37) in ModelOpt. You can override default settings by providing an additional JSON dict. E.g. To use 2-layer eagle with 8192 intermediate size for MLP, set `eagle_config.json` to: +For EAGLE‑1 and EAGLE‑3 we provide a [default model architecture config](https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt/torch/speculative/config.py#L37) in ModelOpt. You can override default settings via `eagle.eagle_architecture_config` in the YAML. E.g. to use a 2-layer EAGLE head with 8192 intermediate size: -```json -{ - "num_hidden_layers": 2, - "intermediate_size":8192 -} +```yaml +eagle: + eagle_architecture_config: + num_hidden_layers: 2 + intermediate_size: 8192 ``` ### Draft Vocabulary Compression @@ -263,61 +264,26 @@ python scripts/calibrate_draft_vocab.py --model meta-llama/Llama-3.2-1B-Instruct This will produce a `d2t.pt` file in `save_dir`, which is the mapping from draft token to target token. During inference, draft tokens can be mapped back to target tokens by `target_token = draft_token + d2t[draft_token]`. -Then, simply set `{"draft_vocab_size":32000}` in `eagle_config.json` and include `--draft_vocab_cache ` when running `./launch_train.sh`. The draft model will use this provided vocab table during training and export. +Then, set `eagle_architecture_config.draft_vocab_size: 32000` and `data.draft_vocab_cache: ` in your YAML. The draft model will use this provided vocab table during training and export. ### Interact with `modelopt.torch.speculative` -`main.py` provides an example for converting a HF base model for speculative decoding and training it. It consists of a few simple steps: -First, load the base model and tokenizer from Hugging Face: - -```python -model = transformers.AutoModelForCausalLM.from_pretrained( - "" -) -``` - -Then, load default eagle config and make necessary overwrites: +`main.py` provides a complete example for converting a HF base model for speculative decoding and training it. The core steps are loading the base model, converting it with an eagle config dict, and training with HF Trainer: ```python -# Load default config -config = { - "eagle1": EAGLE1_DEFAULT_CFG, - "eagle3": EAGLE3_DEFAULT_CFG, -}[training_args.mode]["config"] - -# overwrite config with custom config -config["eagle_architecture_config"].update({"": ""}) - -# Mandatory: hidden size, vocab size and max position embeddings must match base model -config["eagle_architecture_config"].update( - { - "hidden_size": model.config.hidden_size, - "vocab_size": model.config.vocab_size, - "max_position_embeddings": model.config.max_position_embeddings, - } -) -``` +import modelopt.torch.speculative as mtsp -Then, we convert model to a speculative decoding model: +# Convert base model in-place to an EAGLE speculative decoding model +eagle_cfg = {"eagle_decoder_type": "llama", ...} # fields from EagleConfig +mtsp.convert(model, [("eagle", eagle_cfg)]) -```python -mtsp.convert(model, [("eagle", config)]) +# Train with HF Trainer as usual +trainer = transformers.Trainer(model=model, ...) +trainer.train() +trainer.save_model("") ``` -This will modify the model in-place with eagle training forward, making it compatible with HF trainer: - -```python -# Create a trainer -trainer = transformers.Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) -trainer._move_model_to_device(model, trainer.args.device) - -# Enable HF checkpointing so that the saved model will contain the speculative decoding module -mto.enable_huggingface_checkpointing() - -trainer.train(resume_from_checkpoint=checkpoint) -trainer.save_state() -trainer.save_model("") -``` +See `main.py` for the full example including tokenizer setup, dataset loading, and checkpoint handling. ## Support Matrix diff --git a/examples/speculative_decoding/eagle_config.json b/examples/speculative_decoding/eagle_config.json deleted file mode 100644 index 2c63c08510..0000000000 --- a/examples/speculative_decoding/eagle_config.json +++ /dev/null @@ -1,2 +0,0 @@ -{ -} diff --git a/examples/speculative_decoding/fsdp_config.json b/examples/speculative_decoding/fsdp_config.json deleted file mode 100644 index 6d934182fe..0000000000 --- a/examples/speculative_decoding/fsdp_config.json +++ /dev/null @@ -1 +0,0 @@ -{"fsdp_version":2} \ No newline at end of file diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index c15b97bdaa..41d71d1417 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -14,277 +14,62 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Usage: +# Single GPU: ./launch_train.sh --config ../../modelopt_recipes/general/speculative_decoding/eagle3.yaml model.model_name_or_path=xxx +# Multi-node: ./launch_train.sh --config ../../modelopt_recipes/general/speculative_decoding/eagle3.yaml --num_nodes 2 --head_node_ip +# With overrides: ./launch_train.sh --config my.yaml model.model_name_or_path=xxx training.output_dir=yyy +# +# Extra key=value args are forwarded as OmegaConf dotlist overrides to main.py. +# All training config (model, data, hyperparams, eagle, fsdp) lives in the YAML file. +# Only multi-node routing args are passed here; mixed_precision is fixed to bf16. + set -eo pipefail +SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" + +CONFIG_FILE="" +NUM_NODES=1 +HEAD_NODE_IP="" +EXTRA_ARGS=() while [ $# -gt 0 ]; do case "$1" in - --training_seq_len*) - if [[ "$1" != *=* ]]; then shift; fi - TRAINING_SEQ_LEN="${1#*=}" - ;; - --model*) - if [[ "$1" != *=* ]]; then shift; fi - MODEL="${1#*=}" - ;; - --data*) - if [[ "$1" != *=* ]]; then shift; fi - DATA="${1#*=}" - ;; - --offline-data*) - if [[ "$1" != *=* ]]; then shift; fi - OFFLINE_DATA_PATH="${1#*=}" - ;; - --mode*) - if [[ "$1" != *=* ]]; then shift; fi - MODE="${1#*=}" - ;; - --eagle_decoder_type*) - if [[ "$1" != *=* ]]; then shift; fi - EAGLE_DECODER_TYPE="${1#*=}" - ;; - --output_dir*) - if [[ "$1" != *=* ]]; then shift; fi - OUTPUT_DIR="${1#*=}" - ;; - --num_epochs*) - if [[ "$1" != *=* ]]; then shift; fi - NUM_EPOCHS="${1#*=}" - ;; - --save_steps*) - if [[ "$1" != *=* ]]; then shift; fi - SAVE_STEPS="${1#*=}" - ;; - --lr*) - if [[ "$1" != *=* ]]; then shift; fi - LR="${1#*=}" - ;; - --train_bs*) - if [[ "$1" != *=* ]]; then shift; fi - TRAIN_BS="${1#*=}" - ;; - --eagle_config*) - if [[ "$1" != *=* ]]; then shift; fi - EAGLE_CONFIG="${1#*=}" - ;; - --disable_tqdm*) - if [[ "$1" != *=* ]]; then shift; fi - DISABLE_TQDM="${1#*=}" - ;; - --vlm_processor*) - if [[ "$1" != *=* ]]; then shift; fi - VLM_PROCESSOR="${1#*=}" - ;; - --vlm_img_dir*) - if [[ "$1" != *=* ]]; then shift; fi - VLM_IMG_DIR="${1#*=}" - ;; - --estimate_ar*) - if [[ "$1" != *=* ]]; then shift; fi - ESTIMATE_AR="${1#*=}" - ;; - --ar_validate_steps*) - if [[ "$1" != *=* ]]; then shift; fi - AR_VALIDATE_STEPS="${1#*=}" - ;; - --num_ttt_steps*) - if [[ "$1" != *=* ]]; then shift; fi - NUM_TTT_STEPS="${1#*=}" - ;; - --cp_size*) - if [[ "$1" != *=* ]]; then shift; fi - CP_SIZE="${1#*=}" - ;; - --dp_size*) - if [[ "$1" != *=* ]]; then shift; fi - DP_SHARD_SIZE="${1#*=}" - ;; - --log_steps*) - if [[ "$1" != *=* ]]; then shift; fi - LOG_STEPS="${1#*=}" - ;; - --draft_vocab_cache*) - if [[ "$1" != *=* ]]; then shift; fi - DRAFT_VOCAB_CACHE="${1#*=}" - ;; - --num_nodes*) - if [[ "$1" != *=* ]]; then shift; fi - NUM_NODES="${1#*=}" - ;; - --head_node_ip*) - if [[ "$1" != *=* ]]; then shift; fi - HEAD_NODE_IP="${1#*=}" - ;; - --mix_hidden_states*) - if [[ "$1" != *=* ]]; then shift; fi - MIX_HIDDEN_STATES="${1#*=}" - ;; - --disable_torch_compile*) - if [[ "$1" != *=* ]]; then shift; fi - DISABLE_TORCH_COMPILE="${1#*=}" - ;; - --use_fake_base_for_offline*) - if [[ "$1" != *=* ]]; then shift; fi - USE_FAKE_BASE_FOR_OFFLINE="${1#*=}" - ;; - --trust_remote_code*) - if [[ "$1" != *=* ]]; then shift; fi - TRUST_REMOTE_CODE="${1#*=}" - ;; - --fsdp*) - if [[ "$1" != *=* ]]; then shift; fi - FSDP="${1#*=}" - ;; - *) - >&2 printf "Error: Invalid argument ${1#*=}\n" - exit 1 - ;; + --config*) if [[ "$1" != *=* ]]; then shift; fi; CONFIG_FILE="${1#*=}" ;; + --num_nodes*) if [[ "$1" != *=* ]]; then shift; fi; NUM_NODES="${1#*=}" ;; + --head_node_ip*) if [[ "$1" != *=* ]]; then shift; fi; HEAD_NODE_IP="${1#*=}" ;; + *) EXTRA_ARGS+=("$1") ;; esac shift done -set -x +if [ -z "$CONFIG_FILE" ]; then + >&2 echo "Usage: ./launch_train.sh --config [--num_nodes N] [--head_node_ip IP] [key=value ...]" + exit 1 +fi -SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" -NUM_NODES=${NUM_NODES:-1} -if [[ "$NUM_NODES" != 1 ]]; then - #Multi Node Training +# GPU count detection +if [[ "$NUM_NODES" != "1" ]]; then GPU_PER_NODE=${GPU_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)} TOTAL_GPU=$((NUM_NODES * GPU_PER_NODE)) echo "Total GPUs: $TOTAL_GPU (NUM_NODES: $NUM_NODES, GPU_PER_NODE: $GPU_PER_NODE)" else - #Single Node Training, GPU can be specified by $CUDA_VISIBLE_DEVICES - TOTAL_GPU=$(python -c "import torch; print(torch.cuda.device_count())") - echo "Total GPUs: $TOTAL_GPU (Single Node Training)" -fi -# Calculate save_steps -DEFAULT_SAVE_STEPS=$((8192 / TOTAL_GPU)) - -MODEL=${MODEL:-"TinyLlama/TinyLlama-1.1B-Chat-v1.0"} -MODE=${MODE:-"eagle3"} -EAGLE_DECODER_TYPE=${EAGLE_DECODER_TYPE:-"llama"} -# Set default OUTPUT_DIR to ckpts/{modelname}, where {modelname} is the last part of the model path -MODEL_BASENAME=$(basename "$MODEL") -OUTPUT_DIR=${OUTPUT_DIR:-"ckpts/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M)"} -NUM_EPOCHS=${NUM_EPOCHS:-1} -SAVE_STEPS=${SAVE_STEPS:-$DEFAULT_SAVE_STEPS} -LR=${LR:-"1e-4"} -TRAIN_BS=${TRAIN_BS:-1} -TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-2048} -DATA=${DATA:-""} -OFFLINE_DATA_PATH=${OFFLINE_DATA_PATH:-""} -DISABLE_TQDM=${DISABLE_TQDM:-False} -VLM_PROCESSOR=${VLM_PROCESSOR:-} -VLM_IMG_DIR=${VLM_IMG_DIR:-} -AR_VALIDATE_STEPS=${AR_VALIDATE_STEPS:-1000} -ESTIMATE_AR=${ESTIMATE_AR:-False} -CP_SIZE=${CP_SIZE:-1} -DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((TOTAL_GPU/CP_SIZE))} -LOG_STEPS=${LOG_STEPS:-100} -DRAFT_VOCAB_CACHE=${DRAFT_VOCAB_CACHE:-""} -MIX_HIDDEN_STATES=${MIX_HIDDEN_STATES:-"False"} -DISABLE_TORCH_COMPILE=${DISABLE_TORCH_COMPILE:-"False"} -NUM_TTT_STEPS=${NUM_TTT_STEPS:-3} - -USE_FAKE_BASE_FOR_OFFLINE=${USE_FAKE_BASE_FOR_OFFLINE:-"False"} -TRUST_REMOTE_CODE=${TRUST_REMOTE_CODE:-"False"} -FSDP=${FSDP:-"False"} - -if [[ "$MODE" == "eagle3" ]]; then - if [[ -n "$EAGLE_CONFIG" ]]; then - SPECULATIVE_ARGS="--eagle_config $EAGLE_CONFIG" - else - SPECULATIVE_ARGS="" - fi -else - echo "Only eagle3 supported for now!" - exit 1 -fi - -if [[ "$OFFLINE_DATA_PATH" != "" ]]; then - if [[ ! -d "$OFFLINE_DATA_PATH" ]]; then - echo "Offline data path $OFFLINE_DATA_PATH does not exist or is not a directory." - exit 1 - else - DATA_ARGS="--offline-data-path $OFFLINE_DATA_PATH --ar_validate_steps -1" - fi -else - DATA_ARGS="--data_path $DATA" -fi - - -if [[ "$VLM_PROCESSOR" != "" ]]; then - VLM_ARGS="--vlm_processor $VLM_PROCESSOR --vlm_img_dir $VLM_IMG_DIR" -else - VLM_ARGS="" + TOTAL_GPU=$(python3 -c "import torch; print(torch.cuda.device_count())") + echo "Total GPUs: $TOTAL_GPU (single node)" fi -if [[ "$TOTAL_GPU" -gt 1 && "$FSDP" == "True" ]]; then - #Use FSDP2 when multi GPU available - FSDP_ARGS="--fsdp 'full_shard' --fsdp_config ${SCRIPT_DIR}/fsdp_config.json" -else - #Otherwise, single GPU training - FSDP_ARGS="" -fi - - -if [[ "$DRAFT_VOCAB_CACHE" != "" ]]; then - DRAFT_VOCAB_CACHE_ARGS="--draft_vocab_cache $DRAFT_VOCAB_CACHE" -else - DRAFT_VOCAB_CACHE_ARGS="" -fi - -if [[ "$NUM_NODES" != 1 ]]; then +# Multi-node routing args (accelerate only; training config comes from the YAML) +MULTI_NODE_ARGS="" +if [[ "$NUM_NODES" != "1" ]]; then MULTI_NODE_ARGS="--num_processes $TOTAL_GPU \ --num_machines $NUM_NODES \ --machine_rank $SLURM_PROCID \ --rdzv_backend c10d \ --main_process_ip $HEAD_NODE_IP \ --main_process_port 29500" -else - MULTI_NODE_ARGS="" fi -# Disable tokenizers parallelism to avoid warning export TOKENIZERS_PARALLELISM=False -CMD="accelerate launch $MULTI_NODE_ARGS --mixed_precision bf16 ${SCRIPT_DIR}/main.py \ - --mode $MODE \ - --eagle_decoder_type $EAGLE_DECODER_TYPE \ - --model_name_or_path $MODEL \ - --training_seq_len $TRAINING_SEQ_LEN \ - --dataloader_drop_last True \ - --bf16 True \ - --output_dir $OUTPUT_DIR \ - --num_train_epochs $NUM_EPOCHS \ - --per_device_train_batch_size $TRAIN_BS \ - --per_device_eval_batch_size $TRAIN_BS \ - --gradient_accumulation_steps 1 \ - --do_eval False \ - --eval_accumulation_steps 1 \ - --save_strategy steps \ - --save_steps $SAVE_STEPS \ - --learning_rate $LR \ - --weight_decay 0.0 \ - --warmup_steps 100 \ - --lr_scheduler_type linear \ - --logging_steps $LOG_STEPS \ - --tf32 True \ - $DATA_ARGS \ - --disable_tqdm $DISABLE_TQDM \ - --estimate_ar $ESTIMATE_AR \ - --ar_validate_steps $AR_VALIDATE_STEPS \ - --mix_hidden_states $MIX_HIDDEN_STATES \ - --disable_torch_compile $DISABLE_TORCH_COMPILE \ - --use_fake_base_for_offline $USE_FAKE_BASE_FOR_OFFLINE \ - --trust_remote_code $TRUST_REMOTE_CODE \ - $DRAFT_VOCAB_CACHE_ARGS \ - $VLM_ARGS \ - $SPECULATIVE_ARGS \ - $FSDP_ARGS \ - --cp_size $CP_SIZE \ - --dp_shard_size $DP_SHARD_SIZE \ - --num_ttt_steps $NUM_TTT_STEPS \ -" +set -x start_time=$(date +%s) -sh -c "$CMD" -echo "Total time taken: $(( $(date +%s) - $start_time )) seconds" +sh -c "accelerate launch --mixed_precision bf16 $MULTI_NODE_ARGS ${SCRIPT_DIR}/main.py --config $CONFIG_FILE ${EXTRA_ARGS[*]}" +echo "Total time: $(( $(date +%s) - $start_time )) seconds" diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 3369d399c2..694aa3303f 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -29,7 +29,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json +import argparse import os from dataclasses import dataclass, field from typing import Literal @@ -43,6 +43,7 @@ make_eagle_supervised_data_module, patch_ring_attention_for_ttt, ) +from omegaconf import OmegaConf from transformers.trainer_utils import get_last_checkpoint import modelopt.torch.opt as mto @@ -56,12 +57,18 @@ @dataclass class ModelArguments: - model_name_or_path: str | None = field(default="TinyLlama/TinyLlama-1.1B-Chat-v1.0") + model_name_or_path: str | None = field( + default="TinyLlama/TinyLlama-1.1B-Chat-v1.0", + metadata={"help": "HuggingFace model ID or local path to the base model."}, + ) use_fake_base_for_offline: bool = field( - default=False, metadata={"help": "Whether to use fake base for offline training."} + default=False, + metadata={ + "help": "Load model architecture without real base weights. Offline training only." + }, ) trust_remote_code: bool = field( - default=False, metadata={"help": "Whether to trust remote code."} + default=False, metadata={"help": "Trust remote code when loading model."} ) @@ -69,23 +76,18 @@ class ModelArguments: class DataArguments: data_path: str = field( default=None, - metadata={"help": "Path to the training data."}, + metadata={"help": "Path to the online training data."}, ) - eval_data_path: str = field(default=None, metadata={"help": "Path to the evaluation data."}) offline_data_path: str = field( default=None, metadata={ - "help": """Path to the offline training data. Providing this flag sets - `eagle_offline` in the EagleConfig and enables offline training. - The directory should contain many `.pt` files, each containing a pre-processed - data sample. `data_path` should still point to the original conversations file. - """ + "help": "Path to offline training data directory (.pt files). This argument enables offline mode.", }, ) lazy_preprocess: bool = True draft_vocab_cache: str | None = field( default=None, - metadata={"help": "Path to d2t.pt cache file."}, + metadata={"help": "Path to draft vocabulary cache file."}, ) vlm_img_dir: str = field(default=None, metadata={"help": "Path to the VLM image directory."}) vlm_processor: str = field(default=None, metadata={"help": "Path to the VLM processor."}) @@ -93,28 +95,24 @@ class DataArguments: @dataclass class TrainingArguments(transformers.TrainingArguments): - cache_dir: str | None = field(default=None) training_seq_len: int = field( default=2048, metadata={ "help": ( - "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + "Training sequence length. Sequences will be right padded or truncated to this length." ) }, ) - dataloader_drop_last: bool = field(default=True) - bf16: bool = field(default=True) mode: Literal["eagle3", "medusa"] = "eagle3" estimate_ar: bool = field( - default=False, metadata={"help": "Whether to estimate AR during training for logging."} - ) - ar_validate_steps: int = field(default=1000, metadata={"help": "Steps between AR validation."}) - disable_tqdm: bool = field(default=False, metadata={"help": "Disable tqdm progress bar."}) - remove_unused_columns: bool = field( - default=False, metadata={"help": "Set to False to keep extra args for VLM."} + default=False, metadata={"help": "Whether to estimate AR using training accuracy to log."} ) + ar_validate_steps: int = field(default=1000, metadata={"help": "AR validation interval."}) cp_size: int = field(default=1, metadata={"help": "Context parallelism size."}) - dp_shard_size: int = field(default=1, metadata={"help": "Data parallelism shard size."}) + dp_shard_size: int | None = field( + default=None, + metadata={"help": "Data parallelism shard size. None = auto (total_gpu / cp_size)."}, + ) @dataclass @@ -123,42 +121,70 @@ class MedusaArguments: medusa_num_layers: int | None = field(default=1) -@dataclass -class EagleArguments: - eagle_config: str = field(default=None, metadata={"help": "Path to eagle_config.json"}) - eagle_decoder_type: str = field( - default="llama", - metadata={"help": "The class of eagle decoder to use. Available options: llama, kimik2"}, - ) - mix_hidden_states: bool = field( - default=False, - metadata={"help": "Whether to mix hidden states from previous TTT step."}, - ) - disable_torch_compile: bool = field( - default=False, - metadata={"help": "Disable torch.compile on eagle forward/loss methods."}, - ) - num_ttt_steps: int = field( - default=3, - metadata={"help": "Number of train-time-test steps to use during training."}, - ) +def _parse_cli() -> tuple[str, list[str]]: + """Parse --config (required) from argv; return remaining args as config overrides. + + Extra arguments use OmegaConf dotlist syntax, e.g. + ``model.model_name_or_path=meta-llama/Llama-3.2-1B training.output_dir=ckpts/test``. + """ + p = argparse.ArgumentParser(add_help=False) + p.add_argument("--config", required=True, help="Path to the YAML config file.") + args, overrides = p.parse_known_args() + return args.config, overrides + + +def _load_config(config_path: str, overrides: list[str] = ()) -> tuple[dict, dict]: + """Load training config from a YAML file with sections: model, data, training, eagle. + + *overrides* are OmegaConf dotlist entries (e.g. ``["model.model_name_or_path=xxx"]``) + applied on top of the YAML. + + Returns: + hf_cfg: Flat dict from model/data/training sections, for HfArgumentParser.parse_dict() + eagle_cfg: Eagle section dict (EagleConfig fields), passed directly to mtsp.convert() + """ + merged = OmegaConf.load(config_path) + if overrides: + merged = OmegaConf.merge(merged, OmegaConf.from_dotlist(list(overrides))) + cfg = OmegaConf.to_container(merged, resolve=True) + + # Eagle section maps directly to EagleConfig fields — no field enumeration needed. + # eagle_architecture_config is a nested dict and is included as-is. + eagle_cfg = cfg.get("eagle", {}) + + hf_cfg = { + **cfg.get("model", {}), + **cfg.get("data", {}), + **cfg.get("training", {}), + } + + if hf_cfg.get("dp_shard_size") is None: + cp_size = hf_cfg.get("cp_size", 1) + hf_cfg["dp_shard_size"] = torch.cuda.device_count() // cp_size + + return hf_cfg, eagle_cfg def train(): + config_path, overrides = _parse_cli() + hf_cfg, eagle_cfg = _load_config(config_path, overrides) + parser = transformers.HfArgumentParser( ( ModelArguments, DataArguments, TrainingArguments, MedusaArguments, - EagleArguments, ) ) - model_args, data_args, training_args, medusa_args, eagle_args = ( - parser.parse_args_into_dataclasses() + model_args, data_args, training_args, medusa_args = parser.parse_dict( + hf_cfg, allow_extra_keys=True ) + if not data_args.data_path and not data_args.offline_data_path: - raise ValueError("Either data_path or offline_data_path must be provided.") + raise ValueError( + "Either data.data_path or data.offline_data_path must be set in the config." + ) if training_args.cp_size > 1 or training_args.dp_shard_size > 1: training_args.parallelism_config = ParallelismConfig( cp_size=training_args.cp_size, dp_shard_size=training_args.dp_shard_size @@ -167,7 +193,7 @@ def train(): patch_ring_attention_for_ttt() # Specific patch to accelerate 1.12.0. Removable after move to 1.13.0 training_args.parallelism_config.sp_backend = None - print_rank_0(f"arguments: {model_args}, {training_args}, {medusa_args}, {eagle_args}") + print_rank_0(f"arguments: {model_args}, {training_args}, {medusa_args}, eagle_cfg={eagle_cfg}") # Detect checkpoint to resume from last_checkpoint = ( @@ -213,28 +239,17 @@ def train(): } mtsp.convert(model, [("medusa", config)]) elif training_args.mode == "eagle3": - custom_config = ( - json.load(open(eagle_args.eagle_config)) if eagle_args.eagle_config else {} - ) - - config = { - "eagle_decoder_type": eagle_args.eagle_decoder_type, - "eagle_offline": use_offline_training, - "eagle_mix_hidden_states": eagle_args.mix_hidden_states, - "eagle_use_torch_compile": not eagle_args.disable_torch_compile, - "eagle_ttt_steps": eagle_args.num_ttt_steps, - "eagle_architecture_config": custom_config, - } - - mtsp.convert(model, [("eagle", config)]) + # eagle_cfg maps directly to EagleConfig fields; eagle_offline is derived here. + eagle_cfg["eagle_offline"] = use_offline_training + mtsp.convert(model, [("eagle", eagle_cfg)]) - # read draft vocab cache + # Load draft vocab cache if the draft model uses a compressed vocabulary if model.eagle_config.draft_vocab_size < model.eagle_config.vocab_size: if not os.path.isfile(data_args.draft_vocab_cache): raise FileNotFoundError( f"Draft vocab cache provided but not found: {data_args.draft_vocab_cache}" ) - model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache) + model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache, weights_only=True) print_rank_0(f"Loaded draft vocab cache from {data_args.draft_vocab_cache}.") else: raise Exception(f"{training_args.mode} is not supported!") diff --git a/examples/speculative_decoding/train_eagle3_and_export.sh b/examples/speculative_decoding/train_eagle3_and_export.sh index 0f5fef9354..92ccb6d513 100755 --- a/examples/speculative_decoding/train_eagle3_and_export.sh +++ b/examples/speculative_decoding/train_eagle3_and_export.sh @@ -17,50 +17,37 @@ set -eo pipefail -# Set default values for BASE_MODEL and DATA BASE_MODEL=meta-llama/Llama-3.2-1B-Instruct DATA=input_conversations/train.jsonl -# Parse input arguments --base_model and --data while [[ $# -gt 0 ]]; do - key="$1" - case $key in - --base_model) - BASE_MODEL="$2" - shift; shift - ;; - --data) - DATA="$2" - shift; shift - ;; - --offline_data) - OFFLINE_DATA_PATH="$2" - shift; shift - ;; - *) - echo "Unknown argument: $1" - exit 1 - ;; + case $1 in + --base_model) BASE_MODEL="$2"; shift; shift ;; + --data) DATA="$2"; shift; shift ;; + --offline_data) OFFLINE_DATA_PATH="$2"; shift; shift ;; + *) echo "Unknown argument: $1"; exit 1 ;; esac done -if [[ "$OFFLINE_DATA_PATH" != "" ]]; then - OFFLINE_DATA_ARGS="--offline-data $OFFLINE_DATA_PATH" +MODEL_BASENAME=$(basename "$BASE_MODEL") +OUTPUT_DIR=ckpts/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M) +mkdir -p "$OUTPUT_DIR" + +BASE_CFG="$(dirname "$(readlink -f "$0")")/../../modelopt_recipes/general/speculative_decoding/eagle3.yaml" + +# Build dotlist overrides +OVERRIDES=( + model.model_name_or_path="$BASE_MODEL" + training.output_dir="$OUTPUT_DIR" +) +if [[ -n "$OFFLINE_DATA_PATH" ]]; then + OVERRIDES+=( data.offline_data_path="$OFFLINE_DATA_PATH" ) else - OFFLINE_DATA_ARGS="" + OVERRIDES+=( data.data_path="$DATA" ) fi -MODEL_BASENAME=$(basename "$BASE_MODEL") - echo "==== [1/3] Training draft model ====" -OUTPUT_DIR=ckpts/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M) -mkdir -p "$(dirname "$OUTPUT_DIR")" -./launch_train.sh --model $BASE_MODEL \ - --output_dir $OUTPUT_DIR \ - $OFFLINE_DATA_ARGS \ - --data $DATA \ - --num_epochs 2 \ - --eagle_config eagle_config.json +./launch_train.sh --config "$BASE_CFG" "${OVERRIDES[@]}" echo "==== [2/3] Evaluating ModelOpt checkpoint on MT-Bench ====" python scripts/ar_validate.py --model_path $OUTPUT_DIR diff --git a/modelopt_recipes/general/speculative_decoding/eagle3.yaml b/modelopt_recipes/general/speculative_decoding/eagle3.yaml new file mode 100644 index 0000000000..0d2c8066e2 --- /dev/null +++ b/modelopt_recipes/general/speculative_decoding/eagle3.yaml @@ -0,0 +1,55 @@ +# Base config for EAGLE3 training. Override fields via OmegaConf dotlist on the CLI. + +# maps to ModelArguments (main.py) +model: + trust_remote_code: false + use_fake_base_for_offline: false + +# maps to DataArguments (main.py) +data: + data_path: input_conversations/train.jsonl + draft_vocab_cache: + +# maps to TrainingArguments (main.py) +training: + # --- commonly modified --- + mode: eagle3 + output_dir: + num_train_epochs: 1 + per_device_train_batch_size: 1 + learning_rate: 1.0e-4 + warmup_steps: 1000 + training_seq_len: 2048 + logging_steps: 100 + save_steps: 8192 + cp_size: 1 + disable_tqdm: false + estimate_ar: false + ar_validate_steps: -1 + + # --- rarely modified --- + do_eval: false + lr_scheduler_type: linear + save_strategy: steps + weight_decay: 0.0 + dataloader_drop_last: true + bf16: true + tf32: true + remove_unused_columns: false + +# maps to EagleConfig (modelopt/torch/speculative/config.py). +eagle: + # eagle_offline is derived from data.offline_data_path; do not set here. + eagle_decoder_type: llama + eagle_ttt_steps: 3 + eagle_mix_hidden_states: false + eagle_use_torch_compile: true + eagle_self_logit_distillation: true + eagle_freeze_base_model: true + eagle_loss_decay_factor: 0.9 + eagle_hidden_state_distillation: false + eagle_reuse_base_decoder: false + eagle_report_acc: true + eagle_enable_nvtx: false + # overwrite to modelopt/torch/speculative/eagle/default_config.py + eagle_architecture_config: {} diff --git a/tests/examples/speculative_decoding/test_eagle.py b/tests/examples/speculative_decoding/test_eagle.py index 271241bcb0..426f6a05e2 100644 --- a/tests/examples/speculative_decoding/test_eagle.py +++ b/tests/examples/speculative_decoding/test_eagle.py @@ -13,13 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import os from pathlib import Path import pytest import safetensors.torch import torch +import yaml from _test_utils.examples.run_command import run_example_command from packaging.version import Version @@ -64,6 +64,14 @@ def generate_offline_pt_data( return output_dir +def _write_eagle_yaml(path: Path, cfg: dict) -> Path: + """Write a YAML training config to *path* and return it.""" + path = Path(path) + with open(path, "w") as f: + yaml.safe_dump(cfg, f, default_flow_style=False) + return path + + @pytest.fixture(scope="module") def eagle_output_dir(tmp_path_factory): """Eagle output directory shared in this module.""" @@ -100,7 +108,7 @@ def test_calibrate_draft_vocab(tiny_llama_path, tiny_daring_anteater_path, draft # fmt: off -@pytest.mark.parametrize(("cp_size", "mix_hidden_states"), [(1, "false"), (2, "false"), (1, "true"), (2, "true")]) +@pytest.mark.parametrize(("cp_size", "mix_hidden_states"), [(1, False), (2, False), (1, True), (2, True)]) def test_llama_eagle3(tiny_llama_path, tiny_daring_anteater_path, tmp_path, eagle_output_dir, @@ -112,8 +120,8 @@ def test_llama_eagle3(tiny_llama_path, pytest.skip("cp_size=2 requires at least 2 GPUs, but only {} found.".format(available_gpus)) if cp_size == 2 and not Version(torch.__version__) >= Version("2.10.0"): pytest.skip("cp_size=2 requires torch 2.10.0") - # Create an ultra-tiny EAGLE config for testing to reduce memory usage - tiny_eagle_config = { + + tiny_eagle_arch_config = { "max_position_embeddings": 128, "num_hidden_layers": 1, "intermediate_size": 64, @@ -121,43 +129,48 @@ def test_llama_eagle3(tiny_llama_path, "num_key_value_heads": 2, "head_dim": 64, } - - # Write the tiny config to a temporary file - config_file = tmp_path / f"tiny_eagle_config_cp{cp_size}.json" - with open(config_file, "w") as f: - json.dump(tiny_eagle_config, f) + cfg = { + "model": {"model_name_or_path": str(tiny_llama_path)}, + "data": {"data_path": str(tiny_daring_anteater_path)}, + "training": { + "output_dir": str(eagle_output_dir / f"eagle-tinyllama-cp{cp_size}-mix{mix_hidden_states}"), + "num_train_epochs": 0.25, + "learning_rate": 1e-5, + "training_seq_len": 128, + "cp_size": cp_size, + "per_device_train_batch_size": 1, + }, + "eagle": { + "eagle_mix_hidden_states": mix_hidden_states, + "eagle_architecture_config": tiny_eagle_arch_config, + }, + } + yaml_file = _write_eagle_yaml(tmp_path / f"cfg_cp{cp_size}.yaml", cfg) run_example_command( - [ - "./launch_train.sh", - "--model", tiny_llama_path, - "--data", tiny_daring_anteater_path, - "--num_epochs", "0.25", - "--lr", "1e-5", - "--mode", "eagle3", - "--eagle_config", str(config_file), - "--output_dir", eagle_output_dir / f"eagle-tinyllama-cp{cp_size}", - "--training_seq_len", "128", # Match max_position_embeddings - "--cp_size", str(cp_size), - "--mix_hidden_states", mix_hidden_states, - ], + ["./launch_train.sh", "--config", str(yaml_file)], "speculative_decoding", ) -def test_resume_training(tiny_daring_anteater_path, eagle_output_dir): +def test_resume_training(tiny_daring_anteater_path, eagle_output_dir, tmp_path): """Test resume training of Eagle3.""" + checkpoint_dir = eagle_output_dir / "eagle-tinyllama-cp1-mixFalse" + cfg = { + "model": {"model_name_or_path": str(checkpoint_dir)}, + "data": {"data_path": str(tiny_daring_anteater_path)}, + "training": { + "output_dir": str(checkpoint_dir), + "num_train_epochs": 0.5, + "learning_rate": 1e-5, + "training_seq_len": 128, + "per_device_train_batch_size": 1, + }, + "eagle": {}, + } + yaml_file = _write_eagle_yaml(tmp_path / "resume_cfg.yaml", cfg) run_example_command( - [ - "./launch_train.sh", - "--model", eagle_output_dir / "eagle-tinyllama-cp1", - "--data", tiny_daring_anteater_path, - "--num_epochs", "0.5", - "--lr", "1e-5", - "--mode", "eagle3", - "--output_dir", eagle_output_dir / "eagle-tinyllama-cp1", - "--training_seq_len", "128", # Match max_position_embeddings - ], + ["./launch_train.sh", "--config", str(yaml_file)], "speculative_decoding", ) @@ -239,7 +252,7 @@ def test_offline_eagle3_training( num_aux_layers=min(cfg.num_hidden_layers, 3), ) - tiny_eagle_config = { + tiny_eagle_arch_config = { "max_position_embeddings": 128, "num_hidden_layers": 1, "intermediate_size": 64, @@ -247,27 +260,32 @@ def test_offline_eagle3_training( "num_key_value_heads": 2, "head_dim": 64, } - config_file = tmp_path / "tiny_eagle_config_offline.json" - with open(config_file, "w") as f: - json.dump(tiny_eagle_config, f) - - cmd = [ - "./launch_train.sh", - "--model", model_path, - "--data", tiny_daring_anteater_path, - "--offline-data", offline_data_dir, - "--num_epochs", "0.1", - "--lr", "1e-5", - "--mode", "eagle3", - "--eagle_config", str(config_file), - "--output_dir", output_subdir, - "--training_seq_len", "64", - "--trust_remote_code", "True", - "--fsdp", "False", - ] - if use_fake_base: - cmd += ["--use_fake_base_for_offline", "true"] - run_example_command(cmd, "speculative_decoding") + training_cfg = { + "model": { + "model_name_or_path": str(model_path), + "trust_remote_code": True, + "use_fake_base_for_offline": use_fake_base, + }, + "data": { + "data_path": str(tiny_daring_anteater_path), + "offline_data_path": str(offline_data_dir), + }, + "training": { + "output_dir": str(output_subdir), + "num_train_epochs": 0.1, + "learning_rate": 1e-5, + "training_seq_len": 64, + "per_device_train_batch_size": 1, + }, + "eagle": { + "eagle_architecture_config": tiny_eagle_arch_config, + }, + } + yaml_file = _write_eagle_yaml(tmp_path / f"offline_cfg_{model_id}.yaml", training_cfg) + run_example_command( + ["./launch_train.sh", "--config", str(yaml_file)], + "speculative_decoding", + ) assert os.path.exists(output_subdir / "config.json") @@ -277,9 +295,9 @@ def test_offline_resume_training_kimi(tiny_daring_anteater_path, tmp_path, eagle Depends on test_offline_eagle3_training["kimi-k2.5"] having run first. Exercises AutoModelForCausalLM.from_pretrained with model_type='fake_base_model'. """ - import transformers - checkpoint_dir = eagle_output_dir / "eagle-Kimi-K2.5-offline" + + import transformers config = transformers.AutoConfig.from_pretrained(checkpoint_dir, trust_remote_code=True) offline_data_dir = generate_offline_pt_data( @@ -289,20 +307,27 @@ def test_offline_resume_training_kimi(tiny_daring_anteater_path, tmp_path, eagle num_aux_layers=min(config.num_hidden_layers, 3), ) + training_cfg = { + "model": { + "model_name_or_path": str(checkpoint_dir), + "trust_remote_code": True, + "use_fake_base_for_offline": True, + }, + "data": { + "data_path": str(tiny_daring_anteater_path), + "offline_data_path": str(offline_data_dir), + }, + "training": { + "output_dir": str(checkpoint_dir), + "num_train_epochs": 0.2, + "learning_rate": 1e-5, + "training_seq_len": 64, + "per_device_train_batch_size": 1, + }, + "eagle": {}, + } + yaml_file = _write_eagle_yaml(tmp_path / "resume_kimi_cfg.yaml", training_cfg) run_example_command( - [ - "./launch_train.sh", - "--model", checkpoint_dir, - "--data", tiny_daring_anteater_path, - "--offline-data", offline_data_dir, - "--num_epochs", "0.2", - "--lr", "1e-5", - "--mode", "eagle3", - "--output_dir", checkpoint_dir, - "--training_seq_len", "64", - "--trust_remote_code", "True", - "--fsdp", "False", - "--use_fake_base_for_offline", "true", - ], + ["./launch_train.sh", "--config", str(yaml_file)], "speculative_decoding", ) diff --git a/tools/launcher/common/eagle3/offline_training.sh b/tools/launcher/common/eagle3/offline_training.sh index 09384a499b..630b9e8f70 100644 --- a/tools/launcher/common/eagle3/offline_training.sh +++ b/tools/launcher/common/eagle3/offline_training.sh @@ -27,7 +27,6 @@ export PATH=$PATH:/workspace/.local/bin trap 'error_handler $0 $LINENO' ERR # ERROR HANDLER bash modules/Model-Optimizer/examples/speculative_decoding/launch_train.sh \ - --model ${HF_MODEL_CKPT} \ ${@} python modules/Model-Optimizer/examples/speculative_decoding/scripts/export_hf_checkpoint.py \ diff --git a/tools/launcher/examples/Qwen/Qwen3-8B/hf_offline_eagle3.yaml b/tools/launcher/examples/Qwen/Qwen3-8B/hf_offline_eagle3.yaml index 934ab2928e..071cfd03a1 100644 --- a/tools/launcher/examples/Qwen/Qwen3-8B/hf_offline_eagle3.yaml +++ b/tools/launcher/examples/Qwen/Qwen3-8B/hf_offline_eagle3.yaml @@ -67,20 +67,13 @@ pipeline: task_2: script: common/eagle3/offline_training.sh args: - - --offline-data /scratchspace/offline_hidden_states - - --data_path None - - --mode eagle3 - - --num_epochs 1 - - --lr 3e-4 - - --save_steps 500000 - - --output_dir /scratchspace/eagle3 - - --train_bs 8 - - --training_seq_len 4096 - - --eagle_config modules/Model-Optimizer/examples/speculative_decoding/eagle_config.json - - --disable_tqdm True - - --ar_validate_steps 500000 - environment: - - HF_MODEL_CKPT: <> + - --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/eagle3.yaml + - model.model_name_or_path=<> + - data.offline_data_path=/scratchspace/offline_hidden_states + - training.output_dir=/scratchspace/eagle3 + - training.training_seq_len=4096 + - training.disable_tqdm=true + - training.ar_validate_steps=500000 slurm_config: _factory_: "slurm_factory" nodes: 1