-
Notifications
You must be signed in to change notification settings - Fork 3.3k
QK logits clipping (non-split version) #1929
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
BoxiangW
wants to merge
33
commits into
NVIDIA:main
Choose a base branch
from
BoxiangW:boxiangw/muon-clip
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+756
−2
Open
Changes from all commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
0ef342c
Add qk-clip
BoxiangW a5d2ea8
w/o EP working
BoxiangW 978b99b
Lint
BoxiangW b77a6d6
Fix MLA usage and add logging for max_attention_score
BoxiangW a83b68c
Removed Attention MuonClip since it is not correct. Added MLA MuonClip
BoxiangW eae6a5a
Add TE bug WAR and make sure MLA works
BoxiangW c370cdd
Fix qkclip issue
BoxiangW 6f8f77d
Fix bug
BoxiangW 527075f
Fix TP usage
BoxiangW 3cebc1f
Fix Lint
BoxiangW e3c468f
Fix import and error log
BoxiangW ef995ea
lint
BoxiangW 7f5f1e2
Address comments
BoxiangW 65fc829
Lint
BoxiangW 75ee38c
Added GQA QK Clipping
BoxiangW a1c0ff6
Lint
BoxiangW 762bdb5
Remove comment
BoxiangW ba5c073
Rename max_score to max_logit
BoxiangW 7917e68
Change name
BoxiangW 495f58d
Merge branch 'main' into boxiangw/muon-clip
BoxiangW 55cc00d
Lint
BoxiangW f54fbdd
Fix copyright
BoxiangW 9095615
Merge branch 'main' into boxiangw/muon-clip
BoxiangW 4b490cc
Move qk_clip fucntion into megatron/core
BoxiangW 0453ded
Merge branch 'main' into boxiangw/muon-clip
BoxiangW 1bb0407
Add tests for qk_clip
BoxiangW b63c573
Lint and copyright
BoxiangW bdda5e9
Add te version checks into PR
BoxiangW bd0436b
Address comments
BoxiangW 646aea5
Address comments
BoxiangW 4d16029
DP all reduce and switch to mul_ inplace op
BoxiangW 2b890d0
Update both main_params and non
BoxiangW 95fdba3
Address comments
BoxiangW File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
|
|
||
| import torch | ||
|
|
||
| from megatron.core import mpu | ||
|
|
||
|
|
||
| def clip_qk(model, log_max_only=False) -> float: | ||
| """ | ||
| Clip the QK attention logits to the threshold, recommended for Muon optimizer. | ||
|
|
||
| Args: | ||
| model: The model to clip the QK attention logits, a list of model chunks. | ||
| log_only: Whether to only log the max attention logit, without updating the weights. | ||
|
|
||
| Returns: | ||
| The maximum attention logit, a float. | ||
| """ | ||
|
|
||
| with torch.no_grad(): | ||
| log_max_attention_logit = 0 | ||
| for model_chunk in model: | ||
| for transformer_layer in model_chunk.module.module.decoder.layers: | ||
| if hasattr(transformer_layer.self_attention, 'clip_qk'): | ||
| torch.distributed.all_reduce( | ||
| transformer_layer.self_attention.core_attention.current_max_attn_logits, | ||
| op=torch.distributed.ReduceOp.MAX, | ||
| group=mpu.get_data_parallel_group(with_context_parallel=True), | ||
| ) | ||
| log_max_attention_logit = max( | ||
| log_max_attention_logit, | ||
| torch.max( | ||
| transformer_layer.self_attention.core_attention.current_max_attn_logits | ||
| ).item(), | ||
| ) | ||
| if not log_max_only: | ||
| transformer_layer.self_attention.clip_qk() | ||
|
|
||
| return log_max_attention_logit |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if
not is_te_min_versionbut config.qk_clip? might raise an errorThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right now it will raise an error if te version is wrong