Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def __init__(
self.enable_logprob = False
self.max_logprobs = 20
self.logprobs_mode = "raw_logprobs"
self.enable_keep_sampling_mask = False
self.redundant_experts_num = 0
self.seed = 0
self.quantization = None
Expand Down
20 changes: 20 additions & 0 deletions fastdeploy/engine/args_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,14 @@ class EngineArgs:
Must be explicitly enabled via the `--enable-logprob` startup parameter to output logprob values.
"""

enable_keep_sampling_mask: bool = False
"""
When enabled, the server returns a sparse index list for each generated token, indicating
which vocabulary positions were retained after top_p/top_k sampling, and streams it to
the client. In MTP (multi-token prediction) scenarios this field is a List[List[int]],
where each inner list contains the retained vocabulary indices for a predicted token.
"""

max_logprobs: int = 20
"""
Maximum number of log probabilities to return when `enable_logprob` is True. The default value comes the default for the
Expand Down Expand Up @@ -901,6 +909,18 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=EngineArgs.enable_logprob,
help="Enable output of token-level log probabilities.",
)
model_group.add_argument(
"--enable-keep-sampling-mask",
action="store_true",
default=EngineArgs.enable_keep_sampling_mask,
help=(
Comment thread
zeroRains marked this conversation as resolved.
"Enable output of sampling mask as a sparse index list over the vocabulary. "
"For non-MTP decoding, this is a list[int] per token step indicating which "
"vocabulary indices were kept after top_p/top_k sampling. "
"For MTP decoding, this is a list[list[int]] per token step, where each inner "
"list corresponds to one MTP group."
),
)
model_group.add_argument(
"--max-logprobs",
type=int,
Expand Down
1 change: 1 addition & 0 deletions fastdeploy/engine/common_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2491,6 +2491,7 @@ def _start_worker_service(self):
"moe_gate_fp32": self.cfg.model_config.moe_gate_fp32,
"enable_entropy": self.cfg.model_config.enable_entropy,
"enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule,
"enable_keep_sampling_mask": self.cfg.model_config.enable_keep_sampling_mask,
}
for worker_flag, value in worker_store_true_flag.items():
if value:
Expand Down
1 change: 1 addition & 0 deletions fastdeploy/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,7 @@ def _start_worker_service(self):
"enable_entropy": self.cfg.model_config.enable_entropy,
"ep_prefill_use_worst_num_tokens": self.cfg.parallel_config.ep_prefill_use_worst_num_tokens,
"enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule,
"enable_keep_sampling_mask": self.cfg.model_config.enable_keep_sampling_mask,
}
for worker_flag, value in worker_store_true_flag.items():
if value:
Expand Down
5 changes: 5 additions & 0 deletions fastdeploy/engine/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,10 @@ class CompletionOutput:
delta_message: Optional[DeltaMessage] = None
multipart: Optional[list[Any]] = None
num_image_tokens: Optional[int] = None
# Sparse indices of retained vocab ids:
# - Non-MTP: list[int]
# - MTP: list[list[int]]
sampling_mask: Optional[Any] = None

def to_dict(self):
"""
Expand All @@ -745,6 +749,7 @@ def to_dict(self):
"text": self.text,
"reasoning_content": self.reasoning_content,
"reasoning_token_num": self.reasoning_token_num,
"sampling_mask": self.sampling_mask,
}

@classmethod
Expand Down
5 changes: 5 additions & 0 deletions fastdeploy/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,8 @@ class ChatCompletionResponseChoice(BaseModel):
prompt_logprobs: Optional[PromptLogprobs] = None
finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]]
speculate_metrics: Optional[SpeculateMetrics] = None
# Per-token retained vocab indices from top_p/top_k sampling: List[List[int]], one list of vocab indices per token
sampling_mask: Optional[List[List[int]]] = None


class ChatCompletionResponse(BaseModel):
Expand Down Expand Up @@ -333,6 +335,9 @@ class ChatCompletionResponseStreamChoice(BaseModel):
logprobs: Optional[LogProbs] = None
draft_logprobs: Optional[LogProbs] = None
prompt_logprobs: Optional[PromptLogprobs] = None
# Per-token index list of retained positions after top_p sampling.
# Non-MTP: [[idx, ...]] (1 token/step). MTP: [[idx, ...], ...] (N accepted tokens/step).
sampling_mask: Optional[List[List[int]]] = None
finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]] = None
arrival_time: Optional[float] = None
speculate_metrics: Optional[SpeculateMetrics] = None
Expand Down
32 changes: 32 additions & 0 deletions fastdeploy/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,11 @@ async def chat_completion_stream_generator(
delta=delta_message,
logprobs=logprobs_res,
draft_logprobs=draft_logprobs_res,
sampling_mask=(
self._make_sampling_mask_list(output["sampling_mask"])
if output.get("sampling_mask") is not None
else None
),
arrival_time=arrival_time,
speculate_metrics=output_speculate_metrics,
)
Expand Down Expand Up @@ -580,6 +585,7 @@ async def chat_completion_full_generator(
decoder_base_url=self.tokenizer_base_url,
)
prompt_logprobs_res_list = [[] for _ in range(num_choices)]
sampling_mask_list = [[] for _ in range(num_choices)]
speculate_metrics = [None for _ in range(num_choices)]
choices = []
while num_choices > 0:
Expand Down Expand Up @@ -660,6 +666,9 @@ async def chat_completion_full_generator(
)
if prompt_logprobs_res:
prompt_logprobs_res_list[idx].extend(clamp_prompt_logprobs(prompt_logprobs_res))
output_sampling_mask = output.get("sampling_mask", None)

This comment was marked as outdated.

if output_sampling_mask is not None:
sampling_mask_list[idx].append(self._make_sampling_mask_list(output_sampling_mask))
speculate_metrics[idx] = data["metrics"].get("speculate_metrics", None)
if data["finished"]:
trace_carrier = data.get("trace_carrier")
Expand Down Expand Up @@ -695,6 +704,7 @@ async def chat_completion_full_generator(
draft_logprob_contents=draft_logprob_contents,
response_processor=response_processor,
prompt_logprobs_res_list=prompt_logprobs_res_list,
sampling_mask_list=sampling_mask_list,
max_tokens=max_tokens,
speculate_metrics=speculate_metrics[idx],
)
Expand Down Expand Up @@ -749,6 +759,7 @@ async def _create_chat_completion_choice(
logprob_contents: list,
draft_logprob_contents: list,
prompt_logprobs_res_list: list,
sampling_mask_list: list,
response_processor: ChatResponseProcessor,
max_tokens: int,
speculate_metrics: SpeculateMetrics | None,
Expand Down Expand Up @@ -787,6 +798,11 @@ async def _create_chat_completion_choice(
if prompt_logprobs_res_list[idx]:
prompt_logprobs_full_res = prompt_logprobs_res_list[idx]

# Flatten per-step List[List[int]] into a single List[List[int]] over all tokens.
sampling_mask_full_res = None
if sampling_mask_list and sampling_mask_list[idx]:
sampling_mask_full_res = [mask for step in sampling_mask_list[idx] for mask in step]

num_cached_tokens[idx] = data.get("num_cached_tokens", 0)
num_input_image_tokens[idx] = data.get("num_input_image_tokens", 0)
num_input_video_tokens[idx] = data.get("num_input_video_tokens", 0)
Expand All @@ -810,6 +826,7 @@ async def _create_chat_completion_choice(
logprobs=logprobs_full_res,
draft_logprobs=draft_logprobs_full_res,
prompt_logprobs=prompt_logprobs_full_res,
sampling_mask=sampling_mask_full_res,
finish_reason=finish_reason,
speculate_metrics=speculate_metrics,
)
Expand Down Expand Up @@ -1000,3 +1017,18 @@ def _make_logprob_dict(
)
for token_id, logprob, rank, token in zip(logprob_token_ids, logprobs, ranks, decoded_tokens)
}

@staticmethod
def _make_sampling_mask_list(sampling_mask) -> List[List[int]]:
"""Wrap sampling_mask into a uniform List[List[int]] format.

sampling_mask is already in sparse-index form (no bool-to-index conversion needed):
Non-MTP: List[int] (indices for 1 token/step) → [[idx, ...]]
MTP: List[List[int]] (indices for N tokens/step) → [[idx, ...], ...]
"""
assert sampling_mask is not None
if sampling_mask and isinstance(sampling_mask[0], list):
Comment thread
zeroRains marked this conversation as resolved.
Comment thread
zeroRains marked this conversation as resolved.
Comment thread
zeroRains marked this conversation as resolved.
# MTP: already List[List[int]], return as-is
Comment thread
zeroRains marked this conversation as resolved.
return sampling_mask
# Non-MTP: already List[int], wrap in outer list for uniform format
return [sampling_mask]
37 changes: 31 additions & 6 deletions fastdeploy/model_executor/layers/sample/logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from typing import Callable, List, Optional, Tuple

import numpy as np
import paddle
import paddle.nn.functional as F
import triton
Expand Down Expand Up @@ -133,7 +134,7 @@ def build_output_logprobs(
is_naive: bool = False,
logprobs_mode: str = "default",
compute_logprobs_fn: Optional[Callable] = None,
) -> Tuple[Optional[LogprobsTensors], Optional[paddle.Tensor]]:
) -> Tuple[Optional[LogprobsTensors], Optional[paddle.Tensor], Optional[paddle.Tensor]]:
"""
Build logprobs output for both NAIVE and speculative (MTP/Ngram) modes.

Expand All @@ -153,15 +154,12 @@ def build_output_logprobs(
scaling and top_p normalization. Used when logprobs_mode == "raw_logprobs".

Returns:
tuple: (logprobs_tensors, cu_batch_token_offset)
tuple: (logprobs_tensors, cu_batch_token_offset, output_logits)
"""
num_logprobs = sampling_metadata.max_num_logprobs
logprobs_tensors = None
cu_batch_token_offset = None

if num_logprobs is None:
return logprobs_tensors, cu_batch_token_offset

real_bsz = share_inputs["seq_lens_this_time"].shape[0]

if is_naive:
Expand Down Expand Up @@ -208,6 +206,10 @@ def build_output_logprobs(
mask = idx < share_inputs["accept_num"].unsqueeze(1)
token_ids = paddle.masked_select(share_inputs["accept_tokens"], mask)

# Adapt for sampling mask
Comment thread
zeroRains marked this conversation as resolved.
if num_logprobs is None:
return None, None, output_logits

This comment was marked as outdated.

This comment was marked as outdated.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

无影响,只有logprobs or sampling_mask为true的时候会执行这个函数,如果只有sampling_mask,仍然需要计算出logits,如果只有logrpobs,那也需要计算logits,我觉得不影响


# Compute logprobs with temperature scaling and top_p normalization
if logprobs_mode == "raw_logprobs":
raw_logprobs = compute_logprobs_fn(output_logits, sampling_metadata)
Expand All @@ -217,5 +219,28 @@ def build_output_logprobs(
raw_logprobs = F.log_softmax(output_logits, axis=-1)

logprobs_tensors = gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids)
# output_logits use to compute sampling_mask
return logprobs_tensors, cu_batch_token_offset, output_logits

return logprobs_tensors, cu_batch_token_offset

def logprobs_renormalize_with_logz(logprobs: paddle.Tensor, logz: np.ndarray, logprobs_tensors: LogprobsTensors):
"""
Renormalize logprobs to match truncated sampling distribution.
Args:
logprobs: tensor [B, max_num_logprobs + 1]
logz: [B], log(sum(probs in candidate set K)) for each request
logprobs_tensors: LogprobsTensors
"""
logz = paddle.to_tensor(logz, dtype=logprobs.dtype)
# Renormalize: log π_masked = log π_full - log Z_K
# Only normalize valid candidates; padding positions use -inf
valid_mask = paddle.isfinite(logprobs)
normalized_logprobs = paddle.where(
valid_mask, logprobs - logz.unsqueeze(1), paddle.full_like(logprobs, float("-inf"))
Comment thread
zeroRains marked this conversation as resolved.
)
# Update logprobs_tensors with normalized values
return LogprobsTensors(
logprob_token_ids=logprobs_tensors.logprob_token_ids,
logprobs=normalized_logprobs,
selected_token_ranks=logprobs_tensors.selected_token_ranks,
)
Comment on lines +234 to +246
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logprobs_renormalize_with_logz 目前对所有 isfinite 的位置统一做 logprobs - logZ_K,但 logprobs_tensors 里的 top-k 项是从“全量分布”topk 取出的,未必全部落在 top_p/top_k 截断后的候选集合 K 内(尤其当 top_p 很小且 max_logprobs 较大时)。这会导致返回的“重归一化 logprobs”仍包含候选集之外 token 的有限值,不符合截断分布语义。建议结合 sampling_mask(或 candidate set)把不在 K 内的 token logprobs 置为 -inf / None,并仅对 K 内条目做重归一化,或改为直接在截断后的分布上构造 logprobs 输出。

Copilot uses AI. Check for mistakes.
2 changes: 2 additions & 0 deletions fastdeploy/model_executor/layers/sample/meta_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,5 @@ class SamplingMetadata:
# Add for HPU post-processing
seq_lens_encoder: Optional[paddle.Tensor] = None
seq_lens_decoder: Optional[paddle.Tensor] = None
# Add for keep sampling mask
keep_sampling_mask: Optional[bool] = None
Loading
Loading