Skip to content

Commit d06c49b

Browse files
committed
Setup flux model loading in the UI
1 parent a8a2fc1 commit d06c49b

File tree

22 files changed

+814
-138
lines changed

22 files changed

+814
-138
lines changed

invokeai/app/invocations/fields.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
4040

4141
# region Model Field Types
4242
MainModel = "MainModelField"
43+
FluxMainModel = "FluxMainModelField"
4344
SDXLMainModel = "SDXLMainModelField"
4445
SDXLRefinerModel = "SDXLRefinerModelField"
4546
ONNXModel = "ONNXModelField"
@@ -126,12 +127,14 @@ class FieldDescriptions:
126127
noise = "Noise tensor"
127128
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
128129
unet = "UNet (scheduler, LoRAs)"
130+
transformer = "Transformer"
129131
vae = "VAE"
130132
cond = "Conditioning tensor"
131133
controlnet_model = "ControlNet model to load"
132134
vae_model = "VAE model to load"
133135
lora_model = "LoRA model to load"
134136
main_model = "Main model (UNet, VAE, CLIP) to load"
137+
flux_model = "Flux model (Transformer, VAE, CLIP) to load"
135138
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
136139
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
137140
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"

invokeai/app/invocations/flux_text_to_image.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
from pathlib import Path
22
from typing import Literal
3+
from pydantic import Field
34

45
import torch
56
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
67
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
78
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
9+
from invokeai.app.invocations.model import ModelIdentifierField
810
from optimum.quanto import qfloat8
911
from PIL import Image
1012
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
1113
from transformers.models.auto import AutoModelForTextEncoding
1214

1315
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
14-
from invokeai.app.invocations.fields import InputField, WithBoard, WithMetadata
16+
from invokeai.app.invocations.fields import InputField, WithBoard, WithMetadata, UIType, Input
1517
from invokeai.app.invocations.primitives import ImageOutput
1618
from invokeai.app.services.shared.invocation_context import InvocationContext
1719
from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel
@@ -40,6 +42,11 @@ class QuantizedModelForTextEncoding(FastQuantizedTransformersModel):
4042
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
4143
"""Text-to-image generation using a FLUX model."""
4244

45+
flux_model: ModelIdentifierField = InputField(
46+
description="The Flux model",
47+
input=Input.Any,
48+
ui_type=UIType.FluxMainModel
49+
)
4350
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.")
4451
use_8bit: bool = InputField(
4552
default=False, description="Whether to quantize the transformer model to 8-bit precision."

invokeai/app/invocations/model.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@ class CLIPField(BaseModel):
6060
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
6161

6262

63+
64+
class TransformerField(BaseModel):
65+
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
66+
scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel")
67+
68+
6369
class VAEField(BaseModel):
6470
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
6571
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
@@ -122,6 +128,49 @@ def invoke(self, context: InvocationContext) -> ModelIdentifierOutput:
122128
return ModelIdentifierOutput(model=self.model)
123129

124130

131+
@invocation_output("flux_model_loader_output")
132+
class FluxModelLoaderOutput(BaseInvocationOutput):
133+
"""Flux base model loader output"""
134+
135+
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
136+
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
137+
clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
138+
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
139+
140+
141+
@invocation("flux_model_loader", title="Flux Main Model", tags=["model", "flux"], category="model", version="1.0.3")
142+
class FluxModelLoaderInvocation(BaseInvocation):
143+
"""Loads a flux base model, outputting its submodels."""
144+
145+
model: ModelIdentifierField = InputField(
146+
description=FieldDescriptions.flux_model,
147+
ui_type=UIType.FluxMainModel,
148+
input=Input.Direct,
149+
)
150+
151+
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
152+
model_key = self.model.key
153+
154+
# TODO: not found exceptions
155+
if not context.models.exists(model_key):
156+
raise Exception(f"Unknown model: {model_key}")
157+
158+
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
159+
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
160+
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
161+
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
162+
tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
163+
text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
164+
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
165+
166+
return FluxModelLoaderOutput(
167+
transformer=TransformerField(transformer=transformer, scheduler=scheduler),
168+
clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
169+
clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
170+
vae=VAEField(vae=vae),
171+
)
172+
173+
125174
@invocation(
126175
"main_model_loader",
127176
title="Main Model",

invokeai/backend/model_manager/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ class BaseModelType(str, Enum):
5252
StableDiffusion2 = "sd-2"
5353
StableDiffusionXL = "sdxl"
5454
StableDiffusionXLRefiner = "sdxl-refiner"
55+
Flux = "flux"
5556
# Kandinsky2_1 = "kandinsky-2.1"
5657

5758

@@ -74,6 +75,7 @@ class SubModelType(str, Enum):
7475
"""Submodel type."""
7576

7677
UNet = "unet"
78+
Transformer = "transformer"
7779
TextEncoder = "text_encoder"
7880
TextEncoder2 = "text_encoder_2"
7981
Tokenizer = "tokenizer"

invokeai/backend/model_manager/probe.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ class ModelProbe(object):
9595
}
9696

9797
CLASS2TYPE = {
98+
"FluxPipeline": ModelType.Main,
9899
"StableDiffusionPipeline": ModelType.Main,
99100
"StableDiffusionInpaintPipeline": ModelType.Main,
100101
"StableDiffusionXLPipeline": ModelType.Main,
@@ -626,6 +627,10 @@ def get_repo_variant(self) -> ModelRepoVariant:
626627

627628
class PipelineFolderProbe(FolderProbeBase):
628629
def get_base_type(self) -> BaseModelType:
630+
with open(f"{self.model_path}/model_index.json", "r") as file:
631+
conf = json.load(file)
632+
if "_class_name" in conf and conf.get("_class_name") == "FluxPipeline":
633+
return BaseModelType.Flux
629634
with open(self.model_path / "unet" / "config.json", "r") as file:
630635
unet_conf = json.load(file)
631636
if unet_conf["cross_attention_dim"] == 768:

invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge.tsx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ const BASE_COLOR_MAP: Record<BaseModelType, string> = {
1313
'sd-2': 'teal',
1414
sdxl: 'invokeBlue',
1515
'sdxl-refiner': 'invokeBlue',
16+
flux: 'invokeBlue',
1617
};
1718

1819
const ModelBaseBadge = ({ base }: Props) => {

invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ import {
1414
isEnumFieldInputTemplate,
1515
isFloatFieldInputInstance,
1616
isFloatFieldInputTemplate,
17+
isFluxMainModelFieldInputInstance,
18+
isFluxMainModelFieldInputTemplate,
1719
isImageFieldInputInstance,
1820
isImageFieldInputTemplate,
1921
isIntegerFieldInputInstance,
@@ -48,6 +50,7 @@ import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
4850
import ColorFieldInputComponent from './inputs/ColorFieldInputComponent';
4951
import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent';
5052
import EnumFieldInputComponent from './inputs/EnumFieldInputComponent';
53+
import FluxMainModelFieldInputComponent from './inputs/FluxMainModelFieldInputComponent';
5154
import ImageFieldInputComponent from './inputs/ImageFieldInputComponent';
5255
import IPAdapterModelFieldInputComponent from './inputs/IPAdapterModelFieldInputComponent';
5356
import LoRAModelFieldInputComponent from './inputs/LoRAModelFieldInputComponent';
@@ -69,6 +72,7 @@ type InputFieldProps = {
6972
const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
7073
const fieldInstance = useFieldInputInstance(nodeId, fieldName);
7174
const fieldTemplate = useFieldInputTemplate(nodeId, fieldName);
75+
window.console.log("Hit 0")
7276

7377
if (isStringFieldInputInstance(fieldInstance) && isStringFieldInputTemplate(fieldTemplate)) {
7478
return <StringFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
@@ -145,6 +149,9 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
145149
if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) {
146150
return <ColorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
147151
}
152+
if (isFluxMainModelFieldInputInstance(fieldInstance) && isFluxMainModelFieldInputTemplate(fieldTemplate)) {
153+
return <FluxMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
154+
}
148155

149156
if (isSDXLMainModelFieldInputInstance(fieldInstance) && isSDXLMainModelFieldInputTemplate(fieldTemplate)) {
150157
return <SDXLMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
2+
import { useAppDispatch } from 'app/store/storeHooks';
3+
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
4+
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
5+
import type { FluxMainModelFieldInputInstance, FluxMainModelFieldInputTemplate } from 'features/nodes/types/field';
6+
import { memo, useCallback } from 'react';
7+
import { useFluxModels } from 'services/api/hooks/modelsByType';
8+
import type { MainModelConfig } from 'services/api/types';
9+
10+
import type { FieldComponentProps } from './types';
11+
12+
type Props = FieldComponentProps<FluxMainModelFieldInputInstance, FluxMainModelFieldInputTemplate>;
13+
14+
const FluxMainModelFieldInputComponent = (props: Props) => {
15+
const { nodeId, field } = props;
16+
const dispatch = useAppDispatch();
17+
const [modelConfigs, { isLoading }] = useFluxModels();
18+
const _onChange = useCallback(
19+
(value: MainModelConfig | null) => {
20+
if (!value) {
21+
return;
22+
}
23+
dispatch(
24+
fieldMainModelValueChanged({
25+
nodeId,
26+
fieldName: field.name,
27+
value,
28+
})
29+
);
30+
},
31+
[dispatch, field.name, nodeId]
32+
);
33+
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
34+
modelConfigs,
35+
onChange: _onChange,
36+
isLoading,
37+
selectedModel: field.value,
38+
});
39+
40+
return (
41+
<Flex w="full" alignItems="center" gap={2}>
42+
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
43+
<Combobox
44+
value={value}
45+
placeholder={placeholder}
46+
options={options}
47+
onChange={onChange}
48+
noOptionsMessage={noOptionsMessage}
49+
/>
50+
</FormControl>
51+
</Flex>
52+
);
53+
};
54+
55+
export default memo(FluxMainModelFieldInputComponent);

invokeai/frontend/web/src/features/nodes/types/common.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ export type SchedulerField = z.infer<typeof zSchedulerField>;
6161
// #endregion
6262

6363
// #region Model-related schemas
64-
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
64+
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner', 'flux']);
6565
const zModelType = z.enum([
6666
'main',
6767
'vae',
@@ -76,6 +76,7 @@ const zModelType = z.enum([
7676
]);
7777
const zSubModelType = z.enum([
7878
'unet',
79+
'transformer',
7980
'text_encoder',
8081
'text_encoder_2',
8182
'tokenizer',

invokeai/frontend/web/src/features/nodes/types/constants.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ export const MODEL_TYPES = [
3131
'ControlNetModelField',
3232
'LoRAModelField',
3333
'MainModelField',
34+
'FluxMainModelField',
3435
'SDXLMainModelField',
3536
'SDXLRefinerModelField',
3637
'VaeModelField',
@@ -61,13 +62,15 @@ export const FIELD_COLORS: { [key: string]: string } = {
6162
LatentsField: 'pink.500',
6263
LoRAModelField: 'teal.500',
6364
MainModelField: 'teal.500',
65+
FluxMainModelField: 'teal.500',
6466
SDXLMainModelField: 'teal.500',
6567
SDXLRefinerModelField: 'teal.500',
6668
SpandrelImageToImageModelField: 'teal.500',
6769
StringField: 'yellow.500',
6870
T2IAdapterField: 'teal.500',
6971
T2IAdapterModelField: 'teal.500',
7072
UNetField: 'red.500',
73+
TransformerField: 'red.500',
7174
VAEField: 'blue.500',
7275
VAEModelField: 'teal.500',
7376
};

invokeai/frontend/web/src/features/nodes/types/field.ts

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,10 @@ const zSDXLMainModelFieldType = zFieldTypeBase.extend({
115115
name: z.literal('SDXLMainModelField'),
116116
originalType: zStatelessFieldType.optional(),
117117
});
118+
const zFluxMainModelFieldType = zFieldTypeBase.extend({
119+
name: z.literal('FluxMainModelField'),
120+
originalType: zStatelessFieldType.optional(),
121+
});
118122
const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
119123
name: z.literal('SDXLRefinerModelField'),
120124
originalType: zStatelessFieldType.optional(),
@@ -158,6 +162,7 @@ const zStatefulFieldType = z.union([
158162
zModelIdentifierFieldType,
159163
zMainModelFieldType,
160164
zSDXLMainModelFieldType,
165+
zFluxMainModelFieldType,
161166
zSDXLRefinerModelFieldType,
162167
zVAEModelFieldType,
163168
zLoRAModelFieldType,
@@ -447,6 +452,29 @@ export const isSDXLMainModelFieldInputTemplate = (val: unknown): val is SDXLMain
447452
zSDXLMainModelFieldInputTemplate.safeParse(val).success;
448453
// #endregion
449454

455+
// #region FluxMainModelField
456+
457+
const zFluxMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
458+
const zFluxMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
459+
value: zFluxMainModelFieldValue,
460+
});
461+
const zFluxMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
462+
type: zFluxMainModelFieldType,
463+
originalType: zFieldType.optional(),
464+
default: zFluxMainModelFieldValue,
465+
});
466+
const zFluxMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
467+
type: zFluxMainModelFieldType,
468+
});
469+
export type FluxMainModelFieldInputInstance = z.infer<typeof zFluxMainModelFieldInputInstance>;
470+
export type FluxMainModelFieldInputTemplate = z.infer<typeof zFluxMainModelFieldInputTemplate>;
471+
export const isFluxMainModelFieldInputInstance = (val: unknown): val is FluxMainModelFieldInputInstance =>
472+
zFluxMainModelFieldInputInstance.safeParse(val).success;
473+
export const isFluxMainModelFieldInputTemplate = (val: unknown): val is FluxMainModelFieldInputTemplate =>
474+
zFluxMainModelFieldInputTemplate.safeParse(val).success;
475+
476+
// #endregion
477+
450478
// #region SDXLRefinerModelField
451479

452480
/** @alias */ // tells knip to ignore this duplicate export
@@ -693,6 +721,7 @@ export const zStatefulFieldValue = z.union([
693721
zModelIdentifierFieldValue,
694722
zMainModelFieldValue,
695723
zSDXLMainModelFieldValue,
724+
zFluxMainModelFieldValue,
696725
zSDXLRefinerModelFieldValue,
697726
zVAEModelFieldValue,
698727
zLoRAModelFieldValue,
@@ -720,6 +749,7 @@ const zStatefulFieldInputInstance = z.union([
720749
zBoardFieldInputInstance,
721750
zModelIdentifierFieldInputInstance,
722751
zMainModelFieldInputInstance,
752+
zFluxMainModelFieldInputInstance,
723753
zSDXLMainModelFieldInputInstance,
724754
zSDXLRefinerModelFieldInputInstance,
725755
zVAEModelFieldInputInstance,
@@ -749,6 +779,7 @@ const zStatefulFieldInputTemplate = z.union([
749779
zBoardFieldInputTemplate,
750780
zModelIdentifierFieldInputTemplate,
751781
zMainModelFieldInputTemplate,
782+
zFluxMainModelFieldInputTemplate,
752783
zSDXLMainModelFieldInputTemplate,
753784
zSDXLRefinerModelFieldInputTemplate,
754785
zVAEModelFieldInputTemplate,
@@ -779,6 +810,7 @@ const zStatefulFieldOutputTemplate = z.union([
779810
zBoardFieldOutputTemplate,
780811
zModelIdentifierFieldOutputTemplate,
781812
zMainModelFieldOutputTemplate,
813+
zFluxMainModelFieldOutputTemplate,
782814
zSDXLMainModelFieldOutputTemplate,
783815
zSDXLRefinerModelFieldOutputTemplate,
784816
zVAEModelFieldOutputTemplate,

invokeai/frontend/web/src/features/nodes/types/v1/fieldTypeMap.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,11 @@ const FIELD_TYPE_V1_TO_STATEFUL_FIELD_TYPE_V2: {
114114
isCollection: false,
115115
isCollectionOrScalar: false,
116116
},
117+
FluxMainModelField: {
118+
name: 'FluxMainModelField',
119+
isCollection: false,
120+
isCollectionOrScalar: false,
121+
},
117122
SDXLMainModelField: {
118123
name: 'SDXLMainModelField',
119124
isCollection: false,

0 commit comments

Comments
 (0)