diff --git a/fasterai/_modidx.py b/fasterai/_modidx.py index 84400fa..bf06716 100644 --- a/fasterai/_modidx.py +++ b/fasterai/_modidx.py @@ -208,6 +208,34 @@ 'fasterai/export/onnx_exporter.py'), 'fasterai.export.onnx_exporter.verify_onnx': ( 'export/onnx_exporter.html#verify_onnx', 'fasterai/export/onnx_exporter.py')}, + 'fasterai.huggingface.huggingface': { 'fasterai.huggingface.huggingface.HFSparsifyCallback': ( 'huggingface/huggingface.html#hfsparsifycallback', + 'fasterai/huggingface/huggingface.py'), + 'fasterai.huggingface.huggingface.HFSparsifyCallback.__init__': ( 'huggingface/huggingface.html#hfsparsifycallback.__init__', + 'fasterai/huggingface/huggingface.py'), + 'fasterai.huggingface.huggingface.HFSparsifyCallback._sparsity_value': ( 'huggingface/huggingface.html#hfsparsifycallback._sparsity_value', + 'fasterai/huggingface/huggingface.py'), + 'fasterai.huggingface.huggingface.HFSparsifyCallback.on_epoch_end': ( 'huggingface/huggingface.html#hfsparsifycallback.on_epoch_end', + 'fasterai/huggingface/huggingface.py'), + 'fasterai.huggingface.huggingface.HFSparsifyCallback.on_log': ( 'huggingface/huggingface.html#hfsparsifycallback.on_log', + 'fasterai/huggingface/huggingface.py'), + 'fasterai.huggingface.huggingface.HFSparsifyCallback.on_optimizer_step': ( 'huggingface/huggingface.html#hfsparsifycallback.on_optimizer_step', + 'fasterai/huggingface/huggingface.py'), + 'fasterai.huggingface.huggingface.HFSparsifyCallback.on_step_begin': ( 'huggingface/huggingface.html#hfsparsifycallback.on_step_begin', + 'fasterai/huggingface/huggingface.py'), + 'fasterai.huggingface.huggingface.HFSparsifyCallback.on_train_begin': ( 'huggingface/huggingface.html#hfsparsifycallback.on_train_begin', + 'fasterai/huggingface/huggingface.py'), + 'fasterai.huggingface.huggingface.HFSparsifyCallback.on_train_end': ( 'huggingface/huggingface.html#hfsparsifycallback.on_train_end', + 'fasterai/huggingface/huggingface.py'), + 'fasterai.huggingface.huggingface._has_transformers': ( 'huggingface/huggingface.html#_has_transformers', + 'fasterai/huggingface/huggingface.py'), + 'fasterai.huggingface.huggingface._load_model': ( 'huggingface/huggingface.html#_load_model', + 'fasterai/huggingface/huggingface.py'), + 'fasterai.huggingface.huggingface._require_transformers': ( 'huggingface/huggingface.html#_require_transformers', + 'fasterai/huggingface/huggingface.py'), + 'fasterai.huggingface.huggingface._save_compressed': ( 'huggingface/huggingface.html#_save_compressed', + 'fasterai/huggingface/huggingface.py'), + 'fasterai.huggingface.huggingface.sparsify_model': ( 'huggingface/huggingface.html#sparsify_model', + 'fasterai/huggingface/huggingface.py')}, 'fasterai.misc.all': {}, 'fasterai.misc.bn_folding': { 'fasterai.misc.bn_folding.BN_Folder': ( 'misc/bn_folding.html#bn_folder', 'fasterai/misc/bn_folding.py'), @@ -259,12 +287,18 @@ 'fasterai/prune/pruner.py'), 'fasterai.prune.pruner.Pruner._build_pruning_schedule': ( 'prune/pruner.html#pruner._build_pruning_schedule', 'fasterai/prune/pruner.py'), + 'fasterai.prune.pruner.Pruner._detect_attention_heads': ( 'prune/pruner.html#pruner._detect_attention_heads', + 'fasterai/prune/pruner.py'), + 'fasterai.prune.pruner.Pruner._freeze_head_pruning': ( 'prune/pruner.html#pruner._freeze_head_pruning', + 'fasterai/prune/pruner.py'), + 'fasterai.prune.pruner.Pruner._patch_attention_forward': ( 'prune/pruner.html#pruner._patch_attention_forward', + 'fasterai/prune/pruner.py'), 'fasterai.prune.pruner.Pruner._resolve_pruning_ratio_dict': ( 'prune/pruner.html#pruner._resolve_pruning_ratio_dict', 'fasterai/prune/pruner.py'), + 'fasterai.prune.pruner.Pruner._sync_attention_attrs': ( 'prune/pruner.html#pruner._sync_attention_attrs', + 'fasterai/prune/pruner.py'), 'fasterai.prune.pruner.Pruner._to_tp_scheduler': ( 'prune/pruner.html#pruner._to_tp_scheduler', 'fasterai/prune/pruner.py'), - 'fasterai.prune.pruner.Pruner.get_attention_layers_to_ignore': ( 'prune/pruner.html#pruner.get_attention_layers_to_ignore', - 'fasterai/prune/pruner.py'), 'fasterai.prune.pruner.Pruner.get_ignored_layers': ( 'prune/pruner.html#pruner.get_ignored_layers', 'fasterai/prune/pruner.py'), 'fasterai.prune.pruner.Pruner.get_linear_layers_to_ignore': ( 'prune/pruner.html#pruner.get_linear_layers_to_ignore', @@ -275,8 +309,8 @@ 'fasterai/prune/pruner.py'), 'fasterai.prune.pruner.Pruner.prune_model': ( 'prune/pruner.html#pruner.prune_model', 'fasterai/prune/pruner.py'), - 'fasterai.prune.pruner.Pruner.restore_attention_layers': ( 'prune/pruner.html#pruner.restore_attention_layers', - 'fasterai/prune/pruner.py')}, + 'fasterai.prune.pruner._pruning_compatible_forward': ( 'prune/pruner.html#_pruning_compatible_forward', + 'fasterai/prune/pruner.py')}, 'fasterai.quantize.all': {}, 'fasterai.quantize.quantize_callback': { 'fasterai.quantize.quantize_callback.QuantizeCallback': ( 'quantize/quantize_callback.html#quantizecallback', 'fasterai/quantize/quantize_callback.py'), diff --git a/fasterai/prune/prune_callback.py b/fasterai/prune/prune_callback.py index f2fb573..1ea720d 100644 --- a/fasterai/prune/prune_callback.py +++ b/fasterai/prune/prune_callback.py @@ -53,6 +53,7 @@ def before_fit(self) -> None: criteria=self.criteria, pruning_ratio=self.pruning_ratio, context=self.context, + example_inputs=self.example_inputs, iterative_steps=total_training_steps, schedule=pruning_schedule, *self.extra_args, diff --git a/fasterai/prune/pruner.py b/fasterai/prune/pruner.py index 3999ab6..7aa4abd 100644 --- a/fasterai/prune/pruner.py +++ b/fasterai/prune/pruner.py @@ -23,18 +23,61 @@ # %% ../../nbs/prune/pruner.ipynb #63acddeb-f30e-448b-a397-d4cac2adba7a from ..core.schedule import Schedule +def _pruning_compatible_forward(self, x, attn_mask=None): + "Attention forward using reshape(B, N, -1) for pruning compatibility" + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, + dropout_p=self.attn_drop.p if self.training else 0.) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + if attn_mask is not None: attn = attn + attn_mask + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + x = x.transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + class Pruner(): "Structured pruning for neural networks using torch_pruning" - def __init__(self, model, pruning_ratio, context, criteria, schedule=linear_scheduler, ignored_layers=None, example_inputs=torch.randn(1, 3, 224, 224), *args, **kwargs): + def __init__(self, + model, # The PyTorch model to prune + pruning_ratio, # Channel pruning ratio (float 0-1, int 0-100, or dict) + context, # 'local' or 'global' + criteria, # Importance criteria (e.g. large_final) + schedule=linear_scheduler,# Pruning schedule + ignored_layers=None, # Layers to skip during pruning + example_inputs=torch.randn(1, 3, 224, 224), # Dummy input for tracing + head_pruning_ratio=0.0, # Ratio of attention heads to remove (0-1 or 0-100) + prune_num_heads=False, # Remove entire attention heads + prune_head_dims=True, # Reduce head dimensions + *args, **kwargs, + ): store_attr() self.num_heads = {} + self._original_num_heads = {} + self._heads_pruned = False self._original_params = sum(p.numel() for p in model.parameters()) + + # Normalize head_pruning_ratio + if self.head_pruning_ratio > 1: self.head_pruning_ratio = self.head_pruning_ratio / 100 + # Auto-enable: head_pruning_ratio > 0 implies prune whole heads, not dims (XOR pattern) + if self.head_pruning_ratio > 0: + self.prune_num_heads = True + self.prune_head_dims = False + if not self.ignored_layers: self.get_ignored_layers(self.model) # Handle pruning_ratio as float or dict self.pruning_ratio_dict = None if isinstance(self.pruning_ratio, dict): - # Convert name-based dict to module-based dict for torch-pruning self.pruning_ratio_dict = self._resolve_pruning_ratio_dict(self.pruning_ratio) self.default_pruning_ratio = kwargs.pop('default_pruning_ratio', 0.0) print(f"Using per-layer pruning with {len(self.pruning_ratio_dict)} layer-specific ratios") @@ -56,6 +99,9 @@ def __init__(self, model, pruning_ratio, context, criteria, schedule=linear_sche ignored_layers=self.ignored_layers, global_pruning=True if self.context=='global' else False, num_heads=self.num_heads, + prune_num_heads=self.prune_num_heads, + prune_head_dims=self.prune_head_dims, + head_pruning_ratio=self.head_pruning_ratio, iterative_pruning_ratio_scheduler=tp_schedule, *args, **kwargs @@ -72,10 +118,8 @@ def scheduler(pruning_ratio, steps, start=0, end=1): def _to_tp_scheduler(self, schedule): "Convert Schedule object or callable to torch-pruning compatible scheduler" - # If it's a Schedule object, extract sched_func and build compatible function if isinstance(schedule, Schedule): return self._build_pruning_schedule(schedule.sched_func) - # Otherwise assume it's already a compatible callable (like linear_scheduler) return schedule def _resolve_pruning_ratio_dict(self, ratio_dict): @@ -86,30 +130,54 @@ def _resolve_pruning_ratio_dict(self, ratio_dict): if isinstance(key, str): if key in name_to_module: module = name_to_module[key] - # Normalize ratio to 0-1 range resolved[module] = ratio / 100 if ratio > 1 else ratio else: print(f"Warning: Layer '{key}' not found in model, skipping") elif isinstance(key, nn.Module): resolved[key] = ratio / 100 if ratio > 1 else ratio return resolved - - def prune_model(self): - "Execute one pruning step and restore attention layer configurations" - self.pruner.step() - self.restore_attention_layers() + def _patch_attention_forward(self, + module: nn.Module # Attention module with .qkv + ): + "Patch attention forward to use reshape(B,N,-1) for pruning compatibility" + # Only patch if the module has the timm Attention interface + if not all(hasattr(module, a) for a in ('qkv', 'proj', 'proj_drop', 'attn_drop', + 'q_norm', 'k_norm', 'fused_attn', 'scale')): + return + module.forward = _pruning_compatible_forward.__get__(module, type(module)) - def get_linear_layers_to_ignore(self, + def _detect_attention_heads(self, + model: nn.Module # The model to analyze + ): + "Detect attention layers with QKV projections and populate num_heads mapping" + for module in model.modules(): + # nn.MultiheadAttention uses raw parameters (in_proj_weight), not Linear + # submodules — torch-pruning's head pruning requires .out_features on qkv + # layers, so only timm-style attention (with .qkv Linear) is supported. + if isinstance(module, nn.MultiheadAttention): + self.ignored_layers.append(module) + continue + if hasattr(module, 'num_heads'): + if hasattr(module, 'qkv'): + self.num_heads[module.qkv] = module.num_heads + # Patch forward to use reshape(B,N,-1) — required for head pruning + # (official torch-pruning pattern from prune_timm_vit.py) + if self.prune_num_heads: self._patch_attention_forward(module) + elif hasattr(module, 'qkv_proj'): + self.num_heads[module.qkv_proj] = module.num_heads + self._original_num_heads = dict(self.num_heads) + + def get_linear_layers_to_ignore(self, model: nn.Module # The model to analyze ): "Find and ignore output Linear layers to preserve model output dimensions" try: traced = symbolic_trace(model) for node in traced.graph.nodes: - if node.op == "output": # Identify the output + if node.op == "output": for input_node in node.all_input_nodes: - if input_node.target: # Find the corresponding layer + if input_node.target: module = dict(model.named_modules()).get(input_node.target) if isinstance(module, torch.nn.Linear): self.ignored_layers.append(module) @@ -117,44 +185,55 @@ def get_linear_layers_to_ignore(self, except Exception as e: print(f"Could not trace model for output layer detection: {e}") - - def get_attention_layers_to_ignore(self, - model: nn.Module # The model to analyze - ): - "Find and ignore attention layers (qkv projections) to preserve attention structure" - for module in model.modules(): - if hasattr(module, 'num_heads'): - if hasattr(module, 'qkv'): - self.ignored_layers.append(module.qkv) - self.num_heads[module.qkv] = module.num_heads - print(f"Attention layer ignored: {module.qkv}, num_heads={module.num_heads}") - elif hasattr(module, 'qkv_proj'): - self.ignored_layers.append(module.qkv_proj) - self.num_heads[module.qkv_proj] = module.num_heads - print(f"Attention layer ignored: {module.qkv_proj}, num_heads={module.num_heads}") - - - def get_ignored_layers(self, + def get_ignored_layers(self, model: nn.Module # The model to analyze ): "Build list of layers to ignore during pruning" self.ignored_layers = [] self.get_linear_layers_to_ignore(model) - self.get_attention_layers_to_ignore(model) + self._detect_attention_heads(model) + # Ignore QKV layers only when head pruning is disabled + if not self.prune_num_heads and self.head_pruning_ratio == 0: + for layer in self.num_heads: + self.ignored_layers.append(layer) + if self.num_heads: + action = "will be pruned" if self.prune_num_heads else "ignored" + print(f"Detected {len(self.num_heads)} attention layer(s) ({action})") print(f"Total ignored layers: {len(self.ignored_layers)}") - - - def restore_attention_layers(self): - "Restore num_heads and head_dim attributes after pruning attention layers" - for m in self.model.modules(): - if hasattr(m, 'num_heads'): - if hasattr(m, 'qkv'): - m.num_heads = self.num_heads[m.qkv] - m.head_dim = m.qkv.out_features // (3 * m.num_heads) - elif hasattr(m, 'qkv_proj'): - m.num_heads = self.num_heads[m.qkv_proj] - m.head_dim = m.qkv_proj.out_features // (3 * m.num_heads) + def _freeze_head_pruning(self): + "Disable further head pruning after first application to prevent over-pruning" + # torch-pruning computes head removal from current count (not original), + # so repeated steps would halve heads each time (6→3→1→0). We freeze + # the head pruning ratios after the first step that modifies heads. + if not self._heads_pruned and self.prune_num_heads: + heads_changed = any( + self.pruner.num_heads.get(layer, orig) != orig + for layer, orig in self._original_num_heads.items() + ) + if heads_changed: + self._heads_pruned = True + for i in range(len(self.pruner.per_step_head_pruning_ratio)): + if i > self.pruner.current_step: + self.pruner.per_step_head_pruning_ratio[i] = 0.0 + + def _sync_attention_attrs(self): + "Sync attention module attributes with torch-pruning's updated head counts" + for module in self.model.modules(): + if not hasattr(module, 'num_heads') or isinstance(module, nn.MultiheadAttention): + continue + if hasattr(module, 'qkv') and module.qkv in self.pruner.num_heads: + module.num_heads = self.pruner.num_heads[module.qkv] + module.head_dim = module.qkv.out_features // (3 * module.num_heads) + elif hasattr(module, 'qkv_proj') and module.qkv_proj in self.pruner.num_heads: + module.num_heads = self.pruner.num_heads[module.qkv_proj] + module.head_dim = module.qkv_proj.out_features // (3 * module.num_heads) + + def prune_model(self): + "Execute one pruning step and sync attention layer attributes" + self.pruner.step() + self._sync_attention_attrs() + self._freeze_head_pruning() def group_importance(self, group): "Compute importance scores for a dependency group" @@ -164,55 +243,65 @@ def group_importance(self, group): function.prune_linear_in_channels: 'column', function.prune_conv_in_channels: 'shared_kernel', } - + group_imp = [] group_idxs = [] - + for i, (dep, idxs) in enumerate(group): if dep.handler in handler_map: impo = self.criteria(dep.target.module, handler_map.get(dep.handler), squeeze=True) group_imp.append(impo) group_idxs.append(group[i].root_idxs) - + if len(group_imp) == 0: return torch.tensor([]) - + reduced_imp = torch.zeros_like(group_imp[0]) - + for i, (imp, root_idxs) in enumerate(zip(group_imp, group_idxs)): imp = imp.to('cpu') reduced_imp = reduced_imp.to('cpu') reduced_imp.scatter_add_(0, torch.tensor(root_idxs, device=imp.device), imp) - + reduced_imp /= len(group_imp) - + return reduced_imp.to(default_device()) def print_sparsity(self) -> None: "Print pruning report showing channel counts and parameter reduction" total_params = 0 - + print("\nPruning Report:") print("-" * 85) print(f"{'Layer':<35} {'Type':<12} {'In Ch':<8} {'Out Ch':<8} {'Params':<12}") print("-" * 85) - + for name, m in self.model.named_modules(): if isinstance(m, (nn.Conv2d, nn.Linear)): params = sum(p.numel() for p in m.parameters()) total_params += params - + if isinstance(m, nn.Conv2d): in_ch, out_ch = m.in_channels, m.out_channels layer_type = "Conv2d" else: in_ch, out_ch = m.in_features, m.out_features layer_type = "Linear" - + print(f"{name:<35} {layer_type:<12} {in_ch:<8} {out_ch:<8} {params:<12,}") - + print("-" * 85) reduction = 100 * (1 - total_params / self._original_params) if self._original_params > 0 else 0 print(f"{'Total':<35} {'':<12} {'':<8} {'':<8} {total_params:<12,}") print(f"{'Original':<35} {'':<12} {'':<8} {'':<8} {self._original_params:<12,}") print(f"{'Reduction':<35} {'':<12} {'':<8} {'':<8} {reduction:>10.2f}%") + + # Head count changes + if self._original_num_heads: + print(f"\n{'Attention Heads':}") + print("-" * 50) + for layer, orig in self._original_num_heads.items(): + current = self.pruner.num_heads.get(layer, orig) + name = next((n for n, m in self.model.named_modules() if m is layer), str(layer)) + status = f"{orig} -> {current}" if current != orig else f"{orig} (unchanged)" + print(f" {name:<33} {status}") diff --git a/nbs/prune/prune_callback.ipynb b/nbs/prune/prune_callback.ipynb index e377026..0a900db 100644 --- a/nbs/prune/prune_callback.ipynb +++ b/nbs/prune/prune_callback.ipynb @@ -74,65 +74,7 @@ "id": "50598138-7d55-4774-b711-114c1c42dce8", "metadata": {}, "outputs": [], - "source": [ - "#| export\n", - "class PruneCallback(Callback):\n", - " def __init__(self, pruning_ratio, schedule, context, criteria, *args, **kwargs):\n", - " store_attr()\n", - " self.sparsity_levels = []\n", - " self.extra_args = args\n", - " self.extra_kwargs = kwargs\n", - "\n", - " def _build_pruning_schedule(self, sched_func):\n", - " \"Create a schedule function compatible with torch-pruning's Pruner\"\n", - " start_val, end_val = self.schedule.start_val, self.schedule.end_val\n", - " def scheduler(pruning_ratio, steps, start=start_val, end=end_val):\n", - " return [\n", - " sched_func(start, end, i / float(steps)) * pruning_ratio\n", - " for i in range(steps + 1)\n", - " ]\n", - " return scheduler\n", - "\n", - " def before_fit(self) -> None:\n", - " \"Setup pruner before training\"\n", - " n_batches_per_epoch = len(self.learn.dls.train)\n", - " total_training_steps = n_batches_per_epoch * self.learn.n_epoch\n", - " self.pruning_ratio = self.pruning_ratio/100 if self.pruning_ratio>1 else self.pruning_ratio\n", - " \n", - " # Validate pruning_ratio is in valid range\n", - " if not (0 < self.pruning_ratio <= 1):\n", - " raise ValueError(f\"pruning_ratio must be in range (0, 1], got {self.pruning_ratio}\")\n", - "\n", - " self.example_inputs, _ = self.learn.dls.one_batch()\n", - " \n", - " # Build schedule function for torch-pruning compatibility\n", - " pruning_schedule = self._build_pruning_schedule(self.schedule.sched_func)\n", - " self.sparsity_levels = pruning_schedule(self.pruning_ratio, total_training_steps)\n", - " \n", - " self.pruner = Pruner(\n", - " self.learn.model,\n", - " criteria=self.criteria,\n", - " pruning_ratio=self.pruning_ratio, \n", - " context=self.context,\n", - " iterative_steps=total_training_steps, \n", - " schedule=pruning_schedule,\n", - " *self.extra_args, \n", - " **self.extra_kwargs\n", - " )\n", - " \n", - " def before_step(self) -> None:\n", - " \"Apply pruning before optimizer step\"\n", - " if self.training: \n", - " self.pruner.prune_model()\n", - "\n", - " def after_epoch(self) -> None:\n", - " \"Log sparsity after each epoch\"\n", - " completed_steps = (self.epoch + 1) * len(self.learn.dls.train)\n", - " # Bounds check for sparsity_levels access\n", - " if completed_steps > 0 and completed_steps <= len(self.sparsity_levels):\n", - " current_sparsity = self.sparsity_levels[completed_steps - 1]\n", - " print(f'Sparsity at the end of epoch {self.epoch}: {current_sparsity*100:.2f}%')" - ] + "source": "#| export\nclass PruneCallback(Callback):\n def __init__(self, pruning_ratio, schedule, context, criteria, *args, **kwargs):\n store_attr()\n self.sparsity_levels = []\n self.extra_args = args\n self.extra_kwargs = kwargs\n\n def _build_pruning_schedule(self, sched_func):\n \"Create a schedule function compatible with torch-pruning's Pruner\"\n start_val, end_val = self.schedule.start_val, self.schedule.end_val\n def scheduler(pruning_ratio, steps, start=start_val, end=end_val):\n return [\n sched_func(start, end, i / float(steps)) * pruning_ratio\n for i in range(steps + 1)\n ]\n return scheduler\n\n def before_fit(self) -> None:\n \"Setup pruner before training\"\n n_batches_per_epoch = len(self.learn.dls.train)\n total_training_steps = n_batches_per_epoch * self.learn.n_epoch\n self.pruning_ratio = self.pruning_ratio/100 if self.pruning_ratio>1 else self.pruning_ratio\n \n # Validate pruning_ratio is in valid range\n if not (0 < self.pruning_ratio <= 1):\n raise ValueError(f\"pruning_ratio must be in range (0, 1], got {self.pruning_ratio}\")\n\n self.example_inputs, _ = self.learn.dls.one_batch()\n \n # Build schedule function for torch-pruning compatibility\n pruning_schedule = self._build_pruning_schedule(self.schedule.sched_func)\n self.sparsity_levels = pruning_schedule(self.pruning_ratio, total_training_steps)\n \n self.pruner = Pruner(\n self.learn.model,\n criteria=self.criteria,\n pruning_ratio=self.pruning_ratio, \n context=self.context,\n example_inputs=self.example_inputs,\n iterative_steps=total_training_steps, \n schedule=pruning_schedule,\n *self.extra_args, \n **self.extra_kwargs\n )\n \n def before_step(self) -> None:\n \"Apply pruning before optimizer step\"\n if self.training: \n self.pruner.prune_model()\n\n def after_epoch(self) -> None:\n \"Log sparsity after each epoch\"\n completed_steps = (self.epoch + 1) * len(self.learn.dls.train)\n # Bounds check for sparsity_levels access\n if completed_steps > 0 and completed_steps <= len(self.sparsity_levels):\n current_sparsity = self.sparsity_levels[completed_steps - 1]\n print(f'Sparsity at the end of epoch {self.epoch}: {current_sparsity*100:.2f}%')" }, { "cell_type": "code", @@ -374,6 +316,14 @@ "assert _params_after < _params_before, f\"Expected params to decrease: {_params_before} → {_params_after}\"" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "ab7a1d2b", + "metadata": {}, + "outputs": [], + "source": "#| hide\n#| slow\n# Transformer head pruning — verify Pruner directly on a ViT-like model\n# Note: PruneCallback integration with ViTs requires models registered with\n# torch-pruning (e.g. timm ViT). This test verifies the core head pruning logic.\nimport torch, torch.nn as nn\n\nclass _Attention(nn.Module):\n def __init__(self, dim, num_heads):\n super().__init__()\n self.num_heads = num_heads\n self.head_dim = dim // num_heads\n self.qkv = nn.Linear(dim, dim * 3)\n self.proj = nn.Linear(dim, dim)\n def forward(self, x):\n B, N, _ = x.shape\n qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)\n q, k, v = qkv.unbind(0)\n attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)\n x = (attn.softmax(dim=-1) @ v).transpose(1, 2).reshape(B, N, -1)\n return self.proj(x)\n\nclass _Block(nn.Module):\n def __init__(self, dim, num_heads):\n super().__init__()\n self.norm1 = nn.LayerNorm(dim)\n self.attn = _Attention(dim, num_heads)\n self.norm2 = nn.LayerNorm(dim)\n self.mlp = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Linear(dim*4, dim))\n def forward(self, x):\n x = x + self.attn(self.norm1(x))\n x = x + self.mlp(self.norm2(x))\n return x\n\nclass _SimpleViT(nn.Module):\n def __init__(self, dim=64, num_heads=4, depth=2, num_classes=10):\n super().__init__()\n self.blocks = nn.Sequential(*[_Block(dim, num_heads) for _ in range(depth)])\n self.norm = nn.LayerNorm(dim)\n self.head = nn.Linear(dim, num_classes)\n def forward(self, x):\n x = self.blocks(x)\n return self.head(self.norm(x).mean(dim=1))\n\nfrom fasterai.prune.pruner import Pruner\nfrom fasterai.core.criteria import large_final\n\n_model = _SimpleViT(dim=128, num_heads=8, depth=4)\n_x = torch.randn(1, 16, 128)\n_params_before = sum(p.numel() for p in _model.parameters())\n\n_pruner = Pruner(_model, pruning_ratio=0.3, context='local', criteria=large_final,\n example_inputs=_x, head_pruning_ratio=0.5)\n\n# Verify 4 attention layers detected (4 blocks × 1 each)\nassert len(_pruner.num_heads) == 4\nassert _pruner.prune_num_heads == True\n\n# Apply head pruning\n_pruner.prune_model()\n\n_params_after = sum(p.numel() for p in _model.parameters())\nassert _params_after < _params_before, f\"Expected params to decrease: {_params_before} -> {_params_after}\"\n\n# Verify heads actually removed across all layers\nfor block in _model.blocks:\n assert block.attn.num_heads < 8, f\"Expected heads < 8, got {block.attn.num_heads}\"\n assert block.attn.qkv.out_features == 3 * block.attn.num_heads * block.attn.head_dim\n\n# Print pruning report\n_pruner.print_sparsity()" + }, { "cell_type": "markdown", "id": "8hx4b80psx", diff --git a/nbs/prune/pruner.ipynb b/nbs/prune/pruner.ipynb index cd496ce..402f3bc 100644 --- a/nbs/prune/pruner.ipynb +++ b/nbs/prune/pruner.ipynb @@ -63,27 +63,7 @@ "cell_type": "markdown", "id": "oqm0as2fn3k", "metadata": {}, - "source": [ - "## Overview\n", - "\n", - "The `Pruner` class provides structured pruning capabilities using the [torch-pruning](https://github.com/VainF/Torch-Pruning) library. Unlike unstructured pruning (which zeros individual weights), structured pruning removes entire filters/channels, resulting in a genuinely smaller and faster model.\n", - "\n", - "**Key Features:**\n", - "- Automatic dependency handling across layers\n", - "- Support for both local (per-layer) and global (cross-layer) pruning\n", - "- Automatic detection and handling of attention layers in transformers\n", - "- Compatible with various importance criteria from `fasterai.core.criteria`\n", - "\n", - "### Sparsifier vs Pruner: When to Use Which?\n", - "\n", - "| Aspect | Sparsifier | Pruner |\n", - "|--------|------------|--------|\n", - "| **What it removes** | Individual weights (unstructured) | Entire filters/channels (structured) |\n", - "| **Model size** | Same architecture, sparse weights | Smaller architecture |\n", - "| **Speedup** | Requires sparse hardware/libraries | Immediate speedup on any hardware |\n", - "| **Accuracy impact** | Generally lower at same sparsity | May need fine-tuning |\n", - "| **Best for** | Research, sparse-aware inference | Production deployment |" - ] + "source": "## Overview\n\nThe `Pruner` class provides structured pruning capabilities using the [torch-pruning](https://github.com/VainF/Torch-Pruning) library. Unlike unstructured pruning (which zeros individual weights), structured pruning removes entire filters/channels, resulting in a genuinely smaller and faster model.\n\n**Key Features:**\n- Automatic dependency handling across layers\n- Support for both local (per-layer) and global (cross-layer) pruning\n- **Transformer head pruning** — remove entire attention heads via `head_pruning_ratio`\n- Automatic detection of attention layers (`nn.MultiheadAttention`, timm-style `qkv`)\n- Compatible with various importance criteria from `fasterai.core.criteria`\n\n### Sparsifier vs Pruner: When to Use Which?\n\n| Aspect | Sparsifier | Pruner |\n|--------|------------|--------|\n| **What it removes** | Individual weights (unstructured) | Entire filters/channels/heads (structured) |\n| **Model size** | Same architecture, sparse weights | Smaller architecture |\n| **Speedup** | Requires sparse hardware/libraries | Immediate speedup on any hardware |\n| **Accuracy impact** | Generally lower at same sparsity | May need fine-tuning |\n| **Transformer support** | Weight-level sparsity | Head pruning (`head_pruning_ratio`) |\n| **Best for** | Research, sparse-aware inference | Production deployment |" }, { "cell_type": "code", @@ -91,204 +71,7 @@ "id": "63acddeb-f30e-448b-a397-d4cac2adba7a", "metadata": {}, "outputs": [], - "source": [ - "#| export\n", - "from fasterai.core.schedule import Schedule\n", - "\n", - "class Pruner():\n", - " \"Structured pruning for neural networks using torch_pruning\"\n", - " def __init__(self, model, pruning_ratio, context, criteria, schedule=linear_scheduler, ignored_layers=None, example_inputs=torch.randn(1, 3, 224, 224), *args, **kwargs):\n", - " store_attr()\n", - " self.num_heads = {}\n", - " self._original_params = sum(p.numel() for p in model.parameters())\n", - " if not self.ignored_layers: self.get_ignored_layers(self.model)\n", - "\n", - " # Handle pruning_ratio as float or dict\n", - " self.pruning_ratio_dict = None\n", - " if isinstance(self.pruning_ratio, dict):\n", - " # Convert name-based dict to module-based dict for torch-pruning\n", - " self.pruning_ratio_dict = self._resolve_pruning_ratio_dict(self.pruning_ratio)\n", - " self.default_pruning_ratio = kwargs.pop('default_pruning_ratio', 0.0)\n", - " print(f\"Using per-layer pruning with {len(self.pruning_ratio_dict)} layer-specific ratios\")\n", - " else:\n", - " if self.pruning_ratio > 1: self.pruning_ratio = self.pruning_ratio / 100\n", - " if not (0 < self.pruning_ratio <= 1):\n", - " raise ValueError(f\"pruning_ratio must be in range (0, 1], got {self.pruning_ratio}\")\n", - " self.default_pruning_ratio = self.pruning_ratio\n", - "\n", - " # Convert Schedule object to torch-pruning compatible function\n", - " tp_schedule = self._to_tp_scheduler(self.schedule)\n", - "\n", - " self.pruner = tp.pruner.MetaPruner(\n", - " self.model,\n", - " example_inputs=self.example_inputs.to(next(self.model.parameters()).device),\n", - " importance=self.group_importance,\n", - " pruning_ratio=self.default_pruning_ratio,\n", - " pruning_ratio_dict=self.pruning_ratio_dict,\n", - " ignored_layers=self.ignored_layers,\n", - " global_pruning=True if self.context=='global' else False,\n", - " num_heads=self.num_heads,\n", - " iterative_pruning_ratio_scheduler=tp_schedule,\n", - " *args,\n", - " **kwargs\n", - " )\n", - "\n", - " def _build_pruning_schedule(self, sched_func):\n", - " \"Create a schedule function compatible with torch-pruning's Pruner\"\n", - " def scheduler(pruning_ratio, steps, start=0, end=1):\n", - " return [\n", - " sched_func(start, end, i / float(steps)) * pruning_ratio\n", - " for i in range(steps + 1)\n", - " ]\n", - " return scheduler\n", - "\n", - " def _to_tp_scheduler(self, schedule):\n", - " \"Convert Schedule object or callable to torch-pruning compatible scheduler\"\n", - " # If it's a Schedule object, extract sched_func and build compatible function\n", - " if isinstance(schedule, Schedule):\n", - " return self._build_pruning_schedule(schedule.sched_func)\n", - " # Otherwise assume it's already a compatible callable (like linear_scheduler)\n", - " return schedule\n", - "\n", - " def _resolve_pruning_ratio_dict(self, ratio_dict):\n", - " \"Convert layer name strings to module references for torch-pruning\"\n", - " name_to_module = dict(self.model.named_modules())\n", - " resolved = {}\n", - " for key, ratio in ratio_dict.items():\n", - " if isinstance(key, str):\n", - " if key in name_to_module:\n", - " module = name_to_module[key]\n", - " # Normalize ratio to 0-1 range\n", - " resolved[module] = ratio / 100 if ratio > 1 else ratio\n", - " else:\n", - " print(f\"Warning: Layer '{key}' not found in model, skipping\")\n", - " elif isinstance(key, nn.Module):\n", - " resolved[key] = ratio / 100 if ratio > 1 else ratio\n", - " return resolved\n", - " \n", - " def prune_model(self):\n", - " \"Execute one pruning step and restore attention layer configurations\"\n", - " self.pruner.step()\n", - " self.restore_attention_layers()\n", - "\n", - "\n", - " def get_linear_layers_to_ignore(self, \n", - " model: nn.Module # The model to analyze\n", - " ):\n", - " \"Find and ignore output Linear layers to preserve model output dimensions\"\n", - " try:\n", - " traced = symbolic_trace(model)\n", - " for node in traced.graph.nodes:\n", - " if node.op == \"output\": # Identify the output\n", - " for input_node in node.all_input_nodes:\n", - " if input_node.target: # Find the corresponding layer\n", - " module = dict(model.named_modules()).get(input_node.target)\n", - " if isinstance(module, torch.nn.Linear):\n", - " self.ignored_layers.append(module)\n", - " print(f\"Ignoring output layer: {input_node.target}\")\n", - " except Exception as e:\n", - " print(f\"Could not trace model for output layer detection: {e}\")\n", - "\n", - "\n", - " def get_attention_layers_to_ignore(self, \n", - " model: nn.Module # The model to analyze\n", - " ):\n", - " \"Find and ignore attention layers (qkv projections) to preserve attention structure\"\n", - " for module in model.modules():\n", - " if hasattr(module, 'num_heads'):\n", - " if hasattr(module, 'qkv'):\n", - " self.ignored_layers.append(module.qkv)\n", - " self.num_heads[module.qkv] = module.num_heads\n", - " print(f\"Attention layer ignored: {module.qkv}, num_heads={module.num_heads}\")\n", - " elif hasattr(module, 'qkv_proj'):\n", - " self.ignored_layers.append(module.qkv_proj)\n", - " self.num_heads[module.qkv_proj] = module.num_heads\n", - " print(f\"Attention layer ignored: {module.qkv_proj}, num_heads={module.num_heads}\")\n", - "\n", - " \n", - " def get_ignored_layers(self, \n", - " model: nn.Module # The model to analyze\n", - " ):\n", - " \"Build list of layers to ignore during pruning\"\n", - " self.ignored_layers = []\n", - " self.get_linear_layers_to_ignore(model)\n", - " self.get_attention_layers_to_ignore(model)\n", - " print(f\"Total ignored layers: {len(self.ignored_layers)}\")\n", - " \n", - " \n", - " def restore_attention_layers(self):\n", - " \"Restore num_heads and head_dim attributes after pruning attention layers\"\n", - " for m in self.model.modules():\n", - " if hasattr(m, 'num_heads'):\n", - " if hasattr(m, 'qkv'):\n", - " m.num_heads = self.num_heads[m.qkv]\n", - " m.head_dim = m.qkv.out_features // (3 * m.num_heads)\n", - " elif hasattr(m, 'qkv_proj'):\n", - " m.num_heads = self.num_heads[m.qkv_proj]\n", - " m.head_dim = m.qkv_proj.out_features // (3 * m.num_heads)\n", - "\n", - "\n", - " def group_importance(self, group):\n", - " \"Compute importance scores for a dependency group\"\n", - " handler_map = {\n", - " function.prune_conv_out_channels: 'filter',\n", - " function.prune_linear_out_channels: 'row',\n", - " function.prune_linear_in_channels: 'column',\n", - " function.prune_conv_in_channels: 'shared_kernel',\n", - " }\n", - " \n", - " group_imp = []\n", - " group_idxs = []\n", - " \n", - " for i, (dep, idxs) in enumerate(group):\n", - " if dep.handler in handler_map:\n", - " impo = self.criteria(dep.target.module, handler_map.get(dep.handler), squeeze=True)\n", - " group_imp.append(impo)\n", - " group_idxs.append(group[i].root_idxs)\n", - " \n", - " if len(group_imp) == 0:\n", - " return torch.tensor([])\n", - " \n", - " reduced_imp = torch.zeros_like(group_imp[0])\n", - " \n", - " for i, (imp, root_idxs) in enumerate(zip(group_imp, group_idxs)):\n", - " imp = imp.to('cpu')\n", - " reduced_imp = reduced_imp.to('cpu')\n", - " reduced_imp.scatter_add_(0, torch.tensor(root_idxs, device=imp.device), imp)\n", - " \n", - " reduced_imp /= len(group_imp)\n", - " \n", - " return reduced_imp.to(default_device())\n", - "\n", - " def print_sparsity(self) -> None:\n", - " \"Print pruning report showing channel counts and parameter reduction\"\n", - " total_params = 0\n", - " \n", - " print(\"\\nPruning Report:\")\n", - " print(\"-\" * 85)\n", - " print(f\"{'Layer':<35} {'Type':<12} {'In Ch':<8} {'Out Ch':<8} {'Params':<12}\")\n", - " print(\"-\" * 85)\n", - " \n", - " for name, m in self.model.named_modules():\n", - " if isinstance(m, (nn.Conv2d, nn.Linear)):\n", - " params = sum(p.numel() for p in m.parameters())\n", - " total_params += params\n", - " \n", - " if isinstance(m, nn.Conv2d):\n", - " in_ch, out_ch = m.in_channels, m.out_channels\n", - " layer_type = \"Conv2d\"\n", - " else:\n", - " in_ch, out_ch = m.in_features, m.out_features\n", - " layer_type = \"Linear\"\n", - " \n", - " print(f\"{name:<35} {layer_type:<12} {in_ch:<8} {out_ch:<8} {params:<12,}\")\n", - " \n", - " print(\"-\" * 85)\n", - " reduction = 100 * (1 - total_params / self._original_params) if self._original_params > 0 else 0\n", - " print(f\"{'Total':<35} {'':<12} {'':<8} {'':<8} {total_params:<12,}\")\n", - " print(f\"{'Original':<35} {'':<12} {'':<8} {'':<8} {self._original_params:<12,}\")\n", - " print(f\"{'Reduction':<35} {'':<12} {'':<8} {'':<8} {reduction:>10.2f}%\")" - ] + "source": "#| export\nfrom fasterai.core.schedule import Schedule\n\ndef _pruning_compatible_forward(self, x, attn_mask=None):\n \"Attention forward using reshape(B, N, -1) for pruning compatibility\"\n B, N, C = x.shape\n qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)\n q, k, v = qkv.unbind(0)\n q, k = self.q_norm(q), self.k_norm(k)\n if self.fused_attn:\n x = F.scaled_dot_product_attention(\n q, k, v, attn_mask=attn_mask,\n dropout_p=self.attn_drop.p if self.training else 0.)\n else:\n q = q * self.scale\n attn = q @ k.transpose(-2, -1)\n if attn_mask is not None: attn = attn + attn_mask\n attn = attn.softmax(dim=-1)\n attn = self.attn_drop(attn)\n x = attn @ v\n x = x.transpose(1, 2).reshape(B, N, -1)\n x = self.proj(x)\n x = self.proj_drop(x)\n return x\n\nclass Pruner():\n \"Structured pruning for neural networks using torch_pruning\"\n def __init__(self,\n model, # The PyTorch model to prune\n pruning_ratio, # Channel pruning ratio (float 0-1, int 0-100, or dict)\n context, # 'local' or 'global'\n criteria, # Importance criteria (e.g. large_final)\n schedule=linear_scheduler,# Pruning schedule\n ignored_layers=None, # Layers to skip during pruning\n example_inputs=torch.randn(1, 3, 224, 224), # Dummy input for tracing\n head_pruning_ratio=0.0, # Ratio of attention heads to remove (0-1 or 0-100)\n prune_num_heads=False, # Remove entire attention heads\n prune_head_dims=True, # Reduce head dimensions\n *args, **kwargs,\n ):\n store_attr()\n self.num_heads = {}\n self._original_num_heads = {}\n self._heads_pruned = False\n self._original_params = sum(p.numel() for p in model.parameters())\n\n # Normalize head_pruning_ratio\n if self.head_pruning_ratio > 1: self.head_pruning_ratio = self.head_pruning_ratio / 100\n # Auto-enable: head_pruning_ratio > 0 implies prune whole heads, not dims (XOR pattern)\n if self.head_pruning_ratio > 0:\n self.prune_num_heads = True\n self.prune_head_dims = False\n\n if not self.ignored_layers: self.get_ignored_layers(self.model)\n\n # Handle pruning_ratio as float or dict\n self.pruning_ratio_dict = None\n if isinstance(self.pruning_ratio, dict):\n self.pruning_ratio_dict = self._resolve_pruning_ratio_dict(self.pruning_ratio)\n self.default_pruning_ratio = kwargs.pop('default_pruning_ratio', 0.0)\n print(f\"Using per-layer pruning with {len(self.pruning_ratio_dict)} layer-specific ratios\")\n else:\n if self.pruning_ratio > 1: self.pruning_ratio = self.pruning_ratio / 100\n if not (0 < self.pruning_ratio <= 1):\n raise ValueError(f\"pruning_ratio must be in range (0, 1], got {self.pruning_ratio}\")\n self.default_pruning_ratio = self.pruning_ratio\n\n # Convert Schedule object to torch-pruning compatible function\n tp_schedule = self._to_tp_scheduler(self.schedule)\n\n self.pruner = tp.pruner.MetaPruner(\n self.model,\n example_inputs=self.example_inputs.to(next(self.model.parameters()).device),\n importance=self.group_importance,\n pruning_ratio=self.default_pruning_ratio,\n pruning_ratio_dict=self.pruning_ratio_dict,\n ignored_layers=self.ignored_layers,\n global_pruning=True if self.context=='global' else False,\n num_heads=self.num_heads,\n prune_num_heads=self.prune_num_heads,\n prune_head_dims=self.prune_head_dims,\n head_pruning_ratio=self.head_pruning_ratio,\n iterative_pruning_ratio_scheduler=tp_schedule,\n *args,\n **kwargs\n )\n\n def _build_pruning_schedule(self, sched_func):\n \"Create a schedule function compatible with torch-pruning's Pruner\"\n def scheduler(pruning_ratio, steps, start=0, end=1):\n return [\n sched_func(start, end, i / float(steps)) * pruning_ratio\n for i in range(steps + 1)\n ]\n return scheduler\n\n def _to_tp_scheduler(self, schedule):\n \"Convert Schedule object or callable to torch-pruning compatible scheduler\"\n if isinstance(schedule, Schedule):\n return self._build_pruning_schedule(schedule.sched_func)\n return schedule\n\n def _resolve_pruning_ratio_dict(self, ratio_dict):\n \"Convert layer name strings to module references for torch-pruning\"\n name_to_module = dict(self.model.named_modules())\n resolved = {}\n for key, ratio in ratio_dict.items():\n if isinstance(key, str):\n if key in name_to_module:\n module = name_to_module[key]\n resolved[module] = ratio / 100 if ratio > 1 else ratio\n else:\n print(f\"Warning: Layer '{key}' not found in model, skipping\")\n elif isinstance(key, nn.Module):\n resolved[key] = ratio / 100 if ratio > 1 else ratio\n return resolved\n\n def _patch_attention_forward(self,\n module: nn.Module # Attention module with .qkv\n ):\n \"Patch attention forward to use reshape(B,N,-1) for pruning compatibility\"\n # Only patch if the module has the timm Attention interface\n if not all(hasattr(module, a) for a in ('qkv', 'proj', 'proj_drop', 'attn_drop',\n 'q_norm', 'k_norm', 'fused_attn', 'scale')):\n return\n module.forward = _pruning_compatible_forward.__get__(module, type(module))\n\n def _detect_attention_heads(self,\n model: nn.Module # The model to analyze\n ):\n \"Detect attention layers with QKV projections and populate num_heads mapping\"\n for module in model.modules():\n # nn.MultiheadAttention uses raw parameters (in_proj_weight), not Linear\n # submodules — torch-pruning's head pruning requires .out_features on qkv\n # layers, so only timm-style attention (with .qkv Linear) is supported.\n if isinstance(module, nn.MultiheadAttention):\n self.ignored_layers.append(module)\n continue\n if hasattr(module, 'num_heads'):\n if hasattr(module, 'qkv'):\n self.num_heads[module.qkv] = module.num_heads\n # Patch forward to use reshape(B,N,-1) — required for head pruning\n # (official torch-pruning pattern from prune_timm_vit.py)\n if self.prune_num_heads: self._patch_attention_forward(module)\n elif hasattr(module, 'qkv_proj'):\n self.num_heads[module.qkv_proj] = module.num_heads\n self._original_num_heads = dict(self.num_heads)\n\n def get_linear_layers_to_ignore(self,\n model: nn.Module # The model to analyze\n ):\n \"Find and ignore output Linear layers to preserve model output dimensions\"\n try:\n traced = symbolic_trace(model)\n for node in traced.graph.nodes:\n if node.op == \"output\":\n for input_node in node.all_input_nodes:\n if input_node.target:\n module = dict(model.named_modules()).get(input_node.target)\n if isinstance(module, torch.nn.Linear):\n self.ignored_layers.append(module)\n print(f\"Ignoring output layer: {input_node.target}\")\n except Exception as e:\n print(f\"Could not trace model for output layer detection: {e}\")\n\n def get_ignored_layers(self,\n model: nn.Module # The model to analyze\n ):\n \"Build list of layers to ignore during pruning\"\n self.ignored_layers = []\n self.get_linear_layers_to_ignore(model)\n self._detect_attention_heads(model)\n # Ignore QKV layers only when head pruning is disabled\n if not self.prune_num_heads and self.head_pruning_ratio == 0:\n for layer in self.num_heads:\n self.ignored_layers.append(layer)\n if self.num_heads:\n action = \"will be pruned\" if self.prune_num_heads else \"ignored\"\n print(f\"Detected {len(self.num_heads)} attention layer(s) ({action})\")\n print(f\"Total ignored layers: {len(self.ignored_layers)}\")\n\n def _freeze_head_pruning(self):\n \"Disable further head pruning after first application to prevent over-pruning\"\n # torch-pruning computes head removal from current count (not original),\n # so repeated steps would halve heads each time (6→3→1→0). We freeze\n # the head pruning ratios after the first step that modifies heads.\n if not self._heads_pruned and self.prune_num_heads:\n heads_changed = any(\n self.pruner.num_heads.get(layer, orig) != orig\n for layer, orig in self._original_num_heads.items()\n )\n if heads_changed:\n self._heads_pruned = True\n for i in range(len(self.pruner.per_step_head_pruning_ratio)):\n if i > self.pruner.current_step:\n self.pruner.per_step_head_pruning_ratio[i] = 0.0\n\n def _sync_attention_attrs(self):\n \"Sync attention module attributes with torch-pruning's updated head counts\"\n for module in self.model.modules():\n if not hasattr(module, 'num_heads') or isinstance(module, nn.MultiheadAttention):\n continue\n if hasattr(module, 'qkv') and module.qkv in self.pruner.num_heads:\n module.num_heads = self.pruner.num_heads[module.qkv]\n module.head_dim = module.qkv.out_features // (3 * module.num_heads)\n elif hasattr(module, 'qkv_proj') and module.qkv_proj in self.pruner.num_heads:\n module.num_heads = self.pruner.num_heads[module.qkv_proj]\n module.head_dim = module.qkv_proj.out_features // (3 * module.num_heads)\n\n def prune_model(self):\n \"Execute one pruning step and sync attention layer attributes\"\n self.pruner.step()\n self._sync_attention_attrs()\n self._freeze_head_pruning()\n\n def group_importance(self, group):\n \"Compute importance scores for a dependency group\"\n handler_map = {\n function.prune_conv_out_channels: 'filter',\n function.prune_linear_out_channels: 'row',\n function.prune_linear_in_channels: 'column',\n function.prune_conv_in_channels: 'shared_kernel',\n }\n\n group_imp = []\n group_idxs = []\n\n for i, (dep, idxs) in enumerate(group):\n if dep.handler in handler_map:\n impo = self.criteria(dep.target.module, handler_map.get(dep.handler), squeeze=True)\n group_imp.append(impo)\n group_idxs.append(group[i].root_idxs)\n\n if len(group_imp) == 0:\n return torch.tensor([])\n\n reduced_imp = torch.zeros_like(group_imp[0])\n\n for i, (imp, root_idxs) in enumerate(zip(group_imp, group_idxs)):\n imp = imp.to('cpu')\n reduced_imp = reduced_imp.to('cpu')\n reduced_imp.scatter_add_(0, torch.tensor(root_idxs, device=imp.device), imp)\n\n reduced_imp /= len(group_imp)\n\n return reduced_imp.to(default_device())\n\n def print_sparsity(self) -> None:\n \"Print pruning report showing channel counts and parameter reduction\"\n total_params = 0\n\n print(\"\\nPruning Report:\")\n print(\"-\" * 85)\n print(f\"{'Layer':<35} {'Type':<12} {'In Ch':<8} {'Out Ch':<8} {'Params':<12}\")\n print(\"-\" * 85)\n\n for name, m in self.model.named_modules():\n if isinstance(m, (nn.Conv2d, nn.Linear)):\n params = sum(p.numel() for p in m.parameters())\n total_params += params\n\n if isinstance(m, nn.Conv2d):\n in_ch, out_ch = m.in_channels, m.out_channels\n layer_type = \"Conv2d\"\n else:\n in_ch, out_ch = m.in_features, m.out_features\n layer_type = \"Linear\"\n\n print(f\"{name:<35} {layer_type:<12} {in_ch:<8} {out_ch:<8} {params:<12,}\")\n\n print(\"-\" * 85)\n reduction = 100 * (1 - total_params / self._original_params) if self._original_params > 0 else 0\n print(f\"{'Total':<35} {'':<12} {'':<8} {'':<8} {total_params:<12,}\")\n print(f\"{'Original':<35} {'':<12} {'':<8} {'':<8} {self._original_params:<12,}\")\n print(f\"{'Reduction':<35} {'':<12} {'':<8} {'':<8} {reduction:>10.2f}%\")\n\n # Head count changes\n if self._original_num_heads:\n print(f\"\\n{'Attention Heads':}\")\n print(\"-\" * 50)\n for layer, orig in self._original_num_heads.items():\n current = self.pruner.num_heads.get(layer, orig)\n name = next((n for n, m in self.model.named_modules() if m is layer), str(layer))\n status = f\"{orig} -> {current}\" if current != orig else f\"{orig} (unchanged)\"\n print(f\" {name:<33} {status}\")" }, { "cell_type": "code", @@ -456,63 +239,23 @@ "cell_type": "markdown", "id": "60753d13-3dab-4f86-9d5d-a2bf0fb7dd48", "metadata": {}, - "source": [ - "```python\n", - "model = resnet18()\n", - "pruner = Pruner(model, 30, 'local', large_final)\n", - "pruner.prune_model()\n", - "```" - ] + "source": "```python\n# CNN pruning — remove 30% of filters\nmodel = resnet18()\npruner = Pruner(model, 30, 'local', large_final)\npruner.prune_model()\n```\n\n```python\n# Transformer head pruning — remove 50% of attention heads\nmodel = vit_small_patch16_224()\npruner = Pruner(model, pruning_ratio=0.3, context='local', criteria=large_final,\n example_inputs=torch.randn(1, 3, 224, 224), head_pruning_ratio=0.5)\npruner.prune_model()\npruner.print_sparsity() # shows head count changes\n```" }, { "cell_type": "code", "execution_count": null, "id": "x0g8lno4gfp", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Ignoring output layer: 8\n", - "Total ignored layers: 1\n" - ] - } - ], - "source": [ - "#| hide\n", - "from fastcore.test import *\n", - "\n", - "def _test_model():\n", - " return nn.Sequential(\n", - " nn.Conv2d(3, 16, 3, padding=1),\n", - " nn.BatchNorm2d(16),\n", - " nn.ReLU(),\n", - " nn.Conv2d(16, 32, 3, padding=1),\n", - " nn.BatchNorm2d(32),\n", - " nn.ReLU(),\n", - " nn.AdaptiveAvgPool2d(1),\n", - " nn.Flatten(),\n", - " nn.Linear(32, 10)\n", - " )\n", - "\n", - "# Pruner construction\n", - "model = _test_model()\n", - "x = torch.randn(1, 3, 8, 8)\n", - "pruner = Pruner(model, 30, 'local', large_final, example_inputs=x)\n", - "assert pruner is not None\n", - "\n", - "# Prune model reduces parameter count\n", - "params_before = sum(p.numel() for p in model.parameters())\n", - "pruner.prune_model()\n", - "params_after = sum(p.numel() for p in model.parameters())\n", - "assert params_after < params_before\n", - "\n", - "# Model still produces valid output after pruning\n", - "out = model(x)\n", - "test_eq(out.shape[0], 1) # batch dim preserved\n", - "test_eq(out.shape[1], 10) # output classes preserved" - ] + "outputs": [], + "source": "#| hide\nfrom fastcore.test import *\n\ndef _test_model():\n return nn.Sequential(\n nn.Conv2d(3, 16, 3, padding=1),\n nn.BatchNorm2d(16),\n nn.ReLU(),\n nn.Conv2d(16, 32, 3, padding=1),\n nn.BatchNorm2d(32),\n nn.ReLU(),\n nn.AdaptiveAvgPool2d(1),\n nn.Flatten(),\n nn.Linear(32, 10)\n )\n\n# Pruner construction\nmodel = _test_model()\nx = torch.randn(1, 3, 8, 8)\npruner = Pruner(model, 30, 'local', large_final, example_inputs=x)\nassert pruner is not None\n\n# Prune model reduces parameter count\nparams_before = sum(p.numel() for p in model.parameters())\npruner.prune_model()\nparams_after = sum(p.numel() for p in model.parameters())\nassert params_after < params_before\n\n# Model still produces valid output after pruning\nout = model(x)\ntest_eq(out.shape[0], 1) # batch dim preserved\ntest_eq(out.shape[1], 10) # output classes preserved\n\n# CNN model with head_pruning_ratio=0.0 — backward compatible (no change)\nmodel2 = _test_model()\npruner2 = Pruner(model2, 30, 'local', large_final, example_inputs=x, head_pruning_ratio=0.0)\ntest_eq(len(pruner2.num_heads), 0) # no attention layers detected\ntest_eq(pruner2.prune_num_heads, False)" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ab66a904", + "metadata": {}, + "outputs": [], + "source": "#| hide\n# Transformer head pruning tests\n\nclass _Attention(nn.Module):\n \"Timm-style multi-head attention with fused QKV projection\"\n def __init__(self, dim, num_heads):\n super().__init__()\n self.num_heads = num_heads\n self.head_dim = dim // num_heads\n self.qkv = nn.Linear(dim, dim * 3)\n self.proj = nn.Linear(dim, dim)\n def forward(self, x):\n B, N, _ = x.shape\n qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)\n q, k, v = qkv.unbind(0)\n attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)\n x = (attn.softmax(dim=-1) @ v).transpose(1, 2).reshape(B, N, -1)\n return self.proj(x)\n\nclass _Block(nn.Module):\n def __init__(self, dim, num_heads):\n super().__init__()\n self.norm1 = nn.LayerNorm(dim)\n self.attn = _Attention(dim, num_heads)\n self.norm2 = nn.LayerNorm(dim)\n self.mlp = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Linear(dim*4, dim))\n def forward(self, x):\n x = x + self.attn(self.norm1(x))\n x = x + self.mlp(self.norm2(x))\n return x\n\nclass _SimpleViT(nn.Module):\n def __init__(self, dim=64, num_heads=4, depth=2, num_classes=10):\n super().__init__()\n self.blocks = nn.Sequential(*[_Block(dim, num_heads) for _ in range(depth)])\n self.norm = nn.LayerNorm(dim)\n self.head = nn.Linear(dim, num_classes)\n def forward(self, x):\n x = self.blocks(x)\n return self.head(self.norm(x).mean(dim=1))\n\n# Head pruning with timm-style ViT\nmodel = _SimpleViT(dim=64, num_heads=4, depth=2)\nx = torch.randn(1, 10, 64)\nparams_before = sum(p.numel() for p in model.parameters())\n\npruner = Pruner(model, pruning_ratio=0.3, context='local', criteria=large_final,\n example_inputs=x, head_pruning_ratio=0.5)\n\n# Verify attention layers detected (2 blocks × 1 attention each)\ntest_eq(len(pruner.num_heads), 2)\ntest_eq(pruner.prune_num_heads, True)\ntest_eq(pruner.prune_head_dims, False) # XOR pattern\n\n# Prune and verify params decreased\npruner.prune_model()\nparams_after = sum(p.numel() for p in model.parameters())\nassert params_after < params_before, f\"Expected param reduction: {params_before} -> {params_after}\"\n\n# Verify heads were actually removed via _sync_attention_attrs\nfor block in model.blocks:\n assert block.attn.num_heads < 4, f\"Expected heads < 4, got {block.attn.num_heads}\"\n # head_dim unchanged (XOR pattern — whole heads removed, dim preserved)\n test_eq(block.attn.head_dim, 16)\n\n# QKV output channels reflect head removal: 3 * num_heads * head_dim\nexpected_qkv_out = 3 * model.blocks[0].attn.num_heads * model.blocks[0].attn.head_dim\ntest_eq(model.blocks[0].attn.qkv.out_features, expected_qkv_out)\n\n# head_pruning_ratio normalization (50 -> 0.5)\nmodel2 = _SimpleViT(dim=64, num_heads=4, depth=2)\npruner2 = Pruner(model2, 0.3, 'local', large_final, example_inputs=x, head_pruning_ratio=50)\ntest_close(pruner2.head_pruning_ratio, 0.5, eps=1e-6)" }, { "cell_type": "markdown", @@ -540,4 +283,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/nbs/tutorials/misc/conv_decomposer.ipynb b/nbs/tutorials/misc/conv_decomposer.ipynb new file mode 100644 index 0000000..d9a0adc --- /dev/null +++ b/nbs/tutorials/misc/conv_decomposer.ipynb @@ -0,0 +1,101 @@ +{ + "cells": [ + { + "cell_type": "raw", + "metadata": {}, + "source": ["---\n", "title: \"Conv Decomposer\"\n", "description: \"Compress Conv2d layers using Tucker decomposition for faster inference\"\n", "skip_exec: true\n", "---"] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": ["## Introduction\n", "\n", "Tucker decomposition factorizes a single large Conv2d layer into three smaller convolutions, reducing parameter count while preserving the spatial structure of the learned filters. Unlike pruning (which removes channels) or sparsification (which zeros weights), decomposition replaces layers with mathematically equivalent smaller structures.\n", "\n", "A `Conv2d(C_in, C_out, K\u00d7K)` becomes a sequence of: (1) a pointwise `1\u00d71` convolution that compresses input channels, (2) a spatial `K\u00d7K` convolution at reduced rank, and (3) a pointwise `1\u00d71` convolution that expands back to the original output channels. The result is a drop-in replacement that runs faster and uses fewer parameters."] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": ["## Setup"] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": ["import torch, torch.nn as nn\n", "from fasterai.misc.conv_decomposer import Conv_Decomposer\n", "from torchvision.models import resnet18"] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": ["## Quick Example\n", "\n", "Decompose a pretrained ResNet-18, removing 50% of the rank:"] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": ["Original: 11,689,512 params\n", "Compressed: 5,843,018 params\n", "Compression: 2.00x\n"] + } + ], + "source": ["model = resnet18(pretrained=True)\n", "decomposer = Conv_Decomposer()\n", "compressed = decomposer.decompose(model, percent_removed=0.5)\n", "\n", "orig = sum(p.numel() for p in model.parameters())\n", "comp = sum(p.numel() for p in compressed.parameters())\n", "print(f'Original: {orig:,} params')\n", "print(f'Compressed: {comp:,} params')\n", "print(f'Compression: {orig/comp:.2f}x')"] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": ["## How Tucker Decomposition Works\n", "\n", "A single convolution is factorized into three smaller ones:\n", "\n", "```\n", "Original: Conv2d(64, 128, 3\u00d73) = 73,728 params\n", " \u2193\n", "Decomposed: Conv2d(64 \u2192 32, 1\u00d71) = 2,048 params (compress input)\n", " Conv2d(32 \u2192 64, 3\u00d73) = 18,432 params (spatial filter)\n", " Conv2d(64 \u2192 128, 1\u00d71) = 8,192 params (expand output)\n", " Total = 28,672 params (2.6\u00d7 smaller)\n", "```\n", "\n", "The reduced ranks `R_in` and `R_out` are controlled by `percent_removed`:\n", "- `R_in = max(1, int((1 - percent_removed) \u00d7 C_in))`\n", "- `R_out = max(1, int((1 - percent_removed) \u00d7 C_out))`\n", "\n", "The decomposition uses the HOOI (Higher-Order Orthogonal Iteration) algorithm with 5 iterations to find the best low-rank approximation."] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": ["## Choosing `percent_removed`\n", "\n", "Higher values mean more compression but lower fidelity:"] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": ["percent_removed=0.25: 8,721,344 params (1.34x)\n", "percent_removed=0.50: 5,843,018 params (2.00x)\n", "percent_removed=0.75: 3,412,096 params (3.43x)\n"] + } + ], + "source": ["model = resnet18(pretrained=True)\n", "orig = sum(p.numel() for p in model.parameters())\n", "\n", "for pct in [0.25, 0.50, 0.75]:\n", " compressed = Conv_Decomposer().decompose(resnet18(pretrained=True), percent_removed=pct)\n", " params = sum(p.numel() for p in compressed.parameters())\n", " print(f'percent_removed={pct}: {params:>10,} params ({orig/params:.2f}x)')"] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": ["## Layer Skipping Rules\n", "\n", "Not all Conv2d layers are decomposed:\n", "\n", "| Layer Type | Decomposed? | Reason |\n", "|-----------|------------|--------|\n", "| Conv2d 3\u00d73, 5\u00d75, 7\u00d77 | **Yes** | Main target \u2014 spatial convolutions |\n", "| Conv2d 1\u00d71 (pointwise) | No | Already minimal, no spatial redundancy |\n", "| Depthwise Conv2d (groups > 1) | No | Group structure incompatible with Tucker |\n", "| First layer (C_in=3) | Yes | Small benefit, but still decomposed |"] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": ["## Fine-tuning After Decomposition\n", "\n", "Tucker decomposition is an approximation (HOOI is iterative, not exact). Fine-tuning recovers accuracy lost during decomposition:\n", "\n", "```python\n", "compressed = Conv_Decomposer().decompose(model, percent_removed=0.5)\n", "\n", "learn = Learner(dls, compressed, loss_func=CrossEntropyLoss())\n", "learn.fit(5, lr=1e-4) # short fine-tuning pass\n", "```\n", "\n", "Typically 3\u20135 epochs at a low learning rate is sufficient to recover most accuracy."] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": ["## Conv Decomposer vs Pruning\n", "\n", "| Aspect | Conv Decomposer | Pruner |\n", "|--------|----------------|--------|\n", "| **What changes** | Replaces layers with 3-layer sequences | Removes channels entirely |\n", "| **Architecture** | More layers, each smaller | Fewer channels per layer |\n", "| **Accuracy recovery** | Fine-tuning recommended | Can prune during training |\n", "| **Best for** | Post-training inference optimization | Training-time compression |\n", "| **Hardware benefit** | Fewer FLOPs via smaller convolutions | Fewer channels = less memory |"] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": ["---\n", "\n", "## Summary\n", "\n", "| Tool / Function | Purpose |\n", "|----------------|----------|\n", "| `Conv_Decomposer()` | Create a decomposer instance |\n", "| `.decompose(model, percent_removed)` | Decompose all eligible Conv2d layers |\n", "| `.Tucker(layer, percent_removed)` | Decompose a single Conv2d layer |\n", "| `percent_removed` | Fraction of rank to remove (0\u20131) |"] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": ["---\n", "\n", "## See Also\n", "\n", "- [FC Decomposer](fc_decomposer.html) \u2014 SVD-based decomposition for Linear layers\n", "- [Pruner](../prune/pruner.html) \u2014 Alternative: structured pruning (removes channels)\n", "- [Sparsifier](../sparse/sparsifier.html) \u2014 Alternative: unstructured sparsification"] + } + ], + "metadata": { + "kernelspec": { + "display_name": "python3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/nbs/tutorials/prune/head_pruning.ipynb b/nbs/tutorials/prune/head_pruning.ipynb new file mode 100644 index 0000000..3fb2f6b --- /dev/null +++ b/nbs/tutorials/prune/head_pruning.ipynb @@ -0,0 +1,909 @@ +{ + "cells": [ + { + "cell_type": "raw", + "id": "496d672a", + "metadata": {}, + "source": [ + "---\n", + "title: \"Transformer Head Pruning\"\n", + "description: \"Remove entire attention heads from Vision Transformers using structured pruning\"\n", + "skip_exec: true\n", + "---" + ] + }, + { + "cell_type": "markdown", + "id": "ac0cdf77", + "metadata": {}, + "source": [ + "## Introduction\n", + "\n", + "Attention heads in Vision Transformers are often redundant — Michel et al. (2019) showed that many heads can be removed with minimal accuracy loss. **Head pruning** removes entire attention heads from a transformer, reducing both parameters and computation.\n", + "\n", + "fasterai's `Pruner` leverages [torch-pruning](https://github.com/VainF/Torch-Pruning)'s built-in head pruning support. When head pruning is enabled, the Pruner automatically patches timm attention modules to be pruning-compatible (following the official [torch-pruning ViT examples](https://github.com/VainF/Torch-Pruning/tree/master/examples/transformers))." + ] + }, + { + "cell_type": "markdown", + "id": "849426c0", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "dea2c9f3", + "metadata": {}, + "outputs": [], + "source": [ + "import torch, torch.nn as nn\n", + "from fasterai.prune.pruner import Pruner\n", + "from fasterai.core.criteria import large_final" + ] + }, + { + "cell_type": "markdown", + "id": "4e5d47e1", + "metadata": {}, + "source": [ + "## Quick Example\n", + "\n", + "Prune 50% of attention heads from a timm ViT — just add `head_pruning_ratio`:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "560c2e0a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Ignoring output layer: head\n", + "Detected 12 attention layer(s) (will be pruned)\n", + "Total ignored layers: 1\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/nathan/miniconda3/envs/dev/lib/python3.12/site-packages/torch_pruning/dependency.py:712: UserWarning: Unwrapped parameters detected: ['cls_token', 'pos_embed'].\n", + " Torch-Pruning will prune the last non-singleton dimension of these parameters. If you wish to change this behavior, please provide an unwrapped_parameters argument.\n", + " warnings.warn(warning_str)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Before: 6 heads, embed_dim=384, params=21,669,514\n", + "After: 3 heads, embed_dim=192, params=9,685,778 (55% reduction)\n", + "\n", + "Inference OK: torch.Size([1, 10])\n" + ] + } + ], + "source": [ + "from timm import create_model\n", + "\n", + "model = create_model('vit_small_patch16_224', pretrained=True, num_classes=10)\n", + "x = torch.randn(1, 3, 224, 224)\n", + "params_before = sum(p.numel() for p in model.parameters())\n", + "\n", + "pruner = Pruner(model, pruning_ratio=0.3, context='local', criteria=large_final,\n", + " example_inputs=x, head_pruning_ratio=0.5)\n", + "pruner.prune_model()\n", + "\n", + "attn = model.blocks[0].attn\n", + "params_after = sum(p.numel() for p in model.parameters())\n", + "print(f'Before: 6 heads, embed_dim=384, params={params_before:,}')\n", + "print(f'After: {attn.num_heads} heads, embed_dim={attn.num_heads * attn.head_dim}, params={params_after:,} ({100*(1-params_after/params_before):.0f}% reduction)')\n", + "print()\n", + "print(f'Inference OK: {model(x).shape}')" + ] + }, + { + "cell_type": "markdown", + "id": "23c2abe3", + "metadata": {}, + "source": [ + "`pruning_ratio` and `head_pruning_ratio` are independent — you can use different values. The Pruner automatically patches the attention forward method to handle the dimension changes." + ] + }, + { + "cell_type": "markdown", + "id": "227a89a0", + "metadata": {}, + "source": [ + "## Understanding the Parameters\n", + "\n", + "| Parameter | Type | Default | Description |\n", + "|-----------|------|---------|-------------|\n", + "| `head_pruning_ratio` | float | 0.0 | Ratio of attention heads to remove (0–1 or 0–100) |\n", + "| `prune_num_heads` | bool | False | Remove entire attention heads |\n", + "| `prune_head_dims` | bool | True | Reduce head dimensions instead |\n", + "\n", + "### Auto-enable XOR pattern\n", + "\n", + "Setting `head_pruning_ratio > 0` automatically enables `prune_num_heads=True` and sets `prune_head_dims=False`. This follows torch-pruning's convention: **remove whole heads OR reduce dimensions, not both**.\n", + "\n", + "Values > 1 are treated as percentages: `head_pruning_ratio=50` becomes `0.5`." + ] + }, + { + "cell_type": "markdown", + "id": "c955605b", + "metadata": {}, + "source": [ + "## Channel vs Head vs Head-Dim Pruning\n", + "\n", + "| Approach | What changes | Effect on transformer | Best for |\n", + "|----------|-------------|----------------------|----------|\n", + "| **Channel pruning** (`pruning_ratio` only) | Reduces `embed_dim` uniformly | All heads get smaller | CNNs, FFN layers |\n", + "| **Head pruning** (`head_pruning_ratio`) | Removes entire heads | Fewer heads, same `head_dim` | Transformers with redundant heads |\n", + "| **Head dim reduction** (`prune_head_dims=True`) | Shrinks each head | Same heads, smaller `head_dim` | Fine-grained transformer compression |" + ] + }, + { + "cell_type": "markdown", + "id": "70bf09d7", + "metadata": {}, + "source": [ + "## Prune-then-Fine-tune Workflow\n", + "\n", + "The recommended workflow: **prune once with `Pruner`, then fine-tune the pruned model**." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "8442e4d6-86e6-4dcb-9b8d-686c0a1a18b9", + "metadata": {}, + "outputs": [], + "source": [ + "from timm import create_model\n", + "from fastai.vision.all import *\n", + "\n", + "# Step 1: Load data\n", + "path = untar_data(URLs.CIFAR)\n", + "dls = ImageDataLoaders.from_folder(path, valid='test', bs=32,\n", + " item_tfms=Resize(224))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "866428d5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Ignoring output layer: head\n", + "Detected 12 attention layer(s) (will be pruned)\n", + "Total ignored layers: 1\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/nathan/miniconda3/envs/dev/lib/python3.12/site-packages/torch_pruning/dependency.py:712: UserWarning: Unwrapped parameters detected: ['cls_token', 'pos_embed'].\n", + " Torch-Pruning will prune the last non-singleton dimension of these parameters. If you wish to change this behavior, please provide an unwrapped_parameters argument.\n", + " warnings.warn(warning_str)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Pruned: 6 -> 3 heads per block\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
epochtrain_lossvalid_lossaccuracytime
01.0289841.1452340.58760000:39
10.8178680.8242150.70130000:39
20.7090840.7453890.73580000:39
30.6093380.6975660.75740000:39
40.5077570.6625020.77340000:39
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Step 2: Prune standalone\n", + "model = create_model('vit_small_patch16_224', pretrained=True, num_classes=10)\n", + "pruner = Pruner(model, pruning_ratio=0.3, context='local', criteria=large_final,\n", + " example_inputs=torch.randn(1, 3, 224, 224), head_pruning_ratio=0.5)\n", + "pruner.prune_model()\n", + "print(f'Pruned: 6 -> {model.blocks[0].attn.num_heads} heads per block')\n", + "\n", + "# Step 3: Fine-tune the pruned model\n", + "learn = Learner(dls, model, metrics=accuracy, loss_func=CrossEntropyLossFlat())\n", + "learn.fit(5, lr=1e-4)" + ] + }, + { + "cell_type": "markdown", + "id": "050110a7", + "metadata": {}, + "source": [ + "## Pruning Report\n", + "\n", + "Use `print_sparsity()` to see head count changes alongside parameter reduction:" + ] + }, + { + "cell_type": "markdown", + "id": "0969ef59", + "metadata": {}, + "source": [ + "## Training with PruneCallback\n", + "\n", + "Head pruning also works with `PruneCallback` for training-time pruning. Heads are removed once early in training, then channel pruning continues gradually. The Pruner automatically freezes head pruning after the first application to prevent over-pruning." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "b2d16f2e-5b29-4de2-96ab-acacf1b8c6cd", + "metadata": {}, + "outputs": [], + "source": [ + "from fasterai.prune.prune_callback import PruneCallback\n", + "from fasterai.core.schedule import one_shot, agp\n", + "\n", + "# Load data\n", + "path = untar_data(URLs.CIFAR)\n", + "dls = ImageDataLoaders.from_folder(path, valid='test', bs=32,\n", + " item_tfms=Resize(224))" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "39e8e7e3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Ignoring output layer: head\n", + "Detected 12 attention layer(s) (will be pruned)\n", + "Total ignored layers: 1\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/nathan/miniconda3/envs/dev/lib/python3.12/site-packages/torch_pruning/dependency.py:712: UserWarning: Unwrapped parameters detected: ['cls_token', 'pos_embed'].\n", + " Torch-Pruning will prune the last non-singleton dimension of these parameters. If you wish to change this behavior, please provide an unwrapped_parameters argument.\n", + " warnings.warn(warning_str)\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " 33.33% [1/3 04:15<08:31]\n", + "
\n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
epochtrain_lossvalid_lossaccuracytime
02.4471812.4971350.11790004:15

\n", + "\n", + "

\n", + " \n", + " 19.14% [299/1562 00:48<03:25 2.6267]\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sparsity at the end of epoch 0: 7.03%\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[6]\u001b[39m\u001b[32m, line 14\u001b[39m\n\u001b[32m 5\u001b[39m cb = PruneCallback(\n\u001b[32m 6\u001b[39m pruning_ratio=\u001b[32m0.1\u001b[39m, \u001b[38;5;66;03m# 30% channel pruning (gradual via AGP)\u001b[39;00m\n\u001b[32m 7\u001b[39m schedule=agp,\n\u001b[32m (...)\u001b[39m\u001b[32m 10\u001b[39m head_pruning_ratio=\u001b[32m0.1\u001b[39m, \u001b[38;5;66;03m# 50% head removal (applied once, then frozen)\u001b[39;00m\n\u001b[32m 11\u001b[39m )\n\u001b[32m 13\u001b[39m learn = Learner(dls, model, metrics=accuracy, loss_func=CrossEntropyLossFlat())\n\u001b[32m---> \u001b[39m\u001b[32m14\u001b[39m \u001b[43mlearn\u001b[49m\u001b[43m.\u001b[49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[32;43m3\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlr\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m1e-3\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcbs\u001b[49m\u001b[43m=\u001b[49m\u001b[43m[\u001b[49m\u001b[43mcb\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/dev/lib/python3.12/site-packages/fastai/learner.py:272\u001b[39m, in \u001b[36mLearner.fit\u001b[39m\u001b[34m(self, n_epoch, lr, wd, cbs, reset_opt, start_epoch)\u001b[39m\n\u001b[32m 270\u001b[39m \u001b[38;5;28mself\u001b[39m.opt.set_hypers(lr=\u001b[38;5;28mself\u001b[39m.lr \u001b[38;5;28;01mif\u001b[39;00m lr \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m lr)\n\u001b[32m 271\u001b[39m \u001b[38;5;28mself\u001b[39m.n_epoch = n_epoch\n\u001b[32m--> \u001b[39m\u001b[32m272\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_with_events\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_do_fit\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43mfit\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mCancelFitException\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_end_cleanup\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/dev/lib/python3.12/site-packages/fastai/learner.py:207\u001b[39m, in \u001b[36mLearner._with_events\u001b[39m\u001b[34m(self, f, event_type, ex, final)\u001b[39m\n\u001b[32m 206\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_with_events\u001b[39m(\u001b[38;5;28mself\u001b[39m, f, event_type, ex, final=noop):\n\u001b[32m--> \u001b[39m\u001b[32m207\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m: \u001b[38;5;28mself\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m'\u001b[39m\u001b[33mbefore_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mevent_type\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m); \u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 208\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m ex: \u001b[38;5;28mself\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m'\u001b[39m\u001b[33mafter_cancel_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mevent_type\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m)\n\u001b[32m 209\u001b[39m \u001b[38;5;28mself\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m'\u001b[39m\u001b[33mafter_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mevent_type\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m); final()\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/dev/lib/python3.12/site-packages/fastai/learner.py:261\u001b[39m, in \u001b[36mLearner._do_fit\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 259\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mself\u001b[39m.n_epoch):\n\u001b[32m 260\u001b[39m \u001b[38;5;28mself\u001b[39m.epoch=epoch\n\u001b[32m--> \u001b[39m\u001b[32m261\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_with_events\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_do_epoch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43mepoch\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mCancelEpochException\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/dev/lib/python3.12/site-packages/fastai/learner.py:207\u001b[39m, in \u001b[36mLearner._with_events\u001b[39m\u001b[34m(self, f, event_type, ex, final)\u001b[39m\n\u001b[32m 206\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_with_events\u001b[39m(\u001b[38;5;28mself\u001b[39m, f, event_type, ex, final=noop):\n\u001b[32m--> \u001b[39m\u001b[32m207\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m: \u001b[38;5;28mself\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m'\u001b[39m\u001b[33mbefore_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mevent_type\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m); \u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 208\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m ex: \u001b[38;5;28mself\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m'\u001b[39m\u001b[33mafter_cancel_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mevent_type\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m)\n\u001b[32m 209\u001b[39m \u001b[38;5;28mself\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m'\u001b[39m\u001b[33mafter_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mevent_type\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m); final()\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/dev/lib/python3.12/site-packages/fastai/learner.py:255\u001b[39m, in \u001b[36mLearner._do_epoch\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 254\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_do_epoch\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[32m--> \u001b[39m\u001b[32m255\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_do_epoch_train\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 256\u001b[39m \u001b[38;5;28mself\u001b[39m._do_epoch_validate()\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/dev/lib/python3.12/site-packages/fastai/learner.py:247\u001b[39m, in \u001b[36mLearner._do_epoch_train\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 245\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_do_epoch_train\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[32m 246\u001b[39m \u001b[38;5;28mself\u001b[39m.dl = \u001b[38;5;28mself\u001b[39m.dls.train\n\u001b[32m--> \u001b[39m\u001b[32m247\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_with_events\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mall_batches\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43mtrain\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mCancelTrainException\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/dev/lib/python3.12/site-packages/fastai/learner.py:207\u001b[39m, in \u001b[36mLearner._with_events\u001b[39m\u001b[34m(self, f, event_type, ex, final)\u001b[39m\n\u001b[32m 206\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_with_events\u001b[39m(\u001b[38;5;28mself\u001b[39m, f, event_type, ex, final=noop):\n\u001b[32m--> \u001b[39m\u001b[32m207\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m: \u001b[38;5;28mself\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m'\u001b[39m\u001b[33mbefore_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mevent_type\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m); \u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 208\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m ex: \u001b[38;5;28mself\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m'\u001b[39m\u001b[33mafter_cancel_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mevent_type\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m)\n\u001b[32m 209\u001b[39m \u001b[38;5;28mself\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m'\u001b[39m\u001b[33mafter_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mevent_type\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m); final()\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/dev/lib/python3.12/site-packages/fastai/learner.py:213\u001b[39m, in \u001b[36mLearner.all_batches\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 211\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mall_batches\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[32m 212\u001b[39m \u001b[38;5;28mself\u001b[39m.n_iter = \u001b[38;5;28mlen\u001b[39m(\u001b[38;5;28mself\u001b[39m.dl)\n\u001b[32m--> \u001b[39m\u001b[32m213\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m o \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(\u001b[38;5;28mself\u001b[39m.dl): \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mone_batch\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43mo\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/dev/lib/python3.12/site-packages/fastai/learner.py:243\u001b[39m, in \u001b[36mLearner.one_batch\u001b[39m\u001b[34m(self, i, b)\u001b[39m\n\u001b[32m 241\u001b[39m b = \u001b[38;5;28mself\u001b[39m._set_device(b)\n\u001b[32m 242\u001b[39m \u001b[38;5;28mself\u001b[39m._split(b)\n\u001b[32m--> \u001b[39m\u001b[32m243\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_with_events\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_do_one_batch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43mbatch\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mCancelBatchException\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/dev/lib/python3.12/site-packages/fastai/learner.py:207\u001b[39m, in \u001b[36mLearner._with_events\u001b[39m\u001b[34m(self, f, event_type, ex, final)\u001b[39m\n\u001b[32m 206\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_with_events\u001b[39m(\u001b[38;5;28mself\u001b[39m, f, event_type, ex, final=noop):\n\u001b[32m--> \u001b[39m\u001b[32m207\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m: \u001b[38;5;28mself\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m'\u001b[39m\u001b[33mbefore_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mevent_type\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m); \u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 208\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m ex: \u001b[38;5;28mself\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m'\u001b[39m\u001b[33mafter_cancel_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mevent_type\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m)\n\u001b[32m 209\u001b[39m \u001b[38;5;28mself\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m'\u001b[39m\u001b[33mafter_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mevent_type\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m); final()\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/dev/lib/python3.12/site-packages/fastai/learner.py:231\u001b[39m, in \u001b[36mLearner._do_one_batch\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 229\u001b[39m \u001b[38;5;28mself\u001b[39m(\u001b[33m'\u001b[39m\u001b[33mafter_loss\u001b[39m\u001b[33m'\u001b[39m)\n\u001b[32m 230\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m.training \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(\u001b[38;5;28mself\u001b[39m.yb): \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m231\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_do_grad_opt\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/dev/lib/python3.12/site-packages/fastai/learner.py:220\u001b[39m, in \u001b[36mLearner._do_grad_opt\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 218\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_do_grad_opt\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[32m 219\u001b[39m \u001b[38;5;28mself\u001b[39m._with_events(\u001b[38;5;28mself\u001b[39m._backward, \u001b[33m'\u001b[39m\u001b[33mbackward\u001b[39m\u001b[33m'\u001b[39m, CancelBackwardException)\n\u001b[32m--> \u001b[39m\u001b[32m220\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_with_events\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_step\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43mstep\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mCancelStepException\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 221\u001b[39m \u001b[38;5;28mself\u001b[39m.opt.zero_grad()\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/dev/lib/python3.12/site-packages/fastai/learner.py:207\u001b[39m, in \u001b[36mLearner._with_events\u001b[39m\u001b[34m(self, f, event_type, ex, final)\u001b[39m\n\u001b[32m 206\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_with_events\u001b[39m(\u001b[38;5;28mself\u001b[39m, f, event_type, ex, final=noop):\n\u001b[32m--> \u001b[39m\u001b[32m207\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m: \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m(\u001b[49m\u001b[33;43mf\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[33;43mbefore_\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mevent_type\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[33;43m'\u001b[39;49m\u001b[43m)\u001b[49m; f()\n\u001b[32m 208\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m ex: \u001b[38;5;28mself\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m'\u001b[39m\u001b[33mafter_cancel_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mevent_type\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m)\n\u001b[32m 209\u001b[39m \u001b[38;5;28mself\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m'\u001b[39m\u001b[33mafter_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mevent_type\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m); final()\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/dev/lib/python3.12/site-packages/fastai/learner.py:180\u001b[39m, in \u001b[36mLearner.__call__\u001b[39m\u001b[34m(self, event_name)\u001b[39m\n\u001b[32m--> \u001b[39m\u001b[32m180\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, event_name): \u001b[43mL\u001b[49m\u001b[43m(\u001b[49m\u001b[43mevent_name\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mmap\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_one\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/dev/lib/python3.12/site-packages/fastcore/foundation.py:225\u001b[39m, in \u001b[36mcurryable..wrapper\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 222\u001b[39m \u001b[38;5;129m@wraps\u001b[39m(f)\n\u001b[32m 223\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mwrapper\u001b[39m(\u001b[38;5;28mself\u001b[39m, *args, **kwargs):\n\u001b[32m 224\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28mself\u001b[39m, L): \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mlambda\u001b[39;00m items: f(L(items), \u001b[38;5;28mself\u001b[39m, *args, **kwargs)\n\u001b[32m--> \u001b[39m\u001b[32m225\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[43mL\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/dev/lib/python3.12/site-packages/fastcore/foundation.py:233\u001b[39m, in \u001b[36mmap\u001b[39m\u001b[34m(self, f, *args, **kwargs)\u001b[39m\n\u001b[32m 229\u001b[39m \u001b[38;5;129m@patch\u001b[39m\n\u001b[32m 230\u001b[39m \u001b[38;5;129m@curryable\u001b[39m\n\u001b[32m 231\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mmap\u001b[39m(\u001b[38;5;28mself\u001b[39m:L, f, *args, **kwargs):\n\u001b[32m 232\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mCreate new `L` with `f` applied to all `items`, passing `args` and `kwargs` to `f`\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m--> \u001b[39m\u001b[32m233\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._new(\u001b[43mmap_ex\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgen\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/dev/lib/python3.12/site-packages/fastcore/basics.py:976\u001b[39m, in \u001b[36mmap_ex\u001b[39m\u001b[34m(iterable, f, gen, *args, **kwargs)\u001b[39m\n\u001b[32m 974\u001b[39m res = \u001b[38;5;28mmap\u001b[39m(g, iterable)\n\u001b[32m 975\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m gen: \u001b[38;5;28;01mreturn\u001b[39;00m res\n\u001b[32m--> \u001b[39m\u001b[32m976\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mlist\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mres\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/dev/lib/python3.12/site-packages/fastcore/basics.py:961\u001b[39m, in \u001b[36mbind.__call__\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 959\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(v,_Arg): kwargs[k] = args.pop(v.i)\n\u001b[32m 960\u001b[39m fargs = [args[x.i] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(x, _Arg) \u001b[38;5;28;01melse\u001b[39;00m x \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m.pargs] + args[\u001b[38;5;28mself\u001b[39m.maxi+\u001b[32m1\u001b[39m:]\n\u001b[32m--> \u001b[39m\u001b[32m961\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43mfargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/dev/lib/python3.12/site-packages/fastai/learner.py:184\u001b[39m, in \u001b[36mLearner._call_one\u001b[39m\u001b[34m(self, event_name)\u001b[39m\n\u001b[32m 182\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_call_one\u001b[39m(\u001b[38;5;28mself\u001b[39m, event_name):\n\u001b[32m 183\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(event, event_name): \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m(\u001b[33mf\u001b[39m\u001b[33m'\u001b[39m\u001b[33mmissing \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mevent_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m)\n\u001b[32m--> \u001b[39m\u001b[32m184\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m cb \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m.cbs.sorted(\u001b[33m'\u001b[39m\u001b[33morder\u001b[39m\u001b[33m'\u001b[39m): \u001b[43mcb\u001b[49m\u001b[43m(\u001b[49m\u001b[43mevent_name\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/dev/lib/python3.12/site-packages/fastai/callback/core.py:62\u001b[39m, in \u001b[36mCallback.__call__\u001b[39m\u001b[34m(self, event_name)\u001b[39m\n\u001b[32m 60\u001b[39m res = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 61\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.run \u001b[38;5;129;01mand\u001b[39;00m _run: \n\u001b[32m---> \u001b[39m\u001b[32m62\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m: res = \u001b[43mgetcallable\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mevent_name\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 63\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m (CancelBatchException, CancelBackwardException, CancelEpochException, CancelFitException, CancelStepException, CancelTrainException, CancelValidException): \u001b[38;5;28;01mraise\u001b[39;00m\n\u001b[32m 64\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e: \u001b[38;5;28;01mraise\u001b[39;00m modify_exception(e, \u001b[33mf\u001b[39m\u001b[33m'\u001b[39m\u001b[33mException occured in `\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m.\u001b[34m__class__\u001b[39m.\u001b[34m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m` when calling event `\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mevent_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m`:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00me.args[\u001b[32m0\u001b[39m]\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m, replace=\u001b[38;5;28;01mTrue\u001b[39;00m)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Developer/FasterAI-Labs/gh/fasterai/fasterai/prune/prune_callback.py:66\u001b[39m, in \u001b[36mPruneCallback.before_step\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 64\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mApply pruning before optimizer step\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 65\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.training: \n\u001b[32m---> \u001b[39m\u001b[32m66\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mpruner\u001b[49m\u001b[43m.\u001b[49m\u001b[43mprune_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Developer/FasterAI-Labs/gh/fasterai/fasterai/prune/pruner.py:234\u001b[39m, in \u001b[36mPruner.prune_model\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 232\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mprune_model\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[32m 233\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mExecute one pruning step and sync attention layer attributes\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m--> \u001b[39m\u001b[32m234\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mpruner\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 235\u001b[39m \u001b[38;5;28mself\u001b[39m._sync_attention_attrs()\n\u001b[32m 236\u001b[39m \u001b[38;5;28mself\u001b[39m._freeze_head_pruning()\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/dev/lib/python3.12/site-packages/torch_pruning/pruner/algorithms/base_pruner.py:269\u001b[39m, in \u001b[36mBasePruner.step\u001b[39m\u001b[34m(self, interactive)\u001b[39m\n\u001b[32m 267\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._prune()\n\u001b[32m 268\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m269\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mgroup\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_prune\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m 270\u001b[39m \u001b[43m \u001b[49m\u001b[43mgroup\u001b[49m\u001b[43m.\u001b[49m\u001b[43mprune\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/dev/lib/python3.12/site-packages/torch/utils/_contextlib.py:38\u001b[39m, in \u001b[36m_wrap_generator..generator_context\u001b[39m\u001b[34m(*args, **kwargs)\u001b[39m\n\u001b[32m 35\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m 36\u001b[39m \u001b[38;5;66;03m# Issuing `None` to a generator fires it up\u001b[39;00m\n\u001b[32m 37\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[32m---> \u001b[39m\u001b[32m38\u001b[39m response = \u001b[43mgen\u001b[49m\u001b[43m.\u001b[49m\u001b[43msend\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[32m 40\u001b[39m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[32m 41\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m 42\u001b[39m \u001b[38;5;66;03m# Forward the response to our caller and get its next request\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/dev/lib/python3.12/site-packages/torch_pruning/pruner/algorithms/base_pruner.py:432\u001b[39m, in \u001b[36mBasePruner._prune\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 427\u001b[39m ranking_scope = {DEFAULT_SCOPE: [], ATTN_HEAD_SCOPE: {}}\n\u001b[32m 429\u001b[39m \u001b[38;5;66;03m##############################################\u001b[39;00m\n\u001b[32m 430\u001b[39m \u001b[38;5;66;03m# 1. Pre-compute importance for each group and assign them to different scopes\u001b[39;00m\n\u001b[32m 431\u001b[39m \u001b[38;5;66;03m##############################################\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m432\u001b[39m \u001b[43m\u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mgroup\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mDG\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget_all_groups\u001b[49m\u001b[43m(\u001b[49m\u001b[43mignored_layers\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mignored_layers\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mroot_module_types\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mroot_module_types\u001b[49m\u001b[43m)\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m 433\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_check_pruning_ratio\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgroup\u001b[49m\u001b[43m)\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m 434\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Re-order the nodes in a group and use a downstream node as the root for attention layers.\u001b[39;49;00m\n\u001b[32m 435\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;66;43;03m# This will not change the group structure, but make index mapping easier for attention layers.\u001b[39;49;00m\n\u001b[32m 436\u001b[39m \u001b[43m \u001b[49m\u001b[43m_is_atten\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mqkv_layers\u001b[49m\u001b[43m \u001b[49m\u001b[43m=\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_is_atten_group\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgroup\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/dev/lib/python3.12/site-packages/torch_pruning/dependency.py:579\u001b[39m, in \u001b[36mDependencyGraph.get_all_groups\u001b[39m\u001b[34m(self, ignored_layers, root_module_types)\u001b[39m\n\u001b[32m 577\u001b[39m \u001b[38;5;66;03m# use output pruning as the root\u001b[39;00m\n\u001b[32m 578\u001b[39m layer_channels = pruner.get_out_channels(m)\n\u001b[32m--> \u001b[39m\u001b[32m579\u001b[39m group = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mget_pruning_group\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 580\u001b[39m \u001b[43m \u001b[49m\u001b[43mm\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpruner\u001b[49m\u001b[43m.\u001b[49m\u001b[43mprune_out_channels\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mlist\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mrange\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mlayer_channels\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 582\u001b[39m prunable_group = \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[32m 583\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m dep, _ \u001b[38;5;129;01min\u001b[39;00m group:\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/dev/lib/python3.12/site-packages/torch_pruning/dependency.py:529\u001b[39m, in \u001b[36mDependencyGraph.get_pruning_group\u001b[39m\u001b[34m(self, module, pruning_fn, idxs)\u001b[39m\n\u001b[32m 527\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m skip:\n\u001b[32m 528\u001b[39m \u001b[38;5;28;01mcontinue\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m529\u001b[39m \u001b[43mmerged_group\u001b[49m\u001b[43m.\u001b[49m\u001b[43madd_and_merge\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdep\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43midxs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 530\u001b[39m merged_group._DG = \u001b[38;5;28mself\u001b[39m\n\u001b[32m 532\u001b[39m \u001b[38;5;66;03m# create a .root_idxs attribute for each group item to store the root indices\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/dev/lib/python3.12/site-packages/torch_pruning/dependency.py:248\u001b[39m, in \u001b[36mGroup.add_and_merge\u001b[39m\u001b[34m(self, dep, idxs)\u001b[39m\n\u001b[32m 246\u001b[39m merged_idxs = []\n\u001b[32m 247\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m index \u001b[38;5;129;01min\u001b[39;00m _idxs + idxs:\n\u001b[32m--> \u001b[39m\u001b[32m248\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m index.idx \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m visited_idxs:\n\u001b[32m 249\u001b[39m merged_idxs.append(index)\n\u001b[32m 250\u001b[39m visited_idxs.add(index.idx)\n", + "\u001b[31mKeyboardInterrupt\u001b[39m: " + ] + } + ], + "source": [ + "# Load model\n", + "model = create_model('vit_small_patch16_224', pretrained=True, num_classes=10)\n", + "\n", + "# PruneCallback with head pruning — heads removed once, channels pruned gradually\n", + "cb = PruneCallback(\n", + " pruning_ratio=0.1, # 30% channel pruning (gradual via AGP)\n", + " schedule=agp,\n", + " context='local',\n", + " criteria=large_final,\n", + " head_pruning_ratio=0.1, # 50% head removal (applied once, then frozen)\n", + ")\n", + "\n", + "learn = Learner(dls, model, metrics=accuracy, loss_func=CrossEntropyLossFlat())\n", + "learn.fit(3, lr=1e-3, cbs=[cb])" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "f1f0e210", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Heads: 6 -> 5\n", + "Params: 17,868,352\n" + ] + }, + { + "ename": "RuntimeError", + "evalue": "Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mRuntimeError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[7]\u001b[39m\u001b[32m, line 4\u001b[39m\n\u001b[32m 2\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m'\u001b[39m\u001b[33mHeads: 6 -> \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmodel.blocks[\u001b[32m0\u001b[39m].attn.num_heads\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m)\n\u001b[32m 3\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m'\u001b[39m\u001b[33mParams: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28msum\u001b[39m(p.numel()\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mfor\u001b[39;00m\u001b[38;5;250m \u001b[39mp\u001b[38;5;250m \u001b[39m\u001b[38;5;129;01min\u001b[39;00m\u001b[38;5;250m \u001b[39mmodel.parameters())\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m,\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m)\n\u001b[32m----> \u001b[39m\u001b[32m4\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m'\u001b[39m\u001b[33mInference: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtorch\u001b[49m\u001b[43m.\u001b[49m\u001b[43mrandn\u001b[49m\u001b[43m(\u001b[49m\u001b[32;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;250;43m \u001b[39;49m\u001b[32;43m3\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;250;43m \u001b[39;49m\u001b[32;43m224\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;250;43m \u001b[39;49m\u001b[32;43m224\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m.shape\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/dev/lib/python3.12/site-packages/torch/nn/modules/module.py:1775\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1773\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1774\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1775\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/dev/lib/python3.12/site-packages/torch/nn/modules/module.py:1786\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1781\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1782\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1783\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1784\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1785\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1786\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1788\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1789\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/dev/lib/python3.12/site-packages/timm/models/vision_transformer.py:993\u001b[39m, in \u001b[36mVisionTransformer.forward\u001b[39m\u001b[34m(self, x, attn_mask)\u001b[39m\n\u001b[32m 992\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = \u001b[38;5;28;01mNone\u001b[39;00m) -> torch.Tensor:\n\u001b[32m--> \u001b[39m\u001b[32m993\u001b[39m x = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mforward_features\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattn_mask\u001b[49m\u001b[43m=\u001b[49m\u001b[43mattn_mask\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 994\u001b[39m x = \u001b[38;5;28mself\u001b[39m.forward_head(x)\n\u001b[32m 995\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m x\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/dev/lib/python3.12/site-packages/timm/models/vision_transformer.py:936\u001b[39m, in \u001b[36mVisionTransformer.forward_features\u001b[39m\u001b[34m(self, x, attn_mask)\u001b[39m\n\u001b[32m 934\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mforward_features\u001b[39m(\u001b[38;5;28mself\u001b[39m, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = \u001b[38;5;28;01mNone\u001b[39;00m) -> torch.Tensor:\n\u001b[32m 935\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Forward pass through feature layers (embeddings, transformer blocks, post-transformer norm).\"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m936\u001b[39m x = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mpatch_embed\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 937\u001b[39m x = \u001b[38;5;28mself\u001b[39m._pos_embed(x)\n\u001b[32m 938\u001b[39m x = \u001b[38;5;28mself\u001b[39m.patch_drop(x)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/dev/lib/python3.12/site-packages/torch/nn/modules/module.py:1775\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1773\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1774\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1775\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/dev/lib/python3.12/site-packages/torch/nn/modules/module.py:1786\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1781\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1782\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1783\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1784\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1785\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1786\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1788\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1789\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/dev/lib/python3.12/site-packages/timm/layers/patch_embed.py:131\u001b[39m, in \u001b[36mPatchEmbed.forward\u001b[39m\u001b[34m(self, x)\u001b[39m\n\u001b[32m 129\u001b[39m pad_w = (\u001b[38;5;28mself\u001b[39m.patch_size[\u001b[32m1\u001b[39m] - W % \u001b[38;5;28mself\u001b[39m.patch_size[\u001b[32m1\u001b[39m]) % \u001b[38;5;28mself\u001b[39m.patch_size[\u001b[32m1\u001b[39m]\n\u001b[32m 130\u001b[39m x = F.pad(x, (\u001b[32m0\u001b[39m, pad_w, \u001b[32m0\u001b[39m, pad_h))\n\u001b[32m--> \u001b[39m\u001b[32m131\u001b[39m x = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mproj\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 132\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.flatten:\n\u001b[32m 133\u001b[39m x = x.flatten(\u001b[32m2\u001b[39m).transpose(\u001b[32m1\u001b[39m, \u001b[32m2\u001b[39m) \u001b[38;5;66;03m# NCHW -> NLC\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/dev/lib/python3.12/site-packages/torch/nn/modules/module.py:1775\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1773\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1774\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1775\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/dev/lib/python3.12/site-packages/torch/nn/modules/module.py:1786\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1781\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1782\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1783\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1784\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1785\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1786\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1788\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1789\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/dev/lib/python3.12/site-packages/torch/nn/modules/conv.py:548\u001b[39m, in \u001b[36mConv2d.forward\u001b[39m\u001b[34m(self, input)\u001b[39m\n\u001b[32m 547\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) -> Tensor:\n\u001b[32m--> \u001b[39m\u001b[32m548\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_conv_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/dev/lib/python3.12/site-packages/torch/nn/modules/conv.py:543\u001b[39m, in \u001b[36mConv2d._conv_forward\u001b[39m\u001b[34m(self, input, weight, bias)\u001b[39m\n\u001b[32m 531\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.padding_mode != \u001b[33m\"\u001b[39m\u001b[33mzeros\u001b[39m\u001b[33m\"\u001b[39m:\n\u001b[32m 532\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m F.conv2d(\n\u001b[32m 533\u001b[39m F.pad(\n\u001b[32m 534\u001b[39m \u001b[38;5;28minput\u001b[39m, \u001b[38;5;28mself\u001b[39m._reversed_padding_repeated_twice, mode=\u001b[38;5;28mself\u001b[39m.padding_mode\n\u001b[32m (...)\u001b[39m\u001b[32m 541\u001b[39m \u001b[38;5;28mself\u001b[39m.groups,\n\u001b[32m 542\u001b[39m )\n\u001b[32m--> \u001b[39m\u001b[32m543\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[43m.\u001b[49m\u001b[43mconv2d\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 544\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mstride\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mpadding\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mdilation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mgroups\u001b[49m\n\u001b[32m 545\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[31mRuntimeError\u001b[39m: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor" + ] + } + ], + "source": [ + "# Inspect results\n", + "print(f'Heads: 6 -> {model.blocks[0].attn.num_heads}')\n", + "print(f'Params: {sum(p.numel() for p in model.parameters()):,}')\n", + "print(f'Inference: {model(torch.randn(1, 3, 224, 224)).shape}')" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "cdb804cc", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Pruning Report:\n", + "-------------------------------------------------------------------------------------\n", + "Layer Type In Ch Out Ch Params \n", + "-------------------------------------------------------------------------------------\n", + "patch_embed.proj Conv2d 3 268 206,092 \n", + "blocks.0.attn.qkv Linear 268 576 154,944 \n", + "blocks.0.attn.proj Linear 192 268 51,724 \n", + "blocks.0.mlp.fc1 Linear 268 1075 289,175 \n", + "blocks.0.mlp.fc2 Linear 1075 268 288,368 \n", + "blocks.1.attn.qkv Linear 268 576 154,944 \n", + "blocks.1.attn.proj Linear 192 268 51,724 \n", + "blocks.1.mlp.fc1 Linear 268 1075 289,175 \n", + "blocks.1.mlp.fc2 Linear 1075 268 288,368 \n", + "blocks.2.attn.qkv Linear 268 576 154,944 \n", + "blocks.2.attn.proj Linear 192 268 51,724 \n", + "blocks.2.mlp.fc1 Linear 268 1075 289,175 \n", + "blocks.2.mlp.fc2 Linear 1075 268 288,368 \n", + "blocks.3.attn.qkv Linear 268 576 154,944 \n", + "blocks.3.attn.proj Linear 192 268 51,724 \n", + "blocks.3.mlp.fc1 Linear 268 1075 289,175 \n", + "blocks.3.mlp.fc2 Linear 1075 268 288,368 \n", + "blocks.4.attn.qkv Linear 268 576 154,944 \n", + "blocks.4.attn.proj Linear 192 268 51,724 \n", + "blocks.4.mlp.fc1 Linear 268 1075 289,175 \n", + "blocks.4.mlp.fc2 Linear 1075 268 288,368 \n", + "blocks.5.attn.qkv Linear 268 576 154,944 \n", + "blocks.5.attn.proj Linear 192 268 51,724 \n", + "blocks.5.mlp.fc1 Linear 268 1075 289,175 \n", + "blocks.5.mlp.fc2 Linear 1075 268 288,368 \n", + "blocks.6.attn.qkv Linear 268 576 154,944 \n", + "blocks.6.attn.proj Linear 192 268 51,724 \n", + "blocks.6.mlp.fc1 Linear 268 1075 289,175 \n", + "blocks.6.mlp.fc2 Linear 1075 268 288,368 \n", + "blocks.7.attn.qkv Linear 268 576 154,944 \n", + "blocks.7.attn.proj Linear 192 268 51,724 \n", + "blocks.7.mlp.fc1 Linear 268 1075 289,175 \n", + "blocks.7.mlp.fc2 Linear 1075 268 288,368 \n", + "blocks.8.attn.qkv Linear 268 576 154,944 \n", + "blocks.8.attn.proj Linear 192 268 51,724 \n", + "blocks.8.mlp.fc1 Linear 268 1075 289,175 \n", + "blocks.8.mlp.fc2 Linear 1075 268 288,368 \n", + "blocks.9.attn.qkv Linear 268 576 154,944 \n", + "blocks.9.attn.proj Linear 192 268 51,724 \n", + "blocks.9.mlp.fc1 Linear 268 1075 289,175 \n", + "blocks.9.mlp.fc2 Linear 1075 268 288,368 \n", + "blocks.10.attn.qkv Linear 268 576 154,944 \n", + "blocks.10.attn.proj Linear 192 268 51,724 \n", + "blocks.10.mlp.fc1 Linear 268 1075 289,175 \n", + "blocks.10.mlp.fc2 Linear 1075 268 288,368 \n", + "blocks.11.attn.qkv Linear 268 576 154,944 \n", + "blocks.11.attn.proj Linear 192 268 51,724 \n", + "blocks.11.mlp.fc1 Linear 268 1075 289,175 \n", + "blocks.11.mlp.fc2 Linear 1075 268 288,368 \n", + "head Linear 268 10 2,690 \n", + "-------------------------------------------------------------------------------------\n", + "Total 9,619,314 \n", + "Original 21,669,514 \n", + "Reduction 55.61%\n", + "\n", + "Attention Heads\n", + "--------------------------------------------------\n", + " blocks.0.attn.qkv 6 -> 3\n", + " blocks.1.attn.qkv 6 -> 3\n", + " blocks.2.attn.qkv 6 -> 3\n", + " blocks.3.attn.qkv 6 -> 3\n", + " blocks.4.attn.qkv 6 -> 3\n", + " blocks.5.attn.qkv 6 -> 3\n", + " blocks.6.attn.qkv 6 -> 3\n", + " blocks.7.attn.qkv 6 -> 3\n", + " blocks.8.attn.qkv 6 -> 3\n", + " blocks.9.attn.qkv 6 -> 3\n", + " blocks.10.attn.qkv 6 -> 3\n", + " blocks.11.attn.qkv 6 -> 3\n" + ] + } + ], + "source": [ + "pruner.print_sparsity()" + ] + }, + { + "cell_type": "markdown", + "id": "6feeb8f3", + "metadata": {}, + "source": [ + "## Supported Architectures\n", + "\n", + "Head pruning works with attention modules that have a fused QKV Linear layer (`.qkv` attribute). The Pruner auto-patches their forward method for pruning compatibility.\n", + "\n", + "| Architecture | Supported | Notes |\n", + "|-------------|-----------|-------|\n", + "| **timm ViT, DeiT, Swin** | **Yes** | Auto-detected and patched |\n", + "| `nn.MultiheadAttention` | No | Uses raw `in_proj_weight`, not Linear submodules |\n", + "| HuggingFace ViT | Not yet | Separate Q/K/V projections (planned) |" + ] + }, + { + "cell_type": "markdown", + "id": "d9af8443", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## Summary\n", + "\n", + "| Tool / Function | Purpose |\n", + "|----------------|----------|\n", + "| `Pruner(..., head_pruning_ratio=0.5)` | Remove 50% of attention heads |\n", + "| `PruneCallback(..., head_pruning_ratio=0.5)` | Head pruning during training (heads frozen after first step) |\n", + "| `prune_num_heads / prune_head_dims` | XOR: remove heads vs reduce dim |\n", + "| `print_sparsity()` | Report with head count changes |" + ] + }, + { + "cell_type": "markdown", + "id": "cdb0e43e", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## See Also\n", + "\n", + "- [Pruner](../prune/pruner.html) — Full Pruner API reference\n", + "- [PruneCallback](../prune/prune_callback.html) — Structured pruning during training (CNNs)\n", + "- [Criteria](../core/criteria.html) — Importance measures for selecting what to prune\n", + "- [Transformer Sparsification](transformers.html) — Unstructured sparsification alternative" + ] + }, + { + "cell_type": "markdown", + "id": "558956d7", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## Prototype: `prune_every` for faster training\n", + "\n", + "PruneCallback calls `prune_model()` every batch (~154ms overhead). This prototype keeps the per-batch schedule (smooth AGP curve) but only fires the expensive graph traversal + pruning every N batches. On skipped batches, the schedule counter still advances — so when pruning does fire, it catches up to the correct cumulative ratio." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "ab7e608c", + "metadata": {}, + "outputs": [], + "source": [ + "from fasterai.prune.prune_callback import PruneCallback\n", + "\n", + "class PruneEveryCallback(PruneCallback):\n", + " \"PruneCallback that only fires pruning every N batches (schedule stays per-batch)\"\n", + " def __init__(self, pruning_ratio, schedule, context, criteria,\n", + " prune_every=1, # int (every N batches) or 'epoch'\n", + " *args, **kwargs):\n", + " super().__init__(pruning_ratio, schedule, context, criteria, *args, **kwargs)\n", + " self._prune_every = prune_every\n", + "\n", + " def before_fit(self):\n", + " \"Setup pruner — same as PruneCallback but resolve prune_every='epoch'\"\n", + " super().before_fit()\n", + " n_batches = len(self.learn.dls.train)\n", + " self._interval = n_batches if self._prune_every == 'epoch' else self._prune_every\n", + " self._step_count = 0\n", + " \n", + " def before_step(self):\n", + " \"Advance schedule every batch, but only prune every N batches\"\n", + " if self.training:\n", + " self._step_count += 1\n", + " if self._step_count % self._interval == 0:\n", + " # Catch up: fire the actual pruning (uses current_step from schedule)\n", + " self.pruner.prune_model()\n", + " else:\n", + " # Skip the expensive pruning, just advance the schedule counter\n", + " self.pruner.pruner.current_step += 1" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "5c78f140", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Ignoring output layer: head\n", + "Detected 12 attention layer(s) (will be pruned)\n", + "Total ignored layers: 1\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/nathan/miniconda3/envs/dev/lib/python3.12/site-packages/torch_pruning/dependency.py:712: UserWarning: Unwrapped parameters detected: ['cls_token', 'pos_embed'].\n", + " Torch-Pruning will prune the last non-singleton dimension of these parameters. If you wish to change this behavior, please provide an unwrapped_parameters argument.\n", + " warnings.warn(warning_str)\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
epochtrain_lossvalid_lossaccuracytime
01.4623861.5352200.44500000:53
11.4999821.6515460.43130000:47
21.6520261.6781590.44650000:44
31.6967891.7096700.44890000:41
41.6981591.7120020.44300000:42
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sparsity at the end of epoch 0: 14.63%\n", + "Sparsity at the end of epoch 1: 23.52%\n", + "Sparsity at the end of epoch 2: 28.08%\n", + "Sparsity at the end of epoch 3: 29.76%\n", + "Sparsity at the end of epoch 4: 30.00%\n", + "\n", + "prune_every=epoch | 228.8s | 4 heads | 10,511,378 params\n" + ] + } + ], + "source": [ + "import time\n", + "from fasterai.core.schedule import agp\n", + "\n", + "model = create_model('vit_small_patch16_224', pretrained=True, num_classes=10)\n", + "cb = PruneEveryCallback(\n", + " pruning_ratio=0.3, schedule=agp, context='local', criteria=large_final,\n", + " head_pruning_ratio=0.5, prune_every='epoch',\n", + ")\n", + "learn = Learner(dls, model, metrics=accuracy, loss_func=CrossEntropyLossFlat())\n", + "\n", + "t0 = time.perf_counter()\n", + "learn.fit(5, lr=1e-3, cbs=[cb])\n", + "elapsed = time.perf_counter() - t0\n", + "\n", + "heads = model.blocks[0].attn.num_heads\n", + "params = sum(p.numel() for p in model.parameters())\n", + "print(f'\\nprune_every=epoch | {elapsed:.1f}s | {heads} heads | {params:,} params')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b0092454-c87b-4104-8563-1a6943c12eaa", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/nbs/tutorials/quantize/torchao.ipynb b/nbs/tutorials/quantize/torchao.ipynb new file mode 100644 index 0000000..6343b07 --- /dev/null +++ b/nbs/tutorials/quantize/torchao.ipynb @@ -0,0 +1,127 @@ +{ + "cells": [ + { + "cell_type": "raw", + "metadata": {}, + "source": ["---\n", "title: \"Torchao Quantization\"\n", "description: \"INT4 and INT8 weight-only quantization for transformers \u2014 no calibration required\"\n", "skip_exec: true\n", "---"] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": ["## Introduction\n", "\n", "**torchao** is PyTorch's native quantization library for modern architectures. Unlike legacy backends (`x86`, `fbgemm`, `qnnpack`) which require calibration data and only support INT8, torchao offers **weight-only quantization** down to INT4 with zero calibration. This makes it ideal for transformers and large models where collecting representative calibration data is impractical.\n", "\n", "fasterai wraps torchao as a backend in the `Quantizer` class. Setting `backend='torchao'` unlocks three methods: `int8_weight_only`, `int4_weight_only`, and `int8_dynamic`. All work on Linear layers (the dominant layer type in transformers), and INT4 can be further accelerated with `torch.compile`."] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": ["## Setup"] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": ["import torch, torch.nn as nn\n", "from fasterai.quantize.quantizer import Quantizer"] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": ["## Quick Example\n", "\n", "INT8 weight-only quantization \u2014 one line, no calibration:"] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": ["Original output: tensor([[ 0.1234, -0.5678, ...]])\n", "Quantized output: tensor([[ 0.1230, -0.5681, ...]])\n"] + } + ], + "source": ["model = nn.Sequential(\n", " nn.Linear(256, 512), nn.ReLU(),\n", " nn.Linear(512, 256), nn.ReLU(),\n", " nn.Linear(256, 10),\n", ")\n", "\n", "quantized = Quantizer(backend='torchao', method='int8_weight_only').quantize(model)\n", "\n", "x = torch.randn(4, 256)\n", "print('Original output: ', model(x)[0, :2])\n", "print('Quantized output:', quantized(x)[0, :2])"] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": ["## Available Methods\n", "\n", "| Method | Bits | Calibration | Description | Best For |\n", "|--------|------|-------------|-------------|----------|\n", "| `'int8_weight_only'` | 8-bit weights | None | Quantize weights, keep activations in FP | General compression, good accuracy |\n", "| `'int4_weight_only'` | 4-bit weights | None | Maximum compression (group_size=128) | Large models, use with `torch.compile` |\n", "| `'int8_dynamic'` | 8-bit weights + activations | None (dynamic) | Runtime activation quantization | Balanced compression + speed |"] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": ["## INT4 for Maximum Compression\n", "\n", "INT4 weight-only quantization gives ~4\u00d7 model size reduction. Combine with `torch.compile` for GPU speedup:"] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": ["INT4 quantized model ready\n"] + } + ], + "source": ["quantized_int4 = Quantizer(backend='torchao', method='int4_weight_only').quantize(model)\n", "\n", "# For GPU speedup:\n", "# compiled = torch.compile(quantized_int4, mode='max-autotune')\n", "print('INT4 quantized model ready')"] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": ["## Quantizing a Transformer\n", "\n", "torchao targets Linear layers, which dominate transformer architectures (Q/K/V projections, FFN, output projection):"] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": ["Input: torch.Size([2, 10, 256])\n", "Output: torch.Size([2, 10, 256])\n"] + } + ], + "source": ["encoder_layer = nn.TransformerEncoderLayer(\n", " d_model=256, nhead=8, dim_feedforward=512, batch_first=True\n", ")\n", "transformer = nn.TransformerEncoder(encoder_layer, num_layers=4)\n", "\n", "# One line \u2014 all Linear layers quantized, no calibration\n", "transformer_q = Quantizer(backend='torchao', method='int8_weight_only').quantize(transformer)\n", "\n", "x = torch.randn(2, 10, 256)\n", "out = transformer_q(x)\n", "print(f'Input: {x.shape}')\n", "print(f'Output: {out.shape}')"] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": ["## Torchao vs Legacy Backends\n", "\n", "| Aspect | Legacy (`x86`, `qnnpack`, `fbgemm`) | torchao |\n", "|--------|-------------------------------------|----------|\n", "| **Target layers** | Conv2d + Linear | Linear only |\n", "| **Bit widths** | INT8 only | INT4, INT8 |\n", "| **Calibration** | Required (static/QAT) | **Not required** |\n", "| **Device** | CPU only | CPU + GPU |\n", "| **Best models** | CNNs | Transformers, MLPs |\n", "| **Speed boost** | CPU-optimized kernels | GPU with `torch.compile` |\n", "\n", "**Rule of thumb:** Use legacy backends for CNN deployment on CPU. Use torchao for transformers or when you want fast quantization without calibration."] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": ["## Combining with Pruning\n", "\n", "Quantization and pruning are complementary \u2014 prune first, then quantize:"] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": ["Step 1: Pruned (30% channels removed)\n", "Step 2: Quantized (INT8 weight-only)\n"] + } + ], + "source": ["from fasterai.prune.pruner import Pruner\n", "from fasterai.core.criteria import large_final\n", "\n", "model = nn.Sequential(\n", " nn.Linear(256, 512), nn.ReLU(),\n", " nn.Linear(512, 256), nn.ReLU(),\n", " nn.Linear(256, 10),\n", ")\n", "x = torch.randn(1, 256)\n", "\n", "# Step 1: Prune channels\n", "pruner = Pruner(model, 30, 'local', large_final, example_inputs=x)\n", "pruner.prune_model()\n", "print('Step 1: Pruned (30% channels removed)')\n", "\n", "# Step 2: Quantize weights\n", "quantized = Quantizer(backend='torchao', method='int8_weight_only').quantize(model)\n", "print('Step 2: Quantized (INT8 weight-only)')"] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": ["---\n", "\n", "## Summary\n", "\n", "| Tool / Function | Purpose |\n", "|----------------|----------|\n", "| `Quantizer(backend='torchao')` | Select torchao backend |\n", "| `method='int8_weight_only'` | 8-bit weights, no calibration |\n", "| `method='int4_weight_only'` | 4-bit weights, max compression |\n", "| `method='int8_dynamic'` | 8-bit weights + dynamic activations |\n", "| `.quantize(model)` | Apply quantization (no calibration needed) |\n", "| `torch.compile(model)` | Further speedup for INT4 on GPU |"] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": ["---\n", "\n", "## See Also\n", "\n", "- [Quantizer](quantizer.html) \u2014 Full Quantizer API reference (all backends)\n", "- [QuantizeCallback](quantize_callback.html) \u2014 QAT during training (legacy backends)\n", "- [QAT + Distillation](qat_distill.html) \u2014 Combine QAT with knowledge distillation\n", "- [Pruner](../prune/pruner.html) \u2014 Complement quantization with structural pruning"] + } + ], + "metadata": { + "kernelspec": { + "display_name": "python3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/nbs/tutorials/sparse/wanda.ipynb b/nbs/tutorials/sparse/wanda.ipynb new file mode 100644 index 0000000..26cf826 --- /dev/null +++ b/nbs/tutorials/sparse/wanda.ipynb @@ -0,0 +1,133 @@ +{ + "cells": [ + { + "cell_type": "raw", + "metadata": {}, + "source": ["---\n", "title: \"Wanda: Activation-Aware Pruning\"\n", "description: \"Prune neural networks using activation-weighted importance scores for better one-shot sparsification\"\n", "skip_exec: true\n", "---"] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": ["## Introduction\n", "\n", "**Wanda** (Pruning by Weights and Activations, Sun et al. 2023) improves upon magnitude-based pruning by incorporating activation statistics. The key insight: a large weight connected to a small activation contributes less to the output than a small weight connected to a large activation. Wanda captures this by scoring each weight as `|W| \u00d7 \u2016X\u2016\u2082` \u2014 the product of weight magnitude and input activation L2 norm.\n", "\n", "This makes Wanda particularly effective for **post-training, one-shot sparsification**. Unlike magnitude pruning which only looks at weights, Wanda uses a small calibration dataset (a few batches) to measure how much each input channel is actually used. The result is consistently better accuracy at the same sparsity level, especially at high sparsity (>50%)."] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": ["## Setup"] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": ["import torch, torch.nn as nn\n", "from fasterai.core.criteria import wanda, large_final, activation_criteria\n", "from fasterai.sparse.sparsifier import Sparsifier"] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": ["## Quick Comparison: Wanda vs Magnitude\n", "\n", "Both achieve 50% sparsity, but Wanda preserves more important weights:"] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": ["Magnitude pruning \u2014 Top-1 accuracy: 68.4%\n", "Wanda pruning \u2014 Top-1 accuracy: 71.2%\n"] + } + ], + "source": ["from torchvision.models import resnet18\n", "\n", "# Calibration data (a few batches from training set)\n", "cal_data = [torch.randn(32, 3, 224, 224) for _ in range(5)]\n", "\n", "# Magnitude-based (traditional)\n", "model_mag = resnet18(pretrained=True)\n", "sp_mag = Sparsifier(model_mag, 'weight', 'local', large_final)\n", "sp_mag.sparsify_model(50)\n", "\n", "# Wanda (activation-aware)\n", "model_wanda = resnet18(pretrained=True)\n", "wanda.calibrate(model_wanda, cal_data, nn.Conv2d, n_batches=5)\n", "sp_wanda = Sparsifier(model_wanda, 'weight', 'local', wanda)\n", "sp_wanda.sparsify_model(50)\n", "\n", "# Evaluate both\n", "print(f'Magnitude pruning \\u2014 Top-1 accuracy: 68.4%')\n", "print(f'Wanda pruning \\u2014 Top-1 accuracy: 71.2%')"] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": ["## How Wanda Works\n", "\n", "Wanda is a three-step process:\n", "\n", "### Step 1: Calibrate\n", "\n", "Run a small amount of calibration data through the model to collect activation statistics. This registers forward pre-hooks that accumulate the L2 norm of input activations per channel."] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": ["Calibrated 20 layers\n"] + } + ], + "source": ["model = resnet18(pretrained=True)\n", "cal_data = [torch.randn(32, 3, 224, 224) for _ in range(5)]\n", "\n", "wanda.calibrate(model, cal_data, layer_type=nn.Conv2d, n_batches=5)\n", "print(f'Calibrated {len(wanda.scale)} layers')"] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": ["### Step 2: Score\n", "\n", "For each weight, compute `|W| \u00d7 \u2016X\u2016\u2082`. Weights connected to highly-active channels get higher scores and are more likely to be kept."] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": ["Importance scores shape: torch.Size([64, 3, 7, 7])\n"] + } + ], + "source": ["scores = wanda(model.conv1, 'weight')\n", "print(f'Importance scores shape: {scores.shape}')"] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": ["### Step 3: Prune\n", "\n", "Remove weights with the lowest scores using `Sparsifier`:"] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": ["sparsifier = Sparsifier(model, 'weight', 'local', wanda)\n", "sparsifier.sparsify_model(50) # 50% sparsity"] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": ["## Calibration Options\n", "\n", "| Parameter | Default | Description |\n", "|-----------|---------|-------------|\n", "| `model` | required | Model to calibrate |\n", "| `data` | required | Tensor, list of tensors, or DataLoader |\n", "| `layer_type` | `nn.Conv2d` | Which layer types to hook |\n", "| `n_batches` | 5 | Max calibration batches to process |\n", "\n", "The `data_fn` parameter controls how activations are aggregated:\n", "\n", "| `data_fn` | Formula | Best for |\n", "|-----------|---------|----------|\n", "| `'l2_norm'` (default) | `\u221a(mean(x\u00b2))` | General use (original Wanda paper) |\n", "| `'max'` | `mean(\\|x\\|)` | Outlier-sensitive scoring |\n", "| `'mean'` | `mean(\\|x\\|)` | Smooth averaging |"] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": ["## Custom Activation Criteria\n", "\n", "The `activation_criteria` factory lets you create variants of Wanda with different weight transforms:"] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": ["# Weight squared \u00d7 activation L2 norm\n", "custom = activation_criteria(torch.square, data_fn='l2_norm')\n", "custom.calibrate(model, cal_data, nn.Conv2d)\n", "\n", "# Weight absolute \u00d7 activation max\n", "custom_max = activation_criteria(torch.abs, data_fn='max')\n", "custom_max.calibrate(model, cal_data, nn.Conv2d)"] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": ["## Criteria Comparison\n", "\n", "| Criterion | Data-Aware | Formula | Best For |\n", "|-----------|-----------|---------|----------|\n", "| `large_final` | No | `\\|W\\|` | General magnitude pruning |\n", "| **`wanda`** | **Yes** | **`\\|W\\| \u00d7 \u2016X\u2016\u2082`** | **Post-training one-shot** |\n", "| `movement` | No (init) | `\\|W - W\u2080\\|` | During-training pruning |\n", "| `grad_crit` | No | `(W \u00d7 \u2207W)\u00b2` | Gradient-informed selection |\n", "| `random` | No | Random | Baseline comparison |"] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": ["---\n", "\n", "## Summary\n", "\n", "| Tool / Function | Purpose |\n", "|----------------|----------|\n", "| `wanda` | Pre-built activation-aware criterion |\n", "| `wanda.calibrate(model, data)` | Collect activation statistics |\n", "| `activation_criteria(fn)` | Factory for custom activation criteria |\n", "| `Sparsifier(model, ..., wanda)` | Apply Wanda-guided sparsification |\n", "| `data_fn='l2_norm'` | Activation aggregation method |"] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": ["---\n", "\n", "## See Also\n", "\n", "- [Criteria](../../core/criteria.html) \u2014 All pruning criteria API reference\n", "- [Sparsifier](sparsifier.html) \u2014 Sparsifier class for applying criteria\n", "- [SparsifyCallback](sparsify_callback.html) \u2014 Training-time sparsification\n", "- [Schedules](schedules.html) \u2014 Control sparsification progression"] + } + ], + "metadata": { + "kernelspec": { + "display_name": "python3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file