Skip to content

Conversation

@finbarrtimbers
Copy link
Collaborator

@finbarrtimbers finbarrtimbers commented Dec 1, 2025

This lets us clean it up by using comprehensions and loops instead of making repeated calls.

Also switches the variables in train to use Shazeer style shape suffixes.

Runs:

  • Single GPU GRPO: Beaker
  • Single GPU GRPO with tools: Beaker

Note

Refactors GRPO training to use a typed CollatedBatchData instead of dicts, updates the trainer/train flow and metrics, and migrates PromptRequest/Result types from queue_types to data_types across code and tests.

  • RL/Trainer (grpo_fast.py):
    • Replace dict-based batch inputs with typed CollatedBatchData throughout (prepare_collated_data_for_workers, compute_logprobs, calculate_token_counts, train).
    • Simplify train signature to train(data_BT, pad_token_id); derive minibatching from args.num_mini_batches and operate with shape-suffix variables (e.g., *_BT).
    • Rework logprob/ratio/PG/KL computations and metric aggregation into structured tensors; update one_training_step to pass data_BT.
  • Data Model (data_types.py):
    • Add CollatedBatchData dataclass with tensor lists, slicing, and length helpers.
    • Centralize PromptRequest, GenerationResult, RequestInfo, TokenStatistics here (used repo-wide).
  • Tests:
    • Update expectations to CollatedBatchData (field checks, iteration) and adjust imports.
  • Misc:
    • Migrate imports from open_instruct.queue_types to open_instruct.data_types in code and tests, including vllm_utils.py and benchmarking script.

Written by Cursor Bugbot for commit 9f14736. This will update automatically on new commits. Configure here.

@finbarrtimbers finbarrtimbers changed the title Refactors PolicyRayTrainActor so that the train method uses a dict of data Refactors PolicyTrainerRayProcess so that the train method uses a dict of data Dec 1, 2025
@github-actions
Copy link
Contributor

github-actions bot commented Dec 1, 2025

Documentation Changes Detected

📄 olmo2/index.html
--- site-base/olmo2/index.html	2025-12-01 16:29:10.215190724 +0000
+++ site-pr/olmo2/index.html	2025-12-01 16:29:06.881200320 +0000
@@ -990,7 +990,7 @@
     --local_mini_batch_size 32 \
     --number_samples_per_prompt 16 \
     --local_rollout_batch_size 4 \
-    --kl_estimator kl3 \
+    --kl_estimator 2 \
     --learning_rate 5e-7 \
     --dataset_mixer_list allenai/RLVR-GSM-MATH-IF-Mixed-Constraints 1.0 \
📄 tulu3/index.html
--- site-base/tulu3/index.html	2025-12-01 16:29:10.215190724 +0000
+++ site-pr/tulu3/index.html	2025-12-01 16:29:06.881200320 +0000
@@ -1260,7 +1260,7 @@
 <span class="k">for</span><span class="w"> </span>beta<span class="w"> </span><span class="k">in</span><span class="w"> </span><span class="m">0</span>.01<span class="p">;</span><span class="w"> </span><span class="k">do</span>
 <span class="k">for</span><span class="w"> </span>nspp<span class="w"> </span><span class="k">in</span><span class="w"> </span><span class="m">16</span><span class="p">;</span><span class="w"> </span><span class="k">do</span>
 <span class="k">for</span><span class="w"> </span>m<span class="w"> </span><span class="k">in</span><span class="w"> </span>half-m<span class="w"> </span><span class="p">;</span><span class="w"> </span><span class="k">do</span>
-<span class="k">for</span><span class="w"> </span>kl_estimator<span class="w"> </span><span class="k">in</span><span class="w"> </span>kl3<span class="p">;</span><span class="w"> </span><span class="k">do</span>
+<span class="k">for</span><span class="w"> </span>kl_estimator<span class="w"> </span><span class="k">in</span><span class="w"> </span><span class="m">2</span><span class="p">;</span><span class="w"> </span><span class="k">do</span>
 <span class="nv">local_rollout_batch_size</span><span class="o">=</span><span class="m">4</span>
 <span class="k">if</span><span class="w"> </span><span class="o">[</span><span class="w"> </span><span class="nv">$m</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="s2">&quot;half-m&quot;</span><span class="w"> </span><span class="o">]</span><span class="p">;</span><span class="w"> </span><span class="k">then</span>

