Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
0ef342c
Add qk-clip
BoxiangW Sep 25, 2025
a5d2ea8
w/o EP working
BoxiangW Sep 26, 2025
978b99b
Lint
BoxiangW Sep 26, 2025
b77a6d6
Fix MLA usage and add logging for max_attention_score
BoxiangW Sep 29, 2025
a83b68c
Removed Attention MuonClip since it is not correct. Added MLA MuonClip
BoxiangW Oct 1, 2025
eae6a5a
Add TE bug WAR and make sure MLA works
BoxiangW Oct 1, 2025
c370cdd
Fix qkclip issue
BoxiangW Oct 9, 2025
6f8f77d
Fix bug
BoxiangW Oct 9, 2025
527075f
Fix TP usage
BoxiangW Oct 10, 2025
3cebc1f
Fix Lint
BoxiangW Oct 10, 2025
e3c468f
Fix import and error log
BoxiangW Oct 10, 2025
ef995ea
lint
BoxiangW Oct 10, 2025
7f5f1e2
Address comments
BoxiangW Oct 16, 2025
65fc829
Lint
BoxiangW Oct 16, 2025
75ee38c
Added GQA QK Clipping
BoxiangW Oct 17, 2025
a1c0ff6
Lint
BoxiangW Oct 17, 2025
762bdb5
Remove comment
BoxiangW Oct 24, 2025
ba5c073
Rename max_score to max_logit
BoxiangW Oct 24, 2025
7917e68
Change name
BoxiangW Oct 24, 2025
495f58d
Merge branch 'main' into boxiangw/muon-clip
BoxiangW Nov 3, 2025
55cc00d
Lint
BoxiangW Nov 3, 2025
f54fbdd
Fix copyright
BoxiangW Nov 5, 2025
9095615
Merge branch 'main' into boxiangw/muon-clip
BoxiangW Nov 5, 2025
4b490cc
Move qk_clip fucntion into megatron/core
BoxiangW Nov 8, 2025
0453ded
Merge branch 'main' into boxiangw/muon-clip
BoxiangW Nov 8, 2025
1bb0407
Add tests for qk_clip
BoxiangW Nov 10, 2025
b63c573
Lint and copyright
BoxiangW Nov 10, 2025
bdda5e9
Add te version checks into PR
BoxiangW Nov 10, 2025
bd0436b
Address comments
BoxiangW Nov 12, 2025
646aea5
Address comments
BoxiangW Nov 12, 2025
4d16029
DP all reduce and switch to mul_ inplace op
BoxiangW Nov 17, 2025
2b890d0
Update both main_params and non
BoxiangW Nov 17, 2025
95fdba3
Address comments
BoxiangW Nov 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,14 @@ def __init__(
self.kept_packed_seq_params.discard("cu_seqlens_q_padded")
self.kept_packed_seq_params.discard("cu_seqlens_kv_padded")

if config.qk_clip or config.log_max_attention_logit:
# qk-clip is only supported in TE 2.9.0 and later
assert is_te_min_version("2.9.0"), "qk-clip is only supported in TE 2.9.0 and later"

# TE 2.9.0 introduces return_max_logit for qk-clip getting the max attention logits
extra_kwargs["return_max_logit"] = True
self.current_max_attn_logits = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if not is_te_min_version but config.qk_clip? might raise an error

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now it will raise an error if te version is wrong


super().__init__(
num_attention_heads=self.config.num_attention_heads,
kv_channels=kv_channels,
Expand Down Expand Up @@ -1058,6 +1066,22 @@ def forward(
**attention_bias_kwargs,
**packed_seq_kwargs,
)

if self.config.qk_clip or self.config.log_max_attention_logit:
# qk-clip is only supported in TE 2.9.0 and later
assert is_te_min_version("2.9.0"), "qk-clip is only supported in TE 2.9.0 and later"

# Update Q K outside of TE Attention API
core_attn_out, batch_max_attention_logits = core_attn_out

# Update QK_Clip balancing eta
if self.current_max_attn_logits is None:
self.current_max_attn_logits = batch_max_attention_logits
else:
self.current_max_attn_logits = torch.max(
self.current_max_attn_logits, batch_max_attention_logits
)

else:
core_attn_out = super().forward(
query, key, value, attention_mask, **attention_bias_kwargs, **packed_seq_kwargs
Expand Down
39 changes: 39 additions & 0 deletions megatron/core/optimizer/qk_clip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

import torch

from megatron.core import mpu


def clip_qk(model, log_max_only=False) -> float:
"""
Clip the QK attention logits to the threshold, recommended for Muon optimizer.

Args:
model: The model to clip the QK attention logits, a list of model chunks.
log_only: Whether to only log the max attention logit, without updating the weights.

Returns:
The maximum attention logit, a float.
"""

with torch.no_grad():
log_max_attention_logit = 0
for model_chunk in model:
for transformer_layer in model_chunk.module.module.decoder.layers:
if hasattr(transformer_layer.self_attention, 'clip_qk'):
torch.distributed.all_reduce(
transformer_layer.self_attention.core_attention.current_max_attn_logits,
op=torch.distributed.ReduceOp.MAX,
group=mpu.get_data_parallel_group(with_context_parallel=True),
)
log_max_attention_logit = max(
log_max_attention_logit,
torch.max(
transformer_layer.self_attention.core_attention.current_max_attn_logits
).item(),
)
if not log_max_only:
transformer_layer.self_attention.clip_qk()

return log_max_attention_logit
104 changes: 104 additions & 0 deletions megatron/core/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,13 @@ def set_for_recompute_input_layernorm(self):
"""Set the attention layer for recompute input_layernorm. Only needed for fp8."""
raise NotImplementedError("set_for_recompute_input_layernorm is not implemented.")

def clip_qk(self):
"""
QK Clipping is a technique to clip the query and key attention logits to prevent the
attention logits from exploding.
"""
raise NotImplementedError("clip_qk is not implemented.")


class SelfAttention(Attention):
"""Self-attention layer class
Expand Down Expand Up @@ -1145,6 +1152,103 @@ def set_for_recompute_input_layernorm(self):

set_save_original_input(self.linear_qkv)

def clip_qk(self):
"""
QK Clipping is a technique to clip the query and key attention logits to prevent the
attention logits from exploding. This function is experimental on GQA.
"""
if not self.config.qk_clip:
raise ValueError("qk_clip option needs to be enabled")

if self.core_attention.current_max_attn_logits is None:
raise ValueError("current_max_attn_logits is None")

assert self.core_attention.current_max_attn_logits.shape == (
self.num_attention_heads_per_partition,
), f"current_max_attn_logits shape is not ({self.num_attention_heads_per_partition}, ) \
but {self.core_attention.current_max_attn_logits.shape}"

grouped_max_attn_logits = torch.max(
self.core_attention.current_max_attn_logits.view(
self.num_query_groups_per_partition, -1
),
dim=1,
).values

# only update the weight if any head has
# current_max_attn_logits > qk_clip_threshold
if torch.any(grouped_max_attn_logits > self.config.qk_clip_threshold):
# Use num_query_groups_per_partition for tensor parallel scenarios

# qk_clip_balancing_eta (g, 1, 1)
assert grouped_max_attn_logits.shape == (
self.num_query_groups_per_partition,
), f"current_max_attn_logits shape is not ({self.num_query_groups_per_partition},) \
but {grouped_max_attn_logits.shape}"
self.qk_clip_balancing_eta = torch.clamp(
self.config.qk_clip_threshold / grouped_max_attn_logits, max=1.0
).view(self.num_query_groups_per_partition, 1, 1)
assert torch.all(self.qk_clip_balancing_eta <= 1.0)

# Handle different weight access patterns (main_param vs direct access)
if hasattr(self.linear_qkv.weight, 'main_param'):
self.linear_qkv.weight.main_param.data.copy_(
self._clip_linear_qkv(self.linear_qkv.weight.main_param.data)
)

self.linear_qkv.weight.data.copy_(self._clip_linear_qkv(self.linear_qkv.weight.data))

# reset current_max_attn_logits
self.core_attention.current_max_attn_logits = None

def _clip_linear_qkv(self, weight):
"""Apply qkclip to linear_qkv layer"""
# Reshape to (g, query_projection_size + 2 * kv_projection_size, -1)
weight_reshaped = weight.view(
self.num_query_groups_per_partition,
(self.query_projection_size + 2 * self.kv_projection_size)
// self.num_query_groups_per_partition,
-1,
)

# Split into query_projection_size and 2 * kv_projection_size parts:
# (n, a, -1) and (n, b, -1)
weight_q = weight_reshaped[
:, : self.query_projection_size // self.num_query_groups_per_partition, :
]
weight_k = weight_reshaped[
:,
self.query_projection_size
// self.num_query_groups_per_partition : (
self.query_projection_size + self.kv_projection_size
)
// self.num_query_groups_per_partition,
:,
]
weight_v = weight_reshaped[
:,
(self.query_projection_size + self.kv_projection_size)
// self.num_query_groups_per_partition :,
:,
]

# extend the qk_clip_balancing_eta to the same shape as weight_q and weight_k
self.qk_clip_balancing_eta_extended = self.qk_clip_balancing_eta.repeat(
1, weight_q.size(1), 1
)

# Clipping
weight_q.mul_(torch.pow(self.qk_clip_balancing_eta_extended, self.config.qk_clip_alpha))
weight_k.mul_(torch.pow(self.qk_clip_balancing_eta, 1 - self.config.qk_clip_alpha))

# Concatenate back and reshape to original shape
weight_updated = torch.cat([weight_q, weight_k, weight_v], dim=1)
weight_updated = weight_updated.view(
self.query_projection_size + 2 * self.kv_projection_size, -1
)

return weight_updated


class CrossAttention(Attention):
"""Cross-attention layer class
Expand Down
120 changes: 120 additions & 0 deletions megatron/core/transformer/multi_latent_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,3 +917,123 @@ def set_for_recompute_input_layernorm(self):
if self.config.q_lora_rank is not None:
set_save_original_input(self.linear_q_down_proj)
set_save_original_input(self.linear_kv_down_proj)

def clip_qk(self):
"""
QK Clipping is a technique to clip the query and key attention logits to prevent the
attention logits from exploding. Per MuonClip usage, we update the weight by calling this
function after Muon optimizer step.
"""

if not self.config.qk_clip:
raise ValueError("qk_clip option needs to be enabled")

if self.core_attention.current_max_attn_logits is None:
raise ValueError("current_max_attn_logits is None")

# Check if we're in absorption mode
if self.cache_mla_latents and not hasattr(self, 'linear_kv_up_proj'):
raise ValueError(
"qk_clip is not supported when cache_mla_latents is enabled and absorption is "
"active. The linear_kv_up_proj layer has been deleted during absorption "
"preparation."
)

assert self.core_attention.current_max_attn_logits.shape == (
self.num_attention_heads_per_partition,
), f"current_max_attn_logits shape is not ({self.num_attention_heads_per_partition}, ) \
but {self.core_attention.current_max_attn_logits.shape}"

# only update the weight if any head has
# current_max_attn_logits > qk_clip_threshold
if torch.any(self.core_attention.current_max_attn_logits > self.config.qk_clip_threshold):
# Use num_attention_heads_per_partition for tensor parallel scenarios

# qk_clip_balancing_eta (n, 1, 1)
assert self.core_attention.current_max_attn_logits.shape == (
self.num_attention_heads_per_partition,
), f"current_max_attn_logits shape is not ({self.num_attention_heads_per_partition},) \
but {self.core_attention.current_max_attn_logits.shape}"
self.qk_clip_balancing_eta = torch.clamp(
self.config.qk_clip_threshold / self.core_attention.current_max_attn_logits, max=1.0
).view(self.num_attention_heads_per_partition, 1, 1)
assert torch.all(self.qk_clip_balancing_eta <= 1.0)

# Update q side weight, keep qk_pos_emb_head_dim side weight unchanged
if self.config.q_lora_rank is None:
q_proj_weight = self.linear_q_proj.weight
else:
q_proj_weight = self.linear_q_up_proj.weight

# Handle different weight access patterns (main_param vs direct access)
if hasattr(q_proj_weight, 'main_param'):
q_proj_weight.main_param.data.copy_(
self._clip_q_proj_weight(q_proj_weight.main_param.data)
)
q_proj_weight.data.copy_(self._clip_q_proj_weight(q_proj_weight.data))

# Update k side weight, keep v side weight unchanged
kv_proj_weight = self.linear_kv_up_proj.weight

# Handle different weight access patterns
if hasattr(kv_proj_weight, 'main_param'):
kv_proj_weight.main_param.data.copy_(
self._clip_kv_proj_weight(kv_proj_weight.main_param.data)
)
kv_proj_weight.data.copy_(self._clip_kv_proj_weight(kv_proj_weight.data))

# reset current_max_attn_logits
self.core_attention.current_max_attn_logits = None

def _clip_q_proj_weight(self, weight):
"""Clip q_proj_weight"""
# Reshape to (n, a + b, -1)
weight_reshaped = weight.view(
self.num_attention_heads_per_partition,
self.config.qk_head_dim + self.config.qk_pos_emb_head_dim,
-1,
)

# Split into qk_head_dim and qk_pos_emb_head_dim parts: (n, a, -1) and (n, b, -1)
weight_q_nope = weight_reshaped[:, : self.config.qk_head_dim, :]
weight_q_pe = weight_reshaped[:, self.config.qk_head_dim :, :]

# Clipping
weight_q_nope.mul_(torch.pow(self.qk_clip_balancing_eta, self.config.qk_clip_alpha))
weight_q_pe.mul_(self.qk_clip_balancing_eta)

# Concatenate back and reshape to original shape
weight_q_updated = torch.cat([weight_q_nope, weight_q_pe], dim=1)
weight_q_updated = weight_q_updated.view(
self.num_attention_heads_per_partition
* (self.config.qk_head_dim + self.config.qk_pos_emb_head_dim),
-1,
)

return weight_q_updated

def _clip_kv_proj_weight(self, weight):
"""Clip kv_proj_weight"""
# shape: (n, qk_head_dim + v_head_dim, kv_lora_rank)
weight_reshaped = weight.view(
self.num_attention_heads_per_partition,
self.config.qk_head_dim + self.config.v_head_dim,
-1,
)

# Split into qk_head_dim and v_head_dim parts: (n, a, -1) and (n, b, -1)
weight_k = weight_reshaped[:, : self.config.qk_head_dim, :]
weight_v = weight_reshaped[:, self.config.qk_head_dim :, :]

# Clipping
weight_k.mul_(torch.pow(self.qk_clip_balancing_eta, 1 - self.config.qk_clip_alpha))

# Concatenate back and reshape to original shape
weight_kv_updated = torch.cat([weight_k, weight_v], dim=1)
weight_kv_updated = weight_kv_updated.view(
self.num_attention_heads_per_partition
* (self.config.qk_head_dim + self.config.v_head_dim),
-1,
)

return weight_kv_updated
13 changes: 13 additions & 0 deletions megatron/core/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,19 @@ class TransformerConfig(ModelParallelConfig):
qk_layernorm: bool = False
"""Whether to apply `normalization` type of normalization to the query and key embeddings."""

qk_clip: bool = False
"""Whether to clip the query and key weights. Needed for Muon MLA Model training."""

qk_clip_alpha: float = 0.5
"""The balancing alpha for qk-clip. Q = Q * (eta ** alpha)"""

qk_clip_threshold: float = 100
"""The balancing threshold for qk-clip. eta = min(threshold / max_attention_logits, 1.0)"""

log_max_attention_logit: bool = False
"""Whether to log the max attention logit across whole model. Decoupled from qk_clip,
defualts to False. Setting qk_clip will automatically log the max logit"""

test_mode: bool = False
"""Whether to run real-time tests."""

Expand Down
24 changes: 24 additions & 0 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,6 +977,19 @@ def validate_args(args, defaults={}):
if args.add_bias_linear:
args.add_qkv_bias = True

if args.qk_clip:
assert is_te_min_version("2.9.0"), \
'--qk-clip is only supported with TE >= 2.9.0.'
assert 0.0 < args.qk_clip_alpha < 1.0, \
'--qk-clip-alpha must be between 0.0 and 1.0 when using --qk-clip.'
assert args.qk_clip_threshold > 0, \
'--qk-clip-threshold must be greater than 0 when using --qk-clip.'

# decoupled log max attention logit check
if args.log_max_attention_logit:
assert is_te_min_version("2.9.0"), \
'--log-max-attention-logit is only supported with TE >= 2.9.0.'

# Retro checks.
if args.retro_add_retriever:

Expand Down Expand Up @@ -1205,6 +1218,9 @@ def validate_args(args, defaults={}):
assert (
args.recompute_granularity != 'full'
), 'recompute_granularity must not be full when CUDA Graphs are enabled.'

if args.multi_latent_attention:
assert not args.group_query_attention, "Group query attention is mutually exclusive with multi latent attention."

# Print arguments.
_print_args("arguments", args)
Expand Down Expand Up @@ -1864,6 +1880,8 @@ def _add_logging_args(parser):
group.add_argument('--log-world-size-to-tensorboard',
action='store_true',
help='Enable world size logging to tensorboard.')
group.add_argument('--log-max-attention-logit', action='store_true',
help='Enable max attention logit logging to tensorboard.')
group.add_argument('--wandb-project', type=str, default='',
help='The wandb project name. Ignore wandb by default.')
group.add_argument('--wandb-entity', type=str, default='',
Expand Down Expand Up @@ -2206,6 +2224,12 @@ def _add_training_args(parser):
group.add_argument('--add-qkv-bias', action='store_true',
help='Enable bias only in the QKV linear layers',
dest='add_qkv_bias')
group.add_argument('--qk-clip', action='store_true',
help='Whether to use qk-clip for training stabilization, strongly recommended for Muon.')
group.add_argument('--qk-clip-alpha', type=float, default=0.5,
help='The balancing alpha for qk-clip.')
group.add_argument('--qk-clip-threshold', type=float, default=100,
help='The balancing threshold for qk-clip.')
group.add_argument('--optimizer', type=str, default='adam',
choices=['adam', 'sgd'],
help='Optimizer function')
Expand Down
Loading
Loading