Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 64 additions & 3 deletions Examples/CoreMLLLMChat/CoreMLLLMChat/LLMRunner.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,34 @@ import Foundation
import UIKit
#endif

/// Which compute units the next `loadModel` call should request.
/// Reload is required to apply — the selection is baked into the
/// per-chunk `MLModelConfiguration` at load time.
enum ComputeMode: String, CaseIterable, Identifiable {
case aneOnly
case gpuOnly
case gpuPrefill
case splitChunk3
case all

var id: String { rawValue }

/// Shared UserDefaults key. Referenced by `LLMRunner` (reader) and
/// `ModelPickerView` (writer) so the selection from the picker flows
/// into the next load without extra plumbing.
static let storageKey = "LLMRunner.computeMode"

var label: String {
switch self {
case .aneOnly: return "ANE"
case .gpuOnly: return "GPU"
case .gpuPrefill: return "ANE + GPU prefill"
case .splitChunk3: return "ANE + c3→GPU (spike)"
case .all: return "All"
}
}
}

/// Thin @Observable wrapper around CoreMLLLM for the chat app.
///
/// Delegates all inference to the CoreMLLLM package. Adds app-specific
Expand All @@ -20,6 +48,14 @@ final class LLMRunner {
var hasAudio = false
var maxAudioDuration: TimeInterval = 10.0

/// Current compute-unit preference. Read-only here — the source of
/// truth is UserDefaults, written by `ModelPickerView`'s picker.
/// `loadModel` reads this to decide `computeUnits:` + env gates.
var computeMode: ComputeMode {
let raw = UserDefaults.standard.string(forKey: ComputeMode.storageKey) ?? ""
return ComputeMode(rawValue: raw) ?? .aneOnly
}

// MTP speculation metrics
var mtpAcceptanceRate: Double = 0
var mtpTokensPerRound: Double = 0
Expand All @@ -30,7 +66,7 @@ final class LLMRunner {
var crossVocabTokensPerCycle: Double = 0

private var llm: CoreMLLLM?
private var modelFolderURL: URL?
private(set) var modelFolderURL: URL?

// MARK: - Loading

Expand Down Expand Up @@ -62,7 +98,9 @@ final class LLMRunner {
modelFolderURL = folder
loadingStatus = "Loading..."

llm = try await CoreMLLLM.load(from: folder) { [weak self] status in
let units = Self.applyComputeMode(computeMode)

llm = try await CoreMLLLM.load(from: folder, computeUnits: units) { [weak self] status in
Task { @MainActor in
self?.loadingStatus = status
}
Expand All @@ -74,7 +112,30 @@ final class LLMRunner {
maxAudioDuration = llm!.maxAudioDuration
isLoaded = true
loadingStatus = "Ready"
print("[LLMRunner] loaded: vision=\(hasVision) audio=\(hasAudio) model=\(modelName)")
print("[LLMRunner] loaded: vision=\(hasVision) audio=\(hasAudio) model=\(modelName) compute=\(computeMode.label)")
}

/// Translate `ComputeMode` into the `MLComputeUnits` + env-gate state
/// that `ChunkedEngine.load` observes. Env gates are cleared first so
/// flipping modes doesn't leave stale flags set.
@discardableResult
private static func applyComputeMode(_ mode: ComputeMode) -> MLComputeUnits {
setenv("GPU_PREFILL", "0", 1)
setenv("COMPUTE_UNIT_SPLIT", "0", 1)
switch mode {
case .aneOnly:
return .cpuAndNeuralEngine
case .gpuOnly:
return .cpuAndGPU
case .gpuPrefill:
setenv("GPU_PREFILL", "1", 1)
return .cpuAndNeuralEngine
case .splitChunk3:
setenv("COMPUTE_UNIT_SPLIT", "1", 1)
return .cpuAndNeuralEngine
case .all:
return .all
}
}

// MARK: - Generation
Expand Down
17 changes: 17 additions & 0 deletions Examples/CoreMLLLMChat/CoreMLLLMChat/ModelPickerView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,26 @@ struct ModelPickerView: View {
let downloader = ModelDownloader.shared
let onModelReady: (URL) -> Void

// Picked value is read by `LLMRunner.loadModel` at the next load
// via UserDefaults (same key as `ComputeMode.storageKey`). Applies
// to both `Load` on a downloaded model and fresh `Download` flows.
@AppStorage(ComputeMode.storageKey) private var computeMode: ComputeMode = .aneOnly

var body: some View {
NavigationStack {
List {
Section("Compute Units") {
Picker("Units", selection: $computeMode) {
ForEach(ComputeMode.allCases) { mode in
Text(mode.label).tag(mode)
}
}
.pickerStyle(.menu)
Text("Applied at Load time. Changing this without reloading has no effect.")
.font(.caption2)
.foregroundStyle(.secondary)
}

Section("Available Models") {
ForEach(downloader.availableModels) { model in
let _ = downloader.refreshTrigger
Expand Down
99 changes: 97 additions & 2 deletions Sources/CoreMLLLM/ChunkedEngine.swift
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,22 @@ final class ChunkedEngine {
}
}

// Phase D1b spike: optionally move chunk3 to .cpuAndGPU while other
// chunks stay on the inherited compute unit (usually ANE via .all).
// Hypothesis: distinct-compute-unit chunks go through distinct drivers
// and can overlap, where pure-ANE chunks serialise (see PR #75).
// Gated by COMPUTE_UNIT_SPLIT=1. Default-off, zero behaviour change.
let splitEnabled = ProcessInfo.processInfo.environment["COMPUTE_UNIT_SPLIT"] == "1"
let splitTarget = ProcessInfo.processInfo.environment["COMPUTE_UNIT_SPLIT_CHUNK"] ?? "chunk3"
let splitConfig: MLModelConfiguration = {
let c = MLModelConfiguration()
c.computeUnits = .cpuAndGPU
return c
}()
if splitEnabled {
print("[Spike] COMPUTE_UNIT_SPLIT=1 — \(splitTarget) will load on .cpuAndGPU")
}

func findModel(_ name: String) -> URL? {
// For .mlmodelc we require coremldata.bin alongside the directory
// — a half-populated directory (e.g. stray prefill_chunk with only
Expand All @@ -171,9 +187,11 @@ final class ChunkedEngine {
throw CoreMLLLMError.modelNotFound(name)
}
let t0 = CFAbsoluteTimeGetCurrent()
let m = try MLModel(contentsOf: url, configuration: cfg)
let effectiveCfg = (splitEnabled && name == splitTarget) ? splitConfig : cfg
let m = try MLModel(contentsOf: url, configuration: effectiveCfg)
let dt = CFAbsoluteTimeGetCurrent() - t0
print("[Load] \(name) done in \(String(format: "%.1f", dt))s")
print("[Load] \(name) done in \(String(format: "%.1f", dt))s" +
(splitEnabled && name == splitTarget ? " (.cpuAndGPU)" : ""))
return m
}

Expand Down Expand Up @@ -407,6 +425,14 @@ final class ChunkedEngine {
engine.reset()
print("[Load] ANE prewarm (4 steps) done in \(String(format: "%.2f", CFAbsoluteTimeGetCurrent() - warmT0))s")

// Phase D1b spike: one-shot c2/c3 cross-compute-unit overlap probe.
// Only runs when COMPUTE_UNIT_SPLIT=1 (so chunk3 is on .cpuAndGPU).
// Mirrors PR #75's probe pattern but measures c2 (ANE) vs c3 (GPU).
if splitEnabled {
try engine.runComputeUnitSplitProbe()
engine.reset()
}

return engine
}

Expand Down Expand Up @@ -1585,6 +1611,75 @@ final class ChunkedEngine {
func makeDrafterFullMask(position: Int) throws -> MLMultiArray {
try makeCausalMask(position: position, length: config.contextLength)
}

// MARK: - Phase D1b spike: compute-unit-split concurrency probe

/// One-shot probe: can c2 (ANE) and c3 (.cpuAndGPU) predictions overlap
/// on separate DispatchQueues when they go through distinct drivers?
/// Mirrors PR #75's pattern but pairs c2/c3 for the split experiment.
func runComputeUnitSplitProbe() throws {
print("[Spike] Running compute-unit-split probe (c2 ANE vs c3 .cpuAndGPU)")
let p = 1
let fv: (MLMultiArray) -> MLFeatureValue = { MLFeatureValue(multiArray: $0) }
let rope: [String: MLFeatureValue] = [
"causal_mask_full": fv(try makeCausalMask(position: p, length: config.contextLength)),
"causal_mask_sliding": fv(try makeSlidingCausalMask(position: p, W: config.slidingWindow)),
"update_mask": fv(try makeUpdateMask(position: p, length: config.contextLength)),
"cos_s": fv(try lookupRoPE(table: cosSlidingTable, position: p, dim: 256)),
"sin_s": fv(try lookupRoPE(table: sinSlidingTable, position: p, dim: 256)),
"cos_f": fv(try lookupRoPE(table: cosFullTable, position: p, dim: 512)),
"sin_f": fv(try lookupRoPE(table: sinFullTable, position: p, dim: 512))]
var d1 = rope
d1["hidden_states"] = fv(try embedTokens.lookup(0, shape: [1, 1, NSNumber(value: config.hiddenSize)]))
d1["per_layer_raw"] = fv(try lookupPerLayerRaw(tokenID: 0))
d1["K_sliding_in"] = fv(kSliding1); d1["V_sliding_in"] = fv(vSliding1)
d1["K_full_in"] = fv(kFull1); d1["V_full_in"] = fv(vFull1)
let o1 = try chunk1.prediction(from: try MLDictionaryFeatureProvider(dictionary: d1))
let plc = o1.featureValue(for: "per_layer_combined_out")!.multiArrayValue!
var d2 = rope
d2["hidden_states"] = fv(o1.featureValue(for: "hidden_states_out")!.multiArrayValue!)
d2["per_layer_combined"] = fv(plc)
d2["K_sliding_in"] = fv(kSliding2); d2["V_sliding_in"] = fv(vSliding2)
d2["K_full_in"] = fv(kFull2); d2["V_full_in"] = fv(vFull2)
let c2Inputs = try MLDictionaryFeatureProvider(dictionary: d2)
let o2 = try chunk2.prediction(from: c2Inputs)
var d3 = rope
d3["hidden_states"] = fv(o2.featureValue(for: "hidden_states_out")!.multiArrayValue!)
d3["per_layer_combined"] = fv(plc)
for k in ["kv13_k", "kv13_v", "kv14_k", "kv14_v"] { d3[k] = fv(o2.featureValue(for: k)!.multiArrayValue!) }
let c3Inputs = try MLDictionaryFeatureProvider(dictionary: d3)
_ = try chunk2.prediction(from: c2Inputs); _ = try chunk3.prediction(from: c3Inputs) // warm
let trials = 10
func time(_ block: () throws -> Void) rethrows -> Double {
let t0 = CFAbsoluteTimeGetCurrent()
for _ in 0..<trials { try block() }
return (CFAbsoluteTimeGetCurrent() - t0) / Double(trials)
}
let sC2 = try time { _ = try self.chunk2.prediction(from: c2Inputs) }
let sC3 = try time { _ = try self.chunk3.prediction(from: c3Inputs) }
let seq = try time {
_ = try self.chunk2.prediction(from: c2Inputs); _ = try self.chunk3.prediction(from: c3Inputs)
}
let q2 = DispatchQueue(label: "spike.c2", qos: .userInitiated)
let q3 = DispatchQueue(label: "spike.c3", qos: .userInitiated)
let par = try time {
let g = DispatchGroup(); var e2: Error?; var e3: Error?
g.enter(); q2.async { do { _ = try self.chunk2.prediction(from: c2Inputs) } catch { e2 = error }; g.leave() }
g.enter(); q3.async { do { _ = try self.chunk3.prediction(from: c3Inputs) } catch { e3 = error }; g.leave() }
g.wait(); if let e = e2 ?? e3 { throw e }
}
let ideal = max(sC2, sC3), sum = sC2 + sC3
let overlap = (sum - par) / max(sum - ideal, 1e-6)
print(String(format: "[Spike] c2_serial=%.2fms c3_serial=%.2fms seq_both=%.2fms parallel=%.2fms",
sC2 * 1000, sC3 * 1000, seq * 1000, par * 1000))
print(String(format: "[Spike] ideal_parallel=%.2fms sum=%.2fms overlap_factor=%.2f (1.0=full, 0.0=serial)",
ideal * 1000, sum * 1000, overlap))
let verdict = overlap > 0.5 ? "strong overlap — pursue full compute-unit-split implementation"
: overlap > 0.30 ? "meaningful overlap — evaluate full 4-way assignment"
: overlap > 0.15 ? "partial overlap — marginal, likely net-neutral after GPU deficit"
: "no overlap — cross-compute-unit also serializes at system level"
print("[Spike] VERDICT: \(verdict).")
}
}

// MARK: - SpeculativeTarget conformance
Expand Down
Loading