-
Notifications
You must be signed in to change notification settings - Fork 173
Chunk Across Batch and Context length for logprob calculations for grpo #357
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: nightly
Are you sure you want to change the base?
Chunk Across Batch and Context length for logprob calculations for grpo #357
Conversation
Introduced chunked_hidden_states_selective_log_softmax for memory efficiency and updated related functions to utilize it. Removed deprecated sampling parameter updates and adjusted logit handling in grpo_compute_loss.
Comment out sections of code related to importance sampling and logit processing.
Added a new function 'grpo_update_SamplingParams' to update sampling parameters based on provided arguments. Refactored logit processing to handle chunked inputs and improved clarity in the computation of log probabilities.
Refactor input handling for pixel values and image grid, adding pre-calculation of padding and chunking logic.
Added max_left_pad parameter handling and adjusted related logic for padding in the model's input processing.
Removed unnecessary breakpoints and cleaned up comments.
Summary of ChangesHello @pluesclues, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces significant architectural changes to how log probabilities are calculated within the GRPO framework, primarily focusing on memory optimization. By implementing a chunked approach for logprob computation and refactoring the GRPO loss function to accept these pre-computed values, the system can now handle larger batch sizes or longer sequence lengths more efficiently. This change centralizes logit manipulation parameters and improves the overall robustness of the GRPO training process. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the GRPO loss calculation to be more memory-efficient by chunking log-probability computations. A new function chunked_hidden_states_selective_log_softmax is introduced for this purpose, and the loss calculation is split across the batch dimension.
While the approach is sound, I've found a critical bug in the implementation: grpo_compute_loss is called with a mix of log-probabilities and hidden states, which will lead to incorrect loss values. I've also identified several other issues related to code clarity, duplication, and potential bugs in handling keyword arguments. Please see my detailed comments for suggestions on how to address these points.
unsloth_zoo/rl_replacements.py
Outdated
| prev_max_left_pad = kwargs.get("max_left_pad", None) | ||
|
|
||
| #delete this from kwargs so less issues | ||
| sampling_per_token_logps = kwargs.pop("sampling_per_token_logps", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line unconditionally overwrites the sampling_per_token_logps variable set on line 548, which makes the conditional logic on that line ineffective. This is likely a bug. This line should be removed, and line 548 should be modified to use .pop() to correctly handle the value based on the vllm_importance_sampling_correction flag.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just changed the pop kwargs to be equal to _ since this would be a bug, I do not fully recall why that was there. will investigate further into that.
| # More memory efficient by chunking on (bsz+qlen) dimension | ||
| # Exactly equivalent to the above | ||
| @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) | ||
| def chunked_hidden_states_selective_log_softmax( | ||
| hidden_states, | ||
| lm_head, | ||
| index, | ||
| chunks=4, | ||
| logit_scale_multiply=0.0, | ||
| logit_scale_divide=0.0, | ||
| logit_softcapping=0.0, | ||
| temperature=1.0 | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This new function is missing type hints for its parameters and return value, as well as a proper docstring. Adding them would improve readability and maintainability, following Python best practices (PEP 484). The existing comments can be converted into a formal docstring.
| # More memory efficient by chunking on (bsz+qlen) dimension | |
| # Exactly equivalent to the above | |
| @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) | |
| def chunked_hidden_states_selective_log_softmax( | |
| hidden_states, | |
| lm_head, | |
| index, | |
| chunks=4, | |
| logit_scale_multiply=0.0, | |
| logit_scale_divide=0.0, | |
| logit_softcapping=0.0, | |
| temperature=1.0 | |
| ): | |
| @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) | |
| def chunked_hidden_states_selective_log_softmax( | |
| hidden_states: torch.Tensor, | |
| lm_head: torch.Tensor, | |
| index: torch.Tensor, | |
| chunks: int = 4, | |
| logit_scale_multiply: float = 0.0, | |
| logit_scale_divide: float = 0.0, | |
| logit_softcapping: float = 0.0, | |
| temperature: float = 1.0, | |
| ) -> torch.Tensor: | |
| """ | |
| More memory efficient by chunking on (bsz+qlen) dimension. | |
| Exactly equivalent to the non-chunked version. | |
| """ |
unsloth_zoo/rl_replacements.py
Outdated
| pass | ||
| RL_REPLACEMENTS["left_pack_padding"] = left_pack_padding | ||
|
|
||
| import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed this for code cleanliness
| ref, | ||
| new, | ||
| old, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parameters ref, new, and old are now expected to be log-probabilities. The previous names (ref_logits, etc.) were also not perfectly accurate, but the new names are even more ambiguous. Consider renaming them to something more descriptive like ref_logps, new_logps, and old_logps to improve code clarity and maintainability.
| ref, | |
| new, | |
| old, | |
| ref_logps, | |
| new_logps, | |
| old_logps, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed naming covnentions for these just now.
unsloth_zoo/rl_replacements.py
Outdated
| # mark_dynamic(input_ids_j) | ||
| # mark_dynamic(mask_j) | ||
|
|
||
| #breakpoint() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed this breakpoint just now.
| def chunk_optional(tensor, chunks): | ||
| if tensor is None: | ||
| return [None] * chunks | ||
| return torch.chunk(tensor, chunks=chunks, dim=0) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed one of these chunk_optionals.
|
@codex review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| max_left_pad = max(left_pad_tokens_per_prompt).item() | ||
|
|
||
| if max_left_pad < max(prev_max_left_pad).item() and (logits_to_keep +max_left_pad) != old_logps.shape[1]: | ||
| max_left_pad = max(prev_max_left_pad).item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard prev_max_left_pad before max()
When max_left_pad is not supplied in kwargs it defaults to None, but the new padding adjustment uses max(prev_max_left_pad).item() unconditionally. On the first call (or any call without that kwarg) this hits max(None) and raises a TypeError before any loss is computed, breaking training even for valid inputs. Add a None check or default to an empty tensor before calling max().
Useful? React with 👍 / 👎.
| ) | ||
| pass | ||
| pass | ||
| # pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove
|
|
||
| prev_max_left_pad = kwargs.get("max_left_pad", None) | ||
|
|
||
| #delete this from kwargs so less issues |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Spacing and capitalize and mention we enable by default
|
|
||
| max_left_pad = max(left_pad_tokens_per_prompt).item() | ||
|
|
||
| if max_left_pad < max(prev_max_left_pad).item() and (logits_to_keep +max_left_pad) != old_logps.shape[1]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you use torch ops
| new_hidden_states_chunk, | ||
| lm_head, | ||
| completion_ids, | ||
| chunks = 8, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does increasing this reduce VRAM more but makes it slower?
Relies on: unslothai/unsloth#3628