Skip to content

feat: support top-k / top-p sampling with trainer-side replay#2601

Draft
mikasenghaas wants to merge 7 commits into
mainfrom
feat/top-k-p-sampling
Draft

feat: support top-k / top-p sampling with trainer-side replay#2601
mikasenghaas wants to merge 7 commits into
mainfrom
feat/top-k-p-sampling

Conversation

@mikasenghaas
Copy link
Copy Markdown
Member

@mikasenghaas mikasenghaas commented May 22, 2026

Summary

  • Add top_k and top_p to TrainSamplingConfig so they can be enabled for training rollouts (previously only EvalSamplingConfig exposed them).
  • Plumb the sampling args from the orchestrator to the trainer (temperature / top_k / top_p on TrainingSample and MicroBatch).
  • Replay the same truncation on the trainer side when computing logprobs, so the importance ratio against vLLM's processed_logprobs stays unbiased.

Closes #2600.

Why this needs trainer-side replay

prime-rl already configures vLLM with logprobs_mode="processed_logprobs" (inference.py:412), so inference_logprobs returned by vLLM are computed from the truncated, renormalized distribution after top-k / top-p. If we enable top-k / top-p sampling but the trainer keeps computing logprobs over the full vocab, the importance ratio (trainer / inference) is biased.

Implementation notes

  • apply_top_k_top_p (in trainer/rl/loss.py) mirrors vLLM's apply_top_k_top_p_pytorch reference. Verified bit-exact against the reference (same -inf mask positions, zero value diff) across top-k only, top-p only, and combined.
  • The helper accepts optional labels. We restore the label token's original logit after masking so FP-precision boundary cases (the sampled token falling just outside the trainer's top-k under bf16 forward) don't push its logprob to -inf and blow up the masked loss. Without this guard, even step 0 with identical inference / trainer weights hits Loss: inf because of bf16 ranking jitter at the boundary.
  • Sampling args are stored as scalars on MicroBatch (one value per micro batch, not per-token). The packer enforces that samples sharing a micro batch also share (training_mode, temperature, top_k, top_p) — non-restrictive in practice because a training run shares its sampling config.
  • The fused LM head (FusedOutputLinear) rejects non-trivial truncation: the chunked kernel streams over vocab chunks and can't find a global top-k threshold without materializing the full logits. The vanilla LM head (the RL default) is unaffected.
  • Entropy is computed on the pre-truncation distribution so the metric stays comparable across runs that do / don't truncate.

Multi-run compatibility

MultiPacker.pack() groups samples by run_idx and calls prepare_batch separately for each run (packer.py:320-330). The scalar-sampling-args constraint then only needs to hold within a single run, which it does naturally since one run = one RL config = one sampling config. Verified end-to-end:

  • Per-run isolation: each run's scalars flow through to its own micro batches; the trainer broadcasts each batch's scalar inside its loop, so truncation matches what inference applied for that run.
  • LoRA routing preserved: idxs=[run_idx] * len(run_samples) sets lora_num_tokens[run_idx] = N correctly.
  • Mixed envs within a run: if two envs in the same run use different sampling configs, the packer splits cleanly into separate micro batches. Packing efficiency drops slightly in this case, correctness is preserved.
  • Dummy batches added for worker-count divisibility deepcopy from a real batch, so they inherit the same (temperature, top_k, top_p) and the trainer applies truncation normally — the label-safety guard keeps logprobs finite and loss_mask=False zeros the loss contribution.

Future work: defer-to-vLLM-defaults

TrainSamplingConfig currently requires explicit values for top_k / top_p (defaults -1 / 1.0, both disabled). Supporting None — "use whatever vLLM applies, which comes from the model's generation_config.json" — would be nice for ergonomics, but vLLM doesn't echo the resolved sampling params back in its response, so the orchestrator wouldn't know what to record on the TrainingSample, and the trainer wouldn't know what to mirror. The clean implementation is to load the model's GenerationConfig orchestrator-side at startup, substitute None → default value before sending, then both vLLM and the trainer use the same explicit values. Left out for now.

Verification

Ran reverse-text RL with Qwen3-0.6B (default generation config: top_k=20, top_p=0.95) on 2 GPUs. Mismatch KL is small and stable, on the same order as a no-truncation baseline — confirming the trainer's truncation matches what vLLM did:

step with top_k=20 + top_p=0.95 baseline (no truncation)
0 0.0007 0.0010
1 0.0009 0.0012
2 0.0018 0.0014
3 0.0071 0.0054
4 0.0064 0.0045
5 0.0019 0.0021

For comparison, running with inference-side top-k / top-p but trainer-side truncation disabled gives ~2-3x larger Mismatch KL at step 0 (0.0025 vs 0.0007) even with mild truncation (top_k=20 keeps ~99% of mass on Qwen). With more aggressive truncation the gap would be much wider.

All 33 unit tests in tests/unit/orchestrator/test_batch.py, test_trajectories.py, test_teacher_logprobs.py, and tests/unit/train/rl/test_packer.py pass.

🤖 Generated with Claude Code


Note

Medium Risk
Changes sampling parameter plumbing and logprob computation in the RL trainer; mistakes could bias importance ratios or break training when using fused LM heads.

Overview
Adds train-time top_k/top_p sampling controls (and enforces temperature > 0) and plumbs these sampling args from orchestrator rollouts through TrainingSample/MicroBatch into the trainer.

Updates batching/packing to treat sampling args as scalar per microbatch (no more per-token temperature lists), splitting microbatches when (temperature, top_k, top_p, training_mode) differ.

Implements trainer-side replay of vLLM-style top-k/top-p truncation via apply_top_k_top_p before computing logprobs, and explicitly rejects truncation when using the fused LM head; tests are updated accordingly.

