1+ import Foundation
2+ import Hub
3+ import MLX
4+ import MLXNN
5+ import MLXRandom
6+ import Tokenizers
7+ import Logging
8+
9+ private let logger = Logger ( label: " flux.swift.FluxModelCore " )
10+
11+ public struct FluxModelConfiguration {
12+ public let transformerConfig : MultiModalDiffusionConfiguration
13+ public let t5Config : T5Configuration
14+ public let clipConfig : CLIPConfiguration
15+ public let vaeConfig : VAEConfiguration
16+ public let t5MaxSequenceLength : Int
17+ public let clipMaxSequenceLength : Int
18+ public let clipPaddingToken : Int32
19+
20+ nonisolated ( unsafe) public static let schnell = FluxModelConfiguration (
21+ transformerConfig: MultiModalDiffusionConfiguration ( ) ,
22+ t5Config: T5Configuration ( ) ,
23+ clipConfig: CLIPConfiguration ( ) ,
24+ vaeConfig: VAEConfiguration ( ) ,
25+ t5MaxSequenceLength: 256 ,
26+ clipMaxSequenceLength: 77 ,
27+ clipPaddingToken: 49407
28+ )
29+
30+ nonisolated ( unsafe) public static let dev = FluxModelConfiguration (
31+ transformerConfig: MultiModalDiffusionConfiguration ( guidanceEmbeds: true ) ,
32+ t5Config: T5Configuration ( ) ,
33+ clipConfig: CLIPConfiguration ( ) ,
34+ vaeConfig: VAEConfiguration ( ) ,
35+ t5MaxSequenceLength: 512 ,
36+ clipMaxSequenceLength: 77 ,
37+ clipPaddingToken: 49407
38+ )
39+
40+ nonisolated ( unsafe) public static let kontextDev = FluxModelConfiguration (
41+ transformerConfig: MultiModalDiffusionConfiguration ( guidanceEmbeds: true ) ,
42+ t5Config: T5Configuration (
43+ vocabSize: 32128 ,
44+ dModel: 4096 ,
45+ dKv: 64 ,
46+ dFf: 10240 ,
47+ numHeads: 64 ,
48+ numLayers: 24
49+ ) ,
50+ clipConfig: CLIPConfiguration (
51+ hiddenSize: 768 ,
52+ intermediateSize: 3072 ,
53+ headDimension: 64 ,
54+ batchSize: 1 ,
55+ numAttentionHeads: 12 ,
56+ positionEmbeddingsCount: 77 ,
57+ tokenEmbeddingsCount: 49408 ,
58+ numHiddenLayers: 11
59+ ) ,
60+ vaeConfig: VAEConfiguration ( ) ,
61+ t5MaxSequenceLength: 512 ,
62+ clipMaxSequenceLength: 77 ,
63+ clipPaddingToken: 49407
64+ )
65+ }
66+
67+ public class FluxModelCore : @unchecked Sendable {
68+ public let transformer : MultiModalDiffusionTransformer
69+ public let vae : VAE
70+ public let t5Encoder : T5Encoder
71+ public let clipEncoder : CLIPEncoder
72+
73+ var clipTokenizer : CLIPTokenizer
74+ var t5Tokenizer : any Tokenizer
75+
76+ public let configuration : FluxModelConfiguration
77+ public var modelDirectory : URL ?
78+
79+ public init ( hub: HubApi , fluxConfiguration: FluxConfiguration , modelConfiguration: FluxModelConfiguration ) throws {
80+ self . configuration = modelConfiguration
81+
82+ let repo = Hub . Repo ( id: fluxConfiguration. id)
83+ let directory = hub. localRepoLocation ( repo)
84+
85+ ( self . t5Tokenizer, self . clipTokenizer) = try FLUX . loadTokenizers ( directory: directory, hub: hub)
86+
87+ self . transformer = MultiModalDiffusionTransformer ( modelConfiguration. transformerConfig)
88+ self . vae = VAE ( modelConfiguration. vaeConfig)
89+ self . t5Encoder = T5Encoder ( modelConfiguration. t5Config)
90+ self . clipEncoder = CLIPEncoder ( modelConfiguration. clipConfig)
91+ }
92+
93+ public init ( hub: HubApi , modelDirectory: URL , modelConfiguration: FluxModelConfiguration ) throws {
94+ self . configuration = modelConfiguration
95+ self . modelDirectory = modelDirectory
96+
97+ logger. info ( " Initializing from quantized model directory: \( modelDirectory. path) " )
98+
99+ ( self . t5Tokenizer, self . clipTokenizer) = try FLUX . loadTokenizers ( directory: modelDirectory, hub: hub)
100+
101+ self . transformer = MultiModalDiffusionTransformer ( modelConfiguration. transformerConfig)
102+ self . vae = VAE ( modelConfiguration. vaeConfig)
103+ self . t5Encoder = T5Encoder ( modelConfiguration. t5Config)
104+ self . clipEncoder = CLIPEncoder ( modelConfiguration. clipConfig)
105+ }
106+
107+ public func loadWeights( from directory: URL , dtype: DType = . float16) throws {
108+ self . modelDirectory = directory
109+ logger. info ( " Loading weights from: \( directory. path) " )
110+ logger. info ( " Using dtype: \( dtype) " )
111+
112+ try loadTransformerWeights ( from: directory. appending ( path: " transformer " ) , dtype: dtype)
113+ try loadVAEWeights ( from: directory. appending ( path: " vae " ) , dtype: dtype)
114+ try loadT5EncoderWeights ( from: directory. appending ( path: " text_encoder_2 " ) , dtype: dtype)
115+ try loadCLIPEncoderWeights ( from: directory. appending ( path: " text_encoder " ) , dtype: dtype)
116+
117+ logger. info ( " All weights loaded successfully " )
118+ }
119+
120+
121+ private func loadTransformerWeights( from directory: URL , dtype: DType ) throws {
122+ var transformerWeights = [ String: MLXArray] ( )
123+
124+ guard let enumerator = FileManager . default. enumerator (
125+ at: directory, includingPropertiesForKeys: nil
126+ ) else {
127+ throw FluxError . weightsNotFound ( " Unable to enumerate transformer directory: \( directory) " )
128+ }
129+
130+ for case let url as URL in enumerator {
131+ if url. pathExtension == " safetensors " {
132+ let w = try loadArrays ( url: url)
133+ for (key, value) in w {
134+ let newKey = FLUX . remapWeightKey ( key)
135+ if value. dtype != . bfloat16 {
136+ transformerWeights [ newKey] = value. asType ( dtype)
137+ } else {
138+ transformerWeights [ newKey] = value
139+ }
140+ }
141+ }
142+ }
143+ transformer. update ( parameters: ModuleParameters . unflattened ( transformerWeights) )
144+ }
145+
146+ private func loadVAEWeights( from directory: URL , dtype: DType ) throws {
147+ let vaeURL = directory. appending ( path: " diffusion_pytorch_model.safetensors " )
148+ var vaeWeights = try loadArrays ( url: vaeURL)
149+
150+ for (key, value) in vaeWeights {
151+ if value. dtype != . bfloat16 {
152+ vaeWeights [ key] = value. asType ( dtype)
153+ }
154+ if value. ndim == 4 {
155+ vaeWeights [ key] = value. transposed ( 0 , 2 , 3 , 1 )
156+ }
157+ }
158+ vae. update ( parameters: ModuleParameters . unflattened ( vaeWeights) )
159+ }
160+
161+ private func loadT5EncoderWeights( from directory: URL , dtype: DType ) throws {
162+ var weights = [ String: MLXArray] ( )
163+
164+ guard let enumerator = FileManager . default. enumerator (
165+ at: directory, includingPropertiesForKeys: nil
166+ ) else {
167+ throw FluxError . weightsNotFound ( " Unable to enumerate T5 encoder directory: \( directory) " )
168+ }
169+
170+ for case let url as URL in enumerator {
171+ if url. pathExtension == " safetensors " {
172+ let w = try loadArrays ( url: url)
173+ for (key, value) in w {
174+ if value. dtype != . bfloat16 {
175+ weights [ key] = value. asType ( dtype)
176+ } else {
177+ weights [ key] = value
178+ }
179+ }
180+ }
181+ }
182+
183+ if let relativeAttentionBias = weights [
184+ " encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight "
185+ ] {
186+ weights [ " relative_attention_bias.weight " ] = relativeAttentionBias
187+ }
188+
189+ t5Encoder. update ( parameters: ModuleParameters . unflattened ( weights) )
190+ }
191+
192+ private func loadCLIPEncoderWeights( from directory: URL , dtype: DType ) throws {
193+ let weightsURL = directory. appending ( path: " model.safetensors " )
194+ var weights = try loadArrays ( url: weightsURL)
195+
196+ for (key, value) in weights {
197+ if value. dtype != . bfloat16 {
198+ weights [ key] = value. asType ( dtype)
199+ }
200+ }
201+ clipEncoder. update ( parameters: ModuleParameters . unflattened ( weights) )
202+ }
203+
204+
205+ public func conditionText( prompt: String ) -> ( MLXArray , MLXArray ) {
206+ let t5Tokens = t5Tokenizer. encode ( text: prompt, addSpecialTokens: true )
207+ let paddedT5Tokens = Array ( t5Tokens. prefix ( configuration. t5MaxSequenceLength) )
208+ + Array( repeating: 0 , count: max ( 0 , configuration. t5MaxSequenceLength - min( t5Tokens. count, configuration. t5MaxSequenceLength) ) )
209+
210+ let clipTokens = clipTokenizer. tokenize ( text: prompt)
211+ let paddedClipTokens = Array ( clipTokens. prefix ( configuration. clipMaxSequenceLength) )
212+ + Array( repeating: configuration. clipPaddingToken, count: max ( 0 , configuration. clipMaxSequenceLength - min( clipTokens. count, configuration. clipMaxSequenceLength) ) )
213+
214+ let promptEmbeddings = t5Encoder ( MLXArray ( paddedT5Tokens) [ . newAxis] )
215+ let pooledPromptEmbeddings = clipEncoder ( MLXArray ( paddedClipTokens) [ . newAxis] )
216+
217+ return ( promptEmbeddings, pooledPromptEmbeddings)
218+ }
219+
220+
221+ public func ensureLoaded( ) {
222+ eval ( transformer, t5Encoder, clipEncoder)
223+ }
224+
225+ public func decode( xt: MLXArray ) -> MLXArray {
226+ var x = vae. decode ( xt)
227+ x = clip ( x / 2 + 0.5 , min: 0 , max: 1 )
228+ return x
229+ }
230+
231+ public func detachedDecoder( ) -> ImageDecoder {
232+ let autoencoder = self . vae
233+ func decode( xt: MLXArray ) -> MLXArray {
234+ var x = autoencoder. decode ( xt)
235+ x = clip ( x / 2 + 0.5 , min: 0 , max: 1 )
236+ return x
237+ }
238+ return decode ( xt: )
239+ }
240+ }
241+
242+
243+ extension FluxModelCore : FLUXComponents { }
0 commit comments