-
Notifications
You must be signed in to change notification settings - Fork 2.6k
feat(LoRA): support AI Toolkit LoRA for FLUX #8071
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
d2da351
feat(LoRA): support AI Toolkit LoRA for FLUX [WIP]
0ea5581
fix(LoRA): add ai-toolkit to lora loader
ab81a51
WIP!: …they weren't in diffusers format…
a761fcd
test: add some aitoolkit lora tests
0ca7a05
fix: group aitoolkit lora layers
f392e39
fix: group aitoolkit lora layers
b0c3520
fix: move AI Toolkit to the bottom of the detection list
a79c5e3
Merge branch 'main' into feat/aitoolkit-lora
13fdbd8
Merge branch 'refs/heads/main' into feat/aitoolkit-lora
dd32bd2
Merge branch 'refs/heads/main' into feat/aitoolkit-lora
2688724
Merge branch 'main' into feat/aitoolkit-lora
jazzhaiku File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
63 changes: 63 additions & 0 deletions
63
invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import json | ||
from dataclasses import dataclass, field | ||
from typing import Any | ||
|
||
import torch | ||
|
||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch | ||
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict | ||
from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import _group_by_layer | ||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX | ||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw | ||
from invokeai.backend.util import InvokeAILogger | ||
|
||
|
||
def is_state_dict_likely_in_flux_aitoolkit_format(state_dict: dict[str, Any], metadata: dict[str, Any] = None) -> bool: | ||
if metadata: | ||
try: | ||
software = json.loads(metadata.get("software", "{}")) | ||
except json.JSONDecodeError: | ||
return False | ||
return software.get("name") == "ai-toolkit" | ||
# metadata got lost somewhere | ||
return any("diffusion_model" == k.split(".", 1)[0] for k in state_dict.keys()) | ||
|
||
|
||
@dataclass | ||
class GroupedStateDict: | ||
transformer: dict[str, Any] = field(default_factory=dict) | ||
# might also grow CLIP and T5 submodels | ||
|
||
|
||
def _group_state_by_submodel(state_dict: dict[str, Any]) -> GroupedStateDict: | ||
logger = InvokeAILogger.get_logger() | ||
grouped = GroupedStateDict() | ||
for key, value in state_dict.items(): | ||
submodel_name, param_name = key.split(".", 1) | ||
match submodel_name: | ||
case "diffusion_model": | ||
grouped.transformer[param_name] = value | ||
case _: | ||
logger.warning(f"Unexpected submodel name: {submodel_name}") | ||
return grouped | ||
|
||
|
||
def _rename_peft_lora_keys(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: | ||
"""Renames keys from the PEFT LoRA format to the InvokeAI format.""" | ||
renamed_state_dict = {} | ||
for key, value in state_dict.items(): | ||
renamed_key = key.replace(".lora_A.", ".lora_down.").replace(".lora_B.", ".lora_up.") | ||
renamed_state_dict[renamed_key] = value | ||
return renamed_state_dict | ||
|
||
|
||
def lora_model_from_flux_aitoolkit_state_dict(state_dict: dict[str, torch.Tensor]) -> ModelPatchRaw: | ||
state_dict = _rename_peft_lora_keys(state_dict) | ||
by_layer = _group_by_layer(state_dict) | ||
by_model = _group_state_by_submodel(by_layer) | ||
|
||
layers: dict[str, BaseLayerPatch] = {} | ||
for layer_key, layer_state_dict in by_model.transformer.items(): | ||
layers[FLUX_LORA_TRANSFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict) | ||
|
||
return ModelPatchRaw(layers=layers) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.