Skip to content

Commit c8f84d2

Browse files
committed
Manage quantization of models within the loader
1 parent 9bf8ac7 commit c8f84d2

File tree

8 files changed

+244
-284
lines changed

8 files changed

+244
-284
lines changed

invokeai/app/invocations/fields.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ class FieldDescriptions:
126126
negative_cond = "Negative conditioning tensor"
127127
noise = "Noise tensor"
128128
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
129+
t5Encoder = "T5 tokenizer and text encoder"
129130
unet = "UNet (scheduler, LoRAs)"
130131
transformer = "Transformer"
131132
vae = "VAE"

invokeai/app/invocations/flux_text_to_image.py

Lines changed: 54 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
77
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
88
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
9-
from invokeai.app.invocations.model import ModelIdentifierField
9+
from invokeai.app.invocations.model import TransformerField, CLIPField, T5EncoderField, VAEField
1010
from optimum.quanto import qfloat8
1111
from PIL import Image
1212
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
1313
from transformers.models.auto import AutoModelForTextEncoding
1414

1515
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
16-
from invokeai.app.invocations.fields import InputField, WithBoard, WithMetadata, UIType, Input
16+
from invokeai.app.invocations.fields import InputField, WithBoard, WithMetadata, Input, FieldDescriptions
1717
from invokeai.app.invocations.primitives import ImageOutput
1818
from invokeai.app.services.shared.invocation_context import InvocationContext
1919
from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel
@@ -42,14 +42,24 @@ class QuantizedModelForTextEncoding(FastQuantizedTransformersModel):
4242
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
4343
"""Text-to-image generation using a FLUX model."""
4444

45-
flux_model: ModelIdentifierField = InputField(
46-
description="The Flux model",
47-
input=Input.Any,
48-
ui_type=UIType.FluxMainModel
45+
transformer: TransformerField = InputField(
46+
description=FieldDescriptions.unet,
47+
input=Input.Connection,
48+
title="Transformer",
4949
)
50-
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.")
51-
use_8bit: bool = InputField(
52-
default=False, description="Whether to quantize the transformer model to 8-bit precision."
50+
clip: CLIPField = InputField(
51+
title="CLIP",
52+
description=FieldDescriptions.clip,
53+
input=Input.Connection,
54+
)
55+
t5Encoder: T5EncoderField = InputField(
56+
title="T5EncoderField",
57+
description=FieldDescriptions.t5Encoder,
58+
input=Input.Connection,
59+
)
60+
vae: VAEField = InputField(
61+
description=FieldDescriptions.vae,
62+
input=Input.Connection,
5363
)
5464
positive_prompt: str = InputField(description="Positive prompt for text-to-image generation.")
5565
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
@@ -63,45 +73,40 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
6373

6474
@torch.no_grad()
6575
def invoke(self, context: InvocationContext) -> ImageOutput:
66-
model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model])
67-
68-
t5_embeddings, clip_embeddings = self._encode_prompt(context, model_path)
69-
latents = self._run_diffusion(context, model_path, clip_embeddings, t5_embeddings)
70-
image = self._run_vae_decoding(context, model_path, latents)
76+
t5_embeddings, clip_embeddings = self._encode_prompt(context)
77+
latents = self._run_diffusion(context, clip_embeddings, t5_embeddings)
78+
image = self._run_vae_decoding(context, latents)
7179
image_dto = context.images.save(image=image)
7280
return ImageOutput.build(image_dto)
7381

74-
def _encode_prompt(self, context: InvocationContext, flux_model_dir: Path) -> tuple[torch.Tensor, torch.Tensor]:
75-
# Determine the T5 max sequence length based on the model.
76-
if self.model == "flux-schnell":
77-
max_seq_len = 256
78-
# elif self.model == "flux-dev":
79-
# max_seq_len = 512
80-
else:
81-
raise ValueError(f"Unknown model: {self.model}")
82+
def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]:
83+
# TODO: Determine the T5 max sequence length based on the model.
84+
# if self.model == "flux-schnell":
85+
max_seq_len = 256
86+
# # elif self.model == "flux-dev":
87+
# # max_seq_len = 512
88+
# else:
89+
# raise ValueError(f"Unknown model: {self.model}")
8290

83-
# Load the CLIP tokenizer.
84-
clip_tokenizer_path = flux_model_dir / "tokenizer"
85-
clip_tokenizer = CLIPTokenizer.from_pretrained(clip_tokenizer_path, local_files_only=True)
86-
assert isinstance(clip_tokenizer, CLIPTokenizer)
91+
# Load CLIP.
92+
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
93+
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
8794

88-
# Load the T5 tokenizer.
89-
t5_tokenizer_path = flux_model_dir / "tokenizer_2"
90-
t5_tokenizer = T5TokenizerFast.from_pretrained(t5_tokenizer_path, local_files_only=True)
91-
assert isinstance(t5_tokenizer, T5TokenizerFast)
95+
# Load T5.
96+
t5_tokenizer_info = context.models.load(self.t5Encoder.tokenizer)
97+
t5_text_encoder_info = context.models.load(self.t5Encoder.text_encoder)
9298

93-
clip_text_encoder_path = flux_model_dir / "text_encoder"
94-
t5_text_encoder_path = flux_model_dir / "text_encoder_2"
9599
with (
96-
context.models.load_local_model(
97-
model_path=clip_text_encoder_path, loader=self._load_flux_text_encoder
98-
) as clip_text_encoder,
99-
context.models.load_local_model(
100-
model_path=t5_text_encoder_path, loader=self._load_flux_text_encoder_2
101-
) as t5_text_encoder,
100+
clip_text_encoder_info as clip_text_encoder,
101+
t5_text_encoder_info as t5_text_encoder,
102+
clip_tokenizer_info as clip_tokenizer,
103+
t5_tokenizer_info as t5_tokenizer,
102104
):
103105
assert isinstance(clip_text_encoder, CLIPTextModel)
104106
assert isinstance(t5_text_encoder, T5EncoderModel)
107+
assert isinstance(clip_tokenizer, CLIPTokenizer)
108+
assert isinstance(t5_tokenizer, T5TokenizerFast)
109+
105110
pipeline = FluxPipeline(
106111
scheduler=None,
107112
vae=None,
@@ -114,7 +119,7 @@ def _encode_prompt(self, context: InvocationContext, flux_model_dir: Path) -> tu
114119

115120
# prompt_embeds: T5 embeddings
116121
# pooled_prompt_embeds: CLIP embeddings
117-
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
122+
prompt_embeds, pooled_prompt_embeds, _ = pipeline.encode_prompt(
118123
prompt=self.positive_prompt,
119124
prompt_2=self.positive_prompt,
120125
device=TorchDevice.choose_torch_device(),
@@ -128,22 +133,23 @@ def _encode_prompt(self, context: InvocationContext, flux_model_dir: Path) -> tu
128133
def _run_diffusion(
129134
self,
130135
context: InvocationContext,
131-
flux_model_dir: Path,
132136
clip_embeddings: torch.Tensor,
133137
t5_embeddings: torch.Tensor,
134138
):
135-
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(flux_model_dir / "scheduler", local_files_only=True)
139+
scheduler_info = context.models.load(self.transformer.scheduler)
140+
transformer_info = context.models.load(self.transformer.transformer)
136141

137142
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
138143
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
139144
# if the cache is not empty.
140-
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
145+
# context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
141146

142-
transformer_path = flux_model_dir / "transformer"
143-
with context.models.load_local_model(
144-
model_path=transformer_path, loader=self._load_flux_transformer
145-
) as transformer:
147+
with (
148+
transformer_info as transformer,
149+
scheduler_info as scheduler
150+
):
146151
assert isinstance(transformer, FluxTransformer2DModel)
152+
assert isinstance(scheduler, FlowMatchEulerDiscreteScheduler)
147153

148154
flux_pipeline_with_transformer = FluxPipeline(
149155
scheduler=scheduler,
@@ -176,11 +182,10 @@ def _run_diffusion(
176182
def _run_vae_decoding(
177183
self,
178184
context: InvocationContext,
179-
flux_model_dir: Path,
180185
latents: torch.Tensor,
181186
) -> Image.Image:
182-
vae_path = flux_model_dir / "vae"
183-
with context.models.load_local_model(model_path=vae_path, loader=self._load_flux_vae) as vae:
187+
vae_info = context.models.load(self.vae.vae)
188+
with vae_info as vae:
184189
assert isinstance(vae, AutoencoderKL)
185190

186191
flux_pipeline_with_vae = FluxPipeline(
@@ -205,81 +210,3 @@ def _run_vae_decoding(
205210

206211
assert isinstance(image, Image.Image)
207212
return image
208-
209-
@staticmethod
210-
def _load_flux_text_encoder(path: Path) -> CLIPTextModel:
211-
model = CLIPTextModel.from_pretrained(path, local_files_only=True)
212-
assert isinstance(model, CLIPTextModel)
213-
return model
214-
215-
def _load_flux_text_encoder_2(self, path: Path) -> T5EncoderModel:
216-
if self.use_8bit:
217-
model_8bit_path = path / "quantized"
218-
if model_8bit_path.exists():
219-
# The quantized model exists, load it.
220-
# TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
221-
# something that we should be able to make much faster.
222-
q_model = QuantizedModelForTextEncoding.from_pretrained(model_8bit_path)
223-
224-
# Access the underlying wrapped model.
225-
# We access the wrapped model, even though it is private, because it simplifies the type checking by
226-
# always returning a T5EncoderModel from this function.
227-
model = q_model._wrapped
228-
else:
229-
# The quantized model does not exist yet, quantize and save it.
230-
# TODO(ryand): dtype?
231-
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
232-
assert isinstance(model, T5EncoderModel)
233-
234-
q_model = QuantizedModelForTextEncoding.quantize(model, weights=qfloat8)
235-
236-
model_8bit_path.mkdir(parents=True, exist_ok=True)
237-
q_model.save_pretrained(model_8bit_path)
238-
239-
# (See earlier comment about accessing the wrapped model.)
240-
model = q_model._wrapped
241-
else:
242-
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
243-
244-
assert isinstance(model, T5EncoderModel)
245-
return model
246-
247-
def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
248-
if self.use_8bit:
249-
model_8bit_path = path / "quantized"
250-
if model_8bit_path.exists():
251-
# The quantized model exists, load it.
252-
# TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
253-
# something that we should be able to make much faster.
254-
q_model = QuantizedFluxTransformer2DModel.from_pretrained(model_8bit_path)
255-
256-
# Access the underlying wrapped model.
257-
# We access the wrapped model, even though it is private, because it simplifies the type checking by
258-
# always returning a FluxTransformer2DModel from this function.
259-
model = q_model._wrapped
260-
else:
261-
# The quantized model does not exist yet, quantize and save it.
262-
# TODO(ryand): Loading in float16 and then quantizing seems to result in NaNs. In order to run this on
263-
# GPUs that don't support bfloat16, we would need to host the quantized model instead of generating it
264-
# here.
265-
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
266-
assert isinstance(model, FluxTransformer2DModel)
267-
268-
q_model = QuantizedFluxTransformer2DModel.quantize(model, weights=qfloat8)
269-
270-
model_8bit_path.mkdir(parents=True, exist_ok=True)
271-
q_model.save_pretrained(model_8bit_path)
272-
273-
# (See earlier comment about accessing the wrapped model.)
274-
model = q_model._wrapped
275-
else:
276-
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
277-
278-
assert isinstance(model, FluxTransformer2DModel)
279-
return model
280-
281-
@staticmethod
282-
def _load_flux_vae(path: Path) -> AutoencoderKL:
283-
model = AutoencoderKL.from_pretrained(path, local_files_only=True)
284-
assert isinstance(model, AutoencoderKL)
285-
return model

invokeai/app/invocations/model.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ class TransformerField(BaseModel):
6565
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
6666
scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel")
6767

68+
class T5EncoderField(BaseModel):
69+
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
70+
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
71+
6872

6973
class VAEField(BaseModel):
7074
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
@@ -133,8 +137,8 @@ class FluxModelLoaderOutput(BaseInvocationOutput):
133137
"""Flux base model loader output"""
134138

135139
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
136-
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
137-
clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
140+
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
141+
t5Encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5Encoder, title="T5 Encoder")
138142
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
139143

140144

@@ -166,7 +170,7 @@ def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
166170
return FluxModelLoaderOutput(
167171
transformer=TransformerField(transformer=transformer, scheduler=scheduler),
168172
clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
169-
clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
173+
t5Encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=text_encoder2),
170174
vae=VAEField(vae=vae),
171175
)
172176

invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,12 @@ def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelTy
7878

7979
# TO DO: Add exception handling
8080
def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type
81-
if module in ["diffusers", "transformers"]:
81+
if module in [
82+
"diffusers",
83+
"transformers",
84+
"invokeai.backend.quantization.fast_quantized_transformers_model",
85+
"invokeai.backend.quantization.fast_quantized_diffusion_model",
86+
]:
8287
res_type = sys.modules[module]
8388
else:
8489
res_type = sys.modules["diffusers"].pipelines

invokeai/backend/model_manager/load/model_util.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
1111
from diffusers.schedulers.scheduling_utils import SchedulerMixin
12-
from transformers import CLIPTokenizer
12+
from transformers import CLIPTokenizer, T5TokenizerFast
1313

1414
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
1515
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
@@ -48,6 +48,13 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
4848
),
4949
):
5050
return model.calc_size()
51+
elif isinstance(
52+
model,
53+
(
54+
T5TokenizerFast,
55+
),
56+
):
57+
return len(model)
5158
else:
5259
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
5360
# supported model types.

invokeai/backend/quantization/fast_quantized_diffusion_model.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,17 @@
1212
)
1313
from optimum.quanto.models import QuantizedDiffusersModel
1414
from optimum.quanto.models.shared_dict import ShardedStateDict
15+
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
1516

1617
from invokeai.backend.requantize import requantize
1718

1819

1920
class FastQuantizedDiffusersModel(QuantizedDiffusersModel):
2021
@classmethod
21-
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]):
22+
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike], base_class = FluxTransformer2DModel, **kwargs):
2223
"""We override the `from_pretrained()` method in order to use our custom `requantize()` implementation."""
23-
if cls.base_class is None:
24+
base_class = base_class or cls.base_class
25+
if base_class is None:
2426
raise ValueError("The `base_class` attribute needs to be configured.")
2527

2628
if not is_accelerate_available():
@@ -43,16 +45,16 @@ def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]):
4345

4446
with open(model_config_path, "r", encoding="utf-8") as f:
4547
original_model_cls_name = json.load(f)["_class_name"]
46-
configured_cls_name = cls.base_class.__name__
48+
configured_cls_name = base_class.__name__
4749
if configured_cls_name != original_model_cls_name:
4850
raise ValueError(
4951
f"Configured base class ({configured_cls_name}) differs from what was derived from the provided configuration ({original_model_cls_name})."
5052
)
5153

5254
# Create an empty model
53-
config = cls.base_class.load_config(model_name_or_path)
55+
config = base_class.load_config(model_name_or_path)
5456
with init_empty_weights():
55-
model = cls.base_class.from_config(config)
57+
model = base_class.from_config(config)
5658

5759
# Look for the index of a sharded checkpoint
5860
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
@@ -72,6 +74,6 @@ def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]):
7274
# Requantize and load quantized weights from state_dict
7375
requantize(model, state_dict=state_dict, quantization_map=qmap)
7476
model.eval()
75-
return cls(model)
77+
return cls(model)._wrapped
7678
else:
7779
raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.")

0 commit comments

Comments
 (0)