-
Notifications
You must be signed in to change notification settings - Fork 10
support cogvideox i2v & adapt to the changes in commit 607f968 #59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
{ | ||
"seed": 42, | ||
"text_len": 226, | ||
"num_videos_per_prompt": 1, | ||
"target_video_length": 81, | ||
"num_inference_steps": 50, | ||
"num_train_timesteps": 1000, | ||
"timestep_spacing": "trailing", | ||
"steps_offset": 0, | ||
"latent_channels": 16, | ||
"height": 768, | ||
"width": 1360, | ||
"vae_scale_factor_temporal": 4, | ||
"vae_scale_factor_spatial": 8, | ||
"vae_scaling_factor_image": 0.7, | ||
"vae_invert_scale_latents": true, | ||
"batch_size": 1, | ||
"patch_size": 2, | ||
"patch_size_t": 2, | ||
"guidance_scale": 0, | ||
"use_rotary_positional_embeddings": true, | ||
"do_classifier_free_guidance": false, | ||
"transformer_sample_width": 170, | ||
"transformer_sample_height": 96, | ||
"transformer_sample_frames": 81, | ||
"transformer_attention_head_dim": 64, | ||
"transformer_num_attention_heads": 48, | ||
"transformer_temporal_compression_ratio": 4, | ||
"transformer_temporal_interpolation_scale": 1.0, | ||
"transformer_use_learned_positional_embeddings": false, | ||
"transformer_spatial_interpolation_scale": 1.875, | ||
"transformer_num_layers": 42, | ||
"transformer_ofs_embed_dim": 512, | ||
"beta_schedule": "scaled_linear", | ||
"scheduler_beta_start": 0.00085, | ||
"scheduler_beta_end": 0.012, | ||
"scheduler_set_alpha_to_one": true, | ||
"scheduler_snr_shift_scale": 1.0, | ||
"scheduler_rescale_betas_zero_snr": true, | ||
"scheduler_prediction_type": "v_prediction", | ||
"use_dynamic_cfg": true | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
from diffusers.utils import export_to_video | ||
from diffusers.utils import export_to_video, load_image | ||
import imageio | ||
import numpy as np | ||
import torch | ||
|
||
from lightx2v.utils.registry_factory import RUNNER_REGISTER | ||
from lightx2v.models.runners.default_runner import DefaultRunner | ||
|
@@ -16,27 +17,19 @@ class CogvideoxRunner(DefaultRunner): | |
def __init__(self, config): | ||
super().__init__(config) | ||
|
||
def load_transformer(self, init_device): | ||
model = CogvideoxModel(self.config) | ||
return model | ||
|
||
def load_image_encoder(self, init_device): | ||
return None | ||
|
||
def load_text_encoder(self, init_device): | ||
@ProfilingContext("Load models") | ||
def load_model(self): | ||
text_encoder = T5EncoderModel_v1_1_xxl(self.config) | ||
text_encoders = [text_encoder] | ||
return text_encoders | ||
|
||
def load_vae(self, init_device): | ||
vae_model = CogvideoxVAE(self.config) | ||
return vae_model, vae_model | ||
self.text_encoders = [text_encoder] | ||
self.model = CogvideoxModel(self.config) | ||
self.vae_model = CogvideoxVAE(self.config) | ||
image_encoder = None | ||
|
||
def init_scheduler(self): | ||
scheduler = CogvideoxXDPMScheduler(self.config) | ||
self.model.set_scheduler(scheduler) | ||
|
||
def run_text_encoder(self, text, img): | ||
def run_text_encoder(self, text): | ||
text_encoder_output = {} | ||
n_prompt = self.config.get("negative_prompt", "") | ||
context = self.text_encoders[0].infer([text], self.config) | ||
|
@@ -45,38 +38,60 @@ def run_text_encoder(self, text, img): | |
text_encoder_output["context_null"] = context_null | ||
return text_encoder_output | ||
|
||
def run_vae_encoder(self, img): | ||
# TODO: implement vae encoder for Cogvideox | ||
raise NotImplementedError("I2V inference is not implemented for Cogvideox.") | ||
def run_image_encoder(self, config, vae_model): | ||
image = load_image(image=self.config.image_path) | ||
image = vae_model.video_processor.preprocess(image, height=config.height, width=config.width).to(torch.device("cuda"), dtype=torch.bfloat16) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The device and dtype for image preprocessing are hardcoded here to |
||
image = image.unsqueeze(2) | ||
image = [vae_model.encode(img.unsqueeze(0)) for img in image] | ||
return image | ||
|
||
def get_encoder_output_i2v(self, clip_encoder_out, vae_encode_out, text_encoder_output, img): | ||
# TODO: Implement image encoder for Cogvideox-I2V | ||
raise ValueError(f"Unsupported model class: {self.config['model_cls']}") | ||
@ProfilingContext("Run Encoders") | ||
async def run_input_encoder_local_t2v(self): | ||
prompt = self.config["prompt"] | ||
text_encoder_output = self.run_text_encoder(prompt) | ||
return {"text_encoder_output": text_encoder_output, "image_encoder_output": None} | ||
|
||
@ProfilingContext("Run Encoders") | ||
async def run_input_encoder_local_i2v(self): | ||
image_encoder_output = self.run_image_encoder(self.config, self.vae_model) | ||
prompt = self.config["prompt"] | ||
text_encoder_output = self.run_text_encoder(prompt) | ||
return {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output} | ||
|
||
@ProfilingContext("Run VAE Decoder") | ||
async def run_vae_decoder_local(self, latents, generator): | ||
images = self.vae_model.decode(latents, generator=generator, config=self.config) | ||
return images | ||
|
||
def set_target_shape(self): | ||
ret = {} | ||
if self.config.task == "i2v": | ||
# TODO: implement set_target_shape for Cogvideox-I2V | ||
raise NotImplementedError("I2V inference is not implemented for Cogvideox.") | ||
else: | ||
num_frames = self.config.target_video_length | ||
latent_frames = (num_frames - 1) // self.config.vae_scale_factor_temporal + 1 | ||
additional_frames = 0 | ||
patch_size_t = self.config.patch_size_t | ||
if patch_size_t is not None and latent_frames % patch_size_t != 0: | ||
additional_frames = patch_size_t - latent_frames % patch_size_t | ||
num_frames += additional_frames * self.config.vae_scale_factor_temporal | ||
self.config.target_shape = ( | ||
num_frames = self.config.target_video_length | ||
latent_frames = (num_frames - 1) // self.config.vae_scale_factor_temporal + 1 | ||
additional_frames = 0 | ||
patch_size_t = self.config.patch_size_t | ||
if patch_size_t is not None and latent_frames % patch_size_t != 0: | ||
additional_frames = patch_size_t - latent_frames % patch_size_t | ||
num_frames += additional_frames * self.config.vae_scale_factor_temporal | ||
target_shape = ( | ||
self.config.batch_size, | ||
(num_frames - 1) // self.config.vae_scale_factor_temporal + 1, | ||
self.config.latent_channels, | ||
self.config.height // self.config.vae_scale_factor_spatial, | ||
self.config.width // self.config.vae_scale_factor_spatial, | ||
) | ||
if self.config.task in ["t2v"]: | ||
self.config.target_shape = target_shape | ||
elif self.config.task in ["i2v"]: | ||
self.config.target_shape = target_shape[:1] + (target_shape[1] + target_shape[1] % self.config.patch_size_t,) + target_shape[2:] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The frame dimension of Given that If Could you clarify the intent here? Is this an additional specific padding requirement for CogVideoX I2V, or is there a scenario where |
||
self.config.padding_shape = ( | ||
self.config.batch_size, | ||
(num_frames - 1) // self.config.vae_scale_factor_temporal + 1, | ||
(num_frames - 1) // self.config.vae_scale_factor_temporal, | ||
self.config.latent_channels, | ||
self.config.height // self.config.vae_scale_factor_spatial, | ||
self.config.width // self.config.vae_scale_factor_spatial, | ||
) | ||
ret["target_shape"] = self.config.target_shape | ||
return ret | ||
return None | ||
|
||
def save_video_func(self, images): | ||
def save_video(self, images): | ||
with imageio.get_writer(self.config.save_video_path, fps=16) as writer: | ||
for pil_image in images: | ||
frame_np = np.array(pil_image, dtype=np.uint8) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -56,6 +56,18 @@ def rescale_zero_terminal_snr(alphas_cumprod): | |
return alphas_bar | ||
|
||
|
||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents | ||
def retrieve_latents(encoder_output, generator=None, sample_mode="sample"): | ||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": | ||
return encoder_output.latent_dist.sample(generator) | ||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": | ||
return encoder_output.latent_dist.mode() | ||
elif hasattr(encoder_output, "latents"): | ||
return encoder_output.latents | ||
else: | ||
raise AttributeError("Could not access latents of provided encoder_output") | ||
|
||
|
||
class CogvideoxXDPMScheduler(BaseScheduler): | ||
def __init__(self, config): | ||
self.config = config | ||
|
@@ -133,6 +145,8 @@ def set_timesteps(self): | |
|
||
def prepare(self, image_encoder_output): | ||
self.image_encoder_output = image_encoder_output | ||
if self.config.task in ["i2v"]: | ||
self.prepare_image_latents(image=image_encoder_output, padding_shape=self.config.padding_shape, dtype=torch.bfloat16) | ||
self.prepare_latents(shape=self.config.target_shape, dtype=torch.bfloat16) | ||
self.prepare_guidance() | ||
self.prepare_rotary_pos_embedding() | ||
|
@@ -144,6 +158,25 @@ def prepare_latents(self, shape, dtype): | |
self.latents = latents | ||
self.old_pred_original_sample = None | ||
|
||
def prepare_image_latents(self, image, padding_shape, dtype): | ||
image_latents = [retrieve_latents(img, self.generator) for img in image] | ||
image_latents = torch.cat(image_latents, dim=0).to(torch.bfloat16).permute(0, 2, 1, 3, 4) # [B, F, C, H, W] | ||
|
||
if not self.config.vae_invert_scale_latents: | ||
image_latents = self.config.vae_scaling_factor_image * image_latents | ||
else: | ||
# This is awkward but required because the CogVideoX team forgot to multiply the | ||
# scaling factor during training :) | ||
image_latents = 1 / self.config.vae_scaling_factor_image * image_latents | ||
|
||
latent_padding = torch.zeros(padding_shape, device=torch.device("cuda"), dtype=torch.bfloat16) | ||
image_latents = torch.cat([image_latents, latent_padding], dim=1) | ||
# Select the first frame along the second dimension | ||
if self.config.patch_size_t is not None: | ||
first_frame = image_latents[:, : image_latents.size(1) % self.config.patch_size_t, ...] | ||
image_latents = torch.cat([first_frame, image_latents], dim=1) | ||
Comment on lines
+175
to
+177
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This padding logic If Is this specific way of padding (making the frame count |
||
self.image_latents = image_latents | ||
|
||
def prepare_guidance(self): | ||
self.guidance_scale = self.config.guidance_scale | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
ofs_emb
is initialized withfill_value=2.0
whenself.config.transformer_ofs_embed_dim
is not None. Could you clarify the significance of this specific value2.0
? If it's a critical parameter, consider defining it as a named constant or making it configurable for better readability and maintainability.