Numerically stable parallel cumsum-based WKV + jax/tf/keras implementations #189
jackd
started this conversation in
Show and tell
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi there, love the work. Having dug through the paper in more detail recently I realized the WKV implementation has some similarities with some ongoing work I'm involved in, so I hacked up a proof of concept using keras / keras-nlp here. I've included a theory page, but to summarise:
w**(t-1), it can be expressed as a cumulative sum(v, t), where the actual valuezis represented asz = exp(t) * vI've included a very rough performance summary which show promise. That said:
This is a side-project of a side-project for me, so while I've enjoyed doing it I can't afford to spend much longer fine tuning a backend I understand little about. I'm pretty confident a cuda implementation based on thrust's inclusive_scan would be straight forward and perform considerably better than my triton implementation, but having never written custom pytorch bindings that's a project I'm going to pass on (if anyone decides to take that up I can offer a basic sketch).
Hope this helps someone :)
Beta Was this translation helpful? Give feedback.
All reactions