-
Notifications
You must be signed in to change notification settings - Fork 284
add flashinfer-trtllm-ragged-prefill-attn #1099
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @SangChengC, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the Deepseek2 model by integrating FlashInfer and TRTLLM for ragged prefill attention. It introduces conditional kernel selection based on an environment variable, allowing users to switch between FlashInfer and TRTLLM kernels. The changes modify the model's structure and attention mechanisms to improve performance and flexibility. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for flashinfer-trtllm-ragged-prefill-attn. The changes introduce a new environment variable to toggle this feature and implement the corresponding attention kernels for both normal and fp8 modes. The implementation is largely correct, but I've identified some areas for improvement. Specifically, there's significant code duplication in the new attention functions that should be refactored for better maintainability. Additionally, there's a piece of unreachable code that needs to be removed.
| def _context_attention_trtllm_ragged_with_CC_fp8( | ||
| self, | ||
| q: torch.Tensor, | ||
| kv, | ||
| infer_state: Deepseek2FlashInferStateInfo, | ||
| layer_weight: Deepseek2TransformerLayerWeight, | ||
| out=None, | ||
| ) -> torch.Tensor: | ||
| k_nope, k_rope, v = self._decompress_kv( | ||
| kv, | ||
| infer_state, | ||
| layer_weight, | ||
| True, | ||
| infer_state.total_token_num, | ||
| infer_state.b_seq_len, | ||
| infer_state.max_value_in_b_seq_len, | ||
| infer_state.b1_kv_start_loc, | ||
| ) | ||
| o_tensor = ( | ||
| self.alloc_tensor((q.shape[0], q.shape[1], self.qk_nope_head_dim), dtype=q.dtype) if out is None else out | ||
| ) | ||
| k = torch.cat([k_nope, torch.repeat_interleave(k_rope, self.tp_q_head_num_, dim=-2)], dim=-1) | ||
|
|
||
| seq_lens = infer_state.b_seq_len.int() | ||
| cum_seq_lens = infer_state.b1_cu_q_seq_len.int() | ||
| max_seq_len = int(seq_lens.max().item()) | ||
|
|
||
| o = flashinfer.prefill.trtllm_ragged_attention_deepseek( | ||
| query=q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim), | ||
| key=k.view(-1, self.tp_k_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim), | ||
| value=v.view(-1, self.tp_v_head_num_, self.v_head_dim), | ||
| workspace_buffer=infer_state.flashinfer_extra_state.workspace_buffer, | ||
| seq_lens=seq_lens, | ||
| max_q_len=max_seq_len, | ||
| max_kv_len=max_seq_len, | ||
| bmm1_scale=self.softmax_scale, | ||
| bmm2_scale=1.0, | ||
| o_sf_scale=1.0, | ||
| batch_size=infer_state.batch_size, | ||
| window_left=-1, | ||
| cum_seq_lens_q=cum_seq_lens, | ||
| cum_seq_lens_kv=cum_seq_lens, | ||
| enable_pdl=False, | ||
| is_causal=True, | ||
| return_lse=False, | ||
| ) | ||
| o_tensor.copy_(o) | ||
| return o_tensor | ||
| return q |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The method _context_attention_trtllm_ragged_with_CC_fp8 is nearly identical to _context_attention_trtllm_ragged_with_CC. The only difference is the boolean value True passed for the is_fp8 parameter to self._decompress_kv. This significant code duplication makes the code harder to maintain and more prone to errors if one function is updated and the other is not.
To improve maintainability, consider refactoring these two methods into a single private helper that accepts an is_fp8 boolean parameter.
No description provided.