Feature/expose metal flashattn kv#97
Feature/expose metal flashattn kv#97matrixsoftwarelimited wants to merge 3 commits intotattn:mainfrom
Conversation
Adds four new parameters to `LlamaClient.Parameter` so callers can control Metal GPU offload, Flash Attention, and KV cache quantization on Apple platforms — knobs that were previously baked into llama.cpp defaults and not reachable from Swift. * `nGpuLayers: Int = -1` — number of model layers to offload to the GPU. `-1` means "all", which is the right call on Apple Silicon's unified memory. Wired through `Model.init` to `llama_model_params.n_gpu_layers`. Simulator override (forced 0) is preserved. * `flashAttention: Bool = true` — toggles Flash Attention. llama.cpp `b8851` exposes this as a tri-state enum (`auto/disabled/enabled`); the bool maps to explicit `enabled`/`disabled` so behavior is deterministic. * `kvCacheTypeK` / `kvCacheTypeV: KVCacheType = .f16` — quantization for the KV cache, mapped to `ggml_type` (`GGML_TYPE_F16/Q8_0/Q4_0`). `q8_0` roughly halves cache memory at negligible quality cost, allowing larger contexts at the same RAM footprint. The `KVCacheType` Swift enum is added to `LlamaClient` as a public type so callers don't need to import the C header directly. Backwards compatible: defaults match prior behavior. Old call sites compile unchanged.
`Context.clear()` now also resets the prefill batch alongside the KV cache. Without this, a generation that ends via an external stop condition (e.g. stop sequences enforced at the consumer level) leaves the batch struct with `n_tokens > 0`. The Generator does `batch.add` per emitted token and relies on the NEXT iteration's `decode()` to call `batch.clear()`; if the consumer breaks out of `for try await ... in stream` before that next iteration, the cleanup never runs. The subsequent `textStream(...)` call then enters `decode(text:)` with stale `n_tokens`, walks past the end of `seq_id` (which is allocated for exactly `parameter.batch` entries by `llama_batch_init`), and crashes on a force-unwrap of nil at `Batch.swift:20`. Use `n_gpu_layers = 999` instead of `Int32.max` for the "all layers" sentinel — matches the value used throughout the llama.cpp examples and sidesteps any internal arithmetic that might overflow with INT32_MAX.
Round-2 fix (d71786a) added `batch.clear()` to `Context.clear()` but only the `.plain` case in `LlamaClient.textStream(from:)` calls `context.clear()`. The `.chat` and `.chatTemplate` paths went straight into `messageProcessor.process(...)`, so a stale `batch.n_tokens > 0` from a previous consumer-level early termination (stop sequence or maxTokens break) survived into the next generation's prefill — same crash, different trigger. Surface a `context.clear()` at the top of all three input cases so the fix is consumer-symmetric. Reproduced on iOS with Gemma 3 1B Q4_K_M: persona turn 1 (stop sequence) → judge turn 1 (maxTokens=200 break) → persona turn 2 prefill → Batch.swift:20 force-unwrap of nil seq_id slot.
|
all done |
There was a problem hiding this comment.
Code Review
This pull request introduces support for Flash Attention, KV cache quantization, and configurable GPU layer offloading. It also addresses a crash occurring when generations are interrupted by ensuring the batch state is cleared. Feedback focuses on ensuring the Swift-side promptCaches are cleared alongside the KV cache to prevent state mismatches. There are also suggestions to improve the safety of GPU layer parameter conversion and to reconsider the performance impact of full context clearing in chat modes, which may inadvertently disable KV cache reuse.
| // of the batch's `seq_id` array (allocated for `parameter.batch` | ||
| // entries) and crashes on a force-unwrap of nil. Clearing the batch | ||
| // here makes `clear()` safe to call between any two generations. | ||
| batch.clear() |
There was a problem hiding this comment.
The clear() method resets the underlying KV cache using llama_memory_clear, but it does not clear the promptCaches array. This will lead to a state mismatch where subsequent calls to textStream might skip processing prompt chunks that are no longer in the KV cache, resulting in incorrect model output. You should clear the Swift-side cache tracking whenever the KV cache is wiped.
| batch.clear() | |
| batch.clear() | |
| promptCaches.removeAll() |
| #if targetEnvironment(simulator) | ||
| model_params.n_gpu_layers = 0 | ||
| #else | ||
| model_params.n_gpu_layers = parameter.nGpuLayers == -1 ? 999 : Int32(parameter.nGpuLayers) |
There was a problem hiding this comment.
Using Int32(parameter.nGpuLayers) can cause a runtime trap if the value exceeds Int32.max. Additionally, negative values other than -1 are passed directly to llama.cpp, which may not be handled as expected. It's safer to use a clamping conversion and explicitly handle the -1 sentinel while ensuring other negative values are treated as 0 (CPU-only).
| model_params.n_gpu_layers = parameter.nGpuLayers == -1 ? 999 : Int32(parameter.nGpuLayers) | |
| model_params.n_gpu_layers = parameter.nGpuLayers == -1 ? 999 : Int32(clamping: max(0, parameter.nGpuLayers)) |
| // Batch.swift:20. The asymmetry between `.plain` (cleared) | ||
| // and `.chat`/`.chatTemplate` (not cleared) was the root | ||
| // cause of the residual crash in d71786a. | ||
| context.clear() |
There was a problem hiding this comment.
Calling context.clear() before every generation fixes the crash related to batch.n_tokens, but it also wipes the entire KV cache. This effectively disables the KV cache reuse logic implemented in Context.swift (via promptCaches), forcing the model to re-process the entire conversation history for every message. If KV cache reuse is desired for performance in multi-turn chats, consider only clearing the batch state instead of the entire context memory.
| tools: tools | ||
| ) | ||
| case .chat(let messages): | ||
| context.clear() |


No description provided.