@@ -272,6 +272,8 @@ def __init__(self,
272272 self ._done = False
273273 self .metrics_dict = {}
274274 self .trace_headers : Optional [dict [str , str ]] = None
275+ # torch backend will use trtllm sampler in beam search mode, but it does not support return logprobs incrementally
276+ self .use_trtllm_sampler = sampling_params .use_beam_search and sampling_params .best_of > 1
275277
276278 if ray_queue is not None :
277279 if has_event_loop ():
@@ -378,20 +380,27 @@ def _handle_sequence(self,
378380 # each streamed response_tensors.log_probs[src_idx]
379381 # contains a streamwise monotonically growing list of logprobs.
380382 # so we need to accumulate only the new ones unique to that particular streamed response
381- assert output ._last_logprobs_len <= len (
382- response_tensors .log_probs [src_idx ]
383- ), (f"_last_logprobs_len ({ output ._last_logprobs_len } ) > log_probs length ("
384- f"{ len (response_tensors .log_probs [src_idx ])} )" )
385- output .logprobs += response_tensors .log_probs [src_idx ][
386- output ._last_logprobs_len :]
383+ if self .use_trtllm_sampler :
384+ assert output ._last_logprobs_len <= len (
385+ response_tensors .log_probs [src_idx ]
386+ ), (f"_last_logprobs_len ({ output ._last_logprobs_len } ) > log_probs length ("
387+ f"{ len (response_tensors .log_probs [src_idx ])} )" )
388+ output .logprobs += response_tensors .log_probs [src_idx ][
389+ output ._last_logprobs_len :]
390+ else :
391+ output .logprobs += response_tensors .log_probs [src_idx ]
392+
387393 # overcome some WAR in the cpp executor
388- if finish_reasons [src_idx ] != tllm .FinishReason .CANCELLED :
394+ if finish_reasons [
395+ src_idx ] != tllm .FinishReason .CANCELLED and self .use_trtllm_sampler :
389396 # Check if logprobs is a list (not a dict or other structure)
390397 if len (output .logprobs ) > output .length :
391398 # LlmResult holds a reference to LogProbStorage, which may be updated by the worker before the result is serialized.
392399 # Therefore, we treat extra logprobs/logits as expected and only consume what's needed.
393400 output .logprobs = output .logprobs [:output .length ]
394- assert len (output .logprobs ) == output .length
401+ assert len (
402+ output .logprobs
403+ ) == output .length , f"logprobs length: { len (output .logprobs )} != output.length: { output .length } "
395404
396405 if response_tensors .generation_logits is not None :
397406 output .generation_logits = response_tensors .generation_logits [
0 commit comments