Skip to content

Conversation

@Jianbing-D
Copy link

@Jianbing-D Jianbing-D commented Nov 14, 2025

What does this PR do ?

This PR introduces an implementation that fuses Linear Layer of lm_head and Cross-Entropy, in order to avoid materializing the intermediate logits tensor, helping reducing memory footprint.

PR to the main branch: #2206

⚠️ For major changes (either in lines of code or in its impact), please make sure to first share discuss a design-doc with the team.

Contribution process

flowchart LR
    A[Pre-checks] --> B[PR Tests]
    subgraph Code Review/Approval
        C1[Expert Review] --> C2[Final Review]
    end
    B --> C1
    C2 --> D[Merge]
Loading

Pre-checks

  • I want this PR in a versioned release and have added the appropriate Milestone (e.g., Core 0.8)
  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

The following process is enforced via the CODEOWNERS file for changes into megatron/core. For changes outside of megatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.

For MRs into `main` branch

(Step 1): Add PR label Expert Review

(Step 2): Collect the expert reviewers reviews

  1. Attach the Expert Review label when your PR is ready for review.
  2. GitHub auto-assigns expert reviewers based on your changes. They will get notified and pick up your PR soon.

⚠️ Only proceed to the next step once all reviewers have approved, merge-conflict are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

(Step 3): Final Review

  1. Add Final Review label
  2. GitHub auto-assigns final reviewers based on your changes. They will get notified and pick up your PR soon.

(Optional Step 4): Cherry-pick into release branch

If this PR also needs to be merged into core_r* release branches, after this PR has been merged, select Cherry-pick to open a new PR into the release branch.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either [email protected] or [email protected].

Merging your PR

Any member of core-adlr and core-nemo will be able to merge your PR.

Jianbing-D and others added 10 commits November 14, 2025 00:59
init fused linear cross-entropy interface
* add forward-mainloop and bwd_partial_dlogits kernel

Signed-off-by: Jianbing Dong <[email protected]>

* skip TestFusedLinearCrossEntropyOnGptModel for single GPU

Signed-off-by: Jianbing Dong <[email protected]>

* added unit-test for linear_cross_entropy on dp

Signed-off-by: Jianbing Dong <[email protected]>

---------

Signed-off-by: Jianbing Dong <[email protected]>
* added unit-test for TP

Signed-off-by: Jianbing Dong <[email protected]>

* add sequence-parallel and its unit-test

Signed-off-by: Jianbing Dong <[email protected]>

---------

Signed-off-by: Jianbing Dong <[email protected]>
* 1. fix weight is None issue
2. API compatible fix

* 1. fix weight is None issue
2. API compatible fix

* fix fused linear-ce fusion loss issue

* fix typo in fused_linear_ce triton

* 1. fix weight is None issue
2. API compatible fix

* fix fused linear-ce fusion loss issue

* add sequence_parallel option on compute_language_model_loss_without_logits

* Linear cross-entropy fusion is not used by default.
* Remove redundant logits calculations in gpt_model

* Merge the linear-cross-entropy-fusion flag and the cross-entropy-fusion flag
Signed-off-by: Jianbing Dong <[email protected]>
* rename compute_output_layer_and_language_model_loss

* remove used option fused_linear_cross_entropy in transformer_config
@copy-pr-bot
Copy link

copy-pr-bot bot commented Nov 14, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@Jianbing-D
Copy link
Author

Details about this feature

Training LLM typically involves a two-stage pipeline at the output layer: hidden states are projected into vocabulary logits via a linear transformation (lm_head Layer), followed by Cross-Entropy loss computation against target tokens. While conceptually simple, such workflow incurs substantial overhead. The intermediate logits tensor, with dimension proportional to batch size, sequence length, and vocabulary size, must be fully materialized in GPU memory, even though only one target token per position is ultimately used. This leads to significant memory footprint and bandwidth consumption, limiting scalability and slowing training throughput. The following code snippet might better illustrate that workflow:

hidden_state = xxx # shape = [batch, seqlen, dim]
weight = lm_head.weight # shape = [vocabsize, dim]
labels = xxx # shape = [batch, seqlen]

logits = hidden_state @ weight.T
loss = torch.nn.functional.cross_entropy(logits, labels, reduction="none")

On top of the local logit tensor, other techniques might need some other intermediate buffers for collecting full information across all GPUs. For example, the following snippet is a TP compatible layer, comprised of torch native ops:

tp_rank = xxx
tp_world_size = xxx

logits = hidden @ weight.T

whole_logits = torch.empty(
    (logits.shape[0], logits.shape[-1] * tp_world_size),
     dtype=logits.dtype,
     device=logits.device,
)
whole_logits_ref = [
     whole_logits[..., i * logits.shape[-1] : (i + 1) * logits.shape[-1]]
     for i in range(tp_world_size)
]
dist.all_gather(whole_logits_ref, logits, group=tp_group)

logprobs = torch.nn.functional.cross_entropy(
    whole_logits.view(-1, whole_logits.shape[-1]), labels.view(-1), reduction=reduction
)

By fusing Linear and Cross-Entropy into one single operation, this PR could help avoid materializing the intermediate logit tensor.

hidden_state = xxx # shape = [batch, seqlen, dim]
weight = lm_head.weight # shape = [vocabsize, dim]
labels = xxx # shape = [batch, seqlen]

loss = linear_cross_entropy(hidden_state, weight, labels, reduction="none")

which could help reduce 2bsv memory footprints AT LEAST.

  • in the forward pass, no need to materializing logit tensor, whose shape is [batch, seqlen, vocabsize]
  • in the backward pass, no need to materializing grad of logit tensor, whose shape is also [batch, seqlen, vocabsize]

Functionalities

def linear_cross_entropy(
    hidden: torch.Tensor,
    weight: torch.Tensor,
    labels: torch.Tensor,
    tp_group: typing.Optional[torch.distributed.ProcessGroup] = None,
    reduction: str = "mean",
    ignore_index: int = -100,
    sequence_parallel: bool = False,
) -> torch.Tensor
  • The input tensor is BF16 or FP16 format, and will conduct accumulation and other logics in FP32 format, avoiding precision problem.
  • It supports Data-Parallel, Tensor-Parallel along vocabsize, and Sequence-Parallel along seqlen.
    1. when tp_group is None it works in DP mode
    2. when tp_group is not None, and sequence_parallel is False, it works in TP mode
    3. when tp_group is not None, and sequence_parallel is True, it works in SP mode
  • It supports specifying ignore_idex as what native torch cross-entropy does.
  • It supports specifying reduction method as what native torch cross-entropy does.
  • It is optimized for latest NVIDIA Blackwell GPUs.

Performance and Storage

In DP mode, this PR could lead to perf boost and storage reduction in the following config:
image

You may try the following steps to reproduce it:

# start a Megatron image on GB200
$ pip install nvidia-cutlass-dsl==4.2.1
$ pip install PyGithub
$ pytest -s -v tests/unit_tests/fusions/test_fused_linear_cross_entropy.py
$ torchrun --nproc_per_node=4 --nnodes=1 -m pytest -s -v tests/unit_tests/fusions/test_fused_linear_cross_entropy.py

@Jianbing-D
Copy link
Author

For convergency test, please refer to: #2206 (comment)

@Jianbing-D
Copy link
Author

Linking: #2206

@Jianbing-D Jianbing-D marked this pull request as ready for review November 14, 2025 10:03
@Jianbing-D Jianbing-D requested review from a team as code owners November 14, 2025 10:03
@yaox12 yaox12 added dev branch Dev branch related issues and development Expert Review Apply this label to indicate that your PR is ready for expert review. labels Nov 17, 2025
@yaox12 yaox12 changed the title Feat linear cross entropy for dev [Dev] Feature: linear cross entropy fusion Nov 17, 2025
@yanring yanring requested a review from lhb8125 November 17, 2025 03:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

dev branch Dev branch related issues and development Expert Review Apply this label to indicate that your PR is ready for expert review.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants