Skip to content

support cogvideox i2v & adapt to the changes in commit 607f968 #59

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

✅ [CogVideoX1.5-5B-T2V](https://huggingface.co/THUDM/CogVideoX1.5-5B)

✅ [CogVideoX1.5-5B-I2V](https://huggingface.co/THUDM/CogVideoX1.5-5B-I2V)

## How to Run

Please refer to the [documentation](https://github.com/ModelTC/lightx2v/tree/main/docs) in lightx2v.
Expand Down
42 changes: 42 additions & 0 deletions configs/cogvideox_i2v.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
{
"seed": 42,
"text_len": 226,
"num_videos_per_prompt": 1,
"target_video_length": 81,
"num_inference_steps": 50,
"num_train_timesteps": 1000,
"timestep_spacing": "trailing",
"steps_offset": 0,
"latent_channels": 16,
"height": 768,
"width": 1360,
"vae_scale_factor_temporal": 4,
"vae_scale_factor_spatial": 8,
"vae_scaling_factor_image": 0.7,
"vae_invert_scale_latents": true,
"batch_size": 1,
"patch_size": 2,
"patch_size_t": 2,
"guidance_scale": 0,
"use_rotary_positional_embeddings": true,
"do_classifier_free_guidance": false,
"transformer_sample_width": 170,
"transformer_sample_height": 96,
"transformer_sample_frames": 81,
"transformer_attention_head_dim": 64,
"transformer_num_attention_heads": 48,
"transformer_temporal_compression_ratio": 4,
"transformer_temporal_interpolation_scale": 1.0,
"transformer_use_learned_positional_embeddings": false,
"transformer_spatial_interpolation_scale": 1.875,
"transformer_num_layers": 42,
"transformer_ofs_embed_dim": 512,
"beta_schedule": "scaled_linear",
"scheduler_beta_start": 0.00085,
"scheduler_beta_end": 0.012,
"scheduler_set_alpha_to_one": true,
"scheduler_snr_shift_scale": 1.0,
"scheduler_rescale_betas_zero_snr": true,
"scheduler_prediction_type": "v_prediction",
"use_dynamic_cfg": true
}
16 changes: 15 additions & 1 deletion lightx2v/models/networks/cogvideox/infer/pre_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def _get_positional_embeddings(self, sample_height, sample_width, sample_frames,

return joint_pos_embedding

def infer(self, weights, hidden_states, timestep, encoder_hidden_states):
def infer(self, weights, hidden_states, timestep, encoder_hidden_states, ofs=None):
t_emb = get_timestep_embedding(
timestep,
self.inner_dim,
Expand All @@ -46,6 +46,20 @@ def infer(self, weights, hidden_states, timestep, encoder_hidden_states):
sample = torch.nn.functional.silu(sample)
emb = weights.time_embedding_linear_2.apply(sample)

if ofs is not None:
ofs_emb = get_timestep_embedding(
ofs,
self.config.transformer_ofs_embed_dim,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.freq_shift,
scale=self.scale,
)
ofs_emb = ofs_emb.to(dtype=hidden_states.dtype)
ofs_sample = weights.ofs_embedding_linear_1.apply(ofs_emb)
ofs_sample = torch.nn.functional.silu(ofs_sample)
ofs_emb = weights.ofs_embedding_linear_2.apply(ofs_sample)
emb = emb + ofs_emb

text_embeds = weights.patch_embed_text_proj.apply(encoder_hidden_states)
num_frames, channels, height, width = hidden_states.shape
infer_shapes = (num_frames, channels, height, width)
Expand Down
15 changes: 11 additions & 4 deletions lightx2v/models/networks/cogvideox/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,23 @@ def infer(self, inputs):
t = self.scheduler.timesteps[self.scheduler.step_index]
text_encoder_output = inputs["text_encoder_output"]["context"]
do_classifier_free_guidance = self.config.guidance_scale > 1.0
latent_model_input = self.scheduler.latents
latent_model_input = torch.cat([self.scheduler.latents] * 2) if do_classifier_free_guidance else self.scheduler.latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
if self.config.task in ["i2v"]:
ofs_emb = None if self.config.transformer_ofs_embed_dim is None else latent_model_input.new_full((1,), fill_value=2.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The ofs_emb is initialized with fill_value=2.0 when self.config.transformer_ofs_embed_dim is not None. Could you clarify the significance of this specific value 2.0? If it's a critical parameter, consider defining it as a named constant or making it configurable for better readability and maintainability.

image_latents = self.scheduler.image_latents
latent_image_input = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents
latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2)
else:
ofs_emb = None
timestep = t.expand(latent_model_input.shape[0])

hidden_states, encoder_hidden_states, emb, infer_shapes = self.pre_infer.infer(
self.pre_weight,
latent_model_input[0],
timestep,
text_encoder_output[0],
ofs=ofs_emb,
)

hidden_states, encoder_hidden_states = self.transformer_infer.infer(
Expand All @@ -103,13 +111,12 @@ def infer(self, inputs):
)

noise_pred = self.post_infer.infer(self.post_weight, hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=emb, infer_shapes=infer_shapes)

noise_pred = noise_pred.float()

if self.config.use_dynamic_cfg: # True
if self.config.use_dynamic_cfg:
self.scheduler.guidance_scale = 1 + self.scheduler.guidance_scale * ((1 - math.cos(math.pi * ((self.scheduler.infer_steps - t.item()) / self.scheduler.infer_steps) ** 5.0)) / 2)

if do_classifier_free_guidance: # False
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.scheduler.guidance_scale * (noise_pred_text - noise_pred_uncond)

Expand Down
6 changes: 5 additions & 1 deletion lightx2v/models/networks/cogvideox/weights/pre_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@ def load_weights(self, weight_dict):
self.time_embedding_linear_2 = MM_WEIGHT_REGISTER["Default"]("time_embedding.linear_2.weight", "time_embedding.linear_2.bias")
self.patch_embed_proj = MM_WEIGHT_REGISTER["Default"]("patch_embed.proj.weight", "patch_embed.proj.bias")
self.patch_embed_text_proj = MM_WEIGHT_REGISTER["Default"]("patch_embed.text_proj.weight", "patch_embed.text_proj.bias")

self.weight_list = [self.time_embedding_linear_1, self.time_embedding_linear_2, self.patch_embed_proj, self.patch_embed_text_proj]

if "ofs_embed_dim" in self.config:
self.ofs_embedding_linear_1 = MM_WEIGHT_REGISTER["Default"]("ofs_embedding.linear_1.weight", "ofs_embedding.linear_1.bias")
self.ofs_embedding_linear_2 = MM_WEIGHT_REGISTER["Default"]("ofs_embedding.linear_2.weight", "ofs_embedding.linear_2.bias")
self.weight_list += [self.ofs_embedding_linear_1, self.ofs_embedding_linear_2]

for mm_weight in self.weight_list:
mm_weight.set_config(self.config)
mm_weight.load(weight_dict)
Expand Down
93 changes: 54 additions & 39 deletions lightx2v/models/runners/cogvideox/cogvidex_runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from diffusers.utils import export_to_video
from diffusers.utils import export_to_video, load_image
import imageio
import numpy as np
import torch

from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.default_runner import DefaultRunner
Expand All @@ -16,27 +17,19 @@ class CogvideoxRunner(DefaultRunner):
def __init__(self, config):
super().__init__(config)

def load_transformer(self, init_device):
model = CogvideoxModel(self.config)
return model

def load_image_encoder(self, init_device):
return None

def load_text_encoder(self, init_device):
@ProfilingContext("Load models")
def load_model(self):
text_encoder = T5EncoderModel_v1_1_xxl(self.config)
text_encoders = [text_encoder]
return text_encoders

def load_vae(self, init_device):
vae_model = CogvideoxVAE(self.config)
return vae_model, vae_model
self.text_encoders = [text_encoder]
self.model = CogvideoxModel(self.config)
self.vae_model = CogvideoxVAE(self.config)
image_encoder = None

def init_scheduler(self):
scheduler = CogvideoxXDPMScheduler(self.config)
self.model.set_scheduler(scheduler)

def run_text_encoder(self, text, img):
def run_text_encoder(self, text):
text_encoder_output = {}
n_prompt = self.config.get("negative_prompt", "")
context = self.text_encoders[0].infer([text], self.config)
Expand All @@ -45,38 +38,60 @@ def run_text_encoder(self, text, img):
text_encoder_output["context_null"] = context_null
return text_encoder_output

def run_vae_encoder(self, img):
# TODO: implement vae encoder for Cogvideox
raise NotImplementedError("I2V inference is not implemented for Cogvideox.")
def run_image_encoder(self, config, vae_model):
image = load_image(image=self.config.image_path)
image = vae_model.video_processor.preprocess(image, height=config.height, width=config.width).to(torch.device("cuda"), dtype=torch.bfloat16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The device and dtype for image preprocessing are hardcoded here to torch.device("cuda") and torch.bfloat16. While this might be the common case, would it be beneficial to align this with a global device/dtype configuration from self.config for greater flexibility, similar to how init_device is determined in get_init_device?

image = image.unsqueeze(2)
image = [vae_model.encode(img.unsqueeze(0)) for img in image]
return image

def get_encoder_output_i2v(self, clip_encoder_out, vae_encode_out, text_encoder_output, img):
# TODO: Implement image encoder for Cogvideox-I2V
raise ValueError(f"Unsupported model class: {self.config['model_cls']}")
@ProfilingContext("Run Encoders")
async def run_input_encoder_local_t2v(self):
prompt = self.config["prompt"]
text_encoder_output = self.run_text_encoder(prompt)
return {"text_encoder_output": text_encoder_output, "image_encoder_output": None}

@ProfilingContext("Run Encoders")
async def run_input_encoder_local_i2v(self):
image_encoder_output = self.run_image_encoder(self.config, self.vae_model)
prompt = self.config["prompt"]
text_encoder_output = self.run_text_encoder(prompt)
return {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output}

@ProfilingContext("Run VAE Decoder")
async def run_vae_decoder_local(self, latents, generator):
images = self.vae_model.decode(latents, generator=generator, config=self.config)
return images

def set_target_shape(self):
ret = {}
if self.config.task == "i2v":
# TODO: implement set_target_shape for Cogvideox-I2V
raise NotImplementedError("I2V inference is not implemented for Cogvideox.")
else:
num_frames = self.config.target_video_length
latent_frames = (num_frames - 1) // self.config.vae_scale_factor_temporal + 1
additional_frames = 0
patch_size_t = self.config.patch_size_t
if patch_size_t is not None and latent_frames % patch_size_t != 0:
additional_frames = patch_size_t - latent_frames % patch_size_t
num_frames += additional_frames * self.config.vae_scale_factor_temporal
self.config.target_shape = (
num_frames = self.config.target_video_length
latent_frames = (num_frames - 1) // self.config.vae_scale_factor_temporal + 1
additional_frames = 0
patch_size_t = self.config.patch_size_t
if patch_size_t is not None and latent_frames % patch_size_t != 0:
additional_frames = patch_size_t - latent_frames % patch_size_t
num_frames += additional_frames * self.config.vae_scale_factor_temporal
target_shape = (
self.config.batch_size,
(num_frames - 1) // self.config.vae_scale_factor_temporal + 1,
self.config.latent_channels,
self.config.height // self.config.vae_scale_factor_spatial,
self.config.width // self.config.vae_scale_factor_spatial,
)
if self.config.task in ["t2v"]:
self.config.target_shape = target_shape
elif self.config.task in ["i2v"]:
self.config.target_shape = target_shape[:1] + (target_shape[1] + target_shape[1] % self.config.patch_size_t,) + target_shape[2:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The frame dimension of target_shape for I2V is adjusted using target_shape[1] + target_shape[1] % self.config.patch_size_t.

Given that target_shape[1] (derived from num_frames on lines 67-73) should already be a multiple of self.config.patch_size_t due to the padding logic on lines 71-73, target_shape[1] % self.config.patch_size_t would be 0, making this addition redundant.

If self.config.patch_size_t is always 2 (as in cogvideox_i2v.json), this logic effectively ensures the dimension is even (L if L is even, L+1 if L is odd).

Could you clarify the intent here? Is this an additional specific padding requirement for CogVideoX I2V, or is there a scenario where target_shape[1] might not be a multiple of patch_size_t at this point?

self.config.padding_shape = (
self.config.batch_size,
(num_frames - 1) // self.config.vae_scale_factor_temporal + 1,
(num_frames - 1) // self.config.vae_scale_factor_temporal,
self.config.latent_channels,
self.config.height // self.config.vae_scale_factor_spatial,
self.config.width // self.config.vae_scale_factor_spatial,
)
ret["target_shape"] = self.config.target_shape
return ret
return None

def save_video_func(self, images):
def save_video(self, images):
with imageio.get_writer(self.config.save_video_path, fps=16) as writer:
for pil_image in images:
frame_np = np.array(pil_image, dtype=np.uint8)
Expand Down
33 changes: 33 additions & 0 deletions lightx2v/models/schedulers/cogvideox/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,18 @@ def rescale_zero_terminal_snr(alphas_cumprod):
return alphas_bar


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator=None, sample_mode="sample"):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")


class CogvideoxXDPMScheduler(BaseScheduler):
def __init__(self, config):
self.config = config
Expand Down Expand Up @@ -133,6 +145,8 @@ def set_timesteps(self):

def prepare(self, image_encoder_output):
self.image_encoder_output = image_encoder_output
if self.config.task in ["i2v"]:
self.prepare_image_latents(image=image_encoder_output, padding_shape=self.config.padding_shape, dtype=torch.bfloat16)
self.prepare_latents(shape=self.config.target_shape, dtype=torch.bfloat16)
self.prepare_guidance()
self.prepare_rotary_pos_embedding()
Expand All @@ -144,6 +158,25 @@ def prepare_latents(self, shape, dtype):
self.latents = latents
self.old_pred_original_sample = None

def prepare_image_latents(self, image, padding_shape, dtype):
image_latents = [retrieve_latents(img, self.generator) for img in image]
image_latents = torch.cat(image_latents, dim=0).to(torch.bfloat16).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]

if not self.config.vae_invert_scale_latents:
image_latents = self.config.vae_scaling_factor_image * image_latents
else:
# This is awkward but required because the CogVideoX team forgot to multiply the
# scaling factor during training :)
image_latents = 1 / self.config.vae_scaling_factor_image * image_latents

latent_padding = torch.zeros(padding_shape, device=torch.device("cuda"), dtype=torch.bfloat16)
image_latents = torch.cat([image_latents, latent_padding], dim=1)
# Select the first frame along the second dimension
if self.config.patch_size_t is not None:
first_frame = image_latents[:, : image_latents.size(1) % self.config.patch_size_t, ...]
image_latents = torch.cat([first_frame, image_latents], dim=1)
Comment on lines +175 to +177
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This padding logic first_frame = image_latents[:, : image_latents.size(1) % self.config.patch_size_t, ...]; image_latents = torch.cat([first_frame, image_latents], dim=1) results in a new length L_new = L + (L % P) where L is the current number of frames and P is patch_size_t.

If patch_size_t is 2 (as per cogvideox_i2v.json), this makes the frame dimension even. This is the same pattern observed in the runner's set_target_shape for I2V.

Is this specific way of padding (making the frame count L + L % P) a requirement for subsequent operations in CogVideoX, particularly when patch_size_t might not be 2? If patch_size_t could be other values (e.g., 3), L + L % P (e.g., 80 + 80%3 = 80+2=82) doesn't make the length a multiple of P (82 is not divisible by 3). Clarifying the rationale would be helpful.

self.image_latents = image_latents

def prepare_guidance(self):
self.guidance_scale = self.config.guidance_scale

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1126,7 +1126,6 @@ def disable_slicing(self) -> None:

def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = x.shape

if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
return self.tiled_encode(x)

Expand Down Expand Up @@ -1169,7 +1168,6 @@ def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[Autoencoder
h = torch.cat(encoded_slices)
else:
h = self._encode(x)

posterior = DiagonalGaussianDistribution(h)

if not return_dict:
Expand Down
18 changes: 12 additions & 6 deletions lightx2v/models/video_encoders/hf/cogvideox/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os
import glob
import torch # type: ignore
from safetensors import safe_open # type: ignore
from diffusers.video_processor import VideoProcessor # type: ignore
import torch
from safetensors import safe_open
from diffusers.video_processor import VideoProcessor

from lightx2v.models.video_encoders.hf.cogvideox.autoencoder_ks_cogvidex import AutoencoderKLCogVideoX

Expand Down Expand Up @@ -34,13 +34,19 @@ def load(self):
self.vae_config = AutoencoderKLCogVideoX.load_config(vae_path)
self.model = AutoencoderKLCogVideoX.from_config(self.vae_config)
vae_ckpt = self._load_ckpt(vae_path)
self.vae_scale_factor_spatial = 2 ** (len(self.vae_config["block_out_channels"]) - 1) # 8
self.vae_scale_factor_temporal = self.vae_config["temporal_compression_ratio"] # 4
self.vae_scaling_factor_image = self.vae_config["scaling_factor"] # 0.7
self.vae_scale_factor_spatial = 2 ** (len(self.vae_config["block_out_channels"]) - 1)
self.vae_scale_factor_temporal = self.vae_config["temporal_compression_ratio"]
self.vae_scaling_factor_image = self.vae_config["scaling_factor"]
self.model.load_state_dict(vae_ckpt)
self.model.to(torch.bfloat16).to(torch.device("cuda"))
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)

@torch.no_grad()
def encode(self, x, return_dict=True):
self.model.enable_tiling()
x = self.model.encode(x, return_dict)
return x

@torch.no_grad()
def decode(self, latents, generator, config):
latents = latents.permute(0, 2, 1, 3, 4)
Expand Down
Loading