diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index ecd9fdae05..0db760585a 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -91,6 +91,7 @@ def add_parser_api_server(): ArgumentHelper.device(pt_group) ArgumentHelper.eager_mode(pt_group) ArgumentHelper.disable_vision_encoder(pt_group) + ArgumentHelper.logprobs_mode(pt_group) # common engine args dtype_act = ArgumentHelper.dtype(pt_group) @@ -217,6 +218,7 @@ def api_server(args): model_format=args.model_format, hf_overrides=args.hf_overrides, disable_vision_encoder=args.disable_vision_encoder, + logprobs_mode=args.logprobs_mode, ) else: from lmdeploy.messages import TurbomindEngineConfig diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index c165b01e1e..9787ac79dd 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -601,6 +601,15 @@ def disable_vision_encoder(parser): default=False, help='enable metrics system') + @staticmethod + def logprobs_mode(parser): + """The mode of logprobs.""" + parser.add_argument('--logprobs-mode', + type=str, + default=None, + choices=[None, 'raw_logits', 'raw_logprobs'], + help='The mode of logprobs.') + # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/utils/__init__.py class FlexibleArgumentParser(argparse.ArgumentParser): diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 215c730ad0..eb897e5682 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -331,6 +331,7 @@ class PytorchEngineConfig: It can be used to override the default config of the model, disable_vision_encoder (bool): Whether to disable loading vision encoder. Default to False. + logprobs_mode (str): The mode of logprob, options: ['raw_logits', 'raw_logprobs'] """ dtype: str = 'auto' tp: int = 1 @@ -363,6 +364,7 @@ class PytorchEngineConfig: enable_metrics: bool = False hf_overrides: Optional[Dict[str, Any]] = None disable_vision_encoder: bool = False + logprobs_mode: str = None role: EngineRole = EngineRole.Hybrid migration_backend: MigrationBackend = MigrationBackend.DLSlime diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 2384f51f31..05c716c6a9 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -293,6 +293,7 @@ class MiscConfig: model_format: str = None hf_overrides: Dict[str, Any] = None disable_vision_encoder: bool = False + logprobs_mode: str = None @classmethod def from_engine_config(cls, engine_config: PytorchEngineConfig): @@ -302,5 +303,6 @@ def from_engine_config(cls, engine_config: PytorchEngineConfig): prefill_interval=engine_config.prefill_interval, model_format=engine_config.model_format, hf_overrides=engine_config.hf_overrides, - disable_vision_encoder=engine_config.disable_vision_encoder) + disable_vision_encoder=engine_config.disable_vision_encoder, + logprobs_mode=engine_config.logprobs_mode) return misc_config diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 0d1d5c8d16..d98b75c6e0 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -26,6 +26,7 @@ from .engine_checker import EngineChecker from .executor import build_executor from .logits_process import SamplingInputs +from .model_agent import BatchedOutputs from .request import Request, RequestManager, RequestType, Response logger = get_logger('lmdeploy') @@ -45,6 +46,7 @@ class InferOutput: meta: Any = None finish: bool = False logits: torch.Tensor = None + logprobs: torch.Tensor = None # send cache blocks back for migration in Disaggregated LLM Serving # when Prefill Engine is Done. @@ -813,9 +815,18 @@ def update_running_migration(self, running: SeqList, next_token_ids: np.ndarray, msg.update_token_ids(update_token, model_meta=model_meta) msg.status = MessageStatus.STOPPED - def _make_infer_outputs(self, new_token_timestamp: float, next_token_ids: torch.LongTensor, running: SeqList, - logits: torch.Tensor, stopped: torch.Tensor, model_metas: List[Dict[str, Any]]): + def _make_infer_outputs( + self, + batched_outputs: BatchedOutputs, + running: SeqList, + ): """Make infer output.""" + new_token_timestamp = batched_outputs.new_token_timestamp + next_token_ids = batched_outputs.next_token_ids + logits = batched_outputs.logits + stopped = batched_outputs.stopped + model_metas = batched_outputs.model_metas + logprobs = batched_outputs.logprobs seq_length = [seq.num_token_ids for seq in running] is_run = [seq.status == MessageStatus.LOCKED for seq in running] @@ -836,13 +847,21 @@ def _make_infer_outputs(self, new_token_timestamp: float, next_token_ids: torch. cache_block_ids = self.scheduler.block_manager.get_block_table(msg).tolist() else: cache_block_ids = None + + # logprobs + num_logprobs = msg.sampling_param.num_logprobs + cur_logprobs = None + if num_logprobs >= 0: + cur_logprobs = (logprobs.vals[idx, :num_logprobs + 1], logprobs.indices[idx, :num_logprobs + 1]) + req_metrics = RequestMetrics(new_token_timestamp, msg.engine_events) out = InferOutput(session_id=session_id, resp=msg.resp, finish=finish, token_ids=token_ids, cache_block_ids=cache_block_ids, - req_metrics=req_metrics) + req_metrics=req_metrics, + logprobs=cur_logprobs) outputs[session_id] = out if msg.return_logits: @@ -974,12 +993,22 @@ def __log_resps(outputs: List[InferOutput]): def __send_resp(out: InferOutput): """Send response.""" resp_type = (ResponseType.FINISH if out.finish else ResponseType.SUCCESS) + cur_logprobs = out.logprobs + logprobs = None + if cur_logprobs is not None: + # logprobs to dict + vals = cur_logprobs[0].tolist() + indices = cur_logprobs[1].tolist() + cur_logprobs = dict(zip(indices, vals)) + logprobs = [] if out.resp.data is None else out.resp.data.get('logprobs', []) + logprobs = logprobs + [cur_logprobs] self._response(out.resp, resp_type, data=dict(token_ids=out.token_ids, logits=out.logits, cache_block_ids=out.cache_block_ids, - req_metrics=out.req_metrics)) + req_metrics=out.req_metrics, + logprobs=logprobs)) def __send_resps(step_outputs: List[InferOutput]): """Send response callback.""" @@ -1115,8 +1144,8 @@ async def _async_loop_main( # send output out = await self.executor.get_output_async() - if len(out) > 0: - step_outputs = self._make_infer_outputs(**out, running=running) + if out is not None: + step_outputs = self._make_infer_outputs(out, running=running) resp_que.put_nowait(step_outputs) # lock forward event diff --git a/lmdeploy/pytorch/engine/engine_instance.py b/lmdeploy/pytorch/engine/engine_instance.py index f937c59398..c3331f2a83 100644 --- a/lmdeploy/pytorch/engine/engine_instance.py +++ b/lmdeploy/pytorch/engine/engine_instance.py @@ -149,6 +149,7 @@ async def async_stream_infer(self, cache_block_ids = resp.data.get('cache_block_ids', None) if resp.data else None req_metrics = resp.data.get('req_metrics', None) if resp.data else None + logprobs = resp.data.get('logprobs', None) if resp.data else None if resp.type == ResponseType.SUCCESS: token_ids = resp.data['token_ids'].tolist() num_ids = len(token_ids) @@ -157,7 +158,8 @@ async def async_stream_infer(self, token_ids, num_ids, cache_block_ids=cache_block_ids, - req_metrics=req_metrics) + req_metrics=req_metrics, + logprobs=logprobs) elif resp.type == ResponseType.FINISH: resp_data = resp.data token_ids = resp_data['token_ids'].tolist() @@ -169,7 +171,8 @@ async def async_stream_infer(self, num_ids, logits=logits, cache_block_ids=cache_block_ids, - req_metrics=req_metrics) + req_metrics=req_metrics, + logprobs=logprobs) break else: logger.debug(f'session[{session_id}] failed.') diff --git a/lmdeploy/pytorch/engine/executor/ray_executor.py b/lmdeploy/pytorch/engine/executor/ray_executor.py index c60fce37c3..c6ef85e222 100644 --- a/lmdeploy/pytorch/engine/executor/ray_executor.py +++ b/lmdeploy/pytorch/engine/executor/ray_executor.py @@ -6,7 +6,6 @@ import time from typing import Any, Dict, List, Optional, Tuple -import numpy as np import ray import ray.exceptions import torch @@ -326,13 +325,7 @@ def warmup_dist(self): def pack_output(self, output: Dict): """Pack output.""" - for k, v in output.items(): - if isinstance(v, torch.Tensor): - # fix numpy do not have BFloat16 type - if v.dtype is torch.bfloat16: - v = v.to(torch.float16) - output[k] = v.numpy() - return output + return output.to_numpy() def remote_log_start(self, msg: str): """Remote log start.""" @@ -487,10 +480,7 @@ async def _prefetch_outputs(self): outs = await self.workers[0].get_outputs.remote() logger.debug(f'Receive {len(outs)} outputs from worker[0].') for out in outs: - # pack pytorch - for k, v in out.items(): - if isinstance(v, np.ndarray): - out[k] = torch.from_numpy(v) + out = out.to_tensor() self.remote_outs.put_nowait(out) def _prefetch_task_callback(self, task: asyncio.Task): diff --git a/lmdeploy/pytorch/engine/logits_process.py b/lmdeploy/pytorch/engine/logits_process.py index 1332aee73f..214f83256e 100644 --- a/lmdeploy/pytorch/engine/logits_process.py +++ b/lmdeploy/pytorch/engine/logits_process.py @@ -126,6 +126,7 @@ class SamplingInputs: min_top_p: float = 1.0 response_formats: Tuple[str] = () logits_processors: List[List[LogitsProcessor]] = None + max_num_logprobs: Optional[int] = None @classmethod def from_sampling_params(cls, seqs: List[SchedulerSequence]): @@ -142,6 +143,7 @@ def from_sampling_params(cls, seqs: List[SchedulerSequence]): random_offsets = [None] * batch_size response_formats = [None] * batch_size logits_processors = [None] * batch_size + num_logprobs = [None] * batch_size def __gather_params(): """Gather params.""" @@ -164,6 +166,7 @@ def __gather_params(): bad_words[idx] = bw stop_words[idx] = sw logits_processors[idx] = param.logits_processors + num_logprobs[idx] = param.num_logprobs def __get_topp(top_p): """Get topp.""" @@ -232,6 +235,8 @@ def __get_bad_words(bad_words): random_seeds = torch.tensor(random_seeds) random_offsets = torch.tensor(random_offsets) + max_num_logprobs = max(num_logprobs) + sampling_input = cls( temperature=temperature, bad_words=bad_words, @@ -248,6 +253,7 @@ def __get_bad_words(bad_words): max_top_k=max_top_k, min_top_p=min_top_p, logits_processors=logits_processors, + max_num_logprobs=max_num_logprobs, ) return sampling_input @@ -280,11 +286,13 @@ def __init__(self, sampling_inputs: SamplingInputs, ignore_eos: torch.Tensor, tokenizer: Optional[Tokenizer] = None, - sampling_vocab_size: Optional[int] = None): + sampling_vocab_size: Optional[int] = None, + logprobs_mode: Optional[str] = None): self.sampling_inputs: SamplingInputs = sampling_inputs self.ignore_eos = ignore_eos self.tokenizer = tokenizer self.sampling_vocab_size = sampling_vocab_size + self.logprobs_mode = logprobs_mode async def _wait_stream_once(self): """Wait stream once.""" @@ -309,6 +317,19 @@ async def __call__(self, all_ids: torch.LongTensor, guided_input_ids: torch.Long torch.FloatTensor: The processed prediction scores. """ + + num_logprobs = self.sampling_inputs.max_num_logprobs + # get raw logprobs + if num_logprobs < 0: + logprobs = None + else: + if self.logprobs_mode == 'raw_logits': + logprobs = scores.clone() + elif self.logprobs_mode == 'raw_logprobs': + logprobs = scores.log_softmax(dim=-1) + else: + logprobs = None + sampling_inputs = self.sampling_inputs custom_logits_processors = self.sampling_inputs.logits_processors @@ -338,7 +359,7 @@ async def __call__(self, all_ids: torch.LongTensor, guided_input_ids: torch.Long if guided_input_ids is not None: await self._wait_stream_once() scores = _guided_sampling(sampling_inputs.response_formats, scores, guided_input_ids, self.tokenizer) - return scores + return scores, logprobs @torch.inference_mode() def sampling(self, logits: torch.Tensor): @@ -384,3 +405,19 @@ def __random_sampling(scores: torch.Tensor, indices: torch.LongTensor): else: scores, indices = logits.topk(max_topk, dim=1) return __random_sampling(scores, indices) + + @torch.inference_mode() + def compute_logprobs(self, raw_logprobs: torch.Tensor, token_ids: torch.LongTensor): + """Compute logprobs.""" + if raw_logprobs is None: + return None + + indices = token_ids.unsqueeze(-1) + logprobs = raw_logprobs.gather(-1, indices) + num_logprobs = self.sampling_inputs.max_num_logprobs + if num_logprobs > 0: + topk_logprobs, topk_indices = raw_logprobs.topk(num_logprobs, dim=-1) + logprobs = torch.cat([logprobs, topk_logprobs], dim=-1) + indices = torch.cat([indices, topk_indices], dim=-1) + + return logprobs, indices.to(torch.int32) diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index d66d38dc17..3ac76fb417 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -4,10 +4,12 @@ import functools import time from contextlib import asynccontextmanager, contextmanager +from dataclasses import dataclass, fields from multiprocessing.reduction import ForkingPickler from os import getenv from typing import Any, Dict, List, Optional +import numpy as np import torch import torch.distributed as dist from torch.profiler import ProfilerActivity, profile, record_function @@ -30,6 +32,81 @@ logger = get_logger('lmdeploy') +@dataclass +class BatchedLogProbs: + vals: torch.Tensor + indices: torch.Tensor + + def to_cpu(self): + """To cpu.""" + return BatchedLogProbs(vals=self.vals.cpu(), indices=self.indices.cpu()) + + def to_numpy(self): + """To numpy.""" + if self.vals.dtype == torch.bfloat16: + np_vals = self.vals + else: + np_vals = self.vals.detach().numpy() + return BatchedLogProbs(vals=np_vals, indices=self.indices.detach().numpy()) + + def to_tensor(self): + """To tensor.""" + if isinstance(self.vals, torch.Tensor): + vals = self.vals + else: + vals = torch.from_numpy(vals) + return BatchedLogProbs(vals=vals, indices=torch.from_numpy(self.indices)) + + +@dataclass +class BatchedOutputs: + next_token_ids: torch.Tensor + stopped: torch.Tensor + logits: Optional[torch.Tensor] = None + model_metas: List[Dict[str, Any]] = None + logprobs: Optional[BatchedLogProbs] = None + new_token_timestamp: int = 0 + + def to_cpu(self): + """To cpu.""" + out = dict() + for f in fields(self): + k = f.name + v = getattr(self, k) + if isinstance(v, torch.Tensor): + v = v.cpu() + elif hasattr(v, 'to_cpu'): + v = v.to_cpu() + out[k] = v + return BatchedOutputs(**out) + + def to_numpy(self): + """To numpy.""" + out = dict() + for f in fields(self): + k = f.name + v = getattr(self, k) + if isinstance(v, torch.Tensor) and v.dtype != torch.bfloat16: + v = v.detach().numpy() + elif hasattr(v, 'to_numpy'): + v = v.to_numpy() + out[k] = v + return BatchedOutputs(**out) + + def to_tensor(self): + """To tensor.""" + out = dict() + for f in fields(self): + k = f.name + v = getattr(self, k) + if isinstance(v, np.ndarray): + v = torch.from_numpy(v) + elif hasattr(v, 'to_tensor'): + v = v.to_tensor() + out[k] = v + return BatchedOutputs(**out) + + class AgentProfiler: def __init__(self, dist_ctx: DistContext, stream: torch.Stream): @@ -452,18 +529,24 @@ def __get_last_logits(): logits_processor = FusedLogitsProcessor(sampling_inputs, ignore_eos, self.tokenizer, - sampling_vocab_size=self.sampling_vocab_size) - logits = await logits_processor(all_ids, guided_input_ids, split_logits) + sampling_vocab_size=self.sampling_vocab_size, + logprobs_mode=self.misc_config.logprobs_mode) + logits, raw_logprobs = await logits_processor(all_ids, guided_input_ids, split_logits) next_token_ids = logits_processor.sampling(logits) + logprobs = logits_processor.compute_logprobs(raw_logprobs, next_token_ids) + if logprobs is not None: + logprobs = BatchedLogProbs( + vals=logprobs[0], + indices=logprobs[1], + ) - return next_token_ids + return next_token_ids, logprobs - def _push_output(self, output: dict): + def _push_output(self, output: BatchedOutputs): """Push output.""" event = torch.cuda.Event() event.record() - output['event'] = event - self._out_que.put_nowait(output) + self._out_que.put_nowait((output, event)) def _broadcast_next_token(self, next_token_ids: torch.Tensor, dist_ctx: DistContext = None): if dist_ctx is None: @@ -619,8 +702,8 @@ async def __prepare_dp(): if need_output: logger.debug(f' rank[{rank}]: Sampling [{idx}].') # sampling - next_token_ids = await self.async_sampling_logits(logits, all_ids, guided_input_ids, sampling_inputs, - inputs, num_ignore_eos > 0) + next_token_ids, logprobs = await self.async_sampling_logits(logits, all_ids, guided_input_ids, + sampling_inputs, inputs, num_ignore_eos > 0) num_ignore_eos = num_ignore_eos - 1 # stopping criteria @@ -631,6 +714,7 @@ async def __prepare_dp(): # as it can trigger recompilation on different ranks when using torch.compile. with torch.inference_mode(): next_token_ids = torch.zeros_like(num_ignore_eos) + logprobs = None # broadcast next token for TP > 1 need_broadcast_next = (dp == 1 and tp > 1 and idx < loop_count - 1) @@ -643,10 +727,11 @@ async def __prepare_dp(): if need_output: logger.debug(f' rank[{rank}]: Output [{idx}]') self._push_output( - dict(next_token_ids=next_token_ids, - logits=logits if return_logits else None, - stopped=stopped, - model_metas=model_metas)) + BatchedOutputs(next_token_ids=next_token_ids, + logits=logits if return_logits else None, + stopped=stopped, + model_metas=model_metas, + logprobs=logprobs)) # update for next loop if is_decoding and idx < loop_count - 1: @@ -796,16 +881,12 @@ async def get_output_async(self): if out is None: return dict() - event = out.pop('event') + out, event = out while not event.query(): await asyncio.sleep(0.001) with torch.cuda.stream(self.out_stream), torch.inference_mode(), record_function('outputs_D2H'): - out['next_token_ids'] = out['next_token_ids'].cpu() - out['stopped'] = out['stopped'].cpu() - # MUST be a wall-clock time - out['new_token_timestamp'] = time.time() - if out['logits'] is not None: - out['logits'] = out['logits'].cpu() + out = out.to_cpu() + out.new_token_timestamp = time.time() return out def _build_model(self): diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index 442e7146f4..c21db48eab 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -53,6 +53,7 @@ class SamplingParam: logits_processors: Optional[List[LogitsProcessor]] = None out_logits: bool = False out_last_hidden_states: bool = False + num_logprobs: int = -1 @classmethod def from_gen_config(self, gen_config: GenerationConfig): @@ -110,6 +111,9 @@ def from_gen_config(self, gen_config: GenerationConfig): 'a int >=0 and <= `max_new_tokens`,' f' but is {min_new_tokens}') min_new_tokens = 0 + logprobs = gen_config.logprobs + if logprobs is None: + logprobs = -1 return SamplingParam(top_p=top_p, top_k=top_k, min_p=min_p, @@ -123,7 +127,8 @@ def from_gen_config(self, gen_config: GenerationConfig): max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, logits_processors=gen_config.logits_processors, - out_logits=(output_logits is not None)) + out_logits=(output_logits is not None), + num_logprobs=logprobs) class MessageStatus(enum.Enum):