-
Notifications
You must be signed in to change notification settings - Fork 11
Enable fast attention in nanoPET #454
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
I think backward compatible proper implementation is on the way. We can discuss it soon. upd. not in the backward compatible sense as later clarified, I meant supporting backward pass |
Does it give the same result as the original version of the code? If so, I would be quite surprised |
What do you mean by saying that it’s not backward compatible? I don’t see why backward pass should not work |
Backward compatible in the API-sense, i.e. that the same checkpoint will produce the same results without retraining. |
@spozdn I think that using a custom attn_mask here
will prevent torch from using flash attention You'll find more details here, where they say that they only implement full (all True) and causal masks https://github.com/Dao-AILab/flash-attention And here (https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html), where they say that Anyway, this can be used to check which one is used: EDIT: Flash attention doesn't support a custom mask, but CUDNN attention does, so we'll go with it |
4274feb
to
df0587e
Compare
df0587e
to
f87835e
Compare
3372801
to
4541da7
Compare
f5df1b4
to
9f846db
Compare
9f846db
to
87704fe
Compare
📚 Documentation preview 📚: https://metatrain--454.org.readthedocs.build/en/454/