diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 8dc18403608..8750222c137 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -211,6 +211,7 @@ def __init__( self.enable_logprob = False self.max_logprobs = 20 self.logprobs_mode = "raw_logprobs" + self.enable_keep_sampling_mask = False self.redundant_experts_num = 0 self.seed = 0 self.quantization = None diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index d350350f85d..e4d054b9924 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -460,6 +460,14 @@ class EngineArgs: Must be explicitly enabled via the `--enable-logprob` startup parameter to output logprob values. """ + enable_keep_sampling_mask: bool = False + """ + When enabled, the server returns a sparse index list for each generated token, indicating + which vocabulary positions were retained after top_p/top_k sampling, and streams it to + the client. In MTP (multi-token prediction) scenarios this field is a List[List[int]], + where each inner list contains the retained vocabulary indices for a predicted token. + """ + max_logprobs: int = 20 """ Maximum number of log probabilities to return when `enable_logprob` is True. The default value comes the default for the @@ -901,6 +909,18 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.enable_logprob, help="Enable output of token-level log probabilities.", ) + model_group.add_argument( + "--enable-keep-sampling-mask", + action="store_true", + default=EngineArgs.enable_keep_sampling_mask, + help=( + "Enable output of sampling mask as a sparse index list over the vocabulary. " + "For non-MTP decoding, this is a list[int] per token step indicating which " + "vocabulary indices were kept after top_p/top_k sampling. " + "For MTP decoding, this is a list[list[int]] per token step, where each inner " + "list corresponds to one MTP group." + ), + ) model_group.add_argument( "--max-logprobs", type=int, diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index dabed9e4342..3eae43c3c32 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -2491,6 +2491,7 @@ def _start_worker_service(self): "moe_gate_fp32": self.cfg.model_config.moe_gate_fp32, "enable_entropy": self.cfg.model_config.enable_entropy, "enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule, + "enable_keep_sampling_mask": self.cfg.model_config.enable_keep_sampling_mask, } for worker_flag, value in worker_store_true_flag.items(): if value: diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 283693fae8c..44edea80d34 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -655,6 +655,7 @@ def _start_worker_service(self): "enable_entropy": self.cfg.model_config.enable_entropy, "ep_prefill_use_worst_num_tokens": self.cfg.parallel_config.ep_prefill_use_worst_num_tokens, "enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule, + "enable_keep_sampling_mask": self.cfg.model_config.enable_keep_sampling_mask, } for worker_flag, value in worker_store_true_flag.items(): if value: diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 0e95cd5e1fb..ccab1ac4114 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -727,6 +727,10 @@ class CompletionOutput: delta_message: Optional[DeltaMessage] = None multipart: Optional[list[Any]] = None num_image_tokens: Optional[int] = None + # Sparse indices of retained vocab ids: + # - Non-MTP: list[int] + # - MTP: list[list[int]] + sampling_mask: Optional[Any] = None def to_dict(self): """ @@ -745,6 +749,7 @@ def to_dict(self): "text": self.text, "reasoning_content": self.reasoning_content, "reasoning_token_num": self.reasoning_token_num, + "sampling_mask": self.sampling_mask, } @classmethod diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index 3560f3a8aef..42923623776 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -270,6 +270,8 @@ class ChatCompletionResponseChoice(BaseModel): prompt_logprobs: Optional[PromptLogprobs] = None finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]] speculate_metrics: Optional[SpeculateMetrics] = None + # Per-token retained vocab indices from top_p/top_k sampling: List[List[int]], one list of vocab indices per token + sampling_mask: Optional[List[List[int]]] = None class ChatCompletionResponse(BaseModel): @@ -333,6 +335,9 @@ class ChatCompletionResponseStreamChoice(BaseModel): logprobs: Optional[LogProbs] = None draft_logprobs: Optional[LogProbs] = None prompt_logprobs: Optional[PromptLogprobs] = None + # Per-token index list of retained positions after top_p sampling. + # Non-MTP: [[idx, ...]] (1 token/step). MTP: [[idx, ...], ...] (N accepted tokens/step). + sampling_mask: Optional[List[List[int]]] = None finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]] = None arrival_time: Optional[float] = None speculate_metrics: Optional[SpeculateMetrics] = None diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index eb106f6550f..55bd37412a0 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -435,6 +435,11 @@ async def chat_completion_stream_generator( delta=delta_message, logprobs=logprobs_res, draft_logprobs=draft_logprobs_res, + sampling_mask=( + self._make_sampling_mask_list(output["sampling_mask"]) + if output.get("sampling_mask") is not None + else None + ), arrival_time=arrival_time, speculate_metrics=output_speculate_metrics, ) @@ -580,6 +585,7 @@ async def chat_completion_full_generator( decoder_base_url=self.tokenizer_base_url, ) prompt_logprobs_res_list = [[] for _ in range(num_choices)] + sampling_mask_list = [[] for _ in range(num_choices)] speculate_metrics = [None for _ in range(num_choices)] choices = [] while num_choices > 0: @@ -660,6 +666,9 @@ async def chat_completion_full_generator( ) if prompt_logprobs_res: prompt_logprobs_res_list[idx].extend(clamp_prompt_logprobs(prompt_logprobs_res)) + output_sampling_mask = output.get("sampling_mask", None) + if output_sampling_mask is not None: + sampling_mask_list[idx].append(self._make_sampling_mask_list(output_sampling_mask)) speculate_metrics[idx] = data["metrics"].get("speculate_metrics", None) if data["finished"]: trace_carrier = data.get("trace_carrier") @@ -695,6 +704,7 @@ async def chat_completion_full_generator( draft_logprob_contents=draft_logprob_contents, response_processor=response_processor, prompt_logprobs_res_list=prompt_logprobs_res_list, + sampling_mask_list=sampling_mask_list, max_tokens=max_tokens, speculate_metrics=speculate_metrics[idx], ) @@ -749,6 +759,7 @@ async def _create_chat_completion_choice( logprob_contents: list, draft_logprob_contents: list, prompt_logprobs_res_list: list, + sampling_mask_list: list, response_processor: ChatResponseProcessor, max_tokens: int, speculate_metrics: SpeculateMetrics | None, @@ -787,6 +798,11 @@ async def _create_chat_completion_choice( if prompt_logprobs_res_list[idx]: prompt_logprobs_full_res = prompt_logprobs_res_list[idx] + # Flatten per-step List[List[int]] into a single List[List[int]] over all tokens. + sampling_mask_full_res = None + if sampling_mask_list and sampling_mask_list[idx]: + sampling_mask_full_res = [mask for step in sampling_mask_list[idx] for mask in step] + num_cached_tokens[idx] = data.get("num_cached_tokens", 0) num_input_image_tokens[idx] = data.get("num_input_image_tokens", 0) num_input_video_tokens[idx] = data.get("num_input_video_tokens", 0) @@ -810,6 +826,7 @@ async def _create_chat_completion_choice( logprobs=logprobs_full_res, draft_logprobs=draft_logprobs_full_res, prompt_logprobs=prompt_logprobs_full_res, + sampling_mask=sampling_mask_full_res, finish_reason=finish_reason, speculate_metrics=speculate_metrics, ) @@ -1000,3 +1017,18 @@ def _make_logprob_dict( ) for token_id, logprob, rank, token in zip(logprob_token_ids, logprobs, ranks, decoded_tokens) } + + @staticmethod + def _make_sampling_mask_list(sampling_mask) -> List[List[int]]: + """Wrap sampling_mask into a uniform List[List[int]] format. + + sampling_mask is already in sparse-index form (no bool-to-index conversion needed): + Non-MTP: List[int] (indices for 1 token/step) → [[idx, ...]] + MTP: List[List[int]] (indices for N tokens/step) → [[idx, ...], ...] + """ + assert sampling_mask is not None + if sampling_mask and isinstance(sampling_mask[0], list): + # MTP: already List[List[int]], return as-is + return sampling_mask + # Non-MTP: already List[int], wrap in outer list for uniform format + return [sampling_mask] diff --git a/fastdeploy/model_executor/layers/sample/logprobs.py b/fastdeploy/model_executor/layers/sample/logprobs.py index 559abdb298e..84b02b1a68e 100644 --- a/fastdeploy/model_executor/layers/sample/logprobs.py +++ b/fastdeploy/model_executor/layers/sample/logprobs.py @@ -16,6 +16,7 @@ from typing import Callable, List, Optional, Tuple +import numpy as np import paddle import paddle.nn.functional as F import triton @@ -133,7 +134,7 @@ def build_output_logprobs( is_naive: bool = False, logprobs_mode: str = "default", compute_logprobs_fn: Optional[Callable] = None, -) -> Tuple[Optional[LogprobsTensors], Optional[paddle.Tensor]]: +) -> Tuple[Optional[LogprobsTensors], Optional[paddle.Tensor], Optional[paddle.Tensor]]: """ Build logprobs output for both NAIVE and speculative (MTP/Ngram) modes. @@ -153,15 +154,12 @@ def build_output_logprobs( scaling and top_p normalization. Used when logprobs_mode == "raw_logprobs". Returns: - tuple: (logprobs_tensors, cu_batch_token_offset) + tuple: (logprobs_tensors, cu_batch_token_offset, output_logits) """ num_logprobs = sampling_metadata.max_num_logprobs logprobs_tensors = None cu_batch_token_offset = None - if num_logprobs is None: - return logprobs_tensors, cu_batch_token_offset - real_bsz = share_inputs["seq_lens_this_time"].shape[0] if is_naive: @@ -208,6 +206,10 @@ def build_output_logprobs( mask = idx < share_inputs["accept_num"].unsqueeze(1) token_ids = paddle.masked_select(share_inputs["accept_tokens"], mask) + # Adapt for sampling mask + if num_logprobs is None: + return None, None, output_logits + # Compute logprobs with temperature scaling and top_p normalization if logprobs_mode == "raw_logprobs": raw_logprobs = compute_logprobs_fn(output_logits, sampling_metadata) @@ -217,5 +219,28 @@ def build_output_logprobs( raw_logprobs = F.log_softmax(output_logits, axis=-1) logprobs_tensors = gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids) + # output_logits use to compute sampling_mask + return logprobs_tensors, cu_batch_token_offset, output_logits - return logprobs_tensors, cu_batch_token_offset + +def logprobs_renormalize_with_logz(logprobs: paddle.Tensor, logz: np.ndarray, logprobs_tensors: LogprobsTensors): + """ + Renormalize logprobs to match truncated sampling distribution. + Args: + logprobs: tensor [B, max_num_logprobs + 1] + logz: [B], log(sum(probs in candidate set K)) for each request + logprobs_tensors: LogprobsTensors + """ + logz = paddle.to_tensor(logz, dtype=logprobs.dtype) + # Renormalize: log π_masked = log π_full - log Z_K + # Only normalize valid candidates; padding positions use -inf + valid_mask = paddle.isfinite(logprobs) + normalized_logprobs = paddle.where( + valid_mask, logprobs - logz.unsqueeze(1), paddle.full_like(logprobs, float("-inf")) + ) + # Update logprobs_tensors with normalized values + return LogprobsTensors( + logprob_token_ids=logprobs_tensors.logprob_token_ids, + logprobs=normalized_logprobs, + selected_token_ranks=logprobs_tensors.selected_token_ranks, + ) diff --git a/fastdeploy/model_executor/layers/sample/meta_data.py b/fastdeploy/model_executor/layers/sample/meta_data.py index 0d7f6915ab4..e2ecb276957 100644 --- a/fastdeploy/model_executor/layers/sample/meta_data.py +++ b/fastdeploy/model_executor/layers/sample/meta_data.py @@ -66,3 +66,5 @@ class SamplingMetadata: # Add for HPU post-processing seq_lens_encoder: Optional[paddle.Tensor] = None seq_lens_decoder: Optional[paddle.Tensor] = None + # Add for keep sampling mask + keep_sampling_mask: Optional[bool] = None diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 08a33c11096..c0f63ffc136 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -19,6 +19,7 @@ from concurrent.futures import Future, ThreadPoolExecutor from typing import Any, List, Optional +import numpy as np import paddle import paddle.nn.functional as F from paddle import nn @@ -105,6 +106,129 @@ def padding_sampling_params(top_p, top_k, infer_seed, seq_lens_this_time, seq_le return top_p_padding, top_k_padding, topp_seed +def _compute_sampling_mask( + probs: paddle.Tensor, + top_p: paddle.Tensor, + top_k: Optional[paddle.Tensor] = None, + top_k_list: Optional[list] = None, +) -> tuple[List[np.ndarray], np.ndarray]: + """ + Compute a combined top-k + top-p (nucleus) sampling mask as sparse + retained-token indices. + + Processing order: + 1. Sort probs descending once (shared by top-k and top-p stages). + 2. top-k mask — zero out positions beyond top_k[i] in sorted order. + 3. top-k renorm — renormalise in-place after truncation. + 4. top-p mask — cumsum on the already-sorted renormed probs; no + second argsort needed. + 5. intersect — AND of the two masks, applied on GPU before D2H. + + Either filter can be disabled: + - top-k is skipped when top_k_list is None or all values <= 0. + - top-p[i] >= 1.0 → keep all tokens for that request. + + Args: + probs: [num_reqs, vocab_size] softmax probabilities (GPU). + top_p: [num_reqs, 1] top-p threshold per request (GPU). + top_k: [num_reqs, 1] top-k per request (GPU, int); 0 = disabled. + top_k_list: Python list of top-k values; used to decide whether any + top-k filtering is needed at all. + + Returns: + Tuple of (sparse_indices, logz_per_batch): + - sparse_indices: List of length num_reqs; element i is a 1-D int64 + numpy array of the retained vocab indices for request i. + - logz_per_batch: 1-D numpy array of shape [num_reqs] containing + log(Z_K) where Z_K is the sum of probabilities in the candidate set. + """ + real_bsz = probs.shape[0] + vocab_size = probs.shape[1] + top_p = top_p[:real_bsz] # [B, 1] + + has_top_k = top_k is not None and top_k_list and any(x > 0 for x in top_k_list) + + # ------------------------------------------------------------------ + # Stage 1: single sort — descending by probability. + # sorted_indices / sorted_probs are reused by both top-k and top-p. + # ------------------------------------------------------------------ + sorted_indices = paddle.argsort(probs, axis=-1, descending=True) # [B, V] + sorted_probs = paddle.take_along_axis(probs, sorted_indices, axis=-1) # [B, V] + + # ------------------------------------------------------------------ + # Stage 2: top-k mask (GPU, no D2H) + # ------------------------------------------------------------------ + if has_top_k: + top_k = top_k[:real_bsz] # [B, 1] + # top_k == 0 means "disabled" → keep all columns for that row. + effective_k = paddle.where(top_k > 0, top_k, paddle.full_like(top_k, vocab_size)) + + # Relax: also keep positions whose prob ties with the k-th element. + # boundary index (0-based) = effective_k - 1, clamped to [0, V-1]. + k_idx = (effective_k - 1).clip(min=0).squeeze(-1).astype("int64") # [B] k-th index + batch_idx = paddle.arange(k_idx.shape[0], dtype="int64") # [B] bs index + boundary_prob = sorted_probs[batch_idx, k_idx].unsqueeze(-1) # [B, 1] min_probs in topk candidates + topk_mask = sorted_probs >= boundary_prob # [B, V] True = retained by top-k + + # Zero out tail, then renorm row-wise. + masked_sorted_probs = paddle.where(topk_mask, sorted_probs, paddle.zeros_like(sorted_probs)) + row_sums = masked_sorted_probs.sum(axis=-1, keepdim=True).clip(min=1e-9) + renorm_sorted_probs = masked_sorted_probs / row_sums # [B, V] + else: + topk_mask = None + renorm_sorted_probs = sorted_probs + + # ------------------------------------------------------------------ + # Stage 3: top-p mask on already-sorted renormed probs (no re-sort). + # ------------------------------------------------------------------ + cum_probs = paddle.cumsum(renorm_sorted_probs, axis=-1) # [B, V] + topp_mask = (cum_probs - renorm_sorted_probs) <= top_p # [B, V] + # When top_p[i] >= 1.0, keep the entire row. + topp_mask = paddle.where( + (top_p >= 1.0).expand_as(topp_mask), + paddle.ones_like(topp_mask), + topp_mask, + ) + + # Extend mask to cover sort tie-breaking: include all tokens whose + # probability >= the boundary token's probability (last retained + # in sorted order). In descending-sorted probs this just extends + # the contiguous True block by the run of equal-prob tokens. + k_per_row = topp_mask.astype("int32").sum(axis=-1, keepdim=True) # [B,1] + # boundary_idx = last True position (k-1), clamp for safety + boundary_idx = (k_per_row - 1).clip(min=0) # [B, 1] + boundary_prob = paddle.take_along_axis( + renorm_sorted_probs, + boundary_idx, + axis=-1, + ) # [B, 1] + topp_mask = topp_mask | (renorm_sorted_probs >= boundary_prob) + + # ------------------------------------------------------------------ + # Stage 4: intersect on GPU, then minimal D2H. + # ------------------------------------------------------------------ + final_mask = topk_mask & topp_mask if has_top_k else topp_mask # [B, V] + + k_per_row = final_mask.astype("int32").sum(axis=-1) # [B] + max_k = int(k_per_row.max().item()) + + # ------------------------------------------------------------------ + # Stage 5: compute logZ_K for renormalization + # Z_K = sum(probs[i] * final_mask[i]) for each request i + # logZ_K = log(Z_K), with small constant to avoid log(0) + # ------------------------------------------------------------------ + candidate_probs = paddle.where(final_mask, sorted_probs, paddle.zeros_like(sorted_probs)) + z_k = candidate_probs.sum(axis=-1) # [B] + logz_per_batch = paddle.log(z_k + 1e-10).cpu().numpy() # [B] + + # Transfer only the leading max_k columns — typically max_k << vocab_size. + indices_window_cpu = sorted_indices[:, :max_k].cpu().numpy() # [B, max_k] + mask_window_cpu = final_mask[:, :max_k].cpu().numpy() # [B, max_k] + + sparse_indices = [indices_window_cpu[i, mask_window_cpu[i]] for i in range(real_bsz)] + return sparse_indices, logz_per_batch + + class GuidedDecoding: """ processor for guided decoding. @@ -554,6 +678,19 @@ def forward_cuda( _record_logits_diagnostic(logits, tag="post_penalty_logits", probs=probs) probs = min_p_sampling(probs, sampling_metadata.min_p, sampling_metadata.min_p_list) + + # Compute sampling mask BEFORE top_k_top_p_sampling modifies probs. + # Binary mask [num_reqs, vocab_size]: 1 = retained by top_k/top_p, 0 = truncated. + sampling_mask = None + logz_per_batch = None + if sampling_metadata.keep_sampling_mask: + sampling_mask, logz_per_batch = _compute_sampling_mask( + probs, + sampling_metadata.top_p, + top_k=sampling_metadata.top_k, + top_k_list=sampling_metadata.top_k_list, + ) + _, next_tokens = top_k_top_p_sampling( probs, sampling_metadata.top_p, @@ -577,6 +714,8 @@ def forward_cuda( sampled_token_ids=next_tokens, logprobs_tensors=logprobs_tensors, logits=logits, + sampling_mask=sampling_mask, + logz_per_batch=logz_per_batch, ) return sampler_output @@ -1029,9 +1168,10 @@ def forward_cuda( reject_all_drafts, ) + keep_sampling_mask = sampling_metadata.keep_sampling_mask # Build logprobs via unified path (outside of sampling logic) - if sampling_metadata.max_num_logprobs is not None: - logprobs_tensors, cu_batch_token_offset = build_output_logprobs( + if sampling_metadata.max_num_logprobs is not None or keep_sampling_mask: + logprobs_tensors, cu_batch_token_offset, target_logits = build_output_logprobs( logits, sampling_metadata, share_inputs, @@ -1042,6 +1182,33 @@ def forward_cuda( sampler_output.logprobs_tensors = logprobs_tensors if cu_batch_token_offset is not None: sampler_output.cu_batch_token_offset = cu_batch_token_offset.cpu() + if keep_sampling_mask: + real_bsz = share_inputs["seq_lens_this_time"].shape[0] + accept_nums = share_inputs["accept_num"][:real_bsz].reshape([-1]) + # Derive target probs from already-extracted target_logits; avoids a second kernel call. + target_probs = F.softmax(target_logits, axis=-1) + # Compute sampling mask at accepted token positions. + # Shape: [total_accepted_tokens, vocab_size], bool (CPU). + # Expand top_p from [batch, 1] to [total_accepted, 1]. + # total_accepted = accept_nums.sum() + accept_top_p = ( + sampling_metadata.top_p[:real_bsz].squeeze(1).repeat_interleave(accept_nums).unsqueeze(1) + ) + accept_top_k = None + if ( + sampling_metadata.top_k is not None + and sampling_metadata.top_k_list + and any(x > 0 for x in sampling_metadata.top_k_list) + ): + accept_top_k = ( + sampling_metadata.top_k[:real_bsz].squeeze(1).repeat_interleave(accept_nums).unsqueeze(1) + ) + sampler_output.sampling_mask, sampler_output.logz_per_batch = _compute_sampling_mask( + target_probs, + accept_top_p, + top_k=accept_top_k, + top_k_list=sampling_metadata.top_k_list, + ) return sampler_output def forward_xpu( diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 29fc4235381..9e14133146c 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -22,6 +22,7 @@ from fastdeploy import envs from fastdeploy.config import SpeculativeConfig +from fastdeploy.inter_communicator import ZmqIpcClient from fastdeploy.model_executor.ops.gpu import ( mtp_save_first_token, mtp_save_first_token_with_topk, @@ -114,6 +115,9 @@ from fastdeploy.model_executor.layers.moe.routing_indices_cache import ( RoutingReplayManager, ) +from fastdeploy.model_executor.layers.sample.logprobs import ( + logprobs_renormalize_with_logz, +) from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata from fastdeploy.output.pooler import PoolerOutput, PoolingSequenceGroupOutput from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData @@ -221,6 +225,7 @@ def _build_stream_transfer_data( pooler_outputs: List[PoolingSequenceGroupOutput] = None, logprobs: Optional[LogprobsTensors] = None, prompt_logprobs_list: Optional[LogprobsTensors] = None, + sampling_mask: Optional[List[np.ndarray]] = None, ): """Split output_tokens and output""" @@ -230,6 +235,8 @@ def _build_stream_transfer_data( output_tokens = output_tokens.numpy().reshape([-1]) output_tokens_lists = np.split(output_tokens, output_tokens.shape[0]) + sampling_mask_list = sampling_mask + for bid, output_token_per_sample in enumerate(output_tokens_lists): stream_transfer_data = StreamTransferData( decoder_state=DecoderState.TEXT, tokens=output_token_per_sample, batch_id=bid @@ -238,6 +245,8 @@ def _build_stream_transfer_data( stream_transfer_data.logprobs = logprobs.slice_rows(bid, bid + 1) if prompt_logprobs_list: stream_transfer_data.prompt_logprobs = prompt_logprobs_list[bid] + if sampling_mask_list is not None: + stream_transfer_data.sampling_mask = sampling_mask_list[bid] stream_transfer_datas.append(stream_transfer_data) elif pooler_outputs is not None: for bid, pooler_output in enumerate(pooler_outputs): @@ -373,6 +382,14 @@ def post_process_normal( model_output.is_block_step, ) + # Renormalize logprobs to match truncated sampling distribution (when enabled). + if sampler_output.logprobs_tensors is not None and sampler_output.logz_per_batch is not None: + sampler_output.logprobs_tensors = logprobs_renormalize_with_logz( + sampler_output.logprobs_tensors.logprobs, + sampler_output.logz_per_batch, + sampler_output.logprobs_tensors, + ) + def save_output_normal( model_output: ModelOutputData, @@ -380,6 +397,7 @@ def save_output_normal( share_inputs: Dict[str, paddle.Tensor], async_output_queue: queue.Queue = None, save_each_rank: bool = False, + sampling_mask_zmq_client: Optional[ZmqIpcClient] = None, ): # Transmit the model's output and stop generation signal via message queue. # In the future, we will abandon this approach. @@ -398,6 +416,7 @@ def save_output_normal( recover_share_inputs_map["sampled_token_ids"], logprobs=sampler_output.logprobs_tensors, prompt_logprobs_list=model_output.prompt_logprobs_list, + sampling_mask=sampler_output.sampling_mask, ) async_output_queue.put(output) else: @@ -434,6 +453,12 @@ def save_output_normal( recover_share_inputs_map["last_preempted_idx"], model_output.mp_rank, ) + # Send sampling_mask via ZMQ side-channel when enabled. + if sampler_output.sampling_mask is not None and model_output.mp_rank == 0: + # sampling_mask is List[np.ndarray] of sparse int indices, one array per request. + mask_dict = {i: arr.tolist() for i, arr in enumerate(sampler_output.sampling_mask)} + + sampling_mask_zmq_client.send_pyobj(mask_dict) share_inputs["last_preempted_idx"][:] = 0 @@ -525,6 +550,14 @@ def post_process_specualate( model_output.max_dec_len, # max_dec_len ) + # Renormalize logprobs to match truncated sampling distribution (when enabled). + if sampler_output.logprobs_tensors is not None and sampler_output.logz_per_batch is not None: + sampler_output.logprobs_tensors = logprobs_renormalize_with_logz( + sampler_output.logprobs_tensors.logprobs, + sampler_output.logz_per_batch, + sampler_output.logprobs_tensors, + ) + def save_output_specualate( sampler_output: SamplerOutput, @@ -534,6 +567,7 @@ def save_output_specualate( local_rank: int, tensor_parallel_rank: int, save_each_rank: bool = False, + sampling_mask_zmq_client: ZmqIpcClient = None, is_mtp_prefill: bool = False, ): if is_mtp_prefill: @@ -656,6 +690,24 @@ def save_output_specualate( model_output.mp_rank, save_each_rank, ) + # Send sampling_mask via ZMQ side-channel when enabled. + if sampler_output.sampling_mask is not None and model_output.mp_rank == 0: + # sampling_mask is List[np.ndarray] of sparse int indices, length = total_accepted_tokens. + # Group by request using accept_num so each entry is List[np.ndarray] (n arrays per req). + real_bsz = model_output.accept_num.shape[0] + accept_nums = model_output.accept_num[:real_bsz].flatten().tolist() + mask_dict = {} + offset = 0 + total_masks = len(sampler_output.sampling_mask) + for i, n in enumerate(accept_nums): + n = max(int(n), 0) + if n > 0: + # List of n sparse index arrays, one per accepted token + mask_dict[i] = [arr.tolist() for arr in sampler_output.sampling_mask[offset : offset + n]] + offset += n + if offset != total_masks: + raise ValueError(f"sampling_mask length mismatch: expected {offset}, got {total_masks}") + sampling_mask_zmq_client.send_pyobj(mask_dict) share_inputs["last_preempted_idx"][:] = 0 diff --git a/fastdeploy/output/stream_transfer_data.py b/fastdeploy/output/stream_transfer_data.py index b32e01c954f..dce21bb5963 100644 --- a/fastdeploy/output/stream_transfer_data.py +++ b/fastdeploy/output/stream_transfer_data.py @@ -46,3 +46,7 @@ class StreamTransferData: accept_num: Optional[np.array] = None # [num_reqs, hidden_size] pooler_output: Optional[np.array] = None + # 1-D int32 numpy array of vocab indices retained by top_p/top_k for + # this request. Sparse format: only retained positions, not a dense + # vocab-sized bool mask. + sampling_mask: Optional[np.array] = None diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 85e54647b7e..7195bf83aa8 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -83,6 +83,14 @@ def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, split_conn self.speculative_decoding = self.cfg.speculative_config.method is not None self.use_logprobs = self.cfg.model_config.enable_logprob + self.use_sampling_mask = getattr(self.cfg.model_config, "enable_keep_sampling_mask", False) + if not envs.FD_USE_GET_SAVE_OUTPUT_V1 and self.use_sampling_mask: + rank_id = self.cfg.parallel_config.local_data_parallel_id + port = self.cfg.parallel_config.engine_worker_queue_port[rank_id] + self.sampling_mask_zmq_server = ZmqIpcServer( + name=f"sampling_mask_output_rank_{rank_id}_{port}", mode=zmq.PULL + ) + llm_logger.info(f"create zmq sampling_mask_output_rank_{rank_id}_{port}") self.enable_draft_logprob = self.cfg.speculative_config.enable_draft_logprob if self.speculative_decoding: @@ -357,6 +365,8 @@ def _process_batch_output_use_zmq(self, receive_datas): result.prompt_logprobs = stream_data.prompt_logprobs except Exception as e: llm_logger.warning(f"Failed to parse prompt_logprobs from StreamTransferData: {e}") + if getattr(stream_data, "sampling_mask", None) is not None: + result.outputs.sampling_mask = stream_data.sampling_mask.tolist() if self.tokens_counter[task_id] == 0: if task.messages is not None: result.prompt = task.messages @@ -734,6 +744,15 @@ def _process_batch_output(self): batch = self.output_tokens[1, 0] tokens = tokens[2 : batch + 2] + # Receive sampling constraints per request from ZMQ side-channel (if enabled). + # The worker sends a dict {batch_id: sparse_vocab_indices} each step, + # where the value is a list[int] or list[list[int]] of allowed token ids + sampling_masks_per_request = {} + if self.use_sampling_mask and not envs.FD_USE_GET_SAVE_OUTPUT_V1 and hasattr(self, "sampling_mask_zmq_server"): + _, mask_data = self.sampling_mask_zmq_server.receive_pyobj_once(block=True) + if mask_data is not None and isinstance(mask_data, dict): + sampling_masks_per_request = mask_data + batch_result = list() # reschedule for i in range(batch): @@ -868,6 +887,9 @@ def _process_batch_output(self): result.num_input_image_tokens = task.multimodal_inputs.get("num_input_image_tokens", 0) result.num_input_video_tokens = task.multimodal_inputs.get("num_input_video_tokens", 0) + if self.use_sampling_mask and i in sampling_masks_per_request: + result.outputs.sampling_mask = sampling_masks_per_request[i] + if is_prefill and len(token_ids) > 1: result.outputs.draft_token_ids = copy.deepcopy(token_ids) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 43478e1a817..679711e2970 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -126,6 +126,7 @@ def __init__( self.spec_method = self.fd_config.speculative_config.method self.speculative_decoding = self.spec_method is not None self.enable_logprob = fd_config.model_config.enable_logprob + self.enable_keep_sampling_mask = fd_config.model_config.enable_keep_sampling_mask self.enable_early_stop = self.fd_config.early_stop_config.enable_early_stop self.is_pooling_model = self.fd_config.model_config.runner_type == "pooling" self.ori_vocab_size = self.fd_config.model_config.ori_vocab_size @@ -236,6 +237,17 @@ def __init__( # Rollout routing replay config self.routing_replay_manager = None + # ZMQ side-channel for sampling_mask in non-FD_USE_GET_SAVE_OUTPUT_V1 path + self.sampling_mask_zmq_client = None + if not envs.FD_USE_GET_SAVE_OUTPUT_V1 and self.enable_keep_sampling_mask: + rank_id = self.parallel_config.local_data_parallel_id + port = self.parallel_config.engine_worker_queue_port[rank_id] + self.sampling_mask_zmq_client = ZmqIpcClient( + name=f"sampling_mask_output_rank_{rank_id}_{port}", mode=zmq.PUSH + ) + self.sampling_mask_zmq_client.connect() + logger.info(f"create send zmq sampling_mask_output_rank_{rank_id}_{port}") + self.zmq_client = None self.async_output_queue = None if envs.FD_USE_GET_SAVE_OUTPUT_V1: @@ -1233,6 +1245,7 @@ def _prepare_inputs(self, cached_token_num=-1, cached_real_bsz=-1, is_dummy_or_p top_p_normalized_logprobs=self.share_inputs["top_p_normalized_logprobs"], logits_processors=self.share_inputs["logits_processors"], share_inputs=self.share_inputs, + keep_sampling_mask=self.enable_keep_sampling_mask, ) return token_num, token_num_event @@ -2486,6 +2499,7 @@ def _save_model_output( local_rank=self.local_rank, tensor_parallel_rank=self.parallel_config.tensor_parallel_rank, save_each_rank=self.parallel_config.use_ep, + sampling_mask_zmq_client=self.sampling_mask_zmq_client, is_mtp_prefill=( self.spec_method == SpecMethod.MTP and self.scheduler_config.splitwise_role == "prefill" ), @@ -2497,6 +2511,7 @@ def _save_model_output( share_inputs=self.share_inputs, async_output_queue=self.async_output_queue, save_each_rank=self.parallel_config.use_ep, + sampling_mask_zmq_client=self.sampling_mask_zmq_client, ) def _pool(self, hidden_states: paddle.Tensor, num_running_requests: int) -> Optional[ModelRunnerOutput]: diff --git a/fastdeploy/worker/output.py b/fastdeploy/worker/output.py index 365fec12475..44cc9cb9e16 100644 --- a/fastdeploy/worker/output.py +++ b/fastdeploy/worker/output.py @@ -15,8 +15,9 @@ """ from dataclasses import dataclass, field -from typing import NamedTuple, Optional +from typing import List, NamedTuple, Optional +import numpy as np import paddle @@ -178,6 +179,24 @@ class SamplerOutput: token_num_per_batch: Optional[paddle.Tensor] = None cu_batch_token_offset: Optional[paddle.Tensor] = None logits: Optional[paddle.Tensor] = None + # Sparse sampling mask for top_p/top_k: + # - Non-speculative decoding: per-request mask. This is a list of length + # num_reqs, where element i is a 1-D int32 numpy array of vocab indices + # retained by top_p/top_k for request i. Replaces the previous dense + # [num_reqs, vocab_size] bool tensor. + # - Speculative decoding: flattened per-accepted-token mask. This may be + # stored as a list aligned with all accepted tokens + # (e.g. length = total_accepted_tokens) and is regrouped by accept_num + # (number of accepted tokens per request) in post-processing before + # being sent back as per-request data. + # Callers MUST NOT assume this is always shaped by num_reqs; they should + # check whether the current path is speculative or non-speculative when + # interpreting the dimension. + sampling_mask: Optional[List[np.ndarray]] = None + # logZ_K for each request: log(sum(probs in candidate set K)) + # Used for renormalizing logprobs to match the truncated sampling distribution. + # Shape: [num_reqs] + logz_per_batch: Optional[np.ndarray] = None @dataclass diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 3f2a1fcf0dd..be4883bdefb 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -1095,6 +1095,16 @@ def parse_args(): help="Maximum tokens per item in mm input.", ) + parser.add_argument( + "--enable_keep_sampling_mask", + "--enable-keep-sampling-mask", + action="store_true", + help=( + "Enable output of keep_sampling_mask as sparse vocab index list per token step " + "(Non-MTP: List[int]; MTP: List[List[int]])." + ), + ) + parser.add_argument( "--num_cpu_blocks", type=int, diff --git a/tests/e2e/test_ernie_21b_mtp.py b/tests/e2e/test_ernie_21b_mtp.py index dc60a213217..0ac4ec789af 100644 --- a/tests/e2e/test_ernie_21b_mtp.py +++ b/tests/e2e/test_ernie_21b_mtp.py @@ -83,6 +83,7 @@ def setup_and_run_server(): json.dumps(speculative_config), "--graph-optimization-config", '{"use_cudagraph":true, "use_unique_memory_pool":true, "draft_model_use_cudagraph":true}', + "--enable-keep-sampling-mask", ] # Start subprocess in new process group @@ -366,3 +367,176 @@ def test_mtp_accept_ratio(api_url): prompt_tokens = chunks[-1]["usage"]["prompt_tokens"] cached_tokens = chunks[-1]["usage"]["prompt_tokens_details"]["cached_tokens"] assert cached_tokens == prompt_tokens // 64 * 64, "cached_tokens数量有问题" + + +def _assert_sampling_mask_format(sampling_mask, max_tokens): + """验证 sampling_mask 字段格式的公共辅助函数。 + + sampling_mask 是 List[List[int]]: + - 外层列表长度 == 生成的 token 数(completion_tokens),对应 MTP 每步可接受多个 token + - 内层列表为保留位置的词汇表索引(int),非空且单调递增 + """ + assert sampling_mask is not None, "sampling_mask 不应为 None" + assert isinstance(sampling_mask, list), "sampling_mask 应为 list" + assert len(sampling_mask) > 0, "sampling_mask 不应为空" + assert len(sampling_mask) <= max_tokens, "sampling_mask 长度不应超过 max_tokens" + + for token_mask in sampling_mask: + assert isinstance(token_mask, list), f"每个 token 的 mask 应为 list,实际: {type(token_mask)}" + assert len(token_mask) > 0, "每个 token 的 mask 不应为空(至少保留采样到的 token)" + for idx in token_mask: + assert isinstance(idx, int), f"mask 中的每个元素应为 int,实际: {type(idx)}" + assert idx >= 0, f"mask 索引不应为负数,实际: {idx}" + + +def test_keep_sampling_mask_stream(api_url): + """测试流式响应中 keep_sampling_mask 功能(MTP 模式)。 + + 验证: + 1. 每个非空 chunk 的 choices[0].sampling_mask 格式为 List[List[int]] + 2. 内层列表包含词汇表保留位置的索引,非空且单调递增 + 3. 最终 sampling_mask 总长度等于 completion_tokens + """ + max_tokens = 20 + payload = { + "model": "default", + "temperature": 1.0, + "top_p": 0.9, + "seed": 42, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "请用一句话介绍Python语言。"}, + ], + "max_tokens": max_tokens, + "stream": True, + "stream_options": {"include_usage": True}, + } + + response = send_request(url=api_url, payload=payload) + chunks = get_stream_chunks(response) + + assert len(chunks) > 1, "流式响应应包含至少两个 chunk" + + all_sampling_masks = [] + for chunk in chunks[:-1]: # 最后一个 chunk 是 usage-only + choice = chunk["choices"][0] + # 仅当 delta 有实际内容时才应携带 sampling_mask(首个 role chunk 内容为空,不含该字段) + has_content = bool(choice.get("delta", {}).get("content")) + mask = choice.get("sampling_mask") + if has_content: + assert mask is not None, f"有内容的 chunk 缺少 sampling_mask 字段: {choice}" + if mask is not None: + assert isinstance(mask, list), f"sampling_mask 应为 list,实际: {type(mask)}" + for token_mask in mask: + assert isinstance(token_mask, list), "每个 token mask 应为 list" + assert len(token_mask) > 0, "每个 token mask 不应为空" + for idx in token_mask: + assert isinstance(idx, int) and idx >= 0, f"mask 索引应为非负 int,实际: {idx}" + all_sampling_masks.extend(mask) + + # 最后一个 chunk 携带 usage 信息 + usage = chunks[-1].get("usage") + if usage: + completion_tokens = usage["completion_tokens"] + assert ( + len(all_sampling_masks) == completion_tokens + ), f"sampling_mask 总长度 {len(all_sampling_masks)} 应等于 completion_tokens {completion_tokens}" + + +def test_keep_sampling_mask_non_stream(api_url): + """测试非流式响应中 keep_sampling_mask 功能(MTP 模式)。 + + 验证: + 1. choices[0].sampling_mask 格式为 List[List[int]] + 2. 长度等于 completion_tokens + 3. 内层列表包含非负递增的词汇表索引 + """ + max_tokens = 20 + payload = { + "model": "default", + "temperature": 1.0, + "top_p": 0.9, + "seed": 42, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "请用一句话介绍Python语言。"}, + ], + "max_tokens": max_tokens, + "stream": False, + } + + response = send_request(url=api_url, payload=payload).json() + assert "choices" in response, f"响应缺少 choices 字段: {response}" + choice = response["choices"][0] + assert "sampling_mask" in choice, f"choice 缺少 sampling_mask 字段: {choice}" + + sampling_mask = choice["sampling_mask"] + completion_tokens = response["usage"]["completion_tokens"] + _assert_sampling_mask_format(sampling_mask, max_tokens) + assert ( + len(sampling_mask) == completion_tokens + ), f"sampling_mask 长度 {len(sampling_mask)} 应等于 completion_tokens {completion_tokens}" + + +def test_keep_sampling_mask_top_p_1_stream(api_url): + """测试 top_p=1.0 时流式响应的 sampling_mask(MTP 模式)。 + + top_p=1.0 表示保留全部词汇,每个 token mask 应包含所有词汇表位置。 + 验证 mask 非空且每个内层列表长度 > 1(至少保留多个候选 token)。 + """ + max_tokens = 10 + payload = { + "model": "default", + "temperature": 1.0, + "top_p": 1.0, + "seed": 42, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "1+1="}, + ], + "max_tokens": max_tokens, + "stream": True, + "stream_options": {"include_usage": True}, + } + + response = send_request(url=api_url, payload=payload) + chunks = get_stream_chunks(response) + assert len(chunks) > 1, "流式响应应包含至少两个 chunk" + + for chunk in chunks[:-1]: + choice = chunk["choices"][0] + mask = choice.get("sampling_mask") + if mask is not None: + for token_mask in mask: + assert len(token_mask) > 1, "top_p=1.0 时每个 token 的候选集应大于 1" + + +def test_keep_sampling_mask_consistent_with_top_p(api_url): + """对比 top_p=0.1 与 top_p=0.9 时 sampling_mask 的候选集大小(非流式,MTP 模式)。 + + top_p 越小,保留的候选 token 越少,平均 mask 长度应更短。 + """ + max_tokens = 15 + + def get_avg_mask_len(top_p): + payload = { + "model": "default", + "temperature": 1.0, + "top_p": top_p, + "seed": 42, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "请列举三种编程语言。"}, + ], + "max_tokens": max_tokens, + "stream": False, + } + resp = send_request(url=api_url, payload=payload).json() + mask = resp["choices"][0].get("sampling_mask") + if not mask: + return 0 + return sum(len(m) for m in mask) / len(mask) + + avg_small = get_avg_mask_len(0.1) + avg_large = get_avg_mask_len(0.9) + assert avg_small <= avg_large, f"top_p=0.1 的平均 mask 长度 ({avg_small:.1f}) 应 <= top_p=0.9 ({avg_large:.1f})" diff --git a/tests/entrypoints/openai/test_max_streaming_tokens.py b/tests/entrypoints/openai/test_max_streaming_tokens.py index d98e79b74f2..bd7b6482b09 100644 --- a/tests/entrypoints/openai/test_max_streaming_tokens.py +++ b/tests/entrypoints/openai/test_max_streaming_tokens.py @@ -577,6 +577,7 @@ async def test_create_chat_completion_choice(self): response_processor=mock_response_processor, max_tokens=max_tokens_list[idx], speculate_metrics=None, + sampling_mask_list=None, ) expected = case["expected"] diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 1b33405503f..12f20f39eab 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -398,6 +398,7 @@ async def test_create_chat_completion_choice_audio_recover(self): response_processor=response_processor, max_tokens=2, speculate_metrics=None, + sampling_mask_list=None, ) self.assertEqual(choice.finish_reason, "recover_stop") @@ -421,6 +422,7 @@ async def test_create_chat_completion_choice_audio_recover(self): response_processor=response_processor, max_tokens=2, speculate_metrics=None, + sampling_mask_list=None, ) self.assertEqual(choice_length.finish_reason, "length") diff --git a/tests/metrics/test_new_metrics.py b/tests/metrics/test_new_metrics.py index 030acaf4299..f650d6d7d7c 100644 --- a/tests/metrics/test_new_metrics.py +++ b/tests/metrics/test_new_metrics.py @@ -54,6 +54,8 @@ def test_cache_metrics_update_history(self, mock_main_process_metrics): def setUp(self): """为 TokenProcessor 测试设置通用的 mock 对象。""" self.mock_cfg = MagicMock() + self.mock_cfg.parallel_config.local_data_parallel_id = 0 + self.mock_cfg.parallel_config.engine_worker_queue_port = ["9700"] self.mock_cached_generated_tokens = MagicMock() self.mock_engine_worker_queue = MagicMock() self.mock_split_connector = MagicMock() diff --git a/tests/output/test_process_batch_draft_tokens.py b/tests/output/test_process_batch_draft_tokens.py index 3686dd1b64b..eef5df62cc9 100644 --- a/tests/output/test_process_batch_draft_tokens.py +++ b/tests/output/test_process_batch_draft_tokens.py @@ -30,6 +30,8 @@ def setUp(self): # 模拟 cfg cfg = MagicMock() cfg.speculative_config = MagicMock() + cfg.parallel_config.local_data_parallel_id = 0 + cfg.parallel_config.engine_worker_queue_port = ["9700"] cfg.speculative_config.method = "mtp" cfg.speculative_config.num_speculative_tokens = 3 cfg.model_config = MagicMock() diff --git a/tests/output/test_process_batch_output.py b/tests/output/test_process_batch_output.py index 46282cd386a..9398e07d9f5 100644 --- a/tests/output/test_process_batch_output.py +++ b/tests/output/test_process_batch_output.py @@ -166,6 +166,7 @@ def setup_token_processor(self, speculative_decoding=False, use_logprobs=False): processor.total_step_per_request = {} processor.accept_token_num_per_head_per_request = {} processor.accept_token_num_per_head = [0] * MAX_DRAFT_TOKENS + processor.use_sampling_mask = False # processor._recycle_resources = Mock() diff --git a/tests/output/test_process_batch_output_use_zmq.py b/tests/output/test_process_batch_output_use_zmq.py index 07826e6f0eb..8244bb06bbf 100644 --- a/tests/output/test_process_batch_output_use_zmq.py +++ b/tests/output/test_process_batch_output_use_zmq.py @@ -31,6 +31,7 @@ def setUp(self): self.cfg.model_config.enable_logprob = True self.cfg.speculative_config.method = None self.cfg.parallel_config.local_data_parallel_id = 0 + self.cfg.parallel_config.engine_worker_queue_port = ["9700"] self.cached_generated_tokens = MagicMock() self.engine_worker_queue = MagicMock() self.split_connector = MagicMock() diff --git a/tests/output/test_token_processor_trace_print.py b/tests/output/test_token_processor_trace_print.py index 9ba9b45dfae..018038143f3 100644 --- a/tests/output/test_token_processor_trace_print.py +++ b/tests/output/test_token_processor_trace_print.py @@ -23,6 +23,8 @@ class TestTokenProcessorMetrics: def setup_method(self): self.mock_cfg = MagicMock() + self.mock_cfg.parallel_config.local_data_parallel_id = 0 + self.mock_cfg.parallel_config.engine_worker_queue_port = ["9700"] self.mock_cached_tokens = MagicMock() self.mock_engine_queue = MagicMock() self.mock_split_connector = MagicMock()