From f629497af258f1407e90d169c395573d0dd93b49 Mon Sep 17 00:00:00 2001 From: shavit Date: Sat, 20 Apr 2024 11:09:22 -0400 Subject: [PATCH 1/6] Read weights from safetensors --- Package.swift | 3 +- Sources/Models/LanguageModel.swift | 79 +++++++++++++ Tests/ModelsTests/LanguageModelTests.swift | 106 ++++++++++++++++++ .../Resources/tensor-1d-int32.safetensors | Bin 0 -> 84 bytes .../Resources/tensor-2d-float64.safetensors | Bin 0 -> 128 bytes .../Resources/tensor-3d.safetensors | Bin 0 -> 128 bytes .../Resources/tensor-4d.safetensors | Bin 0 -> 152 bytes 7 files changed, 187 insertions(+), 1 deletion(-) create mode 100644 Tests/ModelsTests/LanguageModelTests.swift create mode 100644 Tests/ModelsTests/Resources/tensor-1d-int32.safetensors create mode 100644 Tests/ModelsTests/Resources/tensor-2d-float64.safetensors create mode 100644 Tests/ModelsTests/Resources/tensor-3d.safetensors create mode 100644 Tests/ModelsTests/Resources/tensor-4d.safetensors diff --git a/Package.swift b/Package.swift index de98aa6..d563e5c 100644 --- a/Package.swift +++ b/Package.swift @@ -31,6 +31,7 @@ let package = Package( .testTarget(name: "PreTokenizerTests", dependencies: ["Tokenizers", "Hub"]), .testTarget(name: "TensorUtilsTests", dependencies: ["TensorUtils"]), .testTarget(name: "NormalizerTests", dependencies: ["Tokenizers", "Hub"]), - .testTarget(name: "PostProcessorTests", dependencies: ["Tokenizers", "Hub"]) + .testTarget(name: "PostProcessorTests", dependencies: ["Tokenizers", "Hub"]), + .testTarget(name: "ModelsTests", dependencies: ["Models", "Hub"], resources: [.process("Resources")]) ] ) diff --git a/Sources/Models/LanguageModel.swift b/Sources/Models/LanguageModel.swift index e46a23c..31f7b2b 100644 --- a/Sources/Models/LanguageModel.swift +++ b/Sources/Models/LanguageModel.swift @@ -21,6 +21,7 @@ public class LanguageModel { struct Configurations { var modelConfig: Config + var modelWeights: ModelWeights var tokenizerConfig: Config? var tokenizerData: Config } @@ -213,3 +214,81 @@ extension LanguageModel: TextGenerationModel { } extension String: Error {} + +// MARK: - Model weights + +struct ModelWeights { + + enum ModelWeightsError: Error { + case notSupported(message: String) + case invalidFile + } + + private let dictionary: [String: MLMultiArray] + + init(_ dictionary: [String: MLMultiArray]) { + self.dictionary = dictionary + } + + subscript(key: String) -> MLMultiArray { dictionary[key]! } + + static func from(fileURL: URL) throws -> ModelWeights { + // TODO: Either this or switch or both + guard fileURL.pathExtension == "safetensors" else { throw ModelWeightsError.notSupported(message: "\(fileURL.pathExtension)") } + /* + switch data.subdata(in: 0..<8).withUnsafeBytes({ $0.load(as: [Int8].self) }) { + case [0x93, 0x4e, 0x55, 0x4d, 0x50, 0x59]: fatalError("mlx is not supported") + case [0x47, 0x47, 0x55, 0x46]: fatalError("gguf is not supported") + default: throw ModelWeightsError.notSupported(message: "found \(data)") + } + */ + let data = try Data(contentsOf: fileURL, options: .mappedIfSafe) + + // Safetensors part + let headerSize: UInt64 = data.subdata(in: 0..<8).withUnsafeBytes({ $0.load(as: UInt64.self) }) + let header = try SafetensorHeader.from(data: data, offset: 8, size: headerSize) + + var dict = [String: MLMultiArray]() + for (key, point) in header { + guard let offsets = point?.dataOffsets, offsets.count >= 2, + let shape = point?.shape as? [NSNumber] + else { continue } + + let strides = shape.dropFirst().reversed().reduce(into: [1]) { acc, a in + acc.insert(acc[0].intValue * a.intValue as NSNumber, at: 0) + } + let start = Data.Index(UInt64(8 + offsets[0]) + headerSize) + let end = Data.Index(UInt64(8 + offsets[1]) + headerSize) + let tensorData = data.subdata(in: start.. [String: Offset?] { + assert(size < data.count) + let decoder = JSONDecoder() + decoder.keyDecodingStrategy = .convertFromSnakeCase + return try decoder.decode([String: Offset?].self, from: data.subdata(in: offset..<(offset + Int(size)))) + } + } +} diff --git a/Tests/ModelsTests/LanguageModelTests.swift b/Tests/ModelsTests/LanguageModelTests.swift new file mode 100644 index 0000000..bce800f --- /dev/null +++ b/Tests/ModelsTests/LanguageModelTests.swift @@ -0,0 +1,106 @@ +import Models + +@testable import Models +@testable import Hub +import XCTest + +class ModelWeightsTests: XCTestCase { + + let downloadDestination: URL = { + FileManager.default.urls(for: .cachesDirectory, in: .userDomainMask).first!.appending(component: "huggingface-tests") + }() + + var hubApi: HubApi { HubApi(downloadBase: downloadDestination) } + + func testLoadWeightsFromFileURL() async throws { + let repo = "google/bert_uncased_L-2_H-128_A-2" + let modelDir = try await hubApi.snapshot(from: repo) + + let files = try FileManager.default.contentsOfDirectory(at: modelDir, includingPropertiesForKeys: [.isReadableKey]) + XCTAssertTrue(files.contains(where: { $0.lastPathComponent == "config.json" })) + XCTAssertTrue(files.contains(where: { $0.lastPathComponent == "model.safetensors" })) + + let modelFile = modelDir.appending(path: "/model.safetensors") + let weights = try ModelWeights.from(fileURL: modelFile) + XCTAssertEqual(weights["bert.embeddings.LayerNorm.bias"].dataType, .float32) + XCTAssertEqual(weights["bert.embeddings.LayerNorm.bias"].count, 128) + XCTAssertEqual(weights["bert.embeddings.LayerNorm.bias"].shape.count, 1) + + XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"].dataType, .float32) + XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"].count, 3906816) + XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"].shape.count, 2) + + XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"][[0, 0]].floatValue, -0.0041, accuracy: 1e-3) + XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"][[3, 4]].floatValue, 0.0037, accuracy: 1e-3) + XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"][[5, 3]].floatValue, -0.5371, accuracy: 1e-3) + XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"][[7, 8]].floatValue, 0.0460, accuracy: 1e-3) + XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"][[11, 7]].floatValue, -0.0058, accuracy: 1e-3) + } + + func testSafetensorReadTensor1D() throws { + let modelFile = Bundle.module.url(forResource: "tensor-1d-int32", withExtension: "safetensors")! + let weights: ModelWeights = try ModelWeights.from(fileURL: modelFile) + let tensor = weights["embedding"] + XCTAssertEqual(tensor.dataType, .int32) + XCTAssertEqual(tensor[[0]], 1) + XCTAssertEqual(tensor[[1]], 2) + XCTAssertEqual(tensor[[2]], 3) + } + + func testSafetensorReadTensor2D() throws { + let modelFile = Bundle.module.url(forResource: "tensor-2d-float64", withExtension: "safetensors")! + let weights: ModelWeights = try ModelWeights.from(fileURL: modelFile) + let tensor = weights["embedding"] + XCTAssertEqual(tensor.dataType, .float64) + XCTAssertEqual(tensor[[0, 0]], 1) + XCTAssertEqual(tensor[[0, 1]], 2) + XCTAssertEqual(tensor[[0, 2]], 3) + XCTAssertEqual(tensor[[1, 0]], 24) + XCTAssertEqual(tensor[[1, 1]], 25) + XCTAssertEqual(tensor[[1, 2]], 26) + } + + func testSafetensorReadTensor3D() throws { + let modelFile = Bundle.module.url(forResource: "tensor-3d", withExtension: "safetensors")! + let weights: ModelWeights = try ModelWeights.from(fileURL: modelFile) + let tensor = weights["embedding"] + XCTAssertEqual(tensor.dataType, .float32) + XCTAssertEqual(tensor[[0, 0, 0]], 22) + XCTAssertEqual(tensor[[0, 0, 1]], 23) + XCTAssertEqual(tensor[[0, 0, 2]], 24) + XCTAssertEqual(tensor[[0, 1, 0]], 11) + XCTAssertEqual(tensor[[0, 1, 1]], 12) + XCTAssertEqual(tensor[[0, 1, 2]], 13) + XCTAssertEqual(tensor[[1, 0, 0]], 2) + XCTAssertEqual(tensor[[1, 0, 1]], 3) + XCTAssertEqual(tensor[[1, 0, 2]], 4) + XCTAssertEqual(tensor[[1, 1, 0]], 1) + XCTAssertEqual(tensor[[1, 1, 1]], 2) + XCTAssertEqual(tensor[[1, 1, 2]], 3) + } + + func testSafetensorReadTensor4D() throws { + let modelFile = Bundle.module.url(forResource: "tensor-4d", withExtension: "safetensors")! + let weights: ModelWeights = try ModelWeights.from(fileURL: modelFile) + let tensor = weights["embedding"] + XCTAssertEqual(tensor.dataType, .float32) + XCTAssertEqual(tensor[[0, 0, 0, 0]], 11) + XCTAssertEqual(tensor[[0, 0, 0, 1]], 12) + XCTAssertEqual(tensor[[0, 0, 0, 2]], 13) + XCTAssertEqual(tensor[[0, 0, 1, 0]], 1) + XCTAssertEqual(tensor[[0, 0, 1, 1]], 2) + XCTAssertEqual(tensor[[0, 0, 1, 2]], 3) + XCTAssertEqual(tensor[[0, 0, 2, 0]], 4) + XCTAssertEqual(tensor[[0, 0, 2, 1]], 5) + XCTAssertEqual(tensor[[0, 0, 2, 2]], 6) + XCTAssertEqual(tensor[[1, 0, 0, 0]], 22) + XCTAssertEqual(tensor[[1, 0, 0, 1]], 23) + XCTAssertEqual(tensor[[1, 0, 0, 2]], 24) + XCTAssertEqual(tensor[[1, 0, 1, 0]], 15) + XCTAssertEqual(tensor[[1, 0, 1, 1]], 16) + XCTAssertEqual(tensor[[1, 0, 1, 2]], 17) + XCTAssertEqual(tensor[[1, 0, 2, 0]], 17) + XCTAssertEqual(tensor[[1, 0, 2, 1]], 18) + XCTAssertEqual(tensor[[1, 0, 2, 2]], 19) + } +} diff --git a/Tests/ModelsTests/Resources/tensor-1d-int32.safetensors b/Tests/ModelsTests/Resources/tensor-1d-int32.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..b60400284a288a972b32283e2f62a4390e500347 GIT binary patch literal 84 zcmZ=@fPiYH)ZC=hl$6Z8bS0~5rIeD&f>b3dB~N1`B^{;Wj6@JG+BjB6DJ8KaF+M*n btvI!$7${_*V`vmxTdTkbG=>R?nSmGpu=fb3dB{wq@B^{;Wj6@JG+DOMZR!1o%u_Q4* pKP{~|wWJs*XrNb3dB{yRuB^{;Wj6@JG+DOMp$2e9;DJ8Ka zF+M*ntvI!$7$|I@V`33oTdM#93=A6_85ni|@c|$<0AdFq4gg{X2Ot9C1`xJ~@c}>N B9d!Tz literal 0 HcmV?d00001 diff --git a/Tests/ModelsTests/Resources/tensor-4d.safetensors b/Tests/ModelsTests/Resources/tensor-4d.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..0f7a80d58dc60d0ddde9477125c42348722134fc GIT binary patch literal 152 zcmeZZfPiYH)ZC=hl$6Z8bS0~5rIeD&f>b3dB{yRuB^{;Wj6@JG+DOM($56*OR!1o% zu_Q4*KP{~|wWJs*VW4Af6kA)XprF9OVBpBW-~hw{K-^#tG~EG+fVcsO7Xa}AAl?9! W+X2J}fVcpp4v0ITcmj}~;RpaVOCw$Y literal 0 HcmV?d00001 From 8dc29794a3a652cddd70fa4496dcea87f1ccf2ad Mon Sep 17 00:00:00 2001 From: shavit Date: Sat, 20 Apr 2024 11:29:01 -0400 Subject: [PATCH 2/6] Check file extension --- Sources/Models/LanguageModel.swift | 34 +++++++++++++++--------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/Sources/Models/LanguageModel.swift b/Sources/Models/LanguageModel.swift index 31f7b2b..1a4d620 100644 --- a/Sources/Models/LanguageModel.swift +++ b/Sources/Models/LanguageModel.swift @@ -233,21 +233,21 @@ struct ModelWeights { subscript(key: String) -> MLMultiArray { dictionary[key]! } static func from(fileURL: URL) throws -> ModelWeights { - // TODO: Either this or switch or both - guard fileURL.pathExtension == "safetensors" else { throw ModelWeightsError.notSupported(message: "\(fileURL.pathExtension)") } - /* - switch data.subdata(in: 0..<8).withUnsafeBytes({ $0.load(as: [Int8].self) }) { - case [0x93, 0x4e, 0x55, 0x4d, 0x50, 0x59]: fatalError("mlx is not supported") - case [0x47, 0x47, 0x55, 0x46]: fatalError("gguf is not supported") - default: throw ModelWeightsError.notSupported(message: "found \(data)") - } - */ + guard ["safetensors", "gguf", "mlx"].contains(fileURL.pathExtension) + else { throw ModelWeightsError.notSupported(message: "\(fileURL.pathExtension)") } + let data = try Data(contentsOf: fileURL, options: .mappedIfSafe) - - // Safetensors part + switch [UInt8](data.subdata(in: 0..<4)) { + case [0x47, 0x47, 0x55, 0x46]: throw ModelWeightsError.notSupported(message: ("gguf")) + case [0x93, 0x4e, 0x55, 0x4d]: throw ModelWeightsError.notSupported(message: "mlx") // Actually [0x93, 0x4e, 0x55, 0x4d, 0x50, 0x59] + default: return try fromSafetensor(data: data) + } + } + + static private func fromSafetensor(data: Data) throws -> ModelWeights { let headerSize: UInt64 = data.subdata(in: 0..<8).withUnsafeBytes({ $0.load(as: UInt64.self) }) let header = try SafetensorHeader.from(data: data, offset: 8, size: headerSize) - + var dict = [String: MLMultiArray]() for (key, point) in header { guard let offsets = point?.dataOffsets, offsets.count >= 2, @@ -262,7 +262,7 @@ struct ModelWeights { let tensorData = data.subdata(in: start.. [String: Offset?] { - assert(size < data.count) + guard size < data.count else { throw ModelWeightsError.invalidFile } let decoder = JSONDecoder() decoder.keyDecodingStrategy = .convertFromSnakeCase return try decoder.decode([String: Offset?].self, from: data.subdata(in: offset..<(offset + Int(size)))) From e7308831643b5525d2b30996ad768fb37700c734 Mon Sep 17 00:00:00 2001 From: shavit Date: Mon, 22 Apr 2024 14:15:10 -0400 Subject: [PATCH 3/6] Deintegrate Safetensor * Separate Safetensor from the weights * Rename test tensors to include type * Rename ModelWeights to Weights * Throw error for unsupported data types * Remove model weights from LanguageModel.Configurations --- Sources/Models/LanguageModel.swift | 93 ++++++++++-------- Tests/ModelsTests/LanguageModelTests.swift | 16 +-- ...etensors => tensor-3d-float32.safetensors} | Bin ...etensors => tensor-4d-float32.safetensors} | Bin 4 files changed, 59 insertions(+), 50 deletions(-) rename Tests/ModelsTests/Resources/{tensor-3d.safetensors => tensor-3d-float32.safetensors} (100%) rename Tests/ModelsTests/Resources/{tensor-4d.safetensors => tensor-4d-float32.safetensors} (100%) diff --git a/Sources/Models/LanguageModel.swift b/Sources/Models/LanguageModel.swift index 1a4d620..af3c958 100644 --- a/Sources/Models/LanguageModel.swift +++ b/Sources/Models/LanguageModel.swift @@ -21,11 +21,10 @@ public class LanguageModel { struct Configurations { var modelConfig: Config - var modelWeights: ModelWeights var tokenizerConfig: Config? var tokenizerData: Config } - + private var configuration: LanguageModelConfigurationFromHub? = nil private var _tokenizer: Tokenizer? = nil @@ -217,9 +216,9 @@ extension String: Error {} // MARK: - Model weights -struct ModelWeights { +struct Weights { - enum ModelWeightsError: Error { + enum WeightsError: Error { case notSupported(message: String) case invalidFile } @@ -232,63 +231,73 @@ struct ModelWeights { subscript(key: String) -> MLMultiArray { dictionary[key]! } - static func from(fileURL: URL) throws -> ModelWeights { + static func from(fileURL: URL) throws -> Weights { guard ["safetensors", "gguf", "mlx"].contains(fileURL.pathExtension) - else { throw ModelWeightsError.notSupported(message: "\(fileURL.pathExtension)") } + else { throw WeightsError.notSupported(message: "\(fileURL.pathExtension)") } let data = try Data(contentsOf: fileURL, options: .mappedIfSafe) - switch [UInt8](data.subdata(in: 0..<4)) { - case [0x47, 0x47, 0x55, 0x46]: throw ModelWeightsError.notSupported(message: ("gguf")) - case [0x93, 0x4e, 0x55, 0x4d]: throw ModelWeightsError.notSupported(message: "mlx") // Actually [0x93, 0x4e, 0x55, 0x4d, 0x50, 0x59] - default: return try fromSafetensor(data: data) + switch ([UInt8](data.subdata(in: 0..<4)), [UInt8](data.subdata(in: 4..<6))) { + case ([0x47, 0x47, 0x55, 0x46], _): throw WeightsError.notSupported(message: ("gguf")) + case ([0x93, 0x4e, 0x55, 0x4d], [0x50, 0x59]): throw WeightsError.notSupported(message: "mlx") + default: return try Safetensor.from(data: data) } } +} - static private func fromSafetensor(data: Data) throws -> ModelWeights { - let headerSize: UInt64 = data.subdata(in: 0..<8).withUnsafeBytes({ $0.load(as: UInt64.self) }) - let header = try SafetensorHeader.from(data: data, offset: 8, size: headerSize) - - var dict = [String: MLMultiArray]() - for (key, point) in header { - guard let offsets = point?.dataOffsets, offsets.count >= 2, - let shape = point?.shape as? [NSNumber] - else { continue } +struct Safetensor { - let strides = shape.dropFirst().reversed().reduce(into: [1]) { acc, a in - acc.insert(acc[0].intValue * a.intValue as NSNumber, at: 0) - } - let start = Data.Index(UInt64(8 + offsets[0]) + headerSize) - let end = Data.Index(UInt64(8 + offsets[1]) + headerSize) - let tensorData = data.subdata(in: start.. [String: Offset?] { - guard size < data.count else { throw ModelWeightsError.invalidFile } + static func from(data: Data) throws -> [String: Offset?] { let decoder = JSONDecoder() decoder.keyDecodingStrategy = .convertFromSnakeCase - return try decoder.decode([String: Offset?].self, from: data.subdata(in: offset..<(offset + Int(size)))) + return try decoder.decode([String: Offset?].self, from: data) } } + + static func from(data: Data) throws -> Weights { + let headerSize: Int = data.subdata(in: 0..<8).withUnsafeBytes({ $0.load(as: Int.self) }) + guard headerSize < data.count else { throw Error.invalidFile } + let header = try Header.from(data: data.subdata(in: 8..<(headerSize + 8))) + + var dict = [String: MLMultiArray]() + for (key, point) in header { + guard let offsets = point?.dataOffsets, offsets.count == 2, + let shape = point?.shape as? [NSNumber], + let dType = try point?.dataType + else { continue } + + let strides = shape.dropFirst().reversed().reduce(into: [1]) { acc, a in + acc.insert(acc[0].intValue * a.intValue as NSNumber, at: 0) + } + let start = 8 + offsets[0] + headerSize + let end = 8 + offsets[1] + headerSize + let tensorData = data.subdata(in: start.. Date: Tue, 9 Jul 2024 14:09:11 -0400 Subject: [PATCH 4/6] Move Weights to TensorUtils --- Package.swift | 3 +- Sources/Models/LanguageModel.swift | 88 ------------------ Sources/TensorUtils/Weights.swift | 88 ++++++++++++++++++ .../Resources/tensor-1d-int32.safetensors | Bin .../Resources/tensor-2d-float64.safetensors | Bin .../Resources/tensor-3d-float32.safetensors | Bin .../Resources/tensor-4d-float32.safetensors | Bin .../WeightsTests.swift} | 4 +- 8 files changed, 90 insertions(+), 93 deletions(-) create mode 100644 Sources/TensorUtils/Weights.swift rename Tests/{ModelsTests => TensorUtilsTests}/Resources/tensor-1d-int32.safetensors (100%) rename Tests/{ModelsTests => TensorUtilsTests}/Resources/tensor-2d-float64.safetensors (100%) rename Tests/{ModelsTests => TensorUtilsTests}/Resources/tensor-3d-float32.safetensors (100%) rename Tests/{ModelsTests => TensorUtilsTests}/Resources/tensor-4d-float32.safetensors (100%) rename Tests/{ModelsTests/LanguageModelTests.swift => TensorUtilsTests/WeightsTests.swift} (99%) diff --git a/Package.swift b/Package.swift index d563e5c..8d90a52 100644 --- a/Package.swift +++ b/Package.swift @@ -29,9 +29,8 @@ let package = Package( .testTarget(name: "TokenizersTests", dependencies: ["Tokenizers", "Models", "Hub"], resources: [.process("Resources"), .process("Vocabs")]), .testTarget(name: "HubTests", dependencies: ["Hub"]), .testTarget(name: "PreTokenizerTests", dependencies: ["Tokenizers", "Hub"]), - .testTarget(name: "TensorUtilsTests", dependencies: ["TensorUtils"]), + .testTarget(name: "TensorUtilsTests", dependencies: ["TensorUtils", "Models", "Hub"], resources: [.process("Resources")]), .testTarget(name: "NormalizerTests", dependencies: ["Tokenizers", "Hub"]), .testTarget(name: "PostProcessorTests", dependencies: ["Tokenizers", "Hub"]), - .testTarget(name: "ModelsTests", dependencies: ["Models", "Hub"], resources: [.process("Resources")]) ] ) diff --git a/Sources/Models/LanguageModel.swift b/Sources/Models/LanguageModel.swift index af3c958..457755a 100644 --- a/Sources/Models/LanguageModel.swift +++ b/Sources/Models/LanguageModel.swift @@ -213,91 +213,3 @@ extension LanguageModel: TextGenerationModel { } extension String: Error {} - -// MARK: - Model weights - -struct Weights { - - enum WeightsError: Error { - case notSupported(message: String) - case invalidFile - } - - private let dictionary: [String: MLMultiArray] - - init(_ dictionary: [String: MLMultiArray]) { - self.dictionary = dictionary - } - - subscript(key: String) -> MLMultiArray { dictionary[key]! } - - static func from(fileURL: URL) throws -> Weights { - guard ["safetensors", "gguf", "mlx"].contains(fileURL.pathExtension) - else { throw WeightsError.notSupported(message: "\(fileURL.pathExtension)") } - - let data = try Data(contentsOf: fileURL, options: .mappedIfSafe) - switch ([UInt8](data.subdata(in: 0..<4)), [UInt8](data.subdata(in: 4..<6))) { - case ([0x47, 0x47, 0x55, 0x46], _): throw WeightsError.notSupported(message: ("gguf")) - case ([0x93, 0x4e, 0x55, 0x4d], [0x50, 0x59]): throw WeightsError.notSupported(message: "mlx") - default: return try Safetensor.from(data: data) - } - } -} - -struct Safetensor { - - typealias Error = Weights.WeightsError - - struct Header { - - struct Offset: Decodable { - let dataOffsets: [Int]? - let dtype: String? - let shape: [Int]? - - /// Unsupported: "I8", "U8", "I16", "U16", "BF16" - var dataType: MLMultiArrayDataType? { - get throws { - switch dtype { - case "I32", "U32": .int32 - case "F16": .float16 - case "F32": .float32 - case "F64", "U64": .float64 - default: throw Error.notSupported(message: "\(dtype ?? "empty")") - } - } - } - } - - static func from(data: Data) throws -> [String: Offset?] { - let decoder = JSONDecoder() - decoder.keyDecodingStrategy = .convertFromSnakeCase - return try decoder.decode([String: Offset?].self, from: data) - } - } - - static func from(data: Data) throws -> Weights { - let headerSize: Int = data.subdata(in: 0..<8).withUnsafeBytes({ $0.load(as: Int.self) }) - guard headerSize < data.count else { throw Error.invalidFile } - let header = try Header.from(data: data.subdata(in: 8..<(headerSize + 8))) - - var dict = [String: MLMultiArray]() - for (key, point) in header { - guard let offsets = point?.dataOffsets, offsets.count == 2, - let shape = point?.shape as? [NSNumber], - let dType = try point?.dataType - else { continue } - - let strides = shape.dropFirst().reversed().reduce(into: [1]) { acc, a in - acc.insert(acc[0].intValue * a.intValue as NSNumber, at: 0) - } - let start = 8 + offsets[0] + headerSize - let end = 8 + offsets[1] + headerSize - let tensorData = data.subdata(in: start.. MLMultiArray { dictionary[key]! } + + static func from(fileURL: URL) throws -> Weights { + guard ["safetensors", "gguf", "mlx"].contains(fileURL.pathExtension) + else { throw WeightsError.notSupported(message: "\(fileURL.pathExtension)") } + + let data = try Data(contentsOf: fileURL, options: .mappedIfSafe) + switch ([UInt8](data.subdata(in: 0..<4)), [UInt8](data.subdata(in: 4..<6))) { + case ([0x47, 0x47, 0x55, 0x46], _): throw WeightsError.notSupported(message: ("gguf")) + case ([0x93, 0x4e, 0x55, 0x4d], [0x50, 0x59]): throw WeightsError.notSupported(message: "mlx") + default: return try Safetensor.from(data: data) + } + } +} + +struct Safetensor { + + typealias Error = Weights.WeightsError + + struct Header { + + struct Offset: Decodable { + let dataOffsets: [Int]? + let dtype: String? + let shape: [Int]? + + /// Unsupported: "I8", "U8", "I16", "U16", "BF16" + var dataType: MLMultiArrayDataType? { + get throws { + switch dtype { + case "I32", "U32": .int32 + case "F16": .float16 + case "F32": .float32 + case "F64", "U64": .float64 + default: throw Error.notSupported(message: "\(dtype ?? "empty")") + } + } + } + } + + static func from(data: Data) throws -> [String: Offset?] { + let decoder = JSONDecoder() + decoder.keyDecodingStrategy = .convertFromSnakeCase + return try decoder.decode([String: Offset?].self, from: data) + } + } + + static func from(data: Data) throws -> Weights { + let headerSize: Int = data.subdata(in: 0..<8).withUnsafeBytes({ $0.load(as: Int.self) }) + guard headerSize < data.count else { throw Error.invalidFile } + let header = try Header.from(data: data.subdata(in: 8..<(headerSize + 8))) + + var dict = [String: MLMultiArray]() + for (key, point) in header { + guard let offsets = point?.dataOffsets, offsets.count == 2, + let shape = point?.shape as? [NSNumber], + let dType = try point?.dataType + else { continue } + + let strides = shape.dropFirst().reversed().reduce(into: [1]) { acc, a in + acc.insert(acc[0].intValue * a.intValue as NSNumber, at: 0) + } + let start = 8 + offsets[0] + headerSize + let end = 8 + offsets[1] + headerSize + let tensorData = data.subdata(in: start.. Date: Sat, 14 Dec 2024 08:16:11 -0500 Subject: [PATCH 5/6] Specify filenames to download in tests. --- Tests/TensorUtilsTests/WeightsTests.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Tests/TensorUtilsTests/WeightsTests.swift b/Tests/TensorUtilsTests/WeightsTests.swift index 5285bcc..802d571 100644 --- a/Tests/TensorUtilsTests/WeightsTests.swift +++ b/Tests/TensorUtilsTests/WeightsTests.swift @@ -12,7 +12,7 @@ class WeightsTests: XCTestCase { func testLoadWeightsFromFileURL() async throws { let repo = "google/bert_uncased_L-2_H-128_A-2" - let modelDir = try await hubApi.snapshot(from: repo) + let modelDir = try await hubApi.snapshot(from: repo, matching: ["config.json", "model.safetensors"]) let files = try FileManager.default.contentsOfDirectory(at: modelDir, includingPropertiesForKeys: [.isReadableKey]) XCTAssertTrue(files.contains(where: { $0.lastPathComponent == "config.json" })) From 15173b09365c4e4b9aef409ec394e1031856860a Mon Sep 17 00:00:00 2001 From: shavit Date: Sat, 14 Dec 2024 08:19:20 -0500 Subject: [PATCH 6/6] Make the weights optional and public Enable safe access to keys. --- Sources/TensorUtils/Weights.swift | 6 ++--- Tests/TensorUtilsTests/WeightsTests.swift | 30 +++++++++++------------ 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/Sources/TensorUtils/Weights.swift b/Sources/TensorUtils/Weights.swift index 386baca..2050e01 100644 --- a/Sources/TensorUtils/Weights.swift +++ b/Sources/TensorUtils/Weights.swift @@ -1,7 +1,7 @@ import CoreML -struct Weights { +public struct Weights { enum WeightsError: Error { case notSupported(message: String) @@ -14,9 +14,9 @@ struct Weights { self.dictionary = dictionary } - subscript(key: String) -> MLMultiArray { dictionary[key]! } + subscript(key: String) -> MLMultiArray? { dictionary[key] } - static func from(fileURL: URL) throws -> Weights { + public static func from(fileURL: URL) throws -> Weights { guard ["safetensors", "gguf", "mlx"].contains(fileURL.pathExtension) else { throw WeightsError.notSupported(message: "\(fileURL.pathExtension)") } diff --git a/Tests/TensorUtilsTests/WeightsTests.swift b/Tests/TensorUtilsTests/WeightsTests.swift index 802d571..5d2e478 100644 --- a/Tests/TensorUtilsTests/WeightsTests.swift +++ b/Tests/TensorUtilsTests/WeightsTests.swift @@ -20,25 +20,25 @@ class WeightsTests: XCTestCase { let modelFile = modelDir.appending(path: "/model.safetensors") let weights = try Weights.from(fileURL: modelFile) - XCTAssertEqual(weights["bert.embeddings.LayerNorm.bias"].dataType, .float32) - XCTAssertEqual(weights["bert.embeddings.LayerNorm.bias"].count, 128) - XCTAssertEqual(weights["bert.embeddings.LayerNorm.bias"].shape.count, 1) + XCTAssertEqual(weights["bert.embeddings.LayerNorm.bias"]!.dataType, .float32) + XCTAssertEqual(weights["bert.embeddings.LayerNorm.bias"]!.count, 128) + XCTAssertEqual(weights["bert.embeddings.LayerNorm.bias"]!.shape.count, 1) - XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"].dataType, .float32) - XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"].count, 3906816) - XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"].shape.count, 2) + XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]!.dataType, .float32) + XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]!.count, 3906816) + XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]!.shape.count, 2) - XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"][[0, 0]].floatValue, -0.0041, accuracy: 1e-3) - XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"][[3, 4]].floatValue, 0.0037, accuracy: 1e-3) - XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"][[5, 3]].floatValue, -0.5371, accuracy: 1e-3) - XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"][[7, 8]].floatValue, 0.0460, accuracy: 1e-3) - XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"][[11, 7]].floatValue, -0.0058, accuracy: 1e-3) + XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]![[0, 0]].floatValue, -0.0041, accuracy: 1e-3) + XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]![[3, 4]].floatValue, 0.0037, accuracy: 1e-3) + XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]![[5, 3]].floatValue, -0.5371, accuracy: 1e-3) + XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]![[7, 8]].floatValue, 0.0460, accuracy: 1e-3) + XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]![[11, 7]].floatValue, -0.0058, accuracy: 1e-3) } func testSafetensorReadTensor1D() throws { let modelFile = Bundle.module.url(forResource: "tensor-1d-int32", withExtension: "safetensors")! let weights: Weights = try Weights.from(fileURL: modelFile) - let tensor = weights["embedding"] + let tensor = weights["embedding"]! XCTAssertEqual(tensor.dataType, .int32) XCTAssertEqual(tensor[[0]], 1) XCTAssertEqual(tensor[[1]], 2) @@ -48,7 +48,7 @@ class WeightsTests: XCTestCase { func testSafetensorReadTensor2D() throws { let modelFile = Bundle.module.url(forResource: "tensor-2d-float64", withExtension: "safetensors")! let weights: Weights = try Weights.from(fileURL: modelFile) - let tensor = weights["embedding"] + let tensor = weights["embedding"]! XCTAssertEqual(tensor.dataType, .float64) XCTAssertEqual(tensor[[0, 0]], 1) XCTAssertEqual(tensor[[0, 1]], 2) @@ -61,7 +61,7 @@ class WeightsTests: XCTestCase { func testSafetensorReadTensor3D() throws { let modelFile = Bundle.module.url(forResource: "tensor-3d-float32", withExtension: "safetensors")! let weights: Weights = try Weights.from(fileURL: modelFile) - let tensor = weights["embedding"] + let tensor = weights["embedding"]! XCTAssertEqual(tensor.dataType, .float32) XCTAssertEqual(tensor[[0, 0, 0]], 22) XCTAssertEqual(tensor[[0, 0, 1]], 23) @@ -80,7 +80,7 @@ class WeightsTests: XCTestCase { func testSafetensorReadTensor4D() throws { let modelFile = Bundle.module.url(forResource: "tensor-4d-float32", withExtension: "safetensors")! let weights: Weights = try Weights.from(fileURL: modelFile) - let tensor = weights["embedding"] + let tensor = weights["embedding"]! XCTAssertEqual(tensor.dataType, .float32) XCTAssertEqual(tensor[[0, 0, 0, 0]], 11) XCTAssertEqual(tensor[[0, 0, 0, 1]], 12)