[KSM] support keep sampling mask#7460
[KSM] support keep sampling mask#7460zeroRains wants to merge 5 commits intoPaddlePaddle:release/2.6from
Conversation
|
Thanks for your contribution! |
There was a problem hiding this comment.
Pull request overview
该 PR 旨在为推理服务增加 keep sampling mask 输出能力:在 top_p/top_k 截断采样后,将每步保留下来的词表索引以稀疏形式返回/流式返回,便于客户端侧做可解释性与调试分析,并补充相应的 CLI 开关与端到端测试。
Changes:
- 新增启动参数
--enable-keep-sampling-mask,贯通 Engine/Worker/Sampler/TokenProcessor/OpenAI Serving 的开关传递。 - 在采样阶段计算稀疏 sampling_mask(以及 logZ),并在非
FD_USE_GET_SAVE_OUTPUT_V1路径通过 ZMQ side-channel 发送到 token_processor,再输出到 OpenAI 响应。 - 新增/更新单测与 e2e 测试覆盖 sampling_mask 在流式与非流式响应中的格式与一致性。
Reviewed changes
Copilot reviewed 19 out of 19 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/output/test_process_batch_output.py | 为测试构造的 processor 补齐 use_sampling_mask 字段初始化。 |
| tests/entrypoints/openai/test_max_streaming_tokens.py | 更新调用以适配 chat choice 新增的 sampling_mask_list 参数。 |
| tests/e2e/test_ernie_21b_mtp.py | e2e:启动参数开启 keep sampling mask,并新增流式/非流式/不同 top_p 的校验用例。 |
| fastdeploy/worker/worker_process.py | Worker CLI 新增 --enable-keep-sampling-mask(含下划线与短横线别名)。 |
| fastdeploy/worker/output.py | SamplerOutput 新增 sampling_mask 与 logz_per_batch 字段(稀疏 mask 与 logZ)。 |
| fastdeploy/worker/gpu_model_runner.py | 读取配置开关;非 V1 路径创建 sampling_mask ZMQ client;prepare_inputs 传 keep_sampling_mask;save_output 透传 sampling_mask_zmq_client。 |
| fastdeploy/output/token_processor.py | 非 V1 路径新增 sampling_mask ZMQ server;每步接收 mask 并写入 RequestOutput.outputs。 |
| fastdeploy/output/stream_transfer_data.py | StreamTransferData 新增 sampling_mask 字段以承载稀疏 mask。 |
| fastdeploy/model_executor/pre_and_post_process.py | stream transfer data 增加 sampling_mask;save_output_* 增加 side-channel 发送;新增基于 logZ 的 logprobs 归一化步骤。 |
| fastdeploy/model_executor/layers/sample/sampler.py | 新增 _compute_sampling_mask;normal 与 speculative 路径在采样前计算 sampling_mask/logZ 并写入 SamplerOutput。 |
| fastdeploy/model_executor/layers/sample/meta_data.py | SamplingMetadata 新增 keep_sampling_mask 字段。 |
| fastdeploy/model_executor/layers/sample/logprobs.py | build_output_logprobs 返回值新增 output_logits;新增 logprobs_renormalize_with_logz。 |
| fastdeploy/entrypoints/openai/serving_chat.py | 在 stream/full 响应中输出 sampling_mask;新增 _make_sampling_mask_list 并在 choice 汇总时扁平化。 |
| fastdeploy/entrypoints/openai/protocol.py | OpenAI 协议响应模型新增 sampling_mask 字段(List[List[int]])。 |
| fastdeploy/engine/request.py | CompletionOutput 新增 sampling_mask 字段并纳入 to_dict 输出。 |
| fastdeploy/engine/engine.py | worker_store_true_flag 增加 enable_keep_sampling_mask,启动 worker 时透传开关。 |
| fastdeploy/engine/common_engine.py | 同 engine.py:透传 enable_keep_sampling_mask 到 worker 启动参数。 |
| fastdeploy/engine/args_utils.py | EngineArgs/CLI 新增 --enable-keep-sampling-mask 参数与说明。 |
| fastdeploy/config.py | ModelConfig 新增 enable_keep_sampling_mask 默认字段。 |
| # Group by request using accept_num so each entry is List[np.ndarray] (n arrays per req). | ||
| real_bsz = model_output.accept_num.shape[0] | ||
| accept_nums = model_output.accept_num[:real_bsz].flatten().tolist() | ||
| mask_dict = {} | ||
| offset = 0 | ||
| total_masks = len(sampler_output.sampling_mask) | ||
| for i, n in enumerate(accept_nums): | ||
| n = max(int(n), 0) | ||
| if n > 0: | ||
| # List of n sparse index arrays, one per accepted token | ||
| mask_dict[i] = [arr.tolist() for arr in sampler_output.sampling_mask[offset : offset + n]] | ||
| offset += n | ||
| if offset != total_masks: | ||
| raise ValueError(f"sampling_mask length mismatch: expected {offset}, got {total_masks}") |
There was a problem hiding this comment.
Speculative 路径发送 sampling_mask 时用的是 model_output.accept_num 来做分组并按 i 构造 mask_dict key,但上面 speculate_save_output(_topk) 的输出会经过 index_to_batch_id + enable_pd_reorder 恢复到原始 batch 顺序;如果开启 PD reorder,这里未对 sampler_output.sampling_mask / accept_num / logz_per_batch 做一致的恢复排序,mask_dict 的 key/分组将与 token_processor 侧的 batch_id 不一致。建议:在生成 mask_dict 前先对 accept_num 与 sampling_mask 做与输出一致的 recover/reorder(可复用 recover_share_inputs["accept_num_cpu"] 或扩展 recover_batch_index_for_sampler_output),并同步重排 logz_per_batch。
| # Group by request using accept_num so each entry is List[np.ndarray] (n arrays per req). | |
| real_bsz = model_output.accept_num.shape[0] | |
| accept_nums = model_output.accept_num[:real_bsz].flatten().tolist() | |
| mask_dict = {} | |
| offset = 0 | |
| total_masks = len(sampler_output.sampling_mask) | |
| for i, n in enumerate(accept_nums): | |
| n = max(int(n), 0) | |
| if n > 0: | |
| # List of n sparse index arrays, one per accepted token | |
| mask_dict[i] = [arr.tolist() for arr in sampler_output.sampling_mask[offset : offset + n]] | |
| offset += n | |
| if offset != total_masks: | |
| raise ValueError(f"sampling_mask length mismatch: expected {offset}, got {total_masks}") | |
| # Recover it to the same batch order as speculate_save_output(_topk) before grouping by request. | |
| real_bsz = recover_share_inputs["accept_num_cpu"].shape[0] | |
| raw_accept_nums = model_output.accept_num[:real_bsz].flatten().tolist() | |
| recovered_accept_nums = recover_share_inputs["accept_num_cpu"][:real_bsz].flatten().tolist() | |
| total_masks = len(sampler_output.sampling_mask) | |
| sampling_mask_groups = [] | |
| offset = 0 | |
| for n in raw_accept_nums: | |
| n = max(int(n), 0) | |
| sampling_mask_groups.append(sampler_output.sampling_mask[offset : offset + n]) | |
| offset += n | |
| if offset != total_masks: | |
| raise ValueError(f"sampling_mask length mismatch: expected {offset}, got {total_masks}") | |
| recovered_sampling_mask_groups = [[] for _ in range(real_bsz)] | |
| if model_output.index_to_batch_id is None: | |
| batch_id_map = list(range(real_bsz)) | |
| else: | |
| batch_id_map = np.asarray(model_output.index_to_batch_id[:real_bsz]).flatten().tolist() | |
| for i, group in enumerate(sampling_mask_groups): | |
| batch_id = int(batch_id_map[i]) | |
| if batch_id < 0 or batch_id >= real_bsz: | |
| raise ValueError(f"sampling_mask batch_id out of range: {batch_id}, real_bsz={real_bsz}") | |
| recovered_sampling_mask_groups[batch_id] = group | |
| mask_dict = {} | |
| for i, n in enumerate(recovered_accept_nums): | |
| n = max(int(n), 0) | |
| if len(recovered_sampling_mask_groups[i]) != n: | |
| raise ValueError( | |
| f"sampling_mask group size mismatch for batch {i}: " | |
| f"expected {n}, got {len(recovered_sampling_mask_groups[i])}" | |
| ) | |
| if n > 0: | |
| # List of n sparse index arrays, one per accepted token. | |
| mask_dict[i] = [arr.tolist() for arr in recovered_sampling_mask_groups[i]] |
| logz = paddle.to_tensor(logz, dtype=logprobs.dtype) | ||
| # Renormalize: log π_masked = log π_full - log Z_K | ||
| # Only normalize valid candidates; padding positions use -inf | ||
| valid_mask = paddle.isfinite(logprobs) | ||
| normalized_logprobs = paddle.where( | ||
| valid_mask, logprobs - logz.unsqueeze(1), paddle.full_like(logprobs, float("-inf")) | ||
| ) | ||
| # Update logprobs_tensors with normalized values | ||
| return LogprobsTensors( | ||
| logprob_token_ids=logprobs_tensors.logprob_token_ids, | ||
| logprobs=normalized_logprobs, | ||
| selected_token_ranks=logprobs_tensors.selected_token_ranks, | ||
| ) |
There was a problem hiding this comment.
logprobs_renormalize_with_logz 目前对所有 isfinite 的位置统一做 logprobs - logZ_K,但 logprobs_tensors 里的 top-k 项是从“全量分布”topk 取出的,未必全部落在 top_p/top_k 截断后的候选集合 K 内(尤其当 top_p 很小且 max_logprobs 较大时)。这会导致返回的“重归一化 logprobs”仍包含候选集之外 token 的有限值,不符合截断分布语义。建议结合 sampling_mask(或 candidate set)把不在 K 内的 token logprobs 置为 -inf / None,并仅对 K 内条目做重归一化,或改为直接在截断后的分布上构造 logprobs 输出。
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## release/2.6 #7460 +/- ##
==============================================
Coverage ? 73.61%
==============================================
Files ? 376
Lines ? 53120
Branches ? 8297
==============================================
Hits ? 39102
Misses ? 11265
Partials ? 2753
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
PaddlePaddle-bot
left a comment
There was a problem hiding this comment.
🤖 AI Code Review |
2026-04-17 19:08 CST
📋 Review 摘要
PR 概述:为 FastDeploy 新增 Keep Sampling Mask (KSM) 功能,在 top_p/top_k 采样时返回保留词汇表索引的稀疏列表,并支持 logprobs 重归一化。
变更范围:model_executor/layers/sample/(核心采样逻辑)、entrypoints/openai/(API 协议与返回)、worker/(ZMQ 通道传输)、output/(结果处理)
影响面 Tag:OP APIServer Engine
📝 PR 规范检查
PR 标题 [KSM] support keep sampling mask 中的 [KSM] 不在官方 Tag 列表中。建议使用官方 Tag。
标题建议(可直接复制):
[Feature] Support keep sampling mask for top_p/top_k sampling
问题
| 级别 | 文件 | 概述 |
|---|---|---|
| 🔴 Bug | pre_and_post_process.py:456 |
sampling_mask_zmq_client 未做空值检查,在 metax 等未初始化 ZMQ 客户端的路径会 crash |
| 🔴 Bug | pre_and_post_process.py:639 |
同上,save_output_specualate 中同样的空值风险 |
| 🟡 建议 | serving_chat.py:669 |
非流式路径中 sampling_mask_list 多层嵌套可读性差,建议添加注释说明数据结构 |
| 🟡 建议 | logprobs.py:211 |
build_output_logprobs 移除了早期返回,未启用 logprobs 时仍执行 speculate_get_target_logits,有性能影响 |
总体评价
KSM 功能整体设计合理,采样掩码的稀疏索引表示方式高效,ZMQ 侧通道传输和 logprobs 重归一化的思路清晰。但在 save_output_normal 和 save_output_specualate 中,sampling_mask_zmq_client 未做空值保护,metax_model_runner.py 等调用方未传入该参数,若在这些硬件平台启用 --enable-keep-sampling-mask 将直接崩溃,需修复后合入。
| # sampling_mask is List[np.ndarray] of sparse int indices, one array per request. | ||
| mask_dict = {i: arr.tolist() for i, arr in enumerate(sampler_output.sampling_mask)} | ||
|
|
||
| sampling_mask_zmq_client.send_pyobj(mask_dict) |
There was a problem hiding this comment.
🔴 Bug sampling_mask_zmq_client 可能为 None,调用 send_pyobj 时会抛出 AttributeError。
save_output_normal 的 sampling_mask_zmq_client 参数默认值为 None,但此处仅检查了 sampler_output.sampling_mask is not None,未检查 ZMQ 客户端是否已初始化。metax_model_runner.py 调用 save_output_normal 时就没有传入 sampling_mask_zmq_client。
如果用户在 metax 平台启用 --enable-keep-sampling-mask,且 sampler_output.sampling_mask 不为 None,将直接崩溃。
建议修复:
if sampler_output.sampling_mask is not None and model_output.mp_rank == 0 and sampling_mask_zmq_client is not None:| offset += n | ||
| if offset != total_masks: | ||
| raise ValueError(f"sampling_mask length mismatch: expected {offset}, got {total_masks}") | ||
| sampling_mask_zmq_client.send_pyobj(mask_dict) |
There was a problem hiding this comment.
🔴 Bug 与 save_output_normal 同样的问题,sampling_mask_zmq_client 未做空值检查。
建议修复:
if sampler_output.sampling_mask is not None and model_output.mp_rank == 0 and sampling_mask_zmq_client is not None:|
|
||
| # Adapt for sampling mask | ||
| if num_logprobs is None: | ||
| return None, None, output_logits |
There was a problem hiding this comment.
🟡 建议 当仅启用 keep_sampling_mask 而未启用 enable_logprob 时(num_logprobs is None),build_output_logprobs 仍然会执行完整的 logit 提取逻辑(speculative 路径中包括 speculate_get_target_logits 等 GPU kernel 调用),仅在此处才提前返回。
这是为了获取 output_logits 用于采样掩码计算,逻辑上正确,但建议在此处添加注释说明:即使 num_logprobs is None,也需要执行前面的 logit 提取,以便调用方获得 output_logits 用于 _compute_sampling_mask。否则后续维护者可能会误以为可以将早期返回移回原位。
| ) | ||
| if prompt_logprobs_res: | ||
| prompt_logprobs_res_list[idx].extend(clamp_prompt_logprobs(prompt_logprobs_res)) | ||
| output_sampling_mask = output.get("sampling_mask", None) |
There was a problem hiding this comment.
🟡 建议 此处 sampling_mask_list[idx].append(self._make_sampling_mask_list(...)) 会产生三层嵌套结构:sampling_mask_list[idx] 的类型为 List[List[List[int]]],即 [step1: [[idx,...],[idx,...]], step2: [[idx,...]], ...]。
后续在 _create_chat_completion_choice 中通过 [mask for step in sampling_mask_list[idx] for mask in step] 展平为 List[List[int]],逻辑正确但不直观。建议在此处添加注释说明数据结构和后续展平逻辑,或改用 extend 避免额外嵌套层:
if output_sampling_mask is not None:
sampling_mask_list[idx].extend(self._make_sampling_mask_list(output_sampling_mask))这样 sampling_mask_list[idx] 类型为 List[List[int]],后续无需展平。
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 24 out of 24 changed files in this pull request and generated 5 comments.
Comments suppressed due to low confidence (1)
tests/output/test_process_batch_draft_tokens.py:39
- 这里 cfg.model_config 是 MagicMock,若未显式设置 enable_keep_sampling_mask=False,TokenProcessor 可能把 keep_sampling_mask 当成开启并尝试创建/绑定 ZMQ IPC server(路径包含固定的 "9700"),在测试并发或重复执行时容易冲突。建议在 cfg.model_config 上补充
enable_keep_sampling_mask = False(除非本用例确实要覆盖该功能并做好 socket 清理/隔离)。
# 模拟 cfg
cfg = MagicMock()
cfg.speculative_config = MagicMock()
cfg.parallel_config.local_data_parallel_id = 0
cfg.parallel_config.engine_worker_queue_port = ["9700"]
cfg.speculative_config.method = "mtp"
cfg.speculative_config.num_speculative_tokens = 3
cfg.model_config = MagicMock()
cfg.model_config.enable_logprob = True
| # where the value is a list[int] or list[list[int]] of allowed token ids | ||
| sampling_masks_per_request = {} | ||
| if self.use_sampling_mask and not envs.FD_USE_GET_SAVE_OUTPUT_V1 and hasattr(self, "sampling_mask_zmq_server"): | ||
| _, mask_data = self.sampling_mask_zmq_server.receive_pyobj_once(block=True) | ||
| if mask_data is not None and isinstance(mask_data, dict): | ||
| sampling_masks_per_request = mask_data | ||
|
|
There was a problem hiding this comment.
这里用 block=True 同步等待 sampling_mask side-channel 消息,缺少超时/降级路径:一旦 worker 未发送(例如 client 未创建/发送失败/某些 runner 未接入该 side-channel),TokenProcessor 会永久阻塞,导致整体推理挂死。建议改为非阻塞轮询(block=False)并在缺失时允许该 step 继续,或增加可配置超时并打印错误日志,避免死锁。
| # where the value is a list[int] or list[list[int]] of allowed token ids | |
| sampling_masks_per_request = {} | |
| if self.use_sampling_mask and not envs.FD_USE_GET_SAVE_OUTPUT_V1 and hasattr(self, "sampling_mask_zmq_server"): | |
| _, mask_data = self.sampling_mask_zmq_server.receive_pyobj_once(block=True) | |
| if mask_data is not None and isinstance(mask_data, dict): | |
| sampling_masks_per_request = mask_data | |
| # where the value is a list[int] or list[list[int]] of allowed token ids. | |
| # Use a non-blocking receive so a missing side-channel message does not | |
| # stall the whole token processing loop. | |
| sampling_masks_per_request = {} | |
| if self.use_sampling_mask and not envs.FD_USE_GET_SAVE_OUTPUT_V1 and hasattr(self, "sampling_mask_zmq_server"): | |
| mask_data = None | |
| try: | |
| _, mask_data = self.sampling_mask_zmq_server.receive_pyobj_once(block=False) | |
| except zmq.Again: | |
| mask_data = None | |
| except Exception: | |
| llm_logger.exception( | |
| "Failed to receive sampling_mask side-channel message; " | |
| "continuing without sampling mask for this step." | |
| ) | |
| mask_data = None | |
| if mask_data is not None: | |
| if isinstance(mask_data, dict): | |
| sampling_masks_per_request = mask_data | |
| else: | |
| llm_logger.warning( | |
| "Ignore invalid sampling_mask side-channel payload type: %s", | |
| type(mask_data).__name__, | |
| ) |
| def setup_method(self): | ||
| self.mock_cfg = MagicMock() | ||
| self.mock_cfg.parallel_config.local_data_parallel_id = 0 | ||
| self.mock_cfg.parallel_config.engine_worker_queue_port = ["9700"] |
There was a problem hiding this comment.
这里的 cfg/model_config 使用 MagicMock 时,TokenProcessor.init 里 getattr(cfg.model_config, "enable_keep_sampling_mask", False) 会返回一个 truthy 的 MagicMock,导致单测意外开启 keep_sampling_mask 并尝试 bind 固定的 IPC 地址(/dev/shm/sampling_mask_output_rank_0_9700.socket),容易在并行/重复运行时出现“Address already in use”或资源泄漏。建议在 mock_cfg.model_config 上显式设置 enable_keep_sampling_mask=False(或 patch envs.FD_USE_GET_SAVE_OUTPUT_V1=True 以避免创建该 server)。
| self.mock_cfg.parallel_config.engine_worker_queue_port = ["9700"] | |
| self.mock_cfg.parallel_config.engine_worker_queue_port = ["9700"] | |
| self.mock_cfg.model_config.enable_keep_sampling_mask = False |
| """为 TokenProcessor 测试设置通用的 mock 对象。""" | ||
| self.mock_cfg = MagicMock() | ||
| self.mock_cfg.parallel_config.local_data_parallel_id = 0 | ||
| self.mock_cfg.parallel_config.engine_worker_queue_port = ["9700"] |
There was a problem hiding this comment.
该 setUp 使用 MagicMock 构造 cfg 时同样存在 enable_keep_sampling_mask 被 MagicMock 误判为 True 的风险,TokenProcessor 可能在单测中意外创建并 bind sampling_mask 的 ZMQ IPC socket,造成端口/文件冲突和测试不稳定。建议显式设置 self.mock_cfg.model_config.enable_keep_sampling_mask = False。
| self.mock_cfg.parallel_config.engine_worker_queue_port = ["9700"] | |
| self.mock_cfg.parallel_config.engine_worker_queue_port = ["9700"] | |
| self.mock_cfg.model_config.enable_keep_sampling_mask = False |
| def setUp(self): | ||
| self.cfg = MagicMock() | ||
| self.cfg.model_config.enable_logprob = True | ||
| self.cfg.speculative_config.method = None | ||
| self.cfg.parallel_config.local_data_parallel_id = 0 | ||
| self.cfg.parallel_config.engine_worker_queue_port = ["9700"] | ||
| self.cached_generated_tokens = MagicMock() |
There was a problem hiding this comment.
该测试 cfg 通过 MagicMock 构造,TokenProcessor 初始化时可能将 enable_keep_sampling_mask 读取为 truthy 的 MagicMock,从而在单测里意外创建并 bind sampling_mask 的 ZMQ IPC server(固定 name/端口),导致用例间冲突或资源泄漏。建议在 cfg.model_config 上显式设置 enable_keep_sampling_mask=False。
| # Renormalize logprobs to match truncated sampling distribution (when enabled). | ||
| if sampler_output.logprobs_tensors is not None and sampler_output.logz_per_batch is not None: | ||
| sampler_output.logprobs_tensors = logprobs_renormalize_with_logz( | ||
| sampler_output.logprobs_tensors.logprobs, | ||
| sampler_output.logz_per_batch, | ||
| sampler_output.logprobs_tensors, | ||
| ) |
There was a problem hiding this comment.
这里对 logprobs 做 renormalize 时需要避免与 Sampler.compute_logprobs 中的 top_p_normalized_logprobs 逻辑重复归一化;否则当请求侧已开启 top_p_normalized_logprobs(top_p!=1.0)时会出现二次减去 logZ,导致返回的 logprobs 数值错误。建议按 request/token 维度判断是否已做过 top_p 归一化,再决定是否应用 logz_per_batch(或仅对未归一化的行应用)。
Motivation
本 PR 为 FastDeploy 实现 Keep Sampling Mask (KSM) 功能,用于在 top_p/top_k 采样过程中返回保留的词汇表索引列表(稀疏格式)。
当前推理引擎在执行 top_p/top_k 采样时,仅返回最终采样的 token ID,但不提供采样过程中的候选集合信息。这导致:
本 PR 通过新增 sampling_mask 字段,记录每个 token 采样时保留的词汇表索引(稀疏格式),并提供基于候选集合的 logprobs 重归一化功能。
Modifications
sampler.py 下新增_compute_sampling_mask方法
添加启动参数--enable-keep-sampling-mask
logprobs.py 下新增logz的renormalize函数,
logprobs_renormalize_with_logzpre_and_post_process.py的post_processs中调用renormalize函数
Usage or Command
服务启动指令:
Accuracy Tests
yes
Checklist
[FDConfig],[APIServer],[Engine],[Scheduler],[PD Disaggregation],[Executor],[Graph Optimization],[Speculative Decoding],[RL],[Models],[Quantization],[Loader],[OP],[KVCache],[DataProcessor],[BugFix],[Docs],[CI],[Optimization],[Feature],[Benchmark],[Others],[XPU],[HPU],[GCU],[DCU],[Iluvatar],[Metax]]pre-commitbefore commit.releasebranch, make sure the PR has been submitted to thedevelopbranch, then cherry-pick it to thereleasebranch with the[Cherry-Pick]PR tag.