feat(mlx): port FlashQLA to Apple Silicon#12
Open
0xClandestine wants to merge 11 commits into
Open
Conversation
Two-subprocess design compares flash_qla_mlx against a PyTorch CPU reference across 7 cases covering GQA, initial state, T % 64 != 0, forward-only, and forward+backward. All pass with 0.0000 relative error. Also fixes two bugs in chunk.py: - Replace mx.flip (absent in MLX 0.31.2) with [::-1] slice reversal - Fix broadcast shape in _chunk_dqkwg_bwd (5D vs 4D mismatch)
…lation Each `x[..., i, :i] = new_val` scatter-update was triggering a full copy-on-write of the [B, N, H, C, C] array — 63 copies per forward pass. Replace with a row-by-row matmul accumulation that builds the result functionally, eliminating the redundant allocations and making the computation graph safe for mx.compile tracing.
For fixed-shape (non-variable-length) inputs, wrap fwd and bwd with mx.compile so MLX traces through all Python loops once per unique shape and unrolls them into a static Metal graph. This eliminates per-step Python dispatch overhead for both the 63-iteration _kkt_solve loop and the N-chunk _chunk_gdr_fwd/bwd loop. cu_seqlens paths fall back to the uncompiled functions because pack/unpack call .item() on tensor values during execution, which is incompatible with tracing.
Replace per-batch-element Python loops (.item() calls) with pure MLX gather operations: unpack uses arange+clip+fancy-index, pack uses a broadcast comparison to map each packed position to its (batch, time) source, and prepare_chunk_offsets uses cumsum instead of a Python accumulation loop. Reduces kernel dispatch count from O(B) to O(1) per call. Still requires one .item() call for the output shape (sum_T / max_len), which is an unavoidable consequence of dynamic sequence lengths.
…enchmark chunk_size is now a parameter of chunk_gated_delta_rule, chunk_gated_delta_rule_fwd, and chunk_gated_delta_rule_bwd (default 64). This enables hardware-specific tuning — Apple Silicon benchmarks show cs=32 is 13-27% faster for Qwen-27B forward (T>=1024) while cs=64 remains better for backward/training. Also fixes two pre-existing bugs where pad_and_reshape(A, dim=1) in _w_u_fwd and _chunk_wy_bwd omitted the chunk_size argument, causing silent shape mismatches for any chunk_size other than 64. benchmark/bench_mlx.py sweeps sequence length and chunk_size for Apple Silicon; run with --B 1 --Hk 8 --Hv 32 --K 128 --V 128 for Qwen-27B dimensions.
…currence
The recurrence h_i = A_i @ h_{i-1} + b_i supports an associative parallel
prefix scan, but composing operators costs O(K^3) vs O(K^2) sequential.
For K=128 (Qwen-27B) this is 512-1280x more FLOPs across practical sequence
lengths. Sequential with mx.compile unrolling is the correct algorithm.
--compare-sdpa flag adds a side-by-side table of FlashQLA vs mx.fast.scaled_dot_product_attention (GQA causal) at Qwen-27B dimensions. Dimension mapping: FlashQLA Hk->N_kv, Hv->N_q so both produce the same (B, T, Hv, V) output volume.
For GQA configurations (Hk < Hv), the original implementation expands k from Hk to Hv heads before the einsum, computing Hv independent gram products when only Hk are distinct. Factoring out the shared k·k gram matrix (computed at Hk heads) and combining with the Hv-dimensional beta/decay via reshape+broadcast avoids materialising the gqa_ratio-x expanded k tensor and reduces einsum FLOPs by gqa_ratio. For Qwen-27B dims (Hk=8, Hv=32, K=128) this gives a 4× cheaper kkt_fwd einsum, yielding 5-9% forward throughput improvement at T≥1024.
Replaces the 63-iteration Python forward-substitution loop in _kkt_solve
with a single mx.fast.metal_kernel dispatch. Each threadgroup handles
one (B,N,H) matrix; C parallel threads each solve one column of
(I+L)^{-1} independently, so no threadgroup_barrier is required.
Also fixes two bugs found during development:
- grid=(batch_total * C) not (batch_total): grid is total threads, not
threadgroups, so number of threadgroups = grid.x / threadgroup.x
- mx.contiguous() required after swapaxes before reshape: non-contiguous
dims cannot be merged without an explicit copy
End-to-end improvement (B=1 Hk=8 Hv=32 K=128 V=128 cs=64):
T=2048: 253→335 Ktok/s fwd (+32%)
T=8192: 241→331 Ktok/s fwd (+37%)
T=16384: 236→339 Ktok/s fwd (+44%)
…rhead Three optimizations to chunk.py: 1. KKT solve layout copies: the Metal kernel now uses strided row access over the natural [B*N, C, H, C] layout produced by pad_and_reshape, so _kkt_solve no longer needs swapaxes+contiguous+reshape before dispatch or swapaxes+contiguous after. Kernel cache keyed by (chunk_size, H). 2. Mask caching: _get_mask(size, diagonal) caches mx.triu(mx.ones(...)) in a module-level dict, replacing five separate constructions across _kkt_fwd, _chunk_o_fwd, _chunk_dv_bwd, _chunk_dqkwg_bwd, _chunk_wy_bwd. 3. Backward list overhead: replaced dh_list.insert(0, ...) [O(N) per step, O(N²) total] with append + [::-1] reversal, and folded the dv update into a single expression per iteration.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR replaces CUDA-shaped implementation patterns with MLX-native equivalents, fixes latent bugs exposed by the refactor, and adds an Apple Silicon benchmark.
Changes
Performance
_kkt_solveMetal kernel — replaces the 63-iteration Python forward-substitution loop with a singlemx.fast.metal_kerneldispatch. Each threadgroup handles one (B,N,H) matrix; C parallel threads each solve one column of (I+L)⁻¹ independently (columns are data-independent so no barrier needed). Threadgroup memory is used to avoid a register-spill bug in Metal's compiler for large thread-local arrays at C=64. 2.2× faster for_kkt_solvein isolation.[B*N, C, H, C]layout frompad_and_reshape, removing 2×swapaxes+ 2×mx.contiguousper forward call._kkt_fwd— for GQA configs (Hk < Hv), avoids expandingkbygqa_ratiobefore the einsum; computes the shared gram matrix at Hk heads and combines with Hv-dimensional beta/decay via reshape+broadcast (zero-copy). Reduces einsum FLOPs bygqa_ratio(4× for Qwen-27B)._kkt_solvefunctional rewrite — eliminated 63 full[B, N, H, C, C]copy-on-write allocations per forward pass by replacing in-place scatter-updates with a matmul-based row accumulationmx.compilefor fwd/bwd — wrapschunk_gated_delta_rule_fwd/bwdinmx.compilefor the commoncu_seqlens=Nonepath, unrolling all Python loops into a static Metal graph paid once per unique shapepack/unpack/prepare_chunk_offsets— replaced per-batch-element Python loops and.item()calls with MLX gather ops andcumsum, reducing kernel dispatch count from O(B) to O(1) per call_get_mask(size, diagonal)cachesmx.triu(mx.ones(...))in a module-level dict, replacing five separate constructions across fwd/bwd helpersdh_list.insert(0, ...)(O(N²) total) with append +[::-1]reversal; dv update folded into one expression per iterationBug fixes
pad_and_reshape(A, dim=1)missingchunk_sizein_w_u_fwdand_chunk_wy_bwd— both defaulted to 64 regardless of the actual chunk size, causing silent shape mismatches for anychunk_size != 64API
chunk_sizeexposed as a parameter onchunk_gated_delta_rule,chunk_gated_delta_rule_fwd, andchunk_gated_delta_rule_bwd(default 64, backward-compatible)Benchmark
benchmark/bench_mlx.py— sweeps sequence lengths and chunk sizes on Apple Silicon;--compare-sdpaadds a side-by-side againstmx.fast.scaled_dot_product_attention(GQA causal)Benchmark results (Qwen-27B dims, B=1 Hk=8 Hv=32 K=128 V=128, float32, causal)
chunk_size sweep (forward):
cs=64is 8–12% faster thancs=32for T≥2048.Cumulative improvement from baseline (initial MLX port, T=2048):
FlashQLA vs
mx.fast.scaled_dot_product_attention(chunk_size=64, best chunk):Crossover at ~4K tokens. FlashQLA is 5.0× faster than softmax attention at T=16384, with O(T) vs O(T²) memory.
Notes on parallel scan
The inter-chunk recurrence
h_i = A_i @ h_{i-1} + b_isupports an associative parallel prefix scan, but composing operators costs O(K³) vs O(K²) sequential. For K=128 this is 512–1280× more FLOPs across practical sequence lengths — sequential is the correct algorithm and is documented as such in the code.Sub-16K throughput analysis
For T<2048, throughput is Metal kernel dispatch overhead dominated (~8–10µs/dispatch). At T=512, cs=64 there are ~220 Metal kernel dispatches in the compiled graph (1 from the kkt_solve Metal kernel + 2×N from the inter-chunk recurrence + bookkeeping) — down from ~280 before (where the kkt_solve loop alone contributed 63 dispatches). The remaining gap between short and long sequences is dominated by the 2×N per-chunk recurrence dispatches; closing it entirely would require fusing the full intra-chunk pipeline into a single custom Metal kernel.
Test plan
python tests/test_mlx.py