Using reduction=none, LigerFusedLinearCrossEntropyLoss returns wrong grads when multiply a designed weight_mask.
Reproduce code is uploaded below.
import torch
from torch.nn import CrossEntropyLoss
from liger_kernel.transformers import LigerCrossEntropyLoss
from liger_kernel.transformers import LigerFusedLinearCrossEntropyLoss
# B, T, H, V = 2, 2048, 256, 32000
# B, T, H, V = 2, 1, 10, 15
B, T, H, V = 2, 4, 10, 15
ignore_index = -100
reduction = "none"
device = "cuda"
dtype = torch.float32
scalar = 2
atol, rtol = 1e-8, 1e-5
use_ignore_index = False
if use_ignore_index:
# Passed
target_ce = LigerCrossEntropyLoss(ignore_index=ignore_index, reduction=reduction)
target_flce = LigerFusedLinearCrossEntropyLoss(ignore_index=ignore_index, reduction=reduction)
# torch_ce = CrossEntropyLoss(ignore_index=ignore_index, reduction=reduction)
else:
# Also passed
target_ce = LigerCrossEntropyLoss(reduction=reduction)
target_flce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
_tensor = torch.randn(B * T, H, device=device, dtype=dtype) * scalar
lin_weight = torch.randn(V, H, device=device, dtype=dtype)
_input1 = _tensor.detach().clone().requires_grad_(True)
lin_weight1=lin_weight.clone().detach().requires_grad_(True)
_input1_mul_weight = (_input1@lin_weight1.transpose(0,1))
_input2 = _tensor.detach().clone().requires_grad_(True)
lin_weight2=lin_weight.clone().detach().requires_grad_(True)
target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)
# # Assign some random number of elements as ignore_index
# num_elements_to_assign = torch.randint(
# 1, B * T // 2, (1,)
# ).item() # Random number of elements to set to ignore_index
# indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] # Randomly select indices
# target[indices_to_assign] = ignore_index
# target[:B*T//2] = ignore_index
target[:1] = ignore_index
if use_ignore_index:
output = target_ce(_input1_mul_weight, target)
output2 = target_flce(lin_weight2, _input2, target)
mask = (target != -100)
loss1 = (output * mask).sum() / mask.sum()
loss2 = (output2 * mask).sum() / mask.sum()
else:
output = target_ce(_input1_mul_weight, target)
output2 = target_flce(lin_weight2, _input2, target)
mask = (target != -100).type_as(output2)
# mask=torch.randn(B*T,device=device, dtype=dtype)
print(f'weight_mask:{mask}')
loss1 = (output * mask).sum()
loss2 = (output2 * mask).sum()
# loss1=output
# loss2=output2*output2
print(f'loss1:{loss1}')
print(f'loss2:{loss2}')
loss1.backward(gradient=torch.ones_like(loss1))
loss2.backward(gradient=torch.ones_like(loss2))
print(f'grad1_sum:{torch.sum(torch.abs(_input1.grad))}')
print(f'grad2_sum:{torch.sum(torch.abs(_input2.grad))}')
# assert torch.allclose(loss1, loss2, atol=atol, rtol=rtol)
# assert torch.allclose(_input1.grad, _input2.grad, atol=atol, rtol=rtol)
🐛 Describe the bug
Using reduction=none, LigerFusedLinearCrossEntropyLoss returns wrong grads when multiply a designed weight_mask.
Code:
Output:

Code:
Output:

Reproduce code is uploaded below.
Reproduce
Versions
Environment Versions:
Python version: 3.11.11
Liger Kernel version: 0.6.4
PyTorch version: 2.6.0+cu126
CUDA version: 12.6
Triton version: 3.2.0
Transformers version: 4.56.0