Skip to content
Draft
14 changes: 10 additions & 4 deletions packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,14 @@ class ModelConfig(BaseModelConfig):


class TrainSamplingConfig(BaseConfig):
temperature: float = Field(1.0, ge=0)
"""Sampling temperature."""
temperature: float = Field(1.0, gt=0)
"""Sampling temperature. Must be strictly positive — the trainer divides logits by it."""

top_p: float = Field(1.0, gt=0, le=1)
"""Nucleus sampling threshold. 1.0 disables. When < 1.0, the trainer replays the same truncation when computing logprobs so the importance ratio stays unbiased."""

top_k: int = Field(-1, ge=-1)
"""Top-k sampling. -1 disables. When > 0, the trainer replays the same truncation when computing logprobs so the importance ratio stays unbiased."""

repetition_penalty: float = Field(1.0, ge=0)
"""Repetition penalty. Values > 1.0 discourage repetition, < 1.0 encourage it, 1.0 disables."""
Expand All @@ -69,7 +75,7 @@ def to_sampling_args(self) -> dict[str, Any]:
# Top-level OAI params
args: dict[str, Any] = {
"temperature": self.temperature,
"top_p": 1.0,
"top_p": self.top_p,
"logprobs": True,
}
if self.max_completion_tokens is not None:
Expand All @@ -79,6 +85,7 @@ def to_sampling_args(self) -> dict[str, Any]:

