Skip to content

Wan vae trt #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@

# just4dev
devscripts/
out*
mycode
work_dir
51 changes: 51 additions & 0 deletions examples/text_encoder_trt/convert_CLIP_L_to_trt_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import os
import argparse

import torch
from loguru import logger

from lightx2v.text2v.models.text_encoders.hf.clip.model import TextEncoderHFClipModel
from lightx2v.text2v.models.text_encoders.trt.clip.trt_clip_infer import CLIPTrtModelInfer


def parse_args():
args = argparse.ArgumentParser()
args.add_argument("--model_path", help="", type=str, default="/mtc/yongyang/models/x2v_models/hunyuan/lightx2v_format/t2v")
args.add_argument("--dtype", default=torch.float32)
args.add_argument("--device", default="cuda", type=str)
return args.parse_args()


def convert_trt_engine(args):
init_device = torch.device(args.device)
text_encoder_2 = TextEncoderHFClipModel(os.path.join(args.model_path, "text_encoder_2"), init_device)
texts = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
tokens = text_encoder_2.tokenizer(
texts,
return_length=False,
return_overflowing_tokens=False,
return_attention_mask=True,
truncation=True,
max_length=text_encoder_2.max_length,
padding="max_length",
return_tensors="pt",
).to(init_device)
input_ids=tokens["input_ids"].to(init_device)
attention_mask=tokens["attention_mask"].to(init_device)
onnx_path = CLIPTrtModelInfer.export_to_onnx(text_encoder_2.model, model_dir=args.model_path, input_ids=input_ids, attention_mask=attention_mask)
del text_encoder_2
torch.cuda.empty_cache()
engine_path = onnx_path.replace(".onnx", ".engine")
CLIPTrtModelInfer.convert_to_trt_engine(onnx_path, engine_path)
logger.info(f"ONNX: {onnx_path}")
logger.info(f"TRT Engine: {engine_path}")
return


def main():
args = parse_args()
convert_trt_engine(args)


if __name__ == "__main__":
main()
45 changes: 45 additions & 0 deletions examples/text_encoder_trt/convert_t5xxl_to_trt_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from pathlib import Path
import os
import argparse

import torch
from loguru import logger

from lightx2v.text2v.models.text_encoders.hf.t5.model import T5EncoderModel
from lightx2v.text2v.models.text_encoders.trt.t5.trt_t5_infer import T5TrtModelInfer


def parse_args():
args = argparse.ArgumentParser()
args.add_argument("--model_path", help="", type=str, default="models/Wan2.1-T2V-1.3B")
args.add_argument("--dtype", default=torch.float16)
args.add_argument("--device", default="cuda", type=str)
return args.parse_args()


def convert_trt_engine(args):
t5_checkpoint_path = os.path.join(args.model_path, "models_t5_umt5-xxl-enc-bf16.pth")
t5_tokenizer_path = os.path.join(args.model_path, "google/umt5-xxl")
assert Path(t5_checkpoint_path).exists(), f"{t5_checkpoint_path} not exists."
model = T5EncoderModel(text_len=512, dtype=args.dtype, device=args.device, checkpoint_path=t5_checkpoint_path, tokenizer_path=t5_tokenizer_path, shard_fn=None)
texts = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
ids, mask = model.tokenizer(texts, return_mask=True, add_special_tokens=True)
ids = ids.to(args.device)
mask = mask.to(args.device)
onnx_path = T5TrtModelInfer.export_to_onnx(model.model, model_dir=args.model_path, ids=ids, mask=mask)
del model
torch.cuda.empty_cache()
engine_path = onnx_path.replace(".onnx", ".engine")
T5TrtModelInfer.convert_to_trt_engine(onnx_path, engine_path)
logger.info(f"ONNX: {onnx_path}")
logger.info(f"TRT Engine: {engine_path}")
return


def main():
args = parse_args()
convert_trt_engine(args)


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion examples/vae_trt/convert_trt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ export PYTHONPATH="./":$PYTHONPATH
# --optShapes=inp:1x16x17x32x16 \
# --maxShapes=inp:1x16x17x32x32

model_path=""
model_path="/mtc/yongyang/models/x2v_models/hunyuan/lightx2v_format/t2v/"
python examples/vae_trt/convert_vae_trt_engine.py --model_path ${model_path}
4 changes: 2 additions & 2 deletions examples/vae_trt/convert_vae_trt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def parse_args():
return args.parse_args()


