1+ import asyncio
2+ import atexit
13import os
4+ import threading
25from typing import Any , Dict , List , Optional , Tuple
36
47try :
1720from tensorrt_llm .logger import logger
1821
1922from .._utils import nvtx_range_debug
23+ from ..llmapi .tracer import global_tracer
24+ from ..llmapi .utils import _SyncQueue , logger_debug
2025from .executor import GenerationExecutor
2126from .postproc_worker import PostprocWorkerConfig
2227from .ray_gpu_worker import RayGPUWorker , RayWorkerWrapper
2328from .request import GenerationRequest
24- from .result import GenerationResult , RayAsyncQueue , RaySyncQueue
29+ from .result import GenerationResult
30+ from .rpc import RPCClient
31+ from .rpc .rpc_common import get_unique_ipc_addr
32+ from .utils import ErrorResponse , is_llm_response
2533
2634__all__ = [
2735 "RayExecutor" ,
@@ -76,28 +84,25 @@ def __init__(self,
7684 self .master_address = ray .util .get_node_ip_address ()
7785 self .master_port = get_free_port ()
7886
79- self .response_queue = RayAsyncQueue .options (runtime_env = {
80- "env_vars" : {
81- "TLLM_DISABLE_MPI" : "1"
82- }
83- }).remote ()
84- self .response_sync_queue = RaySyncQueue .options (runtime_env = {
85- "env_vars" : {
86- "TLLM_DISABLE_MPI" : "1"
87- }
88- }).remote ()
89- self .async_response_queue_weakref = self .create_actor_weak_ref (
90- self .response_queue )
91- self .sync_response_queue_weakref = self .create_actor_weak_ref (
92- self .response_sync_queue )
93- self .response_queue .warmup .remote ()
94- self .response_sync_queue .warmup .remote ()
87+ self .rpc_addr = get_unique_ipc_addr ()
88+ self .rpc_client = RPCClient (self .rpc_addr )
89+ print (f"====RPC client created at { self .rpc_addr } " )
90+
91+ self ._results = {}
92+ self ._shutdown_event = threading .Event ()
93+ self .main_loop_task_obj = None
94+ self .main_loop = None
9595
9696 worker_kwargs = dict (** worker_kwargs ,
9797 postproc_worker_config = postproc_worker_config ,
98- is_llm_executor = is_llm_executor )
98+ is_llm_executor = is_llm_executor ,
99+ rpc_addr = self .rpc_addr )
99100
100101 self .create_workers (RayGPUWorker , worker_kwargs )
102+
103+ logger .info ("Setting up engine via RPC" )
104+ self .setup_engine_remote ()
105+ self .setup_mainloop ()
101106 except Exception as e :
102107 # Clean up the Ray resources early during exception
103108 self .shutdown ()
@@ -110,8 +115,103 @@ def create_actor_weak_ref(actor_handle: ray.actor.ActorHandle):
110115 return ray .actor .ActorHandle ._deserialization_helper (state ,
111116 weak_ref = True )
112117
113- def use_ray_queue (self ) -> bool :
114- return True
118+ async def _generic_fetch_loop_async (self , fetch_method_name : str ,
119+ handler_method , method_name : str ):
120+ # TODO copied from GenerationExecutorRpcProxy, need refactoring.
121+ """Generic method for fetching data in a loop from RPC worker.
122+
123+ Args:
124+ fetch_method_name: Name of the RPC client method to call
125+ handler_method: The handler method to call with the fetched data
126+ method_name: Name of the method for logging
127+ """
128+ try :
129+ fetch_method = getattr (self .rpc_client , fetch_method_name )
130+ async for data in fetch_method ().remote_streaming ():
131+ if self ._shutdown_event .is_set ():
132+ return
133+ handler_method (data )
134+ except asyncio .CancelledError :
135+ logger .debug (f"{ method_name } task cancelled" )
136+ except Exception as e :
137+ logger .error (f"Error in { method_name } : { e } " )
138+ raise
139+
140+ async def _fetch_responses_loop_async (self ):
141+ # TODO copied from GenerationExecutorRpcProxy, need refactoring.
142+ await self ._generic_fetch_loop_async (
143+ fetch_method_name = "fetch_responses_loop_async" ,
144+ handler_method = self .handle_responses ,
145+ method_name = "_fetch_responses_loop_async" )
146+
147+ def setup_mainloop (self ):
148+ # TODO copied from GenerationExecutorRpcProxy, need refactoring.
149+ async def main_loop_task ():
150+ await self ._fetch_responses_loop_async ()
151+
152+ def _run_main_loop_task ():
153+ """Local method to run the main loop task."""
154+ self .main_loop = asyncio .new_event_loop ()
155+ asyncio .set_event_loop (self .main_loop )
156+
157+ self .main_loop_task_obj = self .main_loop .create_task (
158+ main_loop_task ())
159+ try :
160+ self .main_loop .run_until_complete (self .main_loop_task_obj )
161+ except asyncio .CancelledError :
162+ pass # Task cancellation is expected during shutdown
163+ finally :
164+ self .main_loop .close ()
165+
166+ self .main_loop_thread = threading .Thread (target = _run_main_loop_task ,
167+ daemon = True ,
168+ name = "ray_executor_main_loop" )
169+ self .main_loop_thread .start ()
170+ atexit .register (self .shutdown )
171+
172+ def setup_engine_remote (self ):
173+ return self .collective_rpc ("setup_engine" , non_block = False )
174+
175+ def handle_responses (self , responses : list [GenerationResult ]) -> bool :
176+ # TODO copied from GenerationExecutorRpcProxy, need refactoring.
177+ async_queues = []
178+ event_loop = None
179+
180+ def process_res (res : list ):
181+ for r in res :
182+ client_id = r .client_id
183+ nonlocal event_loop
184+ nonlocal async_queues
185+
186+ if client_id not in self ._results :
187+ logger .warning (
188+ f"Received response for unknown client_id: { client_id } " )
189+ continue
190+
191+ queue = self ._results [client_id ].queue
192+ if isinstance (queue , _SyncQueue ):
193+ queue .put_nowait (r )
194+ async_queues .append (queue )
195+ # all the loops are identical
196+ event_loop = event_loop or queue .loop
197+ else :
198+ queue .put (r )
199+
200+ if (is_llm_response (r ) and r .result .is_final ) or isinstance (
201+ r , ErrorResponse ):
202+ self ._results .pop (client_id )
203+
204+ # Handle the case where responses might not be a list of lists
205+ if responses and not isinstance (responses [0 ], list ):
206+ # If responses is a flat list, wrap it
207+ responses = [responses ]
208+
209+ for res in responses :
210+ global_tracer ().log_instant ("RPC.get" )
211+ process_res (res )
212+
213+ if async_queues :
214+ _SyncQueue .notify_many (event_loop , async_queues )
115215
116216 def create_workers (self , worker_cls , worker_kwargs ):
117217 # When set to be a fraction, it allows Ray to schedule
@@ -192,27 +292,27 @@ def submit(self, request: GenerationRequest) -> GenerationResult:
192292 """
193293 Low-level API to the executor. Return a "future" GenerationResult
194294 which can be waited.
195- Forwards the request to the workers through the request queue .
295+ Forwards the request to the workers through RPC .
196296 """
197297 request .set_id (self ._get_next_client_id ())
198298 logprob_params = self ._get_logprob_params (request )
199299
300+ with nvtx_range_debug ("rpc_submit" ):
301+ self .rpc_client .submit (request ).remote (need_response = False )
302+
200303 result = GenerationResult (
201304 request ,
202305 background_error_handler = self ._handle_background_error ,
203306 executor = self ,
204307 disaggregated_params = request .disaggregated_params ,
205308 logprob_params = logprob_params )
206-
207- with nvtx_range_debug ("request_queue.put" ):
208- self .call_all_ray_workers ("enqueue_request" ,
209- leader_only = True ,
210- request = request ,
211- async_call = True ,
212- result_wait_queue = result .queue )
309+ self ._results [request .id ] = result
213310
214311 return result
215312
313+ def start (self ):
314+ pass
315+
216316 def report_device_ids (self ) -> list [str ]:
217317 gpu_ids = self .call_all_ray_workers ("report_device_id" ,
218318 leader_only = False ,
@@ -225,12 +325,45 @@ def abort_request(self, request_id: int) -> None:
225325 async_call = False ,
226326 request_id = request_id )
227327
328+ # TODO: Use Ray RPC to shutdown RPC server, and then close client
228329 def shutdown (self ):
229- # Release actors
230- self .response_queue = None
231- self .response_sync_queue = None
232- self .async_response_queue_weakref = None
233- self .sync_response_queue_weakref = None
330+ try :
331+ self .shutdown_impl ()
332+ except Exception as e :
333+ # TODO: clean up
334+ print (f"Error shutting down RayExecutor: { e } " )
335+ raise e
336+
337+ def shutdown_impl (self ):
338+ if self ._shutdown_event .is_set ():
339+ return
340+ self ._shutdown_event .set ()
341+ logger_debug (f"Shutting down RayExecutor (RPC mode)" , color = "yellow" )
342+
343+ if hasattr (self , 'rpc_client' ) and self .rpc_client is not None :
344+ try :
345+ logger_debug ("Shutting down RPC remote" , color = "yellow" )
346+ self .call_all_ray_workers ("shutdown" ,
347+ leader_only = False ,
348+ async_call = False )
349+ except Exception as e :
350+ logger .warning (f"Error shutting down RPC remote: { e } " )
351+
352+ if hasattr (self , 'main_loop' ) and self .main_loop and hasattr (
353+ self , 'main_loop_task_obj' ) and self .main_loop_task_obj :
354+ logger_debug ("Cancelling main loop task." , color = "yellow" )
355+ try :
356+ self .main_loop .call_soon_threadsafe (
357+ self .main_loop_task_obj .cancel )
358+ except Exception as e :
359+ logger_debug (f"Error cancelling main loop task: { e } " ,
360+ color = "yellow" )
361+
362+ if hasattr (self , 'main_loop_thread' ):
363+ self .main_loop_thread .join ()
364+
365+ if hasattr (self , 'rpc_client' ) and self .rpc_client is not None :
366+ self .rpc_client .close ()
234367
235368 self .workers = None
236369 if hasattr (self ,
0 commit comments