Skip to content

Commit 16ab0c6

Browse files
committed
Explicitly track and wait on all futures; makes it possible to scale to > 1 thread correctly in the future
Signed-off-by: Dan Hansen <[email protected]>
1 parent 52f9670 commit 16ab0c6

File tree

1 file changed

+44
-22
lines changed

1 file changed

+44
-22
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 44 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -91,14 +91,20 @@ def values(self):
9191
return vars(self).values()
9292

9393

94+
@dataclass(kw_only=True)
95+
class SamplerEvent:
96+
cuda_event: torch.cuda.Event
97+
worker_futures: Optional[list[futures.Future[Any]]] = None
98+
99+
94100
@dataclass(kw_only=True)
95101
class SampleState:
96102
scheduled_requests: ScheduledRequests
97103

98104
device: Optional[SampleStateTensors] = None
99105
host: Optional[SampleStateTensors] = None
100106

101-
sampler_event: Optional[torch.cuda.Event] = None
107+
sampler_event: Optional[SamplerEvent] = None
102108

103109

104110
class Sampler(ABC):
@@ -593,12 +599,15 @@ class AsyncWorkerMixin:
593599
operations will seamlessly run on the main thread
594600
"""
595601

602+
MAX_WORKERS = 1
603+
596604
def _async_worker_active(self) -> bool:
597605
return self._async_worker is not None
598606

599607
def _async_worker_init(self, enable_async_worker: bool):
600608
self.enable_async_worker = enable_async_worker
601609
self._async_worker = None
610+
self._async_worker_futures: list[futures.Future[any]] = []
602611

603612
def async_worker_start(self):
604613
assert self.enable_async_worker
@@ -613,7 +622,7 @@ def _async_worker_initializer(device_id):
613622
torch.cuda.set_stream(torch.cuda.Stream())
614623

615624
self._async_worker = futures.ThreadPoolExecutor(
616-
max_workers=1,
625+
max_workers=self.MAX_WORKERS,
617626
initializer=_async_worker_initializer,
618627
initargs=(torch.cuda.current_device(),),
619628
)
@@ -631,41 +640,54 @@ def _async_worker_run(self, ready: torch.cuda.Event, func, /, *args, **kwargs):
631640
ready.synchronize()
632641

633642
# Do the work
634-
return func(*args, **kwargs)
643+
result = func(*args, **kwargs)
644+
645+
# Work submitted to the async worker is expected to block at the end,
646+
# consistent with the semantics of futures; make sure that we wait for
647+
# everything to complete
648+
torch.cuda.current_stream().synchronize()
649+
650+
return result
635651

636652
def _async_worker_submit(self, func, /, *args, **kwargs):
637653
if self._async_worker_active():
638654
# Record an event on the main thread/stream that we will
639655
# synchronize with on the worker thread/stream
640656
ready = torch.cuda.Event()
641657
ready.record()
642-
return self._async_worker.submit(self._async_worker_run, ready, func, *args, **kwargs)
658+
659+
# Submit the async work
660+
result = self._async_worker.submit(self._async_worker_run, ready, func, *args, **kwargs)
661+
662+
# Save the future, so that we can await it later
663+
self._async_worker_futures.append(result)
664+
665+
return result
643666
else:
644667
# If the async worker is not in use, just execute the function
645668
return func(*args, **kwargs)
646669

647-
def _copy_to_host(self, src: torch.Tensor, pin_memory=False) -> torch.Tensor:
648-
dest = torch.empty_like(src, device="cpu", pin_memory=pin_memory)
670+
def _copy_to_host(self, src: torch.Tensor) -> torch.Tensor:
671+
dest = torch.empty_like(src, device="cpu", pin_memory=True)
649672
self._async_worker_submit(dest.copy_, src, non_blocking=True)
650673
return dest
651674

652-
def _sampler_event_get(self) -> torch.cuda.Event | futures.Future[torch.cuda.Event]:
653-
def _get_sampler_event() -> torch.cuda.Event:
654-
sampler_event = torch.cuda.Event()
655-
sampler_event.record()
656-
return sampler_event
675+
def _sampler_event_get(self) -> SamplerEvent:
676+
cuda_event = torch.cuda.Event()
677+
cuda_event.record()
657678

658-
return self._async_worker_submit(_get_sampler_event)
679+
# Transfer ownership to worker_futures and re-initialize
680+
worker_futures = self._async_worker_futures
681+
self._async_worker_futures = []
682+
683+
return SamplerEvent(cuda_event=cuda_event, worker_futures=worker_futures)
659684

660685
@staticmethod
661-
def _sampler_event_synchronize(
662-
sampler_event: torch.cuda.Event | futures.Future[torch.cuda.Event] | None,
663-
):
686+
def _sampler_event_synchronize(sampler_event: SamplerEvent):
664687
if sampler_event:
665-
if isinstance(sampler_event, futures.Future):
666-
sampler_event.result().synchronize()
667-
else:
668-
sampler_event.synchronize()
688+
if sampler_event.worker_futures:
689+
futures.wait(sampler_event.worker_futures)
690+
sampler_event.cuda_event.synchronize()
669691

670692

671693
class TorchSampler(Sampler, AsyncWorkerMixin):
@@ -1216,7 +1238,7 @@ def sample_async(
12161238
self._write_finish_reasons(
12171239
requests, finish_reasons=finish_reasons, seq_slots=seq_slots, new_tokens=new_tokens
12181240
)
1219-
finish_reasons_host = finish_reasons.to(device="cpu", non_blocking=True)
1241+
finish_reasons_host = self._copy_to_host(finish_reasons)
12201242

12211243
sampler_event = self._sampler_event_get()
12221244
return SampleStateTorch(
@@ -1839,8 +1861,8 @@ def _process_requests(
18391861
logprobs_cuda, k=max(req.py_num_logprobs for req in requests), dim=-1
18401862
)
18411863
# Use a single D2H copy to reduce overheads
1842-
topk_vals = self._copy_to_host(topk_vals_cuda, pin_memory=True)
1843-
topk_indices = self._copy_to_host(topk_indices_cuda, pin_memory=True)
1864+
topk_vals = self._copy_to_host(topk_vals_cuda)
1865+
topk_indices = self._copy_to_host(topk_indices_cuda)
18441866
current_offset = 0
18451867
for req_id, steps in zip(
18461868
logprobs_req_indices, req_num_steps[logprobs_req_indices].tolist()

0 commit comments

Comments
 (0)