Showing first 10 lines of diff for each changed file (up to 5 files, excluding search indices).

@github-actions
Copy link
Contributor

github-actions bot commented Dec 1, 2025

Documentation Changes Detected

📄 olmo2/index.html
--- site-base/olmo2/index.html	2025-12-01 16:52:05.401483293 +0000
+++ site-pr/olmo2/index.html	2025-12-01 16:52:02.907489884 +0000
@@ -990,7 +990,7 @@
     --local_mini_batch_size 32 \
     --number_samples_per_prompt 16 \
     --local_rollout_batch_size 4 \
-    --kl_estimator kl3 \
+    --kl_estimator 2 \
     --learning_rate 5e-7 \
     --dataset_mixer_list allenai/RLVR-GSM-MATH-IF-Mixed-Constraints 1.0 \
📄 tulu3/index.html
--- site-base/tulu3/index.html	2025-12-01 16:52:05.401483293 +0000
+++ site-pr/tulu3/index.html	2025-12-01 16:52:02.907489884 +0000
@@ -1260,7 +1260,7 @@
 <span class="k">for</span><span class="w"> </span>beta<span class="w"> </span><span class="k">in</span><span class="w"> </span><span class="m">0</span>.01<span class="p">;</span><span class="w"> </span><span class="k">do</span>
 <span class="k">for</span><span class="w"> </span>nspp<span class="w"> </span><span class="k">in</span><span class="w"> </span><span class="m">16</span><span class="p">;</span><span class="w"> </span><span class="k">do</span>
 <span class="k">for</span><span class="w"> </span>m<span class="w"> </span><span class="k">in</span><span class="w"> </span>half-m<span class="w"> </span><span class="p">;</span><span class="w"> </span><span class="k">do</span>
-<span class="k">for</span><span class="w"> </span>kl_estimator<span class="w"> </span><span class="k">in</span><span class="w"> </span>kl3<span class="p">;</span><span class="w"> </span><span class="k">do</span>
+<span class="k">for</span><span class="w"> </span>kl_estimator<span class="w"> </span><span class="k">in</span><span class="w"> </span><span class="m">2</span><span class="p">;</span><span class="w"> </span><span class="k">do</span>
 <span class="nv">local_rollout_batch_size</span><span class="o">=</span><span class="m">4</span>
 <span class="k">if</span><span class="w"> </span><span class="o">[</span><span class="w"> </span><span class="nv">$m</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="s2">&quot;half-m&quot;</span><span class="w"> </span><span class="o">]</span><span class="p">;</span><span class="w"> </span><span class="k">then</span>

Showing first 10 lines of diff for each changed file (up to 5 files, excluding search indices).

@github-actions
Copy link
Contributor

github-actions bot commented Dec 1, 2025

Documentation Changes Detected

📄 olmo2/index.html
--- site-base/olmo2/index.html	2025-12-01 19:01:45.485977869 +0000
+++ site-pr/olmo2/index.html	2025-12-01 19:01:42.499987013 +0000
@@ -990,7 +990,7 @@
     --local_mini_batch_size 32 \
     --number_samples_per_prompt 16 \
     --local_rollout_batch_size 4 \
-    --kl_estimator kl3 \
+    --kl_estimator 2 \
     --learning_rate 5e-7 \
     --dataset_mixer_list allenai/RLVR-GSM-MATH-IF-Mixed-Constraints 1.0 \
📄 tulu3/index.html
--- site-base/tulu3/index.html	2025-12-01 19:01:45.485977869 +0000
+++ site-pr/tulu3/index.html	2025-12-01 19:01:42.500987010 +0000
@@ -1260,7 +1260,7 @@
 <span class="k">for</span><span class="w"> </span>beta<span class="w"> </span><span class="k">in</span><span class="w"> </span><span class="m">0</span>.01<span class="p">;</span><span class="w"> </span><span class="k">do</span>
 <span class="k">for</span><span class="w"> </span>nspp<span class="w"> </span><span class="k">in</span><span class="w"> </span><span class="m">16</span><span class="p">;</span><span class="w"> </span><span class="k">do</span>
 <span class="k">for</span><span class="w"> </span>m<span class="w"> </span><span class="k">in</span><span class="w"> </span>half-m<span class="w"> </span><span class="p">;</span><span class="w"> </span><span class="k">do</span>
-<span class="k">for</span><span class="w"> </span>kl_estimator<span class="w"> </span><span class="k">in</span><span class="w"> </span>kl3<span class="p">;</span><span class="w"> </span><span class="k">do</span>
+<span class="k">for</span><span class="w"> </span>kl_estimator<span class="w"> </span><span class="k">in</span><span class="w"> </span><span class="m">2</span><span class="p">;</span><span class="w"> </span><span class="k">do</span>
 <span class="nv">local_rollout_batch_size</span><span class="o">=</span><span class="m">4</span>
 <span class="k">if</span><span class="w"> </span><span class="o">[</span><span class="w"> </span><span class="nv">$m</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="s2">&quot;half-m&quot;</span><span class="w"> </span><span class="o">]</span><span class="p">;</span><span class="w"> </span><span class="k">then</span>

Showing first 10 lines of diff for each changed file (up to 5 files, excluding search indices).

@github-actions
Copy link
Contributor

github-actions bot commented Dec 1, 2025

Documentation Changes Detected

📄 olmo2/index.html
--- site-base/olmo2/index.html	2025-12-01 19:04:49.702340373 +0000
+++ site-pr/olmo2/index.html	2025-12-01 19:04:47.382347742 +0000
@@ -990,7 +990,7 @@
     --local_mini_batch_size 32 \
     --number_samples_per_prompt 16 \
     --local_rollout_batch_size 4 \
-    --kl_estimator kl3 \
+    --kl_estimator 2 \
     --learning_rate 5e-7 \
     --dataset_mixer_list allenai/RLVR-GSM-MATH-IF-Mixed-Constraints 1.0 \
📄 tulu3/index.html
--- site-base/tulu3/index.html	2025-12-01 19:04:49.702340373 +0000
+++ site-pr/tulu3/index.html	2025-12-01 19:04:47.382347742 +0000
@@ -1260,7 +1260,7 @@
 <span class="k">for</span><span class="w"> </span>beta<span class="w"> </span><span class="k">in</span><span class="w"> </span><span class="m">0</span>.01<span class="p">;</span><span class="w"> </span><span class="k">do</span>
 <span class="k">for</span><span class="w"> </span>nspp<span class="w"> </span><span class="k">in</span><span class="w"> </span><span class="m">16</span><span class="p">;</span><span class="w"> </span><span class="k">do</span>
 <span class="k">for</span><span class="w"> </span>m<span class="w"> </span><span class="k">in</span><span class="w"> </span>half-m<span class="w"> </span><span class="p">;</span><span class="w"> </span><span class="k">do</span>
-<span class="k">for</span><span class="w"> </span>kl_estimator<span class="w"> </span><span class="k">in</span><span class="w"> </span>kl3<span class="p">;</span><span class="w"> </span><span class="k">do</span>
+<span class="k">for</span><span class="w"> </span>kl_estimator<span class="w"> </span><span class="k">in</span><span class="w"> </span><span class="m">2</span><span class="p">;</span><span class="w"> </span><span class="k">do</span>
 <span class="nv">local_rollout_batch_size</span><span class="o">=</span><span class="m">4</span>
 <span class="k">if</span><span class="w"> </span><span class="o">[</span><span class="w"> </span><span class="nv">$m</span><span class="w"> </span><span class="o">==</span><span class="w"> </span><span class="s2">&quot;half-m&quot;</span><span class="w"> </span><span class="o">]</span><span class="p">;</span><span class="w"> </span><span class="k">then</span>

Showing first 10 lines of diff for each changed file (up to 5 files, excluding search indices).

@finbarrtimbers finbarrtimbers marked this pull request as ready for review December 3, 2025 17:02
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Copy link
Collaborator

@hamishivi hamishivi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally LGTM, one nit pick around using dicts to pass stuff around.

- Rename queue_types.py to data_types.py (avoiding stdlib conflict)
- Add CollatedBatchData dataclass with __getitem__ and __len__ methods
- Update grpo_fast.py to use dataclass instead of dict
- Update all imports across affected files

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
@finbarrtimbers finbarrtimbers added this pull request to the merge queue Dec 4, 2025
Merged via the queue into main with commit 79ea62e Dec 4, 2025
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants