[Speculative Decoding] Refactor EAGLE3 training to YAML-based config and recipe system#1134
[Speculative Decoding] Refactor EAGLE3 training to YAML-based config and recipe system#1134
Conversation
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
📝 WalkthroughWalkthroughMigrates the speculative decoding example from CLI-argument-driven training to YAML recipe-driven training. Scripts, README, tests, and examples now use a single Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant LaunchScript as "launch_train.sh"
participant Main as "main.py"
participant Converter as "mtsp.convert"
participant Trainer as "HF Trainer / accelerate"
participant Checkpoint as "Checkpoint/Export"
User->>LaunchScript: provide --config (and optional --model)
LaunchScript->>Main: accelerate launch main.py --config ...
Main->>Main: load YAML (__base__ inheritance via OmegaConf)
Main->>Converter: if eagle3, call mtsp.convert(eagle_cfg, ...)
Converter-->>Main: converted model (in-place or wrapper)
Main->>Trainer: build HF args from YAML and start training
Trainer->>Checkpoint: save checkpoints / final model
Trainer-->>User: training complete
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes 🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
|
Codecov Report✅ All modified and coverable lines are covered by tests.
Additional details and impacted files@@ Coverage Diff @@
## main #1134 +/- ##
===========================================
- Coverage 70.19% 54.53% -15.66%
===========================================
Files 230 348 +118
Lines 26044 39766 +13722
===========================================
+ Hits 18281 21686 +3405
- Misses 7763 18080 +10317
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
There was a problem hiding this comment.
Actionable comments posted: 8
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/speculative_decoding/main.py (1)
258-259:⚠️ Potential issue | 🔴 CriticalAdd
weights_only=Truetotorch.load()call for security.The
torch.load(data_args.draft_vocab_cache)at line 258 does not specifyweights_only=True, which allows arbitrary code execution from malicious pickle files. Sinced2tis a pure tensor (int64),weights_only=Trueis both safe and compatible.Proposed fix
- model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache) + model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache, weights_only=True)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/speculative_decoding/main.py` around lines 258 - 259, The torch.load call that assigns model.eagle_module.d2t from data_args.draft_vocab_cache should pass weights_only=True to avoid executing pickled code; update the load call in the code that sets model.eagle_module.d2t to use torch.load(data_args.draft_vocab_cache, weights_only=True) so only tensor data is deserialized.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/speculative_decoding/launch_train.sh`:
- Around line 49-79: The accelerate launch invocation currently interpolates
unquoted variables into sh -c and omits --num_processes on single-node runs; to
fix, build the command as a bash array (e.g., CMD=()) instead of a single sh -c
string, append required flags using the existing
MULTI_NODE_ARGS/MODEL_ARG/TOTAL_GPU symbols (ensure MULTI_NODE_ARGS always
includes "--num_processes $TOTAL_GPU" even for single-node), and then run the
launch with "${CMD[@]}" so that $CONFIG_FILE, $MODEL, $HEAD_NODE_IP and other
variables are safely quoted and preserved without word-splitting or accidental
expansion.
In `@examples/speculative_decoding/main.py`:
- Line 111: The metadata help string for the dataclass field ar_validate_steps
is incomplete; update the metadata["help"] for ar_validate_steps to a full,
descriptive sentence (e.g., "Number of autoregressive validation steps to run
during evaluation" or similar) so users understand its purpose; locate the
ar_validate_steps field definition and replace the truncated help text with the
completed description.
In `@examples/speculative_decoding/train_eagle3_and_export.sh`:
- Around line 43-48: train_config.yaml is missing the base model identifier so
the generated YAML is not replayable; update the code that writes YAML_FILE
(train_config.yaml) to include the model_name_or_path value (the model used via
the --model override) under model: (e.g., model_name_or_path: "<value>") so the
config fully captures the runtime model selection; ensure the string comes from
the same variable/arg used to parse the --model override and is written when
creating YAML_FILE (preserving YAML_FILE, OUTPUT_DIR, and model_name_or_path
references).
In `@modelopt_recipes/speculative_decoding/kimi_k25_eagle_offline.yaml`:
- Around line 3-6: The recipe currently enables trust_remote_code by default in
the model block (fields model_name_or_path: moonshotai/Kimi-K2.5 and
trust_remote_code: true); change that default to false and instead
document/require an explicit opt-in (e.g., a commented flag or
environment-driven toggle) so users must consciously enable trust_remote_code
for the Kimi recipe; update any README or inline comment near the model
configuration and/or the use_fake_base_for_offline handling so it explains how
to opt in (enable trust_remote_code) when the user intentionally trusts the
model's custom HF code.
In `@modelopt_recipes/speculative_decoding/llama3_eagle_offline.yaml`:
- Around line 4-6: Replace the unsafe default by changing the YAML key
trust_remote_code from true to false in the model block of the Llama offline
recipe (the block containing model_name_or_path: meta-llama/Llama-3.2-1B);
update the value so the recipe does not silently enable remote code execution
and leave a brief comment if you want to document that users must opt-in to
enable remote code loading manually.
In `@modelopt_recipes/speculative_decoding/llama3_eagle_online.yaml`:
- Around line 4-6: The YAML enables trust_remote_code for a stock Llama model;
remove or set trust_remote_code to false to avoid executing arbitrary repo code.
Edit the model block that contains model_name_or_path: meta-llama/Llama-3.2-1B
and either delete the trust_remote_code line or change it to trust_remote_code:
false so the pipeline uses the standard transformers implementation rather than
allowing remote code execution.
In `@tests/examples/speculative_decoding/test_eagle.py`:
- Around line 138-149: The test currently writes both mix_hidden_states variants
into the same output directory causing runs to clobber each other; modify the
training output_dir construction (where eagle_output_dir /
f"eagle-tinyllama-cp{cp_size}" is used) to include the mix_hidden_states flag
(e.g., append `_mix{mix_hidden_states}` or similar) so each (cp_size,
mix_hidden_states) combination gets a unique checkpoint directory; update any
references that assume the old path (e.g., test_resume_training) to use the new
per-variant output_dir variable.
- Around line 269-273: Parametrize the trust_remote_code flag instead of
hardcoding True: add a test parameter (default False) named trust_remote_code to
the relevant test cases and use it when writing the model YAML dictionary
(replace the hardcoded "trust_remote_code": True with "trust_remote_code":
trust_remote_code) and when calling AutoConfig.from_pretrained (replace the
hardcoded trust_remote_code=True with trust_remote_code=trust_remote_code);
update only the specific test invocations that require remote code execution to
pass trust_remote_code=True. Ensure the new parameter is included in the pytest
parametrization for the test function(s) that build the YAML/model config so
local models keep trust_remote_code=False while remote-model cases explicitly
set it to True.
---
Outside diff comments:
In `@examples/speculative_decoding/main.py`:
- Around line 258-259: The torch.load call that assigns model.eagle_module.d2t
from data_args.draft_vocab_cache should pass weights_only=True to avoid
executing pickled code; update the load call in the code that sets
model.eagle_module.d2t to use torch.load(data_args.draft_vocab_cache,
weights_only=True) so only tensor data is deserialized.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: c971e970-8ce6-4555-9bd5-f56f417bbb15
📒 Files selected for processing (11)
examples/speculative_decoding/README.mdexamples/speculative_decoding/eagle_config.jsonexamples/speculative_decoding/fsdp_config.jsonexamples/speculative_decoding/launch_train.shexamples/speculative_decoding/main.pyexamples/speculative_decoding/train_eagle3_and_export.shmodelopt_recipes/speculative_decoding/_base_eagle3.yamlmodelopt_recipes/speculative_decoding/kimi_k25_eagle_offline.yamlmodelopt_recipes/speculative_decoding/llama3_eagle_offline.yamlmodelopt_recipes/speculative_decoding/llama3_eagle_online.yamltests/examples/speculative_decoding/test_eagle.py
💤 Files with no reviewable changes (2)
- examples/speculative_decoding/eagle_config.json
- examples/speculative_decoding/fsdp_config.json
modelopt_recipes/speculative_decoding/llama3_eagle_offline.yaml
Outdated
Show resolved
Hide resolved
|
So does the yaml file encode all the information modelopt needs for the eagle3 training? |
Basically yes. The only exception is the accelerate configs (e.g. multinode settings). They need to be passed in addition to the yaml config, e.g.: I think they are orthogonal to the "recipe" and is more convenient to set in this way, since the node ip is often dynamic on slurm jobs. Do you think it's better to put it also in the yaml? |
ChenhanYu
left a comment
There was a problem hiding this comment.
PR Review: Refactor EAGLE3 training to YAML-based config and recipe system
Summary
Clean refactor that replaces ~250 lines of shell argument parsing in launch_train.sh with a YAML-based config system using OmegaConf. Config files support __base__ inheritance, and pre-configured recipes are shipped under modelopt_recipes/speculative_decoding/. The EagleArguments dataclass is removed — eagle config now passes directly from YAML to mtsp.convert(). Tests and README are updated accordingly. Net reduction: -74 lines. The direction is good.
Findings
1. Missing omegaconf dependency — Blocker
examples/speculative_decoding/main.py:47 — from omegaconf import OmegaConf
This is a new import, but omegaconf is not added to pyproject.toml extras or any requirements.txt. Users will get ImportError unless they happen to have it installed transitively (e.g., via Hydra). Needs to be added as a dependency.
2. trust_remote_code: true on Llama recipes — Security
modelopt_recipes/speculative_decoding/llama3_eagle_online.yaml:7 and llama3_eagle_offline.yaml:7
Llama models don't require trust_remote_code. Given that PR #975 just put effort into removing hardcoded trust_remote_code=True throughout the codebase, shipping recipes with it enabled by default undermines that security improvement. Should be false for Llama recipes.
3. _parse_cli silently ignores unknown args — Migration footgun
examples/speculative_decoding/main.py:133 — args, _ = p.parse_known_args()
Users migrating from the old CLI (e.g., --eagle_config, --mode eagle3, --mix_hidden_states) will have their flags silently ignored with no error or deprecation warning. Consider logging the unknown args.
4. dp_shard_size: 0 magic sentinel — Edge case
examples/speculative_decoding/main.py:168-170 — If torch.cuda.device_count() returns 0 (CPU-only node, CUDA not visible), this produces 0. Guard with gpu_count = torch.cuda.device_count() or 1.
5. train_eagle3_and_export.sh YAML not self-contained
Line 103: ./launch_train.sh --config "$YAML_FILE" --model "$BASE_MODEL" — The generated YAML doesn't include model_name_or_path, so the config alone can't reproduce the training run despite the comment saying it's "preserved alongside the checkpoint."
6. Truncated help text
examples/speculative_decoding/main.py:111 — ar_validate_steps help string is "AR validation ." — incomplete.
7. Flat config merge can silently collide
main.py:159-163 merges model, data, and training dicts into one flat dict. If any sections share a key name, the later section silently wins.
Overall this is a well-structured refactor. Main action items: add omegaconf dependency, fix trust_remote_code on Llama recipes, and consider logging unknown CLI args for migration safety.
This is an AI-assisted review — human sign-off required before merging.
|
|
||
|
|
||
| @dataclass | ||
| class EagleArguments: |
There was a problem hiding this comment.
This dataclass is removed. Its arguments are subset of EagleConfig here. Now we directly parse the "eagle" section in yaml to a EagleConfig object.
There was a problem hiding this comment.
I think we can do the same for medusa/dflash
There was a problem hiding this comment.
Why don't we just move this file to tools/launcher/examples/moonshotai/Kimi-K2.5/?
There was a problem hiding this comment.
We can. It's not clear to me whether modelopt_recipes/ or tool/launcher/examples/ is the best place for these yamls. Curious what you think
There was a problem hiding this comment.
Same. Why don't we move this to tools/launcher/examples/meta-llama/Llama-3.2-1B-Instruct?
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Comments 2, 3, 6 addressed. Other points seems fine to me |
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (2)
tests/examples/speculative_decoding/test_eagle.py (2)
246-273:⚠️ Potential issue | 🔴 CriticalRemove hardcoded
trust_remote_code=True; make it explicit and default-safe.Hardcoding
trust_remote_code=Truein bothAutoConfig.from_pretrained(...)and generated YAML model configs is a CRITICAL security violation in this repo’s rules. Parameterize it and default toFalse, enablingTrueonly in explicitly justified test cases.Proposed direction
-@pytest.mark.parametrize( - ("model_source", "use_fake_base"), +@pytest.mark.parametrize( + ("model_source", "use_fake_base", "trust_remote_code"), [ - (None, False), - ("moonshotai/Kimi-K2.5", True), - ("moonshotai/Kimi-K2-Thinking", True), - ("MiniMaxAI/MiniMax-M2.5", True), + (None, False, False), + ("moonshotai/Kimi-K2.5", True, True), + ("moonshotai/Kimi-K2-Thinking", True, True), + ("MiniMaxAI/MiniMax-M2.5", True, True), ], ) def test_offline_eagle3_training(..., model_source, use_fake_base, trust_remote_code): ... - cfg = transformers.AutoConfig.from_pretrained(model_path, trust_remote_code=True) + cfg = transformers.AutoConfig.from_pretrained( + model_path, trust_remote_code=trust_remote_code + ) ... - "trust_remote_code": True, + "trust_remote_code": trust_remote_code,Apply the same pattern to
test_offline_resume_training_kimiinstead of hardcodingTrue.As per coding guidelines,
Flag trust_remote_code=True hardcoded for transformers model or tokenizer loading as CRITICAL security issue. Code should expose it as a caller-configurable parameter defaulting to False, not hardcode True.Also applies to: 306-320
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/examples/speculative_decoding/test_eagle.py` around lines 246 - 273, Replace hardcoded trust_remote_code=True usages by adding a test-level parameter (default False) and passing that variable to transformers.AutoConfig.from_pretrained and into the generated training_cfg["model"]["trust_remote_code"]; specifically, introduce a local variable (e.g., trust_remote_code=False) at the top of the test and use it when calling AutoConfig.from_pretrained(...) and when building training_cfg["model"] instead of the literal True, and update any related test variants (e.g., test_offline_resume_training_kimi) to set trust_remote_code=True only when explicitly required.
139-139:⚠️ Potential issue | 🟠 MajorUpdate downstream consumers to the new
-mix...checkpoint path.Line 139 changed training output to
...-cp{cp_size}-mix{mix_hidden_states}, but later tests still readeagle-tinyllama-cp1(Line 187, Line 201), which can breaktest_ar_validateandtest_export_hf_checkpoint.Proposed fix
- "--model_path", eagle_output_dir / "eagle-tinyllama-cp1", + "--model_path", eagle_output_dir / "eagle-tinyllama-cp1-mixFalse", ... - "--model_path", eagle_output_dir / "eagle-tinyllama-cp1", + "--model_path", eagle_output_dir / "eagle-tinyllama-cp1-mixFalse",As per coding guidelines,
All test coverage checks in PRs must pass for new features and examples.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/examples/speculative_decoding/test_eagle.py` at line 139, Tests still reference the old checkpoint path name; update downstream consumers to use the new output_dir format that includes the mix suffix (constructed via eagle_output_dir / f"eagle-tinyllama-cp{cp_size}-mix{mix_hidden_states}"). Locate usages in tests that read the checkpoint (notably the test functions test_ar_validate and test_export_hf_checkpoint) and replace hardcoded "eagle-tinyllama-cp1" (or similar) with the same formatted path logic or a derived variable from eagle_output_dir and cp_size/mix_hidden_states so both creation and consumption use the identical "-cp{cp_size}-mix{mix_hidden_states}" filename.
🧹 Nitpick comments (1)
examples/speculative_decoding/main.py (1)
131-134: Consider failing on unknown CLI args instead of silently ignoring them.Line 131-Line 134 currently accepts typos/legacy flags and continues. This can hide misconfiguration in the YAML migration path.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/speculative_decoding/main.py` around lines 131 - 134, The CLI currently uses p.parse_known_args() which swallows typos/legacy flags; change to p.parse_args() or explicitly fail when unknown is non-empty: after calling parse_known_args(), if unknown is truthy call parser.error(...) or raise SystemExit with a clear message so the script fails fast instead of printing via print_rank_0 and proceeding; update the handling around parse_known_args()/parse_args() and remove the print_rank_0 fallback so callers relying on args.config and args.model get validated input.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/speculative_decoding/main.py`:
- Around line 256-260: The code currently calls
os.path.isfile(data_args.draft_vocab_cache) which raises a TypeError when
data_args.draft_vocab_cache is None; add a guard to check for a truthy/non-None
draft_vocab_cache before performing os.path.isfile. Specifically, in the block
that assigns model.eagle_module.d2t via torch.load, first assert or raise a
clear error if data_args.draft_vocab_cache is falsy (e.g., None or empty string)
with a descriptive message, then proceed to call
os.path.isfile(data_args.draft_vocab_cache) and
torch.load(data_args.draft_vocab_cache, weights_only=True) only when the path
exists.
---
Duplicate comments:
In `@tests/examples/speculative_decoding/test_eagle.py`:
- Around line 246-273: Replace hardcoded trust_remote_code=True usages by adding
a test-level parameter (default False) and passing that variable to
transformers.AutoConfig.from_pretrained and into the generated
training_cfg["model"]["trust_remote_code"]; specifically, introduce a local
variable (e.g., trust_remote_code=False) at the top of the test and use it when
calling AutoConfig.from_pretrained(...) and when building training_cfg["model"]
instead of the literal True, and update any related test variants (e.g.,
test_offline_resume_training_kimi) to set trust_remote_code=True only when
explicitly required.
- Line 139: Tests still reference the old checkpoint path name; update
downstream consumers to use the new output_dir format that includes the mix
suffix (constructed via eagle_output_dir /
f"eagle-tinyllama-cp{cp_size}-mix{mix_hidden_states}"). Locate usages in tests
that read the checkpoint (notably the test functions test_ar_validate and
test_export_hf_checkpoint) and replace hardcoded "eagle-tinyllama-cp1" (or
similar) with the same formatted path logic or a derived variable from
eagle_output_dir and cp_size/mix_hidden_states so both creation and consumption
use the identical "-cp{cp_size}-mix{mix_hidden_states}" filename.
---
Nitpick comments:
In `@examples/speculative_decoding/main.py`:
- Around line 131-134: The CLI currently uses p.parse_known_args() which
swallows typos/legacy flags; change to p.parse_args() or explicitly fail when
unknown is non-empty: after calling parse_known_args(), if unknown is truthy
call parser.error(...) or raise SystemExit with a clear message so the script
fails fast instead of printing via print_rank_0 and proceeding; update the
handling around parse_known_args()/parse_args() and remove the print_rank_0
fallback so callers relying on args.config and args.model get validated input.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: e99433a4-e316-439e-8872-ae66d1eeea96
📒 Files selected for processing (4)
examples/speculative_decoding/main.pymodelopt_recipes/speculative_decoding/llama3_eagle_offline.yamlmodelopt_recipes/speculative_decoding/llama3_eagle_online.yamltests/examples/speculative_decoding/test_eagle.py
✅ Files skipped from review due to trivial changes (1)
- modelopt_recipes/speculative_decoding/llama3_eagle_online.yaml
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt_recipes/speculative_decoding/llama3_eagle_offline.yaml
| if not os.path.isfile(data_args.draft_vocab_cache): | ||
| raise FileNotFoundError( | ||
| f"Draft vocab cache provided but not found: {data_args.draft_vocab_cache}" | ||
| ) | ||
| model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache) | ||
| model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache, weights_only=True) |
There was a problem hiding this comment.
Guard draft_vocab_cache before filesystem checks to avoid TypeError.
On Line 256, os.path.isfile(data_args.draft_vocab_cache) will throw when draft_vocab_cache is None. Fail fast with a clear message before checking file existence.
Proposed fix
# Load draft vocab cache if the draft model uses a compressed vocabulary
if model.eagle_config.draft_vocab_size < model.eagle_config.vocab_size:
- if not os.path.isfile(data_args.draft_vocab_cache):
+ if not data_args.draft_vocab_cache:
+ raise ValueError(
+ "data.draft_vocab_cache must be set when draft_vocab_size < vocab_size."
+ )
+ if not os.path.isfile(data_args.draft_vocab_cache):
raise FileNotFoundError(
f"Draft vocab cache provided but not found: {data_args.draft_vocab_cache}"
)
model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache, weights_only=True)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/speculative_decoding/main.py` around lines 256 - 260, The code
currently calls os.path.isfile(data_args.draft_vocab_cache) which raises a
TypeError when data_args.draft_vocab_cache is None; add a guard to check for a
truthy/non-None draft_vocab_cache before performing os.path.isfile.
Specifically, in the block that assigns model.eagle_module.d2t via torch.load,
first assert or raise a clear error if data_args.draft_vocab_cache is falsy
(e.g., None or empty string) with a descriptive message, then proceed to call
os.path.isfile(data_args.draft_vocab_cache) and
torch.load(data_args.draft_vocab_cache, weights_only=True) only when the path
exists.
What does this PR do?
Refactors EAGLE3 training to use a unified YAML-based config system.
Type of change: Refactor
Changes
launch_train.shnow accepts--config <yaml>(required) and--model <path>(optional override). All other settings live in YAML.modelopt_recipes/speculative_decoding/with__base__inheritance support.eagle_config.jsonandfsdp_config.json; architecture config is now nested undereagle.eagle_architecture_configin YAML.examples/speculative_decoding/README.mdfor the new interface.Usage
Summary by CodeRabbit
Documentation
New Features
Chores
Tests