Skip to content

Commit b625da5

Browse files
committed
refine the pr
Signed-off-by: Superjomn <[email protected]>
1 parent 4510602 commit b625da5

File tree

1 file changed

+22
-5
lines changed

1 file changed

+22
-5
lines changed

tensorrt_llm/executor/rpc_worker.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,17 @@ class RpcWorker(BaseWorker):
3636
- `shutdown`: Shutdown the worker.
3737
"""
3838

39-
# Number of RPC server workers
39+
# Default number of RPC server workers
4040
# Increased to handle concurrent requests and prevent thread pool exhaustion
4141
# Need enough workers for: submit requests + fetch_responses + other operations
42-
NUM_WORKERS = 32
42+
# Can be overridden via constructor parameter
43+
DEFAULT_NUM_WORKERS = 32
44+
45+
# Default timeout for fetch_responses in seconds
46+
# This is a short timeout to prevent blocking the event loop while still allowing
47+
# responses to be fetched efficiently. The value is tuned to balance responsiveness
48+
# and CPU usage. Can be overridden via constructor parameter.
49+
DEFAULT_FETCH_TIMEOUT = 0.1
4350

4451
def __init__(
4552
self,
@@ -51,6 +58,8 @@ def __init__(
5158
hf_model_dir: Optional[Path] = None,
5259
tokenizer: Optional[TokenizerBase] = None,
5360
llm_args: Optional[BaseLlmArgs] = None,
61+
num_workers: Optional[int] = None,
62+
fetch_timeout: Optional[float] = None,
5463
) -> None:
5564
super().__init__(
5665
engine=engine,
@@ -63,6 +72,12 @@ def __init__(
6372
llm_args=llm_args,
6473
)
6574

75+
# Configure number of RPC workers
76+
self.num_workers = num_workers if num_workers is not None else self.DEFAULT_NUM_WORKERS
77+
78+
# Configure fetch timeout
79+
self._fetch_timeout = fetch_timeout if fetch_timeout is not None else self.DEFAULT_FETCH_TIMEOUT
80+
6681
# Extract garbage_collection_gen0_threshold from llm_args if available
6782
self.garbage_collection_gen0_threshold = (
6883
llm_args.garbage_collection_gen0_threshold if llm_args is not None
@@ -95,7 +110,9 @@ def fetch_responses(self, timeout: Optional[float] = None) -> list:
95110
color="orange",
96111
category="Worker"):
97112
# NOTE: This is a blocking call, it will wait for the responses to be available.
98-
responses = super().await_responses(timeout=0.1)
113+
# Use the configured fetch timeout if no timeout is provided
114+
actual_timeout = timeout if timeout is not None else self._fetch_timeout
115+
responses = super().await_responses(timeout=actual_timeout)
99116
self._await_response_helper.responses_handler(responses)
100117
logger_debug(f"[worker] Fetched {len(responses)} responses",
101118
color="green")
@@ -248,11 +265,11 @@ def main_task(
248265

249266
else:
250267
logger_debug(
251-
f"[worker] Worker {mpi_rank()} is creating the RPC service",
268+
f"[worker] Worker {mpi_rank()} is creating the RPC service with {worker.num_workers} workers",
252269
color="yellow")
253270
# Step 2: Create the RPC service, it will expose all the APIs of the worker as remote call to the client
254271
# Set num_workers to larger than 1 since there are some streaming tasks runs infinitely, such as await_responses_async.
255-
rpc_server = RPCServer(worker, num_workers=RpcWorker.NUM_WORKERS)
272+
rpc_server = RPCServer(worker, num_workers=worker.num_workers)
256273
rpc_server.bind(rpc_addr)
257274
rpc_server.start()
258275
logger_debug(f"[worker] RPC server {mpi_rank()} is started",

0 commit comments

Comments
 (0)