Skip to content

Commit ec0d984

Browse files
authored
[nvbug/5280806][fix] Fix 2 model spec decode flow (#4807)
Signed-off-by: Mike Iovine <[email protected]>
1 parent 9e05613 commit ec0d984

File tree

6 files changed

+37
-25
lines changed

6 files changed

+37
-25
lines changed

examples/pytorch/quickstart_advanced.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def add_llm_args(parser):
110110
parser.add_argument('--spec_decode_nextn', type=int, default=1)
111111
parser.add_argument('--eagle_model_dir', type=str, default=None)
112112
parser.add_argument('--max_matching_ngram_size', type=int, default=5)
113+
parser.add_argument('--use_one_model', default=False, action='store_true')
113114

114115
# Relaxed acceptance
115116
parser.add_argument('--use_relaxed_acceptance_for_thinking',
@@ -139,6 +140,11 @@ def setup_llm(args):
139140
) if args.spec_decode_algo is not None else None
140141

141142
if spec_decode_algo == 'MTP':
143+
if not args.use_one_model:
144+
print(
145+
"MTP only supports one model style spec decode; ignoring default use_one_model=False"
146+
)
147+
142148
spec_config = MTPDecodingConfig(
143149
num_nextn_predict_layers=args.spec_decode_nextn,
144150
use_relaxed_acceptance_for_thinking=args.
@@ -148,7 +154,8 @@ def setup_llm(args):
148154
elif spec_decode_algo == "EAGLE3":
149155
spec_config = EagleDecodingConfig(
150156
max_draft_len=args.spec_decode_nextn,
151-
pytorch_eagle_weights_path=args.eagle_model_dir)
157+
pytorch_eagle_weights_path=args.eagle_model_dir,
158+
eagle3_one_model=args.use_one_model)
152159
elif spec_decode_algo == "NGRAM":
153160
spec_config = NGramDecodingConfig(
154161
prompt_lookup_num_tokens=args.spec_decode_nextn,

tensorrt_llm/_torch/models/modeling_llama.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1242,8 +1242,6 @@ def forward(
12421242

12431243
hidden_states, hidden_states_to_save = self.norm(
12441244
hidden_states, residual)
1245-
if self.spec_config.spec_dec_mode.is_eagle3():
1246-
spec_metadata.maybe_capture_hidden_states(1, hidden_states_to_save)
12471245
return hidden_states, hidden_states_to_save
12481246

12491247

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1229,7 +1229,15 @@ def _prepare_tp_inputs(
12291229
num_draft_tokens = len(request.py_draft_tokens)
12301230
past_seen_token_num = request.max_beam_num_tokens - 1
12311231
draft_lens.append(num_draft_tokens)
1232-
prompt_lengths.append(request.py_prompt_len)
1232+
1233+
if self.is_spec_decode and self.spec_config.spec_dec_mode.extend_ctx(
1234+
self.attn_backend):
1235+
# We're treating the prompt lengths as context requests here, so
1236+
# the the prompt lens should not include the cached tokens.
1237+
prompt_lengths.append(1 + num_draft_tokens)
1238+
else:
1239+
prompt_lengths.append(request.py_prompt_len)
1240+
12331241
sequence_lengths.append(1 + num_draft_tokens)
12341242
gather_ids.extend(
12351243
list(

tensorrt_llm/_torch/speculative/interface.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import copy
22
from dataclasses import dataclass, field
33
from enum import IntEnum, auto
4-
from typing import Dict, List, Optional
4+
from typing import Dict, List, Optional, Type
55

66
import torch
77

@@ -59,7 +59,7 @@ def need_load_draft_weights(self):
5959
def has_spec_decoder(self):
6060
return self.is_mtp() or self.is_eagle3() or self.is_eagle3_one_model()
6161

62-
def extend_ctx(self, attention_backend: AttentionBackend):
62+
def extend_ctx(self, attention_backend: Type[AttentionBackend]):
6363
"""
6464
If true, treat generation requests with draft tokens as
6565
chunked context requests at the kernel level. Required for
@@ -68,7 +68,7 @@ def extend_ctx(self, attention_backend: AttentionBackend):
6868

6969
# Fixme: only trtllm attention backend supports eagle3 generation-phase kernels on blackwell.
7070
return (self.is_eagle3()
71-
and not (isinstance(attention_backend, TrtllmAttention)
71+
and not (issubclass(attention_backend, TrtllmAttention)
7272
and get_sm_version() == 100)) or self.is_ngram()
7373

7474
@staticmethod

tests/integration/test_lists/waives.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,6 @@ examples/test_gpt.py::test_starcoder_fp8_quantization_2gpu[starcoder] SKIP (http
382382
examples/test_gpt.py::test_starcoder_fp8_quantization_2gpu[starcoderplus] SKIP (https://nvbugs/5144931)
383383
unittest/_torch -k "not (modeling or multi_gpu or auto_deploy)" SKIP (https://nvbugs/5280806)
384384
examples/test_whisper.py::test_llm_whisper_general[large-v3-disable_gemm_plugin-disable_attention_plugin-disable_weight_only-float16-nb:1-use_python_runtime] SKIP (https://nvbugs/5244570)
385-
unittest/_torch/speculative/test_eagle3.py SKIP (https://nvbugs/5280806)
386385
triton_server/test_triton_rcca.py::test_mistral_beam_search[rcca_4714407-True-10---False-True-False-0-128-disableDecoupleMode-inflight_fused_batching-disableTrtOverlap--guaranteed_no_evict---1-1-1-False-ensemble] SKIP (https://nvbugs/5240060)
387386
triton_server/test_triton.py::test_triton_extensive[triton-extensive] SKIP
388387
triton_server/test_triton.py::test_gpt_speculative_decoding[gpt-speculative-decoding] SKIP

tests/unittest/_torch/speculative/test_eagle3.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from tensorrt_llm import SamplingParams
99
from tensorrt_llm._torch import LLM
10-
from tensorrt_llm.llmapi import BuildConfig, EagleDecodingConfig, KvCacheConfig
10+
from tensorrt_llm.llmapi import EagleDecodingConfig, KvCacheConfig
1111

1212
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
1313
from utils.llm_data import llm_models_root
@@ -38,20 +38,19 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str):
3838

3939
draft_len = 4
4040
spec_config = EagleDecodingConfig(
41-
max_draft_len=draft_len, pytorch_eagle_weights_path=eagle_model_dir)
42-
43-
build_config = None
44-
if attn_backend == "FLASHINFER":
45-
# TODO: fix max seq len logic in py_executor_creator. We will get
46-
# an illegal memory access if this is not set to a preset value,
47-
# which is definitely not right.
48-
build_config = BuildConfig(max_seq_len=2048)
49-
50-
llm_spec = LLM(model=target_model_dir,
51-
**pytorch_config,
52-
kv_cache_config=kv_cache_config,
53-
speculative_config=spec_config,
54-
build_config=build_config)
41+
max_draft_len=draft_len,
42+
pytorch_eagle_weights_path=eagle_model_dir,
43+
# Llama 3 does not support one model eagle.
44+
eagle3_one_model=False)
45+
46+
llm_spec = LLM(
47+
model=target_model_dir,
48+
**pytorch_config,
49+
kv_cache_config=kv_cache_config,
50+
speculative_config=spec_config,
51+
# TODO: https://nvbugspro.nvidia.com/bug/5319281
52+
max_num_tokens=2048,
53+
max_seq_len=2048)
5554

5655
sampling_params = SamplingParams(
5756
max_tokens=32,
@@ -78,7 +77,7 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str):
7877
num_tokens = len(new_tokens)
7978

8079
accept_rate = num_accepted / num_drafted
81-
assert accept_rate > 0.25
80+
assert accept_rate > 0.15
8281

8382
prompts = [
8483
"The capital of France is", "The president of the United States is"
@@ -90,7 +89,8 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str):
9089
llm_ref = LLM(model=target_model_dir,
9190
**pytorch_config,
9291
kv_cache_config=kv_cache_config,
93-
build_config=build_config)
92+
max_num_tokens=2048,
93+
max_seq_len=2048)
9494

9595
results_ref = llm_ref.generate(prompts, sampling_params)
9696
generated_text_ref = [result.outputs[0].text for result in results_ref]

0 commit comments

Comments
 (0)