Skip to content

Commit 6151a4c

Browse files
authored
[None][feat] Add simple optimizations for MTP 2-model (#9176)
Signed-off-by: Mike Iovine <[email protected]>
1 parent 24f5cd7 commit 6151a4c

File tree

2 files changed

+1
-4
lines changed

2 files changed

+1
-4
lines changed

tensorrt_llm/_torch/speculative/drafting_loops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def save_metadata_state(attn_metadata: AttentionMetadata,
5959
spec_metadata.eagle3_resource_manager.is_first_draft = True
6060

6161

62+
@torch.compile(options={'max-autotune': True})
6263
def prepare_for_generation(attn_metadata: AttentionMetadata,
6364
spec_metadata: SpecMetadata,
6465
position_ids: torch.Tensor) -> torch.Tensor:

tensorrt_llm/_torch/speculative/interface.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,6 @@ def needs_kv_cache_rewind(self):
6767
) or self.is_ngram()
6868

6969
def support_overlap_scheduler(self):
70-
# TODO: fix accuracy issue
71-
if self.is_mtp_eagle():
72-
return False
73-
7470
return self.is_mtp_one_model() or self.is_eagle3_one_model(
7571
) or self.has_draft_model()
7672

0 commit comments

Comments
 (0)