-
Notifications
You must be signed in to change notification settings - Fork 3.3k
[Dev] Feature: linear cross entropy fusion #2256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Jianbing-D
wants to merge
12
commits into
NVIDIA:dev
Choose a base branch
from
Jianbing-D:feat-linear-cross-entropy-for-dev
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
[Dev] Feature: linear cross entropy fusion #2256
Jianbing-D
wants to merge
12
commits into
NVIDIA:dev
from
Jianbing-D:feat-linear-cross-entropy-for-dev
+3,838
−29
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Signed-off-by: Jianbing Dong <[email protected]>
init fused linear cross-entropy interface
* add forward-mainloop and bwd_partial_dlogits kernel Signed-off-by: Jianbing Dong <[email protected]> * skip TestFusedLinearCrossEntropyOnGptModel for single GPU Signed-off-by: Jianbing Dong <[email protected]> * added unit-test for linear_cross_entropy on dp Signed-off-by: Jianbing Dong <[email protected]> --------- Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
* added unit-test for TP Signed-off-by: Jianbing Dong <[email protected]> * add sequence-parallel and its unit-test Signed-off-by: Jianbing Dong <[email protected]> --------- Signed-off-by: Jianbing Dong <[email protected]>
* 1. fix weight is None issue 2. API compatible fix * 1. fix weight is None issue 2. API compatible fix * fix fused linear-ce fusion loss issue * fix typo in fused_linear_ce triton * 1. fix weight is None issue 2. API compatible fix * fix fused linear-ce fusion loss issue * add sequence_parallel option on compute_language_model_loss_without_logits * Linear cross-entropy fusion is not used by default.
Signed-off-by: Jianbing Dong <[email protected]>
* Remove redundant logits calculations in gpt_model * Merge the linear-cross-entropy-fusion flag and the cross-entropy-fusion flag
Signed-off-by: Jianbing Dong <[email protected]>
* rename compute_output_layer_and_language_model_loss * remove used option fused_linear_cross_entropy in transformer_config
Author
Author
|
For convergency test, please refer to: #2206 (comment) |
Author
|
Linking: #2206 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
dev branch
Dev branch related issues and development
Expert Review
Apply this label to indicate that your PR is ready for expert review.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.

What does this PR do ?
This PR introduces an implementation that fuses Linear Layer of lm_head and Cross-Entropy, in order to avoid materializing the intermediate logits tensor, helping reducing memory footprint.
PR to the main branch: #2206
Contribution process
flowchart LR A[Pre-checks] --> B[PR Tests] subgraph Code Review/Approval C1[Expert Review] --> C2[Final Review] end B --> C1 C2 --> D[Merge]Pre-checks
Core 0.8)Code review
The following process is enforced via the CODEOWNERS file for changes into
megatron/core. For changes outside ofmegatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.For MRs into `main` branch
(Step 1): Add PR label
Expert Review(Step 2): Collect the expert reviewers reviews
Expert Reviewlabel when your PR is ready for review.Final Review might get declined if these requirements are not fulfilled.
(Step 3): Final Review
Final Reviewlabel(Optional Step 4): Cherry-pick into release branch
If this PR also needs to be merged into
core_r*release branches, after this PR has been merged, selectCherry-pickto open a new PR into the release branch.For MRs into `dev` branch
The proposed review process for `dev` branch is under active discussion.MRs are mergable after one approval by either
[email protected]or[email protected].Merging your PR
Any member of core-adlr and
core-nemowill be able to merge your PR.