diff --git a/fasterai/_modidx.py b/fasterai/_modidx.py index 500ab14..e6f0d9d 100644 --- a/fasterai/_modidx.py +++ b/fasterai/_modidx.py @@ -223,18 +223,26 @@ 'fasterai/misc/bn_folding.py')}, 'fasterai.misc.conv_decomposer': { 'fasterai.misc.conv_decomposer.Conv_Decomposer': ( 'misc/conv_decomposer.html#conv_decomposer', 'fasterai/misc/conv_decomposer.py'), + 'fasterai.misc.conv_decomposer.Conv_Decomposer.CP': ( 'misc/conv_decomposer.html#conv_decomposer.cp', + 'fasterai/misc/conv_decomposer.py'), + 'fasterai.misc.conv_decomposer.Conv_Decomposer.SVD': ( 'misc/conv_decomposer.html#conv_decomposer.svd', + 'fasterai/misc/conv_decomposer.py'), + 'fasterai.misc.conv_decomposer.Conv_Decomposer.Spatial': ( 'misc/conv_decomposer.html#conv_decomposer.spatial', + 'fasterai/misc/conv_decomposer.py'), 'fasterai.misc.conv_decomposer.Conv_Decomposer.Tucker': ( 'misc/conv_decomposer.html#conv_decomposer.tucker', 'fasterai/misc/conv_decomposer.py'), 'fasterai.misc.conv_decomposer.Conv_Decomposer.__init__': ( 'misc/conv_decomposer.html#conv_decomposer.__init__', 'fasterai/misc/conv_decomposer.py'), 'fasterai.misc.conv_decomposer.Conv_Decomposer.decompose': ( 'misc/conv_decomposer.html#conv_decomposer.decompose', 'fasterai/misc/conv_decomposer.py'), + 'fasterai.misc.conv_decomposer._mode_unfold': ( 'misc/conv_decomposer.html#_mode_unfold', + 'fasterai/misc/conv_decomposer.py'), 'fasterai.misc.conv_decomposer._partial_tucker': ( 'misc/conv_decomposer.html#_partial_tucker', - 'fasterai/misc/conv_decomposer.py'), - 'fasterai.misc.conv_decomposer._unfold': ( 'misc/conv_decomposer.html#_unfold', - 'fasterai/misc/conv_decomposer.py')}, + 'fasterai/misc/conv_decomposer.py')}, 'fasterai.misc.cpu_optimizer': { 'fasterai.misc.cpu_optimizer.accelerate_model_for_cpu': ( 'misc/cpu_optimizer.html#accelerate_model_for_cpu', - 'fasterai/misc/cpu_optimizer.py')}, + 'fasterai/misc/cpu_optimizer.py'), + 'fasterai.misc.cpu_optimizer.optimize_for_cpu': ( 'misc/cpu_optimizer.html#optimize_for_cpu', + 'fasterai/misc/cpu_optimizer.py')}, 'fasterai.misc.fc_decomposer': { 'fasterai.misc.fc_decomposer.FC_Decomposer': ( 'misc/fc_decomposer.html#fc_decomposer', 'fasterai/misc/fc_decomposer.py'), 'fasterai.misc.fc_decomposer.FC_Decomposer.SVD': ( 'misc/fc_decomposer.html#fc_decomposer.svd', @@ -242,7 +250,13 @@ 'fasterai.misc.fc_decomposer.FC_Decomposer.__init__': ( 'misc/fc_decomposer.html#fc_decomposer.__init__', 'fasterai/misc/fc_decomposer.py'), 'fasterai.misc.fc_decomposer.FC_Decomposer.decompose': ( 'misc/fc_decomposer.html#fc_decomposer.decompose', - 'fasterai/misc/fc_decomposer.py')}, + 'fasterai/misc/fc_decomposer.py'), + 'fasterai.misc.fc_decomposer._collect_activation_rms': ( 'misc/fc_decomposer.html#_collect_activation_rms', + 'fasterai/misc/fc_decomposer.py'), + 'fasterai.misc.fc_decomposer._rank_from_energy': ( 'misc/fc_decomposer.html#_rank_from_energy', + 'fasterai/misc/fc_decomposer.py'), + 'fasterai.misc.fc_decomposer._should_decompose': ( 'misc/fc_decomposer.html#_should_decompose', + 'fasterai/misc/fc_decomposer.py')}, 'fasterai.prune.all': {}, 'fasterai.prune.prune_callback': { 'fasterai.prune.prune_callback.PruneCallback': ( 'prune/prune_callback.html#prunecallback', 'fasterai/prune/prune_callback.py'), diff --git a/fasterai/misc/all.py b/fasterai/misc/all.py index 545f64c..f071eec 100644 --- a/fasterai/misc/all.py +++ b/fasterai/misc/all.py @@ -1,3 +1,4 @@ from .bn_folding import * from .fc_decomposer import * -from .conv_decomposer import * \ No newline at end of file +from .conv_decomposer import * +from .cpu_optimizer import * \ No newline at end of file diff --git a/fasterai/misc/bn_folding.py b/fasterai/misc/bn_folding.py index 758b722..6334d33 100644 --- a/fasterai/misc/bn_folding.py +++ b/fasterai/misc/bn_folding.py @@ -6,7 +6,6 @@ # %% ../../nbs/misc/bn_folding.ipynb #productive-preparation import torch import torch.nn as nn -import torch.nn.functional as F import copy # %% ../../nbs/misc/bn_folding.ipynb #83000749 diff --git a/fasterai/misc/conv_decomposer.py b/fasterai/misc/conv_decomposer.py index 933f8c2..81e39b7 100644 --- a/fasterai/misc/conv_decomposer.py +++ b/fasterai/misc/conv_decomposer.py @@ -1,86 +1,208 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/misc/conv_decomposer.ipynb. # %% auto #0 -__all__ = ['Conv_Decomposer'] +__all__ = ['VALID_METHODS', 'Conv_Decomposer'] # %% ../../nbs/misc/conv_decomposer.ipynb #imports import torch import torch.nn as nn import copy +from einops import rearrange # %% ../../nbs/misc/conv_decomposer.ipynb #conv-decomposer -def _unfold(tensor, mode): - "Unfold a tensor along a mode into a matrix" - return tensor.moveaxis(mode, 0).flatten(1) +from .fc_decomposer import _rank_from_energy, _should_decompose -def _partial_tucker(weight, ranks, n_iter=5): +def _mode_unfold(W, mode): + "Unfold a 4D tensor along a mode into a 2D matrix" + return rearrange(W, 'o i h w -> o (i h w)') if mode == 0 else rearrange(W, 'o i h w -> i (o h w)') + +def _partial_tucker(weight, ranks, n_iter=10, tol=1e-4): "Partial Tucker decomposition on modes [0, 1] via alternating SVD (HOOI)" - # Initialize factors from SVD of mode unfoldings - U0 = torch.linalg.svd(_unfold(weight, 0), full_matrices=False)[0][:, :ranks[0]] - U1 = torch.linalg.svd(_unfold(weight, 1), full_matrices=False)[0][:, :ranks[1]] + U0 = torch.linalg.svd(_mode_unfold(weight, 0), full_matrices=False)[0][:, :ranks[0]] + U1 = torch.linalg.svd(_mode_unfold(weight, 1), full_matrices=False)[0][:, :ranks[1]] for _ in range(n_iter): - # Project out mode 0 using U0, then update U1 + U0_prev, U1_prev = U0.clone(), U1.clone() proj = torch.einsum('oihw, or -> rihw', weight, U0) - U1 = torch.linalg.svd(_unfold(proj, 1), full_matrices=False)[0][:, :ranks[1]] - # Project out mode 1 using U1, then update U0 + U1 = torch.linalg.svd(_mode_unfold(proj, 1), full_matrices=False)[0][:, :ranks[1]] proj = torch.einsum('oihw, is -> oshw', weight, U1) - U0 = torch.linalg.svd(_unfold(proj, 0), full_matrices=False)[0][:, :ranks[0]] + U0 = torch.linalg.svd(_mode_unfold(proj, 0), full_matrices=False)[0][:, :ranks[0]] + if (U0 - U0_prev).norm() + (U1 - U1_prev).norm() < tol: break - # Core = W ×₀ U0ᵀ ×₁ U1ᵀ core = torch.einsum('oihw, or, is -> rshw', weight, U0, U1) return core, [U0, U1] +VALID_METHODS = frozenset({'tucker', 'svd', 'spatial', 'cp'}) class Conv_Decomposer: - "Decompose Conv2d layers using Tucker decomposition to reduce parameters and FLOPs" + "Decompose Conv2d layers to reduce parameters and FLOPs" def __init__(self): pass def decompose(self, - model: nn.Module, # The model to decompose - percent_removed: float = 0.5, # Fraction of rank to remove per mode [0, 1) + model: nn.Module, # The model to decompose + percent_removed: float = 0.5, # Fraction of rank to remove [0, 1) + method: str = 'tucker', # 'tucker', 'svd', 'spatial', or 'cp' + energy_threshold: float | None = None, # Auto rank via energy retention (0-1) + layers: list[str] | None = None, # Layer names to decompose (None = all eligible) + exclude: list[str] | None = None, # Layer names to skip + n_iter: int = 10, # Max HOOI iterations (tucker only) + tol: float = 1e-4, # HOOI convergence tolerance (tucker only) ) -> nn.Module: - "Recursively decompose all eligible Conv2d layers in the model" - if not (0 <= percent_removed < 1): + "Decompose eligible Conv2d layers using the specified method." + if method not in VALID_METHODS: + raise ValueError(f"method must be one of {VALID_METHODS}, got {method!r}") + if energy_threshold is None and not (0 <= percent_removed < 1): raise ValueError(f"percent_removed must be in range [0, 1), got {percent_removed}") + if energy_threshold is not None and not (0 < energy_threshold <= 1): + raise ValueError(f"energy_threshold must be in range (0, 1], got {energy_threshold}") + + decompose_fn = {'tucker': self.Tucker, 'svd': self.SVD, + 'spatial': self.Spatial, 'cp': self.CP}[method] new_model = copy.deepcopy(model) - for name in list(new_model._modules): - module = new_model._modules[name] - if len(list(module._modules)) > 0: - new_model._modules[name] = self.decompose(module, percent_removed) - elif isinstance(module, nn.Conv2d) and module.groups == 1 and min(module.kernel_size) > 1: - new_model._modules[name] = self.Tucker(module, percent_removed) + for name, module in list(new_model.named_modules()): + if (isinstance(module, nn.Conv2d) and module.groups == 1 + and min(module.kernel_size) > 1 + and _should_decompose(name, layers, exclude)): + parent_name, _, child_name = name.rpartition('.') + parent = new_model.get_submodule(parent_name) if parent_name else new_model + if method == 'tucker': + replacement = decompose_fn(module, percent_removed, energy_threshold, n_iter, tol) + else: + replacement = decompose_fn(module, percent_removed, energy_threshold) + setattr(parent, child_name, replacement) return new_model - def Tucker(self, - layer: nn.Conv2d, # The Conv2d layer to decompose - percent_removed: float, # Fraction of rank to remove per mode + def SVD(self, + layer: nn.Conv2d, + percent_removed: float = 0.5, + energy_threshold: float | None = None, + ) -> nn.Sequential: + "SVD: 2 layers — spatial at reduced output rank + pointwise expansion" + W = layer.weight.data + C_out, C_in = W.shape[:2] + K = layer.kernel_size + + W_2d = rearrange(W, 'o i h w -> o (i h w)') + + U, S, Vh = torch.linalg.svd(W_2d, full_matrices=False) + R = _rank_from_energy(S, energy_threshold) if energy_threshold else max(1, int((1 - percent_removed) * min(C_out, C_in))) + + W_first = torch.diag(S[:R]) @ Vh[:R] + + first = nn.Conv2d(C_in, R, K, stride=layer.stride, + padding=layer.padding, dilation=layer.dilation, bias=False) + first.weight.data = rearrange(W_first, 'r (i h w) -> r i h w', i=C_in, h=K[0], w=K[1]) + + last = nn.Conv2d(R, C_out, 1, bias=layer.bias is not None) + last.weight.data = rearrange(U[:, :R], 'o r -> o r 1 1') + if layer.bias is not None: last.bias.data = layer.bias.data + + return nn.Sequential(first, last) + + def Spatial(self, + layer: nn.Conv2d, + percent_removed: float = 0.5, + energy_threshold: float | None = None, ) -> nn.Sequential: - "Perform Tucker decomposition on a single Conv2d layer" + "Spatial separable: 2 layers — K×1 vertical + 1×K horizontal (batched SVD)" W = layer.weight.data C_out, C_in = W.shape[:2] + Kh, Kw = layer.kernel_size + + W_spatial = rearrange(W, 'o i h w -> (o i) h w') + U_all, S_all, Vh_all = torch.linalg.svd(W_spatial, full_matrices=False) + R = _rank_from_energy(S_all[0], energy_threshold) if energy_threshold else max(1, int((1 - percent_removed) * min(Kh, Kw))) - R_out = max(1, int((1 - percent_removed) * C_out)) - R_in = max(1, int((1 - percent_removed) * C_in)) + U_scaled = U_all[:, :, :R] * S_all[:, :R].unsqueeze(1).sqrt() + W_vert = rearrange(U_scaled, '(o i) h r -> (o r) i h 1', o=C_out, i=C_in) + + Vh_scaled = S_all[:, :R].unsqueeze(2).sqrt() * Vh_all[:, :R, :] + Vh_by_out = rearrange(Vh_scaled, '(o i) r w -> o i r w', o=C_out) + W_horiz = rearrange(Vh_by_out.mean(dim=1), 'o r w -> o r 1 w') + + vert = nn.Conv2d(C_in, C_out * R, (Kh, 1), + stride=(layer.stride[0], 1), padding=(layer.padding[0], 0), bias=False) + vert.weight.data = W_vert + + horiz = nn.Conv2d(C_out * R, C_out, (1, Kw), groups=C_out, + stride=(1, layer.stride[1]), padding=(0, layer.padding[1]), + bias=layer.bias is not None) + horiz.weight.data = W_horiz + if layer.bias is not None: horiz.bias.data = layer.bias.data + + return nn.Sequential(vert, horiz) + + def CP(self, + layer: nn.Conv2d, + percent_removed: float = 0.5, + energy_threshold: float | None = None, + ) -> nn.Sequential: + "CP: 4 layers — pointwise compress + depthwise vertical + depthwise horizontal + pointwise expand" + W = layer.weight.data + C_out, C_in = W.shape[:2] + Kh, Kw = layer.kernel_size + + W_2d = rearrange(W, 'o i h w -> o (i h w)') + U, S, Vh = torch.linalg.svd(W_2d, full_matrices=False) + S0 = torch.linalg.svd(_mode_unfold(W, 0), full_matrices=False)[1] + R = _rank_from_energy(S0, energy_threshold) if energy_threshold else max(1, int((1 - percent_removed) * min(C_out, C_in))) + + V_4d = rearrange(Vh[:R], 'r (i h w) -> r i h w', i=C_in, h=Kh, w=Kw) + spatial_avg = V_4d.mean(dim=1) + U_s, S_s, Vh_s = torch.linalg.svd(spatial_avg, full_matrices=False) + + W_dw_v = rearrange(U_s[:, :, 0] * S_s[:, 0:1].sqrt(), 'r h -> r 1 h 1') + W_dw_h = rearrange(Vh_s[:, 0, :] * S_s[:, 0:1].sqrt(), 'r w -> r 1 1 w') + channel_norms = V_4d.pow(2).sum(dim=(2, 3)).sqrt() + W_pw_in = rearrange(channel_norms * S[:R].sqrt().unsqueeze(1), 'r i -> r i 1 1') + W_pw_out = rearrange(U[:, :R] * S[:R].sqrt().unsqueeze(0), 'o r -> o r 1 1') + + pw_in = nn.Conv2d(C_in, R, 1, bias=False) + pw_in.weight.data = W_pw_in + dw_v = nn.Conv2d(R, R, (Kh, 1), groups=R, stride=(layer.stride[0], 1), + padding=(layer.padding[0], 0), bias=False) + dw_v.weight.data = W_dw_v + dw_h = nn.Conv2d(R, R, (1, Kw), groups=R, stride=(1, layer.stride[1]), + padding=(0, layer.padding[1]), bias=False) + dw_h.weight.data = W_dw_h + pw_out = nn.Conv2d(R, C_out, 1, bias=layer.bias is not None) + pw_out.weight.data = W_pw_out + if layer.bias is not None: pw_out.bias.data = layer.bias.data + + return nn.Sequential(pw_in, dw_v, dw_h, pw_out) + + def Tucker(self, + layer: nn.Conv2d, + percent_removed: float = 0.5, + energy_threshold: float | None = None, + n_iter: int = 10, + tol: float = 1e-4, + ) -> nn.Sequential: + "Tucker: 3 layers — pointwise compress + spatial + pointwise expand" + W = layer.weight.data + C_out, C_in = W.shape[:2] - core, (U_out, U_in) = _partial_tucker(W, [R_out, R_in]) - # core: (R_out, R_in, H, W), U_out: (C_out, R_out), U_in: (C_in, R_in) + if energy_threshold is not None: + S0 = torch.linalg.svd(_mode_unfold(W, 0), full_matrices=False)[1] + S1 = torch.linalg.svd(_mode_unfold(W, 1), full_matrices=False)[1] + R_out = _rank_from_energy(S0, energy_threshold) + R_in = _rank_from_energy(S1, energy_threshold) + else: + R_out = max(1, int((1 - percent_removed) * C_out)) + R_in = max(1, int((1 - percent_removed) * C_in)) + core, (U_out, U_in) = _partial_tucker(W, [R_out, R_in], n_iter=n_iter, tol=tol) - # 1. Pointwise input compression: (C_in → R_in) first = nn.Conv2d(C_in, R_in, 1, bias=False) - first.weight.data = U_in.t().unsqueeze(-1).unsqueeze(-1) + first.weight.data = rearrange(U_in.t(), 'r i -> r i 1 1') - # 2. Spatial convolution at reduced rank: (R_in → R_out) middle = nn.Conv2d(R_in, R_out, layer.kernel_size, stride=layer.stride, padding=layer.padding, dilation=layer.dilation, bias=False) middle.weight.data = core - # 3. Pointwise output expansion: (R_out → C_out) last = nn.Conv2d(R_out, C_out, 1, bias=layer.bias is not None) - last.weight.data = U_out.unsqueeze(-1).unsqueeze(-1) - if layer.bias is not None: - last.bias.data = layer.bias.data + last.weight.data = rearrange(U_out, 'o r -> o r 1 1') + if layer.bias is not None: last.bias.data = layer.bias.data return nn.Sequential(first, middle, last) diff --git a/fasterai/misc/cpu_optimizer.py b/fasterai/misc/cpu_optimizer.py index b5fd869..bf5e91c 100644 --- a/fasterai/misc/cpu_optimizer.py +++ b/fasterai/misc/cpu_optimizer.py @@ -1,20 +1,37 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/misc/cpu_optimizer.ipynb. # %% auto #0 -__all__ = ['accelerate_model_for_cpu'] +__all__ = ['optimize_for_cpu', 'accelerate_model_for_cpu'] # %% ../../nbs/misc/cpu_optimizer.ipynb #fbbccd4a import torch import torch.nn as nn -from torch.utils.mobile_optimizer import optimize_for_mobile +import warnings # %% ../../nbs/misc/cpu_optimizer.ipynb #6524ac31 -def accelerate_model_for_cpu(model: nn.Module, example_input: torch.Tensor): - model.eval() - example_input = example_input.to(memory_format=torch.channels_last) - - model = model.to(memory_format=torch.channels_last) - model = torch.jit.script(model) - model = optimize_for_mobile(model) +def optimize_for_cpu( + model: nn.Module, # The PyTorch model to optimize + sample: torch.Tensor, # Sample input for tracing (with batch dim) + *, + backend: str = "compile", # "compile" (torch.compile) or "trace" (torch.jit.trace) + compile_mode: str = "default", # torch.compile mode +) -> nn.Module: + "Optimize model for CPU inference via channels-last layout + compilation" + model = model.eval().to(memory_format=torch.channels_last) + sample = sample.to(memory_format=torch.channels_last) + + if backend == "compile": + return torch.compile(model, mode=compile_mode) + elif backend == "trace": + with torch.no_grad(): + return torch.jit.trace(model, sample) + else: + raise ValueError(f"Unknown backend: {backend!r}. Use 'compile' or 'trace'.") - return model +def accelerate_model_for_cpu(model: nn.Module, example_input: torch.Tensor): + "Deprecated: use optimize_for_cpu() instead" + warnings.warn( + "accelerate_model_for_cpu is deprecated, use optimize_for_cpu(model, sample) instead", + DeprecationWarning, stacklevel=2, + ) + return optimize_for_cpu(model, example_input, backend="trace") diff --git a/fasterai/misc/fc_decomposer.py b/fasterai/misc/fc_decomposer.py index c1fa113..745cf46 100644 --- a/fasterai/misc/fc_decomposer.py +++ b/fasterai/misc/fc_decomposer.py @@ -6,47 +6,125 @@ # %% ../../nbs/misc/fc_decomposer.ipynb #fbbccd4a import torch import torch.nn as nn -import torch.nn.functional as F import copy # %% ../../nbs/misc/fc_decomposer.ipynb #6524ac31 +def _rank_from_energy(S, threshold): + "Find minimum rank to retain `threshold` fraction of singular value energy" + energy = S.pow(2).cumsum(0) / S.pow(2).sum() + idx = (energy >= threshold).nonzero(as_tuple=True)[0] + return max(1, int(idx[0].item()) + 1) if len(idx) > 0 else S.shape[0] + +def _should_decompose(name, layers=None, exclude=None): + "Check if a named layer should be decomposed" + if exclude and name in exclude: return False + if layers is not None: return name in layers + return True + +def _collect_activation_rms( + model: nn.Module, # Model to calibrate + data, # Tensor, list of batches, or DataLoader + layer_type: type = nn.Linear, # Layer types to hook + n_batches: int = 5, # Max batches to process +) -> dict[nn.Module, torch.Tensor]: + "Collect per-input-channel RMS activation norms via forward hooks" + device = next(model.parameters()).device + state = {} + hooks = [] + for m in model.modules(): + if isinstance(m, layer_type): + state[m] = {'acc': torch.zeros(m.weight.shape[1], device=device), 'n': 0} + def make_hook(module): + def hook(mod, inp): + x = inp[0].detach() + dims = [i for i in range(x.dim()) if i != 1] # keep channel dim + state[module]['acc'] += x.pow(2).sum(dim=dims) + state[module]['n'] += x.shape[0] + return hook + hooks.append(m.register_forward_pre_hook(make_hook(m))) + + model.eval() + with torch.no_grad(): + if isinstance(data, torch.Tensor): + model(data.to(device)) + else: + for n, batch in enumerate(data): + if n >= n_batches: break + xb = batch[0] if isinstance(batch, (tuple, list)) else batch + model(xb.as_subclass(torch.Tensor).to(device)) + + for h in hooks: h.remove() + return {m: (s['acc'] / max(s['n'], 1)).sqrt() for m, s in state.items()} + + class FC_Decomposer: "Decompose fully-connected layers using SVD to reduce parameters" - def __init__(self): - pass + def __init__(self): pass def decompose(self, - model: nn.Module, # The model to decompose - percent_removed: float = 0.5 # Fraction of singular values to remove [0, 1) + model: nn.Module, # The model to decompose + percent_removed: float = 0.5, # Fraction of singular values to remove [0, 1) + energy_threshold: float | None = None, # Auto rank: keep this fraction of energy (0-1) + data = None, # Calibration data for ASVD (None = standard SVD) + n_batches: int = 5, # Number of calibration batches + layers: list[str] | None = None, # Layer names to decompose (None = all) + exclude: list[str] | None = None, # Layer names to skip ) -> nn.Module: - "Recursively decompose all Linear layers in the model using SVD" - if not (0 <= percent_removed < 1): + "Decompose Linear layers using SVD. Pass data for activation-aware ASVD." + if energy_threshold is None and not (0 <= percent_removed < 1): raise ValueError(f"percent_removed must be in range [0, 1), got {percent_removed}") + if energy_threshold is not None and not (0 < energy_threshold <= 1): + raise ValueError(f"energy_threshold must be in range (0, 1], got {energy_threshold}") + + # Collect activation stats on ORIGINAL model before deepcopy + scale_map = {} + if data is not None: + rms = _collect_activation_rms(model, data, nn.Linear, n_batches) + # Map by name so we can find them after deepcopy + for name, m in model.named_modules(): + if m in rms: scale_map[name] = rms[m] new_model = copy.deepcopy(model) - module_names = list(new_model._modules) - - for k, name in enumerate(module_names): - if len(list(new_model._modules[name]._modules)) > 0: - new_model._modules[name] = self.decompose(new_model._modules[name], percent_removed) - else: - if isinstance(new_model._modules[name], nn.Linear): - layer = self.SVD(new_model._modules[name], percent_removed) - new_model._modules[name] = layer + for name, module in list(new_model.named_modules()): + if isinstance(module, nn.Linear) and _should_decompose(name, layers, exclude): + scale = scale_map.get(name, None) + parent_name, _, child_name = name.rpartition('.') + parent = new_model.get_submodule(parent_name) if parent_name else new_model + setattr(parent, child_name, self.SVD(module, percent_removed, energy_threshold, scale)) return new_model - def SVD(self, - layer: nn.Linear, # The Linear layer to decompose - percent_removed: float # Fraction of singular values to remove + layer: nn.Linear, # The Linear layer to decompose + percent_removed: float = 0.5, # Fraction of singular values to remove + energy_threshold: float | None = None, # Auto rank via energy retention + scale: torch.Tensor | None = None, # Per-channel activation RMS for ASVD ) -> nn.Sequential: - "Perform SVD decomposition on a single Linear layer" + "Perform SVD decomposition. With scale: activation-aware SVD (ASVD)." W = layer.weight.data - U, S, Vh = torch.linalg.svd(W, full_matrices=False) - L = max(1, int((1.-percent_removed) * S.shape[0])) + + # ASVD: scale columns by activation RMS before SVD + if scale is not None: + s = scale.to(W.device) + 1e-6 + W_scaled = W * s.unsqueeze(0) # (out, in) * (1, in) + else: + W_scaled = W + + U, S, Vh = torch.linalg.svd(W_scaled, full_matrices=False) + + if energy_threshold is not None: + L = _rank_from_energy(S, energy_threshold) + else: + L = max(1, int((1.-percent_removed) * S.shape[0])) + W1 = U[:,:L] W2 = torch.diag(S[:L]) @ Vh[:L] + + # ASVD: undo scaling in the first layer's weights + if scale is not None: + s_inv = 1.0 / s + W2 = W2 * s_inv.unsqueeze(0) # (L, in) * (1, in) + layer_1 = nn.Linear(in_features=layer.in_features, out_features=L, bias=False) layer_1.weight.data = W2 diff --git a/nbs/_quarto.yml b/nbs/_quarto.yml index 3490807..972f71b 100644 --- a/nbs/_quarto.yml +++ b/nbs/_quarto.yml @@ -114,6 +114,8 @@ website: contents: - misc/bn_folding.ipynb - misc/fc_decomposer.ipynb + - misc/conv_decomposer.ipynb + - misc/cpu_optimizer.ipynb - section: Export contents: - export/onnx_exporter.ipynb diff --git a/nbs/misc/bn_folding.ipynb b/nbs/misc/bn_folding.ipynb index ea50d7c..8f2f01e 100644 --- a/nbs/misc/bn_folding.ipynb +++ b/nbs/misc/bn_folding.ipynb @@ -45,7 +45,6 @@ "#| export\n", "import torch\n", "import torch.nn as nn\n", - "import torch.nn.functional as F\n", "import copy" ] }, diff --git a/nbs/misc/conv_decomposer.ipynb b/nbs/misc/conv_decomposer.ipynb index 5edb688..810a5da 100644 --- a/nbs/misc/conv_decomposer.ipynb +++ b/nbs/misc/conv_decomposer.ipynb @@ -5,12 +5,7 @@ "id": "frontmatter", "metadata": {}, "source": [ - "---", - "description: Decompose Conv2d layers via Tucker decomposition", - "output-file: conv_decomposer.html", - "title: Conv2d Layers Decomposer", - "skip_showdoc: true", - "---" + "---description: Decompose Conv2d layers via Tucker decompositionoutput-file: conv_decomposer.htmltitle: Conv2d Layers Decomposerskip_showdoc: true---" ] }, { @@ -29,13 +24,35 @@ "id": "showdoc-import", "metadata": {}, "outputs": [], - "source": "#| include: false\nfrom nbdev.showdoc import *" + "source": [ + "#| include: false\n", + "from nbdev.showdoc import *" + ] }, { "cell_type": "markdown", "id": "overview", "metadata": {}, - "source": "## Overview\n\nThe `Conv_Decomposer` class reduces model size and FLOPs by factorizing Conv2d layers into three smaller convolutions using Tucker decomposition. This is the Conv2d counterpart of `FC_Decomposer` (which uses SVD for Linear layers).\n\n**How it works:** A Conv2d weight `[C_out, C_in, H, W]` is decomposed into:\n1. `Conv2d(C_in, R_in, 1)` — pointwise input channel compression\n2. `Conv2d(R_in, R_out, (H, W))` — spatial convolution at reduced rank\n3. `Conv2d(R_out, C_out, 1)` — pointwise output channel expansion\n\n### When to Use\n\n| Scenario | Recommendation |\n|----------|----------------|\n| Large 3x3 or larger convolutions | **Highly recommended** — significant FLOP savings |\n| 1x1 pointwise convolutions | Skipped automatically (already minimal) |\n| Depthwise / grouped convolutions | Skipped (Tucker assumes standard convolution) |\n| First layer (C_in=3) | Works but limited benefit |\n| Post-training compression | Fine-tune after decomposition for best accuracy |" + "source": [ + "## Overview\n", + "\n", + "The `Conv_Decomposer` class reduces model size and FLOPs by factorizing Conv2d layers into three smaller convolutions using Tucker decomposition. This is the Conv2d counterpart of `FC_Decomposer` (which uses SVD for Linear layers).\n", + "\n", + "**How it works:** A Conv2d weight `[C_out, C_in, H, W]` is decomposed into:\n", + "1. `Conv2d(C_in, R_in, 1)` — pointwise input channel compression\n", + "2. `Conv2d(R_in, R_out, (H, W))` — spatial convolution at reduced rank\n", + "3. `Conv2d(R_out, C_out, 1)` — pointwise output channel expansion\n", + "\n", + "### When to Use\n", + "\n", + "| Scenario | Recommendation |\n", + "|----------|----------------|\n", + "| Large 3x3 or larger convolutions | **Highly recommended** — significant FLOP savings |\n", + "| 1x1 pointwise convolutions | Skipped automatically (already minimal) |\n", + "| Depthwise / grouped convolutions | Skipped (Tucker assumes standard convolution) |\n", + "| First layer (C_in=3) | Works but limited benefit |\n", + "| Post-training compression | Fine-tune after decomposition for best accuracy |" + ] }, { "cell_type": "code", @@ -43,7 +60,7 @@ "id": "imports", "metadata": {}, "outputs": [], - "source": "#| export\nimport torch\nimport torch.nn as nn\nimport copy" + "source": "#| export\nimport torch\nimport torch.nn as nn\nimport copy\nfrom einops import rearrange" }, { "cell_type": "code", @@ -51,7 +68,205 @@ "id": "conv-decomposer", "metadata": {}, "outputs": [], - "source": "#| export\ndef _unfold(tensor, mode):\n \"Unfold a tensor along a mode into a matrix\"\n return tensor.moveaxis(mode, 0).flatten(1)\n\ndef _partial_tucker(weight, ranks, n_iter=5):\n \"Partial Tucker decomposition on modes [0, 1] via alternating SVD (HOOI)\"\n # Initialize factors from SVD of mode unfoldings\n U0 = torch.linalg.svd(_unfold(weight, 0), full_matrices=False)[0][:, :ranks[0]]\n U1 = torch.linalg.svd(_unfold(weight, 1), full_matrices=False)[0][:, :ranks[1]]\n\n for _ in range(n_iter):\n # Project out mode 0 using U0, then update U1\n proj = torch.einsum('oihw, or -> rihw', weight, U0)\n U1 = torch.linalg.svd(_unfold(proj, 1), full_matrices=False)[0][:, :ranks[1]]\n # Project out mode 1 using U1, then update U0\n proj = torch.einsum('oihw, is -> oshw', weight, U1)\n U0 = torch.linalg.svd(_unfold(proj, 0), full_matrices=False)[0][:, :ranks[0]]\n\n # Core = W ×₀ U0ᵀ ×₁ U1ᵀ\n core = torch.einsum('oihw, or, is -> rshw', weight, U0, U1)\n return core, [U0, U1]\n\n\nclass Conv_Decomposer:\n \"Decompose Conv2d layers using Tucker decomposition to reduce parameters and FLOPs\"\n\n def __init__(self): pass\n\n def decompose(self,\n model: nn.Module, # The model to decompose\n percent_removed: float = 0.5, # Fraction of rank to remove per mode [0, 1)\n ) -> nn.Module:\n \"Recursively decompose all eligible Conv2d layers in the model\"\n if not (0 <= percent_removed < 1):\n raise ValueError(f\"percent_removed must be in range [0, 1), got {percent_removed}\")\n\n new_model = copy.deepcopy(model)\n for name in list(new_model._modules):\n module = new_model._modules[name]\n if len(list(module._modules)) > 0:\n new_model._modules[name] = self.decompose(module, percent_removed)\n elif isinstance(module, nn.Conv2d) and module.groups == 1 and min(module.kernel_size) > 1:\n new_model._modules[name] = self.Tucker(module, percent_removed)\n return new_model\n\n def Tucker(self,\n layer: nn.Conv2d, # The Conv2d layer to decompose\n percent_removed: float, # Fraction of rank to remove per mode\n ) -> nn.Sequential:\n \"Perform Tucker decomposition on a single Conv2d layer\"\n W = layer.weight.data\n C_out, C_in = W.shape[:2]\n\n R_out = max(1, int((1 - percent_removed) * C_out))\n R_in = max(1, int((1 - percent_removed) * C_in))\n\n core, (U_out, U_in) = _partial_tucker(W, [R_out, R_in])\n # core: (R_out, R_in, H, W), U_out: (C_out, R_out), U_in: (C_in, R_in)\n\n # 1. Pointwise input compression: (C_in → R_in)\n first = nn.Conv2d(C_in, R_in, 1, bias=False)\n first.weight.data = U_in.t().unsqueeze(-1).unsqueeze(-1)\n\n # 2. Spatial convolution at reduced rank: (R_in → R_out)\n middle = nn.Conv2d(R_in, R_out, layer.kernel_size, stride=layer.stride,\n padding=layer.padding, dilation=layer.dilation, bias=False)\n middle.weight.data = core\n\n # 3. Pointwise output expansion: (R_out → C_out)\n last = nn.Conv2d(R_out, C_out, 1, bias=layer.bias is not None)\n last.weight.data = U_out.unsqueeze(-1).unsqueeze(-1)\n if layer.bias is not None:\n last.bias.data = layer.bias.data\n\n return nn.Sequential(first, middle, last)" + "source": [ + "#| export\n", + "from fasterai.misc.fc_decomposer import _rank_from_energy, _should_decompose\n", + "\n", + "def _mode_unfold(W, mode):\n", + " \"Unfold a 4D tensor along a mode into a 2D matrix\"\n", + " return rearrange(W, 'o i h w -> o (i h w)') if mode == 0 else rearrange(W, 'o i h w -> i (o h w)')\n", + "\n", + "def _partial_tucker(weight, ranks, n_iter=10, tol=1e-4):\n", + " \"Partial Tucker decomposition on modes [0, 1] via alternating SVD (HOOI)\"\n", + " U0 = torch.linalg.svd(_mode_unfold(weight, 0), full_matrices=False)[0][:, :ranks[0]]\n", + " U1 = torch.linalg.svd(_mode_unfold(weight, 1), full_matrices=False)[0][:, :ranks[1]]\n", + "\n", + " for _ in range(n_iter):\n", + " U0_prev, U1_prev = U0.clone(), U1.clone()\n", + " proj = torch.einsum('oihw, or -> rihw', weight, U0)\n", + " U1 = torch.linalg.svd(_mode_unfold(proj, 1), full_matrices=False)[0][:, :ranks[1]]\n", + " proj = torch.einsum('oihw, is -> oshw', weight, U1)\n", + " U0 = torch.linalg.svd(_mode_unfold(proj, 0), full_matrices=False)[0][:, :ranks[0]]\n", + " if (U0 - U0_prev).norm() + (U1 - U1_prev).norm() < tol: break\n", + "\n", + " core = torch.einsum('oihw, or, is -> rshw', weight, U0, U1)\n", + " return core, [U0, U1]\n", + "\n", + "VALID_METHODS = frozenset({'tucker', 'svd', 'spatial', 'cp'})\n", + "\n", + "class Conv_Decomposer:\n", + " \"Decompose Conv2d layers to reduce parameters and FLOPs\"\n", + "\n", + " def __init__(self): pass\n", + "\n", + " def decompose(self,\n", + " model: nn.Module, # The model to decompose\n", + " percent_removed: float = 0.5, # Fraction of rank to remove [0, 1)\n", + " method: str = 'tucker', # 'tucker', 'svd', 'spatial', or 'cp'\n", + " energy_threshold: float | None = None, # Auto rank via energy retention (0-1)\n", + " layers: list[str] | None = None, # Layer names to decompose (None = all eligible)\n", + " exclude: list[str] | None = None, # Layer names to skip\n", + " n_iter: int = 10, # Max HOOI iterations (tucker only)\n", + " tol: float = 1e-4, # HOOI convergence tolerance (tucker only)\n", + " ) -> nn.Module:\n", + " \"Decompose eligible Conv2d layers using the specified method.\"\n", + " if method not in VALID_METHODS:\n", + " raise ValueError(f\"method must be one of {VALID_METHODS}, got {method!r}\")\n", + " if energy_threshold is None and not (0 <= percent_removed < 1):\n", + " raise ValueError(f\"percent_removed must be in range [0, 1), got {percent_removed}\")\n", + " if energy_threshold is not None and not (0 < energy_threshold <= 1):\n", + " raise ValueError(f\"energy_threshold must be in range (0, 1], got {energy_threshold}\")\n", + "\n", + " decompose_fn = {'tucker': self.Tucker, 'svd': self.SVD,\n", + " 'spatial': self.Spatial, 'cp': self.CP}[method]\n", + "\n", + " new_model = copy.deepcopy(model)\n", + " for name, module in list(new_model.named_modules()):\n", + " if (isinstance(module, nn.Conv2d) and module.groups == 1 \n", + " and min(module.kernel_size) > 1\n", + " and _should_decompose(name, layers, exclude)):\n", + " parent_name, _, child_name = name.rpartition('.')\n", + " parent = new_model.get_submodule(parent_name) if parent_name else new_model\n", + " if method == 'tucker':\n", + " replacement = decompose_fn(module, percent_removed, energy_threshold, n_iter, tol)\n", + " else:\n", + " replacement = decompose_fn(module, percent_removed, energy_threshold)\n", + " setattr(parent, child_name, replacement)\n", + " return new_model\n", + "\n", + " def SVD(self,\n", + " layer: nn.Conv2d,\n", + " percent_removed: float = 0.5,\n", + " energy_threshold: float | None = None,\n", + " ) -> nn.Sequential:\n", + " \"SVD: 2 layers — spatial at reduced output rank + pointwise expansion\"\n", + " W = layer.weight.data\n", + " C_out, C_in = W.shape[:2]\n", + " K = layer.kernel_size\n", + "\n", + " W_2d = rearrange(W, 'o i h w -> o (i h w)')\n", + "\n", + " U, S, Vh = torch.linalg.svd(W_2d, full_matrices=False)\n", + " R = _rank_from_energy(S, energy_threshold) if energy_threshold else max(1, int((1 - percent_removed) * min(C_out, C_in)))\n", + "\n", + " W_first = torch.diag(S[:R]) @ Vh[:R]\n", + "\n", + " first = nn.Conv2d(C_in, R, K, stride=layer.stride,\n", + " padding=layer.padding, dilation=layer.dilation, bias=False)\n", + " first.weight.data = rearrange(W_first, 'r (i h w) -> r i h w', i=C_in, h=K[0], w=K[1])\n", + "\n", + " last = nn.Conv2d(R, C_out, 1, bias=layer.bias is not None)\n", + " last.weight.data = rearrange(U[:, :R], 'o r -> o r 1 1')\n", + " if layer.bias is not None: last.bias.data = layer.bias.data\n", + "\n", + " return nn.Sequential(first, last)\n", + "\n", + " def Spatial(self,\n", + " layer: nn.Conv2d,\n", + " percent_removed: float = 0.5,\n", + " energy_threshold: float | None = None,\n", + " ) -> nn.Sequential:\n", + " \"Spatial separable: 2 layers — K×1 vertical + 1×K horizontal (batched SVD)\"\n", + " W = layer.weight.data\n", + " C_out, C_in = W.shape[:2]\n", + " Kh, Kw = layer.kernel_size\n", + "\n", + " W_spatial = rearrange(W, 'o i h w -> (o i) h w')\n", + " U_all, S_all, Vh_all = torch.linalg.svd(W_spatial, full_matrices=False)\n", + " R = _rank_from_energy(S_all[0], energy_threshold) if energy_threshold else max(1, int((1 - percent_removed) * min(Kh, Kw)))\n", + "\n", + " U_scaled = U_all[:, :, :R] * S_all[:, :R].unsqueeze(1).sqrt()\n", + " W_vert = rearrange(U_scaled, '(o i) h r -> (o r) i h 1', o=C_out, i=C_in)\n", + "\n", + " Vh_scaled = S_all[:, :R].unsqueeze(2).sqrt() * Vh_all[:, :R, :]\n", + " Vh_by_out = rearrange(Vh_scaled, '(o i) r w -> o i r w', o=C_out)\n", + " W_horiz = rearrange(Vh_by_out.mean(dim=1), 'o r w -> o r 1 w')\n", + "\n", + " vert = nn.Conv2d(C_in, C_out * R, (Kh, 1),\n", + " stride=(layer.stride[0], 1), padding=(layer.padding[0], 0), bias=False)\n", + " vert.weight.data = W_vert\n", + "\n", + " horiz = nn.Conv2d(C_out * R, C_out, (1, Kw), groups=C_out,\n", + " stride=(1, layer.stride[1]), padding=(0, layer.padding[1]),\n", + " bias=layer.bias is not None)\n", + " horiz.weight.data = W_horiz\n", + " if layer.bias is not None: horiz.bias.data = layer.bias.data\n", + "\n", + " return nn.Sequential(vert, horiz)\n", + "\n", + " def CP(self,\n", + " layer: nn.Conv2d,\n", + " percent_removed: float = 0.5,\n", + " energy_threshold: float | None = None,\n", + " ) -> nn.Sequential:\n", + " \"CP: 4 layers — pointwise compress + depthwise vertical + depthwise horizontal + pointwise expand\"\n", + " W = layer.weight.data\n", + " C_out, C_in = W.shape[:2]\n", + " Kh, Kw = layer.kernel_size\n", + "\n", + " W_2d = rearrange(W, 'o i h w -> o (i h w)')\n", + " U, S, Vh = torch.linalg.svd(W_2d, full_matrices=False)\n", + " S0 = torch.linalg.svd(_mode_unfold(W, 0), full_matrices=False)[1]\n", + " R = _rank_from_energy(S0, energy_threshold) if energy_threshold else max(1, int((1 - percent_removed) * min(C_out, C_in)))\n", + "\n", + " V_4d = rearrange(Vh[:R], 'r (i h w) -> r i h w', i=C_in, h=Kh, w=Kw)\n", + " spatial_avg = V_4d.mean(dim=1)\n", + " U_s, S_s, Vh_s = torch.linalg.svd(spatial_avg, full_matrices=False)\n", + "\n", + " W_dw_v = rearrange(U_s[:, :, 0] * S_s[:, 0:1].sqrt(), 'r h -> r 1 h 1')\n", + " W_dw_h = rearrange(Vh_s[:, 0, :] * S_s[:, 0:1].sqrt(), 'r w -> r 1 1 w')\n", + " channel_norms = V_4d.pow(2).sum(dim=(2, 3)).sqrt()\n", + " W_pw_in = rearrange(channel_norms * S[:R].sqrt().unsqueeze(1), 'r i -> r i 1 1')\n", + " W_pw_out = rearrange(U[:, :R] * S[:R].sqrt().unsqueeze(0), 'o r -> o r 1 1')\n", + "\n", + " pw_in = nn.Conv2d(C_in, R, 1, bias=False)\n", + " pw_in.weight.data = W_pw_in\n", + " dw_v = nn.Conv2d(R, R, (Kh, 1), groups=R, stride=(layer.stride[0], 1),\n", + " padding=(layer.padding[0], 0), bias=False)\n", + " dw_v.weight.data = W_dw_v\n", + " dw_h = nn.Conv2d(R, R, (1, Kw), groups=R, stride=(1, layer.stride[1]),\n", + " padding=(0, layer.padding[1]), bias=False)\n", + " dw_h.weight.data = W_dw_h\n", + " pw_out = nn.Conv2d(R, C_out, 1, bias=layer.bias is not None)\n", + " pw_out.weight.data = W_pw_out\n", + " if layer.bias is not None: pw_out.bias.data = layer.bias.data\n", + "\n", + " return nn.Sequential(pw_in, dw_v, dw_h, pw_out)\n", + "\n", + " def Tucker(self,\n", + " layer: nn.Conv2d,\n", + " percent_removed: float = 0.5,\n", + " energy_threshold: float | None = None,\n", + " n_iter: int = 10,\n", + " tol: float = 1e-4,\n", + " ) -> nn.Sequential:\n", + " \"Tucker: 3 layers — pointwise compress + spatial + pointwise expand\"\n", + " W = layer.weight.data\n", + " C_out, C_in = W.shape[:2]\n", + "\n", + " if energy_threshold is not None:\n", + " S0 = torch.linalg.svd(_mode_unfold(W, 0), full_matrices=False)[1]\n", + " S1 = torch.linalg.svd(_mode_unfold(W, 1), full_matrices=False)[1]\n", + " R_out = _rank_from_energy(S0, energy_threshold)\n", + " R_in = _rank_from_energy(S1, energy_threshold)\n", + " else:\n", + " R_out = max(1, int((1 - percent_removed) * C_out))\n", + " R_in = max(1, int((1 - percent_removed) * C_in))\n", + " core, (U_out, U_in) = _partial_tucker(W, [R_out, R_in], n_iter=n_iter, tol=tol)\n", + "\n", + " first = nn.Conv2d(C_in, R_in, 1, bias=False)\n", + " first.weight.data = rearrange(U_in.t(), 'r i -> r i 1 1')\n", + "\n", + " middle = nn.Conv2d(R_in, R_out, layer.kernel_size, stride=layer.stride,\n", + " padding=layer.padding, dilation=layer.dilation, bias=False)\n", + " middle.weight.data = core\n", + "\n", + " last = nn.Conv2d(R_out, C_out, 1, bias=layer.bias is not None)\n", + " last.weight.data = rearrange(U_out, 'o r -> o r 1 1')\n", + " if layer.bias is not None: last.bias.data = layer.bias.data\n", + "\n", + " return nn.Sequential(first, middle, last)" + ] }, { "cell_type": "code", @@ -77,7 +292,27 @@ "cell_type": "markdown", "id": "usage", "metadata": {}, - "source": "---\n\n## Usage Example\n\n```python\nfrom fasterai.misc.conv_decomposer import Conv_Decomposer\nfrom torchvision.models import resnet18\n\nmodel = resnet18(pretrained=True)\ndecomposer = Conv_Decomposer()\ncompressed = decomposer.decompose(model, percent_removed=0.5)\n\n# Check parameter reduction\norig = sum(p.numel() for p in model.parameters())\ncomp = sum(p.numel() for p in compressed.parameters())\nprint(f\"Compression: {orig/comp:.2f}x\")\n```\n\n> **Note:** Tucker decomposition uses an iterative algorithm (HOOI), so even at `percent_removed=0.0` there will be small reconstruction error. Fine-tuning after decomposition is recommended." + "source": [ + "---\n", + "\n", + "## Usage Example\n", + "\n", + "```python\n", + "from fasterai.misc.conv_decomposer import Conv_Decomposer\n", + "from torchvision.models import resnet18\n", + "\n", + "model = resnet18(pretrained=True)\n", + "decomposer = Conv_Decomposer()\n", + "compressed = decomposer.decompose(model, percent_removed=0.5)\n", + "\n", + "# Check parameter reduction\n", + "orig = sum(p.numel() for p in model.parameters())\n", + "comp = sum(p.numel() for p in compressed.parameters())\n", + "print(f\"Compression: {orig/comp:.2f}x\")\n", + "```\n", + "\n", + "> **Note:** Tucker decomposition uses an iterative algorithm (HOOI), so even at `percent_removed=0.0` there will be small reconstruction error. Fine-tuning after decomposition is recommended." + ] }, { "cell_type": "code", @@ -85,7 +320,7 @@ "id": "tests", "metadata": {}, "outputs": [], - "source": "#| hide\nfrom fastcore.test import *\n\ndecomposer = Conv_Decomposer()\n\n# --- Output shape preserved ---\n_m = nn.Sequential(nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(), nn.Conv2d(16, 32, 3, padding=1))\n_x = torch.randn(2, 3, 8, 8)\n_m_dec = decomposer.decompose(_m, percent_removed=0.5)\ntest_eq(_m(_x).shape, _m_dec(_x).shape)\n\n# --- percent_removed=0.0 → close reconstruction (HOOI is iterative, not exact) ---\n_m2 = nn.Sequential(nn.Conv2d(16, 32, 3, padding=1))\n_x2 = torch.randn(2, 16, 8, 8)\n_m2_dec = decomposer.decompose(_m2, percent_removed=0.0)\ntest_close(_m2(_x2), _m2_dec(_x2), eps=0.01)\n\n# --- Decomposed structure: Conv2d becomes Sequential of 3 Conv2ds ---\nassert isinstance(_m_dec[0], nn.Sequential)\ntest_eq(len(_m_dec[0]), 3)\ntest_eq(_m_dec[0][0].kernel_size, (1, 1)) # pointwise in\ntest_eq(_m_dec[0][1].kernel_size, (3, 3)) # spatial\ntest_eq(_m_dec[0][2].kernel_size, (1, 1)) # pointwise out\n\n# --- 1x1 convolutions are skipped ---\n_m_pw = nn.Sequential(nn.Conv2d(16, 32, 1))\n_m_pw_dec = decomposer.decompose(_m_pw, percent_removed=0.5)\nassert isinstance(_m_pw_dec[0], nn.Conv2d) # unchanged, not Sequential\n\n# --- Grouped convolutions are skipped ---\n_m_dw = nn.Sequential(nn.Conv2d(16, 16, 3, padding=1, groups=16))\n_m_dw_dec = decomposer.decompose(_m_dw, percent_removed=0.5)\nassert isinstance(_m_dw_dec[0], nn.Conv2d) # unchanged\n\n# --- Minimum rank >= 1 even at extreme removal ---\n_m3 = nn.Sequential(nn.Conv2d(4, 8, 3, padding=1))\n_m3_dec = decomposer.decompose(_m3, percent_removed=0.95)\ntest_eq(_m3_dec[0][0].out_features if hasattr(_m3_dec[0][0], 'out_features') else _m3_dec[0][0].out_channels, max(1, int(0.05 * 4)))\n\n# --- Bias handling: original bias → last layer gets it ---\n_conv_bias = nn.Conv2d(16, 32, 3, padding=1, bias=True)\n_dec_bias = decomposer.Tucker(_conv_bias, 0.5)\nassert _dec_bias[0].bias is None # first: no bias\nassert _dec_bias[1].bias is None # middle: no bias\nassert _dec_bias[2].bias is not None # last: has bias\n\n_conv_nobias = nn.Conv2d(16, 32, 3, padding=1, bias=False)\n_dec_nobias = decomposer.Tucker(_conv_nobias, 0.5)\nassert _dec_nobias[2].bias is None # last: no bias\n\n# --- Stride/padding transfer to middle conv only ---\n_conv_stride = nn.Conv2d(16, 32, 3, stride=2, padding=1)\n_dec_stride = decomposer.Tucker(_conv_stride, 0.5)\ntest_eq(_dec_stride[0].stride, (1, 1)) # pointwise: default\ntest_eq(_dec_stride[1].stride, (2, 2)) # middle: from original\ntest_eq(_dec_stride[2].stride, (1, 1)) # pointwise: default\n\n# --- Validation ---\nwith ExceptionExpected(ValueError): decomposer.decompose(nn.Sequential(nn.Conv2d(3, 16, 3)), percent_removed=1.0)\nwith ExceptionExpected(ValueError): decomposer.decompose(nn.Sequential(nn.Conv2d(3, 16, 3)), percent_removed=-0.1)" + "source": "#| hide\nfrom fastcore.test import *\n\ndecomposer = Conv_Decomposer()\n_m = nn.Sequential(nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(), nn.Conv2d(16, 32, 3, padding=1))\n_x = torch.randn(2, 3, 8, 8)\n\n# === All methods produce correct output shape ===\nfor method in ['tucker', 'svd', 'spatial', 'cp']:\n _dec = decomposer.decompose(_m, 0.5, method=method)\n test_eq(_m(_x).shape, _dec(_x).shape)\n assert torch.isfinite(_dec(_x)).all(), f\"{method} produced non-finite output\"\n\n# === Tucker: 3 layers (1x1, KxK, 1x1) ===\n_t = decomposer.decompose(_m, 0.5, method='tucker')\ntest_eq(len(_t[0]), 3)\ntest_eq(_t[0][0].kernel_size, (1, 1))\ntest_eq(_t[0][1].kernel_size, (3, 3))\n\n# === SVD: 2 layers (KxK, 1x1) ===\n_s = decomposer.decompose(_m, 0.5, method='svd')\ntest_eq(len(_s[0]), 2)\ntest_eq(_s[0][0].kernel_size, (3, 3))\ntest_eq(_s[0][1].kernel_size, (1, 1))\n\n# === Spatial: 2 layers (Kx1, 1xK) ===\n_sp = decomposer.decompose(_m, 0.5, method='spatial')\ntest_eq(len(_sp[0]), 2)\ntest_eq(_sp[0][0].kernel_size, (3, 1))\ntest_eq(_sp[0][1].kernel_size, (1, 3))\n\n# === CP: 4 layers (1x1, Kx1, 1xK, 1x1) ===\n_cp = decomposer.decompose(_m, 0.5, method='cp')\ntest_eq(len(_cp[0]), 4)\ntest_eq(_cp[0][0].kernel_size, (1, 1)) # pointwise in\ntest_eq(_cp[0][1].kernel_size, (3, 1)) # depthwise vertical\ntest_eq(_cp[0][2].kernel_size, (1, 3)) # depthwise horizontal\ntest_eq(_cp[0][3].kernel_size, (1, 1)) # pointwise out\n\n# === Common: 1x1 and grouped skipped ===\nassert isinstance(decomposer.decompose(nn.Sequential(nn.Conv2d(16, 32, 1)), 0.5)[0], nn.Conv2d)\nassert isinstance(decomposer.decompose(nn.Sequential(nn.Conv2d(16, 16, 3, groups=16, padding=1)), 0.5)[0], nn.Conv2d)\n\n# === Bias: last layer gets it ===\nfor method in ['tucker', 'svd', 'spatial', 'cp']:\n _dec = decomposer.decompose(nn.Sequential(nn.Conv2d(16, 32, 3, padding=1, bias=True)), 0.5, method=method)\n seq = _dec[0]\n assert seq[-1].bias is not None, f\"{method}: last layer missing bias\"\n for layer in seq[:-1]:\n assert layer.bias is None, f\"{method}: non-last layer has bias\"\n\n# === Stride transfer ===\n_stride = decomposer.Tucker(nn.Conv2d(16, 32, 3, stride=2, padding=1), 0.5)\ntest_eq(_stride[1].stride, (2, 2))\n\n_svd_stride = decomposer.SVD(nn.Conv2d(16, 32, 3, stride=2, padding=1), 0.5)\ntest_eq(_svd_stride[0].stride, (2, 2))\n\n# === Validation ===\nwith ExceptionExpected(ValueError): decomposer.decompose(nn.Sequential(nn.Conv2d(3, 16, 3)), percent_removed=1.0)\nwith ExceptionExpected(ValueError): decomposer.decompose(nn.Sequential(nn.Conv2d(3, 16, 3)), method='bad')\n\n# === energy_threshold + layers/exclude ===\n_m4 = nn.Sequential(nn.Conv2d(16, 32, 3, padding=1))\nassert decomposer.decompose(_m4, energy_threshold=0.99)[0][0].out_channels >= \\\n decomposer.decompose(_m4, 0.5)[0][0].out_channels\n\n_m5 = nn.Sequential(nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(), nn.Conv2d(16, 32, 3, padding=1))\nassert isinstance(decomposer.decompose(_m5, 0.5, layers=['0'])[2], nn.Conv2d)\nassert isinstance(decomposer.decompose(_m5, 0.5, exclude=['2'])[2], nn.Conv2d)" }, { "cell_type": "code", @@ -93,13 +328,40 @@ "id": "8q9xqce7y6k", "metadata": {}, "outputs": [], - "source": "#| hide\n#| slow\nfrom torchvision.models import resnet18\n\n# Decompose a real ResNet-18, verify it still works\n_resnet = resnet18(weights=None)\n_resnet.eval()\n_x = torch.randn(2, 3, 64, 64)\n_out_orig = _resnet(_x)\n\n_dec = Conv_Decomposer()\n_resnet_dec = _dec.decompose(_resnet, percent_removed=0.5)\n_resnet_dec.eval()\n_out_dec = _resnet_dec(_x)\n\n# Same output shape\ntest_eq(_out_orig.shape, _out_dec.shape)\n\n# Outputs are finite (no NaN/Inf)\nassert torch.isfinite(_out_dec).all(), \"Decomposed ResNet produced non-finite outputs\"\n\n# Parameter count reduced\n_orig_params = sum(p.numel() for p in _resnet.parameters())\n_dec_params = sum(p.numel() for p in _resnet_dec.parameters())\nassert _dec_params < _orig_params, f\"Expected fewer params: {_dec_params} >= {_orig_params}\"\nprint(f\"ResNet-18: {_orig_params:,} → {_dec_params:,} params ({_orig_params/_dec_params:.2f}x compression)\")" + "source": [ + "#| hide\n", + "#| slow\n", + "from torchvision.models import resnet18\n", + "\n", + "# Decompose a real ResNet-18, verify it still works\n", + "_resnet = resnet18(weights=None)\n", + "_resnet.eval()\n", + "_x = torch.randn(2, 3, 64, 64)\n", + "_out_orig = _resnet(_x)\n", + "\n", + "_dec = Conv_Decomposer()\n", + "_resnet_dec = _dec.decompose(_resnet, percent_removed=0.5)\n", + "_resnet_dec.eval()\n", + "_out_dec = _resnet_dec(_x)\n", + "\n", + "# Same output shape\n", + "test_eq(_out_orig.shape, _out_dec.shape)\n", + "\n", + "# Outputs are finite (no NaN/Inf)\n", + "assert torch.isfinite(_out_dec).all(), \"Decomposed ResNet produced non-finite outputs\"\n", + "\n", + "# Parameter count reduced\n", + "_orig_params = sum(p.numel() for p in _resnet.parameters())\n", + "_dec_params = sum(p.numel() for p in _resnet_dec.parameters())\n", + "assert _dec_params < _orig_params, f\"Expected fewer params: {_dec_params} >= {_orig_params}\"\n", + "print(f\"ResNet-18: {_orig_params:,} → {_dec_params:,} params ({_orig_params/_dec_params:.2f}x compression)\")" + ] }, { "cell_type": "markdown", "id": "seealso", "metadata": {}, - "source": "---\n\n## See Also\n\n- [FC Decomposer](fc_decomposer.html) - SVD decomposition for Linear layers\n- [BN Folding](bn_folding.html) - Fold BatchNorm into preceding Conv/Linear layers\n- [Pruner](../prune/pruner.html) - Structured pruning that removes entire filters" + "source": "## Future Work\n\n- **LayerNorm_Folder**: Fold LayerNorm into adjacent Linear layers for transformer inference (analogous to BN_Folder)\n- **NuclearNormCallback**: Add nuclear norm regularization during training to pre-condition weights for better SVD/Tucker decomposition (Low-Rank Prehab, arxiv 2512.01980)\n- **Latency-aware rank selection**: Use fasterlatency to predict actual speedup at each rank, selecting ranks to hit a target latency budget rather than parameter budget (FLAR-SVD, CVPRW 2025)\n\n---\n\n## See Also\n\n- [FC Decomposer](fc_decomposer.html) - SVD decomposition for Linear layers\n- [BN Folding](bn_folding.html) - Fold BatchNorm into preceding Conv/Linear layers\n- [Pruner](../prune/pruner.html) - Structured pruning that removes entire filters" } ], "metadata": { diff --git a/nbs/misc/cpu_optimizer.ipynb b/nbs/misc/cpu_optimizer.ipynb index b953d8d..426faf7 100644 --- a/nbs/misc/cpu_optimizer.ipynb +++ b/nbs/misc/cpu_optimizer.ipynb @@ -6,11 +6,10 @@ "metadata": {}, "source": [ "---\n", - "description: Further optimize for CPU inference\n", + "description: Optimize models for CPU inference\n", "output-file: cpu_optimizer.html\n", - "title: Further optimize for CPU inference\n", + "title: CPU Optimizer\n", "skip_showdoc: true\n", - "skip_exec: true\n", "---" ] }, @@ -45,7 +44,7 @@ "#| export\n", "import torch\n", "import torch.nn as nn\n", - "from torch.utils.mobile_optimizer import optimize_for_mobile" + "import warnings" ] }, { @@ -55,16 +54,15 @@ "source": [ "## Overview\n", "\n", - "The `accelerate_model_for_cpu` function applies optimizations to prepare a PyTorch model for efficient CPU inference. It combines several techniques:\n", + "`optimize_for_cpu` prepares a model for efficient CPU inference by combining:\n", "\n", - "1. **Channels-last memory format**: Optimizes memory layout for CNN operations on CPU\n", - "2. **TorchScript compilation**: JIT compiles the model for faster execution\n", - "3. **Mobile optimization**: Applies `optimize_for_mobile` for operator fusion and other optimizations\n", + "1. **Channels-last memory format** — optimizes layout for CNN operations on CPU\n", + "2. **Compilation** — `torch.compile` (default) or `torch.jit.trace` for operator fusion\n", "\n", - "**When to use:**\n", - "- Deploying models on CPU-only servers\n", - "- Edge deployment without GPU\n", - "- After quantization for maximum CPU performance" + "| Backend | Speed | Compatibility | Best For |\n", + "|---------|-------|---------------|----------|\n", + "| `\"compile\"` | Faster | Most models | Default choice |\n", + "| `\"trace\"` | Good | Requires static shapes | Legacy / mobile |" ] }, { @@ -75,15 +73,32 @@ "outputs": [], "source": [ "#| export\n", - "def accelerate_model_for_cpu(model: nn.Module, example_input: torch.Tensor):\n", - " model.eval()\n", - " example_input = example_input.to(memory_format=torch.channels_last)\n", - " \n", - " model = model.to(memory_format=torch.channels_last)\n", - " model = torch.jit.script(model)\n", - " model = optimize_for_mobile(model)\n", + "def optimize_for_cpu(\n", + " model: nn.Module, # The PyTorch model to optimize\n", + " sample: torch.Tensor, # Sample input for tracing (with batch dim)\n", + " *,\n", + " backend: str = \"compile\", # \"compile\" (torch.compile) or \"trace\" (torch.jit.trace)\n", + " compile_mode: str = \"default\", # torch.compile mode\n", + ") -> nn.Module:\n", + " \"Optimize model for CPU inference via channels-last layout + compilation\"\n", + " model = model.eval().to(memory_format=torch.channels_last)\n", + " sample = sample.to(memory_format=torch.channels_last)\n", + "\n", + " if backend == \"compile\":\n", + " return torch.compile(model, mode=compile_mode)\n", + " elif backend == \"trace\":\n", + " with torch.no_grad():\n", + " return torch.jit.trace(model, sample)\n", + " else:\n", + " raise ValueError(f\"Unknown backend: {backend!r}. Use 'compile' or 'trace'.\")\n", "\n", - " return model" + "def accelerate_model_for_cpu(model: nn.Module, example_input: torch.Tensor):\n", + " \"Deprecated: use optimize_for_cpu() instead\"\n", + " warnings.warn(\n", + " \"accelerate_model_for_cpu is deprecated, use optimize_for_cpu(model, sample) instead\",\n", + " DeprecationWarning, stacklevel=2,\n", + " )\n", + " return optimize_for_cpu(model, example_input, backend=\"trace\")" ] }, { @@ -91,51 +106,9 @@ "execution_count": null, "id": "50222d43", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Found permutation search CUDA kernels\n", - "[ASP][Info] permutation_search_kernels can be imported.\n" - ] - }, - { - "data": { - "text/markdown": [ - "---\n", - "\n", - "[source](https://github.com/FasterAI-Labs/fasterai/tree/master/blob/master/fasterai/misc/cpu_optimizer.py#L12){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", - "\n", - "### accelerate_model_for_cpu\n", - "\n", - "```python\n", - "\n", - "def accelerate_model_for_cpu(\n", - " model:Module, example_input:Tensor\n", - "):\n", - "\n", - "\n", - "```" - ], - "text/plain": [ - "```python\n", - "\n", - "def accelerate_model_for_cpu(\n", - " model:Module, example_input:Tensor\n", - "):\n", - "\n", - "\n", - "```" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "show_doc(accelerate_model_for_cpu)" + "show_doc(optimize_for_cpu)" ] }, { @@ -143,41 +116,41 @@ "id": "78818w1gh87", "metadata": {}, "source": [ - "**Parameters:**\n", - "\n", - "- `model`: The PyTorch model to optimize\n", - "- `example_input`: A sample input tensor (used for tracing)\n", - "\n", - "**Returns:** An optimized TorchScript model\n", - "\n", - "---\n", - "\n", - "## Usage Example\n", - "\n", "```python\n", - "from fasterai.misc.cpu_optimizer import accelerate_model_for_cpu\n", - "import torch\n", + "from fasterai.misc.cpu_optimizer import optimize_for_cpu\n", "\n", - "# Create example input matching your model's expected shape\n", - "example_input = torch.randn(1, 3, 224, 224)\n", + "model = resnet18(pretrained=True)\n", + "sample = torch.randn(1, 3, 224, 224)\n", "\n", - "# Optimize model for CPU inference\n", - "optimized_model = accelerate_model_for_cpu(model, example_input)\n", + "# Default: torch.compile\n", + "optimized = optimize_for_cpu(model, sample)\n", "\n", - "# Use the optimized model\n", - "with torch.no_grad():\n", - " output = optimized_model(input_tensor)\n", + "# Or JIT trace for mobile/static shapes\n", + "traced = optimize_for_cpu(model, sample, backend=\"trace\")\n", "```\n", "\n", - "**Note:** The returned model is a TorchScript model. Some dynamic Python features may not be supported." + "> **Note:** `accelerate_model_for_cpu` is deprecated. Use `optimize_for_cpu` instead." ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "test_cpu_opt", + "metadata": {}, + "outputs": [], + "source": "#| hide\nfrom fastcore.test import *\nimport torch, torch.nn as nn\n\n# optimize_for_cpu with trace backend\n_m = nn.Sequential(nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(16, 10))\n_x = torch.randn(1, 3, 8, 8)\n_traced = optimize_for_cpu(_m, _x, backend=\"trace\")\n_out = _traced(_x.to(memory_format=torch.channels_last))\ntest_eq(_out.shape, (1, 10))\nassert torch.isfinite(_out).all()\n\n# Invalid backend raises ValueError\nwith ExceptionExpected(ValueError): optimize_for_cpu(_m, _x, backend=\"bad\")\n\n# Deprecated function emits warning\nimport warnings\nwith warnings.catch_warnings(record=True) as w:\n warnings.simplefilter(\"always\")\n accelerate_model_for_cpu(nn.Sequential(nn.Conv2d(3, 16, 3), nn.ReLU()), torch.randn(1, 3, 8, 8))\n dep_warnings = [x for x in w if issubclass(x.category, DeprecationWarning)]\n assert len(dep_warnings) >= 1, f\"Expected DeprecationWarning, got {[x.category for x in w]}\"" + }, { "cell_type": "markdown", "id": "see_also", "metadata": {}, "source": [ - "---\n\n## See Also\n\n- [BN Folding](bn_folding.html) — Fold batch normalization\n- [ONNX Export](../export/onnx_exporter.html) — Export for deployment" + "---\n", + "\n", + "## See Also\n", + "\n", + "- [BN Folding](bn_folding.html) — Fold batch normalization\n", + "- [ONNX Export](../export/onnx_exporter.html) — Export for deployment" ] } ], diff --git a/nbs/misc/fc_decomposer.ipynb b/nbs/misc/fc_decomposer.ipynb index 6a12fdb..67e4f71 100644 --- a/nbs/misc/fc_decomposer.ipynb +++ b/nbs/misc/fc_decomposer.ipynb @@ -101,7 +101,6 @@ "#| export\n", "import torch\n", "import torch.nn as nn\n", - "import torch.nn.functional as F\n", "import copy" ] }, @@ -113,43 +112,122 @@ "outputs": [], "source": [ "#| export\n", + "def _rank_from_energy(S, threshold):\n", + " \"Find minimum rank to retain `threshold` fraction of singular value energy\"\n", + " energy = S.pow(2).cumsum(0) / S.pow(2).sum()\n", + " idx = (energy >= threshold).nonzero(as_tuple=True)[0]\n", + " return max(1, int(idx[0].item()) + 1) if len(idx) > 0 else S.shape[0]\n", + "\n", + "def _should_decompose(name, layers=None, exclude=None):\n", + " \"Check if a named layer should be decomposed\"\n", + " if exclude and name in exclude: return False\n", + " if layers is not None: return name in layers\n", + " return True\n", + "\n", + "def _collect_activation_rms(\n", + " model: nn.Module, # Model to calibrate\n", + " data, # Tensor, list of batches, or DataLoader\n", + " layer_type: type = nn.Linear, # Layer types to hook\n", + " n_batches: int = 5, # Max batches to process\n", + ") -> dict[nn.Module, torch.Tensor]:\n", + " \"Collect per-input-channel RMS activation norms via forward hooks\"\n", + " device = next(model.parameters()).device\n", + " state = {}\n", + " hooks = []\n", + " for m in model.modules():\n", + " if isinstance(m, layer_type):\n", + " state[m] = {'acc': torch.zeros(m.weight.shape[1], device=device), 'n': 0}\n", + " def make_hook(module):\n", + " def hook(mod, inp):\n", + " x = inp[0].detach()\n", + " dims = [i for i in range(x.dim()) if i != 1] # keep channel dim\n", + " state[module]['acc'] += x.pow(2).sum(dim=dims)\n", + " state[module]['n'] += x.shape[0]\n", + " return hook\n", + " hooks.append(m.register_forward_pre_hook(make_hook(m)))\n", + "\n", + " model.eval()\n", + " with torch.no_grad():\n", + " if isinstance(data, torch.Tensor):\n", + " model(data.to(device))\n", + " else:\n", + " for n, batch in enumerate(data):\n", + " if n >= n_batches: break\n", + " xb = batch[0] if isinstance(batch, (tuple, list)) else batch\n", + " model(xb.as_subclass(torch.Tensor).to(device))\n", + "\n", + " for h in hooks: h.remove()\n", + " return {m: (s['acc'] / max(s['n'], 1)).sqrt() for m, s in state.items()}\n", + "\n", + "\n", "class FC_Decomposer:\n", " \"Decompose fully-connected layers using SVD to reduce parameters\"\n", "\n", - " def __init__(self):\n", - " pass\n", + " def __init__(self): pass\n", " \n", " def decompose(self, \n", - " model: nn.Module, # The model to decompose\n", - " percent_removed: float = 0.5 # Fraction of singular values to remove [0, 1)\n", + " model: nn.Module, # The model to decompose\n", + " percent_removed: float = 0.5, # Fraction of singular values to remove [0, 1)\n", + " energy_threshold: float | None = None, # Auto rank: keep this fraction of energy (0-1)\n", + " data = None, # Calibration data for ASVD (None = standard SVD)\n", + " n_batches: int = 5, # Number of calibration batches\n", + " layers: list[str] | None = None, # Layer names to decompose (None = all)\n", + " exclude: list[str] | None = None, # Layer names to skip\n", " ) -> nn.Module:\n", - " \"Recursively decompose all Linear layers in the model using SVD\"\n", - " if not (0 <= percent_removed < 1):\n", + " \"Decompose Linear layers using SVD. Pass data for activation-aware ASVD.\"\n", + " if energy_threshold is None and not (0 <= percent_removed < 1):\n", " raise ValueError(f\"percent_removed must be in range [0, 1), got {percent_removed}\")\n", + " if energy_threshold is not None and not (0 < energy_threshold <= 1):\n", + " raise ValueError(f\"energy_threshold must be in range (0, 1], got {energy_threshold}\")\n", + "\n", + " # Collect activation stats on ORIGINAL model before deepcopy\n", + " scale_map = {}\n", + " if data is not None:\n", + " rms = _collect_activation_rms(model, data, nn.Linear, n_batches)\n", + " # Map by name so we can find them after deepcopy\n", + " for name, m in model.named_modules():\n", + " if m in rms: scale_map[name] = rms[m]\n", "\n", " new_model = copy.deepcopy(model)\n", - " module_names = list(new_model._modules)\n", - "\n", - " for k, name in enumerate(module_names):\n", - " if len(list(new_model._modules[name]._modules)) > 0:\n", - " new_model._modules[name] = self.decompose(new_model._modules[name], percent_removed)\n", - " else:\n", - " if isinstance(new_model._modules[name], nn.Linear):\n", - " layer = self.SVD(new_model._modules[name], percent_removed)\n", - " new_model._modules[name] = layer\n", + " for name, module in list(new_model.named_modules()):\n", + " if isinstance(module, nn.Linear) and _should_decompose(name, layers, exclude):\n", + " scale = scale_map.get(name, None)\n", + " parent_name, _, child_name = name.rpartition('.')\n", + " parent = new_model.get_submodule(parent_name) if parent_name else new_model\n", + " setattr(parent, child_name, self.SVD(module, percent_removed, energy_threshold, scale))\n", " return new_model\n", "\n", - "\n", " def SVD(self, \n", - " layer: nn.Linear, # The Linear layer to decompose\n", - " percent_removed: float # Fraction of singular values to remove\n", + " layer: nn.Linear, # The Linear layer to decompose\n", + " percent_removed: float = 0.5, # Fraction of singular values to remove\n", + " energy_threshold: float | None = None, # Auto rank via energy retention\n", + " scale: torch.Tensor | None = None, # Per-channel activation RMS for ASVD\n", " ) -> nn.Sequential:\n", - " \"Perform SVD decomposition on a single Linear layer\"\n", + " \"Perform SVD decomposition. With scale: activation-aware SVD (ASVD).\"\n", " W = layer.weight.data\n", - " U, S, Vh = torch.linalg.svd(W, full_matrices=False)\n", - " L = max(1, int((1.-percent_removed) * S.shape[0]))\n", + "\n", + " # ASVD: scale columns by activation RMS before SVD\n", + " if scale is not None:\n", + " s = scale.to(W.device) + 1e-6\n", + " W_scaled = W * s.unsqueeze(0) # (out, in) * (1, in)\n", + " else:\n", + " W_scaled = W\n", + "\n", + " U, S, Vh = torch.linalg.svd(W_scaled, full_matrices=False)\n", + "\n", + " if energy_threshold is not None:\n", + " L = _rank_from_energy(S, energy_threshold)\n", + " else:\n", + " L = max(1, int((1.-percent_removed) * S.shape[0]))\n", + "\n", " W1 = U[:,:L]\n", " W2 = torch.diag(S[:L]) @ Vh[:L]\n", + "\n", + " # ASVD: undo scaling in the first layer's weights\n", + " if scale is not None:\n", + " s_inv = 1.0 / s\n", + " W2 = W2 * s_inv.unsqueeze(0) # (L, in) * (1, in)\n", + "\n", " layer_1 = nn.Linear(in_features=layer.in_features, \n", " out_features=L, bias=False)\n", " layer_1.weight.data = W2\n", @@ -279,7 +357,7 @@ "decomposer = FC_Decomposer()\n", "model_dec = decomposer.decompose(model, percent_removed=0.5)\n", "out_dec = model_dec(x)\n", - "test_close(out_orig, out_dec, eps=1.0) # 50% SVD removal has significant reconstruction error\n", + "test_close(out_orig, out_dec, eps=1.0)\n", "\n", "# Decomposed structure: Linear → Sequential(Linear, Linear)\n", "assert isinstance(model_dec[0], nn.Sequential)\n", @@ -292,7 +370,7 @@ "m2_dec = decomposer.decompose(m2, percent_removed=0.0)\n", "test_close(out2, m2_dec(x2), eps=1e-4)\n", "\n", - "# L >= 1 always (even at extreme removal)\n", + "# L >= 1 always\n", "m3 = nn.Sequential(nn.Linear(10, 20))\n", "m3_dec = decomposer.decompose(m3, percent_removed=0.95)\n", "assert m3_dec[0][0].out_features >= 1\n", @@ -300,8 +378,51 @@ "# Invalid percent_removed raises ValueError\n", "with ExceptionExpected(ValueError):\n", " decomposer.decompose(nn.Sequential(nn.Linear(10, 10)), percent_removed=1.0)\n", - "with ExceptionExpected(ValueError):\n", - " decomposer.decompose(nn.Sequential(nn.Linear(10, 10)), percent_removed=-0.1)" + "\n", + "# --- energy_threshold ---\n", + "m4 = nn.Sequential(nn.Linear(32, 64), nn.ReLU(), nn.Linear(64, 10))\n", + "m4_99 = decomposer.decompose(m4, energy_threshold=0.99)\n", + "m4_50 = decomposer.decompose(m4, percent_removed=0.5)\n", + "assert m4_99[0][0].out_features >= m4_50[0][0].out_features\n", + "\n", + "# --- layers / exclude ---\n", + "m6 = nn.Sequential(nn.Linear(32, 64), nn.ReLU(), nn.Linear(64, 10))\n", + "m6_sel = decomposer.decompose(m6, 0.5, layers=['0'])\n", + "assert isinstance(m6_sel[0], nn.Sequential)\n", + "assert isinstance(m6_sel[2], nn.Linear)\n", + "\n", + "m7 = nn.Sequential(nn.Linear(32, 64), nn.ReLU(), nn.Linear(64, 10))\n", + "m7_exc = decomposer.decompose(m7, 0.5, exclude=['2'])\n", + "assert isinstance(m7_exc[0], nn.Sequential)\n", + "assert isinstance(m7_exc[2], nn.Linear)\n", + "\n", + "# --- ASVD: activation-aware SVD ---\n", + "m8 = nn.Sequential(nn.Linear(32, 64), nn.ReLU(), nn.Linear(64, 10))\n", + "x8 = torch.randn(16, 32)\n", + "out8 = m8(x8)\n", + "\n", + "# ASVD with calibration data\n", + "m8_asvd = decomposer.decompose(m8, 0.5, data=[x8])\n", + "out8_asvd = m8_asvd(x8)\n", + "\n", + "# Standard SVD for comparison\n", + "m8_svd = decomposer.decompose(m8, 0.5)\n", + "out8_svd = m8_svd(x8)\n", + "\n", + "# Both produce valid outputs\n", + "assert torch.isfinite(out8_asvd).all()\n", + "assert torch.isfinite(out8_svd).all()\n", + "\n", + "# ASVD should have lower reconstruction error on the calibration data\n", + "err_asvd = (out8 - out8_asvd).pow(2).mean()\n", + "err_svd = (out8 - out8_svd).pow(2).mean()\n", + "# Note: on random weights this may not always hold, but scaling should not make things worse\n", + "assert torch.isfinite(err_asvd)\n", + "\n", + "# ASVD with data=None → same as standard SVD\n", + "m9 = nn.Sequential(nn.Linear(10, 20))\n", + "m9_no_data = decomposer.decompose(m9, 0.5, data=None)\n", + "assert isinstance(m9_no_data[0], nn.Sequential)" ] }, { diff --git a/nbs/prune/pruner.ipynb b/nbs/prune/pruner.ipynb index c83b93f..085274e 100644 --- a/nbs/prune/pruner.ipynb +++ b/nbs/prune/pruner.ipynb @@ -343,4 +343,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/nbs/tutorials/misc/conv_decomposer.ipynb b/nbs/tutorials/misc/conv_decomposer.ipynb index 9ea4a36..2c515e4 100644 --- a/nbs/tutorials/misc/conv_decomposer.ipynb +++ b/nbs/tutorials/misc/conv_decomposer.ipynb @@ -21,28 +21,7 @@ "source": [ "## Overview\n", "\n", - "**Conv2d Layer Decomposition** uses Tucker decomposition to factorize convolutional layers into three smaller, more efficient convolutions. This is the Conv2d counterpart of FC Decomposition (which uses SVD for Linear layers).\n", - "\n", - "### How It Works\n", - "\n", - "A Conv2d weight tensor $W \\in \\mathbb{R}^{C_{out} \\times C_{in} \\times H \\times W}$ is decomposed into three convolutions:\n", - "\n", - "1. `Conv2d(C_in, R_in, 1)` — pointwise input compression\n", - "2. `Conv2d(R_in, R_out, (H, W))` — spatial convolution at reduced rank\n", - "3. `Conv2d(R_out, C_out, 1)` — pointwise output expansion\n", - "\n", - "Where $R_{in}$ and $R_{out}$ are the Tucker ranks, controlled by `percent_removed`.\n", - "\n", - "### When to Use Conv Decomposition\n", - "\n", - "| Model Type | Conv Layer Size | Recommendation |\n", - "|------------|-----------------|----------------|\n", - "| ResNet-style | Medium 3×3 convolutions | ✅ **Effective** — 2-4x FLOP reduction |\n", - "| VGG-style | Large 3×3 convolutions | ✅ **Highly effective** |\n", - "| MobileNet | Already uses depthwise separable | ❌ Skipped (grouped convolutions) |\n", - "| 1×1 convolutions | Pointwise | ❌ Skipped automatically |\n", - "\n", - "**Key advantage:** Works on any hardware — no sparse kernel requirements." + "`Conv_Decomposer` factorizes Conv2d layers into smaller convolutions using 4 different mathematical decompositions. Each trades off compression, accuracy, and inference overhead differently." ] }, { @@ -95,7 +74,80 @@ "execution_count": null, "id": "train", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
epochtrain_lossvalid_lossaccuracytime
00.5594000.3097990.85656300:02
10.3347570.3535670.85318000:02
20.2519190.2982110.87753700:02
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "learn = vision_learner(dls, resnet18, metrics=accuracy)\n", "learn.unfreeze()\n", @@ -107,9 +159,9 @@ "id": "decompose-header", "metadata": {}, "source": [ - "## 3. Apply Tucker Decomposition\n", + "## 3. Compare Decomposition Methods\n", "\n", - "Use `Conv_Decomposer` to factorize all eligible Conv2d layers:" + "Let's decompose the same model with all 4 methods and compare:" ] }, { @@ -117,22 +169,56 @@ "execution_count": null, "id": "decompose", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Method Layers Params Compress Latency Speedup\n", + "------------------------------------------------------------\n", + "original — 11,704,896 1.0x 6.80ms 1.0x\n", + "svd 2 6,426,195 1.8x 4.92ms 1.4x\n", + "spatial 2 4,388,736 2.7x 5.65ms 1.2x\n", + "tucker 3 4,723,619 2.5x 9.04ms 0.8x\n", + "cp 4 1,897,873 6.2x 8.95ms 0.8x\n" + ] + } + ], "source": [ + "import copy, time\n", + "\n", "def count_params(model):\n", " return sum(p.numel() for p in model.parameters())\n", "\n", + "def measure_latency(model, x, warmup=10, steps=50):\n", + " model.eval()\n", + " with torch.no_grad():\n", + " for _ in range(warmup): model(x)\n", + " if x.is_cuda: torch.cuda.synchronize()\n", + " t0 = time.perf_counter()\n", + " for _ in range(steps): model(x)\n", + " if x.is_cuda: torch.cuda.synchronize()\n", + " return (time.perf_counter() - t0) / steps * 1000 # ms\n", + "\n", + "learn.model = learn.model.cpu()\n", + "\n", "original_params = count_params(learn.model)\n", - "print(f\"Original parameters: {original_params:,}\")\n", + "device = next(learn.model.parameters()).device\n", + "x_bench = torch.randn(8, 3, 64, 64, device=device)\n", + "base_ms = measure_latency(learn.model, x_bench)\n", "\n", - "# Decompose — remove 50% of rank per mode\n", "decomposer = Conv_Decomposer()\n", - "new_model = decomposer.decompose(learn.model, percent_removed=0.5)\n", "\n", - "new_params = count_params(new_model)\n", - "print(f\"Decomposed parameters: {new_params:,}\")\n", - "print(f\"Reduction: {(1 - new_params/original_params)*100:.1f}%\")\n", - "print(f\"Compression: {original_params/new_params:.1f}x\")" + "print(f\"{'Method':<10} {'Layers':>6} {'Params':>10} {'Compress':>9} {'Latency':>9} {'Speedup':>8}\")\n", + "print(\"-\" * 60)\n", + "print(f\"{'original':<10} {'—':>6} {original_params:>10,} {'1.0x':>9} {base_ms:>8.2f}ms {'1.0x':>8}\")\n", + "\n", + "for method in ['svd', 'spatial', 'tucker', 'cp']:\n", + " model_dec = decomposer.decompose(copy.deepcopy(learn.model), 0.5, method=method)\n", + " params = count_params(model_dec)\n", + " ms = measure_latency(model_dec, x_bench)\n", + " n_layers = {'svd': 2, 'tucker': 3, 'spatial': 2, 'cp': 4}[method]\n", + " print(f\"{method:<10} {n_layers:>6} {params:>10,} {original_params/params:>8.1f}x {ms:>8.2f}ms {base_ms/ms:>7.1f}x\")" ] }, { @@ -140,19 +226,29 @@ "id": "explain", "metadata": {}, "source": [ - "### What Happened?\n", + "### How Each Method Decomposes a Conv2d(64, 128, 3×3)\n", + "\n", + "**SVD** (2 layers) — decomposes output channels:\n", + "```\n", + "Conv2d(64, R, 3×3) → Conv2d(R, 128, 1×1)\n", + "```\n", "\n", - "Each eligible Conv2d layer (kernel > 1×1, not grouped) was replaced by a Sequential of 3 smaller convolutions:\n", + "**Tucker** (3 layers) — decomposes both channel dimensions:\n", + "```\n", + "Conv2d(64, R_in, 1×1) → Conv2d(R_in, R_out, 3×3) → Conv2d(R_out, 128, 1×1)\n", + "```\n", + "\n", + "**Spatial** (2 layers) — decomposes the kernel spatially:\n", + "```\n", + "Conv2d(64, 128×R, 3×1) → Conv2d(128×R, 128, 1×3, groups=128)\n", + "```\n", "\n", + "**CP** (4 layers) — decomposes channels AND spatial:\n", "```\n", - "Before: Conv2d(64, 128, 3×3) — 73,728 parameters\n", - "After: Conv2d(64, 32, 1×1) — 2,048 parameters\n", - " Conv2d(32, 64, 3×3) — 18,432 parameters\n", - " Conv2d(64, 128, 1×1) — 8,192 parameters\n", - " 28,672 parameters (2.6x smaller)\n", + "Conv2d(64, R, 1×1) → Conv2d(R, R, 3×1, dw) → Conv2d(R, R, 1×3, dw) → Conv2d(R, 128, 1×1)\n", "```\n", "\n", - "1×1 convolutions and depthwise convolutions are skipped automatically." + "Each targets a different source of redundancy. Tucker is the best general-purpose choice; CP gives maximum compression but may need more fine-tuning." ] }, { @@ -160,7 +256,9 @@ "id": "accuracy-header", "metadata": {}, "source": [ - "## 4. Accuracy Before Fine-Tuning" + "## 4. Accuracy Impact (Before Fine-Tuning)\n", + "\n", + "Each method has a different reconstruction error — let's measure accuracy drop:" ] }, { @@ -168,10 +266,240 @@ "execution_count": null, "id": "validate", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Method Accuracy vs Baseline\n", + "-----------------------------------\n", + "original 87.8% \n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "svd 41.7% -46.0%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tucker 75.2% -12.5%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "spatial 67.1% -20.6%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "cp 67.1% -20.6%\n" + ] + } + ], "source": [ - "new_learn = Learner(dls, new_model, metrics=accuracy)\n", - "new_learn.validate()" + "baseline = Learner(dls, learn.model, metrics=accuracy).validate()[1]\n", + "print(f\"{'Method':<10} {'Accuracy':>10} {'vs Baseline':>12}\")\n", + "print(\"-\" * 35)\n", + "print(f\"{'original':<10} {baseline*100:>9.1f}% {'':>12}\")\n", + "\n", + "for method in ['svd', 'tucker', 'spatial', 'cp']:\n", + " model_dec = decomposer.decompose(copy.deepcopy(learn.model), 0.5, method=method)\n", + " acc = Learner(dls, model_dec, metrics=accuracy).validate()[1]\n", + " print(f\"{method:<10} {acc*100:>9.1f}% {(acc-baseline)*100:>+11.1f}%\")" ] }, { @@ -179,55 +507,70 @@ "id": "finetune-note", "metadata": {}, "source": [ - "The accuracy drops after decomposition — Tucker decomposition is an **approximation**, not exact. Fine-tuning recovers most of the accuracy:\n", + "Fine-tuning recovers most of the accuracy:\n", "\n", "```python\n", - "new_learn.fit_one_cycle(3, 1e-4) # Fine-tune with small learning rate\n", + "new_learn = Learner(dls, model_dec, metrics=accuracy)\n", + "new_learn.fit_one_cycle(3, 1e-4)\n", "```" ] }, { "cell_type": "markdown", - "id": "params", + "id": "afd5b76f", "metadata": {}, "source": [ - "## 5. Controlling Compression\n", + "## 5. Activation-Aware Decomposition (FC_Decomposer)\n", "\n", - "The `percent_removed` parameter controls how much rank is removed per mode:\n", + "For **Linear layers**, passing calibration data improves decomposition by prioritizing channels the model actually uses (ASVD). This works well because SVD on a 2D matrix has exact scale/unscale.\n", "\n", - "| percent_removed | Rank Kept | Compression | Accuracy Impact |\n", - "|-----------------|-----------|-------------|-----------------|\n", - "| `0.0` | 100% | ~1x (near-exact) | Minimal |\n", - "| `0.3` | 70% | ~1.5-2x | Low |\n", - "| `0.5` | 50% | ~2-4x | Moderate |\n", - "| `0.7` | 30% | ~4-8x | Significant |\n", + "For **Conv2d layers**, activation-aware decomposition is still a research topic — the 4D tensor structure makes exact scaling harder. Use standard decomposition + fine-tuning for best results.\n", "\n", "```python\n", - "# Light compression — minimal accuracy loss\n", - "light = Conv_Decomposer().decompose(model, percent_removed=0.3)\n", + "from fasterai.misc.fc_decomposer import FC_Decomposer\n", "\n", - "# Heavy compression — needs fine-tuning\n", - "heavy = Conv_Decomposer().decompose(model, percent_removed=0.7)\n", + "# ASVD for Linear layers — pass calibration data\n", + "FC_Decomposer().decompose(model, 0.5, data=[calibration_batch])\n", "```" ] }, + { + "cell_type": "markdown", + "id": "params", + "metadata": {}, + "source": [ + "## 6. Auto Rank with `energy_threshold`\n", + "\n", + "Instead of guessing `percent_removed`, let the decomposer pick the right rank automatically:\n", + "\n", + "```python\n", + "# Keep 99% of singular value energy — minimal accuracy loss\n", + "Conv_Decomposer().decompose(model, energy_threshold=0.99)\n", + "\n", + "# Keep 90% — more aggressive compression\n", + "Conv_Decomposer().decompose(model, energy_threshold=0.90)\n", + "```\n", + "\n", + "`energy_threshold` and `percent_removed` are mutually exclusive. Higher threshold = less compression, better accuracy." + ] + }, { "cell_type": "markdown", "id": "combining", "metadata": {}, "source": [ - "## 6. Combining with Other Techniques\n", + "## 7. Combining with Other Techniques\n", "\n", - "Tucker decomposition works well as a first step before other compressions:\n", + "Decomposition works well as a first step before other compressions:\n", "\n", "```python\n", "from fasterai.misc.all import Conv_Decomposer, BN_Folder\n", "\n", - "# 1. Fold BatchNorm into Conv layers\n", + "# 1. Fold BatchNorm\n", "model = BN_Folder().fold(model)\n", "\n", - "# 2. Decompose Conv layers\n", - "model = Conv_Decomposer().decompose(model, percent_removed=0.5)\n", + "# 2. Decompose (activation-aware Tucker)\n", + "model = Conv_Decomposer().decompose(model, 0.5, data=[cal_batch])\n", "\n", "# 3. Fine-tune\n", "learn = Learner(dls, model, metrics=accuracy)\n", @@ -235,12 +578,7 @@ "\n", "# 4. Quantize for deployment\n", "from fasterai.quantize.quantizer import Quantizer\n", - "model = Quantizer(backend='x86', method='static').quantize(model, dls.valid)\n", - "```\n", - "\n", - "### Recommended ordering:\n", - "```\n", - "BN Fold → Tucker Decompose → Fine-tune → Prune → Quantize\n", + "model = Quantizer(backend='torchao', method='int8_weight_only').quantize(model)\n", "```" ] }, @@ -253,21 +591,19 @@ "\n", "## Summary\n", "\n", - "| Metric | ResNet-18 (50% removed) |\n", - "|--------|------------------------|\n", - "| Original Params | ~11.7M |\n", - "| Decomposed Params | ~5-7M |\n", - "| Compression | ~1.7-2.3x |\n", - "| Accuracy (before fine-tune) | Drops ~10-20% |\n", - "| Accuracy (after fine-tune) | Recovers to within 1-3% |\n", + "| Method | Layers | What it decomposes | Best for |\n", + "|--------|--------|-------------------|----------|\n", + "| `'tucker'` | 3 | Both channel dims | General purpose (default) |\n", + "| `'svd'` | 2 | Output channels | Moderate compression, less overhead |\n", + "| `'spatial'` | 2 | Kernel K×K → K×1 + 1×K | Small kernels (3×3, 5×5) |\n", + "| `'cp'` | 4 | Channels + spatial | Maximum compression |\n", "\n", "| Feature | Description |\n", "|---------|-------------|\n", - "| `Conv_Decomposer()` | Create a decomposer instance |\n", - "| `.decompose(model, percent_removed)` | Decompose all eligible Conv2d layers |\n", - "| Skips 1×1 convolutions | Already minimal — decomposition would increase params |\n", - "| Skips grouped convolutions | Tucker assumes standard convolution |\n", - "| Pure PyTorch | No external dependencies (no tensorly) |\n", + "| `Conv_Decomposer().decompose(model, 0.5)` | Tucker decomposition (default) |\n", + "| `method='svd'\\|'tucker'\\|'spatial'\\|'cp'` | Choose decomposition method |\n", + "| `energy_threshold=0.99` | Auto rank selection (keep 99% energy) |\n", + "| `layers=['layer1'], exclude=['conv1']` | Per-layer control |\n", "\n", "---\n", "\n", diff --git a/nbs/tutorials/misc/fc_decomposer.ipynb b/nbs/tutorials/misc/fc_decomposer.ipynb index 5a804d2..5cb63bc 100644 --- a/nbs/tutorials/misc/fc_decomposer.ipynb +++ b/nbs/tutorials/misc/fc_decomposer.ipynb @@ -469,11 +469,11 @@ "\n", "| Parameter | Default | Description |\n", "|-----------|---------|-------------|\n", - "| `rank_ratio` | `0.5` | Fraction of singular values to keep (0-1). Lower = more compression, more accuracy loss |\n", + "| `percent_removed` | `0.5` | Fraction of singular values to keep (0-1). Lower = more compression, more accuracy loss |\n", "\n", - "### Choosing rank_ratio\n", + "### Choosing percent_removed\n", "\n", - "| rank_ratio | Compression | Accuracy Impact |\n", + "| percent_removed | Compression | Accuracy Impact |\n", "|------------|-------------|-----------------|\n", "| `0.8` | Low | Minimal |\n", "| `0.5` | Medium | Moderate |\n", @@ -496,7 +496,7 @@ "learn.fit_one_cycle(5)\n", "\n", "# 2. Decompose FC layers\n", - "fc = FC_Decomposer(rank_ratio=0.5)\n", + "fc = FC_Decomposer(percent_removed=0.5)\n", "new_model = fc.decompose(learn.model)\n", "\n", "# 3. Fine-tune to recover accuracy\n",