diff --git a/Sources/LocalLLMClientLlama/Context.swift b/Sources/LocalLLMClientLlama/Context.swift index b0811b6..daa50c5 100644 --- a/Sources/LocalLLMClientLlama/Context.swift +++ b/Sources/LocalLLMClientLlama/Context.swift @@ -44,9 +44,24 @@ public final class Context: @unchecked Sendable { ctx_params.n_threads = Int32(parameter.numberOfThreads ?? max(1, min(8, ProcessInfo.processInfo.processorCount - 2))) ctx_params.n_threads_batch = ctx_params.n_threads + // Flash Attention — significantly faster on Apple Silicon and uses + // less attention-buffer memory. llama.cpp `b8851` exposes this as a + // tri-state enum (auto / disabled / enabled). We map our boolean to + // explicit enabled/disabled so behavior is deterministic. + ctx_params.flash_attn_type = parameter.flashAttention + ? LLAMA_FLASH_ATTN_TYPE_ENABLED + : LLAMA_FLASH_ATTN_TYPE_DISABLED + + // KV cache quantization. Lower precision halves (or quarters) the + // memory cost of the cache, which scales linearly with `n_ctx`. + // `f16` keeps full precision (default); `q8_0` and `q4_0` trade a + // small amount of quality for substantial memory savings. + ctx_params.type_k = Self.ggmlType(for: parameter.kvCacheTypeK) + ctx_params.type_v = Self.ggmlType(for: parameter.kvCacheTypeV) + self.parameter = parameter self.pauseHandler = PauseHandler(disableAutoPause: parameter.options.disableAutoPause) - self.model = try Model(url: url) + self.model = try Model(url: url, parameter: parameter) self.context = try model.makeAndAllocateContext(with: ctx_params) batch = llama_batch_init(Int32(parameter.batch), 0, 1) extraEOSTokens = parameter.options.extraEOSTokens @@ -92,7 +107,28 @@ public final class Context: @unchecked Sendable { llama_free(context) } + /// Maps the public Swift KV cache type enum to the underlying GGML type. + private static func ggmlType(for type: LlamaClient.KVCacheType) -> ggml_type { + switch type { + case .f16: return GGML_TYPE_F16 + case .q8_0: return GGML_TYPE_Q8_0 + case .q4_0: return GGML_TYPE_Q4_0 + } + } + public func clear() { + // Reset the prefill batch as well as the KV cache. Without this, a + // generation that was cut short by an external stop condition (e.g. + // stop sequences applied at the consumer level) leaves + // `batch.n_tokens > 0` because the generator's per-token `batch.add` + // is followed by an early `break` in the consumer's `for try await`, + // skipping the next `decode()` that would have called `batch.clear()`. + // The next `textStream(...)` call's prefill then walks past the end + // of the batch's `seq_id` array (allocated for `parameter.batch` + // entries) and crashes on a force-unwrap of nil. Clearing the batch + // here makes `clear()` safe to call between any two generations. + batch.clear() + guard let kv = llama_get_memory(context) else { return } diff --git a/Sources/LocalLLMClientLlama/LlamaClient.swift b/Sources/LocalLLMClientLlama/LlamaClient.swift index 14e0320..289b443 100644 --- a/Sources/LocalLLMClientLlama/LlamaClient.swift +++ b/Sources/LocalLLMClientLlama/LlamaClient.swift @@ -63,6 +63,16 @@ public final class LlamaClient: LLMClient { context.clear() try context.decode(text: text) case .chatTemplate(let messages): + // Match the `.plain` path: reset the prefill batch + KV cache + // before each new generation. Without this, a previous + // generation that was cut short by an external stop condition + // (consumer-level stop sequence or maxTokens break) leaves + // stale `batch.n_tokens > 0` and the next prefill walks past + // the end of `seq_id`, crashing on the force-unwrap at + // Batch.swift:20. The asymmetry between `.plain` (cleared) + // and `.chat`/`.chatTemplate` (not cleared) was the root + // cause of the residual crash in d71786a. + context.clear() try messageProcessor.process( templateMessages: messages, context: context, @@ -70,6 +80,7 @@ public final class LlamaClient: LLMClient { tools: tools ) case .chat(let messages): + context.clear() try messageProcessor.process( messages: messages, context: context, diff --git a/Sources/LocalLLMClientLlama/Model.swift b/Sources/LocalLLMClientLlama/Model.swift index a6493db..3641768 100644 --- a/Sources/LocalLLMClientLlama/Model.swift +++ b/Sources/LocalLLMClientLlama/Model.swift @@ -11,10 +11,23 @@ final class Model { llama_model_get_vocab(model) } - init(url: URL) throws(LLMError) { + init(url: URL, parameter: LlamaClient.Parameter = .default) throws(LLMError) { var model_params = llama_model_default_params() + + // GPU layer offload. On Apple Silicon (real device + Mac) the GPU has + // unified memory access so offloading "all" layers is the desired + // setting. On the iOS Simulator there is no Metal device available + // for llama.cpp, so we force CPU-only regardless of the requested + // value to avoid runtime failures. + // + // We use 999 as the "all layers" sentinel (the same value used + // throughout the llama.cpp examples). `Int32.max` was tried first + // but appears to trigger internal arithmetic edge cases in + // `llama_batch` allocation paths on b8851; 999 sidesteps that. #if targetEnvironment(simulator) model_params.n_gpu_layers = 0 +#else + model_params.n_gpu_layers = parameter.nGpuLayers == -1 ? 999 : Int32(parameter.nGpuLayers) #endif model_params.use_mmap = true diff --git a/Sources/LocalLLMClientLlama/Parameter.swift b/Sources/LocalLLMClientLlama/Parameter.swift index 7684342..3e826e5 100644 --- a/Sources/LocalLLMClientLlama/Parameter.swift +++ b/Sources/LocalLLMClientLlama/Parameter.swift @@ -17,6 +17,10 @@ public extension LlamaClient { /// - typicalP: Limits sampling based on typical probability. Default is `1`. /// - penaltyLastN: The number of recent tokens to consider for penalty. Default is `64`. /// - penaltyRepeat: The penalty factor for repeating tokens. Default is `1.1`. + /// - nGpuLayers: Number of model layers to offload to the GPU (Metal on Apple platforms). `-1` means "all layers". `0` runs everything on CPU. Default is `-1`. + /// - flashAttention: Enable Flash Attention. On Apple Silicon this is significantly faster and uses less memory. Default is `true`. + /// - kvCacheTypeK: Quantization type for the key half of the KV cache. Lower precision (`.q8_0`, `.q4_0`) reduces memory roughly proportionally, allowing larger context. Default is `.f16` (no quantization). + /// - kvCacheTypeV: Quantization type for the value half of the KV cache. Default is `.f16`. /// - options: Additional options for the Llama client. public init( context: Int = 2048, @@ -29,6 +33,10 @@ public extension LlamaClient { typicalP: Float = 1, penaltyLastN: Int = 64, penaltyRepeat: Float = 1.1, + nGpuLayers: Int = -1, + flashAttention: Bool = true, + kvCacheTypeK: KVCacheType = .f16, + kvCacheTypeV: KVCacheType = .f16, options: Options = .init() ) { self.context = context @@ -41,9 +49,13 @@ public extension LlamaClient { self.typicalP = typicalP self.penaltyLastN = penaltyLastN self.penaltyRepeat = penaltyRepeat + self.nGpuLayers = nGpuLayers + self.flashAttention = flashAttention + self.kvCacheTypeK = kvCacheTypeK + self.kvCacheTypeV = kvCacheTypeV self.options = options } - + /// The size of the context window in tokens. public var context: Int /// The random seed for generation. `nil` means a random seed will be used. @@ -65,6 +77,19 @@ public extension LlamaClient { /// The penalty factor for repeating tokens. public var penaltyRepeat: Float + /// Number of model layers to offload to the GPU (Metal on Apple platforms). + /// `-1` means "all layers" — typically the desired setting on Apple Silicon + /// where the GPU has unified memory access. `0` forces CPU-only execution. + public var nGpuLayers: Int + /// Enable Flash Attention. On Apple Silicon (M-series, A-series) this is + /// significantly faster and uses less memory. Set to `false` only for + /// hardware that lacks support or for debugging. + public var flashAttention: Bool + /// Quantization type for the key half of the KV cache. + public var kvCacheTypeK: KVCacheType + /// Quantization type for the value half of the KV cache. + public var kvCacheTypeV: KVCacheType + /// Additional options for the Llama client. public var options: Options @@ -72,6 +97,18 @@ public extension LlamaClient { public static let `default` = Parameter() } + /// KV cache quantization types. Lower precision reduces memory used by the + /// KV cache, which scales linearly with `context` size. The quality cost + /// is usually negligible at `.q8_0` and only modest at `.q4_0`. + enum KVCacheType: String, Sendable, CaseIterable { + /// Half precision (16-bit float). Default; no quantization. + case f16 + /// 8-bit quantization. Roughly 50% memory of `.f16`, very low quality cost. + case q8_0 + /// 4-bit quantization. Roughly 25% memory of `.f16`, modest quality cost. + case q4_0 + } + /// Defines additional, less commonly used options for the Llama client. struct Options: Sendable { /// Initializes a new set of options for the Llama client.