# vLLM extra_body params
extra_body = dict(self.extra_body)
extra_body.setdefault("top_k", self.top_k)
if self.min_tokens > 0:
extra_body["min_tokens"] = self.min_tokens
if self.repetition_penalty != 1.0:
Expand Down Expand Up @@ -954,7 +961,6 @@ def resolve_env_config(self):
for env in self.train.env:
env.extra_env_kwargs.update(max_seq_len=self.seq_len)
if is_vllm:
env.sampling.extra_body.setdefault("top_k", -1)
env.sampling.extra_body.setdefault("min_p", 0.0)
env.sampling.extra_body.setdefault("return_token_ids", True)
return self
12 changes: 7 additions & 5 deletions src/prime_rl/orchestrator/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,10 @@ def interleave_rollout(
return None

has_error = output["error"] is not None
# this field should be guaranteed because we set temperature in get_sampling_args
temperature = output["sampling_args"]["temperature"]
sampling_args = output["sampling_args"]
temperature = sampling_args["temperature"]
top_p = sampling_args["top_p"]
top_k = sampling_args["extra_body"]["top_k"]

def prepare_step_tokens(step: vf.TrajectoryStep, step_idx: int) -> dict[str, Any] | None:
tokens = step["tokens"]
Expand Down Expand Up @@ -279,7 +281,9 @@ def make_sample(tokens: dict[str, Any]) -> TrainingSample:
completion_ids=completion_ids,
completion_mask=completion_mask,
completion_logprobs=list(tokens["completion_logprobs"]),
completion_temperatures=[temperature] * len(completion_ids),
temperature=float(temperature),
top_k=int(top_k),
top_p=float(top_p),
teacher_logprobs=None,
advantage=None,
env_name=output["env_name"],
Expand All @@ -296,7 +300,6 @@ def extend_sample(sample: TrainingSample, prefix_len: int, step_idx: int) -> Non
sample.completion_ids.extend(new_prompt_ids)
sample.completion_mask.extend([False] * len(new_prompt_ids))
sample.completion_logprobs.extend([0.0] * len(new_prompt_ids))
sample.completion_temperatures.extend([temperature] * len(new_prompt_ids))

# Extend with new completion tokens
completion_ids = tokens["completion_ids"]
Expand All @@ -306,7 +309,6 @@ def extend_sample(sample: TrainingSample, prefix_len: int, step_idx: int) -> Non
else:
sample.completion_mask.extend(bool(i) for i in tokens["completion_mask"])
sample.completion_logprobs.extend(tokens["completion_logprobs"])
sample.completion_temperatures.extend([temperature] * len(completion_ids))

if tokens.get("routed_experts") is not None and sample.routed_experts is not None:
step_routed = tokens["routed_experts"]
Expand Down
21 changes: 7 additions & 14 deletions src/prime_rl/trainer/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,6 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch
assert training_example.env_name != "all", "env_name='all' is reserved for aggregate metric keys"
env_names = [training_example.env_name] * len(input_ids)

# Per-token temperatures: prompt tokens use first completion temp (masked out anyway)
# Default to 1.0 if completion is empty (e.g., model generated only tool calls with no text)
prompt_temp = training_example.completion_temperatures[0] if training_example.completion_temperatures else 1.0
temperatures = [prompt_temp] * len(training_example.prompt_ids) + training_example.completion_temperatures

# Teacher logprobs already cover the full sequence (prompt + completion),
# computed via prefill in the orchestrator when a teacher model is configured
teacher_logprobs = training_example.teacher_logprobs
Expand All @@ -36,7 +31,6 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch
position_ids = position_ids[:seq_len]
advantages = advantages[:seq_len]
rewards = rewards[:seq_len]
temperatures = temperatures[:seq_len]
if teacher_logprobs is not None:
teacher_logprobs = teacher_logprobs[:seq_len]
if routed_experts is not None:
Expand All @@ -52,9 +46,8 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch
== len(position_ids)
== len(inference_logprobs)
== len(rewards)
== len(temperatures)
), (
f"input_ids: {len(input_ids)}, advantages: {len(advantages)}, loss_mask: {len(loss_mask)}, position_ids: {len(position_ids)}, inference_logprobs: {len(inference_logprobs)}, rewards: {len(rewards)}, temperatures: {len(temperatures)}"
f"input_ids: {len(input_ids)}, advantages: {len(advantages)}, loss_mask: {len(loss_mask)}, position_ids: {len(position_ids)}, inference_logprobs: {len(inference_logprobs)}, rewards: {len(rewards)}"
)
if teacher_logprobs is not None:
assert len(teacher_logprobs) == len(input_ids), f"teacher_logprobs: {len(teacher_logprobs)}"
Expand All @@ -77,7 +70,9 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch
position_ids=position_ids,
inference_logprobs=inference_logprobs,
teacher_logprobs=teacher_logprobs,
temperatures=temperatures,
temperature=training_example.temperature,
top_k=training_example.top_k,
top_p=training_example.top_p,
rewards=rewards,
routed_experts=routed_experts,
mm_token_type_ids=mm_token_type_ids,
Expand All @@ -98,7 +93,6 @@ def packed_samples_into_micro_bs(
"""
Pack samples into micro_batch efficiently.
We follow the First Fit Decreasing algorithm to pack the samples into bins and minimize potential padding while never truncating.
With per-token temperatures, samples can be packed together regardless of their temperature values.

NOTE: Multimodal samples (with mm_kwargs) are NOT packed together as they have variable-sized
vision data that doesn't pack well. Each multimodal sample becomes its own micro batch.
Expand All @@ -122,10 +116,12 @@ def packed_samples_into_micro_bs(
# Don't pack into multimodal micro batches
if _is_multimodal_sample(bin_content):
continue
# Check if sequence fits in this bin
if (
len(bin_content.input_ids) + len(sample.input_ids) <= max_seq_len
and bin_content.training_mode == sample.training_mode
and bin_content.temperature == sample.temperature
and bin_content.top_k == sample.top_k
and bin_content.top_p == sample.top_p
):
existing_len = len(bin_content.input_ids)
bin_content.input_ids.extend(sample.input_ids)
Expand All @@ -138,7 +134,6 @@ def packed_samples_into_micro_bs(
elif bin_content.rewards is not None:
bin_content.rewards.extend([float("nan")] * len(sample.input_ids))
bin_content.inference_logprobs.extend(sample.inference_logprobs)
bin_content.temperatures.extend(sample.temperatures)
if sample.teacher_logprobs is not None:
if bin_content.teacher_logprobs is None:
bin_content.teacher_logprobs = []
Expand Down Expand Up @@ -192,8 +187,6 @@ def pad_micro_batch(micro_batch: MicroBatch, pad_to_multiple_of: int) -> MicroBa
micro_batch.loss_mask.extend([False] * padding_size)
micro_batch.position_ids.extend(list(range(padding_size)))
micro_batch.inference_logprobs.extend([0.0] * padding_size)
# Use temperature 1.0 for padding tokens (doesn't matter since loss_mask is False)
micro_batch.temperatures.extend([1.0] * padding_size)
if micro_batch.teacher_logprobs is not None:
micro_batch.teacher_logprobs.extend([0.0] * padding_size)
micro_batch.lora_num_tokens[-1] += (
Expand Down
16 changes: 12 additions & 4 deletions src/prime_rl/trainer/rl/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ class TensorMicroBatch(TypedDict):
inference_logprobs: Float[Tensor, "batch seq"]
teacher_logprobs: Float[Tensor, "batch seq"] | None
loss_mask: Bool[Tensor, "batch seq"]
temperatures: Float[Tensor, "batch seq"] # Per-token temperatures
temperature: float
top_k: int # -1 disables truncation
top_p: float # 1.0 disables truncation
env_names: list[str]

# Batch level
Expand Down Expand Up @@ -112,7 +114,9 @@ def _get_sample_micro_batch(self, generator: torch.Generator) -> TensorMicroBatc
"rewards": None,
"inference_logprobs": inference_logprobs.unsqueeze(0),
"teacher_logprobs": None,
"temperatures": torch.ones(input_ids.shape[0]).unsqueeze(0),
"temperature": 1.0,
"top_k": -1,
"top_p": 1.0,
"env_names": ["fake"] * input_ids.shape[0],
"loss_mask": loss_mask.unsqueeze(0),
"lora_num_tokens": lora_num_tokens,
Expand Down Expand Up @@ -140,7 +144,9 @@ def _get_micro_batch(self, generator: torch.Generator) -> TensorMicroBatch:
"rewards": None,
"inference_logprobs": torch.randn(self.seq_len, generator=generator).unsqueeze(0),
"teacher_logprobs": None,
"temperatures": torch.ones(self.seq_len).unsqueeze(0),
"temperature": 1.0,
"top_k": -1,
"top_p": 1.0,
"env_names": ["fake"] * self.seq_len,
"loss_mask": torch.ones(self.seq_len, dtype=torch.bool).unsqueeze(0),
"lora_num_tokens": lora_num_tokens,
Expand Down Expand Up @@ -222,7 +228,9 @@ def _micro_batch_to_tensor(self, micro_batch: MicroBatch) -> TensorMicroBatch:
if micro_batch.teacher_logprobs is not None
else None,
loss_mask=torch.tensor(micro_batch.loss_mask, dtype=torch.bool).unsqueeze(0),
temperatures=torch.tensor(micro_batch.temperatures, dtype=torch.float).unsqueeze(0),
temperature=micro_batch.temperature,
top_k=micro_batch.top_k,
top_p=micro_batch.top_p,
env_names=micro_batch.env_names,
lora_num_tokens=torch.tensor(micro_batch.lora_num_tokens, dtype=torch.int32),
mm_kwargs=mm_kwargs,
Expand Down
46 changes: 46 additions & 0 deletions src/prime_rl/trainer/rl/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,52 @@ def selective_log_softmax(
return torch.gather(logprobs, dim=-1, index=index.unsqueeze(-1)).squeeze(-1)


def apply_top_k_top_p(
logits: Tensor,
top_k: int | None,
top_p: float | None,
labels: Tensor | None = None,
) -> Tensor:
"""Mirror vLLM's top-k / top-p truncation so the trainer logits sum over
the same support as the inference-time sample. Bit-exact match against
vLLM's ``apply_top_k_top_p_pytorch``. Scalar k / p; ``k <= 0`` and
``p >= 1.0`` are no-ops.

``labels`` is the safety guard: restore the label's original logit after
masking so FP-precision boundary cases (the sampled token falling just
outside the trainer's top-k) don't push its logprob to ``-inf``. Also
keeps prompt / padding positions finite when the scalar truncation is
applied uniformly.
"""
vocab_size = logits.shape[-1]
do_top_k = top_k is not None and 0 < top_k < vocab_size
do_top_p = top_p is not None and top_p < 1.0
if not do_top_k and not do_top_p:
return logits

label_logits = logits.gather(-1, labels.unsqueeze(-1)) if labels is not None else None

if do_top_k:
top_values, _ = logits.topk(top_k, dim=-1)
threshold = top_values[..., -1:]
logits = logits.masked_fill(logits < threshold, float("-inf"))

if do_top_p:
# Sort ascending so the low-probability tail is at the front (mirrors vLLM).
sorted_logits, sorted_idx = logits.sort(dim=-1, descending=False)
sorted_probs = sorted_logits.softmax(dim=-1)
cumprobs = sorted_probs.cumsum(dim=-1)
top_p_mask = cumprobs <= (1.0 - top_p)
top_p_mask[..., -1] = False # always keep the top token
sorted_logits = sorted_logits.masked_fill(top_p_mask, float("-inf"))
logits = torch.empty_like(logits).scatter_(-1, sorted_idx, sorted_logits)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Top-p masking applied to logits instead of raw probabilities

Medium Severity

The top-p implementation computes sorted_probs = sorted_logits.softmax(dim=-1) then uses cumulative probabilities to create a mask, but then applies that mask to sorted_logits (the pre-softmax values) using masked_fill. After masking some logits to -inf, when selective_log_softmax later recomputes log_softmax, the renormalized probabilities will differ from what vLLM computed because the softmax denominator changes. In contrast, vLLM's apply_top_k_top_p_pytorch determines which tokens to remove via the same cumulative probability logic but the key difference is that vLLM computes processed_logprobs on the already-truncated distribution in a single pass, whereas here the softmax used to determine the mask boundary and the final log_softmax in selective_log_softmax are two separate computations over the same truncated set—which is actually equivalent. On closer inspection the math is consistent since the same tokens end up masked in both the boundary-finding softmax and the final log_softmax. However, there is a real discrepancy: after top-k masking sets some logits to -inf, the top-p branch recomputes softmax on these partially-masked logits, finding the cumulative boundary on the renormalized (post-top-k) distribution. If vLLM applies top-p on the original pre-top-k distribution rather than post-top-k, the mask boundaries would differ. The PR claims bit-exact matching against vLLM's reference but vLLM applies both filters to the same original logits distribution simultaneously rather than sequentially.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 0548ac5. Configure here.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bit-exact verified against apply_top_k_top_p_pytorch with both top_k and top_p set (0 mask disagreements, 0 value diff). vLLM applies top-p the same way — sort, mask top-k to -inf, then softmax(post-top-k-sorted_logits) + cumsum + mask top-p tail. The order is equivalent.


if label_logits is not None:
logits = logits.scatter(-1, labels.unsqueeze(-1), label_logits)

return logits


@jaxtyped(typechecker=typechecker)
@torch.compile(dynamic=True)
def compute_entropy(shifted_logits: Float[Tensor, "batch seq vocab"]) -> Float[Tensor, "batch seq"]:
Expand Down
5 changes: 0 additions & 5 deletions src/prime_rl/trainer/rl/packer.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,6 @@ def _validate_sample(self, sample: TrainingSample) -> tuple[bool, str | None]:
False,
f"Run wrote a sample with completion logprobs length != completion ids length ({len(sample.completion_logprobs)} != {len(sample.completion_ids)})",
)
if len(sample.completion_temperatures) != len(sample.completion_ids):
return (
False,
f"Run wrote a sample with completion temperatures length != completion ids length ({len(sample.completion_temperatures)} != {len(sample.completion_ids)})",
)
if sample_length == 0:
return False, "Run wrote a sample with no tokens"
if sample_length > self.seq_len:
Expand Down
25 changes: 18 additions & 7 deletions src/prime_rl/trainer/rl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)
from prime_rl.utils.logger import setup_logger
from prime_rl.trainer.rl.loss import (
apply_top_k_top_p,
compute_entropy,
compute_loss,
compute_importance_ratio_and_mismatch_kl,
Expand Down Expand Up @@ -420,11 +421,12 @@ def load_run_checkpoint(_optimizer, idx: int) -> None:
)
set_lora_num_tokens(lora_num_tokens)

temperatures = micro_batch["temperatures"].to("cuda")

# Shard temperatures for context parallelism if enabled
if cp_enabled:
temperatures = shard_for_cp(temperatures, cp_rank=cp_rank, cp_world_size=cp_size)
temperature = micro_batch["temperature"]
top_k = micro_batch["top_k"]
top_p = micro_batch["top_p"]
truncate_logits = (top_k > 0) or (top_p < 1.0)
# Fused LM head wants a per-token tensor; CP-sharded input_ids gives the right shape.
temperatures = torch.full(input_ids.shape, temperature, dtype=torch.float32, device="cuda")

# Forward pass with per-token temperatures
with maybe_record_function("forward"), maybe_activation_offloading(config.model.ac_offloading):
Expand All @@ -445,9 +447,18 @@ def load_run_checkpoint(_optimizer, idx: int) -> None:
logits = out["logits"]
# Per-token temperature scaling: temperatures is [batch, seq], logits is [batch, seq, vocab]
scaled_logits = logits / temperatures.unsqueeze(-1)
out["logprobs"] = selective_log_softmax(scaled_logits, labels)
# Entropy on the full distribution so it stays comparable across truncated / not.
out["entropy"] = compute_entropy(scaled_logits)
# else: FusedOutputLinear was used - logprobs already computed with per-token temperatures
scaled_logits = apply_top_k_top_p(scaled_logits, top_k, top_p, labels=labels)
out["logprobs"] = selective_log_softmax(scaled_logits, labels)
else:
# FusedOutputLinear streams over vocab chunks and can't find a global top-k threshold.
if truncate_logits:
raise ValueError(
"top_k / top_p truncation requires the vanilla LM head - set "
"model.fused_lm_head_token_chunk_size = 'disabled' or run with top_k=-1 "
"and top_p=1.0."
)

if cp_enabled:
out["logprobs"] = gather_for_cp(out["logprobs"], cp_group)
Expand Down
12 changes: 10 additions & 2 deletions src/prime_rl/transport/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@ class TrainingSample(msgspec.Struct, array_like=True, gc=False, omit_defaults=Tr
completion_ids: list[int]
completion_mask: list[bool]
completion_logprobs: list[float]
completion_temperatures: list[float] # Per-token temperatures used during generation
temperature: float
env_name: str
# ``top_k = -1`` and ``top_p = 1.0`` disable truncation.
top_k: int = -1
top_p: float = 1.0
teacher_logprobs: list[float] | None = None
advantage: float | None = None
reward: float | None = None
Expand Down Expand Up @@ -66,8 +69,13 @@ class MicroBatch(msgspec.Struct, array_like=True, gc=False, omit_defaults=True):
advantages: list[float]
inference_logprobs: list[float]
position_ids: list[int]
temperatures: list[float] # Per-token temperatures used during generation
# Scalar sampling args shared by every sample in the micro batch (the
# packer enforces uniformity). The trainer replays the same truncation
# when computing logprobs so the importance ratio stays unbiased.
temperature: float
env_names: list[str]
top_k: int = -1
top_p: float = 1.0
teacher_logprobs: list[float] | None = None
lora_num_tokens: list[int] | None = None
routed_experts: list[list[list[int]]] | None = None
Expand Down
Loading
Loading