Skip to content

feat(mlx): port FlashQLA to Apple Silicon#12

Open
0xClandestine wants to merge 11 commits into
QwenLM:mainfrom
0xClandestine:mlx-native-optimizations
Open

feat(mlx): port FlashQLA to Apple Silicon#12
0xClandestine wants to merge 11 commits into
QwenLM:mainfrom
0xClandestine:mlx-native-optimizations

Conversation

@0xClandestine
Copy link
Copy Markdown

@0xClandestine 0xClandestine commented May 3, 2026

Summary

This PR replaces CUDA-shaped implementation patterns with MLX-native equivalents, fixes latent bugs exposed by the refactor, and adds an Apple Silicon benchmark.

Changes

Performance

  • _kkt_solve Metal kernel — replaces the 63-iteration Python forward-substitution loop with a single mx.fast.metal_kernel dispatch. Each threadgroup handles one (B,N,H) matrix; C parallel threads each solve one column of (I+L)⁻¹ independently (columns are data-independent so no barrier needed). Threadgroup memory is used to avoid a register-spill bug in Metal's compiler for large thread-local arrays at C=64. 2.2× faster for _kkt_solve in isolation.
  • KKT solve layout copies eliminated — kernel now uses strided row access over the natural [B*N, C, H, C] layout from pad_and_reshape, removing 2× swapaxes + 2× mx.contiguous per forward call.
  • GQA gram matrix optimisation in _kkt_fwd — for GQA configs (Hk < Hv), avoids expanding k by gqa_ratio before the einsum; computes the shared gram matrix at Hk heads and combines with Hv-dimensional beta/decay via reshape+broadcast (zero-copy). Reduces einsum FLOPs by gqa_ratio (4× for Qwen-27B).
  • _kkt_solve functional rewrite — eliminated 63 full [B, N, H, C, C] copy-on-write allocations per forward pass by replacing in-place scatter-updates with a matmul-based row accumulation
  • mx.compile for fwd/bwd — wraps chunk_gated_delta_rule_fwd/bwd in mx.compile for the common cu_seqlens=None path, unrolling all Python loops into a static Metal graph paid once per unique shape
  • Vectorized pack/unpack/prepare_chunk_offsets — replaced per-batch-element Python loops and .item() calls with MLX gather ops and cumsum, reducing kernel dispatch count from O(B) to O(1) per call
  • Triu mask caching_get_mask(size, diagonal) caches mx.triu(mx.ones(...)) in a module-level dict, replacing five separate constructions across fwd/bwd helpers
  • Backward list overhead — replaced dh_list.insert(0, ...) (O(N²) total) with append + [::-1] reversal; dv update folded into one expression per iteration

Bug fixes

  • pad_and_reshape(A, dim=1) missing chunk_size in _w_u_fwd and _chunk_wy_bwd — both defaulted to 64 regardless of the actual chunk size, causing silent shape mismatches for any chunk_size != 64

API

  • chunk_size exposed as a parameter on chunk_gated_delta_rule, chunk_gated_delta_rule_fwd, and chunk_gated_delta_rule_bwd (default 64, backward-compatible)

Benchmark

  • benchmark/bench_mlx.py — sweeps sequence lengths and chunk sizes on Apple Silicon; --compare-sdpa adds a side-by-side against mx.fast.scaled_dot_product_attention (GQA causal)

Benchmark results (Qwen-27B dims, B=1 Hk=8 Hv=32 K=128 V=128, float32, causal)

chunk_size sweep (forward):

     T     cs=32 Ktok/s   cs=64 Ktok/s
   512          ~210-228       ~240-279
  1024           280-283        292-327
  2048           306            332-334
  4096           315-316        347-349
  8192           315-320        340-350
 16384           316-318        ~340-345

cs=64 is 8–12% faster than cs=32 for T≥2048.

Cumulative improvement from baseline (initial MLX port, T=2048):

  Baseline (initial port)          ~240-254 Ktok/s
  + GQA gram matrix optimisation   ~303-316 Ktok/s  (+26%)
  + Metal kernel for _kkt_solve    ~332-336 Ktok/s  (+7%)
  + layout copies / misc           ~332-349 Ktok/s  (+2-4%)
  ─────────────────────────────────────────────────────────
  Total                            ~345 Ktok/s      (+44%)

FlashQLA vs mx.fast.scaled_dot_product_attention (chunk_size=64, best chunk):

     T        FlashQLA            SDPA     QLA ms   SDPA ms   speedup
   512      ~243 Ktok/s      ~1189 Ktok/s    ~1.8     ~0.4     0.24x
  1024      ~303 Ktok/s       ~890 Ktok/s    ~3.4     ~1.2     0.34x
  2048       311 Ktok/s         548 Ktok/s    6.6      3.7     0.57x
  4096       329 Ktok/s         282 Ktok/s   12.5     14.5     1.17x
  8192       322 Ktok/s         146 Ktok/s   25.5     56.3     2.21x
 16384       331 Ktok/s          66 Ktok/s   49.5    247.8     5.00x

Crossover at ~4K tokens. FlashQLA is 5.0× faster than softmax attention at T=16384, with O(T) vs O(T²) memory.

Notes on parallel scan

The inter-chunk recurrence h_i = A_i @ h_{i-1} + b_i supports an associative parallel prefix scan, but composing operators costs O(K³) vs O(K²) sequential. For K=128 this is 512–1280× more FLOPs across practical sequence lengths — sequential is the correct algorithm and is documented as such in the code.

Sub-16K throughput analysis

For T<2048, throughput is Metal kernel dispatch overhead dominated (~8–10µs/dispatch). At T=512, cs=64 there are ~220 Metal kernel dispatches in the compiled graph (1 from the kkt_solve Metal kernel + 2×N from the inter-chunk recurrence + bookkeeping) — down from ~280 before (where the kkt_solve loop alone contributed 63 dispatches). The remaining gap between short and long sequences is dominated by the 2×N per-chunk recurrence dispatches; closing it entirely would require fusing the full intra-chunk pipeline into a single custom Metal kernel.

Test plan

  • All 11 correctness tests pass with 0.0000 relative error
  • python tests/test_mlx.py

Two-subprocess design compares flash_qla_mlx against a PyTorch CPU
reference across 7 cases covering GQA, initial state, T % 64 != 0,
forward-only, and forward+backward. All pass with 0.0000 relative error.

Also fixes two bugs in chunk.py:
- Replace mx.flip (absent in MLX 0.31.2) with [::-1] slice reversal
- Fix broadcast shape in _chunk_dqkwg_bwd (5D vs 4D mismatch)
…lation

Each `x[..., i, :i] = new_val` scatter-update was triggering a full
copy-on-write of the [B, N, H, C, C] array — 63 copies per forward pass.
Replace with a row-by-row matmul accumulation that builds the result
functionally, eliminating the redundant allocations and making the
computation graph safe for mx.compile tracing.
For fixed-shape (non-variable-length) inputs, wrap fwd and bwd with
mx.compile so MLX traces through all Python loops once per unique shape
and unrolls them into a static Metal graph.  This eliminates per-step
Python dispatch overhead for both the 63-iteration _kkt_solve loop and
the N-chunk _chunk_gdr_fwd/bwd loop.

cu_seqlens paths fall back to the uncompiled functions because pack/unpack
call .item() on tensor values during execution, which is incompatible
with tracing.
Replace per-batch-element Python loops (.item() calls) with pure MLX
gather operations: unpack uses arange+clip+fancy-index, pack uses a
broadcast comparison to map each packed position to its (batch, time)
source, and prepare_chunk_offsets uses cumsum instead of a Python
accumulation loop.

Reduces kernel dispatch count from O(B) to O(1) per call.  Still
requires one .item() call for the output shape (sum_T / max_len),
which is an unavoidable consequence of dynamic sequence lengths.
…enchmark

