Skip to content

Conversation

@awni
Copy link
Member

@awni awni commented Sep 29, 2025

Working generation:

mlx_lm.generate --model mlx_model -p "-" -m 128 < /Volumes/mlr_share/awni/mlx-lm/prompt.txt
==========
Of course. The passage you provided is the opening of Jane Austen's *Pride and Prejudice*, specifically Chapters I, II, and the beginning of Chapter III. Here is a summary of the key events:

**Chapter I:**
- The novel opens with the famous line about a wealthy single man being in want of a wife.
- Mrs. Bennet tells her husband that Netherfield Park has been let to a wealthy young man, Mr. Bingley. She is excited by the prospect of him marrying one of their daughters.
- Mr. Bennet is sarcastic and teasing, while Mrs. Bennet is excitable
==========
Prompt: 2878 tokens, 167.263 tokens-per-sec
Generation: 128 tokens, 7.316 tokens-per-sec
Peak memory: 397.827 GB

@awni
Copy link
Member Author

awni commented Oct 4, 2025

This is functional... but unfortunately quite slow. You pay a penalty for the topk gather for shorter context. And from some micro benchmarks, I think it would take fairly long context for this to be more efficient than the naive attention.

We can maybe improve on the current implementation with a gather_mm for the q @ k[idx].T matmul. But I don't think we have an op for the scores @ v[idx] to do it in one go.

Here's the micro benchmark for reference:

import copy
import mlx.core as mx
from mlx_lm import load
from mlx_lm.models.cache import CacheList, KCache, KVCache

import time

model, _ = load("mlx_model", lazy=True)

attn = model.layers[0].self_attn
#attn.indexer.index_topk = float("inf")

cache = CacheList(KVCache(), KCache())

d = 7168
l = 32768
x = mx.random.normal(shape=(1, l, d)).astype(mx.bfloat16)

attn(x, cache=cache)
mx.eval(cache.state)

q = mx.random.normal(shape=(1, 1, d)).astype(mx.bfloat16)

def fn(x, cache):
    c = copy.deepcopy(cache)
    for _ in range(10):
        x = attn(x, cache=c)
    return x

# warmup
for _ in range(10):
    mx.eval(fn(q, cache))

tic = time.time()
for _ in range(10):
    mx.eval(fn(q, cache))
toc = time.time()
print(toc - tic)

c.extend(other)


class KCache(_BaseCache):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it makes sense to make KVCache general to an arbitrary number of arrays?

@kernelpool
Copy link
Contributor

kernelpool commented Oct 6, 2025

I've been doing some testing on this and the current implementation quickly breaks down if you ask the model to summarize a large function (e.g. 4k tokens long, embedded in the prompt), resulting in incoherent and/or repetitive output.

After looking through other implementations, including deepseek-ai/DeepSeek-V3.2-Exp, I was able to eventually fix the issue (missing sum over heads). But you probably want to review these first, so I've just provided the diff below.

diff --git a/mlx_lm/models/deepseek_v32.py b/mlx_lm/models/deepseek_v32.py
index f811253..1ef421c 100644
--- a/mlx_lm/models/deepseek_v32.py
+++ b/mlx_lm/models/deepseek_v32.py
@@ -101,10 +101,23 @@ class Indexer(nn.Module):
             k = cache.update_and_fetch(k)
         if k.shape[2] <= self.index_topk:
             return None
+
+        # Compute Q @ K^T
+        scores = q @ k.swapaxes(-1, -2)
+
+        # Apply ReLU activation (critical for correct sparse attention)
+        scores = mx.maximum(scores, 0)
+
+        # Compute per-head weights and apply after matmul
         weights = self.weights_proj(x) * (self.n_heads**-0.5)
         weights = (weights * self.softmax_scale).swapaxes(-1, -2)[..., None]
-        q_scaled = q * weights
-        scores = (q * weights) @ k.swapaxes(-1, -2)
+        scores = scores * weights
+
+        # Sum over heads (matches reference kernel.py:244)
+        # Output shape: o: T.Tensor[(b, m, n), FP32] - no heads dimension
+        scores = scores.sum(axis=1)  # (b, s, k_seq)
+
+        # Apply mask and select top-k indices
         if mask is not None:
             scores = mx.where(mask, scores, -float("inf"))
         return mx.argpartition(scores, kth=-self.index_topk, axis=-1)[
@@ -188,6 +201,8 @@ class DeepseekV32Attention(nn.Module):
 
         if self.q_lora_rank is None:
             q = self.q_proj(x)
+            # Note: DeepSeek v3.2 always uses LoRA, but handle this path defensively
+            qr = x  # Fallback for indexer (shouldn't be needed in practice)
         else:
             qr = self.q_a_layernorm(self.q_a_proj(x))
             q = self.q_b_proj(qr)
@@ -217,26 +232,47 @@ class DeepseekV32Attention(nn.Module):
             keys = mx.concatenate([k_nope, k_pe], axis=-1)
 
         queries = mx.concatenate([q_nope, q_pe], axis=-1)
-        topk_indices = self.indexer(x, qr, mask, cache=cache[1])
+        # Indexer needs 3D mask (B, L, L) or None
+        # create_attention_mask returns: None, "causal", or 2D (L, L)
+        # But handle 4D (B, num_heads, L, L) for tests/edge cases
+        if mask is None or isinstance(mask, str):
+            indexer_mask = None
+        elif mask.ndim == 2:
+            # 2D (L, L) -> 3D (B, L, L)
+            indexer_mask = mx.broadcast_to(mask, (B, *mask.shape))
+        elif mask.ndim == 4:
+            # 4D (B, num_heads, L, L) -> 3D (B, L, L)
+            indexer_mask = mask[:, 0, :, :]
+        else:
+            # Already 3D (B, L, L)
+            indexer_mask = mask
+
+        topk_indices = self.indexer(x, qr, indexer_mask, cache=cache[1])
         if topk_indices is not None:
-            repeats = self.num_heads // self.config.index_n_heads
-            if L == 1:
-                topk_indices = mx.repeat(topk_indices, repeats, axis=1).squeeze(-2)[
-                    ..., None
-                ]
-                keys = mx.take_along_axis(keys, topk_indices, axis=-2)
-                values = mx.take_along_axis(values, topk_indices, axis=-2)
-            else:
-                topk_mask = mx.zeros(
-                    (B, self.config.index_n_heads, *mask.shape[-2:]), mx.bool_
-                )
-                topk_mask = mx.put_along_axis(
-                    topk_mask, topk_indices, mx.array(True), axis=-1
-                )
-                mask = mask & topk_mask
-                mask = mx.repeat(mask, repeats, axis=1)
+            # topk_indices shape: (B, L, topk) - no heads dimension after sum
+            # Reference: scatter topk_indices into -inf mask, then add to existing mask
+            # Create sparse mask (works for both decode and prefill)
+            k_seq = keys.shape[2]
+            # Create boolean sparse mask: True at topk positions
+            sparse_mask = mx.zeros((B, L, k_seq), dtype=mx.bool_)
+            ones = mx.ones(topk_indices.shape, dtype=mx.bool_)
+            sparse_mask = mx.put_along_axis(sparse_mask, topk_indices, ones, axis=-1)
+
+            # Combine with causal mask using AND
+            if mask is not None and not isinstance(mask, str):
+                # Convert to (B, L, k_seq)
+                if mask.ndim == 2:
+                    causal_mask = mx.broadcast_to(mask, sparse_mask.shape)
+                elif mask.ndim == 4:
+                    causal_mask = mask[:, 0, :, :]
+                else:
+                    causal_mask = mask
+                sparse_mask = sparse_mask & causal_mask
+
+            # Broadcast to (B, num_heads, L, k_seq)
+            mask = sparse_mask[:, None, :, :]
         output = scaled_dot_product_attention(
-            queries, keys, values, cache=cache, scale=self.scale, mask=mask
+            queries, keys, values, cache=cache[0], scale=self.scale, mask=mask
         )
         output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
         return self.o_proj(output)

@awni
Copy link
Member Author

awni commented Oct 6, 2025

@kernelpool thanks for the fix. That's awesome! Do you mind sending a PR? I have some questions/comments but they are better made in-line.

* Fix sparse token selection in deepseek v3.2

* Fix 4D mask input handling and remove unnecessary ones array
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants