Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
3f89ea9
add: DFlash block diffusion speculative decoding
ChenhanYu Mar 27, 2026
190cb3a
fix: rewrite DFlash to match SpecForge reference
ChenhanYu Mar 28, 2026
b7a2a7b
fix: correct mask_token_id and base model forward dispatch
ChenhanYu Mar 29, 2026
a310d96
add: auto-detect mask_token_id for DFlash across model families
ChenhanYu Mar 29, 2026
972dfaa
fix: prevent DDP deadlock during AR validation
ChenhanYu Mar 29, 2026
6c4eb80
fix: avoid DynamicModule dispatch loop in forward/training paths
ChenhanYu Mar 29, 2026
2c42363
fix: revert training/eval to super().forward() matching EAGLE pattern
ChenhanYu Mar 30, 2026
a279960
fix: DDP deadlock when no valid loss positions on a rank
ChenhanYu Mar 30, 2026
cbddc30
add: logit distillation option for DFlash training
ChenhanYu Mar 30, 2026
c53a66a
fix: print training accuracy to console at each log step
ChenhanYu Mar 30, 2026
2eabf57
fix: use response-only loss mask for DFlash training
ChenhanYu Mar 31, 2026
2a16232
fix: apply assistant_masks to labels in LanguageDataCollator
ChenhanYu Mar 31, 2026
e3b9930
fix: robust response-only loss mask via regex assistant span detection
ChenhanYu Mar 31, 2026
07066c2
docs: add DFlash section to speculative decoding README
ChenhanYu Mar 31, 2026
a32de63
fix: resolve DFlash components from base model architecture
ChenhanYu Mar 31, 2026
6a6a9ca
fix: enable response-only loss mask for DFlash training
ChenhanYu Mar 31, 2026
a777849
add: DFlash launcher example for Qwen3-8B
ChenhanYu Apr 1, 2026
2c56aca
fix: inline values in DFlash launcher YAML for --yaml compatibility
ChenhanYu Apr 1, 2026
306fc3e
add: unit tests for DFlash speculative decoding
ChenhanYu Apr 1, 2026
c4a3ecb
fix: add docstrings to DFlash classes for coverage check
ChenhanYu Apr 1, 2026
1c23ced
add: AR validation step to DFlash launcher pipeline
ChenhanYu Apr 1, 2026
38450b0
fix: split DFlash tests into CPU (unit) and GPU tests
ChenhanYu Apr 1, 2026
4c2fc77
fix: correct DFlash attention mask test for reverse-causal pattern
ChenhanYu Apr 1, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 88 additions & 7 deletions examples/speculative_decoding/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -319,15 +319,96 @@ trainer.save_state()
trainer.save_model("<path to the output directory>")
```

## DFlash: Block Diffusion for Flash Speculative Decoding

DFlash ([arXiv:2602.06036](https://arxiv.org/abs/2602.06036)) is a parallel speculative decoding method that predicts multiple tokens simultaneously using block diffusion. Unlike autoregressive methods (EAGLE, Medusa) that draft one token at a time, DFlash predicts an entire block of tokens in parallel, then iteratively denoises them.

### Architecture

DFlash uses three key mechanisms:

- **Feature Fusion**: Multi-layer hidden states from the target model are projected via a fully-connected layer and RMSNorm to create context features
- **KV Injection**: Context features are injected as K/V in every draft decoder layer, while Q comes from the noise embeddings. QK-Norm (RMSNorm on Q and K before RoPE) stabilizes attention
- **Parallel Drafting**: Within each block of size B, unknown positions use a `mask_token_id` token. Only block-start positions get the real token. The attention mask allows noise tokens to attend to all context tokens from previous blocks, plus causally within the same block

### Training

```bash
./launch_train.sh --model $BASE_MODEL \
--output_dir $OUTPUT_DIR \
--data input_conversations/train.jsonl \
--num_epochs $NUM_EPOCH \
--mode dflash \
--dflash_block_size 16 \
--dflash_num_layers 5
```

Key arguments:

| Flag | Default | Description |
|------|---------|-------------|
| `--mode dflash` | - | Enable DFlash mode |
| `--dflash_block_size` | 16 | Block size for parallel prediction |
| `--dflash_num_layers` | 5 | Number of decoder layers in draft module |
| `--dflash_config` | None | Path to JSON config for custom architecture |
| `--dflash_mask_token_id` | auto | Mask token ID (auto-detected from model) |
| `--dflash_disable_torch_compile` | False | Disable torch.compile |
| `--dflash_use_logit_distillation` | False | Use KD from target model logits instead of hard CE |

### mask_token_id

The `mask_token_id` is critical for DFlash training and inference. It must be consistent between training and deployment. Auto-detection logic:

| Model Family | mask_token_id | Source |
|-------------|---------------|--------|
| Qwen3.5 | 248070 | Built-in `[MASK]` token |
| Qwen3 (8B) | 151643 | `eos_token_id` |
| Llama 3 | 128002 | `reserved_special_token_0` |
| Others | `pad_token_id` | Fallback |

Override with `--dflash_mask_token_id <id>` if auto-detection is incorrect.

### Configuring Draft Model

Similar to EAGLE, provide a JSON config to customize the draft architecture:

```json
{
"num_hidden_layers": 5,
"rms_norm_eps": 1e-6
}
```

Model dimensions (hidden_size, num_attention_heads, etc.) are automatically inherited from the base model.

### Current Status (WIP)

| Feature | Status |
|---------|--------|
| Architecture (Feature Fusion, KV Injection, Parallel Drafting) | Working |
| Online training with HF Trainer | Working |
| Inference / AR validation (`pseudo_speculative_generate`) | Working |
| z-lab checkpoint loading and inference (AR 7-9) | Working |
| Logit distillation option | Working |
| Response-only loss masking | Working |
| DDP training | Working (with `find_unused_parameters=True`) |

**Known gap**: Training with ModelOpt achieves ~35% per-token accuracy (matching SpecForge's ~30%), but acceptance rate (AR) is lower than SpecForge-trained checkpoints (1.15 vs 1.95). Investigation shows the **data pipeline** differs significantly:

- SpecForge uses its own tokenizer template with system prompt and response-only loss mask
- ModelOpt's `LanguageDataCollator` uses `apply_chat_template` with different formatting

Aligning the data pipeline is the next step to close the AR gap.

## Support Matrix

| Model | Medusa | EAGLE1/2 | EAGLE3 |
| :---: | :---: | :---: | :---: |
| LLAMA 2 | ✅ | ✅ | ✅ |
| LLAMA 3, 3.1 | ✅ | ✅ | ✅ |
| Mistral | ✅ | ✅ | ✅ |
| Phi 3 | ✅ | ✅ | ✅ |
| QWen 1.5,2,2.5,3 | ✅ | ✅ | ✅ |
| Model | Medusa | EAGLE1/2 | EAGLE3 | DFlash |
| :---: | :---: | :---: | :---: | :---: |
| LLAMA 2 | ✅ | ✅ | ✅ | ✅ |
| LLAMA 3, 3.1 | ✅ | ✅ | ✅ | ✅ |
| Mistral | ✅ | ✅ | ✅ | ✅ |
| Phi 3 | ✅ | ✅ | ✅ | ✅ |
| QWen 1.5,2,2.5,3 | ✅ | ✅ | ✅ | ✅ |

## Speculation Module Checkpoints

Expand Down
44 changes: 30 additions & 14 deletions examples/speculative_decoding/eagle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def make_eagle_supervised_data_module(
tokenizer: transformers.PreTrainedTokenizer,
data_args,
train_len=None,
answer_only_loss=False,
) -> dict:
if data_args.offline_data_path is None:
train_dataset = ShardedDataset("json", data_files=data_args.data_path)
Expand All @@ -148,6 +149,7 @@ def make_eagle_supervised_data_module(
tokenizer=tokenizer,
train_len=train_len,
return_labels=True,
answer_only_loss=answer_only_loss,
)
else:
data_collator = VisionLanguageDataCollator(
Expand Down Expand Up @@ -203,6 +205,12 @@ def on_log(self, args, state, control, **kwargs):
if not hasattr(state, "training_accs") or len(state.training_accs) == 0:
return control
average_acc = np.mean(state.training_accs, axis=0)
# Always print accuracy to console
try:
acc_str = ", ".join(f"{a:.4f}" for a in np.array(average_acc).flatten())
print_rank_0(f"Step {state.global_step} Training Acc: [{acc_str}]")
except Exception:
print_rank_0(f"Step {state.global_step} Training Acc: {average_acc}")
if self.estimate_ar:
# Calculate mean training AR since last log
# NOTE: This is only an estimate of the real AR.
Expand Down Expand Up @@ -235,23 +243,31 @@ def on_log(self, args, state, control, **kwargs):
return control

def on_step_end(self, args, state, control, **kwargs):
"""Run AR validation periodically, if available."""
"""Run AR validation periodically, if available.

Only runs on rank 0 to avoid DDP deadlock — other ranks skip and
synchronize via barrier.
"""
if self.ar_validate_steps <= 0:
return control
if state.global_step % self.ar_validate_steps == 0 and state.global_step > 0:
print_rank_0("Running AR validation...")
try:
ars = validate_ar(
model=kwargs["model"],
tokenizer=kwargs["processing_class"],
ds=load_dataset("HuggingFaceH4/mt_bench_prompts")["train"],
device=kwargs["model"].device,
)
print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}")
if wandb and is_master():
wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step)
except Exception:
print_rank_0("AR validation not available.")
if is_master():
print_rank_0("Running AR validation...")
try:
ars = validate_ar(
model=kwargs["model"],
tokenizer=kwargs["processing_class"],
ds=load_dataset("HuggingFaceH4/mt_bench_prompts")["train"],
device=kwargs["model"].device,
)
print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}")
if wandb:
wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step)
except Exception:
print_rank_0("AR validation not available.")
# Barrier to synchronize all ranks after validation
if torch.distributed.is_initialized():
torch.distributed.barrier()
Comment on lines +254 to +270
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

AR_VALIDATE_FILE="$(fd -a '^ar_validate\.py$' scripts | head -n 1)"
test -n "$AR_VALIDATE_FILE"

echo "== validate_ar implementation: $AR_VALIDATE_FILE =="
sed -n '1,260p' "$AR_VALIDATE_FILE"

echo
echo "== distributed/sharding-sensitive calls in validate_ar =="
rg -n -C2 'unwrap|model\.module|distributed|barrier|all_reduce|all_gather|get_rank|fsdp|deepspeed|context_parallel|tensor_parallel|pseudo_speculative_generate|generate\(' "$AR_VALIDATE_FILE" || true

echo
echo "== Eagle/distributed call sites =="
rg -n -C2 'EagleTrainingPlot|patch_ring_attention_for_ttt|cp_size|tp_size|context parallel|tensor parallel' examples/speculative_decoding || true

Repository: NVIDIA/Model-Optimizer

Length of output: 160


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# First, explore top-level structure
echo "=== Repository structure ==="
ls -la | head -20

echo
echo "=== Looking for ar_validate or validate_ar files ==="
fd -i 'ar_validate|validate.*ar' --type f

echo
echo "=== Looking for scripts directory ==="
find . -maxdepth 2 -type d -name 'scripts' 2>/dev/null || echo "No scripts dir found"

echo
echo "=== Searching for validate_ar function definition ==="
rg -l 'def validate_ar' --type py

Repository: NVIDIA/Model-Optimizer

Length of output: 1578


🏁 Script executed:

#!/bin/bash
set -euo pipefail

VALIDATE_FILE="examples/speculative_decoding/scripts/ar_validate.py"

# Read the file to understand validate_ar implementation
echo "=== Full ar_validate.py content ==="
cat -n "$VALIDATE_FILE"

Repository: NVIDIA/Model-Optimizer

Length of output: 4763


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Searching for HFARValidation class ==="
rg -n 'class HFARValidation|def validate\(' --type py -B2 -A10 | head -60

echo
echo "=== Checking eagle_utils.py context around validation ==="
cat -n examples/speculative_decoding/eagle_utils.py | sed -n '220,260p'

echo
echo "=== Searching for model wrapping/preparation in eagle_utils.py ==="
rg -n 'unwrap|DistributedDataParallel|FSDP|tensor_parallel|context_parallel' examples/speculative_decoding/eagle_utils.py || echo "No explicit wrapping patterns found"

Repository: NVIDIA/Model-Optimizer

Length of output: 7010


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Searching for pseudo_speculative_generate definition ==="
rg -n 'def pseudo_speculative_generate' --type py -B2 -A15 | head -50

echo
echo "=== Checking eagle training script for model setup ==="
fd -i 'eagle.*train' --type py examples/speculative_decoding | head -5

echo
echo "=== Checking imports in eagle_utils.py ==="
head -50 examples/speculative_decoding/eagle_utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 4716


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Full pseudo_speculative_generate from transformers.py ==="
rg -n 'def pseudo_speculative_generate' modelopt/torch/speculative/plugins/transformers.py -A30

echo
echo "=== Eagle training script ==="
fd -type f -name '*eagle*train*' examples/speculative_decoding

echo
echo "=== First 80 lines of eagle_utils.py imports ==="
head -80 examples/speculative_decoding/eagle_utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1651


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Full pseudo_speculative_generate from transformers.py ==="
sed -n '1055,1100p' modelopt/torch/speculative/plugins/transformers.py

echo
echo "=== Eagle training scripts ==="
find examples/speculative_decoding -name '*train*' -type f 2>/dev/null

echo
echo "=== First 80 lines of eagle_utils.py ==="
head -80 examples/speculative_decoding/eagle_utils.py

echo
echo "=== Check if model gets prepared with accelerate/DDP ==="
rg -n 'prepare|DistributedDataParallel|FSDP' examples/speculative_decoding --type py | head -30

Repository: NVIDIA/Model-Optimizer

Length of output: 4722


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Check pop_and_gather_aux_hiddens for collective operations ==="
rg -n 'pop_and_gather_aux_hiddens|all_gather|broadcast|reduce' modelopt/torch/speculative/plugins/transformers.py -B2 -A8 | head -80

echo
echo "=== Check if EAGLE model has distributed state ==="
rg -n 'class.*EAGLE|_aux_hidden_states|gather' modelopt/torch/speculative/plugins/transformers.py | head -30

echo
echo "=== Check training script for model preparation ==="
grep -r 'Trainer\|prepare\|accelerate' examples/speculative_decoding/train*.sh 2>/dev/null || echo "No train scripts with prepare"

Repository: NVIDIA/Model-Optimizer

Length of output: 4128


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Check if pop_and_gather_aux_hiddens has ANY distributed calls ==="
sed -n '532,560p' modelopt/torch/speculative/plugins/transformers.py

echo
echo "=== Check eagle_utils.py for context-parallel or tensor-parallel wrapping ==="
rg -n 'patch_ring_attention|cp_size|tp_size|tensor_parallel|context_parallel' examples/speculative_decoding/eagle_utils.py -B3 -A3

echo
echo "=== Check if model is actually wrapped in training callback context ==="
rg -n 'EagleTrainingPlot|on_step_end|kwargs\[' examples/speculative_decoding/eagle_utils.py | head -20

echo
echo "=== Look for where model might be context-parallel wrapped ==="
grep -r 'patch_ring_attention_for_ttt\|apply.*context' examples/speculative_decoding/ --include='*.py' | head -10

Repository: NVIDIA/Model-Optimizer

Length of output: 3447


Move dataset loading outside the master-only validation block for efficiency.

The validation code is safe as-is: validate_ar() contains no distributed collectives, and pop_and_gather_aux_hiddens() only performs local tensor concatenation (not a distributed gather). The barrier() correctly synchronizes ranks after rank-0 validation.

However, reloading the validation dataset inside the master-only block wastes idle time on non-master ranks. Load the dataset once before the validation check:

if state.global_step % self.ar_validate_steps == 0 and state.global_step > 0:
    ds = load_dataset("HuggingFaceH4/mt_bench_prompts")["train"]
    if is_master():
        print_rank_0("Running AR validation...")
        try:
            ars = validate_ar(
                model=kwargs["model"],
                tokenizer=kwargs["processing_class"],
                ds=ds,
                device=kwargs["model"].device,
            )
            print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}")
            if wandb:
                wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step)
        except Exception:
            print_rank_0("AR validation not available.")
    if torch.distributed.is_initialized():
        torch.distributed.barrier()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/speculative_decoding/eagle_utils.py` around lines 238 - 254, Load
