Skip to content

Conversation

@tongyx361
Copy link

@tongyx361 tongyx361 commented Oct 20, 2025

Trying to resolve #1930.

Test

Environment: H800, CUDA 12.6

Warning

Sorry that I don't have access to newer architectures, on which I have to resort to other contributors to help test on.

FLASH_ATTN_CUDA_ARCHS="90" python setup.py install

Test for the issue

import torch
from flash_attn.ops.triton.cross_entropy import cross_entropy_loss

V = 100000
row = torch.full((V,), float('-inf'), device='cuda', dtype=torch.float32)
i = 50000
row[i] = 10
logits = row.unsqueeze(0)  # [1, V]
labels = torch.tensor([i], device='cuda', dtype=torch.long)

loss = cross_entropy_loss(logits, labels)[0]
print(torch.isfinite(loss))
print(loss)

Output:

tensor([True], device='cuda:0')
tensor([0.], device='cuda:0')

Existing test

$ pytest -q -s tests/losses/test_cross_entropy.py 
...................................................sss...sss...sss...sss...sss...sss...sss...sss...sss...sss...sss...sss...sss...sss...sss...sss...sss...sss...sss...sss...sss...sss...sss...sss...................................................sss...sss...sss...sss...sss...sss...sss...sss...sss...sss...sss...sss...sss...sss...sss...sss...sss...sss...sss...sss...sss...sss...sss...sss
240 passed, 144 skipped in 10.59s

@tongyx361 tongyx361 changed the title fix: nan when m_i_new=-inf fix: nan when m_i_new=-inf in online softmax Oct 20, 2025
@tongyx361 tongyx361 marked this pull request as draft October 20, 2025 14:38
@tongyx361 tongyx361 marked this pull request as ready for review October 20, 2025 15:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

NaNs when many of the logits are masked (-inf)

1 participant