-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Description
This bug did not exist in v0.21.0. I think it may have been introduced in v1.0.0rc3. The server slows down and eventually basically stops responding to requests. It seems to occur more quickly if there are many client-aborted requests, and possibly high concurrency accelerates it too (so it could be a race condition). I noticed that the cause was _handle_canceled_requests:
I added some logging in _handle_canceled_requests and found that self.executor_request_queue.get_canceled_req_ids_size() accumulates - never decreases.
TensorRT-LLM/tensorrt_llm/_torch/pyexecutor/py_executor.py
Lines 1913 to 1936 in e2f69c5
| def _handle_canceled_requests(self): | |
| if self.executor_request_queue.get_canceled_req_ids_size() == 0: | |
| return | |
| # Remove cancel request in the waiting queue | |
| self.executor_request_queue.update_waiting_queue() | |
| for request in self.active_requests: | |
| req_id = request.py_request_id if not request.is_child else request.parent_request_id | |
| if req_id not in self.executor_request_queue.get_canceled_req_ids(): | |
| continue | |
| is_cancelled = self._try_cancel_request(request) | |
| if is_cancelled: | |
| # Mark requests as finished, then, we reuse all existing code | |
| # to clean up the KV cache resources. | |
| request.finish_by_reason(FinishReason.CANCELLED) | |
| request.decoding_iter = request.py_decoding_iter | |
| if self.enable_attention_dp: | |
| # TODO: revisit the cancel logic of attention dp | |
| # When enable attention dp, each rank does not have full copy of requests | |
| # so we need to remove the cancel requests not in the local rank | |
| self.executor_request_queue.clear_canceled_req_ids() |
Adding enable_attention_dp:true to the extra_llm_api_options yaml solves the issue for me.
So I guess the self.executor_request_queue.clear_canceled_req_ids() gating is incorrect, or maybe what I'm observing here is just a symptom of a different underlying issue.
This is what I'm running, using 4xB200 and nvcr.io/nvidia/tensorrt-llm/release:1.2.0rc0:
cat >/root/data/trtllm-config.yml<<EOF
kv_cache_config:
enable_block_reuse: true
dtype: fp8
free_gpu_memory_fraction: 0.88
host_cache_size: 60000000000
EOF
trtllm-serve /root/data/nvidia/DeepSeek-R1-FP4 --backend pytorch --max_seq_len 9216 --max_batch_size 4096 --max_num_tokens 16384 --tp_size 4 --trust_remote_code --extra_llm_api_options /root/data/trtllm-config.yml