Skip to content

Commit 0b24a61

Browse files
committed
Fix some issues discovered in testing
Signed-off-by: Dan Hansen <[email protected]>
1 parent 6d6960b commit 0b24a61

File tree

3 files changed

+14
-7
lines changed

3 files changed

+14
-7
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,8 @@ def shutdown(self):
450450
for key in keys:
451451
del self.virtual_memory_pools[key]
452452
# Stop the sampler's async worker, if it was used
453-
if isinstance(self.sampler, AsyncWorkerMixin):
453+
if (isinstance(self.sampler, AsyncWorkerMixin)
454+
and self.sampler.async_worker_enabled()):
454455
self.sampler.async_worker_stop()
455456

456457
def can_enqueue_requests(self) -> bool:

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,8 @@ def update_sampler_max_seq_len(max_seq_len, sampler):
196196

197197

198198
def maybe_start_sampler_async_worker(sampler):
199-
if isinstance(sampler, AsyncWorkerMixin) and sampler.enable_async_worker:
199+
if (isinstance(sampler, AsyncWorkerMixin)
200+
and sampler.async_worker_enabled()):
200201
sampler.async_worker_start()
201202

202203

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -607,12 +607,15 @@ def _async_worker_active(self) -> bool:
607607
return self._async_worker is not None
608608

609609
def _async_worker_init(self, enable_async_worker: bool):
610-
self.enable_async_worker = enable_async_worker
610+
self._enable_async_worker = enable_async_worker
611611
self._async_worker = None
612612
self._async_worker_futures: list[futures.Future[any]] = []
613613

614+
def async_worker_enabled(self):
615+
return hasattr(self, "_enable_async_worker") and self._enable_async_worker
616+
614617
def async_worker_start(self):
615-
assert self.enable_async_worker
618+
assert self.async_worker_enabled()
616619
assert not self._async_worker_active()
617620

618621
def _async_worker_initializer(device_id):
@@ -630,10 +633,12 @@ def _async_worker_initializer(device_id):
630633
)
631634

632635
def async_worker_stop(self):
633-
if self._async_worker_active():
634-
self._async_worker.shutdown(wait=True)
635-
self._async_worker = None
636+
assert self.async_worker_enabled()
637+
assert self._async_worker_active()
638+
self._async_worker.shutdown(wait=True)
639+
self._async_worker = None
636640

641+
@torch.inference_mode()
637642
def _async_worker_run(self, ready: torch.cuda.Event, func, /, *args, **kwargs):
638643
# Make sure the async work takes place after all prior operations on
639644
# the primary stream. synchronize() is intentionally chosen instead of

0 commit comments

Comments
 (0)