Skip to content

support logprobs #3852

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lmdeploy/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def add_parser_api_server():
ArgumentHelper.device(pt_group)
ArgumentHelper.eager_mode(pt_group)
ArgumentHelper.disable_vision_encoder(pt_group)
ArgumentHelper.logprobs_mode(pt_group)

# common engine args
dtype_act = ArgumentHelper.dtype(pt_group)
Expand Down Expand Up @@ -217,6 +218,7 @@ def api_server(args):
model_format=args.model_format,
hf_overrides=args.hf_overrides,
disable_vision_encoder=args.disable_vision_encoder,
logprobs_mode=args.logprobs_mode,
)
else:
from lmdeploy.messages import TurbomindEngineConfig
Expand Down
9 changes: 9 additions & 0 deletions lmdeploy/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,15 @@ def disable_vision_encoder(parser):
default=False,
help='enable metrics system')

@staticmethod
def logprobs_mode(parser):
"""The mode of logprobs."""
parser.add_argument('--logprobs-mode',
type=str,
default=None,
choices=[None, 'raw_logits', 'raw_logprobs'],
help='The mode of logprobs.')


# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/utils/__init__.py
class FlexibleArgumentParser(argparse.ArgumentParser):
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ class PytorchEngineConfig:
It can be used to override the default config of the model,
disable_vision_encoder (bool): Whether to disable loading vision
encoder. Default to False.
logprobs_mode (str): The mode of logprob, options: ['raw_logits', 'raw_logprobs']
"""
dtype: str = 'auto'
tp: int = 1
Expand Down Expand Up @@ -363,6 +364,7 @@ class PytorchEngineConfig:
enable_metrics: bool = False
hf_overrides: Optional[Dict[str, Any]] = None
disable_vision_encoder: bool = False
logprobs_mode: str = None

role: EngineRole = EngineRole.Hybrid
migration_backend: MigrationBackend = MigrationBackend.DLSlime
Expand Down
4 changes: 3 additions & 1 deletion lmdeploy/pytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ class MiscConfig:
model_format: str = None
hf_overrides: Dict[str, Any] = None
disable_vision_encoder: bool = False
logprobs_mode: str = None

@classmethod
def from_engine_config(cls, engine_config: PytorchEngineConfig):
Expand All @@ -302,5 +303,6 @@ def from_engine_config(cls, engine_config: PytorchEngineConfig):
prefill_interval=engine_config.prefill_interval,
model_format=engine_config.model_format,
hf_overrides=engine_config.hf_overrides,
disable_vision_encoder=engine_config.disable_vision_encoder)
disable_vision_encoder=engine_config.disable_vision_encoder,
logprobs_mode=engine_config.logprobs_mode)
return misc_config
41 changes: 35 additions & 6 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .engine_checker import EngineChecker
from .executor import build_executor
from .logits_process import SamplingInputs
from .model_agent import BatchedOutputs
from .request import Request, RequestManager, RequestType, Response

logger = get_logger('lmdeploy')
Expand All @@ -45,6 +46,7 @@ class InferOutput:
meta: Any = None
finish: bool = False
logits: torch.Tensor = None
logprobs: torch.Tensor = None

# send cache blocks back for migration in Disaggregated LLM Serving
# when Prefill Engine is Done.
Expand Down Expand Up @@ -813,9 +815,18 @@ def update_running_migration(self, running: SeqList, next_token_ids: np.ndarray,
msg.update_token_ids(update_token, model_meta=model_meta)
msg.status = MessageStatus.STOPPED

def _make_infer_outputs(self, new_token_timestamp: float, next_token_ids: torch.LongTensor, running: SeqList,
logits: torch.Tensor, stopped: torch.Tensor, model_metas: List[Dict[str, Any]]):
def _make_infer_outputs(
self,
batched_outputs: BatchedOutputs,
running: SeqList,
):
"""Make infer output."""
new_token_timestamp = batched_outputs.new_token_timestamp
next_token_ids = batched_outputs.next_token_ids
logits = batched_outputs.logits
stopped = batched_outputs.stopped
model_metas = batched_outputs.model_metas
logprobs = batched_outputs.logprobs

seq_length = [seq.num_token_ids for seq in running]
is_run = [seq.status == MessageStatus.LOCKED for seq in running]
Expand All @@ -836,13 +847,21 @@ def _make_infer_outputs(self, new_token_timestamp: float, next_token_ids: torch.
cache_block_ids = self.scheduler.block_manager.get_block_table(msg).tolist()
else:
cache_block_ids = None

# logprobs
num_logprobs = msg.sampling_param.num_logprobs
cur_logprobs = None
if num_logprobs >= 0:
cur_logprobs = (logprobs.vals[idx, :num_logprobs + 1], logprobs.indices[idx, :num_logprobs + 1])

req_metrics = RequestMetrics(new_token_timestamp, msg.engine_events)
out = InferOutput(session_id=session_id,
resp=msg.resp,
finish=finish,
token_ids=token_ids,
cache_block_ids=cache_block_ids,
req_metrics=req_metrics)
req_metrics=req_metrics,
logprobs=cur_logprobs)
outputs[session_id] = out

if msg.return_logits:
Expand Down Expand Up @@ -974,12 +993,22 @@ def __log_resps(outputs: List[InferOutput]):
def __send_resp(out: InferOutput):
"""Send response."""
resp_type = (ResponseType.FINISH if out.finish else ResponseType.SUCCESS)
cur_logprobs = out.logprobs
logprobs = None
if cur_logprobs is not None:
# logprobs to dict
vals = cur_logprobs[0].tolist()
indices = cur_logprobs[1].tolist()
cur_logprobs = dict(zip(indices, vals))
logprobs = [] if out.resp.data is None else out.resp.data.get('logprobs', [])
logprobs = logprobs + [cur_logprobs]
self._response(out.resp,
resp_type,
data=dict(token_ids=out.token_ids,
logits=out.logits,
cache_block_ids=out.cache_block_ids,
req_metrics=out.req_metrics))
req_metrics=out.req_metrics,
logprobs=logprobs))

def __send_resps(step_outputs: List[InferOutput]):
"""Send response callback."""
Expand Down Expand Up @@ -1115,8 +1144,8 @@ async def _async_loop_main(

# send output
out = await self.executor.get_output_async()
if len(out) > 0:
step_outputs = self._make_infer_outputs(**out, running=running)
if out is not None:
step_outputs = self._make_infer_outputs(out, running=running)
resp_que.put_nowait(step_outputs)

# lock forward event
Expand Down
7 changes: 5 additions & 2 deletions lmdeploy/pytorch/engine/engine_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ async def async_stream_infer(self,

cache_block_ids = resp.data.get('cache_block_ids', None) if resp.data else None
req_metrics = resp.data.get('req_metrics', None) if resp.data else None
logprobs = resp.data.get('logprobs', None) if resp.data else None
if resp.type == ResponseType.SUCCESS:
token_ids = resp.data['token_ids'].tolist()
num_ids = len(token_ids)
Expand All @@ -157,7 +158,8 @@ async def async_stream_infer(self,
token_ids,
num_ids,
cache_block_ids=cache_block_ids,
req_metrics=req_metrics)
req_metrics=req_metrics,
logprobs=logprobs)
elif resp.type == ResponseType.FINISH:
resp_data = resp.data
token_ids = resp_data['token_ids'].tolist()
Expand All @@ -169,7 +171,8 @@ async def async_stream_infer(self,
num_ids,
logits=logits,
cache_block_ids=cache_block_ids,
req_metrics=req_metrics)
req_metrics=req_metrics,
logprobs=logprobs)
break
else:
logger.debug(f'session[{session_id}] failed.')
Expand Down
14 changes: 2 additions & 12 deletions lmdeploy/pytorch/engine/executor/ray_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import time
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import ray
import ray.exceptions
import torch
Expand Down Expand Up @@ -326,13 +325,7 @@ def warmup_dist(self):

def pack_output(self, output: Dict):
"""Pack output."""
for k, v in output.items():
if isinstance(v, torch.Tensor):
# fix numpy do not have BFloat16 type
if v.dtype is torch.bfloat16:
v = v.to(torch.float16)
output[k] = v.numpy()
return output
return output.to_numpy()

def remote_log_start(self, msg: str):
"""Remote log start."""
Expand Down Expand Up @@ -487,10 +480,7 @@ async def _prefetch_outputs(self):
outs = await self.workers[0].get_outputs.remote()
logger.debug(f'Receive {len(outs)} outputs from worker[0].')
for out in outs:
# pack pytorch
for k, v in out.items():
if isinstance(v, np.ndarray):
out[k] = torch.from_numpy(v)
out = out.to_tensor()
self.remote_outs.put_nowait(out)

def _prefetch_task_callback(self, task: asyncio.Task):
Expand Down
41 changes: 39 additions & 2 deletions lmdeploy/pytorch/engine/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class SamplingInputs:
min_top_p: float = 1.0
response_formats: Tuple[str] = ()
logits_processors: List[List[LogitsProcessor]] = None
max_num_logprobs: Optional[int] = None

@classmethod
def from_sampling_params(cls, seqs: List[SchedulerSequence]):
Expand All @@ -142,6 +143,7 @@ def from_sampling_params(cls, seqs: List[SchedulerSequence]):
random_offsets = [None] * batch_size
response_formats = [None] * batch_size
logits_processors = [None] * batch_size
num_logprobs = [None] * batch_size

def __gather_params():
"""Gather params."""
Expand All @@ -164,6 +166,7 @@ def __gather_params():
bad_words[idx] = bw
stop_words[idx] = sw
logits_processors[idx] = param.logits_processors
num_logprobs[idx] = param.num_logprobs

def __get_topp(top_p):
"""Get topp."""
Expand Down Expand Up @@ -232,6 +235,8 @@ def __get_bad_words(bad_words):
random_seeds = torch.tensor(random_seeds)
random_offsets = torch.tensor(random_offsets)

max_num_logprobs = max(num_logprobs)

sampling_input = cls(
temperature=temperature,
bad_words=bad_words,
Expand All @@ -248,6 +253,7 @@ def __get_bad_words(bad_words):
max_top_k=max_top_k,
min_top_p=min_top_p,
logits_processors=logits_processors,
max_num_logprobs=max_num_logprobs,
)
return sampling_input

Expand Down Expand Up @@ -280,11 +286,13 @@ def __init__(self,
sampling_inputs: SamplingInputs,
ignore_eos: torch.Tensor,
tokenizer: Optional[Tokenizer] = None,
sampling_vocab_size: Optional[int] = None):
sampling_vocab_size: Optional[int] = None,
logprobs_mode: Optional[str] = None):
self.sampling_inputs: SamplingInputs = sampling_inputs
self.ignore_eos = ignore_eos
self.tokenizer = tokenizer
self.sampling_vocab_size = sampling_vocab_size
self.logprobs_mode = logprobs_mode

async def _wait_stream_once(self):
"""Wait stream once."""
Expand All @@ -309,6 +317,19 @@ async def __call__(self, all_ids: torch.LongTensor, guided_input_ids: torch.Long
torch.FloatTensor: The processed prediction scores.

"""

