Skip to content

Commit c69af65

Browse files
Uncap cosmos predict2 res and fix mem estimation. (Comfy-Org#8518)
1 parent 251f54a commit c69af65

2 files changed

Lines changed: 11 additions & 10 deletions

File tree

comfy/ldm/cosmos/position_embedding.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ def __init__(
7272
):
7373
del kwargs
7474
super().__init__()
75-
self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float, device=device))
7675
self.base_fps = base_fps
7776
self.max_h = len_h
7877
self.max_w = len_w
@@ -134,21 +133,19 @@ def generate_embeddings(
134133
temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range.to(device=device))
135134

136135
B, T, H, W, _ = B_T_H_W_C
136+
seq = torch.arange(max(H, W, T), dtype=torch.float, device=device)
137137
uniform_fps = (fps is None) or isinstance(fps, (int, float)) or (fps.min() == fps.max())
138138
assert (
139139
uniform_fps or B == 1 or T == 1
140140
), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1"
141-
assert (
142-
H <= self.max_h and W <= self.max_w
143-
), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w})"
144-
half_emb_h = torch.outer(self.seq[:H].to(device=device), h_spatial_freqs)
145-
half_emb_w = torch.outer(self.seq[:W].to(device=device), w_spatial_freqs)
141+
half_emb_h = torch.outer(seq[:H].to(device=device), h_spatial_freqs)
142+
half_emb_w = torch.outer(seq[:W].to(device=device), w_spatial_freqs)
146143

147144
# apply sequence scaling in temporal dimension
148145
if fps is None or self.enable_fps_modulation is False: # image case
149-
half_emb_t = torch.outer(self.seq[:T].to(device=device), temporal_freqs)
146+
half_emb_t = torch.outer(seq[:T].to(device=device), temporal_freqs)
150147
else:
151-
half_emb_t = torch.outer(self.seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs)
148+
half_emb_t = torch.outer(seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs)
152149

153150
half_emb_h = torch.stack([torch.cos(half_emb_h), -torch.sin(half_emb_h), torch.sin(half_emb_h), torch.cos(half_emb_h)], dim=-1)
154151
half_emb_w = torch.stack([torch.cos(half_emb_w), -torch.sin(half_emb_w), torch.sin(half_emb_w), torch.cos(half_emb_w)], dim=-1)

comfy/supported_models.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -923,9 +923,13 @@ class CosmosT2IPredict2(supported_models_base.BASE):
923923
unet_extra_config = {}
924924
latent_format = latent_formats.Wan21
925925

926-
memory_usage_factor = 1.6 #TODO
926+
memory_usage_factor = 1.0
927927

928-
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] #TODO
928+
supported_inference_dtypes = [torch.bfloat16, torch.float32]
929+
930+
def __init__(self, unet_config):
931+
super().__init__(unet_config)
932+
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.9
929933

930934
def get_model(self, state_dict, prefix="", device=None):
931935
out = model_base.CosmosPredict2(self, device=device)

0 commit comments

Comments
 (0)