diff --git a/fasterai/_modidx.py b/fasterai/_modidx.py
index 500ab14..e6f0d9d 100644
--- a/fasterai/_modidx.py
+++ b/fasterai/_modidx.py
@@ -223,18 +223,26 @@
'fasterai/misc/bn_folding.py')},
'fasterai.misc.conv_decomposer': { 'fasterai.misc.conv_decomposer.Conv_Decomposer': ( 'misc/conv_decomposer.html#conv_decomposer',
'fasterai/misc/conv_decomposer.py'),
+ 'fasterai.misc.conv_decomposer.Conv_Decomposer.CP': ( 'misc/conv_decomposer.html#conv_decomposer.cp',
+ 'fasterai/misc/conv_decomposer.py'),
+ 'fasterai.misc.conv_decomposer.Conv_Decomposer.SVD': ( 'misc/conv_decomposer.html#conv_decomposer.svd',
+ 'fasterai/misc/conv_decomposer.py'),
+ 'fasterai.misc.conv_decomposer.Conv_Decomposer.Spatial': ( 'misc/conv_decomposer.html#conv_decomposer.spatial',
+ 'fasterai/misc/conv_decomposer.py'),
'fasterai.misc.conv_decomposer.Conv_Decomposer.Tucker': ( 'misc/conv_decomposer.html#conv_decomposer.tucker',
'fasterai/misc/conv_decomposer.py'),
'fasterai.misc.conv_decomposer.Conv_Decomposer.__init__': ( 'misc/conv_decomposer.html#conv_decomposer.__init__',
'fasterai/misc/conv_decomposer.py'),
'fasterai.misc.conv_decomposer.Conv_Decomposer.decompose': ( 'misc/conv_decomposer.html#conv_decomposer.decompose',
'fasterai/misc/conv_decomposer.py'),
+ 'fasterai.misc.conv_decomposer._mode_unfold': ( 'misc/conv_decomposer.html#_mode_unfold',
+ 'fasterai/misc/conv_decomposer.py'),
'fasterai.misc.conv_decomposer._partial_tucker': ( 'misc/conv_decomposer.html#_partial_tucker',
- 'fasterai/misc/conv_decomposer.py'),
- 'fasterai.misc.conv_decomposer._unfold': ( 'misc/conv_decomposer.html#_unfold',
- 'fasterai/misc/conv_decomposer.py')},
+ 'fasterai/misc/conv_decomposer.py')},
'fasterai.misc.cpu_optimizer': { 'fasterai.misc.cpu_optimizer.accelerate_model_for_cpu': ( 'misc/cpu_optimizer.html#accelerate_model_for_cpu',
- 'fasterai/misc/cpu_optimizer.py')},
+ 'fasterai/misc/cpu_optimizer.py'),
+ 'fasterai.misc.cpu_optimizer.optimize_for_cpu': ( 'misc/cpu_optimizer.html#optimize_for_cpu',
+ 'fasterai/misc/cpu_optimizer.py')},
'fasterai.misc.fc_decomposer': { 'fasterai.misc.fc_decomposer.FC_Decomposer': ( 'misc/fc_decomposer.html#fc_decomposer',
'fasterai/misc/fc_decomposer.py'),
'fasterai.misc.fc_decomposer.FC_Decomposer.SVD': ( 'misc/fc_decomposer.html#fc_decomposer.svd',
@@ -242,7 +250,13 @@
'fasterai.misc.fc_decomposer.FC_Decomposer.__init__': ( 'misc/fc_decomposer.html#fc_decomposer.__init__',
'fasterai/misc/fc_decomposer.py'),
'fasterai.misc.fc_decomposer.FC_Decomposer.decompose': ( 'misc/fc_decomposer.html#fc_decomposer.decompose',
- 'fasterai/misc/fc_decomposer.py')},
+ 'fasterai/misc/fc_decomposer.py'),
+ 'fasterai.misc.fc_decomposer._collect_activation_rms': ( 'misc/fc_decomposer.html#_collect_activation_rms',
+ 'fasterai/misc/fc_decomposer.py'),
+ 'fasterai.misc.fc_decomposer._rank_from_energy': ( 'misc/fc_decomposer.html#_rank_from_energy',
+ 'fasterai/misc/fc_decomposer.py'),
+ 'fasterai.misc.fc_decomposer._should_decompose': ( 'misc/fc_decomposer.html#_should_decompose',
+ 'fasterai/misc/fc_decomposer.py')},
'fasterai.prune.all': {},
'fasterai.prune.prune_callback': { 'fasterai.prune.prune_callback.PruneCallback': ( 'prune/prune_callback.html#prunecallback',
'fasterai/prune/prune_callback.py'),
diff --git a/fasterai/misc/all.py b/fasterai/misc/all.py
index 545f64c..f071eec 100644
--- a/fasterai/misc/all.py
+++ b/fasterai/misc/all.py
@@ -1,3 +1,4 @@
from .bn_folding import *
from .fc_decomposer import *
-from .conv_decomposer import *
\ No newline at end of file
+from .conv_decomposer import *
+from .cpu_optimizer import *
\ No newline at end of file
diff --git a/fasterai/misc/bn_folding.py b/fasterai/misc/bn_folding.py
index 758b722..6334d33 100644
--- a/fasterai/misc/bn_folding.py
+++ b/fasterai/misc/bn_folding.py
@@ -6,7 +6,6 @@
# %% ../../nbs/misc/bn_folding.ipynb #productive-preparation
import torch
import torch.nn as nn
-import torch.nn.functional as F
import copy
# %% ../../nbs/misc/bn_folding.ipynb #83000749
diff --git a/fasterai/misc/conv_decomposer.py b/fasterai/misc/conv_decomposer.py
index 933f8c2..81e39b7 100644
--- a/fasterai/misc/conv_decomposer.py
+++ b/fasterai/misc/conv_decomposer.py
@@ -1,86 +1,208 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/misc/conv_decomposer.ipynb.
# %% auto #0
-__all__ = ['Conv_Decomposer']
+__all__ = ['VALID_METHODS', 'Conv_Decomposer']
# %% ../../nbs/misc/conv_decomposer.ipynb #imports
import torch
import torch.nn as nn
import copy
+from einops import rearrange
# %% ../../nbs/misc/conv_decomposer.ipynb #conv-decomposer
-def _unfold(tensor, mode):
- "Unfold a tensor along a mode into a matrix"
- return tensor.moveaxis(mode, 0).flatten(1)
+from .fc_decomposer import _rank_from_energy, _should_decompose
-def _partial_tucker(weight, ranks, n_iter=5):
+def _mode_unfold(W, mode):
+ "Unfold a 4D tensor along a mode into a 2D matrix"
+ return rearrange(W, 'o i h w -> o (i h w)') if mode == 0 else rearrange(W, 'o i h w -> i (o h w)')
+
+def _partial_tucker(weight, ranks, n_iter=10, tol=1e-4):
"Partial Tucker decomposition on modes [0, 1] via alternating SVD (HOOI)"
- # Initialize factors from SVD of mode unfoldings
- U0 = torch.linalg.svd(_unfold(weight, 0), full_matrices=False)[0][:, :ranks[0]]
- U1 = torch.linalg.svd(_unfold(weight, 1), full_matrices=False)[0][:, :ranks[1]]
+ U0 = torch.linalg.svd(_mode_unfold(weight, 0), full_matrices=False)[0][:, :ranks[0]]
+ U1 = torch.linalg.svd(_mode_unfold(weight, 1), full_matrices=False)[0][:, :ranks[1]]
for _ in range(n_iter):
- # Project out mode 0 using U0, then update U1
+ U0_prev, U1_prev = U0.clone(), U1.clone()
proj = torch.einsum('oihw, or -> rihw', weight, U0)
- U1 = torch.linalg.svd(_unfold(proj, 1), full_matrices=False)[0][:, :ranks[1]]
- # Project out mode 1 using U1, then update U0
+ U1 = torch.linalg.svd(_mode_unfold(proj, 1), full_matrices=False)[0][:, :ranks[1]]
proj = torch.einsum('oihw, is -> oshw', weight, U1)
- U0 = torch.linalg.svd(_unfold(proj, 0), full_matrices=False)[0][:, :ranks[0]]
+ U0 = torch.linalg.svd(_mode_unfold(proj, 0), full_matrices=False)[0][:, :ranks[0]]
+ if (U0 - U0_prev).norm() + (U1 - U1_prev).norm() < tol: break
- # Core = W ×₀ U0ᵀ ×₁ U1ᵀ
core = torch.einsum('oihw, or, is -> rshw', weight, U0, U1)
return core, [U0, U1]
+VALID_METHODS = frozenset({'tucker', 'svd', 'spatial', 'cp'})
class Conv_Decomposer:
- "Decompose Conv2d layers using Tucker decomposition to reduce parameters and FLOPs"
+ "Decompose Conv2d layers to reduce parameters and FLOPs"
def __init__(self): pass
def decompose(self,
- model: nn.Module, # The model to decompose
- percent_removed: float = 0.5, # Fraction of rank to remove per mode [0, 1)
+ model: nn.Module, # The model to decompose
+ percent_removed: float = 0.5, # Fraction of rank to remove [0, 1)
+ method: str = 'tucker', # 'tucker', 'svd', 'spatial', or 'cp'
+ energy_threshold: float | None = None, # Auto rank via energy retention (0-1)
+ layers: list[str] | None = None, # Layer names to decompose (None = all eligible)
+ exclude: list[str] | None = None, # Layer names to skip
+ n_iter: int = 10, # Max HOOI iterations (tucker only)
+ tol: float = 1e-4, # HOOI convergence tolerance (tucker only)
) -> nn.Module:
- "Recursively decompose all eligible Conv2d layers in the model"
- if not (0 <= percent_removed < 1):
+ "Decompose eligible Conv2d layers using the specified method."
+ if method not in VALID_METHODS:
+ raise ValueError(f"method must be one of {VALID_METHODS}, got {method!r}")
+ if energy_threshold is None and not (0 <= percent_removed < 1):
raise ValueError(f"percent_removed must be in range [0, 1), got {percent_removed}")
+ if energy_threshold is not None and not (0 < energy_threshold <= 1):
+ raise ValueError(f"energy_threshold must be in range (0, 1], got {energy_threshold}")
+
+ decompose_fn = {'tucker': self.Tucker, 'svd': self.SVD,
+ 'spatial': self.Spatial, 'cp': self.CP}[method]
new_model = copy.deepcopy(model)
- for name in list(new_model._modules):
- module = new_model._modules[name]
- if len(list(module._modules)) > 0:
- new_model._modules[name] = self.decompose(module, percent_removed)
- elif isinstance(module, nn.Conv2d) and module.groups == 1 and min(module.kernel_size) > 1:
- new_model._modules[name] = self.Tucker(module, percent_removed)
+ for name, module in list(new_model.named_modules()):
+ if (isinstance(module, nn.Conv2d) and module.groups == 1
+ and min(module.kernel_size) > 1
+ and _should_decompose(name, layers, exclude)):
+ parent_name, _, child_name = name.rpartition('.')
+ parent = new_model.get_submodule(parent_name) if parent_name else new_model
+ if method == 'tucker':
+ replacement = decompose_fn(module, percent_removed, energy_threshold, n_iter, tol)
+ else:
+ replacement = decompose_fn(module, percent_removed, energy_threshold)
+ setattr(parent, child_name, replacement)
return new_model
- def Tucker(self,
- layer: nn.Conv2d, # The Conv2d layer to decompose
- percent_removed: float, # Fraction of rank to remove per mode
+ def SVD(self,
+ layer: nn.Conv2d,
+ percent_removed: float = 0.5,
+ energy_threshold: float | None = None,
+ ) -> nn.Sequential:
+ "SVD: 2 layers — spatial at reduced output rank + pointwise expansion"
+ W = layer.weight.data
+ C_out, C_in = W.shape[:2]
+ K = layer.kernel_size
+
+ W_2d = rearrange(W, 'o i h w -> o (i h w)')
+
+ U, S, Vh = torch.linalg.svd(W_2d, full_matrices=False)
+ R = _rank_from_energy(S, energy_threshold) if energy_threshold else max(1, int((1 - percent_removed) * min(C_out, C_in)))
+
+ W_first = torch.diag(S[:R]) @ Vh[:R]
+
+ first = nn.Conv2d(C_in, R, K, stride=layer.stride,
+ padding=layer.padding, dilation=layer.dilation, bias=False)
+ first.weight.data = rearrange(W_first, 'r (i h w) -> r i h w', i=C_in, h=K[0], w=K[1])
+
+ last = nn.Conv2d(R, C_out, 1, bias=layer.bias is not None)
+ last.weight.data = rearrange(U[:, :R], 'o r -> o r 1 1')
+ if layer.bias is not None: last.bias.data = layer.bias.data
+
+ return nn.Sequential(first, last)
+
+ def Spatial(self,
+ layer: nn.Conv2d,
+ percent_removed: float = 0.5,
+ energy_threshold: float | None = None,
) -> nn.Sequential:
- "Perform Tucker decomposition on a single Conv2d layer"
+ "Spatial separable: 2 layers — K×1 vertical + 1×K horizontal (batched SVD)"
W = layer.weight.data
C_out, C_in = W.shape[:2]
+ Kh, Kw = layer.kernel_size
+
+ W_spatial = rearrange(W, 'o i h w -> (o i) h w')
+ U_all, S_all, Vh_all = torch.linalg.svd(W_spatial, full_matrices=False)
+ R = _rank_from_energy(S_all[0], energy_threshold) if energy_threshold else max(1, int((1 - percent_removed) * min(Kh, Kw)))
- R_out = max(1, int((1 - percent_removed) * C_out))
- R_in = max(1, int((1 - percent_removed) * C_in))
+ U_scaled = U_all[:, :, :R] * S_all[:, :R].unsqueeze(1).sqrt()
+ W_vert = rearrange(U_scaled, '(o i) h r -> (o r) i h 1', o=C_out, i=C_in)
+
+ Vh_scaled = S_all[:, :R].unsqueeze(2).sqrt() * Vh_all[:, :R, :]
+ Vh_by_out = rearrange(Vh_scaled, '(o i) r w -> o i r w', o=C_out)
+ W_horiz = rearrange(Vh_by_out.mean(dim=1), 'o r w -> o r 1 w')
+
+ vert = nn.Conv2d(C_in, C_out * R, (Kh, 1),
+ stride=(layer.stride[0], 1), padding=(layer.padding[0], 0), bias=False)
+ vert.weight.data = W_vert
+
+ horiz = nn.Conv2d(C_out * R, C_out, (1, Kw), groups=C_out,
+ stride=(1, layer.stride[1]), padding=(0, layer.padding[1]),
+ bias=layer.bias is not None)
+ horiz.weight.data = W_horiz
+ if layer.bias is not None: horiz.bias.data = layer.bias.data
+
+ return nn.Sequential(vert, horiz)
+
+ def CP(self,
+ layer: nn.Conv2d,
+ percent_removed: float = 0.5,
+ energy_threshold: float | None = None,
+ ) -> nn.Sequential:
+ "CP: 4 layers — pointwise compress + depthwise vertical + depthwise horizontal + pointwise expand"
+ W = layer.weight.data
+ C_out, C_in = W.shape[:2]
+ Kh, Kw = layer.kernel_size
+
+ W_2d = rearrange(W, 'o i h w -> o (i h w)')
+ U, S, Vh = torch.linalg.svd(W_2d, full_matrices=False)
+ S0 = torch.linalg.svd(_mode_unfold(W, 0), full_matrices=False)[1]
+ R = _rank_from_energy(S0, energy_threshold) if energy_threshold else max(1, int((1 - percent_removed) * min(C_out, C_in)))
+
+ V_4d = rearrange(Vh[:R], 'r (i h w) -> r i h w', i=C_in, h=Kh, w=Kw)
+ spatial_avg = V_4d.mean(dim=1)
+ U_s, S_s, Vh_s = torch.linalg.svd(spatial_avg, full_matrices=False)
+
+ W_dw_v = rearrange(U_s[:, :, 0] * S_s[:, 0:1].sqrt(), 'r h -> r 1 h 1')
+ W_dw_h = rearrange(Vh_s[:, 0, :] * S_s[:, 0:1].sqrt(), 'r w -> r 1 1 w')
+ channel_norms = V_4d.pow(2).sum(dim=(2, 3)).sqrt()
+ W_pw_in = rearrange(channel_norms * S[:R].sqrt().unsqueeze(1), 'r i -> r i 1 1')
+ W_pw_out = rearrange(U[:, :R] * S[:R].sqrt().unsqueeze(0), 'o r -> o r 1 1')
+
+ pw_in = nn.Conv2d(C_in, R, 1, bias=False)
+ pw_in.weight.data = W_pw_in
+ dw_v = nn.Conv2d(R, R, (Kh, 1), groups=R, stride=(layer.stride[0], 1),
+ padding=(layer.padding[0], 0), bias=False)
+ dw_v.weight.data = W_dw_v
+ dw_h = nn.Conv2d(R, R, (1, Kw), groups=R, stride=(1, layer.stride[1]),
+ padding=(0, layer.padding[1]), bias=False)
+ dw_h.weight.data = W_dw_h
+ pw_out = nn.Conv2d(R, C_out, 1, bias=layer.bias is not None)
+ pw_out.weight.data = W_pw_out
+ if layer.bias is not None: pw_out.bias.data = layer.bias.data
+
+ return nn.Sequential(pw_in, dw_v, dw_h, pw_out)
+
+ def Tucker(self,
+ layer: nn.Conv2d,
+ percent_removed: float = 0.5,
+ energy_threshold: float | None = None,
+ n_iter: int = 10,
+ tol: float = 1e-4,
+ ) -> nn.Sequential:
+ "Tucker: 3 layers — pointwise compress + spatial + pointwise expand"
+ W = layer.weight.data
+ C_out, C_in = W.shape[:2]
- core, (U_out, U_in) = _partial_tucker(W, [R_out, R_in])
- # core: (R_out, R_in, H, W), U_out: (C_out, R_out), U_in: (C_in, R_in)
+ if energy_threshold is not None:
+ S0 = torch.linalg.svd(_mode_unfold(W, 0), full_matrices=False)[1]
+ S1 = torch.linalg.svd(_mode_unfold(W, 1), full_matrices=False)[1]
+ R_out = _rank_from_energy(S0, energy_threshold)
+ R_in = _rank_from_energy(S1, energy_threshold)
+ else:
+ R_out = max(1, int((1 - percent_removed) * C_out))
+ R_in = max(1, int((1 - percent_removed) * C_in))
+ core, (U_out, U_in) = _partial_tucker(W, [R_out, R_in], n_iter=n_iter, tol=tol)
- # 1. Pointwise input compression: (C_in → R_in)
first = nn.Conv2d(C_in, R_in, 1, bias=False)
- first.weight.data = U_in.t().unsqueeze(-1).unsqueeze(-1)
+ first.weight.data = rearrange(U_in.t(), 'r i -> r i 1 1')
- # 2. Spatial convolution at reduced rank: (R_in → R_out)
middle = nn.Conv2d(R_in, R_out, layer.kernel_size, stride=layer.stride,
padding=layer.padding, dilation=layer.dilation, bias=False)
middle.weight.data = core
- # 3. Pointwise output expansion: (R_out → C_out)
last = nn.Conv2d(R_out, C_out, 1, bias=layer.bias is not None)
- last.weight.data = U_out.unsqueeze(-1).unsqueeze(-1)
- if layer.bias is not None:
- last.bias.data = layer.bias.data
+ last.weight.data = rearrange(U_out, 'o r -> o r 1 1')
+ if layer.bias is not None: last.bias.data = layer.bias.data
return nn.Sequential(first, middle, last)
diff --git a/fasterai/misc/cpu_optimizer.py b/fasterai/misc/cpu_optimizer.py
index b5fd869..bf5e91c 100644
--- a/fasterai/misc/cpu_optimizer.py
+++ b/fasterai/misc/cpu_optimizer.py
@@ -1,20 +1,37 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/misc/cpu_optimizer.ipynb.
# %% auto #0
-__all__ = ['accelerate_model_for_cpu']
+__all__ = ['optimize_for_cpu', 'accelerate_model_for_cpu']
# %% ../../nbs/misc/cpu_optimizer.ipynb #fbbccd4a
import torch
import torch.nn as nn
-from torch.utils.mobile_optimizer import optimize_for_mobile
+import warnings
# %% ../../nbs/misc/cpu_optimizer.ipynb #6524ac31
-def accelerate_model_for_cpu(model: nn.Module, example_input: torch.Tensor):
- model.eval()
- example_input = example_input.to(memory_format=torch.channels_last)
-
- model = model.to(memory_format=torch.channels_last)
- model = torch.jit.script(model)
- model = optimize_for_mobile(model)
+def optimize_for_cpu(
+ model: nn.Module, # The PyTorch model to optimize
+ sample: torch.Tensor, # Sample input for tracing (with batch dim)
+ *,
+ backend: str = "compile", # "compile" (torch.compile) or "trace" (torch.jit.trace)
+ compile_mode: str = "default", # torch.compile mode
+) -> nn.Module:
+ "Optimize model for CPU inference via channels-last layout + compilation"
+ model = model.eval().to(memory_format=torch.channels_last)
+ sample = sample.to(memory_format=torch.channels_last)
+
+ if backend == "compile":
+ return torch.compile(model, mode=compile_mode)
+ elif backend == "trace":
+ with torch.no_grad():
+ return torch.jit.trace(model, sample)
+ else:
+ raise ValueError(f"Unknown backend: {backend!r}. Use 'compile' or 'trace'.")
- return model
+def accelerate_model_for_cpu(model: nn.Module, example_input: torch.Tensor):
+ "Deprecated: use optimize_for_cpu() instead"
+ warnings.warn(
+ "accelerate_model_for_cpu is deprecated, use optimize_for_cpu(model, sample) instead",
+ DeprecationWarning, stacklevel=2,
+ )
+ return optimize_for_cpu(model, example_input, backend="trace")
diff --git a/fasterai/misc/fc_decomposer.py b/fasterai/misc/fc_decomposer.py
index c1fa113..745cf46 100644
--- a/fasterai/misc/fc_decomposer.py
+++ b/fasterai/misc/fc_decomposer.py
@@ -6,47 +6,125 @@
# %% ../../nbs/misc/fc_decomposer.ipynb #fbbccd4a
import torch
import torch.nn as nn
-import torch.nn.functional as F
import copy
# %% ../../nbs/misc/fc_decomposer.ipynb #6524ac31
+def _rank_from_energy(S, threshold):
+ "Find minimum rank to retain `threshold` fraction of singular value energy"
+ energy = S.pow(2).cumsum(0) / S.pow(2).sum()
+ idx = (energy >= threshold).nonzero(as_tuple=True)[0]
+ return max(1, int(idx[0].item()) + 1) if len(idx) > 0 else S.shape[0]
+
+def _should_decompose(name, layers=None, exclude=None):
+ "Check if a named layer should be decomposed"
+ if exclude and name in exclude: return False
+ if layers is not None: return name in layers
+ return True
+
+def _collect_activation_rms(
+ model: nn.Module, # Model to calibrate
+ data, # Tensor, list of batches, or DataLoader
+ layer_type: type = nn.Linear, # Layer types to hook
+ n_batches: int = 5, # Max batches to process
+) -> dict[nn.Module, torch.Tensor]:
+ "Collect per-input-channel RMS activation norms via forward hooks"
+ device = next(model.parameters()).device
+ state = {}
+ hooks = []
+ for m in model.modules():
+ if isinstance(m, layer_type):
+ state[m] = {'acc': torch.zeros(m.weight.shape[1], device=device), 'n': 0}
+ def make_hook(module):
+ def hook(mod, inp):
+ x = inp[0].detach()
+ dims = [i for i in range(x.dim()) if i != 1] # keep channel dim
+ state[module]['acc'] += x.pow(2).sum(dim=dims)
+ state[module]['n'] += x.shape[0]
+ return hook
+ hooks.append(m.register_forward_pre_hook(make_hook(m)))
+
+ model.eval()
+ with torch.no_grad():
+ if isinstance(data, torch.Tensor):
+ model(data.to(device))
+ else:
+ for n, batch in enumerate(data):
+ if n >= n_batches: break
+ xb = batch[0] if isinstance(batch, (tuple, list)) else batch
+ model(xb.as_subclass(torch.Tensor).to(device))
+
+ for h in hooks: h.remove()
+ return {m: (s['acc'] / max(s['n'], 1)).sqrt() for m, s in state.items()}
+
+
class FC_Decomposer:
"Decompose fully-connected layers using SVD to reduce parameters"
- def __init__(self):
- pass
+ def __init__(self): pass
def decompose(self,
- model: nn.Module, # The model to decompose
- percent_removed: float = 0.5 # Fraction of singular values to remove [0, 1)
+ model: nn.Module, # The model to decompose
+ percent_removed: float = 0.5, # Fraction of singular values to remove [0, 1)
+ energy_threshold: float | None = None, # Auto rank: keep this fraction of energy (0-1)
+ data = None, # Calibration data for ASVD (None = standard SVD)
+ n_batches: int = 5, # Number of calibration batches
+ layers: list[str] | None = None, # Layer names to decompose (None = all)
+ exclude: list[str] | None = None, # Layer names to skip
) -> nn.Module:
- "Recursively decompose all Linear layers in the model using SVD"
- if not (0 <= percent_removed < 1):
+ "Decompose Linear layers using SVD. Pass data for activation-aware ASVD."
+ if energy_threshold is None and not (0 <= percent_removed < 1):
raise ValueError(f"percent_removed must be in range [0, 1), got {percent_removed}")
+ if energy_threshold is not None and not (0 < energy_threshold <= 1):
+ raise ValueError(f"energy_threshold must be in range (0, 1], got {energy_threshold}")
+
+ # Collect activation stats on ORIGINAL model before deepcopy
+ scale_map = {}
+ if data is not None:
+ rms = _collect_activation_rms(model, data, nn.Linear, n_batches)
+ # Map by name so we can find them after deepcopy
+ for name, m in model.named_modules():
+ if m in rms: scale_map[name] = rms[m]
new_model = copy.deepcopy(model)
- module_names = list(new_model._modules)
-
- for k, name in enumerate(module_names):
- if len(list(new_model._modules[name]._modules)) > 0:
- new_model._modules[name] = self.decompose(new_model._modules[name], percent_removed)
- else:
- if isinstance(new_model._modules[name], nn.Linear):
- layer = self.SVD(new_model._modules[name], percent_removed)
- new_model._modules[name] = layer
+ for name, module in list(new_model.named_modules()):
+ if isinstance(module, nn.Linear) and _should_decompose(name, layers, exclude):
+ scale = scale_map.get(name, None)
+ parent_name, _, child_name = name.rpartition('.')
+ parent = new_model.get_submodule(parent_name) if parent_name else new_model
+ setattr(parent, child_name, self.SVD(module, percent_removed, energy_threshold, scale))
return new_model
-
def SVD(self,
- layer: nn.Linear, # The Linear layer to decompose
- percent_removed: float # Fraction of singular values to remove
+ layer: nn.Linear, # The Linear layer to decompose
+ percent_removed: float = 0.5, # Fraction of singular values to remove
+ energy_threshold: float | None = None, # Auto rank via energy retention
+ scale: torch.Tensor | None = None, # Per-channel activation RMS for ASVD
) -> nn.Sequential:
- "Perform SVD decomposition on a single Linear layer"
+ "Perform SVD decomposition. With scale: activation-aware SVD (ASVD)."
W = layer.weight.data
- U, S, Vh = torch.linalg.svd(W, full_matrices=False)
- L = max(1, int((1.-percent_removed) * S.shape[0]))
+
+ # ASVD: scale columns by activation RMS before SVD
+ if scale is not None:
+ s = scale.to(W.device) + 1e-6
+ W_scaled = W * s.unsqueeze(0) # (out, in) * (1, in)
+ else:
+ W_scaled = W
+
+ U, S, Vh = torch.linalg.svd(W_scaled, full_matrices=False)
+
+ if energy_threshold is not None:
+ L = _rank_from_energy(S, energy_threshold)
+ else:
+ L = max(1, int((1.-percent_removed) * S.shape[0]))
+
W1 = U[:,:L]
W2 = torch.diag(S[:L]) @ Vh[:L]
+
+ # ASVD: undo scaling in the first layer's weights
+ if scale is not None:
+ s_inv = 1.0 / s
+ W2 = W2 * s_inv.unsqueeze(0) # (L, in) * (1, in)
+
layer_1 = nn.Linear(in_features=layer.in_features,
out_features=L, bias=False)
layer_1.weight.data = W2
diff --git a/nbs/_quarto.yml b/nbs/_quarto.yml
index 3490807..972f71b 100644
--- a/nbs/_quarto.yml
+++ b/nbs/_quarto.yml
@@ -114,6 +114,8 @@ website:
contents:
- misc/bn_folding.ipynb
- misc/fc_decomposer.ipynb
+ - misc/conv_decomposer.ipynb
+ - misc/cpu_optimizer.ipynb
- section: Export
contents:
- export/onnx_exporter.ipynb
diff --git a/nbs/misc/bn_folding.ipynb b/nbs/misc/bn_folding.ipynb
index ea50d7c..8f2f01e 100644
--- a/nbs/misc/bn_folding.ipynb
+++ b/nbs/misc/bn_folding.ipynb
@@ -45,7 +45,6 @@
"#| export\n",
"import torch\n",
"import torch.nn as nn\n",
- "import torch.nn.functional as F\n",
"import copy"
]
},
diff --git a/nbs/misc/conv_decomposer.ipynb b/nbs/misc/conv_decomposer.ipynb
index 5edb688..810a5da 100644
--- a/nbs/misc/conv_decomposer.ipynb
+++ b/nbs/misc/conv_decomposer.ipynb
@@ -5,12 +5,7 @@
"id": "frontmatter",
"metadata": {},
"source": [
- "---",
- "description: Decompose Conv2d layers via Tucker decomposition",
- "output-file: conv_decomposer.html",
- "title: Conv2d Layers Decomposer",
- "skip_showdoc: true",
- "---"
+ "---description: Decompose Conv2d layers via Tucker decompositionoutput-file: conv_decomposer.htmltitle: Conv2d Layers Decomposerskip_showdoc: true---"
]
},
{
@@ -29,13 +24,35 @@
"id": "showdoc-import",
"metadata": {},
"outputs": [],
- "source": "#| include: false\nfrom nbdev.showdoc import *"
+ "source": [
+ "#| include: false\n",
+ "from nbdev.showdoc import *"
+ ]
},
{
"cell_type": "markdown",
"id": "overview",
"metadata": {},
- "source": "## Overview\n\nThe `Conv_Decomposer` class reduces model size and FLOPs by factorizing Conv2d layers into three smaller convolutions using Tucker decomposition. This is the Conv2d counterpart of `FC_Decomposer` (which uses SVD for Linear layers).\n\n**How it works:** A Conv2d weight `[C_out, C_in, H, W]` is decomposed into:\n1. `Conv2d(C_in, R_in, 1)` — pointwise input channel compression\n2. `Conv2d(R_in, R_out, (H, W))` — spatial convolution at reduced rank\n3. `Conv2d(R_out, C_out, 1)` — pointwise output channel expansion\n\n### When to Use\n\n| Scenario | Recommendation |\n|----------|----------------|\n| Large 3x3 or larger convolutions | **Highly recommended** — significant FLOP savings |\n| 1x1 pointwise convolutions | Skipped automatically (already minimal) |\n| Depthwise / grouped convolutions | Skipped (Tucker assumes standard convolution) |\n| First layer (C_in=3) | Works but limited benefit |\n| Post-training compression | Fine-tune after decomposition for best accuracy |"
+ "source": [
+ "## Overview\n",
+ "\n",
+ "The `Conv_Decomposer` class reduces model size and FLOPs by factorizing Conv2d layers into three smaller convolutions using Tucker decomposition. This is the Conv2d counterpart of `FC_Decomposer` (which uses SVD for Linear layers).\n",
+ "\n",
+ "**How it works:** A Conv2d weight `[C_out, C_in, H, W]` is decomposed into:\n",
+ "1. `Conv2d(C_in, R_in, 1)` — pointwise input channel compression\n",
+ "2. `Conv2d(R_in, R_out, (H, W))` — spatial convolution at reduced rank\n",
+ "3. `Conv2d(R_out, C_out, 1)` — pointwise output channel expansion\n",
+ "\n",
+ "### When to Use\n",
+ "\n",
+ "| Scenario | Recommendation |\n",
+ "|----------|----------------|\n",
+ "| Large 3x3 or larger convolutions | **Highly recommended** — significant FLOP savings |\n",
+ "| 1x1 pointwise convolutions | Skipped automatically (already minimal) |\n",
+ "| Depthwise / grouped convolutions | Skipped (Tucker assumes standard convolution) |\n",
+ "| First layer (C_in=3) | Works but limited benefit |\n",
+ "| Post-training compression | Fine-tune after decomposition for best accuracy |"
+ ]
},
{
"cell_type": "code",
@@ -43,7 +60,7 @@
"id": "imports",
"metadata": {},
"outputs": [],
- "source": "#| export\nimport torch\nimport torch.nn as nn\nimport copy"
+ "source": "#| export\nimport torch\nimport torch.nn as nn\nimport copy\nfrom einops import rearrange"
},
{
"cell_type": "code",
@@ -51,7 +68,205 @@
"id": "conv-decomposer",
"metadata": {},
"outputs": [],
- "source": "#| export\ndef _unfold(tensor, mode):\n \"Unfold a tensor along a mode into a matrix\"\n return tensor.moveaxis(mode, 0).flatten(1)\n\ndef _partial_tucker(weight, ranks, n_iter=5):\n \"Partial Tucker decomposition on modes [0, 1] via alternating SVD (HOOI)\"\n # Initialize factors from SVD of mode unfoldings\n U0 = torch.linalg.svd(_unfold(weight, 0), full_matrices=False)[0][:, :ranks[0]]\n U1 = torch.linalg.svd(_unfold(weight, 1), full_matrices=False)[0][:, :ranks[1]]\n\n for _ in range(n_iter):\n # Project out mode 0 using U0, then update U1\n proj = torch.einsum('oihw, or -> rihw', weight, U0)\n U1 = torch.linalg.svd(_unfold(proj, 1), full_matrices=False)[0][:, :ranks[1]]\n # Project out mode 1 using U1, then update U0\n proj = torch.einsum('oihw, is -> oshw', weight, U1)\n U0 = torch.linalg.svd(_unfold(proj, 0), full_matrices=False)[0][:, :ranks[0]]\n\n # Core = W ×₀ U0ᵀ ×₁ U1ᵀ\n core = torch.einsum('oihw, or, is -> rshw', weight, U0, U1)\n return core, [U0, U1]\n\n\nclass Conv_Decomposer:\n \"Decompose Conv2d layers using Tucker decomposition to reduce parameters and FLOPs\"\n\n def __init__(self): pass\n\n def decompose(self,\n model: nn.Module, # The model to decompose\n percent_removed: float = 0.5, # Fraction of rank to remove per mode [0, 1)\n ) -> nn.Module:\n \"Recursively decompose all eligible Conv2d layers in the model\"\n if not (0 <= percent_removed < 1):\n raise ValueError(f\"percent_removed must be in range [0, 1), got {percent_removed}\")\n\n new_model = copy.deepcopy(model)\n for name in list(new_model._modules):\n module = new_model._modules[name]\n if len(list(module._modules)) > 0:\n new_model._modules[name] = self.decompose(module, percent_removed)\n elif isinstance(module, nn.Conv2d) and module.groups == 1 and min(module.kernel_size) > 1:\n new_model._modules[name] = self.Tucker(module, percent_removed)\n return new_model\n\n def Tucker(self,\n layer: nn.Conv2d, # The Conv2d layer to decompose\n percent_removed: float, # Fraction of rank to remove per mode\n ) -> nn.Sequential:\n \"Perform Tucker decomposition on a single Conv2d layer\"\n W = layer.weight.data\n C_out, C_in = W.shape[:2]\n\n R_out = max(1, int((1 - percent_removed) * C_out))\n R_in = max(1, int((1 - percent_removed) * C_in))\n\n core, (U_out, U_in) = _partial_tucker(W, [R_out, R_in])\n # core: (R_out, R_in, H, W), U_out: (C_out, R_out), U_in: (C_in, R_in)\n\n # 1. Pointwise input compression: (C_in → R_in)\n first = nn.Conv2d(C_in, R_in, 1, bias=False)\n first.weight.data = U_in.t().unsqueeze(-1).unsqueeze(-1)\n\n # 2. Spatial convolution at reduced rank: (R_in → R_out)\n middle = nn.Conv2d(R_in, R_out, layer.kernel_size, stride=layer.stride,\n padding=layer.padding, dilation=layer.dilation, bias=False)\n middle.weight.data = core\n\n # 3. Pointwise output expansion: (R_out → C_out)\n last = nn.Conv2d(R_out, C_out, 1, bias=layer.bias is not None)\n last.weight.data = U_out.unsqueeze(-1).unsqueeze(-1)\n if layer.bias is not None:\n last.bias.data = layer.bias.data\n\n return nn.Sequential(first, middle, last)"
+ "source": [
+ "#| export\n",
+ "from fasterai.misc.fc_decomposer import _rank_from_energy, _should_decompose\n",
+ "\n",
+ "def _mode_unfold(W, mode):\n",
+ " \"Unfold a 4D tensor along a mode into a 2D matrix\"\n",
+ " return rearrange(W, 'o i h w -> o (i h w)') if mode == 0 else rearrange(W, 'o i h w -> i (o h w)')\n",
+ "\n",
+ "def _partial_tucker(weight, ranks, n_iter=10, tol=1e-4):\n",
+ " \"Partial Tucker decomposition on modes [0, 1] via alternating SVD (HOOI)\"\n",
+ " U0 = torch.linalg.svd(_mode_unfold(weight, 0), full_matrices=False)[0][:, :ranks[0]]\n",
+ " U1 = torch.linalg.svd(_mode_unfold(weight, 1), full_matrices=False)[0][:, :ranks[1]]\n",
+ "\n",
+ " for _ in range(n_iter):\n",
+ " U0_prev, U1_prev = U0.clone(), U1.clone()\n",
+ " proj = torch.einsum('oihw, or -> rihw', weight, U0)\n",
+ " U1 = torch.linalg.svd(_mode_unfold(proj, 1), full_matrices=False)[0][:, :ranks[1]]\n",
+ " proj = torch.einsum('oihw, is -> oshw', weight, U1)\n",
+ " U0 = torch.linalg.svd(_mode_unfold(proj, 0), full_matrices=False)[0][:, :ranks[0]]\n",
+ " if (U0 - U0_prev).norm() + (U1 - U1_prev).norm() < tol: break\n",
+ "\n",
+ " core = torch.einsum('oihw, or, is -> rshw', weight, U0, U1)\n",
+ " return core, [U0, U1]\n",
+ "\n",
+ "VALID_METHODS = frozenset({'tucker', 'svd', 'spatial', 'cp'})\n",
+ "\n",
+ "class Conv_Decomposer:\n",
+ " \"Decompose Conv2d layers to reduce parameters and FLOPs\"\n",
+ "\n",
+ " def __init__(self): pass\n",
+ "\n",
+ " def decompose(self,\n",
+ " model: nn.Module, # The model to decompose\n",
+ " percent_removed: float = 0.5, # Fraction of rank to remove [0, 1)\n",
+ " method: str = 'tucker', # 'tucker', 'svd', 'spatial', or 'cp'\n",
+ " energy_threshold: float | None = None, # Auto rank via energy retention (0-1)\n",
+ " layers: list[str] | None = None, # Layer names to decompose (None = all eligible)\n",
+ " exclude: list[str] | None = None, # Layer names to skip\n",
+ " n_iter: int = 10, # Max HOOI iterations (tucker only)\n",
+ " tol: float = 1e-4, # HOOI convergence tolerance (tucker only)\n",
+ " ) -> nn.Module:\n",
+ " \"Decompose eligible Conv2d layers using the specified method.\"\n",
+ " if method not in VALID_METHODS:\n",
+ " raise ValueError(f\"method must be one of {VALID_METHODS}, got {method!r}\")\n",
+ " if energy_threshold is None and not (0 <= percent_removed < 1):\n",
+ " raise ValueError(f\"percent_removed must be in range [0, 1), got {percent_removed}\")\n",
+ " if energy_threshold is not None and not (0 < energy_threshold <= 1):\n",
+ " raise ValueError(f\"energy_threshold must be in range (0, 1], got {energy_threshold}\")\n",
+ "\n",
+ " decompose_fn = {'tucker': self.Tucker, 'svd': self.SVD,\n",
+ " 'spatial': self.Spatial, 'cp': self.CP}[method]\n",
+ "\n",
+ " new_model = copy.deepcopy(model)\n",
+ " for name, module in list(new_model.named_modules()):\n",
+ " if (isinstance(module, nn.Conv2d) and module.groups == 1 \n",
+ " and min(module.kernel_size) > 1\n",
+ " and _should_decompose(name, layers, exclude)):\n",
+ " parent_name, _, child_name = name.rpartition('.')\n",
+ " parent = new_model.get_submodule(parent_name) if parent_name else new_model\n",
+ " if method == 'tucker':\n",
+ " replacement = decompose_fn(module, percent_removed, energy_threshold, n_iter, tol)\n",
+ " else:\n",
+ " replacement = decompose_fn(module, percent_removed, energy_threshold)\n",
+ " setattr(parent, child_name, replacement)\n",
+ " return new_model\n",
+ "\n",
+ " def SVD(self,\n",
+ " layer: nn.Conv2d,\n",
+ " percent_removed: float = 0.5,\n",
+ " energy_threshold: float | None = None,\n",
+ " ) -> nn.Sequential:\n",
+ " \"SVD: 2 layers — spatial at reduced output rank + pointwise expansion\"\n",
+ " W = layer.weight.data\n",
+ " C_out, C_in = W.shape[:2]\n",
+ " K = layer.kernel_size\n",
+ "\n",
+ " W_2d = rearrange(W, 'o i h w -> o (i h w)')\n",
+ "\n",
+ " U, S, Vh = torch.linalg.svd(W_2d, full_matrices=False)\n",
+ " R = _rank_from_energy(S, energy_threshold) if energy_threshold else max(1, int((1 - percent_removed) * min(C_out, C_in)))\n",
+ "\n",
+ " W_first = torch.diag(S[:R]) @ Vh[:R]\n",
+ "\n",
+ " first = nn.Conv2d(C_in, R, K, stride=layer.stride,\n",
+ " padding=layer.padding, dilation=layer.dilation, bias=False)\n",
+ " first.weight.data = rearrange(W_first, 'r (i h w) -> r i h w', i=C_in, h=K[0], w=K[1])\n",
+ "\n",
+ " last = nn.Conv2d(R, C_out, 1, bias=layer.bias is not None)\n",
+ " last.weight.data = rearrange(U[:, :R], 'o r -> o r 1 1')\n",
+ " if layer.bias is not None: last.bias.data = layer.bias.data\n",
+ "\n",
+ " return nn.Sequential(first, last)\n",
+ "\n",
+ " def Spatial(self,\n",
+ " layer: nn.Conv2d,\n",
+ " percent_removed: float = 0.5,\n",
+ " energy_threshold: float | None = None,\n",
+ " ) -> nn.Sequential:\n",
+ " \"Spatial separable: 2 layers — K×1 vertical + 1×K horizontal (batched SVD)\"\n",
+ " W = layer.weight.data\n",
+ " C_out, C_in = W.shape[:2]\n",
+ " Kh, Kw = layer.kernel_size\n",
+ "\n",
+ " W_spatial = rearrange(W, 'o i h w -> (o i) h w')\n",
+ " U_all, S_all, Vh_all = torch.linalg.svd(W_spatial, full_matrices=False)\n",
+ " R = _rank_from_energy(S_all[0], energy_threshold) if energy_threshold else max(1, int((1 - percent_removed) * min(Kh, Kw)))\n",
+ "\n",
+ " U_scaled = U_all[:, :, :R] * S_all[:, :R].unsqueeze(1).sqrt()\n",
+ " W_vert = rearrange(U_scaled, '(o i) h r -> (o r) i h 1', o=C_out, i=C_in)\n",
+ "\n",
+ " Vh_scaled = S_all[:, :R].unsqueeze(2).sqrt() * Vh_all[:, :R, :]\n",
+ " Vh_by_out = rearrange(Vh_scaled, '(o i) r w -> o i r w', o=C_out)\n",
+ " W_horiz = rearrange(Vh_by_out.mean(dim=1), 'o r w -> o r 1 w')\n",
+ "\n",
+ " vert = nn.Conv2d(C_in, C_out * R, (Kh, 1),\n",
+ " stride=(layer.stride[0], 1), padding=(layer.padding[0], 0), bias=False)\n",
+ " vert.weight.data = W_vert\n",
+ "\n",
+ " horiz = nn.Conv2d(C_out * R, C_out, (1, Kw), groups=C_out,\n",
+ " stride=(1, layer.stride[1]), padding=(0, layer.padding[1]),\n",
+ " bias=layer.bias is not None)\n",
+ " horiz.weight.data = W_horiz\n",
+ " if layer.bias is not None: horiz.bias.data = layer.bias.data\n",
+ "\n",
+ " return nn.Sequential(vert, horiz)\n",
+ "\n",
+ " def CP(self,\n",
+ " layer: nn.Conv2d,\n",
+ " percent_removed: float = 0.5,\n",
+ " energy_threshold: float | None = None,\n",
+ " ) -> nn.Sequential:\n",
+ " \"CP: 4 layers — pointwise compress + depthwise vertical + depthwise horizontal + pointwise expand\"\n",
+ " W = layer.weight.data\n",
+ " C_out, C_in = W.shape[:2]\n",
+ " Kh, Kw = layer.kernel_size\n",
+ "\n",
+ " W_2d = rearrange(W, 'o i h w -> o (i h w)')\n",
+ " U, S, Vh = torch.linalg.svd(W_2d, full_matrices=False)\n",
+ " S0 = torch.linalg.svd(_mode_unfold(W, 0), full_matrices=False)[1]\n",
+ " R = _rank_from_energy(S0, energy_threshold) if energy_threshold else max(1, int((1 - percent_removed) * min(C_out, C_in)))\n",
+ "\n",
+ " V_4d = rearrange(Vh[:R], 'r (i h w) -> r i h w', i=C_in, h=Kh, w=Kw)\n",
+ " spatial_avg = V_4d.mean(dim=1)\n",
+ " U_s, S_s, Vh_s = torch.linalg.svd(spatial_avg, full_matrices=False)\n",
+ "\n",
+ " W_dw_v = rearrange(U_s[:, :, 0] * S_s[:, 0:1].sqrt(), 'r h -> r 1 h 1')\n",
+ " W_dw_h = rearrange(Vh_s[:, 0, :] * S_s[:, 0:1].sqrt(), 'r w -> r 1 1 w')\n",
+ " channel_norms = V_4d.pow(2).sum(dim=(2, 3)).sqrt()\n",
+ " W_pw_in = rearrange(channel_norms * S[:R].sqrt().unsqueeze(1), 'r i -> r i 1 1')\n",
+ " W_pw_out = rearrange(U[:, :R] * S[:R].sqrt().unsqueeze(0), 'o r -> o r 1 1')\n",
+ "\n",
+ " pw_in = nn.Conv2d(C_in, R, 1, bias=False)\n",
+ " pw_in.weight.data = W_pw_in\n",
+ " dw_v = nn.Conv2d(R, R, (Kh, 1), groups=R, stride=(layer.stride[0], 1),\n",
+ " padding=(layer.padding[0], 0), bias=False)\n",
+ " dw_v.weight.data = W_dw_v\n",
+ " dw_h = nn.Conv2d(R, R, (1, Kw), groups=R, stride=(1, layer.stride[1]),\n",
+ " padding=(0, layer.padding[1]), bias=False)\n",
+ " dw_h.weight.data = W_dw_h\n",
+ " pw_out = nn.Conv2d(R, C_out, 1, bias=layer.bias is not None)\n",
+ " pw_out.weight.data = W_pw_out\n",
+ " if layer.bias is not None: pw_out.bias.data = layer.bias.data\n",
+ "\n",
+ " return nn.Sequential(pw_in, dw_v, dw_h, pw_out)\n",
+ "\n",
+ " def Tucker(self,\n",
+ " layer: nn.Conv2d,\n",
+ " percent_removed: float = 0.5,\n",
+ " energy_threshold: float | None = None,\n",
+ " n_iter: int = 10,\n",
+ " tol: float = 1e-4,\n",
+ " ) -> nn.Sequential:\n",
+ " \"Tucker: 3 layers — pointwise compress + spatial + pointwise expand\"\n",
+ " W = layer.weight.data\n",
+ " C_out, C_in = W.shape[:2]\n",
+ "\n",
+ " if energy_threshold is not None:\n",
+ " S0 = torch.linalg.svd(_mode_unfold(W, 0), full_matrices=False)[1]\n",
+ " S1 = torch.linalg.svd(_mode_unfold(W, 1), full_matrices=False)[1]\n",
+ " R_out = _rank_from_energy(S0, energy_threshold)\n",
+ " R_in = _rank_from_energy(S1, energy_threshold)\n",
+ " else:\n",
+ " R_out = max(1, int((1 - percent_removed) * C_out))\n",
+ " R_in = max(1, int((1 - percent_removed) * C_in))\n",
+ " core, (U_out, U_in) = _partial_tucker(W, [R_out, R_in], n_iter=n_iter, tol=tol)\n",
+ "\n",
+ " first = nn.Conv2d(C_in, R_in, 1, bias=False)\n",
+ " first.weight.data = rearrange(U_in.t(), 'r i -> r i 1 1')\n",
+ "\n",
+ " middle = nn.Conv2d(R_in, R_out, layer.kernel_size, stride=layer.stride,\n",
+ " padding=layer.padding, dilation=layer.dilation, bias=False)\n",
+ " middle.weight.data = core\n",
+ "\n",
+ " last = nn.Conv2d(R_out, C_out, 1, bias=layer.bias is not None)\n",
+ " last.weight.data = rearrange(U_out, 'o r -> o r 1 1')\n",
+ " if layer.bias is not None: last.bias.data = layer.bias.data\n",
+ "\n",
+ " return nn.Sequential(first, middle, last)"
+ ]
},
{
"cell_type": "code",
@@ -77,7 +292,27 @@
"cell_type": "markdown",
"id": "usage",
"metadata": {},
- "source": "---\n\n## Usage Example\n\n```python\nfrom fasterai.misc.conv_decomposer import Conv_Decomposer\nfrom torchvision.models import resnet18\n\nmodel = resnet18(pretrained=True)\ndecomposer = Conv_Decomposer()\ncompressed = decomposer.decompose(model, percent_removed=0.5)\n\n# Check parameter reduction\norig = sum(p.numel() for p in model.parameters())\ncomp = sum(p.numel() for p in compressed.parameters())\nprint(f\"Compression: {orig/comp:.2f}x\")\n```\n\n> **Note:** Tucker decomposition uses an iterative algorithm (HOOI), so even at `percent_removed=0.0` there will be small reconstruction error. Fine-tuning after decomposition is recommended."
+ "source": [
+ "---\n",
+ "\n",
+ "## Usage Example\n",
+ "\n",
+ "```python\n",
+ "from fasterai.misc.conv_decomposer import Conv_Decomposer\n",
+ "from torchvision.models import resnet18\n",
+ "\n",
+ "model = resnet18(pretrained=True)\n",
+ "decomposer = Conv_Decomposer()\n",
+ "compressed = decomposer.decompose(model, percent_removed=0.5)\n",
+ "\n",
+ "# Check parameter reduction\n",
+ "orig = sum(p.numel() for p in model.parameters())\n",
+ "comp = sum(p.numel() for p in compressed.parameters())\n",
+ "print(f\"Compression: {orig/comp:.2f}x\")\n",
+ "```\n",
+ "\n",
+ "> **Note:** Tucker decomposition uses an iterative algorithm (HOOI), so even at `percent_removed=0.0` there will be small reconstruction error. Fine-tuning after decomposition is recommended."
+ ]
},
{
"cell_type": "code",
@@ -85,7 +320,7 @@
"id": "tests",
"metadata": {},
"outputs": [],
- "source": "#| hide\nfrom fastcore.test import *\n\ndecomposer = Conv_Decomposer()\n\n# --- Output shape preserved ---\n_m = nn.Sequential(nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(), nn.Conv2d(16, 32, 3, padding=1))\n_x = torch.randn(2, 3, 8, 8)\n_m_dec = decomposer.decompose(_m, percent_removed=0.5)\ntest_eq(_m(_x).shape, _m_dec(_x).shape)\n\n# --- percent_removed=0.0 → close reconstruction (HOOI is iterative, not exact) ---\n_m2 = nn.Sequential(nn.Conv2d(16, 32, 3, padding=1))\n_x2 = torch.randn(2, 16, 8, 8)\n_m2_dec = decomposer.decompose(_m2, percent_removed=0.0)\ntest_close(_m2(_x2), _m2_dec(_x2), eps=0.01)\n\n# --- Decomposed structure: Conv2d becomes Sequential of 3 Conv2ds ---\nassert isinstance(_m_dec[0], nn.Sequential)\ntest_eq(len(_m_dec[0]), 3)\ntest_eq(_m_dec[0][0].kernel_size, (1, 1)) # pointwise in\ntest_eq(_m_dec[0][1].kernel_size, (3, 3)) # spatial\ntest_eq(_m_dec[0][2].kernel_size, (1, 1)) # pointwise out\n\n# --- 1x1 convolutions are skipped ---\n_m_pw = nn.Sequential(nn.Conv2d(16, 32, 1))\n_m_pw_dec = decomposer.decompose(_m_pw, percent_removed=0.5)\nassert isinstance(_m_pw_dec[0], nn.Conv2d) # unchanged, not Sequential\n\n# --- Grouped convolutions are skipped ---\n_m_dw = nn.Sequential(nn.Conv2d(16, 16, 3, padding=1, groups=16))\n_m_dw_dec = decomposer.decompose(_m_dw, percent_removed=0.5)\nassert isinstance(_m_dw_dec[0], nn.Conv2d) # unchanged\n\n# --- Minimum rank >= 1 even at extreme removal ---\n_m3 = nn.Sequential(nn.Conv2d(4, 8, 3, padding=1))\n_m3_dec = decomposer.decompose(_m3, percent_removed=0.95)\ntest_eq(_m3_dec[0][0].out_features if hasattr(_m3_dec[0][0], 'out_features') else _m3_dec[0][0].out_channels, max(1, int(0.05 * 4)))\n\n# --- Bias handling: original bias → last layer gets it ---\n_conv_bias = nn.Conv2d(16, 32, 3, padding=1, bias=True)\n_dec_bias = decomposer.Tucker(_conv_bias, 0.5)\nassert _dec_bias[0].bias is None # first: no bias\nassert _dec_bias[1].bias is None # middle: no bias\nassert _dec_bias[2].bias is not None # last: has bias\n\n_conv_nobias = nn.Conv2d(16, 32, 3, padding=1, bias=False)\n_dec_nobias = decomposer.Tucker(_conv_nobias, 0.5)\nassert _dec_nobias[2].bias is None # last: no bias\n\n# --- Stride/padding transfer to middle conv only ---\n_conv_stride = nn.Conv2d(16, 32, 3, stride=2, padding=1)\n_dec_stride = decomposer.Tucker(_conv_stride, 0.5)\ntest_eq(_dec_stride[0].stride, (1, 1)) # pointwise: default\ntest_eq(_dec_stride[1].stride, (2, 2)) # middle: from original\ntest_eq(_dec_stride[2].stride, (1, 1)) # pointwise: default\n\n# --- Validation ---\nwith ExceptionExpected(ValueError): decomposer.decompose(nn.Sequential(nn.Conv2d(3, 16, 3)), percent_removed=1.0)\nwith ExceptionExpected(ValueError): decomposer.decompose(nn.Sequential(nn.Conv2d(3, 16, 3)), percent_removed=-0.1)"
+ "source": "#| hide\nfrom fastcore.test import *\n\ndecomposer = Conv_Decomposer()\n_m = nn.Sequential(nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(), nn.Conv2d(16, 32, 3, padding=1))\n_x = torch.randn(2, 3, 8, 8)\n\n# === All methods produce correct output shape ===\nfor method in ['tucker', 'svd', 'spatial', 'cp']:\n _dec = decomposer.decompose(_m, 0.5, method=method)\n test_eq(_m(_x).shape, _dec(_x).shape)\n assert torch.isfinite(_dec(_x)).all(), f\"{method} produced non-finite output\"\n\n# === Tucker: 3 layers (1x1, KxK, 1x1) ===\n_t = decomposer.decompose(_m, 0.5, method='tucker')\ntest_eq(len(_t[0]), 3)\ntest_eq(_t[0][0].kernel_size, (1, 1))\ntest_eq(_t[0][1].kernel_size, (3, 3))\n\n# === SVD: 2 layers (KxK, 1x1) ===\n_s = decomposer.decompose(_m, 0.5, method='svd')\ntest_eq(len(_s[0]), 2)\ntest_eq(_s[0][0].kernel_size, (3, 3))\ntest_eq(_s[0][1].kernel_size, (1, 1))\n\n# === Spatial: 2 layers (Kx1, 1xK) ===\n_sp = decomposer.decompose(_m, 0.5, method='spatial')\ntest_eq(len(_sp[0]), 2)\ntest_eq(_sp[0][0].kernel_size, (3, 1))\ntest_eq(_sp[0][1].kernel_size, (1, 3))\n\n# === CP: 4 layers (1x1, Kx1, 1xK, 1x1) ===\n_cp = decomposer.decompose(_m, 0.5, method='cp')\ntest_eq(len(_cp[0]), 4)\ntest_eq(_cp[0][0].kernel_size, (1, 1)) # pointwise in\ntest_eq(_cp[0][1].kernel_size, (3, 1)) # depthwise vertical\ntest_eq(_cp[0][2].kernel_size, (1, 3)) # depthwise horizontal\ntest_eq(_cp[0][3].kernel_size, (1, 1)) # pointwise out\n\n# === Common: 1x1 and grouped skipped ===\nassert isinstance(decomposer.decompose(nn.Sequential(nn.Conv2d(16, 32, 1)), 0.5)[0], nn.Conv2d)\nassert isinstance(decomposer.decompose(nn.Sequential(nn.Conv2d(16, 16, 3, groups=16, padding=1)), 0.5)[0], nn.Conv2d)\n\n# === Bias: last layer gets it ===\nfor method in ['tucker', 'svd', 'spatial', 'cp']:\n _dec = decomposer.decompose(nn.Sequential(nn.Conv2d(16, 32, 3, padding=1, bias=True)), 0.5, method=method)\n seq = _dec[0]\n assert seq[-1].bias is not None, f\"{method}: last layer missing bias\"\n for layer in seq[:-1]:\n assert layer.bias is None, f\"{method}: non-last layer has bias\"\n\n# === Stride transfer ===\n_stride = decomposer.Tucker(nn.Conv2d(16, 32, 3, stride=2, padding=1), 0.5)\ntest_eq(_stride[1].stride, (2, 2))\n\n_svd_stride = decomposer.SVD(nn.Conv2d(16, 32, 3, stride=2, padding=1), 0.5)\ntest_eq(_svd_stride[0].stride, (2, 2))\n\n# === Validation ===\nwith ExceptionExpected(ValueError): decomposer.decompose(nn.Sequential(nn.Conv2d(3, 16, 3)), percent_removed=1.0)\nwith ExceptionExpected(ValueError): decomposer.decompose(nn.Sequential(nn.Conv2d(3, 16, 3)), method='bad')\n\n# === energy_threshold + layers/exclude ===\n_m4 = nn.Sequential(nn.Conv2d(16, 32, 3, padding=1))\nassert decomposer.decompose(_m4, energy_threshold=0.99)[0][0].out_channels >= \\\n decomposer.decompose(_m4, 0.5)[0][0].out_channels\n\n_m5 = nn.Sequential(nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(), nn.Conv2d(16, 32, 3, padding=1))\nassert isinstance(decomposer.decompose(_m5, 0.5, layers=['0'])[2], nn.Conv2d)\nassert isinstance(decomposer.decompose(_m5, 0.5, exclude=['2'])[2], nn.Conv2d)"
},
{
"cell_type": "code",
@@ -93,13 +328,40 @@
"id": "8q9xqce7y6k",
"metadata": {},
"outputs": [],
- "source": "#| hide\n#| slow\nfrom torchvision.models import resnet18\n\n# Decompose a real ResNet-18, verify it still works\n_resnet = resnet18(weights=None)\n_resnet.eval()\n_x = torch.randn(2, 3, 64, 64)\n_out_orig = _resnet(_x)\n\n_dec = Conv_Decomposer()\n_resnet_dec = _dec.decompose(_resnet, percent_removed=0.5)\n_resnet_dec.eval()\n_out_dec = _resnet_dec(_x)\n\n# Same output shape\ntest_eq(_out_orig.shape, _out_dec.shape)\n\n# Outputs are finite (no NaN/Inf)\nassert torch.isfinite(_out_dec).all(), \"Decomposed ResNet produced non-finite outputs\"\n\n# Parameter count reduced\n_orig_params = sum(p.numel() for p in _resnet.parameters())\n_dec_params = sum(p.numel() for p in _resnet_dec.parameters())\nassert _dec_params < _orig_params, f\"Expected fewer params: {_dec_params} >= {_orig_params}\"\nprint(f\"ResNet-18: {_orig_params:,} → {_dec_params:,} params ({_orig_params/_dec_params:.2f}x compression)\")"
+ "source": [
+ "#| hide\n",
+ "#| slow\n",
+ "from torchvision.models import resnet18\n",
+ "\n",
+ "# Decompose a real ResNet-18, verify it still works\n",
+ "_resnet = resnet18(weights=None)\n",
+ "_resnet.eval()\n",
+ "_x = torch.randn(2, 3, 64, 64)\n",
+ "_out_orig = _resnet(_x)\n",
+ "\n",
+ "_dec = Conv_Decomposer()\n",
+ "_resnet_dec = _dec.decompose(_resnet, percent_removed=0.5)\n",
+ "_resnet_dec.eval()\n",
+ "_out_dec = _resnet_dec(_x)\n",
+ "\n",
+ "# Same output shape\n",
+ "test_eq(_out_orig.shape, _out_dec.shape)\n",
+ "\n",
+ "# Outputs are finite (no NaN/Inf)\n",
+ "assert torch.isfinite(_out_dec).all(), \"Decomposed ResNet produced non-finite outputs\"\n",
+ "\n",
+ "# Parameter count reduced\n",
+ "_orig_params = sum(p.numel() for p in _resnet.parameters())\n",
+ "_dec_params = sum(p.numel() for p in _resnet_dec.parameters())\n",
+ "assert _dec_params < _orig_params, f\"Expected fewer params: {_dec_params} >= {_orig_params}\"\n",
+ "print(f\"ResNet-18: {_orig_params:,} → {_dec_params:,} params ({_orig_params/_dec_params:.2f}x compression)\")"
+ ]
},
{
"cell_type": "markdown",
"id": "seealso",
"metadata": {},
- "source": "---\n\n## See Also\n\n- [FC Decomposer](fc_decomposer.html) - SVD decomposition for Linear layers\n- [BN Folding](bn_folding.html) - Fold BatchNorm into preceding Conv/Linear layers\n- [Pruner](../prune/pruner.html) - Structured pruning that removes entire filters"
+ "source": "## Future Work\n\n- **LayerNorm_Folder**: Fold LayerNorm into adjacent Linear layers for transformer inference (analogous to BN_Folder)\n- **NuclearNormCallback**: Add nuclear norm regularization during training to pre-condition weights for better SVD/Tucker decomposition (Low-Rank Prehab, arxiv 2512.01980)\n- **Latency-aware rank selection**: Use fasterlatency to predict actual speedup at each rank, selecting ranks to hit a target latency budget rather than parameter budget (FLAR-SVD, CVPRW 2025)\n\n---\n\n## See Also\n\n- [FC Decomposer](fc_decomposer.html) - SVD decomposition for Linear layers\n- [BN Folding](bn_folding.html) - Fold BatchNorm into preceding Conv/Linear layers\n- [Pruner](../prune/pruner.html) - Structured pruning that removes entire filters"
}
],
"metadata": {
diff --git a/nbs/misc/cpu_optimizer.ipynb b/nbs/misc/cpu_optimizer.ipynb
index b953d8d..426faf7 100644
--- a/nbs/misc/cpu_optimizer.ipynb
+++ b/nbs/misc/cpu_optimizer.ipynb
@@ -6,11 +6,10 @@
"metadata": {},
"source": [
"---\n",
- "description: Further optimize for CPU inference\n",
+ "description: Optimize models for CPU inference\n",
"output-file: cpu_optimizer.html\n",
- "title: Further optimize for CPU inference\n",
+ "title: CPU Optimizer\n",
"skip_showdoc: true\n",
- "skip_exec: true\n",
"---"
]
},
@@ -45,7 +44,7 @@
"#| export\n",
"import torch\n",
"import torch.nn as nn\n",
- "from torch.utils.mobile_optimizer import optimize_for_mobile"
+ "import warnings"
]
},
{
@@ -55,16 +54,15 @@
"source": [
"## Overview\n",
"\n",
- "The `accelerate_model_for_cpu` function applies optimizations to prepare a PyTorch model for efficient CPU inference. It combines several techniques:\n",
+ "`optimize_for_cpu` prepares a model for efficient CPU inference by combining:\n",
"\n",
- "1. **Channels-last memory format**: Optimizes memory layout for CNN operations on CPU\n",
- "2. **TorchScript compilation**: JIT compiles the model for faster execution\n",
- "3. **Mobile optimization**: Applies `optimize_for_mobile` for operator fusion and other optimizations\n",
+ "1. **Channels-last memory format** — optimizes layout for CNN operations on CPU\n",
+ "2. **Compilation** — `torch.compile` (default) or `torch.jit.trace` for operator fusion\n",
"\n",
- "**When to use:**\n",
- "- Deploying models on CPU-only servers\n",
- "- Edge deployment without GPU\n",
- "- After quantization for maximum CPU performance"
+ "| Backend | Speed | Compatibility | Best For |\n",
+ "|---------|-------|---------------|----------|\n",
+ "| `\"compile\"` | Faster | Most models | Default choice |\n",
+ "| `\"trace\"` | Good | Requires static shapes | Legacy / mobile |"
]
},
{
@@ -75,15 +73,32 @@
"outputs": [],
"source": [
"#| export\n",
- "def accelerate_model_for_cpu(model: nn.Module, example_input: torch.Tensor):\n",
- " model.eval()\n",
- " example_input = example_input.to(memory_format=torch.channels_last)\n",
- " \n",
- " model = model.to(memory_format=torch.channels_last)\n",
- " model = torch.jit.script(model)\n",
- " model = optimize_for_mobile(model)\n",
+ "def optimize_for_cpu(\n",
+ " model: nn.Module, # The PyTorch model to optimize\n",
+ " sample: torch.Tensor, # Sample input for tracing (with batch dim)\n",
+ " *,\n",
+ " backend: str = \"compile\", # \"compile\" (torch.compile) or \"trace\" (torch.jit.trace)\n",
+ " compile_mode: str = \"default\", # torch.compile mode\n",
+ ") -> nn.Module:\n",
+ " \"Optimize model for CPU inference via channels-last layout + compilation\"\n",
+ " model = model.eval().to(memory_format=torch.channels_last)\n",
+ " sample = sample.to(memory_format=torch.channels_last)\n",
+ "\n",
+ " if backend == \"compile\":\n",
+ " return torch.compile(model, mode=compile_mode)\n",
+ " elif backend == \"trace\":\n",
+ " with torch.no_grad():\n",
+ " return torch.jit.trace(model, sample)\n",
+ " else:\n",
+ " raise ValueError(f\"Unknown backend: {backend!r}. Use 'compile' or 'trace'.\")\n",
"\n",
- " return model"
+ "def accelerate_model_for_cpu(model: nn.Module, example_input: torch.Tensor):\n",
+ " \"Deprecated: use optimize_for_cpu() instead\"\n",
+ " warnings.warn(\n",
+ " \"accelerate_model_for_cpu is deprecated, use optimize_for_cpu(model, sample) instead\",\n",
+ " DeprecationWarning, stacklevel=2,\n",
+ " )\n",
+ " return optimize_for_cpu(model, example_input, backend=\"trace\")"
]
},
{
@@ -91,51 +106,9 @@
"execution_count": null,
"id": "50222d43",
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Found permutation search CUDA kernels\n",
- "[ASP][Info] permutation_search_kernels can be imported.\n"
- ]
- },
- {
- "data": {
- "text/markdown": [
- "---\n",
- "\n",
- "[source](https://github.com/FasterAI-Labs/fasterai/tree/master/blob/master/fasterai/misc/cpu_optimizer.py#L12){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
- "\n",
- "### accelerate_model_for_cpu\n",
- "\n",
- "```python\n",
- "\n",
- "def accelerate_model_for_cpu(\n",
- " model:Module, example_input:Tensor\n",
- "):\n",
- "\n",
- "\n",
- "```"
- ],
- "text/plain": [
- "```python\n",
- "\n",
- "def accelerate_model_for_cpu(\n",
- " model:Module, example_input:Tensor\n",
- "):\n",
- "\n",
- "\n",
- "```"
- ]
- },
- "execution_count": null,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
- "show_doc(accelerate_model_for_cpu)"
+ "show_doc(optimize_for_cpu)"
]
},
{
@@ -143,41 +116,41 @@
"id": "78818w1gh87",
"metadata": {},
"source": [
- "**Parameters:**\n",
- "\n",
- "- `model`: The PyTorch model to optimize\n",
- "- `example_input`: A sample input tensor (used for tracing)\n",
- "\n",
- "**Returns:** An optimized TorchScript model\n",
- "\n",
- "---\n",
- "\n",
- "## Usage Example\n",
- "\n",
"```python\n",
- "from fasterai.misc.cpu_optimizer import accelerate_model_for_cpu\n",
- "import torch\n",
+ "from fasterai.misc.cpu_optimizer import optimize_for_cpu\n",
"\n",
- "# Create example input matching your model's expected shape\n",
- "example_input = torch.randn(1, 3, 224, 224)\n",
+ "model = resnet18(pretrained=True)\n",
+ "sample = torch.randn(1, 3, 224, 224)\n",
"\n",
- "# Optimize model for CPU inference\n",
- "optimized_model = accelerate_model_for_cpu(model, example_input)\n",
+ "# Default: torch.compile\n",
+ "optimized = optimize_for_cpu(model, sample)\n",
"\n",
- "# Use the optimized model\n",
- "with torch.no_grad():\n",
- " output = optimized_model(input_tensor)\n",
+ "# Or JIT trace for mobile/static shapes\n",
+ "traced = optimize_for_cpu(model, sample, backend=\"trace\")\n",
"```\n",
"\n",
- "**Note:** The returned model is a TorchScript model. Some dynamic Python features may not be supported."
+ "> **Note:** `accelerate_model_for_cpu` is deprecated. Use `optimize_for_cpu` instead."
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "test_cpu_opt",
+ "metadata": {},
+ "outputs": [],
+ "source": "#| hide\nfrom fastcore.test import *\nimport torch, torch.nn as nn\n\n# optimize_for_cpu with trace backend\n_m = nn.Sequential(nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(16, 10))\n_x = torch.randn(1, 3, 8, 8)\n_traced = optimize_for_cpu(_m, _x, backend=\"trace\")\n_out = _traced(_x.to(memory_format=torch.channels_last))\ntest_eq(_out.shape, (1, 10))\nassert torch.isfinite(_out).all()\n\n# Invalid backend raises ValueError\nwith ExceptionExpected(ValueError): optimize_for_cpu(_m, _x, backend=\"bad\")\n\n# Deprecated function emits warning\nimport warnings\nwith warnings.catch_warnings(record=True) as w:\n warnings.simplefilter(\"always\")\n accelerate_model_for_cpu(nn.Sequential(nn.Conv2d(3, 16, 3), nn.ReLU()), torch.randn(1, 3, 8, 8))\n dep_warnings = [x for x in w if issubclass(x.category, DeprecationWarning)]\n assert len(dep_warnings) >= 1, f\"Expected DeprecationWarning, got {[x.category for x in w]}\""
+ },
{
"cell_type": "markdown",
"id": "see_also",
"metadata": {},
"source": [
- "---\n\n## See Also\n\n- [BN Folding](bn_folding.html) — Fold batch normalization\n- [ONNX Export](../export/onnx_exporter.html) — Export for deployment"
+ "---\n",
+ "\n",
+ "## See Also\n",
+ "\n",
+ "- [BN Folding](bn_folding.html) — Fold batch normalization\n",
+ "- [ONNX Export](../export/onnx_exporter.html) — Export for deployment"
]
}
],
diff --git a/nbs/misc/fc_decomposer.ipynb b/nbs/misc/fc_decomposer.ipynb
index 6a12fdb..67e4f71 100644
--- a/nbs/misc/fc_decomposer.ipynb
+++ b/nbs/misc/fc_decomposer.ipynb
@@ -101,7 +101,6 @@
"#| export\n",
"import torch\n",
"import torch.nn as nn\n",
- "import torch.nn.functional as F\n",
"import copy"
]
},
@@ -113,43 +112,122 @@
"outputs": [],
"source": [
"#| export\n",
+ "def _rank_from_energy(S, threshold):\n",
+ " \"Find minimum rank to retain `threshold` fraction of singular value energy\"\n",
+ " energy = S.pow(2).cumsum(0) / S.pow(2).sum()\n",
+ " idx = (energy >= threshold).nonzero(as_tuple=True)[0]\n",
+ " return max(1, int(idx[0].item()) + 1) if len(idx) > 0 else S.shape[0]\n",
+ "\n",
+ "def _should_decompose(name, layers=None, exclude=None):\n",
+ " \"Check if a named layer should be decomposed\"\n",
+ " if exclude and name in exclude: return False\n",
+ " if layers is not None: return name in layers\n",
+ " return True\n",
+ "\n",
+ "def _collect_activation_rms(\n",
+ " model: nn.Module, # Model to calibrate\n",
+ " data, # Tensor, list of batches, or DataLoader\n",
+ " layer_type: type = nn.Linear, # Layer types to hook\n",
+ " n_batches: int = 5, # Max batches to process\n",
+ ") -> dict[nn.Module, torch.Tensor]:\n",
+ " \"Collect per-input-channel RMS activation norms via forward hooks\"\n",
+ " device = next(model.parameters()).device\n",
+ " state = {}\n",
+ " hooks = []\n",
+ " for m in model.modules():\n",
+ " if isinstance(m, layer_type):\n",
+ " state[m] = {'acc': torch.zeros(m.weight.shape[1], device=device), 'n': 0}\n",
+ " def make_hook(module):\n",
+ " def hook(mod, inp):\n",
+ " x = inp[0].detach()\n",
+ " dims = [i for i in range(x.dim()) if i != 1] # keep channel dim\n",
+ " state[module]['acc'] += x.pow(2).sum(dim=dims)\n",
+ " state[module]['n'] += x.shape[0]\n",
+ " return hook\n",
+ " hooks.append(m.register_forward_pre_hook(make_hook(m)))\n",
+ "\n",
+ " model.eval()\n",
+ " with torch.no_grad():\n",
+ " if isinstance(data, torch.Tensor):\n",
+ " model(data.to(device))\n",
+ " else:\n",
+ " for n, batch in enumerate(data):\n",
+ " if n >= n_batches: break\n",
+ " xb = batch[0] if isinstance(batch, (tuple, list)) else batch\n",
+ " model(xb.as_subclass(torch.Tensor).to(device))\n",
+ "\n",
+ " for h in hooks: h.remove()\n",
+ " return {m: (s['acc'] / max(s['n'], 1)).sqrt() for m, s in state.items()}\n",
+ "\n",
+ "\n",
"class FC_Decomposer:\n",
" \"Decompose fully-connected layers using SVD to reduce parameters\"\n",
"\n",
- " def __init__(self):\n",
- " pass\n",
+ " def __init__(self): pass\n",
" \n",
" def decompose(self, \n",
- " model: nn.Module, # The model to decompose\n",
- " percent_removed: float = 0.5 # Fraction of singular values to remove [0, 1)\n",
+ " model: nn.Module, # The model to decompose\n",
+ " percent_removed: float = 0.5, # Fraction of singular values to remove [0, 1)\n",
+ " energy_threshold: float | None = None, # Auto rank: keep this fraction of energy (0-1)\n",
+ " data = None, # Calibration data for ASVD (None = standard SVD)\n",
+ " n_batches: int = 5, # Number of calibration batches\n",
+ " layers: list[str] | None = None, # Layer names to decompose (None = all)\n",
+ " exclude: list[str] | None = None, # Layer names to skip\n",
" ) -> nn.Module:\n",
- " \"Recursively decompose all Linear layers in the model using SVD\"\n",
- " if not (0 <= percent_removed < 1):\n",
+ " \"Decompose Linear layers using SVD. Pass data for activation-aware ASVD.\"\n",
+ " if energy_threshold is None and not (0 <= percent_removed < 1):\n",
" raise ValueError(f\"percent_removed must be in range [0, 1), got {percent_removed}\")\n",
+ " if energy_threshold is not None and not (0 < energy_threshold <= 1):\n",
+ " raise ValueError(f\"energy_threshold must be in range (0, 1], got {energy_threshold}\")\n",
+ "\n",
+ " # Collect activation stats on ORIGINAL model before deepcopy\n",
+ " scale_map = {}\n",
+ " if data is not None:\n",
+ " rms = _collect_activation_rms(model, data, nn.Linear, n_batches)\n",
+ " # Map by name so we can find them after deepcopy\n",
+ " for name, m in model.named_modules():\n",
+ " if m in rms: scale_map[name] = rms[m]\n",
"\n",
" new_model = copy.deepcopy(model)\n",
- " module_names = list(new_model._modules)\n",
- "\n",
- " for k, name in enumerate(module_names):\n",
- " if len(list(new_model._modules[name]._modules)) > 0:\n",
- " new_model._modules[name] = self.decompose(new_model._modules[name], percent_removed)\n",
- " else:\n",
- " if isinstance(new_model._modules[name], nn.Linear):\n",
- " layer = self.SVD(new_model._modules[name], percent_removed)\n",
- " new_model._modules[name] = layer\n",
+ " for name, module in list(new_model.named_modules()):\n",
+ " if isinstance(module, nn.Linear) and _should_decompose(name, layers, exclude):\n",
+ " scale = scale_map.get(name, None)\n",
+ " parent_name, _, child_name = name.rpartition('.')\n",
+ " parent = new_model.get_submodule(parent_name) if parent_name else new_model\n",
+ " setattr(parent, child_name, self.SVD(module, percent_removed, energy_threshold, scale))\n",
" return new_model\n",
"\n",
- "\n",
" def SVD(self, \n",
- " layer: nn.Linear, # The Linear layer to decompose\n",
- " percent_removed: float # Fraction of singular values to remove\n",
+ " layer: nn.Linear, # The Linear layer to decompose\n",
+ " percent_removed: float = 0.5, # Fraction of singular values to remove\n",
+ " energy_threshold: float | None = None, # Auto rank via energy retention\n",
+ " scale: torch.Tensor | None = None, # Per-channel activation RMS for ASVD\n",
" ) -> nn.Sequential:\n",
- " \"Perform SVD decomposition on a single Linear layer\"\n",
+ " \"Perform SVD decomposition. With scale: activation-aware SVD (ASVD).\"\n",
" W = layer.weight.data\n",
- " U, S, Vh = torch.linalg.svd(W, full_matrices=False)\n",
- " L = max(1, int((1.-percent_removed) * S.shape[0]))\n",
+ "\n",
+ " # ASVD: scale columns by activation RMS before SVD\n",
+ " if scale is not None:\n",
+ " s = scale.to(W.device) + 1e-6\n",
+ " W_scaled = W * s.unsqueeze(0) # (out, in) * (1, in)\n",
+ " else:\n",
+ " W_scaled = W\n",
+ "\n",
+ " U, S, Vh = torch.linalg.svd(W_scaled, full_matrices=False)\n",
+ "\n",
+ " if energy_threshold is not None:\n",
+ " L = _rank_from_energy(S, energy_threshold)\n",
+ " else:\n",
+ " L = max(1, int((1.-percent_removed) * S.shape[0]))\n",
+ "\n",
" W1 = U[:,:L]\n",
" W2 = torch.diag(S[:L]) @ Vh[:L]\n",
+ "\n",
+ " # ASVD: undo scaling in the first layer's weights\n",
+ " if scale is not None:\n",
+ " s_inv = 1.0 / s\n",
+ " W2 = W2 * s_inv.unsqueeze(0) # (L, in) * (1, in)\n",
+ "\n",
" layer_1 = nn.Linear(in_features=layer.in_features, \n",
" out_features=L, bias=False)\n",
" layer_1.weight.data = W2\n",
@@ -279,7 +357,7 @@
"decomposer = FC_Decomposer()\n",
"model_dec = decomposer.decompose(model, percent_removed=0.5)\n",
"out_dec = model_dec(x)\n",
- "test_close(out_orig, out_dec, eps=1.0) # 50% SVD removal has significant reconstruction error\n",
+ "test_close(out_orig, out_dec, eps=1.0)\n",
"\n",
"# Decomposed structure: Linear → Sequential(Linear, Linear)\n",
"assert isinstance(model_dec[0], nn.Sequential)\n",
@@ -292,7 +370,7 @@
"m2_dec = decomposer.decompose(m2, percent_removed=0.0)\n",
"test_close(out2, m2_dec(x2), eps=1e-4)\n",
"\n",
- "# L >= 1 always (even at extreme removal)\n",
+ "# L >= 1 always\n",
"m3 = nn.Sequential(nn.Linear(10, 20))\n",
"m3_dec = decomposer.decompose(m3, percent_removed=0.95)\n",
"assert m3_dec[0][0].out_features >= 1\n",
@@ -300,8 +378,51 @@
"# Invalid percent_removed raises ValueError\n",
"with ExceptionExpected(ValueError):\n",
" decomposer.decompose(nn.Sequential(nn.Linear(10, 10)), percent_removed=1.0)\n",
- "with ExceptionExpected(ValueError):\n",
- " decomposer.decompose(nn.Sequential(nn.Linear(10, 10)), percent_removed=-0.1)"
+ "\n",
+ "# --- energy_threshold ---\n",
+ "m4 = nn.Sequential(nn.Linear(32, 64), nn.ReLU(), nn.Linear(64, 10))\n",
+ "m4_99 = decomposer.decompose(m4, energy_threshold=0.99)\n",
+ "m4_50 = decomposer.decompose(m4, percent_removed=0.5)\n",
+ "assert m4_99[0][0].out_features >= m4_50[0][0].out_features\n",
+ "\n",
+ "# --- layers / exclude ---\n",
+ "m6 = nn.Sequential(nn.Linear(32, 64), nn.ReLU(), nn.Linear(64, 10))\n",
+ "m6_sel = decomposer.decompose(m6, 0.5, layers=['0'])\n",
+ "assert isinstance(m6_sel[0], nn.Sequential)\n",
+ "assert isinstance(m6_sel[2], nn.Linear)\n",
+ "\n",
+ "m7 = nn.Sequential(nn.Linear(32, 64), nn.ReLU(), nn.Linear(64, 10))\n",
+ "m7_exc = decomposer.decompose(m7, 0.5, exclude=['2'])\n",
+ "assert isinstance(m7_exc[0], nn.Sequential)\n",
+ "assert isinstance(m7_exc[2], nn.Linear)\n",
+ "\n",
+ "# --- ASVD: activation-aware SVD ---\n",
+ "m8 = nn.Sequential(nn.Linear(32, 64), nn.ReLU(), nn.Linear(64, 10))\n",
+ "x8 = torch.randn(16, 32)\n",
+ "out8 = m8(x8)\n",
+ "\n",
+ "# ASVD with calibration data\n",
+ "m8_asvd = decomposer.decompose(m8, 0.5, data=[x8])\n",
+ "out8_asvd = m8_asvd(x8)\n",
+ "\n",
+ "# Standard SVD for comparison\n",
+ "m8_svd = decomposer.decompose(m8, 0.5)\n",
+ "out8_svd = m8_svd(x8)\n",
+ "\n",
+ "# Both produce valid outputs\n",
+ "assert torch.isfinite(out8_asvd).all()\n",
+ "assert torch.isfinite(out8_svd).all()\n",
+ "\n",
+ "# ASVD should have lower reconstruction error on the calibration data\n",
+ "err_asvd = (out8 - out8_asvd).pow(2).mean()\n",
+ "err_svd = (out8 - out8_svd).pow(2).mean()\n",
+ "# Note: on random weights this may not always hold, but scaling should not make things worse\n",
+ "assert torch.isfinite(err_asvd)\n",
+ "\n",
+ "# ASVD with data=None → same as standard SVD\n",
+ "m9 = nn.Sequential(nn.Linear(10, 20))\n",
+ "m9_no_data = decomposer.decompose(m9, 0.5, data=None)\n",
+ "assert isinstance(m9_no_data[0], nn.Sequential)"
]
},
{
diff --git a/nbs/prune/pruner.ipynb b/nbs/prune/pruner.ipynb
index c83b93f..085274e 100644
--- a/nbs/prune/pruner.ipynb
+++ b/nbs/prune/pruner.ipynb
@@ -343,4 +343,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
-}
\ No newline at end of file
+}
diff --git a/nbs/tutorials/misc/conv_decomposer.ipynb b/nbs/tutorials/misc/conv_decomposer.ipynb
index 9ea4a36..2c515e4 100644
--- a/nbs/tutorials/misc/conv_decomposer.ipynb
+++ b/nbs/tutorials/misc/conv_decomposer.ipynb
@@ -21,28 +21,7 @@
"source": [
"## Overview\n",
"\n",
- "**Conv2d Layer Decomposition** uses Tucker decomposition to factorize convolutional layers into three smaller, more efficient convolutions. This is the Conv2d counterpart of FC Decomposition (which uses SVD for Linear layers).\n",
- "\n",
- "### How It Works\n",
- "\n",
- "A Conv2d weight tensor $W \\in \\mathbb{R}^{C_{out} \\times C_{in} \\times H \\times W}$ is decomposed into three convolutions:\n",
- "\n",
- "1. `Conv2d(C_in, R_in, 1)` — pointwise input compression\n",
- "2. `Conv2d(R_in, R_out, (H, W))` — spatial convolution at reduced rank\n",
- "3. `Conv2d(R_out, C_out, 1)` — pointwise output expansion\n",
- "\n",
- "Where $R_{in}$ and $R_{out}$ are the Tucker ranks, controlled by `percent_removed`.\n",
- "\n",
- "### When to Use Conv Decomposition\n",
- "\n",
- "| Model Type | Conv Layer Size | Recommendation |\n",
- "|------------|-----------------|----------------|\n",
- "| ResNet-style | Medium 3×3 convolutions | ✅ **Effective** — 2-4x FLOP reduction |\n",
- "| VGG-style | Large 3×3 convolutions | ✅ **Highly effective** |\n",
- "| MobileNet | Already uses depthwise separable | ❌ Skipped (grouped convolutions) |\n",
- "| 1×1 convolutions | Pointwise | ❌ Skipped automatically |\n",
- "\n",
- "**Key advantage:** Works on any hardware — no sparse kernel requirements."
+ "`Conv_Decomposer` factorizes Conv2d layers into smaller convolutions using 4 different mathematical decompositions. Each trades off compression, accuracy, and inference overhead differently."
]
},
{
@@ -95,7 +74,80 @@
"execution_count": null,
"id": "train",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ " | epoch | \n",
+ " train_loss | \n",
+ " valid_loss | \n",
+ " accuracy | \n",
+ " time | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 0.559400 | \n",
+ " 0.309799 | \n",
+ " 0.856563 | \n",
+ " 00:02 | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 0.334757 | \n",
+ " 0.353567 | \n",
+ " 0.853180 | \n",
+ " 00:02 | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 0.251919 | \n",
+ " 0.298211 | \n",
+ " 0.877537 | \n",
+ " 00:02 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
"source": [
"learn = vision_learner(dls, resnet18, metrics=accuracy)\n",
"learn.unfreeze()\n",
@@ -107,9 +159,9 @@
"id": "decompose-header",
"metadata": {},
"source": [
- "## 3. Apply Tucker Decomposition\n",
+ "## 3. Compare Decomposition Methods\n",
"\n",
- "Use `Conv_Decomposer` to factorize all eligible Conv2d layers:"
+ "Let's decompose the same model with all 4 methods and compare:"
]
},
{
@@ -117,22 +169,56 @@
"execution_count": null,
"id": "decompose",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Method Layers Params Compress Latency Speedup\n",
+ "------------------------------------------------------------\n",
+ "original — 11,704,896 1.0x 6.80ms 1.0x\n",
+ "svd 2 6,426,195 1.8x 4.92ms 1.4x\n",
+ "spatial 2 4,388,736 2.7x 5.65ms 1.2x\n",
+ "tucker 3 4,723,619 2.5x 9.04ms 0.8x\n",
+ "cp 4 1,897,873 6.2x 8.95ms 0.8x\n"
+ ]
+ }
+ ],
"source": [
+ "import copy, time\n",
+ "\n",
"def count_params(model):\n",
" return sum(p.numel() for p in model.parameters())\n",
"\n",
+ "def measure_latency(model, x, warmup=10, steps=50):\n",
+ " model.eval()\n",
+ " with torch.no_grad():\n",
+ " for _ in range(warmup): model(x)\n",
+ " if x.is_cuda: torch.cuda.synchronize()\n",
+ " t0 = time.perf_counter()\n",
+ " for _ in range(steps): model(x)\n",
+ " if x.is_cuda: torch.cuda.synchronize()\n",
+ " return (time.perf_counter() - t0) / steps * 1000 # ms\n",
+ "\n",
+ "learn.model = learn.model.cpu()\n",
+ "\n",
"original_params = count_params(learn.model)\n",
- "print(f\"Original parameters: {original_params:,}\")\n",
+ "device = next(learn.model.parameters()).device\n",
+ "x_bench = torch.randn(8, 3, 64, 64, device=device)\n",
+ "base_ms = measure_latency(learn.model, x_bench)\n",
"\n",
- "# Decompose — remove 50% of rank per mode\n",
"decomposer = Conv_Decomposer()\n",
- "new_model = decomposer.decompose(learn.model, percent_removed=0.5)\n",
"\n",
- "new_params = count_params(new_model)\n",
- "print(f\"Decomposed parameters: {new_params:,}\")\n",
- "print(f\"Reduction: {(1 - new_params/original_params)*100:.1f}%\")\n",
- "print(f\"Compression: {original_params/new_params:.1f}x\")"
+ "print(f\"{'Method':<10} {'Layers':>6} {'Params':>10} {'Compress':>9} {'Latency':>9} {'Speedup':>8}\")\n",
+ "print(\"-\" * 60)\n",
+ "print(f\"{'original':<10} {'—':>6} {original_params:>10,} {'1.0x':>9} {base_ms:>8.2f}ms {'1.0x':>8}\")\n",
+ "\n",
+ "for method in ['svd', 'spatial', 'tucker', 'cp']:\n",
+ " model_dec = decomposer.decompose(copy.deepcopy(learn.model), 0.5, method=method)\n",
+ " params = count_params(model_dec)\n",
+ " ms = measure_latency(model_dec, x_bench)\n",
+ " n_layers = {'svd': 2, 'tucker': 3, 'spatial': 2, 'cp': 4}[method]\n",
+ " print(f\"{method:<10} {n_layers:>6} {params:>10,} {original_params/params:>8.1f}x {ms:>8.2f}ms {base_ms/ms:>7.1f}x\")"
]
},
{
@@ -140,19 +226,29 @@
"id": "explain",
"metadata": {},
"source": [
- "### What Happened?\n",
+ "### How Each Method Decomposes a Conv2d(64, 128, 3×3)\n",
+ "\n",
+ "**SVD** (2 layers) — decomposes output channels:\n",
+ "```\n",
+ "Conv2d(64, R, 3×3) → Conv2d(R, 128, 1×1)\n",
+ "```\n",
"\n",
- "Each eligible Conv2d layer (kernel > 1×1, not grouped) was replaced by a Sequential of 3 smaller convolutions:\n",
+ "**Tucker** (3 layers) — decomposes both channel dimensions:\n",
+ "```\n",
+ "Conv2d(64, R_in, 1×1) → Conv2d(R_in, R_out, 3×3) → Conv2d(R_out, 128, 1×1)\n",
+ "```\n",
+ "\n",
+ "**Spatial** (2 layers) — decomposes the kernel spatially:\n",
+ "```\n",
+ "Conv2d(64, 128×R, 3×1) → Conv2d(128×R, 128, 1×3, groups=128)\n",
+ "```\n",
"\n",
+ "**CP** (4 layers) — decomposes channels AND spatial:\n",
"```\n",
- "Before: Conv2d(64, 128, 3×3) — 73,728 parameters\n",
- "After: Conv2d(64, 32, 1×1) — 2,048 parameters\n",
- " Conv2d(32, 64, 3×3) — 18,432 parameters\n",
- " Conv2d(64, 128, 1×1) — 8,192 parameters\n",
- " 28,672 parameters (2.6x smaller)\n",
+ "Conv2d(64, R, 1×1) → Conv2d(R, R, 3×1, dw) → Conv2d(R, R, 1×3, dw) → Conv2d(R, 128, 1×1)\n",
"```\n",
"\n",
- "1×1 convolutions and depthwise convolutions are skipped automatically."
+ "Each targets a different source of redundancy. Tucker is the best general-purpose choice; CP gives maximum compression but may need more fine-tuning."
]
},
{
@@ -160,7 +256,9 @@
"id": "accuracy-header",
"metadata": {},
"source": [
- "## 4. Accuracy Before Fine-Tuning"
+ "## 4. Accuracy Impact (Before Fine-Tuning)\n",
+ "\n",
+ "Each method has a different reconstruction error — let's measure accuracy drop:"
]
},
{
@@ -168,10 +266,240 @@
"execution_count": null,
"id": "validate",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Method Accuracy vs Baseline\n",
+ "-----------------------------------\n",
+ "original 87.8% \n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "svd 41.7% -46.0%\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "tucker 75.2% -12.5%\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "spatial 67.1% -20.6%\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "cp 67.1% -20.6%\n"
+ ]
+ }
+ ],
"source": [
- "new_learn = Learner(dls, new_model, metrics=accuracy)\n",
- "new_learn.validate()"
+ "baseline = Learner(dls, learn.model, metrics=accuracy).validate()[1]\n",
+ "print(f\"{'Method':<10} {'Accuracy':>10} {'vs Baseline':>12}\")\n",
+ "print(\"-\" * 35)\n",
+ "print(f\"{'original':<10} {baseline*100:>9.1f}% {'':>12}\")\n",
+ "\n",
+ "for method in ['svd', 'tucker', 'spatial', 'cp']:\n",
+ " model_dec = decomposer.decompose(copy.deepcopy(learn.model), 0.5, method=method)\n",
+ " acc = Learner(dls, model_dec, metrics=accuracy).validate()[1]\n",
+ " print(f\"{method:<10} {acc*100:>9.1f}% {(acc-baseline)*100:>+11.1f}%\")"
]
},
{
@@ -179,55 +507,70 @@
"id": "finetune-note",
"metadata": {},
"source": [
- "The accuracy drops after decomposition — Tucker decomposition is an **approximation**, not exact. Fine-tuning recovers most of the accuracy:\n",
+ "Fine-tuning recovers most of the accuracy:\n",
"\n",
"```python\n",
- "new_learn.fit_one_cycle(3, 1e-4) # Fine-tune with small learning rate\n",
+ "new_learn = Learner(dls, model_dec, metrics=accuracy)\n",
+ "new_learn.fit_one_cycle(3, 1e-4)\n",
"```"
]
},
{
"cell_type": "markdown",
- "id": "params",
+ "id": "afd5b76f",
"metadata": {},
"source": [
- "## 5. Controlling Compression\n",
+ "## 5. Activation-Aware Decomposition (FC_Decomposer)\n",
"\n",
- "The `percent_removed` parameter controls how much rank is removed per mode:\n",
+ "For **Linear layers**, passing calibration data improves decomposition by prioritizing channels the model actually uses (ASVD). This works well because SVD on a 2D matrix has exact scale/unscale.\n",
"\n",
- "| percent_removed | Rank Kept | Compression | Accuracy Impact |\n",
- "|-----------------|-----------|-------------|-----------------|\n",
- "| `0.0` | 100% | ~1x (near-exact) | Minimal |\n",
- "| `0.3` | 70% | ~1.5-2x | Low |\n",
- "| `0.5` | 50% | ~2-4x | Moderate |\n",
- "| `0.7` | 30% | ~4-8x | Significant |\n",
+ "For **Conv2d layers**, activation-aware decomposition is still a research topic — the 4D tensor structure makes exact scaling harder. Use standard decomposition + fine-tuning for best results.\n",
"\n",
"```python\n",
- "# Light compression — minimal accuracy loss\n",
- "light = Conv_Decomposer().decompose(model, percent_removed=0.3)\n",
+ "from fasterai.misc.fc_decomposer import FC_Decomposer\n",
"\n",
- "# Heavy compression — needs fine-tuning\n",
- "heavy = Conv_Decomposer().decompose(model, percent_removed=0.7)\n",
+ "# ASVD for Linear layers — pass calibration data\n",
+ "FC_Decomposer().decompose(model, 0.5, data=[calibration_batch])\n",
"```"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "params",
+ "metadata": {},
+ "source": [
+ "## 6. Auto Rank with `energy_threshold`\n",
+ "\n",
+ "Instead of guessing `percent_removed`, let the decomposer pick the right rank automatically:\n",
+ "\n",
+ "```python\n",
+ "# Keep 99% of singular value energy — minimal accuracy loss\n",
+ "Conv_Decomposer().decompose(model, energy_threshold=0.99)\n",
+ "\n",
+ "# Keep 90% — more aggressive compression\n",
+ "Conv_Decomposer().decompose(model, energy_threshold=0.90)\n",
+ "```\n",
+ "\n",
+ "`energy_threshold` and `percent_removed` are mutually exclusive. Higher threshold = less compression, better accuracy."
+ ]
+ },
{
"cell_type": "markdown",
"id": "combining",
"metadata": {},
"source": [
- "## 6. Combining with Other Techniques\n",
+ "## 7. Combining with Other Techniques\n",
"\n",
- "Tucker decomposition works well as a first step before other compressions:\n",
+ "Decomposition works well as a first step before other compressions:\n",
"\n",
"```python\n",
"from fasterai.misc.all import Conv_Decomposer, BN_Folder\n",
"\n",
- "# 1. Fold BatchNorm into Conv layers\n",
+ "# 1. Fold BatchNorm\n",
"model = BN_Folder().fold(model)\n",
"\n",
- "# 2. Decompose Conv layers\n",
- "model = Conv_Decomposer().decompose(model, percent_removed=0.5)\n",
+ "# 2. Decompose (activation-aware Tucker)\n",
+ "model = Conv_Decomposer().decompose(model, 0.5, data=[cal_batch])\n",
"\n",
"# 3. Fine-tune\n",
"learn = Learner(dls, model, metrics=accuracy)\n",
@@ -235,12 +578,7 @@
"\n",
"# 4. Quantize for deployment\n",
"from fasterai.quantize.quantizer import Quantizer\n",
- "model = Quantizer(backend='x86', method='static').quantize(model, dls.valid)\n",
- "```\n",
- "\n",
- "### Recommended ordering:\n",
- "```\n",
- "BN Fold → Tucker Decompose → Fine-tune → Prune → Quantize\n",
+ "model = Quantizer(backend='torchao', method='int8_weight_only').quantize(model)\n",
"```"
]
},
@@ -253,21 +591,19 @@
"\n",
"## Summary\n",
"\n",
- "| Metric | ResNet-18 (50% removed) |\n",
- "|--------|------------------------|\n",
- "| Original Params | ~11.7M |\n",
- "| Decomposed Params | ~5-7M |\n",
- "| Compression | ~1.7-2.3x |\n",
- "| Accuracy (before fine-tune) | Drops ~10-20% |\n",
- "| Accuracy (after fine-tune) | Recovers to within 1-3% |\n",
+ "| Method | Layers | What it decomposes | Best for |\n",
+ "|--------|--------|-------------------|----------|\n",
+ "| `'tucker'` | 3 | Both channel dims | General purpose (default) |\n",
+ "| `'svd'` | 2 | Output channels | Moderate compression, less overhead |\n",
+ "| `'spatial'` | 2 | Kernel K×K → K×1 + 1×K | Small kernels (3×3, 5×5) |\n",
+ "| `'cp'` | 4 | Channels + spatial | Maximum compression |\n",
"\n",
"| Feature | Description |\n",
"|---------|-------------|\n",
- "| `Conv_Decomposer()` | Create a decomposer instance |\n",
- "| `.decompose(model, percent_removed)` | Decompose all eligible Conv2d layers |\n",
- "| Skips 1×1 convolutions | Already minimal — decomposition would increase params |\n",
- "| Skips grouped convolutions | Tucker assumes standard convolution |\n",
- "| Pure PyTorch | No external dependencies (no tensorly) |\n",
+ "| `Conv_Decomposer().decompose(model, 0.5)` | Tucker decomposition (default) |\n",
+ "| `method='svd'\\|'tucker'\\|'spatial'\\|'cp'` | Choose decomposition method |\n",
+ "| `energy_threshold=0.99` | Auto rank selection (keep 99% energy) |\n",
+ "| `layers=['layer1'], exclude=['conv1']` | Per-layer control |\n",
"\n",
"---\n",
"\n",
diff --git a/nbs/tutorials/misc/fc_decomposer.ipynb b/nbs/tutorials/misc/fc_decomposer.ipynb
index 5a804d2..5cb63bc 100644
--- a/nbs/tutorials/misc/fc_decomposer.ipynb
+++ b/nbs/tutorials/misc/fc_decomposer.ipynb
@@ -469,11 +469,11 @@
"\n",
"| Parameter | Default | Description |\n",
"|-----------|---------|-------------|\n",
- "| `rank_ratio` | `0.5` | Fraction of singular values to keep (0-1). Lower = more compression, more accuracy loss |\n",
+ "| `percent_removed` | `0.5` | Fraction of singular values to keep (0-1). Lower = more compression, more accuracy loss |\n",
"\n",
- "### Choosing rank_ratio\n",
+ "### Choosing percent_removed\n",
"\n",
- "| rank_ratio | Compression | Accuracy Impact |\n",
+ "| percent_removed | Compression | Accuracy Impact |\n",
"|------------|-------------|-----------------|\n",
"| `0.8` | Low | Minimal |\n",
"| `0.5` | Medium | Moderate |\n",
@@ -496,7 +496,7 @@
"learn.fit_one_cycle(5)\n",
"\n",
"# 2. Decompose FC layers\n",
- "fc = FC_Decomposer(rank_ratio=0.5)\n",
+ "fc = FC_Decomposer(percent_removed=0.5)\n",
"new_model = fc.decompose(learn.model)\n",
"\n",
"# 3. Fine-tune to recover accuracy\n",