Reviewed by Cursor Bugbot for commit a12741d. Bugbot is set up for automated code reviews on this repo. Configure here.

mikasenghaas and others added 5 commits May 22, 2026 20:55
Add `top_k` and `top_p` to `TrainSamplingConfig` and replay the same
truncation on the trainer side when computing logprobs, so the importance
ratio against vLLM's processed logprobs stays unbiased.

- Per-token `completion_top_k` / `completion_top_p` plumbed through
  `TrainingSample` → `MicroBatch` → `TensorMicroBatch`, mirroring the
  existing per-token temperature path.
- `apply_top_k_top_p` in `loss.py` mirrors vLLM's pytorch reference
  (bit-exact against `apply_top_k_top_p_pytorch`). Accepts optional
  labels and restores their logits after masking so FP-precision
  boundary cases don't blow up the loss.
- Trainer shift-left-aligns top_k / top_p with the label being
  predicted, matching how labels are shifted left for next-token
  prediction.
- Fused LM head rejects non-trivial truncation since the chunked
  kernel can't find a global top-k threshold without materializing
  the full logits.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Per-token list[float] was overkill — temperature, top_k, and top_p are
constant per rollout. Storing scalars cuts wire size, simplifies the
trainer (no shift-left alignment, no CP shard for these), and makes the
truncation code a clean scalar broadcast.

Trade-off: samples in a single micro batch must share temperature /
top_k / top_p (in addition to the existing training_mode constraint).
The packer enforces this. In practice all training samples share the
same sampling config, so the constraint is non-restrictive.

- TrainingSample: `completion_temperature: float`, `completion_top_k: int`,
  `completion_top_p: float` (replace the per-token lists).
- MicroBatch: scalar `temperature` / `top_k` / `top_p`.
- TensorMicroBatch: same scalars (Python ints/floats, not tensors).
- Trainer constructs a `temperatures` tensor by broadcasting the scalar
  across the (possibly CP-sharded) sequence to feed the fused LM head.
- `apply_top_k_top_p` takes scalar k / p — no per-row threshold logic,
  no passthrough handling.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`temperature` / `top_k` / `top_p` are properties of the whole rollout,
not specifically of the completion tokens. The prefix made sense on
`completion_ids` / `completion_mask` / `completion_logprobs` (which are
per-completion-token data) but reads oddly on the scalar sampling args.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- ``output["sampling_args"]`` always carries top_p / top_k now (set
  unconditionally in TrainSamplingConfig.to_sampling_args), so read them
  directly. Updated tests accordingly.
- Drop the packer docstring paragraph spelling out the scalar-sampling-arg
  constraint — trivially true from the field types.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@mikasenghaas mikasenghaas requested review from Jackmin801 and samsja May 22, 2026 21:34
@mikasenghaas mikasenghaas marked this pull request as ready for review May 22, 2026 21:37
Same fixture update as for test_trajectories.py — the orchestrator now
unconditionally reads ``sampling_args["top_p"]`` and
``sampling_args["extra_body"]["top_k"]`` (the defensive ``.get`` defaults
were dropped in d64c62f).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
hallerite
hallerite previously approved these changes May 22, 2026
Copy link
Copy Markdown

@cursor cursor Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 2 potential issues.

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit 0548ac5. Configure here.

top_p_mask = cumprobs <= (1.0 - top_p)
top_p_mask[..., -1] = False # always keep the top token
sorted_logits = sorted_logits.masked_fill(top_p_mask, float("-inf"))
logits = torch.empty_like(logits).scatter_(-1, sorted_idx, sorted_logits)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Top-p masking applied to logits instead of raw probabilities

Medium Severity

The top-p implementation computes sorted_probs = sorted_logits.softmax(dim=-1) then uses cumulative probabilities to create a mask, but then applies that mask to sorted_logits (the pre-softmax values) using masked_fill. After masking some logits to -inf, when selective_log_softmax later recomputes log_softmax, the renormalized probabilities will differ from what vLLM computed because the softmax denominator changes. In contrast, vLLM's apply_top_k_top_p_pytorch determines which tokens to remove via the same cumulative probability logic but the key difference is that vLLM computes processed_logprobs on the already-truncated distribution in a single pass, whereas here the softmax used to determine the mask boundary and the final log_softmax in selective_log_softmax are two separate computations over the same truncated set—which is actually equivalent. On closer inspection the math is consistent since the same tokens end up masked in both the boundary-finding softmax and the final log_softmax. However, there is a real discrepancy: after top-k masking sets some logits to -inf, the top-p branch recomputes softmax on these partially-masked logits, finding the cumulative boundary on the renormalized (post-top-k) distribution. If vLLM applies top-p on the original pre-top-k distribution rather than post-top-k, the mask boundaries would differ. The PR claims bit-exact matching against vLLM's reference but vLLM applies both filters to the same original logits distribution simultaneously rather than sequentially.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 0548ac5. Configure here.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bit-exact verified against apply_top_k_top_p_pytorch with both top_k and top_p set (0 mask disagreements, 0 value diff). vLLM applies top-p the same way — sort, mask top-k to -inf, then softmax(post-top-k-sorted_logits) + cumsum + mask top-p tail. The order is equivalent.

Comment thread src/prime_rl/trainer/rl/packer.py Outdated
Bugbot caught a divergence: ``TrainSamplingConfig.temperature`` allowed
``ge=0`` but the new packer validation rejected ``<= 0``, so a user
setting ``temperature: 0`` got valid config parsing but zero samples
reaching the trainer.

Tighten the field to ``gt=0`` (greedy decoding is undefined for the RL
trainer — logits / 0 is NaN) and drop the redundant packer check.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

feat: support top-k training

2 participants