Skip to content

Feature/expose metal flashattn kv#97

Open
matrixsoftwarelimited wants to merge 3 commits intotattn:mainfrom
matrixsoftwarelimited:feature/expose-metal-flashattn-kv
Open

Feature/expose metal flashattn kv#97
matrixsoftwarelimited wants to merge 3 commits intotattn:mainfrom
matrixsoftwarelimited:feature/expose-metal-flashattn-kv

Conversation

@matrixsoftwarelimited
Copy link
Copy Markdown

No description provided.

SLM added 3 commits April 28, 2026 14:12
Adds four new parameters to `LlamaClient.Parameter` so callers can control
Metal GPU offload, Flash Attention, and KV cache quantization on Apple
platforms — knobs that were previously baked into llama.cpp defaults and
not reachable from Swift.

* `nGpuLayers: Int = -1` — number of model layers to offload to the GPU.
  `-1` means "all", which is the right call on Apple Silicon's unified
  memory. Wired through `Model.init` to `llama_model_params.n_gpu_layers`.
  Simulator override (forced 0) is preserved.
* `flashAttention: Bool = true` — toggles Flash Attention. llama.cpp
  `b8851` exposes this as a tri-state enum (`auto/disabled/enabled`); the
  bool maps to explicit `enabled`/`disabled` so behavior is deterministic.
* `kvCacheTypeK` / `kvCacheTypeV: KVCacheType = .f16` — quantization for
  the KV cache, mapped to `ggml_type` (`GGML_TYPE_F16/Q8_0/Q4_0`). `q8_0`
  roughly halves cache memory at negligible quality cost, allowing larger
  contexts at the same RAM footprint.

The `KVCacheType` Swift enum is added to `LlamaClient` as a public type so
callers don't need to import the C header directly.

Backwards compatible: defaults match prior behavior. Old call sites compile
unchanged.
`Context.clear()` now also resets the prefill batch alongside the KV cache.

Without this, a generation that ends via an external stop condition (e.g.
stop sequences enforced at the consumer level) leaves the batch struct
with `n_tokens > 0`. The Generator does `batch.add` per emitted token and
relies on the NEXT iteration's `decode()` to call `batch.clear()`; if the
consumer breaks out of `for try await ... in stream` before that next
iteration, the cleanup never runs. The subsequent `textStream(...)` call
then enters `decode(text:)` with stale `n_tokens`, walks past the end of
`seq_id` (which is allocated for exactly `parameter.batch` entries by
`llama_batch_init`), and crashes on a force-unwrap of nil at
`Batch.swift:20`.

Use `n_gpu_layers = 999` instead of `Int32.max` for the "all layers"
sentinel — matches the value used throughout the llama.cpp examples and
sidesteps any internal arithmetic that might overflow with INT32_MAX.
Round-2 fix (d71786a) added `batch.clear()` to `Context.clear()` but only
the `.plain` case in `LlamaClient.textStream(from:)` calls `context.clear()`.
The `.chat` and `.chatTemplate` paths went straight into
`messageProcessor.process(...)`, so a stale `batch.n_tokens > 0` from a
previous consumer-level early termination (stop sequence or maxTokens break)
survived into the next generation's prefill — same crash, different trigger.

Surface a `context.clear()` at the top of all three input cases so the
fix is consumer-symmetric. Reproduced on iOS with Gemma 3 1B Q4_K_M:
persona turn 1 (stop sequence) → judge turn 1 (maxTokens=200 break) →
persona turn 2 prefill → Batch.swift:20 force-unwrap of nil seq_id slot.
@matrixsoftwarelimited
Copy link
Copy Markdown
Author

all done

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Flash Attention, KV cache quantization, and configurable GPU layer offloading. It also addresses a crash occurring when generations are interrupted by ensuring the batch state is cleared. Feedback focuses on ensuring the Swift-side promptCaches are cleared alongside the KV cache to prevent state mismatches. There are also suggestions to improve the safety of GPU layer parameter conversion and to reconsider the performance impact of full context clearing in chat modes, which may inadvertently disable KV cache reuse.

// of the batch's `seq_id` array (allocated for `parameter.batch`
// entries) and crashes on a force-unwrap of nil. Clearing the batch
// here makes `clear()` safe to call between any two generations.
batch.clear()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The clear() method resets the underlying KV cache using llama_memory_clear, but it does not clear the promptCaches array. This will lead to a state mismatch where subsequent calls to textStream might skip processing prompt chunks that are no longer in the KV cache, resulting in incorrect model output. You should clear the Swift-side cache tracking whenever the KV cache is wiped.

Suggested change
batch.clear()
batch.clear()
promptCaches.removeAll()

#if targetEnvironment(simulator)
model_params.n_gpu_layers = 0
#else
model_params.n_gpu_layers = parameter.nGpuLayers == -1 ? 999 : Int32(parameter.nGpuLayers)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using Int32(parameter.nGpuLayers) can cause a runtime trap if the value exceeds Int32.max. Additionally, negative values other than -1 are passed directly to llama.cpp, which may not be handled as expected. It's safer to use a clamping conversion and explicitly handle the -1 sentinel while ensuring other negative values are treated as 0 (CPU-only).

Suggested change
model_params.n_gpu_layers = parameter.nGpuLayers == -1 ? 999 : Int32(parameter.nGpuLayers)
model_params.n_gpu_layers = parameter.nGpuLayers == -1 ? 999 : Int32(clamping: max(0, parameter.nGpuLayers))

// Batch.swift:20. The asymmetry between `.plain` (cleared)
// and `.chat`/`.chatTemplate` (not cleared) was the root
// cause of the residual crash in d71786a.
context.clear()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling context.clear() before every generation fixes the crash related to batch.n_tokens, but it also wipes the entire KV cache. This effectively disables the KV cache reuse logic implemented in Context.swift (via promptCaches), forcing the model to re-process the entire conversation history for every message. If KV cache reuse is desired for performance in multi-turn chats, consider only clearing the batch state instead of the entire context memory.

tools: tools
)
case .chat(let messages):
context.clear()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

As noted in the .chatTemplate case, calling context.clear() here wipes the KV cache and disables reuse. Consider if a full reset is necessary or if only clearing the batch state would suffice to fix the reported crash while preserving performance.

@tattn
Copy link
Copy Markdown
Owner

tattn commented Apr 29, 2026

Thanks for the PR.
If you clear the context, won’t it become impossible to continue the conversation like this?

Current main branch:
image

This PR:
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants