diff --git a/README.md b/README.md index 71b1007..e6560ff 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,8 @@ ✅ [CogVideoX1.5-5B-T2V](https://huggingface.co/THUDM/CogVideoX1.5-5B) +✅ [CogVideoX1.5-5B-I2V](https://huggingface.co/THUDM/CogVideoX1.5-5B-I2V) + ## How to Run Please refer to the [documentation](https://github.com/ModelTC/lightx2v/tree/main/docs) in lightx2v. diff --git a/configs/cogvideox_i2v.json b/configs/cogvideox_i2v.json new file mode 100644 index 0000000..5bf28df --- /dev/null +++ b/configs/cogvideox_i2v.json @@ -0,0 +1,42 @@ +{ + "seed": 42, + "text_len": 226, + "num_videos_per_prompt": 1, + "target_video_length": 81, + "num_inference_steps": 50, + "num_train_timesteps": 1000, + "timestep_spacing": "trailing", + "steps_offset": 0, + "latent_channels": 16, + "height": 768, + "width": 1360, + "vae_scale_factor_temporal": 4, + "vae_scale_factor_spatial": 8, + "vae_scaling_factor_image": 0.7, + "vae_invert_scale_latents": true, + "batch_size": 1, + "patch_size": 2, + "patch_size_t": 2, + "guidance_scale": 0, + "use_rotary_positional_embeddings": true, + "do_classifier_free_guidance": false, + "transformer_sample_width": 170, + "transformer_sample_height": 96, + "transformer_sample_frames": 81, + "transformer_attention_head_dim": 64, + "transformer_num_attention_heads": 48, + "transformer_temporal_compression_ratio": 4, + "transformer_temporal_interpolation_scale": 1.0, + "transformer_use_learned_positional_embeddings": false, + "transformer_spatial_interpolation_scale": 1.875, + "transformer_num_layers": 42, + "transformer_ofs_embed_dim": 512, + "beta_schedule": "scaled_linear", + "scheduler_beta_start": 0.00085, + "scheduler_beta_end": 0.012, + "scheduler_set_alpha_to_one": true, + "scheduler_snr_shift_scale": 1.0, + "scheduler_rescale_betas_zero_snr": true, + "scheduler_prediction_type": "v_prediction", + "use_dynamic_cfg": true +} diff --git a/lightx2v/models/networks/cogvideox/infer/pre_infer.py b/lightx2v/models/networks/cogvideox/infer/pre_infer.py index 052d5d7..6283122 100644 --- a/lightx2v/models/networks/cogvideox/infer/pre_infer.py +++ b/lightx2v/models/networks/cogvideox/infer/pre_infer.py @@ -33,7 +33,7 @@ def _get_positional_embeddings(self, sample_height, sample_width, sample_frames, return joint_pos_embedding - def infer(self, weights, hidden_states, timestep, encoder_hidden_states): + def infer(self, weights, hidden_states, timestep, encoder_hidden_states, ofs=None): t_emb = get_timestep_embedding( timestep, self.inner_dim, @@ -46,6 +46,20 @@ def infer(self, weights, hidden_states, timestep, encoder_hidden_states): sample = torch.nn.functional.silu(sample) emb = weights.time_embedding_linear_2.apply(sample) + if ofs is not None: + ofs_emb = get_timestep_embedding( + ofs, + self.config.transformer_ofs_embed_dim, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.freq_shift, + scale=self.scale, + ) + ofs_emb = ofs_emb.to(dtype=hidden_states.dtype) + ofs_sample = weights.ofs_embedding_linear_1.apply(ofs_emb) + ofs_sample = torch.nn.functional.silu(ofs_sample) + ofs_emb = weights.ofs_embedding_linear_2.apply(ofs_sample) + emb = emb + ofs_emb + text_embeds = weights.patch_embed_text_proj.apply(encoder_hidden_states) num_frames, channels, height, width = hidden_states.shape infer_shapes = (num_frames, channels, height, width) diff --git a/lightx2v/models/networks/cogvideox/model.py b/lightx2v/models/networks/cogvideox/model.py index 5e4473d..fe3e358 100644 --- a/lightx2v/models/networks/cogvideox/model.py +++ b/lightx2v/models/networks/cogvideox/model.py @@ -84,8 +84,15 @@ def infer(self, inputs): t = self.scheduler.timesteps[self.scheduler.step_index] text_encoder_output = inputs["text_encoder_output"]["context"] do_classifier_free_guidance = self.config.guidance_scale > 1.0 - latent_model_input = self.scheduler.latents + latent_model_input = torch.cat([self.scheduler.latents] * 2) if do_classifier_free_guidance else self.scheduler.latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + if self.config.task in ["i2v"]: + ofs_emb = None if self.config.transformer_ofs_embed_dim is None else latent_model_input.new_full((1,), fill_value=2.0) + image_latents = self.scheduler.image_latents + latent_image_input = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents + latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2) + else: + ofs_emb = None timestep = t.expand(latent_model_input.shape[0]) hidden_states, encoder_hidden_states, emb, infer_shapes = self.pre_infer.infer( @@ -93,6 +100,7 @@ def infer(self, inputs): latent_model_input[0], timestep, text_encoder_output[0], + ofs=ofs_emb, ) hidden_states, encoder_hidden_states = self.transformer_infer.infer( @@ -103,13 +111,12 @@ def infer(self, inputs): ) noise_pred = self.post_infer.infer(self.post_weight, hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=emb, infer_shapes=infer_shapes) - noise_pred = noise_pred.float() - if self.config.use_dynamic_cfg: # True + if self.config.use_dynamic_cfg: self.scheduler.guidance_scale = 1 + self.scheduler.guidance_scale * ((1 - math.cos(math.pi * ((self.scheduler.infer_steps - t.item()) / self.scheduler.infer_steps) ** 5.0)) / 2) - if do_classifier_free_guidance: # False + if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.scheduler.guidance_scale * (noise_pred_text - noise_pred_uncond) diff --git a/lightx2v/models/networks/cogvideox/weights/pre_weights.py b/lightx2v/models/networks/cogvideox/weights/pre_weights.py index 618c125..a655c9d 100644 --- a/lightx2v/models/networks/cogvideox/weights/pre_weights.py +++ b/lightx2v/models/networks/cogvideox/weights/pre_weights.py @@ -12,9 +12,13 @@ def load_weights(self, weight_dict): self.time_embedding_linear_2 = MM_WEIGHT_REGISTER["Default"]("time_embedding.linear_2.weight", "time_embedding.linear_2.bias") self.patch_embed_proj = MM_WEIGHT_REGISTER["Default"]("patch_embed.proj.weight", "patch_embed.proj.bias") self.patch_embed_text_proj = MM_WEIGHT_REGISTER["Default"]("patch_embed.text_proj.weight", "patch_embed.text_proj.bias") - self.weight_list = [self.time_embedding_linear_1, self.time_embedding_linear_2, self.patch_embed_proj, self.patch_embed_text_proj] + if "ofs_embed_dim" in self.config: + self.ofs_embedding_linear_1 = MM_WEIGHT_REGISTER["Default"]("ofs_embedding.linear_1.weight", "ofs_embedding.linear_1.bias") + self.ofs_embedding_linear_2 = MM_WEIGHT_REGISTER["Default"]("ofs_embedding.linear_2.weight", "ofs_embedding.linear_2.bias") + self.weight_list += [self.ofs_embedding_linear_1, self.ofs_embedding_linear_2] + for mm_weight in self.weight_list: mm_weight.set_config(self.config) mm_weight.load(weight_dict) diff --git a/lightx2v/models/runners/cogvideox/cogvidex_runner.py b/lightx2v/models/runners/cogvideox/cogvidex_runner.py index 40bc22e..30831c6 100644 --- a/lightx2v/models/runners/cogvideox/cogvidex_runner.py +++ b/lightx2v/models/runners/cogvideox/cogvidex_runner.py @@ -1,6 +1,7 @@ -from diffusers.utils import export_to_video +from diffusers.utils import export_to_video, load_image import imageio import numpy as np +import torch from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.models.runners.default_runner import DefaultRunner @@ -16,27 +17,19 @@ class CogvideoxRunner(DefaultRunner): def __init__(self, config): super().__init__(config) - def load_transformer(self, init_device): - model = CogvideoxModel(self.config) - return model - - def load_image_encoder(self, init_device): - return None - - def load_text_encoder(self, init_device): + @ProfilingContext("Load models") + def load_model(self): text_encoder = T5EncoderModel_v1_1_xxl(self.config) - text_encoders = [text_encoder] - return text_encoders - - def load_vae(self, init_device): - vae_model = CogvideoxVAE(self.config) - return vae_model, vae_model + self.text_encoders = [text_encoder] + self.model = CogvideoxModel(self.config) + self.vae_model = CogvideoxVAE(self.config) + image_encoder = None def init_scheduler(self): scheduler = CogvideoxXDPMScheduler(self.config) self.model.set_scheduler(scheduler) - def run_text_encoder(self, text, img): + def run_text_encoder(self, text): text_encoder_output = {} n_prompt = self.config.get("negative_prompt", "") context = self.text_encoders[0].infer([text], self.config) @@ -45,38 +38,60 @@ def run_text_encoder(self, text, img): text_encoder_output["context_null"] = context_null return text_encoder_output - def run_vae_encoder(self, img): - # TODO: implement vae encoder for Cogvideox - raise NotImplementedError("I2V inference is not implemented for Cogvideox.") + def run_image_encoder(self, config, vae_model): + image = load_image(image=self.config.image_path) + image = vae_model.video_processor.preprocess(image, height=config.height, width=config.width).to(torch.device("cuda"), dtype=torch.bfloat16) + image = image.unsqueeze(2) + image = [vae_model.encode(img.unsqueeze(0)) for img in image] + return image - def get_encoder_output_i2v(self, clip_encoder_out, vae_encode_out, text_encoder_output, img): - # TODO: Implement image encoder for Cogvideox-I2V - raise ValueError(f"Unsupported model class: {self.config['model_cls']}") + @ProfilingContext("Run Encoders") + async def run_input_encoder_local_t2v(self): + prompt = self.config["prompt"] + text_encoder_output = self.run_text_encoder(prompt) + return {"text_encoder_output": text_encoder_output, "image_encoder_output": None} + + @ProfilingContext("Run Encoders") + async def run_input_encoder_local_i2v(self): + image_encoder_output = self.run_image_encoder(self.config, self.vae_model) + prompt = self.config["prompt"] + text_encoder_output = self.run_text_encoder(prompt) + return {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output} + + @ProfilingContext("Run VAE Decoder") + async def run_vae_decoder_local(self, latents, generator): + images = self.vae_model.decode(latents, generator=generator, config=self.config) + return images def set_target_shape(self): - ret = {} - if self.config.task == "i2v": - # TODO: implement set_target_shape for Cogvideox-I2V - raise NotImplementedError("I2V inference is not implemented for Cogvideox.") - else: - num_frames = self.config.target_video_length - latent_frames = (num_frames - 1) // self.config.vae_scale_factor_temporal + 1 - additional_frames = 0 - patch_size_t = self.config.patch_size_t - if patch_size_t is not None and latent_frames % patch_size_t != 0: - additional_frames = patch_size_t - latent_frames % patch_size_t - num_frames += additional_frames * self.config.vae_scale_factor_temporal - self.config.target_shape = ( + num_frames = self.config.target_video_length + latent_frames = (num_frames - 1) // self.config.vae_scale_factor_temporal + 1 + additional_frames = 0 + patch_size_t = self.config.patch_size_t + if patch_size_t is not None and latent_frames % patch_size_t != 0: + additional_frames = patch_size_t - latent_frames % patch_size_t + num_frames += additional_frames * self.config.vae_scale_factor_temporal + target_shape = ( + self.config.batch_size, + (num_frames - 1) // self.config.vae_scale_factor_temporal + 1, + self.config.latent_channels, + self.config.height // self.config.vae_scale_factor_spatial, + self.config.width // self.config.vae_scale_factor_spatial, + ) + if self.config.task in ["t2v"]: + self.config.target_shape = target_shape + elif self.config.task in ["i2v"]: + self.config.target_shape = target_shape[:1] + (target_shape[1] + target_shape[1] % self.config.patch_size_t,) + target_shape[2:] + self.config.padding_shape = ( self.config.batch_size, - (num_frames - 1) // self.config.vae_scale_factor_temporal + 1, + (num_frames - 1) // self.config.vae_scale_factor_temporal, self.config.latent_channels, self.config.height // self.config.vae_scale_factor_spatial, self.config.width // self.config.vae_scale_factor_spatial, ) - ret["target_shape"] = self.config.target_shape - return ret + return None - def save_video_func(self, images): + def save_video(self, images): with imageio.get_writer(self.config.save_video_path, fps=16) as writer: for pil_image in images: frame_np = np.array(pil_image, dtype=np.uint8) diff --git a/lightx2v/models/schedulers/cogvideox/scheduler.py b/lightx2v/models/schedulers/cogvideox/scheduler.py index 9c6e19d..cbb9eca 100644 --- a/lightx2v/models/schedulers/cogvideox/scheduler.py +++ b/lightx2v/models/schedulers/cogvideox/scheduler.py @@ -56,6 +56,18 @@ def rescale_zero_terminal_snr(alphas_cumprod): return alphas_bar +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents(encoder_output, generator=None, sample_mode="sample"): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + class CogvideoxXDPMScheduler(BaseScheduler): def __init__(self, config): self.config = config @@ -133,6 +145,8 @@ def set_timesteps(self): def prepare(self, image_encoder_output): self.image_encoder_output = image_encoder_output + if self.config.task in ["i2v"]: + self.prepare_image_latents(image=image_encoder_output, padding_shape=self.config.padding_shape, dtype=torch.bfloat16) self.prepare_latents(shape=self.config.target_shape, dtype=torch.bfloat16) self.prepare_guidance() self.prepare_rotary_pos_embedding() @@ -144,6 +158,25 @@ def prepare_latents(self, shape, dtype): self.latents = latents self.old_pred_original_sample = None + def prepare_image_latents(self, image, padding_shape, dtype): + image_latents = [retrieve_latents(img, self.generator) for img in image] + image_latents = torch.cat(image_latents, dim=0).to(torch.bfloat16).permute(0, 2, 1, 3, 4) # [B, F, C, H, W] + + if not self.config.vae_invert_scale_latents: + image_latents = self.config.vae_scaling_factor_image * image_latents + else: + # This is awkward but required because the CogVideoX team forgot to multiply the + # scaling factor during training :) + image_latents = 1 / self.config.vae_scaling_factor_image * image_latents + + latent_padding = torch.zeros(padding_shape, device=torch.device("cuda"), dtype=torch.bfloat16) + image_latents = torch.cat([image_latents, latent_padding], dim=1) + # Select the first frame along the second dimension + if self.config.patch_size_t is not None: + first_frame = image_latents[:, : image_latents.size(1) % self.config.patch_size_t, ...] + image_latents = torch.cat([first_frame, image_latents], dim=1) + self.image_latents = image_latents + def prepare_guidance(self): self.guidance_scale = self.config.guidance_scale diff --git a/lightx2v/models/video_encoders/hf/cogvideox/autoencoder_ks_cogvidex.py b/lightx2v/models/video_encoders/hf/cogvideox/autoencoder_ks_cogvidex.py index b0ac115..a8fe918 100644 --- a/lightx2v/models/video_encoders/hf/cogvideox/autoencoder_ks_cogvidex.py +++ b/lightx2v/models/video_encoders/hf/cogvideox/autoencoder_ks_cogvidex.py @@ -1126,7 +1126,6 @@ def disable_slicing(self) -> None: def _encode(self, x: torch.Tensor) -> torch.Tensor: batch_size, num_channels, num_frames, height, width = x.shape - if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): return self.tiled_encode(x) @@ -1169,7 +1168,6 @@ def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[Autoencoder h = torch.cat(encoded_slices) else: h = self._encode(x) - posterior = DiagonalGaussianDistribution(h) if not return_dict: diff --git a/lightx2v/models/video_encoders/hf/cogvideox/model.py b/lightx2v/models/video_encoders/hf/cogvideox/model.py index 952f3d2..af85ac7 100644 --- a/lightx2v/models/video_encoders/hf/cogvideox/model.py +++ b/lightx2v/models/video_encoders/hf/cogvideox/model.py @@ -1,8 +1,8 @@ import os import glob -import torch # type: ignore -from safetensors import safe_open # type: ignore -from diffusers.video_processor import VideoProcessor # type: ignore +import torch +from safetensors import safe_open +from diffusers.video_processor import VideoProcessor from lightx2v.models.video_encoders.hf.cogvideox.autoencoder_ks_cogvidex import AutoencoderKLCogVideoX @@ -34,13 +34,19 @@ def load(self): self.vae_config = AutoencoderKLCogVideoX.load_config(vae_path) self.model = AutoencoderKLCogVideoX.from_config(self.vae_config) vae_ckpt = self._load_ckpt(vae_path) - self.vae_scale_factor_spatial = 2 ** (len(self.vae_config["block_out_channels"]) - 1) # 8 - self.vae_scale_factor_temporal = self.vae_config["temporal_compression_ratio"] # 4 - self.vae_scaling_factor_image = self.vae_config["scaling_factor"] # 0.7 + self.vae_scale_factor_spatial = 2 ** (len(self.vae_config["block_out_channels"]) - 1) + self.vae_scale_factor_temporal = self.vae_config["temporal_compression_ratio"] + self.vae_scaling_factor_image = self.vae_config["scaling_factor"] self.model.load_state_dict(vae_ckpt) self.model.to(torch.bfloat16).to(torch.device("cuda")) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + @torch.no_grad() + def encode(self, x, return_dict=True): + self.model.enable_tiling() + x = self.model.encode(x, return_dict) + return x + @torch.no_grad() def decode(self, latents, generator, config): latents = latents.permute(0, 2, 1, 3, 4) diff --git a/scripts/run_cogvideox_i2v.sh b/scripts/run_cogvideox_i2v.sh new file mode 100644 index 0000000..5b7af90 --- /dev/null +++ b/scripts/run_cogvideox_i2v.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +# set path and first +lightx2v_path= +model_path= + +# check section +if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then + cuda_devices=3 + echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable." + export CUDA_VISIBLE_DEVICES=${cuda_devices} +fi + +if [ -z "${lightx2v_path}" ]; then + echo "Error: lightx2v_path is not set. Please set this variable first." + exit 1 +fi + +if [ -z "${model_path}" ]; then + echo "Error: model_path is not set. Please set this variable first." + exit 1 +fi + +export TOKENIZERS_PARALLELISM=false + +export PYTHONPATH=${lightx2v_path}:$PYTHONPATH + +export ENABLE_PROFILING_DEBUG=true +export ENABLE_GRAPH_MODE=false + +export PYTHONPATH=/mtc/wushuo/VideoGen/diffusers:$PYTHONPATH + +python -m lightx2v.infer \ +--model_cls cogvideox \ +--task i2v \ +--model_path $model_path \ +--image_path /mtc/wushuo/VideoGen/CogVideo/input.jpg \ +--config_json ${lightx2v_path}/configs/cogvideox_i2v.json \ +--prompt "This guy gave you a thumbs-up." \ +--save_video_path ${lightx2v_path}/save_results/output_lightx2v_cogvideox_i2v.mp4