feat: add transformer head pruning to Pruner#29
Open
nathanhubens wants to merge 3 commits into
Open
Conversation
Enable structured pruning of attention heads via torch-pruning's built-in head pruning support. Adds head_pruning_ratio, prune_num_heads, and prune_head_dims parameters with backward-compatible defaults. - Refactor _detect_attention_heads: populates num_heads without ignoring QKV layers when head pruning is active (timm-style .qkv/.qkv_proj) - Add _sync_attention_attrs: reads from pruner's live num_heads state instead of stale snapshot (fixes post-pruning corruption) - Auto-enable XOR pattern: head_pruning_ratio > 0 sets prune_num_heads=True, prune_head_dims=False (following torch-pruning convention) - Fix PruneCallback: forward example_inputs to Pruner (was using default 224x224, breaking transformer tracing) - Update print_sparsity to show head count changes - Add unit + slow integration tests for head pruning
timm's Attention.forward uses reshape(B, N, C) where C comes from the input shape. After head pruning, the attention output dimension differs from the input dimension, causing reshape failures. Following the official torch-pruning ViT example pattern, the Pruner now automatically patches timm Attention.forward to use reshape(B, N, -1) when head pruning is enabled. This makes head pruning work with any combination of pruning_ratio and head_pruning_ratio. Also adds head_pruning tutorial with prune-then-fine-tune workflow.
- Patch timm Attention.forward to use reshape(B,N,-1) for pruning compatibility (from official torch-pruning prune_timm_vit.py example) - Freeze head pruning after first application to prevent over-pruning (torch-pruning computes removal from current count, not original) - Add tutorials: head_pruning, conv_decomposer, wanda, torchao - Head pruning tutorial includes prune_every prototype (WIP)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
head_pruning_ratio,prune_num_heads,prune_head_dimsparameters toPrunerfor structured attention head removal_detect_attention_heads()populatesnum_headsdict without ignoring QKV layers when head pruning is active_sync_attention_attrs()that reads from torch-pruning's livenum_headsstate (fixes stale-snapshot bug in oldrestore_attention_layers)head_pruning_ratio > 0→prune_num_heads=True, prune_head_dims=False(matching torch-pruning convention)PruneCallback.before_fitto forwardexample_inputsto Pruner (was falling back to default 224×224, breaking transformer model tracing)print_sparsity()to report attention head count changesSupported architectures
.qkvLinear +.num_heads): full head pruning supportnn.MultiheadAttention: continues to be ignored (torch-pruning's head pruning requires.out_featureson QKV layers, which native MHA lacks — uses rawin_proj_weightparameter)Usage
Test plan
nbdev-testsuite passes