diff --git a/benchmark/bench_mlx.py b/benchmark/bench_mlx.py new file mode 100644 index 0000000..50e06b4 --- /dev/null +++ b/benchmark/bench_mlx.py @@ -0,0 +1,260 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] +# +# Apple Silicon benchmark for flash_qla_mlx. +# Measures forward/backward throughput across sequence lengths and chunk sizes, +# and optionally compares against mx.fast.scaled_dot_product_attention. +# +# Usage: +# # chunk-size sweep (default) +# python benchmark/bench_mlx.py --B 1 --Hk 8 --Hv 32 --K 128 --V 128 +# +# # Qwen-27B forward-only comparison vs softmax attention +# python benchmark/bench_mlx.py --B 1 --Hk 8 --Hv 32 --K 128 --V 128 \ +# --fwd-only --compare-sdpa +# +# # Short options +# python benchmark/bench_mlx.py --fwd-only --T 512 1024 2048 4096 8192 +# python benchmark/bench_mlx.py --chunk-sizes 32 64 128 + +import argparse +import math +import time + +import mlx.core as mx + +from flash_qla_mlx import chunk_gated_delta_rule + + +# --------------------------------------------------------------------------- +# Timing helper +# --------------------------------------------------------------------------- + +def _time(fn, warmup: int = 3, iters: int = 10) -> float: + """Return median wall-clock time in ms.""" + for _ in range(warmup): + mx.eval(fn()) + times = [] + for _ in range(iters): + t0 = time.perf_counter() + mx.eval(fn()) + times.append((time.perf_counter() - t0) * 1000) + times.sort() + return times[len(times) // 2] + + +# --------------------------------------------------------------------------- +# FlashQLA runners +# --------------------------------------------------------------------------- + +def _qla_inputs(B, T, Hk, Hv, K, V, dtype): + mx.random.seed(0) + return ( + mx.random.normal((B, T, Hk, K)).astype(dtype), # q + mx.random.normal((B, T, Hk, K)).astype(dtype), # k + mx.random.normal((B, T, Hv, V)).astype(dtype), # v + -mx.abs(mx.random.normal((B, T, Hv))).astype(dtype), # g + mx.sigmoid(mx.random.normal((B, T, Hv))).astype(dtype), # beta + ) + + +def bench_qla_fwd(B, T, Hk, Hv, K, V, chunk_size, dtype): + q, k, v, g, beta = _qla_inputs(B, T, Hk, Hv, K, V, dtype) + ms = _time(lambda: chunk_gated_delta_rule(q, k, v, g, beta, + chunk_size=chunk_size)[0]) + return ms, B * T / ms * 1e3 + + +def bench_qla_bwd(B, T, Hk, Hv, K, V, chunk_size, dtype): + q, k, v, g, beta = _qla_inputs(B, T, Hk, Hv, K, V, dtype) + + def run(): + return mx.grad( + lambda q, k, v, g, beta: chunk_gated_delta_rule( + q, k, v, g, beta, chunk_size=chunk_size + )[0].sum() + )(q, k, v, g, beta) + + ms = _time(run) + return ms, B * T / ms * 1e3 + + +# --------------------------------------------------------------------------- +# SDPA runner +# +# Dimension mapping from FlashQLA to GQA softmax attention (Qwen-27B style): +# FlashQLA Hk -> SDPA N_kv (KV heads, e.g. 8) +# FlashQLA Hv -> SDPA N_q (query heads, e.g. 32) +# FlashQLA K -> SDPA D (head dim, e.g. 128) +# +# mx.fast.scaled_dot_product_attention expects [B, H, T, D] layout. +# Causal mask is used for a fair comparison (FlashQLA is always causal). +# --------------------------------------------------------------------------- + +def bench_sdpa_fwd(B, T, Hk, Hv, K, V, dtype): + """ + Benchmark mx.fast.scaled_dot_product_attention with GQA: + q : [B, Hv, T, K] — Hv query heads + k : [B, Hk, T, K] — Hk KV heads + v : [B, Hk, T, V] — Hk KV heads + This matches the output dimensionality of FlashQLA (B, T, Hv, V). + """ + mx.random.seed(0) + scale = K ** -0.5 + q = mx.random.normal((B, Hv, T, K)).astype(dtype) + k = mx.random.normal((B, Hk, T, K)).astype(dtype) + v = mx.random.normal((B, Hk, T, V)).astype(dtype) + + ms = _time(lambda: mx.fast.scaled_dot_product_attention( + q, k, v, scale=scale, mask="causal" + )) + return ms, B * T / ms * 1e3 + + +# --------------------------------------------------------------------------- +# Table helpers +# --------------------------------------------------------------------------- + +def _print_qla_table(pass_name, args, dtype): + bench = bench_qla_fwd if pass_name == "fwd" else bench_qla_bwd + header = f"{'T':>6} " + " ".join( + f"cs={cs:>3} ms Ktok/s" for cs in args.chunk_sizes + ) + print(f"\nFlashQLA {pass_name.upper()} " + f"(B={args.B} Hk={args.Hk} Hv={args.Hv} " + f"K={args.K} V={args.V} {args.dtype})") + print(header) + print("-" * len(header)) + for T in args.T: + row = f"{T:>6} " + for cs in args.chunk_sizes: + if cs > T: + row += f"{'--':>10} {'--':>6} " + continue + try: + ms, tps = bench(args.B, T, args.Hk, args.Hv, + args.K, args.V, cs, dtype) + row += f"{ms:>10.2f} {tps/1000:>6.1f} " + except Exception as e: + row += f"{'ERR':>10} {'ERR':>6} " + print(row) + + +def _print_compare_table(args, dtype): + """ + Side-by-side: best FlashQLA chunk_size vs mx.fast.scaled_dot_product_attention. + Runs forward pass only (SDPA has no exposed backward in MLX fast path). + """ + # Pick best chunk_size from a quick pre-sweep + print(f"\nPre-sweeping chunk sizes to find best for T={args.T[0]}...") + best_cs, best_ms = args.chunk_sizes[0], float("inf") + for cs in args.chunk_sizes: + if cs > args.T[0]: + continue + try: + ms, _ = bench_qla_fwd(args.B, args.T[0], args.Hk, args.Hv, + args.K, args.V, cs, dtype) + if ms < best_ms: + best_ms, best_cs = ms, cs + except Exception: + pass + print(f" Best chunk_size = {best_cs}") + + # GQA layout note + print(f"\n SDPA shape: q=[B,{args.Hv},T,{args.K}] " + f"k=v=[B,{args.Hk},T,{args.K}] causal (GQA {args.Hv//args.Hk}:1)") + print(f" FlashQLA: Hk={args.Hk} Hv={args.Hv} K={args.K} V={args.V} " + f"chunk_size={best_cs}") + + col = 14 + header = (f"{'T':>6} {'FlashQLA':>{col}} {'SDPA':>{col}} " + f"{'QLA ms':>8} {'SDPA ms':>8} {'speedup':>8}") + print() + print(header) + print("-" * len(header)) + + for T in args.T: + if best_cs > T: + print(f"{T:>6} (T < chunk_size, skipped)") + continue + + try: + qla_ms, qla_tps = bench_qla_fwd(args.B, T, args.Hk, args.Hv, + args.K, args.V, best_cs, dtype) + except Exception as e: + print(f"{T:>6} QLA ERR: {e}") + continue + + try: + sdpa_ms, sdpa_tps = bench_sdpa_fwd(args.B, T, args.Hk, args.Hv, + args.K, args.V, dtype) + speedup = sdpa_ms / qla_ms + sdpa_str = f"{sdpa_tps/1000:>10.1f} Ktok/s" + sdpa_ms_str = f"{sdpa_ms:>8.2f}" + speedup_str = f"{speedup:>7.2f}x" + except Exception as e: + sdpa_str = f"{'OOM/ERR':>{col}}" + sdpa_ms_str = f"{'--':>8}" + speedup_str = f"{'--':>8}" + + qla_str = f"{qla_tps/1000:>10.1f} Ktok/s" + print(f"{T:>6} {qla_str:>{col}} {sdpa_str:>{col}} " + f"{qla_ms:>8.2f} {sdpa_ms_str} {speedup_str}") + + print() + print(" speedup > 1 means FlashQLA is faster than SDPA") + print(" Note: FlashQLA is O(T) memory; SDPA is O(T²). " + "They are architecturally different primitives.") + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser( + description="flash_qla_mlx Apple Silicon benchmark", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Qwen-27B forward comparison: + python benchmark/bench_mlx.py \\ + --B 1 --Hk 8 --Hv 32 --K 128 --V 128 \\ + --fwd-only --compare-sdpa \\ + --T 512 1024 2048 4096 8192 16384 +""", + ) + parser.add_argument("--B", type=int, default=1) + parser.add_argument("--Hk", type=int, default=8, help="KV heads (FlashQLA k heads)") + parser.add_argument("--Hv", type=int, default=32, help="Q/output heads (FlashQLA v heads)") + parser.add_argument("--K", type=int, default=128, help="Head dim (key/query)") + parser.add_argument("--V", type=int, default=128, help="Head dim (value)") + parser.add_argument( + "--T", type=int, nargs="+", + default=[512, 1024, 2048, 4096, 8192], + ) + parser.add_argument( + "--chunk-sizes", type=int, nargs="+", default=[32, 64], + help="chunk_size values to sweep in the QLA table", + ) + parser.add_argument("--fwd-only", action="store_true") + parser.add_argument("--compare-sdpa", action="store_true", + help="Add SDPA vs FlashQLA comparison table") + parser.add_argument("--dtype", + choices=["float32", "float16", "bfloat16"], + default="float32") + args = parser.parse_args() + + dtype = {"float32": mx.float32, + "float16": mx.float16, + "bfloat16": mx.bfloat16}[args.dtype] + + passes = ["fwd"] if args.fwd_only else ["fwd", "bwd"] + for p in passes: + _print_qla_table(p, args, dtype) + + if args.compare_sdpa: + _print_compare_table(args, dtype) + + +if __name__ == "__main__": + main() diff --git a/flash_qla_mlx/__init__.py b/flash_qla_mlx/__init__.py new file mode 100644 index 0000000..b9ab370 --- /dev/null +++ b/flash_qla_mlx/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] + +__version__ = "0.1.0" + +from flash_qla_mlx.ops.gated_delta_rule.chunk import ( + chunk_gated_delta_rule_fwd, + chunk_gated_delta_rule_bwd, + chunk_gated_delta_rule, +) + +__all__ = [ + "chunk_gated_delta_rule_fwd", + "chunk_gated_delta_rule_bwd", + "chunk_gated_delta_rule", +] diff --git a/flash_qla_mlx/ops/__init__.py b/flash_qla_mlx/ops/__init__.py new file mode 100644 index 0000000..81120d5 --- /dev/null +++ b/flash_qla_mlx/ops/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] + +from .gated_delta_rule import chunk_gated_delta_rule + +__all__ = ["chunk_gated_delta_rule"] diff --git a/flash_qla_mlx/ops/gated_delta_rule/__init__.py b/flash_qla_mlx/ops/gated_delta_rule/__init__.py new file mode 100644 index 0000000..48f8668 --- /dev/null +++ b/flash_qla_mlx/ops/gated_delta_rule/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] + +from .chunk import ( + chunk_gated_delta_rule_fwd, + chunk_gated_delta_rule_bwd, + chunk_gated_delta_rule, +) + +__all__ = [ + "chunk_gated_delta_rule_fwd", + "chunk_gated_delta_rule_bwd", + "chunk_gated_delta_rule", +] diff --git a/flash_qla_mlx/ops/gated_delta_rule/chunk.py b/flash_qla_mlx/ops/gated_delta_rule/chunk.py new file mode 100644 index 0000000..073db35 --- /dev/null +++ b/flash_qla_mlx/ops/gated_delta_rule/chunk.py @@ -0,0 +1,914 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] + +import mlx.core as mx + +from flash_qla_mlx.utils import ( + pack, + unpack, + pad_and_reshape, + fill_last_chunk_of_g, + prepare_chunk_offsets, + l2norm, +) + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _chunk_local_cumsum( + g: mx.array, + chunk_size: int = 64, + cu_seqlens: mx.array = None, + reverse: bool = False, +) -> mx.array: + if cu_seqlens is not None: + g = unpack(g, cu_seqlens) + + batch_size, num_tokens, num_heads = g.shape + g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, H] + + if reverse: + g = g[:, :, ::-1, :] + g = mx.cumsum(g, axis=2) + g = g[:, :, ::-1, :] + else: + g = mx.cumsum(g, axis=2) + + g = g.reshape(batch_size, -1, num_heads)[:, :num_tokens] + + if cu_seqlens is not None: + g = pack(g, cu_seqlens) + return g + + +def _kkt_fwd( + k: mx.array, + g: mx.array, + beta: mx.array, + cu_seqlens: mx.array = None, + chunk_size: int = 64, +) -> mx.array: + if cu_seqlens is not None: + k = unpack(k, cu_seqlens) + g = unpack(g, cu_seqlens) + beta = unpack(beta, cu_seqlens) + + batch_size, num_tokens, num_k_heads, head_dim = k.shape + num_v_heads = g.shape[-1] + gqa_ratio = num_v_heads // num_k_heads + + k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hk, K] + g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv] + beta = pad_and_reshape(beta, dim=1, chunk_size=chunk_size) # [B, N, C, Hv] + + mask = _get_mask(chunk_size, diagonal=0) + decay_mask = mx.exp(g[:, :, :, None, :] - g[:, :, None, :, :]) # [B, N, C, C, Hv] + decay_mask = mx.where(mask[None, None, :, :, None], mx.zeros_like(decay_mask), decay_mask) + + # Compute gram matrix using Hk heads only, then combine with Hv-dimensional + # beta/decay via a reshape+broadcast (zero-copy) to avoid expanding k by gqa_ratio. + # attn[c,h,d] = beta[c,h] * (k[c,g,:] . k[d,g,:]) * decay[c,d,h] g = h // gqa_ratio + gram = mx.einsum("bnchk,bndhk->bnchd", k, k) # [B, N, C, Hk, D] + if gqa_ratio > 1: + B, N = gram.shape[:2] + # beta_decay: [B, N, C, Hk, gqa_ratio, D] (no copy of gram needed) + beta_decay = ( + beta[:, :, :, :, None] * mx.swapaxes(decay_mask, -2, -1) + ).reshape(B, N, chunk_size, num_k_heads, gqa_ratio, chunk_size) + attn = (gram[:, :, :, :, None, :] * beta_decay).reshape( + B, N, chunk_size, num_v_heads, chunk_size + ) + else: + attn = gram * beta[:, :, :, :, None] * mx.swapaxes(decay_mask, -2, -1) + attn = attn.reshape(batch_size, -1, num_v_heads, chunk_size)[:, :num_tokens] + + if cu_seqlens is not None: + attn = pack(attn, cu_seqlens) + return attn + + +# Module-level caches. +_KKT_SOLVE_KERNELS: dict = {} +_TRIU_MASKS: dict = {} + + +def _get_mask(size: int, diagonal: int = 0) -> mx.array: + key = (size, diagonal) + if key not in _TRIU_MASKS: + _TRIU_MASKS[key] = mx.triu(mx.ones((size, size), dtype=mx.bool_), k=diagonal) + return _TRIU_MASKS[key] + + +def _get_kkt_solve_kernel(chunk_size: int, num_heads: int): + """Return a cached Metal kernel for the KKT triangular solve. + + Input/output layout: [B*N, C, H, C] (the natural output of pad_and_reshape). + Each of the B*N*H matrices is accessed with stride H*C between rows so we + never need swapaxes or mx.contiguous — the reshape B*N from [B,N,...] is + always a valid contiguous merge of leading dims. + """ + key = (chunk_size, num_heads) + if key not in _KKT_SOLVE_KERNELS: + C = chunk_size + H = num_heads + HC = H * C + CHC = C * H * C + # Threadgroup b_total maps to matrix (bn, h): bn = b_total / H, h = b_total % H. + # Row i of that matrix: flat offset bn*CHC + i*HC + h*C (stride HC between rows). + # Column j is thread index (one thread per column, parallel and independent). + source = f""" + uint b_total = threadgroup_position_in_grid.x; + uint j = thread_position_in_threadgroup.x; + uint bn = b_total / {H}u; + uint h = b_total % {H}u; + threadgroup float X_tg[{C * C}]; + for (uint i = 0; i < {C}u; i++) {{ + X_tg[i * {C}u + j] = (i == j) ? 1.0f : 0.0f; + }} + for (uint i = 1; i < {C}u; i++) {{ + float acc = X_tg[i * {C}u + j]; + for (uint d = 0; d < i; d++) {{ + acc += (float)x_in[bn * {CHC}u + i * {HC}u + h * {C}u + d] + * X_tg[d * {C}u + j]; + }} + X_tg[i * {C}u + j] = acc; + }} + for (uint i = 0; i < {C}u; i++) {{ + out[bn * {CHC}u + i * {HC}u + h * {C}u + j] = (T)X_tg[i * {C}u + j]; + }} + """ + _KKT_SOLVE_KERNELS[key] = mx.fast.metal_kernel( + name=f"kkt_solve_cs{C}_h{H}", + input_names=["x_in"], + output_names=["out"], + source=source, + ) + return _KKT_SOLVE_KERNELS[key] + + +def _kkt_solve( + x: mx.array, + cu_seqlens: mx.array = None, + chunk_size: int = 64, +) -> mx.array: + if cu_seqlens is not None: + x = unpack(x, cu_seqlens) + + batch_size, num_tokens, num_heads, _ = x.shape + + # pad_and_reshape → [B, N, C, H, C] (contiguous). + # Reshape to [B*N, C, H, C] — valid contiguous merge of leading dims. + # The kernel uses strided row access (stride HC between rows) so no + # swapaxes or mx.contiguous is needed before or after dispatch. + x = -pad_and_reshape(x, dim=1, chunk_size=chunk_size) # [B, N, C, H, C] + B, N, C, H, _ = x.shape + batch_total = B * N * H + + x = _get_kkt_solve_kernel(C, H)( + inputs=[x.reshape(B * N, C, H, C)], + template=[("T", x.dtype)], + grid=(batch_total * C, 1, 1), + threadgroup=(C, 1, 1), + output_shapes=[(B * N, C, H, C)], + output_dtypes=[x.dtype], + )[0].reshape(batch_size, -1, num_heads, chunk_size)[:, :num_tokens] + + if cu_seqlens is not None: + x = pack(x, cu_seqlens) + return x + + +def _kkt( + k: mx.array, + beta: mx.array, + g: mx.array, + cu_seqlens: mx.array = None, + chunk_size: int = 64, +) -> mx.array: + """Compute A = (I - L)^{-1} where L encodes the gated KKT system.""" + A = _kkt_fwd(k=k, g=g, beta=beta, cu_seqlens=cu_seqlens, chunk_size=chunk_size) + A = _kkt_solve(x=A, cu_seqlens=cu_seqlens, chunk_size=chunk_size) + return A + + +def _w_u_fwd( + k: mx.array, + v: mx.array, + beta: mx.array, + A: mx.array, + g: mx.array, + cu_seqlens: mx.array = None, +) -> tuple: + if cu_seqlens is not None: + k = unpack(k, cu_seqlens) + v = unpack(v, cu_seqlens) + A = unpack(A, cu_seqlens) + beta = unpack(beta, cu_seqlens) + g = unpack(g, cu_seqlens) + + batch_size, num_tokens, _, chunk_size = A.shape + _, _, num_k_heads, head_dim_k = k.shape + _, _, num_v_heads, head_dim_v = v.shape + + if num_k_heads != num_v_heads: + k = mx.repeat(k, num_v_heads // num_k_heads, axis=2) + + k_beta = pad_and_reshape( + k * (beta * mx.exp(g))[..., None], dim=1, chunk_size=chunk_size + ) # [B, N, C, Hv, K] + v_beta = pad_and_reshape( + v * beta[..., None], dim=1, chunk_size=chunk_size + ) # [B, N, C, Hv, V] + A = pad_and_reshape(A, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, D] + + w = mx.einsum("bnchd,bndhk->bnchk", A, k_beta).reshape( + batch_size, -1, num_v_heads, head_dim_k + )[:, :num_tokens] + u = mx.einsum("bnchd,bndhk->bnchk", A, v_beta).reshape( + batch_size, -1, num_v_heads, head_dim_v + )[:, :num_tokens] + + if cu_seqlens is not None: + w = pack(w, cu_seqlens) + u = pack(u, cu_seqlens) + return w, u + + +def _chunk_gdr_fwd( + k: mx.array, + w: mx.array, + u: mx.array, + g: mx.array, + initial_state: mx.array = None, + cu_seqlens: mx.array = None, + chunk_size: int = 64, +) -> tuple: + if cu_seqlens is not None: + k = unpack(k, cu_seqlens) + w = unpack(w, cu_seqlens) + u = unpack(u, cu_seqlens) + g = unpack(g, cu_seqlens) + + batch_size, num_tokens, num_k_heads, head_dim_k = k.shape + _, _, num_v_heads, head_dim_v = u.shape + + if num_k_heads != num_v_heads: + k = mx.repeat(k, num_v_heads // num_k_heads, axis=2) + + k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K] + w = pad_and_reshape(w, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K] + u = pad_and_reshape(u, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V] + g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv] + g = fill_last_chunk_of_g(g, num_tokens, cu_seqlens, chunk_size=chunk_size) + + if initial_state is None: + last_state = mx.zeros( + (batch_size, num_v_heads, head_dim_k, head_dim_v), dtype=g.dtype + ) + else: + last_state = initial_state.astype(g.dtype) + + # Inter-chunk recurrence: h_i = A_i @ h_{i-1} + b_i (linear recurrence). + # A parallel prefix scan is theoretically possible but costs O(K^3) per + # operator composition vs O(K^2) sequential, giving 512-1280x more FLOPs + # for typical head dims (K=64-128) and sequence lengths. Sequential is + # the correct algorithm here. With mx.compile, this loop is unrolled into + # a static Metal graph so Python dispatch overhead is paid only once. + h_list, vn_list = [], [] + for i in range(k.shape[1]): + h_list.append(last_state) + v_new = u[:, i] - mx.einsum("bchk,bhkv->bchv", w[:, i], last_state) + vn_list.append(v_new) + last_state = last_state * mx.exp(g[:, i, -1, :])[:, :, None, None] + last_state = last_state + mx.einsum( + "bchk,bchv->bhkv", + k[:, i] * mx.exp(g[:, i, -1:, :, None] - g[:, i, :, :, None]), + v_new, + ) + + h = mx.stack(h_list, axis=1) + vn = ( + mx.stack(vn_list, axis=1) + .reshape(batch_size, -1, num_v_heads, head_dim_v)[:, :num_tokens] + ) + + if cu_seqlens is not None: + chunk_offsets = prepare_chunk_offsets(cu_seqlens, chunk_size) + vn = pack(vn, cu_seqlens) + h = pack(h, chunk_offsets) + + return h, vn, last_state + + +def _chunk_o_fwd( + q: mx.array, + k: mx.array, + v: mx.array, + h: mx.array, + g: mx.array, + cu_seqlens: mx.array = None, + scale: float = None, + chunk_size: int = 64, +) -> mx.array: + if cu_seqlens is not None: + chunk_offsets = prepare_chunk_offsets(cu_seqlens, chunk_size) + q = unpack(q, cu_seqlens) + k = unpack(k, cu_seqlens) + v = unpack(v, cu_seqlens) + g = unpack(g, cu_seqlens) + h = unpack(h, chunk_offsets) + + batch_size, num_tokens, num_k_heads, head_dim_k = k.shape + _, _, num_v_heads, head_dim_v = v.shape + + if num_k_heads != num_v_heads: + q = mx.repeat(q, num_v_heads // num_k_heads, axis=2) + k = mx.repeat(k, num_v_heads // num_k_heads, axis=2) + + scale = scale if scale is not None else head_dim_k ** (-0.5) + + q = pad_and_reshape(q, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K] + k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) + v = pad_and_reshape(v, dim=1, chunk_size=chunk_size) + g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) + + q = q * scale + + mask = _get_mask(chunk_size, diagonal=1) + decay_mask = mx.exp(g[:, :, :, None, :] - g[:, :, None, :, :]) + decay_mask = mx.where(mask[None, None, :, :, None], mx.zeros_like(decay_mask), decay_mask) + + attn = mx.einsum("bnchk,bndhk->bncdh", q, k) * decay_mask + attn_inter = mx.einsum("bnchk,bnhkv->bnchv", q * mx.exp(g)[..., None], h) + o = attn_inter + mx.einsum("bncdh,bndhv->bnchv", attn, v) + + o = o.reshape(batch_size, -1, num_v_heads, head_dim_v)[:, :num_tokens] + if cu_seqlens is not None: + o = pack(o, cu_seqlens) + return o + + +# --------------------------------------------------------------------------- +# Backward helpers +# --------------------------------------------------------------------------- + + +def _chunk_dv_bwd( + q: mx.array, + k: mx.array, + g: mx.array, + do: mx.array, + cu_seqlens: mx.array = None, + scale: float = None, + chunk_size: int = 64, +) -> mx.array: + if cu_seqlens is not None: + q = unpack(q, cu_seqlens) + k = unpack(k, cu_seqlens) + g = unpack(g, cu_seqlens) + do = unpack(do, cu_seqlens) + + batch_size, num_tokens, num_k_heads, head_dim_k = k.shape + _, _, num_v_heads, head_dim_v = do.shape + + if num_k_heads != num_v_heads: + q = mx.repeat(q, num_v_heads // num_k_heads, axis=2) + k = mx.repeat(k, num_v_heads // num_k_heads, axis=2) + + scale = scale if scale is not None else head_dim_k ** (-0.5) + + q = pad_and_reshape(q, dim=1, chunk_size=chunk_size) + k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) + g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) + do = pad_and_reshape(do, dim=1, chunk_size=chunk_size) + + q = q * scale + + mask = _get_mask(chunk_size, diagonal=1) + decay_mask = mx.exp(g[:, :, :, None, :] - g[:, :, None, :, :]) + decay_mask = mx.where(mask[None, None, :, :, None], mx.zeros_like(decay_mask), decay_mask) + + attn = mx.einsum("bnchk,bndhk->bncdh", q, k) * decay_mask + dv = mx.einsum("bncdh,bnchv->bndhv", attn, do) + + dv = dv.reshape(batch_size, -1, num_v_heads, head_dim_v)[:, :num_tokens] + if cu_seqlens is not None: + dv = pack(dv, cu_seqlens) + return dv + + +def _chunk_gdr_bwd( + q: mx.array, + k: mx.array, + w: mx.array, + g: mx.array, + do: mx.array, + dv: mx.array, + h0: mx.array = None, + dht: mx.array = None, + cu_seqlens: mx.array = None, + scale: float = None, + chunk_size: int = 64, +) -> tuple: + if cu_seqlens is not None: + q = unpack(q, cu_seqlens) + k = unpack(k, cu_seqlens) + w = unpack(w, cu_seqlens) + g = unpack(g, cu_seqlens) + do = unpack(do, cu_seqlens) + dv = unpack(dv, cu_seqlens) + + batch_size, num_tokens, num_k_heads, head_dim_k = k.shape + _, _, num_v_heads, head_dim_v = do.shape + + if num_k_heads != num_v_heads: + q = mx.repeat(q, num_v_heads // num_k_heads, axis=2) + k = mx.repeat(k, num_v_heads // num_k_heads, axis=2) + + scale = scale if scale is not None else head_dim_k ** (-0.5) + + q = pad_and_reshape(q, dim=1, chunk_size=chunk_size) + k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) + w = pad_and_reshape(w, dim=1, chunk_size=chunk_size) + g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) + do = pad_and_reshape(do, dim=1, chunk_size=chunk_size) + dv = pad_and_reshape(dv, dim=1, chunk_size=chunk_size) + g = fill_last_chunk_of_g(g, num_tokens, cu_seqlens, chunk_size=chunk_size) + + q = q * scale + + if dht is None: + dstate = mx.zeros( + (batch_size, num_v_heads, head_dim_k, head_dim_v), dtype=g.dtype + ) + else: + dstate = dht.astype(g.dtype) + + dstate_inter = mx.einsum( + "bnchk,bnchv->bnhkv", q * mx.exp(g)[..., None], do + ) + + dh_acc = [] + dv_acc = [] + dv_slices = list(mx.split(dv, dv.shape[1], axis=1)) # list of [B,1,C,Hv,V] + for i in reversed(range(k.shape[1])): + dh_acc.append(dstate) + dv_i = dv_slices[i][:, 0] + mx.einsum( + "bchk,bhkv->bchv", + k[:, i] * mx.exp(g[:, i, -1:, :, None] - g[:, i, :, :, None]), + dstate, + ) + dv_acc.append(dv_i) + dstate = dstate * mx.exp(g[:, i, -1, :])[:, :, None, None] + dstate = ( + dstate + + dstate_inter[:, i] + - mx.einsum("bchk,bchv->bhkv", w[:, i], dv_i) + ) + + dh = mx.stack(dh_acc[::-1], axis=1) + + dh0 = None if h0 is None else dstate + dv = mx.stack(dv_acc[::-1], axis=1).reshape( + batch_size, -1, num_v_heads, head_dim_v + )[:, :num_tokens] + + if cu_seqlens is not None: + chunk_offsets = prepare_chunk_offsets(cu_seqlens, chunk_size) + dv = pack(dv, cu_seqlens) + dh = pack(dh, chunk_offsets) + return dh, dh0, dv + + +def _chunk_dqkwg_bwd( + q: mx.array, + k: mx.array, + v: mx.array, + w: mx.array, + g: mx.array, + h: mx.array, + dv: mx.array, + do: mx.array, + dh: mx.array, + cu_seqlens: mx.array = None, + scale: float = None, + chunk_size: int = 64, +) -> tuple: + if cu_seqlens is not None: + chunk_offsets = prepare_chunk_offsets(cu_seqlens, chunk_size) + q = unpack(q, cu_seqlens) + k = unpack(k, cu_seqlens) + v = unpack(v, cu_seqlens) + w = unpack(w, cu_seqlens) + g = unpack(g, cu_seqlens) + do = unpack(do, cu_seqlens) + dv = unpack(dv, cu_seqlens) + h = unpack(h, chunk_offsets) + dh = unpack(dh, chunk_offsets) + + batch_size, num_tokens, num_k_heads, head_dim_k = k.shape + _, _, num_v_heads, head_dim_v = do.shape + + if num_k_heads != num_v_heads: + q = mx.repeat(q, num_v_heads // num_k_heads, axis=2) + k = mx.repeat(k, num_v_heads // num_k_heads, axis=2) + + scale = scale if scale is not None else head_dim_k ** (-0.5) + + q = pad_and_reshape(q, dim=1, chunk_size=chunk_size) + k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) + v = pad_and_reshape(v, dim=1, chunk_size=chunk_size) + w = pad_and_reshape(w, dim=1, chunk_size=chunk_size) + g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) + do = pad_and_reshape(do, dim=1, chunk_size=chunk_size) + dv = pad_and_reshape(dv, dim=1, chunk_size=chunk_size) + g = fill_last_chunk_of_g(g, num_tokens, cu_seqlens, chunk_size=chunk_size) + + mask = _get_mask(chunk_size, diagonal=1) + decay_mask = mx.exp(g[:, :, :, None, :] - g[:, :, None, :, :]) + decay_mask = mx.where(mask[None, None, :, :, None], mx.zeros_like(decay_mask), decay_mask) + + dg_last = (h * dh).sum(axis=-1).sum(axis=-1) # [B, N, Hv] + ds = mx.einsum("bnchv,bndhv->bncdh", do, v) + dq = mx.einsum("bnchv,bnhkv->bnchk", do, h) + dk = mx.einsum("bnchv,bnhkv->bnchk", v, dh) + dw = -mx.einsum("bnchv,bnhkv->bnchk", dv, h) + + g_last = g[:, :, -1] + dg_last = dg_last * mx.exp(g_last) + dq = dq * mx.exp(g)[..., None] * scale + dg = (q * dq).sum(axis=-1) + dk = dk * mx.exp(g_last[:, :, None, :] - g)[..., None] + dg = dg - (k * dk).sum(axis=-1) + dg_last = dg_last + (k * dk).sum(axis=-1).sum(axis=-2) + ds = ds * decay_mask * scale + ds2 = ds * mx.einsum("bnchk,bndhk->bncdh", q, k) + dg = dg + ds2.sum(axis=-2) + dg = dg - ds2.sum(axis=-3) + dq = dq + mx.einsum("bncdh,bndhk->bnchk", ds, k) + dk = dk + mx.einsum("bncdh,bnchk->bndhk", ds, q) + dg[:, :, -1] = dg[:, :, -1] + dg_last + + dg = fill_last_chunk_of_g( + dg, num_tokens, cu_seqlens, chunk_size=chunk_size, reverse=True + ) + dq = dq.reshape(batch_size, -1, num_v_heads, head_dim_k)[:, :num_tokens] + dk = dk.reshape(batch_size, -1, num_v_heads, head_dim_k)[:, :num_tokens] + dw = dw.reshape(batch_size, -1, num_v_heads, head_dim_k)[:, :num_tokens] + dg = dg.reshape(batch_size, -1, num_v_heads)[:, :num_tokens] + + if cu_seqlens is not None: + dq = pack(dq, cu_seqlens) + dk = pack(dk, cu_seqlens) + dw = pack(dw, cu_seqlens) + dg = pack(dg, cu_seqlens) + return dq, dk, dw, dg + + +def _chunk_wy_bwd( + k: mx.array, + v: mx.array, + beta: mx.array, + A: mx.array, + g: mx.array, + dw: mx.array, + du: mx.array, + dk1: mx.array, + dg1: mx.array, + cu_seqlens: mx.array = None, + chunk_size: int = 64, +) -> tuple: + if cu_seqlens is not None: + k = unpack(k, cu_seqlens) + v = unpack(v, cu_seqlens) + beta = unpack(beta, cu_seqlens) + A = unpack(A, cu_seqlens) + g = unpack(g, cu_seqlens) + dw = unpack(dw, cu_seqlens) + du = unpack(du, cu_seqlens) + dk1 = unpack(dk1, cu_seqlens) + dg1 = unpack(dg1, cu_seqlens) + + batch_size, num_tokens, num_k_heads, head_dim_k = k.shape + _, _, num_v_heads, head_dim_v = v.shape + chunk_size_A = A.shape[-1] + + if num_k_heads != num_v_heads: + k = mx.repeat(k, num_v_heads // num_k_heads, axis=2) + + k = pad_and_reshape(k, dim=1, chunk_size=chunk_size_A) + v = pad_and_reshape(v, dim=1, chunk_size=chunk_size_A) + beta = pad_and_reshape(beta, dim=1, chunk_size=chunk_size_A) + A = pad_and_reshape(A, dim=1, chunk_size=chunk_size_A) + g = pad_and_reshape(g, dim=1, chunk_size=chunk_size_A) + dw = pad_and_reshape(dw, dim=1, chunk_size=chunk_size_A) + du = pad_and_reshape(du, dim=1, chunk_size=chunk_size_A) + dk1 = pad_and_reshape(dk1, dim=1, chunk_size=chunk_size_A) + dg1 = pad_and_reshape(dg1, dim=1, chunk_size=chunk_size_A) + + dA = mx.einsum("bnchk,bndhk->bnchd", dw, k * (beta * mx.exp(g))[..., None]) + dk_beta_g = mx.einsum("bnchd,bnchk->bndhk", A, dw) + dk = dk_beta_g * (beta * mx.exp(g))[..., None] + db = (dk_beta_g * k * mx.exp(g)[..., None]).sum(axis=-1) + dg = (dk_beta_g * k * (mx.exp(g) * beta)[..., None]).sum(axis=-1) + + dA = dA + mx.einsum("bnchv,bndhv->bnchd", du, v * beta[..., None]) + dv_beta = mx.einsum("bnchd,bnchv->bndhv", A, du) + dv = dv_beta * beta[..., None] + db = db + (dv_beta * v).sum(axis=-1) + + mask = _get_mask(chunk_size_A, diagonal=0) + decay_mask = mx.exp(g[:, :, :, None, :] - g[:, :, None, :, :]) + decay_mask = mx.where( + mask[None, None, :, :, None], mx.zeros_like(decay_mask), decay_mask + ) + decay_mask = mx.swapaxes(decay_mask, -2, -1) + + dA = mx.where(mask[None, None, :, None, :], mx.zeros_like(dA), dA) + dA = mx.einsum("bndhc,bndhe->bnche", A, dA) + dA = mx.einsum("bnchd,bnehd->bnche", dA, A) + dA = -dA * decay_mask + + A_kkt = mx.einsum("bnchk,bndhk->bnchd", k * beta[..., None], k) + dk_beta = mx.einsum("bnchd,bndhk->bnchk", dA, k) + db = db + (dk_beta * k).sum(axis=-1) + dk = dk + mx.einsum("bnchd,bnchk->bndhk", dA, k * beta[..., None]) + dk = dk + dk_beta * beta[..., None] + dk = dk + dk1 + + dg = dg + (dA * A_kkt).sum(axis=-1) - mx.swapaxes( + (dA * A_kkt).sum(axis=-3), -1, -2 + ) + dg = dg + dg1 + + dk = dk.reshape(batch_size, -1, num_v_heads, head_dim_k)[:, :num_tokens] + dv = dv.reshape(batch_size, -1, num_v_heads, head_dim_k)[:, :num_tokens] + db = db.reshape(batch_size, -1, num_v_heads)[:, :num_tokens] + dg = dg.reshape(batch_size, -1, num_v_heads)[:, :num_tokens] + + if cu_seqlens is not None: + dk = pack(dk, cu_seqlens) + dv = pack(dv, cu_seqlens) + db = pack(db, cu_seqlens) + dg = pack(dg, cu_seqlens) + return dk, dv, db, dg + + +def _group_reduce_vector(buffer: mx.array, Hg: int) -> mx.array: + batch_size, num_tokens, H, K = buffer.shape + return buffer.reshape(batch_size, num_tokens, Hg, H // Hg, K).sum(axis=3) + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def chunk_gated_delta_rule_fwd( + q: mx.array, + k: mx.array, + v: mx.array, + g: mx.array, + beta: mx.array, + scale: float = None, + initial_state: mx.array = None, + cu_seqlens: mx.array = None, + output_final_state: bool = True, + output_h: bool = False, + chunk_size: int = 64, +) -> tuple: + """ + Forward pass of the Gated Delta Rule. + + Args: + q: [B, T, Hk, K] + k: [B, T, Hk, K] + v: [B, T, Hv, V] + g: [B, T, Hv] (log-decay, negative values) + beta: [B, T, Hv] + scale: Optional softmax scale (defaults to K**-0.5) + initial_state: Optional [B, Hv, K, V] + cu_seqlens: Optional [S+1] int32, for variable-length inputs + output_final_state: Whether to return the final state + output_h: Whether to return the chunk-level states h + chunk_size: Tokens per chunk (default 64; tune for your hardware) + + Returns: + (g_cumsum, A, o, h, final_state) + h and final_state may be None if not requested. + """ + if scale is None: + scale = k.shape[-1] ** -0.5 + + g_cumsum = _chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens) + A = _kkt(k=k, beta=beta, g=g_cumsum, cu_seqlens=cu_seqlens, chunk_size=chunk_size) + w, u = _w_u_fwd(k=k, v=v, beta=beta, A=A, g=g_cumsum, cu_seqlens=cu_seqlens) + h, vn, final_state = _chunk_gdr_fwd( + k=k, w=w, u=u, g=g_cumsum, + initial_state=initial_state, + cu_seqlens=cu_seqlens, chunk_size=chunk_size, + ) + o = _chunk_o_fwd( + q=q, k=k, v=vn, h=h, g=g_cumsum, + cu_seqlens=cu_seqlens, scale=scale, chunk_size=chunk_size, + ) + + return ( + g_cumsum, + A, + o, + h if output_h else None, + final_state if output_final_state else None, + ) + + +def chunk_gated_delta_rule_bwd( + q: mx.array, + k: mx.array, + v: mx.array, + g: mx.array, + beta: mx.array, + A: mx.array, + do: mx.array, + dht: mx.array = None, + scale: float = None, + initial_state: mx.array = None, + cu_seqlens: mx.array = None, + chunk_size: int = 64, +) -> tuple: + """ + Backward pass of the Gated Delta Rule. + + Args: + q, k, v, g, beta, A: Saved from forward (g is the cumsum version) + do: Gradient of the output [B, T, Hv, V] + dht: Gradient of the final state [B, Hv, K, V] or None + scale, initial_state, cu_seqlens: same as forward + chunk_size: Must match the value used in the forward pass + + Returns: + (dq, dk, dv, db, dg, dh0) + """ + if scale is None: + scale = k.shape[-1] ** -0.5 + + w, u = _w_u_fwd(k=k, v=v, beta=beta, A=A, g=g, cu_seqlens=cu_seqlens) + h, vn, _ = _chunk_gdr_fwd( + k=k, w=w, u=u, g=g, + initial_state=initial_state, cu_seqlens=cu_seqlens, chunk_size=chunk_size, + ) + dv = _chunk_dv_bwd( + q=q, k=k, g=g, do=do, cu_seqlens=cu_seqlens, scale=scale, chunk_size=chunk_size + ) + dh, dh0, dv = _chunk_gdr_bwd( + q=q, k=k, w=w, g=g, do=do, dv=dv, + h0=initial_state, dht=dht, + cu_seqlens=cu_seqlens, scale=scale, chunk_size=chunk_size, + ) + dq, dk1, dw, dg1 = _chunk_dqkwg_bwd( + q=q, k=k, v=vn, w=w, g=g, h=h, + dv=dv, do=do, dh=dh, + cu_seqlens=cu_seqlens, scale=scale, chunk_size=chunk_size, + ) + dk, dv_out, db, dg = _chunk_wy_bwd( + k=k, v=v, beta=beta, A=A, g=g, + dw=dw, du=dv, dk1=dk1, dg1=dg1, + cu_seqlens=cu_seqlens, chunk_size=chunk_size, + ) + + Hg, H = k.shape[-2], v.shape[-2] + if Hg < H: + dq = _group_reduce_vector(dq, Hg) + dk = _group_reduce_vector(dk, Hg) + + dg = _chunk_local_cumsum(dg, chunk_size=chunk_size, reverse=True, cu_seqlens=cu_seqlens) + return dq, dk, dv_out, db, dg, dh0 + + +# Compiled variants for the common cu_seqlens=None path. +# mx.compile traces all Python loops once at first call per shape, unrolling +# them into a static Metal graph. Falls back to uncompiled for variable-length +# (cu_seqlens) inputs because those paths call .item() during tracing. +_compiled_fwd = mx.compile(chunk_gated_delta_rule_fwd) +_compiled_bwd = mx.compile(chunk_gated_delta_rule_bwd) + + +def chunk_gated_delta_rule( + q: mx.array, + k: mx.array, + v: mx.array, + g: mx.array, + beta: mx.array, + scale: float = None, + initial_state: mx.array = None, + output_final_state: bool = False, + use_qk_l2norm_in_kernel: bool = False, + cu_seqlens: mx.array = None, + head_first: bool = False, + chunk_size: int = 64, +) -> tuple: + """ + Gated Delta Rule: end-to-end forward with MLX-native gradient support. + + Args: + q: [B, T, Hk, K] + k: [B, T, Hk, K] + v: [B, T, Hv, V] + g: [B, T, Hv] + beta: [B, T, Hv] + scale: Optional softmax scale + initial_state: Optional [B, Hv, K, V] + output_final_state: Return final recurrent state + use_qk_l2norm_in_kernel: L2-normalize q and k before computation + cu_seqlens: Optional variable-length sequence boundaries + head_first: Not supported (must be False) + chunk_size: Tokens per chunk (default 64; tune for your hardware) + + Returns: + (o, final_state) — final_state is None if output_final_state=False + """ + assert not head_first, "head_first=True is not supported." + assert v.shape[2] % k.shape[2] == 0, ( + "num_v_heads must be divisible by num_k_heads." + ) + + if cu_seqlens is not None and q.shape[0] != 1: + raise ValueError( + f"Batch size must be 1 when using cu_seqlens, got {q.shape[0]}." + ) + + if scale is None: + scale = k.shape[-1] ** -0.5 + + if use_qk_l2norm_in_kernel: + q = l2norm(q) + k = l2norm(k) + + # Define custom forward/backward for correct gradient computation. + # We use closures to capture scale, cu_seqlens, and the presence of h0. + _scale = scale + _cu_seqlens = cu_seqlens + _chunk_size = chunk_size + _fwd = _compiled_fwd if cu_seqlens is None else chunk_gated_delta_rule_fwd + _bwd = _compiled_bwd if cu_seqlens is None else chunk_gated_delta_rule_bwd + + if initial_state is not None: + @mx.custom_function + def _fn(q, k, v, g, beta, h0): + g_out, A, o, _, fs = _fwd( + q=q, k=k, v=v, g=g, beta=beta, + scale=_scale, initial_state=h0, + cu_seqlens=_cu_seqlens, + output_final_state=True, output_h=False, + chunk_size=_chunk_size, + ) + fs_safe = fs if fs is not None else mx.zeros_like(h0) + return o, fs_safe, g_out, A + + @_fn.vjp + def _fn_vjp(primals, cotangents, outputs): + q, k, v, g, beta, h0 = primals + do, dfs, _, _ = cotangents + _, _, g_out, A = outputs + dq, dk, dv, db, dg, dh0 = _bwd( + q=q, k=k, v=v, g=g_out, beta=beta, A=A, + do=do, dht=dfs, scale=_scale, + initial_state=h0, cu_seqlens=_cu_seqlens, + chunk_size=_chunk_size, + ) + return dq, dk, dv, dg, db, dh0 + + o, fs, _, _ = _fn(q, k, v, g, beta, initial_state) + else: + @mx.custom_function + def _fn(q, k, v, g, beta): + g_out, A, o, _, fs = _fwd( + q=q, k=k, v=v, g=g, beta=beta, + scale=_scale, initial_state=None, + cu_seqlens=_cu_seqlens, + output_final_state=output_final_state, output_h=False, + chunk_size=_chunk_size, + ) + _dummy = mx.zeros((1,), dtype=q.dtype) + return o, _dummy, g_out, A + + @_fn.vjp + def _fn_vjp(primals, cotangents, outputs): + q, k, v, g, beta = primals + do, _, _, _ = cotangents + _, _, g_out, A = outputs + dq, dk, dv, db, dg, _ = _bwd( + q=q, k=k, v=v, g=g_out, beta=beta, A=A, + do=do, dht=None, scale=_scale, + initial_state=None, cu_seqlens=_cu_seqlens, + chunk_size=_chunk_size, + ) + return dq, dk, dv, dg, db + + o, _, _, _ = _fn(q, k, v, g, beta) + fs = None + + return o, fs if output_final_state else None diff --git a/flash_qla_mlx/utils/__init__.py b/flash_qla_mlx/utils/__init__.py new file mode 100644 index 0000000..c8952e2 --- /dev/null +++ b/flash_qla_mlx/utils/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] + +from .math import l2norm +from .pack import ( + pack, + unpack, + pad_and_reshape, + fill_last_chunk_of_g, + prepare_chunk_offsets, +) + +__all__ = [ + "l2norm", + "pack", + "unpack", + "pad_and_reshape", + "fill_last_chunk_of_g", + "prepare_chunk_offsets", +] diff --git a/flash_qla_mlx/utils/math.py b/flash_qla_mlx/utils/math.py new file mode 100644 index 0000000..13dfe38 --- /dev/null +++ b/flash_qla_mlx/utils/math.py @@ -0,0 +1,9 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] + +import mlx.core as mx + + +def l2norm(x: mx.array, dim: int = -1, eps: float = 1e-6) -> mx.array: + inv_norm = mx.rsqrt((x * x).sum(axis=dim, keepdims=True) + eps) + return x * inv_norm diff --git a/flash_qla_mlx/utils/pack.py b/flash_qla_mlx/utils/pack.py new file mode 100644 index 0000000..1b30f7d --- /dev/null +++ b/flash_qla_mlx/utils/pack.py @@ -0,0 +1,144 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] + +import mlx.core as mx + + +def unpack( + x: mx.array, # [1, sum_T, *dims] + cu_seqlens: mx.array, +) -> mx.array: + assert x.shape[0] == 1 + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] # [B] + max_len = int(mx.max(seqlens).item()) + + t = mx.arange(max_len, dtype=mx.int32) # [max_len] + src = mx.expand_dims(cu_seqlens[:-1], 1) + t # [B, max_len] + src = mx.clip(src, 0, x.shape[1] - 1) + + out = x[0][src] # [B, max_len, *dims] + + valid = t < mx.expand_dims(seqlens, 1) # [B, max_len] + for _ in x.shape[2:]: + valid = mx.expand_dims(valid, -1) + return mx.where(valid, out, mx.zeros_like(out)) + + +def pack( + x: mx.array, # [B, max_T, *dims] + cu_seqlens: mx.array, +) -> mx.array: + sum_T = int(cu_seqlens[-1].item()) + + i = mx.arange(sum_T, dtype=mx.int32) # [sum_T] + # b_idx[i] = batch element that packed position i belongs to + b_idx = ( + mx.expand_dims(i, 1) >= mx.expand_dims(cu_seqlens[1:], 0) + ).sum(axis=1).astype(mx.int32) # [sum_T] + t_idx = (i - cu_seqlens[b_idx]).astype(mx.int32) # [sum_T] + + return x[b_idx, t_idx][None] # [1, sum_T, *dims] + + +def pad_and_reshape( + x: mx.array, + dim: int, + chunk_size: int = 64, +) -> mx.array: + seq_len = x.shape[dim] + pad_size = (chunk_size - seq_len % chunk_size) % chunk_size + if pad_size > 0: + pad_width = [(0, 0)] * len(x.shape) + pad_width[dim] = (0, pad_size) + x = mx.pad(x, pad_width) + shape = list(x.shape) + shape[dim : dim + 1] = [-1, chunk_size] + return x.reshape(shape) + + +def fill_last_chunk_of_g( + g: mx.array, # [B, N, C, Hv] + num_tokens: int, + cu_seqlens: mx.array = None, + chunk_size: int = 64, + reverse: bool = False, +) -> mx.array: + if cu_seqlens is None: + last_chunk_size = num_tokens % chunk_size + if last_chunk_size > 0: + B, N, C, Hv = g.shape + g_last = g[:, -1, :, :] # [B, C, Hv] + if reverse: + update = ( + g_last[:, last_chunk_size - 1 : last_chunk_size, :] + + g_last[:, -1:, :] + ) + new_last = mx.concatenate( + [ + g_last[:, : last_chunk_size - 1, :], + update, + g_last[:, last_chunk_size:, :], + ], + axis=1, + ) + else: + fill_val = mx.broadcast_to( + g_last[:, last_chunk_size - 1 : last_chunk_size, :], + (B, C - last_chunk_size, Hv), + ) + new_last = mx.concatenate( + [g_last[:, :last_chunk_size, :], fill_val], axis=1 + ) + g = mx.concatenate([g[:, :-1, :, :], new_last[:, None, :, :]], axis=1) + else: + batch_size = cu_seqlens.shape[0] - 1 + new_g_list = [] + for i in range(batch_size): + start = int(cu_seqlens[i].item()) + end = int(cu_seqlens[i + 1].item()) + seqlen = end - start + last_chunk_idx = seqlen // chunk_size + lcs = seqlen % chunk_size + g_i = g[i] # [N, C, Hv] + if lcs > 0: + g_i_last = g_i[last_chunk_idx] # [C, Hv] + if reverse: + update = ( + g_i_last[lcs - 1 : lcs, :] + g_i_last[-1:, :] + ) + new_chunk = mx.concatenate( + [ + g_i_last[: lcs - 1, :], + update, + g_i_last[lcs:, :], + ], + axis=0, + ) + else: + fill_val = mx.broadcast_to( + g_i_last[lcs - 1 : lcs, :], + (chunk_size - lcs, g_i_last.shape[-1]), + ) + new_chunk = mx.concatenate( + [g_i_last[:lcs, :], fill_val], axis=0 + ) + g_i = mx.concatenate( + [ + g_i[:last_chunk_idx, :, :], + new_chunk[None, :, :], + g_i[last_chunk_idx + 1 :, :, :], + ], + axis=0, + ) + new_g_list.append(g_i) + g = mx.stack(new_g_list, axis=0) + return g + + +def prepare_chunk_offsets(cu_seqlens: mx.array, chunk_size: int) -> mx.array: + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] # [B] + n_chunks = (seqlens + chunk_size - 1) // chunk_size # ceiling division + return mx.concatenate([ + mx.array([0], dtype=mx.int32), + mx.cumsum(n_chunks).astype(mx.int32), + ]) diff --git a/setup.py b/setup.py index 2c6c516..8d1181f 100644 --- a/setup.py +++ b/setup.py @@ -27,5 +27,8 @@ "tilelang==0.1.8", "apache-tvm-ffi==0.1.9", ], + extras_require={ + "mlx": ["mlx>=0.24.0"], + }, zip_safe=False, ) diff --git a/tests/test_mlx.py b/tests/test_mlx.py new file mode 100644 index 0000000..8ba0a26 --- /dev/null +++ b/tests/test_mlx.py @@ -0,0 +1,348 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] +# +# Numerical correctness tests for flash_qla_mlx (Apple Silicon / MLX). +# Compares against a PyTorch CPU reference. No CUDA required. +# +# Run: +# python tests/test_mlx.py +# +# Design: two sequential subprocesses share a temp directory. +# --gen-ref DIR : PyTorch-only process — computes reference, writes .npy files +# --check-mlx DIR: MLX-only process — loads .npy files, runs flash_qla_mlx +# Loading PyTorch and MLX in the same process causes macOS memory conflicts +# (OMP dual-runtime + Metal/CPU allocator interference), so they run separately. + +import os +import sys +import subprocess +import tempfile +# On macOS with Python 3.14 + PyTorch 2.11, numpy's OMP runtime must not load +# before torch's OMP runtime — the reverse order causes memory corruption on +# cumsum and other ops. Pre-load torch first when we're in gen-ref mode. +if "--gen-ref" in sys.argv: + import torch as _torch_preload # noqa: F401 — sets up OMP before numpy +import numpy as np + +# =========================================================================== +# Shared test-case definitions (used in both sub-modes) +# =========================================================================== + +CASES = [ + # forward + backward + dict(B=2, T=256, Hk=4, Hv=4, h0=False, seed=0), + dict(B=1, T=256, Hk=4, Hv=8, h0=False, seed=1), # GQA + dict(B=2, T=128, Hk=4, Hv=4, h0=True, seed=2), + dict(B=1, T=192, Hk=2, Hv=2, h0=False, seed=3), # T not divisible by 64 + dict(B=1, T=128, Hk=2, Hv=2, h0=False, seed=5, bwd=True), + dict(B=1, T=128, Hk=2, Hv=4, h0=False, seed=6, bwd=True), # GQA bwd + dict(B=1, T=128, Hk=2, Hv=2, h0=True, seed=7, bwd=True), +] +K_DIM = V_DIM = 128 + + +def make_inputs(rng, B, T, Hk, Hv, with_h0, with_bwd=False): + q = rng.standard_normal((B, T, Hk, K_DIM)).astype(np.float32) + k = rng.standard_normal((B, T, Hk, K_DIM)).astype(np.float32) + # L2-normalize k so KKT triangular solve stays stable in float32. + k /= np.linalg.norm(k, axis=-1, keepdims=True) + 1e-8 + v = rng.standard_normal((B, T, Hv, V_DIM)).astype(np.float32) + g = (-np.log1p(np.exp(-rng.standard_normal((B, T, Hv)).astype(np.float32))) / 16) + beta = (1 / (1 + np.exp(-rng.standard_normal((B, T, Hv)).astype(np.float32)))) + h0 = (rng.standard_normal((B, Hv, K_DIM, V_DIM)).astype(np.float32) + if with_h0 else None) + do = (rng.standard_normal((B, T, Hv, V_DIM)).astype(np.float32) + if with_bwd else None) + dht = (rng.standard_normal((B, Hv, K_DIM, V_DIM)).astype(np.float32) + if (with_bwd and with_h0) else None) + return q, k, v, g, beta, h0, do, dht + + +# =========================================================================== +# Mode: --gen-ref (pure PyTorch, no MLX) +# =========================================================================== + +def _gen_ref(data_dir): + import torch + + def pad(x, dim, cs=64): + amt = (cs - x.shape[dim] % cs) % cs + if amt > 0: # F.pad with 6-tuple on 4D tensors segfaults in Py3.14 + torch2.11 when amt==0 + zeros = [0] * (2 * (x.dim() - 1 - dim)) + x = torch.nn.functional.pad(x, (*zeros, 0, amt)) + return x.reshape(list(x.shape[:dim]) + [-1, cs] + list(x.shape[dim + 1:])) + + def fill_g(g, T, cs=64): + lcs = T % cs + if lcs: + g = g.clone() + g[:, -1, lcs:] = g[:, -1, lcs - 1:lcs] + return g + + def ref_fwd(q, k, v, g, beta, scale, h0=None, cs=64): + B, T, Hk, K = q.shape + _, _, Hv, V = v.shape + g = pad(g, 1, cs).cumsum(2).reshape(B, -1, Hv)[:, :T] + if Hk != Hv: + k = k.repeat_interleave(Hv // Hk, dim=2) + kc = pad(k, 1, cs); gc = pad(g, 1, cs); bc = pad(beta, 1, cs) + mask_u = torch.triu(torch.ones(cs, cs, dtype=torch.bool)) + decay = torch.exp(gc[:, :, :, None, :] - gc[:, :, None, :, :]) + decay = decay.masked_fill(mask_u[None, None, :, :, None], 0.0) + A = (torch.einsum("bnchk,bndhk->bnchd", kc * bc.unsqueeze(-1), kc) + * decay.swapaxes(-2, -1)).reshape(B, -1, Hv, cs)[:, :T] + Ac = -pad(A, 1, cs).swapaxes(2, 3) + for i in range(1, cs): + row = Ac[..., i, :i].clone() + sub = Ac[..., :i, :i].clone() + Ac[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + Ac = Ac + torch.eye(cs, dtype=Ac.dtype) + A = Ac.swapaxes(2, 3).reshape(B, -1, Hv, cs)[:, :T] + Ac2 = pad(A, 1) + kb = pad(k * (beta * g.exp()).unsqueeze(-1), 1) + vb = pad(v * beta.unsqueeze(-1), 1) + w = torch.einsum("bnchd,bndhk->bnchk", Ac2, kb).reshape(B, -1, Hv, K)[:, :T] + u = torch.einsum("bnchd,bndhv->bnchv", Ac2, vb).reshape(B, -1, Hv, V)[:, :T] + kc2 = pad(k, 1, cs); wc = pad(w, 1, cs); uc = pad(u, 1, cs) + gc2 = fill_g(pad(g, 1, cs), T, cs) + hs = (torch.zeros(B, Hv, K, V, dtype=g.dtype) + if h0 is None else h0.to(g.dtype)) + hl, vl = [], [] + for i in range(kc2.shape[1]): + hl.append(hs.clone()) + vn = uc[:, i] - torch.einsum("bchk,bhkv->bchv", wc[:, i], hs) + vl.append(vn) + hs = hs * gc2[:, i, -1, :].exp()[:, :, None, None] + hs = hs + torch.einsum( + "bchk,bchv->bhkv", + kc2[:, i] * (gc2[:, i, -1:, :, None] - gc2[:, i, :, :, None]).exp(), + vn, + ) + hc = torch.stack(hl, 1) + vn = torch.stack(vl, 1).reshape(B, -1, Hv, V)[:, :T] + qr = q.repeat_interleave(Hv // Hk, dim=2) if Hk != Hv else q + kr = k + qc = pad(qr, 1, cs) * scale; kc3 = pad(kr, 1, cs) + vnc = pad(vn, 1, cs); gc3 = pad(g, 1, cs) + mask_o = torch.triu(torch.ones(cs, cs, dtype=torch.bool), diagonal=1) + dec = torch.exp(gc3[:, :, :, None, :] - gc3[:, :, None, :, :]) + dec = dec.masked_fill(mask_o[None, None, :, :, None], 0.0) + at = torch.einsum("bnchk,bndhk->bncdh", qc, kc3) * dec + ai = torch.einsum("bnchk,bnhkv->bnchv", qc * gc3.exp().unsqueeze(-1), hc) + o = (ai + torch.einsum("bncdh,bndhv->bnchv", at, vnc)).reshape(B, -1, Hv, V)[:, :T] + return g, o, A, hs + + for idx, case in enumerate(CASES): + B, T, Hk, Hv = case["B"], case["T"], case["Hk"], case["Hv"] + rng = np.random.default_rng(case["seed"]) + scale = K_DIM ** -0.5 + q, k, v, g, beta, h0, do, dht = make_inputs( + rng, B, T, Hk, Hv, case["h0"], case.get("bwd", False) + ) + tq = torch.from_numpy(q.copy()); tk = torch.from_numpy(k.copy()) + tv = torch.from_numpy(v.copy()); tg = torch.from_numpy(g.copy()) + tb = torch.from_numpy(beta.copy()) + th0 = torch.from_numpy(h0.copy()) if h0 is not None else None + g_ref, o_ref, A_ref, s_ref = ref_fwd(tq, tk, tv, tg, tb, scale, th0) + d = { + "case": case, + "inputs": dict(q=q, k=k, v=v, g=g, beta=beta, h0=h0), + "fwd": dict( + g_cumsum=g_ref.numpy(), o=o_ref.numpy(), + A=A_ref.numpy(), s=s_ref.numpy(), + ), + } + if case.get("bwd"): + tq2 = torch.from_numpy(q.copy()).requires_grad_(True) + tk2 = torch.from_numpy(k.copy()).requires_grad_(True) + tv2 = torch.from_numpy(v.copy()).requires_grad_(True) + tg2 = torch.from_numpy(g.copy()).requires_grad_(True) + tb2 = torch.from_numpy(beta.copy()).requires_grad_(True) + th02 = (torch.from_numpy(h0.copy()).requires_grad_(True) + if h0 is not None else None) + _, o2, _, s2 = ref_fwd(tq2, tk2, tv2, tg2, tb2, scale, th02) + loss = (o2 * torch.from_numpy(do.copy())).sum() + if dht is not None: + loss = loss + (s2 * torch.from_numpy(dht.copy())).sum() + loss.backward() + d["bwd"] = dict( + do=do, dht=dht, + dq=tq2.grad.numpy(), dk=tk2.grad.numpy(), + dv=tv2.grad.numpy(), db=tb2.grad.numpy(), dg=tg2.grad.numpy(), + dh0=th02.grad.numpy() if th02 is not None else None, + ) + np.save(os.path.join(data_dir, f"case_{idx}.npy"), d, allow_pickle=True) + print(f"[gen-ref] wrote {len(CASES)} cases to {data_dir}") + + +# =========================================================================== +# Mode: --check-mlx (pure MLX, no PyTorch) +# =========================================================================== + +def _check_mlx(data_dir, tol=0.02): + sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) + import mlx.core as mx + from flash_qla_mlx import ( + chunk_gated_delta_rule_fwd as mlx_fwd, + chunk_gated_delta_rule_bwd as mlx_bwd, + chunk_gated_delta_rule, + ) + + def _mx(a): return mx.array(a) + def _np(a): mx.eval(a); return np.array(a) + def rel(got, ref): + g = np.asarray(got, dtype=np.float64) + r = np.asarray(ref, dtype=np.float64) + e = float(np.abs(g - r).max() / (np.abs(r).max() + 1e-8)) + return e if np.isfinite(e) else float("inf") + + files = sorted(f for f in os.listdir(data_dir) if f.endswith(".npy")) + all_passed = True + + for fname in files: + d = np.load(os.path.join(data_dir, fname), allow_pickle=True).item() + case = d["case"] + inp = d["inputs"] + fwd = d["fwd"] + q, k, v, g, beta = inp["q"], inp["k"], inp["v"], inp["g"], inp["beta"] + h0 = inp["h0"] + B, T, Hk, Hv = case["B"], case["T"], case["Hk"], case["Hv"] + scale = K_DIM ** -0.5 + + g_out, A_out, o_out, _, s_out = mlx_fwd( + q=_mx(q), k=_mx(k), v=_mx(v), g=_mx(g), beta=_mx(beta), + scale=scale, + initial_state=_mx(h0) if h0 is not None else None, + output_final_state=True, + ) + mx.eval(g_out, A_out, o_out) + errs_fwd = { + "o": rel(_np(o_out), fwd["o"]), + "g": rel(_np(g_out), fwd["g_cumsum"]), + "A": rel(_np(A_out), fwd["A"]), + } + if h0 is not None and s_out is not None: + mx.eval(s_out) + errs_fwd["s"] = rel(_np(s_out), fwd["s"]) + + tag = f"B={B} T={T} Hk={Hk} Hv={Hv} h0={case['h0']}" + status = "PASS" + for n, e in errs_fwd.items(): + if e > tol: + status = f"FAIL({n}={e:.4f})" + all_passed = False + print(f" [fwd] {tag} " + " ".join(f"{n}={e:.4f}" for n, e in errs_fwd.items()) + + f" {status}") + + if "bwd" in d: + bwd = d["bwd"] + dq_a, dk_a, dv_a, db_a, dg_a, dh0_a = mlx_bwd( + q=_mx(q), k=_mx(k), v=_mx(v), g=g_out, beta=_mx(beta), A=A_out, + do=_mx(bwd["do"]), + dht=_mx(bwd["dht"]) if bwd["dht"] is not None else None, + scale=scale, + initial_state=_mx(h0) if h0 is not None else None, + ) + mx.eval(dq_a, dk_a, dv_a, db_a, dg_a) + errs_bwd = { + "dq": rel(_np(dq_a), bwd["dq"]), + "dk": rel(_np(dk_a), bwd["dk"]), + "dv": rel(_np(dv_a), bwd["dv"]), + "db": rel(_np(db_a), bwd["db"]), + "dg": rel(_np(dg_a), bwd["dg"]), + } + if case["h0"] and dh0_a is not None: + mx.eval(dh0_a) + errs_bwd["dh0"] = rel(_np(dh0_a), bwd["dh0"]) + status = "PASS" + for n, e in errs_bwd.items(): + if e > tol: + status = f"FAIL({n}={e:.4f})" + all_passed = False + print(f" [bwd] {tag} " + " ".join(f"{n}={e:.4f}" for n, e in errs_bwd.items()) + + f" {status}") + + # Autograd smoke: mx.grad(chunk_gated_delta_rule) == mlx_bwd + rng = np.random.default_rng(42) + q2, k2, v2, g2, beta2, _, do2, _ = make_inputs(rng, 1, 128, 2, 2, False, True) + scale2 = K_DIM ** -0.5 + g_o2, A_o2, _, _, _ = mlx_fwd( + q=_mx(q2), k=_mx(k2), v=_mx(v2), g=_mx(g2), beta=_mx(beta2), + scale=scale2, output_final_state=False, + ) + mx.eval(g_o2, A_o2) + dq_e, dk_e, dv_e, db_e, dg_e, _ = mlx_bwd( + q=_mx(q2), k=_mx(k2), v=_mx(v2), g=g_o2, beta=_mx(beta2), A=A_o2, + do=_mx(do2), scale=scale2, + ) + mx.eval(dq_e, dk_e, dv_e, db_e, dg_e) + mdo = _mx(do2) + def loss_fn(q_, k_, v_, g_, b_): + o, _ = chunk_gated_delta_rule(q_, k_, v_, g_, b_, scale=scale2) + return (o * mdo).sum() + dq_ag, dk_ag, dv_ag, dg_ag, db_ag = mx.grad(loss_fn, argnums=(0,1,2,3,4))( + _mx(q2), _mx(k2), _mx(v2), _mx(g2), _mx(beta2) + ) + mx.eval(dq_ag, dk_ag, dv_ag, db_ag, dg_ag) + ag_errs = { + "dq": rel(_np(dq_e), _np(dq_ag)), + "dk": rel(_np(dk_e), _np(dk_ag)), + "dv": rel(_np(dv_e), _np(dv_ag)), + } + status = "PASS" + for n, e in ag_errs.items(): + if e > 0.02: + status = f"FAIL({n}={e:.4f})" + all_passed = False + print(f" [autograd] " + " ".join(f"{n}={e:.4f}" for n, e in ag_errs.items()) + + f" {status}") + + return all_passed + + +# =========================================================================== +# Orchestrator +# =========================================================================== + +def main(): + if "--gen-ref" in sys.argv: + _gen_ref(sys.argv[sys.argv.index("--gen-ref") + 1]) + return + if "--check-mlx" in sys.argv: + ok = _check_mlx(sys.argv[sys.argv.index("--check-mlx") + 1]) + sys.exit(0 if ok else 1) + + print("=" * 64) + print("flash_qla_mlx numerical correctness (PyTorch CPU reference)") + print("=" * 64) + + env = {**os.environ, "KMP_DUPLICATE_LIB_OK": "TRUE"} + + with tempfile.TemporaryDirectory() as tmp: + print("\n[1/2] PyTorch CPU reference...") + r = subprocess.run( + [sys.executable, __file__, "--gen-ref", tmp], + capture_output=True, text=True, env=env, + ) + if r.returncode != 0: + print("STDOUT:", r.stdout) + print("STDERR:", r.stderr) + sys.exit(1) + print(r.stdout.strip()) + + print("\n[2/2] flash_qla_mlx (MLX)...") + r = subprocess.run( + [sys.executable, __file__, "--check-mlx", tmp], + capture_output=True, text=True, env=env, + ) + print(r.stdout.strip()) + if r.returncode != 0: + print("STDERR:", r.stderr) + print("\nSome tests FAILED.") + sys.exit(1) + + print("\nAll tests passed.") + + +if __name__ == "__main__": + main()