Skip to content

Commit 13a6565

Browse files
committed
[None][feat] Make 2-model spec dec use the 1-model kernels (Hopper)
Signed-off-by: Mike Iovine <[email protected]>
1 parent f2ebaf2 commit 13a6565

File tree

2 files changed

+7
-13
lines changed

2 files changed

+7
-13
lines changed

tensorrt_llm/_torch/speculative/interface.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import torch
77

8-
from ..._utils import get_sm_version
98
from ..attention_backend.trtllm import AttentionBackend, TrtllmAttention
109
from ..pyexecutor.resource_manager import BaseResourceManager
1110

@@ -113,17 +112,11 @@ def extend_ctx(self, attention_backend: Type[AttentionBackend]):
113112
# 1-model has separate logic for handling draft tokens
114113
return False
115114

116-
if issubclass(attention_backend,
117-
TrtllmAttention) and self.is_mtp_eagle():
118-
# TRTLLM MLA does not work with the chunked context mode.
119-
return False
120-
121-
return not issubclass(attention_backend,
122-
TrtllmAttention) or get_sm_version() != 100
115+
return not issubclass(attention_backend, TrtllmAttention)
123116

124117
def attention_need_spec_dec_mode(
125118
self,
126-
spec_resource_manager: BaseResourceManager,
119+
spec_resource_manager: Optional[BaseResourceManager],
127120
is_draft_model: bool,
128121
attention_backend: Type[AttentionBackend],
129122
use_chain_drafter: bool,
@@ -133,9 +126,10 @@ def attention_need_spec_dec_mode(
133126
If true, the attention backend kernel needs to run in spec-dec mode (multi-token query mode).
134127
"""
135128
is_trtllm_attention = issubclass(attention_backend, TrtllmAttention)
136-
return self.is_eagle3_one_model() or (
129+
130+
return self.is_eagle3_one_model() or not is_draft_model or (
137131
self.is_eagle3() and spec_resource_manager.is_first_draft
138-
and is_trtllm_attention and use_chain_drafter and is_draft_model)
132+
and is_trtllm_attention)
139133

140134
@staticmethod
141135
def from_string(name: Optional[str]) -> "SpeculativeDecodingMode":

tests/unittest/_torch/speculative/test_eagle3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
206206
num_tokens = len(new_tokens)
207207

208208
accept_rate = num_accepted / num_drafted
209-
assert accept_rate > 0.15
209+
assert accept_rate > 0.10
210210

211211
# Output tests
212212
sampling_params = SamplingParams(max_tokens=10, temperature=0)
@@ -252,7 +252,7 @@ def test_llama_eagle3_long_prompt(use_cuda_graph):
252252
speculative_config=spec_config,
253253
max_batch_size=1,
254254
cuda_graph_config=cuda_graph_config,
255-
disable_overlap_scheduler=False)
255+
disable_overlap_scheduler=True)
256256

257257
prompt = [", ".join(str(i) for i in range(1000))]
258258

0 commit comments

Comments
 (0)