-
Notifications
You must be signed in to change notification settings - Fork 302
feat: support top-k / top-p sampling with trainer-side replay #2601
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
mikasenghaas
wants to merge
7
commits into
main
Choose a base branch
from
feat/top-k-p-sampling
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+147
−81
Draft
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
9ba6b5b
feat: support top-k / top-p sampling with trainer-side truncation replay
mikasenghaas d064fa0
refactor: store sampling args as scalars on TrainingSample / MicroBatch
mikasenghaas 5f2677c
refactor: drop completion_ prefix on sampling args
mikasenghaas d93ce73
refactor: trim verbose comments around sampling-arg plumbing
mikasenghaas d64c62f
refactor: drop defensive lookups and redundant packer doc
mikasenghaas 0548ac5
test: include top_p / top_k in sft trajectory test fixtures
mikasenghaas a12741d
fix: reject temperature <= 0 at config time
mikasenghaas File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Top-p masking applied to logits instead of raw probabilities
Medium Severity
The top-p implementation computes
sorted_probs = sorted_logits.softmax(dim=-1)then uses cumulative probabilities to create a mask, but then applies that mask tosorted_logits(the pre-softmax values) usingmasked_fill. After masking some logits to-inf, whenselective_log_softmaxlater recomputeslog_softmax, the renormalized probabilities will differ from what vLLM computed because the softmax denominator changes. In contrast, vLLM'sapply_top_k_top_p_pytorchdetermines which tokens to remove via the same cumulative probability logic but the key difference is that vLLM computesprocessed_logprobson the already-truncated distribution in a single pass, whereas here the softmax used to determine the mask boundary and the finallog_softmaxinselective_log_softmaxare two separate computations over the same truncated set—which is actually equivalent. On closer inspection the math is consistent since the same tokens end up masked in both the boundary-finding softmax and the final log_softmax. However, there is a real discrepancy: after top-k masking sets some logits to-inf, the top-p branch recomputes softmax on these partially-masked logits, finding the cumulative boundary on the renormalized (post-top-k) distribution. If vLLM applies top-p on the original pre-top-k distribution rather than post-top-k, the mask boundaries would differ. The PR claims bit-exact matching against vLLM's reference but vLLM applies both filters to the same original logits distribution simultaneously rather than sequentially.Reviewed by Cursor Bugbot for commit 0548ac5. Configure here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bit-exact verified against
apply_top_k_top_p_pytorchwith both top_k and top_p set (0 mask disagreements, 0 value diff). vLLM applies top-p the same way — sort, mask top-k to -inf, thensoftmax(post-top-k-sorted_logits)+ cumsum + mask top-p tail. The order is equivalent.