-
Notifications
You must be signed in to change notification settings - Fork 469
Refactors PolicyTrainerRayProcess so that the train method uses a dict of data
#1240
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
PolicyRayTrainActor so that the train method uses a dict of dataPolicyTrainerRayProcess so that the train method uses a dict of data
Documentation Changes Detected📄
|
Documentation Changes Detected📄
|
Documentation Changes Detected📄
|
Documentation Changes Detected📄
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
hamishivi
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally LGTM, one nit pick around using dicts to pass stuff around.
- Rename queue_types.py to data_types.py (avoiding stdlib conflict) - Add CollatedBatchData dataclass with __getitem__ and __len__ methods - Update grpo_fast.py to use dataclass instead of dict - Update all imports across affected files 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
This lets us clean it up by using comprehensions and loops instead of making repeated calls.
Also switches the variables in
trainto use Shazeer style shape suffixes.Runs:
Note
Refactors GRPO training to use a typed CollatedBatchData instead of dicts, updates the trainer/train flow and metrics, and migrates PromptRequest/Result types from queue_types to data_types across code and tests.
CollatedBatchDatathroughout (prepare_collated_data_for_workers,compute_logprobs,calculate_token_counts,train).trainsignature totrain(data_BT, pad_token_id); derive minibatching fromargs.num_mini_batchesand operate with shape-suffix variables (e.g.,*_BT).one_training_stepto passdata_BT.CollatedBatchDatadataclass with tensor lists, slicing, and length helpers.PromptRequest,GenerationResult,RequestInfo,TokenStatisticshere (used repo-wide).CollatedBatchData(field checks, iteration) and adjust imports.open_instruct.queue_typestoopen_instruct.data_typesin code and tests, includingvllm_utils.pyand benchmarking script.Written by Cursor Bugbot for commit 9f14736. This will update automatically on new commits. Configure here.