Skip to content

Commit 00b633f

Browse files
Revert "Add SeedVR2 support (CORE-6) (Comfy-Org#14110)" (Comfy-Org#14359)
This reverts commit 7863cf0.
1 parent a0a055b commit 00b633f

26 files changed

Lines changed: 40 additions & 7383 deletions

comfy/latent_formats.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ class LatentFormat:
44
scale_factor = 1.0
55
latent_channels = 4
66
latent_dimensions = 2
7-
preserve_empty_channel_multiples = False
87
latent_rgb_factors = None
98
latent_rgb_factors_bias = None
109
latent_rgb_factors_reshape = None
@@ -780,10 +779,6 @@ class ACEAudio(LatentFormat):
780779
latent_channels = 8
781780
latent_dimensions = 2
782781

783-
class SeedVR2(LatentFormat):
784-
latent_channels = 16
785-
preserve_empty_channel_multiples = True
786-
787782
class ACEAudio15(LatentFormat):
788783
latent_channels = 64
789784
latent_dimensions = 1

comfy/ldm/modules/attention.py

Lines changed: 2 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -735,86 +735,7 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
735735
)
736736
return out
737737

738-
def _var_attention_qkv(q, k, v, heads, skip_reshape):
739-
if skip_reshape:
740-
return q, k, v, q.shape[-1]
741-
total_tokens, embed_dim = q.shape
742-
head_dim = embed_dim // heads
743-
return (
744-
q.view(total_tokens, heads, head_dim),
745-
k.view(k.shape[0], heads, head_dim),
746-
v.view(v.shape[0], heads, head_dim),
747-
head_dim,
748-
)
749738

750-
751-
def _var_attention_output(out, heads, head_dim, skip_output_reshape):
752-
if skip_output_reshape:
753-
return out
754-
return out.reshape(-1, heads * head_dim)
755-
756-
757-
def _use_blackwell_attention():
758-
device = model_management.get_torch_device()
759-
if device.type != "cuda":
760-
return False
761-
major, minor = torch.cuda.get_device_capability(device)
762-
return (major, minor) >= (12, 0)
763-
764-
765-
def _validate_split_cu_seqlens(name, cu_seqlens, token_count):
766-
if cu_seqlens.dtype not in (torch.int32, torch.int64):
767-
raise ValueError(f"{name} must use an integer dtype")
768-
if cu_seqlens.ndim != 1 or cu_seqlens.numel() < 2:
769-
raise ValueError(f"{name} must be a 1D tensor with at least two offsets")
770-
if cu_seqlens[0].item() != 0:
771-
raise ValueError(f"{name} must start at 0")
772-
if (cu_seqlens[1:] <= cu_seqlens[:-1]).any().item():
773-
raise ValueError(f"{name} must be strictly increasing")
774-
if cu_seqlens[-1].item() != token_count:
775-
raise ValueError(f"{name} does not match token count")
776-
777-
778-
def _split_indices(cu_seqlens):
779-
return cu_seqlens[1:-1].to(device="cpu", dtype=torch.long)
780-
781-
782-
def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
783-
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
784-
785-
_validate_split_cu_seqlens("cu_seqlens_q", cu_seqlens_q, q.shape[0])
786-
_validate_split_cu_seqlens("cu_seqlens_k", cu_seqlens_k, k.shape[0])
787-
if cu_seqlens_k[-1].item() != v.shape[0]:
788-
raise ValueError("cu_seqlens_k does not match v token count")
789-
790-
q_split_indices = _split_indices(cu_seqlens_q)
791-
k_split_indices = _split_indices(cu_seqlens_k)
792-
q_splits = torch.tensor_split(q, q_split_indices, dim=0)
793-
k_splits = torch.tensor_split(k, k_split_indices, dim=0)
794-
v_splits = torch.tensor_split(v, k_split_indices, dim=0)
795-
if len(q_splits) != len(k_splits) or len(q_splits) != len(v_splits):
796-
raise ValueError("cu_seqlens_q and cu_seqlens_k must describe the same sequence count")
797-
798-
out = []
799-
for q_i, k_i, v_i in zip(q_splits, k_splits, v_splits):
800-
q_i = q_i.permute(1, 0, 2).unsqueeze(0)
801-
k_i = k_i.permute(1, 0, 2).unsqueeze(0)
802-
v_i = v_i.permute(1, 0, 2).unsqueeze(0)
803-
out_dtype = q_i.dtype
804-
if optimized_attention is attention_sage and q_i.dtype not in (torch.float16, torch.bfloat16):
805-
q_i = q_i.to(torch.bfloat16)
806-
k_i = k_i.to(torch.bfloat16)
807-
v_i = v_i.to(torch.bfloat16)
808-
out_i = optimized_attention(q_i, k_i, v_i, heads, skip_reshape=True, skip_output_reshape=True)
809-
if out_i.dtype != out_dtype:
810-
out_i = out_i.to(out_dtype)
811-
out.append(out_i.squeeze(0).permute(1, 0, 2))
812-
813-
out = torch.cat(out, dim=0)
814-
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
815-
816-
817-
optimized_var_attention = var_attention_optimized_split
818739
optimized_attention = attention_basic
819740

820741
if model_management.sage_attention_enabled():
@@ -837,8 +758,6 @@ def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *a
837758
logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention")
838759
optimized_attention = attention_sub_quad
839760

840-
logging.info("Using optimized_attention split-loop for variable-length attention")
841-
842761
optimized_attention_masked = optimized_attention
843762

844763

@@ -854,7 +773,6 @@ def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *a
854773
register_attention_function("pytorch", attention_pytorch)
855774
register_attention_function("sub_quad", attention_sub_quad)
856775
register_attention_function("split", attention_split)
857-
register_attention_function("var_attention_optimized_split", var_attention_optimized_split)
858776

859777

860778
def optimized_attention_for_device(device, mask=False, small_input=False):
@@ -1291,3 +1209,5 @@ def forward(
12911209
x = self.proj_out(x)
12921210
out = x + x_in
12931211
return out
1212+
1213+

comfy/ldm/modules/diffusionmodules/model.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import xformers
1414
import xformers.ops
1515

16-
1716
def torch_cat_if_needed(xl, dim):
1817
xl = [x for x in xl if x is not None and x.shape[dim] > 0]
1918
if len(xl) > 1:
@@ -23,8 +22,7 @@ def torch_cat_if_needed(xl, dim):
2322
else:
2423
return None
2524

26-
27-
def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1):
25+
def get_timestep_embedding(timesteps, embedding_dim):
2826
"""
2927
This matches the implementation in Denoising Diffusion Probabilistic Models:
3028
From Fairseq.
@@ -35,13 +33,11 @@ def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, down
3533
assert len(timesteps.shape) == 1
3634

3735
half_dim = embedding_dim // 2
38-
emb = math.log(10000) / (half_dim - downscale_freq_shift)
36+
emb = math.log(10000) / (half_dim - 1)
3937
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
4038
emb = emb.to(device=timesteps.device)
4139
emb = timesteps.float()[:, None] * emb[None, :]
4240
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
43-
if flip_sin_to_cos:
44-
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
4541
if embedding_dim % 2 == 1: # zero pad
4642
emb = torch.nn.functional.pad(emb, (0,1,0,0))
4743
return emb

0 commit comments

Comments
 (0)