7575from ..registry import DMRegistry
7676from ..search_space import SampleFunc
7777from ..traced_hp import TracedHp
78+ from .hooks import L2NormHook
7879
7980SUPPORTED_MODELS = {GPTModel : "megatron.core.models.gpt.GPTModel" }
8081
@@ -211,37 +212,17 @@ def _setup(self):
211212 # can be discarded.
212213 # This limitation might be fixed in OMNIML-180 (Flexible Importance Estimator)
213214 # where we separate the importance estimation from the dynamic module.
214- self ._register_temp_attribute ("_activations" , None )
215- self .hook_handle = self .linear_fc2 .register_forward_hook (self ._linear_fc2_forward_hook )
215+ max_ffn_size = self .get_hparam ("ffn_hidden_size" ).max
216+ assert isinstance (max_ffn_size , int ), "ffn_hidden_size.max must be an int"
217+ activation_hook = L2NormHook (max_size = max_ffn_size )
218+ self ._register_temp_attribute ("_activation_hook" , activation_hook )
219+ # TODO: confusion: why hook_handle is removed manually in export() and not using _register_temp_attribute?
220+ self .hook_handle = self .linear_fc2 .register_forward_hook (activation_hook )
216221 ffn_hidden_size .register_importance (self ._estimate_importance )
217222
218- def _linear_fc2_forward_hook (self , module , input , output ):
219- """Hook to collect activations for importance estimation.
220-
221- Activations are computed as mean over seq_len and then squared and summed over batch_size.
222- Later we take the square root of the sum to get the L2 norm.
223- """
224- # Gather input [seq_len, batch_size, ffn_hidden_size] over all TP regions
225- # NOTE: This is not used at the moment since we restrict to TP=1
226- input = gather_from_tensor_model_parallel_region (input [0 ]).detach ()
227-
228- # Dont aggregate activations from non-max subnets (e.g. from profiling)
229- if input .shape [- 1 ] != self .get_hparam ("ffn_hidden_size" ).max :
230- return
231-
232- input = input .to (torch .float32 ) # use full precision to avoid overflow
233- activations = input .abs ().mean (dim = 0 ) # [batch_size, ffn_hidden_size]
234- activations = activations .pow (2 ).sum (dim = 0 ) # [ffn_hidden_size]
235- if self ._activations is None :
236- self ._activations = activations
237- else :
238- self ._activations += activations
239-
240223 def _estimate_importance (self ) -> TracedHp .Importance :
241224 """Return the activation magnitude-based importance of the ffn_hidden_size."""
242- assert self ._activations is not None , "No activations collected for importance estimation."
243- # Convert squared sum to L2 norm
244- return self ._activations .pow (0.5 )
225+ return self ._activation_hook .accumulate ()
245226
246227 def export (self ) -> torch .nn .Module :
247228 """Export the dynamic module to a torch.nn.Module."""
@@ -545,46 +526,21 @@ def _setup(self):
545526 )
546527
547528 # register importance estimator for linear_qkv.output_size and linear_proj.input_size
548- self ._register_temp_attribute ("_activations" , None )
549- self .hook_handle = self .linear_proj .register_forward_hook (self ._linear_proj_forward_hook )
529+ num_heads_per_group_max = self .get_hparam ("num_heads_per_group" ).max
530+ num_query_groups_max = self .get_hparam ("num_query_groups" ).max
531+ max_size = num_heads_per_group_max * num_query_groups_max * self .config .kv_channels
532+ activation_hook = L2NormHook (max_size = max_size )
533+ self ._register_temp_attribute ("_activation_hook" , activation_hook )
534+ self .hook_handle = self .linear_proj .register_forward_hook (activation_hook )
550535 # NOTE: num_heads_per_group's slice_order will be of length num_attention_heads to be able to sort heads,
551536 # otherwise we would only have aggregated importance of heads per group.
552537 # While enforcing order during `sort_parameters`, we dont check the shape of the slice_order
553538 num_heads_per_group .register_importance (self ._estimate_all_head_importance )
554539 num_query_groups .register_importance (self ._estimate_query_group_importance )
555540
556- def _linear_proj_forward_hook (self , module , input , output ):
557- """Hook to collect activations for importance estimation.
558-
559- Activations are computed as mean over seq_len and then squared and summed over batch_size.
560- Later we take the square root of the sum to get the L2 norm.
561- """
562- # Gather input [seq_len, batch_size, query_projection_size] over all TP regions
563- # NOTE: This is not used at the moment since we restrict to TP=1
564- input = gather_from_tensor_model_parallel_region (input [0 ]).detach ()
565-
566- # Dont aggregate activations from non-max subnets (e.g. from profiling)
567- if (
568- input .shape [- 1 ]
569- != self .get_hparam ("num_heads_per_group" ).max
570- * self .get_hparam ("num_query_groups" ).max
571- * self .config .kv_channels
572- ):
573- return
574-
575- input = input .to (torch .float32 ) # use full precision to avoid overflow
576- activations = input .abs ().mean (dim = 0 )
577- activations = activations .pow (2 ).sum (dim = 0 ) # [query_projection_size]
578- if self ._activations is None :
579- self ._activations = activations
580- else :
581- self ._activations += activations
582-
583541 def _estimate_all_head_importance (self ) -> TracedHp .Importance :
584542 """Return the importance for num_attention_heads (num_heads_per_group * num_query_groups)."""
585- assert self ._activations is not None , "No activations collected for importance estimation."
586- # Convert squared sum to L2 norm
587- scores = self ._activations .pow (0.5 )
543+ scores = self ._activation_hook .accumulate ()
588544 attn_head_importance = torch .linalg .vector_norm (
589545 scores .view (
590546 self .get_hparam ("num_heads_per_group" ).max
@@ -598,9 +554,7 @@ def _estimate_all_head_importance(self) -> TracedHp.Importance:
598554
599555 def _estimate_query_group_importance (self ) -> TracedHp .Importance :
600556 """Return the importance of the ``num_query_groups`` hparam."""
601- assert self ._activations is not None , "No activations collected for importance estimation."
602- # Convert squared sum to L2 norm
603- scores = self ._activations .pow (0.5 )
557+ scores = self ._activation_hook .accumulate ()
604558 group_importance = torch .linalg .vector_norm (
605559 scores .view (
606560 self .get_hparam ("num_heads_per_group" ).max ,
@@ -1353,7 +1307,12 @@ def get_activations_and_layer_scores(
13531307 """Get the per-rank activations and layer scores from the module."""
13541308 local_activations = {}
13551309 for n , m in self .named_modules ():
1356- if hasattr (m , "_activations" ):
1310+ # New pattern: activations stored in hook
1311+ if hasattr (m , "_activation_hook" ) and m ._activation_hook ._activations is not None :
1312+ local_activations [n ] = m ._activation_hook ._activations
1313+ # Legacy pattern: activations stored directly on module.
1314+ # TODO: remove this once we switch to the new pattern.
1315+ elif hasattr (m , "_activations" ) and m ._activations is not None :
13571316 local_activations [n ] = m ._activations
13581317 activations_per_rank = dist .allgather (
13591318 local_activations , group = get_pipeline_model_parallel_group ()
@@ -1385,7 +1344,12 @@ def set_activations_and_layer_scores(
13851344 for layer in self .decoder .layers :
13861345 layer ._scores = layer_scores [layer .layer_number ]
13871346 for n , m in self .named_modules ():
1388- if hasattr (m , "_activations" ):
1347+ # New pattern: activations stored in hook
1348+ if hasattr (m , "_activation_hook" ):
1349+ m ._activation_hook ._activations = activations_per_rank [rank ][n ]
1350+ # Legacy pattern: activations stored directly on module.
1351+ # TODO: remove this once we switch to the new pattern.
1352+ elif hasattr (m , "_activations" ):
13891353 m ._activations = activations_per_rank [rank ][n ]
13901354
13911355
0 commit comments