-
Notifications
You must be signed in to change notification settings - Fork 740
[KSM] support keep sampling mask #7460
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: release/2.6
Are you sure you want to change the base?
Changes from all commits
9283c71
83b7333
1a16f7d
59c3b94
f9a6e44
e8cb3dc
6530c94
6731ab5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,6 +16,7 @@ | |
|
|
||
| from typing import Callable, List, Optional, Tuple | ||
|
|
||
| import numpy as np | ||
| import paddle | ||
| import paddle.nn.functional as F | ||
| import triton | ||
|
|
@@ -133,7 +134,7 @@ def build_output_logprobs( | |
| is_naive: bool = False, | ||
| logprobs_mode: str = "default", | ||
| compute_logprobs_fn: Optional[Callable] = None, | ||
| ) -> Tuple[Optional[LogprobsTensors], Optional[paddle.Tensor]]: | ||
| ) -> Tuple[Optional[LogprobsTensors], Optional[paddle.Tensor], Optional[paddle.Tensor]]: | ||
| """ | ||
| Build logprobs output for both NAIVE and speculative (MTP/Ngram) modes. | ||
|
|
||
|
|
@@ -153,15 +154,12 @@ def build_output_logprobs( | |
| scaling and top_p normalization. Used when logprobs_mode == "raw_logprobs". | ||
|
|
||
| Returns: | ||
| tuple: (logprobs_tensors, cu_batch_token_offset) | ||
| tuple: (logprobs_tensors, cu_batch_token_offset, output_logits) | ||
| """ | ||
| num_logprobs = sampling_metadata.max_num_logprobs | ||
| logprobs_tensors = None | ||
| cu_batch_token_offset = None | ||
|
|
||
| if num_logprobs is None: | ||
| return logprobs_tensors, cu_batch_token_offset | ||
|
|
||
| real_bsz = share_inputs["seq_lens_this_time"].shape[0] | ||
|
|
||
| if is_naive: | ||
|
|
@@ -208,6 +206,10 @@ def build_output_logprobs( | |
| mask = idx < share_inputs["accept_num"].unsqueeze(1) | ||
| token_ids = paddle.masked_select(share_inputs["accept_tokens"], mask) | ||
|
|
||
| # Adapt for sampling mask | ||
|
zeroRains marked this conversation as resolved.
|
||
| if num_logprobs is None: | ||
| return None, None, output_logits | ||
This comment was marked as outdated.
Sorry, something went wrong.
This comment was marked as outdated.
Sorry, something went wrong.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 无影响,只有logprobs or sampling_mask为true的时候会执行这个函数,如果只有sampling_mask,仍然需要计算出logits,如果只有logrpobs,那也需要计算logits,我觉得不影响 |
||
|
|
||
| # Compute logprobs with temperature scaling and top_p normalization | ||
| if logprobs_mode == "raw_logprobs": | ||
| raw_logprobs = compute_logprobs_fn(output_logits, sampling_metadata) | ||
|
|
@@ -217,5 +219,28 @@ def build_output_logprobs( | |
| raw_logprobs = F.log_softmax(output_logits, axis=-1) | ||
|
|
||
| logprobs_tensors = gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids) | ||
| # output_logits use to compute sampling_mask | ||
| return logprobs_tensors, cu_batch_token_offset, output_logits | ||
|
|
||
| return logprobs_tensors, cu_batch_token_offset | ||
|
|
||
| def logprobs_renormalize_with_logz(logprobs: paddle.Tensor, logz: np.ndarray, logprobs_tensors: LogprobsTensors): | ||
| """ | ||
| Renormalize logprobs to match truncated sampling distribution. | ||
| Args: | ||
| logprobs: tensor [B, max_num_logprobs + 1] | ||
| logz: [B], log(sum(probs in candidate set K)) for each request | ||
| logprobs_tensors: LogprobsTensors | ||
| """ | ||
| logz = paddle.to_tensor(logz, dtype=logprobs.dtype) | ||
| # Renormalize: log π_masked = log π_full - log Z_K | ||
| # Only normalize valid candidates; padding positions use -inf | ||
| valid_mask = paddle.isfinite(logprobs) | ||
| normalized_logprobs = paddle.where( | ||
| valid_mask, logprobs - logz.unsqueeze(1), paddle.full_like(logprobs, float("-inf")) | ||
|
zeroRains marked this conversation as resolved.
|
||
| ) | ||
| # Update logprobs_tensors with normalized values | ||
| return LogprobsTensors( | ||
| logprob_token_ids=logprobs_tensors.logprob_token_ids, | ||
| logprobs=normalized_logprobs, | ||
| selected_token_ranks=logprobs_tensors.selected_token_ranks, | ||
| ) | ||
|
Comment on lines
+234
to
+246
|
||
Uh oh!
There was an error while loading. Please reload this page.