-
Notifications
You must be signed in to change notification settings - Fork 221
[OMNIML-2852] [2/n] Add Core Sparse Attention Infrastructure #527
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #527 +/- ##
==========================================
+ Coverage 74.64% 74.95% +0.31%
==========================================
Files 183 192 +9
Lines 18542 18939 +397
==========================================
+ Hits 13840 14196 +356
- Misses 4702 4743 +41 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
54bfe2c to
0ce1376
Compare
fc9d285 to
5d027e0
Compare
|
Hi @kaix-nv could you further split this code change? This PR has 3000+ lines of code change and many file moves |
|
|
||
|
|
||
| # Create registry for sparse attention modules | ||
| SparseAttentionRegistry = _DMRegistryCls("SparseAttention", SparseAttentionModule) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use a single registry for all Sparsity algorithms and modes and then use top-level mts.sparsify(model, mode=...) so all algorithms (e.g. weight or attention sparsify) are invoked by single shared API instead of separate API per algorithm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good advice. I'll submit a follow-up PR later.
tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py
Outdated
Show resolved
Hide resolved
tests/gpu/torch/sparsity/attention_sparsity/test_attention_sparsity_gpu.py
Outdated
Show resolved
Hide resolved
tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py
Outdated
Show resolved
Hide resolved
tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py
Outdated
Show resolved
Hide resolved
jy-yuan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work on the overall architecture!
modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py
Outdated
Show resolved
Hide resolved
| total_blocks = ( | ||
| num_block_rows * (num_block_rows + 1) // 2 # Causal: N(N+1)/2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does that means rows==columns? Which means we only have causal in self-attention, not cross-attention?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it's for causal attention in prefill.
| "--backend", | ||
| type=str, | ||
| default="pytorch", | ||
| choices=["pytorch", "triton"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is "triton" a TODO?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes
| method = getattr(module, "_method", "unknown") | ||
| threshold = getattr(module, "_threshold", "N/A") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do SparseAttentionModule have _method or _threshold?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
print_sparse_attention_summary isn’t used in this PR, I’ve removed it. It will be introduced in the next PR.
| def restore_sparse_attention_state(model: nn.Module, state_dict: dict[str, Any]): | ||
| """Restore sparse attention state from state dict. | ||
|
|
||
| Args: | ||
| model: Model with sparse attention modules | ||
| state_dict: Saved state dictionary | ||
| """ | ||
| for name, module in model.named_modules(): | ||
| if isinstance(module, SparseAttentionModule): | ||
| module_name = get_unwrapped_name(name, model) | ||
| if module_name in state_dict: | ||
| module_state = state_dict[module_name] | ||
|
|
||
| # Restore method and config | ||
| if "method" in module_state: | ||
| module._method = module_state["method"] | ||
| if "method_config" in module_state: | ||
| # Restore config attributes | ||
| for key, val in module_state["method_config"].items(): | ||
| setattr(module, f"_{key}", val) | ||
|
|
||
| # Re-setup with restored config | ||
| module._setup() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need add test for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test_restore_sparse_attention_model covers the test for this func.
cd6fce2 to
0ca4d20
Compare
|
@kevalmorabia97 I've addressed the review suggestions. Could you please review and approve the PR so I can move forward with the subsequent PRs? Thanks. |
Signed-off-by: Kai Xu <[email protected]>
Signed-off-by: Kai Xu <[email protected]>
0ca4d20 to
02182f8
Compare
kevalmorabia97
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only reviewed high-level structure. Would suggest someone else take a look at the core sparsity logic as I'm not familiar with it
Signed-off-by: Kai Xu <[email protected]>
02182f8 to
8acf333
Compare
|
Looks great! Should we have a simpler high-level usage which aligns with |
I reviewed and approve, thanks! @kevalmorabia97 |
Good point. Initially, I wanted to separate attention sparsity from weight sparsity, so I choose different APIs for each feature. But using a single API is indeed more consistent, I’ll submit a follow-up PR to unify them. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please revert license header diff in this file
) ## What does this PR do? **Type of change:** ? <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> New feature **Overview:** ? This PR provides a sparse attention support in ModelOpt for applying attention sparsity through skip softmax method, enabling inference speedups for LLMs. Key Features: - Skip softmax support - Sparse attention config - Extensible method registry for future sparse attention algorithms - HuggingFace Transformers integration - Phase-aware thresholds (separate prefill/decode) [Design doc](https://docs.google.com/document/d/1OgmTAKkoD4ZSWYXel-FeaQqmI5PtyNhQ4dEuhGiZAQQ/edit?tab=t.0#heading=h.dyp44woziy9x) ## Usage <!-- You can potentially add a usage example below. --> ```python import torch import modelopt.torch.sparsity.attention_sparsity as mts from transformers import AutoModelForCausalLM # Load model (must use eager attention for softmax patching) model = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-2-7b-hf", attn_implementation="eager", # Required! torch_dtype=torch.bfloat16, ) # Use pre-defined configuration from modelopt.torch.sparsity.attention_sparsity import SKIP_SOFTMAX_DEFAULT model = mts.sparsify(model, SKIP_SOFTMAX_DEFAULT) ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ### Unit Test ```bash pytest tests/unit/torch/sparsity/attention_sparsity -v pytest tests/gpu/torch/sparsity/attention_sparsity -v pytest tests/examples/llm_sparsity/attention_sparsity -v ``` ALL PASSED. ### Accuracy Benchmark: MMLU Model: Qwen/Qwen3-4B Cmd: python mmlu.py --model_name causal --model_path Qwen/Qwen3-4B --sparse_cfg SKIP_SOFTMAX_DEFAULT | | MMLU | |----------------------|-------| | BF16 | 69.96 | | SKIP_SOFTMAX_DEFAULT | 69.86 | ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> --------- Signed-off-by: Kai Xu <[email protected]>
) ## What does this PR do? **Type of change:** ? <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> New feature **Overview:** ? This PR provides a sparse attention support in ModelOpt for applying attention sparsity through skip softmax method, enabling inference speedups for LLMs. Key Features: - Skip softmax support - Sparse attention config - Extensible method registry for future sparse attention algorithms - HuggingFace Transformers integration - Phase-aware thresholds (separate prefill/decode) [Design doc](https://docs.google.com/document/d/1OgmTAKkoD4ZSWYXel-FeaQqmI5PtyNhQ4dEuhGiZAQQ/edit?tab=t.0#heading=h.dyp44woziy9x) ## Usage <!-- You can potentially add a usage example below. --> ```python import torch import modelopt.torch.sparsity.attention_sparsity as mts from transformers import AutoModelForCausalLM # Load model (must use eager attention for softmax patching) model = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-2-7b-hf", attn_implementation="eager", # Required! torch_dtype=torch.bfloat16, ) # Use pre-defined configuration from modelopt.torch.sparsity.attention_sparsity import SKIP_SOFTMAX_DEFAULT model = mts.sparsify(model, SKIP_SOFTMAX_DEFAULT) ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ### Unit Test ```bash pytest tests/unit/torch/sparsity/attention_sparsity -v pytest tests/gpu/torch/sparsity/attention_sparsity -v pytest tests/examples/llm_sparsity/attention_sparsity -v ``` ALL PASSED. ### Accuracy Benchmark: MMLU Model: Qwen/Qwen3-4B Cmd: python mmlu.py --model_name causal --model_path Qwen/Qwen3-4B --sparse_cfg SKIP_SOFTMAX_DEFAULT | | MMLU | |----------------------|-------| | BF16 | 69.96 | | SKIP_SOFTMAX_DEFAULT | 69.86 | ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> --------- Signed-off-by: Kai Xu <[email protected]>
What does this PR do?
Type of change: ?
New feature
Overview: ?
This PR provides a sparse attention support in ModelOpt for applying attention sparsity through skip softmax method, enabling inference speedups for LLMs.
Key Features:
Design doc
Usage
Testing
Unit Test
ALL PASSED.
Accuracy
Benchmark: MMLU
Model: Qwen/Qwen3-4B
Cmd: python mmlu.py --model_name causal --model_path Qwen/Qwen3-4B --sparse_cfg SKIP_SOFTMAX_DEFAULT
Before your PR is "Ready for review"
Additional Information