From cd2dccf6f294d070f750788b3e5ca296706b66e1 Mon Sep 17 00:00:00 2001
From: Sergey Borisov <stalkek7779@yandex.ru>
Date: Thu, 27 Jun 2024 16:18:21 +0300
Subject: [PATCH 01/25] Redo attention processor to support other attention
 types

---
 .../stable_diffusion/diffusers_pipeline.py    |   4 +-
 .../diffusion/custom_atttention.py            | 416 +++++++++++++-----
 .../diffusion/unet_attention_patcher.py       |  43 +-
 3 files changed, 347 insertions(+), 116 deletions(-)

diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py
index ee464f73e1f..fe5310c216d 100644
--- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py
+++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py
@@ -330,8 +330,6 @@ def latents_from_embeddings(
             # latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
             latents = self.scheduler.add_noise(latents, noise, batched_init_timestep)
 
-        self._adjust_memory_efficient_attention(latents)
-
         # Handle mask guidance (a.k.a. inpainting).
         mask_guidance: AddsMaskGuidance | None = None
         if mask is not None and not is_inpainting_model(self.unet):
@@ -371,6 +369,8 @@ def latents_from_embeddings(
             )
             unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
             attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
+        else:
+            self._adjust_memory_efficient_attention(latents)
 
         with attn_ctx:
             callback(
diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py
index 1334313fe6e..8a2c62354d0 100644
--- a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py
+++ b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py
@@ -1,14 +1,20 @@
 from dataclasses import dataclass
-from typing import List, Optional, cast
+from typing import List, Optional, Union, Callable, cast
 
 import torch
 import torch.nn.functional as F
 from diffusers.models.attention_processor import Attention, AttnProcessor2_0
+from diffusers.utils.import_utils import is_xformers_available
 
 from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
 from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData
 from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
 
+if is_xformers_available():
+    import xformers
+    import xformers.ops
+else:
+    xformers = None
 
 @dataclass
 class IPAdapterAttentionWeights:
@@ -16,10 +22,13 @@ class IPAdapterAttentionWeights:
     skip: bool
 
 
-class CustomAttnProcessor2_0(AttnProcessor2_0):
-    """A custom implementation of AttnProcessor2_0 that supports additional Invoke features.
+class CustomAttnProcessor:
+    """A custom implementation of attention processor that supports additional Invoke features.
     This implementation is based on
-    https://github.com/huggingface/diffusers/blame/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1204
+    AttnProcessor (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L732)
+    SlicedAttnProcessor (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1616)
+    XFormersAttnProcessor (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1113)
+    AttnProcessor2_0 (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1204)
     Supported custom features:
     - IP-Adapter
     - Regional prompt attention
@@ -27,17 +36,48 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
 
     def __init__(
         self,
+        attention_type: str,
         ip_adapter_attention_weights: Optional[List[IPAdapterAttentionWeights]] = None,
+        # xformers
+        attention_op: Optional[Callable] = None,
+        # sliced
+        slice_size: Optional[Union[str, int]] = None, # TODO: or "auto"?
+
     ):
-        """Initialize a CustomAttnProcessor2_0.
+        """Initialize a CustomAttnProcessor.
         Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are
         layer-specific are passed to __init__().
         Args:
             ip_adapter_weights: The IP-Adapter attention weights. ip_adapter_weights[i] contains the attention weights
                 for the i'th IP-Adapter.
+            attention_op (`Callable`, *optional*, defaults to `None`):
+                The base
+                [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
+                use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
+                operator.
+            slice_size (`int`, *optional*):
+                The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
+                `attention_head_dim` must be a multiple of the `slice_size`.
         """
-        super().__init__()
+        if attention_type not in ["normal", "sliced", "xformers", "torch-sdp"]:
+            raise Exception(f"Unknown attention type: {attention_type}")
+
+        if attention_type == "xformers" and xformers is None:
+            raise ImportError("xformers attention requires xformers module to be installed.")
+
+        if attention_type == "torch-sdp" and not hasattr(F, "scaled_dot_product_attention"):
+            raise ImportError("torch-sdp attention requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+        if attention_type == "sliced":
+            if slice_size is None:
+                raise Exception(f"slice_size required for sliced attention")
+            if slice_size not in ["auto", "max"] and not isinstance(slice_size, int):
+                raise Exception(f"Unsupported slice_size: {slice_size}")
+
         self._ip_adapter_attention_weights = ip_adapter_attention_weights
+        self.attention_type = attention_type
+        self.attention_op = attention_op
+        self.slice_size = slice_size
 
     def __call__(
         self,
@@ -53,19 +93,12 @@ def __call__(
         regional_ip_data: Optional[RegionalIPData] = None,
         *args,
         **kwargs,
-    ) -> torch.FloatTensor:
-        """Apply attention.
-        Args:
-            regional_prompt_data: The regional prompt data for the current batch. If not None, this will be used to
-                apply regional prompt masking.
-            regional_ip_data: The IP-Adapter data for the current batch.
-        """
+    ) -> torch.Tensor:
         # If true, we are doing cross-attention, if false we are doing self-attention.
         is_cross_attention = encoder_hidden_states is not None
 
-        # Start unmodified block from AttnProcessor2_0.
-        # vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
         residual = hidden_states
+
         if attn.spatial_norm is not None:
             hidden_states = attn.spatial_norm(hidden_states, temb)
 
@@ -75,18 +108,134 @@ def __call__(
             batch_size, channel, height, width = hidden_states.shape
             hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
 
-        batch_size, sequence_length, _ = (
+        batch_size, key_length, _ = (
             hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
         )
-        # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-        # End unmodified block from AttnProcessor2_0.
+        query_length = hidden_states.shape[1]
+
+        attention_mask = self.prepare_attention_mask(
+            attn=attn,
+            attention_mask=attention_mask,
+            batch_size=batch_size,
+            key_length=key_length,
+            query_length=query_length,
+            is_cross_attention=is_cross_attention,
+            regional_prompt_data=regional_prompt_data,
+        )
+
+        if attn.group_norm is not None:
+            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+        query = attn.to_q(hidden_states)
+
+        if encoder_hidden_states is None:
+            encoder_hidden_states = hidden_states
+        elif attn.norm_cross:
+            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+        key = attn.to_k(encoder_hidden_states)
+        value = attn.to_v(encoder_hidden_states)
+
+        hidden_states = self.run_attention(
+            attn=attn,
+            query=query,
+            key=key,
+            value=value,
+            attention_mask=attention_mask,
+        )
+
+        if is_cross_attention:
+            hidden_states = self.run_ip_adapters(
+                attn=attn,
+                hidden_states=hidden_states,
+                regional_ip_data=regional_ip_data,
+                query_length=query_length,
+                query=query,
+            )
+
+        # linear proj
+        hidden_states = attn.to_out[0](hidden_states)
+        # dropout
+        hidden_states = attn.to_out[1](hidden_states)
+
+        if input_ndim == 4:
+            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+        if attn.residual_connection:
+            hidden_states = hidden_states + residual
+
+        hidden_states = hidden_states / attn.rescale_output_factor
+
+        return hidden_states
+
+
+    def run_ip_adapters(
+        self,
+        attn: Attention,
+        hidden_states: torch.Tensor,
+        regional_ip_data: Optional[RegionalIPData],
+        query_length: int, # TODO: just read from query?
+        query: torch.Tensor,
+    ) -> torch.Tensor:
+        if self._ip_adapter_attention_weights is None:
+            # If IP-Adapter is not enabled, then regional_ip_data should not be passed in.
+            assert regional_ip_data is None
+            return hidden_states
+
+        ip_masks = regional_ip_data.get_masks(query_seq_len=query_length)
+
+        assert (
+            len(regional_ip_data.image_prompt_embeds)
+            == len(self._ip_adapter_attention_weights)
+            == len(regional_ip_data.scales)
+            == ip_masks.shape[1]
+        )
+
+        for ipa_index, ip_hidden_states in enumerate(regional_ip_data.image_prompt_embeds):
+            # The batch dimensions should match.
+            #assert ip_hidden_states.shape[0] == encoder_hidden_states.shape[0]
+            # The token_len dimensions should match.
+            #assert ip_hidden_states.shape[-1] == encoder_hidden_states.shape[-1]
+
+            if self._ip_adapter_attention_weights[ipa_index].skip:
+                continue
+
+            ipa_weights = self._ip_adapter_attention_weights[ipa_index].ip_adapter_weights
+            ipa_scale = regional_ip_data.scales[ipa_index]
+            ip_mask = ip_masks[0, ipa_index, ...]
+
+            # Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
+            ip_key = ipa_weights.to_k_ip(ip_hidden_states)
+            ip_value = ipa_weights.to_v_ip(ip_hidden_states)
+
+            ip_hidden_states = self.run_attention(
+                attn=attn,
+                query=query,
+                key=ip_key,
+                value=ip_value,
+                attention_mask=None,
+            )
+
+            # Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
+            hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask
+
+        return hidden_states
+
+
+    def prepare_attention_mask(
+        self,
+        attn: Attention,
+        attention_mask: Optional[torch.Tensor],
+        batch_size: int,
+        key_length: int,
+        query_length: int,
+        is_cross_attention: bool,
+        regional_prompt_data: Optional[RegionalPromptData],
+    ) -> Optional[torch.Tensor]:
 
-        _, query_seq_len, _ = hidden_states.shape
-        # Handle regional prompt attention masks.
         if regional_prompt_data is not None and is_cross_attention:
-            assert percent_through is not None
             prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
-                query_seq_len=query_seq_len, key_seq_len=sequence_length
+                query_seq_len=query_length, key_seq_len=key_length
             )
 
             if attention_mask is None:
@@ -94,32 +243,112 @@ def __call__(
             else:
                 attention_mask = prompt_region_attention_mask + attention_mask
 
-        # Start unmodified block from AttnProcessor2_0.
-        # vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
-        if attention_mask is not None:
-            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
-            # scaled_dot_product_attention expects attention_mask shape to be
-            # (batch, heads, source_length, target_length)
-            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
 
-        if attn.group_norm is not None:
-            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+        attention_mask = attn.prepare_attention_mask(attention_mask, key_length, batch_size)
 
-        query = attn.to_q(hidden_states)
+        if self.attention_type in ["normal", "sliced"]:
+            pass
 
-        if encoder_hidden_states is None:
-            encoder_hidden_states = hidden_states
-        elif attn.norm_cross:
-            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+        elif self.attention_type == "xformers":
+            if attention_mask is not None:
+                # expand our mask's singleton query_length dimension:
+                #   [batch*heads,            1, key_length] ->
+                #   [batch*heads, query_length, key_length]
+                # so that it can be added as a bias onto the attention scores that xformers computes:
+                #   [batch*heads, query_length, key_length]
+                # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
+                attention_mask = attention_mask.expand(-1, query_length, -1)
 
-        key = attn.to_k(encoder_hidden_states)
-        value = attn.to_v(encoder_hidden_states)
+        elif self.attention_type == "torch-sdp":
+            if attention_mask is not None:
+                # scaled_dot_product_attention expects attention_mask shape to be
+                # (batch, heads, source_length, target_length)
+                attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+        else:
+            raise Exception(f"Unknown attention type: {self.attention_type}")
+
+        return attention_mask
 
+    def run_attention(
+        self,
+        attn: Attention,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        attention_mask: Optional[torch.Tensor],
+    ) -> torch.Tensor:
+        if self.attention_type == "normal":
+            attn_call = self.run_attention_normal
+        elif self.attention_type == "xformers":
+            attn_call = self.run_attention_xformers
+        elif self.attention_type == "torch-sdp":
+            attn_call = self.run_attention_sdp
+        elif self.attention_type == "sliced":
+            attn_call = self.run_attention_sliced
+        else:
+            raise Exception(f"Unknown attention type: {self.attention_type}")
+
+        return attn_call(
+            attn=attn,
+            query=query,
+            key=key,
+            value=value,
+            attention_mask=attention_mask,
+        )
+
+    def run_attention_normal(
+        self,
+        attn: Attention,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        attention_mask: Optional[torch.Tensor],
+    ) -> torch.Tensor:
+        query = attn.head_to_batch_dim(query)
+        key   = attn.head_to_batch_dim(key)
+        value = attn.head_to_batch_dim(value)
+
+        attention_probs = attn.get_attention_scores(query, key, attention_mask)
+        hidden_states = torch.bmm(attention_probs, value)
+        hidden_states = attn.batch_to_head_dim(hidden_states)
+
+        return hidden_states
+
+    def run_attention_xformers(
+        self,
+        attn: Attention,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        attention_mask: Optional[torch.Tensor],
+    ) -> torch.Tensor:
+        # attention_op
+        query = attn.head_to_batch_dim(query).contiguous()
+        key   = attn.head_to_batch_dim(key).contiguous()
+        value = attn.head_to_batch_dim(value).contiguous()
+
+        hidden_states = xformers.ops.memory_efficient_attention(
+            query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
+        )
+        hidden_states = hidden_states.to(query.dtype)
+        hidden_states = attn.batch_to_head_dim(hidden_states)
+
+        return hidden_states
+
+    def run_attention_sdp(
+        self,
+        attn: Attention,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        attention_mask: Optional[torch.Tensor],
+    ) -> torch.Tensor:
+        batch_size = key.shape[0]
         inner_dim = key.shape[-1]
         head_dim = inner_dim // attn.heads
 
         query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
         key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
 
@@ -131,84 +360,51 @@ def __call__(
 
         hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
         hidden_states = hidden_states.to(query.dtype)
-        # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-        # End unmodified block from AttnProcessor2_0.
-
-        # Apply IP-Adapter conditioning.
-        if is_cross_attention:
-            if self._ip_adapter_attention_weights:
-                assert regional_ip_data is not None
-                ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len)
-
-                assert (
-                    len(regional_ip_data.image_prompt_embeds)
-                    == len(self._ip_adapter_attention_weights)
-                    == len(regional_ip_data.scales)
-                    == ip_masks.shape[1]
-                )
-
-                for ipa_index, ipa_embed in enumerate(regional_ip_data.image_prompt_embeds):
-                    ipa_weights = self._ip_adapter_attention_weights[ipa_index].ip_adapter_weights
-                    ipa_scale = regional_ip_data.scales[ipa_index]
-                    ip_mask = ip_masks[0, ipa_index, ...]
-
-                    # The batch dimensions should match.
-                    assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
-                    # The token_len dimensions should match.
-                    assert ipa_embed.shape[-1] == encoder_hidden_states.shape[-1]
 
-                    ip_hidden_states = ipa_embed
+        return hidden_states
 
-                    # Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
-
-                    if not self._ip_adapter_attention_weights[ipa_index].skip:
-                        ip_key = ipa_weights.to_k_ip(ip_hidden_states)
-                        ip_value = ipa_weights.to_v_ip(ip_hidden_states)
-
-                        # Expected ip_key and ip_value shape:
-                        # (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
-
-                        ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-                        ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
-                        # Expected ip_key and ip_value shape:
-                        # (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
-
-                        # TODO: add support for attn.scale when we move to Torch 2.1
-                        ip_hidden_states = F.scaled_dot_product_attention(
-                            query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
-                        )
+    def run_attention_sliced(
+        self,
+        attn: Attention,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        attention_mask: Optional[torch.Tensor],
+    ) -> torch.Tensor:
+        # slice_size
+        if self.slice_size == "max":
+            slice_size = 1
+        elif self.slice_size == "auto":
+            slice_size = max(1, attn.sliceable_head_dim // 2)
+        else:
+            slice_size = min(self.slice_size, attn.sliceable_head_dim)
+
+        dim = query.shape[-1]
+
+        query = attn.head_to_batch_dim(query)
+        key   = attn.head_to_batch_dim(key)
+        value = attn.head_to_batch_dim(value)
+
+        batch_size_attention, query_tokens, _ = query.shape
+        hidden_states = torch.zeros(
+            (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
+        )
 
-                        # Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
-                        ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(
-                            batch_size, -1, attn.heads * head_dim
-                        )
+        for i in range(batch_size_attention // slice_size):
+            start_idx = i * slice_size
+            end_idx = (i + 1) * slice_size
 
-                        ip_hidden_states = ip_hidden_states.to(query.dtype)
+            query_slice = query[start_idx:end_idx]
+            key_slice = key[start_idx:end_idx]
+            attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
 
-                        # Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
-                        hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask
-            else:
-                # If IP-Adapter is not enabled, then regional_ip_data should not be passed in.
-                assert regional_ip_data is None
+            attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
 
-        # Start unmodified block from AttnProcessor2_0.
-        # vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
-        # linear proj
-        hidden_states = attn.to_out[0](hidden_states)
-        # dropout
-        hidden_states = attn.to_out[1](hidden_states)
+            attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
 
-        if input_ndim == 4:
-            batch_size, channel, height, width = hidden_states.shape
-            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+            hidden_states[start_idx:end_idx] = attn_slice
 
-        if attn.residual_connection:
-            hidden_states = hidden_states + residual
+        hidden_states = attn.batch_to_head_dim(hidden_states)
 
-        hidden_states = hidden_states / attn.rescale_output_factor
-        # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-        # End of unmodified block from AttnProcessor2_0
+        return hidden_states
 
-        # casting torch.Tensor to torch.FloatTensor to avoid type issues
-        return cast(torch.FloatTensor, hidden_states)
diff --git a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py
index ac00a8e06ea..c2f79607c02 100644
--- a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py
+++ b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py
@@ -2,10 +2,12 @@
 from typing import List, Optional, TypedDict
 
 from diffusers.models import UNet2DConditionModel
+from diffusers.utils.import_utils import is_xformers_available
 
+from invokeai.app.services.config.config_default import get_config
 from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
 from invokeai.backend.stable_diffusion.diffusion.custom_atttention import (
-    CustomAttnProcessor2_0,
+    CustomAttnProcessor,
     IPAdapterAttentionWeights,
 )
 
@@ -16,22 +18,52 @@ class UNetIPAdapterData(TypedDict):
 
 
 class UNetAttentionPatcher:
-    """A class for patching a UNet with CustomAttnProcessor2_0 attention layers."""
+    """A class for patching a UNet with CustomAttnProcessor attention layers."""
 
     def __init__(self, ip_adapter_data: Optional[List[UNetIPAdapterData]]):
         self._ip_adapters = ip_adapter_data
 
+    def get_attention_processor_kwargs(self, unet: UNet2DConditionModel):
+        config = get_config()
+        kwargs = dict()
+        
+        # TODO:
+        attention_type = config.attention_type
+        if attention_type == "auto":
+            if self.unet.device.type == "cuda":
+                if is_xformers_available():
+                    attention_type = "xformers"
+                elif hasattr(torch.nn.functional, "scaled_dot_product_attention"):
+                    attention_type = "torch-sdp"
+                else:
+                    attention_type = "normal"
+            else:
+                attention_type = "sliced"
+
+        kwargs["attention_type"] = attention_type
+        
+        if attention_type == "sliced":
+            slice_size = config.attention_slice_size
+            if slice_size == "balanced":
+                slice_size = "auto"
+            kwargs["slice_size"] = slice_size
+
+        return kwargs
+
     def _prepare_attention_processors(self, unet: UNet2DConditionModel):
         """Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
         weights into them (if IP-Adapters are being applied).
         Note that the `unet` param is only used to determine attention block dimensions and naming.
         """
         # Construct a dict of attention processors based on the UNet's architecture.
+
+        attn_processor_kwargs = self.get_attention_processor_kwargs(unet)
+
         attn_procs = {}
         for idx, name in enumerate(unet.attn_processors.keys()):
             if name.endswith("attn1.processor") or self._ip_adapters is None:
                 # "attn1" processors do not use IP-Adapters.
-                attn_procs[name] = CustomAttnProcessor2_0()
+                attn_procs[name] = CustomAttnProcessor(**attn_processor_kwargs)
             else:
                 # Collect the weights from each IP Adapter for the idx'th attention processor.
                 ip_adapter_attention_weights_collection: list[IPAdapterAttentionWeights] = []
@@ -48,7 +80,10 @@ def _prepare_attention_processors(self, unet: UNet2DConditionModel):
                     )
                     ip_adapter_attention_weights_collection.append(ip_adapter_attention_weights)
 
-                attn_procs[name] = CustomAttnProcessor2_0(ip_adapter_attention_weights_collection)
+                attn_procs[name] = CustomAttnProcessor(
+                    ip_adapter_attention_weights=ip_adapter_attention_weights_collection,
+                    **attn_processor_kwargs,
+                )
 
         return attn_procs
 

From 9f40c2da8d81fdcb3af48e741a88159eee53508e Mon Sep 17 00:00:00 2001
From: Sergey Borisov <stalkek7779@yandex.ru>
Date: Sun, 28 Jul 2024 02:24:40 +0300
Subject: [PATCH 02/25] Remove xformers and normal attention

---
 invokeai/app/api/routers/app_info.py          |   8 +-
 .../stable_diffusion/diffusers_pipeline.py    |  36 +++---
 .../diffusion/custom_atttention.py            | 107 +++---------------
 .../diffusion/unet_attention_patcher.py       |  22 ++--
 invokeai/backend/util/hotfixes.py             |  46 --------
 5 files changed, 38 insertions(+), 181 deletions(-)

diff --git a/invokeai/app/api/routers/app_info.py b/invokeai/app/api/routers/app_info.py
index c3bc98a0387..22827ef7879 100644
--- a/invokeai/app/api/routers/app_info.py
+++ b/invokeai/app/api/routers/app_info.py
@@ -1,6 +1,6 @@
 import typing
 from enum import Enum
-from importlib.metadata import PackageNotFoundError, version
+from importlib.metadata import version
 from pathlib import Path
 from platform import python_version
 from typing import Optional
@@ -76,10 +76,6 @@ async def get_version() -> AppVersion:
 
 @app_router.get("/app_deps", operation_id="get_app_deps", status_code=200, response_model=AppDependencyVersions)
 async def get_app_deps() -> AppDependencyVersions:
-    try:
-        xformers = version("xformers")
-    except PackageNotFoundError:
-        xformers = None
     return AppDependencyVersions(
         accelerate=version("accelerate"),
         compel=version("compel"),
@@ -93,7 +89,7 @@ async def get_app_deps() -> AppDependencyVersions:
         torch=torch.version.__version__,
         torchvision=version("torchvision"),
         transformers=version("transformers"),
-        xformers=xformers,
+        xformers=None,  # TODO: ask frontend
     )
 
 
diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py
index fe5310c216d..33115cbeb4d 100644
--- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py
+++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py
@@ -10,15 +10,16 @@
 import psutil
 import torch
 import torchvision.transforms as T
+from diffusers.models.attention_processor import AttnProcessor2_0
 from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
 from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
 from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
 from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
 from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
-from diffusers.utils.import_utils import is_xformers_available
 from pydantic import Field
 from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
 
+import invokeai.backend.util.logging as logger
 from invokeai.app.services.config.config_default import get_config
 from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
 from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
@@ -177,14 +178,13 @@ def __init__(
         self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
 
     def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
-        """
-        if xformers is available, use it, otherwise use sliced attention.
-        """
         config = get_config()
-        if config.attention_type == "xformers":
-            self.enable_xformers_memory_efficient_attention()
-            return
-        elif config.attention_type == "sliced":
+        attention_type = config.attention_type
+        if attention_type in ["normal", "xformers"]:
+            logger.warning(f'Attention "{attention_type}" no longer supported, "torch-sdp" will be used instead.')
+            attention_type = "torch-sdp"
+
+        if config.attention_type == "sliced":
             slice_size = config.attention_slice_size
             if slice_size == "auto":
                 slice_size = auto_detect_slice_size(latents)
@@ -192,24 +192,14 @@ def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
                 slice_size = "auto"
             self.enable_attention_slicing(slice_size=slice_size)
             return
-        elif config.attention_type == "normal":
-            self.disable_attention_slicing()
-            return
         elif config.attention_type == "torch-sdp":
-            if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
-                # diffusers enables sdp automatically
-                return
-            else:
-                raise Exception("torch-sdp attention slicing not available")
+            self.unet.set_attn_processor(AttnProcessor2_0())
+            return
 
         # the remainder if this code is called when attention_type=='auto'
         if self.unet.device.type == "cuda":
-            if is_xformers_available():
-                self.enable_xformers_memory_efficient_attention()
-                return
-            elif hasattr(torch.nn.functional, "scaled_dot_product_attention"):
-                # diffusers enables sdp automatically
-                return
+            self.unet.set_attn_processor(AttnProcessor2_0())
+            return
 
         if self.unet.device.type == "cpu" or self.unet.device.type == "mps":
             mem_free = psutil.virtual_memory().free
@@ -234,7 +224,7 @@ def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
             # diffusers recommends always enabling for mps
             self.enable_attention_slicing(slice_size="max")
         else:
-            self.disable_attention_slicing()
+            self.unet.set_attn_processor(AttnProcessor2_0())
 
     def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
         raise Exception("Should not be called")
diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py
index 8a2c62354d0..d5f78ded70f 100644
--- a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py
+++ b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py
@@ -1,20 +1,14 @@
 from dataclasses import dataclass
-from typing import List, Optional, Union, Callable, cast
+from typing import List, Optional, Union
 
 import torch
 import torch.nn.functional as F
-from diffusers.models.attention_processor import Attention, AttnProcessor2_0
-from diffusers.utils.import_utils import is_xformers_available
+from diffusers.models.attention_processor import Attention
 
 from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
 from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData
 from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
 
-if is_xformers_available():
-    import xformers
-    import xformers.ops
-else:
-    xformers = None
 
 @dataclass
 class IPAdapterAttentionWeights:
@@ -25,9 +19,7 @@ class IPAdapterAttentionWeights:
 class CustomAttnProcessor:
     """A custom implementation of attention processor that supports additional Invoke features.
     This implementation is based on
-    AttnProcessor (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L732)
     SlicedAttnProcessor (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1616)
-    XFormersAttnProcessor (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1113)
     AttnProcessor2_0 (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1204)
     Supported custom features:
     - IP-Adapter
@@ -38,11 +30,8 @@ def __init__(
         self,
         attention_type: str,
         ip_adapter_attention_weights: Optional[List[IPAdapterAttentionWeights]] = None,
-        # xformers
-        attention_op: Optional[Callable] = None,
         # sliced
-        slice_size: Optional[Union[str, int]] = None, # TODO: or "auto"?
-
+        slice_size: Optional[Union[str, int]] = None,
     ):
         """Initialize a CustomAttnProcessor.
         Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are
@@ -50,33 +39,24 @@ def __init__(
         Args:
             ip_adapter_weights: The IP-Adapter attention weights. ip_adapter_weights[i] contains the attention weights
                 for the i'th IP-Adapter.
-            attention_op (`Callable`, *optional*, defaults to `None`):
-                The base
-                [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
-                use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
-                operator.
             slice_size (`int`, *optional*):
                 The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
                 `attention_head_dim` must be a multiple of the `slice_size`.
         """
-        if attention_type not in ["normal", "sliced", "xformers", "torch-sdp"]:
-            raise Exception(f"Unknown attention type: {attention_type}")
-
-        if attention_type == "xformers" and xformers is None:
-            raise ImportError("xformers attention requires xformers module to be installed.")
+        if attention_type not in ["sliced", "torch-sdp"]:
+            raise ValueError(f"Unknown attention type: {attention_type}")
 
         if attention_type == "torch-sdp" and not hasattr(F, "scaled_dot_product_attention"):
             raise ImportError("torch-sdp attention requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
 
         if attention_type == "sliced":
             if slice_size is None:
-                raise Exception(f"slice_size required for sliced attention")
+                raise ValueError("slice_size required for sliced attention")
             if slice_size not in ["auto", "max"] and not isinstance(slice_size, int):
-                raise Exception(f"Unsupported slice_size: {slice_size}")
+                raise ValueError(f"Unsupported slice_size: {slice_size}")
 
         self._ip_adapter_attention_weights = ip_adapter_attention_weights
         self.attention_type = attention_type
-        self.attention_op = attention_op
         self.slice_size = slice_size
 
     def __call__(
@@ -165,16 +145,14 @@ def __call__(
             hidden_states = hidden_states + residual
 
         hidden_states = hidden_states / attn.rescale_output_factor
-
         return hidden_states
 
-
     def run_ip_adapters(
         self,
         attn: Attention,
         hidden_states: torch.Tensor,
         regional_ip_data: Optional[RegionalIPData],
-        query_length: int, # TODO: just read from query?
+        query_length: int,  # TODO: just read from query?
         query: torch.Tensor,
     ) -> torch.Tensor:
         if self._ip_adapter_attention_weights is None:
@@ -193,9 +171,9 @@ def run_ip_adapters(
 
         for ipa_index, ip_hidden_states in enumerate(regional_ip_data.image_prompt_embeds):
             # The batch dimensions should match.
-            #assert ip_hidden_states.shape[0] == encoder_hidden_states.shape[0]
+            # assert ip_hidden_states.shape[0] == encoder_hidden_states.shape[0]
             # The token_len dimensions should match.
-            #assert ip_hidden_states.shape[-1] == encoder_hidden_states.shape[-1]
+            # assert ip_hidden_states.shape[-1] == encoder_hidden_states.shape[-1]
 
             if self._ip_adapter_attention_weights[ipa_index].skip:
                 continue
@@ -221,7 +199,6 @@ def run_ip_adapters(
 
         return hidden_states
 
-
     def prepare_attention_mask(
         self,
         attn: Attention,
@@ -232,7 +209,6 @@ def prepare_attention_mask(
         is_cross_attention: bool,
         regional_prompt_data: Optional[RegionalPromptData],
     ) -> Optional[torch.Tensor]:
-
         if regional_prompt_data is not None and is_cross_attention:
             prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
                 query_seq_len=query_length, key_seq_len=key_length
@@ -243,22 +219,11 @@ def prepare_attention_mask(
             else:
                 attention_mask = prompt_region_attention_mask + attention_mask
 
-
         attention_mask = attn.prepare_attention_mask(attention_mask, key_length, batch_size)
 
-        if self.attention_type in ["normal", "sliced"]:
+        if self.attention_type == "sliced":
             pass
 
-        elif self.attention_type == "xformers":
-            if attention_mask is not None:
-                # expand our mask's singleton query_length dimension:
-                #   [batch*heads,            1, key_length] ->
-                #   [batch*heads, query_length, key_length]
-                # so that it can be added as a bias onto the attention scores that xformers computes:
-                #   [batch*heads, query_length, key_length]
-                # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
-                attention_mask = attention_mask.expand(-1, query_length, -1)
-
         elif self.attention_type == "torch-sdp":
             if attention_mask is not None:
                 # scaled_dot_product_attention expects attention_mask shape to be
@@ -278,11 +243,7 @@ def run_attention(
         value: torch.Tensor,
         attention_mask: Optional[torch.Tensor],
     ) -> torch.Tensor:
-        if self.attention_type == "normal":
-            attn_call = self.run_attention_normal
-        elif self.attention_type == "xformers":
-            attn_call = self.run_attention_xformers
-        elif self.attention_type == "torch-sdp":
+        if self.attention_type == "torch-sdp":
             attn_call = self.run_attention_sdp
         elif self.attention_type == "sliced":
             attn_call = self.run_attention_sliced
@@ -297,45 +258,6 @@ def run_attention(
             attention_mask=attention_mask,
         )
 
-    def run_attention_normal(
-        self,
-        attn: Attention,
-        query: torch.Tensor,
-        key: torch.Tensor,
-        value: torch.Tensor,
-        attention_mask: Optional[torch.Tensor],
-    ) -> torch.Tensor:
-        query = attn.head_to_batch_dim(query)
-        key   = attn.head_to_batch_dim(key)
-        value = attn.head_to_batch_dim(value)
-
-        attention_probs = attn.get_attention_scores(query, key, attention_mask)
-        hidden_states = torch.bmm(attention_probs, value)
-        hidden_states = attn.batch_to_head_dim(hidden_states)
-
-        return hidden_states
-
-    def run_attention_xformers(
-        self,
-        attn: Attention,
-        query: torch.Tensor,
-        key: torch.Tensor,
-        value: torch.Tensor,
-        attention_mask: Optional[torch.Tensor],
-    ) -> torch.Tensor:
-        # attention_op
-        query = attn.head_to_batch_dim(query).contiguous()
-        key   = attn.head_to_batch_dim(key).contiguous()
-        value = attn.head_to_batch_dim(value).contiguous()
-
-        hidden_states = xformers.ops.memory_efficient_attention(
-            query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
-        )
-        hidden_states = hidden_states.to(query.dtype)
-        hidden_states = attn.batch_to_head_dim(hidden_states)
-
-        return hidden_states
-
     def run_attention_sdp(
         self,
         attn: Attention,
@@ -382,7 +304,7 @@ def run_attention_sliced(
         dim = query.shape[-1]
 
         query = attn.head_to_batch_dim(query)
-        key   = attn.head_to_batch_dim(key)
+        key = attn.head_to_batch_dim(key)
         value = attn.head_to_batch_dim(value)
 
         batch_size_attention, query_tokens, _ = query.shape
@@ -399,12 +321,9 @@ def run_attention_sliced(
             attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
 
             attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
-
             attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
 
             hidden_states[start_idx:end_idx] = attn_slice
 
         hidden_states = attn.batch_to_head_dim(hidden_states)
-
         return hidden_states
-
diff --git a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py
index c2f79607c02..065b0563578 100644
--- a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py
+++ b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py
@@ -2,8 +2,8 @@
 from typing import List, Optional, TypedDict
 
 from diffusers.models import UNet2DConditionModel
-from diffusers.utils.import_utils import is_xformers_available
 
+import invokeai.backend.util.logging as logger
 from invokeai.app.services.config.config_default import get_config
 from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
 from invokeai.backend.stable_diffusion.diffusion.custom_atttention import (
@@ -25,23 +25,21 @@ def __init__(self, ip_adapter_data: Optional[List[UNetIPAdapterData]]):
 
     def get_attention_processor_kwargs(self, unet: UNet2DConditionModel):
         config = get_config()
-        kwargs = dict()
-        
-        # TODO:
+        kwargs = {}
+
         attention_type = config.attention_type
+        if attention_type in ["normal", "xformers"]:
+            logger.warning(f'Attention "{attention_type}" no longer supported, "torch-sdp" will be used instead.')
+            attention_type = "torch-sdp"
+
         if attention_type == "auto":
-            if self.unet.device.type == "cuda":
-                if is_xformers_available():
-                    attention_type = "xformers"
-                elif hasattr(torch.nn.functional, "scaled_dot_product_attention"):
-                    attention_type = "torch-sdp"
-                else:
-                    attention_type = "normal"
+            if unet.device.type == "cuda":
+                attention_type = "torch-sdp"
             else:
                 attention_type = "sliced"
 
         kwargs["attention_type"] = attention_type
-        
+
         if attention_type == "sliced":
             slice_size = config.attention_slice_size
             if slice_size == "balanced":
diff --git a/invokeai/backend/util/hotfixes.py b/invokeai/backend/util/hotfixes.py
index 7e362fe9589..a9ed2538825 100644
--- a/invokeai/backend/util/hotfixes.py
+++ b/invokeai/backend/util/hotfixes.py
@@ -791,49 +791,3 @@ def new_LoRACompatibleConv_forward(self, hidden_states, scale: float = 1.0):
 
 
 diffusers.models.lora.LoRACompatibleConv.forward = new_LoRACompatibleConv_forward
-
-try:
-    import xformers
-
-    xformers_available = True
-except Exception:
-    xformers_available = False
-
-
-if xformers_available:
-    # TODO: remove when fixed in diffusers
-    _xformers_memory_efficient_attention = xformers.ops.memory_efficient_attention
-
-    def new_memory_efficient_attention(
-        query: torch.Tensor,
-        key: torch.Tensor,
-        value: torch.Tensor,
-        attn_bias=None,
-        p: float = 0.0,
-        scale: Optional[float] = None,
-        *,
-        op=None,
-    ):
-        # diffusers not align shape to 8, which is required by xformers
-        if attn_bias is not None and type(attn_bias) is torch.Tensor:
-            orig_size = attn_bias.shape[-1]
-            new_size = ((orig_size + 7) // 8) * 8
-            aligned_attn_bias = torch.zeros(
-                (attn_bias.shape[0], attn_bias.shape[1], new_size),
-                device=attn_bias.device,
-                dtype=attn_bias.dtype,
-            )
-            aligned_attn_bias[:, :, :orig_size] = attn_bias
-            attn_bias = aligned_attn_bias[:, :, :orig_size]
-
-        return _xformers_memory_efficient_attention(
-            query=query,
-            key=key,
-            value=value,
-            attn_bias=attn_bias,
-            p=p,
-            scale=scale,
-            op=op,
-        )
-
-    xformers.ops.memory_efficient_attention = new_memory_efficient_attention

From 1ab827619c10ea3014cdadc8d89e6766568f7d46 Mon Sep 17 00:00:00 2001
From: Sergey Borisov <stalkek7779@yandex.ru>
Date: Sun, 28 Jul 2024 02:26:32 +0300
Subject: [PATCH 03/25] Fix file name

---
 .../diffusion/{custom_atttention.py => custom_attention.py}     | 0
 .../stable_diffusion/diffusion/unet_attention_patcher.py        | 2 +-
 2 files changed, 1 insertion(+), 1 deletion(-)
 rename invokeai/backend/stable_diffusion/diffusion/{custom_atttention.py => custom_attention.py} (100%)

diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
similarity index 100%
rename from invokeai/backend/stable_diffusion/diffusion/custom_atttention.py
rename to invokeai/backend/stable_diffusion/diffusion/custom_attention.py
diff --git a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py
index 065b0563578..3365fafef26 100644
--- a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py
+++ b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py
@@ -6,7 +6,7 @@
 import invokeai.backend.util.logging as logger
 from invokeai.app.services.config.config_default import get_config
 from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
-from invokeai.backend.stable_diffusion.diffusion.custom_atttention import (
+from invokeai.backend.stable_diffusion.diffusion.custom_attention import (
     CustomAttnProcessor,
     IPAdapterAttentionWeights,
 )

From 89c37c3979338df6a78f4b2b78a912e3a1e1fba1 Mon Sep 17 00:00:00 2001
From: Sergey Borisov <stalkek7779@yandex.ru>
Date: Sun, 28 Jul 2024 02:56:49 +0300
Subject: [PATCH 04/25] Sync fixes

---
 invokeai/app/invocations/denoise_latents.py   |  4 +-
 .../diffusion/custom_attention.py             | 44 ++++++++++++-------
 .../diffusion/unet_attention_patcher.py       | 32 +-------------
 3 files changed, 31 insertions(+), 49 deletions(-)

diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py
index 2787074265c..c446f714b60 100644
--- a/invokeai/app/invocations/denoise_latents.py
+++ b/invokeai/app/invocations/denoise_latents.py
@@ -55,7 +55,7 @@
     TextConditioningData,
     TextConditioningRegions,
 )
-from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
+from invokeai.backend.stable_diffusion.diffusion.custom_attention import CustomAttnProcessor
 from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
 from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
 from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt
@@ -810,7 +810,7 @@ def _new_invoke(self, context: InvocationContext) -> LatentsOutput:
                 seed=seed,
                 scheduler_step_kwargs=scheduler_step_kwargs,
                 conditioning_data=conditioning_data,
-                attention_processor_cls=CustomAttnProcessor2_0,
+                attention_processor_cls=CustomAttnProcessor,
             ),
             unet=None,
             scheduler=scheduler,
diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
index d5f78ded70f..4ba53e7f784 100644
--- a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
+++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
@@ -1,13 +1,16 @@
 from dataclasses import dataclass
-from typing import List, Optional, Union
+from typing import List, Optional
 
 import torch
 import torch.nn.functional as F
 from diffusers.models.attention_processor import Attention
 
+import invokeai.backend.util.logging as logger
+from invokeai.app.services.config.config_default import get_config
 from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
 from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData
 from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
+from invokeai.backend.util.devices import TorchDevice
 
 
 @dataclass
@@ -28,10 +31,7 @@ class CustomAttnProcessor:
 
     def __init__(
         self,
-        attention_type: str,
         ip_adapter_attention_weights: Optional[List[IPAdapterAttentionWeights]] = None,
-        # sliced
-        slice_size: Optional[Union[str, int]] = None,
     ):
         """Initialize a CustomAttnProcessor.
         Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are
@@ -39,25 +39,37 @@ def __init__(
         Args:
             ip_adapter_weights: The IP-Adapter attention weights. ip_adapter_weights[i] contains the attention weights
                 for the i'th IP-Adapter.
-            slice_size (`int`, *optional*):
-                The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
-                `attention_head_dim` must be a multiple of the `slice_size`.
         """
-        if attention_type not in ["sliced", "torch-sdp"]:
-            raise ValueError(f"Unknown attention type: {attention_type}")
+
+        self._ip_adapter_attention_weights = ip_adapter_attention_weights
+        self.attention_type, self.slice_size = self._select_attention()
+
+    def _select_attention(self):
+        config = get_config()
+        attention_type = config.attention_type
+        if attention_type in ["normal", "xformers"]:
+            logger.warning(f'Attention "{attention_type}" no longer supported, "torch-sdp" will be used instead.')
+            attention_type = "torch-sdp"
+
+        if attention_type == "auto":
+            exec_device = TorchDevice.choose_torch_device()
+            if exec_device.type == "mps":
+                attention_type = "sliced"
+            else:
+                attention_type = "torch-sdp"
 
         if attention_type == "torch-sdp" and not hasattr(F, "scaled_dot_product_attention"):
             raise ImportError("torch-sdp attention requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
 
+        slice_size = None
         if attention_type == "sliced":
-            if slice_size is None:
-                raise ValueError("slice_size required for sliced attention")
-            if slice_size not in ["auto", "max"] and not isinstance(slice_size, int):
-                raise ValueError(f"Unsupported slice_size: {slice_size}")
+            slice_size = config.attention_slice_size
+            if slice_size not in ["auto", "balanced", "max"] and not isinstance(slice_size, int):
+                raise ValueError(f"Unsupported attention_slice_size: {slice_size}")
+            if slice_size == "balanced":
+                slice_size = "auto"
 
-        self._ip_adapter_attention_weights = ip_adapter_attention_weights
-        self.attention_type = attention_type
-        self.slice_size = slice_size
+        return attention_type, slice_size
 
     def __call__(
         self,
diff --git a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py
index 3365fafef26..ce45ac157c2 100644
--- a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py
+++ b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py
@@ -3,8 +3,6 @@
 
 from diffusers.models import UNet2DConditionModel
 
-import invokeai.backend.util.logging as logger
-from invokeai.app.services.config.config_default import get_config
 from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
 from invokeai.backend.stable_diffusion.diffusion.custom_attention import (
     CustomAttnProcessor,
@@ -23,31 +21,6 @@ class UNetAttentionPatcher:
     def __init__(self, ip_adapter_data: Optional[List[UNetIPAdapterData]]):
         self._ip_adapters = ip_adapter_data
 
-    def get_attention_processor_kwargs(self, unet: UNet2DConditionModel):
-        config = get_config()
-        kwargs = {}
-
-        attention_type = config.attention_type
-        if attention_type in ["normal", "xformers"]:
-            logger.warning(f'Attention "{attention_type}" no longer supported, "torch-sdp" will be used instead.')
-            attention_type = "torch-sdp"
-
-        if attention_type == "auto":
-            if unet.device.type == "cuda":
-                attention_type = "torch-sdp"
-            else:
-                attention_type = "sliced"
-
-        kwargs["attention_type"] = attention_type
-
-        if attention_type == "sliced":
-            slice_size = config.attention_slice_size
-            if slice_size == "balanced":
-                slice_size = "auto"
-            kwargs["slice_size"] = slice_size
-
-        return kwargs
-
     def _prepare_attention_processors(self, unet: UNet2DConditionModel):
         """Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
         weights into them (if IP-Adapters are being applied).
@@ -55,13 +28,11 @@ def _prepare_attention_processors(self, unet: UNet2DConditionModel):
         """
         # Construct a dict of attention processors based on the UNet's architecture.
 
-        attn_processor_kwargs = self.get_attention_processor_kwargs(unet)
-
         attn_procs = {}
         for idx, name in enumerate(unet.attn_processors.keys()):
             if name.endswith("attn1.processor") or self._ip_adapters is None:
                 # "attn1" processors do not use IP-Adapters.
-                attn_procs[name] = CustomAttnProcessor(**attn_processor_kwargs)
+                attn_procs[name] = CustomAttnProcessor()
             else:
                 # Collect the weights from each IP Adapter for the idx'th attention processor.
                 ip_adapter_attention_weights_collection: list[IPAdapterAttentionWeights] = []
@@ -80,7 +51,6 @@ def _prepare_attention_processors(self, unet: UNet2DConditionModel):
 
                 attn_procs[name] = CustomAttnProcessor(
                     ip_adapter_attention_weights=ip_adapter_attention_weights_collection,
-                    **attn_processor_kwargs,
                 )
 
         return attn_procs

From e9cc750f8b1b732ddd4277051d89bacfce73cf57 Mon Sep 17 00:00:00 2001
From: Sergey Borisov <stalkek7779@yandex.ru>
Date: Sun, 28 Jul 2024 23:13:33 +0300
Subject: [PATCH 05/25] Update app config

---
 .../app/services/config/config_default.py     | 27 ++++++++++++++++---
 1 file changed, 24 insertions(+), 3 deletions(-)

diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py
index 6c39760bdc8..36cb56c9dbe 100644
--- a/invokeai/app/services/config/config_default.py
+++ b/invokeai/app/services/config/config_default.py
@@ -28,11 +28,11 @@
 DEFAULT_VRAM_CACHE = 0.25
 DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
 PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
-ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
+ATTENTION_TYPE = Literal["auto", "sliced", "torch-sdp"]
 ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
 LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
 LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
-CONFIG_SCHEMA_VERSION = "4.0.2"
+CONFIG_SCHEMA_VERSION = "4.0.3"
 
 
 def get_default_ram_cache_size() -> float:
@@ -107,7 +107,7 @@ class InvokeAIAppConfig(BaseSettings):
         device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
         precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
         sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
-        attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
+        attention_type: Attention type.<br>Valid values: `auto`, `sliced`, `torch-sdp`
         attention_slice_size: Slice size, valid when attention_type=="sliced".<br>Valid values: `auto`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8`
         force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
         pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
@@ -433,6 +433,24 @@ def migrate_v4_0_1_to_4_0_2_config_dict(config_dict: dict[str, Any]) -> dict[str
     return parsed_config_dict
 
 
+def migrate_v4_0_2_to_4_0_3_config_dict(config_dict: dict[str, Any]) -> dict[str, Any]:
+    """Migrate v4.0.2 config dictionary to a v4.0.3 config dictionary.
+
+    Args:
+        config_dict: A dictionary of settings from a v4.0.2 config file.
+
+    Returns:
+        An config dict with the settings migrated to v4.0.3.
+    """
+    parsed_config_dict: dict[str, Any] = copy.deepcopy(config_dict)
+    # normal and xformers attentions removed in 4.0.3
+    attention_type = parsed_config_dict.get("attention_type", None)
+    if attention_type in ["normal", "xformers"]:
+        parsed_config_dict["attention_type"] = "torch-sdp"
+    parsed_config_dict["schema_version"] = "4.0.3"
+    return parsed_config_dict
+
+
 def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
     """Load and migrate a config file to the latest version.
 
@@ -458,6 +476,9 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
     if loaded_config_dict["schema_version"] == "4.0.1":
         migrated = True
         loaded_config_dict = migrate_v4_0_1_to_4_0_2_config_dict(loaded_config_dict)
+    if loaded_config_dict["schema_version"] == "4.0.2":
+        migrated = True
+        loaded_config_dict = migrate_v4_0_2_to_4_0_3_config_dict(loaded_config_dict)
 
     if migrated:
         shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))

From 4b6d61377af4142cb58ef03d0b420220704a34cf Mon Sep 17 00:00:00 2001
From: Sergey Borisov <stalkek7779@yandex.ru>
Date: Mon, 29 Jul 2024 13:47:51 +0300
Subject: [PATCH 06/25] Remove remaining references to xformers

---
 docker/Dockerfile                    | 7 +------
 flake.nix                            | 2 +-
 installer/lib/installer.py           | 4 ++--
 invokeai/app/api/routers/app_info.py | 2 --
 pyproject.toml                       | 6 ------
 5 files changed, 4 insertions(+), 17 deletions(-)

diff --git a/docker/Dockerfile b/docker/Dockerfile
index 7ea078af0d9..24f2ff9e2f7 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -43,12 +43,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \
         extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cu121"; \
     fi &&\
 
-    # xformers + triton fails to install on arm64
-    if [ "$GPU_DRIVER" = "cuda" ] && [ "$TARGETPLATFORM" = "linux/amd64" ]; then \
-        pip install $extra_index_url_arg -e ".[xformers]"; \
-    else \
-        pip install $extra_index_url_arg -e "."; \
-    fi
+    pip install $extra_index_url_arg -e ".";
 
 # #### Build the Web UI ------------------------------------
 
diff --git a/flake.nix b/flake.nix
index 3ccc6658121..bf8d2ae9466 100644
--- a/flake.nix
+++ b/flake.nix
@@ -84,7 +84,7 @@
     in
     {
       devShells.${system} = rec {
-        develop = mkShell { dir = "venv"; install = "-e '.[xformers]' --extra-index-url https://download.pytorch.org/whl/cu118"; };
+        develop = mkShell { dir = "venv"; install = "-e '.' --extra-index-url https://download.pytorch.org/whl/cu118"; };
         default = develop;
       };
     };
diff --git a/installer/lib/installer.py b/installer/lib/installer.py
index 11823b413e0..504c801df6d 100644
--- a/installer/lib/installer.py
+++ b/installer/lib/installer.py
@@ -418,11 +418,11 @@ def get_torch_source() -> Tuple[str | None, str | None]:
             url = "https://download.pytorch.org/whl/cpu"
         elif device.value == "cuda":
             # CUDA uses the default PyPi index
-            optional_modules = "[xformers,onnx-cuda]"
+            optional_modules = "[onnx-cuda]"
     elif OS == "Windows":
         if device.value == "cuda":
             url = "https://download.pytorch.org/whl/cu121"
-            optional_modules = "[xformers,onnx-cuda]"
+            optional_modules = "[onnx-cuda]"
         elif device.value == "cpu":
             # CPU  uses the default PyPi index, no optional modules
             pass
diff --git a/invokeai/app/api/routers/app_info.py b/invokeai/app/api/routers/app_info.py
index e2556ecaa7c..9f87e2cdec0 100644
--- a/invokeai/app/api/routers/app_info.py
+++ b/invokeai/app/api/routers/app_info.py
@@ -56,7 +56,6 @@ class AppDependencyVersions(BaseModel):
     torch: str = Field(description="PyTorch version")
     torchvision: str = Field(description="PyTorch Vision version")
     transformers: str = Field(description="transformers version")
-    xformers: Optional[str] = Field(description="xformers version")
 
 
 class AppConfig(BaseModel):
@@ -88,7 +87,6 @@ async def get_app_deps() -> AppDependencyVersions:
         torch=torch.version.__version__,
         torchvision=version("torchvision"),
         transformers=version("transformers"),
-        xformers=None,  # TODO: ask frontend
     )
 
 
diff --git a/pyproject.toml b/pyproject.toml
index 9acaa17e44d..d1be2215f0d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -95,12 +95,6 @@ dependencies = [
 ]
 
 [project.optional-dependencies]
-"xformers" = [
-  # Core generation dependencies, pinned for reproducible builds.
-  "xformers==0.0.25post1; sys_platform!='darwin'",
-  # Auxiliary dependencies, pinned only if necessary.
-  "triton; sys_platform=='linux'",
-]
 "onnx" = ["onnxruntime"]
 "onnx-cuda" = ["onnxruntime-gpu"]
 "onnx-directml" = ["onnxruntime-directml"]

From d5fa938eb005a11ff9ff03588ae6ae27c2483bdd Mon Sep 17 00:00:00 2001
From: Sergey Borisov <stalkek7779@yandex.ru>
Date: Tue, 30 Jul 2024 04:09:02 +0300
Subject: [PATCH 07/25] Run api regen

---
 invokeai/frontend/web/src/services/api/schema.ts | 5 -----
 1 file changed, 5 deletions(-)

diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts
index 59f9897f740..5334db35e1e 100644
--- a/invokeai/frontend/web/src/services/api/schema.ts
+++ b/invokeai/frontend/web/src/services/api/schema.ts
@@ -725,11 +725,6 @@ export type components = {
        * @description transformers version
        */
       transformers: string;
-      /**
-       * Xformers
-       * @description xformers version
-       */
-      xformers: string | null;
     };
     /**
      * AppVersion

From 5a9cc04e79623f95ee1f07d330879636199ea097 Mon Sep 17 00:00:00 2001
From: Sergey Borisov <stalkek7779@yandex.ru>
Date: Fri, 2 Aug 2024 00:46:17 +0300
Subject: [PATCH 08/25] Small rearrangement

---
 .../diffusion/custom_attention.py             | 62 +++++--------------
 1 file changed, 17 insertions(+), 45 deletions(-)

diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
index 4ba53e7f784..3884fe06ee6 100644
--- a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
+++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
@@ -105,15 +105,18 @@ def __call__(
         )
         query_length = hidden_states.shape[1]
 
-        attention_mask = self.prepare_attention_mask(
-            attn=attn,
-            attention_mask=attention_mask,
-            batch_size=batch_size,
-            key_length=key_length,
-            query_length=query_length,
-            is_cross_attention=is_cross_attention,
-            regional_prompt_data=regional_prompt_data,
-        )
+        # Regional Prompt Attention Mask
+        if regional_prompt_data is not None and is_cross_attention:
+            prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
+                query_seq_len=query_length, key_seq_len=key_length
+            )
+
+            if attention_mask is None:
+                attention_mask = prompt_region_attention_mask
+            else:
+                attention_mask = prompt_region_attention_mask + attention_mask
+
+        attention_mask = attn.prepare_attention_mask(attention_mask, key_length, batch_size)
 
         if attn.group_norm is not None:
             hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
@@ -211,42 +214,6 @@ def run_ip_adapters(
 
         return hidden_states
 
-    def prepare_attention_mask(
-        self,
-        attn: Attention,
-        attention_mask: Optional[torch.Tensor],
-        batch_size: int,
-        key_length: int,
-        query_length: int,
-        is_cross_attention: bool,
-        regional_prompt_data: Optional[RegionalPromptData],
-    ) -> Optional[torch.Tensor]:
-        if regional_prompt_data is not None and is_cross_attention:
-            prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
-                query_seq_len=query_length, key_seq_len=key_length
-            )
-
-            if attention_mask is None:
-                attention_mask = prompt_region_attention_mask
-            else:
-                attention_mask = prompt_region_attention_mask + attention_mask
-
-        attention_mask = attn.prepare_attention_mask(attention_mask, key_length, batch_size)
-
-        if self.attention_type == "sliced":
-            pass
-
-        elif self.attention_type == "torch-sdp":
-            if attention_mask is not None:
-                # scaled_dot_product_attention expects attention_mask shape to be
-                # (batch, heads, source_length, target_length)
-                attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
-
-        else:
-            raise Exception(f"Unknown attention type: {self.attention_type}")
-
-        return attention_mask
-
     def run_attention(
         self,
         attn: Attention,
@@ -286,6 +253,11 @@ def run_attention_sdp(
         key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
 
+        if attention_mask is not None:
+            # scaled_dot_product_attention expects attention_mask shape to be
+            # (batch, heads, source_length, target_length)
+            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
         # the output of sdp = (batch, num_heads, seq_len, head_dim)
         # TODO: add support for attn.scale when we move to Torch 2.1
         hidden_states = F.scaled_dot_product_attention(

From be84746e6783a9259e8cbde96d78fbb8514058bb Mon Sep 17 00:00:00 2001
From: Sergey Borisov <stalkek7779@yandex.ru>
Date: Fri, 2 Aug 2024 00:57:19 +0300
Subject: [PATCH 09/25] Add assert check

Co-Authored-By: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com>
---
 invokeai/backend/stable_diffusion/diffusion/custom_attention.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
index 3884fe06ee6..ad2f68627a8 100644
--- a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
+++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
@@ -175,6 +175,7 @@ def run_ip_adapters(
             assert regional_ip_data is None
             return hidden_states
 
+        assert regional_ip_data is not None
         ip_masks = regional_ip_data.get_masks(query_seq_len=query_length)
 
         assert (

From bf2f798341a40ff58b7df9dc10463237593d2d64 Mon Sep 17 00:00:00 2001
From: Sergey Borisov <stalkek7779@yandex.ru>
Date: Sat, 3 Aug 2024 01:27:01 +0300
Subject: [PATCH 10/25] Fix bad generation on slice_size not factor of heads
 count

---
 .gitignore                                    |   8 +
 invokeai.yaml.bak                             |   6 +
 .../diffusion/custom_attention.py             |   2 +-
 .../diffusion/custom_atttention.py            | 383 ++++++++++++++++++
 invokeai/frontend/web/scripts/typegen.js      |   2 +-
 invokeai/frontend/web/vite.config.mts         |   6 +-
 6 files changed, 402 insertions(+), 5 deletions(-)
 create mode 100644 invokeai.yaml.bak
 create mode 100644 invokeai/backend/stable_diffusion/diffusion/custom_atttention.py

diff --git a/.gitignore b/.gitignore
index 29d27d78ed5..a9739a7294b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,5 +1,13 @@
 .idea/
 
+models/
+nodes/
+configs/
+databases/
+invokeai.yaml
+invokeai.example.yaml
+outputs/
+
 # Byte-compiled / optimized / DLL files
 __pycache__/
 *.py[cod]
diff --git a/invokeai.yaml.bak b/invokeai.yaml.bak
new file mode 100644
index 00000000000..b348590cae6
--- /dev/null
+++ b/invokeai.yaml.bak
@@ -0,0 +1,6 @@
+# Internal metadata - do not edit:
+schema_version: 4.0.2
+
+# Put user settings here - see https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/:
+host: 0.0.0.0
+attention_type: torch-sdp
diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
index ad2f68627a8..743a1d5658c 100644
--- a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
+++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
@@ -297,7 +297,7 @@ def run_attention_sliced(
             (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
         )
 
-        for i in range(batch_size_attention // slice_size):
+        for i in range((batch_size_attention - 1) // slice_size + 1):
             start_idx = i * slice_size
             end_idx = (i + 1) * slice_size
 
diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py
new file mode 100644
index 00000000000..c5a48847f8d
--- /dev/null
+++ b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py
@@ -0,0 +1,383 @@
+from dataclasses import dataclass
+from typing import List, Optional
+
+import torch
+import torch.nn.functional as F
+from diffusers.models.attention_processor import Attention
+
+import invokeai.backend.util.logging as logger
+from invokeai.app.services.config.config_default import get_config
+from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
+from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData
+from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
+from invokeai.backend.util.devices import TorchDevice
+
+
+@dataclass
+class IPAdapterAttentionWeights:
+    ip_adapter_weights: IPAttentionProcessorWeights
+    skip: bool
+
+
+class CustomAttnProcessor2_0:
+    """A custom implementation of attention processor that supports additional Invoke features.
+    This implementation is based on
+    SlicedAttnProcessor (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1616)
+    AttnProcessor2_0 (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1204)
+    Supported custom features:
+    - IP-Adapter
+    - Regional prompt attention
+    """
+
+    def __init__(
+        self,
+        ip_adapter_attention_weights: Optional[List[IPAdapterAttentionWeights]] = None,
+    ):
+        """Initialize a CustomAttnProcessor.
+        Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are
+        layer-specific are passed to __init__().
+        Args:
+            ip_adapter_weights: The IP-Adapter attention weights. ip_adapter_weights[i] contains the attention weights
+                for the i'th IP-Adapter.
+        """
+
+        self._ip_adapter_attention_weights = ip_adapter_attention_weights
+        self.attention_type, self.slice_size = self._select_attention()
+
+    def _select_attention(self):
+        config = get_config()
+        attention_type = config.attention_type
+        if attention_type in ["normal", "xformers"]:
+            logger.warning(f'Attention "{attention_type}" no longer supported, "torch-sdp" will be used instead.')
+            attention_type = "torch-sdp"
+
+        if attention_type == "auto":
+            exec_device = TorchDevice.choose_torch_device()
+            if exec_device.type == "mps":
+                attention_type = "sliced"
+            else:
+                attention_type = "torch-sdp"
+
+        if attention_type == "torch-sdp" and not hasattr(F, "scaled_dot_product_attention"):
+            raise ImportError("torch-sdp attention requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+        slice_size = None
+        if attention_type == "sliced":
+            slice_size = config.attention_slice_size
+            if slice_size not in ["auto", "balanced", "max"] and not isinstance(slice_size, int):
+                raise ValueError(f"Unsupported attention_slice_size: {slice_size}")
+            if slice_size == "balanced":
+                slice_size = "auto"
+
+        return attention_type, slice_size
+
+    def __call__(
+        self,
+        attn: Attention,
+        hidden_states: torch.Tensor,
+        encoder_hidden_states: Optional[torch.Tensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        temb: Optional[torch.Tensor] = None,
+        # For Regional Prompting:
+        regional_prompt_data: Optional[RegionalPromptData] = None,
+        percent_through: Optional[torch.Tensor] = None,
+        # For IP-Adapter:
+        regional_ip_data: Optional[RegionalIPData] = None,
+        *args,
+        **kwargs,
+    ) -> torch.Tensor:
+        # If true, we are doing cross-attention, if false we are doing self-attention.
+        is_cross_attention = encoder_hidden_states is not None
+
+        residual = hidden_states
+
+        if attn.spatial_norm is not None:
+            hidden_states = attn.spatial_norm(hidden_states, temb)
+
+        input_ndim = hidden_states.ndim
+
+        if input_ndim == 4:
+            batch_size, channel, height, width = hidden_states.shape
+            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+        batch_size, key_length, _ = (
+            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+        )
+        query_length = hidden_states.shape[1]
+
+        # Regional Prompt Attention Mask
+        if regional_prompt_data is not None and is_cross_attention:
+            prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
+                query_seq_len=query_length, key_seq_len=key_length
+            )
+
+            if attention_mask is None:
+                attention_mask = prompt_region_attention_mask
+            else:
+                attention_mask = prompt_region_attention_mask + attention_mask
+
+        attention_mask = attn.prepare_attention_mask(attention_mask, key_length, batch_size)
+
+        if attn.group_norm is not None:
+            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+        query = attn.to_q(hidden_states)
+
+        if encoder_hidden_states is None:
+            encoder_hidden_states = hidden_states
+        elif attn.norm_cross:
+            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+        key = attn.to_k(encoder_hidden_states)
+        value = attn.to_v(encoder_hidden_states)
+
+        hidden_states = self.run_attention(
+            attn=attn,
+            query=query,
+            key=key,
+            value=value,
+            attention_mask=attention_mask,
+        )
+
+        if is_cross_attention:
+            hidden_states = self.run_ip_adapters(
+                attn=attn,
+                hidden_states=hidden_states,
+                regional_ip_data=regional_ip_data,
+                query_length=query_length,
+                query=query,
+            )
+
+        # linear proj
+        hidden_states = attn.to_out[0](hidden_states)
+        # dropout
+        hidden_states = attn.to_out[1](hidden_states)
+
+        if input_ndim == 4:
+            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+        if attn.residual_connection:
+            hidden_states = hidden_states + residual
+
+        hidden_states = hidden_states / attn.rescale_output_factor
+        return hidden_states
+
+    def run_ip_adapters(
+        self,
+        attn: Attention,
+        hidden_states: torch.Tensor,
+        regional_ip_data: Optional[RegionalIPData],
+        query_length: int,  # TODO: just read from query?
+        query: torch.Tensor,
+    ) -> torch.Tensor:
+        if self._ip_adapter_attention_weights is None:
+            # If IP-Adapter is not enabled, then regional_ip_data should not be passed in.
+            assert regional_ip_data is None
+            return hidden_states
+
+        assert regional_ip_data is not None
+        ip_masks = regional_ip_data.get_masks(query_seq_len=query_length)
+
+        assert (
+            len(regional_ip_data.image_prompt_embeds)
+            == len(self._ip_adapter_attention_weights)
+            == len(regional_ip_data.scales)
+            == ip_masks.shape[1]
+        )
+
+        for ipa_index, ip_hidden_states in enumerate(regional_ip_data.image_prompt_embeds):
+            # The batch dimensions should match.
+            # assert ip_hidden_states.shape[0] == encoder_hidden_states.shape[0]
+            # The token_len dimensions should match.
+            # assert ip_hidden_states.shape[-1] == encoder_hidden_states.shape[-1]
+
+            if self._ip_adapter_attention_weights[ipa_index].skip:
+                continue
+
+            ipa_weights = self._ip_adapter_attention_weights[ipa_index].ip_adapter_weights
+            ipa_scale = regional_ip_data.scales[ipa_index]
+            ip_mask = ip_masks[0, ipa_index, ...]
+
+            # Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
+            ip_key = ipa_weights.to_k_ip(ip_hidden_states)
+            ip_value = ipa_weights.to_v_ip(ip_hidden_states)
+
+            ip_hidden_states = self.run_attention(
+                attn=attn,
+                query=query,
+                key=ip_key,
+                value=ip_value,
+                attention_mask=None,
+            )
+
+            # Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
+            hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask
+
+        return hidden_states
+
+    def run_attention(
+        self,
+        attn: Attention,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        attention_mask: Optional[torch.Tensor],
+    ) -> torch.Tensor:
+        if self.attention_type == "torch-sdp":
+            attn_call = self.run_attention_sdp
+        elif self.attention_type == "sliced":
+            attn_call = self.run_attention_sliced
+        else:
+            raise Exception(f"Unknown attention type: {self.attention_type}")
+
+        return attn_call(
+            attn=attn,
+            query=query,
+            key=key,
+            value=value,
+            attention_mask=attention_mask,
+        )
+
+    def run_attention_sdp(
+        self,
+        attn: Attention,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        attention_mask: Optional[torch.Tensor],
+    ) -> torch.Tensor:
+        batch_size = key.shape[0]
+        inner_dim = key.shape[-1]
+        head_dim = inner_dim // attn.heads
+
+        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+        if attention_mask is not None:
+            # scaled_dot_product_attention expects attention_mask shape to be
+            # (batch, heads, source_length, target_length)
+            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+        # the output of sdp = (batch, num_heads, seq_len, head_dim)
+        # TODO: add support for attn.scale when we move to Torch 2.1
+        hidden_states = F.scaled_dot_product_attention(
+            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+        )
+
+        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+        hidden_states = hidden_states.to(query.dtype)
+
+        return hidden_states
+
+    def run_attention_sliced(
+        self,
+        attn: Attention,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        attention_mask: Optional[torch.Tensor],
+    ) -> torch.Tensor:
+        if True:
+            func = self._run_attention_sliced_norm
+        else:
+            func = self._run_attention_sliced_sdp
+
+        return func(
+            attn=attn,
+            query=query,
+            key=key,
+            value=value,
+            attention_mask=attention_mask,
+        )
+
+    def _run_attention_sliced_norm(
+        self,
+        attn: Attention,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        attention_mask: Optional[torch.Tensor],
+    ) -> torch.Tensor:
+        # slice_size
+        if self.slice_size == "max":
+            slice_size = 1
+        elif self.slice_size == "auto":
+            slice_size = max(1, attn.sliceable_head_dim // 2)
+        else:
+            slice_size = min(self.slice_size, attn.sliceable_head_dim)
+
+        dim = query.shape[-1]
+
+        query = attn.head_to_batch_dim(query)
+        key = attn.head_to_batch_dim(key)
+        value = attn.head_to_batch_dim(value)
+
+        batch_size_attention, query_tokens, _ = query.shape
+        hidden_states = torch.zeros(
+            (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
+        )
+
+        for i in range(batch_size_attention // slice_size):
+            start_idx = i * slice_size
+            end_idx = (i + 1) * slice_size
+
+            query_slice = query[start_idx:end_idx]
+            key_slice = key[start_idx:end_idx]
+            attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
+
+            attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
+            attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
+
+            hidden_states[start_idx:end_idx] = attn_slice
+
+        hidden_states = attn.batch_to_head_dim(hidden_states)
+        return hidden_states
+
+
+    def _run_attention_sliced_sdp(
+        self,
+        attn: Attention,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        attention_mask: Optional[torch.Tensor],
+    ) -> torch.Tensor:
+        # slice_size
+        if self.slice_size == "max":
+            slice_size = 1
+        elif self.slice_size == "auto":
+            slice_size = max(1, attn.sliceable_head_dim // 2)
+        else:
+            slice_size = min(self.slice_size, attn.sliceable_head_dim)
+
+        dim = query.shape[-1]
+
+        query = attn.head_to_batch_dim(query)
+        key = attn.head_to_batch_dim(key)
+        value = attn.head_to_batch_dim(value)
+
+        batch_size_attention, query_tokens, _ = query.shape
+        hidden_states = torch.zeros(
+            (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
+        )
+
+        for i in range(batch_size_attention // slice_size):
+            start_idx = i * slice_size
+            end_idx = (i + 1) * slice_size
+
+            query_slice = query[start_idx:end_idx]
+            key_slice = key[start_idx:end_idx]
+            value_slice = value[start_idx:end_idx]
+            attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
+
+            # the output of sdp = (batch, num_heads, seq_len, head_dim)
+            # TODO: add support for attn.scale when we move to Torch 2.1
+            attn_slice = F.scaled_dot_product_attention(
+                query_slice, key_slice, value_slice, attn_mask=attn_mask_slice, dropout_p=0.0, is_causal=False
+            )
+
+            hidden_states[start_idx:end_idx] = attn_slice
+
+        hidden_states = attn.batch_to_head_dim(hidden_states)
+        return hidden_states
diff --git a/invokeai/frontend/web/scripts/typegen.js b/invokeai/frontend/web/scripts/typegen.js
index fa2d791350d..435c82a1abb 100644
--- a/invokeai/frontend/web/scripts/typegen.js
+++ b/invokeai/frontend/web/scripts/typegen.js
@@ -3,7 +3,7 @@ import fs from 'node:fs';
 
 import openapiTS from 'openapi-typescript';
 
-const OPENAPI_URL = 'http://127.0.0.1:9090/openapi.json';
+const OPENAPI_URL = 'http://192.168.5.199:9090/openapi.json';
 const OUTPUT_FILE = 'src/services/api/schema.ts';
 
 async function generateTypes(schema) {
diff --git a/invokeai/frontend/web/vite.config.mts b/invokeai/frontend/web/vite.config.mts
index a40c515465c..59a3cf1901f 100644
--- a/invokeai/frontend/web/vite.config.mts
+++ b/invokeai/frontend/web/vite.config.mts
@@ -71,18 +71,18 @@ export default defineConfig(({ mode }) => {
       proxy: {
         // Proxy socket.io to the nodes socketio server
         '/ws/socket.io': {
-          target: 'ws://127.0.0.1:9090',
+          target: 'ws://192.168.5.199:9090',
           ws: true,
         },
         // Proxy openapi schema definiton
         '/openapi.json': {
-          target: 'http://127.0.0.1:9090/openapi.json',
+          target: 'http://192.168.5.199:9090/openapi.json',
           rewrite: (path) => path.replace(/^\/openapi.json/, ''),
           changeOrigin: true,
         },
         // proxy nodes api
         '/api/': {
-          target: 'http://127.0.0.1:9090/api/',
+          target: 'http://192.168.5.199:9090/api/',
           rewrite: (path) => path.replace(/^\/api/, ''),
           changeOrigin: true,
         },

From 91cc89a75a6d70414e1a3414a6d32560b37b534b Mon Sep 17 00:00:00 2001
From: Sergey Borisov <stalkek7779@yandex.ru>
Date: Sat, 3 Aug 2024 01:27:40 +0300
Subject: [PATCH 11/25] Use invoke slice_size values, to have less confusion

---
 .../backend/stable_diffusion/diffusion/custom_attention.py  | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
index 743a1d5658c..bf5e5fa5949 100644
--- a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
+++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
@@ -66,8 +66,8 @@ def _select_attention(self):
             slice_size = config.attention_slice_size
             if slice_size not in ["auto", "balanced", "max"] and not isinstance(slice_size, int):
                 raise ValueError(f"Unsupported attention_slice_size: {slice_size}")
-            if slice_size == "balanced":
-                slice_size = "auto"
+            if slice_size == "auto":
+                slice_size = "balanced"
 
         return attention_type, slice_size
 
@@ -281,7 +281,7 @@ def run_attention_sliced(
         # slice_size
         if self.slice_size == "max":
             slice_size = 1
-        elif self.slice_size == "auto":
+        elif self.slice_size == "balanced":
             slice_size = max(1, attn.sliceable_head_dim // 2)
         else:
             slice_size = min(self.slice_size, attn.sliceable_head_dim)

From 719daebd187a4671d7d54d2b7b488cd858180716 Mon Sep 17 00:00:00 2001
From: Sergey Borisov <stalkek7779@yandex.ru>
Date: Sat, 3 Aug 2024 01:28:24 +0300
Subject: [PATCH 12/25] Add torch-sdp scale parameter support(added in torch
 2.1)

---
 .../diffusion/custom_attention.py                 | 15 +++++++++++++--
 1 file changed, 13 insertions(+), 2 deletions(-)

diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
index bf5e5fa5949..eed3bb3fb9a 100644
--- a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
+++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
@@ -44,6 +44,15 @@ def __init__(
         self._ip_adapter_attention_weights = ip_adapter_attention_weights
         self.attention_type, self.slice_size = self._select_attention()
 
+        # inspect didn't work because it's native function
+        # In 2.0 torch there no scale argument in sdp, it's added in 2.1
+        # Probably can selected based on torch version instead
+        try:
+            F.scaled_dot_product_attention(torch.zeros(1,1), torch.zeros(1,1), torch.zeros(1,1), scale=0.5)
+            self.scaled_sdp = True
+        except:
+            self.scaled_sdp = False
+
     def _select_attention(self):
         config = get_config()
         attention_type = config.attention_type
@@ -260,9 +269,11 @@ def run_attention_sdp(
             attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
 
         # the output of sdp = (batch, num_heads, seq_len, head_dim)
-        # TODO: add support for attn.scale when we move to Torch 2.1
+        scale_kwargs = {}
+        if self.scaled_sdp:
+            scale_kwargs["scale"] = attn.scale
         hidden_states = F.scaled_dot_product_attention(
-            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, **scale_kwargs
         )
 
         hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)

From a16fa31479ee1e0a65f43e033aec65f36209ff08 Mon Sep 17 00:00:00 2001
From: Sergey Borisov <stalkek7779@yandex.ru>
Date: Sat, 3 Aug 2024 01:28:48 +0300
Subject: [PATCH 13/25] Test implementation of sliced attention using torch-sdp

---
 .../diffusion/custom_attention.py             | 25 ++++++++++++++++---
 1 file changed, 21 insertions(+), 4 deletions(-)

diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
index eed3bb3fb9a..460910a1078 100644
--- a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
+++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
@@ -314,12 +314,29 @@ def run_attention_sliced(
 
             query_slice = query[start_idx:end_idx]
             key_slice = key[start_idx:end_idx]
+            value_slice = value[start_idx:end_idx]
             attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
 
-            attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
-            attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
-
-            hidden_states[start_idx:end_idx] = attn_slice
+            # TODO: compare speed/memory on mps
+            # cuda, sd1, 31 step, 1024x1024
+            # denoise_latents       1   19.667s     3.418G
+            # denoise_latents       1   11.601s     2.133G (sdp)
+            # cpu, sd1, 10 steps, 512x512
+            # denoise_latents       1   43.859s     0.000G
+            # denoise_latents       1   40.696s     0.000G (sdp)
+            if False:
+                attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
+                torch.bmm(attn_slice, value_slice, out=hidden_states[start_idx:end_idx])
+            else:
+                if attn_mask_slice is not None:
+                    attn_mask_slice = attn_mask_slice.unsqueeze(0)
+
+                scale_kwargs = {}
+                if self.scaled_sdp:
+                    scale_kwargs["scale"] = attn.scale
+                hidden_states[start_idx:end_idx] = F.scaled_dot_product_attention(
+                    query_slice.unsqueeze(0), key_slice.unsqueeze(0), value_slice.unsqueeze(0), attn_mask=attn_mask_slice, dropout_p=0.0, is_causal=False, **scale_kwargs
+                ).squeeze(0)
 
         hidden_states = attn.batch_to_head_dim(hidden_states)
         return hidden_states

From 7ffceaa7ff27a0b24bb6f0e8c2970d51609eebe4 Mon Sep 17 00:00:00 2001
From: Sergey Borisov <stalkek7779@yandex.ru>
Date: Sat, 3 Aug 2024 02:33:30 +0300
Subject: [PATCH 14/25] Fix slice_size handling

---
 invokeai/backend/stable_diffusion/diffusers_pipeline.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py
index ace300ee03d..42c2fdba397 100644
--- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py
+++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py
@@ -179,7 +179,7 @@ def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
             slice_size = config.attention_slice_size
             if slice_size == "auto":
                 slice_size = auto_detect_slice_size(latents)
-            elif slice_size == "balanced":
+            if slice_size == "balanced":
                 slice_size = "auto"
             self.enable_attention_slicing(slice_size=slice_size)
             return

From c7e71038dd6836fd4c8376cbc639a99c00650a45 Mon Sep 17 00:00:00 2001
From: Sergey Borisov <stalkek7779@yandex.ru>
Date: Sat, 3 Aug 2024 17:04:00 +0300
Subject: [PATCH 15/25] Revert "Fix bad generation on slice_size not factor of
 heads count"

This reverts commit bf2f798341a40ff58b7df9dc10463237593d2d64.
---
 .gitignore                                    |   8 -
 invokeai.yaml.bak                             |   6 -
 .../diffusion/custom_atttention.py            | 383 ------------------
 invokeai/frontend/web/scripts/typegen.js      |   2 +-
 invokeai/frontend/web/vite.config.mts         |   6 +-
 5 files changed, 4 insertions(+), 401 deletions(-)
 delete mode 100644 invokeai.yaml.bak
 delete mode 100644 invokeai/backend/stable_diffusion/diffusion/custom_atttention.py

diff --git a/.gitignore b/.gitignore
index a9739a7294b..29d27d78ed5 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,13 +1,5 @@
 .idea/
 
-models/
-nodes/
-configs/
-databases/
-invokeai.yaml
-invokeai.example.yaml
-outputs/
-
 # Byte-compiled / optimized / DLL files
 __pycache__/
 *.py[cod]
diff --git a/invokeai.yaml.bak b/invokeai.yaml.bak
deleted file mode 100644
index b348590cae6..00000000000
--- a/invokeai.yaml.bak
+++ /dev/null
@@ -1,6 +0,0 @@
-# Internal metadata - do not edit:
-schema_version: 4.0.2
-
-# Put user settings here - see https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/:
-host: 0.0.0.0
-attention_type: torch-sdp
diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py
deleted file mode 100644
index c5a48847f8d..00000000000
--- a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py
+++ /dev/null
@@ -1,383 +0,0 @@
-from dataclasses import dataclass
-from typing import List, Optional
-
-import torch
-import torch.nn.functional as F
-from diffusers.models.attention_processor import Attention
-
-import invokeai.backend.util.logging as logger
-from invokeai.app.services.config.config_default import get_config
-from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
-from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData
-from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
-from invokeai.backend.util.devices import TorchDevice
-
-
-@dataclass
-class IPAdapterAttentionWeights:
-    ip_adapter_weights: IPAttentionProcessorWeights
-    skip: bool
-
-
-class CustomAttnProcessor2_0:
-    """A custom implementation of attention processor that supports additional Invoke features.
-    This implementation is based on
-    SlicedAttnProcessor (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1616)
-    AttnProcessor2_0 (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1204)
-    Supported custom features:
-    - IP-Adapter
-    - Regional prompt attention
-    """
-
-    def __init__(
-        self,
-        ip_adapter_attention_weights: Optional[List[IPAdapterAttentionWeights]] = None,
-    ):
-        """Initialize a CustomAttnProcessor.
-        Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are
-        layer-specific are passed to __init__().
-        Args:
-            ip_adapter_weights: The IP-Adapter attention weights. ip_adapter_weights[i] contains the attention weights
-                for the i'th IP-Adapter.
-        """
-
-        self._ip_adapter_attention_weights = ip_adapter_attention_weights
-        self.attention_type, self.slice_size = self._select_attention()
-
-    def _select_attention(self):
-        config = get_config()
-        attention_type = config.attention_type
-        if attention_type in ["normal", "xformers"]:
-            logger.warning(f'Attention "{attention_type}" no longer supported, "torch-sdp" will be used instead.')
-            attention_type = "torch-sdp"
-
-        if attention_type == "auto":
-            exec_device = TorchDevice.choose_torch_device()
-            if exec_device.type == "mps":
-                attention_type = "sliced"
-            else:
-                attention_type = "torch-sdp"
-
-        if attention_type == "torch-sdp" and not hasattr(F, "scaled_dot_product_attention"):
-            raise ImportError("torch-sdp attention requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
-
-        slice_size = None
-        if attention_type == "sliced":
-            slice_size = config.attention_slice_size
-            if slice_size not in ["auto", "balanced", "max"] and not isinstance(slice_size, int):
-                raise ValueError(f"Unsupported attention_slice_size: {slice_size}")
-            if slice_size == "balanced":
-                slice_size = "auto"
-
-        return attention_type, slice_size
-
-    def __call__(
-        self,
-        attn: Attention,
-        hidden_states: torch.Tensor,
-        encoder_hidden_states: Optional[torch.Tensor] = None,
-        attention_mask: Optional[torch.Tensor] = None,
-        temb: Optional[torch.Tensor] = None,
-        # For Regional Prompting:
-        regional_prompt_data: Optional[RegionalPromptData] = None,
-        percent_through: Optional[torch.Tensor] = None,
-        # For IP-Adapter:
-        regional_ip_data: Optional[RegionalIPData] = None,
-        *args,
-        **kwargs,
-    ) -> torch.Tensor:
-        # If true, we are doing cross-attention, if false we are doing self-attention.
-        is_cross_attention = encoder_hidden_states is not None
-
-        residual = hidden_states
-
-        if attn.spatial_norm is not None:
-            hidden_states = attn.spatial_norm(hidden_states, temb)
-
-        input_ndim = hidden_states.ndim
-
-        if input_ndim == 4:
-            batch_size, channel, height, width = hidden_states.shape
-            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
-        batch_size, key_length, _ = (
-            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
-        )
-        query_length = hidden_states.shape[1]
-
-        # Regional Prompt Attention Mask
-        if regional_prompt_data is not None and is_cross_attention:
-            prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
-                query_seq_len=query_length, key_seq_len=key_length
-            )
-
-            if attention_mask is None:
-                attention_mask = prompt_region_attention_mask
-            else:
-                attention_mask = prompt_region_attention_mask + attention_mask
-
-        attention_mask = attn.prepare_attention_mask(attention_mask, key_length, batch_size)
-
-        if attn.group_norm is not None:
-            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
-        query = attn.to_q(hidden_states)
-
-        if encoder_hidden_states is None:
-            encoder_hidden_states = hidden_states
-        elif attn.norm_cross:
-            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
-
-        key = attn.to_k(encoder_hidden_states)
-        value = attn.to_v(encoder_hidden_states)
-
-        hidden_states = self.run_attention(
-            attn=attn,
-            query=query,
-            key=key,
-            value=value,
-            attention_mask=attention_mask,
-        )
-
-        if is_cross_attention:
-            hidden_states = self.run_ip_adapters(
-                attn=attn,
-                hidden_states=hidden_states,
-                regional_ip_data=regional_ip_data,
-                query_length=query_length,
-                query=query,
-            )
-
-        # linear proj
-        hidden_states = attn.to_out[0](hidden_states)
-        # dropout
-        hidden_states = attn.to_out[1](hidden_states)
-
-        if input_ndim == 4:
-            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
-        if attn.residual_connection:
-            hidden_states = hidden_states + residual
-
-        hidden_states = hidden_states / attn.rescale_output_factor
-        return hidden_states
-
-    def run_ip_adapters(
-        self,
-        attn: Attention,
-        hidden_states: torch.Tensor,
-        regional_ip_data: Optional[RegionalIPData],
-        query_length: int,  # TODO: just read from query?
-        query: torch.Tensor,
-    ) -> torch.Tensor:
-        if self._ip_adapter_attention_weights is None:
-            # If IP-Adapter is not enabled, then regional_ip_data should not be passed in.
-            assert regional_ip_data is None
-            return hidden_states
-
-        assert regional_ip_data is not None
-        ip_masks = regional_ip_data.get_masks(query_seq_len=query_length)
-
-        assert (
-            len(regional_ip_data.image_prompt_embeds)
-            == len(self._ip_adapter_attention_weights)
-            == len(regional_ip_data.scales)
-            == ip_masks.shape[1]
-        )
-
-        for ipa_index, ip_hidden_states in enumerate(regional_ip_data.image_prompt_embeds):
-            # The batch dimensions should match.
-            # assert ip_hidden_states.shape[0] == encoder_hidden_states.shape[0]
-            # The token_len dimensions should match.
-            # assert ip_hidden_states.shape[-1] == encoder_hidden_states.shape[-1]
-
-            if self._ip_adapter_attention_weights[ipa_index].skip:
-                continue
-
-            ipa_weights = self._ip_adapter_attention_weights[ipa_index].ip_adapter_weights
-            ipa_scale = regional_ip_data.scales[ipa_index]
-            ip_mask = ip_masks[0, ipa_index, ...]
-
-            # Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
-            ip_key = ipa_weights.to_k_ip(ip_hidden_states)
-            ip_value = ipa_weights.to_v_ip(ip_hidden_states)
-
-            ip_hidden_states = self.run_attention(
-                attn=attn,
-                query=query,
-                key=ip_key,
-                value=ip_value,
-                attention_mask=None,
-            )
-
-            # Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
-            hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask
-
-        return hidden_states
-
-    def run_attention(
-        self,
-        attn: Attention,
-        query: torch.Tensor,
-        key: torch.Tensor,
-        value: torch.Tensor,
-        attention_mask: Optional[torch.Tensor],
-    ) -> torch.Tensor:
-        if self.attention_type == "torch-sdp":
-            attn_call = self.run_attention_sdp
-        elif self.attention_type == "sliced":
-            attn_call = self.run_attention_sliced
-        else:
-            raise Exception(f"Unknown attention type: {self.attention_type}")
-
-        return attn_call(
-            attn=attn,
-            query=query,
-            key=key,
-            value=value,
-            attention_mask=attention_mask,
-        )
-
-    def run_attention_sdp(
-        self,
-        attn: Attention,
-        query: torch.Tensor,
-        key: torch.Tensor,
-        value: torch.Tensor,
-        attention_mask: Optional[torch.Tensor],
-    ) -> torch.Tensor:
-        batch_size = key.shape[0]
-        inner_dim = key.shape[-1]
-        head_dim = inner_dim // attn.heads
-
-        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
-        if attention_mask is not None:
-            # scaled_dot_product_attention expects attention_mask shape to be
-            # (batch, heads, source_length, target_length)
-            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
-
-        # the output of sdp = (batch, num_heads, seq_len, head_dim)
-        # TODO: add support for attn.scale when we move to Torch 2.1
-        hidden_states = F.scaled_dot_product_attention(
-            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
-        )
-
-        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
-        hidden_states = hidden_states.to(query.dtype)
-
-        return hidden_states
-
-    def run_attention_sliced(
-        self,
-        attn: Attention,
-        query: torch.Tensor,
-        key: torch.Tensor,
-        value: torch.Tensor,
-        attention_mask: Optional[torch.Tensor],
-    ) -> torch.Tensor:
-        if True:
-            func = self._run_attention_sliced_norm
-        else:
-            func = self._run_attention_sliced_sdp
-
-        return func(
-            attn=attn,
-            query=query,
-            key=key,
-            value=value,
-            attention_mask=attention_mask,
-        )
-
-    def _run_attention_sliced_norm(
-        self,
-        attn: Attention,
-        query: torch.Tensor,
-        key: torch.Tensor,
-        value: torch.Tensor,
-        attention_mask: Optional[torch.Tensor],
-    ) -> torch.Tensor:
-        # slice_size
-        if self.slice_size == "max":
-            slice_size = 1
-        elif self.slice_size == "auto":
-            slice_size = max(1, attn.sliceable_head_dim // 2)
-        else:
-            slice_size = min(self.slice_size, attn.sliceable_head_dim)
-
-        dim = query.shape[-1]
-
-        query = attn.head_to_batch_dim(query)
-        key = attn.head_to_batch_dim(key)
-        value = attn.head_to_batch_dim(value)
-
-        batch_size_attention, query_tokens, _ = query.shape
-        hidden_states = torch.zeros(
-            (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
-        )
-
-        for i in range(batch_size_attention // slice_size):
-            start_idx = i * slice_size
-            end_idx = (i + 1) * slice_size
-
-            query_slice = query[start_idx:end_idx]
-            key_slice = key[start_idx:end_idx]
-            attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
-
-            attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
-            attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
-
-            hidden_states[start_idx:end_idx] = attn_slice
-
-        hidden_states = attn.batch_to_head_dim(hidden_states)
-        return hidden_states
-
-
-    def _run_attention_sliced_sdp(
-        self,
-        attn: Attention,
-        query: torch.Tensor,
-        key: torch.Tensor,
-        value: torch.Tensor,
-        attention_mask: Optional[torch.Tensor],
-    ) -> torch.Tensor:
-        # slice_size
-        if self.slice_size == "max":
-            slice_size = 1
-        elif self.slice_size == "auto":
-            slice_size = max(1, attn.sliceable_head_dim // 2)
-        else:
-            slice_size = min(self.slice_size, attn.sliceable_head_dim)
-
-        dim = query.shape[-1]
-
-        query = attn.head_to_batch_dim(query)
-        key = attn.head_to_batch_dim(key)
-        value = attn.head_to_batch_dim(value)
-
-        batch_size_attention, query_tokens, _ = query.shape
-        hidden_states = torch.zeros(
-            (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
-        )
-
-        for i in range(batch_size_attention // slice_size):
-            start_idx = i * slice_size
-            end_idx = (i + 1) * slice_size
-
-            query_slice = query[start_idx:end_idx]
-            key_slice = key[start_idx:end_idx]
-            value_slice = value[start_idx:end_idx]
-            attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
-
-            # the output of sdp = (batch, num_heads, seq_len, head_dim)
-            # TODO: add support for attn.scale when we move to Torch 2.1
-            attn_slice = F.scaled_dot_product_attention(
-                query_slice, key_slice, value_slice, attn_mask=attn_mask_slice, dropout_p=0.0, is_causal=False
-            )
-
-            hidden_states[start_idx:end_idx] = attn_slice
-
-        hidden_states = attn.batch_to_head_dim(hidden_states)
-        return hidden_states
diff --git a/invokeai/frontend/web/scripts/typegen.js b/invokeai/frontend/web/scripts/typegen.js
index 435c82a1abb..fa2d791350d 100644
--- a/invokeai/frontend/web/scripts/typegen.js
+++ b/invokeai/frontend/web/scripts/typegen.js
@@ -3,7 +3,7 @@ import fs from 'node:fs';
 
 import openapiTS from 'openapi-typescript';
 
-const OPENAPI_URL = 'http://192.168.5.199:9090/openapi.json';
+const OPENAPI_URL = 'http://127.0.0.1:9090/openapi.json';
 const OUTPUT_FILE = 'src/services/api/schema.ts';
 
 async function generateTypes(schema) {
diff --git a/invokeai/frontend/web/vite.config.mts b/invokeai/frontend/web/vite.config.mts
index 59a3cf1901f..a40c515465c 100644
--- a/invokeai/frontend/web/vite.config.mts
+++ b/invokeai/frontend/web/vite.config.mts
@@ -71,18 +71,18 @@ export default defineConfig(({ mode }) => {
       proxy: {
         // Proxy socket.io to the nodes socketio server
         '/ws/socket.io': {
-          target: 'ws://192.168.5.199:9090',
+          target: 'ws://127.0.0.1:9090',
           ws: true,
         },
         // Proxy openapi schema definiton
         '/openapi.json': {
-          target: 'http://192.168.5.199:9090/openapi.json',
+          target: 'http://127.0.0.1:9090/openapi.json',
           rewrite: (path) => path.replace(/^\/openapi.json/, ''),
           changeOrigin: true,
         },
         // proxy nodes api
         '/api/': {
-          target: 'http://192.168.5.199:9090/api/',
+          target: 'http://127.0.0.1:9090/api/',
           rewrite: (path) => path.replace(/^\/api/, ''),
           changeOrigin: true,
         },

From 302dc9faeec3116987d626418a50d2bb3f53868e Mon Sep 17 00:00:00 2001
From: Sergey Borisov <stalkek7779@yandex.ru>
Date: Sun, 4 Aug 2024 02:05:32 +0300
Subject: [PATCH 16/25] Return normal attention, change slicing logic, remove
 old attention code

---
 .../app/services/config/config_default.py     |  33 ++-
 .../stable_diffusion/diffusers_pipeline.py    |  81 +-----
 .../diffusion/custom_attention.py             | 152 +++++++----
 .../diffusion/unet_attention_patcher.py       |   2 +-
 .../multi_diffusion_pipeline.py               | 239 +++++++++---------
 5 files changed, 253 insertions(+), 254 deletions(-)

diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py
index 36cb56c9dbe..352b1b46830 100644
--- a/invokeai/app/services/config/config_default.py
+++ b/invokeai/app/services/config/config_default.py
@@ -13,6 +13,7 @@
 from typing import Any, Literal, Optional
 
 import psutil
+import torch
 import yaml
 from pydantic import BaseModel, Field, PrivateAttr, field_validator
 from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict
@@ -28,8 +29,8 @@
 DEFAULT_VRAM_CACHE = 0.25
 DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
 PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
-ATTENTION_TYPE = Literal["auto", "sliced", "torch-sdp"]
-ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
+ATTENTION_TYPE = Literal["auto", "normal", "torch-sdp"]
+ATTENTION_SLICE_SIZE = Literal["auto", "none", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
 LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
 LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
 CONFIG_SCHEMA_VERSION = "4.0.3"
@@ -181,7 +182,7 @@ class InvokeAIAppConfig(BaseSettings):
     # GENERATION
     sequential_guidance:           bool = Field(default=False,              description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.")
     attention_type:      ATTENTION_TYPE = Field(default="auto",             description="Attention type.")
-    attention_slice_size: ATTENTION_SLICE_SIZE = Field(default="auto",      description='Slice size, valid when attention_type=="sliced".')
+    attention_slice_size: ATTENTION_SLICE_SIZE = Field(default="auto",      description='Slice size')
     force_tiled_decode:            bool = Field(default=False,              description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
     pil_compress_level:             int = Field(default=1,                  description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
     max_queue_size:                 int = Field(default=10000, gt=0,        description="Maximum number of items in the session queue.")
@@ -443,10 +444,30 @@ def migrate_v4_0_2_to_4_0_3_config_dict(config_dict: dict[str, Any]) -> dict[str
         An config dict with the settings migrated to v4.0.3.
     """
     parsed_config_dict: dict[str, Any] = copy.deepcopy(config_dict)
-    # normal and xformers attentions removed in 4.0.3
     attention_type = parsed_config_dict.get("attention_type", None)
-    if attention_type in ["normal", "xformers"]:
-        parsed_config_dict["attention_type"] = "torch-sdp"
+
+    # now attention_slice_size means enabling slicing attention
+    if attention_type != "sliced" and "attention_slice_size" in parsed_config_dict:
+        del parsed_config_dict["attention_slice_size"]
+
+    # xformers attention removed, on mps better works normal attention
+    if attention_type == "xformers":
+        if torch.backends.mps.is_available():
+            parsed_config_dict["attention_type"] = "normal"
+        else:
+            parsed_config_dict["attention_type"] = "torch-sdp"
+
+    # slicing attention now enabled by `attention_slice_size`
+    if attention_type == "sliced":
+        if torch.backends.mps.is_available():
+            parsed_config_dict["attention_type"] = "normal"
+        else:
+            parsed_config_dict["attention_type"] = "torch-sdp"
+
+        # if no attention_slise_size in config, use balanced as default option
+        if "attention_slice_size" not in parsed_config_dict:
+            parsed_config_dict["attention_slice_size"] = "balanced"
+
     parsed_config_dict["schema_version"] = "4.0.3"
     return parsed_config_dict
 
diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py
index 42c2fdba397..6c2dca11f37 100644
--- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py
+++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py
@@ -1,16 +1,13 @@
 from __future__ import annotations
 
 import math
-from contextlib import nullcontext
 from dataclasses import dataclass
 from typing import Any, Callable, List, Optional, Union
 
 import einops
 import PIL.Image
-import psutil
 import torch
 import torchvision.transforms as T
-from diffusers.models.attention_processor import AttnProcessor2_0
 from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
 from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
 from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
@@ -19,14 +16,10 @@
 from pydantic import Field
 from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
 
-import invokeai.backend.util.logging as logger
-from invokeai.app.services.config.config_default import get_config
 from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
 from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
 from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
 from invokeai.backend.stable_diffusion.extensions.preview import PipelineIntermediateState
-from invokeai.backend.util.attention import auto_detect_slice_size
-from invokeai.backend.util.devices import TorchDevice
 from invokeai.backend.util.hotfixes import ControlNetModel
 
 
@@ -168,55 +161,6 @@ def __init__(
 
         self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
 
-    def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
-        config = get_config()
-        attention_type = config.attention_type
-        if attention_type in ["normal", "xformers"]:
-            logger.warning(f'Attention "{attention_type}" no longer supported, "torch-sdp" will be used instead.')
-            attention_type = "torch-sdp"
-
-        if config.attention_type == "sliced":
-            slice_size = config.attention_slice_size
-            if slice_size == "auto":
-                slice_size = auto_detect_slice_size(latents)
-            if slice_size == "balanced":
-                slice_size = "auto"
-            self.enable_attention_slicing(slice_size=slice_size)
-            return
-        elif config.attention_type == "torch-sdp":
-            self.unet.set_attn_processor(AttnProcessor2_0())
-            return
-
-        # the remainder if this code is called when attention_type=='auto'
-        if self.unet.device.type == "cuda":
-            self.unet.set_attn_processor(AttnProcessor2_0())
-            return
-
-        if self.unet.device.type == "cpu" or self.unet.device.type == "mps":
-            mem_free = psutil.virtual_memory().free
-        elif self.unet.device.type == "cuda":
-            mem_free, _ = torch.cuda.mem_get_info(TorchDevice.normalize(self.unet.device))
-        else:
-            raise ValueError(f"unrecognized device {self.unet.device}")
-        # input tensor of [1, 4, h/8, w/8]
-        # output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
-        bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4
-        max_size_required_for_baddbmm = (
-            16
-            * latents.size(dim=2)
-            * latents.size(dim=3)
-            * latents.size(dim=2)
-            * latents.size(dim=3)
-            * bytes_per_element_needed_for_baddbmm_duplication
-        )
-        if max_size_required_for_baddbmm > (mem_free * 3.0 / 4.0):  # 3.3 / 4.0 is from old Invoke code
-            self.enable_attention_slicing(slice_size="max")
-        elif torch.backends.mps.is_available():
-            # diffusers recommends always enabling for mps
-            self.enable_attention_slicing(slice_size="max")
-        else:
-            self.unet.set_attn_processor(AttnProcessor2_0())
-
     def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
         raise Exception("Should not be called")
 
@@ -335,25 +279,14 @@ def latents_from_embeddings(
                 is_gradient_mask=is_gradient_mask,
             )
 
-        use_ip_adapter = ip_adapter_data is not None
-        use_regional_prompting = (
-            conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
-        )
-        unet_attention_patcher = None
-        attn_ctx = nullcontext()
-
-        if use_ip_adapter or use_regional_prompting:
-            ip_adapters: Optional[List[UNetIPAdapterData]] = (
-                [{"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks} for ipa in ip_adapter_data]
-                if use_ip_adapter
-                else None
-            )
-            unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
-            attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
-        else:
-            self._adjust_memory_efficient_attention(latents)
+        ip_adapters: Optional[List[UNetIPAdapterData]] = None
+        if ip_adapter_data is not None:
+            ip_adapters = [
+                {"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks} for ipa in ip_adapter_data
+            ]
 
-        with attn_ctx:
+        unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
+        with unet_attention_patcher.apply_custom_attention(self.invokeai_diffuser.model):
             callback(
                 PipelineIntermediateState(
                     step=-1,
diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
index 460910a1078..370c10b1d0b 100644
--- a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
+++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
@@ -1,11 +1,13 @@
+import math
 from dataclasses import dataclass
 from typing import List, Optional
 
+import psutil
 import torch
 import torch.nn.functional as F
 from diffusers.models.attention_processor import Attention
+from packaging.version import Version
 
-import invokeai.backend.util.logging as logger
 from invokeai.app.services.config.config_default import get_config
 from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
 from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData
@@ -42,43 +44,48 @@ def __init__(
         """
 
         self._ip_adapter_attention_weights = ip_adapter_attention_weights
-        self.attention_type, self.slice_size = self._select_attention()
-
-        # inspect didn't work because it's native function
-        # In 2.0 torch there no scale argument in sdp, it's added in 2.1
-        # Probably can selected based on torch version instead
-        try:
-            F.scaled_dot_product_attention(torch.zeros(1,1), torch.zeros(1,1), torch.zeros(1,1), scale=0.5)
-            self.scaled_sdp = True
-        except:
-            self.scaled_sdp = False
-
-    def _select_attention(self):
+
         config = get_config()
-        attention_type = config.attention_type
-        if attention_type in ["normal", "xformers"]:
-            logger.warning(f'Attention "{attention_type}" no longer supported, "torch-sdp" will be used instead.')
-            attention_type = "torch-sdp"
-
-        if attention_type == "auto":
-            exec_device = TorchDevice.choose_torch_device()
-            if exec_device.type == "mps":
-                attention_type = "sliced"
-            else:
-                attention_type = "torch-sdp"
+        self.attention_type = config.attention_type
+        if self.attention_type == "auto":
+            self.attention_type = self._select_attention_type()
 
-        if attention_type == "torch-sdp" and not hasattr(F, "scaled_dot_product_attention"):
-            raise ImportError("torch-sdp attention requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+        self.slice_size = config.attention_slice_size
+        if self.slice_size == "auto":
+            self.slice_size = self._select_slice_size()
 
-        slice_size = None
-        if attention_type == "sliced":
-            slice_size = config.attention_slice_size
-            if slice_size not in ["auto", "balanced", "max"] and not isinstance(slice_size, int):
-                raise ValueError(f"Unsupported attention_slice_size: {slice_size}")
-            if slice_size == "auto":
-                slice_size = "balanced"
+        if self.attention_type == "torch-sdp" and not hasattr(F, "scaled_dot_product_attention"):
+            raise ImportError("torch-sdp attention requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
 
-        return attention_type, slice_size
+        # In 2.0 torch there no `scale` argument in sdp, it's added in 2.1
+        self.scaled_sdp = Version(torch.__version__) >= Version("2.1")
+
+    def _select_attention_type(self) -> str:
+        device = TorchDevice.choose_torch_device()
+        # On some mps system normal attention still faster than torch-sdp on others - on par
+        if device.type == "mps":
+            return "normal"
+        else:  # cuda, cpu
+            return "torch-sdp"
+
+    def _select_slice_size(self) -> str:
+        device = TorchDevice.choose_torch_device()
+        if device.type in ["cpu", "mps"]:
+            total_ram_gb = math.ceil(psutil.virtual_memory().total / 2**30)
+            if total_ram_gb <= 16:
+                return "max"
+            if total_ram_gb <= 32:
+                return "balanced"
+            return "none"
+        elif device.type == "cuda":
+            total_vram_gb = math.ceil(torch.cuda.get_device_properties(device).total_memory / 2**30)
+            if total_vram_gb <= 4:
+                return "max"
+            if total_vram_gb <= 6:
+                return "balanced"
+            return "none"
+        else:
+            raise ValueError(f"Unknown device: {device.type}")
 
     def __call__(
         self,
@@ -224,6 +231,19 @@ def run_ip_adapters(
 
         return hidden_states
 
+    def _get_slice_size(self, attn) -> Optional[int]:
+        if self.slice_size == "none":
+            return None
+        if isinstance(self.slice_size, int):
+            return self.slice_size
+
+        if self.slice_size == "max":
+            return 1
+        if self.slice_size == "balanced":
+            return max(1, attn.sliceable_head_dim // 2)
+
+        raise ValueError(f"Incorrect slice_size value: {self.slice_size}")
+
     def run_attention(
         self,
         attn: Attention,
@@ -232,10 +252,21 @@ def run_attention(
         value: torch.Tensor,
         attention_mask: Optional[torch.Tensor],
     ) -> torch.Tensor:
+        slice_size = self._get_slice_size(attn)
+        if slice_size is not None:
+            return self.run_attention_sliced(
+                attn=attn,
+                query=query,
+                key=key,
+                value=value,
+                attention_mask=attention_mask,
+                slice_size=slice_size,
+            )
+
         if self.attention_type == "torch-sdp":
             attn_call = self.run_attention_sdp
-        elif self.attention_type == "sliced":
-            attn_call = self.run_attention_sliced
+        elif self.attention_type == "normal":
+            attn_call = self.run_attention_normal
         else:
             raise Exception(f"Unknown attention type: {self.attention_type}")
 
@@ -247,6 +278,24 @@ def run_attention(
             attention_mask=attention_mask,
         )
 
+    def run_attention_normal(
+        self,
+        attn: Attention,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        attention_mask: Optional[torch.Tensor],
+    ) -> torch.Tensor:
+        query = attn.head_to_batch_dim(query)
+        key = attn.head_to_batch_dim(key)
+        value = attn.head_to_batch_dim(value)
+
+        attention_probs = attn.get_attention_scores(query, key, attention_mask)
+        hidden_states = torch.bmm(attention_probs, value)
+        hidden_states = attn.batch_to_head_dim(hidden_states)
+
+        return hidden_states
+
     def run_attention_sdp(
         self,
         attn: Attention,
@@ -288,15 +337,8 @@ def run_attention_sliced(
         key: torch.Tensor,
         value: torch.Tensor,
         attention_mask: Optional[torch.Tensor],
+        slice_size: int,
     ) -> torch.Tensor:
-        # slice_size
-        if self.slice_size == "max":
-            slice_size = 1
-        elif self.slice_size == "balanced":
-            slice_size = max(1, attn.sliceable_head_dim // 2)
-        else:
-            slice_size = min(self.slice_size, attn.sliceable_head_dim)
-
         dim = query.shape[-1]
 
         query = attn.head_to_batch_dim(query)
@@ -317,17 +359,11 @@ def run_attention_sliced(
             value_slice = value[start_idx:end_idx]
             attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
 
-            # TODO: compare speed/memory on mps
-            # cuda, sd1, 31 step, 1024x1024
-            # denoise_latents       1   19.667s     3.418G
-            # denoise_latents       1   11.601s     2.133G (sdp)
-            # cpu, sd1, 10 steps, 512x512
-            # denoise_latents       1   43.859s     0.000G
-            # denoise_latents       1   40.696s     0.000G (sdp)
-            if False:
+            if self.attention_type == "normal":
                 attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
                 torch.bmm(attn_slice, value_slice, out=hidden_states[start_idx:end_idx])
-            else:
+                del attn_slice
+            elif self.attention_type == "torch-sdp":
                 if attn_mask_slice is not None:
                     attn_mask_slice = attn_mask_slice.unsqueeze(0)
 
@@ -335,8 +371,16 @@ def run_attention_sliced(
                 if self.scaled_sdp:
                     scale_kwargs["scale"] = attn.scale
                 hidden_states[start_idx:end_idx] = F.scaled_dot_product_attention(
-                    query_slice.unsqueeze(0), key_slice.unsqueeze(0), value_slice.unsqueeze(0), attn_mask=attn_mask_slice, dropout_p=0.0, is_causal=False, **scale_kwargs
+                    query_slice.unsqueeze(0),
+                    key_slice.unsqueeze(0),
+                    value_slice.unsqueeze(0),
+                    attn_mask=attn_mask_slice,
+                    dropout_p=0.0,
+                    is_causal=False,
+                    **scale_kwargs,
                 ).squeeze(0)
+            else:
+                raise ValueError(f"Unknown attention type: {self.attention_type}")
 
         hidden_states = attn.batch_to_head_dim(hidden_states)
         return hidden_states
diff --git a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py
index ce45ac157c2..8ba8b3acf38 100644
--- a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py
+++ b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py
@@ -56,7 +56,7 @@ def _prepare_attention_processors(self, unet: UNet2DConditionModel):
         return attn_procs
 
     @contextmanager
-    def apply_ip_adapter_attention(self, unet: UNet2DConditionModel):
+    def apply_custom_attention(self, unet: UNet2DConditionModel):
         """A context manager that patches `unet` with CustomAttnProcessor2_0 attention layers."""
         attn_procs = self._prepare_attention_processors(unet)
         orig_attn_processors = unet.attn_processors
diff --git a/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py b/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py
index 6c07fc1c2c8..f78facaf581 100644
--- a/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py
+++ b/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py
@@ -13,6 +13,7 @@
     StableDiffusionGeneratorPipeline,
 )
 from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData
+from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
 from invokeai.backend.tiles.utils import Tile
 
 
@@ -63,132 +64,132 @@ def multi_diffusion_denoise(
             latents = self.scheduler.add_noise(latents, noise, batched_init_timestep)
             assert isinstance(latents, torch.Tensor)  # For static type checking.
 
-        # TODO(ryand): Look into the implications of passing in latents here that are larger than they will be after
-        # cropping into regions.
-        self._adjust_memory_efficient_attention(latents)
-
-        # Many of the diffusers schedulers are stateful (i.e. they update internal state in each call to step()). Since
-        # we are calling step() multiple times at the same timestep (once for each region batch), we must maintain a
-        # separate scheduler state for each region batch.
-        # TODO(ryand): This solution allows all schedulers to **run**, but does not fully solve the issue of scheduler
-        # statefulness. Some schedulers store previous model outputs in their state, but these values become incorrect
-        # as Multi-Diffusion blending is applied (e.g. the PNDMScheduler). This can result in a blurring effect when
-        # multiple MultiDiffusion regions overlap. Solving this properly would require a case-by-case review of each
-        # scheduler to determine how it's state needs to be updated for compatibilty with Multi-Diffusion.
-        region_batch_schedulers: list[SchedulerMixin] = [
-            copy.deepcopy(self.scheduler) for _ in multi_diffusion_conditioning
-        ]
-
-        callback(
-            PipelineIntermediateState(
-                step=-1,
-                order=self.scheduler.order,
-                total_steps=len(timesteps),
-                timestep=self.scheduler.config.num_train_timesteps,
-                latents=latents,
-            )
-        )
-
-        for i, t in enumerate(self.progress_bar(timesteps)):
-            batched_t = t.expand(batch_size)
-
-            merged_latents = torch.zeros_like(latents)
-            merged_latents_weights = torch.zeros(
-                (1, 1, latent_height, latent_width), device=latents.device, dtype=latents.dtype
-            )
-            merged_pred_original: torch.Tensor | None = None
-            for region_idx, region_conditioning in enumerate(multi_diffusion_conditioning):
-                # Switch to the scheduler for the region batch.
-                self.scheduler = region_batch_schedulers[region_idx]
-
-                # Crop the inputs to the region.
-                region_latents = latents[
-                    :,
-                    :,
-                    region_conditioning.region.coords.top : region_conditioning.region.coords.bottom,
-                    region_conditioning.region.coords.left : region_conditioning.region.coords.right,
-                ]
-
-                # Run the denoising step on the region.
-                step_output = self.step(
-                    t=batched_t,
-                    latents=region_latents,
-                    conditioning_data=region_conditioning.text_conditioning_data,
-                    step_index=i,
-                    total_step_count=len(timesteps),
-                    scheduler_step_kwargs=scheduler_step_kwargs,
-                    mask_guidance=None,
-                    mask=None,
-                    masked_latents=None,
-                    control_data=region_conditioning.control_data,
-                )
-
-                # Build a region_weight matrix that applies gradient blending to the edges of the region.
-                region = region_conditioning.region
-                _, _, region_height, region_width = step_output.prev_sample.shape
-                region_weight = torch.ones(
-                    (1, 1, region_height, region_width),
-                    dtype=latents.dtype,
-                    device=latents.device,
-                )
-                if region.overlap.left > 0:
-                    left_grad = torch.linspace(
-                        0, 1, region.overlap.left, device=latents.device, dtype=latents.dtype
-                    ).view((1, 1, 1, -1))
-                    region_weight[:, :, :, : region.overlap.left] *= left_grad
-                if region.overlap.top > 0:
-                    top_grad = torch.linspace(
-                        0, 1, region.overlap.top, device=latents.device, dtype=latents.dtype
-                    ).view((1, 1, -1, 1))
-                    region_weight[:, :, : region.overlap.top, :] *= top_grad
-                if region.overlap.right > 0:
-                    right_grad = torch.linspace(
-                        1, 0, region.overlap.right, device=latents.device, dtype=latents.dtype
-                    ).view((1, 1, 1, -1))
-                    region_weight[:, :, :, -region.overlap.right :] *= right_grad
-                if region.overlap.bottom > 0:
-                    bottom_grad = torch.linspace(
-                        1, 0, region.overlap.bottom, device=latents.device, dtype=latents.dtype
-                    ).view((1, 1, -1, 1))
-                    region_weight[:, :, -region.overlap.bottom :, :] *= bottom_grad
-
-                # Update the merged results with the region results.
-                merged_latents[
-                    :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
-                ] += step_output.prev_sample * region_weight
-                merged_latents_weights[
-                    :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
-                ] += region_weight
-
-                pred_orig_sample = getattr(step_output, "pred_original_sample", None)
-                if pred_orig_sample is not None:
-                    # If one region has pred_original_sample, then we can assume that all regions will have it, because
-                    # they all use the same scheduler.
-                    if merged_pred_original is None:
-                        merged_pred_original = torch.zeros_like(latents)
-                    merged_pred_original[
-                        :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
-                    ] += pred_orig_sample
-
-            # Normalize the merged results.
-            latents = torch.where(merged_latents_weights > 0, merged_latents / merged_latents_weights, merged_latents)
-            # For debugging, uncomment this line to visualize the region seams:
-            # latents = torch.where(merged_latents_weights > 1, 0.0, latents)
-            predicted_original = None
-            if merged_pred_original is not None:
-                predicted_original = torch.where(
-                    merged_latents_weights > 0, merged_pred_original / merged_latents_weights, merged_pred_original
-                )
+        unet_attention_patcher = UNetAttentionPatcher(ip_adapter_data=None)
+        with unet_attention_patcher.apply_custom_attention(self.invokeai_diffuser.model):
+            # Many of the diffusers schedulers are stateful (i.e. they update internal state in each call to step()). Since
+            # we are calling step() multiple times at the same timestep (once for each region batch), we must maintain a
+            # separate scheduler state for each region batch.
+            # TODO(ryand): This solution allows all schedulers to **run**, but does not fully solve the issue of scheduler
+            # statefulness. Some schedulers store previous model outputs in their state, but these values become incorrect
+            # as Multi-Diffusion blending is applied (e.g. the PNDMScheduler). This can result in a blurring effect when
+            # multiple MultiDiffusion regions overlap. Solving this properly would require a case-by-case review of each
+            # scheduler to determine how it's state needs to be updated for compatibilty with Multi-Diffusion.
+            region_batch_schedulers: list[SchedulerMixin] = [
+                copy.deepcopy(self.scheduler) for _ in multi_diffusion_conditioning
+            ]
 
             callback(
                 PipelineIntermediateState(
-                    step=i,
+                    step=-1,
                     order=self.scheduler.order,
                     total_steps=len(timesteps),
-                    timestep=int(t),
+                    timestep=self.scheduler.config.num_train_timesteps,
                     latents=latents,
-                    predicted_original=predicted_original,
                 )
             )
 
+            for i, t in enumerate(self.progress_bar(timesteps)):
+                batched_t = t.expand(batch_size)
+
+                merged_latents = torch.zeros_like(latents)
+                merged_latents_weights = torch.zeros(
+                    (1, 1, latent_height, latent_width), device=latents.device, dtype=latents.dtype
+                )
+                merged_pred_original: torch.Tensor | None = None
+                for region_idx, region_conditioning in enumerate(multi_diffusion_conditioning):
+                    # Switch to the scheduler for the region batch.
+                    self.scheduler = region_batch_schedulers[region_idx]
+
+                    # Crop the inputs to the region.
+                    region_latents = latents[
+                        :,
+                        :,
+                        region_conditioning.region.coords.top : region_conditioning.region.coords.bottom,
+                        region_conditioning.region.coords.left : region_conditioning.region.coords.right,
+                    ]
+
+                    # Run the denoising step on the region.
+                    step_output = self.step(
+                        t=batched_t,
+                        latents=region_latents,
+                        conditioning_data=region_conditioning.text_conditioning_data,
+                        step_index=i,
+                        total_step_count=len(timesteps),
+                        scheduler_step_kwargs=scheduler_step_kwargs,
+                        mask_guidance=None,
+                        mask=None,
+                        masked_latents=None,
+                        control_data=region_conditioning.control_data,
+                    )
+
+                    # Build a region_weight matrix that applies gradient blending to the edges of the region.
+                    region = region_conditioning.region
+                    _, _, region_height, region_width = step_output.prev_sample.shape
+                    region_weight = torch.ones(
+                        (1, 1, region_height, region_width),
+                        dtype=latents.dtype,
+                        device=latents.device,
+                    )
+                    if region.overlap.left > 0:
+                        left_grad = torch.linspace(
+                            0, 1, region.overlap.left, device=latents.device, dtype=latents.dtype
+                        ).view((1, 1, 1, -1))
+                        region_weight[:, :, :, : region.overlap.left] *= left_grad
+                    if region.overlap.top > 0:
+                        top_grad = torch.linspace(
+                            0, 1, region.overlap.top, device=latents.device, dtype=latents.dtype
+                        ).view((1, 1, -1, 1))
+                        region_weight[:, :, : region.overlap.top, :] *= top_grad
+                    if region.overlap.right > 0:
+                        right_grad = torch.linspace(
+                            1, 0, region.overlap.right, device=latents.device, dtype=latents.dtype
+                        ).view((1, 1, 1, -1))
+                        region_weight[:, :, :, -region.overlap.right :] *= right_grad
+                    if region.overlap.bottom > 0:
+                        bottom_grad = torch.linspace(
+                            1, 0, region.overlap.bottom, device=latents.device, dtype=latents.dtype
+                        ).view((1, 1, -1, 1))
+                        region_weight[:, :, -region.overlap.bottom :, :] *= bottom_grad
+
+                    # Update the merged results with the region results.
+                    merged_latents[
+                        :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
+                    ] += step_output.prev_sample * region_weight
+                    merged_latents_weights[
+                        :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
+                    ] += region_weight
+
+                    pred_orig_sample = getattr(step_output, "pred_original_sample", None)
+                    if pred_orig_sample is not None:
+                        # If one region has pred_original_sample, then we can assume that all regions will have it, because
+                        # they all use the same scheduler.
+                        if merged_pred_original is None:
+                            merged_pred_original = torch.zeros_like(latents)
+                        merged_pred_original[
+                            :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
+                        ] += pred_orig_sample
+
+                # Normalize the merged results.
+                latents = torch.where(
+                    merged_latents_weights > 0, merged_latents / merged_latents_weights, merged_latents
+                )
+                # For debugging, uncomment this line to visualize the region seams:
+                # latents = torch.where(merged_latents_weights > 1, 0.0, latents)
+                predicted_original = None
+                if merged_pred_original is not None:
+                    predicted_original = torch.where(
+                        merged_latents_weights > 0, merged_pred_original / merged_latents_weights, merged_pred_original
+                    )
+
+                callback(
+                    PipelineIntermediateState(
+                        step=i,
+                        order=self.scheduler.order,
+                        total_steps=len(timesteps),
+                        timestep=int(t),
+                        latents=latents,
+                        predicted_original=predicted_original,
+                    )
+                )
+
         return latents

From 18fc36dbcd88cdfa69bc56ca7b9cd33ea9053aac Mon Sep 17 00:00:00 2001
From: Sergey Borisov <stalkek7779@yandex.ru>
Date: Sun, 4 Aug 2024 04:26:05 +0300
Subject: [PATCH 17/25] Suggested changes

Co-Authored-By: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
---
 .../app/services/config/config_default.py     | 21 +++----------------
 .../diffusion/custom_attention.py             | 19 +++--------------
 2 files changed, 6 insertions(+), 34 deletions(-)

diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py
index 352b1b46830..7624d6a22dd 100644
--- a/invokeai/app/services/config/config_default.py
+++ b/invokeai/app/services/config/config_default.py
@@ -13,7 +13,6 @@
 from typing import Any, Literal, Optional
 
 import psutil
-import torch
 import yaml
 from pydantic import BaseModel, Field, PrivateAttr, field_validator
 from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict
@@ -450,23 +449,9 @@ def migrate_v4_0_2_to_4_0_3_config_dict(config_dict: dict[str, Any]) -> dict[str
     if attention_type != "sliced" and "attention_slice_size" in parsed_config_dict:
         del parsed_config_dict["attention_slice_size"]
 
-    # xformers attention removed, on mps better works normal attention
-    if attention_type == "xformers":
-        if torch.backends.mps.is_available():
-            parsed_config_dict["attention_type"] = "normal"
-        else:
-            parsed_config_dict["attention_type"] = "torch-sdp"
-
-    # slicing attention now enabled by `attention_slice_size`
-    if attention_type == "sliced":
-        if torch.backends.mps.is_available():
-            parsed_config_dict["attention_type"] = "normal"
-        else:
-            parsed_config_dict["attention_type"] = "torch-sdp"
-
-        # if no attention_slise_size in config, use balanced as default option
-        if "attention_slice_size" not in parsed_config_dict:
-            parsed_config_dict["attention_slice_size"] = "balanced"
+    # xformers attention removed, sliced moved to attention_slice_size
+    if attention_type in ["sliced", "xformers"]:
+        parsed_config_dict["attention_type"] = "auto"
 
     parsed_config_dict["schema_version"] = "4.0.3"
     return parsed_config_dict
diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
index 370c10b1d0b..0ef6deb3988 100644
--- a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
+++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
@@ -6,7 +6,6 @@
 import torch
 import torch.nn.functional as F
 from diffusers.models.attention_processor import Attention
-from packaging.version import Version
 
 from invokeai.app.services.config.config_default import get_config
 from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
@@ -54,12 +53,6 @@ def __init__(
         if self.slice_size == "auto":
             self.slice_size = self._select_slice_size()
 
-        if self.attention_type == "torch-sdp" and not hasattr(F, "scaled_dot_product_attention"):
-            raise ImportError("torch-sdp attention requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
-
-        # In 2.0 torch there no `scale` argument in sdp, it's added in 2.1
-        self.scaled_sdp = Version(torch.__version__) >= Version("2.1")
-
     def _select_attention_type(self) -> str:
         device = TorchDevice.choose_torch_device()
         # On some mps system normal attention still faster than torch-sdp on others - on par
@@ -231,7 +224,7 @@ def run_ip_adapters(
 
         return hidden_states
 
-    def _get_slice_size(self, attn) -> Optional[int]:
+    def _get_slice_size(self, attn: Attention) -> Optional[int]:
         if self.slice_size == "none":
             return None
         if isinstance(self.slice_size, int):
@@ -318,11 +311,8 @@ def run_attention_sdp(
             attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
 
         # the output of sdp = (batch, num_heads, seq_len, head_dim)
-        scale_kwargs = {}
-        if self.scaled_sdp:
-            scale_kwargs["scale"] = attn.scale
         hidden_states = F.scaled_dot_product_attention(
-            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, **scale_kwargs
+            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale
         )
 
         hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
@@ -367,9 +357,6 @@ def run_attention_sliced(
                 if attn_mask_slice is not None:
                     attn_mask_slice = attn_mask_slice.unsqueeze(0)
 
-                scale_kwargs = {}
-                if self.scaled_sdp:
-                    scale_kwargs["scale"] = attn.scale
                 hidden_states[start_idx:end_idx] = F.scaled_dot_product_attention(
                     query_slice.unsqueeze(0),
                     key_slice.unsqueeze(0),
@@ -377,7 +364,7 @@ def run_attention_sliced(
                     attn_mask=attn_mask_slice,
                     dropout_p=0.0,
                     is_causal=False,
-                    **scale_kwargs,
+                    scale=attn.scale,
                 ).squeeze(0)
             else:
                 raise ValueError(f"Unknown attention type: {self.attention_type}")

From f44e0cd01423382b77e0a16ac0ddd22168711ce6 Mon Sep 17 00:00:00 2001
From: Sergey Borisov <stalkek7779@yandex.ru>
Date: Sun, 4 Aug 2024 13:17:41 +0300
Subject: [PATCH 18/25] Update config docstring

Co-Authored-By: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
---
 invokeai/app/services/config/config_default.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py
index 7624d6a22dd..6281facbc28 100644
--- a/invokeai/app/services/config/config_default.py
+++ b/invokeai/app/services/config/config_default.py
@@ -107,8 +107,8 @@ class InvokeAIAppConfig(BaseSettings):
         device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
         precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
         sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
-        attention_type: Attention type.<br>Valid values: `auto`, `sliced`, `torch-sdp`
-        attention_slice_size: Slice size, valid when attention_type=="sliced".<br>Valid values: `auto`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8`
+        attention_type: Attention type.<br>Valid values: `auto`, `normal`, `torch-sdp`
+        attention_slice_size: Slice size<br>Valid values: `auto`, `none`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8`
         force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
         pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
         max_queue_size: Maximum number of items in the session queue.

From 9618b6e11f56370ed23dfe730ee7648ebdd7c7c8 Mon Sep 17 00:00:00 2001
From: Sergey Borisov <stalkek7779@yandex.ru>
Date: Tue, 6 Aug 2024 20:31:26 +0300
Subject: [PATCH 19/25] Suggested changes

Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
---
 docs/installation/020_INSTALL_MANUAL.md              |  8 --------
 .../stable_diffusion/diffusion/custom_attention.py   |  6 ++++--
 .../diffusion/shared_invokeai_diffusion.py           |  3 ---
 .../backend/stable_diffusion/diffusion_backend.py    |  3 ---
 .../stable_diffusion/extensions/controlnet.py        |  5 -----
 invokeai/version/__init__.py                         | 12 ------------
 scripts/invokeai-web.py                              |  3 ---
 7 files changed, 4 insertions(+), 36 deletions(-)

diff --git a/docs/installation/020_INSTALL_MANUAL.md b/docs/installation/020_INSTALL_MANUAL.md
index 059834eb453..8b7eeb0cbf7 100644
--- a/docs/installation/020_INSTALL_MANUAL.md
+++ b/docs/installation/020_INSTALL_MANUAL.md
@@ -87,14 +87,6 @@ Before you start, go through the [installation requirements](./INSTALL_REQUIREME
             pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
             ```
 
-    - If you have a CUDA GPU and want to install with `xformers`, you need to add an option to the package name. Note that `xformers` is not necessary. PyTorch includes an implementation of the SDP attention algorithm with the same performance.
-
-        !!! example "Install with `xformers`"
-
-            ```bash
-            pip install "InvokeAI[xformers]" --use-pep517
-            ```
-
 1. Deactivate and reactivate your runtime directory so that the invokeai-specific commands become available in the environment:
 
     === "Linux/macOS"
diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
index 0ef6deb3988..ca79f4a6f55 100644
--- a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
+++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
@@ -55,7 +55,10 @@ def __init__(
 
     def _select_attention_type(self) -> str:
         device = TorchDevice.choose_torch_device()
-        # On some mps system normal attention still faster than torch-sdp on others - on par
+        # On some mps system normal attention still faster than torch-sdp, on others - on par
+        # Results torch-sdp vs normal attention
+        # gogurt: 67.993s vs 67.729s
+        # Adreitz: 260.868s vs 226.638s
         if device.type == "mps":
             return "normal"
         else:  # cuda, cpu
@@ -89,7 +92,6 @@ def __call__(
         temb: Optional[torch.Tensor] = None,
         # For Regional Prompting:
         regional_prompt_data: Optional[RegionalPromptData] = None,
-        percent_through: Optional[torch.Tensor] = None,
         # For IP-Adapter:
         regional_ip_data: Optional[RegionalIPData] = None,
         *args,
diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py
index f418133e49f..d5f54ccff16 100644
--- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py
+++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py
@@ -335,7 +335,6 @@ def _apply_standard_conditioning(
             cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
                 regions=regions, device=x.device, dtype=x.dtype
             )
-            cross_attention_kwargs["percent_through"] = step_index / total_step_count
 
         both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
             conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds
@@ -426,7 +425,6 @@ def _apply_standard_conditioning_sequentially(
             cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
                 regions=[conditioning_data.uncond_regions], device=x.device, dtype=x.dtype
             )
-            cross_attention_kwargs["percent_through"] = step_index / total_step_count
 
         # Run unconditioned UNet denoising (i.e. negative prompt).
         unconditioned_next_x = self.model_forward_callback(
@@ -474,7 +472,6 @@ def _apply_standard_conditioning_sequentially(
             cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
                 regions=[conditioning_data.cond_regions], device=x.device, dtype=x.dtype
             )
-            cross_attention_kwargs["percent_through"] = step_index / total_step_count
 
         # Run conditioned UNet denoising (i.e. positive prompt).
         conditioned_next_x = self.model_forward_callback(
diff --git a/invokeai/backend/stable_diffusion/diffusion_backend.py b/invokeai/backend/stable_diffusion/diffusion_backend.py
index 4191db734f9..dddf7e555d4 100644
--- a/invokeai/backend/stable_diffusion/diffusion_backend.py
+++ b/invokeai/backend/stable_diffusion/diffusion_backend.py
@@ -114,9 +114,6 @@ def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditio
             sample=sample,
             timestep=ctx.timestep,
             encoder_hidden_states=None,  # set later by conditoning
-            cross_attention_kwargs=dict(  # noqa: C408
-                percent_through=ctx.step_index / len(ctx.inputs.timesteps),
-            ),
         )
 
         ctx.conditioning_mode = conditioning_mode
diff --git a/invokeai/backend/stable_diffusion/extensions/controlnet.py b/invokeai/backend/stable_diffusion/extensions/controlnet.py
index a48a681af3f..8728779e00b 100644
--- a/invokeai/backend/stable_diffusion/extensions/controlnet.py
+++ b/invokeai/backend/stable_diffusion/extensions/controlnet.py
@@ -112,8 +112,6 @@ def pre_unet_step(self, ctx: DenoiseContext):
             ctx.unet_kwargs.mid_block_additional_residual += mid_sample
 
     def _run(self, ctx: DenoiseContext, soft_injection: bool, conditioning_mode: ConditioningMode):
-        total_steps = len(ctx.inputs.timesteps)
-
         model_input = ctx.latent_model_input
         image_tensor = self._image_tensor
         if conditioning_mode == ConditioningMode.Both:
@@ -124,9 +122,6 @@ def _run(self, ctx: DenoiseContext, soft_injection: bool, conditioning_mode: Con
             sample=model_input,
             timestep=ctx.timestep,
             encoder_hidden_states=None,  # set later by conditioning
-            cross_attention_kwargs=dict(  # noqa: C408
-                percent_through=ctx.step_index / total_steps,
-            ),
         )
 
         ctx.inputs.conditioning_data.to_unet_kwargs(cn_unet_kwargs, conditioning_mode=conditioning_mode)
diff --git a/invokeai/version/__init__.py b/invokeai/version/__init__.py
index 57efb1af95f..8720b915320 100644
--- a/invokeai/version/__init__.py
+++ b/invokeai/version/__init__.py
@@ -6,15 +6,3 @@
 
 __app_id__ = "invoke-ai/InvokeAI"
 __app_name__ = "InvokeAI"
-
-
-def _ignore_xformers_triton_message_on_windows():
-    import logging
-
-    logging.getLogger("xformers").addFilter(
-        lambda record: "A matching Triton is not available" not in record.getMessage()
-    )
-
-
-# In order to be effective, this needs to happen before anything could possibly import xformers.
-_ignore_xformers_triton_message_on_windows()
diff --git a/scripts/invokeai-web.py b/scripts/invokeai-web.py
index 691e58f7d17..cf68004cc6c 100755
--- a/scripts/invokeai-web.py
+++ b/scripts/invokeai-web.py
@@ -2,13 +2,10 @@
 
 # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
 
-import logging
 import os
 
 from invokeai.app.run_app import run_app
 
-logging.getLogger("xformers").addFilter(lambda record: "A matching Triton is not available" not in record.getMessage())
-
 
 def main():
     # Change working directory to the repo root

From 09aef431f414665abdff5ff745dab9a0e3a92f07 Mon Sep 17 00:00:00 2001
From: Sergey Borisov <stalkek7779@yandex.ru>
Date: Wed, 7 Aug 2024 20:53:54 +0300
Subject: [PATCH 20/25] Restore xformers

---
 docker/Dockerfile                             |  7 ++-
 docs/installation/020_INSTALL_MANUAL.md       |  8 +++
 flake.nix                                     |  2 +-
 installer/lib/installer.py                    |  4 +-
 invokeai/app/api/routers/app_info.py          |  8 ++-
 .../app/services/config/config_default.py     |  6 +-
 .../diffusion/custom_attention.py             | 56 ++++++++++++++++++-
 invokeai/backend/util/hotfixes.py             | 46 +++++++++++++++
 .../frontend/web/src/services/api/schema.ts   |  5 ++
 invokeai/version/__init__.py                  | 12 ++++
 pyproject.toml                                |  6 ++
 scripts/invokeai-web.py                       |  3 +
 12 files changed, 154 insertions(+), 9 deletions(-)

diff --git a/docker/Dockerfile b/docker/Dockerfile
index 24f2ff9e2f7..7ea078af0d9 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -43,7 +43,12 @@ RUN --mount=type=cache,target=/root/.cache/pip \
         extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cu121"; \
     fi &&\
 
-    pip install $extra_index_url_arg -e ".";
+    # xformers + triton fails to install on arm64
+    if [ "$GPU_DRIVER" = "cuda" ] && [ "$TARGETPLATFORM" = "linux/amd64" ]; then \
+        pip install $extra_index_url_arg -e ".[xformers]"; \
+    else \
+        pip install $extra_index_url_arg -e "."; \
+    fi
 
 # #### Build the Web UI ------------------------------------
 
diff --git a/docs/installation/020_INSTALL_MANUAL.md b/docs/installation/020_INSTALL_MANUAL.md
index 8b7eeb0cbf7..059834eb453 100644
--- a/docs/installation/020_INSTALL_MANUAL.md
+++ b/docs/installation/020_INSTALL_MANUAL.md
@@ -87,6 +87,14 @@ Before you start, go through the [installation requirements](./INSTALL_REQUIREME
             pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
             ```
 
+    - If you have a CUDA GPU and want to install with `xformers`, you need to add an option to the package name. Note that `xformers` is not necessary. PyTorch includes an implementation of the SDP attention algorithm with the same performance.
+
+        !!! example "Install with `xformers`"
+
+            ```bash
+            pip install "InvokeAI[xformers]" --use-pep517
+            ```
+
 1. Deactivate and reactivate your runtime directory so that the invokeai-specific commands become available in the environment:
 
     === "Linux/macOS"
diff --git a/flake.nix b/flake.nix
index bf8d2ae9466..3ccc6658121 100644
--- a/flake.nix
+++ b/flake.nix
@@ -84,7 +84,7 @@
     in
     {
       devShells.${system} = rec {
-        develop = mkShell { dir = "venv"; install = "-e '.' --extra-index-url https://download.pytorch.org/whl/cu118"; };
+        develop = mkShell { dir = "venv"; install = "-e '.[xformers]' --extra-index-url https://download.pytorch.org/whl/cu118"; };
         default = develop;
       };
     };
diff --git a/installer/lib/installer.py b/installer/lib/installer.py
index 504c801df6d..11823b413e0 100644
--- a/installer/lib/installer.py
+++ b/installer/lib/installer.py
@@ -418,11 +418,11 @@ def get_torch_source() -> Tuple[str | None, str | None]:
             url = "https://download.pytorch.org/whl/cpu"
         elif device.value == "cuda":
             # CUDA uses the default PyPi index
-            optional_modules = "[onnx-cuda]"
+            optional_modules = "[xformers,onnx-cuda]"
     elif OS == "Windows":
         if device.value == "cuda":
             url = "https://download.pytorch.org/whl/cu121"
-            optional_modules = "[onnx-cuda]"
+            optional_modules = "[xformers,onnx-cuda]"
         elif device.value == "cpu":
             # CPU  uses the default PyPi index, no optional modules
             pass
diff --git a/invokeai/app/api/routers/app_info.py b/invokeai/app/api/routers/app_info.py
index 9f87e2cdec0..3206adb2421 100644
--- a/invokeai/app/api/routers/app_info.py
+++ b/invokeai/app/api/routers/app_info.py
@@ -1,6 +1,6 @@
 import typing
 from enum import Enum
-from importlib.metadata import version
+from importlib.metadata import PackageNotFoundError, version
 from pathlib import Path
 from platform import python_version
 from typing import Optional
@@ -56,6 +56,7 @@ class AppDependencyVersions(BaseModel):
     torch: str = Field(description="PyTorch version")
     torchvision: str = Field(description="PyTorch Vision version")
     transformers: str = Field(description="transformers version")
+    xformers: Optional[str] = Field(description="xformers version")
 
 
 class AppConfig(BaseModel):
@@ -74,6 +75,10 @@ async def get_version() -> AppVersion:
 
 @app_router.get("/app_deps", operation_id="get_app_deps", status_code=200, response_model=AppDependencyVersions)
 async def get_app_deps() -> AppDependencyVersions:
+    try:
+        xformers = version("xformers")
+    except PackageNotFoundError:
+        xformers = None
     return AppDependencyVersions(
         accelerate=version("accelerate"),
         compel=version("compel"),
@@ -87,6 +92,7 @@ async def get_app_deps() -> AppDependencyVersions:
         torch=torch.version.__version__,
         torchvision=version("torchvision"),
         transformers=version("transformers"),
+        xformers=xformers,
     )
 
 
diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py
index 6281facbc28..4f703b59262 100644
--- a/invokeai/app/services/config/config_default.py
+++ b/invokeai/app/services/config/config_default.py
@@ -28,7 +28,7 @@
 DEFAULT_VRAM_CACHE = 0.25
 DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
 PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
-ATTENTION_TYPE = Literal["auto", "normal", "torch-sdp"]
+ATTENTION_TYPE = Literal["auto", "normal", "xformers", "torch-sdp"]
 ATTENTION_SLICE_SIZE = Literal["auto", "none", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
 LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
 LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
@@ -449,8 +449,8 @@ def migrate_v4_0_2_to_4_0_3_config_dict(config_dict: dict[str, Any]) -> dict[str
     if attention_type != "sliced" and "attention_slice_size" in parsed_config_dict:
         del parsed_config_dict["attention_slice_size"]
 
-    # xformers attention removed, sliced moved to attention_slice_size
-    if attention_type in ["sliced", "xformers"]:
+    # sliced moved to attention_slice_size
+    if attention_type == "sliced":
         parsed_config_dict["attention_type"] = "auto"
 
     parsed_config_dict["schema_version"] = "4.0.3"
diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
index ca79f4a6f55..35f81db9ba0 100644
--- a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
+++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
@@ -6,6 +6,7 @@
 import torch
 import torch.nn.functional as F
 from diffusers.models.attention_processor import Attention
+from diffusers.utils.import_utils import is_xformers_available
 
 from invokeai.app.services.config.config_default import get_config
 from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
@@ -13,6 +14,12 @@
 from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
 from invokeai.backend.util.devices import TorchDevice
 
+if is_xformers_available():
+    import xformers
+    import xformers.ops
+else:
+    xformers = None
+
 
 @dataclass
 class IPAdapterAttentionWeights:
@@ -23,7 +30,9 @@ class IPAdapterAttentionWeights:
 class CustomAttnProcessor:
     """A custom implementation of attention processor that supports additional Invoke features.
     This implementation is based on
+    AttnProcessor (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L732)
     SlicedAttnProcessor (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1616)
+    XFormersAttnProcessor (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1113)
     AttnProcessor2_0 (https://github.com/huggingface/diffusers/blob/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1204)
     Supported custom features:
     - IP-Adapter
@@ -53,6 +62,9 @@ def __init__(
         if self.slice_size == "auto":
             self.slice_size = self._select_slice_size()
 
+        if self.attention_type == "xformers" and xformers is None:
+            raise ImportError("xformers attention requires xformers module to be installed.")
+
     def _select_attention_type(self) -> str:
         device = TorchDevice.choose_torch_device()
         # On some mps system normal attention still faster than torch-sdp, on others - on par
@@ -61,7 +73,14 @@ def _select_attention_type(self) -> str:
         # Adreitz: 260.868s vs 226.638s
         if device.type == "mps":
             return "normal"
-        else:  # cuda, cpu
+        elif device.type == "cuda":
+            # Flash Attention is supported from sm80 compute capability onwards in PyTorch
+            # https://pytorch.org/blog/accelerated-pytorch-2/
+            if torch.cuda.get_device_capability("cuda")[0] < 8 and xformers is not None:
+                return "xformers"
+            else:
+                return "torch-sdp"
+        else:  # cpu
             return "torch-sdp"
 
     def _select_slice_size(self) -> str:
@@ -262,6 +281,8 @@ def run_attention(
             attn_call = self.run_attention_sdp
         elif self.attention_type == "normal":
             attn_call = self.run_attention_normal
+        elif self.attention_type == "xformers":
+            attn_call = self.run_attention_xformers
         else:
             raise Exception(f"Unknown attention type: {self.attention_type}")
 
@@ -291,6 +312,35 @@ def run_attention_normal(
 
         return hidden_states
 
+    def run_attention_xformers(
+        self,
+        attn: Attention,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        attention_mask: Optional[torch.Tensor],
+    ) -> torch.Tensor:
+        query = attn.head_to_batch_dim(query).contiguous()
+        key   = attn.head_to_batch_dim(key).contiguous()
+        value = attn.head_to_batch_dim(value).contiguous()
+
+        if attention_mask is not None:
+            # expand our mask's singleton query_length dimension:
+            #   [batch*heads,            1, key_length] ->
+            #   [batch*heads, query_length, key_length]
+            # so that it can be added as a bias onto the attention scores that xformers computes:
+            #   [batch*heads, query_length, key_length]
+            # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
+            attention_mask = attention_mask.expand(-1, query.shape[1], -1)
+
+        hidden_states = xformers.ops.memory_efficient_attention(
+            query, key, value, attn_bias=attention_mask, op=None, scale=attn.scale
+        )
+        hidden_states = hidden_states.to(query.dtype)
+        hidden_states = attn.batch_to_head_dim(hidden_states)
+
+        return hidden_states
+
     def run_attention_sdp(
         self,
         attn: Attention,
@@ -355,6 +405,10 @@ def run_attention_sliced(
                 attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
                 torch.bmm(attn_slice, value_slice, out=hidden_states[start_idx:end_idx])
                 del attn_slice
+            elif self.attention_type == "xformers":
+                hidden_states[start_idx:end_idx] = xformers.ops.memory_efficient_attention(
+                    query_slice, key_slice, value_slice, attn_bias=attn_mask_slice, op=None, scale=attn.scale
+                )
             elif self.attention_type == "torch-sdp":
                 if attn_mask_slice is not None:
                     attn_mask_slice = attn_mask_slice.unsqueeze(0)
diff --git a/invokeai/backend/util/hotfixes.py b/invokeai/backend/util/hotfixes.py
index a9ed2538825..7e362fe9589 100644
--- a/invokeai/backend/util/hotfixes.py
+++ b/invokeai/backend/util/hotfixes.py
@@ -791,3 +791,49 @@ def new_LoRACompatibleConv_forward(self, hidden_states, scale: float = 1.0):
 
 
 diffusers.models.lora.LoRACompatibleConv.forward = new_LoRACompatibleConv_forward
+
+try:
+    import xformers
+
+    xformers_available = True
+except Exception:
+    xformers_available = False
+
+
+if xformers_available:
+    # TODO: remove when fixed in diffusers
+    _xformers_memory_efficient_attention = xformers.ops.memory_efficient_attention
+
+    def new_memory_efficient_attention(
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        attn_bias=None,
+        p: float = 0.0,
+        scale: Optional[float] = None,
+        *,
+        op=None,
+    ):
+        # diffusers not align shape to 8, which is required by xformers
+        if attn_bias is not None and type(attn_bias) is torch.Tensor:
+            orig_size = attn_bias.shape[-1]
+            new_size = ((orig_size + 7) // 8) * 8
+            aligned_attn_bias = torch.zeros(
+                (attn_bias.shape[0], attn_bias.shape[1], new_size),
+                device=attn_bias.device,
+                dtype=attn_bias.dtype,
+            )
+            aligned_attn_bias[:, :, :orig_size] = attn_bias
+            attn_bias = aligned_attn_bias[:, :, :orig_size]
+
+        return _xformers_memory_efficient_attention(
+            query=query,
+            key=key,
+            value=value,
+            attn_bias=attn_bias,
+            p=p,
+            scale=scale,
+            op=op,
+        )
+
+    xformers.ops.memory_efficient_attention = new_memory_efficient_attention
diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts
index b4b39eae32b..79b82a23fae 100644
--- a/invokeai/frontend/web/src/services/api/schema.ts
+++ b/invokeai/frontend/web/src/services/api/schema.ts
@@ -725,6 +725,11 @@ export type components = {
        * @description transformers version
        */
       transformers: string;
+      /**
+       * Xformers
+       * @description xformers version
+       */
+      xformers: string | null;
     };
     /**
      * AppVersion
diff --git a/invokeai/version/__init__.py b/invokeai/version/__init__.py
index 8720b915320..57efb1af95f 100644
--- a/invokeai/version/__init__.py
+++ b/invokeai/version/__init__.py
@@ -6,3 +6,15 @@
 
 __app_id__ = "invoke-ai/InvokeAI"
 __app_name__ = "InvokeAI"
+
+
+def _ignore_xformers_triton_message_on_windows():
+    import logging
+
+    logging.getLogger("xformers").addFilter(
+        lambda record: "A matching Triton is not available" not in record.getMessage()
+    )
+
+
+# In order to be effective, this needs to happen before anything could possibly import xformers.
+_ignore_xformers_triton_message_on_windows()
diff --git a/pyproject.toml b/pyproject.toml
index 5bcf74d88cd..cdf032b301b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -94,6 +94,12 @@ dependencies = [
 ]
 
 [project.optional-dependencies]
+"xformers" = [
+  # Core generation dependencies, pinned for reproducible builds.
+  "xformers==0.0.25post1; sys_platform!='darwin'",
+  # Auxiliary dependencies, pinned only if necessary.
+  "triton; sys_platform=='linux'",
+]
 "onnx" = ["onnxruntime"]
 "onnx-cuda" = ["onnxruntime-gpu"]
 "onnx-directml" = ["onnxruntime-directml"]
diff --git a/scripts/invokeai-web.py b/scripts/invokeai-web.py
index cf68004cc6c..691e58f7d17 100755
--- a/scripts/invokeai-web.py
+++ b/scripts/invokeai-web.py
@@ -2,10 +2,13 @@
 
 # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
 
+import logging
 import os
 
 from invokeai.app.run_app import run_app
 
+logging.getLogger("xformers").addFilter(lambda record: "A matching Triton is not available" not in record.getMessage())
+
 
 def main():
     # Change working directory to the repo root

From 37dfab7cb1d5ed03ec027cec5fb6daacf6d978df Mon Sep 17 00:00:00 2001
From: Sergey Borisov <stalkek7779@yandex.ru>
Date: Wed, 7 Aug 2024 21:23:32 +0300
Subject: [PATCH 21/25] Small fixes

---
 invokeai/app/services/config/config_default.py               | 2 +-
 .../backend/stable_diffusion/diffusion/custom_attention.py   | 5 ++++-
 2 files changed, 5 insertions(+), 2 deletions(-)

diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py
index 4f703b59262..06a197f630e 100644
--- a/invokeai/app/services/config/config_default.py
+++ b/invokeai/app/services/config/config_default.py
@@ -107,7 +107,7 @@ class InvokeAIAppConfig(BaseSettings):
         device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
         precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
         sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
-        attention_type: Attention type.<br>Valid values: `auto`, `normal`, `torch-sdp`
+        attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `torch-sdp`
         attention_slice_size: Slice size<br>Valid values: `auto`, `none`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8`
         force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
         pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
index 35f81db9ba0..d29ec07815c 100644
--- a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
+++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
@@ -321,7 +321,7 @@ def run_attention_xformers(
         attention_mask: Optional[torch.Tensor],
     ) -> torch.Tensor:
         query = attn.head_to_batch_dim(query).contiguous()
-        key   = attn.head_to_batch_dim(key).contiguous()
+        key = attn.head_to_batch_dim(key).contiguous()
         value = attn.head_to_batch_dim(value).contiguous()
 
         if attention_mask is not None:
@@ -406,6 +406,9 @@ def run_attention_sliced(
                 torch.bmm(attn_slice, value_slice, out=hidden_states[start_idx:end_idx])
                 del attn_slice
             elif self.attention_type == "xformers":
+                if attn_mask_slice is not None:
+                    attn_mask_slice = attn_mask_slice.expand(-1, query.shape[1], -1)
+
                 hidden_states[start_idx:end_idx] = xformers.ops.memory_efficient_attention(
                     query_slice, key_slice, value_slice, attn_bias=attn_mask_slice, op=None, scale=attn.scale
                 )

From 192fba4fe3e2673d22ef69291a52f8859cd13bba Mon Sep 17 00:00:00 2001
From: Sergey Borisov <stalkek7779@yandex.ru>
Date: Tue, 20 Aug 2024 02:03:34 +0300
Subject: [PATCH 22/25] Rewrite sliced attention, more optimizations(batched
 torch-sdp for old cuda, multihead xformers for high heads count)

---
 .../diffusion/custom_attention.py             | 272 ++++++++++++------
 invokeai/backend/util/hotfixes.py             |  15 +-
 2 files changed, 193 insertions(+), 94 deletions(-)

diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
index d29ec07815c..d854e20efde 100644
--- a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
+++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
@@ -53,6 +53,9 @@ def __init__(
 
         self._ip_adapter_attention_weights = ip_adapter_attention_weights
 
+        device = TorchDevice.choose_torch_device()
+        self.is_old_cuda = device.type == "cuda" and torch.cuda.get_device_capability(device)[0] < 8
+
         config = get_config()
         self.attention_type = config.attention_type
         if self.attention_type == "auto":
@@ -76,7 +79,7 @@ def _select_attention_type(self) -> str:
         elif device.type == "cuda":
             # Flash Attention is supported from sm80 compute capability onwards in PyTorch
             # https://pytorch.org/blog/accelerated-pytorch-2/
-            if torch.cuda.get_device_capability("cuda")[0] < 8 and xformers is not None:
+            if self.is_old_cuda and xformers is not None:
                 return "xformers"
             else:
                 return "torch-sdp"
@@ -265,9 +268,10 @@ def run_attention(
         key: torch.Tensor,
         value: torch.Tensor,
         attention_mask: Optional[torch.Tensor],
+        no_sliced: bool = False,
     ) -> torch.Tensor:
         slice_size = self._get_slice_size(attn)
-        if slice_size is not None:
+        if not no_sliced and slice_size is not None:
             return self.run_attention_sliced(
                 attn=attn,
                 query=query,
@@ -294,6 +298,41 @@ def run_attention(
             attention_mask=attention_mask,
         )
 
+    @staticmethod
+    def _align_attention_mask_memory(attention_mask: torch.Tensor, alignment: int = 8) -> torch.Tensor:
+        if attention_mask.stride(-2) % alignment == 0 and attention_mask.stride(-2) != 0:
+            return attention_mask
+
+        last_mask_dim = attention_mask.shape[-1]
+        new_last_mask_dim = last_mask_dim + (alignment - (last_mask_dim % alignment))
+        attention_mask_mem = torch.empty(
+            attention_mask.shape[:-1] + (new_last_mask_dim,),
+            device=attention_mask.device,
+            dtype=attention_mask.dtype,
+        )
+        attention_mask_mem[..., :last_mask_dim] = attention_mask
+        return attention_mask_mem[..., :last_mask_dim]
+
+    @staticmethod
+    def _head_to_batch_dim(tensor: torch.Tensor, head_dim: int) -> torch.Tensor:
+        # [B, S, H*He] -> [B, S, H, He]
+        tensor = tensor.reshape(tensor.shape[0], tensor.shape[1], -1, head_dim)
+        # [B, S, H, He] -> [B, H, S, He]
+        tensor = tensor.permute(0, 2, 1, 3)
+        # [B, H, S, He] -> [B*H, S, He]
+        tensor = tensor.reshape(-1, tensor.shape[2], head_dim)
+        return tensor
+
+    @staticmethod
+    def _batch_to_head_dim(tensor: torch.Tensor, batch_size: int) -> torch.Tensor:
+        # [B*H, S, He] -> [B, H, S, He]
+        tensor = tensor.reshape(batch_size, -1, tensor.shape[1], tensor.shape[2])
+        # [B, H, S, He] -> [B, S, H, He]
+        tensor = tensor.permute(0, 2, 1, 3)
+        # [B, S, H, He] -> [B, S, H*He]
+        tensor = tensor.reshape(tensor.shape[0], tensor.shape[1], -1)
+        return tensor
+
     def run_attention_normal(
         self,
         attn: Attention,
@@ -302,14 +341,17 @@ def run_attention_normal(
         value: torch.Tensor,
         attention_mask: Optional[torch.Tensor],
     ) -> torch.Tensor:
-        query = attn.head_to_batch_dim(query)
-        key = attn.head_to_batch_dim(key)
-        value = attn.head_to_batch_dim(value)
+        batch_size = query.shape[0]
+        head_dim = attn.to_q.weight.shape[0] // attn.heads
+
+        query = self._head_to_batch_dim(query, head_dim)
+        key = self._head_to_batch_dim(key, head_dim)
+        value = self._head_to_batch_dim(value, head_dim)
 
         attention_probs = attn.get_attention_scores(query, key, attention_mask)
         hidden_states = torch.bmm(attention_probs, value)
-        hidden_states = attn.batch_to_head_dim(hidden_states)
 
+        hidden_states = self._batch_to_head_dim(hidden_states, batch_size)
         return hidden_states
 
     def run_attention_xformers(
@@ -319,25 +361,62 @@ def run_attention_xformers(
         key: torch.Tensor,
         value: torch.Tensor,
         attention_mask: Optional[torch.Tensor],
+        multihead: Optional[bool] = None,
     ) -> torch.Tensor:
-        query = attn.head_to_batch_dim(query).contiguous()
-        key = attn.head_to_batch_dim(key).contiguous()
-        value = attn.head_to_batch_dim(value).contiguous()
+        batch_size = query.shape[0]
+        head_dim = attn.to_q.weight.shape[0] // attn.heads
+
+        # batched execution on xformers slightly faster for small heads count
+        if multihead is None:
+            heads_count = query.shape[2] // head_dim
+            multihead = heads_count >= 4
+
+        if multihead:
+            # [B, S, H*He] -> [B, S, H, He]
+            query = query.view(batch_size, query.shape[1], -1, head_dim)
+            key = key.view(batch_size, key.shape[1], -1, head_dim)
+            value = value.view(batch_size, value.shape[1], -1, head_dim)
+
+            if attention_mask is not None:
+                # [B*H, 1, S_key] -> [B, H, 1, S_key]
+                attention_mask = attention_mask.view(batch_size, -1, attention_mask.shape[1], attention_mask.shape[2])
+                # expand our mask's singleton query dimension:
+                #   [B, H,       1, S_key] ->
+                #   [B, H, S_query, S_key]
+                # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
+                attention_mask = attention_mask.expand(-1, -1, query.shape[1], -1)
+                # xformers requires mask memory to be aligned to 8
+                attention_mask = self._align_attention_mask_memory(attention_mask)
+
+            hidden_states = xformers.ops.memory_efficient_attention(
+                query, key, value, attn_bias=attention_mask, op=None, scale=attn.scale
+            )
+            # [B, S_query, H, He] -> [B, S_query, H*He]
+            hidden_states = hidden_states.reshape(hidden_states.shape[:-2] + (-1,))
+            hidden_states = hidden_states.to(query.dtype)
 
-        if attention_mask is not None:
-            # expand our mask's singleton query_length dimension:
-            #   [batch*heads,            1, key_length] ->
-            #   [batch*heads, query_length, key_length]
-            # so that it can be added as a bias onto the attention scores that xformers computes:
-            #   [batch*heads, query_length, key_length]
-            # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
-            attention_mask = attention_mask.expand(-1, query.shape[1], -1)
-
-        hidden_states = xformers.ops.memory_efficient_attention(
-            query, key, value, attn_bias=attention_mask, op=None, scale=attn.scale
-        )
-        hidden_states = hidden_states.to(query.dtype)
-        hidden_states = attn.batch_to_head_dim(hidden_states)
+        else:
+            # contiguous inputs slightly faster in batched execution
+            # [B, S, H*He] -> [B*H, S, He]
+            query = self._head_to_batch_dim(query, head_dim).contiguous()
+            key = self._head_to_batch_dim(key, head_dim).contiguous()
+            value = self._head_to_batch_dim(value, head_dim).contiguous()
+
+            if attention_mask is not None:
+                # expand our mask's singleton query dimension:
+                #   [B*H,       1, S_key] ->
+                #   [B*H, S_query, S_key]
+                # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
+                attention_mask = attention_mask.expand(-1, query.shape[1], -1)
+                # xformers requires mask memory to be aligned to 8
+                attention_mask = self._align_attention_mask_memory(attention_mask)
+
+            hidden_states = xformers.ops.memory_efficient_attention(
+                query, key, value, attn_bias=attention_mask, op=None, scale=attn.scale
+            )
+            hidden_states = hidden_states.to(query.dtype)
+            # [B*H, S_query, He] -> [B, S_query, H*He]
+            hidden_states = self._batch_to_head_dim(hidden_states, batch_size)
 
         return hidden_states
 
@@ -348,27 +427,54 @@ def run_attention_sdp(
         key: torch.Tensor,
         value: torch.Tensor,
         attention_mask: Optional[torch.Tensor],
+        multihead: Optional[bool] = None,
     ) -> torch.Tensor:
-        batch_size = key.shape[0]
-        inner_dim = key.shape[-1]
-        head_dim = inner_dim // attn.heads
-
-        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
-        if attention_mask is not None:
-            # scaled_dot_product_attention expects attention_mask shape to be
-            # (batch, heads, source_length, target_length)
-            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+        batch_size = query.shape[0]
+        head_dim = attn.to_q.weight.shape[0] // attn.heads
+
+        if multihead is None:
+            # multihead extremely slow on old cuda gpu
+            multihead = not self.is_old_cuda
+
+        if multihead:
+            # [B, S, H*He] -> [B, H, S, He]
+            query = query.view(batch_size, query.shape[1], -1, head_dim).transpose(1, 2)
+            key = key.view(batch_size, key.shape[1], -1, head_dim).transpose(1, 2)
+            value = value.view(batch_size, value.shape[1], -1, head_dim).transpose(1, 2)
+
+            if attention_mask is not None:
+                # [B*H, 1, S_key] -> [B, H, 1, S_key]
+                attention_mask = attention_mask.view(batch_size, -1, attention_mask.shape[1], attention_mask.shape[2])
+                # mask alignment to 8 decreases memory consumption and increases speed
+                attention_mask = self._align_attention_mask_memory(attention_mask)
+
+            hidden_states = F.scaled_dot_product_attention(
+                query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale
+            )
 
-        # the output of sdp = (batch, num_heads, seq_len, head_dim)
-        hidden_states = F.scaled_dot_product_attention(
-            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale
-        )
+            # [B, H, S_query, He] -> [B, S_query, H, He]
+            hidden_states = hidden_states.transpose(1, 2)
+            # [B, S_query, H, He] -> [B, S_query, H*He]
+            hidden_states = hidden_states.reshape(hidden_states.shape[:-2] + (-1,))
+            hidden_states = hidden_states.to(query.dtype)
+        else:
+            # [B, S, H*He] -> [B*H, S, He]
+            query = self._head_to_batch_dim(query, head_dim)
+            key = self._head_to_batch_dim(key, head_dim)
+            value = self._head_to_batch_dim(value, head_dim)
+
+            if attention_mask is not None:
+                # attention mask already in shape [B*H, 1, S_key]/[B*H, S_query, S_key]
+                # mask alignment to 8 decreases memory consumption and increases speed
+                attention_mask = self._align_attention_mask_memory(attention_mask)
+
+            hidden_states = F.scaled_dot_product_attention(
+                query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale
+            )
 
-        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
-        hidden_states = hidden_states.to(query.dtype)
+            hidden_states = hidden_states.to(query.dtype)
+            # [B*H, S_query, He] -> [B, S_query, H*He]
+            hidden_states = self._batch_to_head_dim(hidden_states, batch_size)
 
         return hidden_states
 
@@ -381,52 +487,50 @@ def run_attention_sliced(
         attention_mask: Optional[torch.Tensor],
         slice_size: int,
     ) -> torch.Tensor:
-        dim = query.shape[-1]
-
-        query = attn.head_to_batch_dim(query)
-        key = attn.head_to_batch_dim(key)
-        value = attn.head_to_batch_dim(value)
+        batch_size = query.shape[0]
+        head_dim = attn.to_q.weight.shape[0] // attn.heads
+        heads_count = query.shape[2] // head_dim
+
+        # [B, S, H*He] -> [B, H, S, He]
+        query = query.reshape(query.shape[0], query.shape[1], -1, head_dim).transpose(1, 2)
+        key = key.reshape(key.shape[0], key.shape[1], -1, head_dim).transpose(1, 2)
+        value = value.reshape(value.shape[0], value.shape[1], -1, head_dim).transpose(1, 2)
+        # [B*H, S_query/1, S_key] -> [B, H, S_query/1, S_key]
+        if attention_mask is not None:
+            attention_mask = attention_mask.reshape(batch_size, -1, attention_mask.shape[1], attention_mask.shape[2])
 
-        batch_size_attention, query_tokens, _ = query.shape
-        hidden_states = torch.zeros(
-            (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
-        )
+        # [B, H, S_query, He]
+        hidden_states = torch.empty(query.shape, device=query.device, dtype=query.dtype)
 
-        for i in range((batch_size_attention - 1) // slice_size + 1):
+        for i in range((heads_count - 1) // slice_size + 1):
             start_idx = i * slice_size
             end_idx = (i + 1) * slice_size
 
-            query_slice = query[start_idx:end_idx]
-            key_slice = key[start_idx:end_idx]
-            value_slice = value[start_idx:end_idx]
-            attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
-
-            if self.attention_type == "normal":
-                attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
-                torch.bmm(attn_slice, value_slice, out=hidden_states[start_idx:end_idx])
-                del attn_slice
-            elif self.attention_type == "xformers":
-                if attn_mask_slice is not None:
-                    attn_mask_slice = attn_mask_slice.expand(-1, query.shape[1], -1)
-
-                hidden_states[start_idx:end_idx] = xformers.ops.memory_efficient_attention(
-                    query_slice, key_slice, value_slice, attn_bias=attn_mask_slice, op=None, scale=attn.scale
-                )
-            elif self.attention_type == "torch-sdp":
-                if attn_mask_slice is not None:
-                    attn_mask_slice = attn_mask_slice.unsqueeze(0)
-
-                hidden_states[start_idx:end_idx] = F.scaled_dot_product_attention(
-                    query_slice.unsqueeze(0),
-                    key_slice.unsqueeze(0),
-                    value_slice.unsqueeze(0),
-                    attn_mask=attn_mask_slice,
-                    dropout_p=0.0,
-                    is_causal=False,
-                    scale=attn.scale,
-                ).squeeze(0)
-            else:
-                raise ValueError(f"Unknown attention type: {self.attention_type}")
+            # [B, H_s, S, He] -> [B, S, H_s*He]
+            query_slice = query[:, start_idx:end_idx, :, :].transpose(1, 2).reshape(batch_size, query.shape[2], -1)
+            key_slice = key[:, start_idx:end_idx, :, :].transpose(1, 2).reshape(batch_size, key.shape[2], -1)
+            value_slice = value[:, start_idx:end_idx, :, :].transpose(1, 2).reshape(batch_size, value.shape[2], -1)
 
-        hidden_states = attn.batch_to_head_dim(hidden_states)
-        return hidden_states
+            # [B, H_s, S_query/1, S_key] -> [B*H_s, S_query/1, S_key]
+            attn_mask_slice = None
+            if attention_mask is not None:
+                attn_mask_slice = attention_mask[:, start_idx:end_idx, :, :].reshape((-1,) + attention_mask.shape[-2:])
+
+            # [B, S_query, H_s*He]
+            hidden_states_slice = self.run_attention(
+                attn=attn,
+                query=query_slice,
+                key=key_slice,
+                value=value_slice,
+                attention_mask=attn_mask_slice,
+                no_sliced=True,
+            )
+
+            # [B, S_query, H_s*He] -> [B, H_s, S_query, He]
+            hidden_states[:, start_idx:end_idx] = hidden_states_slice.reshape(
+                hidden_states_slice.shape[:-1] + (-1, head_dim)
+            ).transpose(1, 2)
+
+        # [B, H_s, S_query, He] -> [B, S_query, H_s*He]
+        hidden_states = hidden_states.transpose(1, 2)
+        return hidden_states.reshape(hidden_states.shape[:-2] + (-1,))
diff --git a/invokeai/backend/util/hotfixes.py b/invokeai/backend/util/hotfixes.py
index 7e362fe9589..1991aed3ccc 100644
--- a/invokeai/backend/util/hotfixes.py
+++ b/invokeai/backend/util/hotfixes.py
@@ -802,8 +802,11 @@ def new_LoRACompatibleConv_forward(self, hidden_states, scale: float = 1.0):
 
 if xformers_available:
     # TODO: remove when fixed in diffusers
+    from invokeai.backend.stable_diffusion.diffusion.custom_attention import CustomAttnProcessor
+
     _xformers_memory_efficient_attention = xformers.ops.memory_efficient_attention
 
+    # TODO: remove? or there still possible calls to xformers not by our attention processor?
     def new_memory_efficient_attention(
         query: torch.Tensor,
         key: torch.Tensor,
@@ -815,16 +818,8 @@ def new_memory_efficient_attention(
         op=None,
     ):
         # diffusers not align shape to 8, which is required by xformers
-        if attn_bias is not None and type(attn_bias) is torch.Tensor:
-            orig_size = attn_bias.shape[-1]
-            new_size = ((orig_size + 7) // 8) * 8
-            aligned_attn_bias = torch.zeros(
-                (attn_bias.shape[0], attn_bias.shape[1], new_size),
-                device=attn_bias.device,
-                dtype=attn_bias.dtype,
-            )
-            aligned_attn_bias[:, :, :orig_size] = attn_bias
-            attn_bias = aligned_attn_bias[:, :, :orig_size]
+        if attn_bias is not None and isinstance(attn_bias, torch.Tensor):
+            attn_bias = CustomAttnProcessor._align_attention_mask_memory(attn_bias)
 
         return _xformers_memory_efficient_attention(
             query=query,

From 0b1ff8f659f5bf605e7da3932b6c7652238ada5a Mon Sep 17 00:00:00 2001
From: Sergey Borisov <stalkek7779@yandex.ru>
Date: Tue, 20 Aug 2024 02:19:57 +0300
Subject: [PATCH 23/25] Remove redundant alignment in batched torch-sdp
 execution, add comments

---
 .../diffusion/custom_attention.py             | 24 +++++++++++++++----
 1 file changed, 19 insertions(+), 5 deletions(-)

diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
index d854e20efde..8b1b0d3872d 100644
--- a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
+++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
@@ -367,6 +367,12 @@ def run_attention_xformers(
         head_dim = attn.to_q.weight.shape[0] // attn.heads
 
         # batched execution on xformers slightly faster for small heads count
+        # 8 heads:
+        # xformers(dim3): 20.155955553054810 vram: 16483328
+        # xformers(dim4): 17.558132648468018 vram: 16483328
+        # 1 head:
+        # xformers(dim3):  5.660739183425903 vram:  9516032
+        # xformers(dim4):  6.114191055297852 vram:  9516032
         if multihead is None:
             heads_count = query.shape[2] // head_dim
             multihead = heads_count >= 4
@@ -433,7 +439,9 @@ def run_attention_sdp(
         head_dim = attn.to_q.weight.shape[0] // attn.heads
 
         if multihead is None:
-            # multihead extremely slow on old cuda gpu
+            # multihead extremely slow on old cuda gpu:
+            # torch-sdp(dim3): 30.07543110847473 vram: 23954432
+            # torch-sdp(dim4): 299.3908393383026 vram: 13861888
             multihead = not self.is_old_cuda
 
         if multihead:
@@ -446,6 +454,12 @@ def run_attention_sdp(
                 # [B*H, 1, S_key] -> [B, H, 1, S_key]
                 attention_mask = attention_mask.view(batch_size, -1, attention_mask.shape[1], attention_mask.shape[2])
                 # mask alignment to 8 decreases memory consumption and increases speed
+                # fp16:
+                # torch-sdp(dim4, mask):          6.1701478958129880 vram:  7864320
+                # torch-sdp(dim4, aligned mask):  3.3127212524414062 vram:  2621440
+                # fp32:
+                # torch-sdp(dim4, mask):         23.0943229198455800 vram: 16121856
+                # torch-sdp(dim4, aligned mask): 17.3104763031005860 vram:  5636096
                 attention_mask = self._align_attention_mask_memory(attention_mask)
 
             hidden_states = F.scaled_dot_product_attention(
@@ -463,10 +477,10 @@ def run_attention_sdp(
             key = self._head_to_batch_dim(key, head_dim)
             value = self._head_to_batch_dim(value, head_dim)
 
-            if attention_mask is not None:
-                # attention mask already in shape [B*H, 1, S_key]/[B*H, S_query, S_key]
-                # mask alignment to 8 decreases memory consumption and increases speed
-                attention_mask = self._align_attention_mask_memory(attention_mask)
+            # attention mask already in shape [B*H, 1, S_key]/[B*H, S_query, S_key]
+            # and there no noticable changes from memory alignment:
+            # torch-sdp(dim3, mask):          9.7391905784606930 vram: 12713984
+            # torch-sdp(dim3, aligned mask): 10.0090200901031500 vram: 12713984
 
             hidden_states = F.scaled_dot_product_attention(
                 query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale

From 3d19cacdc4c72bd9371db74b626bdeda5d002cd8 Mon Sep 17 00:00:00 2001
From: Sergey Borisov <stalkek7779@yandex.ru>
Date: Tue, 20 Aug 2024 03:05:39 +0300
Subject: [PATCH 24/25] Suggested changes

Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
---
 .../diffusion/custom_attention.py             | 21 +++++++++++++------
 1 file changed, 15 insertions(+), 6 deletions(-)

diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
index 8b1b0d3872d..740402d8b91 100644
--- a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
+++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
@@ -8,6 +8,7 @@
 from diffusers.models.attention_processor import Attention
 from diffusers.utils.import_utils import is_xformers_available
 
+import invokeai.backend.util.logging as logger
 from invokeai.app.services.config.config_default import get_config
 from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
 from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData
@@ -77,12 +78,20 @@ def _select_attention_type(self) -> str:
         if device.type == "mps":
             return "normal"
         elif device.type == "cuda":
+            # In testing on a Tesla P40 (Pascal architecture), torch-sdp is much slower than xformers
+            # (8.84 s/it vs. 1.81 s/it for SDXL). We have not tested extensively to find the precise GPU architecture or
+            # compute capability where this performance gap begins.
             # Flash Attention is supported from sm80 compute capability onwards in PyTorch
-            # https://pytorch.org/blog/accelerated-pytorch-2/
-            if self.is_old_cuda and xformers is not None:
-                return "xformers"
-            else:
-                return "torch-sdp"
+            # (https://pytorch.org/blog/accelerated-pytorch-2/). For now, we use this as the cutoff for selecting
+            # between xformers and torch-sdp.
+            if self.is_old_cuda:
+                if xformers is not None:
+                    return "xformers"
+                logger.warning(
+                    f"xFormers is not installed, but is recommended for best performance with GPU {torch.cuda.get_device_properties(device).name}"
+                )
+
+            return "torch-sdp"
         else:  # cpu
             return "torch-sdp"
 
@@ -478,7 +487,7 @@ def run_attention_sdp(
             value = self._head_to_batch_dim(value, head_dim)
 
             # attention mask already in shape [B*H, 1, S_key]/[B*H, S_query, S_key]
-            # and there no noticable changes from memory alignment:
+            # and there no noticable changes from memory alignment in batched run:
             # torch-sdp(dim3, mask):          9.7391905784606930 vram: 12713984
             # torch-sdp(dim3, aligned mask): 10.0090200901031500 vram: 12713984
 

From b947129799ecaa68880585a8183fa48446fa0f65 Mon Sep 17 00:00:00 2001
From: Sergey Borisov <stalkek7779@yandex.ru>
Date: Tue, 20 Aug 2024 21:28:02 +0300
Subject: [PATCH 25/25] Edit comments

---
 .../diffusion/custom_attention.py             | 34 ++++++++++---------
 1 file changed, 18 insertions(+), 16 deletions(-)

diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
index 740402d8b91..68d3bdc7c84 100644
--- a/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
+++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py
@@ -376,12 +376,12 @@ def run_attention_xformers(
         head_dim = attn.to_q.weight.shape[0] // attn.heads
 
         # batched execution on xformers slightly faster for small heads count
-        # 8 heads:
-        # xformers(dim3): 20.155955553054810 vram: 16483328
-        # xformers(dim4): 17.558132648468018 vram: 16483328
-        # 1 head:
-        # xformers(dim3):  5.660739183425903 vram:  9516032
-        # xformers(dim4):  6.114191055297852 vram:  9516032
+        # 8 heads, fp16 (100000 attention calls):
+        # xformers(dim3): 20.155955553054810s vram: 16483328b
+        # xformers(dim4): 17.558132648468018s vram: 16483328b
+        # 1 head, fp16 (100000 attention calls):
+        # xformers(dim3):  5.660739183425903s vram:  9516032b
+        # xformers(dim4):  6.114191055297852s vram:  9516032b
         if multihead is None:
             heads_count = query.shape[2] // head_dim
             multihead = heads_count >= 4
@@ -449,8 +449,9 @@ def run_attention_sdp(
 
         if multihead is None:
             # multihead extremely slow on old cuda gpu:
-            # torch-sdp(dim3): 30.07543110847473 vram: 23954432
-            # torch-sdp(dim4): 299.3908393383026 vram: 13861888
+            # fp16 (100000 attention calls):
+            # torch-sdp(dim3): 30.07543110847473s vram: 23954432b
+            # torch-sdp(dim4): 299.3908393383026s vram: 13861888b
             multihead = not self.is_old_cuda
 
         if multihead:
@@ -463,12 +464,12 @@ def run_attention_sdp(
                 # [B*H, 1, S_key] -> [B, H, 1, S_key]
                 attention_mask = attention_mask.view(batch_size, -1, attention_mask.shape[1], attention_mask.shape[2])
                 # mask alignment to 8 decreases memory consumption and increases speed
-                # fp16:
-                # torch-sdp(dim4, mask):          6.1701478958129880 vram:  7864320
-                # torch-sdp(dim4, aligned mask):  3.3127212524414062 vram:  2621440
-                # fp32:
-                # torch-sdp(dim4, mask):         23.0943229198455800 vram: 16121856
-                # torch-sdp(dim4, aligned mask): 17.3104763031005860 vram:  5636096
+                # fp16 (100000 attention calls):
+                # torch-sdp(dim4, mask):          6.1701478958129880s vram:  7864320b
+                # torch-sdp(dim4, aligned mask):  3.3127212524414062s vram:  2621440b
+                # fp32 (100000 attention calls):
+                # torch-sdp(dim4, mask):         23.0943229198455800s vram: 16121856b
+                # torch-sdp(dim4, aligned mask): 17.3104763031005860s vram:  5636096b
                 attention_mask = self._align_attention_mask_memory(attention_mask)
 
             hidden_states = F.scaled_dot_product_attention(
@@ -488,8 +489,9 @@ def run_attention_sdp(
 
             # attention mask already in shape [B*H, 1, S_key]/[B*H, S_query, S_key]
             # and there no noticable changes from memory alignment in batched run:
-            # torch-sdp(dim3, mask):          9.7391905784606930 vram: 12713984
-            # torch-sdp(dim3, aligned mask): 10.0090200901031500 vram: 12713984
+            # fp16 (100000 attention calls):
+            # torch-sdp(dim3, mask):          9.7391905784606930s vram: 12713984b
+            # torch-sdp(dim3, aligned mask): 10.0090200901031500s vram: 12713984b
 
             hidden_states = F.scaled_dot_product_attention(
                 query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale