Skip to content

Commit faf1780

Browse files
hchingsgreg-kwasniewski1
authored andcommitted
[TRTLLM-8988][feat] Unify MPI & Ray's req/response handling with RPC Client/Server (NVIDIA#8765)
Signed-off-by: Erin Ho <[email protected]>
1 parent 4f5d752 commit faf1780

File tree

9 files changed

+522
-320
lines changed

9 files changed

+522
-320
lines changed

tensorrt_llm/_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,13 @@ def mpi_disabled() -> bool:
524524
return os.environ.get("TLLM_DISABLE_MPI") == "1"
525525

526526

527+
def ray_use_rpc() -> bool:
528+
"""True if TLLM_RAY_USE_RPC is set to "1", False otherwise.
529+
# TODO: deprecate this once Ray is fully moved to use RPC client/server.
530+
"""
531+
return os.environ.get("TLLM_RAY_USE_RPC") == "1"
532+
533+
527534
def mpi_rank():
528535
if mpi_disabled():
529536
try:

tensorrt_llm/executor/ray_executor.py

Lines changed: 135 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,23 @@
1313
placement_group)
1414

1515
from tensorrt_llm._ray_utils import unwrap_ray_errors
16-
from tensorrt_llm._utils import get_free_port
16+
from tensorrt_llm._utils import get_free_port, nvtx_range_debug, ray_use_rpc
1717
from tensorrt_llm.logger import logger
1818

19-
from .._utils import nvtx_range_debug
19+
from ..llmapi.utils import logger_debug
2020
from .executor import GenerationExecutor
2121
from .postproc_worker import PostprocWorkerConfig
2222
from .ray_gpu_worker import RayGPUWorker, RayWorkerWrapper
2323
from .request import GenerationRequest
2424
from .result import GenerationResult, RayAsyncQueue, RaySyncQueue
25+
from .rpc_proxy import RpcExecutorMixin
2526

2627
__all__ = [
2728
"RayExecutor",
2829
]
2930

3031

31-
class RayExecutor(GenerationExecutor):
32+
class RayExecutor(RpcExecutorMixin, GenerationExecutor):
3233

