-
Notifications
You must be signed in to change notification settings - Fork 314
[WIP] deepseek v32 #512
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[WIP] deepseek v32 #512
Conversation
|
This is functional... but unfortunately quite slow. You pay a penalty for the topk gather for shorter context. And from some micro benchmarks, I think it would take fairly long context for this to be more efficient than the naive attention. We can maybe improve on the current implementation with a Here's the micro benchmark for reference: import copy
import mlx.core as mx
from mlx_lm import load
from mlx_lm.models.cache import CacheList, KCache, KVCache
import time
model, _ = load("mlx_model", lazy=True)
attn = model.layers[0].self_attn
#attn.indexer.index_topk = float("inf")
cache = CacheList(KVCache(), KCache())
d = 7168
l = 32768
x = mx.random.normal(shape=(1, l, d)).astype(mx.bfloat16)
attn(x, cache=cache)
mx.eval(cache.state)
q = mx.random.normal(shape=(1, 1, d)).astype(mx.bfloat16)
def fn(x, cache):
c = copy.deepcopy(cache)
for _ in range(10):
x = attn(x, cache=c)
return x
# warmup
for _ in range(10):
mx.eval(fn(q, cache))
tic = time.time()
for _ in range(10):
mx.eval(fn(q, cache))
toc = time.time()
print(toc - tic) |
| c.extend(other) | ||
|
|
||
|
|
||
| class KCache(_BaseCache): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it makes sense to make KVCache general to an arbitrary number of arrays?
|
I've been doing some testing on this and the current implementation quickly breaks down if you ask the model to summarize a large function (e.g. 4k tokens long, embedded in the prompt), resulting in incoherent and/or repetitive output. After looking through other implementations, including deepseek-ai/DeepSeek-V3.2-Exp, I was able to eventually fix the issue (missing sum over heads). But you probably want to review these first, so I've just provided the diff below. |
|
@kernelpool thanks for the fix. That's awesome! Do you mind sending a PR? I have some questions/comments but they are better made in-line. |
* Fix sparse token selection in deepseek v3.2 * Fix 4D mask input handling and remove unnecessary ones array
Working generation: