[NPU] optimize fused_linear_cross_entropy#1222
Conversation
Main ModificationsIntroduce plain fast path for Ascend CE:Add a new kernel: liger_cross_entropy_forward_kernel_plain. A lightweight kernel is used for vanilla cross-entropy cases without ce_weight, softcap, label smoothing, z loss, token metrics, or token scaling. Remove upfront gradient computation in forward passThe original implementation precomputed and cached grad_input / grad_weight / grad_bias during the forward pass. The new implementation only computes loss and essential statistics, and offloads all actual gradient calculation to LigerFusedLinearCrossEntropyFunction.backward(). Reimplement backward pass with chunked computationReuse cached forward logits if available; otherwise recompute input @ weight.T in the backward pass. The CE backward kernel produces grad_logits, and GEMM is then applied to derive grad_input, grad_weight, and grad_bias. Add logits storage limit via _get_logits_save_limit_bytes()Under the plain fast path, logits are cached when their size is within device memory budget to avoid redundant GEMM in the backward pass. The limit is defined as min(4GiB, total_memory / 8) with a hard minimum of 1GiB. Adjust token scaling logicThe legacy implementation gathered valid targets separately. The new logic adopts safe_targets combined with masking, and caches the complete scaling_factors_full for backward usage. Remove the element_mul_kernel pathThe old backward logic multiplied precomputed forward gradients by the upstream grad_output. The new implementation consumes grad_output directly inside the CE backward kernel. MotivationThe core objective is to optimize the performance and gradient semantics of fused linear cross entropy on Ascend.The legacy implementation placed heavy gradient computation into the forward pass, making the forward stage undertake backward workload. It also failed to fully utilize the actual grad_output passed from autograd, especially for reduction="none" scenarios where upstream gradients are vector-valued. By moving gradient computation back to the backward pass, CE gradients are computed strictly based on the real incoming grad_output, yielding cleaner and more consistent semantics.Furthermore, vanilla cross-entropy is a high-frequency path in language model training. The newly added plain fast path leverages larger block size, fewer kernel branches, and reduced Python/kernel launch overhead. Logits are reused within a controlled memory budget to lower backward recomputation cost. Overall, this change achieves an Ascend-optimized trade-off between memory footprint and GEMM / kernel launch efficiency. |
ac85ea6 to
0779a47
Compare
| # Directly load x[y]. Attempting to "extract" it from the scanned blocks | ||
| # tends to increase UB/register pressure on Ascend. | ||
| x_y = tl.load(logits_row_ptr + y).cast(tl.float32) |
There was a problem hiding this comment.
Do you mean reading after for loop might hurt the performance?
There was a problem hiding this comment.
Load x[y] directly outside the scan loop. Avoid embedding selection logic into the block loop to reduce undefined behavior risks and register pressure on Ascend, thereby improving performance.
| # For plain CE, prefer a larger BLOCK_SIZE (fewer loop iters) using the same | ||
| # NPU-oriented tuning as `cross_entropy_forward`'s plain_lm path. | ||
| if plain_fast_path: | ||
| tile_shapes = compute_default_tiling_strategy( | ||
| safety_margin=0.9, | ||
| dtype_size=4, | ||
| memory_multiplier=2.5, | ||
| shapes=((V,),), | ||
| tiling_dims=(0,), | ||
| ) | ||
| forward_block_size = max(256, tile_shapes[0][0]) if tile_shapes else 8192 | ||
|
|
||
| # If we're in the plain path and bias-free, launch forward once for all rows | ||
| # to reduce Python overhead and kernel launch count. | ||
| scaling_factors_full = None | ||
| logits_for_backward = None | ||
| if input_requires_grad and plain_fast_path and bias is None: | ||
| # Save logits for backward when memory allows, avoiding an extra | ||
| # input @ weight.T in backward. This is especially important once BT is | ||
| # large enough that the GEMM dominates the plain CE backward path. | ||
| bytes_per_elem = logits.element_size() | ||
| if BT * V * bytes_per_elem <= _get_logits_save_limit_bytes(device): | ||
| logits_for_backward = logits | ||
|
|
||
| if plain_fast_path and bias is None: | ||
| if not logits.is_contiguous(): | ||
| logits = logits.contiguous() | ||
| if target.stride(-1) != 1: | ||
| target = target.contiguous() | ||
| liger_cross_entropy_forward_kernel_plain[(BT,)]( | ||
| X_ptr=logits, | ||
| X_stride=logits.stride(-2), | ||
| Y_ptr=target, | ||
| loss_ptr=loss_1d, | ||
| n_cols=V, | ||
| n_rows=BT, | ||
| ce_stats_ptr=ce_stats, | ||
| ignore_index=ignore_index, | ||
| reduction=reduction, | ||
| BLOCK_SIZE=forward_block_size, | ||
| ) | ||
| loss = loss_1d if reduction == "none" else torch.sum(loss_1d) | ||
| return ( | ||
| loss, | ||
| None, | ||
| None, | ||
| None, | ||
| loss_1d, | ||
| ce_stats, | ||
| None, | ||
| plain_fast_path, | ||
| logits_for_backward, | ||
| ) |
There was a problem hiding this comment.
This fallback makes sense but shouldn't we inform user to use cross entropy instead of fused lienar cross entropy?
There was a problem hiding this comment.
Added logging in fused_linear_cross_entropy_forward:if the caller already has logits available, it is recommended to use the standalone cross-entropy operator directly.The warning is triggered only once per process to avoid redundant alerts at every training step.
3816fea to
f1ecd1c
Compare
f1ecd1c to
e9c8f01
Compare
e9c8f01 to
4e7b81f
Compare





Summary
Move gradient computation out of forward, add a plain CE fast path, and optionally reuse saved logits to improve Ascend training performance while preserving correct grad_output handling.
Testing Done
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence