diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index f4c057cb5b7..f312aeffb67 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -296,7 +296,7 @@ def flux_lora_format(cls, mod: ModelOnDisk): from invokeai.backend.patches.lora_conversions.formats import flux_format_from_state_dict sd = mod.load_state_dict(mod.path) - value = flux_format_from_state_dict(sd) + value = flux_format_from_state_dict(sd, mod.metadata()) mod.cache[key] = value return value diff --git a/invokeai/backend/model_manager/load/model_loaders/lora.py b/invokeai/backend/model_manager/load/model_loaders/lora.py index 071713316a1..076919a14fe 100644 --- a/invokeai/backend/model_manager/load/model_loaders/lora.py +++ b/invokeai/backend/model_manager/load/model_loaders/lora.py @@ -20,6 +20,10 @@ ModelType, SubModelType, ) +from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import ( + is_state_dict_likely_in_flux_aitoolkit_format, + lora_model_from_flux_aitoolkit_state_dict, +) from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import ( is_state_dict_likely_flux_control, lora_model_from_flux_control_state_dict, @@ -92,6 +96,8 @@ def _load_model( model = lora_model_from_flux_onetrainer_state_dict(state_dict=state_dict) elif is_state_dict_likely_flux_control(state_dict=state_dict): model = lora_model_from_flux_control_state_dict(state_dict=state_dict) + elif is_state_dict_likely_in_flux_aitoolkit_format(state_dict=state_dict): + model = lora_model_from_flux_aitoolkit_state_dict(state_dict=state_dict) else: raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}") else: diff --git a/invokeai/backend/model_manager/taxonomy.py b/invokeai/backend/model_manager/taxonomy.py index 37c1197cc47..e009b203e14 100644 --- a/invokeai/backend/model_manager/taxonomy.py +++ b/invokeai/backend/model_manager/taxonomy.py @@ -137,6 +137,7 @@ class FluxLoRAFormat(str, Enum): Kohya = "flux.kohya" OneTrainer = "flux.onetrainer" Control = "flux.control" + AIToolkit = "flux.aitoolkit" AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, None] diff --git a/invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py b/invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py new file mode 100644 index 00000000000..6ca06a0355f --- /dev/null +++ b/invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py @@ -0,0 +1,63 @@ +import json +from dataclasses import dataclass, field +from typing import Any + +import torch + +from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch +from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict +from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import _group_by_layer +from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX +from invokeai.backend.patches.model_patch_raw import ModelPatchRaw +from invokeai.backend.util import InvokeAILogger + + +def is_state_dict_likely_in_flux_aitoolkit_format(state_dict: dict[str, Any], metadata: dict[str, Any] = None) -> bool: + if metadata: + try: + software = json.loads(metadata.get("software", "{}")) + except json.JSONDecodeError: + return False + return software.get("name") == "ai-toolkit" + # metadata got lost somewhere + return any("diffusion_model" == k.split(".", 1)[0] for k in state_dict.keys()) + + +@dataclass +class GroupedStateDict: + transformer: dict[str, Any] = field(default_factory=dict) + # might also grow CLIP and T5 submodels + + +def _group_state_by_submodel(state_dict: dict[str, Any]) -> GroupedStateDict: + logger = InvokeAILogger.get_logger() + grouped = GroupedStateDict() + for key, value in state_dict.items(): + submodel_name, param_name = key.split(".", 1) + match submodel_name: + case "diffusion_model": + grouped.transformer[param_name] = value + case _: + logger.warning(f"Unexpected submodel name: {submodel_name}") + return grouped + + +def _rename_peft_lora_keys(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """Renames keys from the PEFT LoRA format to the InvokeAI format.""" + renamed_state_dict = {} + for key, value in state_dict.items(): + renamed_key = key.replace(".lora_A.", ".lora_down.").replace(".lora_B.", ".lora_up.") + renamed_state_dict[renamed_key] = value + return renamed_state_dict + + +def lora_model_from_flux_aitoolkit_state_dict(state_dict: dict[str, torch.Tensor]) -> ModelPatchRaw: + state_dict = _rename_peft_lora_keys(state_dict) + by_layer = _group_by_layer(state_dict) + by_model = _group_state_by_submodel(by_layer) + + layers: dict[str, BaseLayerPatch] = {} + for layer_key, layer_state_dict in by_model.transformer.items(): + layers[FLUX_LORA_TRANSFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict) + + return ModelPatchRaw(layers=layers) diff --git a/invokeai/backend/patches/lora_conversions/formats.py b/invokeai/backend/patches/lora_conversions/formats.py index 46073100679..94f71e05ee6 100644 --- a/invokeai/backend/patches/lora_conversions/formats.py +++ b/invokeai/backend/patches/lora_conversions/formats.py @@ -1,4 +1,7 @@ from invokeai.backend.model_manager.taxonomy import FluxLoRAFormat +from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import ( + is_state_dict_likely_in_flux_aitoolkit_format, +) from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import is_state_dict_likely_flux_control from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import ( is_state_dict_likely_in_flux_diffusers_format, @@ -11,7 +14,7 @@ ) -def flux_format_from_state_dict(state_dict): +def flux_format_from_state_dict(state_dict: dict, metadata: dict | None = None) -> FluxLoRAFormat | None: if is_state_dict_likely_in_flux_kohya_format(state_dict): return FluxLoRAFormat.Kohya elif is_state_dict_likely_in_flux_onetrainer_format(state_dict): @@ -20,5 +23,7 @@ def flux_format_from_state_dict(state_dict): return FluxLoRAFormat.Diffusers elif is_state_dict_likely_flux_control(state_dict): return FluxLoRAFormat.Control + elif is_state_dict_likely_in_flux_aitoolkit_format(state_dict, metadata): + return FluxLoRAFormat.AIToolkit else: return None diff --git a/tests/backend/patches/lora_conversions/lora_state_dicts/flux_lora_aitoolkit_format.py b/tests/backend/patches/lora_conversions/lora_state_dicts/flux_lora_aitoolkit_format.py new file mode 100644 index 00000000000..98b278df869 --- /dev/null +++ b/tests/backend/patches/lora_conversions/lora_state_dicts/flux_lora_aitoolkit_format.py @@ -0,0 +1,458 @@ +state_dict_keys = { + "diffusion_model.double_blocks.0.img_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.0.img_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.0.img_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.0.img_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.0.img_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.0.img_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.0.img_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.0.img_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.0.txt_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.0.txt_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.0.txt_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.0.txt_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.0.txt_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.0.txt_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.0.txt_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.0.txt_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.1.img_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.1.img_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.1.img_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.1.img_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.1.img_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.1.img_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.1.img_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.1.img_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.1.txt_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.1.txt_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.1.txt_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.1.txt_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.1.txt_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.1.txt_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.1.txt_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.1.txt_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.10.img_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.10.img_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.10.img_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.10.img_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.10.img_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.10.img_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.10.img_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.10.img_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.10.txt_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.10.txt_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.10.txt_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.10.txt_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.10.txt_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.10.txt_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.10.txt_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.10.txt_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.11.img_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.11.img_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.11.img_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.11.img_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.11.img_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.11.img_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.11.img_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.11.img_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.11.txt_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.11.txt_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.11.txt_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.11.txt_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.11.txt_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.11.txt_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.11.txt_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.11.txt_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.12.img_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.12.img_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.12.img_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.12.img_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.12.img_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.12.img_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.12.img_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.12.img_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.12.txt_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.12.txt_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.12.txt_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.12.txt_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.12.txt_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.12.txt_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.12.txt_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.12.txt_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.13.img_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.13.img_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.13.img_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.13.img_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.13.img_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.13.img_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.13.img_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.13.img_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.13.txt_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.13.txt_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.13.txt_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.13.txt_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.13.txt_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.13.txt_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.13.txt_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.13.txt_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.14.img_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.14.img_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.14.img_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.14.img_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.14.img_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.14.img_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.14.img_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.14.img_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.14.txt_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.14.txt_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.14.txt_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.14.txt_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.14.txt_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.14.txt_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.14.txt_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.14.txt_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.15.img_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.15.img_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.15.img_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.15.img_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.15.img_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.15.img_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.15.img_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.15.img_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.15.txt_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.15.txt_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.15.txt_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.15.txt_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.15.txt_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.15.txt_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.15.txt_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.15.txt_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.16.img_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.16.img_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.16.img_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.16.img_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.16.img_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.16.img_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.16.img_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.16.img_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.16.txt_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.16.txt_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.16.txt_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.16.txt_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.16.txt_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.16.txt_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.16.txt_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.16.txt_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.17.img_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.17.img_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.17.img_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.17.img_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.17.img_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.17.img_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.17.img_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.17.img_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.17.txt_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.17.txt_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.17.txt_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.17.txt_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.17.txt_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.17.txt_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.17.txt_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.17.txt_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.18.img_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.18.img_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.18.img_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.18.img_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.18.img_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.18.img_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.18.img_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.18.img_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.18.txt_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.18.txt_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.18.txt_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.18.txt_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.18.txt_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.18.txt_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.18.txt_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.18.txt_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.2.img_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.2.img_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.2.img_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.2.img_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.2.img_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.2.img_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.2.img_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.2.img_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.2.txt_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.2.txt_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.2.txt_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.2.txt_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.2.txt_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.2.txt_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.2.txt_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.2.txt_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.3.img_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.3.img_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.3.img_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.3.img_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.3.img_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.3.img_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.3.img_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.3.img_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.3.txt_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.3.txt_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.3.txt_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.3.txt_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.3.txt_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.3.txt_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.3.txt_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.3.txt_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.4.img_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.4.img_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.4.img_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.4.img_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.4.img_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.4.img_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.4.img_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.4.img_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.4.txt_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.4.txt_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.4.txt_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.4.txt_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.4.txt_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.4.txt_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.4.txt_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.4.txt_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.5.img_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.5.img_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.5.img_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.5.img_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.5.img_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.5.img_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.5.img_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.5.img_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.5.txt_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.5.txt_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.5.txt_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.5.txt_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.5.txt_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.5.txt_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.5.txt_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.5.txt_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.6.img_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.6.img_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.6.img_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.6.img_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.6.img_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.6.img_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.6.img_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.6.img_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.6.txt_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.6.txt_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.6.txt_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.6.txt_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.6.txt_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.6.txt_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.6.txt_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.6.txt_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.7.img_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.7.img_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.7.img_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.7.img_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.7.img_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.7.img_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.7.img_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.7.img_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.7.txt_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.7.txt_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.7.txt_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.7.txt_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.7.txt_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.7.txt_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.7.txt_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.7.txt_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.8.img_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.8.img_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.8.img_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.8.img_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.8.img_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.8.img_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.8.img_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.8.img_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.8.txt_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.8.txt_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.8.txt_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.8.txt_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.8.txt_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.8.txt_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.8.txt_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.8.txt_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.9.img_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.9.img_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.9.img_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.9.img_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.9.img_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.9.img_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.9.img_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.9.img_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.9.txt_attn.proj.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.9.txt_attn.proj.lora_B.weight": [3072, 16], + "diffusion_model.double_blocks.9.txt_attn.qkv.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.9.txt_attn.qkv.lora_B.weight": [9216, 16], + "diffusion_model.double_blocks.9.txt_mlp.0.lora_A.weight": [16, 3072], + "diffusion_model.double_blocks.9.txt_mlp.0.lora_B.weight": [12288, 16], + "diffusion_model.double_blocks.9.txt_mlp.2.lora_A.weight": [16, 12288], + "diffusion_model.double_blocks.9.txt_mlp.2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.0.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.0.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.0.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.0.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.1.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.1.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.1.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.1.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.10.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.10.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.10.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.10.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.11.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.11.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.11.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.11.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.12.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.12.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.12.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.12.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.13.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.13.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.13.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.13.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.14.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.14.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.14.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.14.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.15.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.15.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.15.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.15.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.16.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.16.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.16.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.16.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.17.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.17.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.17.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.17.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.18.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.18.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.18.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.18.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.19.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.19.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.19.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.19.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.2.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.2.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.2.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.2.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.20.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.20.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.20.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.20.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.21.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.21.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.21.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.21.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.22.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.22.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.22.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.22.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.23.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.23.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.23.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.23.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.24.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.24.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.24.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.24.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.25.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.25.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.25.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.25.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.26.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.26.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.26.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.26.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.27.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.27.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.27.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.27.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.28.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.28.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.28.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.28.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.29.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.29.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.29.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.29.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.3.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.3.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.3.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.3.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.30.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.30.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.30.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.30.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.31.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.31.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.31.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.31.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.32.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.32.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.32.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.32.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.33.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.33.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.33.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.33.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.34.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.34.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.34.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.34.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.35.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.35.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.35.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.35.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.36.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.36.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.36.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.36.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.37.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.37.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.37.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.37.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.4.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.4.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.4.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.4.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.5.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.5.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.5.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.5.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.6.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.6.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.6.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.6.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.7.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.7.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.7.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.7.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.8.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.8.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.8.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.8.linear2.lora_B.weight": [3072, 16], + "diffusion_model.single_blocks.9.linear1.lora_A.weight": [16, 3072], + "diffusion_model.single_blocks.9.linear1.lora_B.weight": [21504, 16], + "diffusion_model.single_blocks.9.linear2.lora_A.weight": [16, 15360], + "diffusion_model.single_blocks.9.linear2.lora_B.weight": [3072, 16], +} diff --git a/tests/backend/patches/lora_conversions/test_flux_aitoolkit_lora_conversion_utils.py b/tests/backend/patches/lora_conversions/test_flux_aitoolkit_lora_conversion_utils.py new file mode 100644 index 00000000000..ed3e05a9b26 --- /dev/null +++ b/tests/backend/patches/lora_conversions/test_flux_aitoolkit_lora_conversion_utils.py @@ -0,0 +1,59 @@ +import accelerate +import pytest + +from invokeai.backend.flux.model import Flux +from invokeai.backend.flux.util import params +from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import ( + _group_state_by_submodel, + is_state_dict_likely_in_flux_aitoolkit_format, + lora_model_from_flux_aitoolkit_state_dict, +) +from tests.backend.patches.lora_conversions.lora_state_dicts.flux_dora_onetrainer_format import ( + state_dict_keys as flux_onetrainer_state_dict_keys, +) +from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_aitoolkit_format import ( + state_dict_keys as flux_aitoolkit_state_dict_keys, +) +from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_diffusers_format import ( + state_dict_keys as flux_diffusers_state_dict_keys, +) +from tests.backend.patches.lora_conversions.lora_state_dicts.utils import keys_to_mock_state_dict + + +def test_is_state_dict_likely_in_flux_aitoolkit_format(): + state_dict = keys_to_mock_state_dict(flux_aitoolkit_state_dict_keys) + assert is_state_dict_likely_in_flux_aitoolkit_format(state_dict) + + +@pytest.mark.parametrize("sd_keys", [flux_diffusers_state_dict_keys, flux_onetrainer_state_dict_keys]) +def test_is_state_dict_likely_in_flux_kohya_format_false(sd_keys: dict[str, list[int]]): + state_dict = keys_to_mock_state_dict(sd_keys) + assert not is_state_dict_likely_in_flux_aitoolkit_format(state_dict) + + +def test_flux_aitoolkit_transformer_state_dict_is_in_invoke_format(): + state_dict = keys_to_mock_state_dict(flux_aitoolkit_state_dict_keys) + converted_state_dict = _group_state_by_submodel(state_dict).transformer + + # Extract the prefixes from the converted state dict (without the lora suffixes) + converted_key_prefixes: list[str] = [] + for k in converted_state_dict.keys(): + k = k.replace(".lora_A.weight", "") + k = k.replace(".lora_B.weight", "") + converted_key_prefixes.append(k) + + # Initialize a FLUX model on the meta device. + with accelerate.init_empty_weights(): + model = Flux(params["flux-schnell"]) + model_keys = set(model.state_dict().keys()) + + for converted_key_prefix in converted_key_prefixes: + assert any(model_key.startswith(converted_key_prefix) for model_key in model_keys), ( + f"'{converted_key_prefix}' did not match any model keys." + ) + + +def test_lora_model_from_flux_aitoolkit_state_dict(): + state_dict = keys_to_mock_state_dict(flux_aitoolkit_state_dict_keys) + + assert lora_model_from_flux_aitoolkit_state_dict(state_dict)