3334
def __init__(self,
3435
worker_kwargs: Dict,
@@ -75,44 +76,44 @@ def __init__(self,
7576
self.tp_size = tp_size
7677
self.master_address = ray.util.get_node_ip_address()
7778
self.master_port = get_free_port()
78-
79-
self.response_queue = RayAsyncQueue.options(runtime_env={
80-
"env_vars": {
81-
"TLLM_DISABLE_MPI": "1"
82-
}
83-
}).remote()
84-
self.response_sync_queue = RaySyncQueue.options(runtime_env={
85-
"env_vars": {
86-
"TLLM_DISABLE_MPI": "1"
87-
}
88-
}).remote()
89-
self.async_response_queue_weakref = self.create_actor_weak_ref(
90-
self.response_queue)
91-
self.sync_response_queue_weakref = self.create_actor_weak_ref(
92-
self.response_sync_queue)
93-
self.response_queue.warmup.remote()
94-
self.response_sync_queue.warmup.remote()
79+
self.use_rpc = ray_use_rpc()
9580

9681
worker_kwargs = dict(**worker_kwargs,
9782
postproc_worker_config=postproc_worker_config,
9883
is_llm_executor=is_llm_executor)
9984

100-
self.create_workers(RayGPUWorker, worker_kwargs)
85+
if self.use_rpc:
86+
self.init_rpc_executor()
87+
worker_kwargs['rpc_addr'] = self.rpc_addr
88+
self.create_workers(RayGPUWorker, worker_kwargs)
89+
self.setup_engine_remote()
90+
self.setup_mainloop(tasks=[self._fetch_responses_loop_async],
91+
thread_name="ray_executor_main_loop")
92+
logger.info(f"Connecting to RPC server at {self.rpc_addr}")
93+
else:
94+
self.response_queue = RayAsyncQueue.options(runtime_env={
95+
"env_vars": {
96+
"TLLM_DISABLE_MPI": "1"
97+
}
98+
}).remote()
99+
self.response_sync_queue = RaySyncQueue.options(runtime_env={
100+
"env_vars": {
101+
"TLLM_DISABLE_MPI": "1"
102+
}
103+
}).remote()
104+
self.async_response_queue_weakref = self.create_actor_weak_ref(
105+
self.response_queue)
106+
self.sync_response_queue_weakref = self.create_actor_weak_ref(
107+
self.response_sync_queue)
108+
self.response_queue.warmup.remote()
109+
self.response_sync_queue.warmup.remote()
110+
self.create_workers(RayGPUWorker, worker_kwargs)
111+
101112
except Exception as e:
102-
# Clean up the Ray resources early during exception
103113
self.shutdown()
104114
logger.error(f"Failed to initialize RayExecutor: {e}")
105115
raise e
106116

107-
@staticmethod
108-
def create_actor_weak_ref(actor_handle: ray.actor.ActorHandle):
109-
state, _, _ = actor_handle._serialization_helper()
110-
return ray.actor.ActorHandle._deserialization_helper(state,
111-
weak_ref=True)
112-
113-
def use_ray_queue(self) -> bool:
114-
return True
115-
116117
def create_workers(self, worker_cls, worker_kwargs):
117118
# When set to be a fraction, it allows Ray to schedule
118119
# multiple actors on a single GPU for colocate use cases.
@@ -188,49 +189,118 @@ def collective_rpc(self,
188189
**kwargs))
189190
return refs if non_block else ray.get(refs)
190191

191-
def submit(self, request: GenerationRequest) -> GenerationResult:
192+
def submit(self, request: "GenerationRequest") -> "GenerationResult":
192193
"""
193-
Low-level API to the executor. Return a "future" GenerationResult
194-
which can be waited.
195-
Forwards the request to the workers through the request queue.
194+
Low-level API to the executor. Return a "future" GenerationResult
195+
which can be waited.
196+
Forwards the request to the workers through RPC or Ray queues depending on mode.
196197
"""
197198
request.set_id(self._get_next_client_id())
198199
logprob_params = self._get_logprob_params(request)
199200

200-
result = GenerationResult(
201-
request,
202-
background_error_handler=self._handle_background_error,
203-
executor=self,
204-
disaggregated_params=request.disaggregated_params,
205-
logprob_params=logprob_params)
206-
207-
with nvtx_range_debug("request_queue.put"):
208-
self.call_all_ray_workers("enqueue_request",
209-
leader_only=True,
210-
request=request,
211-
async_call=True,
212-
result_wait_queue=result.queue)
201+
if self.use_rpc:
202+
with nvtx_range_debug("rpc_submit"):
203+
self.rpc_client.submit(request).remote(need_response=False)
204+
205+
result = GenerationResult(
206+
request,
207+
background_error_handler=self._handle_background_error,
208+
executor=self,
209+
disaggregated_params=request.disaggregated_params,
210+
logprob_params=logprob_params)
211+
self._results[request.id] = result
212+
else:
213+
result = GenerationResult(
214+
request,
215+
background_error_handler=self._handle_background_error,
216+
executor=self,
217+
disaggregated_params=request.disaggregated_params,
218+
logprob_params=logprob_params)
219+
220+
with nvtx_range_debug("request_queue.put"):
221+
self.call_all_ray_workers("enqueue_request",
222+
leader_only=True,
223+
request=request,
224+
async_call=True,
225+
result_wait_queue=result.queue)
213226

214227
return result
215228

229+
def start(self):
230+
pass
231+
232+
def setup_engine_remote(self):
233+
return self.collective_rpc("setup_engine", non_block=False)
234+
216235
def report_device_ids(self) -> list[str]:
217236
gpu_ids = self.call_all_ray_workers("report_device_id",
218237
leader_only=False,
219238
async_call=False)
220239
return sorted(gpu_ids)
221240

241+
def use_ray_queue(self) -> bool:
242+
return not self.use_rpc
243+
222244
def abort_request(self, request_id: int) -> None:
223245
self.call_all_ray_workers("abort_request",
224246
leader_only=True,
225247
async_call=False,
226248
request_id=request_id)
227249

228250
def shutdown(self):
229-
# Release actors
230-
self.response_queue = None
231-
self.response_sync_queue = None
232-
self.async_response_queue_weakref = None
233-
self.sync_response_queue_weakref = None
251+
if hasattr(self, '_shutdown_event') and self._shutdown_event.is_set():
252+
return
253+
if hasattr(self, '_shutdown_event'):
254+
self._shutdown_event.set()
255+
256+
mode_str = "RPC mode" if self.use_rpc else "Ray queue mode"
257+
logger_debug(f"Shutting down RayExecutor ({mode_str})", color="yellow")
258+
259+
if self.use_rpc:
260+
if hasattr(self, 'main_loop') and self.main_loop and hasattr(
261+
self, 'main_loop_task_obj') and self.main_loop_task_obj:
262+
logger_debug("Cancelling main loop task.", color="yellow")
263+
try:
264+
self.main_loop.call_soon_threadsafe(
265+
self.main_loop_task_obj.cancel)
266+
except Exception as e:
267+
logger_debug(f"Error cancelling main loop task: {e}",
268+
color="yellow")
269+
270+
if hasattr(self, 'main_loop_thread'):
271+
self.main_loop_thread.join()
272+
273+
# Then, shutdown the workers
274+
if hasattr(self, 'workers') and self.workers is not None:
275+
try:
276+
logger_debug("Shutting down RPC remote", color="yellow")
277+
shutdown_refs = [
278+
worker.shutdown.remote() for worker in self.workers
279+
]
280+
# Add timeout to prevent indefinite hanging
281+
ray.get(shutdown_refs, timeout=30.0)
282+
except ray.exceptions.GetTimeoutError:
283+
logger.warning(
284+
"Timeout waiting for workers to shutdown after 30 seconds"
285+
)
286+
except Exception as e:
287+
logger.warning(f"Error shutting down RPC remote: {e}")
288+
289+
if hasattr(self, 'rpc_client') and self.rpc_client is not None:
290+
try:
291+
self.rpc_client.close()
292+
except Exception as e:
293+
# Suppress errors during RPC client shutdown
294+
# These can occur if the client is already closed or if there are
295+
# pending operations that get cancelled during cleanup
296+
logger_debug(
297+
f"Suppressed error during RPC client close: {e}")
298+
else:
299+
# Release actors
300+
self.response_queue = None
301+
self.response_sync_queue = None
302+
self.async_response_queue_weakref = None
303+
self.sync_response_queue_weakref = None
234304

235305
self.workers = None
236306
if hasattr(self,
@@ -246,12 +316,6 @@ def shutdown(self):
246316
logger.debug("Shutting down Ray cluster")
247317
ray.shutdown()
248318

249-
@property
250-
def enable_postprocess_parallel(self) -> bool:
251-
ret = super().enable_postprocess_parallel
252-
assert ret == False, "Postprocess parallel is not supported in RayExecutor"
253-
return ret
254-
255319
def _get_placement_group(self,
256320
tp_size: int) -> Tuple[PlacementGroup, List[int]]:
257321
"""
@@ -317,3 +381,15 @@ def _get_placement_group(self,
317381
pg = placement_group(bundles, strategy=strategy)
318382

319383
return pg, bundle_indices
384+
385+
@property
386+
def enable_postprocess_parallel(self) -> bool:
387+
ret = super().enable_postprocess_parallel
388+
assert ret == False, "Postprocess parallel is not supported in RayExecutor"
389+
return ret
390+
391+
@staticmethod
392+
def create_actor_weak_ref(actor_handle: ray.actor.ActorHandle):
393+
state, _, _ = actor_handle._serialization_helper()
394+
return ray.actor.ActorHandle._deserialization_helper(state,
395+
weak_ref=True)

0 commit comments

Comments
 (0)