Skip to content

add fa3_mtp #1005

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open

add fa3_mtp #1005

wants to merge 10 commits into from

Conversation

WANDY666
Copy link
Contributor

No description provided.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @WANDY666, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

I've implemented support for Flash Attention 3 (FA3) in Multi-Token Prediction (MTP) mode, primarily targeting the Deepseek2 model. This enhancement aims to significantly boost the decoding performance by allowing the model to process and generate multiple tokens more efficiently. The changes include a new specialized FA3 kernel, its integration into the Deepseek2 architecture, and the necessary configuration and benchmarking tools to enable and validate this feature.

Highlights

  • New Flash Attention 3 MTP Kernel: I've introduced a new Flash Attention 3 (FA3) kernel specifically designed for Multi-Token Prediction (MTP) mode. This kernel, flash_attn_with_kvcache_mtp, is adapted from sgl-project/sglang and is crucial for optimizing decode performance when generating multiple tokens simultaneously.
  • Deepseek2 Model Integration: I've integrated the new FA3 MTP kernel into the Deepseek2 model's transformer layer. This involves adding mtp_step and mtp_size parameters and conditionally switching to the new _token_gqa_decode_attention_mtp method when the enable_fa3_mtp flag is active. This ensures the Deepseek2 model can leverage the performance benefits of the new kernel.
  • Command-Line Argument and Configuration: I've added a new command-line argument, --enable_fa3_mtp, to api_cli.py and updated the StartArgs configuration. This allows users to easily enable and utilize the new FA3 MTP feature from the server's command line, with a validation check to ensure mtp_mode is also set.
  • Benchmarking Tools and Script: I've included new benchmarking utilities and a specific benchmark script for the fa3_decode_mtp kernel. These tools, adapted from tile-ai/tilelang, will help in evaluating the performance and correctness of the new MTP attention mechanism.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for fa3_mtp (FlashAttention-3 with Multi-Token Prediction), which is a significant feature enhancement. The changes span across dependency management in Dockerfiles, core model logic for DeepSeek V2, and the addition of a new attention kernel. New command-line arguments and a benchmark test are also included to support and validate this feature. My review focuses on improving code readability, performance through refactoring redundant function calls, and optimizing the Docker image size. The core logic for the new feature appears sound and well-integrated.

Comment on lines +42 to +43
RUN git clone https://github.com/ModelTC/LightKernel.git && cd LightKernel && pip install --no-deps -v . && \
cd flash-attention/hopper/ && python setup.py install

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This RUN command can be improved for readability and to reduce the final image size by cleaning up the cloned repository. Consider breaking down the command and removing the LightKernel directory after installation. For example:

RUN git clone https://github.com/ModelTC/LightKernel.git && \
    cd LightKernel && \
    pip install --no-deps -v . && \
    cd flash-attention/hopper/ && \
    python setup.py install && \
    cd ../../.. && \
    rm -rf LightKernel

Comment on lines +42 to +43
RUN git clone https://github.com/ModelTC/LightKernel.git && cd LightKernel && pip install --no-deps -v . && \
cd flash-attention/hopper/ && python setup.py install

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This RUN command can be improved for readability and to reduce the final image size by cleaning up the cloned repository. Consider breaking down the command and removing the LightKernel directory after installation. For example:

RUN git clone https://github.com/ModelTC/LightKernel.git && \
    cd LightKernel && \
    pip install --no-deps -v . && \
    cd flash-attention/hopper/ && \
    python setup.py install && \
    cd ../../.. && \
    rm -rf LightKernel

Comment on lines +12 to +14
mtp_step = get_env_start_args().mtp_step
mtp_size = mtp_step + 1
enable_fa3_mtp = get_env_start_args().enable_fa3_mtp

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve readability and avoid redundant calls, it's better to call get_env_start_args() only once and store the result in a local variable.

Suggested change
mtp_step = get_env_start_args().mtp_step
mtp_size = mtp_step + 1
enable_fa3_mtp = get_env_start_args().enable_fa3_mtp
start_args = get_env_start_args()
mtp_step, enable_fa3_mtp = start_args.mtp_step, start_args.enable_fa3_mtp
mtp_size = mtp_step + 1

@@ -97,7 +100,11 @@ def _bind_attention(self):
)
else:
self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
if get_env_start_args().enable_fa3:
if get_env_start_args().enable_fa3_mtp:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function get_env_start_args() is called multiple times within this if/elif chain (lines 103, 107, 111). To improve performance and readability, consider calling it once before this conditional block and storing the result in a local variable.

@@ -69,7 +69,7 @@ def __init__(self, kvargs):
return

def _init_inferstate_cls(self):
if get_env_start_args().enable_fa3:
if get_env_start_args().enable_fa3 or get_env_start_args().enable_fa3_mtp:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function get_env_start_args() is called twice on this line. To improve performance and readability, consider calling it once before the if statement and storing the result in a local variable.

total_seqlens = cache_seqlens.sum().item()
mean_seqlens = cache_seqlens.float().mean().int().item()
max_seqlen = cache_seqlens.max().item()
max_seqlen_pad = math.ceil(max_seqlen / 256) * 256 # ?为什么对齐256

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This comment appears to be a question left over from development. It should be removed or replaced with a proper explanation for clarity.

Suggested change
max_seqlen_pad = math.ceil(max_seqlen / 256) * 256 # ?为什么对齐256
max_seqlen_pad = math.ceil(max_seqlen / 256) * 256

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant