Skip to content

Commit c7356e7

Browse files
Implement L2NormHook
Signed-off-by: Daniel Korzekwa <[email protected]>
1 parent 0e69d34 commit c7356e7

File tree

3 files changed

+174
-64
lines changed

3 files changed

+174
-64
lines changed
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Forward hooks for activation-based importance estimation in NAS plugins."""
16+
17+
from abc import ABC, abstractmethod
18+
19+
import torch
20+
from megatron.core.tensor_parallel import gather_from_tensor_model_parallel_region
21+
from torch import nn
22+
23+
24+
class ForwardHook(ABC):
25+
"""Base class for PyTorch forward hooks.
26+
27+
This follows the PyTorch forward hook API where the second
28+
parameter is 'args' (a tuple of positional arguments passed to forward()).
29+
30+
Usage:
31+
hook = MyHook()
32+
module.register_forward_hook(hook)
33+
"""
34+
35+
@abstractmethod
36+
def __call__(
37+
self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor
38+
) -> None:
39+
"""Forward hook that is called after the module's forward pass.
40+
41+
Args:
42+
module: The module this hook is registered on
43+
args: Tuple of positional arguments passed to module.forward()
44+
output: The output from module.forward()
45+
46+
Returns:
47+
None (does not modify the output)
48+
"""
49+
...
50+
51+
52+
class L2NormHook(ForwardHook):
53+
"""Hook for accumulating activation statistics for importance estimation.
54+
55+
Activations are computed as mean over seq_len and then squared and summed over batch_size.
56+
In the accumulate() method we take the square root of the sum to get the L2 norm.
57+
58+
Args:
59+
max_size: Optional maximum expected size to validate against (skips if mismatch).
60+
Useful for skipping non-max subnets during profiling.
61+
"""
62+
63+
def __init__(self, max_size: int | None = None):
64+
"""Initialize the L2NormHook."""
65+
self.max_size = max_size
66+
self._activations: torch.Tensor | None = None
67+
68+
def __call__(
69+
self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor
70+
) -> None:
71+
"""Accumulate activation statistics from the forward pass."""
72+
# Gather input [seq_len, batch_size, hidden_size] over all TP regions
73+
# NOTE: This is not used at the moment since we restrict to TP=1
74+
input_tensor = gather_from_tensor_model_parallel_region(args[0]).detach()
75+
76+
# Dont aggregate activations from non-max subnets (e.g. from profiling)
77+
if self.max_size is not None and input_tensor.shape[-1] != self.max_size:
78+
return
79+
80+
input_tensor = input_tensor.to(torch.float32) # use full precision to avoid overflow
81+
activations = input_tensor.abs().mean(dim=0) # [batch_size, hidden_size]
82+
activations = activations.pow(2).sum(dim=0) # [hidden_size]
83+
84+
if self._activations is None:
85+
self._activations = activations
86+
else:
87+
self._activations += activations
88+
89+
def accumulate(self) -> torch.Tensor:
90+
"""Return the accumulated L2 norm of activations.
91+
92+
Returns:
93+
Tensor of accumulated scores, one per channel
94+
95+
Raises:
96+
AssertionError: If no activations have been collected yet
97+
"""
98+
assert self._activations is not None, "No activations collected for importance estimation."
99+
# Convert squared sum to L2 norm
100+
return self._activations.pow(0.5)

modelopt/torch/nas/plugins/megatron.py

Lines changed: 28 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
from ..registry import DMRegistry
7676
from ..search_space import SampleFunc
7777
from ..traced_hp import TracedHp
78+
from .hooks import L2NormHook
7879

7980
SUPPORTED_MODELS = {GPTModel: "megatron.core.models.gpt.GPTModel"}
8081

@@ -211,37 +212,17 @@ def _setup(self):
211212
# can be discarded.
212213
# This limitation might be fixed in OMNIML-180 (Flexible Importance Estimator)
213214
# where we separate the importance estimation from the dynamic module.
214-
self._register_temp_attribute("_activations", None)
215-
self.hook_handle = self.linear_fc2.register_forward_hook(self._linear_fc2_forward_hook)
215+
max_ffn_size = self.get_hparam("ffn_hidden_size").max
216+
assert isinstance(max_ffn_size, int), "ffn_hidden_size.max must be an int"
217+
activation_hook = L2NormHook(max_size=max_ffn_size)
218+
self._register_temp_attribute("_activation_hook", activation_hook)
219+
# TODO: confusion: why hook_handle is removed manually in export() and not using _register_temp_attribute?
220+
self.hook_handle = self.linear_fc2.register_forward_hook(activation_hook)
216221
ffn_hidden_size.register_importance(self._estimate_importance)
217222

218-
def _linear_fc2_forward_hook(self, module, input, output):
219-
"""Hook to collect activations for importance estimation.
220-
221-
Activations are computed as mean over seq_len and then squared and summed over batch_size.
222-
Later we take the square root of the sum to get the L2 norm.
223-
"""
224-
# Gather input [seq_len, batch_size, ffn_hidden_size] over all TP regions
225-
# NOTE: This is not used at the moment since we restrict to TP=1
226-
input = gather_from_tensor_model_parallel_region(input[0]).detach()
227-
228-
# Dont aggregate activations from non-max subnets (e.g. from profiling)
229-
if input.shape[-1] != self.get_hparam("ffn_hidden_size").max:
230-
return
231-
232-
input = input.to(torch.float32) # use full precision to avoid overflow
233-
activations = input.abs().mean(dim=0) # [batch_size, ffn_hidden_size]
234-
activations = activations.pow(2).sum(dim=0) # [ffn_hidden_size]
235-
if self._activations is None:
236-
self._activations = activations
237-
else:
238-
self._activations += activations
239-
240223
def _estimate_importance(self) -> TracedHp.Importance:
241224
"""Return the activation magnitude-based importance of the ffn_hidden_size."""
242-
assert self._activations is not None, "No activations collected for importance estimation."
243-
# Convert squared sum to L2 norm
244-
return self._activations.pow(0.5)
225+
return self._activation_hook.accumulate()
245226

246227
def export(self) -> torch.nn.Module:
247228
"""Export the dynamic module to a torch.nn.Module."""
@@ -545,46 +526,21 @@ def _setup(self):
545526
)
546527

547528
# register importance estimator for linear_qkv.output_size and linear_proj.input_size
548-
self._register_temp_attribute("_activations", None)
549-
self.hook_handle = self.linear_proj.register_forward_hook(self._linear_proj_forward_hook)
529+
num_heads_per_group_max = self.get_hparam("num_heads_per_group").max
530+
num_query_groups_max = self.get_hparam("num_query_groups").max
531+
max_size = num_heads_per_group_max * num_query_groups_max * self.config.kv_channels
532+
activation_hook = L2NormHook(max_size=max_size)
533+
self._register_temp_attribute("_activation_hook", activation_hook)
534+
self.hook_handle = self.linear_proj.register_forward_hook(activation_hook)
550535
# NOTE: num_heads_per_group's slice_order will be of length num_attention_heads to be able to sort heads,
551536
# otherwise we would only have aggregated importance of heads per group.
552537
# While enforcing order during `sort_parameters`, we dont check the shape of the slice_order
553538
num_heads_per_group.register_importance(self._estimate_all_head_importance)
554539
num_query_groups.register_importance(self._estimate_query_group_importance)
555540

556-
def _linear_proj_forward_hook(self, module, input, output):
557-
"""Hook to collect activations for importance estimation.
558-
559-
Activations are computed as mean over seq_len and then squared and summed over batch_size.
560-
Later we take the square root of the sum to get the L2 norm.
561-
"""
562-
# Gather input [seq_len, batch_size, query_projection_size] over all TP regions
563-
# NOTE: This is not used at the moment since we restrict to TP=1
564-
input = gather_from_tensor_model_parallel_region(input[0]).detach()
565-
566-
# Dont aggregate activations from non-max subnets (e.g. from profiling)
567-
if (
568-
input.shape[-1]
569-
!= self.get_hparam("num_heads_per_group").max
570-
* self.get_hparam("num_query_groups").max
571-
* self.config.kv_channels
572-
):
573-
return
574-
575-
input = input.to(torch.float32) # use full precision to avoid overflow
576-
activations = input.abs().mean(dim=0)
577-
activations = activations.pow(2).sum(dim=0) # [query_projection_size]
578-
if self._activations is None:
579-
self._activations = activations
580-
else:
581-
self._activations += activations
582-
583541
def _estimate_all_head_importance(self) -> TracedHp.Importance:
584542
"""Return the importance for num_attention_heads (num_heads_per_group * num_query_groups)."""
585-
assert self._activations is not None, "No activations collected for importance estimation."
586-
# Convert squared sum to L2 norm
587-
scores = self._activations.pow(0.5)
543+
scores = self._activation_hook.accumulate()
588544
attn_head_importance = torch.linalg.vector_norm(
589545
scores.view(
590546
self.get_hparam("num_heads_per_group").max
@@ -598,9 +554,7 @@ def _estimate_all_head_importance(self) -> TracedHp.Importance:
598554

599555
def _estimate_query_group_importance(self) -> TracedHp.Importance:
600556
"""Return the importance of the ``num_query_groups`` hparam."""
601-
assert self._activations is not None, "No activations collected for importance estimation."
602-
# Convert squared sum to L2 norm
603-
scores = self._activations.pow(0.5)
557+
scores = self._activation_hook.accumulate()
604558
group_importance = torch.linalg.vector_norm(
605559
scores.view(
606560
self.get_hparam("num_heads_per_group").max,
@@ -1353,7 +1307,12 @@ def get_activations_and_layer_scores(
13531307
"""Get the per-rank activations and layer scores from the module."""
13541308
local_activations = {}
13551309
for n, m in self.named_modules():
1356-
if hasattr(m, "_activations"):
1310+
# New pattern: activations stored in hook
1311+
if hasattr(m, "_activation_hook") and m._activation_hook._activations is not None:
1312+
local_activations[n] = m._activation_hook._activations
1313+
# Legacy pattern: activations stored directly on module.
1314+
# TODO: remove this once we switch to the new pattern.
1315+
elif hasattr(m, "_activations") and m._activations is not None:
13571316
local_activations[n] = m._activations
13581317
activations_per_rank = dist.allgather(
13591318
local_activations, group=get_pipeline_model_parallel_group()
@@ -1385,7 +1344,12 @@ def set_activations_and_layer_scores(
13851344
for layer in self.decoder.layers:
13861345
layer._scores = layer_scores[layer.layer_number]
13871346
for n, m in self.named_modules():
1388-
if hasattr(m, "_activations"):
1347+
# New pattern: activations stored in hook
1348+
if hasattr(m, "_activation_hook"):
1349+
m._activation_hook._activations = activations_per_rank[rank][n]
1350+
# Legacy pattern: activations stored directly on module.
1351+
# TODO: remove this once we switch to the new pattern.
1352+
elif hasattr(m, "_activations"):
13891353
m._activations = activations_per_rank[rank][n]
13901354

13911355

tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,52 @@ def forward_loop(m):
132132
assert pruning_scores["layer_scores"]
133133
assert pruning_scores["activations_per_rank"]
134134

135+
# TODO: Simplify it: this unit test is too long,
136+
# hard to read (the same set of assertions across different test cases with if-else).
137+
138+
assert len(pruning_scores["activations_per_rank"]) == 1
139+
rank_0_activations = pruning_scores["activations_per_rank"][0]
140+
141+
# Test case 1: MHA - pruned ffn/4 (num_attention_heads=8, num_query_groups=8, ffn_div=4)
142+
if pruned_ffn_div == 4:
143+
# Layer scores (these use cosine similarity, independent of FFN activation hook)
144+
assert pruning_scores["layer_scores"][1] == pytest.approx(2.1437832713127136, abs=1e-5)
145+
assert pruning_scores["layer_scores"][2] == pytest.approx(1.792158305644989, abs=1e-5)
146+
147+
# Validate decoder.layers.0.mlp activations
148+
mlp_0_acts = rank_0_activations["decoder.layers.0.mlp"]
149+
assert mlp_0_acts.min().item() == pytest.approx(0.0011843212, abs=1e-5)
150+
assert mlp_0_acts.max().item() == pytest.approx(1.0846971273, abs=1e-5)
151+
assert mlp_0_acts.mean().item() == pytest.approx(0.0535472594, abs=1e-5)
152+
153+
# Validate decoder.layers.1.mlp activations
154+
mlp_1_acts = rank_0_activations["decoder.layers.1.mlp"]
155+
assert mlp_1_acts.min().item() == pytest.approx(0.0002450741, abs=1e-5)
156+
assert mlp_1_acts.max().item() == pytest.approx(1.1014972925, abs=1e-5)
157+
assert mlp_1_acts.mean().item() == pytest.approx(0.0904172808, abs=1e-5)
158+
159+
# Test case 2: GQA - pruned attention/2 (num_attention_heads=8, num_query_groups=4, attention_div=2)
160+
elif pruned_num_attention_heads_div == 2 and pruned_ffn_div == 1:
161+
# Layer scores
162+
assert pruning_scores["layer_scores"][1] == pytest.approx(2.1119985580444336, abs=1e-5)
163+
assert pruning_scores["layer_scores"][2] == pytest.approx(1.7729830741882324, abs=1e-5)
164+
165+
# Validate decoder.layers.0.self_attention activations
166+
assert "decoder.layers.0.self_attention" in rank_0_activations
167+
attn_0_acts = rank_0_activations["decoder.layers.0.self_attention"]
168+
assert attn_0_acts.shape == torch.Size([256])
169+
assert attn_0_acts.min().item() == pytest.approx(0.03729403391480446, abs=1e-5)
170+
assert attn_0_acts.max().item() == pytest.approx(0.3653244972229004, abs=1e-5)
171+
assert attn_0_acts.mean().item() == pytest.approx(0.15008458495140076, abs=1e-5)
172+
173+
# Validate decoder.layers.1.self_attention activations
174+
assert "decoder.layers.1.self_attention" in rank_0_activations
175+
attn_1_acts = rank_0_activations["decoder.layers.1.self_attention"]
176+
assert attn_1_acts.shape == torch.Size([256])
177+
assert attn_1_acts.min().item() == pytest.approx(0.140824556350708, abs=1e-5)
178+
assert attn_1_acts.max().item() == pytest.approx(1.0845409631729126, abs=1e-5)
179+
assert attn_1_acts.mean().item() == pytest.approx(0.4730667173862457, abs=1e-5)
180+
135181
# Assert weights are pruned correctly
136182
for layer in model.decoder.layers:
137183
assert layer.mlp.linear_fc1.weight.shape == (

0 commit comments

Comments
 (0)