the validation dataset before the master-only block so non-master ranks don't
waste time; specifically, when checking the AR validation trigger
(state.global_step % self.ar_validate_steps == 0 and state.global_step > 0) call
load_dataset("HuggingFaceH4/mt_bench_prompts")["train"] into a local variable
(e.g., ds) before the is_master() check, then inside the is_master() block call
validate_ar(model=kwargs["model"], tokenizer=kwargs["processing_class"], ds=ds,
device=kwargs["model"].device) as before; keep print_rank_0, the try/except
around validate_ar, and the torch.distributed.barrier() after the block to
preserve synchronization.

return control


Expand Down
44 changes: 37 additions & 7 deletions examples/speculative_decoding/launch_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,22 @@ while [ $# -gt 0 ]; do
if [[ "$1" != *=* ]]; then shift; fi
FSDP="${1#*=}"
;;
--dflash_block_size*)
if [[ "$1" != *=* ]]; then shift; fi
DFLASH_BLOCK_SIZE="${1#*=}"
;;
--dflash_num_layers*)
if [[ "$1" != *=* ]]; then shift; fi
DFLASH_NUM_LAYERS="${1#*=}"
;;
--dflash_config*)
if [[ "$1" != *=* ]]; then shift; fi
DFLASH_CONFIG="${1#*=}"
;;
--dflash_mask_token_id*)
if [[ "$1" != *=* ]]; then shift; fi
DFLASH_MASK_TOKEN_ID="${1#*=}"
;;
*)
>&2 printf "Error: Invalid argument ${1#*=}\n"
exit 1
Expand Down Expand Up @@ -195,8 +211,20 @@ if [[ "$MODE" == "eagle3" ]]; then
else
SPECULATIVE_ARGS=""
fi
elif [[ "$MODE" == "dflash" ]]; then
DFLASH_BLOCK_SIZE=${DFLASH_BLOCK_SIZE:-16}
DFLASH_NUM_LAYERS=${DFLASH_NUM_LAYERS:-5}
SPECULATIVE_ARGS="--dflash_block_size $DFLASH_BLOCK_SIZE --dflash_num_layers $DFLASH_NUM_LAYERS --dflash_disable_torch_compile"
if [[ -n "$DFLASH_CONFIG" ]]; then
SPECULATIVE_ARGS="$SPECULATIVE_ARGS --dflash_config $DFLASH_CONFIG"
fi
if [[ -n "$DFLASH_MASK_TOKEN_ID" ]]; then
SPECULATIVE_ARGS="$SPECULATIVE_ARGS --dflash_mask_token_id $DFLASH_MASK_TOKEN_ID"
fi
# DFlash uses DDP instead of FSDP
FSDP_ARGS="--ddp_find_unused_parameters True --ddp_timeout 300"
else
echo "Only eagle3 supported for now!"
echo "Unsupported mode: $MODE. Supported: eagle3, dflash"
exit 1
fi

Expand All @@ -218,12 +246,14 @@ else
VLM_ARGS=""
fi

if [[ "$TOTAL_GPU" -gt 1 && "$FSDP" == "True" ]]; then
#Use FSDP2 when multi GPU available
FSDP_ARGS="--fsdp 'full_shard' --fsdp_config ${SCRIPT_DIR}/fsdp_config.json"
else
#Otherwise, single GPU training
FSDP_ARGS=""
if [[ "$MODE" != "dflash" ]]; then
if [[ "$TOTAL_GPU" -gt 1 && "$FSDP" == "True" ]]; then
#Use FSDP2 when multi GPU available
FSDP_ARGS="--fsdp 'full_shard' --fsdp_config ${SCRIPT_DIR}/fsdp_config.json"
else
#Otherwise, single GPU training
FSDP_ARGS=""
fi
fi


Expand Down
54 changes: 50 additions & 4 deletions examples/speculative_decoding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class TrainingArguments(transformers.TrainingArguments):
)
dataloader_drop_last: bool = field(default=True)
bf16: bool = field(default=True)
mode: Literal["eagle3", "medusa"] = "eagle3"
mode: Literal["eagle3", "medusa", "dflash"] = "eagle3"
estimate_ar: bool = field(
default=False, metadata={"help": "Whether to estimate AR during training for logging."}
)
Expand Down Expand Up @@ -144,6 +144,32 @@ class EagleArguments:
)


@dataclass
class DFlashArguments:
dflash_block_size: int = field(
default=16, metadata={"help": "Block size for DFlash parallel prediction."}
)
dflash_num_layers: int = field(
default=5, metadata={"help": "Number of decoder layers in the DFlash draft module."}
)
dflash_config: str = field(default=None, metadata={"help": "Path to dflash_config.json"})
dflash_disable_torch_compile: bool = field(
default=False,
metadata={"help": "Disable torch.compile on DFlash forward/loss methods."},
)
dflash_mask_token_id: int = field(
default=None,
metadata={"help": "Mask token ID for DFlash. If not set, auto-detected from model."},
)
dflash_use_logit_distillation: bool = field(
default=False,
metadata={
"help": "Use logit distillation (KD from target model) instead of hard CE. "
"Enables training with data not synthesized by the target model."
},
)


def train():
parser = transformers.HfArgumentParser(
(
Expand All @@ -152,9 +178,10 @@ def train():
TrainingArguments,
MedusaArguments,
EagleArguments,
DFlashArguments,
)
)
model_args, data_args, training_args, medusa_args, eagle_args = (
model_args, data_args, training_args, medusa_args, eagle_args, dflash_args = (
parser.parse_args_into_dataclasses()
)
if not data_args.data_path and not data_args.offline_data_path:
Expand Down Expand Up @@ -236,13 +263,32 @@ def train():
)
model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache)
print_rank_0(f"Loaded draft vocab cache from {data_args.draft_vocab_cache}.")
elif training_args.mode == "dflash":
custom_config = (
json.load(open(dflash_args.dflash_config)) if dflash_args.dflash_config else {}
)
custom_config.setdefault("num_hidden_layers", dflash_args.dflash_num_layers)
if dflash_args.dflash_mask_token_id is not None:
custom_config["mask_token_id"] = dflash_args.dflash_mask_token_id

