1313 placement_group )
1414
1515from tensorrt_llm ._ray_utils import unwrap_ray_errors
16- from tensorrt_llm ._utils import get_free_port
16+ from tensorrt_llm ._utils import get_free_port , nvtx_range_debug , ray_use_rpc
1717from tensorrt_llm .logger import logger
1818
19- from .._utils import nvtx_range_debug
19+ from ..llmapi . utils import logger_debug
2020from .executor import GenerationExecutor
2121from .postproc_worker import PostprocWorkerConfig
2222from .ray_gpu_worker import RayGPUWorker , RayWorkerWrapper
2323from .request import GenerationRequest
2424from .result import GenerationResult , RayAsyncQueue , RaySyncQueue
25+ from .rpc_proxy import RpcExecutorMixin
2526
2627__all__ = [
2728 "RayExecutor" ,
2829]
2930
3031
31- class RayExecutor (GenerationExecutor ):
32+ class RayExecutor (RpcExecutorMixin , GenerationExecutor ):
3233
3334 def __init__ (self ,
3435 worker_kwargs : Dict ,
@@ -75,44 +76,44 @@ def __init__(self,
7576 self .tp_size = tp_size
7677 self .master_address = ray .util .get_node_ip_address ()
7778 self .master_port = get_free_port ()
78-
79- self .response_queue = RayAsyncQueue .options (runtime_env = {
80- "env_vars" : {
81- "TLLM_DISABLE_MPI" : "1"
82- }
83- }).remote ()
84- self .response_sync_queue = RaySyncQueue .options (runtime_env = {
85- "env_vars" : {
86- "TLLM_DISABLE_MPI" : "1"
87- }
88- }).remote ()
89- self .async_response_queue_weakref = self .create_actor_weak_ref (
90- self .response_queue )
91- self .sync_response_queue_weakref = self .create_actor_weak_ref (
92- self .response_sync_queue )
93- self .response_queue .warmup .remote ()
94- self .response_sync_queue .warmup .remote ()
79+ self .use_rpc = ray_use_rpc ()
9580
9681 worker_kwargs = dict (** worker_kwargs ,
9782 postproc_worker_config = postproc_worker_config ,
9883 is_llm_executor = is_llm_executor )
9984
100- self .create_workers (RayGPUWorker , worker_kwargs )
85+ if self .use_rpc :
86+ self .init_rpc_executor ()
87+ worker_kwargs ['rpc_addr' ] = self .rpc_addr
88+ self .create_workers (RayGPUWorker , worker_kwargs )
89+ self .setup_engine_remote ()
90+ self .setup_mainloop (tasks = [self ._fetch_responses_loop_async ],
91+ thread_name = "ray_executor_main_loop" )
92+ logger .info (f"Connecting to RPC server at { self .rpc_addr } " )
93+ else :
94+ self .response_queue = RayAsyncQueue .options (runtime_env = {
95+ "env_vars" : {
96+ "TLLM_DISABLE_MPI" : "1"
97+ }
98+ }).remote ()
99+ self .response_sync_queue = RaySyncQueue .options (runtime_env = {
100+ "env_vars" : {
101+ "TLLM_DISABLE_MPI" : "1"
102+ }
103+ }).remote ()
104+ self .async_response_queue_weakref = self .create_actor_weak_ref (
105+ self .response_queue )
106+ self .sync_response_queue_weakref = self .create_actor_weak_ref (
107+ self .response_sync_queue )
108+ self .response_queue .warmup .remote ()
109+ self .response_sync_queue .warmup .remote ()
110+ self .create_workers (RayGPUWorker , worker_kwargs )
111+
101112 except Exception as e :
102- # Clean up the Ray resources early during exception
103113 self .shutdown ()
104114 logger .error (f"Failed to initialize RayExecutor: { e } " )
105115 raise e
106116
107- @staticmethod
108- def create_actor_weak_ref (actor_handle : ray .actor .ActorHandle ):
109- state , _ , _ = actor_handle ._serialization_helper ()
110- return ray .actor .ActorHandle ._deserialization_helper (state ,
111- weak_ref = True )
112-
113- def use_ray_queue (self ) -> bool :
114- return True
115-
116117 def create_workers (self , worker_cls , worker_kwargs ):
117118 # When set to be a fraction, it allows Ray to schedule
118119 # multiple actors on a single GPU for colocate use cases.
@@ -188,49 +189,118 @@ def collective_rpc(self,
188189 ** kwargs ))
189190 return refs if non_block else ray .get (refs )
190191
191- def submit (self , request : GenerationRequest ) -> GenerationResult :
192+ def submit (self , request : " GenerationRequest" ) -> " GenerationResult" :
192193 """
193- Low-level API to the executor. Return a "future" GenerationResult
194- which can be waited.
195- Forwards the request to the workers through the request queue .
194+ Low-level API to the executor. Return a "future" GenerationResult
195+ which can be waited.
196+ Forwards the request to the workers through RPC or Ray queues depending on mode .
196197 """
197198 request .set_id (self ._get_next_client_id ())
198199 logprob_params = self ._get_logprob_params (request )
199200
200- result = GenerationResult (
201- request ,
202- background_error_handler = self ._handle_background_error ,
203- executor = self ,
204- disaggregated_params = request .disaggregated_params ,
205- logprob_params = logprob_params )
206-
207- with nvtx_range_debug ("request_queue.put" ):
208- self .call_all_ray_workers ("enqueue_request" ,
209- leader_only = True ,
210- request = request ,
211- async_call = True ,
212- result_wait_queue = result .queue )
201+ if self .use_rpc :
202+ with nvtx_range_debug ("rpc_submit" ):
203+ self .rpc_client .submit (request ).remote (need_response = False )
204+
205+ result = GenerationResult (
206+ request ,
207+ background_error_handler = self ._handle_background_error ,
208+ executor = self ,
209+ disaggregated_params = request .disaggregated_params ,
210+ logprob_params = logprob_params )
211+ self ._results [request .id ] = result
212+ else :
213+ result = GenerationResult (
214+ request ,
215+ background_error_handler = self ._handle_background_error ,
216+ executor = self ,
217+ disaggregated_params = request .disaggregated_params ,
218+ logprob_params = logprob_params )
219+
220+ with nvtx_range_debug ("request_queue.put" ):
221+ self .call_all_ray_workers ("enqueue_request" ,
222+ leader_only = True ,
223+ request = request ,
224+ async_call = True ,
225+ result_wait_queue = result .queue )
213226
214227 return result
215228
229+ def start (self ):
230+ pass
231+
232+ def setup_engine_remote (self ):
233+ return self .collective_rpc ("setup_engine" , non_block = False )
234+
216235 def report_device_ids (self ) -> list [str ]:
217236 gpu_ids = self .call_all_ray_workers ("report_device_id" ,
218237 leader_only = False ,
219238 async_call = False )
220239 return sorted (gpu_ids )
221240
241+ def use_ray_queue (self ) -> bool :
242+ return not self .use_rpc
243+
222244 def abort_request (self , request_id : int ) -> None :
223245 self .call_all_ray_workers ("abort_request" ,
224246 leader_only = True ,
225247 async_call = False ,
226248 request_id = request_id )
227249
228250 def shutdown (self ):
229- # Release actors
230- self .response_queue = None
231- self .response_sync_queue = None
232- self .async_response_queue_weakref = None
233- self .sync_response_queue_weakref = None
251+ if hasattr (self , '_shutdown_event' ) and self ._shutdown_event .is_set ():
252+ return
253+ if hasattr (self , '_shutdown_event' ):
254+ self ._shutdown_event .set ()
255+
256+ mode_str = "RPC mode" if self .use_rpc else "Ray queue mode"
257+ logger_debug (f"Shutting down RayExecutor ({ mode_str } )" , color = "yellow" )
258+
259+ if self .use_rpc :
260+ if hasattr (self , 'main_loop' ) and self .main_loop and hasattr (
261+ self , 'main_loop_task_obj' ) and self .main_loop_task_obj :
262+ logger_debug ("Cancelling main loop task." , color = "yellow" )
263+ try :
264+ self .main_loop .call_soon_threadsafe (
265+ self .main_loop_task_obj .cancel )
266+ except Exception as e :
267+ logger_debug (f"Error cancelling main loop task: { e } " ,
268+ color = "yellow" )
269+
270+ if hasattr (self , 'main_loop_thread' ):
271+ self .main_loop_thread .join ()
272+
273+ # Then, shutdown the workers
274+ if hasattr (self , 'workers' ) and self .workers is not None :
275+ try :
276+ logger_debug ("Shutting down RPC remote" , color = "yellow" )
277+ shutdown_refs = [
278+ worker .shutdown .remote () for worker in self .workers
279+ ]
280+ # Add timeout to prevent indefinite hanging
281+ ray .get (shutdown_refs , timeout = 30.0 )
282+ except ray .exceptions .GetTimeoutError :
283+ logger .warning (
284+ "Timeout waiting for workers to shutdown after 30 seconds"
285+ )
286+ except Exception as e :
287+ logger .warning (f"Error shutting down RPC remote: { e } " )
288+
289+ if hasattr (self , 'rpc_client' ) and self .rpc_client is not None :
290+ try :
291+ self .rpc_client .close ()
292+ except Exception as e :
293+ # Suppress errors during RPC client shutdown
294+ # These can occur if the client is already closed or if there are
295+ # pending operations that get cancelled during cleanup
296+ logger_debug (
297+ f"Suppressed error during RPC client close: { e } " )
298+ else :
299+ # Release actors
300+ self .response_queue = None
301+ self .response_sync_queue = None
302+ self .async_response_queue_weakref = None
303+ self .sync_response_queue_weakref = None
234304
235305 self .workers = None
236306 if hasattr (self ,
@@ -246,12 +316,6 @@ def shutdown(self):
246316 logger .debug ("Shutting down Ray cluster" )
247317 ray .shutdown ()
248318
249- @property
250- def enable_postprocess_parallel (self ) -> bool :
251- ret = super ().enable_postprocess_parallel
252- assert ret == False , "Postprocess parallel is not supported in RayExecutor"
253- return ret
254-
255319 def _get_placement_group (self ,
256320 tp_size : int ) -> Tuple [PlacementGroup , List [int ]]:
257321 """
@@ -317,3 +381,15 @@ def _get_placement_group(self,
317381 pg = placement_group (bundles , strategy = strategy )
318382
319383 return pg , bundle_indices
384+
385+ @property
386+ def enable_postprocess_parallel (self ) -> bool :
387+ ret = super ().enable_postprocess_parallel
388+ assert ret == False , "Postprocess parallel is not supported in RayExecutor"
389+ return ret
390+
391+ @staticmethod
392+ def create_actor_weak_ref (actor_handle : ray .actor .ActorHandle ):
393+ state , _ , _ = actor_handle ._serialization_helper ()
394+ return ray .actor .ActorHandle ._deserialization_helper (state ,
395+ weak_ref = True )
0 commit comments