def convert_vae_trt_engine(args):
def convert_trt_engine(args):
vae_path = os.path.join(args.model_path, "hunyuan-video-t2v-720p/vae")
assert Path(vae_path).exists(), f"{vae_path} not exists."
config = AutoencoderKLCausal3D.load_config(vae_path)
Expand All @@ -38,7 +38,7 @@ def convert_vae_trt_engine(args):

def main():
args = parse_args()
convert_vae_trt_engine(args)
convert_trt_engine(args)


if __name__ == "__main__":
Expand Down
5 changes: 3 additions & 2 deletions lightx2v/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np
from PIL import Image
from lightx2v.text2v.models.text_encoders.hf.llama.model import TextEncoderHFLlamaModel
from lightx2v.text2v.models.text_encoders.hf.clip.model import TextEncoderHFClipModel
from lightx2v.text2v.models.text_encoders.trt.clip.model import TextEncoderHFClipModel
from lightx2v.text2v.models.text_encoders.hf.t5.model import T5EncoderModel
from lightx2v.text2v.models.text_encoders.hf.llava.model import TextEncoderHFLlavaModel

Expand All @@ -22,7 +22,7 @@
from lightx2v.text2v.models.networks.hunyuan.model import HunyuanModel
from lightx2v.text2v.models.networks.wan.model import WanModel

from lightx2v.text2v.models.video_encoders.hf.autoencoder_kl_causal_3d.model import VideoEncoderKLCausal3DModel
from lightx2v.text2v.models.video_encoders.trt.autoencoder_kl_causal_3d.model import VideoEncoderKLCausal3DModel
from lightx2v.text2v.models.video_encoders.hf.wan.vae import WanVAE
from lightx2v.utils.utils import save_videos_grid, seed_all, cache_video
from lightx2v.common.ops import *
Expand Down Expand Up @@ -54,6 +54,7 @@ def load_models(args, model_config):
text_len=model_config["text_len"],
dtype=torch.bfloat16,
device=init_device,
engine_path="/mtc/wq/project/sd/code/lightx2v/mycode/onnx/t5_fp32/t5_bf16.engine",
checkpoint_path=os.path.join(args.model_path, "models_t5_umt5-xxl-enc-bf16.pth"),
tokenizer_path=os.path.join(args.model_path, "google/umt5-xxl"),
shard_fn=None,
Expand Down
102 changes: 102 additions & 0 deletions lightx2v/common/backend_infer/trt/trt_infer_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from pathlib import Path

import numpy as np
import torch
import tensorrt as trt
from cuda import cudart
import torch.nn as nn
from loguru import logger

from lightx2v.common.backend_infer.trt import common

TRT_LOGGER = trt.Logger(trt.Logger.INFO)


np_torch_dtype_map = {"float16": torch.float16, "float32": torch.float32}


class TrtModelInferBase(nn.Module):
"""
Implements inference for the TensorRT engine.
"""

def __init__(self, engine_path, **kwargs):
"""
:param engine_path: The path to the serialized engine to load from disk.
"""
# Load TRT engine
if not Path(engine_path).exists():
raise FileNotFoundError(f"Tensorrt engine `{str(engine_path)}` not exists.")
self.logger = trt.Logger(trt.Logger.ERROR)
with open(engine_path, "rb") as f, trt.Runtime(self.logger) as runtime:
assert runtime
self.engine = runtime.deserialize_cuda_engine(f.read())
assert self.engine
self.context = self.engine.create_execution_context()
assert self.context
logger.info(f"Loaded tensorrt engine from `{engine_path}`")
self.inp_list = []
self.out_list = []
self.get_io_properties()

def alloc(self, shape_dict):
"""
Setup I/O bindings
"""
self.inputs = []
self.outputs = []
self.allocations = []
for i in range(self.engine.num_io_tensors):
name = self.engine.get_tensor_name(i)
is_input = False
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
is_input = True
dtype = self.engine.get_tensor_dtype(name)
shape = shape_dict[name]
if is_input:
self.context.set_input_shape(name, shape)
self.batch_size = shape[0]
if dtype == trt.DataType.BF16:
dtype = trt.DataType.HALF
size = np.dtype(trt.nptype(dtype)).itemsize
for s in shape:
size *= s
allocation = common.cuda_call(cudart.cudaMalloc(size))
binding = {
"index": i,
"name": name,
"dtype": np.dtype(trt.nptype(dtype)),
"shape": list(shape),
"allocation": allocation,
}
self.allocations.append(allocation)
if is_input:
self.inputs.append(binding)
else:
self.outputs.append(binding)

assert self.batch_size > 0
assert len(self.inputs) > 0
assert len(self.outputs) > 0
assert len(self.allocations) > 0

def get_io_properties(self):
for bind in self.engine:
mode = self.engine.get_tensor_mode(bind)
dtype = trt.nptype(self.engine.get_tensor_dtype(bind))
if mode.name == "INPUT":
self.inp_list.append({"name": bind, "shape": self.engine.get_tensor_shape(bind), "dtype": dtype})
else:
self.out_list.append({"name": bind, "shape": self.engine.get_tensor_shape(bind), "dtype": dtype})
return

def __call__(self, batch, *args, **kwargs):
pass

@staticmethod
def export_to_onnx(model: torch.nn.Module, model_dir):
pass

@staticmethod
def convert_to_trt_engine(onnx_path, engine_path):
pass
30 changes: 14 additions & 16 deletions lightx2v/text2v/models/text_encoders/hf/t5/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import torch.nn as nn
import torch.nn.functional as F

from .tokenizer import HuggingfaceTokenizer
from lightx2v.text2v.models.text_encoders.hf.t5.tokenizer import HuggingfaceTokenizer
# from .tokenizer import HuggingfaceTokenizer

__all__ = [
"T5Model",
Expand Down Expand Up @@ -459,15 +460,7 @@ def umt5_xxl(**kwargs):


class T5EncoderModel:
def __init__(
self,
text_len,
dtype=torch.bfloat16,
device=torch.cuda.current_device(),
checkpoint_path=None,
tokenizer_path=None,
shard_fn=None,
):
def __init__(self, text_len, dtype=torch.bfloat16, device=None, checkpoint_path=None, tokenizer_path=None, shard_fn=None, **kwargs):
self.text_len = text_len
self.dtype = dtype
self.device = device
Expand Down Expand Up @@ -497,8 +490,6 @@ def infer(self, texts, args):
self.to_cuda()

ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True)
ids = ids.cuda()
mask = mask.cuda()
seq_lens = mask.gt(0).sum(dim=1).long()
context = self.model(ids, mask)

Expand All @@ -509,17 +500,24 @@ def infer(self, texts, args):


if __name__ == "__main__":
checkpoint_dir = ""
checkpoint_dir = "/mtc/wq/project/sd/models/Wan2.1-T2V-1.3B"
t5_checkpoint = "models_t5_umt5-xxl-enc-bf16.pth"
t5_tokenizer = "google/umt5-xxl"
model = T5EncoderModel(
text_len=512,
dtype=torch.bfloat16,
device=torch.device("cuda"),
dtype=torch.float16,
device=torch.device("cpu"),
checkpoint_path=os.path.join(checkpoint_dir, t5_checkpoint),
tokenizer_path=os.path.join(checkpoint_dir, t5_tokenizer),
shard_fn=None,
)
text = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
outputs = model.infer(text)
from dataclasses import dataclass

@dataclass
class TempArgs:
cpu_offload: False

args = TempArgs(cpu_offload=False)
outputs = model.infer(text, args)
print(outputs)
Empty file.
61 changes: 61 additions & 0 deletions lightx2v/text2v/models/text_encoders/trt/clip/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os

import torch
from transformers import AutoTokenizer

from .trt_clip_infer import CLIPTrtModelInfer


class TextEncoderHFClipModel:
def __init__(self, model_path, device, **kwargs):
self.device = device
self.model_path = model_path
self.engine_path = os.path.join(model_path, "onnx/clip_l/clip_l.engine")
self.init()
self.load()

def init(self):
self.max_length = 77

def load(self):
self.model = CLIPTrtModelInfer(engine_path=self.engine_path)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, padding_side="right")

def to_cpu(self):
self.model = self.model.to("cpu")

def to_cuda(self):
self.model = self.model.to("cuda")

@torch.no_grad()
def infer(self, text, args):
if args.cpu_offload:
self.to_cuda()
tokens = self.tokenizer(
text,
return_length=False,
return_overflowing_tokens=False,
return_attention_mask=True,
truncation=True,
max_length=self.max_length,
padding="max_length",
return_tensors="pt",
).to("cuda")

outputs = self.model(
ids=tokens["input_ids"],
mask=tokens["attention_mask"],
)

last_hidden_state = outputs["pooler_output"]
if args.cpu_offload:
self.to_cpu()
return last_hidden_state, tokens["attention_mask"]


if __name__ == "__main__":
model_path = ""
model = TextEncoderHFClipModel(model_path, torch.device("cuda"))
text = "A cat walks on the grass, realistic style."
outputs = model.infer(text)
print(outputs)
Loading