diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 000000000..496f08f8a --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,21 @@ +# Most of the stuff is Staxs +# core is somewhat general +* @Stax124 +/core/ @Stax124 @gabe56f +/frontend/ @Stax124 @gabe56f + +# Stax-specific +/core/* @Stax124 +/core/inference/esrgan/ @Stax124 + +# Gabe-specific stuff +/libs/ @gabe56f +/core/scheduling/ @gabe56f +/core/optimizations/ @gabe56f +/core/inference/injectables/ @gabe56f +/core/inference/utilities/kohya_hires.py @gabe56f +/core/inference/utilities/anisotropic.py @gabe56f +/core/inference/utilities/cfg.py @gabe56f +/core/inference/utilities/sag/ @gabe56f +/core/inference/utilities/prompt_expansion/ @gabe56f +/core/inference/onnx/ @gabe56f \ No newline at end of file diff --git a/.github/workflows/docker_build_tag.yml b/.github/workflows/docker_build_tag.yml new file mode 100644 index 000000000..09b1f4d40 --- /dev/null +++ b/.github/workflows/docker_build_tag.yml @@ -0,0 +1,27 @@ +name: Docker Build on Tag Push + +on: + push: + tags: + - "v*" + +jobs: + build: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Login to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + - name: Build and push + uses: docker/build-push-action@v5 + with: + context: . + file: docker/cuda/dockerfile + push: true + tags: ${{ secrets.DOCKERHUB_USERNAME }}/volta:${{ github.ref_name }}-cuda diff --git a/.gitignore b/.gitignore index 6bc19436a..c9d9b40b5 100644 --- a/.gitignore +++ b/.gitignore @@ -48,12 +48,12 @@ cover/ # Diffusers convert files out traced_unet/ -onnx +/onnx converted # Docker -test.docker-compose.yml -test-no-mount.docker-compose.yml +*test.docker-compose.yml +*test-no-mount.docker-compose.yml # Docs node_modules/ diff --git a/.vscode/launch.json b/.vscode/launch.json index c88c13bdd..07b6e4776 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -1,12 +1,12 @@ { - "configurations": [ - { - "name": "Python: File", - "type": "python", - "request": "launch", - "program": "main.py", - "args": ["--log-level=DEBUG"], - "justMyCode": false - } - ] + "configurations": [ + { + "name": "VoltaML API Debug", + "type": "python", + "request": "launch", + "program": "main.py", + "args": ["--log-level=DEBUG"], + "justMyCode": false + } + ] } diff --git a/.vscode/settings.json b/.vscode/settings.json index 00d3fda6c..2ef4fd2b7 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -4,11 +4,8 @@ "python.testing.pytestEnabled": true, "python.analysis.typeCheckingMode": "basic", "python.languageServer": "Pylance", - "rust-analyzer.linkedProjects": [ - "./manager/Cargo.toml" - ], + "rust-analyzer.linkedProjects": ["./manager/Cargo.toml"], "[python]": { "editor.defaultFormatter": "ms-python.black-formatter" - }, - "python.formatting.provider": "none" + } } diff --git a/api/app.py b/api/app.py index 8260c8905..eed283404 100644 --- a/api/app.py +++ b/api/app.py @@ -19,7 +19,9 @@ from api.websockets.data import Data from api.websockets.notification import Notification from core import shared +from core.files import get_full_model_path from core.types import InferenceBackend +from core.utils import determine_model_type logger = logging.getLogger(__name__) @@ -47,12 +49,33 @@ async def validation_exception_handler(_request: Request, exc: RequestValidation logger.debug(exc) + if exc._error_cache is not None and exc._error_cache[0]["loc"][0] == "body": + from core.config._config import Configuration + + default_value = Configuration() + keys = [str(i) for i in exc._error_cache[0]["loc"][1:]] # type: ignore + current_value = exc._error_cache[0]["ctx"]["given"] # type: ignore + + # Traverse the config object to find the correct value + for key in keys: + default_value = getattr(default_value, key) + + websocket_manager.broadcast_sync( + data=Data( + data={ + "default_value": default_value, + "key": keys, + "current_value": current_value, + }, + data_type="incorrect_settings_value", + ) + ) + try: - why = str(exc).split(":")[1].strip() - await websocket_manager.broadcast( + websocket_manager.broadcast_sync( data=Notification( severity="error", - message=f"Validation error: {why}", + message="Validation error", title="Validation Error", ) ) @@ -130,7 +153,9 @@ async def startup_event(): for model in config.api.autoloaded_models: if model in [i.path for i in all_models]: backend: InferenceBackend = [i.backend for i in all_models if i.path == model][0] # type: ignore - gpu.load_model(model, backend) + model_type = determine_model_type(get_full_model_path(model))[1] + + gpu.load_model(model, backend, type=model_type) else: logger.warning(f"Autoloaded model {model} not found, skipping") diff --git a/api/routes/models.py b/api/routes/models.py index 9986d109e..2aacff959 100644 --- a/api/routes/models.py +++ b/api/routes/models.py @@ -20,10 +20,11 @@ DeleteModelRequest, InferenceBackend, ModelResponse, + PyTorchModelBase, TextualInversionLoadRequest, VaeLoadRequest, ) -from core.utils import download_file +from core.utils import determine_model_type, download_file router = APIRouter(tags=["models"]) logger = logging.getLogger(__name__) @@ -62,9 +63,11 @@ def list_loaded_models() -> List[ModelResponse]: loaded_models = [] for model_id in gpu.loaded_models: + name, type_, stage = determine_model_type(get_full_model_path(model_id)) + loaded_models.append( ModelResponse( - name=Path(model_id).name + name=name if (".ckpt" in model_id) or (".safetensors" in model_id) else model_id, backend=gpu.loaded_models[model_id].backend, @@ -75,6 +78,8 @@ def list_loaded_models() -> List[ModelResponse]: "textual_inversions", [] ), valid=True, + stage=stage, + type=type_, ) ) @@ -92,11 +97,12 @@ def list_available_models() -> List[ModelResponse]: def load_model( model: str, backend: InferenceBackend, + type: PyTorchModelBase, ): "Loads a model into memory" try: - gpu.load_model(model, backend) + gpu.load_model(model, backend, type) websocket_manager.broadcast_sync(data=Data(data_type="refresh_models", data={})) except torch.cuda.OutOfMemoryError: # type: ignore @@ -106,7 +112,7 @@ def load_model( @router.post("/unload") -async def unload_model(model: str): +def unload_model(model: str): "Unloads a model from memory" gpu.unload(model) @@ -125,7 +131,7 @@ def unload_all_models(): @router.post("/load-vae") -async def load_vae(req: VaeLoadRequest): +def load_vae(req: VaeLoadRequest): "Load a VAE into a model" gpu.load_vae(req) @@ -134,7 +140,7 @@ async def load_vae(req: VaeLoadRequest): @router.post("/load-textual-inversion") -async def load_textual_inversion(req: TextualInversionLoadRequest): +def load_textual_inversion(req: TextualInversionLoadRequest): "Load a LoRA model into a model" gpu.load_textual_inversion(req) @@ -143,7 +149,7 @@ async def load_textual_inversion(req: TextualInversionLoadRequest): @router.post("/memory-cleanup") -async def cleanup(): +def cleanup(): "Free up memory manually" gpu.memory_cleanup() @@ -151,7 +157,7 @@ async def cleanup(): @router.post("/download") -async def download_model(model: str): +def download_model(model: str): "Download a model to the cache" gpu.download_huggingface_model(model) @@ -243,7 +249,7 @@ def delete_model(req: DeleteModelRequest): @router.post("/download-model") def download_checkpoint( - link: str, model_type: Literal["Checkpoint", "TextualInversion", "LORA"] + link: str, model_type: Literal["Checkpoint", "TextualInversion", "LORA", "VAE"] ) -> str: "Download a model from a link and return the path to the downloaded file." @@ -254,7 +260,12 @@ def download_checkpoint( folder = "textual-inversion" elif mtype == "lora": folder = "lora" + elif mtype == "vae": + folder = "vae" else: raise ValueError(f"Unknown model type {mtype}") - return download_file(link, Path("data") / folder, True).as_posix() + saved_path = download_file(link, Path("data") / folder, True).as_posix() + websocket_manager.broadcast_sync(Data(data_type="refresh_models", data={})) + + return saved_path diff --git a/api/routes/outputs.py b/api/routes/outputs.py index d6b391397..ea9f596c9 100644 --- a/api/routes/outputs.py +++ b/api/routes/outputs.py @@ -12,7 +12,7 @@ thread_pool = ThreadPoolExecutor() logger = logging.getLogger(__name__) -valid_extensions = ["png", "jpeg", "webp"] +valid_extensions = ["png", "jpeg", "webp", "gif"] def sort_images(images: List[Dict[str, Any]]) -> List[Dict[str, Any]]: diff --git a/api/routes/settings.py b/api/routes/settings.py index 32ac65b69..fe4225d63 100644 --- a/api/routes/settings.py +++ b/api/routes/settings.py @@ -4,7 +4,7 @@ from fastapi import APIRouter from core import config -from core.config.config import update_config +from core.config._config import update_config router = APIRouter(tags=["settings"]) diff --git a/core/config/__init__.py b/core/config/__init__.py index 2d52a925a..cc2962f88 100644 --- a/core/config/__init__.py +++ b/core/config/__init__.py @@ -1,14 +1,16 @@ from pathlib import Path -from diffusers.utils.constants import DIFFUSERS_CACHE +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE as DIFFUSERS_CACHE -from .config import ( +from ._config import ( Configuration, - Img2ImgConfig, - Txt2ImgConfig, load_config, save_config, ) +from .default_settings import ( + Txt2ImgConfig, + Img2ImgConfig, +) config = load_config() diff --git a/core/config/_config.py b/core/config/_config.py new file mode 100644 index 000000000..f45d5ae45 --- /dev/null +++ b/core/config/_config.py @@ -0,0 +1,80 @@ +import logging +from dataclasses import Field, dataclass, field, fields + +from dataclasses_json import CatchAll, DataClassJsonMixin, Undefined, dataclass_json + +from core.config.samplers.sampler_config import SamplerConfig + +from .api_settings import APIConfig +from .bot_settings import BotConfig +from .default_settings import ( + AITemplateConfig, + ControlNetConfig, + Img2ImgConfig, + InpaintingConfig, + ONNXConfig, + Txt2ImgConfig, + UpscaleConfig, +) +from .flags_settings import FlagsConfig +from .frontend_settings import FrontendConfig +from .interrogator_settings import InterrogatorConfig + +logger = logging.getLogger(__name__) + + +@dataclass_json(undefined=Undefined.INCLUDE) +@dataclass +class Configuration(DataClassJsonMixin): + "Main configuration class for the application" + + txt2img: Txt2ImgConfig = field(default_factory=Txt2ImgConfig) + img2img: Img2ImgConfig = field(default_factory=Img2ImgConfig) + inpainting: InpaintingConfig = field(default_factory=InpaintingConfig) + controlnet: ControlNetConfig = field(default_factory=ControlNetConfig) + upscale: UpscaleConfig = field(default_factory=UpscaleConfig) + api: APIConfig = field(default_factory=APIConfig) + interrogator: InterrogatorConfig = field(default_factory=InterrogatorConfig) + aitemplate: AITemplateConfig = field(default_factory=AITemplateConfig) + onnx: ONNXConfig = field(default_factory=ONNXConfig) + bot: BotConfig = field(default_factory=BotConfig) + frontend: FrontendConfig = field(default_factory=FrontendConfig) + sampler_config: SamplerConfig = field(default_factory=SamplerConfig) + flags: FlagsConfig = field(default_factory=FlagsConfig) + extra: CatchAll = field(default_factory=dict) + + +def save_config(config: Configuration): + "Save the configuration to a file" + + logger.info("Saving configuration to data/settings.json") + + with open("data/settings.json", "w", encoding="utf-8") as f: + f.write(config.to_json(ensure_ascii=False, indent=4)) + + +def update_config(config: Configuration, new_config: Configuration): + "Update the configuration with new values instead of overwriting the pointer" + + for cls_field in fields(new_config): + assert isinstance(cls_field, Field) + setattr(config, cls_field.name, getattr(new_config, cls_field.name)) + + +def load_config(): + "Load the configuration from a file" + + logger.info("Loading configuration from data/settings.json") + + try: + with open("data/settings.json", "r", encoding="utf-8") as f: + config = Configuration.from_json(f.read()) + logger.info("Configuration loaded from data/settings.json") + return config + + except FileNotFoundError: + logger.info("data/settings.json not found, creating a new one") + config = Configuration() + save_config(config) + logger.info("Configuration saved to data/settings.json") + return config diff --git a/core/config/api_settings.py b/core/config/api_settings.py new file mode 100644 index 000000000..f1389dfc5 --- /dev/null +++ b/core/config/api_settings.py @@ -0,0 +1,248 @@ +from dataclasses import dataclass, field +from typing import Dict, List, Literal, Union + +import torch + +from core.flags import LatentScaleModel + + +@dataclass +class APIConfig: + "Configuration for the API" + + # Autoload + autoloaded_textual_inversions: List[str] = field(default_factory=list) + autoloaded_models: List[str] = field(default_factory=list) + autoloaded_vae: Dict[str, str] = field(default_factory=dict) + + # Websockets and intervals + websocket_sync_interval: float = 0.02 + websocket_perf_interval: float = 1.0 + enable_websocket_logging: bool = True + + # TomeSD + use_tomesd: bool = False # really extreme, probably will have to wait around until tome improves a bit + tomesd_ratio: float = 0.25 # had to tone this down, 0.4 is too big of a context loss even on short prompts + tomesd_downsample_layers: Literal[1, 2, 4, 8] = 1 + + # General optimizations + autocast: bool = False + attention_processor: Literal[ + "xformers", "sdpa", "cross-attention", "subquadratic", "multihead" + ] = "sdpa" + fuse_attention: bool = True + subquadratic_size: int = 512 + attention_slicing: Union[int, Literal["auto", "disabled"]] = "auto" + channels_last: bool = True + trace_model: bool = False + clear_memory_policy: Literal["always", "after_disconnect", "never"] = "always" + offload: Literal["disabled", "model", "module"] = "disabled" + data_type: Literal[ + "float32", "float16", "bfloat16", "float8_e4m3fn", "float8_e5m2" + ] = "float16" + use_minimal_sdxl_pipeline: bool = False # slower, but works better + + # VRAM optimizations + # whether to run both parts of CFG>1 generations in one call. Increases VRAM usage during inference, + # halves inference speed for most -- newer than 10xx -- cards. + batch_cond_uncond: bool = True + # Whether to cache the weight of tensors for LoRA loading during float8 inference. + # Improves how LoRAs work when data-type is a subset of FP8, but increases system RAM usage to 2x. + # ONLY WORKS WHEN DATA_TYPE=FLOAT8 + cache_fp16_weight: bool = False # only works on float8. Used for LoRAs. + + # Approximation-type optimizations ("ruin" quality for big boosts in performance) + # According to the following paper: https://arxiv.org/pdf/2312.09608.pdf (Faster-Diffusion) + # ControlNet infers can be "effectively skipped" after a certain point, because their control + # on the images loosens up after a while. + # + # value: from which "percentage" of the diffusion process should we skip controlnet inferring. + # default: 1.0, don't skip at all. + # IMPORTANT: the first min(5, n-3) steps WILL always have controlnet inferring on to avoid completely + # disabling controlnet + approximate_controlnet: float = 1.0 + + # According to the following paper: https://browse.arxiv.org/html/2312.12487v1 (Adaptive Guidance) + # even stopping it naively (without implementing AG, which I (Gabe) might implement later) + # doesn't murder image quality. I'd compare it to how TensorRT "mutates" images. + # Probably shouldn't be set to anything below 0.75 + # + # value: from which "percentage" of the diffusion process should we stop inferring uncond. + # default: 1.0, don't stop inferring at all. + cfg_uncond_tau: float = 1.0 + + # Won't implement timestep pararellization since it has a tendency to "explode" VRAM usage, and I + # don't know if I'm ready for issues that could bring down the line... + # source: https://arxiv.org/pdf/2312.09608.pdf (Faster-Diffusion) + + # According to the following paper: https://arxiv.org/pdf/2312.09608.pdf (Faster-Diffusion) + # Dropping encode/decode -- effectively caching it -- and creating new values every nth step + # produces negligible quality loss whilst boosting performance by 1/drop_encode_decode%. + # + # value: "off" disables this. "on" infers the first 5 steps ALWAYS as full-quality ones, and then every 5th. + # default: "off," due to quality loss, same reason as to why HyperTile is off by default. + drop_encode_decode: Union[ + int, # if int, drop every x-th step excluding the first 5 + Literal["off", "on"], # on = first 5 + every 5th + ] = "off" # "on" results in a ~30% increase in performance, without ruining quality too much -- highly depends on seed + + deepcache_cache_interval: int = 1 # cache every x-th layer, 1 = disabled + + # CUDA specific optimizations + reduced_precision: bool = False + cudnn_benchmark: bool = False + deterministic_generation: bool = False + + # Device settings + device: str = "cuda:0" + # Where to load the models onto first. By default, diffusers has a kind of stupid way of + # first loading models to cpu and then onto the device. + load_location: Union[str, Literal["on-device", "cpu"]] = "on-device" + # Load models using streaming instead of mmaping. Mmaping is faster 99% of the time, + # however WSL2s drvfs fucks these things up and just streaming the loading is usually faster. + # Only case where users may run into these issues is when loading from the host NTFS drive from WSL2, or when having some cloud + # service set as their drive. + # Evem then, highly recommend leaving this on False. + stream_load: bool = False + + # Critical + enable_shutdown: bool = True + + # CLIP + clip_skip: int = 1 + clip_quantization: Literal["full", "int8", "int4"] = "full" + + huggingface_style_parsing: bool = False + + # Saving + save_path_template: str = "{folder}/{prompt}/{id}-{index}.{extension}" + image_extension: Literal["png", "webp", "jpeg"] = "png" + image_quality: int = 95 + image_return_format: Literal["bytes", "base64"] = "base64" + + # Grid + disable_grid: bool = False + + # Torch compile + torch_compile: bool = False + torch_compile_fullgraph: bool = False + torch_compile_dynamic: bool = False + torch_compile_backend: str = "inductor" + torch_compile_mode: Literal[ + "default", + "reduce-overhead", + "max-autotune", + ] = "reduce-overhead" + + sfast_compile: bool = False + sfast_xformers: bool = True + sfast_triton: bool = True + sfast_cuda_graph: bool = True + + # Hypertile + hypertile: bool = False + hypertile_unet_chunk: int = 512 + + # Kohya Deep-Shrink + deepshrink_enabled: bool = True + deepshrink_depth_1: int = 3 # -1 to 12; steps of 1 + deepshrink_stop_at_1: float = 0.15 # 0 to 0.5; steps of 0.01 + + deepshrink_depth_2: int = 4 # -1 to 12; steps of 1 + deepshrink_stop_at_2: float = 0.30 # 0 to 0.5; steps of 0.01 + + deepshrink_scaler: LatentScaleModel = "bislerp" + deepshrink_base_scale: float = 0.5 # 0.05 to 1.0; steps of 0.05 + deepshrink_early_out: bool = True + + # K_Diffusion + sgm_noise_multiplier: bool = False # also known as "alternate DDIM ODE" + kdiffusers_quantization: bool = True # improves sampling quality + + # K_Diffusion & Diffusers + # What to do with refiner: + # - "joint:" instead of creating a new sampler, it uses the refiner inside of the main loop, + # replacing the unet with the refiners unet after a certain number of steps have + # been processed. This improves consistency and generation quality. + # - "separate:" creates a new pipeline for refiner and does the refining there on the final + # latents of the image. This can introduce some artifacts/lose context. + sdxl_refiner: Literal["joint", "separate"] = "separate" + + # "philox" is what a "cuda" generator would be, except, it's on cpu + generator: Literal["device", "cpu", "philox"] = "device" + + # VAE + live_preview_method: Literal[ + "disabled", + "approximation", + "taesd", + "full", # TODO: isn't supported yet. + ] = "approximation" + live_preview_delay: float = 2.0 + vae_slicing: bool = True + vae_tiling: bool = True + upcast_vae: bool = False # Fixes issues on 10xx-series and RX cards + # Somewhat fixes extraordinarily high CFG values. Does also change output composition, so + # best to leave on off by default. TODO: write docs for this? + apply_unsharp_mask: bool = False + # Rescales CFG to a known good value when CFG is higher than this number. Set to "off" to disable. + cfg_rescale_threshold: Union[float, Literal["off"]] = 10.0 + + # Prompt expansion (very, and I mean VERYYYY heavily inspired/copied from lllyasviel/Fooocus) + prompt_to_prompt: bool = False + prompt_to_prompt_model: Literal[ + "lllyasviel/Fooocus-Expansion", + "daspartho/prompt-extend", + "succinctly/text2image-prompt-generator", + "Gustavosta/MagicPrompt-Stable-Diffusion", + "Ar4ikov/gpt2-medium-650k-stable-diffusion-prompt-generator", + ] = "lllyasviel/Fooocus-Expansion" + prompt_to_prompt_device: Literal["cpu", "gpu"] = "gpu" + + # Free U + free_u: bool = False + free_u_s1: float = 0.9 + free_u_s2: float = 0.2 + free_u_b1: float = 1.2 + free_u_b2: float = 1.4 + + @property + def dtype(self) -> torch.dtype: + "Return selected data type" + return getattr(torch, self.data_type) + + @property + def load_device(self) -> torch.device: + "Device to use for loading models onto." + return ( + torch.device(self.device) + if self.load_location == "on-device" + else torch.device(self.load_location) + ) + + @property + def load_dtype(self) -> torch.dtype: + "Data type for loading models." + dtype = self.dtype + if "float8" in self.data_type: + from core.shared_dependent import gpu + + if self.device == "cpu": + if "bfloat16" in gpu.capabilities.supported_precisions_cpu: + dtype = torch.bfloat16 + else: + dtype = torch.float32 + else: + if "float16" in gpu.capabilities.supported_precisions_gpu: + dtype = torch.float16 + else: + dtype = torch.float32 + return dtype + + @property + def overwrite_generator(self) -> bool: + "Whether the generator needs to be overwritten with 'cpu.'" + + return any( + map(lambda x: x in self.device, ["mps", "directml", "vulkan", "intel"]) + ) diff --git a/core/config/bot_settings.py b/core/config/bot_settings.py new file mode 100644 index 000000000..4f21406ba --- /dev/null +++ b/core/config/bot_settings.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass + +from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers + + +@dataclass +class BotConfig: + "Configuration for the bot" + + default_scheduler: KarrasDiffusionSchedulers = ( + KarrasDiffusionSchedulers.DPMSolverSinglestepScheduler + ) + verbose: bool = False + use_default_negative_prompt: bool = True diff --git a/core/config/config.py b/core/config/config.py deleted file mode 100644 index 6bafd093d..000000000 --- a/core/config/config.py +++ /dev/null @@ -1,335 +0,0 @@ -import logging -import multiprocessing -from dataclasses import Field, dataclass, field, fields -from typing import Dict, List, Literal, Optional, Union - -import torch -from dataclasses_json import CatchAll, DataClassJsonMixin, Undefined, dataclass_json -from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers - -from core.config.samplers.sampler_config import SamplerConfig -from core.flags import HighResFixFlag -from core.types import SigmaScheduler - -logger = logging.getLogger(__name__) - - -@dataclass -class BaseDiffusionMixin: - width: int = 512 - height: int = 512 - batch_count: int = 1 - batch_size: int = 1 - seed: int = -1 - cfg_scale: int = 7 - steps: int = 40 - prompt: str = "" - negative_prompt: str = "" - sampler: Union[ - int, str - ] = KarrasDiffusionSchedulers.DPMSolverSinglestepScheduler.value - sigmas: SigmaScheduler = "automatic" - - -@dataclass -class QuantDict: - vae_decoder: Optional[bool] = None - vae_encoder: Optional[bool] = None - unet: Optional[bool] = None - text_encoder: Optional[bool] = None - - -@dataclass -class Txt2ImgConfig(BaseDiffusionMixin): - "Configuration for the text to image pipeline" - - self_attention_scale: float = 0.0 - - -@dataclass -class Img2ImgConfig(BaseDiffusionMixin): - "Configuration for the image to image pipeline" - - resize_method: int = 0 - denoising_strength: float = 0.6 - self_attention_scale: float = 0.0 - - -@dataclass -class InpaintingConfig(BaseDiffusionMixin): - "Configuration for the inpainting pipeline" - - self_attention_scale: float = 0.0 - - -@dataclass -class ControlNetConfig(BaseDiffusionMixin): - "Configuration for the inpainting pipeline" - - controlnet: str = "lllyasviel/sd-controlnet-canny" - controlnet_conditioning_scale: float = 1.0 - detection_resolution: int = 512 - is_preprocessed: bool = False - save_preprocessed: bool = False - return_preprocessed: bool = True - - -@dataclass -class UpscaleConfig: - "Configuration for the RealESRGAN upscaler" - - model: str = "RealESRGAN_x4plus_anime_6B" - upscale_factor: int = 4 - tile_size: int = field(default=128) - tile_padding: int = field(default=10) - - -@dataclass -class APIConfig: - "Configuration for the API" - - # Autoload - autoloaded_textual_inversions: List[str] = field(default_factory=list) - autoloaded_models: List[str] = field(default_factory=list) - autoloaded_vae: Dict[str, str] = field(default_factory=dict) - - # Websockets and intervals - websocket_sync_interval: float = 0.02 - websocket_perf_interval: float = 1.0 - enable_websocket_logging: bool = True - - # TomeSD - use_tomesd: bool = False # really extreme, probably will have to wait around until tome improves a bit - tomesd_ratio: float = 0.25 # had to tone this down, 0.4 is too big of a context loss even on short prompts - tomesd_downsample_layers: Literal[1, 2, 4, 8] = 1 - - # General optimizations - autocast: bool = False - attention_processor: Literal[ - "xformers", "sdpa", "cross-attention", "subquadratic", "multihead" - ] = "sdpa" - subquadratic_size: int = 512 - attention_slicing: Union[int, Literal["auto", "disabled"]] = "disabled" - channels_last: bool = True - trace_model: bool = False - clear_memory_policy: Literal["always", "after_disconnect", "never"] = "always" - offload: Literal["module", "model", "disabled"] = "disabled" - data_type: Literal["float32", "float16", "bfloat16"] = "float16" - dont_merge_latents: bool = ( - False # Will drop performance, but could help with some VRAM issues - ) - - # CUDA specific optimizations - reduced_precision: bool = False - cudnn_benchmark: bool = False - deterministic_generation: bool = False - - # Device settings - device: str = "cuda:0" - - # Critical - enable_shutdown: bool = True - - # CLIP - clip_skip: int = 1 - clip_quantization: Literal["full", "int8", "int4"] = "full" - - huggingface_style_parsing: bool = False - - # Saving - save_path_template: str = "{folder}/{prompt}/{id}-{index}.{extension}" - image_extension: Literal["png", "webp", "jpeg"] = "png" - image_quality: int = 95 - image_return_format: Literal["bytes", "base64"] = "base64" - - # Grid - disable_grid: bool = False - - # Torch compile - torch_compile: bool = False - torch_compile_fullgraph: bool = False - torch_compile_dynamic: bool = False - torch_compile_backend: str = "inductor" - torch_compile_mode: Literal[ - "default", - "reduce-overhead", - "max-autotune", - ] = "reduce-overhead" - - sfast_compile: bool = False - sfast_xformers: bool = True - sfast_triton: bool = True - sfast_cuda_graph: bool = True - - # Hypertile - hypertile: bool = False - hypertile_unet_chunk: int = 256 - - # K_Diffusion - sgm_noise_multiplier: bool = False # also known as "alternate DDIM ODE" - kdiffusers_quantization: bool = True # improves sampling quality - - # "philox" is what a "cuda" generator would be, except, it's on cpu - generator: Literal["device", "cpu", "philox"] = "device" - - # VAE - live_preview_method: Literal["disabled", "approximation", "taesd"] = "approximation" - live_preview_delay: float = 2.0 - vae_slicing: bool = True - vae_tiling: bool = False - - # Prompt expansion (very, and I mean VERYYYY heavily inspired/copied from lllyasviel/Fooocus) - prompt_to_prompt: bool = False - prompt_to_prompt_model: Literal[ - "lllyasviel/Fooocus-Expansion", - "daspartho/prompt-extend", - "succinctly/text2image-prompt-generator", - "Gustavosta/MagicPrompt-Stable-Diffusion", - "Ar4ikov/gpt2-medium-650k-stable-diffusion-prompt-generator", - ] = "lllyasviel/Fooocus-Expansion" - prompt_to_prompt_device: Literal["cpu", "gpu"] = "gpu" - - # Free U - free_u: bool = False - free_u_s1: float = 0.9 - free_u_s2: float = 0.2 - free_u_b1: float = 1.2 - free_u_b2: float = 1.4 - - @property - def dtype(self): - "Return selected data type" - if self.data_type == "bfloat16": - return torch.bfloat16 - if self.data_type == "float16": - return torch.float16 - return torch.float32 - - @property - def overwrite_generator(self) -> bool: - "Whether the generator needs to be overwritten with 'cpu.'" - - return any( - map(lambda x: x in self.device, ["mps", "directml", "vulkan", "intel"]) - ) - - -@dataclass -class AITemplateConfig: - "Configuration for model inference and acceleration" - - num_threads: int = field(default=min(multiprocessing.cpu_count() - 1, 8)) - - -@dataclass -class ONNXConfig: - "Configuration for ONNX acceleration" - - quant_dict: QuantDict = field(default_factory=QuantDict) - - -@dataclass -class BotConfig: - "Configuration for the bot" - - default_scheduler: KarrasDiffusionSchedulers = ( - KarrasDiffusionSchedulers.DPMSolverSinglestepScheduler - ) - verbose: bool = False - use_default_negative_prompt: bool = True - - -@dataclass -class InterrogatorConfig: - "Configuration for interrogation models" - - # set to "Salesforce/blip-image-captioning-base" for an extra gig of vram - caption_model: str = "Salesforce/blip-image-captioning-large" - visualizer_model: str = "ViT-L-14/openai" - - offload_captioner: bool = False - offload_visualizer: bool = False - - chunk_size: int = 2048 # set to 1024 for lower vram usage - flavor_intermediate_count: int = 2048 # set to 1024 for lower vram usage - - flamingo_model: str = "dhansmair/flamingo-mini" - - caption_max_length: int = 32 - - -@dataclass -class FrontendConfig: - "Configuration for the frontend" - - theme: str = "dark" - background_image_override: str = "" - enable_theme_editor: bool = False - image_browser_columns: int = 5 - on_change_timer: int = 0 - nsfw_ok_threshold: int = 0 - disable_analytics: bool = False - - -@dataclass -class FlagsConfig: - "Configuration for flags" - - highres: HighResFixFlag = field(default_factory=HighResFixFlag) - - -@dataclass_json(undefined=Undefined.INCLUDE) -@dataclass -class Configuration(DataClassJsonMixin): - "Main configuration class for the application" - - txt2img: Txt2ImgConfig = field(default_factory=Txt2ImgConfig) - img2img: Img2ImgConfig = field(default_factory=Img2ImgConfig) - inpainting: InpaintingConfig = field(default_factory=InpaintingConfig) - controlnet: ControlNetConfig = field(default_factory=ControlNetConfig) - upscale: UpscaleConfig = field(default_factory=UpscaleConfig) - api: APIConfig = field(default_factory=APIConfig) - interrogator: InterrogatorConfig = field(default_factory=InterrogatorConfig) - aitemplate: AITemplateConfig = field(default_factory=AITemplateConfig) - onnx: ONNXConfig = field(default_factory=ONNXConfig) - bot: BotConfig = field(default_factory=BotConfig) - frontend: FrontendConfig = field(default_factory=FrontendConfig) - flags: FlagsConfig = field(default_factory=FlagsConfig) - sampler_config: SamplerConfig = field(default_factory=SamplerConfig) - extra: CatchAll = field(default_factory=dict) - - -def save_config(config: Configuration): - "Save the configuration to a file" - - logger.info("Saving configuration to data/settings.json") - - with open("data/settings.json", "w", encoding="utf-8") as f: - f.write(config.to_json(ensure_ascii=False, indent=4)) - - -def update_config(config: Configuration, new_config: Configuration): - "Update the configuration with new values instead of overwriting the pointer" - - for cls_field in fields(new_config): - assert isinstance(cls_field, Field) - setattr(config, cls_field.name, getattr(new_config, cls_field.name)) - - -def load_config(): - "Load the configuration from a file" - - logger.info("Loading configuration from data/settings.json") - - try: - with open("data/settings.json", "r", encoding="utf-8") as f: - config = Configuration.from_json(f.read()) - logger.info("Configuration loaded from data/settings.json") - return config - - except FileNotFoundError: - logger.info("data/settings.json not found, creating a new one") - config = Configuration() - save_config(config) - logger.info("Configuration saved to data/settings.json") - return config diff --git a/core/config/default_settings.py b/core/config/default_settings.py new file mode 100644 index 000000000..e177c1cfc --- /dev/null +++ b/core/config/default_settings.py @@ -0,0 +1,108 @@ +import multiprocessing +from dataclasses import dataclass, field +from typing import Optional, Union + +from core.flags import ( + ADetailerFlag, + DeepshrinkFlag, + HighResFixFlag, + ScalecrafterFlag, + UpscaleFlag, +) +from core.types import SigmaScheduler + + +@dataclass +class QuantDict: + "Configuration for ONNX quantization" + + vae_decoder: Optional[bool] = None + vae_encoder: Optional[bool] = None + unet: Optional[bool] = None + text_encoder: Optional[bool] = None + + +@dataclass +class BaseDiffusionMixin: + width: int = 512 + height: int = 512 + seed: int = -1 + cfg_scale: int = 7 + steps: int = 25 + prompt: str = "" + negative_prompt: str = ( + "(worst quality, low quality:1.4), monochrome, (interlocked fingers:1.2)" + ) + sampler: Union[int, str] = "dpmpp_2m" + sigmas: SigmaScheduler = "exponential" + batch_count: int = 1 + batch_size: int = 1 + + # Flags + highres: HighResFixFlag = field(default_factory=HighResFixFlag) + upscale: UpscaleFlag = field(default_factory=UpscaleFlag) + deepshrink: DeepshrinkFlag = field(default_factory=DeepshrinkFlag) + scalecrafter: ScalecrafterFlag = field(default_factory=ScalecrafterFlag) + adetailer: ADetailerFlag = field(default_factory=ADetailerFlag) + + +@dataclass +class Txt2ImgConfig(BaseDiffusionMixin): + "Configuration for the text to image pipeline" + + self_attention_scale: float = 0.0 + + +@dataclass +class Img2ImgConfig(BaseDiffusionMixin): + "Configuration for the image to image pipeline" + + resize_method: int = 0 + denoising_strength: float = 0.6 + self_attention_scale: float = 0.0 + + +@dataclass +class InpaintingConfig(BaseDiffusionMixin): + "Configuration for the inpainting pipeline" + + self_attention_scale: float = 0.0 + strength: float = 0.6 + + +@dataclass +class ControlNetConfig(BaseDiffusionMixin): + "Configuration for the inpainting pipeline" + + self_attention_scale: float = 0.0 + + controlnet: str = "lllyasviel/sd-controlnet-canny" + controlnet_conditioning_scale: float = 1.0 + detection_resolution: int = 512 + is_preprocessed: bool = False + save_preprocessed: bool = False + return_preprocessed: bool = True + + +@dataclass +class UpscaleConfig: + "Configuration for the RealESRGAN upscaler" + + model: str = "RealESRGAN_x4plus_anime_6B" + upscale_factor: int = 4 + tile_size: int = field(default=128) + tile_padding: int = field(default=10) + + +@dataclass +class AITemplateConfig: + "Configuration for model inference and acceleration" + + num_threads: int = field(default=min(multiprocessing.cpu_count() - 1, 8)) + + +@dataclass +class ONNXConfig: + "Configuration for ONNX acceleration" + + quant_dict: QuantDict = field(default_factory=QuantDict) diff --git a/core/config/flags_settings.py b/core/config/flags_settings.py new file mode 100644 index 000000000..b921fba7b --- /dev/null +++ b/core/config/flags_settings.py @@ -0,0 +1,11 @@ +from dataclasses import dataclass, field + +from core.flags import SDXLFlag, SDXLRefinerFlag + + +@dataclass +class FlagsConfig: + "Configuration for flags" + + refiner: SDXLRefinerFlag = field(default_factory=SDXLRefinerFlag) + sdxl: SDXLFlag = field(default_factory=SDXLFlag) diff --git a/core/config/frontend_settings.py b/core/config/frontend_settings.py new file mode 100644 index 000000000..801c7b814 --- /dev/null +++ b/core/config/frontend_settings.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass + + +@dataclass +class FrontendConfig: + "Configuration for the frontend" + + theme: str = "dark" + background_image_override: str = "" + enable_theme_editor: bool = False + image_browser_columns: int = 5 + on_change_timer: int = 0 + nsfw_ok_threshold: int = 0 + disable_analytics: bool = False diff --git a/core/config/interrogator_settings.py b/core/config/interrogator_settings.py new file mode 100644 index 000000000..bfadc0c10 --- /dev/null +++ b/core/config/interrogator_settings.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + + +@dataclass +class InterrogatorConfig: + "Configuration for interrogation models" + + # set to "Salesforce/blip-image-captioning-base" for an extra gig of vram + caption_model: str = "Salesforce/blip-image-captioning-large" + visualizer_model: str = "ViT-L-14/openai" + + offload_captioner: bool = False + offload_visualizer: bool = False + + chunk_size: int = 2048 # set to 1024 for lower vram usage + flavor_intermediate_count: int = 2048 # set to 1024 for lower vram usage + + flamingo_model: str = "dhansmair/flamingo-mini" + + caption_max_length: int = 32 diff --git a/core/config/samplers/kdiffusion_sampler_config.py b/core/config/samplers/kdiffusion_sampler_config.py index 3ef57153c..7858bcfd4 100644 --- a/core/config/samplers/kdiffusion_sampler_config.py +++ b/core/config/samplers/kdiffusion_sampler_config.py @@ -34,6 +34,12 @@ class Heun(BaseMixin): s_noise: Optional[float] = None +@dataclass +class Heunpp(BaseMixin): + s_churn: Optional[float] = None + s_noise: Optional[float] = None + + @dataclass class DPM_2(BaseMixin): s_churn: Optional[float] = None diff --git a/core/config/samplers/sampler_config.py b/core/config/samplers/sampler_config.py index a9355c7cd..4f6d29dbf 100644 --- a/core/config/samplers/sampler_config.py +++ b/core/config/samplers/sampler_config.py @@ -15,6 +15,7 @@ Euler, Euler_a, Heun, + Heunpp, ) @@ -130,6 +131,7 @@ class SamplerConfig: euler: Euler = field(default_factory=Euler) lms: LMS = field(default_factory=LMS) heun: Heun = field(default_factory=Heun) + heunpp: Heunpp = field(default_factory=Heunpp) dpm_fast: DPM_fast = field(default_factory=DPM_fast) dpm_adaptive: DPM_adaptive = field(default_factory=DPM_adaptive) dpm2: DPM_2 = field(default_factory=DPM_2) diff --git a/core/files.py b/core/files.py index 52f47c845..2d0e7b1bc 100644 --- a/core/files.py +++ b/core/files.py @@ -1,18 +1,19 @@ import logging import os from pathlib import Path -from typing import List, Optional, Union +from typing import List, Union -from diffusers.utils.constants import DIFFUSERS_CACHE +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE as DIFFUSERS_CACHE from huggingface_hub.file_download import repo_folder_name from core.types import ModelResponse +from core.utils import determine_model_type logger = logging.getLogger(__name__) class CachedModelList: - "List of models downloaded for PyTorch and (or) converted to TRT" + "List of models that user has downloaded" def __init__(self): self.paths = { @@ -53,23 +54,28 @@ def pytorch(self) -> List[ModelResponse]: # Skip if it is not a huggingface model if "model" not in model_name: continue + parsed_model_name: str = "/".join(model_name.split("--")[1:3]) - name: str = "/".join(model_name.split("--")[1:3]) try: - models.append( - ModelResponse( - name=name, - path=name, - backend="PyTorch", - vae="default", - valid=is_valid_diffusers_model(get_full_model_path(name)), - state="not loaded", - ) - ) - except ValueError: - logger.debug(f"Invalid model {name}, skipping...") + full_path = get_full_model_path(parsed_model_name) + except ValueError as e: + logger.debug(f"Model {parsed_model_name} is not valid: {e}") continue + _name, base, stage = determine_model_type(full_path) + models.append( + ModelResponse( + name=parsed_model_name, + path=parsed_model_name, + backend="PyTorch", + type=base, + stage=stage, + vae="default", + valid=is_valid_diffusers_model(full_path), + state="not loaded", + ) + ) + # Localy stored models logger.debug(f"Looking for local models in '{self.paths['checkpoints']}'") for model_path in self.paths["checkpoints"].rglob("*"): @@ -79,10 +85,12 @@ def pytorch(self) -> List[ModelResponse]: if not model_path.joinpath("model_index.json").exists(): continue + name, base, stage = determine_model_type(model_path) + # Assuming that model is in Diffusers format models.append( ModelResponse( - name=model_path.name, + name=name, path=model_path.relative_to( self.paths["checkpoints"] ).as_posix(), @@ -90,6 +98,8 @@ def pytorch(self) -> List[ModelResponse]: vae="default", valid=is_valid_diffusers_model(model_path), state="not loaded", + type=base, + stage=stage, ) ) elif ( @@ -98,29 +108,34 @@ def pytorch(self) -> List[ModelResponse]: model_path.parent.joinpath("model_index.json").exists() or model_path.parent.parent.joinpath("model_index.json").exists() ): + if ".ckpt" == model_path.suffix: + name, base, stage = model_path.name, "SD1.x", "first_stage" + else: + name, base, stage = determine_model_type(model_path) + # Assuming that model is in Checkpoint / Safetensors format models.append( ModelResponse( - name=model_path.name, + name=name, path=model_path.relative_to( self.paths["checkpoints"] ).as_posix(), backend="PyTorch", vae="default", valid=True, + type=base, + stage=stage, state="not loaded", ) ) else: # Junk file, notify user - logger.debug( - f"Found junk file {model_path} in {self.paths['checkpoints']}, skipping..." - ) + logger.debug(f"Found junk file {model_path}, skipping...") return models def aitemplate(self) -> List[ModelResponse]: - "List of models converted to TRT" + "List of models converted to AITempalte" models: List[ModelResponse] = [] @@ -423,17 +438,18 @@ def diffusers_storage_name(repo_id: str, repo_type: str = "model") -> str: ) -def current_diffusers_ref(path: str, revision: str = "main") -> Optional[str]: +def current_diffusers_ref(path: str, revision: str = "main") -> str: "Return the current ref of the diffusers model" rev_path = os.path.join(path, "refs", revision) snapshot_path = os.path.join(path, "snapshots") if not os.path.exists(rev_path) or not os.path.exists(snapshot_path): - return None + raise ValueError( + f"Ref path {rev_path} or snapshot path {snapshot_path} not found" + ) snapshots = os.listdir(snapshot_path) - ref = "" with open(os.path.join(path, "refs", revision), "r", encoding="utf-8") as f: ref = f.read().strip().split(":")[0] @@ -442,6 +458,10 @@ def current_diffusers_ref(path: str, revision: str = "main") -> Optional[str]: if ref.startswith(snapshot): return snapshot + raise ValueError( + f"Ref {ref} found in {snapshot_path} for revision {revision}, but ref path does not exist" + ) + def get_full_model_path( repo_id: str, @@ -476,7 +496,7 @@ def get_full_model_path( ref = current_diffusers_ref(storage, revision) if not ref: - raise ValueError("No ref found") + raise ValueError(f"No ref found for {repo_id}") if diffusers_skip_ref_follow: return Path(storage) diff --git a/core/flags.py b/core/flags.py index 8e6d72942..5aef2fe91 100644 --- a/core/flags.py +++ b/core/flags.py @@ -1,14 +1,15 @@ -from dataclasses import dataclass -from typing import Literal +from dataclasses import dataclass, field +from typing import Dict, Literal, List, Union from dataclasses_json.api import DataClassJsonMixin +from core.types import SigmaScheduler + LatentScaleModel = Literal[ "nearest", "area", "bilinear", - "bislerp-original", - "bislerp-tortured", + "bislerp", "bicubic", "nearest-exact", ] @@ -23,6 +24,8 @@ class Flag: class HighResFixFlag(Flag, DataClassJsonMixin): "Flag to fix high resolution images" + enabled: bool = False # For storing in json + scale: float = 2 mode: Literal["latent", "image"] = "latent" @@ -30,9 +33,159 @@ class HighResFixFlag(Flag, DataClassJsonMixin): image_upscaler: str = "RealESRGAN_x4plus_anime_6B" # Latent Upscaling - latent_scale_mode: LatentScaleModel = "bislerp-tortured" + latent_scale_mode: LatentScaleModel = "bislerp" antialiased: bool = False # Img2img strength: float = 0.7 steps: int = 50 + + +@dataclass +class DeepshrinkFlag(Flag, DataClassJsonMixin): + "Flag for deepshrink" + + enabled: bool = False # For storing in json + + depth_1: int = 3 # -1 to 12; steps of 1 + stop_at_1: float = 0.15 # 0 to 0.5; steps of 0.01 + + depth_2: int = 4 # -1 to 12; steps of 1 + stop_at_2: float = 0.30 # 0 to 0.5; steps of 0.01 + + scaler: LatentScaleModel = "bislerp" + base_scale: float = 0.5 # 0.05 to 1.0; steps of 0.05 + early_out: bool = False + + +@dataclass +class ScalecrafterFlag(Flag, DataClassJsonMixin): + "Flag for Scalecrafter settings" + + enabled: bool = False # For storing in json + + base: str = "sd15" + # In other words: allow untested/"unsafe" resolutions like "1234x4321" + unsafe_resolutions: bool = True + # May produce more "appealing" images, but will triple, or even quadruple memory usage. + disperse: bool = False + + +@dataclass +class XLOriginalSize: + width: int = 1024 + height: int = 1024 + + +@dataclass +class SDXLFlag(Flag, DataClassJsonMixin): + "Flag for SDXL settings" + + original_size: XLOriginalSize = field(default_factory=XLOriginalSize) + + +@dataclass +class SDXLRefinerFlag(Flag, DataClassJsonMixin): + "Flag for SDXL refiners" + + steps: int = 50 + strength: float = 0.3 + model: str = "" + aesthetic_score: float = 6.0 + negative_aesthetic_score: float = 2.5 + + +@dataclass +class AnimateDiffFlag(Flag, DataClassJsonMixin): + "Flag for AnimateDiff" + + motion_model: str = "" + frames: int = 16 + fps: int = 10 # not working + + # Depends on seed whether or not it works??? Weird... investigate later... + # Probably self-explanatory, but increases generation time to {freeinit_iterations}x. + freeinit_iterations: int = -1 # -1 to disable, 5 recommended + freeinit_fast_sampling: bool = ( + False # decreases quality, but reduces generation time by ~60% + ) + freeinit_method: Literal["butterworth", "gaussian", "ideal", "box"] = "butterworth" + freeinit_n: int = 4 + freeinit_ds: float = 0.25 + freeinit_dt: float = 0.25 + + # Big maybes: + # - https://github.com/omerbt/TokenFlow + # - https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved/commit/5bbcae4d226e8f298a8b204e9cc9b2dd41fbe417 + + # TODO: steal code from here: + # - https://github.com/guoyww/AnimateDiff/pull/132/files + + # only active when (frames > context_size) --> sliding context window. + context_size: int = 16 + frame_stride: int = 2 + frame_overlap: int = 4 + context_scheduler: Literal["uniform", "uniform_constant", "uniform_v2"] = "uniform" + + closed_loop: bool = False + + # increase processing time for decreased memory usage + chunk_feed_forward: int = -1 # -1 for disable, 0 for batch, 1 for sequence + chunk_feed_size: Union[Literal["auto"], int] = -1 + + input_video: str = "" # not working + init_image: str = "" # not working + video_controlnets: List[str] = field(default_factory=list) # not working + + # PIA is a new technique using a 9-channel unet3d instead of the traditional 4-channel unet3d. + # Very basic rundown of what it does -- same principle as 9-channel inpaint, however the masks + # "opacity" or rather, "weight" changes based on how far along are we in the animation. Starts out with + # relatively strong control and loosens it up, giving animation over to the motion module. + # + # In theory it improves animation quality by a large margin. + use_pia: bool = True + pia_checkpont: str = "data/pia/pia.ckpt" + pia_cond_frame: int = 0 + pia_motion: int = 2 # 0 - 2 - motion settings, 0 is lowest, 3 is highest + pia_motion_type: Literal["normal", "closed_loop", "style_transfer"] = "normal" + + +@dataclass +class UpscaleFlag(Flag, DataClassJsonMixin): + "Flag for upscaling" + + enabled: bool = False # For storing in json + + upscale_factor: float = field(default=4) + tile_size: int = field(default=128) + tile_padding: int = field(default=10) + model: str = field(default="RealESRGAN_x4plus_anime_6B") + + +@dataclass +class ADetailerFlag(Flag, DataClassJsonMixin): + "Flag for ADetailer settings" + + # I hate pydantic + sampler_settings: Dict = field(default_factory=dict) + prompt_to_prompt_settings: Dict = field(default_factory=dict) + + enabled: bool = field(default=False) # For storing in json + + # Inpainting + image: Union[bytes, str, None] = field(default=None) + mask_image: Union[bytes, str, None] = field(default=None) + scheduler: Union[int, str] = "dpmpp_2m" + steps: int = field(default=40) + cfg_scale: float = field(default=7) + self_attention_scale: float = field(default=1.0) + sigmas: SigmaScheduler = field(default="exponential") + seed: int = field(default=0) + strength: float = field(default=0.45) + + # ADetailer specific + mask_dilation: int = field(default=4) + mask_blur: int = field(default=4) + mask_padding: int = field(default=32) + iterations: int = field(default=1) + upscale: int = field(default=2) diff --git a/core/gpu.py b/core/gpu.py index 16f84914e..9373770cc 100644 --- a/core/gpu.py +++ b/core/gpu.py @@ -2,6 +2,7 @@ import math import multiprocessing import time +from dataclasses import asdict from importlib.util import find_spec from pathlib import Path from typing import TYPE_CHECKING, Dict, List, Union @@ -16,25 +17,36 @@ from core import shared from core.config import config from core.errors import InferenceInterruptedError, ModelNotLoadedError +from core.flags import ADetailerFlag, HighResFixFlag, UpscaleFlag from core.inference.ait import AITemplateStableDiffusion from core.inference.esrgan import RealESRGAN, Upscaler -from core.inference.functions import download_model, is_ipex_available +from core.inference.functions import is_ipex_available from core.inference.pytorch import PyTorchStableDiffusion +from core.inference.sdxl import SDXLStableDiffusion +from core.inference.utilities.latents import scale_latents from core.interrogation.base_interrogator import InterrogationResult from core.optimizations import is_hypertile_available from core.png_metadata import save_images from core.queue import Queue from core.types import ( + ADetailerQueueEntry, AITemplateBuildRequest, AITemplateDynamicBuildRequest, Capabilities, ControlNetQueueEntry, + Img2imgData, + Img2ImgQueueEntry, InferenceBackend, InferenceJob, + InpaintData, + InpaintQueueEntry, InterrogatorQueueEntry, Job, ONNXBuildRequest, + PyTorchModelBase, TextualInversionLoadRequest, + Txt2ImgQueueEntry, + UpscaleData, UpscaleQueueEntry, VaeLoadRequest, ) @@ -57,6 +69,7 @@ def __init__(self) -> None: PyTorchStableDiffusion, "AITemplateStableDiffusion", "OnnxStableDiffusion", + "SDXLStableDiffusion", ], ] = {} self.capabilities = self._get_capabilities() @@ -98,19 +111,25 @@ def _get_capabilities(self) -> Capabilities: for device in [torch.device("cpu"), torch.device(config.api.device)]: support_map[device.type] = [] for dt in test_suite: - dtype = getattr(torch, dt) - a = torch.tensor([1.0], device=device, dtype=dtype) - b = torch.tensor([2.0], device=device, dtype=dtype) try: + dtype = getattr(torch, dt) + a = torch.tensor([1.0], device=device, dtype=dtype) + b = torch.tensor([2.0], device=device, dtype=dtype) torch.matmul(a, b) support_map[device.type].append(dt) except RuntimeError: pass + except AssertionError: + pass for t, s in support_map.items(): if t == "cpu": - cap.supported_precisions_cpu = ["float32"] + s + cap.supported_precisions_cpu = ( + ["float32"] + s + ["float8_e4m3fn", "float8_e5m2"] + ) else: - cap.supported_precisions_gpu = ["float32"] + s + cap.supported_precisions_gpu = ( + ["float32"] + s + ["float8_e4m3fn", "float8_e5m2"] + ) try: cap.supported_torch_compile_backends = ( torch._dynamo.list_backends() # type: ignore @@ -171,72 +190,248 @@ def vram_used(self) -> float: index = torch.device(config.api.device).index return torch.cuda.memory_allocated(index) / 1024**2 - def generate( - self, - job: InferenceJob, - ): - "Generate images from the queue" + def highres_flag( + self, job: Job, images: Union[List[Image.Image], torch.Tensor] + ) -> List[Image.Image]: + flag = job.flags["highres_fix"] + flag = HighResFixFlag.from_dict(flag) - job = preprocess_job(job) + if flag.mode == "latent": + assert isinstance(images, (torch.Tensor, torch.FloatTensor)) + latents = images - def generate_thread_call(job: Job) -> List[Image.Image]: - try: - model: Union[ - PyTorchStableDiffusion, - AITemplateStableDiffusion, - "OnnxStableDiffusion", - ] = self.loaded_models[job.model] - except KeyError as err: - websocket_manager.broadcast_sync( - Notification( - "error", - "Model not loaded", - f"Model {job.model} is not loaded, please load it first", + latents = scale_latents( + latents=latents, + scale=flag.scale, + latent_scale_mode=flag.latent_scale_mode, + ) + + height = latents.shape[2] * 8 + width = latents.shape[3] * 8 + output_images = latents + else: + from core.shared_dependent import gpu + + assert isinstance(images, List) + output_images = [] + + for image in images: + output: tuple[Image.Image, float] = gpu.upscale( + UpscaleQueueEntry( + data=UpscaleData( + id=job.data.id, + # FastAPI validation error, we need to do this so that we can pass in a PIL image + image=image, # type: ignore + upscale_factor=flag.scale, + ), + model=flag.image_upscaler, + save_image=False, ) ) + output_images.append(output[0]) + + output_images = output_images[0] # type: ignore + height = int(flag.scale * job.data.height) + width = int(flag.scale * job.data.width) + + data = Img2imgData( + prompt=job.data.prompt, + negative_prompt=job.data.negative_prompt, + image=output_images, # type: ignore + scheduler=job.data.scheduler, + batch_count=job.data.batch_count, + batch_size=job.data.batch_size, + strength=flag.strength, + steps=flag.steps, + guidance_scale=job.data.guidance_scale, + prompt_to_prompt_settings=job.data.prompt_to_prompt_settings, + seed=job.data.seed, + self_attention_scale=job.data.self_attention_scale, + sigmas=job.data.sigmas, + sampler_settings=job.data.sampler_settings, + height=height, + width=width, + ) + + img2img_job = Img2ImgQueueEntry( + data=data, + model=job.model, + ) + + result: List[Image.Image] = self.run_inference(img2img_job) + return result + + def upscale_flag(self, job: Job, images: List[Image.Image]) -> List[Image.Image]: + logger.debug("Upscaling image") + + flag = UpscaleFlag(**job.flags["upscale"]) + + final_images = [] + for image in images: + upscale_job = UpscaleQueueEntry( + data=UpscaleData( + image=image, # type: ignore # Pydantic would cry if we extend the union + upscale_factor=flag.upscale_factor, + tile_padding=flag.tile_padding, + tile_size=flag.tile_size, + ), + model=flag.model, + ) + + final_images.append(self.upscale(upscale_job)[0]) + + return final_images + + def adetailer_flag(self, job: Job, images: List[Image.Image]) -> List[Image.Image]: + logger.debug("Running ADetailer") + + flag = ADetailerFlag(**job.flags["adetailer"]) + data = asdict(flag) + mask_blur = data.pop("mask_blur") + mask_dilation = data.pop("mask_dilation") + mask_padding = data.pop("mask_padding") + iterations = data.pop("iterations") + upscale = data.pop("upscale") + data.pop("enabled", None) + + data["prompt"] = job.data.prompt + data["negative_prompt"] = job.data.negative_prompt - logger.debug("Model not loaded on any GPU. Raising error") - raise ModelNotLoadedError(f"Model {job.model} is not loaded") from err + data = InpaintData(**data) - shared.interrupt = False + assert data is not None - if job.flags: - logger.debug(f"Job flags: {job.flags}") + final_images = [] + for image in images: + data.image = image # type: ignore + data.prompt = job.data.prompt + data.negative_prompt = job.data.negative_prompt - steps = job.data.steps + adetailer_job = ADetailerQueueEntry( + data=data, + mask_blur=mask_blur, + mask_dilation=mask_dilation, + mask_padding=mask_padding, + iterations=iterations, + upscale=upscale, + model=job.model, + ) + + final_images.extend(self.run_inference(adetailer_job)) + + return final_images + + def postprocess( + self, job: Job, images: Union[List[Image.Image], torch.Tensor] + ) -> List[Image.Image]: + "Postprocess images" + + logger.debug(f"Postprocessing flags: {job.flags}") + + if "highres_fix" in job.flags: + images = self.highres_flag(job, images) + + if "adetailer" in job.flags: + assert isinstance(images, list) + images = self.adetailer_flag(job, images) + + if "upscale" in job.flags: + assert isinstance(images, list) + images = self.upscale_flag(job, images) + + assert isinstance(images, list) + return images + + def set_callback_target(self, job: Job): + "Set the callback target for the job, updates the shared object and also returns the target" + + if isinstance(job, Txt2ImgQueueEntry): + target = "txt2img" + elif isinstance(job, Img2ImgQueueEntry): + target = "img2img" + elif isinstance(job, ControlNetQueueEntry): + target = "controlnet" + elif isinstance(job, InpaintQueueEntry): + target = "inpainting" + else: + raise ValueError("Unknown job type") + + shared.current_method = target + return target + + def run_inference(self, job: Job) -> List[Image.Image]: + try: + model: Union[ + PyTorchStableDiffusion, + AITemplateStableDiffusion, + SDXLStableDiffusion, + "OnnxStableDiffusion", + ] = self.loaded_models[job.model] + except KeyError as err: + websocket_manager.broadcast_sync( + Notification( + "error", + "Model not loaded", + f"Model {job.model} is not loaded, please load it first", + ) + ) - strength: float = getattr(job.data, "strength", 1.0) - steps = math.floor(steps * strength) + logger.debug("Model not loaded on any GPU. Raising error") + raise ModelNotLoadedError(f"Model {job.model} is not loaded") from err - shared.current_done_steps = 0 + shared.interrupt = False - if not isinstance(job, ControlNetQueueEntry): - from core import shared_dependent + if job.flags: + logger.debug(f"Job flags: {job.flags}") - if shared_dependent.cached_controlnet_preprocessor is not None: - # Wipe cached controlnet preprocessor - shared_dependent.cached_controlnet_preprocessor = None - self.memory_cleanup() + steps = job.data.steps - # shared.current_model = model + strength: float = getattr(job.data, "strength", 1.0) + steps = math.floor(steps * strength) - if isinstance(model, PyTorchStableDiffusion): - logger.debug("Generating with PyTorch") - images: List[Image.Image] = model.generate(job) - elif isinstance(model, AITemplateStableDiffusion): - logger.debug("Generating with AITemplate") - images: List[Image.Image] = model.generate(job) + shared.current_done_steps = 0 + + if not isinstance(job, ControlNetQueueEntry): + from core import shared_dependent + + if shared_dependent.cached_controlnet_preprocessor is not None: + # Wipe cached controlnet preprocessor + shared_dependent.cached_controlnet_preprocessor = None + self.memory_cleanup() + + if isinstance(model, PyTorchStableDiffusion): + logger.debug("Generating with SD PyTorch") + shared.current_model = "SD1.x" + images: Union[List[Image.Image], torch.Tensor] = model.generate(job) + elif isinstance(model, SDXLStableDiffusion): + logger.debug("Generating with SDXL (PyTorch)") + shared.current_model = "SDXL" + images: Union[List[Image.Image], torch.Tensor] = model.generate(job) + elif isinstance(model, AITemplateStableDiffusion): + logger.debug("Generating with SD AITemplate") + images: Union[List[Image.Image], torch.Tensor] = model.generate(job) + else: + from core.inference.onnx import OnnxStableDiffusion + + if isinstance(model, OnnxStableDiffusion): + logger.debug("Generating with SD ONNX") + images: Union[List[Image.Image], torch.Tensor] = model.generate(job) else: - from core.inference.onnx import OnnxStableDiffusion + raise NotImplementedError("Unknown model type") - if isinstance(model, OnnxStableDiffusion): - logger.debug("Generating with ONNX") - images: List[Image.Image] = model.generate(job) - else: - raise NotImplementedError("Unknown model type") + self.memory_cleanup() - self.memory_cleanup() - return images + # Run postprocessing + images = self.postprocess(job, images) + return images + + def generate( + self, + job: InferenceJob, + ): + "Generate images from the queue" + + job = preprocess_job(job) try: # Wait for turn in the queue @@ -246,7 +441,8 @@ def generate_thread_call(job: Job) -> List[Image.Image]: # Generate images try: - generated_images = generate_thread_call(job) + self.set_callback_target(job) + generated_images = self.run_inference(job) assert generated_images is not None @@ -327,6 +523,7 @@ def load_model( self, model: str, backend: InferenceBackend, + type: PyTorchModelBase, ): "Load a model into memory" @@ -391,6 +588,22 @@ def load_model_thread_call( pt_model = OnnxStableDiffusion(model_id=model) self.loaded_models[model] = pt_model + elif type == "SDXL": + logger.debug("Selecting SDXL") + + websocket_manager.broadcast_sync( + Notification( + "info", + "SDXL", + f"Loading {model} into memory, this may take a while", + ) + ) + + sdxl_model = SDXLStableDiffusion( + model_id=model, + device=config.api.device, + ) + self.loaded_models[model] = sdxl_model else: logger.debug("Selecting PyTorch") @@ -581,7 +794,9 @@ def model_to_f16_thread_call(): def download_huggingface_model(self, model: str): "Download a model from the internet." - download_model(model) + from diffusers.pipelines.pipeline_utils import DiffusionPipeline + + DiffusionPipeline.download(model, resume_download=True) def load_vae(self, req: VaeLoadRequest): "Change the models VAE" @@ -589,10 +804,10 @@ def load_vae(self, req: VaeLoadRequest): if req.model in self.loaded_models: internal_model = self.loaded_models[req.model] - if isinstance(internal_model, PyTorchStableDiffusion): + if hasattr(internal_model, "change_vae"): logger.info(f"Loading VAE model: {req.vae}") - internal_model.change_vae(req.vae) + internal_model.change_vae(req.vae) # type: ignore websocket_manager.broadcast_sync( Notification( @@ -625,6 +840,20 @@ def load_textual_inversion(self, req: TextualInversionLoadRequest): f"Textual inversion model {req.textual_inversion} loaded", ) ) + if isinstance(internal_model, SDXLStableDiffusion): + logger.info(f"Loading textual inversion model: {req.textual_inversion}") + + internal_model.load_textual_inversion(req.textual_inversion) + + websocket_manager.broadcast_sync( + Notification( + "success", + "Textual inversion model loaded", + f"Textual inversion model {req.textual_inversion} loaded", + ) + ) + else: + logger.warning(f"Model {req.model} does not support textual inversion") else: websocket_manager.broadcast_sync( diff --git a/core/inference/adetailer/adetailer.py b/core/inference/adetailer/adetailer.py new file mode 100644 index 000000000..fce5f33e2 --- /dev/null +++ b/core/inference/adetailer/adetailer.py @@ -0,0 +1,141 @@ +# Taken from https://github.com/Bing-su/asdff +# Origial author: Bing-su +# Modified by: Stax124 + +import functools +import logging +from typing import Any, Callable, Iterable, List, Optional + +from asdff.utils import ( + ADOutput, + bbox_padding, + composite, + mask_dilate, + mask_gaussian_blur, +) +from asdff.yolo import yolo_detector +from PIL import Image, ImageOps + +from core.types import InpaintQueueEntry +from core.utils import convert_to_image + +logger = logging.getLogger(__name__) + +DetectorType = Callable[[Image.Image], Optional[List[Image.Image]]] + + +def ordinal(n: int) -> str: + d = {1: "st", 2: "nd", 3: "rd"} + return str(n) + ("th" if 11 <= n % 100 <= 13 else d.get(n % 10, "th")) + + +class ADetailer: + def get_default_detector(self, model_path: Optional[str] = None): + if model_path is not None: + return functools.partial(yolo_detector, model_path=model_path) + + return yolo_detector + + def generate( + self, + fn: Any, + inpaint_entry: InpaintQueueEntry, + detectors: DetectorType | Iterable[DetectorType] | None = None, + mask_dilation: int = 4, + mask_blur: int = 4, + mask_padding: int = 32, + iterations: int = 1, + upscale: int = 2, + yolo_model: Optional[str] = None, + ) -> ADOutput: + if detectors is None: + detectors = [self.get_default_detector(yolo_model)] + elif not isinstance(detectors, Iterable): + detectors = [detectors] + + input_image = convert_to_image(inpaint_entry.data.image).convert("RGB") + + init_images = [] + final_images = [] + + init_images.append(input_image.copy()) + final_image = None + + for j, detector in enumerate(detectors): + masks = detector(input_image) + if masks is None: + logger.info(f"No object detected with {ordinal(j + 1)} detector.") + continue + + for k, mask in enumerate(masks): + mask = mask.convert("L") + mask = mask_dilate(mask, mask_dilation) + bbox = mask.getbbox() + if bbox is None: + logger.info(f"No object in {ordinal(k + 1)} mask.") + continue + mask = mask_gaussian_blur(mask, mask_blur) + bbox_padded = bbox_padding(bbox, input_image.size, mask_padding) + inverted_mask = ImageOps.invert(mask) + + for _i in range(iterations): + inpaint_output: List[Image.Image] = self.process_inpainting( + fn, + inpaint_entry, + input_image, + inverted_mask, + bbox_padded, + upscale=upscale, + ) + + inpaint_image: Image.Image = inpaint_output[0] # type: ignore + + final_image = composite( + input_image, + mask, + inpaint_image, + bbox_padded, + ) + + input_image = final_image + + assert final_image is not None + final_images.append(final_image) + + return ADOutput(images=final_images, init_images=init_images) + + def process_inpainting( + self, + fn: Callable, + inpaint_entry: InpaintQueueEntry, + init_image: Image.Image, + mask: Image.Image, + bbox_padded: tuple[int, int, int, int], + upscale: int = 2, + ): # -> tuple[PipelineImageInput, Any | None] | StableDiffusionPipelineOutput: + crop_image = init_image.crop(bbox_padded) + crop_mask = mask.crop(bbox_padded) + + # Get the current size of the images + width, height = crop_image.size + + # Calculate the new size + new_size = (int(width * upscale), int(height * upscale)) + + # Resize the images + crop_image = crop_image.resize(new_size, resample=Image.LANCZOS) + crop_mask = crop_mask.resize(new_size, resample=Image.LANCZOS) + + inpaint_entry.data.image = crop_image # type: ignore + inpaint_entry.data.mask_image = crop_mask # type: ignore + + inpaint_entry.data.width = crop_image.width + inpaint_entry.data.height = crop_image.height + + out: List[Image.Image] = fn(inpaint_entry) + # Resize back to original size + out_resized = [] + for image in out: + out_resized.append(image.resize((width, height), resample=Image.LANCZOS)) + + return out_resized diff --git a/core/inference/ait/aitemplate.py b/core/inference/ait/aitemplate.py index 3b9fee669..d090a91fc 100644 --- a/core/inference/ait/aitemplate.py +++ b/core/inference/ait/aitemplate.py @@ -3,7 +3,7 @@ from typing import Any, List, Literal, Optional, Tuple, Union import torch -from diffusers.models.autoencoder_kl import AutoencoderKL +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from diffusers.models.controlnet import ControlNetModel from diffusers.models.unet_2d_condition import UNet2DConditionModel from PIL import Image @@ -14,13 +14,10 @@ from api import websocket_manager from api.websockets.data import Data -from core import shared from core.config import config -from core.flags import HighResFixFlag from core.inference.ait.pipeline import StableDiffusionAITPipeline from core.inference.base_model import InferenceModel -from core.inference.functions import load_pytorch_pipeline -from core.inference.utilities.latents import scale_latents +from core.inference.functions import get_output_type, load_pytorch_pipeline from core.inference_callbacks import callback from core.types import ( Backend, @@ -35,7 +32,6 @@ from ..utilities import ( change_scheduler, create_generator, - get_weighted_text_embeddings, image_to_controlnet_input, init_ait_module, ) @@ -215,7 +211,7 @@ def manage_optional_components( cn = ControlNetModel.from_pretrained( target_controlnet, resume_download=True, - torch_dtype=config.api.dtype, + torch_dtype=config.api.load_dtype, ) assert isinstance(cn, ControlNetModel) @@ -227,7 +223,7 @@ def manage_optional_components( "Optimization: xformers not available, enabling attention slicing instead" ) - cn.to(device=torch.device(self.device), dtype=config.api.dtype) + cn.to(device=torch.device(self.device), dtype=config.api.load_dtype) self.controlnet = cn self.current_controlnet = target_controlnet @@ -268,7 +264,7 @@ def create_pipe( ) return pipe - def generate(self, job: Job) -> List[Image.Image]: + def generate(self, job: Job) -> Union[List[Image.Image], torch.Tensor]: logging.info(f"Adding job {job.data.id} to queue") if isinstance(job, Txt2ImgQueueEntry): @@ -287,7 +283,7 @@ def generate(self, job: Job) -> List[Image.Image]: def txt2img( self, job: Txt2ImgQueueEntry, - ) -> List[Image.Image]: + ) -> Union[List[Image.Image], torch.Tensor]: "Generates images from text" pipe = self.create_pipe( scheduler=(job.data.scheduler, job.data.sigmas), @@ -296,88 +292,54 @@ def txt2img( generator = create_generator(seed=job.data.seed) - total_images: List[Image.Image] = [] - shared.current_method = "txt2img" + total_images: Union[List[Image.Image], torch.Tensor] = [] + output_type = get_output_type(job) for _ in tqdm(range(job.data.batch_count), desc="Queue", position=1): - output_type = "pil" - - if "highres_fix" in job.flags: - output_type = "latent" - - prompt_embeds, negative_prompt_embeds = get_weighted_text_embeddings( - pipe, job.data.prompt, job.data.negative_prompt - ) data = pipe( generator=generator, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, + prompt=job.data.prompt, + negative_prompt=job.data.negative_prompt, height=job.data.height, width=job.data.width, num_inference_steps=job.data.steps, guidance_scale=job.data.guidance_scale, - negative_prompt=job.data.negative_prompt, output_type=output_type, callback=callback, num_images_per_prompt=job.data.batch_size, ) - if output_type == "latent": - latents = data[0] # type: ignore - assert isinstance(latents, (torch.Tensor, torch.FloatTensor)) - - flag = job.flags["highres_fix"] - flag = HighResFixFlag.from_dict(flag) - - latents = scale_latents( - latents=latents, - scale=flag.scale, - latent_scale_mode=flag.latent_scale_mode, - ) + images: Union[List[Image.Image], torch.Tensor] = data[0] # type: ignore - data = pipe( - generator=generator, - prompt=job.data.prompt, - image=latents, - height=latents.shape[2] * 8, - width=latents.shape[3] * 8, - num_inference_steps=flag.steps, - guidance_scale=job.data.guidance_scale, - self_attention_scale=job.data.self_attention_scale, - negative_prompt=job.data.negative_prompt, - output_type="pil", - callback=callback, - strength=flag.strength, - return_dict=False, - num_images_per_prompt=job.data.batch_size, + if not isinstance(images, List): + total_images = images + else: + assert isinstance(total_images, List) + total_images.extend(images) + + if isinstance(total_images, List): + websocket_manager.broadcast_sync( + data=Data( + data_type="txt2img", + data={ + "progress": 0, + "current_step": 0, + "total_steps": 0, + "image": convert_images_to_base64_grid( + total_images, # type: ignore + quality=config.api.image_quality, + image_format=config.api.image_extension, + ), + }, ) - - images: list[Image.Image] = data[0] # type: ignore - - total_images.extend(images) - - websocket_manager.broadcast_sync( - data=Data( - data_type="txt2img", - data={ - "progress": 0, - "current_step": 0, - "total_steps": 0, - "image": convert_images_to_base64_grid( - total_images, - quality=config.api.image_quality, - image_format=config.api.image_extension, - ), - }, ) - ) return total_images def img2img( self, job: Img2ImgQueueEntry, - ) -> List[Image.Image]: + ) -> Union[List[Image.Image], torch.Tensor]: "Generates images from images" pipe = self.create_pipe( scheduler=(job.data.scheduler, job.data.sigmas), @@ -386,58 +348,62 @@ def img2img( generator = create_generator(seed=job.data.seed) - input_image = convert_to_image(job.data.image) - input_image = resize(input_image, job.data.width, job.data.height) + # Preprocess the image + if isinstance(job.data.image, (str, bytes, Image.Image)): + input_image = convert_to_image(job.data.image) + input_image = resize(input_image, job.data.width, job.data.height) + else: + input_image = job.data.image - total_images: List[Image.Image] = [] - shared.current_method = "img2img" + total_images: Union[List[Image.Image], torch.Tensor] = [] + output_type = get_output_type(job) for _ in tqdm(range(job.data.batch_count), desc="Queue", position=1): - prompt_embeds, negative_prompt_embeds = get_weighted_text_embeddings( - pipe, job.data.prompt, job.data.negative_prompt - ) data = pipe( generator=generator, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, + prompt=job.data.prompt, + negative_prompt=job.data.negative_prompt, image=input_image, # type: ignore num_inference_steps=job.data.steps, guidance_scale=job.data.guidance_scale, - negative_prompt=job.data.negative_prompt, - output_type="pil", + output_type=output_type, callback=callback, strength=job.data.strength, # type: ignore return_dict=False, num_images_per_prompt=job.data.batch_size, ) - images = data[0] - assert isinstance(images, List) - - total_images.extend(images) - - websocket_manager.broadcast_sync( - data=Data( - data_type="img2img", - data={ - "progress": 0, - "current_step": 0, - "total_steps": 0, - "image": convert_images_to_base64_grid( - total_images, - quality=config.api.image_quality, - image_format=config.api.image_extension, - ), - }, + images: Union[List[Image.Image], torch.Tensor] = data[0] # type: ignore + + if not isinstance(images, List): + total_images = images + else: + assert isinstance(total_images, List) + total_images.extend(images) + + if isinstance(total_images, List): + websocket_manager.broadcast_sync( + data=Data( + data_type="img2img", + data={ + "progress": 0, + "current_step": 0, + "total_steps": 0, + "image": convert_images_to_base64_grid( + total_images, # type: ignore + quality=config.api.image_quality, + image_format=config.api.image_extension, + ), + }, + ) ) - ) return total_images def controlnet2img( self, job: ControlNetQueueEntry, - ) -> List[Image.Image]: + ) -> Union[List[Image.Image], torch.Tensor]: "Generates images from images" pipe = self.create_pipe( controlnet=job.data.controlnet, @@ -454,22 +420,18 @@ def controlnet2img( if not job.data.is_preprocessed: input_image = image_to_controlnet_input(input_image, job.data) - total_images: List[Image.Image] = [input_image] - shared.current_method = "controlnet" + total_images: Union[List[Image.Image], torch.Tensor] = [] + output_type = get_output_type(job) for _ in tqdm(range(job.data.batch_count), desc="Queue", position=1): - prompt_embeds, negative_prompt_embeds = get_weighted_text_embeddings( - pipe, job.data.prompt, job.data.negative_prompt - ) data = pipe( generator=generator, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, + prompt=job.data.prompt, + negative_prompt=job.data.negative_prompt, image=input_image, # type: ignore num_inference_steps=job.data.steps, guidance_scale=job.data.guidance_scale, - negative_prompt=job.data.negative_prompt, - output_type="pil", + output_type=output_type, callback=callback, return_dict=False, num_images_per_prompt=job.data.batch_size, @@ -478,27 +440,31 @@ def controlnet2img( width=job.data.width, ) - images = data[0] - assert isinstance(images, List) - - total_images.extend(images) - - websocket_manager.broadcast_sync( - data=Data( - data_type="controlnet", - data={ - "progress": 0, - "current_step": 0, - "total_steps": 0, - "image": convert_images_to_base64_grid( - total_images - if job.data.return_preprocessed - else total_images[1:], - quality=config.api.image_quality, - image_format=config.api.image_extension, - ), - }, + images: Union[List[Image.Image], torch.Tensor] = data[0] # type: ignore + + if not isinstance(images, List): + total_images = images + else: + assert isinstance(total_images, List) + total_images.extend(images) + + if isinstance(total_images, List): + websocket_manager.broadcast_sync( + data=Data( + data_type="controlnet", + data={ + "progress": 0, + "current_step": 0, + "total_steps": 0, + "image": convert_images_to_base64_grid( + total_images # type: ignore + if job.data.return_preprocessed + else total_images[1:], + quality=config.api.image_quality, + image_format=config.api.image_extension, + ), + }, + ) ) - ) return total_images diff --git a/core/inference/ait/pipeline.py b/core/inference/ait/pipeline.py index d5d30986c..aafa911ae 100644 --- a/core/inference/ait/pipeline.py +++ b/core/inference/ait/pipeline.py @@ -19,7 +19,7 @@ from typing import Any, Callable, List, Optional, Union import torch -from diffusers.models.autoencoder_kl import AutoencoderKL +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from diffusers.models.controlnet import ControlNetModel from diffusers.models.unet_2d_condition import UNet2DConditionModel from diffusers.pipelines.stable_diffusion.pipeline_output import ( @@ -34,6 +34,7 @@ from tqdm import tqdm from transformers.models.clip import CLIPTextModel, CLIPTokenizer +from core.config import config from core.inference.functions import is_aitemplate_available from core.inference.utilities import ( get_timesteps, @@ -271,7 +272,7 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 if prompt is not None: - prompt_embeds, negative_prompt_embeds = get_weighted_text_embeddings( + prompt_embeds, _, negative_prompt_embeds, _ = get_weighted_text_embeddings( self, prompt=prompt, uncond_prompt=negative_prompt, @@ -324,11 +325,15 @@ def __call__( prompt_embeds.dtype, self.device, generator, + None, latents, align_to=64, ) extra_step_kwargs = prepare_extra_step_kwargs( - self.scheduler, eta, generator=generator + self.scheduler, + eta, + generator=generator, + device=torch.device(config.api.device), ) # Necessary for controlnet to function text_embeddings = text_embeddings.half() @@ -341,7 +346,9 @@ def __call__( - float(i / len(timesteps) < 0.0 or (i + 1) / len(timesteps) > 1.0) ) - def do_denoise(x, t, call: Callable) -> torch.Tensor: + def do_denoise( + x, t, call: Callable, change_source: Callable[[Callable], None] + ) -> torch.Tensor: latent_model_input = ( torch.cat([x] * 2) if do_classifier_free_guidance else x ) @@ -447,6 +454,11 @@ def do_denoise(x, t, call: Callable) -> torch.Tensor: callback_steps=1, ) else: + s = self.unet + + def change(src): + nonlocal s + s = src def _call(*args, **kwargs): if len(args) == 3: @@ -461,7 +473,7 @@ def _call(*args, **kwargs): ) for i, t in enumerate(tqdm(timesteps, desc="AITemplate")): - latents = do_denoise(latents, t, _call) # type: ignore + latents = do_denoise(latents, t, _call, change) # type: ignore # call the callback, if provided if callback is not None: diff --git a/core/inference/functions.py b/core/inference/functions.py index 23f4e29db..db5d154d8 100644 --- a/core/inference/functions.py +++ b/core/inference/functions.py @@ -2,34 +2,29 @@ import json import logging import os -from functools import partialmethod +from functools import partial from importlib.util import find_spec from pathlib import Path from typing import Any, Dict, Optional, Tuple, Union import requests import torch -from diffusers.models.autoencoder_kl import AutoencoderKL -from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( - assign_to_checkpoint, - conv_attn_to_linear, - create_vae_diffusers_config, - download_from_original_stable_diffusion_ckpt, - renew_vae_attention_paths, - renew_vae_resnet_paths, -) +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL +from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( StableDiffusionPipeline, ) from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.utils.constants import ( CONFIG_NAME, - DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, ONNX_WEIGHTS_NAME, WEIGHTS_NAME, ) -from diffusers.utils.hub_utils import HF_HUB_OFFLINE +from huggingface_hub.constants import ( + HUGGINGFACE_HUB_CACHE as DIFFUSERS_CACHE, + HF_HUB_OFFLINE, +) from huggingface_hub import model_info # type: ignore from huggingface_hub._snapshot_download import snapshot_download from huggingface_hub.file_download import hf_hub_download @@ -42,11 +37,18 @@ from omegaconf import OmegaConf from packaging import version from requests import HTTPError -from transformers import CLIPTextModel +from transformers.models.clip.modeling_clip import BaseModelOutput from core.config import config from core.files import get_full_model_path +from core.flags import HighResFixFlag from core.optimizations import compile_sfast +from core.types import Job +from .utilities.convert_from_ckpt import ( + download_from_original_stable_diffusion_ckpt, + convert_ldm_vae_checkpoint, + create_vae_diffusers_config, +) logger = logging.getLogger(__name__) config_name = "model_index.json" @@ -311,7 +313,7 @@ def download_model( def is_safetensors_compatible(info: ModelInfo) -> bool: "Check if the model is compatible with safetensors" - filenames = set(sibling.rfilename for sibling in info.siblings) + filenames = set(sibling.rfilename for sibling in info.siblings) # type: ignore pt_filenames = set(filename for filename in filenames if filename.endswith(".bin")) safetensors_compatible = any(file.endswith(".safetensors") for file in filenames) for pt_filename in pt_filenames: @@ -352,269 +354,114 @@ def load_pytorch_pipeline( else: logger.info("Loading model as checkpoint") - # This function does not inherit the channels so we need to hack it like this - in_channels = 9 if "inpaint" in model_id_or_path.casefold() else 4 - - cl = StableDiffusionPipeline - # I never knew this existed, but this is pretty handy :) - cl.__init__ = partialmethod(cl.__init__, requires_safety_checker=False) # type: ignore try: pipe = download_from_original_stable_diffusion_ckpt( str(get_full_model_path(model_id_or_path)), - pipeline_class=cl, # type: ignore from_safetensors=use_safetensors, extract_ema=True, - load_safety_checker=False, - num_in_channels=in_channels, ) except KeyError: pipe = download_from_original_stable_diffusion_ckpt( str(get_full_model_path(model_id_or_path)), - pipeline_class=cl, # type: ignore from_safetensors=use_safetensors, extract_ema=False, - load_safety_checker=False, - num_in_channels=in_channels, ) else: - pipe = StableDiffusionPipeline.from_pretrained( + pipe = DiffusionPipeline.from_pretrained( pretrained_model_name_or_path=get_full_model_path(model_id_or_path), - torch_dtype=config.api.dtype, + torch_dtype=config.api.load_dtype, safety_checker=None, feature_extractor=None, low_cpu_mem_usage=True, ) - assert isinstance(pipe, StableDiffusionPipeline) - logger.debug(f"Loaded {model_id_or_path} with {config.api.data_type}") - assert isinstance(pipe, StableDiffusionPipeline) + for name, text_encoder in [x for x in vars(pipe).items() if "text_encoder" in x[0]]: + if text_encoder is not None: + + def new_forward( + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + bober=None, + ): + output_hidden_states = True + original = bober.old_forward( # type: ignore + inputs_embeds, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = (_ := original[1])[: len(_) - config.api.clip_skip] + last_hidden_state = hidden_states[-1] - # AIT freaks out if any of these are lost - if not is_for_aitemplate: - conf = pipe.text_encoder.config - conf.num_hidden_layers = 13 - config.api.clip_skip - pipe.text_encoder = CLIPTextModel.from_pretrained( - None, config=conf, state_dict=pipe.text_encoder.state_dict() - ) - if config.api.clip_quantization != "full": - from transformers import BitsAndBytesConfig - from transformers.utils.bitsandbytes import ( - get_keys_to_not_convert, - replace_with_bnb_linear, - set_module_quantized_tensor_to_device, - ) + attentions = original[2] if output_attentions else None - state_dict = pipe.text_encoder.state_dict() # type: ignore - bnbconfig = BitsAndBytesConfig( - load_in_8bit=config.api.clip_quantization == "int8", - load_in_4bit=config.api.clip_quantization == "int4", - ) + if not return_dict: + return last_hidden_state, hidden_states, attentions + return BaseModelOutput( + last_hidden_state=last_hidden_state, + hidden_states=hidden_states, + attentions=attentions, + ) - dont_convert = get_keys_to_not_convert(pipe.text_encoder) - pipe.text_encoder = replace_with_bnb_linear( - pipe.text_encoder.to(config.api.device, config.api.dtype), # type: ignore - dont_convert, - quantization_config=bnbconfig, - ) + if config.api.clip_quantization != "full": + from transformers import BitsAndBytesConfig + from transformers.utils.bitsandbytes import ( + get_keys_to_not_convert, + replace_with_bnb_linear, + set_module_quantized_tensor_to_device, + ) - pipe.text_encoder.is_loaded_in_8bit = True - pipe.text_encoder.is_quantized = True + state_dict = text_encoder.state_dict() # type: ignore + bnbconfig = BitsAndBytesConfig( + load_in_8bit=config.api.clip_quantization == "int8", + load_in_4bit=config.api.clip_quantization == "int4", + ) - # This shouldn't even be needed, but diffusers likes meta tensors a bit too much - # Not that I don't see their purpose, it's just less general - for k, v in state_dict.items(): - set_module_quantized_tensor_to_device( - pipe.text_encoder, k, config.api.device, v + dont_convert = get_keys_to_not_convert(text_encoder) + text_encoder.is_loaded_in_8bit = True # type: ignore + text_encoder.is_quantized = True # type: ignore + nt = replace_with_bnb_linear( + pipe.text_encoder.to(config.api.device, config.api.load_dtype), # type: ignore + dont_convert, + quantization_config=bnbconfig, ) - del state_dict, dont_convert - del conf + + # This shouldn't even be needed, but diffusers likes meta tensors a bit too much + # Not that I don't see their purpose, it's just less general + for k, v in state_dict.items(): + set_module_quantized_tensor_to_device(nt, k, config.api.device, v) + setattr(pipe, name, nt) + del state_dict, dont_convert + + text_encoder.text_model.encoder.old_forward = text_encoder.text_model.encoder.forward # type: ignore + # fuck you python + # enjoy bober + text_encoder.text_model.encoder.forward = partial(new_forward, bober=text_encoder.text_model.encoder) # type: ignore + logger.debug(f"Overwritten {name}s final_layer_norm.") if optimize: from core.optimizations import optimize_model optimize_model( - pipe=pipe, + pipe=pipe, # type: ignore device=device, is_for_aitemplate=is_for_aitemplate, ) + if config.api.sfast_compile: + pipe = compile_sfast(pipe) else: - pipe.to(device) - - if config.api.sfast_compile: - pipe = compile_sfast(pipe) - - return pipe - - -def _custom_convert_ldm_vae_checkpoint(checkpoint, conf): - vae_state_dict = checkpoint - - new_checkpoint = {} - - new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] - new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] - new_checkpoint["encoder.conv_out.weight"] = vae_state_dict[ - "encoder.conv_out.weight" - ] - new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] - new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict[ - "encoder.norm_out.weight" - ] - new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict[ - "encoder.norm_out.bias" - ] - - new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] - new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] - new_checkpoint["decoder.conv_out.weight"] = vae_state_dict[ - "decoder.conv_out.weight" - ] - new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] - new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict[ - "decoder.norm_out.weight" - ] - new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict[ - "decoder.norm_out.bias" - ] - - new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] - new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] - new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] - new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] - - # Retrieves the keys for the encoder down blocks only - num_down_blocks = len( - { - ".".join(layer.split(".")[:3]) - for layer in vae_state_dict - if "encoder.down" in layer - } - ) - down_blocks = { - layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] - for layer_id in range(num_down_blocks) - } - - # Retrieves the keys for the decoder up blocks only - num_up_blocks = len( - { - ".".join(layer.split(".")[:3]) - for layer in vae_state_dict - if "decoder.up" in layer - } - ) - up_blocks = { - layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] - for layer_id in range(num_up_blocks) - } - - for i in range(num_down_blocks): - resnets = [ - key - for key in down_blocks[i] - if f"down.{i}" in key and f"down.{i}.downsample" not in key - ] - - if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: - new_checkpoint[ - f"encoder.down_blocks.{i}.downsamplers.0.conv.weight" - ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight") - new_checkpoint[ - f"encoder.down_blocks.{i}.downsamplers.0.conv.bias" - ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias") - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} - assign_to_checkpoint( - paths, - new_checkpoint, - vae_state_dict, - additional_replacements=[meta_path], - config=conf, - ) - - mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint( - paths, - new_checkpoint, - vae_state_dict, - additional_replacements=[meta_path], - config=conf, - ) - - mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint( - paths, - new_checkpoint, - vae_state_dict, - additional_replacements=[meta_path], - config=conf, - ) - conv_attn_to_linear(new_checkpoint) - - for i in range(num_up_blocks): - block_id = num_up_blocks - 1 - i - resnets = [ - key - for key in up_blocks[block_id] - if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key - ] + pipe.to(device, config.api.load_dtype) - if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: - new_checkpoint[ - f"decoder.up_blocks.{i}.upsamplers.0.conv.weight" - ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"] - new_checkpoint[ - f"decoder.up_blocks.{i}.upsamplers.0.conv.bias" - ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} - assign_to_checkpoint( - paths, - new_checkpoint, - vae_state_dict, - additional_replacements=[meta_path], - config=conf, - ) - - mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint( - paths, - new_checkpoint, - vae_state_dict, - additional_replacements=[meta_path], - config=conf, - ) - - mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint( - paths, - new_checkpoint, - vae_state_dict, - additional_replacements=[meta_path], - config=conf, - ) - conv_attn_to_linear(new_checkpoint) - return new_checkpoint + return pipe # type: ignore def convert_vaept_to_diffusers(path: str) -> AutoencoderKL: @@ -628,23 +475,36 @@ def convert_vaept_to_diffusers(path: str) -> AutoencoderKL: original_config = OmegaConf.load(io_obj) image_size = 512 - device = "cuda" if torch.cuda.is_available() else "cpu" if path.endswith("safetensors"): - from safetensors import safe_open + from safetensors.torch import load_file - checkpoint = {} - with safe_open(path, framework="pt", device="cpu") as f: - for key in f.keys(): - checkpoint[key] = f.get_tensor(key) + dev = str(config.api.load_device) + if "cuda" in dev: + dev = int(dev.split(":")[1]) + checkpoint = load_file(path, device=dev) # type: ignore else: - checkpoint = torch.load(path, map_location=device)["state_dict"] + checkpoint = torch.load( + path, + map_location=lambda storage, _: storage.to( + device=config.api.load_device, dtype=config.api.load_dtype + ), + )["state_dict"] # Convert the VAE model. vae_config = create_vae_diffusers_config(original_config, image_size=image_size) - converted_vae_checkpoint = _custom_convert_ldm_vae_checkpoint( - checkpoint, vae_config - ) + converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) vae = AutoencoderKL(**vae_config) vae.load_state_dict(converted_vae_checkpoint) return vae + + +def get_output_type(job: Job): + return ( + "latent" + if ( + "highres_fix" in job.flags + and HighResFixFlag(**job.flags["highres_fix"]).mode == "latent" + ) + else "pil" + ) diff --git a/core/inference/injectables/__init__.py b/core/inference/injectables/__init__.py index 08a09b188..3cb621699 100644 --- a/core/inference/injectables/__init__.py +++ b/core/inference/injectables/__init__.py @@ -1,5 +1,7 @@ from pathlib import Path from typing import Dict, List, Optional, Union +import re +import logging from diffusers.models.lora import LoRACompatibleConv import torch @@ -13,6 +15,9 @@ LoRACompatibleConv.old_forward = LoRACompatibleConv.forward # type: ignore +logger = logging.getLogger(__name__) + + def load_lora_utilities(pipe): "Reset/redirect Linear and Conv2ds forward to the lora processor" if hasattr(pipe, "lora_injector"): @@ -43,6 +48,7 @@ def __init__(self): super().__init__() self.managers: List[HookObject] = [] self.modules = {} + self.pipe = None self.device: torch.device = None # type: ignore self.dtype: torch.dtype = None # type: ignore @@ -59,54 +65,22 @@ def _get_target_modules( target_replace_modules: List[str], ): target_modules = [] + for name, module in root_module.named_modules(): if module.__class__.__name__ in target_replace_modules: for child_name, child_module in module.named_modules(): - is_linear = isinstance(child_module, torch.nn.Linear) - is_conv2d = isinstance(child_module, torch.nn.Conv2d) - if not (is_linear or is_conv2d): + if not any( + [ + x in child_module.__class__.__name__ + for x in ["Linear", "Conv"] + ] + ): continue - # retarded change, revert pliz diffusers - name = name.replace(".", "_") - name = name.replace("input_blocks", "down_blocks") - name = name.replace("middle_block", "mid_block") - name = name.replace("output_blocks", "out_blocks") - - name = name.replace("to_out_0_lora", "to_out_lora") - name = name.replace("emb_layers", "time_emb_proj") - - name = name.replace("q_proj_lora", "to_q_lora") - name = name.replace("k_proj_lora", "to_k_lora") - name = name.replace("v_proj_lora", "to_v_lora") - name = name.replace("out_proj_lora", "to_out_lora") - - # Prepare for SDXL - if "emb" in name: - import re - - pattern = r"\_\d+(?=\D*$)" - name = re.sub(pattern, "", name, count=1) - if "in_layers_2" in name: - name = name.replace("in_layers_2", "conv1") - if "out_layers_3" in name: - name = name.replace("out_layers_3", "conv2") - if "downsamplers" in name or "upsamplers" in name: - name = name.replace("op", "conv") - if "skip" in name: - name = name.replace("skip_connection", "conv_shortcut") - - if "transformer_blocks" in name: - if ( - "attn1" in name or "attn2" in name - ) and "processor" not in name: - name = name.replace("attn1", "attn1_processor") - name = name.replace("attn2", "attn2_processor") - elif "mlp" in name: - name = name.replace("_lora_", "_lora_linear_layer_") - lora_name = prefix + "." + name + "." + child_name + lora_name = prefix + "_" + name + "_" + child_name lora_name = lora_name.replace(".", "_") target_modules.append((lora_name, child_module)) + # print("---") return target_modules def _load_state_dict(self, file: Union[Path, str]) -> Dict[str, torch.nn.Module]: @@ -118,6 +92,15 @@ def _load_state_dict(self, file: Union[Path, str]) -> Dict[str, torch.nn.Module] state_dict = load_file(file) else: state_dict = torch.load(file) # .bin, .pt, .ckpt... + + if hasattr(self.pipe, "text_encoder_2"): + logger.debug("Mapping SGM") + unet_config = self.pipe.unet.config # type: ignore + state_dict = self._maybe_map_sgm_blocks_to_diffusers( + state_dict, unet_config + ) + state_dict = self._convert_kohya_lora_to_diffusers(state_dict) + return state_dict # type: ignore @torch.no_grad() @@ -148,6 +131,7 @@ def diffusers_lora_forward( def install_hooks(self, pipe): """Install LoRAHook to the pipe""" assert len(self.modules) == 0 + self.pipe = pipe text_encoder_targets = [] if hasattr(pipe, "text_encoder_2"): text_encoder_targets = ( @@ -156,7 +140,7 @@ def install_hooks(self, pipe): pipe.text_encoder_2, "lora_te2", ["CLIPAttention", "CLIPMLP"] ) + self._get_target_modules( - pipe.text_encoder, "lora_te1", ["CLIPAttention", "CLIPMLP"] + pipe.text_encoder, "lora_te", ["CLIPAttention", "CLIPMLP"] ) ) else: @@ -183,7 +167,7 @@ def install_hooks(self, pipe): self.change_forwards() self.device = config.api.device # type: ignore - self.dtype = pipe.unet.dtype + self.dtype = config.api.load_dtype # Temporary, TODO: replace this with something sensible def apply_lycoris( @@ -222,6 +206,292 @@ def apply_lora( lora.alpha = alpha if alpha else 1.0 self.managers[0].containers[file.name] = lora + def _convert_kohya_lora_to_diffusers(self, state_dict): + unet_state_dict = {} + te_state_dict = {} + te2_state_dict = {} + network_alphas = {} + + # every down weight has a corresponding up weight and potentially an alpha weight + lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")] + for key in lora_keys: + lora_name, lora_key = key.split(".", 1) + lora_name_up = lora_name + ".lora_up.weight" + lora_name_alpha = lora_name + ".alpha" + + if lora_name.startswith("lora_unet_"): + diffusers_name = lora_name.replace("lora_unet_", "").replace("_", ".") + + if "input.blocks" in diffusers_name: + diffusers_name = diffusers_name.replace( + "input.blocks", "down.blocks" + ) + + if "middle.block" in diffusers_name: + diffusers_name = diffusers_name.replace("middle.block", "mid.block") + if "output.blocks" in diffusers_name: + diffusers_name = diffusers_name.replace( + "output.blocks", "up.blocks" + ) + + diffusers_name = diffusers_name.replace("emb.layers", "time.emb.proj") + + # SDXL specificity. + if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name: + pattern = r"\.\d+(?=\D*$)" + diffusers_name = re.sub(pattern, "", diffusers_name, count=1) + if ".in." in diffusers_name: + diffusers_name = diffusers_name.replace("in.layers.2", "conv1") + if ".out." in diffusers_name: + diffusers_name = diffusers_name.replace("out.layers.3", "conv2") + if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name: + diffusers_name = diffusers_name.replace("op", "conv") + if "skip" in diffusers_name: + diffusers_name = diffusers_name.replace( + "skip.connection", "conv.shortcut" + ) + + # LyCORIS specificity. + if "time.emb.proj" in diffusers_name: + diffusers_name = diffusers_name.replace( + "time.emb.proj", "time.emb.proj" + ) + + # General coverage. + if "transformer_blocks" in diffusers_name: + if "attn1" in diffusers_name or "attn2" in diffusers_name: + diffusers_name = diffusers_name.replace(".", "_") + unet_state_dict[ + diffusers_name + "." + lora_key + ] = state_dict.pop(key) + unet_state_dict[ + (diffusers_name + "." + lora_key).replace("_down.", "_up.") + ] = state_dict.pop(lora_name_up) + elif "ff" in diffusers_name: + diffusers_name = diffusers_name.replace(".", "_") + unet_state_dict[ + diffusers_name + "." + lora_key + ] = state_dict.pop(key) + unet_state_dict[ + (diffusers_name + "." + lora_key).replace("_down.", "_up.") + ] = state_dict.pop(lora_name_up) + elif any(key in diffusers_name for key in ("proj.in", "proj.out")): + diffusers_name = diffusers_name.replace(".", "_") + unet_state_dict[diffusers_name + "." + lora_key] = state_dict.pop( + key + ) + unet_state_dict[ + (diffusers_name + "." + lora_key).replace("_down.", "_up.") + ] = state_dict.pop(lora_name_up) + else: + diffusers_name = diffusers_name.replace(".", "_") + unet_state_dict[diffusers_name + "." + lora_key] = state_dict.pop( + key + ) + unet_state_dict[ + (diffusers_name + "." + lora_key).replace("_down.", "_up.") + ] = state_dict.pop(lora_name_up) + + elif lora_name.startswith("lora_te2_"): + diffusers_name = key.replace("lora_te2_", "") + if "self_attn" in diffusers_name: + te2_state_dict[diffusers_name + "." + lora_key] = state_dict.pop( + key + ) + te2_state_dict[ + (diffusers_name + "." + lora_key).replace("_down.", "_up.") + ] = state_dict.pop(lora_name_up) + elif "mlp" in diffusers_name: + # Be aware that this is the new diffusers convention and the rest of the code might + # not utilize it yet. + diffusers_name = diffusers_name.replace( + ".lora.", ".lora_linear_layer." + ) + diffusers_name = diffusers_name.replace(".", "_") + te2_state_dict[diffusers_name + "." + lora_key] = state_dict.pop( + key + ) + te2_state_dict[ + (diffusers_name + "." + lora_key).replace("_down.", "_up.") + ] = state_dict.pop(lora_name_up) + elif lora_name.startswith("lora_te"): + diffusers_name = "_".join(key.split("_")[2:]) + if "self_attn" in diffusers_name: + te_state_dict[diffusers_name + "." + lora_key] = state_dict.pop(key) + te_state_dict[ + (diffusers_name + "." + lora_key).replace("_down.", "_up.") + ] = state_dict.pop(lora_name_up) + elif "mlp" in diffusers_name: + # Be aware that this is the new diffusers convention and the rest of the code might + # not utilize it yet. + diffusers_name = diffusers_name.replace( + ".lora.", ".lora_linear_layer." + ) + diffusers_name = diffusers_name.replace(".", "_") + te_state_dict[diffusers_name + "." + lora_key] = state_dict.pop(key) + te_state_dict[ + (diffusers_name + "." + lora_key).replace("_down.", "_up.") + ] = state_dict.pop(lora_name_up) + # Rename the alphas so that they can be mapped appropriately. + if lora_name_alpha in state_dict: + alpha = state_dict.pop(lora_name_alpha).item() + if lora_name_alpha.startswith("lora_unet_"): + prefix = "lora_unet_" + elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")): + prefix = "lora_te_" + else: + prefix = "lora_te2_" + new_name = prefix + diffusers_name.split("_lora")[0] + ".alpha" # type: ignore + network_alphas.update({new_name: alpha}) + + unet_state_dict = { + f"lora_unet_{module_name}": params + for module_name, params in unet_state_dict.items() + } + te_state_dict = { + f"lora_te_{module_name}": params + for module_name, params in te_state_dict.items() + } + te2_state_dict = ( + { + f"lora_te2_{module_name}": params + for module_name, params in te2_state_dict.items() + } + if len(te2_state_dict) > 0 + else None + ) + if te2_state_dict is not None: + te_state_dict.update(te2_state_dict) + + new_state_dict = {**unet_state_dict, **te_state_dict, **network_alphas} + return new_state_dict + + def _maybe_map_sgm_blocks_to_diffusers( + self, state_dict, unet_config, delimiter="_", block_slice_pos=5 + ): + # 1. get all state_dict_keys + all_keys = list(state_dict.keys()) + sgm_patterns = ["input_blocks", "middle_block", "output_blocks"] + + # 2. check if needs remapping, if not return original dict + is_in_sgm_format = False + for key in all_keys: + if any(p in key for p in sgm_patterns): + is_in_sgm_format = True + break + + if not is_in_sgm_format: + return state_dict + + # 3. Else remap from SGM patterns + new_state_dict = {} + inner_block_map = ["resnets", "attentions", "upsamplers"] + + # Retrieves # of down, mid and up blocks + input_block_ids, middle_block_ids, output_block_ids = set(), set(), set() + + for layer in all_keys: + if "text" in layer: + new_state_dict[layer] = state_dict.pop(layer) + else: + layer_id = int(layer.split(delimiter)[:block_slice_pos][-1]) + if sgm_patterns[0] in layer: + input_block_ids.add(layer_id) + elif sgm_patterns[1] in layer: + middle_block_ids.add(layer_id) + elif sgm_patterns[2] in layer: + output_block_ids.add(layer_id) + else: + raise ValueError( + f"Checkpoint not supported because layer {layer} not supported." + ) + + input_blocks = { + layer_id: [ + key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key + ] + for layer_id in input_block_ids + } + middle_blocks = { + layer_id: [ + key for key in state_dict if f"middle_block{delimiter}{layer_id}" in key + ] + for layer_id in middle_block_ids + } + output_blocks = { + layer_id: [ + key + for key in state_dict + if f"output_blocks{delimiter}{layer_id}" in key + ] + for layer_id in output_block_ids + } + + # Rename keys accordingly + for i in input_block_ids: + block_id = (i - 1) // (unet_config.layers_per_block + 1) + layer_in_block_id = (i - 1) % (unet_config.layers_per_block + 1) + + for key in input_blocks[i]: + inner_block_id = int(key.split(delimiter)[block_slice_pos]) + inner_block_key = ( + inner_block_map[inner_block_id] + if "op" not in key + else "downsamplers" + ) + inner_layers_in_block = ( + str(layer_in_block_id) if "op" not in key else "0" + ) + new_key = delimiter.join( + key.split(delimiter)[: block_slice_pos - 1] + + [str(block_id), inner_block_key, inner_layers_in_block] + + key.split(delimiter)[block_slice_pos + 1 :] + ) + new_state_dict[new_key] = state_dict.pop(key) + + for i in middle_block_ids: + key_part = None + if i == 0: + key_part = [inner_block_map[0], "0"] + elif i == 1: + key_part = [inner_block_map[1], "0"] + elif i == 2: + key_part = [inner_block_map[0], "1"] + else: + raise ValueError(f"Invalid middle block id {i}.") + + for key in middle_blocks[i]: + new_key = delimiter.join( + key.split(delimiter)[: block_slice_pos - 1] + + key_part + + key.split(delimiter)[block_slice_pos:] + ) + new_state_dict[new_key] = state_dict.pop(key) + + for i in output_block_ids: + block_id = i // (unet_config.layers_per_block + 1) + layer_in_block_id = i % (unet_config.layers_per_block + 1) + + for key in output_blocks[i]: + inner_block_id = int(key.split(delimiter)[block_slice_pos]) + inner_block_key = inner_block_map[inner_block_id] + inner_layers_in_block = ( + str(layer_in_block_id) if inner_block_id < 2 else "0" + ) + new_key = delimiter.join( + key.split(delimiter)[: block_slice_pos - 1] + + [str(block_id), inner_block_key, inner_layers_in_block] + + key.split(delimiter)[block_slice_pos + 1 :] + ) + new_state_dict[new_key] = state_dict.pop(key) + + if len(state_dict) > 0: + raise ValueError( + "At this point all state dict entries have to be converted." + ) + + return new_state_dict + def remove_lora(self, file: Union[Path, str]): """Remove the individual LoRA from the pipe.""" if not isinstance(file, Path): diff --git a/core/inference/injectables/lora.py b/core/inference/injectables/lora.py index c88593cc7..59cf2fda8 100644 --- a/core/inference/injectables/lora.py +++ b/core/inference/injectables/lora.py @@ -1,4 +1,5 @@ from typing import Dict, Union +import logging import torch @@ -6,6 +7,8 @@ from .utils import HookObject +logger = logging.getLogger(__name__) + class LoRAModule(object): "Main module per LoRA object." @@ -22,7 +25,7 @@ class LoRAUpDown(object): def __init__(self) -> None: self.down: Union[torch.nn.Conv2d, torch.nn.Linear] = None # type: ignore self.up: Union[torch.nn.Conv2d, torch.nn.Linear] = None # type: ignore - self.alpha: float = 0.5 + self.alpha: float = 1.0 # why was it 0.5??? class LoRAManager(HookObject): @@ -49,11 +52,16 @@ def load( ) -> LoRAModule: lora = LoRAModule(name) + missing = 0 + for k, v in state_dict.items(): key, lora_key = k.split(".", 1) module = modules.get(key, None) + if module is None: - print(key, lora_key) + # Big problem! Something broke, and that's BADDDDD!!! + logger.debug(f"Couldn't find {key}.{lora_key}") + missing += 1 continue lora_module = lora.modules.get(key, None) if lora_module is None: @@ -61,21 +69,30 @@ def load( lora.modules[key] = lora_module # type: ignore if lora_key == "alpha": - lora_module.alpha = v.item() # type: ignore + lora_module.alpha = v # type: ignore + continue + if isinstance(v, float): + # Probably loaded wrong, or lora is broken: just ignore, it's gonna be fine... + logger.debug( + f"{key}.{lora_key} has for whatever reason a float here, when it shouldn't..." + ) continue if isinstance(module, torch.nn.Linear): module = torch.nn.Linear(v.shape[1], v.shape[0], bias=False) # type: ignore else: - module = torch.nn.Conv2d(v.shape[1], v.shape[0], (1, 1), bias=False) # type: ignore + module = torch.nn.Conv2d(v.shape[1], v.shape[0], v.shape[2], v.shape[3], bias=False) # type: ignore with torch.no_grad(): module.weight.copy_(v, True) # type: ignore - module.to(device=torch.device("cpu"), dtype=config.api.dtype) + module.to(device=torch.device("cpu"), dtype=config.api.load_dtype) if lora_key == "lora_up.weight": lora_module.up = module else: lora_module.down = module - # print(*lora.modules.keys(), sep="\n") + if missing != 0: + logger.error( + f"Uh oh! Something went wrong loading the lora. If the output looks completely whack, contact us on discord! Missing keys: {missing}." + ) return lora def apply_hooks(self, p: Union[torch.nn.Conv2d, torch.nn.Linear]) -> None: @@ -92,13 +109,20 @@ def apply_hooks(self, p: Union[torch.nn.Conv2d, torch.nn.Linear]) -> None: p.weight.copy_(weights_backup) layer_name = getattr(p, "layer_name", None) + skipped = False for _, lora in self.containers.items(): module: LoRAModule = lora.modules.get(layer_name, None) # type: ignore if module is None: continue + if module.up is None: # type: ignore + skipped = True + continue with torch.no_grad(): - up = module.up.weight.to(p.weight.device, dtype=p.weight.dtype) # type: ignore - down = module.down.weight.to(p.weight.device, dtype=p.weight.dtype) # type: ignore + weight = p.weight.clone() + if hasattr(p, "fp16_weight"): + weight = p.fp16_weight.clone() # type: ignore + up = module.up.weight.to(p.weight.device, dtype=weight.dtype) # type: ignore + down = module.down.weight.to(p.weight.device, dtype=weight.dtype) # type: ignore if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): updown = ( @@ -108,13 +132,20 @@ def apply_hooks(self, p: Union[torch.nn.Conv2d, torch.nn.Linear]) -> None: ) else: updown = up @ down - p.weight += ( + + if len(weight.shape) == 4 and weight.shape[1] == 9: + # inpainting model. zero pad updown to make channel[1] 4 -> 9 + updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) + weight += ( updown * lora.alpha * ( - module.alpha / module.up.weight.shape[1] # type: ignore + module.alpha / module.down.weight.shape[0] # type: ignore if module.alpha else 1.0 ) ) + p.weight.copy_(weight.to(p.weight.dtype)) + if skipped: + logger.warn(f"Broken weight on {getattr(p, 'layer_name', 'UNKNOWN')}.") setattr(p, "lora_current_names", wanted_names) diff --git a/core/inference/injectables/lycoris.py b/core/inference/injectables/lycoris.py index 128bb1826..88bc99fe6 100644 --- a/core/inference/injectables/lycoris.py +++ b/core/inference/injectables/lycoris.py @@ -177,7 +177,7 @@ def load( lyco_module.scale = v.item() # type: ignore continue if lyco_key == "diff": - v = v.to(device=torch.device("cpu"), dtype=config.api.dtype) + v = v.to(device=torch.device("cpu"), dtype=config.api.load_dtype) v.requires_grad_(False) lyco_module = FullModule() # type: ignore lyco_module.weight = v # type: ignore @@ -199,7 +199,7 @@ def load( lyco_module.bias[0], # type: ignore lyco_module.bias[1], # type: ignore tuple(lyco_module.bias[2]), # type: ignore - ).to(device=torch.device("cpu"), dtype=config.api.dtype) + ).to(device=torch.device("cpu"), dtype=config.api.load_dtype) lyco_module.bias.requires_grad_(False) continue if lyco_key in CON_KEY: @@ -251,7 +251,7 @@ def load( v = v.reshape(module.weight.shape) # type: ignore module.weight.copy_(v) # type: ignore - module.to(device=torch.device("cpu"), dtype=config.api.dtype) # type: ignore + module.to(device=torch.device("cpu"), dtype=config.api.load_dtype) # type: ignore module.requires_grad_(False) # type: ignore if lyco_key == "lora_up.weight" or lyco_key == "dyn_up": @@ -274,7 +274,7 @@ def load( if hasattr(sd_module, "weight"): lyco_module.shape = sd_module.weight.shape # type: ignore - v = v.to(device=torch.device("cpu"), dtype=config.api.dtype) + v = v.to(device=torch.device("cpu"), dtype=config.api.load_dtype) v.requires_grad_(False) if lyco_key == "hada_w1_a": @@ -298,7 +298,7 @@ def load( lyco.modules[key] = lyco_module if lyco_key == "weight": - lyco_module.w = v.to(torch.device("cpu"), dtype=config.api.dtype) # type: ignore + lyco_module.w = v.to(torch.device("cpu"), dtype=config.api.load_dtype) # type: ignore elif lyco_key == "on_input": lyco_module.on_input = v # type: ignore elif lyco_key in KRON_KEY: @@ -312,7 +312,7 @@ def load( if hasattr(sd_module, "weight"): lyco_module.shape = sd_module.weight.shape # type: ignore - v = v.to(device=torch.device("cpu"), dtype=config.api.dtype) + v = v.to(device=torch.device("cpu"), dtype=config.api.load_dtype) v.requires_grad_(False) if lyco_key == "lokr_w1": diff --git a/core/inference/injectables/textual_inversion.py b/core/inference/injectables/textual_inversion.py new file mode 100644 index 000000000..e6976a3ca --- /dev/null +++ b/core/inference/injectables/textual_inversion.py @@ -0,0 +1,92 @@ +from copy import deepcopy + +from diffusers.loaders import ( + TextualInversionLoaderMixin, + load_textual_inversion_state_dicts, +) +import torch +from transformers import PreTrainedTokenizer, PreTrainedModel + + +def maybe_convert_prompt(prompt: str, tokenizer: PreTrainedTokenizer) -> str: + tokens = tokenizer.tokenize(prompt) + unique_tokens = set(tokens) + for token in unique_tokens: + if token in tokenizer.added_tokens_encoder: + replacement = token + i = 1 + while f"{token}_{i}" in tokenizer.added_tokens_encoder: + replacement += f" {token}_{i}" + i += 1 + prompt = prompt.replace(token, replacement) + return prompt + + +def unload(token: str, tokenizer: PreTrainedTokenizer, text_encoder: PreTrainedModel): + load_map = text_encoder.change_map if hasattr(text_encoder, "change_map") else [] + input_embedding: torch.Tensor = text_encoder.get_input_embeddings().weight + device, dtype = text_encoder.device, text_encoder.dtype + + if token in load_map: + token_id: int = tokenizer.convert_tokens_to_ids(token) # type: ignore + tokenizer.added_tokens_encoder.pop(token) + input_embedding.data = torch.cat( + (input_embedding.data[:token_id], input_embedding.data[token_id + 1 :]) + ) + text_encoder.resize_token_embeddings(len(tokenizer)) + load_map.remove(token) + + input_embedding.to(device, dtype) + setattr(text_encoder, "change_map", load_map) + + +def unload_all(tokenizer: PreTrainedTokenizer, text_encoder: PreTrainedModel): + load_map = text_encoder.change_map if hasattr(text_encoder, "change_map") else [] + input_embedding: torch.Tensor = text_encoder.get_input_embeddings().weight + device, dtype = text_encoder.device, text_encoder.dtype + + for token in deepcopy(load_map): + token_id: int = tokenizer.convert_tokens_to_ids(token) # type: ignore + tokenizer.added_tokens_encoder.pop(token) + input_embedding.data = torch.cat( + (input_embedding.data[:token_id], input_embedding.data[token_id + 1 :]) + ) + text_encoder.resize_token_embeddings(len(tokenizer)) + load_map.remove(token) + + input_embedding.to(device, dtype) + setattr(text_encoder, "change_map", load_map) + + +def load( + model: str, + token: str, + tokenizer: PreTrainedTokenizer, + text_encoder: PreTrainedModel, +): + state_dicts = load_textual_inversion_state_dicts(model) + + token, embeddings = TextualInversionLoaderMixin._retrieve_tokens_and_embeddings( + [token], state_dicts, tokenizer # type: ignore + ) + tokens, embeddings = TextualInversionLoaderMixin._retrieve_tokens_and_embeddings( + token, embeddings, tokenizer + ) + + device, dtype = text_encoder.device, text_encoder.dtype + + load_map = text_encoder.change_map if hasattr(text_encoder, "change_map") else [] + input_embedding: torch.Tensor = text_encoder.get_input_embeddings().weight + + def load(token, embedding): + tokenizer.add_tokens(token) + token_id = tokenizer.convert_tokens_to_ids(token) + input_embedding.data[token_id] = embedding + text_encoder.resize_token_embeddings(len(tokenizer)) + load_map.append(token) + + for _token, embedding in zip(tokens, embeddings): + load(_token, embedding) + + input_embedding.to(device, dtype) + setattr(text_encoder, "change_map", load_map) diff --git a/core/inference/onnx/pipeline.py b/core/inference/onnx/pipeline.py index 8ee22dbef..93d74fd1c 100644 --- a/core/inference/onnx/pipeline.py +++ b/core/inference/onnx/pipeline.py @@ -9,14 +9,14 @@ from dataclasses import fields from pathlib import Path from time import time -from typing import Dict, List, Optional, Tuple, Type, TypeVar, Union +from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union import numpy as np import torch from accelerate import init_empty_weights, load_checkpoint_and_dispatch from accelerate.utils import set_module_tensor_to_device from diffusers.models.attention_processor import AttnProcessor -from diffusers.models.autoencoder_kl import AutoencoderKL, AutoencoderKLOutput +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from diffusers.models.unet_2d_condition import UNet2DConditionModel from diffusers.models.vae import DecoderOutput from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE @@ -105,8 +105,8 @@ class AutoencoderKLWrapper(AutoencoderKL): def encode(self, x) -> Tuple: # pylint: disable=arguments-differ x = x.to(self.device, dtype=self.dtype) - outputs: AutoencoderKLOutput = AutoencoderKL.encode(self, x, True) # type: ignore - return (outputs.latent_dist.sample().to(self.device, dtype=self.dtype),) + outputs = AutoencoderKL.encode(self, x, True) # type: ignore + return (outputs.latent_dist.sample().to(self.device, dtype=self.dtype),) # type: ignore def decode(self, z) -> Tuple: # pylint: disable=arguments-differ z = z.to(self.device, dtype=self.dtype) @@ -609,7 +609,7 @@ def convert_text_encoder( main_folder / "text_encoder" ) assert isinstance(text_encoder, CLIPTextModelWrapper) - text_encoder.to(dtype=dtype, device=device) + text_encoder.to(dtype=dtype, device=device) # type: ignore text_encoder.eval() num_tokens = text_encoder.config.max_position_embeddings @@ -715,10 +715,10 @@ def convert_unet( del unet self.memory_cleanup() if needs_collate: - unet = onnx.load( # type: ignore pylint: disable=undefined-variable + unet = onnx.load( # type: ignore # noqa: F821 str((unet_out_path / "unet.onnx").absolute().as_posix()) ) - onnx.save_model( # type: ignore pylint: disable=undefined-variable + onnx.save_model( # type: ignore # noqa: F821 unet, str((output_folder / "unet.onnx").absolute().as_posix()), save_as_external_data=True, @@ -767,7 +767,7 @@ def fn_recursive_attn_processor( logger.info("Compiling cross-attention into model") set_attn_processor(vae, AttnProcessor()) - vae.forward = lambda sample: vae.encode(sample)[0] # type: ignore + vae.forward = lambda sample: vae.encode(sample)[0] # type: ignore # noqa: F821 onnx_export( vae, model_args=( @@ -1031,7 +1031,9 @@ def _init_latent_model(latents, do_classifier_free_guidance, t): logger.debug("timestep start") rt = time() - def do_inference(x: torch.Tensor, t: torch.Tensor, call) -> torch.Tensor: + def do_inference( + x, t, call: Callable, change_source: Callable[[Callable], None] + ) -> torch.Tensor: if kw is not None: latent_model_input = kw(x.numpy(), do_classifier_free_guidance, t) else: @@ -1086,6 +1088,11 @@ def do_inference(x: torch.Tensor, t: torch.Tensor, call) -> torch.Tensor: 1, ).numpy() else: + s = self.unet + + def change(src): + nonlocal s + s = src def _call(*args, **kwargs): if len(args) == 3: @@ -1101,7 +1108,7 @@ def _call(*args, **kwargs): )[0] for i, t in enumerate(tqdm(timesteps)): - latents = do_inference(latents, t, _call) + latents = do_inference(latents, t, _call, change) logger.debug("timestep end (%.2fs)", time() - rt) return latents diff --git a/core/inference/pytorch/pipeline.py b/core/inference/pytorch/pipeline.py index 134518dcb..704fd0eac 100644 --- a/core/inference/pytorch/pipeline.py +++ b/core/inference/pytorch/pipeline.py @@ -2,10 +2,13 @@ from contextlib import ExitStack from typing import Any, Callable, Dict, List, Literal, Optional, Union +import inspect +import math import PIL import torch -from diffusers.models.autoencoder_kl import AutoencoderKL +from diffusers.models.adapter import MultiAdapter +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from diffusers.models.controlnet import ControlNetModel from diffusers.models.unet_2d_condition import UNet2DConditionModel from diffusers.pipelines.stable_diffusion.pipeline_output import ( @@ -15,30 +18,50 @@ StableDiffusionPipeline, ) from diffusers.schedulers.scheduling_lms_discrete import LMSDiscreteScheduler +from diffusers.schedulers.scheduling_ddim import DDIMScheduler from diffusers.schedulers.scheduling_utils import SchedulerMixin from diffusers.utils import logging +from PIL import Image from tqdm import tqdm from transformers.models.clip import CLIPTextModel, CLIPTokenizer from core.config import config +from core.flags import DeepshrinkFlag, ScalecrafterFlag from core.inference.utilities import ( + ScalecrafterSettings, + calculate_cfg, full_vae, + get_scalecrafter_config, get_timesteps, get_weighted_text_embeddings, + modify_kohya, numpy_to_pil, pad_tensor, + post_scalecrafter, + postprocess_kohya, prepare_extra_step_kwargs, prepare_image, prepare_latents, prepare_mask_and_masked_image, prepare_mask_latents, + preprocess_adapter_image, preprocess_image, + setup_scalecrafter, + step_scalecrafter, +) +from core.inference.utilities.animatediff import ( + get_context_scheduler, + nil_scheduler, + freeinit_filter, + freeinit_mix, + prepare_mask_coef_by_statistics, ) from core.inference.utilities.philox import PhiloxGenerator -from core.optimizations import inference_context +from core.flags import AnimateDiffFlag +from core.optimizations import ensure_correct_device, inference_context, unload_all from core.scheduling import KdiffusionSchedulerAdapter -from .sag import CrossAttnStoreProcessor, pred_epsilon, pred_x0, sag_masking +from ..utilities.sag import CrossAttnStoreProcessor, calculate_sag # ------------------------------------------------------------------------------ @@ -85,6 +108,7 @@ def __init__( feature_extractor: Any = None, requires_safety_checker: bool = False, controlnet: Optional[ControlNetModel] = None, + image_encoder: Any = None, ): super().__init__( vae=vae, @@ -122,19 +146,43 @@ def _execution_device(self): `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ - if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): # type: ignore - return self.device - for module in self.unet.modules(): # type: ignore - if ( - hasattr(module, "_hf_hook") - and hasattr( - module._hf_hook, - "execution_device", - ) - and module._hf_hook.execution_device is not None # type: ignore - ): - return torch.device(module._hf_hook.execution_device) # type: ignore - return self.device + return torch.device(config.api.device) + + def _default_height_width(self, height, width, image): + if image is None: + return height, width + + # NOTE: It is possible that a list of images have different + # dimensions for each image, so just checking the first image + # is not _exactly_ correct, but it is simple. + while isinstance(image, list): + image = image[0] + + if height is None: + if isinstance(image, Image.Image): + height = image.height + elif isinstance(image, torch.Tensor): + height = image.shape[-2] + + # round down to nearest multiple of `self.adapter.downscale_factor` + if hasattr(self, "adapter") and self.adapter is not None: + height = ( + height // self.adapter.downscale_factor + ) * self.adapter.downscale_factor + + if width is None: + if isinstance(image, Image.Image): + width = image.width + elif isinstance(image, torch.Tensor): + width = image.shape[-1] + + # round down to nearest multiple of `self.adapter.downscale_factor` + if hasattr(self, "adapter") and self.adapter is not None: + width = ( + width // self.adapter.downscale_factor + ) * self.adapter.downscale_factor + + return height, width def _encode_prompt( self, @@ -167,6 +215,8 @@ def _encode_prompt( """ batch_size = len(prompt) if isinstance(prompt, list) else 1 + ensure_correct_device(self.text_encoder) + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) logger.debug(f"Post textual prompt: {prompt}") @@ -184,7 +234,7 @@ def _encode_prompt( " the batch size of `prompt`." ) - text_embeddings, uncond_embeddings = get_weighted_text_embeddings( + text_embeddings, _, uncond_embeddings, _ = get_weighted_text_embeddings( pipe=self.parent, prompt=prompt, uncond_prompt=negative_prompt if do_classifier_free_guidance else None, @@ -228,15 +278,6 @@ def _check_inputs(self, prompt, strength, callback_steps): f" {type(callback_steps)}." ) - def _decode_latents(self, latents, height, width): - latents = 1 / 0.18215 * latents - image = self.vae.decode(latents).sample # type: ignore - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - img = image[:, :height, :width, :] - return img - @torch.no_grad() def __call__( self, @@ -264,6 +305,11 @@ def __call__( callback_steps: int = 1, seed: int = 0, prompt_expansion_settings: Optional[Dict] = None, + adapter_conditioning_scale: Union[float, List[float]] = 1.0, + adapter_conditioning_factor: float = 1.0, + animatediff: Optional[AnimateDiffFlag] = None, + deepshrink: Optional[DeepshrinkFlag] = None, + scalecrafter: Optional[ScalecrafterFlag] = None, # type: ignore ): r""" Function invoked when calling the pipeline for generation. @@ -339,38 +385,52 @@ def __call__( list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ - - with inference_context(self.unet, self.vae, height, width) as inf: + with inference_context( + self.unet, self.vae, height, width, [animatediff] + ) as inf: # 0. Modify unet and vae to the (optionally) modified versions from inf self.unet = inf.unet # type: ignore self.vae = inf.vae # type: ignore + if scalecrafter is not None: + unsafe = scalecrafter.unsafe_resolutions # type: ignore + scalecrafter: ScalecrafterSettings = get_scalecrafter_config("sd15", height, width, scalecrafter.disperse) # type: ignore + logger.info( + f'Applying ScaleCrafter with (base="{scalecrafter.base}", res="{scalecrafter.height * 8}x{scalecrafter.width * 8}", dis="{scalecrafter.disperse is not None}")' + ) + if not unsafe and ( + (scalecrafter.height * 8) != height + or (scalecrafter.width * 8) != width + ): + height, width = scalecrafter.height * 8, scalecrafter.width * 8 + + height, width = self._default_height_width(height, width, image) + # 1. Check inputs. Raise error if not correct self._check_inputs(prompt, strength, callback_steps) if hasattr(self, "controlnet"): global_pool_conditions = self.controlnet.config.global_pool_conditions # type: ignore guess_mode = guess_mode or global_pool_conditions - num_channels_unet = self.unet.config.in_channels # type: ignore + num_channels_unet = self.unet.config["in_channels"] # 2. Define call parameters batch_size = 1 if isinstance(prompt, str) else len(prompt) device = self._execution_device + latents_device = torch.device("cpu") if animatediff is not None else device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 split_latents_into_two = ( - config.api.dont_merge_latents and do_classifier_free_guidance - ) - do_self_attention_guidance = self_attention_scale > 0.0 and not isinstance( - self.scheduler, KdiffusionSchedulerAdapter + not config.api.batch_cond_uncond and do_classifier_free_guidance ) + do_self_attention_guidance = self_attention_scale > 0.0 # 3. Encode input prompt text_embeddings = self._encode_prompt( prompt, - self.unet.dtype, + config.api.load_dtype, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, @@ -380,46 +440,50 @@ def __call__( ).to(device) dtype = text_embeddings.dtype + adapter_input = None # type: ignore + if hasattr(self, "adapter"): + if isinstance(self.adapter, MultiAdapter): + adapter_input: list = [] # type: ignore + + if not isinstance(adapter_conditioning_scale, list): + adapter_conditioning_scale = [ + adapter_conditioning_scale * len(image) + ] + + for oi in image: + oi = preprocess_adapter_image(oi, height, width) + oi = oi.to(device, dtype) # type: ignore + adapter_input.append(oi) # type: ignore + else: + adapter_input: torch.Tensor = preprocess_adapter_image( # type: ignore + adapter_input, height, width + ) + adapter_input.to(device, dtype) + # 4. Preprocess image and mask if isinstance(image, PIL.Image.Image): # type: ignore - width, height = image.size # type: ignore - if not hasattr(self, "controlnet"): - image = preprocess_image(image) + if animatediff is not None and animatediff.use_pia: + mask_image = image + image = None else: - image = prepare_image( - image, - width, - height, - batch_size, - num_images_per_prompt, - device, - dtype, - ) + width, height = image.size # type: ignore + if not hasattr(self, "controlnet"): + image = preprocess_image(image) + else: + image = prepare_image( + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + ) if image is not None: image = image.to(device=self.device, dtype=dtype) - if mask_image is not None: - mask, masked_image, _ = prepare_mask_and_masked_image( - image, mask_image, height, width - ) - mask, masked_image_latents = prepare_mask_latents( - mask, - masked_image, - batch_size * num_images_per_prompt, # type: ignore - height, - width, - dtype, - device, - do_classifier_free_guidance, - self.vae, - self.vae_scale_factor, - self.vae.config.scaling_factor, # type: ignore - generator=generator, - ) - else: - mask = None # 5. set timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) # type: ignore + self.scheduler.set_timesteps(num_inference_steps, device=latents_device) # type: ignore timesteps, num_inference_steps = get_timesteps( self.scheduler, num_inference_steps, @@ -441,14 +505,86 @@ def __call__( height, width, dtype, - device, + latents_device, generator, - latents, - latent_channels=None if mask is None else self.vae.config.latent_channels, # type: ignore + latents=latents, + latent_channels=None, + frames=None if animatediff is None else animatediff.frames, ) - # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = prepare_extra_step_kwargs(self.scheduler, eta, generator) # type: ignore + if mask_image is not None: + if animatediff is not None and animatediff.use_pia: + assert latents is not None + # fmt: off + mask_image = preprocess_image(mask_image) + mask_image = mask_image.to(device=self.vae.device, dtype=self.vae.dtype) + image_latent = self.vae.encode(mask_image).latent_dist.sample(generator) + image_latent = image_latent.to(device="cpu", dtype=torch.float32) + image_latent = torch.nn.functional.interpolate(image_latent, size=(latents.shape[-2], latents.shape[-1])) + image_latent = image_latent * 0.18215 + image_latent = image_latent.to(device=latents_device, dtype=latents.dtype) + mask = torch.zeros(latents.shape[0], 1, latents.shape[2], latents.shape[3], latents.shape[4]) \ + .to(device=latents_device, dtype=latents.dtype) + + sim_range = animatediff.pia_motion + if animatediff.pia_motion_type == "closed_loop" or animatediff.closed_loop: + sim_range += 3 + elif animatediff.pia_motion_type == "style_transfer": + sim_range = -1 * sim_range - 1 + + mask_coef = prepare_mask_coef_by_statistics(animatediff.frames, animatediff.pia_cond_frame, sim_range) + + masked_image = torch.zeros(latents.shape[0], 4, latents.shape[2], latents.shape[3], latents.shape[4]) \ + .to(device=latents_device, dtype=latents.dtype) + for f in range(animatediff.frames): + mask[:, :, f, :, :] = mask_coef[f] + masked_image[:, :, f, :, :] = image_latent.clone() + # fmt: on + else: + mask, masked_image, _ = prepare_mask_and_masked_image( + image, mask_image, height, width + ) + mask, masked_image_latents = prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, # type: ignore + height, + width, + dtype, + device, + do_classifier_free_guidance, + self.vae, + self.vae_scale_factor, + self.vae.config.scaling_factor, # type: ignore + generator=generator, + ) + else: + mask = None + + assert latents is not None + + # 7. Prepare extra step kwargs. + extra_step_kwargs = prepare_extra_step_kwargs(self.scheduler, eta, generator, latents_device) # type: ignore + + setup_scalecrafter(self.unet, scalecrafter) # type: ignore + + if hasattr(self, "adapter"): + if isinstance(self.adapter, MultiAdapter): + adapter_state = self.adapter( + adapter_input, adapter_conditioning_scale + ) + for k, v in enumerate(adapter_state): + adapter_state[k] = v + else: + adapter_state = self.adapter(adapter_input) + for k, v in enumerate(adapter_state): + adapter_state[k] = v * adapter_conditioning_scale + if num_images_per_prompt > 1: # type: ignore + for k, v in enumerate(adapter_state): + adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1) + if do_classifier_free_guidance: + for k, v in enumerate(adapter_state): + adapter_state[k] = torch.cat([v] * 2, dim=0) controlnet_keep = [] if hasattr(self, "controlnet"): @@ -472,131 +608,346 @@ def get_map_size(_, __, output): -2: ] # output.sample.shape[-2:] in older diffusers + cutoff = num_inference_steps * adapter_conditioning_factor + + j = 0 + idx = False + + context_scheduler = ( + get_context_scheduler(animatediff.context_scheduler) + if animatediff is not None + else nil_scheduler + ) + context_args = [] + iteration_count = 1 + freq_filter = None + f_scheduler: DDIMScheduler = None # type: ignore + f_fast_sampling = False + if animatediff is not None: + context_args = [ + animatediff.frames, + animatediff.context_size, + animatediff.frame_stride, + animatediff.frame_overlap, + animatediff.closed_loop, + ] + if animatediff.freeinit_iterations != -1: + iteration_count = animatediff.freeinit_iterations + f_fast_sampling = animatediff.freeinit_fast_sampling + # FreeInit only works with DDIM, so... just use DDIM :) + f_scheduler = DDIMScheduler.from_config( # type: ignore + { + "beta_start": 0.00085, + "beta_end": 0.012, + "beta_scheduler": "linear", + } + ) + freq_filter = freeinit_filter( + latents.shape, + device=self.unet.device, + params={ + "method": animatediff.freeinit_method, + "n": animatediff.freeinit_n, + "d_t": animatediff.freeinit_dt, + "d_s": animatediff.freeinit_ds, + }, + ) + + classify = do_classifier_free_guidance + prv_feature = None + + # 8. Denoising loop def do_denoise( x: torch.Tensor, t: torch.IntTensor, - call: Callable, + call: Callable[..., torch.Tensor], + change_source: Callable[[Callable], None], ): + nonlocal j, idx, do_classifier_free_guidance, prv_feature # expand the latents if we are doing classifier free guidance - latent_model_input = ( - torch.cat([x] * 2) if do_classifier_free_guidance and not split_latents_into_two else x # type: ignore + assert context_scheduler is not None + + self.unet = modify_kohya(self.unet, j, num_inference_steps, deepshrink) + + tau = min(j / num_inference_steps, 1.0) + + can_controlnet = ( + tau <= config.api.approximate_controlnet + or math.floor(tau * 100) % 5 == 0 + ) + do_classifier_free_guidance = classify and ( + tau <= config.api.cfg_uncond_tau + or config.api.cfg_uncond_tau == 1.0 + or animatediff is not None ) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # type: ignore - if num_channels_unet == 9: - latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) # type: ignore + noise_pred, counter = None, None + if animatediff is not None: + assert latents is not None + noise_pred = torch.zeros( + ( + x.shape[0] * (2 if do_classifier_free_guidance else 1), + *x.shape[1:], + ), + device=latents.device, + dtype=latents.dtype, + ) + counter = torch.zeros( + (1, 1, animatediff.frames, 1, 1), + device=latents.device, + dtype=latents.dtype, + ) - # predict the noise residual - if not hasattr(self, "controlnet"): - if split_latents_into_two: - uncond, cond = text_embeddings.chunk(2) - noise_pred_text = call(latent_model_input, t, cond=cond) - noise_pred_uncond = call(latent_model_input, t, cond=uncond) - else: - noise_pred = call( # type: ignore - latent_model_input, - t, - cond=text_embeddings, - ) - else: - assert self.controlnet is not None + self.unet = step_scalecrafter( + self.unet, scalecrafter, j, num_inference_steps + ) - if guess_mode and do_classifier_free_guidance: - # Infer ControlNet only for the conditional batch. - control_model_input = x - control_model_input = self.scheduler.scale_model_input(control_model_input, t).half() # type: ignore - controlnet_prompt_embeds = text_embeddings.chunk(2)[1] + for context in context_scheduler(j, *context_args): + if animatediff is not None: + latent_model_input = x[:, :, context].repeat( + 2 if do_classifier_free_guidance else 1, 1, 1, 1, 1 + ) + if mask is not None: + assert masked_image is not None + # fmt: off + latent_mask = torch.cat([mask[:, :, context]] * 2) if do_classifier_free_guidance else mask[:, :, context] + latent_masked_image = torch.cat([masked_image[:, :, context]] * 2) if do_classifier_free_guidance else masked_image[:, :, context] + # fmt: on else: - control_model_input = latent_model_input - controlnet_prompt_embeds = text_embeddings + latent_model_input = ( + torch.cat([x] * 2) if do_classifier_free_guidance and not split_latents_into_two else x # type: ignore + ) + if mask is not None: + latent_mask = mask + latent_masked_image = masked_image_latents # type: ignore + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # type: ignore + + if num_channels_unet == 9: + latent_model_input = torch.cat([latent_model_input, latent_mask, latent_masked_image], dim=1) # type: ignore + latent_model_input = latent_model_input.to(device=latents_device) + + # predict the noise residual + down_intrablock_additional_residuals = None + if hasattr(self, "adapter") and self.adapter is not None: + if j < cutoff: + assert adapter_state is not None + down_intrablock_additional_residuals = [ + state.clone() for state in adapter_state + ] - cond_scale = controlnet_conditioning_scale * controlnet_keep[i] + down_block_res_samples, mid_block_res_sample = None, None + if ( + hasattr(self, "controlnet") and self.controlnet is not None + ) and can_controlnet: + if guess_mode and do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = x + control_model_input = self.scheduler.scale_model_input(control_model_input, t).half() # type: ignore + controlnet_prompt_embeds = text_embeddings.chunk(2)[1] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = text_embeddings + + cond_scale = controlnet_conditioning_scale * controlnet_keep[j] + + change_source(self.controlnet) + down_block_res_samples, mid_block_res_sample = call( + control_model_input, + t, + cond=controlnet_prompt_embeds, + controlnet_cond=image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + ) - ( - down_block_res_samples, - mid_block_res_sample, - ) = self.controlnet( - control_model_input, - t, - encoder_hidden_states=controlnet_prompt_embeds, - controlnet_cond=image, - conditioning_scale=cond_scale, - guess_mode=guess_mode, - return_dict=False, + if guess_mode and do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [ + torch.cat([torch.zeros_like(d), d]) + for d in down_block_res_samples + ] + mid_block_res_sample = torch.cat( + [ + torch.zeros_like(mid_block_res_sample), + mid_block_res_sample, + ] + ) + + change_source(self.unet) + kwargs = set( + inspect.signature(self.unet.forward).parameters.keys() # type: ignore ) - if guess_mode and do_classifier_free_guidance: - # Infered ControlNet only for the conditional batch. - # To apply the output of ControlNet to both the unconditional and conditional batches, - # add 0 to the unconditional batch to keep it unchanged. - down_block_res_samples = [ - torch.cat([torch.zeros_like(d), d]) - for d in down_block_res_samples - ] - mid_block_res_sample = torch.cat( - [ - torch.zeros_like(mid_block_res_sample), - mid_block_res_sample, - ] - ) - if split_latents_into_two: + _kwargs = { + "down_block_additional_residuals": down_block_res_samples, + "mid_block_additional_residual": mid_block_res_sample, + "down_intrablock_additional_residuals": down_intrablock_additional_residuals, + "order": j, + "drop_encode_decode": config.api.drop_encode_decode != "off", + "quick_replicate": config.api.deepcache_cache_interval > 1, + "replicate_prv_feature": prv_feature, + } + if split_latents_into_two and do_classifier_free_guidance: uncond, cond = text_embeddings.chunk(2) - noise_pred_text = call(latent_model_input, t, cond=cond) - noise_pred_uncond = call(latent_model_input, t, cond=uncond) - else: - noise_pred = call( # type: ignore - latent_model_input, - t, - cond=text_embeddings, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, + uncond_down, uncond_mid = None, None + cond_down, cond_mid = None, None + uncond_intra, cond_intra = None, None + + if down_block_res_samples is not None: + uncond_down, cond_down = down_block_res_samples.chunk(2) # type: ignore + uncond_mid, cond_mid = mid_block_res_sample.chunk(2) # type: ignore + + if down_intrablock_additional_residuals is not None: + uncond_intra, cond_intra = [], [] + for s in down_intrablock_additional_residuals: + unc, cnd = s.chunk(2) + uncond_intra.append(unc) + cond_intra.append(cnd) + + _kwargs.update( + { + "down_block_additional_residuals": cond_down, + "mid_block_additional_residual": cond_mid, + "down_intrablock_additional_residuals": cond_intra, + } + ) + for kw, _ in _kwargs.copy().items(): + if kw not in kwargs: + del _kwargs[kw] + + if animatediff is not None: + assert noise_pred is not None + assert counter is not None + # fmt: off + # This is an abomination with formatting enabled + lmi = latent_model_input.to(dtype=dtype, device=device) + noise_pred[1, :, context] = noise_pred[1, :, context] \ + + call( + lmi, + t, + cond=cond, + **_kwargs, + )[0].to(dtype=noise_pred.dtype, device=noise_pred.device) + counter[1, :, context] = counter[1, :, context] + 1 + # fmt: on + else: + noise_pred_text = call( + latent_model_input, t, cond=cond, **_kwargs + ) + + _kwargs.update( + { + "down_block_additional_residuals": uncond_down, + "mid_block_additional_residual": uncond_mid, + "down_intrablock_additional_residuals": uncond_intra, + } ) + for kw, _ in _kwargs.copy().items(): + if kw not in kwargs: + del _kwargs[kw] + if animatediff is not None: + assert noise_pred is not None + assert counter is not None + # fmt: off + # This is an abomination with formatting enabled + lmi = latent_model_input.to(dtype=dtype, device=device) + noise_pred[0, :, context] = noise_pred[0, :, context] \ + + call( + lmi, + t, + cond=uncond, + **_kwargs, + )[0].to(dtype=noise_pred.dtype, device=noise_pred.device) + counter[0, :, context] = counter[0, :, context] + 1 + # fmt: on + else: + noise_pred_uncond = call( + latent_model_input, t, cond=uncond, **_kwargs + ) + else: + for kw, _ in _kwargs.copy().items(): + if kw not in kwargs: + del _kwargs[kw] + + if animatediff is not None: + assert noise_pred is not None + assert counter is not None + # fmt: off + # This is an abomination with formatting enabled + lmi = latent_model_input.to(dtype=dtype, device=device) + noise_pred[:, :, context] = noise_pred[:, :, context] \ + + call( + lmi, + t, + cond=text_embeddings, + **_kwargs, + )[0].to(dtype=noise_pred.dtype, device=noise_pred.device) + counter[:, :, context] = counter[:, :, context] + 1 + # fmt: on + else: + noise_pred = call( # type: ignore + latent_model_input, + t, + cond=text_embeddings, + **_kwargs, + ) + + self.unet, noise_pred_vanilla = post_scalecrafter( + self.unet, + scalecrafter, + j, + num_inference_steps, + call, + latent_model_input, # type: ignore + t, + cond=text_embeddings, + down_block_additional_residuals=down_block_res_samples, # type: ignore + mid_block_additional_residual=mid_block_res_sample, # type: ignore + down_intrablock_additional_residuals=down_intrablock_additional_residuals, # type: ignore + ) # perform guidance if do_classifier_free_guidance: if not split_latents_into_two: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) # type: ignore - noise_pred = noise_pred_uncond + guidance_scale * ( # type: ignore - noise_pred_text - noise_pred_uncond # type: ignore - ) # type: ignore + if animatediff is not None: + assert noise_pred is not None + assert counter is not None + noise_pred_uncond, noise_pred_text = (noise_pred / counter).chunk(2) # type: ignore + else: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) # type: ignore + noise_pred = calculate_cfg( + j, + noise_pred_text, # type: ignore + noise_pred_uncond, # type: ignore + guidance_scale, + t, + additional_pred=noise_pred_vanilla, + ) if do_self_attention_guidance: - if do_classifier_free_guidance: - pred = pred_x0(self, x, noise_pred_uncond, t) # type: ignore - uncond_attn, cond_attn = store_processor.attention_probs.chunk(2) # type: ignore - degraded_latents = sag_masking( - self, pred, uncond_attn, map_size, t, pred_epsilon(self, x, noise_pred_uncond, t) # type: ignore - ) - uncond_emb, _ = text_embeddings.chunk(2) - # predict the noise residual - # this probably could have been done better but honestly fuck this - degraded_prep = call( # type: ignore - degraded_latents.to(dtype=self.unet.dtype), - t, - cond=uncond_emb, - ) - noise_pred += self_attention_scale * (noise_pred_uncond - degraded_prep) # type: ignore - else: - pred = pred_x0(self, x, noise_pred, t) # type: ignore - cond_attn = store_processor.attention_probs # type: ignore - degraded_latents = sag_masking( - self, - pred, - cond_attn, - map_size, - t, - pred_epsilon(self, x, noise_pred, t), # type: ignore - ) - # predict the noise residual - degraded_prep = call( # type: ignore - degraded_latents.to(dtype=self.unet.dtype), - t, - cond=text_embeddings, - ) - noise_pred += self_attention_scale * (noise_pred - degraded_prep) # type: ignore + if not do_classifier_free_guidance: + noise_pred_uncond = noise_pred # type: ignore + noise_pred += calculate_sag( # type: ignore + self, + call, + store_processor, # type: ignore + x, + noise_pred_uncond, # type: ignore + t, + map_size, # type: ignore + text_embeddings, + self_attention_scale, + guidance_scale, + dtype, + drop_encode_decode=config.api.drop_encode_decode != "off", + order=j, + ) if not isinstance(self.scheduler, KdiffusionSchedulerAdapter): # compute the previous noisy sample x_t -> x_t-1 + assert noise_pred is not None x = self.scheduler.step( # type: ignore noise_pred, t.to(noise_pred.device), x.to(noise_pred.device), **extra_step_kwargs # type: ignore ).prev_sample # type: ignore @@ -609,392 +960,120 @@ def do_denoise( init_mask = mask[:1] init_mask = pad_tensor(init_mask, 8, (x.shape[2], x.shape[3])) - if i < len(timesteps) - 1: - noise_timestep = timesteps[i + 1] - init_latents_proper = self.scheduler.add_noise( - init_latents_proper, noise, torch.tensor([noise_timestep]) # type: ignore - ) - x = (1 - init_mask) * init_latents_proper + init_mask * x # type: ignore + + j += 1 + self.unet = postprocess_kohya(self.unet) # type: ignore return x # 8. Denoising loop + ensure_correct_device(self.unet) + latents = latents.to(dtype=dtype) # type: ignore + if image_latents is not None: + image_latents = image_latents.to(dtype=dtype) # type: ignore + with ExitStack() as gs: if do_self_attention_guidance: gs.enter_context(self.unet.mid_block.attentions[0].register_forward_hook(get_map_size)) # type: ignore - if isinstance(self.scheduler, KdiffusionSchedulerAdapter): - latents = self.scheduler.do_inference( - latents, # type: ignore - generator=generator, - call=self.unet, # type: ignore - apply_model=do_denoise, - callback=callback, - callback_steps=callback_steps, - ) - else: - - def _call(*args, **kwargs): - if len(args) == 3: - encoder_hidden_states = args[-1] - args = args[:2] - if kwargs.get("cond", None) is not None: - encoder_hidden_states = kwargs.pop("cond") - return self.unet( - *args, - encoder_hidden_states=encoder_hidden_states, # type: ignore - return_dict=True, - **kwargs, - )[0] - - for i, t in enumerate(tqdm(timesteps, desc="PyTorch")): - latents = do_denoise(latents, t, _call) # type: ignore - - # call the callback, if provided - if i % callback_steps == 0: - if callback is not None: - callback(i, t, latents) # type: ignore - if ( - is_cancelled_callback is not None - and is_cancelled_callback() - ): - return None + for iter in range(iteration_count): + if hasattr(self.scheduler, "_step_index"): + self.scheduler._step_index = None # type: ignore + if freq_filter is not None: + if iter == 0: + assert latents is not None + initial_noise = latents.detach().clone() + else: + assert latents is not None and f_scheduler is not None + diffuse_timestep = ( + f_scheduler.config["num_train_timesteps"] - 1 + ) + diffuse_timesteps = torch.full( + (batch_size,), int(diffuse_timestep) + ) + diffuse_timesteps = diffuse_timesteps.long() + z_T = f_scheduler.add_noise( + original_samples=latents.to(device), # type: ignore + noise=initial_noise.to(device), # type: ignore + timesteps=diffuse_timesteps.to(device), # type: ignore + ) + z_rand = torch.randn(latents.shape, device=device) + lt = latents.dtype + latents = freeinit_mix( + z_T.to(dtype=torch.float32), z_rand, LPF=freq_filter + ) + latents = latents.to(dtype=lt) # type: ignore + if f_fast_sampling: + curr_inf = int( + num_inference_steps / iteration_count * (iter + 1) + ) + self.scheduler.set_timesteps(curr_inf, device=device) + timesteps = self.scheduler.timesteps + + if isinstance(self.scheduler, KdiffusionSchedulerAdapter): + latents = self.scheduler.do_inference( + latents, # type: ignore + device=latents_device, + generator=generator, + call=self.unet, # type: ignore + apply_model=do_denoise, + callback=callback, + callback_steps=callback_steps, + ) + else: + s = self.unet + + def change(src): + nonlocal s + s = src + + def _call(*args, **kwargs): + if len(args) == 3: + encoder_hidden_states = args[-1] + args = args[:2] + if kwargs.get("cond", None) is not None: + encoder_hidden_states = kwargs.pop("cond") + ret = s( + *args, + encoder_hidden_states=encoder_hidden_states, # type: ignore + return_dict=False, + **kwargs, + ) + if isinstance(s, UNet2DConditionModel): + return ret[0] + return ret + + for i, t in enumerate(tqdm(timesteps, desc="PyTorch")): + latents = do_denoise(latents, t, _call, change) # type: ignore + + # call the callback, if provided + if i % callback_steps == 0: + if callback is not None: + callback(i, t, latents) # type: ignore + if ( + is_cancelled_callback is not None + and is_cancelled_callback() + ): + return None + + latents = latents.to(device=device) # type: ignore # 9. Post-processing if output_type == "latent": + unload_all() return latents, False - image = full_vae(latents, overwrite=lambda sample: self.vae.decode(sample).sample, height=height, width=width) # type: ignore + converted_image = full_vae(latents, self.vae, height=height, width=width) # type: ignore # 11. Convert to PIL if output_type == "pil": - image = numpy_to_pil(image) - - if hasattr(self, "final_offload_hook"): - self.final_offload_hook.offload() # type: ignore + converted_image = numpy_to_pil(converted_image) - if not return_dict: - return image, False + unload_all() - return StableDiffusionPipelineOutput( - images=image, nsfw_content_detected=False # type: ignore - ) + if not return_dict: + return converted_image, False # type: ignore - def text2img( - self, - prompt: Union[str, List[str]], - generator: Union[PhiloxGenerator, torch.Generator], - negative_prompt: Optional[Union[str, List[str]]] = None, - height: int = 512, - width: int = 512, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - self_attention_scale: float = 0.0, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - latents: Optional[torch.FloatTensor] = None, - max_embeddings_multiples: Optional[int] = 100, - output_type: Literal["pil", "latent"] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - is_cancelled_callback: Optional[Callable[[], bool]] = None, - callback_steps: int = 1, - seed: int = 1, - prompt_expansion_settings: Optional[Dict] = None, - ): - r""" - Function for text-to-image generation. - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - height (`int`, *optional*, defaults to 512): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - max_embeddings_multiples (`int`, *optional*, defaults to `100`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - is_cancelled_callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. If the function returns - `True`, the inference will be cancelled. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - return self.__call__( - prompt=prompt, - generator=generator, - negative_prompt=negative_prompt, - height=height, - width=width, - num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, - self_attention_scale=self_attention_scale, - num_images_per_prompt=num_images_per_prompt, - eta=eta, - latents=latents, - max_embeddings_multiples=max_embeddings_multiples, - output_type=output_type, - return_dict=return_dict, - callback=callback, - is_cancelled_callback=is_cancelled_callback, - callback_steps=callback_steps, - seed=seed, - prompt_expansion_settings=prompt_expansion_settings, - ) - - def img2img( - self, - image: Union[torch.FloatTensor, PIL.Image.Image], # type: ignore - prompt: Union[str, List[str]], - generator: Union[PhiloxGenerator, torch.Generator], - height: int = 512, - width: int = 512, - negative_prompt: Optional[Union[str, List[str]]] = None, - strength: float = 0.8, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - self_attention_scale: float = 0.0, - num_images_per_prompt: Optional[int] = 1, - eta: Optional[float] = 0.0, - max_embeddings_multiples: Optional[int] = 100, - output_type: Literal["pil", "latent"] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - is_cancelled_callback: Optional[Callable[[], bool]] = None, - callback_steps: int = 1, - seed: int = 1, - prompt_expansion_settings: Optional[Dict] = None, - ): - r""" - Function for image-to-image generation. - Args: - image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. - `image` will be used as a starting point, adding more noise to it the larger the `strength`. The - number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added - noise will be maximum and the denoising process will run for the full number of iterations specified in - `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. This parameter will be modulated by `strength`. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - max_embeddings_multiples (`int`, *optional*, defaults to `100`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - is_cancelled_callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. If the function returns - `True`, the inference will be cancelled. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - return self.__call__( - prompt=prompt, - generator=generator, - negative_prompt=negative_prompt, - image=image, - height=height, - width=width, - num_inference_steps=num_inference_steps, # type: ignore - guidance_scale=guidance_scale, # type: ignore - self_attention_scale=self_attention_scale, - strength=strength, - num_images_per_prompt=num_images_per_prompt, - eta=eta, # type: ignore - max_embeddings_multiples=max_embeddings_multiples, - output_type=output_type, - return_dict=return_dict, - callback=callback, - is_cancelled_callback=is_cancelled_callback, - callback_steps=callback_steps, - seed=seed, - prompt_expansion_settings=prompt_expansion_settings, - ) - - def inpaint( - self, - image: Union[torch.FloatTensor, PIL.Image.Image], # type: ignore - mask_image: Union[torch.FloatTensor, PIL.Image.Image], # type: ignore - prompt: Union[str, List[str]], - generator: Union[PhiloxGenerator, torch.Generator], - negative_prompt: Optional[Union[str, List[str]]] = None, - strength: float = 0.8, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - self_attention_scale: float = 0.0, - num_images_per_prompt: Optional[int] = 1, - eta: Optional[float] = 0.0, - max_embeddings_multiples: Optional[int] = 100, - output_type: Literal["pil", "latent"] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - is_cancelled_callback: Optional[Callable[[], bool]] = None, - callback_steps: int = 1, - width: int = 512, - height: int = 512, - seed: int = 1, - prompt_expansion_settings: Optional[Dict] = None, - ): - r""" - Function for inpaint. - Args: - image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. This is the image whose masked region will be inpainted. - mask_image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be - replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a - PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should - contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` - is 1, the denoising process will be run on the masked area for the full number of iterations specified - in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more - noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur. - num_inference_steps (`int`, *optional*, defaults to 50): - The reference number of denoising steps. More denoising steps usually lead to a higher quality image at - the expense of slower inference. This parameter will be modulated by `strength`, as explained above. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - max_embeddings_multiples (`int`, *optional*, defaults to `100`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - is_cancelled_callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. If the function returns - `True`, the inference will be cancelled. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - width (`int`, *optional*, defaults to 512): - The width of the generated image. - height (`int`, *optional*, defaults to 512): - The height of the generated image. - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - return self.__call__( - prompt=prompt, - generator=generator, - negative_prompt=negative_prompt, - image=image, - mask_image=mask_image, - num_inference_steps=num_inference_steps, # type: ignore - guidance_scale=guidance_scale, # type: ignore - self_attention_scale=self_attention_scale, - strength=strength, - num_images_per_prompt=num_images_per_prompt, - eta=eta, # type: ignore - max_embeddings_multiples=max_embeddings_multiples, - output_type=output_type, - return_dict=return_dict, - callback=callback, - is_cancelled_callback=is_cancelled_callback, - callback_steps=callback_steps, - width=width, - height=height, - seed=seed, - prompt_expansion_settings=prompt_expansion_settings, + return StableDiffusionPipelineOutput( + images=converted_image, nsfw_content_detected=False # type: ignore ) diff --git a/core/inference/pytorch/pytorch.py b/core/inference/pytorch/pytorch.py index 70a54429f..8c19656e3 100755 --- a/core/inference/pytorch/pytorch.py +++ b/core/inference/pytorch/pytorch.py @@ -5,7 +5,7 @@ import requests import torch -from diffusers.models.autoencoder_kl import AutoencoderKL +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from diffusers.models.controlnet import ControlNetModel from diffusers.models.modeling_utils import ModelMixin from diffusers.models.unet_2d_condition import UNet2DConditionModel @@ -22,18 +22,23 @@ from api.websockets.notification import Notification from core import shared from core.config import config -from core.flags import HighResFixFlag +from core.flags import AnimateDiffFlag, DeepshrinkFlag, ScalecrafterFlag from core.inference.base_model import InferenceModel -from core.inference.functions import convert_vaept_to_diffusers, load_pytorch_pipeline +from core.inference.functions import ( + convert_vaept_to_diffusers, + get_output_type, + load_pytorch_pipeline, +) from core.inference.pytorch.pipeline import StableDiffusionLongPromptWeightingPipeline from core.inference.utilities import ( change_scheduler, create_generator, image_to_controlnet_input, - scale_latents, ) from core.inference_callbacks import callback +from core.optimizations import optimize_vae from core.types import ( + ADetailerQueueEntry, Backend, ControlNetQueueEntry, Img2ImgQueueEntry, @@ -41,8 +46,6 @@ Job, SigmaScheduler, Txt2ImgQueueEntry, - UpscaleData, - UpscaleQueueEntry, ) from core.utils import convert_images_to_base64_grid, convert_to_image, resize @@ -70,10 +73,6 @@ def __init__( self.text_encoder: CLIPTextModel self.tokenizer: CLIPTokenizer self.scheduler: Any - self.feature_extractor: Any - self.requires_safety_checker: bool - self.safety_checker: Any - self.image_encoder: Any self.controlnet: Optional[ControlNetModel] = None self.current_controlnet: str = "" @@ -119,7 +118,7 @@ def load(self): self.load_textual_inversion(textural_inversion) except Exception as e: logger.warning( - f"Failed to load textual inversion {textural_inversion}: {e}" + f"({e.__class__.__name__}) Failed to load textual inversion {textural_inversion}: {e}" ) websocket_manager.broadcast_sync( Notification( @@ -142,6 +141,13 @@ def change_vae(self, vae: str) -> None: setattr(self, "original_vae", self.vae) old_vae = getattr(self, "original_vae") + # Not sure what I needed this for, but whatever + dtype = config.api.load_dtype + device = self.unet.device + + if hasattr(self.text_encoder, "v_offload_device"): + device = torch.device("cpu") + if vae == "default": self.vae = old_vae else: @@ -150,9 +156,7 @@ def change_vae(self, vae: str) -> None: f"https://huggingface.co/{vae}/raw/main/config.json" ).json()["_class_name"] cont = getattr(importlib.import_module("diffusers"), cont) - self.vae = cont.from_pretrained(vae).to( - device=old_vae.device, dtype=old_vae.dtype - ) + self.vae = cont.from_pretrained(vae).to(device, dtype) if not hasattr(self.vae.config, "block_out_channels"): setattr( self.vae.config, @@ -169,21 +173,17 @@ def change_vae(self, vae: str) -> None: if Path(vae).is_dir(): self.vae = ModelMixin.from_pretrained(vae) # type: ignore else: - self.vae = convert_vaept_to_diffusers(vae).to( - device=old_vae.device, dtype=old_vae.dtype - ) + self.vae = convert_vaept_to_diffusers(vae).to(device, dtype) else: raise FileNotFoundError(f"{vae} is not a valid path") - if isinstance(self.vae, AutoencoderKL): - if config.api.vae_slicing: - self.vae.enable_slicing() - if config.api.vae_tiling: - self.vae.enable_tiling() + # Check if text_encoder has v_offload_device, because it always + # gets wholly offloaded instead of being sequentially offloaded + if hasattr(self.text_encoder, "v_offload_device"): + from core.optimizations.offload import set_offload - logger.info(f"Successfully changed vae to {vae} of type {type(self.vae)}") - - # This is at the end 'cause I've read horror stories about pythons prefetch system + self.vae = set_offload(self.vae, torch.device(config.api.device)) # type: ignore + self.vae = optimize_vae(self.vae) # type: ignore self.vae_path = vae def unload(self) -> None: @@ -195,9 +195,6 @@ def unload(self) -> None: self.text_encoder, self.tokenizer, self.scheduler, - self.feature_extractor, - self.requires_safety_checker, - self.safety_checker, ) if hasattr(self, "image_encoder"): @@ -231,7 +228,7 @@ def manage_optional_components( load_lora_utilities(self) if not variations: - self.image_encoder = None + self.image_encoder = None # type: ignore if self.current_controlnet != target_controlnet: logging.debug(f"Old: {self.current_controlnet}, New: {target_controlnet}") @@ -249,7 +246,7 @@ def manage_optional_components( cn = ControlNetModel.from_pretrained( target_controlnet, resume_download=True, - torch_dtype=config.api.dtype, + torch_dtype=config.api.load_dtype, ) assert isinstance(cn, ControlNetModel) @@ -299,7 +296,7 @@ def create_pipe( return pipe - def txt2img(self, job: Txt2ImgQueueEntry) -> List[Image.Image]: + def txt2img(self, job: Txt2ImgQueueEntry) -> Union[List[Image.Image], torch.Tensor]: "Generate an image from a prompt" pipe = self.create_pipe( @@ -309,132 +306,110 @@ def txt2img(self, job: Txt2ImgQueueEntry) -> List[Image.Image]: generator = create_generator(job.data.seed) - total_images: List[Image.Image] = [] - shared.current_method = "txt2img" + total_images: Union[List[Image.Image], torch.Tensor] = [] + output_type = get_output_type(job) - for _ in tqdm(range(job.data.batch_count), desc="Queue", position=1): - output_type = ( - "latent" - if ( - "highres_fix" in job.flags - and HighResFixFlag(**job.flags["highres_fix"]).mode == "latent" - ) - else "pil" - ) + deepshrink = None + if "deepshrink" in job.flags: + deepshrink = DeepshrinkFlag.from_dict(job.flags["deepshrink"]) - data = pipe.text2img( - generator=generator, - prompt=job.data.prompt, - height=job.data.height, - width=job.data.width, - num_inference_steps=job.data.steps, - guidance_scale=job.data.guidance_scale, - self_attention_scale=job.data.self_attention_scale, - negative_prompt=job.data.negative_prompt, - output_type=output_type, - callback=callback, - num_images_per_prompt=job.data.batch_size, - seed=job.data.seed, - prompt_expansion_settings=job.data.prompt_to_prompt_settings, - ) - - if "highres_fix" in job.flags: - flag = job.flags["highres_fix"] - flag = HighResFixFlag.from_dict(flag) - - if flag.mode == "latent": - latents = data[0] # type: ignore - assert isinstance(latents, (torch.Tensor, torch.FloatTensor)) + scalecrafter = None + if "scalecrafter" in job.flags: + scalecrafter = ScalecrafterFlag.from_dict(job.flags["scalecrafter"]) - latents = scale_latents( - latents=latents, - scale=flag.scale, - latent_scale_mode=flag.latent_scale_mode, - ) + animatediff = AnimateDiffFlag( + motion_model="data/motion-models/v3_sd15_mm.ckpt", + ) + if "animatediff" in job.flags: + animatediff = AnimateDiffFlag.from_dict(job.flags["animatediff"]) - data = pipe.img2img( - generator=generator, - prompt=job.data.prompt, - image=latents, - height=latents.shape[2] * 8, - width=latents.shape[3] * 8, - num_inference_steps=flag.steps, - guidance_scale=job.data.guidance_scale, - self_attention_scale=job.data.self_attention_scale, - negative_prompt=job.data.negative_prompt, - output_type="pil", - callback=callback, - strength=flag.strength, - return_dict=False, - num_images_per_prompt=job.data.batch_size, - seed=job.data.seed, - prompt_expansion_settings=job.data.prompt_to_prompt_settings, - ) + for _ in tqdm(range(job.data.batch_count), desc="Queue", position=1): + if animatediff is not None and animatediff.use_pia: + base_image = pipe( # type: ignore + generator=generator, + prompt=job.data.prompt, + height=job.data.height, + width=job.data.width, + num_inference_steps=job.data.steps, + guidance_scale=job.data.guidance_scale, + self_attention_scale=job.data.self_attention_scale, + negative_prompt=job.data.negative_prompt, + output_type=output_type, + callback=callback, + num_images_per_prompt=job.data.batch_size, + seed=job.data.seed, + prompt_expansion_settings=job.data.prompt_to_prompt_settings, + deepshrink=deepshrink, + scalecrafter=scalecrafter, + animatediff=None, + )[0][0] + data = pipe( + generator=generator, + prompt=job.data.prompt, + image=base_image, + height=job.data.height, + width=job.data.width, + num_inference_steps=job.data.steps, + guidance_scale=job.data.guidance_scale, + self_attention_scale=job.data.self_attention_scale, + negative_prompt=job.data.negative_prompt, + output_type=output_type, + callback=callback, + num_images_per_prompt=job.data.batch_size, + seed=job.data.seed, + prompt_expansion_settings=job.data.prompt_to_prompt_settings, + deepshrink=deepshrink, + scalecrafter=scalecrafter, + animatediff=animatediff, + ) + else: + data = pipe( + generator=generator, + prompt=job.data.prompt, + height=job.data.height, + width=job.data.width, + num_inference_steps=job.data.steps, + guidance_scale=job.data.guidance_scale, + self_attention_scale=job.data.self_attention_scale, + negative_prompt=job.data.negative_prompt, + output_type=output_type, + callback=callback, + num_images_per_prompt=job.data.batch_size, + seed=job.data.seed, + prompt_expansion_settings=job.data.prompt_to_prompt_settings, + deepshrink=deepshrink, + scalecrafter=scalecrafter, + animatediff=animatediff, + ) - else: - from core.shared_dependent import gpu - - images = data[0] # type: ignore - assert isinstance(images, List) - - upscaled_images = [] - for image in images: - output: tuple[Image.Image, float] = gpu.upscale( - UpscaleQueueEntry( - data=UpscaleData( - id=job.data.id, - # FastAPI validation error, we need to do this so that we can pass in a PIL image - image=image, # type: ignore - upscale_factor=flag.scale, - ), - model=flag.image_upscaler, - save_image=False, - ) - ) - upscaled_images.append(output[0]) - - data = pipe.img2img( - generator=generator, - prompt=job.data.prompt, - image=upscaled_images[0], - height=int(flag.scale * job.data.height), - width=int(flag.scale * job.data.width), - num_inference_steps=flag.steps, - guidance_scale=job.data.guidance_scale, - self_attention_scale=job.data.self_attention_scale, - negative_prompt=job.data.negative_prompt, - output_type="pil", - callback=callback, - strength=flag.strength, - return_dict=False, - num_images_per_prompt=job.data.batch_size, - seed=job.data.seed, - prompt_expansion_settings=job.data.prompt_to_prompt_settings, - ) + images: Union[List[Image.Image], torch.Tensor] = data[0] # type: ignore - images: list[Image.Image] = data[0] # type: ignore - - total_images.extend(images) - - websocket_manager.broadcast_sync( - data=Data( - data_type="txt2img", - data={ - "progress": 0, - "current_step": 0, - "total_steps": 0, - "image": convert_images_to_base64_grid( - total_images, - quality=config.api.image_quality, - image_format=config.api.image_extension, - ), - }, + if not isinstance(images, List): + total_images = images + else: + assert isinstance(total_images, List) + total_images.extend(images) + + if isinstance(total_images, List): + websocket_manager.broadcast_sync( + data=Data( + data_type=shared.current_method or "txt2img", + data={ + "progress": 0, + "current_step": 0, + "total_steps": 0, + "image": convert_images_to_base64_grid( + total_images, # type: ignore + quality=config.api.image_quality, + image_format=config.api.image_extension, + ), + }, + ) ) - ) return total_images - def img2img(self, job: Img2ImgQueueEntry) -> List[Image.Image]: + def img2img(self, job: Img2ImgQueueEntry) -> Union[List[Image.Image], torch.Tensor]: "Generate an image from an image" pipe = self.create_pipe( @@ -445,14 +420,27 @@ def img2img(self, job: Img2ImgQueueEntry) -> List[Image.Image]: generator = create_generator(job.data.seed) # Preprocess the image - input_image = convert_to_image(job.data.image) - input_image = resize(input_image, job.data.width, job.data.height) + if isinstance(job.data.image, (str, bytes, Image.Image)): + input_image = convert_to_image(job.data.image) + input_image = resize(input_image, job.data.width, job.data.height) + else: + input_image = job.data.image + + total_images: Union[List[Image.Image], torch.Tensor] = [] + output_type = get_output_type(job) + + deepshrink = None + if "deepshrink" in job.flags: + deepshrink = DeepshrinkFlag.from_dict(job.flags["deepshrink"]) - total_images: List[Image.Image] = [] - shared.current_method = "img2img" + animatediff = AnimateDiffFlag( + motion_model="data/motion-models/v3_sd15_mm.ckpt", + ) + if "animatediff" in job.flags: + animatediff = AnimateDiffFlag.from_dict(job.flags["animatediff"]) for _ in tqdm(range(job.data.batch_count), desc="Queue", position=1): - data = pipe.img2img( + data = pipe( generator=generator, prompt=job.data.prompt, image=input_image, @@ -462,42 +450,45 @@ def img2img(self, job: Img2ImgQueueEntry) -> List[Image.Image]: guidance_scale=job.data.guidance_scale, self_attention_scale=job.data.self_attention_scale, negative_prompt=job.data.negative_prompt, - output_type="pil", + output_type=output_type, callback=callback, strength=job.data.strength, return_dict=False, num_images_per_prompt=job.data.batch_size, seed=job.data.seed, prompt_expansion_settings=job.data.prompt_to_prompt_settings, + deepshrink=deepshrink, + animatediff=animatediff, ) - if not data: - raise ValueError("No data returned from pipeline") - - images = data[0] - assert isinstance(images, List) - - total_images.extend(images) - - websocket_manager.broadcast_sync( - data=Data( - data_type="img2img", - data={ - "progress": 0, - "current_step": 0, - "total_steps": 0, - "image": convert_images_to_base64_grid( - total_images, - quality=config.api.image_quality, - image_format=config.api.image_extension, - ), - }, + images: Union[List[Image.Image], torch.Tensor] = data[0] # type: ignore + + if not isinstance(images, List): + total_images = images + else: + assert isinstance(total_images, List) + total_images.extend(images) + + if isinstance(total_images, List): + websocket_manager.broadcast_sync( + data=Data( + data_type=shared.current_method or "img2img", + data={ + "progress": 0, + "current_step": 0, + "total_steps": 0, + "image": convert_images_to_base64_grid( + total_images, # type: ignore + quality=config.api.image_quality, + image_format=config.api.image_extension, + ), + }, + ) ) - ) return total_images - def inpaint(self, job: InpaintQueueEntry) -> List[Image.Image]: + def inpaint(self, job: InpaintQueueEntry) -> Union[List[Image.Image], torch.Tensor]: "Generate an image from an image" pipe = self.create_pipe( @@ -515,11 +506,15 @@ def inpaint(self, job: InpaintQueueEntry) -> List[Image.Image]: input_mask_image = ImageOps.invert(input_mask_image) input_mask_image = resize(input_mask_image, job.data.width, job.data.height) - total_images: List[Image.Image] = [] - shared.current_method = "inpainting" + total_images: Union[List[Image.Image], torch.Tensor] = [] + output_type = get_output_type(job) + + deepshrink = None + if "deepshrink" in job.flags: + deepshrink = DeepshrinkFlag.from_dict(job.flags["deepshrink"]) for _ in tqdm(range(job.data.batch_count), desc="Queue", position=1): - data = pipe.inpaint( + data = pipe( generator=generator, prompt=job.data.prompt, image=input_image, @@ -528,7 +523,7 @@ def inpaint(self, job: InpaintQueueEntry) -> List[Image.Image]: guidance_scale=job.data.guidance_scale, self_attention_scale=job.data.self_attention_scale, negative_prompt=job.data.negative_prompt, - output_type="pil", + output_type=output_type, callback=callback, return_dict=False, num_images_per_prompt=job.data.batch_size, @@ -536,35 +531,40 @@ def inpaint(self, job: InpaintQueueEntry) -> List[Image.Image]: height=job.data.height, seed=job.data.seed, prompt_expansion_settings=job.data.prompt_to_prompt_settings, + deepshrink=deepshrink, + strength=job.data.strength, ) - if not data: - raise ValueError("No data returned from pipeline") - - images = data[0] - assert isinstance(images, List) - - total_images.extend(images) - - websocket_manager.broadcast_sync( - data=Data( - data_type="inpainting", - data={ - "progress": 0, - "current_step": 0, - "total_steps": 0, - "image": convert_images_to_base64_grid( - total_images, - quality=config.api.image_quality, - image_format=config.api.image_extension, - ), - }, + images: Union[List[Image.Image], torch.Tensor] = data[0] # type: ignore + + if not isinstance(images, List): + total_images = images + else: + assert isinstance(total_images, List) + total_images.extend(images) + + if isinstance(total_images, List): + websocket_manager.broadcast_sync( + data=Data( + data_type=shared.current_method or "inpainting", + data={ + "progress": 0, + "current_step": 0, + "total_steps": 0, + "image": convert_images_to_base64_grid( + total_images, # type: ignore + quality=config.api.image_quality, + image_format=config.api.image_extension, + ), + }, + ) ) - ) return total_images - def controlnet2img(self, job: ControlNetQueueEntry) -> List[Image.Image]: + def controlnet2img( + self, job: ControlNetQueueEntry + ) -> Union[List[Image.Image], torch.Tensor]: "Generate an image from an image and controlnet conditioning" if config.api.trace_model is True: @@ -590,8 +590,8 @@ def controlnet2img(self, job: ControlNetQueueEntry) -> List[Image.Image]: input_image = image_to_controlnet_input(input_image, job.data) logger.debug(f"Preprocessed image size: {input_image.size}") - total_images: List[Image.Image] = [input_image] - shared.current_method = "controlnet" + total_images: Union[List[Image.Image], torch.Tensor] = [] + output_type = get_output_type(job) for _ in tqdm(range(job.data.batch_count), desc="Queue", position=1): data = pipe( @@ -601,7 +601,7 @@ def controlnet2img(self, job: ControlNetQueueEntry) -> List[Image.Image]: image=input_image, num_inference_steps=job.data.steps, guidance_scale=job.data.guidance_scale, - output_type="pil", + output_type=output_type, callback=callback, return_dict=False, num_images_per_prompt=job.data.batch_size, @@ -612,32 +612,66 @@ def controlnet2img(self, job: ControlNetQueueEntry) -> List[Image.Image]: prompt_expansion_settings=job.data.prompt_to_prompt_settings, ) - images = data[0] # type: ignore - assert isinstance(images, List) - - total_images.extend(images) # type: ignore - - websocket_manager.broadcast_sync( - data=Data( - data_type="controlnet", - data={ - "progress": 0, - "current_step": 0, - "total_steps": 0, - "image": convert_images_to_base64_grid( - total_images - if job.data.return_preprocessed - else total_images[1:], - quality=config.api.image_quality, - image_format=config.api.image_extension, - ), - }, + images: Union[List[Image.Image], torch.Tensor] = data[0] # type: ignore + + if not isinstance(images, List): + total_images = images + else: + assert isinstance(total_images, List) + total_images.extend(images) + + if job.data.return_preprocessed and isinstance(total_images, List): + total_images.append(input_image) + + if isinstance(total_images, List): + websocket_manager.broadcast_sync( + data=Data( + data_type=shared.current_method or "controlnet", + data={ + "progress": 0, + "current_step": 0, + "total_steps": 0, + "image": convert_images_to_base64_grid( + total_images # type: ignore + if job.data.return_preprocessed + else total_images[1:], + quality=config.api.image_quality, + image_format=config.api.image_extension, + ), + }, + ) ) - ) return total_images - def generate(self, job: Job): + def adetailer( + self, + job: ADetailerQueueEntry, + ): + from ..adetailer.adetailer import ADetailer + + data = job.data + assert data is not None + + entry = InpaintQueueEntry( + data=data, + model=job.model, + save_image=job.save_image, + ) + + output = ADetailer().generate( + fn=self.inpaint, + inpaint_entry=entry, + mask_dilation=job.mask_dilation, + mask_blur=job.mask_blur, + mask_padding=job.mask_padding, + upscale=job.upscale, + iterations=job.iterations, + ) + + return [*output.images, *output.init_images] + + def generate(self, job: Job) -> Union[List[Image.Image], torch.Tensor]: "Generate images from the queue" logging.info(f"Adding job {job.data.id} to queue") @@ -651,6 +685,8 @@ def generate(self, job: Job): images = self.inpaint(job) elif isinstance(job, ControlNetQueueEntry): images = self.controlnet2img(job) + elif isinstance(job, ADetailerQueueEntry): + images = self.adetailer(job) else: raise ValueError("Invalid job type for this pipeline") except Exception as e: @@ -716,7 +752,13 @@ def load_textual_inversion(self, textual_inversion: str): token = Path(textual_inversion).stem logger.info(f"Loading token {token} for textual inversion model") - pipe.load_textual_inversion(textual_inversion, token=token) + try: + pipe.load_textual_inversion(textual_inversion, token=token) + except ValueError as e: + if "Loaded state dictonary is incorrect" in str(e): + logger.info(f"Assuming {textual_inversion} is for SDXL, skipping") + return + raise e self.textual_inversions.append(textual_inversion) logger.info(f"Textual inversion model {textual_inversion} loaded successfully") diff --git a/core/inference/pytorch/sag/__init__.py b/core/inference/pytorch/sag/__init__.py deleted file mode 100644 index d31c422ae..000000000 --- a/core/inference/pytorch/sag/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .cross_attn import CrossAttnStoreProcessor -from .sag_utils import pred_epsilon, pred_x0, sag_masking diff --git a/core/inference/pytorch/sag/cross_attn.py b/core/inference/pytorch/sag/cross_attn.py deleted file mode 100644 index 2a30cd88c..000000000 --- a/core/inference/pytorch/sag/cross_attn.py +++ /dev/null @@ -1,46 +0,0 @@ -import torch - - -class CrossAttnStoreProcessor: - "Modified Cross Attention Processor with capabilities to store probabilities." - - def __init__(self): - self.attention_probs = None - - def __call__( - self, - attn, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - ): - batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask( - attention_mask, sequence_length, batch_size - ) - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states( - encoder_hidden_states - ) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - query = attn.head_to_batch_dim(query) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - self.attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = torch.bmm(self.attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - return hidden_states diff --git a/core/inference/sdxl/__init__.py b/core/inference/sdxl/__init__.py new file mode 100644 index 000000000..6717b06f2 --- /dev/null +++ b/core/inference/sdxl/__init__.py @@ -0,0 +1 @@ +from .sdxl import SDXLStableDiffusion diff --git a/core/inference/sdxl/pipeline.py b/core/inference/sdxl/pipeline.py new file mode 100644 index 000000000..cb5ec13a2 --- /dev/null +++ b/core/inference/sdxl/pipeline.py @@ -0,0 +1,765 @@ +import logging +from contextlib import ExitStack +from typing import Callable, List, Literal, Optional, Union +import inspect + +import PIL +import torch +from diffusers.models.adapter import MultiAdapter +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL +from diffusers.models.unet_2d_condition import UNet2DConditionModel +from diffusers.pipelines.stable_diffusion.pipeline_output import ( + StableDiffusionPipelineOutput, +) +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import ( + StableDiffusionXLPipeline, +) +from diffusers.schedulers.scheduling_lms_discrete import LMSDiscreteScheduler +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from PIL import Image +from tqdm import tqdm +from transformers.models.clip import ( + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, +) + +from core.config import config +from core.flags import SDXLRefinerFlag, DeepshrinkFlag, ScalecrafterFlag +from core.inference.utilities import ( + calculate_cfg, + full_vae, + get_timesteps, + get_weighted_text_embeddings, + numpy_to_pil, + philox, + prepare_extra_step_kwargs, + prepare_latents, + prepare_mask_and_masked_image, + prepare_mask_latents, + preprocess_adapter_image, + preprocess_image, + preprocess_mask, + sag, + modify_kohya, + postprocess_kohya, + get_scalecrafter_config, + post_scalecrafter, + step_scalecrafter, + setup_scalecrafter, + ScalecrafterSettings, +) +from core.optimizations import ensure_correct_device, inference_context, unload_all +from core.scheduling import KdiffusionSchedulerAdapter + +# ------------------------------------------------------------------------------ + +logger = logging.getLogger(__name__) + + +class StableDiffusionXLLongPromptWeightingPipeline(StableDiffusionXLPipeline): + def __init__( + self, + parent, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: SchedulerMixin, + aesthetic_score: bool, + force_zeros: bool, + ): + super().__init__( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, # type: ignore + ) + self.__init__additional__() + + self.parent = parent + self.aesthetic_score: bool = aesthetic_score + self.force_zeros: bool = force_zeros + self.vae: AutoencoderKL + self.text_encoder: CLIPTextModel + self.text_encoder_2: CLIPTextModelWithProjection + self.tokenizer: CLIPTokenizer + self.tokenizer_2: CLIPTokenizer + self.unet: UNet2DConditionModel + self.scheduler: LMSDiscreteScheduler + + def __init__additional__(self): + if not hasattr(self, "vae_scale_factor"): + setattr( + self, + "vae_scale_factor", + 2 ** (len(self.vae.config.block_out_channels) - 1), # type: ignore + ) + + def _default_height_width(self, height, width, image): + if image is None: + return height, width + + # NOTE: It is possible that a list of images have different + # dimensions for each image, so just checking the first image + # is not _exactly_ correct, but it is simple. + while isinstance(image, list): + image = image[0] + + if height is None: + if isinstance(image, Image.Image): + height = image.height + elif isinstance(image, torch.Tensor): + height = image.shape[-2] + + # round down to nearest multiple of `self.adapter.downscale_factor` + if hasattr(self, "adapter") and self.adapter is not None: + height = ( + height // self.adapter.downscale_factor + ) * self.adapter.downscale_factor + + if width is None: + if isinstance(image, Image.Image): + width = image.width + elif isinstance(image, torch.Tensor): + width = image.shape[-1] + + # round down to nearest multiple of `self.adapter.downscale_factor` + if hasattr(self, "adapter") and self.adapter is not None: + width = ( + width // self.adapter.downscale_factor + ) * self.adapter.downscale_factor + + return height, width + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + return torch.device(config.api.device) + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + negative_prompt, + max_embeddings_multiples, + seed, + prompt_expansion_settings=None, + ): + if negative_prompt == "": + negative_prompt = None + + prompts = [prompt, prompt] + negative_prompts = [negative_prompt, negative_prompt] + tokenizers = ( + [self.tokenizer, self.tokenizer_2] + if self.tokenizer is not None + else [self.tokenizer_2] + ) + text_encoders = ( + [self.text_encoder, self.text_encoder_2] + if self.text_encoder is not None + else [self.text_encoder_2] + ) + + prompt_embeds_list = [] + uncond_embeds_list = [] + for prompt, negative_prompt, tokenizer, text_encoder in zip( + prompts, negative_prompts, tokenizers, text_encoders + ): + ensure_correct_device(text_encoder) + prompt = self.maybe_convert_prompt(prompt, tokenizer) + logger.debug(f"Post textual prompt: {prompt}") + + if negative_prompt is not None: + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + logger.debug(f"Post textual negative_prompt: {negative_prompt}") + + ( + text_embeddings, + pooled_embeddings, + uncond_embeddings, + uncond_pooled_embeddings, + ) = get_weighted_text_embeddings( + pipe=self.parent, # type: ignore + prompt=prompt, + uncond_prompt="" if negative_prompt is None and not self.force_zeros else negative_prompt, # type: ignore + max_embeddings_multiples=max_embeddings_multiples, + seed=seed, + prompt_expansion_settings=prompt_expansion_settings, + tokenizer=tokenizer, + text_encoder=text_encoder, + ) + if negative_prompt is None and self.force_zeros: + uncond_embeddings = torch.zeros_like(text_embeddings) + uncond_pooled_embeddings = torch.zeros_like(pooled_embeddings) + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view( + bs_embed * num_images_per_prompt, seq_len, -1 + ) + + bs_embed, seq_len, _ = uncond_embeddings.shape # type: ignore + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) # type: ignore + uncond_embeddings = uncond_embeddings.view( + bs_embed * num_images_per_prompt, seq_len, -1 + ) + + prompt_embeds_list.append(text_embeddings) + uncond_embeds_list.append(uncond_embeddings) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + uncond_embeds = torch.concat(uncond_embeds_list, dim=-1) + + bs_embed = pooled_embeddings.shape[0] # type: ignore + pooled_embeddings = pooled_embeddings.repeat(1, num_images_per_prompt) # type: ignore + pooled_embeddings = pooled_embeddings.view(bs_embed * num_images_per_prompt, -1) + + bs_embed = uncond_pooled_embeddings.shape[0] # type: ignore + uncond_pooled_embeddings = uncond_pooled_embeddings.repeat(1, num_images_per_prompt) # type: ignore + uncond_pooled_embeddings = uncond_pooled_embeddings.view( + bs_embed * num_images_per_prompt, -1 + ) + + # Only the last one is necessary + return prompt_embeds.to(device), uncond_embeds.to(device), pooled_embeddings.to(device), uncond_pooled_embeddings.to(device) # type: ignore + + def _get_add_time_ids( + self, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + dtype, + ): + if self.aesthetic_score: + add_time_ids = list( + original_size + crops_coords_top_left + (aesthetic_score,) + ) + add_neg_time_ids = list( + original_size + crops_coords_top_left + (negative_aesthetic_score,) + ) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim # type: ignore + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features # type: ignore + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim # type: ignore + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim # type: ignore + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + + @torch.no_grad() + def __call__( + self, + prompt: str, + generator: Union[torch.Generator, philox.PhiloxGenerator], + seed: int, + aesthetic_score: float = 6.0, + negative_aesthetic_score: float = 2.5, + original_size: Optional[List[int]] = [1024, 1024], + negative_prompt: Optional[str] = None, + image: Union[torch.FloatTensor, PIL.Image.Image] = None, # type: ignore + mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, # type: ignore + self_attention_scale: float = 0.0, + height: int = 1024, + width: int = 1024, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + strength: float = 0.8, + num_images_per_prompt: int = 1, + eta: float = 0.0, + latents: Optional[torch.FloatTensor] = None, + max_embeddings_multiples: Optional[int] = 100, + output_type: Literal["pil", "latent"] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, + callback_steps: int = 1, + prompt_expansion_settings=None, + adapter_conditioning_scale: Union[float, List[float]] = 1.0, + adapter_conditioning_factor: float = 1.0, + refiner: Optional[SDXLRefinerFlag] = None, + refiner_model: Optional["StableDiffusionXLLongPromptWeightingPipeline"] = None, + deepshrink: Optional[DeepshrinkFlag] = None, + scalecrafter: Optional[ScalecrafterFlag] = None, # type: ignore + ): + if original_size is None: + original_size = [height, width] + + if config.api.torch_compile: + self.unet = torch.compile( + self.unet, + fullgraph=config.api.torch_compile_fullgraph, + dynamic=config.api.torch_compile_dynamic, + mode=config.api.torch_compile_mode, + ) # type: ignore + + # 0. Default height and width to unet + with inference_context(self.unet, self.vae, height, width) as context: # type: ignore + self.unet = context.unet # type: ignore + self.vae = context.vae # type: ignore + + if scalecrafter is not None: + unsafe = scalecrafter.unsafe_resolutions # type: ignore + scalecrafter: ScalecrafterSettings = get_scalecrafter_config("sd15", height, width, scalecrafter.disperse) # type: ignore + logger.info( + f'Applying ScaleCrafter with (base="{scalecrafter.base}", res="{scalecrafter.height * 8}x{scalecrafter.width * 8}", dis="{scalecrafter.disperse is not None}")' + ) + if not unsafe and ( + (scalecrafter.height * 8) != height + or (scalecrafter.width * 8) != width + ): + height, width = scalecrafter.height * 8, scalecrafter.width * 8 + + refiner_steps = 10000 + if refiner_model is not None: + assert refiner is not None + num_inference_steps += refiner.steps + refiner_steps = num_inference_steps // (refiner.strength + 1) + + refiner_model = refiner_model.unet # type: ignore + aesthetic_score = refiner.aesthetic_score + negative_aesthetic_score = refiner.negative_aesthetic_score + + original_size = tuple(original_size) # type: ignore + height, width = self._default_height_width(height, width, image) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + do_self_attention_guidance = self_attention_scale > 1.0 + split_latents_into_two = ( + not config.api.batch_cond_uncond and do_classifier_free_guidance + ) + + # 3. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + negative_prompt, + max_embeddings_multiples, + seed, + prompt_expansion_settings=prompt_expansion_settings, + ) + dtype = prompt_embeds.dtype + + adapter_input = None # type: ignore + if hasattr(self, "adapter"): + if isinstance(self.adapter, MultiAdapter): + adapter_input: list = [] # type: ignore + + if not isinstance(adapter_conditioning_scale, list): + adapter_conditioning_scale = [ + adapter_conditioning_scale * len(image) + ] + + for oi in image: + oi = preprocess_adapter_image(oi, height, width) + oi = oi.to(device, dtype) # type: ignore + adapter_input.append(oi) # type: ignore + else: + adapter_input: torch.Tensor = preprocess_adapter_image( # type: ignore + adapter_input, height, width + ) + adapter_input.to(device, dtype) + + # 4. Preprocess image and mask + if isinstance(image, Image.Image): + image = preprocess_image(image) + if image is not None: + image = image.to(device=device, dtype=dtype) + if isinstance(mask_image, Image.Image): + mask_image = preprocess_mask(mask_image) + mask_image = mask_image.to(device=device, dtype=dtype) + if mask_image is not None: + mask, masked_image, _ = prepare_mask_and_masked_image( + image, mask_image, height, width + ) + mask, _ = prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + dtype, + device, + do_classifier_free_guidance, + self.vae, + self.vae_scale_factor, + self.vae.config.scaling_factor, # type: ignore + generator=generator, + ) + else: + mask = None + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) # type: ignore + timesteps, num_inference_steps = get_timesteps( + self.scheduler, + num_inference_steps, + strength, + device, + image is None or hasattr(self, "controlnet"), + ) + if isinstance(self.scheduler, KdiffusionSchedulerAdapter): + self.scheduler.timesteps = timesteps + self.scheduler.steps = num_inference_steps + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + latents, init_latents_orig, noise = prepare_latents( + self, # type: ignore + image, + latent_timestep, + batch_size * num_images_per_prompt, + height, + width, + dtype, + device, + generator, + None, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = prepare_extra_step_kwargs( + scheduler=self.scheduler, generator=generator, eta=eta, device=device + ) + + setup_scalecrafter(self.unet, scalecrafter) # type: ignore + + if hasattr(self, "adapter"): + if isinstance(self.adapter, MultiAdapter): + adapter_state = self.adapter( + adapter_input, adapter_conditioning_scale + ) + for k, v in enumerate(adapter_state): + adapter_state[k] = v + else: + adapter_state = self.adapter(adapter_input) + for k, v in enumerate(adapter_state): + adapter_state[k] = v * adapter_conditioning_scale + if num_images_per_prompt > 1: + for k, v in enumerate(adapter_state): + adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1) + if do_classifier_free_guidance: + for k, v in enumerate(adapter_state): + adapter_state[k] = torch.cat([v] * 2, dim=0) + add_text_embeds = pooled_prompt_embeds + add_time_ids, add_neg_time_ids = self._get_add_time_ids( + original_size, + (0, 0), + (height, width), + aesthetic_score, + negative_aesthetic_score, + dtype, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat( + [negative_prompt_embeds, prompt_embeds], dim=0 + ) + add_text_embeds = torch.cat( + [negative_pooled_prompt_embeds, add_text_embeds], dim=0 + ) + add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) # type: ignore + + cutoff = num_inference_steps * adapter_conditioning_factor + # 8. Denoising loop + j = 0 + un = self.unet + + if do_self_attention_guidance: + store_processor = sag.CrossAttnStoreProcessor() + self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor # type: ignore + + map_size = None + + def get_map_size(_, __, output): + nonlocal map_size + map_size = output[0].shape[ + -2: + ] # output.sample.shape[-2:] in older diffusers + + classify = do_classifier_free_guidance + + def do_denoise( + x: torch.Tensor, + t: torch.IntTensor, + call: Callable[..., torch.Tensor], + change_source: Callable[[Callable], None], + ) -> torch.Tensor: + nonlocal j, un, do_classifier_free_guidance + + un = modify_kohya(un, j, num_inference_steps, deepshrink) # type: ignore + un = step_scalecrafter(un, scalecrafter, j, num_inference_steps) + + tau = j / num_inference_steps + + do_classifier_free_guidance = ( + classify and tau <= config.api.cfg_uncond_tau + ) + + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([x] * 2) if do_classifier_free_guidance and not split_latents_into_two else x # type: ignore + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # type: ignore + + if j >= refiner_steps: + assert refiner_model is not None + un = refiner_model + + down_intrablock_additional_residuals = None + if hasattr(self, "adapter") and self.adapter is not None: + if j < cutoff: + assert adapter_state is not None + down_intrablock_additional_residuals = [ + state.clone() for state in adapter_state + ] + + change_source(un) + kwargs = set( + inspect.signature(un.forward).parameters.keys() # type: ignore + ) + ensure_correct_device(un) # type: ignore + + _kwargs = { + "added_cond_kwargs": { + "text_embeds": add_text_embeds, + "time_ids": add_time_ids, + }, + "down_intrablock_additional_residuals": down_intrablock_additional_residuals, + "order": j, + "drop_encode_decode": config.api.drop_encode_decode != "off", + } + if split_latents_into_two: + uncond, cond = prompt_embeds.chunk(2) + uncond_text, cond_text = add_text_embeds.chunk(2) + uncond_time, cond_time = add_time_ids.chunk(2) + uncond_intra, cond_intra = None, None + if down_intrablock_additional_residuals is not None: + uncond_intra, cond_intra = [], [] + for s in down_intrablock_additional_residuals: + unc, cnd = s.chunk(2) + uncond_intra.append(unc) + cond_intra.append(cnd) + + added_cond_kwargs = { + "text_embeds": cond_text, + "time_ids": cond_time, + } + added_uncond_kwargs = { + "text_embeds": uncond_text, + "time_ids": uncond_time, + } + + _kwargs.update( + { + "added_cond_kwargs": added_cond_kwargs, + "down_intrablock_additional_residuals": cond_intra, + } + ) + for kw, _ in _kwargs.copy().items(): + if kw not in kwargs: + del _kwargs[kw] + noise_pred_text = call(latent_model_input, t, cond=cond, **_kwargs) + + _kwargs.update( + { + "added_cond_kwargs": added_uncond_kwargs, + "down_intrablock_additional_residuals": uncond_intra, + } + ) + for kw, _ in _kwargs.copy().items(): + if kw not in kwargs: + del _kwargs[kw] + noise_pred_uncond = call( + latent_model_input, t, cond=uncond, **_kwargs + ) + else: + for kw, _ in _kwargs.copy().items(): + if kw not in kwargs: + del _kwargs[kw] + noise_pred = call(latent_model_input, t, cond=prompt_embeds, **_kwargs) # type: ignore + + un, noise_pred_vanilla = post_scalecrafter( + self.unet, + scalecrafter, + j, + num_inference_steps, + call, + latent_model_input, + t, + cond=prompt_embeds, + down_intrablock_additional_residuals=down_intrablock_additional_residuals, + ) + + # perform guidance + if do_classifier_free_guidance: + if not split_latents_into_two: + # if isinstance(noise_pred, list): # type: ignore + # noise_pred = noise_pred[0] + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) # type: ignore + noise_pred = calculate_cfg( + j, noise_pred_text, noise_pred_uncond, guidance_scale, t, additional_pred=noise_pred_vanilla # type: ignore + ) + + if do_self_attention_guidance: + if not do_classifier_free_guidance: + noise_pred_uncond = noise_pred # type: ignore + noise_pred += sag.calculate_sag( # type: ignore + self, + call, + store_processor, # type: ignore + x, + noise_pred_uncond, # type: ignore + t, + map_size, # type: ignore + prompt_embeds, + self_attention_scale, + guidance_scale, + config.api.load_dtype, + ) + + if not isinstance(self.scheduler, KdiffusionSchedulerAdapter): + # compute the previous noisy sample x_t -> x_t-1 + x = self.scheduler.step( # type: ignore + noise_pred, t.to(noise_pred.device), x.to(noise_pred.device), **extra_step_kwargs # type: ignore + ).prev_sample # type: ignore + else: + x = noise_pred # type: ignore + + if mask is not None: + # masking + init_latents_proper = self.scheduler.add_noise( # type: ignore + init_latents_orig, noise, torch.tensor([t]) # type: ignore + ) + x = (init_latents_proper * mask) + (x * (1 - mask)) # type: ignore + j += 1 + un = postprocess_kohya(un) + return x + + if do_self_attention_guidance: + pass + + ensure_correct_device(self.unet) + latents = latents.to(dtype=dtype) # type: ignore + if init_latents_orig is not None: + init_latents_orig = init_latents_orig.to(dtype=dtype) + with ExitStack() as gs: + if do_self_attention_guidance: + gs.enter_context(self.unet.mid_block.attentions[0].register_forward_hook(get_map_size)) # type: ignore + + if isinstance(self.scheduler, KdiffusionSchedulerAdapter): + latents = self.scheduler.do_inference( + latents, # type: ignore + generator=generator, + call=self.unet, # type: ignore + apply_model=do_denoise, + callback=callback, + callback_steps=callback_steps, + ) + else: + s = self.unet + + def change(src): + nonlocal s + s = src + + def _call(*args, **kwargs): + if len(args) == 3: + encoder_hidden_states = args[-1] + args = args[:2] + if kwargs.get("cond", None) is not None: + encoder_hidden_states = kwargs.pop("cond") + ret = s( + *args, + encoder_hidden_states=encoder_hidden_states, # type: ignore + return_dict=False, + **kwargs, + ) + if isinstance(s, UNet2DConditionModel): + return ret[0] + return ret + + for i, t in enumerate(tqdm(timesteps, desc="SDXL")): + latents = do_denoise(latents, t, _call, change) # type: ignore + + # call the callback, if provided + if i % callback_steps == 0: + if callback is not None: + callback(i, t, latents) # type: ignore + if ( + is_cancelled_callback is not None + and is_cancelled_callback() + ): + return None + + # 9. Post-processing + if output_type == "latent": + unload_all() + return latents, False + + image = full_vae(latents, self.vae, height=height, width=width) # type: ignore + + # 11. Convert to PIL + if output_type == "pil": + image = numpy_to_pil(image) # type: ignore + + unload_all() + + if not return_dict: + return image, False + + return StableDiffusionPipelineOutput( + images=image, nsfw_content_detected=False # type: ignore + ) diff --git a/core/inference/sdxl/sdxl.py b/core/inference/sdxl/sdxl.py new file mode 100644 index 000000000..e1fbd8ffd --- /dev/null +++ b/core/inference/sdxl/sdxl.py @@ -0,0 +1,668 @@ +import logging +from pathlib import Path +from typing import Any, List, Optional, Tuple, Union + +import torch +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL +from diffusers.models.unet_2d_condition import UNet2DConditionModel +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import ( + StableDiffusionXLPipeline, +) +from PIL import Image, ImageOps +from safetensors.torch import load_file +from tqdm import tqdm +from transformers.models.clip.modeling_clip import ( + CLIPTextModel, + CLIPTextModelWithProjection, +) +from transformers.models.clip.tokenization_clip import CLIPTokenizer + +from api import websocket_manager +from api.websockets import Data +from api.websockets.notification import Notification +from core import shared +from core.config import config +from core.files import get_full_model_path +from core.flags import SDXLFlag, SDXLRefinerFlag, DeepshrinkFlag, ScalecrafterFlag +from core.inference.base_model import InferenceModel +from core.inference.functions import ( + convert_vaept_to_diffusers, + get_output_type, + load_pytorch_pipeline, +) +from core.inference.utilities import change_scheduler, create_generator +from core.inference_callbacks import callback +from core.optimizations import optimize_vae +from core.types import ( + ADetailerQueueEntry, + Backend, + Img2ImgQueueEntry, + InpaintQueueEntry, + Job, + SigmaScheduler, + Txt2ImgQueueEntry, +) +from core.utils import convert_images_to_base64_grid, convert_to_image, resize + +from .pipeline import StableDiffusionXLLongPromptWeightingPipeline + +logger = logging.getLogger(__name__) + + +class SDXLStableDiffusion(InferenceModel): + "High level model wrapper for SDXL models" + + def __init__( + self, + model_id: str, + device: str = "cuda", + autoload: bool = True, + bare: bool = False, + ) -> None: + super().__init__(model_id, device) + + self.backend: Backend = "PyTorch" + self.type = "SDXL" + self.bare: bool = bare + + # Components + self.vae: AutoencoderKL + self.unet: UNet2DConditionModel + self.text_encoder: CLIPTextModel + self.text_encoder_2: CLIPTextModelWithProjection + self.tokenizer: CLIPTokenizer + self.tokenizer_2: CLIPTokenizer + self.force_zeros: bool + self.aesthetic_score: bool + self.scheduler: Any + self.final_offload_hook: Any = None + + self.image_encoder: Any + + self.vae_path: str = "default" + self.unload_loras: List[str] = [] + self.unload_lycoris: List[str] = [] + + self.textual_inversions: List[str] = [] + + if autoload: + self.load() + + def load(self): + "Load the model from HuggingFace" + + logger.info(f"Loading {self.model_id} with {config.api.data_type}") + + pipe = load_pytorch_pipeline( + self.model_id, + device=self.device, + optimize=not self.bare, + ) + + self.vae = pipe.vae # type: ignore + self.unet = pipe.unet # type: ignore + self.text_encoder = pipe.text_encoder # type: ignore + self.text_encoder_2 = pipe.text_encoder_2 # type: ignore + self.tokenizer = pipe.tokenizer # type: ignore + self.tokenizer_2 = pipe.tokenizer_2 # type: ignore + self.scheduler = pipe.scheduler # type: ignore + if hasattr(pipe.config, "requires_aesthetics_score"): + self.aesthetic_score = pipe.config.requires_aesthetics_score # type: ignore + else: + self.aesthetic_score = False + self.force_zeros = pipe.config.force_zeros_for_empty_prompt # type: ignore + + if not self.bare: + # Autoload textual inversions + for textural_inversion in config.api.autoloaded_textual_inversions: + try: + self.load_textual_inversion(textural_inversion) + except Exception as e: + logger.warning( + f"({e.__class__.__name__}) Failed to load textual inversion {textural_inversion}: {e}" + ) + websocket_manager.broadcast_sync( + Notification( + severity="error", + message=f"Failed to load textual inversion: {textural_inversion}", + title="Autoload Error", + ) + ) + + # Free up memory + del pipe + self.memory_cleanup() + + def change_vae(self, vae: str) -> None: + "Change the vae to the one specified" + + if self.vae_path == "default": + setattr(self, "original_vae", self.vae) + + old_vae = getattr(self, "original_vae") + dtype = config.api.load_dtype + device = self.unet.device + + if hasattr(self.text_encoder, "v_offload_device"): + device = torch.device("cpu") + + if vae == "default": + self.vae = old_vae + else: + full_path = get_full_model_path(vae) + if full_path.is_dir(): + self.vae = AutoencoderKL.from_pretrained(full_path).to( # type: ignore + device=device, dtype=dtype + ) + else: + self.vae = convert_vaept_to_diffusers(full_path.as_posix()).to( + device=device, dtype=dtype + ) + + # Check if text_encoder has v_offload_device, because it always + # gets wholly offloaded instead of being sequentially offloaded + if hasattr(self.text_encoder, "v_offload_device"): + from core.optimizations.offload import set_offload + + self.vae = set_offload(self.vae, torch.device(config.api.device)) # type: ignore + self.vae = optimize_vae(self.vae) # type: ignore + self.vae_path = vae + + def unload(self) -> None: + "Unload the model from memory" + + del ( + self.vae, + self.unet, + self.text_encoder, + self.text_encoder_2, + self.tokenizer, + self.tokenizer_2, + self.scheduler, + self.aesthetic_score, + self.force_zeros, + ) + + if hasattr(self, "original_vae"): + del self.original_vae # type: ignore + + self.memory_cleanup() + + def load_refiner(self, refiner: SDXLRefinerFlag, job) -> Tuple[Any, Any]: + from core.shared_dependent import gpu + + unload = False + + if refiner.model not in gpu.loaded_models: + gpu.load_model(refiner.model, "PyTorch", "SDXL") + unload = True + model: SDXLStableDiffusion = gpu.loaded_models[refiner.model] # type: ignore + if config.api.clear_memory_policy == "always": + self.memory_cleanup() + pipe = model.create_pipe( + scheduler=(job.data.scheduler, job.data.sigmas), + sampler_settings=job.data.sampler_settings, + ) + unl = lambda: "" + if unload: + + def unll(): + nonlocal model, refiner + + del model + gpu.unload(refiner.model) + + unl = unll + + return pipe, unl + + def create_pipe( + self, + controlnet: Optional[str] = "", + scheduler: Optional[Tuple[Any, SigmaScheduler]] = None, + sampler_settings: Optional[dict] = None, + ) -> StableDiffusionXLLongPromptWeightingPipeline: + "Create an LWP-XL pipeline" + + # self.manage_optional_components(target_controlnet=controlnet or "") + + pipe = StableDiffusionXLLongPromptWeightingPipeline( + parent=self, + vae=self.vae, + unet=self.unet, # type: ignore + text_encoder=self.text_encoder, + text_encoder_2=self.text_encoder_2, + tokenizer=self.tokenizer, + tokenizer_2=self.tokenizer_2, + scheduler=self.scheduler, + force_zeros=self.force_zeros, + aesthetic_score=self.aesthetic_score, + ) + pipe.parent = self + + if scheduler: + change_scheduler( + model=pipe, + scheduler=scheduler[0], # type: ignore + sigma_type=scheduler[1], + sampler_settings=sampler_settings, + ) + + return pipe + + def txt2img(self, job: Txt2ImgQueueEntry) -> Union[List[Image.Image], torch.Tensor]: + "Generate an image from a prompt" + + pipe = self.create_pipe( + scheduler=(job.data.scheduler, job.data.sigmas), + sampler_settings=job.data.sampler_settings, + ) + generator = create_generator(job.data.seed) + + total_images: Union[List[Image.Image], torch.Tensor] = [] + output_type = get_output_type(job) + + deepshrink = None + if "deepshrink" in job.flags: + deepshrink = DeepshrinkFlag.from_dict(job.flags["deepshrink"]) + + scalecrafter = None + if "scalecrafter" in job.flags: + scalecrafter = ScalecrafterFlag.from_dict(job.flags["scalecrafter"]) + + for _ in tqdm(range(job.data.batch_count), desc="Queue", position=1): + xl_flag = None + if "sdxl" in job.flags: + xl_flag = SDXLFlag.from_dict(job.flags["sdxl"]) + + refiner = None + if "refiner" in job.flags: + output_type = "latent" + refiner = SDXLRefinerFlag.from_dict(job.flags["refiner"]) + + refiner_model, unload = None, lambda: "" + if config.api.sdxl_refiner == "joint" and refiner is not None: + refiner_model, unload = self.load_refiner(refiner, job) + + original_size = None + if xl_flag: + original_size = [ + xl_flag.original_size.height, + xl_flag.original_size.width, + ] + + data = pipe( + original_size=original_size, + generator=generator, + prompt=job.data.prompt, + height=job.data.height, + width=job.data.width, + num_inference_steps=job.data.steps, + guidance_scale=job.data.guidance_scale, + negative_prompt=job.data.negative_prompt, + output_type=output_type, + callback=callback, + num_images_per_prompt=job.data.batch_size, + seed=job.data.seed, + self_attention_scale=job.data.self_attention_scale, + prompt_expansion_settings=job.data.prompt_to_prompt_settings, + refiner=refiner, + refiner_model=refiner_model, + deepshrink=deepshrink, + scalecrafter=scalecrafter, + ) + + if refiner is not None and config.api.sdxl_refiner == "separate": + latents: torch.FloatTensor = data[0] # type: ignore + + refiner_model, unload = self.load_refiner(refiner, job) + + data = pipe( + aesthetic_score=refiner.aesthetic_score, + negative_aesthetic_score=refiner.negative_aesthetic_score, + original_size=original_size, + image=latents, + generator=generator, + prompt=job.data.prompt, + height=job.data.height, + width=job.data.width, + strength=refiner.strength, + num_inference_steps=refiner.steps, + guidance_scale=job.data.guidance_scale, + negative_prompt=job.data.negative_prompt, + callback=callback, + num_images_per_prompt=job.data.batch_size, + return_dict=False, + output_type=output_type, + seed=job.data.seed, + self_attention_scale=job.data.self_attention_scale, + prompt_expansion_settings=job.data.prompt_to_prompt_settings, + ) + + images: Union[List[Image.Image], torch.Tensor] = data[0] # type: ignore + + if not isinstance(images, List): + total_images = images + else: + assert isinstance(total_images, List) + total_images.extend(images) + + unload() + + if isinstance(total_images, List): + websocket_manager.broadcast_sync( + data=Data( + data_type=shared.current_method or "txt2img", + data={ + "progress": 0, + "current_step": 0, + "total_steps": 0, + "image": convert_images_to_base64_grid( + total_images, quality=90, image_format="webp" # type: ignore + ), + }, + ) + ) + + return total_images + + def img2img(self, job: Img2ImgQueueEntry) -> Union[List[Image.Image], torch.Tensor]: + "Generate an image from an image" + + pipe = self.create_pipe( + scheduler=(job.data.scheduler, job.data.sigmas), + sampler_settings=job.data.sampler_settings, + ) + generator = create_generator(job.data.seed) + + # Preprocess the image + if isinstance(job.data.image, (str, bytes, Image.Image)): + input_image = convert_to_image(job.data.image) + input_image = resize(input_image, job.data.width, job.data.height) + else: + input_image = job.data.image + + total_images: Union[List[Image.Image], torch.Tensor] = [] + output_type = get_output_type(job) + + xl_flag = None + if "sdxl" in job.flags: + xl_flag = SDXLFlag.from_dict(job.flags["sdxl"]) + + original_size = None + if xl_flag: + original_size = [ + xl_flag.original_size.height, + xl_flag.original_size.width, + ] + + deepshrink = None + if "deepshrink" in job.flags: + deepshrink = DeepshrinkFlag.from_dict(job.flags["deepshrink"]) + + for _ in tqdm(range(job.data.batch_count), desc="Queue", position=1): + data = pipe( + original_size=original_size, + generator=generator, + prompt=job.data.prompt, + image=input_image, # type: ignore + num_inference_steps=job.data.steps, + guidance_scale=job.data.guidance_scale, + negative_prompt=job.data.negative_prompt, + width=job.data.width, + height=job.data.height, + output_type=output_type, + callback=callback, + strength=job.data.strength, + return_dict=False, + num_images_per_prompt=job.data.batch_size, + seed=job.data.seed, + self_attention_scale=job.data.self_attention_scale, + prompt_expansion_settings=job.data.prompt_to_prompt_settings, + deepshrink=deepshrink, + ) + + images: Union[List[Image.Image], torch.Tensor] = data[0] # type: ignore + + if not isinstance(images, List): + total_images = images + else: + assert isinstance(total_images, List) + total_images.extend(images) + + if isinstance(total_images, List): + websocket_manager.broadcast_sync( + data=Data( + data_type=shared.current_method or "img2img", + data={ + "progress": 0, + "current_step": 0, + "total_steps": 0, + "image": convert_images_to_base64_grid( + total_images, quality=90, image_format="webp" # type: ignore + ), + }, + ) + ) + + return total_images + + def inpaint(self, job: InpaintQueueEntry) -> Union[List[Image.Image], torch.Tensor]: + "Generate an image from an image" + + pipe = self.create_pipe( + scheduler=(job.data.scheduler, job.data.sigmas), + sampler_settings=job.data.sampler_settings, + ) + generator = create_generator(job.data.seed) + + # Preprocess images + input_image = convert_to_image(job.data.image).convert("RGB") + input_image = resize(input_image, job.data.width, job.data.height) + + input_mask_image = convert_to_image(job.data.mask_image).convert("RGB") + input_mask_image = ImageOps.invert(input_mask_image) + input_mask_image = resize(input_mask_image, job.data.width, job.data.height) + + total_images: Union[List[Image.Image], torch.Tensor] = [] + output_type = get_output_type(job) + + xl_flag = None + if "sdxl" in job.flags: + xl_flag = SDXLFlag.from_dict(job.flags["sdxl"]) + + original_size = None + if xl_flag: + original_size = [ + xl_flag.original_size.height, + xl_flag.original_size.width, + ] + + deepshrink = None + if "deepshrink" in job.flags: + deepshrink = DeepshrinkFlag.from_dict(job.flags["deepshrink"]) + + for _ in tqdm(range(job.data.batch_count), desc="Queue", position=1): + data = pipe( + original_size=original_size, + generator=generator, + prompt=job.data.prompt, + image=input_image, + mask_image=input_mask_image, + num_inference_steps=job.data.steps, + guidance_scale=job.data.guidance_scale, + negative_prompt=job.data.negative_prompt, + output_type=output_type, + callback=callback, + return_dict=False, + num_images_per_prompt=job.data.batch_size, + width=job.data.width, + height=job.data.height, + seed=job.data.seed, + strength=job.data.strength, + self_attention_scale=job.data.self_attention_scale, + prompt_expansion_settings=job.data.prompt_to_prompt_settings, + deepshrink=deepshrink, + ) + + images: Union[List[Image.Image], torch.Tensor] = data[0] # type: ignore + + if not isinstance(images, List): + total_images = images + else: + assert isinstance(total_images, List) + total_images.extend(images) + + if isinstance(total_images, List): + websocket_manager.broadcast_sync( + data=Data( + data_type=shared.current_method or "inpainting", + data={ + "progress": 0, + "current_step": 0, + "total_steps": 0, + "image": convert_images_to_base64_grid( + total_images, quality=90, image_format="webp" # type: ignore + ), + }, + ) + ) + + return total_images + + def adetailer( + self, + job: ADetailerQueueEntry, + ): + from ..adetailer.adetailer import ADetailer + + data = job.data + assert data is not None + + entry = InpaintQueueEntry( + data=data, + model=job.model, + save_image=job.save_image, + ) + + output = ADetailer().generate( + fn=self.inpaint, + inpaint_entry=entry, + mask_dilation=job.mask_dilation, + mask_blur=job.mask_blur, + mask_padding=job.mask_padding, + upscale=job.upscale, + iterations=job.iterations, + ) + + return [*output.images, *output.init_images] + + def generate(self, job: Job) -> Union[List[Image.Image], torch.Tensor]: + "Generate images from the queue" + + logging.info(f"Adding job {job.data.id} to queue") + self.memory_cleanup() + + try: + if isinstance(job, Txt2ImgQueueEntry): + images = self.txt2img(job) + elif isinstance(job, Img2ImgQueueEntry): + images = self.img2img(job) + elif isinstance(job, InpaintQueueEntry): + images = self.inpaint(job) + elif isinstance(job, ADetailerQueueEntry): + images = self.adetailer(job) + else: + raise ValueError("Invalid job type for this pipeline") + except Exception as e: + self.memory_cleanup() + raise e + if len(self.unload_loras) != 0: + for lora in self.unload_loras: + try: + self.lora_injector.remove_lora(lora) # type: ignore + logger.debug(f"Unloading LoRA: {lora}") + except KeyError: + pass + self.unload_loras.clear() + if len(self.unload_lycoris) != 0: # type: ignore + for lora in self.unload_lycoris: # type: ignore + try: + self.lora_injector.remove_lycoris(lora) # type: ignore + logger.debug(f"Unloading LyCORIS: {lora}") + except KeyError: + pass + + # Clean memory and return images + self.memory_cleanup() + return images + + def save(self, path: str = "converted", safetensors: bool = False): + "Dump current pipeline to specified path" + + pipe = StableDiffusionXLPipeline( + vae=self.vae, + unet=self.unet, + text_encoder_2=self.text_encoder_2, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + tokenizer_2=self.tokenizer_2, + scheduler=self.scheduler, + ) + + pipe.save_pretrained(path, safe_serialization=safetensors) + + def load_textual_inversion(self, textual_inversion: str): + "Inject a textual inversion model into the pipeline" + + logger.info( + f"Loading textual inversion model {textual_inversion} onto {self.model_id}..." + ) + + if any(textual_inversion in lora for lora in self.textual_inversions): + logger.info( + f"Textual inversion model {textual_inversion} already loaded onto {self.model_id}" + ) + return + + pipe = StableDiffusionXLPipeline( + vae=self.vae, + unet=self.unet, + text_encoder=self.text_encoder, + text_encoder_2=self.text_encoder_2, + tokenizer=self.tokenizer, + tokenizer_2=self.tokenizer_2, + scheduler=self.scheduler, + ) + pipe.parent = self + + token = Path(textual_inversion).stem + logger.info(f"Loading token {token} for textual inversion model") + + state_dict = load_file(textual_inversion) + + try: + pipe.load_textual_inversion( + state_dict["clip_g"], # type: ignore + token=token, + text_encoder=pipe.text_encoder_2, + tokenizer=pipe.tokenizer_2, + ) + pipe.load_textual_inversion( + state_dict["clip_l"], # type: ignore + token=token, + text_encoder=pipe.text_encoder, + tokenizer=pipe.tokenizer, + ) + except KeyError: + logger.info(f"Assuming {textual_inversion} is for non SDXL model, skipping") + return + + self.textual_inversions.append(textual_inversion) + logger.info(f"Textual inversion model {textual_inversion} loaded successfully") + logger.debug(f"All added tokens: {self.tokenizer.added_tokens_encoder}") + + def tokenize(self, text: str): + "Return the vocabulary of the tokenizer" + + return [i.replace("", "") for i in self.tokenizer.tokenize(text=text)] diff --git a/core/inference/utilities/__init__.py b/core/inference/utilities/__init__.py index 4951f184b..0fbbceef8 100644 --- a/core/inference/utilities/__init__.py +++ b/core/inference/utilities/__init__.py @@ -6,6 +6,7 @@ prepare_latents, prepare_mask_and_masked_image, prepare_mask_latents, + preprocess_adapter_image, preprocess_image, preprocess_mask, scale_latents, @@ -15,3 +16,13 @@ from .random import create_generator, randn, randn_like from .vae import taesd, full_vae, cheap_approximation, numpy_to_pil, decode_latents from .prompt_expansion import download_model, expand +from .cfg import calculate_cfg +from .unet_patches import _dummy +from .kohya_hires import post_process as postprocess_kohya, modify_unet as modify_kohya +from .scalecrafter import ( + ScalecrafterSettings, + find_config_closest_to as get_scalecrafter_config, + post_scale as post_scalecrafter, + scale as step_scalecrafter, + scale_setup as setup_scalecrafter, +) diff --git a/core/inference/utilities/animatediff/__init__.py b/core/inference/utilities/animatediff/__init__.py new file mode 100644 index 000000000..3031ce34d --- /dev/null +++ b/core/inference/utilities/animatediff/__init__.py @@ -0,0 +1,142 @@ +# AnimateDiff is NCFHW (batch, channels, frames, height, width) + +from typing import Optional + +import numpy as np + +from .freeinit import freq_mix_3d as freeinit_mix, get_freq_filter as freeinit_filter +from .models.unet import UNet3DConditionModel +from .pia.masking import prepare_mask_coef_by_statistics + + +def ordered_halving(val): + "Returns fraction that has denominator that is a power of 2" + + bin_str = f"{val:064b}" + bin_flip = bin_str[::-1] + as_int = int(bin_flip, 2) + final = as_int / (1 << 64) + return final + + +# Generator that returns lists of latent indeces to diffuse on +def uniform( + step: int = 0, + num_frames: int = 0, + context_size: Optional[int] = None, + context_stride: int = 3, + context_overlap: int = 4, + closed_loop: bool = True, +): + if num_frames <= context_size: # type: ignore + yield list(range(num_frames)) + return + + context_stride = min( + context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1 # type: ignore + ) + + for context_step in 1 << np.arange(context_stride): + pad = int(round(num_frames * ordered_halving(step))) + for j in range( + int(ordered_halving(step) * context_step) + pad, + num_frames + pad + (0 if closed_loop else -context_overlap), + (context_size * context_step - context_overlap), + ): + yield [ + e % num_frames + for e in range(j, j + context_size * context_step, context_step) + ] + + +def uniform_v2( + step: int = 0, + num_frames: int = 0, + context_size: Optional[int] = None, + context_stride: int = 3, + context_overlap: int = 4, + closed_loop: bool = True, +): + if num_frames <= context_size: # type: ignore + yield list(range(num_frames)) + return + + context_stride = min( + context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1 # type: ignore + ) + + pad = int(round(num_frames * ordered_halving(step))) + for context_step in 1 << np.arange(context_stride): + j_initial = int(ordered_halving(step) * context_step) + pad + for j in range( + j_initial, + num_frames + pad - context_overlap, + (context_size * context_step - context_overlap), + ): + if context_size * context_step > num_frames: + # On the final context_step, + # ensure no frame appears in the window twice + yield [e % num_frames for e in range(j, j + num_frames, context_step)] + continue + j = j % num_frames + if j > (j + context_size * context_step) % num_frames and not closed_loop: + yield [e for e in range(j, num_frames, context_step)] + j_stop = (j + context_size * context_step) % num_frames + # When ((num_frames % (context_size - context_overlap)+context_overlap) % context_size != 0, + # This can cause 'superflous' runs where all frames in + # a context window have already been processed during + # the first context window of this stride and step. + # While the following commented if should prevent this, + # I believe leaving it in is more correct as it maintains + # the total conditional passes per frame over a large total steps + # if j_stop > context_overlap: + yield [e for e in range(0, j_stop, context_step)] + continue + yield [ + e % num_frames + for e in range(j, j + context_size * context_step, context_step) + ] + + +def uniform_constant( + step: int = 0, + num_frames: int = 0, + context_size: Optional[int] = None, + context_stride: int = 3, + context_overlap: int = 4, + closed_loop: bool = True, +): + if num_frames <= context_size: # type: ignore + yield list(range(num_frames)) + return + + context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1) # type: ignore + + for context_step in 1 << np.arange(context_stride): + pad = int(round(num_frames * ordered_halving(step))) + for j in range( + int(ordered_halving(step) * context_step) + pad, + num_frames + pad + (0 if closed_loop else -context_overlap), + (context_size * context_step - context_overlap), + ): + skip_this_window = False + prev_val = -1 + to_yield = [] + for e in range(j, j + context_size * context_step, context_step): + e = e % num_frames + if not closed_loop and e < prev_val: + skip_this_window = True + break + to_yield.append(e) + prev_val = e + if skip_this_window: + continue + yield to_yield + + +def get_context_scheduler(name: str): + return globals().get(name, nil_scheduler) + + +def nil_scheduler(*args, **kwargs): + yield 0 diff --git a/core/inference/utilities/animatediff/freeinit.py b/core/inference/utilities/animatediff/freeinit.py new file mode 100644 index 000000000..6bd9cf8b5 --- /dev/null +++ b/core/inference/utilities/animatediff/freeinit.py @@ -0,0 +1,162 @@ +import torch +import math + + +def freq_mix_3d(x, noise, LPF): + """ + Noise reinitialization. + + Args: + x: diffused latent + noise: randomly sampled noise + LPF: low pass filter + """ + # FFT + x_freq = torch.fft.fftn(x, dim=(-3, -2, -1)) + x_freq = torch.fft.fftshift(x_freq, dim=(-3, -2, -1)) + noise_freq = torch.fft.fftn(noise, dim=(-3, -2, -1)) + noise_freq = torch.fft.fftshift(noise_freq, dim=(-3, -2, -1)) + + # frequency mix + HPF = 1 - LPF + x_freq_low = x_freq * LPF + noise_freq_high = noise_freq * HPF + x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain + + # IFFT + x_freq_mixed = torch.fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1)) + x_mixed = torch.fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real + + return x_mixed + + +def get_freq_filter(shape, device, params: dict): + """ + Form the frequency filter for noise reinitialization. + + Args: + shape: shape of latent (B, C, T, H, W) + params: filter parameters + """ + if params["method"] == "gaussian": + return gaussian_low_pass_filter( + shape=shape, d_s=params["d_s"], d_t=params["d_t"] + ).to(device) + elif params["method"] == "ideal": + return ideal_low_pass_filter( + shape=shape, d_s=params["d_s"], d_t=params["d_t"] + ).to(device) + elif params["method"] == "box": + return box_low_pass_filter( + shape=shape, d_s=params["d_s"], d_t=params["d_t"] + ).to(device) + elif params["method"] == "butterworth": + return butterworth_low_pass_filter( + shape=shape, n=params["n"], d_s=params["d_s"], d_t=params["d_t"] + ).to(device) + else: + raise NotImplementedError + + +def gaussian_low_pass_filter(shape, d_s=0.25, d_t=0.25): + """ + Compute the gaussian low pass filter mask. + + Args: + shape: shape of the filter (volume) + d_s: normalized stop frequency for spatial dimensions (0.0-1.0) + d_t: normalized stop frequency for temporal dimension (0.0-1.0) + """ + T, H, W = shape[-3], shape[-2], shape[-1] + mask = torch.zeros(shape) + if d_s == 0 or d_t == 0: + return mask + for t in range(T): + for h in range(H): + for w in range(W): + d_square = ( + ((d_s / d_t) * (2 * t / T - 1)) ** 2 + + (2 * h / H - 1) ** 2 + + (2 * w / W - 1) ** 2 + ) + mask[..., t, h, w] = math.exp(-1 / (2 * d_s**2) * d_square) + return mask + + +def butterworth_low_pass_filter(shape, n=4, d_s=0.25, d_t=0.25): + """ + Compute the butterworth low pass filter mask. + + Args: + shape: shape of the filter (volume) + n: order of the filter, larger n ~ ideal, smaller n ~ gaussian + d_s: normalized stop frequency for spatial dimensions (0.0-1.0) + d_t: normalized stop frequency for temporal dimension (0.0-1.0) + """ + T, H, W = shape[-3], shape[-2], shape[-1] + mask = torch.zeros(shape) + if d_s == 0 or d_t == 0: + return mask + for t in range(T): + for h in range(H): + for w in range(W): + d_square = ( + ((d_s / d_t) * (2 * t / T - 1)) ** 2 + + (2 * h / H - 1) ** 2 + + (2 * w / W - 1) ** 2 + ) + mask[..., t, h, w] = 1 / (1 + (d_square / d_s**2) ** n) + return mask + + +def ideal_low_pass_filter(shape, d_s=0.25, d_t=0.25): + """ + Compute the ideal low pass filter mask. + + Args: + shape: shape of the filter (volume) + d_s: normalized stop frequency for spatial dimensions (0.0-1.0) + d_t: normalized stop frequency for temporal dimension (0.0-1.0) + """ + T, H, W = shape[-3], shape[-2], shape[-1] + mask = torch.zeros(shape) + if d_s == 0 or d_t == 0: + return mask + for t in range(T): + for h in range(H): + for w in range(W): + d_square = ( + ((d_s / d_t) * (2 * t / T - 1)) ** 2 + + (2 * h / H - 1) ** 2 + + (2 * w / W - 1) ** 2 + ) + mask[..., t, h, w] = 1 if d_square <= d_s * 2 else 0 + return mask + + +def box_low_pass_filter(shape, d_s=0.25, d_t=0.25): + """ + Compute the ideal low pass filter mask (approximated version). + + Args: + shape: shape of the filter (volume) + d_s: normalized stop frequency for spatial dimensions (0.0-1.0) + d_t: normalized stop frequency for temporal dimension (0.0-1.0) + """ + T, H, W = shape[-3], shape[-2], shape[-1] + mask = torch.zeros(shape) + if d_s == 0 or d_t == 0: + return mask + + threshold_s = round(int(H // 2) * d_s) + threshold_t = round(T // 2 * d_t) + + cframe, crow, ccol = T // 2, H // 2, W // 2 + mask[ + ..., + cframe - threshold_t : cframe + threshold_t, + crow - threshold_s : crow + threshold_s, + ccol - threshold_s : ccol + threshold_s, + ] = 1.0 + + return mask diff --git a/core/inference/utilities/animatediff/models/__init__.py b/core/inference/utilities/animatediff/models/__init__.py new file mode 100644 index 000000000..752607f84 --- /dev/null +++ b/core/inference/utilities/animatediff/models/__init__.py @@ -0,0 +1 @@ +MMV2_DIM_KEY = "up_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe" diff --git a/core/inference/utilities/animatediff/models/attention.py b/core/inference/utilities/animatediff/models/attention.py new file mode 100644 index 000000000..d11252fcb --- /dev/null +++ b/core/inference/utilities/animatediff/models/attention.py @@ -0,0 +1,310 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py + +from dataclasses import dataclass +from typing import Optional + +import torch +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import BaseOutput # type: ignore +from diffusers.utils.import_utils import is_xformers_available +from diffusers.models.attention import FeedForward, AdaLayerNorm # type: ignore +from diffusers.models.attention_processor import Attention + +from einops import rearrange, repeat + + +@dataclass +class Transformer3DModelOutput(BaseOutput): + sample: torch.FloatTensor + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +class Transformer3DModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + unet_use_cross_frame_attention=None, + unet_use_temporal_attention=None, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # Define input layers + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm( + num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True # type: ignore + ) + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) # type: ignore + else: + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0 # type: ignore + ) + + # Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + if use_linear_projection: + self.proj_out = nn.Linear(in_channels, inner_dim) # type: ignore + else: + self.proj_out = nn.Conv2d( + inner_dim, in_channels, kernel_size=1, stride=1, padding=0 # type: ignore + ) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + return_dict: bool = True, + ): + # Input + assert ( + hidden_states.dim() == 5 + ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + encoder_hidden_states = repeat( + encoder_hidden_states, "b n c -> (b f) n c", f=video_length + ) + + batch, channel, height, weight = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( + batch, height * weight, inner_dim + ) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( + batch, height * weight, inner_dim + ) + hidden_states = self.proj_in(hidden_states) + + # Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + video_length=video_length, + ) + + # Output + if not self.use_linear_projection: + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim) + .permute(0, 3, 1, 2) + .contiguous() + ) + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim) + .permute(0, 3, 1, 2) + .contiguous() + ) + + output = hidden_states + residual + + output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) + if not return_dict: + return (output,) + + return Transformer3DModelOutput(sample=output) + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + unet_use_cross_frame_attention=None, + unet_use_temporal_attention=None, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + self.use_ada_layer_norm = num_embeds_ada_norm is not None + self.unet_use_cross_frame_attention = unet_use_cross_frame_attention + self.unet_use_temporal_attention = unet_use_temporal_attention + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + self.norm1 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) # type: ignore + if self.use_ada_layer_norm + else nn.LayerNorm(dim) + ) + + # Cross-Attn + if cross_attention_dim is not None: + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + else: + self.attn2 = None + + if cross_attention_dim is not None: + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) # type: ignore + if self.use_ada_layer_norm + else nn.LayerNorm(dim) + ) + else: + self.norm2 = None + + # Feed-forward + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) + self.norm3 = nn.LayerNorm(dim) + + # Temp-Attn + assert unet_use_temporal_attention is not None + if unet_use_temporal_attention: + self.attn_temp = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + nn.init.zeros_(self.attn_temp.to_out[0].weight.data) + self.norm_temp = ( + AdaLayerNorm(dim, num_embeds_ada_norm) # type: ignore + if self.use_ada_layer_norm + else nn.LayerNorm(dim) + ) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + attention_mask=None, + video_length=None, + ): + # SparseCausal-Attention + norm_hidden_states = ( + self.norm1(hidden_states, timestep) + if self.use_ada_layer_norm + else self.norm1(hidden_states) + ) + if self.unet_use_cross_frame_attention: + hidden_states = ( + self.attn1( + norm_hidden_states, + attention_mask=attention_mask, + video_length=video_length, + ) + + hidden_states + ) + else: + hidden_states = ( + self.attn1(norm_hidden_states, attention_mask=attention_mask) + + hidden_states + ) + + if self.attn2 is not None: + # Cross-Attention + norm_hidden_states = ( + self.norm2(hidden_states, timestep) # type: ignore + if self.use_ada_layer_norm + else self.norm2(hidden_states) # type: ignore + ) + hidden_states = ( + self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + ) + + hidden_states + ) + + # Feed-forward + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + + # Temporal-Attention + if self.unet_use_temporal_attention: + d = hidden_states.shape[1] + hidden_states = rearrange( + hidden_states, "(b f) d c -> (b d) f c", f=video_length + ) + norm_hidden_states = ( + self.norm_temp(hidden_states, timestep) + if self.use_ada_layer_norm + else self.norm_temp(hidden_states) + ) + hidden_states = self.attn_temp(norm_hidden_states) + hidden_states + hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) + + return hidden_states diff --git a/core/inference/utilities/animatediff/models/motion_module.py b/core/inference/utilities/animatediff/models/motion_module.py new file mode 100644 index 000000000..6f8450979 --- /dev/null +++ b/core/inference/utilities/animatediff/models/motion_module.py @@ -0,0 +1,391 @@ +from dataclasses import dataclass + +import torch +from torch import nn +from diffusers.utils import BaseOutput # type: ignore +from diffusers.utils.import_utils import is_xformers_available +from diffusers.models.attention import FeedForward +from diffusers.models.attention_processor import ( + Attention as CrossAttention, + AttnProcessor, + AttnProcessor2_0, + XFormersAttnProcessor, +) + +from einops import rearrange, repeat +import math + + +def zero_module(module): + # Zero out the parameters of a module and return it. + for p in module.parameters(): + p.detach().zero_() + return module + + +@dataclass +class TemporalTransformer3DModelOutput(BaseOutput): + sample: torch.FloatTensor + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict): + if motion_module_type == "Vanilla": + return VanillaTemporalModule( + in_channels=in_channels, + **motion_module_kwargs, + ) + else: + raise ValueError + + +class VanillaTemporalModule(nn.Module): + def __init__( + self, + in_channels, + num_attention_heads=8, + num_transformer_block=2, + attention_block_types=("Temporal_Self", "Temporal_Self"), + cross_frame_attention_mode=None, + temporal_position_encoding=False, + temporal_position_encoding_max_len=24, + temporal_attention_dim_div=1, + zero_initialize=True, + ): + super().__init__() + + self.temporal_transformer = TemporalTransformer3DModel( + in_channels=in_channels, + num_attention_heads=num_attention_heads, + attention_head_dim=in_channels + // num_attention_heads + // temporal_attention_dim_div, + num_layers=num_transformer_block, + attention_block_types=attention_block_types, + cross_frame_attention_mode=cross_frame_attention_mode, + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, + ) + + if zero_initialize: + self.temporal_transformer.proj_out = zero_module( + self.temporal_transformer.proj_out + ) + + def forward( + self, + input_tensor, + temb, + encoder_hidden_states, + attention_mask=None, + anchor_frame_idx=None, + ): + hidden_states = input_tensor + hidden_states = self.temporal_transformer( + hidden_states, encoder_hidden_states, attention_mask + ) + + output = hidden_states + return output + + +class TemporalTransformer3DModel(nn.Module): + def __init__( + self, + in_channels, + num_attention_heads, + attention_head_dim, + num_layers, + attention_block_types=( + "Temporal_Self", + "Temporal_Self", + ), + dropout=0.0, + norm_num_groups=32, + cross_attention_dim=768, + activation_fn="geglu", + attention_bias=False, + upcast_attention=False, + cross_frame_attention_mode=None, + temporal_position_encoding=False, + temporal_position_encoding_max_len=24, + ): + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + + self.norm = torch.nn.GroupNorm( + num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + TemporalTransformerBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + attention_block_types=attention_block_types, + dropout=dropout, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + upcast_attention=upcast_attention, + cross_frame_attention_mode=cross_frame_attention_mode, + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, + ) + for d in range(num_layers) + ] + ) + self.proj_out = nn.Linear(inner_dim, in_channels) + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): + assert ( + hidden_states.dim() == 5 + ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + + batch, channel, height, weight = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( + batch, height * weight, inner_dim + ) + hidden_states = self.proj_in(hidden_states) + + # Transformer Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + video_length=video_length, + ) + + # output + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim) + .permute(0, 3, 1, 2) + .contiguous() + ) + + output = hidden_states + residual + output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) + + return output + + +class TemporalTransformerBlock(nn.Module): + def __init__( + self, + dim, + num_attention_heads, + attention_head_dim, + attention_block_types=( + "Temporal_Self", + "Temporal_Self", + ), + dropout=0.0, + norm_num_groups=32, + cross_attention_dim=768, + activation_fn="geglu", + attention_bias=False, + upcast_attention=False, + cross_frame_attention_mode=None, + temporal_position_encoding=False, + temporal_position_encoding_max_len=24, + ): + super().__init__() + + attention_blocks = [] + norms = [] + + for block_name in attention_block_types: + attention_blocks.append( + VersatileAttention( + attention_mode=block_name.split("_")[0], + cross_attention_dim=cross_attention_dim + if block_name.endswith("_Cross") + else None, + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + cross_frame_attention_mode=cross_frame_attention_mode, + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, + ) + ) + norms.append(nn.LayerNorm(dim)) + + self.attention_blocks = nn.ModuleList(attention_blocks) + self.norms = nn.ModuleList(norms) + + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) + self.ff_norm = nn.LayerNorm(dim) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + video_length=None, + ): + for attention_block, norm in zip(self.attention_blocks, self.norms): + norm_hidden_states = norm(hidden_states) + hidden_states = ( + attention_block( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states + if attention_block.is_cross_attention + else None, + video_length=video_length, + ) + + hidden_states + ) + + hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states + + output = hidden_states + return output + + +class PositionalEncoding(nn.Module): + def __init__(self, d_model, dropout=0.0, max_len=24): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + position = torch.arange(max_len).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) + ) + pe = torch.zeros(1, max_len, d_model) + pe[0, :, 0::2] = torch.sin(position * div_term) + pe[0, :, 1::2] = torch.cos(position * div_term) + self.register_buffer("pe", pe) + + def forward(self, x): + x = x + self.pe[:, : x.size(1)] + return self.dropout(x) + + +class VersatileAttention(CrossAttention): + def __init__( + self, + attention_mode=None, + cross_frame_attention_mode=None, + temporal_position_encoding=False, + temporal_position_encoding_max_len=24, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + assert attention_mode == "Temporal" + + self.attention_mode = attention_mode + self.is_cross_attention = kwargs["cross_attention_dim"] is not None + + self.pos_encoder = ( + PositionalEncoding( + kwargs["query_dim"], + dropout=0.0, + max_len=temporal_position_encoding_max_len, + ) + if (temporal_position_encoding and attention_mode == "Temporal") + else None + ) + + def extra_repr(self): + return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}" + + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, attention_op=None + ): + if use_memory_efficient_attention_xformers: + if not is_xformers_available(): + raise ModuleNotFoundError( + ( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers" + ), + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" + " only available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + assert xformers is not None + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + + # XFormersAttnProcessor corrupts video generation and work with Pytorch 1.13. + # Pytorch 2.0.1 AttnProcessor works the same as XFormersAttnProcessor in Pytorch 1.13. + # You don't need XFormersAttnProcessor here. + processor = XFormersAttnProcessor( + attention_op=attention_op, + ) + else: + processor = AttnProcessor() + if hasattr(torch.nn.functional, "scaled_dot_product_attention"): + processor = AttnProcessor2_0() + + self.set_processor(processor) # type: ignore + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + video_length=None, + **cross_attention_kwargs, + ): + if self.attention_mode == "Temporal": + d = hidden_states.shape[1] + hidden_states = rearrange( + hidden_states, "(b f) d c -> (b d) f c", f=video_length + ) + + if self.pos_encoder is not None: + hidden_states = self.pos_encoder(hidden_states) + + encoder_hidden_states = ( + repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) + if encoder_hidden_states is not None + else encoder_hidden_states + ) + else: + raise NotImplementedError + + hidden_states = self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + if self.attention_mode == "Temporal": + hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) + + return hidden_states diff --git a/core/inference/utilities/animatediff/models/resnet.py b/core/inference/utilities/animatediff/models/resnet.py new file mode 100644 index 000000000..45397933d --- /dev/null +++ b/core/inference/utilities/animatediff/models/resnet.py @@ -0,0 +1,253 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange + + +class InflatedConv3d(nn.Conv2d): + def forward(self, x): + video_length = x.shape[2] + + x = rearrange(x, "b c f h w -> (b f) c h w") + x = super().forward(x) + x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) + + return x + + +class InflatedGroupNorm(nn.GroupNorm): + def forward(self, x): + video_length = x.shape[2] + + x = rearrange(x, "b c f h w -> (b f) c h w") + x = super().forward(x) + x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) + + return x + + +class Upsample3D(nn.Module): + def __init__( + self, + channels, + use_conv=False, + use_conv_transpose=False, + out_channels=None, + name="conv", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + if use_conv_transpose: + raise NotImplementedError + elif use_conv: + self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) + + def forward(self, hidden_states, output_size=None): + assert hidden_states.shape[1] == self.channels + + if self.use_conv_transpose: + raise NotImplementedError + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + # if `output_size` is passed we force the interpolation output + # size and do not make use of `scale_factor=2` + if output_size is None: + hidden_states = F.interpolate( + hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest" + ) + else: + hidden_states = F.interpolate( + hidden_states, size=output_size, mode="nearest" + ) + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + # if self.use_conv: + # if self.name == "conv": + # hidden_states = self.conv(hidden_states) + # else: + # hidden_states = self.Conv2d_0(hidden_states) + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class Downsample3D(nn.Module): + def __init__( + self, channels, use_conv=False, out_channels=None, padding=1, name="conv" + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if use_conv: + self.conv = InflatedConv3d( + self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + raise NotImplementedError + + def forward(self, hidden_states): + assert hidden_states.shape[1] == self.channels + if self.use_conv and self.padding == 0: + raise NotImplementedError + + assert hidden_states.shape[1] == self.channels + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class ResnetBlock3D(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + non_linearity="swish", + time_embedding_norm="default", + output_scale_factor=1.0, + use_in_shortcut=None, + use_inflated_groupnorm=False, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + assert use_inflated_groupnorm is not None + if use_inflated_groupnorm: + self.norm1 = InflatedGroupNorm( + num_groups=groups, num_channels=in_channels, eps=eps, affine=True + ) + else: + self.norm1 = torch.nn.GroupNorm( + num_groups=groups, num_channels=in_channels, eps=eps, affine=True + ) + + self.conv1 = InflatedConv3d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + if temb_channels is not None: + if self.time_embedding_norm == "default": + time_emb_proj_out_channels = out_channels + elif self.time_embedding_norm == "scale_shift": + time_emb_proj_out_channels = out_channels * 2 + else: + raise ValueError( + f"unknown time_embedding_norm : {self.time_embedding_norm} " + ) + + self.time_emb_proj = torch.nn.Linear( + temb_channels, time_emb_proj_out_channels + ) + else: + self.time_emb_proj = None + + if use_inflated_groupnorm: + self.norm2 = InflatedGroupNorm( + num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True + ) + else: + self.norm2 = torch.nn.GroupNorm( + num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True + ) + + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = InflatedConv3d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + + self.use_in_shortcut = ( + self.in_channels != self.out_channels + if use_in_shortcut is None + else use_in_shortcut + ) + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = InflatedConv3d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, input_tensor, temb): + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] # type: ignore + + if temb is not None and self.time_embedding_norm == "default": + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + + if temb is not None and self.time_embedding_norm == "scale_shift": + scale, shift = torch.chunk(temb, 2, dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + + +class Mish(torch.nn.Module): + def forward(self, hidden_states): + return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) diff --git a/core/inference/utilities/animatediff/models/unet.py b/core/inference/utilities/animatediff/models/unet.py new file mode 100644 index 000000000..2de9650b2 --- /dev/null +++ b/core/inference/utilities/animatediff/models/unet.py @@ -0,0 +1,861 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py + +from copy import deepcopy +from dataclasses import dataclass +from pathlib import Path +from typing import List, Optional, Tuple, Union +import logging + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers import ModelMixin # type: ignore +from diffusers.utils import BaseOutput # type: ignore +from diffusers.models.embeddings import TimestepEmbedding, Timesteps + +from core.config import config +from core.inference.utilities.load import load_checkpoint +from .unet_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, +) +from ..pia.inflate import patch_conv3d +from .resnet import InflatedConv3d, InflatedGroupNorm +from . import MMV2_DIM_KEY + + +logger = logging.getLogger(__name__) + + +@dataclass +class UNet3DConditionOutput(BaseOutput): + sample: torch.FloatTensor + + +class UNet3DConditionModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( # type: ignore + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + mid_block_type: str = "UNetMidBlock3DCrossAttn", + up_block_types: Tuple[str] = ( # type: ignore + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), # type: ignore + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: Union[int, Tuple[int]] = 8, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + addition_embed_type: Optional[str] = None, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + use_inflated_groupnorm=False, + # Additional + use_motion_module=False, + motion_module_resolutions=(1, 2, 4, 8), + motion_module_mid_block=False, + motion_module_decoder_only=False, + motion_module_type=None, + motion_module_kwargs={}, + unet_use_cross_frame_attention=False, + unet_use_temporal_attention=False, + ): + super().__init__() + + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + + # input + self.conv_in = InflatedConv3d( + in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1) + ) + + # time + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + else: + self.class_embedding = None + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) # type: ignore + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) # type: ignore + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + res = 2**i + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], # type: ignore + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], # type: ignore + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, # type: ignore + use_motion_module=use_motion_module + and (res in motion_module_resolutions) + and (not motion_module_decoder_only), + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock3DCrossAttn": + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], # type: ignore + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, # type: ignore + use_motion_module=use_motion_module and motion_module_mid_block, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the videos + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) # type: ignore + only_cross_attention = list(reversed(only_cross_attention)) # type: ignore + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + res = 2 ** (3 - i) + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[ + min(i + 1, len(block_out_channels) - 1) + ] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], # type: ignore + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, # type: ignore + use_motion_module=use_motion_module + and (res in motion_module_resolutions), + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if use_inflated_groupnorm: + self.conv_norm_out = InflatedGroupNorm( + num_channels=block_out_channels[0], + num_groups=norm_num_groups, + eps=norm_eps, + ) + else: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], + num_groups=norm_num_groups, + eps=norm_eps, + ) + self.conv_act = nn.SiLU() + self.conv_out = InflatedConv3d( + block_out_channels[0], out_channels, kernel_size=3, padding=1 + ) + + @property + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> dict: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors( + name: str, module: torch.nn.Module, processors: dict + ): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor( + return_deprecated_lora=True + ) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor, _remove_lora=False): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor, _remove_lora=_remove_lora) + else: + module.set_processor( + processor.pop(f"{name}.processor"), _remove_lora=_remove_lora + ) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def enable_forward_chunking( + self, chunk_size: Optional[int] = None, dim: int = 0 + ) -> None: + """ + Sets the attention processor to use [feed forward + chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + + Parameters: + chunk_size (`int`, *optional*): + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually + over each tensor of dim=`dim`. + dim (`int`, *optional*, defaults to `0`): + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) + or dim=1 (sequence length). + """ + if dim not in [0, 1]: + raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") + + # By default chunk size is 1 + chunk_size = chunk_size or 1 + + def fn_recursive_feed_forward( + module: torch.nn.Module, chunk_size: int, dim: int + ): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, chunk_size, dim) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + + num_slicable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_slicable_layers * [1] + + slice_size = ( + num_slicable_layers * [slice_size] + if not isinstance(slice_size, list) + else slice_size + ) + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice( + module: torch.nn.Module, slice_size: List[int] + ): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance( + module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D) + ): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + # support controlnet + added_cond_kwargs: dict = {}, + down_intrablock_additional_residuals: Optional[torch.Tensor] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + drop_encode_decode: bool = False, + order: int = 0, + return_dict: bool = True, + ) -> Union[UNet3DConditionOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # 1 - default + # 2 - deepcache + # 3 - faster-diffusion + method = 1 + if drop_encode_decode: + method = 3 + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.debug("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # center input if necessary + if self.config["center_input_sample"]: + sample = 2 * sample - 1.0 # type: ignore + + # time + if isinstance(timestep, list): + timesteps = timestep[0] + else: + timesteps = timestep + if not torch.is_tensor(timesteps) and (not isinstance(timesteps, list)): + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif (not isinstance(timesteps, list)) and len(timesteps.shape) == 0: # type: ignore + timesteps = timesteps[None].to(sample.device) # type: ignore + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + if (not isinstance(timesteps, list)) and len(timesteps.shape) == 1: # type: ignore + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) # type: ignore + elif isinstance(timesteps, list): + # timesteps list, such as [981,961,941] + from core.inference.utilities.faster_diffusion import warp_timestep + + timesteps = warp_timestep(timesteps, sample.shape[0]).to(sample.device) # type: ignore + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + aug_emb = None + + if self.config["addition_embed_type"] == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config["addition_embed_type"] == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) # type: ignore + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) # type: ignore + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) # type: ignore + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + + emb = emb + aug_emb if aug_emb is not None else emb + + def downsample(downsample_block, additional_residuals: dict): + nonlocal sample, emb, encoder_hidden_states, attention_mask, down_intrablock_additional_residuals, down_block_res_samples + if ( + hasattr(downsample_block, "has_cross_attention") + and downsample_block.has_cross_attention + ): + # For t2i-adapter CrossAttnDownBlock2D + if is_adapter and len(down_intrablock_additional_residuals) > 0: # type: ignore + additional_residuals[ + "additional_residuals" + ] = down_intrablock_additional_residuals.pop( # type: ignore + 0 + ) + + sample, res_samples = downsample_block( + hidden_states=sample, # type: ignore + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + if is_adapter and len(down_intrablock_additional_residuals) > 0: # type: ignore + sample += down_intrablock_additional_residuals.pop(0) # type: ignore + + down_block_res_samples += res_samples + + def upsample(upsample_block, i, length, additional={}): + nonlocal self, down_block_res_samples, forward_upsample_size, sample, emb, encoder_hidden_states, upsample_size, attention_mask + + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[length:] + down_block_res_samples = down_block_res_samples[:length] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if ( + hasattr(upsample_block, "has_cross_attention") + and upsample_block.has_cross_attention + ): + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + **additional, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + ) + + is_controlnet = ( + mid_block_additional_residual is not None + and down_block_additional_residuals is not None + ) + is_adapter = down_intrablock_additional_residuals is not None + + if method == 1: + sample = self.conv_in(sample) + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + downsample(downsample_block, {}) + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals # type: ignore + ): + down_block_res_sample = ( + down_block_res_sample + down_block_additional_residual + ) + new_down_block_res_samples = new_down_block_res_samples + ( + down_block_res_sample, + ) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + if ( + hasattr(self.mid_block, "has_cross_attention") + and self.mid_block.has_cross_attention + ): + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + ) + else: + sample = self.mid_block(sample, emb) + + # To support T2I-Adapter-XL + if ( + is_adapter + and len(down_intrablock_additional_residuals) > 0 # type: ignore + and sample.shape == down_intrablock_additional_residuals[0].shape # type: ignore + ): + sample += down_intrablock_additional_residuals.pop(0) # type: ignore + + if is_controlnet: + sample = sample + mid_block_additional_residual # type: ignore + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + upsample(upsample_block, i, -len(upsample_block.resnets)) + elif method == 2: + raise NotImplementedError( + "DeepCache isn't implemented yet, I don't even have a clue how you got this error either..." + ) + else: + assert order is not None + + mod = config.api.drop_encode_decode + + cond = order <= 5 or order % 5 == 0 + if isinstance(mod, int): + cond = order <= 5 or order % mod == 0 + + if cond: + # 2. pre-process + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + downsample(downsample_block, {}) + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals # type: ignore + ): + down_block_res_sample = ( + down_block_res_sample + down_block_additional_residual + ) + new_down_block_res_samples = new_down_block_res_samples + ( + down_block_res_sample, + ) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + if ( + hasattr(self.mid_block, "has_cross_attention") + and self.mid_block.has_cross_attention + ): + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + ) + else: + sample = self.mid_block(sample, emb) + + # To support T2I-Adapter-XL + if ( + is_adapter + and len(down_intrablock_additional_residuals) > 0 # type: ignore + and sample.shape == down_intrablock_additional_residuals[0].shape # type: ignore + ): + sample += down_intrablock_additional_residuals.pop(0) # type: ignore + + if is_controlnet: + sample = sample + mid_block_additional_residual # type: ignore + + # 4.5. save features + setattr(self, "skip_feature", deepcopy(down_block_res_samples)) + setattr(self, "toup_feature", sample.detach().clone()) + else: + down_block_res_samples = self.skip_feature + sample = self.toup_feature + + for i, upsample_block in enumerate(self.up_blocks): + upsample(upsample_block, i, -len(upsample_block.resnets)) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNet3DConditionOutput(sample=sample) + + def convert_to_pia(self, checkpoint_name: str): + return patch_conv3d(self, checkpoint_name) + + @classmethod + def from_pretrained_2d( + cls, + unet, + motion_module_path, + is_sdxl: bool = False, + ): + motion_module_path = Path(motion_module_path) + state_dict = unet.state_dict() + + # load the motion module weights + if motion_module_path.exists() and motion_module_path.is_file(): + motion_state_dict = load_checkpoint( + motion_module_path.as_posix(), + motion_module_path.suffix.lower() == ".safetensors", + ) + else: + raise FileNotFoundError( + f"no motion module weights found in {motion_module_path}" + ) + + # merge the state dicts + motion_state_dict.update(state_dict) + + # check if we have a v1 or v2 motion module + motion_up_dim = motion_state_dict[MMV2_DIM_KEY].shape + unet_additional_kwargs = { + "unet_use_cross_frame_attention": False, + "unet_use_temporal_attention": False, + "use_motion_module": True, + "motion_module_resolutions": [1, 2, 4, 8], + "motion_module_mid_block": False, + "motion_module_decoder_only": False, + "motion_module_type": "Vanilla", + "motion_module_kwargs": { + "num_attention_heads": 8, + "num_transformer_block": 1, + "attention_block_types": ["Temporal_Self", "Temporal_Self"], + "temporal_position_encoding": True, + "temporal_position_encoding_max_len": 24, + "temporal_attention_dim_div": 1, + }, + } + if motion_up_dim[1] != 24: + logger.debug("Detected V2 motion module") + if unet_additional_kwargs: + motion_module_kwargs = unet_additional_kwargs.pop( + "motion_module_kwargs", {} + ) + motion_module_kwargs[ + "temporal_position_encoding_max_len" + ] = motion_up_dim[1] + unet_additional_kwargs["motion_module_kwargs"] = motion_module_kwargs + else: + unet_additional_kwargs = { + "motion_module_kwargs": { + "temporal_position_encoding_max_len": motion_up_dim[2] + } + } + + unet_config = dict(unet.config) + unet_config["_class_name"] = cls.__name__ + if is_sdxl: + unet_config["down_block_types"] = [ + "DownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + ] + unet_config["up_block_types"] = [ + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "UpBlock3D", + ] + else: + unet_config["down_block_types"] = [ + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ] + unet_config["up_block_types"] = [ + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + ] + unet_config["mid_block_type"] = "UNetMidBlock3DCrossAttn" + + if unet_additional_kwargs is None: + unet_additional_kwargs = {} + model: torch.nn.Module = cls.from_config(unet_config, **unet_additional_kwargs) # type: ignore + + # load the weights into the model + m, u = model.load_state_dict(motion_state_dict, strict=False) + logger.debug(f"### missing keys: {len(m)}; ### unexpected keys: {len(u)};") + + params = [ + p.numel() if "temporal" in n.lower() else 0 + for n, p in model.named_parameters() + ] + logger.info(f"Loaded {sum(params) / 1e6}M-parameter motion module") + + return model diff --git a/core/inference/utilities/animatediff/models/unet_blocks.py b/core/inference/utilities/animatediff/models/unet_blocks.py new file mode 100644 index 000000000..19d200de7 --- /dev/null +++ b/core/inference/utilities/animatediff/models/unet_blocks.py @@ -0,0 +1,841 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py + +import torch +from torch import nn + +from .attention import Transformer3DModel +from .resnet import Downsample3D, ResnetBlock3D, Upsample3D +from .motion_module import get_motion_module + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + unet_use_cross_frame_attention=False, + unet_use_temporal_attention=False, + use_inflated_groupnorm=False, + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, +): + down_block_type = ( + down_block_type[7:] + if down_block_type.startswith("UNetRes") + else down_block_type + ) + if down_block_type == "DownBlock3D": + return DownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, # type: ignore + downsample_padding=downsample_padding, # type: ignore + resnet_time_scale_shift=resnet_time_scale_shift, + use_inflated_groupnorm=use_inflated_groupnorm, + use_motion_module=use_motion_module, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + elif down_block_type == "CrossAttnDownBlock3D": + if cross_attention_dim is None: + raise ValueError( + "cross_attention_dim must be specified for CrossAttnDownBlock3D" + ) + return CrossAttnDownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, # type: ignore + downsample_padding=downsample_padding, # type: ignore + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, + use_motion_module=use_motion_module, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + unet_use_cross_frame_attention=False, + unet_use_temporal_attention=False, + use_inflated_groupnorm=False, + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, +): + up_block_type = ( + up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + ) + if up_block_type == "UpBlock3D": + return UpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, # type: ignore + resnet_time_scale_shift=resnet_time_scale_shift, + use_inflated_groupnorm=use_inflated_groupnorm, + use_motion_module=use_motion_module, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + elif up_block_type == "CrossAttnUpBlock3D": + if cross_attention_dim is None: + raise ValueError( + "cross_attention_dim must be specified for CrossAttnUpBlock3D" + ) + return CrossAttnUpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, # type: ignore + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, + use_motion_module=use_motion_module, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock3DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + unet_use_cross_frame_attention=False, + unet_use_temporal_attention=False, + use_inflated_groupnorm=False, + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = ( + resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + ) + + # there is always at least one resnet + resnets = [ + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_inflated_groupnorm=use_inflated_groupnorm, # type: ignore + ) + ] + attentions = [] + motion_modules = [] + + for _ in range(num_layers): + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=in_channels, + motion_module_type=motion_module_type, # type: ignore + motion_module_kwargs=motion_module_kwargs, # type: ignore + ) + if use_motion_module + else None + ) + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_inflated_groupnorm=use_inflated_groupnorm, # type: ignore + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + def forward( + self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None + ): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet, motion_module in zip( + self.attentions, self.resnets[1:], self.motion_modules # type: ignore + ): + hidden_states = attn( + hidden_states, encoder_hidden_states=encoder_hidden_states + ).sample + hidden_states = ( + motion_module( + hidden_states, temb, encoder_hidden_states=encoder_hidden_states + ) + if motion_module is not None + else hidden_states + ) + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class CrossAttnDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + unet_use_cross_frame_attention=False, + unet_use_temporal_attention=False, + use_inflated_groupnorm=False, + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + resnets = [] + attentions = [] + motion_modules = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_inflated_groupnorm=use_inflated_groupnorm, # type: ignore + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, # type: ignore + motion_module_kwargs=motion_module_kwargs, # type: ignore + ) + if use_motion_module + else None + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None + ): + output_states = () + + for resnet, attn, motion_module in zip( + self.resnets, self.attentions, self.motion_modules + ): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( # type: ignore + create_custom_forward(resnet), hidden_states, temb + ) + hidden_states = torch.utils.checkpoint.checkpoint( # type: ignore + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + )[0] + if motion_module is not None: + hidden_states = torch.utils.checkpoint.checkpoint( # type: ignore + create_custom_forward(motion_module), + hidden_states.requires_grad_(), + temb, + encoder_hidden_states, + ) + + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, encoder_hidden_states=encoder_hidden_states + ).sample + + # add motion module + hidden_states = ( + motion_module( + hidden_states, temb, encoder_hidden_states=encoder_hidden_states + ) + if motion_module is not None + else hidden_states + ) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + use_inflated_groupnorm=False, + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + resnets = [] + motion_modules = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_inflated_groupnorm=use_inflated_groupnorm, # type: ignore + ) + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, # type: ignore + motion_module_kwargs=motion_module_kwargs, # type: ignore + ) + if use_motion_module + else None + ) + + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + output_states = () + + for resnet, motion_module in zip(self.resnets, self.motion_modules): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( # type: ignore + create_custom_forward(resnet), hidden_states, temb + ) + if motion_module is not None: + hidden_states = torch.utils.checkpoint.checkpoint( # type: ignore + create_custom_forward(motion_module), + hidden_states.requires_grad_(), + temb, + encoder_hidden_states, + ) + else: + hidden_states = resnet(hidden_states, temb) + + # add motion module + hidden_states = ( + motion_module( + hidden_states, temb, encoder_hidden_states=encoder_hidden_states + ) + if motion_module is not None + else hidden_states + ) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + unet_use_cross_frame_attention=False, + unet_use_temporal_attention=False, + use_inflated_groupnorm=False, + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + resnets = [] + attentions = [] + motion_modules = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_inflated_groupnorm=use_inflated_groupnorm, # type: ignore + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, # type: ignore + motion_module_kwargs=motion_module_kwargs, # type: ignore + ) + if use_motion_module + else None + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + attention_mask=None, + ): + for resnet, attn, motion_module in zip( + self.resnets, self.attentions, self.motion_modules + ): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( # type: ignore + create_custom_forward(resnet), hidden_states, temb + ) + hidden_states = torch.utils.checkpoint.checkpoint( # type: ignore + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + )[0] + if motion_module is not None: + hidden_states = torch.utils.checkpoint.checkpoint( # type: ignore + create_custom_forward(motion_module), + hidden_states.requires_grad_(), + temb, + encoder_hidden_states, + ) + + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, encoder_hidden_states=encoder_hidden_states + ).sample + + # add motion module + hidden_states = ( + motion_module( + hidden_states, temb, encoder_hidden_states=encoder_hidden_states + ) + if motion_module is not None + else hidden_states + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + use_inflated_groupnorm=False, + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + resnets = [] + motion_modules = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_inflated_groupnorm=use_inflated_groupnorm, # type: ignore + ) + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, # type: ignore + motion_module_kwargs=motion_module_kwargs, # type: ignore + ) + if use_motion_module + else None + ) + + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + upsample_size=None, + encoder_hidden_states=None, + ): + for resnet, motion_module in zip(self.resnets, self.motion_modules): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( # type: ignore + create_custom_forward(resnet), hidden_states, temb + ) + if motion_module is not None: + hidden_states = torch.utils.checkpoint.checkpoint( # type: ignore + create_custom_forward(motion_module), + hidden_states.requires_grad_(), + temb, + encoder_hidden_states, + ) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = ( + motion_module( + hidden_states, temb, encoder_hidden_states=encoder_hidden_states + ) + if motion_module is not None + else hidden_states + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states diff --git a/core/inference/utilities/animatediff/pia/inflate.py b/core/inference/utilities/animatediff/pia/inflate.py new file mode 100644 index 000000000..364b46057 --- /dev/null +++ b/core/inference/utilities/animatediff/pia/inflate.py @@ -0,0 +1,36 @@ +from typing import TYPE_CHECKING +import logging + +import torch + +from core.inference.utilities.load import load_checkpoint +from ..models.resnet import InflatedConv3d + +if TYPE_CHECKING: + from ..models.unet import UNet3DConditionModel + + +logger = logging.getLogger(__name__) + + +def patch_conv3d(unet: "UNet3DConditionModel", pia_path: str) -> "UNet3DConditionModel": + old_weight, old_bias = unet.conv_in.weight, unet.conv_in.bias + new_conv = InflatedConv3d( + 9, # 9 channels + old_weight.shape[0], + kernel_size=unet.conv_in.kernel_size, # type: ignore + stride=unet.conv_in.stride, # type: ignore + padding=unet.conv_in.padding, # type: ignore + bias=True if old_bias is not None else False, + ) + param = torch.zeros((320, 5, 3, 3), requires_grad=True) + new_conv.weight = torch.nn.Parameter(torch.cat([old_weight, param], dim=1)) + if old_bias is not None: + new_conv.bias = old_bias + unet.conv_in = new_conv + unet.config["in_channels"] = 9 + + checkpoint = load_checkpoint(pia_path, pia_path.endswith("safetensors")) + m, u = unet.load_state_dict(checkpoint, strict=False) + logger.debug(f"Missing keys: {m}, unexpected: {u}") + return unet diff --git a/core/inference/utilities/animatediff/pia/masking.py b/core/inference/utilities/animatediff/pia/masking.py new file mode 100644 index 000000000..a71cbcef4 --- /dev/null +++ b/core/inference/utilities/animatediff/pia/masking.py @@ -0,0 +1,57 @@ +from typing import List + +# fmt: off +RANGE_LIST = [ + [1.0, 0.9, 0.85, 0.85, 0.85, 0.8], # 0 Small Motion + [1.0, 0.8, 0.8, 0.8, 0.79, 0.78, 0.75], # Moderate Motion + [1.0, 0.8, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.6, 0.5, 0.5], # Large Motion + [1.0, 0.7, 0.65, 0.65, 0.6, 0.6, 0.6, 0.55, 0.5, 0.5, 0.45, 0.45, 0.4], # ULTRA Large Motion + [1.0 , 0.9 , 0.85, 0.85, 0.85, 0.8 , 0.8 , 0.8 , 0.8 , 0.8 , 0.8 , 0.8 , 0.85, 0.85, 0.9 , 1.0 ], # Loop + [1.0 , 0.8 , 0.8 , 0.8 , 0.79, 0.78, 0.75, 0.75, 0.75, 0.75, 0.75, 0.78, 0.79, 0.8 , 0.8 , 1.0 ], # Loop + [1.0 , 0.8 , 0.7 , 0.7 , 0.7 , 0.7 , 0.6 , 0.5 , 0.5 , 0.6 , 0.7 , 0.7 , 0.7 , 0.7 , 0.8 , 1.0 ], # Loop + [1.0 , 0.7 , 0.6 , 0.6 , 0.6 , 0.6 , 0.5 , 0.4 , 0.4 , 0.5 , 0.6 , 0.6 , 0.6 , 0.6 , 0.7 , 1.0 ], # Loop + [0.4, 0.1], # Style Transfer ULTRA Large Motion + [0.5, 0.2], # Style Transfer Large Motion + [0.5, 0.4, 0.4, 0.4, 0.35, 0.35, 0.3, 0.25, 0.2], # Style Transfer Moderate Motion + [0.5, 0.4, 0.4, 0.4, 0.35, 0.3], # Style Transfer Candidate Small Motion +] +# fmt: on + + +def prepare_mask_coef( + video_length: int, cond_frame: int, sim_range: List[float] = [0.2, 1.0] +): + assert ( + len(sim_range) == 2 + ), "sim_range should has the length of 2, including the min and max similarity" + + assert video_length > 1, "video_length should be greater than 1" + + assert video_length > cond_frame, "video_length should be greater than cond_frame" + + diff = abs(sim_range[0] - sim_range[1]) / (video_length - 1) + coef = [1.0] * video_length + for f in range(video_length): + f_diff = diff * abs(cond_frame - f) + f_diff = 1 - f_diff + coef[f] *= f_diff + + return coef + + +def prepare_mask_coef_by_statistics(video_length: int, cond_frame: int, sim_range: int): + assert video_length > 0, "video_length should be greater than 0" + + assert video_length > cond_frame, "video_length should be greater than cond_frame" + + range_list = RANGE_LIST + + assert sim_range < len(range_list), f"sim_range type{sim_range} not implemented" + + coef = range_list[sim_range] + coef = coef + ([coef[-1]] * (video_length - len(coef))) + + order = [abs(i - cond_frame) for i in range(video_length)] + coef = [coef[order[i]] for i in range(video_length)] + + return coef diff --git a/core/inference/utilities/anisotropic.py b/core/inference/utilities/anisotropic.py new file mode 100644 index 000000000..2a7325b96 --- /dev/null +++ b/core/inference/utilities/anisotropic.py @@ -0,0 +1,256 @@ +# Taken from lllyasviel/Fooocus +# Show some love over at https://github.com/lllyasviel/Fooocus/ + +from typing import Union, Tuple, Optional + +import torch + + +def _compute_zero_padding(kernel_size: Union[Tuple[int, int], int]) -> Tuple[int, int]: + ky, kx = _unpack_2d_ks(kernel_size) + return (ky - 1) // 2, (kx - 1) // 2 + + +def _unpack_2d_ks(kernel_size: Union[Tuple[int, int], int]) -> Tuple[int, int]: + if isinstance(kernel_size, int): + ky = kx = kernel_size + else: + ky, kx = kernel_size + + ky = int(ky) + kx = int(kx) + return ky, kx + + +def gaussian( + window_size: int, + sigma: Union[torch.Tensor, float], + *, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + batch_size = sigma.shape[0] # type: ignore + + x = ( + torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) # type: ignore + - window_size // 2 + ).expand(batch_size, -1) + + if window_size % 2 == 0: + x = x + 0.5 + + gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0))) # type: ignore + + return gauss / gauss.sum(-1, keepdim=True) + + +def simple_gaussian_2d(img, kernel_size, sigma): + "Blurs an image with gaussian blur." + ksize_half = (kernel_size - 1) * 0.5 + + x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) + + pdf = torch.exp(-0.5 * (x / sigma).pow(2)) + + x_kernel = pdf / pdf.sum() + x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) + + kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) + kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) + + padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] + + img = torch.nn.functional.pad(img, padding, mode="reflect") + img = torch.nn.functional.conv2d(img, kernel2d, groups=img.shape[-3]) + + return img + + +def get_gaussian_kernel2d( + kernel_size: Union[Tuple[int, int], int], + sigma: Union[Tuple[float, float], torch.Tensor], + force_even: bool = False, + *, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + sigma = torch.Tensor([[sigma, sigma]]).to(device=device, dtype=dtype) # type: ignore + + ksize_y, ksize_x = _unpack_2d_ks(kernel_size) + sigma_y, sigma_x = sigma[:, 0, None], sigma[:, 1, None] + + kernel_y = get_gaussian_kernel1d( + ksize_y, sigma_y, force_even, device=device, dtype=dtype + )[..., None] + kernel_x = get_gaussian_kernel1d( + ksize_x, sigma_x, force_even, device=device, dtype=dtype + )[..., None] + + return kernel_y * kernel_x.view(-1, 1, ksize_x) + + +def gaussian_blur2d( + input: torch.Tensor, + kernel_size: Union[Tuple[int, int], int], + sigma: Union[Tuple[float, float], torch.Tensor], +) -> torch.Tensor: + if isinstance(sigma, tuple): + sigma = torch.tensor([sigma], device=input.device, dtype=input.dtype) + else: + sigma = sigma.to(device=input.device, dtype=input.dtype) + + ky, kx = _unpack_2d_ks(kernel_size) + bs = sigma.shape[0] + kernel_x = get_gaussian_kernel1d(kx, sigma[:, 1].view(bs, 1)) + kernel_y = get_gaussian_kernel1d(ky, sigma[:, 0].view(bs, 1)) + out = filter2d_separable(input, kernel_x, kernel_y) + + return out + + +def filter2d_separable( + input: torch.Tensor, + kernel_x: torch.Tensor, + kernel_y: torch.Tensor, +) -> torch.Tensor: + out_x = filter2d(input, kernel_x[..., None, :]) + out = filter2d(out_x, kernel_y[..., None]) + return out + + +def filter2d( + input: torch.Tensor, + kernel: torch.Tensor, +) -> torch.Tensor: + b, c, h, w = input.shape + tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype) + + tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) + height, width = tmp_kernel.shape[-2:] + + # pad the input tensor + padding_shape: list[int] = _compute_padding([height, width]) + input = torch.nn.functional.pad(input, padding_shape, mode="reflect") + + # kernel and input tensor reshape to align element-wise or batch-wise params + tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) + input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) + + # convolve the tensor with the kernel. + output = torch.nn.functional.conv2d( + input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1 + ) + out = output.view(b, c, h, w) + + return out + + +def unsharp_mask( + input: torch.Tensor, + kernel_size: Union[Tuple[int, int], int], + sigma: Union[Tuple[float, float], torch.Tensor], +) -> torch.Tensor: + data_blur: torch.Tensor = gaussian_blur2d(input, kernel_size, sigma) + data_sharpened: torch.Tensor = input + (input - data_blur) + return data_sharpened + + +def get_gaussian_kernel1d( + kernel_size: int, + sigma: Union[float, torch.Tensor], + force_even: bool = False, + *, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + return gaussian(kernel_size, sigma, device=device, dtype=dtype) + + +def _compute_padding(kernel_size: list[int]) -> list[int]: + computed = [k - 1 for k in kernel_size] + out_padding = 2 * len(kernel_size) * [0] + + for i in range(len(kernel_size)): + computed_tmp = computed[-(i + 1)] + + pad_front = computed_tmp // 2 + pad_rear = computed_tmp - pad_front + + out_padding[2 * i + 0] = pad_front + out_padding[2 * i + 1] = pad_rear + + return out_padding + + +def _bilateral_blur( + input: torch.Tensor, + guidance: Optional[torch.Tensor], + kernel_size: Union[Tuple[int, int], int], + sigma_color: Union[float, torch.Tensor], + sigma_space: Union[Tuple[float, float], torch.Tensor], + border_type: str = "reflect", + color_distance_type: str = "l1", +) -> torch.Tensor: + if isinstance(sigma_color, torch.Tensor): + sigma_color = sigma_color.to(device=input.device, dtype=input.dtype).view( + -1, 1, 1, 1, 1 + ) + + ky, kx = _unpack_2d_ks(kernel_size) + pad_y, pad_x = _compute_zero_padding(kernel_size) + + padded_input = torch.nn.functional.pad( + input, (pad_x, pad_x, pad_y, pad_y), mode=border_type + ) + unfolded_input = ( + padded_input.unfold(2, ky, 1).unfold(3, kx, 1).flatten(-2) + ) # (B, C, H, W, Ky x Kx) + + if guidance is None: + guidance = input + unfolded_guidance = unfolded_input + else: + padded_guidance = torch.nn.functional.pad( + guidance, (pad_x, pad_x, pad_y, pad_y), mode=border_type + ) + unfolded_guidance = ( + padded_guidance.unfold(2, ky, 1).unfold(3, kx, 1).flatten(-2) + ) # (B, C, H, W, Ky x Kx) + + diff = unfolded_guidance - guidance.unsqueeze(-1) + if color_distance_type == "l1": + color_distance_sq = diff.abs().sum(1, keepdim=True).square() + elif color_distance_type == "l2": + color_distance_sq = diff.square().sum(1, keepdim=True) + else: + raise ValueError("color_distance_type only acceps l1 or l2") + color_kernel = ( + -0.5 / sigma_color**2 * color_distance_sq + ).exp() # (B, 1, H, W, Ky x Kx) + + space_kernel = get_gaussian_kernel2d( + kernel_size, sigma_space, device=input.device, dtype=input.dtype # type: ignore + ) + space_kernel = space_kernel.view(-1, 1, 1, 1, kx * ky) + + kernel = space_kernel * color_kernel + out = (unfolded_input * kernel).sum(-1) / kernel.sum(-1) + return out + + +def adaptive_anisotropic_filter(x, g=None): + if g is None: + g = x + s, m = torch.std_mean(g, dim=(1, 2, 3), keepdim=True) + s = s + 1e-5 + guidance = (g - m) / s + y = _bilateral_blur( + x, + guidance, + kernel_size=(13, 13), + sigma_color=3.0, + sigma_space=3.0, # type: ignore + border_type="reflect", + color_distance_type="l1", + ) + return y diff --git a/core/inference/utilities/cfg.py b/core/inference/utilities/cfg.py new file mode 100644 index 000000000..2ad30e545 --- /dev/null +++ b/core/inference/utilities/cfg.py @@ -0,0 +1,118 @@ +from typing import Optional + +import torch + +import k_diffusion +from .anisotropic import adaptive_anisotropic_filter, unsharp_mask +from core.config import config + +cfg_x0, cfg_s, cfg_cin, eps_record = None, None, None, None + + +def patched_ddpm_denoiser_forward(self, input, sigma, **kwargs): + global cfg_x0, cfg_s, cfg_cin, eps_record + + c_out, c_in = [ + k_diffusion.utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma) + ] + cfg_x0, cfg_s, cfg_cin = input, c_out, c_in + + c_in, c_out = c_in.to(device=input.device), c_out.to(device=input.device) + + eps = self.get_eps( + input * c_in, + self.sigma_to_t(sigma.to(device=self.log_sigmas.device)).to(input.device), + **kwargs, + ) + + if not isinstance(eps, torch.Tensor): + return eps[0] * c_out + input + else: + if eps.shape != input.shape: + eps = torch.nn.functional.interpolate( + eps, (input.shape[2], input.shape[3]), mode="bilinear" + ) + return eps * c_out + input + + +def patched_vddpm_denoiser_forward(self, input, sigma, **kwargs): + global cfg_x0, cfg_s, cfg_cin + + c_skip, c_out, c_in = [ + k_diffusion.utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma) + ] + cfg_x0, cfg_s, cfg_cin = input, c_out, c_in + + c_skip, c_out, c_in = ( + c_skip.to(device=input.device), + c_out.to(device=input.device), + c_in.to(device=input.device), + ) + + v = self.get_v( + input * c_in, + self.sigma_to_t(sigma.to(device=self.log_sigmas.device)).to(input.device), + **kwargs, + ) + + if not isinstance(v, torch.Tensor): + return v[0] * c_out + input * c_skip + else: + if v.shape != input.shape: + v = torch.nn.functional.interpolate( + v, (input.shape[2], input.shape[3]), mode="bilinear" + ) + return v * c_out + input * c_skip + + +k_diffusion.external.DiscreteEpsDDPMDenoiser.forward = patched_ddpm_denoiser_forward +k_diffusion.external.DiscreteVDDPMDenoiser.forward = patched_vddpm_denoiser_forward + + +def calculate_cfg( + i: int, + cond: torch.Tensor, + uncond: torch.Tensor, + cfg: float, + timestep: torch.IntTensor, + additional_pred: Optional[torch.Tensor], +): + if config.api.apply_unsharp_mask: + cc = uncond + cfg * (cond - uncond) + + MIX_FACTOR = 0.003 + cond_scale_factor = min(0.02 * cfg, 0.65) + usm_sigma = torch.clamp(1 + timestep * cond_scale_factor, min=1e-6) + sharpened = unsharp_mask(cond, (3, 3), (usm_sigma, usm_sigma)) # type: ignore + + return cc + (sharpened - cc) * MIX_FACTOR + + if config.api.cfg_rescale_threshold == "off": + if additional_pred is not None: + additional_pred, _ = additional_pred.chunk(2) + uncond = additional_pred + return uncond + cfg * (cond - uncond) + + if config.api.cfg_rescale_threshold <= cfg: + global cfg_x0, cfg_s + + if cfg_x0.shape[0] == 2: # type: ignore + cfg_x0, _ = cfg_x0.chunk(2) # type: ignore + + positive_x0 = cond * cfg_s + cfg_x0 + t = 1.0 - (timestep / 999.0)[:, None, None, None].clone() + # Magic number: 2.0 is "sharpness" + alpha = 0.001 * 2.0 * t + + positive_eps_degraded = adaptive_anisotropic_filter(x=cond, g=positive_x0) + cond = positive_eps_degraded * alpha + cond * (1.0 - alpha) + + reps = (uncond + cfg * (cond - uncond)) * t + # Magic number: 0.7 is "base cfg" + mimicked = (uncond + 0.7 * (cond - uncond)) * (1 - t) + return reps + mimicked + + if additional_pred is not None: + additional_pred, _ = additional_pred.chunk(2) + uncond = additional_pred + return uncond + cfg * (cond - uncond) diff --git a/core/inference/utilities/convert_from_ckpt.py b/core/inference/utilities/convert_from_ckpt.py new file mode 100644 index 000000000..2b31b902f --- /dev/null +++ b/core/inference/utilities/convert_from_ckpt.py @@ -0,0 +1,1391 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from io import BytesIO +from typing import Dict, Optional, Union +import logging + +import requests +import torch +from omegaconf import OmegaConf +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device +from transformers import ( + CLIPImageProcessor, + CLIPTextConfig, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) +from diffusers import ( + StableDiffusionPipeline, # type: ignore + StableDiffusionXLPipeline, # type: ignore +) +from diffusers.models import ( + AutoencoderKL, # type: ignore + UNet2DConditionModel, # type: ignore +) +from diffusers.schedulers import ( + DDIMScheduler, # type: ignore + EulerDiscreteScheduler, # type: ignore +) +from diffusers.pipelines.pipeline_utils import DiffusionPipeline + +from core.config import config as volta_config +from core.optimizations.sdxl_unet import UNet2DConditionModel as SDXLUNet2D +from .load import load_checkpoint + +logger = logging.getLogger(__name__) + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments( + new_item, n_shave_prefix_segments=n_shave_prefix_segments + ) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments( + new_item, n_shave_prefix_segments=n_shave_prefix_segments + ) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("q.weight", "to_q.weight") + new_item = new_item.replace("q.bias", "to_q.bias") + + new_item = new_item.replace("k.weight", "to_k.weight") + new_item = new_item.replace("k.bias", "to_k.bias") + + new_item = new_item.replace("v.weight", "to_v.weight") + new_item = new_item.replace("v.bias", "to_v.bias") + + new_item = new_item.replace("proj_out.weight", "to_out.0.weight") + new_item = new_item.replace("proj_out.bias", "to_out.0.bias") + + new_item = shave_segments( + new_item, n_shave_prefix_segments=n_shave_prefix_segments + ) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def assign_to_checkpoint( + paths, + checkpoint, + old_checkpoint, + attention_paths_to_split=None, + additional_replacements=None, + config=None, +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits + attention layers, and takes into account additional replacements that may arise. + + Assigns the weights to the new checkpoint. + """ + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 # type: ignore + + old_tensor = old_tensor.reshape( + (num_heads, 3 * channels // num_heads) + old_tensor.shape[1:] + ) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if ( + attention_paths_to_split is not None + and new_path in attention_paths_to_split + ): + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + is_attn_weight = "proj_attn.weight" in new_path or ( + "attentions" in new_path and "to_" in new_path + ) + shape = old_checkpoint[path["old"]].shape + if is_attn_weight and len(shape) == 3: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + elif is_attn_weight and len(shape) == 4: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + if controlnet: + unet_params = original_config.model.params.control_stage_config.params + else: + if ( + "unet_config" in original_config.model.params + and original_config.model.params.unet_config is not None + ): + unet_params = original_config.model.params.unet_config.params + else: + unet_params = original_config.model.params.network_config.params + + vae_params = original_config.model.params.first_stage_config.params.ddconfig + + block_out_channels = [ + unet_params.model_channels * mult for mult in unet_params.channel_mult + ] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = ( + "CrossAttnDownBlock2D" + if resolution in unet_params.attention_resolutions + else "DownBlock2D" + ) + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = ( + "CrossAttnUpBlock2D" + if resolution in unet_params.attention_resolutions + else "UpBlock2D" + ) + up_block_types.append(block_type) + resolution //= 2 + + if unet_params.transformer_depth is not None: + transformer_layers_per_block = ( + unet_params.transformer_depth + if isinstance(unet_params.transformer_depth, int) + else list(unet_params.transformer_depth) + ) + else: + transformer_layers_per_block = 1 + + vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) + + head_dim = unet_params.num_heads if "num_heads" in unet_params else None + use_linear_projection = ( + unet_params.use_linear_in_transformer + if "use_linear_in_transformer" in unet_params + else False + ) + if use_linear_projection: + # stable diffusion 2-base-512 and 2-768 + if head_dim is None: + head_dim_mult = unet_params.model_channels // unet_params.num_head_channels + head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)] + + class_embed_type = None + addition_embed_type = None + addition_time_embed_dim = None + projection_class_embeddings_input_dim = None + context_dim = None + + if unet_params.context_dim is not None: + context_dim = ( + unet_params.context_dim + if isinstance(unet_params.context_dim, int) + else unet_params.context_dim[0] + ) + + if "num_classes" in unet_params: + if unet_params.num_classes == "sequential": + if context_dim in [2048, 1280]: + # SDXL + addition_embed_type = "text_time" + addition_time_embed_dim = 256 + else: + class_embed_type = "projection" + assert "adm_in_channels" in unet_params + projection_class_embeddings_input_dim = unet_params.adm_in_channels + + config = { + "sample_size": image_size // vae_scale_factor, + "in_channels": unet_params.in_channels, + "down_block_types": tuple(down_block_types), + "block_out_channels": tuple(block_out_channels), + "layers_per_block": unet_params.num_res_blocks, + "cross_attention_dim": context_dim, + "attention_head_dim": head_dim, + "use_linear_projection": use_linear_projection, + "class_embed_type": class_embed_type, + "addition_embed_type": addition_embed_type, + "addition_time_embed_dim": addition_time_embed_dim, + "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, + "transformer_layers_per_block": transformer_layers_per_block, + } + + if "disable_self_attentions" in unet_params: + config["only_cross_attention"] = unet_params.disable_self_attentions + + if "num_classes" in unet_params and isinstance(unet_params.num_classes, int): + config["num_class_embeds"] = unet_params.num_classes + + if controlnet: + config["conditioning_channels"] = unet_params.hint_channels + else: + config["out_channels"] = unet_params.out_channels + config["up_block_types"] = tuple(up_block_types) + + return config + + +def create_vae_diffusers_config(original_config, image_size: int): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + vae_params = original_config.model.params.first_stage_config.params.ddconfig + _ = original_config.model.params.first_stage_config.params.embed_dim + + block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) + + config = { + "sample_size": image_size, + "in_channels": vae_params.in_channels, + "out_channels": vae_params.out_ch, + "down_block_types": tuple(down_block_types), + "up_block_types": tuple(up_block_types), + "block_out_channels": tuple(block_out_channels), + "latent_channels": vae_params.z_channels, + "layers_per_block": vae_params.num_res_blocks, + } + return config + + +def convert_ldm_unet_checkpoint( + checkpoint, + config, + path=None, + extract_ema=False, + skip_extract_state_dict=False, +): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + if skip_extract_state_dict: + unet_state_dict = checkpoint + else: + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + + unet_key = "model.diffusion_model." + + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.") + logger.warning( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop( + flat_ema_key + ) + else: + if sum(k.startswith("model_ema") for k in keys) > 100: + logger.warning( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict[ + "time_embed.0.weight" + ] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict[ + "time_embed.0.bias" + ] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict[ + "time_embed.2.weight" + ] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict[ + "time_embed.2.bias" + ] + + if config["class_embed_type"] is None: + # No parameters to port + ... + elif ( + config["class_embed_type"] == "timestep" + or config["class_embed_type"] == "projection" + ): + new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict[ + "label_emb.0.0.weight" + ] + new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict[ + "label_emb.0.0.bias" + ] + new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict[ + "label_emb.0.2.weight" + ] + new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict[ + "label_emb.0.2.bias" + ] + else: + raise NotImplementedError( + f"Not implemented `class_embed_type`: {config['class_embed_type']}" + ) + + if config["addition_embed_type"] == "text_time": + new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict[ + "label_emb.0.0.weight" + ] + new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict[ + "label_emb.0.0.bias" + ] + new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict[ + "label_emb.0.2.weight" + ] + new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict[ + "label_emb.0.2.bias" + ] + + # Relevant to StableDiffusionUpscalePipeline + if "num_class_embeds" in config: + if (config["num_class_embeds"] is not None) and ( + "label_emb.weight" in unet_state_dict + ): + new_checkpoint["class_embedding.weight"] = unet_state_dict[ + "label_emb.weight" + ] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len( + { + ".".join(layer.split(".")[:2]) + for layer in unet_state_dict + if "input_blocks" in layer + } + ) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len( + { + ".".join(layer.split(".")[:2]) + for layer in unet_state_dict + if "middle_block" in layer + } + ) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len( + { + ".".join(layer.split(".")[:2]) + for layer in unet_state_dict + if "output_blocks" in layer + } + ) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key + for key in input_blocks[i] + if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[ + f"down_blocks.{block_id}.downsamplers.0.conv.weight" + ] = unet_state_dict.pop(f"input_blocks.{i}.0.op.weight") + new_checkpoint[ + f"down_blocks.{block_id}.downsamplers.0.conv.bias" + ] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias") + + paths = renew_resnet_paths(resnets) + meta_path = { + "old": f"input_blocks.{i}.0", + "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + config=config, + ) + + if len(attentions): + paths = renew_attention_paths(attentions) + + meta_path = { + "old": f"input_blocks.{i}.1", + "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + config=config, + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + config=config, + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [ + key for key in output_blocks[i] if f"output_blocks.{i}.1" in key + ] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = { + "old": f"output_blocks.{i}.0", + "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + config=config, + ) + + output_block_list = {k: sorted(v) for k, v in output_block_list.items()} + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index( + ["conv.bias", "conv.weight"] + ) + new_checkpoint[ + f"up_blocks.{block_id}.upsamplers.0.conv.weight" + ] = unet_state_dict[f"output_blocks.{i}.{index}.conv.weight"] + new_checkpoint[ + f"up_blocks.{block_id}.upsamplers.0.conv.bias" + ] = unet_state_dict[f"output_blocks.{i}.{index}.conv.bias"] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + config=config, + ) + else: + resnet_0_paths = renew_resnet_paths( + output_block_layers, n_shave_prefix_segments=1 + ) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join( + [ + "up_blocks", + str(block_id), + "resnets", + str(layer_in_block_id), + path["new"], + ] + ) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + return new_checkpoint + + +def convert_ldm_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + vae_state_dict = {} + keys = list(checkpoint.keys()) + vae_key = ( + "first_stage_model." + if any(k.startswith("first_stage_model.") for k in keys) + else "" + ) + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict[ + "encoder.conv_out.weight" + ] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict[ + "encoder.norm_out.weight" + ] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict[ + "encoder.norm_out.bias" + ] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict[ + "decoder.conv_out.weight" + ] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict[ + "decoder.norm_out.weight" + ] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict[ + "decoder.norm_out.bias" + ] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len( + { + ".".join(layer.split(".")[:3]) + for layer in vae_state_dict + if "encoder.down" in layer + } + ) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] + for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len( + { + ".".join(layer.split(".")[:3]) + for layer in vae_state_dict + if "decoder.up" in layer + } + ) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] + for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [ + key + for key in down_blocks[i] + if f"down.{i}" in key and f"down.{i}.downsample" not in key + ] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[ + f"encoder.down_blocks.{i}.downsamplers.0.conv.weight" + ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight") + new_checkpoint[ + f"encoder.down_blocks.{i}.downsamplers.0.conv.bias" + ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias") + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key + for key in up_blocks[block_id] + if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[ + f"decoder.up_blocks.{i}.upsamplers.0.conv.weight" + ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"] + new_checkpoint[ + f"decoder.up_blocks.{i}.upsamplers.0.conv.bias" + ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + + +def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None): + if text_encoder is None: + config_name = "openai/clip-vit-large-patch14" + config = CLIPTextConfig.from_pretrained( + config_name, local_files_only=local_files_only + ) + + with init_empty_weights(): + text_model = CLIPTextModel(config) # type: ignore + else: + text_model = text_encoder + + keys = list(checkpoint.keys()) + + text_model_dict = {} + + remove_prefixes = [ + "cond_stage_model.transformer", + "conditioner.embedders.0.transformer", + ] + + for key in keys: + for prefix in remove_prefixes: + if key.startswith(prefix): + text_model_dict[key[len(prefix + ".") :]] = checkpoint[key] + + for param_name, param in text_model_dict.items(): + set_module_tensor_to_device( + text_model, + param_name, + volta_config.api.load_device, # never set dtype here, it screws things up + # dtype=volta_config.api.load_dtype, + value=param, + ) + + return text_model + + +textenc_conversion_lst = [ + ("positional_embedding", "text_model.embeddings.position_embedding.weight"), + ("token_embedding.weight", "text_model.embeddings.token_embedding.weight"), + ("ln_final.weight", "text_model.final_layer_norm.weight"), + ("ln_final.bias", "text_model.final_layer_norm.bias"), + ("text_projection", "text_projection.weight"), +] +textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst} + +textenc_transformer_conversion_lst = [ + # (stable-diffusion, HF Diffusers) + ("resblocks.", "text_model.encoder.layers."), + ("ln_1", "layer_norm1"), + ("ln_2", "layer_norm2"), + (".c_fc.", ".fc1."), + (".c_proj.", ".fc2."), + (".attn", ".self_attn"), + ("ln_final.", "transformer.text_model.final_layer_norm."), + ( + "token_embedding.weight", + "transformer.text_model.embeddings.token_embedding.weight", + ), + ( + "positional_embedding", + "transformer.text_model.embeddings.position_embedding.weight", + ), +] +protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst} +textenc_pattern = re.compile("|".join(protected.keys())) + + +def convert_open_clip_checkpoint( + checkpoint, + config_name, + prefix="cond_stage_model.model.", + has_projection=False, + local_files_only=False, + **config_kwargs, +): + config = CLIPTextConfig.from_pretrained( + config_name, **config_kwargs, local_files_only=local_files_only + ) + + with init_empty_weights(): + text_model = ( + CLIPTextModelWithProjection(config) # type: ignore + if has_projection + else CLIPTextModel(config) # type: ignore + ) + + keys = list(checkpoint.keys()) + + keys_to_ignore = [] + if ( + config_name == "stabilityai/stable-diffusion-2" + and config.num_hidden_layers == 23 + ): + # make sure to remove all keys > 22 + keys_to_ignore += [ + k + for k in keys + if k.startswith("cond_stage_model.model.transformer.resblocks.23") + ] + keys_to_ignore += ["cond_stage_model.model.text_projection"] + + text_model_dict = {} + + if prefix + "text_projection" in checkpoint: + d_model = int(checkpoint[prefix + "text_projection"].shape[0]) + else: + d_model = 1024 + + text_model_dict[ + "text_model.embeddings.position_ids" + ] = text_model.text_model.embeddings.get_buffer("position_ids") + + for key in keys: + if key in keys_to_ignore: + continue + if key[len(prefix) :] in textenc_conversion_map: + if key.endswith("text_projection"): + value = checkpoint[key].T.contiguous() + else: + value = checkpoint[key] + + text_model_dict[textenc_conversion_map[key[len(prefix) :]]] = value + + if key.startswith(prefix + "transformer."): + new_key = key[len(prefix + "transformer.") :] + if new_key.endswith(".in_proj_weight"): + new_key = new_key[: -len(".in_proj_weight")] + new_key = textenc_pattern.sub( + lambda m: protected[re.escape(m.group(0))], new_key + ) + text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][ + :d_model, : + ] + text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][ + d_model : d_model * 2, : + ] + text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][ + d_model * 2 :, : + ] + elif new_key.endswith(".in_proj_bias"): + new_key = new_key[: -len(".in_proj_bias")] + new_key = textenc_pattern.sub( + lambda m: protected[re.escape(m.group(0))], new_key + ) + text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model] + text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][ + d_model : d_model * 2 + ] + text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][ + d_model * 2 : + ] + else: + new_key = textenc_pattern.sub( + lambda m: protected[re.escape(m.group(0))], new_key + ) + + text_model_dict[new_key] = checkpoint[key] + + for param_name, param in text_model_dict.items(): + set_module_tensor_to_device( + text_model, + param_name, + volta_config.api.load_device, # don't set dtype here + # dtype=volta_config.api.load_dtype, + value=param, + ) + + return text_model + + +def stable_unclip_image_encoder(original_config, local_files_only=False): + """ + Returns the image processor and clip image encoder for the img2img unclip pipeline. + + We currently know of two types of stable unclip models which separately use the clip and the openclip image + encoders. + """ + + image_embedder_config = original_config.model.params.embedder_config + + sd_clip_image_embedder_class = image_embedder_config.target + sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1] + + if sd_clip_image_embedder_class == "ClipImageEmbedder": + clip_model_name = image_embedder_config.params.model + + if clip_model_name == "ViT-L/14": + feature_extractor = CLIPImageProcessor() + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + "openai/clip-vit-large-patch14", local_files_only=local_files_only + ) + else: + raise NotImplementedError( + f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}" + ) + + elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder": + feature_extractor = CLIPImageProcessor() + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", local_files_only=local_files_only + ) + else: + raise NotImplementedError( + f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}" + ) + + return feature_extractor, image_encoder + + +def download_from_original_stable_diffusion_ckpt( + checkpoint_path_or_dict: Union[str, Dict[str, torch.Tensor]], + original_config_file: str = None, # type: ignore + image_size: Optional[int] = None, + prediction_type: str = None, # type: ignore + model_type: str = None, # type: ignore + extract_ema: bool = False, + scheduler_type: str = "pndm", + num_in_channels: Optional[int] = None, + upcast_attention: Optional[bool] = None, + from_safetensors: bool = False, + local_files_only=False, +) -> DiffusionPipeline: + if prediction_type == "v-prediction": + prediction_type = "v_prediction" + + checkpoint = checkpoint_path_or_dict + if isinstance(checkpoint, str): + checkpoint = load_checkpoint(checkpoint, from_safetensors) + + # print( + # *list(map(lambda i: str(i[0]) + " " + str(i[1].shape), checkpoint.items())), + # sep="\n", + # ) + + # Sometimes models don't have the global_step item + if "global_step" in checkpoint: + global_step = checkpoint["global_step"] + else: + logger.debug("global_step key not found in model") + global_step = None + + # NOTE: this while loop isn't great but this controlnet checkpoint has one additional + # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 + while "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] # type: ignore + + if original_config_file is None: + key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" + key_name_sd_xl_base = ( + "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias" + ) + key_name_sd_xl_refiner = ( + "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias" + ) + + # model_type = "v1" + config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" + + if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024: # type: ignore + # model_type = "v2" + config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml" + if global_step == 110000: + # v2.1 needs to upcast attention + upcast_attention = True + elif key_name_sd_xl_base in checkpoint: + # only base xl has two text embedders + config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml" + elif key_name_sd_xl_refiner in checkpoint: + # only refiner xl has embedder and one text embedders + config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml" + + original_config_file = BytesIO(requests.get(config_url).content) # type: ignore + + original_config = OmegaConf.load(original_config_file) + + # Convert the text model. + if ( + model_type is None + and "cond_stage_config" in original_config.model.params + and original_config.model.params.cond_stage_config is not None + ): + model_type = original_config.model.params.cond_stage_config.target.split(".")[ + -1 + ] + logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}") + elif model_type is None and original_config.model.params.network_config is not None: + if original_config.model.params.network_config.params.context_dim == 2048: + model_type = "SDXL" + else: + model_type = "SDXL-Refiner" + if image_size is None: + image_size = 1024 + + # Check if we have a SDXL or SD model and initialize default pipeline + pipeline_class = StableDiffusionPipeline # type: ignore + if model_type in ["SDXL", "SDXL-Refiner"]: + pipeline_class = StableDiffusionXLPipeline # type: ignore + + conv_in_weight = checkpoint.get( # type: ignore + "model.diffusion_model.input_blocks.0.0.weight", None + ) + if conv_in_weight is None: + num_in_channels = 4 + else: + num_in_channels = conv_in_weight.shape[1] + + if "unet_config" in original_config.model.params: + original_config["model"]["params"]["unet_config"]["params"][ # type: ignore + "in_channels" + ] = num_in_channels + + if ( + "parameterization" in original_config["model"]["params"] # type: ignore + and original_config["model"]["params"]["parameterization"] == "v" # type: ignore + ): + if prediction_type is None: + # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` + # as it relies on a brittle global step parameter here + prediction_type = "epsilon" if global_step == 875000 else "v_prediction" + if image_size is None: + # NOTE: For stable diffusion 2 base one has to pass `image_size==512` + # as it relies on a brittle global step parameter here + image_size = 512 if global_step == 875000 else 768 + else: + if prediction_type is None: + prediction_type = "epsilon" + if image_size is None: + image_size = 512 + + num_train_timesteps = ( + getattr(original_config.model.params, "timesteps", None) or 1000 + ) + + if model_type in ["SDXL", "SDXL-Refiner"]: + scheduler_dict = { + "beta_schedule": "scaled_linear", + "beta_start": 0.00085, + "beta_end": 0.012, + "interpolation_type": "linear", + "num_train_timesteps": num_train_timesteps, + "prediction_type": "epsilon", + "sample_max_value": 1.0, + "set_alpha_to_one": False, + "skip_prk_steps": True, + "steps_offset": 1, + "timestep_spacing": "leading", + } + scheduler = EulerDiscreteScheduler.from_config(scheduler_dict) + scheduler_type = "euler" + else: + beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02 + beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085 + scheduler = DDIMScheduler( + beta_end=beta_end, + beta_schedule="scaled_linear", + beta_start=beta_start, + num_train_timesteps=num_train_timesteps, + steps_offset=1, + clip_sample=False, + set_alpha_to_one=False, + prediction_type=prediction_type, + ) + # make sure scheduler works correctly with DDIM + scheduler.register_to_config(clip_sample=False) # type: ignore + + if scheduler_type == "euler": + scheduler = EulerDiscreteScheduler.from_config(scheduler.config) # type: ignore + elif scheduler_type == "ddim": + scheduler = scheduler + + # Convert the UNet2DConditionModel model. + unet_config = create_unet_diffusers_config(original_config, image_size=image_size) # type: ignore + unet_config["upcast_attention"] = upcast_attention + + path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else "" + converted_unet_checkpoint = convert_ldm_unet_checkpoint( + checkpoint, unet_config, path=path, extract_ema=extract_ema + ) + + with init_empty_weights(): + if ( + model_type in ["SDXL", "SDXL-Refiner"] + and volta_config.api.use_minimal_sdxl_pipeline + ): + unet = SDXLUNet2D() + else: + unet = UNet2DConditionModel(**unet_config) + + if model_type not in ["SDXL", "SDXL-Refiner"]: # SBM Delay this. + for param_name, param in converted_unet_checkpoint.items(): + set_module_tensor_to_device( + unet, + param_name, + volta_config.api.load_device, + dtype=volta_config.api.load_dtype, + value=param, + ) + + # Convert the VAE model. + vae_config = create_vae_diffusers_config(original_config, image_size=image_size) # type: ignore + converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) + + if ( + "model" in original_config + and "params" in original_config.model + and "scale_factor" in original_config.model.params + ): + vae_scaling_factor = original_config.model.params.scale_factor + else: + vae_scaling_factor = 0.18215 # default SD scaling factor + + vae_config["scaling_factor"] = vae_scaling_factor + + with init_empty_weights(): + vae = AutoencoderKL(**vae_config) + + for param_name, param in converted_vae_checkpoint.items(): + set_module_tensor_to_device( + vae, + param_name, + volta_config.api.load_device, + dtype=volta_config.api.load_dtype, + value=param, + ) + + if model_type == "FrozenOpenCLIPEmbedder": + config_name = "stabilityai/stable-diffusion-2" + config_kwargs = {"subfolder": "text_encoder"} + + text_model = convert_open_clip_checkpoint( + checkpoint, + config_name, + local_files_only=local_files_only, + **config_kwargs, # type: ignore + ) + + tokenizer = CLIPTokenizer.from_pretrained( + "stabilityai/stable-diffusion-2", + subfolder="tokenizer", + local_files_only=local_files_only, + ) + + pipe = pipeline_class( # type: ignore + vae=vae, + text_encoder=text_model, # type: ignore + tokenizer=tokenizer, + unet=unet, # type: ignore + scheduler=scheduler, # type: ignore + safety_checker=None, # type: ignore + feature_extractor=None, # type: ignore + ) + elif model_type == "FrozenCLIPEmbedder": + text_model = convert_ldm_clip_checkpoint( + checkpoint, local_files_only=local_files_only, text_encoder=None + ) + tokenizer = CLIPTokenizer.from_pretrained( + "openai/clip-vit-large-patch14", local_files_only=local_files_only + ) + + pipe = pipeline_class( # type: ignore + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, # type: ignore + scheduler=scheduler, # type: ignore + safety_checker=None, # type: ignore + feature_extractor=None, # type: ignore + ) + else: + is_refiner = model_type == "SDXL-Refiner" + + tokenizer = None + text_encoder = None + if not is_refiner: + tokenizer = CLIPTokenizer.from_pretrained( + "openai/clip-vit-large-patch14", local_files_only=local_files_only + ) + text_encoder = convert_ldm_clip_checkpoint( + checkpoint, local_files_only=local_files_only + ) + + tokenizer_2 = CLIPTokenizer.from_pretrained( + "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", + pad_token="!", + local_files_only=local_files_only, + ) + + config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + config_kwargs = {"projection_dim": 1280} + prefix = ( + "conditioner.embedders.0.model." + if is_refiner + else "conditioner.embedders.1.model." + ) + + text_encoder_2 = convert_open_clip_checkpoint( + checkpoint, + config_name, + prefix=prefix, + has_projection=True, + local_files_only=local_files_only, + **config_kwargs, + ) + + for ( + param_name, + param, + ) in converted_unet_checkpoint.items(): # SBM Now move model to cpu. + set_module_tensor_to_device( + unet, + param_name, + volta_config.api.load_device, + dtype=volta_config.api.load_dtype, + value=param, + ) + pipeline_kwargs = { + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "text_encoder_2": text_encoder_2, + "tokenizer_2": tokenizer_2, + "unet": unet, + "scheduler": scheduler, + } + + if is_refiner: + pipeline_kwargs.update({"requires_aesthetics_score": is_refiner}) + else: + pipeline_kwargs.update({"force_zeros_for_empty_prompt": False}) + + pipe = pipeline_class(**pipeline_kwargs) # type: ignore + + return pipe # type: ignore diff --git a/core/inference/utilities/kohya_hires.py b/core/inference/utilities/kohya_hires.py new file mode 100644 index 000000000..98eee7024 --- /dev/null +++ b/core/inference/utilities/kohya_hires.py @@ -0,0 +1,139 @@ +from typing import Tuple, Optional +from functools import partial + +from diffusers import UNet2DConditionModel # type: ignore +from diffusers.models.unet_2d_blocks import CrossAttnUpBlock2D, UpBlock2D +import torch + +from core.flags import DeepshrinkFlag +from .latents import scale_latents + +step_limit = 0 + + +def nf( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + *args, + **kwargs, +) -> torch.FloatTensor: + mode = "bilinear" + if hasattr(self, "kohya_scaler"): + mode = self.kohya_scaler + if mode == "bislerp": + mode = "bilinear" + out = list(res_hidden_states_tuple) + for i, o in enumerate(out): + if o.shape[2] != hidden_states.shape[2]: + out[i] = torch.nn.functional.interpolate( + o, + ( + hidden_states.shape[2], + hidden_states.shape[3], + ), + mode=mode, + ) + res_hidden_states_tuple = tuple(out) + + return self.nn_forward( + *args, + hidden_states=hidden_states, + res_hidden_states_tuple=res_hidden_states_tuple, + **kwargs, + ) + + +CrossAttnUpBlock2D.nn_forward = CrossAttnUpBlock2D.forward # type: ignore +UpBlock2D.nn_forward = UpBlock2D.forward # type: ignore +CrossAttnUpBlock2D.forward = nf +UpBlock2D.forward = nf + + +def _round(x, y): + return y * round(x / y) + + +def modify_unet( + unet: UNet2DConditionModel, + step: int, + total_steps: int, + flag: Optional[DeepshrinkFlag] = None, +) -> UNet2DConditionModel: + if flag is None: + return unet + + global step_limit + + s1, s2 = flag.stop_at_1, flag.stop_at_2 + if s1 > s2: + s2 = s1 + p1 = (s1, flag.depth_1 - 1) + p2 = (s2, flag.depth_2 - 1) + + if step < step_limit: + return unet + + for s, d in [p1, p2]: + out_d = d if flag.early_out else -(d + 1) + out_d = min(out_d, len(unet.up_blocks) - 1) + if step < total_steps * s: + if not hasattr(unet.down_blocks[d], "kohya_scale"): + for block, scale in [ + (unet.down_blocks[d], flag.base_scale), + (unet.up_blocks[out_d], 1.0 / flag.base_scale), + ]: + setattr(block, "kohya_scale", scale) + setattr(block, "kohya_scaler", flag.scaler) + setattr(block, "_orignal_forawrd", block.forward) + + def new_forawrd(self, hidden_states, *args, **kwargs): + hidden_states = scale_latents( + hidden_states, + self.kohya_scale, + self.kohya_scaler, + False, + ) + if "scale" in kwargs: + kwargs.pop("scale") + return self._orignal_forawrd(hidden_states, *args, **kwargs) + + block.forward = partial(new_forawrd, block) + # In case someone wants to work on smooth scaling + # The double comments are there 'cause of attempts made before + # else: + # scale_ratio = step / (total_steps * s) + # downscale = min( + # (1 - flag.base_scale) * scale_ratio + flag.base_scale, + # ) + # # upscale = _round( + # upscale = (1.0 / flag.base_scale) * (flag.base_scale / downscale) #, 0.25 + # # ) + # unet.down_blocks[d].kohya_scale = downscale # _round(downscale, 0.2) # type: ignore + # unet.up_blocks[out_d].kohya_scale = upscale # type: ignore + # print( + # unet.down_blocks[d].kohya_scale, unet.up_blocks[out_d].kohya_scale + # ) + return unet + elif hasattr(unet.down_blocks[d], "kohya_scale") and ( + p1[1] != p2[1] or s == p2[0] + ): + unet.down_blocks[d].forward = unet.down_blocks[d]._orignal_forawrd + if hasattr(unet.up_blocks[out_d], "_orignal_forawrd"): + unet.up_blocks[out_d].forward = unet.up_blocks[out_d]._orignal_forawrd + step_limit = step + return unet + + +def post_process(unet: UNet2DConditionModel) -> UNet2DConditionModel: + for i, b in enumerate(unet.down_blocks): + if hasattr(b, "kohya_scale"): + unet.down_blocks[i].forward = b._orignal_forawrd + for i, b in enumerate(unet.up_blocks): + if hasattr(b, "kohya_scale"): + unet.up_blocks[i].forward = b._orignal_forawrd + + global step_limit + + step_limit = 0 + return unet diff --git a/core/inference/utilities/latents.py b/core/inference/utilities/latents.py index 0efb2cf8c..fecd0d4f2 100644 --- a/core/inference/utilities/latents.py +++ b/core/inference/utilities/latents.py @@ -1,3 +1,4 @@ +from contextlib import ExitStack import logging import math from time import time @@ -6,7 +7,7 @@ import numpy as np import torch import torch.nn.functional as F -from diffusers.models import vae as diffusers_vae +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL as diffusers_vae from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( StableDiffusionPipeline, ) @@ -18,6 +19,7 @@ from core.inference.utilities.philox import PhiloxGenerator from .random import randn +from core.optimizations.autocast_utils import autocast logger = logging.getLogger(__name__) @@ -149,7 +151,7 @@ def prepare_mask_latents( ) mask = mask.to(device=device, dtype=dtype) - masked_image = masked_image.to(device=device, dtype=dtype) + masked_image = masked_image.to(device=device, dtype=vae.dtype) masked_image_latents = vae_scaling_factor * vae.encode( masked_image ).latent_dist.sample(generator=generator) @@ -216,6 +218,36 @@ def prepare_image( return image +def preprocess_adapter_image(image, height, width): + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, Image.Image): + image = [image] + + if isinstance(image[0], Image.Image): + image = [ + np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])) + for i in image + ] + image = [ + i[None, ..., None] if i.ndim == 2 else i[None, ...] for i in image + ] # expand [h, w] or [h, w, c] to [b, h, w, c] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + if image[0].ndim == 3: # type: ignore + image = torch.stack(image, dim=0) # type: ignore + elif image[0].ndim == 4: # type: ignore + image = torch.cat(image, dim=0) # type: ignore + else: + raise ValueError( + f"Invalid image tensor! Expecting image tensor with 3 or 4 dimension, but recive: {image[0].ndim}" # type: ignore + ) + return image + + def preprocess_mask(mask): mask = mask.convert("L") # w, h = mask.size @@ -242,17 +274,27 @@ def prepare_latents( dtype: torch.dtype, device: torch.device, generator: Union[PhiloxGenerator, torch.Generator], + frames: Optional[int] = None, latents=None, latent_channels: Optional[int] = None, align_to: int = 1, ): if image is None: - shape = ( - batch_size, - pipe.unet.config.in_channels, # type: ignore - (math.ceil(height / align_to) * align_to) // pipe.vae_scale_factor, # type: ignore - (math.ceil(width / align_to) * align_to) // pipe.vae_scale_factor, # type: ignore - ) + if frames is not None: + shape = ( + batch_size, + pipe.unet.config.in_channels, + frames, + (math.ceil(height / align_to) * align_to) // pipe.vae_scale_factor, # type: ignore + (math.ceil(width / align_to) * align_to) // pipe.vae_scale_factor, # type: ignore + ) + else: + shape = ( + batch_size, + pipe.unet.config.in_channels, # type: ignore + (math.ceil(height / align_to) * align_to) // pipe.vae_scale_factor, # type: ignore + (math.ceil(width / align_to) * align_to) // pipe.vae_scale_factor, # type: ignore + ) if latents is None: # randn does not work reproducibly on mps @@ -265,15 +307,26 @@ def prepare_latents( latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler - latents = latents * pipe.scheduler.init_noise_sigma # type: ignore + sigma = pipe.scheduler.init_noise_sigma + if isinstance(sigma, torch.Tensor): + sigma = sigma.to(dtype=latents.dtype, device=latents.device) + latents = latents * sigma # type: ignore + if frames is not None: + latents = latents.to(memory_format=torch.channels_last_3d) # type: ignore return latents, None, None else: if image.shape[1] != 4: image = pad_tensor(image, pipe.vae_scale_factor) - init_latent_dist = pipe.vae.encode(image.to(config.api.device, dtype=pipe.vae.dtype)).latent_dist # type: ignore - init_latents = init_latent_dist.sample(generator=generator) + with ExitStack() as gs: + if pipe.vae.config["force_upcast"] or config.api.upcast_vae: + gs.enter_context(autocast(dtype=torch.float32)) + init_latent_dist = pipe.vae.encode(image.to(config.api.device, dtype=pipe.vae.dtype)).latent_dist # type: ignore + + if pipe.vae.config["force_upcast"] or config.api.upcast_vae: + gs.enter_context(autocast(dtype=config.api.load_dtype)) + init_latents = init_latent_dist.sample(generator=generator) # type: ignore init_latents = 0.18215 * init_latents - init_latents = torch.cat([init_latents] * batch_size, dim=0) + init_latents = torch.cat([init_latents] * batch_size, dim=0) # type: ignore else: logger.debug("Skipping VAE encode, already have latents") init_latents = image @@ -283,7 +336,7 @@ def prepare_latents( init_latents_orig = init_latents shape = init_latents.shape - if latent_channels is not None: + if latent_channels is not None and latent_channels != shape[1]: shape = ( batch_size, latent_channels, # type: ignore @@ -297,77 +350,7 @@ def prepare_latents( return latents, init_latents_orig, noise -def bislerp_original(samples, width, height): - shape = list(samples.shape) - width_scale = (shape[3]) / (width) - height_scale = (shape[2]) / (height) - - shape[3] = width - shape[2] = height - out1 = torch.empty( - shape, dtype=samples.dtype, layout=samples.layout, device=samples.device - ) - - def algorithm(in1, in2, t): - dims = in1.shape - val = t - - # flatten to batches - low = in1.reshape(dims[0], -1) - high = in2.reshape(dims[0], -1) - - low_weight = torch.norm(low, dim=1, keepdim=True) - low_weight[low_weight == 0] = 0.0000000001 - low_norm = low / low_weight - high_weight = torch.norm(high, dim=1, keepdim=True) - high_weight[high_weight == 0] = 0.0000000001 - high_norm = high / high_weight - - dot_prod = (low_norm * high_norm).sum(1) - dot_prod[dot_prod > 0.9995] = 0.9995 - dot_prod[dot_prod < -0.9995] = -0.9995 - omega = torch.acos(dot_prod) - so = torch.sin(omega) - res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low_norm + ( - torch.sin(val * omega) / so - ).unsqueeze(1) * high_norm - res *= low_weight * (1.0 - val) + high_weight * val - return res.reshape(dims) - - for x_dest in range(shape[3]): - for y_dest in range(shape[2]): - y = (y_dest + 0.5) * height_scale - 0.5 - x = (x_dest + 0.5) * width_scale - 0.5 - - x1 = max(math.floor(x), 0) - x2 = min(x1 + 1, samples.shape[3] - 1) - wx = x - math.floor(x) - - y1 = max(math.floor(y), 0) - y2 = min(y1 + 1, samples.shape[2] - 1) - wy = y - math.floor(y) - - in1 = samples[:, :, y1, x1] - in2 = samples[:, :, y1, x2] - in3 = samples[:, :, y2, x1] - in4 = samples[:, :, y2, x2] - - if (x1 == x2) and (y1 == y2): - out_value = in1 - elif x1 == x2: - out_value = algorithm(in1, in3, wy) - elif y1 == y2: - out_value = algorithm(in1, in2, wx) - else: - o1 = algorithm(in1, in2, wx) - o2 = algorithm(in3, in4, wx) - out_value = algorithm(o1, o2, wy) - - out1[:, :, y_dest, x_dest] = out_value - return out1 - - -def bislerp_gabeified(samples, width, height): +def bislerp(samples, width, height): device = samples.device def slerp(b1, b2, r): @@ -487,7 +470,6 @@ def scale_latents( "Interpolate the latents to the desired scale." s = time() - logger.debug(f"Scaling latents with shape {list(latents.shape)}, scale: {scale}") # Scale and round to multiple of 32 @@ -495,10 +477,8 @@ def scale_latents( height_truncated = int(latents.shape[3] * scale) # Scale the latents - if latent_scale_mode == "bislerp-tortured": - interpolated = bislerp_gabeified(latents, height_truncated, width_truncated) - elif latent_scale_mode == "bislerp-original": - interpolated = bislerp_original(latents, height_truncated, width_truncated) + if latent_scale_mode == "bislerp": + interpolated = bislerp(latents, height_truncated, width_truncated) else: interpolated = F.interpolate( latents, diff --git a/core/inference/utilities/load.py b/core/inference/utilities/load.py new file mode 100644 index 000000000..1462ce805 --- /dev/null +++ b/core/inference/utilities/load.py @@ -0,0 +1,53 @@ +from typing import Dict +from io import BytesIO +import logging +from time import perf_counter as time + +import torch +from safetensors.torch import load, load_file + +from core.config import config + + +logger = logging.getLogger(__name__) + + +def load_checkpoint(path: str, from_safetensors: bool) -> Dict[str, torch.Tensor]: + now = time() + if from_safetensors: + dev = str(config.api.load_device) + if "cuda" in dev: + dev = int(dev.split(":")[1]) + + if config.api.stream_load: + with open(path, "rb") as f: + checkpoint = load(f.read()) + checkpoint = { + k: v.to(device=config.api.load_device) for k, v in checkpoint.items() + } + else: + checkpoint = load_file(path, device=dev) # type: ignore + else: + if config.api.stream_load: + with open(path, "rb") as f: + buffer = BytesIO(f.read()) + checkpoint = torch.load(buffer, map_location=config.api.load_device) + else: + try: + checkpoint = torch.load( + path, + mmap=True, + weights_only=True, + map_location=config.api.load_device, + ) + except RuntimeError: + # File is really old / wasn't saved with "_use_new_zipfile_serialization=True" + # so we cannot mmap. + checkpoint = torch.load( + path, + mmap=False, + weights_only=True, + map_location=config.api.load_device, + ) + logger.debug(f'Loading "{path}" took {round(time() - now, 2)}s.') + return checkpoint diff --git a/core/inference/utilities/lwp.py b/core/inference/utilities/lwp.py index d5b37d127..c7d13747c 100644 --- a/core/inference/utilities/lwp.py +++ b/core/inference/utilities/lwp.py @@ -35,6 +35,7 @@ re.X, ) + special_parser = re.compile( r"\<(lora|ti):([^\:\(\)\<\>\[\]]+)(?::[\s]*([+-]?(?:[0-9]*[.])?[0-9]+))?\>|\<(lora|ti):(http[^\(\)\<\>\[\]]+\/[^:]+)(?::[\s]*([+-]?(?:[0-9]*[.])?[0-9]+))?\>" ) @@ -171,9 +172,7 @@ def multiply_range(start_position, multiplier): return res -def get_prompts_with_weights( - pipe: StableDiffusionPipeline, prompt: List[str], max_length: int -): +def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int): r""" Tokenize a list of prompts and return its tokens with weights of each token. @@ -188,7 +187,7 @@ def get_prompts_with_weights( text_weight = [] for word, weight in texts_and_weights: # tokenize and discard the starting and the ending token - token = pipe.tokenizer(word, max_length=max_length, truncation=True).input_ids[1:-1] # type: ignore + token = tokenizer(word, max_length=max_length, truncation=True).input_ids[1:-1] # type: ignore text_token += token # copy the weight by length of token text_weight += [weight] * len(token) @@ -249,7 +248,8 @@ def get_unweighted_text_embeddings( text_input: torch.Tensor, chunk_length: int, no_boseos_middle: Optional[bool] = True, -): + text_encoder=None, +) -> Tuple[torch.Tensor, torch.Tensor]: """ When the length of tokens is a multiple of the capacity of the text encoder, it should be split into chunks and sent to the text encoder individually. @@ -259,41 +259,82 @@ def get_unweighted_text_embeddings( max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) - if max_embeddings_multiples > 1: - text_embeddings = [] - for i in range(max_embeddings_multiples): - # extract the i-th chunk - text_input_chunk = text_input[ - :, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2 - ].clone() - - # cover the head and the tail by the starting and the ending tokens - text_input_chunk[:, 0] = text_input[0, 0] - text_input_chunk[:, -1] = text_input[0, -1] + if not hasattr(pipe, "text_encoder_2"): + if max_embeddings_multiples > 1: + text_embeddings = [] + for i in range(max_embeddings_multiples): + # extract the i-th chunk + text_input_chunk = text_input[ + :, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2 + ].clone() + + # cover the head and the tail by the starting and the ending tokens + text_input_chunk[:, 0] = text_input[0, 0] + text_input_chunk[:, -1] = text_input[0, -1] + if hasattr(pipe, "clip_inference"): + text_embedding = pipe.clip_inference(text_input_chunk) + else: + text_embedding = pipe.text_encoder(text_input_chunk)[0] # type: ignore + + if no_boseos_middle: + if i == 0: + # discard the ending token + text_embedding = text_embedding[:, :-1] + elif i == max_embeddings_multiples - 1: + # discard the starting token + text_embedding = text_embedding[:, 1:] + else: + # discard both starting and ending tokens + text_embedding = text_embedding[:, 1:-1] + + text_embeddings.append(text_embedding) + text_embeddings = torch.concat(text_embeddings, axis=1) # type: ignore + else: if hasattr(pipe, "clip_inference"): - text_embedding = pipe.clip_inference(text_input_chunk) + text_embeddings = pipe.clip_inference(text_input) else: - text_embedding = pipe.text_encoder(text_input_chunk)[0] # type: ignore - - if no_boseos_middle: - if i == 0: - # discard the ending token - text_embedding = text_embedding[:, :-1] - elif i == max_embeddings_multiples - 1: - # discard the starting token - text_embedding = text_embedding[:, 1:] - else: - # discard both starting and ending tokens - text_embedding = text_embedding[:, 1:-1] - - text_embeddings.append(text_embedding) - text_embeddings = torch.concat(text_embeddings, axis=1) # type: ignore + text_embeddings = pipe.text_encoder(text_input)[0] # type: ignore + return text_embeddings, None # type: ignore else: - if hasattr(pipe, "clip_inference"): - text_embeddings = pipe.clip_inference(text_input) + if max_embeddings_multiples > 1: + text_embeddings = [] + hidden_states = [] + for i in range(max_embeddings_multiples): + text_input_chunk = text_input[ + :, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2 + ].clone() + + text_input_chunk[:, 0] = text_input[0, 0] + text_input_chunk[:, -1] = text_input[0, -1] + text_embedding = text_encoder( # type: ignore + text_input_chunk, output_hidden_states=True + ) + + if no_boseos_middle: + if i == 0: + text_embedding.hidden_states[-2] = text_embedding.hidden_states[ + -2 + ][:, :-1] + elif i == max_embeddings_multiples - 1: + text_embedding.hidden_states[-2] = text_embedding.hidden_states[ + -2 + ][:, 1:] + else: + text_embedding.hidden_states[-2] = text_embedding.hidden_states[ + -2 + ][:, 1:-1] + text_embeddings.append(text_embedding) + text_embeddings = torch.concat([x.hidden_states[-2] for x in text_embeddings], axis=1) # type: ignore + # Temporary, but hey, at least it works :) + # TODO: try and fix this monstrosity :/ + hidden_states = text_embeddings[-1][0].unsqueeze(0) # type: ignore + # text_embeddings = torch.Tensor(hidden_states.shape[0]) else: - text_embeddings = pipe.text_encoder(text_input)[0] # type: ignore - return text_embeddings + text_embeddings = text_encoder(text_input, output_hidden_states=True) # type: ignore + hidden_states = text_embeddings[0] + text_embeddings = text_embeddings.hidden_states[-2] + logger.debug(f"{hidden_states.shape} {text_embeddings.shape}") + return text_embeddings, hidden_states def get_weighted_text_embeddings( @@ -306,6 +347,8 @@ def get_weighted_text_embeddings( skip_weighting: Optional[bool] = False, seed: int = -1, prompt_expansion_settings: Optional[Dict] = None, + text_encoder=None, + tokenizer=None, ): r""" Prompts can be assigned with local weights using brackets. For example, @@ -334,7 +377,10 @@ def get_weighted_text_embeddings( """ prompt_expansion_settings = prompt_expansion_settings or {} - max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 # type: ignore + tokenizer = tokenizer or pipe.tokenizer + text_encoder = text_encoder or pipe.text_encoder + + max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 # type: ignore if isinstance(prompt, str): prompt = [prompt] @@ -436,18 +482,18 @@ def get_weighted_text_embeddings( if not skip_parsing: prompt_tokens, prompt_weights = get_prompts_with_weights( - pipe, prompt, max_length - 2 + tokenizer, prompt, max_length - 2 ) if uncond_prompt is not None: if isinstance(uncond_prompt, str): uncond_prompt = [uncond_prompt] uncond_tokens, uncond_weights = get_prompts_with_weights( - pipe, uncond_prompt, max_length - 2 + tokenizer, uncond_prompt, max_length - 2 ) else: prompt_tokens = [ token[1:-1] - for token in pipe.tokenizer( # type: ignore + for token in tokenizer( # type: ignore prompt, max_length=max_length, truncation=True ).input_ids ] @@ -457,7 +503,7 @@ def get_weighted_text_embeddings( uncond_prompt = [uncond_prompt] uncond_tokens = [ token[1:-1] - for token in pipe.tokenizer( # type: ignore + for token in tokenizer( # type: ignore uncond_prompt, max_length=max_length, truncation=True ).input_ids ] @@ -470,14 +516,14 @@ def get_weighted_text_embeddings( max_embeddings_multiples = min( max_embeddings_multiples, # type: ignore - (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1, # type: ignore + (max_length - 1) // (tokenizer.model_max_length - 2) + 1, # type: ignore ) max_embeddings_multiples = max(1, max_embeddings_multiples) # type: ignore - max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 # type: ignore + max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 # type: ignore # pad the length of tokens and weights - bos = pipe.tokenizer.bos_token_id # type: ignore - eos = pipe.tokenizer.eos_token_id # type: ignore + bos = tokenizer.bos_token_id # type: ignore + eos = tokenizer.eos_token_id # type: ignore prompt_tokens, prompt_weights = pad_tokens_and_weights( prompt_tokens, prompt_weights, @@ -485,10 +531,10 @@ def get_weighted_text_embeddings( bos, eos, no_boseos_middle=no_boseos_middle, # type: ignore - chunk_length=pipe.tokenizer.model_max_length, # type: ignore + chunk_length=tokenizer.model_max_length, # type: ignore ) prompt_tokens = torch.tensor( - prompt_tokens, dtype=torch.long, device=pipe.device if hasattr(pipe, "clip_inference") else pipe.text_encoder.device # type: ignore + prompt_tokens, dtype=torch.long, device=pipe.device if hasattr(pipe, "clip_inference") else text_encoder.device # type: ignore ) if uncond_prompt is not None: uncond_tokens, uncond_weights = pad_tokens_and_weights( @@ -498,41 +544,43 @@ def get_weighted_text_embeddings( bos, eos, no_boseos_middle=no_boseos_middle, # type: ignore - chunk_length=pipe.tokenizer.model_max_length, # type: ignore + chunk_length=tokenizer.model_max_length, # type: ignore ) uncond_tokens = torch.tensor( - uncond_tokens, dtype=torch.long, device=pipe.device if hasattr(pipe, "clip_inference") else pipe.text_encoder.device # type: ignore + uncond_tokens, dtype=torch.long, device=pipe.device if hasattr(pipe, "clip_inference") else text_encoder.device # type: ignore ) # get the embeddings - text_embeddings = get_unweighted_text_embeddings( + text_embeddings, hidden_states = get_unweighted_text_embeddings( pipe, # type: ignore prompt_tokens, - pipe.tokenizer.model_max_length, # type: ignore + tokenizer.model_max_length, # type: ignore no_boseos_middle=no_boseos_middle, + text_encoder=text_encoder, ) prompt_weights = torch.tensor( - prompt_weights, dtype=text_embeddings.dtype, device=pipe.device if hasattr(pipe, "clip_inference") else pipe.text_encoder.device # type: ignore + prompt_weights, dtype=text_embeddings.dtype, device=pipe.device if hasattr(pipe, "clip_inference") else text_encoder.device # type: ignore ) if uncond_prompt is not None: - uncond_embeddings = get_unweighted_text_embeddings( + uncond_embeddings, uncond_hidden_states = get_unweighted_text_embeddings( pipe, # type: ignore uncond_tokens, # type: ignore - pipe.tokenizer.model_max_length, # type: ignore + tokenizer.model_max_length, # type: ignore no_boseos_middle=no_boseos_middle, + text_encoder=text_encoder, ) uncond_weights = torch.tensor( - uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device if hasattr(pipe, "clip_inference") else pipe.text_encoder.device # type: ignore + uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device if hasattr(pipe, "clip_inference") else text_encoder.device # type: ignore ) # assign weights to the prompts and normalize in the sense of mean if (not skip_parsing) and (not skip_weighting): previous_mean = ( - text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) # type: ignore ) text_embeddings *= prompt_weights.unsqueeze(-1) current_mean = ( - text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) # type: ignore ) text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) if uncond_prompt is not None: @@ -552,5 +600,5 @@ def get_weighted_text_embeddings( ) if uncond_prompt is not None: - return text_embeddings, uncond_embeddings # type: ignore - return text_embeddings, None + return text_embeddings, hidden_states, uncond_embeddings, uncond_hidden_states # type: ignore + return text_embeddings, hidden_states, None, None diff --git a/core/inference/utilities/prompt_expansion/downloader.py b/core/inference/utilities/prompt_expansion/downloader.py index 59523da42..89cc70a9d 100644 --- a/core/inference/utilities/prompt_expansion/downloader.py +++ b/core/inference/utilities/prompt_expansion/downloader.py @@ -40,9 +40,13 @@ def download_model(): folder.mkdir() if isinstance(d, tuple): download_file(d[0], folder, add_filename=True) - os.rename( - (folder / "fooocus_expansion.bin").absolute().resolve().as_posix(), - (folder / "pytorch_model.bin").absolute().resolve().as_posix(), - ) + try: + # Should fix weird cases where it thinks d is a tuple while it is not??? + os.rename( + (folder / "fooocus_expansion.bin").absolute().resolve().as_posix(), + (folder / "pytorch_model.bin").absolute().resolve().as_posix(), + ) + except Exception: + pass else: download_file(d, folder, add_filename=True) diff --git a/core/inference/utilities/prompt_expansion/expand.py b/core/inference/utilities/prompt_expansion/expand.py index 000c3493f..7416ef0f9 100644 --- a/core/inference/utilities/prompt_expansion/expand.py +++ b/core/inference/utilities/prompt_expansion/expand.py @@ -98,7 +98,7 @@ def _device_dtype(prompt_to_prompt_device) -> Tuple[torch.device, torch.dtype]: if prompt_to_prompt_device == "gpu" else torch.device("cpu") ) - dtype = config.api.dtype if prompt_to_prompt_device == "gpu" else torch.float32 + dtype = config.api.load_dtype if prompt_to_prompt_device == "gpu" else torch.float32 return (device, dtype) @@ -128,7 +128,7 @@ def _load(prompt_to_prompt_model, prompt_to_prompt_device): _GPT.eval() device, dtype = _device_dtype(prompt_to_prompt_device) - _GPT = _GPT.to(device=device, dtype=dtype) + _GPT = _GPT.to(device=device, dtype=dtype) # type: ignore @torch.inference_mode() diff --git a/core/inference/utilities/sag/__init__.py b/core/inference/utilities/sag/__init__.py new file mode 100644 index 000000000..881ea1922 --- /dev/null +++ b/core/inference/utilities/sag/__init__.py @@ -0,0 +1,62 @@ +from typing import Callable + +import torch + +from core.scheduling import KdiffusionSchedulerAdapter +from .cross_attn import CrossAttnStoreProcessor +from .sag_utils import pred_epsilon, pred_x0, sag_masking +from .kdiff import calculate_sag as kdiff +from .diffusers import calculate_sag as diff + + +def calculate_sag( + pipe, + call: Callable, + store_processor, + latent: torch.Tensor, + noise_pred_uncond: torch.Tensor, + timestep: torch.IntTensor, + map_size: tuple, + text_embeddings: torch.Tensor, + scale: float, + cfg: float, + dtype: torch.dtype, + **additional_kwargs, +) -> torch.Tensor: + new_kwargs = {} + for kw, arg in additional_kwargs.items(): + if arg is not None and isinstance(arg, torch.Tensor): + if arg.shape[0] != 1: + arg, _ = arg.chunk(2) + new_kwargs[kw] = arg + + if isinstance(pipe.scheduler, KdiffusionSchedulerAdapter): + return kdiff( + pipe, + call, + store_processor, + latent, + noise_pred_uncond, + timestep, + map_size, + text_embeddings, + scale, + cfg, + dtype, + **new_kwargs, + ) + else: + return diff( + pipe, + call, + store_processor, + latent, + noise_pred_uncond, + timestep, + map_size, + text_embeddings, + scale, + cfg, + dtype, + **new_kwargs, + ) diff --git a/core/inference/utilities/sag/cross_attn.py b/core/inference/utilities/sag/cross_attn.py new file mode 100644 index 000000000..f699c9bcd --- /dev/null +++ b/core/inference/utilities/sag/cross_attn.py @@ -0,0 +1,190 @@ +import torch + +from core.config import config + + +class CrossAttnStoreProcessor: + "Modified Cross Attention Processor with capabilities to store probabilities." + + def __init__(self): + self.attention_map = None + self.use_sdpa = any( + [x == config.api.attention_processor for x in ["sdpa", "xformers"]] + ) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + if not self.use_sdpa: + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + + if attn.group_norm is not None: + hidden_states = attn.group_norm( + hidden_states.transpose(1, 2) + ).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + self.attention_map = ( + attention_probs.reshape(batch_size, -1, *attention_probs.shape[1:]) + .mean(1) + .sum(1) + ) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) # type: ignore + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + else: + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1] + ) + + if attn.group_norm is not None: + hidden_states = attn.group_norm( + hidden_states.transpose(1, 2) + ).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + ) + + dummy_value_eye = torch.eye( + value.shape[2], device=value.device, dtype=value.dtype + ).expand(batch_size, attn.heads, -1, -1) + self.attention_map = ( + torch.nn.functional.scaled_dot_product_attention( + query, + key, + dummy_value_eye, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + ) + .mean(1) + .sum(1) + ) + + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) # type: ignore + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states diff --git a/core/inference/utilities/sag/diffusers.py b/core/inference/utilities/sag/diffusers.py new file mode 100644 index 000000000..ddb4c4b37 --- /dev/null +++ b/core/inference/utilities/sag/diffusers.py @@ -0,0 +1,38 @@ +from typing import Callable + +import torch + +from .sag_utils import sag_masking, pred_epsilon, pred_x0 + + +def calculate_sag( + pipe, + call: Callable, + store_processor, + latent: torch.Tensor, + noise: torch.Tensor, + timestep: torch.IntTensor, + map_size: tuple, + text_embeddings: torch.Tensor, + scale: float, + cfg: float, + dtype: torch.dtype, + **additional_kwargs, +) -> torch.Tensor: + pred: torch.Tensor = pred_x0(pipe, latent, noise, timestep) + if cfg > 1: + cond_attn, _ = store_processor.attention_map.chunk(2) + text_embeddings, _ = text_embeddings.chunk(2) + else: + cond_attn = store_processor.attention_map + + eps = pred_epsilon(pipe, latent, noise, timestep) + degraded: torch.Tensor = sag_masking(pipe, pred, cond_attn, map_size, timestep, eps) + + degraded_prep = call( + degraded.to(dtype=dtype), + timestep, + cond=text_embeddings, + **additional_kwargs, + ) + return scale * (noise - degraded_prep) diff --git a/core/inference/utilities/sag/kdiff.py b/core/inference/utilities/sag/kdiff.py new file mode 100644 index 000000000..bb7bbd4a5 --- /dev/null +++ b/core/inference/utilities/sag/kdiff.py @@ -0,0 +1,42 @@ +from typing import Callable + +import torch + +from .sag_utils import sag_masking + + +def calculate_sag( + pipe, + call: Callable, + store_processor, + latent: torch.Tensor, + noise: torch.Tensor, + timestep: torch.IntTensor, + map_size: tuple, + text_embeddings: torch.Tensor, + scale: float, + cfg: float, + dtype: torch.dtype, + **additional_kwargs, +) -> torch.Tensor: + pred: torch.Tensor = noise # noise is already pred_x0 with kdiff + if cfg > 1: + cond_attn, _ = store_processor.attention_map.chunk(2) + text_embeddings, _ = text_embeddings.chunk(2) + else: + cond_attn = store_processor.attention_map + + degraded: torch.Tensor = sag_masking(pipe, pred, cond_attn, map_size, timestep, 0) + + # messed up the order of these two, spent half an hour looking for problems. + # Epsilon + compensation = noise - degraded + degraded = degraded - (noise - latent) + + degraded_pred = call( + degraded.to(dtype=dtype), + timestep, + cond=text_embeddings, + **additional_kwargs, + ) + return (noise - (degraded_pred + compensation)) * scale diff --git a/core/inference/pytorch/sag/sag_utils.py b/core/inference/utilities/sag/sag_utils.py similarity index 70% rename from core/inference/pytorch/sag/sag_utils.py rename to core/inference/utilities/sag/sag_utils.py index 2213940b4..f8ecec58f 100644 --- a/core/inference/pytorch/sag/sag_utils.py +++ b/core/inference/utilities/sag/sag_utils.py @@ -1,6 +1,8 @@ import torch import torch.nn.functional as F +from ..anisotropic import simple_gaussian_2d + def pred_x0(pipe, sample, model_output, timestep): """ @@ -35,18 +37,19 @@ def pred_x0(pipe, sample, model_output, timestep): return pred_original_sample -def sag_masking(pipe, original_latents, attn_map, map_size, t, eps): +def sag_masking(pipe, original_latents: torch.Tensor, attn_map, map_size, t, eps): "sag_masking" # Same masking process as in SAG paper: https://arxiv.org/pdf/2210.00939.pdf - _, hw1, hw2 = attn_map.shape - b, latent_channel, latent_h, latent_w = original_latents.shape + if original_latents.dim() == 5: + b, latent_channel, _, latent_h, latent_w = original_latents.shape + else: + b, latent_channel, latent_h, latent_w = original_latents.shape h = pipe.unet.config.attention_head_dim if isinstance(h, list): h = h[-1] # Produce attention mask - attn_map = attn_map.reshape(b, h, hw1, hw2) - attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > 1.0 + attn_mask = attn_map > 1.0 attn_mask = ( attn_mask.reshape(b, map_size[0], map_size[1]) .unsqueeze(1) @@ -56,13 +59,14 @@ def sag_masking(pipe, original_latents, attn_map, map_size, t, eps): attn_mask = F.interpolate(attn_mask, (latent_h, latent_w)) # Blur according to the self-attention mask - degraded_latents = gaussian_blur_2d(original_latents, kernel_size=9, sigma=1.0) + degraded_latents = simple_gaussian_2d(original_latents, kernel_size=9, sigma=1.0) # type: ignore degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask) # Noise it again to match the noise level - degraded_latents = pipe.scheduler.add_noise( - degraded_latents, noise=eps, timesteps=torch.tensor([t]) - ) + if isinstance(eps, torch.Tensor): + degraded_latents = pipe.scheduler.add_noise( + degraded_latents, noise=eps, timesteps=torch.tensor([t]) + ) return degraded_latents @@ -89,25 +93,3 @@ def pred_epsilon(pipe, sample, model_output, timestep): ) return pred_eps - - -def gaussian_blur_2d(img, kernel_size, sigma): - "Blurs an image with gaussian blur." - ksize_half = (kernel_size - 1) * 0.5 - - x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) - - pdf = torch.exp(-0.5 * (x / sigma).pow(2)) - - x_kernel = pdf / pdf.sum() - x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) - - kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) - kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) - - padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] - - img = F.pad(img, padding, mode="reflect") - img = F.conv2d(img, kernel2d, groups=img.shape[-3]) - - return img diff --git a/core/inference/utilities/scalecrafter.py b/core/inference/utilities/scalecrafter.py new file mode 100644 index 000000000..e4c779655 --- /dev/null +++ b/core/inference/utilities/scalecrafter.py @@ -0,0 +1,393 @@ +from copy import deepcopy +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple +import os +import yaml +import math +from pathlib import Path + +from diffusers.models.unet_2d_condition import UNet2DConditionModel +import torch +import scipy + + +@dataclass +class ScalecrafterSettings: + inflate_tau: float + ndcfg_tau: float + dilate_tau: float + + progressive: bool + + dilation_settings: Dict[str, float] + ndcfg_dilate_settings: Dict[str, float] + disperse_list: List[str] + disperse: Optional[torch.Tensor] + + height: int = 0 + width: int = 0 + base: str = "sd15" + + +SCALECRAFTER_DIR = Path("data/scalecrafter") + + +_unet_inflate, _unet_inflate_vanilla = None, None +_backup_forwards = dict() + + +def find_config_closest_to( + base: str, height: int, width: int, disperse: bool = False +) -> ScalecrafterSettings: + """Find ScaleCrafter config for specified SDxx version closest to provided resolution.""" + # Normalize base to the format in SCALECRAFTER_CONFIG + base = base.replace(".", "").lower() + if base == "sd1x": + base = "sd15" + elif base == "sd2x": + base = "sd21" + + resolutions = [ + x + for x in SCALECRAFTER_CONFIG + if x.base == base and ((x.disperse is not None) == disperse) + ] + + # If there are no resolutions for said base, use default one. + if len(resolutions) == 0: + resolutions = [SCALECRAFTER_CONFIG[0]] + + # Map resolutions to a tuple of (name, resolution -> h*w) + resolutions = [ + (x, abs((x.height * x.width * 64) - (height * width))) for x in resolutions + ] + + # Read the settings of the one with the lowest resolution. + return min(resolutions, key=lambda x: x[1])[0] + + +class ReDilateConvProcessor: + "Conv2d with support for up-/downscaling latents" + + def __init__( + self, + module: torch.nn.Conv2d, + pf_factor: float = 1.0, + mode: str = "bilinear", + activate: bool = True, + ): + self.dilation = math.ceil(pf_factor) + self.factor = float(self.dilation / pf_factor) + self.module = module + self.mode = mode + self.activate = activate + + def __call__( + self, input: torch.Tensor, scale: float, *args, **kwargs + ) -> torch.Tensor: + if len(args) > 0: + print(len(args)) + print("".join(map(str, map(type, args)))) + if self.activate: + ori_dilation, ori_padding = self.module.dilation, self.module.padding + inflation_kernel_size = (self.module.weight.shape[-1] - 3) // 2 + self.module.dilation, self.module.padding = self.dilation, ( # type: ignore + self.dilation * (1 + inflation_kernel_size), + self.dilation * (1 + inflation_kernel_size), + ) + ori_size, new_size = ( + ( + int(input.shape[-2] / self.module.stride[0]), + int(input.shape[-1] / self.module.stride[1]), + ), + ( + round(input.shape[-2] * self.factor), + round(input.shape[-1] * self.factor), + ), + ) + input = torch.nn.functional.interpolate( + input, size=new_size, mode=self.mode + ) + input = self.module._conv_forward( + input, self.module.weight, self.module.bias + ) + self.module.dilation, self.module.padding = ori_dilation, ori_padding + result = torch.nn.functional.interpolate( + input, size=ori_size, mode=self.mode + ) + return result + else: + return self.module._conv_forward( + input, self.module.weight, self.module.bias + ) + + +def inflate_kernels( + unet: UNet2DConditionModel, + inflate_conv_list: list, + inflation_transform: torch.Tensor, +) -> UNet2DConditionModel: + def replace_module(module: torch.nn.Module, name: List[str], index: list, value): + if len(name) == 1 and len(index) == 0: + setattr(module, name[0], value) + return module + + current_name, next_name = name[0], name[1:] + current_index, next_index = int(index[0]), index[1:] + replace = getattr(module, current_name) + replace[current_index] = replace_module( + replace[current_index], next_name, next_index, value + ) + setattr(module, current_name, replace) + return module + + inflation_transform.to(dtype=unet.dtype, device=unet.device) + + for name, module in unet.named_modules(): + if name in inflate_conv_list: + weight, bias = module.weight.detach(), module.bias.detach() + (i, o, *_), kernel_size = ( + weight.shape, + int(math.sqrt(inflation_transform.shape[0])), + ) + transformed_weight = torch.einsum( + "mn, ion -> iom", + inflation_transform.to(dtype=weight.dtype), + weight.view(i, o, -1), + ) + conv = torch.nn.Conv2d( + o, + i, + (kernel_size, kernel_size), + stride=module.stride, + padding=module.padding, + device=weight.device, + dtype=weight.dtype, + ) + conv.weight.detach().copy_( + transformed_weight.view(i, o, kernel_size, kernel_size) + ) + conv.bias.detach().copy_(bias) # type: ignore + + sub_names = name.split(".") + if name.startswith("mid_block"): + names, indexes = sub_names[1::2], sub_names[2::2] + unet.mid_block = replace_module(unet.mid_block, names, indexes, conv) # type: ignore + else: + names, indexes = sub_names[0::2], sub_names[1::2] + replace_module(unet, names, indexes, conv) + return unet + + +def scale_setup(unet: UNet2DConditionModel, settings: Optional[ScalecrafterSettings]): + global _unet_inflate, _unet_inflate_vanilla + + if _unet_inflate_vanilla is not None: + del _unet_inflate_vanilla, _unet_inflate + + if settings is None: + return + + if settings.disperse is not None: + if len(settings.disperse_list) != 0: + _unet_inflate = deepcopy(unet) + _unet_inflate = inflate_kernels( + _unet_inflate, settings.disperse_list, settings.disperse + ) + if settings.ndcfg_tau > 0: + _unet_inflate_vanilla = deepcopy(unet) + _unet_inflate_vanilla = inflate_kernels( + _unet_inflate_vanilla, settings.disperse_list, settings.disperse + ) + + +def scale( + unet: UNet2DConditionModel, + settings: Optional[ScalecrafterSettings], + step: int, + total_steps: int, +) -> UNet2DConditionModel: + if settings is None: + return unet + + global _backup_forwards, _unet_inflate, _unet_inflate_vanilla + + tau = step / total_steps + inflate = settings.inflate_tau < tau and settings.disperse is not None + + if inflate: + unet = _unet_inflate # type: ignore + + for name, module in unet.named_modules(): + if settings.dilation_settings is not None: + if name in settings.dilation_settings.keys(): + _backup_forwards[name] = module.forward + dilate = settings.dilation_settings[name] + if settings.progressive: + dilate = max( + math.ceil( + dilate * ((settings.dilate_tau - tau) / settings.dilate_tau) + ), + 2, + ) + if tau < settings.inflate_tau and name in settings.disperse_list: + dilate = dilate / 2 + module.forward = ReDilateConvProcessor( # type: ignore + module, dilate, mode="bilinear", activate=tau < settings.dilate_tau # type: ignore + ) + + return unet + + +def post_scale( + unet: UNet2DConditionModel, + settings: Optional[ScalecrafterSettings], + step: int, + total_steps: int, + call, + *args, + **kwargs, +) -> Tuple[UNet2DConditionModel, Optional[torch.Tensor]]: + if settings is None: + return unet, None + + global _backup_forwards + for name, module in unet.named_modules(): + if name in _backup_forwards.keys(): + module.forward = _backup_forwards[name] + _backup_forwards.clear() + + tau = step / total_steps + noise_pred_vanilla = None + if tau < settings.ndcfg_tau: + inflate = settings.inflate_tau < tau and settings.disperse is not None + + if inflate: + unet = _unet_inflate_vanilla # type: ignore + + for name, module in unet.named_modules(): + if name in settings.ndcfg_dilate_settings.keys(): + _backup_forwards[name] = module.forward + dilate = settings.ndcfg_dilate_settings[name] + if settings.progressive: + dilate = max( + math.ceil( + dilate * ((settings.ndcfg_tau - tau) / settings.ndcfg_tau) + ), + 2, + ) + if tau < settings.inflate_tau and name in settings.disperse_list: + dilate = dilate / 2 + module.forward = ReDilateConvProcessor( # type: ignore + module, dilate, mode="bilinear", activate=tau < settings.ndcfg_tau # type: ignore + ) + noise_pred_vanilla = call(*args, **kwargs) + + for name, module in unet.named_modules(): + if name in _backup_forwards.keys(): + module.forward = _backup_forwards[name] + _backup_forwards.clear() + + return unet, noise_pred_vanilla + + +class ScaledAttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __init__(self, processor, test_res, train_res): + self.processor = processor + self.test_res = test_res + self.train_res = train_res + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + input_ndim = hidden_states.ndim + if encoder_hidden_states is None: + if input_ndim == 4: + _, _, height, width = hidden_states.shape + sequence_length = height * width + else: + _, sequence_length, _ = hidden_states.shape + + test_train_ratio = float(self.test_res / self.train_res) + train_sequence_length = sequence_length / test_train_ratio + scale_factor = math.log(sequence_length, train_sequence_length) ** 0.5 + else: + scale_factor = 1 + + original_scale = attn.scale + attn.scale = attn.scale * scale_factor + hidden_states = self.processor( + attn, hidden_states, encoder_hidden_states, attention_mask, temb + ) + attn.scale = original_scale + return hidden_states + + +def read_settings(config_name: str): + file = SCALECRAFTER_DIR / "configs" / config_name + with open(file, "r") as f: + config = yaml.safe_load(f) + # 0. Default height and width to unet + base = config_name.split("_")[0].strip().lower().replace(".", "") + steps = config["num_inference_steps"] + height = config["latent_height"] + width = config["latent_width"] + inflate_tau = config["inflate_tau"] / steps + ndcfg_tau = config["ndcfg_tau"] / steps + dilate_tau = config["dilate_tau"] / steps + progressive = config["progressive"] + + dilate_settings = dict() + if config["dilate_settings"] is not None: + with open(os.path.join(SCALECRAFTER_DIR, config["dilate_settings"])) as f: + for line in f.readlines(): + name, dilate = line.strip().split(":") + dilate_settings[name] = float(dilate) + + ndcfg_dilate_settings = dict() + if config["ndcfg_dilate_settings"] is not None: + with open(os.path.join(SCALECRAFTER_DIR, config["ndcfg_dilate_settings"])) as f: + for line in f.readlines(): + name, dilate = line.strip().split(":") + ndcfg_dilate_settings[name] = float(dilate) + + inflate_settings = list() + if config["disperse_settings"] is not None: + with open(os.path.join(SCALECRAFTER_DIR, config["disperse_settings"])) as f: + inflate_settings = list(map(lambda x: x.strip(), f.readlines())) + + disperse = None + if config["disperse_transform"] is not None: + disperse = scipy.io.loadmat( + os.path.join(SCALECRAFTER_DIR, config["disperse_transform"]) + )["R"] + disperse = torch.tensor(disperse, device="cpu") + + return ScalecrafterSettings( + inflate_tau, + ndcfg_tau, + dilate_tau, + # -- + progressive, + # -- + dilate_settings, + ndcfg_dilate_settings, + inflate_settings, + # -- + disperse=disperse, + height=height, + width=width, + base=base, + ) + + +SCALECRAFTER_CONFIG = list(map(read_settings, os.listdir(SCALECRAFTER_DIR / "configs"))) diff --git a/core/inference/utilities/scheduling.py b/core/inference/utilities/scheduling.py index 860de8a49..9d6ec1129 100644 --- a/core/inference/utilities/scheduling.py +++ b/core/inference/utilities/scheduling.py @@ -14,6 +14,7 @@ from core.config import config from core.inference.utilities.philox import PhiloxGenerator from core.scheduling import KdiffusionSchedulerAdapter, create_sampler +from core.scheduling.custom.sasolver import SASolverScheduler from core.types import PyTorchModelType, SigmaScheduler from core.utils import unwrap_enum @@ -51,6 +52,7 @@ def prepare_extra_step_kwargs( scheduler: SchedulerMixin, eta: Optional[float], generator: Union[PhiloxGenerator, torch.Generator], + device: torch.device, ): """prepare extra kwargs for the scheduler step, since not all schedulers have the same signature eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. @@ -73,7 +75,9 @@ def prepare_extra_step_kwargs( in set(inspect.signature(scheduler.step).parameters.keys()) # type: ignore and config.api.generator != "philox" ) - if accepts_generator: + if accepts_generator and ( + hasattr(generator, "device") and generator.device == device # type: ignore + ): extra_step_kwargs["generator"] = generator return extra_step_kwargs @@ -110,17 +114,20 @@ def change_scheduler( else: sched = DDIMScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler") # type: ignore - new_scheduler = create_sampler( - alphas_cumprod=sched.alphas_cumprod, # type: ignore - denoiser_enable_quantization=True, - sampler=scheduler, - sigma_type=sigma_type, - eta_noise_seed_delta=0, - sigma_always_discard_next_to_last=False, - sigma_use_old_karras_scheduler=False, - device=model.unet.device, # type: ignore - dtype=model.unet.dtype, # type: ignore - sampler_settings=sampler_settings, - ) + if scheduler == "sasolver": + new_scheduler = SASolverScheduler.from_config(config=configuration) # type: ignore + else: + new_scheduler = create_sampler( + alphas_cumprod=sched.alphas_cumprod, # type: ignore + denoiser_enable_quantization=config.api.kdiffusers_quantization, + sampler=scheduler, + sigma_type=sigma_type, + eta_noise_seed_delta=0, + sigma_always_discard_next_to_last=False, + sigma_use_old_karras_scheduler=False, + device=torch.device(config.api.device), # type: ignore + dtype=config.api.load_dtype, # type: ignore + sampler_settings=sampler_settings, + ) model.scheduler = new_scheduler # type: ignore return new_scheduler # type: ignore diff --git a/core/inference/utilities/unet_patches.py b/core/inference/utilities/unet_patches.py new file mode 100644 index 000000000..ab4a72bdd --- /dev/null +++ b/core/inference/utilities/unet_patches.py @@ -0,0 +1,665 @@ +from copy import deepcopy +from typing import Any, Dict, Optional, Tuple, Union, List + +from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D +from diffusers.models.unet_2d_condition import UNet2DConditionModel +from diffusers.utils.torch_utils import apply_freeu +import torch + +_dummy = None # here so I can import this + + +def _unet_new_forward( + self: UNet2DConditionModel, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int, List[torch.Tensor]], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + quick_replicate: bool = False, # whether to turn on deepcache + drop_encode_decode: bool = False, + replicate_prv_feature: Optional[List[torch.Tensor]] = None, + cache_layer_id: Optional[int] = None, + cache_block_id: Optional[int] = None, + order: Optional[int] = None, + return_dict: bool = True, +) -> Tuple: + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # 1 - default + # 2 - deepcache + # 3 - faster-diffusion + method = 1 + if quick_replicate: + method = 2 + elif drop_encode_decode: + method = 3 + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + for dim in sample.shape[-2:]: + if dim % default_overall_up_factor != 0: + # Forward upsample size to force interpolation output size. + forward_upsample_size = True + break + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = ( + 1 - encoder_attention_mask.to(sample.dtype) + ) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config["center_input_sample"]: + sample = 2 * sample - 1.0 # type: ignore + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: # type: ignore + timesteps = timesteps[None].to(sample.device) # type: ignore + + if len(timesteps.shape) == 1: # type: ignore + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) # type: ignore + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + if self.class_embedding is not None: + if self.config["class_embed_type"] == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) # type: ignore + + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + + if self.config["class_embeddings_concat"]: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + if self.config["addition_embed_type"] == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config["addition_embed_type"] == "text_image": + # Kandinsky 2.1 - style + image_embs = added_cond_kwargs.get("image_embeds") # type: ignore + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) # type: ignore + aug_emb = self.add_embedding(text_embs, image_embs) + elif self.config["addition_embed_type"] == "text_time": + # SDXL - style + text_embeds = added_cond_kwargs.get("text_embeds") # type: ignore + if "time_ids" not in added_cond_kwargs: # type: ignore + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") # type: ignore + time_embeds = self.add_time_proj(time_ids.flatten()) # type: ignore + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) # type: ignore + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) # type: ignore + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config["addition_embed_type"] == "image": + # Kandinsky 2.2 - style + image_embs = added_cond_kwargs.get("image_embeds") # type: ignore + aug_emb = self.add_embedding(image_embs) + elif self.config["addition_embed_type"] == "image_hint": + # Kandinsky 2.2 - style + image_embs = added_cond_kwargs.get("image_embeds") # type: ignore + hint = added_cond_kwargs.get("hint") # type: ignore + aug_emb, hint = self.add_embedding(image_embs, hint) + sample = torch.cat([sample, hint], dim=1) # type: ignore + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + if ( + self.encoder_hid_proj is not None + and self.config["encoder_hid_dim_type"] == "text_proj" + ): + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif ( + self.encoder_hid_proj is not None + and self.config["encoder_hid_dim_type"] == "text_image_proj" + ): + # Kadinsky 2.1 - style + image_embeds = added_cond_kwargs.get("image_embeds") # type: ignore + encoder_hidden_states = self.encoder_hid_proj( + encoder_hidden_states, image_embeds + ) + elif ( + self.encoder_hid_proj is not None + and self.config["encoder_hid_dim_type"] == "image_proj" + ): + # Kandinsky 2.2 - style + image_embeds = added_cond_kwargs.get("image_embeds") # type: ignore + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + elif ( + self.encoder_hid_proj is not None + and self.config["encoder_hid_dim_type"] == "ip_image_proj" + ): + image_embeds = added_cond_kwargs.get("image_embeds") # type: ignore + image_embeds = self.encoder_hid_proj(image_embeds).to( + encoder_hidden_states.dtype + ) + encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1) + + def downsample(downsample_block, additional_residuals: dict): + nonlocal sample, emb, encoder_hidden_states, attention_mask, cross_attention_kwargs, encoder_attention_mask, down_intrablock_additional_residuals, down_block_res_samples + if ( + hasattr(downsample_block, "has_cross_attention") + and downsample_block.has_cross_attention + ): + # For t2i-adapter CrossAttnDownBlock2D + if is_adapter and len(down_intrablock_additional_residuals) > 0: # type: ignore + additional_residuals[ + "additional_residuals" + ] = down_intrablock_additional_residuals.pop( # type: ignore + 0 + ) + + sample, res_samples = downsample_block( + hidden_states=sample, # type: ignore + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block( + hidden_states=sample, temb=emb, scale=1.0 + ) + if is_adapter and len(down_intrablock_additional_residuals) > 0: # type: ignore + sample += down_intrablock_additional_residuals.pop(0) # type: ignore + + down_block_res_samples += res_samples + + prv_f = replicate_prv_feature + needs_prv = prv_f is None and method == 2 + + def upsample(upsample_block, i, length, additional={}): + nonlocal self, cache_block_id, needs_prv, prv_f, cache_layer_id, down_block_res_samples, forward_upsample_size, sample, emb, encoder_hidden_states, cross_attention_kwargs, upsample_size, attention_mask, encoder_attention_mask + + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[length:] + down_block_res_samples = down_block_res_samples[:length] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if ( + hasattr(upsample_block, "has_cross_attention") + and upsample_block.has_cross_attention + ): + sample, current_record_f = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + needs_prv=needs_prv, + **additional, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + scale=1.0, + ) + current_record_f = None + if ( + needs_prv + and cache_layer_id is not None + and current_record_f is not None + and i == len(self.up_blocks) - cache_layer_id - 1 + ): + assert cache_block_id is not None + prv_f = current_record_f[-cache_block_id - 1] + + is_controlnet = ( + mid_block_additional_residual is not None + and down_block_additional_residuals is not None + ) + is_adapter = down_intrablock_additional_residuals is not None + + if method == 3: + assert order is not None + from core.config import config + + mod = config.api.drop_encode_decode + + # ipow = int(np.sqrt(9 + 8 * order)) + cond = order <= 5 or order % 5 == 0 + if isinstance(mod, int): + # First 5 steps always full cond, just to make sure samples aren't being wasted + cond = order <= 5 or order % mod == 0 + + if cond: + # 2. pre-process + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + downsample(downsample_block, {}) + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals # type: ignore + ): + down_block_res_sample = ( + down_block_res_sample + down_block_additional_residual + ) + new_down_block_res_samples = new_down_block_res_samples + ( + down_block_res_sample, + ) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + if ( + hasattr(self.mid_block, "has_cross_attention") + and self.mid_block.has_cross_attention + ): + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = self.mid_block(sample, emb) + + # To support T2I-Adapter-XL + if ( + is_adapter + and len(down_intrablock_additional_residuals) > 0 # type: ignore + and sample.shape == down_intrablock_additional_residuals[0].shape # type: ignore + ): + sample += down_intrablock_additional_residuals.pop(0) # type: ignore + + if is_controlnet: + sample = sample + mid_block_additional_residual # type: ignore + + # 4.5. save features + setattr(self, "skip_feature", deepcopy(down_block_res_samples)) + setattr(self, "toup_feature", sample.detach().clone()) + else: + down_block_res_samples = self.skip_feature + sample = self.toup_feature + + for i, upsample_block in enumerate(self.up_blocks): + upsample(upsample_block, i, -len(upsample_block.resnets)) + else: + # 2. pre-process + sample = self.conv_in(sample) + + # 2.5 GLIGEN position net + if ( + cross_attention_kwargs is not None + and cross_attention_kwargs.get("gligen", None) is not None + ): + cross_attention_kwargs = cross_attention_kwargs.copy() + gligen_args = cross_attention_kwargs.pop("gligen") + cross_attention_kwargs["gligen"] = { + "objs": self.position_net(**gligen_args) + } + + # 3. down + down_block_res_samples = (sample,) + if method == 1 or (method == 2 and replicate_prv_feature is None): + for downsample_block in self.down_blocks: + downsample(downsample_block, {}) + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals # type: ignore + ): + down_block_res_sample = ( + down_block_res_sample + down_block_additional_residual + ) + new_down_block_res_samples = new_down_block_res_samples + ( + down_block_res_sample, + ) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + if ( + hasattr(self.mid_block, "has_cross_attention") + and self.mid_block.has_cross_attention + ): + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = self.mid_block(sample, emb) + + # To support T2I-Adapter-XL + if ( + is_adapter + and len(down_intrablock_additional_residuals) > 0 # type: ignore + and sample.shape == down_intrablock_additional_residuals[0].shape # type: ignore + ): + sample += down_intrablock_additional_residuals.pop(0) # type: ignore + + if is_controlnet: + sample = sample + mid_block_additional_residual # type: ignore + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + upsample(upsample_block, i, -len(upsample_block.resnets)) + elif method == 2 and replicate_prv_feature is not None: + assert ( + cache_layer_id is not None + and cache_block_id is not None + and replicate_prv_feature is not None + ) + # Down + for i, downsample_block in enumerate(self.down_blocks): + if i > cache_layer_id: + break + downsample( + downsample_block, + { + "exist_block_number": cache_block_id + if i == cache_layer_id + else None + }, + ) + + # Skip mid_block + + # Up + sample = replicate_prv_feature # type: ignore + if cache_block_id == len(self.down_blocks[cache_layer_id].attentions): + cache_block_id = 0 + cache_layer_id += 1 + else: + cache_block_id += 1 + + for i, upsample_block in enumerate(self.up_blocks): + if i < len(self.up_blocks) - 1 - cache_layer_id: + continue + + if i == len(self.up_blocks) - 1 - cache_layer_id: + length = cache_block_id + 1 + else: + length = len(upsample_block.resnets) + + upsample( + upsample_block, + i, + -length, + { + "enter_block_number": cache_block_id + if i == len(self.up_blocks) - 1 - cache_layer_id + else None + }, + ) + prv_f = replicate_prv_feature + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) # type: ignore + sample = self.conv_out(sample) + + return ( + sample, + prv_f, + ) + + +# Changes: added enter_block_number +def _up_new_forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + enter_block_number: Optional[int] = None, + needs_prv: bool = False, +) -> Tuple[torch.FloatTensor, List]: + prv_f = [] + lora_scale = ( + cross_attention_kwargs.get("scale", 1.0) + if cross_attention_kwargs is not None + else 1.0 + ) + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): + if ( + enter_block_number is not None + and i < len(self.resnets) - enter_block_number - 1 + ): + continue + + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( # type: ignore + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + if needs_prv: + prv_f.append(hidden_states) + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) # type: ignore + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} + hidden_states = torch.utils.checkpoint.checkpoint( # type: ignore + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale) + + return hidden_states, prv_f + + +# Changes: added exist_block_number +def _down_new_forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + exist_block_number: Optional[int] = None, + additional_residuals: Optional[torch.FloatTensor] = None, +) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + output_states = () + + lora_scale = ( + cross_attention_kwargs.get("scale", 1.0) + if cross_attention_kwargs is not None + else 1.0 + ) + + blocks = list(zip(self.resnets, self.attentions)) + + for i, (resnet, attn) in enumerate(blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} + hidden_states = torch.utils.checkpoint.checkpoint( # type: ignore + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + # apply additional residuals to the output of the last pair of resnet and attention blocks + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals # type: ignore + + output_states = output_states + (hidden_states,) + if ( + exist_block_number is not None + and len(output_states) == exist_block_number + 1 + ): + return hidden_states, output_states + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, scale=lora_scale) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +CrossAttnUpBlock2D.forward = _up_new_forward # type: ignore +CrossAttnDownBlock2D.forward = _down_new_forward # type: ignore +UNet2DConditionModel.forward = _unet_new_forward # type: ignore diff --git a/core/inference/utilities/vae.py b/core/inference/utilities/vae.py index ab6cf8b5c..2c8a5fd2a 100644 --- a/core/inference/utilities/vae.py +++ b/core/inference/utilities/vae.py @@ -1,24 +1,28 @@ -from typing import Callable, Optional +from contextlib import ExitStack +from typing import Callable, Optional, Union, List +from io import BytesIO import numpy as np import torch from PIL import Image +from core import shared from core.config import config +from core.optimizations import autocast, ensure_correct_device taesd_model = None def taesd( samples: torch.Tensor, height: Optional[int] = None, width: Optional[int] = None -) -> torch.Tensor: +) -> np.ndarray: global taesd_model if taesd_model is None: from diffusers.models.autoencoder_tiny import AutoencoderTiny model = "madebyollin/taesd" - if False: # TODO: if is_sdxl: + if shared.current_model == "SDXL": model = "madebyollin/taesdxl" taesd_model = AutoencoderTiny.from_pretrained( model, torch_dtype=torch.float16 @@ -34,19 +38,18 @@ def taesd( ) -def cheap_approximation(sample: torch.Tensor) -> Image.Image: +def cheap_approximation(sample: torch.Tensor) -> Union[BytesIO, Image.Image]: "Convert a tensor of latents to RGB" # Credit to Automatic111 stable-diffusion-webui # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2 - coeffs = [ [0.298, 0.207, 0.208], [0.187, 0.286, 0.173], [-0.158, 0.189, 0.264], [-0.184, -0.271, -0.473], ] - if False: # TODO: if is_sdxl: + if shared.current_model == "SDXL": coeffs = [ [0.3448, 0.4168, 0.4395], [-0.1953, -0.0290, 0.0250], @@ -54,7 +57,27 @@ def cheap_approximation(sample: torch.Tensor) -> Image.Image: [-0.3730, -0.2499, -0.2088], ] coeffs = torch.tensor(coeffs, dtype=torch.float32, device="cpu") - + if sample.dim() == 4: + decoded_rgb = torch.einsum( + "lfxy,lr -> frxy", sample.to(torch.float32).to("cpu"), coeffs + ) + decoded_rgb = torch.clamp((decoded_rgb + 1.0) / 2.0, min=0.0, max=1.0) + + decoded_rgb = 255.0 * np.moveaxis(decoded_rgb.cpu().numpy(), 1, -1) + decoded_rgb = decoded_rgb.astype(np.uint8) + + buffer = BytesIO() + images = [Image.fromarray(frame) for frame in decoded_rgb] + images[0].save( + buffer, + "gif", + save_all=True, + append_images=images[1:], + loop=0, + optimize=True, + subrectangles=True, + ) + return buffer decoded_rgb = torch.einsum( "lxy,lr -> rxy", sample.to(torch.float32).to("cpu"), coeffs ) @@ -67,16 +90,39 @@ def cheap_approximation(sample: torch.Tensor) -> Image.Image: def full_vae( samples: torch.Tensor, - overwrite: Callable[[torch.Tensor], torch.Tensor], + vae, height: Optional[int] = None, width: Optional[int] = None, -) -> torch.Tensor: - return decode_latents( - overwrite, - samples, - height or samples[0].shape[1] * 8, - width or samples[0].shape[2] * 8, - ) +) -> np.ndarray: + ensure_correct_device(vae) # type: ignore + + def decode(sample): + with ExitStack() as gs: + if vae.config["force_upcast"] or config.api.upcast_vae: + gs.enter_context(autocast(dtype=torch.float32)) + return vae.decode(sample, return_dict=False)[0] + + if samples.dim() == 5: + return torch.tensor( + np.array( # this is here since torch thinks itll be faster like this + [ + decode_latents( + decode, # type: ignore + samples[x].permute(1, 0, 2, 3), + height or samples[0].shape[1] * 8, + width or samples[0].shape[2] * 8, + ) + for x in range(samples.shape[0]) + ] + ) + ).numpy() + else: + return decode_latents( + decode, # type: ignore + samples, + height or samples[0].shape[1] * 8, + width or samples[0].shape[2] * 8, + ) def decode_latents( @@ -84,9 +130,10 @@ def decode_latents( latents: torch.Tensor, height: int, width: int, -) -> torch.Tensor: + scaling_factor: float = 0.18215, +) -> np.ndarray: "Decode latents" - latents = 1 / 0.18215 * latents + latents = 1 / scaling_factor * latents image = decode_lambda(latents) # type: ignore image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 @@ -95,17 +142,34 @@ def decode_latents( return img -def numpy_to_pil(images): +def numpy_to_pil(images: np.ndarray) -> List[Union[BytesIO, Image.Image]]: """ Convert a numpy image or a batch of images to a PIL image. """ - if images.ndim == 3: - images = images[None, ...] - images = (images * 255).round().astype("uint8") - if images.shape[-1] == 1: - # special case for grayscale (single channel) images - pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] + pil_images: List[Union[BytesIO, Image.Image]] = [] + if images.ndim == 5: + for image in images: + frames_done: List[Image.Image] = [] + for frame in image: + frame: np.ndarray = (frame * 255).round().astype(np.uint8) + frames_done.append(Image.fromarray(frame)) + + buffer = BytesIO() + frames_done[0].save( + buffer, + "gif", + save_all=True, + append_images=frames_done[1:], + loop=0, + optimize=True, + subrectangles=True, + ) + + pil_images.append(buffer) else: + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype(np.uint8) pil_images = [Image.fromarray(image) for image in images] return pil_images diff --git a/core/inference_callbacks.py b/core/inference_callbacks.py index bf73938db..e783ddb65 100644 --- a/core/inference_callbacks.py +++ b/core/inference_callbacks.py @@ -1,5 +1,6 @@ import time -from typing import List +from typing import Union, List +from io import BytesIO import torch from PIL import Image @@ -31,14 +32,14 @@ def callback(step: int, _timestep: int, tensor: torch.Tensor): (time.time() - last_image_time > config.api.live_preview_delay) ) - images: List[Image.Image] = [] + images: List[Union[BytesIO, Image.Image]] = [] if send_image: last_image_time = time.time() if config.api.live_preview_method == "approximation": for t in range(tensor.shape[0]): images.append(cheap_approximation(tensor[t])) else: - for img in numpy_to_pil(taesd(tensor)): + for img in numpy_to_pil(taesd(tensor)): # type: ignore images.append(img) websocket_manager.broadcast_sync( @@ -53,7 +54,7 @@ def callback(step: int, _timestep: int, tensor: torch.Tensor): "image": convert_images_to_base64_grid( images, quality=60, image_format="webp" ) - if send_image + if len(images) > 0 else "", }, ) diff --git a/core/install_requirements.py b/core/install_requirements.py index a918e254e..ad9979efb 100644 --- a/core/install_requirements.py +++ b/core/install_requirements.py @@ -14,8 +14,9 @@ "opencv-contrib-python-headless": "cv2", "fastapi-analytics": "api_analytics", "cuda-python": "cuda", - "open_clip_torch": "open_clip", + "open-clip-torch": "open_clip", "python-multipart": "multipart", + "invisible-watermark": "imwatermark", "discord.py": "discord", "HyperTile": "hyper-tile", "stable-fast": "sfast", @@ -159,11 +160,26 @@ class PytorchDistribution: "-m", "pip", "install", - "torch==1.13.0a0", - "torchvision==0.14.1a0", - "intel_extension_for_pytorch==1.13.120+xpu", - "-f", - "https://developer.intel.com/ipex-whl-stable-xpu", + "torch==2.0.0a0+gitc6a572f", + "torchvision==0.14.1a0+5e8e2f1", + "intel-extension-for-pytorch==2.0.110+gitba7f6c1", + "--extra-index-url", + "https://pytorch-extension.intel.com/release-whl/stable/xpu/us/", + ], + ), + PytorchDistribution( + windows_supported=True, + name="intel", + check_command=["test", "-f", '"/etc/OpenCL/vendors/intel.icd"'], + success_message="Intel check success, assuming user has an Intel (i)GPU", + install_command=[ + sys.executable, + "-m", + "pip", + "install", + "https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.0.110%2Bxpu-master%2Bdll-bundle/torch-2.0.0a0+gite9ebda2-cp310-cp310-win_amd64.whl", + "https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.0.110%2Bxpu-master%2Bdll-bundle/torchvision-0.15.2a0+fa99a53-cp310-cp310-win_amd64.whl", + "https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.0.110%2Bxpu-master%2Bdll-bundle/intel_extension_for_pytorch-2.0.110+gitc6ea20b-cp310-cp310-win_amd64.whl", ], ), PytorchDistribution( @@ -218,7 +234,7 @@ class PytorchDistribution: ] -def install_deps(force_distribution: int = -1): +def install_deps(force_distribution: Union[int, str] = -1): "Install necessary requirements for inference" # Install pytorch @@ -234,17 +250,30 @@ def install_deps(force_distribution: int = -1): x for x in _pytorch_distributions if x.name == force_distribution.lower() + and (x.windows_supported if platform.system() == "Windows" else True) ][0] logger.info("Installing PyTorch") if platform.system() == "Darwin": subprocess.check_call( - [sys.executable, "-m", "pip", "install", "torch==2.0.0", "torchvision"] + [sys.executable, "-m", "pip", "install", "torch==2.1.0", "torchvision"] ) else: - for c in _pytorch_distributions: - if ( - (c.windows_supported if platform.system() == "Windows" else True) - and ( + if forced_distribution is not None: + # User forced a specific distribution + + logger.info(forced_distribution.success_message) + if isinstance(forced_distribution.install_command[0], list): + for cmd in forced_distribution.install_command: + subprocess.check_call(cmd) + else: + subprocess.check_call(forced_distribution.install_command) # type: ignore + else: + # Automatically detect pytorch distribution + + for c in _pytorch_distributions: + if ( + c.windows_supported if platform.system() == "Windows" else True + ) and ( ( subprocess.run( c.check_command, @@ -254,15 +283,14 @@ def install_deps(force_distribution: int = -1): ).returncode == 0 ) - ) - ) or c == forced_distribution: - logger.info(c.success_message) - if isinstance(c.install_command[0], list): - for cmd in c.install_command: - subprocess.check_call(cmd) - else: - subprocess.check_call(c.install_command) # type: ignore - break + ): + logger.info(c.success_message) + if isinstance(c.install_command[0], list): + for cmd in c.install_command: + subprocess.check_call(cmd) + else: + subprocess.check_call(c.install_command) # type: ignore + break # Install other requirements install_requirements("requirements/pytorch.txt") @@ -409,8 +437,13 @@ def check_valid_python_version(): print("Please consider switching to an older release to use volta!") raise RuntimeError("Unsupported Python version") elif minor < 9: - print("The python release you are currently using is older than our") - print("official supported version! Please consider updating to Python 3.11!") + print("--------------------------------------------------------") + print("| The python release you are currently using is older |") + print("| than our official supported version! Please consider |") + print("| updating to Python 3.11! |") + print("| |") + print("| Issues will most likely be IGNORED! |") + print("--------------------------------------------------------") def is_up_to_date(): diff --git a/core/interrogation/clip.py b/core/interrogation/clip.py index 371891280..7fd1a86f3 100644 --- a/core/interrogation/clip.py +++ b/core/interrogation/clip.py @@ -45,7 +45,7 @@ def __init__( self.caption_processor: AutoProcessor self.clip_model = None self.clip_preprocess = None - self.dtype: torch.dtype = config.api.dtype + self.dtype: torch.dtype = config.api.load_dtype if autoload: self.load() diff --git a/core/interrogation/deepdanbooru.py b/core/interrogation/deepdanbooru.py index e7df3f7e3..c4ac46490 100644 --- a/core/interrogation/deepdanbooru.py +++ b/core/interrogation/deepdanbooru.py @@ -39,7 +39,7 @@ def __init__( self.tags = [] self.model: DeepDanbooruModel self.model_location = Path("data") / "models" / "deepdanbooru.pt" - self.dtype = torch.quint8 if quantized else config.api.dtype + self.dtype = torch.quint8 if quantized else config.api.load_dtype self.device: torch.device if isinstance(self.device, str): self.device = torch.device(device) diff --git a/core/interrogation/flamingo.py b/core/interrogation/flamingo.py index 40a432043..7fc1f2749 100644 --- a/core/interrogation/flamingo.py +++ b/core/interrogation/flamingo.py @@ -17,7 +17,7 @@ def __init__(self, device: Union[str, torch.device] = "cuda"): super().__init__(device) self.device = device - self.dtype = config.api.dtype + self.dtype = config.api.load_dtype self.model: FlamingoModel self.processor: FlamingoProcessor @@ -25,7 +25,7 @@ def load(self): model = FlamingoModel.from_pretrained(config.interrogator.flamingo_model) assert isinstance(model, FlamingoModel) self.model = model - self.model.to(self.device, dtype=self.dtype) + self.model.to(self.device, dtype=self.dtype) # type: ignore self.model.eval() self.processor = FlamingoProcessor(self.model.config) diff --git a/core/optimizations/__init__.py b/core/optimizations/__init__.py index 8b6f062b2..334c21be5 100644 --- a/core/optimizations/__init__.py +++ b/core/optimizations/__init__.py @@ -1,13 +1,19 @@ from .autocast_utils import autocast, without_autocast from .context_manager import inference_context, InferenceContext -from .pytorch_optimizations import optimize_model +from .pytorch_optimizations import optimize_model, optimize_vae +from .upcast import upcast_vae +from .offload import ensure_correct_device, unload_all from .hypertile import is_hypertile_available, hypertile from .compile.stable_fast import compile as compile_sfast __all__ = [ "optimize_model", + "optimize_vae", "without_autocast", "autocast", + "upcast_vae", + "ensure_correct_device", + "unload_all", "inference_context", "InferenceContext", "is_hypertile_available", diff --git a/core/optimizations/attn/__init__.py b/core/optimizations/attn/__init__.py index 1001140e9..a9695fbfc 100644 --- a/core/optimizations/attn/__init__.py +++ b/core/optimizations/attn/__init__.py @@ -1,7 +1,11 @@ import logging import torch -from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0 +from diffusers.models.attention_processor import ( + Attention, + AttnProcessor, + AttnProcessor2_0, +) from diffusers.utils.import_utils import is_xformers_available from packaging import version @@ -39,15 +43,19 @@ def _xf(pipe): ) is None, # --- + "flashattention": lambda p: apply_flash_attention(p.unet) is None, + # --- "multihead": lambda p: apply_multihead_attention(p.unet) is None, # --- "flash-attn": lambda p: apply_flash_attention(p.unet) is None, } -def set_attention_processor(pipe): +def set_attention_processor(pipe, fused: bool = True, silent: bool = False): "Set attention processor to the first one available/the one set in the config" + logger.disabled = silent + res = False try: curr_processor = list(ATTENTION_PROCESSORS.keys()).index( @@ -62,11 +70,27 @@ def set_attention_processor(pipe): logger.info( f"Optimization: Enabled {attention_processors_list[curr_processor][0]} attention" ) + if fused: + b = True + for attn_processor in pipe.unet.attn_processors.values(): + if "Added" in attn_processor.__class__.__name__: + b = False + if b: + n = 0 + for module in pipe.unet.modules(): + if isinstance(module, Attention): + if hasattr(module, "fuse_projections"): + n += 1 + module.fuse_projections(fuse=True) + if n != 0: + logger.info(f"Optimization: Fused {n} attention modules") + curr_processor = (curr_processor + 1) % len(attention_processors_list) __all__ = [ "apply_subquadratic_attention", "apply_multihead_attention", + "apply_flash_attention", "set_attention_processor", ] diff --git a/core/optimizations/autocast_utils.py b/core/optimizations/autocast_utils.py index 127afc049..32d25204c 100644 --- a/core/optimizations/autocast_utils.py +++ b/core/optimizations/autocast_utils.py @@ -42,7 +42,7 @@ def autocast( global _initialized_directml - if dtype == torch.float32 or disable: + if disable: return contextlib.nullcontext() if "privateuseone" in config.api.device: if not _initialized_directml: diff --git a/core/optimizations/compile/trace_utils.py b/core/optimizations/compile/trace_utils.py index 5515250ce..c57c29f79 100644 --- a/core/optimizations/compile/trace_utils.py +++ b/core/optimizations/compile/trace_utils.py @@ -1,6 +1,7 @@ import logging import warnings from typing import Tuple +from functools import partial import torch from diffusers.models.unet_2d_condition import UNet2DConditionOutput @@ -42,32 +43,76 @@ def forward( def warmup( - model: torch.nn.Module, amount: int, dtype: torch.dtype, device: torch.device + model: torch.nn.Module, + amount: int, + dtype: torch.dtype, + device: torch.device, + silent: bool = False, ) -> None: "Warms up model with amount generated sample inputs." model.eval() with torch.inference_mode(): - for _ in tqdm(range(amount), desc="Warming up"): + for _ in tqdm(range(amount), disable=silent, desc="Warming up"): model(*generate_inputs(dtype, device)) +def trace_ipex( + model: torch.nn.Module, + dtype: torch.dtype, + device: torch.device, + cpu: dict, + silent: bool = False, +) -> Tuple[torch.nn.Module, bool]: + from core.inference.functions import is_ipex_available + + logger.disabled = silent + + if is_ipex_available(): + import intel_extension_for_pytorch as ipex + + logger.info("Optimization: Running IPEX optimizations") + + if config.api.channels_last: + ipex.enable_auto_channels_last() + else: + ipex.disable_auto_channels_last() + ipex.enable_onednn_fusion(True) + ipex.set_fp32_math_mode( + ipex.FP32MathMode.BF32 + if "AMD" not in cpu["VendorId"] + else ipex.FP32MathMode.FP32 + ) + model = ipex.optimize( + model, # type: ignore + dtype=dtype, + auto_kernel_selection=True, + sample_input=generate_inputs(dtype, device), + concat_linear=True, + graph_mode=True, + ) + return model, True + else: + return model, False + + def trace_model( model: torch.nn.Module, dtype: torch.dtype, device: torch.device, iterations: int = 25, - ipex: bool = False, + silent: bool = False, ) -> torch.nn.Module: "Traces the model for inference" + logger.disabled = silent + og = model - from functools import partial if model.forward.__code__.co_argcount > 3: model.forward = partial(model.forward, return_dict=False) - warmup(model, iterations, dtype, device) - if config.api.channels_last and not ipex: + warmup(model, iterations, dtype, device, silent=silent) + if config.api.channels_last: model.to(memory_format=torch.channels_last) # type: ignore logger.debug("Starting trace") with warnings.catch_warnings(): @@ -77,7 +122,7 @@ def trace_model( model = torch.jit.trace(model, generate_inputs(dtype, device), check_trace=False) # type: ignore model = torch.jit.freeze(model) # type: ignore logger.debug("Tracing done") - warmup(model, iterations // 5, dtype, device) + warmup(model, iterations // 5, dtype, device, silent=silent) model.in_channels = og.in_channels model.dtype = og.dtype diff --git a/core/optimizations/context_manager.py b/core/optimizations/context_manager.py index e02143035..878610ed5 100644 --- a/core/optimizations/context_manager.py +++ b/core/optimizations/context_manager.py @@ -1,25 +1,116 @@ from contextlib import ExitStack +from types import TracebackType +from typing import List, Optional, TypeVar import torch +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL +from diffusers.models.unet_2d_condition import UNet2DConditionModel from core.config import config +from core.flags import AnimateDiffFlag, Flag from .autocast_utils import autocast from .hypertile import is_hypertile_available, hypertile +T = TypeVar("T", bound=type[Flag]) + + class InferenceContext(ExitStack): """inference context""" + old_device: torch.device + old_dtype: torch.dtype + old_unet: torch.nn.Module | None = None unet: torch.nn.Module vae: torch.nn.Module + flags: List[Optional[Flag]] = [] + components: dict = {} + + def to(self, device: str, dtype: torch.dtype, memory_format): + self.vae.to(device=device, dtype=dtype, memory_format=memory_format) # type: ignore + self.unet.to(device=device, dtype=dtype, memory_format=memory_format) # type: ignore + + def enable_freeu(self, s1, s2, b1, b2): + if hasattr(self.unet, "enable_freeu"): + self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) + + def get_flag(self, _type: T) -> Optional[T]: + try: + return [x for x in self.flags if isinstance(x, _type)].pop() # type: ignore + except IndexError: + return None + def __exit__( + self, + __exc_type: type[BaseException] | None, + __exc_value: BaseException | None, + __traceback: TracebackType | None, + ) -> bool: + ret = super().__exit__(__exc_type, __exc_value, __traceback) + if self.old_unet is not None: + self.old_unet.to(device=self.old_device, dtype=self.old_dtype) + return ret -def inference_context(unet, vae, height, width) -> InferenceContext: + def enable_xformers_memory_efficient_attention(self): + self.unet.enable_xformers_memory_efficient_attention() + + +def inference_context( + unet: UNet2DConditionModel, + vae: AutoencoderKL, + height: int, + width: int, + flags: List[Optional[Flag]] = [], +) -> InferenceContext: "Helper function for centralizing context management" s = InferenceContext() s.unet = unet s.vae = vae - s.enter_context(autocast(unet.dtype, disable=config.api.autocast)) + s.components = {"unet": unet, "vae": vae} + s.enter_context( + autocast( + config.api.load_dtype, + disable=config.api.autocast and not unet.force_autocast, + ) + ) + s.flags = flags + + s.old_device = s.unet.device + s.old_dtype = s.unet.dtype + + animatediff = s.get_flag(AnimateDiffFlag) + if animatediff is not None: + from core.inference.utilities.animatediff import UNet3DConditionModel + from .pytorch_optimizations import optimize_model + from core.shared_dependent import gpu + + s.unet.to("cpu") + s.old_unet = s.unet + s.unet = UNet3DConditionModel.from_pretrained_2d( # type: ignore + s.unet, animatediff.motion_model # type: ignore + ) + + if animatediff.use_pia: + s.unet = s.unet.convert_to_pia(animatediff.pia_checkpont) + + if config.api.clear_memory_policy == "always": + gpu.memory_cleanup() + + optimize_model(s, config.api.device, silent=True) # type: ignore + + if animatediff.chunk_feed_forward != -1: + # TODO: do auto batch calculation + # for now, "auto" is 1. + batch_size = ( + 1 + if animatediff.chunk_feed_size == "auto" + else animatediff.chunk_feed_size + ) + s.unet.enable_forward_chunking( + chunk_size=batch_size, dim=animatediff.chunk_feed_forward + ) + + s.components.update({"unet": s.unet}) if is_hypertile_available() and config.api.hypertile: s.enter_context(hypertile(unet, height, width)) if config.api.torch_compile: diff --git a/core/optimizations/dtype.py b/core/optimizations/dtype.py new file mode 100644 index 000000000..b4ced4a8c --- /dev/null +++ b/core/optimizations/dtype.py @@ -0,0 +1,81 @@ +import logging +from typing import Union + +import torch +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( + StableDiffusionPipeline, +) +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import ( + StableDiffusionXLPipeline, +) + +from core.config import config + +try: + force_autocast = [torch.float8_e4m3fn, torch.float8_e5m2] +except AttributeError: + force_autocast = [] + +logger = logging.getLogger(__name__) + + +def cast( + pipe: Union[StableDiffusionPipeline, StableDiffusionXLPipeline], + device: str, + dtype: torch.dtype, + offload: bool, + silent: bool = False, +): + logger.disabled = silent + + # Change the order of the channels to be more efficient for the GPU + # DirectML only supports contiguous memory format + memory_format = torch.preserve_format + if config.api.channels_last: + if "privateuseone" in device: + logger.warn( + "Optimization: Skipping channels_last, since DirectML doesn't support it." + ) + else: + memory_format = torch.channels_last + if hasattr(pipe, "unet"): + pipe.unet.to(memory_format=memory_format) + if hasattr(pipe, "vae"): + pipe.vae.to(memory_format=memory_format) + logger.info("Optimization: Enabled channels_last memory format") + + pipe.unet.force_autocast = dtype in force_autocast # type: ignore + if pipe.unet.force_autocast: + for b in [x for x in pipe.components.values() if hasattr(x, "modules")]: # type: ignore + if "CLIP" in b.__class__.__name__: + b.to(device=None if offload else device, dtype=config.api.load_dtype) + else: + for module in b.modules(): + if any( + [ + x + for x in ["Conv", "Linear"] # 'cause LoRACompatibleConv + if x in module.__class__.__name__ + ] + ): + if hasattr(module, "fp16_weight"): + del module.fp16_weight + if config.api.cache_fp16_weight: + module.fp16_weight = module.weight.clone().half() + module.to( + device=None if offload else device, + dtype=dtype, + ) + else: + module.to( + device=None if offload else device, + dtype=config.api.load_dtype, + ) + if not config.api.autocast: + logger.info("Optimization: Forcing autocast on due to float8 weights.") + else: + pipe.to( + device=None if offload else device, dtype=dtype, memory_format=memory_format + ) + + return pipe diff --git a/core/optimizations/offload.py b/core/optimizations/offload.py new file mode 100644 index 000000000..7cd26e93e --- /dev/null +++ b/core/optimizations/offload.py @@ -0,0 +1,51 @@ +# pylint: disable=global-statement + +from typing import Optional +import logging + +from accelerate import cpu_offload +import torch + +from core.config import config + +logger = logging.getLogger(__name__) +_module: torch.nn.Module = None # type: ignore + + +def unload_all(): + global _module + if _module is not None: + _module.cpu() + _module = None # type: ignore + + +def ensure_correct_device(module: torch.nn.Module): + global _module + if _module is not None: + if module.__class__.__name__ == _module.__class__.__name__: + return + logger.debug(f"Transferring {_module.__class__.__name__} to cpu.") + _module.cpu() + _module = None # type: ignore + if hasattr(module, "v_offload_device"): + device = getattr(module, "v_offload_device", config.api.device) + + logger.debug(f"Transferring {module.__class__.__name__} to {str(device)}.") + module.to(device=torch.device(device)) + _module = module + else: + logger.debug(f"Don't need to do anything with {module.__class__.__name__}.") + + +def set_offload( + module: torch.nn.Module, device: torch.device, offload_type: Optional[str] = None +): + offload = offload_type or config.api.offload + if offload == "module": + class_name = module.__class__.__name__ + if "CLIP" not in class_name and "Autoencoder" not in class_name: + return cpu_offload( + module, device, offload_buffers=len(module._parameters) > 0 + ) + setattr(module, "v_offload_device", device) + return module diff --git a/core/optimizations/pytorch_optimizations.py b/core/optimizations/pytorch_optimizations.py index 76a9341f4..e08f96b10 100644 --- a/core/optimizations/pytorch_optimizations.py +++ b/core/optimizations/pytorch_optimizations.py @@ -1,18 +1,22 @@ import logging -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( StableDiffusionPipeline, ) -from diffusers.utils.import_utils import is_accelerate_available +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import ( + StableDiffusionXLPipeline, +) from core.config import config -from core.files import get_full_model_path from .attn import set_attention_processor -from .compile.trace_utils import generate_inputs, trace_model +from .compile.trace_utils import trace_ipex, trace_model +from .dtype import cast +from .offload import set_offload +from .upcast import upcast_vae logger = logging.getLogger(__name__) @@ -21,13 +25,14 @@ def optimize_model( - pipe: StableDiffusionPipeline, + pipe: Union[StableDiffusionPipeline, StableDiffusionXLPipeline], device, is_for_aitemplate: bool = False, + silent: bool = False, ) -> None: "Optimize the model for inference." - from core.inference.functions import is_ipex_available + logger.disabled = silent # Tuple[Supported, Enabled by default, Enabled] hardware_scheduling = experimental_check_hardware_scheduling() @@ -48,26 +53,23 @@ def optimize_model( ) offload = ( - config.api.offload - if (is_pytorch_pipe(pipe) and not is_for_aitemplate) - else None + config.api.offload != "disabled" + and is_pytorch_pipe(pipe) + and not is_for_aitemplate + ) + can_offload = ( + any(map(lambda x: x not in config.api.device, ["cpu", "vulkan", "mps"])) + and offload ) - can_offload = any( - map(lambda x: x not in config.api.device, ["cpu", "vulkan", "mps"]) - ) and (offload != "disabled" and offload is not None) - # Took me an hour to understand why CPU stopped working... - # Turns out AMD just lacks support for BF16... - # Not mad, not mad at all... to be fair, I'm just disappointed - if not can_offload and not is_for_aitemplate: - pipe.to(device, torch_dtype=config.api.dtype) + pipe = cast(pipe, device, config.api.dtype, can_offload, silent=silent) if "cuda" in config.api.device and not is_for_aitemplate: supports_tf = supports_tf32(device) if config.api.reduced_precision: if supports_tf: - logger.info("Optimization: Enabled all reduced precision operations") torch.set_float32_matmul_precision("medium") + logger.info("Optimization: Enabled all reduced precision operations") else: logger.warning( "Optimization: Device capability is not higher than 8.0, skipping most of reduction" @@ -89,6 +91,9 @@ def optimize_model( logger.info("Optimization: CUDNN benchmark enabled") torch.backends.cudnn.benchmark = config.api.cudnn_benchmark # type: ignore + if is_pytorch_pipe(pipe): + pipe.vae = optimize_vae(pipe.vae, silent=silent) + # Attention slicing that should save VRAM (but is slower) slicing = config.api.attention_slicing if slicing != "disabled" and is_pytorch_pipe(pipe) and not is_for_aitemplate: @@ -99,79 +104,27 @@ def optimize_model( pipe.enable_attention_slicing(slicing) logger.info(f"Optimization: Enabled attention slicing ({slicing})") - # Change the order of the channels to be more efficient for the GPU - # DirectML only supports contiguous memory format - # Disable for IPEX as well, they don't like torch's way of setting memory format - if ( - config.api.channels_last - and "privateuseone" not in config.api.device - and (not is_ipex_available() and "cpu" not in config.api.device) - and not is_for_aitemplate - ): - pipe.unet.to(memory_format=torch.channels_last) # type: ignore - pipe.vae.to(memory_format=torch.channels_last) # type: ignore - logger.info("Optimization: Enabled channels_last memory format") - # xFormers and SPDA if not is_for_aitemplate: - set_attention_processor(pipe) + set_attention_processor(pipe, silent=silent) if config.api.autocast: logger.info("Optimization: Enabled autocast") if can_offload: - if not is_accelerate_available(): - logger.warning( - "Optimization: Offload is not available, because accelerate is not installed" - ) - else: - if offload == "model": - # Offload to CPU - from accelerate import cpu_offload_with_hook - - if "cuda" in config.api.device: - torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) - - hook = None - - for cpu_offloaded_model in [ - pipe.text_encoder, - pipe.unet, - pipe.vae, - ]: - _, hook = cpu_offload_with_hook( - cpu_offloaded_model, device, prev_module_hook=hook - ) - pipe.final_offload_hook = hook - setattr(pipe.vae, "main_device", True) - setattr(pipe.unet, "main_device", True) - logger.info("Optimization: Offloaded model parts to CPU.") - - elif offload == "module": - # Enable sequential offload - from accelerate import cpu_offload, disk_offload - - for m in [ - pipe.vae, - pipe.unet, - ]: - if USE_DISK_OFFLOAD: - # If USE_DISK_OFFLOAD toggle set (idk why anyone would do this, but it's nice to support stuff - # like this in case anyone wants to try running this on fuck knows what) - # then offload to disk. - disk_offload( - m, - str( - get_full_model_path("offload-dir", model_folder="temp") - / m.__name__ - ), - device, - offload_buffers=True, - ) - else: - cpu_offload(m, device, offload_buffers=True) - - logger.info("Optimization: Enabled sequential offload") + # Offload to CPU + + for model_name in [ + "text_encoder", + "text_encoder_2", + "unet", + "vae", + ]: + cpu_offloaded_model = getattr(pipe, model_name, None) + if cpu_offloaded_model is not None: + cpu_offloaded_model = set_offload(cpu_offloaded_model, device) + setattr(pipe, model_name, cpu_offloaded_model) + logger.info("Optimization: Offloaded model parts to CPU.") if config.api.free_u: pipe.enable_freeu( @@ -181,14 +134,6 @@ def optimize_model( b2=config.api.free_u_b2, ) - if config.api.vae_slicing: - pipe.enable_vae_slicing() - logger.info("Optimization: Enabled VAE slicing") - - if config.api.vae_tiling: - pipe.enable_vae_tiling() - logger.info("Optimization: Enabled VAE tiling") - if config.api.use_tomesd and not is_for_aitemplate: try: import tomesd @@ -212,45 +157,14 @@ def optimize_model( f"Running on an {cpu['VendorId']} device. Used threads: {torch.get_num_threads()}-{torch.get_num_interop_threads()} / {cpu['num_virtual_cores']}" ) - if is_ipex_available(): - import intel_extension_for_pytorch as ipex - - logger.info("Optimization: Running IPEX optimizations") - - if config.api.channels_last: - ipex.enable_auto_channels_last() - else: - ipex.disable_auto_channels_last() - ipex.enable_onednn_fusion(True) - ipex.set_fp32_math_mode( - ipex.FP32MathMode.BF32 - if "AMD" not in cpu["VendorId"] - else ipex.FP32MathMode.FP32 - ) - pipe.unet = ipex.optimize( - pipe.unet, # type: ignore - dtype=config.api.dtype, - auto_kernel_selection=True, - sample_input=generate_inputs(config.api.dtype, device), - ) - ipexed = True + pipe.unet, ipexed = trace_ipex( + pipe.unet, config.api.load_dtype, device, cpu, silent=silent + ) if config.api.trace_model and not ipexed and not is_for_aitemplate: logger.info("Optimization: Tracing model.") logger.warning("This will break controlnet and loras!") - if config.api.attention_processor == "xformers": - logger.warning( - "Skipping tracing because xformers used for attention processor. Please change to SDPA to enable tracing." - ) - else: - pipe.unet = trace_model(pipe.unet, config.api.dtype, device) # type: ignore - elif is_ipex_available() and config.api.trace_model and not is_for_aitemplate: - logger.warning( - "Skipping tracing because IPEX optimizations have already been done" - ) - logger.warning( - "This is a temporary measure, tracing will work with IPEX-enabled devices later on" - ) + pipe.unet = trace_model(pipe.unet, config.api.load_dtype, device, silent=silent) # type: ignore if config.api.torch_compile and not is_for_aitemplate: if config.api.attention_processor == "xformers": @@ -267,13 +181,6 @@ def optimize_model( "mode": config.api.torch_compile_mode, }, ) - # Wrong place! - # pipe.unet = torch.compile( - # pipe.unet, - # fullgraph=config.api.torch_compile_fullgraph, - # dynamic=config.api.torch_compile_dynamic, - # mode=config.api.torch_compile_mode, - # ) def supports_tf32(device: Optional[torch.device] = None) -> bool: @@ -303,4 +210,20 @@ def experimental_check_hardware_scheduling() -> Tuple[int, int, int]: def is_pytorch_pipe(pipe): "Checks if the pipe is a pytorch pipe" - return issubclass(pipe.__class__, (DiffusionPipeline)) + from .context_manager import InferenceContext + + return issubclass(pipe.__class__, (DiffusionPipeline, InferenceContext)) + + +def optimize_vae(vae, silent: bool = False): + "Optimize a VAE according to config defined in data/settings.json" + vae = upcast_vae(vae, silent=silent) + + if hasattr(vae, "enable_slicing") and config.api.vae_slicing: + vae.enable_slicing() + logger.info("Optimization: Enabled VAE slicing") + + if config.api.vae_tiling and hasattr(vae, "enable_tiling"): + vae.enable_tiling() + logger.info("Optimization: Enabled VAE tiling") + return vae diff --git a/core/optimizations/sdxl_unet.py b/core/optimizations/sdxl_unet.py new file mode 100644 index 000000000..1b056648d --- /dev/null +++ b/core/optimizations/sdxl_unet.py @@ -0,0 +1,590 @@ +# Obviously modified from the original source code +# https://github.com/huggingface/diffusers +# So has APACHE 2.0 license + +# Author : Simo Ryu + +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +import inspect + +from collections import namedtuple + +# SDXL + + +class Timesteps(nn.Module): + def __init__(self, num_channels: int = 320): + super().__init__() + self.num_channels = num_channels + + def forward(self, timesteps): + half_dim = self.num_channels // 2 + exponent = -math.log(10000) * torch.arange( + half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - 0.0) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + sin_emb = torch.sin(emb) + cos_emb = torch.cos(emb) + emb = torch.cat([cos_emb, sin_emb], dim=-1) + + return emb + + +class TimestepEmbedding(nn.Module): + def __init__(self, in_features, out_features): + super(TimestepEmbedding, self).__init__() + self.linear_1 = nn.Linear(in_features, out_features, bias=True) + self.act = nn.SiLU() + self.linear_2 = nn.Linear(out_features, out_features, bias=True) + + def forward(self, sample): + sample = self.linear_1(sample) + sample = self.act(sample) + sample = self.linear_2(sample) + + return sample + + +class ResnetBlock2D(nn.Module): + def __init__(self, in_channels, out_channels, conv_shortcut=True): + super(ResnetBlock2D, self).__init__() + self.norm1 = nn.GroupNorm(32, in_channels, eps=1e-05, affine=True) + self.conv1 = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.time_emb_proj = nn.Linear(1280, out_channels, bias=True) + self.norm2 = nn.GroupNorm(32, out_channels, eps=1e-05, affine=True) + self.dropout = nn.Dropout(p=0.0, inplace=False) + self.conv2 = nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.nonlinearity = nn.SiLU() + self.conv_shortcut = None + if conv_shortcut: + self.conv_shortcut = nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1 + ) + + def forward(self, input_tensor, temb): + hidden_states = input_tensor + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv1(hidden_states) + + temb = self.nonlinearity(temb) + temb = self.time_emb_proj(temb)[:, :, None, None] + hidden_states = hidden_states + temb + hidden_states = self.norm2(hidden_states) + + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = input_tensor + hidden_states + + return output_tensor + + +class Attention(nn.Module): + def __init__( + self, inner_dim, cross_attention_dim=None, num_heads=None, dropout=0.0 + ): + super(Attention, self).__init__() + if num_heads is None: + self.head_dim = 64 + self.num_heads = inner_dim // self.head_dim + else: + self.num_heads = num_heads + self.head_dim = inner_dim // num_heads + + self.scale = self.head_dim**-0.5 + if cross_attention_dim is None: + cross_attention_dim = inner_dim + self.to_q = nn.Linear(inner_dim, inner_dim, bias=False) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False) + + self.to_out = nn.ModuleList( + [nn.Linear(inner_dim, inner_dim), nn.Dropout(dropout, inplace=False)] + ) + + def forward(self, hidden_states, encoder_hidden_states=None): + q = self.to_q(hidden_states) + k = ( + self.to_k(encoder_hidden_states) + if encoder_hidden_states is not None + else self.to_k(hidden_states) + ) + v = ( + self.to_v(encoder_hidden_states) + if encoder_hidden_states is not None + else self.to_v(hidden_states) + ) + b, t, c = q.size() + + q = q.view(q.size(0), q.size(1), self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(k.size(0), k.size(1), self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(v.size(0), v.size(1), self.num_heads, self.head_dim).transpose(1, 2) + + scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale + attn_weights = torch.softmax(scores, dim=-1) + attn_output = torch.matmul(attn_weights, v) + attn_output = attn_output.transpose(1, 2).contiguous().view(b, t, c) + + for layer in self.to_out: + attn_output = layer(attn_output) + + return attn_output + + +class GEGLU(nn.Module): + def __init__(self, in_features, out_features): + super(GEGLU, self).__init__() + self.proj = nn.Linear(in_features, out_features * 2, bias=True) + + def forward(self, x): + x_proj = self.proj(x) + x1, x2 = x_proj.chunk(2, dim=-1) + return x1 * torch.nn.functional.gelu(x2) + + +class FeedForward(nn.Module): + def __init__(self, in_features, out_features): + super(FeedForward, self).__init__() + + self.net = nn.ModuleList( + [ + GEGLU(in_features, out_features * 4), + nn.Dropout(p=0.0, inplace=False), + nn.Linear(out_features * 4, out_features, bias=True), + ] + ) + + def forward(self, x): + for layer in self.net: + x = layer(x) + return x + + +class BasicTransformerBlock(nn.Module): + def __init__(self, hidden_size): + super(BasicTransformerBlock, self).__init__() + self.norm1 = nn.LayerNorm(hidden_size, eps=1e-05, elementwise_affine=True) + self.attn1 = Attention(hidden_size) + self.norm2 = nn.LayerNorm(hidden_size, eps=1e-05, elementwise_affine=True) + self.attn2 = Attention(hidden_size, 2048) + self.norm3 = nn.LayerNorm(hidden_size, eps=1e-05, elementwise_affine=True) + self.ff = FeedForward(hidden_size, hidden_size) + + def forward(self, x, encoder_hidden_states=None): + residual = x + + x = self.norm1(x) + x = self.attn1(x) + x = x + residual + + residual = x + + x = self.norm2(x) + if encoder_hidden_states is not None: + x = self.attn2(x, encoder_hidden_states) + else: + x = self.attn2(x) + x = x + residual + + residual = x + + x = self.norm3(x) + x = self.ff(x) + x = x + residual + return x + + +class Transformer2DModel(nn.Module): + def __init__(self, in_channels, out_channels, n_layers): + super(Transformer2DModel, self).__init__() + self.norm = nn.GroupNorm(32, in_channels, eps=1e-06, affine=True) + self.proj_in = nn.Linear(in_channels, out_channels, bias=True) + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(out_channels) for _ in range(n_layers)] + ) + self.proj_out = nn.Linear(out_channels, out_channels, bias=True) + + def forward(self, hidden_states, encoder_hidden_states=None): + batch, _, height, width = hidden_states.shape + res = hidden_states + hidden_states = self.norm(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( + batch, height * width, inner_dim + ) + hidden_states = self.proj_in(hidden_states) + + for block in self.transformer_blocks: + hidden_states = block(hidden_states, encoder_hidden_states) + + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, height, width, inner_dim) + .permute(0, 3, 1, 2) + .contiguous() + ) + + return hidden_states + res + + +class Downsample2D(nn.Module): + def __init__(self, in_channels, out_channels): + super(Downsample2D, self).__init__() + self.conv = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=2, padding=1 + ) + + def forward(self, x): + return self.conv(x) + + +class Upsample2D(nn.Module): + def __init__(self, in_channels, out_channels): + super(Upsample2D, self).__init__() + self.conv = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + return self.conv(x) + + +class DownBlock2D(nn.Module): + def __init__(self, in_channels, out_channels): + super(DownBlock2D, self).__init__() + self.resnets = nn.ModuleList( + [ + ResnetBlock2D(in_channels, out_channels, conv_shortcut=False), + ResnetBlock2D(out_channels, out_channels, conv_shortcut=False), + ] + ) + self.downsamplers = nn.ModuleList([Downsample2D(out_channels, out_channels)]) + + def forward(self, hidden_states, temb): + output_states = [] + for module in self.resnets: + hidden_states = module(hidden_states, temb) + output_states.append(hidden_states) + + hidden_states = self.downsamplers[0](hidden_states) + output_states.append(hidden_states) + + return hidden_states, output_states + + +class CrossAttnDownBlock2D(nn.Module): + def __init__(self, in_channels, out_channels, n_layers, has_downsamplers=True): + super(CrossAttnDownBlock2D, self).__init__() + self.attentions = nn.ModuleList( + [ + Transformer2DModel(out_channels, out_channels, n_layers), + Transformer2DModel(out_channels, out_channels, n_layers), + ] + ) + self.resnets = nn.ModuleList( + [ + ResnetBlock2D(in_channels, out_channels), + ResnetBlock2D(out_channels, out_channels, conv_shortcut=False), + ] + ) + self.downsamplers = None + if has_downsamplers: + self.downsamplers = nn.ModuleList( + [Downsample2D(out_channels, out_channels)] + ) + + def forward(self, hidden_states, temb, encoder_hidden_states): + output_states = [] + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + ) + output_states.append(hidden_states) + + if self.downsamplers is not None: + hidden_states = self.downsamplers[0](hidden_states) + output_states.append(hidden_states) + + return hidden_states, output_states + + +class CrossAttnUpBlock2D(nn.Module): + def __init__(self, in_channels, out_channels, prev_output_channel, n_layers): + super(CrossAttnUpBlock2D, self).__init__() + self.attentions = nn.ModuleList( + [ + Transformer2DModel(out_channels, out_channels, n_layers), + Transformer2DModel(out_channels, out_channels, n_layers), + Transformer2DModel(out_channels, out_channels, n_layers), + ] + ) + self.resnets = nn.ModuleList( + [ + ResnetBlock2D(prev_output_channel + out_channels, out_channels), + ResnetBlock2D(2 * out_channels, out_channels), + ResnetBlock2D(out_channels + in_channels, out_channels), + ] + ) + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)]) + + def forward( + self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states + ): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class UpBlock2D(nn.Module): + def __init__(self, in_channels, out_channels, prev_output_channel): + super(UpBlock2D, self).__init__() + self.resnets = nn.ModuleList( + [ + ResnetBlock2D(out_channels + prev_output_channel, out_channels), + ResnetBlock2D(out_channels * 2, out_channels), + ResnetBlock2D(out_channels + in_channels, out_channels), + ] + ) + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None): + for resnet in self.resnets: + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class UNetMidBlock2DCrossAttn(nn.Module): + def __init__(self, in_features): + super(UNetMidBlock2DCrossAttn, self).__init__() + self.attentions = nn.ModuleList( + [Transformer2DModel(in_features, in_features, n_layers=10)] + ) + self.resnets = nn.ModuleList( + [ + ResnetBlock2D(in_features, in_features, conv_shortcut=False), + ResnetBlock2D(in_features, in_features, conv_shortcut=False), + ] + ) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): # type: ignore + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + ) + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class UNet2DConditionModel(nn.Module): + def __init__(self): + super(UNet2DConditionModel, self).__init__() + + # This is needed to imitate huggingface config behavior + # has nothing to do with the model itself + # remove this if you don't use diffuser's pipeline + self.config = namedtuple( + "config", "in_channels addition_time_embed_dim sample_size" + ) + self.config.in_channels = 4 # type: ignore + self.config.addition_time_embed_dim = 256 # type: ignore + self.config.sample_size = 128 # type: ignore + + self.conv_in = nn.Conv2d(4, 320, kernel_size=3, stride=1, padding=1) + self.time_proj = Timesteps() + self.time_embedding = TimestepEmbedding(in_features=320, out_features=1280) + self.add_time_proj = Timesteps(256) + self.add_embedding = TimestepEmbedding(in_features=2816, out_features=1280) + self.down_blocks = nn.ModuleList( + [ + DownBlock2D(in_channels=320, out_channels=320), + CrossAttnDownBlock2D(in_channels=320, out_channels=640, n_layers=2), + CrossAttnDownBlock2D( + in_channels=640, + out_channels=1280, + n_layers=10, + has_downsamplers=False, + ), + ] + ) + self.up_blocks = nn.ModuleList( + [ + CrossAttnUpBlock2D( + in_channels=640, + out_channels=1280, + prev_output_channel=1280, + n_layers=10, + ), + CrossAttnUpBlock2D( + in_channels=320, + out_channels=640, + prev_output_channel=1280, + n_layers=2, + ), + UpBlock2D(in_channels=320, out_channels=320, prev_output_channel=640), + ] + ) + self.mid_block = UNetMidBlock2DCrossAttn(1280) + self.conv_norm_out = nn.GroupNorm(32, 320, eps=1e-05, affine=True) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(320, 4, kernel_size=3, stride=1, padding=1) + + @classmethod + def _get_signature_keys(cls, obj): + parameters = inspect.signature(obj.__init__).parameters + required_parameters = { + k: v for k, v in parameters.items() if v.default == inspect._empty + } + expected_modules = set(required_parameters.keys()) - {"self"} + + return expected_modules + + @property + def device(self) -> torch.device: + r""" + Returns: + `torch.device`: The torch device on which the pipeline is located. + """ + module_names = self._get_signature_keys(self) + modules = [getattr(self, n, None) for n in module_names] + modules = [m for m in modules if isinstance(m, torch.nn.Module)] + + for module in modules: + return module.device + + return torch.device("cpu") + + @property + def dtype(self) -> torch.dtype: + r""" + Returns: + `torch.dtype`: The torch dtype on which the pipeline is located. + """ + module_names = self._get_signature_keys(self) + modules = [getattr(self, n, None) for n in module_names] + modules = [m for m in modules if isinstance(m, torch.nn.Module)] + + for module in modules: + return module.dtype + + return torch.float32 + + @property + def attn_processors(self): + return {} + + def set_attn_processor(self, *args, **kwargs): + pass + + def forward( + self, sample, timesteps, encoder_hidden_states, added_cond_kwargs, **kwargs + ): + # Implement the forward pass through the model + timesteps = timesteps.expand(sample.shape[0]) + t_emb = self.time_proj(timesteps).to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb) + + text_embeds = added_cond_kwargs.get("text_embeds") + time_ids = added_cond_kwargs.get("time_ids") + + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + + emb = emb + aug_emb + + sample = self.conv_in(sample) + + # 3. down + s0 = sample + sample, [s1, s2, s3] = self.down_blocks[0]( + sample, + temb=emb, + ) + + sample, [s4, s5, s6] = self.down_blocks[1]( + sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + ) + + sample, [s7, s8] = self.down_blocks[2]( + sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + ) + + # 4. mid + sample = self.mid_block( + sample, emb, encoder_hidden_states=encoder_hidden_states + ) + + # 5. up + sample = self.up_blocks[0]( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=[s6, s7, s8], + encoder_hidden_states=encoder_hidden_states, + ) + + sample = self.up_blocks[1]( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=[s3, s4, s5], + encoder_hidden_states=encoder_hidden_states, + ) + + sample = self.up_blocks[2]( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=[s0, s1, s2], + ) + + # 6. post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return [sample] diff --git a/core/optimizations/upcast.py b/core/optimizations/upcast.py new file mode 100644 index 000000000..a592d1c25 --- /dev/null +++ b/core/optimizations/upcast.py @@ -0,0 +1,35 @@ +import logging + +import torch +from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL + +from core.config import config + +logger = logging.getLogger(__name__) + + +def upcast_vae(vae: AutoencoderKL, silent: bool = False): + logger.disabled = silent + if ( + vae.config["force_upcast"] or config.api.upcast_vae + ) and vae.dtype == torch.float16: + dtype = vae.dtype + logger.info( + 'Upcasting VAE to FP32 (vae["force_upcast"] OR config.api.upcast_vae)' + ) + vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + vae.decoder.mid_block.attentions[0].processor, # type: ignore + ( + AttnProcessor2_0, + XFormersAttnProcessor, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + vae.post_quant_conv.to(dtype=dtype) + vae.decoder.conv_in.to(dtype=dtype) + vae.decoder.mid_block.to(dtype=dtype) # type: ignore + return vae diff --git a/core/png_metadata.py b/core/png_metadata.py index 51c426e19..674d5a11d 100644 --- a/core/png_metadata.py +++ b/core/png_metadata.py @@ -34,6 +34,7 @@ def create_metadata( UpscaleQueueEntry, ], index: int, + extension: str, ): "Return image with metadata burned into it" @@ -51,7 +52,7 @@ def write_metadata_text(key: str): def write_metadata_exif(key: str): exif_meta_dict[key] = str(unwrap_enum_name(data.__dict__.get(key, ""))) - if config.api.image_extension == "png": + if extension == "png": for key in fields(data): if key.name not in ("image", "mask_image"): write_metadata_text(key.name) @@ -65,7 +66,7 @@ def write_metadata_exif(key: str): elif isinstance(job, Img2ImgQueueEntry): procedure = "img2img" elif isinstance(job, InpaintQueueEntry): - procedure = "inpaint" + procedure = "inpainting" elif isinstance(job, ControlNetQueueEntry): procedure = "control_net" elif isinstance(job, UpscaleQueueEntry): @@ -73,7 +74,7 @@ def write_metadata_exif(key: str): else: procedure = "unknown" - if config.api.image_extension == "png": + if extension == "png": text_metadata.add_text("procedure", procedure) text_metadata.add_text("model", job.model) user_comment: bytes = b"" # for type checking @@ -85,7 +86,7 @@ def write_metadata_exif(key: str): json.dumps(exif_meta_dict, ensure_ascii=False), encoding="unicode" ) - return text_metadata if config.api.image_extension == "png" else user_comment + return text_metadata if extension == "png" else user_comment def save_images( @@ -138,7 +139,8 @@ def save_images( else: folder = "img2img" - metadata = create_metadata(job, i) + extension = "gif" if isinstance(image, BytesIO) else config.api.image_extension + metadata = create_metadata(job, i, extension=extension) if job.save_image == "r2": # Save into Cloudflare R2 bucket @@ -148,7 +150,7 @@ def save_images( filename = f"{job.data.id}-{i}.png" image_bytes = BytesIO() - image.save(image_bytes, pnginfo=metadata, format=config.api.image_extension) + image.save(image_bytes, pnginfo=metadata, format=extension) image_bytes.seek(0) url = r2.upload_file(file=image_bytes, filename=filename) @@ -168,7 +170,7 @@ def save_images( if not isinstance(job, UpscaleQueueEntry) else "0", "index": i, - "extension": config.api.image_extension, + "extension": extension, } ) @@ -179,24 +181,30 @@ def save_images( with path.open("wb") as f: logger.debug(f"Saving image to {path.as_posix()}") - if config.api.image_extension == "png": + if extension == "png": image.save(f, pnginfo=metadata) else: - # ! This is using 2 filesystem calls, find a way to save directly to disk with metadata properly inserted - # Save the image - image.save(f, quality=config.api.image_quality) - - # Insert metadata - exif_metadata = { - "0th": {}, - "Exif": Image.Exif(), - "GPS": {}, - "Interop": {}, - "1st": {}, - } - exif_metadata["Exif"][piexif.ExifIFD.UserComment] = metadata - exif_bytes = piexif.dump(exif_metadata) - piexif.insert(exif_bytes, path.as_posix()) + buffer = BytesIO() + if extension == "gif": + buffer: BytesIO = image # type: ignore + else: + image.save(buffer, quality=config.api.image_quality) + + if extension == "gif": + buffer.seek(0) + f.write(buffer.getbuffer()) + else: + # Insert metadata + exif_metadata = { + "0th": {}, + "Exif": Image.Exif(), + "GPS": {}, + "Interop": {}, + "1st": {}, + } + exif_metadata["Exif"][piexif.ExifIFD.UserComment] = metadata + exif_bytes = piexif.dump(exif_metadata) + piexif.insert(exif_bytes, buffer, f) return urls diff --git a/core/scheduling/adapter/k_adapter.py b/core/scheduling/adapter/k_adapter.py index 618a731a4..31cec61bc 100644 --- a/core/scheduling/adapter/k_adapter.py +++ b/core/scheduling/adapter/k_adapter.py @@ -24,7 +24,11 @@ class KdiffusionSchedulerAdapter: denoiser: Denoiser # diffusers compat - config: dict = {"steps_offset": 0, "prediction_type": "epsilon"} + config: dict = { + "steps_offset": 0, + "prediction_type": "epsilon", + "num_train_timesteps": 1000, + } # should really be "sigmas," but for compatibility with diffusers # it's named timesteps. @@ -135,14 +139,27 @@ def do_inference( self, x: torch.Tensor, call: Callable, - apply_model: Callable[..., torch.Tensor], + apply_model: Callable[ + [ + torch.Tensor, + torch.IntTensor, + Callable[..., torch.Tensor], + Callable[[Callable], None], + ], + torch.Tensor, + ], generator: Union[PhiloxGenerator, torch.Generator], callback, callback_steps, + device: torch.device = None, # type: ignore ) -> torch.Tensor: "Run inference function provided with denoiser." - apply_model = functools.partial(apply_model, call=self.denoiser) - self.denoiser.inner_model.callable = call + + def change_source(src): + self.denoiser.inner_model.callable = src + + apply_model = functools.partial(apply_model, call=self.denoiser, change_source=change_source) # type: ignore + change_source(call) def callback_func(data): if callback is not None and data["i"] % callback_steps == 0: @@ -168,7 +185,7 @@ def noiser(sigma=None, sigma_next=None): "model": apply_model, "x": x, "callback": callback_func, - "sigmas": self.timesteps, + "sigmas": self.timesteps.to(device=x.device), "sigma_min": self.denoiser.sigmas[0].item(), # type: ignore "sigma_max": self.denoiser.sigmas[-1].item(), # type: ignore "noise_sampler": create_noise_sampler(), diff --git a/core/scheduling/adapter/unipc_adapter.py b/core/scheduling/adapter/unipc_adapter.py index c93dd42c2..277b551c3 100644 --- a/core/scheduling/adapter/unipc_adapter.py +++ b/core/scheduling/adapter/unipc_adapter.py @@ -2,6 +2,7 @@ from typing import Any, Callable, Optional, Union import torch +from diffusers.models.unet_2d_condition import UNet2DConditionModel from core.inference.utilities.philox import PhiloxGenerator @@ -114,40 +115,45 @@ def do_inference( callback_steps, optional_device: Optional[torch.device] = None, optional_dtype: Optional[torch.dtype] = None, + device: torch.device = None, # type: ignore ) -> torch.Tensor: device = optional_device or call.device dtype = optional_dtype or call.dtype + unet_or_controlnet = call + def noise_pred_fn(x, t_continuous, cond=None, **model_kwargs): # Was originally get_model_input_time(t_continous) # but "schedule" is ALWAYS "discrete," so we can skip it :) t_input = (t_continuous - 1.0 / self.scheduler.total_N) * 1000 if cond is None: - output = call( + output = unet_or_controlnet( x.to(device=device, dtype=dtype), t_input.to(device=device, dtype=dtype), - return_dict=True, + return_dict=False, **model_kwargs, - )[0] + ) + if isinstance(unet_or_controlnet, UNet2DConditionModel): + output = output[0] else: - output = call(x.to(device=device, dtype=dtype), t_input.to(device=device, dtype=dtype), return_dict=True, encoder_hidden_states=cond, **model_kwargs)[0] # type: ignore - if self.model_type == "noise": - return output - elif self.model_type == "x_start": - alpha_t, sigma_t = self.scheduler.marginal_alpha( - t_continuous - ), self.scheduler.marginal_std(t_continuous) - return (x - alpha_t * output) / sigma_t - elif self.model_type == "v": - alpha_t, sigma_t = self.scheduler.marginal_alpha( - t_continuous - ), self.scheduler.marginal_std(t_continuous) - return alpha_t * output + sigma_t * x - elif self.model_type == "score": - sigma_t = self.scheduler.marginal_std(t_continuous) - return -sigma_t * output - - apply_model = functools.partial(apply_model, call=noise_pred_fn) + output = unet_or_controlnet( + x.to(device=device, dtype=dtype), + t_input.to(device=device, dtype=dtype), + encoder_hidden_states=cond, + return_dict=False, + **model_kwargs, + ) + if isinstance(unet_or_controlnet, UNet2DConditionModel): + output = output[0] + return output + + def change_source(src): + nonlocal unet_or_controlnet + unet_or_controlnet = src + + apply_model = functools.partial( + apply_model, call=noise_pred_fn, change_source=change_source + ) # predict_x0=True -> algorithm_type="data_prediction" # predict_x0=False -> algorithm_type="noise_prediction" diff --git a/core/scheduling/custom/heunpp.py b/core/scheduling/custom/heunpp.py new file mode 100644 index 000000000..634127209 --- /dev/null +++ b/core/scheduling/custom/heunpp.py @@ -0,0 +1,82 @@ +import torch +from tqdm import trange +from k_diffusion.sampling import to_d + + +@torch.no_grad() +def sample_heunpp2( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + s_churn=0.0, + s_tmin=0.0, + s_tmax=float("inf"), + s_noise=1.0, +): + """Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" + # https://github.com/Carzit/sd-webui-samplers-scheduler-for-v1.6/blob/main/scripts/ksampler.py#L356 + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + s_end = sigmas[-1] + for i in trange(len(sigmas) - 1, disable=disable): + gamma = ( + min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) + if s_tmin <= sigmas[i] <= s_tmax + else 0.0 + ) + eps = torch.randn_like(x) * s_noise + sigma_hat = sigmas[i] * (gamma + 1) + if gamma > 0: + x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5 + denoised = model(x, sigma_hat * s_in, **extra_args) + d = to_d(x, sigma_hat, denoised) + if callback is not None: + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigma_hat, + "denoised": denoised, + } + ) + dt = sigmas[i + 1] - sigma_hat + if sigmas[i + 1] == s_end: + # Euler method + x = x + d * dt + elif sigmas[i + 2] == s_end: + # Heun's method + x_2 = x + d * dt + denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args) + d_2 = to_d(x_2, sigmas[i + 1], denoised_2) + + w = 2 * sigmas[0] + w2 = sigmas[i + 1] / w + w1 = 1 - w2 + + d_prime = d * w1 + d_2 * w2 + + x = x + d_prime * dt + + else: + # Heun++ + x_2 = x + d * dt + denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args) + d_2 = to_d(x_2, sigmas[i + 1], denoised_2) + dt_2 = sigmas[i + 2] - sigmas[i + 1] + + x_3 = x_2 + d_2 * dt_2 + denoised_3 = model(x_3, sigmas[i + 2] * s_in, **extra_args) + d_3 = to_d(x_3, sigmas[i + 2], denoised_3) + + w = 3 * sigmas[0] + w2 = sigmas[i + 1] / w + w3 = sigmas[i + 2] / w + w1 = 1 - w2 - w3 + + d_prime = w1 * d + w2 * d_2 + w3 * d_3 + x = x + d_prime * dt + return x diff --git a/core/scheduling/custom/lcm.py b/core/scheduling/custom/lcm.py new file mode 100644 index 000000000..ca2de8ad5 --- /dev/null +++ b/core/scheduling/custom/lcm.py @@ -0,0 +1,35 @@ +from k_diffusion.sampling import default_noise_sampler +import torch +from tqdm import trange + + +@torch.no_grad() +def sample_lcm( + model, + x: torch.Tensor, + sigmas, + extra_args=None, + callback=None, + disable=None, + noise_sampler=None, +): + extra_args = {} if extra_args is None else extra_args + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) + + x = denoised + if sigmas[i + 1] > 0: + x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1]) + return x diff --git a/core/scheduling/custom/sasolver.py b/core/scheduling/custom/sasolver.py new file mode 100644 index 000000000..3a2afc25a --- /dev/null +++ b/core/scheduling/custom/sasolver.py @@ -0,0 +1,1104 @@ +# type: ignore + +# Copyright 2023 Shuchen Xue, etc. in University of Chinese Academy of Sciences Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: check https://arxiv.org/abs/2309.05019 +# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py + +import math +from typing import List, Optional, Tuple, Union, Callable + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils.torch_utils import randn_tensor +from diffusers.schedulers.scheduling_utils import ( + KarrasDiffusionSchedulers, + SchedulerMixin, + SchedulerOutput, +) + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class SASolverScheduler(SchedulerMixin, ConfigMixin): + """ + `SASolverScheduler` is a fast dedicated high-order solver for diffusion SDEs. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + predictor_order (`int`, defaults to 2): + The predictor order which can be `1` or `2` or `3` or '4'. It is recommended to use `predictor_order=2` for guided + sampling, and `predictor_order=3` for unconditional sampling. + corrector_order (`int`, defaults to 2): + The corrector order which can be `1` or `2` or `3` or '4'. It is recommended to use `corrector_order=2` for guided + sampling, and `corrector_order=3` for unconditional sampling. + predictor_corrector_mode (`str`, defaults to `PEC`): + The predictor-corrector mode can be `PEC` or 'PECE'. It is recommended to use `PEC` mode for fast + sampling, and `PECE` for high-quality sampling (PECE needs around twice model evaluations as PEC). + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and + `algorithm_type="dpmsolver++"`. + algorithm_type (`str`, defaults to `data_prediction`): + Algorithm type for the solver; can be `data_prediction` or `noise_prediction`. It is recommended to use `data_prediction` + with `solver_order=2` for guided sampling like in Stable Diffusion. + lower_order_final (`bool`, defaults to `True`): + Whether to use lower-order solvers in the final steps. Default = True. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + lambda_min_clipped (`float`, defaults to `-inf`): + Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the + cosine (`squaredcos_cap_v2`) noise schedule. + variance_type (`str`, *optional*): + Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output + contains the predicted Gaussian variance. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable + Diffusion. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + predictor_order: int = 2, + corrector_order: int = 2, + predictor_corrector_mode: str = "PEC", + prediction_type: str = "epsilon", + tau_func: Callable = lambda t: 1 if t >= 200 and t <= 800 else 0, + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "data_prediction", + lower_order_final: bool = True, + use_karras_sigmas: Optional[bool] = False, + lambda_min_clipped: float = -float("inf"), + variance_type: Optional[str] = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace( + beta_start, beta_end, num_train_timesteps, dtype=torch.float32 + ) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace( + beta_start**0.5, + beta_end**0.5, + num_train_timesteps, + dtype=torch.float32, + ) + ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError( + f"{beta_schedule} does is not implemented for {self.__class__}" + ) + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + # Currently we only support VP-type noise schedule + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + if algorithm_type not in ["data_prediction", "noise_prediction"]: + raise NotImplementedError( + f"{algorithm_type} does is not implemented for {self.__class__}" + ) + + # setable values + self.num_inference_steps = None + timesteps = np.linspace( + 0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32 + )[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.timestep_list = [None] * max(predictor_order, corrector_order - 1) + self.model_outputs = [None] * max(predictor_order, corrector_order - 1) + + self.tau_func = tau_func + self.predict_x0 = algorithm_type == "data_prediction" + self.lower_order_nums = 0 + self.last_sample = None + + def set_timesteps( + self, num_inference_steps: int = None, device: Union[str, torch.device] = None + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + # Clipping the minimum of all lambda(t) for numerical stability. + # This is critical for cosine (squaredcos_cap_v2) noise schedule. + clipped_idx = torch.searchsorted( + torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped + ) + last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item() + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, last_timestep - 1, num_inference_steps + 1) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + + elif self.config.timestep_spacing == "leading": + step_ratio = last_timestep // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + (np.arange(0, num_inference_steps + 1) * step_ratio) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64) + ) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + if self.config.use_karras_sigmas: + log_sigmas = np.log(sigmas) + sigmas = self._convert_to_karras( + in_sigmas=sigmas, num_inference_steps=num_inference_steps + ) + timesteps = np.array( + [self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas] + ).round() + timesteps = np.flip(timesteps).copy().astype(np.int64) + + self.sigmas = torch.from_numpy(sigmas) + + # when num_inference_steps == num_train_timesteps, we can end up with + # duplicates in timesteps. + _, unique_indices = np.unique(timesteps, return_index=True) + timesteps = timesteps[np.sort(unique_indices)] + + self.timesteps = torch.from_numpy(timesteps).to(device) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * max(self.config.predictor_order, self.config.corrector_order - 1) + self.lower_order_nums = 0 + self.last_sample = None + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, height, width = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = ( + sample.float() + ) # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * height * width) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = ( + torch.clamp(sample, -s, s) / s + ) # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, height, width) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(sigma) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = ( + np.cumsum((dists >= 0), axis=0) + .argmax(axis=0) + .clip(max=log_sigmas.shape[0] - 2) + ) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras( + self, in_sigmas: torch.FloatTensor, num_inference_steps + ) -> torch.FloatTensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + sigma_min: float = in_sigmas[-1].item() + sigma_max: float = in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + def convert_model_output( + self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor + ) -> torch.FloatTensor: + """ + Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is + designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an + integral of the data prediction model. + + + + The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise + prediction and data prediction models. + + + + Args: + model_output (`torch.FloatTensor`): + The direct output from the learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.FloatTensor`: + The converted model output. + """ + + # SA-Solver_data_prediction needs to solve an integral of the data prediction model. + if self.config.algorithm_type in ["data_prediction"]: + if self.config.prediction_type == "epsilon": + # SA-Solver only needs the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + model_output = model_output[:, :3] + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = alpha_t * sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the SASolverScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + + # SA-Solver_noise_prediction needs to solve an integral of the noise prediction model. + elif self.config.algorithm_type in ["noise_prediction"]: + if self.config.prediction_type == "epsilon": + # SA-Solver only needs the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + epsilon = model_output[:, :3] + else: + epsilon = model_output + elif self.config.prediction_type == "sample": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + epsilon = (sample - alpha_t * model_output) / sigma_t + elif self.config.prediction_type == "v_prediction": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + epsilon = alpha_t * model_output + sigma_t * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the SASolverScheduler." + ) + + if self.config.thresholding: + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = (sample - sigma_t * epsilon) / alpha_t + x0_pred = self._threshold_sample(x0_pred) + epsilon = (sample - alpha_t * x0_pred) / sigma_t + + return epsilon + + def get_coefficients_exponential_negative( + self, order, interval_start, interval_end + ): + """ + Calculate the integral of exp(-x) * x^order dx from interval_start to interval_end + """ + assert order in [0, 1, 2, 3], "order is only supported for 0, 1, 2 and 3" + + if order == 0: + return torch.exp(-interval_end) * ( + torch.exp(interval_end - interval_start) - 1 + ) + elif order == 1: + return torch.exp(-interval_end) * ( + (interval_start + 1) * torch.exp(interval_end - interval_start) + - (interval_end + 1) + ) + elif order == 2: + return torch.exp(-interval_end) * ( + (interval_start**2 + 2 * interval_start + 2) + * torch.exp(interval_end - interval_start) + - (interval_end**2 + 2 * interval_end + 2) + ) + elif order == 3: + return torch.exp(-interval_end) * ( + (interval_start**3 + 3 * interval_start**2 + 6 * interval_start + 6) + * torch.exp(interval_end - interval_start) + - (interval_end**3 + 3 * interval_end**2 + 6 * interval_end + 6) + ) + + def get_coefficients_exponential_positive( + self, order, interval_start, interval_end, tau + ): + """ + Calculate the integral of exp(x(1+tau^2)) * x^order dx from interval_start to interval_end + """ + assert order in [0, 1, 2, 3], "order is only supported for 0, 1, 2 and 3" + + # after change of variable(cov) + interval_end_cov = (1 + tau**2) * interval_end + interval_start_cov = (1 + tau**2) * interval_start + + if order == 0: + return ( + torch.exp(interval_end_cov) + * (1 - torch.exp(-(interval_end_cov - interval_start_cov))) + / ((1 + tau**2)) + ) + elif order == 1: + return ( + torch.exp(interval_end_cov) + * ( + (interval_end_cov - 1) + - (interval_start_cov - 1) + * torch.exp(-(interval_end_cov - interval_start_cov)) + ) + / ((1 + tau**2) ** 2) + ) + elif order == 2: + return ( + torch.exp(interval_end_cov) + * ( + (interval_end_cov**2 - 2 * interval_end_cov + 2) + - (interval_start_cov**2 - 2 * interval_start_cov + 2) + * torch.exp(-(interval_end_cov - interval_start_cov)) + ) + / ((1 + tau**2) ** 3) + ) + elif order == 3: + return ( + torch.exp(interval_end_cov) + * ( + ( + interval_end_cov**3 + - 3 * interval_end_cov**2 + + 6 * interval_end_cov + - 6 + ) + - ( + interval_start_cov**3 + - 3 * interval_start_cov**2 + + 6 * interval_start_cov + - 6 + ) + * torch.exp(-(interval_end_cov - interval_start_cov)) + ) + / ((1 + tau**2) ** 4) + ) + + def lagrange_polynomial_coefficient(self, order, lambda_list): + """ + Calculate the coefficient of lagrange polynomial + """ + + assert order in [0, 1, 2, 3] + assert order == len(lambda_list) - 1 + if order == 0: + return [[1]] + elif order == 1: + return [ + [ + 1 / (lambda_list[0] - lambda_list[1]), + -lambda_list[1] / (lambda_list[0] - lambda_list[1]), + ], + [ + 1 / (lambda_list[1] - lambda_list[0]), + -lambda_list[0] / (lambda_list[1] - lambda_list[0]), + ], + ] + elif order == 2: + denominator1 = (lambda_list[0] - lambda_list[1]) * ( + lambda_list[0] - lambda_list[2] + ) + denominator2 = (lambda_list[1] - lambda_list[0]) * ( + lambda_list[1] - lambda_list[2] + ) + denominator3 = (lambda_list[2] - lambda_list[0]) * ( + lambda_list[2] - lambda_list[1] + ) + return [ + [ + 1 / denominator1, + (-lambda_list[1] - lambda_list[2]) / denominator1, + lambda_list[1] * lambda_list[2] / denominator1, + ], + [ + 1 / denominator2, + (-lambda_list[0] - lambda_list[2]) / denominator2, + lambda_list[0] * lambda_list[2] / denominator2, + ], + [ + 1 / denominator3, + (-lambda_list[0] - lambda_list[1]) / denominator3, + lambda_list[0] * lambda_list[1] / denominator3, + ], + ] + elif order == 3: + denominator1 = ( + (lambda_list[0] - lambda_list[1]) + * (lambda_list[0] - lambda_list[2]) + * (lambda_list[0] - lambda_list[3]) + ) + denominator2 = ( + (lambda_list[1] - lambda_list[0]) + * (lambda_list[1] - lambda_list[2]) + * (lambda_list[1] - lambda_list[3]) + ) + denominator3 = ( + (lambda_list[2] - lambda_list[0]) + * (lambda_list[2] - lambda_list[1]) + * (lambda_list[2] - lambda_list[3]) + ) + denominator4 = ( + (lambda_list[3] - lambda_list[0]) + * (lambda_list[3] - lambda_list[1]) + * (lambda_list[3] - lambda_list[2]) + ) + return [ + [ + 1 / denominator1, + (-lambda_list[1] - lambda_list[2] - lambda_list[3]) / denominator1, + ( + lambda_list[1] * lambda_list[2] + + lambda_list[1] * lambda_list[3] + + lambda_list[2] * lambda_list[3] + ) + / denominator1, + (-lambda_list[1] * lambda_list[2] * lambda_list[3]) / denominator1, + ], + [ + 1 / denominator2, + (-lambda_list[0] - lambda_list[2] - lambda_list[3]) / denominator2, + ( + lambda_list[0] * lambda_list[2] + + lambda_list[0] * lambda_list[3] + + lambda_list[2] * lambda_list[3] + ) + / denominator2, + (-lambda_list[0] * lambda_list[2] * lambda_list[3]) / denominator2, + ], + [ + 1 / denominator3, + (-lambda_list[0] - lambda_list[1] - lambda_list[3]) / denominator3, + ( + lambda_list[0] * lambda_list[1] + + lambda_list[0] * lambda_list[3] + + lambda_list[1] * lambda_list[3] + ) + / denominator3, + (-lambda_list[0] * lambda_list[1] * lambda_list[3]) / denominator3, + ], + [ + 1 / denominator4, + (-lambda_list[0] - lambda_list[1] - lambda_list[2]) / denominator4, + ( + lambda_list[0] * lambda_list[1] + + lambda_list[0] * lambda_list[2] + + lambda_list[1] * lambda_list[2] + ) + / denominator4, + (-lambda_list[0] * lambda_list[1] * lambda_list[2]) / denominator4, + ], + ] + + def get_coefficients_fn( + self, order, interval_start, interval_end, lambda_list, tau + ): + assert order in [1, 2, 3, 4] + assert order == len( + lambda_list + ), "the length of lambda list must be equal to the order" + coefficients = [] + lagrange_coefficient = self.lagrange_polynomial_coefficient( + order - 1, lambda_list + ) + for i in range(order): + coefficient = 0 + for j in range(order): + if self.predict_x0: + coefficient += lagrange_coefficient[i][ + j + ] * self.get_coefficients_exponential_positive( + order - 1 - j, interval_start, interval_end, tau + ) + else: + coefficient += lagrange_coefficient[i][ + j + ] * self.get_coefficients_exponential_negative( + order - 1 - j, interval_start, interval_end + ) + coefficients.append(coefficient) + assert ( + len(coefficients) == order + ), "the length of coefficients does not match the order" + return coefficients + + def stochastic_adams_bashforth_update( + self, + model_output: torch.FloatTensor, + prev_timestep: int, + sample: torch.FloatTensor, + noise: torch.FloatTensor, + order: int, + tau: torch.FloatTensor, + ) -> torch.FloatTensor: + """ + One step for the SA-Predictor. + + Args: + model_output (`torch.FloatTensor`): + The direct output from the learned diffusion model at the current timestep. + prev_timestep (`int`): + The previous discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + order (`int`): + The order of SA-Predictor at this timestep. + + Returns: + `torch.FloatTensor`: + The sample tensor at the previous timestep. + """ + + assert noise is not None + timestep_list = self.timestep_list + model_output_list = self.model_outputs + s0, t = self.timestep_list[-1], prev_timestep + lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0] + alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] + sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + gradient_part = torch.zeros_like(sample) + h = lambda_t - lambda_s0 + lambda_list = [] + + for i in range(order): + lambda_list.append(self.lambda_t[timestep_list[-(i + 1)]]) + + gradient_coefficients = self.get_coefficients_fn( + order, lambda_s0, lambda_t, lambda_list, tau + ) + + x = sample + + if self.predict_x0: + if ( + order == 2 + ): ## if order = 2 we do a modification that does not influence the convergence order similar to unipc. Note: This is used only for few steps sampling. + # The added term is O(h^3). Empirically we find it will slightly improve the image quality. + # ODE case + # gradient_coefficients[0] += 1.0 * torch.exp(lambda_t) * (h ** 2 / 2 - (h - 1 + torch.exp(-h))) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda(t_prev_list[-2])) + # gradient_coefficients[1] -= 1.0 * torch.exp(lambda_t) * (h ** 2 / 2 - (h - 1 + torch.exp(-h))) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda(t_prev_list[-2])) + gradient_coefficients[0] += ( + 1.0 + * torch.exp((1 + tau**2) * lambda_t) + * ( + h**2 / 2 + - (h * (1 + tau**2) - 1 + torch.exp((1 + tau**2) * (-h))) + / ((1 + tau**2) ** 2) + ) + / ( + self.lambda_t[timestep_list[-1]] + - self.lambda_t[timestep_list[-2]] + ) + ) + gradient_coefficients[1] -= ( + 1.0 + * torch.exp((1 + tau**2) * lambda_t) + * ( + h**2 / 2 + - (h * (1 + tau**2) - 1 + torch.exp((1 + tau**2) * (-h))) + / ((1 + tau**2) ** 2) + ) + / ( + self.lambda_t[timestep_list[-1]] + - self.lambda_t[timestep_list[-2]] + ) + ) + + for i in range(order): + if self.predict_x0: + gradient_part += ( + (1 + tau**2) + * sigma_t + * torch.exp(-(tau**2) * lambda_t) + * gradient_coefficients[i] + * model_output_list[-(i + 1)] + ) + else: + gradient_part += ( + -(1 + tau**2) + * alpha_t + * gradient_coefficients[i] + * model_output_list[-(i + 1)] + ) + + if self.predict_x0: + noise_part = sigma_t * torch.sqrt(1 - torch.exp(-2 * tau**2 * h)) * noise + else: + noise_part = tau * sigma_t * torch.sqrt(torch.exp(2 * h) - 1) * noise + + if self.predict_x0: + x_t = ( + torch.exp(-(tau**2) * h) * (sigma_t / sigma_s0) * x + + gradient_part + + noise_part + ) + else: + x_t = (alpha_t / alpha_s0) * x + gradient_part + noise_part + + x_t = x_t.to(x.dtype) + return x_t + + def stochastic_adams_moulton_update( + self, + this_model_output: torch.FloatTensor, + this_timestep: int, + last_sample: torch.FloatTensor, + last_noise: torch.FloatTensor, + this_sample: torch.FloatTensor, + order: int, + tau: torch.FloatTensor, + ) -> torch.FloatTensor: + """ + One step for the SA-Corrector. + + Args: + this_model_output (`torch.FloatTensor`): + The model outputs at `x_t`. + this_timestep (`int`): + The current timestep `t`. + last_sample (`torch.FloatTensor`): + The generated sample before the last predictor `x_{t-1}`. + this_sample (`torch.FloatTensor`): + The generated sample after the last predictor `x_{t}`. + order (`int`): + The order of SA-Corrector at this step. + + Returns: + `torch.FloatTensor`: + The corrected sample tensor at the current timestep. + """ + + assert last_noise is not None + timestep_list = self.timestep_list + model_output_list = self.model_outputs + s0, t = self.timestep_list[-1], this_timestep + lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0] + alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] + sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + gradient_part = torch.zeros_like(this_sample) + h = lambda_t - lambda_s0 + t_list = timestep_list + [this_timestep] + lambda_list = [] + for i in range(order): + lambda_list.append(self.lambda_t[t_list[-(i + 1)]]) + + model_prev_list = model_output_list + [this_model_output] + + gradient_coefficients = self.get_coefficients_fn( + order, lambda_s0, lambda_t, lambda_list, tau + ) + + x = last_sample + + if self.predict_x0: + if ( + order == 2 + ): ## if order = 2 we do a modification that does not influence the convergence order similar to UniPC. Note: This is used only for few steps sampling. + # The added term is O(h^3). Empirically we find it will slightly improve the image quality. + # ODE case + # gradient_coefficients[0] += 1.0 * torch.exp(lambda_t) * (h / 2 - (h - 1 + torch.exp(-h)) / h) + # gradient_coefficients[1] -= 1.0 * torch.exp(lambda_t) * (h / 2 - (h - 1 + torch.exp(-h)) / h) + gradient_coefficients[0] += ( + 1.0 + * torch.exp((1 + tau**2) * lambda_t) + * ( + h / 2 + - (h * (1 + tau**2) - 1 + torch.exp((1 + tau**2) * (-h))) + / ((1 + tau**2) ** 2 * h) + ) + ) + gradient_coefficients[1] -= ( + 1.0 + * torch.exp((1 + tau**2) * lambda_t) + * ( + h / 2 + - (h * (1 + tau**2) - 1 + torch.exp((1 + tau**2) * (-h))) + / ((1 + tau**2) ** 2 * h) + ) + ) + + for i in range(order): + if self.predict_x0: + gradient_part += ( + (1 + tau**2) + * sigma_t + * torch.exp(-(tau**2) * lambda_t) + * gradient_coefficients[i] + * model_prev_list[-(i + 1)] + ) + else: + gradient_part += ( + -(1 + tau**2) + * alpha_t + * gradient_coefficients[i] + * model_prev_list[-(i + 1)] + ) + + if self.predict_x0: + noise_part = ( + sigma_t * torch.sqrt(1 - torch.exp(-2 * tau**2 * h)) * last_noise + ) + else: + noise_part = tau * sigma_t * torch.sqrt(torch.exp(2 * h) - 1) * last_noise + + if self.predict_x0: + x_t = ( + torch.exp(-(tau**2) * h) * (sigma_t / sigma_s0) * x + + gradient_part + + noise_part + ) + else: + x_t = (alpha_t / alpha_s0) * x + gradient_part + noise_part + + x_t = x_t.to(x.dtype) + return x_t + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + generator=None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the SA-Solver. + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + step_index = (self.timesteps == timestep).nonzero() + if len(step_index) == 0: + step_index = len(self.timesteps) - 1 + else: + step_index = step_index.item() + + use_corrector = step_index > 0 and self.last_sample is not None + + model_output_convert = self.convert_model_output(model_output, timestep, sample) + + if use_corrector: + current_tau = self.tau_func(self.timestep_list[-1]) + sample = self.stochastic_adams_moulton_update( + this_model_output=model_output_convert, + this_timestep=timestep, + last_sample=self.last_sample, + last_noise=self.last_noise, + this_sample=sample, + order=self.this_corrector_order, + tau=current_tau, + ) + + prev_timestep = ( + 0 + if step_index == len(self.timesteps) - 1 + else self.timesteps[step_index + 1] + ) + + for i in range( + max(self.config.predictor_order, self.config.corrector_order - 1) - 1 + ): + self.model_outputs[i] = self.model_outputs[i + 1] + self.timestep_list[i] = self.timestep_list[i + 1] + + self.model_outputs[-1] = model_output_convert + self.timestep_list[-1] = timestep + + noise = randn_tensor( + model_output.shape, + generator=generator, + device=model_output.device, + dtype=model_output.dtype, + ) + + if self.config.lower_order_final: + this_predictor_order = min( + self.config.predictor_order, len(self.timesteps) - step_index + ) + this_corrector_order = min( + self.config.corrector_order, len(self.timesteps) - step_index + 1 + ) + else: + this_predictor_order = self.config.predictor_order + this_corrector_order = self.config.corrector_order + + self.this_predictor_order = min( + this_predictor_order, self.lower_order_nums + 1 + ) # warmup for multistep + self.this_corrector_order = min( + this_corrector_order, self.lower_order_nums + 2 + ) # warmup for multistep + assert self.this_predictor_order > 0 + assert self.this_corrector_order > 0 + + self.last_sample = sample + self.last_noise = noise + + current_tau = self.tau_func(self.timestep_list[-1]) + prev_sample = self.stochastic_adams_bashforth_update( + model_output=model_output_convert, + prev_timestep=prev_timestep, + sample=sample, + noise=noise, + order=self.this_predictor_order, + tau=current_tau, + ) + + if self.lower_order_nums < max( + self.config.predictor_order, self.config.corrector_order - 1 + ): + self.lower_order_nums += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input( + self, sample: torch.FloatTensor, *args, **kwargs + ) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): + The input sample. + + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + alphas_cumprod = self.alphas_cumprod.to( + device=original_samples.device, dtype=original_samples.dtype + ) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = ( + sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + ) + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/core/scheduling/denoiser.py b/core/scheduling/denoiser.py index bfecffb3f..b5ad46abd 100644 --- a/core/scheduling/denoiser.py +++ b/core/scheduling/denoiser.py @@ -1,4 +1,5 @@ import torch +from diffusers.models.unet_2d_condition import UNet2DConditionModel from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser from .types import Denoiser @@ -17,7 +18,10 @@ def apply_model(self, *args, **kwargs) -> torch.Tensor: if kwargs.get("cond", None) is not None: encoder_hidden_states = kwargs.pop("cond") if isinstance(self.callable, torch.nn.Module): - return self.callable(*args, encoder_hidden_states=encoder_hidden_states, return_dict=True, **kwargs)[0] # type: ignore + ret = self.callable(*args, encoder_hidden_states=encoder_hidden_states, return_dict=False, **kwargs) # type: ignore + if isinstance(self.callable, UNet2DConditionModel): + return ret[0] + return ret else: return self.callable(*args, encoder_hidden_states=encoder_hidden_states, **kwargs) # type: ignore diff --git a/core/scheduling/scheduling.py b/core/scheduling/scheduling.py index 0ca46aff9..4510bddca 100644 --- a/core/scheduling/scheduling.py +++ b/core/scheduling/scheduling.py @@ -10,6 +10,8 @@ from .adapter.unipc_adapter import UnipcSchedulerAdapter from .custom.dpmpp_2m import sample_dpmpp_2mV2 from .custom.restart import restart_sampler +from .custom.heunpp import sample_heunpp2 +from .custom.lcm import sample_lcm from .denoiser import create_denoiser logger = logging.getLogger(__name__) @@ -76,8 +78,10 @@ "sample_dpmpp_3m_sde", {"brownian_noise": True}, ), + ("heunpp", sample_heunpp2, {}), ("unipc_multistep", "unipc", {}), ("restart", restart_sampler, {}), + ("lcm", sample_lcm, {}), ] diff --git a/core/scheduling/unipc/unipc.py b/core/scheduling/unipc/unipc.py index 23f8a34b1..5d82d1dc1 100644 --- a/core/scheduling/unipc/unipc.py +++ b/core/scheduling/unipc/unipc.py @@ -44,6 +44,7 @@ def __init__( We support both data_prediction and noise_prediction. """ + self.model = lambda x, t: model_fn(x, t.expand((x.shape[0]))) self.noise_schedule = noise_schedule diff --git a/core/shared.py b/core/shared.py index 8530d61b6..579b045dd 100644 --- a/core/shared.py +++ b/core/shared.py @@ -1,15 +1,15 @@ import asyncio from concurrent.futures import ThreadPoolExecutor -from typing import TYPE_CHECKING, List, Optional, Union, Literal +from typing import TYPE_CHECKING, List, Optional, Literal + +from .types import PyTorchModelBase if TYPE_CHECKING: from uvicorn import Server - from core.inference.pytorch import PyTorchStableDiffusion - amd: bool = False all_gpus: List = [] -current_model: Union["PyTorchStableDiffusion", None] = None +current_model: Optional[PyTorchModelBase] = None current_method: Literal[None, "txt2img", "img2img", "inpainting", "controlnet"] = None current_steps: int = 50 current_done_steps: int = 0 diff --git a/core/types.py b/core/types.py index 5133fce7b..ab47412c6 100644 --- a/core/types.py +++ b/core/types.py @@ -17,7 +17,7 @@ ) from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers -InferenceBackend = Literal["PyTorch", "AITemplate", "ONNX"] +InferenceBackend = Literal["PyTorch", "AITemplate", "SDXL", "ONNX"] SigmaScheduler = Literal["automatic", "karras", "exponential", "polyexponential", "vp"] Backend = Literal[ "PyTorch", @@ -31,7 +31,18 @@ "Upscaler", "GPT", # for prompt-expansion ] -ImageFormats = Literal["png", "jpeg", "webp"] +PyTorchModelBase = Literal[ + "SD1.x", + "SD2.x", + "SDXL", + "Kandinsky 2.1", + "Kandinsky 2.2", + "Wuerstchen", + "IF", + "Unknown", +] +PyTorchModelStage = Literal["text_encoding", "first_stage", "last_stage"] +ImageFormats = Literal["png", "jpeg", "webp", "gif"] @dataclass @@ -101,7 +112,7 @@ class Img2imgData: @dataclass class InpaintData: - "Dataclass for the data of an img2img request" + "Dataclass for the data of an inpainting request" prompt: str image: Union[bytes, str] @@ -118,6 +129,7 @@ class InpaintData: seed: int = field(default=0) batch_size: int = field(default=1) batch_count: int = field(default=1) + strength: float = field(default=0.6) sampler_settings: Dict = field(default_factory=dict) prompt_to_prompt_settings: Dict = field(default_factory=dict) @@ -136,6 +148,7 @@ class ControlNetData: height: int = field(default=512) steps: int = field(default=25) guidance_scale: float = field(default=7) + self_attention_scale: float = field(default=0.0) sigmas: SigmaScheduler = field(default="automatic") seed: int = field(default=0) batch_size: int = field(default=1) @@ -209,6 +222,20 @@ class UpscaleQueueEntry(Job): data: UpscaleData +@dataclass +class ADetailerQueueEntry(Job): + "Dataclass for an ADetailer job" + + data: Optional[InpaintData] + + # Adetailer specific flags + mask_dilation: int = 4 + mask_blur: int = 4 + mask_padding: int = 32 + iterations: int = 1 + upscale: int = 2 + + @dataclass class QuantizationDict: "Dataclass for quantization parameters" @@ -280,6 +307,8 @@ class ModelResponse: vae: str state: Literal["loading", "loaded", "not loaded"] = field(default="not loaded") textual_inversions: List[str] = field(default_factory=list) + type: PyTorchModelBase = "SD1.x" + stage: PyTorchModelStage = "last_stage" @dataclass diff --git a/core/utils.py b/core/utils.py index 725e575ae..665c9572d 100644 --- a/core/utils.py +++ b/core/utils.py @@ -1,9 +1,11 @@ import asyncio import base64 +import json import logging import math import os import re +import struct from enum import Enum from io import BytesIO from pathlib import Path @@ -15,13 +17,7 @@ from tqdm import tqdm from core.thread import ThreadWithReturnValue -from core.types import ( - ControlNetQueueEntry, - ImageFormats, - Img2ImgQueueEntry, - InpaintQueueEntry, - Txt2ImgQueueEntry, -) +from core.types import ImageFormats, InferenceJob, PyTorchModelBase, PyTorchModelStage logger = logging.getLogger(__name__) content_disposition_regex = re.compile(r"filename=[\"]?([^\";\n]+)[\"]?") @@ -52,9 +48,14 @@ def get_grid_dimension(length: int) -> Tuple[int, int]: def convert_image_to_stream( - image: Image.Image, quality: int = 95, _format: ImageFormats = "webp" + image: Union[BytesIO, Image.Image], + quality: int = 95, + _format: ImageFormats = "webp", ) -> BytesIO: "Convert an image to a stream of bytes" + if isinstance(image, BytesIO): + image.seek(0) + return image stream = BytesIO() image.save(stream, format=_format, quality=quality) @@ -90,8 +91,86 @@ def convert_to_image( raise ValueError(f"Type {type(image)} not supported yet") +def determine_model_type( + file: Path, +) -> Tuple[str, PyTorchModelBase, PyTorchModelStage]: + name = file.name + model_type: PyTorchModelBase = "Unknown" + model_stage: PyTorchModelStage = "last_stage" + if file.suffix == ".safetensors": + with open(file, "rb") as f: + length = struct.unpack("= imgs[0].size[0] dim = get_grid_dimension(len(imgs)) if landscape: @@ -154,15 +236,19 @@ def image_grid(imgs: List[Image.Image]): def convert_images_to_base64_grid( - images: List[Image.Image], + images: List[Union[BytesIO, Image.Image]], quality: int = 95, image_format: ImageFormats = "png", ) -> str: "Convert a list of images to a list of base64 strings" - return convert_image_to_base64( - image_grid(images), quality=quality, image_format=image_format - ) + if isinstance(images[0], BytesIO): + quality = max(quality - 20, 20) + return convert_image_to_base64(images[0], quality=quality, image_format="gif") + else: + return convert_image_to_base64( + image_grid(images), quality=quality, image_format=image_format # type: ignore + ) def resize(image: Image.Image, w: int, h: int): @@ -206,7 +292,14 @@ def download_file(url: str, file: Path, add_filename: bool = False): if add_filename: file = file / file_name - total = int(r.headers["Content-Length"]) + + try: + total = int(r.headers["Content-Length"]) + except KeyError: + total = None + logger.warning( + "Content-Length header not found, progress bar will not work" + ) if file.exists(): logger.debug(f"File {file.as_posix()} already exists, skipping") @@ -224,20 +317,5 @@ def download_file(url: str, file: Path, add_filename: bool = False): return file -def preprocess_job( - job: Union[ - Txt2ImgQueueEntry, Img2ImgQueueEntry, InpaintQueueEntry, ControlNetQueueEntry - ] -): - if not isinstance(job, ControlNetQueueEntry): - # SAG does not work with KDiffusion schedulers - try: - int(unwrap_enum(job.data.scheduler)) - except ValueError: - if job.data.self_attention_scale > 0: - logger.warning( - f"Scheduler {job.data.scheduler} does not support SAG, setting to 0" - ) - job.data.self_attention_scale = 0 - +def preprocess_job(job: InferenceJob): return job diff --git a/data/motion-models/.gitkeep b/data/motion-models/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/data/pia/.gitkeep b/data/pia/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/data/scalecrafter/assets/dilate_settings/all_valid_convs.txt b/data/scalecrafter/assets/dilate_settings/all_valid_convs.txt new file mode 100644 index 000000000..d8030476b --- /dev/null +++ b/data/scalecrafter/assets/dilate_settings/all_valid_convs.txt @@ -0,0 +1,50 @@ +down_blocks.0.resnets.0.conv1 +down_blocks.0.resnets.0.conv2 +down_blocks.0.resnets.1.conv1 +down_blocks.0.resnets.1.conv2 +down_blocks.0.downsamplers.0.conv +down_blocks.1.resnets.0.conv1 +down_blocks.1.resnets.0.conv2 +down_blocks.1.resnets.1.conv1 +down_blocks.1.resnets.1.conv2 +down_blocks.1.downsamplers.0.conv +down_blocks.2.resnets.0.conv1 +down_blocks.2.resnets.0.conv2 +down_blocks.2.resnets.1.conv1 +down_blocks.2.resnets.1.conv2 +down_blocks.2.downsamplers.0.conv +down_blocks.3.resnets.0.conv1 +down_blocks.3.resnets.0.conv2 +down_blocks.3.resnets.1.conv1 +down_blocks.3.resnets.1.conv2 +up_blocks.0.resnets.0.conv1 +up_blocks.0.resnets.0.conv2 +up_blocks.0.resnets.1.conv1 +up_blocks.0.resnets.1.conv2 +up_blocks.0.resnets.2.conv1 +up_blocks.0.resnets.2.conv2 +up_blocks.0.upsamplers.0.conv +up_blocks.1.resnets.0.conv1 +up_blocks.1.resnets.0.conv2 +up_blocks.1.resnets.1.conv1 +up_blocks.1.resnets.1.conv2 +up_blocks.1.resnets.2.conv1 +up_blocks.1.resnets.2.conv2 +up_blocks.1.upsamplers.0.conv +up_blocks.2.resnets.0.conv1 +up_blocks.2.resnets.0.conv2 +up_blocks.2.resnets.1.conv1 +up_blocks.2.resnets.1.conv2 +up_blocks.2.resnets.2.conv1 +up_blocks.2.resnets.2.conv2 +up_blocks.2.upsamplers.0.conv +up_blocks.3.resnets.0.conv1 +up_blocks.3.resnets.0.conv2 +up_blocks.3.resnets.1.conv1 +up_blocks.3.resnets.1.conv2 +up_blocks.3.resnets.2.conv1 +up_blocks.3.resnets.2.conv2 +mid_block.resnets.0.conv1 +mid_block.resnets.0.conv2 +mid_block.resnets.1.conv1 +mid_block.resnets.1.conv2 diff --git a/data/scalecrafter/assets/dilate_settings/sd1.5_1024x1024.txt b/data/scalecrafter/assets/dilate_settings/sd1.5_1024x1024.txt new file mode 100644 index 000000000..e3c699030 --- /dev/null +++ b/data/scalecrafter/assets/dilate_settings/sd1.5_1024x1024.txt @@ -0,0 +1,39 @@ +down_blocks.1.resnets.0.conv1:2 +down_blocks.1.resnets.0.conv2:2 +down_blocks.1.resnets.1.conv1:2 +down_blocks.1.resnets.1.conv2:2 +down_blocks.1.downsamplers.0.conv:2 +down_blocks.2.resnets.0.conv1:2 +down_blocks.2.resnets.0.conv2:2 +down_blocks.2.resnets.1.conv1:2 +down_blocks.2.resnets.1.conv2:2 +down_blocks.2.downsamplers.0.conv:2 +down_blocks.3.resnets.0.conv1:2 +down_blocks.3.resnets.0.conv2:2 +down_blocks.3.resnets.1.conv1:2 +down_blocks.3.resnets.1.conv2:2 +up_blocks.0.resnets.0.conv1:2 +up_blocks.0.resnets.0.conv2:2 +up_blocks.0.resnets.1.conv1:2 +up_blocks.0.resnets.1.conv2:2 +up_blocks.0.resnets.2.conv1:2 +up_blocks.0.resnets.2.conv2:2 +up_blocks.0.upsamplers.0.conv:2 +up_blocks.1.resnets.0.conv1:2 +up_blocks.1.resnets.0.conv2:2 +up_blocks.1.resnets.1.conv1:2 +up_blocks.1.resnets.1.conv2:2 +up_blocks.1.resnets.2.conv1:2 +up_blocks.1.resnets.2.conv2:2 +up_blocks.1.upsamplers.0.conv:2 +up_blocks.2.resnets.0.conv1:2 +up_blocks.2.resnets.0.conv2:2 +up_blocks.2.resnets.1.conv1:2 +up_blocks.2.resnets.1.conv2:2 +up_blocks.2.resnets.2.conv1:2 +up_blocks.2.resnets.2.conv2:2 +up_blocks.2.upsamplers.0.conv:2 +mid_block.resnets.0.conv1:2 +mid_block.resnets.0.conv2:2 +mid_block.resnets.1.conv1:2 +mid_block.resnets.1.conv2:2 diff --git a/data/scalecrafter/assets/dilate_settings/sd1.5_1280x1280.txt b/data/scalecrafter/assets/dilate_settings/sd1.5_1280x1280.txt new file mode 100644 index 000000000..22769476e --- /dev/null +++ b/data/scalecrafter/assets/dilate_settings/sd1.5_1280x1280.txt @@ -0,0 +1,39 @@ +down_blocks.1.resnets.0.conv1:2.5 +down_blocks.1.resnets.0.conv2:2.5 +down_blocks.1.resnets.1.conv1:2.5 +down_blocks.1.resnets.1.conv2:2.5 +down_blocks.1.downsamplers.0.conv:2.5 +down_blocks.2.resnets.0.conv1:2.5 +down_blocks.2.resnets.0.conv2:2.5 +down_blocks.2.resnets.1.conv1:2.5 +down_blocks.2.resnets.1.conv2:2.5 +down_blocks.2.downsamplers.0.conv:2.5 +down_blocks.3.resnets.0.conv1:2.5 +down_blocks.3.resnets.0.conv2:2.5 +down_blocks.3.resnets.1.conv1:2.5 +down_blocks.3.resnets.1.conv2:2.5 +up_blocks.0.resnets.0.conv1:2.5 +up_blocks.0.resnets.0.conv2:2.5 +up_blocks.0.resnets.1.conv1:2.5 +up_blocks.0.resnets.1.conv2:2.5 +up_blocks.0.resnets.2.conv1:2.5 +up_blocks.0.resnets.2.conv2:2.5 +up_blocks.0.upsamplers.0.conv:2.5 +up_blocks.1.resnets.0.conv1:2.5 +up_blocks.1.resnets.0.conv2:2.5 +up_blocks.1.resnets.1.conv1:2.5 +up_blocks.1.resnets.1.conv2:2.5 +up_blocks.1.resnets.2.conv1:2.5 +up_blocks.1.resnets.2.conv2:2.5 +up_blocks.1.upsamplers.0.conv:2.5 +up_blocks.2.resnets.0.conv1:2.5 +up_blocks.2.resnets.0.conv2:2.5 +up_blocks.2.resnets.1.conv1:2.5 +up_blocks.2.resnets.1.conv2:2.5 +up_blocks.2.resnets.2.conv1:2.5 +up_blocks.2.resnets.2.conv2:2.5 +up_blocks.2.upsamplers.0.conv:2.5 +mid_block.resnets.0.conv1:2.5 +mid_block.resnets.0.conv2:2.5 +mid_block.resnets.1.conv1:2.5 +mid_block.resnets.1.conv2:2.5 diff --git a/data/scalecrafter/assets/dilate_settings/sd1.5_2048x1024.txt b/data/scalecrafter/assets/dilate_settings/sd1.5_2048x1024.txt new file mode 100644 index 000000000..dd2089694 --- /dev/null +++ b/data/scalecrafter/assets/dilate_settings/sd1.5_2048x1024.txt @@ -0,0 +1,39 @@ +down_blocks.1.resnets.0.conv1:2 +down_blocks.1.resnets.0.conv2:2 +down_blocks.1.resnets.1.conv1:2 +down_blocks.1.resnets.1.conv2:2 +down_blocks.1.downsamplers.0.conv:2 +down_blocks.2.resnets.0.conv1:3 +down_blocks.2.resnets.0.conv2:3 +down_blocks.2.resnets.1.conv1:3 +down_blocks.2.resnets.1.conv2:3 +down_blocks.2.downsamplers.0.conv:3 +down_blocks.3.resnets.0.conv1:4 +down_blocks.3.resnets.0.conv2:4 +down_blocks.3.resnets.1.conv1:4 +down_blocks.3.resnets.1.conv2:4 +up_blocks.0.resnets.0.conv1:4 +up_blocks.0.resnets.0.conv2:4 +up_blocks.0.resnets.1.conv1:4 +up_blocks.0.resnets.1.conv2:4 +up_blocks.0.resnets.2.conv1:4 +up_blocks.0.resnets.2.conv2:4 +up_blocks.0.upsamplers.0.conv:4 +up_blocks.1.resnets.0.conv1:3 +up_blocks.1.resnets.0.conv2:3 +up_blocks.1.resnets.1.conv1:3 +up_blocks.1.resnets.1.conv2:3 +up_blocks.1.resnets.2.conv1:3 +up_blocks.1.resnets.2.conv2:3 +up_blocks.1.upsamplers.0.conv:3 +up_blocks.2.resnets.0.conv1:2 +up_blocks.2.resnets.0.conv2:2 +up_blocks.2.resnets.1.conv1:2 +up_blocks.2.resnets.1.conv2:2 +up_blocks.2.resnets.2.conv1:2 +up_blocks.2.resnets.2.conv2:2 +up_blocks.2.upsamplers.0.conv:2 +mid_block.resnets.0.conv1:4 +mid_block.resnets.0.conv2:4 +mid_block.resnets.1.conv1:4 +mid_block.resnets.1.conv2:4 diff --git a/data/scalecrafter/assets/dilate_settings/sd1.5_2048x2048.txt b/data/scalecrafter/assets/dilate_settings/sd1.5_2048x2048.txt new file mode 100644 index 000000000..dd2089694 --- /dev/null +++ b/data/scalecrafter/assets/dilate_settings/sd1.5_2048x2048.txt @@ -0,0 +1,39 @@ +down_blocks.1.resnets.0.conv1:2 +down_blocks.1.resnets.0.conv2:2 +down_blocks.1.resnets.1.conv1:2 +down_blocks.1.resnets.1.conv2:2 +down_blocks.1.downsamplers.0.conv:2 +down_blocks.2.resnets.0.conv1:3 +down_blocks.2.resnets.0.conv2:3 +down_blocks.2.resnets.1.conv1:3 +down_blocks.2.resnets.1.conv2:3 +down_blocks.2.downsamplers.0.conv:3 +down_blocks.3.resnets.0.conv1:4 +down_blocks.3.resnets.0.conv2:4 +down_blocks.3.resnets.1.conv1:4 +down_blocks.3.resnets.1.conv2:4 +up_blocks.0.resnets.0.conv1:4 +up_blocks.0.resnets.0.conv2:4 +up_blocks.0.resnets.1.conv1:4 +up_blocks.0.resnets.1.conv2:4 +up_blocks.0.resnets.2.conv1:4 +up_blocks.0.resnets.2.conv2:4 +up_blocks.0.upsamplers.0.conv:4 +up_blocks.1.resnets.0.conv1:3 +up_blocks.1.resnets.0.conv2:3 +up_blocks.1.resnets.1.conv1:3 +up_blocks.1.resnets.1.conv2:3 +up_blocks.1.resnets.2.conv1:3 +up_blocks.1.resnets.2.conv2:3 +up_blocks.1.upsamplers.0.conv:3 +up_blocks.2.resnets.0.conv1:2 +up_blocks.2.resnets.0.conv2:2 +up_blocks.2.resnets.1.conv1:2 +up_blocks.2.resnets.1.conv2:2 +up_blocks.2.resnets.2.conv1:2 +up_blocks.2.resnets.2.conv2:2 +up_blocks.2.upsamplers.0.conv:2 +mid_block.resnets.0.conv1:4 +mid_block.resnets.0.conv2:4 +mid_block.resnets.1.conv1:4 +mid_block.resnets.1.conv2:4 diff --git a/data/scalecrafter/assets/dilate_settings/sd2.1_1024x1024.txt b/data/scalecrafter/assets/dilate_settings/sd2.1_1024x1024.txt new file mode 100644 index 000000000..e3c699030 --- /dev/null +++ b/data/scalecrafter/assets/dilate_settings/sd2.1_1024x1024.txt @@ -0,0 +1,39 @@ +down_blocks.1.resnets.0.conv1:2 +down_blocks.1.resnets.0.conv2:2 +down_blocks.1.resnets.1.conv1:2 +down_blocks.1.resnets.1.conv2:2 +down_blocks.1.downsamplers.0.conv:2 +down_blocks.2.resnets.0.conv1:2 +down_blocks.2.resnets.0.conv2:2 +down_blocks.2.resnets.1.conv1:2 +down_blocks.2.resnets.1.conv2:2 +down_blocks.2.downsamplers.0.conv:2 +down_blocks.3.resnets.0.conv1:2 +down_blocks.3.resnets.0.conv2:2 +down_blocks.3.resnets.1.conv1:2 +down_blocks.3.resnets.1.conv2:2 +up_blocks.0.resnets.0.conv1:2 +up_blocks.0.resnets.0.conv2:2 +up_blocks.0.resnets.1.conv1:2 +up_blocks.0.resnets.1.conv2:2 +up_blocks.0.resnets.2.conv1:2 +up_blocks.0.resnets.2.conv2:2 +up_blocks.0.upsamplers.0.conv:2 +up_blocks.1.resnets.0.conv1:2 +up_blocks.1.resnets.0.conv2:2 +up_blocks.1.resnets.1.conv1:2 +up_blocks.1.resnets.1.conv2:2 +up_blocks.1.resnets.2.conv1:2 +up_blocks.1.resnets.2.conv2:2 +up_blocks.1.upsamplers.0.conv:2 +up_blocks.2.resnets.0.conv1:2 +up_blocks.2.resnets.0.conv2:2 +up_blocks.2.resnets.1.conv1:2 +up_blocks.2.resnets.1.conv2:2 +up_blocks.2.resnets.2.conv1:2 +up_blocks.2.resnets.2.conv2:2 +up_blocks.2.upsamplers.0.conv:2 +mid_block.resnets.0.conv1:2 +mid_block.resnets.0.conv2:2 +mid_block.resnets.1.conv1:2 +mid_block.resnets.1.conv2:2 diff --git a/data/scalecrafter/assets/dilate_settings/sd2.1_1280x1280.txt b/data/scalecrafter/assets/dilate_settings/sd2.1_1280x1280.txt new file mode 100644 index 000000000..22769476e --- /dev/null +++ b/data/scalecrafter/assets/dilate_settings/sd2.1_1280x1280.txt @@ -0,0 +1,39 @@ +down_blocks.1.resnets.0.conv1:2.5 +down_blocks.1.resnets.0.conv2:2.5 +down_blocks.1.resnets.1.conv1:2.5 +down_blocks.1.resnets.1.conv2:2.5 +down_blocks.1.downsamplers.0.conv:2.5 +down_blocks.2.resnets.0.conv1:2.5 +down_blocks.2.resnets.0.conv2:2.5 +down_blocks.2.resnets.1.conv1:2.5 +down_blocks.2.resnets.1.conv2:2.5 +down_blocks.2.downsamplers.0.conv:2.5 +down_blocks.3.resnets.0.conv1:2.5 +down_blocks.3.resnets.0.conv2:2.5 +down_blocks.3.resnets.1.conv1:2.5 +down_blocks.3.resnets.1.conv2:2.5 +up_blocks.0.resnets.0.conv1:2.5 +up_blocks.0.resnets.0.conv2:2.5 +up_blocks.0.resnets.1.conv1:2.5 +up_blocks.0.resnets.1.conv2:2.5 +up_blocks.0.resnets.2.conv1:2.5 +up_blocks.0.resnets.2.conv2:2.5 +up_blocks.0.upsamplers.0.conv:2.5 +up_blocks.1.resnets.0.conv1:2.5 +up_blocks.1.resnets.0.conv2:2.5 +up_blocks.1.resnets.1.conv1:2.5 +up_blocks.1.resnets.1.conv2:2.5 +up_blocks.1.resnets.2.conv1:2.5 +up_blocks.1.resnets.2.conv2:2.5 +up_blocks.1.upsamplers.0.conv:2.5 +up_blocks.2.resnets.0.conv1:2.5 +up_blocks.2.resnets.0.conv2:2.5 +up_blocks.2.resnets.1.conv1:2.5 +up_blocks.2.resnets.1.conv2:2.5 +up_blocks.2.resnets.2.conv1:2.5 +up_blocks.2.resnets.2.conv2:2.5 +up_blocks.2.upsamplers.0.conv:2.5 +mid_block.resnets.0.conv1:2.5 +mid_block.resnets.0.conv2:2.5 +mid_block.resnets.1.conv1:2.5 +mid_block.resnets.1.conv2:2.5 diff --git a/data/scalecrafter/assets/dilate_settings/sd2.1_2048x1024.txt b/data/scalecrafter/assets/dilate_settings/sd2.1_2048x1024.txt new file mode 100644 index 000000000..dd2089694 --- /dev/null +++ b/data/scalecrafter/assets/dilate_settings/sd2.1_2048x1024.txt @@ -0,0 +1,39 @@ +down_blocks.1.resnets.0.conv1:2 +down_blocks.1.resnets.0.conv2:2 +down_blocks.1.resnets.1.conv1:2 +down_blocks.1.resnets.1.conv2:2 +down_blocks.1.downsamplers.0.conv:2 +down_blocks.2.resnets.0.conv1:3 +down_blocks.2.resnets.0.conv2:3 +down_blocks.2.resnets.1.conv1:3 +down_blocks.2.resnets.1.conv2:3 +down_blocks.2.downsamplers.0.conv:3 +down_blocks.3.resnets.0.conv1:4 +down_blocks.3.resnets.0.conv2:4 +down_blocks.3.resnets.1.conv1:4 +down_blocks.3.resnets.1.conv2:4 +up_blocks.0.resnets.0.conv1:4 +up_blocks.0.resnets.0.conv2:4 +up_blocks.0.resnets.1.conv1:4 +up_blocks.0.resnets.1.conv2:4 +up_blocks.0.resnets.2.conv1:4 +up_blocks.0.resnets.2.conv2:4 +up_blocks.0.upsamplers.0.conv:4 +up_blocks.1.resnets.0.conv1:3 +up_blocks.1.resnets.0.conv2:3 +up_blocks.1.resnets.1.conv1:3 +up_blocks.1.resnets.1.conv2:3 +up_blocks.1.resnets.2.conv1:3 +up_blocks.1.resnets.2.conv2:3 +up_blocks.1.upsamplers.0.conv:3 +up_blocks.2.resnets.0.conv1:2 +up_blocks.2.resnets.0.conv2:2 +up_blocks.2.resnets.1.conv1:2 +up_blocks.2.resnets.1.conv2:2 +up_blocks.2.resnets.2.conv1:2 +up_blocks.2.resnets.2.conv2:2 +up_blocks.2.upsamplers.0.conv:2 +mid_block.resnets.0.conv1:4 +mid_block.resnets.0.conv2:4 +mid_block.resnets.1.conv1:4 +mid_block.resnets.1.conv2:4 diff --git a/data/scalecrafter/assets/dilate_settings/sd2.1_2048x2048.txt b/data/scalecrafter/assets/dilate_settings/sd2.1_2048x2048.txt new file mode 100644 index 000000000..e3423b1f0 --- /dev/null +++ b/data/scalecrafter/assets/dilate_settings/sd2.1_2048x2048.txt @@ -0,0 +1,40 @@ +down_blocks.2.resnets.0.conv1:3 +down_blocks.2.resnets.0.conv2:3 +down_blocks.2.resnets.1.conv1:3 +down_blocks.2.resnets.1.conv2:3 +down_blocks.2.downsamplers.0.conv:4 +down_blocks.3.resnets.0.conv1:4 +down_blocks.3.resnets.0.conv2:4 +down_blocks.3.resnets.1.conv1:4 +down_blocks.3.resnets.1.conv2:4 +up_blocks.0.resnets.0.conv1:4 +up_blocks.0.resnets.0.conv2:4 +up_blocks.0.resnets.1.conv1:4 +up_blocks.0.resnets.1.conv2:4 +up_blocks.0.resnets.2.conv1:4 +up_blocks.0.resnets.2.conv2:4 +up_blocks.0.upsamplers.0.conv:4 +up_blocks.1.resnets.0.conv1:4 +up_blocks.1.resnets.0.conv2:4 +up_blocks.1.resnets.1.conv1:4 +up_blocks.1.resnets.1.conv2:4 +up_blocks.1.resnets.2.conv1:4 +up_blocks.1.resnets.2.conv2:4 +up_blocks.1.upsamplers.0.conv:4 +up_blocks.2.resnets.0.conv1:3 +up_blocks.2.resnets.0.conv2:3 +up_blocks.2.resnets.1.conv1:3 +up_blocks.2.resnets.1.conv2:3 +up_blocks.2.resnets.2.conv1:3 +up_blocks.2.resnets.2.conv2:3 +up_blocks.2.upsamplers.0.conv:3 +up_blocks.3.resnets.0.conv1:2 +up_blocks.3.resnets.0.conv2:2 +up_blocks.3.resnets.1.conv1:2 +up_blocks.3.resnets.1.conv2:2 +up_blocks.3.resnets.2.conv1:2 +up_blocks.3.resnets.2.conv2:2 +mid_block.resnets.0.conv1:4 +mid_block.resnets.0.conv2:4 +mid_block.resnets.1.conv1:4 +mid_block.resnets.1.conv2:4 diff --git a/data/scalecrafter/assets/dilate_settings/sdxl_2048x2048.txt b/data/scalecrafter/assets/dilate_settings/sdxl_2048x2048.txt new file mode 100644 index 000000000..e2a9e8576 --- /dev/null +++ b/data/scalecrafter/assets/dilate_settings/sdxl_2048x2048.txt @@ -0,0 +1,15 @@ +down_blocks.3.resnets.0.conv1:2 +down_blocks.3.resnets.0.conv2:2 +down_blocks.3.resnets.1.conv1:2 +down_blocks.3.resnets.1.conv2:2 +up_blocks.0.resnets.0.conv1:2 +up_blocks.0.resnets.0.conv2:2 +up_blocks.0.resnets.1.conv1:2 +up_blocks.0.resnets.1.conv2:2 +up_blocks.0.resnets.2.conv1:2 +up_blocks.0.resnets.2.conv2:2 +up_blocks.0.upsamplers.0.conv:2 +mid_block.resnets.0.conv1:2 +mid_block.resnets.0.conv2:2 +mid_block.resnets.1.conv1:2 +mid_block.resnets.1.conv2:2 diff --git a/data/scalecrafter/assets/dilate_settings/sdxl_2560x2560.txt b/data/scalecrafter/assets/dilate_settings/sdxl_2560x2560.txt new file mode 100644 index 000000000..72af3b53e --- /dev/null +++ b/data/scalecrafter/assets/dilate_settings/sdxl_2560x2560.txt @@ -0,0 +1,15 @@ +down_blocks.3.resnets.0.conv1:2.5 +down_blocks.3.resnets.0.conv2:2.5 +down_blocks.3.resnets.1.conv1:2.5 +down_blocks.3.resnets.1.conv2:2.5 +up_blocks.0.resnets.0.conv1:2.5 +up_blocks.0.resnets.0.conv2:2.5 +up_blocks.0.resnets.1.conv1:2.5 +up_blocks.0.resnets.1.conv2:2.5 +up_blocks.0.resnets.2.conv1:2.5 +up_blocks.0.resnets.2.conv2:2.5 +up_blocks.0.upsamplers.0.conv:2.5 +mid_block.resnets.0.conv1:2.5 +mid_block.resnets.0.conv2:2.5 +mid_block.resnets.1.conv1:2.5 +mid_block.resnets.1.conv2:2.5 diff --git a/data/scalecrafter/assets/dilate_settings/sdxl_4096x2048.txt b/data/scalecrafter/assets/dilate_settings/sdxl_4096x2048.txt new file mode 100644 index 000000000..43da89cf5 --- /dev/null +++ b/data/scalecrafter/assets/dilate_settings/sdxl_4096x2048.txt @@ -0,0 +1,34 @@ +down_blocks.2.resnets.0.conv1:3 +down_blocks.2.resnets.0.conv2:3 +down_blocks.2.resnets.1.conv1:3 +down_blocks.2.resnets.1.conv2:3 +down_blocks.2.downsamplers.0.conv:3 +down_blocks.3.resnets.0.conv1:4 +down_blocks.3.resnets.0.conv2:4 +down_blocks.3.resnets.1.conv1:4 +down_blocks.3.resnets.1.conv2:4 +up_blocks.0.resnets.0.conv1:4 +up_blocks.0.resnets.0.conv2:4 +up_blocks.0.resnets.1.conv1:4 +up_blocks.0.resnets.1.conv2:4 +up_blocks.0.resnets.2.conv1:4 +up_blocks.0.resnets.2.conv2:4 +up_blocks.0.upsamplers.0.conv:4 +up_blocks.1.resnets.0.conv1:3 +up_blocks.1.resnets.0.conv2:3 +up_blocks.1.resnets.1.conv1:3 +up_blocks.1.resnets.1.conv2:3 +up_blocks.1.resnets.2.conv1:3 +up_blocks.1.resnets.2.conv2:3 +up_blocks.1.upsamplers.0.conv:3 +up_blocks.2.resnets.0.conv1:2 +up_blocks.2.resnets.0.conv2:2 +up_blocks.2.resnets.1.conv1:2 +up_blocks.2.resnets.1.conv2:2 +up_blocks.2.resnets.2.conv1:2 +up_blocks.2.resnets.2.conv2:2 +up_blocks.2.upsamplers.0.conv:2 +mid_block.resnets.0.conv1:4 +mid_block.resnets.0.conv2:4 +mid_block.resnets.1.conv1:4 +mid_block.resnets.1.conv2:4 diff --git a/data/scalecrafter/assets/dilate_settings/sdxl_4096x4096.txt b/data/scalecrafter/assets/dilate_settings/sdxl_4096x4096.txt new file mode 100644 index 000000000..047cb71e9 --- /dev/null +++ b/data/scalecrafter/assets/dilate_settings/sdxl_4096x4096.txt @@ -0,0 +1,34 @@ +down_blocks.2.resnets.0.conv1:4 +down_blocks.2.resnets.0.conv2:4 +down_blocks.2.resnets.1.conv1:4 +down_blocks.2.resnets.1.conv2:4 +down_blocks.2.downsamplers.0.conv:4 +down_blocks.3.resnets.0.conv1:4 +down_blocks.3.resnets.0.conv2:4 +down_blocks.3.resnets.1.conv1:4 +down_blocks.3.resnets.1.conv2:4 +up_blocks.0.resnets.0.conv1:4 +up_blocks.0.resnets.0.conv2:4 +up_blocks.0.resnets.1.conv1:4 +up_blocks.0.resnets.1.conv2:4 +up_blocks.0.resnets.2.conv1:4 +up_blocks.0.resnets.2.conv2:4 +up_blocks.0.upsamplers.0.conv:4 +up_blocks.1.resnets.0.conv1:3 +up_blocks.1.resnets.0.conv2:3 +up_blocks.1.resnets.1.conv1:3 +up_blocks.1.resnets.1.conv2:3 +up_blocks.1.resnets.2.conv1:3 +up_blocks.1.resnets.2.conv2:3 +up_blocks.1.upsamplers.0.conv:3 +up_blocks.2.resnets.0.conv1:2 +up_blocks.2.resnets.0.conv2:2 +up_blocks.2.resnets.1.conv1:2 +up_blocks.2.resnets.1.conv2:2 +up_blocks.2.resnets.2.conv1:2 +up_blocks.2.resnets.2.conv2:2 +up_blocks.2.upsamplers.0.conv:2 +mid_block.resnets.0.conv1:4 +mid_block.resnets.0.conv2:4 +mid_block.resnets.1.conv1:4 +mid_block.resnets.1.conv2:4 diff --git a/data/scalecrafter/assets/disperse_settings/sd1.5_2048x2048.txt b/data/scalecrafter/assets/disperse_settings/sd1.5_2048x2048.txt new file mode 100644 index 000000000..81296cf48 --- /dev/null +++ b/data/scalecrafter/assets/disperse_settings/sd1.5_2048x2048.txt @@ -0,0 +1,23 @@ +down_blocks.2.downsamplers.0.conv +down_blocks.3.resnets.0.conv1 +down_blocks.3.resnets.0.conv2 +down_blocks.3.resnets.1.conv1 +down_blocks.3.resnets.1.conv2 +up_blocks.0.resnets.0.conv1 +up_blocks.0.resnets.0.conv2 +up_blocks.0.resnets.1.conv1 +up_blocks.0.resnets.1.conv2 +up_blocks.0.resnets.2.conv1 +up_blocks.0.resnets.2.conv2 +up_blocks.0.upsamplers.0.conv +up_blocks.1.resnets.0.conv1 +up_blocks.1.resnets.0.conv2 +up_blocks.1.resnets.1.conv1 +up_blocks.1.resnets.1.conv2 +up_blocks.1.resnets.2.conv1 +up_blocks.1.resnets.2.conv2 +up_blocks.1.upsamplers.0.conv +mid_block.resnets.0.conv1 +mid_block.resnets.0.conv2 +mid_block.resnets.1.conv1 +mid_block.resnets.1.conv2 diff --git a/data/scalecrafter/assets/disperse_settings/sd2.1_2048x2048.txt b/data/scalecrafter/assets/disperse_settings/sd2.1_2048x2048.txt new file mode 100644 index 000000000..81296cf48 --- /dev/null +++ b/data/scalecrafter/assets/disperse_settings/sd2.1_2048x2048.txt @@ -0,0 +1,23 @@ +down_blocks.2.downsamplers.0.conv +down_blocks.3.resnets.0.conv1 +down_blocks.3.resnets.0.conv2 +down_blocks.3.resnets.1.conv1 +down_blocks.3.resnets.1.conv2 +up_blocks.0.resnets.0.conv1 +up_blocks.0.resnets.0.conv2 +up_blocks.0.resnets.1.conv1 +up_blocks.0.resnets.1.conv2 +up_blocks.0.resnets.2.conv1 +up_blocks.0.resnets.2.conv2 +up_blocks.0.upsamplers.0.conv +up_blocks.1.resnets.0.conv1 +up_blocks.1.resnets.0.conv2 +up_blocks.1.resnets.1.conv1 +up_blocks.1.resnets.1.conv2 +up_blocks.1.resnets.2.conv1 +up_blocks.1.resnets.2.conv2 +up_blocks.1.upsamplers.0.conv +mid_block.resnets.0.conv1 +mid_block.resnets.0.conv2 +mid_block.resnets.1.conv1 +mid_block.resnets.1.conv2 diff --git a/data/scalecrafter/assets/disperse_settings/sdxl_4096x4096.txt b/data/scalecrafter/assets/disperse_settings/sdxl_4096x4096.txt new file mode 100644 index 000000000..5d27a4f67 --- /dev/null +++ b/data/scalecrafter/assets/disperse_settings/sdxl_4096x4096.txt @@ -0,0 +1,16 @@ +down_blocks.2.resnets.0.conv1 +down_blocks.2.resnets.0.conv2 +down_blocks.2.resnets.1.conv1 +down_blocks.2.resnets.1.conv2 +down_blocks.2.downsamplers.0.conv +up_blocks.0.resnets.0.conv1 +up_blocks.0.resnets.0.conv2 +up_blocks.0.resnets.1.conv1 +up_blocks.0.resnets.1.conv2 +up_blocks.0.resnets.2.conv1 +up_blocks.0.resnets.2.conv2 +up_blocks.0.upsamplers.0.conv +mid_block.resnets.0.conv1 +mid_block.resnets.0.conv2 +mid_block.resnets.1.conv1 +mid_block.resnets.1.conv2 \ No newline at end of file diff --git a/data/scalecrafter/assets/inflate_settings/sd1.5_2048x2048.txt b/data/scalecrafter/assets/inflate_settings/sd1.5_2048x2048.txt new file mode 100644 index 000000000..113001f3e --- /dev/null +++ b/data/scalecrafter/assets/inflate_settings/sd1.5_2048x2048.txt @@ -0,0 +1,39 @@ +down_blocks.1.resnets.0.conv1 +down_blocks.1.resnets.0.conv2 +down_blocks.1.resnets.1.conv1 +down_blocks.1.resnets.1.conv2 +down_blocks.1.downsamplers.0.conv +down_blocks.2.resnets.0.conv1 +down_blocks.2.resnets.0.conv2 +down_blocks.2.resnets.1.conv1 +down_blocks.2.resnets.1.conv2 +down_blocks.2.downsamplers.0.conv +down_blocks.3.resnets.0.conv1 +down_blocks.3.resnets.0.conv2 +down_blocks.3.resnets.1.conv1 +down_blocks.3.resnets.1.conv2 +up_blocks.0.resnets.0.conv1 +up_blocks.0.resnets.0.conv2 +up_blocks.0.resnets.1.conv1 +up_blocks.0.resnets.1.conv2 +up_blocks.0.resnets.2.conv1 +up_blocks.0.resnets.2.conv2 +up_blocks.0.upsamplers.0.conv +up_blocks.1.resnets.0.conv1 +up_blocks.1.resnets.0.conv2 +up_blocks.1.resnets.1.conv1 +up_blocks.1.resnets.1.conv2 +up_blocks.1.resnets.2.conv1 +up_blocks.1.resnets.2.conv2 +up_blocks.1.upsamplers.0.conv +up_blocks.2.resnets.0.conv1 +up_blocks.2.resnets.0.conv2 +up_blocks.2.resnets.1.conv1 +up_blocks.2.resnets.1.conv2 +up_blocks.2.resnets.2.conv1 +up_blocks.2.resnets.2.conv2 +up_blocks.2.upsamplers.0.conv +mid_block.resnets.0.conv1 +mid_block.resnets.0.conv2 +mid_block.resnets.1.conv1 +mid_block.resnets.1.conv2 diff --git a/data/scalecrafter/assets/inflate_settings/sdxl_4096x4096.txt b/data/scalecrafter/assets/inflate_settings/sdxl_4096x4096.txt new file mode 100644 index 000000000..e69de29bb diff --git a/data/scalecrafter/assets/ndcfg_dilate_settings/sd1.5_2048x1024.txt b/data/scalecrafter/assets/ndcfg_dilate_settings/sd1.5_2048x1024.txt new file mode 100644 index 000000000..199b42b44 --- /dev/null +++ b/data/scalecrafter/assets/ndcfg_dilate_settings/sd1.5_2048x1024.txt @@ -0,0 +1,27 @@ +down_blocks.2.resnets.0.conv1:3 +down_blocks.2.resnets.0.conv2:3 +down_blocks.2.resnets.1.conv1:3 +down_blocks.2.resnets.1.conv2:3 +down_blocks.2.downsamplers.0.conv:3 +down_blocks.3.resnets.0.conv1:4 +down_blocks.3.resnets.0.conv2:4 +down_blocks.3.resnets.1.conv1:4 +down_blocks.3.resnets.1.conv2:4 +up_blocks.0.resnets.0.conv1:4 +up_blocks.0.resnets.0.conv2:4 +up_blocks.0.resnets.1.conv1:4 +up_blocks.0.resnets.1.conv2:4 +up_blocks.0.resnets.2.conv1:4 +up_blocks.0.resnets.2.conv2:4 +up_blocks.0.upsamplers.0.conv:4 +up_blocks.1.resnets.0.conv1:3 +up_blocks.1.resnets.0.conv2:3 +up_blocks.1.resnets.1.conv1:3 +up_blocks.1.resnets.1.conv2:3 +up_blocks.1.resnets.2.conv1:3 +up_blocks.1.resnets.2.conv2:3 +up_blocks.1.upsamplers.0.conv:3 +mid_block.resnets.0.conv1:4 +mid_block.resnets.0.conv2:4 +mid_block.resnets.1.conv1:4 +mid_block.resnets.1.conv2:4 diff --git a/data/scalecrafter/assets/ndcfg_dilate_settings/sd1.5_2048x2048.txt b/data/scalecrafter/assets/ndcfg_dilate_settings/sd1.5_2048x2048.txt new file mode 100644 index 000000000..199b42b44 --- /dev/null +++ b/data/scalecrafter/assets/ndcfg_dilate_settings/sd1.5_2048x2048.txt @@ -0,0 +1,27 @@ +down_blocks.2.resnets.0.conv1:3 +down_blocks.2.resnets.0.conv2:3 +down_blocks.2.resnets.1.conv1:3 +down_blocks.2.resnets.1.conv2:3 +down_blocks.2.downsamplers.0.conv:3 +down_blocks.3.resnets.0.conv1:4 +down_blocks.3.resnets.0.conv2:4 +down_blocks.3.resnets.1.conv1:4 +down_blocks.3.resnets.1.conv2:4 +up_blocks.0.resnets.0.conv1:4 +up_blocks.0.resnets.0.conv2:4 +up_blocks.0.resnets.1.conv1:4 +up_blocks.0.resnets.1.conv2:4 +up_blocks.0.resnets.2.conv1:4 +up_blocks.0.resnets.2.conv2:4 +up_blocks.0.upsamplers.0.conv:4 +up_blocks.1.resnets.0.conv1:3 +up_blocks.1.resnets.0.conv2:3 +up_blocks.1.resnets.1.conv1:3 +up_blocks.1.resnets.1.conv2:3 +up_blocks.1.resnets.2.conv1:3 +up_blocks.1.resnets.2.conv2:3 +up_blocks.1.upsamplers.0.conv:3 +mid_block.resnets.0.conv1:4 +mid_block.resnets.0.conv2:4 +mid_block.resnets.1.conv1:4 +mid_block.resnets.1.conv2:4 diff --git a/data/scalecrafter/assets/ndcfg_dilate_settings/sd2.1_2048x1024.txt b/data/scalecrafter/assets/ndcfg_dilate_settings/sd2.1_2048x1024.txt new file mode 100644 index 000000000..0b48eb4d0 --- /dev/null +++ b/data/scalecrafter/assets/ndcfg_dilate_settings/sd2.1_2048x1024.txt @@ -0,0 +1,34 @@ +down_blocks.2.resnets.0.conv1:3 +down_blocks.2.resnets.0.conv2:3 +down_blocks.2.resnets.1.conv1:3 +down_blocks.2.resnets.1.conv2:3 +down_blocks.2.downsamplers.0.conv:4 +down_blocks.3.resnets.0.conv1:4 +down_blocks.3.resnets.0.conv2:4 +down_blocks.3.resnets.1.conv1:4 +down_blocks.3.resnets.1.conv2:4 +up_blocks.0.resnets.0.conv1:4 +up_blocks.0.resnets.0.conv2:4 +up_blocks.0.resnets.1.conv1:4 +up_blocks.0.resnets.1.conv2:4 +up_blocks.0.resnets.2.conv1:4 +up_blocks.0.resnets.2.conv2:4 +up_blocks.0.upsamplers.0.conv:4 +up_blocks.1.resnets.0.conv1:4 +up_blocks.1.resnets.0.conv2:4 +up_blocks.1.resnets.1.conv1:4 +up_blocks.1.resnets.1.conv2:4 +up_blocks.1.resnets.2.conv1:4 +up_blocks.1.resnets.2.conv2:4 +up_blocks.1.upsamplers.0.conv:4 +up_blocks.2.resnets.0.conv1:3 +up_blocks.2.resnets.0.conv2:3 +up_blocks.2.resnets.1.conv1:3 +up_blocks.2.resnets.1.conv2:3 +up_blocks.2.resnets.2.conv1:3 +up_blocks.2.resnets.2.conv2:3 +up_blocks.2.upsamplers.0.conv:3 +mid_block.resnets.0.conv1:4 +mid_block.resnets.0.conv2:4 +mid_block.resnets.1.conv1:4 +mid_block.resnets.1.conv2:4 diff --git a/data/scalecrafter/assets/ndcfg_dilate_settings/sd2.1_2048x2048.txt b/data/scalecrafter/assets/ndcfg_dilate_settings/sd2.1_2048x2048.txt new file mode 100644 index 000000000..0b48eb4d0 --- /dev/null +++ b/data/scalecrafter/assets/ndcfg_dilate_settings/sd2.1_2048x2048.txt @@ -0,0 +1,34 @@ +down_blocks.2.resnets.0.conv1:3 +down_blocks.2.resnets.0.conv2:3 +down_blocks.2.resnets.1.conv1:3 +down_blocks.2.resnets.1.conv2:3 +down_blocks.2.downsamplers.0.conv:4 +down_blocks.3.resnets.0.conv1:4 +down_blocks.3.resnets.0.conv2:4 +down_blocks.3.resnets.1.conv1:4 +down_blocks.3.resnets.1.conv2:4 +up_blocks.0.resnets.0.conv1:4 +up_blocks.0.resnets.0.conv2:4 +up_blocks.0.resnets.1.conv1:4 +up_blocks.0.resnets.1.conv2:4 +up_blocks.0.resnets.2.conv1:4 +up_blocks.0.resnets.2.conv2:4 +up_blocks.0.upsamplers.0.conv:4 +up_blocks.1.resnets.0.conv1:4 +up_blocks.1.resnets.0.conv2:4 +up_blocks.1.resnets.1.conv1:4 +up_blocks.1.resnets.1.conv2:4 +up_blocks.1.resnets.2.conv1:4 +up_blocks.1.resnets.2.conv2:4 +up_blocks.1.upsamplers.0.conv:4 +up_blocks.2.resnets.0.conv1:3 +up_blocks.2.resnets.0.conv2:3 +up_blocks.2.resnets.1.conv1:3 +up_blocks.2.resnets.1.conv2:3 +up_blocks.2.resnets.2.conv1:3 +up_blocks.2.resnets.2.conv2:3 +up_blocks.2.upsamplers.0.conv:3 +mid_block.resnets.0.conv1:4 +mid_block.resnets.0.conv2:4 +mid_block.resnets.1.conv1:4 +mid_block.resnets.1.conv2:4 diff --git a/data/scalecrafter/assets/ndcfg_dilate_settings/sdxl_4096x2048.txt b/data/scalecrafter/assets/ndcfg_dilate_settings/sdxl_4096x2048.txt new file mode 100644 index 000000000..e2a9e8576 --- /dev/null +++ b/data/scalecrafter/assets/ndcfg_dilate_settings/sdxl_4096x2048.txt @@ -0,0 +1,15 @@ +down_blocks.3.resnets.0.conv1:2 +down_blocks.3.resnets.0.conv2:2 +down_blocks.3.resnets.1.conv1:2 +down_blocks.3.resnets.1.conv2:2 +up_blocks.0.resnets.0.conv1:2 +up_blocks.0.resnets.0.conv2:2 +up_blocks.0.resnets.1.conv1:2 +up_blocks.0.resnets.1.conv2:2 +up_blocks.0.resnets.2.conv1:2 +up_blocks.0.resnets.2.conv2:2 +up_blocks.0.upsamplers.0.conv:2 +mid_block.resnets.0.conv1:2 +mid_block.resnets.0.conv2:2 +mid_block.resnets.1.conv1:2 +mid_block.resnets.1.conv2:2 diff --git a/data/scalecrafter/assets/ndcfg_dilate_settings/sdxl_4096x4096.txt b/data/scalecrafter/assets/ndcfg_dilate_settings/sdxl_4096x4096.txt new file mode 100644 index 000000000..e2a9e8576 --- /dev/null +++ b/data/scalecrafter/assets/ndcfg_dilate_settings/sdxl_4096x4096.txt @@ -0,0 +1,15 @@ +down_blocks.3.resnets.0.conv1:2 +down_blocks.3.resnets.0.conv2:2 +down_blocks.3.resnets.1.conv1:2 +down_blocks.3.resnets.1.conv2:2 +up_blocks.0.resnets.0.conv1:2 +up_blocks.0.resnets.0.conv2:2 +up_blocks.0.resnets.1.conv1:2 +up_blocks.0.resnets.1.conv2:2 +up_blocks.0.resnets.2.conv1:2 +up_blocks.0.resnets.2.conv2:2 +up_blocks.0.upsamplers.0.conv:2 +mid_block.resnets.0.conv1:2 +mid_block.resnets.0.conv2:2 +mid_block.resnets.1.conv1:2 +mid_block.resnets.1.conv2:2 diff --git a/data/scalecrafter/configs/sd1.5_1024x1024.yaml b/data/scalecrafter/configs/sd1.5_1024x1024.yaml new file mode 100644 index 000000000..50ec3f33a --- /dev/null +++ b/data/scalecrafter/configs/sd1.5_1024x1024.yaml @@ -0,0 +1,13 @@ +ndcfg_tau: 0 +dilate_tau: 30 +inflate_tau: 0 +dilate_settings: ./assets/dilate_settings/sd1.5_1024x1024.txt +ndcfg_dilate_settings: ~ +disperse_settings: ~ +disperse_transform: ~ +progressive: false +num_inference_steps: 50 +inference_batch_size: 4 +num_iters_per_prompt: 1 +latent_height: 128 +latent_width: 128 \ No newline at end of file diff --git a/data/scalecrafter/configs/sd1.5_1280x1280.yaml b/data/scalecrafter/configs/sd1.5_1280x1280.yaml new file mode 100644 index 000000000..3baebf614 --- /dev/null +++ b/data/scalecrafter/configs/sd1.5_1280x1280.yaml @@ -0,0 +1,13 @@ +ndcfg_tau: 0 +dilate_tau: 35 +inflate_tau: 0 +dilate_settings: ./assets/dilate_settings/sd1.5_1280x1280.txt +ndcfg_dilate_settings: ~ +disperse_settings: ~ +disperse_transform: ~ +progressive: false +num_inference_steps: 50 +inference_batch_size: 4 +num_iters_per_prompt: 1 +latent_height: 160 +latent_width: 160 \ No newline at end of file diff --git a/data/scalecrafter/configs/sd1.5_2048x1024.yaml b/data/scalecrafter/configs/sd1.5_2048x1024.yaml new file mode 100644 index 000000000..00b2bfa71 --- /dev/null +++ b/data/scalecrafter/configs/sd1.5_2048x1024.yaml @@ -0,0 +1,13 @@ +ndcfg_tau: 35 +dilate_tau: 35 +inflate_tau: 0 +dilate_settings: ./assets/dilate_settings/sd1.5_2048x1024.txt +ndcfg_dilate_settings: ./assets/ndcfg_dilate_settings/sd1.5_2048x1024.txt +disperse_settings: ~ +disperse_transform: ~ +progressive: true +num_inference_steps: 50 +inference_batch_size: 4 +num_iters_per_prompt: 1 +latent_height: 128 +latent_width: 256 \ No newline at end of file diff --git a/data/scalecrafter/configs/sd1.5_2048x2048.yaml b/data/scalecrafter/configs/sd1.5_2048x2048.yaml new file mode 100644 index 000000000..3496cb5df --- /dev/null +++ b/data/scalecrafter/configs/sd1.5_2048x2048.yaml @@ -0,0 +1,13 @@ +ndcfg_tau: 35 +dilate_tau: 35 +inflate_tau: 0 +dilate_settings: ./assets/dilate_settings/sd1.5_2048x2048.txt +ndcfg_dilate_settings: ./assets/ndcfg_dilate_settings/sd1.5_2048x2048.txt +disperse_settings: ~ +disperse_transform: ~ +progressive: true +num_inference_steps: 50 +inference_batch_size: 4 +num_iters_per_prompt: 1 +latent_height: 256 +latent_width: 256 \ No newline at end of file diff --git a/data/scalecrafter/configs/sd1.5_2048x2048_disperse.yaml b/data/scalecrafter/configs/sd1.5_2048x2048_disperse.yaml new file mode 100644 index 000000000..3f3f7107d --- /dev/null +++ b/data/scalecrafter/configs/sd1.5_2048x2048_disperse.yaml @@ -0,0 +1,13 @@ +ndcfg_tau: 35 +dilate_tau: 35 +inflate_tau: 0 +dilate_settings: ./assets/dilate_settings/sd1.5_2048x2048.txt +ndcfg_dilate_settings: ./assets/ndcfg_dilate_settings/sd1.5_2048x2048.txt +disperse_settings: ./assets/disperse_settings/sd1.5_2048x2048.txt +disperse_transform: ./transforms/R20to1_new.mat +progressive: true +num_inference_steps: 50 +inference_batch_size: 4 +num_iters_per_prompt: 1 +latent_height: 256 +latent_width: 256 \ No newline at end of file diff --git a/data/scalecrafter/configs/sd2.1_1024x1024.yaml b/data/scalecrafter/configs/sd2.1_1024x1024.yaml new file mode 100644 index 000000000..75697dd60 --- /dev/null +++ b/data/scalecrafter/configs/sd2.1_1024x1024.yaml @@ -0,0 +1,13 @@ +ndcfg_tau: 0 +dilate_tau: 20 +inflate_tau: 0 +dilate_settings: ./assets/dilate_settings/sd2.1_1024x1024.txt +ndcfg_dilate_settings: ~ +disperse_settings: ~ +disperse_transform: ~ +progressive: false +num_inference_steps: 50 +inference_batch_size: 4 +num_iters_per_prompt: 1 +latent_height: 128 +latent_width: 128 \ No newline at end of file diff --git a/data/scalecrafter/configs/sd2.1_1280x1280.yaml b/data/scalecrafter/configs/sd2.1_1280x1280.yaml new file mode 100644 index 000000000..4628a4f66 --- /dev/null +++ b/data/scalecrafter/configs/sd2.1_1280x1280.yaml @@ -0,0 +1,13 @@ +ndcfg_tau: 0 +dilate_tau: 30 +inflate_tau: 0 +dilate_settings: ./assets/dilate_settings/sd2.1_1280x1280.txt +ndcfg_dilate_settings: ~ +disperse_settings: ~ +disperse_transform: ~ +progressive: false +num_inference_steps: 50 +inference_batch_size: 4 +num_iters_per_prompt: 1 +latent_height: 160 +latent_width: 160 \ No newline at end of file diff --git a/data/scalecrafter/configs/sd2.1_2048x1024.yaml b/data/scalecrafter/configs/sd2.1_2048x1024.yaml new file mode 100644 index 000000000..50109ff0a --- /dev/null +++ b/data/scalecrafter/configs/sd2.1_2048x1024.yaml @@ -0,0 +1,13 @@ +ndcfg_tau: 0 +dilate_tau: 37 +inflate_tau: 0 +dilate_settings: ./assets/dilate_settings/sd2.1_2048x1024.txt +ndcfg_dilate_settings: ./assets/ndcfg_dilate_settings/sd2.1_2048x1024.txt +disperse_settings: ~ +disperse_transform: ~ +progressive: true +num_inference_steps: 50 +inference_batch_size: 4 +num_iters_per_prompt: 1 +latent_height: 128 +latent_width: 256 \ No newline at end of file diff --git a/data/scalecrafter/configs/sd2.1_2048x2048.yaml b/data/scalecrafter/configs/sd2.1_2048x2048.yaml new file mode 100644 index 000000000..fec34a0f3 --- /dev/null +++ b/data/scalecrafter/configs/sd2.1_2048x2048.yaml @@ -0,0 +1,13 @@ +ndcfg_tau: 37 +dilate_tau: 37 +inflate_tau: 0 +dilate_settings: ./assets/dilate_settings/sd2.1_2048x2048.txt +ndcfg_dilate_settings: ./assets/ndcfg_dilate_settings/sd2.1_2048x2048.txt +disperse_settings: ~ +disperse_transform: ~ +progressive: true +num_inference_steps: 50 +inference_batch_size: 1 +num_iters_per_prompt: 1 +latent_height: 256 +latent_width: 256 \ No newline at end of file diff --git a/data/scalecrafter/configs/sd2.1_2048x2048_disperse.yaml b/data/scalecrafter/configs/sd2.1_2048x2048_disperse.yaml new file mode 100644 index 000000000..7eb513176 --- /dev/null +++ b/data/scalecrafter/configs/sd2.1_2048x2048_disperse.yaml @@ -0,0 +1,13 @@ +ndcfg_tau: 37 +dilate_tau: 37 +inflate_tau: 0 +dilate_settings: ./assets/dilate_settings/sd2.1_2048x2048.txt +ndcfg_dilate_settings: ./assets/ndcfg_dilate_settings/sd2.1_2048x2048.txt +disperse_settings: ./assets/disperse_settings/sd2.1_2048x2048.txt +disperse_transform: ./transforms/R20to1_new.mat +progressive: true +num_inference_steps: 50 +inference_batch_size: 1 +num_iters_per_prompt: 1 +latent_height: 256 +latent_width: 256 \ No newline at end of file diff --git a/data/scalecrafter/configs/sdxl_2048x2048.yaml b/data/scalecrafter/configs/sdxl_2048x2048.yaml new file mode 100644 index 000000000..2ffd386f8 --- /dev/null +++ b/data/scalecrafter/configs/sdxl_2048x2048.yaml @@ -0,0 +1,16 @@ +ndcfg_tau: 0 +dilate_tau: 30 +inflate_tau: 0 +sdedit_tau: 0 +dilate_settings: ./assets/dilate_settings/sdxl_2048x2048.txt +ndcfg_dilate_settings: ~ +disperse_settings: ~ +disperse_transform: ~ +progressive: false +num_inference_steps: 50 +inference_batch_size: 4 +num_iters_per_prompt: 1 +latent_height: 256 +latent_width: 256 +pixel_height: 2048 +pixel_width: 2048 diff --git a/data/scalecrafter/configs/sdxl_2560x2560.yaml b/data/scalecrafter/configs/sdxl_2560x2560.yaml new file mode 100644 index 000000000..e2a81b5d4 --- /dev/null +++ b/data/scalecrafter/configs/sdxl_2560x2560.yaml @@ -0,0 +1,16 @@ +ndcfg_tau: 0 +dilate_tau: 30 +inflate_tau: 0 +sdedit_tau: 0 +dilate_settings: ./assets/dilate_settings/sdxl_2560x2560.txt +ndcfg_dilate_settings: ~ +disperse_settings: ~ +disperse_transform: ~ +progressive: false +num_inference_steps: 50 +inference_batch_size: 4 +num_iters_per_prompt: 1 +latent_height: 320 +latent_width: 320 +pixel_height: 2560 +pixel_width: 2560 diff --git a/data/scalecrafter/configs/sdxl_4096x2048.yaml b/data/scalecrafter/configs/sdxl_4096x2048.yaml new file mode 100644 index 000000000..82cfee805 --- /dev/null +++ b/data/scalecrafter/configs/sdxl_4096x2048.yaml @@ -0,0 +1,16 @@ +ndcfg_tau: 35 +dilate_tau: 35 +inflate_tau: 0 +sdedit_tau: 0 +dilate_settings: ./assets/dilate_settings/sdxl_4096x2048.txt +ndcfg_dilate_settings: ./assets/ndcfg_dilate_settings/sdxl_4096x2048.txt +disperse_settings: ~ +disperse_transform: ~ +progressive: true +num_inference_steps: 50 +inference_batch_size: 1 +num_iters_per_prompt: 1 +latent_height: 256 +latent_width: 512 +pixel_height: 2048 +pixel_width: 4096 diff --git a/data/scalecrafter/configs/sdxl_4096x4096.yaml b/data/scalecrafter/configs/sdxl_4096x4096.yaml new file mode 100644 index 000000000..15240ebb2 --- /dev/null +++ b/data/scalecrafter/configs/sdxl_4096x4096.yaml @@ -0,0 +1,16 @@ +ndcfg_tau: 35 +dilate_tau: 35 +inflate_tau: 0 +sdedit_tau: 0 +dilate_settings: ./assets/dilate_settings/sdxl_4096x4096.txt +ndcfg_dilate_settings: ./assets/ndcfg_dilate_settings/sdxl_4096x4096.txt +disperse_settings: ~ +disperse_transform: ~ +progressive: true +num_inference_steps: 50 +inference_batch_size: 1 +num_iters_per_prompt: 1 +latent_height: 512 +latent_width: 512 +pixel_height: 4096 +pixel_width: 4096 diff --git a/data/scalecrafter/configs/sdxl_4096x4096_disperse.yaml b/data/scalecrafter/configs/sdxl_4096x4096_disperse.yaml new file mode 100644 index 000000000..b2764f36c --- /dev/null +++ b/data/scalecrafter/configs/sdxl_4096x4096_disperse.yaml @@ -0,0 +1,16 @@ +ndcfg_tau: 35 +dilate_tau: 35 +inflate_tau: 0 +sdedit_tau: 0 +dilate_settings: ./assets/dilate_settings/sdxl_4096x4096.txt +ndcfg_dilate_settings: ./assets/ndcfg_dilate_settings/sdxl_4096x4096.txt +disperse_settings: ./assets/disperse_settings/sdxl_4096x4096.txt +disperse_transform: ./transforms/R20to1_new.mat +progressive: true +num_inference_steps: 50 +inference_batch_size: 1 +num_iters_per_prompt: 1 +latent_height: 512 +latent_width: 512 +pixel_height: 4096 +pixel_width: 4096 diff --git a/data/scalecrafter/disperse/bilinear_upsample_symbolic.m b/data/scalecrafter/disperse/bilinear_upsample_symbolic.m new file mode 100644 index 000000000..636aedeb9 --- /dev/null +++ b/data/scalecrafter/disperse/bilinear_upsample_symbolic.m @@ -0,0 +1,43 @@ +function output = bilinear_upsample_symbolic(input, upsample_factor) + % Convert input to symbolic variables if they are not already + if ~isa(input, 'sym') + input = sym(input); + end + + % Get the dimensions of the input matrix + [input_rows, input_cols] = size(input); + + % Calculate the dimensions of the output matrix + output_rows = upsample_factor * (input_rows - 1) + 1; + output_cols = upsample_factor * (input_cols - 1) + 1; + + % Initialize the output matrix with zeros + output = sym(zeros(output_rows, output_cols)); + + % Perform the 2D bilinear upsampling + for i = 1:output_rows + for j = 1:output_cols + % Calculate the corresponding input coordinates (1-indexed) + input_row = (i - 1) / upsample_factor + 1; + input_col = (j - 1) / upsample_factor + 1; + + % Find the surrounding input pixel coordinates + row1 = floor(input_row); + row2 = ceil(input_row); + col1 = floor(input_col); + col2 = ceil(input_col); + + % Calculate the interpolation weights + alpha = input_row - row1; + beta = input_col - col1; + + % Perform bilinear interpolation + if row1 > 0 && row2 <= input_rows && col1 > 0 && col2 <= input_cols + output(i, j) = (1 - alpha) * (1 - beta) * input(row1, col1) + ... + (1 - alpha) * beta * input(row1, col2) + ... + alpha * (1 - beta) * input(row2, col1) + ... + alpha * beta * input(row2, col2); + end + end + end +end \ No newline at end of file diff --git a/data/scalecrafter/disperse/conv2d.m b/data/scalecrafter/disperse/conv2d.m new file mode 100644 index 000000000..70394aeda --- /dev/null +++ b/data/scalecrafter/disperse/conv2d.m @@ -0,0 +1,31 @@ +function out = conv2d(input, kernel, padding) + % Get the dimensions of the input and kernel + [input_rows, input_cols] = size(input); + [kernel_rows, kernel_cols] = size(kernel); + + % Calculate the output dimensions with padding + output_rows = input_rows + 2 * padding - kernel_rows + 1; + output_cols = input_cols + 2 * padding - kernel_cols + 1; + + % Initialize the padded input with zeros + padded_input = sym(zeros(input_rows + 2 * padding, input_cols + 2 * padding)); + + % Fill the padded input with the original input values + padded_input(padding + 1 : padding + input_rows, padding + 1 : padding + input_cols) = input; + + % Initialize the output matrix with zeros + out = sym(zeros(output_rows, output_cols)); + + % Perform the 2D convolution + for m = 1 : output_rows + for n = 1 : output_cols + temp_sum = 0; + for k = 1 : kernel_rows + for l = 1 : kernel_cols + temp_sum = temp_sum + kernel(k, l) * padded_input(m + k - 1, n + l - 1); + end + end + out(m, n) = temp_sum; + end + end +end \ No newline at end of file diff --git a/data/scalecrafter/disperse/kernel_disperse.m b/data/scalecrafter/disperse/kernel_disperse.m new file mode 100644 index 000000000..93ba87c70 --- /dev/null +++ b/data/scalecrafter/disperse/kernel_disperse.m @@ -0,0 +1,89 @@ +function R = kernel_disperse(smallSize, largeSize, inputSize, scale, eta, verbose) + % Solve the convolution dispersion transform + % Params: + % smallSize: size of the input kernel (i.e. 3) + % largeSize: size of the output kernel (i.e. 5) + % inputSize: size of the input feature (i.e. 7) + % scale: perception field enlarge scale (i.e. 2) + % eta: the weight combining structue-level and pixel-level + % calibration (i.e. 0.05) + % verbose: whether to deliver a visualization + % Outputs: + % R: dispersion linear transform + if ~exist('verbose', 'var'), verbose = false; end + + % Initialize kernel and inputs + % R = sym_kernel('r', largeSize ^ 2 , smallSize ^ 2); + smallKernel = sym_kernel('a', smallSize, smallSize); + largeKernel = sym_kernel('b', largeSize, largeSize); + inputFeature = sym_kernel('x', inputSize, inputSize); + + % Compute structure-level calibration + interFeature = bilinear_upsample_symbolic(inputFeature, scale); + smallOutput = conv2d(inputFeature, smallKernel, (smallSize - 1) / 2); + largeOutput = conv2d(interFeature, largeKernel, (largeSize - 1) / 2); + smallOutput = bilinear_upsample_symbolic(smallOutput, scale); + + % Compute loss and get the equation set + structError = largeOutput - smallOutput; + structError = reshape(transpose(structError), [], 1); + % structError = structError(13); + equations = []; + for input = reshape(transpose(inputFeature), 1, []) + equations = [equations; diff(structError, input)]; + end + equations = equations(equations ~= 0); + + equationNum = size(equations, 1); + structLHSCoeff = sym(zeros([equationNum, largeSize ^ 2])); + loopIndex = 0; + for element = reshape(transpose(largeKernel), 1, []) + loopIndex = loopIndex + 1; + structLHSCoeff(1:end, loopIndex) = diff(equations, element); + end + termLHS = structLHSCoeff * reshape(transpose(largeKernel), [], 1); + structRHSCoeff = termLHS - equations; + + % Compute pixel-level calibration + smallOutput = conv2d(inputFeature, smallKernel, (smallSize - 1) / 2); + largeOutput = conv2d(inputFeature, largeKernel, (largeSize - 1) / 2); + + pixelError = largeOutput - smallOutput; + pixelError = reshape(transpose(pixelError), [], 1); + % pixelError = pixelError(5); + equations = []; + for input = reshape(transpose(inputFeature), 1, []) + equations = [equations; diff(pixelError, input)]; + end + equations = equations(equations ~= 0); + + equationNum = size(equations, 1); + pixelLHSCoeff = sym(zeros([equationNum, largeSize ^ 2])); + loopIndex = 0; + for element = reshape(transpose(largeKernel), 1, []) + loopIndex = loopIndex + 1; + pixelLHSCoeff(1:end, loopIndex) = diff(equations, element); + end + termLHS = pixelLHSCoeff * reshape(transpose(largeKernel), [], 1); + pixelRHSCoeff = termLHS - equations; + + % Solve the least square problem + A = [structLHSCoeff; eta * pixelLHSCoeff]; + b = [structRHSCoeff; eta * pixelRHSCoeff]; + x = (transpose(A) * A) \ (transpose(A) * b); + x = vpa(x); + R = zeros([largeSize ^ 2, smallSize ^ 2]); + loopIndex = 0; + for element = reshape(transpose(smallKernel), 1, []) + loopIndex = loopIndex + 1; + R(1:end, loopIndex) = diff(x, element); + end + + if verbose + largeKernel = R * ones([smallSize ^ 2, 1]); + largeKernel = transpose(reshape(largeKernel, largeSize, largeSize)); + heatmap(figure, largeKernel); + title("Dispersed conv. kernel provided a small kernel filled with one"); + end +end + diff --git a/data/scalecrafter/disperse/sym_kernel.m b/data/scalecrafter/disperse/sym_kernel.m new file mode 100644 index 000000000..2946816b3 --- /dev/null +++ b/data/scalecrafter/disperse/sym_kernel.m @@ -0,0 +1,10 @@ +function kernel = sym_kernel(symbol, height, width) + kernel = sym(zeros([height, width])); + for i = 1:height + for j = 1:width + index = (i - 1) * width + j; + kernel(i, j) = str2sym(sprintf("%s%d", symbol, index)); + end + end +end + diff --git a/data/scalecrafter/transforms/R20to1.mat b/data/scalecrafter/transforms/R20to1.mat new file mode 100644 index 000000000..136c7ad31 Binary files /dev/null and b/data/scalecrafter/transforms/R20to1.mat differ diff --git a/data/scalecrafter/transforms/R20to1_new.mat b/data/scalecrafter/transforms/R20to1_new.mat new file mode 100644 index 000000000..5ae953680 Binary files /dev/null and b/data/scalecrafter/transforms/R20to1_new.mat differ diff --git a/data/scalecrafter/transforms/R2to1.mat b/data/scalecrafter/transforms/R2to1.mat new file mode 100644 index 000000000..c612b1caf Binary files /dev/null and b/data/scalecrafter/transforms/R2to1.mat differ diff --git a/data/themes/dark.json b/data/themes/dark.json index 219e4c5e2..5cbc18c8e 100644 --- a/data/themes/dark.json +++ b/data/themes/dark.json @@ -18,5 +18,8 @@ }, "Tabs": { "colorSegment": "rgba(24, 24, 28, 0.6)" + }, + "Drawer": { + "color": "rgba(44, 44, 50, 0)" } } \ No newline at end of file diff --git a/docker/ait-no-mount.docker-compose.yml b/docker/ait-no-mount.docker-compose.yml new file mode 100644 index 000000000..022e51513 --- /dev/null +++ b/docker/ait-no-mount.docker-compose.yml @@ -0,0 +1,29 @@ +version: "3.7" + +services: + voltaml: + image: stax124/volta:experimental-ait + environment: + # General + - HUGGINGFACE_TOKEN=${HUGGINGFACE_TOKEN:-} + - LOG_LEVEL=${LOG_LEVEL:-INFO} + - EXTRA_ARGS=${EXTRA_ARGS:-} + + # Extra api keys + - FASTAPI_ANALYTICS_KEY=${FASTAPI_ANALYTICS_KEY:-} + - DISCORD_BOT_TOKEN=${DISCORD_BOT_TOKEN:-} + + # R2 + - R2_ENDPOINT=${R2_ENDPOINT:-} + - R2_BUCKET_NAME=${R2_BUCKET_NAME:-} + - AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID:-} + - AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY:-} + - R2_DEV_ADDRESS=${R2_DEV_ADDRESS:-} + ports: + - "5003:5003" + deploy: + resources: + reservations: + devices: + - driver: nvidia + capabilities: ["gpu"] diff --git a/docker/ait.docker-compose.yml b/docker/ait.docker-compose.yml new file mode 100644 index 000000000..c8ccac7c0 --- /dev/null +++ b/docker/ait.docker-compose.yml @@ -0,0 +1,32 @@ +version: "3.7" + +services: + voltaml: + image: stax124/volta:experimental-ait + environment: + # General + - HUGGINGFACE_TOKEN=${HUGGINGFACE_TOKEN:-} + - LOG_LEVEL=${LOG_LEVEL:-INFO} + - EXTRA_ARGS=${EXTRA_ARGS:-} + + # Extra api keys + - FASTAPI_ANALYTICS_KEY=${FASTAPI_ANALYTICS_KEY:-} + - DISCORD_BOT_TOKEN=${DISCORD_BOT_TOKEN:-} + + # R2 + - R2_ENDPOINT=${R2_ENDPOINT:-} + - R2_BUCKET_NAME=${R2_BUCKET_NAME:-} + - AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID:-} + - AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY:-} + - R2_DEV_ADDRESS=${R2_DEV_ADDRESS:-} + volumes: + - ${HOME}/voltaML/data:/app/data # XXX is the path to the folder where all the outputs will be saved + - ${HOME}/.cache/huggingface:/root/.cache/huggingface # YYY is path to your home folder (you may need to change the YYY/.cache/huggingface to YYY\.cache\huggingface on Windows) + ports: + - "5003:5003" + deploy: + resources: + reservations: + devices: + - driver: nvidia + capabilities: ["gpu"] diff --git a/dockerfile b/docker/ait/dockerfile similarity index 57% rename from dockerfile rename to docker/ait/dockerfile index bd42c1d43..09ff746ca 100644 --- a/dockerfile +++ b/docker/ait/dockerfile @@ -1,28 +1,39 @@ -FROM stax124/aitemplate:latest +FROM stax124/ait:torch2.1.1-cuda11.8-ubuntu22.04-devel ENV DEBIAN_FRONTEND=noninteractive +# Basic dependencies RUN apt update && apt install curl -y -RUN curl -sL https://deb.nodesource.com/setup_18.x | bash +RUN apt install time git -y +RUN apt install python3 python3-pip -y +RUN pip install --upgrade pip +RUN apt install -y ca-certificates curl gnupg +# Set up Node.js and Yarn +RUN mkdir -p /etc/apt/keyrings +RUN curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key | gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg +RUN echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_20.x nodistro main" | tee /etc/apt/sources.list.d/nodesource.list +RUN apt update RUN apt install nodejs -y - RUN npm i -g yarn -RUN apt install time git -y -RUN pip install --upgrade pip +# Set up working directory and copy requirement definitions WORKDIR /app - COPY requirements /app/requirements +# PyTorch goes first to avoid redownloads +RUN --mount=type=cache,mode=0755,target=/root/.cache/pip pip install torch torchvision torchaudio + +# Other Python dependencies +RUN --mount=type=cache,mode=0755,target=/root/.cache/pip pip install python-dotenv requests RUN --mount=type=cache,mode=0755,target=/root/.cache/pip pip install -r requirements/api.txt RUN --mount=type=cache,mode=0755,target=/root/.cache/pip pip install -r requirements/bot.txt RUN --mount=type=cache,mode=0755,target=/root/.cache/pip pip install -r requirements/pytorch.txt RUN --mount=type=cache,mode=0755,target=/root/.cache/pip pip install -r requirements/interrogation.txt -RUN --mount=type=cache,mode=0755,target=/root/.cache/pip pip install python-dotenv COPY . /app +# Install frontend dependencies and build the frontend RUN --mount=type=cache,mode=0755,target=/app/frontend/node_modules cd frontend && yarn install && yarn build RUN rm -rf frontend/node_modules diff --git a/docker/cuda-no-mount.docker-compose.yml b/docker/cuda-no-mount.docker-compose.yml new file mode 100644 index 000000000..fe8aa9753 --- /dev/null +++ b/docker/cuda-no-mount.docker-compose.yml @@ -0,0 +1,29 @@ +version: "3.7" + +services: + voltaml: + image: stax124/volta:experimental-cuda + environment: + # General + - HUGGINGFACE_TOKEN=${HUGGINGFACE_TOKEN:-} + - LOG_LEVEL=${LOG_LEVEL:-INFO} + - EXTRA_ARGS=${EXTRA_ARGS:-} + + # Extra api keys + - FASTAPI_ANALYTICS_KEY=${FASTAPI_ANALYTICS_KEY:-} + - DISCORD_BOT_TOKEN=${DISCORD_BOT_TOKEN:-} + + # R2 + - R2_ENDPOINT=${R2_ENDPOINT:-} + - R2_BUCKET_NAME=${R2_BUCKET_NAME:-} + - AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID:-} + - AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY:-} + - R2_DEV_ADDRESS=${R2_DEV_ADDRESS:-} + ports: + - "5003:5003" + deploy: + resources: + reservations: + devices: + - driver: nvidia + capabilities: ["gpu"] diff --git a/docker/cuda.docker-compose.yml b/docker/cuda.docker-compose.yml new file mode 100644 index 000000000..f71744385 --- /dev/null +++ b/docker/cuda.docker-compose.yml @@ -0,0 +1,32 @@ +version: "3.7" + +services: + voltaml: + image: stax124/volta:experimental-cuda + environment: + # General + - HUGGINGFACE_TOKEN=${HUGGINGFACE_TOKEN:-} + - LOG_LEVEL=${LOG_LEVEL:-INFO} + - EXTRA_ARGS=${EXTRA_ARGS:-} + + # Extra api keys + - FASTAPI_ANALYTICS_KEY=${FASTAPI_ANALYTICS_KEY:-} + - DISCORD_BOT_TOKEN=${DISCORD_BOT_TOKEN:-} + + # R2 + - R2_ENDPOINT=${R2_ENDPOINT:-} + - R2_BUCKET_NAME=${R2_BUCKET_NAME:-} + - AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID:-} + - AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY:-} + - R2_DEV_ADDRESS=${R2_DEV_ADDRESS:-} + volumes: + - ${HOME}/voltaML/data:/app/data # XXX is the path to the folder where all the outputs will be saved + - ${HOME}/.cache/huggingface:/root/.cache/huggingface # YYY is path to your home folder (you may need to change the YYY/.cache/huggingface to YYY\.cache\huggingface on Windows) + ports: + - "5003:5003" + deploy: + resources: + reservations: + devices: + - driver: nvidia + capabilities: ["gpu"] diff --git a/docker/cuda/dockerfile b/docker/cuda/dockerfile new file mode 100644 index 000000000..ba9468359 --- /dev/null +++ b/docker/cuda/dockerfile @@ -0,0 +1,46 @@ +FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04 + +ENV DEBIAN_FRONTEND=noninteractive + +# Basic dependencies +RUN apt update && apt install curl -y +RUN apt install time git -y +RUN apt install python3 python3-pip -y +RUN pip install --upgrade pip +RUN apt install -y ca-certificates curl gnupg + +# Set up Node.js and Yarn +RUN mkdir -p /etc/apt/keyrings +RUN curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key | gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg +RUN echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_20.x nodistro main" | tee /etc/apt/sources.list.d/nodesource.list +RUN apt update +RUN apt install nodejs -y +RUN npm i -g yarn + +# Set up working directory and copy requirement definitions +WORKDIR /app +COPY requirements /app/requirements + +# PyTorch goes first to avoid redownloads +RUN --mount=type=cache,mode=0755,target=/root/.cache/pip pip install torch torchvision torchaudio + +# Other Python dependencies +RUN --mount=type=cache,mode=0755,target=/root/.cache/pip pip install python-dotenv requests +RUN --mount=type=cache,mode=0755,target=/root/.cache/pip pip install -r requirements/api.txt +RUN --mount=type=cache,mode=0755,target=/root/.cache/pip pip install -r requirements/bot.txt +RUN --mount=type=cache,mode=0755,target=/root/.cache/pip pip install -r requirements/pytorch.txt +RUN --mount=type=cache,mode=0755,target=/root/.cache/pip pip install -r requirements/interrogation.txt + +COPY . /app + +# Install frontend dependencies and build the frontend +RUN --mount=type=cache,mode=0755,target=/app/frontend/node_modules cd frontend && yarn install && yarn build +RUN rm -rf frontend/node_modules + +# Remove caches +RUN rm -rf /root/.cache +RUN rm -rf /usr/local/share/.cache + +# Run the server +RUN chmod +x scripts/start.sh +ENTRYPOINT ["bash", "./scripts/start.sh"] diff --git a/docs/settings/reproducibility.md b/docs/settings/reproducibility.md new file mode 100644 index 000000000..1d41a7b83 --- /dev/null +++ b/docs/settings/reproducibility.md @@ -0,0 +1,67 @@ +# Reproducibility & Generation + +Reproducibility settings are settings that change generation output. These changes can vary from small, to large, with small being a few lines look sharper + +## Device + +Changing the device to the correct one -- that being, your fastest available GPU -- can not only improve performance, but also change how the images look like. Something generated using DirectML on an AMD card won't EVER look the same as something generated with CUDA. + +## Data type + +Generally, changing data type to a lower precision (lower number) one, will improve performance, however, when taken to extreme degrees (volta doesn't have this implemented) image quality starts to get hammered. `16-bit float` or `16-bit bfloat` is generally the lowest people should need to go. + +## Deterministic generation + +PyTorch, and as such, Volta, is by design indeterministic, - that is, not 100% reproducible. This can raise a few issues: generations using the exact same parameters **MAY NOT** come out the same. Changing this to on, should fix these issues. + +## SGM Noise Multiplier + +SGM Noise multiplier changes how noise is calculated. This is only useful for reproducing already created images. From a more technical standpoint: this changes noising to mimic SDXL's noise creation. **Only useful on `SD1.x`.** + +### On vs. off + + + + + + +## Quantization in KDiff samplers + +Quantization in K-samplers helps the samplers to create more sharp and defined lines. This is another one of those _"small, but useful"_ changes. + +### On vs. off + + + + + + +## Generator + + diff --git a/docs/static/settings/reproducibility/quant_on.webp b/docs/static/settings/reproducibility/quant_on.webp new file mode 100644 index 000000000..bf5a1f3e5 Binary files /dev/null and b/docs/static/settings/reproducibility/quant_on.webp differ diff --git a/docs/static/settings/reproducibility/sgm_off.webp b/docs/static/settings/reproducibility/sgm_off.webp new file mode 100644 index 000000000..6621f57a8 Binary files /dev/null and b/docs/static/settings/reproducibility/sgm_off.webp differ diff --git a/docs/static/settings/reproducibility/sgm_on.webp b/docs/static/settings/reproducibility/sgm_on.webp new file mode 100644 index 000000000..a079f3dc6 Binary files /dev/null and b/docs/static/settings/reproducibility/sgm_on.webp differ diff --git a/frontend/dist/assets/404View.js b/frontend/dist/assets/404View.js index cb2d76290..7615194eb 100644 --- a/frontend/dist/assets/404View.js +++ b/frontend/dist/assets/404View.js @@ -1,4 +1,4 @@ -import { d as defineComponent, o as openBlock, j as createElementBlock, g as createVNode, w as withCtx, h as unref, c3 as NResult, n as NCard } from "./index.js"; +import { d as defineComponent, o as openBlock, a as createElementBlock, e as createVNode, w as withCtx, f as unref, c6 as NResult, m as NCard } from "./index.js"; const _hoisted_1 = { style: { "width": "100vw", "height": "100vh", "display": "flex", "align-items": "center", "justify-content": "center", "backdrop-filter": "blur(4px)" } }; const _sfc_main = /* @__PURE__ */ defineComponent({ __name: "404View", diff --git a/frontend/dist/assets/AboutView.js b/frontend/dist/assets/AboutView.js index fee3296e9..5ce1297f6 100644 --- a/frontend/dist/assets/AboutView.js +++ b/frontend/dist/assets/AboutView.js @@ -1,4 +1,4 @@ -import { _ as _export_sfc, j as createElementBlock, o as openBlock, f as createBaseVNode } from "./index.js"; +import { _ as _export_sfc, a as createElementBlock, o as openBlock, b as createBaseVNode } from "./index.js"; const _sfc_main = {}; const _hoisted_1 = { class: "about" }; const _hoisted_2 = /* @__PURE__ */ createBaseVNode("h1", null, "This is an about page", -1); diff --git a/frontend/dist/assets/AccelerateView.js b/frontend/dist/assets/AccelerateView.js index c10dc0930..81d5ff96c 100644 --- a/frontend/dist/assets/AccelerateView.js +++ b/frontend/dist/assets/AccelerateView.js @@ -1,6 +1,7 @@ -import { Q as cB, ab as cM, aa as c, at as cE, aT as iconSwitchTransition, ac as cNotM, d as defineComponent, S as useConfig, ag as useRtl, T as useTheme, a3 as provide, y as h, aw as flatten, ax as getSlot, P as createInjectionKey, bg as stepsLight, R as inject, a_ as throwError, c as computed, ah as createKey, Y as useThemeClass, a1 as call, av as resolveWrappedSlot, ai as resolveSlot, aI as NIconSwitchTransition, aj as NBaseIcon, bh as FinishedIcon, bi as ErrorIcon, p as useMessage, a as useState, z as ref, o as openBlock, j as createElementBlock, g as createVNode, w as withCtx, h as unref, N as NSpace, n as NCard, f as createBaseVNode, i as NSelect, A as NButton, k as createTextVNode, bd as NModal, t as serverUrl, u as useSettings, e as createBlock, D as NTabPane, E as NTabs } from "./index.js"; -import { a as NSlider, N as NSwitch } from "./Switch.js"; +import { R as cB, ac as cM, ab as c, au as cE, aU as iconSwitchTransition, ad as cNotM, d as defineComponent, T as useConfig, ah as useRtl, U as useTheme, a4 as provide, A as h, ax as flatten, ay as getSlot, Q as createInjectionKey, bk as stepsLight, S as inject, a$ as throwError, c as computed, ai as createKey, Z as useThemeClass, a2 as call, aw as resolveWrappedSlot, aj as resolveSlot, aJ as NIconSwitchTransition, ak as NBaseIcon, bl as FinishedIcon, bm as ErrorIcon, r as useMessage, l as useState, B as ref, o as openBlock, a as createElementBlock, e as createVNode, w as withCtx, f as unref, j as NSpace, m as NCard, b as createBaseVNode, q as NSelect, C as NButton, h as createTextVNode, bh as NModal, x as serverUrl, u as useSettings, g as createBlock, n as NTabPane, p as NTabs } from "./index.js"; +import { N as NSlider } from "./Slider.js"; import { N as NInputNumber } from "./InputNumber.js"; +import { N as NSwitch } from "./Switch.js"; const style = cB("steps", ` width: 100%; display: flex; diff --git a/frontend/dist/assets/CloudUpload.js b/frontend/dist/assets/CloudUpload.js index 78a6dea82..68c86e860 100644 --- a/frontend/dist/assets/CloudUpload.js +++ b/frontend/dist/assets/CloudUpload.js @@ -1,4 +1,4 @@ -import { d as defineComponent, o as openBlock, j as createElementBlock, f as createBaseVNode } from "./index.js"; +import { d as defineComponent, o as openBlock, a as createElementBlock, b as createBaseVNode } from "./index.js"; const _hoisted_1 = { xmlns: "http://www.w3.org/2000/svg", "xmlns:xlink": "http://www.w3.org/1999/xlink", diff --git a/frontend/dist/assets/DescriptionsItem.js b/frontend/dist/assets/DescriptionsItem.js index 27c00366e..1c31a5408 100644 --- a/frontend/dist/assets/DescriptionsItem.js +++ b/frontend/dist/assets/DescriptionsItem.js @@ -1,4 +1,4 @@ -import { aa as c, Q as cB, ac as cNotM, ab as cM, at as cE, aU as insideModal, aV as insidePopover, d as defineComponent, S as useConfig, T as useTheme, c as computed, ah as createKey, Y as useThemeClass, bS as useCompitable, aw as flatten, y as h, aQ as repeat, ax as getSlot, bT as descriptionsLight } from "./index.js"; +import { ab as c, R as cB, ad as cNotM, ac as cM, au as cE, aV as insideModal, aW as insidePopover, d as defineComponent, T as useConfig, U as useTheme, c as computed, ai as createKey, Z as useThemeClass, bX as useCompitable, ax as flatten, A as h, aR as repeat, ay as getSlot, bY as descriptionsLight } from "./index.js"; function getVNodeChildren(vNode, slotName = "default", fallback = []) { const { children } = vNode; if (children !== null && typeof children === "object" && !Array.isArray(children)) { diff --git a/frontend/dist/assets/ExtraView.js b/frontend/dist/assets/ExtraView.js index 51509edbf..55f5c75e5 100644 --- a/frontend/dist/assets/ExtraView.js +++ b/frontend/dist/assets/ExtraView.js @@ -1,4 +1,4 @@ -import { _ as _export_sfc, d as defineComponent, a as useState, o as openBlock, e as createBlock, w as withCtx, h as unref, g as createVNode, D as NTabPane, E as NTabs } from "./index.js"; +import { _ as _export_sfc, d as defineComponent, l as useState, o as openBlock, g as createBlock, w as withCtx, f as unref, e as createVNode, n as NTabPane, p as NTabs } from "./index.js"; const _sfc_main$2 = {}; function _sfc_render$1(_ctx, _cache) { return "Autofill manager"; diff --git a/frontend/dist/assets/GenerateSection.vue_vue_type_script_setup_true_lang.js b/frontend/dist/assets/GenerateSection.vue_vue_type_script_setup_true_lang.js index a3d8d99d5..46752b12a 100644 --- a/frontend/dist/assets/GenerateSection.vue_vue_type_script_setup_true_lang.js +++ b/frontend/dist/assets/GenerateSection.vue_vue_type_script_setup_true_lang.js @@ -1,4 +1,4 @@ -import { d as defineComponent, o as openBlock, j as createElementBlock, f as createBaseVNode, a as useState, u as useSettings, z as ref, b9 as onMounted, q as onUnmounted, t as serverUrl, e as createBlock, w as withCtx, g as createVNode, h as unref, r as NGi, A as NButton, B as NIcon, k as createTextVNode, s as NGrid, bV as NAlert, m as createCommentVNode, n as NCard } from "./index.js"; +import { d as defineComponent, o as openBlock, a as createElementBlock, b as createBaseVNode, l as useState, u as useSettings, B as ref, bb as onMounted, s as onUnmounted, x as serverUrl, g as createBlock, w as withCtx, e as createVNode, f as unref, t as NGi, C as NButton, D as NIcon, h as createTextVNode, v as NGrid, bK as NAlert, k as createCommentVNode, m as NCard } from "./index.js"; const _hoisted_1$1 = { xmlns: "http://www.w3.org/2000/svg", "xmlns:xlink": "http://www.w3.org/1999/xlink", diff --git a/frontend/dist/assets/GridOutline.js b/frontend/dist/assets/GridOutline.js deleted file mode 100644 index aaa851714..000000000 --- a/frontend/dist/assets/GridOutline.js +++ /dev/null @@ -1,92 +0,0 @@ -import { d as defineComponent, o as openBlock, j as createElementBlock, f as createBaseVNode } from "./index.js"; -const _hoisted_1 = { - xmlns: "http://www.w3.org/2000/svg", - "xmlns:xlink": "http://www.w3.org/1999/xlink", - viewBox: "0 0 512 512" -}; -const _hoisted_2 = /* @__PURE__ */ createBaseVNode( - "rect", - { - x: "48", - y: "48", - width: "176", - height: "176", - rx: "20", - ry: "20", - fill: "none", - stroke: "currentColor", - "stroke-linecap": "round", - "stroke-linejoin": "round", - "stroke-width": "32" - }, - null, - -1 - /* HOISTED */ -); -const _hoisted_3 = /* @__PURE__ */ createBaseVNode( - "rect", - { - x: "288", - y: "48", - width: "176", - height: "176", - rx: "20", - ry: "20", - fill: "none", - stroke: "currentColor", - "stroke-linecap": "round", - "stroke-linejoin": "round", - "stroke-width": "32" - }, - null, - -1 - /* HOISTED */ -); -const _hoisted_4 = /* @__PURE__ */ createBaseVNode( - "rect", - { - x: "48", - y: "288", - width: "176", - height: "176", - rx: "20", - ry: "20", - fill: "none", - stroke: "currentColor", - "stroke-linecap": "round", - "stroke-linejoin": "round", - "stroke-width": "32" - }, - null, - -1 - /* HOISTED */ -); -const _hoisted_5 = /* @__PURE__ */ createBaseVNode( - "rect", - { - x: "288", - y: "288", - width: "176", - height: "176", - rx: "20", - ry: "20", - fill: "none", - stroke: "currentColor", - "stroke-linecap": "round", - "stroke-linejoin": "round", - "stroke-width": "32" - }, - null, - -1 - /* HOISTED */ -); -const _hoisted_6 = [_hoisted_2, _hoisted_3, _hoisted_4, _hoisted_5]; -const GridOutline = defineComponent({ - name: "GridOutline", - render: function render(_ctx, _cache) { - return openBlock(), createElementBlock("svg", _hoisted_1, _hoisted_6); - } -}); -export { - GridOutline as G -}; diff --git a/frontend/dist/assets/Image2ImageView.js b/frontend/dist/assets/Image2ImageView.js index 7233bee1c..176efb54b 100644 --- a/frontend/dist/assets/Image2ImageView.js +++ b/frontend/dist/assets/Image2ImageView.js @@ -1,12 +1,13 @@ -import { d as defineComponent, o as openBlock, j as createElementBlock, f as createBaseVNode, a as useState, u as useSettings, p as useMessage, q as onUnmounted, g as createVNode, w as withCtx, h as unref, r as NGi, n as NCard, N as NSpace, l as NTooltip, k as createTextVNode, i as NSelect, s as NGrid, t as serverUrl, v as pushScopeId, x as popScopeId, _ as _export_sfc, m as createCommentVNode, y as h, z as ref, A as NButton, B as NIcon, e as createBlock, C as toDisplayString, D as NTabPane, E as NTabs } from "./index.js"; -import { B as BurnerClock, P as Prompt, _ as _sfc_main$5, a as _sfc_main$6, b as _sfc_main$9 } from "./clock.js"; -import { _ as _sfc_main$7 } from "./GenerateSection.vue_vue_type_script_setup_true_lang.js"; -import { _ as _sfc_main$8 } from "./ImageOutput.vue_vue_type_script_setup_true_lang.js"; +import { d as defineComponent, o as openBlock, a as createElementBlock, b as createBaseVNode, l as useState, u as useSettings, r as useMessage, s as onUnmounted, e as createVNode, w as withCtx, f as unref, t as NGi, m as NCard, j as NSpace, N as NTooltip, h as createTextVNode, q as NSelect, v as NGrid, x as serverUrl, y as pushScopeId, z as popScopeId, _ as _export_sfc, A as h, B as ref, C as NButton, D as NIcon, g as createBlock, E as toDisplayString, n as NTabPane, p as NTabs } from "./index.js"; +import { B as BurnerClock, P as Prompt, b as _sfc_main$5, _ as _sfc_main$6, a as _sfc_main$7, c as _sfc_main$8, d as _sfc_main$d } from "./clock.js"; +import { _ as _sfc_main$b } from "./GenerateSection.vue_vue_type_script_setup_true_lang.js"; +import { _ as _sfc_main$c } from "./ImageOutput.vue_vue_type_script_setup_true_lang.js"; import { I as ImageUpload } from "./ImageUpload.js"; -import { _ as _sfc_main$4 } from "./SamplerPicker.vue_vue_type_script_setup_true_lang.js"; +import { _ as _sfc_main$4, a as _sfc_main$9, b as _sfc_main$a } from "./Upscale.vue_vue_type_script_setup_true_lang.js"; import { v as v4 } from "./v4.js"; -import { a as NSlider, N as NSwitch } from "./Switch.js"; +import { N as NSlider } from "./Slider.js"; import { N as NInputNumber } from "./InputNumber.js"; +import { N as NSwitch } from "./Switch.js"; import "./DescriptionsItem.js"; import "./SendOutputTo.vue_vue_type_script_setup_true_lang.js"; import "./TrashBin.js"; @@ -145,7 +146,7 @@ const TrashBinSharp = defineComponent({ return openBlock(), createElementBlock("svg", _hoisted_1$3, _hoisted_6$3); } }); -const _withScopeId$2 = (n) => (pushScopeId("data-v-efacc8fd"), n = n(), popScopeId(), n); +const _withScopeId$2 = (n) => (pushScopeId("data-v-d4ff54ab"), n = n(), popScopeId(), n); const _hoisted_1$2 = { style: { "margin": "0 12px" } }; const _hoisted_2$2 = { class: "flex-container" }; const _hoisted_3$2 = /* @__PURE__ */ _withScopeId$2(() => /* @__PURE__ */ createBaseVNode("p", { style: { "margin-right": "12px", "width": "150px" } }, "ControlNet", -1)); @@ -154,23 +155,20 @@ const _hoisted_5$2 = { class: "flex-container" }; const _hoisted_6$2 = /* @__PURE__ */ _withScopeId$2(() => /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Steps", -1)); const _hoisted_7$2 = /* @__PURE__ */ _withScopeId$2(() => /* @__PURE__ */ createBaseVNode("b", { class: "highlight" }, "We recommend using 20-50 steps for most images.", -1)); const _hoisted_8$2 = { class: "flex-container" }; -const _hoisted_9$2 = /* @__PURE__ */ _withScopeId$2(() => /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "CFG Scale", -1)); -const _hoisted_10$2 = /* @__PURE__ */ _withScopeId$2(() => /* @__PURE__ */ createBaseVNode("b", { class: "highlight" }, "We recommend using 3-15 for most images.", -1)); -const _hoisted_11$2 = { class: "flex-container" }; -const _hoisted_12$2 = /* @__PURE__ */ _withScopeId$2(() => /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Batch Count", -1)); -const _hoisted_13$2 = { class: "flex-container" }; -const _hoisted_14$2 = /* @__PURE__ */ _withScopeId$2(() => /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "ControlNet Conditioning Scale", -1)); -const _hoisted_15$2 = { class: "flex-container" }; -const _hoisted_16$2 = /* @__PURE__ */ _withScopeId$2(() => /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Detection resolution", -1)); -const _hoisted_17$2 = { class: "flex-container" }; -const _hoisted_18$1 = /* @__PURE__ */ _withScopeId$2(() => /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Seed", -1)); -const _hoisted_19$1 = /* @__PURE__ */ _withScopeId$2(() => /* @__PURE__ */ createBaseVNode("b", { class: "highlight" }, "For random seed use -1.", -1)); -const _hoisted_20$1 = { class: "flex-container" }; -const _hoisted_21$1 = /* @__PURE__ */ _withScopeId$2(() => /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Is Preprocessed", -1)); -const _hoisted_22$1 = { class: "flex-container" }; -const _hoisted_23$1 = /* @__PURE__ */ _withScopeId$2(() => /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Save Preprocessed", -1)); -const _hoisted_24$1 = { class: "flex-container" }; -const _hoisted_25 = /* @__PURE__ */ _withScopeId$2(() => /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Return Preprocessed", -1)); +const _hoisted_9$2 = /* @__PURE__ */ _withScopeId$2(() => /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Batch Count", -1)); +const _hoisted_10$2 = { class: "flex-container" }; +const _hoisted_11$2 = /* @__PURE__ */ _withScopeId$2(() => /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "ControlNet Conditioning Scale", -1)); +const _hoisted_12$1 = { class: "flex-container" }; +const _hoisted_13$1 = /* @__PURE__ */ _withScopeId$2(() => /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Detection resolution", -1)); +const _hoisted_14$1 = { class: "flex-container" }; +const _hoisted_15$1 = /* @__PURE__ */ _withScopeId$2(() => /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Seed", -1)); +const _hoisted_16$1 = /* @__PURE__ */ _withScopeId$2(() => /* @__PURE__ */ createBaseVNode("b", { class: "highlight" }, "For random seed use -1.", -1)); +const _hoisted_17$1 = { class: "flex-container" }; +const _hoisted_18$1 = /* @__PURE__ */ _withScopeId$2(() => /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Is Preprocessed", -1)); +const _hoisted_19$1 = { class: "flex-container" }; +const _hoisted_20$1 = /* @__PURE__ */ _withScopeId$2(() => /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Save Preprocessed", -1)); +const _hoisted_21 = { class: "flex-container" }; +const _hoisted_22 = /* @__PURE__ */ _withScopeId$2(() => /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Return Preprocessed", -1)); const _sfc_main$3 = /* @__PURE__ */ defineComponent({ __name: "ControlNet", setup(__props) { @@ -231,7 +229,28 @@ const _sfc_main$3 = /* @__PURE__ */ defineComponent({ save_preprocessed: settings.data.settings.controlnet.save_preprocessed, return_preprocessed: settings.data.settings.controlnet.return_preprocessed }, - model: (_a = settings.data.settings.model) == null ? void 0 : _a.path + model: (_a = settings.data.settings.model) == null ? void 0 : _a.path, + flags: { + ...settings.data.settings.controlnet.highres.enabled ? { + highres_fix: { + mode: settings.data.settings.controlnet.highres.mode, + image_upscaler: settings.data.settings.controlnet.highres.image_upscaler, + scale: settings.data.settings.controlnet.highres.scale, + latent_scale_mode: settings.data.settings.controlnet.highres.latent_scale_mode, + strength: settings.data.settings.controlnet.highres.strength, + steps: settings.data.settings.controlnet.highres.steps, + antialiased: settings.data.settings.controlnet.highres.antialiased + } + } : {}, + ...settings.data.settings.controlnet.upscale.enabled ? { + upscale: { + upscale_factor: settings.data.settings.controlnet.upscale.upscale_factor, + tile_size: settings.data.settings.controlnet.upscale.tile_size, + tile_padding: settings.data.settings.controlnet.upscale.tile_padding, + model: settings.data.settings.controlnet.upscale.model + } + } : {} + } }) }).then((res) => { if (!res.ok) { @@ -340,40 +359,13 @@ const _sfc_main$3 = /* @__PURE__ */ defineComponent({ max: 300 }, null, 8, ["value"]) ]), + createVNode(unref(_sfc_main$6), { tab: "controlnet" }), + createVNode(unref(_sfc_main$7), { tab: "controlnet" }), createBaseVNode("div", _hoisted_8$2, [ createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { trigger: withCtx(() => [ _hoisted_9$2 ]), - default: withCtx(() => [ - createTextVNode(' Guidance scale indicates how much should model stay close to the prompt. Higher values might be exactly what you want, but generated images might have some artefacts. Lower values indicates that model can "dream" about this prompt more. '), - _hoisted_10$2 - ]), - _: 1 - }), - createVNode(unref(NSlider), { - value: unref(settings).data.settings.controlnet.cfg_scale, - "onUpdate:value": _cache[4] || (_cache[4] = ($event) => unref(settings).data.settings.controlnet.cfg_scale = $event), - min: 1, - max: 30, - step: 0.5, - style: { "margin-right": "12px" } - }, null, 8, ["value"]), - createVNode(unref(NInputNumber), { - value: unref(settings).data.settings.controlnet.cfg_scale, - "onUpdate:value": _cache[5] || (_cache[5] = ($event) => unref(settings).data.settings.controlnet.cfg_scale = $event), - size: "small", - style: { "min-width": "96px", "width": "96px" }, - min: 1, - max: 30, - step: 0.5 - }, null, 8, ["value"]) - ]), - createBaseVNode("div", _hoisted_11$2, [ - createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { - trigger: withCtx(() => [ - _hoisted_12$2 - ]), default: withCtx(() => [ createTextVNode(" Number of images to generate after each other. ") ]), @@ -381,27 +373,27 @@ const _sfc_main$3 = /* @__PURE__ */ defineComponent({ }), createVNode(unref(NSlider), { value: unref(settings).data.settings.controlnet.batch_count, - "onUpdate:value": _cache[6] || (_cache[6] = ($event) => unref(settings).data.settings.controlnet.batch_count = $event), + "onUpdate:value": _cache[4] || (_cache[4] = ($event) => unref(settings).data.settings.controlnet.batch_count = $event), min: 1, max: 9, style: { "margin-right": "12px" } }, null, 8, ["value"]), createVNode(unref(NInputNumber), { value: unref(settings).data.settings.controlnet.batch_count, - "onUpdate:value": _cache[7] || (_cache[7] = ($event) => unref(settings).data.settings.controlnet.batch_count = $event), + "onUpdate:value": _cache[5] || (_cache[5] = ($event) => unref(settings).data.settings.controlnet.batch_count = $event), size: "small", style: { "min-width": "96px", "width": "96px" }, min: 1, max: 9 }, null, 8, ["value"]) ]), - createVNode(unref(_sfc_main$6), { + createVNode(unref(_sfc_main$8), { "batch-size-object": unref(settings).data.settings.controlnet }, null, 8, ["batch-size-object"]), - createBaseVNode("div", _hoisted_13$2, [ + createBaseVNode("div", _hoisted_10$2, [ createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { trigger: withCtx(() => [ - _hoisted_14$2 + _hoisted_11$2 ]), default: withCtx(() => [ createTextVNode(" How much should the ControlNet affect the image. ") @@ -410,7 +402,7 @@ const _sfc_main$3 = /* @__PURE__ */ defineComponent({ }), createVNode(unref(NSlider), { value: unref(settings).data.settings.controlnet.controlnet_conditioning_scale, - "onUpdate:value": _cache[8] || (_cache[8] = ($event) => unref(settings).data.settings.controlnet.controlnet_conditioning_scale = $event), + "onUpdate:value": _cache[6] || (_cache[6] = ($event) => unref(settings).data.settings.controlnet.controlnet_conditioning_scale = $event), min: 0.1, max: 2, style: { "margin-right": "12px" }, @@ -418,7 +410,7 @@ const _sfc_main$3 = /* @__PURE__ */ defineComponent({ }, null, 8, ["value"]), createVNode(unref(NInputNumber), { value: unref(settings).data.settings.controlnet.controlnet_conditioning_scale, - "onUpdate:value": _cache[9] || (_cache[9] = ($event) => unref(settings).data.settings.controlnet.controlnet_conditioning_scale = $event), + "onUpdate:value": _cache[7] || (_cache[7] = ($event) => unref(settings).data.settings.controlnet.controlnet_conditioning_scale = $event), size: "small", style: { "min-width": "96px", "width": "96px" }, min: 0.1, @@ -426,10 +418,10 @@ const _sfc_main$3 = /* @__PURE__ */ defineComponent({ step: 0.025 }, null, 8, ["value"]) ]), - createBaseVNode("div", _hoisted_15$2, [ + createBaseVNode("div", _hoisted_12$1, [ createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { trigger: withCtx(() => [ - _hoisted_16$2 + _hoisted_13$1 ]), default: withCtx(() => [ createTextVNode(" What resolution to use for the image processing. This process does not affect the final result but can affect the quality of the ControlNet processing. ") @@ -438,7 +430,7 @@ const _sfc_main$3 = /* @__PURE__ */ defineComponent({ }), createVNode(unref(NSlider), { value: unref(settings).data.settings.controlnet.detection_resolution, - "onUpdate:value": _cache[10] || (_cache[10] = ($event) => unref(settings).data.settings.controlnet.detection_resolution = $event), + "onUpdate:value": _cache[8] || (_cache[8] = ($event) => unref(settings).data.settings.controlnet.detection_resolution = $event), min: 128, max: 2048, style: { "margin-right": "12px" }, @@ -446,7 +438,7 @@ const _sfc_main$3 = /* @__PURE__ */ defineComponent({ }, null, 8, ["value"]), createVNode(unref(NInputNumber), { value: unref(settings).data.settings.controlnet.detection_resolution, - "onUpdate:value": _cache[11] || (_cache[11] = ($event) => unref(settings).data.settings.controlnet.detection_resolution = $event), + "onUpdate:value": _cache[9] || (_cache[9] = ($event) => unref(settings).data.settings.controlnet.detection_resolution = $event), size: "small", style: { "min-width": "96px", "width": "96px" }, min: 128, @@ -454,45 +446,45 @@ const _sfc_main$3 = /* @__PURE__ */ defineComponent({ step: 8 }, null, 8, ["value"]) ]), - createBaseVNode("div", _hoisted_17$2, [ + createBaseVNode("div", _hoisted_14$1, [ createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { trigger: withCtx(() => [ - _hoisted_18$1 + _hoisted_15$1 ]), default: withCtx(() => [ createTextVNode(" Seed is a number that represents the starting canvas of your image. If you want to create the same image as your friend, you can use the same settings and seed to do so. "), - _hoisted_19$1 + _hoisted_16$1 ]), _: 1 }), createVNode(unref(NInputNumber), { value: unref(settings).data.settings.controlnet.seed, - "onUpdate:value": _cache[12] || (_cache[12] = ($event) => unref(settings).data.settings.controlnet.seed = $event), + "onUpdate:value": _cache[10] || (_cache[10] = ($event) => unref(settings).data.settings.controlnet.seed = $event), size: "small", min: -1, max: 999999999999, style: { "flex-grow": "1" } }, null, 8, ["value"]) ]), - createBaseVNode("div", _hoisted_20$1, [ - _hoisted_21$1, + createBaseVNode("div", _hoisted_17$1, [ + _hoisted_18$1, createVNode(unref(NSwitch), { value: unref(settings).data.settings.controlnet.is_preprocessed, - "onUpdate:value": _cache[13] || (_cache[13] = ($event) => unref(settings).data.settings.controlnet.is_preprocessed = $event) + "onUpdate:value": _cache[11] || (_cache[11] = ($event) => unref(settings).data.settings.controlnet.is_preprocessed = $event) }, null, 8, ["value"]) ]), - createBaseVNode("div", _hoisted_22$1, [ - _hoisted_23$1, + createBaseVNode("div", _hoisted_19$1, [ + _hoisted_20$1, createVNode(unref(NSwitch), { value: unref(settings).data.settings.controlnet.save_preprocessed, - "onUpdate:value": _cache[14] || (_cache[14] = ($event) => unref(settings).data.settings.controlnet.save_preprocessed = $event) + "onUpdate:value": _cache[12] || (_cache[12] = ($event) => unref(settings).data.settings.controlnet.save_preprocessed = $event) }, null, 8, ["value"]) ]), - createBaseVNode("div", _hoisted_24$1, [ - _hoisted_25, + createBaseVNode("div", _hoisted_21, [ + _hoisted_22, createVNode(unref(NSwitch), { value: unref(settings).data.settings.controlnet.return_preprocessed, - "onUpdate:value": _cache[15] || (_cache[15] = ($event) => unref(settings).data.settings.controlnet.return_preprocessed = $event) + "onUpdate:value": _cache[13] || (_cache[13] = ($event) => unref(settings).data.settings.controlnet.return_preprocessed = $event) }, null, 8, ["value"]) ]) ]), @@ -500,20 +492,22 @@ const _sfc_main$3 = /* @__PURE__ */ defineComponent({ }) ]), _: 1 - }) + }), + createVNode(unref(_sfc_main$9), { tab: "controlnet" }), + createVNode(unref(_sfc_main$a), { tab: "controlnet" }) ]), _: 1 }), createVNode(unref(NGi), null, { default: withCtx(() => [ - createVNode(unref(_sfc_main$7), { generate }), - createVNode(unref(_sfc_main$8), { + createVNode(unref(_sfc_main$b), { generate }), + createVNode(unref(_sfc_main$c), { "current-image": unref(global).state.controlnet.currentImage, images: unref(global).state.controlnet.images, data: unref(settings).data.settings.controlnet, - onImageClicked: _cache[16] || (_cache[16] = ($event) => unref(global).state.controlnet.currentImage = $event) + onImageClicked: _cache[14] || (_cache[14] = ($event) => unref(global).state.controlnet.currentImage = $event) }, null, 8, ["current-image", "images", "data"]), - createVNode(unref(_sfc_main$9), { + createVNode(unref(_sfc_main$d), { style: { "margin-top": "12px" }, "gen-data": unref(global).state.controlnet.genData }, null, 8, ["gen-data"]) @@ -527,28 +521,19 @@ const _sfc_main$3 = /* @__PURE__ */ defineComponent({ }; } }); -const ControlNet = /* @__PURE__ */ _export_sfc(_sfc_main$3, [["__scopeId", "data-v-efacc8fd"]]); -const _withScopeId$1 = (n) => (pushScopeId("data-v-9c556ef8"), n = n(), popScopeId(), n); +const ControlNet = /* @__PURE__ */ _export_sfc(_sfc_main$3, [["__scopeId", "data-v-d4ff54ab"]]); +const _withScopeId$1 = (n) => (pushScopeId("data-v-a4145f6c"), n = n(), popScopeId(), n); const _hoisted_1$1 = { style: { "margin": "0 12px" } }; const _hoisted_2$1 = { class: "flex-container" }; const _hoisted_3$1 = /* @__PURE__ */ _withScopeId$1(() => /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Steps", -1)); const _hoisted_4$1 = /* @__PURE__ */ _withScopeId$1(() => /* @__PURE__ */ createBaseVNode("b", { class: "highlight" }, "We recommend using 20-50 steps for most images.", -1)); const _hoisted_5$1 = { class: "flex-container" }; -const _hoisted_6$1 = /* @__PURE__ */ _withScopeId$1(() => /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "CFG Scale", -1)); -const _hoisted_7$1 = /* @__PURE__ */ _withScopeId$1(() => /* @__PURE__ */ createBaseVNode("b", { class: "highlight" }, "We recommend using 3-15 for most images.", -1)); -const _hoisted_8$1 = { - key: 0, - class: "flex-container" -}; -const _hoisted_9$1 = /* @__PURE__ */ _withScopeId$1(() => /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Self Attention Scale", -1)); -const _hoisted_10$1 = /* @__PURE__ */ _withScopeId$1(() => /* @__PURE__ */ createBaseVNode("b", { class: "highlight" }, "PyTorch ONLY.", -1)); -const _hoisted_11$1 = { class: "flex-container" }; -const _hoisted_12$1 = /* @__PURE__ */ _withScopeId$1(() => /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Batch Count", -1)); -const _hoisted_13$1 = { class: "flex-container" }; -const _hoisted_14$1 = /* @__PURE__ */ _withScopeId$1(() => /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Denoising Strength", -1)); -const _hoisted_15$1 = { class: "flex-container" }; -const _hoisted_16$1 = /* @__PURE__ */ _withScopeId$1(() => /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Seed", -1)); -const _hoisted_17$1 = /* @__PURE__ */ _withScopeId$1(() => /* @__PURE__ */ createBaseVNode("b", { class: "highlight" }, "For random seed use -1.", -1)); +const _hoisted_6$1 = /* @__PURE__ */ _withScopeId$1(() => /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Batch Count", -1)); +const _hoisted_7$1 = { class: "flex-container" }; +const _hoisted_8$1 = /* @__PURE__ */ _withScopeId$1(() => /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Denoising Strength", -1)); +const _hoisted_9$1 = { class: "flex-container" }; +const _hoisted_10$1 = /* @__PURE__ */ _withScopeId$1(() => /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Seed", -1)); +const _hoisted_11$1 = /* @__PURE__ */ _withScopeId$1(() => /* @__PURE__ */ createBaseVNode("b", { class: "highlight" }, "For random seed use -1.", -1)); const _sfc_main$2 = /* @__PURE__ */ defineComponent({ __name: "Img2Img", setup(__props) { @@ -601,7 +586,41 @@ const _sfc_main$2 = /* @__PURE__ */ defineComponent({ prompt_to_prompt: settings.data.settings.api.prompt_to_prompt } }, - model: (_a = settings.data.settings.model) == null ? void 0 : _a.path + ...settings.data.settings.img2img.deepshrink.enabled ? { + flags: { + deepshrink: { + early_out: settings.data.settings.img2img.deepshrink.early_out, + depth_1: settings.data.settings.img2img.deepshrink.depth_1, + stop_at_1: settings.data.settings.img2img.deepshrink.stop_at_1, + depth_2: settings.data.settings.img2img.deepshrink.depth_2, + stop_at_2: settings.data.settings.img2img.deepshrink.stop_at_2, + scaler: settings.data.settings.img2img.deepshrink.scaler, + base_scale: settings.data.settings.img2img.deepshrink.base_scale + } + } + } : {}, + model: (_a = settings.data.settings.model) == null ? void 0 : _a.path, + flags: { + ...settings.data.settings.img2img.highres.enabled ? { + highres_fix: { + mode: settings.data.settings.img2img.highres.mode, + image_upscaler: settings.data.settings.img2img.highres.image_upscaler, + scale: settings.data.settings.img2img.highres.scale, + latent_scale_mode: settings.data.settings.img2img.highres.latent_scale_mode, + strength: settings.data.settings.img2img.highres.strength, + steps: settings.data.settings.img2img.highres.steps, + antialiased: settings.data.settings.img2img.highres.antialiased + } + } : {}, + ...settings.data.settings.img2img.upscale.enabled ? { + upscale: { + upscale_factor: settings.data.settings.img2img.upscale.upscale_factor, + tile_size: settings.data.settings.img2img.upscale.tile_size, + tile_padding: settings.data.settings.img2img.upscale.tile_padding, + model: settings.data.settings.img2img.upscale.model + } + } : {} + } }) }).then((res) => { if (!res.ok) { @@ -657,196 +676,139 @@ const _sfc_main$2 = /* @__PURE__ */ defineComponent({ vertical: "", class: "left-container" }, { - default: withCtx(() => { - var _a; - return [ - createVNode(unref(Prompt), { tab: "img2img" }), - createVNode(unref(_sfc_main$4), { type: "img2img" }), - createVNode(unref(_sfc_main$5), { - "dimensions-object": unref(settings).data.settings.img2img - }, null, 8, ["dimensions-object"]), - createBaseVNode("div", _hoisted_2$1, [ - createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { - trigger: withCtx(() => [ - _hoisted_3$1 - ]), - default: withCtx(() => [ - createTextVNode(" Number of steps to take in the diffusion process. Higher values will result in more detailed images but will take longer to generate. There is also a point of diminishing returns around 100 steps. "), - _hoisted_4$1 - ]), - _: 1 - }), - createVNode(unref(NSlider), { - value: unref(settings).data.settings.img2img.steps, - "onUpdate:value": _cache[1] || (_cache[1] = ($event) => unref(settings).data.settings.img2img.steps = $event), - min: 5, - max: 300, - style: { "margin-right": "12px" } - }, null, 8, ["value"]), - createVNode(unref(NInputNumber), { - value: unref(settings).data.settings.img2img.steps, - "onUpdate:value": _cache[2] || (_cache[2] = ($event) => unref(settings).data.settings.img2img.steps = $event), - size: "small", - style: { "min-width": "96px", "width": "96px" }, - min: 5, - max: 300 - }, null, 8, ["value"]) - ]), - createBaseVNode("div", _hoisted_5$1, [ - createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { - trigger: withCtx(() => [ - _hoisted_6$1 - ]), - default: withCtx(() => [ - createTextVNode(' Guidance scale indicates how much should model stay close to the prompt. Higher values might be exactly what you want, but generated images might have some artefacts. Lower values indicates that model can "dream" about this prompt more. '), - _hoisted_7$1 - ]), - _: 1 - }), - createVNode(unref(NSlider), { - value: unref(settings).data.settings.img2img.cfg_scale, - "onUpdate:value": _cache[3] || (_cache[3] = ($event) => unref(settings).data.settings.img2img.cfg_scale = $event), - min: 1, - max: 30, - step: 0.5, - style: { "margin-right": "12px" } - }, null, 8, ["value"]), - createVNode(unref(NInputNumber), { - value: unref(settings).data.settings.img2img.cfg_scale, - "onUpdate:value": _cache[4] || (_cache[4] = ($event) => unref(settings).data.settings.img2img.cfg_scale = $event), - size: "small", - style: { "min-width": "96px", "width": "96px" }, - min: 1, - max: 30, - step: 0.5 - }, null, 8, ["value"]) - ]), - Number.isInteger(unref(settings).data.settings.img2img.sampler) && ((_a = unref(settings).data.settings.model) == null ? void 0 : _a.backend) === "PyTorch" ? (openBlock(), createElementBlock("div", _hoisted_8$1, [ - createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { - trigger: withCtx(() => [ - _hoisted_9$1 - ]), - default: withCtx(() => [ - _hoisted_10$1, - createTextVNode(" If self attention is >0, SAG will guide the model and improve the quality of the image at the cost of speed. Higher values will follow the guidance more closely, which can lead to better, more sharp and detailed outputs. ") - ]), - _: 1 - }), - createVNode(unref(NSlider), { - value: unref(settings).data.settings.img2img.self_attention_scale, - "onUpdate:value": _cache[5] || (_cache[5] = ($event) => unref(settings).data.settings.img2img.self_attention_scale = $event), - min: 0, - max: 1, - step: 0.05, - style: { "margin-right": "12px" } - }, null, 8, ["value"]), - createVNode(unref(NInputNumber), { - value: unref(settings).data.settings.img2img.self_attention_scale, - "onUpdate:value": _cache[6] || (_cache[6] = ($event) => unref(settings).data.settings.img2img.self_attention_scale = $event), - size: "small", - style: { "min-width": "96px", "width": "96px" }, - min: 0, - max: 1, - step: 0.05 - }, null, 8, ["value"]) - ])) : createCommentVNode("", true), - createBaseVNode("div", _hoisted_11$1, [ - createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { - trigger: withCtx(() => [ - _hoisted_12$1 - ]), - default: withCtx(() => [ - createTextVNode(" Number of images to generate after each other. ") - ]), - _: 1 - }), - createVNode(unref(NSlider), { - value: unref(settings).data.settings.img2img.batch_count, - "onUpdate:value": _cache[7] || (_cache[7] = ($event) => unref(settings).data.settings.img2img.batch_count = $event), - min: 1, - max: 9, - style: { "margin-right": "12px" } - }, null, 8, ["value"]), - createVNode(unref(NInputNumber), { - value: unref(settings).data.settings.img2img.batch_count, - "onUpdate:value": _cache[8] || (_cache[8] = ($event) => unref(settings).data.settings.img2img.batch_count = $event), - size: "small", - style: { "min-width": "96px", "width": "96px" }, - min: 1, - max: 9 - }, null, 8, ["value"]) - ]), - createVNode(unref(_sfc_main$6), { - "batch-size-object": unref(settings).data.settings.img2img - }, null, 8, ["batch-size-object"]), - createBaseVNode("div", _hoisted_13$1, [ - createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { - trigger: withCtx(() => [ - _hoisted_14$1 - ]), - default: withCtx(() => [ - createTextVNode(" Lower values will stick more to the original image, 0.5-0.75 is ideal ") - ]), - _: 1 - }), - createVNode(unref(NSlider), { - value: unref(settings).data.settings.img2img.denoising_strength, - "onUpdate:value": _cache[9] || (_cache[9] = ($event) => unref(settings).data.settings.img2img.denoising_strength = $event), - min: 0.1, - max: 1, - style: { "margin-right": "12px" }, - step: 0.025 - }, null, 8, ["value"]), - createVNode(unref(NInputNumber), { - value: unref(settings).data.settings.img2img.denoising_strength, - "onUpdate:value": _cache[10] || (_cache[10] = ($event) => unref(settings).data.settings.img2img.denoising_strength = $event), - size: "small", - style: { "min-width": "96px", "width": "96px" }, - min: 0.1, - max: 1, - step: 0.025 - }, null, 8, ["value"]) - ]), - createBaseVNode("div", _hoisted_15$1, [ - createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { - trigger: withCtx(() => [ - _hoisted_16$1 - ]), - default: withCtx(() => [ - createTextVNode(" Seed is a number that represents the starting canvas of your image. If you want to create the same image as your friend, you can use the same settings and seed to do so. "), - _hoisted_17$1 - ]), - _: 1 - }), - createVNode(unref(NInputNumber), { - value: unref(settings).data.settings.img2img.seed, - "onUpdate:value": _cache[11] || (_cache[11] = ($event) => unref(settings).data.settings.img2img.seed = $event), - size: "small", - min: -1, - max: 999999999999, - style: { "flex-grow": "1" } - }, null, 8, ["value"]) - ]) - ]; - }), + default: withCtx(() => [ + createVNode(unref(Prompt), { tab: "img2img" }), + createVNode(unref(_sfc_main$4), { type: "img2img" }), + createVNode(unref(_sfc_main$5), { + "dimensions-object": unref(settings).data.settings.img2img + }, null, 8, ["dimensions-object"]), + createBaseVNode("div", _hoisted_2$1, [ + createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { + trigger: withCtx(() => [ + _hoisted_3$1 + ]), + default: withCtx(() => [ + createTextVNode(" Number of steps to take in the diffusion process. Higher values will result in more detailed images but will take longer to generate. There is also a point of diminishing returns around 100 steps. "), + _hoisted_4$1 + ]), + _: 1 + }), + createVNode(unref(NSlider), { + value: unref(settings).data.settings.img2img.steps, + "onUpdate:value": _cache[1] || (_cache[1] = ($event) => unref(settings).data.settings.img2img.steps = $event), + min: 5, + max: 300, + style: { "margin-right": "12px" } + }, null, 8, ["value"]), + createVNode(unref(NInputNumber), { + value: unref(settings).data.settings.img2img.steps, + "onUpdate:value": _cache[2] || (_cache[2] = ($event) => unref(settings).data.settings.img2img.steps = $event), + size: "small", + style: { "min-width": "96px", "width": "96px" }, + min: 5, + max: 300 + }, null, 8, ["value"]) + ]), + createVNode(unref(_sfc_main$6), { tab: "img2img" }), + createVNode(unref(_sfc_main$7), { tab: "img2img" }), + createBaseVNode("div", _hoisted_5$1, [ + createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { + trigger: withCtx(() => [ + _hoisted_6$1 + ]), + default: withCtx(() => [ + createTextVNode(" Number of images to generate after each other. ") + ]), + _: 1 + }), + createVNode(unref(NSlider), { + value: unref(settings).data.settings.img2img.batch_count, + "onUpdate:value": _cache[3] || (_cache[3] = ($event) => unref(settings).data.settings.img2img.batch_count = $event), + min: 1, + max: 9, + style: { "margin-right": "12px" } + }, null, 8, ["value"]), + createVNode(unref(NInputNumber), { + value: unref(settings).data.settings.img2img.batch_count, + "onUpdate:value": _cache[4] || (_cache[4] = ($event) => unref(settings).data.settings.img2img.batch_count = $event), + size: "small", + style: { "min-width": "96px", "width": "96px" }, + min: 1, + max: 9 + }, null, 8, ["value"]) + ]), + createVNode(unref(_sfc_main$8), { + "batch-size-object": unref(settings).data.settings.img2img + }, null, 8, ["batch-size-object"]), + createBaseVNode("div", _hoisted_7$1, [ + createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { + trigger: withCtx(() => [ + _hoisted_8$1 + ]), + default: withCtx(() => [ + createTextVNode(" Lower values will stick more to the original image, 0.5-0.75 is ideal ") + ]), + _: 1 + }), + createVNode(unref(NSlider), { + value: unref(settings).data.settings.img2img.denoising_strength, + "onUpdate:value": _cache[5] || (_cache[5] = ($event) => unref(settings).data.settings.img2img.denoising_strength = $event), + min: 0.1, + max: 1, + style: { "margin-right": "12px" }, + step: 0.025 + }, null, 8, ["value"]), + createVNode(unref(NInputNumber), { + value: unref(settings).data.settings.img2img.denoising_strength, + "onUpdate:value": _cache[6] || (_cache[6] = ($event) => unref(settings).data.settings.img2img.denoising_strength = $event), + size: "small", + style: { "min-width": "96px", "width": "96px" }, + min: 0.1, + max: 1, + step: 0.025 + }, null, 8, ["value"]) + ]), + createBaseVNode("div", _hoisted_9$1, [ + createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { + trigger: withCtx(() => [ + _hoisted_10$1 + ]), + default: withCtx(() => [ + createTextVNode(" Seed is a number that represents the starting canvas of your image. If you want to create the same image as your friend, you can use the same settings and seed to do so. "), + _hoisted_11$1 + ]), + _: 1 + }), + createVNode(unref(NInputNumber), { + value: unref(settings).data.settings.img2img.seed, + "onUpdate:value": _cache[7] || (_cache[7] = ($event) => unref(settings).data.settings.img2img.seed = $event), + size: "small", + min: -1, + max: 999999999999, + style: { "flex-grow": "1" } + }, null, 8, ["value"]) + ]) + ]), _: 1 }) ]), _: 1 - }) + }), + createVNode(unref(_sfc_main$9), { tab: "img2img" }), + createVNode(unref(_sfc_main$a), { tab: "img2img" }) ]), _: 1 }), createVNode(unref(NGi), null, { default: withCtx(() => [ - createVNode(unref(_sfc_main$7), { generate }), - createVNode(unref(_sfc_main$8), { + createVNode(unref(_sfc_main$b), { generate }), + createVNode(unref(_sfc_main$c), { "current-image": unref(global).state.img2img.currentImage, images: unref(global).state.img2img.images, data: unref(settings).data.settings.img2img, - onImageClicked: _cache[12] || (_cache[12] = ($event) => unref(global).state.img2img.currentImage = $event) + onImageClicked: _cache[8] || (_cache[8] = ($event) => unref(global).state.img2img.currentImage = $event) }, null, 8, ["current-image", "images", "data"]), - createVNode(unref(_sfc_main$9), { + createVNode(unref(_sfc_main$d), { style: { "margin-top": "12px" }, "gen-data": unref(global).state.img2img.genData }, null, 8, ["gen-data"]) @@ -860,7 +822,7 @@ const _sfc_main$2 = /* @__PURE__ */ defineComponent({ }; } }); -const ImageToImage = /* @__PURE__ */ _export_sfc(_sfc_main$2, [["__scopeId", "data-v-9c556ef8"]]); +const ImageToImage = /* @__PURE__ */ _export_sfc(_sfc_main$2, [["__scopeId", "data-v-a4145f6c"]]); var VueDrawingCanvas = /* @__PURE__ */ defineComponent({ name: "VueDrawingCanvas", props: { @@ -1436,7 +1398,7 @@ var VueDrawingCanvas = /* @__PURE__ */ defineComponent({ }); } }); -const _withScopeId = (n) => (pushScopeId("data-v-7963dde9"), n = n(), popScopeId(), n); +const _withScopeId = (n) => (pushScopeId("data-v-23b19530"), n = n(), popScopeId(), n); const _hoisted_1 = { style: { "margin": "0 12px" } }; const _hoisted_2 = { style: { "display": "inline-flex", "align-items": "center" } }; const _hoisted_3 = /* @__PURE__ */ _withScopeId(() => /* @__PURE__ */ createBaseVNode("svg", { @@ -1460,21 +1422,14 @@ const _hoisted_9 = { class: "flex-container" }; const _hoisted_10 = /* @__PURE__ */ _withScopeId(() => /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Steps", -1)); const _hoisted_11 = /* @__PURE__ */ _withScopeId(() => /* @__PURE__ */ createBaseVNode("b", { class: "highlight" }, "We recommend using 20-50 steps for most images.", -1)); const _hoisted_12 = { class: "flex-container" }; -const _hoisted_13 = /* @__PURE__ */ _withScopeId(() => /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "CFG Scale", -1)); -const _hoisted_14 = /* @__PURE__ */ _withScopeId(() => /* @__PURE__ */ createBaseVNode("b", { class: "highlight" }, "We recommend using 3-15 for most images.", -1)); -const _hoisted_15 = { - key: 0, - class: "flex-container" -}; -const _hoisted_16 = /* @__PURE__ */ _withScopeId(() => /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Self Attention Scale", -1)); -const _hoisted_17 = /* @__PURE__ */ _withScopeId(() => /* @__PURE__ */ createBaseVNode("b", { class: "highlight" }, "PyTorch ONLY.", -1)); +const _hoisted_13 = /* @__PURE__ */ _withScopeId(() => /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Strength", -1)); +const _hoisted_14 = { class: "flex-container" }; +const _hoisted_15 = /* @__PURE__ */ _withScopeId(() => /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Batch Count", -1)); +const _hoisted_16 = { class: "flex-container" }; +const _hoisted_17 = /* @__PURE__ */ _withScopeId(() => /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Batch Size", -1)); const _hoisted_18 = { class: "flex-container" }; -const _hoisted_19 = /* @__PURE__ */ _withScopeId(() => /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Batch Count", -1)); -const _hoisted_20 = { class: "flex-container" }; -const _hoisted_21 = /* @__PURE__ */ _withScopeId(() => /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Batch Size", -1)); -const _hoisted_22 = { class: "flex-container" }; -const _hoisted_23 = /* @__PURE__ */ _withScopeId(() => /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Seed", -1)); -const _hoisted_24 = /* @__PURE__ */ _withScopeId(() => /* @__PURE__ */ createBaseVNode("b", { class: "highlight" }, "For random seed use -1.", -1)); +const _hoisted_19 = /* @__PURE__ */ _withScopeId(() => /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Seed", -1)); +const _hoisted_20 = /* @__PURE__ */ _withScopeId(() => /* @__PURE__ */ createBaseVNode("b", { class: "highlight" }, "For random seed use -1.", -1)); const _sfc_main$1 = /* @__PURE__ */ defineComponent({ __name: "Inpainting", setup(__props) { @@ -1525,7 +1480,41 @@ const _sfc_main$1 = /* @__PURE__ */ defineComponent({ prompt_to_prompt: settings.data.settings.api.prompt_to_prompt } }, - model: (_a = settings.data.settings.model) == null ? void 0 : _a.path + ...settings.data.settings.inpainting.deepshrink.enabled ? { + flags: { + deepshrink: { + early_out: settings.data.settings.inpainting.deepshrink.early_out, + depth_1: settings.data.settings.inpainting.deepshrink.depth_1, + stop_at_1: settings.data.settings.inpainting.deepshrink.stop_at_1, + depth_2: settings.data.settings.inpainting.deepshrink.depth_2, + stop_at_2: settings.data.settings.inpainting.deepshrink.stop_at_2, + scaler: settings.data.settings.inpainting.deepshrink.scaler, + base_scale: settings.data.settings.inpainting.deepshrink.base_scale + } + } + } : {}, + model: (_a = settings.data.settings.model) == null ? void 0 : _a.path, + flags: { + ...settings.data.settings.inpainting.highres.enabled ? { + highres_fix: { + mode: settings.data.settings.inpainting.highres.mode, + image_upscaler: settings.data.settings.inpainting.highres.image_upscaler, + scale: settings.data.settings.inpainting.highres.scale, + latent_scale_mode: settings.data.settings.inpainting.highres.latent_scale_mode, + strength: settings.data.settings.inpainting.highres.strength, + steps: settings.data.settings.inpainting.highres.steps, + antialiased: settings.data.settings.inpainting.highres.antialiased + } + } : {}, + ...settings.data.settings.inpainting.upscale.enabled ? { + upscale: { + upscale_factor: settings.data.settings.inpainting.upscale.upscale_factor, + tile_size: settings.data.settings.inpainting.upscale.tile_size, + tile_padding: settings.data.settings.inpainting.upscale.tile_padding, + model: settings.data.settings.inpainting.upscale.model + } + } : {} + } }) }).then((res) => { if (!res.ok) { @@ -1782,228 +1771,199 @@ const _sfc_main$1 = /* @__PURE__ */ defineComponent({ vertical: "", class: "left-container" }, { - default: withCtx(() => { - var _a; - return [ - createVNode(unref(Prompt), { tab: "inpainting" }), - createVNode(unref(_sfc_main$4), { type: "inpainting" }), - createBaseVNode("div", _hoisted_5, [ - _hoisted_6, - createVNode(unref(NSlider), { - value: unref(settings).data.settings.inpainting.width, - "onUpdate:value": _cache[2] || (_cache[2] = ($event) => unref(settings).data.settings.inpainting.width = $event), - min: 128, - max: 2048, - step: 8, - style: { "margin-right": "12px" } - }, null, 8, ["value"]), - createVNode(unref(NInputNumber), { - value: unref(settings).data.settings.inpainting.width, - "onUpdate:value": _cache[3] || (_cache[3] = ($event) => unref(settings).data.settings.inpainting.width = $event), - size: "small", - style: { "min-width": "96px", "width": "96px" }, - step: 8, - min: 128, - max: 2048 - }, null, 8, ["value"]) - ]), - createBaseVNode("div", _hoisted_7, [ - _hoisted_8, - createVNode(unref(NSlider), { - value: unref(settings).data.settings.inpainting.height, - "onUpdate:value": _cache[4] || (_cache[4] = ($event) => unref(settings).data.settings.inpainting.height = $event), - min: 128, - max: 2048, - step: 8, - style: { "margin-right": "12px" } - }, null, 8, ["value"]), - createVNode(unref(NInputNumber), { - value: unref(settings).data.settings.inpainting.height, - "onUpdate:value": _cache[5] || (_cache[5] = ($event) => unref(settings).data.settings.inpainting.height = $event), - size: "small", - style: { "min-width": "96px", "width": "96px" }, - step: 8, - min: 128, - max: 2048 - }, null, 8, ["value"]) - ]), - createBaseVNode("div", _hoisted_9, [ - createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { - trigger: withCtx(() => [ - _hoisted_10 - ]), - default: withCtx(() => [ - createTextVNode(" Number of steps to take in the diffusion process. Higher values will result in more detailed images but will take longer to generate. There is also a point of diminishing returns around 100 steps. "), - _hoisted_11 - ]), - _: 1 - }), - createVNode(unref(NSlider), { - value: unref(settings).data.settings.inpainting.steps, - "onUpdate:value": _cache[6] || (_cache[6] = ($event) => unref(settings).data.settings.inpainting.steps = $event), - min: 5, - max: 300, - style: { "margin-right": "12px" } - }, null, 8, ["value"]), - createVNode(unref(NInputNumber), { - value: unref(settings).data.settings.inpainting.steps, - "onUpdate:value": _cache[7] || (_cache[7] = ($event) => unref(settings).data.settings.inpainting.steps = $event), - size: "small", - style: { "min-width": "96px", "width": "96px" }, - min: 5, - max: 300 - }, null, 8, ["value"]) - ]), - createBaseVNode("div", _hoisted_12, [ - createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { - trigger: withCtx(() => [ - _hoisted_13 - ]), - default: withCtx(() => [ - createTextVNode(' Guidance scale indicates how much should model stay close to the prompt. Higher values might be exactly what you want, but generated images might have some artefacts. Lower values indicates that model can "dream" about this prompt more. '), - _hoisted_14 - ]), - _: 1 - }), - createVNode(unref(NSlider), { - value: unref(settings).data.settings.inpainting.cfg_scale, - "onUpdate:value": _cache[8] || (_cache[8] = ($event) => unref(settings).data.settings.inpainting.cfg_scale = $event), - min: 1, - max: 30, - step: 0.5, - style: { "margin-right": "12px" } - }, null, 8, ["value"]), - createVNode(unref(NInputNumber), { - value: unref(settings).data.settings.inpainting.cfg_scale, - "onUpdate:value": _cache[9] || (_cache[9] = ($event) => unref(settings).data.settings.inpainting.cfg_scale = $event), - size: "small", - style: { "min-width": "96px", "width": "96px" }, - min: 1, - max: 30, - step: 0.5 - }, null, 8, ["value"]) - ]), - Number.isInteger(unref(settings).data.settings.inpainting.sampler) && ((_a = unref(settings).data.settings.model) == null ? void 0 : _a.backend) === "PyTorch" ? (openBlock(), createElementBlock("div", _hoisted_15, [ - createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { - trigger: withCtx(() => [ - _hoisted_16 - ]), - default: withCtx(() => [ - _hoisted_17, - createTextVNode(" If self attention is >0, SAG will guide the model and improve the quality of the image at the cost of speed. Higher values will follow the guidance more closely, which can lead to better, more sharp and detailed outputs. ") - ]), - _: 1 - }), - createVNode(unref(NSlider), { - value: unref(settings).data.settings.inpainting.self_attention_scale, - "onUpdate:value": _cache[10] || (_cache[10] = ($event) => unref(settings).data.settings.inpainting.self_attention_scale = $event), - min: 0, - max: 1, - step: 0.05, - style: { "margin-right": "12px" } - }, null, 8, ["value"]), - createVNode(unref(NInputNumber), { - value: unref(settings).data.settings.inpainting.self_attention_scale, - "onUpdate:value": _cache[11] || (_cache[11] = ($event) => unref(settings).data.settings.inpainting.self_attention_scale = $event), - size: "small", - style: { "min-width": "96px", "width": "96px" }, - min: 0, - max: 1, - step: 0.05 - }, null, 8, ["value"]) - ])) : createCommentVNode("", true), - createBaseVNode("div", _hoisted_18, [ - createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { - trigger: withCtx(() => [ - _hoisted_19 - ]), - default: withCtx(() => [ - createTextVNode(" Number of images to generate after each other. ") - ]), - _: 1 - }), - createVNode(unref(NSlider), { - value: unref(settings).data.settings.inpainting.batch_count, - "onUpdate:value": _cache[12] || (_cache[12] = ($event) => unref(settings).data.settings.inpainting.batch_count = $event), - min: 1, - max: 9, - style: { "margin-right": "12px" } - }, null, 8, ["value"]), - createVNode(unref(NInputNumber), { - value: unref(settings).data.settings.inpainting.batch_count, - "onUpdate:value": _cache[13] || (_cache[13] = ($event) => unref(settings).data.settings.inpainting.batch_count = $event), - size: "small", - style: { "min-width": "96px", "width": "96px" }, - min: 1, - max: 9 - }, null, 8, ["value"]) - ]), - createBaseVNode("div", _hoisted_20, [ - createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { - trigger: withCtx(() => [ - _hoisted_21 - ]), - default: withCtx(() => [ - createTextVNode(" Number of images to generate in paralel. ") - ]), - _: 1 - }), - createVNode(unref(NSlider), { - value: unref(settings).data.settings.inpainting.batch_size, - "onUpdate:value": _cache[14] || (_cache[14] = ($event) => unref(settings).data.settings.inpainting.batch_size = $event), - min: 1, - max: 9, - style: { "margin-right": "12px" } - }, null, 8, ["value"]), - createVNode(unref(NInputNumber), { - value: unref(settings).data.settings.inpainting.batch_size, - "onUpdate:value": _cache[15] || (_cache[15] = ($event) => unref(settings).data.settings.inpainting.batch_size = $event), - size: "small", - style: { "min-width": "96px", "width": "96px" }, - min: 1, - max: 9 - }, null, 8, ["value"]) - ]), - createBaseVNode("div", _hoisted_22, [ - createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { - trigger: withCtx(() => [ - _hoisted_23 - ]), - default: withCtx(() => [ - createTextVNode(" Seed is a number that represents the starting canvas of your image. If you want to create the same image as your friend, you can use the same settings and seed to do so. "), - _hoisted_24 - ]), - _: 1 - }), - createVNode(unref(NInputNumber), { - value: unref(settings).data.settings.inpainting.seed, - "onUpdate:value": _cache[16] || (_cache[16] = ($event) => unref(settings).data.settings.inpainting.seed = $event), - size: "small", - min: -1, - max: 999999999999, - style: { "flex-grow": "1" } - }, null, 8, ["value"]) - ]) - ]; - }), + default: withCtx(() => [ + createVNode(unref(Prompt), { tab: "inpainting" }), + createVNode(unref(_sfc_main$4), { type: "inpainting" }), + createBaseVNode("div", _hoisted_5, [ + _hoisted_6, + createVNode(unref(NSlider), { + value: unref(settings).data.settings.inpainting.width, + "onUpdate:value": _cache[2] || (_cache[2] = ($event) => unref(settings).data.settings.inpainting.width = $event), + min: 128, + max: 2048, + step: 8, + style: { "margin-right": "12px" } + }, null, 8, ["value"]), + createVNode(unref(NInputNumber), { + value: unref(settings).data.settings.inpainting.width, + "onUpdate:value": _cache[3] || (_cache[3] = ($event) => unref(settings).data.settings.inpainting.width = $event), + size: "small", + style: { "min-width": "96px", "width": "96px" }, + step: 8, + min: 128, + max: 2048 + }, null, 8, ["value"]) + ]), + createBaseVNode("div", _hoisted_7, [ + _hoisted_8, + createVNode(unref(NSlider), { + value: unref(settings).data.settings.inpainting.height, + "onUpdate:value": _cache[4] || (_cache[4] = ($event) => unref(settings).data.settings.inpainting.height = $event), + min: 128, + max: 2048, + step: 8, + style: { "margin-right": "12px" } + }, null, 8, ["value"]), + createVNode(unref(NInputNumber), { + value: unref(settings).data.settings.inpainting.height, + "onUpdate:value": _cache[5] || (_cache[5] = ($event) => unref(settings).data.settings.inpainting.height = $event), + size: "small", + style: { "min-width": "96px", "width": "96px" }, + step: 8, + min: 128, + max: 2048 + }, null, 8, ["value"]) + ]), + createBaseVNode("div", _hoisted_9, [ + createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { + trigger: withCtx(() => [ + _hoisted_10 + ]), + default: withCtx(() => [ + createTextVNode(" Number of steps to take in the diffusion process. Higher values will result in more detailed images but will take longer to generate. There is also a point of diminishing returns around 100 steps. "), + _hoisted_11 + ]), + _: 1 + }), + createVNode(unref(NSlider), { + value: unref(settings).data.settings.inpainting.steps, + "onUpdate:value": _cache[6] || (_cache[6] = ($event) => unref(settings).data.settings.inpainting.steps = $event), + min: 5, + max: 300, + style: { "margin-right": "12px" } + }, null, 8, ["value"]), + createVNode(unref(NInputNumber), { + value: unref(settings).data.settings.inpainting.steps, + "onUpdate:value": _cache[7] || (_cache[7] = ($event) => unref(settings).data.settings.inpainting.steps = $event), + size: "small", + style: { "min-width": "96px", "width": "96px" }, + min: 5, + max: 300 + }, null, 8, ["value"]) + ]), + createVNode(unref(_sfc_main$6), { tab: "inpainting" }), + createVNode(unref(_sfc_main$7), { tab: "inpainting" }), + createBaseVNode("div", _hoisted_12, [ + createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { + trigger: withCtx(() => [ + _hoisted_13 + ]), + default: withCtx(() => [ + createTextVNode(" How much should the masked are be changed from the original ") + ]), + _: 1 + }), + createVNode(unref(NSlider), { + value: unref(settings).data.settings.inpainting.strength, + "onUpdate:value": _cache[8] || (_cache[8] = ($event) => unref(settings).data.settings.inpainting.strength = $event), + min: 0, + max: 1, + step: 0.01, + style: { "margin-right": "12px" } + }, null, 8, ["value"]), + createVNode(unref(NInputNumber), { + value: unref(settings).data.settings.inpainting.strength, + "onUpdate:value": _cache[9] || (_cache[9] = ($event) => unref(settings).data.settings.inpainting.strength = $event), + size: "small", + style: { "min-width": "96px", "width": "96px" }, + min: 0, + max: 1, + step: 0.01 + }, null, 8, ["value"]) + ]), + createBaseVNode("div", _hoisted_14, [ + createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { + trigger: withCtx(() => [ + _hoisted_15 + ]), + default: withCtx(() => [ + createTextVNode(" Number of images to generate after each other. ") + ]), + _: 1 + }), + createVNode(unref(NSlider), { + value: unref(settings).data.settings.inpainting.batch_count, + "onUpdate:value": _cache[10] || (_cache[10] = ($event) => unref(settings).data.settings.inpainting.batch_count = $event), + min: 1, + max: 9, + style: { "margin-right": "12px" } + }, null, 8, ["value"]), + createVNode(unref(NInputNumber), { + value: unref(settings).data.settings.inpainting.batch_count, + "onUpdate:value": _cache[11] || (_cache[11] = ($event) => unref(settings).data.settings.inpainting.batch_count = $event), + size: "small", + style: { "min-width": "96px", "width": "96px" }, + min: 1, + max: 9 + }, null, 8, ["value"]) + ]), + createBaseVNode("div", _hoisted_16, [ + createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { + trigger: withCtx(() => [ + _hoisted_17 + ]), + default: withCtx(() => [ + createTextVNode(" Number of images to generate in paralel. ") + ]), + _: 1 + }), + createVNode(unref(NSlider), { + value: unref(settings).data.settings.inpainting.batch_size, + "onUpdate:value": _cache[12] || (_cache[12] = ($event) => unref(settings).data.settings.inpainting.batch_size = $event), + min: 1, + max: 9, + style: { "margin-right": "12px" } + }, null, 8, ["value"]), + createVNode(unref(NInputNumber), { + value: unref(settings).data.settings.inpainting.batch_size, + "onUpdate:value": _cache[13] || (_cache[13] = ($event) => unref(settings).data.settings.inpainting.batch_size = $event), + size: "small", + style: { "min-width": "96px", "width": "96px" }, + min: 1, + max: 9 + }, null, 8, ["value"]) + ]), + createBaseVNode("div", _hoisted_18, [ + createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { + trigger: withCtx(() => [ + _hoisted_19 + ]), + default: withCtx(() => [ + createTextVNode(" Seed is a number that represents the starting canvas of your image. If you want to create the same image as your friend, you can use the same settings and seed to do so. "), + _hoisted_20 + ]), + _: 1 + }), + createVNode(unref(NInputNumber), { + value: unref(settings).data.settings.inpainting.seed, + "onUpdate:value": _cache[14] || (_cache[14] = ($event) => unref(settings).data.settings.inpainting.seed = $event), + size: "small", + min: -1, + max: 999999999999, + style: { "flex-grow": "1" } + }, null, 8, ["value"]) + ]) + ]), _: 1 }) ]), _: 1 - }) + }), + createVNode(unref(_sfc_main$9), { tab: "inpainting" }), + createVNode(unref(_sfc_main$a), { tab: "inpainting" }) ]), _: 1 }), createVNode(unref(NGi), null, { default: withCtx(() => [ - createVNode(unref(_sfc_main$7), { generate }), - createVNode(unref(_sfc_main$8), { + createVNode(unref(_sfc_main$b), { generate }), + createVNode(unref(_sfc_main$c), { "current-image": unref(global).state.inpainting.currentImage, images: unref(global).state.inpainting.images, data: unref(settings).data.settings.inpainting, - onImageClicked: _cache[17] || (_cache[17] = ($event) => unref(global).state.inpainting.currentImage = $event) + onImageClicked: _cache[15] || (_cache[15] = ($event) => unref(global).state.inpainting.currentImage = $event) }, null, 8, ["current-image", "images", "data"]), - createVNode(unref(_sfc_main$9), { + createVNode(unref(_sfc_main$d), { style: { "margin-top": "12px" }, "gen-data": unref(global).state.inpainting.genData }, null, 8, ["gen-data"]) @@ -2017,7 +1977,7 @@ const _sfc_main$1 = /* @__PURE__ */ defineComponent({ }; } }); -const Inpainting = /* @__PURE__ */ _export_sfc(_sfc_main$1, [["__scopeId", "data-v-7963dde9"]]); +const Inpainting = /* @__PURE__ */ _export_sfc(_sfc_main$1, [["__scopeId", "data-v-23b19530"]]); const _sfc_main = /* @__PURE__ */ defineComponent({ __name: "Image2ImageView", setup(__props) { diff --git a/frontend/dist/assets/ImageBrowserView.js b/frontend/dist/assets/ImageBrowserView.js index e600d20c5..6214a48df 100644 --- a/frontend/dist/assets/ImageBrowserView.js +++ b/frontend/dist/assets/ImageBrowserView.js @@ -1,9 +1,97 @@ -import { d as defineComponent, b6 as useCssVars, a as useState, u as useSettings, R as inject, z as ref, c as computed, bI as urlFromPath, b8 as reactive, b9 as onMounted, q as onUnmounted, o as openBlock, j as createElementBlock, f as createBaseVNode, g as createVNode, h as unref, w as withCtx, F as Fragment, L as renderList, b7 as themeOverridesKey, t as serverUrl, J as NInput, B as NIcon, bd as NModal, s as NGrid, r as NGi, A as NButton, k as createTextVNode, M as NScrollbar, e as createBlock, by as convertToTextString, C as toDisplayString, m as createCommentVNode, bJ as diffusersSchedulerTuple, _ as _export_sfc } from "./index.js"; +import { d as defineComponent, o as openBlock, a as createElementBlock, b as createBaseVNode, b8 as useCssVars, l as useState, u as useSettings, S as inject, B as ref, c as computed, bN as urlFromPath, ba as reactive, bb as onMounted, s as onUnmounted, e as createVNode, f as unref, w as withCtx, F as Fragment, M as renderList, b9 as themeOverridesKey, x as serverUrl, J as NInput, D as NIcon, bh as NModal, v as NGrid, t as NGi, C as NButton, h as createTextVNode, O as NScrollbar, g as createBlock, bC as convertToTextString, E as toDisplayString, k as createCommentVNode, bO as diffusersSchedulerTuple, _ as _export_sfc } from "./index.js"; import { D as Download, _ as _sfc_main$1 } from "./SendOutputTo.vue_vue_type_script_setup_true_lang.js"; -import { G as GridOutline } from "./GridOutline.js"; import { N as NImage, T as TrashBin } from "./TrashBin.js"; -import { a as NSlider } from "./Switch.js"; +import { N as NSlider } from "./Slider.js"; import { N as NDescriptionsItem, a as NDescriptions } from "./DescriptionsItem.js"; +import "./Switch.js"; +const _hoisted_1$1 = { + xmlns: "http://www.w3.org/2000/svg", + "xmlns:xlink": "http://www.w3.org/1999/xlink", + viewBox: "0 0 512 512" +}; +const _hoisted_2$1 = /* @__PURE__ */ createBaseVNode( + "rect", + { + x: "48", + y: "48", + width: "176", + height: "176", + rx: "20", + ry: "20", + fill: "none", + stroke: "currentColor", + "stroke-linecap": "round", + "stroke-linejoin": "round", + "stroke-width": "32" + }, + null, + -1 + /* HOISTED */ +); +const _hoisted_3$1 = /* @__PURE__ */ createBaseVNode( + "rect", + { + x: "288", + y: "48", + width: "176", + height: "176", + rx: "20", + ry: "20", + fill: "none", + stroke: "currentColor", + "stroke-linecap": "round", + "stroke-linejoin": "round", + "stroke-width": "32" + }, + null, + -1 + /* HOISTED */ +); +const _hoisted_4$1 = /* @__PURE__ */ createBaseVNode( + "rect", + { + x: "48", + y: "288", + width: "176", + height: "176", + rx: "20", + ry: "20", + fill: "none", + stroke: "currentColor", + "stroke-linecap": "round", + "stroke-linejoin": "round", + "stroke-width": "32" + }, + null, + -1 + /* HOISTED */ +); +const _hoisted_5 = /* @__PURE__ */ createBaseVNode( + "rect", + { + x: "288", + y: "288", + width: "176", + height: "176", + rx: "20", + ry: "20", + fill: "none", + stroke: "currentColor", + "stroke-linecap": "round", + "stroke-linejoin": "round", + "stroke-width": "32" + }, + null, + -1 + /* HOISTED */ +); +const _hoisted_6 = [_hoisted_2$1, _hoisted_3$1, _hoisted_4$1, _hoisted_5]; +const GridOutline = defineComponent({ + name: "GridOutline", + render: function render(_ctx, _cache) { + return openBlock(), createElementBlock("svg", _hoisted_1$1, _hoisted_6); + } +}); const _hoisted_1 = { style: { "width": "calc(100vw - 98px)", "height": "48px", "border-bottom": "#505050 1px solid", "margin-top": "52px", "display": "flex", "justify-content": "end", "align-items": "center", "padding-right": "24px", "position": "fixed", "top": "0", "z-index": "1" }, class: "top-bar" diff --git a/frontend/dist/assets/ImageOutput.vue_vue_type_script_setup_true_lang.js b/frontend/dist/assets/ImageOutput.vue_vue_type_script_setup_true_lang.js index 57bc1e1e5..d3bb917d7 100644 --- a/frontend/dist/assets/ImageOutput.vue_vue_type_script_setup_true_lang.js +++ b/frontend/dist/assets/ImageOutput.vue_vue_type_script_setup_true_lang.js @@ -1,4 +1,4 @@ -import { d as defineComponent, z as ref, a as useState, o as openBlock, e as createBlock, w as withCtx, h as unref, r as NGi, g as createVNode, B as NIcon, k as createTextVNode, A as NButton, m as createCommentVNode, s as NGrid, c as computed, f as createBaseVNode, j as createElementBlock, F as Fragment, L as renderList, M as NScrollbar, n as NCard } from "./index.js"; +import { d as defineComponent, B as ref, l as useState, o as openBlock, g as createBlock, w as withCtx, f as unref, t as NGi, e as createVNode, D as NIcon, h as createTextVNode, C as NButton, k as createCommentVNode, v as NGrid, c as computed, b as createBaseVNode, a as createElementBlock, F as Fragment, M as renderList, O as NScrollbar, m as NCard } from "./index.js"; import { D as Download, _ as _sfc_main$2 } from "./SendOutputTo.vue_vue_type_script_setup_true_lang.js"; import { T as TrashBin, N as NImage } from "./TrashBin.js"; const _sfc_main$1 = /* @__PURE__ */ defineComponent({ diff --git a/frontend/dist/assets/ImageProcessingView.js b/frontend/dist/assets/ImageProcessingView.js index 413029b4f..416c2946d 100644 --- a/frontend/dist/assets/ImageProcessingView.js +++ b/frontend/dist/assets/ImageProcessingView.js @@ -1,24 +1,24 @@ -import { d as defineComponent, a as useState, u as useSettings, p as useMessage, c as computed, b as upscalerOptions, o as openBlock, j as createElementBlock, g as createVNode, w as withCtx, h as unref, r as NGi, n as NCard, N as NSpace, f as createBaseVNode, i as NSelect, l as NTooltip, k as createTextVNode, s as NGrid, t as serverUrl, v as pushScopeId, x as popScopeId, _ as _export_sfc, e as createBlock, D as NTabPane, E as NTabs } from "./index.js"; +import { d as defineComponent, l as useState, u as useSettings, r as useMessage, c as computed, L as upscalerOptions, o as openBlock, a as createElementBlock, e as createVNode, w as withCtx, f as unref, t as NGi, m as NCard, j as NSpace, b as createBaseVNode, q as NSelect, N as NTooltip, h as createTextVNode, v as NGrid, x as serverUrl, g as createBlock, n as NTabPane, p as NTabs } from "./index.js"; import { _ as _sfc_main$2 } from "./GenerateSection.vue_vue_type_script_setup_true_lang.js"; import { _ as _sfc_main$3 } from "./ImageOutput.vue_vue_type_script_setup_true_lang.js"; import { I as ImageUpload } from "./ImageUpload.js"; -import { a as NSlider } from "./Switch.js"; +import { N as NSlider } from "./Slider.js"; import { N as NInputNumber } from "./InputNumber.js"; import "./SendOutputTo.vue_vue_type_script_setup_true_lang.js"; +import "./Switch.js"; import "./TrashBin.js"; import "./CloudUpload.js"; -const _withScopeId = (n) => (pushScopeId("data-v-5358ed01"), n = n(), popScopeId(), n); const _hoisted_1 = { style: { "margin": "0 12px" } }; const _hoisted_2 = { class: "flex-container" }; -const _hoisted_3 = /* @__PURE__ */ _withScopeId(() => /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Model", -1)); +const _hoisted_3 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Model", -1); const _hoisted_4 = { class: "flex-container" }; -const _hoisted_5 = /* @__PURE__ */ _withScopeId(() => /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Scale Factor", -1)); +const _hoisted_5 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Scale Factor", -1); const _hoisted_6 = { class: "flex-container" }; -const _hoisted_7 = /* @__PURE__ */ _withScopeId(() => /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Tile Size", -1)); +const _hoisted_7 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Tile Size", -1); const _hoisted_8 = { class: "flex-container" }; -const _hoisted_9 = /* @__PURE__ */ _withScopeId(() => /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Tile Padding", -1)); +const _hoisted_9 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Tile Padding", -1); const _sfc_main$1 = /* @__PURE__ */ defineComponent({ - __name: "Upscale", + __name: "ESRGAN", setup(__props) { const global = useState(); const settings = useSettings(); @@ -208,7 +208,6 @@ const _sfc_main$1 = /* @__PURE__ */ defineComponent({ }; } }); -const Upscale = /* @__PURE__ */ _export_sfc(_sfc_main$1, [["__scopeId", "data-v-5358ed01"]]); const _sfc_main = /* @__PURE__ */ defineComponent({ __name: "ImageProcessingView", setup(__props) { @@ -225,7 +224,7 @@ const _sfc_main = /* @__PURE__ */ defineComponent({ name: "upscale" }, { default: withCtx(() => [ - createVNode(unref(Upscale)) + createVNode(unref(_sfc_main$1)) ]), _: 1 }) diff --git a/frontend/dist/assets/ImageUpload.js b/frontend/dist/assets/ImageUpload.js index 892dc106b..5753ae711 100644 --- a/frontend/dist/assets/ImageUpload.js +++ b/frontend/dist/assets/ImageUpload.js @@ -1,4 +1,4 @@ -import { d as defineComponent, z as ref, c as computed, b9 as onMounted, o as openBlock, e as createBlock, w as withCtx, f as createBaseVNode, bU as withModifiers, j as createElementBlock, g as createVNode, h as unref, B as NIcon, C as toDisplayString, n as NCard, v as pushScopeId, x as popScopeId, _ as _export_sfc } from "./index.js"; +import { d as defineComponent, B as ref, c as computed, bb as onMounted, o as openBlock, g as createBlock, w as withCtx, b as createBaseVNode, be as withModifiers, a as createElementBlock, e as createVNode, f as unref, D as NIcon, E as toDisplayString, m as NCard, y as pushScopeId, z as popScopeId, _ as _export_sfc } from "./index.js"; import { C as CloudUpload } from "./CloudUpload.js"; const _withScopeId = (n) => (pushScopeId("data-v-9ed1514f"), n = n(), popScopeId(), n); const _hoisted_1 = { class: "image-container" }; diff --git a/frontend/dist/assets/InputNumber.js b/frontend/dist/assets/InputNumber.js index 65f5be9cc..6d62eb20c 100644 --- a/frontend/dist/assets/InputNumber.js +++ b/frontend/dist/assets/InputNumber.js @@ -1,4 +1,4 @@ -import { d as defineComponent, y as h, aa as c, Q as cB, S as useConfig, T as useTheme, ad as useLocale, ar as useFormItem, z as ref, X as toRef, ae as useMergedState, as as useMemo, K as watch, ag as useRtl, c as computed, bW as rgba, J as NInput, av as resolveWrappedSlot, bX as inputNumberLight, aD as on, ai as resolveSlot, aj as NBaseIcon, bY as XButton, a$ as AddIcon, a1 as call, W as nextTick } from "./index.js"; +import { d as defineComponent, A as h, ab as c, R as cB, T as useConfig, U as useTheme, ae as useLocale, as as useFormItem, B as ref, Y as toRef, af as useMergedState, at as useMemo, K as watch, ah as useRtl, c as computed, bZ as rgba, J as NInput, aw as resolveWrappedSlot, b_ as inputNumberLight, aE as on, aj as resolveSlot, ak as NBaseIcon, b$ as XButton, b0 as AddIcon, a2 as call, X as nextTick } from "./index.js"; const RemoveIcon = defineComponent({ name: "Remove", render() { diff --git a/frontend/dist/assets/ModelPopup.vue_vue_type_script_setup_true_lang.js b/frontend/dist/assets/ModelPopup.vue_vue_type_script_setup_true_lang.js index d63e0de48..86370f87f 100644 --- a/frontend/dist/assets/ModelPopup.vue_vue_type_script_setup_true_lang.js +++ b/frontend/dist/assets/ModelPopup.vue_vue_type_script_setup_true_lang.js @@ -1,4 +1,4 @@ -import { bj as upperFirst, bk as toString, bl as createCompounder, bm as cloneVNode, a3 as provide, P as createInjectionKey, R as inject, a_ as throwError, d as defineComponent, S as useConfig, z as ref, bn as onBeforeUpdate, y as h, bo as indexMap, c as computed, b9 as onMounted, aB as onBeforeUnmount, Q as cB, at as cE, aa as c, ab as cM, a4 as keep, ae as useMergedState, X as toRef, af as watchEffect, bp as onUpdated, K as watch, W as nextTick, T as useTheme, Y as useThemeClass, aw as flatten, aL as VResizeObserver, bq as resolveSlotWithProps, br as withDirectives, bs as vShow, aX as Transition, ba as normalizeStyle, bt as getPreciseEventTarget, aD as on, aC as off, bu as carouselLight, ac as cNotM, ar as useFormItem, ah as createKey, bv as color2Class, L as renderList, aj as NBaseIcon, bw as rateLight, a1 as call, p as useMessage, u as useSettings, b8 as reactive, o as openBlock, e as createBlock, w as withCtx, g as createVNode, j as createElementBlock, h as unref, D as NTabPane, s as NGrid, r as NGi, F as Fragment, f as createBaseVNode, n as NCard, k as createTextVNode, C as toDisplayString, bx as NTag, i as NSelect, A as NButton, E as NTabs, bd as NModal, t as serverUrl } from "./index.js"; +import { bn as upperFirst, bo as toString, bp as createCompounder, bq as cloneVNode, a4 as provide, Q as createInjectionKey, S as inject, a$ as throwError, d as defineComponent, T as useConfig, B as ref, br as onBeforeUpdate, A as h, bs as indexMap, c as computed, bb as onMounted, aC as onBeforeUnmount, R as cB, au as cE, ab as c, ac as cM, a5 as keep, af as useMergedState, Y as toRef, ag as watchEffect, bt as onUpdated, K as watch, X as nextTick, U as useTheme, Z as useThemeClass, ax as flatten, aN as VResizeObserver, bu as resolveSlotWithProps, bv as withDirectives, bw as vShow, aY as Transition, bc as normalizeStyle, bx as getPreciseEventTarget, aE as on, aD as off, by as carouselLight, ad as cNotM, as as useFormItem, ai as createKey, bz as color2Class, M as renderList, ak as NBaseIcon, bA as rateLight, a2 as call, r as useMessage, u as useSettings, ba as reactive, o as openBlock, g as createBlock, w as withCtx, e as createVNode, a as createElementBlock, f as unref, n as NTabPane, v as NGrid, t as NGi, F as Fragment, b as createBaseVNode, m as NCard, h as createTextVNode, E as toDisplayString, bB as NTag, q as NSelect, C as NButton, p as NTabs, bh as NModal, x as serverUrl } from "./index.js"; import { a as NDescriptions, N as NDescriptionsItem } from "./DescriptionsItem.js"; function capitalize(string) { return upperFirst(toString(string).toLowerCase()); @@ -1697,6 +1697,7 @@ const _sfc_main = /* @__PURE__ */ defineComponent({ } }); export { + NRate as N, _sfc_main as _, nsfwIndex as n }; diff --git a/frontend/dist/assets/ModelsView.js b/frontend/dist/assets/ModelsView.js index 0356485f0..1d85d47ff 100644 --- a/frontend/dist/assets/ModelsView.js +++ b/frontend/dist/assets/ModelsView.js @@ -1,8 +1,7 @@ -import { d as defineComponent, y as h, O as replaceable, P as createInjectionKey, Q as cB, R as inject, S as useConfig, T as useTheme, U as popselectLight, c as computed, V as createTreeMate, K as watch, W as nextTick, X as toRef, Y as useThemeClass, Z as NInternalSelectMenu, $ as createTmOptions, a0 as happensIn, a1 as call, a2 as keysOf, z as ref, a3 as provide, a4 as keep, a5 as createRefSetter, a6 as mergeEventHandlers, a7 as omit, a8 as NPopover, a9 as popoverBaseProps, aa as c, ab as cM, ac as cNotM, ad as useLocale, ae as useMergedState, af as watchEffect, ag as useRtl, ah as createKey, ai as resolveSlot, J as NInput, i as NSelect, F as Fragment, aj as NBaseIcon, ak as useAdjustedTo, al as paginationLight, am as useMergedClsPrefix, an as ellipsisLight, ao as onDeactivated, l as NTooltip, ap as mergeProps, aq as useStyle, ar as useFormItem, as as useMemo, at as cE, au as radioLight, av as resolveWrappedSlot, aw as flatten$1, ax as getSlot, ay as depx, az as formatLength, A as NButton, aA as NScrollbar, aB as onBeforeUnmount, aC as off, aD as on, aE as ChevronDownIcon, aF as NDropdown, aG as pxfy, aH as get, aI as NIconSwitchTransition, aJ as NBaseLoading, aK as ChevronRightIcon, q as onUnmounted, aL as VResizeObserver, aM as warn, aN as cssrAnchorMetaName, aO as VVirtualList, aP as NEmpty, aQ as repeat, aR as beforeNextFrameOnce, aS as fadeInScaleUpTransition, aT as iconSwitchTransition, aU as insideModal, aV as insidePopover, aW as createId, aX as Transition, aY as dataTableLight, aZ as loadingBarApiInjectionKey, a_ as throwError, a$ as AddIcon, b0 as NProgress, b1 as NFadeInExpandTransition, b2 as EyeIcon, b3 as fadeInHeightExpandTransition, b4 as Teleport, b5 as uploadLight, o as openBlock, j as createElementBlock, f as createBaseVNode, b6 as useCssVars, h as unref, u as useSettings, b7 as themeOverridesKey, b8 as reactive, b9 as onMounted, g as createVNode, w as withCtx, B as NIcon, L as renderList, ba as normalizeStyle, k as createTextVNode, C as toDisplayString, bb as NText, m as createCommentVNode, _ as _export_sfc, a as useState, p as useMessage, bc as huggingfaceModelsFile, n as NCard, t as serverUrl, v as pushScopeId, x as popScopeId, N as NSpace, bd as NModal, e as createBlock, r as NGi, s as NGrid, be as NDivider, bf as Backends, D as NTabPane, E as NTabs } from "./index.js"; -import { _ as _sfc_main$5, n as nsfwIndex } from "./ModelPopup.vue_vue_type_script_setup_true_lang.js"; -import { G as GridOutline } from "./GridOutline.js"; -import { a as NSlider, N as NSwitch } from "./Switch.js"; +import { d as defineComponent, A as h, P as replaceable, Q as createInjectionKey, R as cB, S as inject, T as useConfig, U as useTheme, V as popselectLight, c as computed, W as createTreeMate, K as watch, X as nextTick, Y as toRef, Z as useThemeClass, $ as NInternalSelectMenu, a0 as createTmOptions, a1 as happensIn, a2 as call, a3 as keysOf, B as ref, a4 as provide, a5 as keep, a6 as createRefSetter, a7 as mergeEventHandlers, a8 as omit, a9 as NPopover, aa as popoverBaseProps, ab as c, ac as cM, ad as cNotM, ae as useLocale, af as useMergedState, ag as watchEffect, ah as useRtl, ai as createKey, aj as resolveSlot, J as NInput, q as NSelect, F as Fragment, ak as NBaseIcon, al as useAdjustedTo, am as paginationLight, an as useMergedClsPrefix, ao as ellipsisLight, ap as onDeactivated, N as NTooltip, aq as mergeProps, ar as useStyle, as as useFormItem, at as useMemo, au as cE, av as radioLight, aw as resolveWrappedSlot, ax as flatten$1, ay as getSlot, az as depx, aA as formatLength, C as NButton, aB as NScrollbar, aC as onBeforeUnmount, aD as off, aE as on, aF as ChevronDownIcon, aG as NDropdown, aH as pxfy, aI as get, aJ as NIconSwitchTransition, aK as NBaseLoading, aL as ChevronRightIcon, aM as cssrAnchorMetaName, s as onUnmounted, aN as VResizeObserver, aO as warn, aP as VVirtualList, aQ as NEmpty, aR as repeat, aS as beforeNextFrameOnce, aT as fadeInScaleUpTransition, aU as iconSwitchTransition, aV as insideModal, aW as insidePopover, aX as createId, aY as Transition, aZ as dataTableLight, a_ as loadingBarApiInjectionKey, a$ as throwError, b0 as AddIcon, b1 as NProgress, b2 as NFadeInExpandTransition, b3 as EyeIcon, b4 as fadeInHeightExpandTransition, b5 as Teleport, b6 as uploadLight, o as openBlock, a as createElementBlock, b as createBaseVNode, b7 as createStaticVNode, b8 as useCssVars, f as unref, b9 as themeOverridesKey, ba as reactive, bb as onMounted, e as createVNode, w as withCtx, D as NIcon, M as renderList, _ as _export_sfc, u as useSettings, k as createCommentVNode, bc as normalizeStyle, h as createTextVNode, E as toDisplayString, bd as NText, g as createBlock, be as withModifiers, l as useState, r as useMessage, bf as huggingfaceModelsFile, m as NCard, x as serverUrl, y as pushScopeId, z as popScopeId, bg as Menu, j as NSpace, bh as NModal, t as NGi, v as NGrid, bi as NDivider, bj as Backends, n as NTabPane, p as NTabs } from "./index.js"; +import { _ as _sfc_main$6, n as nsfwIndex, N as NRate } from "./ModelPopup.vue_vue_type_script_setup_true_lang.js"; import { N as NCheckboxGroup, a as NCheckbox, S as Settings } from "./Settings.js"; +import { N as NSwitch } from "./Switch.js"; import { g as getFilesFromEntries, i as isImageFile, N as NImage, d as download, a as NImageGroup, c as createSettledFileInfo, e as environmentSupportFile, m as matchType, b as createImageDataUrl, T as TrashBin } from "./TrashBin.js"; import { C as CloudUpload } from "./CloudUpload.js"; import "./DescriptionsItem.js"; @@ -1398,8 +1397,8 @@ const RenderSorter = defineComponent({ } }, render() { - const { render: render4, order } = this; - return render4({ + const { render: render7, order } = this; + return render7({ order }); } @@ -1575,8 +1574,8 @@ const RenderFilter = defineComponent({ } }, render() { - const { render: render4, active, show } = this; - return render4({ + const { render: render7, active, show } = this; + return render7({ active, show }); @@ -2786,9 +2785,9 @@ const Cell = defineComponent({ render() { const { isSummary, column, row, renderCell } = this; let cell; - const { render: render4, key, ellipsis } = column; - if (render4 && !isSummary) { - cell = render4(row, this.index); + const { render: render7, key, ellipsis } = column; + if (render7 && !isSummary) { + cell = render7(row, this.index); } else { if (isSummary) { cell = row[key].value; @@ -6398,98 +6397,164 @@ const NUpload = defineComponent({ ); } }); -const _hoisted_1$6 = { +const _hoisted_1$a = { xmlns: "http://www.w3.org/2000/svg", "xmlns:xlink": "http://www.w3.org/1999/xlink", viewBox: "0 0 512 512" }; -const _hoisted_2$6 = /* @__PURE__ */ createBaseVNode( +const _hoisted_2$a = /* @__PURE__ */ createBaseVNode( "path", { - d: "M261.56 101.28a8 8 0 0 0-11.06 0L66.4 277.15a8 8 0 0 0-2.47 5.79L63.9 448a32 32 0 0 0 32 32H192a16 16 0 0 0 16-16V328a8 8 0 0 1 8-8h80a8 8 0 0 1 8 8v136a16 16 0 0 0 16 16h96.06a32 32 0 0 0 32-32V282.94a8 8 0 0 0-2.47-5.79z", - fill: "currentColor" + fill: "none", + stroke: "currentColor", + "stroke-linecap": "round", + "stroke-linejoin": "round", + "stroke-width": "48", + d: "M112 268l144 144l144-144" }, null, -1 /* HOISTED */ ); -const _hoisted_3$6 = /* @__PURE__ */ createBaseVNode( +const _hoisted_3$8 = /* @__PURE__ */ createBaseVNode( "path", { - d: "M490.91 244.15l-74.8-71.56V64a16 16 0 0 0-16-16h-48a16 16 0 0 0-16 16v32l-57.92-55.38C272.77 35.14 264.71 32 256 32c-8.68 0-16.72 3.14-22.14 8.63l-212.7 203.5c-6.22 6-7 15.87-1.34 22.37A16 16 0 0 0 43 267.56L250.5 69.28a8 8 0 0 1 11.06 0l207.52 198.28a16 16 0 0 0 22.59-.44c6.14-6.36 5.63-16.86-.76-22.97z", - fill: "currentColor" + fill: "none", + stroke: "currentColor", + "stroke-linecap": "round", + "stroke-linejoin": "round", + "stroke-width": "48", + d: "M256 392V100" }, null, -1 /* HOISTED */ ); -const _hoisted_4$5 = [_hoisted_2$6, _hoisted_3$6]; -const Home = defineComponent({ - name: "Home", +const _hoisted_4$6 = [_hoisted_2$a, _hoisted_3$8]; +const ArrowDownOutline = defineComponent({ + name: "ArrowDownOutline", render: function render(_ctx, _cache) { - return openBlock(), createElementBlock("svg", _hoisted_1$6, _hoisted_4$5); + return openBlock(), createElementBlock("svg", _hoisted_1$a, _hoisted_4$6); } }); -const _hoisted_1$5 = { +const _hoisted_1$9 = { xmlns: "http://www.w3.org/2000/svg", "xmlns:xlink": "http://www.w3.org/1999/xlink", viewBox: "0 0 512 512" }; -const _hoisted_2$5 = /* @__PURE__ */ createBaseVNode( +const _hoisted_2$9 = /* @__PURE__ */ createStaticVNode('', 5); +const _hoisted_7$1 = [_hoisted_2$9]; +const EyeOffOutline = defineComponent({ + name: "EyeOffOutline", + render: function render2(_ctx, _cache) { + return openBlock(), createElementBlock("svg", _hoisted_1$9, _hoisted_7$1); + } +}); +const _hoisted_1$8 = { + xmlns: "http://www.w3.org/2000/svg", + "xmlns:xlink": "http://www.w3.org/1999/xlink", + viewBox: "0 0 512 512" +}; +const _hoisted_2$8 = /* @__PURE__ */ createBaseVNode( "path", { + d: "M255.66 112c-77.94 0-157.89 45.11-220.83 135.33a16 16 0 0 0-.27 17.77C82.92 340.8 161.8 400 255.66 400c92.84 0 173.34-59.38 221.79-135.25a16.14 16.14 0 0 0 0-17.47C428.89 172.28 347.8 112 255.66 112z", fill: "none", stroke: "currentColor", "stroke-linecap": "round", - "stroke-miterlimit": "10", - "stroke-width": "48", - d: "M88 152h336" + "stroke-linejoin": "round", + "stroke-width": "32" }, null, -1 /* HOISTED */ ); -const _hoisted_3$5 = /* @__PURE__ */ createBaseVNode( - "path", +const _hoisted_3$7 = /* @__PURE__ */ createBaseVNode( + "circle", { + cx: "256", + cy: "256", + r: "80", fill: "none", stroke: "currentColor", - "stroke-linecap": "round", "stroke-miterlimit": "10", - "stroke-width": "48", - d: "M88 256h336" + "stroke-width": "32" }, null, -1 /* HOISTED */ ); -const _hoisted_4$4 = /* @__PURE__ */ createBaseVNode( +const _hoisted_4$5 = [_hoisted_2$8, _hoisted_3$7]; +const EyeOutline = defineComponent({ + name: "EyeOutline", + render: function render3(_ctx, _cache) { + return openBlock(), createElementBlock("svg", _hoisted_1$8, _hoisted_4$5); + } +}); +const _hoisted_1$7 = { + xmlns: "http://www.w3.org/2000/svg", + "xmlns:xlink": "http://www.w3.org/1999/xlink", + viewBox: "0 0 512 512" +}; +const _hoisted_2$7 = /* @__PURE__ */ createBaseVNode( "path", { + d: "M352.92 80C288 80 256 144 256 144s-32-64-96.92-64c-52.76 0-94.54 44.14-95.08 96.81c-1.1 109.33 86.73 187.08 183 252.42a16 16 0 0 0 18 0c96.26-65.34 184.09-143.09 183-252.42c-.54-52.67-42.32-96.81-95.08-96.81z", fill: "none", stroke: "currentColor", "stroke-linecap": "round", - "stroke-miterlimit": "10", - "stroke-width": "48", - d: "M88 360h336" + "stroke-linejoin": "round", + "stroke-width": "32" }, null, -1 /* HOISTED */ ); -const _hoisted_5$2 = [_hoisted_2$5, _hoisted_3$5, _hoisted_4$4]; -const Menu = defineComponent({ - name: "Menu", - render: function render2(_ctx, _cache) { - return openBlock(), createElementBlock("svg", _hoisted_1$5, _hoisted_5$2); +const _hoisted_3$6 = [_hoisted_2$7]; +const HeartOutline = defineComponent({ + name: "HeartOutline", + render: function render4(_ctx, _cache) { + return openBlock(), createElementBlock("svg", _hoisted_1$7, _hoisted_3$6); } }); -const _hoisted_1$4 = { +const _hoisted_1$6 = { + xmlns: "http://www.w3.org/2000/svg", + "xmlns:xlink": "http://www.w3.org/1999/xlink", + viewBox: "0 0 512 512" +}; +const _hoisted_2$6 = /* @__PURE__ */ createBaseVNode( + "path", + { + d: "M261.56 101.28a8 8 0 0 0-11.06 0L66.4 277.15a8 8 0 0 0-2.47 5.79L63.9 448a32 32 0 0 0 32 32H192a16 16 0 0 0 16-16V328a8 8 0 0 1 8-8h80a8 8 0 0 1 8 8v136a16 16 0 0 0 16 16h96.06a32 32 0 0 0 32-32V282.94a8 8 0 0 0-2.47-5.79z", + fill: "currentColor" + }, + null, + -1 + /* HOISTED */ +); +const _hoisted_3$5 = /* @__PURE__ */ createBaseVNode( + "path", + { + d: "M490.91 244.15l-74.8-71.56V64a16 16 0 0 0-16-16h-48a16 16 0 0 0-16 16v32l-57.92-55.38C272.77 35.14 264.71 32 256 32c-8.68 0-16.72 3.14-22.14 8.63l-212.7 203.5c-6.22 6-7 15.87-1.34 22.37A16 16 0 0 0 43 267.56L250.5 69.28a8 8 0 0 1 11.06 0l207.52 198.28a16 16 0 0 0 22.59-.44c6.14-6.36 5.63-16.86-.76-22.97z", + fill: "currentColor" + }, + null, + -1 + /* HOISTED */ +); +const _hoisted_4$4 = [_hoisted_2$6, _hoisted_3$5]; +const Home = defineComponent({ + name: "Home", + render: function render5(_ctx, _cache) { + return openBlock(), createElementBlock("svg", _hoisted_1$6, _hoisted_4$4); + } +}); +const _hoisted_1$5 = { xmlns: "http://www.w3.org/2000/svg", "xmlns:xlink": "http://www.w3.org/1999/xlink", viewBox: "0 0 512 512" }; -const _hoisted_2$4 = /* @__PURE__ */ createBaseVNode( +const _hoisted_2$5 = /* @__PURE__ */ createBaseVNode( "path", { d: "M221.09 64a157.09 157.09 0 1 0 157.09 157.09A157.1 157.1 0 0 0 221.09 64z", @@ -6516,36 +6581,30 @@ const _hoisted_3$4 = /* @__PURE__ */ createBaseVNode( -1 /* HOISTED */ ); -const _hoisted_4$3 = [_hoisted_2$4, _hoisted_3$4]; +const _hoisted_4$3 = [_hoisted_2$5, _hoisted_3$4]; const SearchOutline = defineComponent({ name: "SearchOutline", - render: function render3(_ctx, _cache) { - return openBlock(), createElementBlock("svg", _hoisted_1$4, _hoisted_4$3); + render: function render6(_ctx, _cache) { + return openBlock(), createElementBlock("svg", _hoisted_1$5, _hoisted_4$3); } }); -const _hoisted_1$3 = { +const _hoisted_1$4 = { style: { "width": "calc(100vw - 98px)", "height": "48px", "border-bottom": "#505050 1px solid", "display": "flex", "justify-content": "end", "align-items": "center", "padding-right": "24px", "position": "sticky", "top": "52px", "z-index": "1" }, class: "top-bar" }; -const _hoisted_2$3 = { +const _hoisted_2$4 = { class: "main-container", style: { "margin": "12px", "margin-top": "8px" } }; -const _hoisted_3$3 = { class: "image-grid" }; -const _hoisted_4$2 = { key: 0 }; -const _hoisted_5$1 = ["src", "onClick"]; -const _hoisted_6$1 = { style: { "position": "absolute", "width": "100%", "bottom": "0", "padding": "0 8px", "min-height": "32px", "overflow": "hidden", "box-sizing": "border-box", "backdrop-filter": "blur(12px)" } }; -const _sfc_main$4 = /* @__PURE__ */ defineComponent({ +const _sfc_main$5 = /* @__PURE__ */ defineComponent({ __name: "CivitAIDownload", setup(__props) { useCssVars((_ctx) => { var _a, _b; return { - "6b1de230": unref(settings).data.settings.frontend.image_browser_columns, - "a55b21d8": (_b = (_a = unref(theme)) == null ? void 0 : _a.Card) == null ? void 0 : _b.color + "66e6d45b": (_b = (_a = unref(theme)) == null ? void 0 : _a.Card) == null ? void 0 : _b.color }; }); - const settings = useSettings(); const theme = inject(themeOverridesKey); const loadingLock = ref(false); const currentPage = ref(1); @@ -6553,32 +6612,18 @@ const _sfc_main$4 = /* @__PURE__ */ defineComponent({ const types = ref(""); const currentModel = ref(null); const showModal = ref(false); + const gridRef = ref(null); const scrollComponent = ref(null); const itemFilter = ref(""); - const gridColumnRefs = ref([]); - const currentColumn = ref(0); - const currentRowIndex = ref(0); + const currentIndex = ref(0); const loadingBar = useLoadingBar(); - function imgClick(column_index, item_index) { - currentRowIndex.value = item_index; - currentColumn.value = column_index; - const item = columns.value[column_index][item_index]; + function imgClick(item_index) { + const item = modelData[item_index]; + currentIndex.value = item_index; currentModel.value = item; showModal.value = true; } const modelData = reactive([]); - const columns = computed(() => { - const cols = []; - for (let i = 0; i < settings.data.settings.frontend.image_browser_columns; i++) { - cols.push([]); - } - for (let i = 0; i < modelData.length; i++) { - cols[i % settings.data.settings.frontend.image_browser_columns].push( - modelData[i] - ); - } - return cols; - }); async function refreshImages() { currentPage.value = 1; modelData.splice(0, modelData.length); @@ -6604,16 +6649,11 @@ const _sfc_main$4 = /* @__PURE__ */ defineComponent({ return; } let minBox = 0; - for (const col of gridColumnRefs.value) { - const lastImg = col.childNodes.item( - col.childNodes.length - 2 - ); - const bottombbox = lastImg.getBoundingClientRect().bottom; - if (minBox === 0) { - minBox = bottombbox; - } else if (bottombbox < minBox) { - minBox = bottombbox; - } + const bottombbox = gridRef.value.getBoundingClientRect().bottom; + if (minBox === 0) { + minBox = bottombbox; + } else if (bottombbox < minBox) { + minBox = bottombbox; } if (minBox - 50 < window.innerHeight) { if (loadingLock.value) { @@ -6646,18 +6686,16 @@ const _sfc_main$4 = /* @__PURE__ */ defineComponent({ } }; function moveImage(direction) { - const numColumns = settings.data.settings.frontend.image_browser_columns; + if (currentModel.value === null) { + return; + } if (direction === -1) { - if (currentColumn.value > 0) { - imgClick(currentColumn.value - 1, currentRowIndex.value); - } else { - imgClick(numColumns - 1, currentRowIndex.value - 1); + if (currentIndex.value > 0) { + imgClick(currentIndex.value - 1); } } else if (direction === 1) { - if (currentColumn.value < numColumns - 1) { - imgClick(currentColumn.value + 1, currentRowIndex.value); - } else { - imgClick(0, currentRowIndex.value + 1); + if (currentIndex.value < modelData.length - 1) { + imgClick(currentIndex.value + 1); } } } @@ -6694,12 +6732,12 @@ const _sfc_main$4 = /* @__PURE__ */ defineComponent({ refreshImages(); return (_ctx, _cache) => { return openBlock(), createElementBlock(Fragment, null, [ - createVNode(unref(_sfc_main$5), { + createVNode(unref(_sfc_main$6), { model: currentModel.value, "show-modal": showModal.value, "onUpdate:showModal": _cache[0] || (_cache[0] = ($event) => showModal.value = $event) }, null, 8, ["model", "show-modal"]), - createBaseVNode("div", _hoisted_1$3, [ + createBaseVNode("div", _hoisted_1$4, [ createVNode(unref(NInput), { value: itemFilter.value, "onUpdate:value": _cache[1] || (_cache[1] = ($event) => itemFilter.value = $event), @@ -6729,28 +6767,20 @@ const _sfc_main$4 = /* @__PURE__ */ defineComponent({ value: types.value, "onUpdate:value": _cache[3] || (_cache[3] = ($event) => types.value = $event), options: [ - { - value: "", - label: "All" - }, + { value: "", label: "All" }, { value: "Checkpoint", label: "Checkpoint" }, - { - value: "TextualInversion", - label: "Textual Inversion" - }, - { - value: "LORA", - label: "LORA" - } + { value: "TextualInversion", label: "Textual Inversion" }, + { value: "LORA", label: "LORA" }, + { value: "VAE", label: "VAE" } ], style: { "margin-right": "4px" } }, null, 8, ["value"]), createVNode(unref(NButton), { onClick: refreshImages, - style: { "margin-right": "24px" }, + style: { "margin-right": "24px", "padding": "0px 48px" }, type: "primary" }, { default: withCtx(() => [ @@ -6762,78 +6792,169 @@ const _sfc_main$4 = /* @__PURE__ */ defineComponent({ }) ]), _: 1 - }), - createVNode(unref(NIcon), { - style: { "margin-right": "12px" }, - size: "22" - }, { - default: withCtx(() => [ - createVNode(unref(GridOutline)) - ]), - _: 1 - }), - createVNode(unref(NSlider), { - style: { "width": "30vw" }, - min: 1, - max: 10, - value: unref(settings).data.settings.frontend.image_browser_columns, - "onUpdate:value": _cache[4] || (_cache[4] = ($event) => unref(settings).data.settings.frontend.image_browser_columns = $event) - }, null, 8, ["value"]) + }) ]), - createBaseVNode("div", _hoisted_2$3, [ + createBaseVNode("div", _hoisted_2$4, [ createBaseVNode("div", { ref_key: "scrollComponent", ref: scrollComponent }, [ - createBaseVNode("div", _hoisted_3$3, [ - (openBlock(true), createElementBlock(Fragment, null, renderList(columns.value, (column, column_index) => { + createBaseVNode("div", { + class: "image-grid", + ref_key: "gridRef", + ref: gridRef + }, [ + (openBlock(true), createElementBlock(Fragment, null, renderList(modelData, (item, item_index) => { return openBlock(), createElementBlock("div", { - key: column_index, - class: "image-column", - ref_for: true, - ref_key: "gridColumnRefs", - ref: gridColumnRefs + key: item_index, + style: { "border-radius": "20px", "position": "relative", "border": "1px solid #505050", "overflow": "hidden", "margin-bottom": "8px" } }, [ - (openBlock(true), createElementBlock(Fragment, null, renderList(column, (item, item_index) => { - var _a; - return openBlock(), createElementBlock("div", { - key: item_index, - style: { "border-radius": "20px", "position": "relative", "border": "1px solid #505050", "overflow": "hidden", "margin-bottom": "8px" } - }, [ - ((_a = item.modelVersions[0].images[0]) == null ? void 0 : _a.url) ? (openBlock(), createElementBlock("div", _hoisted_4$2, [ - createBaseVNode("img", { - src: item.modelVersions[0].images[0].url, - style: normalizeStyle({ - width: "100%", - height: "auto", - minHeight: "200px", - cursor: "pointer", - borderRadius: "8px", - filter: unref(nsfwIndex)(item.modelVersions[0].images[0].nsfw) > unref(settings).data.settings.frontend.nsfw_ok_threshold ? "blur(12px)" : "none" - }), - onClick: ($event) => imgClick(column_index, item_index) - }, null, 12, _hoisted_5$1), - createBaseVNode("div", _hoisted_6$1, [ - createVNode(unref(NText), { depth: 2 }, { - default: withCtx(() => [ - createTextVNode(toDisplayString(item.name), 1) - ]), - _: 2 - }, 1024) - ]) - ])) : createCommentVNode("", true) - ]); - }), 128)) + createVNode(unref(_sfc_main$4), { + item, + item_index, + onImgClick: imgClick + }, null, 8, ["item", "item_index"]) ]); }), 128)) - ]) + ], 512) ], 512) ]) ], 64); }; } }); -const CivitAIDownload = /* @__PURE__ */ _export_sfc(_sfc_main$4, [["__scopeId", "data-v-e10a07d2"]]); +const CivitAIDownload = /* @__PURE__ */ _export_sfc(_sfc_main$5, [["__scopeId", "data-v-89afc237"]]); +const _hoisted_1$3 = ["src"]; +const _hoisted_2$3 = { style: { "position": "absolute", "width": "100%", "bottom": "0", "padding": "0 8px 0 12px", "min-height": "32px", "overflow": "hidden", "box-sizing": "border-box", "backdrop-filter": "blur(12px)", "background-color": "rgba(0, 0, 0, 0.3)" } }; +const _hoisted_3$3 = { style: { "display": "flex", "flex-direction": "column" } }; +const _hoisted_4$2 = { style: { "display": "flex", "justify-content": "space-between", "align-items": "center" } }; +const _hoisted_5$1 = { style: { "display": "flex", "align-items": "center" } }; +const _sfc_main$4 = /* @__PURE__ */ defineComponent({ + __name: "CivitAIModelImage", + props: { + item: { + type: Object, + required: true + }, + item_index: { + type: Number, + required: true + } + }, + emits: ["imgClick"], + setup(__props, { emit }) { + const props = __props; + const settings = useSettings(); + const filterOverride = ref(false); + function convertToShortValue(count) { + if (count < 1e3) + return count; + if (count < 1e6) + return `${(count / 1e3).toFixed(1)}k`; + return `${(count / 1e6).toFixed(1)}m`; + } + return (_ctx, _cache) => { + var _a; + return openBlock(), createElementBlock("div", { + style: { "height": "500px", "cursor": "pointer" }, + onClick: _cache[1] || (_cache[1] = ($event) => emit("imgClick", props.item_index)) + }, [ + createBaseVNode("div", { + style: normalizeStyle({ + filter: unref(nsfwIndex)(props.item.modelVersions[0].images[0].nsfw) > unref(settings).data.settings.frontend.nsfw_ok_threshold && !filterOverride.value ? "blur(32px)" : "none", + width: "100%", + height: "100%" + }) + }, [ + ((_a = props.item.modelVersions[0].images[0]) == null ? void 0 : _a.url) ? (openBlock(), createElementBlock("img", { + key: 0, + src: props.item.modelVersions[0].images[0].url, + style: { + width: "100%", + height: "100%", + objectFit: "cover" + } + }, null, 8, _hoisted_1$3)) : createCommentVNode("", true) + ], 4), + createBaseVNode("div", _hoisted_2$3, [ + createBaseVNode("div", _hoisted_3$3, [ + createBaseVNode("div", _hoisted_4$2, [ + createVNode(unref(NRate), { + value: props.item.stats.rating, + "allow-half": "", + size: "small" + }, null, 8, ["value"]), + createBaseVNode("div", _hoisted_5$1, [ + createVNode(unref(NIcon), { + color: "white", + size: "18" + }, { + default: withCtx(() => [ + createVNode(unref(ArrowDownOutline)) + ]), + _: 1 + }), + createVNode(unref(NText), { + size: "small", + style: { "color": "white" } + }, { + default: withCtx(() => [ + createTextVNode(toDisplayString(convertToShortValue(props.item.stats.downloadCount)), 1) + ]), + _: 1 + }), + createVNode(unref(NIcon), { + color: "white", + size: "18", + style: { "margin-left": "4px" } + }, { + default: withCtx(() => [ + createVNode(unref(HeartOutline)) + ]), + _: 1 + }), + createVNode(unref(NText), { + size: "small", + style: { "color": "white" } + }, { + default: withCtx(() => [ + createTextVNode(toDisplayString(convertToShortValue(props.item.stats.favoriteCount)), 1) + ]), + _: 1 + }) + ]) + ]), + createVNode(unref(NText), { depth: 2 }, { + default: withCtx(() => [ + createTextVNode(toDisplayString(__props.item.name), 1) + ]), + _: 1 + }) + ]) + ]), + unref(nsfwIndex)(props.item.modelVersions[0].images[0].nsfw) > unref(settings).data.settings.frontend.nsfw_ok_threshold ? (openBlock(), createBlock(unref(NButton), { + key: 0, + type: "error", + style: { "position": "absolute", "top": "0", "right": "0px", "padding": "0 16px 0 12px", "overflow": "hidden", "box-sizing": "border-box", "border-radius": "0px 0px 0px 8px" }, + onClick: _cache[0] || (_cache[0] = withModifiers(($event) => filterOverride.value = !filterOverride.value, ["stop"])) + }, { + default: withCtx(() => [ + createVNode(unref(NIcon), { + color: "white", + size: "18" + }, { + default: withCtx(() => [ + filterOverride.value ? (openBlock(), createBlock(unref(EyeOffOutline), { key: 0 })) : (openBlock(), createBlock(unref(EyeOutline), { key: 1 })) + ]), + _: 1 + }) + ]), + _: 1 + })) : createCommentVNode("", true) + ]); + }; + } +}); const _withScopeId = (n) => (pushScopeId("data-v-b405f046"), n = n(), popScopeId(), n); const _hoisted_1$2 = { style: { "margin": "18px" } }; const _hoisted_2$2 = { style: { "width": "100%", "display": "inline-flex", "justify-content": "space-between", "align-items": "center" } }; diff --git a/frontend/dist/assets/SendOutputTo.vue_vue_type_script_setup_true_lang.js b/frontend/dist/assets/SendOutputTo.vue_vue_type_script_setup_true_lang.js index beac36b50..4dfd8d921 100644 --- a/frontend/dist/assets/SendOutputTo.vue_vue_type_script_setup_true_lang.js +++ b/frontend/dist/assets/SendOutputTo.vue_vue_type_script_setup_true_lang.js @@ -1,4 +1,4 @@ -import { d as defineComponent, o as openBlock, j as createElementBlock, f as createBaseVNode, bK as useRouter, u as useSettings, a as useState, z as ref, b8 as reactive, K as watch, c as computed, g as createVNode, w as withCtx, h as unref, n as NCard, A as NButton, k as createTextVNode, M as NScrollbar, F as Fragment, L as renderList, C as toDisplayString, be as NDivider, bd as NModal, e as createBlock, r as NGi, s as NGrid, m as createCommentVNode } from "./index.js"; +import { d as defineComponent, o as openBlock, a as createElementBlock, b as createBaseVNode, bP as useRouter, u as useSettings, l as useState, B as ref, ba as reactive, K as watch, c as computed, e as createVNode, w as withCtx, f as unref, m as NCard, C as NButton, h as createTextVNode, O as NScrollbar, F as Fragment, M as renderList, E as toDisplayString, bi as NDivider, bh as NModal, g as createBlock, t as NGi, v as NGrid, k as createCommentVNode } from "./index.js"; import { N as NSwitch } from "./Switch.js"; const _hoisted_1$3 = { xmlns: "http://www.w3.org/2000/svg", diff --git a/frontend/dist/assets/Settings.js b/frontend/dist/assets/Settings.js index a780d5357..28edde7de 100644 --- a/frontend/dist/assets/Settings.js +++ b/frontend/dist/assets/Settings.js @@ -1,4 +1,4 @@ -import { y as h, d as defineComponent, S as useConfig, ar as useFormItem, z as ref, c as computed, ae as useMergedState, a3 as provide, X as toRef, P as createInjectionKey, a1 as call, aa as c, Q as cB, ab as cM, at as cE, aT as iconSwitchTransition, aU as insideModal, aV as insidePopover, R as inject, as as useMemo, T as useTheme, bH as checkboxLight, ag as useRtl, ah as createKey, Y as useThemeClass, aW as createId, av as resolveWrappedSlot, aI as NIconSwitchTransition, aD as on, o as openBlock, j as createElementBlock, f as createBaseVNode } from "./index.js"; +import { A as h, d as defineComponent, T as useConfig, as as useFormItem, B as ref, c as computed, af as useMergedState, a4 as provide, Y as toRef, Q as createInjectionKey, a2 as call, ab as c, R as cB, ac as cM, au as cE, aU as iconSwitchTransition, aV as insideModal, aW as insidePopover, S as inject, at as useMemo, U as useTheme, bM as checkboxLight, ah as useRtl, ai as createKey, Z as useThemeClass, aX as createId, aw as resolveWrappedSlot, aJ as NIconSwitchTransition, aE as on, o as openBlock, a as createElementBlock, b as createBaseVNode } from "./index.js"; const CheckMark = h( "svg", { viewBox: "0 0 64 64", class: "check-icon" }, diff --git a/frontend/dist/assets/SettingsView.js b/frontend/dist/assets/SettingsView.js index da6bca715..cb6a2717a 100644 --- a/frontend/dist/assets/SettingsView.js +++ b/frontend/dist/assets/SettingsView.js @@ -1,9 +1,10 @@ -import { d as defineComponent, u as useSettings, o as openBlock, e as createBlock, w as withCtx, g as createVNode, h as unref, J as NInput, i as NSelect, n as NCard, b8 as reactive, c as computed, by as convertToTextString, z as ref, t as serverUrl, K as watch, a as useState, f as createBaseVNode, j as createElementBlock, L as renderList, bb as NText, k as createTextVNode, C as toDisplayString, F as Fragment, m as createCommentVNode, D as NTabPane, E as NTabs, R as inject, bz as themeKey, A as NButton, l as NTooltip, p as useMessage, bA as useNotification, q as onUnmounted, bB as defaultSettings } from "./index.js"; -import { a as NFormItem, _ as _sfc_main$h, N as NForm } from "./SamplerPicker.vue_vue_type_script_setup_true_lang.js"; -import { N as NSwitch, a as NSlider } from "./Switch.js"; +import { d as defineComponent, u as useSettings, o as openBlock, g as createBlock, w as withCtx, e as createVNode, f as unref, J as NInput, q as NSelect, m as NCard, ba as reactive, c as computed, bC as convertToTextString, B as ref, x as serverUrl, K as watch, l as useState, b as createBaseVNode, a as createElementBlock, M as renderList, bd as NText, h as createTextVNode, E as toDisplayString, F as Fragment, k as createCommentVNode, n as NTabPane, p as NTabs, S as inject, bD as themeKey, C as NButton, N as NTooltip, r as useMessage, bE as useNotification, s as onUnmounted, bF as defaultSettings } from "./index.js"; +import { c as NFormItem, _ as _sfc_main$g, b as _sfc_main$h, a as _sfc_main$i, N as NForm } from "./Upscale.vue_vue_type_script_setup_true_lang.js"; +import { N as NSwitch } from "./Switch.js"; import { N as NInputNumber } from "./InputNumber.js"; +import { N as NSlider } from "./Slider.js"; import "./Settings.js"; -const _sfc_main$g = /* @__PURE__ */ defineComponent({ +const _sfc_main$f = /* @__PURE__ */ defineComponent({ __name: "ControlNetSettings", setup(__props) { const settings = useSettings(); @@ -177,9 +178,17 @@ const _sfc_main$g = /* @__PURE__ */ defineComponent({ ]), _: 1 }), - createVNode(_sfc_main$h, { + createVNode(unref(_sfc_main$g), { type: "controlnet", target: "defaultSettings" + }), + createVNode(unref(_sfc_main$h), { + tab: "controlnet", + target: "defaultSettings" + }), + createVNode(unref(_sfc_main$i), { + tab: "controlnet", + target: "defaultSettings" }) ]), _: 1 @@ -190,7 +199,7 @@ const _sfc_main$g = /* @__PURE__ */ defineComponent({ }; } }); -const _sfc_main$f = /* @__PURE__ */ defineComponent({ +const _sfc_main$e = /* @__PURE__ */ defineComponent({ __name: "ImageBrowserSettings", setup(__props) { const settings = useSettings(); @@ -220,7 +229,7 @@ const _sfc_main$f = /* @__PURE__ */ defineComponent({ }; } }); -const _sfc_main$e = /* @__PURE__ */ defineComponent({ +const _sfc_main$d = /* @__PURE__ */ defineComponent({ __name: "ImageToImageSettings", setup(__props) { const settings = useSettings(); @@ -354,9 +363,17 @@ const _sfc_main$e = /* @__PURE__ */ defineComponent({ ]), _: 1 }), - createVNode(_sfc_main$h, { + createVNode(unref(_sfc_main$g), { type: "img2img", target: "defaultSettings" + }), + createVNode(unref(_sfc_main$h), { + tab: "img2img", + target: "defaultSettings" + }), + createVNode(unref(_sfc_main$i), { + tab: "img2img", + target: "defaultSettings" }) ]), _: 1 @@ -367,7 +384,7 @@ const _sfc_main$e = /* @__PURE__ */ defineComponent({ }; } }); -const _sfc_main$d = /* @__PURE__ */ defineComponent({ +const _sfc_main$c = /* @__PURE__ */ defineComponent({ __name: "InpaintingSettings", setup(__props) { const settings = useSettings(); @@ -488,9 +505,17 @@ const _sfc_main$d = /* @__PURE__ */ defineComponent({ ]), _: 1 }), - createVNode(_sfc_main$h, { + createVNode(unref(_sfc_main$g), { type: "inpainting", target: "defaultSettings" + }), + createVNode(unref(_sfc_main$h), { + tab: "inpainting", + target: "defaultSettings" + }), + createVNode(unref(_sfc_main$i), { + tab: "inpainting", + target: "defaultSettings" }) ]), _: 1 @@ -501,7 +526,7 @@ const _sfc_main$d = /* @__PURE__ */ defineComponent({ }; } }); -const _sfc_main$c = /* @__PURE__ */ defineComponent({ +const _sfc_main$b = /* @__PURE__ */ defineComponent({ __name: "TextToImageSettings", setup(__props) { const settings = useSettings(); @@ -622,9 +647,17 @@ const _sfc_main$c = /* @__PURE__ */ defineComponent({ ]), _: 1 }), - createVNode(_sfc_main$h, { + createVNode(unref(_sfc_main$g), { type: "txt2img", target: "defaultSettings" + }), + createVNode(unref(_sfc_main$h), { + tab: "txt2img", + target: "defaultSettings" + }), + createVNode(unref(_sfc_main$i), { + tab: "txt2img", + target: "defaultSettings" }) ]), _: 1 @@ -635,7 +668,7 @@ const _sfc_main$c = /* @__PURE__ */ defineComponent({ }; } }); -const _sfc_main$b = /* @__PURE__ */ defineComponent({ +const _sfc_main$a = /* @__PURE__ */ defineComponent({ __name: "ThemeSettings", setup(__props) { const settings = useSettings(); @@ -711,7 +744,7 @@ const _sfc_main$b = /* @__PURE__ */ defineComponent({ } }); const _hoisted_1$3 = { style: { "width": "100%" } }; -const _sfc_main$a = /* @__PURE__ */ defineComponent({ +const _sfc_main$9 = /* @__PURE__ */ defineComponent({ __name: "AutoloadSettings", setup(__props) { const settings = useSettings(); @@ -836,7 +869,7 @@ const _sfc_main$a = /* @__PURE__ */ defineComponent({ }; } }); -const _sfc_main$9 = /* @__PURE__ */ defineComponent({ +const _sfc_main$8 = /* @__PURE__ */ defineComponent({ __name: "BotSettings", setup(__props) { const settings = useSettings(); @@ -891,7 +924,7 @@ const _sfc_main$9 = /* @__PURE__ */ defineComponent({ }; } }); -const _sfc_main$8 = /* @__PURE__ */ defineComponent({ +const _sfc_main$7 = /* @__PURE__ */ defineComponent({ __name: "FilesSettings", setup(__props) { const settings = useSettings(); @@ -970,114 +1003,6 @@ const _sfc_main$8 = /* @__PURE__ */ defineComponent({ }; } }); -const _sfc_main$7 = /* @__PURE__ */ defineComponent({ - __name: "FlagsSettings", - setup(__props) { - const settings = useSettings(); - return (_ctx, _cache) => { - return openBlock(), createBlock(unref(NCard), { title: "Hi-res fix" }, { - default: withCtx(() => [ - createVNode(unref(NForm), null, { - default: withCtx(() => [ - createVNode(unref(NFormItem), { - label: "Scale", - "label-placement": "left" - }, { - default: withCtx(() => [ - createVNode(unref(NInputNumber), { - value: unref(settings).defaultSettings.flags.highres.scale, - "onUpdate:value": _cache[0] || (_cache[0] = ($event) => unref(settings).defaultSettings.flags.highres.scale = $event) - }, null, 8, ["value"]) - ]), - _: 1 - }), - createVNode(unref(NFormItem), { - label: "Scaling Mode", - "label-placement": "left" - }, { - default: withCtx(() => [ - createVNode(unref(NSelect), { - options: [ - { - label: "Nearest", - value: "nearest" - }, - { - label: "Linear", - value: "linear" - }, - { - label: "Bilinear", - value: "bilinear" - }, - { - label: "Bicubic", - value: "bicubic" - }, - { - label: "Bislerp (Original, slow)", - value: "bislerp-original" - }, - { - label: "Bislerp (Tortured, fast)", - value: "bislerp-tortured" - }, - { - label: "Nearest Exact", - value: "nearest-exact" - } - ], - value: unref(settings).defaultSettings.flags.highres.latent_scale_mode, - "onUpdate:value": _cache[1] || (_cache[1] = ($event) => unref(settings).defaultSettings.flags.highres.latent_scale_mode = $event) - }, null, 8, ["options", "value"]) - ]), - _: 1 - }), - createVNode(unref(NFormItem), { - label: "Strength", - "label-placement": "left" - }, { - default: withCtx(() => [ - createVNode(unref(NInputNumber), { - value: unref(settings).defaultSettings.flags.highres.strength, - "onUpdate:value": _cache[2] || (_cache[2] = ($event) => unref(settings).defaultSettings.flags.highres.strength = $event) - }, null, 8, ["value"]) - ]), - _: 1 - }), - createVNode(unref(NFormItem), { - label: "Steps", - "label-placement": "left" - }, { - default: withCtx(() => [ - createVNode(unref(NInputNumber), { - value: unref(settings).defaultSettings.flags.highres.steps, - "onUpdate:value": _cache[3] || (_cache[3] = ($event) => unref(settings).defaultSettings.flags.highres.steps = $event) - }, null, 8, ["value"]) - ]), - _: 1 - }), - createVNode(unref(NFormItem), { - label: "Antialiased", - "label-placement": "left" - }, { - default: withCtx(() => [ - createVNode(unref(NSwitch), { - value: unref(settings).defaultSettings.flags.highres.antialiased, - "onUpdate:value": _cache[4] || (_cache[4] = ($event) => unref(settings).defaultSettings.flags.highres.antialiased = $event) - }, null, 8, ["value"]) - ]), - _: 1 - }) - ]), - _: 1 - }) - ]), - _: 1 - }); - }; - } -}); const _sfc_main$6 = /* @__PURE__ */ defineComponent({ __name: "FrontendSettings", setup(__props) { @@ -1086,31 +1011,31 @@ const _sfc_main$6 = /* @__PURE__ */ defineComponent({ default: withCtx(() => [ createVNode(unref(NTabPane), { name: "Text to Image" }, { default: withCtx(() => [ - createVNode(unref(_sfc_main$c)) + createVNode(unref(_sfc_main$b)) ]), _: 1 }), createVNode(unref(NTabPane), { name: "Image to Image" }, { default: withCtx(() => [ - createVNode(unref(_sfc_main$e)) + createVNode(unref(_sfc_main$d)) ]), _: 1 }), createVNode(unref(NTabPane), { name: "ControlNet" }, { default: withCtx(() => [ - createVNode(unref(_sfc_main$g)) + createVNode(unref(_sfc_main$f)) ]), _: 1 }), createVNode(unref(NTabPane), { name: "Inpainting" }, { default: withCtx(() => [ - createVNode(unref(_sfc_main$d)) + createVNode(unref(_sfc_main$c)) ]), _: 1 }), createVNode(unref(NTabPane), { name: "Image Browser" }, { default: withCtx(() => [ - createVNode(unref(_sfc_main$f)) + createVNode(unref(_sfc_main$e)) ]), _: 1 }) @@ -1237,7 +1162,7 @@ const _hoisted_1$2 = { key: 0, class: "flex-container" }; -const _hoisted_2$1 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Subquadratic chunk size (affects VRAM usage)", -1); +const _hoisted_2$2 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Subquadratic chunk size (affects VRAM usage)", -1); const _hoisted_3$1 = { "flex-direction": "row" }; const _hoisted_4$1 = { key: 1 }; const _hoisted_5$1 = { key: 2 }; @@ -1319,7 +1244,7 @@ const _sfc_main$3 = /* @__PURE__ */ defineComponent({ _: 1 }), unref(settings).defaultSettings.api.attention_processor == "subquadratic" ? (openBlock(), createElementBlock("div", _hoisted_1$2, [ - _hoisted_2$1, + _hoisted_2$2, createVNode(unref(NSlider), { value: unref(settings).defaultSettings.api.subquadratic_size, "onUpdate:value": _cache[2] || (_cache[2] = ($event) => unref(settings).defaultSettings.api.subquadratic_size = $event), @@ -1638,18 +1563,42 @@ const _sfc_main$3 = /* @__PURE__ */ defineComponent({ } }); const _hoisted_1$1 = { key: 1 }; -const _hoisted_2 = { class: "flex-container" }; +const _hoisted_2$1 = { class: "flex-container" }; const _hoisted_3 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Hypertile UNet chunk size", -1); const _hoisted_4 = /* @__PURE__ */ createBaseVNode("b", { class: "highlight" }, 'PyTorch ONLY. Recommended sizes are 1/4th your desired resolution or plain "256."', -1); const _hoisted_5 = /* @__PURE__ */ createBaseVNode("b", null, "LARGE (1024x1024+)", -1); const _hoisted_6 = { key: 2 }; const _hoisted_7 = { key: 3 }; -const _hoisted_8 = { key: 4 }; +const _hoisted_8 = { key: 0 }; +const _hoisted_9 = { style: { "margin-bottom": "12px", "display": "flex", "flex-direction": "row", "flex-wrap": "wrap", "gap": "8px 0" } }; const _sfc_main$2 = /* @__PURE__ */ defineComponent({ __name: "ReproducibilitySettings", setup(__props) { const settings = useSettings(); const global = useState(); + const enabledCfg = computed({ + get() { + return settings.defaultSettings.api.cfg_rescale_threshold != "off"; + }, + set(value) { + if (!value) { + settings.defaultSettings.api.cfg_rescale_threshold = "off"; + } else { + settings.defaultSettings.api.cfg_rescale_threshold = 10; + } + } + }); + const cfgRescaleValue = computed({ + get() { + if (settings.defaultSettings.api.cfg_rescale_threshold == "off") { + return 1; + } + return settings.defaultSettings.api.cfg_rescale_threshold; + }, + set(value) { + settings.defaultSettings.api.cfg_rescale_threshold = value; + } + }); const availableDtypes = computed(() => { if (settings.defaultSettings.api.device.includes("cpu")) { return global.state.capabilities.supported_precisions_cpu.map((value) => { @@ -1661,6 +1610,12 @@ const _sfc_main$2 = /* @__PURE__ */ defineComponent({ case "float16": description = "16-bit float"; break; + case "float8_e5m2": + description = "8-bit float (5-data)"; + break; + case "float8_e4m3fn": + description = "8-bit float (4-data)"; + break; default: description = "16-bit bfloat"; } @@ -1676,6 +1631,12 @@ const _sfc_main$2 = /* @__PURE__ */ defineComponent({ case "float16": description = "16-bit float"; break; + case "float8_e5m2": + description = "8-bit float (5-data)"; + break; + case "float8_e4m3fn": + description = "8-bit float (4-data)"; + break; default: description = "16-bit bfloat"; } @@ -1820,7 +1781,7 @@ const _sfc_main$2 = /* @__PURE__ */ defineComponent({ _: 1 }), unref(settings).defaultSettings.api.hypertile ? (openBlock(), createElementBlock("div", _hoisted_1$1, [ - createBaseVNode("div", _hoisted_2, [ + createBaseVNode("div", _hoisted_2$1, [ createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { trigger: withCtx(() => [ _hoisted_3 @@ -2004,60 +1965,178 @@ const _sfc_main$2 = /* @__PURE__ */ defineComponent({ ]), _: 1 }), - unref(settings).defaultSettings.api.free_u ? (openBlock(), createElementBlock("div", _hoisted_8, [ - createVNode(unref(NFormItem), { - label: "Free U S1", - "label-placement": "left" - }, { - default: withCtx(() => [ - createVNode(unref(NInputNumber), { - value: unref(settings).defaultSettings.api.free_u_s1, - "onUpdate:value": _cache[19] || (_cache[19] = ($event) => unref(settings).defaultSettings.api.free_u_s1 = $event), - step: 0.01 - }, null, 8, ["value"]) - ]), - _: 1 - }), - createVNode(unref(NFormItem), { - label: "Free U S2", - "label-placement": "left" - }, { - default: withCtx(() => [ - createVNode(unref(NInputNumber), { - value: unref(settings).defaultSettings.api.free_u_s2, - "onUpdate:value": _cache[20] || (_cache[20] = ($event) => unref(settings).defaultSettings.api.free_u_s2 = $event), - step: 0.01 - }, null, 8, ["value"]) - ]), - _: 1 - }), - createVNode(unref(NFormItem), { - label: "Free U B1", - "label-placement": "left" - }, { - default: withCtx(() => [ - createVNode(unref(NInputNumber), { - value: unref(settings).defaultSettings.api.free_u_b1, - "onUpdate:value": _cache[21] || (_cache[21] = ($event) => unref(settings).defaultSettings.api.free_u_b1 = $event), - step: 0.01 - }, null, 8, ["value"]) - ]), - _: 1 - }), - createVNode(unref(NFormItem), { - label: "Free U B2", - "label-placement": "left" - }, { - default: withCtx(() => [ - createVNode(unref(NInputNumber), { - value: unref(settings).defaultSettings.api.free_u_b2, - "onUpdate:value": _cache[22] || (_cache[22] = ($event) => unref(settings).defaultSettings.api.free_u_b2 = $event), - step: 0.01 - }, null, 8, ["value"]) - ]), - _: 1 - }) - ])) : createCommentVNode("", true) + createVNode(unref(NCard), { + bordered: false, + style: { "margin-bottom": "12px" } + }, { + default: withCtx(() => [ + unref(settings).defaultSettings.api.free_u ? (openBlock(), createElementBlock("div", _hoisted_8, [ + createBaseVNode("div", _hoisted_9, [ + createVNode(unref(NButton), { + style: { "margin-left": "12px" }, + ghost: "", + type: "info", + onClick: _cache[19] || (_cache[19] = () => { + unref(settings).defaultSettings.api.free_u_b1 = 1.3; + unref(settings).defaultSettings.api.free_u_b2 = 1.4; + unref(settings).defaultSettings.api.free_u_s1 = 0.9; + unref(settings).defaultSettings.api.free_u_s2 = 0.2; + }) + }, { + default: withCtx(() => [ + createTextVNode(" Apply SD 1.4 Defaults ") + ]), + _: 1 + }), + createVNode(unref(NButton), { + style: { "margin-left": "12px" }, + ghost: "", + type: "warning", + onClick: _cache[20] || (_cache[20] = () => { + unref(settings).defaultSettings.api.free_u_b1 = 1.5; + unref(settings).defaultSettings.api.free_u_b2 = 1.6; + unref(settings).defaultSettings.api.free_u_s1 = 0.9; + unref(settings).defaultSettings.api.free_u_s2 = 0.2; + }) + }, { + default: withCtx(() => [ + createTextVNode(" Apply SD 1.5 Defaults ") + ]), + _: 1 + }), + createVNode(unref(NButton), { + style: { "margin-left": "12px" }, + ghost: "", + type: "success", + onClick: _cache[21] || (_cache[21] = () => { + unref(settings).defaultSettings.api.free_u_b1 = 1.4; + unref(settings).defaultSettings.api.free_u_b2 = 1.6; + unref(settings).defaultSettings.api.free_u_s1 = 0.9; + unref(settings).defaultSettings.api.free_u_s2 = 0.2; + }) + }, { + default: withCtx(() => [ + createTextVNode(" Apply SD 2.1 Defaults ") + ]), + _: 1 + }), + createVNode(unref(NButton), { + style: { "margin-left": "12px" }, + ghost: "", + type: "error", + onClick: _cache[22] || (_cache[22] = () => { + unref(settings).defaultSettings.api.free_u_b1 = 1.3; + unref(settings).defaultSettings.api.free_u_b2 = 1.4; + unref(settings).defaultSettings.api.free_u_s1 = 0.9; + unref(settings).defaultSettings.api.free_u_s2 = 0.2; + }) + }, { + default: withCtx(() => [ + createTextVNode(" Apply SDXL Defaults ") + ]), + _: 1 + }) + ]), + createVNode(unref(NFormItem), { + label: "Free U B1", + "label-placement": "left" + }, { + default: withCtx(() => [ + createVNode(unref(NInputNumber), { + value: unref(settings).defaultSettings.api.free_u_b1, + "onUpdate:value": _cache[23] || (_cache[23] = ($event) => unref(settings).defaultSettings.api.free_u_b1 = $event), + step: 0.01 + }, null, 8, ["value"]) + ]), + _: 1 + }), + createVNode(unref(NFormItem), { + label: "Free U B2", + "label-placement": "left" + }, { + default: withCtx(() => [ + createVNode(unref(NInputNumber), { + value: unref(settings).defaultSettings.api.free_u_b2, + "onUpdate:value": _cache[24] || (_cache[24] = ($event) => unref(settings).defaultSettings.api.free_u_b2 = $event), + step: 0.01 + }, null, 8, ["value"]) + ]), + _: 1 + }), + createVNode(unref(NFormItem), { + label: "Free U S1", + "label-placement": "left" + }, { + default: withCtx(() => [ + createVNode(unref(NInputNumber), { + value: unref(settings).defaultSettings.api.free_u_s1, + "onUpdate:value": _cache[25] || (_cache[25] = ($event) => unref(settings).defaultSettings.api.free_u_s1 = $event), + step: 0.01 + }, null, 8, ["value"]) + ]), + _: 1 + }), + createVNode(unref(NFormItem), { + label: "Free U S2", + "label-placement": "left" + }, { + default: withCtx(() => [ + createVNode(unref(NInputNumber), { + value: unref(settings).defaultSettings.api.free_u_s2, + "onUpdate:value": _cache[26] || (_cache[26] = ($event) => unref(settings).defaultSettings.api.free_u_s2 = $event), + step: 0.01 + }, null, 8, ["value"]) + ]), + _: 1 + }) + ])) : createCommentVNode("", true) + ]), + _: 1 + }), + createVNode(unref(NFormItem), { + label: "Upcast VAE", + "label-placement": "left" + }, { + default: withCtx(() => [ + createVNode(unref(NSwitch), { + value: unref(settings).defaultSettings.api.upcast_vae, + "onUpdate:value": _cache[27] || (_cache[27] = ($event) => unref(settings).defaultSettings.api.upcast_vae = $event) + }, null, 8, ["value"]) + ]), + _: 1 + }), + createVNode(unref(NFormItem), { + label: "Apply unsharp mask", + "label-placement": "left" + }, { + default: withCtx(() => [ + createVNode(unref(NSwitch), { + value: unref(settings).defaultSettings.api.apply_unsharp_mask, + "onUpdate:value": _cache[28] || (_cache[28] = ($event) => unref(settings).defaultSettings.api.apply_unsharp_mask = $event) + }, null, 8, ["value"]) + ]), + _: 1 + }), + createVNode(unref(NFormItem), { + label: "CFG Rescale Threshold", + "label-placement": "left" + }, { + default: withCtx(() => [ + createVNode(unref(NSlider), { + value: cfgRescaleValue.value, + "onUpdate:value": _cache[29] || (_cache[29] = ($event) => cfgRescaleValue.value = $event), + disabled: !enabledCfg.value, + min: 2, + max: 30, + step: 0.5 + }, null, 8, ["value", "disabled"]), + createVNode(unref(NSwitch), { + value: enabledCfg.value, + "onUpdate:value": _cache[30] || (_cache[30] = ($event) => enabledCfg.value = $event) + }, null, 8, ["value"]) + ]), + _: 1 + }) ]), _: 1 }); @@ -2136,6 +2215,7 @@ const _sfc_main$1 = /* @__PURE__ */ defineComponent({ } }); const _hoisted_1 = { class: "main-container" }; +const _hoisted_2 = { style: { "width": "100%", "display": "flex", "justify-content": "end", "margin-bottom": "12px" } }; const _sfc_main = /* @__PURE__ */ defineComponent({ __name: "SettingsView", setup(__props) { @@ -2154,25 +2234,16 @@ const _sfc_main = /* @__PURE__ */ defineComponent({ } function saveSettings() { saving.value = true; - fetch(`${serverUrl}/api/settings/save`, { - method: "POST", - headers: { - "Content-Type": "application/json" - }, - body: JSON.stringify(settings.defaultSettings) - }).then((res) => { - if (res.status === 200) { - message.success("Settings saved successfully"); - } else { - res.json().then((data) => { - message.error("Error while saving settings"); - notification.create({ - title: "Error while saving settings", - content: data.message, - type: "error" - }); - }); - } + settings.saveSettings().then(() => { + message.success("Settings saved"); + }).catch((e) => { + message.error("Failed to save settings"); + notification.create({ + title: "Failed to save settings", + content: e, + type: "error" + }); + }).finally(() => { saving.value = false; }); } @@ -2181,43 +2252,43 @@ const _sfc_main = /* @__PURE__ */ defineComponent({ }); return (_ctx, _cache) => { return openBlock(), createElementBlock("div", _hoisted_1, [ + createBaseVNode("div", _hoisted_2, [ + createVNode(unref(NButton), { + type: "error", + ghost: "", + style: { "margin-right": "12px" }, + onClick: resetSettings + }, { + default: withCtx(() => [ + createTextVNode("Reset Settings") + ]), + _: 1 + }), + createVNode(unref(NButton), { + type: "success", + ghost: "", + onClick: saveSettings, + loading: saving.value + }, { + default: withCtx(() => [ + createTextVNode("Save Settings") + ]), + _: 1 + }, 8, ["loading"]) + ]), createVNode(unref(NCard), null, { default: withCtx(() => [ createVNode(unref(NTabs), null, { - suffix: withCtx(() => [ - createVNode(unref(NButton), { - type: "error", - ghost: "", - style: { "margin-right": "12px" }, - onClick: resetSettings - }, { - default: withCtx(() => [ - createTextVNode("Reset Settings") - ]), - _: 1 - }), - createVNode(unref(NButton), { - type: "success", - ghost: "", - onClick: saveSettings, - loading: saving.value - }, { - default: withCtx(() => [ - createTextVNode("Save Settings") - ]), - _: 1 - }, 8, ["loading"]) - ]), default: withCtx(() => [ createVNode(unref(NTabPane), { name: "Autoload" }, { default: withCtx(() => [ - createVNode(unref(_sfc_main$a)) + createVNode(unref(_sfc_main$9)) ]), _: 1 }), createVNode(unref(NTabPane), { name: "Files & Saving" }, { default: withCtx(() => [ - createVNode(unref(_sfc_main$8)) + createVNode(unref(_sfc_main$7)) ]), _: 1 }), @@ -2247,7 +2318,7 @@ const _sfc_main = /* @__PURE__ */ defineComponent({ }), createVNode(unref(NTabPane), { name: "Bot" }, { default: withCtx(() => [ - createVNode(unref(_sfc_main$9)) + createVNode(unref(_sfc_main$8)) ]), _: 1 }), @@ -2257,15 +2328,9 @@ const _sfc_main = /* @__PURE__ */ defineComponent({ ]), _: 1 }), - createVNode(unref(NTabPane), { name: "Flags" }, { - default: withCtx(() => [ - createVNode(unref(_sfc_main$7)) - ]), - _: 1 - }), createVNode(unref(NTabPane), { name: "Theme" }, { default: withCtx(() => [ - createVNode(unref(_sfc_main$b)) + createVNode(unref(_sfc_main$a)) ]), _: 1 }), diff --git a/frontend/dist/assets/Slider.js b/frontend/dist/assets/Slider.js new file mode 100644 index 000000000..5d516c602 --- /dev/null +++ b/frontend/dist/assets/Slider.js @@ -0,0 +1,727 @@ +import { B as ref, br as onBeforeUpdate, ab as c, R as cB, ac as cM, au as cE, aT as fadeInScaleUpTransition, aV as insideModal, aW as insidePopover, d as defineComponent, T as useConfig, U as useTheme, as as useFormItem, c as computed, Y as toRef, af as useMergedState, K as watch, X as nextTick, aC as onBeforeUnmount, Z as useThemeClass, bT as isMounted, al as useAdjustedTo, A as h, c0 as VBinder, c1 as VTarget, aj as resolveSlot, c2 as VFollower, aY as Transition, c3 as sliderLight, aE as on, aD as off, a2 as call } from "./index.js"; +function isTouchEvent(e) { + return window.TouchEvent && e instanceof window.TouchEvent; +} +function useRefs() { + const refs = ref(/* @__PURE__ */ new Map()); + const setRefs = (index) => (el) => { + refs.value.set(index, el); + }; + onBeforeUpdate(() => { + refs.value.clear(); + }); + return [refs, setRefs]; +} +const style = c([cB("slider", ` + display: block; + padding: calc((var(--n-handle-size) - var(--n-rail-height)) / 2) 0; + position: relative; + z-index: 0; + width: 100%; + cursor: pointer; + user-select: none; + -webkit-user-select: none; + `, [cM("reverse", [cB("slider-handles", [cB("slider-handle-wrapper", ` + transform: translate(50%, -50%); + `)]), cB("slider-dots", [cB("slider-dot", ` + transform: translateX(50%, -50%); + `)]), cM("vertical", [cB("slider-handles", [cB("slider-handle-wrapper", ` + transform: translate(-50%, -50%); + `)]), cB("slider-marks", [cB("slider-mark", ` + transform: translateY(calc(-50% + var(--n-dot-height) / 2)); + `)]), cB("slider-dots", [cB("slider-dot", ` + transform: translateX(-50%) translateY(0); + `)])])]), cM("vertical", ` + padding: 0 calc((var(--n-handle-size) - var(--n-rail-height)) / 2); + width: var(--n-rail-width-vertical); + height: 100%; + `, [cB("slider-handles", ` + top: calc(var(--n-handle-size) / 2); + right: 0; + bottom: calc(var(--n-handle-size) / 2); + left: 0; + `, [cB("slider-handle-wrapper", ` + top: unset; + left: 50%; + transform: translate(-50%, 50%); + `)]), cB("slider-rail", ` + height: 100%; + `, [cE("fill", ` + top: unset; + right: 0; + bottom: unset; + left: 0; + `)]), cM("with-mark", ` + width: var(--n-rail-width-vertical); + margin: 0 32px 0 8px; + `), cB("slider-marks", ` + top: calc(var(--n-handle-size) / 2); + right: unset; + bottom: calc(var(--n-handle-size) / 2); + left: 22px; + font-size: var(--n-mark-font-size); + `, [cB("slider-mark", ` + transform: translateY(50%); + white-space: nowrap; + `)]), cB("slider-dots", ` + top: calc(var(--n-handle-size) / 2); + right: unset; + bottom: calc(var(--n-handle-size) / 2); + left: 50%; + `, [cB("slider-dot", ` + transform: translateX(-50%) translateY(50%); + `)])]), cM("disabled", ` + cursor: not-allowed; + opacity: var(--n-opacity-disabled); + `, [cB("slider-handle", ` + cursor: not-allowed; + `)]), cM("with-mark", ` + width: 100%; + margin: 8px 0 32px 0; + `), c("&:hover", [cB("slider-rail", { + backgroundColor: "var(--n-rail-color-hover)" +}, [cE("fill", { + backgroundColor: "var(--n-fill-color-hover)" +})]), cB("slider-handle", { + boxShadow: "var(--n-handle-box-shadow-hover)" +})]), cM("active", [cB("slider-rail", { + backgroundColor: "var(--n-rail-color-hover)" +}, [cE("fill", { + backgroundColor: "var(--n-fill-color-hover)" +})]), cB("slider-handle", { + boxShadow: "var(--n-handle-box-shadow-hover)" +})]), cB("slider-marks", ` + position: absolute; + top: 18px; + left: calc(var(--n-handle-size) / 2); + right: calc(var(--n-handle-size) / 2); + `, [cB("slider-mark", ` + position: absolute; + transform: translateX(-50%); + white-space: nowrap; + `)]), cB("slider-rail", ` + width: 100%; + position: relative; + height: var(--n-rail-height); + background-color: var(--n-rail-color); + transition: background-color .3s var(--n-bezier); + border-radius: calc(var(--n-rail-height) / 2); + `, [cE("fill", ` + position: absolute; + top: 0; + bottom: 0; + border-radius: calc(var(--n-rail-height) / 2); + transition: background-color .3s var(--n-bezier); + background-color: var(--n-fill-color); + `)]), cB("slider-handles", ` + position: absolute; + top: 0; + right: calc(var(--n-handle-size) / 2); + bottom: 0; + left: calc(var(--n-handle-size) / 2); + `, [cB("slider-handle-wrapper", ` + outline: none; + position: absolute; + top: 50%; + transform: translate(-50%, -50%); + cursor: pointer; + display: flex; + `, [cB("slider-handle", ` + height: var(--n-handle-size); + width: var(--n-handle-size); + border-radius: 50%; + overflow: hidden; + transition: box-shadow .2s var(--n-bezier), background-color .3s var(--n-bezier); + background-color: var(--n-handle-color); + box-shadow: var(--n-handle-box-shadow); + `, [c("&:hover", ` + box-shadow: var(--n-handle-box-shadow-hover); + `)]), c("&:focus", [cB("slider-handle", ` + box-shadow: var(--n-handle-box-shadow-focus); + `, [c("&:hover", ` + box-shadow: var(--n-handle-box-shadow-active); + `)])])])]), cB("slider-dots", ` + position: absolute; + top: 50%; + left: calc(var(--n-handle-size) / 2); + right: calc(var(--n-handle-size) / 2); + `, [cM("transition-disabled", [cB("slider-dot", "transition: none;")]), cB("slider-dot", ` + transition: + border-color .3s var(--n-bezier), + box-shadow .3s var(--n-bezier), + background-color .3s var(--n-bezier); + position: absolute; + transform: translate(-50%, -50%); + height: var(--n-dot-height); + width: var(--n-dot-width); + border-radius: var(--n-dot-border-radius); + overflow: hidden; + box-sizing: border-box; + border: var(--n-dot-border); + background-color: var(--n-dot-color); + `, [cM("active", "border: var(--n-dot-border-active);")])])]), cB("slider-handle-indicator", ` + font-size: var(--n-font-size); + padding: 6px 10px; + border-radius: var(--n-indicator-border-radius); + color: var(--n-indicator-text-color); + background-color: var(--n-indicator-color); + box-shadow: var(--n-indicator-box-shadow); + `, [fadeInScaleUpTransition()]), cB("slider-handle-indicator", ` + font-size: var(--n-font-size); + padding: 6px 10px; + border-radius: var(--n-indicator-border-radius); + color: var(--n-indicator-text-color); + background-color: var(--n-indicator-color); + box-shadow: var(--n-indicator-box-shadow); + `, [cM("top", ` + margin-bottom: 12px; + `), cM("right", ` + margin-left: 12px; + `), cM("bottom", ` + margin-top: 12px; + `), cM("left", ` + margin-right: 12px; + `), fadeInScaleUpTransition()]), insideModal(cB("slider", [cB("slider-dot", "background-color: var(--n-dot-color-modal);")])), insidePopover(cB("slider", [cB("slider-dot", "background-color: var(--n-dot-color-popover);")]))]); +const eventButtonLeft = 0; +const sliderProps = Object.assign(Object.assign({}, useTheme.props), { to: useAdjustedTo.propTo, defaultValue: { + type: [Number, Array], + default: 0 +}, marks: Object, disabled: { + type: Boolean, + default: void 0 +}, formatTooltip: Function, keyboard: { + type: Boolean, + default: true +}, min: { + type: Number, + default: 0 +}, max: { + type: Number, + default: 100 +}, step: { + type: [Number, String], + default: 1 +}, range: Boolean, value: [Number, Array], placement: String, showTooltip: { + type: Boolean, + default: void 0 +}, tooltip: { + type: Boolean, + default: true +}, vertical: Boolean, reverse: Boolean, "onUpdate:value": [Function, Array], onUpdateValue: [Function, Array] }); +const NSlider = defineComponent({ + name: "Slider", + props: sliderProps, + setup(props) { + const { mergedClsPrefixRef, namespaceRef, inlineThemeDisabled } = useConfig(props); + const themeRef = useTheme("Slider", "-slider", style, sliderLight, props, mergedClsPrefixRef); + const handleRailRef = ref(null); + const [handleRefs, setHandleRefs] = useRefs(); + const [followerRefs, setFollowerRefs] = useRefs(); + const followerEnabledIndexSetRef = ref(/* @__PURE__ */ new Set()); + const formItem = useFormItem(props); + const { mergedDisabledRef } = formItem; + const precisionRef = computed(() => { + const { step } = props; + if (Number(step) <= 0 || step === "mark") + return 0; + const stepString = step.toString(); + let precision = 0; + if (stepString.includes(".")) { + precision = stepString.length - stepString.indexOf(".") - 1; + } + return precision; + }); + const uncontrolledValueRef = ref(props.defaultValue); + const controlledValueRef = toRef(props, "value"); + const mergedValueRef = useMergedState(controlledValueRef, uncontrolledValueRef); + const arrifiedValueRef = computed(() => { + const { value: mergedValue } = mergedValueRef; + return (props.range ? mergedValue : [mergedValue]).map(clampValue); + }); + const handleCountExceeds2Ref = computed(() => arrifiedValueRef.value.length > 2); + const mergedPlacementRef = computed(() => { + return props.placement === void 0 ? props.vertical ? "right" : "top" : props.placement; + }); + const markValuesRef = computed(() => { + const { marks } = props; + return marks ? Object.keys(marks).map(parseFloat) : null; + }); + const activeIndexRef = ref(-1); + const previousIndexRef = ref(-1); + const hoverIndexRef = ref(-1); + const draggingRef = ref(false); + const dotTransitionDisabledRef = ref(false); + const styleDirectionRef = computed(() => { + const { vertical, reverse } = props; + const left = reverse ? "right" : "left"; + const bottom = reverse ? "top" : "bottom"; + return vertical ? bottom : left; + }); + const fillStyleRef = computed(() => { + if (handleCountExceeds2Ref.value) + return; + const values = arrifiedValueRef.value; + const start = valueToPercentage(props.range ? Math.min(...values) : props.min); + const end = valueToPercentage(props.range ? Math.max(...values) : values[0]); + const { value: styleDirection } = styleDirectionRef; + return props.vertical ? { + [styleDirection]: `${start}%`, + height: `${end - start}%` + } : { + [styleDirection]: `${start}%`, + width: `${end - start}%` + }; + }); + const markInfosRef = computed(() => { + const mergedMarks = []; + const { marks } = props; + if (marks) { + const orderValues = arrifiedValueRef.value.slice(); + orderValues.sort((a, b) => a - b); + const { value: styleDirection } = styleDirectionRef; + const { value: handleCountExceeds2 } = handleCountExceeds2Ref; + const { range } = props; + const isActive = handleCountExceeds2 ? () => false : (num) => range ? num >= orderValues[0] && num <= orderValues[orderValues.length - 1] : num <= orderValues[0]; + for (const key of Object.keys(marks)) { + const num = Number(key); + mergedMarks.push({ + active: isActive(num), + label: marks[key], + style: { + [styleDirection]: `${valueToPercentage(num)}%` + } + }); + } + } + return mergedMarks; + }); + function getHandleStyle(value, index) { + const percentage = valueToPercentage(value); + const { value: styleDirection } = styleDirectionRef; + return { + [styleDirection]: `${percentage}%`, + zIndex: index === activeIndexRef.value ? 1 : 0 + }; + } + function isShowTooltip(index) { + return props.showTooltip || hoverIndexRef.value === index || activeIndexRef.value === index && draggingRef.value; + } + function shouldKeepTooltipTransition(index) { + if (!draggingRef.value) + return true; + return !(activeIndexRef.value === index && previousIndexRef.value === index); + } + function focusActiveHandle(index) { + var _a; + if (~index) { + activeIndexRef.value = index; + (_a = handleRefs.value.get(index)) === null || _a === void 0 ? void 0 : _a.focus(); + } + } + function syncPosition() { + followerRefs.value.forEach((inst, index) => { + if (isShowTooltip(index)) + inst.syncPosition(); + }); + } + function doUpdateValue(value) { + const { "onUpdate:value": _onUpdateValue, onUpdateValue } = props; + const { nTriggerFormInput, nTriggerFormChange } = formItem; + if (onUpdateValue) + call(onUpdateValue, value); + if (_onUpdateValue) + call(_onUpdateValue, value); + uncontrolledValueRef.value = value; + nTriggerFormInput(); + nTriggerFormChange(); + } + function dispatchValueUpdate(value) { + const { range } = props; + if (range) { + if (Array.isArray(value)) { + const { value: oldValues } = arrifiedValueRef; + if (value.join() !== oldValues.join()) { + doUpdateValue(value); + } + } + } else if (!Array.isArray(value)) { + const oldValue = arrifiedValueRef.value[0]; + if (oldValue !== value) { + doUpdateValue(value); + } + } + } + function doDispatchValue(value, index) { + if (props.range) { + const values = arrifiedValueRef.value.slice(); + values.splice(index, 1, value); + dispatchValueUpdate(values); + } else { + dispatchValueUpdate(value); + } + } + function sanitizeValue(value, currentValue, stepBuffer) { + const stepping = stepBuffer !== void 0; + if (!stepBuffer) { + stepBuffer = value - currentValue > 0 ? 1 : -1; + } + const markValues = markValuesRef.value || []; + const { step } = props; + if (step === "mark") { + const closestMark2 = getClosestMark(value, markValues.concat(currentValue), stepping ? stepBuffer : void 0); + return closestMark2 ? closestMark2.value : currentValue; + } + if (step <= 0) + return currentValue; + const { value: precision } = precisionRef; + let closestMark; + if (stepping) { + const currentStep = Number((currentValue / step).toFixed(precision)); + const actualStep = Math.floor(currentStep); + const leftStep = currentStep > actualStep ? actualStep : actualStep - 1; + const rightStep = currentStep < actualStep ? actualStep : actualStep + 1; + closestMark = getClosestMark(currentValue, [ + Number((leftStep * step).toFixed(precision)), + Number((rightStep * step).toFixed(precision)), + ...markValues + ], stepBuffer); + } else { + const roundValue = getRoundValue(value); + closestMark = getClosestMark(value, [...markValues, roundValue]); + } + return closestMark ? clampValue(closestMark.value) : currentValue; + } + function clampValue(value) { + return Math.min(props.max, Math.max(props.min, value)); + } + function valueToPercentage(value) { + const { max, min } = props; + return (value - min) / (max - min) * 100; + } + function percentageToValue(percentage) { + const { max, min } = props; + return min + (max - min) * percentage; + } + function getRoundValue(value) { + const { step, min } = props; + if (Number(step) <= 0 || step === "mark") + return value; + const newValue = Math.round((value - min) / step) * step + min; + return Number(newValue.toFixed(precisionRef.value)); + } + function getClosestMark(currentValue, markValues = markValuesRef.value, buffer) { + if (!(markValues === null || markValues === void 0 ? void 0 : markValues.length)) + return null; + let closestMark = null; + let index = -1; + while (++index < markValues.length) { + const diff = markValues[index] - currentValue; + const distance = Math.abs(diff); + if ( + // find marks in the same direction + (buffer === void 0 || diff * buffer > 0) && (closestMark === null || distance < closestMark.distance) + ) { + closestMark = { + index, + distance, + value: markValues[index] + }; + } + } + return closestMark; + } + function getPointValue(event) { + const railEl = handleRailRef.value; + if (!railEl) + return; + const touchEvent = isTouchEvent(event) ? event.touches[0] : event; + const railRect = railEl.getBoundingClientRect(); + let percentage; + if (props.vertical) { + percentage = (railRect.bottom - touchEvent.clientY) / railRect.height; + } else { + percentage = (touchEvent.clientX - railRect.left) / railRect.width; + } + if (props.reverse) { + percentage = 1 - percentage; + } + return percentageToValue(percentage); + } + function handleRailKeyDown(e) { + if (mergedDisabledRef.value || !props.keyboard) + return; + const { vertical, reverse } = props; + switch (e.key) { + case "ArrowUp": + e.preventDefault(); + handleStepValue(vertical && reverse ? -1 : 1); + break; + case "ArrowRight": + e.preventDefault(); + handleStepValue(!vertical && reverse ? -1 : 1); + break; + case "ArrowDown": + e.preventDefault(); + handleStepValue(vertical && reverse ? 1 : -1); + break; + case "ArrowLeft": + e.preventDefault(); + handleStepValue(!vertical && reverse ? 1 : -1); + break; + } + } + function handleStepValue(ratio) { + const activeIndex = activeIndexRef.value; + if (activeIndex === -1) + return; + const { step } = props; + const currentValue = arrifiedValueRef.value[activeIndex]; + const nextValue = Number(step) <= 0 || step === "mark" ? currentValue : currentValue + step * ratio; + doDispatchValue( + // Avoid the number of value does not change when `step` is null + sanitizeValue(nextValue, currentValue, ratio > 0 ? 1 : -1), + activeIndex + ); + } + function handleRailMouseDown(event) { + var _a, _b; + if (mergedDisabledRef.value) + return; + if (!isTouchEvent(event) && event.button !== eventButtonLeft) { + return; + } + const pointValue = getPointValue(event); + if (pointValue === void 0) + return; + const values = arrifiedValueRef.value.slice(); + const activeIndex = props.range ? (_b = (_a = getClosestMark(pointValue, values)) === null || _a === void 0 ? void 0 : _a.index) !== null && _b !== void 0 ? _b : -1 : 0; + if (activeIndex !== -1) { + event.preventDefault(); + focusActiveHandle(activeIndex); + startDragging(); + doDispatchValue(sanitizeValue(pointValue, arrifiedValueRef.value[activeIndex]), activeIndex); + } + } + function startDragging() { + if (!draggingRef.value) { + draggingRef.value = true; + on("touchend", document, handleMouseUp); + on("mouseup", document, handleMouseUp); + on("touchmove", document, handleMouseMove); + on("mousemove", document, handleMouseMove); + } + } + function stopDragging() { + if (draggingRef.value) { + draggingRef.value = false; + off("touchend", document, handleMouseUp); + off("mouseup", document, handleMouseUp); + off("touchmove", document, handleMouseMove); + off("mousemove", document, handleMouseMove); + } + } + function handleMouseMove(event) { + const { value: activeIndex } = activeIndexRef; + if (!draggingRef.value || activeIndex === -1) { + stopDragging(); + return; + } + const pointValue = getPointValue(event); + doDispatchValue(sanitizeValue(pointValue, arrifiedValueRef.value[activeIndex]), activeIndex); + } + function handleMouseUp() { + stopDragging(); + } + function handleHandleFocus(index) { + activeIndexRef.value = index; + if (!mergedDisabledRef.value) { + hoverIndexRef.value = index; + } + } + function handleHandleBlur(index) { + if (activeIndexRef.value === index) { + activeIndexRef.value = -1; + stopDragging(); + } + if (hoverIndexRef.value === index) { + hoverIndexRef.value = -1; + } + } + function handleHandleMouseEnter(index) { + hoverIndexRef.value = index; + } + function handleHandleMouseLeave(index) { + if (hoverIndexRef.value === index) { + hoverIndexRef.value = -1; + } + } + watch(activeIndexRef, (_, previous) => void nextTick(() => previousIndexRef.value = previous)); + watch(mergedValueRef, () => { + if (props.marks) { + if (dotTransitionDisabledRef.value) + return; + dotTransitionDisabledRef.value = true; + void nextTick(() => { + dotTransitionDisabledRef.value = false; + }); + } + void nextTick(syncPosition); + }); + onBeforeUnmount(() => { + stopDragging(); + }); + const cssVarsRef = computed(() => { + const { self: { markFontSize, railColor, railColorHover, fillColor, fillColorHover, handleColor, opacityDisabled, dotColor, dotColorModal, handleBoxShadow, handleBoxShadowHover, handleBoxShadowActive, handleBoxShadowFocus, dotBorder, dotBoxShadow, railHeight, railWidthVertical, handleSize, dotHeight, dotWidth, dotBorderRadius, fontSize, dotBorderActive, dotColorPopover }, common: { cubicBezierEaseInOut } } = themeRef.value; + return { + "--n-bezier": cubicBezierEaseInOut, + "--n-dot-border": dotBorder, + "--n-dot-border-active": dotBorderActive, + "--n-dot-border-radius": dotBorderRadius, + "--n-dot-box-shadow": dotBoxShadow, + "--n-dot-color": dotColor, + "--n-dot-color-modal": dotColorModal, + "--n-dot-color-popover": dotColorPopover, + "--n-dot-height": dotHeight, + "--n-dot-width": dotWidth, + "--n-fill-color": fillColor, + "--n-fill-color-hover": fillColorHover, + "--n-font-size": fontSize, + "--n-handle-box-shadow": handleBoxShadow, + "--n-handle-box-shadow-active": handleBoxShadowActive, + "--n-handle-box-shadow-focus": handleBoxShadowFocus, + "--n-handle-box-shadow-hover": handleBoxShadowHover, + "--n-handle-color": handleColor, + "--n-handle-size": handleSize, + "--n-opacity-disabled": opacityDisabled, + "--n-rail-color": railColor, + "--n-rail-color-hover": railColorHover, + "--n-rail-height": railHeight, + "--n-rail-width-vertical": railWidthVertical, + "--n-mark-font-size": markFontSize + }; + }); + const themeClassHandle = inlineThemeDisabled ? useThemeClass("slider", void 0, cssVarsRef, props) : void 0; + const indicatorCssVarsRef = computed(() => { + const { self: { fontSize, indicatorColor, indicatorBoxShadow, indicatorTextColor, indicatorBorderRadius } } = themeRef.value; + return { + "--n-font-size": fontSize, + "--n-indicator-border-radius": indicatorBorderRadius, + "--n-indicator-box-shadow": indicatorBoxShadow, + "--n-indicator-color": indicatorColor, + "--n-indicator-text-color": indicatorTextColor + }; + }); + const indicatorThemeClassHandle = inlineThemeDisabled ? useThemeClass("slider-indicator", void 0, indicatorCssVarsRef, props) : void 0; + return { + mergedClsPrefix: mergedClsPrefixRef, + namespace: namespaceRef, + uncontrolledValue: uncontrolledValueRef, + mergedValue: mergedValueRef, + mergedDisabled: mergedDisabledRef, + mergedPlacement: mergedPlacementRef, + isMounted: isMounted(), + adjustedTo: useAdjustedTo(props), + dotTransitionDisabled: dotTransitionDisabledRef, + markInfos: markInfosRef, + isShowTooltip, + shouldKeepTooltipTransition, + handleRailRef, + setHandleRefs, + setFollowerRefs, + fillStyle: fillStyleRef, + getHandleStyle, + activeIndex: activeIndexRef, + arrifiedValues: arrifiedValueRef, + followerEnabledIndexSet: followerEnabledIndexSetRef, + handleRailMouseDown, + handleHandleFocus, + handleHandleBlur, + handleHandleMouseEnter, + handleHandleMouseLeave, + handleRailKeyDown, + indicatorCssVars: inlineThemeDisabled ? void 0 : indicatorCssVarsRef, + indicatorThemeClass: indicatorThemeClassHandle === null || indicatorThemeClassHandle === void 0 ? void 0 : indicatorThemeClassHandle.themeClass, + indicatorOnRender: indicatorThemeClassHandle === null || indicatorThemeClassHandle === void 0 ? void 0 : indicatorThemeClassHandle.onRender, + cssVars: inlineThemeDisabled ? void 0 : cssVarsRef, + themeClass: themeClassHandle === null || themeClassHandle === void 0 ? void 0 : themeClassHandle.themeClass, + onRender: themeClassHandle === null || themeClassHandle === void 0 ? void 0 : themeClassHandle.onRender + }; + }, + render() { + var _a; + const { mergedClsPrefix, themeClass, formatTooltip } = this; + (_a = this.onRender) === null || _a === void 0 ? void 0 : _a.call(this); + return h( + "div", + { class: [ + `${mergedClsPrefix}-slider`, + themeClass, + { + [`${mergedClsPrefix}-slider--disabled`]: this.mergedDisabled, + [`${mergedClsPrefix}-slider--active`]: this.activeIndex !== -1, + [`${mergedClsPrefix}-slider--with-mark`]: this.marks, + [`${mergedClsPrefix}-slider--vertical`]: this.vertical, + [`${mergedClsPrefix}-slider--reverse`]: this.reverse + } + ], style: this.cssVars, onKeydown: this.handleRailKeyDown, onMousedown: this.handleRailMouseDown, onTouchstart: this.handleRailMouseDown }, + h( + "div", + { class: `${mergedClsPrefix}-slider-rail` }, + h("div", { class: `${mergedClsPrefix}-slider-rail__fill`, style: this.fillStyle }), + this.marks ? h("div", { class: [ + `${mergedClsPrefix}-slider-dots`, + this.dotTransitionDisabled && `${mergedClsPrefix}-slider-dots--transition-disabled` + ] }, this.markInfos.map((mark) => h("div", { key: mark.label, class: [ + `${mergedClsPrefix}-slider-dot`, + { + [`${mergedClsPrefix}-slider-dot--active`]: mark.active + } + ], style: mark.style }))) : null, + h("div", { ref: "handleRailRef", class: `${mergedClsPrefix}-slider-handles` }, this.arrifiedValues.map((value, index) => { + const showTooltip = this.isShowTooltip(index); + return h(VBinder, null, { + default: () => [ + h(VTarget, null, { + default: () => h("div", { ref: this.setHandleRefs(index), class: `${mergedClsPrefix}-slider-handle-wrapper`, tabindex: this.mergedDisabled ? -1 : 0, style: this.getHandleStyle(value, index), onFocus: () => { + this.handleHandleFocus(index); + }, onBlur: () => { + this.handleHandleBlur(index); + }, onMouseenter: () => { + this.handleHandleMouseEnter(index); + }, onMouseleave: () => { + this.handleHandleMouseLeave(index); + } }, resolveSlot(this.$slots.thumb, () => [ + h("div", { class: `${mergedClsPrefix}-slider-handle` }) + ])) + }), + this.tooltip && h(VFollower, { ref: this.setFollowerRefs(index), show: showTooltip, to: this.adjustedTo, enabled: this.showTooltip && !this.range || this.followerEnabledIndexSet.has(index), teleportDisabled: this.adjustedTo === useAdjustedTo.tdkey, placement: this.mergedPlacement, containerClass: this.namespace }, { + default: () => h(Transition, { name: "fade-in-scale-up-transition", appear: this.isMounted, css: this.shouldKeepTooltipTransition(index), onEnter: () => { + this.followerEnabledIndexSet.add(index); + }, onAfterLeave: () => { + this.followerEnabledIndexSet.delete(index); + } }, { + default: () => { + var _a2; + if (showTooltip) { + (_a2 = this.indicatorOnRender) === null || _a2 === void 0 ? void 0 : _a2.call(this); + return h("div", { class: [ + `${mergedClsPrefix}-slider-handle-indicator`, + this.indicatorThemeClass, + `${mergedClsPrefix}-slider-handle-indicator--${this.mergedPlacement}` + ], style: this.indicatorCssVars }, typeof formatTooltip === "function" ? formatTooltip(value) : value); + } + return null; + } + }) + }) + ] + }); + })), + this.marks ? h("div", { class: `${mergedClsPrefix}-slider-marks` }, this.markInfos.map((mark) => h("div", { key: mark.label, class: `${mergedClsPrefix}-slider-mark`, style: mark.style }, mark.label))) : null + ) + ); + } +}); +export { + NSlider as N +}; diff --git a/frontend/dist/assets/Switch.js b/frontend/dist/assets/Switch.js index 27efe4dec..391f8d5cd 100644 --- a/frontend/dist/assets/Switch.js +++ b/frontend/dist/assets/Switch.js @@ -1,727 +1,4 @@ -import { z as ref, bn as onBeforeUpdate, aa as c, Q as cB, ab as cM, at as cE, aS as fadeInScaleUpTransition, aU as insideModal, aV as insidePopover, d as defineComponent, S as useConfig, T as useTheme, ar as useFormItem, c as computed, X as toRef, ae as useMergedState, K as watch, W as nextTick, aB as onBeforeUnmount, Y as useThemeClass, bO as isMounted, ak as useAdjustedTo, y as h, bZ as VBinder, b_ as VTarget, ai as resolveSlot, b$ as VFollower, aX as Transition, c0 as sliderLight, aD as on, aC as off, a1 as call, aT as iconSwitchTransition, ac as cNotM, ah as createKey, aG as pxfy, ay as depx, c1 as isSlotEmpty, av as resolveWrappedSlot, c2 as switchLight, aI as NIconSwitchTransition, aJ as NBaseLoading } from "./index.js"; -function isTouchEvent(e) { - return window.TouchEvent && e instanceof window.TouchEvent; -} -function useRefs() { - const refs = ref(/* @__PURE__ */ new Map()); - const setRefs = (index) => (el) => { - refs.value.set(index, el); - }; - onBeforeUpdate(() => { - refs.value.clear(); - }); - return [refs, setRefs]; -} -const style$1 = c([cB("slider", ` - display: block; - padding: calc((var(--n-handle-size) - var(--n-rail-height)) / 2) 0; - position: relative; - z-index: 0; - width: 100%; - cursor: pointer; - user-select: none; - -webkit-user-select: none; - `, [cM("reverse", [cB("slider-handles", [cB("slider-handle-wrapper", ` - transform: translate(50%, -50%); - `)]), cB("slider-dots", [cB("slider-dot", ` - transform: translateX(50%, -50%); - `)]), cM("vertical", [cB("slider-handles", [cB("slider-handle-wrapper", ` - transform: translate(-50%, -50%); - `)]), cB("slider-marks", [cB("slider-mark", ` - transform: translateY(calc(-50% + var(--n-dot-height) / 2)); - `)]), cB("slider-dots", [cB("slider-dot", ` - transform: translateX(-50%) translateY(0); - `)])])]), cM("vertical", ` - padding: 0 calc((var(--n-handle-size) - var(--n-rail-height)) / 2); - width: var(--n-rail-width-vertical); - height: 100%; - `, [cB("slider-handles", ` - top: calc(var(--n-handle-size) / 2); - right: 0; - bottom: calc(var(--n-handle-size) / 2); - left: 0; - `, [cB("slider-handle-wrapper", ` - top: unset; - left: 50%; - transform: translate(-50%, 50%); - `)]), cB("slider-rail", ` - height: 100%; - `, [cE("fill", ` - top: unset; - right: 0; - bottom: unset; - left: 0; - `)]), cM("with-mark", ` - width: var(--n-rail-width-vertical); - margin: 0 32px 0 8px; - `), cB("slider-marks", ` - top: calc(var(--n-handle-size) / 2); - right: unset; - bottom: calc(var(--n-handle-size) / 2); - left: 22px; - font-size: var(--n-mark-font-size); - `, [cB("slider-mark", ` - transform: translateY(50%); - white-space: nowrap; - `)]), cB("slider-dots", ` - top: calc(var(--n-handle-size) / 2); - right: unset; - bottom: calc(var(--n-handle-size) / 2); - left: 50%; - `, [cB("slider-dot", ` - transform: translateX(-50%) translateY(50%); - `)])]), cM("disabled", ` - cursor: not-allowed; - opacity: var(--n-opacity-disabled); - `, [cB("slider-handle", ` - cursor: not-allowed; - `)]), cM("with-mark", ` - width: 100%; - margin: 8px 0 32px 0; - `), c("&:hover", [cB("slider-rail", { - backgroundColor: "var(--n-rail-color-hover)" -}, [cE("fill", { - backgroundColor: "var(--n-fill-color-hover)" -})]), cB("slider-handle", { - boxShadow: "var(--n-handle-box-shadow-hover)" -})]), cM("active", [cB("slider-rail", { - backgroundColor: "var(--n-rail-color-hover)" -}, [cE("fill", { - backgroundColor: "var(--n-fill-color-hover)" -})]), cB("slider-handle", { - boxShadow: "var(--n-handle-box-shadow-hover)" -})]), cB("slider-marks", ` - position: absolute; - top: 18px; - left: calc(var(--n-handle-size) / 2); - right: calc(var(--n-handle-size) / 2); - `, [cB("slider-mark", ` - position: absolute; - transform: translateX(-50%); - white-space: nowrap; - `)]), cB("slider-rail", ` - width: 100%; - position: relative; - height: var(--n-rail-height); - background-color: var(--n-rail-color); - transition: background-color .3s var(--n-bezier); - border-radius: calc(var(--n-rail-height) / 2); - `, [cE("fill", ` - position: absolute; - top: 0; - bottom: 0; - border-radius: calc(var(--n-rail-height) / 2); - transition: background-color .3s var(--n-bezier); - background-color: var(--n-fill-color); - `)]), cB("slider-handles", ` - position: absolute; - top: 0; - right: calc(var(--n-handle-size) / 2); - bottom: 0; - left: calc(var(--n-handle-size) / 2); - `, [cB("slider-handle-wrapper", ` - outline: none; - position: absolute; - top: 50%; - transform: translate(-50%, -50%); - cursor: pointer; - display: flex; - `, [cB("slider-handle", ` - height: var(--n-handle-size); - width: var(--n-handle-size); - border-radius: 50%; - overflow: hidden; - transition: box-shadow .2s var(--n-bezier), background-color .3s var(--n-bezier); - background-color: var(--n-handle-color); - box-shadow: var(--n-handle-box-shadow); - `, [c("&:hover", ` - box-shadow: var(--n-handle-box-shadow-hover); - `)]), c("&:focus", [cB("slider-handle", ` - box-shadow: var(--n-handle-box-shadow-focus); - `, [c("&:hover", ` - box-shadow: var(--n-handle-box-shadow-active); - `)])])])]), cB("slider-dots", ` - position: absolute; - top: 50%; - left: calc(var(--n-handle-size) / 2); - right: calc(var(--n-handle-size) / 2); - `, [cM("transition-disabled", [cB("slider-dot", "transition: none;")]), cB("slider-dot", ` - transition: - border-color .3s var(--n-bezier), - box-shadow .3s var(--n-bezier), - background-color .3s var(--n-bezier); - position: absolute; - transform: translate(-50%, -50%); - height: var(--n-dot-height); - width: var(--n-dot-width); - border-radius: var(--n-dot-border-radius); - overflow: hidden; - box-sizing: border-box; - border: var(--n-dot-border); - background-color: var(--n-dot-color); - `, [cM("active", "border: var(--n-dot-border-active);")])])]), cB("slider-handle-indicator", ` - font-size: var(--n-font-size); - padding: 6px 10px; - border-radius: var(--n-indicator-border-radius); - color: var(--n-indicator-text-color); - background-color: var(--n-indicator-color); - box-shadow: var(--n-indicator-box-shadow); - `, [fadeInScaleUpTransition()]), cB("slider-handle-indicator", ` - font-size: var(--n-font-size); - padding: 6px 10px; - border-radius: var(--n-indicator-border-radius); - color: var(--n-indicator-text-color); - background-color: var(--n-indicator-color); - box-shadow: var(--n-indicator-box-shadow); - `, [cM("top", ` - margin-bottom: 12px; - `), cM("right", ` - margin-left: 12px; - `), cM("bottom", ` - margin-top: 12px; - `), cM("left", ` - margin-right: 12px; - `), fadeInScaleUpTransition()]), insideModal(cB("slider", [cB("slider-dot", "background-color: var(--n-dot-color-modal);")])), insidePopover(cB("slider", [cB("slider-dot", "background-color: var(--n-dot-color-popover);")]))]); -const eventButtonLeft = 0; -const sliderProps = Object.assign(Object.assign({}, useTheme.props), { to: useAdjustedTo.propTo, defaultValue: { - type: [Number, Array], - default: 0 -}, marks: Object, disabled: { - type: Boolean, - default: void 0 -}, formatTooltip: Function, keyboard: { - type: Boolean, - default: true -}, min: { - type: Number, - default: 0 -}, max: { - type: Number, - default: 100 -}, step: { - type: [Number, String], - default: 1 -}, range: Boolean, value: [Number, Array], placement: String, showTooltip: { - type: Boolean, - default: void 0 -}, tooltip: { - type: Boolean, - default: true -}, vertical: Boolean, reverse: Boolean, "onUpdate:value": [Function, Array], onUpdateValue: [Function, Array] }); -const NSlider = defineComponent({ - name: "Slider", - props: sliderProps, - setup(props) { - const { mergedClsPrefixRef, namespaceRef, inlineThemeDisabled } = useConfig(props); - const themeRef = useTheme("Slider", "-slider", style$1, sliderLight, props, mergedClsPrefixRef); - const handleRailRef = ref(null); - const [handleRefs, setHandleRefs] = useRefs(); - const [followerRefs, setFollowerRefs] = useRefs(); - const followerEnabledIndexSetRef = ref(/* @__PURE__ */ new Set()); - const formItem = useFormItem(props); - const { mergedDisabledRef } = formItem; - const precisionRef = computed(() => { - const { step } = props; - if (Number(step) <= 0 || step === "mark") - return 0; - const stepString = step.toString(); - let precision = 0; - if (stepString.includes(".")) { - precision = stepString.length - stepString.indexOf(".") - 1; - } - return precision; - }); - const uncontrolledValueRef = ref(props.defaultValue); - const controlledValueRef = toRef(props, "value"); - const mergedValueRef = useMergedState(controlledValueRef, uncontrolledValueRef); - const arrifiedValueRef = computed(() => { - const { value: mergedValue } = mergedValueRef; - return (props.range ? mergedValue : [mergedValue]).map(clampValue); - }); - const handleCountExceeds2Ref = computed(() => arrifiedValueRef.value.length > 2); - const mergedPlacementRef = computed(() => { - return props.placement === void 0 ? props.vertical ? "right" : "top" : props.placement; - }); - const markValuesRef = computed(() => { - const { marks } = props; - return marks ? Object.keys(marks).map(parseFloat) : null; - }); - const activeIndexRef = ref(-1); - const previousIndexRef = ref(-1); - const hoverIndexRef = ref(-1); - const draggingRef = ref(false); - const dotTransitionDisabledRef = ref(false); - const styleDirectionRef = computed(() => { - const { vertical, reverse } = props; - const left = reverse ? "right" : "left"; - const bottom = reverse ? "top" : "bottom"; - return vertical ? bottom : left; - }); - const fillStyleRef = computed(() => { - if (handleCountExceeds2Ref.value) - return; - const values = arrifiedValueRef.value; - const start = valueToPercentage(props.range ? Math.min(...values) : props.min); - const end = valueToPercentage(props.range ? Math.max(...values) : values[0]); - const { value: styleDirection } = styleDirectionRef; - return props.vertical ? { - [styleDirection]: `${start}%`, - height: `${end - start}%` - } : { - [styleDirection]: `${start}%`, - width: `${end - start}%` - }; - }); - const markInfosRef = computed(() => { - const mergedMarks = []; - const { marks } = props; - if (marks) { - const orderValues = arrifiedValueRef.value.slice(); - orderValues.sort((a, b) => a - b); - const { value: styleDirection } = styleDirectionRef; - const { value: handleCountExceeds2 } = handleCountExceeds2Ref; - const { range } = props; - const isActive = handleCountExceeds2 ? () => false : (num) => range ? num >= orderValues[0] && num <= orderValues[orderValues.length - 1] : num <= orderValues[0]; - for (const key of Object.keys(marks)) { - const num = Number(key); - mergedMarks.push({ - active: isActive(num), - label: marks[key], - style: { - [styleDirection]: `${valueToPercentage(num)}%` - } - }); - } - } - return mergedMarks; - }); - function getHandleStyle(value, index) { - const percentage = valueToPercentage(value); - const { value: styleDirection } = styleDirectionRef; - return { - [styleDirection]: `${percentage}%`, - zIndex: index === activeIndexRef.value ? 1 : 0 - }; - } - function isShowTooltip(index) { - return props.showTooltip || hoverIndexRef.value === index || activeIndexRef.value === index && draggingRef.value; - } - function shouldKeepTooltipTransition(index) { - if (!draggingRef.value) - return true; - return !(activeIndexRef.value === index && previousIndexRef.value === index); - } - function focusActiveHandle(index) { - var _a; - if (~index) { - activeIndexRef.value = index; - (_a = handleRefs.value.get(index)) === null || _a === void 0 ? void 0 : _a.focus(); - } - } - function syncPosition() { - followerRefs.value.forEach((inst, index) => { - if (isShowTooltip(index)) - inst.syncPosition(); - }); - } - function doUpdateValue(value) { - const { "onUpdate:value": _onUpdateValue, onUpdateValue } = props; - const { nTriggerFormInput, nTriggerFormChange } = formItem; - if (onUpdateValue) - call(onUpdateValue, value); - if (_onUpdateValue) - call(_onUpdateValue, value); - uncontrolledValueRef.value = value; - nTriggerFormInput(); - nTriggerFormChange(); - } - function dispatchValueUpdate(value) { - const { range } = props; - if (range) { - if (Array.isArray(value)) { - const { value: oldValues } = arrifiedValueRef; - if (value.join() !== oldValues.join()) { - doUpdateValue(value); - } - } - } else if (!Array.isArray(value)) { - const oldValue = arrifiedValueRef.value[0]; - if (oldValue !== value) { - doUpdateValue(value); - } - } - } - function doDispatchValue(value, index) { - if (props.range) { - const values = arrifiedValueRef.value.slice(); - values.splice(index, 1, value); - dispatchValueUpdate(values); - } else { - dispatchValueUpdate(value); - } - } - function sanitizeValue(value, currentValue, stepBuffer) { - const stepping = stepBuffer !== void 0; - if (!stepBuffer) { - stepBuffer = value - currentValue > 0 ? 1 : -1; - } - const markValues = markValuesRef.value || []; - const { step } = props; - if (step === "mark") { - const closestMark2 = getClosestMark(value, markValues.concat(currentValue), stepping ? stepBuffer : void 0); - return closestMark2 ? closestMark2.value : currentValue; - } - if (step <= 0) - return currentValue; - const { value: precision } = precisionRef; - let closestMark; - if (stepping) { - const currentStep = Number((currentValue / step).toFixed(precision)); - const actualStep = Math.floor(currentStep); - const leftStep = currentStep > actualStep ? actualStep : actualStep - 1; - const rightStep = currentStep < actualStep ? actualStep : actualStep + 1; - closestMark = getClosestMark(currentValue, [ - Number((leftStep * step).toFixed(precision)), - Number((rightStep * step).toFixed(precision)), - ...markValues - ], stepBuffer); - } else { - const roundValue = getRoundValue(value); - closestMark = getClosestMark(value, [...markValues, roundValue]); - } - return closestMark ? clampValue(closestMark.value) : currentValue; - } - function clampValue(value) { - return Math.min(props.max, Math.max(props.min, value)); - } - function valueToPercentage(value) { - const { max, min } = props; - return (value - min) / (max - min) * 100; - } - function percentageToValue(percentage) { - const { max, min } = props; - return min + (max - min) * percentage; - } - function getRoundValue(value) { - const { step, min } = props; - if (Number(step) <= 0 || step === "mark") - return value; - const newValue = Math.round((value - min) / step) * step + min; - return Number(newValue.toFixed(precisionRef.value)); - } - function getClosestMark(currentValue, markValues = markValuesRef.value, buffer) { - if (!(markValues === null || markValues === void 0 ? void 0 : markValues.length)) - return null; - let closestMark = null; - let index = -1; - while (++index < markValues.length) { - const diff = markValues[index] - currentValue; - const distance = Math.abs(diff); - if ( - // find marks in the same direction - (buffer === void 0 || diff * buffer > 0) && (closestMark === null || distance < closestMark.distance) - ) { - closestMark = { - index, - distance, - value: markValues[index] - }; - } - } - return closestMark; - } - function getPointValue(event) { - const railEl = handleRailRef.value; - if (!railEl) - return; - const touchEvent = isTouchEvent(event) ? event.touches[0] : event; - const railRect = railEl.getBoundingClientRect(); - let percentage; - if (props.vertical) { - percentage = (railRect.bottom - touchEvent.clientY) / railRect.height; - } else { - percentage = (touchEvent.clientX - railRect.left) / railRect.width; - } - if (props.reverse) { - percentage = 1 - percentage; - } - return percentageToValue(percentage); - } - function handleRailKeyDown(e) { - if (mergedDisabledRef.value || !props.keyboard) - return; - const { vertical, reverse } = props; - switch (e.key) { - case "ArrowUp": - e.preventDefault(); - handleStepValue(vertical && reverse ? -1 : 1); - break; - case "ArrowRight": - e.preventDefault(); - handleStepValue(!vertical && reverse ? -1 : 1); - break; - case "ArrowDown": - e.preventDefault(); - handleStepValue(vertical && reverse ? 1 : -1); - break; - case "ArrowLeft": - e.preventDefault(); - handleStepValue(!vertical && reverse ? 1 : -1); - break; - } - } - function handleStepValue(ratio) { - const activeIndex = activeIndexRef.value; - if (activeIndex === -1) - return; - const { step } = props; - const currentValue = arrifiedValueRef.value[activeIndex]; - const nextValue = Number(step) <= 0 || step === "mark" ? currentValue : currentValue + step * ratio; - doDispatchValue( - // Avoid the number of value does not change when `step` is null - sanitizeValue(nextValue, currentValue, ratio > 0 ? 1 : -1), - activeIndex - ); - } - function handleRailMouseDown(event) { - var _a, _b; - if (mergedDisabledRef.value) - return; - if (!isTouchEvent(event) && event.button !== eventButtonLeft) { - return; - } - const pointValue = getPointValue(event); - if (pointValue === void 0) - return; - const values = arrifiedValueRef.value.slice(); - const activeIndex = props.range ? (_b = (_a = getClosestMark(pointValue, values)) === null || _a === void 0 ? void 0 : _a.index) !== null && _b !== void 0 ? _b : -1 : 0; - if (activeIndex !== -1) { - event.preventDefault(); - focusActiveHandle(activeIndex); - startDragging(); - doDispatchValue(sanitizeValue(pointValue, arrifiedValueRef.value[activeIndex]), activeIndex); - } - } - function startDragging() { - if (!draggingRef.value) { - draggingRef.value = true; - on("touchend", document, handleMouseUp); - on("mouseup", document, handleMouseUp); - on("touchmove", document, handleMouseMove); - on("mousemove", document, handleMouseMove); - } - } - function stopDragging() { - if (draggingRef.value) { - draggingRef.value = false; - off("touchend", document, handleMouseUp); - off("mouseup", document, handleMouseUp); - off("touchmove", document, handleMouseMove); - off("mousemove", document, handleMouseMove); - } - } - function handleMouseMove(event) { - const { value: activeIndex } = activeIndexRef; - if (!draggingRef.value || activeIndex === -1) { - stopDragging(); - return; - } - const pointValue = getPointValue(event); - doDispatchValue(sanitizeValue(pointValue, arrifiedValueRef.value[activeIndex]), activeIndex); - } - function handleMouseUp() { - stopDragging(); - } - function handleHandleFocus(index) { - activeIndexRef.value = index; - if (!mergedDisabledRef.value) { - hoverIndexRef.value = index; - } - } - function handleHandleBlur(index) { - if (activeIndexRef.value === index) { - activeIndexRef.value = -1; - stopDragging(); - } - if (hoverIndexRef.value === index) { - hoverIndexRef.value = -1; - } - } - function handleHandleMouseEnter(index) { - hoverIndexRef.value = index; - } - function handleHandleMouseLeave(index) { - if (hoverIndexRef.value === index) { - hoverIndexRef.value = -1; - } - } - watch(activeIndexRef, (_, previous) => void nextTick(() => previousIndexRef.value = previous)); - watch(mergedValueRef, () => { - if (props.marks) { - if (dotTransitionDisabledRef.value) - return; - dotTransitionDisabledRef.value = true; - void nextTick(() => { - dotTransitionDisabledRef.value = false; - }); - } - void nextTick(syncPosition); - }); - onBeforeUnmount(() => { - stopDragging(); - }); - const cssVarsRef = computed(() => { - const { self: { markFontSize, railColor, railColorHover, fillColor, fillColorHover, handleColor, opacityDisabled, dotColor, dotColorModal, handleBoxShadow, handleBoxShadowHover, handleBoxShadowActive, handleBoxShadowFocus, dotBorder, dotBoxShadow, railHeight, railWidthVertical, handleSize, dotHeight, dotWidth, dotBorderRadius, fontSize, dotBorderActive, dotColorPopover }, common: { cubicBezierEaseInOut } } = themeRef.value; - return { - "--n-bezier": cubicBezierEaseInOut, - "--n-dot-border": dotBorder, - "--n-dot-border-active": dotBorderActive, - "--n-dot-border-radius": dotBorderRadius, - "--n-dot-box-shadow": dotBoxShadow, - "--n-dot-color": dotColor, - "--n-dot-color-modal": dotColorModal, - "--n-dot-color-popover": dotColorPopover, - "--n-dot-height": dotHeight, - "--n-dot-width": dotWidth, - "--n-fill-color": fillColor, - "--n-fill-color-hover": fillColorHover, - "--n-font-size": fontSize, - "--n-handle-box-shadow": handleBoxShadow, - "--n-handle-box-shadow-active": handleBoxShadowActive, - "--n-handle-box-shadow-focus": handleBoxShadowFocus, - "--n-handle-box-shadow-hover": handleBoxShadowHover, - "--n-handle-color": handleColor, - "--n-handle-size": handleSize, - "--n-opacity-disabled": opacityDisabled, - "--n-rail-color": railColor, - "--n-rail-color-hover": railColorHover, - "--n-rail-height": railHeight, - "--n-rail-width-vertical": railWidthVertical, - "--n-mark-font-size": markFontSize - }; - }); - const themeClassHandle = inlineThemeDisabled ? useThemeClass("slider", void 0, cssVarsRef, props) : void 0; - const indicatorCssVarsRef = computed(() => { - const { self: { fontSize, indicatorColor, indicatorBoxShadow, indicatorTextColor, indicatorBorderRadius } } = themeRef.value; - return { - "--n-font-size": fontSize, - "--n-indicator-border-radius": indicatorBorderRadius, - "--n-indicator-box-shadow": indicatorBoxShadow, - "--n-indicator-color": indicatorColor, - "--n-indicator-text-color": indicatorTextColor - }; - }); - const indicatorThemeClassHandle = inlineThemeDisabled ? useThemeClass("slider-indicator", void 0, indicatorCssVarsRef, props) : void 0; - return { - mergedClsPrefix: mergedClsPrefixRef, - namespace: namespaceRef, - uncontrolledValue: uncontrolledValueRef, - mergedValue: mergedValueRef, - mergedDisabled: mergedDisabledRef, - mergedPlacement: mergedPlacementRef, - isMounted: isMounted(), - adjustedTo: useAdjustedTo(props), - dotTransitionDisabled: dotTransitionDisabledRef, - markInfos: markInfosRef, - isShowTooltip, - shouldKeepTooltipTransition, - handleRailRef, - setHandleRefs, - setFollowerRefs, - fillStyle: fillStyleRef, - getHandleStyle, - activeIndex: activeIndexRef, - arrifiedValues: arrifiedValueRef, - followerEnabledIndexSet: followerEnabledIndexSetRef, - handleRailMouseDown, - handleHandleFocus, - handleHandleBlur, - handleHandleMouseEnter, - handleHandleMouseLeave, - handleRailKeyDown, - indicatorCssVars: inlineThemeDisabled ? void 0 : indicatorCssVarsRef, - indicatorThemeClass: indicatorThemeClassHandle === null || indicatorThemeClassHandle === void 0 ? void 0 : indicatorThemeClassHandle.themeClass, - indicatorOnRender: indicatorThemeClassHandle === null || indicatorThemeClassHandle === void 0 ? void 0 : indicatorThemeClassHandle.onRender, - cssVars: inlineThemeDisabled ? void 0 : cssVarsRef, - themeClass: themeClassHandle === null || themeClassHandle === void 0 ? void 0 : themeClassHandle.themeClass, - onRender: themeClassHandle === null || themeClassHandle === void 0 ? void 0 : themeClassHandle.onRender - }; - }, - render() { - var _a; - const { mergedClsPrefix, themeClass, formatTooltip } = this; - (_a = this.onRender) === null || _a === void 0 ? void 0 : _a.call(this); - return h( - "div", - { class: [ - `${mergedClsPrefix}-slider`, - themeClass, - { - [`${mergedClsPrefix}-slider--disabled`]: this.mergedDisabled, - [`${mergedClsPrefix}-slider--active`]: this.activeIndex !== -1, - [`${mergedClsPrefix}-slider--with-mark`]: this.marks, - [`${mergedClsPrefix}-slider--vertical`]: this.vertical, - [`${mergedClsPrefix}-slider--reverse`]: this.reverse - } - ], style: this.cssVars, onKeydown: this.handleRailKeyDown, onMousedown: this.handleRailMouseDown, onTouchstart: this.handleRailMouseDown }, - h( - "div", - { class: `${mergedClsPrefix}-slider-rail` }, - h("div", { class: `${mergedClsPrefix}-slider-rail__fill`, style: this.fillStyle }), - this.marks ? h("div", { class: [ - `${mergedClsPrefix}-slider-dots`, - this.dotTransitionDisabled && `${mergedClsPrefix}-slider-dots--transition-disabled` - ] }, this.markInfos.map((mark) => h("div", { key: mark.label, class: [ - `${mergedClsPrefix}-slider-dot`, - { - [`${mergedClsPrefix}-slider-dot--active`]: mark.active - } - ], style: mark.style }))) : null, - h("div", { ref: "handleRailRef", class: `${mergedClsPrefix}-slider-handles` }, this.arrifiedValues.map((value, index) => { - const showTooltip = this.isShowTooltip(index); - return h(VBinder, null, { - default: () => [ - h(VTarget, null, { - default: () => h("div", { ref: this.setHandleRefs(index), class: `${mergedClsPrefix}-slider-handle-wrapper`, tabindex: this.mergedDisabled ? -1 : 0, style: this.getHandleStyle(value, index), onFocus: () => { - this.handleHandleFocus(index); - }, onBlur: () => { - this.handleHandleBlur(index); - }, onMouseenter: () => { - this.handleHandleMouseEnter(index); - }, onMouseleave: () => { - this.handleHandleMouseLeave(index); - } }, resolveSlot(this.$slots.thumb, () => [ - h("div", { class: `${mergedClsPrefix}-slider-handle` }) - ])) - }), - this.tooltip && h(VFollower, { ref: this.setFollowerRefs(index), show: showTooltip, to: this.adjustedTo, enabled: this.showTooltip && !this.range || this.followerEnabledIndexSet.has(index), teleportDisabled: this.adjustedTo === useAdjustedTo.tdkey, placement: this.mergedPlacement, containerClass: this.namespace }, { - default: () => h(Transition, { name: "fade-in-scale-up-transition", appear: this.isMounted, css: this.shouldKeepTooltipTransition(index), onEnter: () => { - this.followerEnabledIndexSet.add(index); - }, onAfterLeave: () => { - this.followerEnabledIndexSet.delete(index); - } }, { - default: () => { - var _a2; - if (showTooltip) { - (_a2 = this.indicatorOnRender) === null || _a2 === void 0 ? void 0 : _a2.call(this); - return h("div", { class: [ - `${mergedClsPrefix}-slider-handle-indicator`, - this.indicatorThemeClass, - `${mergedClsPrefix}-slider-handle-indicator--${this.mergedPlacement}` - ], style: this.indicatorCssVars }, typeof formatTooltip === "function" ? formatTooltip(value) : value); - } - return null; - } - }) - }) - ] - }); - })), - this.marks ? h("div", { class: `${mergedClsPrefix}-slider-marks` }, this.markInfos.map((mark) => h("div", { key: mark.label, class: `${mergedClsPrefix}-slider-mark`, style: mark.style }, mark.label))) : null - ) - ); - } -}); +import { R as cB, au as cE, aU as iconSwitchTransition, ab as c, ac as cM, ad as cNotM, d as defineComponent, T as useConfig, U as useTheme, as as useFormItem, B as ref, Y as toRef, af as useMergedState, c as computed, ai as createKey, aH as pxfy, az as depx, Z as useThemeClass, c4 as isSlotEmpty, A as h, aw as resolveWrappedSlot, c5 as switchLight, aJ as NIconSwitchTransition, aK as NBaseLoading, a2 as call } from "./index.js"; const style = cB("switch", ` height: var(--n-height); min-width: var(--n-width); @@ -1077,6 +354,5 @@ const NSwitch = defineComponent({ } }); export { - NSwitch as N, - NSlider as a + NSwitch as N }; diff --git a/frontend/dist/assets/TaggerView.js b/frontend/dist/assets/TaggerView.js index 154f4dc90..d78f5fc00 100644 --- a/frontend/dist/assets/TaggerView.js +++ b/frontend/dist/assets/TaggerView.js @@ -1,9 +1,10 @@ -import { d as defineComponent, a as useState, u as useSettings, p as useMessage, z as ref, c as computed, G as spaceRegex, o as openBlock, j as createElementBlock, g as createVNode, w as withCtx, h as unref, r as NGi, n as NCard, N as NSpace, f as createBaseVNode, i as NSelect, l as NTooltip, k as createTextVNode, J as NInput, C as toDisplayString, s as NGrid, t as serverUrl, v as pushScopeId, x as popScopeId, _ as _export_sfc } from "./index.js"; +import { d as defineComponent, l as useState, u as useSettings, r as useMessage, B as ref, c as computed, G as spaceRegex, o as openBlock, a as createElementBlock, e as createVNode, w as withCtx, f as unref, t as NGi, m as NCard, j as NSpace, b as createBaseVNode, q as NSelect, N as NTooltip, h as createTextVNode, J as NInput, E as toDisplayString, v as NGrid, x as serverUrl, y as pushScopeId, z as popScopeId, _ as _export_sfc } from "./index.js"; import { _ as _sfc_main$1 } from "./GenerateSection.vue_vue_type_script_setup_true_lang.js"; import { I as ImageUpload } from "./ImageUpload.js"; import { v as v4 } from "./v4.js"; -import { a as NSlider, N as NSwitch } from "./Switch.js"; +import { N as NSlider } from "./Slider.js"; import { N as NInputNumber } from "./InputNumber.js"; +import { N as NSwitch } from "./Switch.js"; import "./CloudUpload.js"; const _withScopeId = (n) => (pushScopeId("data-v-94d16b9f"), n = n(), popScopeId(), n); const _hoisted_1 = { class: "main-container" }; diff --git a/frontend/dist/assets/TestView.js b/frontend/dist/assets/TestView.js index ff8c3ffbd..a1078c6fb 100644 --- a/frontend/dist/assets/TestView.js +++ b/frontend/dist/assets/TestView.js @@ -1,4 +1,4 @@ -import { d as defineComponent, z as ref, o as openBlock, e as createBlock, h as unref } from "./index.js"; +import { d as defineComponent, B as ref, o as openBlock, g as createBlock, f as unref } from "./index.js"; import { _ as _sfc_main$1 } from "./ModelPopup.vue_vue_type_script_setup_true_lang.js"; import "./DescriptionsItem.js"; const _sfc_main = /* @__PURE__ */ defineComponent({ diff --git a/frontend/dist/assets/TextToImageView.js b/frontend/dist/assets/TextToImageView.js index 9501b1c55..2076fbf38 100644 --- a/frontend/dist/assets/TextToImageView.js +++ b/frontend/dist/assets/TextToImageView.js @@ -1,182 +1,551 @@ -import { d as defineComponent, u as useSettings, a as useState, c as computed, b as upscalerOptions, o as openBlock, e as createBlock, w as withCtx, f as createBaseVNode, g as createVNode, h as unref, N as NSpace, i as NSelect, j as createElementBlock, k as createTextVNode, l as NTooltip, m as createCommentVNode, n as NCard, p as useMessage, q as onUnmounted, r as NGi, s as NGrid, t as serverUrl } from "./index.js"; -import { _ as _sfc_main$6 } from "./GenerateSection.vue_vue_type_script_setup_true_lang.js"; -import { _ as _sfc_main$7 } from "./ImageOutput.vue_vue_type_script_setup_true_lang.js"; -import { B as BurnerClock, P as Prompt, _ as _sfc_main$4, a as _sfc_main$5, b as _sfc_main$8 } from "./clock.js"; -import { N as NSwitch, a as NSlider } from "./Switch.js"; +import { d as defineComponent, u as useSettings, c as computed, o as openBlock, a as createElementBlock, b as createBaseVNode, e as createVNode, f as unref, g as createBlock, w as withCtx, h as createTextVNode, N as NTooltip, i as isDev, j as NSpace, k as createCommentVNode, F as Fragment, l as useState, m as NCard, n as NTabPane, p as NTabs, q as NSelect, r as useMessage, s as onUnmounted, t as NGi, v as NGrid, x as serverUrl } from "./index.js"; +import { _ as _sfc_main$d } from "./GenerateSection.vue_vue_type_script_setup_true_lang.js"; +import { _ as _sfc_main$e } from "./ImageOutput.vue_vue_type_script_setup_true_lang.js"; +import { _ as _sfc_main$7, a as _sfc_main$8, B as BurnerClock, P as Prompt, b as _sfc_main$9, c as _sfc_main$a, d as _sfc_main$f } from "./clock.js"; +import { _ as _sfc_main$6, a as _sfc_main$b, b as _sfc_main$c } from "./Upscale.vue_vue_type_script_setup_true_lang.js"; +import { N as NSwitch } from "./Switch.js"; +import { N as NSlider } from "./Slider.js"; import { N as NInputNumber } from "./InputNumber.js"; -import { _ as _sfc_main$3 } from "./SamplerPicker.vue_vue_type_script_setup_true_lang.js"; import { v as v4 } from "./v4.js"; import "./SendOutputTo.vue_vue_type_script_setup_true_lang.js"; import "./TrashBin.js"; import "./DescriptionsItem.js"; import "./Settings.js"; +const _hoisted_1$3 = { class: "flex-container" }; +const _hoisted_2$3 = /* @__PURE__ */ createBaseVNode("div", { class: "slider-label" }, [ + /* @__PURE__ */ createBaseVNode("p", null, "Enabled") +], -1); +const _hoisted_3$3 = { class: "flex-container" }; +const _hoisted_4$3 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Steps", -1); +const _hoisted_5$3 = /* @__PURE__ */ createBaseVNode("b", { class: "highlight" }, "We recommend using 20-50 steps for most images.", -1); +const _hoisted_6$3 = { class: "flex-container" }; +const _hoisted_7$3 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Seed", -1); +const _hoisted_8$2 = /* @__PURE__ */ createBaseVNode("b", { class: "highlight" }, "For random seed use -1.", -1); +const _hoisted_9$2 = { class: "flex-container" }; +const _hoisted_10$1 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Strength", -1); +const _hoisted_11$1 = { class: "flex-container" }; +const _hoisted_12$1 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Mask Dilation", -1); +const _hoisted_13$1 = { class: "flex-container" }; +const _hoisted_14$1 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Mask Blur", -1); +const _hoisted_15$1 = { class: "flex-container" }; +const _hoisted_16$1 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Mask Padding", -1); +const _hoisted_17 = { class: "flex-container" }; +const _hoisted_18 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Iterations", -1); +const _hoisted_19 = { class: "flex-container" }; +const _hoisted_20 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Upscale", -1); +const _sfc_main$5 = /* @__PURE__ */ defineComponent({ + __name: "ADetailer", + props: { + tab: { + type: String, + required: true + }, + target: { + type: String, + required: false, + default: "settings" + } + }, + setup(__props) { + const props = __props; + const settings = useSettings(); + const target = computed(() => { + if (props.target === "settings") { + return settings.data.settings; + } + return settings.defaultSettings; + }); + return (_ctx, _cache) => { + return openBlock(), createElementBlock(Fragment, null, [ + createBaseVNode("div", _hoisted_1$3, [ + _hoisted_2$3, + createVNode(unref(NSwitch), { + value: target.value[props.tab].adetailer.enabled, + "onUpdate:value": _cache[0] || (_cache[0] = ($event) => target.value[props.tab].adetailer.enabled = $event) + }, null, 8, ["value"]) + ]), + target.value[props.tab].adetailer.enabled ? (openBlock(), createBlock(unref(NSpace), { + key: 0, + vertical: "", + class: "left-container", + "builtin-theme-overrides": { + gapMedium: "0 12px" + } + }, { + default: withCtx(() => [ + createVNode(unref(_sfc_main$6), { type: "inpainting" }), + createBaseVNode("div", _hoisted_3$3, [ + createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { + trigger: withCtx(() => [ + _hoisted_4$3 + ]), + default: withCtx(() => [ + createTextVNode(" Number of steps to take in the diffusion process. Higher values will result in more detailed images but will take longer to generate. There is also a point of diminishing returns around 100 steps. "), + _hoisted_5$3 + ]), + _: 1 + }), + createVNode(unref(NSlider), { + value: target.value[props.tab].adetailer.steps, + "onUpdate:value": _cache[1] || (_cache[1] = ($event) => target.value[props.tab].adetailer.steps = $event), + min: 5, + max: 300, + style: { "margin-right": "12px" } + }, null, 8, ["value"]), + createVNode(unref(NInputNumber), { + value: target.value[props.tab].adetailer.steps, + "onUpdate:value": _cache[2] || (_cache[2] = ($event) => target.value[props.tab].adetailer.steps = $event), + size: "small", + style: { "min-width": "96px", "width": "96px" } + }, null, 8, ["value"]) + ]), + createVNode(unref(_sfc_main$7), { + tab: "inpainting", + target: "adetailer" + }), + createVNode(unref(_sfc_main$8), { + tab: "inpainting", + target: "adetailer" + }), + createBaseVNode("div", _hoisted_6$3, [ + createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { + trigger: withCtx(() => [ + _hoisted_7$3 + ]), + default: withCtx(() => [ + createTextVNode(" Seed is a number that represents the starting canvas of your image. If you want to create the same image as your friend, you can use the same settings and seed to do so. "), + _hoisted_8$2 + ]), + _: 1 + }), + createVNode(unref(NInputNumber), { + value: target.value[props.tab].adetailer.seed, + "onUpdate:value": _cache[3] || (_cache[3] = ($event) => target.value[props.tab].adetailer.seed = $event), + size: "small", + min: -1, + max: 999999999999, + style: { "flex-grow": "1" } + }, null, 8, ["value"]) + ]), + createBaseVNode("div", _hoisted_9$2, [ + createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { + trigger: withCtx(() => [ + _hoisted_10$1 + ]), + default: withCtx(() => [ + createTextVNode(" How much should the masked are be changed from the original ") + ]), + _: 1 + }), + createVNode(unref(NSlider), { + value: target.value[props.tab].adetailer.strength, + "onUpdate:value": _cache[4] || (_cache[4] = ($event) => target.value[props.tab].adetailer.strength = $event), + min: 0, + max: 1, + step: 0.01, + style: { "margin-right": "12px" } + }, null, 8, ["value"]), + createVNode(unref(NInputNumber), { + value: target.value[props.tab].adetailer.strength, + "onUpdate:value": _cache[5] || (_cache[5] = ($event) => target.value[props.tab].adetailer.strength = $event), + size: "small", + style: { "min-width": "96px", "width": "96px" }, + min: 0, + max: 1, + step: 0.01 + }, null, 8, ["value"]) + ]), + createBaseVNode("div", _hoisted_11$1, [ + createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { + trigger: withCtx(() => [ + _hoisted_12$1 + ]), + default: withCtx(() => [ + createTextVNode(" Expands bright pixels in the mask to cover more of the image. ") + ]), + _: 1 + }), + createVNode(unref(NInputNumber), { + value: target.value[props.tab].adetailer.mask_dilation, + "onUpdate:value": _cache[6] || (_cache[6] = ($event) => target.value[props.tab].adetailer.mask_dilation = $event), + size: "small", + min: 0, + style: { "flex-grow": "1" } + }, null, 8, ["value"]) + ]), + createBaseVNode("div", _hoisted_13$1, [ + createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { + trigger: withCtx(() => [ + _hoisted_14$1 + ]), + default: withCtx(() => [ + createTextVNode(" Makes for a smooth transition between masked and unmasked areas. ") + ]), + _: 1 + }), + createVNode(unref(NInputNumber), { + value: target.value[props.tab].adetailer.mask_blur, + "onUpdate:value": _cache[7] || (_cache[7] = ($event) => target.value[props.tab].adetailer.mask_blur = $event), + size: "small", + min: 0, + style: { "flex-grow": "1" } + }, null, 8, ["value"]) + ]), + createBaseVNode("div", _hoisted_15$1, [ + createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { + trigger: withCtx(() => [ + _hoisted_16$1 + ]), + default: withCtx(() => [ + createTextVNode(" Image will be cropped to the mask size plus padding. More padding might mean smoother transitions but slower generation. ") + ]), + _: 1 + }), + createVNode(unref(NInputNumber), { + value: target.value[props.tab].adetailer.mask_padding, + "onUpdate:value": _cache[8] || (_cache[8] = ($event) => target.value[props.tab].adetailer.mask_padding = $event), + size: "small", + min: 0, + style: { "flex-grow": "1" } + }, null, 8, ["value"]) + ]), + createBaseVNode("div", _hoisted_17, [ + createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { + trigger: withCtx(() => [ + _hoisted_18 + ]), + default: withCtx(() => [ + createTextVNode(" Iterations should increase the quality of the image at the cost of time. ") + ]), + _: 1 + }), + createVNode(unref(NInputNumber), { + value: target.value[props.tab].adetailer.iterations, + "onUpdate:value": _cache[9] || (_cache[9] = ($event) => target.value[props.tab].adetailer.iterations = $event), + disabled: !unref(isDev), + size: "small", + min: 1, + style: { "flex-grow": "1" } + }, null, 8, ["value", "disabled"]) + ]), + createBaseVNode("div", _hoisted_19, [ + createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { + trigger: withCtx(() => [ + _hoisted_20 + ]), + default: withCtx(() => [ + createTextVNode(" Hom much should the image be upscaled before processing. This increases the quality of the image at the cost of time as bigger canvas can usually hold more detail. ") + ]), + _: 1 + }), + createVNode(unref(NSlider), { + value: target.value[props.tab].adetailer.upscale, + "onUpdate:value": _cache[10] || (_cache[10] = ($event) => target.value[props.tab].adetailer.upscale = $event), + min: 1, + max: 4, + step: 0.1, + style: { "margin-right": "12px" } + }, null, 8, ["value"]), + createVNode(unref(NInputNumber), { + value: target.value[props.tab].adetailer.upscale, + "onUpdate:value": _cache[11] || (_cache[11] = ($event) => target.value[props.tab].adetailer.upscale = $event), + size: "small", + style: { "min-width": "96px", "width": "96px" }, + min: 1, + max: 4, + step: 0.1 + }, null, 8, ["value"]) + ]) + ]), + _: 1 + })) : createCommentVNode("", true) + ], 64); + }; + } +}); +const _hoisted_1$2 = { class: "flex-container" }; +const _hoisted_2$2 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Enabled", -1); +const _hoisted_3$2 = { key: 0 }; +const _hoisted_4$2 = { class: "flex-container" }; +const _hoisted_5$2 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Width", -1); +const _hoisted_6$2 = { class: "flex-container" }; +const _hoisted_7$2 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Height", -1); +const _sfc_main$4 = /* @__PURE__ */ defineComponent({ + __name: "ResizeFromDimensionsInput", + setup(__props) { + const settings = useSettings(); + const global = useState(); + return (_ctx, _cache) => { + return openBlock(), createBlock(unref(NCard), { + title: "Resize from", + class: "generate-extra-card" + }, { + default: withCtx(() => [ + createBaseVNode("div", _hoisted_1$2, [ + _hoisted_2$2, + createVNode(unref(NSwitch), { + value: unref(global).state.txt2img.sdxl_resize, + "onUpdate:value": _cache[0] || (_cache[0] = ($event) => unref(global).state.txt2img.sdxl_resize = $event) + }, null, 8, ["value"]) + ]), + unref(global).state.txt2img.sdxl_resize ? (openBlock(), createElementBlock("div", _hoisted_3$2, [ + createBaseVNode("div", _hoisted_4$2, [ + _hoisted_5$2, + createVNode(unref(NSlider), { + value: unref(settings).data.settings.flags.sdxl.original_size.width, + "onUpdate:value": _cache[1] || (_cache[1] = ($event) => unref(settings).data.settings.flags.sdxl.original_size.width = $event), + min: 128, + max: 2048, + step: 1, + style: { "margin-right": "12px" } + }, null, 8, ["value"]), + createVNode(unref(NInputNumber), { + value: unref(settings).data.settings.flags.sdxl.original_size.width, + "onUpdate:value": _cache[2] || (_cache[2] = ($event) => unref(settings).data.settings.flags.sdxl.original_size.width = $event), + size: "small", + style: { "min-width": "96px", "width": "96px" }, + step: 1 + }, null, 8, ["value"]) + ]), + createBaseVNode("div", _hoisted_6$2, [ + _hoisted_7$2, + createVNode(unref(NSlider), { + value: unref(settings).data.settings.flags.sdxl.original_size.height, + "onUpdate:value": _cache[3] || (_cache[3] = ($event) => unref(settings).data.settings.flags.sdxl.original_size.height = $event), + min: 128, + max: 2048, + step: 1, + style: { "margin-right": "12px" } + }, null, 8, ["value"]), + createVNode(unref(NInputNumber), { + value: unref(settings).data.settings.flags.sdxl.original_size.height, + "onUpdate:value": _cache[4] || (_cache[4] = ($event) => unref(settings).data.settings.flags.sdxl.original_size.height = $event), + size: "small", + style: { "min-width": "96px", "width": "96px" }, + step: 1 + }, null, 8, ["value"]) + ]) + ])) : createCommentVNode("", true) + ]), + _: 1 + }); + }; + } +}); +const _sfc_main$3 = /* @__PURE__ */ defineComponent({ + __name: "Restoration", + props: { + tab: { + type: String, + required: true + }, + target: { + type: String, + required: false, + default: "settings" + } + }, + setup(__props) { + const props = __props; + return (_ctx, _cache) => { + return openBlock(), createBlock(unref(NCard), { + title: "Restoration", + class: "generate-extra-card" + }, { + default: withCtx(() => [ + createVNode(unref(NTabs), { + animated: "", + type: "segment" + }, { + default: withCtx(() => [ + createVNode(unref(NTabPane), { + tab: "ADetailer", + name: "adetailer" + }, { + default: withCtx(() => [ + createVNode(unref(_sfc_main$5), { + tab: props.tab, + target: props.target + }, null, 8, ["tab", "target"]) + ]), + _: 1 + }) + ]), + _: 1 + }) + ]), + _: 1 + }); + }; + } +}); const _hoisted_1$1 = { class: "flex-container" }; const _hoisted_2$1 = /* @__PURE__ */ createBaseVNode("div", { class: "slider-label" }, [ /* @__PURE__ */ createBaseVNode("p", null, "Enabled") ], -1); const _hoisted_3$1 = { class: "flex-container" }; -const _hoisted_4$1 = /* @__PURE__ */ createBaseVNode("div", { class: "slider-label" }, [ - /* @__PURE__ */ createBaseVNode("p", null, "Mode") -], -1); -const _hoisted_5$1 = { key: 0 }; +const _hoisted_4$1 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Refiner model", -1); +const _hoisted_5$1 = /* @__PURE__ */ createBaseVNode("b", { class: "highlight" }, " Generally, the refiner that came with your model is bound to generate the best results. ", -1); const _hoisted_6$1 = { class: "flex-container" }; -const _hoisted_7$1 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Upscaler", -1); -const _hoisted_8$1 = { key: 1 }; +const _hoisted_7$1 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Steps", -1); +const _hoisted_8$1 = /* @__PURE__ */ createBaseVNode("b", { class: "highlight" }, "We recommend using 20-50 steps for most images.", -1); const _hoisted_9$1 = { class: "flex-container" }; -const _hoisted_10$1 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Antialiased", -1); -const _hoisted_11$1 = { class: "flex-container" }; -const _hoisted_12$1 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Latent Mode", -1); -const _hoisted_13$1 = { class: "flex-container" }; -const _hoisted_14$1 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Steps", -1); -const _hoisted_15 = /* @__PURE__ */ createBaseVNode("b", { class: "highlight" }, "We recommend using 20-50 steps for most images.", -1); -const _hoisted_16 = { class: "flex-container" }; -const _hoisted_17 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Scale", -1); -const _hoisted_18 = { class: "flex-container" }; -const _hoisted_19 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Strength", -1); +const _hoisted_10 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Aesthetic Score", -1); +const _hoisted_11 = /* @__PURE__ */ createBaseVNode("b", { class: "highlight" }, "Generally best to keep it around 6.", -1); +const _hoisted_12 = { class: "flex-container" }; +const _hoisted_13 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Negative Aesthetic Score", -1); +const _hoisted_14 = /* @__PURE__ */ createBaseVNode("b", { class: "highlight" }, "Generally best to keep it around 3.", -1); +const _hoisted_15 = { class: "flex-container" }; +const _hoisted_16 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Strength", -1); const _sfc_main$2 = /* @__PURE__ */ defineComponent({ - __name: "HighResFix", + __name: "XLRefiner", setup(__props) { const settings = useSettings(); const global = useState(); - const imageUpscalerOptions = computed(() => { - const localModels = global.state.models.filter( - (model) => model.backend === "Upscaler" && !(upscalerOptions.map((option) => option.label).indexOf(model.name) !== -1) - ).map((model) => ({ - label: model.name, - value: model.path - })); - return [...upscalerOptions, ...localModels]; + const refinerModels = computed(() => { + return global.state.models.filter((model) => model.type === "SDXL").map((model) => { + return { + label: model.name, + value: model.name + }; + }); }); - const latentUpscalerOptions = [ - { label: "Nearest", value: "nearest" }, - { label: "Nearest exact", value: "nearest-exact" }, - { label: "Area", value: "area" }, - { label: "Bilinear", value: "bilinear" }, - { label: "Bicubic", value: "bicubic" }, - { - label: "Bislerp (Original, slow)", - value: "bislerp-original" - }, - { - label: "Bislerp (Tortured, fast)", - value: "bislerp-tortured" - } - ]; + async function onRefinerChange(modelStr) { + settings.data.settings.flags.refiner.model = modelStr; + } return (_ctx, _cache) => { - return openBlock(), createBlock(unref(NCard), { title: "Highres fix" }, { + return openBlock(), createBlock(unref(NCard), { + title: "SDXL Refiner", + class: "generate-extra-card" + }, { default: withCtx(() => [ createBaseVNode("div", _hoisted_1$1, [ _hoisted_2$1, createVNode(unref(NSwitch), { - value: unref(global).state.txt2img.highres, - "onUpdate:value": _cache[0] || (_cache[0] = ($event) => unref(global).state.txt2img.highres = $event) + value: unref(global).state.txt2img.refiner, + "onUpdate:value": _cache[0] || (_cache[0] = ($event) => unref(global).state.txt2img.refiner = $event) }, null, 8, ["value"]) ]), - unref(global).state.txt2img.highres ? (openBlock(), createBlock(unref(NSpace), { + unref(global).state.txt2img.refiner ? (openBlock(), createBlock(unref(NSpace), { key: 0, vertical: "", class: "left-container" }, { default: withCtx(() => [ createBaseVNode("div", _hoisted_3$1, [ - _hoisted_4$1, + createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { + trigger: withCtx(() => [ + _hoisted_4$1 + ]), + default: withCtx(() => [ + createTextVNode(" The SDXL-Refiner model to use for this step of diffusion. "), + _hoisted_5$1 + ]), + _: 1 + }), createVNode(unref(NSelect), { - value: unref(settings).data.settings.flags.highres.mode, - "onUpdate:value": _cache[1] || (_cache[1] = ($event) => unref(settings).data.settings.flags.highres.mode = $event), - options: [ - { label: "Latent", value: "latent" }, - { label: "Image", value: "image" } - ] - }, null, 8, ["value"]) + options: refinerModels.value, + placeholder: "None", + "onUpdate:value": onRefinerChange, + value: unref(settings).data.settings.flags.refiner.model !== null ? unref(settings).data.settings.flags.refiner.model : "" + }, null, 8, ["options", "value"]) ]), - unref(settings).data.settings.flags.highres.mode === "image" ? (openBlock(), createElementBlock("div", _hoisted_5$1, [ - createBaseVNode("div", _hoisted_6$1, [ - _hoisted_7$1, - createVNode(unref(NSelect), { - value: unref(settings).data.settings.flags.highres.image_upscaler, - "onUpdate:value": _cache[2] || (_cache[2] = ($event) => unref(settings).data.settings.flags.highres.image_upscaler = $event), - size: "small", - style: { "flex-grow": "1" }, - filterable: "", - options: imageUpscalerOptions.value - }, null, 8, ["value", "options"]) - ]) - ])) : (openBlock(), createElementBlock("div", _hoisted_8$1, [ - createBaseVNode("div", _hoisted_9$1, [ - _hoisted_10$1, - createVNode(unref(NSwitch), { - value: unref(settings).data.settings.flags.highres.antialiased, - "onUpdate:value": _cache[3] || (_cache[3] = ($event) => unref(settings).data.settings.flags.highres.antialiased = $event) - }, null, 8, ["value"]) - ]), - createBaseVNode("div", _hoisted_11$1, [ - _hoisted_12$1, - createVNode(unref(NSelect), { - value: unref(settings).data.settings.flags.highres.latent_scale_mode, - "onUpdate:value": _cache[4] || (_cache[4] = ($event) => unref(settings).data.settings.flags.highres.latent_scale_mode = $event), - size: "small", - style: { "flex-grow": "1" }, - filterable: "", - options: latentUpscalerOptions - }, null, 8, ["value"]) - ]) - ])), - createBaseVNode("div", _hoisted_13$1, [ + createBaseVNode("div", _hoisted_6$1, [ createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { trigger: withCtx(() => [ - _hoisted_14$1 + _hoisted_7$1 ]), default: withCtx(() => [ createTextVNode(" Number of steps to take in the diffusion process. Higher values will result in more detailed images but will take longer to generate. There is also a point of diminishing returns around 100 steps. "), - _hoisted_15 + _hoisted_8$1 ]), _: 1 }), createVNode(unref(NSlider), { - value: unref(settings).data.settings.flags.highres.steps, - "onUpdate:value": _cache[5] || (_cache[5] = ($event) => unref(settings).data.settings.flags.highres.steps = $event), + value: unref(settings).data.settings.flags.refiner.steps, + "onUpdate:value": _cache[1] || (_cache[1] = ($event) => unref(settings).data.settings.flags.refiner.steps = $event), min: 5, max: 300, style: { "margin-right": "12px" } }, null, 8, ["value"]), createVNode(unref(NInputNumber), { - value: unref(settings).data.settings.flags.highres.steps, - "onUpdate:value": _cache[6] || (_cache[6] = ($event) => unref(settings).data.settings.flags.highres.steps = $event), + value: unref(settings).data.settings.flags.refiner.steps, + "onUpdate:value": _cache[2] || (_cache[2] = ($event) => unref(settings).data.settings.flags.refiner.steps = $event), + size: "small", + style: { "min-width": "96px", "width": "96px" } + }, null, 8, ["value"]) + ]), + createBaseVNode("div", _hoisted_9$1, [ + createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { + trigger: withCtx(() => [ + _hoisted_10 + ]), + default: withCtx(() => [ + createTextVNode(' Generally higher numbers will produce "more professional" images. '), + _hoisted_11 + ]), + _: 1 + }), + createVNode(unref(NSlider), { + value: unref(settings).data.settings.flags.refiner.aesthetic_score, + "onUpdate:value": _cache[3] || (_cache[3] = ($event) => unref(settings).data.settings.flags.refiner.aesthetic_score = $event), + min: 0, + max: 10, + step: 0.5, + style: { "margin-right": "12px" } + }, null, 8, ["value"]), + createVNode(unref(NInputNumber), { + value: unref(settings).data.settings.flags.refiner.aesthetic_score, + "onUpdate:value": _cache[4] || (_cache[4] = ($event) => unref(settings).data.settings.flags.refiner.aesthetic_score = $event), + min: 0, + max: 10, + step: 0.25, size: "small", style: { "min-width": "96px", "width": "96px" } }, null, 8, ["value"]) ]), - createBaseVNode("div", _hoisted_16, [ - _hoisted_17, + createBaseVNode("div", _hoisted_12, [ + createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { + trigger: withCtx(() => [ + _hoisted_13 + ]), + default: withCtx(() => [ + createTextVNode(" Makes sense to keep this lower than aesthetic score. "), + _hoisted_14 + ]), + _: 1 + }), createVNode(unref(NSlider), { - value: unref(settings).data.settings.flags.highres.scale, - "onUpdate:value": _cache[7] || (_cache[7] = ($event) => unref(settings).data.settings.flags.highres.scale = $event), - min: 1, - max: 8, - step: 0.1, + value: unref(settings).data.settings.flags.refiner.negative_aesthetic_score, + "onUpdate:value": _cache[5] || (_cache[5] = ($event) => unref(settings).data.settings.flags.refiner.negative_aesthetic_score = $event), + min: 0, + max: 10, + step: 0.5, style: { "margin-right": "12px" } }, null, 8, ["value"]), createVNode(unref(NInputNumber), { - value: unref(settings).data.settings.flags.highres.scale, - "onUpdate:value": _cache[8] || (_cache[8] = ($event) => unref(settings).data.settings.flags.highres.scale = $event), + value: unref(settings).data.settings.flags.refiner.negative_aesthetic_score, + "onUpdate:value": _cache[6] || (_cache[6] = ($event) => unref(settings).data.settings.flags.refiner.negative_aesthetic_score = $event), + min: 0, + max: 10, + step: 0.25, size: "small", - style: { "min-width": "96px", "width": "96px" }, - step: 0.1 + style: { "min-width": "96px", "width": "96px" } }, null, 8, ["value"]) ]), - createBaseVNode("div", _hoisted_18, [ - _hoisted_19, + createBaseVNode("div", _hoisted_15, [ + _hoisted_16, createVNode(unref(NSlider), { - value: unref(settings).data.settings.flags.highres.strength, - "onUpdate:value": _cache[9] || (_cache[9] = ($event) => unref(settings).data.settings.flags.highres.strength = $event), + value: unref(settings).data.settings.flags.refiner.strength, + "onUpdate:value": _cache[7] || (_cache[7] = ($event) => unref(settings).data.settings.flags.refiner.strength = $event), min: 0.1, max: 0.9, step: 0.05, style: { "margin-right": "12px" } }, null, 8, ["value"]), createVNode(unref(NInputNumber), { - value: unref(settings).data.settings.flags.highres.strength, - "onUpdate:value": _cache[10] || (_cache[10] = ($event) => unref(settings).data.settings.flags.highres.strength = $event), + value: unref(settings).data.settings.flags.refiner.strength, + "onUpdate:value": _cache[8] || (_cache[8] = ($event) => unref(settings).data.settings.flags.refiner.strength = $event), size: "small", style: { "min-width": "96px", "width": "96px" }, min: 0.1, @@ -198,24 +567,20 @@ const _hoisted_2 = { class: "flex-container" }; const _hoisted_3 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Steps", -1); const _hoisted_4 = /* @__PURE__ */ createBaseVNode("b", { class: "highlight" }, "We recommend using 20-50 steps for most images.", -1); const _hoisted_5 = { class: "flex-container" }; -const _hoisted_6 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "CFG Scale", -1); -const _hoisted_7 = /* @__PURE__ */ createBaseVNode("b", { class: "highlight" }, "We recommend using 3-15 for most images.", -1); -const _hoisted_8 = { - key: 0, - class: "flex-container" -}; -const _hoisted_9 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Self Attention Scale", -1); -const _hoisted_10 = { class: "flex-container" }; -const _hoisted_11 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Batch Count", -1); -const _hoisted_12 = { class: "flex-container" }; -const _hoisted_13 = /* @__PURE__ */ createBaseVNode("p", { style: { "margin-right": "12px", "width": "75px" } }, "Seed", -1); -const _hoisted_14 = /* @__PURE__ */ createBaseVNode("b", { class: "highlight" }, "For random seed use -1.", -1); +const _hoisted_6 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Batch Count", -1); +const _hoisted_7 = { class: "flex-container" }; +const _hoisted_8 = /* @__PURE__ */ createBaseVNode("p", { style: { "margin-right": "12px", "width": "75px" } }, "Seed", -1); +const _hoisted_9 = /* @__PURE__ */ createBaseVNode("b", { class: "highlight" }, "For random seed use -1.", -1); const _sfc_main$1 = /* @__PURE__ */ defineComponent({ __name: "Txt2Img", setup(__props) { const global = useState(); const settings = useSettings(); const messageHandler = useMessage(); + const isSelectedModelSDXL = computed(() => { + var _a; + return ((_a = settings.data.settings.model) == null ? void 0 : _a.type) === "SDXL"; + }); const checkSeed = (seed) => { if (seed === -1) { seed = Math.floor(Math.random() * 999999999999); @@ -223,7 +588,7 @@ const _sfc_main$1 = /* @__PURE__ */ defineComponent({ return seed; }; const generate = () => { - var _a; + var _a, _b; if (settings.data.settings.txt2img.seed === null) { messageHandler.error("Please set a seed"); return; @@ -260,17 +625,77 @@ const _sfc_main$1 = /* @__PURE__ */ defineComponent({ model: (_a = settings.data.settings.model) == null ? void 0 : _a.path, backend: "PyTorch", autoload: false, - flags: global.state.txt2img.highres ? { - highres_fix: { - mode: settings.data.settings.flags.highres.mode, - image_upscaler: settings.data.settings.flags.highres.image_upscaler, - scale: settings.data.settings.flags.highres.scale, - latent_scale_mode: settings.data.settings.flags.highres.latent_scale_mode, - strength: settings.data.settings.flags.highres.strength, - steps: settings.data.settings.flags.highres.steps, - antialiased: settings.data.settings.flags.highres.antialiased - } - } : {} + flags: { + ...isSelectedModelSDXL.value && global.state.txt2img.sdxl_resize ? { + sdxl: { + original_size: { + width: settings.data.settings.flags.sdxl.original_size.width, + height: settings.data.settings.flags.sdxl.original_size.height + } + } + } : {}, + ...settings.data.settings.txt2img.highres.enabled ? { + highres_fix: { + mode: settings.data.settings.txt2img.highres.mode, + image_upscaler: settings.data.settings.txt2img.highres.image_upscaler, + scale: settings.data.settings.txt2img.highres.scale, + latent_scale_mode: settings.data.settings.txt2img.highres.latent_scale_mode, + strength: settings.data.settings.txt2img.highres.strength, + steps: settings.data.settings.txt2img.highres.steps, + antialiased: settings.data.settings.txt2img.highres.antialiased + } + } : global.state.txt2img.refiner ? { + refiner: { + model: settings.data.settings.flags.refiner.model, + aesthetic_score: settings.data.settings.flags.refiner.aesthetic_score, + negative_aesthetic_score: settings.data.settings.flags.refiner.negative_aesthetic_score, + steps: settings.data.settings.flags.refiner.steps, + strength: settings.data.settings.flags.refiner.strength + } + } : {}, + ...settings.data.settings.txt2img.deepshrink.enabled ? { + deepshrink: { + early_out: settings.data.settings.txt2img.deepshrink.early_out, + depth_1: settings.data.settings.txt2img.deepshrink.depth_1, + stop_at_1: settings.data.settings.txt2img.deepshrink.stop_at_1, + depth_2: settings.data.settings.txt2img.deepshrink.depth_2, + stop_at_2: settings.data.settings.txt2img.deepshrink.stop_at_2, + scaler: settings.data.settings.txt2img.deepshrink.scaler, + base_scale: settings.data.settings.txt2img.deepshrink.base_scale + } + } : {}, + ...settings.data.settings.txt2img.scalecrafter.enabled ? { + scalecrafter: { + unsafe_resolutions: settings.data.settings.txt2img.scalecrafter.unsafe_resolutions, + base: (_b = settings.data.settings.model) == null ? void 0 : _b.type, + disperse: settings.data.settings.txt2img.scalecrafter.disperse + } + } : {}, + ...settings.data.settings.txt2img.upscale.enabled ? { + upscale: { + upscale_factor: settings.data.settings.txt2img.upscale.upscale_factor, + tile_size: settings.data.settings.txt2img.upscale.tile_size, + tile_padding: settings.data.settings.txt2img.upscale.tile_padding, + model: settings.data.settings.txt2img.upscale.model + } + } : {}, + ...settings.data.settings.txt2img.adetailer.enabled ? { + adetailer: { + cfg_scale: settings.data.settings.txt2img.adetailer.cfg_scale, + mask_blur: settings.data.settings.txt2img.adetailer.mask_blur, + mask_dilation: settings.data.settings.txt2img.adetailer.mask_dilation, + mask_padding: settings.data.settings.txt2img.adetailer.mask_padding, + iterations: settings.data.settings.txt2img.adetailer.iterations, + upscale: settings.data.settings.txt2img.adetailer.upscale, + scheduler: settings.data.settings.txt2img.adetailer.sampler, + strength: settings.data.settings.txt2img.adetailer.strength, + seed: settings.data.settings.txt2img.adetailer.seed, + self_attention_scale: settings.data.settings.txt2img.adetailer.self_attention_scale, + sigmas: settings.data.settings.txt2img.adetailer.sigmas, + steps: settings.data.settings.txt2img.adetailer.steps + } + } : {} + } }) }).then((res) => { if (!res.ok) { @@ -317,158 +742,111 @@ const _sfc_main$1 = /* @__PURE__ */ defineComponent({ vertical: "", class: "left-container" }, { - default: withCtx(() => { - var _a; - return [ - createVNode(unref(Prompt), { tab: "txt2img" }), - createVNode(unref(_sfc_main$3), { type: "txt2img" }), - createVNode(unref(_sfc_main$4), { - "dimensions-object": unref(settings).data.settings.txt2img - }, null, 8, ["dimensions-object"]), - createBaseVNode("div", _hoisted_2, [ - createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { - trigger: withCtx(() => [ - _hoisted_3 - ]), - default: withCtx(() => [ - createTextVNode(" Number of steps to take in the diffusion process. Higher values will result in more detailed images but will take longer to generate. There is also a point of diminishing returns around 100 steps. "), - _hoisted_4 - ]), - _: 1 - }), - createVNode(unref(NSlider), { - value: unref(settings).data.settings.txt2img.steps, - "onUpdate:value": _cache[0] || (_cache[0] = ($event) => unref(settings).data.settings.txt2img.steps = $event), - min: 5, - max: 300, - style: { "margin-right": "12px" } - }, null, 8, ["value"]), - createVNode(unref(NInputNumber), { - value: unref(settings).data.settings.txt2img.steps, - "onUpdate:value": _cache[1] || (_cache[1] = ($event) => unref(settings).data.settings.txt2img.steps = $event), - size: "small", - style: { "min-width": "96px", "width": "96px" } - }, null, 8, ["value"]) - ]), - createBaseVNode("div", _hoisted_5, [ - createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { - trigger: withCtx(() => [ - _hoisted_6 - ]), - default: withCtx(() => [ - createTextVNode(' Guidance scale indicates how much should model stay close to the prompt. Higher values might be exactly what you want, but generated images might have some artefacts. Lower values indicates that model can "dream" about this prompt more. '), - _hoisted_7 - ]), - _: 1 - }), - createVNode(unref(NSlider), { - value: unref(settings).data.settings.txt2img.cfg_scale, - "onUpdate:value": _cache[2] || (_cache[2] = ($event) => unref(settings).data.settings.txt2img.cfg_scale = $event), - min: 1, - max: 30, - step: 0.5, - style: { "margin-right": "12px" } - }, null, 8, ["value"]), - createVNode(unref(NInputNumber), { - value: unref(settings).data.settings.txt2img.cfg_scale, - "onUpdate:value": _cache[3] || (_cache[3] = ($event) => unref(settings).data.settings.txt2img.cfg_scale = $event), - size: "small", - style: { "min-width": "96px", "width": "96px" }, - step: 0.5 - }, null, 8, ["value"]) - ]), - Number.isInteger(unref(settings).data.settings.txt2img.sampler) && ((_a = unref(settings).data.settings.model) == null ? void 0 : _a.backend) === "PyTorch" ? (openBlock(), createElementBlock("div", _hoisted_8, [ - createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { - trigger: withCtx(() => [ - _hoisted_9 - ]), - default: withCtx(() => [ - createTextVNode(" If self attention is >0, SAG will guide the model and improve the quality of the image at the cost of speed. Higher values will follow the guidance more closely, which can lead to better, more sharp and detailed outputs. ") - ]), - _: 1 - }), - createVNode(unref(NSlider), { - value: unref(settings).data.settings.txt2img.self_attention_scale, - "onUpdate:value": _cache[4] || (_cache[4] = ($event) => unref(settings).data.settings.txt2img.self_attention_scale = $event), - min: 0, - max: 1, - step: 0.05, - style: { "margin-right": "12px" } - }, null, 8, ["value"]), - createVNode(unref(NInputNumber), { - value: unref(settings).data.settings.txt2img.self_attention_scale, - "onUpdate:value": _cache[5] || (_cache[5] = ($event) => unref(settings).data.settings.txt2img.self_attention_scale = $event), - size: "small", - style: { "min-width": "96px", "width": "96px" }, - step: 0.05 - }, null, 8, ["value"]) - ])) : createCommentVNode("", true), - createBaseVNode("div", _hoisted_10, [ - createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { - trigger: withCtx(() => [ - _hoisted_11 - ]), - default: withCtx(() => [ - createTextVNode(" Number of images to generate after each other. ") - ]), - _: 1 - }), - createVNode(unref(NSlider), { - value: unref(settings).data.settings.txt2img.batch_count, - "onUpdate:value": _cache[6] || (_cache[6] = ($event) => unref(settings).data.settings.txt2img.batch_count = $event), - min: 1, - max: 9, - style: { "margin-right": "12px" } - }, null, 8, ["value"]), - createVNode(unref(NInputNumber), { - value: unref(settings).data.settings.txt2img.batch_count, - "onUpdate:value": _cache[7] || (_cache[7] = ($event) => unref(settings).data.settings.txt2img.batch_count = $event), - size: "small", - style: { "min-width": "96px", "width": "96px" } - }, null, 8, ["value"]) - ]), - createVNode(unref(_sfc_main$5), { - "batch-size-object": unref(settings).data.settings.txt2img - }, null, 8, ["batch-size-object"]), - createBaseVNode("div", _hoisted_12, [ - createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { - trigger: withCtx(() => [ - _hoisted_13 - ]), - default: withCtx(() => [ - createTextVNode(" Seed is a number that represents the starting canvas of your image. If you want to create the same image as your friend, you can use the same settings and seed to do so. "), - _hoisted_14 - ]), - _: 1 - }), - createVNode(unref(NInputNumber), { - value: unref(settings).data.settings.txt2img.seed, - "onUpdate:value": _cache[8] || (_cache[8] = ($event) => unref(settings).data.settings.txt2img.seed = $event), - size: "small", - style: { "flex-grow": "1" } - }, null, 8, ["value"]) - ]) - ]; - }), + default: withCtx(() => [ + createVNode(unref(Prompt), { tab: "txt2img" }), + createVNode(unref(_sfc_main$6), { type: "txt2img" }), + createVNode(unref(_sfc_main$9), { + "dimensions-object": unref(settings).data.settings.txt2img + }, null, 8, ["dimensions-object"]), + createBaseVNode("div", _hoisted_2, [ + createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { + trigger: withCtx(() => [ + _hoisted_3 + ]), + default: withCtx(() => [ + createTextVNode(" Number of steps to take in the diffusion process. Higher values will result in more detailed images but will take longer to generate. There is also a point of diminishing returns around 100 steps. "), + _hoisted_4 + ]), + _: 1 + }), + createVNode(unref(NSlider), { + value: unref(settings).data.settings.txt2img.steps, + "onUpdate:value": _cache[0] || (_cache[0] = ($event) => unref(settings).data.settings.txt2img.steps = $event), + min: 5, + max: 300, + style: { "margin-right": "12px" } + }, null, 8, ["value"]), + createVNode(unref(NInputNumber), { + value: unref(settings).data.settings.txt2img.steps, + "onUpdate:value": _cache[1] || (_cache[1] = ($event) => unref(settings).data.settings.txt2img.steps = $event), + size: "small", + style: { "min-width": "96px", "width": "96px" } + }, null, 8, ["value"]) + ]), + createVNode(unref(_sfc_main$7), { tab: "txt2img" }), + createVNode(unref(_sfc_main$8), { tab: "txt2img" }), + createBaseVNode("div", _hoisted_5, [ + createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { + trigger: withCtx(() => [ + _hoisted_6 + ]), + default: withCtx(() => [ + createTextVNode(" Number of images to generate after each other. ") + ]), + _: 1 + }), + createVNode(unref(NSlider), { + value: unref(settings).data.settings.txt2img.batch_count, + "onUpdate:value": _cache[2] || (_cache[2] = ($event) => unref(settings).data.settings.txt2img.batch_count = $event), + min: 1, + max: 9, + style: { "margin-right": "12px" } + }, null, 8, ["value"]), + createVNode(unref(NInputNumber), { + value: unref(settings).data.settings.txt2img.batch_count, + "onUpdate:value": _cache[3] || (_cache[3] = ($event) => unref(settings).data.settings.txt2img.batch_count = $event), + size: "small", + style: { "min-width": "96px", "width": "96px" } + }, null, 8, ["value"]) + ]), + createVNode(unref(_sfc_main$a), { + "batch-size-object": unref(settings).data.settings.txt2img + }, null, 8, ["batch-size-object"]), + createBaseVNode("div", _hoisted_7, [ + createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { + trigger: withCtx(() => [ + _hoisted_8 + ]), + default: withCtx(() => [ + createTextVNode(" Seed is a number that represents the starting canvas of your image. If you want to create the same image as your friend, you can use the same settings and seed to do so. "), + _hoisted_9 + ]), + _: 1 + }), + createVNode(unref(NInputNumber), { + value: unref(settings).data.settings.txt2img.seed, + "onUpdate:value": _cache[4] || (_cache[4] = ($event) => unref(settings).data.settings.txt2img.seed = $event), + size: "small", + style: { "flex-grow": "1" } + }, null, 8, ["value"]) + ]) + ]), _: 1 }) ]), _: 1 }), - createVNode(unref(_sfc_main$2), { style: { "margin-top": "12px", "margin-bottom": "12px" } }) + isSelectedModelSDXL.value ? (openBlock(), createBlock(unref(_sfc_main$4), { + key: 0, + "dimensions-object": unref(settings).data.settings.txt2img + }, null, 8, ["dimensions-object"])) : createCommentVNode("", true), + isSelectedModelSDXL.value ? (openBlock(), createBlock(unref(_sfc_main$2), { key: 1 })) : createCommentVNode("", true), + createVNode(unref(_sfc_main$b), { tab: "txt2img" }), + createVNode(unref(_sfc_main$c), { tab: "txt2img" }), + createVNode(unref(_sfc_main$3), { tab: "txt2img" }) ]), _: 1 }), createVNode(unref(NGi), null, { default: withCtx(() => [ - createVNode(unref(_sfc_main$6), { generate }), - createVNode(unref(_sfc_main$7), { + createVNode(unref(_sfc_main$d), { generate }), + createVNode(unref(_sfc_main$e), { "current-image": unref(global).state.txt2img.currentImage, images: unref(global).state.txt2img.images, data: unref(settings).data.settings.txt2img, - onImageClicked: _cache[9] || (_cache[9] = ($event) => unref(global).state.txt2img.currentImage = $event) + onImageClicked: _cache[5] || (_cache[5] = ($event) => unref(global).state.txt2img.currentImage = $event) }, null, 8, ["current-image", "images", "data"]), - createVNode(unref(_sfc_main$8), { + createVNode(unref(_sfc_main$f), { style: { "margin-top": "12px" }, "gen-data": unref(global).state.txt2img.genData }, null, 8, ["gen-data"]) diff --git a/frontend/dist/assets/TrashBin.js b/frontend/dist/assets/TrashBin.js index 8f894ef91..94dfc07fe 100644 --- a/frontend/dist/assets/TrashBin.js +++ b/frontend/dist/assets/TrashBin.js @@ -1,4 +1,4 @@ -import { O as replaceable, y as h, d as defineComponent, bL as isBrowser, T as useTheme, P as createInjectionKey, aa as c, Q as cB, bM as fadeInTransition, aS as fadeInScaleUpTransition, ac as cNotM, X as toRef, bN as imageLight, z as ref, ad as useLocale, K as watch, aD as on, aC as off, aB as onBeforeUnmount, R as inject, c as computed, S as useConfig, Y as useThemeClass, bO as isMounted, bP as LazyTeleport, br as withDirectives, bQ as zindexable, aX as Transition, F as Fragment, aj as NBaseIcon, bs as vShow, ba as normalizeStyle, bR as kebabCase, l as NTooltip, aR as beforeNextFrameOnce, aW as createId, a3 as provide, bC as getCurrentInstance, b9 as onMounted, af as watchEffect, o as openBlock, j as createElementBlock, f as createBaseVNode } from "./index.js"; +import { P as replaceable, A as h, d as defineComponent, bQ as isBrowser, U as useTheme, Q as createInjectionKey, ab as c, R as cB, bR as fadeInTransition, aT as fadeInScaleUpTransition, ad as cNotM, Y as toRef, bS as imageLight, B as ref, ae as useLocale, K as watch, aE as on, aD as off, aC as onBeforeUnmount, S as inject, c as computed, T as useConfig, Z as useThemeClass, bT as isMounted, bU as LazyTeleport, bv as withDirectives, bV as zindexable, aY as Transition, F as Fragment, ak as NBaseIcon, bw as vShow, bc as normalizeStyle, bW as kebabCase, N as NTooltip, aS as beforeNextFrameOnce, aX as createId, a4 as provide, bG as getCurrentInstance, bb as onMounted, ag as watchEffect, o as openBlock, a as createElementBlock, b as createBaseVNode } from "./index.js"; const RotateClockwiseIcon = replaceable("rotateClockwise", h( "svg", { viewBox: "0 0 20 20", fill: "none", xmlns: "http://www.w3.org/2000/svg" }, diff --git a/frontend/dist/assets/SamplerPicker.vue_vue_type_script_setup_true_lang.js b/frontend/dist/assets/Upscale.vue_vue_type_script_setup_true_lang.js similarity index 69% rename from frontend/dist/assets/SamplerPicker.vue_vue_type_script_setup_true_lang.js rename to frontend/dist/assets/Upscale.vue_vue_type_script_setup_true_lang.js index d3120ca90..eeae2599e 100644 --- a/frontend/dist/assets/SamplerPicker.vue_vue_type_script_setup_true_lang.js +++ b/frontend/dist/assets/Upscale.vue_vue_type_script_setup_true_lang.js @@ -1,7 +1,8 @@ -import { R as inject, bC as getCurrentInstance, K as watch, aB as onBeforeUnmount, Q as cB, ab as cM, aa as c, P as createInjectionKey, d as defineComponent, S as useConfig, T as useTheme, z as ref, a3 as provide, y as h, bD as formLight, a2 as keysOf, c as computed, az as formatLength, aH as get, bE as commonVariables, at as cE, X as toRef, aW as createId, bF as formItemInjectionKey, b9 as onMounted, ah as createKey, Y as useThemeClass, aX as Transition, av as resolveWrappedSlot, aM as warn, u as useSettings, o as openBlock, j as createElementBlock, f as createBaseVNode, g as createVNode, w as withCtx, h as unref, n as NCard, F as Fragment, L as renderList, A as NButton, k as createTextVNode, C as toDisplayString, by as convertToTextString, e as createBlock, bG as resolveDynamicComponent, bd as NModal, l as NTooltip, i as NSelect, B as NIcon } from "./index.js"; -import { S as Settings, a as NCheckbox } from "./Settings.js"; +import { S as inject, bG as getCurrentInstance, K as watch, aC as onBeforeUnmount, R as cB, ac as cM, ab as c, Q as createInjectionKey, d as defineComponent, T as useConfig, U as useTheme, B as ref, a4 as provide, A as h, bH as formLight, a3 as keysOf, c as computed, aA as formatLength, aI as get, bI as commonVariables, au as cE, Y as toRef, aX as createId, bJ as formItemInjectionKey, bb as onMounted, ai as createKey, Z as useThemeClass, aY as Transition, aw as resolveWrappedSlot, aO as warn, u as useSettings, o as openBlock, a as createElementBlock, b as createBaseVNode, e as createVNode, f as unref, w as withCtx, h as createTextVNode, bK as NAlert, m as NCard, q as NSelect, k as createCommentVNode, F as Fragment, l as useState, L as upscalerOptions, g as createBlock, N as NTooltip, j as NSpace, n as NTabPane, p as NTabs, M as renderList, C as NButton, E as toDisplayString, bC as convertToTextString, bL as resolveDynamicComponent, bh as NModal, D as NIcon } from "./index.js"; +import { N as NSwitch } from "./Switch.js"; import { N as NInputNumber } from "./InputNumber.js"; -import { a as NSlider } from "./Switch.js"; +import { N as NSlider } from "./Slider.js"; +import { S as Settings, a as NCheckbox } from "./Settings.js"; function useInjectionInstanceCollection(injectionName, collectionKey, registerKeyRef) { var _a; const injection = inject(injectionName, null); @@ -1839,17 +1840,472 @@ const NFormItem = defineComponent({ ); } }); -const _hoisted_1 = { class: "flex-container" }; -const _hoisted_2 = { style: { "margin-left": "12px", "margin-right": "12px", "white-space": "nowrap" } }; -const _hoisted_3 = /* @__PURE__ */ createBaseVNode("p", { style: { "margin-right": "12px", "width": "100px" } }, "Sampler", -1); -const _hoisted_4 = /* @__PURE__ */ createBaseVNode("a", { +const _hoisted_1$4 = { class: "flex-container" }; +const _hoisted_2$4 = /* @__PURE__ */ createBaseVNode("div", { class: "slider-label" }, [ + /* @__PURE__ */ createBaseVNode("p", null, "Enabled") +], -1); +const _hoisted_3$4 = { key: 0 }; +const _hoisted_4$4 = /* @__PURE__ */ createBaseVNode("b", { class: "highlight" }, "diffusers", -1); +const _hoisted_5$4 = { class: "flex-container space-between" }; +const _hoisted_6$4 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Depth", -1); +const _hoisted_7$4 = { class: "flex-container" }; +const _hoisted_8$3 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Stop at", -1); +const _hoisted_9$3 = { class: "flex-container space-between" }; +const _hoisted_10$2 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Depth", -1); +const _hoisted_11$1 = { class: "flex-container" }; +const _hoisted_12$1 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Stop at", -1); +const _hoisted_13$1 = { class: "flex-container" }; +const _hoisted_14$1 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Scale", -1); +const _hoisted_15$1 = { class: "flex-container" }; +const _hoisted_16$1 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Latent scaler", -1); +const _hoisted_17$1 = { class: "flex-container" }; +const _hoisted_18$1 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Early out", -1); +const _sfc_main$5 = /* @__PURE__ */ defineComponent({ + __name: "DeepShrink", + props: { + tab: { + type: String, + required: true + }, + target: { + type: String, + required: false, + default: "settings" + } + }, + setup(__props) { + const props = __props; + const settings = useSettings(); + const latentUpscalerOptions = [ + { label: "Nearest", value: "nearest" }, + { label: "Nearest exact", value: "nearest-exact" }, + { label: "Area", value: "area" }, + { label: "Bilinear", value: "bilinear" }, + { label: "Bicubic", value: "bicubic" }, + { label: "Bislerp", value: "bislerp" } + ]; + const target = computed(() => { + if (props.target === "settings") { + return settings.data.settings; + } + return settings.defaultSettings; + }); + return (_ctx, _cache) => { + return openBlock(), createElementBlock(Fragment, null, [ + createBaseVNode("div", _hoisted_1$4, [ + _hoisted_2$4, + createVNode(unref(NSwitch), { + value: target.value[props.tab].deepshrink.enabled, + "onUpdate:value": _cache[0] || (_cache[0] = ($event) => target.value[props.tab].deepshrink.enabled = $event) + }, null, 8, ["value"]) + ]), + target.value[props.tab].deepshrink.enabled ? (openBlock(), createElementBlock("div", _hoisted_3$4, [ + createVNode(unref(NAlert), { type: "warning" }, { + default: withCtx(() => [ + createTextVNode(" Only works on "), + _hoisted_4$4, + createTextVNode(" samplers ") + ]), + _: 1 + }), + createVNode(unref(NCard), { + bordered: false, + title: "First layer" + }, { + default: withCtx(() => [ + createBaseVNode("div", _hoisted_5$4, [ + _hoisted_6$4, + createVNode(unref(NInputNumber), { + value: target.value[props.tab].deepshrink.depth_1, + "onUpdate:value": _cache[1] || (_cache[1] = ($event) => target.value[props.tab].deepshrink.depth_1 = $event), + max: 4, + min: 1, + step: 1 + }, null, 8, ["value"]) + ]), + createBaseVNode("div", _hoisted_7$4, [ + _hoisted_8$3, + createVNode(unref(NSlider), { + value: target.value[props.tab].deepshrink.stop_at_1, + "onUpdate:value": _cache[2] || (_cache[2] = ($event) => target.value[props.tab].deepshrink.stop_at_1 = $event), + min: 0.05, + max: 1, + step: 0.05, + style: { "margin-right": "12px" } + }, null, 8, ["value"]), + createVNode(unref(NInputNumber), { + value: target.value[props.tab].deepshrink.stop_at_1, + "onUpdate:value": _cache[3] || (_cache[3] = ($event) => target.value[props.tab].deepshrink.stop_at_1 = $event), + max: 1, + min: 0.05, + step: 0.05 + }, null, 8, ["value"]) + ]) + ]), + _: 1 + }), + createVNode(unref(NCard), { + bordered: false, + title: "Second layer" + }, { + default: withCtx(() => [ + createBaseVNode("div", _hoisted_9$3, [ + _hoisted_10$2, + createVNode(unref(NInputNumber), { + value: target.value[props.tab].deepshrink.depth_2, + "onUpdate:value": _cache[4] || (_cache[4] = ($event) => target.value[props.tab].deepshrink.depth_2 = $event), + max: 4, + min: 1, + step: 1 + }, null, 8, ["value"]) + ]), + createBaseVNode("div", _hoisted_11$1, [ + _hoisted_12$1, + createVNode(unref(NSlider), { + value: target.value[props.tab].deepshrink.stop_at_2, + "onUpdate:value": _cache[5] || (_cache[5] = ($event) => target.value[props.tab].deepshrink.stop_at_2 = $event), + min: 0.05, + max: 1, + step: 0.05 + }, null, 8, ["value"]), + createVNode(unref(NInputNumber), { + value: target.value[props.tab].deepshrink.stop_at_2, + "onUpdate:value": _cache[6] || (_cache[6] = ($event) => target.value[props.tab].deepshrink.stop_at_2 = $event), + max: 1, + min: 0.05, + step: 0.05 + }, null, 8, ["value"]) + ]) + ]), + _: 1 + }), + createVNode(unref(NCard), { + bordered: false, + title: "Scale" + }, { + default: withCtx(() => [ + createBaseVNode("div", _hoisted_13$1, [ + _hoisted_14$1, + createVNode(unref(NSlider), { + value: target.value[props.tab].deepshrink.base_scale, + "onUpdate:value": _cache[7] || (_cache[7] = ($event) => target.value[props.tab].deepshrink.base_scale = $event), + min: 0.05, + max: 1, + step: 0.05 + }, null, 8, ["value"]), + createVNode(unref(NInputNumber), { + value: target.value[props.tab].deepshrink.base_scale, + "onUpdate:value": _cache[8] || (_cache[8] = ($event) => target.value[props.tab].deepshrink.base_scale = $event), + max: 1, + min: 0.05, + step: 0.05 + }, null, 8, ["value"]) + ]), + createBaseVNode("div", _hoisted_15$1, [ + _hoisted_16$1, + createVNode(unref(NSelect), { + value: target.value[props.tab].deepshrink.scaler, + "onUpdate:value": _cache[9] || (_cache[9] = ($event) => target.value[props.tab].deepshrink.scaler = $event), + filterable: "", + options: latentUpscalerOptions + }, null, 8, ["value"]) + ]) + ]), + _: 1 + }), + createVNode(unref(NCard), { + bordered: false, + title: "Other" + }, { + default: withCtx(() => [ + createBaseVNode("div", _hoisted_17$1, [ + _hoisted_18$1, + createVNode(unref(NSwitch), { + value: target.value[props.tab].deepshrink.early_out, + "onUpdate:value": _cache[10] || (_cache[10] = ($event) => target.value[props.tab].deepshrink.early_out = $event) + }, null, 8, ["value"]) + ]) + ]), + _: 1 + }) + ])) : createCommentVNode("", true) + ], 64); + }; + } +}); +const _hoisted_1$3 = { class: "flex-container" }; +const _hoisted_2$3 = /* @__PURE__ */ createBaseVNode("div", { class: "slider-label" }, [ + /* @__PURE__ */ createBaseVNode("p", null, "Enabled") +], -1); +const _hoisted_3$3 = { class: "flex-container" }; +const _hoisted_4$3 = /* @__PURE__ */ createBaseVNode("div", { class: "slider-label" }, [ + /* @__PURE__ */ createBaseVNode("p", null, "Mode") +], -1); +const _hoisted_5$3 = { key: 0 }; +const _hoisted_6$3 = { class: "flex-container" }; +const _hoisted_7$3 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Upscaler", -1); +const _hoisted_8$2 = { key: 1 }; +const _hoisted_9$2 = { class: "flex-container" }; +const _hoisted_10$1 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Antialiased", -1); +const _hoisted_11 = { class: "flex-container" }; +const _hoisted_12 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Latent Mode", -1); +const _hoisted_13 = { class: "flex-container" }; +const _hoisted_14 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Steps", -1); +const _hoisted_15 = /* @__PURE__ */ createBaseVNode("b", { class: "highlight" }, "We recommend using 20-50 steps for most images.", -1); +const _hoisted_16 = { class: "flex-container" }; +const _hoisted_17 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Scale", -1); +const _hoisted_18 = { class: "flex-container" }; +const _hoisted_19 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Strength", -1); +const _sfc_main$4 = /* @__PURE__ */ defineComponent({ + __name: "HighResFix", + props: { + tab: { + type: String, + required: true + }, + target: { + type: String, + required: false, + default: "settings" + } + }, + setup(__props) { + const props = __props; + const settings = useSettings(); + const global = useState(); + const target = computed(() => { + if (props.target === "settings") { + return settings.data.settings; + } + return settings.defaultSettings; + }); + const imageUpscalerOptions = computed(() => { + const localModels = global.state.models.filter( + (model) => model.backend === "Upscaler" && !(upscalerOptions.map((option) => option.label).indexOf(model.name) !== -1) + ).map((model) => ({ + label: model.name, + value: model.path + })); + return [...upscalerOptions, ...localModels]; + }); + const latentUpscalerOptions = [ + { label: "Nearest", value: "nearest" }, + { label: "Nearest exact", value: "nearest-exact" }, + { label: "Area", value: "area" }, + { label: "Bilinear", value: "bilinear" }, + { label: "Bicubic", value: "bicubic" }, + { label: "Bislerp", value: "bislerp" } + ]; + return (_ctx, _cache) => { + return openBlock(), createElementBlock(Fragment, null, [ + createBaseVNode("div", _hoisted_1$3, [ + _hoisted_2$3, + createVNode(unref(NSwitch), { + value: target.value[props.tab].highres.enabled, + "onUpdate:value": _cache[0] || (_cache[0] = ($event) => target.value[props.tab].highres.enabled = $event) + }, null, 8, ["value"]) + ]), + target.value[props.tab].highres.enabled ? (openBlock(), createBlock(unref(NSpace), { + key: 0, + vertical: "", + class: "left-container" + }, { + default: withCtx(() => [ + createBaseVNode("div", _hoisted_3$3, [ + _hoisted_4$3, + createVNode(unref(NSelect), { + value: target.value[props.tab].highres.mode, + "onUpdate:value": _cache[1] || (_cache[1] = ($event) => target.value[props.tab].highres.mode = $event), + options: [ + { label: "Latent", value: "latent" }, + { label: "Image", value: "image" } + ] + }, null, 8, ["value"]) + ]), + target.value[props.tab].highres.mode === "image" ? (openBlock(), createElementBlock("div", _hoisted_5$3, [ + createBaseVNode("div", _hoisted_6$3, [ + _hoisted_7$3, + createVNode(unref(NSelect), { + value: target.value[props.tab].highres.image_upscaler, + "onUpdate:value": _cache[2] || (_cache[2] = ($event) => target.value[props.tab].highres.image_upscaler = $event), + size: "small", + style: { "flex-grow": "1" }, + filterable: "", + options: imageUpscalerOptions.value + }, null, 8, ["value", "options"]) + ]) + ])) : (openBlock(), createElementBlock("div", _hoisted_8$2, [ + createBaseVNode("div", _hoisted_9$2, [ + _hoisted_10$1, + createVNode(unref(NSwitch), { + value: target.value[props.tab].highres.antialiased, + "onUpdate:value": _cache[3] || (_cache[3] = ($event) => target.value[props.tab].highres.antialiased = $event) + }, null, 8, ["value"]) + ]), + createBaseVNode("div", _hoisted_11, [ + _hoisted_12, + createVNode(unref(NSelect), { + value: target.value[props.tab].highres.latent_scale_mode, + "onUpdate:value": _cache[4] || (_cache[4] = ($event) => target.value[props.tab].highres.latent_scale_mode = $event), + size: "small", + style: { "flex-grow": "1" }, + filterable: "", + options: latentUpscalerOptions + }, null, 8, ["value"]) + ]) + ])), + createBaseVNode("div", _hoisted_13, [ + createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { + trigger: withCtx(() => [ + _hoisted_14 + ]), + default: withCtx(() => [ + createTextVNode(" Number of steps to take in the diffusion process. Higher values will result in more detailed images but will take longer to generate. There is also a point of diminishing returns around 100 steps. "), + _hoisted_15 + ]), + _: 1 + }), + createVNode(unref(NSlider), { + value: target.value[props.tab].highres.steps, + "onUpdate:value": _cache[5] || (_cache[5] = ($event) => target.value[props.tab].highres.steps = $event), + min: 5, + max: 300, + style: { "margin-right": "12px" } + }, null, 8, ["value"]), + createVNode(unref(NInputNumber), { + value: target.value[props.tab].highres.steps, + "onUpdate:value": _cache[6] || (_cache[6] = ($event) => target.value[props.tab].highres.steps = $event), + size: "small", + style: { "min-width": "96px", "width": "96px" } + }, null, 8, ["value"]) + ]), + createBaseVNode("div", _hoisted_16, [ + _hoisted_17, + createVNode(unref(NSlider), { + value: target.value[props.tab].highres.scale, + "onUpdate:value": _cache[7] || (_cache[7] = ($event) => target.value[props.tab].highres.scale = $event), + min: 1, + max: 8, + step: 0.1, + style: { "margin-right": "12px" } + }, null, 8, ["value"]), + createVNode(unref(NInputNumber), { + value: target.value[props.tab].highres.scale, + "onUpdate:value": _cache[8] || (_cache[8] = ($event) => target.value[props.tab].highres.scale = $event), + size: "small", + style: { "min-width": "96px", "width": "96px" }, + step: 0.1 + }, null, 8, ["value"]) + ]), + createBaseVNode("div", _hoisted_18, [ + _hoisted_19, + createVNode(unref(NSlider), { + value: target.value[props.tab].highres.strength, + "onUpdate:value": _cache[9] || (_cache[9] = ($event) => target.value[props.tab].highres.strength = $event), + min: 0.1, + max: 0.9, + step: 0.05, + style: { "margin-right": "12px" } + }, null, 8, ["value"]), + createVNode(unref(NInputNumber), { + value: target.value[props.tab].highres.strength, + "onUpdate:value": _cache[10] || (_cache[10] = ($event) => target.value[props.tab].highres.strength = $event), + size: "small", + style: { "min-width": "96px", "width": "96px" }, + min: 0.1, + max: 0.9, + step: 0.05 + }, null, 8, ["value"]) + ]) + ]), + _: 1 + })) : createCommentVNode("", true) + ], 64); + }; + } +}); +const _sfc_main$3 = /* @__PURE__ */ defineComponent({ + __name: "HighResFixTabs", + props: { + tab: { + type: String, + required: true + }, + target: { + type: String, + required: false, + default: "settings" + } + }, + setup(__props) { + const props = __props; + return (_ctx, _cache) => { + return openBlock(), createBlock(unref(NCard), { + title: "High Resolution Fix", + class: "generate-extra-card" + }, { + default: withCtx(() => [ + createVNode(unref(NTabs), { + animated: "", + type: "segment" + }, { + default: withCtx(() => [ + createVNode(unref(NTabPane), { + tab: "Image to Image", + name: "highresfix" + }, { + default: withCtx(() => [ + createVNode(unref(_sfc_main$4), { + tab: props.tab, + target: props.target + }, null, 8, ["tab", "target"]) + ]), + _: 1 + }), + createVNode(unref(NTabPane), { + tab: "Scalecrafter", + name: "scalecrafter" + }, { + default: withCtx(() => [ + createVNode(unref(_sfc_main$1), { + tab: props.tab, + target: props.target + }, null, 8, ["tab", "target"]) + ]), + _: 1 + }), + createVNode(unref(NTabPane), { + tab: "DeepShrink", + name: "deepshrink" + }, { + default: withCtx(() => [ + createVNode(unref(_sfc_main$5), { + tab: props.tab, + target: props.target + }, null, 8, ["tab", "target"]) + ]), + _: 1 + }) + ]), + _: 1 + }) + ]), + _: 1 + }); + }; + } +}); +const _hoisted_1$2 = { class: "flex-container" }; +const _hoisted_2$2 = { style: { "margin-left": "12px", "margin-right": "12px", "white-space": "nowrap" } }; +const _hoisted_3$2 = /* @__PURE__ */ createBaseVNode("p", { style: { "margin-right": "12px", "width": "100px" } }, "Sampler", -1); +const _hoisted_4$2 = /* @__PURE__ */ createBaseVNode("a", { target: "_blank", href: "https://docs.google.com/document/d/1n0YozLAUwLJWZmbsx350UD_bwAx3gZMnRuleIZt_R1w" }, "Learn more", -1); -const _hoisted_5 = { class: "flex-container" }; -const _hoisted_6 = /* @__PURE__ */ createBaseVNode("p", { style: { "margin-right": "12px", "width": "94px" } }, "Sigmas", -1); -const _hoisted_7 = /* @__PURE__ */ createBaseVNode("b", { class: "highlight" }, 'Only "Default" and "Karras" sigmas work on diffusers samplers (and "Karras" are only applied to KDPM samplers)', -1); -const _sfc_main = /* @__PURE__ */ defineComponent({ +const _hoisted_5$2 = { class: "flex-container" }; +const _hoisted_6$2 = /* @__PURE__ */ createBaseVNode("p", { style: { "margin-right": "12px", "width": "94px" } }, "Sigmas", -1); +const _hoisted_7$2 = /* @__PURE__ */ createBaseVNode("b", { class: "highlight" }, 'Only "Default" and "Karras" sigmas work on diffusers samplers (and "Karras" are only applied to KDPM samplers)', -1); +const _sfc_main$2 = /* @__PURE__ */ defineComponent({ __name: "SamplerPicker", props: { type: { @@ -1953,7 +2409,7 @@ const _sfc_main = /* @__PURE__ */ defineComponent({ }); return (_ctx, _cache) => { return openBlock(), createElementBlock(Fragment, null, [ - createBaseVNode("div", _hoisted_1, [ + createBaseVNode("div", _hoisted_1$2, [ createVNode(unref(NModal), { show: showModal.value, "onUpdate:show": _cache[1] || (_cache[1] = ($event) => showModal.value = $event), @@ -1985,7 +2441,7 @@ const _sfc_main = /* @__PURE__ */ defineComponent({ ]), _: 2 }, 1032, ["type", "disabled", "onClick"]), - createBaseVNode("p", _hoisted_2, toDisplayString(unref(convertToTextString)(param)), 1), + createBaseVNode("p", _hoisted_2$2, toDisplayString(unref(convertToTextString)(param)), 1), (openBlock(), createBlock(resolveDynamicComponent( resolveComponent( target.value.sampler_config["ui_settings"][param], @@ -2002,11 +2458,11 @@ const _sfc_main = /* @__PURE__ */ defineComponent({ }, 8, ["show"]), createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { trigger: withCtx(() => [ - _hoisted_3 + _hoisted_3$2 ]), default: withCtx(() => [ createTextVNode(" The sampler is the method used to generate the image. Your result may vary drastically depending on the sampler you choose. "), - _hoisted_4 + _hoisted_4$2 ]), _: 1 }), @@ -2032,14 +2488,14 @@ const _sfc_main = /* @__PURE__ */ defineComponent({ _: 1 }) ]), - createBaseVNode("div", _hoisted_5, [ + createBaseVNode("div", _hoisted_5$2, [ createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { trigger: withCtx(() => [ - _hoisted_6 + _hoisted_6$2 ]), default: withCtx(() => [ createTextVNode(" Changes the sigmas used in the diffusion process. Can change the quality of the output. "), - _hoisted_7 + _hoisted_7$2 ]), _: 1 }), @@ -2054,8 +2510,263 @@ const _sfc_main = /* @__PURE__ */ defineComponent({ }; } }); +const _hoisted_1$1 = { class: "flex-container" }; +const _hoisted_2$1 = /* @__PURE__ */ createBaseVNode("div", { class: "slider-label" }, [ + /* @__PURE__ */ createBaseVNode("p", null, "Enabled") +], -1); +const _hoisted_3$1 = /* @__PURE__ */ createBaseVNode("b", { class: "highlight" }, "Automatic", -1); +const _hoisted_4$1 = /* @__PURE__ */ createBaseVNode("b", { class: "highlight" }, "Karras", -1); +const _hoisted_5$1 = { class: "flex-container" }; +const _hoisted_6$1 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Disperse", -1); +const _hoisted_7$1 = /* @__PURE__ */ createBaseVNode("b", { class: "highlight" }, " However, this comes at the cost of increased vram usage, generally in the range of 3-4x. ", -1); +const _hoisted_8$1 = { class: "flex-container" }; +const _hoisted_9$1 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Unsafe resolutions", -1); +const _sfc_main$1 = /* @__PURE__ */ defineComponent({ + __name: "Scalecrafter", + props: { + tab: { + type: String, + required: true + }, + target: { + type: String, + required: false, + default: "settings" + } + }, + setup(__props) { + const props = __props; + const settings = useSettings(); + const target = computed(() => { + if (props.target === "settings") { + return settings.data.settings; + } + return settings.defaultSettings; + }); + return (_ctx, _cache) => { + return openBlock(), createElementBlock(Fragment, null, [ + createBaseVNode("div", _hoisted_1$1, [ + _hoisted_2$1, + createVNode(unref(NSwitch), { + value: target.value[props.tab].scalecrafter.enabled, + "onUpdate:value": _cache[0] || (_cache[0] = ($event) => target.value[props.tab].scalecrafter.enabled = $event) + }, null, 8, ["value"]) + ]), + target.value[props.tab].scalecrafter.enabled ? (openBlock(), createBlock(unref(NSpace), { + key: 0, + vertical: "", + class: "left-container" + }, { + default: withCtx(() => [ + createVNode(unref(NAlert), { type: "warning" }, { + default: withCtx(() => [ + createTextVNode(" Only works with "), + _hoisted_3$1, + createTextVNode(" and "), + _hoisted_4$1, + createTextVNode(" sigmas ") + ]), + _: 1 + }), + createBaseVNode("div", _hoisted_5$1, [ + createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { + trigger: withCtx(() => [ + _hoisted_6$1 + ]), + default: withCtx(() => [ + createTextVNode(" May generate more unique images. "), + _hoisted_7$1 + ]), + _: 1 + }), + createVNode(unref(NSwitch), { + value: target.value[props.tab].scalecrafter.disperse, + "onUpdate:value": _cache[1] || (_cache[1] = ($event) => target.value[props.tab].scalecrafter.disperse = $event) + }, null, 8, ["value"]) + ]), + createBaseVNode("div", _hoisted_8$1, [ + createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { + trigger: withCtx(() => [ + _hoisted_9$1 + ]), + default: withCtx(() => [ + createTextVNode(" Allow generating with unique resolutions that don't have configs ready for them, or clamp them (really, force them) to the closest resolution. ") + ]), + _: 1 + }), + createVNode(unref(NSwitch), { + value: target.value[props.tab].scalecrafter.unsafe_resolutions, + "onUpdate:value": _cache[2] || (_cache[2] = ($event) => target.value[props.tab].scalecrafter.unsafe_resolutions = $event) + }, null, 8, ["value"]) + ]) + ]), + _: 1 + })) : createCommentVNode("", true) + ], 64); + }; + } +}); +const _hoisted_1 = { class: "flex-container" }; +const _hoisted_2 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Enabled", -1); +const _hoisted_3 = { class: "flex-container" }; +const _hoisted_4 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Model", -1); +const _hoisted_5 = { class: "flex-container" }; +const _hoisted_6 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Scale Factor", -1); +const _hoisted_7 = { class: "flex-container" }; +const _hoisted_8 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Tile Size", -1); +const _hoisted_9 = { class: "flex-container" }; +const _hoisted_10 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Tile Padding", -1); +const _sfc_main = /* @__PURE__ */ defineComponent({ + __name: "Upscale", + props: { + tab: { + type: String, + required: true + }, + target: { + type: String, + required: false, + default: "settings" + } + }, + setup(__props) { + const props = __props; + const global = useState(); + const settings = useSettings(); + const target = computed(() => { + if (props.target === "settings") { + return settings.data.settings; + } + return settings.defaultSettings; + }); + const upscalerOptionsFull = computed(() => { + const localModels = global.state.models.filter( + (model) => model.backend === "Upscaler" && !(upscalerOptions.map((option) => option.label).indexOf(model.name) !== -1) + ).map((model) => ({ + label: model.name, + value: model.path + })); + return [...upscalerOptions, ...localModels]; + }); + return (_ctx, _cache) => { + return openBlock(), createBlock(unref(NCard), { + title: "Upscale", + class: "generate-extra-card" + }, { + default: withCtx(() => [ + createBaseVNode("div", _hoisted_1, [ + _hoisted_2, + createVNode(unref(NSwitch), { + value: target.value[props.tab].upscale.enabled, + "onUpdate:value": _cache[0] || (_cache[0] = ($event) => target.value[props.tab].upscale.enabled = $event) + }, null, 8, ["value"]) + ]), + target.value[props.tab].upscale.enabled ? (openBlock(), createBlock(unref(NSpace), { + key: 0, + vertical: "", + class: "left-container" + }, { + default: withCtx(() => [ + createBaseVNode("div", _hoisted_3, [ + _hoisted_4, + createVNode(unref(NSelect), { + value: target.value[props.tab].upscale.model, + "onUpdate:value": _cache[1] || (_cache[1] = ($event) => target.value[props.tab].upscale.model = $event), + style: { "margin-right": "12px" }, + filterable: "", + tag: "", + options: upscalerOptionsFull.value + }, null, 8, ["value", "options"]) + ]), + createBaseVNode("div", _hoisted_5, [ + createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { + trigger: withCtx(() => [ + _hoisted_6 + ]), + default: withCtx(() => [ + createTextVNode(" TODO ") + ]), + _: 1 + }), + createVNode(unref(NSlider), { + value: target.value[props.tab].upscale.upscale_factor, + "onUpdate:value": _cache[2] || (_cache[2] = ($event) => target.value[props.tab].upscale.upscale_factor = $event), + min: 1, + max: 4, + step: 0.1, + style: { "margin-right": "12px" } + }, null, 8, ["value"]), + createVNode(unref(NInputNumber), { + value: target.value[props.tab].upscale.upscale_factor, + "onUpdate:value": _cache[3] || (_cache[3] = ($event) => target.value[props.tab].upscale.upscale_factor = $event), + size: "small", + style: { "min-width": "96px", "width": "96px" }, + min: 1, + max: 4, + step: 0.1 + }, null, 8, ["value"]) + ]), + createBaseVNode("div", _hoisted_7, [ + createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { + trigger: withCtx(() => [ + _hoisted_8 + ]), + default: withCtx(() => [ + createTextVNode(" How large each tile should be. Larger tiles will use more memory. 0 will disable tiling. ") + ]), + _: 1 + }), + createVNode(unref(NSlider), { + value: target.value[props.tab].upscale.tile_size, + "onUpdate:value": _cache[4] || (_cache[4] = ($event) => target.value[props.tab].upscale.tile_size = $event), + min: 32, + max: 2048, + style: { "margin-right": "12px" } + }, null, 8, ["value"]), + createVNode(unref(NInputNumber), { + value: target.value[props.tab].upscale.tile_size, + "onUpdate:value": _cache[5] || (_cache[5] = ($event) => target.value[props.tab].upscale.tile_size = $event), + size: "small", + min: 32, + max: 2048, + style: { "min-width": "96px", "width": "96px" } + }, null, 8, ["value"]) + ]), + createBaseVNode("div", _hoisted_9, [ + createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { + trigger: withCtx(() => [ + _hoisted_10 + ]), + default: withCtx(() => [ + createTextVNode(" How much should tiles overlap. Larger padding will use more memory, but image should not have visible seams. ") + ]), + _: 1 + }), + createVNode(unref(NSlider), { + value: target.value[props.tab].upscale.tile_padding, + "onUpdate:value": _cache[6] || (_cache[6] = ($event) => target.value[props.tab].upscale.tile_padding = $event), + style: { "margin-right": "12px" } + }, null, 8, ["value"]), + createVNode(unref(NInputNumber), { + value: target.value[props.tab].upscale.tile_padding, + "onUpdate:value": _cache[7] || (_cache[7] = ($event) => target.value[props.tab].upscale.tile_padding = $event), + size: "small", + style: { "min-width": "96px", "width": "96px" } + }, null, 8, ["value"]) + ]) + ]), + _: 1 + })) : createCommentVNode("", true) + ]), + _: 1 + }); + }; + } +}); export { NForm as N, - _sfc_main as _, - NFormItem as a + _sfc_main$2 as _, + _sfc_main$3 as a, + _sfc_main as b, + NFormItem as c }; diff --git a/frontend/dist/assets/clock.js b/frontend/dist/assets/clock.js index 7deffb759..eaff2dd06 100644 --- a/frontend/dist/assets/clock.js +++ b/frontend/dist/assets/clock.js @@ -5,16 +5,17 @@ var __publicField = (obj, key, value) => { return value; }; import { N as NDescriptionsItem, a as NDescriptions } from "./DescriptionsItem.js"; -import { d as defineComponent, o as openBlock, j as createElementBlock, f as createBaseVNode, e as createBlock, w as withCtx, g as createVNode, h as unref, k as createTextVNode, C as toDisplayString, n as NCard, m as createCommentVNode, u as useSettings, l as NTooltip, F as Fragment, a as useState, c as computed, G as spaceRegex, B as NIcon, i as NSelect, H as promptHandleKeyUp, I as promptHandleKeyDown, J as NInput, _ as _export_sfc, K as watch, z as ref, t as serverUrl } from "./index.js"; -import { a as NSlider, N as NSwitch } from "./Switch.js"; +import { d as defineComponent, o as openBlock, a as createElementBlock, b as createBaseVNode, g as createBlock, w as withCtx, e as createVNode, f as unref, h as createTextVNode, E as toDisplayString, m as NCard, k as createCommentVNode, u as useSettings, N as NTooltip, c as computed, F as Fragment, l as useState, G as spaceRegex, D as NIcon, q as NSelect, H as promptHandleKeyUp, I as promptHandleKeyDown, J as NInput, _ as _export_sfc, K as watch, B as ref, x as serverUrl } from "./index.js"; +import { N as NSlider } from "./Slider.js"; import { N as NInputNumber } from "./InputNumber.js"; -import { N as NForm, a as NFormItem } from "./SamplerPicker.vue_vue_type_script_setup_true_lang.js"; -const _hoisted_1$2 = { +import { N as NForm, c as NFormItem } from "./Upscale.vue_vue_type_script_setup_true_lang.js"; +import { N as NSwitch } from "./Switch.js"; +const _hoisted_1$4 = { xmlns: "http://www.w3.org/2000/svg", "xmlns:xlink": "http://www.w3.org/1999/xlink", viewBox: "0 0 512 512" }; -const _hoisted_2$2 = /* @__PURE__ */ createBaseVNode( +const _hoisted_2$4 = /* @__PURE__ */ createBaseVNode( "path", { d: "M262.29 192.31a64 64 0 1 0 57.4 57.4a64.13 64.13 0 0 0-57.4-57.4zM416.39 256a154.34 154.34 0 0 1-1.53 20.79l45.21 35.46a10.81 10.81 0 0 1 2.45 13.75l-42.77 74a10.81 10.81 0 0 1-13.14 4.59l-44.9-18.08a16.11 16.11 0 0 0-15.17 1.75A164.48 164.48 0 0 1 325 400.8a15.94 15.94 0 0 0-8.82 12.14l-6.73 47.89a11.08 11.08 0 0 1-10.68 9.17h-85.54a11.11 11.11 0 0 1-10.69-8.87l-6.72-47.82a16.07 16.07 0 0 0-9-12.22a155.3 155.3 0 0 1-21.46-12.57a16 16 0 0 0-15.11-1.71l-44.89 18.07a10.81 10.81 0 0 1-13.14-4.58l-42.77-74a10.8 10.8 0 0 1 2.45-13.75l38.21-30a16.05 16.05 0 0 0 6-14.08c-.36-4.17-.58-8.33-.58-12.5s.21-8.27.58-12.35a16 16 0 0 0-6.07-13.94l-38.19-30A10.81 10.81 0 0 1 49.48 186l42.77-74a10.81 10.81 0 0 1 13.14-4.59l44.9 18.08a16.11 16.11 0 0 0 15.17-1.75A164.48 164.48 0 0 1 187 111.2a15.94 15.94 0 0 0 8.82-12.14l6.73-47.89A11.08 11.08 0 0 1 213.23 42h85.54a11.11 11.11 0 0 1 10.69 8.87l6.72 47.82a16.07 16.07 0 0 0 9 12.22a155.3 155.3 0 0 1 21.46 12.57a16 16 0 0 0 15.11 1.71l44.89-18.07a10.81 10.81 0 0 1 13.14 4.58l42.77 74a10.8 10.8 0 0 1-2.45 13.75l-38.21 30a16.05 16.05 0 0 0-6.05 14.08c.33 4.14.55 8.3.55 12.47z", @@ -28,14 +29,14 @@ const _hoisted_2$2 = /* @__PURE__ */ createBaseVNode( -1 /* HOISTED */ ); -const _hoisted_3$2 = [_hoisted_2$2]; +const _hoisted_3$3 = [_hoisted_2$4]; const SettingsOutline = defineComponent({ name: "SettingsOutline", render: function render(_ctx, _cache) { - return openBlock(), createElementBlock("svg", _hoisted_1$2, _hoisted_3$2); + return openBlock(), createElementBlock("svg", _hoisted_1$4, _hoisted_3$3); } }); -const _sfc_main$3 = /* @__PURE__ */ defineComponent({ +const _sfc_main$5 = /* @__PURE__ */ defineComponent({ __name: "OutputStats", props: { genData: { @@ -73,17 +74,17 @@ const _sfc_main$3 = /* @__PURE__ */ defineComponent({ }; } }); -const _hoisted_1$1 = { +const _hoisted_1$3 = { key: 0, class: "flex-container" }; -const _hoisted_2$1 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Batch Size", -1); -const _hoisted_3$1 = { +const _hoisted_2$3 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Batch Size", -1); +const _hoisted_3$2 = { key: 1, class: "flex-container" }; const _hoisted_4$1 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Batch Size", -1); -const _sfc_main$2 = /* @__PURE__ */ defineComponent({ +const _sfc_main$4 = /* @__PURE__ */ defineComponent({ __name: "BatchSizeInput", props: { batchSizeObject: { @@ -95,10 +96,10 @@ const _sfc_main$2 = /* @__PURE__ */ defineComponent({ const props = __props; const settings = useSettings(); return (_ctx, _cache) => { - return unref(settings).data.settings.aitDim.batch_size ? (openBlock(), createElementBlock("div", _hoisted_1$1, [ + return unref(settings).data.settings.aitDim.batch_size ? (openBlock(), createElementBlock("div", _hoisted_1$3, [ createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { trigger: withCtx(() => [ - _hoisted_2$1 + _hoisted_2$3 ]), default: withCtx(() => [ createTextVNode(" Number of images to generate in paralel. ") @@ -120,7 +121,7 @@ const _sfc_main$2 = /* @__PURE__ */ defineComponent({ max: unref(settings).data.settings.aitDim.batch_size[1], style: { "min-width": "96px", "width": "96px" } }, null, 8, ["value", "min", "max"]) - ])) : (openBlock(), createElementBlock("div", _hoisted_3$1, [ + ])) : (openBlock(), createElementBlock("div", _hoisted_3$2, [ createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { trigger: withCtx(() => [ _hoisted_4$1 @@ -147,11 +148,81 @@ const _sfc_main$2 = /* @__PURE__ */ defineComponent({ }; } }); -const _hoisted_1 = { +const _hoisted_1$2 = { class: "flex-container" }; +const _hoisted_2$2 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "CFG Scale", -1); +const _hoisted_3$1 = /* @__PURE__ */ createBaseVNode("b", { class: "highlight" }, "We recommend using 3-15 for most images.", -1); +const _sfc_main$3 = /* @__PURE__ */ defineComponent({ + __name: "CFGScaleInput", + props: { + tab: { + type: String, + required: true + }, + target: { + type: String, + required: false, + default: "settings" + } + }, + setup(__props) { + const props = __props; + const settings = useSettings(); + const cfgMax = computed(() => { + var scale = 30; + return scale + Math.max( + settings.defaultSettings.api.apply_unsharp_mask ? 15 : 0, + settings.defaultSettings.api.cfg_rescale_threshold == "off" ? 0 : 30 + ); + }); + const target = computed(() => { + if (props.target === "settings") { + return settings.data.settings[props.tab]; + } else if (props.target === "adetailer") { + return settings.data.settings[props.tab].adetailer; + } else if (props.target === "defaultSettingsAdetailer") { + return settings.defaultSettings[props.tab].adetailer; + } else { + return settings.defaultSettings; + } + }); + return (_ctx, _cache) => { + return openBlock(), createElementBlock("div", _hoisted_1$2, [ + createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { + trigger: withCtx(() => [ + _hoisted_2$2 + ]), + default: withCtx(() => [ + createTextVNode(" Guidance scale indicates how close should the model stay to the prompt. Higher values might be exactly what you want, but generated images might have some artifacts. Lower values give the model more freedom, and therefore might produce more coherent/less-artifacty images, but wouldn't follow the prompt as closely. "), + _hoisted_3$1 + ]), + _: 1 + }), + createVNode(unref(NSlider), { + value: target.value.cfg_scale, + "onUpdate:value": _cache[0] || (_cache[0] = ($event) => target.value.cfg_scale = $event), + min: 1, + max: cfgMax.value, + step: 0.5, + style: { "margin-right": "12px" } + }, null, 8, ["value", "max"]), + createVNode(unref(NInputNumber), { + value: target.value.cfg_scale, + "onUpdate:value": _cache[1] || (_cache[1] = ($event) => target.value.cfg_scale = $event), + size: "small", + style: { "min-width": "96px", "width": "96px" }, + min: 1, + max: cfgMax.value, + step: 0.5 + }, null, 8, ["value", "max"]) + ]); + }; + } +}); +const _hoisted_1$1 = { key: 0, class: "flex-container" }; -const _hoisted_2 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Width", -1); +const _hoisted_2$1 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Width", -1); const _hoisted_3 = { key: 1, class: "flex-container" @@ -167,7 +238,7 @@ const _hoisted_7 = { class: "flex-container" }; const _hoisted_8 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Height", -1); -const _sfc_main$1 = /* @__PURE__ */ defineComponent({ +const _sfc_main$2 = /* @__PURE__ */ defineComponent({ __name: "DimensionsInput", props: { dimensionsObject: { @@ -180,15 +251,14 @@ const _sfc_main$1 = /* @__PURE__ */ defineComponent({ const settings = useSettings(); return (_ctx, _cache) => { return openBlock(), createElementBlock(Fragment, null, [ - unref(settings).data.settings.aitDim.width ? (openBlock(), createElementBlock("div", _hoisted_1, [ - _hoisted_2, + unref(settings).data.settings.aitDim.width ? (openBlock(), createElementBlock("div", _hoisted_1$1, [ + _hoisted_2$1, createVNode(unref(NSlider), { value: props.dimensionsObject.width, "onUpdate:value": _cache[0] || (_cache[0] = ($event) => props.dimensionsObject.width = $event), min: unref(settings).data.settings.aitDim.width[0], max: unref(settings).data.settings.aitDim.width[1], - step: 64, - style: { "margin-right": "12px" } + step: 64 }, null, 8, ["value", "min", "max"]), createVNode(unref(NInputNumber), { value: props.dimensionsObject.width, @@ -206,8 +276,7 @@ const _sfc_main$1 = /* @__PURE__ */ defineComponent({ "onUpdate:value": _cache[2] || (_cache[2] = ($event) => props.dimensionsObject.width = $event), min: 128, max: 2048, - step: 1, - style: { "margin-right": "12px" } + step: 1 }, null, 8, ["value"]), createVNode(unref(NInputNumber), { value: props.dimensionsObject.width, @@ -224,8 +293,7 @@ const _sfc_main$1 = /* @__PURE__ */ defineComponent({ "onUpdate:value": _cache[4] || (_cache[4] = ($event) => props.dimensionsObject.height = $event), min: unref(settings).data.settings.aitDim.height[0], max: unref(settings).data.settings.aitDim.height[1], - step: 64, - style: { "margin-right": "12px" } + step: 64 }, null, 8, ["value", "min", "max"]), createVNode(unref(NInputNumber), { value: props.dimensionsObject.height, @@ -243,8 +311,7 @@ const _sfc_main$1 = /* @__PURE__ */ defineComponent({ "onUpdate:value": _cache[6] || (_cache[6] = ($event) => props.dimensionsObject.height = $event), min: 128, max: 2048, - step: 1, - style: { "margin-right": "12px" } + step: 1 }, null, 8, ["value"]), createVNode(unref(NInputNumber), { value: props.dimensionsObject.height, @@ -258,7 +325,7 @@ const _sfc_main$1 = /* @__PURE__ */ defineComponent({ }; } }); -const _sfc_main = /* @__PURE__ */ defineComponent({ +const _sfc_main$1 = /* @__PURE__ */ defineComponent({ __name: "Prompt", props: { tab: { @@ -411,7 +478,70 @@ const _sfc_main = /* @__PURE__ */ defineComponent({ }; } }); -const Prompt = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "data-v-780680bc"]]); +const Prompt = /* @__PURE__ */ _export_sfc(_sfc_main$1, [["__scopeId", "data-v-780680bc"]]); +const _hoisted_1 = { + key: 0, + class: "flex-container" +}; +const _hoisted_2 = /* @__PURE__ */ createBaseVNode("p", { class: "slider-label" }, "Self Attention Scale", -1); +const _sfc_main = /* @__PURE__ */ defineComponent({ + __name: "SAGInput", + props: { + tab: { + type: String, + required: true + }, + target: { + type: String, + required: false, + default: "settings" + } + }, + setup(__props) { + const props = __props; + const settings = useSettings(); + const target = computed(() => { + if (props.target === "settings") { + return settings.data.settings[props.tab]; + } else if (props.target === "adetailer") { + return settings.data.settings[props.tab].adetailer; + } else if (props.target === "defaultSettingsAdetailer") { + return settings.defaultSettings[props.tab].adetailer; + } else { + return settings.defaultSettings; + } + }); + return (_ctx, _cache) => { + var _a; + return ((_a = unref(settings).data.settings.model) == null ? void 0 : _a.backend) === "PyTorch" ? (openBlock(), createElementBlock("div", _hoisted_1, [ + createVNode(unref(NTooltip), { style: { "max-width": "600px" } }, { + trigger: withCtx(() => [ + _hoisted_2 + ]), + default: withCtx(() => [ + createTextVNode(" If self attention is >0, SAG will guide the model and improve the quality of the image at the cost of speed. Higher values will follow the guidance more closely, which can lead to better, more sharp and detailed outputs. ") + ]), + _: 1 + }), + createVNode(unref(NSlider), { + value: target.value.self_attention_scale, + "onUpdate:value": _cache[0] || (_cache[0] = ($event) => target.value.self_attention_scale = $event), + min: 0, + max: 1, + step: 0.05, + style: { "margin-right": "12px" } + }, null, 8, ["value"]), + createVNode(unref(NInputNumber), { + value: target.value.self_attention_scale, + "onUpdate:value": _cache[1] || (_cache[1] = ($event) => target.value.self_attention_scale = $event), + size: "small", + style: { "min-width": "96px", "width": "96px" }, + step: 0.05 + }, null, 8, ["value"]) + ])) : createCommentVNode("", true); + }; + } +}); class BurnerClock { constructor(observed_value, settings, callback, timerOverrride = 0, sendInterrupt = true) { __publicField(this, "isChanging", ref(false)); @@ -484,7 +614,9 @@ class BurnerClock { export { BurnerClock as B, Prompt as P, - _sfc_main$1 as _, - _sfc_main$2 as a, - _sfc_main$3 as b + _sfc_main$3 as _, + _sfc_main as a, + _sfc_main$2 as b, + _sfc_main$4 as c, + _sfc_main$5 as d }; diff --git a/frontend/dist/assets/index.css b/frontend/dist/assets/index.css index 9df01e50a..c76c322d6 100644 --- a/frontend/dist/assets/index.css +++ b/frontend/dist/assets/index.css @@ -2,14 +2,15 @@ margin: 0 12px; } -.split { - width: 50%; -} - .flex-container { width: 100%; display: inline-flex; align-items: center; + gap: 0 8px; +} + +.flex-container.space-between { + justify-content: space-between; } .slider-label { @@ -26,6 +27,14 @@ color: #63e2b7; } +.generate-extra-card { + margin-top: 12px; +} + +.generate-extra-card:last-child { + margin-bottom: 12px; +} + .navbar { position: fixed; top: 0; @@ -50,23 +59,23 @@ justify-content: center; } -.progress-container[data-v-29f01b28] { +.progress-container[data-v-3a99505a] { margin: 12px; flex-grow: 1; width: 400px; } -.top-bar[data-v-29f01b28] { +.top-bar[data-v-3a99505a] { display: inline-flex; align-items: center; padding-top: 10px; padding-bottom: 10px; - width: calc(100% - 64px); + width: var(--ca9a9586); height: 32px; position: fixed; top: 0; z-index: 1; } -.logo[data-v-29f01b28] { +.logo[data-v-3a99505a] { margin-right: 16px; margin-left: 16px; } @@ -77,52 +86,39 @@ margin-bottom: 10px; } -.image-container img[data-v-5358ed01] { +.image-container img[data-v-d4ff54ab] { width: 100%; height: 100%; object-fit: contain; overflow: hidden; } -.image-container[data-v-5358ed01] { +.image-container[data-v-d4ff54ab] { height: 70vh; width: 100%; display: flex; justify-content: center; } -.image-container img[data-v-efacc8fd] { +.image-container img[data-v-a4145f6c] { width: 100%; height: 100%; object-fit: contain; overflow: hidden; } -.image-container[data-v-efacc8fd] { +.image-container[data-v-a4145f6c] { height: 70vh; width: 100%; display: flex; justify-content: center; } -.image-container img[data-v-9c556ef8] { - width: 100%; - height: 100%; - object-fit: contain; - overflow: hidden; -} -.image-container[data-v-9c556ef8] { - height: 70vh; - width: 100%; - display: flex; - justify-content: center; -} - -.hidden-input[data-v-7963dde9] { +.hidden-input[data-v-23b19530] { display: none; } -.utility-button[data-v-7963dde9] { +.utility-button[data-v-23b19530] { margin-right: 8px; } -.file-upload[data-v-7963dde9] { +.file-upload[data-v-23b19530] { appearance: none; background-color: transparent; border: 1px solid #63e2b7; @@ -142,51 +138,43 @@ vertical-align: middle; white-space: nowrap; } -.file-upload[data-v-7963dde9]:focus:not(:focus-visible):not(.focus-visible) { +.file-upload[data-v-23b19530]:focus:not(:focus-visible):not(.focus-visible) { box-shadow: none; outline: none; } -.file-upload[data-v-7963dde9]:focus { +.file-upload[data-v-23b19530]:focus { box-shadow: rgba(46, 164, 79, 0.4) 0 0 0 3px; outline: none; } -.file-upload[data-v-7963dde9]:disabled { +.file-upload[data-v-23b19530]:disabled { background-color: #94d3a2; border-color: rgba(27, 31, 35, 0.1); color: rgba(255, 255, 255, 0.8); cursor: default; } -.image-container[data-v-7963dde9] { +.image-container[data-v-23b19530] { width: 100%; display: flex; justify-content: center; } -.img-slider[data-v-e10a07d2] { - aspect-ratio: 1/1; - height: 182px; - width: auto; -} -.image-grid[data-v-e10a07d2] { +.image-grid[data-v-89afc237] { display: grid; - grid-template-columns: repeat( - var(--6b1de230), - 1fr - ); + grid-template-columns: repeat(auto-fill, minmax(280px, 1fr)); grid-gap: 8px; } -.top-bar[data-v-e10a07d2] { - background-color: var(--a55b21d8); -} -.image-column[data-v-e10a07d2] { - display: flex; - flex-direction: column; +.top-bar[data-v-89afc237] { + background-color: var(--66e6d45b); } .install[data-v-b405f046] { width: 100%; padding: 10px 0px; } + +.router-container { + margin-top: 52px; +} .autocomplete { position: relative; display: inline-block; @@ -194,52 +182,54 @@ .autocomplete-items { position: absolute; z-index: 99; - background-color: var(--4c7ba08e); - border-radius: var(--01ab46a4); + background-color: var(--e68ef196); + border-radius: var(--3f674355); padding: 2px; } .autocomplete-items div { padding: 8px; cursor: pointer; - border-radius: var(--01ab46a4); + border-radius: var(--3f674355); } .autocomplete-active { - background-color: var(--e4e78d9e); - color: var(--d0777f2a); + background-color: var(--646dc050); + color: var(--96bf2bb8); } #autocomplete-list { max-height: min(600px, 70vh); overflow-y: auto; } .n-card { - backdrop-filter: var(--98485856); + backdrop-filter: var(--b08f9a64); } .navbar .n-layout { - backdrop-filter: var(--98485856); + backdrop-filter: var(--b08f9a64); } .navbar .n-layout-toggle-button { - backdrop-filter: var(--98485856); + backdrop-filter: var(--b08f9a64); } .top-bar { - backdrop-filter: var(--98485856); - background-color: var(--6a1d04dc); + backdrop-filter: var(--b08f9a64); + background-color: var(--139458d6); } .navbar { - backdrop-filter: var(--98485856); + backdrop-filter: var(--b08f9a64); } #background { width: 100vw; height: 100vh; position: fixed; - background-image: var(--344206c2); + background-image: var(--31f48ff4); background-size: cover; background-position: center; background-attachment: fixed; top: 0; left: 0; z-index: -99; +} +.main { + margin-left: var(--3ac72808); }body { - margin-left: 64px; color: #fff; font-family: Inter, -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Oxygen, Ubuntu, Cantarell, "Fira Sans", "Droid Sans", "Helvetica Neue", @@ -276,3 +266,7 @@ a, .n-menu-item:last-child { margin-top: auto; } + +.n-drawer-content-wrapper { + backdrop-filter: blur(10px); +} diff --git a/frontend/dist/assets/index.js b/frontend/dist/assets/index.js index 3b6a6c669..82d03952d 100644 --- a/frontend/dist/assets/index.js +++ b/frontend/dist/assets/index.js @@ -1552,7 +1552,7 @@ function renderComponentRoot(instance) { slots, attrs, emit: emit2, - render: render15, + render: render17, renderCache, data, setupState, @@ -1566,7 +1566,7 @@ function renderComponentRoot(instance) { if (vnode.shapeFlag & 4) { const proxyToUse = withProxy || proxy; result = normalizeVNode( - render15.call( + render17.call( proxyToUse, proxyToUse, renderCache, @@ -2668,7 +2668,7 @@ function applyOptions(instance) { beforeUnmount, destroyed, unmounted, - render: render15, + render: render17, renderTracked, renderTriggered, errorCaptured, @@ -2767,8 +2767,8 @@ function applyOptions(instance) { instance.exposed = {}; } } - if (render15 && instance.render === NOOP) { - instance.render = render15; + if (render17 && instance.render === NOOP) { + instance.render = render17; } if (inheritAttrs != null) { instance.inheritAttrs = inheritAttrs; @@ -3000,7 +3000,7 @@ function createAppContext() { }; } let uid$1 = 0; -function createAppAPI(render15, hydrate) { +function createAppAPI(render17, hydrate) { return function createApp2(rootComponent, rootProps = null) { if (!isFunction$2(rootComponent)) { rootComponent = extend({}, rootComponent); @@ -3069,7 +3069,7 @@ function createAppAPI(render15, hydrate) { if (isHydrate && hydrate) { hydrate(vnode, rootContainer); } else { - render15(vnode, rootContainer, isSVG2); + render17(vnode, rootContainer, isSVG2); } isMounted2 = true; app2._container = rootContainer; @@ -3079,7 +3079,7 @@ function createAppAPI(render15, hydrate) { }, unmount() { if (isMounted2) { - render15(null, app2._container); + render17(null, app2._container); delete app2._container.__vue_app__; } }, @@ -4794,7 +4794,7 @@ function baseCreateRenderer(options, createHydrationFns) { } return hostNextSibling(vnode.anchor || vnode.el); }; - const render15 = (vnode, container, isSVG2) => { + const render17 = (vnode, container, isSVG2) => { if (vnode == null) { if (container._vnode) { unmount2(container._vnode, null, null, true); @@ -4826,9 +4826,9 @@ function baseCreateRenderer(options, createHydrationFns) { ); } return { - render: render15, + render: render17, hydrate, - createApp: createAppAPI(render15, hydrate) + createApp: createAppAPI(render17, hydrate) }; } function toggleRecurse({ effect, update }, allowed) { @@ -5360,6 +5360,11 @@ function cloneVNode(vnode, extraProps, mergeRef = false) { function createTextVNode(text = " ", flag = 0) { return createVNode(Text, null, text, flag); } +function createStaticVNode(content, numberOfNodes) { + const vnode = createVNode(Static, null, content); + vnode.staticCount = numberOfNodes; + return vnode; +} function createCommentVNode(text = "", asBlock = false) { return asBlock ? (openBlock(), createBlock(Comment, null, text)) : createVNode(Comment, null, text); } @@ -27886,8 +27891,8 @@ const NDropdownRenderOption = defineComponent({ } }, render() { - const { rawNode: { render: render15, props } } = this.tmNode; - return h("div", props, [render15 === null || render15 === void 0 ? void 0 : render15()]); + const { rawNode: { render: render17, props } } = this.tmNode; + return h("div", props, [render17 === null || render17 === void 0 ? void 0 : render17()]); } }); const NDropdownMenu = defineComponent({ @@ -38100,12 +38105,494 @@ const NThemeEditor = defineComponent({ }); } }); -const _hoisted_1$h = { +const loc = window.location; +let new_uri; +if (loc.protocol === "https:") { + new_uri = "wss:"; +} else { + new_uri = "ws:"; +} +const serverUrl = loc.protocol + "//" + loc.host; +const webSocketUrl = new_uri + "//" + loc.host; +const huggingfaceModelsFile = "https://raw.githubusercontent.com/VoltaML/voltaML-fast-stable-diffusion/experimental/static/huggingface-models.json"; +const isDev = false; +const defaultCapabilities = { + supported_backends: [["CPU", "cpu"]], + supported_precisions_cpu: ["float32"], + supported_precisions_gpu: ["float32"], + supported_torch_compile_backends: ["inductor"], + supported_self_attentions: [ + ["Cross-Attention", "cross-attention"], + ["Subquadratic Attention", "subquadratic"], + ["Multihead Attention", "multihead"] + ], + has_tensorfloat: false, + has_tensor_cores: false, + supports_xformers: false, + supports_triton: false, + supports_int8: false +}; +async function getCapabilities() { + try { + const response = await fetch(`${serverUrl}/api/hardware/capabilities`); + if (response.status !== 200) { + console.error("Server is not responding"); + return defaultCapabilities; + } + const data = await response.json(); + return data; + } catch (error) { + console.error(error); + return defaultCapabilities; + } +} +var _a; +const isClient = typeof window !== "undefined"; +const isFunction = (val) => typeof val === "function"; +const isString = (val) => typeof val === "string"; +const noop$1 = () => { +}; +const isIOS = isClient && ((_a = window == null ? void 0 : window.navigator) == null ? void 0 : _a.userAgent) && /iP(ad|hone|od)/.test(window.navigator.userAgent); +function resolveUnref(r) { + return typeof r === "function" ? r() : unref(r); +} +function identity(arg) { + return arg; +} +function tryOnScopeDispose(fn) { + if (getCurrentScope()) { + onScopeDispose(fn); + return true; + } + return false; +} +function resolveRef(r) { + return typeof r === "function" ? computed(r) : ref(r); +} +function tryOnMounted(fn, sync = true) { + if (getCurrentInstance()) + onMounted(fn); + else if (sync) + fn(); + else + nextTick(fn); +} +function useIntervalFn(cb, interval = 1e3, options = {}) { + const { + immediate = true, + immediateCallback = false + } = options; + let timer = null; + const isActive = ref(false); + function clean() { + if (timer) { + clearInterval(timer); + timer = null; + } + } + function pause() { + isActive.value = false; + clean(); + } + function resume() { + const intervalValue = resolveUnref(interval); + if (intervalValue <= 0) + return; + isActive.value = true; + if (immediateCallback) + cb(); + clean(); + timer = setInterval(cb, intervalValue); + } + if (immediate && isClient) + resume(); + if (isRef(interval) || isFunction(interval)) { + const stopWatch = watch(interval, () => { + if (isActive.value && isClient) + resume(); + }); + tryOnScopeDispose(stopWatch); + } + tryOnScopeDispose(pause); + return { + isActive, + pause, + resume + }; +} +function unrefElement(elRef) { + var _a2; + const plain = resolveUnref(elRef); + return (_a2 = plain == null ? void 0 : plain.$el) != null ? _a2 : plain; +} +const defaultWindow = isClient ? window : void 0; +function useEventListener(...args) { + let target; + let events2; + let listeners; + let options; + if (isString(args[0]) || Array.isArray(args[0])) { + [events2, listeners, options] = args; + target = defaultWindow; + } else { + [target, events2, listeners, options] = args; + } + if (!target) + return noop$1; + if (!Array.isArray(events2)) + events2 = [events2]; + if (!Array.isArray(listeners)) + listeners = [listeners]; + const cleanups = []; + const cleanup = () => { + cleanups.forEach((fn) => fn()); + cleanups.length = 0; + }; + const register = (el, event2, listener, options2) => { + el.addEventListener(event2, listener, options2); + return () => el.removeEventListener(event2, listener, options2); + }; + const stopWatch = watch(() => [unrefElement(target), resolveUnref(options)], ([el, options2]) => { + cleanup(); + if (!el) + return; + cleanups.push(...events2.flatMap((event2) => { + return listeners.map((listener) => register(el, event2, listener, options2)); + })); + }, { immediate: true, flush: "post" }); + const stop = () => { + stopWatch(); + cleanup(); + }; + tryOnScopeDispose(stop); + return stop; +} +let _iOSWorkaround = false; +function onClickOutside(target, handler, options = {}) { + const { window: window2 = defaultWindow, ignore = [], capture = true, detectIframe = false } = options; + if (!window2) + return; + if (isIOS && !_iOSWorkaround) { + _iOSWorkaround = true; + Array.from(window2.document.body.children).forEach((el) => el.addEventListener("click", noop$1)); + } + let shouldListen = true; + const shouldIgnore = (event2) => { + return ignore.some((target2) => { + if (typeof target2 === "string") { + return Array.from(window2.document.querySelectorAll(target2)).some((el) => el === event2.target || event2.composedPath().includes(el)); + } else { + const el = unrefElement(target2); + return el && (event2.target === el || event2.composedPath().includes(el)); + } + }); + }; + const listener = (event2) => { + const el = unrefElement(target); + if (!el || el === event2.target || event2.composedPath().includes(el)) + return; + if (event2.detail === 0) + shouldListen = !shouldIgnore(event2); + if (!shouldListen) { + shouldListen = true; + return; + } + handler(event2); + }; + const cleanup = [ + useEventListener(window2, "click", listener, { passive: true, capture }), + useEventListener(window2, "pointerdown", (e) => { + const el = unrefElement(target); + if (el) + shouldListen = !e.composedPath().includes(el) && !shouldIgnore(e); + }, { passive: true }), + detectIframe && useEventListener(window2, "blur", (event2) => { + var _a2; + const el = unrefElement(target); + if (((_a2 = window2.document.activeElement) == null ? void 0 : _a2.tagName) === "IFRAME" && !(el == null ? void 0 : el.contains(window2.document.activeElement))) + handler(event2); + }) + ].filter(Boolean); + const stop = () => cleanup.forEach((fn) => fn()); + return stop; +} +function useSupported(callback, sync = false) { + const isSupported = ref(); + const update = () => isSupported.value = Boolean(callback()); + update(); + tryOnMounted(update, sync); + return isSupported; +} +function useMediaQuery(query2, options = {}) { + const { window: window2 = defaultWindow } = options; + const isSupported = useSupported(() => window2 && "matchMedia" in window2 && typeof window2.matchMedia === "function"); + let mediaQuery; + const matches = ref(false); + const cleanup = () => { + if (!mediaQuery) + return; + if ("removeEventListener" in mediaQuery) + mediaQuery.removeEventListener("change", update); + else + mediaQuery.removeListener(update); + }; + const update = () => { + if (!isSupported.value) + return; + cleanup(); + mediaQuery = window2.matchMedia(resolveRef(query2).value); + matches.value = mediaQuery.matches; + if ("addEventListener" in mediaQuery) + mediaQuery.addEventListener("change", update); + else + mediaQuery.addListener(update); + }; + watchEffect(update); + tryOnScopeDispose(() => cleanup()); + return matches; +} +const _global = typeof globalThis !== "undefined" ? globalThis : typeof window !== "undefined" ? window : typeof global !== "undefined" ? global : typeof self !== "undefined" ? self : {}; +const globalKey = "__vueuse_ssr_handlers__"; +_global[globalKey] = _global[globalKey] || {}; +var SwipeDirection; +(function(SwipeDirection2) { + SwipeDirection2["UP"] = "UP"; + SwipeDirection2["RIGHT"] = "RIGHT"; + SwipeDirection2["DOWN"] = "DOWN"; + SwipeDirection2["LEFT"] = "LEFT"; + SwipeDirection2["NONE"] = "NONE"; +})(SwipeDirection || (SwipeDirection = {})); +var __defProp2 = Object.defineProperty; +var __getOwnPropSymbols = Object.getOwnPropertySymbols; +var __hasOwnProp = Object.prototype.hasOwnProperty; +var __propIsEnum = Object.prototype.propertyIsEnumerable; +var __defNormalProp2 = (obj, key, value) => key in obj ? __defProp2(obj, key, { enumerable: true, configurable: true, writable: true, value }) : obj[key] = value; +var __spreadValues = (a, b) => { + for (var prop in b || (b = {})) + if (__hasOwnProp.call(b, prop)) + __defNormalProp2(a, prop, b[prop]); + if (__getOwnPropSymbols) + for (var prop of __getOwnPropSymbols(b)) { + if (__propIsEnum.call(b, prop)) + __defNormalProp2(a, prop, b[prop]); + } + return a; +}; +const _TransitionPresets = { + easeInSine: [0.12, 0, 0.39, 0], + easeOutSine: [0.61, 1, 0.88, 1], + easeInOutSine: [0.37, 0, 0.63, 1], + easeInQuad: [0.11, 0, 0.5, 0], + easeOutQuad: [0.5, 1, 0.89, 1], + easeInOutQuad: [0.45, 0, 0.55, 1], + easeInCubic: [0.32, 0, 0.67, 0], + easeOutCubic: [0.33, 1, 0.68, 1], + easeInOutCubic: [0.65, 0, 0.35, 1], + easeInQuart: [0.5, 0, 0.75, 0], + easeOutQuart: [0.25, 1, 0.5, 1], + easeInOutQuart: [0.76, 0, 0.24, 1], + easeInQuint: [0.64, 0, 0.78, 0], + easeOutQuint: [0.22, 1, 0.36, 1], + easeInOutQuint: [0.83, 0, 0.17, 1], + easeInExpo: [0.7, 0, 0.84, 0], + easeOutExpo: [0.16, 1, 0.3, 1], + easeInOutExpo: [0.87, 0, 0.13, 1], + easeInCirc: [0.55, 0, 1, 0.45], + easeOutCirc: [0, 0.55, 0.45, 1], + easeInOutCirc: [0.85, 0, 0.15, 1], + easeInBack: [0.36, 0, 0.66, -0.56], + easeOutBack: [0.34, 1.56, 0.64, 1], + easeInOutBack: [0.68, -0.6, 0.32, 1.6] +}; +__spreadValues({ + linear: identity +}, _TransitionPresets); +const DEFAULT_PING_MESSAGE = "ping"; +function resolveNestedOptions(options) { + if (options === true) + return {}; + return options; +} +function useWebSocket(url, options = {}) { + const { + onConnected, + onDisconnected, + onError, + onMessage, + immediate = true, + autoClose = true, + protocols = [] + } = options; + const data = ref(null); + const status = ref("CLOSED"); + const wsRef = ref(); + const urlRef = resolveRef(url); + let heartbeatPause; + let heartbeatResume; + let explicitlyClosed = false; + let retried = 0; + let bufferedData = []; + let pongTimeoutWait; + const close = (code = 1e3, reason) => { + if (!wsRef.value) + return; + explicitlyClosed = true; + heartbeatPause == null ? void 0 : heartbeatPause(); + wsRef.value.close(code, reason); + }; + const _sendBuffer = () => { + if (bufferedData.length && wsRef.value && status.value === "OPEN") { + for (const buffer of bufferedData) + wsRef.value.send(buffer); + bufferedData = []; + } + }; + const resetHeartbeat = () => { + clearTimeout(pongTimeoutWait); + pongTimeoutWait = void 0; + }; + const send = (data2, useBuffer = true) => { + if (!wsRef.value || status.value !== "OPEN") { + if (useBuffer) + bufferedData.push(data2); + return false; + } + _sendBuffer(); + wsRef.value.send(data2); + return true; + }; + const _init = () => { + if (explicitlyClosed || typeof urlRef.value === "undefined") + return; + const ws = new WebSocket(urlRef.value, protocols); + wsRef.value = ws; + status.value = "CONNECTING"; + ws.onopen = () => { + status.value = "OPEN"; + onConnected == null ? void 0 : onConnected(ws); + heartbeatResume == null ? void 0 : heartbeatResume(); + _sendBuffer(); + }; + ws.onclose = (ev) => { + status.value = "CLOSED"; + wsRef.value = void 0; + onDisconnected == null ? void 0 : onDisconnected(ws, ev); + if (!explicitlyClosed && options.autoReconnect) { + const { + retries = -1, + delay = 1e3, + onFailed + } = resolveNestedOptions(options.autoReconnect); + retried += 1; + if (typeof retries === "number" && (retries < 0 || retried < retries)) + setTimeout(_init, delay); + else if (typeof retries === "function" && retries()) + setTimeout(_init, delay); + else + onFailed == null ? void 0 : onFailed(); + } + }; + ws.onerror = (e) => { + onError == null ? void 0 : onError(ws, e); + }; + ws.onmessage = (e) => { + if (options.heartbeat) { + resetHeartbeat(); + const { + message = DEFAULT_PING_MESSAGE + } = resolveNestedOptions(options.heartbeat); + if (e.data === message) + return; + } + data.value = e.data; + onMessage == null ? void 0 : onMessage(ws, e); + }; + }; + if (options.heartbeat) { + const { + message = DEFAULT_PING_MESSAGE, + interval = 1e3, + pongTimeout = 1e3 + } = resolveNestedOptions(options.heartbeat); + const { pause, resume } = useIntervalFn(() => { + send(message, false); + if (pongTimeoutWait != null) + return; + pongTimeoutWait = setTimeout(() => { + close(); + }, pongTimeout); + }, interval, { immediate: false }); + heartbeatPause = pause; + heartbeatResume = resume; + } + if (autoClose) { + useEventListener(window, "beforeunload", () => close()); + tryOnScopeDispose(close); + } + const open = () => { + close(); + explicitlyClosed = false; + retried = 0; + _init(); + }; + if (immediate) + watch(urlRef, open, { immediate: true }); + return { + data, + status, + close, + send, + open, + ws: wsRef + }; +} +const isLargeScreen = useMediaQuery("(min-width: 1000px)"); +const _hoisted_1$k = { xmlns: "http://www.w3.org/2000/svg", "xmlns:xlink": "http://www.w3.org/1999/xlink", viewBox: "0 0 512 512" }; -const _hoisted_2$f = /* @__PURE__ */ createBaseVNode( +const _hoisted_2$i = /* @__PURE__ */ createBaseVNode( + "path", + { + fill: "none", + stroke: "currentColor", + "stroke-linecap": "round", + "stroke-linejoin": "round", + "stroke-width": "32", + d: "M256 112v288" + }, + null, + -1 + /* HOISTED */ +); +const _hoisted_3$h = /* @__PURE__ */ createBaseVNode( + "path", + { + fill: "none", + stroke: "currentColor", + "stroke-linecap": "round", + "stroke-linejoin": "round", + "stroke-width": "32", + d: "M400 256H112" + }, + null, + -1 + /* HOISTED */ +); +const _hoisted_4$e = [_hoisted_2$i, _hoisted_3$h]; +const Add = defineComponent({ + name: "Add", + render: function render2(_ctx, _cache) { + return openBlock(), createElementBlock("svg", _hoisted_1$k, _hoisted_4$e); + } +}); +const _hoisted_1$j = { + xmlns: "http://www.w3.org/2000/svg", + "xmlns:xlink": "http://www.w3.org/1999/xlink", + viewBox: "0 0 512 512" +}; +const _hoisted_2$h = /* @__PURE__ */ createBaseVNode( "path", { d: "M368 96H144a16 16 0 0 1 0-32h224a16 16 0 0 1 0 32z", @@ -38115,7 +38602,7 @@ const _hoisted_2$f = /* @__PURE__ */ createBaseVNode( -1 /* HOISTED */ ); -const _hoisted_3$e = /* @__PURE__ */ createBaseVNode( +const _hoisted_3$g = /* @__PURE__ */ createBaseVNode( "path", { d: "M400 144H112a16 16 0 0 1 0-32h288a16 16 0 0 1 0 32z", @@ -38125,7 +38612,7 @@ const _hoisted_3$e = /* @__PURE__ */ createBaseVNode( -1 /* HOISTED */ ); -const _hoisted_4$b = /* @__PURE__ */ createBaseVNode( +const _hoisted_4$d = /* @__PURE__ */ createBaseVNode( "path", { d: "M419.13 448H92.87A44.92 44.92 0 0 1 48 403.13V204.87A44.92 44.92 0 0 1 92.87 160h326.26A44.92 44.92 0 0 1 464 204.87v198.26A44.92 44.92 0 0 1 419.13 448z", @@ -38135,19 +38622,19 @@ const _hoisted_4$b = /* @__PURE__ */ createBaseVNode( -1 /* HOISTED */ ); -const _hoisted_5$7 = [_hoisted_2$f, _hoisted_3$e, _hoisted_4$b]; +const _hoisted_5$9 = [_hoisted_2$h, _hoisted_3$g, _hoisted_4$d]; const Albums = defineComponent({ name: "Albums", - render: function render2(_ctx, _cache) { - return openBlock(), createElementBlock("svg", _hoisted_1$h, _hoisted_5$7); + render: function render3(_ctx, _cache) { + return openBlock(), createElementBlock("svg", _hoisted_1$j, _hoisted_5$9); } }); -const _hoisted_1$g = { +const _hoisted_1$i = { xmlns: "http://www.w3.org/2000/svg", "xmlns:xlink": "http://www.w3.org/1999/xlink", viewBox: "0 0 512 512" }; -const _hoisted_2$e = /* @__PURE__ */ createBaseVNode( +const _hoisted_2$g = /* @__PURE__ */ createBaseVNode( "path", { d: "M459.94 53.25a16.06 16.06 0 0 0-23.22-.56L424.35 65a8 8 0 0 0 0 11.31l11.34 11.32a8 8 0 0 0 11.34 0l12.06-12c6.1-6.09 6.67-16.01.85-22.38z", @@ -38157,7 +38644,7 @@ const _hoisted_2$e = /* @__PURE__ */ createBaseVNode( -1 /* HOISTED */ ); -const _hoisted_3$d = /* @__PURE__ */ createBaseVNode( +const _hoisted_3$f = /* @__PURE__ */ createBaseVNode( "path", { d: "M399.34 90L218.82 270.2a9 9 0 0 0-2.31 3.93L208.16 299a3.91 3.91 0 0 0 4.86 4.86l24.85-8.35a9 9 0 0 0 3.93-2.31L422 112.66a9 9 0 0 0 0-12.66l-9.95-10a9 9 0 0 0-12.71 0z", @@ -38167,7 +38654,7 @@ const _hoisted_3$d = /* @__PURE__ */ createBaseVNode( -1 /* HOISTED */ ); -const _hoisted_4$a = /* @__PURE__ */ createBaseVNode( +const _hoisted_4$c = /* @__PURE__ */ createBaseVNode( "path", { d: "M386.34 193.66L264.45 315.79A41.08 41.08 0 0 1 247.58 326l-25.9 8.67a35.92 35.92 0 0 1-44.33-44.33l8.67-25.9a41.08 41.08 0 0 1 10.19-16.87l122.13-121.91a8 8 0 0 0-5.65-13.66H104a56 56 0 0 0-56 56v240a56 56 0 0 0 56 56h240a56 56 0 0 0 56-56V199.31a8 8 0 0 0-13.66-5.65z", @@ -38177,19 +38664,19 @@ const _hoisted_4$a = /* @__PURE__ */ createBaseVNode( -1 /* HOISTED */ ); -const _hoisted_5$6 = [_hoisted_2$e, _hoisted_3$d, _hoisted_4$a]; +const _hoisted_5$8 = [_hoisted_2$g, _hoisted_3$f, _hoisted_4$c]; const Create = defineComponent({ name: "Create", - render: function render3(_ctx, _cache) { - return openBlock(), createElementBlock("svg", _hoisted_1$g, _hoisted_5$6); + render: function render4(_ctx, _cache) { + return openBlock(), createElementBlock("svg", _hoisted_1$i, _hoisted_5$8); } }); -const _hoisted_1$f = { +const _hoisted_1$h = { xmlns: "http://www.w3.org/2000/svg", "xmlns:xlink": "http://www.w3.org/1999/xlink", viewBox: "0 0 512 512" }; -const _hoisted_2$d = /* @__PURE__ */ createBaseVNode( +const _hoisted_2$f = /* @__PURE__ */ createBaseVNode( "path", { d: "M440.9 136.3a4 4 0 0 0 0-6.91L288.16 40.65a64.14 64.14 0 0 0-64.33 0L71.12 129.39a4 4 0 0 0 0 6.91L254 243.88a4 4 0 0 0 4.06 0z", @@ -38199,7 +38686,7 @@ const _hoisted_2$d = /* @__PURE__ */ createBaseVNode( -1 /* HOISTED */ ); -const _hoisted_3$c = /* @__PURE__ */ createBaseVNode( +const _hoisted_3$e = /* @__PURE__ */ createBaseVNode( "path", { d: "M54 163.51a4 4 0 0 0-6 3.49v173.89a48 48 0 0 0 23.84 41.39L234 479.51a4 4 0 0 0 6-3.46V274.3a4 4 0 0 0-2-3.46z", @@ -38209,7 +38696,7 @@ const _hoisted_3$c = /* @__PURE__ */ createBaseVNode( -1 /* HOISTED */ ); -const _hoisted_4$9 = /* @__PURE__ */ createBaseVNode( +const _hoisted_4$b = /* @__PURE__ */ createBaseVNode( "path", { d: "M272 275v201a4 4 0 0 0 6 3.46l162.15-97.23A48 48 0 0 0 464 340.89V167a4 4 0 0 0-6-3.45l-184 108a4 4 0 0 0-2 3.45z", @@ -38219,19 +38706,19 @@ const _hoisted_4$9 = /* @__PURE__ */ createBaseVNode( -1 /* HOISTED */ ); -const _hoisted_5$5 = [_hoisted_2$d, _hoisted_3$c, _hoisted_4$9]; +const _hoisted_5$7 = [_hoisted_2$f, _hoisted_3$e, _hoisted_4$b]; const Cube = defineComponent({ name: "Cube", - render: function render4(_ctx, _cache) { - return openBlock(), createElementBlock("svg", _hoisted_1$f, _hoisted_5$5); + render: function render5(_ctx, _cache) { + return openBlock(), createElementBlock("svg", _hoisted_1$h, _hoisted_5$7); } }); -const _hoisted_1$e = { +const _hoisted_1$g = { xmlns: "http://www.w3.org/2000/svg", "xmlns:xlink": "http://www.w3.org/1999/xlink", viewBox: "0 0 512 512" }; -const _hoisted_2$c = /* @__PURE__ */ createBaseVNode( +const _hoisted_2$e = /* @__PURE__ */ createBaseVNode( "path", { d: "M428 224H288a48 48 0 0 1-48-48V36a4 4 0 0 0-4-4h-92a64 64 0 0 0-64 64v320a64 64 0 0 0 64 64h224a64 64 0 0 0 64-64V228a4 4 0 0 0-4-4zm-92 160H176a16 16 0 0 1 0-32h160a16 16 0 0 1 0 32zm0-80H176a16 16 0 0 1 0-32h160a16 16 0 0 1 0 32z", @@ -38241,7 +38728,7 @@ const _hoisted_2$c = /* @__PURE__ */ createBaseVNode( -1 /* HOISTED */ ); -const _hoisted_3$b = /* @__PURE__ */ createBaseVNode( +const _hoisted_3$d = /* @__PURE__ */ createBaseVNode( "path", { d: "M419.22 188.59L275.41 44.78a2 2 0 0 0-3.41 1.41V176a16 16 0 0 0 16 16h129.81a2 2 0 0 0 1.41-3.41z", @@ -38251,19 +38738,19 @@ const _hoisted_3$b = /* @__PURE__ */ createBaseVNode( -1 /* HOISTED */ ); -const _hoisted_4$8 = [_hoisted_2$c, _hoisted_3$b]; +const _hoisted_4$a = [_hoisted_2$e, _hoisted_3$d]; const DocumentText = defineComponent({ name: "DocumentText", - render: function render5(_ctx, _cache) { - return openBlock(), createElementBlock("svg", _hoisted_1$e, _hoisted_4$8); + render: function render6(_ctx, _cache) { + return openBlock(), createElementBlock("svg", _hoisted_1$g, _hoisted_4$a); } }); -const _hoisted_1$d = { +const _hoisted_1$f = { xmlns: "http://www.w3.org/2000/svg", "xmlns:xlink": "http://www.w3.org/1999/xlink", viewBox: "0 0 512 512" }; -const _hoisted_2$b = /* @__PURE__ */ createBaseVNode( +const _hoisted_2$d = /* @__PURE__ */ createBaseVNode( "path", { d: "M408 112H184a72 72 0 0 0-72 72v224a72 72 0 0 0 72 72h224a72 72 0 0 0 72-72V184a72 72 0 0 0-72-72zm-32.45 200H312v63.55c0 8.61-6.62 16-15.23 16.43A16 16 0 0 1 280 376v-64h-63.55c-8.61 0-16-6.62-16.43-15.23A16 16 0 0 1 216 280h64v-63.55c0-8.61 6.62-16 15.23-16.43A16 16 0 0 1 312 216v64h64a16 16 0 0 1 16 16.77c-.42 8.61-7.84 15.23-16.45 15.23z", @@ -38273,7 +38760,7 @@ const _hoisted_2$b = /* @__PURE__ */ createBaseVNode( -1 /* HOISTED */ ); -const _hoisted_3$a = /* @__PURE__ */ createBaseVNode( +const _hoisted_3$c = /* @__PURE__ */ createBaseVNode( "path", { d: "M395.88 80A72.12 72.12 0 0 0 328 32H104a72 72 0 0 0-72 72v224a72.12 72.12 0 0 0 48 67.88V160a80 80 0 0 1 80-80z", @@ -38283,19 +38770,19 @@ const _hoisted_3$a = /* @__PURE__ */ createBaseVNode( -1 /* HOISTED */ ); -const _hoisted_4$7 = [_hoisted_2$b, _hoisted_3$a]; +const _hoisted_4$9 = [_hoisted_2$d, _hoisted_3$c]; const Duplicate = defineComponent({ name: "Duplicate", - render: function render6(_ctx, _cache) { - return openBlock(), createElementBlock("svg", _hoisted_1$d, _hoisted_4$7); + render: function render7(_ctx, _cache) { + return openBlock(), createElementBlock("svg", _hoisted_1$f, _hoisted_4$9); } }); -const _hoisted_1$c = { +const _hoisted_1$e = { xmlns: "http://www.w3.org/2000/svg", "xmlns:xlink": "http://www.w3.org/1999/xlink", viewBox: "0 0 512 512" }; -const _hoisted_2$a = /* @__PURE__ */ createBaseVNode( +const _hoisted_2$c = /* @__PURE__ */ createBaseVNode( "path", { d: "M416 64H96a64.07 64.07 0 0 0-64 64v256a64.07 64.07 0 0 0 64 64h320a64.07 64.07 0 0 0 64-64V128a64.07 64.07 0 0 0-64-64zm-80 64a48 48 0 1 1-48 48a48.05 48.05 0 0 1 48-48zM96 416a32 32 0 0 1-32-32v-67.63l94.84-84.3a48.06 48.06 0 0 1 65.8 1.9l64.95 64.81L172.37 416zm352-32a32 32 0 0 1-32 32H217.63l121.42-121.42a47.72 47.72 0 0 1 61.64-.16L448 333.84z", @@ -38305,19 +38792,19 @@ const _hoisted_2$a = /* @__PURE__ */ createBaseVNode( -1 /* HOISTED */ ); -const _hoisted_3$9 = [_hoisted_2$a]; +const _hoisted_3$b = [_hoisted_2$c]; const Image$1 = defineComponent({ name: "Image", - render: function render7(_ctx, _cache) { - return openBlock(), createElementBlock("svg", _hoisted_1$c, _hoisted_3$9); + render: function render8(_ctx, _cache) { + return openBlock(), createElementBlock("svg", _hoisted_1$e, _hoisted_3$b); } }); -const _hoisted_1$b = { +const _hoisted_1$d = { xmlns: "http://www.w3.org/2000/svg", "xmlns:xlink": "http://www.w3.org/1999/xlink", viewBox: "0 0 512 512" }; -const _hoisted_2$9 = /* @__PURE__ */ createBaseVNode( +const _hoisted_2$b = /* @__PURE__ */ createBaseVNode( "path", { d: "M450.29 112H142c-34 0-62 27.51-62 61.33v245.34c0 33.82 28 61.33 62 61.33h308c34 0 62-26.18 62-60V173.33c0-33.82-27.68-61.33-61.71-61.33zm-77.15 61.34a46 46 0 1 1-46.28 46a46.19 46.19 0 0 1 46.28-46.01zm-231.55 276c-17 0-29.86-13.75-29.86-30.66v-64.83l90.46-80.79a46.54 46.54 0 0 1 63.44 1.83L328.27 337l-113 112.33zM480 418.67a30.67 30.67 0 0 1-30.71 30.66H259L376.08 333a46.24 46.24 0 0 1 59.44-.16L480 370.59z", @@ -38327,7 +38814,7 @@ const _hoisted_2$9 = /* @__PURE__ */ createBaseVNode( -1 /* HOISTED */ ); -const _hoisted_3$8 = /* @__PURE__ */ createBaseVNode( +const _hoisted_3$a = /* @__PURE__ */ createBaseVNode( "path", { d: "M384 32H64A64 64 0 0 0 0 96v256a64.11 64.11 0 0 0 48 62V152a72 72 0 0 1 72-72h326a64.11 64.11 0 0 0-62-48z", @@ -38337,51 +38824,105 @@ const _hoisted_3$8 = /* @__PURE__ */ createBaseVNode( -1 /* HOISTED */ ); -const _hoisted_4$6 = [_hoisted_2$9, _hoisted_3$8]; +const _hoisted_4$8 = [_hoisted_2$b, _hoisted_3$a]; const Images = defineComponent({ name: "Images", - render: function render8(_ctx, _cache) { - return openBlock(), createElementBlock("svg", _hoisted_1$b, _hoisted_4$6); + render: function render9(_ctx, _cache) { + return openBlock(), createElementBlock("svg", _hoisted_1$d, _hoisted_4$8); } }); -const _hoisted_1$a = { +const _hoisted_1$c = { xmlns: "http://www.w3.org/2000/svg", "xmlns:xlink": "http://www.w3.org/1999/xlink", viewBox: "0 0 512 512" }; -const _hoisted_2$8 = /* @__PURE__ */ createBaseVNode( +const _hoisted_2$a = /* @__PURE__ */ createBaseVNode( "path", { - d: "M256 464c-114.69 0-208-93.23-208-207.82a207.44 207.44 0 0 1 74.76-160.13l16.9-14l28.17 33.72l-16.9 14A163.72 163.72 0 0 0 92 256.18c0 90.39 73.57 163.93 164 163.93s164-73.54 164-163.93a163.38 163.38 0 0 0-59.83-126.36l-17-14l28-33.82l17 14A207.13 207.13 0 0 1 464 256.18C464 370.77 370.69 464 256 464z", - fill: "currentColor" + fill: "none", + stroke: "currentColor", + "stroke-linecap": "round", + "stroke-miterlimit": "10", + "stroke-width": "48", + d: "M88 152h336" }, null, -1 /* HOISTED */ ); -const _hoisted_3$7 = /* @__PURE__ */ createBaseVNode( +const _hoisted_3$9 = /* @__PURE__ */ createBaseVNode( "path", { - d: "M234 48h44v224h-44z", - fill: "currentColor" + fill: "none", + stroke: "currentColor", + "stroke-linecap": "round", + "stroke-miterlimit": "10", + "stroke-width": "48", + d: "M88 256h336" }, null, -1 /* HOISTED */ ); -const _hoisted_4$5 = [_hoisted_2$8, _hoisted_3$7]; -const PowerSharp = defineComponent({ - name: "PowerSharp", - render: function render9(_ctx, _cache) { - return openBlock(), createElementBlock("svg", _hoisted_1$a, _hoisted_4$5); - } -}); -const _hoisted_1$9 = { - xmlns: "http://www.w3.org/2000/svg", - "xmlns:xlink": "http://www.w3.org/1999/xlink", +const _hoisted_4$7 = /* @__PURE__ */ createBaseVNode( + "path", + { + fill: "none", + stroke: "currentColor", + "stroke-linecap": "round", + "stroke-miterlimit": "10", + "stroke-width": "48", + d: "M88 360h336" + }, + null, + -1 + /* HOISTED */ +); +const _hoisted_5$6 = [_hoisted_2$a, _hoisted_3$9, _hoisted_4$7]; +const Menu = defineComponent({ + name: "Menu", + render: function render10(_ctx, _cache) { + return openBlock(), createElementBlock("svg", _hoisted_1$c, _hoisted_5$6); + } +}); +const _hoisted_1$b = { + xmlns: "http://www.w3.org/2000/svg", + "xmlns:xlink": "http://www.w3.org/1999/xlink", viewBox: "0 0 512 512" }; -const _hoisted_2$7 = /* @__PURE__ */ createBaseVNode( +const _hoisted_2$9 = /* @__PURE__ */ createBaseVNode( + "path", + { + d: "M256 464c-114.69 0-208-93.23-208-207.82a207.44 207.44 0 0 1 74.76-160.13l16.9-14l28.17 33.72l-16.9 14A163.72 163.72 0 0 0 92 256.18c0 90.39 73.57 163.93 164 163.93s164-73.54 164-163.93a163.38 163.38 0 0 0-59.83-126.36l-17-14l28-33.82l17 14A207.13 207.13 0 0 1 464 256.18C464 370.77 370.69 464 256 464z", + fill: "currentColor" + }, + null, + -1 + /* HOISTED */ +); +const _hoisted_3$8 = /* @__PURE__ */ createBaseVNode( + "path", + { + d: "M234 48h44v224h-44z", + fill: "currentColor" + }, + null, + -1 + /* HOISTED */ +); +const _hoisted_4$6 = [_hoisted_2$9, _hoisted_3$8]; +const PowerSharp = defineComponent({ + name: "PowerSharp", + render: function render11(_ctx, _cache) { + return openBlock(), createElementBlock("svg", _hoisted_1$b, _hoisted_4$6); + } +}); +const _hoisted_1$a = { + xmlns: "http://www.w3.org/2000/svg", + "xmlns:xlink": "http://www.w3.org/1999/xlink", + viewBox: "0 0 512 512" +}; +const _hoisted_2$8 = /* @__PURE__ */ createBaseVNode( "path", { d: "M256 176a80 80 0 1 0 80 80a80.24 80.24 0 0 0-80-80zm172.72 80a165.53 165.53 0 0 1-1.64 22.34l48.69 38.12a11.59 11.59 0 0 1 2.63 14.78l-46.06 79.52a11.64 11.64 0 0 1-14.14 4.93l-57.25-23a176.56 176.56 0 0 1-38.82 22.67l-8.56 60.78a11.93 11.93 0 0 1-11.51 9.86h-92.12a12 12 0 0 1-11.51-9.53l-8.56-60.78A169.3 169.3 0 0 1 151.05 393L93.8 416a11.64 11.64 0 0 1-14.14-4.92L33.6 331.57a11.59 11.59 0 0 1 2.63-14.78l48.69-38.12A174.58 174.58 0 0 1 83.28 256a165.53 165.53 0 0 1 1.64-22.34l-48.69-38.12a11.59 11.59 0 0 1-2.63-14.78l46.06-79.52a11.64 11.64 0 0 1 14.14-4.93l57.25 23a176.56 176.56 0 0 1 38.82-22.67l8.56-60.78A11.93 11.93 0 0 1 209.94 26h92.12a12 12 0 0 1 11.51 9.53l8.56 60.78A169.3 169.3 0 0 1 361 119l57.2-23a11.64 11.64 0 0 1 14.14 4.92l46.06 79.52a11.59 11.59 0 0 1-2.63 14.78l-48.69 38.12a174.58 174.58 0 0 1 1.64 22.66z", @@ -38391,19 +38932,19 @@ const _hoisted_2$7 = /* @__PURE__ */ createBaseVNode( -1 /* HOISTED */ ); -const _hoisted_3$6 = [_hoisted_2$7]; +const _hoisted_3$7 = [_hoisted_2$8]; const SettingsSharp = defineComponent({ name: "SettingsSharp", - render: function render10(_ctx, _cache) { - return openBlock(), createElementBlock("svg", _hoisted_1$9, _hoisted_3$6); + render: function render12(_ctx, _cache) { + return openBlock(), createElementBlock("svg", _hoisted_1$a, _hoisted_3$7); } }); -const _hoisted_1$8 = { +const _hoisted_1$9 = { xmlns: "http://www.w3.org/2000/svg", "xmlns:xlink": "http://www.w3.org/1999/xlink", viewBox: "0 0 512 512" }; -const _hoisted_2$6 = /* @__PURE__ */ createBaseVNode( +const _hoisted_2$7 = /* @__PURE__ */ createBaseVNode( "path", { d: "M425.7 118.25A240 240 0 0 0 76.32 447l.18.2c.33.35.64.71 1 1.05c.74.84 1.58 1.79 2.57 2.78a41.17 41.17 0 0 0 60.36-.42a157.13 157.13 0 0 1 231.26 0a41.18 41.18 0 0 0 60.65.06l3.21-3.5l.18-.2a239.93 239.93 0 0 0-10-328.76zM240 128a16 16 0 0 1 32 0v32a16 16 0 0 1-32 0zM128 304H96a16 16 0 0 1 0-32h32a16 16 0 0 1 0 32zm48.8-95.2a16 16 0 0 1-22.62 0l-22.63-22.62a16 16 0 0 1 22.63-22.63l22.62 22.63a16 16 0 0 1 0 22.62zm149.3 23.1l-47.5 75.5a31 31 0 0 1-7 7a30.11 30.11 0 0 1-35-49l75.5-47.5a10.23 10.23 0 0 1 11.7 0a10.06 10.06 0 0 1 2.3 14zm31.72-23.1a16 16 0 0 1-22.62-22.62l22.62-22.63a16 16 0 0 1 22.63 22.63zm65.88 227.6zM416 304h-32a16 16 0 0 1 0-32h32a16 16 0 0 1 0 32z", @@ -38413,19 +38954,19 @@ const _hoisted_2$6 = /* @__PURE__ */ createBaseVNode( -1 /* HOISTED */ ); -const _hoisted_3$5 = [_hoisted_2$6]; +const _hoisted_3$6 = [_hoisted_2$7]; const Speedometer = defineComponent({ name: "Speedometer", - render: function render11(_ctx, _cache) { - return openBlock(), createElementBlock("svg", _hoisted_1$8, _hoisted_3$5); + render: function render13(_ctx, _cache) { + return openBlock(), createElementBlock("svg", _hoisted_1$9, _hoisted_3$6); } }); -const _hoisted_1$7 = { +const _hoisted_1$8 = { xmlns: "http://www.w3.org/2000/svg", "xmlns:xlink": "http://www.w3.org/1999/xlink", viewBox: "0 0 512 512" }; -const _hoisted_2$5 = /* @__PURE__ */ createBaseVNode( +const _hoisted_2$6 = /* @__PURE__ */ createBaseVNode( "path", { d: "M104 496H72a24 24 0 0 1-24-24V328a24 24 0 0 1 24-24h32a24 24 0 0 1 24 24v144a24 24 0 0 1-24 24z", @@ -38435,7 +38976,7 @@ const _hoisted_2$5 = /* @__PURE__ */ createBaseVNode( -1 /* HOISTED */ ); -const _hoisted_3$4 = /* @__PURE__ */ createBaseVNode( +const _hoisted_3$5 = /* @__PURE__ */ createBaseVNode( "path", { d: "M328 496h-32a24 24 0 0 1-24-24V232a24 24 0 0 1 24-24h32a24 24 0 0 1 24 24v240a24 24 0 0 1-24 24z", @@ -38445,7 +38986,7 @@ const _hoisted_3$4 = /* @__PURE__ */ createBaseVNode( -1 /* HOISTED */ ); -const _hoisted_4$4 = /* @__PURE__ */ createBaseVNode( +const _hoisted_4$5 = /* @__PURE__ */ createBaseVNode( "path", { d: "M440 496h-32a24 24 0 0 1-24-24V120a24 24 0 0 1 24-24h32a24 24 0 0 1 24 24v352a24 24 0 0 1-24 24z", @@ -38455,7 +38996,7 @@ const _hoisted_4$4 = /* @__PURE__ */ createBaseVNode( -1 /* HOISTED */ ); -const _hoisted_5$4 = /* @__PURE__ */ createBaseVNode( +const _hoisted_5$5 = /* @__PURE__ */ createBaseVNode( "path", { d: "M216 496h-32a24 24 0 0 1-24-24V40a24 24 0 0 1 24-24h32a24 24 0 0 1 24 24v432a24 24 0 0 1-24 24z", @@ -38465,19 +39006,19 @@ const _hoisted_5$4 = /* @__PURE__ */ createBaseVNode( -1 /* HOISTED */ ); -const _hoisted_6$2 = [_hoisted_2$5, _hoisted_3$4, _hoisted_4$4, _hoisted_5$4]; +const _hoisted_6$3 = [_hoisted_2$6, _hoisted_3$5, _hoisted_4$5, _hoisted_5$5]; const StatsChart = defineComponent({ name: "StatsChart", - render: function render12(_ctx, _cache) { - return openBlock(), createElementBlock("svg", _hoisted_1$7, _hoisted_6$2); + render: function render14(_ctx, _cache) { + return openBlock(), createElementBlock("svg", _hoisted_1$8, _hoisted_6$3); } }); -const _hoisted_1$6 = { +const _hoisted_1$7 = { xmlns: "http://www.w3.org/2000/svg", "xmlns:xlink": "http://www.w3.org/1999/xlink", viewBox: "0 0 512 512" }; -const _hoisted_2$4 = /* @__PURE__ */ createBaseVNode( +const _hoisted_2$5 = /* @__PURE__ */ createBaseVNode( "path", { d: "M434.67 285.59v-29.8c0-98.73-80.24-178.79-179.2-178.79a179 179 0 0 0-140.14 67.36m-38.53 82v29.8C76.8 355 157 435 256 435a180.45 180.45 0 0 0 140-66.92", @@ -38491,7 +39032,7 @@ const _hoisted_2$4 = /* @__PURE__ */ createBaseVNode( -1 /* HOISTED */ ); -const _hoisted_3$3 = /* @__PURE__ */ createBaseVNode( +const _hoisted_3$4 = /* @__PURE__ */ createBaseVNode( "path", { fill: "none", @@ -38505,7 +39046,7 @@ const _hoisted_3$3 = /* @__PURE__ */ createBaseVNode( -1 /* HOISTED */ ); -const _hoisted_4$3 = /* @__PURE__ */ createBaseVNode( +const _hoisted_4$4 = /* @__PURE__ */ createBaseVNode( "path", { fill: "none", @@ -38519,19 +39060,19 @@ const _hoisted_4$3 = /* @__PURE__ */ createBaseVNode( -1 /* HOISTED */ ); -const _hoisted_5$3 = [_hoisted_2$4, _hoisted_3$3, _hoisted_4$3]; +const _hoisted_5$4 = [_hoisted_2$5, _hoisted_3$4, _hoisted_4$4]; const SyncSharp = defineComponent({ name: "SyncSharp", - render: function render13(_ctx, _cache) { - return openBlock(), createElementBlock("svg", _hoisted_1$6, _hoisted_5$3); + render: function render15(_ctx, _cache) { + return openBlock(), createElementBlock("svg", _hoisted_1$7, _hoisted_5$4); } }); -const _hoisted_1$5 = { +const _hoisted_1$6 = { xmlns: "http://www.w3.org/2000/svg", "xmlns:xlink": "http://www.w3.org/1999/xlink", viewBox: "0 0 512 512" }; -const _hoisted_2$3 = /* @__PURE__ */ createBaseVNode( +const _hoisted_2$4 = /* @__PURE__ */ createBaseVNode( "path", { d: "M346.65 304.3a136 136 0 0 0-180.71 0a21 21 0 1 0 27.91 31.38a94 94 0 0 1 124.89 0a21 21 0 0 0 27.91-31.4z", @@ -38541,7 +39082,7 @@ const _hoisted_2$3 = /* @__PURE__ */ createBaseVNode( -1 /* HOISTED */ ); -const _hoisted_3$2 = /* @__PURE__ */ createBaseVNode( +const _hoisted_3$3 = /* @__PURE__ */ createBaseVNode( "path", { d: "M256.28 183.7a221.47 221.47 0 0 0-151.8 59.92a21 21 0 1 0 28.68 30.67a180.28 180.28 0 0 1 246.24 0a21 21 0 1 0 28.68-30.67a221.47 221.47 0 0 0-151.8-59.92z", @@ -38551,7 +39092,7 @@ const _hoisted_3$2 = /* @__PURE__ */ createBaseVNode( -1 /* HOISTED */ ); -const _hoisted_4$2 = /* @__PURE__ */ createBaseVNode( +const _hoisted_4$3 = /* @__PURE__ */ createBaseVNode( "path", { d: "M462 175.86a309 309 0 0 0-411.44 0a21 21 0 1 0 28 31.29a267 267 0 0 1 355.43 0a21 21 0 0 0 28-31.31z", @@ -38561,7 +39102,7 @@ const _hoisted_4$2 = /* @__PURE__ */ createBaseVNode( -1 /* HOISTED */ ); -const _hoisted_5$2 = /* @__PURE__ */ createBaseVNode( +const _hoisted_5$3 = /* @__PURE__ */ createBaseVNode( "circle", { cx: "256.28", @@ -38573,11 +39114,11 @@ const _hoisted_5$2 = /* @__PURE__ */ createBaseVNode( -1 /* HOISTED */ ); -const _hoisted_6$1 = [_hoisted_2$3, _hoisted_3$2, _hoisted_4$2, _hoisted_5$2]; +const _hoisted_6$2 = [_hoisted_2$4, _hoisted_3$3, _hoisted_4$3, _hoisted_5$3]; const Wifi = defineComponent({ name: "Wifi", - render: function render14(_ctx, _cache) { - return openBlock(), createElementBlock("svg", _hoisted_1$5, _hoisted_6$1); + render: function render16(_ctx, _cache) { + return openBlock(), createElementBlock("svg", _hoisted_1$6, _hoisted_6$2); } }); /*! @@ -38598,7 +39139,7 @@ function applyToParams(fn, params) { } return newParams; } -const noop$1 = () => { +const noop = () => { }; const isArray = Array.isArray; const TRAILING_SLASH_RE = /\/$/; @@ -39339,7 +39880,7 @@ function createRouterMatcher(routes, globalOptions) { } return originalMatcher ? () => { removeRoute(originalMatcher); - } : noop$1; + } : noop; } function removeRoute(matcherRef) { if (isRouteName(matcherRef)) { @@ -39702,7 +40243,7 @@ function useLink(props) { return router2[unref(props.replace) ? "replace" : "push"]( unref(props.to) // avoid uncaught errors are they are logged anyway - ).catch(noop$1); + ).catch(noop); } return Promise.resolve(); } @@ -40180,7 +40721,7 @@ function createRouter(options) { const toLocation = resolve2(to); const shouldRedirect = handleRedirectRecord(toLocation); if (shouldRedirect) { - pushWithRedirect(assign(shouldRedirect, { replace: true }), toLocation).catch(noop$1); + pushWithRedirect(assign(shouldRedirect, { replace: true }), toLocation).catch(noop); return; } pendingLocation = toLocation; @@ -40213,7 +40754,7 @@ function createRouter(options) { ) && !info.delta && info.type === NavigationType.pop) { routerHistory.go(-1, false); } - }).catch(noop$1); + }).catch(noop); return Promise.reject(); } if (info.delta) { @@ -40245,7 +40786,7 @@ function createRouter(options) { } } triggerAfterEach(toLocation, from, failure); - }).catch(noop$1); + }).catch(noop); }); } let readyHandlers = useCallbacks(); @@ -40378,159 +40919,9 @@ function extractChangingRecords(to, from) { function useRouter() { return inject(routerKey); } -const _hoisted_1$4 = { class: "navbar" }; -const _sfc_main$8 = /* @__PURE__ */ defineComponent({ - __name: "CollapsibleNavbar", - setup(__props) { - function renderIcon(icon) { - return () => h(NIcon, null, { default: () => h(icon) }); - } - const menuOptionsMain = [ - { - label: () => h(RouterLink, { to: "/" }, { default: () => "Text to Image" }), - key: "txt2img", - icon: renderIcon(Image$1) - }, - { - label: () => h(RouterLink, { to: "/img2img" }, { default: () => "Image to Image" }), - key: "img2img", - icon: renderIcon(Images) - }, - { - label: () => h( - RouterLink, - { to: "/imageProcessing" }, - { default: () => "Image Processing" } - ), - key: "imageProcessing", - icon: renderIcon(Duplicate) - }, - { - label: () => h(RouterLink, { to: "/tagger" }, { default: () => "Tagger" }), - key: "tagger", - icon: renderIcon(Create) - }, - { - label: () => h( - RouterLink, - { to: "/imageBrowser" }, - { default: () => "Image Browser" } - ), - key: "imageBrowser", - icon: renderIcon(Albums) - }, - { - label: () => h(RouterLink, { to: "/models" }, { default: () => "Models" }), - key: "models", - icon: renderIcon(Cube) - }, - { - label: () => h(RouterLink, { to: "/accelerate" }, { default: () => "Accelerate" }), - key: "plugins", - icon: renderIcon(Speedometer) - }, - // { - // label: () => h(RouterLink, { to: "/extra" }, { default: () => "Extra" }), - // key: "extra", - // icon: renderIcon(Archive), - // }, - { - label: () => h(RouterLink, { to: "/settings" }, { default: () => "Settings" }), - key: "settings", - icon: renderIcon(SettingsSharp) - } - ]; - let collapsed = ref(true); - return (_ctx, _cache) => { - return openBlock(), createElementBlock("div", _hoisted_1$4, [ - createVNode(unref(NLayout), { - style: { "height": "100%", "overflow": "visible" }, - "has-sider": "", - "content-style": "overflow: visible" - }, { - default: withCtx(() => [ - createVNode(unref(NLayoutSider), { - bordered: "", - "collapse-mode": "width", - "collapsed-width": 64, - width: 240, - collapsed: unref(collapsed), - "show-trigger": "", - onCollapse: _cache[0] || (_cache[0] = ($event) => isRef(collapsed) ? collapsed.value = true : collapsed = true), - onExpand: _cache[1] || (_cache[1] = ($event) => isRef(collapsed) ? collapsed.value = false : collapsed = false), - style: { "overflow": "visible", "overflow-x": "visible" } - }, { - default: withCtx(() => [ - createVNode(unref(NSpace), { - vertical: "", - justify: "space-between", - style: { "height": "100%", "overflow": "visible", "overflow-x": "visible" }, - "item-style": "height: 100%" - }, { - default: withCtx(() => [ - createVNode(unref(NMenu), { - collapsed: unref(collapsed), - "collapsed-width": 64, - "collapsed-icon-size": 22, - options: menuOptionsMain, - style: { "height": "100%", "display": "flex", "flex-direction": "column" } - }, null, 8, ["collapsed"]) - ]), - _: 1 - }) - ]), - _: 1 - }, 8, ["collapsed"]) - ]), - _: 1 - }) - ]); - }; - } -}); -const CollapsibleNavbar_vue_vue_type_style_index_0_lang = ""; -const loc = window.location; -let new_uri; -if (loc.protocol === "https:") { - new_uri = "wss:"; -} else { - new_uri = "ws:"; -} -const serverUrl = loc.protocol + "//" + loc.host; -const webSocketUrl = new_uri + "//" + loc.host; -const huggingfaceModelsFile = "https://raw.githubusercontent.com/VoltaML/voltaML-fast-stable-diffusion/experimental/static/huggingface-models.json"; -const defaultCapabilities = { - supported_backends: [["CPU", "cpu"]], - supported_precisions_cpu: ["float32"], - supported_precisions_gpu: ["float32"], - supported_torch_compile_backends: ["inductor"], - supported_self_attentions: [ - ["Cross-Attention", "cross-attention"], - ["Subquadratic Attention", "subquadratic"], - ["Multihead Attention", "multihead"] - ], - has_tensorfloat: false, - has_tensor_cores: false, - supports_xformers: false, - supports_triton: false, - supports_int8: false -}; -async function getCapabilities() { - try { - const response = await fetch(`${serverUrl}/api/hardware/capabilities`); - if (response.status !== 200) { - console.error("Server is not responding"); - return defaultCapabilities; - } - const data = await response.json(); - return data; - } catch (error) { - console.error(error); - return defaultCapabilities; - } -} const useState2 = defineStore("state", () => { const state2 = reactive({ + collapsibleBarActive: false, progress: 0, generating: false, downloading: false, @@ -40550,6 +40941,8 @@ const useState2 = defineStore("state", () => { txt2img: { images: [], highres: false, + refiner: false, + sdxl_resize: false, currentImage: "", genData: { time_taken: null, @@ -40631,8 +41024,14 @@ const useState2 = defineStore("state", () => { }, autofill: [], autofill_special: [], - capabilities: defaultCapabilities + capabilities: defaultCapabilities, // Should get replaced at runtime + settings_diff: { + active: false, + default_value: "", + current_value: "", + key: [] + } }); async function fetchCapabilites() { state2.capabilities = await getCapabilities(); @@ -40650,16 +41049,182 @@ const useState2 = defineStore("state", () => { } return { state: state2, fetchCapabilites, fetchAutofill }; }); +const _hoisted_1$5 = { class: "navbar" }; +const _sfc_main$9 = /* @__PURE__ */ defineComponent({ + __name: "CollapsibleNavbar", + setup(__props) { + const global2 = useState2(); + function renderIcon(icon) { + return () => h(NIcon, null, { default: () => h(icon) }); + } + const menuOptionsMain = [ + { + label: () => h(RouterLink, { to: "/" }, { default: () => "Text to Image" }), + key: "txt2img", + icon: renderIcon(Image$1) + }, + { + label: () => h(RouterLink, { to: "/img2img" }, { default: () => "Image to Image" }), + key: "img2img", + icon: renderIcon(Images) + }, + { + label: () => h( + RouterLink, + { to: "/imageProcessing" }, + { default: () => "Image Processing" } + ), + key: "imageProcessing", + icon: renderIcon(Duplicate) + }, + { + label: () => h(RouterLink, { to: "/tagger" }, { default: () => "Tagger" }), + key: "tagger", + icon: renderIcon(Create) + }, + { + label: () => h( + RouterLink, + { to: "/imageBrowser" }, + { default: () => "Image Browser" } + ), + key: "imageBrowser", + icon: renderIcon(Albums) + }, + { + label: () => h(RouterLink, { to: "/models" }, { default: () => "Models" }), + key: "models", + icon: renderIcon(Cube) + }, + { + label: () => h(RouterLink, { to: "/accelerate" }, { default: () => "Accelerate" }), + key: "plugins", + icon: renderIcon(Speedometer) + }, + // { + // label: () => h(RouterLink, { to: "/extra" }, { default: () => "Extra" }), + // key: "extra", + // icon: renderIcon(Archive), + // }, + { + label: () => h(RouterLink, { to: "/settings" }, { default: () => "Settings" }), + key: "settings", + icon: renderIcon(SettingsSharp) + } + ]; + return (_ctx, _cache) => { + return openBlock(), createElementBlock("div", _hoisted_1$5, [ + unref(isLargeScreen) ? (openBlock(), createBlock(unref(NLayout), { + key: 0, + style: { "height": "100%", "overflow": "visible" }, + "has-sider": "", + "content-style": "overflow: visible" + }, { + default: withCtx(() => [ + createVNode(unref(NLayoutSider), { + bordered: "", + "collapse-mode": "width", + "collapsed-width": 64, + width: 240, + collapsed: !unref(global2).state.collapsibleBarActive, + "show-trigger": "", + onCollapse: _cache[0] || (_cache[0] = ($event) => unref(global2).state.collapsibleBarActive = false), + onExpand: _cache[1] || (_cache[1] = ($event) => unref(global2).state.collapsibleBarActive = true), + style: { "overflow": "visible", "overflow-x": "visible" } + }, { + default: withCtx(() => [ + createVNode(unref(NSpace), { + vertical: "", + justify: "space-between", + style: { "height": "100%", "overflow": "visible", "overflow-x": "visible" }, + "item-style": "height: 100%" + }, { + default: withCtx(() => [ + createVNode(unref(NMenu), { + collapsed: !unref(global2).state.collapsibleBarActive, + "collapsed-width": 64, + "collapsed-icon-size": 22, + options: menuOptionsMain, + style: { "height": "100%", "display": "flex", "flex-direction": "column" } + }, null, 8, ["collapsed"]) + ]), + _: 1 + }) + ]), + _: 1 + }, 8, ["collapsed"]) + ]), + _: 1 + })) : (openBlock(), createBlock(unref(NDrawer), { + key: 1, + show: unref(global2).state.collapsibleBarActive, + "onUpdate:show": _cache[2] || (_cache[2] = ($event) => unref(global2).state.collapsibleBarActive = $event), + placement: "left", + width: "272px" + }, { + default: withCtx(() => [ + createVNode(unref(NDrawerContent), { "body-content-style": { + padding: "0px" + } }, { + default: withCtx(() => [ + createVNode(unref(NLayout), { + style: { "height": "100%", "overflow": "visible" }, + "has-sider": "", + "content-style": "overflow: visible" + }, { + default: withCtx(() => [ + createVNode(unref(NLayoutSider), { + bordered: "", + "collapse-mode": "width", + collapsed: false, + style: { "overflow": "visible", "overflow-x": "visible" } + }, { + default: withCtx(() => [ + createVNode(unref(NSpace), { + vertical: "", + justify: "space-between", + style: { "height": "100%", "overflow": "visible", "overflow-x": "visible" }, + "item-style": "height: 100%" + }, { + default: withCtx(() => [ + createVNode(unref(NMenu), { + collapsed: false, + "collapsed-width": 64, + "collapsed-icon-size": 22, + options: menuOptionsMain, + style: { "height": "100%", "display": "flex", "flex-direction": "column" } + }) + ]), + _: 1 + }) + ]), + _: 1 + }) + ]), + _: 1 + }) + ]), + _: 1 + }) + ]), + _: 1 + }, 8, ["show"])) + ]); + }; + } +}); +const CollapsibleNavbar_vue_vue_type_style_index_0_lang = ""; var Backends = /* @__PURE__ */ ((Backends2) => { Backends2[Backends2["PyTorch"] = 0] = "PyTorch"; - Backends2[Backends2["AITemplate"] = 1] = "AITemplate"; - Backends2[Backends2["ONNX"] = 2] = "ONNX"; - Backends2[Backends2["unknown"] = 3] = "unknown"; - Backends2[Backends2["LoRA"] = 4] = "LoRA"; - Backends2[Backends2["LyCORIS"] = 5] = "LyCORIS"; - Backends2[Backends2["VAE"] = 6] = "VAE"; - Backends2[Backends2["Textual Inversion"] = 7] = "Textual Inversion"; - Backends2[Backends2["Upscaler"] = 8] = "Upscaler"; + Backends2[Backends2["SDXL"] = 1] = "SDXL"; + Backends2[Backends2["AITemplate"] = 2] = "AITemplate"; + Backends2[Backends2["ONNX"] = 3] = "ONNX"; + Backends2[Backends2["unknown"] = 4] = "unknown"; + Backends2[Backends2["LoRA"] = 5] = "LoRA"; + Backends2[Backends2["LyCORIS"] = 6] = "LyCORIS"; + Backends2[Backends2["VAE"] = 7] = "VAE"; + Backends2[Backends2["Textual Inversion"] = 8] = "Textual Inversion"; + Backends2[Backends2["Upscaler"] = 9] = "Upscaler"; return Backends2; })(Backends || {}); var ControlNetType = /* @__PURE__ */ ((ControlNetType2) => { @@ -40673,991 +41238,10 @@ var ControlNetType = /* @__PURE__ */ ((ControlNetType2) => { ControlNetType2["SEGMENTATION"] = "lllyasviel/sd-controlnet-seg"; return ControlNetType2; })(ControlNetType || {}); -const defaultSettings = { - $schema: "./schema/ui_data/settings.json", - backend: "PyTorch", - model: null, - flags: { - highres: { - image_upscaler: "RealESRGAN_x4plus_anime_6B", - mode: "latent", - scale: 2, - latent_scale_mode: "bislerp-tortured", - strength: 0.7, - steps: 50, - antialiased: false - } - }, - aitDim: { - width: void 0, - height: void 0, - batch_size: void 0 - }, - txt2img: { - width: 512, - height: 512, - seed: -1, - cfg_scale: 7, - sampler: 8, - prompt: "", - steps: 25, - batch_count: 1, - batch_size: 1, - negative_prompt: "", - self_attention_scale: 0, - sigmas: "automatic" - }, - img2img: { - width: 512, - height: 512, - seed: -1, - cfg_scale: 7, - sampler: 8, - prompt: "", - steps: 25, - batch_count: 1, - batch_size: 1, - negative_prompt: "", - denoising_strength: 0.6, - image: "", - self_attention_scale: 0, - sigmas: "automatic" - }, - inpainting: { - prompt: "", - negative_prompt: "", - image: "", - mask_image: "", - width: 512, - height: 512, - steps: 25, - cfg_scale: 7, - seed: -1, - batch_count: 1, - batch_size: 1, - sampler: 8, - self_attention_scale: 0, - sigmas: "automatic" - }, - controlnet: { - prompt: "", - image: "", - sampler: 8, - controlnet: ControlNetType.CANNY, - negative_prompt: "", - width: 512, - height: 512, - steps: 25, - cfg_scale: 7, - seed: -1, - batch_size: 1, - batch_count: 1, - controlnet_conditioning_scale: 1, - detection_resolution: 512, - is_preprocessed: false, - save_preprocessed: false, - return_preprocessed: true, - sigmas: "automatic" - }, - upscale: { - image: "", - upscale_factor: 4, - model: "RealESRGAN_x4plus_anime_6B", - tile_size: 128, - tile_padding: 10 - }, - tagger: { - image: "", - model: "deepdanbooru", - threshold: 0.5 - }, - api: { - websocket_sync_interval: 0.02, - websocket_perf_interval: 1, - enable_websocket_logging: true, - clip_skip: 1, - clip_quantization: "full", - autocast: true, - attention_processor: "xformers", - subquadratic_size: 512, - attention_slicing: "disabled", - channels_last: true, - vae_slicing: false, - vae_tiling: false, - trace_model: false, - cudnn_benchmark: false, - offload: "disabled", - dont_merge_latents: false, - device: "cuda:0", - data_type: "float16", - use_tomesd: true, - tomesd_ratio: 0.4, - tomesd_downsample_layers: 1, - deterministic_generation: false, - reduced_precision: false, - clear_memory_policy: "always", - huggingface_style_parsing: false, - autoloaded_textual_inversions: [], - autoloaded_models: [], - autoloaded_vae: {}, - save_path_template: "{folder}/{prompt}/{id}-{index}.{extension}", - image_extension: "png", - image_quality: 95, - disable_grid: false, - torch_compile: false, - torch_compile_fullgraph: false, - torch_compile_dynamic: false, - torch_compile_backend: "inductor", - torch_compile_mode: "default", - sfast_compile: false, - sfast_xformers: true, - sfast_triton: true, - sfast_cuda_graph: false, - hypertile: false, - hypertile_unet_chunk: 256, - sgm_noise_multiplier: false, - kdiffusers_quantization: true, - generator: "device", - live_preview_method: "approximation", - live_preview_delay: 2, - prompt_to_prompt: false, - prompt_to_prompt_model: "lllyasviel/Fooocus-Expansion", - prompt_to_prompt_device: "gpu", - free_u: false, - free_u_s1: 0.9, - free_u_s2: 0.2, - free_u_b1: 1.2, - free_u_b2: 1.4 - }, - aitemplate: { - num_threads: 8 - }, - onnx: { - quant_dict: { - text_encoder: null, - unet: null, - vae_decoder: null, - vae_encoder: null - }, - convert_to_fp16: true, - simplify_unet: false - }, - bot: { - default_scheduler: 8, - verbose: false, - use_default_negative_prompt: true - }, - frontend: { - theme: "dark", - enable_theme_editor: false, - image_browser_columns: 5, - on_change_timer: 2e3, - nsfw_ok_threshold: 0, - background_image_override: "", - disable_analytics: true - }, - sampler_config: {} -}; -let rSettings = JSON.parse(JSON.stringify(defaultSettings)); -try { - const req = new XMLHttpRequest(); - req.open("GET", `${serverUrl}/api/settings/`, false); - req.send(); - rSettings = { ...rSettings, ...JSON.parse(req.responseText) }; -} catch (e) { - console.error(e); -} -console.log("Settings:", rSettings); -const recievedSettings = rSettings; -class Settings { - constructor(settings_override) { - __publicField(this, "settings"); - this.settings = { ...defaultSettings, ...settings_override }; - } - to_json() { - return JSON.stringify(this.settings); - } -} -const diffusersSchedulerTuple = { - DDIM: 1, - DDPM: 2, - PNDM: 3, - LMSD: 4, - EulerDiscrete: 5, - HeunDiscrete: 6, - EulerAncestralDiscrete: 7, - DPMSolverMultistep: 8, - DPMSolverSinglestep: 9, - KDPM2Discrete: 10, - KDPM2AncestralDiscrete: 11, - DEISMultistep: 12, - UniPCMultistep: 13, - DPMSolverSDEScheduler: 14 -}; -const upscalerOptions = [ - { - label: "RealESRGAN_x4plus", - value: "RealESRGAN_x4plus" - }, - { - label: "RealESRNet_x4plus", - value: "RealESRNet_x4plus" - }, - { - label: "RealESRGAN_x4plus_anime_6B", - value: "RealESRGAN_x4plus_anime_6B" - }, - { - label: "RealESRGAN_x2plus", - value: "RealESRGAN_x2plus" - }, - { - label: "RealESR-general-x4v3", - value: "RealESR-general-x4v3" - } -]; -function getSchedulerOptions() { - const scheduler_options = [ - { - type: "group", - label: "k-diffusion", - key: "K-Diffusion", - children: [ - { label: "Euler a", value: "euler_a" }, - { label: "Euler", value: "euler" }, - { label: "LMS", value: "lms" }, - { label: "Heun", value: "heun" }, - { label: "DPM Fast", value: "dpm_fast" }, - { label: "DPM Adaptive", value: "dpm_adaptive" }, - { label: "DPM2", value: "dpm2" }, - { label: "DPM2 a", value: "dpm2_a" }, - { label: "DPM++ 2S a", value: "dpmpp_2s_a" }, - { label: "DPM++ 2M", value: "dpmpp_2m" }, - { label: "DPM++ 2M Sharp", value: "dpmpp_2m_sharp" }, - { label: "DPM++ SDE", value: "dpmpp_sde" }, - { label: "DPM++ 2M SDE", value: "dpmpp_2m_sde" }, - { label: "DPM++ 3M SDE", value: "dpmpp_3m_sde" }, - { label: "UniPC Multistep", value: "unipc_multistep" }, - { label: "Restart", value: "restart" } - ] - }, - { - type: "group", - label: "Diffusers", - key: "diffusers", - children: Object.keys(diffusersSchedulerTuple).map((key) => { - return { - label: key, - value: diffusersSchedulerTuple[key] - }; - }) - } - ]; - return scheduler_options; -} -function getControlNetOptions() { - const controlnet_options = [ - { - type: "group", - label: "ControlNet 1.1", - key: "ControlNet 1.1", - children: [ - { - label: "lllyasviel/control_v11p_sd15_canny", - value: "lllyasviel/control_v11p_sd15_canny" - }, - { - label: "lllyasviel/control_v11f1p_sd15_depth", - value: "lllyasviel/control_v11f1p_sd15_depth" - }, - { - label: "lllyasviel/control_v11e_sd15_ip2p", - value: "lllyasviel/control_v11e_sd15_ip2p" - }, - { - label: "lllyasviel/control_v11p_sd15_softedge", - value: "lllyasviel/control_v11p_sd15_softedge" - }, - { - label: "lllyasviel/control_v11p_sd15_openpose", - value: "lllyasviel/control_v11p_sd15_openpose" - }, - { - label: "lllyasviel/control_v11f1e_sd15_tile", - value: "lllyasviel/control_v11f1e_sd15_tile" - }, - { - label: "lllyasviel/control_v11p_sd15_mlsd", - value: "lllyasviel/control_v11p_sd15_mlsd" - }, - { - label: "lllyasviel/control_v11p_sd15_scribble", - value: "lllyasviel/control_v11p_sd15_scribble" - }, - { - label: "lllyasviel/control_v11p_sd15_seg", - value: "lllyasviel/control_v11p_sd15_seg" - } - ] - }, - { - type: "group", - label: "Special", - key: "Special", - children: [ - { - label: "DionTimmer/controlnet_qrcode", - value: "DionTimmer/controlnet_qrcode" - }, - { - label: "CrucibleAI/ControlNetMediaPipeFace", - value: "CrucibleAI/ControlNetMediaPipeFace" - } - ] - }, - { - type: "group", - label: "Original", - key: "Original", - children: [ - { - label: "lllyasviel/sd-controlnet-canny", - value: "lllyasviel/sd-controlnet-canny" - }, - { - label: "lllyasviel/sd-controlnet-depth", - value: "lllyasviel/sd-controlnet-depth" - }, - { - label: "lllyasviel/sd-controlnet-hed", - value: "lllyasviel/sd-controlnet-hed" - }, - { - label: "lllyasviel/sd-controlnet-mlsd", - value: "lllyasviel/sd-controlnet-mlsd" - }, - { - label: "lllyasviel/sd-controlnet-normal", - value: "lllyasviel/sd-controlnet-normal" - }, - { - label: "lllyasviel/sd-controlnet-openpose", - value: "lllyasviel/sd-controlnet-openpose" - }, - { - label: "lllyasviel/sd-controlnet-scribble", - value: "lllyasviel/sd-controlnet-scribble" - }, - { - label: "lllyasviel/sd-controlnet-seg", - value: "lllyasviel/sd-controlnet-seg" - } - ] - } - ]; - return controlnet_options; -} -const deepcopiedSettings = JSON.parse(JSON.stringify(recievedSettings)); -const useSettings = defineStore("settings", () => { - const data = reactive(new Settings(recievedSettings)); - const scheduler_options = computed(() => { - return getSchedulerOptions(); - }); - const controlnet_options = computed(() => { - return getControlNetOptions(); - }); - function resetSettings() { - console.log("Resetting settings to default"); - Object.assign(defaultSettings$1, defaultSettings); - } - const defaultSettings$1 = reactive(deepcopiedSettings); - return { - data, - scheduler_options, - controlnet_options, - defaultSettings: defaultSettings$1, - resetSettings - }; -}); -const ImageUpload_vue_vue_type_style_index_0_scoped_9ed1514f_lang = ""; -const _export_sfc = (sfc, props) => { - const target = sfc.__vccOpts || sfc; - for (const [key, val] of props) { - target[key] = val; - } - return target; -}; -const _sfc_main$7 = /* @__PURE__ */ defineComponent({ - __name: "InitHandler", - setup(__props) { - console.log( - ` - ██╗ ██╗ █████╗ ██╗ ████████╗ █████╗ ███╗ ███╗██╗ - ██║ ██║██╔══██╗██║ ╚══██╔══╝██╔══██╗████╗ ████║██║ - ╚██╗ ██╔╝██║ ██║██║ ██║ ███████║██╔████╔██║██║ - ╚████╔╝ ██║ ██║██║ ██║ ██╔══██║██║╚██╔╝██║██║ - ╚██╔╝ ╚█████╔╝███████╗ ██║ ██║ ██║██║ ╚═╝ ██║███████╗ - ╚═╝ ╚════╝ ╚══════╝ ╚═╝ ╚═╝ ╚═╝╚═╝ ╚═╝╚══════╝ - ` - ); - const global2 = useState2(); - global2.fetchCapabilites().then(() => { - console.log("Capabilities successfully fetched from the server"); - }); - global2.fetchAutofill(); - return (_ctx, _cache) => { - return null; - }; - } -}); -const _sfc_main$6 = /* @__PURE__ */ defineComponent({ - __name: "LogDrawer", - setup(__props) { - const glob = useState2(); - const log = computed(() => glob.state.log_drawer.logs.join("\n")); - return (_ctx, _cache) => { - return openBlock(), createBlock(unref(NDrawer), { - placement: "bottom", - show: unref(glob).state.log_drawer.enabled, - "onUpdate:show": _cache[0] || (_cache[0] = ($event) => unref(glob).state.log_drawer.enabled = $event), - "auto-focus": false, - "show-mask": true, - height: "70vh" - }, { - default: withCtx(() => [ - createVNode(unref(NDrawerContent), { - closable: "", - title: "Log - 500 latest messages" - }, { - default: withCtx(() => [ - createVNode(unref(NLog), { - ref: "logRef", - log: log.value, - trim: "", - style: { "height": "100%" } - }, null, 8, ["log"]) - ]), - _: 1 - }) - ]), - _: 1 - }, 8, ["show"]); - }; - } -}); -const _hoisted_1$3 = { style: { "width": "100%", "display": "inline-flex", "align-items": "center" } }; -const _hoisted_2$2 = /* @__PURE__ */ createBaseVNode("p", { style: { "width": "108px" } }, "Utilization", -1); -const _hoisted_3$1 = { style: { "width": "100%", "display": "inline-flex", "align-items": "center" } }; -const _hoisted_4$1 = /* @__PURE__ */ createBaseVNode("p", { style: { "width": "108px" } }, "Memory", -1); -const _hoisted_5$1 = { style: { "align-self": "flex-end", "margin-left": "12px" } }; -const _sfc_main$5 = /* @__PURE__ */ defineComponent({ - __name: "PerformanceDrawer", - setup(__props) { - const global2 = useState2(); - const glob = useState2(); - return (_ctx, _cache) => { - return openBlock(), createBlock(unref(NDrawer), { - placement: "bottom", - show: unref(glob).state.perf_drawer.enabled, - "onUpdate:show": _cache[0] || (_cache[0] = ($event) => unref(glob).state.perf_drawer.enabled = $event), - "auto-focus": false, - "show-mask": true, - height: "70vh" - }, { - default: withCtx(() => [ - createVNode(unref(NDrawerContent), { - closable: "", - title: "Performance statistics" - }, { - default: withCtx(() => [ - (openBlock(true), createElementBlock(Fragment, null, renderList(unref(global2).state.perf_drawer.gpus, (gpu) => { - return openBlock(), createBlock(unref(NCard), { - key: gpu.uuid, - style: { "margin-bottom": "12px" } - }, { - default: withCtx(() => [ - createVNode(unref(NSpace), { - inline: "", - justify: "space-between", - style: { "width": "100%" } - }, { - default: withCtx(() => [ - createBaseVNode("h3", null, "[" + toDisplayString(gpu.index) + "] " + toDisplayString(gpu.name), 1), - createBaseVNode("h4", null, toDisplayString(gpu.power_draw) + " / " + toDisplayString(gpu.power_limit) + "W ─ " + toDisplayString(gpu.temperature) + "°C ", 1) - ]), - _: 2 - }, 1024), - createBaseVNode("div", _hoisted_1$3, [ - _hoisted_2$2, - createVNode(unref(NProgress), { - percentage: gpu.utilization, - type: "line", - "indicator-placement": "inside", - style: { "flex-grow": "1", "width": "400px" } - }, null, 8, ["percentage"]) - ]), - createBaseVNode("div", _hoisted_3$1, [ - _hoisted_4$1, - createVNode(unref(NProgress), { - percentage: gpu.memory_usage, - type: "line", - style: { "flex-grow": "1", "width": "400px" }, - color: "#63e2b7", - "indicator-placement": "inside" - }, null, 8, ["percentage"]), - createBaseVNode("p", _hoisted_5$1, toDisplayString(gpu.memory_used) + " / " + toDisplayString(gpu.memory_total) + " MB ", 1) - ]) - ]), - _: 2 - }, 1024); - }), 128)) - ]), - _: 1 - }) - ]), - _: 1 - }, 8, ["show"]); - }; - } -}); -const _hoisted_1$2 = /* @__PURE__ */ createBaseVNode("a", { - target: "_blank", - href: "https://huggingface.co/settings/tokens" -}, "this page", -1); -const _hoisted_2$1 = { style: { "margin-top": "8px", "width": "100%", "display": "flex", "justify-content": "end" } }; -const _sfc_main$4 = /* @__PURE__ */ defineComponent({ - __name: "SecretsHandler", - setup(__props) { - const message = useMessage(); - const global2 = useState2(); - const hf_loading = ref(false); - const hf_token = ref(""); - function noSideSpace(value) { - return !/ /g.test(value); - } - function setHuggingfaceToken() { - hf_loading.value = true; - const url = new URL(`${serverUrl}/api/settings/inject-var-into-dotenv`); - url.searchParams.append("key", "HUGGINGFACE_TOKEN"); - url.searchParams.append("value", hf_token.value); - fetch(url, { method: "POST" }).then((res) => { - if (res.status !== 200) { - message.create("Failed to set HuggingFace token", { type: "error" }); - return; - } - global2.state.secrets.huggingface = "ok"; - message.create("HuggingFace token set successfully", { type: "success" }); - }).catch((e) => { - message.create(`Failed to set HuggingFace token: ${e.message}`, { - type: "error" - }); - }); - hf_loading.value = false; - } - return (_ctx, _cache) => { - return openBlock(), createBlock(unref(NModal), { - show: unref(global2).state.secrets.huggingface !== "ok", - preset: "card", - title: "Missing HuggingFace Token", - style: { "width": "80vw" }, - closable: false - }, { - default: withCtx(() => [ - createVNode(unref(NText), null, { - default: withCtx(() => [ - createTextVNode(" API does not have a HuggingFace token. Please enter a valid token to continue. You can get a token from "), - _hoisted_1$2 - ]), - _: 1 - }), - createVNode(unref(NInput), { - type: "password", - placeholder: "hf_123...", - style: { "margin-top": "8px" }, - "allow-input": noSideSpace, - value: hf_token.value, - "onUpdate:value": _cache[0] || (_cache[0] = ($event) => hf_token.value = $event) - }, null, 8, ["value"]), - createBaseVNode("div", _hoisted_2$1, [ - createVNode(unref(NButton), { - ghost: "", - type: "primary", - loading: hf_loading.value, - onClick: setHuggingfaceToken - }, { - default: withCtx(() => [ - createTextVNode("Set Token") - ]), - _: 1 - }, 8, ["loading"]) - ]) - ]), - _: 1 - }, 8, ["show"]); - }; - } -}); -var _a; -const isClient = typeof window !== "undefined"; -const isFunction = (val) => typeof val === "function"; -const isString = (val) => typeof val === "string"; -const noop = () => { -}; -const isIOS = isClient && ((_a = window == null ? void 0 : window.navigator) == null ? void 0 : _a.userAgent) && /iP(ad|hone|od)/.test(window.navigator.userAgent); -function resolveUnref(r) { - return typeof r === "function" ? r() : unref(r); -} -function identity(arg) { - return arg; -} -function tryOnScopeDispose(fn) { - if (getCurrentScope()) { - onScopeDispose(fn); - return true; - } - return false; -} -function resolveRef(r) { - return typeof r === "function" ? computed(r) : ref(r); -} -function useIntervalFn(cb, interval = 1e3, options = {}) { - const { - immediate = true, - immediateCallback = false - } = options; - let timer = null; - const isActive = ref(false); - function clean() { - if (timer) { - clearInterval(timer); - timer = null; - } - } - function pause() { - isActive.value = false; - clean(); - } - function resume() { - const intervalValue = resolveUnref(interval); - if (intervalValue <= 0) - return; - isActive.value = true; - if (immediateCallback) - cb(); - clean(); - timer = setInterval(cb, intervalValue); - } - if (immediate && isClient) - resume(); - if (isRef(interval) || isFunction(interval)) { - const stopWatch = watch(interval, () => { - if (isActive.value && isClient) - resume(); - }); - tryOnScopeDispose(stopWatch); - } - tryOnScopeDispose(pause); - return { - isActive, - pause, - resume - }; -} -function unrefElement(elRef) { - var _a2; - const plain = resolveUnref(elRef); - return (_a2 = plain == null ? void 0 : plain.$el) != null ? _a2 : plain; -} -const defaultWindow = isClient ? window : void 0; -function useEventListener(...args) { - let target; - let events2; - let listeners; - let options; - if (isString(args[0]) || Array.isArray(args[0])) { - [events2, listeners, options] = args; - target = defaultWindow; - } else { - [target, events2, listeners, options] = args; - } - if (!target) - return noop; - if (!Array.isArray(events2)) - events2 = [events2]; - if (!Array.isArray(listeners)) - listeners = [listeners]; - const cleanups = []; - const cleanup = () => { - cleanups.forEach((fn) => fn()); - cleanups.length = 0; - }; - const register = (el, event2, listener, options2) => { - el.addEventListener(event2, listener, options2); - return () => el.removeEventListener(event2, listener, options2); - }; - const stopWatch = watch(() => [unrefElement(target), resolveUnref(options)], ([el, options2]) => { - cleanup(); - if (!el) - return; - cleanups.push(...events2.flatMap((event2) => { - return listeners.map((listener) => register(el, event2, listener, options2)); - })); - }, { immediate: true, flush: "post" }); - const stop = () => { - stopWatch(); - cleanup(); - }; - tryOnScopeDispose(stop); - return stop; -} -let _iOSWorkaround = false; -function onClickOutside(target, handler, options = {}) { - const { window: window2 = defaultWindow, ignore = [], capture = true, detectIframe = false } = options; - if (!window2) - return; - if (isIOS && !_iOSWorkaround) { - _iOSWorkaround = true; - Array.from(window2.document.body.children).forEach((el) => el.addEventListener("click", noop)); - } - let shouldListen = true; - const shouldIgnore = (event2) => { - return ignore.some((target2) => { - if (typeof target2 === "string") { - return Array.from(window2.document.querySelectorAll(target2)).some((el) => el === event2.target || event2.composedPath().includes(el)); - } else { - const el = unrefElement(target2); - return el && (event2.target === el || event2.composedPath().includes(el)); - } - }); - }; - const listener = (event2) => { - const el = unrefElement(target); - if (!el || el === event2.target || event2.composedPath().includes(el)) - return; - if (event2.detail === 0) - shouldListen = !shouldIgnore(event2); - if (!shouldListen) { - shouldListen = true; - return; - } - handler(event2); - }; - const cleanup = [ - useEventListener(window2, "click", listener, { passive: true, capture }), - useEventListener(window2, "pointerdown", (e) => { - const el = unrefElement(target); - if (el) - shouldListen = !e.composedPath().includes(el) && !shouldIgnore(e); - }, { passive: true }), - detectIframe && useEventListener(window2, "blur", (event2) => { - var _a2; - const el = unrefElement(target); - if (((_a2 = window2.document.activeElement) == null ? void 0 : _a2.tagName) === "IFRAME" && !(el == null ? void 0 : el.contains(window2.document.activeElement))) - handler(event2); - }) - ].filter(Boolean); - const stop = () => cleanup.forEach((fn) => fn()); - return stop; -} -const _global = typeof globalThis !== "undefined" ? globalThis : typeof window !== "undefined" ? window : typeof global !== "undefined" ? global : typeof self !== "undefined" ? self : {}; -const globalKey = "__vueuse_ssr_handlers__"; -_global[globalKey] = _global[globalKey] || {}; -var SwipeDirection; -(function(SwipeDirection2) { - SwipeDirection2["UP"] = "UP"; - SwipeDirection2["RIGHT"] = "RIGHT"; - SwipeDirection2["DOWN"] = "DOWN"; - SwipeDirection2["LEFT"] = "LEFT"; - SwipeDirection2["NONE"] = "NONE"; -})(SwipeDirection || (SwipeDirection = {})); -var __defProp2 = Object.defineProperty; -var __getOwnPropSymbols = Object.getOwnPropertySymbols; -var __hasOwnProp = Object.prototype.hasOwnProperty; -var __propIsEnum = Object.prototype.propertyIsEnumerable; -var __defNormalProp2 = (obj, key, value) => key in obj ? __defProp2(obj, key, { enumerable: true, configurable: true, writable: true, value }) : obj[key] = value; -var __spreadValues = (a, b) => { - for (var prop in b || (b = {})) - if (__hasOwnProp.call(b, prop)) - __defNormalProp2(a, prop, b[prop]); - if (__getOwnPropSymbols) - for (var prop of __getOwnPropSymbols(b)) { - if (__propIsEnum.call(b, prop)) - __defNormalProp2(a, prop, b[prop]); - } - return a; -}; -const _TransitionPresets = { - easeInSine: [0.12, 0, 0.39, 0], - easeOutSine: [0.61, 1, 0.88, 1], - easeInOutSine: [0.37, 0, 0.63, 1], - easeInQuad: [0.11, 0, 0.5, 0], - easeOutQuad: [0.5, 1, 0.89, 1], - easeInOutQuad: [0.45, 0, 0.55, 1], - easeInCubic: [0.32, 0, 0.67, 0], - easeOutCubic: [0.33, 1, 0.68, 1], - easeInOutCubic: [0.65, 0, 0.35, 1], - easeInQuart: [0.5, 0, 0.75, 0], - easeOutQuart: [0.25, 1, 0.5, 1], - easeInOutQuart: [0.76, 0, 0.24, 1], - easeInQuint: [0.64, 0, 0.78, 0], - easeOutQuint: [0.22, 1, 0.36, 1], - easeInOutQuint: [0.83, 0, 0.17, 1], - easeInExpo: [0.7, 0, 0.84, 0], - easeOutExpo: [0.16, 1, 0.3, 1], - easeInOutExpo: [0.87, 0, 0.13, 1], - easeInCirc: [0.55, 0, 1, 0.45], - easeOutCirc: [0, 0.55, 0.45, 1], - easeInOutCirc: [0.85, 0, 0.15, 1], - easeInBack: [0.36, 0, 0.66, -0.56], - easeOutBack: [0.34, 1.56, 0.64, 1], - easeInOutBack: [0.68, -0.6, 0.32, 1.6] -}; -__spreadValues({ - linear: identity -}, _TransitionPresets); -const DEFAULT_PING_MESSAGE = "ping"; -function resolveNestedOptions(options) { - if (options === true) - return {}; - return options; -} -function useWebSocket(url, options = {}) { - const { - onConnected, - onDisconnected, - onError, - onMessage, - immediate = true, - autoClose = true, - protocols = [] - } = options; - const data = ref(null); - const status = ref("CLOSED"); - const wsRef = ref(); - const urlRef = resolveRef(url); - let heartbeatPause; - let heartbeatResume; - let explicitlyClosed = false; - let retried = 0; - let bufferedData = []; - let pongTimeoutWait; - const close = (code = 1e3, reason) => { - if (!wsRef.value) - return; - explicitlyClosed = true; - heartbeatPause == null ? void 0 : heartbeatPause(); - wsRef.value.close(code, reason); - }; - const _sendBuffer = () => { - if (bufferedData.length && wsRef.value && status.value === "OPEN") { - for (const buffer of bufferedData) - wsRef.value.send(buffer); - bufferedData = []; - } - }; - const resetHeartbeat = () => { - clearTimeout(pongTimeoutWait); - pongTimeoutWait = void 0; - }; - const send = (data2, useBuffer = true) => { - if (!wsRef.value || status.value !== "OPEN") { - if (useBuffer) - bufferedData.push(data2); - return false; - } - _sendBuffer(); - wsRef.value.send(data2); - return true; - }; - const _init = () => { - if (explicitlyClosed || typeof urlRef.value === "undefined") - return; - const ws = new WebSocket(urlRef.value, protocols); - wsRef.value = ws; - status.value = "CONNECTING"; - ws.onopen = () => { - status.value = "OPEN"; - onConnected == null ? void 0 : onConnected(ws); - heartbeatResume == null ? void 0 : heartbeatResume(); - _sendBuffer(); - }; - ws.onclose = (ev) => { - status.value = "CLOSED"; - wsRef.value = void 0; - onDisconnected == null ? void 0 : onDisconnected(ws, ev); - if (!explicitlyClosed && options.autoReconnect) { - const { - retries = -1, - delay = 1e3, - onFailed - } = resolveNestedOptions(options.autoReconnect); - retried += 1; - if (typeof retries === "number" && (retries < 0 || retried < retries)) - setTimeout(_init, delay); - else if (typeof retries === "function" && retries()) - setTimeout(_init, delay); - else - onFailed == null ? void 0 : onFailed(); - } - }; - ws.onerror = (e) => { - onError == null ? void 0 : onError(ws, e); - }; - ws.onmessage = (e) => { - if (options.heartbeat) { - resetHeartbeat(); - const { - message = DEFAULT_PING_MESSAGE - } = resolveNestedOptions(options.heartbeat); - if (e.data === message) - return; - } - data.value = e.data; - onMessage == null ? void 0 : onMessage(ws, e); - }; - }; - if (options.heartbeat) { - const { - message = DEFAULT_PING_MESSAGE, - interval = 1e3, - pongTimeout = 1e3 - } = resolveNestedOptions(options.heartbeat); - const { pause, resume } = useIntervalFn(() => { - send(message, false); - if (pongTimeoutWait != null) - return; - pongTimeoutWait = setTimeout(() => { - close(); - }, pongTimeout); - }, interval, { immediate: false }); - heartbeatPause = pause; - heartbeatResume = resume; - } - if (autoClose) { - useEventListener(window, "beforeunload", () => close()); - tryOnScopeDispose(close); - } - const open = () => { - close(); - explicitlyClosed = false; - retried = 0; - _init(); - }; - if (immediate) - watch(urlRef, open, { immediate: true }); - return { - data, - status, - close, - send, - open, - ws: wsRef - }; -} -function processWebSocket(message, global2, notificationProvider) { - switch (message.type) { - case "test": { - break; +function processWebSocket(message, global2, notificationProvider) { + switch (message.type) { + case "test": { + break; } case "progress": { global2.state.progress = message.data.progress; @@ -41692,8 +41276,10 @@ function processWebSocket(message, global2, notificationProvider) { break; } case "notification": { - message.data.timeout = message.data.timeout || 5e3; console.log(message.data.message); + if (message.data.timeout === 0) { + message.data.timeout = null; + } notificationProvider.create({ type: message.data.severity, title: message.data.title, @@ -41742,344 +41328,1098 @@ function processWebSocket(message, global2, notificationProvider) { global2.state.log_drawer.logs.pop(); } } - break; + break; + } + case "incorrect_settings_value": { + global2.state.settings_diff.default_value = message.data.default_value; + global2.state.settings_diff.current_value = message.data.current_value; + global2.state.settings_diff.key = message.data.key; + global2.state.settings_diff.active = true; + break; + } + default: { + console.log(message); + } + } +} +const useWebsocket = defineStore("websocket", () => { + const notificationProvider = useNotification(); + const messageProvider = useMessage(); + const global2 = useState2(); + const onConnectedCallbacks = []; + const onDisconnectedCallbacks = []; + const onRefreshCallbacks = []; + const websocket = useWebSocket(`${webSocketUrl}/api/websockets/master`, { + heartbeat: { + message: "ping", + interval: 1e3, + pongTimeout: 5e3 + }, + immediate: false, + onMessage: (ws, event2) => { + if (event2.data === "pong") { + return; + } + const data = JSON.parse(event2.data); + if (data.type === "refresh_models") { + onRefreshCallbacks.forEach((callback) => callback()); + console.log("Models refreshed"); + return; + } + processWebSocket(data, global2, notificationProvider); + }, + onConnected: () => { + messageProvider.success("Connected to server"); + onConnectedCallbacks.forEach((callback) => callback()); + }, + onDisconnected: () => { + onDisconnectedCallbacks.forEach((callback) => callback()); + } + }); + function ws_text() { + switch (readyState.value) { + case "CLOSED": + return "Closed"; + case "CONNECTING": + return "Connecting"; + case "OPEN": + return "Connected"; + } + } + function get_color() { + switch (readyState.value) { + case "CLOSED": + return "error"; + case "CONNECTING": + return "warning"; + case "OPEN": + return "success"; + } + } + const readyState = ref(websocket.status); + const loading = computed(() => readyState.value === "CONNECTING"); + const text = computed(() => ws_text()); + const color = computed(() => get_color()); + return { + websocket, + readyState, + loading, + text, + ws_open: websocket.open, + color, + onConnectedCallbacks, + onDisconnectedCallbacks, + onRefreshCallbacks + }; +}); +const spaceRegex = new RegExp("[\\s,]+"); +const arrowKeys = [38, 40]; +let currentFocus = -1; +function convertToTextString(str) { + const upper = str.charAt(0).toUpperCase() + str.slice(1); + return upper.replace(/_/g, " "); +} +function cloneObj(obj) { + return window.structuredClone(obj); +} +function addActive(x) { + if (!x) + return false; + removeActive(x); + if (currentFocus >= x.length) { + currentFocus = 0; + } + if (currentFocus < 0) { + currentFocus = x.length - 1; + } + x[currentFocus].classList.add("autocomplete-active"); +} +function removeActive(x) { + for (let i = 0; i < x.length; i++) { + x[i].classList.remove("autocomplete-active"); + } +} +function closeAllLists(elmnt, input) { + var _a2, _b; + const x = document.getElementsByClassName("autocomplete-items"); + for (let i = 0; i < x.length; i++) { + if (elmnt != x[i] && elmnt != input) { + (_b = (_a2 = x[i]) == null ? void 0 : _a2.parentNode) == null ? void 0 : _b.removeChild(x[i]); + } + } +} +async function startWebsocket(messageProvider) { + const websocketState = useWebsocket(); + const timeout = 1e3; + const controller = new AbortController(); + const id = setTimeout(() => controller.abort(), timeout); + const response = await fetch(`${serverUrl}/api/test/alive`, { + signal: controller.signal + }).catch(() => { + messageProvider.error("Server is not responding"); + }); + clearTimeout(id); + if (response === void 0) { + return; + } + if (response.status !== 200) { + messageProvider.error("Server is not responding"); + return; + } + console.log("Starting websocket"); + websocketState.ws_open(); +} +function getTextBoundaries(elem) { + if (elem === null) { + console.error("Element is null"); + return [0, 0]; + } + if (elem.tagName === "INPUT" && elem.type === "text" || elem.tagName === "TEXTAREA") { + return [ + elem.selectionStart === null ? 0 : elem.selectionStart, + elem.selectionEnd === null ? 0 : elem.selectionEnd + ]; + } + console.error("Element is not input"); + return [0, 0]; +} +function promptHandleKeyUp(e, data, key, globalState) { + var _a2, _b, _c, _d, _e; + if (e.key === "ArrowUp" && e.ctrlKey) { + const values = getTextBoundaries( + document.activeElement + ); + const boundaryIndexStart = values[0]; + const boundaryIndexEnd = values[1]; + e.preventDefault(); + const elem = document.activeElement; + const current_selection = elem.value.substring( + boundaryIndexStart, + boundaryIndexEnd + ); + const regex = /\(([^:]+([:]?[\s]?)([\d.\d]+))\)/; + const matches = regex.exec(current_selection); + if (matches) { + if (matches) { + const value = parseFloat(matches[3]); + const new_value = (value + 0.1).toFixed(1); + const beforeString = elem.value.substring(0, boundaryIndexStart); + const afterString = elem.value.substring(boundaryIndexEnd); + const newString = `${beforeString}${current_selection.replace( + matches[3], + new_value + )}${afterString}`; + elem.value = newString; + data[key] = newString; + elem.setSelectionRange(boundaryIndexStart, boundaryIndexEnd); + } + } else if (boundaryIndexStart !== boundaryIndexEnd) { + const new_inner_string = `(${current_selection}:1.1)`; + const beforeString = elem.value.substring(0, boundaryIndexStart); + const afterString = elem.value.substring(boundaryIndexEnd); + elem.value = `${beforeString}${new_inner_string}${afterString}`; + data[key] = `${beforeString}${new_inner_string}${afterString}`; + elem.setSelectionRange(boundaryIndexStart, boundaryIndexEnd + 6); + } else { + console.log("No selection, cannot parse for weighting"); + } + } + if (e.key === "ArrowDown" && e.ctrlKey) { + const values = getTextBoundaries( + document.activeElement + ); + const boundaryIndexStart = values[0]; + const boundaryIndexEnd = values[1]; + e.preventDefault(); + const elem = document.activeElement; + const current_selection = elem.value.substring( + boundaryIndexStart, + boundaryIndexEnd + ); + const regex = /\(([^:]+([:]?[\s]?)([\d.\d]+))\)/; + const matches = regex.exec(current_selection); + if (matches) { + if (matches) { + const value = parseFloat(matches[3]); + const new_value = Math.max(value - 0.1, 0).toFixed(1); + const beforeString = elem.value.substring(0, boundaryIndexStart); + const afterString = elem.value.substring(boundaryIndexEnd); + const newString = `${beforeString}${current_selection.replace( + matches[3], + new_value + )}${afterString}`; + elem.value = newString; + data[key] = newString; + elem.setSelectionRange(boundaryIndexStart, boundaryIndexEnd); + } + } else if (boundaryIndexStart !== boundaryIndexEnd) { + const new_inner_string = `(${current_selection}:0.9)`; + const beforeString = elem.value.substring(0, boundaryIndexStart); + const afterString = elem.value.substring(boundaryIndexEnd); + elem.value = `${beforeString}${new_inner_string}${afterString}`; + data[key] = `${beforeString}${new_inner_string}${afterString}`; + elem.setSelectionRange(boundaryIndexStart, boundaryIndexEnd + 6); + } else { + console.log("No selection, cannot parse for weighting"); + } + } + const input = e.target; + if (input) { + const text = input.value; + const currentTokenStripped = (_a2 = text.split(",").pop()) == null ? void 0 : _a2.trim(); + closeAllLists(void 0, input); + if (!currentTokenStripped) { + return false; + } + const toAppend = []; + for (let i = 0; i < globalState.state.autofill_special.length; i++) { + if (globalState.state.autofill_special[i].toLowerCase().includes(currentTokenStripped.toLowerCase())) { + const b = document.createElement("DIV"); + b.innerText = globalState.state.autofill_special[i]; + b.innerHTML += ""; + b.addEventListener("click", function() { + input.value = text.substring(0, text.lastIndexOf(",") + 1) + globalState.state.autofill_special[i]; + data[key] = input.value; + closeAllLists(void 0, input); + }); + toAppend.push(b); + } + } + const lowercaseStrippedToken = currentTokenStripped.toLowerCase(); + if (lowercaseStrippedToken.length >= 3) { + for (let i = 0; i < globalState.state.autofill.length; i++) { + if (globalState.state.autofill[i].toLowerCase().includes(lowercaseStrippedToken)) { + if (toAppend.length >= 30) { + break; + } + const b = document.createElement("DIV"); + b.innerText = globalState.state.autofill[i]; + b.innerHTML += ""; + b.addEventListener("click", function() { + input.value = text.substring(0, text.lastIndexOf(",") + 1) + globalState.state.autofill[i]; + data[key] = input.value; + closeAllLists(void 0, input); + }); + toAppend.push(b); + } + } } - default: { - console.log(message); + if (toAppend.length === 0) { + return false; } - } -} -const useWebsocket = defineStore("websocket", () => { - const notificationProvider = useNotification(); - const messageProvider = useMessage(); - const global2 = useState2(); - const onConnectedCallbacks = []; - const onDisconnectedCallbacks = []; - const onRefreshCallbacks = []; - const websocket = useWebSocket(`${webSocketUrl}/api/websockets/master`, { - heartbeat: { - message: "ping", - interval: 1e3, - pongTimeout: 5e3 - }, - immediate: false, - onMessage: (ws, event2) => { - if (event2.data === "pong") { - return; - } - const data = JSON.parse(event2.data); - if (data.type === "refresh_models") { - onRefreshCallbacks.forEach((callback) => callback()); - console.log("Models refreshed"); - return; - } - processWebSocket(data, global2, notificationProvider); - }, - onConnected: () => { - messageProvider.success("Connected to server"); - onConnectedCallbacks.forEach((callback) => callback()); - }, - onDisconnected: () => { - onDisconnectedCallbacks.forEach((callback) => callback()); + const div = document.createElement("DIV"); + div.setAttribute("id", "autocomplete-list"); + div.setAttribute("class", "autocomplete-items"); + (_e = (_d = (_c = (_b = input.parentNode) == null ? void 0 : _b.parentNode) == null ? void 0 : _c.parentNode) == null ? void 0 : _d.parentNode) == null ? void 0 : _e.appendChild(div); + for (let i = 0; i < toAppend.length; i++) { + div.appendChild(toAppend[i]); } - }); - function ws_text() { - switch (readyState.value) { - case "CLOSED": - return "Closed"; - case "CONNECTING": - return "Connecting"; - case "OPEN": - return "Connected"; + onClickOutside(div, () => { + closeAllLists(void 0, input); + }); + const autocompleteList = document.getElementById("autocomplete-list"); + const x = autocompleteList == null ? void 0 : autocompleteList.getElementsByTagName("div"); + if (e.key === "ArrowDown") { + currentFocus++; + addActive(x); + e.preventDefault(); + } else if (e.key === "ArrowUp") { + currentFocus--; + addActive(x); + e.preventDefault(); + } else if (e.key === "Enter" || e.key === "Tab") { + e.stopImmediatePropagation(); + e.preventDefault(); + if (currentFocus > -1) { + if (x) + x[currentFocus].click(); + } + } else if (e.key === "Escape") { + closeAllLists(void 0, input); } } - function get_color() { - switch (readyState.value) { - case "CLOSED": - return "error"; - case "CONNECTING": - return "warning"; - case "OPEN": - return "success"; +} +function promptHandleKeyDown(e) { + if (arrowKeys.includes(e.keyCode) && e.ctrlKey) { + e.preventDefault(); + } + if (document.getElementById("autocomplete-list")) { + if (e.key === "Enter" || e.key === "Tab" || e.key === "ArrowDown" || e.key === "ArrowUp") { + e.preventDefault(); } } - const readyState = ref(websocket.status); - const loading = computed(() => readyState.value === "CONNECTING"); - const text = computed(() => ws_text()); - const color = computed(() => get_color()); - return { - websocket, - readyState, - loading, - text, - ws_open: websocket.open, - color, - onConnectedCallbacks, - onDisconnectedCallbacks, - onRefreshCallbacks - }; +} +function urlFromPath(path) { + const url = new URL(path, serverUrl); + return url.href; +} +const defaultADetailerSettings = { + enabled: false, + steps: 30, + cfg_scale: 7, + seed: -1, + sampler: "dpmpp_2m", + self_attention_scale: 0, + sigmas: "exponential", + strength: 0.4, + mask_dilation: 0, + mask_blur: 0, + mask_padding: 0, + iterations: 1, + upscale: 2 +}; +const deepShrinkFlagDefault = Object.freeze({ + enabled: false, + depth_1: 3, + stop_at_1: 0.15, + depth_2: 4, + stop_at_2: 0.3, + scaler: "bislerp", + base_scale: 0.5, + early_out: false }); -const spaceRegex = new RegExp("[\\s,]+"); -const arrowKeys = [38, 40]; -let currentFocus = -1; -function convertToTextString(str) { - const upper = str.charAt(0).toUpperCase() + str.slice(1); - return upper.replace(/_/g, " "); +const highresFixFlagDefault = Object.freeze({ + enabled: false, + scale: 2, + mode: "image", + image_upscaler: "RealESRGAN_x4plus_anime_6B", + latent_scale_mode: "bislerp", + antialiased: false, + strength: 0.65, + steps: 50 +}); +const scaleCrafterFlagDefault = Object.freeze({ + enabled: false, + base: "sd15", + unsafe_resolutions: true, + disperse: false +}); +const upscaleFlagDefault = Object.freeze({ + enabled: false, + upscale_factor: 4, + tile_size: 128, + tile_padding: 10, + model: "RealESRGAN_x4plus_anime_6B" +}); +var Sampler = /* @__PURE__ */ ((Sampler2) => { + Sampler2[Sampler2["DDIM"] = 1] = "DDIM"; + Sampler2[Sampler2["DDPM"] = 2] = "DDPM"; + Sampler2[Sampler2["PNDM"] = 3] = "PNDM"; + Sampler2[Sampler2["LMSD"] = 4] = "LMSD"; + Sampler2[Sampler2["EulerDiscrete"] = 5] = "EulerDiscrete"; + Sampler2[Sampler2["HeunDiscrete"] = 6] = "HeunDiscrete"; + Sampler2[Sampler2["EulerAncestralDiscrete"] = 7] = "EulerAncestralDiscrete"; + Sampler2[Sampler2["DPMSolverMultistep"] = 8] = "DPMSolverMultistep"; + Sampler2[Sampler2["DPMSolverSinglestep"] = 9] = "DPMSolverSinglestep"; + Sampler2[Sampler2["KDPM2Discrete"] = 10] = "KDPM2Discrete"; + Sampler2[Sampler2["KDPM2AncestralDiscrete"] = 11] = "KDPM2AncestralDiscrete"; + Sampler2[Sampler2["DEISMultistep"] = 12] = "DEISMultistep"; + Sampler2[Sampler2["UniPCMultistep"] = 13] = "UniPCMultistep"; + Sampler2[Sampler2["DPMSolverSDEScheduler"] = 14] = "DPMSolverSDEScheduler"; + return Sampler2; +})(Sampler || {}); +const defaultSettings = { + $schema: "./schema/ui_data/settings.json", + backend: "PyTorch", + model: null, + flags: { + sdxl: { + original_size: { + width: 1024, + height: 1024 + } + }, + refiner: { + model: void 0, + aesthetic_score: 6, + negative_aesthetic_score: 2.5, + steps: 50, + strength: 0.3 + } + }, + aitDim: { + width: void 0, + height: void 0, + batch_size: void 0 + }, + txt2img: { + width: 512, + height: 512, + seed: -1, + cfg_scale: 7, + sampler: Sampler.DPMSolverMultistep, + prompt: "", + steps: 25, + batch_count: 1, + batch_size: 1, + negative_prompt: "", + self_attention_scale: 0, + sigmas: "automatic", + highres: cloneObj(highresFixFlagDefault), + upscale: cloneObj(upscaleFlagDefault), + deepshrink: cloneObj(deepShrinkFlagDefault), + scalecrafter: cloneObj(scaleCrafterFlagDefault), + adetailer: cloneObj(defaultADetailerSettings) + }, + img2img: { + width: 512, + height: 512, + seed: -1, + cfg_scale: 7, + sampler: Sampler.DPMSolverMultistep, + prompt: "", + steps: 25, + batch_count: 1, + batch_size: 1, + negative_prompt: "", + denoising_strength: 0.6, + image: "", + self_attention_scale: 0, + sigmas: "automatic", + highres: cloneObj(highresFixFlagDefault), + upscale: cloneObj(upscaleFlagDefault), + deepshrink: cloneObj(deepShrinkFlagDefault), + scalecrafter: cloneObj(scaleCrafterFlagDefault), + adetailer: cloneObj(defaultADetailerSettings) + }, + inpainting: { + prompt: "", + negative_prompt: "", + image: "", + mask_image: "", + width: 512, + height: 512, + steps: 25, + cfg_scale: 7, + seed: -1, + batch_count: 1, + batch_size: 1, + strength: 0.65, + sampler: Sampler.DPMSolverMultistep, + self_attention_scale: 0, + sigmas: "automatic", + highres: cloneObj(highresFixFlagDefault), + upscale: cloneObj(upscaleFlagDefault), + deepshrink: cloneObj(deepShrinkFlagDefault), + scalecrafter: cloneObj(scaleCrafterFlagDefault), + adetailer: cloneObj(defaultADetailerSettings) + }, + controlnet: { + prompt: "", + image: "", + sampler: Sampler.DPMSolverMultistep, + controlnet: ControlNetType.CANNY, + negative_prompt: "", + width: 512, + height: 512, + steps: 25, + cfg_scale: 7, + seed: -1, + batch_size: 1, + batch_count: 1, + controlnet_conditioning_scale: 1, + detection_resolution: 512, + is_preprocessed: false, + save_preprocessed: false, + return_preprocessed: true, + self_attention_scale: 0, + sigmas: "automatic", + highres: cloneObj(highresFixFlagDefault), + upscale: cloneObj(upscaleFlagDefault), + deepshrink: cloneObj(deepShrinkFlagDefault), + scalecrafter: cloneObj(scaleCrafterFlagDefault), + adetailer: cloneObj(defaultADetailerSettings) + }, + upscale: { + image: "", + upscale_factor: 4, + model: "RealESRGAN_x4plus_anime_6B", + tile_size: 128, + tile_padding: 10 + }, + tagger: { + image: "", + model: "deepdanbooru", + threshold: 0.5 + }, + api: { + websocket_sync_interval: 0.02, + websocket_perf_interval: 1, + enable_websocket_logging: true, + clip_skip: 1, + clip_quantization: "full", + autocast: true, + attention_processor: "xformers", + subquadratic_size: 512, + attention_slicing: "disabled", + channels_last: true, + trace_model: false, + cudnn_benchmark: false, + offload: "disabled", + dont_merge_latents: false, + device: "cuda:0", + data_type: "float16", + use_tomesd: true, + tomesd_ratio: 0.4, + tomesd_downsample_layers: 1, + deterministic_generation: false, + reduced_precision: false, + clear_memory_policy: "always", + huggingface_style_parsing: false, + autoloaded_textual_inversions: [], + autoloaded_models: [], + autoloaded_vae: {}, + save_path_template: "{folder}/{prompt}/{id}-{index}.{extension}", + image_extension: "png", + image_quality: 95, + disable_grid: false, + torch_compile: false, + torch_compile_fullgraph: false, + torch_compile_dynamic: false, + torch_compile_backend: "inductor", + torch_compile_mode: "default", + sfast_compile: false, + sfast_xformers: true, + sfast_triton: true, + sfast_cuda_graph: false, + hypertile: false, + hypertile_unet_chunk: 256, + sgm_noise_multiplier: false, + kdiffusers_quantization: true, + xl_refiner: "joint", + generator: "device", + live_preview_method: "approximation", + live_preview_delay: 2, + upcast_vae: false, + vae_slicing: false, + vae_tiling: false, + apply_unsharp_mask: false, + cfg_rescale_threshold: 10, + prompt_to_prompt: false, + prompt_to_prompt_model: "lllyasviel/Fooocus-Expansion", + prompt_to_prompt_device: "gpu", + free_u: false, + free_u_s1: 0.9, + free_u_s2: 0.2, + free_u_b1: 1.2, + free_u_b2: 1.4 + }, + aitemplate: { + num_threads: 8 + }, + onnx: { + quant_dict: { + text_encoder: null, + unet: null, + vae_decoder: null, + vae_encoder: null + }, + convert_to_fp16: true, + simplify_unet: false + }, + bot: { + default_scheduler: Sampler.DPMSolverMultistep, + verbose: false, + use_default_negative_prompt: true + }, + frontend: { + theme: "dark", + enable_theme_editor: false, + image_browser_columns: 5, + on_change_timer: 2e3, + nsfw_ok_threshold: 0, + background_image_override: "", + disable_analytics: true + }, + sampler_config: {} +}; +let rSettings = JSON.parse(JSON.stringify(defaultSettings)); +try { + const req = new XMLHttpRequest(); + req.open("GET", `${serverUrl}/api/settings/`, false); + req.send(); + rSettings = { ...rSettings, ...JSON.parse(req.responseText) }; +} catch (e) { + console.error(e); } -function addActive(x) { - if (!x) - return false; - removeActive(x); - if (currentFocus >= x.length) { - currentFocus = 0; +console.log("Settings:", rSettings); +const recievedSettings = rSettings; +class Settings { + constructor(settings_override) { + __publicField(this, "settings"); + this.settings = { ...defaultSettings, ...settings_override }; } - if (currentFocus < 0) { - currentFocus = x.length - 1; + to_json() { + return JSON.stringify(this.settings); } - x[currentFocus].classList.add("autocomplete-active"); } -function removeActive(x) { - for (let i = 0; i < x.length; i++) { - x[i].classList.remove("autocomplete-active"); +const diffusersSchedulerTuple = { + DDIM: 1, + DDPM: 2, + PNDM: 3, + LMSD: 4, + EulerDiscrete: 5, + HeunDiscrete: 6, + EulerAncestralDiscrete: 7, + DPMSolverMultistep: 8, + DPMSolverSinglestep: 9, + KDPM2Discrete: 10, + KDPM2AncestralDiscrete: 11, + DEISMultistep: 12, + UniPCMultistep: 13, + DPMSolverSDEScheduler: 14 +}; +const upscalerOptions = [ + { + label: "RealESRGAN_x4plus", + value: "RealESRGAN_x4plus" + }, + { + label: "RealESRNet_x4plus", + value: "RealESRNet_x4plus" + }, + { + label: "RealESRGAN_x4plus_anime_6B", + value: "RealESRGAN_x4plus_anime_6B" + }, + { + label: "RealESRGAN_x2plus", + value: "RealESRGAN_x2plus" + }, + { + label: "RealESR-general-x4v3", + value: "RealESR-general-x4v3" } -} -function closeAllLists(elmnt, input) { - var _a2, _b; - const x = document.getElementsByClassName("autocomplete-items"); - for (let i = 0; i < x.length; i++) { - if (elmnt != x[i] && elmnt != input) { - (_b = (_a2 = x[i]) == null ? void 0 : _a2.parentNode) == null ? void 0 : _b.removeChild(x[i]); +]; +function getSamplerOptions() { + const scheduler_options = [ + { + type: "group", + label: "k-diffusion", + key: "K-Diffusion", + children: [ + { label: "Euler a", value: "euler_a" }, + { label: "Euler", value: "euler" }, + { label: "LMS", value: "lms" }, + { label: "Heun", value: "heun" }, + { label: "Heun++", value: "heunpp" }, + { label: "DPM Fast", value: "dpm_fast" }, + { label: "DPM Adaptive", value: "dpm_adaptive" }, + { label: "DPM2", value: "dpm2" }, + { label: "DPM2 a", value: "dpm2_a" }, + { label: "DPM++ 2S a", value: "dpmpp_2s_a" }, + { label: "DPM++ 2M", value: "dpmpp_2m" }, + { label: "DPM++ 2M Sharp", value: "dpmpp_2m_sharp" }, + { label: "DPM++ SDE", value: "dpmpp_sde" }, + { label: "DPM++ 2M SDE", value: "dpmpp_2m_sde" }, + { label: "DPM++ 3M SDE", value: "dpmpp_3m_sde" }, + { label: "UniPC Multistep", value: "unipc_multistep" }, + { label: "Restart", value: "restart" } + ] + }, + { + type: "group", + label: "Diffusers", + key: "diffusers", + children: [ + ...Object.keys(diffusersSchedulerTuple).map((key) => { + return { + label: key, + value: diffusersSchedulerTuple[key] + }; + }), + { label: "SASolverMultistep", value: "sasolver" } + ] } - } + ]; + return scheduler_options; } -async function startWebsocket(messageProvider) { - const websocketState = useWebsocket(); - const timeout = 1e3; - const controller = new AbortController(); - const id = setTimeout(() => controller.abort(), timeout); - const response = await fetch(`${serverUrl}/api/test/alive`, { - signal: controller.signal - }).catch(() => { - messageProvider.error("Server is not responding"); +function getControlNetOptions() { + const controlnet_options = [ + { + type: "group", + label: "ControlNet 1.1", + key: "ControlNet 1.1", + children: [ + { + label: "lllyasviel/control_v11p_sd15_canny", + value: "lllyasviel/control_v11p_sd15_canny" + }, + { + label: "lllyasviel/control_v11f1p_sd15_depth", + value: "lllyasviel/control_v11f1p_sd15_depth" + }, + { + label: "lllyasviel/control_v11e_sd15_ip2p", + value: "lllyasviel/control_v11e_sd15_ip2p" + }, + { + label: "lllyasviel/control_v11p_sd15_softedge", + value: "lllyasviel/control_v11p_sd15_softedge" + }, + { + label: "lllyasviel/control_v11p_sd15_openpose", + value: "lllyasviel/control_v11p_sd15_openpose" + }, + { + label: "lllyasviel/control_v11f1e_sd15_tile", + value: "lllyasviel/control_v11f1e_sd15_tile" + }, + { + label: "lllyasviel/control_v11p_sd15_mlsd", + value: "lllyasviel/control_v11p_sd15_mlsd" + }, + { + label: "lllyasviel/control_v11p_sd15_scribble", + value: "lllyasviel/control_v11p_sd15_scribble" + }, + { + label: "lllyasviel/control_v11p_sd15_seg", + value: "lllyasviel/control_v11p_sd15_seg" + } + ] + }, + { + type: "group", + label: "Special", + key: "Special", + children: [ + { + label: "DionTimmer/controlnet_qrcode", + value: "DionTimmer/controlnet_qrcode" + }, + { + label: "CrucibleAI/ControlNetMediaPipeFace", + value: "CrucibleAI/ControlNetMediaPipeFace" + } + ] + }, + { + type: "group", + label: "Original", + key: "Original", + children: [ + { + label: "lllyasviel/sd-controlnet-canny", + value: "lllyasviel/sd-controlnet-canny" + }, + { + label: "lllyasviel/sd-controlnet-depth", + value: "lllyasviel/sd-controlnet-depth" + }, + { + label: "lllyasviel/sd-controlnet-hed", + value: "lllyasviel/sd-controlnet-hed" + }, + { + label: "lllyasviel/sd-controlnet-mlsd", + value: "lllyasviel/sd-controlnet-mlsd" + }, + { + label: "lllyasviel/sd-controlnet-normal", + value: "lllyasviel/sd-controlnet-normal" + }, + { + label: "lllyasviel/sd-controlnet-openpose", + value: "lllyasviel/sd-controlnet-openpose" + }, + { + label: "lllyasviel/sd-controlnet-scribble", + value: "lllyasviel/sd-controlnet-scribble" + }, + { + label: "lllyasviel/sd-controlnet-seg", + value: "lllyasviel/sd-controlnet-seg" + } + ] + } + ]; + return controlnet_options; +} +const deepcopiedSettings = JSON.parse(JSON.stringify(recievedSettings)); +const useSettings = defineStore("settings", () => { + const data = reactive(new Settings(recievedSettings)); + const samplers = computed(() => { + return getSamplerOptions(); }); - clearTimeout(id); - if (response === void 0) { - return; - } - if (response.status !== 200) { - messageProvider.error("Server is not responding"); - return; + const controlnet_options = computed(() => { + return getControlNetOptions(); + }); + function resetSettings() { + console.log("Resetting settings to default"); + Object.assign(defaultSettings$1, defaultSettings); } - console.log("Starting websocket"); - websocketState.ws_open(); -} -function getTextBoundaries(elem) { - if (elem === null) { - console.error("Element is null"); - return [0, 0]; + async function saveSettings() { + fetch(`${serverUrl}/api/settings/save`, { + method: "POST", + headers: { + "Content-Type": "application/json" + }, + body: JSON.stringify(defaultSettings$1) + }).then((res) => { + if (res.status === 200) { + console.log("Settings saved successfully"); + } else { + throw new Error("Failed to save settings"); + } + }); } - if (elem.tagName === "INPUT" && elem.type === "text" || elem.tagName === "TEXTAREA") { - return [ - elem.selectionStart === null ? 0 : elem.selectionStart, - elem.selectionEnd === null ? 0 : elem.selectionEnd - ]; + const defaultSettings$1 = reactive(deepcopiedSettings); + return { + data, + scheduler_options: samplers, + controlnet_options, + defaultSettings: defaultSettings$1, + resetSettings, + saveSettings + }; +}); +const ImageUpload_vue_vue_type_style_index_0_scoped_9ed1514f_lang = ""; +const _export_sfc = (sfc, props) => { + const target = sfc.__vccOpts || sfc; + for (const [key, val] of props) { + target[key] = val; } - console.error("Element is not input"); - return [0, 0]; -} -function promptHandleKeyUp(e, data, key, globalState) { - var _a2, _b, _c, _d, _e; - if (e.key === "ArrowUp" && e.ctrlKey) { - const values = getTextBoundaries( - document.activeElement - ); - const boundaryIndexStart = values[0]; - const boundaryIndexEnd = values[1]; - e.preventDefault(); - const elem = document.activeElement; - const current_selection = elem.value.substring( - boundaryIndexStart, - boundaryIndexEnd + return target; +}; +const _sfc_main$8 = /* @__PURE__ */ defineComponent({ + __name: "InitHandler", + setup(__props) { + console.log( + ` + ██╗ ██╗ █████╗ ██╗ ████████╗ █████╗ ███╗ ███╗██╗ + ██║ ██║██╔══██╗██║ ╚══██╔══╝██╔══██╗████╗ ████║██║ + ╚██╗ ██╔╝██║ ██║██║ ██║ ███████║██╔████╔██║██║ + ╚████╔╝ ██║ ██║██║ ██║ ██╔══██║██║╚██╔╝██║██║ + ╚██╔╝ ╚█████╔╝███████╗ ██║ ██║ ██║██║ ╚═╝ ██║███████╗ + ╚═╝ ╚════╝ ╚══════╝ ╚═╝ ╚═╝ ╚═╝╚═╝ ╚═╝╚══════╝ + ` ); - const regex = /\(([^:]+([:]?[\s]?)([\d.\d]+))\)/; - const matches = regex.exec(current_selection); - if (matches) { - if (matches) { - const value = parseFloat(matches[3]); - const new_value = (value + 0.1).toFixed(1); - const beforeString = elem.value.substring(0, boundaryIndexStart); - const afterString = elem.value.substring(boundaryIndexEnd); - const newString = `${beforeString}${current_selection.replace( - matches[3], - new_value - )}${afterString}`; - elem.value = newString; - data[key] = newString; - elem.setSelectionRange(boundaryIndexStart, boundaryIndexEnd); - } - } else if (boundaryIndexStart !== boundaryIndexEnd) { - const new_inner_string = `(${current_selection}:1.1)`; - const beforeString = elem.value.substring(0, boundaryIndexStart); - const afterString = elem.value.substring(boundaryIndexEnd); - elem.value = `${beforeString}${new_inner_string}${afterString}`; - data[key] = `${beforeString}${new_inner_string}${afterString}`; - elem.setSelectionRange(boundaryIndexStart, boundaryIndexEnd + 6); - } else { - console.log("No selection, cannot parse for weighting"); - } + const global2 = useState2(); + global2.fetchCapabilites().then(() => { + console.log("Capabilities successfully fetched from the server"); + }); + global2.fetchAutofill(); + return (_ctx, _cache) => { + return null; + }; } - if (e.key === "ArrowDown" && e.ctrlKey) { - const values = getTextBoundaries( - document.activeElement - ); - const boundaryIndexStart = values[0]; - const boundaryIndexEnd = values[1]; - e.preventDefault(); - const elem = document.activeElement; - const current_selection = elem.value.substring( - boundaryIndexStart, - boundaryIndexEnd - ); - const regex = /\(([^:]+([:]?[\s]?)([\d.\d]+))\)/; - const matches = regex.exec(current_selection); - if (matches) { - if (matches) { - const value = parseFloat(matches[3]); - const new_value = Math.max(value - 0.1, 0).toFixed(1); - const beforeString = elem.value.substring(0, boundaryIndexStart); - const afterString = elem.value.substring(boundaryIndexEnd); - const newString = `${beforeString}${current_selection.replace( - matches[3], - new_value - )}${afterString}`; - elem.value = newString; - data[key] = newString; - elem.setSelectionRange(boundaryIndexStart, boundaryIndexEnd); - } - } else if (boundaryIndexStart !== boundaryIndexEnd) { - const new_inner_string = `(${current_selection}:0.9)`; - const beforeString = elem.value.substring(0, boundaryIndexStart); - const afterString = elem.value.substring(boundaryIndexEnd); - elem.value = `${beforeString}${new_inner_string}${afterString}`; - data[key] = `${beforeString}${new_inner_string}${afterString}`; - elem.setSelectionRange(boundaryIndexStart, boundaryIndexEnd + 6); - } else { - console.log("No selection, cannot parse for weighting"); - } +}); +const _sfc_main$7 = /* @__PURE__ */ defineComponent({ + __name: "LogDrawer", + setup(__props) { + const glob = useState2(); + const log = computed(() => glob.state.log_drawer.logs.join("\n")); + return (_ctx, _cache) => { + return openBlock(), createBlock(unref(NDrawer), { + placement: "bottom", + show: unref(glob).state.log_drawer.enabled, + "onUpdate:show": _cache[0] || (_cache[0] = ($event) => unref(glob).state.log_drawer.enabled = $event), + "auto-focus": false, + "show-mask": true, + height: "70vh" + }, { + default: withCtx(() => [ + createVNode(unref(NDrawerContent), { + closable: "", + title: "Log - 500 latest messages" + }, { + default: withCtx(() => [ + createVNode(unref(NLog), { + ref: "logRef", + log: log.value, + trim: "", + style: { "height": "100%" } + }, null, 8, ["log"]) + ]), + _: 1 + }) + ]), + _: 1 + }, 8, ["show"]); + }; } - const input = e.target; - if (input) { - const text = input.value; - const currentTokenStripped = (_a2 = text.split(",").pop()) == null ? void 0 : _a2.trim(); - closeAllLists(void 0, input); - if (!currentTokenStripped) { - return false; - } - const toAppend = []; - for (let i = 0; i < globalState.state.autofill_special.length; i++) { - if (globalState.state.autofill_special[i].toLowerCase().includes(currentTokenStripped.toLowerCase())) { - const b = document.createElement("DIV"); - b.innerText = globalState.state.autofill_special[i]; - b.innerHTML += ""; - b.addEventListener("click", function() { - input.value = text.substring(0, text.lastIndexOf(",") + 1) + globalState.state.autofill_special[i]; - data[key] = input.value; - closeAllLists(void 0, input); - }); - toAppend.push(b); - } +}); +const _hoisted_1$4 = { style: { "width": "100%", "display": "inline-flex", "align-items": "center" } }; +const _hoisted_2$3 = /* @__PURE__ */ createBaseVNode("p", { style: { "width": "108px" } }, "Utilization", -1); +const _hoisted_3$2 = { style: { "width": "100%", "display": "inline-flex", "align-items": "center" } }; +const _hoisted_4$2 = /* @__PURE__ */ createBaseVNode("p", { style: { "width": "108px" } }, "Memory", -1); +const _hoisted_5$2 = { style: { "align-self": "flex-end", "margin-left": "12px" } }; +const _sfc_main$6 = /* @__PURE__ */ defineComponent({ + __name: "PerformanceDrawer", + setup(__props) { + const global2 = useState2(); + const glob = useState2(); + return (_ctx, _cache) => { + return openBlock(), createBlock(unref(NDrawer), { + placement: "bottom", + show: unref(glob).state.perf_drawer.enabled, + "onUpdate:show": _cache[0] || (_cache[0] = ($event) => unref(glob).state.perf_drawer.enabled = $event), + "auto-focus": false, + "show-mask": true, + height: "70vh" + }, { + default: withCtx(() => [ + createVNode(unref(NDrawerContent), { + closable: "", + title: "Performance statistics" + }, { + default: withCtx(() => [ + (openBlock(true), createElementBlock(Fragment, null, renderList(unref(global2).state.perf_drawer.gpus, (gpu) => { + return openBlock(), createBlock(unref(NCard), { + key: gpu.uuid, + style: { "margin-bottom": "12px" } + }, { + default: withCtx(() => [ + createVNode(unref(NSpace), { + inline: "", + justify: "space-between", + style: { "width": "100%" } + }, { + default: withCtx(() => [ + createBaseVNode("h3", null, "[" + toDisplayString(gpu.index) + "] " + toDisplayString(gpu.name), 1), + createBaseVNode("h4", null, toDisplayString(gpu.power_draw) + " / " + toDisplayString(gpu.power_limit) + "W ─ " + toDisplayString(gpu.temperature) + "°C ", 1) + ]), + _: 2 + }, 1024), + createBaseVNode("div", _hoisted_1$4, [ + _hoisted_2$3, + createVNode(unref(NProgress), { + percentage: gpu.utilization, + type: "line", + "indicator-placement": "inside", + style: { "flex-grow": "1", "width": "400px" } + }, null, 8, ["percentage"]) + ]), + createBaseVNode("div", _hoisted_3$2, [ + _hoisted_4$2, + createVNode(unref(NProgress), { + percentage: gpu.memory_usage, + type: "line", + style: { "flex-grow": "1", "width": "400px" }, + color: "#63e2b7", + "indicator-placement": "inside" + }, null, 8, ["percentage"]), + createBaseVNode("p", _hoisted_5$2, toDisplayString(gpu.memory_used) + " / " + toDisplayString(gpu.memory_total) + " MB ", 1) + ]) + ]), + _: 2 + }, 1024); + }), 128)) + ]), + _: 1 + }) + ]), + _: 1 + }, 8, ["show"]); + }; + } +}); +const _hoisted_1$3 = /* @__PURE__ */ createBaseVNode("a", { + target: "_blank", + href: "https://huggingface.co/settings/tokens" +}, "this page", -1); +const _hoisted_2$2 = { style: { "margin-top": "8px", "width": "100%", "display": "flex", "justify-content": "end" } }; +const _sfc_main$5 = /* @__PURE__ */ defineComponent({ + __name: "SecretsHandler", + setup(__props) { + const message = useMessage(); + const global2 = useState2(); + const hf_loading = ref(false); + const hf_token = ref(""); + function noSideSpace(value) { + return !/ /g.test(value); } - const lowercaseStrippedToken = currentTokenStripped.toLowerCase(); - if (lowercaseStrippedToken.length >= 3) { - for (let i = 0; i < globalState.state.autofill.length; i++) { - if (globalState.state.autofill[i].toLowerCase().includes(lowercaseStrippedToken)) { - if (toAppend.length >= 30) { - break; - } - const b = document.createElement("DIV"); - b.innerText = globalState.state.autofill[i]; - b.innerHTML += ""; - b.addEventListener("click", function() { - input.value = text.substring(0, text.lastIndexOf(",") + 1) + globalState.state.autofill[i]; - data[key] = input.value; - closeAllLists(void 0, input); - }); - toAppend.push(b); + function setHuggingfaceToken() { + hf_loading.value = true; + const url = new URL(`${serverUrl}/api/settings/inject-var-into-dotenv`); + url.searchParams.append("key", "HUGGINGFACE_TOKEN"); + url.searchParams.append("value", hf_token.value); + fetch(url, { method: "POST" }).then((res) => { + if (res.status !== 200) { + message.create("Failed to set HuggingFace token", { type: "error" }); + return; } - } - } - if (toAppend.length === 0) { - return false; - } - const div = document.createElement("DIV"); - div.setAttribute("id", "autocomplete-list"); - div.setAttribute("class", "autocomplete-items"); - (_e = (_d = (_c = (_b = input.parentNode) == null ? void 0 : _b.parentNode) == null ? void 0 : _c.parentNode) == null ? void 0 : _d.parentNode) == null ? void 0 : _e.appendChild(div); - for (let i = 0; i < toAppend.length; i++) { - div.appendChild(toAppend[i]); - } - onClickOutside(div, () => { - closeAllLists(void 0, input); - }); - const autocompleteList = document.getElementById("autocomplete-list"); - const x = autocompleteList == null ? void 0 : autocompleteList.getElementsByTagName("div"); - if (e.key === "ArrowDown") { - currentFocus++; - addActive(x); - e.preventDefault(); - } else if (e.key === "ArrowUp") { - currentFocus--; - addActive(x); - e.preventDefault(); - } else if (e.key === "Enter" || e.key === "Tab") { - e.stopImmediatePropagation(); - e.preventDefault(); - if (currentFocus > -1) { - if (x) - x[currentFocus].click(); - } - } else if (e.key === "Escape") { - closeAllLists(void 0, input); - } - } -} -function promptHandleKeyDown(e) { - if (arrowKeys.includes(e.keyCode) && e.ctrlKey) { - e.preventDefault(); - } - if (document.getElementById("autocomplete-list")) { - if (e.key === "Enter" || e.key === "Tab" || e.key === "ArrowDown" || e.key === "ArrowUp") { - e.preventDefault(); + global2.state.secrets.huggingface = "ok"; + message.create("HuggingFace token set successfully", { type: "success" }); + }).catch((e) => { + message.create(`Failed to set HuggingFace token: ${e.message}`, { + type: "error" + }); + }); + hf_loading.value = false; } + return (_ctx, _cache) => { + return openBlock(), createBlock(unref(NModal), { + show: unref(global2).state.secrets.huggingface !== "ok", + preset: "card", + title: "Missing HuggingFace Token", + style: { "width": "80vw" }, + closable: false + }, { + default: withCtx(() => [ + createVNode(unref(NText), null, { + default: withCtx(() => [ + createTextVNode(" API does not have a HuggingFace token. Please enter a valid token to continue. You can get a token from "), + _hoisted_1$3 + ]), + _: 1 + }), + createVNode(unref(NInput), { + type: "password", + placeholder: "hf_123...", + style: { "margin-top": "8px" }, + "allow-input": noSideSpace, + value: hf_token.value, + "onUpdate:value": _cache[0] || (_cache[0] = ($event) => hf_token.value = $event) + }, null, 8, ["value"]), + createBaseVNode("div", _hoisted_2$2, [ + createVNode(unref(NButton), { + ghost: "", + type: "primary", + loading: hf_loading.value, + onClick: setHuggingfaceToken + }, { + default: withCtx(() => [ + createTextVNode("Set Token") + ]), + _: 1 + }, 8, ["loading"]) + ]) + ]), + _: 1 + }, 8, ["show"]); + }; } -} -function urlFromPath(path) { - const url = new URL(path, serverUrl); - return url.href; -} -const _withScopeId = (n) => (pushScopeId("data-v-29f01b28"), n = n(), popScopeId(), n); -const _hoisted_1$1 = { class: "top-bar" }; -const _hoisted_2 = { key: 0 }; -const _hoisted_3 = { key: 1 }; -const _hoisted_4 = /* @__PURE__ */ _withScopeId(() => /* @__PURE__ */ createBaseVNode("img", { +}); +const _withScopeId = (n) => (pushScopeId("data-v-3a99505a"), n = n(), popScopeId(), n); +const _hoisted_1$2 = { class: "top-bar" }; +const _hoisted_2$1 = { key: 0 }; +const _hoisted_3$1 = { key: 0 }; +const _hoisted_4$1 = { key: 1 }; +const _hoisted_5$1 = /* @__PURE__ */ _withScopeId(() => /* @__PURE__ */ createBaseVNode("img", { src: "https://i.imgflip.com/84840n.jpg", style: { "max-width": "30vw", "max-height": "30vh" } }, null, -1)); -const _hoisted_5 = { key: 2 }; -const _hoisted_6 = { style: { "display": "inline-flex", "width": "100%", "margin-bottom": "12px" } }; -const _hoisted_7 = { style: { "display": "inline-flex" } }; -const _hoisted_8 = { key: 0 }; +const _hoisted_6$1 = { key: 2 }; +const _hoisted_7$1 = { style: { "display": "inline-flex", "width": "100%", "margin-bottom": "12px" } }; +const _hoisted_8 = { style: { "display": "flex", "flex-direction": "row", "align-items": "center" } }; const _hoisted_9 = { style: { "display": "inline-flex" } }; -const _hoisted_10 = { key: 1 }; -const _hoisted_11 = /* @__PURE__ */ _withScopeId(() => /* @__PURE__ */ createBaseVNode("b", null, "Ignore the tokens on CivitAI", -1)); -const _hoisted_12 = { key: 0 }; -const _hoisted_13 = { style: { "display": "inline-flex" } }; -const _hoisted_14 = { key: 1 }; -const _hoisted_15 = { class: "progress-container" }; -const _hoisted_16 = { style: { "display": "inline-flex", "align-items": "center" } }; -const _sfc_main$3 = /* @__PURE__ */ defineComponent({ +const _hoisted_10 = { key: 0 }; +const _hoisted_11 = { style: { "display": "inline-flex" } }; +const _hoisted_12 = { key: 1 }; +const _hoisted_13 = /* @__PURE__ */ _withScopeId(() => /* @__PURE__ */ createBaseVNode("b", null, "Ignore the tokens on CivitAI", -1)); +const _hoisted_14 = { key: 0 }; +const _hoisted_15 = { style: { "display": "inline-flex" } }; +const _hoisted_16 = { key: 1 }; +const _hoisted_17 = { class: "progress-container" }; +const _hoisted_18 = { style: { "display": "inline-flex", "align-items": "center" } }; +const _sfc_main$4 = /* @__PURE__ */ defineComponent({ __name: "TopBar", setup(__props) { + useCssVars((_ctx) => ({ + "ca9a9586": topBarWidth.value + })); const router2 = useRouter(); const websocketState = useWebsocket(); const global2 = useState2(); @@ -42130,35 +42470,74 @@ const _sfc_main$3 = /* @__PURE__ */ defineComponent({ } }); }); + const manualVAEModels = computed(() => { + const selectedModel = global2.state.selected_model; + if ((selectedModel == null ? void 0 : selectedModel.type) === "SDXL") { + return [ + { + name: "Default VAE (fp32)", + path: "default", + backend: "VAE", + valid: true, + state: "not loaded", + vae: "default", + textual_inversions: [], + type: "SDXL", + stage: "last_stage" + }, + { + name: "FP16 VAE", + path: "madebyollin/sdxl-vae-fp16-fix", + backend: "VAE", + valid: true, + state: "not loaded", + vae: "fp16", + textual_inversions: [], + type: "SDXL", + stage: "last_stage" + } + ]; + } else { + return [ + { + name: "Default VAE", + path: "default", + backend: "VAE", + valid: true, + state: "not loaded", + vae: "default", + textual_inversions: [], + type: "SD1.x", + stage: "last_stage" + }, + { + name: "Tiny VAE (fast)", + path: "madebyollin/taesd", + backend: "VAE", + valid: true, + state: "not loaded", + vae: "madebyollin/taesd", + textual_inversions: [], + type: "SD1.x", + stage: "last_stage" + }, + { + name: "Asymmetric VAE", + path: "cross-attention/asymmetric-autoencoder-kl-x-1-5", + backend: "VAE", + valid: true, + state: "not loaded", + vae: "cross-attention/asymmetric-autoencoder-kl-x-1-5", + textual_inversions: [], + type: "SD1.x", + stage: "last_stage" + } + ]; + } + }); const vaeModels = computed(() => { return [ - { - name: "Default VAE", - path: "default", - backend: "VAE", - valid: true, - state: "not loaded", - vae: "default", - textual_inversions: [] - }, - { - name: "Tiny VAE (fast)", - path: "madebyollin/taesd", - backend: "VAE", - valid: true, - state: "not loaded", - vae: "madebyollin/taesd", - textual_inversions: [] - }, - { - name: "Asymmetric VAE", - path: "cross-attention/asymmetric-autoencoder-kl-x-1-5", - backend: "VAE", - valid: true, - state: "not loaded", - vae: "cross-attention/asymmetric-autoencoder-kl-x-1-5", - textual_inversions: [] - }, + ...manualVAEModels.value, ...filteredModels.value.filter((model) => { return model.backend === "VAE"; }).sort((a, b) => { @@ -42273,7 +42652,11 @@ const _sfc_main$3 = /* @__PURE__ */ defineComponent({ model.state = "loading"; modelsLoading.value = true; const load_url = new URL(`${serverUrl}/api/models/load`); - const params = { model: model.path, backend: model.backend }; + const params = { + model: model.path, + backend: model.backend, + type: model.type + }; load_url.search = new URLSearchParams(params).toString(); fetch(load_url, { method: "POST" @@ -42314,7 +42697,7 @@ const _sfc_main$3 = /* @__PURE__ */ defineComponent({ "Content-Type": "application/json" }, body: JSON.stringify({ - model: global2.state.selected_model.name, + model: global2.state.selected_model.path, vae: vae.path }) }); @@ -42335,7 +42718,7 @@ const _sfc_main$3 = /* @__PURE__ */ defineComponent({ "Content-Type": "application/json" }, body: JSON.stringify({ - model: global2.state.selected_model.name, + model: global2.state.selected_model.path, textual_inversion: textualInversion.path }) }); @@ -42389,6 +42772,21 @@ const _sfc_main$3 = /* @__PURE__ */ defineComponent({ global2.state.models.splice(0, global2.state.models.length); console.log("Reset models"); } + function getModelTag(type) { + switch (type) { + case "SD1.x": + return [type, "primary"]; + case "SD2.x": + return [type, "info"]; + case "SDXL": + return [type, "warning"]; + case "Kandinsky 2.1": + case "Kandinsky 2.2": + return ["Kandinsky", "success"]; + default: + return [type, "error"]; + } + } websocketState.onConnectedCallbacks.push(() => { refreshModels(); }); @@ -42544,12 +42942,31 @@ const _sfc_main$3 = /* @__PURE__ */ defineComponent({ break; } } + const topBarWidth = computed(() => { + return isLargeScreen.value ? "calc(100% - 64px)" : "100%"; + }); startWebsocket(message); return (_ctx, _cache) => { var _a2; - return openBlock(), createElementBlock("div", _hoisted_1$1, [ + return openBlock(), createElementBlock("div", _hoisted_1$2, [ + !unref(isLargeScreen) ? (openBlock(), createBlock(unref(NButton), { + key: 0, + bordered: false, + style: { "margin": "0 2px", "padding": "8px 8px" }, + onClick: _cache[0] || (_cache[0] = ($event) => unref(global2).state.collapsibleBarActive = true) + }, { + default: withCtx(() => [ + createVNode(unref(NIcon), { size: "24" }, { + default: withCtx(() => [ + createVNode(unref(Menu)) + ]), + _: 1 + }) + ]), + _: 1 + })) : createCommentVNode("", true), createVNode(unref(NSelect), { - style: { "max-width": "250px", "padding-left": "12px", "padding-right": "12px" }, + style: { "max-width": "250px", "padding-right": "4px" }, options: generatedModelOptions.value, "onUpdate:value": onModelChange, loading: modelsLoading.value, @@ -42559,18 +42976,26 @@ const _sfc_main$3 = /* @__PURE__ */ defineComponent({ filterable: "" }, null, 8, ["options", "loading", "value"]), createVNode(unref(NButton), { - onClick: _cache[0] || (_cache[0] = ($event) => showModal.value = true), + onClick: _cache[1] || (_cache[1] = ($event) => showModal.value = true), loading: modelsLoading.value, type: unref(settings).data.settings.model ? "default" : "success" }, { default: withCtx(() => [ - createTextVNode(" Load Model") + unref(isLargeScreen) ? (openBlock(), createElementBlock("p", _hoisted_2$1, "Load Model")) : (openBlock(), createBlock(unref(NIcon), { + key: 1, + size: "18" + }, { + default: withCtx(() => [ + createVNode(unref(Add)) + ]), + _: 1 + })) ]), _: 1 }, 8, ["loading", "type"]), createVNode(unref(NModal), { show: showModal.value, - "onUpdate:show": _cache[4] || (_cache[4] = ($event) => showModal.value = $event), + "onUpdate:show": _cache[5] || (_cache[5] = ($event) => showModal.value = $event), closable: "", "mask-closable": "", preset: "card", @@ -42579,7 +43004,7 @@ const _sfc_main$3 = /* @__PURE__ */ defineComponent({ "auto-focus": false }, { default: withCtx(() => [ - unref(websocketState).readyState === "CLOSED" ? (openBlock(), createElementBlock("div", _hoisted_2, [ + unref(websocketState).readyState === "CLOSED" ? (openBlock(), createElementBlock("div", _hoisted_3$1, [ createVNode(unref(NResult), { title: "You are not connected to the server", description: "Click the button below to reconnect", @@ -42589,7 +43014,7 @@ const _sfc_main$3 = /* @__PURE__ */ defineComponent({ footer: withCtx(() => [ createVNode(unref(NButton), { type: "success", - onClick: _cache[1] || (_cache[1] = ($event) => unref(startWebsocket)(unref(message))) + onClick: _cache[2] || (_cache[2] = ($event) => unref(startWebsocket)(unref(message))) }, { default: withCtx(() => [ createTextVNode("Reconnect") @@ -42599,7 +43024,7 @@ const _sfc_main$3 = /* @__PURE__ */ defineComponent({ ]), _: 1 }) - ])) : unref(global2).state.models.length === 0 ? (openBlock(), createElementBlock("div", _hoisted_3, [ + ])) : unref(global2).state.models.length === 0 ? (openBlock(), createElementBlock("div", _hoisted_4$1, [ createVNode(unref(NResult), { title: "No models found", style: { "height": "70vh", "display": "flex", "align-items": "center", "justify-content": "center", "flex-direction": "column" }, @@ -42610,7 +43035,7 @@ const _sfc_main$3 = /* @__PURE__ */ defineComponent({ trigger: withCtx(() => [ createVNode(unref(NButton), { type: "success", - onClick: _cache[2] || (_cache[2] = () => { + onClick: _cache[3] || (_cache[3] = () => { unref(global2).state.modelManager.tab = "civitai"; unref(router2).push("/models"); showModal.value = false; @@ -42623,18 +43048,18 @@ const _sfc_main$3 = /* @__PURE__ */ defineComponent({ }) ]), default: withCtx(() => [ - _hoisted_4 + _hoisted_5$1 ]), _: 1 }) ]), _: 1 }) - ])) : (openBlock(), createElementBlock("div", _hoisted_5, [ - createBaseVNode("div", _hoisted_6, [ + ])) : (openBlock(), createElementBlock("div", _hoisted_6$1, [ + createBaseVNode("div", _hoisted_7$1, [ createVNode(unref(NInput), { value: filter.value, - "onUpdate:value": _cache[3] || (_cache[3] = ($event) => filter.value = $event), + "onUpdate:value": _cache[4] || (_cache[4] = ($event) => filter.value = $event), clearable: "", placeholder: "Filter Models" }, null, 8, ["value"]), @@ -42681,8 +43106,20 @@ const _sfc_main$3 = /* @__PURE__ */ defineComponent({ style: { "display": "inline-flex", "width": "100%", "align-items": "center", "justify-content": "space-between", "border-bottom": "1px solid rgb(66, 66, 71)" }, key: model.path }, [ - createBaseVNode("p", null, toDisplayString(model.name), 1), - createBaseVNode("div", _hoisted_7, [ + createBaseVNode("div", _hoisted_8, [ + createVNode(unref(NTag), { + type: getModelTag(model.type)[1], + ghost: "", + style: { "margin-right": "8px" } + }, { + default: withCtx(() => [ + createTextVNode(toDisplayString(getModelTag(model.type)[0]), 1) + ]), + _: 2 + }, 1032, ["type"]), + createBaseVNode("p", null, toDisplayString(model.name), 1) + ]), + createBaseVNode("div", _hoisted_9, [ model.state === "loaded" ? (openBlock(), createBlock(unref(NButton), { key: 0, type: "error", @@ -42730,7 +43167,7 @@ const _sfc_main$3 = /* @__PURE__ */ defineComponent({ default: withCtx(() => [ createVNode(unref(NCard), { title: vae_title.value }, { default: withCtx(() => [ - unref(global2).state.selected_model !== null ? (openBlock(), createElementBlock("div", _hoisted_8, [ + unref(global2).state.selected_model !== null ? (openBlock(), createElementBlock("div", _hoisted_10, [ (openBlock(true), createElementBlock(Fragment, null, renderList(vaeModels.value, (vae) => { var _a3; return openBlock(), createElementBlock("div", { @@ -42738,7 +43175,7 @@ const _sfc_main$3 = /* @__PURE__ */ defineComponent({ key: vae.path }, [ createBaseVNode("p", null, toDisplayString(vae.name), 1), - createBaseVNode("div", _hoisted_9, [ + createBaseVNode("div", _hoisted_11, [ ((_a3 = unref(global2).state.selected_model) == null ? void 0 : _a3.vae) == vae.path ? (openBlock(), createBlock(unref(NButton), { key: 0, type: "error", @@ -42765,7 +43202,7 @@ const _sfc_main$3 = /* @__PURE__ */ defineComponent({ ]) ]); }), 128)) - ])) : (openBlock(), createElementBlock("div", _hoisted_10, [ + ])) : (openBlock(), createElementBlock("div", _hoisted_12, [ createVNode(unref(NAlert), { type: "warning", "show-icon": "", @@ -42794,12 +43231,12 @@ const _sfc_main$3 = /* @__PURE__ */ defineComponent({ title: "Usage of textual inversion" }, { default: withCtx(() => [ - _hoisted_11, + _hoisted_13, createTextVNode(". The name of the inversion that is displayed here will be the actual token (easynegative.pt -> easynegative) ") ]), _: 1 }), - unref(global2).state.selected_model !== null ? (openBlock(), createElementBlock("div", _hoisted_12, [ + unref(global2).state.selected_model !== null ? (openBlock(), createElementBlock("div", _hoisted_14, [ (openBlock(true), createElementBlock(Fragment, null, renderList(textualInversionModels.value, (textualInversion) => { var _a3; return openBlock(), createElementBlock("div", { @@ -42807,7 +43244,7 @@ const _sfc_main$3 = /* @__PURE__ */ defineComponent({ key: textualInversion.path }, [ createBaseVNode("p", null, toDisplayString(textualInversion.name), 1), - createBaseVNode("div", _hoisted_13, [ + createBaseVNode("div", _hoisted_15, [ ((_a3 = unref(global2).state.selected_model) == null ? void 0 : _a3.textual_inversions.includes( textualInversion.path )) ? (openBlock(), createBlock(unref(NButton), { @@ -42836,7 +43273,7 @@ const _sfc_main$3 = /* @__PURE__ */ defineComponent({ ]) ]); }), 128)) - ])) : (openBlock(), createElementBlock("div", _hoisted_14, [ + ])) : (openBlock(), createElementBlock("div", _hoisted_16, [ createVNode(unref(NAlert), { type: "warning", "show-icon": "", @@ -42971,7 +43408,7 @@ const _sfc_main$3 = /* @__PURE__ */ defineComponent({ ]), _: 1 }, 8, ["show"]), - createBaseVNode("div", _hoisted_15, [ + createBaseVNode("div", _hoisted_17, [ createVNode(unref(NProgress), { type: "line", percentage: unref(global2).state.progress, @@ -42991,7 +43428,7 @@ const _sfc_main$3 = /* @__PURE__ */ defineComponent({ _: 1 }, 8, ["percentage", "processing"]) ]), - createBaseVNode("div", _hoisted_16, [ + createBaseVNode("div", _hoisted_18, [ createVNode(unref(NDropdown), { options: dropdownOptions, onSelect: dropdownSelected @@ -43003,7 +43440,7 @@ const _sfc_main$3 = /* @__PURE__ */ defineComponent({ "icon-placement": "left", "render-icon": renderIcon(unref(Wifi)), loading: unref(websocketState).loading, - onClick: _cache[5] || (_cache[5] = ($event) => unref(startWebsocket)(unref(message))) + onClick: _cache[6] || (_cache[6] = ($event) => unref(startWebsocket)(unref(message))) }, null, 8, ["type", "render-icon", "loading"]) ]), _: 1 @@ -43013,16 +43450,125 @@ const _sfc_main$3 = /* @__PURE__ */ defineComponent({ }; } }); -const TopBar_vue_vue_type_style_index_0_scoped_29f01b28_lang = ""; -const TopBar = /* @__PURE__ */ _export_sfc(_sfc_main$3, [["__scopeId", "data-v-29f01b28"]]); +const TopBar_vue_vue_type_style_index_0_scoped_3a99505a_lang = ""; +const TopBar = /* @__PURE__ */ _export_sfc(_sfc_main$4, [["__scopeId", "data-v-3a99505a"]]); const Prompt_vue_vue_type_style_index_0_lang = ""; const Prompt_vue_vue_type_style_index_1_scoped_780680bc_lang = ""; -const Upscale_vue_vue_type_style_index_0_scoped_5358ed01_lang = ""; -const ControlNet_vue_vue_type_style_index_0_scoped_efacc8fd_lang = ""; -const Img2Img_vue_vue_type_style_index_0_scoped_9c556ef8_lang = ""; -const Inpainting_vue_vue_type_style_index_0_scoped_7963dde9_lang = ""; -const CivitAIDownload_vue_vue_type_style_index_0_scoped_e10a07d2_lang = ""; +const ControlNet_vue_vue_type_style_index_0_scoped_d4ff54ab_lang = ""; +const Img2Img_vue_vue_type_style_index_0_scoped_a4145f6c_lang = ""; +const Inpainting_vue_vue_type_style_index_0_scoped_23b19530_lang = ""; +const CivitAIDownload_vue_vue_type_style_index_0_scoped_89afc237_lang = ""; const HuggingfaceDownload_vue_vue_type_style_index_0_scoped_b405f046_lang = ""; +const _hoisted_1$1 = { style: { "margin": "16px 0" } }; +const _hoisted_2 = /* @__PURE__ */ createBaseVNode("b", null, "Key in question:", -1); +const _hoisted_3 = /* @__PURE__ */ createBaseVNode("br", null, null, -1); +const _hoisted_4 = /* @__PURE__ */ createBaseVNode("b", null, "Current value", -1); +const _hoisted_5 = /* @__PURE__ */ createBaseVNode("br", null, null, -1); +const _hoisted_6 = /* @__PURE__ */ createBaseVNode("b", null, "Default value", -1); +const _hoisted_7 = /* @__PURE__ */ createBaseVNode("br", null, null, -1); +const _sfc_main$3 = /* @__PURE__ */ defineComponent({ + __name: "SettingsDiffResolver", + setup(__props) { + const global2 = useState2(); + const settings = useSettings(); + const message = useMessage(); + function apply2() { + const key = global2.state.settings_diff.key; + const default_value = global2.state.settings_diff.default_value; + const indexable_keys = key.slice(0, key.length - 1); + const last_key = key[key.length - 1]; + let current = settings.defaultSettings; + for (const indexable_key of indexable_keys) { + current = current[indexable_key]; + } + current[last_key] = default_value; + settings.saveSettings().then(() => { + message.success("Settings saved"); + }).catch((e) => { + message.error("Failed to save settings: " + e); + }).finally(() => { + global2.state.settings_diff.active = false; + }); + } + return (_ctx, _cache) => { + return openBlock(), createBlock(unref(NModal), { + show: unref(global2).state.settings_diff.active + }, { + default: withCtx(() => [ + createVNode(unref(NCard), { + title: "Settings Diff Resolver", + style: { "max-width": "90vw" } + }, { + default: withCtx(() => [ + createVNode(unref(NAlert), { + "show-icon": "", + type: "warning" + }, { + default: withCtx(() => [ + createTextVNode("Failed to save config") + ]), + _: 1 + }), + createBaseVNode("div", _hoisted_1$1, [ + _hoisted_2, + createTextVNode(" " + toDisplayString(unref(global2).state.settings_diff.key.join("->")) + " ", 1), + _hoisted_3, + _hoisted_4, + createTextVNode(" " + toDisplayString(unref(global2).state.settings_diff.current_value) + " ", 1), + _hoisted_5, + _hoisted_6, + createTextVNode(" " + toDisplayString(unref(global2).state.settings_diff.default_value) + " ", 1), + _hoisted_7 + ]), + createVNode(unref(NGrid), { + cols: "2", + "x-gap": "8" + }, { + default: withCtx(() => [ + createVNode(unref(NGi), null, { + default: withCtx(() => [ + createVNode(unref(NButton), { + type: "warning", + block: "", + ghost: "", + style: { "width": "100%" }, + onClick: _cache[0] || (_cache[0] = ($event) => unref(global2).state.settings_diff.active = false) + }, { + default: withCtx(() => [ + createTextVNode(" I Will Fix It Myself ") + ]), + _: 1 + }) + ]), + _: 1 + }), + createVNode(unref(NGi), null, { + default: withCtx(() => [ + createVNode(unref(NButton), { + type: "primary", + style: { "width": "100%" }, + onClick: apply2 + }, { + default: withCtx(() => [ + createTextVNode(" Apply Default Value and Save ") + ]), + _: 1 + }) + ]), + _: 1 + }) + ]), + _: 1 + }) + ]), + _: 1 + }) + ]), + _: 1 + }, 8, ["show"]); + }; + } +}); const _sfc_main$2 = {}; function _sfc_render(_ctx, _cache) { const _component_RouterView = resolveComponent("RouterView"); @@ -43036,7 +43582,7 @@ const _sfc_main$1 = /* @__PURE__ */ defineComponent({ return (_ctx, _cache) => { return openBlock(), createBlock(unref(NNotificationProvider), { placement: "bottom-right", - max: 3 + max: 2 }, { default: withCtx(() => [ createVNode(unref(NLoadingBarProvider), null, { @@ -43044,13 +43590,14 @@ const _sfc_main$1 = /* @__PURE__ */ defineComponent({ createVNode(unref(NMessageProvider), null, { default: withCtx(() => [ _hoisted_1, - createVNode(unref(_sfc_main$4)), - createVNode(unref(_sfc_main$8)), + createVNode(unref(_sfc_main$5)), + createVNode(unref(_sfc_main$9)), createVNode(unref(TopBar)), + createVNode(unref(_sfc_main$8)), + createVNode(routerContainerVue, { class: "router-container" }), + createVNode(unref(_sfc_main$6)), createVNode(unref(_sfc_main$7)), - createVNode(routerContainerVue, { style: { "margin-top": "52px" } }), - createVNode(unref(_sfc_main$5)), - createVNode(unref(_sfc_main$6)) + createVNode(unref(_sfc_main$3)) ]), _: 1 }) @@ -43063,19 +43610,21 @@ const _sfc_main$1 = /* @__PURE__ */ defineComponent({ }; } }); +const Content_vue_vue_type_style_index_0_lang = ""; const _sfc_main = /* @__PURE__ */ defineComponent({ __name: "App", setup(__props) { useCssVars((_ctx) => { var _a2, _b, _c; return { - "4c7ba08e": theme.value.common.popoverColor, - "01ab46a4": theme.value.common.borderRadius, - "e4e78d9e": theme.value.common.pressedColor, - "d0777f2a": theme.value.common.primaryColorHover, - "98485856": blur.value, - "6a1d04dc": ((_b = (_a2 = overrides.value) == null ? void 0 : _a2.Card) == null ? void 0 : _b.color) ?? ((_c = theme.value.Card.common) == null ? void 0 : _c.cardColor), - "344206c2": backgroundImage.value + "e68ef196": theme.value.common.popoverColor, + "3f674355": theme.value.common.borderRadius, + "646dc050": theme.value.common.pressedColor, + "96bf2bb8": theme.value.common.primaryColorHover, + "b08f9a64": blur.value, + "139458d6": ((_b = (_a2 = overrides.value) == null ? void 0 : _a2.Card) == null ? void 0 : _b.color) ?? ((_c = theme.value.Card.common) == null ? void 0 : _c.cardColor), + "31f48ff4": backgroundImage.value, + "3ac72808": marginLeft.value }; }); const settings = useSettings(); @@ -43121,6 +43670,9 @@ const _sfc_main = /* @__PURE__ */ defineComponent({ document.body.style.backgroundColor = ((_b = (_a2 = overrides.value) == null ? void 0 : _a2.common) == null ? void 0 : _b.baseColor) ?? theme.value.common.baseColor; } ); + const marginLeft = computed(() => { + return isLargeScreen.value ? "64px" : "0px"; + }); return (_ctx, _cache) => { return openBlock(), createBlock(unref(NConfigProvider), { theme: theme.value, @@ -43198,27 +43750,27 @@ const router = createRouter({ { path: "/", name: "home", - component: () => __vitePreload(() => import("./TextToImageView.js"), true ? ["assets/TextToImageView.js","assets/GenerateSection.vue_vue_type_script_setup_true_lang.js","assets/ImageOutput.vue_vue_type_script_setup_true_lang.js","assets/SendOutputTo.vue_vue_type_script_setup_true_lang.js","assets/Switch.js","assets/TrashBin.js","assets/clock.js","assets/DescriptionsItem.js","assets/InputNumber.js","assets/SamplerPicker.vue_vue_type_script_setup_true_lang.js","assets/Settings.js","assets/v4.js"] : void 0) + component: () => __vitePreload(() => import("./TextToImageView.js"), true ? ["assets/TextToImageView.js","assets/GenerateSection.vue_vue_type_script_setup_true_lang.js","assets/ImageOutput.vue_vue_type_script_setup_true_lang.js","assets/SendOutputTo.vue_vue_type_script_setup_true_lang.js","assets/Switch.js","assets/TrashBin.js","assets/clock.js","assets/DescriptionsItem.js","assets/Slider.js","assets/InputNumber.js","assets/Upscale.vue_vue_type_script_setup_true_lang.js","assets/Settings.js","assets/v4.js"] : void 0) }, { path: "/txt2img", name: "txt2img", - component: () => __vitePreload(() => import("./TextToImageView.js"), true ? ["assets/TextToImageView.js","assets/GenerateSection.vue_vue_type_script_setup_true_lang.js","assets/ImageOutput.vue_vue_type_script_setup_true_lang.js","assets/SendOutputTo.vue_vue_type_script_setup_true_lang.js","assets/Switch.js","assets/TrashBin.js","assets/clock.js","assets/DescriptionsItem.js","assets/InputNumber.js","assets/SamplerPicker.vue_vue_type_script_setup_true_lang.js","assets/Settings.js","assets/v4.js"] : void 0) + component: () => __vitePreload(() => import("./TextToImageView.js"), true ? ["assets/TextToImageView.js","assets/GenerateSection.vue_vue_type_script_setup_true_lang.js","assets/ImageOutput.vue_vue_type_script_setup_true_lang.js","assets/SendOutputTo.vue_vue_type_script_setup_true_lang.js","assets/Switch.js","assets/TrashBin.js","assets/clock.js","assets/DescriptionsItem.js","assets/Slider.js","assets/InputNumber.js","assets/Upscale.vue_vue_type_script_setup_true_lang.js","assets/Settings.js","assets/v4.js"] : void 0) }, { path: "/img2img", name: "img2img", - component: () => __vitePreload(() => import("./Image2ImageView.js"), true ? ["assets/Image2ImageView.js","assets/clock.js","assets/DescriptionsItem.js","assets/Switch.js","assets/InputNumber.js","assets/SamplerPicker.vue_vue_type_script_setup_true_lang.js","assets/Settings.js","assets/GenerateSection.vue_vue_type_script_setup_true_lang.js","assets/ImageOutput.vue_vue_type_script_setup_true_lang.js","assets/SendOutputTo.vue_vue_type_script_setup_true_lang.js","assets/TrashBin.js","assets/ImageUpload.js","assets/CloudUpload.js","assets/v4.js"] : void 0) + component: () => __vitePreload(() => import("./Image2ImageView.js"), true ? ["assets/Image2ImageView.js","assets/clock.js","assets/DescriptionsItem.js","assets/Slider.js","assets/InputNumber.js","assets/Upscale.vue_vue_type_script_setup_true_lang.js","assets/Switch.js","assets/Settings.js","assets/GenerateSection.vue_vue_type_script_setup_true_lang.js","assets/ImageOutput.vue_vue_type_script_setup_true_lang.js","assets/SendOutputTo.vue_vue_type_script_setup_true_lang.js","assets/TrashBin.js","assets/ImageUpload.js","assets/CloudUpload.js","assets/v4.js"] : void 0) }, { path: "/imageProcessing", name: "imageProcessing", - component: () => __vitePreload(() => import("./ImageProcessingView.js"), true ? ["assets/ImageProcessingView.js","assets/GenerateSection.vue_vue_type_script_setup_true_lang.js","assets/ImageOutput.vue_vue_type_script_setup_true_lang.js","assets/SendOutputTo.vue_vue_type_script_setup_true_lang.js","assets/Switch.js","assets/TrashBin.js","assets/ImageUpload.js","assets/CloudUpload.js","assets/InputNumber.js"] : void 0) + component: () => __vitePreload(() => import("./ImageProcessingView.js"), true ? ["assets/ImageProcessingView.js","assets/GenerateSection.vue_vue_type_script_setup_true_lang.js","assets/ImageOutput.vue_vue_type_script_setup_true_lang.js","assets/SendOutputTo.vue_vue_type_script_setup_true_lang.js","assets/Switch.js","assets/TrashBin.js","assets/ImageUpload.js","assets/CloudUpload.js","assets/Slider.js","assets/InputNumber.js"] : void 0) }, { path: "/models", name: "models", - component: () => __vitePreload(() => import("./ModelsView.js"), true ? ["assets/ModelsView.js","assets/ModelPopup.vue_vue_type_script_setup_true_lang.js","assets/DescriptionsItem.js","assets/GridOutline.js","assets/Switch.js","assets/Settings.js","assets/TrashBin.js","assets/CloudUpload.js"] : void 0) + component: () => __vitePreload(() => import("./ModelsView.js"), true ? ["assets/ModelsView.js","assets/ModelPopup.vue_vue_type_script_setup_true_lang.js","assets/DescriptionsItem.js","assets/Settings.js","assets/Switch.js","assets/TrashBin.js","assets/CloudUpload.js"] : void 0) }, { path: "/about", @@ -43228,7 +43780,7 @@ const router = createRouter({ { path: "/accelerate", name: "accelerate", - component: () => __vitePreload(() => import("./AccelerateView.js"), true ? ["assets/AccelerateView.js","assets/Switch.js","assets/InputNumber.js"] : void 0) + component: () => __vitePreload(() => import("./AccelerateView.js"), true ? ["assets/AccelerateView.js","assets/Slider.js","assets/InputNumber.js","assets/Switch.js"] : void 0) }, { path: "/extra", @@ -43243,17 +43795,17 @@ const router = createRouter({ { path: "/settings", name: "settings", - component: () => __vitePreload(() => import("./SettingsView.js"), true ? ["assets/SettingsView.js","assets/SamplerPicker.vue_vue_type_script_setup_true_lang.js","assets/Settings.js","assets/InputNumber.js","assets/Switch.js"] : void 0) + component: () => __vitePreload(() => import("./SettingsView.js"), true ? ["assets/SettingsView.js","assets/Upscale.vue_vue_type_script_setup_true_lang.js","assets/Switch.js","assets/InputNumber.js","assets/Slider.js","assets/Settings.js"] : void 0) }, { path: "/imageBrowser", name: "imageBrowser", - component: () => __vitePreload(() => import("./ImageBrowserView.js"), true ? ["assets/ImageBrowserView.js","assets/SendOutputTo.vue_vue_type_script_setup_true_lang.js","assets/Switch.js","assets/GridOutline.js","assets/TrashBin.js","assets/DescriptionsItem.js","assets/ImageBrowserView.css"] : void 0) + component: () => __vitePreload(() => import("./ImageBrowserView.js"), true ? ["assets/ImageBrowserView.js","assets/SendOutputTo.vue_vue_type_script_setup_true_lang.js","assets/Switch.js","assets/TrashBin.js","assets/Slider.js","assets/DescriptionsItem.js","assets/ImageBrowserView.css"] : void 0) }, { path: "/tagger", name: "tagger", - component: () => __vitePreload(() => import("./TaggerView.js"), true ? ["assets/TaggerView.js","assets/GenerateSection.vue_vue_type_script_setup_true_lang.js","assets/ImageUpload.js","assets/CloudUpload.js","assets/v4.js","assets/Switch.js","assets/InputNumber.js","assets/TaggerView.css"] : void 0) + component: () => __vitePreload(() => import("./TaggerView.js"), true ? ["assets/TaggerView.js","assets/GenerateSection.vue_vue_type_script_setup_true_lang.js","assets/ImageUpload.js","assets/CloudUpload.js","assets/v4.js","assets/Slider.js","assets/InputNumber.js","assets/Switch.js","assets/TaggerView.css"] : void 0) }, { path: "/:pathMatch(.*)", @@ -43275,190 +43827,193 @@ app.use(index, { }); app.mount("#app"); export { - createTmOptions as $, - NButton as A, - NIcon as B, - toDisplayString as C, - NTabPane as D, - NTabs as E, + NInternalSelectMenu as $, + h as A, + ref as B, + NButton as C, + NIcon as D, + toDisplayString as E, Fragment as F, spaceRegex as G, promptHandleKeyUp as H, promptHandleKeyDown as I, NInput as J, watch as K, - renderList as L, - NScrollbar as M, - NSpace as N, - replaceable as O, - createInjectionKey as P, - cB as Q, - inject as R, - useConfig as S, - useTheme as T, - popselectLight$1 as U, - createTreeMate as V, - nextTick as W, - toRef as X, - useThemeClass as Y, - NInternalSelectMenu as Z, + upscalerOptions as L, + renderList as M, + NTooltip as N, + NScrollbar as O, + replaceable as P, + createInjectionKey as Q, + cB as R, + inject as S, + useConfig as T, + useTheme as U, + popselectLight$1 as V, + createTreeMate as W, + nextTick as X, + toRef as Y, + useThemeClass as Z, _export_sfc as _, - useState2 as a, - AddIcon as a$, - happensIn as a0, - call as a1, - keysOf as a2, - provide as a3, - keep as a4, - createRefSetter as a5, - mergeEventHandlers as a6, - omit as a7, - NPopover as a8, - popoverBaseProps as a9, - NScrollbar$1 as aA, - onBeforeUnmount as aB, - off as aC, - on as aD, - ChevronDownIcon as aE, - NDropdown as aF, - pxfy as aG, - get as aH, - NIconSwitchTransition as aI, - NBaseLoading as aJ, - ChevronRightIcon as aK, - VResizeObserver as aL, - warn$2 as aM, - cssrAnchorMetaName as aN, - VVirtualList as aO, - NEmpty as aP, - repeat as aQ, - beforeNextFrameOnce as aR, - fadeInScaleUpTransition as aS, - iconSwitchTransition as aT, - insideModal as aU, - insidePopover as aV, - createId as aW, - Transition as aX, - dataTableLight$1 as aY, - loadingBarApiInjectionKey as aZ, - throwError as a_, - c$1 as aa, - cM as ab, - cNotM as ac, - useLocale as ad, - useMergedState as ae, - watchEffect as af, - useRtl as ag, - createKey as ah, - resolveSlot as ai, - NBaseIcon as aj, - useAdjustedTo as ak, - paginationLight$1 as al, - useMergedClsPrefix as am, - ellipsisLight$1 as an, - onDeactivated as ao, - mergeProps as ap, - useStyle as aq, - useFormItem as ar, - useMemo as as, - cE as at, - radioLight$1 as au, - resolveWrappedSlot as av, - flatten$2 as aw, - getSlot$1 as ax, - depx as ay, - formatLength as az, - upscalerOptions as b, - VFollower as b$, - NProgress as b0, - NFadeInExpandTransition as b1, - EyeIcon as b2, - fadeInHeightExpandTransition as b3, - Teleport as b4, - uploadLight$1 as b5, - useCssVars as b6, - themeOverridesKey as b7, - reactive as b8, - onMounted as b9, - useNotification as bA, - defaultSettings as bB, - getCurrentInstance as bC, - formLight$1 as bD, - commonVariables$m as bE, - formItemInjectionKey as bF, - resolveDynamicComponent as bG, - checkboxLight$1 as bH, - urlFromPath as bI, - diffusersSchedulerTuple as bJ, - useRouter as bK, - isBrowser$3 as bL, - fadeInTransition as bM, - imageLight as bN, - isMounted as bO, - LazyTeleport as bP, - zindexable$1 as bQ, - kebabCase$1 as bR, - useCompitable as bS, - descriptionsLight$1 as bT, - withModifiers as bU, - NAlert as bV, - rgba as bW, - inputNumberLight$1 as bX, - XButton as bY, - VBinder as bZ, - VTarget as b_, - normalizeStyle as ba, - NText as bb, - huggingfaceModelsFile as bc, - NModal as bd, - NDivider as be, - Backends as bf, - stepsLight$1 as bg, - FinishedIcon as bh, - ErrorIcon$1 as bi, - upperFirst$1 as bj, - toString as bk, - createCompounder as bl, - cloneVNode as bm, - onBeforeUpdate as bn, - indexMap as bo, - onUpdated as bp, - resolveSlotWithProps as bq, - withDirectives as br, - vShow as bs, - getPreciseEventTarget as bt, - carouselLight$1 as bu, - color2Class as bv, - rateLight as bw, - NTag as bx, - convertToTextString as by, - themeKey as bz, + createElementBlock as a, + throwError as a$, + createTmOptions as a0, + happensIn as a1, + call as a2, + keysOf as a3, + provide as a4, + keep as a5, + createRefSetter as a6, + mergeEventHandlers as a7, + omit as a8, + NPopover as a9, + formatLength as aA, + NScrollbar$1 as aB, + onBeforeUnmount as aC, + off as aD, + on as aE, + ChevronDownIcon as aF, + NDropdown as aG, + pxfy as aH, + get as aI, + NIconSwitchTransition as aJ, + NBaseLoading as aK, + ChevronRightIcon as aL, + cssrAnchorMetaName as aM, + VResizeObserver as aN, + warn$2 as aO, + VVirtualList as aP, + NEmpty as aQ, + repeat as aR, + beforeNextFrameOnce as aS, + fadeInScaleUpTransition as aT, + iconSwitchTransition as aU, + insideModal as aV, + insidePopover as aW, + createId as aX, + Transition as aY, + dataTableLight$1 as aZ, + loadingBarApiInjectionKey as a_, + popoverBaseProps as aa, + c$1 as ab, + cM as ac, + cNotM as ad, + useLocale as ae, + useMergedState as af, + watchEffect as ag, + useRtl as ah, + createKey as ai, + resolveSlot as aj, + NBaseIcon as ak, + useAdjustedTo as al, + paginationLight$1 as am, + useMergedClsPrefix as an, + ellipsisLight$1 as ao, + onDeactivated as ap, + mergeProps as aq, + useStyle as ar, + useFormItem as as, + useMemo as at, + cE as au, + radioLight$1 as av, + resolveWrappedSlot as aw, + flatten$2 as ax, + getSlot$1 as ay, + depx as az, + createBaseVNode as b, + XButton as b$, + AddIcon as b0, + NProgress as b1, + NFadeInExpandTransition as b2, + EyeIcon as b3, + fadeInHeightExpandTransition as b4, + Teleport as b5, + uploadLight$1 as b6, + createStaticVNode as b7, + useCssVars as b8, + themeOverridesKey as b9, + rateLight as bA, + NTag as bB, + convertToTextString as bC, + themeKey as bD, + useNotification as bE, + defaultSettings as bF, + getCurrentInstance as bG, + formLight$1 as bH, + commonVariables$m as bI, + formItemInjectionKey as bJ, + NAlert as bK, + resolveDynamicComponent as bL, + checkboxLight$1 as bM, + urlFromPath as bN, + diffusersSchedulerTuple as bO, + useRouter as bP, + isBrowser$3 as bQ, + fadeInTransition as bR, + imageLight as bS, + isMounted as bT, + LazyTeleport as bU, + zindexable$1 as bV, + kebabCase$1 as bW, + useCompitable as bX, + descriptionsLight$1 as bY, + rgba as bZ, + inputNumberLight$1 as b_, + reactive as ba, + onMounted as bb, + normalizeStyle as bc, + NText as bd, + withModifiers as be, + huggingfaceModelsFile as bf, + Menu as bg, + NModal as bh, + NDivider as bi, + Backends as bj, + stepsLight$1 as bk, + FinishedIcon as bl, + ErrorIcon$1 as bm, + upperFirst$1 as bn, + toString as bo, + createCompounder as bp, + cloneVNode as bq, + onBeforeUpdate as br, + indexMap as bs, + onUpdated as bt, + resolveSlotWithProps as bu, + withDirectives as bv, + vShow as bw, + getPreciseEventTarget as bx, + carouselLight$1 as by, + color2Class as bz, computed as c, - sliderLight$1 as c0, - isSlotEmpty as c1, - switchLight$1 as c2, - NResult as c3, + VBinder as c0, + VTarget as c1, + VFollower as c2, + sliderLight$1 as c3, + isSlotEmpty as c4, + switchLight$1 as c5, + NResult as c6, defineComponent as d, - createBlock as e, - createBaseVNode as f, - createVNode as g, - unref as h, - NSelect as i, - createElementBlock as j, - createTextVNode as k, - NTooltip as l, - createCommentVNode as m, - NCard as n, + createVNode as e, + unref as f, + createBlock as g, + createTextVNode as h, + isDev as i, + NSpace as j, + createCommentVNode as k, + useState2 as l, + NCard as m, + NTabPane as n, openBlock as o, - useMessage as p, - onUnmounted as q, - NGi as r, - NGrid as s, - serverUrl as t, + NTabs as p, + NSelect as q, + useMessage as r, + onUnmounted as s, + NGi as t, useSettings as u, - pushScopeId as v, + NGrid as v, withCtx as w, - popScopeId as x, - h as y, - ref as z + serverUrl as x, + pushScopeId as y, + popScopeId as z }; diff --git a/frontend/src/App.vue b/frontend/src/App.vue index f2346d5b1..445ae374b 100644 --- a/frontend/src/App.vue +++ b/frontend/src/App.vue @@ -17,6 +17,7 @@ import { computed, provide, ref, watch } from "vue"; import { useState as useAnalytics } from "vue-gtag-next"; import Content from "./Content.vue"; import { serverUrl } from "./env"; +import { isLargeScreen } from "./helper/mediaQueries"; import { useSettings } from "./store/settings"; import type { ExtendedThemeOverrides } from "./types"; @@ -72,6 +73,10 @@ watch( overrides.value?.common?.baseColor ?? theme.value.common.baseColor; } ); + +const marginLeft = computed(() => { + return isLargeScreen.value ? "64px" : "0px"; +}); diff --git a/frontend/src/Content.vue b/frontend/src/Content.vue index 5c9d38880..045b62ce9 100644 --- a/frontend/src/Content.vue +++ b/frontend/src/Content.vue @@ -1,6 +1,6 @@ diff --git a/frontend/src/components/imageProcessing/index.ts b/frontend/src/components/imageProcessing/index.ts index b35cd812b..c27a8a325 100644 --- a/frontend/src/components/imageProcessing/index.ts +++ b/frontend/src/components/imageProcessing/index.ts @@ -1 +1 @@ -export { default as Upscale } from "./Upscale.vue"; +export { default as ESRGAN } from "./ESRGAN.vue"; diff --git a/frontend/src/components/inference/ControlNet.vue b/frontend/src/components/inference/ControlNet.vue index d24702797..99f646741 100644 --- a/frontend/src/components/inference/ControlNet.vue +++ b/frontend/src/components/inference/ControlNet.vue @@ -77,36 +77,9 @@ /> - -
- - - Guidance scale indicates how much should model stay close to the - prompt. Higher values might be exactly what you want, but - generated images might have some artefacts. Lower values - indicates that model can "dream" about this prompt more. - We recommend using 3-15 for most images. - - - -
+ + +
@@ -248,6 +221,9 @@
+ + + @@ -277,13 +253,17 @@ import "@/assets/2img.css"; import { BurnerClock } from "@/clock"; import { BatchSizeInput, + CFGScale, DimensionsInput, GenerateSection, + HighResFixTabs, ImageOutput, ImageUpload, OutputStats, Prompt, + SAGInput, SamplerPicker, + Upscale, } from "@/components"; import { serverUrl } from "@/env"; import { @@ -376,6 +356,36 @@ const generate = () => { settings.data.settings.controlnet.return_preprocessed, }, model: settings.data.settings.model?.path, + flags: { + ...(settings.data.settings.controlnet.highres.enabled + ? { + highres_fix: { + mode: settings.data.settings.controlnet.highres.mode, + image_upscaler: + settings.data.settings.controlnet.highres.image_upscaler, + scale: settings.data.settings.controlnet.highres.scale, + latent_scale_mode: + settings.data.settings.controlnet.highres.latent_scale_mode, + strength: settings.data.settings.controlnet.highres.strength, + steps: settings.data.settings.controlnet.highres.steps, + antialiased: + settings.data.settings.controlnet.highres.antialiased, + }, + } + : {}), + ...(settings.data.settings.controlnet.upscale.enabled + ? { + upscale: { + upscale_factor: + settings.data.settings.controlnet.upscale.upscale_factor, + tile_size: settings.data.settings.controlnet.upscale.tile_size, + tile_padding: + settings.data.settings.controlnet.upscale.tile_padding, + model: settings.data.settings.controlnet.upscale.model, + }, + } + : {}), + }, }), }) .then((res) => { diff --git a/frontend/src/components/inference/Img2Img.vue b/frontend/src/components/inference/Img2Img.vue index 550eccfd4..514d25d19 100644 --- a/frontend/src/components/inference/Img2Img.vue +++ b/frontend/src/components/inference/Img2Img.vue @@ -51,76 +51,9 @@ /> - -
- - - Guidance scale indicates how much should model stay close to the - prompt. Higher values might be exactly what you want, but - generated images might have some artefacts. Lower values - indicates that model can "dream" about this prompt more. - We recommend using 3-15 for most images. - - - -
- - -
- - - PyTorch ONLY. If self attention is >0, - SAG will guide the model and improve the quality of the image at - the cost of speed. Higher values will follow the guidance more - closely, which can lead to better, more sharp and detailed - outputs. - + - - -
+
@@ -200,6 +133,9 @@
+ + + @@ -229,13 +165,17 @@ import "@/assets/2img.css"; import { BurnerClock } from "@/clock"; import { BatchSizeInput, + CFGScale, DimensionsInput, GenerateSection, + HighResFixTabs, ImageOutput, ImageUpload, OutputStats, Prompt, + SAGInput, SamplerPicker, + Upscale, } from "@/components"; import { serverUrl } from "@/env"; import { @@ -314,7 +254,52 @@ const generate = () => { prompt_to_prompt: settings.data.settings.api.prompt_to_prompt, }, }, + ...(settings.data.settings.img2img.deepshrink.enabled + ? { + flags: { + deepshrink: { + early_out: settings.data.settings.img2img.deepshrink.early_out, + depth_1: settings.data.settings.img2img.deepshrink.depth_1, + stop_at_1: settings.data.settings.img2img.deepshrink.stop_at_1, + depth_2: settings.data.settings.img2img.deepshrink.depth_2, + stop_at_2: settings.data.settings.img2img.deepshrink.stop_at_2, + scaler: settings.data.settings.img2img.deepshrink.scaler, + base_scale: + settings.data.settings.img2img.deepshrink.base_scale, + }, + }, + } + : {}), model: settings.data.settings.model?.path, + flags: { + ...(settings.data.settings.img2img.highres.enabled + ? { + highres_fix: { + mode: settings.data.settings.img2img.highres.mode, + image_upscaler: + settings.data.settings.img2img.highres.image_upscaler, + scale: settings.data.settings.img2img.highres.scale, + latent_scale_mode: + settings.data.settings.img2img.highres.latent_scale_mode, + strength: settings.data.settings.img2img.highres.strength, + steps: settings.data.settings.img2img.highres.steps, + antialiased: settings.data.settings.img2img.highres.antialiased, + }, + } + : {}), + ...(settings.data.settings.img2img.upscale.enabled + ? { + upscale: { + upscale_factor: + settings.data.settings.img2img.upscale.upscale_factor, + tile_size: settings.data.settings.img2img.upscale.tile_size, + tile_padding: + settings.data.settings.img2img.upscale.tile_padding, + model: settings.data.settings.img2img.upscale.model, + }, + } + : {}), + }, }), }) .then((res) => { diff --git a/frontend/src/components/inference/Inpainting.vue b/frontend/src/components/inference/Inpainting.vue index 9be531806..1beefe62c 100644 --- a/frontend/src/components/inference/Inpainting.vue +++ b/frontend/src/components/inference/Inpainting.vue @@ -170,74 +170,32 @@ /> - -
- - - Guidance scale indicates how much should model stay close to the - prompt. Higher values might be exactly what you want, but - generated images might have some artefacts. Lower values - indicates that model can "dream" about this prompt more. - We recommend using 3-15 for most images. - - - -
+ - -
+ + + +
- PyTorch ONLY. If self attention is >0, - SAG will guide the model and improve the quality of the image at - the cost of speed. Higher values will follow the guidance more - closely, which can lead to better, more sharp and detailed - outputs. + How much should the masked are be changed from the original -
@@ -306,6 +264,9 @@
+ + + @@ -334,11 +295,15 @@ import "@/assets/2img.css"; import { BurnerClock } from "@/clock"; import { + CFGScale, GenerateSection, + HighResFixTabs, ImageOutput, OutputStats, Prompt, + SAGInput, SamplerPicker, + Upscale, } from "@/components"; import { serverUrl } from "@/env"; import { @@ -425,7 +390,56 @@ const generate = () => { prompt_to_prompt: settings.data.settings.api.prompt_to_prompt, }, }, + ...(settings.data.settings.inpainting.deepshrink.enabled + ? { + flags: { + deepshrink: { + early_out: + settings.data.settings.inpainting.deepshrink.early_out, + depth_1: settings.data.settings.inpainting.deepshrink.depth_1, + stop_at_1: + settings.data.settings.inpainting.deepshrink.stop_at_1, + depth_2: settings.data.settings.inpainting.deepshrink.depth_2, + stop_at_2: + settings.data.settings.inpainting.deepshrink.stop_at_2, + scaler: settings.data.settings.inpainting.deepshrink.scaler, + base_scale: + settings.data.settings.inpainting.deepshrink.base_scale, + }, + }, + } + : {}), model: settings.data.settings.model?.path, + flags: { + ...(settings.data.settings.inpainting.highres.enabled + ? { + highres_fix: { + mode: settings.data.settings.inpainting.highres.mode, + image_upscaler: + settings.data.settings.inpainting.highres.image_upscaler, + scale: settings.data.settings.inpainting.highres.scale, + latent_scale_mode: + settings.data.settings.inpainting.highres.latent_scale_mode, + strength: settings.data.settings.inpainting.highres.strength, + steps: settings.data.settings.inpainting.highres.steps, + antialiased: + settings.data.settings.inpainting.highres.antialiased, + }, + } + : {}), + ...(settings.data.settings.inpainting.upscale.enabled + ? { + upscale: { + upscale_factor: + settings.data.settings.inpainting.upscale.upscale_factor, + tile_size: settings.data.settings.inpainting.upscale.tile_size, + tile_padding: + settings.data.settings.inpainting.upscale.tile_padding, + model: settings.data.settings.inpainting.upscale.model, + }, + } + : {}), + }, }), }) .then((res) => { diff --git a/frontend/src/components/inference/Txt2Img.vue b/frontend/src/components/inference/Txt2Img.vue index 981fbbb17..80bb95431 100644 --- a/frontend/src/components/inference/Txt2Img.vue +++ b/frontend/src/components/inference/Txt2Img.vue @@ -43,71 +43,9 @@ /> - -
- - - Guidance scale indicates how much should model stay close to the - prompt. Higher values might be exactly what you want, but - generated images might have some artefacts. Lower values - indicates that model can "dream" about this prompt more. - We recommend using 3-15 for most images. - - - -
- - -
- - - If self attention is >0, SAG will guide the model and improve - the quality of the image at the cost of speed. Higher values - will follow the guidance more closely, which can lead to better, - more sharp and detailed outputs. - + - - -
+
@@ -154,7 +92,16 @@ - + + + + + + + @@ -182,13 +129,19 @@ diff --git a/frontend/src/components/models/CivitAIModelImage.vue b/frontend/src/components/models/CivitAIModelImage.vue new file mode 100644 index 000000000..de3e40b3f --- /dev/null +++ b/frontend/src/components/models/CivitAIModelImage.vue @@ -0,0 +1,129 @@ + + + diff --git a/frontend/src/components/models/index.ts b/frontend/src/components/models/index.ts index 54946f188..f56a8412c 100644 --- a/frontend/src/components/models/index.ts +++ b/frontend/src/components/models/index.ts @@ -1,4 +1,5 @@ export { default as CivitAIDownload } from "./CivitAIDownload.vue"; +export { default as CivitAIModelImage } from "./CivitAIModelImage.vue"; export { default as HuggingfaceDownload } from "./HuggingfaceDownload.vue"; export { default as ModelConvert } from "./ModelConvert.vue"; export { default as ModelManager } from "./ModelManager.vue"; diff --git a/frontend/src/components/settings/FlagsSettings.vue b/frontend/src/components/settings/FlagsSettings.vue deleted file mode 100644 index 5fa9b97ec..000000000 --- a/frontend/src/components/settings/FlagsSettings.vue +++ /dev/null @@ -1,77 +0,0 @@ - - - diff --git a/frontend/src/components/settings/ReproducibilitySettings.vue b/frontend/src/components/settings/ReproducibilitySettings.vue index 7c788223f..495e1e274 100644 --- a/frontend/src/components/settings/ReproducibilitySettings.vue +++ b/frontend/src/components/settings/ReproducibilitySettings.vue @@ -189,37 +189,133 @@ -
- - - - - - - - - - - - -
+ +
+
+ + Apply SD 1.4 Defaults + + + Apply SD 1.5 Defaults + + + Apply SD 2.1 Defaults + + + Apply SDXL Defaults + +
+ + + + + + + + + + + + + +
+
+ + + + + + + + + + + + + diff --git a/frontend/src/components/settings/defaultSettings/ControlNetSettings.vue b/frontend/src/components/settings/defaultSettings/ControlNetSettings.vue index eec3dbaa0..7ff20081e 100644 --- a/frontend/src/components/settings/defaultSettings/ControlNetSettings.vue +++ b/frontend/src/components/settings/defaultSettings/ControlNetSettings.vue @@ -77,13 +77,16 @@ :step="8" /> + + +