From 0d35ef0f161703bb7f80c948e3d487a6e5fa7d3b Mon Sep 17 00:00:00 2001 From: grimoire Date: Thu, 26 Jun 2025 15:08:53 +0800 Subject: [PATCH] ray close wait for forward finish --- lmdeploy/pytorch/engine/executor/base_worker.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/lmdeploy/pytorch/engine/executor/base_worker.py b/lmdeploy/pytorch/engine/executor/base_worker.py index 93870c6bda..2cfa3dbd3f 100644 --- a/lmdeploy/pytorch/engine/executor/base_worker.py +++ b/lmdeploy/pytorch/engine/executor/base_worker.py @@ -50,6 +50,7 @@ def __init__( logger.setLevel(log_level) self.out_que: asyncio.Queue = None self._output_loop: asyncio.Task = None + self._forward_event: asyncio.Event = None def init_process_group(self, rank: int, master_addr: str = None, master_port: str = None): """Initialize process group.""" @@ -136,7 +137,9 @@ def get_input_processor(self): def start(self): """Start engine loop.""" - self.model_agent.start() + self._forward_event = asyncio.Event() + self._forward_event.set() # Set the event to allow forward calls + self.model_agent.start(self._forward_event) event_loop = asyncio.get_event_loop() self.out_que = asyncio.Queue() self._output_loop = event_loop.create_task(self._get_outputs_loop(), name='GetOutputsLoop') @@ -148,6 +151,7 @@ def stop(self): self._output_loop.cancel() async def stop_async(self): + await self._forward_event.wait() # Ensure forward event is set before stopping await self.model_agent.stop_async() if self._output_loop is not None: self._output_loop.cancel()