Skip to content

Commit 987a937

Browse files
authored
Support context window for PiD and fix lq_latent rounding (Comfy-Org#14136)
1 parent 51ef17e commit 987a937

2 files changed

Lines changed: 20 additions & 2 deletions

File tree

comfy/ldm/pixeldit/pid.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,9 @@ def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_
207207
f"Flux1/SD3 = 16 channels, Flux2 = 128 channels."
208208
)
209209
B = x.shape[0]
210-
Hs = x.shape[2] // self.patch_size
211-
Ws = x.shape[3] // self.patch_size
210+
# Match the backbone's pad_to_patch_size (round up) so the LQ grid lines up with the patch stream.
211+
Hs = -(-x.shape[2] // self.patch_size)
212+
Ws = -(-x.shape[3] // self.patch_size)
212213

213214
degrade_sigma = degrade_sigma.to(device=x.device, dtype=torch.float32).reshape(-1)
214215
if degrade_sigma.numel() == 1 and B > 1:

comfy/model_base.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1428,6 +1428,23 @@ def extra_conds(self, **kwargs):
14281428
out["degrade_sigma"] = comfy.conds.CONDRegular(degrade_sigma)
14291429
return out
14301430

1431+
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
1432+
if cond_key == "lq_latent" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
1433+
lq = cond_value.cond
1434+
dim = window.dim
1435+
if dim >= lq.ndim:
1436+
return None
1437+
lq_proj = self.diffusion_model.lq_proj
1438+
ratio = lq_proj.sr_scale * lq_proj.latent_spatial_down_factor
1439+
# Map x window indices -> lq indices (deduplicated, sorted, in-bounds).
1440+
lq_size = lq.size(dim)
1441+
lq_indices = sorted({i // ratio for i in window.index_list if 0 <= i // ratio < lq_size})
1442+
if not lq_indices:
1443+
return None
1444+
idx = tuple([slice(None)] * dim + [lq_indices])
1445+
return cond_value._copy_with(lq[idx].to(device))
1446+
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
1447+
14311448

14321449
class WAN21(BaseModel):
14331450
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):

0 commit comments

Comments
 (0)