Skip to content

feat(LoRA): support AI Toolkit LoRA for FLUX #8071

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 16, 2025
2 changes: 1 addition & 1 deletion invokeai/backend/model_manager/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def flux_lora_format(cls, mod: ModelOnDisk):
from invokeai.backend.patches.lora_conversions.formats import flux_format_from_state_dict

sd = mod.load_state_dict(mod.path)
value = flux_format_from_state_dict(sd)
value = flux_format_from_state_dict(sd, mod.metadata())
mod.cache[key] = value
return value

Expand Down
6 changes: 6 additions & 0 deletions invokeai/backend/model_manager/load/model_loaders/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
ModelType,
SubModelType,
)
from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import (
is_state_dict_likely_in_flux_aitoolkit_format,
lora_model_from_flux_aitoolkit_state_dict,
)
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import (
is_state_dict_likely_flux_control,
lora_model_from_flux_control_state_dict,
Expand Down Expand Up @@ -92,6 +96,8 @@ def _load_model(
model = lora_model_from_flux_onetrainer_state_dict(state_dict=state_dict)
elif is_state_dict_likely_flux_control(state_dict=state_dict):
model = lora_model_from_flux_control_state_dict(state_dict=state_dict)
elif is_state_dict_likely_in_flux_aitoolkit_format(state_dict=state_dict):
model = lora_model_from_flux_aitoolkit_state_dict(state_dict=state_dict)
else:
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
else:
Expand Down
1 change: 1 addition & 0 deletions invokeai/backend/model_manager/taxonomy.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class FluxLoRAFormat(str, Enum):
Kohya = "flux.kohya"
OneTrainer = "flux.onetrainer"
Control = "flux.control"
AIToolkit = "flux.aitoolkit"


AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, None]
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import json
from dataclasses import dataclass, field
from typing import Any

import torch

from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import _group_by_layer
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.util import InvokeAILogger


def is_state_dict_likely_in_flux_aitoolkit_format(state_dict: dict[str, Any], metadata: dict[str, Any] = None) -> bool:
if metadata:
try:
software = json.loads(metadata.get("software", "{}"))
except json.JSONDecodeError:
return False
return software.get("name") == "ai-toolkit"
# metadata got lost somewhere
return any("diffusion_model" == k.split(".", 1)[0] for k in state_dict.keys())


@dataclass
class GroupedStateDict:
transformer: dict[str, Any] = field(default_factory=dict)
# might also grow CLIP and T5 submodels


def _group_state_by_submodel(state_dict: dict[str, Any]) -> GroupedStateDict:
logger = InvokeAILogger.get_logger()
grouped = GroupedStateDict()
for key, value in state_dict.items():
submodel_name, param_name = key.split(".", 1)
match submodel_name:
case "diffusion_model":
grouped.transformer[param_name] = value
case _:
logger.warning(f"Unexpected submodel name: {submodel_name}")
return grouped


def _rename_peft_lora_keys(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""Renames keys from the PEFT LoRA format to the InvokeAI format."""
renamed_state_dict = {}
for key, value in state_dict.items():
renamed_key = key.replace(".lora_A.", ".lora_down.").replace(".lora_B.", ".lora_up.")
renamed_state_dict[renamed_key] = value
return renamed_state_dict


def lora_model_from_flux_aitoolkit_state_dict(state_dict: dict[str, torch.Tensor]) -> ModelPatchRaw:
state_dict = _rename_peft_lora_keys(state_dict)
by_layer = _group_by_layer(state_dict)
by_model = _group_state_by_submodel(by_layer)

layers: dict[str, BaseLayerPatch] = {}
for layer_key, layer_state_dict in by_model.transformer.items():
layers[FLUX_LORA_TRANSFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)

return ModelPatchRaw(layers=layers)
7 changes: 6 additions & 1 deletion invokeai/backend/patches/lora_conversions/formats.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from invokeai.backend.model_manager.taxonomy import FluxLoRAFormat
from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import (
is_state_dict_likely_in_flux_aitoolkit_format,
)
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import (
is_state_dict_likely_in_flux_diffusers_format,
Expand All @@ -11,7 +14,7 @@
)


def flux_format_from_state_dict(state_dict):
def flux_format_from_state_dict(state_dict: dict, metadata: dict | None = None) -> FluxLoRAFormat | None:
if is_state_dict_likely_in_flux_kohya_format(state_dict):
return FluxLoRAFormat.Kohya
elif is_state_dict_likely_in_flux_onetrainer_format(state_dict):
Expand All @@ -20,5 +23,7 @@ def flux_format_from_state_dict(state_dict):
return FluxLoRAFormat.Diffusers
elif is_state_dict_likely_flux_control(state_dict):
return FluxLoRAFormat.Control
elif is_state_dict_likely_in_flux_aitoolkit_format(state_dict, metadata):
return FluxLoRAFormat.AIToolkit
else:
return None
Loading