-
Notifications
You must be signed in to change notification settings - Fork 272
add fa3_mtp #1005
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
add fa3_mtp #1005
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @WANDY666, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
I've implemented support for Flash Attention 3 (FA3) in Multi-Token Prediction (MTP) mode, primarily targeting the Deepseek2 model. This enhancement aims to significantly boost the decoding performance by allowing the model to process and generate multiple tokens more efficiently. The changes include a new specialized FA3 kernel, its integration into the Deepseek2 architecture, and the necessary configuration and benchmarking tools to enable and validate this feature.
Highlights
- New Flash Attention 3 MTP Kernel: I've introduced a new Flash Attention 3 (FA3) kernel specifically designed for Multi-Token Prediction (MTP) mode. This kernel,
flash_attn_with_kvcache_mtp
, is adapted fromsgl-project/sglang
and is crucial for optimizing decode performance when generating multiple tokens simultaneously. - Deepseek2 Model Integration: I've integrated the new FA3 MTP kernel into the Deepseek2 model's transformer layer. This involves adding
mtp_step
andmtp_size
parameters and conditionally switching to the new_token_gqa_decode_attention_mtp
method when theenable_fa3_mtp
flag is active. This ensures the Deepseek2 model can leverage the performance benefits of the new kernel. - Command-Line Argument and Configuration: I've added a new command-line argument,
--enable_fa3_mtp
, toapi_cli.py
and updated theStartArgs
configuration. This allows users to easily enable and utilize the new FA3 MTP feature from the server's command line, with a validation check to ensuremtp_mode
is also set. - Benchmarking Tools and Script: I've included new benchmarking utilities and a specific benchmark script for the
fa3_decode_mtp
kernel. These tools, adapted fromtile-ai/tilelang
, will help in evaluating the performance and correctness of the new MTP attention mechanism.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for fa3_mtp
(FlashAttention-3 with Multi-Token Prediction), which is a significant feature enhancement. The changes span across dependency management in Dockerfiles, core model logic for DeepSeek V2, and the addition of a new attention kernel. New command-line arguments and a benchmark test are also included to support and validate this feature. My review focuses on improving code readability, performance through refactoring redundant function calls, and optimizing the Docker image size. The core logic for the new feature appears sound and well-integrated.
RUN git clone https://github.com/ModelTC/LightKernel.git && cd LightKernel && pip install --no-deps -v . && \ | ||
cd flash-attention/hopper/ && python setup.py install |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This RUN
command can be improved for readability and to reduce the final image size by cleaning up the cloned repository. Consider breaking down the command and removing the LightKernel
directory after installation. For example:
RUN git clone https://github.com/ModelTC/LightKernel.git && \
cd LightKernel && \
pip install --no-deps -v . && \
cd flash-attention/hopper/ && \
python setup.py install && \
cd ../../.. && \
rm -rf LightKernel
RUN git clone https://github.com/ModelTC/LightKernel.git && cd LightKernel && pip install --no-deps -v . && \ | ||
cd flash-attention/hopper/ && python setup.py install |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This RUN
command can be improved for readability and to reduce the final image size by cleaning up the cloned repository. Consider breaking down the command and removing the LightKernel
directory after installation. For example:
RUN git clone https://github.com/ModelTC/LightKernel.git && \
cd LightKernel && \
pip install --no-deps -v . && \
cd flash-attention/hopper/ && \
python setup.py install && \
cd ../../.. && \
rm -rf LightKernel
mtp_step = get_env_start_args().mtp_step | ||
mtp_size = mtp_step + 1 | ||
enable_fa3_mtp = get_env_start_args().enable_fa3_mtp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To improve readability and avoid redundant calls, it's better to call get_env_start_args()
only once and store the result in a local variable.
mtp_step = get_env_start_args().mtp_step | |
mtp_size = mtp_step + 1 | |
enable_fa3_mtp = get_env_start_args().enable_fa3_mtp | |
start_args = get_env_start_args() | |
mtp_step, enable_fa3_mtp = start_args.mtp_step, start_args.enable_fa3_mtp | |
mtp_size = mtp_step + 1 |
@@ -97,7 +100,11 @@ def _bind_attention(self): | |||
) | |||
else: | |||
self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self) | |||
if get_env_start_args().enable_fa3: | |||
if get_env_start_args().enable_fa3_mtp: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -69,7 +69,7 @@ def __init__(self, kvargs): | |||
return | |||
|
|||
def _init_inferstate_cls(self): | |||
if get_env_start_args().enable_fa3: | |||
if get_env_start_args().enable_fa3 or get_env_start_args().enable_fa3_mtp: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
total_seqlens = cache_seqlens.sum().item() | ||
mean_seqlens = cache_seqlens.float().mean().int().item() | ||
max_seqlen = cache_seqlens.max().item() | ||
max_seqlen_pad = math.ceil(max_seqlen / 256) * 256 # ?为什么对齐256 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No description provided.