Skip to content

Commit bc52b6c

Browse files
authored
Feat: Support VLLM one-model eagle ckpt; Add unit tests; (#573)
## What does this PR do? **Type of change:** New feature, new tests; <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** - Add conversion scripts for eagle3 llm-compressor style checkpoint - Jira Ticket: https://jirasw.nvidia.com/browse/OMNIML-2866 - Add unit tests for `ar_validate.py`, `export_hf_checkpoint.py`, and `convert_to_vllm_ckpt.py`. ## Usage <!-- You can potentially add a usage example below. --> ```python python scripts/convert_to_vllm_ckpt.py --input <eagle3 ckpt> --verifier <base model> --output <path> ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> --------- Signed-off-by: h-guo18 <[email protected]>
1 parent 1d0ee04 commit bc52b6c

File tree

8 files changed

+248
-7
lines changed

8 files changed

+248
-7
lines changed

examples/speculative_decoding/README.md

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,15 +129,15 @@ Once we finish dumping hidden states, launch offline training with an extra `--o
129129
For online training checkpoints, we can run in-framework evaluation on MT-bench:
130130

131131
```bash
132-
python ar_validate.py --model_path $ONLINE_CKPT
132+
python scripts/ar_validate.py --model_path $ONLINE_CKPT
133133
```
134134

135135
**Note**: In-framework evaluation is supported only for online training. For offline training checkpoints, please export the model and evaluate it using serving frameworks.
136136

137137
## Export
138138

139139
```bash
140-
python export_hf_checkpoint.py --model_path $OUTPUT_DIR --export_path $EXPORT_PATH
140+
python scripts/export_hf_checkpoint.py --model_path $OUTPUT_DIR --export_path $EXPORT_PATH
141141
```
142142

143143
This exports the model from a ModelOpt checkpoint to a deployment-compatible format.
@@ -175,6 +175,16 @@ kv_cache_config:
175175
176176
Please refer to [TRT-LLM Doc: Speculative Decoding](https://nvidia.github.io/TensorRT-LLM/examples/llm_speculative_decoding.html) for detailed usage.
177177
178+
### vLLM
179+
180+
Please refer to [VLLM Doc: Speculative Decoding](https://docs.vllm.ai/en/latest/features/spec_decode/) for detailed usage.
181+
182+
Optionally, you can convert the exported checkpoint to contain target model information, which is accepted by vLLM to simplify depployment:
183+
184+
```bash
185+
python scripts/convert_to_vllm_ckpt.py --input <exported_ckpt> --verifier <target_model> --output <output_dir>
186+
```
187+
178188
### SGLang
179189

180190
Please refer to [SGLang Doc: Speculative Decoding](https://docs.sglang.ai/advanced_features/speculative_decoding.html#EAGLE-3-Decoding) for detailed usage.
@@ -227,7 +237,7 @@ Note: Add `--quantization=modelopt` flag for quantized models.
227237
Then, we generate conversations with the base model using prompts from Daring-Anteater:
228238

229239
```bash
230-
python server_generate.py --data_path input_conversations/daring-anteater.jsonl --output_path synthetic/train.jsonl
240+
python scripts/server_generate.py --data_path input_conversations/daring-anteater.jsonl --output_path synthetic/train.jsonl
231241
```
232242

233243
To add a system prompt, use the `--system_prompt <system_prompt_text>` argument.
@@ -239,7 +249,7 @@ For large scale data generation, please see [SLURM prepare data](SLURM_prepare_d
239249
We can optionally use smaller vocab size for the draft model for faster training and inference. E.g. Llama3.2-1B has a vocab size of 128256. In this example, we construct a draft vocab mapping of size 32k by finding the most commonly appeared vocabs in our training set:
240250

241251
```bash
242-
python calibrate_draft_vocab.py --model meta-llama/Llama-3.2-1B-Instruct --data input_conversations/daring-anteater.jsonl --draft_vocab_size 32000 --save_dir draft_vocab_cache
252+
python scripts/calibrate_draft_vocab.py --model meta-llama/Llama-3.2-1B-Instruct --data input_conversations/daring-anteater.jsonl --draft_vocab_size 32000 --save_dir draft_vocab_cache
243253
```
244254

245255
This will produce a `d2t.pt` file in `save_dir`, which is the mapping from draft token to target token. During inference, draft tokens can be mapped back to target tokens by `target_token = draft_token + d2t[draft_token]`.

examples/speculative_decoding/eagle_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
import numpy as np
2222
import torch
2323
import transformers
24-
from ar_validate import validate_ar
2524
from datasets import load_dataset
2625
from PIL import Image
26+
from scripts.ar_validate import validate_ar
2727
from torch.utils.data import Dataset
2828
from transformers import AutoProcessor, Trainer, TrainerCallback
2929
from transformers.trainer_pt_utils import LabelSmoother
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""
17+
Convert a TRTLLM eagle checkpoint to an VLLM compatible one-model checkpoint.
18+
"""
19+
20+
import argparse
21+
import json
22+
import os
23+
import shutil
24+
from copy import deepcopy
25+
26+
VLLM_EAGLE3_ONE_CKPT_CFG_TEMPLATE = {
27+
"architectures": ["Eagle3Speculator"],
28+
"auto_map": {"": "eagle3.Eagle3SpeculatorConfig"},
29+
"draft_vocab_size": None,
30+
"has_no_defaults_at_init": False,
31+
"norm_before_residual": True,
32+
"speculators_config": {
33+
"algorithm": "eagle3",
34+
"default_proposal_method": "greedy",
35+
"proposal_methods": [
36+
{
37+
"accept_tolerance": 0.0,
38+
"proposal_type": "greedy",
39+
"speculative_tokens": 3,
40+
"verifier_accept_k": 1,
41+
}
42+
],
43+
"verifier": {"architectures": [""], "name_or_path": ""},
44+
},
45+
"speculators_model_type": "eagle3",
46+
"speculators_version": "0.1.0.dev14",
47+
"target_hidden_size": None,
48+
"torch_dtype": None,
49+
"transformer_layer_config": {
50+
"attention_bias": None,
51+
"attention_dropout": None,
52+
"head_dim": None,
53+
"hidden_act": None,
54+
"hidden_size": None,
55+
"initializer_range": None,
56+
"intermediate_size": None,
57+
"max_position_embeddings": None,
58+
"mlp_bias": None,
59+
"model_type": None,
60+
"num_attention_heads": None,
61+
"num_hidden_layers": None,
62+
"num_key_value_heads": None,
63+
"pretraining_tp": None,
64+
"rms_norm_eps": None,
65+
"rope_scaling": None,
66+
"rope_theta": None,
67+
"use_cache": True,
68+
"vocab_size": None,
69+
},
70+
"transformers_version": None,
71+
}
72+
73+
74+
def convert_to_eagle3_speculator_config(
75+
draft_cfg,
76+
verifier_name_or_path,
77+
template_cfg=VLLM_EAGLE3_ONE_CKPT_CFG_TEMPLATE,
78+
):
79+
"""
80+
Convert a draft model config and a verifier model config to an Eagle3Speculator config.
81+
"""
82+
83+
verifier_config_path = os.path.join(verifier_name_or_path, "config.json")
84+
with open(verifier_config_path, encoding="utf-8") as verifier_cfg_file:
85+
verifier_cfg = json.load(verifier_cfg_file)
86+
87+
speculator_config = deepcopy(template_cfg)
88+
89+
try:
90+
# Update speculators_config separately to avoid type conflicts
91+
speculator_config["speculators_config"].update(
92+
{
93+
"verifier": {
94+
"architectures": verifier_cfg["architectures"],
95+
"name_or_path": verifier_name_or_path,
96+
},
97+
}
98+
)
99+
100+
# Update other fields
101+
speculator_config.update(
102+
{
103+
"draft_vocab_size": draft_cfg["draft_vocab_size"],
104+
"target_hidden_size": verifier_cfg["hidden_size"],
105+
"torch_dtype": draft_cfg["torch_dtype"],
106+
"transformer_layer_config": {
107+
k: draft_cfg[k] for k in template_cfg["transformer_layer_config"]
108+
},
109+
"transformers_version": draft_cfg["transformers_version"],
110+
}
111+
)
112+
except Exception as e:
113+
raise Exception(f"Error converting draft config: {e}")
114+
115+
return speculator_config
116+
117+
118+
def main():
119+
parser = argparse.ArgumentParser(
120+
description="Convert TRTLLM eagle checkpoint to VLLM compatible one-model checkpoint."
121+
)
122+
parser.add_argument("--input", help="Path to TRTLLM eagle checkpoint.")
123+
parser.add_argument("--verifier", help="Name or path to the verifier model.")
124+
parser.add_argument("--output", help="Save path for converted vllm one-model checkpoint.")
125+
126+
args = parser.parse_args()
127+
128+
with open(os.path.join(args.input, "config.json"), encoding="utf-8") as f:
129+
original_draft_cfg = json.load(f)
130+
131+
converted_cfg = convert_to_eagle3_speculator_config(
132+
original_draft_cfg,
133+
args.verifier,
134+
)
135+
136+
# Write the converted config to the output directory
137+
os.makedirs(args.output, exist_ok=True)
138+
with open(os.path.join(args.output, "config.json"), "w", encoding="utf-8") as f:
139+
json.dump(converted_cfg, f, indent=2, ensure_ascii=False)
140+
141+
# Copy the model.safetensor file from input dir to output dir
142+
input_model_path = os.path.join(args.input, "model.safetensors")
143+
output_model_path = os.path.join(args.output, "model.safetensors")
144+
shutil.copyfile(input_model_path, output_model_path)
145+
146+
147+
if __name__ == "__main__":
148+
main()

examples/speculative_decoding/server_generate.py renamed to examples/speculative_decoding/scripts/server_generate.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
116
# Adapted from: https://github.com/FasterDecoding/Medusa/blob/e2a5d20/data_generation/generate.py
217
#
318
# Licensed under the Apache License, Version 2.0 (the "License");

tests/examples/speculative_decoding/test_eagle.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,22 @@
1515

1616
import json
1717

18+
import pytest
19+
import safetensors.torch
1820
from _test_utils.examples.run_command import run_example_command
1921

22+
from modelopt.torch.export.plugins.hf_spec_export import EAGLE_MODELOPT_TO_OFFICIAL
23+
24+
25+
@pytest.fixture(scope="module")
26+
def eagle_output_dir(tmp_path_factory):
27+
"""Eagle output directory shared in this module."""
28+
return tmp_path_factory.mktemp("eagle_output_dir")
29+
2030

2131
# fmt: off
22-
def test_llama_eagle3(tiny_llama_path, num_gpus, tiny_daring_anteater_path, tmp_path):
32+
def test_llama_eagle3(tiny_llama_path, num_gpus, tiny_daring_anteater_path, tmp_path, eagle_output_dir):
33+
"""Test Eagle3 training with a tiny llama model."""
2334
# Create an ultra-tiny EAGLE config for testing to reduce memory usage
2435
tiny_eagle_config = {
2536
"max_position_embeddings": 128,
@@ -45,8 +56,65 @@ def test_llama_eagle3(tiny_llama_path, num_gpus, tiny_daring_anteater_path, tmp_
4556
"--num_gpu", str(num_gpus),
4657
"--mode", "eagle3",
4758
"--eagle_config", str(config_file),
48-
"--output_dir", tmp_path / "eagle-tinyllama",
59+
"--output_dir", eagle_output_dir / "eagle-tinyllama",
4960
"--training_seq_len", "128", # Match max_position_embeddings
5061
],
5162
"speculative_decoding",
5263
)
64+
65+
66+
def test_ar_validate(eagle_output_dir):
67+
"""Test in-framework AR evaluation."""
68+
run_example_command(
69+
[
70+
"python", "./scripts/ar_validate.py",
71+
"--model_path", eagle_output_dir / "eagle-tinyllama",
72+
"--osl", "20",
73+
"--num_samples", "10",
74+
"--steps", "3"
75+
],
76+
"speculative_decoding",
77+
)
78+
79+
80+
def test_export_hf_checkpoint(eagle_output_dir):
81+
"""Test export of Eagle3 checkpoint."""
82+
run_example_command(
83+
[
84+
"python", "./scripts/export_hf_checkpoint.py",
85+
"--model_path", eagle_output_dir / "eagle-tinyllama",
86+
"--export_path", eagle_output_dir / "eagle-tinyllama-export",
87+
],
88+
"speculative_decoding",
89+
)
90+
# Check the exported checkpoints have required keys
91+
state_dict = safetensors.torch.load_file(eagle_output_dir / "eagle-tinyllama-export" / "model.safetensors")
92+
for required_key in EAGLE_MODELOPT_TO_OFFICIAL["required"].values():
93+
assert required_key in state_dict, f"Missing key '{required_key}' in state_dict"
94+
95+
96+
def test_convert_to_vllm_ckpt(tiny_llama_path, eagle_output_dir):
97+
"""Test conversion of Eagle3 checkpoint to VLLM one-model checkpoint."""
98+
run_example_command(
99+
[
100+
"python", "./scripts/convert_to_vllm_ckpt.py",
101+
"--input", eagle_output_dir / "eagle-tinyllama-export",
102+
"--verifier", tiny_llama_path,
103+
"--output", eagle_output_dir / "eagle-tinyllama-export-vllm-one-ckpt",
104+
],
105+
"speculative_decoding",
106+
)
107+
108+
@pytest.mark.skip(reason="Needs dataset conversion to role-content format; consolidate data loading first.")
109+
def test_calibrate_draft_vocab(tiny_llama_path, tiny_daring_anteater_path,tmp_path):
110+
"""Test calibration of draft vocabulary."""
111+
run_example_command(
112+
[
113+
"python", "./scripts/calibrate_draft_vocab.py",
114+
"--model", tiny_llama_path,
115+
"--data", tiny_daring_anteater_path,
116+
"--draft_vocab_size", "100",
117+
"--save_dir", tmp_path / "draft_vocab_cache",
118+
],
119+
"speculative_decoding",
120+
)

0 commit comments

Comments
 (0)