chunk_size is now a parameter of chunk_gated_delta_rule,
chunk_gated_delta_rule_fwd, and chunk_gated_delta_rule_bwd (default 64).
This enables hardware-specific tuning — Apple Silicon benchmarks show
cs=32 is 13-27% faster for Qwen-27B forward (T>=1024) while cs=64
remains better for backward/training.

Also fixes two pre-existing bugs where pad_and_reshape(A, dim=1) in
_w_u_fwd and _chunk_wy_bwd omitted the chunk_size argument, causing
silent shape mismatches for any chunk_size other than 64.

benchmark/bench_mlx.py sweeps sequence length and chunk_size for
Apple Silicon; run with --B 1 --Hk 8 --Hv 32 --K 128 --V 128 for
Qwen-27B dimensions.
…currence

The recurrence h_i = A_i @ h_{i-1} + b_i supports an associative parallel
prefix scan, but composing operators costs O(K^3) vs O(K^2) sequential.
For K=128 (Qwen-27B) this is 512-1280x more FLOPs across practical sequence
lengths.  Sequential with mx.compile unrolling is the correct algorithm.
--compare-sdpa flag adds a side-by-side table of FlashQLA vs
mx.fast.scaled_dot_product_attention (GQA causal) at Qwen-27B
dimensions.  Dimension mapping: FlashQLA Hk->N_kv, Hv->N_q so
both produce the same (B, T, Hv, V) output volume.
@0xClandestine 0xClandestine changed the title perf: MLX-native optimizations for Apple Silicon feat(mlx): port FlashQLA to Apple Silicon May 3, 2026
For GQA configurations (Hk < Hv), the original implementation expands k
from Hk to Hv heads before the einsum, computing Hv independent gram
products when only Hk are distinct.  Factoring out the shared k·k gram
matrix (computed at Hk heads) and combining with the Hv-dimensional
beta/decay via reshape+broadcast avoids materialising the gqa_ratio-x
expanded k tensor and reduces einsum FLOPs by gqa_ratio.

For Qwen-27B dims (Hk=8, Hv=32, K=128) this gives a 4× cheaper kkt_fwd
einsum, yielding 5-9% forward throughput improvement at T≥1024.
Replaces the 63-iteration Python forward-substitution loop in _kkt_solve
with a single mx.fast.metal_kernel dispatch.  Each threadgroup handles
one (B,N,H) matrix; C parallel threads each solve one column of
(I+L)^{-1} independently, so no threadgroup_barrier is required.

Also fixes two bugs found during development:
- grid=(batch_total * C) not (batch_total): grid is total threads, not
  threadgroups, so number of threadgroups = grid.x / threadgroup.x
- mx.contiguous() required after swapaxes before reshape: non-contiguous
  dims cannot be merged without an explicit copy

End-to-end improvement (B=1 Hk=8 Hv=32 K=128 V=128 cs=64):
  T=2048:  253→335 Ktok/s fwd  (+32%)
  T=8192:  241→331 Ktok/s fwd  (+37%)
  T=16384: 236→339 Ktok/s fwd  (+44%)
…rhead

Three optimizations to chunk.py:

1. KKT solve layout copies: the Metal kernel now uses strided row access
   over the natural [B*N, C, H, C] layout produced by pad_and_reshape, so
   _kkt_solve no longer needs swapaxes+contiguous+reshape before dispatch
   or swapaxes+contiguous after. Kernel cache keyed by (chunk_size, H).

2. Mask caching: _get_mask(size, diagonal) caches mx.triu(mx.ones(...))
   in a module-level dict, replacing five separate constructions across
   _kkt_fwd, _chunk_o_fwd, _chunk_dv_bwd, _chunk_dqkwg_bwd, _chunk_wy_bwd.

3. Backward list overhead: replaced dh_list.insert(0, ...) [O(N) per
   step, O(N²) total] with append + [::-1] reversal, and folded the dv
   update into a single expression per iteration.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant