Skip to content

feat: add transformer head pruning to Pruner#29

Open
nathanhubens wants to merge 3 commits into
masterfrom
feature/head-pruning
Open

feat: add transformer head pruning to Pruner#29
nathanhubens wants to merge 3 commits into
masterfrom
feature/head-pruning

Conversation

@nathanhubens
Copy link
Copy Markdown
Collaborator

Summary

  • Add head_pruning_ratio, prune_num_heads, prune_head_dims parameters to Pruner for structured attention head removal
  • Refactor attention detection: _detect_attention_heads() populates num_heads dict without ignoring QKV layers when head pruning is active
  • Add _sync_attention_attrs() that reads from torch-pruning's live num_heads state (fixes stale-snapshot bug in old restore_attention_layers)
  • Auto-enable XOR pattern: head_pruning_ratio > 0prune_num_heads=True, prune_head_dims=False (matching torch-pruning convention)
  • Fix PruneCallback.before_fit to forward example_inputs to Pruner (was falling back to default 224×224, breaking transformer model tracing)
  • Update print_sparsity() to report attention head count changes

Supported architectures

  • timm-style attention (modules with .qkv Linear + .num_heads): full head pruning support
  • nn.MultiheadAttention: continues to be ignored (torch-pruning's head pruning requires .out_features on QKV layers, which native MHA lacks — uses raw in_proj_weight parameter)

Usage

# Standalone
pruner = Pruner(model, pruning_ratio=0.3, context='local', criteria=large_final,
                example_inputs=x, head_pruning_ratio=0.5)
pruner.prune_model()

# With PruneCallback
cb = PruneCallback(pruning_ratio=30, schedule=agp, context='global',
                   criteria=large_final, head_pruning_ratio=0.5)
learn.fit(10, cbs=[cb])

Test plan

  • Unit tests: head detection, auto-enable XOR, ratio normalization, backward compat (CNN unchanged)
  • Slow integration test: head pruning on 4-layer ViT (8→4 heads, param reduction verified)
  • Full nbdev-test suite passes
  • Pre-push validator: nbdev-export + nbdev-clean + clean checkout verified

Enable structured pruning of attention heads via torch-pruning's built-in
head pruning support. Adds head_pruning_ratio, prune_num_heads, and
prune_head_dims parameters with backward-compatible defaults.

- Refactor _detect_attention_heads: populates num_heads without ignoring
  QKV layers when head pruning is active (timm-style .qkv/.qkv_proj)
- Add _sync_attention_attrs: reads from pruner's live num_heads state
  instead of stale snapshot (fixes post-pruning corruption)
- Auto-enable XOR pattern: head_pruning_ratio > 0 sets prune_num_heads=True,
  prune_head_dims=False (following torch-pruning convention)
- Fix PruneCallback: forward example_inputs to Pruner (was using default
  224x224, breaking transformer tracing)
- Update print_sparsity to show head count changes
- Add unit + slow integration tests for head pruning
timm's Attention.forward uses reshape(B, N, C) where C comes from the
input shape. After head pruning, the attention output dimension differs
from the input dimension, causing reshape failures.

Following the official torch-pruning ViT example pattern, the Pruner now
automatically patches timm Attention.forward to use reshape(B, N, -1)
when head pruning is enabled. This makes head pruning work with any
combination of pruning_ratio and head_pruning_ratio.

Also adds head_pruning tutorial with prune-then-fine-tune workflow.
- Patch timm Attention.forward to use reshape(B,N,-1) for pruning
  compatibility (from official torch-pruning prune_timm_vit.py example)
- Freeze head pruning after first application to prevent over-pruning
  (torch-pruning computes removal from current count, not original)
- Add tutorials: head_pruning, conv_decomposer, wanda, torchao
- Head pruning tutorial includes prune_every prototype (WIP)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant