Skip to content

Commit abef7de

Browse files
committed
Fix some issues discovered in testing
Signed-off-by: Dan Hansen <[email protected]>
1 parent 1756418 commit abef7de

File tree

3 files changed

+14
-7
lines changed

3 files changed

+14
-7
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,8 @@ def shutdown(self):
450450
for key in keys:
451451
del self.virtual_memory_pools[key]
452452
# Stop the sampler's async worker, if it was used
453-
if isinstance(self.sampler, AsyncWorkerMixin):
453+
if (isinstance(self.sampler, AsyncWorkerMixin)
454+
and self.async_worker_enabled()):
454455
self.sampler.async_worker_stop()
455456

456457
def can_enqueue_requests(self) -> bool:

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,8 @@ def update_sampler_max_seq_len(max_seq_len, sampler):
196196

197197

198198
def maybe_start_sampler_async_worker(sampler):
199-
if isinstance(sampler, AsyncWorkerMixin) and sampler.enable_async_worker:
199+
if (isinstance(sampler, AsyncWorkerMixin)
200+
and sampler.async_worker_enabled()):
200201
sampler.async_worker_start()
201202

202203

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -605,12 +605,15 @@ def _async_worker_active(self) -> bool:
605605
return self._async_worker is not None
606606

607607
def _async_worker_init(self, enable_async_worker: bool):
608-
self.enable_async_worker = enable_async_worker
608+
self._enable_async_worker = enable_async_worker
609609
self._async_worker = None
610610
self._async_worker_futures: list[futures.Future[any]] = []
611611

612+
def async_worker_enabled(self):
613+
return hasattr(self, "_enable_async_worker") and self._enable_async_worker
614+
612615
def async_worker_start(self):
613-
assert self.enable_async_worker
616+
assert self.async_worker_enabled()
614617
assert not self._async_worker_active()
615618

616619
def _async_worker_initializer(device_id):
@@ -628,10 +631,12 @@ def _async_worker_initializer(device_id):
628631
)
629632

630633
def async_worker_stop(self):
631-
if self._async_worker_active():
632-
self._async_worker.shutdown(wait=True)
633-
self._async_worker = None
634+
assert self.async_worker_enabled()
635+
assert self._async_worker_active()
636+
self._async_worker.shutdown(wait=True)
637+
self._async_worker = None
634638

639+
@torch.inference_mode()
635640
def _async_worker_run(self, ready: torch.cuda.Event, func, /, *args, **kwargs):
636641
# Make sure the async work takes place after all prior operations on
637642
# the primary stream. synchronize() is intentionally chosen instead of

0 commit comments

Comments
 (0)