-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Open
Labels
Speculative Decoding<NV>MTP/Eagle/Medusa/Lookahead/Prompt-Lookup-Decoding/Draft-Target-Model/ReDrafter<NV>MTP/Eagle/Medusa/Lookahead/Prompt-Lookup-Decoding/Draft-Target-Model/ReDrafterbugSomething isn't workingSomething isn't workingstalewaiting for feedback
Description
System Info
Architecture: x86_64
OS: Ubuntu 25.04
CUDA: v13
GPU: A single RTX Pro 6000
gpt-oss-120b alone works fine at about 160 tokens/sec decode
Turning on speculative decoding using gpt-oss-120b-Eagle3 or .. v2 immediately goes to infinite loop when using Jan AI as UI
Models:
pip install huggingface-hub
huggingface-cli download openai/gpt-oss-120b --local-dir /models/original/gpt-oss-120b
huggingface-cli download nvidia/gpt-oss-120b-Eagle3 --local-dir /models/original/gpt-oss-120b-Eagle3
huggingface-cli download nvidia/gpt-oss-120b-Eagle3-v2 --local-dir /models/original/gpt-oss-120b-Eagle3-v2Container:
docker run --rm --ipc=host -it \
--ulimit stack=67108864 \
--ulimit memlock=-1 \
--gpus all \
-p 8000:8000 \
-e TRTLLM_ENABLE_PDL=1 \
-v /models:/models:rw \
nvcr.io/nvidia/tensorrt-llm/release:1.2.0rc1 \
/bin/bashlow_latency_speculative.yaml:
enable_attention_dp: false
disable_overlap_scheduler: true
enable_autotuner: false
cuda_graph_config:
max_batch_size: 1
enable_padding: true
moe_config:
backend: CUTLASS
speculative_config:
decoding_type: Eagle
max_draft_len: 3
speculative_model_dir: /models/original/gpt-oss-120b-Eagle3-v2/
kv_cache_config:
enable_block_reuse: falseServe command:
trtllm-serve \
/models/original/gpt-oss-120b \
--host 0.0.0.0 \
--port 8000 \
--backend pytorch \
--tp_size 1 \
--ep_size 1 \
--max_batch_size 1 \
--trust_remote_code \
--extra_llm_api_options low_latency_speculative.yaml \
--kv_cache_free_gpu_memory_fraction 0.9When sending any message with Jan UI, connection hangs indefinitely.
root@9f85dd98c613:/app/tensorrt_llm# nvidia-smi
Thu Oct 23 06:11:01 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.95.05 Driver Version: 580.95.05 CUDA Version: 13.0 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA RTX PRO 6000 Blac... On | 00000000:01:00.0 Off | Off |
| 30% 28C P8 16W / 600W | 20MiB / 97887MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| No running processes found |
+-----------------------------------------------------------------------------------------+
root@9f85dd98c613:/app/tensorrt_llm# nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2025 NVIDIA Corporation
Built on Wed_Jul_16_07:30:01_PM_PDT_2025
Cuda compilation tools, release 13.0, V13.0.48
Build cuda_13.0.r13.0/compiler.36260728_0Who can help?
No response
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
See above for more steps.
When logging with log_level debug I see this in the log:
When I try using speculative decoding, the serving engine goes into infinite loop:
[TensorRT-LLM][WARNING] Attention workspace size is not enough, increase the size from 33555968 bytes to 33558016 bytes
[10/23/2025-05:46:41] [TRT-LLM] [V] ------before _handle_responses, rank = 0, output = [<tensorrt_llm._torch.pyexecutor.llm_request.LlmRequest object at 0x6fded8053a20>]
[10/23/2025-05:46:41] [TRT-LLM] [V] after gather, rank = 0, responses = [(1, LlmResponse(request_id=0, error_msg=None, result=<tensorrt_llm._torch.pyexecutor.llm_request.LlmResult object at 0x6ff949fb6300>, client_id=2))]
[10/23/2025-05:46:41] [TRT-LLM] [V] has 1 active_request, scheduled 0 context requests and 1 generation requests
[10/23/2025-05:46:41] [TRT-LLM] [V] Detected use_mrope: False
[10/23/2025-05:46:41] [TRT-LLM] [V] Reset Python GC thresholds to default value: (700, 10, 10)
[10/23/2025-05:46:41] [TRT-LLM] [V] Detected use_mrope: False
[10/23/2025-05:46:41] [TRT-LLM] [V] Set Python GC threshold to customized value: 20000
[10/23/2025-05:46:41] [TRT-LLM] [V] Created HarmonyStreamState for request 2
[10/23/2025-05:46:41] [TRT-LLM] [V] ------before _handle_responses, rank = 0, output = [<tensorrt_llm._torch.pyexecutor.llm_request.LlmRequest object at 0x6fded8053a20>]
[10/23/2025-05:46:41] [TRT-LLM] [V] after gather, rank = 0, responses = [(1, LlmResponse(request_id=0, error_msg=None, result=<tensorrt_llm._torch.pyexecutor.llm_request.LlmResult object at 0x6ff949fb5f40>, client_id=2))]
[10/23/2025-05:46:41] [TRT-LLM] [V] has 1 active_request, scheduled 0 context requests and 1 generation requests
[10/23/2025-05:46:41] [TRT-LLM] [V] Reset Python GC thresholds to default value: (700, 10, 10)
[10/23/2025-05:46:41] [TRT-LLM] [V] Set Python GC threshold to customized value: 20000
[10/23/2025-05:46:41] [TRT-LLM] [V] Detected use_mrope: False
[10/23/2025-05:46:41] [TRT-LLM] [V] Detected use_mrope: False
[10/23/2025-05:46:41] [TRT-LLM] [V] ------before _handle_responses, rank = 0, output = [<tensorrt_llm._torch.pyexecutor.llm_request.LlmRequest object at 0x6fded8053a20>]
[10/23/2025-05:46:41] [TRT-LLM] [V] after gather, rank = 0, responses = [(1, LlmResponse(request_id=0, error_msg=None, result=<tensorrt_llm._torch.pyexecutor.llm_request.LlmResult object at 0x6ff949fb6240>, client_id=2))]
[10/23/2025-05:46:41] [TRT-LLM] [V] has 1 active_request, scheduled 0 context requests and 1 generation requests
[10/23/2025-05:46:41] [TRT-LLM] [V] Reset Python GC thresholds to default value: (700, 10, 10)
[10/23/2025-05:46:41] [TRT-LLM] [V] Set Python GC threshold to customized value: 20000
[10/23/2025-05:46:41] [TRT-LLM] [V] Detected use_mrope: False
[10/23/2025-05:46:41] [TRT-LLM] [V] Detected use_mrope: False
[10/23/2025-05:46:41] [TRT-LLM] [V] ------before _handle_responses, rank = 0, output = [<tensorrt_llm._torch.pyexecutor.llm_request.LlmRequest object at 0x6fded8053a20>]
[10/23/2025-05:46:41] [TRT-LLM] [V] after gather, rank = 0, responses = [(1, LlmResponse(request_id=0, error_msg=None, result=<tensorrt_llm._torch.pyexecutor.llm_request.LlmResult object at 0x6ff949fb77d0>, client_id=2))]
[10/23/2025-05:46:41] [TRT-LLM] [V] has 1 active_request, scheduled 0 context requests and 1 generation requests
[10/23/2025-05:46:41] [TRT-LLM] [V] Detected use_mrope: False
[10/23/2025-05:46:41] [TRT-LLM] [V] Reset Python GC thresholds to default value: (700, 10, 10)
[10/23/2025-05:46:41] [TRT-LLM] [V] Set Python GC threshold to customized value: 20000
[10/23/2025-05:46:41] [TRT-LLM] [V] Detected use_mrope: False
[10/23/2025-05:46:41] [TRT-LLM] [V] ------before _handle_responses, rank = 0, output = [<tensorrt_llm._torch.pyexecutor.llm_request.LlmRequest object at 0x6fded8053a20>]
(and this goes on very fast infinitely)
Expected behavior
The model runs and Jan UI returns results
actual behavior
Infinite loop
additional notes
I did also try to run curl manually, and here was the result (request completed but output is weird, and I think somehow it does not use Harmony format internally):
zmarty@zmarty-aorus:/models/exl3$ curl -s http://127.0.0.1:8000/v1/chat/completions -H "Content-Type: application/json" -d '{
"model": "gpt-oss-120b",
"messages": [{"role":"user","content":"Say hi in one sentence."}],
"max_tokens": 64
}' | jq .
{
"id": "chatcmpl-a324e25e3e57422eb8c49bbb09554d46",
"object": "chat.completion",
"created": 1761196686,
"model": "gpt-oss-120b",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "<|channel|>!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!",
"reasoning_content": null,
"reasoning": null,
"tool_calls": []
},
"logprobs": null,
"finish_reason": "length",
"stop_reason": null,
"mm_embedding_handle": null,
"disaggregated_params": null,
"avg_decoded_tokens_per_iter": null
}
],
"usage": {
"prompt_tokens": 77,
"total_tokens": 141,
"completion_tokens": 64,
"prompt_tokens_details": null
},
"prompt_token_ids": null
}Before submitting a new issue...
- Make sure you already searched for relevant issues, and checked the documentation and examples for answers to frequently asked questions.
Metadata
Metadata
Assignees
Labels
Speculative Decoding<NV>MTP/Eagle/Medusa/Lookahead/Prompt-Lookup-Decoding/Draft-Target-Model/ReDrafter<NV>MTP/Eagle/Medusa/Lookahead/Prompt-Lookup-Decoding/Draft-Target-Model/ReDrafterbugSomething isn't workingSomething isn't workingstalewaiting for feedback