From cc10c81fb2d474302709bfd7df64325a0892ad25 Mon Sep 17 00:00:00 2001 From: "clandestine.eth" <96172957+0xClandestine@users.noreply.github.com> Date: Thu, 30 Apr 2026 02:58:33 -0400 Subject: [PATCH 01/11] feat: add flash_qla_mlx for Apple Silicon --- flash_qla_mlx/__init__.py | 16 + flash_qla_mlx/ops/__init__.py | 6 + .../ops/gated_delta_rule/__init__.py | 14 + flash_qla_mlx/ops/gated_delta_rule/chunk.py | 822 ++++++++++++++++++ flash_qla_mlx/utils/__init__.py | 20 + flash_qla_mlx/utils/math.py | 9 + flash_qla_mlx/utils/pack.py | 147 ++++ setup.py | 3 + 8 files changed, 1037 insertions(+) create mode 100644 flash_qla_mlx/__init__.py create mode 100644 flash_qla_mlx/ops/__init__.py create mode 100644 flash_qla_mlx/ops/gated_delta_rule/__init__.py create mode 100644 flash_qla_mlx/ops/gated_delta_rule/chunk.py create mode 100644 flash_qla_mlx/utils/__init__.py create mode 100644 flash_qla_mlx/utils/math.py create mode 100644 flash_qla_mlx/utils/pack.py diff --git a/flash_qla_mlx/__init__.py b/flash_qla_mlx/__init__.py new file mode 100644 index 0000000..b9ab370 --- /dev/null +++ b/flash_qla_mlx/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] + +__version__ = "0.1.0" + +from flash_qla_mlx.ops.gated_delta_rule.chunk import ( + chunk_gated_delta_rule_fwd, + chunk_gated_delta_rule_bwd, + chunk_gated_delta_rule, +) + +__all__ = [ + "chunk_gated_delta_rule_fwd", + "chunk_gated_delta_rule_bwd", + "chunk_gated_delta_rule", +] diff --git a/flash_qla_mlx/ops/__init__.py b/flash_qla_mlx/ops/__init__.py new file mode 100644 index 0000000..81120d5 --- /dev/null +++ b/flash_qla_mlx/ops/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] + +from .gated_delta_rule import chunk_gated_delta_rule + +__all__ = ["chunk_gated_delta_rule"] diff --git a/flash_qla_mlx/ops/gated_delta_rule/__init__.py b/flash_qla_mlx/ops/gated_delta_rule/__init__.py new file mode 100644 index 0000000..48f8668 --- /dev/null +++ b/flash_qla_mlx/ops/gated_delta_rule/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] + +from .chunk import ( + chunk_gated_delta_rule_fwd, + chunk_gated_delta_rule_bwd, + chunk_gated_delta_rule, +) + +__all__ = [ + "chunk_gated_delta_rule_fwd", + "chunk_gated_delta_rule_bwd", + "chunk_gated_delta_rule", +] diff --git a/flash_qla_mlx/ops/gated_delta_rule/chunk.py b/flash_qla_mlx/ops/gated_delta_rule/chunk.py new file mode 100644 index 0000000..61586f9 --- /dev/null +++ b/flash_qla_mlx/ops/gated_delta_rule/chunk.py @@ -0,0 +1,822 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] + +import mlx.core as mx + +from flash_qla_mlx.utils import ( + pack, + unpack, + pad_and_reshape, + fill_last_chunk_of_g, + prepare_chunk_offsets, + l2norm, +) + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _chunk_local_cumsum( + g: mx.array, + chunk_size: int = 64, + cu_seqlens: mx.array = None, + reverse: bool = False, +) -> mx.array: + if cu_seqlens is not None: + g = unpack(g, cu_seqlens) + + batch_size, num_tokens, num_heads = g.shape + g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, H] + + if reverse: + g = mx.flip(g, axis=2) + g = mx.cumsum(g, axis=2) + g = mx.flip(g, axis=2) + else: + g = mx.cumsum(g, axis=2) + + g = g.reshape(batch_size, -1, num_heads)[:, :num_tokens] + + if cu_seqlens is not None: + g = pack(g, cu_seqlens) + return g + + +def _kkt_fwd( + k: mx.array, + g: mx.array, + beta: mx.array, + cu_seqlens: mx.array = None, + chunk_size: int = 64, +) -> mx.array: + if cu_seqlens is not None: + k = unpack(k, cu_seqlens) + g = unpack(g, cu_seqlens) + beta = unpack(beta, cu_seqlens) + + batch_size, num_tokens, num_k_heads, head_dim = k.shape + num_v_heads = g.shape[-1] + + if num_k_heads != num_v_heads: + k = mx.repeat(k, num_v_heads // num_k_heads, axis=2) + + k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, H, K] + g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, H] + beta = pad_and_reshape(beta, dim=1, chunk_size=chunk_size) # [B, N, C, H] + + mask = mx.triu(mx.ones((chunk_size, chunk_size), dtype=mx.bool_), k=0) + decay_mask = mx.exp(g[:, :, :, None, :] - g[:, :, None, :, :]) # [B, N, C, C, H] + decay_mask = mx.where(mask[None, None, :, :, None], mx.zeros_like(decay_mask), decay_mask) + + # attn[b, n, c, h, d] = (k_beta[b,n,c,h,:] . k[b,n,d,h,:]) * decay[b,n,c,d,h] + attn = mx.einsum( + "bnchk,bndhk->bnchd", k * beta[:, :, :, :, None], k + ) * mx.swapaxes(decay_mask, -2, -1) + attn = attn.reshape(batch_size, -1, num_v_heads, chunk_size)[:, :num_tokens] + + if cu_seqlens is not None: + attn = pack(attn, cu_seqlens) + return attn + + +def _kkt_solve( + x: mx.array, + cu_seqlens: mx.array = None, + chunk_size: int = 64, +) -> mx.array: + if cu_seqlens is not None: + x = unpack(x, cu_seqlens) + + batch_size, num_tokens, num_heads, _ = x.shape + + # x: [B, T, H, D] -> [B, N, H, C, D] (negated, lower-tri solve) + x = -pad_and_reshape(x, dim=1, chunk_size=chunk_size) # [B, N, C, H, D] + x = mx.swapaxes(x, 2, 3) # [B, N, H, C, D] + + for i in range(1, chunk_size): + row = x[..., i, :i] # [B, N, H, i] + sub = x[..., :i, :i] # [B, N, H, i, i] + new_val = row + (row[..., None] * sub).sum(axis=-2) + x[..., i, :i] = new_val + + x = x + mx.eye(chunk_size, dtype=x.dtype) + x = mx.swapaxes(x, 2, 3) # [B, N, C, H, D] + x = x.reshape(batch_size, -1, num_heads, chunk_size)[:, :num_tokens] + + if cu_seqlens is not None: + x = pack(x, cu_seqlens) + return x + + +def _kkt( + k: mx.array, + beta: mx.array, + g: mx.array, + cu_seqlens: mx.array = None, + chunk_size: int = 64, +) -> mx.array: + """Compute A = (I - L)^{-1} where L encodes the gated KKT system.""" + A = _kkt_fwd(k=k, g=g, beta=beta, cu_seqlens=cu_seqlens, chunk_size=chunk_size) + A = _kkt_solve(x=A, cu_seqlens=cu_seqlens, chunk_size=chunk_size) + return A + + +def _w_u_fwd( + k: mx.array, + v: mx.array, + beta: mx.array, + A: mx.array, + g: mx.array, + cu_seqlens: mx.array = None, +) -> tuple: + if cu_seqlens is not None: + k = unpack(k, cu_seqlens) + v = unpack(v, cu_seqlens) + A = unpack(A, cu_seqlens) + beta = unpack(beta, cu_seqlens) + g = unpack(g, cu_seqlens) + + batch_size, num_tokens, _, chunk_size = A.shape + _, _, num_k_heads, head_dim_k = k.shape + _, _, num_v_heads, head_dim_v = v.shape + + if num_k_heads != num_v_heads: + k = mx.repeat(k, num_v_heads // num_k_heads, axis=2) + + k_beta = pad_and_reshape( + k * (beta * mx.exp(g))[..., None], dim=1, chunk_size=chunk_size + ) # [B, N, C, Hv, K] + v_beta = pad_and_reshape( + v * beta[..., None], dim=1, chunk_size=chunk_size + ) # [B, N, C, Hv, V] + A = pad_and_reshape(A, dim=1) # [B, N, C, Hv, D] + + w = mx.einsum("bnchd,bndhk->bnchk", A, k_beta).reshape( + batch_size, -1, num_v_heads, head_dim_k + )[:, :num_tokens] + u = mx.einsum("bnchd,bndhk->bnchk", A, v_beta).reshape( + batch_size, -1, num_v_heads, head_dim_v + )[:, :num_tokens] + + if cu_seqlens is not None: + w = pack(w, cu_seqlens) + u = pack(u, cu_seqlens) + return w, u + + +def _chunk_gdr_fwd( + k: mx.array, + w: mx.array, + u: mx.array, + g: mx.array, + initial_state: mx.array = None, + cu_seqlens: mx.array = None, + chunk_size: int = 64, +) -> tuple: + if cu_seqlens is not None: + k = unpack(k, cu_seqlens) + w = unpack(w, cu_seqlens) + u = unpack(u, cu_seqlens) + g = unpack(g, cu_seqlens) + + batch_size, num_tokens, num_k_heads, head_dim_k = k.shape + _, _, num_v_heads, head_dim_v = u.shape + + if num_k_heads != num_v_heads: + k = mx.repeat(k, num_v_heads // num_k_heads, axis=2) + + k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K] + w = pad_and_reshape(w, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K] + u = pad_and_reshape(u, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V] + g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv] + g = fill_last_chunk_of_g(g, num_tokens, cu_seqlens, chunk_size=chunk_size) + + if initial_state is None: + last_state = mx.zeros( + (batch_size, num_v_heads, head_dim_k, head_dim_v), dtype=g.dtype + ) + else: + last_state = initial_state.astype(g.dtype) + + h_list, vn_list = [], [] + for i in range(k.shape[1]): + h_list.append(last_state) + v_new = u[:, i] - mx.einsum("bchk,bhkv->bchv", w[:, i], last_state) + vn_list.append(v_new) + last_state = last_state * mx.exp(g[:, i, -1, :])[:, :, None, None] + last_state = last_state + mx.einsum( + "bchk,bchv->bhkv", + k[:, i] * mx.exp(g[:, i, -1:, :, None] - g[:, i, :, :, None]), + v_new, + ) + + h = mx.stack(h_list, axis=1) + vn = ( + mx.stack(vn_list, axis=1) + .reshape(batch_size, -1, num_v_heads, head_dim_v)[:, :num_tokens] + ) + + if cu_seqlens is not None: + chunk_offsets = prepare_chunk_offsets(cu_seqlens, chunk_size) + vn = pack(vn, cu_seqlens) + h = pack(h, chunk_offsets) + + return h, vn, last_state + + +def _chunk_o_fwd( + q: mx.array, + k: mx.array, + v: mx.array, + h: mx.array, + g: mx.array, + cu_seqlens: mx.array = None, + scale: float = None, + chunk_size: int = 64, +) -> mx.array: + if cu_seqlens is not None: + chunk_offsets = prepare_chunk_offsets(cu_seqlens, chunk_size) + q = unpack(q, cu_seqlens) + k = unpack(k, cu_seqlens) + v = unpack(v, cu_seqlens) + g = unpack(g, cu_seqlens) + h = unpack(h, chunk_offsets) + + batch_size, num_tokens, num_k_heads, head_dim_k = k.shape + _, _, num_v_heads, head_dim_v = v.shape + + if num_k_heads != num_v_heads: + q = mx.repeat(q, num_v_heads // num_k_heads, axis=2) + k = mx.repeat(k, num_v_heads // num_k_heads, axis=2) + + scale = scale if scale is not None else head_dim_k ** (-0.5) + + q = pad_and_reshape(q, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K] + k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) + v = pad_and_reshape(v, dim=1, chunk_size=chunk_size) + g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) + + q = q * scale + + mask = mx.triu( + mx.ones((chunk_size, chunk_size), dtype=mx.bool_), k=1 + ) + decay_mask = mx.exp(g[:, :, :, None, :] - g[:, :, None, :, :]) + decay_mask = mx.where(mask[None, None, :, :, None], mx.zeros_like(decay_mask), decay_mask) + + attn = mx.einsum("bnchk,bndhk->bncdh", q, k) * decay_mask + attn_inter = mx.einsum("bnchk,bnhkv->bnchv", q * mx.exp(g)[..., None], h) + o = attn_inter + mx.einsum("bncdh,bndhv->bnchv", attn, v) + + o = o.reshape(batch_size, -1, num_v_heads, head_dim_v)[:, :num_tokens] + if cu_seqlens is not None: + o = pack(o, cu_seqlens) + return o + + +# --------------------------------------------------------------------------- +# Backward helpers +# --------------------------------------------------------------------------- + + +def _chunk_dv_bwd( + q: mx.array, + k: mx.array, + g: mx.array, + do: mx.array, + cu_seqlens: mx.array = None, + scale: float = None, + chunk_size: int = 64, +) -> mx.array: + if cu_seqlens is not None: + q = unpack(q, cu_seqlens) + k = unpack(k, cu_seqlens) + g = unpack(g, cu_seqlens) + do = unpack(do, cu_seqlens) + + batch_size, num_tokens, num_k_heads, head_dim_k = k.shape + _, _, num_v_heads, head_dim_v = do.shape + + if num_k_heads != num_v_heads: + q = mx.repeat(q, num_v_heads // num_k_heads, axis=2) + k = mx.repeat(k, num_v_heads // num_k_heads, axis=2) + + scale = scale if scale is not None else head_dim_k ** (-0.5) + + q = pad_and_reshape(q, dim=1, chunk_size=chunk_size) + k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) + g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) + do = pad_and_reshape(do, dim=1, chunk_size=chunk_size) + + q = q * scale + + mask = mx.triu(mx.ones((chunk_size, chunk_size), dtype=mx.bool_), k=1) + decay_mask = mx.exp(g[:, :, :, None, :] - g[:, :, None, :, :]) + decay_mask = mx.where(mask[None, None, :, :, None], mx.zeros_like(decay_mask), decay_mask) + + attn = mx.einsum("bnchk,bndhk->bncdh", q, k) * decay_mask + dv = mx.einsum("bncdh,bnchv->bndhv", attn, do) + + dv = dv.reshape(batch_size, -1, num_v_heads, head_dim_v)[:, :num_tokens] + if cu_seqlens is not None: + dv = pack(dv, cu_seqlens) + return dv + + +def _chunk_gdr_bwd( + q: mx.array, + k: mx.array, + w: mx.array, + g: mx.array, + do: mx.array, + dv: mx.array, + h0: mx.array = None, + dht: mx.array = None, + cu_seqlens: mx.array = None, + scale: float = None, + chunk_size: int = 64, +) -> tuple: + if cu_seqlens is not None: + q = unpack(q, cu_seqlens) + k = unpack(k, cu_seqlens) + w = unpack(w, cu_seqlens) + g = unpack(g, cu_seqlens) + do = unpack(do, cu_seqlens) + dv = unpack(dv, cu_seqlens) + + batch_size, num_tokens, num_k_heads, head_dim_k = k.shape + _, _, num_v_heads, head_dim_v = do.shape + + if num_k_heads != num_v_heads: + q = mx.repeat(q, num_v_heads // num_k_heads, axis=2) + k = mx.repeat(k, num_v_heads // num_k_heads, axis=2) + + scale = scale if scale is not None else head_dim_k ** (-0.5) + + q = pad_and_reshape(q, dim=1, chunk_size=chunk_size) + k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) + w = pad_and_reshape(w, dim=1, chunk_size=chunk_size) + g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) + do = pad_and_reshape(do, dim=1, chunk_size=chunk_size) + dv = pad_and_reshape(dv, dim=1, chunk_size=chunk_size) + g = fill_last_chunk_of_g(g, num_tokens, cu_seqlens, chunk_size=chunk_size) + + q = q * scale + + if dht is None: + dstate = mx.zeros( + (batch_size, num_v_heads, head_dim_k, head_dim_v), dtype=g.dtype + ) + else: + dstate = dht.astype(g.dtype) + + dstate_inter = mx.einsum( + "bnchk,bnchv->bnhkv", q * mx.exp(g)[..., None], do + ) + + dh_list = [] + dv_list = list(mx.split(dv, dv.shape[1], axis=1)) # list of [B,1,C,Hv,V] + for i in reversed(range(k.shape[1])): + dh_list.insert(0, dstate) + dv_i = dv_list[i][:, 0] # [B, C, Hv, V] + dv_i = dv_i + mx.einsum( + "bchk,bhkv->bchv", + k[:, i] * mx.exp(g[:, i, -1:, :, None] - g[:, i, :, :, None]), + dstate, + ) + dv_list[i] = dv_i + dstate = dstate * mx.exp(g[:, i, -1, :])[:, :, None, None] + dstate = ( + dstate + + dstate_inter[:, i] + - mx.einsum("bchk,bchv->bhkv", w[:, i], dv_i) + ) + + dh = mx.stack(dh_list, axis=1) + + dh0 = None if h0 is None else dstate + dv = mx.stack(dv_list, axis=1).reshape( + batch_size, -1, num_v_heads, head_dim_v + )[:, :num_tokens] + + if cu_seqlens is not None: + chunk_offsets = prepare_chunk_offsets(cu_seqlens, chunk_size) + dv = pack(dv, cu_seqlens) + dh = pack(dh, chunk_offsets) + return dh, dh0, dv + + +def _chunk_dqkwg_bwd( + q: mx.array, + k: mx.array, + v: mx.array, + w: mx.array, + g: mx.array, + h: mx.array, + dv: mx.array, + do: mx.array, + dh: mx.array, + cu_seqlens: mx.array = None, + scale: float = None, + chunk_size: int = 64, +) -> tuple: + if cu_seqlens is not None: + chunk_offsets = prepare_chunk_offsets(cu_seqlens, chunk_size) + q = unpack(q, cu_seqlens) + k = unpack(k, cu_seqlens) + v = unpack(v, cu_seqlens) + w = unpack(w, cu_seqlens) + g = unpack(g, cu_seqlens) + do = unpack(do, cu_seqlens) + dv = unpack(dv, cu_seqlens) + h = unpack(h, chunk_offsets) + dh = unpack(dh, chunk_offsets) + + batch_size, num_tokens, num_k_heads, head_dim_k = k.shape + _, _, num_v_heads, head_dim_v = do.shape + + if num_k_heads != num_v_heads: + q = mx.repeat(q, num_v_heads // num_k_heads, axis=2) + k = mx.repeat(k, num_v_heads // num_k_heads, axis=2) + + scale = scale if scale is not None else head_dim_k ** (-0.5) + + q = pad_and_reshape(q, dim=1, chunk_size=chunk_size) + k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) + v = pad_and_reshape(v, dim=1, chunk_size=chunk_size) + w = pad_and_reshape(w, dim=1, chunk_size=chunk_size) + g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) + do = pad_and_reshape(do, dim=1, chunk_size=chunk_size) + dv = pad_and_reshape(dv, dim=1, chunk_size=chunk_size) + g = fill_last_chunk_of_g(g, num_tokens, cu_seqlens, chunk_size=chunk_size) + + mask = mx.triu(mx.ones((chunk_size, chunk_size), dtype=mx.bool_), k=1) + decay_mask = mx.exp(g[:, :, :, None, :] - g[:, :, None, :, :]) + decay_mask = mx.where(mask[None, None, :, :, None], mx.zeros_like(decay_mask), decay_mask) + + dg_last = (h * dh).sum(axis=-1).sum(axis=-1) # [B, N, Hv] + ds = mx.einsum("bnchv,bndhv->bncdh", do, v) + dq = mx.einsum("bnchv,bnhkv->bnchk", do, h) + dk = mx.einsum("bnchv,bnhkv->bnchk", v, dh) + dw = -mx.einsum("bnchv,bnhkv->bnchk", dv, h) + + g_last = g[:, :, -1] + dg_last = dg_last * mx.exp(g_last) + dq = dq * mx.exp(g)[..., None] * scale + dg = (q * dq).sum(axis=-1) + dk = dk * mx.exp(g_last[:, :, None, :, None] - g)[..., None] + dg = dg - (k * dk).sum(axis=-1) + dg_last = dg_last + (k * dk).sum(axis=-1).sum(axis=-2) + ds = ds * decay_mask * scale + ds2 = ds * mx.einsum("bnchk,bndhk->bncdh", q, k) + dg = dg + ds2.sum(axis=-2) + dg = dg - ds2.sum(axis=-3) + dq = dq + mx.einsum("bncdh,bndhk->bnchk", ds, k) + dk = dk + mx.einsum("bncdh,bnchk->bndhk", ds, q) + dg[:, :, -1] = dg[:, :, -1] + dg_last + + dg = fill_last_chunk_of_g( + dg, num_tokens, cu_seqlens, chunk_size=chunk_size, reverse=True + ) + dq = dq.reshape(batch_size, -1, num_v_heads, head_dim_k)[:, :num_tokens] + dk = dk.reshape(batch_size, -1, num_v_heads, head_dim_k)[:, :num_tokens] + dw = dw.reshape(batch_size, -1, num_v_heads, head_dim_k)[:, :num_tokens] + dg = dg.reshape(batch_size, -1, num_v_heads)[:, :num_tokens] + + if cu_seqlens is not None: + dq = pack(dq, cu_seqlens) + dk = pack(dk, cu_seqlens) + dw = pack(dw, cu_seqlens) + dg = pack(dg, cu_seqlens) + return dq, dk, dw, dg + + +def _chunk_wy_bwd( + k: mx.array, + v: mx.array, + beta: mx.array, + A: mx.array, + g: mx.array, + dw: mx.array, + du: mx.array, + dk1: mx.array, + dg1: mx.array, + cu_seqlens: mx.array = None, + chunk_size: int = 64, +) -> tuple: + if cu_seqlens is not None: + k = unpack(k, cu_seqlens) + v = unpack(v, cu_seqlens) + beta = unpack(beta, cu_seqlens) + A = unpack(A, cu_seqlens) + g = unpack(g, cu_seqlens) + dw = unpack(dw, cu_seqlens) + du = unpack(du, cu_seqlens) + dk1 = unpack(dk1, cu_seqlens) + dg1 = unpack(dg1, cu_seqlens) + + batch_size, num_tokens, num_k_heads, head_dim_k = k.shape + _, _, num_v_heads, head_dim_v = v.shape + chunk_size_A = A.shape[-1] + + if num_k_heads != num_v_heads: + k = mx.repeat(k, num_v_heads // num_k_heads, axis=2) + + k = pad_and_reshape(k, dim=1, chunk_size=chunk_size_A) + v = pad_and_reshape(v, dim=1, chunk_size=chunk_size_A) + beta = pad_and_reshape(beta, dim=1, chunk_size=chunk_size_A) + A = pad_and_reshape(A, dim=1) + g = pad_and_reshape(g, dim=1, chunk_size=chunk_size_A) + dw = pad_and_reshape(dw, dim=1, chunk_size=chunk_size_A) + du = pad_and_reshape(du, dim=1, chunk_size=chunk_size_A) + dk1 = pad_and_reshape(dk1, dim=1, chunk_size=chunk_size_A) + dg1 = pad_and_reshape(dg1, dim=1, chunk_size=chunk_size_A) + + dA = mx.einsum("bnchk,bndhk->bnchd", dw, k * (beta * mx.exp(g))[..., None]) + dk_beta_g = mx.einsum("bnchd,bnchk->bndhk", A, dw) + dk = dk_beta_g * (beta * mx.exp(g))[..., None] + db = (dk_beta_g * k * mx.exp(g)[..., None]).sum(axis=-1) + dg = (dk_beta_g * k * (mx.exp(g) * beta)[..., None]).sum(axis=-1) + + dA = dA + mx.einsum("bnchv,bndhv->bnchd", du, v * beta[..., None]) + dv_beta = mx.einsum("bnchd,bnchv->bndhv", A, du) + dv = dv_beta * beta[..., None] + db = db + (dv_beta * v).sum(axis=-1) + + mask = mx.triu(mx.ones((chunk_size_A, chunk_size_A), dtype=mx.bool_), k=0) + decay_mask = mx.exp(g[:, :, :, None, :] - g[:, :, None, :, :]) + decay_mask = mx.where( + mask[None, None, :, :, None], mx.zeros_like(decay_mask), decay_mask + ) + decay_mask = mx.swapaxes(decay_mask, -2, -1) + + dA = mx.where(mask[None, None, :, None, :], mx.zeros_like(dA), dA) + dA = mx.einsum("bndhc,bndhe->bnche", A, dA) + dA = mx.einsum("bnchd,bnehd->bnche", dA, A) + dA = -dA * decay_mask + + A_kkt = mx.einsum("bnchk,bndhk->bnchd", k * beta[..., None], k) + dk_beta = mx.einsum("bnchd,bndhk->bnchk", dA, k) + db = db + (dk_beta * k).sum(axis=-1) + dk = dk + mx.einsum("bnchd,bnchk->bndhk", dA, k * beta[..., None]) + dk = dk + dk_beta * beta[..., None] + dk = dk + dk1 + + dg = dg + (dA * A_kkt).sum(axis=-1) - mx.swapaxes( + (dA * A_kkt).sum(axis=-3), -1, -2 + ) + dg = dg + dg1 + + dk = dk.reshape(batch_size, -1, num_v_heads, head_dim_k)[:, :num_tokens] + dv = dv.reshape(batch_size, -1, num_v_heads, head_dim_k)[:, :num_tokens] + db = db.reshape(batch_size, -1, num_v_heads)[:, :num_tokens] + dg = dg.reshape(batch_size, -1, num_v_heads)[:, :num_tokens] + + if cu_seqlens is not None: + dk = pack(dk, cu_seqlens) + dv = pack(dv, cu_seqlens) + db = pack(db, cu_seqlens) + dg = pack(dg, cu_seqlens) + return dk, dv, db, dg + + +def _group_reduce_vector(buffer: mx.array, Hg: int) -> mx.array: + batch_size, num_tokens, H, K = buffer.shape + return buffer.reshape(batch_size, num_tokens, Hg, H // Hg, K).sum(axis=3) + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def chunk_gated_delta_rule_fwd( + q: mx.array, + k: mx.array, + v: mx.array, + g: mx.array, + beta: mx.array, + scale: float = None, + initial_state: mx.array = None, + cu_seqlens: mx.array = None, + output_final_state: bool = True, + output_h: bool = False, +) -> tuple: + """ + Forward pass of the Gated Delta Rule. + + Args: + q: [B, T, Hk, K] + k: [B, T, Hk, K] + v: [B, T, Hv, V] + g: [B, T, Hv] (log-decay, negative values) + beta: [B, T, Hv] + scale: Optional softmax scale (defaults to K**-0.5) + initial_state: Optional [B, Hv, K, V] + cu_seqlens: Optional [S+1] int32, for variable-length inputs + output_final_state: Whether to return the final state + output_h: Whether to return the chunk-level states h + + Returns: + (g_cumsum, A, o, h, final_state) + h and final_state may be None if not requested. + """ + if scale is None: + scale = k.shape[-1] ** -0.5 + + chunk_size = 64 + + g_cumsum = _chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens) + A = _kkt(k=k, beta=beta, g=g_cumsum, cu_seqlens=cu_seqlens, chunk_size=chunk_size) + w, u = _w_u_fwd(k=k, v=v, beta=beta, A=A, g=g_cumsum, cu_seqlens=cu_seqlens) + h, vn, final_state = _chunk_gdr_fwd( + k=k, w=w, u=u, g=g_cumsum, + initial_state=initial_state, + cu_seqlens=cu_seqlens, chunk_size=chunk_size, + ) + o = _chunk_o_fwd( + q=q, k=k, v=vn, h=h, g=g_cumsum, + cu_seqlens=cu_seqlens, scale=scale, chunk_size=chunk_size, + ) + + return ( + g_cumsum, + A, + o, + h if output_h else None, + final_state if output_final_state else None, + ) + + +def chunk_gated_delta_rule_bwd( + q: mx.array, + k: mx.array, + v: mx.array, + g: mx.array, + beta: mx.array, + A: mx.array, + do: mx.array, + dht: mx.array = None, + scale: float = None, + initial_state: mx.array = None, + cu_seqlens: mx.array = None, +) -> tuple: + """ + Backward pass of the Gated Delta Rule. + + Args: + q, k, v, g, beta, A: Saved from forward (g is the cumsum version) + do: Gradient of the output [B, T, Hv, V] + dht: Gradient of the final state [B, Hv, K, V] or None + scale, initial_state, cu_seqlens: same as forward + + Returns: + (dq, dk, dv, db, dg, dh0) + """ + if scale is None: + scale = k.shape[-1] ** -0.5 + + chunk_size = 64 + + w, u = _w_u_fwd(k=k, v=v, beta=beta, A=A, g=g, cu_seqlens=cu_seqlens) + h, vn, _ = _chunk_gdr_fwd( + k=k, w=w, u=u, g=g, + initial_state=initial_state, cu_seqlens=cu_seqlens, chunk_size=chunk_size, + ) + dv = _chunk_dv_bwd( + q=q, k=k, g=g, do=do, cu_seqlens=cu_seqlens, scale=scale, chunk_size=chunk_size + ) + dh, dh0, dv = _chunk_gdr_bwd( + q=q, k=k, w=w, g=g, do=do, dv=dv, + h0=initial_state, dht=dht, + cu_seqlens=cu_seqlens, scale=scale, chunk_size=chunk_size, + ) + dq, dk1, dw, dg1 = _chunk_dqkwg_bwd( + q=q, k=k, v=vn, w=w, g=g, h=h, + dv=dv, do=do, dh=dh, + cu_seqlens=cu_seqlens, scale=scale, chunk_size=chunk_size, + ) + dk, dv_out, db, dg = _chunk_wy_bwd( + k=k, v=v, beta=beta, A=A, g=g, + dw=dw, du=dv, dk1=dk1, dg1=dg1, + cu_seqlens=cu_seqlens, chunk_size=chunk_size, + ) + + Hg, H = k.shape[-2], v.shape[-2] + if Hg < H: + dq = _group_reduce_vector(dq, Hg) + dk = _group_reduce_vector(dk, Hg) + + dg = _chunk_local_cumsum(dg, chunk_size=chunk_size, reverse=True, cu_seqlens=cu_seqlens) + return dq, dk, dv_out, db, dg, dh0 + + +def chunk_gated_delta_rule( + q: mx.array, + k: mx.array, + v: mx.array, + g: mx.array, + beta: mx.array, + scale: float = None, + initial_state: mx.array = None, + output_final_state: bool = False, + use_qk_l2norm_in_kernel: bool = False, + cu_seqlens: mx.array = None, + head_first: bool = False, +) -> tuple: + """ + Gated Delta Rule: end-to-end forward with MLX-native gradient support. + + Args: + q: [B, T, Hk, K] + k: [B, T, Hk, K] + v: [B, T, Hv, V] + g: [B, T, Hv] + beta: [B, T, Hv] + scale: Optional softmax scale + initial_state: Optional [B, Hv, K, V] + output_final_state: Return final recurrent state + use_qk_l2norm_in_kernel: L2-normalize q and k before computation + cu_seqlens: Optional variable-length sequence boundaries + head_first: Not supported (must be False) + + Returns: + (o, final_state) — final_state is None if output_final_state=False + """ + assert not head_first, "head_first=True is not supported." + assert v.shape[2] % k.shape[2] == 0, ( + "num_v_heads must be divisible by num_k_heads." + ) + + if cu_seqlens is not None and q.shape[0] != 1: + raise ValueError( + f"Batch size must be 1 when using cu_seqlens, got {q.shape[0]}." + ) + + if scale is None: + scale = k.shape[-1] ** -0.5 + + if use_qk_l2norm_in_kernel: + q = l2norm(q) + k = l2norm(k) + + # Define custom forward/backward for correct gradient computation. + # We use closures to capture scale, cu_seqlens, and the presence of h0. + _scale = scale + _cu_seqlens = cu_seqlens + + if initial_state is not None: + @mx.custom_function + def _fn(q, k, v, g, beta, h0): + g_out, A, o, _, fs = chunk_gated_delta_rule_fwd( + q=q, k=k, v=v, g=g, beta=beta, + scale=_scale, initial_state=h0, + cu_seqlens=_cu_seqlens, + output_final_state=True, output_h=False, + ) + fs_safe = fs if fs is not None else mx.zeros_like(h0) + return o, fs_safe, g_out, A + + @_fn.vjp + def _fn_vjp(primals, cotangents, outputs): + q, k, v, g, beta, h0 = primals + do, dfs, _, _ = cotangents + _, _, g_out, A = outputs + dq, dk, dv, db, dg, dh0 = chunk_gated_delta_rule_bwd( + q=q, k=k, v=v, g=g_out, beta=beta, A=A, + do=do, dht=dfs, scale=_scale, + initial_state=h0, cu_seqlens=_cu_seqlens, + ) + return dq, dk, dv, dg, db, dh0 + + o, fs, _, _ = _fn(q, k, v, g, beta, initial_state) + else: + @mx.custom_function + def _fn(q, k, v, g, beta): + g_out, A, o, _, fs = chunk_gated_delta_rule_fwd( + q=q, k=k, v=v, g=g, beta=beta, + scale=_scale, initial_state=None, + cu_seqlens=_cu_seqlens, + output_final_state=output_final_state, output_h=False, + ) + _dummy = mx.zeros((1,), dtype=q.dtype) + return o, _dummy, g_out, A + + @_fn.vjp + def _fn_vjp(primals, cotangents, outputs): + q, k, v, g, beta = primals + do, _, _, _ = cotangents + _, _, g_out, A = outputs + dq, dk, dv, db, dg, _ = chunk_gated_delta_rule_bwd( + q=q, k=k, v=v, g=g_out, beta=beta, A=A, + do=do, dht=None, scale=_scale, + initial_state=None, cu_seqlens=_cu_seqlens, + ) + return dq, dk, dv, dg, db + + o, _, _, _ = _fn(q, k, v, g, beta) + fs = None + + return o, fs if output_final_state else None diff --git a/flash_qla_mlx/utils/__init__.py b/flash_qla_mlx/utils/__init__.py new file mode 100644 index 0000000..c8952e2 --- /dev/null +++ b/flash_qla_mlx/utils/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] + +from .math import l2norm +from .pack import ( + pack, + unpack, + pad_and_reshape, + fill_last_chunk_of_g, + prepare_chunk_offsets, +) + +__all__ = [ + "l2norm", + "pack", + "unpack", + "pad_and_reshape", + "fill_last_chunk_of_g", + "prepare_chunk_offsets", +] diff --git a/flash_qla_mlx/utils/math.py b/flash_qla_mlx/utils/math.py new file mode 100644 index 0000000..13dfe38 --- /dev/null +++ b/flash_qla_mlx/utils/math.py @@ -0,0 +1,9 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] + +import mlx.core as mx + + +def l2norm(x: mx.array, dim: int = -1, eps: float = 1e-6) -> mx.array: + inv_norm = mx.rsqrt((x * x).sum(axis=dim, keepdims=True) + eps) + return x * inv_norm diff --git a/flash_qla_mlx/utils/pack.py b/flash_qla_mlx/utils/pack.py new file mode 100644 index 0000000..dd244fd --- /dev/null +++ b/flash_qla_mlx/utils/pack.py @@ -0,0 +1,147 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] + +import mlx.core as mx + + +def unpack( + x: mx.array, # [1, sum_T, *dims] + cu_seqlens: mx.array, +) -> mx.array: + assert x.shape[0] == 1 + batch_size = cu_seqlens.shape[0] - 1 + seqlens = [ + int((cu_seqlens[i + 1] - cu_seqlens[i]).item()) for i in range(batch_size) + ] + max_len = max(seqlens) + rest = x.shape[2:] + + parts = [] + for i in range(batch_size): + start = int(cu_seqlens[i].item()) + end = int(cu_seqlens[i + 1].item()) + chunk = x[0, start:end] + if end - start < max_len: + pad_width = [(0, max_len - (end - start))] + [(0, 0)] * len(rest) + chunk = mx.pad(chunk, pad_width) + parts.append(chunk) + return mx.stack(parts, axis=0) # [B, max_len, *dims] + + +def pack( + x: mx.array, # [B, max_T, *dims] + cu_seqlens: mx.array, +) -> mx.array: + batch_size = cu_seqlens.shape[0] - 1 + parts = [] + for i in range(batch_size): + start = int(cu_seqlens[i].item()) + end = int(cu_seqlens[i + 1].item()) + parts.append(x[i, : end - start]) + packed = mx.concatenate(parts, axis=0) # [sum_T, *dims] + return packed[None] # [1, sum_T, *dims] + + +def pad_and_reshape( + x: mx.array, + dim: int, + chunk_size: int = 64, +) -> mx.array: + seq_len = x.shape[dim] + pad_size = (chunk_size - seq_len % chunk_size) % chunk_size + if pad_size > 0: + pad_width = [(0, 0)] * len(x.shape) + pad_width[dim] = (0, pad_size) + x = mx.pad(x, pad_width) + shape = list(x.shape) + shape[dim : dim + 1] = [-1, chunk_size] + return x.reshape(shape) + + +def fill_last_chunk_of_g( + g: mx.array, # [B, N, C, Hv] + num_tokens: int, + cu_seqlens: mx.array = None, + chunk_size: int = 64, + reverse: bool = False, +) -> mx.array: + if cu_seqlens is None: + last_chunk_size = num_tokens % chunk_size + if last_chunk_size > 0: + B, N, C, Hv = g.shape + g_last = g[:, -1, :, :] # [B, C, Hv] + if reverse: + update = ( + g_last[:, last_chunk_size - 1 : last_chunk_size, :] + + g_last[:, -1:, :] + ) + new_last = mx.concatenate( + [ + g_last[:, : last_chunk_size - 1, :], + update, + g_last[:, last_chunk_size:, :], + ], + axis=1, + ) + else: + fill_val = mx.broadcast_to( + g_last[:, last_chunk_size - 1 : last_chunk_size, :], + (B, C - last_chunk_size, Hv), + ) + new_last = mx.concatenate( + [g_last[:, :last_chunk_size, :], fill_val], axis=1 + ) + g = mx.concatenate([g[:, :-1, :, :], new_last[:, None, :, :]], axis=1) + else: + batch_size = cu_seqlens.shape[0] - 1 + new_g_list = [] + for i in range(batch_size): + start = int(cu_seqlens[i].item()) + end = int(cu_seqlens[i + 1].item()) + seqlen = end - start + last_chunk_idx = seqlen // chunk_size + lcs = seqlen % chunk_size + g_i = g[i] # [N, C, Hv] + if lcs > 0: + g_i_last = g_i[last_chunk_idx] # [C, Hv] + if reverse: + update = ( + g_i_last[lcs - 1 : lcs, :] + g_i_last[-1:, :] + ) + new_chunk = mx.concatenate( + [ + g_i_last[: lcs - 1, :], + update, + g_i_last[lcs:, :], + ], + axis=0, + ) + else: + fill_val = mx.broadcast_to( + g_i_last[lcs - 1 : lcs, :], + (chunk_size - lcs, g_i_last.shape[-1]), + ) + new_chunk = mx.concatenate( + [g_i_last[:lcs, :], fill_val], axis=0 + ) + g_i = mx.concatenate( + [ + g_i[:last_chunk_idx, :, :], + new_chunk[None, :, :], + g_i[last_chunk_idx + 1 :, :, :], + ], + axis=0, + ) + new_g_list.append(g_i) + g = mx.stack(new_g_list, axis=0) + return g + + +def prepare_chunk_offsets(cu_seqlens: mx.array, chunk_size: int) -> mx.array: + batch_size = cu_seqlens.shape[0] - 1 + offsets = [0] + for i in range(batch_size): + seqlen = int((cu_seqlens[i + 1] - cu_seqlens[i]).item()) + n_chunks = (seqlen + chunk_size - 1) // chunk_size + offsets.append(offsets[-1] + n_chunks) + return mx.array(offsets, dtype=mx.int32) diff --git a/setup.py b/setup.py index 2c6c516..8d1181f 100644 --- a/setup.py +++ b/setup.py @@ -27,5 +27,8 @@ "tilelang==0.1.8", "apache-tvm-ffi==0.1.9", ], + extras_require={ + "mlx": ["mlx>=0.24.0"], + }, zip_safe=False, ) From 8d4e0492f5a49439e0c026f0ed80a9e00bc72084 Mon Sep 17 00:00:00 2001 From: "clandestine.eth" <96172957+0xClandestine@users.noreply.github.com> Date: Thu, 30 Apr 2026 18:49:04 -0500 Subject: [PATCH 02/11] test: add numerical correctness tests for flash_qla_mlx (Apple Silicon) Two-subprocess design compares flash_qla_mlx against a PyTorch CPU reference across 7 cases covering GQA, initial state, T % 64 != 0, forward-only, and forward+backward. All pass with 0.0000 relative error. Also fixes two bugs in chunk.py: - Replace mx.flip (absent in MLX 0.31.2) with [::-1] slice reversal - Fix broadcast shape in _chunk_dqkwg_bwd (5D vs 4D mismatch) --- flash_qla_mlx/ops/gated_delta_rule/chunk.py | 6 +- tests/test_mlx.py | 348 ++++++++++++++++++++ 2 files changed, 351 insertions(+), 3 deletions(-) create mode 100644 tests/test_mlx.py diff --git a/flash_qla_mlx/ops/gated_delta_rule/chunk.py b/flash_qla_mlx/ops/gated_delta_rule/chunk.py index 61586f9..b70b2f8 100644 --- a/flash_qla_mlx/ops/gated_delta_rule/chunk.py +++ b/flash_qla_mlx/ops/gated_delta_rule/chunk.py @@ -31,9 +31,9 @@ def _chunk_local_cumsum( g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, H] if reverse: - g = mx.flip(g, axis=2) + g = g[:, :, ::-1, :] g = mx.cumsum(g, axis=2) - g = mx.flip(g, axis=2) + g = g[:, :, ::-1, :] else: g = mx.cumsum(g, axis=2) @@ -466,7 +466,7 @@ def _chunk_dqkwg_bwd( dg_last = dg_last * mx.exp(g_last) dq = dq * mx.exp(g)[..., None] * scale dg = (q * dq).sum(axis=-1) - dk = dk * mx.exp(g_last[:, :, None, :, None] - g)[..., None] + dk = dk * mx.exp(g_last[:, :, None, :] - g)[..., None] dg = dg - (k * dk).sum(axis=-1) dg_last = dg_last + (k * dk).sum(axis=-1).sum(axis=-2) ds = ds * decay_mask * scale diff --git a/tests/test_mlx.py b/tests/test_mlx.py new file mode 100644 index 0000000..8ba0a26 --- /dev/null +++ b/tests/test_mlx.py @@ -0,0 +1,348 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] +# +# Numerical correctness tests for flash_qla_mlx (Apple Silicon / MLX). +# Compares against a PyTorch CPU reference. No CUDA required. +# +# Run: +# python tests/test_mlx.py +# +# Design: two sequential subprocesses share a temp directory. +# --gen-ref DIR : PyTorch-only process — computes reference, writes .npy files +# --check-mlx DIR: MLX-only process — loads .npy files, runs flash_qla_mlx +# Loading PyTorch and MLX in the same process causes macOS memory conflicts +# (OMP dual-runtime + Metal/CPU allocator interference), so they run separately. + +import os +import sys +import subprocess +import tempfile +# On macOS with Python 3.14 + PyTorch 2.11, numpy's OMP runtime must not load +# before torch's OMP runtime — the reverse order causes memory corruption on +# cumsum and other ops. Pre-load torch first when we're in gen-ref mode. +if "--gen-ref" in sys.argv: + import torch as _torch_preload # noqa: F401 — sets up OMP before numpy +import numpy as np + +# =========================================================================== +# Shared test-case definitions (used in both sub-modes) +# =========================================================================== + +CASES = [ + # forward + backward + dict(B=2, T=256, Hk=4, Hv=4, h0=False, seed=0), + dict(B=1, T=256, Hk=4, Hv=8, h0=False, seed=1), # GQA + dict(B=2, T=128, Hk=4, Hv=4, h0=True, seed=2), + dict(B=1, T=192, Hk=2, Hv=2, h0=False, seed=3), # T not divisible by 64 + dict(B=1, T=128, Hk=2, Hv=2, h0=False, seed=5, bwd=True), + dict(B=1, T=128, Hk=2, Hv=4, h0=False, seed=6, bwd=True), # GQA bwd + dict(B=1, T=128, Hk=2, Hv=2, h0=True, seed=7, bwd=True), +] +K_DIM = V_DIM = 128 + + +def make_inputs(rng, B, T, Hk, Hv, with_h0, with_bwd=False): + q = rng.standard_normal((B, T, Hk, K_DIM)).astype(np.float32) + k = rng.standard_normal((B, T, Hk, K_DIM)).astype(np.float32) + # L2-normalize k so KKT triangular solve stays stable in float32. + k /= np.linalg.norm(k, axis=-1, keepdims=True) + 1e-8 + v = rng.standard_normal((B, T, Hv, V_DIM)).astype(np.float32) + g = (-np.log1p(np.exp(-rng.standard_normal((B, T, Hv)).astype(np.float32))) / 16) + beta = (1 / (1 + np.exp(-rng.standard_normal((B, T, Hv)).astype(np.float32)))) + h0 = (rng.standard_normal((B, Hv, K_DIM, V_DIM)).astype(np.float32) + if with_h0 else None) + do = (rng.standard_normal((B, T, Hv, V_DIM)).astype(np.float32) + if with_bwd else None) + dht = (rng.standard_normal((B, Hv, K_DIM, V_DIM)).astype(np.float32) + if (with_bwd and with_h0) else None) + return q, k, v, g, beta, h0, do, dht + + +# =========================================================================== +# Mode: --gen-ref (pure PyTorch, no MLX) +# =========================================================================== + +def _gen_ref(data_dir): + import torch + + def pad(x, dim, cs=64): + amt = (cs - x.shape[dim] % cs) % cs + if amt > 0: # F.pad with 6-tuple on 4D tensors segfaults in Py3.14 + torch2.11 when amt==0 + zeros = [0] * (2 * (x.dim() - 1 - dim)) + x = torch.nn.functional.pad(x, (*zeros, 0, amt)) + return x.reshape(list(x.shape[:dim]) + [-1, cs] + list(x.shape[dim + 1:])) + + def fill_g(g, T, cs=64): + lcs = T % cs + if lcs: + g = g.clone() + g[:, -1, lcs:] = g[:, -1, lcs - 1:lcs] + return g + + def ref_fwd(q, k, v, g, beta, scale, h0=None, cs=64): + B, T, Hk, K = q.shape + _, _, Hv, V = v.shape + g = pad(g, 1, cs).cumsum(2).reshape(B, -1, Hv)[:, :T] + if Hk != Hv: + k = k.repeat_interleave(Hv // Hk, dim=2) + kc = pad(k, 1, cs); gc = pad(g, 1, cs); bc = pad(beta, 1, cs) + mask_u = torch.triu(torch.ones(cs, cs, dtype=torch.bool)) + decay = torch.exp(gc[:, :, :, None, :] - gc[:, :, None, :, :]) + decay = decay.masked_fill(mask_u[None, None, :, :, None], 0.0) + A = (torch.einsum("bnchk,bndhk->bnchd", kc * bc.unsqueeze(-1), kc) + * decay.swapaxes(-2, -1)).reshape(B, -1, Hv, cs)[:, :T] + Ac = -pad(A, 1, cs).swapaxes(2, 3) + for i in range(1, cs): + row = Ac[..., i, :i].clone() + sub = Ac[..., :i, :i].clone() + Ac[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + Ac = Ac + torch.eye(cs, dtype=Ac.dtype) + A = Ac.swapaxes(2, 3).reshape(B, -1, Hv, cs)[:, :T] + Ac2 = pad(A, 1) + kb = pad(k * (beta * g.exp()).unsqueeze(-1), 1) + vb = pad(v * beta.unsqueeze(-1), 1) + w = torch.einsum("bnchd,bndhk->bnchk", Ac2, kb).reshape(B, -1, Hv, K)[:, :T] + u = torch.einsum("bnchd,bndhv->bnchv", Ac2, vb).reshape(B, -1, Hv, V)[:, :T] + kc2 = pad(k, 1, cs); wc = pad(w, 1, cs); uc = pad(u, 1, cs) + gc2 = fill_g(pad(g, 1, cs), T, cs) + hs = (torch.zeros(B, Hv, K, V, dtype=g.dtype) + if h0 is None else h0.to(g.dtype)) + hl, vl = [], [] + for i in range(kc2.shape[1]): + hl.append(hs.clone()) + vn = uc[:, i] - torch.einsum("bchk,bhkv->bchv", wc[:, i], hs) + vl.append(vn) + hs = hs * gc2[:, i, -1, :].exp()[:, :, None, None] + hs = hs + torch.einsum( + "bchk,bchv->bhkv", + kc2[:, i] * (gc2[:, i, -1:, :, None] - gc2[:, i, :, :, None]).exp(), + vn, + ) + hc = torch.stack(hl, 1) + vn = torch.stack(vl, 1).reshape(B, -1, Hv, V)[:, :T] + qr = q.repeat_interleave(Hv // Hk, dim=2) if Hk != Hv else q + kr = k + qc = pad(qr, 1, cs) * scale; kc3 = pad(kr, 1, cs) + vnc = pad(vn, 1, cs); gc3 = pad(g, 1, cs) + mask_o = torch.triu(torch.ones(cs, cs, dtype=torch.bool), diagonal=1) + dec = torch.exp(gc3[:, :, :, None, :] - gc3[:, :, None, :, :]) + dec = dec.masked_fill(mask_o[None, None, :, :, None], 0.0) + at = torch.einsum("bnchk,bndhk->bncdh", qc, kc3) * dec + ai = torch.einsum("bnchk,bnhkv->bnchv", qc * gc3.exp().unsqueeze(-1), hc) + o = (ai + torch.einsum("bncdh,bndhv->bnchv", at, vnc)).reshape(B, -1, Hv, V)[:, :T] + return g, o, A, hs + + for idx, case in enumerate(CASES): + B, T, Hk, Hv = case["B"], case["T"], case["Hk"], case["Hv"] + rng = np.random.default_rng(case["seed"]) + scale = K_DIM ** -0.5 + q, k, v, g, beta, h0, do, dht = make_inputs( + rng, B, T, Hk, Hv, case["h0"], case.get("bwd", False) + ) + tq = torch.from_numpy(q.copy()); tk = torch.from_numpy(k.copy()) + tv = torch.from_numpy(v.copy()); tg = torch.from_numpy(g.copy()) + tb = torch.from_numpy(beta.copy()) + th0 = torch.from_numpy(h0.copy()) if h0 is not None else None + g_ref, o_ref, A_ref, s_ref = ref_fwd(tq, tk, tv, tg, tb, scale, th0) + d = { + "case": case, + "inputs": dict(q=q, k=k, v=v, g=g, beta=beta, h0=h0), + "fwd": dict( + g_cumsum=g_ref.numpy(), o=o_ref.numpy(), + A=A_ref.numpy(), s=s_ref.numpy(), + ), + } + if case.get("bwd"): + tq2 = torch.from_numpy(q.copy()).requires_grad_(True) + tk2 = torch.from_numpy(k.copy()).requires_grad_(True) + tv2 = torch.from_numpy(v.copy()).requires_grad_(True) + tg2 = torch.from_numpy(g.copy()).requires_grad_(True) + tb2 = torch.from_numpy(beta.copy()).requires_grad_(True) + th02 = (torch.from_numpy(h0.copy()).requires_grad_(True) + if h0 is not None else None) + _, o2, _, s2 = ref_fwd(tq2, tk2, tv2, tg2, tb2, scale, th02) + loss = (o2 * torch.from_numpy(do.copy())).sum() + if dht is not None: + loss = loss + (s2 * torch.from_numpy(dht.copy())).sum() + loss.backward() + d["bwd"] = dict( + do=do, dht=dht, + dq=tq2.grad.numpy(), dk=tk2.grad.numpy(), + dv=tv2.grad.numpy(), db=tb2.grad.numpy(), dg=tg2.grad.numpy(), + dh0=th02.grad.numpy() if th02 is not None else None, + ) + np.save(os.path.join(data_dir, f"case_{idx}.npy"), d, allow_pickle=True) + print(f"[gen-ref] wrote {len(CASES)} cases to {data_dir}") + + +# =========================================================================== +# Mode: --check-mlx (pure MLX, no PyTorch) +# =========================================================================== + +def _check_mlx(data_dir, tol=0.02): + sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) + import mlx.core as mx + from flash_qla_mlx import ( + chunk_gated_delta_rule_fwd as mlx_fwd, + chunk_gated_delta_rule_bwd as mlx_bwd, + chunk_gated_delta_rule, + ) + + def _mx(a): return mx.array(a) + def _np(a): mx.eval(a); return np.array(a) + def rel(got, ref): + g = np.asarray(got, dtype=np.float64) + r = np.asarray(ref, dtype=np.float64) + e = float(np.abs(g - r).max() / (np.abs(r).max() + 1e-8)) + return e if np.isfinite(e) else float("inf") + + files = sorted(f for f in os.listdir(data_dir) if f.endswith(".npy")) + all_passed = True + + for fname in files: + d = np.load(os.path.join(data_dir, fname), allow_pickle=True).item() + case = d["case"] + inp = d["inputs"] + fwd = d["fwd"] + q, k, v, g, beta = inp["q"], inp["k"], inp["v"], inp["g"], inp["beta"] + h0 = inp["h0"] + B, T, Hk, Hv = case["B"], case["T"], case["Hk"], case["Hv"] + scale = K_DIM ** -0.5 + + g_out, A_out, o_out, _, s_out = mlx_fwd( + q=_mx(q), k=_mx(k), v=_mx(v), g=_mx(g), beta=_mx(beta), + scale=scale, + initial_state=_mx(h0) if h0 is not None else None, + output_final_state=True, + ) + mx.eval(g_out, A_out, o_out) + errs_fwd = { + "o": rel(_np(o_out), fwd["o"]), + "g": rel(_np(g_out), fwd["g_cumsum"]), + "A": rel(_np(A_out), fwd["A"]), + } + if h0 is not None and s_out is not None: + mx.eval(s_out) + errs_fwd["s"] = rel(_np(s_out), fwd["s"]) + + tag = f"B={B} T={T} Hk={Hk} Hv={Hv} h0={case['h0']}" + status = "PASS" + for n, e in errs_fwd.items(): + if e > tol: + status = f"FAIL({n}={e:.4f})" + all_passed = False + print(f" [fwd] {tag} " + " ".join(f"{n}={e:.4f}" for n, e in errs_fwd.items()) + + f" {status}") + + if "bwd" in d: + bwd = d["bwd"] + dq_a, dk_a, dv_a, db_a, dg_a, dh0_a = mlx_bwd( + q=_mx(q), k=_mx(k), v=_mx(v), g=g_out, beta=_mx(beta), A=A_out, + do=_mx(bwd["do"]), + dht=_mx(bwd["dht"]) if bwd["dht"] is not None else None, + scale=scale, + initial_state=_mx(h0) if h0 is not None else None, + ) + mx.eval(dq_a, dk_a, dv_a, db_a, dg_a) + errs_bwd = { + "dq": rel(_np(dq_a), bwd["dq"]), + "dk": rel(_np(dk_a), bwd["dk"]), + "dv": rel(_np(dv_a), bwd["dv"]), + "db": rel(_np(db_a), bwd["db"]), + "dg": rel(_np(dg_a), bwd["dg"]), + } + if case["h0"] and dh0_a is not None: + mx.eval(dh0_a) + errs_bwd["dh0"] = rel(_np(dh0_a), bwd["dh0"]) + status = "PASS" + for n, e in errs_bwd.items(): + if e > tol: + status = f"FAIL({n}={e:.4f})" + all_passed = False + print(f" [bwd] {tag} " + " ".join(f"{n}={e:.4f}" for n, e in errs_bwd.items()) + + f" {status}") + + # Autograd smoke: mx.grad(chunk_gated_delta_rule) == mlx_bwd + rng = np.random.default_rng(42) + q2, k2, v2, g2, beta2, _, do2, _ = make_inputs(rng, 1, 128, 2, 2, False, True) + scale2 = K_DIM ** -0.5 + g_o2, A_o2, _, _, _ = mlx_fwd( + q=_mx(q2), k=_mx(k2), v=_mx(v2), g=_mx(g2), beta=_mx(beta2), + scale=scale2, output_final_state=False, + ) + mx.eval(g_o2, A_o2) + dq_e, dk_e, dv_e, db_e, dg_e, _ = mlx_bwd( + q=_mx(q2), k=_mx(k2), v=_mx(v2), g=g_o2, beta=_mx(beta2), A=A_o2, + do=_mx(do2), scale=scale2, + ) + mx.eval(dq_e, dk_e, dv_e, db_e, dg_e) + mdo = _mx(do2) + def loss_fn(q_, k_, v_, g_, b_): + o, _ = chunk_gated_delta_rule(q_, k_, v_, g_, b_, scale=scale2) + return (o * mdo).sum() + dq_ag, dk_ag, dv_ag, dg_ag, db_ag = mx.grad(loss_fn, argnums=(0,1,2,3,4))( + _mx(q2), _mx(k2), _mx(v2), _mx(g2), _mx(beta2) + ) + mx.eval(dq_ag, dk_ag, dv_ag, db_ag, dg_ag) + ag_errs = { + "dq": rel(_np(dq_e), _np(dq_ag)), + "dk": rel(_np(dk_e), _np(dk_ag)), + "dv": rel(_np(dv_e), _np(dv_ag)), + } + status = "PASS" + for n, e in ag_errs.items(): + if e > 0.02: + status = f"FAIL({n}={e:.4f})" + all_passed = False + print(f" [autograd] " + " ".join(f"{n}={e:.4f}" for n, e in ag_errs.items()) + + f" {status}") + + return all_passed + + +# =========================================================================== +# Orchestrator +# =========================================================================== + +def main(): + if "--gen-ref" in sys.argv: + _gen_ref(sys.argv[sys.argv.index("--gen-ref") + 1]) + return + if "--check-mlx" in sys.argv: + ok = _check_mlx(sys.argv[sys.argv.index("--check-mlx") + 1]) + sys.exit(0 if ok else 1) + + print("=" * 64) + print("flash_qla_mlx numerical correctness (PyTorch CPU reference)") + print("=" * 64) + + env = {**os.environ, "KMP_DUPLICATE_LIB_OK": "TRUE"} + + with tempfile.TemporaryDirectory() as tmp: + print("\n[1/2] PyTorch CPU reference...") + r = subprocess.run( + [sys.executable, __file__, "--gen-ref", tmp], + capture_output=True, text=True, env=env, + ) + if r.returncode != 0: + print("STDOUT:", r.stdout) + print("STDERR:", r.stderr) + sys.exit(1) + print(r.stdout.strip()) + + print("\n[2/2] flash_qla_mlx (MLX)...") + r = subprocess.run( + [sys.executable, __file__, "--check-mlx", tmp], + capture_output=True, text=True, env=env, + ) + print(r.stdout.strip()) + if r.returncode != 0: + print("STDERR:", r.stderr) + print("\nSome tests FAILED.") + sys.exit(1) + + print("\nAll tests passed.") + + +if __name__ == "__main__": + main() From 735687220271cdc96e0bc699435855b9d7e618fb Mon Sep 17 00:00:00 2001 From: "clandestine.eth" <96172957+0xClandestine@users.noreply.github.com> Date: Sun, 3 May 2026 13:33:03 -0500 Subject: [PATCH 03/11] refactor: replace _kkt_solve in-place mutation with functional accumulation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Each `x[..., i, :i] = new_val` scatter-update was triggering a full copy-on-write of the [B, N, H, C, C] array — 63 copies per forward pass. Replace with a row-by-row matmul accumulation that builds the result functionally, eliminating the redundant allocations and making the computation graph safe for mx.compile tracing. --- flash_qla_mlx/ops/gated_delta_rule/chunk.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/flash_qla_mlx/ops/gated_delta_rule/chunk.py b/flash_qla_mlx/ops/gated_delta_rule/chunk.py index b70b2f8..9ca3b30 100644 --- a/flash_qla_mlx/ops/gated_delta_rule/chunk.py +++ b/flash_qla_mlx/ops/gated_delta_rule/chunk.py @@ -95,13 +95,18 @@ def _kkt_solve( x = -pad_and_reshape(x, dim=1, chunk_size=chunk_size) # [B, N, C, H, D] x = mx.swapaxes(x, 2, 3) # [B, N, H, C, D] + # Forward substitution without in-place mutation. + # Accumulate solved rows into a growing matrix [B, N, H, i, C]. + accumulated = x[..., 0:1, :] # row 0 unchanged, [B, N, H, 1, C] for i in range(1, chunk_size): - row = x[..., i, :i] # [B, N, H, i] - sub = x[..., :i, :i] # [B, N, H, i, i] - new_val = row + (row[..., None] * sub).sum(axis=-2) - x[..., i, :i] = new_val - - x = x + mx.eye(chunk_size, dtype=x.dtype) + sub_i = accumulated[..., :i] # [B, N, H, i, i] + row_i = x[..., i:i+1, :i] # [B, N, H, 1, i] + new_val = row_i + mx.matmul(row_i, sub_i) # [B, N, H, 1, i] + pad = mx.zeros((*new_val.shape[:-1], chunk_size - i), dtype=x.dtype) + new_row = mx.concatenate([new_val, pad], axis=-1) # [B, N, H, 1, C] + accumulated = mx.concatenate([accumulated, new_row], axis=-2) + + x = accumulated + mx.eye(chunk_size, dtype=x.dtype) x = mx.swapaxes(x, 2, 3) # [B, N, C, H, D] x = x.reshape(batch_size, -1, num_heads, chunk_size)[:, :num_tokens] From e3d59ad2e609fe5499d65981aacda237552ec78f Mon Sep 17 00:00:00 2001 From: "clandestine.eth" <96172957+0xClandestine@users.noreply.github.com> Date: Sun, 3 May 2026 13:41:55 -0500 Subject: [PATCH 04/11] perf: add mx.compile for fwd/bwd when cu_seqlens=None For fixed-shape (non-variable-length) inputs, wrap fwd and bwd with mx.compile so MLX traces through all Python loops once per unique shape and unrolls them into a static Metal graph. This eliminates per-step Python dispatch overhead for both the 63-iteration _kkt_solve loop and the N-chunk _chunk_gdr_fwd/bwd loop. cu_seqlens paths fall back to the uncompiled functions because pack/unpack call .item() on tensor values during execution, which is incompatible with tracing. --- flash_qla_mlx/ops/gated_delta_rule/chunk.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/flash_qla_mlx/ops/gated_delta_rule/chunk.py b/flash_qla_mlx/ops/gated_delta_rule/chunk.py index 9ca3b30..f7d9e5b 100644 --- a/flash_qla_mlx/ops/gated_delta_rule/chunk.py +++ b/flash_qla_mlx/ops/gated_delta_rule/chunk.py @@ -718,6 +718,14 @@ def chunk_gated_delta_rule_bwd( return dq, dk, dv_out, db, dg, dh0 +# Compiled variants for the common cu_seqlens=None path. +# mx.compile traces all Python loops once at first call per shape, unrolling +# them into a static Metal graph. Falls back to uncompiled for variable-length +# (cu_seqlens) inputs because those paths call .item() during tracing. +_compiled_fwd = mx.compile(chunk_gated_delta_rule_fwd) +_compiled_bwd = mx.compile(chunk_gated_delta_rule_bwd) + + def chunk_gated_delta_rule( q: mx.array, k: mx.array, @@ -771,11 +779,13 @@ def chunk_gated_delta_rule( # We use closures to capture scale, cu_seqlens, and the presence of h0. _scale = scale _cu_seqlens = cu_seqlens + _fwd = _compiled_fwd if cu_seqlens is None else chunk_gated_delta_rule_fwd + _bwd = _compiled_bwd if cu_seqlens is None else chunk_gated_delta_rule_bwd if initial_state is not None: @mx.custom_function def _fn(q, k, v, g, beta, h0): - g_out, A, o, _, fs = chunk_gated_delta_rule_fwd( + g_out, A, o, _, fs = _fwd( q=q, k=k, v=v, g=g, beta=beta, scale=_scale, initial_state=h0, cu_seqlens=_cu_seqlens, @@ -789,7 +799,7 @@ def _fn_vjp(primals, cotangents, outputs): q, k, v, g, beta, h0 = primals do, dfs, _, _ = cotangents _, _, g_out, A = outputs - dq, dk, dv, db, dg, dh0 = chunk_gated_delta_rule_bwd( + dq, dk, dv, db, dg, dh0 = _bwd( q=q, k=k, v=v, g=g_out, beta=beta, A=A, do=do, dht=dfs, scale=_scale, initial_state=h0, cu_seqlens=_cu_seqlens, @@ -800,7 +810,7 @@ def _fn_vjp(primals, cotangents, outputs): else: @mx.custom_function def _fn(q, k, v, g, beta): - g_out, A, o, _, fs = chunk_gated_delta_rule_fwd( + g_out, A, o, _, fs = _fwd( q=q, k=k, v=v, g=g, beta=beta, scale=_scale, initial_state=None, cu_seqlens=_cu_seqlens, @@ -814,7 +824,7 @@ def _fn_vjp(primals, cotangents, outputs): q, k, v, g, beta = primals do, _, _, _ = cotangents _, _, g_out, A = outputs - dq, dk, dv, db, dg, _ = chunk_gated_delta_rule_bwd( + dq, dk, dv, db, dg, _ = _bwd( q=q, k=k, v=v, g=g_out, beta=beta, A=A, do=do, dht=None, scale=_scale, initial_state=None, cu_seqlens=_cu_seqlens, From ef5fbf0c3243da6b3c01b623091101cf6f34f29c Mon Sep 17 00:00:00 2001 From: "clandestine.eth" <96172957+0xClandestine@users.noreply.github.com> Date: Sun, 3 May 2026 14:02:41 -0500 Subject: [PATCH 05/11] perf: vectorize pack/unpack/prepare_chunk_offsets with gather ops Replace per-batch-element Python loops (.item() calls) with pure MLX gather operations: unpack uses arange+clip+fancy-index, pack uses a broadcast comparison to map each packed position to its (batch, time) source, and prepare_chunk_offsets uses cumsum instead of a Python accumulation loop. Reduces kernel dispatch count from O(B) to O(1) per call. Still requires one .item() call for the output shape (sum_T / max_len), which is an unavoidable consequence of dynamic sequence lengths. --- flash_qla_mlx/utils/pack.py | 61 ++++++++++++++++++------------------- 1 file changed, 29 insertions(+), 32 deletions(-) diff --git a/flash_qla_mlx/utils/pack.py b/flash_qla_mlx/utils/pack.py index dd244fd..1b30f7d 100644 --- a/flash_qla_mlx/utils/pack.py +++ b/flash_qla_mlx/utils/pack.py @@ -9,37 +9,35 @@ def unpack( cu_seqlens: mx.array, ) -> mx.array: assert x.shape[0] == 1 - batch_size = cu_seqlens.shape[0] - 1 - seqlens = [ - int((cu_seqlens[i + 1] - cu_seqlens[i]).item()) for i in range(batch_size) - ] - max_len = max(seqlens) - rest = x.shape[2:] - - parts = [] - for i in range(batch_size): - start = int(cu_seqlens[i].item()) - end = int(cu_seqlens[i + 1].item()) - chunk = x[0, start:end] - if end - start < max_len: - pad_width = [(0, max_len - (end - start))] + [(0, 0)] * len(rest) - chunk = mx.pad(chunk, pad_width) - parts.append(chunk) - return mx.stack(parts, axis=0) # [B, max_len, *dims] + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] # [B] + max_len = int(mx.max(seqlens).item()) + + t = mx.arange(max_len, dtype=mx.int32) # [max_len] + src = mx.expand_dims(cu_seqlens[:-1], 1) + t # [B, max_len] + src = mx.clip(src, 0, x.shape[1] - 1) + + out = x[0][src] # [B, max_len, *dims] + + valid = t < mx.expand_dims(seqlens, 1) # [B, max_len] + for _ in x.shape[2:]: + valid = mx.expand_dims(valid, -1) + return mx.where(valid, out, mx.zeros_like(out)) def pack( x: mx.array, # [B, max_T, *dims] cu_seqlens: mx.array, ) -> mx.array: - batch_size = cu_seqlens.shape[0] - 1 - parts = [] - for i in range(batch_size): - start = int(cu_seqlens[i].item()) - end = int(cu_seqlens[i + 1].item()) - parts.append(x[i, : end - start]) - packed = mx.concatenate(parts, axis=0) # [sum_T, *dims] - return packed[None] # [1, sum_T, *dims] + sum_T = int(cu_seqlens[-1].item()) + + i = mx.arange(sum_T, dtype=mx.int32) # [sum_T] + # b_idx[i] = batch element that packed position i belongs to + b_idx = ( + mx.expand_dims(i, 1) >= mx.expand_dims(cu_seqlens[1:], 0) + ).sum(axis=1).astype(mx.int32) # [sum_T] + t_idx = (i - cu_seqlens[b_idx]).astype(mx.int32) # [sum_T] + + return x[b_idx, t_idx][None] # [1, sum_T, *dims] def pad_and_reshape( @@ -138,10 +136,9 @@ def fill_last_chunk_of_g( def prepare_chunk_offsets(cu_seqlens: mx.array, chunk_size: int) -> mx.array: - batch_size = cu_seqlens.shape[0] - 1 - offsets = [0] - for i in range(batch_size): - seqlen = int((cu_seqlens[i + 1] - cu_seqlens[i]).item()) - n_chunks = (seqlen + chunk_size - 1) // chunk_size - offsets.append(offsets[-1] + n_chunks) - return mx.array(offsets, dtype=mx.int32) + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] # [B] + n_chunks = (seqlens + chunk_size - 1) // chunk_size # ceiling division + return mx.concatenate([ + mx.array([0], dtype=mx.int32), + mx.cumsum(n_chunks).astype(mx.int32), + ]) From 274d445876f6a654490c91f83bbdba80c2db08eb Mon Sep 17 00:00:00 2001 From: "clandestine.eth" <96172957+0xClandestine@users.noreply.github.com> Date: Sun, 3 May 2026 14:10:16 -0500 Subject: [PATCH 06/11] feat: expose chunk_size in public API, fix chunk_size bugs, add MLX benchmark MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit chunk_size is now a parameter of chunk_gated_delta_rule, chunk_gated_delta_rule_fwd, and chunk_gated_delta_rule_bwd (default 64). This enables hardware-specific tuning — Apple Silicon benchmarks show cs=32 is 13-27% faster for Qwen-27B forward (T>=1024) while cs=64 remains better for backward/training. Also fixes two pre-existing bugs where pad_and_reshape(A, dim=1) in _w_u_fwd and _chunk_wy_bwd omitted the chunk_size argument, causing silent shape mismatches for any chunk_size other than 64. benchmark/bench_mlx.py sweeps sequence length and chunk_size for Apple Silicon; run with --B 1 --Hk 8 --Hv 32 --K 128 --V 128 for Qwen-27B dimensions. --- benchmark/bench_mlx.py | 155 ++++++++++++++++++++ flash_qla_mlx/ops/gated_delta_rule/chunk.py | 19 ++- 2 files changed, 168 insertions(+), 6 deletions(-) create mode 100644 benchmark/bench_mlx.py diff --git a/benchmark/bench_mlx.py b/benchmark/bench_mlx.py new file mode 100644 index 0000000..52e7820 --- /dev/null +++ b/benchmark/bench_mlx.py @@ -0,0 +1,155 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] +# +# Apple Silicon benchmark for flash_qla_mlx. +# Measures forward and backward throughput across sequence lengths and +# chunk sizes to find the optimal chunk_size for your hardware. +# +# Usage: +# python benchmark/bench_mlx.py # default sweep +# python benchmark/bench_mlx.py --fwd-only # skip backward +# python benchmark/bench_mlx.py --T 512 1024 2048 +# python benchmark/bench_mlx.py --chunk-sizes 32 64 128 + +import argparse +import time + +import mlx.core as mx + +from flash_qla_mlx import chunk_gated_delta_rule + + +# --------------------------------------------------------------------------- +# Timing helpers +# --------------------------------------------------------------------------- + +def _warmup_and_time(fn, *args, warmup: int = 3, iters: int = 10) -> float: + """Return median wall-clock time in ms over `iters` runs.""" + for _ in range(warmup): + out = fn(*args) + mx.eval(out) + + times = [] + for _ in range(iters): + t0 = time.perf_counter() + out = fn(*args) + mx.eval(out) + times.append((time.perf_counter() - t0) * 1000) + + times.sort() + return times[len(times) // 2] # median + + +# --------------------------------------------------------------------------- +# Benchmark runners +# --------------------------------------------------------------------------- + +def bench_fwd(B, T, Hk, Hv, K, V, chunk_size, dtype): + mx.random.seed(0) + q = mx.random.normal((B, T, Hk, K)).astype(dtype) + k = mx.random.normal((B, T, Hk, K)).astype(dtype) + v = mx.random.normal((B, T, Hv, V)).astype(dtype) + g = -mx.abs(mx.random.normal((B, T, Hv))).astype(dtype) + beta = mx.sigmoid(mx.random.normal((B, T, Hv))).astype(dtype) + + def run(): + o, _ = chunk_gated_delta_rule(q, k, v, g, beta, chunk_size=chunk_size) + return o + + ms = _warmup_and_time(run) + toks = B * T + return ms, toks / ms * 1e3 # tokens/s + + +def bench_bwd(B, T, Hk, Hv, K, V, chunk_size, dtype): + mx.random.seed(0) + q = mx.random.normal((B, T, Hk, K)).astype(dtype) + k = mx.random.normal((B, T, Hk, K)).astype(dtype) + v = mx.random.normal((B, T, Hv, V)).astype(dtype) + g = -mx.abs(mx.random.normal((B, T, Hv))).astype(dtype) + beta = mx.sigmoid(mx.random.normal((B, T, Hv))).astype(dtype) + + def run(): + loss, _ = chunk_gated_delta_rule(q, k, v, g, beta, chunk_size=chunk_size) + grad = mx.grad( + lambda q, k, v, g, beta: chunk_gated_delta_rule( + q, k, v, g, beta, chunk_size=chunk_size + )[0].sum() + )(q, k, v, g, beta) + return grad + + ms = _warmup_and_time(run) + toks = B * T + return ms, toks / ms * 1e3 + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser(description="flash_qla_mlx Apple Silicon benchmark") + parser.add_argument("--B", type=int, default=2, help="Batch size") + parser.add_argument("--Hk", type=int, default=4, help="Key/query heads") + parser.add_argument("--Hv", type=int, default=4, help="Value heads") + parser.add_argument("--K", type=int, default=64, help="Key/query head dim") + parser.add_argument("--V", type=int, default=64, help="Value head dim") + parser.add_argument( + "--T", type=int, nargs="+", default=[256, 512, 1024, 2048, 4096], + help="Sequence lengths to benchmark", + ) + parser.add_argument( + "--chunk-sizes", type=int, nargs="+", default=[32, 64, 128], + help="chunk_size values to sweep", + ) + parser.add_argument("--fwd-only", action="store_true", help="Skip backward pass") + parser.add_argument("--dtype", choices=["float32", "float16", "bfloat16"], + default="float32") + args = parser.parse_args() + + dtype_map = {"float32": mx.float32, "float16": mx.float16, "bfloat16": mx.bfloat16} + dtype = dtype_map[args.dtype] + + passes = ["fwd"] if args.fwd_only else ["fwd", "bwd"] + bench_fns = {"fwd": bench_fwd, "bwd": bench_bwd} + + for pass_name in passes: + bench = bench_fns[pass_name] + col_w = 12 + header = f"{'T':>6} " + " ".join( + f"cs={cs:>3} ms Ktok/s" for cs in args.chunk_sizes + ) + sep = "-" * len(header) + + print(f"\n{pass_name.upper()} pass " + f"(B={args.B} Hk={args.Hk} Hv={args.Hv} " + f"K={args.K} V={args.V} dtype={args.dtype})") + print(header) + print(sep) + + for T in args.T: + if T % max(args.chunk_sizes) != 0: + # pad check: skip sizes where T < chunk_size + valid_cs = [cs for cs in args.chunk_sizes if cs <= T] + else: + valid_cs = args.chunk_sizes + + row = f"{T:>6} " + for cs in args.chunk_sizes: + if cs > T: + row += f"{'--':>8} {'--':>6} " + continue + try: + ms, tps = bench( + args.B, T, args.Hk, args.Hv, args.K, args.V, cs, dtype + ) + row += f"{ms:>8.2f} {tps/1000:>6.1f} " + except Exception as e: + row += f"{'ERR':>8} {'':>6} " + print(row) + + print() + + +if __name__ == "__main__": + main() diff --git a/flash_qla_mlx/ops/gated_delta_rule/chunk.py b/flash_qla_mlx/ops/gated_delta_rule/chunk.py index f7d9e5b..473338e 100644 --- a/flash_qla_mlx/ops/gated_delta_rule/chunk.py +++ b/flash_qla_mlx/ops/gated_delta_rule/chunk.py @@ -156,7 +156,7 @@ def _w_u_fwd( v_beta = pad_and_reshape( v * beta[..., None], dim=1, chunk_size=chunk_size ) # [B, N, C, Hv, V] - A = pad_and_reshape(A, dim=1) # [B, N, C, Hv, D] + A = pad_and_reshape(A, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, D] w = mx.einsum("bnchd,bndhk->bnchk", A, k_beta).reshape( batch_size, -1, num_v_heads, head_dim_k @@ -532,7 +532,7 @@ def _chunk_wy_bwd( k = pad_and_reshape(k, dim=1, chunk_size=chunk_size_A) v = pad_and_reshape(v, dim=1, chunk_size=chunk_size_A) beta = pad_and_reshape(beta, dim=1, chunk_size=chunk_size_A) - A = pad_and_reshape(A, dim=1) + A = pad_and_reshape(A, dim=1, chunk_size=chunk_size_A) g = pad_and_reshape(g, dim=1, chunk_size=chunk_size_A) dw = pad_and_reshape(dw, dim=1, chunk_size=chunk_size_A) du = pad_and_reshape(du, dim=1, chunk_size=chunk_size_A) @@ -608,6 +608,7 @@ def chunk_gated_delta_rule_fwd( cu_seqlens: mx.array = None, output_final_state: bool = True, output_h: bool = False, + chunk_size: int = 64, ) -> tuple: """ Forward pass of the Gated Delta Rule. @@ -623,6 +624,7 @@ def chunk_gated_delta_rule_fwd( cu_seqlens: Optional [S+1] int32, for variable-length inputs output_final_state: Whether to return the final state output_h: Whether to return the chunk-level states h + chunk_size: Tokens per chunk (default 64; tune for your hardware) Returns: (g_cumsum, A, o, h, final_state) @@ -631,8 +633,6 @@ def chunk_gated_delta_rule_fwd( if scale is None: scale = k.shape[-1] ** -0.5 - chunk_size = 64 - g_cumsum = _chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens) A = _kkt(k=k, beta=beta, g=g_cumsum, cu_seqlens=cu_seqlens, chunk_size=chunk_size) w, u = _w_u_fwd(k=k, v=v, beta=beta, A=A, g=g_cumsum, cu_seqlens=cu_seqlens) @@ -667,6 +667,7 @@ def chunk_gated_delta_rule_bwd( scale: float = None, initial_state: mx.array = None, cu_seqlens: mx.array = None, + chunk_size: int = 64, ) -> tuple: """ Backward pass of the Gated Delta Rule. @@ -676,6 +677,7 @@ def chunk_gated_delta_rule_bwd( do: Gradient of the output [B, T, Hv, V] dht: Gradient of the final state [B, Hv, K, V] or None scale, initial_state, cu_seqlens: same as forward + chunk_size: Must match the value used in the forward pass Returns: (dq, dk, dv, db, dg, dh0) @@ -683,8 +685,6 @@ def chunk_gated_delta_rule_bwd( if scale is None: scale = k.shape[-1] ** -0.5 - chunk_size = 64 - w, u = _w_u_fwd(k=k, v=v, beta=beta, A=A, g=g, cu_seqlens=cu_seqlens) h, vn, _ = _chunk_gdr_fwd( k=k, w=w, u=u, g=g, @@ -738,6 +738,7 @@ def chunk_gated_delta_rule( use_qk_l2norm_in_kernel: bool = False, cu_seqlens: mx.array = None, head_first: bool = False, + chunk_size: int = 64, ) -> tuple: """ Gated Delta Rule: end-to-end forward with MLX-native gradient support. @@ -754,6 +755,7 @@ def chunk_gated_delta_rule( use_qk_l2norm_in_kernel: L2-normalize q and k before computation cu_seqlens: Optional variable-length sequence boundaries head_first: Not supported (must be False) + chunk_size: Tokens per chunk (default 64; tune for your hardware) Returns: (o, final_state) — final_state is None if output_final_state=False @@ -779,6 +781,7 @@ def chunk_gated_delta_rule( # We use closures to capture scale, cu_seqlens, and the presence of h0. _scale = scale _cu_seqlens = cu_seqlens + _chunk_size = chunk_size _fwd = _compiled_fwd if cu_seqlens is None else chunk_gated_delta_rule_fwd _bwd = _compiled_bwd if cu_seqlens is None else chunk_gated_delta_rule_bwd @@ -790,6 +793,7 @@ def _fn(q, k, v, g, beta, h0): scale=_scale, initial_state=h0, cu_seqlens=_cu_seqlens, output_final_state=True, output_h=False, + chunk_size=_chunk_size, ) fs_safe = fs if fs is not None else mx.zeros_like(h0) return o, fs_safe, g_out, A @@ -803,6 +807,7 @@ def _fn_vjp(primals, cotangents, outputs): q=q, k=k, v=v, g=g_out, beta=beta, A=A, do=do, dht=dfs, scale=_scale, initial_state=h0, cu_seqlens=_cu_seqlens, + chunk_size=_chunk_size, ) return dq, dk, dv, dg, db, dh0 @@ -815,6 +820,7 @@ def _fn(q, k, v, g, beta): scale=_scale, initial_state=None, cu_seqlens=_cu_seqlens, output_final_state=output_final_state, output_h=False, + chunk_size=_chunk_size, ) _dummy = mx.zeros((1,), dtype=q.dtype) return o, _dummy, g_out, A @@ -828,6 +834,7 @@ def _fn_vjp(primals, cotangents, outputs): q=q, k=k, v=v, g=g_out, beta=beta, A=A, do=do, dht=None, scale=_scale, initial_state=None, cu_seqlens=_cu_seqlens, + chunk_size=_chunk_size, ) return dq, dk, dv, dg, db From 1b1544719b0c4f46fee0550fe09d562614dc78df Mon Sep 17 00:00:00 2001 From: "clandestine.eth" <96172957+0xClandestine@users.noreply.github.com> Date: Sun, 3 May 2026 14:11:17 -0500 Subject: [PATCH 07/11] docs: document why parallel scan is not beneficial for inter-chunk recurrence The recurrence h_i = A_i @ h_{i-1} + b_i supports an associative parallel prefix scan, but composing operators costs O(K^3) vs O(K^2) sequential. For K=128 (Qwen-27B) this is 512-1280x more FLOPs across practical sequence lengths. Sequential with mx.compile unrolling is the correct algorithm. --- flash_qla_mlx/ops/gated_delta_rule/chunk.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/flash_qla_mlx/ops/gated_delta_rule/chunk.py b/flash_qla_mlx/ops/gated_delta_rule/chunk.py index 473338e..a2876b2 100644 --- a/flash_qla_mlx/ops/gated_delta_rule/chunk.py +++ b/flash_qla_mlx/ops/gated_delta_rule/chunk.py @@ -205,6 +205,12 @@ def _chunk_gdr_fwd( else: last_state = initial_state.astype(g.dtype) + # Inter-chunk recurrence: h_i = A_i @ h_{i-1} + b_i (linear recurrence). + # A parallel prefix scan is theoretically possible but costs O(K^3) per + # operator composition vs O(K^2) sequential, giving 512-1280x more FLOPs + # for typical head dims (K=64-128) and sequence lengths. Sequential is + # the correct algorithm here. With mx.compile, this loop is unrolled into + # a static Metal graph so Python dispatch overhead is paid only once. h_list, vn_list = [], [] for i in range(k.shape[1]): h_list.append(last_state) From 320b8a70db3549c43e5f6a2fdcd3450c8718fe21 Mon Sep 17 00:00:00 2001 From: "clandestine.eth" <96172957+0xClandestine@users.noreply.github.com> Date: Sun, 3 May 2026 14:14:44 -0500 Subject: [PATCH 08/11] feat: add SDPA comparison to MLX benchmark --compare-sdpa flag adds a side-by-side table of FlashQLA vs mx.fast.scaled_dot_product_attention (GQA causal) at Qwen-27B dimensions. Dimension mapping: FlashQLA Hk->N_kv, Hv->N_q so both produce the same (B, T, Hv, V) output volume. --- benchmark/bench_mlx.py | 289 ++++++++++++++++++++++++++++------------- 1 file changed, 197 insertions(+), 92 deletions(-) diff --git a/benchmark/bench_mlx.py b/benchmark/bench_mlx.py index 52e7820..50e06b4 100644 --- a/benchmark/bench_mlx.py +++ b/benchmark/bench_mlx.py @@ -2,16 +2,23 @@ # Licensed under The MIT License [see LICENSE for details] # # Apple Silicon benchmark for flash_qla_mlx. -# Measures forward and backward throughput across sequence lengths and -# chunk sizes to find the optimal chunk_size for your hardware. +# Measures forward/backward throughput across sequence lengths and chunk sizes, +# and optionally compares against mx.fast.scaled_dot_product_attention. # # Usage: -# python benchmark/bench_mlx.py # default sweep -# python benchmark/bench_mlx.py --fwd-only # skip backward -# python benchmark/bench_mlx.py --T 512 1024 2048 +# # chunk-size sweep (default) +# python benchmark/bench_mlx.py --B 1 --Hk 8 --Hv 32 --K 128 --V 128 +# +# # Qwen-27B forward-only comparison vs softmax attention +# python benchmark/bench_mlx.py --B 1 --Hk 8 --Hv 32 --K 128 --V 128 \ +# --fwd-only --compare-sdpa +# +# # Short options +# python benchmark/bench_mlx.py --fwd-only --T 512 1024 2048 4096 8192 # python benchmark/bench_mlx.py --chunk-sizes 32 64 128 import argparse +import math import time import mlx.core as mx @@ -20,67 +27,184 @@ # --------------------------------------------------------------------------- -# Timing helpers +# Timing helper # --------------------------------------------------------------------------- -def _warmup_and_time(fn, *args, warmup: int = 3, iters: int = 10) -> float: - """Return median wall-clock time in ms over `iters` runs.""" +def _time(fn, warmup: int = 3, iters: int = 10) -> float: + """Return median wall-clock time in ms.""" for _ in range(warmup): - out = fn(*args) - mx.eval(out) - + mx.eval(fn()) times = [] for _ in range(iters): t0 = time.perf_counter() - out = fn(*args) - mx.eval(out) + mx.eval(fn()) times.append((time.perf_counter() - t0) * 1000) - times.sort() - return times[len(times) // 2] # median + return times[len(times) // 2] # --------------------------------------------------------------------------- -# Benchmark runners +# FlashQLA runners # --------------------------------------------------------------------------- -def bench_fwd(B, T, Hk, Hv, K, V, chunk_size, dtype): +def _qla_inputs(B, T, Hk, Hv, K, V, dtype): mx.random.seed(0) - q = mx.random.normal((B, T, Hk, K)).astype(dtype) - k = mx.random.normal((B, T, Hk, K)).astype(dtype) - v = mx.random.normal((B, T, Hv, V)).astype(dtype) - g = -mx.abs(mx.random.normal((B, T, Hv))).astype(dtype) - beta = mx.sigmoid(mx.random.normal((B, T, Hv))).astype(dtype) + return ( + mx.random.normal((B, T, Hk, K)).astype(dtype), # q + mx.random.normal((B, T, Hk, K)).astype(dtype), # k + mx.random.normal((B, T, Hv, V)).astype(dtype), # v + -mx.abs(mx.random.normal((B, T, Hv))).astype(dtype), # g + mx.sigmoid(mx.random.normal((B, T, Hv))).astype(dtype), # beta + ) - def run(): - o, _ = chunk_gated_delta_rule(q, k, v, g, beta, chunk_size=chunk_size) - return o - ms = _warmup_and_time(run) - toks = B * T - return ms, toks / ms * 1e3 # tokens/s +def bench_qla_fwd(B, T, Hk, Hv, K, V, chunk_size, dtype): + q, k, v, g, beta = _qla_inputs(B, T, Hk, Hv, K, V, dtype) + ms = _time(lambda: chunk_gated_delta_rule(q, k, v, g, beta, + chunk_size=chunk_size)[0]) + return ms, B * T / ms * 1e3 -def bench_bwd(B, T, Hk, Hv, K, V, chunk_size, dtype): - mx.random.seed(0) - q = mx.random.normal((B, T, Hk, K)).astype(dtype) - k = mx.random.normal((B, T, Hk, K)).astype(dtype) - v = mx.random.normal((B, T, Hv, V)).astype(dtype) - g = -mx.abs(mx.random.normal((B, T, Hv))).astype(dtype) - beta = mx.sigmoid(mx.random.normal((B, T, Hv))).astype(dtype) +def bench_qla_bwd(B, T, Hk, Hv, K, V, chunk_size, dtype): + q, k, v, g, beta = _qla_inputs(B, T, Hk, Hv, K, V, dtype) def run(): - loss, _ = chunk_gated_delta_rule(q, k, v, g, beta, chunk_size=chunk_size) - grad = mx.grad( + return mx.grad( lambda q, k, v, g, beta: chunk_gated_delta_rule( q, k, v, g, beta, chunk_size=chunk_size )[0].sum() )(q, k, v, g, beta) - return grad - ms = _warmup_and_time(run) - toks = B * T - return ms, toks / ms * 1e3 + ms = _time(run) + return ms, B * T / ms * 1e3 + + +# --------------------------------------------------------------------------- +# SDPA runner +# +# Dimension mapping from FlashQLA to GQA softmax attention (Qwen-27B style): +# FlashQLA Hk -> SDPA N_kv (KV heads, e.g. 8) +# FlashQLA Hv -> SDPA N_q (query heads, e.g. 32) +# FlashQLA K -> SDPA D (head dim, e.g. 128) +# +# mx.fast.scaled_dot_product_attention expects [B, H, T, D] layout. +# Causal mask is used for a fair comparison (FlashQLA is always causal). +# --------------------------------------------------------------------------- + +def bench_sdpa_fwd(B, T, Hk, Hv, K, V, dtype): + """ + Benchmark mx.fast.scaled_dot_product_attention with GQA: + q : [B, Hv, T, K] — Hv query heads + k : [B, Hk, T, K] — Hk KV heads + v : [B, Hk, T, V] — Hk KV heads + This matches the output dimensionality of FlashQLA (B, T, Hv, V). + """ + mx.random.seed(0) + scale = K ** -0.5 + q = mx.random.normal((B, Hv, T, K)).astype(dtype) + k = mx.random.normal((B, Hk, T, K)).astype(dtype) + v = mx.random.normal((B, Hk, T, V)).astype(dtype) + + ms = _time(lambda: mx.fast.scaled_dot_product_attention( + q, k, v, scale=scale, mask="causal" + )) + return ms, B * T / ms * 1e3 + + +# --------------------------------------------------------------------------- +# Table helpers +# --------------------------------------------------------------------------- + +def _print_qla_table(pass_name, args, dtype): + bench = bench_qla_fwd if pass_name == "fwd" else bench_qla_bwd + header = f"{'T':>6} " + " ".join( + f"cs={cs:>3} ms Ktok/s" for cs in args.chunk_sizes + ) + print(f"\nFlashQLA {pass_name.upper()} " + f"(B={args.B} Hk={args.Hk} Hv={args.Hv} " + f"K={args.K} V={args.V} {args.dtype})") + print(header) + print("-" * len(header)) + for T in args.T: + row = f"{T:>6} " + for cs in args.chunk_sizes: + if cs > T: + row += f"{'--':>10} {'--':>6} " + continue + try: + ms, tps = bench(args.B, T, args.Hk, args.Hv, + args.K, args.V, cs, dtype) + row += f"{ms:>10.2f} {tps/1000:>6.1f} " + except Exception as e: + row += f"{'ERR':>10} {'ERR':>6} " + print(row) + + +def _print_compare_table(args, dtype): + """ + Side-by-side: best FlashQLA chunk_size vs mx.fast.scaled_dot_product_attention. + Runs forward pass only (SDPA has no exposed backward in MLX fast path). + """ + # Pick best chunk_size from a quick pre-sweep + print(f"\nPre-sweeping chunk sizes to find best for T={args.T[0]}...") + best_cs, best_ms = args.chunk_sizes[0], float("inf") + for cs in args.chunk_sizes: + if cs > args.T[0]: + continue + try: + ms, _ = bench_qla_fwd(args.B, args.T[0], args.Hk, args.Hv, + args.K, args.V, cs, dtype) + if ms < best_ms: + best_ms, best_cs = ms, cs + except Exception: + pass + print(f" Best chunk_size = {best_cs}") + + # GQA layout note + print(f"\n SDPA shape: q=[B,{args.Hv},T,{args.K}] " + f"k=v=[B,{args.Hk},T,{args.K}] causal (GQA {args.Hv//args.Hk}:1)") + print(f" FlashQLA: Hk={args.Hk} Hv={args.Hv} K={args.K} V={args.V} " + f"chunk_size={best_cs}") + + col = 14 + header = (f"{'T':>6} {'FlashQLA':>{col}} {'SDPA':>{col}} " + f"{'QLA ms':>8} {'SDPA ms':>8} {'speedup':>8}") + print() + print(header) + print("-" * len(header)) + + for T in args.T: + if best_cs > T: + print(f"{T:>6} (T < chunk_size, skipped)") + continue + + try: + qla_ms, qla_tps = bench_qla_fwd(args.B, T, args.Hk, args.Hv, + args.K, args.V, best_cs, dtype) + except Exception as e: + print(f"{T:>6} QLA ERR: {e}") + continue + + try: + sdpa_ms, sdpa_tps = bench_sdpa_fwd(args.B, T, args.Hk, args.Hv, + args.K, args.V, dtype) + speedup = sdpa_ms / qla_ms + sdpa_str = f"{sdpa_tps/1000:>10.1f} Ktok/s" + sdpa_ms_str = f"{sdpa_ms:>8.2f}" + speedup_str = f"{speedup:>7.2f}x" + except Exception as e: + sdpa_str = f"{'OOM/ERR':>{col}}" + sdpa_ms_str = f"{'--':>8}" + speedup_str = f"{'--':>8}" + + qla_str = f"{qla_tps/1000:>10.1f} Ktok/s" + print(f"{T:>6} {qla_str:>{col}} {sdpa_str:>{col}} " + f"{qla_ms:>8.2f} {sdpa_ms_str} {speedup_str}") + + print() + print(" speedup > 1 means FlashQLA is faster than SDPA") + print(" Note: FlashQLA is O(T) memory; SDPA is O(T²). " + "They are architecturally different primitives.") # --------------------------------------------------------------------------- @@ -88,67 +212,48 @@ def run(): # --------------------------------------------------------------------------- def main(): - parser = argparse.ArgumentParser(description="flash_qla_mlx Apple Silicon benchmark") - parser.add_argument("--B", type=int, default=2, help="Batch size") - parser.add_argument("--Hk", type=int, default=4, help="Key/query heads") - parser.add_argument("--Hv", type=int, default=4, help="Value heads") - parser.add_argument("--K", type=int, default=64, help="Key/query head dim") - parser.add_argument("--V", type=int, default=64, help="Value head dim") + parser = argparse.ArgumentParser( + description="flash_qla_mlx Apple Silicon benchmark", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Qwen-27B forward comparison: + python benchmark/bench_mlx.py \\ + --B 1 --Hk 8 --Hv 32 --K 128 --V 128 \\ + --fwd-only --compare-sdpa \\ + --T 512 1024 2048 4096 8192 16384 +""", + ) + parser.add_argument("--B", type=int, default=1) + parser.add_argument("--Hk", type=int, default=8, help="KV heads (FlashQLA k heads)") + parser.add_argument("--Hv", type=int, default=32, help="Q/output heads (FlashQLA v heads)") + parser.add_argument("--K", type=int, default=128, help="Head dim (key/query)") + parser.add_argument("--V", type=int, default=128, help="Head dim (value)") parser.add_argument( - "--T", type=int, nargs="+", default=[256, 512, 1024, 2048, 4096], - help="Sequence lengths to benchmark", + "--T", type=int, nargs="+", + default=[512, 1024, 2048, 4096, 8192], ) parser.add_argument( - "--chunk-sizes", type=int, nargs="+", default=[32, 64, 128], - help="chunk_size values to sweep", + "--chunk-sizes", type=int, nargs="+", default=[32, 64], + help="chunk_size values to sweep in the QLA table", ) - parser.add_argument("--fwd-only", action="store_true", help="Skip backward pass") - parser.add_argument("--dtype", choices=["float32", "float16", "bfloat16"], + parser.add_argument("--fwd-only", action="store_true") + parser.add_argument("--compare-sdpa", action="store_true", + help="Add SDPA vs FlashQLA comparison table") + parser.add_argument("--dtype", + choices=["float32", "float16", "bfloat16"], default="float32") args = parser.parse_args() - dtype_map = {"float32": mx.float32, "float16": mx.float16, "bfloat16": mx.bfloat16} - dtype = dtype_map[args.dtype] + dtype = {"float32": mx.float32, + "float16": mx.float16, + "bfloat16": mx.bfloat16}[args.dtype] passes = ["fwd"] if args.fwd_only else ["fwd", "bwd"] - bench_fns = {"fwd": bench_fwd, "bwd": bench_bwd} - - for pass_name in passes: - bench = bench_fns[pass_name] - col_w = 12 - header = f"{'T':>6} " + " ".join( - f"cs={cs:>3} ms Ktok/s" for cs in args.chunk_sizes - ) - sep = "-" * len(header) - - print(f"\n{pass_name.upper()} pass " - f"(B={args.B} Hk={args.Hk} Hv={args.Hv} " - f"K={args.K} V={args.V} dtype={args.dtype})") - print(header) - print(sep) - - for T in args.T: - if T % max(args.chunk_sizes) != 0: - # pad check: skip sizes where T < chunk_size - valid_cs = [cs for cs in args.chunk_sizes if cs <= T] - else: - valid_cs = args.chunk_sizes - - row = f"{T:>6} " - for cs in args.chunk_sizes: - if cs > T: - row += f"{'--':>8} {'--':>6} " - continue - try: - ms, tps = bench( - args.B, T, args.Hk, args.Hv, args.K, args.V, cs, dtype - ) - row += f"{ms:>8.2f} {tps/1000:>6.1f} " - except Exception as e: - row += f"{'ERR':>8} {'':>6} " - print(row) + for p in passes: + _print_qla_table(p, args, dtype) - print() + if args.compare_sdpa: + _print_compare_table(args, dtype) if __name__ == "__main__": From b374359aeeaf32cd368e7ea2732cace9d0acdd2e Mon Sep 17 00:00:00 2001 From: "clandestine.eth" <96172957+0xClandestine@users.noreply.github.com> Date: Sun, 3 May 2026 14:39:15 -0500 Subject: [PATCH 09/11] perf: reduce kkt_fwd einsum FLOPs for GQA by computing gram at Hk heads MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit For GQA configurations (Hk < Hv), the original implementation expands k from Hk to Hv heads before the einsum, computing Hv independent gram products when only Hk are distinct. Factoring out the shared k·k gram matrix (computed at Hk heads) and combining with the Hv-dimensional beta/decay via reshape+broadcast avoids materialising the gqa_ratio-x expanded k tensor and reduces einsum FLOPs by gqa_ratio. For Qwen-27B dims (Hk=8, Hv=32, K=128) this gives a 4× cheaper kkt_fwd einsum, yielding 5-9% forward throughput improvement at T≥1024. --- flash_qla_mlx/ops/gated_delta_rule/chunk.py | 31 +++++++++++++-------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/flash_qla_mlx/ops/gated_delta_rule/chunk.py b/flash_qla_mlx/ops/gated_delta_rule/chunk.py index a2876b2..02944b8 100644 --- a/flash_qla_mlx/ops/gated_delta_rule/chunk.py +++ b/flash_qla_mlx/ops/gated_delta_rule/chunk.py @@ -58,22 +58,31 @@ def _kkt_fwd( batch_size, num_tokens, num_k_heads, head_dim = k.shape num_v_heads = g.shape[-1] + gqa_ratio = num_v_heads // num_k_heads - if num_k_heads != num_v_heads: - k = mx.repeat(k, num_v_heads // num_k_heads, axis=2) - - k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, H, K] - g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, H] - beta = pad_and_reshape(beta, dim=1, chunk_size=chunk_size) # [B, N, C, H] + k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hk, K] + g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv] + beta = pad_and_reshape(beta, dim=1, chunk_size=chunk_size) # [B, N, C, Hv] mask = mx.triu(mx.ones((chunk_size, chunk_size), dtype=mx.bool_), k=0) - decay_mask = mx.exp(g[:, :, :, None, :] - g[:, :, None, :, :]) # [B, N, C, C, H] + decay_mask = mx.exp(g[:, :, :, None, :] - g[:, :, None, :, :]) # [B, N, C, C, Hv] decay_mask = mx.where(mask[None, None, :, :, None], mx.zeros_like(decay_mask), decay_mask) - # attn[b, n, c, h, d] = (k_beta[b,n,c,h,:] . k[b,n,d,h,:]) * decay[b,n,c,d,h] - attn = mx.einsum( - "bnchk,bndhk->bnchd", k * beta[:, :, :, :, None], k - ) * mx.swapaxes(decay_mask, -2, -1) + # Compute gram matrix using Hk heads only, then combine with Hv-dimensional + # beta/decay via a reshape+broadcast (zero-copy) to avoid expanding k by gqa_ratio. + # attn[c,h,d] = beta[c,h] * (k[c,g,:] . k[d,g,:]) * decay[c,d,h] g = h // gqa_ratio + gram = mx.einsum("bnchk,bndhk->bnchd", k, k) # [B, N, C, Hk, D] + if gqa_ratio > 1: + B, N = gram.shape[:2] + # beta_decay: [B, N, C, Hk, gqa_ratio, D] (no copy of gram needed) + beta_decay = ( + beta[:, :, :, :, None] * mx.swapaxes(decay_mask, -2, -1) + ).reshape(B, N, chunk_size, num_k_heads, gqa_ratio, chunk_size) + attn = (gram[:, :, :, :, None, :] * beta_decay).reshape( + B, N, chunk_size, num_v_heads, chunk_size + ) + else: + attn = gram * beta[:, :, :, :, None] * mx.swapaxes(decay_mask, -2, -1) attn = attn.reshape(batch_size, -1, num_v_heads, chunk_size)[:, :num_tokens] if cu_seqlens is not None: From 03b0976ba3215a1c4940c3a8ca7a95a5e072867d Mon Sep 17 00:00:00 2001 From: "clandestine.eth" <96172957+0xClandestine@users.noreply.github.com> Date: Sun, 3 May 2026 15:21:15 -0500 Subject: [PATCH 10/11] perf: replace _kkt_solve Python loop with Metal kernel (2.2x speedup) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces the 63-iteration Python forward-substitution loop in _kkt_solve with a single mx.fast.metal_kernel dispatch. Each threadgroup handles one (B,N,H) matrix; C parallel threads each solve one column of (I+L)^{-1} independently, so no threadgroup_barrier is required. Also fixes two bugs found during development: - grid=(batch_total * C) not (batch_total): grid is total threads, not threadgroups, so number of threadgroups = grid.x / threadgroup.x - mx.contiguous() required after swapaxes before reshape: non-contiguous dims cannot be merged without an explicit copy End-to-end improvement (B=1 Hk=8 Hv=32 K=128 V=128 cs=64): T=2048: 253→335 Ktok/s fwd (+32%) T=8192: 241→331 Ktok/s fwd (+37%) T=16384: 236→339 Ktok/s fwd (+44%) --- flash_qla_mlx/ops/gated_delta_rule/chunk.py | 88 +++++++++++++++++---- 1 file changed, 71 insertions(+), 17 deletions(-) diff --git a/flash_qla_mlx/ops/gated_delta_rule/chunk.py b/flash_qla_mlx/ops/gated_delta_rule/chunk.py index 02944b8..e91a7c7 100644 --- a/flash_qla_mlx/ops/gated_delta_rule/chunk.py +++ b/flash_qla_mlx/ops/gated_delta_rule/chunk.py @@ -90,6 +90,54 @@ def _kkt_fwd( return attn +# Module-level cache for Metal kkt_solve kernels, keyed by chunk_size. +# Each kernel solves (I - x_in)^{-1} for a batch of C×C lower-triangular +# matrices in a single GPU dispatch: one thread per (batch, column) pair, +# computing that column of the inverse via forward substitution. +_KKT_SOLVE_KERNELS: dict = {} + + +def _get_kkt_solve_kernel(chunk_size: int): + if chunk_size not in _KKT_SOLVE_KERNELS: + C = chunk_size + # Each threadgroup handles one (batch, n, h) element; C threads handle + # C columns in parallel. Columns are independent so no barrier is needed. + # Threadgroup memory avoids the register-spill bug seen with T x_col[C] + # at C=64 in Metal's shader compiler. + source = f""" + uint b = threadgroup_position_in_grid.x; + uint j = thread_position_in_threadgroup.x; + + // Column j of (I - x_in)^{{-1}} via forward substitution. + // x_in holds strictly-lower-triangular entries -L (negative values). + // X[i, j] = delta(i==j) + sum_{{d [B, N, H, C, D] (negated, lower-tri solve) - x = -pad_and_reshape(x, dim=1, chunk_size=chunk_size) # [B, N, C, H, D] - x = mx.swapaxes(x, 2, 3) # [B, N, H, C, D] - - # Forward substitution without in-place mutation. - # Accumulate solved rows into a growing matrix [B, N, H, i, C]. - accumulated = x[..., 0:1, :] # row 0 unchanged, [B, N, H, 1, C] - for i in range(1, chunk_size): - sub_i = accumulated[..., :i] # [B, N, H, i, i] - row_i = x[..., i:i+1, :i] # [B, N, H, 1, i] - new_val = row_i + mx.matmul(row_i, sub_i) # [B, N, H, 1, i] - pad = mx.zeros((*new_val.shape[:-1], chunk_size - i), dtype=x.dtype) - new_row = mx.concatenate([new_val, pad], axis=-1) # [B, N, H, 1, C] - accumulated = mx.concatenate([accumulated, new_row], axis=-2) - - x = accumulated + mx.eye(chunk_size, dtype=x.dtype) - x = mx.swapaxes(x, 2, 3) # [B, N, C, H, D] + # x: [B, T, H, D] -> [B, N, H, C, C] (negated, lower-tri solve) + x = -pad_and_reshape(x, dim=1, chunk_size=chunk_size) # [B, N, C, H, C] + x = mx.swapaxes(x, 2, 3) # [B, N, H, C, C] + + B, N, H, C, _ = x.shape + batch_total = B * N * H + + # Single Metal dispatch: one threadgroup per batch element, C threads per + # group each solving one column of (I - x)^{-1} independently. + # Replaces the (chunk_size - 1)-step Python loop with one GPU kernel call. + # NOTE: grid = total threads (not threadgroups). With threadgroup=(C,1,1), + # number of threadgroups = grid.x / C = batch_total. So grid.x = batch_total * C. + x = _get_kkt_solve_kernel(C)( + inputs=[mx.contiguous(x).reshape(batch_total, C, C)], + template=[("T", x.dtype)], + grid=(batch_total * C, 1, 1), + threadgroup=(C, 1, 1), + output_shapes=[(batch_total, C, C)], + output_dtypes=[x.dtype], + )[0].reshape(B, N, H, C, C) + + x = mx.swapaxes(x, 2, 3) # [B, N, C, H, C] — non-contiguous after swapaxes + x = mx.contiguous(x) # merge dims (N,C) requires C-contiguous layout x = x.reshape(batch_size, -1, num_heads, chunk_size)[:, :num_tokens] if cu_seqlens is not None: From 3a4fbf70ab3128c95e9ba1c8bb8cc587927d11e8 Mon Sep 17 00:00:00 2001 From: "clandestine.eth" <96172957+0xClandestine@users.noreply.github.com> Date: Sun, 3 May 2026 15:49:58 -0500 Subject: [PATCH 11/11] perf: eliminate KKT layout copies, cache triu masks, fix bwd list overhead MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three optimizations to chunk.py: 1. KKT solve layout copies: the Metal kernel now uses strided row access over the natural [B*N, C, H, C] layout produced by pad_and_reshape, so _kkt_solve no longer needs swapaxes+contiguous+reshape before dispatch or swapaxes+contiguous after. Kernel cache keyed by (chunk_size, H). 2. Mask caching: _get_mask(size, diagonal) caches mx.triu(mx.ones(...)) in a module-level dict, replacing five separate constructions across _kkt_fwd, _chunk_o_fwd, _chunk_dv_bwd, _chunk_dqkwg_bwd, _chunk_wy_bwd. 3. Backward list overhead: replaced dh_list.insert(0, ...) [O(N) per step, O(N²) total] with append + [::-1] reversal, and folded the dv update into a single expression per iteration. --- flash_qla_mlx/ops/gated_delta_rule/chunk.py | 111 ++++++++++---------- 1 file changed, 56 insertions(+), 55 deletions(-) diff --git a/flash_qla_mlx/ops/gated_delta_rule/chunk.py b/flash_qla_mlx/ops/gated_delta_rule/chunk.py index e91a7c7..073db35 100644 --- a/flash_qla_mlx/ops/gated_delta_rule/chunk.py +++ b/flash_qla_mlx/ops/gated_delta_rule/chunk.py @@ -64,7 +64,7 @@ def _kkt_fwd( g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv] beta = pad_and_reshape(beta, dim=1, chunk_size=chunk_size) # [B, N, C, Hv] - mask = mx.triu(mx.ones((chunk_size, chunk_size), dtype=mx.bool_), k=0) + mask = _get_mask(chunk_size, diagonal=0) decay_mask = mx.exp(g[:, :, :, None, :] - g[:, :, None, :, :]) # [B, N, C, C, Hv] decay_mask = mx.where(mask[None, None, :, :, None], mx.zeros_like(decay_mask), decay_mask) @@ -90,52 +90,63 @@ def _kkt_fwd( return attn -# Module-level cache for Metal kkt_solve kernels, keyed by chunk_size. -# Each kernel solves (I - x_in)^{-1} for a batch of C×C lower-triangular -# matrices in a single GPU dispatch: one thread per (batch, column) pair, -# computing that column of the inverse via forward substitution. +# Module-level caches. _KKT_SOLVE_KERNELS: dict = {} +_TRIU_MASKS: dict = {} -def _get_kkt_solve_kernel(chunk_size: int): - if chunk_size not in _KKT_SOLVE_KERNELS: +def _get_mask(size: int, diagonal: int = 0) -> mx.array: + key = (size, diagonal) + if key not in _TRIU_MASKS: + _TRIU_MASKS[key] = mx.triu(mx.ones((size, size), dtype=mx.bool_), k=diagonal) + return _TRIU_MASKS[key] + + +def _get_kkt_solve_kernel(chunk_size: int, num_heads: int): + """Return a cached Metal kernel for the KKT triangular solve. + + Input/output layout: [B*N, C, H, C] (the natural output of pad_and_reshape). + Each of the B*N*H matrices is accessed with stride H*C between rows so we + never need swapaxes or mx.contiguous — the reshape B*N from [B,N,...] is + always a valid contiguous merge of leading dims. + """ + key = (chunk_size, num_heads) + if key not in _KKT_SOLVE_KERNELS: C = chunk_size - # Each threadgroup handles one (batch, n, h) element; C threads handle - # C columns in parallel. Columns are independent so no barrier is needed. - # Threadgroup memory avoids the register-spill bug seen with T x_col[C] - # at C=64 in Metal's shader compiler. + H = num_heads + HC = H * C + CHC = C * H * C + # Threadgroup b_total maps to matrix (bn, h): bn = b_total / H, h = b_total % H. + # Row i of that matrix: flat offset bn*CHC + i*HC + h*C (stride HC between rows). + # Column j is thread index (one thread per column, parallel and independent). source = f""" - uint b = threadgroup_position_in_grid.x; + uint b_total = threadgroup_position_in_grid.x; uint j = thread_position_in_threadgroup.x; - - // Column j of (I - x_in)^{{-1}} via forward substitution. - // x_in holds strictly-lower-triangular entries -L (negative values). - // X[i, j] = delta(i==j) + sum_{{d [B, N, H, C, C] (negated, lower-tri solve) + # pad_and_reshape → [B, N, C, H, C] (contiguous). + # Reshape to [B*N, C, H, C] — valid contiguous merge of leading dims. + # The kernel uses strided row access (stride HC between rows) so no + # swapaxes or mx.contiguous is needed before or after dispatch. x = -pad_and_reshape(x, dim=1, chunk_size=chunk_size) # [B, N, C, H, C] - x = mx.swapaxes(x, 2, 3) # [B, N, H, C, C] - - B, N, H, C, _ = x.shape + B, N, C, H, _ = x.shape batch_total = B * N * H - # Single Metal dispatch: one threadgroup per batch element, C threads per - # group each solving one column of (I - x)^{-1} independently. - # Replaces the (chunk_size - 1)-step Python loop with one GPU kernel call. - # NOTE: grid = total threads (not threadgroups). With threadgroup=(C,1,1), - # number of threadgroups = grid.x / C = batch_total. So grid.x = batch_total * C. - x = _get_kkt_solve_kernel(C)( - inputs=[mx.contiguous(x).reshape(batch_total, C, C)], + x = _get_kkt_solve_kernel(C, H)( + inputs=[x.reshape(B * N, C, H, C)], template=[("T", x.dtype)], grid=(batch_total * C, 1, 1), threadgroup=(C, 1, 1), - output_shapes=[(batch_total, C, C)], + output_shapes=[(B * N, C, H, C)], output_dtypes=[x.dtype], - )[0].reshape(B, N, H, C, C) - - x = mx.swapaxes(x, 2, 3) # [B, N, C, H, C] — non-contiguous after swapaxes - x = mx.contiguous(x) # merge dims (N,C) requires C-contiguous layout - x = x.reshape(batch_size, -1, num_heads, chunk_size)[:, :num_tokens] + )[0].reshape(batch_size, -1, num_heads, chunk_size)[:, :num_tokens] if cu_seqlens is not None: x = pack(x, cu_seqlens) @@ -334,9 +337,7 @@ def _chunk_o_fwd( q = q * scale - mask = mx.triu( - mx.ones((chunk_size, chunk_size), dtype=mx.bool_), k=1 - ) + mask = _get_mask(chunk_size, diagonal=1) decay_mask = mx.exp(g[:, :, :, None, :] - g[:, :, None, :, :]) decay_mask = mx.where(mask[None, None, :, :, None], mx.zeros_like(decay_mask), decay_mask) @@ -386,7 +387,7 @@ def _chunk_dv_bwd( q = q * scale - mask = mx.triu(mx.ones((chunk_size, chunk_size), dtype=mx.bool_), k=1) + mask = _get_mask(chunk_size, diagonal=1) decay_mask = mx.exp(g[:, :, :, None, :] - g[:, :, None, :, :]) decay_mask = mx.where(mask[None, None, :, :, None], mx.zeros_like(decay_mask), decay_mask) @@ -450,17 +451,17 @@ def _chunk_gdr_bwd( "bnchk,bnchv->bnhkv", q * mx.exp(g)[..., None], do ) - dh_list = [] - dv_list = list(mx.split(dv, dv.shape[1], axis=1)) # list of [B,1,C,Hv,V] + dh_acc = [] + dv_acc = [] + dv_slices = list(mx.split(dv, dv.shape[1], axis=1)) # list of [B,1,C,Hv,V] for i in reversed(range(k.shape[1])): - dh_list.insert(0, dstate) - dv_i = dv_list[i][:, 0] # [B, C, Hv, V] - dv_i = dv_i + mx.einsum( + dh_acc.append(dstate) + dv_i = dv_slices[i][:, 0] + mx.einsum( "bchk,bhkv->bchv", k[:, i] * mx.exp(g[:, i, -1:, :, None] - g[:, i, :, :, None]), dstate, ) - dv_list[i] = dv_i + dv_acc.append(dv_i) dstate = dstate * mx.exp(g[:, i, -1, :])[:, :, None, None] dstate = ( dstate @@ -468,10 +469,10 @@ def _chunk_gdr_bwd( - mx.einsum("bchk,bchv->bhkv", w[:, i], dv_i) ) - dh = mx.stack(dh_list, axis=1) + dh = mx.stack(dh_acc[::-1], axis=1) dh0 = None if h0 is None else dstate - dv = mx.stack(dv_list, axis=1).reshape( + dv = mx.stack(dv_acc[::-1], axis=1).reshape( batch_size, -1, num_v_heads, head_dim_v )[:, :num_tokens] @@ -526,7 +527,7 @@ def _chunk_dqkwg_bwd( dv = pad_and_reshape(dv, dim=1, chunk_size=chunk_size) g = fill_last_chunk_of_g(g, num_tokens, cu_seqlens, chunk_size=chunk_size) - mask = mx.triu(mx.ones((chunk_size, chunk_size), dtype=mx.bool_), k=1) + mask = _get_mask(chunk_size, diagonal=1) decay_mask = mx.exp(g[:, :, :, None, :] - g[:, :, None, :, :]) decay_mask = mx.where(mask[None, None, :, :, None], mx.zeros_like(decay_mask), decay_mask) @@ -619,7 +620,7 @@ def _chunk_wy_bwd( dv = dv_beta * beta[..., None] db = db + (dv_beta * v).sum(axis=-1) - mask = mx.triu(mx.ones((chunk_size_A, chunk_size_A), dtype=mx.bool_), k=0) + mask = _get_mask(chunk_size_A, diagonal=0) decay_mask = mx.exp(g[:, :, :, None, :] - g[:, :, None, :, :]) decay_mask = mx.where( mask[None, None, :, :, None], mx.zeros_like(decay_mask), decay_mask