Skip to content

[KSM] support keep sampling mask#7460

Open
zeroRains wants to merge 5 commits intoPaddlePaddle:release/2.6from
zeroRains:ksm_2.6
Open

[KSM] support keep sampling mask#7460
zeroRains wants to merge 5 commits intoPaddlePaddle:release/2.6from
zeroRains:ksm_2.6

Conversation

@zeroRains
Copy link
Copy Markdown
Contributor

@zeroRains zeroRains commented Apr 17, 2026

Motivation

本 PR 为 FastDeploy 实现 Keep Sampling Mask (KSM) 功能,用于在 top_p/top_k 采样过程中返回保留的词汇表索引列表(稀疏格式)。

当前推理引擎在执行 top_p/top_k 采样时,仅返回最终采样的 token ID,但不提供采样过程中的候选集合信息。这导致:

  1. 可解释性不足:无法了解模型在每个 token 生成时考虑了哪些候选词
  2. 调试困难:难以分析采样策略(如 top_p=0.9, top_k=50)的实际效果
  3. 下游应用受限:无法基于候选集合实现自定义后处理逻辑logprobs
  4. 归一化不完整:返回的 logprobs 未基于截断后的候选集合重新归一化

本 PR 通过新增 sampling_mask 字段,记录每个 token 采样时保留的词汇表索引(稀疏格式),并提供基于候选集合的 logprobs 重归一化功能。

