Skip to content

Commit 8902514

Browse files
committed
[refactor] Adjust TRTLLMSampler to use sampling requests in sampling state
Signed-off-by: Robin Kobus <[email protected]>
1 parent db315e0 commit 8902514

File tree

1 file changed

+28
-32
lines changed

1 file changed

+28
-32
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 28 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1988,9 +1988,12 @@ def sample_async(
19881988
if beam_width > 1:
19891989
self._update_cache_indirection_buffer(scheduled_requests)
19901990

1991+
decoder_input_buffers = self.store["decoder_input_buffers"][self.micro_batch_idx]
1992+
decoder_state = self.store["decoder_state"]
1993+
19911994
make_decoding_batch_input(
1992-
self.store["decoder_input_buffers"][self.micro_batch_idx],
1993-
self.store["decoder_state"],
1995+
decoder_input_buffers,
1996+
decoder_state,
19941997
scheduled_requests.context_requests,
19951998
scheduled_requests.generation_requests,
19961999
model_outputs["logits"],
@@ -2000,35 +2003,40 @@ def sample_async(
20002003
)
20012004

20022005
self.algs.decoder.forward_async(
2003-
self.store["decoder_state"],
2004-
self.store["decoder_input_buffers"][self.micro_batch_idx],
2006+
decoder_state,
2007+
decoder_input_buffers,
20052008
)
20062009

2010+
finished_context_requests = [
2011+
req for req in scheduled_requests.context_requests if req.is_last_context_chunk
2012+
]
2013+
sampling_requests = finished_context_requests + scheduled_requests.generation_requests
2014+
20072015
finalize_events = {}
20082016
gathered_ids = None
20092017
if beam_width > 1:
2010-
finished_sum_device = self.store["decoder_state"].finished_sum
2018+
finished_sum_device = decoder_state.finished_sum
20112019

2012-
for request in scheduled_requests.all_requests():
2020+
for request in sampling_requests:
20132021
if request.is_context_init_state:
20142022
continue
20152023
if finished_sum_device[request.seq_slot] == beam_width:
20162024
finalize_events[request.request_id] = self._finalize_request(request, False)
20172025
elif request.streaming:
20182026
finalize_events[request.request_id] = self._finalize_request(request, True)
2019-
gathered_ids = self.store["decoder_state"].gathered_ids.to("cpu", non_blocking=True)
2020-
new_output_tokens = self.store["decoder_state"].all_new_tokens.to("cpu", non_blocking=True)
2021-
finished_sum = self.store["decoder_state"].finished_sum.to("cpu", non_blocking=True)
2022-
finish_reasons = self.store["decoder_state"].finish_reasons.to("cpu", non_blocking=True)
2023-
sequence_lengths = self.store["decoder_state"].sequence_lengths.to("cpu", non_blocking=True)
2027+
gathered_ids = decoder_state.gathered_ids.to("cpu", non_blocking=True)
2028+
new_output_tokens = decoder_state.all_new_tokens.to("cpu", non_blocking=True)
2029+
finished_sum = decoder_state.finished_sum.to("cpu", non_blocking=True)
2030+
finish_reasons = decoder_state.finish_reasons.to("cpu", non_blocking=True)
2031+
sequence_lengths = decoder_state.sequence_lengths.to("cpu", non_blocking=True)
20242032

20252033
log_probs = None
20262034
cum_log_probs = None
2027-
if any(request.py_return_log_probs for request in scheduled_requests.all_requests()):
2028-
log_probs = self.store["decoder_state"].log_probs.to("cpu", non_blocking=True)
2029-
cum_log_probs = self.store["decoder_state"].cum_log_probs.to("cpu", non_blocking=True)
2035+
if any(request.py_return_log_probs for request in sampling_requests):
2036+
log_probs = decoder_state.log_probs.to("cpu", non_blocking=True)
2037+
cum_log_probs = decoder_state.cum_log_probs.to("cpu", non_blocking=True)
20302038

2031-
device = SampleStateTensors(new_tokens=self.store["decoder_state"].all_new_tokens)
2039+
device = SampleStateTensors(new_tokens=decoder_state.all_new_tokens)
20322040

20332041
host = SampleStateTensorsHostTRTLLM(
20342042
new_tokens=new_output_tokens,
@@ -2046,7 +2054,7 @@ def sample_async(
20462054
self.micro_batch_idx = (self.micro_batch_idx + 1) % self.num_micro_batches
20472055

20482056
return SampleStateTRTLLM(
2049-
scheduled_requests=scheduled_requests,
2057+
requests=sampling_requests,
20502058
device=device,
20512059
host=host,
20522060
sampler_event=sampler_event,
@@ -2062,13 +2070,13 @@ def update_requests(
20622070
):
20632071
# resource_manager will not be used in this function, just for interface consistency.
20642072
assert isinstance(state, SampleStateTRTLLM)
2065-
if state.scheduled_requests.batch_size == 0:
2073+
if len(state.requests) == 0:
20662074
return
20672075

20682076
if state.sampler_event:
20692077
state.sampler_event.synchronize()
20702078

2071-
beam_width = self.beam_width(state.scheduled_requests.all_requests())
2079+
beam_width = self.beam_width(state.requests)
20722080

20732081
if beam_width == 1 and self.MAX_DECODING_TOKENS == 1:
20742082
self.update_requests_single_beam_single_step(state)
@@ -2087,13 +2095,7 @@ def update_requests_single_beam_single_step(self, state: SampleStateTRTLLM):
20872095
state.host.cum_log_probs.tolist() if state.host.cum_log_probs is not None else None
20882096
)
20892097

2090-
reqs = [
2091-
r for r in state.scheduled_requests.context_requests if not r.is_context_init_state
2092-
] + [
2093-
r
2094-
for r in state.scheduled_requests.generation_requests
2095-
if not r.is_generation_complete_state
2096-
]
2098+
reqs = [r for r in state.requests if not r.is_generation_complete_state]
20972099

20982100
reqs_with_new_tokens = [
20992101
r for r in reqs if (sequence_lengths_host_data[r.py_seq_slot] > r.get_num_tokens(0))
@@ -2148,13 +2150,7 @@ def update_requests_multiple_beams_or_drafting(
21482150
log_probs_host = state.host.log_probs.tolist() if state.host.log_probs is not None else None
21492151
finalize_events = state.finalize_events
21502152

2151-
reqs = [
2152-
r for r in state.scheduled_requests.context_requests if not r.is_context_init_state
2153-
] + [
2154-
r
2155-
for r in state.scheduled_requests.generation_requests
2156-
if not r.is_generation_complete_state
2157-
]
2153+
reqs = [r for r in state.requests if not r.is_generation_complete_state]
21582154

21592155
for request in reqs:
21602156
seq_slot = request.py_seq_slot

0 commit comments

Comments
 (0)