Skip to content

Train against raw policy logprobs#2604

Draft
samsja wants to merge 1 commit into
mainfrom
fix/raw-logprobs-policy
Draft

Train against raw policy logprobs#2604
samsja wants to merge 1 commit into
mainfrom
fix/raw-logprobs-policy

Conversation

@samsja
Copy link
Copy Markdown
Member

@samsja samsja commented May 23, 2026

Summary

  • force vLLM to return raw_logprobs
  • stop replaying sampling temperature in trainer policy logprob/entropy computation
  • keep fused LM-head signatures compatible while computing raw policy logprobs, and update fused/vanilla tests accordingly

Validation

  • uv run ruff check packages/prime-rl-configs/src/prime_rl/configs/inference.py src/prime_rl/trainer/rl/train.py src/prime_rl/trainer/models/layers/lm_head.py src/prime_rl/trainer/models/layers/lm_head_gemma.py tests/unit/train/rl/test_fused_lm_head.py tests/unit/train/models/test_nemotron_h_kl.py
  • uv run pytest tests/unit/train/rl/test_fused_lm_head.py -q -m 'not gpu'
  • uv run pytest tests/unit/train/rl/test_fused_lm_head.py tests/unit/train/models/test_nemotron_h_kl.py -q -m gpu
  • uv run python - <<'PY' ... InferenceConfig().to_vllm().logprobs_mode returned raw_logprobs

Reverse-text temp=1.5 smoke was attempted but blocked on this workstation before a valid run completed: the pinned vLLM wheel requires libcudart.so.13, and with a temporary CUDA 13 runtime the local NVIDIA 535.171.04 driver failed with CUDA driver version is insufficient for CUDA runtime version.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant