Skip to content

[opt]opti-qwen2-vl-vit #1004

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Aug 14, 2025
Merged

[opt]opti-qwen2-vl-vit #1004

merged 15 commits into from
Aug 14, 2025

Conversation

SangChengC
Copy link
Contributor

No description provided.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @SangChengC, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

I've implemented a series of targeted optimizations across the Qwen2-VL and Qwen2.5-VL vision models. The primary goal was to enhance computational efficiency and improve code maintainability. This involved integrating highly optimized rotary position embeddings from vllm, streamlining the image preprocessing pipeline, and enabling advanced Flash Attention features. These changes collectively aim to boost the performance of visual processing within the system.

Highlights

  • Rotary Position Embedding Optimization: Integrated vllm's optimized rotary position embedding (apply_rotary_emb) into both Qwen2-VL and Qwen2.5-VL vision models, which is expected to significantly enhance the efficiency of positional encoding calculations.
  • Streamlined Image Preprocessing: Refactored and simplified the image preprocessing pipeline in vision_process.py, leading to cleaner code, reduced dependencies, and more efficient image resizing and normalization.
  • Flash Attention v3 Activation: Activated Flash Attention v3 for the Hopper architecture by removing a restrictive flag, allowing the system to leverage this advanced attention mechanism for improved performance in compatible environments.
  • Enhanced Model Initialization and Data Type Management: Improved data type handling and model initialization for Qwen2-VL and Qwen2.5-VL vision transformers, ensuring better type consistency and flexibility during model loading.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several optimizations for Qwen2-VL and Qwen2.5-VL models, primarily by integrating components from vllm for rotary embeddings and enabling FlashAttention v3. The changes also include significant refactoring of the vision processing pipeline, simplifying the code and improving configurability. However, I've identified a few critical issues related to incorrect application of rotary position embeddings and argument mismatches in function calls that will lead to runtime errors. These need to be addressed to ensure the models function correctly.

Comment on lines 79 to 84
def apply_rotary_pos_emb_vision(self, t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
t_ = t.float()
cos = freqs.cos()
sin = freqs.sin()
output = apply_rotary_emb(t_, cos, sin).type_as(t)
return output

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The implementation of apply_rotary_pos_emb_vision has a shape mismatch issue. The cos and sin tensors derived from freqs have a shape of (sequence_length, head_dim/2), while the input tensor t has a shape of (sequence_length, num_heads, head_dim). The apply_rotary_emb function performs element-wise operations, so cos and sin must be broadcastable to t. The current implementation will fail at runtime.

To fix this, you need to expand cos and sin to be compatible with t's shape.

    def apply_rotary_pos_emb_vision(self, t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
        t_ = t.float()
        cos = freqs.cos()
        sin = freqs.sin()
        cos = torch.cat((cos, cos), dim=-1).unsqueeze(1)
        sin = torch.cat((sin, sin), dim=-1).unsqueeze(1)
        output = apply_rotary_emb(t_, cos, sin).type_as(t)
        return output

Comment on lines 361 to 366
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens_now,
max_seqlen=max_seqlen,
position_embeddings=emb,
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a TypeError waiting to happen here. You are calling blk (an instance of Qwen2_5_VLVisionBlock) with the keyword argument position_embeddings. The Qwen2_5_VLVisionBlock.forward method then passes this argument to self.attn.forward. However, self.attn (an instance of Qwen2_5_VLVisionFlashAttention) does not have a position_embeddings parameter in its forward method.

The correct keyword argument should be rotary_pos_emb. Additionally, the emb tensor you are passing is incorrectly constructed. You should pass rotary_pos_emb directly.

            hidden_states = blk(
                hidden_states,
                cu_seqlens=cu_seqlens_now,
                max_seqlen=max_seqlen,
                rotary_pos_emb=rotary_pos_emb,
            )

Comment on lines 141 to 146
def apply_rotary_pos_emb_vision(self, t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
t_ = t.float()
cos = freqs.cos()
sin = freqs.sin()
output = apply_rotary_emb(t_, cos, sin).type_as(t)
return output

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This implementation of apply_rotary_pos_emb_vision will cause a shape mismatch error. The cos and sin tensors are not correctly shaped to be broadcastable with the input tensor t inside apply_rotary_emb. The call to this function in forward also uses q.unsqueeze(0), which is likely incorrect and should be removed along with the corresponding squeeze(0).

You need to expand cos and sin to be broadcastable with t.

    def apply_rotary_pos_emb_vision(self, t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
        t_ = t.float()
        cos = freqs.cos()
        sin = freqs.sin()
        cos = torch.cat((cos, cos), dim=-1).unsqueeze(1)
        sin = torch.cat((sin, sin), dim=-1).unsqueeze(1)
        output = apply_rotary_emb(t_, cos, sin).type_as(t)
        return output

Comment on lines 157 to 158
q = self.apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb)
k = self.apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The calls to apply_rotary_pos_emb_vision use unsqueeze(0) on q and k, and then squeeze(0) on the result. This is unnecessary and makes the code harder to follow. The apply_rotary_pos_emb_vision function can be adjusted to work directly with the 3D tensor shape of q and k.

        q = self.apply_rotary_pos_emb_vision(q, rotary_pos_emb)
        k = self.apply_rotary_pos_emb_vision(k, rotary_pos_emb)

@SangChengC SangChengC changed the title 0808-opti-qwen2-vl-vit [opti]opti-qwen2-vl-vit Aug 8, 2025
@SangChengC SangChengC changed the title [opti]opti-qwen2-vl-vit [opt]opti-qwen2-vl-vit Aug 8, 2025
@hiworldwzj hiworldwzj merged commit df6afff into main Aug 14, 2025
1 check passed
@hiworldwzj hiworldwzj deleted the opti-qwen2-vl-vit branch August 14, 2025 08:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants