Skip to content

Commit d43f4a7

Browse files
committed
Refactor to further compartmentalize worker from core sampler code
Signed-off-by: Dan Hansen <[email protected]>
1 parent 64c8a67 commit d43f4a7

File tree

1 file changed

+78
-115
lines changed

1 file changed

+78
-115
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 78 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@
8484

8585
@dataclass(kw_only=True)
8686
class SampleStateTensors:
87-
new_tokens: torch.Tensor | futures.Future[torch.Tensor]
87+
new_tokens: torch.Tensor
8888
log_probs: torch.Tensor | None = None
8989

9090
def values(self):
@@ -623,22 +623,15 @@ def async_worker_stop(self):
623623
self._async_worker.shutdown(wait=True)
624624
self._async_worker = None
625625

626-
def _async_worker_run(self, ready, func, /, *args, **kwargs):
626+
def _async_worker_run(self, ready: torch.cuda.Event, func, /, *args, **kwargs):
627627
# Make sure the async work takes place after all prior operations on
628628
# the primary stream. synchronize() is intentionally chosen instead of
629629
# wait() here; otherwise, blocking copies will stall subsequent CUDA
630630
# API calls on the main thread
631631
ready.synchronize()
632632

633633
# Do the work
634-
result = func(*args, **kwargs)
635-
636-
# work submitted to the async worker is expected to block at the end,
637-
# consistent with the semantics of futures; make sure that we wait for
638-
# everything to complete
639-
torch.cuda.current_stream().synchronize()
640-
641-
return result
634+
return func(*args, **kwargs)
642635

643636
def _async_worker_submit(self, func, /, *args, **kwargs):
644637
if self._async_worker_active():
@@ -651,6 +644,29 @@ def _async_worker_submit(self, func, /, *args, **kwargs):
651644
# If the async worker is not in use, just execute the function
652645
return func(*args, **kwargs)
653646

647+
def _copy_to_host(self, src: torch.Tensor, pin_memory=False) -> torch.Tensor:
648+
dest = torch.empty_like(src, device="cpu", pin_memory=pin_memory)
649+
self._async_worker_submit(dest.copy_, src, non_blocking=True)
650+
return dest
651+
652+
def _sampler_event_get(self) -> torch.cuda.Event | futures.Future[torch.cuda.Event]:
653+
def _get_sampler_event() -> torch.cuda.Event:
654+
sampler_event = torch.cuda.Event()
655+
sampler_event.record()
656+
return sampler_event
657+
658+
return self._async_worker_submit(_get_sampler_event)
659+
660+
@staticmethod
661+
def _sampler_event_synchronize(
662+
sampler_event: torch.cuda.Event | futures.Future[torch.cuda.Event] | None,
663+
):
664+
if sampler_event:
665+
if isinstance(sampler_event, futures.Future):
666+
sampler_event.result().synchronize()
667+
else:
668+
sampler_event.synchronize()
669+
654670

655671
class TorchSampler(Sampler):
656672
SampleState = SampleStateTorch
@@ -675,7 +691,7 @@ class Args:
675691
max_num_sequences: int
676692
max_beam_width: int
677693
max_total_draft_tokens: int
678-
enable_async_worker: Optional[bool] = False
694+
enable_async_worker: bool = False
679695

680696
def __init__(self, args: Args):
681697
self.max_seq_len = args.max_seq_len
@@ -797,17 +813,8 @@ def handle_logprobs(
797813
count: int,
798814
):
799815
if request.py_return_log_probs:
800-
if self._async_worker_active():
801-
# These should be futures if we used the async worker
802-
assert isinstance(request.py_topk_logprobs_values, futures.Future)
803-
assert isinstance(request.py_topk_logprobs_vals, futures.Future)
804-
topk_log_probs_vals = request.py_topk_logprobs_vals.result()
805-
topk_log_probs_indices = request.py_topk_logprobs_indices.result()
806-
else:
807-
topk_log_probs_vals = request.py_topk_logprobs_vals
808-
topk_log_probs_indices = request.py_topk_logprobs_indices
809-
topk_log_probs_vals = topk_log_probs_vals[:count]
810-
topk_log_probs_indices = topk_log_probs_indices[:count]
816+
topk_log_probs_vals = request.py_topk_logprobs_vals[:count]
817+
topk_log_probs_indices = request.py_topk_logprobs_indices[:count]
811818

