Skip to content

[NPU] optimize fused_linear_cross_entropy#1222

Merged
Tcc0403 merged 1 commit into
linkedin:mainfrom
sunyi0505:fused_linear_cross_entropy
May 15, 2026
Merged

[NPU] optimize fused_linear_cross_entropy#1222
Tcc0403 merged 1 commit into
linkedin:mainfrom
sunyi0505:fused_linear_cross_entropy

Conversation

@sunyi0505
Copy link
Copy Markdown
Contributor

@sunyi0505 sunyi0505 commented May 13, 2026

Summary

Move gradient computation out of forward, add a plain CE fast path, and optionally reuse saved logits to improve Ascend training performance while preserving correct grad_output handling.

Testing Done

  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@sunyi0505
Copy link
Copy Markdown
Contributor Author

Main Modifications

Introduce plain fast path for Ascend CE:

Add a new kernel: liger_cross_entropy_forward_kernel_plain. A lightweight kernel is used for vanilla cross-entropy cases without ce_weight, softcap, label smoothing, z loss, token metrics, or token scaling.

Remove upfront gradient computation in forward pass

The original implementation precomputed and cached grad_input / grad_weight / grad_bias during the forward pass. The new implementation only computes loss and essential statistics, and offloads all actual gradient calculation to LigerFusedLinearCrossEntropyFunction.backward().

Reimplement backward pass with chunked computation

Reuse cached forward logits if available; otherwise recompute input @ weight.T in the backward pass. The CE backward kernel produces grad_logits, and GEMM is then applied to derive grad_input, grad_weight, and grad_bias.

Add logits storage limit via _get_logits_save_limit_bytes()

Under the plain fast path, logits are cached when their size is within device memory budget to avoid redundant GEMM in the backward pass. The limit is defined as min(4GiB, total_memory / 8) with a hard minimum of 1GiB.

Adjust token scaling logic

The legacy implementation gathered valid targets separately. The new logic adopts safe_targets combined with masking, and caches the complete scaling_factors_full for backward usage.

Remove the element_mul_kernel path

The old backward logic multiplied precomputed forward gradients by the upstream grad_output. The new implementation consumes grad_output directly inside the CE backward kernel.

Motivation

The core objective is to optimize the performance and gradient semantics of fused linear cross entropy on Ascend.

The legacy implementation placed heavy gradient computation into the forward pass, making the forward stage undertake backward workload. It also failed to fully utilize the actual grad_output passed from autograd, especially for reduction="none" scenarios where upstream gradients are vector-valued. By moving gradient computation back to the backward pass, CE gradients are computed strictly based on the real incoming grad_output, yielding cleaner and more consistent semantics.

Furthermore, vanilla cross-entropy is a high-frequency path in language model training. The newly added plain fast path leverages larger block size, fewer kernel branches, and reduced Python/kernel launch overhead. Logits are reused within a controlled memory budget to lower backward recomputation cost. Overall, this change achieves an Ascend-optimized trade-off between memory footprint and GEMM / kernel launch efficiency.

forward:
image

backward:
image

full:
image

no_grad_forward:
image

ut:
image

@sunyi0505 sunyi0505 force-pushed the fused_linear_cross_entropy branch from ac85ea6 to 0779a47 Compare May 13, 2026 09:25
Comment on lines +241 to +243
# Directly load x[y]. Attempting to "extract" it from the scanned blocks
# tends to increase UB/register pressure on Ascend.
x_y = tl.load(logits_row_ptr + y).cast(tl.float32)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean reading after for loop might hurt the performance?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Load x[y] directly outside the scan loop. Avoid embedding selection logic into the block loop to reduce undefined behavior risks and register pressure on Ascend, thereby improving performance.

Comment thread src/liger_kernel/ops/backends/_ascend/ops/fused_linear_cross_entropy.py Outdated
Comment thread src/liger_kernel/ops/backends/_ascend/ops/fused_linear_cross_entropy.py Outdated
Comment thread src/liger_kernel/ops/backends/_ascend/ops/fused_linear_cross_entropy.py Outdated
Comment on lines +109 to +161
# For plain CE, prefer a larger BLOCK_SIZE (fewer loop iters) using the same
# NPU-oriented tuning as `cross_entropy_forward`'s plain_lm path.
if plain_fast_path:
tile_shapes = compute_default_tiling_strategy(
safety_margin=0.9,
dtype_size=4,
memory_multiplier=2.5,
shapes=((V,),),
tiling_dims=(0,),
)
forward_block_size = max(256, tile_shapes[0][0]) if tile_shapes else 8192

# If we're in the plain path and bias-free, launch forward once for all rows
# to reduce Python overhead and kernel launch count.
scaling_factors_full = None
logits_for_backward = None
if input_requires_grad and plain_fast_path and bias is None:
# Save logits for backward when memory allows, avoiding an extra
# input @ weight.T in backward. This is especially important once BT is
# large enough that the GEMM dominates the plain CE backward path.
bytes_per_elem = logits.element_size()
if BT * V * bytes_per_elem <= _get_logits_save_limit_bytes(device):
logits_for_backward = logits

if plain_fast_path and bias is None:
if not logits.is_contiguous():
logits = logits.contiguous()
if target.stride(-1) != 1:
target = target.contiguous()
liger_cross_entropy_forward_kernel_plain[(BT,)](
X_ptr=logits,
X_stride=logits.stride(-2),
Y_ptr=target,
loss_ptr=loss_1d,
n_cols=V,
n_rows=BT,
ce_stats_ptr=ce_stats,
ignore_index=ignore_index,
reduction=reduction,
BLOCK_SIZE=forward_block_size,
)
loss = loss_1d if reduction == "none" else torch.sum(loss_1d)
return (
loss,
None,
None,
None,
loss_1d,
ce_stats,
None,
plain_fast_path,
logits_for_backward,
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fallback makes sense but shouldn't we inform user to use cross entropy instead of fused lienar cross entropy?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added logging in fused_linear_cross_entropy_forward:if the caller already has logits available, it is recommended to use the standalone cross-entropy operator directly.The warning is triggered only once per process to avoid redundant alerts at every training step.

@sunyi0505 sunyi0505 force-pushed the fused_linear_cross_entropy branch 2 times, most recently from 3816fea to f1ecd1c Compare May 14, 2026 03:08
@sunyi0505 sunyi0505 requested a review from Tcc0403 May 14, 2026 03:10
@sunyi0505 sunyi0505 force-pushed the fused_linear_cross_entropy branch from f1ecd1c to e9c8f01 Compare May 15, 2026 01:38
@sunyi0505 sunyi0505 force-pushed the fused_linear_cross_entropy branch from e9c8f01 to 4e7b81f Compare May 15, 2026 01:38
Copy link
Copy Markdown
Collaborator

@Tcc0403 Tcc0403 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@Tcc0403 Tcc0403 added this pull request to the merge queue May 15, 2026
Merged via the queue into linkedin:main with commit d5bf025 May 15, 2026
5 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants