6
6
from diffusers import AutoencoderKL , FlowMatchEulerDiscreteScheduler
7
7
from diffusers .models .transformers .transformer_flux import FluxTransformer2DModel
8
8
from diffusers .pipelines .flux .pipeline_flux import FluxPipeline
9
- from invokeai .app .invocations .model import ModelIdentifierField
9
+ from invokeai .app .invocations .model import TransformerField , CLIPField , T5EncoderField , VAEField
10
10
from optimum .quanto import qfloat8
11
11
from PIL import Image
12
12
from transformers import CLIPTextModel , CLIPTokenizer , T5EncoderModel , T5TokenizerFast
13
13
from transformers .models .auto import AutoModelForTextEncoding
14
14
15
15
from invokeai .app .invocations .baseinvocation import BaseInvocation , invocation
16
- from invokeai .app .invocations .fields import InputField , WithBoard , WithMetadata , UIType , Input
16
+ from invokeai .app .invocations .fields import InputField , WithBoard , WithMetadata , Input , FieldDescriptions
17
17
from invokeai .app .invocations .primitives import ImageOutput
18
18
from invokeai .app .services .shared .invocation_context import InvocationContext
19
19
from invokeai .backend .quantization .fast_quantized_diffusion_model import FastQuantizedDiffusersModel
@@ -42,14 +42,24 @@ class QuantizedModelForTextEncoding(FastQuantizedTransformersModel):
42
42
class FluxTextToImageInvocation (BaseInvocation , WithMetadata , WithBoard ):
43
43
"""Text-to-image generation using a FLUX model."""
44
44
45
- flux_model : ModelIdentifierField = InputField (
46
- description = "The Flux model" ,
47
- input = Input .Any ,
48
- ui_type = UIType . FluxMainModel
45
+ transformer : TransformerField = InputField (
46
+ description = FieldDescriptions . unet ,
47
+ input = Input .Connection ,
48
+ title = "Transformer" ,
49
49
)
50
- model : TFluxModelKeys = InputField (description = "The FLUX model to use for text-to-image generation." )
51
- use_8bit : bool = InputField (
52
- default = False , description = "Whether to quantize the transformer model to 8-bit precision."
50
+ clip : CLIPField = InputField (
51
+ title = "CLIP" ,
52
+ description = FieldDescriptions .clip ,
53
+ input = Input .Connection ,
54
+ )
55
+ t5Encoder : T5EncoderField = InputField (
56
+ title = "T5EncoderField" ,
57
+ description = FieldDescriptions .t5Encoder ,
58
+ input = Input .Connection ,
59
+ )
60
+ vae : VAEField = InputField (
61
+ description = FieldDescriptions .vae ,
62
+ input = Input .Connection ,
53
63
)
54
64
positive_prompt : str = InputField (description = "Positive prompt for text-to-image generation." )
55
65
width : int = InputField (default = 1024 , multiple_of = 16 , description = "Width of the generated image." )
@@ -63,45 +73,40 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
63
73
64
74
@torch .no_grad ()
65
75
def invoke (self , context : InvocationContext ) -> ImageOutput :
66
- model_path = context .models .download_and_cache_model (FLUX_MODELS [self .model ])
67
-
68
- t5_embeddings , clip_embeddings = self ._encode_prompt (context , model_path )
69
- latents = self ._run_diffusion (context , model_path , clip_embeddings , t5_embeddings )
70
- image = self ._run_vae_decoding (context , model_path , latents )
76
+ t5_embeddings , clip_embeddings = self ._encode_prompt (context )
77
+ latents = self ._run_diffusion (context , clip_embeddings , t5_embeddings )
78
+ image = self ._run_vae_decoding (context , latents )
71
79
image_dto = context .images .save (image = image )
72
80
return ImageOutput .build (image_dto )
73
81
74
- def _encode_prompt (self , context : InvocationContext , flux_model_dir : Path ) -> tuple [torch .Tensor , torch .Tensor ]:
75
- # Determine the T5 max sequence length based on the model.
76
- if self .model == "flux-schnell" :
77
- max_seq_len = 256
78
- # elif self.model == "flux-dev":
79
- # max_seq_len = 512
80
- else :
81
- raise ValueError (f"Unknown model: { self .model } " )
82
+ def _encode_prompt (self , context : InvocationContext ) -> tuple [torch .Tensor , torch .Tensor ]:
83
+ # TODO: Determine the T5 max sequence length based on the model.
84
+ # if self.model == "flux-schnell":
85
+ max_seq_len = 256
86
+ # # elif self.model == "flux-dev":
87
+ # # max_seq_len = 512
88
+ # else:
89
+ # raise ValueError(f"Unknown model: {self.model}")
82
90
83
- # Load the CLIP tokenizer.
84
- clip_tokenizer_path = flux_model_dir / "tokenizer"
85
- clip_tokenizer = CLIPTokenizer .from_pretrained (clip_tokenizer_path , local_files_only = True )
86
- assert isinstance (clip_tokenizer , CLIPTokenizer )
91
+ # Load CLIP.
92
+ clip_tokenizer_info = context .models .load (self .clip .tokenizer )
93
+ clip_text_encoder_info = context .models .load (self .clip .text_encoder )
87
94
88
- # Load the T5 tokenizer.
89
- t5_tokenizer_path = flux_model_dir / "tokenizer_2"
90
- t5_tokenizer = T5TokenizerFast .from_pretrained (t5_tokenizer_path , local_files_only = True )
91
- assert isinstance (t5_tokenizer , T5TokenizerFast )
95
+ # Load T5.
96
+ t5_tokenizer_info = context .models .load (self .t5Encoder .tokenizer )
97
+ t5_text_encoder_info = context .models .load (self .t5Encoder .text_encoder )
92
98
93
- clip_text_encoder_path = flux_model_dir / "text_encoder"
94
- t5_text_encoder_path = flux_model_dir / "text_encoder_2"
95
99
with (
96
- context .models .load_local_model (
97
- model_path = clip_text_encoder_path , loader = self ._load_flux_text_encoder
98
- ) as clip_text_encoder ,
99
- context .models .load_local_model (
100
- model_path = t5_text_encoder_path , loader = self ._load_flux_text_encoder_2
101
- ) as t5_text_encoder ,
100
+ clip_text_encoder_info as clip_text_encoder ,
101
+ t5_text_encoder_info as t5_text_encoder ,
102
+ clip_tokenizer_info as clip_tokenizer ,
103
+ t5_tokenizer_info as t5_tokenizer ,
102
104
):
103
105
assert isinstance (clip_text_encoder , CLIPTextModel )
104
106
assert isinstance (t5_text_encoder , T5EncoderModel )
107
+ assert isinstance (clip_tokenizer , CLIPTokenizer )
108
+ assert isinstance (t5_tokenizer , T5TokenizerFast )
109
+
105
110
pipeline = FluxPipeline (
106
111
scheduler = None ,
107
112
vae = None ,
@@ -114,7 +119,7 @@ def _encode_prompt(self, context: InvocationContext, flux_model_dir: Path) -> tu
114
119
115
120
# prompt_embeds: T5 embeddings
116
121
# pooled_prompt_embeds: CLIP embeddings
117
- prompt_embeds , pooled_prompt_embeds , text_ids = pipeline .encode_prompt (
122
+ prompt_embeds , pooled_prompt_embeds , _ = pipeline .encode_prompt (
118
123
prompt = self .positive_prompt ,
119
124
prompt_2 = self .positive_prompt ,
120
125
device = TorchDevice .choose_torch_device (),
@@ -128,22 +133,23 @@ def _encode_prompt(self, context: InvocationContext, flux_model_dir: Path) -> tu
128
133
def _run_diffusion (
129
134
self ,
130
135
context : InvocationContext ,
131
- flux_model_dir : Path ,
132
136
clip_embeddings : torch .Tensor ,
133
137
t5_embeddings : torch .Tensor ,
134
138
):
135
- scheduler = FlowMatchEulerDiscreteScheduler .from_pretrained (flux_model_dir / "scheduler" , local_files_only = True )
139
+ scheduler_info = context .models .load (self .transformer .scheduler )
140
+ transformer_info = context .models .load (self .transformer .transformer )
136
141
137
142
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
138
143
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
139
144
# if the cache is not empty.
140
- context .models ._services .model_manager .load .ram_cache .make_room (24 * 2 ** 30 )
145
+ # context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
141
146
142
- transformer_path = flux_model_dir / "transformer"
143
- with context . models . load_local_model (
144
- model_path = transformer_path , loader = self . _load_flux_transformer
145
- ) as transformer :
147
+ with (
148
+ transformer_info as transformer ,
149
+ scheduler_info as scheduler
150
+ ):
146
151
assert isinstance (transformer , FluxTransformer2DModel )
152
+ assert isinstance (scheduler , FlowMatchEulerDiscreteScheduler )
147
153
148
154
flux_pipeline_with_transformer = FluxPipeline (
149
155
scheduler = scheduler ,
@@ -176,11 +182,10 @@ def _run_diffusion(
176
182
def _run_vae_decoding (
177
183
self ,
178
184
context : InvocationContext ,
179
- flux_model_dir : Path ,
180
185
latents : torch .Tensor ,
181
186
) -> Image .Image :
182
- vae_path = flux_model_dir / " vae"
183
- with context . models . load_local_model ( model_path = vae_path , loader = self . _load_flux_vae ) as vae :
187
+ vae_info = context . models . load ( self . vae . vae )
188
+ with vae_info as vae :
184
189
assert isinstance (vae , AutoencoderKL )
185
190
186
191
flux_pipeline_with_vae = FluxPipeline (
@@ -205,81 +210,3 @@ def _run_vae_decoding(
205
210
206
211
assert isinstance (image , Image .Image )
207
212
return image
208
-
209
- @staticmethod
210
- def _load_flux_text_encoder (path : Path ) -> CLIPTextModel :
211
- model = CLIPTextModel .from_pretrained (path , local_files_only = True )
212
- assert isinstance (model , CLIPTextModel )
213
- return model
214
-
215
- def _load_flux_text_encoder_2 (self , path : Path ) -> T5EncoderModel :
216
- if self .use_8bit :
217
- model_8bit_path = path / "quantized"
218
- if model_8bit_path .exists ():
219
- # The quantized model exists, load it.
220
- # TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
221
- # something that we should be able to make much faster.
222
- q_model = QuantizedModelForTextEncoding .from_pretrained (model_8bit_path )
223
-
224
- # Access the underlying wrapped model.
225
- # We access the wrapped model, even though it is private, because it simplifies the type checking by
226
- # always returning a T5EncoderModel from this function.
227
- model = q_model ._wrapped
228
- else :
229
- # The quantized model does not exist yet, quantize and save it.
230
- # TODO(ryand): dtype?
231
- model = T5EncoderModel .from_pretrained (path , local_files_only = True )
232
- assert isinstance (model , T5EncoderModel )
233
-
234
- q_model = QuantizedModelForTextEncoding .quantize (model , weights = qfloat8 )
235
-
236
- model_8bit_path .mkdir (parents = True , exist_ok = True )
237
- q_model .save_pretrained (model_8bit_path )
238
-
239
- # (See earlier comment about accessing the wrapped model.)
240
- model = q_model ._wrapped
241
- else :
242
- model = T5EncoderModel .from_pretrained (path , local_files_only = True )
243
-
244
- assert isinstance (model , T5EncoderModel )
245
- return model
246
-
247
- def _load_flux_transformer (self , path : Path ) -> FluxTransformer2DModel :
248
- if self .use_8bit :
249
- model_8bit_path = path / "quantized"
250
- if model_8bit_path .exists ():
251
- # The quantized model exists, load it.
252
- # TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
253
- # something that we should be able to make much faster.
254
- q_model = QuantizedFluxTransformer2DModel .from_pretrained (model_8bit_path )
255
-
256
- # Access the underlying wrapped model.
257
- # We access the wrapped model, even though it is private, because it simplifies the type checking by
258
- # always returning a FluxTransformer2DModel from this function.
259
- model = q_model ._wrapped
260
- else :
261
- # The quantized model does not exist yet, quantize and save it.
262
- # TODO(ryand): Loading in float16 and then quantizing seems to result in NaNs. In order to run this on
263
- # GPUs that don't support bfloat16, we would need to host the quantized model instead of generating it
264
- # here.
265
- model = FluxTransformer2DModel .from_pretrained (path , local_files_only = True , torch_dtype = torch .bfloat16 )
266
- assert isinstance (model , FluxTransformer2DModel )
267
-
268
- q_model = QuantizedFluxTransformer2DModel .quantize (model , weights = qfloat8 )
269
-
270
- model_8bit_path .mkdir (parents = True , exist_ok = True )
271
- q_model .save_pretrained (model_8bit_path )
272
-
273
- # (See earlier comment about accessing the wrapped model.)
274
- model = q_model ._wrapped
275
- else :
276
- model = FluxTransformer2DModel .from_pretrained (path , local_files_only = True , torch_dtype = torch .bfloat16 )
277
-
278
- assert isinstance (model , FluxTransformer2DModel )
279
- return model
280
-
281
- @staticmethod
282
- def _load_flux_vae (path : Path ) -> AutoencoderKL :
283
- model = AutoencoderKL .from_pretrained (path , local_files_only = True )
284
- assert isinstance (model , AutoencoderKL )
285
- return model
0 commit comments