diff --git a/README.md b/README.md
index 62a22c3..c48cad6 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,13 @@
+
This repo contains code to run LongLive using 2x 3090, both single prompt and interactive prompts. It needs at least 2x24gb gpus
+Download bfloat weights from https://huggingface.co/srivassid/LongLiveMultiGPU/tree/main, create a new folder longlive_models/models folder under the same parent as LongLive adn put them there,
+ and clone the model weight [LongLive-1.3B](https://huggingface.co/Efficient-Large-Model/LongLive-1.3B) and put that folder inside LongLive folder.
+Create a folder under longlive_models/prompts and create a file called interactive_models.jsonl, pick one line from LongLive/example/interactive_example.jsonl and save it.
+Do the same for single prompt, create a file longlive_models/prompts/vidprom_filtered_extended.txt, go to LongLive/example/long_example.txt adn paste one of the lines in the file
+if nvidia-pyindex package throws an error, comment it out from requirements.txt
+For single prompts run single_quantized.sh, for interactive prompts, run run_quantized.sh
+I put cuda:2 in interactive_inference_quantized.py and single_prompt_inference.py. Change it to cuda:1 for the second gpu
+Any issues, leave a comment
+
diff --git a/configs/longlive_inference.yaml b/configs/longlive_inference.yaml
index 2c0daf5..ed1111f 100644
--- a/configs/longlive_inference.yaml
+++ b/configs/longlive_inference.yaml
@@ -4,7 +4,7 @@ denoising_step_list:
- 500
- 250
warp_denoising_step: true # need to remove - 0 in denoising_step_list if warp_denoising_step is true
-num_frame_per_block: 3
+num_frame_per_block: 6
model_name: Wan2.1-T2V-1.3B
model_kwargs:
local_attn_size: 12
diff --git a/configs/longlive_interactive_inference.yaml b/configs/longlive_interactive_inference.yaml
index 113ae78..64aeaf1 100644
--- a/configs/longlive_interactive_inference.yaml
+++ b/configs/longlive_interactive_inference.yaml
@@ -5,7 +5,7 @@ denoising_step_list:
- 500
- 250
warp_denoising_step: true # need to remove - 0 in denoising_step_list if warp_denoising_step is true
-num_frame_per_block: 3
+num_frame_per_block: 6
model_name: Wan2.1-T2V-1.3B
model_kwargs:
local_attn_size: 12
@@ -17,12 +17,12 @@ model_kwargs:
data_path: longlive_models/prompts/interactive_example.jsonl
output_folder: videos/interactive
inference_iter: -1
-num_output_frames: 240
+num_output_frames: 144
use_ema: false
seed: 1
num_samples: 1
save_with_index: true
-switch_frame_indices: 40, 80, 120, 160, 200
+switch_frame_indices: 48, 96
global_sink: true
context_noise: 0
diff --git a/interactive_inference_quantized.py b/interactive_inference_quantized.py
new file mode 100644
index 0000000..e7aa291
--- /dev/null
+++ b/interactive_inference_quantized.py
@@ -0,0 +1,341 @@
+# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License").
+# You may not use this file except in compliance with the License.
+# To view a copy of this license, visit http://www.apache.org/licenses/LICENSE-2.0
+#
+# No warranties are given. The work is provided "AS IS", without warranty of any kind, express or implied.
+#
+# SPDX-License-Identifier: Apache-2.0
+import argparse
+import os
+from typing import List
+
+import torch
+import torch.distributed as dist
+from omegaconf import OmegaConf
+from tqdm import tqdm
+from torch.utils.data import DataLoader, SequentialSampler
+from torch.utils.data.distributed import DistributedSampler
+from torchvision.io import write_video
+from torchvision import transforms # noqa: F401
+from einops import rearrange
+
+from utils.misc import set_seed
+from utils.distributed import barrier
+from utils.memory import gpu, get_cuda_free_memory_gb, DynamicSwapInstaller
+
+from pipeline.interactive_causal_inference import (
+ InteractiveCausalInferencePipeline,
+)
+from utils.dataset import MultiTextDataset
+
+# ----------------------------- Argument parsing -----------------------------
+parser = argparse.ArgumentParser("Interactive causal inference")
+parser.add_argument("--config_path", type=str, help="Path to the config file")
+parser.add_argument("--use_quantized", action="store_true",
+ help="Use quantized models from ../longlive_models/")
+args = parser.parse_args()
+
+config = OmegaConf.load(args.config_path)
+# config.model_kwargs.local_attn_size = 32
+# ======================== LOAD QUANTIZED MODELS TO CPU FIRST ========================
+quantized_base_state = None
+quantized_lora_state = None
+
+if args.use_quantized:
+ print("\n" + "="*70)
+ print("š§ Pre-loading QUANTIZED models to CPU")
+ print("="*70)
+
+ import sys
+ sys.path.insert(0, '..')
+
+ base_path = '../longlive_models/longlive_base_bfloat16.pt'
+ lora_path = '../longlive_models/lora_bfloat16.pt'
+
+ print(f"š„ Loading base checkpoint to CPU from: {base_path}")
+ base_checkpoint = torch.load(base_path, map_location='cpu', weights_only=False)
+
+ print(f"š„ Loading LoRA checkpoint to CPU from: {lora_path}")
+ lora_checkpoint = torch.load(lora_path, map_location='cpu', weights_only=False)
+
+ # Extract the models
+ quantized_base_state = base_checkpoint['generator']
+ quantized_lora_state = lora_checkpoint['generator_lora']
+ quantized_critic_lora = lora_checkpoint.get('critic_lora', {})
+
+ print(f"ā
Loaded to CPU - will transfer after pipeline init")
+ print(f" Base params: {sum(p.numel() for p in quantized_base_state.values() if isinstance(p, torch.Tensor)):,}")
+ print(f" LoRA params: {sum(p.numel() for p in quantized_lora_state.values() if isinstance(p, torch.Tensor)):,}")
+ print("="*70 + "\n")
+
+ # Don't load from checkpoint files
+ config.generator_ckpt = None
+ config.lora_ckpt = None # We'll load this manually later
+
+# ======================== END PRE-LOADING ========================
+
+# ----------------------------- Distributed setup -----------------------------
+if "LOCAL_RANK" in os.environ: # Multi-GPU via torchrun
+ os.environ["NCCL_CROSS_NIC"] = "1"
+ os.environ["NCCL_DEBUG"] = os.environ.get("NCCL_DEBUG", "INFO")
+ os.environ["NCCL_TIMEOUT"] = os.environ.get("NCCL_TIMEOUT", "1800")
+
+ local_rank = int(os.environ["LOCAL_RANK"])
+ world_size = int(os.environ.get("WORLD_SIZE", "1"))
+ rank = int(os.environ.get("RANK", str(local_rank)))
+
+ torch.cuda.set_device(local_rank)
+ device = torch.device(f"cuda:{local_rank}")
+
+ if not dist.is_initialized():
+ dist.init_process_group(
+ backend="nccl",
+ rank=rank,
+ world_size=world_size,
+ timeout=torch.distributed.constants.default_pg_timeout,
+ )
+
+ set_seed(config.seed + local_rank)
+ print(f"[Rank {rank}] Distributed mode on GPU {local_rank}")
+
+else: # Single-GPU mode
+ assert torch.cuda.is_available(), "CUDA is required but not available"
+
+ local_rank = 0
+ rank = 0
+ device = torch.device("cuda:0")
+ torch.cuda.set_device(device)
+
+ set_seed(config.seed)
+ print("Single GPU mode on cuda:0")
+
+low_memory = get_cuda_free_memory_gb(device) < 40
+torch.set_grad_enabled(False)
+
+# ======================== INITIALIZE PIPELINE ========================
+pipeline = InteractiveCausalInferencePipeline(config, device=device)
+print("Generator device:", next(pipeline.generator.parameters()).device)
+print("VAE device:", next(pipeline.vae.parameters()).device)
+print("Text encoder device:", next(pipeline.text_encoder.parameters()).device)
+# ======================== LOAD QUANTIZED WEIGHTS ========================
+if args.use_quantized and quantized_base_state is not None:
+ print("\n" + "="*70)
+ print("š§ Loading quantized base model into pipeline")
+ print("="*70)
+
+ missing, unexpected = pipeline.generator.load_state_dict(quantized_base_state, strict=False)
+ if local_rank == 0:
+ if missing:
+ print(f"[Warning] {len(missing)} parameters missing: {missing[:8]} ...")
+ if unexpected:
+ print(f"[Warning] {len(unexpected)} unexpected params: {unexpected[:8]} ...")
+ print("ā
Quantized base model loaded")
+
+ print("š§ Converting ALL generator parameters to bfloat16...")
+ for name, param in pipeline.generator.named_parameters():
+ if param.dtype != torch.bfloat16:
+ param.data = param.data.to(torch.bfloat16)
+ for name, buffer in pipeline.generator.named_buffers():
+ if buffer.dtype != torch.bfloat16 and buffer.dtype == torch.float32:
+ buffer.data = buffer.data.to(torch.bfloat16)
+ print("ā
All generator parameters converted to bfloat16")
+
+ # Clear the CPU checkpoint to free memory
+ del base_checkpoint
+ del quantized_base_state
+ import gc
+ gc.collect()
+
+# --------------------------- LoRA support (optional) ---------------------------
+from utils.lora_utils import configure_lora_for_model
+import peft
+
+pipeline.is_lora_enabled = False
+if getattr(config, "adapter", None) and configure_lora_for_model is not None:
+ if local_rank == 0:
+ print(f"\nš§ LoRA enabled with config: {config.adapter}")
+ print("Applying LoRA to generator (inference)...")
+
+ pipeline.generator.model = configure_lora_for_model(
+ pipeline.generator.model,
+ model_name="generator",
+ lora_config=config.adapter,
+ is_main_process=(local_rank == 0),
+ )
+
+ # Load quantized LoRA weights
+ if args.use_quantized and quantized_lora_state is not None:
+ if local_rank == 0:
+ print(f"Loading QUANTIZED LoRA weights from CPU")
+ peft.set_peft_model_state_dict(pipeline.generator.model, quantized_lora_state)
+ if local_rank == 0:
+ print("ā
Quantized LoRA weights loaded")
+
+ # Clear LoRA checkpoint
+ del lora_checkpoint
+ del quantized_lora_state
+ gc.collect()
+ elif not args.use_quantized:
+ # Original LoRA loading for non-quantized
+ lora_ckpt_path = getattr(config, "lora_ckpt", None)
+ if lora_ckpt_path:
+ if local_rank == 0:
+ print(f"Loading LoRA checkpoint from {lora_ckpt_path}")
+ lora_checkpoint = torch.load(lora_ckpt_path, map_location="cpu")
+ if isinstance(lora_checkpoint, dict) and "generator_lora" in lora_checkpoint:
+ peft.set_peft_model_state_dict(pipeline.generator.model, lora_checkpoint["generator_lora"])
+ else:
+ peft.set_peft_model_state_dict(pipeline.generator.model, lora_checkpoint)
+ if local_rank == 0:
+ print("LoRA weights loaded")
+
+ pipeline.is_lora_enabled = True
+
+# Move pipeline to appropriate dtype and device
+# print("\nš§ Moving pipeline to bfloat16...")
+# pipeline = pipeline.to(dtype=torch.bfloat16)
+
+if low_memory:
+ DynamicSwapInstaller.install_model(pipeline.text_encoder, device=device)
+
+# ======================== USE GPU 1 or 2 FOR INFERENCE ========================
+
+# ======================== USE GPU 2 WITH DYNAMIC SWAP ========================
+# inference_device = torch.device('cuda:2')
+#
+# print(f"\nš¬ Loading all components to GPU 2")
+# print(f" GPU 2 Free VRAM before: {get_cuda_free_memory_gb(inference_device):.2f} GB")
+#
+# torch.cuda.empty_cache()
+#
+# print(f" GPU 2 Free VRAM after clearing GPU 0: {get_cuda_free_memory_gb(inference_device):.2f} GB")
+#
+# # Now move everything to GPU 2
+# print(" Moving generator to GPU 2...")
+# pipeline.generator.to(device=inference_device)
+#
+# print(" Converting VAE to bfloat16...")
+# pipeline.vae = pipeline.vae.to(dtype=torch.bfloat16)
+#
+# # print(" Moving text_encoder from CPU to GPU 2...")
+# # pipeline.text_encoder = pipeline.text_encoder.to(device=inference_device)
+#
+# print(f"ā
All components on GPU 2")
+# print(f" GPU 2 VRAM used: {(torch.cuda.memory_allocated(2) / 1024**3):.2f} GB")
+
+print("\nš Moving models to GPU 2 for inference...")
+
+inference_device = torch.device("cuda:2")
+
+pipeline.generator.to(inference_device)
+pipeline.vae.to(inference_device)
+
+# Optional but recommended for speed
+pipeline.generator = pipeline.generator.to(dtype=torch.bfloat16)
+pipeline.vae = pipeline.vae.to(dtype=torch.bfloat16)
+
+torch.cuda.empty_cache()
+
+print("Generator device:", next(pipeline.generator.parameters()).device)
+print("VAE device:", next(pipeline.vae.parameters()).device)
+print("Text encoder device:", next(pipeline.text_encoder.parameters()).device)
+
+device = inference_device
+
+# ======================== END GPU SETUP ========================
+
+# ----------------------------- Build dataset -----------------------------
+if isinstance(config.switch_frame_indices, int):
+ switch_frame_indices: List[int] = [int(config.switch_frame_indices)]
+else:
+ switch_frame_indices: List[int] = [
+ int(x) for x in str(config.switch_frame_indices).split(",") if str(x).strip()
+ ]
+
+dataset = MultiTextDataset(config.data_path)
+
+num_segments = len(dataset[0]["prompts_list"])
+assert len(switch_frame_indices) == num_segments - 1, (
+ "The number of switch_frame_indices should be the number of prompt segments minus 1"
+)
+
+print("Number of segments:", num_segments)
+print("Switch frame indices:", switch_frame_indices)
+
+num_prompts_total = len(dataset)
+print(f"Number of prompt lines: {num_prompts_total}")
+
+if dist.is_initialized():
+ sampler = DistributedSampler(dataset, shuffle=False, drop_last=True)
+else:
+ sampler = SequentialSampler(dataset)
+
+dataloader = DataLoader(dataset, batch_size=1, sampler=sampler, num_workers=0, drop_last=False)
+
+if local_rank == 0:
+ os.makedirs(config.output_folder, exist_ok=True)
+
+if dist.is_initialized():
+ dist.barrier()
+
+# ----------------------------- Inference loop -----------------------------
+print("\n" + "="*70)
+print("š Starting video generation...")
+print("="*70)
+
+for i, batch_data in tqdm(enumerate(dataloader), disable=(local_rank != 0)):
+ idx = batch_data["idx"].item()
+ prompts_list: List[str] = batch_data["prompts_list"]
+
+ sampled_noise = torch.randn(
+ [
+ config.num_samples,
+ config.num_output_frames,
+ 16,
+ 60,
+ 104,
+ ],
+ device=device,
+ dtype=torch.bfloat16,
+ )
+
+ with torch.autocast("cuda", dtype=torch.bfloat16):
+ video = pipeline.inference(
+ noise=sampled_noise,
+ text_prompts_list=prompts_list,
+ switch_frame_indices=switch_frame_indices,
+ return_latents=False,
+ )
+
+ current_video = rearrange(video, "b t c h w -> b t h w c").cpu() * 255.0
+
+ if dist.is_initialized():
+ rank = dist.get_rank()
+ else:
+ rank = 0
+
+ model_type = "quantized" if args.use_quantized else "regular"
+
+ for seed_idx in range(config.num_samples):
+ if config.save_with_index:
+ output_path = os.path.join(config.output_folder, f"rank{rank}-{idx}-{seed_idx}_{model_type}.mp4")
+ else:
+ short_name = prompts_list[0][0][:100].replace("/", "_")
+ output_path = os.path.join(config.output_folder, f"rank{rank}-{short_name}-{seed_idx}_{model_type}.mp4")
+
+ write_video(output_path, current_video[seed_idx].to(torch.uint8), fps=16)
+
+ if local_rank == 0:
+ print(f"ā
Saved: {output_path}")
+
+ if config.inference_iter != -1 and i >= config.inference_iter:
+ break
+
+print("\n" + "="*70)
+print("š Video generation complete!")
+print("="*70)
+
+if dist.is_initialized():
+ dist.destroy_process_group()
diff --git a/pipeline/interactive_causal_inference.py b/pipeline/interactive_causal_inference.py
index e44dd92..7f2fde5 100644
--- a/pipeline/interactive_causal_inference.py
+++ b/pipeline/interactive_causal_inference.py
@@ -32,6 +32,7 @@ def __init__(
# Internal helpers
def _recache_after_switch(self, output, current_start_frame, new_conditional_dict):
+ print("š RECACHING NOW...")
if not self.global_sink:
# reset kv cache
for block_idx in range(self.num_transformer_blocks):
@@ -120,12 +121,28 @@ def inference(
"length of switch_frame_indices should be one less than text_prompts_list"
)
assert num_output_frames % self.num_frame_per_block == 0
+ gen_device = next(self.generator.parameters()).device
+ vae_device = next(self.vae.parameters()).device
+ text_device = next(self.text_encoder.parameters()).device
+ noise = noise.to(gen_device)
num_blocks = num_output_frames // self.num_frame_per_block
# encode all prompts
print(text_prompts_list)
- cond_list = [self.text_encoder(text_prompts=p) for p in text_prompts_list]
+ gen_device = next(self.generator.parameters()).device
+ gen_dtype = next(self.generator.parameters()).dtype
+
+ cond_list = []
+ for p in text_prompts_list:
+ cond = self.text_encoder(text_prompts=p)
+
+ # Move every tensor inside the conditioning dict to generator device AND dtype
+ for k, v in cond.items():
+ if torch.is_tensor(v):
+ cond[k] = v.to(device=gen_device, dtype=gen_dtype)
+
+ cond_list.append(cond)
if low_memory:
gpu_memory_preservation = get_cuda_free_memory_gb(gpu) + 5
@@ -205,12 +222,13 @@ def inference(
noisy_input = noise[
:, current_start_frame : current_start_frame + current_num_frames
]
-
+ # print("KV cache device:", self.kv_cache1[0]["k"].device)
+ # print("Cross-attn cache device:", self.crossattn_cache[0]["k"].device)
# ---------------- Spatial denoising loop ----------------
for index, current_timestep in enumerate(self.denoising_step_list):
timestep = (
torch.ones([batch_size, current_num_frames],
- device=noise.device,
+ device=gen_device,
dtype=torch.int64)
* current_timestep
)
@@ -230,7 +248,7 @@ def inference(
torch.randn_like(denoised_pred.flatten(0, 1)),
next_timestep
* torch.ones(
- [batch_size * current_num_frames], device=noise.device, dtype=torch.long
+ [batch_size * current_num_frames], device=gen_device, dtype=torch.long
),
).unflatten(0, denoised_pred.shape[:2])
else:
@@ -261,7 +279,8 @@ def inference(
current_start_frame += current_num_frames
# Standard decoding
- video = self.vae.decode_to_pixel(output.to(noise.device), use_cache=False)
+ # video = self.vae.decode_to_pixel(output.to(vae_device, dtype=next(self.vae.parameters()).dtype), use_cache=False)
+ video = self.vae.decode_to_pixel(output.to(vae_device, dtype=torch.float32), use_cache=False)
video = (video * 0.5 + 0.5).clamp(0, 1)
if return_latents:
diff --git a/pipeline/single_prompt_causal_inference.py b/pipeline/single_prompt_causal_inference.py
new file mode 100644
index 0000000..7622b78
--- /dev/null
+++ b/pipeline/single_prompt_causal_inference.py
@@ -0,0 +1,193 @@
+# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License").
+# You may not use this file except in compliance with the License.
+# To view a copy of this license, visit http://www.apache.org/licenses/LICENSE-2.0
+#
+# No warranties are given. The work is provided "AS IS", without warranty of any kind, express or implied.
+#
+# SPDX-License-Identifier: Apache-2.0
+from typing import List, Optional
+import torch
+
+from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper
+from utils.memory import gpu, get_cuda_free_memory_gb, move_model_to_device_with_memory_preservation
+from pipeline.causal_inference import CausalInferencePipeline
+import torch.distributed as dist
+from utils.debug_option import DEBUG
+
+
+class SinglePromptCausalInferencePipeline(CausalInferencePipeline):
+ """
+ Simplified causal inference pipeline for single prompt (no prompt switching).
+ This is essentially the base CausalInferencePipeline without the multi-prompt logic.
+ """
+ def __init__(
+ self,
+ args,
+ device,
+ *,
+ generator: WanDiffusionWrapper | None = None,
+ text_encoder: WanTextEncoder | None = None,
+ vae: WanVAEWrapper | None = None,
+ ):
+ super().__init__(args, device, generator=generator, text_encoder=text_encoder, vae=vae)
+ self.global_sink = getattr(args, "global_sink", False)
+
+ def inference(
+ self,
+ noise: torch.Tensor,
+ *,
+ text_prompts: List[str],
+ return_latents: bool = False,
+ low_memory: bool = False,
+ ):
+ """Generate a video with a single prompt throughout.
+
+ Args:
+ noise: Noise tensor, shape = (B, T_out, C, H, W).
+ text_prompts: List[str], prompts for each sample in the batch.
+ return_latents: Whether to also return the latent tensor.
+ low_memory: Enable low-memory mode.
+ """
+ batch_size, num_output_frames, num_channels, height, width = noise.shape
+ assert num_output_frames % self.num_frame_per_block == 0
+
+ # Get device information
+ gen_device = next(self.generator.parameters()).device
+ vae_device = next(self.vae.parameters()).device
+ text_device = next(self.text_encoder.parameters()).device
+
+ # Ensure noise is on the generator device
+ noise = noise.to(gen_device, dtype=torch.bfloat16)
+
+ num_blocks = num_output_frames // self.num_frame_per_block
+
+ # Encode the prompt once
+ print(f"Prompts: {text_prompts}")
+ gen_dtype = next(self.generator.parameters()).dtype
+
+ conditional_dict = self.text_encoder(text_prompts=text_prompts)
+
+ # Move conditioning tensors to generator device and dtype
+ for k, v in conditional_dict.items():
+ if torch.is_tensor(v):
+ conditional_dict[k] = v.to(device=gen_device, dtype=gen_dtype)
+
+ if low_memory:
+ gpu_memory_preservation = get_cuda_free_memory_gb(gpu) + 5
+ move_model_to_device_with_memory_preservation(
+ self.text_encoder,
+ target_device=gpu,
+ preserved_memory_gb=gpu_memory_preservation,
+ )
+
+ output_device = torch.device('cpu') if low_memory else gen_device
+ output = torch.zeros(
+ [batch_size, num_output_frames, num_channels, height, width],
+ device=output_device,
+ dtype=torch.bfloat16, # Match noise dtype
+ )
+
+ # Initialize caches
+ local_attn_cfg = getattr(self.args.model_kwargs, "local_attn_size", -1)
+ kv_policy = ""
+ if local_attn_cfg != -1:
+ # local attention
+ kv_cache_size = local_attn_cfg * self.frame_seq_length
+ kv_policy = f"local, size={local_attn_cfg}"
+ else:
+ # global attention
+ kv_cache_size = num_output_frames * self.frame_seq_length
+ kv_policy = "global (-1)"
+ print(f"KV cache size: {kv_cache_size} (policy: {kv_policy}, frame_seq_length: {self.frame_seq_length}, num_output_frames: {num_output_frames})")
+
+ self._initialize_kv_cache(
+ batch_size,
+ dtype=torch.bfloat16,
+ device=gen_device, # Use generator device
+ kv_cache_size_override=kv_cache_size
+ )
+ self._initialize_crossattn_cache(
+ batch_size=batch_size,
+ dtype=torch.bfloat16,
+ device=gen_device, # Use generator device
+ )
+
+ current_start_frame = 0
+ self.generator.model.local_attn_size = self.local_attn_size
+ print(f"[inference] local_attn_size set on model: {self.generator.model.local_attn_size}")
+ self._set_all_modules_max_attention_size(self.local_attn_size)
+
+ # Temporal denoising by blocks
+ all_num_frames = [self.num_frame_per_block] * num_blocks
+
+ if DEBUG:
+ print("[SinglePrompt] all_num_frames", all_num_frames)
+
+ for current_num_frames in all_num_frames:
+ noisy_input = noise[
+ :, current_start_frame : current_start_frame + current_num_frames
+ ]
+
+ # ---------------- Spatial denoising loop ----------------
+ for index, current_timestep in enumerate(self.denoising_step_list):
+ timestep = (
+ torch.ones([batch_size, current_num_frames],
+ device=gen_device,
+ dtype=torch.int64)
+ * current_timestep
+ )
+
+ if index < len(self.denoising_step_list) - 1:
+ _, denoised_pred = self.generator(
+ noisy_image_or_video=noisy_input,
+ conditional_dict=conditional_dict,
+ timestep=timestep,
+ kv_cache=self.kv_cache1,
+ crossattn_cache=self.crossattn_cache,
+ current_start=current_start_frame * self.frame_seq_length,
+ )
+ next_timestep = self.denoising_step_list[index + 1]
+ noisy_input = self.scheduler.add_noise(
+ denoised_pred.flatten(0, 1),
+ torch.randn_like(denoised_pred.flatten(0, 1)),
+ next_timestep
+ * torch.ones(
+ [batch_size * current_num_frames], device=gen_device, dtype=torch.long
+ ),
+ ).unflatten(0, denoised_pred.shape[:2])
+ else:
+ _, denoised_pred = self.generator(
+ noisy_image_or_video=noisy_input,
+ conditional_dict=conditional_dict,
+ timestep=timestep,
+ kv_cache=self.kv_cache1,
+ crossattn_cache=self.crossattn_cache,
+ current_start=current_start_frame * self.frame_seq_length,
+ )
+
+ # Record output
+ output[:, current_start_frame : current_start_frame + current_num_frames] = denoised_pred.to(output.device)
+
+ # Rerun with clean context to update cache
+ context_timestep = torch.ones_like(timestep) * self.args.context_noise
+ self.generator(
+ noisy_image_or_video=denoised_pred,
+ conditional_dict=conditional_dict,
+ timestep=context_timestep,
+ kv_cache=self.kv_cache1,
+ crossattn_cache=self.crossattn_cache,
+ current_start=current_start_frame * self.frame_seq_length,
+ )
+
+ # Update frame pointer
+ current_start_frame += current_num_frames
+
+ # Standard decoding
+ video = self.vae.decode_to_pixel(output.to(vae_device, dtype=torch.float32), use_cache=False)
+ video = (video * 0.5 + 0.5).clamp(0, 1)
+
+ if return_latents:
+ return video, output
+ return video
diff --git a/requirements.txt b/requirements.txt
index 7d586e2..764c4db 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -27,7 +27,7 @@ pydantic==2.10.6
scikit-image
huggingface_hub[cli]
dominate
-nvidia-pyindex
+#nvidia-pyindex
nvidia-tensorrt
pycuda
onnx
diff --git a/run_quantized.sh b/run_quantized.sh
new file mode 100755
index 0000000..f8af39f
--- /dev/null
+++ b/run_quantized.sh
@@ -0,0 +1,40 @@
+#!/bin/bash
+# Run LongLive with your quantized models
+# Place this in: /media/sid/Kingston/longlive/LongLive/
+
+echo "āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā"
+echo "ā LongLive Interactive Inference with Quantized Models ā"
+echo "āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā"
+echo ""
+
+# Check if quantized models exist
+if [ ! -f "../longlive_models/longlive_base_bfloat16.pt" ]; then
+ echo "ā Error: Quantized models not found!"
+ echo ""
+ echo "Please run the quantization script first:"
+ echo " cd /media/sid/Kingston/longlive"
+ echo " python longlive_3x3090.py"
+ exit 1
+fi
+
+echo "ā
Quantized models found"
+echo ""
+
+# Create a simple prompt file for testing
+cat > test_prompts.txt << 'EOF'
+A serene garden with colorful butterflies, sunny day, photorealistic
+The butterflies begin to glow with a magical light, gathering together
+The garden transforms into an enchanted mystical forest at twilight
+EOF
+
+echo "š Created test prompts:"
+cat test_prompts.txt
+echo ""
+
+# Run with quantized models
+python interactive_inference_quantized.py \
+ --config_path configs/longlive_interactive_inference.yaml \
+ --use_quantized
+
+echo ""
+echo "š Done! Check the output folder for your video."
diff --git a/single_prompt_inference.py b/single_prompt_inference.py
new file mode 100644
index 0000000..2386f2d
--- /dev/null
+++ b/single_prompt_inference.py
@@ -0,0 +1,331 @@
+# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License").
+# You may not use this file except in compliance with the License.
+# To view a copy of this license, visit http://www.apache.org/licenses/LICENSE-2.0
+#
+# No warranties are given. The work is provided "AS IS", without warranty of any kind, express or implied.
+#
+# SPDX-License-Identifier: Apache-2.0
+import argparse
+import os
+from typing import List
+
+import torch
+import torch.distributed as dist
+from omegaconf import OmegaConf
+from tqdm import tqdm
+from torch.utils.data import DataLoader, SequentialSampler
+from torch.utils.data.distributed import DistributedSampler
+from torchvision.io import write_video
+from torchvision import transforms # noqa: F401
+from einops import rearrange
+
+from utils.misc import set_seed
+from utils.distributed import barrier
+from utils.memory import gpu, get_cuda_free_memory_gb, DynamicSwapInstaller
+
+from pipeline.single_prompt_causal_inference import (
+ SinglePromptCausalInferencePipeline,
+)
+from utils.dataset import TextDataset
+
+# ----------------------------- Argument parsing -----------------------------
+parser = argparse.ArgumentParser("Single prompt causal inference")
+parser.add_argument("--config_path", type=str, help="Path to the config file")
+parser.add_argument("--use_quantized", action="store_true",
+ help="Use quantized models from ../longlive_models/")
+args = parser.parse_args()
+
+config = OmegaConf.load(args.config_path)
+
+# ======================== LOAD QUANTIZED MODELS TO CPU FIRST ========================
+quantized_base_state = None
+quantized_lora_state = None
+
+if args.use_quantized:
+ print("\n" + "="*70)
+ print("š§ Pre-loading QUANTIZED models to CPU")
+ print("="*70)
+
+ import sys
+ sys.path.insert(0, '..')
+
+ base_path = '../longlive_models/longlive_base_bfloat16.pt'
+ lora_path = '../longlive_models/lora_bfloat16.pt'
+
+ print(f"š„ Loading base checkpoint to CPU from: {base_path}")
+ base_checkpoint = torch.load(base_path, map_location='cpu', weights_only=False)
+
+ print(f"š„ Loading LoRA checkpoint to CPU from: {lora_path}")
+ lora_checkpoint = torch.load(lora_path, map_location='cpu', weights_only=False)
+
+ # Extract the models
+ quantized_base_state = base_checkpoint['generator']
+ quantized_lora_state = lora_checkpoint['generator_lora']
+ quantized_critic_lora = lora_checkpoint.get('critic_lora', {})
+
+ print(f"ā
Loaded to CPU - will transfer after pipeline init")
+ print(f" Base params: {sum(p.numel() for p in quantized_base_state.values() if isinstance(p, torch.Tensor)):,}")
+ print(f" LoRA params: {sum(p.numel() for p in quantized_lora_state.values() if isinstance(p, torch.Tensor)):,}")
+ print("="*70 + "\n")
+
+ # Don't load from checkpoint files
+ config.generator_ckpt = None
+ config.lora_ckpt = None
+
+# ======================== END PRE-LOADING ========================
+
+# ----------------------------- Distributed setup -----------------------------
+if "LOCAL_RANK" in os.environ: # Multi-GPU via torchrun
+ os.environ["NCCL_CROSS_NIC"] = "1"
+ os.environ["NCCL_DEBUG"] = os.environ.get("NCCL_DEBUG", "INFO")
+ os.environ["NCCL_TIMEOUT"] = os.environ.get("NCCL_TIMEOUT", "1800")
+
+ local_rank = int(os.environ["LOCAL_RANK"])
+ world_size = int(os.environ.get("WORLD_SIZE", "1"))
+ rank = int(os.environ.get("RANK", str(local_rank)))
+
+ torch.cuda.set_device(local_rank)
+ device = torch.device(f"cuda:{local_rank}")
+
+ if not dist.is_initialized():
+ dist.init_process_group(
+ backend="nccl",
+ rank=rank,
+ world_size=world_size,
+ timeout=torch.distributed.constants.default_pg_timeout,
+ )
+
+ set_seed(config.seed + local_rank)
+ print(f"[Rank {rank}] Distributed mode on GPU {local_rank}")
+
+else: # Single-GPU mode
+ assert torch.cuda.is_available(), "CUDA is required but not available"
+
+ local_rank = 0
+ rank = 0
+ device = torch.device("cuda:0")
+ torch.cuda.set_device(device)
+
+ set_seed(config.seed)
+ print("Single GPU mode on cuda:0")
+
+low_memory = get_cuda_free_memory_gb(device) < 40
+torch.set_grad_enabled(False)
+
+# ======================== INITIALIZE PIPELINE ========================
+pipeline = SinglePromptCausalInferencePipeline(config, device=device)
+print("Generator device:", next(pipeline.generator.parameters()).device)
+print("VAE device:", next(pipeline.vae.parameters()).device)
+print("Text encoder device:", next(pipeline.text_encoder.parameters()).device)
+
+# ======================== LOAD QUANTIZED WEIGHTS ========================
+if args.use_quantized and quantized_base_state is not None:
+ print("\n" + "="*70)
+ print("š§ Loading quantized base model into pipeline")
+ print("="*70)
+
+ missing, unexpected = pipeline.generator.load_state_dict(quantized_base_state, strict=False)
+ if local_rank == 0:
+ if missing:
+ print(f"[Warning] {len(missing)} parameters missing: {missing[:8]} ...")
+ if unexpected:
+ print(f"[Warning] {len(unexpected)} unexpected params: {unexpected[:8]} ...")
+ print("ā
Quantized base model loaded")
+
+ print("š§ Converting ALL generator parameters to bfloat16...")
+ for name, param in pipeline.generator.named_parameters():
+ if param.dtype != torch.bfloat16:
+ param.data = param.data.to(torch.bfloat16)
+ for name, buffer in pipeline.generator.named_buffers():
+ if buffer.dtype != torch.bfloat16 and buffer.dtype == torch.float32:
+ buffer.data = buffer.data.to(torch.bfloat16)
+ print("ā
All generator parameters converted to bfloat16")
+
+ # Clear the CPU checkpoint to free memory
+ del base_checkpoint
+ del quantized_base_state
+ import gc
+ gc.collect()
+
+# --------------------------- LoRA support (optional) ---------------------------
+from utils.lora_utils import configure_lora_for_model
+import peft
+
+pipeline.is_lora_enabled = False
+if getattr(config, "adapter", None) and configure_lora_for_model is not None:
+ if local_rank == 0:
+ print(f"\nš§ LoRA enabled with config: {config.adapter}")
+ print("Applying LoRA to generator (inference)...")
+
+ pipeline.generator.model = configure_lora_for_model(
+ pipeline.generator.model,
+ model_name="generator",
+ lora_config=config.adapter,
+ is_main_process=(local_rank == 0),
+ )
+
+ # Load quantized LoRA weights
+ if args.use_quantized and quantized_lora_state is not None:
+ if local_rank == 0:
+ print(f"Loading QUANTIZED LoRA weights from CPU")
+ peft.set_peft_model_state_dict(pipeline.generator.model, quantized_lora_state)
+ if local_rank == 0:
+ print("ā
Quantized LoRA weights loaded")
+
+ # Clear LoRA checkpoint
+ del lora_checkpoint
+ del quantized_lora_state
+ gc.collect()
+ elif not args.use_quantized:
+ # Original LoRA loading for non-quantized
+ lora_ckpt_path = getattr(config, "lora_ckpt", None)
+ if lora_ckpt_path:
+ if local_rank == 0:
+ print(f"Loading LoRA checkpoint from {lora_ckpt_path}")
+ lora_checkpoint = torch.load(lora_ckpt_path, map_location="cpu")
+ if isinstance(lora_checkpoint, dict) and "generator_lora" in lora_checkpoint:
+ peft.set_peft_model_state_dict(pipeline.generator.model, lora_checkpoint["generator_lora"])
+ else:
+ peft.set_peft_model_state_dict(pipeline.generator.model, lora_checkpoint)
+ if local_rank == 0:
+ print("LoRA weights loaded")
+
+ pipeline.is_lora_enabled = True
+
+# ======================== MODEL PARALLEL GPU SETUP ========================
+print("\nš Setting up model parallelism (Text ā GPU0, Gen+VAE ā GPU2)")
+
+gpu_text = torch.device("cuda:0")
+gpu_main = torch.device("cuda:2")
+
+torch.cuda.empty_cache()
+
+# 1ļøā£ Move TEXT ENCODER to GPU 0 (prompt encoding only)
+print("š§ Moving text encoder to GPU 0...")
+pipeline.text_encoder = pipeline.text_encoder.to(gpu_text)
+print(f"Text encoder device: {next(pipeline.text_encoder.parameters()).device}")
+
+torch.cuda.empty_cache()
+
+# 2ļøā£ Move GENERATOR to GPU 2
+print("š¬ Moving generator to GPU 2 in bfloat16...")
+pipeline.generator = pipeline.generator.to(gpu_main, dtype=torch.bfloat16)
+print(f"Generator device: {next(pipeline.generator.parameters()).device}")
+
+torch.cuda.empty_cache()
+
+# 3ļøā£ Move VAE to GPU 2
+print("š¼ļø Moving VAE to GPU 2 in bfloat16...")
+pipeline.vae = pipeline.vae.to(gpu_main, dtype=torch.bfloat16)
+print(f"VAE device: {next(pipeline.vae.parameters()).device}")
+
+torch.cuda.empty_cache()
+
+print("\nā
Model parallel setup complete")
+print(f" Text Encoder ā {next(pipeline.text_encoder.parameters()).device}")
+print(f" Generator ā {next(pipeline.generator.parameters()).device}")
+print(f" VAE ā {next(pipeline.vae.parameters()).device}")
+
+# Set main device for noise generation & diffusion
+device = gpu_main
+# ======================== END GPU SETUP ========================
+
+# ======================== OPTIONAL: USE SPECIFIC GPU FOR INFERENCE ========================
+# Uncomment and modify if you want to use a specific GPU (e.g., GPU 2)
+# inference_device = torch.device("cuda:2")
+# print(f"\nš Moving models to {inference_device} for inference...")
+# pipeline.generator.to(inference_device)
+# pipeline.vae.to(inference_device)
+# torch.cuda.empty_cache()
+# device = inference_device
+# ======================== END GPU SETUP ========================
+
+# ----------------------------- Build dataset -----------------------------
+dataset = TextDataset(prompt_path=config.data_path, extended_prompt_path=config.data_path)
+
+num_prompts_total = len(dataset)
+print(f"Number of prompt lines: {num_prompts_total}")
+
+if dist.is_initialized():
+ sampler = DistributedSampler(dataset, shuffle=False, drop_last=True)
+else:
+ sampler = SequentialSampler(dataset)
+
+dataloader = DataLoader(dataset, batch_size=1, sampler=sampler, num_workers=0, drop_last=False)
+
+if local_rank == 0:
+ os.makedirs(config.output_folder, exist_ok=True)
+
+if dist.is_initialized():
+ dist.barrier()
+
+# ----------------------------- Inference loop -----------------------------
+print("\n" + "="*70)
+print("š Starting video generation...")
+print("="*70)
+
+for i, batch_data in tqdm(enumerate(dataloader), disable=(local_rank != 0)):
+ idx = batch_data["idx"].item()
+
+ # Get the prompt from the batch
+ prompt = batch_data['prompts'][0]
+ extended_prompt = batch_data.get('extended_prompts', [prompt])[0]
+
+ # Use extended prompt if available, otherwise use regular prompt
+ text_prompt = extended_prompt if extended_prompt else prompt
+ prompts = [text_prompt] * config.num_samples
+
+ # Get the actual device where generator is located
+ gen_device = next(pipeline.generator.parameters()).device
+
+ sampled_noise = torch.randn(
+ [
+ config.num_samples,
+ config.num_output_frames,
+ 16,
+ 60,
+ 104,
+ ],
+ device=gen_device, # Use generator's device
+ dtype=torch.bfloat16,
+ )
+
+ with torch.autocast("cuda", dtype=torch.bfloat16):
+ video = pipeline.inference(
+ noise=sampled_noise,
+ text_prompts=prompts,
+ return_latents=False,
+ )
+
+ current_video = rearrange(video, "b t c h w -> b t h w c").cpu() * 255.0
+
+ if dist.is_initialized():
+ rank = dist.get_rank()
+ else:
+ rank = 0
+
+ model_type = "quantized" if args.use_quantized else "regular"
+
+ for seed_idx in range(config.num_samples):
+ if config.save_with_index:
+ output_path = os.path.join(config.output_folder, f"rank{rank}-{idx}-{seed_idx}_{model_type}.mp4")
+ else:
+ short_name = text_prompt[:100].replace("/", "_")
+ output_path = os.path.join(config.output_folder, f"rank{rank}-{short_name}-{seed_idx}_{model_type}.mp4")
+
+ write_video(output_path, current_video[seed_idx].to(torch.uint8), fps=16)
+
+ if local_rank == 0:
+ print(f"ā
Saved: {output_path}")
+
+ if config.inference_iter != -1 and i >= config.inference_iter:
+ break
+
+print("\n" + "="*70)
+print("š Video generation complete!")
+print("="*70)
+
+if dist.is_initialized():
+ dist.destroy_process_group()
diff --git a/single_quantized.sh b/single_quantized.sh
new file mode 100755
index 0000000..186caf0
--- /dev/null
+++ b/single_quantized.sh
@@ -0,0 +1 @@
+python single_prompt_inference.py --config_path configs/longlive_inference.yaml --use_quantized
\ No newline at end of file
diff --git a/test_prompts.txt b/test_prompts.txt
new file mode 100644
index 0000000..0d75b9f
--- /dev/null
+++ b/test_prompts.txt
@@ -0,0 +1,3 @@
+A serene garden with colorful butterflies, sunny day, photorealistic
+The butterflies begin to glow with a magical light, gathering together
+The garden transforms into an enchanted mystical forest at twilight