Skip to content

Commit b5e195d

Browse files
committed
[TRTLLM-6756][chore] Enhance TorchSampler with new setup_sampler_step method and fix bugs
- Introduced setup_sampler_step method to enable the setup process for disaggregated serving in beam search. - Updated cache indirection initialization to use torch.zeros to prevent reading invalid values from cache_indirection - Updated mtpSampler to correctly call TorchSampler functions - Fixed handle_finish_reasons by wrapping finish reasons in the FinishReason class. - Adjusted max_lengths_tensor calculation to account for original prompt length. Signed-off-by: Stefan Niebler <[email protected]>
1 parent 7cc7260 commit b5e195d

File tree

3 files changed

+49
-22
lines changed

3 files changed

+49
-22
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -639,7 +639,7 @@ def finish_reasons_list(self) -> FinishReasonsList:
639639
@dataclass(kw_only=True)
640640
class SampleStateTorch(SampleState):
641641
host: SampleStateTensorsHostTorch
642-
beam_histories: list[BeamHistory | None]
642+
beam_histories: list[BeamHistory | None] | None = None
643643

644644

645645
class TorchSampler(Sampler):
@@ -691,7 +691,9 @@ def create_store(self) -> Store:
691691
return self.Store(
692692
new_tokens=int_tensor(self.NEW_TOKENS_SHAPE),
693693
finish_reasons=int_tensor(self.NEW_TOKENS_SHAPE),
694-
cache_indirection=int_tensor(self.CACHE_INDIRECTION_SHAPE),
694+
cache_indirection=torch.zeros(
695+
self.CACHE_INDIRECTION_SHAPE, device="cuda", dtype=torch.int
696+
),
695697
cache_indirection_buffer=int_tensor(self.CACHE_INDIRECTION_SHAPE),
696698
cum_log_probs=torch.zeros(
697699
self.CACHE_INDIRECTION_SHAPE[:-1], device="cuda", dtype=torch.float32
@@ -718,7 +720,7 @@ class Args:
718720
max_num_sequences: int
719721
max_beam_width: int
720722
max_total_draft_tokens: int
721-
disable_overlap_scheduler: bool
723+
disable_overlap_scheduler: bool = False
722724
disable_flash_infer_sampling: bool = False
723725

724726
def __init__(self, args: Args):
@@ -873,7 +875,10 @@ def _handle_finish_reasons(
873875
request.state = LlmRequestState.GENERATION_COMPLETE
874876
for beam_idx in range(request.sampling_config.beam_width):
875877
request.set_finished_reason(
876-
finish_reasons_list[request.py_seq_slot][DEFAULT_STEP_IDX][beam_idx], beam_idx
878+
FinishReason(
879+
finish_reasons_list[request.py_seq_slot][DEFAULT_STEP_IDX][beam_idx]
880+
),
881+
beam_idx,
877882
)
878883
return True
879884
return False
@@ -1069,7 +1074,10 @@ def _process_draft_tokens_tree(
10691074
for idx in eagle_paths[longest_match_path_idx][:longest_accepted_len]:
10701075
add_token(request, new_tokens_list, beam_idx=self.DEFAULT_BEAM_IDX, step=cast(int, idx.item()))
10711076
num_accepted_draft_tokens += 1
1072-
if self.finish_if_reason(request, finish_reasons, step=num_accepted_draft_tokens):
1077+
if self.finish_if_reason(request,
1078+
finish_reasons,
1079+
step=num_accepted_draft_tokens,
1080+
beam_idx=DEFAULT_BEAM_IDX,):
10731081
break
10741082

10751083
assert num_accepted_draft_tokens <= longest_accepted_len
@@ -1080,6 +1088,15 @@ def _process_draft_tokens_tree(
10801088
return num_accepted_draft_tokens - 1
10811089

10821090

1091+
def setup_sampler_step(self, requests: ScheduledRequests):
1092+
"""Setup the sampler step for the requests
1093+
1094+
Args:
1095+
requests: list[LlmRequest]. The requests to setup the sampler step for
1096+
"""
1097+
if self._use_beam_search:
1098+
self._prepare_beam_search(requests)
1099+
10831100
def _prepare_beam_search(
10841101
self,
10851102
requests: list[LlmRequest],
@@ -1090,12 +1107,11 @@ def _prepare_beam_search(
10901107
initialize/reset the buffers for the request
10911108
"""
10921109
for request in requests:
1093-
if (
1094-
not request.is_finished
1095-
and request.is_context_init_state
1096-
and request.is_last_context_chunk
1110+
if not request.is_finished and (
1111+
(request.is_context_init_state and request.is_last_context_chunk)
1112+
or request.is_disagg_generation_transmission_complete
10971113
):
1098-
if request.py_num_logprobs > 1:
1114+
if request.py_return_log_probs and request.py_num_logprobs > 1:
10991115
raise ValueError("Beam search does not support multiple logprobs")
11001116
self.store.cache_indirection[request.py_seq_slot, :, request.py_prompt_len].fill_(0)
11011117
self.store.cum_log_probs[request.py_seq_slot].fill_(0)
@@ -1559,7 +1575,7 @@ def update_requests(
15591575
or req.context_remaining_length != 0
15601576
):
15611577
continue
1562-
if beam_histories[req_idx] is not None:
1578+
if beam_histories is not None and beam_histories[req_idx] is not None:
15631579
self._finalize_beam(
15641580
req,
15651581
beam_histories[req_idx],
@@ -1579,7 +1595,7 @@ def update_requests(
15791595
if req.state == LlmRequestState.GENERATION_COMPLETE:
15801596
continue
15811597
if req.sampling_config.beam_width > 1:
1582-
if beam_histories[req_idx] is not None:
1598+
if beam_histories is not None and beam_histories[req_idx] is not None:
15831599
self._finalize_beam(
15841600
req,
15851601
beam_histories[req_idx],
@@ -2206,7 +2222,10 @@ def _are_max_length(self, requests: list[LlmRequest]) -> torch.Tensor:
22062222
)
22072223
max_lengths_tensor = torch.tensor(
22082224
[
2209-
([min(req.py_max_new_tokens, self.max_seq_len)] * self.max_beam_width)
2225+
(
2226+
[min(req.py_max_new_tokens, self.max_seq_len - req.orig_prompt_len)]
2227+
* self.max_beam_width
2228+
)
22102229
for req in requests
22112230
]
22122231
* self.max_tokens

tensorrt_llm/_torch/speculative/mtp.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -236,11 +236,12 @@ def __init__(self, args: TorchSampler.Args, *, nextn: int):
236236

237237
seq_slots = args.max_num_sequences
238238
max_tokens = args.max_total_draft_tokens + 1
239-
max_beam_width = args.max_beam_width
239+
self.max_beam_width = args.max_beam_width
240240

241241
self.store = self.Store(
242-
new_tokens=int_tensor((max_tokens, seq_slots, max_beam_width)),
243-
next_new_tokens=int_tensor((max_tokens, seq_slots, max_beam_width)),
242+
new_tokens=int_tensor((max_tokens, seq_slots, self.max_beam_width)),
243+
next_new_tokens=int_tensor(
244+
(max_tokens, seq_slots, self.max_beam_width)),
244245
next_draft_tokens=int_tensor(
245246
(seq_slots, args.max_total_draft_tokens)),
246247
new_tokens_lens=int_tensor((seq_slots, )),
@@ -271,20 +272,27 @@ def update_requests(
271272
for req in state.scheduled_requests.context_requests:
272273
if req.state == LlmRequestState.GENERATION_COMPLETE or req.context_remaining_length != 0:
273274
continue
274-
new_token = add_token(req, new_tokens, beam=beam_idx)
275+
new_token = add_token(req, new_tokens, beam_idx=beam_idx)
275276
TorchSampler._handle_stop_criteria(req,
276277
new_token,
277-
max_seq_len=self.max_seq_len)
278+
max_seq_len=self.max_seq_len,
279+
beam_idx=beam_idx)
278280
self._request_common_handling(req, next_draft_tokens_list)
279281

280282
for req in state.scheduled_requests.generation_requests:
281283
if req.state == LlmRequestState.GENERATION_COMPLETE:
282284
continue
283285
num_new_tokens = new_tokens_lens_list[req.py_seq_slot]
284286
for i in range(num_new_tokens):
285-
new_token = add_token(req, new_tokens, beam=beam_idx, step=i)
287+
new_token = add_token(req,
288+
new_tokens,
289+
beam_idx=beam_idx,
290+
step=i)
286291
if TorchSampler._handle_stop_criteria(
287-
req, new_token, max_seq_len=self.max_seq_len):
292+
req,
293+
new_token,
294+
max_seq_len=self.max_seq_len,
295+
beam_idx=beam_idx):
288296
break
289297
req.py_num_accepted_draft_tokens = num_new_tokens - 1
290298
req.py_rewind_len = self.draft_len - req.py_num_accepted_draft_tokens

tests/unittest/_torch/speculative/test_draft_token_tree_verification.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ def run_test(eagle_model_dir, max_seq_len, beam_width, use_dynamic_tree,
4747
))
4848
# fill with NOT_FINISHED to ensure that all finish reasons are NOT_FINISHED
4949
torch_sampler.store.finish_reasons.fill_(FinishReason.NOT_FINISHED.value)
50-
finish_reasons_list = torch_sampler.store.finish_reasons[..., 0].to(
51-
device="cpu").T.tolist()
50+
finish_reasons_list = torch_sampler.store.finish_reasons.to(
51+
device="cpu").permute(1, 0, 2).tolist()
5252
input_new_tokens_list = input_new_tokens.tolist()
5353
num_accepted_draft_tokens = torch_sampler._process_draft_tokens_tree(
5454
request=input_request,

0 commit comments

Comments
 (0)