Skip to content

Commit b53961e

Browse files
authored
[None][feat] Return logprobs incrementally in torch backend (#8785)
Signed-off-by: Dong Cao <[email protected]>
1 parent 9f8d93f commit b53961e

File tree

2 files changed

+32
-11
lines changed

2 files changed

+32
-11
lines changed

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from copy import deepcopy
1+
from copy import copy, deepcopy
22
from dataclasses import dataclass
33
from typing import Any, Dict, List, Optional, Union
44

@@ -327,7 +327,8 @@ def generation_logits(self) -> torch.Tensor | None:
327327

328328
@property
329329
def log_probs(self) -> list[TokenLogprobs] | None:
330-
return self._log_probs and self._log_probs.log_probs
330+
return self._log_probs and hasattr(
331+
self._log_probs, 'log_probs') and self._log_probs.log_probs
331332

332333
@property
333334
def cum_log_probs(self) -> list[float] | None:
@@ -589,10 +590,21 @@ def create_response(self,
589590
"""
590591
result, is_final = super().create_serialized_result(
591592
use_fast_logits, mpi_world_rank)
593+
594+
# Performs a deep copy of py_result._log_probs to eliminate race conditions that may occur between IPC communication and the overriding of newly generated log_probs in streaming mode.
595+
if self.streaming and self.py_result.log_probs and self.sampling_config.beam_width <= 1:
596+
py_result = copy(self.py_result)
597+
py_result._log_probs = deepcopy(self.py_result._log_probs)
598+
599+
for log_prob in self.py_result.log_probs:
600+
log_prob.clear()
601+
else:
602+
py_result = self.py_result
603+
592604
return LlmResponse(
593605
request_id=self.py_request_id
594606
if self.is_child else self.parent_request_id,
595-
result=LlmResult(result, self.py_result, is_final),
607+
result=LlmResult(result, py_result, is_final),
596608
client_id=self.py_client_id) if len(result) > 0 else None
597609

598610
@property

tensorrt_llm/executor/result.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,8 @@ def __init__(self,
272272
self._done = False
273273
self.metrics_dict = {}
274274
self.trace_headers: Optional[dict[str, str]] = None
275+
# torch backend will use trtllm sampler in beam search mode, but it does not support return logprobs incrementally
276+
self.use_trtllm_sampler = sampling_params.use_beam_search and sampling_params.best_of > 1
275277

276278
if ray_queue is not None:
277279
if has_event_loop():
@@ -378,20 +380,27 @@ def _handle_sequence(self,
378380
# each streamed response_tensors.log_probs[src_idx]
379381
# contains a streamwise monotonically growing list of logprobs.
380382
# so we need to accumulate only the new ones unique to that particular streamed response
381-
assert output._last_logprobs_len <= len(
382-
response_tensors.log_probs[src_idx]
383-
), (f"_last_logprobs_len ({output._last_logprobs_len}) > log_probs length ("
384-
f"{len(response_tensors.log_probs[src_idx])})")
385-
output.logprobs += response_tensors.log_probs[src_idx][
386-
output._last_logprobs_len:]
383+
if self.use_trtllm_sampler:
384+
assert output._last_logprobs_len <= len(
385+
response_tensors.log_probs[src_idx]
386+
), (f"_last_logprobs_len ({output._last_logprobs_len}) > log_probs length ("
387+
f"{len(response_tensors.log_probs[src_idx])})")
388+
output.logprobs += response_tensors.log_probs[src_idx][
389+
output._last_logprobs_len:]
390+
else:
391+
output.logprobs += response_tensors.log_probs[src_idx]
392+
387393
# overcome some WAR in the cpp executor
388-
if finish_reasons[src_idx] != tllm.FinishReason.CANCELLED:
394+
if finish_reasons[
395+
src_idx] != tllm.FinishReason.CANCELLED and self.use_trtllm_sampler:
389396
# Check if logprobs is a list (not a dict or other structure)
390397
if len(output.logprobs) > output.length:
391398
# LlmResult holds a reference to LogProbStorage, which may be updated by the worker before the result is serialized.
392399
# Therefore, we treat extra logprobs/logits as expected and only consume what's needed.
393400
output.logprobs = output.logprobs[:output.length]
394-
assert len(output.logprobs) == output.length
401+
assert len(
402+
output.logprobs
403+
) == output.length, f"logprobs length: {len(output.logprobs)} != output.length: {output.length}"
395404

396405
if response_tensors.generation_logits is not None:
397406
output.generation_logits = response_tensors.generation_logits[

0 commit comments

Comments
 (0)