Skip to content

Conversation

@kaix-nv
Copy link
Contributor

@kaix-nv kaix-nv commented Nov 7, 2025

What does this PR do?

Type of change: ?
New feature

Overview: ?
This PR provides a sparse attention support in ModelOpt for applying attention sparsity through skip softmax method, enabling inference speedups for LLMs.

Key Features:

  • Skip softmax support
  • Sparse attention config
  • Extensible method registry for future sparse attention algorithms
  • HuggingFace Transformers integration
  • Phase-aware thresholds (separate prefill/decode)

Design doc

Usage

import torch
import modelopt.torch.sparsity.attention_sparsity as mts
from transformers import AutoModelForCausalLM

# Load model (must use eager attention for softmax patching)
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    attn_implementation="eager",  # Required!
    torch_dtype=torch.bfloat16,
)

# Use pre-defined configuration
from modelopt.torch.sparsity.attention_sparsity import SKIP_SOFTMAX_DEFAULT
model = mts.sparsify(model, SKIP_SOFTMAX_DEFAULT)

Testing

Unit Test

pytest tests/unit/torch/sparsity/attention_sparsity -v
pytest tests/gpu/torch/sparsity/attention_sparsity -v
pytest tests/examples/llm_sparsity/attention_sparsity -v

ALL PASSED.

Accuracy

Benchmark: MMLU
Model: Qwen/Qwen3-4B
Cmd: python mmlu.py --model_name causal --model_path Qwen/Qwen3-4B --sparse_cfg SKIP_SOFTMAX_DEFAULT

MMLU
BF16 69.96
SKIP_SOFTMAX_DEFAULT 69.86

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

@kaix-nv kaix-nv requested a review from a team as a code owner November 7, 2025 07:53
@kaix-nv kaix-nv requested review from realAsma and removed request for realAsma November 7, 2025 07:53
@codecov
Copy link

codecov bot commented Nov 7, 2025

Codecov Report

❌ Patch coverage is 89.67254% with 41 lines in your changes missing coverage. Please review.
✅ Project coverage is 74.95%. Comparing base (fa84955) to head (cd6fce2).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
...ch/sparsity/attention_sparsity/sparse_attention.py 71.42% 16 Missing ⚠️
...pt/torch/sparsity/attention_sparsity/conversion.py 93.02% 9 Missing ⚠️
...y/attention_sparsity/methods/flash_skip_softmax.py 90.90% 8 Missing ⚠️
...ch/sparsity/attention_sparsity/methods/registry.py 88.88% 3 Missing ⚠️
modelopt/torch/sparsity/attention_sparsity/mode.py 90.32% 3 Missing ⚠️
...delopt/torch/sparsity/attention_sparsity/config.py 95.91% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #527      +/-   ##
==========================================
+ Coverage   74.64%   74.95%   +0.31%     
==========================================
  Files         183      192       +9     
  Lines       18542    18939     +397     
==========================================
+ Hits        13840    14196     +356     
- Misses       4702     4743      +41     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@kaix-nv kaix-nv force-pushed the kaix/sparse_attention_core branch 4 times, most recently from 54bfe2c to 0ce1376 Compare November 8, 2025 03:31
@kaix-nv kaix-nv force-pushed the kaix/sparse_attention_core branch from fc9d285 to 5d027e0 Compare November 11, 2025 23:44
@kaix-nv kaix-nv changed the title [2/n] Add Core Sparse Attention Infrastructure [OMNIML-2852][2/n] Add Core Sparse Attention Infrastructure Nov 12, 2025
@kaix-nv kaix-nv changed the title [OMNIML-2852][2/n] Add Core Sparse Attention Infrastructure [OMNIML-2852] [2/n] Add Core Sparse Attention Infrastructure Nov 12, 2025
@cjluo-nv
Copy link
Collaborator

Hi @kaix-nv could you further split this code change? This PR has 3000+ lines of code change and many file moves

@kevalmorabia97 kevalmorabia97 removed the request for review from RalphMao December 1, 2025 19:07


# Create registry for sparse attention modules
SparseAttentionRegistry = _DMRegistryCls("SparseAttention", SparseAttentionModule)
Copy link
Collaborator

@kevalmorabia97 kevalmorabia97 Dec 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use a single registry for all Sparsity algorithms and modes and then use top-level mts.sparsify(model, mode=...) so all algorithms (e.g. weight or attention sparsify) are invoked by single shared API instead of separate API per algorithm?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good advice. I'll submit a follow-up PR later.

Copy link
Collaborator

@jy-yuan jy-yuan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work on the overall architecture!

Comment on lines +193 to +194
total_blocks = (
num_block_rows * (num_block_rows + 1) // 2 # Causal: N(N+1)/2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does that means rows==columns? Which means we only have causal in self-attention, not cross-attention?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's for causal attention in prefill.

"--backend",
type=str,
default="pytorch",
choices=["pytorch", "triton"],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is "triton" a TODO?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

Comment on lines 351 to 352
method = getattr(module, "_method", "unknown")
threshold = getattr(module, "_threshold", "N/A")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do SparseAttentionModule have _method or _threshold?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

print_sparse_attention_summary isn’t used in this PR, I’ve removed it. It will be introduced in the next PR.

Comment on lines +188 to +210
def restore_sparse_attention_state(model: nn.Module, state_dict: dict[str, Any]):
"""Restore sparse attention state from state dict.

Args:
model: Model with sparse attention modules
state_dict: Saved state dictionary
"""
for name, module in model.named_modules():
if isinstance(module, SparseAttentionModule):
module_name = get_unwrapped_name(name, model)
if module_name in state_dict:
module_state = state_dict[module_name]

# Restore method and config
if "method" in module_state:
module._method = module_state["method"]
if "method_config" in module_state:
# Restore config attributes
for key, val in module_state["method_config"].items():
setattr(module, f"_{key}", val)

# Re-setup with restored config
module._setup()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need add test for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_restore_sparse_attention_model covers the test for this func.

@kaix-nv kaix-nv force-pushed the kaix/sparse_attention_core branch from cd6fce2 to 0ca4d20 Compare December 8, 2025 21:32
@kaix-nv
Copy link
Contributor Author

kaix-nv commented Dec 8, 2025

@kevalmorabia97 I've addressed the review suggestions. Could you please review and approve the PR so I can move forward with the subsequent PRs? Thanks.

@kaix-nv kaix-nv force-pushed the kaix/sparse_attention_core branch from 0ca4d20 to 02182f8 Compare December 9, 2025 00:05
Copy link
Collaborator

@kevalmorabia97 kevalmorabia97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only reviewed high-level structure. Would suggest someone else take a look at the core sparsity logic as I'm not familiar with it

@kaix-nv
Copy link
Contributor Author

kaix-nv commented Dec 9, 2025

Only reviewed high-level structure. Would suggest someone else take a look at the core sparsity logic as I'm not familiar with it
I've asked @jy-yuan to review since he's very familiar with the core logic.
@jy-yuan Please approve if you think the PR is in good shape. Thanks.

@kaix-nv kaix-nv force-pushed the kaix/sparse_attention_core branch from 02182f8 to 8acf333 Compare December 9, 2025 15:41
@kaix-nv kaix-nv requested a review from a team as a code owner December 9, 2025 15:41
@realAsma
Copy link
Contributor

Looks great!

Should we have a simpler high-level usage which aligns with mtq?

# Use pre-defined configuration
model = mts.sparsify(model, mts.SPARSE_ATTEN_SKIP_SOFTMAX_CFG)

@jy-yuan
Copy link
Collaborator

jy-yuan commented Dec 10, 2025

Only reviewed high-level structure. Would suggest someone else take a look at the core sparsity logic as I'm not familiar with it
I've asked @jy-yuan to review since he's very familiar with the core logic.
@jy-yuan Please approve if you think the PR is in good shape. Thanks.

I reviewed and approve, thanks! @kevalmorabia97

@kaix-nv
Copy link
Contributor Author

kaix-nv commented Dec 10, 2025

Looks great!

Should we have a simpler high-level usage which aligns with mtq?

# Use pre-defined configuration
model = mts.sparsify(model, mts.SPARSE_ATTEN_SKIP_SOFTMAX_CFG)

Good point. Initially, I wanted to separate attention sparsity from weight sparsity, so I choose different APIs for each feature. But using a single API is indeed more consistent, I’ll submit a follow-up PR to unify them.

@kaix-nv kaix-nv enabled auto-merge (squash) December 11, 2025 00:13
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please revert license header diff in this file

@kaix-nv kaix-nv merged commit cd0d185 into main Dec 11, 2025
38 of 40 checks passed
@kaix-nv kaix-nv deleted the kaix/sparse_attention_core branch December 11, 2025 20:06
b7r6 pushed a commit to weyl-ai/Model-Optimizer that referenced this pull request Dec 18, 2025
)

## What does this PR do?

**Type of change:** ? <!-- Use one of the following: Bug fix, new
feature, new example, new tests, documentation. -->
New feature

**Overview:** ?
This PR provides a sparse attention support in ModelOpt for applying
attention sparsity through skip softmax method, enabling inference
speedups for LLMs.

Key Features:
- Skip softmax support
- Sparse attention config
- Extensible method registry for future sparse attention algorithms
- HuggingFace Transformers integration
- Phase-aware thresholds (separate prefill/decode)

