diff --git a/docs/contributing/MODEL_MANAGER.md b/docs/contributing/MODEL_MANAGER.md index 9699db4f1a6..7e20fb68280 100644 --- a/docs/contributing/MODEL_MANAGER.md +++ b/docs/contributing/MODEL_MANAGER.md @@ -1328,7 +1328,7 @@ from invokeai.app.services.model_load import ModelLoadService, ModelLoaderRegist config = InvokeAIAppConfig.get_config() ram_cache = ModelCache( - max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger + max_cache_size=config.ram_cache_size, logger=logger ) convert_cache = ModelConvertCache( cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 1e78e10d380..67b25dcda93 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -103,6 +103,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: textual_inversion_manager=ti_manager, dtype_for_device_getter=TorchDevice.choose_torch_dtype, truncate_long_prompts=False, + device=TorchDevice.choose_torch_device(), ) conjunction = Compel.parse_prompt_string(self.prompt) @@ -117,6 +118,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: conditioning_data = ConditioningFieldData(conditionings=[BasicConditioningInfo(embeds=c)]) conditioning_name = context.conditioning.save(conditioning_data) + return ConditioningOutput( conditioning=ConditioningField( conditioning_name=conditioning_name, @@ -203,6 +205,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: truncate_long_prompts=False, # TODO: returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip requires_pooled=get_pooled, + device=TorchDevice.choose_torch_device(), ) conjunction = Compel.parse_prompt_string(prompt) @@ -313,7 +316,6 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput: ) ] ) - conditioning_name = context.conditioning.save(conditioning_data) return ConditioningOutput( diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index e94daf70bdd..db43723e339 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -1,4 +1,5 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) +import copy import inspect from contextlib import ExitStack from typing import Any, Dict, Iterator, List, Optional, Tuple, Union @@ -193,9 +194,8 @@ def _get_text_embeddings_and_masks( text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = [] text_embeddings_masks: list[Optional[torch.Tensor]] = [] for cond in cond_list: - cond_data = context.conditioning.load(cond.conditioning_name) + cond_data = copy.deepcopy(context.conditioning.load(cond.conditioning_name)) text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype)) - mask = cond.mask if mask is not None: mask = context.tensors.load(mask.tensor_name) @@ -226,6 +226,7 @@ def _preprocess_regional_prompt_mask( # Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w). mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w) resized_mask = tf(mask) + assert isinstance(resized_mask, torch.Tensor) return resized_mask def _concat_regional_text_embeddings( diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 1dc75add1d6..0ff902067d8 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -26,13 +26,13 @@ DEFAULT_RAM_CACHE = 10.0 DEFAULT_VRAM_CACHE = 0.25 DEFAULT_CONVERT_CACHE = 20.0 -DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"] -PRECISION = Literal["auto", "float16", "bfloat16", "float32"] +DEVICE = Literal["auto", "cpu", "cuda:0", "cuda:1", "cuda:2", "cuda:3", "cuda:4", "cuda:5", "cuda:6", "cuda:7", "mps"] +PRECISION = Literal["auto", "float16", "bfloat16", "float32", "autocast"] ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"] ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"] LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"] -CONFIG_SCHEMA_VERSION = "4.0.1" +CONFIG_SCHEMA_VERSION = "4.0.2" def get_default_ram_cache_size() -> float: @@ -105,14 +105,16 @@ class InvokeAIAppConfig(BaseSettings): convert_cache: Maximum size of on-disk converted models cache (GB). lazy_offload: Keep models in VRAM until their space is needed. log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour. - device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps` - precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.
Valid values: `auto`, `float16`, `bfloat16`, `float32` + device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
Valid values: `auto`, `cpu`, `cuda:0`, `cuda:1`, `cuda:2`, `cuda:3`, `cuda:4`, `cuda:5`, `cuda:6`, `cuda:7`, `mps` + devices: List of execution devices; will override default device selected. + precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.
Valid values: `auto`, `float16`, `bfloat16`, `float32`, `autocast` sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements. attention_type: Attention type.
Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp` attention_slice_size: Slice size, valid when attention_type=="sliced".
Valid values: `auto`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8` force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty). pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting. max_queue_size: Maximum number of items in the session queue. + max_threads: Maximum number of session queue execution threads. Autocalculated from number of GPUs if not set. clear_queue_on_startup: Empties session queue on startup. allow_nodes: List of nodes to allow. Omit to allow all. deny_nodes: List of nodes to deny. Omit to deny none. @@ -178,6 +180,7 @@ class InvokeAIAppConfig(BaseSettings): # DEVICE device: DEVICE = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.") + devices: Optional[list[DEVICE]] = Field(default=None, description="List of execution devices; will override default device selected.") precision: PRECISION = Field(default="auto", description="Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.") # GENERATION @@ -187,6 +190,7 @@ class InvokeAIAppConfig(BaseSettings): force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).") pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.") max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.") + max_threads: Optional[int] = Field(default=None, description="Maximum number of session queue execution threads. Autocalculated from number of GPUs if not set.") clear_queue_on_startup: bool = Field(default=False, description="Empties session queue on startup.") # NODES @@ -376,9 +380,6 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig: # `max_cache_size` was renamed to `ram` some time in v3, but both names were used if k == "max_cache_size" and "ram" not in category_dict: parsed_config_dict["ram"] = v - # `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used - if k == "max_vram_cache_size" and "vram" not in category_dict: - parsed_config_dict["vram"] = v # autocast was removed in v4.0.1 if k == "precision" and v == "autocast": parsed_config_dict["precision"] = "auto" @@ -426,6 +427,27 @@ def migrate_v4_0_0_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig return config +def migrate_v4_0_1_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig: + """Migrate v4.0.1 config dictionary to a current config object. + + A few new multi-GPU options were added in 4.0.2, and this simply + updates the schema label. + + Args: + config_dict: A dictionary of settings from a v4.0.1 config file. + + Returns: + An instance of `InvokeAIAppConfig` with the migrated settings. + """ + parsed_config_dict: dict[str, Any] = {} + for k, _ in config_dict.items(): + if k == "schema_version": + parsed_config_dict[k] = CONFIG_SCHEMA_VERSION + config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict) + return config + + +# TO DO: replace this with a formal registration and migration system def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig: """Load and migrate a config file to the latest version. @@ -457,6 +479,10 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig: loaded_config_dict = migrate_v4_0_0_config_dict(loaded_config_dict) loaded_config_dict.write_file(config_path) + elif loaded_config_dict["schema_version"] == "4.0.1": + loaded_config_dict = migrate_v4_0_1_config_dict(loaded_config_dict) + loaded_config_dict.write_file(config_path) + # Attempt to load as a v4 config file try: # Meta is not included in the model fields, so we need to validate it separately diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index f4fce6098f3..b23a3bb3277 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -53,11 +53,11 @@ def __init__( model_images: "ModelImageFileStorageBase", model_manager: "ModelManagerServiceBase", download_queue: "DownloadQueueServiceBase", - performance_statistics: "InvocationStatsServiceBase", session_queue: "SessionQueueBase", session_processor: "SessionProcessorBase", invocation_cache: "InvocationCacheBase", names: "NameServiceBase", + performance_statistics: "InvocationStatsServiceBase", urls: "UrlServiceBase", workflow_records: "WorkflowRecordsStorageBase", tensors: "ObjectSerializerBase[torch.Tensor]", @@ -77,11 +77,11 @@ def __init__( self.model_images = model_images self.model_manager = model_manager self.download_queue = download_queue - self.performance_statistics = performance_statistics self.session_queue = session_queue self.session_processor = session_processor self.invocation_cache = invocation_cache self.names = names + self.performance_statistics = performance_statistics self.urls = urls self.workflow_records = workflow_records self.tensors = tensors diff --git a/invokeai/app/services/invocation_stats/invocation_stats_default.py b/invokeai/app/services/invocation_stats/invocation_stats_default.py index 5a41f1f5d6b..2aa6f28f658 100644 --- a/invokeai/app/services/invocation_stats/invocation_stats_default.py +++ b/invokeai/app/services/invocation_stats/invocation_stats_default.py @@ -74,9 +74,9 @@ def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: st ) self._stats[graph_execution_state_id].add_node_execution_stats(node_stats) - def reset_stats(self): - self._stats = {} - self._cache_stats = {} + def reset_stats(self, graph_execution_state_id: str): + self._stats.pop(graph_execution_state_id) + self._cache_stats.pop(graph_execution_state_id) def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary: graph_stats_summary = self._get_graph_summary(graph_execution_state_id) diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index dd1b44d8999..0ac58aff98f 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -284,9 +284,14 @@ def prune_jobs(self) -> None: unfinished_jobs = [x for x in self._install_jobs if not x.in_terminal_state] self._install_jobs = unfinished_jobs - def _migrate_yaml(self) -> None: + def _migrate_yaml(self, rename_yaml: Optional[bool] = True, overwrite_db: Optional[bool] = False) -> None: db_models = self.record_store.all_models() + if overwrite_db: + for model in db_models: + self.record_store.del_model(model.key) + db_models = self.record_store.all_models() + legacy_models_yaml_path = ( self._app_config.legacy_models_yaml_path or self._app_config.root_path / "configs" / "models.yaml" ) @@ -336,7 +341,8 @@ def _migrate_yaml(self) -> None: self._logger.warning(f"Model at {model_path} could not be migrated: {e}") # Rename `models.yaml` to `models.yaml.bak` to prevent re-migration - legacy_models_yaml_path.rename(legacy_models_yaml_path.with_suffix(".yaml.bak")) + if rename_yaml: + legacy_models_yaml_path.rename(legacy_models_yaml_path.with_suffix(".yaml.bak")) # Unset the path - we are done with it either way self._app_config.legacy_models_yaml_path = None diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py index da567721956..990f8ca207e 100644 --- a/invokeai/app/services/model_load/model_load_base.py +++ b/invokeai/app/services/model_load/model_load_base.py @@ -33,6 +33,11 @@ def ram_cache(self) -> ModelCacheBase[AnyModel]: def convert_cache(self) -> ModelConvertCacheBase: """Return the checkpoint convert cache used by this loader.""" + @property + @abstractmethod + def gpu_count(self) -> int: + """Return the number of GPUs we are configured to use.""" + @abstractmethod def load_model_from_path( self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index 70674813785..00e14f0d72f 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -46,6 +46,7 @@ def __init__( self._registry = registry def start(self, invoker: Invoker) -> None: + """Start the service.""" self._invoker = invoker @property @@ -53,6 +54,11 @@ def ram_cache(self) -> ModelCacheBase[AnyModel]: """Return the RAM cache used by this loader.""" return self._ram_cache + @property + def gpu_count(self) -> int: + """Return the number of GPUs available for our uses.""" + return len(self._ram_cache.execution_devices) + @property def convert_cache(self) -> ModelConvertCacheBase: """Return the checkpoint convert cache used by this loader.""" diff --git a/invokeai/app/services/model_manager/model_manager_base.py b/invokeai/app/services/model_manager/model_manager_base.py index af1b68e1ec3..d20aefd4f3a 100644 --- a/invokeai/app/services/model_manager/model_manager_base.py +++ b/invokeai/app/services/model_manager/model_manager_base.py @@ -1,6 +1,7 @@ # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team from abc import ABC, abstractmethod +from typing import Optional, Set import torch from typing_extensions import Self @@ -31,7 +32,7 @@ def build_model_manager( model_record_service: ModelRecordServiceBase, download_queue: DownloadQueueServiceBase, events: EventServiceBase, - execution_device: torch.device, + execution_devices: Optional[Set[torch.device]] = None, ) -> Self: """ Construct the model manager service instance. diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index 1a2b9a34022..6ff1b7de675 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -1,14 +1,10 @@ # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team """Implementation of ModelManagerServiceBase.""" -from typing import Optional - -import torch from typing_extensions import Self from invokeai.app.services.invoker import Invoker from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry -from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.logging import InvokeAILogger from ..config import InvokeAIAppConfig @@ -69,7 +65,6 @@ def build_model_manager( model_record_service: ModelRecordServiceBase, download_queue: DownloadQueueServiceBase, events: EventServiceBase, - execution_device: Optional[torch.device] = None, ) -> Self: """ Construct the model manager service instance. @@ -82,9 +77,7 @@ def build_model_manager( ram_cache = ModelCache( max_cache_size=app_config.ram, max_vram_cache_size=app_config.vram, - lazy_offloading=app_config.lazy_offload, logger=logger, - execution_device=execution_device or TorchDevice.choose_torch_device(), ) convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache) loader = ModelLoadService( diff --git a/invokeai/app/services/object_serializer/object_serializer_disk.py b/invokeai/app/services/object_serializer/object_serializer_disk.py index 8edd29e1505..0c9567553a6 100644 --- a/invokeai/app/services/object_serializer/object_serializer_disk.py +++ b/invokeai/app/services/object_serializer/object_serializer_disk.py @@ -1,5 +1,6 @@ import shutil import tempfile +import threading import typing from pathlib import Path from typing import TYPE_CHECKING, Optional, TypeVar @@ -9,6 +10,7 @@ from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError from invokeai.app.util.misc import uuid_string +from invokeai.backend.util.devices import TorchDevice if TYPE_CHECKING: from invokeai.app.services.invoker import Invoker @@ -70,7 +72,10 @@ def _get_path(self, name: str) -> Path: return self._output_dir / name def _new_name(self) -> str: - return f"{self._obj_class_name}_{uuid_string()}" + tid = threading.current_thread().ident + # Add tid to the object name because uuid4 not thread-safe on windows + # See https://stackoverflow.com/questions/2759644/python-multiprocessing-doesnt-play-nicely-with-uuid-uuid4 + return f"{self._obj_class_name}_{tid}-{uuid_string()}" def _tempdir_cleanup(self) -> None: """Calls `cleanup` on the temporary directory, if it exists.""" diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 3f348fb239d..6ca4e164e06 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -1,8 +1,9 @@ import traceback from contextlib import suppress -from threading import BoundedSemaphore, Thread +from queue import Queue +from threading import BoundedSemaphore, Lock, Thread from threading import Event as ThreadEvent -from typing import Optional +from typing import Optional, Set from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput from invokeai.app.services.events.events_common import ( @@ -26,6 +27,7 @@ from invokeai.app.services.shared.graph import NodeInputError from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context from invokeai.app.util.profiler import Profiler +from invokeai.backend.util.devices import TorchDevice from ..invoker import Invoker from .session_processor_base import InvocationServices, SessionProcessorBase, SessionRunnerBase @@ -57,8 +59,11 @@ def __init__( self._on_after_run_node_callbacks = on_after_run_node_callbacks or [] self._on_node_error_callbacks = on_node_error_callbacks or [] self._on_after_run_session_callbacks = on_after_run_session_callbacks or [] + self._process_lock = Lock() - def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None): + def start( + self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None + ) -> None: self._services = services self._cancel_event = cancel_event self._profiler = profiler @@ -76,7 +81,8 @@ def run(self, queue_item: SessionQueueItem): # Loop over invocations until the session is complete or canceled while True: try: - invocation = queue_item.session.next() + with self._process_lock: + invocation = queue_item.session.next() # Anything other than a `NodeInputError` is handled as a processor error except NodeInputError as e: error_type = e.__class__.__name__ @@ -108,7 +114,7 @@ def run(self, queue_item: SessionQueueItem): self._on_after_run_session(queue_item=queue_item) - def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): + def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None: try: # Any unhandled exception in this scope is an invocation error & will fail the graph with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id): @@ -210,7 +216,7 @@ def _on_after_run_session(self, queue_item: SessionQueueItem) -> None: # we don't care about that - suppress the error. with suppress(GESStatsNotFoundError): self._services.performance_statistics.log_stats(queue_item.session.id) - self._services.performance_statistics.reset_stats() + self._services.performance_statistics.reset_stats(queue_item.session.id) for callback in self._on_after_run_session_callbacks: callback(queue_item=queue_item) @@ -324,7 +330,7 @@ def __init__( def start(self, invoker: Invoker) -> None: self._invoker: Invoker = invoker - self._queue_item: Optional[SessionQueueItem] = None + self._active_queue_items: Set[SessionQueueItem] = set() self._invocation: Optional[BaseInvocation] = None self._resume_event = ThreadEvent() @@ -350,7 +356,14 @@ def start(self, invoker: Invoker) -> None: else None ) + self._worker_thread_count = self._invoker.services.configuration.max_threads or len( + TorchDevice.execution_devices() + ) + + self._session_worker_queue: Queue[SessionQueueItem] = Queue() + self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event, profiler=self._profiler) + # Session processor - singlethreaded self._thread = Thread( name="session_processor", target=self._process, @@ -363,6 +376,16 @@ def start(self, invoker: Invoker) -> None: ) self._thread.start() + # Session processor workers - multithreaded + self._invoker.services.logger.debug(f"Starting {self._worker_thread_count} session processing threads.") + for _i in range(0, self._worker_thread_count): + worker = Thread( + name="session_worker", + target=self._process_next_session, + daemon=True, + ) + worker.start() + def stop(self, *args, **kwargs) -> None: self._stop_event.set() @@ -370,7 +393,7 @@ def _poll_now(self) -> None: self._poll_now_event.set() async def _on_queue_cleared(self, event: FastAPIEvent[QueueClearedEvent]) -> None: - if self._queue_item and self._queue_item.queue_id == event[1].queue_id: + if any(item.queue_id == event[1].queue_id for item in self._active_queue_items): self._cancel_event.set() self._poll_now() @@ -378,7 +401,7 @@ async def _on_batch_enqueued(self, event: FastAPIEvent[BatchEnqueuedEvent]) -> N self._poll_now() async def _on_queue_item_status_changed(self, event: FastAPIEvent[QueueItemStatusChangedEvent]) -> None: - if self._queue_item and event[1].status in ["completed", "failed", "canceled"]: + if self._active_queue_items and event[1].status in ["completed", "failed", "canceled"]: # When the queue item is canceled via HTTP, the queue item status is set to `"canceled"` and this event is # emitted. We need to respond to this event and stop graph execution. This is done by setting the cancel # event, which the session runner checks between invocations. If set, the session runner loop is broken. @@ -403,7 +426,7 @@ def pause(self) -> SessionProcessorStatus: def get_status(self) -> SessionProcessorStatus: return SessionProcessorStatus( is_started=self._resume_event.is_set(), - is_processing=self._queue_item is not None, + is_processing=len(self._active_queue_items) > 0, ) def _process( @@ -428,30 +451,22 @@ def _process( resume_event.wait() # Get the next session to process - self._queue_item = self._invoker.services.session_queue.dequeue() + queue_item = self._invoker.services.session_queue.dequeue() - if self._queue_item is None: + if queue_item is None: # The queue was empty, wait for next polling interval or event to try again self._invoker.services.logger.debug("Waiting for next polling interval or event") poll_now_event.wait(self._polling_interval) continue - self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}") + self._session_worker_queue.put(queue_item) + self._invoker.services.logger.debug(f"Scheduling queue item {queue_item.item_id} to run") cancel_event.clear() # Run the graph - self.session_runner.run(queue_item=self._queue_item) - - except Exception as e: - error_type = e.__class__.__name__ - error_message = str(e) - error_traceback = traceback.format_exc() - self._on_non_fatal_processor_error( - queue_item=self._queue_item, - error_type=error_type, - error_message=error_message, - error_traceback=error_traceback, - ) + # self.session_runner.run(queue_item=self._queue_item) + + except Exception: # Wait for next polling interval or event to try again poll_now_event.wait(self._polling_interval) continue @@ -466,9 +481,25 @@ def _process( finally: stop_event.clear() poll_now_event.clear() - self._queue_item = None self._thread_semaphore.release() + def _process_next_session(self) -> None: + while True: + self._resume_event.wait() + queue_item = self._session_worker_queue.get() + if queue_item.status == "canceled": + continue + try: + self._active_queue_items.add(queue_item) + # reserve a GPU for this session - may block + with self._invoker.services.model_manager.load.ram_cache.reserve_execution_device(): + # Run the session on the reserved GPU + self.session_runner.run(queue_item=queue_item) + except Exception: + continue + finally: + self._active_queue_items.remove(queue_item) + def _on_non_fatal_processor_error( self, queue_item: Optional[SessionQueueItem], diff --git a/invokeai/app/services/session_queue/session_queue_common.py b/invokeai/app/services/session_queue/session_queue_common.py index 7f4601eba73..3cff330cff7 100644 --- a/invokeai/app/services/session_queue/session_queue_common.py +++ b/invokeai/app/services/session_queue/session_queue_common.py @@ -236,6 +236,9 @@ def queue_item_dto_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItemDTO } ) + def __hash__(self) -> int: + return self.item_id + class SessionQueueItemDTO(SessionQueueItemWithoutGraph): pass diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index d745e738233..60fd909881b 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -652,7 +652,7 @@ def _is_iterator_connection_valid( output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs] # Input type must be a list - if get_origin(input_field) != list: + if get_origin(input_field) is not list: return False # Validate that all outputs match the input type diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 01662335e46..e8f3d083b13 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -2,6 +2,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Callable, Optional, Union +import torch from PIL.Image import Image from pydantic.networks import AnyHttpUrl from torch import Tensor @@ -26,11 +27,13 @@ from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData +from invokeai.backend.util.devices import TorchDevice if TYPE_CHECKING: from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem + from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase """ The InvocationContext provides access to various services and data about the current invocation. @@ -323,7 +326,6 @@ def load(self, name: str) -> ConditioningFieldData: Returns: The loaded conditioning data. """ - return self._services.conditioning.load(name) @@ -557,6 +559,28 @@ def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_m is_canceled=self.is_canceled, ) + def torch_device(self) -> torch.device: + """ + Return a torch device to use in the current invocation. + + Returns: + A torch.device not currently in use by the system. + """ + ram_cache: "ModelCacheBase[AnyModel]" = self._services.model_manager.load.ram_cache + return ram_cache.get_execution_device() + + def torch_dtype(self, device: Optional[torch.device] = None) -> torch.dtype: + """ + Return a precision type to use with the current invocation and torch device. + + Args: + device: Optional device. + + Returns: + A torch.dtype suited for the current device. + """ + return TorchDevice.choose_torch_dtype(device) + class InvocationContext: """Provides access to various services and data for the current invocation. diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 7ed12a7674d..c0065cefa9b 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -25,6 +25,7 @@ from typing import Literal, Optional, Type, TypeAlias, Union import torch +from diffusers.configuration_utils import ConfigMixin from diffusers.models.modeling_utils import ModelMixin from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter from typing_extensions import Annotated, Any, Dict @@ -37,7 +38,7 @@ # ModelMixin is the base class for all diffusers and transformers models # RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime -AnyModel = Union[ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor]] +AnyModel = Union[ConfigMixin, ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor]] class InvalidModelConfigException(Exception): @@ -177,6 +178,7 @@ class ModelConfigBase(BaseModel): @staticmethod def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None: + """Extend the pydantic schema from a json.""" schema["required"].extend(["key", "type", "format"]) model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra) @@ -443,7 +445,7 @@ def make_config( model = dest_class.model_validate(model_data) else: # mypy doesn't typecheck TypeAdapters well? - model = AnyModelConfigValidator.validate_python(model_data) # type: ignore + model = AnyModelConfigValidator.validate_python(model_data) assert model is not None if key: model.key = key diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index 6748e85dca1..9291e599456 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -65,8 +65,7 @@ class LoadedModelWithoutConfig: def __enter__(self) -> AnyModel: """Context entry.""" - self._locker.lock() - return self.model + return self._locker.lock() def __exit__(self, *args: Any, **kwargs: Any) -> None: """Context exit.""" diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py index 012fd42d556..4fe99c31e6d 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py @@ -8,9 +8,10 @@ """ from abc import ABC, abstractmethod +from contextlib import contextmanager from dataclasses import dataclass, field from logging import Logger -from typing import Dict, Generic, Optional, TypeVar +from typing import Dict, Generator, Generic, Optional, Set, TypeVar import torch @@ -51,44 +52,13 @@ class CacheRecord(Generic[T]): Elements of the cache: key: Unique key for each model, same as used in the models database. - model: Model in memory. - state_dict: A read-only copy of the model's state dict in RAM. It will be - used as a template for creating a copy in the VRAM. + model: Read-only copy of the model *without weights* residing in the "meta device" size: Size of the model - loaded: True if the model's state dict is currently in VRAM - - Before a model is executed, the state_dict template is copied into VRAM, - and then injected into the model. When the model is finished, the VRAM - copy of the state dict is deleted, and the RAM version is reinjected - into the model. - - The state_dict should be treated as a read-only attribute. Do not attempt - to patch or otherwise modify it. Instead, patch the copy of the state_dict - after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel` - context manager call `model_on_device()`. """ key: str - model: T - device: torch.device - state_dict: Optional[Dict[str, torch.Tensor]] size: int - loaded: bool = False - _locks: int = 0 - - def lock(self) -> None: - """Lock this record.""" - self._locks += 1 - - def unlock(self) -> None: - """Unlock this record.""" - self._locks -= 1 - assert self._locks >= 0 - - @property - def locked(self) -> bool: - """Return true if record is locked.""" - return self._locks > 0 + model: T @dataclass @@ -115,30 +85,33 @@ def storage_device(self) -> torch.device: @property @abstractmethod - def execution_device(self) -> torch.device: - """Return the exection device (e.g. "cuda" for VRAM).""" + def execution_devices(self) -> Set[torch.device]: + """Return the set of available execution devices.""" pass - @property + @contextmanager @abstractmethod - def lazy_offloading(self) -> bool: - """Return true if the cache is configured to lazily offload models in VRAM.""" + def reserve_execution_device(self, timeout: int = 0) -> Generator[torch.device, None, None]: + """Reserve an execution device (GPU) under the current thread id.""" pass - @property @abstractmethod - def max_cache_size(self) -> float: - """Return true if the cache is configured to lazily offload models in VRAM.""" - pass + def get_execution_device(self) -> torch.device: + """ + Return an execution device that has been reserved for current thread. - @abstractmethod - def offload_unlocked_models(self, size_required: int) -> None: - """Offload from VRAM any models not actively in use.""" + Note that reservations are done using the current thread's TID. + It might be better to do this using the session ID, but that involves + too many detailed changes to model manager calls. + + May generate a ValueError if no GPU has been reserved. + """ pass + @property @abstractmethod - def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None: - """Move model into the indicated device.""" + def max_cache_size(self) -> float: + """Return true if the cache is configured to lazily offload models in VRAM.""" pass @property @@ -202,6 +175,11 @@ def exists( """Return true if the model identified by key and submodel_type is in the cache.""" pass + @abstractmethod + def model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> AnyModel: + """Move a copy of the model into the indicated device and return it.""" + pass + @abstractmethod def cache_size(self) -> int: """Get the total size of the models currently cached.""" diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index d48e45426e3..31a10b6ea40 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -18,17 +18,19 @@ """ +import copy import gc -import math -import time -from contextlib import suppress +import sys +import threading +from contextlib import contextmanager, suppress from logging import Logger -from typing import Dict, List, Optional +from threading import BoundedSemaphore +from typing import Dict, Generator, List, Optional, Set import torch from invokeai.backend.model_manager import AnyModel, SubModelType -from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff +from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.logging import InvokeAILogger @@ -39,9 +41,7 @@ # Maximum size of the cache, in gigs # Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously DEFAULT_MAX_CACHE_SIZE = 6.0 - -# amount of GPU memory to hold in reserve for use by generations (GB) -DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75 +DEFAULT_MAX_VRAM_CACHE_SIZE = 0.25 # actual size of a gig GIG = 1073741824 @@ -57,12 +57,8 @@ def __init__( self, max_cache_size: float = DEFAULT_MAX_CACHE_SIZE, max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE, - execution_device: torch.device = torch.device("cuda"), storage_device: torch.device = torch.device("cpu"), precision: torch.dtype = torch.float16, - sequential_offload: bool = False, - lazy_offloading: bool = True, - sha_chunksize: int = 16777216, log_memory_usage: bool = False, logger: Optional[Logger] = None, ): @@ -70,23 +66,19 @@ def __init__( Initialize the model RAM cache. :param max_cache_size: Maximum size of the RAM cache [6.0 GB] - :param execution_device: Torch device to load active model into [torch.device('cuda')] :param storage_device: Torch device to save inactive model in [torch.device('cpu')] :param precision: Precision for loaded models [torch.float16] - :param lazy_offloading: Keep model in VRAM until another model needs to be loaded :param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially :param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's behaviour. """ - # allow lazy offloading only when vram cache enabled - self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0 self._precision: torch.dtype = precision self._max_cache_size: float = max_cache_size self._max_vram_cache_size: float = max_vram_cache_size - self._execution_device: torch.device = execution_device self._storage_device: torch.device = storage_device + self._ram_lock = threading.Lock() self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__) self._log_memory_usage = log_memory_usage self._stats: Optional[CacheStats] = None @@ -94,25 +86,87 @@ def __init__( self._cached_models: Dict[str, CacheRecord[AnyModel]] = {} self._cache_stack: List[str] = [] + # device to thread id + self._device_lock = threading.Lock() + self._execution_devices: Dict[torch.device, int] = {x: 0 for x in TorchDevice.execution_devices()} + self._free_execution_device = BoundedSemaphore(len(self._execution_devices)) + + self.logger.info( + f"Using rendering device(s): {', '.join(sorted([str(x) for x in self._execution_devices.keys()]))}" + ) + @property def logger(self) -> Logger: """Return the logger used by the cache.""" return self._logger - @property - def lazy_offloading(self) -> bool: - """Return true if the cache is configured to lazily offload models in VRAM.""" - return self._lazy_offloading - @property def storage_device(self) -> torch.device: """Return the storage device (e.g. "CPU" for RAM).""" return self._storage_device @property - def execution_device(self) -> torch.device: - """Return the exection device (e.g. "cuda" for VRAM).""" - return self._execution_device + def execution_devices(self) -> Set[torch.device]: + """Return the set of available execution devices.""" + devices = self._execution_devices.keys() + return set(devices) + + def get_execution_device(self) -> torch.device: + """ + Return an execution device that has been reserved for current thread. + + Note that reservations are done using the current thread's TID. + It would be better to do this using the session ID, but that involves + too many detailed changes to model manager calls. + + May generate a ValueError if no GPU has been reserved. + """ + current_thread = threading.current_thread().ident + assert current_thread is not None + assigned = [x for x, tid in self._execution_devices.items() if current_thread == tid] + if not assigned: + raise ValueError(f"No GPU has been reserved for the use of thread {current_thread}") + return assigned[0] + + @contextmanager + def reserve_execution_device(self, timeout: Optional[int] = None) -> Generator[torch.device, None, None]: + """Reserve an execution device (e.g. GPU) for exclusive use by a generation thread. + + Note that the reservation is done using the current thread's TID. + It would be better to do this using the session ID, but that involves + too many detailed changes to model manager calls. + """ + device = None + with self._device_lock: + current_thread = threading.current_thread().ident + assert current_thread is not None + + # look for a device that has already been assigned to this thread + assigned = [x for x, tid in self._execution_devices.items() if current_thread == tid] + if assigned: + device = assigned[0] + + # no device already assigned. Get one. + if device is None: + self._free_execution_device.acquire(timeout=timeout) + with self._device_lock: + free_device = [x for x, tid in self._execution_devices.items() if tid == 0] + self._execution_devices[free_device[0]] = current_thread + device = free_device[0] + + # we are outside the lock region now + self.logger.info(f"{current_thread} Reserved torch device {device}") + + # Tell TorchDevice to use this object to get the torch device. + TorchDevice.set_model_cache(self) + try: + yield device + finally: + with self._device_lock: + self.logger.info(f"{current_thread} Released torch device {device}") + self._execution_devices[device] = 0 + self._free_execution_device.release() + torch.cuda.empty_cache() @property def max_cache_size(self) -> float: @@ -157,16 +211,16 @@ def put( submodel_type: Optional[SubModelType] = None, ) -> None: """Store model under key and optional submodel_type.""" - key = self._make_cache_key(key, submodel_type) - if key in self._cached_models: - return - size = calc_model_size_by_data(model) - self.make_room(size) + with self._ram_lock: + key = self._make_cache_key(key, submodel_type) + if key in self._cached_models: + return + size = calc_model_size_by_data(model) + self.make_room(size) - state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None - cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size) - self._cached_models[key] = cache_record - self._cache_stack.append(key) + cache_record = CacheRecord(key=key, model=model, size=size) + self._cached_models[key] = cache_record + self._cache_stack.append(key) def get( self, @@ -184,35 +238,36 @@ def get( This may raise an IndexError if the model is not in the cache. """ - key = self._make_cache_key(key, submodel_type) - if key in self._cached_models: - if self.stats: - self.stats.hits += 1 - else: + with self._ram_lock: + key = self._make_cache_key(key, submodel_type) + if key in self._cached_models: + if self.stats: + self.stats.hits += 1 + else: + if self.stats: + self.stats.misses += 1 + raise IndexError(f"The model with key {key} is not in the cache.") + + cache_entry = self._cached_models[key] + + # more stats if self.stats: - self.stats.misses += 1 - raise IndexError(f"The model with key {key} is not in the cache.") - - cache_entry = self._cached_models[key] - - # more stats - if self.stats: - stats_name = stats_name or key - self.stats.cache_size = int(self._max_cache_size * GIG) - self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size()) - self.stats.in_cache = len(self._cached_models) - self.stats.loaded_model_sizes[stats_name] = max( - self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size - ) + stats_name = stats_name or key + self.stats.cache_size = int(self._max_cache_size * GIG) + self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size()) + self.stats.in_cache = len(self._cached_models) + self.stats.loaded_model_sizes[stats_name] = max( + self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size + ) - # this moves the entry to the top (right end) of the stack - with suppress(Exception): - self._cache_stack.remove(key) - self._cache_stack.append(key) - return ModelLocker( - cache=self, - cache_entry=cache_entry, - ) + # this moves the entry to the top (right end) of the stack + with suppress(Exception): + self._cache_stack.remove(key) + self._cache_stack.append(key) + return ModelLocker( + cache=self, + cache_entry=cache_entry, + ) def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]: if self._log_memory_usage: @@ -225,127 +280,34 @@ def _make_cache_key(self, model_key: str, submodel_type: Optional[SubModelType] else: return model_key - def offload_unlocked_models(self, size_required: int) -> None: - """Move any unused models from VRAM.""" - reserved = self._max_vram_cache_size * GIG - vram_in_use = torch.cuda.memory_allocated() + size_required - self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB") - for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size): - if vram_in_use <= reserved: - break - if not cache_entry.loaded: - continue - if not cache_entry.locked: - self.move_model_to_device(cache_entry, self.storage_device) - cache_entry.loaded = False - vram_in_use = torch.cuda.memory_allocated() + size_required - self.logger.debug( - f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB" - ) - - TorchDevice.empty_cache() - - def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None: - """Move model into the indicated device. + def model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> AnyModel: + """Move a copy of the model into the indicated device and return it. :param cache_entry: The CacheRecord for the model :param target_device: The torch.device to move the model into May raise a torch.cuda.OutOfMemoryError """ - self.logger.debug(f"Called to move {cache_entry.key} to {target_device}") - source_device = cache_entry.device - - # Note: We compare device types only so that 'cuda' == 'cuda:0'. - # This would need to be revised to support multi-GPU. - if torch.device(source_device).type == torch.device(target_device).type: - return - - # Some models don't have a `to` method, in which case they run in RAM/CPU. - if not hasattr(cache_entry.model, "to"): - return - - # This roundabout method for moving the model around is done to avoid - # the cost of moving the model from RAM to VRAM and then back from VRAM to RAM. - # When moving to VRAM, we copy (not move) each element of the state dict from - # RAM to a new state dict in VRAM, and then inject it into the model. - # This operation is slightly faster than running `to()` on the whole model. - # - # When the model needs to be removed from VRAM we simply delete the copy - # of the state dict in VRAM, and reinject the state dict that is cached - # in RAM into the model. So this operation is very fast. - start_model_to_time = time.time() - snapshot_before = self._capture_memory_snapshot() - - try: - if cache_entry.state_dict is not None: - assert hasattr(cache_entry.model, "load_state_dict") - if target_device == self.storage_device: - cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True) - else: - new_dict: Dict[str, torch.Tensor] = {} - for k, v in cache_entry.state_dict.items(): - new_dict[k] = v.to(torch.device(target_device), copy=True, non_blocking=True) - cache_entry.model.load_state_dict(new_dict, assign=True) - cache_entry.model.to(target_device, non_blocking=True) - cache_entry.device = target_device - except Exception as e: # blow away cache entry - self._delete_cache_entry(cache_entry) - raise e - - snapshot_after = self._capture_memory_snapshot() - end_model_to_time = time.time() - self.logger.debug( - f"Moved model '{cache_entry.key}' from {source_device} to" - f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s." - f"Estimated model size: {(cache_entry.size/GIG):.3f} GB." - f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" - ) - - if ( - snapshot_before is not None - and snapshot_after is not None - and snapshot_before.vram is not None - and snapshot_after.vram is not None - ): - vram_change = abs(snapshot_before.vram - snapshot_after.vram) - - # If the estimated model size does not match the change in VRAM, log a warning. - if not math.isclose( - vram_change, - cache_entry.size, - rel_tol=0.1, - abs_tol=10 * MB, - ): - self.logger.debug( - f"Moving model '{cache_entry.key}' from {source_device} to" - f" {target_device} caused an unexpected change in VRAM usage. The model's" - " estimated size may be incorrect. Estimated model size:" - f" {(cache_entry.size/GIG):.3f} GB.\n" - f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" - ) + with self._ram_lock: + self.logger.debug(f"Called to move {cache_entry.key} ({type(cache_entry.model)=}) to {target_device}") + + # Some models don't have a state dictionary, in which case the + # stored model will still reside in CPU + if hasattr(cache_entry.model, "to"): + model_in_gpu = copy.deepcopy(cache_entry.model) + assert hasattr(model_in_gpu, "to") + model_in_gpu.to(target_device) + return model_in_gpu + else: + return cache_entry.model # what happens in CPU stays in CPU def print_cuda_stats(self) -> None: """Log CUDA diagnostics.""" vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG) ram = "%4.2fG" % (self.cache_size() / GIG) - in_ram_models = 0 - in_vram_models = 0 - locked_in_vram_models = 0 - for cache_record in self._cached_models.values(): - if hasattr(cache_record.model, "device"): - if cache_record.model.device == self.storage_device: - in_ram_models += 1 - else: - in_vram_models += 1 - if cache_record.locked: - locked_in_vram_models += 1 - - self.logger.debug( - f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) =" - f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})" - ) + in_ram_models = len(self._cached_models) + self.logger.debug(f"Current VRAM/RAM usage for {in_ram_models} models: {vram}/{ram}") def make_room(self, size: int) -> None: """Make enough room in the cache to accommodate a new model of indicated size.""" @@ -368,12 +330,14 @@ def make_room(self, size: int) -> None: while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack): model_key = self._cache_stack[pos] cache_entry = self._cached_models[model_key] - device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None - self.logger.debug( - f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}" - ) - if not cache_entry.locked: + refs = sys.getrefcount(cache_entry.model) + + # Expected refs: + # 1 from cache_entry + # 1 from getrefcount function + # 1 from onnx runtime object + if refs <= (3 if "onnx" in model_key else 2): self.logger.debug( f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)" ) @@ -400,10 +364,26 @@ def make_room(self, size: int) -> None: if self.stats: self.stats.cleared = models_cleared gc.collect() - TorchDevice.empty_cache() self.logger.debug(f"After making room: cached_models={len(self._cached_models)}") + def _check_free_vram(self, target_device: torch.device, needed_size: int) -> None: + if target_device.type != "cuda": + return + vram_device = ( # mem_get_info() needs an indexed device + target_device if target_device.index is not None else torch.device(str(target_device), index=0) + ) + free_mem, _ = torch.cuda.mem_get_info(torch.device(vram_device)) + if needed_size > free_mem: + raise torch.cuda.OutOfMemoryError + def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None: - self._cache_stack.remove(cache_entry.key) - del self._cached_models[cache_entry.key] + try: + self._cache_stack.remove(cache_entry.key) + del self._cached_models[cache_entry.key] + except ValueError: + pass + + @staticmethod + def _device_name(device: torch.device) -> str: + return f"{device.type}:{device.index}" diff --git a/invokeai/backend/model_manager/load/model_cache/model_locker.py b/invokeai/backend/model_manager/load/model_cache/model_locker.py index 9de17ca5f53..68af7ba97a9 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_locker.py +++ b/invokeai/backend/model_manager/load/model_cache/model_locker.py @@ -10,6 +10,8 @@ from .model_cache_base import CacheRecord, ModelCacheBase, ModelLockerBase +MAX_GPU_WAIT = 600 # wait up to 10 minutes for a GPU to become free + class ModelLocker(ModelLockerBase): """Internal class that mediates movement in and out of GPU.""" @@ -29,33 +31,29 @@ def model(self) -> AnyModel: """Return the model without moving it around.""" return self._cache_entry.model - def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]: - """Return the state dict (if any) for the cached model.""" - return self._cache_entry.state_dict - def lock(self) -> AnyModel: """Move the model into the execution device (GPU) and lock it.""" - self._cache_entry.lock() try: - if self._cache.lazy_offloading: - self._cache.offload_unlocked_models(self._cache_entry.size) - self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device) - self._cache_entry.loaded = True - self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}") + device = self._cache.get_execution_device() + model_on_device = self._cache.model_to_device(self._cache_entry, device) + self._cache.logger.debug(f"Moved {self._cache_entry.key} to {device}") self._cache.print_cuda_stats() except torch.cuda.OutOfMemoryError: self._cache.logger.warning("Insufficient GPU memory to load model. Aborting") - self._cache_entry.unlock() raise except Exception: - self._cache_entry.unlock() raise - return self.model + return model_on_device + # It is no longer necessary to move the model out of VRAM + # because it will be removed when it goes out of scope + # in the caller's context def unlock(self) -> None: """Call upon exit from context.""" - self._cache_entry.unlock() - if not self._cache.lazy_offloading: - self._cache.offload_unlocked_models(0) - self._cache.print_cuda_stats() + self._cache.print_cuda_stats() + + # This is no longer in use in MGPU. + def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]: + """Return the state dict (if any) for the cached model.""" + return None diff --git a/invokeai/backend/model_patcher.py b/invokeai/backend/model_patcher.py index fdc79539ae7..b879c3d4e80 100644 --- a/invokeai/backend/model_patcher.py +++ b/invokeai/backend/model_patcher.py @@ -4,6 +4,7 @@ from __future__ import annotations import pickle +import threading from contextlib import contextmanager from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union @@ -34,6 +35,8 @@ # TODO: rename smth like ModelPatcher and add TI method? class ModelPatcher: + _thread_lock = threading.Lock() + @staticmethod def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]: assert "." not in lora_key @@ -106,7 +109,7 @@ def apply_lora( """ original_weights = {} try: - with torch.no_grad(): + with torch.no_grad(), cls._thread_lock: for lora, lora_weight in loras: # assert lora.device.type == "cpu" for layer_key, layer in lora.layers.items(): @@ -129,9 +132,7 @@ def apply_lora( dtype = module.weight.dtype if module_key not in original_weights: - if model_state_dict is not None: # we were provided with the CPU copy of the state dict - original_weights[module_key] = model_state_dict[module_key + ".weight"] - else: + if model_state_dict is None: # no CPU copy of the state dict was provided original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True) layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 85950a01df5..01aae6b5a49 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -32,8 +32,11 @@ class SDXLConditioningInfo(BasicConditioningInfo): def to(self, device, dtype=None): self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype) + assert self.pooled_embeds.device == device self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype) - return super().to(device=device, dtype=dtype) + result = super().to(device=device, dtype=dtype) + assert self.embeds.device == device + return result @dataclass diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index f418133e49f..a8f47247eca 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -1,6 +1,7 @@ from __future__ import annotations import math +import threading from typing import Any, Callable, Optional, Union import torch @@ -293,24 +294,31 @@ def _apply_standard_conditioning( cross_attention_kwargs["regional_ip_data"] = regional_ip_data added_cond_kwargs = None - if conditioning_data.is_sdxl(): - added_cond_kwargs = { - "text_embeds": torch.cat( - [ - # TODO: how to pad? just by zeros? or even truncate? - conditioning_data.uncond_text.pooled_embeds, - conditioning_data.cond_text.pooled_embeds, - ], - dim=0, - ), - "time_ids": torch.cat( - [ - conditioning_data.uncond_text.add_time_ids, - conditioning_data.cond_text.add_time_ids, - ], - dim=0, - ), - } + try: + if conditioning_data.is_sdxl(): + # tid = threading.current_thread().ident + # print(f'DEBUG {tid} {conditioning_data.uncond_text.pooled_embeds.device=} {conditioning_data.cond_text.pooled_embeds.device=}', flush=True), + added_cond_kwargs = { + "text_embeds": torch.cat( + [ + # TODO: how to pad? just by zeros? or even truncate? + conditioning_data.uncond_text.pooled_embeds, + conditioning_data.cond_text.pooled_embeds, + ], + dim=0, + ), + "time_ids": torch.cat( + [ + conditioning_data.uncond_text.add_time_ids, + conditioning_data.cond_text.add_time_ids, + ], + dim=0, + ), + } + except Exception as e: + tid = threading.current_thread().ident + print(f"DEBUG: {tid} {str(e)}") + raise e if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None: # TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index e8380dc8bcd..dc2bafaa9c4 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -1,10 +1,16 @@ -from typing import Dict, Literal, Optional, Union +"""Torch Device class provides torch device selection services.""" + +from typing import TYPE_CHECKING, Dict, Literal, Optional, Set, Union import torch from deprecated import deprecated from invokeai.app.services.config.config_default import get_config +if TYPE_CHECKING: + from invokeai.backend.model_manager.config import AnyModel + from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase + # legacy APIs TorchPrecisionNames = Literal["float32", "float16", "bfloat16"] CPU_DEVICE = torch.device("cpu") @@ -42,9 +48,23 @@ def torch_dtype(device: torch.device) -> torch.dtype: class TorchDevice: """Abstraction layer for torch devices.""" + _model_cache: Optional["ModelCacheBase[AnyModel]"] = None + + @classmethod + def set_model_cache(cls, cache: "ModelCacheBase[AnyModel]"): + """Set the current model cache.""" + cls._model_cache = cache + @classmethod def choose_torch_device(cls) -> torch.device: """Return the torch.device to use for accelerated inference.""" + if cls._model_cache: + return cls._model_cache.get_execution_device() + else: + return cls._choose_device() + + @classmethod + def _choose_device(cls) -> torch.device: app_config = get_config() if app_config.device != "auto": device = torch.device(app_config.device) @@ -56,11 +76,19 @@ def choose_torch_device(cls) -> torch.device: device = CPU_DEVICE return cls.normalize(device) + @classmethod + def execution_devices(cls) -> Set[torch.device]: + """Return a list of torch.devices that can be used for accelerated inference.""" + app_config = get_config() + if app_config.devices is None: + return cls._lookup_execution_devices() + return {torch.device(x) for x in app_config.devices} + @classmethod def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype: """Return the precision to use for accelerated inference.""" - device = device or cls.choose_torch_device() config = get_config() + device = device or cls._choose_device() if device.type == "cuda" and torch.cuda.is_available(): device_name = torch.cuda.get_device_name(device) if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name: @@ -108,3 +136,13 @@ def empty_cache(cls) -> None: @classmethod def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype: return NAME_TO_PRECISION[precision_name] + + @classmethod + def _lookup_execution_devices(cls) -> Set[torch.device]: + if torch.cuda.is_available(): + devices = {torch.device(f"cuda:{x}") for x in range(0, torch.cuda.device_count())} + elif torch.backends.mps.is_available(): + devices = {torch.device("mps")} + else: + devices = {torch.device("cpu")} + return devices diff --git a/scripts/populate_model_db_from_yaml.py b/scripts/populate_model_db_from_yaml.py new file mode 100755 index 00000000000..80e5bcfc5c6 --- /dev/null +++ b/scripts/populate_model_db_from_yaml.py @@ -0,0 +1,54 @@ +#!/bin/env python + +from argparse import ArgumentParser, Namespace +from pathlib import Path + +from invokeai.app.services.config import InvokeAIAppConfig, get_config +from invokeai.app.services.download import DownloadQueueService +from invokeai.app.services.model_install import ModelInstallService +from invokeai.app.services.model_records import ModelRecordServiceSQL +from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase +from invokeai.backend.util.logging import InvokeAILogger + + +def get_args() -> Namespace: + parser = ArgumentParser(description="Update models database from yaml file") + parser.add_argument("--root", type=Path, required=False, default=None) + parser.add_argument("--yaml_file", type=Path, required=False, default=None) + return parser.parse_args() + + +def populate_config() -> InvokeAIAppConfig: + args = get_args() + config = get_config() + if args.root: + config._root = args.root + if args.yaml_file: + config.legacy_models_yaml_path = args.yaml_file + else: + config.legacy_models_yaml_path = config.root_path / "configs/models.yaml" + return config + + +def initialize_installer(config: InvokeAIAppConfig) -> ModelInstallService: + logger = InvokeAILogger.get_logger(config=config) + db = SqliteDatabase(config.db_path, logger) + record_store = ModelRecordServiceSQL(db) + queue = DownloadQueueService() + queue.start() + installer = ModelInstallService(app_config=config, record_store=record_store, download_queue=queue) + return installer + + +def main() -> None: + config = populate_config() + installer = initialize_installer(config) + installer._migrate_yaml(rename_yaml=False, overwrite_db=True) + print("\n") + print("\t".join(["key", "name", "type", "path"])) + for model in installer.record_store.all_models(): + print("\t".join([model.key, model.name, model.type, (config.models_path / model.path).as_posix()])) + + +if __name__ == "__main__": + main() diff --git a/tests/backend/model_manager/model_loading/test_model_load.py b/tests/backend/model_manager/model_loading/test_model_load.py index 3f12f7f8ee9..ff80d49397f 100644 --- a/tests/backend/model_manager/model_loading/test_model_load.py +++ b/tests/backend/model_manager/model_loading/test_model_load.py @@ -14,13 +14,14 @@ def test_loading(mm2_model_manager: ModelManagerServiceBase, embedding_file: Pat matches = store.search_by_attr(model_name="test_embedding") assert len(matches) == 0 key = mm2_model_manager.install.register_path(embedding_file) - loaded_model = mm2_model_manager.load.load_model(store.get_model(key)) - assert loaded_model is not None - assert loaded_model.config.key == key - with loaded_model as model: - assert isinstance(model, TextualInversionModelRaw) + with mm2_model_manager.load.ram_cache.reserve_execution_device(): + loaded_model = mm2_model_manager.load.load_model(store.get_model(key)) + assert loaded_model is not None + assert loaded_model.config.key == key + with loaded_model as model: + assert isinstance(model, TextualInversionModelRaw) - config = mm2_model_manager.store.get_model(key) - loaded_model_2 = mm2_model_manager.load.load_model(config) + config = mm2_model_manager.store.get_model(key) + loaded_model_2 = mm2_model_manager.load.load_model(config) - assert loaded_model.config.key == loaded_model_2.config.key + assert loaded_model.config.key == loaded_model_2.config.key diff --git a/tests/backend/model_manager/model_manager_fixtures.py b/tests/backend/model_manager/model_manager_fixtures.py index f82239298e1..e7e592d9b71 100644 --- a/tests/backend/model_manager/model_manager_fixtures.py +++ b/tests/backend/model_manager/model_manager_fixtures.py @@ -89,11 +89,10 @@ def mm2_download_queue(mm2_session: Session) -> DownloadQueueServiceBase: @pytest.fixture -def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase: +def mm2_loader(mm2_app_config: InvokeAIAppConfig) -> ModelLoadServiceBase: ram_cache = ModelCache( logger=InvokeAILogger.get_logger(), max_cache_size=mm2_app_config.ram, - max_vram_cache_size=mm2_app_config.vram, ) convert_cache = ModelConvertCache(mm2_app_config.convert_cache_path) return ModelLoadService( diff --git a/tests/backend/util/test_devices.py b/tests/backend/util/test_devices.py index 8e810e43678..d854a82e622 100644 --- a/tests/backend/util/test_devices.py +++ b/tests/backend/util/test_devices.py @@ -8,7 +8,9 @@ import torch from invokeai.app.services.config import get_config +from invokeai.backend.model_manager.load import ModelCache from invokeai.backend.util.devices import TorchDevice, choose_precision, choose_torch_device, torch_dtype +from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 devices = ["cpu", "cuda:0", "cuda:1", "mps"] device_types_cpu = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", torch.float32)] @@ -20,6 +22,7 @@ def test_device_choice(device_name): config = get_config() config.device = device_name + TorchDevice.set_model_cache(None) # disable dynamic selection of GPU device torch_device = TorchDevice.choose_torch_device() assert torch_device == torch.device(device_name) @@ -130,3 +133,32 @@ def test_legacy_precision_name(): assert "float16" == choose_precision(torch.device("cuda")) assert "float16" == choose_precision(torch.device("mps")) assert "float32" == choose_precision(torch.device("cpu")) + + +def test_multi_device_support_1(): + config = get_config() + config.devices = ["cuda:0", "cuda:1"] + assert TorchDevice.execution_devices() == {torch.device("cuda:0"), torch.device("cuda:1")} + + +def test_multi_device_support_2(): + config = get_config() + config.devices = None + with ( + patch("torch.cuda.device_count", return_value=3), + patch("torch.cuda.is_available", return_value=True), + ): + assert TorchDevice.execution_devices() == { + torch.device("cuda:0"), + torch.device("cuda:1"), + torch.device("cuda:2"), + } + + +def test_multi_device_support_3(): + config = get_config() + config.devices = ["cuda:0", "cuda:1"] + cache = ModelCache() + with cache.reserve_execution_device() as gpu: + assert gpu in [torch.device(x) for x in config.devices] + assert TorchDevice.choose_torch_device() == gpu diff --git a/tests/conftest.py b/tests/conftest.py index 8a67e9473c8..e140bcd7df4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,7 +17,6 @@ from invokeai.app.services.images.images_default import ImageService from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache from invokeai.app.services.invocation_services import InvocationServices -from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService from invokeai.app.services.invoker import Invoker from invokeai.backend.util.logging import InvokeAILogger from tests.backend.model_manager.model_manager_fixtures import * # noqa: F403 @@ -49,13 +48,13 @@ def mock_services() -> InvocationServices: model_manager=None, # type: ignore download_queue=None, # type: ignore names=None, # type: ignore - performance_statistics=InvocationStatsService(), session_processor=None, # type: ignore session_queue=None, # type: ignore urls=None, # type: ignore workflow_records=None, # type: ignore tensors=None, # type: ignore conditioning=None, # type: ignore + performance_statistics=None, # type: ignore ) diff --git a/tests/test_config.py b/tests/test_config.py index a6ea2a34806..6606bc0f3e0 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -92,7 +92,6 @@ def test_migrate_v3_config_from_file(tmp_path: Path, patch_rootdir: None): assert config.host == "192.168.1.1" assert config.port == 8080 assert config.ram == 100 - assert config.vram == 50 assert config.legacy_models_yaml_path == Path("/custom/models.yaml") # This should be stripped out assert not hasattr(config, "esrgan")