diff --git a/docs/contributing/MODEL_MANAGER.md b/docs/contributing/MODEL_MANAGER.md
index 9699db4f1a6..7e20fb68280 100644
--- a/docs/contributing/MODEL_MANAGER.md
+++ b/docs/contributing/MODEL_MANAGER.md
@@ -1328,7 +1328,7 @@ from invokeai.app.services.model_load import ModelLoadService, ModelLoaderRegist
config = InvokeAIAppConfig.get_config()
ram_cache = ModelCache(
- max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
+ max_cache_size=config.ram_cache_size, logger=logger
)
convert_cache = ModelConvertCache(
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py
index 1e78e10d380..67b25dcda93 100644
--- a/invokeai/app/invocations/compel.py
+++ b/invokeai/app/invocations/compel.py
@@ -103,6 +103,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
textual_inversion_manager=ti_manager,
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
truncate_long_prompts=False,
+ device=TorchDevice.choose_torch_device(),
)
conjunction = Compel.parse_prompt_string(self.prompt)
@@ -117,6 +118,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
conditioning_data = ConditioningFieldData(conditionings=[BasicConditioningInfo(embeds=c)])
conditioning_name = context.conditioning.save(conditioning_data)
+
return ConditioningOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
@@ -203,6 +205,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
truncate_long_prompts=False, # TODO:
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
requires_pooled=get_pooled,
+ device=TorchDevice.choose_torch_device(),
)
conjunction = Compel.parse_prompt_string(prompt)
@@ -313,7 +316,6 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput:
)
]
)
-
conditioning_name = context.conditioning.save(conditioning_data)
return ConditioningOutput(
diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py
index e94daf70bdd..db43723e339 100644
--- a/invokeai/app/invocations/denoise_latents.py
+++ b/invokeai/app/invocations/denoise_latents.py
@@ -1,4 +1,5 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
+import copy
import inspect
from contextlib import ExitStack
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
@@ -193,9 +194,8 @@ def _get_text_embeddings_and_masks(
text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
text_embeddings_masks: list[Optional[torch.Tensor]] = []
for cond in cond_list:
- cond_data = context.conditioning.load(cond.conditioning_name)
+ cond_data = copy.deepcopy(context.conditioning.load(cond.conditioning_name))
text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))
-
mask = cond.mask
if mask is not None:
mask = context.tensors.load(mask.tensor_name)
@@ -226,6 +226,7 @@ def _preprocess_regional_prompt_mask(
# Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w).
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
resized_mask = tf(mask)
+ assert isinstance(resized_mask, torch.Tensor)
return resized_mask
def _concat_regional_text_embeddings(
diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py
index 1dc75add1d6..0ff902067d8 100644
--- a/invokeai/app/services/config/config_default.py
+++ b/invokeai/app/services/config/config_default.py
@@ -26,13 +26,13 @@
DEFAULT_RAM_CACHE = 10.0
DEFAULT_VRAM_CACHE = 0.25
DEFAULT_CONVERT_CACHE = 20.0
-DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
-PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
+DEVICE = Literal["auto", "cpu", "cuda:0", "cuda:1", "cuda:2", "cuda:3", "cuda:4", "cuda:5", "cuda:6", "cuda:7", "mps"]
+PRECISION = Literal["auto", "float16", "bfloat16", "float32", "autocast"]
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
-CONFIG_SCHEMA_VERSION = "4.0.1"
+CONFIG_SCHEMA_VERSION = "4.0.2"
def get_default_ram_cache_size() -> float:
@@ -105,14 +105,16 @@ class InvokeAIAppConfig(BaseSettings):
convert_cache: Maximum size of on-disk converted models cache (GB).
lazy_offload: Keep models in VRAM until their space is needed.
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
- device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
- precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.
Valid values: `auto`, `float16`, `bfloat16`, `float32`
+ device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
Valid values: `auto`, `cpu`, `cuda:0`, `cuda:1`, `cuda:2`, `cuda:3`, `cuda:4`, `cuda:5`, `cuda:6`, `cuda:7`, `mps`
+ devices: List of execution devices; will override default device selected.
+ precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.
Valid values: `auto`, `float16`, `bfloat16`, `float32`, `autocast`
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
attention_type: Attention type.
Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
attention_slice_size: Slice size, valid when attention_type=="sliced".
Valid values: `auto`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8`
force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
max_queue_size: Maximum number of items in the session queue.
+ max_threads: Maximum number of session queue execution threads. Autocalculated from number of GPUs if not set.
clear_queue_on_startup: Empties session queue on startup.
allow_nodes: List of nodes to allow. Omit to allow all.
deny_nodes: List of nodes to deny. Omit to deny none.
@@ -178,6 +180,7 @@ class InvokeAIAppConfig(BaseSettings):
# DEVICE
device: DEVICE = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.")
+ devices: Optional[list[DEVICE]] = Field(default=None, description="List of execution devices; will override default device selected.")
precision: PRECISION = Field(default="auto", description="Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.")
# GENERATION
@@ -187,6 +190,7 @@ class InvokeAIAppConfig(BaseSettings):
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
+ max_threads: Optional[int] = Field(default=None, description="Maximum number of session queue execution threads. Autocalculated from number of GPUs if not set.")
clear_queue_on_startup: bool = Field(default=False, description="Empties session queue on startup.")
# NODES
@@ -376,9 +380,6 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
if k == "max_cache_size" and "ram" not in category_dict:
parsed_config_dict["ram"] = v
- # `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
- if k == "max_vram_cache_size" and "vram" not in category_dict:
- parsed_config_dict["vram"] = v
# autocast was removed in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
@@ -426,6 +427,27 @@ def migrate_v4_0_0_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig
return config
+def migrate_v4_0_1_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
+ """Migrate v4.0.1 config dictionary to a current config object.
+
+ A few new multi-GPU options were added in 4.0.2, and this simply
+ updates the schema label.
+
+ Args:
+ config_dict: A dictionary of settings from a v4.0.1 config file.
+
+ Returns:
+ An instance of `InvokeAIAppConfig` with the migrated settings.
+ """
+ parsed_config_dict: dict[str, Any] = {}
+ for k, _ in config_dict.items():
+ if k == "schema_version":
+ parsed_config_dict[k] = CONFIG_SCHEMA_VERSION
+ config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
+ return config
+
+
+# TO DO: replace this with a formal registration and migration system
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
"""Load and migrate a config file to the latest version.
@@ -457,6 +479,10 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
loaded_config_dict = migrate_v4_0_0_config_dict(loaded_config_dict)
loaded_config_dict.write_file(config_path)
+ elif loaded_config_dict["schema_version"] == "4.0.1":
+ loaded_config_dict = migrate_v4_0_1_config_dict(loaded_config_dict)
+ loaded_config_dict.write_file(config_path)
+
# Attempt to load as a v4 config file
try:
# Meta is not included in the model fields, so we need to validate it separately
diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py
index f4fce6098f3..b23a3bb3277 100644
--- a/invokeai/app/services/invocation_services.py
+++ b/invokeai/app/services/invocation_services.py
@@ -53,11 +53,11 @@ def __init__(
model_images: "ModelImageFileStorageBase",
model_manager: "ModelManagerServiceBase",
download_queue: "DownloadQueueServiceBase",
- performance_statistics: "InvocationStatsServiceBase",
session_queue: "SessionQueueBase",
session_processor: "SessionProcessorBase",
invocation_cache: "InvocationCacheBase",
names: "NameServiceBase",
+ performance_statistics: "InvocationStatsServiceBase",
urls: "UrlServiceBase",
workflow_records: "WorkflowRecordsStorageBase",
tensors: "ObjectSerializerBase[torch.Tensor]",
@@ -77,11 +77,11 @@ def __init__(
self.model_images = model_images
self.model_manager = model_manager
self.download_queue = download_queue
- self.performance_statistics = performance_statistics
self.session_queue = session_queue
self.session_processor = session_processor
self.invocation_cache = invocation_cache
self.names = names
+ self.performance_statistics = performance_statistics
self.urls = urls
self.workflow_records = workflow_records
self.tensors = tensors
diff --git a/invokeai/app/services/invocation_stats/invocation_stats_default.py b/invokeai/app/services/invocation_stats/invocation_stats_default.py
index 5a41f1f5d6b..2aa6f28f658 100644
--- a/invokeai/app/services/invocation_stats/invocation_stats_default.py
+++ b/invokeai/app/services/invocation_stats/invocation_stats_default.py
@@ -74,9 +74,9 @@ def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: st
)
self._stats[graph_execution_state_id].add_node_execution_stats(node_stats)
- def reset_stats(self):
- self._stats = {}
- self._cache_stats = {}
+ def reset_stats(self, graph_execution_state_id: str):
+ self._stats.pop(graph_execution_state_id)
+ self._cache_stats.pop(graph_execution_state_id)
def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary:
graph_stats_summary = self._get_graph_summary(graph_execution_state_id)
diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py
index dd1b44d8999..0ac58aff98f 100644
--- a/invokeai/app/services/model_install/model_install_default.py
+++ b/invokeai/app/services/model_install/model_install_default.py
@@ -284,9 +284,14 @@ def prune_jobs(self) -> None:
unfinished_jobs = [x for x in self._install_jobs if not x.in_terminal_state]
self._install_jobs = unfinished_jobs
- def _migrate_yaml(self) -> None:
+ def _migrate_yaml(self, rename_yaml: Optional[bool] = True, overwrite_db: Optional[bool] = False) -> None:
db_models = self.record_store.all_models()
+ if overwrite_db:
+ for model in db_models:
+ self.record_store.del_model(model.key)
+ db_models = self.record_store.all_models()
+
legacy_models_yaml_path = (
self._app_config.legacy_models_yaml_path or self._app_config.root_path / "configs" / "models.yaml"
)
@@ -336,7 +341,8 @@ def _migrate_yaml(self) -> None:
self._logger.warning(f"Model at {model_path} could not be migrated: {e}")
# Rename `models.yaml` to `models.yaml.bak` to prevent re-migration
- legacy_models_yaml_path.rename(legacy_models_yaml_path.with_suffix(".yaml.bak"))
+ if rename_yaml:
+ legacy_models_yaml_path.rename(legacy_models_yaml_path.with_suffix(".yaml.bak"))
# Unset the path - we are done with it either way
self._app_config.legacy_models_yaml_path = None
diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py
index da567721956..990f8ca207e 100644
--- a/invokeai/app/services/model_load/model_load_base.py
+++ b/invokeai/app/services/model_load/model_load_base.py
@@ -33,6 +33,11 @@ def ram_cache(self) -> ModelCacheBase[AnyModel]:
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the checkpoint convert cache used by this loader."""
+ @property
+ @abstractmethod
+ def gpu_count(self) -> int:
+ """Return the number of GPUs we are configured to use."""
+
@abstractmethod
def load_model_from_path(
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py
index 70674813785..00e14f0d72f 100644
--- a/invokeai/app/services/model_load/model_load_default.py
+++ b/invokeai/app/services/model_load/model_load_default.py
@@ -46,6 +46,7 @@ def __init__(
self._registry = registry
def start(self, invoker: Invoker) -> None:
+ """Start the service."""
self._invoker = invoker
@property
@@ -53,6 +54,11 @@ def ram_cache(self) -> ModelCacheBase[AnyModel]:
"""Return the RAM cache used by this loader."""
return self._ram_cache
+ @property
+ def gpu_count(self) -> int:
+ """Return the number of GPUs available for our uses."""
+ return len(self._ram_cache.execution_devices)
+
@property
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the checkpoint convert cache used by this loader."""
diff --git a/invokeai/app/services/model_manager/model_manager_base.py b/invokeai/app/services/model_manager/model_manager_base.py
index af1b68e1ec3..d20aefd4f3a 100644
--- a/invokeai/app/services/model_manager/model_manager_base.py
+++ b/invokeai/app/services/model_manager/model_manager_base.py
@@ -1,6 +1,7 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
from abc import ABC, abstractmethod
+from typing import Optional, Set
import torch
from typing_extensions import Self
@@ -31,7 +32,7 @@ def build_model_manager(
model_record_service: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase,
events: EventServiceBase,
- execution_device: torch.device,
+ execution_devices: Optional[Set[torch.device]] = None,
) -> Self:
"""
Construct the model manager service instance.
diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py
index 1a2b9a34022..6ff1b7de675 100644
--- a/invokeai/app/services/model_manager/model_manager_default.py
+++ b/invokeai/app/services/model_manager/model_manager_default.py
@@ -1,14 +1,10 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
"""Implementation of ModelManagerServiceBase."""
-from typing import Optional
-
-import torch
from typing_extensions import Self
from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
-from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
from ..config import InvokeAIAppConfig
@@ -69,7 +65,6 @@ def build_model_manager(
model_record_service: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase,
events: EventServiceBase,
- execution_device: Optional[torch.device] = None,
) -> Self:
"""
Construct the model manager service instance.
@@ -82,9 +77,7 @@ def build_model_manager(
ram_cache = ModelCache(
max_cache_size=app_config.ram,
max_vram_cache_size=app_config.vram,
- lazy_offloading=app_config.lazy_offload,
logger=logger,
- execution_device=execution_device or TorchDevice.choose_torch_device(),
)
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
loader = ModelLoadService(
diff --git a/invokeai/app/services/object_serializer/object_serializer_disk.py b/invokeai/app/services/object_serializer/object_serializer_disk.py
index 8edd29e1505..0c9567553a6 100644
--- a/invokeai/app/services/object_serializer/object_serializer_disk.py
+++ b/invokeai/app/services/object_serializer/object_serializer_disk.py
@@ -1,5 +1,6 @@
import shutil
import tempfile
+import threading
import typing
from pathlib import Path
from typing import TYPE_CHECKING, Optional, TypeVar
@@ -9,6 +10,7 @@
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError
from invokeai.app.util.misc import uuid_string
+from invokeai.backend.util.devices import TorchDevice
if TYPE_CHECKING:
from invokeai.app.services.invoker import Invoker
@@ -70,7 +72,10 @@ def _get_path(self, name: str) -> Path:
return self._output_dir / name
def _new_name(self) -> str:
- return f"{self._obj_class_name}_{uuid_string()}"
+ tid = threading.current_thread().ident
+ # Add tid to the object name because uuid4 not thread-safe on windows
+ # See https://stackoverflow.com/questions/2759644/python-multiprocessing-doesnt-play-nicely-with-uuid-uuid4
+ return f"{self._obj_class_name}_{tid}-{uuid_string()}"
def _tempdir_cleanup(self) -> None:
"""Calls `cleanup` on the temporary directory, if it exists."""
diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py
index 3f348fb239d..6ca4e164e06 100644
--- a/invokeai/app/services/session_processor/session_processor_default.py
+++ b/invokeai/app/services/session_processor/session_processor_default.py
@@ -1,8 +1,9 @@
import traceback
from contextlib import suppress
-from threading import BoundedSemaphore, Thread
+from queue import Queue
+from threading import BoundedSemaphore, Lock, Thread
from threading import Event as ThreadEvent
-from typing import Optional
+from typing import Optional, Set
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.events.events_common import (
@@ -26,6 +27,7 @@
from invokeai.app.services.shared.graph import NodeInputError
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
from invokeai.app.util.profiler import Profiler
+from invokeai.backend.util.devices import TorchDevice
from ..invoker import Invoker
from .session_processor_base import InvocationServices, SessionProcessorBase, SessionRunnerBase
@@ -57,8 +59,11 @@ def __init__(
self._on_after_run_node_callbacks = on_after_run_node_callbacks or []
self._on_node_error_callbacks = on_node_error_callbacks or []
self._on_after_run_session_callbacks = on_after_run_session_callbacks or []
+ self._process_lock = Lock()
- def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None):
+ def start(
+ self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None
+ ) -> None:
self._services = services
self._cancel_event = cancel_event
self._profiler = profiler
@@ -76,7 +81,8 @@ def run(self, queue_item: SessionQueueItem):
# Loop over invocations until the session is complete or canceled
while True:
try:
- invocation = queue_item.session.next()
+ with self._process_lock:
+ invocation = queue_item.session.next()
# Anything other than a `NodeInputError` is handled as a processor error
except NodeInputError as e:
error_type = e.__class__.__name__
@@ -108,7 +114,7 @@ def run(self, queue_item: SessionQueueItem):
self._on_after_run_session(queue_item=queue_item)
- def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
+ def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None:
try:
# Any unhandled exception in this scope is an invocation error & will fail the graph
with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id):
@@ -210,7 +216,7 @@ def _on_after_run_session(self, queue_item: SessionQueueItem) -> None:
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
self._services.performance_statistics.log_stats(queue_item.session.id)
- self._services.performance_statistics.reset_stats()
+ self._services.performance_statistics.reset_stats(queue_item.session.id)
for callback in self._on_after_run_session_callbacks:
callback(queue_item=queue_item)
@@ -324,7 +330,7 @@ def __init__(
def start(self, invoker: Invoker) -> None:
self._invoker: Invoker = invoker
- self._queue_item: Optional[SessionQueueItem] = None
+ self._active_queue_items: Set[SessionQueueItem] = set()
self._invocation: Optional[BaseInvocation] = None
self._resume_event = ThreadEvent()
@@ -350,7 +356,14 @@ def start(self, invoker: Invoker) -> None:
else None
)
+ self._worker_thread_count = self._invoker.services.configuration.max_threads or len(
+ TorchDevice.execution_devices()
+ )
+
+ self._session_worker_queue: Queue[SessionQueueItem] = Queue()
+
self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event, profiler=self._profiler)
+ # Session processor - singlethreaded
self._thread = Thread(
name="session_processor",
target=self._process,
@@ -363,6 +376,16 @@ def start(self, invoker: Invoker) -> None:
)
self._thread.start()
+ # Session processor workers - multithreaded
+ self._invoker.services.logger.debug(f"Starting {self._worker_thread_count} session processing threads.")
+ for _i in range(0, self._worker_thread_count):
+ worker = Thread(
+ name="session_worker",
+ target=self._process_next_session,
+ daemon=True,
+ )
+ worker.start()
+
def stop(self, *args, **kwargs) -> None:
self._stop_event.set()
@@ -370,7 +393,7 @@ def _poll_now(self) -> None:
self._poll_now_event.set()
async def _on_queue_cleared(self, event: FastAPIEvent[QueueClearedEvent]) -> None:
- if self._queue_item and self._queue_item.queue_id == event[1].queue_id:
+ if any(item.queue_id == event[1].queue_id for item in self._active_queue_items):
self._cancel_event.set()
self._poll_now()
@@ -378,7 +401,7 @@ async def _on_batch_enqueued(self, event: FastAPIEvent[BatchEnqueuedEvent]) -> N
self._poll_now()
async def _on_queue_item_status_changed(self, event: FastAPIEvent[QueueItemStatusChangedEvent]) -> None:
- if self._queue_item and event[1].status in ["completed", "failed", "canceled"]:
+ if self._active_queue_items and event[1].status in ["completed", "failed", "canceled"]:
# When the queue item is canceled via HTTP, the queue item status is set to `"canceled"` and this event is
# emitted. We need to respond to this event and stop graph execution. This is done by setting the cancel
# event, which the session runner checks between invocations. If set, the session runner loop is broken.
@@ -403,7 +426,7 @@ def pause(self) -> SessionProcessorStatus:
def get_status(self) -> SessionProcessorStatus:
return SessionProcessorStatus(
is_started=self._resume_event.is_set(),
- is_processing=self._queue_item is not None,
+ is_processing=len(self._active_queue_items) > 0,
)
def _process(
@@ -428,30 +451,22 @@ def _process(
resume_event.wait()
# Get the next session to process
- self._queue_item = self._invoker.services.session_queue.dequeue()
+ queue_item = self._invoker.services.session_queue.dequeue()
- if self._queue_item is None:
+ if queue_item is None:
# The queue was empty, wait for next polling interval or event to try again
self._invoker.services.logger.debug("Waiting for next polling interval or event")
poll_now_event.wait(self._polling_interval)
continue
- self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
+ self._session_worker_queue.put(queue_item)
+ self._invoker.services.logger.debug(f"Scheduling queue item {queue_item.item_id} to run")
cancel_event.clear()
# Run the graph
- self.session_runner.run(queue_item=self._queue_item)
-
- except Exception as e:
- error_type = e.__class__.__name__
- error_message = str(e)
- error_traceback = traceback.format_exc()
- self._on_non_fatal_processor_error(
- queue_item=self._queue_item,
- error_type=error_type,
- error_message=error_message,
- error_traceback=error_traceback,
- )
+ # self.session_runner.run(queue_item=self._queue_item)
+
+ except Exception:
# Wait for next polling interval or event to try again
poll_now_event.wait(self._polling_interval)
continue
@@ -466,9 +481,25 @@ def _process(
finally:
stop_event.clear()
poll_now_event.clear()
- self._queue_item = None
self._thread_semaphore.release()
+ def _process_next_session(self) -> None:
+ while True:
+ self._resume_event.wait()
+ queue_item = self._session_worker_queue.get()
+ if queue_item.status == "canceled":
+ continue
+ try:
+ self._active_queue_items.add(queue_item)
+ # reserve a GPU for this session - may block
+ with self._invoker.services.model_manager.load.ram_cache.reserve_execution_device():
+ # Run the session on the reserved GPU
+ self.session_runner.run(queue_item=queue_item)
+ except Exception:
+ continue
+ finally:
+ self._active_queue_items.remove(queue_item)
+
def _on_non_fatal_processor_error(
self,
queue_item: Optional[SessionQueueItem],
diff --git a/invokeai/app/services/session_queue/session_queue_common.py b/invokeai/app/services/session_queue/session_queue_common.py
index 7f4601eba73..3cff330cff7 100644
--- a/invokeai/app/services/session_queue/session_queue_common.py
+++ b/invokeai/app/services/session_queue/session_queue_common.py
@@ -236,6 +236,9 @@ def queue_item_dto_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItemDTO
}
)
+ def __hash__(self) -> int:
+ return self.item_id
+
class SessionQueueItemDTO(SessionQueueItemWithoutGraph):
pass
diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py
index d745e738233..60fd909881b 100644
--- a/invokeai/app/services/shared/graph.py
+++ b/invokeai/app/services/shared/graph.py
@@ -652,7 +652,7 @@ def _is_iterator_connection_valid(
output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
# Input type must be a list
- if get_origin(input_field) != list:
+ if get_origin(input_field) is not list:
return False
# Validate that all outputs match the input type
diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py
index 01662335e46..e8f3d083b13 100644
--- a/invokeai/app/services/shared/invocation_context.py
+++ b/invokeai/app/services/shared/invocation_context.py
@@ -2,6 +2,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Optional, Union
+import torch
from PIL.Image import Image
from pydantic.networks import AnyHttpUrl
from torch import Tensor
@@ -26,11 +27,13 @@
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
+from invokeai.backend.util.devices import TorchDevice
if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
+ from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
"""
The InvocationContext provides access to various services and data about the current invocation.
@@ -323,7 +326,6 @@ def load(self, name: str) -> ConditioningFieldData:
Returns:
The loaded conditioning data.
"""
-
return self._services.conditioning.load(name)
@@ -557,6 +559,28 @@ def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_m
is_canceled=self.is_canceled,
)
+ def torch_device(self) -> torch.device:
+ """
+ Return a torch device to use in the current invocation.
+
+ Returns:
+ A torch.device not currently in use by the system.
+ """
+ ram_cache: "ModelCacheBase[AnyModel]" = self._services.model_manager.load.ram_cache
+ return ram_cache.get_execution_device()
+
+ def torch_dtype(self, device: Optional[torch.device] = None) -> torch.dtype:
+ """
+ Return a precision type to use with the current invocation and torch device.
+
+ Args:
+ device: Optional device.
+
+ Returns:
+ A torch.dtype suited for the current device.
+ """
+ return TorchDevice.choose_torch_dtype(device)
+
class InvocationContext:
"""Provides access to various services and data for the current invocation.
diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py
index 7ed12a7674d..c0065cefa9b 100644
--- a/invokeai/backend/model_manager/config.py
+++ b/invokeai/backend/model_manager/config.py
@@ -25,6 +25,7 @@
from typing import Literal, Optional, Type, TypeAlias, Union
import torch
+from diffusers.configuration_utils import ConfigMixin
from diffusers.models.modeling_utils import ModelMixin
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
from typing_extensions import Annotated, Any, Dict
@@ -37,7 +38,7 @@
# ModelMixin is the base class for all diffusers and transformers models
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
-AnyModel = Union[ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor]]
+AnyModel = Union[ConfigMixin, ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor]]
class InvalidModelConfigException(Exception):
@@ -177,6 +178,7 @@ class ModelConfigBase(BaseModel):
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
+ """Extend the pydantic schema from a json."""
schema["required"].extend(["key", "type", "format"])
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
@@ -443,7 +445,7 @@ def make_config(
model = dest_class.model_validate(model_data)
else:
# mypy doesn't typecheck TypeAdapters well?
- model = AnyModelConfigValidator.validate_python(model_data) # type: ignore
+ model = AnyModelConfigValidator.validate_python(model_data)
assert model is not None
if key:
model.key = key
diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py
index 6748e85dca1..9291e599456 100644
--- a/invokeai/backend/model_manager/load/load_base.py
+++ b/invokeai/backend/model_manager/load/load_base.py
@@ -65,8 +65,7 @@ class LoadedModelWithoutConfig:
def __enter__(self) -> AnyModel:
"""Context entry."""
- self._locker.lock()
- return self.model
+ return self._locker.lock()
def __exit__(self, *args: Any, **kwargs: Any) -> None:
"""Context exit."""
diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py
index 012fd42d556..4fe99c31e6d 100644
--- a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py
+++ b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py
@@ -8,9 +8,10 @@
"""
from abc import ABC, abstractmethod
+from contextlib import contextmanager
from dataclasses import dataclass, field
from logging import Logger
-from typing import Dict, Generic, Optional, TypeVar
+from typing import Dict, Generator, Generic, Optional, Set, TypeVar
import torch
@@ -51,44 +52,13 @@ class CacheRecord(Generic[T]):
Elements of the cache:
key: Unique key for each model, same as used in the models database.
- model: Model in memory.
- state_dict: A read-only copy of the model's state dict in RAM. It will be
- used as a template for creating a copy in the VRAM.
+ model: Read-only copy of the model *without weights* residing in the "meta device"
size: Size of the model
- loaded: True if the model's state dict is currently in VRAM
-
- Before a model is executed, the state_dict template is copied into VRAM,
- and then injected into the model. When the model is finished, the VRAM
- copy of the state dict is deleted, and the RAM version is reinjected
- into the model.
-
- The state_dict should be treated as a read-only attribute. Do not attempt
- to patch or otherwise modify it. Instead, patch the copy of the state_dict
- after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
- context manager call `model_on_device()`.
"""
key: str
- model: T
- device: torch.device
- state_dict: Optional[Dict[str, torch.Tensor]]
size: int
- loaded: bool = False
- _locks: int = 0
-
- def lock(self) -> None:
- """Lock this record."""
- self._locks += 1
-
- def unlock(self) -> None:
- """Unlock this record."""
- self._locks -= 1
- assert self._locks >= 0
-
- @property
- def locked(self) -> bool:
- """Return true if record is locked."""
- return self._locks > 0
+ model: T
@dataclass
@@ -115,30 +85,33 @@ def storage_device(self) -> torch.device:
@property
@abstractmethod
- def execution_device(self) -> torch.device:
- """Return the exection device (e.g. "cuda" for VRAM)."""
+ def execution_devices(self) -> Set[torch.device]:
+ """Return the set of available execution devices."""
pass
- @property
+ @contextmanager
@abstractmethod
- def lazy_offloading(self) -> bool:
- """Return true if the cache is configured to lazily offload models in VRAM."""
+ def reserve_execution_device(self, timeout: int = 0) -> Generator[torch.device, None, None]:
+ """Reserve an execution device (GPU) under the current thread id."""
pass
- @property
@abstractmethod
- def max_cache_size(self) -> float:
- """Return true if the cache is configured to lazily offload models in VRAM."""
- pass
+ def get_execution_device(self) -> torch.device:
+ """
+ Return an execution device that has been reserved for current thread.
- @abstractmethod
- def offload_unlocked_models(self, size_required: int) -> None:
- """Offload from VRAM any models not actively in use."""
+ Note that reservations are done using the current thread's TID.
+ It might be better to do this using the session ID, but that involves
+ too many detailed changes to model manager calls.
+
+ May generate a ValueError if no GPU has been reserved.
+ """
pass
+ @property
@abstractmethod
- def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
- """Move model into the indicated device."""
+ def max_cache_size(self) -> float:
+ """Return true if the cache is configured to lazily offload models in VRAM."""
pass
@property
@@ -202,6 +175,11 @@ def exists(
"""Return true if the model identified by key and submodel_type is in the cache."""
pass
+ @abstractmethod
+ def model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> AnyModel:
+ """Move a copy of the model into the indicated device and return it."""
+ pass
+
@abstractmethod
def cache_size(self) -> int:
"""Get the total size of the models currently cached."""
diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py
index d48e45426e3..31a10b6ea40 100644
--- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py
+++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py
@@ -18,17 +18,19 @@
"""
+import copy
import gc
-import math
-import time
-from contextlib import suppress
+import sys
+import threading
+from contextlib import contextmanager, suppress
from logging import Logger
-from typing import Dict, List, Optional
+from threading import BoundedSemaphore
+from typing import Dict, Generator, List, Optional, Set
import torch
from invokeai.backend.model_manager import AnyModel, SubModelType
-from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
+from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
@@ -39,9 +41,7 @@
# Maximum size of the cache, in gigs
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
DEFAULT_MAX_CACHE_SIZE = 6.0
-
-# amount of GPU memory to hold in reserve for use by generations (GB)
-DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
+DEFAULT_MAX_VRAM_CACHE_SIZE = 0.25
# actual size of a gig
GIG = 1073741824
@@ -57,12 +57,8 @@ def __init__(
self,
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
- execution_device: torch.device = torch.device("cuda"),
storage_device: torch.device = torch.device("cpu"),
precision: torch.dtype = torch.float16,
- sequential_offload: bool = False,
- lazy_offloading: bool = True,
- sha_chunksize: int = 16777216,
log_memory_usage: bool = False,
logger: Optional[Logger] = None,
):
@@ -70,23 +66,19 @@ def __init__(
Initialize the model RAM cache.
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
- :param execution_device: Torch device to load active model into [torch.device('cuda')]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param precision: Precision for loaded models [torch.float16]
- :param lazy_offloading: Keep model in VRAM until another model needs to be loaded
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
behaviour.
"""
- # allow lazy offloading only when vram cache enabled
- self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
self._precision: torch.dtype = precision
self._max_cache_size: float = max_cache_size
self._max_vram_cache_size: float = max_vram_cache_size
- self._execution_device: torch.device = execution_device
self._storage_device: torch.device = storage_device
+ self._ram_lock = threading.Lock()
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
self._log_memory_usage = log_memory_usage
self._stats: Optional[CacheStats] = None
@@ -94,25 +86,87 @@ def __init__(
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
self._cache_stack: List[str] = []
+ # device to thread id
+ self._device_lock = threading.Lock()
+ self._execution_devices: Dict[torch.device, int] = {x: 0 for x in TorchDevice.execution_devices()}
+ self._free_execution_device = BoundedSemaphore(len(self._execution_devices))
+
+ self.logger.info(
+ f"Using rendering device(s): {', '.join(sorted([str(x) for x in self._execution_devices.keys()]))}"
+ )
+
@property
def logger(self) -> Logger:
"""Return the logger used by the cache."""
return self._logger
- @property
- def lazy_offloading(self) -> bool:
- """Return true if the cache is configured to lazily offload models in VRAM."""
- return self._lazy_offloading
-
@property
def storage_device(self) -> torch.device:
"""Return the storage device (e.g. "CPU" for RAM)."""
return self._storage_device
@property
- def execution_device(self) -> torch.device:
- """Return the exection device (e.g. "cuda" for VRAM)."""
- return self._execution_device
+ def execution_devices(self) -> Set[torch.device]:
+ """Return the set of available execution devices."""
+ devices = self._execution_devices.keys()
+ return set(devices)
+
+ def get_execution_device(self) -> torch.device:
+ """
+ Return an execution device that has been reserved for current thread.
+
+ Note that reservations are done using the current thread's TID.
+ It would be better to do this using the session ID, but that involves
+ too many detailed changes to model manager calls.
+
+ May generate a ValueError if no GPU has been reserved.
+ """
+ current_thread = threading.current_thread().ident
+ assert current_thread is not None
+ assigned = [x for x, tid in self._execution_devices.items() if current_thread == tid]
+ if not assigned:
+ raise ValueError(f"No GPU has been reserved for the use of thread {current_thread}")
+ return assigned[0]
+
+ @contextmanager
+ def reserve_execution_device(self, timeout: Optional[int] = None) -> Generator[torch.device, None, None]:
+ """Reserve an execution device (e.g. GPU) for exclusive use by a generation thread.
+
+ Note that the reservation is done using the current thread's TID.
+ It would be better to do this using the session ID, but that involves
+ too many detailed changes to model manager calls.
+ """
+ device = None
+ with self._device_lock:
+ current_thread = threading.current_thread().ident
+ assert current_thread is not None
+
+ # look for a device that has already been assigned to this thread
+ assigned = [x for x, tid in self._execution_devices.items() if current_thread == tid]
+ if assigned:
+ device = assigned[0]
+
+ # no device already assigned. Get one.
+ if device is None:
+ self._free_execution_device.acquire(timeout=timeout)
+ with self._device_lock:
+ free_device = [x for x, tid in self._execution_devices.items() if tid == 0]
+ self._execution_devices[free_device[0]] = current_thread
+ device = free_device[0]
+
+ # we are outside the lock region now
+ self.logger.info(f"{current_thread} Reserved torch device {device}")
+
+ # Tell TorchDevice to use this object to get the torch device.
+ TorchDevice.set_model_cache(self)
+ try:
+ yield device
+ finally:
+ with self._device_lock:
+ self.logger.info(f"{current_thread} Released torch device {device}")
+ self._execution_devices[device] = 0
+ self._free_execution_device.release()
+ torch.cuda.empty_cache()
@property
def max_cache_size(self) -> float:
@@ -157,16 +211,16 @@ def put(
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Store model under key and optional submodel_type."""
- key = self._make_cache_key(key, submodel_type)
- if key in self._cached_models:
- return
- size = calc_model_size_by_data(model)
- self.make_room(size)
+ with self._ram_lock:
+ key = self._make_cache_key(key, submodel_type)
+ if key in self._cached_models:
+ return
+ size = calc_model_size_by_data(model)
+ self.make_room(size)
- state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
- cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
- self._cached_models[key] = cache_record
- self._cache_stack.append(key)
+ cache_record = CacheRecord(key=key, model=model, size=size)
+ self._cached_models[key] = cache_record
+ self._cache_stack.append(key)
def get(
self,
@@ -184,35 +238,36 @@ def get(
This may raise an IndexError if the model is not in the cache.
"""
- key = self._make_cache_key(key, submodel_type)
- if key in self._cached_models:
- if self.stats:
- self.stats.hits += 1
- else:
+ with self._ram_lock:
+ key = self._make_cache_key(key, submodel_type)
+ if key in self._cached_models:
+ if self.stats:
+ self.stats.hits += 1
+ else:
+ if self.stats:
+ self.stats.misses += 1
+ raise IndexError(f"The model with key {key} is not in the cache.")
+
+ cache_entry = self._cached_models[key]
+
+ # more stats
if self.stats:
- self.stats.misses += 1
- raise IndexError(f"The model with key {key} is not in the cache.")
-
- cache_entry = self._cached_models[key]
-
- # more stats
- if self.stats:
- stats_name = stats_name or key
- self.stats.cache_size = int(self._max_cache_size * GIG)
- self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
- self.stats.in_cache = len(self._cached_models)
- self.stats.loaded_model_sizes[stats_name] = max(
- self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
- )
+ stats_name = stats_name or key
+ self.stats.cache_size = int(self._max_cache_size * GIG)
+ self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
+ self.stats.in_cache = len(self._cached_models)
+ self.stats.loaded_model_sizes[stats_name] = max(
+ self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
+ )
- # this moves the entry to the top (right end) of the stack
- with suppress(Exception):
- self._cache_stack.remove(key)
- self._cache_stack.append(key)
- return ModelLocker(
- cache=self,
- cache_entry=cache_entry,
- )
+ # this moves the entry to the top (right end) of the stack
+ with suppress(Exception):
+ self._cache_stack.remove(key)
+ self._cache_stack.append(key)
+ return ModelLocker(
+ cache=self,
+ cache_entry=cache_entry,
+ )
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
if self._log_memory_usage:
@@ -225,127 +280,34 @@ def _make_cache_key(self, model_key: str, submodel_type: Optional[SubModelType]
else:
return model_key
- def offload_unlocked_models(self, size_required: int) -> None:
- """Move any unused models from VRAM."""
- reserved = self._max_vram_cache_size * GIG
- vram_in_use = torch.cuda.memory_allocated() + size_required
- self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
- for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
- if vram_in_use <= reserved:
- break
- if not cache_entry.loaded:
- continue
- if not cache_entry.locked:
- self.move_model_to_device(cache_entry, self.storage_device)
- cache_entry.loaded = False
- vram_in_use = torch.cuda.memory_allocated() + size_required
- self.logger.debug(
- f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
- )
-
- TorchDevice.empty_cache()
-
- def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
- """Move model into the indicated device.
+ def model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> AnyModel:
+ """Move a copy of the model into the indicated device and return it.
:param cache_entry: The CacheRecord for the model
:param target_device: The torch.device to move the model into
May raise a torch.cuda.OutOfMemoryError
"""
- self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
- source_device = cache_entry.device
-
- # Note: We compare device types only so that 'cuda' == 'cuda:0'.
- # This would need to be revised to support multi-GPU.
- if torch.device(source_device).type == torch.device(target_device).type:
- return
-
- # Some models don't have a `to` method, in which case they run in RAM/CPU.
- if not hasattr(cache_entry.model, "to"):
- return
-
- # This roundabout method for moving the model around is done to avoid
- # the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
- # When moving to VRAM, we copy (not move) each element of the state dict from
- # RAM to a new state dict in VRAM, and then inject it into the model.
- # This operation is slightly faster than running `to()` on the whole model.
- #
- # When the model needs to be removed from VRAM we simply delete the copy
- # of the state dict in VRAM, and reinject the state dict that is cached
- # in RAM into the model. So this operation is very fast.
- start_model_to_time = time.time()
- snapshot_before = self._capture_memory_snapshot()
-
- try:
- if cache_entry.state_dict is not None:
- assert hasattr(cache_entry.model, "load_state_dict")
- if target_device == self.storage_device:
- cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
- else:
- new_dict: Dict[str, torch.Tensor] = {}
- for k, v in cache_entry.state_dict.items():
- new_dict[k] = v.to(torch.device(target_device), copy=True, non_blocking=True)
- cache_entry.model.load_state_dict(new_dict, assign=True)
- cache_entry.model.to(target_device, non_blocking=True)
- cache_entry.device = target_device
- except Exception as e: # blow away cache entry
- self._delete_cache_entry(cache_entry)
- raise e
-
- snapshot_after = self._capture_memory_snapshot()
- end_model_to_time = time.time()
- self.logger.debug(
- f"Moved model '{cache_entry.key}' from {source_device} to"
- f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
- f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
- f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
- )
-
- if (
- snapshot_before is not None
- and snapshot_after is not None
- and snapshot_before.vram is not None
- and snapshot_after.vram is not None
- ):
- vram_change = abs(snapshot_before.vram - snapshot_after.vram)
-
- # If the estimated model size does not match the change in VRAM, log a warning.
- if not math.isclose(
- vram_change,
- cache_entry.size,
- rel_tol=0.1,
- abs_tol=10 * MB,
- ):
- self.logger.debug(
- f"Moving model '{cache_entry.key}' from {source_device} to"
- f" {target_device} caused an unexpected change in VRAM usage. The model's"
- " estimated size may be incorrect. Estimated model size:"
- f" {(cache_entry.size/GIG):.3f} GB.\n"
- f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
- )
+ with self._ram_lock:
+ self.logger.debug(f"Called to move {cache_entry.key} ({type(cache_entry.model)=}) to {target_device}")
+
+ # Some models don't have a state dictionary, in which case the
+ # stored model will still reside in CPU
+ if hasattr(cache_entry.model, "to"):
+ model_in_gpu = copy.deepcopy(cache_entry.model)
+ assert hasattr(model_in_gpu, "to")
+ model_in_gpu.to(target_device)
+ return model_in_gpu
+ else:
+ return cache_entry.model # what happens in CPU stays in CPU
def print_cuda_stats(self) -> None:
"""Log CUDA diagnostics."""
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
ram = "%4.2fG" % (self.cache_size() / GIG)
- in_ram_models = 0
- in_vram_models = 0
- locked_in_vram_models = 0
- for cache_record in self._cached_models.values():
- if hasattr(cache_record.model, "device"):
- if cache_record.model.device == self.storage_device:
- in_ram_models += 1
- else:
- in_vram_models += 1
- if cache_record.locked:
- locked_in_vram_models += 1
-
- self.logger.debug(
- f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) ="
- f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})"
- )
+ in_ram_models = len(self._cached_models)
+ self.logger.debug(f"Current VRAM/RAM usage for {in_ram_models} models: {vram}/{ram}")
def make_room(self, size: int) -> None:
"""Make enough room in the cache to accommodate a new model of indicated size."""
@@ -368,12 +330,14 @@ def make_room(self, size: int) -> None:
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
model_key = self._cache_stack[pos]
cache_entry = self._cached_models[model_key]
- device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
- self.logger.debug(
- f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}"
- )
- if not cache_entry.locked:
+ refs = sys.getrefcount(cache_entry.model)
+
+ # Expected refs:
+ # 1 from cache_entry
+ # 1 from getrefcount function
+ # 1 from onnx runtime object
+ if refs <= (3 if "onnx" in model_key else 2):
self.logger.debug(
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
)
@@ -400,10 +364,26 @@ def make_room(self, size: int) -> None:
if self.stats:
self.stats.cleared = models_cleared
gc.collect()
-
TorchDevice.empty_cache()
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
+ def _check_free_vram(self, target_device: torch.device, needed_size: int) -> None:
+ if target_device.type != "cuda":
+ return
+ vram_device = ( # mem_get_info() needs an indexed device
+ target_device if target_device.index is not None else torch.device(str(target_device), index=0)
+ )
+ free_mem, _ = torch.cuda.mem_get_info(torch.device(vram_device))
+ if needed_size > free_mem:
+ raise torch.cuda.OutOfMemoryError
+
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
- self._cache_stack.remove(cache_entry.key)
- del self._cached_models[cache_entry.key]
+ try:
+ self._cache_stack.remove(cache_entry.key)
+ del self._cached_models[cache_entry.key]
+ except ValueError:
+ pass
+
+ @staticmethod
+ def _device_name(device: torch.device) -> str:
+ return f"{device.type}:{device.index}"
diff --git a/invokeai/backend/model_manager/load/model_cache/model_locker.py b/invokeai/backend/model_manager/load/model_cache/model_locker.py
index 9de17ca5f53..68af7ba97a9 100644
--- a/invokeai/backend/model_manager/load/model_cache/model_locker.py
+++ b/invokeai/backend/model_manager/load/model_cache/model_locker.py
@@ -10,6 +10,8 @@
from .model_cache_base import CacheRecord, ModelCacheBase, ModelLockerBase
+MAX_GPU_WAIT = 600 # wait up to 10 minutes for a GPU to become free
+
class ModelLocker(ModelLockerBase):
"""Internal class that mediates movement in and out of GPU."""
@@ -29,33 +31,29 @@ def model(self) -> AnyModel:
"""Return the model without moving it around."""
return self._cache_entry.model
- def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
- """Return the state dict (if any) for the cached model."""
- return self._cache_entry.state_dict
-
def lock(self) -> AnyModel:
"""Move the model into the execution device (GPU) and lock it."""
- self._cache_entry.lock()
try:
- if self._cache.lazy_offloading:
- self._cache.offload_unlocked_models(self._cache_entry.size)
- self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device)
- self._cache_entry.loaded = True
- self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
+ device = self._cache.get_execution_device()
+ model_on_device = self._cache.model_to_device(self._cache_entry, device)
+ self._cache.logger.debug(f"Moved {self._cache_entry.key} to {device}")
self._cache.print_cuda_stats()
except torch.cuda.OutOfMemoryError:
self._cache.logger.warning("Insufficient GPU memory to load model. Aborting")
- self._cache_entry.unlock()
raise
except Exception:
- self._cache_entry.unlock()
raise
- return self.model
+ return model_on_device
+ # It is no longer necessary to move the model out of VRAM
+ # because it will be removed when it goes out of scope
+ # in the caller's context
def unlock(self) -> None:
"""Call upon exit from context."""
- self._cache_entry.unlock()
- if not self._cache.lazy_offloading:
- self._cache.offload_unlocked_models(0)
- self._cache.print_cuda_stats()
+ self._cache.print_cuda_stats()
+
+ # This is no longer in use in MGPU.
+ def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
+ """Return the state dict (if any) for the cached model."""
+ return None
diff --git a/invokeai/backend/model_patcher.py b/invokeai/backend/model_patcher.py
index fdc79539ae7..b879c3d4e80 100644
--- a/invokeai/backend/model_patcher.py
+++ b/invokeai/backend/model_patcher.py
@@ -4,6 +4,7 @@
from __future__ import annotations
import pickle
+import threading
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
@@ -34,6 +35,8 @@
# TODO: rename smth like ModelPatcher and add TI method?
class ModelPatcher:
+ _thread_lock = threading.Lock()
+
@staticmethod
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
assert "." not in lora_key
@@ -106,7 +109,7 @@ def apply_lora(
"""
original_weights = {}
try:
- with torch.no_grad():
+ with torch.no_grad(), cls._thread_lock:
for lora, lora_weight in loras:
# assert lora.device.type == "cpu"
for layer_key, layer in lora.layers.items():
@@ -129,9 +132,7 @@ def apply_lora(
dtype = module.weight.dtype
if module_key not in original_weights:
- if model_state_dict is not None: # we were provided with the CPU copy of the state dict
- original_weights[module_key] = model_state_dict[module_key + ".weight"]
- else:
+ if model_state_dict is None: # no CPU copy of the state dict was provided
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py
index 85950a01df5..01aae6b5a49 100644
--- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py
+++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py
@@ -32,8 +32,11 @@ class SDXLConditioningInfo(BasicConditioningInfo):
def to(self, device, dtype=None):
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
+ assert self.pooled_embeds.device == device
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
- return super().to(device=device, dtype=dtype)
+ result = super().to(device=device, dtype=dtype)
+ assert self.embeds.device == device
+ return result
@dataclass
diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py
index f418133e49f..a8f47247eca 100644
--- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py
+++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py
@@ -1,6 +1,7 @@
from __future__ import annotations
import math
+import threading
from typing import Any, Callable, Optional, Union
import torch
@@ -293,24 +294,31 @@ def _apply_standard_conditioning(
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
added_cond_kwargs = None
- if conditioning_data.is_sdxl():
- added_cond_kwargs = {
- "text_embeds": torch.cat(
- [
- # TODO: how to pad? just by zeros? or even truncate?
- conditioning_data.uncond_text.pooled_embeds,
- conditioning_data.cond_text.pooled_embeds,
- ],
- dim=0,
- ),
- "time_ids": torch.cat(
- [
- conditioning_data.uncond_text.add_time_ids,
- conditioning_data.cond_text.add_time_ids,
- ],
- dim=0,
- ),
- }
+ try:
+ if conditioning_data.is_sdxl():
+ # tid = threading.current_thread().ident
+ # print(f'DEBUG {tid} {conditioning_data.uncond_text.pooled_embeds.device=} {conditioning_data.cond_text.pooled_embeds.device=}', flush=True),
+ added_cond_kwargs = {
+ "text_embeds": torch.cat(
+ [
+ # TODO: how to pad? just by zeros? or even truncate?
+ conditioning_data.uncond_text.pooled_embeds,
+ conditioning_data.cond_text.pooled_embeds,
+ ],
+ dim=0,
+ ),
+ "time_ids": torch.cat(
+ [
+ conditioning_data.uncond_text.add_time_ids,
+ conditioning_data.cond_text.add_time_ids,
+ ],
+ dim=0,
+ ),
+ }
+ except Exception as e:
+ tid = threading.current_thread().ident
+ print(f"DEBUG: {tid} {str(e)}")
+ raise e
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None:
# TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings
diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py
index e8380dc8bcd..dc2bafaa9c4 100644
--- a/invokeai/backend/util/devices.py
+++ b/invokeai/backend/util/devices.py
@@ -1,10 +1,16 @@
-from typing import Dict, Literal, Optional, Union
+"""Torch Device class provides torch device selection services."""
+
+from typing import TYPE_CHECKING, Dict, Literal, Optional, Set, Union
import torch
from deprecated import deprecated
from invokeai.app.services.config.config_default import get_config
+if TYPE_CHECKING:
+ from invokeai.backend.model_manager.config import AnyModel
+ from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
+
# legacy APIs
TorchPrecisionNames = Literal["float32", "float16", "bfloat16"]
CPU_DEVICE = torch.device("cpu")
@@ -42,9 +48,23 @@ def torch_dtype(device: torch.device) -> torch.dtype:
class TorchDevice:
"""Abstraction layer for torch devices."""
+ _model_cache: Optional["ModelCacheBase[AnyModel]"] = None
+
+ @classmethod
+ def set_model_cache(cls, cache: "ModelCacheBase[AnyModel]"):
+ """Set the current model cache."""
+ cls._model_cache = cache
+
@classmethod
def choose_torch_device(cls) -> torch.device:
"""Return the torch.device to use for accelerated inference."""
+ if cls._model_cache:
+ return cls._model_cache.get_execution_device()
+ else:
+ return cls._choose_device()
+
+ @classmethod
+ def _choose_device(cls) -> torch.device:
app_config = get_config()
if app_config.device != "auto":
device = torch.device(app_config.device)
@@ -56,11 +76,19 @@ def choose_torch_device(cls) -> torch.device:
device = CPU_DEVICE
return cls.normalize(device)
+ @classmethod
+ def execution_devices(cls) -> Set[torch.device]:
+ """Return a list of torch.devices that can be used for accelerated inference."""
+ app_config = get_config()
+ if app_config.devices is None:
+ return cls._lookup_execution_devices()
+ return {torch.device(x) for x in app_config.devices}
+
@classmethod
def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype:
"""Return the precision to use for accelerated inference."""
- device = device or cls.choose_torch_device()
config = get_config()
+ device = device or cls._choose_device()
if device.type == "cuda" and torch.cuda.is_available():
device_name = torch.cuda.get_device_name(device)
if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name:
@@ -108,3 +136,13 @@ def empty_cache(cls) -> None:
@classmethod
def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype:
return NAME_TO_PRECISION[precision_name]
+
+ @classmethod
+ def _lookup_execution_devices(cls) -> Set[torch.device]:
+ if torch.cuda.is_available():
+ devices = {torch.device(f"cuda:{x}") for x in range(0, torch.cuda.device_count())}
+ elif torch.backends.mps.is_available():
+ devices = {torch.device("mps")}
+ else:
+ devices = {torch.device("cpu")}
+ return devices
diff --git a/scripts/populate_model_db_from_yaml.py b/scripts/populate_model_db_from_yaml.py
new file mode 100755
index 00000000000..80e5bcfc5c6
--- /dev/null
+++ b/scripts/populate_model_db_from_yaml.py
@@ -0,0 +1,54 @@
+#!/bin/env python
+
+from argparse import ArgumentParser, Namespace
+from pathlib import Path
+
+from invokeai.app.services.config import InvokeAIAppConfig, get_config
+from invokeai.app.services.download import DownloadQueueService
+from invokeai.app.services.model_install import ModelInstallService
+from invokeai.app.services.model_records import ModelRecordServiceSQL
+from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
+from invokeai.backend.util.logging import InvokeAILogger
+
+
+def get_args() -> Namespace:
+ parser = ArgumentParser(description="Update models database from yaml file")
+ parser.add_argument("--root", type=Path, required=False, default=None)
+ parser.add_argument("--yaml_file", type=Path, required=False, default=None)
+ return parser.parse_args()
+
+
+def populate_config() -> InvokeAIAppConfig:
+ args = get_args()
+ config = get_config()
+ if args.root:
+ config._root = args.root
+ if args.yaml_file:
+ config.legacy_models_yaml_path = args.yaml_file
+ else:
+ config.legacy_models_yaml_path = config.root_path / "configs/models.yaml"
+ return config
+
+
+def initialize_installer(config: InvokeAIAppConfig) -> ModelInstallService:
+ logger = InvokeAILogger.get_logger(config=config)
+ db = SqliteDatabase(config.db_path, logger)
+ record_store = ModelRecordServiceSQL(db)
+ queue = DownloadQueueService()
+ queue.start()
+ installer = ModelInstallService(app_config=config, record_store=record_store, download_queue=queue)
+ return installer
+
+
+def main() -> None:
+ config = populate_config()
+ installer = initialize_installer(config)
+ installer._migrate_yaml(rename_yaml=False, overwrite_db=True)
+ print("\n")
+ print("\t".join(["key", "name", "type", "path"]))
+ for model in installer.record_store.all_models():
+ print("\t".join([model.key, model.name, model.type, (config.models_path / model.path).as_posix()]))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tests/backend/model_manager/model_loading/test_model_load.py b/tests/backend/model_manager/model_loading/test_model_load.py
index 3f12f7f8ee9..ff80d49397f 100644
--- a/tests/backend/model_manager/model_loading/test_model_load.py
+++ b/tests/backend/model_manager/model_loading/test_model_load.py
@@ -14,13 +14,14 @@ def test_loading(mm2_model_manager: ModelManagerServiceBase, embedding_file: Pat
matches = store.search_by_attr(model_name="test_embedding")
assert len(matches) == 0
key = mm2_model_manager.install.register_path(embedding_file)
- loaded_model = mm2_model_manager.load.load_model(store.get_model(key))
- assert loaded_model is not None
- assert loaded_model.config.key == key
- with loaded_model as model:
- assert isinstance(model, TextualInversionModelRaw)
+ with mm2_model_manager.load.ram_cache.reserve_execution_device():
+ loaded_model = mm2_model_manager.load.load_model(store.get_model(key))
+ assert loaded_model is not None
+ assert loaded_model.config.key == key
+ with loaded_model as model:
+ assert isinstance(model, TextualInversionModelRaw)
- config = mm2_model_manager.store.get_model(key)
- loaded_model_2 = mm2_model_manager.load.load_model(config)
+ config = mm2_model_manager.store.get_model(key)
+ loaded_model_2 = mm2_model_manager.load.load_model(config)
- assert loaded_model.config.key == loaded_model_2.config.key
+ assert loaded_model.config.key == loaded_model_2.config.key
diff --git a/tests/backend/model_manager/model_manager_fixtures.py b/tests/backend/model_manager/model_manager_fixtures.py
index f82239298e1..e7e592d9b71 100644
--- a/tests/backend/model_manager/model_manager_fixtures.py
+++ b/tests/backend/model_manager/model_manager_fixtures.py
@@ -89,11 +89,10 @@ def mm2_download_queue(mm2_session: Session) -> DownloadQueueServiceBase:
@pytest.fixture
-def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase:
+def mm2_loader(mm2_app_config: InvokeAIAppConfig) -> ModelLoadServiceBase:
ram_cache = ModelCache(
logger=InvokeAILogger.get_logger(),
max_cache_size=mm2_app_config.ram,
- max_vram_cache_size=mm2_app_config.vram,
)
convert_cache = ModelConvertCache(mm2_app_config.convert_cache_path)
return ModelLoadService(
diff --git a/tests/backend/util/test_devices.py b/tests/backend/util/test_devices.py
index 8e810e43678..d854a82e622 100644
--- a/tests/backend/util/test_devices.py
+++ b/tests/backend/util/test_devices.py
@@ -8,7 +8,9 @@
import torch
from invokeai.app.services.config import get_config
+from invokeai.backend.model_manager.load import ModelCache
from invokeai.backend.util.devices import TorchDevice, choose_precision, choose_torch_device, torch_dtype
+from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
devices = ["cpu", "cuda:0", "cuda:1", "mps"]
device_types_cpu = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", torch.float32)]
@@ -20,6 +22,7 @@
def test_device_choice(device_name):
config = get_config()
config.device = device_name
+ TorchDevice.set_model_cache(None) # disable dynamic selection of GPU device
torch_device = TorchDevice.choose_torch_device()
assert torch_device == torch.device(device_name)
@@ -130,3 +133,32 @@ def test_legacy_precision_name():
assert "float16" == choose_precision(torch.device("cuda"))
assert "float16" == choose_precision(torch.device("mps"))
assert "float32" == choose_precision(torch.device("cpu"))
+
+
+def test_multi_device_support_1():
+ config = get_config()
+ config.devices = ["cuda:0", "cuda:1"]
+ assert TorchDevice.execution_devices() == {torch.device("cuda:0"), torch.device("cuda:1")}
+
+
+def test_multi_device_support_2():
+ config = get_config()
+ config.devices = None
+ with (
+ patch("torch.cuda.device_count", return_value=3),
+ patch("torch.cuda.is_available", return_value=True),
+ ):
+ assert TorchDevice.execution_devices() == {
+ torch.device("cuda:0"),
+ torch.device("cuda:1"),
+ torch.device("cuda:2"),
+ }
+
+
+def test_multi_device_support_3():
+ config = get_config()
+ config.devices = ["cuda:0", "cuda:1"]
+ cache = ModelCache()
+ with cache.reserve_execution_device() as gpu:
+ assert gpu in [torch.device(x) for x in config.devices]
+ assert TorchDevice.choose_torch_device() == gpu
diff --git a/tests/conftest.py b/tests/conftest.py
index 8a67e9473c8..e140bcd7df4 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -17,7 +17,6 @@
from invokeai.app.services.images.images_default import ImageService
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from invokeai.app.services.invocation_services import InvocationServices
-from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
from invokeai.app.services.invoker import Invoker
from invokeai.backend.util.logging import InvokeAILogger
from tests.backend.model_manager.model_manager_fixtures import * # noqa: F403
@@ -49,13 +48,13 @@ def mock_services() -> InvocationServices:
model_manager=None, # type: ignore
download_queue=None, # type: ignore
names=None, # type: ignore
- performance_statistics=InvocationStatsService(),
session_processor=None, # type: ignore
session_queue=None, # type: ignore
urls=None, # type: ignore
workflow_records=None, # type: ignore
tensors=None, # type: ignore
conditioning=None, # type: ignore
+ performance_statistics=None, # type: ignore
)
diff --git a/tests/test_config.py b/tests/test_config.py
index a6ea2a34806..6606bc0f3e0 100644
--- a/tests/test_config.py
+++ b/tests/test_config.py
@@ -92,7 +92,6 @@ def test_migrate_v3_config_from_file(tmp_path: Path, patch_rootdir: None):
assert config.host == "192.168.1.1"
assert config.port == 8080
assert config.ram == 100
- assert config.vram == 50
assert config.legacy_models_yaml_path == Path("/custom/models.yaml")
# This should be stripped out
assert not hasattr(config, "esrgan")