Skip to content

Commit dc2d4c3

Browse files
committed
Add llama4 scaling for Mistral Large 3.
Signed-off-by: Tracin <[email protected]>
1 parent 7cc16f8 commit dc2d4c3

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

tensorrt_llm/_torch/modules/attention.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -972,6 +972,13 @@ def yarn_get_mscale(scale=1, mscale=1):
972972
is_neox=pos_embd_params.is_neox,
973973
)
974974

975+
if hasattr(config.pretrained_config, 'llama_4_scaling'):
976+
self.llama_4_scaling = True
977+
self.floor_scale = getattr(config.pretrained_config.llama_4_scaling,
978+
'original_max_position_embeddings', 8192)
979+
self.attn_scale = getattr(config.pretrained_config.llama_4_scaling,
980+
'beta', 0.1)
981+
975982
if not config.skip_create_weights_in_init:
976983
self.create_weights()
977984

@@ -1114,6 +1121,18 @@ def create_output(self, hidden_states: torch.Tensor, num_contexts: int):
11141121
return hidden_states.new_empty([num_tokens, hidden_size],
11151122
dtype=hidden_states.dtype)
11161123

1124+
def _attention_scaling(self, q, position_ids):
1125+
1126+
def _get_attn_scale(position_ids: torch.Tensor) -> torch.Tensor:
1127+
positions = position_ids.view(-1)
1128+
floor = torch.floor((positions + 1.0) / self.floor_scale)
1129+
attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0
1130+
return attn_scale.unsqueeze(-1)
1131+
1132+
attn_scale = _get_attn_scale(position_ids)
1133+
q = (q * attn_scale).to(q.dtype)
1134+
return q
1135+
11171136
def forward_impl(self,
11181137
position_ids: Optional[torch.Tensor],
11191138
hidden_states: torch.Tensor,
@@ -1184,6 +1203,9 @@ def forward_impl(self,
11841203
assert position_ids is not None
11851204
k_pe_ctx = self.apply_rope(q_ctx, k_pe_ctx, position_ids)
11861205

1206+
if self.llama_4_scaling:
1207+
q_ctx = self._attention_scaling(
1208+
q_ctx, position_ids[..., :num_ctx_tokens])
11871209
self.forward_context(
11881210
q_ctx,
11891211
compressed_kv_ctx,
@@ -1204,6 +1226,9 @@ def forward_impl(self,
12041226
assert position_ids is not None
12051227
k_pe_gen = self.apply_rope(q_gen, k_pe_gen, position_ids)
12061228

1229+
if self.llama_4_scaling:
1230+
q_gen = self._attention_scaling(
1231+
q_gen, position_ids[..., num_ctx_tokens:])
12071232
self.forward_absorption_generation(
12081233
q_gen,
12091234
compressed_kv_gen,

0 commit comments

Comments
 (0)