[Design
doc](https://docs.google.com/document/d/1OgmTAKkoD4ZSWYXel-FeaQqmI5PtyNhQ4dEuhGiZAQQ/edit?tab=t.0#heading=h.dyp44woziy9x)

## Usage
<!-- You can potentially add a usage example below. -->

```python
import torch
import modelopt.torch.sparsity.attention_sparsity as mts
from transformers import AutoModelForCausalLM

# Load model (must use eager attention for softmax patching)
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    attn_implementation="eager",  # Required!
    torch_dtype=torch.bfloat16,
)

# Use pre-defined configuration
from modelopt.torch.sparsity.attention_sparsity import SKIP_SOFTMAX_DEFAULT
model = mts.sparsify(model, SKIP_SOFTMAX_DEFAULT)
```

## Testing
<!-- Mention how have you tested your change if applicable. -->
### Unit Test

```bash
pytest tests/unit/torch/sparsity/attention_sparsity -v
pytest tests/gpu/torch/sparsity/attention_sparsity -v
pytest tests/examples/llm_sparsity/attention_sparsity -v
```
ALL PASSED.

### Accuracy
Benchmark: MMLU
Model: Qwen/Qwen3-4B
Cmd: python mmlu.py --model_name causal --model_path Qwen/Qwen3-4B
--sparse_cfg SKIP_SOFTMAX_DEFAULT

|                      | MMLU  |
|----------------------|-------|
| BF16                 | 69.96 |
| SKIP_SOFTMAX_DEFAULT | 69.86 |

## Before your PR is "*Ready for review*"
<!-- If you haven't finished some of the above items you can still open
`Draft` PR. -->

- **Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)**
and your commits are signed.
- **Is this change backward compatible?**: Yes/No <!--- If No, explain
why. -->
- **Did you write any new necessary tests?**: Yes/No
- **Did you add or update any necessary documentation?**: Yes/No
- **Did you update
[Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**:
Yes/No <!--- Only for new features, API changes, critical bug fixes or
bw breaking changes. -->

## Additional Information
<!-- E.g. related issue. -->

---------

Signed-off-by: Kai Xu <[email protected]>
soodoshll pushed a commit to soodoshll/TensorRT-Model-Optimizer that referenced this pull request Dec 18, 2025
)

## What does this PR do?

**Type of change:** ? <!-- Use one of the following: Bug fix, new
feature, new example, new tests, documentation. -->
New feature

**Overview:** ?
This PR provides a sparse attention support in ModelOpt for applying
attention sparsity through skip softmax method, enabling inference
speedups for LLMs.

Key Features:
- Skip softmax support
- Sparse attention config
- Extensible method registry for future sparse attention algorithms
- HuggingFace Transformers integration
- Phase-aware thresholds (separate prefill/decode)

[Design
doc](https://docs.google.com/document/d/1OgmTAKkoD4ZSWYXel-FeaQqmI5PtyNhQ4dEuhGiZAQQ/edit?tab=t.0#heading=h.dyp44woziy9x)

## Usage
<!-- You can potentially add a usage example below. -->

```python
import torch
import modelopt.torch.sparsity.attention_sparsity as mts
from transformers import AutoModelForCausalLM

# Load model (must use eager attention for softmax patching)
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    attn_implementation="eager",  # Required!
    torch_dtype=torch.bfloat16,
)

# Use pre-defined configuration
from modelopt.torch.sparsity.attention_sparsity import SKIP_SOFTMAX_DEFAULT
model = mts.sparsify(model, SKIP_SOFTMAX_DEFAULT)
```

## Testing
<!-- Mention how have you tested your change if applicable. -->
### Unit Test

```bash
pytest tests/unit/torch/sparsity/attention_sparsity -v
pytest tests/gpu/torch/sparsity/attention_sparsity -v
pytest tests/examples/llm_sparsity/attention_sparsity -v
```
ALL PASSED.

### Accuracy
Benchmark: MMLU
Model: Qwen/Qwen3-4B
Cmd: python mmlu.py --model_name causal --model_path Qwen/Qwen3-4B
--sparse_cfg SKIP_SOFTMAX_DEFAULT

|                      | MMLU  |
|----------------------|-------|
| BF16                 | 69.96 |
| SKIP_SOFTMAX_DEFAULT | 69.86 |

## Before your PR is "*Ready for review*"
<!-- If you haven't finished some of the above items you can still open
`Draft` PR. -->

- **Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)**
and your commits are signed.
- **Is this change backward compatible?**: Yes/No <!--- If No, explain
why. -->
- **Did you write any new necessary tests?**: Yes/No
- **Did you add or update any necessary documentation?**: Yes/No
- **Did you update
[Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**:
Yes/No <!--- Only for new features, API changes, critical bug fixes or
bw breaking changes. -->

## Additional Information
<!-- E.g. related issue. -->

---------

Signed-off-by: Kai Xu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants