Skip to content

Commit a7f534c

Browse files
ivanfioravantiGoekdeniz-Guelmezawni
authored
Gated-Delta Fused Kernel (Qwen3Next) (#454)
* apply gating in recurrent_gated_delta_rule * update cache with new state * prealocate outputs in recurrent_gated_delta_rule * feat(kernel): gated-delta kernel scaffolding with CPU fallbacks and tests; integrate in Qwen3Next behind flag * feat(kernel): implement Metal kernel for gated delta prefill with time iteration to optimize performance * faster single time step kernel * use kernel for prefill * version bump --------- Co-authored-by: Goekdeniz-Guelmez <[email protected]> Co-authored-by: Awni Hannun <[email protected]>
1 parent 04d6d92 commit a7f534c

File tree

5 files changed

+253
-69
lines changed

5 files changed

+253
-69
lines changed

mlx_lm/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# Copyright © 2023-2025 Apple Inc.
22

3-
__version__ = "0.27.1"
3+
__version__ = "0.28.0"

mlx_lm/models/gated_delta.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
from functools import partial
2+
from typing import Optional, Tuple
3+
4+
import mlx.core as mx
5+
import mlx.nn as nn
6+
7+
8+
@partial(mx.compile, shapeless=True)
9+
def compute_g(A_log, a, dt_bias):
10+
return mx.exp(
11+
-mx.exp(A_log.astype(mx.float32)) * nn.softplus(a + dt_bias).astype(A_log.dtype)
12+
)
13+
14+
15+
def _make_gated_delta_kernel():
16+
if not mx.metal.is_available():
17+
return None
18+
source = """
19+
auto n = thread_position_in_grid.z;
20+
auto b_idx = n / Hv;
21+
auto hv_idx = n % Hv;
22+
auto hk_idx = hv_idx / (Hv / Hk);
23+
constexpr int n_per_t = Dk / 32;
24+
25+
// q, k: [B, T, Hk, Dk]
26+
auto q_ = q + b_idx * T * Hk * Dk + hk_idx * Dk;
27+
auto k_ = k + b_idx * T * Hk * Dk + hk_idx * Dk;
28+
29+
// v, y: [B, T, Hv, Dv]
30+
auto v_ = v + b_idx * T * Hv * Dv + hv_idx * Dv;
31+
y += b_idx * T * Hv * Dv + hv_idx * Dv;
32+
33+
auto dk_idx = thread_position_in_threadgroup.x;
34+
auto dv_idx = thread_position_in_grid.y;
35+
36+
// state_in, state_out: [B, Hv, Dv, Dk]
37+
auto i_state = state_in + (n * Dv + dv_idx) * Dk;
38+
auto o_state = state_out + (n * Dv + dv_idx) * Dk;
39+
40+
float state[n_per_t];
41+
for (int i = 0; i < n_per_t; ++i) {
42+
auto s_idx = n_per_t * dk_idx + i;
43+
state[i] = static_cast<float>(i_state[s_idx]);
44+
}
45+
46+
// beta, g: [B, T, Hv]
47+
auto g_ = g + b_idx * T * Hv;
48+
auto beta_ = beta + b_idx * T * Hv;
49+
50+
for (int t = 0; t < T; ++t) {
51+
float kv_mem = 0.0f;
52+
for (int i = 0; i < n_per_t; ++i) {
53+
auto s_idx = n_per_t * dk_idx + i;
54+
state[i] = state[i] * g_[hv_idx];
55+
kv_mem += state[i] * k_[s_idx];
56+
}
57+
kv_mem = simd_sum(kv_mem);
58+
59+
auto delta = (v_[dv_idx] - kv_mem) * beta_[hv_idx];
60+
61+
float out = 0.0f;
62+
for (int i = 0; i < n_per_t; ++i) {
63+
auto s_idx = n_per_t * dk_idx + i;
64+
state[i] = state[i] + k_[s_idx] * delta;
65+
out += state[i] * q_[s_idx];
66+
}
67+
out = simd_sum(out);
68+
if (thread_index_in_simdgroup == 0) {
69+
y[dv_idx] = static_cast<InT>(out);
70+
}
71+
// Increment data pointers to next time step
72+
q_ += Hk * Dk;
73+
k_ += Hk * Dk;
74+
v_ += Hv * Dv;
75+
y += Hv * Dv;
76+
g_ += Hv;
77+
beta_ += Hv;
78+
}
79+
for (int i = 0; i < n_per_t; ++i) {
80+
auto s_idx = n_per_t * dk_idx + i;
81+
o_state[s_idx] = static_cast<InT>(state[i]);
82+
}
83+
"""
84+
return mx.fast.metal_kernel(
85+
name="gated_delta_step",
86+
input_names=["q", "k", "v", "g", "beta", "state_in", "T"],
87+
output_names=["y", "state_out"],
88+
source=source,
89+
)
90+
91+
92+
_gated_delta_kernel = _make_gated_delta_kernel()
93+
94+
95+
def _gated_delta_step_ops(
96+
q: mx.array,
97+
k: mx.array,
98+
v: mx.array,
99+
g: mx.array,
100+
beta: mx.array,
101+
state: mx.array,
102+
) -> Tuple[mx.array, mx.array]:
103+
"""
104+
Ops-based reference implementation for a single recurrent step.
105+
106+
Shapes:
107+
- q, k: [B, H, Dk]
108+
- v: [B, H, Dv]
109+
- g, beta: [B, H]
110+
- state: [B, H, Dv, Dk]
111+
Returns:
112+
- y: [B, H, Dv]
113+
- new_state: [B, H, Dv, Dk]
114+
"""
115+
116+
# Decay
117+
state = state * g[..., None, None]
118+
kv_mem = (state * k[..., None, :]).sum(axis=-1) # [B, H, Dv]
119+
delta = (v - kv_mem) * beta[..., None] # [B, H, Dv]
120+
state = state + k[..., None, :] * delta[..., None]
121+
# Output projection along key dim with q
122+
y = (state * q[..., None, :]).sum(axis=-1) # [B, H, Dv]
123+
return y, state
124+
125+
126+
def gated_delta_kernel(
127+
q: mx.array,
128+
k: mx.array,
129+
v: mx.array,
130+
g: mx.array,
131+
beta: mx.array,
132+
state: mx.array,
133+
) -> Tuple[mx.array, mx.array]:
134+
B, T, Hk, Dk = k.shape
135+
Hv, Dv = v.shape[2:]
136+
input_type = q.dtype
137+
return _gated_delta_kernel(
138+
inputs=[q, k, v, g, beta, state, T],
139+
template=[
140+
("InT", input_type),
141+
("Dk", Dk),
142+
("Dv", Dv),
143+
("Hk", Hk),
144+
("Hv", Hv),
145+
],
146+
grid=(32, Dv, B * Hv),
147+
threadgroup=(32, 4, 1),
148+
output_shapes=[(B, T, Hv, Dv), state.shape],
149+
output_dtypes=[input_type, input_type],
150+
)
151+
152+
153+
def gated_delta_ops(
154+
q: mx.array,
155+
k: mx.array,
156+
v: mx.array,
157+
g: mx.array,
158+
beta: mx.array,
159+
state: Optional[mx.array] = None,
160+
) -> Tuple[mx.array, mx.array]:
161+
"""
162+
Ops-based reference implementation for prompt prefill (sequential loop).
163+
164+
Shapes:
165+
- q, k: [B, T, Hk, Dk]
166+
- v: [B, T, Hv, Dv]
167+
- g, beta: [B, T, Hv]
168+
- state: [B, Hv, Dk, Dv]
169+
Returns:
170+
- y: [B, T, Hv, Dv]
171+
- state: [B, Hv, Dk, Dv]
172+
"""
173+
B, T, Hk, Dk = q.shape
174+
Hv, Dv = v.shape[-2:]
175+
if state is None:
176+
state = mx.zeros((B, Hv, Dv, Dk), dtype=q.dtype)
177+
178+
if (repeat_factor := Hv // Hk) > 1:
179+
q = mx.repeat(q, repeat_factor, -2)
180+
k = mx.repeat(k, repeat_factor, -2)
181+
182+
ys = []
183+
for t in range(T):
184+
y, state = _gated_delta_step_ops(
185+
q[:, t],
186+
k[:, t],
187+
v[:, t],
188+
g[:, t],
189+
beta[:, t],
190+
state,
191+
)
192+
ys.append(y)
193+
y = mx.stack(ys, axis=1)
194+
return y, state
195+
196+
197+
def gated_delta_update(
198+
q: mx.array,
199+
k: mx.array,
200+
v: mx.array,
201+
a: mx.array,
202+
b: mx.array,
203+
A_log: mx.array,
204+
dt_bias: mx.array,
205+
state: Optional[mx.array] = None,
206+
) -> Tuple[mx.array, mx.array]:
207+
208+
beta = mx.sigmoid(b)
209+
g = compute_g(A_log, a, dt_bias)
210+
if state is None:
211+
B, _, Hk, Dk = q.shape
212+
Hv, Dv = v.shape[-2:]
213+
if state is None:
214+
state = mx.zeros((B, Hv, Dv, Dk), dtype=q.dtype)
215+
216+
if mx.default_device() != mx.gpu or not mx.metal.is_available():
217+
return gated_delta_ops(q, k, v, g, beta, state)
218+
else:
219+
return gated_delta_kernel(q, k, v, g, beta, state)

mlx_lm/models/qwen3_next.py

Lines changed: 7 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
# Copyright © 2025 Apple Inc.
22

33
from dataclasses import dataclass
4-
from functools import partial
54
from typing import Any, Dict, List, Optional, Tuple, Union
65

76
import mlx.core as mx
87
import mlx.nn as nn
98

109
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
1110
from .cache import KVCache, MambaCache
11+
from .gated_delta import gated_delta_update
1212
from .rope_utils import initialize_rope
1313
from .switch_layers import SwitchGLU
1414

@@ -45,52 +45,6 @@ class ModelArgs(BaseModelArgs):
4545
full_attention_interval: int = 4
4646

4747

48-
@partial(mx.compile, shapeless=True)
49-
def compute_g(A_log, a, dt_bias):
50-
return mx.exp(-mx.exp(A_log.astype(mx.float32)) * nn.softplus(a + dt_bias)).astype(
51-
A_log.dtype
52-
)
53-
54-
55-
def recurrent_gated_delta_rule(
56-
query: mx.array,
57-
key: mx.array,
58-
value: mx.array,
59-
a: mx.array,
60-
b: mx.array,
61-
A_log: mx.array,
62-
dt_bias: mx.array,
63-
state: mx.array,
64-
use_qk_l2norm_in_kernel: bool = False,
65-
) -> Tuple[mx.array, mx.array]:
66-
B, S, Hk, Dk = key.shape
67-
Hv, Dv = value.shape[2:]
68-
inv_scale = Dk**-0.5
69-
70-
if use_qk_l2norm_in_kernel:
71-
query = (inv_scale**2) * mx.fast.rms_norm(query, None, 1e-6)
72-
key = inv_scale * mx.fast.rms_norm(key, None, 1e-6)
73-
else:
74-
query = inv_scale * query
75-
76-
input_type = query.dtype
77-
if (repeat_factor := Hv // Hk) > 1:
78-
query = mx.repeat(query, repeat_factor, 2)
79-
key = mx.repeat(key, repeat_factor, 2)
80-
81-
beta = mx.sigmoid(b)
82-
g = compute_g(A_log, a, dt_bias)
83-
84-
out = mx.zeros((B, S, Hv, Dv), dtype=input_type)
85-
for i in range(S):
86-
state *= g[:, i, :, None, None]
87-
kv_mem = (state * key[:, i, :, :, None]).sum(axis=-2)
88-
delta = (value[:, i] - kv_mem) * beta[:, i, :, None]
89-
state += key[:, i, :, :, None] * delta[..., None, :]
90-
out[:, i] = (state * query[:, i, :, :, None]).sum(axis=-2)
91-
return out, state
92-
93-
9448
class Qwen3NextRMSNormGated(nn.Module):
9549
def __init__(self, hidden_size: int, eps: float = 1e-6):
9650
super().__init__()
@@ -297,25 +251,14 @@ def __call__(
297251
)
298252
]
299253

300-
if cache is not None and cache[1] is not None:
254+
if cache is not None:
301255
state = cache[1]
302-
else:
303-
state = mx.zeros(
304-
(B, self.num_v_heads, self.head_k_dim, self.head_v_dim),
305-
dtype=inputs.dtype,
306-
)
307256

308-
out, state = recurrent_gated_delta_rule(
309-
q,
310-
k,
311-
v,
312-
a,
313-
b,
314-
self.A_log,
315-
self.dt_bias,
316-
state,
317-
use_qk_l2norm_in_kernel=True,
318-
)
257+
inv_scale = k.shape[-1] ** -0.5
258+
q = (inv_scale**2) * mx.fast.rms_norm(q, None, 1e-6)
259+
k = inv_scale * mx.fast.rms_norm(k, None, 1e-6)
260+
261+
out, state = gated_delta_update(q, k, v, a, b, self.A_log, self.dt_bias, state)
319262

320263
if cache is not None:
321264
cache[1] = state

mlx_lm/models/ssm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,18 @@ def make_ssm_kernel():
3939
float acc = 0.0;
4040
auto x_ = static_cast<float>(x[d_idx]);
4141
42-
for (int i = 0; i < n_per_t; ++i) {{
42+
for (int i = 0; i < n_per_t; ++i) {
4343
auto s_idx = n_per_t * ds_idx + i;
4444
auto idx = d_idx * Ds + s_idx;
4545
auto dB_by_x = x_ * dt_ * static_cast<float>(B_[s_idx]);
4646
auto state = dA * i_state[idx] + dB_by_x;
4747
o_state[idx] = static_cast<T>(state);
4848
acc += state * C_[s_idx];
49-
}}
49+
}
5050
acc = simd_sum(acc);
51-
if (thread_index_in_simdgroup == 0) {{
51+
if (thread_index_in_simdgroup == 0) {
5252
out[d_idx] = static_cast<T>(acc + x_ * D[h_idx]);
53-
}}
53+
}
5454
"""
5555
return mx.fast.metal_kernel(
5656
name="ssm_kernel",

tests/test_models.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from mlx_lm.models import rope_utils
1111
from mlx_lm.models.base import create_causal_mask, scaled_dot_product_attention
1212
from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache
13+
from mlx_lm.models.gated_delta import gated_delta_kernel, gated_delta_ops
1314
from mlx_lm.models.ssm import ssm_attn, ssm_update
1415

1516

@@ -1847,6 +1848,27 @@ def test_ssm_masked(self):
18471848
self.assertTrue(mx.allclose(out, out_m, atol=1e-4, rtol=1e-4))
18481849
self.assertTrue(mx.allclose(out_state, out_state_m, atol=1e-4, rtol=1e-4))
18491850

1851+
def test_gated_delta(self):
1852+
for B in [1, 2]:
1853+
for T in [1, 2]:
1854+
B = 1
1855+
Hk = 16
1856+
Hv = 32
1857+
Dk = 128
1858+
Dv = 128
1859+
1860+
q = mx.random.normal(shape=(B, T, Hk, Dk))
1861+
k = mx.random.normal(shape=(B, T, Hk, Dk))
1862+
v = mx.random.normal(shape=(B, T, Hv, Dv))
1863+
g = mx.random.normal(shape=(B, T, Hv))
1864+
beta = mx.random.normal(shape=(B, T, Hv))
1865+
state = mx.random.normal(shape=(B, Hv, Dk, Dv))
1866+
1867+
y_op, st_op = gated_delta_ops(q, k, v, g, beta, state)
1868+
y_c, st_c = gated_delta_kernel(q, k, v, g, beta, state)
1869+
self.assertTrue(mx.allclose(y_op, y_c, rtol=1e-4, atol=1e-4))
1870+
self.assertTrue(mx.allclose(st_op, st_c, rtol=1e-4, atol=1e-3))
1871+
18501872

18511873
if __name__ == "__main__":
18521874
unittest.main()

0 commit comments

Comments
 (0)