feat: support top-k / top-p sampling with trainer-side replay#2601
feat: support top-k / top-p sampling with trainer-side replay#2601mikasenghaas wants to merge 7 commits into
Conversation
Add `top_k` and `top_p` to `TrainSamplingConfig` and replay the same truncation on the trainer side when computing logprobs, so the importance ratio against vLLM's processed logprobs stays unbiased. - Per-token `completion_top_k` / `completion_top_p` plumbed through `TrainingSample` → `MicroBatch` → `TensorMicroBatch`, mirroring the existing per-token temperature path. - `apply_top_k_top_p` in `loss.py` mirrors vLLM's pytorch reference (bit-exact against `apply_top_k_top_p_pytorch`). Accepts optional labels and restores their logits after masking so FP-precision boundary cases don't blow up the loss. - Trainer shift-left-aligns top_k / top_p with the label being predicted, matching how labels are shifted left for next-token prediction. - Fused LM head rejects non-trivial truncation since the chunked kernel can't find a global top-k threshold without materializing the full logits. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Per-token list[float] was overkill — temperature, top_k, and top_p are constant per rollout. Storing scalars cuts wire size, simplifies the trainer (no shift-left alignment, no CP shard for these), and makes the truncation code a clean scalar broadcast. Trade-off: samples in a single micro batch must share temperature / top_k / top_p (in addition to the existing training_mode constraint). The packer enforces this. In practice all training samples share the same sampling config, so the constraint is non-restrictive. - TrainingSample: `completion_temperature: float`, `completion_top_k: int`, `completion_top_p: float` (replace the per-token lists). - MicroBatch: scalar `temperature` / `top_k` / `top_p`. - TensorMicroBatch: same scalars (Python ints/floats, not tensors). - Trainer constructs a `temperatures` tensor by broadcasting the scalar across the (possibly CP-sharded) sequence to feed the fused LM head. - `apply_top_k_top_p` takes scalar k / p — no per-row threshold logic, no passthrough handling. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`temperature` / `top_k` / `top_p` are properties of the whole rollout, not specifically of the completion tokens. The prefix made sense on `completion_ids` / `completion_mask` / `completion_logprobs` (which are per-completion-token data) but reads oddly on the scalar sampling args. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- ``output["sampling_args"]`` always carries top_p / top_k now (set unconditionally in TrainSamplingConfig.to_sampling_args), so read them directly. Updated tests accordingly. - Drop the packer docstring paragraph spelling out the scalar-sampling-arg constraint — trivially true from the field types. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Same fixture update as for test_trajectories.py — the orchestrator now unconditionally reads ``sampling_args["top_p"]`` and ``sampling_args["extra_body"]["top_k"]`` (the defensive ``.get`` defaults were dropped in d64c62f). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 2 potential issues.
❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
Reviewed by Cursor Bugbot for commit 0548ac5. Configure here.
| top_p_mask = cumprobs <= (1.0 - top_p) | ||
| top_p_mask[..., -1] = False # always keep the top token | ||
| sorted_logits = sorted_logits.masked_fill(top_p_mask, float("-inf")) | ||
| logits = torch.empty_like(logits).scatter_(-1, sorted_idx, sorted_logits) |
There was a problem hiding this comment.
Top-p masking applied to logits instead of raw probabilities
Medium Severity
The top-p implementation computes sorted_probs = sorted_logits.softmax(dim=-1) then uses cumulative probabilities to create a mask, but then applies that mask to sorted_logits (the pre-softmax values) using masked_fill. After masking some logits to -inf, when selective_log_softmax later recomputes log_softmax, the renormalized probabilities will differ from what vLLM computed because the softmax denominator changes. In contrast, vLLM's apply_top_k_top_p_pytorch determines which tokens to remove via the same cumulative probability logic but the key difference is that vLLM computes processed_logprobs on the already-truncated distribution in a single pass, whereas here the softmax used to determine the mask boundary and the final log_softmax in selective_log_softmax are two separate computations over the same truncated set—which is actually equivalent. On closer inspection the math is consistent since the same tokens end up masked in both the boundary-finding softmax and the final log_softmax. However, there is a real discrepancy: after top-k masking sets some logits to -inf, the top-p branch recomputes softmax on these partially-masked logits, finding the cumulative boundary on the renormalized (post-top-k) distribution. If vLLM applies top-p on the original pre-top-k distribution rather than post-top-k, the mask boundaries would differ. The PR claims bit-exact matching against vLLM's reference but vLLM applies both filters to the same original logits distribution simultaneously rather than sequentially.
Reviewed by Cursor Bugbot for commit 0548ac5. Configure here.
There was a problem hiding this comment.
Bit-exact verified against apply_top_k_top_p_pytorch with both top_k and top_p set (0 mask disagreements, 0 value diff). vLLM applies top-p the same way — sort, mask top-k to -inf, then softmax(post-top-k-sorted_logits) + cumsum + mask top-p tail. The order is equivalent.
Bugbot caught a divergence: ``TrainSamplingConfig.temperature`` allowed ``ge=0`` but the new packer validation rejected ``<= 0``, so a user setting ``temperature: 0`` got valid config parsing but zero samples reaching the trainer. Tighten the field to ``gt=0`` (greedy decoding is undefined for the RL trainer — logits / 0 is NaN) and drop the redundant packer check. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>


Summary
top_kandtop_ptoTrainSamplingConfigso they can be enabled for training rollouts (previously onlyEvalSamplingConfigexposed them).temperature/top_k/top_ponTrainingSampleandMicroBatch).processed_logprobsstays unbiased.Closes #2600.
Why this needs trainer-side replay
prime-rl already configures vLLM with
logprobs_mode="processed_logprobs"(inference.py:412), soinference_logprobsreturned by vLLM are computed from the truncated, renormalized distribution after top-k / top-p. If we enable top-k / top-p sampling but the trainer keeps computing logprobs over the full vocab, the importance ratio (trainer / inference) is biased.Implementation notes
apply_top_k_top_p(intrainer/rl/loss.py) mirrors vLLM'sapply_top_k_top_p_pytorchreference. Verified bit-exact against the reference (same-infmask positions, zero value diff) across top-k only, top-p only, and combined.labels. We restore the label token's original logit after masking so FP-precision boundary cases (the sampled token falling just outside the trainer's top-k under bf16 forward) don't push its logprob to-infand blow up the masked loss. Without this guard, even step 0 with identical inference / trainer weights hitsLoss: infbecause of bf16 ranking jitter at the boundary.MicroBatch(one value per micro batch, not per-token). The packer enforces that samples sharing a micro batch also share(training_mode, temperature, top_k, top_p)— non-restrictive in practice because a training run shares its sampling config.FusedOutputLinear) rejects non-trivial truncation: the chunked kernel streams over vocab chunks and can't find a global top-k threshold without materializing the full logits. The vanilla LM head (the RL default) is unaffected.Multi-run compatibility
MultiPacker.pack()groups samples byrun_idxand callsprepare_batchseparately for each run (packer.py:320-330). The scalar-sampling-args constraint then only needs to hold within a single run, which it does naturally since one run = one RL config = one sampling config. Verified end-to-end:idxs=[run_idx] * len(run_samples)setslora_num_tokens[run_idx] = Ncorrectly.(temperature, top_k, top_p)and the trainer applies truncation normally — the label-safety guard keeps logprobs finite andloss_mask=Falsezeros the loss contribution.Future work: defer-to-vLLM-defaults
TrainSamplingConfigcurrently requires explicit values fortop_k/top_p(defaults-1/1.0, both disabled). SupportingNone— "use whatever vLLM applies, which comes from the model'sgeneration_config.json" — would be nice for ergonomics, but vLLM doesn't echo the resolved sampling params back in its response, so the orchestrator wouldn't know what to record on theTrainingSample, and the trainer wouldn't know what to mirror. The clean implementation is to load the model'sGenerationConfigorchestrator-side at startup, substituteNone→ default value before sending, then both vLLM and the trainer use the same explicit values. Left out for now.Verification
Ran reverse-text RL with
Qwen3-0.6B(default generation config:top_k=20, top_p=0.95) on 2 GPUs. Mismatch KL is small and stable, on the same order as a no-truncation baseline — confirming the trainer's truncation matches what vLLM did:For comparison, running with inference-side top-k / top-p but trainer-side truncation disabled gives ~2-3x larger Mismatch KL at step 0 (
0.0025vs0.0007) even with mild truncation (top_k=20keeps ~99% of mass on Qwen). With more aggressive truncation the gap would be much wider.All 33 unit tests in
tests/unit/orchestrator/test_batch.py,test_trajectories.py,test_teacher_logprobs.py, andtests/unit/train/rl/test_packer.pypass.🤖 Generated with Claude Code
Note
Medium Risk
Changes sampling parameter plumbing and logprob computation in the RL trainer; mistakes could bias importance ratios or break training when using fused LM heads.
Overview
Adds train-time
top_k/top_psampling controls (and enforcestemperature > 0) and plumbs these sampling args from orchestrator rollouts throughTrainingSample/MicroBatchinto the trainer.Updates batching/packing to treat sampling args as scalar per microbatch (no more per-token temperature lists), splitting microbatches when
(temperature, top_k, top_p, training_mode)differ.Implements trainer-side replay of vLLM-style top-k/top-p truncation via
apply_top_k_top_pbefore computinglogprobs, and explicitly rejects truncation when using the fused LM head; tests are updated accordingly.Reviewed by Cursor Bugbot for commit a12741d. Bugbot is set up for automated code reviews on this repo. Configure here.