config = {
"dflash_block_size": dflash_args.dflash_block_size,
"dflash_use_torch_compile": not dflash_args.dflash_disable_torch_compile,
"dflash_self_logit_distillation": dflash_args.dflash_use_logit_distillation,
"dflash_architecture_config": custom_config,
}

mtsp.convert(model, [("dflash", config)])
else:
raise Exception(f"{training_args.mode} is not supported!")

print_rank_0("Loading dataset...")
if training_args.mode == "eagle3":
if training_args.mode in ("eagle3", "dflash"):
data_module = make_eagle_supervised_data_module(
tokenizer, data_args, train_len=training_args.training_seq_len
tokenizer,
data_args,
train_len=training_args.training_seq_len,
answer_only_loss=(training_args.mode == "dflash"),
)

trainer = EagleTrainerWithAccLog(
Expand Down
48 changes: 48 additions & 0 deletions modelopt/torch/speculative/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,54 @@
}


def _get_dflash_default_config():
from .dflash.default_config import default_dflash_config

return default_dflash_config


DFLASH_DEFAULT_CFG = {
"algorithm": "dflash",
"config": {
"dflash_architecture_config": {}, # merged with default at convert time
},
}


class DFlashConfig(ModeloptBaseConfig):
"""DFlash config for block-wise parallel speculative decoding."""

dflash_block_size: int = ModeloptField(
default=16,
description="Block size for parallel prediction. Draft predicts this many tokens per block.",
)

dflash_freeze_base_model: bool = ModeloptField(
default=True, description="Whether to freeze base model during DFlash module training."
)

dflash_self_logit_distillation: bool = ModeloptField(
default=True, description="Whether to use logit distillation from base model."
)

dflash_loss_decay_factor: float = ModeloptField(
default=0.9, description="Decay factor for per-block loss weighting."
)

dflash_report_acc: bool = ModeloptField(
default=True, description="Whether to report eval accuracy."
)

dflash_architecture_config: dict = ModeloptField(
default={}, description="Config for the DFlash draft module architecture."
)

dflash_use_torch_compile: bool = ModeloptField(
default=True,
description="Whether to use torch.compile on DFlash forward/loss methods.",
)


class MedusaConfig(ModeloptBaseConfig):
"""Medusa config."""

Expand Down
Loading
Loading