💡 If this PR is a Cherry Pick, the PR title needs to follow the format by adding the [Cherry-Pick] label at the very beginning and appending the original PR ID at the end. For example, [Cherry-Pick][CI] Add check trigger and logic(#5191)

💡 如若此PR是Cherry Pick,PR标题需遵循格式,在最开始加上[Cherry-Pick]标签,以及最后面加上原PR ID,例如[Cherry-Pick][CI] Add check trigger and logic(#5191)

Modifications

sampler.py 下新增_compute_sampling_mask方法
添加启动参数--enable-keep-sampling-mask
logprobs.py 下新增logz的renormalize函数,logprobs_renormalize_with_logz
pre_and_post_process.py的post_processs中调用renormalize函数

Usage or Command

服务启动指令:

export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
MODEL_PATH="/root/paddlejob/tmpspace/GLM-4.5-Air/"
python -m fastdeploy.entrypoints.openai.api_server \
    --port 9293 \
    --host $(hostname -i) \
    --model "$MODEL_PATH" \
    --disable-custom-all-reduce \
    --tensor-parallel-size 8 \
    --max-model-len 131072 \
    --max-num-seqs 32 \
    --gpu-memory-utilization 0.9 \
    --graph-optimization-config '{"use_cudagraph":true}' \
    --enable-logprob \
    --enable-keep-sampling-mask \
    --speculative-config '{"method":"mtp","num_speculative_tokens":1,"num_model_steps":1,"model":"'$MODEL_PATH'"}'

Accuracy Tests

yes

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

Copilot AI review requested due to automatic review settings April 17, 2026 06:02
@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Apr 17, 2026

Thanks for your contribution!

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

该 PR 旨在为推理服务增加 keep sampling mask 输出能力:在 top_p/top_k 截断采样后,将每步保留下来的词表索引以稀疏形式返回/流式返回,便于客户端侧做可解释性与调试分析,并补充相应的 CLI 开关与端到端测试。

Changes:

  • 新增启动参数 --enable-keep-sampling-mask,贯通 Engine/Worker/Sampler/TokenProcessor/OpenAI Serving 的开关传递。
  • 在采样阶段计算稀疏 sampling_mask(以及 logZ),并在非 FD_USE_GET_SAVE_OUTPUT_V1 路径通过 ZMQ side-channel 发送到 token_processor,再输出到 OpenAI 响应。
  • 新增/更新单测与 e2e 测试覆盖 sampling_mask 在流式与非流式响应中的格式与一致性。

Reviewed changes

Copilot reviewed 19 out of 19 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
tests/output/test_process_batch_output.py 为测试构造的 processor 补齐 use_sampling_mask 字段初始化。
tests/entrypoints/openai/test_max_streaming_tokens.py 更新调用以适配 chat choice 新增的 sampling_mask_list 参数。
tests/e2e/test_ernie_21b_mtp.py e2e:启动参数开启 keep sampling mask,并新增流式/非流式/不同 top_p 的校验用例。
fastdeploy/worker/worker_process.py Worker CLI 新增 --enable-keep-sampling-mask(含下划线与短横线别名)。
fastdeploy/worker/output.py SamplerOutput 新增 sampling_mask 与 logz_per_batch 字段(稀疏 mask 与 logZ)。
fastdeploy/worker/gpu_model_runner.py 读取配置开关;非 V1 路径创建 sampling_mask ZMQ client;prepare_inputs 传 keep_sampling_mask;save_output 透传 sampling_mask_zmq_client。
fastdeploy/output/token_processor.py 非 V1 路径新增 sampling_mask ZMQ server;每步接收 mask 并写入 RequestOutput.outputs。
fastdeploy/output/stream_transfer_data.py StreamTransferData 新增 sampling_mask 字段以承载稀疏 mask。
fastdeploy/model_executor/pre_and_post_process.py stream transfer data 增加 sampling_mask;save_output_* 增加 side-channel 发送;新增基于 logZ 的 logprobs 归一化步骤。
fastdeploy/model_executor/layers/sample/sampler.py 新增 _compute_sampling_mask;normal 与 speculative 路径在采样前计算 sampling_mask/logZ 并写入 SamplerOutput。
fastdeploy/model_executor/layers/sample/meta_data.py SamplingMetadata 新增 keep_sampling_mask 字段。
fastdeploy/model_executor/layers/sample/logprobs.py build_output_logprobs 返回值新增 output_logits;新增 logprobs_renormalize_with_logz。
fastdeploy/entrypoints/openai/serving_chat.py 在 stream/full 响应中输出 sampling_mask;新增 _make_sampling_mask_list 并在 choice 汇总时扁平化。
fastdeploy/entrypoints/openai/protocol.py OpenAI 协议响应模型新增 sampling_mask 字段(List[List[int]])。
fastdeploy/engine/request.py CompletionOutput 新增 sampling_mask 字段并纳入 to_dict 输出。
fastdeploy/engine/engine.py worker_store_true_flag 增加 enable_keep_sampling_mask,启动 worker 时透传开关。
fastdeploy/engine/common_engine.py 同 engine.py:透传 enable_keep_sampling_mask 到 worker 启动参数。
fastdeploy/engine/args_utils.py EngineArgs/CLI 新增 --enable-keep-sampling-mask 参数与说明。
fastdeploy/config.py ModelConfig 新增 enable_keep_sampling_mask 默认字段。

Comment thread fastdeploy/entrypoints/openai/serving_chat.py
Comment thread fastdeploy/model_executor/layers/sample/logprobs.py
Comment thread fastdeploy/worker/output.py
Comment thread fastdeploy/output/stream_transfer_data.py
Comment thread fastdeploy/model_executor/layers/sample/sampler.py Outdated
Comment thread fastdeploy/output/token_processor.py
Comment thread fastdeploy/model_executor/pre_and_post_process.py
Comment thread fastdeploy/model_executor/pre_and_post_process.py
PaddlePaddle-bot

This comment was marked as outdated.

Copilot AI review requested due to automatic review settings April 17, 2026 07:30
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 19 out of 19 changed files in this pull request and generated 5 comments.

Comment on lines +625 to +638
# Group by request using accept_num so each entry is List[np.ndarray] (n arrays per req).
real_bsz = model_output.accept_num.shape[0]
accept_nums = model_output.accept_num[:real_bsz].flatten().tolist()
mask_dict = {}
offset = 0
total_masks = len(sampler_output.sampling_mask)
for i, n in enumerate(accept_nums):
n = max(int(n), 0)
if n > 0:
# List of n sparse index arrays, one per accepted token
mask_dict[i] = [arr.tolist() for arr in sampler_output.sampling_mask[offset : offset + n]]
offset += n
if offset != total_masks:
raise ValueError(f"sampling_mask length mismatch: expected {offset}, got {total_masks}")
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Speculative 路径发送 sampling_mask 时用的是 model_output.accept_num 来做分组并按 i 构造 mask_dict key,但上面 speculate_save_output(_topk) 的输出会经过 index_to_batch_id + enable_pd_reorder 恢复到原始 batch 顺序;如果开启 PD reorder,这里未对 sampler_output.sampling_mask / accept_num / logz_per_batch 做一致的恢复排序,mask_dict 的 key/分组将与 token_processor 侧的 batch_id 不一致。建议:在生成 mask_dict 前先对 accept_num 与 sampling_mask 做与输出一致的 recover/reorder(可复用 recover_share_inputs["accept_num_cpu"] 或扩展 recover_batch_index_for_sampler_output),并同步重排 logz_per_batch。

Suggested change
# Group by request using accept_num so each entry is List[np.ndarray] (n arrays per req).
real_bsz = model_output.accept_num.shape[0]
accept_nums = model_output.accept_num[:real_bsz].flatten().tolist()
mask_dict = {}
offset = 0
total_masks = len(sampler_output.sampling_mask)
for i, n in enumerate(accept_nums):
n = max(int(n), 0)
if n > 0:
# List of n sparse index arrays, one per accepted token
mask_dict[i] = [arr.tolist() for arr in sampler_output.sampling_mask[offset : offset + n]]
offset += n
if offset != total_masks:
raise ValueError(f"sampling_mask length mismatch: expected {offset}, got {total_masks}")
# Recover it to the same batch order as speculate_save_output(_topk) before grouping by request.
real_bsz = recover_share_inputs["accept_num_cpu"].shape[0]
raw_accept_nums = model_output.accept_num[:real_bsz].flatten().tolist()
recovered_accept_nums = recover_share_inputs["accept_num_cpu"][:real_bsz].flatten().tolist()
total_masks = len(sampler_output.sampling_mask)
sampling_mask_groups = []
offset = 0
for n in raw_accept_nums:
n = max(int(n), 0)
sampling_mask_groups.append(sampler_output.sampling_mask[offset : offset + n])
offset += n
if offset != total_masks:
raise ValueError(f"sampling_mask length mismatch: expected {offset}, got {total_masks}")
recovered_sampling_mask_groups = [[] for _ in range(real_bsz)]
if model_output.index_to_batch_id is None:
batch_id_map = list(range(real_bsz))
else:
batch_id_map = np.asarray(model_output.index_to_batch_id[:real_bsz]).flatten().tolist()
for i, group in enumerate(sampling_mask_groups):
batch_id = int(batch_id_map[i])
if batch_id < 0 or batch_id >= real_bsz:
raise ValueError(f"sampling_mask batch_id out of range: {batch_id}, real_bsz={real_bsz}")
recovered_sampling_mask_groups[batch_id] = group
mask_dict = {}
for i, n in enumerate(recovered_accept_nums):
n = max(int(n), 0)
if len(recovered_sampling_mask_groups[i]) != n:
raise ValueError(
f"sampling_mask group size mismatch for batch {i}: "
f"expected {n}, got {len(recovered_sampling_mask_groups[i])}"
)
if n > 0:
# List of n sparse index arrays, one per accepted token.
mask_dict[i] = [arr.tolist() for arr in recovered_sampling_mask_groups[i]]

Copilot uses AI. Check for mistakes.
Comment thread fastdeploy/model_executor/layers/sample/sampler.py
Comment on lines +234 to +246
logz = paddle.to_tensor(logz, dtype=logprobs.dtype)
# Renormalize: log π_masked = log π_full - log Z_K
# Only normalize valid candidates; padding positions use -inf
valid_mask = paddle.isfinite(logprobs)
normalized_logprobs = paddle.where(
valid_mask, logprobs - logz.unsqueeze(1), paddle.full_like(logprobs, float("-inf"))
)
# Update logprobs_tensors with normalized values
return LogprobsTensors(
logprob_token_ids=logprobs_tensors.logprob_token_ids,
logprobs=normalized_logprobs,
selected_token_ranks=logprobs_tensors.selected_token_ranks,
)
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logprobs_renormalize_with_logz 目前对所有 isfinite 的位置统一做 logprobs - logZ_K,但 logprobs_tensors 里的 top-k 项是从“全量分布”topk 取出的,未必全部落在 top_p/top_k 截断后的候选集合 K 内(尤其当 top_p 很小且 max_logprobs 较大时)。这会导致返回的“重归一化 logprobs”仍包含候选集之外 token 的有限值,不符合截断分布语义。建议结合 sampling_mask(或 candidate set)把不在 K 内的 token logprobs 置为 -inf / None,并仅对 K 内条目做重归一化,或改为直接在截断后的分布上构造 logprobs 输出。

Copilot uses AI. Check for mistakes.
Comment thread tests/e2e/test_ernie_21b_mtp.py
Comment thread fastdeploy/engine/args_utils.py
PaddlePaddle-bot

This comment was marked as outdated.

@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Apr 17, 2026

Codecov Report

❌ Patch coverage is 77.53623% with 31 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (release/2.6@185708b). Learn more about missing BASE report.

Files with missing lines Patch % Lines
fastdeploy/model_executor/pre_and_post_process.py 57.69% 6 Missing and 5 partials ⚠️
fastdeploy/entrypoints/openai/serving_chat.py 38.46% 6 Missing and 2 partials ⚠️
fastdeploy/output/token_processor.py 60.00% 4 Missing and 2 partials ⚠️
...astdeploy/model_executor/layers/sample/logprobs.py 55.55% 4 Missing ⚠️
fastdeploy/model_executor/layers/sample/sampler.py 96.22% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@              Coverage Diff               @@
##             release/2.6    #7460   +/-   ##
==============================================
  Coverage               ?   73.61%           
==============================================
  Files                  ?      376           
  Lines                  ?    53120           
  Branches               ?     8297           
==============================================
  Hits                   ?    39102           
  Misses                 ?    11265           
  Partials               ?     2753           
Flag Coverage Δ
GPU 73.61% <77.53%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

PaddlePaddle-bot

This comment was marked as outdated.

Copilot AI review requested due to automatic review settings April 17, 2026 10:58
Copy link
Copy Markdown

@PaddlePaddle-bot PaddlePaddle-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 AI Code Review | 2026-04-17 19:08 CST

📋 Review 摘要

PR 概述:为 FastDeploy 新增 Keep Sampling Mask (KSM) 功能,在 top_p/top_k 采样时返回保留词汇表索引的稀疏列表,并支持 logprobs 重归一化。
变更范围model_executor/layers/sample/(核心采样逻辑)、entrypoints/openai/(API 协议与返回)、worker/(ZMQ 通道传输)、output/(结果处理)
影响面 TagOP APIServer Engine

📝 PR 规范检查

PR 标题 [KSM] support keep sampling mask 中的 [KSM] 不在官方 Tag 列表中。建议使用官方 Tag。

标题建议(可直接复制):

  • [Feature] Support keep sampling mask for top_p/top_k sampling

问题

级别 文件 概述
🔴 Bug pre_and_post_process.py:456 sampling_mask_zmq_client 未做空值检查,在 metax 等未初始化 ZMQ 客户端的路径会 crash
🔴 Bug pre_and_post_process.py:639 同上,save_output_specualate 中同样的空值风险
🟡 建议 serving_chat.py:669 非流式路径中 sampling_mask_list 多层嵌套可读性差,建议添加注释说明数据结构
🟡 建议 logprobs.py:211 build_output_logprobs 移除了早期返回,未启用 logprobs 时仍执行 speculate_get_target_logits,有性能影响

总体评价

KSM 功能整体设计合理,采样掩码的稀疏索引表示方式高效,ZMQ 侧通道传输和 logprobs 重归一化的思路清晰。但在 save_output_normalsave_output_specualate 中,sampling_mask_zmq_client 未做空值保护,metax_model_runner.py 等调用方未传入该参数,若在这些硬件平台启用 --enable-keep-sampling-mask 将直接崩溃,需修复后合入。

# sampling_mask is List[np.ndarray] of sparse int indices, one array per request.
mask_dict = {i: arr.tolist() for i, arr in enumerate(sampler_output.sampling_mask)}

sampling_mask_zmq_client.send_pyobj(mask_dict)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug sampling_mask_zmq_client 可能为 None,调用 send_pyobj 时会抛出 AttributeError

save_output_normalsampling_mask_zmq_client 参数默认值为 None,但此处仅检查了 sampler_output.sampling_mask is not None,未检查 ZMQ 客户端是否已初始化。metax_model_runner.py 调用 save_output_normal 时就没有传入 sampling_mask_zmq_client

如果用户在 metax 平台启用 --enable-keep-sampling-mask,且 sampler_output.sampling_mask 不为 None,将直接崩溃。

建议修复:

if sampler_output.sampling_mask is not None and model_output.mp_rank == 0 and sampling_mask_zmq_client is not None:

offset += n
if offset != total_masks:
raise ValueError(f"sampling_mask length mismatch: expected {offset}, got {total_masks}")
sampling_mask_zmq_client.send_pyobj(mask_dict)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bugsave_output_normal 同样的问题,sampling_mask_zmq_client 未做空值检查。

建议修复:

if sampler_output.sampling_mask is not None and model_output.mp_rank == 0 and sampling_mask_zmq_client is not None:


# Adapt for sampling mask
if num_logprobs is None:
return None, None, output_logits
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 当仅启用 keep_sampling_mask 而未启用 enable_logprob 时(num_logprobs is None),build_output_logprobs 仍然会执行完整的 logit 提取逻辑(speculative 路径中包括 speculate_get_target_logits 等 GPU kernel 调用),仅在此处才提前返回。

这是为了获取 output_logits 用于采样掩码计算,逻辑上正确,但建议在此处添加注释说明:即使 num_logprobs is None,也需要执行前面的 logit 提取,以便调用方获得 output_logits 用于 _compute_sampling_mask。否则后续维护者可能会误以为可以将早期返回移回原位。

)
if prompt_logprobs_res:
prompt_logprobs_res_list[idx].extend(clamp_prompt_logprobs(prompt_logprobs_res))
output_sampling_mask = output.get("sampling_mask", None)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 此处 sampling_mask_list[idx].append(self._make_sampling_mask_list(...)) 会产生三层嵌套结构:sampling_mask_list[idx] 的类型为 List[List[List[int]]],即 [step1: [[idx,...],[idx,...]], step2: [[idx,...]], ...]

后续在 _create_chat_completion_choice 中通过 [mask for step in sampling_mask_list[idx] for mask in step] 展平为 List[List[int]],逻辑正确但不直观。建议在此处添加注释说明数据结构和后续展平逻辑,或改用 extend 避免额外嵌套层:

if output_sampling_mask is not None:
    sampling_mask_list[idx].extend(self._make_sampling_mask_list(output_sampling_mask))

这样 sampling_mask_list[idx] 类型为 List[List[int]],后续无需展平。

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 24 out of 24 changed files in this pull request and generated 5 comments.

Comments suppressed due to low confidence (1)

tests/output/test_process_batch_draft_tokens.py:39

  • 这里 cfg.model_config 是 MagicMock,若未显式设置 enable_keep_sampling_mask=False,TokenProcessor 可能把 keep_sampling_mask 当成开启并尝试创建/绑定 ZMQ IPC server(路径包含固定的 "9700"),在测试并发或重复执行时容易冲突。建议在 cfg.model_config 上补充 enable_keep_sampling_mask = False(除非本用例确实要覆盖该功能并做好 socket 清理/隔离)。
        # 模拟 cfg
        cfg = MagicMock()
        cfg.speculative_config = MagicMock()
        cfg.parallel_config.local_data_parallel_id = 0
        cfg.parallel_config.engine_worker_queue_port = ["9700"]
        cfg.speculative_config.method = "mtp"
        cfg.speculative_config.num_speculative_tokens = 3
        cfg.model_config = MagicMock()
        cfg.model_config.enable_logprob = True

Comment on lines +749 to +755
# where the value is a list[int] or list[list[int]] of allowed token ids
sampling_masks_per_request = {}
if self.use_sampling_mask and not envs.FD_USE_GET_SAVE_OUTPUT_V1 and hasattr(self, "sampling_mask_zmq_server"):
_, mask_data = self.sampling_mask_zmq_server.receive_pyobj_once(block=True)
if mask_data is not None and isinstance(mask_data, dict):
sampling_masks_per_request = mask_data

Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里用 block=True 同步等待 sampling_mask side-channel 消息,缺少超时/降级路径:一旦 worker 未发送(例如 client 未创建/发送失败/某些 runner 未接入该 side-channel),TokenProcessor 会永久阻塞,导致整体推理挂死。建议改为非阻塞轮询(block=False)并在缺失时允许该 step 继续,或增加可配置超时并打印错误日志,避免死锁。

Suggested change
# where the value is a list[int] or list[list[int]] of allowed token ids
sampling_masks_per_request = {}
if self.use_sampling_mask and not envs.FD_USE_GET_SAVE_OUTPUT_V1 and hasattr(self, "sampling_mask_zmq_server"):
_, mask_data = self.sampling_mask_zmq_server.receive_pyobj_once(block=True)
if mask_data is not None and isinstance(mask_data, dict):
sampling_masks_per_request = mask_data
# where the value is a list[int] or list[list[int]] of allowed token ids.
# Use a non-blocking receive so a missing side-channel message does not
# stall the whole token processing loop.
sampling_masks_per_request = {}
if self.use_sampling_mask and not envs.FD_USE_GET_SAVE_OUTPUT_V1 and hasattr(self, "sampling_mask_zmq_server"):
mask_data = None
try:
_, mask_data = self.sampling_mask_zmq_server.receive_pyobj_once(block=False)
except zmq.Again:
mask_data = None
except Exception:
llm_logger.exception(
"Failed to receive sampling_mask side-channel message; "
"continuing without sampling mask for this step."
)
mask_data = None
if mask_data is not None:
if isinstance(mask_data, dict):
sampling_masks_per_request = mask_data
else:
llm_logger.warning(
"Ignore invalid sampling_mask side-channel payload type: %s",
type(mask_data).__name__,
)

Copilot uses AI. Check for mistakes.
def setup_method(self):
self.mock_cfg = MagicMock()
self.mock_cfg.parallel_config.local_data_parallel_id = 0
self.mock_cfg.parallel_config.engine_worker_queue_port = ["9700"]
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的 cfg/model_config 使用 MagicMock 时,TokenProcessor.initgetattr(cfg.model_config, "enable_keep_sampling_mask", False) 会返回一个 truthy 的 MagicMock,导致单测意外开启 keep_sampling_mask 并尝试 bind 固定的 IPC 地址(/dev/shm/sampling_mask_output_rank_0_9700.socket),容易在并行/重复运行时出现“Address already in use”或资源泄漏。建议在 mock_cfg.model_config 上显式设置 enable_keep_sampling_mask=False(或 patch envs.FD_USE_GET_SAVE_OUTPUT_V1=True 以避免创建该 server)。

Suggested change
self.mock_cfg.parallel_config.engine_worker_queue_port = ["9700"]
self.mock_cfg.parallel_config.engine_worker_queue_port = ["9700"]
self.mock_cfg.model_config.enable_keep_sampling_mask = False

Copilot uses AI. Check for mistakes.
"""为 TokenProcessor 测试设置通用的 mock 对象。"""
self.mock_cfg = MagicMock()
self.mock_cfg.parallel_config.local_data_parallel_id = 0
self.mock_cfg.parallel_config.engine_worker_queue_port = ["9700"]
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该 setUp 使用 MagicMock 构造 cfg 时同样存在 enable_keep_sampling_mask 被 MagicMock 误判为 True 的风险,TokenProcessor 可能在单测中意外创建并 bind sampling_mask 的 ZMQ IPC socket,造成端口/文件冲突和测试不稳定。建议显式设置 self.mock_cfg.model_config.enable_keep_sampling_mask = False。

Suggested change
self.mock_cfg.parallel_config.engine_worker_queue_port = ["9700"]
self.mock_cfg.parallel_config.engine_worker_queue_port = ["9700"]
self.mock_cfg.model_config.enable_keep_sampling_mask = False

Copilot uses AI. Check for mistakes.
Comment on lines 29 to 35
def setUp(self):
self.cfg = MagicMock()
self.cfg.model_config.enable_logprob = True
self.cfg.speculative_config.method = None
self.cfg.parallel_config.local_data_parallel_id = 0
self.cfg.parallel_config.engine_worker_queue_port = ["9700"]
self.cached_generated_tokens = MagicMock()
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该测试 cfg 通过 MagicMock 构造,TokenProcessor 初始化时可能将 enable_keep_sampling_mask 读取为 truthy 的 MagicMock,从而在单测里意外创建并 bind sampling_mask 的 ZMQ IPC server(固定 name/端口),导致用例间冲突或资源泄漏。建议在 cfg.model_config 上显式设置 enable_keep_sampling_mask=False。

Copilot uses AI. Check for mistakes.
Comment on lines +380 to +386
# Renormalize logprobs to match truncated sampling distribution (when enabled).
if sampler_output.logprobs_tensors is not None and sampler_output.logz_per_batch is not None:
sampler_output.logprobs_tensors = logprobs_renormalize_with_logz(
sampler_output.logprobs_tensors.logprobs,
sampler_output.logz_per_batch,
sampler_output.logprobs_tensors,
)
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里对 logprobs 做 renormalize 时需要避免与 Sampler.compute_logprobs 中的 top_p_normalized_logprobs 逻辑重复归一化;否则当请求侧已开启 top_p_normalized_logprobs(top_p!=1.0)时会出现二次减去 logZ,导致返回的 logprobs 数值错误。建议按 request/token 维度判断是否已做过 top_p 归一化,再决定是否应用 logz_per_batch(或仅对未归一化的行应用)。

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants