Skip to content

[BUG]LigerFusedLinearCrossEntropyLoss(reduction='none') results in wrong grad, even causes grad=0 #968

@Hevans123

Description

@Hevans123

🐛 Describe the bug

Using reduction=none, LigerFusedLinearCrossEntropyLoss returns wrong grads when multiply a designed weight_mask.

  1. For target_labels has ignore_index, grads incorrectly returns all 0
    Code:
Image

Output:
Image

  1. If multiplying a desined weight_mask, grad is not correct (inconsistent with LigerCrossEntropyLoss)
    Code:
Image

Output:
Image

Reproduce code is uploaded below.

Reproduce

import torch

from torch.nn import CrossEntropyLoss

from liger_kernel.transformers import LigerCrossEntropyLoss
from liger_kernel.transformers import LigerFusedLinearCrossEntropyLoss

# B, T, H, V = 2, 2048, 256, 32000
# B, T, H, V = 2, 1, 10, 15
B, T, H, V = 2, 4, 10, 15

ignore_index = -100
reduction = "none"
device = "cuda"
dtype = torch.float32
scalar = 2

atol, rtol = 1e-8, 1e-5

use_ignore_index = False

if use_ignore_index:
    # Passed
    target_ce = LigerCrossEntropyLoss(ignore_index=ignore_index, reduction=reduction)
    target_flce = LigerFusedLinearCrossEntropyLoss(ignore_index=ignore_index, reduction=reduction)
    # torch_ce = CrossEntropyLoss(ignore_index=ignore_index, reduction=reduction)
else:
    # Also passed
    target_ce = LigerCrossEntropyLoss(reduction=reduction)
    target_flce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)

_tensor = torch.randn(B * T, H, device=device, dtype=dtype) * scalar
lin_weight = torch.randn(V, H, device=device, dtype=dtype)

_input1 = _tensor.detach().clone().requires_grad_(True)
lin_weight1=lin_weight.clone().detach().requires_grad_(True)
_input1_mul_weight = (_input1@lin_weight1.transpose(0,1))

_input2 = _tensor.detach().clone().requires_grad_(True)
lin_weight2=lin_weight.clone().detach().requires_grad_(True)


target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)

# # Assign some random number of elements as ignore_index
# num_elements_to_assign = torch.randint(
#     1, B * T // 2, (1,)
# ).item()  # Random number of elements to set to ignore_index
# indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign]  # Randomly select indices
# target[indices_to_assign] = ignore_index

# target[:B*T//2] = ignore_index
target[:1] = ignore_index

if use_ignore_index:
    output = target_ce(_input1_mul_weight, target)
    output2 = target_flce(lin_weight2, _input2, target)
    mask = (target != -100)
    loss1 = (output * mask).sum() / mask.sum()
    loss2 = (output2 * mask).sum() / mask.sum()
else:
    output = target_ce(_input1_mul_weight, target)
    output2 = target_flce(lin_weight2, _input2, target)
    mask = (target != -100).type_as(output2)
    # mask=torch.randn(B*T,device=device, dtype=dtype)
    print(f'weight_mask:{mask}')
    loss1 = (output * mask).sum()
    loss2 = (output2 * mask).sum()
    # loss1=output
    # loss2=output2*output2

print(f'loss1:{loss1}')
print(f'loss2:{loss2}')

loss1.backward(gradient=torch.ones_like(loss1))
loss2.backward(gradient=torch.ones_like(loss2))

print(f'grad1_sum:{torch.sum(torch.abs(_input1.grad))}')

print(f'grad2_sum:{torch.sum(torch.abs(_input2.grad))}')
# assert torch.allclose(loss1, loss2, atol=atol, rtol=rtol)
# assert torch.allclose(_input1.grad, _input2.grad, atol=atol, rtol=rtol)

Versions

Environment Versions:

Python version: 3.11.11
Liger Kernel version: 0.6.4
PyTorch version: 2.6.0+cu126
CUDA version: 12.6
Triton version: 3.2.0
Transformers version: 4.56.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions