Skip to content

Conversation

@pluesclues
Copy link
Contributor

@pluesclues pluesclues commented Nov 21, 2025

Introduced chunked_hidden_states_selective_log_softmax for memory efficiency and updated related functions to utilize it. Removed deprecated sampling parameter updates and adjusted logit handling in grpo_compute_loss.
Comment out sections of code related to importance sampling and logit processing.
Added a new function 'grpo_update_SamplingParams' to update sampling parameters based on provided arguments. Refactored logit processing to handle chunked inputs and improved clarity in the computation of log probabilities.
Refactor input handling for pixel values and image grid, adding pre-calculation of padding and chunking logic.
Added max_left_pad parameter handling and adjusted related logic for padding in the model's input processing.
Removed unnecessary breakpoints and cleaned up comments.
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @pluesclues, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces significant architectural changes to how log probabilities are calculated within the GRPO framework, primarily focusing on memory optimization. By implementing a chunked approach for logprob computation and refactoring the GRPO loss function to accept these pre-computed values, the system can now handle larger batch sizes or longer sequence lengths more efficiently. This change centralizes logit manipulation parameters and improves the overall robustness of the GRPO training process.

Highlights

  • Memory-Efficient Logprob Calculation: Introduced a new function, chunked_hidden_states_selective_log_softmax, to compute log probabilities in a memory-efficient manner by processing hidden states in chunks along the batch and sequence length dimensions. This is crucial for handling large models or long sequences.
  • Refactored GRPO Loss Function: The grpo_compute_loss function has been updated to directly receive pre-calculated log probabilities, streamlining its interface and removing redundant logit processing steps.
  • Batch-wise Logprob Generation: The grpo_accumulated_loss function now processes inputs by iterating through individual samples (batch size of 1) to generate log probabilities using the new chunked function, improving memory management during the forward pass.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the GRPO loss calculation to be more memory-efficient by chunking log-probability computations. A new function chunked_hidden_states_selective_log_softmax is introduced for this purpose, and the loss calculation is split across the batch dimension.

While the approach is sound, I've found a critical bug in the implementation: grpo_compute_loss is called with a mix of log-probabilities and hidden states, which will lead to incorrect loss values. I've also identified several other issues related to code clarity, duplication, and potential bugs in handling keyword arguments. Please see my detailed comments for suggestions on how to address these points.

prev_max_left_pad = kwargs.get("max_left_pad", None)

#delete this from kwargs so less issues
sampling_per_token_logps = kwargs.pop("sampling_per_token_logps", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This line unconditionally overwrites the sampling_per_token_logps variable set on line 548, which makes the conditional logic on that line ineffective. This is likely a bug. This line should be removed, and line 548 should be modified to use .pop() to correctly handle the value based on the vllm_importance_sampling_correction flag.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just changed the pop kwargs to be equal to _ since this would be a bug, I do not fully recall why that was there. will investigate further into that.

Comment on lines +65 to +77
# More memory efficient by chunking on (bsz+qlen) dimension
# Exactly equivalent to the above
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
def chunked_hidden_states_selective_log_softmax(
hidden_states,
lm_head,
index,
chunks=4,
logit_scale_multiply=0.0,
logit_scale_divide=0.0,
logit_softcapping=0.0,
temperature=1.0
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This new function is missing type hints for its parameters and return value, as well as a proper docstring. Adding them would improve readability and maintainability, following Python best practices (PEP 484). The existing comments can be converted into a formal docstring.

Suggested change
# More memory efficient by chunking on (bsz+qlen) dimension
# Exactly equivalent to the above
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
def chunked_hidden_states_selective_log_softmax(
hidden_states,
lm_head,
index,
chunks=4,
logit_scale_multiply=0.0,
logit_scale_divide=0.0,
logit_softcapping=0.0,
temperature=1.0
):
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
def chunked_hidden_states_selective_log_softmax(
hidden_states: torch.Tensor,
lm_head: torch.Tensor,
index: torch.Tensor,
chunks: int = 4,
logit_scale_multiply: float = 0.0,
logit_scale_divide: float = 0.0,
logit_softcapping: float = 0.0,
temperature: float = 1.0,
) -> torch.Tensor:
"""
More memory efficient by chunking on (bsz+qlen) dimension.
Exactly equivalent to the non-chunked version.
"""

pass
RL_REPLACEMENTS["left_pack_padding"] = left_pack_padding

import torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

torch is already imported at the top of this file (line 21). This import is redundant and should be removed to keep the code clean and avoid potential confusion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed this for code cleanliness

Comment on lines +243 to +245
ref,
new,
old,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The parameters ref, new, and old are now expected to be log-probabilities. The previous names (ref_logits, etc.) were also not perfectly accurate, but the new names are even more ambiguous. Consider renaming them to something more descriptive like ref_logps, new_logps, and old_logps to improve code clarity and maintainability.

Suggested change
ref,
new,
old,
ref_logps,
new_logps,
old_logps,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed naming covnentions for these just now.

# mark_dynamic(input_ids_j)
# mark_dynamic(mask_j)

#breakpoint()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This commented-out breakpoint() seems to be a leftover from a debugging session. It should be removed to keep the codebase clean.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed this breakpoint just now.

Comment on lines +615 to +619
def chunk_optional(tensor, chunks):
if tensor is None:
return [None] * chunks
return torch.chunk(tensor, chunks=chunks, dim=0)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The helper function chunk_optional is defined here again. It's a duplicate of the function defined on lines 606-609. This redundant definition should be removed to avoid code duplication.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed one of these chunk_optionals.

@danielhanchen danielhanchen changed the base branch from main to nightly November 27, 2025 03:37
@danielhanchen
Copy link
Contributor

@codex review

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines 573 to +576
max_left_pad = max(left_pad_tokens_per_prompt).item()

if max_left_pad < max(prev_max_left_pad).item() and (logits_to_keep +max_left_pad) != old_logps.shape[1]:
max_left_pad = max(prev_max_left_pad).item()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Guard prev_max_left_pad before max()

When max_left_pad is not supplied in kwargs it defaults to None, but the new padding adjustment uses max(prev_max_left_pad).item() unconditionally. On the first call (or any call without that kwarg) this hits max(None) and raises a TypeError before any loss is computed, breaking training even for valid inputs. Add a None check or default to an empty tensor before calling max().

Useful? React with 👍 / 👎.

)
pass
pass
# pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove


prev_max_left_pad = kwargs.get("max_left_pad", None)

#delete this from kwargs so less issues
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spacing and capitalize and mention we enable by default


max_left_pad = max(left_pad_tokens_per_prompt).item()

if max_left_pad < max(prev_max_left_pad).item() and (logits_to_keep +max_left_pad) != old_logps.shape[1]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use torch ops

new_hidden_states_chunk,
lm_head,
completion_ids,
chunks = 8,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does increasing this reduce VRAM more but makes it slower?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants