|
4 | 4 | import torch |
5 | 5 | from torch import nn |
6 | 6 |
|
7 | | -from tensorrt_llm.bindings.executor import FinishReason |
8 | | - |
9 | 7 | from ..attention_backend import AttentionMetadata |
10 | 8 | from ..pyexecutor.llm_request import LlmRequest, LlmRequestState |
11 | 9 | from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager |
12 | | -from ..pyexecutor.sampler import SampleState, SampleStateTensors, TorchSampler |
| 10 | +from ..pyexecutor.sampler import (SampleState, SampleStateTensors, TorchSampler, |
| 11 | + add_token, int_tensor) |
13 | 12 | from ..pyexecutor.scheduler import ScheduledRequests |
14 | 13 | from .interface import SpecConfig, SpecMetadata, SpeculativeDecodingMode |
15 | 14 |
|
@@ -249,92 +248,96 @@ class MTPSampler(TorchSampler): |
249 | 248 | SampleState = SampleStateMTP |
250 | 249 |
|
251 | 250 | def __init__(self, args: TorchSampler.Args, *, nextn: int): |
252 | | - super().__init__(args) |
253 | 251 | self.mapping = None |
254 | 252 | self.draft_len = nextn |
| 253 | + super().__init__(args) |
255 | 254 |
|
256 | | - def _draft_meet_max_token_stop_criteria(self, request: LlmRequest, |
257 | | - num_tokens: int, beam_idx: int): |
258 | | - if self._meet_max_token_stop_criteria(request, num_tokens): |
259 | | - request.state = LlmRequestState.GENERATION_COMPLETE |
260 | | - request.set_finished_reason(FinishReason.LENGTH, beam_idx) |
| 255 | + @dataclass(frozen=True, kw_only=True) |
| 256 | + class Store(TorchSampler.Store): |
| 257 | + next_new_tokens: torch.Tensor |
| 258 | + next_draft_tokens: torch.Tensor |
| 259 | + new_tokens_lens: torch.Tensor |
| 260 | + |
| 261 | + def create_store(self) -> Store: |
| 262 | + num_tokens, seq_slots, _ = self.NEW_TOKENS_SHAPE |
| 263 | + draft_len = num_tokens - 1 |
| 264 | + assert draft_len == self.draft_len |
| 265 | + return self.Store( |
| 266 | + new_tokens=int_tensor(self.NEW_TOKENS_SHAPE), |
| 267 | + next_new_tokens=int_tensor(self.NEW_TOKENS_SHAPE), |
| 268 | + next_draft_tokens=int_tensor((seq_slots, draft_len)), |
| 269 | + new_tokens_lens=int_tensor((seq_slots, )), |
| 270 | + ) |
| 271 | + |
| 272 | + def _request_common_handling(self, request: LlmRequest, |
| 273 | + next_draft_tokens: list[list[int]]): |
| 274 | + assert not request.py_return_context_logits, "return_context_logits not implemented for MTPSampler" |
| 275 | + assert not request.py_return_generation_logits, "return_generation_logits not implemented for MTPSampler" |
| 276 | + assert not request.py_return_log_probs, "return_log_probs not implemented for MTPSampler" |
| 277 | + request.py_draft_tokens = next_draft_tokens[request.seq_slot] |
| 278 | + request.py_decoding_iter += 1 |
261 | 279 |
|
262 | 280 | def update_requests(self, state: SampleStateMTP) -> None: |
263 | 281 | assert isinstance(state, SampleStateMTP) |
264 | 282 |
|
265 | 283 | state.sampler_event.synchronize() |
266 | | - new_tokens_list = state.host.new_tokens.tolist() |
267 | | - new_tokens_lens_list = state.host.new_tokens_lens.tolist() |
| 284 | + new_tokens = state.host.new_tokens |
| 285 | + new_tokens_lens = state.host.new_tokens_lens |
268 | 286 | next_draft_tokens_list = state.host.next_draft_tokens.tolist() |
269 | | - |
270 | | - idx = 0 |
271 | | - beam_idx = 0 |
272 | | - for request in state.scheduled_requests.context_requests: |
273 | | - assert not request.py_return_context_logits, "return_context_logits not implemented for MTPSampler" |
274 | | - assert not request.py_return_generation_logits, "return_generation_logits not implemented for MTPSampler" |
275 | | - assert not request.py_return_log_probs, "return_log_probs not implemented for MTPSampler" |
276 | | - if request.context_remaining_length != 0: |
277 | | - idx += 1 |
| 287 | + beam_idx = self.BEAM |
| 288 | + for req in state.scheduled_requests.context_requests: |
| 289 | + if req.state == LlmRequestState.GENERATION_COMPLETE or req.context_remaining_length != 0: |
278 | 290 | continue |
| 291 | + new_token = add_token(req, new_tokens, beam=beam_idx) |
| 292 | + self._handle_stop_criteria(req, new_token) |
| 293 | + self._request_common_handling(req, next_draft_tokens_list) |
279 | 294 |
|
280 | | - if request.state != LlmRequestState.GENERATION_COMPLETE: |
281 | | - new_token = new_tokens_list[idx][0] |
282 | | - num_tokens = request.add_new_token(new_token, beam_idx) |
283 | | - should_stop = self._handle_stop_criteria(request, |
284 | | - new_token, |
285 | | - beam=beam_idx) |
286 | | - if self._draft_meet_max_token_stop_criteria( |
287 | | - request, num_tokens, beam_idx): |
288 | | - should_stop = True |
289 | | - request.py_draft_tokens = next_draft_tokens_list[idx] |
290 | | - request.py_decoding_iter += 1 |
291 | | - idx += 1 |
292 | | - |
293 | | - for request in state.scheduled_requests.generation_requests: |
294 | | - assert not request.py_return_context_logits, "return_context_logits not implemented for MTPSampler" |
295 | | - assert not request.py_return_generation_logits, "return_generation_logits not implemented for MTPSampler" |
296 | | - assert not request.py_return_log_probs, "return_log_probs not implemented for MTPSampler" |
297 | | - if request.state != LlmRequestState.GENERATION_COMPLETE: |
298 | | - new_tokens = new_tokens_list[idx] |
299 | | - num_new_tokens = new_tokens_lens_list[idx] |
300 | | - should_stop = False |
301 | | - for i in range(num_new_tokens): |
302 | | - new_token = new_tokens[i] |
303 | | - num_tokens = request.add_new_token(new_token, beam_idx) |
304 | | - should_stop = self._handle_stop_criteria(request, |
305 | | - new_token, |
306 | | - beam=beam_idx) |
307 | | - if should_stop: |
308 | | - break |
309 | | - if self._draft_meet_max_token_stop_criteria( |
310 | | - request, num_tokens, beam_idx): |
311 | | - should_stop = True |
312 | | - request.py_draft_tokens = next_draft_tokens_list[idx] |
313 | | - request.py_rewind_len = self.draft_len - (num_new_tokens - 1) |
314 | | - request.py_decoding_iter += 1 |
315 | | - idx += 1 |
| 295 | + for req in state.scheduled_requests.generation_requests: |
| 296 | + if req.state == LlmRequestState.GENERATION_COMPLETE: |
| 297 | + continue |
| 298 | + num_new_tokens = new_tokens_lens[req.seq_slot] |
| 299 | + for i in range(num_new_tokens): |
| 300 | + new_token = add_token(req, new_tokens, beam=beam_idx, step=i) |
| 301 | + if self._handle_stop_criteria(req, new_token): |
| 302 | + break |
| 303 | + req.py_rewind_len = self.draft_len - (num_new_tokens - 1) |
| 304 | + self._request_common_handling(req, next_draft_tokens_list) |
316 | 305 |
|
317 | 306 | def sample_async(self, scheduled_requests: ScheduledRequests, |
318 | | - model_outputs) -> SampleStateMTP: |
319 | | - # new_tokens_device: all of the accepted tokens, device tensor |
320 | | - # new_tokens_lens_device: the accepted lengths, device tensor |
321 | | - # next_draft_tokens_device: predicted draft tokens, device tensor |
322 | | - # next_new_tokens_device: input tokens for the next iteration, device tensor |
323 | | - new_tokens_device = model_outputs['new_tokens'] |
324 | | - new_tokens_lens_device = model_outputs['new_tokens_lens'] |
325 | | - next_draft_tokens_device = model_outputs['next_draft_tokens'] |
326 | | - next_new_tokens_device = model_outputs['next_new_tokens'] |
| 307 | + outputs: dict[str, torch.Tensor]) -> SampleStateMTP: |
| 308 | + # new_tokens_device: accepted tokens, device tensor, shape: batch_size, nextn + 1 |
| 309 | + # new_tokens_lens_device: accepted lengths, device tensor, shape: batch_size |
| 310 | + # next_draft_tokens_device: predicted draft tokens, device tensor, shape: batch_size, nextn |
| 311 | + # next_new_tokens_device: input tokens for the next iteration, device tensor, shape: batch_size, nextn + 1 |
| 312 | + |
| 313 | + requests = scheduled_requests.all_requests() |
| 314 | + slots = torch.as_tensor([r.seq_slot for r in requests]) |
| 315 | + slots = slots.to(device="cuda", non_blocking=True) |
| 316 | + |
| 317 | + o_new_tokens = outputs['new_tokens'][:len(requests)] |
| 318 | + o_new_tokens_lens = outputs['new_tokens_lens'][:len(requests)] |
| 319 | + o_next_draft_tokens = outputs['next_draft_tokens'][:len(requests)] |
| 320 | + o_next_new_tokens = outputs['next_new_tokens'][:len(requests)] |
| 321 | + |
| 322 | + new_tokens = self.store.new_tokens |
| 323 | + next_new_tokens = self.store.next_new_tokens |
| 324 | + new_tokens_lens = self.store.new_tokens_lens |
| 325 | + next_draft_tokens = self.store.next_draft_tokens |
| 326 | + |
| 327 | + new_tokens.squeeze(-1).T.index_copy_(0, slots, o_new_tokens) |
| 328 | + next_new_tokens.squeeze(-1).T.index_copy_(0, slots, o_next_new_tokens) |
| 329 | + new_tokens_lens.index_copy_(0, slots, o_new_tokens_lens) |
| 330 | + next_draft_tokens.index_copy_(0, slots, o_next_draft_tokens) |
327 | 331 |
|
328 | 332 | device = SampleStateTensorsMTP( |
329 | | - new_tokens=next_new_tokens_device, |
330 | | - new_tokens_lens=new_tokens_lens_device, |
331 | | - next_draft_tokens=next_draft_tokens_device, |
| 333 | + new_tokens=next_new_tokens, |
| 334 | + new_tokens_lens=new_tokens_lens, |
| 335 | + next_draft_tokens=next_draft_tokens, |
332 | 336 | ) |
333 | 337 | host = SampleStateTensorsMTP( |
334 | | - new_tokens=new_tokens_device.to('cpu', non_blocking=True), |
335 | | - new_tokens_lens=new_tokens_lens_device.to('cpu', non_blocking=True), |
336 | | - next_draft_tokens=next_draft_tokens_device.to('cpu', |
337 | | - non_blocking=True), |
| 338 | + new_tokens=new_tokens.to('cpu', non_blocking=True), |
| 339 | + new_tokens_lens=new_tokens_lens.to('cpu', non_blocking=True), |
| 340 | + next_draft_tokens=next_draft_tokens.to('cpu', non_blocking=True), |
338 | 341 | ) |
339 | 342 | sampler_event = torch.cuda.Event() |
340 | 343 | sampler_event.record() |
|
0 commit comments