num_logprobs = self.sampling_inputs.max_num_logprobs
# get raw logprobs
if num_logprobs < 0:
logprobs = None
else:
if self.logprobs_mode == 'raw_logits':
logprobs = scores.clone()
elif self.logprobs_mode == 'raw_logprobs':
logprobs = scores.log_softmax(dim=-1)
else:
logprobs = None

sampling_inputs = self.sampling_inputs

custom_logits_processors = self.sampling_inputs.logits_processors
Expand Down Expand Up @@ -338,7 +359,7 @@ async def __call__(self, all_ids: torch.LongTensor, guided_input_ids: torch.Long
if guided_input_ids is not None:
await self._wait_stream_once()
scores = _guided_sampling(sampling_inputs.response_formats, scores, guided_input_ids, self.tokenizer)
return scores
return scores, logprobs

@torch.inference_mode()
def sampling(self, logits: torch.Tensor):
Expand Down Expand Up @@ -384,3 +405,19 @@ def __random_sampling(scores: torch.Tensor, indices: torch.LongTensor):
else:
scores, indices = logits.topk(max_topk, dim=1)
return __random_sampling(scores, indices)

@torch.inference_mode()
def compute_logprobs(self, raw_logprobs: torch.Tensor, token_ids: torch.LongTensor):
"""Compute logprobs."""
if raw_logprobs is None:
return None

indices = token_ids.unsqueeze(-1)
logprobs = raw_logprobs.gather(-1, indices)
num_logprobs = self.sampling_inputs.max_num_logprobs
if num_logprobs > 0:
topk_logprobs, topk_indices = raw_logprobs.topk(num_logprobs, dim=-1)
logprobs = torch.cat([logprobs, topk_logprobs], dim=-1)
indices = torch.cat([indices, topk_indices], dim=-1)

return logprobs, indices.to(torch.int32)
Loading