Skip to content

Commit 889a4bd

Browse files
netanel-habervenkywonka
authored andcommitted
MTP and derivatives: Align sample state with trtllm sampler sample state (#5675)
This PR moves MTPSampler and derivatives to use the universal seq_slot indexing for sampling. This is the last piece of the puzzle: After this, all of the samplers will use this format. See: 6ee94c7 Signed-off-by: Netanel Haber <[email protected]>
1 parent 584d182 commit 889a4bd

File tree

5 files changed

+112
-107
lines changed

5 files changed

+112
-107
lines changed

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
ExecutorRequest = tllm_executor.Request
2626
ExecutorResponse = tllm_executor.Response
2727
ExecutorSamplingConfig = tllm_executor.SamplingConfig
28+
FinishReason = tllm_executor.FinishReason
2829

2930
REQUEST_TYPE_MAPPING = {
3031
tllm_executor.RequestType.REQUEST_TYPE_CONTEXT_AND_GENERATION:
@@ -319,6 +320,11 @@ def create_response(
319320
def is_dummy(self):
320321
return self.is_attention_dp_dummy or self.is_cuda_graph_dummy or self.is_dummy_request
321322

323+
def finish_by(self, reason: FinishReason, beam: int) -> None:
324+
"""CPP finish by reason does not support beam_width > 1"""
325+
self.state = LlmRequestState.GENERATION_COMPLETE
326+
self.set_finished_reason(reason, beam)
327+
322328

323329
def convert_wordlist(word_list) -> List[List[int]]:
324330
"""Converts a wordlist from format:

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import tqdm
2222

2323
import tensorrt_llm.bindings.internal.userbuffers as ub
24-
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest
2524
from tensorrt_llm._torch.pyexecutor.sampler import SampleStateTensors
2625
from tensorrt_llm._torch.speculative.mtp import SampleStateTensorsMTP
2726
from tensorrt_llm._utils import (is_trace_enabled, local_mpi_rank,
@@ -1174,16 +1173,6 @@ def _prepare_tp_inputs(
11741173
draft_lens = []
11751174
mrope_config = defaultdict(list)
11761175

1177-
mtp_batch_idx = 0 # Temporary: MTP (and Eagle3OneModel) remain the only samplers to index new_tokens serially
1178-
1179-
def py_batch_idx(request: LlmRequest) -> int:
1180-
if not self.without_logits:
1181-
return request.seq_slot
1182-
nonlocal mtp_batch_idx
1183-
batch_idx = mtp_batch_idx
1184-
mtp_batch_idx += 1
1185-
return batch_idx
1186-
11871176
for request in scheduled_requests.context_requests:
11881177
request_ids.append(request.py_request_id)
11891178
all_prompt_tokens = request.get_tokens(0)
@@ -1213,7 +1202,7 @@ def py_batch_idx(request: LlmRequest) -> int:
12131202
) if mrope_rotary_cos_sin.device == 'cpu' else mrope_rotary_cos_sin
12141203
mrope_config['mrope_rotary_cos_sin'].append(
12151204
mrope_rotary_cos_sin.to('cuda', non_blocking=True))
1216-
request.py_batch_idx = py_batch_idx(request)
1205+
request.py_batch_idx = request.seq_slot
12171206

12181207
num_ctx_requests = len(scheduled_requests.context_requests)
12191208
num_ctx_tokens = len(input_ids)
@@ -1295,11 +1284,11 @@ def py_batch_idx(request: LlmRequest) -> int:
12951284
num_cached_tokens_per_seq.append(past_seen_token_num)
12961285
request_ids.append(request.py_request_id)
12971286
# update batch index
1298-
request.py_batch_idx = py_batch_idx(request)
1287+
request.py_batch_idx = request.seq_slot
12991288
else:
13001289
# update batch index
13011290
previous_batch_idx = request.py_batch_idx
1302-
request.py_batch_idx = py_batch_idx(request)
1291+
request.py_batch_idx = request.seq_slot
13031292
# inputs
13041293
# overlap scheduler can only support the speculative decoding
13051294
# methods with a fixed number of draft tokens
@@ -1350,7 +1339,7 @@ def py_batch_idx(request: LlmRequest) -> int:
13501339
prompt_lengths.append(request.py_prompt_len)
13511340
draft_lens.append(0)
13521341

1353-
request.py_batch_idx = py_batch_idx(request)
1342+
request.py_batch_idx = request.seq_slot
13541343

13551344
previous_batch_len = len(previous_batch_indices)
13561345

@@ -1387,7 +1376,8 @@ def previous_seq_slots_device():
13871376
# previous input ids
13881377
previous_batch_tokens = previous_batch_len * (
13891378
1 + self.max_draft_len)
1390-
new_tokens = new_tokens_device[previous_slots, :].flatten()
1379+
new_tokens = new_tokens_device.transpose(
1380+
0, 1)[previous_slots, :].flatten()
13911381
self.input_ids_cuda[num_tokens:num_tokens +
13921382
previous_batch_tokens].copy_(
13931383
new_tokens, non_blocking=True)

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,10 @@ def add_token(request: LlmRequest,
199199
return new_token
200200

201201

202+
def int_tensor(shape: tuple[int, ...], device: str = 'cuda') -> torch.Tensor:
203+
return torch.empty(shape, dtype=torch.int, device=device)
204+
205+
202206
class TorchSampler(Sampler):
203207
BEAM = 0
204208
MAX_BEAM_WIDTH = BEAM + 1
@@ -208,6 +212,9 @@ class Store:
208212
new_tokens: torch.Tensor
209213
"""Shape: See cpp DecoderState.getAllNewTokens()"""
210214

215+
def create_store(self) -> Store:
216+
return self.Store(new_tokens=int_tensor(self.NEW_TOKENS_SHAPE))
217+
211218
@dataclass(frozen=True, kw_only=True)
212219
class Args:
213220
max_seq_len: int
@@ -223,18 +230,16 @@ def __init__(self, args: Args):
223230
assert args.max_beam_width == self.MAX_BEAM_WIDTH, "TorchSampler only supports beam_width = 1"
224231
self.num_seq_slots = args.max_num_sequences
225232

233+
self.NEW_TOKENS_SHAPE = (self.max_tokens, self.num_seq_slots,
234+
self.MAX_BEAM_WIDTH)
226235
# AutoDeploy build creates the sampler in inference mode,
227236
# which would disallow in-place mutating of new_tokens.
228237
# So, we temporarily exit inference mode.
229238
with torch.inference_mode(False):
230-
new_tokens = torch.zeros(
231-
(self.max_tokens, self.num_seq_slots, self.MAX_BEAM_WIDTH),
232-
dtype=torch.int,
233-
device='cuda')
234-
self.store = self.Store(new_tokens=new_tokens)
235-
236-
def _meet_max_token_stop_criteria(self, request: LlmRequest,
237-
num_tokens: int):
239+
self.store = self.create_store()
240+
241+
def _meet_max_token_stop_criteria(self, request: LlmRequest):
242+
num_tokens = request.get_num_tokens(self.BEAM)
238243
return (num_tokens - request.py_orig_prompt_len
239244
>= request.py_max_new_tokens) or (num_tokens
240245
>= self.max_seq_len)
@@ -258,21 +263,20 @@ def _meet_stop_token_criteria(request: LlmRequest):
258263
return True
259264
return False
260265

261-
def _handle_stop_criteria(self, request: LlmRequest, new_token: int, *,
262-
beam: int) -> bool:
266+
def _handle_stop_criteria(self, request: LlmRequest,
267+
new_token: int) -> bool:
263268
"""Handle stop criteria and set appropriate finish reasons and state.
264269
Returns True if generation should stop."""
265270
if new_token == request.py_end_id:
266-
request.finish_by_reason(FinishReason.END_ID)
271+
request.finish_by(FinishReason.END_ID, self.BEAM)
267272
return True
268273

269-
num_tokens = request.get_num_tokens(beam)
270-
if self._meet_max_token_stop_criteria(request, num_tokens):
271-
request.finish_by_reason(FinishReason.LENGTH)
274+
if self._meet_max_token_stop_criteria(request):
275+
request.finish_by(FinishReason.LENGTH, self.BEAM)
272276
return True
273277

274278
if self._meet_stop_token_criteria(request):
275-
request.finish_by_reason(FinishReason.STOP_WORDS)
279+
request.finish_by(FinishReason.STOP_WORDS, self.BEAM)
276280
return True
277281

278282
return False
@@ -307,7 +311,7 @@ def process_draft_tokens(self, request: LlmRequest,
307311
new_tokens,
308312
beam=self.BEAM,
309313
step=num_accepted)
310-
if self._handle_stop_criteria(request, new_token, beam=self.BEAM):
314+
if self._handle_stop_criteria(request, new_token):
311315
break
312316
return num_accepted
313317

@@ -321,15 +325,15 @@ def update_requests(self, state: SampleState) -> None:
321325
if req.state == LlmRequestState.GENERATION_COMPLETE or req.context_remaining_length != 0:
322326
continue
323327
new_token = add_token(req, new_tokens, beam=self.BEAM)
324-
stop = self._handle_stop_criteria(req, new_token, beam=self.BEAM)
328+
self._handle_stop_criteria(req, new_token)
325329
self.handle_logits(req, state, beam=self.BEAM, count=1)
326330
req.py_decoding_iter += 1
327331

328332
for req in state.scheduled_requests.generation_requests:
329333
if req.state == LlmRequestState.GENERATION_COMPLETE:
330334
continue
331335
new_token = add_token(req, new_tokens, beam=self.BEAM)
332-
stop = self._handle_stop_criteria(req, new_token, beam=self.BEAM)
336+
stop = self._handle_stop_criteria(req, new_token)
333337
processed = 1
334338
if not stop and len(req.py_draft_tokens) > 0:
335339
num_accepted = self.process_draft_tokens(

tensorrt_llm/_torch/speculative/eagle3.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,13 +472,15 @@ def draft_decoder(
472472
Draft token ids. Flattened.
473473
'''
474474

475-
draft_tokens = torch.argmax(logits, dim=-1).type(torch.int32)
475+
draft_tokens = torch.argmax(logits, dim=-1)
476476

477477
# Apply d2t (offsets between draft model dictionary and main model dictionary).
478478
if hasattr(draft_model.model,
479479
"d2t") and draft_model.model.d2t is not None:
480480
draft_tokens = draft_model.model.d2t[draft_tokens] + draft_tokens
481481

482+
draft_tokens = draft_tokens.type(torch.int32)
483+
482484
return draft_tokens
483485

484486
def prepare_1st_drafter_inputs(

tensorrt_llm/_torch/speculative/mtp.py

Lines changed: 75 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,11 @@
44
import torch
55
from torch import nn
66

7-
from tensorrt_llm.bindings.executor import FinishReason
8-
97
from ..attention_backend import AttentionMetadata
108
from ..pyexecutor.llm_request import LlmRequest, LlmRequestState
119
from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager
12-
from ..pyexecutor.sampler import SampleState, SampleStateTensors, TorchSampler
10+
from ..pyexecutor.sampler import (SampleState, SampleStateTensors, TorchSampler,
11+
add_token, int_tensor)
1312
from ..pyexecutor.scheduler import ScheduledRequests
1413
from .interface import SpecConfig, SpecMetadata, SpeculativeDecodingMode
1514

@@ -249,92 +248,96 @@ class MTPSampler(TorchSampler):
249248
SampleState = SampleStateMTP
250249

251250
def __init__(self, args: TorchSampler.Args, *, nextn: int):
252-
super().__init__(args)
253251
self.mapping = None
254252
self.draft_len = nextn
253+
super().__init__(args)
255254

256-
def _draft_meet_max_token_stop_criteria(self, request: LlmRequest,
257-
num_tokens: int, beam_idx: int):
258-
if self._meet_max_token_stop_criteria(request, num_tokens):
259-
request.state = LlmRequestState.GENERATION_COMPLETE
260-
request.set_finished_reason(FinishReason.LENGTH, beam_idx)
255+
@dataclass(frozen=True, kw_only=True)
256+
class Store(TorchSampler.Store):
257+
next_new_tokens: torch.Tensor
258+
next_draft_tokens: torch.Tensor
259+
new_tokens_lens: torch.Tensor
260+
261+
def create_store(self) -> Store:
262+
num_tokens, seq_slots, _ = self.NEW_TOKENS_SHAPE
263+
draft_len = num_tokens - 1
264+
assert draft_len == self.draft_len
265+
return self.Store(
266+
new_tokens=int_tensor(self.NEW_TOKENS_SHAPE),
267+
next_new_tokens=int_tensor(self.NEW_TOKENS_SHAPE),
268+
next_draft_tokens=int_tensor((seq_slots, draft_len)),
269+
new_tokens_lens=int_tensor((seq_slots, )),
270+
)
271+
272+
def _request_common_handling(self, request: LlmRequest,
273+
next_draft_tokens: list[list[int]]):
274+
assert not request.py_return_context_logits, "return_context_logits not implemented for MTPSampler"
275+
assert not request.py_return_generation_logits, "return_generation_logits not implemented for MTPSampler"
276+
assert not request.py_return_log_probs, "return_log_probs not implemented for MTPSampler"
277+
request.py_draft_tokens = next_draft_tokens[request.seq_slot]
278+
request.py_decoding_iter += 1
261279

262280
def update_requests(self, state: SampleStateMTP) -> None:
263281
assert isinstance(state, SampleStateMTP)
264282

265283
state.sampler_event.synchronize()
266-
new_tokens_list = state.host.new_tokens.tolist()
267-
new_tokens_lens_list = state.host.new_tokens_lens.tolist()
284+
new_tokens = state.host.new_tokens
285+
new_tokens_lens = state.host.new_tokens_lens
268286
next_draft_tokens_list = state.host.next_draft_tokens.tolist()
269-
270-
idx = 0
271-
beam_idx = 0
272-
for request in state.scheduled_requests.context_requests:
273-
assert not request.py_return_context_logits, "return_context_logits not implemented for MTPSampler"
274-
assert not request.py_return_generation_logits, "return_generation_logits not implemented for MTPSampler"
275-
assert not request.py_return_log_probs, "return_log_probs not implemented for MTPSampler"
276-
if request.context_remaining_length != 0:
277-
idx += 1
287+
beam_idx = self.BEAM
288+
for req in state.scheduled_requests.context_requests:
289+
if req.state == LlmRequestState.GENERATION_COMPLETE or req.context_remaining_length != 0:
278290
continue
291+
new_token = add_token(req, new_tokens, beam=beam_idx)
292+
self._handle_stop_criteria(req, new_token)
293+
self._request_common_handling(req, next_draft_tokens_list)
279294

280-
if request.state != LlmRequestState.GENERATION_COMPLETE:
281-
new_token = new_tokens_list[idx][0]
282-
num_tokens = request.add_new_token(new_token, beam_idx)
283-
should_stop = self._handle_stop_criteria(request,
284-
new_token,
285-
beam=beam_idx)
286-
if self._draft_meet_max_token_stop_criteria(
287-
request, num_tokens, beam_idx):
288-
should_stop = True
289-
request.py_draft_tokens = next_draft_tokens_list[idx]
290-
request.py_decoding_iter += 1
291-
idx += 1
292-
293-
for request in state.scheduled_requests.generation_requests:
294-
assert not request.py_return_context_logits, "return_context_logits not implemented for MTPSampler"
295-
assert not request.py_return_generation_logits, "return_generation_logits not implemented for MTPSampler"
296-
assert not request.py_return_log_probs, "return_log_probs not implemented for MTPSampler"
297-
if request.state != LlmRequestState.GENERATION_COMPLETE:
298-
new_tokens = new_tokens_list[idx]
299-
num_new_tokens = new_tokens_lens_list[idx]
300-
should_stop = False
301-
for i in range(num_new_tokens):
302-
new_token = new_tokens[i]
303-
num_tokens = request.add_new_token(new_token, beam_idx)
304-
should_stop = self._handle_stop_criteria(request,
305-
new_token,
306-
beam=beam_idx)
307-
if should_stop:
308-
break
309-
if self._draft_meet_max_token_stop_criteria(
310-
request, num_tokens, beam_idx):
311-
should_stop = True
312-
request.py_draft_tokens = next_draft_tokens_list[idx]
313-
request.py_rewind_len = self.draft_len - (num_new_tokens - 1)
314-
request.py_decoding_iter += 1
315-
idx += 1
295+
for req in state.scheduled_requests.generation_requests:
296+
if req.state == LlmRequestState.GENERATION_COMPLETE:
297+
continue
298+
num_new_tokens = new_tokens_lens[req.seq_slot]
299+
for i in range(num_new_tokens):
300+
new_token = add_token(req, new_tokens, beam=beam_idx, step=i)
301+
if self._handle_stop_criteria(req, new_token):
302+
break
303+
req.py_rewind_len = self.draft_len - (num_new_tokens - 1)
304+
self._request_common_handling(req, next_draft_tokens_list)
316305

317306
def sample_async(self, scheduled_requests: ScheduledRequests,
318-
model_outputs) -> SampleStateMTP:
319-
# new_tokens_device: all of the accepted tokens, device tensor
320-
# new_tokens_lens_device: the accepted lengths, device tensor
321-
# next_draft_tokens_device: predicted draft tokens, device tensor
322-
# next_new_tokens_device: input tokens for the next iteration, device tensor
323-
new_tokens_device = model_outputs['new_tokens']
324-
new_tokens_lens_device = model_outputs['new_tokens_lens']
325-
next_draft_tokens_device = model_outputs['next_draft_tokens']
326-
next_new_tokens_device = model_outputs['next_new_tokens']
307+
outputs: dict[str, torch.Tensor]) -> SampleStateMTP:
308+
# new_tokens_device: accepted tokens, device tensor, shape: batch_size, nextn + 1
309+
# new_tokens_lens_device: accepted lengths, device tensor, shape: batch_size
310+
# next_draft_tokens_device: predicted draft tokens, device tensor, shape: batch_size, nextn
311+
# next_new_tokens_device: input tokens for the next iteration, device tensor, shape: batch_size, nextn + 1
312+
313+
requests = scheduled_requests.all_requests()
314+
slots = torch.as_tensor([r.seq_slot for r in requests])
315+
slots = slots.to(device="cuda", non_blocking=True)
316+
317+
o_new_tokens = outputs['new_tokens'][:len(requests)]
318+
o_new_tokens_lens = outputs['new_tokens_lens'][:len(requests)]
319+
o_next_draft_tokens = outputs['next_draft_tokens'][:len(requests)]
320+
o_next_new_tokens = outputs['next_new_tokens'][:len(requests)]
321+
322+
new_tokens = self.store.new_tokens
323+
next_new_tokens = self.store.next_new_tokens
324+
new_tokens_lens = self.store.new_tokens_lens
325+
next_draft_tokens = self.store.next_draft_tokens
326+
327+
new_tokens.squeeze(-1).T.index_copy_(0, slots, o_new_tokens)
328+
next_new_tokens.squeeze(-1).T.index_copy_(0, slots, o_next_new_tokens)
329+
new_tokens_lens.index_copy_(0, slots, o_new_tokens_lens)
330+
next_draft_tokens.index_copy_(0, slots, o_next_draft_tokens)
327331

328332
device = SampleStateTensorsMTP(
329-
new_tokens=next_new_tokens_device,
330-
new_tokens_lens=new_tokens_lens_device,
331-
next_draft_tokens=next_draft_tokens_device,
333+
new_tokens=next_new_tokens,
334+
new_tokens_lens=new_tokens_lens,
335+
next_draft_tokens=next_draft_tokens,
332336
)
333337
host = SampleStateTensorsMTP(
334-
new_tokens=new_tokens_device.to('cpu', non_blocking=True),
335-
new_tokens_lens=new_tokens_lens_device.to('cpu', non_blocking=True),
336-
next_draft_tokens=next_draft_tokens_device.to('cpu',
337-
non_blocking=True),
338+
new_tokens=new_tokens.to('cpu', non_blocking=True),
339+
new_tokens_lens=new_tokens_lens.to('cpu', non_blocking=True),
340+
next_draft_tokens=next_draft_tokens.to('cpu', non_blocking=True),
338341
)
339342
sampler_event = torch.cuda.Event()
340343
sampler_event.record()

0 commit comments

Comments
 (0)