812819
token_log_probs = [
813820
{
@@ -1011,9 +1018,7 @@ def _tree_sampling_batch(
10111018
new_draft_tokens_cuda.transpose(0, 1).to(torch.int, non_blocking=True).unsqueeze(dim=-1)
10121019
)
10131020

1014-
new_draft_tokens_host = self._async_worker_submit(
1015-
int_new_draft_tokens.to, "cpu", non_blocking=True
1016-
)
1021+
new_draft_tokens_host = self._copy_to_host(int_new_draft_tokens)
10171022

10181023
return new_draft_tokens_host
10191024

@@ -1130,16 +1135,10 @@ def update_requests(
11301135
resource_manager: Optional[ResourceManager] = None,
11311136
) -> None:
11321137
assert isinstance(state, SampleStateTorch)
1133-
if state.sampler_event:
1134-
state.sampler_event.synchronize()
1138+
self._sampler_event_synchronize(state.sampler_event)
11351139

11361140
assert state.host is not None
1137-
1138-
if self._async_worker_active():
1139-
assert isinstance(state.host.new_tokens, futures.Future)
1140-
new_tokens = state.host.new_tokens.result()
1141-
else:
1142-
new_tokens = state.host.new_tokens
1141+
new_tokens = state.host.new_tokens
11431142
finish_reasons = state.host.finish_reasons_list()
11441143

11451144
new_tokens_list = new_tokens.tolist()
@@ -1219,8 +1218,7 @@ def sample_async(
12191218
)
12201219
finish_reasons_host = finish_reasons.to(device="cpu", non_blocking=True)
12211220

1222-
sampler_event = torch.cuda.Event()
1223-
sampler_event.record()
1221+
sampler_event = self._sampler_event_get()
12241222
return SampleStateTorch(
12251223
scheduled_requests=scheduled_requests,
12261224
device=SampleStateTensors(new_tokens=new_tokens),
@@ -1480,7 +1478,7 @@ def _dims_canonically_ordered(t: torch.Tensor) -> bool:
14801478
new_tokens_cuda.view(-1, *new_tokens_cuda.shape[2:])[:, beam, ...].scatter_(
14811479
0, batch_dest_indices_1d_cuda, batch_next_tokens_cuda_int
14821480
)
1483-
new_tokens_host = self._async_worker_submit(new_tokens_cuda.to, "cpu", non_blocking=True)
1481+
new_tokens_host = self._copy_to_host(new_tokens_cuda)
14841482

14851483
return new_tokens_host
14861484

@@ -1840,39 +1838,23 @@ def _process_requests(
18401838
topk_vals_cuda, topk_indices_cuda = torch.topk(
18411839
logprobs_cuda, k=max(req.py_num_logprobs for req in requests), dim=-1
18421840
)
1843-
1844-
def _copy_log_probs(
1845-
requests, req_num_steps, logprobs_req_indices, topk_vals_cuda, topk_indices_cuda
1841+
# Use a single D2H copy to reduce overheads
1842+
topk_vals = self._copy_to_host(topk_vals_cuda, pin_memory=True)
1843+
topk_indices = self._copy_to_host(topk_indices_cuda, pin_memory=True)
1844+
current_offset = 0
1845+
for req_id, steps in zip(
1846+
logprobs_req_indices, req_num_steps[logprobs_req_indices].tolist()
18461847
):
1847-
# Use a single D2H copy to reduce overheads
1848-
topk_vals = torch.empty_like(topk_vals_cuda, device="cpu", pin_memory=True)
1849-
topk_indices = torch.empty_like(topk_indices_cuda, device="cpu", pin_memory=True)
1850-
topk_vals.copy_(topk_vals_cuda, non_blocking=True)
1851-
topk_indices.copy_(topk_indices_cuda, non_blocking=True)
1852-
current_offset = 0
1853-
for req_id, steps in zip(
1854-
logprobs_req_indices, req_num_steps[logprobs_req_indices].tolist()
1855-
):
1856-
req = requests[req_id]
1857-
next_offset = current_offset + steps
1858-
# NB: Assigning views on memory which is being filled
1859-
# asynchronously
1860-
req.py_topk_logprobs_vals = topk_vals[
1861-
current_offset:next_offset, : req.py_num_logprobs
1862-
]
1863-
req.py_topk_logprobs_indices = topk_indices[
1864-
current_offset:next_offset, : req.py_num_logprobs
1865-
]
1866-
current_offset = next_offset
1867-
1868-
self._async_worker_submit(
1869-
_copy_log_probs,
1870-
requests,
1871-
req_num_steps,
1872-
logprobs_req_indices,
1873-
topk_vals_cuda,
1874-
topk_indices_cuda,
1875-
)
1848+
req = requests[req_id]
1849+
next_offset = current_offset + steps
1850+
# NB: Assigning views on memory which is being filled asynchronously
1851+
req.py_topk_logprobs_vals = topk_vals[
1852+
current_offset:next_offset, : req.py_num_logprobs
1853+
]
1854+
req.py_topk_logprobs_indices = topk_indices[
1855+
current_offset:next_offset, : req.py_num_logprobs
1856+
]
1857+
current_offset = next_offset
18761858

18771859
# Perform sampling in batches
18781860
batched_sampling_result = self._sample_batched_by_strategy(
@@ -1934,9 +1916,7 @@ class SampleStateTensorsHostTRTLLM(SampleStateTensors):
19341916
class SampleStateTRTLLM(SampleState):
19351917
finalize_events: dict[str, CudaEvent] | None = None
19361918
"""`Optional` to accommodate `_forward_step_inter_pp` which creates a `SampleState` without `finalize_events`"""
1937-
host: Optional[SampleStateTensorsHostTRTLLM | futures.Future[SampleStateTensorsHostTRTLLM]] = (
1938-
None
1939-
)
1919+
host: Optional[SampleStateTensorsHostTRTLLM] = None
19401920

19411921

19421922
class TRTLLMSampler(Sampler, AsyncWorkerMixin):
@@ -1959,7 +1939,7 @@ def __init__(
19591939
max_beam_width: int,
19601940
decoding_config: Optional[DecodingConfig] = None,
19611941
kv_cache_config: Optional[KvCacheConfig] = None,
1962-
enable_async_worker: Optional[bool] = False,
1942+
enable_async_worker: bool = False,
19631943
):
19641944
vocab_size = model.config.vocab_size
19651945
num_hidden_layers = model.config.num_hidden_layers
@@ -2158,10 +2138,9 @@ def sample_async(
21582138
)
21592139

21602140
finalize_events = {}
2161-
gather_ids = False
2162-
decoder_state = self.store["decoder_state"]
2141+
gathered_ids = None
21632142
if beam_width > 1:
2164-
finished_sum_device = decoder_state.finished_sum
2143+
finished_sum_device = self.store["decoder_state"].finished_sum
21652144

21662145
for request in scheduled_requests.all_requests():
21672146
if request.is_context_init_state:
@@ -2170,41 +2149,31 @@ def sample_async(
21702149
finalize_events[request.request_id] = self._finalize_request(request, False)
21712150
elif request.streaming:
21722151
finalize_events[request.request_id] = self._finalize_request(request, True)
2173-
gather_ids = True
2174-
2175-
device = SampleStateTensors(new_tokens=decoder_state.all_new_tokens)
2176-
2177-
def _copy_tensors_to_host(gather_ids, scheduled_requests, decoder_state):
2178-
gathered_ids = None
2179-
if gather_ids:
2180-
gathered_ids = decoder_state.gathered_ids.to("cpu", non_blocking=True)
2181-
new_output_tokens = decoder_state.all_new_tokens.to("cpu", non_blocking=True)
2182-
finished_sum = decoder_state.finished_sum.to("cpu", non_blocking=True)
2183-
finish_reasons = decoder_state.finish_reasons.to("cpu", non_blocking=True)
2184-
sequence_lengths = decoder_state.sequence_lengths.to("cpu", non_blocking=True)
2185-
2186-
log_probs = None
2187-
cum_log_probs = None
2188-
if any(request.py_return_log_probs for request in scheduled_requests.all_requests()):
2189-
log_probs = decoder_state.log_probs.to("cpu", non_blocking=True)
2190-
cum_log_probs = decoder_state.cum_log_probs.to("cpu", non_blocking=True)
2191-
2192-
return SampleStateTensorsHostTRTLLM(
2193-
new_tokens=new_output_tokens,
2194-
finished_sum=finished_sum,
2195-
finish_reasons=finish_reasons,
2196-
sequence_lengths=sequence_lengths,
2197-
log_probs=log_probs,
2198-
cum_log_probs=cum_log_probs,
2199-
gathered_ids=gathered_ids,
2200-
)
2201-
2202-
host = self._async_worker_submit(
2203-
_copy_tensors_to_host, gather_ids, scheduled_requests, decoder_state
2152+
gathered_ids = self._copy_to_host(self.store["decoder_state"].gathered_ids)
2153+
new_output_tokens = self._copy_to_host(self.store["decoder_state"].all_new_tokens)
2154+
finished_sum = self._copy_to_host(self.store["decoder_state"].finished_sum)
2155+
finish_reasons = self._copy_to_host(self.store["decoder_state"].finish_reasons)
2156+
sequence_lengths = self._copy_to_host(self.store["decoder_state"].sequence_lengths)
2157+
2158+
log_probs = None
2159+
cum_log_probs = None
2160+
if any(request.py_return_log_probs for request in scheduled_requests.all_requests()):
2161+
log_probs = self._copy_to_host(self.store["decoder_state"].log_probs)
2162+
cum_log_probs = self._copy_to_host(self.store["decoder_state"].cum_log_probs)
2163+
2164+
device = SampleStateTensors(new_tokens=self.store["decoder_state"].all_new_tokens)
2165+
2166+
host = SampleStateTensorsHostTRTLLM(
2167+
new_tokens=new_output_tokens,
2168+
finished_sum=finished_sum,
2169+
finish_reasons=finish_reasons,
2170+
sequence_lengths=sequence_lengths,
2171+
log_probs=log_probs,
2172+
cum_log_probs=cum_log_probs,
2173+
gathered_ids=gathered_ids,
22042174
)
22052175

2206-
sampler_event = torch.cuda.Event()
2207-
sampler_event.record()
2176+
sampler_event = self._sampler_event_get()
22082177

22092178
self.micro_batch_idx = (self.micro_batch_idx + 1) % self.num_micro_batches
22102179

@@ -2228,13 +2197,7 @@ def update_requests(
22282197
if state.scheduled_requests.batch_size == 0:
22292198
return
22302199

2231-
if state.sampler_event:
2232-
state.sampler_event.synchronize()
2233-
2234-
if self._async_worker_active():
2235-
# Wait for and "unpack" the host tensors
2236-
assert isinstance(state.host, futures.Future)
2237-
state.host = state.host.result()
2200+
self._sampler_event_synchronize(state.sampler_event)
22382201

22392202
beam_width = self.beam_width(state.scheduled_requests.all_requests())
22402203

0 commit comments

Comments
 (0)