diff --git a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py index 1507f96079..f4cc4fb1c1 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py @@ -42,8 +42,14 @@ class ModelConfig(BaseModelConfig): class TrainSamplingConfig(BaseConfig): - temperature: float = Field(1.0, ge=0) - """Sampling temperature.""" + temperature: float = Field(1.0, gt=0) + """Sampling temperature. Must be strictly positive — the trainer divides logits by it.""" + + top_p: float = Field(1.0, gt=0, le=1) + """Nucleus sampling threshold. 1.0 disables. When < 1.0, the trainer replays the same truncation when computing logprobs so the importance ratio stays unbiased.""" + + top_k: int = Field(-1, ge=-1) + """Top-k sampling. -1 disables. When > 0, the trainer replays the same truncation when computing logprobs so the importance ratio stays unbiased.""" repetition_penalty: float = Field(1.0, ge=0) """Repetition penalty. Values > 1.0 discourage repetition, < 1.0 encourage it, 1.0 disables.""" @@ -69,7 +75,7 @@ def to_sampling_args(self) -> dict[str, Any]: # Top-level OAI params args: dict[str, Any] = { "temperature": self.temperature, - "top_p": 1.0, + "top_p": self.top_p, "logprobs": True, } if self.max_completion_tokens is not None: @@ -79,6 +85,7 @@ def to_sampling_args(self) -> dict[str, Any]: # vLLM extra_body params extra_body = dict(self.extra_body) + extra_body.setdefault("top_k", self.top_k) if self.min_tokens > 0: extra_body["min_tokens"] = self.min_tokens if self.repetition_penalty != 1.0: @@ -954,7 +961,6 @@ def resolve_env_config(self): for env in self.train.env: env.extra_env_kwargs.update(max_seq_len=self.seq_len) if is_vllm: - env.sampling.extra_body.setdefault("top_k", -1) env.sampling.extra_body.setdefault("min_p", 0.0) env.sampling.extra_body.setdefault("return_token_ids", True) return self diff --git a/src/prime_rl/orchestrator/trajectories.py b/src/prime_rl/orchestrator/trajectories.py index a1e0ff9001..c37393b998 100644 --- a/src/prime_rl/orchestrator/trajectories.py +++ b/src/prime_rl/orchestrator/trajectories.py @@ -230,8 +230,10 @@ def interleave_rollout( return None has_error = output["error"] is not None - # this field should be guaranteed because we set temperature in get_sampling_args - temperature = output["sampling_args"]["temperature"] + sampling_args = output["sampling_args"] + temperature = sampling_args["temperature"] + top_p = sampling_args["top_p"] + top_k = sampling_args["extra_body"]["top_k"] def prepare_step_tokens(step: vf.TrajectoryStep, step_idx: int) -> dict[str, Any] | None: tokens = step["tokens"] @@ -279,7 +281,9 @@ def make_sample(tokens: dict[str, Any]) -> TrainingSample: completion_ids=completion_ids, completion_mask=completion_mask, completion_logprobs=list(tokens["completion_logprobs"]), - completion_temperatures=[temperature] * len(completion_ids), + temperature=float(temperature), + top_k=int(top_k), + top_p=float(top_p), teacher_logprobs=None, advantage=None, env_name=output["env_name"], @@ -296,7 +300,6 @@ def extend_sample(sample: TrainingSample, prefix_len: int, step_idx: int) -> Non sample.completion_ids.extend(new_prompt_ids) sample.completion_mask.extend([False] * len(new_prompt_ids)) sample.completion_logprobs.extend([0.0] * len(new_prompt_ids)) - sample.completion_temperatures.extend([temperature] * len(new_prompt_ids)) # Extend with new completion tokens completion_ids = tokens["completion_ids"] @@ -306,7 +309,6 @@ def extend_sample(sample: TrainingSample, prefix_len: int, step_idx: int) -> Non else: sample.completion_mask.extend(bool(i) for i in tokens["completion_mask"]) sample.completion_logprobs.extend(tokens["completion_logprobs"]) - sample.completion_temperatures.extend([temperature] * len(completion_ids)) if tokens.get("routed_experts") is not None and sample.routed_experts is not None: step_routed = tokens["routed_experts"] diff --git a/src/prime_rl/trainer/batch.py b/src/prime_rl/trainer/batch.py index 9db4aefd74..221f8bb515 100644 --- a/src/prime_rl/trainer/batch.py +++ b/src/prime_rl/trainer/batch.py @@ -19,11 +19,6 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch assert training_example.env_name != "all", "env_name='all' is reserved for aggregate metric keys" env_names = [training_example.env_name] * len(input_ids) - # Per-token temperatures: prompt tokens use first completion temp (masked out anyway) - # Default to 1.0 if completion is empty (e.g., model generated only tool calls with no text) - prompt_temp = training_example.completion_temperatures[0] if training_example.completion_temperatures else 1.0 - temperatures = [prompt_temp] * len(training_example.prompt_ids) + training_example.completion_temperatures - # Teacher logprobs already cover the full sequence (prompt + completion), # computed via prefill in the orchestrator when a teacher model is configured teacher_logprobs = training_example.teacher_logprobs @@ -36,7 +31,6 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch position_ids = position_ids[:seq_len] advantages = advantages[:seq_len] rewards = rewards[:seq_len] - temperatures = temperatures[:seq_len] if teacher_logprobs is not None: teacher_logprobs = teacher_logprobs[:seq_len] if routed_experts is not None: @@ -52,9 +46,8 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch == len(position_ids) == len(inference_logprobs) == len(rewards) - == len(temperatures) ), ( - f"input_ids: {len(input_ids)}, advantages: {len(advantages)}, loss_mask: {len(loss_mask)}, position_ids: {len(position_ids)}, inference_logprobs: {len(inference_logprobs)}, rewards: {len(rewards)}, temperatures: {len(temperatures)}" + f"input_ids: {len(input_ids)}, advantages: {len(advantages)}, loss_mask: {len(loss_mask)}, position_ids: {len(position_ids)}, inference_logprobs: {len(inference_logprobs)}, rewards: {len(rewards)}" ) if teacher_logprobs is not None: assert len(teacher_logprobs) == len(input_ids), f"teacher_logprobs: {len(teacher_logprobs)}" @@ -77,7 +70,9 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch position_ids=position_ids, inference_logprobs=inference_logprobs, teacher_logprobs=teacher_logprobs, - temperatures=temperatures, + temperature=training_example.temperature, + top_k=training_example.top_k, + top_p=training_example.top_p, rewards=rewards, routed_experts=routed_experts, mm_token_type_ids=mm_token_type_ids, @@ -98,7 +93,6 @@ def packed_samples_into_micro_bs( """ Pack samples into micro_batch efficiently. We follow the First Fit Decreasing algorithm to pack the samples into bins and minimize potential padding while never truncating. - With per-token temperatures, samples can be packed together regardless of their temperature values. NOTE: Multimodal samples (with mm_kwargs) are NOT packed together as they have variable-sized vision data that doesn't pack well. Each multimodal sample becomes its own micro batch. @@ -122,10 +116,12 @@ def packed_samples_into_micro_bs( # Don't pack into multimodal micro batches if _is_multimodal_sample(bin_content): continue - # Check if sequence fits in this bin if ( len(bin_content.input_ids) + len(sample.input_ids) <= max_seq_len and bin_content.training_mode == sample.training_mode + and bin_content.temperature == sample.temperature + and bin_content.top_k == sample.top_k + and bin_content.top_p == sample.top_p ): existing_len = len(bin_content.input_ids) bin_content.input_ids.extend(sample.input_ids) @@ -138,7 +134,6 @@ def packed_samples_into_micro_bs( elif bin_content.rewards is not None: bin_content.rewards.extend([float("nan")] * len(sample.input_ids)) bin_content.inference_logprobs.extend(sample.inference_logprobs) - bin_content.temperatures.extend(sample.temperatures) if sample.teacher_logprobs is not None: if bin_content.teacher_logprobs is None: bin_content.teacher_logprobs = [] @@ -192,8 +187,6 @@ def pad_micro_batch(micro_batch: MicroBatch, pad_to_multiple_of: int) -> MicroBa micro_batch.loss_mask.extend([False] * padding_size) micro_batch.position_ids.extend(list(range(padding_size))) micro_batch.inference_logprobs.extend([0.0] * padding_size) - # Use temperature 1.0 for padding tokens (doesn't matter since loss_mask is False) - micro_batch.temperatures.extend([1.0] * padding_size) if micro_batch.teacher_logprobs is not None: micro_batch.teacher_logprobs.extend([0.0] * padding_size) micro_batch.lora_num_tokens[-1] += ( diff --git a/src/prime_rl/trainer/rl/data.py b/src/prime_rl/trainer/rl/data.py index 73e35159af..8dfb0edf93 100644 --- a/src/prime_rl/trainer/rl/data.py +++ b/src/prime_rl/trainer/rl/data.py @@ -24,7 +24,9 @@ class TensorMicroBatch(TypedDict): inference_logprobs: Float[Tensor, "batch seq"] teacher_logprobs: Float[Tensor, "batch seq"] | None loss_mask: Bool[Tensor, "batch seq"] - temperatures: Float[Tensor, "batch seq"] # Per-token temperatures + temperature: float + top_k: int # -1 disables truncation + top_p: float # 1.0 disables truncation env_names: list[str] # Batch level @@ -112,7 +114,9 @@ def _get_sample_micro_batch(self, generator: torch.Generator) -> TensorMicroBatc "rewards": None, "inference_logprobs": inference_logprobs.unsqueeze(0), "teacher_logprobs": None, - "temperatures": torch.ones(input_ids.shape[0]).unsqueeze(0), + "temperature": 1.0, + "top_k": -1, + "top_p": 1.0, "env_names": ["fake"] * input_ids.shape[0], "loss_mask": loss_mask.unsqueeze(0), "lora_num_tokens": lora_num_tokens, @@ -140,7 +144,9 @@ def _get_micro_batch(self, generator: torch.Generator) -> TensorMicroBatch: "rewards": None, "inference_logprobs": torch.randn(self.seq_len, generator=generator).unsqueeze(0), "teacher_logprobs": None, - "temperatures": torch.ones(self.seq_len).unsqueeze(0), + "temperature": 1.0, + "top_k": -1, + "top_p": 1.0, "env_names": ["fake"] * self.seq_len, "loss_mask": torch.ones(self.seq_len, dtype=torch.bool).unsqueeze(0), "lora_num_tokens": lora_num_tokens, @@ -222,7 +228,9 @@ def _micro_batch_to_tensor(self, micro_batch: MicroBatch) -> TensorMicroBatch: if micro_batch.teacher_logprobs is not None else None, loss_mask=torch.tensor(micro_batch.loss_mask, dtype=torch.bool).unsqueeze(0), - temperatures=torch.tensor(micro_batch.temperatures, dtype=torch.float).unsqueeze(0), + temperature=micro_batch.temperature, + top_k=micro_batch.top_k, + top_p=micro_batch.top_p, env_names=micro_batch.env_names, lora_num_tokens=torch.tensor(micro_batch.lora_num_tokens, dtype=torch.int32), mm_kwargs=mm_kwargs, diff --git a/src/prime_rl/trainer/rl/loss.py b/src/prime_rl/trainer/rl/loss.py index 9a9eb25a63..f8c7dcd3ea 100644 --- a/src/prime_rl/trainer/rl/loss.py +++ b/src/prime_rl/trainer/rl/loss.py @@ -47,6 +47,52 @@ def selective_log_softmax( return torch.gather(logprobs, dim=-1, index=index.unsqueeze(-1)).squeeze(-1) +def apply_top_k_top_p( + logits: Tensor, + top_k: int | None, + top_p: float | None, + labels: Tensor | None = None, +) -> Tensor: + """Mirror vLLM's top-k / top-p truncation so the trainer logits sum over + the same support as the inference-time sample. Bit-exact match against + vLLM's ``apply_top_k_top_p_pytorch``. Scalar k / p; ``k <= 0`` and + ``p >= 1.0`` are no-ops. + + ``labels`` is the safety guard: restore the label's original logit after + masking so FP-precision boundary cases (the sampled token falling just + outside the trainer's top-k) don't push its logprob to ``-inf``. Also + keeps prompt / padding positions finite when the scalar truncation is + applied uniformly. + """ + vocab_size = logits.shape[-1] + do_top_k = top_k is not None and 0 < top_k < vocab_size + do_top_p = top_p is not None and top_p < 1.0 + if not do_top_k and not do_top_p: + return logits + + label_logits = logits.gather(-1, labels.unsqueeze(-1)) if labels is not None else None + + if do_top_k: + top_values, _ = logits.topk(top_k, dim=-1) + threshold = top_values[..., -1:] + logits = logits.masked_fill(logits < threshold, float("-inf")) + + if do_top_p: + # Sort ascending so the low-probability tail is at the front (mirrors vLLM). + sorted_logits, sorted_idx = logits.sort(dim=-1, descending=False) + sorted_probs = sorted_logits.softmax(dim=-1) + cumprobs = sorted_probs.cumsum(dim=-1) + top_p_mask = cumprobs <= (1.0 - top_p) + top_p_mask[..., -1] = False # always keep the top token + sorted_logits = sorted_logits.masked_fill(top_p_mask, float("-inf")) + logits = torch.empty_like(logits).scatter_(-1, sorted_idx, sorted_logits) + + if label_logits is not None: + logits = logits.scatter(-1, labels.unsqueeze(-1), label_logits) + + return logits + + @jaxtyped(typechecker=typechecker) @torch.compile(dynamic=True) def compute_entropy(shifted_logits: Float[Tensor, "batch seq vocab"]) -> Float[Tensor, "batch seq"]: diff --git a/src/prime_rl/trainer/rl/packer.py b/src/prime_rl/trainer/rl/packer.py index cf9dcfa02e..40b1ca972f 100644 --- a/src/prime_rl/trainer/rl/packer.py +++ b/src/prime_rl/trainer/rl/packer.py @@ -164,11 +164,6 @@ def _validate_sample(self, sample: TrainingSample) -> tuple[bool, str | None]: False, f"Run wrote a sample with completion logprobs length != completion ids length ({len(sample.completion_logprobs)} != {len(sample.completion_ids)})", ) - if len(sample.completion_temperatures) != len(sample.completion_ids): - return ( - False, - f"Run wrote a sample with completion temperatures length != completion ids length ({len(sample.completion_temperatures)} != {len(sample.completion_ids)})", - ) if sample_length == 0: return False, "Run wrote a sample with no tokens" if sample_length > self.seq_len: diff --git a/src/prime_rl/trainer/rl/train.py b/src/prime_rl/trainer/rl/train.py index d359bae5e7..07a813993b 100644 --- a/src/prime_rl/trainer/rl/train.py +++ b/src/prime_rl/trainer/rl/train.py @@ -27,6 +27,7 @@ ) from prime_rl.utils.logger import setup_logger from prime_rl.trainer.rl.loss import ( + apply_top_k_top_p, compute_entropy, compute_loss, compute_importance_ratio_and_mismatch_kl, @@ -420,11 +421,12 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: ) set_lora_num_tokens(lora_num_tokens) - temperatures = micro_batch["temperatures"].to("cuda") - - # Shard temperatures for context parallelism if enabled - if cp_enabled: - temperatures = shard_for_cp(temperatures, cp_rank=cp_rank, cp_world_size=cp_size) + temperature = micro_batch["temperature"] + top_k = micro_batch["top_k"] + top_p = micro_batch["top_p"] + truncate_logits = (top_k > 0) or (top_p < 1.0) + # Fused LM head wants a per-token tensor; CP-sharded input_ids gives the right shape. + temperatures = torch.full(input_ids.shape, temperature, dtype=torch.float32, device="cuda") # Forward pass with per-token temperatures with maybe_record_function("forward"), maybe_activation_offloading(config.model.ac_offloading): @@ -445,9 +447,18 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: logits = out["logits"] # Per-token temperature scaling: temperatures is [batch, seq], logits is [batch, seq, vocab] scaled_logits = logits / temperatures.unsqueeze(-1) - out["logprobs"] = selective_log_softmax(scaled_logits, labels) + # Entropy on the full distribution so it stays comparable across truncated / not. out["entropy"] = compute_entropy(scaled_logits) - # else: FusedOutputLinear was used - logprobs already computed with per-token temperatures + scaled_logits = apply_top_k_top_p(scaled_logits, top_k, top_p, labels=labels) + out["logprobs"] = selective_log_softmax(scaled_logits, labels) + else: + # FusedOutputLinear streams over vocab chunks and can't find a global top-k threshold. + if truncate_logits: + raise ValueError( + "top_k / top_p truncation requires the vanilla LM head - set " + "model.fused_lm_head_token_chunk_size = 'disabled' or run with top_k=-1 " + "and top_p=1.0." + ) if cp_enabled: out["logprobs"] = gather_for_cp(out["logprobs"], cp_group) diff --git a/src/prime_rl/transport/types.py b/src/prime_rl/transport/types.py index d4c947224f..54185f30a7 100644 --- a/src/prime_rl/transport/types.py +++ b/src/prime_rl/transport/types.py @@ -23,8 +23,11 @@ class TrainingSample(msgspec.Struct, array_like=True, gc=False, omit_defaults=Tr completion_ids: list[int] completion_mask: list[bool] completion_logprobs: list[float] - completion_temperatures: list[float] # Per-token temperatures used during generation + temperature: float env_name: str + # ``top_k = -1`` and ``top_p = 1.0`` disable truncation. + top_k: int = -1 + top_p: float = 1.0 teacher_logprobs: list[float] | None = None advantage: float | None = None reward: float | None = None @@ -66,8 +69,13 @@ class MicroBatch(msgspec.Struct, array_like=True, gc=False, omit_defaults=True): advantages: list[float] inference_logprobs: list[float] position_ids: list[int] - temperatures: list[float] # Per-token temperatures used during generation + # Scalar sampling args shared by every sample in the micro batch (the + # packer enforces uniformity). The trainer replays the same truncation + # when computing logprobs so the importance ratio stays unbiased. + temperature: float env_names: list[str] + top_k: int = -1 + top_p: float = 1.0 teacher_logprobs: list[float] | None = None lora_num_tokens: list[int] | None = None routed_experts: list[list[list[int]]] | None = None diff --git a/tests/unit/orchestrator/test_batch.py b/tests/unit/orchestrator/test_batch.py index e01089ccf4..7a7942c6df 100644 --- a/tests/unit/orchestrator/test_batch.py +++ b/tests/unit/orchestrator/test_batch.py @@ -17,7 +17,7 @@ def _make_training_example( completion_ids=[3, 4], completion_mask=[True, True], completion_logprobs=[-0.1, -0.2], - completion_temperatures=[temperature, temperature], # Per-token temperatures + temperature=temperature, teacher_logprobs=[0.0, 0.0, 0.0, 0.0], advantage=1.0, env_name=env_name, @@ -35,7 +35,7 @@ def test_training_sample_requires_env_name(): completion_ids=[3, 4], completion_mask=[True, True], completion_logprobs=[-0.1, -0.2], - completion_temperatures=[1.0, 1.0], + temperature=1.0, advantage=1.0, ) @@ -74,8 +74,9 @@ def test_prepare_batch_balances_micro_batches_across_workers( assert sum(1 for loss_mask in batch.loss_mask if loss_mask) == 0 -def test_prepare_batch_packs_different_temperatures(make_training_example): - """With per-token temperatures, samples can be packed together regardless of their temperature values.""" +def test_prepare_batch_does_not_pack_different_temperatures(make_training_example): + """Temperature is a scalar per micro batch, so samples with different + temperatures land in separate micro batches even when they would fit.""" example1 = make_training_example(temperature=0.7, env_name="env-a") example2 = make_training_example(temperature=1.1, env_name="env-b") @@ -88,15 +89,11 @@ def test_prepare_batch_packs_different_temperatures(make_training_example): ) flat_batches = [batch for worker_batches in batches_per_gpu for batch in worker_batches] - # With per-token temperatures, samples can now be packed together - assert len(flat_batches) == 1 - # Each sample has 4 tokens (2 prompt + 2 completion), so 8 total tokens - assert len(flat_batches[0].temperatures) == 8 - # First sample (4 tokens): all get temp 0.7 - assert flat_batches[0].temperatures[:4] == [0.7, 0.7, 0.7, 0.7] - # Second sample (4 tokens): all get temp 1.1 - assert flat_batches[0].temperatures[4:8] == [1.1, 1.1, 1.1, 1.1] - assert flat_batches[0].env_names == ["env-a"] * 4 + ["env-b"] * 4 + assert len(flat_batches) == 2 + assert {b.temperature for b in flat_batches} == {0.7, 1.1} + # Each micro batch holds exactly one sample (4 tokens: 2 prompt + 2 completion) + for batch in flat_batches: + assert len(batch.input_ids) == 4 def test_prepare_sample_propagates_training_mode(make_training_example): @@ -134,7 +131,7 @@ def test_prepare_sample_with_routed_experts(): completion_ids=[3, 4], completion_mask=[True, True], completion_logprobs=[-0.1, -0.2], - completion_temperatures=[1.0, 1.0], + temperature=1.0, advantage=1.0, env_name="test-env", routed_experts=routed_experts, @@ -155,7 +152,7 @@ def test_prepare_sample_truncates_routed_experts(): completion_ids=[3, 4], completion_mask=[True, True], completion_logprobs=[-0.1, -0.2], - completion_temperatures=[1.0, 1.0], + temperature=1.0, advantage=1.0, env_name="test-env", routed_experts=routed_experts, @@ -176,7 +173,7 @@ def test_prepare_sample_none_routed_experts(): completion_ids=[3, 4], completion_mask=[True, True], completion_logprobs=[-0.1, -0.2], - completion_temperatures=[1.0, 1.0], + temperature=1.0, advantage=1.0, env_name="test-env", ) diff --git a/tests/unit/orchestrator/test_sft_trajectories.py b/tests/unit/orchestrator/test_sft_trajectories.py index a65456b674..5dd0b362e5 100644 --- a/tests/unit/orchestrator/test_sft_trajectories.py +++ b/tests/unit/orchestrator/test_sft_trajectories.py @@ -49,7 +49,7 @@ def test_interleave_rollout_missing_tokens_returns_none(): extras={}, ) ], - sampling_args={"temperature": 1.0}, + sampling_args={"temperature": 1.0, "top_p": 1.0, "extra_body": {"top_k": -1}}, error=None, ) @@ -89,7 +89,7 @@ def test_pretokenize_rollout_trajectory_for_sft(): extras={}, ), ], - sampling_args={"temperature": 1.0}, + sampling_args={"temperature": 1.0, "top_p": 1.0, "extra_body": {"top_k": -1}}, error=None, ) diff --git a/tests/unit/orchestrator/test_teacher_logprobs.py b/tests/unit/orchestrator/test_teacher_logprobs.py index d63fdce792..74900bd553 100644 --- a/tests/unit/orchestrator/test_teacher_logprobs.py +++ b/tests/unit/orchestrator/test_teacher_logprobs.py @@ -49,7 +49,7 @@ async def _run(): completion_ids=[2, 3], completion_mask=[True, True], completion_logprobs=[-0.1, -0.2], - completion_temperatures=[1.0, 1.0], + temperature=1.0, env_name="test-env", ) diff --git a/tests/unit/orchestrator/test_trajectories.py b/tests/unit/orchestrator/test_trajectories.py index c29e80976c..60cd959505 100644 --- a/tests/unit/orchestrator/test_trajectories.py +++ b/tests/unit/orchestrator/test_trajectories.py @@ -82,7 +82,7 @@ def single_step_trajectory_output(): extras={}, ) ], - sampling_args={"temperature": 1.0}, + sampling_args={"temperature": 1.0, "top_p": 1.0, "extra_body": {"top_k": -1}}, error=None, ) return output @@ -136,7 +136,7 @@ def multi_step_trajectory_output(): extras={}, ), ], - sampling_args={"temperature": 1.0}, + sampling_args={"temperature": 1.0, "top_p": 1.0, "extra_body": {"top_k": -1}}, error=None, ) return output @@ -194,7 +194,7 @@ def multi_step_trajectory_with_tool_calls_output(): advantage=None, stop_condition=None, metrics={"has_error": 0.0, "tool_calls": 1.0}, - sampling_args={"temperature": 1.0}, + sampling_args={"temperature": 1.0, "top_p": 1.0, "extra_body": {"top_k": -1}}, error=None, ) return output @@ -253,7 +253,7 @@ def multi_step_trajectory_extension_never_holds(): extras={}, ), ], - sampling_args={"temperature": 1.0}, + sampling_args={"temperature": 1.0, "top_p": 1.0, "extra_body": {"top_k": -1}}, error=None, ) return output @@ -312,7 +312,7 @@ def multi_step_trajectory_with_tool_calls_extension_never_holds(): reward=1.0, advantage=None, stop_condition=None, - sampling_args={"temperature": 1.0}, + sampling_args={"temperature": 1.0, "top_p": 1.0, "extra_body": {"top_k": -1}}, metrics={"has_error": 0.0, "tool_calls": 1.0}, error=None, ) @@ -332,7 +332,7 @@ def test_branching_equivalent_multi_step_trajectory(multi_step_trajectory_extens assert rollout.completion_ids == [3, 4] assert rollout.completion_mask == [True, True] assert rollout.completion_logprobs == [-0.1, -0.2] - assert rollout.completion_temperatures == [1.0, 1.0] + assert rollout.temperature == 1.0 # second step rollout = rollouts[1] @@ -341,7 +341,7 @@ def test_branching_equivalent_multi_step_trajectory(multi_step_trajectory_extens assert rollout.completion_ids == [7, 8] assert rollout.completion_mask == [True, True] assert rollout.completion_logprobs == [-0.3, -0.4] - assert rollout.completion_temperatures == [1.0, 1.0] + assert rollout.temperature == 1.0 def test_branching_equivalent_multi_step_trajectory_with_tool_calls( @@ -359,7 +359,7 @@ def test_branching_equivalent_multi_step_trajectory_with_tool_calls( assert rollout.completion_ids == [3, 4] assert rollout.completion_mask == [True, True] assert rollout.completion_logprobs == [-0.1, -0.2] - assert rollout.completion_temperatures == [1.0, 1.0] + assert rollout.temperature == 1.0 # second step rollout = rollouts[1] @@ -368,7 +368,7 @@ def test_branching_equivalent_multi_step_trajectory_with_tool_calls( assert rollout.completion_ids == [7, 8] assert rollout.completion_mask == [True, True] assert rollout.completion_logprobs == [-0.3, -0.4] - assert rollout.completion_temperatures == [1.0, 1.0] + assert rollout.temperature == 1.0 def test_interleave_rollout_single_step_trajectory(single_step_trajectory_output): @@ -383,7 +383,7 @@ def test_interleave_rollout_single_step_trajectory(single_step_trajectory_output assert rollout.completion_ids == [3, 4] assert rollout.completion_mask == [True, True] assert rollout.completion_logprobs == [-0.1, -0.2] - assert rollout.completion_temperatures == [1.0, 1.0] + assert rollout.temperature == 1.0 assert rollout.env_name == "test-env" @@ -399,7 +399,7 @@ def test_interleave_rollout_multi_step_trajectory(multi_step_trajectory_output): assert rollout.completion_mask == [True, True, False, False, True, True] assert rollout.completion_logprobs == [-0.1, -0.2, 0, 0, -0.3, -0.4] # Temperatures: 2 completion tokens at temp 1.0, then 2 prompt tokens at temp 1.0, then 2 completion tokens at temp 1.0 - assert rollout.completion_temperatures == [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + assert rollout.temperature == 1.0 def test_interleave_rollout_multi_step_trajectory_with_tool_calls(multi_step_trajectory_with_tool_calls_output): @@ -414,7 +414,7 @@ def test_interleave_rollout_multi_step_trajectory_with_tool_calls(multi_step_tra assert rollout.completion_mask == [True, True, False, False, True, True] assert rollout.completion_logprobs == [-0.1, -0.2, 0, 0, -0.3, -0.4] # Temperatures: 2 completion tokens at temp 1.0, then 2 prompt tokens at temp 1.0, then 2 completion tokens at temp 1.0 - assert rollout.completion_temperatures == [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + assert rollout.temperature == 1.0 @pytest.fixture @@ -560,7 +560,7 @@ def five_step_trajectory_with_extension_break(): trajectory_id="1", ), ], - sampling_args={"temperature": 1.0}, + sampling_args={"temperature": 1.0, "top_p": 1.0, "extra_body": {"top_k": -1}}, error=None, ) return output @@ -697,7 +697,7 @@ def interleaved_agents_trajectory(): extras={}, ), ], - sampling_args={"temperature": 1.0}, + sampling_args={"temperature": 1.0, "top_p": 1.0, "extra_body": {"top_k": -1}}, error=None, ) return output @@ -790,7 +790,7 @@ def test_interleave_rollout_error_masks_all_false(): ), ], error="timeout: environment exceeded time limit", - sampling_args={"temperature": 0.8}, + sampling_args={"temperature": 0.8, "top_p": 1.0, "extra_body": {"top_k": -1}}, ) rollouts = interleave_rollout(output) @@ -803,7 +803,7 @@ def test_interleave_rollout_error_masks_all_false(): assert rollout.completion_mask == [False, False, False, False, False, False] # Logprobs and temperatures still present assert rollout.completion_logprobs == [-0.1, -0.2, 0.0, 0.0, -0.3, -0.4] - assert rollout.completion_temperatures == [0.8] * 6 + assert rollout.temperature == 0.8 def test_align_routed_experts_none(): @@ -869,7 +869,7 @@ def test_interleave_rollout_single_step_with_routed_experts(): extras={}, ) ], - sampling_args={"temperature": 1.0}, + sampling_args={"temperature": 1.0, "top_p": 1.0, "extra_body": {"top_k": -1}}, error=None, ) @@ -941,7 +941,7 @@ def test_interleave_rollout_multi_step_with_routed_experts(): extras={}, ), ], - sampling_args={"temperature": 1.0}, + sampling_args={"temperature": 1.0, "top_p": 1.0, "extra_body": {"top_k": -1}}, error=None, ) @@ -982,7 +982,7 @@ def test_interleave_rollout_none_routed_experts_stays_none(): extras={}, ) ], - sampling_args={"temperature": 1.0}, + sampling_args={"temperature": 1.0, "top_p": 1.0, "extra_body": {"top_k": -1}}, error=None, ) @@ -1075,7 +1075,7 @@ def test_interleave_rollout_packs_pixels_from_renderer_mm_data(): extras={}, ), ], - sampling_args={"temperature": 1.0}, + sampling_args={"temperature": 1.0, "top_p": 1.0, "extra_body": {"top_k": -1}}, error=None, ) diff --git a/tests/unit/train/rl/test_packer.py b/tests/unit/train/rl/test_packer.py index 7068e0665a..ae6eedf1a8 100644 --- a/tests/unit/train/rl/test_packer.py +++ b/tests/unit/train/rl/test_packer.py @@ -47,7 +47,7 @@ def make_training_sample() -> TrainingSample: completion_ids=[2], completion_mask=[True], completion_logprobs=[-0.1], - completion_temperatures=[1.0], + temperature=1.0, env_name="test-env", )