Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions fasterai/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,26 +223,40 @@
'fasterai/misc/bn_folding.py')},
'fasterai.misc.conv_decomposer': { 'fasterai.misc.conv_decomposer.Conv_Decomposer': ( 'misc/conv_decomposer.html#conv_decomposer',
'fasterai/misc/conv_decomposer.py'),
'fasterai.misc.conv_decomposer.Conv_Decomposer.CP': ( 'misc/conv_decomposer.html#conv_decomposer.cp',
'fasterai/misc/conv_decomposer.py'),
'fasterai.misc.conv_decomposer.Conv_Decomposer.SVD': ( 'misc/conv_decomposer.html#conv_decomposer.svd',
'fasterai/misc/conv_decomposer.py'),
'fasterai.misc.conv_decomposer.Conv_Decomposer.Spatial': ( 'misc/conv_decomposer.html#conv_decomposer.spatial',
'fasterai/misc/conv_decomposer.py'),
'fasterai.misc.conv_decomposer.Conv_Decomposer.Tucker': ( 'misc/conv_decomposer.html#conv_decomposer.tucker',
'fasterai/misc/conv_decomposer.py'),
'fasterai.misc.conv_decomposer.Conv_Decomposer.__init__': ( 'misc/conv_decomposer.html#conv_decomposer.__init__',
'fasterai/misc/conv_decomposer.py'),
'fasterai.misc.conv_decomposer.Conv_Decomposer.decompose': ( 'misc/conv_decomposer.html#conv_decomposer.decompose',
'fasterai/misc/conv_decomposer.py'),
'fasterai.misc.conv_decomposer._mode_unfold': ( 'misc/conv_decomposer.html#_mode_unfold',
'fasterai/misc/conv_decomposer.py'),
'fasterai.misc.conv_decomposer._partial_tucker': ( 'misc/conv_decomposer.html#_partial_tucker',
'fasterai/misc/conv_decomposer.py'),
'fasterai.misc.conv_decomposer._unfold': ( 'misc/conv_decomposer.html#_unfold',
'fasterai/misc/conv_decomposer.py')},
'fasterai/misc/conv_decomposer.py')},
'fasterai.misc.cpu_optimizer': { 'fasterai.misc.cpu_optimizer.accelerate_model_for_cpu': ( 'misc/cpu_optimizer.html#accelerate_model_for_cpu',
'fasterai/misc/cpu_optimizer.py')},
'fasterai/misc/cpu_optimizer.py'),
'fasterai.misc.cpu_optimizer.optimize_for_cpu': ( 'misc/cpu_optimizer.html#optimize_for_cpu',
'fasterai/misc/cpu_optimizer.py')},
'fasterai.misc.fc_decomposer': { 'fasterai.misc.fc_decomposer.FC_Decomposer': ( 'misc/fc_decomposer.html#fc_decomposer',
'fasterai/misc/fc_decomposer.py'),
'fasterai.misc.fc_decomposer.FC_Decomposer.SVD': ( 'misc/fc_decomposer.html#fc_decomposer.svd',
'fasterai/misc/fc_decomposer.py'),
'fasterai.misc.fc_decomposer.FC_Decomposer.__init__': ( 'misc/fc_decomposer.html#fc_decomposer.__init__',
'fasterai/misc/fc_decomposer.py'),
'fasterai.misc.fc_decomposer.FC_Decomposer.decompose': ( 'misc/fc_decomposer.html#fc_decomposer.decompose',
'fasterai/misc/fc_decomposer.py')},
'fasterai/misc/fc_decomposer.py'),
'fasterai.misc.fc_decomposer._collect_activation_rms': ( 'misc/fc_decomposer.html#_collect_activation_rms',
'fasterai/misc/fc_decomposer.py'),
'fasterai.misc.fc_decomposer._rank_from_energy': ( 'misc/fc_decomposer.html#_rank_from_energy',
'fasterai/misc/fc_decomposer.py'),
'fasterai.misc.fc_decomposer._should_decompose': ( 'misc/fc_decomposer.html#_should_decompose',
'fasterai/misc/fc_decomposer.py')},
'fasterai.prune.all': {},
'fasterai.prune.prune_callback': { 'fasterai.prune.prune_callback.PruneCallback': ( 'prune/prune_callback.html#prunecallback',
'fasterai/prune/prune_callback.py'),
Expand Down
3 changes: 2 additions & 1 deletion fasterai/misc/all.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .bn_folding import *
from .fc_decomposer import *
from .conv_decomposer import *
from .conv_decomposer import *
from .cpu_optimizer import *
1 change: 0 additions & 1 deletion fasterai/misc/bn_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
# %% ../../nbs/misc/bn_folding.ipynb #productive-preparation
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy

# %% ../../nbs/misc/bn_folding.ipynb #83000749
Expand Down
200 changes: 161 additions & 39 deletions fasterai/misc/conv_decomposer.py
Original file line number Diff line number Diff line change
@@ -1,86 +1,208 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/misc/conv_decomposer.ipynb.

# %% auto #0
__all__ = ['Conv_Decomposer']
__all__ = ['VALID_METHODS', 'Conv_Decomposer']

# %% ../../nbs/misc/conv_decomposer.ipynb #imports
import torch
import torch.nn as nn
import copy
from einops import rearrange

# %% ../../nbs/misc/conv_decomposer.ipynb #conv-decomposer
def _unfold(tensor, mode):
"Unfold a tensor along a mode into a matrix"
return tensor.moveaxis(mode, 0).flatten(1)
from .fc_decomposer import _rank_from_energy, _should_decompose

def _partial_tucker(weight, ranks, n_iter=5):
def _mode_unfold(W, mode):
"Unfold a 4D tensor along a mode into a 2D matrix"
return rearrange(W, 'o i h w -> o (i h w)') if mode == 0 else rearrange(W, 'o i h w -> i (o h w)')

def _partial_tucker(weight, ranks, n_iter=10, tol=1e-4):
"Partial Tucker decomposition on modes [0, 1] via alternating SVD (HOOI)"
# Initialize factors from SVD of mode unfoldings
U0 = torch.linalg.svd(_unfold(weight, 0), full_matrices=False)[0][:, :ranks[0]]
U1 = torch.linalg.svd(_unfold(weight, 1), full_matrices=False)[0][:, :ranks[1]]
U0 = torch.linalg.svd(_mode_unfold(weight, 0), full_matrices=False)[0][:, :ranks[0]]
U1 = torch.linalg.svd(_mode_unfold(weight, 1), full_matrices=False)[0][:, :ranks[1]]

for _ in range(n_iter):
# Project out mode 0 using U0, then update U1
U0_prev, U1_prev = U0.clone(), U1.clone()
proj = torch.einsum('oihw, or -> rihw', weight, U0)
U1 = torch.linalg.svd(_unfold(proj, 1), full_matrices=False)[0][:, :ranks[1]]
# Project out mode 1 using U1, then update U0
U1 = torch.linalg.svd(_mode_unfold(proj, 1), full_matrices=False)[0][:, :ranks[1]]
proj = torch.einsum('oihw, is -> oshw', weight, U1)
U0 = torch.linalg.svd(_unfold(proj, 0), full_matrices=False)[0][:, :ranks[0]]
U0 = torch.linalg.svd(_mode_unfold(proj, 0), full_matrices=False)[0][:, :ranks[0]]
if (U0 - U0_prev).norm() + (U1 - U1_prev).norm() < tol: break

# Core = W ×₀ U0ᵀ ×₁ U1ᵀ
core = torch.einsum('oihw, or, is -> rshw', weight, U0, U1)
return core, [U0, U1]

VALID_METHODS = frozenset({'tucker', 'svd', 'spatial', 'cp'})

class Conv_Decomposer:
"Decompose Conv2d layers using Tucker decomposition to reduce parameters and FLOPs"
"Decompose Conv2d layers to reduce parameters and FLOPs"

def __init__(self): pass

def decompose(self,
model: nn.Module, # The model to decompose
percent_removed: float = 0.5, # Fraction of rank to remove per mode [0, 1)
model: nn.Module, # The model to decompose
percent_removed: float = 0.5, # Fraction of rank to remove [0, 1)
method: str = 'tucker', # 'tucker', 'svd', 'spatial', or 'cp'
energy_threshold: float | None = None, # Auto rank via energy retention (0-1)
layers: list[str] | None = None, # Layer names to decompose (None = all eligible)
exclude: list[str] | None = None, # Layer names to skip
n_iter: int = 10, # Max HOOI iterations (tucker only)
tol: float = 1e-4, # HOOI convergence tolerance (tucker only)
) -> nn.Module:
"Recursively decompose all eligible Conv2d layers in the model"
if not (0 <= percent_removed < 1):
"Decompose eligible Conv2d layers using the specified method."
if method not in VALID_METHODS:
raise ValueError(f"method must be one of {VALID_METHODS}, got {method!r}")
if energy_threshold is None and not (0 <= percent_removed < 1):
raise ValueError(f"percent_removed must be in range [0, 1), got {percent_removed}")
if energy_threshold is not None and not (0 < energy_threshold <= 1):
raise ValueError(f"energy_threshold must be in range (0, 1], got {energy_threshold}")

decompose_fn = {'tucker': self.Tucker, 'svd': self.SVD,
'spatial': self.Spatial, 'cp': self.CP}[method]

new_model = copy.deepcopy(model)
for name in list(new_model._modules):
module = new_model._modules[name]
if len(list(module._modules)) > 0:
new_model._modules[name] = self.decompose(module, percent_removed)
elif isinstance(module, nn.Conv2d) and module.groups == 1 and min(module.kernel_size) > 1:
new_model._modules[name] = self.Tucker(module, percent_removed)
for name, module in list(new_model.named_modules()):
if (isinstance(module, nn.Conv2d) and module.groups == 1
and min(module.kernel_size) > 1
and _should_decompose(name, layers, exclude)):
parent_name, _, child_name = name.rpartition('.')
parent = new_model.get_submodule(parent_name) if parent_name else new_model
if method == 'tucker':
replacement = decompose_fn(module, percent_removed, energy_threshold, n_iter, tol)
else:
replacement = decompose_fn(module, percent_removed, energy_threshold)
setattr(parent, child_name, replacement)
return new_model

def Tucker(self,
layer: nn.Conv2d, # The Conv2d layer to decompose
percent_removed: float, # Fraction of rank to remove per mode
def SVD(self,
layer: nn.Conv2d,
percent_removed: float = 0.5,
energy_threshold: float | None = None,
) -> nn.Sequential:
"SVD: 2 layers — spatial at reduced output rank + pointwise expansion"
W = layer.weight.data
C_out, C_in = W.shape[:2]
K = layer.kernel_size

W_2d = rearrange(W, 'o i h w -> o (i h w)')

U, S, Vh = torch.linalg.svd(W_2d, full_matrices=False)
R = _rank_from_energy(S, energy_threshold) if energy_threshold else max(1, int((1 - percent_removed) * min(C_out, C_in)))

W_first = torch.diag(S[:R]) @ Vh[:R]

first = nn.Conv2d(C_in, R, K, stride=layer.stride,
padding=layer.padding, dilation=layer.dilation, bias=False)
first.weight.data = rearrange(W_first, 'r (i h w) -> r i h w', i=C_in, h=K[0], w=K[1])

last = nn.Conv2d(R, C_out, 1, bias=layer.bias is not None)
last.weight.data = rearrange(U[:, :R], 'o r -> o r 1 1')
if layer.bias is not None: last.bias.data = layer.bias.data

return nn.Sequential(first, last)

def Spatial(self,
layer: nn.Conv2d,
percent_removed: float = 0.5,
energy_threshold: float | None = None,
) -> nn.Sequential:
"Perform Tucker decomposition on a single Conv2d layer"
"Spatial separable: 2 layers — K×1 vertical + 1×K horizontal (batched SVD)"
W = layer.weight.data
C_out, C_in = W.shape[:2]
Kh, Kw = layer.kernel_size

W_spatial = rearrange(W, 'o i h w -> (o i) h w')
U_all, S_all, Vh_all = torch.linalg.svd(W_spatial, full_matrices=False)
R = _rank_from_energy(S_all[0], energy_threshold) if energy_threshold else max(1, int((1 - percent_removed) * min(Kh, Kw)))

R_out = max(1, int((1 - percent_removed) * C_out))
R_in = max(1, int((1 - percent_removed) * C_in))
U_scaled = U_all[:, :, :R] * S_all[:, :R].unsqueeze(1).sqrt()
W_vert = rearrange(U_scaled, '(o i) h r -> (o r) i h 1', o=C_out, i=C_in)

Vh_scaled = S_all[:, :R].unsqueeze(2).sqrt() * Vh_all[:, :R, :]
Vh_by_out = rearrange(Vh_scaled, '(o i) r w -> o i r w', o=C_out)
W_horiz = rearrange(Vh_by_out.mean(dim=1), 'o r w -> o r 1 w')

vert = nn.Conv2d(C_in, C_out * R, (Kh, 1),
stride=(layer.stride[0], 1), padding=(layer.padding[0], 0), bias=False)
vert.weight.data = W_vert

horiz = nn.Conv2d(C_out * R, C_out, (1, Kw), groups=C_out,
stride=(1, layer.stride[1]), padding=(0, layer.padding[1]),
bias=layer.bias is not None)
horiz.weight.data = W_horiz
if layer.bias is not None: horiz.bias.data = layer.bias.data

return nn.Sequential(vert, horiz)

def CP(self,
layer: nn.Conv2d,
percent_removed: float = 0.5,
energy_threshold: float | None = None,
) -> nn.Sequential:
"CP: 4 layers — pointwise compress + depthwise vertical + depthwise horizontal + pointwise expand"
W = layer.weight.data
C_out, C_in = W.shape[:2]
Kh, Kw = layer.kernel_size

W_2d = rearrange(W, 'o i h w -> o (i h w)')
U, S, Vh = torch.linalg.svd(W_2d, full_matrices=False)
S0 = torch.linalg.svd(_mode_unfold(W, 0), full_matrices=False)[1]
R = _rank_from_energy(S0, energy_threshold) if energy_threshold else max(1, int((1 - percent_removed) * min(C_out, C_in)))

V_4d = rearrange(Vh[:R], 'r (i h w) -> r i h w', i=C_in, h=Kh, w=Kw)
spatial_avg = V_4d.mean(dim=1)
U_s, S_s, Vh_s = torch.linalg.svd(spatial_avg, full_matrices=False)

W_dw_v = rearrange(U_s[:, :, 0] * S_s[:, 0:1].sqrt(), 'r h -> r 1 h 1')
W_dw_h = rearrange(Vh_s[:, 0, :] * S_s[:, 0:1].sqrt(), 'r w -> r 1 1 w')
channel_norms = V_4d.pow(2).sum(dim=(2, 3)).sqrt()
W_pw_in = rearrange(channel_norms * S[:R].sqrt().unsqueeze(1), 'r i -> r i 1 1')
W_pw_out = rearrange(U[:, :R] * S[:R].sqrt().unsqueeze(0), 'o r -> o r 1 1')

pw_in = nn.Conv2d(C_in, R, 1, bias=False)
pw_in.weight.data = W_pw_in
dw_v = nn.Conv2d(R, R, (Kh, 1), groups=R, stride=(layer.stride[0], 1),
padding=(layer.padding[0], 0), bias=False)
dw_v.weight.data = W_dw_v
dw_h = nn.Conv2d(R, R, (1, Kw), groups=R, stride=(1, layer.stride[1]),
padding=(0, layer.padding[1]), bias=False)
dw_h.weight.data = W_dw_h
pw_out = nn.Conv2d(R, C_out, 1, bias=layer.bias is not None)
pw_out.weight.data = W_pw_out
if layer.bias is not None: pw_out.bias.data = layer.bias.data

return nn.Sequential(pw_in, dw_v, dw_h, pw_out)

def Tucker(self,
layer: nn.Conv2d,
percent_removed: float = 0.5,
energy_threshold: float | None = None,
n_iter: int = 10,
tol: float = 1e-4,
) -> nn.Sequential:
"Tucker: 3 layers — pointwise compress + spatial + pointwise expand"
W = layer.weight.data
C_out, C_in = W.shape[:2]

core, (U_out, U_in) = _partial_tucker(W, [R_out, R_in])
# core: (R_out, R_in, H, W), U_out: (C_out, R_out), U_in: (C_in, R_in)
if energy_threshold is not None:
S0 = torch.linalg.svd(_mode_unfold(W, 0), full_matrices=False)[1]
S1 = torch.linalg.svd(_mode_unfold(W, 1), full_matrices=False)[1]
R_out = _rank_from_energy(S0, energy_threshold)
R_in = _rank_from_energy(S1, energy_threshold)
else:
R_out = max(1, int((1 - percent_removed) * C_out))
R_in = max(1, int((1 - percent_removed) * C_in))
core, (U_out, U_in) = _partial_tucker(W, [R_out, R_in], n_iter=n_iter, tol=tol)

# 1. Pointwise input compression: (C_in → R_in)
first = nn.Conv2d(C_in, R_in, 1, bias=False)
first.weight.data = U_in.t().unsqueeze(-1).unsqueeze(-1)
first.weight.data = rearrange(U_in.t(), 'r i -> r i 1 1')

# 2. Spatial convolution at reduced rank: (R_in → R_out)
middle = nn.Conv2d(R_in, R_out, layer.kernel_size, stride=layer.stride,
padding=layer.padding, dilation=layer.dilation, bias=False)
middle.weight.data = core

# 3. Pointwise output expansion: (R_out → C_out)
last = nn.Conv2d(R_out, C_out, 1, bias=layer.bias is not None)
last.weight.data = U_out.unsqueeze(-1).unsqueeze(-1)
if layer.bias is not None:
last.bias.data = layer.bias.data
last.weight.data = rearrange(U_out, 'o r -> o r 1 1')
if layer.bias is not None: last.bias.data = layer.bias.data

return nn.Sequential(first, middle, last)
Loading
Loading