diff --git a/README.md b/README.md index 62a22c3..c48cad6 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,13 @@ +

This repo contains code to run LongLive using 2x 3090, both single prompt and interactive prompts. It needs at least 2x24gb gpus

+

Download bfloat weights from https://huggingface.co/srivassid/LongLiveMultiGPU/tree/main, create a new folder longlive_models/models folder under the same parent as LongLive adn put them there, + and clone the model weight [LongLive-1.3B](https://huggingface.co/Efficient-Large-Model/LongLive-1.3B) and put that folder inside LongLive folder.

+

Create a folder under longlive_models/prompts and create a file called interactive_models.jsonl, pick one line from LongLive/example/interactive_example.jsonl and save it.

+

Do the same for single prompt, create a file longlive_models/prompts/vidprom_filtered_extended.txt, go to LongLive/example/long_example.txt adn paste one of the lines in the file

+

if nvidia-pyindex package throws an error, comment it out from requirements.txt

+

For single prompts run single_quantized.sh, for interactive prompts, run run_quantized.sh

+

I put cuda:2 in interactive_inference_quantized.py and single_prompt_inference.py. Change it to cuda:1 for the second gpu

+

Any issues, leave a comment

+

logo

diff --git a/configs/longlive_inference.yaml b/configs/longlive_inference.yaml index 2c0daf5..ed1111f 100644 --- a/configs/longlive_inference.yaml +++ b/configs/longlive_inference.yaml @@ -4,7 +4,7 @@ denoising_step_list: - 500 - 250 warp_denoising_step: true # need to remove - 0 in denoising_step_list if warp_denoising_step is true -num_frame_per_block: 3 +num_frame_per_block: 6 model_name: Wan2.1-T2V-1.3B model_kwargs: local_attn_size: 12 diff --git a/configs/longlive_interactive_inference.yaml b/configs/longlive_interactive_inference.yaml index 113ae78..64aeaf1 100644 --- a/configs/longlive_interactive_inference.yaml +++ b/configs/longlive_interactive_inference.yaml @@ -5,7 +5,7 @@ denoising_step_list: - 500 - 250 warp_denoising_step: true # need to remove - 0 in denoising_step_list if warp_denoising_step is true -num_frame_per_block: 3 +num_frame_per_block: 6 model_name: Wan2.1-T2V-1.3B model_kwargs: local_attn_size: 12 @@ -17,12 +17,12 @@ model_kwargs: data_path: longlive_models/prompts/interactive_example.jsonl output_folder: videos/interactive inference_iter: -1 -num_output_frames: 240 +num_output_frames: 144 use_ema: false seed: 1 num_samples: 1 save_with_index: true -switch_frame_indices: 40, 80, 120, 160, 200 +switch_frame_indices: 48, 96 global_sink: true context_noise: 0 diff --git a/interactive_inference_quantized.py b/interactive_inference_quantized.py new file mode 100644 index 0000000..e7aa291 --- /dev/null +++ b/interactive_inference_quantized.py @@ -0,0 +1,341 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# To view a copy of this license, visit http://www.apache.org/licenses/LICENSE-2.0 +# +# No warranties are given. The work is provided "AS IS", without warranty of any kind, express or implied. +# +# SPDX-License-Identifier: Apache-2.0 +import argparse +import os +from typing import List + +import torch +import torch.distributed as dist +from omegaconf import OmegaConf +from tqdm import tqdm +from torch.utils.data import DataLoader, SequentialSampler +from torch.utils.data.distributed import DistributedSampler +from torchvision.io import write_video +from torchvision import transforms # noqa: F401 +from einops import rearrange + +from utils.misc import set_seed +from utils.distributed import barrier +from utils.memory import gpu, get_cuda_free_memory_gb, DynamicSwapInstaller + +from pipeline.interactive_causal_inference import ( + InteractiveCausalInferencePipeline, +) +from utils.dataset import MultiTextDataset + +# ----------------------------- Argument parsing ----------------------------- +parser = argparse.ArgumentParser("Interactive causal inference") +parser.add_argument("--config_path", type=str, help="Path to the config file") +parser.add_argument("--use_quantized", action="store_true", + help="Use quantized models from ../longlive_models/") +args = parser.parse_args() + +config = OmegaConf.load(args.config_path) +# config.model_kwargs.local_attn_size = 32 +# ======================== LOAD QUANTIZED MODELS TO CPU FIRST ======================== +quantized_base_state = None +quantized_lora_state = None + +if args.use_quantized: + print("\n" + "="*70) + print("šŸ”§ Pre-loading QUANTIZED models to CPU") + print("="*70) + + import sys + sys.path.insert(0, '..') + + base_path = '../longlive_models/longlive_base_bfloat16.pt' + lora_path = '../longlive_models/lora_bfloat16.pt' + + print(f"šŸ“„ Loading base checkpoint to CPU from: {base_path}") + base_checkpoint = torch.load(base_path, map_location='cpu', weights_only=False) + + print(f"šŸ“„ Loading LoRA checkpoint to CPU from: {lora_path}") + lora_checkpoint = torch.load(lora_path, map_location='cpu', weights_only=False) + + # Extract the models + quantized_base_state = base_checkpoint['generator'] + quantized_lora_state = lora_checkpoint['generator_lora'] + quantized_critic_lora = lora_checkpoint.get('critic_lora', {}) + + print(f"āœ… Loaded to CPU - will transfer after pipeline init") + print(f" Base params: {sum(p.numel() for p in quantized_base_state.values() if isinstance(p, torch.Tensor)):,}") + print(f" LoRA params: {sum(p.numel() for p in quantized_lora_state.values() if isinstance(p, torch.Tensor)):,}") + print("="*70 + "\n") + + # Don't load from checkpoint files + config.generator_ckpt = None + config.lora_ckpt = None # We'll load this manually later + +# ======================== END PRE-LOADING ======================== + +# ----------------------------- Distributed setup ----------------------------- +if "LOCAL_RANK" in os.environ: # Multi-GPU via torchrun + os.environ["NCCL_CROSS_NIC"] = "1" + os.environ["NCCL_DEBUG"] = os.environ.get("NCCL_DEBUG", "INFO") + os.environ["NCCL_TIMEOUT"] = os.environ.get("NCCL_TIMEOUT", "1800") + + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + rank = int(os.environ.get("RANK", str(local_rank))) + + torch.cuda.set_device(local_rank) + device = torch.device(f"cuda:{local_rank}") + + if not dist.is_initialized(): + dist.init_process_group( + backend="nccl", + rank=rank, + world_size=world_size, + timeout=torch.distributed.constants.default_pg_timeout, + ) + + set_seed(config.seed + local_rank) + print(f"[Rank {rank}] Distributed mode on GPU {local_rank}") + +else: # Single-GPU mode + assert torch.cuda.is_available(), "CUDA is required but not available" + + local_rank = 0 + rank = 0 + device = torch.device("cuda:0") + torch.cuda.set_device(device) + + set_seed(config.seed) + print("Single GPU mode on cuda:0") + +low_memory = get_cuda_free_memory_gb(device) < 40 +torch.set_grad_enabled(False) + +# ======================== INITIALIZE PIPELINE ======================== +pipeline = InteractiveCausalInferencePipeline(config, device=device) +print("Generator device:", next(pipeline.generator.parameters()).device) +print("VAE device:", next(pipeline.vae.parameters()).device) +print("Text encoder device:", next(pipeline.text_encoder.parameters()).device) +# ======================== LOAD QUANTIZED WEIGHTS ======================== +if args.use_quantized and quantized_base_state is not None: + print("\n" + "="*70) + print("šŸ”§ Loading quantized base model into pipeline") + print("="*70) + + missing, unexpected = pipeline.generator.load_state_dict(quantized_base_state, strict=False) + if local_rank == 0: + if missing: + print(f"[Warning] {len(missing)} parameters missing: {missing[:8]} ...") + if unexpected: + print(f"[Warning] {len(unexpected)} unexpected params: {unexpected[:8]} ...") + print("āœ… Quantized base model loaded") + + print("šŸ”§ Converting ALL generator parameters to bfloat16...") + for name, param in pipeline.generator.named_parameters(): + if param.dtype != torch.bfloat16: + param.data = param.data.to(torch.bfloat16) + for name, buffer in pipeline.generator.named_buffers(): + if buffer.dtype != torch.bfloat16 and buffer.dtype == torch.float32: + buffer.data = buffer.data.to(torch.bfloat16) + print("āœ… All generator parameters converted to bfloat16") + + # Clear the CPU checkpoint to free memory + del base_checkpoint + del quantized_base_state + import gc + gc.collect() + +# --------------------------- LoRA support (optional) --------------------------- +from utils.lora_utils import configure_lora_for_model +import peft + +pipeline.is_lora_enabled = False +if getattr(config, "adapter", None) and configure_lora_for_model is not None: + if local_rank == 0: + print(f"\nšŸ”§ LoRA enabled with config: {config.adapter}") + print("Applying LoRA to generator (inference)...") + + pipeline.generator.model = configure_lora_for_model( + pipeline.generator.model, + model_name="generator", + lora_config=config.adapter, + is_main_process=(local_rank == 0), + ) + + # Load quantized LoRA weights + if args.use_quantized and quantized_lora_state is not None: + if local_rank == 0: + print(f"Loading QUANTIZED LoRA weights from CPU") + peft.set_peft_model_state_dict(pipeline.generator.model, quantized_lora_state) + if local_rank == 0: + print("āœ… Quantized LoRA weights loaded") + + # Clear LoRA checkpoint + del lora_checkpoint + del quantized_lora_state + gc.collect() + elif not args.use_quantized: + # Original LoRA loading for non-quantized + lora_ckpt_path = getattr(config, "lora_ckpt", None) + if lora_ckpt_path: + if local_rank == 0: + print(f"Loading LoRA checkpoint from {lora_ckpt_path}") + lora_checkpoint = torch.load(lora_ckpt_path, map_location="cpu") + if isinstance(lora_checkpoint, dict) and "generator_lora" in lora_checkpoint: + peft.set_peft_model_state_dict(pipeline.generator.model, lora_checkpoint["generator_lora"]) + else: + peft.set_peft_model_state_dict(pipeline.generator.model, lora_checkpoint) + if local_rank == 0: + print("LoRA weights loaded") + + pipeline.is_lora_enabled = True + +# Move pipeline to appropriate dtype and device +# print("\nšŸ”§ Moving pipeline to bfloat16...") +# pipeline = pipeline.to(dtype=torch.bfloat16) + +if low_memory: + DynamicSwapInstaller.install_model(pipeline.text_encoder, device=device) + +# ======================== USE GPU 1 or 2 FOR INFERENCE ======================== + +# ======================== USE GPU 2 WITH DYNAMIC SWAP ======================== +# inference_device = torch.device('cuda:2') +# +# print(f"\nšŸŽ¬ Loading all components to GPU 2") +# print(f" GPU 2 Free VRAM before: {get_cuda_free_memory_gb(inference_device):.2f} GB") +# +# torch.cuda.empty_cache() +# +# print(f" GPU 2 Free VRAM after clearing GPU 0: {get_cuda_free_memory_gb(inference_device):.2f} GB") +# +# # Now move everything to GPU 2 +# print(" Moving generator to GPU 2...") +# pipeline.generator.to(device=inference_device) +# +# print(" Converting VAE to bfloat16...") +# pipeline.vae = pipeline.vae.to(dtype=torch.bfloat16) +# +# # print(" Moving text_encoder from CPU to GPU 2...") +# # pipeline.text_encoder = pipeline.text_encoder.to(device=inference_device) +# +# print(f"āœ… All components on GPU 2") +# print(f" GPU 2 VRAM used: {(torch.cuda.memory_allocated(2) / 1024**3):.2f} GB") + +print("\n🚚 Moving models to GPU 2 for inference...") + +inference_device = torch.device("cuda:2") + +pipeline.generator.to(inference_device) +pipeline.vae.to(inference_device) + +# Optional but recommended for speed +pipeline.generator = pipeline.generator.to(dtype=torch.bfloat16) +pipeline.vae = pipeline.vae.to(dtype=torch.bfloat16) + +torch.cuda.empty_cache() + +print("Generator device:", next(pipeline.generator.parameters()).device) +print("VAE device:", next(pipeline.vae.parameters()).device) +print("Text encoder device:", next(pipeline.text_encoder.parameters()).device) + +device = inference_device + +# ======================== END GPU SETUP ======================== + +# ----------------------------- Build dataset ----------------------------- +if isinstance(config.switch_frame_indices, int): + switch_frame_indices: List[int] = [int(config.switch_frame_indices)] +else: + switch_frame_indices: List[int] = [ + int(x) for x in str(config.switch_frame_indices).split(",") if str(x).strip() + ] + +dataset = MultiTextDataset(config.data_path) + +num_segments = len(dataset[0]["prompts_list"]) +assert len(switch_frame_indices) == num_segments - 1, ( + "The number of switch_frame_indices should be the number of prompt segments minus 1" +) + +print("Number of segments:", num_segments) +print("Switch frame indices:", switch_frame_indices) + +num_prompts_total = len(dataset) +print(f"Number of prompt lines: {num_prompts_total}") + +if dist.is_initialized(): + sampler = DistributedSampler(dataset, shuffle=False, drop_last=True) +else: + sampler = SequentialSampler(dataset) + +dataloader = DataLoader(dataset, batch_size=1, sampler=sampler, num_workers=0, drop_last=False) + +if local_rank == 0: + os.makedirs(config.output_folder, exist_ok=True) + +if dist.is_initialized(): + dist.barrier() + +# ----------------------------- Inference loop ----------------------------- +print("\n" + "="*70) +print("šŸš€ Starting video generation...") +print("="*70) + +for i, batch_data in tqdm(enumerate(dataloader), disable=(local_rank != 0)): + idx = batch_data["idx"].item() + prompts_list: List[str] = batch_data["prompts_list"] + + sampled_noise = torch.randn( + [ + config.num_samples, + config.num_output_frames, + 16, + 60, + 104, + ], + device=device, + dtype=torch.bfloat16, + ) + + with torch.autocast("cuda", dtype=torch.bfloat16): + video = pipeline.inference( + noise=sampled_noise, + text_prompts_list=prompts_list, + switch_frame_indices=switch_frame_indices, + return_latents=False, + ) + + current_video = rearrange(video, "b t c h w -> b t h w c").cpu() * 255.0 + + if dist.is_initialized(): + rank = dist.get_rank() + else: + rank = 0 + + model_type = "quantized" if args.use_quantized else "regular" + + for seed_idx in range(config.num_samples): + if config.save_with_index: + output_path = os.path.join(config.output_folder, f"rank{rank}-{idx}-{seed_idx}_{model_type}.mp4") + else: + short_name = prompts_list[0][0][:100].replace("/", "_") + output_path = os.path.join(config.output_folder, f"rank{rank}-{short_name}-{seed_idx}_{model_type}.mp4") + + write_video(output_path, current_video[seed_idx].to(torch.uint8), fps=16) + + if local_rank == 0: + print(f"āœ… Saved: {output_path}") + + if config.inference_iter != -1 and i >= config.inference_iter: + break + +print("\n" + "="*70) +print("šŸŽ‰ Video generation complete!") +print("="*70) + +if dist.is_initialized(): + dist.destroy_process_group() diff --git a/pipeline/interactive_causal_inference.py b/pipeline/interactive_causal_inference.py index e44dd92..7f2fde5 100644 --- a/pipeline/interactive_causal_inference.py +++ b/pipeline/interactive_causal_inference.py @@ -32,6 +32,7 @@ def __init__( # Internal helpers def _recache_after_switch(self, output, current_start_frame, new_conditional_dict): + print("šŸ” RECACHING NOW...") if not self.global_sink: # reset kv cache for block_idx in range(self.num_transformer_blocks): @@ -120,12 +121,28 @@ def inference( "length of switch_frame_indices should be one less than text_prompts_list" ) assert num_output_frames % self.num_frame_per_block == 0 + gen_device = next(self.generator.parameters()).device + vae_device = next(self.vae.parameters()).device + text_device = next(self.text_encoder.parameters()).device + noise = noise.to(gen_device) num_blocks = num_output_frames // self.num_frame_per_block # encode all prompts print(text_prompts_list) - cond_list = [self.text_encoder(text_prompts=p) for p in text_prompts_list] + gen_device = next(self.generator.parameters()).device + gen_dtype = next(self.generator.parameters()).dtype + + cond_list = [] + for p in text_prompts_list: + cond = self.text_encoder(text_prompts=p) + + # Move every tensor inside the conditioning dict to generator device AND dtype + for k, v in cond.items(): + if torch.is_tensor(v): + cond[k] = v.to(device=gen_device, dtype=gen_dtype) + + cond_list.append(cond) if low_memory: gpu_memory_preservation = get_cuda_free_memory_gb(gpu) + 5 @@ -205,12 +222,13 @@ def inference( noisy_input = noise[ :, current_start_frame : current_start_frame + current_num_frames ] - + # print("KV cache device:", self.kv_cache1[0]["k"].device) + # print("Cross-attn cache device:", self.crossattn_cache[0]["k"].device) # ---------------- Spatial denoising loop ---------------- for index, current_timestep in enumerate(self.denoising_step_list): timestep = ( torch.ones([batch_size, current_num_frames], - device=noise.device, + device=gen_device, dtype=torch.int64) * current_timestep ) @@ -230,7 +248,7 @@ def inference( torch.randn_like(denoised_pred.flatten(0, 1)), next_timestep * torch.ones( - [batch_size * current_num_frames], device=noise.device, dtype=torch.long + [batch_size * current_num_frames], device=gen_device, dtype=torch.long ), ).unflatten(0, denoised_pred.shape[:2]) else: @@ -261,7 +279,8 @@ def inference( current_start_frame += current_num_frames # Standard decoding - video = self.vae.decode_to_pixel(output.to(noise.device), use_cache=False) + # video = self.vae.decode_to_pixel(output.to(vae_device, dtype=next(self.vae.parameters()).dtype), use_cache=False) + video = self.vae.decode_to_pixel(output.to(vae_device, dtype=torch.float32), use_cache=False) video = (video * 0.5 + 0.5).clamp(0, 1) if return_latents: diff --git a/pipeline/single_prompt_causal_inference.py b/pipeline/single_prompt_causal_inference.py new file mode 100644 index 0000000..7622b78 --- /dev/null +++ b/pipeline/single_prompt_causal_inference.py @@ -0,0 +1,193 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# To view a copy of this license, visit http://www.apache.org/licenses/LICENSE-2.0 +# +# No warranties are given. The work is provided "AS IS", without warranty of any kind, express or implied. +# +# SPDX-License-Identifier: Apache-2.0 +from typing import List, Optional +import torch + +from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper +from utils.memory import gpu, get_cuda_free_memory_gb, move_model_to_device_with_memory_preservation +from pipeline.causal_inference import CausalInferencePipeline +import torch.distributed as dist +from utils.debug_option import DEBUG + + +class SinglePromptCausalInferencePipeline(CausalInferencePipeline): + """ + Simplified causal inference pipeline for single prompt (no prompt switching). + This is essentially the base CausalInferencePipeline without the multi-prompt logic. + """ + def __init__( + self, + args, + device, + *, + generator: WanDiffusionWrapper | None = None, + text_encoder: WanTextEncoder | None = None, + vae: WanVAEWrapper | None = None, + ): + super().__init__(args, device, generator=generator, text_encoder=text_encoder, vae=vae) + self.global_sink = getattr(args, "global_sink", False) + + def inference( + self, + noise: torch.Tensor, + *, + text_prompts: List[str], + return_latents: bool = False, + low_memory: bool = False, + ): + """Generate a video with a single prompt throughout. + + Args: + noise: Noise tensor, shape = (B, T_out, C, H, W). + text_prompts: List[str], prompts for each sample in the batch. + return_latents: Whether to also return the latent tensor. + low_memory: Enable low-memory mode. + """ + batch_size, num_output_frames, num_channels, height, width = noise.shape + assert num_output_frames % self.num_frame_per_block == 0 + + # Get device information + gen_device = next(self.generator.parameters()).device + vae_device = next(self.vae.parameters()).device + text_device = next(self.text_encoder.parameters()).device + + # Ensure noise is on the generator device + noise = noise.to(gen_device, dtype=torch.bfloat16) + + num_blocks = num_output_frames // self.num_frame_per_block + + # Encode the prompt once + print(f"Prompts: {text_prompts}") + gen_dtype = next(self.generator.parameters()).dtype + + conditional_dict = self.text_encoder(text_prompts=text_prompts) + + # Move conditioning tensors to generator device and dtype + for k, v in conditional_dict.items(): + if torch.is_tensor(v): + conditional_dict[k] = v.to(device=gen_device, dtype=gen_dtype) + + if low_memory: + gpu_memory_preservation = get_cuda_free_memory_gb(gpu) + 5 + move_model_to_device_with_memory_preservation( + self.text_encoder, + target_device=gpu, + preserved_memory_gb=gpu_memory_preservation, + ) + + output_device = torch.device('cpu') if low_memory else gen_device + output = torch.zeros( + [batch_size, num_output_frames, num_channels, height, width], + device=output_device, + dtype=torch.bfloat16, # Match noise dtype + ) + + # Initialize caches + local_attn_cfg = getattr(self.args.model_kwargs, "local_attn_size", -1) + kv_policy = "" + if local_attn_cfg != -1: + # local attention + kv_cache_size = local_attn_cfg * self.frame_seq_length + kv_policy = f"local, size={local_attn_cfg}" + else: + # global attention + kv_cache_size = num_output_frames * self.frame_seq_length + kv_policy = "global (-1)" + print(f"KV cache size: {kv_cache_size} (policy: {kv_policy}, frame_seq_length: {self.frame_seq_length}, num_output_frames: {num_output_frames})") + + self._initialize_kv_cache( + batch_size, + dtype=torch.bfloat16, + device=gen_device, # Use generator device + kv_cache_size_override=kv_cache_size + ) + self._initialize_crossattn_cache( + batch_size=batch_size, + dtype=torch.bfloat16, + device=gen_device, # Use generator device + ) + + current_start_frame = 0 + self.generator.model.local_attn_size = self.local_attn_size + print(f"[inference] local_attn_size set on model: {self.generator.model.local_attn_size}") + self._set_all_modules_max_attention_size(self.local_attn_size) + + # Temporal denoising by blocks + all_num_frames = [self.num_frame_per_block] * num_blocks + + if DEBUG: + print("[SinglePrompt] all_num_frames", all_num_frames) + + for current_num_frames in all_num_frames: + noisy_input = noise[ + :, current_start_frame : current_start_frame + current_num_frames + ] + + # ---------------- Spatial denoising loop ---------------- + for index, current_timestep in enumerate(self.denoising_step_list): + timestep = ( + torch.ones([batch_size, current_num_frames], + device=gen_device, + dtype=torch.int64) + * current_timestep + ) + + if index < len(self.denoising_step_list) - 1: + _, denoised_pred = self.generator( + noisy_image_or_video=noisy_input, + conditional_dict=conditional_dict, + timestep=timestep, + kv_cache=self.kv_cache1, + crossattn_cache=self.crossattn_cache, + current_start=current_start_frame * self.frame_seq_length, + ) + next_timestep = self.denoising_step_list[index + 1] + noisy_input = self.scheduler.add_noise( + denoised_pred.flatten(0, 1), + torch.randn_like(denoised_pred.flatten(0, 1)), + next_timestep + * torch.ones( + [batch_size * current_num_frames], device=gen_device, dtype=torch.long + ), + ).unflatten(0, denoised_pred.shape[:2]) + else: + _, denoised_pred = self.generator( + noisy_image_or_video=noisy_input, + conditional_dict=conditional_dict, + timestep=timestep, + kv_cache=self.kv_cache1, + crossattn_cache=self.crossattn_cache, + current_start=current_start_frame * self.frame_seq_length, + ) + + # Record output + output[:, current_start_frame : current_start_frame + current_num_frames] = denoised_pred.to(output.device) + + # Rerun with clean context to update cache + context_timestep = torch.ones_like(timestep) * self.args.context_noise + self.generator( + noisy_image_or_video=denoised_pred, + conditional_dict=conditional_dict, + timestep=context_timestep, + kv_cache=self.kv_cache1, + crossattn_cache=self.crossattn_cache, + current_start=current_start_frame * self.frame_seq_length, + ) + + # Update frame pointer + current_start_frame += current_num_frames + + # Standard decoding + video = self.vae.decode_to_pixel(output.to(vae_device, dtype=torch.float32), use_cache=False) + video = (video * 0.5 + 0.5).clamp(0, 1) + + if return_latents: + return video, output + return video diff --git a/requirements.txt b/requirements.txt index 7d586e2..764c4db 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,7 +27,7 @@ pydantic==2.10.6 scikit-image huggingface_hub[cli] dominate -nvidia-pyindex +#nvidia-pyindex nvidia-tensorrt pycuda onnx diff --git a/run_quantized.sh b/run_quantized.sh new file mode 100755 index 0000000..f8af39f --- /dev/null +++ b/run_quantized.sh @@ -0,0 +1,40 @@ +#!/bin/bash +# Run LongLive with your quantized models +# Place this in: /media/sid/Kingston/longlive/LongLive/ + +echo "╔══════════════════════════════════════════════════════════════════╗" +echo "ā•‘ LongLive Interactive Inference with Quantized Models ā•‘" +echo "ā•šā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•" +echo "" + +# Check if quantized models exist +if [ ! -f "../longlive_models/longlive_base_bfloat16.pt" ]; then + echo "āŒ Error: Quantized models not found!" + echo "" + echo "Please run the quantization script first:" + echo " cd /media/sid/Kingston/longlive" + echo " python longlive_3x3090.py" + exit 1 +fi + +echo "āœ… Quantized models found" +echo "" + +# Create a simple prompt file for testing +cat > test_prompts.txt << 'EOF' +A serene garden with colorful butterflies, sunny day, photorealistic +The butterflies begin to glow with a magical light, gathering together +The garden transforms into an enchanted mystical forest at twilight +EOF + +echo "šŸ“ Created test prompts:" +cat test_prompts.txt +echo "" + +# Run with quantized models +python interactive_inference_quantized.py \ + --config_path configs/longlive_interactive_inference.yaml \ + --use_quantized + +echo "" +echo "šŸŽ‰ Done! Check the output folder for your video." diff --git a/single_prompt_inference.py b/single_prompt_inference.py new file mode 100644 index 0000000..2386f2d --- /dev/null +++ b/single_prompt_inference.py @@ -0,0 +1,331 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# To view a copy of this license, visit http://www.apache.org/licenses/LICENSE-2.0 +# +# No warranties are given. The work is provided "AS IS", without warranty of any kind, express or implied. +# +# SPDX-License-Identifier: Apache-2.0 +import argparse +import os +from typing import List + +import torch +import torch.distributed as dist +from omegaconf import OmegaConf +from tqdm import tqdm +from torch.utils.data import DataLoader, SequentialSampler +from torch.utils.data.distributed import DistributedSampler +from torchvision.io import write_video +from torchvision import transforms # noqa: F401 +from einops import rearrange + +from utils.misc import set_seed +from utils.distributed import barrier +from utils.memory import gpu, get_cuda_free_memory_gb, DynamicSwapInstaller + +from pipeline.single_prompt_causal_inference import ( + SinglePromptCausalInferencePipeline, +) +from utils.dataset import TextDataset + +# ----------------------------- Argument parsing ----------------------------- +parser = argparse.ArgumentParser("Single prompt causal inference") +parser.add_argument("--config_path", type=str, help="Path to the config file") +parser.add_argument("--use_quantized", action="store_true", + help="Use quantized models from ../longlive_models/") +args = parser.parse_args() + +config = OmegaConf.load(args.config_path) + +# ======================== LOAD QUANTIZED MODELS TO CPU FIRST ======================== +quantized_base_state = None +quantized_lora_state = None + +if args.use_quantized: + print("\n" + "="*70) + print("šŸ”§ Pre-loading QUANTIZED models to CPU") + print("="*70) + + import sys + sys.path.insert(0, '..') + + base_path = '../longlive_models/longlive_base_bfloat16.pt' + lora_path = '../longlive_models/lora_bfloat16.pt' + + print(f"šŸ“„ Loading base checkpoint to CPU from: {base_path}") + base_checkpoint = torch.load(base_path, map_location='cpu', weights_only=False) + + print(f"šŸ“„ Loading LoRA checkpoint to CPU from: {lora_path}") + lora_checkpoint = torch.load(lora_path, map_location='cpu', weights_only=False) + + # Extract the models + quantized_base_state = base_checkpoint['generator'] + quantized_lora_state = lora_checkpoint['generator_lora'] + quantized_critic_lora = lora_checkpoint.get('critic_lora', {}) + + print(f"āœ… Loaded to CPU - will transfer after pipeline init") + print(f" Base params: {sum(p.numel() for p in quantized_base_state.values() if isinstance(p, torch.Tensor)):,}") + print(f" LoRA params: {sum(p.numel() for p in quantized_lora_state.values() if isinstance(p, torch.Tensor)):,}") + print("="*70 + "\n") + + # Don't load from checkpoint files + config.generator_ckpt = None + config.lora_ckpt = None + +# ======================== END PRE-LOADING ======================== + +# ----------------------------- Distributed setup ----------------------------- +if "LOCAL_RANK" in os.environ: # Multi-GPU via torchrun + os.environ["NCCL_CROSS_NIC"] = "1" + os.environ["NCCL_DEBUG"] = os.environ.get("NCCL_DEBUG", "INFO") + os.environ["NCCL_TIMEOUT"] = os.environ.get("NCCL_TIMEOUT", "1800") + + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + rank = int(os.environ.get("RANK", str(local_rank))) + + torch.cuda.set_device(local_rank) + device = torch.device(f"cuda:{local_rank}") + + if not dist.is_initialized(): + dist.init_process_group( + backend="nccl", + rank=rank, + world_size=world_size, + timeout=torch.distributed.constants.default_pg_timeout, + ) + + set_seed(config.seed + local_rank) + print(f"[Rank {rank}] Distributed mode on GPU {local_rank}") + +else: # Single-GPU mode + assert torch.cuda.is_available(), "CUDA is required but not available" + + local_rank = 0 + rank = 0 + device = torch.device("cuda:0") + torch.cuda.set_device(device) + + set_seed(config.seed) + print("Single GPU mode on cuda:0") + +low_memory = get_cuda_free_memory_gb(device) < 40 +torch.set_grad_enabled(False) + +# ======================== INITIALIZE PIPELINE ======================== +pipeline = SinglePromptCausalInferencePipeline(config, device=device) +print("Generator device:", next(pipeline.generator.parameters()).device) +print("VAE device:", next(pipeline.vae.parameters()).device) +print("Text encoder device:", next(pipeline.text_encoder.parameters()).device) + +# ======================== LOAD QUANTIZED WEIGHTS ======================== +if args.use_quantized and quantized_base_state is not None: + print("\n" + "="*70) + print("šŸ”§ Loading quantized base model into pipeline") + print("="*70) + + missing, unexpected = pipeline.generator.load_state_dict(quantized_base_state, strict=False) + if local_rank == 0: + if missing: + print(f"[Warning] {len(missing)} parameters missing: {missing[:8]} ...") + if unexpected: + print(f"[Warning] {len(unexpected)} unexpected params: {unexpected[:8]} ...") + print("āœ… Quantized base model loaded") + + print("šŸ”§ Converting ALL generator parameters to bfloat16...") + for name, param in pipeline.generator.named_parameters(): + if param.dtype != torch.bfloat16: + param.data = param.data.to(torch.bfloat16) + for name, buffer in pipeline.generator.named_buffers(): + if buffer.dtype != torch.bfloat16 and buffer.dtype == torch.float32: + buffer.data = buffer.data.to(torch.bfloat16) + print("āœ… All generator parameters converted to bfloat16") + + # Clear the CPU checkpoint to free memory + del base_checkpoint + del quantized_base_state + import gc + gc.collect() + +# --------------------------- LoRA support (optional) --------------------------- +from utils.lora_utils import configure_lora_for_model +import peft + +pipeline.is_lora_enabled = False +if getattr(config, "adapter", None) and configure_lora_for_model is not None: + if local_rank == 0: + print(f"\nšŸ”§ LoRA enabled with config: {config.adapter}") + print("Applying LoRA to generator (inference)...") + + pipeline.generator.model = configure_lora_for_model( + pipeline.generator.model, + model_name="generator", + lora_config=config.adapter, + is_main_process=(local_rank == 0), + ) + + # Load quantized LoRA weights + if args.use_quantized and quantized_lora_state is not None: + if local_rank == 0: + print(f"Loading QUANTIZED LoRA weights from CPU") + peft.set_peft_model_state_dict(pipeline.generator.model, quantized_lora_state) + if local_rank == 0: + print("āœ… Quantized LoRA weights loaded") + + # Clear LoRA checkpoint + del lora_checkpoint + del quantized_lora_state + gc.collect() + elif not args.use_quantized: + # Original LoRA loading for non-quantized + lora_ckpt_path = getattr(config, "lora_ckpt", None) + if lora_ckpt_path: + if local_rank == 0: + print(f"Loading LoRA checkpoint from {lora_ckpt_path}") + lora_checkpoint = torch.load(lora_ckpt_path, map_location="cpu") + if isinstance(lora_checkpoint, dict) and "generator_lora" in lora_checkpoint: + peft.set_peft_model_state_dict(pipeline.generator.model, lora_checkpoint["generator_lora"]) + else: + peft.set_peft_model_state_dict(pipeline.generator.model, lora_checkpoint) + if local_rank == 0: + print("LoRA weights loaded") + + pipeline.is_lora_enabled = True + +# ======================== MODEL PARALLEL GPU SETUP ======================== +print("\n🚚 Setting up model parallelism (Text → GPU0, Gen+VAE → GPU2)") + +gpu_text = torch.device("cuda:0") +gpu_main = torch.device("cuda:2") + +torch.cuda.empty_cache() + +# 1ļøāƒ£ Move TEXT ENCODER to GPU 0 (prompt encoding only) +print("🧠 Moving text encoder to GPU 0...") +pipeline.text_encoder = pipeline.text_encoder.to(gpu_text) +print(f"Text encoder device: {next(pipeline.text_encoder.parameters()).device}") + +torch.cuda.empty_cache() + +# 2ļøāƒ£ Move GENERATOR to GPU 2 +print("šŸŽ¬ Moving generator to GPU 2 in bfloat16...") +pipeline.generator = pipeline.generator.to(gpu_main, dtype=torch.bfloat16) +print(f"Generator device: {next(pipeline.generator.parameters()).device}") + +torch.cuda.empty_cache() + +# 3ļøāƒ£ Move VAE to GPU 2 +print("šŸ–¼ļø Moving VAE to GPU 2 in bfloat16...") +pipeline.vae = pipeline.vae.to(gpu_main, dtype=torch.bfloat16) +print(f"VAE device: {next(pipeline.vae.parameters()).device}") + +torch.cuda.empty_cache() + +print("\nāœ… Model parallel setup complete") +print(f" Text Encoder → {next(pipeline.text_encoder.parameters()).device}") +print(f" Generator → {next(pipeline.generator.parameters()).device}") +print(f" VAE → {next(pipeline.vae.parameters()).device}") + +# Set main device for noise generation & diffusion +device = gpu_main +# ======================== END GPU SETUP ======================== + +# ======================== OPTIONAL: USE SPECIFIC GPU FOR INFERENCE ======================== +# Uncomment and modify if you want to use a specific GPU (e.g., GPU 2) +# inference_device = torch.device("cuda:2") +# print(f"\n🚚 Moving models to {inference_device} for inference...") +# pipeline.generator.to(inference_device) +# pipeline.vae.to(inference_device) +# torch.cuda.empty_cache() +# device = inference_device +# ======================== END GPU SETUP ======================== + +# ----------------------------- Build dataset ----------------------------- +dataset = TextDataset(prompt_path=config.data_path, extended_prompt_path=config.data_path) + +num_prompts_total = len(dataset) +print(f"Number of prompt lines: {num_prompts_total}") + +if dist.is_initialized(): + sampler = DistributedSampler(dataset, shuffle=False, drop_last=True) +else: + sampler = SequentialSampler(dataset) + +dataloader = DataLoader(dataset, batch_size=1, sampler=sampler, num_workers=0, drop_last=False) + +if local_rank == 0: + os.makedirs(config.output_folder, exist_ok=True) + +if dist.is_initialized(): + dist.barrier() + +# ----------------------------- Inference loop ----------------------------- +print("\n" + "="*70) +print("šŸš€ Starting video generation...") +print("="*70) + +for i, batch_data in tqdm(enumerate(dataloader), disable=(local_rank != 0)): + idx = batch_data["idx"].item() + + # Get the prompt from the batch + prompt = batch_data['prompts'][0] + extended_prompt = batch_data.get('extended_prompts', [prompt])[0] + + # Use extended prompt if available, otherwise use regular prompt + text_prompt = extended_prompt if extended_prompt else prompt + prompts = [text_prompt] * config.num_samples + + # Get the actual device where generator is located + gen_device = next(pipeline.generator.parameters()).device + + sampled_noise = torch.randn( + [ + config.num_samples, + config.num_output_frames, + 16, + 60, + 104, + ], + device=gen_device, # Use generator's device + dtype=torch.bfloat16, + ) + + with torch.autocast("cuda", dtype=torch.bfloat16): + video = pipeline.inference( + noise=sampled_noise, + text_prompts=prompts, + return_latents=False, + ) + + current_video = rearrange(video, "b t c h w -> b t h w c").cpu() * 255.0 + + if dist.is_initialized(): + rank = dist.get_rank() + else: + rank = 0 + + model_type = "quantized" if args.use_quantized else "regular" + + for seed_idx in range(config.num_samples): + if config.save_with_index: + output_path = os.path.join(config.output_folder, f"rank{rank}-{idx}-{seed_idx}_{model_type}.mp4") + else: + short_name = text_prompt[:100].replace("/", "_") + output_path = os.path.join(config.output_folder, f"rank{rank}-{short_name}-{seed_idx}_{model_type}.mp4") + + write_video(output_path, current_video[seed_idx].to(torch.uint8), fps=16) + + if local_rank == 0: + print(f"āœ… Saved: {output_path}") + + if config.inference_iter != -1 and i >= config.inference_iter: + break + +print("\n" + "="*70) +print("šŸŽ‰ Video generation complete!") +print("="*70) + +if dist.is_initialized(): + dist.destroy_process_group() diff --git a/single_quantized.sh b/single_quantized.sh new file mode 100755 index 0000000..186caf0 --- /dev/null +++ b/single_quantized.sh @@ -0,0 +1 @@ +python single_prompt_inference.py --config_path configs/longlive_inference.yaml --use_quantized \ No newline at end of file diff --git a/test_prompts.txt b/test_prompts.txt new file mode 100644 index 0000000..0d75b9f --- /dev/null +++ b/test_prompts.txt @@ -0,0 +1,3 @@ +A serene garden with colorful butterflies, sunny day, photorealistic +The butterflies begin to glow with a magical light, gathering together +The garden transforms into an enchanted mystical forest at twilight