Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 32 additions & 64 deletions modelopt/torch/nas/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
from megatron.core.transformer.moe.shared_experts import SharedExpertMLP
from megatron.core.transformer.transformer_layer import TransformerLayer

from modelopt.torch.nas.modules import DynamicModuleList
from modelopt.torch.opt.dynamic import DynamicModule
from modelopt.torch.opt.hparam import HPType
from modelopt.torch.opt.searcher import ConstraintsDict
Expand All @@ -77,11 +76,12 @@
ConstraintsRes,
)
from ..hparams.concat import build_concat_hp
from ..modules import _DynamicLayerNorm
from ..modules import DynamicModuleList, _DynamicLayerNorm
from ..modules.utils import get_sliced_tensor, get_sliced_tensor_by_slices
from ..registry import DMRegistry
from ..search_space import SampleFunc
from ..traced_hp import TracedHp
from .megatron_hooks import MegatronL2NormHook

SUPPORTED_MODELS = {GPTModel: "megatron.core.models.gpt.GPTModel"}

Expand Down Expand Up @@ -265,39 +265,19 @@ def _setup(self):
# can be discarded.
# This limitation might be fixed in OMNIML-180 (Flexible Importance Estimator)
# where we separate the importance estimation from the dynamic module.
self._register_temp_attribute("_activations", None)
self.hook_handle = self.linear_fc2.register_forward_hook(self._linear_fc2_forward_hook)
max_ffn_size = int(self.get_hparam(self.hparam_name).max) # type: ignore[arg-type]
activation_hook = MegatronL2NormHook(max_size=max_ffn_size)
self._register_temp_attribute("_activation_hook", activation_hook)
# TODO: confusion: why hook_handle is removed manually in export() and not using _register_temp_attribute?
self.hook_handle = self.linear_fc2.register_forward_hook(activation_hook)
ffn_hidden_size.register_importance(self._estimate_importance)

def _linear_fc2_forward_hook(self, module, input, output):
"""Hook to collect activations for importance estimation.

Activations are computed as mean over seq_len and then squared and summed over batch_size.
Later we take the square root of the sum to get the L2 norm.
"""
# Gather input [seq_len, batch_size, ffn_hidden_size] over all TP regions
# NOTE: This is not used at the moment since we restrict to TP=1
input = gather_from_tensor_model_parallel_region(input[0]).detach()
if input.dim() == 2:
# For sparse experts, there is no batch dimension.
input = input[:, None, :]
# Dont aggregate activations from non-max subnets (e.g. from profiling)
if input.shape[-1] != self.get_hparam(self.hparam_name).max:
return

input = input.to(torch.float32) # use full precision to avoid overflow
activations = input.abs().mean(dim=0) # [batch_size, ffn_hidden_size]
activations = activations.pow(2).sum(dim=0) # [ffn_hidden_size]
if self._activations is None:
self._activations = activations
else:
self._activations += activations

def _estimate_importance(self) -> TracedHp.Importance:
"""Return the activation magnitude-based importance of the ffn_hidden_size."""
assert self._activations is not None, "No activations collected for importance estimation."
# Convert squared sum to L2 norm
return self._activations.pow(0.5)
assert self._activation_hook._activations is not None, (
"No activations collected for importance estimation."
)
return self._activation_hook.accumulate()

def set_hidden_size_hp(self, hidden_size: TracedHp) -> None:
"""Set hidden size for shared expert."""
Expand Down Expand Up @@ -612,46 +592,26 @@ def _setup(self):
)

# register importance estimator for linear_qkv.output_size and linear_proj.input_size
self._register_temp_attribute("_activations", None)
self.hook_handle = self.linear_proj.register_forward_hook(self._linear_proj_forward_hook)
num_heads_per_group_max = int(self.get_hparam("num_heads_per_group").max) # type: ignore[arg-type]
num_query_groups_max = int(self.get_hparam("num_query_groups").max) # type: ignore[arg-type]
max_size = num_heads_per_group_max * num_query_groups_max * self.config.kv_channels
activation_hook = MegatronL2NormHook(max_size=max_size)
self._register_temp_attribute("_activation_hook", activation_hook)
# TODO: confusion: why hook_handle is removed manually in export() and not using _register_temp_attribute?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

even if we register hook_handle as temp attribute, we still need to call hook_handle.remove() to remove the hook so there's no change. Temp attribute will be remove from model i.e. self.hook_handle reference will be dropped but that still doesnt remove the actuall pytorch hook added to the forward

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand now.

self.hook_handle = self.linear_proj.register_forward_hook(activation_hook)
# NOTE: num_heads_per_group's slice_order will be of length num_attention_heads to be able to sort heads,
# otherwise we would only have aggregated importance of heads per group.
# While enforcing order during `sort_parameters`, we dont check the shape of the slice_order
num_heads_per_group.register_importance(self._estimate_all_head_importance)
num_query_groups.register_importance(self._estimate_query_group_importance)

def _linear_proj_forward_hook(self, module, input, output):
"""Hook to collect activations for importance estimation.

Activations are computed as mean over seq_len and then squared and summed over batch_size.
Later we take the square root of the sum to get the L2 norm.
"""
# Gather input [seq_len, batch_size, query_projection_size] over all TP regions
# NOTE: This is not used at the moment since we restrict to TP=1
input = gather_from_tensor_model_parallel_region(input[0]).detach()

# Dont aggregate activations from non-max subnets (e.g. from profiling)
if (
input.shape[-1]
!= self.get_hparam("num_heads_per_group").max
* self.get_hparam("num_query_groups").max
* self.config.kv_channels
):
return

input = input.to(torch.float32) # use full precision to avoid overflow
activations = input.abs().mean(dim=0)
activations = activations.pow(2).sum(dim=0) # [query_projection_size]
if self._activations is None:
self._activations = activations
else:
self._activations += activations

def _estimate_all_head_importance(self) -> TracedHp.Importance:
"""Return the importance for num_attention_heads (num_heads_per_group * num_query_groups)."""
assert self._activations is not None, "No activations collected for importance estimation."
assert self._activation_hook._activations is not None, (
"No activations collected for importance estimation."
)
# Convert squared sum to L2 norm
scores = self._activations.pow(0.5)
scores = self._activation_hook.accumulate()
attn_head_importance = torch.linalg.vector_norm(
scores.view(
self.get_hparam("num_heads_per_group").max
Expand All @@ -665,9 +625,11 @@ def _estimate_all_head_importance(self) -> TracedHp.Importance:

def _estimate_query_group_importance(self) -> TracedHp.Importance:
"""Return the importance of the ``num_query_groups`` hparam."""
assert self._activations is not None, "No activations collected for importance estimation."
assert self._activation_hook._activations is not None, (
"No activations collected for importance estimation."
)
# Convert squared sum to L2 norm
scores = self._activations.pow(0.5)
scores = self._activation_hook.accumulate()
group_importance = torch.linalg.vector_norm(
scores.view(
self.get_hparam("num_heads_per_group").max,
Expand Down Expand Up @@ -1594,8 +1556,11 @@ def get_activations_and_layer_scores(
"""Get the per-rank activations and layer scores from the module."""
local_activations = {}
for n, m in self.named_modules():
# TODO: Remove legacy _activations check once all modules use _activation_hook
if hasattr(m, "_activations"):
local_activations[n] = m._activations
elif hasattr(m, "_activation_hook"):
local_activations[n] = m._activation_hook._activations
activations_per_rank = dist.allgather(
local_activations, group=get_pipeline_model_parallel_group()
)
Expand Down Expand Up @@ -1624,8 +1589,11 @@ def set_activations_and_layer_scores(
for layer in self.decoder.layers:
layer._scores = layer_scores[layer.layer_number]
for n, m in self.named_modules():
# TODO: Remove legacy _activations check once all modules use _activation_hook
if hasattr(m, "_activations"):
m._activations = activations_per_rank[rank][n]
elif hasattr(m, "_activation_hook"):
m._activation_hook._activations = activations_per_rank[rank][n]


def drop_mcore_language_model_layers(model: nn.Module, *, layers_to_drop: list[int]) -> None:
Expand Down
104 changes: 104 additions & 0 deletions modelopt/torch/nas/plugins/megatron_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Forward hooks for activation-based importance estimation (megatron NAS plugin)."""

from abc import ABC, abstractmethod

import torch
from megatron.core.tensor_parallel import gather_from_tensor_model_parallel_region
from torch import nn


class ForwardHook(ABC):
"""Base class for PyTorch forward hooks.

This follows the PyTorch forward hook API where the second
parameter is 'args' (a tuple of positional arguments passed to forward()).

Usage:
hook = MyHook()
module.register_forward_hook(hook)
"""

@abstractmethod
def __call__(
self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor
) -> None:
"""Forward hook that is called after the module's forward pass.

Args:
module: The module this hook is registered on
args: Tuple of positional arguments passed to module.forward()
output: The output from module.forward()

Returns:
None (does not modify the output)
"""
...


class MegatronL2NormHook(ForwardHook):
"""Hook for accumulating activation statistics for importance estimation.

Activations are computed as mean over seq_len and then squared and summed over batch_size.
In the accumulate() method we take the square root of the sum to get the L2 norm.

Args:
max_size: Optional maximum expected size to validate against (skips if mismatch).
Useful for skipping non-max subnets during profiling.
"""

def __init__(self, max_size: int | None = None):
"""Initialize the L2NormHook."""
self.max_size = max_size
self._activations: torch.Tensor | None = None

def __call__(
self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor
) -> None:
"""Accumulate activation statistics from the forward pass."""
# Gather input [seq_len, batch_size, hidden_size] over all TP regions
# NOTE: This is not used at the moment since we restrict to TP=1
input_tensor = gather_from_tensor_model_parallel_region(args[0]).detach()

if input_tensor.dim() == 2:
# For sparse experts, there is no batch dimension.
input_tensor = input_tensor[:, None, :]

# Dont aggregate activations from non-max subnets (e.g. from profiling)
if self.max_size is not None and input_tensor.shape[-1] != self.max_size:
return

input_tensor = input_tensor.to(torch.float32) # use full precision to avoid overflow
activations = input_tensor.abs().mean(dim=0) # [batch_size, hidden_size]
activations = activations.pow(2).sum(dim=0) # [hidden_size]

if self._activations is None:
self._activations = activations
else:
self._activations += activations

def accumulate(self) -> torch.Tensor:
"""Return the accumulated L2 norm of activations.

Returns:
Tensor of accumulated scores, one per channel

Raises:
AssertionError: If no activations have been collected yet
"""
assert self._activations is not None, "No activations collected for importance estimation."
# Convert squared sum to L2 norm
return self._activations.pow(0.5)
48 changes: 48 additions & 0 deletions tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,12 @@ def _get_model(initialize_megatron=True):
normalization=normalization,
num_layers_in_first_pipeline_stage=num_layers_in_first_pipeline_stage,
num_layers_in_last_pipeline_stage=num_layers_in_last_pipeline_stage,
use_cpu_initialization=True, # Ensure deterministic weight init across CUDA versions
).cuda()
return model

model = _get_model()

sd = model.state_dict()

def forward_loop(m):
Expand Down Expand Up @@ -134,6 +136,52 @@ def forward_loop(m):
assert pruning_scores["layer_scores"]
assert pruning_scores["activations_per_rank"]

# TODO: Simplify it: this unit test is too long,
# hard to read (the same set of assertions across different test cases with if-else).

assert len(pruning_scores["activations_per_rank"]) == 1
rank_0_activations = pruning_scores["activations_per_rank"][0]

# Test case 1: MHA - pruned ffn/4 (num_attention_heads=8, num_query_groups=8, ffn_div=4)
if pruned_ffn_div == 4:
# Layer scores
assert pruning_scores["layer_scores"][1] == pytest.approx(2.0868452191352844, abs=1e-3)
assert pruning_scores["layer_scores"][2] == pytest.approx(1.7638601660728455, abs=1e-3)

# Validate decoder.layers.0.mlp activations
mlp_0_acts = rank_0_activations["decoder.layers.0.mlp"]
assert mlp_0_acts.min().item() == pytest.approx(0.0015609927941114, abs=1e-3)
assert mlp_0_acts.max().item() == pytest.approx(0.3844809532165527, abs=1e-3)
assert mlp_0_acts.mean().item() == pytest.approx(0.0629318505525589, abs=1e-3)

# Validate decoder.layers.1.mlp activations
mlp_1_acts = rank_0_activations["decoder.layers.1.mlp"]
assert mlp_1_acts.min().item() == pytest.approx(0.0001484956446802, abs=1e-3)
assert mlp_1_acts.max().item() == pytest.approx(0.7835369110107422, abs=1e-3)
assert mlp_1_acts.mean().item() == pytest.approx(0.0926810950040817, abs=1e-3)

# Test case 2: GQA - pruned attention/2 (num_attention_heads=8, num_query_groups=4, attention_div=2)
elif pruned_num_attention_heads_div == 2 and pruned_ffn_div == 1:
# Layer scores
assert pruning_scores["layer_scores"][1] == pytest.approx(2.1415508985519409, abs=1e-3)
assert pruning_scores["layer_scores"][2] == pytest.approx(1.7198008894920349, abs=1e-3)

# Validate decoder.layers.0.self_attention activations
assert "decoder.layers.0.self_attention" in rank_0_activations
attn_0_acts = rank_0_activations["decoder.layers.0.self_attention"]
assert attn_0_acts.shape == torch.Size([256])
assert attn_0_acts.min().item() == pytest.approx(0.0409194342792034, abs=1e-3)
assert attn_0_acts.max().item() == pytest.approx(0.5261313319206238, abs=1e-3)
assert attn_0_acts.mean().item() == pytest.approx(0.1613342612981796, abs=1e-3)

# Validate decoder.layers.1.self_attention activations
assert "decoder.layers.1.self_attention" in rank_0_activations
attn_1_acts = rank_0_activations["decoder.layers.1.self_attention"]
assert attn_1_acts.shape == torch.Size([256])
assert attn_1_acts.min().item() == pytest.approx(0.1189328655600548, abs=1e-3)
assert attn_1_acts.max().item() == pytest.approx(1.3832759857177734, abs=1e-3)
assert attn_1_acts.mean().item() == pytest.approx(0.4782669544219971, abs=1e-3)

# Assert weights are pruned correctly
for layer in model.decoder.layers:
assert layer.mlp.linear_fc1.weight.shape == (
Expand Down