Skip to content

Commit f48aa54

Browse files
committed
add rpc client/server path
Signed-off-by: Erin Ho <[email protected]> update remove ray queues Signed-off-by: Erin Ho <[email protected]>
1 parent 7828245 commit f48aa54

File tree

7 files changed

+381
-221
lines changed

7 files changed

+381
-221
lines changed

tensorrt_llm/bench/dataclasses/reporting.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import json
4+
import os
45
from collections import defaultdict
56
from typing import Any, Dict, List, NamedTuple
67

@@ -507,9 +508,14 @@ def report_statistics(self) -> None:
507508
f"Max Sequence Length:\t{build_cfg['max_seq_len']}\n"
508509
f"\n")
509510
else:
511+
# Check MPI vs RAY and RPC status
512+
comm_backend = "RAY" if os.environ.get(
513+
"TLLM_DISABLE_MPI") == "1" else "MPI"
514+
ray_status = "[RPC]" if os.environ.get(
515+
"TLLM_RAY_USE_RPC") == "1" else "[original]"
510516
backend_info = (
511517
"\n\n===========================================================\n"
512-
"= PYTORCH BACKEND\n"
518+
f"= PYTORCH BACKEND [{comm_backend}] {ray_status}\n"
513519
"===========================================================\n"
514520
f"Model:\t\t\t{engine['model']}\n"
515521
f"Model Path:\t\t{engine['model_path']}\n"

tensorrt_llm/executor/base_worker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,7 @@ def _deduce_max_tokens(request: GenerationRequest,
523523

524524
def submit(self, request: GenerationRequest) -> GenerationResult:
525525
""" Low-level API to the executor. Return a "future" GenerationResult which can be waited. """
526+
# TODO Use this to test error propogation issue with RayExecutor.
526527
self.start()
527528

528529
if self.rank != 0:

tensorrt_llm/executor/ray_executor.py

Lines changed: 166 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
import asyncio
2+
import atexit
13
import os
4+
import threading
25
from typing import Any, Dict, List, Optional, Tuple
36

47
try:
@@ -17,11 +20,16 @@
1720
from tensorrt_llm.logger import logger
1821

1922
from .._utils import nvtx_range_debug
23+
from ..llmapi.tracer import global_tracer
24+
from ..llmapi.utils import _SyncQueue, logger_debug
2025
from .executor import GenerationExecutor
2126
from .postproc_worker import PostprocWorkerConfig
2227
from .ray_gpu_worker import RayGPUWorker, RayWorkerWrapper
2328
from .request import GenerationRequest
24-
from .result import GenerationResult, RayAsyncQueue, RaySyncQueue
29+
from .result import GenerationResult
30+
from .rpc import RPCClient
31+
from .rpc.rpc_common import get_unique_ipc_addr
32+
from .utils import ErrorResponse, is_llm_response
2533

2634
__all__ = [
2735
"RayExecutor",
@@ -76,28 +84,25 @@ def __init__(self,
7684
self.master_address = ray.util.get_node_ip_address()
7785
self.master_port = get_free_port()
7886

79-
self.response_queue = RayAsyncQueue.options(runtime_env={
80-
"env_vars": {
81-
"TLLM_DISABLE_MPI": "1"
82-
}
83-
}).remote()
84-
self.response_sync_queue = RaySyncQueue.options(runtime_env={
85-
"env_vars": {
86-
"TLLM_DISABLE_MPI": "1"
87-
}
88-
}).remote()
89-
self.async_response_queue_weakref = self.create_actor_weak_ref(
90-
self.response_queue)
91-
self.sync_response_queue_weakref = self.create_actor_weak_ref(
92-
self.response_sync_queue)
93-
self.response_queue.warmup.remote()
94-
self.response_sync_queue.warmup.remote()
87+
self.rpc_addr = get_unique_ipc_addr()
88+
self.rpc_client = RPCClient(self.rpc_addr)
89+
print(f"====RPC client created at {self.rpc_addr}")
90+
91+
self._results = {}
92+
self._shutdown_event = threading.Event()
93+
self.main_loop_task_obj = None
94+
self.main_loop = None
9595

9696
worker_kwargs = dict(**worker_kwargs,
9797
postproc_worker_config=postproc_worker_config,
98-
is_llm_executor=is_llm_executor)
98+
is_llm_executor=is_llm_executor,
99+
rpc_addr=self.rpc_addr)
99100

100101
self.create_workers(RayGPUWorker, worker_kwargs)
102+
103+
logger.info("Setting up engine via RPC")
104+
self.setup_engine_remote()
105+
self.setup_mainloop()
101106
except Exception as e:
102107
# Clean up the Ray resources early during exception
103108
self.shutdown()
@@ -110,8 +115,103 @@ def create_actor_weak_ref(actor_handle: ray.actor.ActorHandle):
110115
return ray.actor.ActorHandle._deserialization_helper(state,
111116
weak_ref=True)
112117

113-
def use_ray_queue(self) -> bool:
114-
return True
118+
async def _generic_fetch_loop_async(self, fetch_method_name: str,
119+
handler_method, method_name: str):
120+
# TODO copied from GenerationExecutorRpcProxy, need refactoring.
121+
"""Generic method for fetching data in a loop from RPC worker.
122+
123+
Args:
124+
fetch_method_name: Name of the RPC client method to call
125+
handler_method: The handler method to call with the fetched data
126+
method_name: Name of the method for logging
127+
"""
128+
try:
129+
fetch_method = getattr(self.rpc_client, fetch_method_name)
130+
async for data in fetch_method().remote_streaming():
131+
if self._shutdown_event.is_set():
132+
return
133+
handler_method(data)
134+
except asyncio.CancelledError:
135+
logger.debug(f"{method_name} task cancelled")
136+
except Exception as e:
137+
logger.error(f"Error in {method_name}: {e}")
138+
raise
139+
140+
async def _fetch_responses_loop_async(self):
141+
# TODO copied from GenerationExecutorRpcProxy, need refactoring.
142+
await self._generic_fetch_loop_async(
143+
fetch_method_name="fetch_responses_loop_async",
144+
handler_method=self.handle_responses,
145+
method_name="_fetch_responses_loop_async")
146+
147+
def setup_mainloop(self):
148+
# TODO copied from GenerationExecutorRpcProxy, need refactoring.
149+
async def main_loop_task():
150+
await self._fetch_responses_loop_async()
151+
152+
def _run_main_loop_task():
153+
"""Local method to run the main loop task."""
154+
self.main_loop = asyncio.new_event_loop()
155+
asyncio.set_event_loop(self.main_loop)
156+
157+
self.main_loop_task_obj = self.main_loop.create_task(
158+
main_loop_task())
159+
try:
160+
self.main_loop.run_until_complete(self.main_loop_task_obj)
161+
except asyncio.CancelledError:
162+
pass # Task cancellation is expected during shutdown
163+
finally:
164+
self.main_loop.close()
165+
166+
self.main_loop_thread = threading.Thread(target=_run_main_loop_task,
167+
daemon=True,
168+
name="ray_executor_main_loop")
169+
self.main_loop_thread.start()
170+
atexit.register(self.shutdown)
171+
172+
def setup_engine_remote(self):
173+
return self.collective_rpc("setup_engine", non_block=False)
174+
175+
def handle_responses(self, responses: list[GenerationResult]) -> bool:
176+
# TODO copied from GenerationExecutorRpcProxy, need refactoring.
177+
async_queues = []
178+
event_loop = None
179+
180+
def process_res(res: list):
181+
for r in res:
182+
client_id = r.client_id
183+
nonlocal event_loop
184+
nonlocal async_queues
185+
186+
if client_id not in self._results:
187+
logger.warning(
188+
f"Received response for unknown client_id: {client_id}")
189+
continue
190+
191+
queue = self._results[client_id].queue
192+
if isinstance(queue, _SyncQueue):
193+
queue.put_nowait(r)
194+
async_queues.append(queue)
195+
# all the loops are identical
196+
event_loop = event_loop or queue.loop
197+
else:
198+
queue.put(r)
199+
200+
if (is_llm_response(r) and r.result.is_final) or isinstance(
201+
r, ErrorResponse):
202+
self._results.pop(client_id)
203+
204+
# Handle the case where responses might not be a list of lists
205+
if responses and not isinstance(responses[0], list):
206+
# If responses is a flat list, wrap it
207+
responses = [responses]
208+
209+
for res in responses:
210+
global_tracer().log_instant("RPC.get")
211+
process_res(res)
212+
213+
if async_queues:
214+
_SyncQueue.notify_many(event_loop, async_queues)
115215

116216
def create_workers(self, worker_cls, worker_kwargs):
117217
# When set to be a fraction, it allows Ray to schedule
@@ -192,27 +292,27 @@ def submit(self, request: GenerationRequest) -> GenerationResult:
192292
"""
193293
Low-level API to the executor. Return a "future" GenerationResult
194294
which can be waited.
195-
Forwards the request to the workers through the request queue.
295+
Forwards the request to the workers through RPC.
196296
"""
197297
request.set_id(self._get_next_client_id())
198298
logprob_params = self._get_logprob_params(request)
199299

300+
with nvtx_range_debug("rpc_submit"):
301+
self.rpc_client.submit(request).remote(need_response=False)
302+
200303
result = GenerationResult(
201304
request,
202305
background_error_handler=self._handle_background_error,
203306
executor=self,
204307
disaggregated_params=request.disaggregated_params,
205308
logprob_params=logprob_params)
206-
207-
with nvtx_range_debug("request_queue.put"):
208-
self.call_all_ray_workers("enqueue_request",
209-
leader_only=True,
210-
request=request,
211-
async_call=True,
212-
result_wait_queue=result.queue)
309+
self._results[request.id] = result
213310

214311
return result
215312

313+
def start(self):
314+
pass
315+
216316
def report_device_ids(self) -> list[str]:
217317
gpu_ids = self.call_all_ray_workers("report_device_id",
218318
leader_only=False,
@@ -225,12 +325,45 @@ def abort_request(self, request_id: int) -> None:
225325
async_call=False,
226326
request_id=request_id)
227327

328+
# TODO: Use Ray RPC to shutdown RPC server, and then close client
228329
def shutdown(self):
229-
# Release actors
230-
self.response_queue = None
231-
self.response_sync_queue = None
232-
self.async_response_queue_weakref = None
233-
self.sync_response_queue_weakref = None
330+
try:
331+
self.shutdown_impl()
332+
except Exception as e:
333+
# TODO: clean up
334+
print(f"Error shutting down RayExecutor: {e}")
335+
raise e
336+
337+
def shutdown_impl(self):
338+
if self._shutdown_event.is_set():
339+
return
340+
self._shutdown_event.set()
341+
logger_debug(f"Shutting down RayExecutor (RPC mode)", color="yellow")
342+
343+
if hasattr(self, 'rpc_client') and self.rpc_client is not None:
344+
try:
345+
logger_debug("Shutting down RPC remote", color="yellow")
346+
self.call_all_ray_workers("shutdown",
347+
leader_only=False,
348+
async_call=False)
349+
except Exception as e:
350+
logger.warning(f"Error shutting down RPC remote: {e}")
351+
352+
if hasattr(self, 'main_loop') and self.main_loop and hasattr(
353+
self, 'main_loop_task_obj') and self.main_loop_task_obj:
354+
logger_debug("Cancelling main loop task.", color="yellow")
355+
try:
356+
self.main_loop.call_soon_threadsafe(
357+
self.main_loop_task_obj.cancel)
358+
except Exception as e:
359+
logger_debug(f"Error cancelling main loop task: {e}",
360+
color="yellow")
361+
362+
if hasattr(self, 'main_loop_thread'):
363+
self.main_loop_thread.join()
364+
365+
if hasattr(self, 'rpc_client') and self.rpc_client is not None:
366+
self.rpc_client.close()
234367

235368
self.workers = None
236369
if hasattr(self,

0 commit comments

Comments
 (0)