Skip to content

Commit 6d6960b

Browse files
committed
Explicitly track and wait on all futures; makes it possible to scale to > 1 thread correctly in the future
Signed-off-by: Dan Hansen <[email protected]>
1 parent b0e1469 commit 6d6960b

File tree

1 file changed

+44
-22
lines changed

1 file changed

+44
-22
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 44 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,20 @@ def values(self):
9393
return vars(self).values()
9494

9595

96+
@dataclass(kw_only=True)
97+
class SamplerEvent:
98+
cuda_event: torch.cuda.Event
99+
worker_futures: Optional[list[futures.Future[Any]]] = None
100+
101+
96102
@dataclass(kw_only=True)
97103
class SampleState:
98104
scheduled_requests: ScheduledRequests
99105

100106
device: Optional[SampleStateTensors] = None
101107
host: Optional[SampleStateTensors] = None
102108

103-
sampler_event: Optional[torch.cuda.Event] = None
109+
sampler_event: Optional[SamplerEvent] = None
104110

105111

106112
class Sampler(ABC):
@@ -595,12 +601,15 @@ class AsyncWorkerMixin:
595601
operations will seamlessly run on the main thread
596602
"""
597603

604+
MAX_WORKERS = 1
605+
598606
def _async_worker_active(self) -> bool:
599607
return self._async_worker is not None
600608

601609
def _async_worker_init(self, enable_async_worker: bool):
602610
self.enable_async_worker = enable_async_worker
603611
self._async_worker = None
612+
self._async_worker_futures: list[futures.Future[any]] = []
604613

605614
def async_worker_start(self):
606615
assert self.enable_async_worker
@@ -615,7 +624,7 @@ def _async_worker_initializer(device_id):
615624
torch.cuda.set_stream(torch.cuda.Stream())
616625

617626
self._async_worker = futures.ThreadPoolExecutor(
618-
max_workers=1,
627+
max_workers=self.MAX_WORKERS,
619628
initializer=_async_worker_initializer,
620629
initargs=(torch.cuda.current_device(),),
621630
)
@@ -633,41 +642,54 @@ def _async_worker_run(self, ready: torch.cuda.Event, func, /, *args, **kwargs):
633642
ready.synchronize()
634643

635644
# Do the work
636-
return func(*args, **kwargs)
645+
result = func(*args, **kwargs)
646+
647+
# Work submitted to the async worker is expected to block at the end,
648+
# consistent with the semantics of futures; make sure that we wait for
649+
# everything to complete
650+
torch.cuda.current_stream().synchronize()
651+
652+
return result
637653

638654
def _async_worker_submit(self, func, /, *args, **kwargs):
639655
if self._async_worker_active():
640656
# Record an event on the main thread/stream that we will
641657
# synchronize with on the worker thread/stream
642658
ready = torch.cuda.Event()
643659
ready.record()
644-
return self._async_worker.submit(self._async_worker_run, ready, func, *args, **kwargs)
660+
661+
# Submit the async work
662+
result = self._async_worker.submit(self._async_worker_run, ready, func, *args, **kwargs)
663+
664+
# Save the future, so that we can await it later
665+
self._async_worker_futures.append(result)
666+
667+
return result
645668
else:
646669
# If the async worker is not in use, just execute the function
647670
return func(*args, **kwargs)
648671

649-
def _copy_to_host(self, src: torch.Tensor, pin_memory=False) -> torch.Tensor:
650-
dest = torch.empty_like(src, device="cpu", pin_memory=pin_memory)
672+
def _copy_to_host(self, src: torch.Tensor) -> torch.Tensor:
673+
dest = torch.empty_like(src, device="cpu", pin_memory=True)
651674
self._async_worker_submit(dest.copy_, src, non_blocking=True)
652675
return dest
653676

654-
def _sampler_event_get(self) -> torch.cuda.Event | futures.Future[torch.cuda.Event]:
655-
def _get_sampler_event() -> torch.cuda.Event:
656-
sampler_event = torch.cuda.Event()
657-
sampler_event.record()
658-
return sampler_event
677+
def _sampler_event_get(self) -> SamplerEvent:
678+
cuda_event = torch.cuda.Event()
679+
cuda_event.record()
659680

660-
return self._async_worker_submit(_get_sampler_event)
681+
# Transfer ownership to worker_futures and re-initialize
682+
worker_futures = self._async_worker_futures
683+
self._async_worker_futures = []
684+
685+
return SamplerEvent(cuda_event=cuda_event, worker_futures=worker_futures)
661686

662687
@staticmethod
663-
def _sampler_event_synchronize(
664-
sampler_event: torch.cuda.Event | futures.Future[torch.cuda.Event] | None,
665-
):
688+
def _sampler_event_synchronize(sampler_event: SamplerEvent):
666689
if sampler_event:
667-
if isinstance(sampler_event, futures.Future):
668-
sampler_event.result().synchronize()
669-
else:
670-
sampler_event.synchronize()
690+
if sampler_event.worker_futures:
691+
futures.wait(sampler_event.worker_futures)
692+
sampler_event.cuda_event.synchronize()
671693

672694

673695
class TorchSampler(Sampler, AsyncWorkerMixin):
@@ -1229,7 +1251,7 @@ def sample_async(
12291251
self._write_finish_reasons(
12301252
requests, finish_reasons=finish_reasons, seq_slots=seq_slots, new_tokens=new_tokens
12311253
)
1232-
finish_reasons_host = finish_reasons.to(device="cpu", non_blocking=True)
1254+
finish_reasons_host = self._copy_to_host(finish_reasons)
12331255

12341256
sampler_event = self._sampler_event_get()
12351257
return SampleStateTorch(
@@ -1852,8 +1874,8 @@ def _process_requests(
18521874
logprobs_cuda, k=max(req.py_num_logprobs for req in requests), dim=-1
18531875
)
18541876
# Use a single D2H copy to reduce overheads
1855-
topk_vals = self._copy_to_host(topk_vals_cuda, pin_memory=True)
1856-
topk_indices = self._copy_to_host(topk_indices_cuda, pin_memory=True)
1877+
topk_vals = self._copy_to_host(topk_vals_cuda)
1878+
topk_indices = self._copy_to_host(topk_indices_cuda)
18571879
current_offset = 0
18581880
for req_id, steps in zip(
18591881
logprobs_req_indices, req_num_steps[logprobs_req_indices].tolist()

0 commit comments

Comments
 (0)