Skip to content

Enable fast attention in nanoPET #454

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

frostedoyster
Copy link
Collaborator

@frostedoyster frostedoyster commented Jan 26, 2025


📚 Documentation preview 📚: https://metatrain--454.org.readthedocs.build/en/454/

@frostedoyster frostedoyster added the Discussion Issues to be discussed by the contributors label Jan 26, 2025
@spozdn
Copy link
Collaborator

spozdn commented Jan 26, 2025

I think backward compatible proper implementation is on the way. We can discuss it soon.

upd. not in the backward compatible sense as later clarified, I meant supporting backward pass

@abmazitov
Copy link
Contributor

Does it give the same result as the original version of the code? If so, I would be quite surprised

@abmazitov
Copy link
Contributor

What do you mean by saying that it’s not backward compatible? I don’t see why backward pass should not work

@Luthaf
Copy link
Member

Luthaf commented Jan 27, 2025

What do you mean by saying that it’s not backward compatible? I don’t see why backward pass should not work

Backward compatible in the API-sense, i.e. that the same checkpoint will produce the same results without retraining.

@frostedoyster
Copy link
Collaborator Author

frostedoyster commented Jan 28, 2025

@spozdn I think that using a custom attn_mask here

attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)

will prevent torch from using flash attention

You'll find more details here, where they say that they only implement full (all True) and causal masks https://github.com/Dao-AILab/flash-attention

And here (https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html), where they say that
"All implementations are enabled by default. Scaled dot product attention attempts to automatically select the most optimal implementation based on the inputs."

Anyway, this can be used to check which one is used:
https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html#torch.nn.attention.sdpa_kernel

EDIT: Flash attention doesn't support a custom mask, but CUDNN attention does, so we'll go with it

@frostedoyster frostedoyster force-pushed the flash-attention-nanopet branch from 4274feb to df0587e Compare January 31, 2025 08:10
@frostedoyster frostedoyster changed the title Weird trick to enable flash attention in nanoPET Enable fast attention in nanoPET Jan 31, 2025
@frostedoyster frostedoyster force-pushed the flash-attention-nanopet branch from df0587e to f87835e Compare January 31, 2025 08:14
@frostedoyster frostedoyster removed the Discussion Issues to be discussed by the contributors label Jan 31, 2025
@frostedoyster frostedoyster force-pushed the flash-attention-nanopet branch 6 times, most recently from 3372801 to 4541da7 Compare January 31, 2025 14:14
@frostedoyster frostedoyster force-pushed the flash-attention-nanopet branch 2 times, most recently from f5df1b4 to 9f846db Compare January 31, 2025 14:43
@frostedoyster frostedoyster force-pushed the flash-attention-nanopet branch from 9f846db to 87704fe Compare January 31, 2025 14:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants