@@ -972,6 +972,13 @@ def yarn_get_mscale(scale=1, mscale=1):
972972 is_neox = pos_embd_params .is_neox ,
973973 )
974974
975+ if hasattr (config .pretrained_config , 'llama_4_scaling' ):
976+ self .llama_4_scaling = True
977+ self .floor_scale = getattr (config .pretrained_config .llama_4_scaling ,
978+ 'original_max_position_embeddings' , 8192 )
979+ self .attn_scale = getattr (config .pretrained_config .llama_4_scaling ,
980+ 'beta' , 0.1 )
981+
975982 if not config .skip_create_weights_in_init :
976983 self .create_weights ()
977984
@@ -1114,6 +1121,18 @@ def create_output(self, hidden_states: torch.Tensor, num_contexts: int):
11141121 return hidden_states .new_empty ([num_tokens , hidden_size ],
11151122 dtype = hidden_states .dtype )
11161123
1124+ def _attention_scaling (self , q , position_ids ):
1125+
1126+ def _get_attn_scale (position_ids : torch .Tensor ) -> torch .Tensor :
1127+ positions = position_ids .view (- 1 )
1128+ floor = torch .floor ((positions + 1.0 ) / self .floor_scale )
1129+ attn_scale = torch .log (floor + 1.0 ) * self .attn_scale + 1.0
1130+ return attn_scale .unsqueeze (- 1 )
1131+
1132+ attn_scale = _get_attn_scale (position_ids )
1133+ q = (q * attn_scale ).to (q .dtype )
1134+ return q
1135+
11171136 def forward_impl (self ,
11181137 position_ids : Optional [torch .Tensor ],
11191138 hidden_states : torch .Tensor ,
@@ -1184,6 +1203,9 @@ def forward_impl(self,
11841203 assert position_ids is not None
11851204 k_pe_ctx = self .apply_rope (q_ctx , k_pe_ctx , position_ids )
11861205
1206+ if self .llama_4_scaling :
1207+ q_ctx = self ._attention_scaling (
1208+ q_ctx , position_ids [..., :num_ctx_tokens ])
11871209 self .forward_context (
11881210 q_ctx ,
11891211 compressed_kv_ctx ,
@@ -1204,6 +1226,9 @@ def forward_impl(self,
12041226 assert position_ids is not None
12051227 k_pe_gen = self .apply_rope (q_gen , k_pe_gen , position_ids )
12061228
1229+ if self .llama_4_scaling :
1230+ q_gen = self ._attention_scaling (
1231+ q_gen , position_ids [..., num_ctx_tokens :])
12071232 self .forward_absorption_generation (
12081233 q_gen ,
12091234 compressed_kv_gen ,
0 commit comments