diff --git a/.gitignore b/.gitignore index ad9e8e0..d5596fb 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,6 @@ # just4dev devscripts/ +out* +mycode +work_dir diff --git a/examples/text_encoder_trt/convert_CLIP_L_to_trt_engine.py b/examples/text_encoder_trt/convert_CLIP_L_to_trt_engine.py new file mode 100644 index 0000000..0b3a662 --- /dev/null +++ b/examples/text_encoder_trt/convert_CLIP_L_to_trt_engine.py @@ -0,0 +1,51 @@ +import os +import argparse + +import torch +from loguru import logger + +from lightx2v.text2v.models.text_encoders.hf.clip.model import TextEncoderHFClipModel +from lightx2v.text2v.models.text_encoders.trt.clip.trt_clip_infer import CLIPTrtModelInfer + + +def parse_args(): + args = argparse.ArgumentParser() + args.add_argument("--model_path", help="", type=str, default="/mtc/yongyang/models/x2v_models/hunyuan/lightx2v_format/t2v") + args.add_argument("--dtype", default=torch.float32) + args.add_argument("--device", default="cuda", type=str) + return args.parse_args() + + +def convert_trt_engine(args): + init_device = torch.device(args.device) + text_encoder_2 = TextEncoderHFClipModel(os.path.join(args.model_path, "text_encoder_2"), init_device) + texts = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." + tokens = text_encoder_2.tokenizer( + texts, + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=True, + truncation=True, + max_length=text_encoder_2.max_length, + padding="max_length", + return_tensors="pt", + ).to(init_device) + input_ids=tokens["input_ids"].to(init_device) + attention_mask=tokens["attention_mask"].to(init_device) + onnx_path = CLIPTrtModelInfer.export_to_onnx(text_encoder_2.model, model_dir=args.model_path, input_ids=input_ids, attention_mask=attention_mask) + del text_encoder_2 + torch.cuda.empty_cache() + engine_path = onnx_path.replace(".onnx", ".engine") + CLIPTrtModelInfer.convert_to_trt_engine(onnx_path, engine_path) + logger.info(f"ONNX: {onnx_path}") + logger.info(f"TRT Engine: {engine_path}") + return + + +def main(): + args = parse_args() + convert_trt_engine(args) + + +if __name__ == "__main__": + main() diff --git a/examples/text_encoder_trt/convert_t5xxl_to_trt_engine.py b/examples/text_encoder_trt/convert_t5xxl_to_trt_engine.py new file mode 100644 index 0000000..69e2d46 --- /dev/null +++ b/examples/text_encoder_trt/convert_t5xxl_to_trt_engine.py @@ -0,0 +1,45 @@ +from pathlib import Path +import os +import argparse + +import torch +from loguru import logger + +from lightx2v.text2v.models.text_encoders.hf.t5.model import T5EncoderModel +from lightx2v.text2v.models.text_encoders.trt.t5.trt_t5_infer import T5TrtModelInfer + + +def parse_args(): + args = argparse.ArgumentParser() + args.add_argument("--model_path", help="", type=str, default="models/Wan2.1-T2V-1.3B") + args.add_argument("--dtype", default=torch.float16) + args.add_argument("--device", default="cuda", type=str) + return args.parse_args() + + +def convert_trt_engine(args): + t5_checkpoint_path = os.path.join(args.model_path, "models_t5_umt5-xxl-enc-bf16.pth") + t5_tokenizer_path = os.path.join(args.model_path, "google/umt5-xxl") + assert Path(t5_checkpoint_path).exists(), f"{t5_checkpoint_path} not exists." + model = T5EncoderModel(text_len=512, dtype=args.dtype, device=args.device, checkpoint_path=t5_checkpoint_path, tokenizer_path=t5_tokenizer_path, shard_fn=None) + texts = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." + ids, mask = model.tokenizer(texts, return_mask=True, add_special_tokens=True) + ids = ids.to(args.device) + mask = mask.to(args.device) + onnx_path = T5TrtModelInfer.export_to_onnx(model.model, model_dir=args.model_path, ids=ids, mask=mask) + del model + torch.cuda.empty_cache() + engine_path = onnx_path.replace(".onnx", ".engine") + T5TrtModelInfer.convert_to_trt_engine(onnx_path, engine_path) + logger.info(f"ONNX: {onnx_path}") + logger.info(f"TRT Engine: {engine_path}") + return + + +def main(): + args = parse_args() + convert_trt_engine(args) + + +if __name__ == "__main__": + main() diff --git a/examples/vae_trt/convert_trt.sh b/examples/vae_trt/convert_trt.sh index f0c5599..12f671c 100644 --- a/examples/vae_trt/convert_trt.sh +++ b/examples/vae_trt/convert_trt.sh @@ -12,5 +12,5 @@ export PYTHONPATH="./":$PYTHONPATH # --optShapes=inp:1x16x17x32x16 \ # --maxShapes=inp:1x16x17x32x32 -model_path="" +model_path="/mtc/yongyang/models/x2v_models/hunyuan/lightx2v_format/t2v/" python examples/vae_trt/convert_vae_trt_engine.py --model_path ${model_path} diff --git a/examples/vae_trt/convert_vae_trt_engine.py b/examples/vae_trt/convert_vae_trt_engine.py index 0882449..50ecf0f 100644 --- a/examples/vae_trt/convert_vae_trt_engine.py +++ b/examples/vae_trt/convert_vae_trt_engine.py @@ -17,7 +17,7 @@ def parse_args(): return args.parse_args() -def convert_vae_trt_engine(args): +def convert_trt_engine(args): vae_path = os.path.join(args.model_path, "hunyuan-video-t2v-720p/vae") assert Path(vae_path).exists(), f"{vae_path} not exists." config = AutoencoderKLCausal3D.load_config(vae_path) @@ -38,7 +38,7 @@ def convert_vae_trt_engine(args): def main(): args = parse_args() - convert_vae_trt_engine(args) + convert_trt_engine(args) if __name__ == "__main__": diff --git a/lightx2v/__main__.py b/lightx2v/__main__.py index 9b14168..f39d291 100755 --- a/lightx2v/__main__.py +++ b/lightx2v/__main__.py @@ -10,7 +10,7 @@ import numpy as np from PIL import Image from lightx2v.text2v.models.text_encoders.hf.llama.model import TextEncoderHFLlamaModel -from lightx2v.text2v.models.text_encoders.hf.clip.model import TextEncoderHFClipModel +from lightx2v.text2v.models.text_encoders.trt.clip.model import TextEncoderHFClipModel from lightx2v.text2v.models.text_encoders.hf.t5.model import T5EncoderModel from lightx2v.text2v.models.text_encoders.hf.llava.model import TextEncoderHFLlavaModel @@ -22,7 +22,7 @@ from lightx2v.text2v.models.networks.hunyuan.model import HunyuanModel from lightx2v.text2v.models.networks.wan.model import WanModel -from lightx2v.text2v.models.video_encoders.hf.autoencoder_kl_causal_3d.model import VideoEncoderKLCausal3DModel +from lightx2v.text2v.models.video_encoders.trt.autoencoder_kl_causal_3d.model import VideoEncoderKLCausal3DModel from lightx2v.text2v.models.video_encoders.hf.wan.vae import WanVAE from lightx2v.utils.utils import save_videos_grid, seed_all, cache_video from lightx2v.common.ops import * @@ -54,6 +54,7 @@ def load_models(args, model_config): text_len=model_config["text_len"], dtype=torch.bfloat16, device=init_device, + engine_path="/mtc/wq/project/sd/code/lightx2v/mycode/onnx/t5_fp32/t5_bf16.engine", checkpoint_path=os.path.join(args.model_path, "models_t5_umt5-xxl-enc-bf16.pth"), tokenizer_path=os.path.join(args.model_path, "google/umt5-xxl"), shard_fn=None, diff --git a/lightx2v/common/backend_infer/trt/trt_infer_base.py b/lightx2v/common/backend_infer/trt/trt_infer_base.py new file mode 100644 index 0000000..f1b6e44 --- /dev/null +++ b/lightx2v/common/backend_infer/trt/trt_infer_base.py @@ -0,0 +1,102 @@ +from pathlib import Path + +import numpy as np +import torch +import tensorrt as trt +from cuda import cudart +import torch.nn as nn +from loguru import logger + +from lightx2v.common.backend_infer.trt import common + +TRT_LOGGER = trt.Logger(trt.Logger.INFO) + + +np_torch_dtype_map = {"float16": torch.float16, "float32": torch.float32} + + +class TrtModelInferBase(nn.Module): + """ + Implements inference for the TensorRT engine. + """ + + def __init__(self, engine_path, **kwargs): + """ + :param engine_path: The path to the serialized engine to load from disk. + """ + # Load TRT engine + if not Path(engine_path).exists(): + raise FileNotFoundError(f"Tensorrt engine `{str(engine_path)}` not exists.") + self.logger = trt.Logger(trt.Logger.ERROR) + with open(engine_path, "rb") as f, trt.Runtime(self.logger) as runtime: + assert runtime + self.engine = runtime.deserialize_cuda_engine(f.read()) + assert self.engine + self.context = self.engine.create_execution_context() + assert self.context + logger.info(f"Loaded tensorrt engine from `{engine_path}`") + self.inp_list = [] + self.out_list = [] + self.get_io_properties() + + def alloc(self, shape_dict): + """ + Setup I/O bindings + """ + self.inputs = [] + self.outputs = [] + self.allocations = [] + for i in range(self.engine.num_io_tensors): + name = self.engine.get_tensor_name(i) + is_input = False + if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: + is_input = True + dtype = self.engine.get_tensor_dtype(name) + shape = shape_dict[name] + if is_input: + self.context.set_input_shape(name, shape) + self.batch_size = shape[0] + if dtype == trt.DataType.BF16: + dtype = trt.DataType.HALF + size = np.dtype(trt.nptype(dtype)).itemsize + for s in shape: + size *= s + allocation = common.cuda_call(cudart.cudaMalloc(size)) + binding = { + "index": i, + "name": name, + "dtype": np.dtype(trt.nptype(dtype)), + "shape": list(shape), + "allocation": allocation, + } + self.allocations.append(allocation) + if is_input: + self.inputs.append(binding) + else: + self.outputs.append(binding) + + assert self.batch_size > 0 + assert len(self.inputs) > 0 + assert len(self.outputs) > 0 + assert len(self.allocations) > 0 + + def get_io_properties(self): + for bind in self.engine: + mode = self.engine.get_tensor_mode(bind) + dtype = trt.nptype(self.engine.get_tensor_dtype(bind)) + if mode.name == "INPUT": + self.inp_list.append({"name": bind, "shape": self.engine.get_tensor_shape(bind), "dtype": dtype}) + else: + self.out_list.append({"name": bind, "shape": self.engine.get_tensor_shape(bind), "dtype": dtype}) + return + + def __call__(self, batch, *args, **kwargs): + pass + + @staticmethod + def export_to_onnx(model: torch.nn.Module, model_dir): + pass + + @staticmethod + def convert_to_trt_engine(onnx_path, engine_path): + pass diff --git a/lightx2v/text2v/models/text_encoders/hf/t5/model.py b/lightx2v/text2v/models/text_encoders/hf/t5/model.py index 5b7fbee..1574016 100755 --- a/lightx2v/text2v/models/text_encoders/hf/t5/model.py +++ b/lightx2v/text2v/models/text_encoders/hf/t5/model.py @@ -7,7 +7,8 @@ import torch.nn as nn import torch.nn.functional as F -from .tokenizer import HuggingfaceTokenizer +from lightx2v.text2v.models.text_encoders.hf.t5.tokenizer import HuggingfaceTokenizer +# from .tokenizer import HuggingfaceTokenizer __all__ = [ "T5Model", @@ -459,15 +460,7 @@ def umt5_xxl(**kwargs): class T5EncoderModel: - def __init__( - self, - text_len, - dtype=torch.bfloat16, - device=torch.cuda.current_device(), - checkpoint_path=None, - tokenizer_path=None, - shard_fn=None, - ): + def __init__(self, text_len, dtype=torch.bfloat16, device=None, checkpoint_path=None, tokenizer_path=None, shard_fn=None, **kwargs): self.text_len = text_len self.dtype = dtype self.device = device @@ -497,8 +490,6 @@ def infer(self, texts, args): self.to_cuda() ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True) - ids = ids.cuda() - mask = mask.cuda() seq_lens = mask.gt(0).sum(dim=1).long() context = self.model(ids, mask) @@ -509,17 +500,24 @@ def infer(self, texts, args): if __name__ == "__main__": - checkpoint_dir = "" + checkpoint_dir = "/mtc/wq/project/sd/models/Wan2.1-T2V-1.3B" t5_checkpoint = "models_t5_umt5-xxl-enc-bf16.pth" t5_tokenizer = "google/umt5-xxl" model = T5EncoderModel( text_len=512, - dtype=torch.bfloat16, - device=torch.device("cuda"), + dtype=torch.float16, + device=torch.device("cpu"), checkpoint_path=os.path.join(checkpoint_dir, t5_checkpoint), tokenizer_path=os.path.join(checkpoint_dir, t5_tokenizer), shard_fn=None, ) text = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." - outputs = model.infer(text) + from dataclasses import dataclass + + @dataclass + class TempArgs: + cpu_offload: False + + args = TempArgs(cpu_offload=False) + outputs = model.infer(text, args) print(outputs) diff --git a/lightx2v/text2v/models/text_encoders/trt/clip/__init__.py b/lightx2v/text2v/models/text_encoders/trt/clip/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lightx2v/text2v/models/text_encoders/trt/clip/model.py b/lightx2v/text2v/models/text_encoders/trt/clip/model.py new file mode 100644 index 0000000..5f4856f --- /dev/null +++ b/lightx2v/text2v/models/text_encoders/trt/clip/model.py @@ -0,0 +1,61 @@ +import os + +import torch +from transformers import AutoTokenizer + +from .trt_clip_infer import CLIPTrtModelInfer + + +class TextEncoderHFClipModel: + def __init__(self, model_path, device, **kwargs): + self.device = device + self.model_path = model_path + self.engine_path = os.path.join(model_path, "onnx/clip_l/clip_l.engine") + self.init() + self.load() + + def init(self): + self.max_length = 77 + + def load(self): + self.model = CLIPTrtModelInfer(engine_path=self.engine_path) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, padding_side="right") + + def to_cpu(self): + self.model = self.model.to("cpu") + + def to_cuda(self): + self.model = self.model.to("cuda") + + @torch.no_grad() + def infer(self, text, args): + if args.cpu_offload: + self.to_cuda() + tokens = self.tokenizer( + text, + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=True, + truncation=True, + max_length=self.max_length, + padding="max_length", + return_tensors="pt", + ).to("cuda") + + outputs = self.model( + ids=tokens["input_ids"], + mask=tokens["attention_mask"], + ) + + last_hidden_state = outputs["pooler_output"] + if args.cpu_offload: + self.to_cpu() + return last_hidden_state, tokens["attention_mask"] + + +if __name__ == "__main__": + model_path = "" + model = TextEncoderHFClipModel(model_path, torch.device("cuda")) + text = "A cat walks on the grass, realistic style." + outputs = model.infer(text) + print(outputs) diff --git a/lightx2v/text2v/models/text_encoders/trt/clip/trt_clip_infer.py b/lightx2v/text2v/models/text_encoders/trt/clip/trt_clip_infer.py new file mode 100644 index 0000000..6c48306 --- /dev/null +++ b/lightx2v/text2v/models/text_encoders/trt/clip/trt_clip_infer.py @@ -0,0 +1,73 @@ +from pathlib import Path +from subprocess import Popen + +import torch +import tensorrt as trt +from loguru import logger +import numpy as np +from torch.nn.modules import Module + +from lightx2v.common.backend_infer.trt import common +from lightx2v.common.backend_infer.trt.trt_infer_base import TrtModelInferBase, np_torch_dtype_map + +TRT_LOGGER = trt.Logger(trt.Logger.INFO) + + + +class CLIPTrtModelInfer(TrtModelInferBase): + def __init__(self, engine_path, **kwargs): + super().__init__(engine_path, **kwargs) + + def __call__(self, ids, mask, *args, **kwargs): + device = ids.device + ids = ids.cpu().numpy() + mask = mask.cpu().numpy() + shp_dict = {i["name"]: i["shape"] for i in self.inp_list} + shp_dict.update({i["name"]: i["shape"] for i in self.out_list}) + self.alloc(shp_dict) + + out_list = [] + for o in self.outputs: + out_list.append(np.zeros(o["shape"], o["dtype"])) + for inp, data in zip(self.inputs, [ids, mask]): + common.memcpy_host_to_device(inp["allocation"], np.ascontiguousarray(data)) + self.context.execute_v2(self.allocations) + outs = [] + for i, out in enumerate(out_list): + common.memcpy_device_to_host(out, self.outputs[i]["allocation"]) + out = torch.from_numpy(out).to(device) + out = out.type(torch.bfloat16) + outs.append(out) + return {"pooler_output": outs[1]} + + @staticmethod + def export_to_onnx(model: Module, model_dir, *args, **kwargs): + ids = kwargs.get("input_ids") + mask = kwargs.get("attention_mask") + onnx_dir = Path(model_dir) / "text_encoder_2/onnx/clip_l" + onnx_dir.mkdir(parents=True, exist_ok=True) + onnx_path = str(onnx_dir / "clip_l.onnx") + + class ClipWrapper(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, input_ids, attention_mask, return_dict=False, output_hidden_states=False): + out = self.model(input_ids, attention_mask, return_dict=return_dict, output_hidden_states=output_hidden_states) + return out + + model_wrapped = ClipWrapper() + model_wrapped.model = model + torch.onnx.export(model_wrapped, (ids, mask), onnx_path, opset_version=14) + return onnx_path + + @staticmethod + def convert_to_trt_engine(onnx_path, engine_path, *args, **kwargs): + logger.info("Start to convert ONNX to tensorrt engine.") + cmd = f"trtexec --onnx={onnx_path} --saveEngine={engine_path} --bf16 " + p = Popen(cmd, shell=True) + p.wait() + if not Path(engine_path).exists(): + raise RuntimeError(f"Convert onnx({onnx_path}) to tensorrt engine failed.") + logger.info("Finish tensorrt converting.") + return engine_path diff --git a/lightx2v/text2v/models/text_encoders/trt/t5/model.py b/lightx2v/text2v/models/text_encoders/trt/t5/model.py new file mode 100644 index 0000000..6b3b6e4 --- /dev/null +++ b/lightx2v/text2v/models/text_encoders/trt/t5/model.py @@ -0,0 +1,39 @@ +import torch + +from ...hf.t5.tokenizer import HuggingfaceTokenizer +from .trt_t5_infer import T5TrtModelInfer + + +class T5EncoderModel: + def __init__(self, text_len, dtype=torch.bfloat16, device=torch.cuda.current_device(), engine_path=None, checkpoint_path=None, tokenizer_path=None, **kwargs): + self.text_len = text_len + self.dtype = dtype + self.device = device + self.checkpoint_path = checkpoint_path + self.tokenizer_path = tokenizer_path + + # init model + self.model = T5TrtModelInfer(engine_path=engine_path) + # init tokenizer + self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=text_len, clean="whitespace") + + def to_cpu(self): + self.model = self.model.to("cpu") + + def to_cuda(self): + self.model = self.model.to("cuda") + + def infer(self, texts, args, **kwargs): + if args.cpu_offload: + self.to_cuda() + + ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True) + ids = ids.cuda() + mask = mask.cuda() + seq_lens = mask.gt(0).sum(dim=1).long() + context = self.model(ids, mask) + + if args.cpu_offload: + self.to_cpu() + + return [u[:v] for u, v in zip(context, seq_lens)] diff --git a/lightx2v/text2v/models/text_encoders/trt/t5/trt_t5_infer.py b/lightx2v/text2v/models/text_encoders/trt/t5/trt_t5_infer.py new file mode 100644 index 0000000..496d69a --- /dev/null +++ b/lightx2v/text2v/models/text_encoders/trt/t5/trt_t5_infer.py @@ -0,0 +1,62 @@ +from pathlib import Path +from subprocess import Popen + +import torch +import tensorrt as trt +from loguru import logger +import numpy as np +from torch.nn.modules import Module + +from lightx2v.common.backend_infer.trt import common +from lightx2v.common.backend_infer.trt.trt_infer_base import TrtModelInferBase, np_torch_dtype_map + +TRT_LOGGER = trt.Logger(trt.Logger.INFO) + + +class T5TrtModelInfer(TrtModelInferBase): + def __init__(self, engine_path, **kwargs): + super().__init__(engine_path, **kwargs) + import onnxruntime as ort + + def __call__(self, ids, mask, *args, **kwargs): + device = ids.device + ids = ids.cpu().numpy() + mask = mask.cpu().numpy() + shp_dict = {i["name"]: i["shape"] for i in self.inp_list} + shp_dict.update({i["name"]: i["shape"] for i in self.out_list}) + self.alloc(shp_dict) + + out_list = [] + for o in self.outputs: + out_list.append(np.zeros(o["shape"], o["dtype"])) + for inp, data in zip(self.inputs, [ids, mask]): + common.memcpy_host_to_device(inp["allocation"], np.ascontiguousarray(data)) + self.context.execute_v2(self.allocations) + outs = [] + for i, out in enumerate(out_list): + common.memcpy_device_to_host(out, self.outputs[i]["allocation"]) + out = torch.from_numpy(out).to(device) + out = out.type(torch.bfloat16) + outs.append(out) + return outs[0] + + @staticmethod + def export_to_onnx(model: Module, model_dir, *args, **kwargs): + ids = kwargs.get("ids") + mask = kwargs.get("mask") + onnx_dir = Path(model_dir) / "onnx/t5" + onnx_dir.mkdir(parents=True, exist_ok=True) + onnx_path = str(onnx_dir / "t5.onnx") + torch.onnx.export(model, (ids, mask), onnx_path, opset_version=14) + return onnx_path + + @staticmethod + def convert_to_trt_engine(onnx_path, engine_path, *args, **kwargs): + logger.info("Start to convert ONNX to tensorrt engine.") + cmd = f"trtexec --onnx={onnx_path} --saveEngine={engine_path} --bf16 " + p = Popen(cmd, shell=True) + p.wait() + if not Path(engine_path).exists(): + raise RuntimeError(f"Convert onnx({onnx_path}) to tensorrt engine failed.") + logger.info("Finish tensorrt converting.") + return engine_path diff --git a/lightx2v/text2v/models/video_encoders/hf/wan/vae.py b/lightx2v/text2v/models/video_encoders/hf/wan/vae.py index d1a62f5..bd9481f 100755 --- a/lightx2v/text2v/models/video_encoders/hf/wan/vae.py +++ b/lightx2v/text2v/models/video_encoders/hf/wan/vae.py @@ -82,12 +82,12 @@ def __init__(self, dim, mode): # layers if mode == "upsample2d": self.resample = nn.Sequential( - Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + Upsample(scale_factor=(2.0, 2.0), mode="nearest"), nn.Conv2d(dim, dim // 2, 3, padding=1), ) elif mode == "upsample3d": self.resample = nn.Sequential( - Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + Upsample(scale_factor=(2.0, 2.0), mode="nearest"), nn.Conv2d(dim, dim // 2, 3, padding=1), ) self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) @@ -253,7 +253,8 @@ def forward(self, x): k, v, ) - x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) + # x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) + x = x.view(x.shape[0], x.shape[2], x.shape[3]).permute(0, 2, 1).reshape(b * t, c, h, w) # output x = self.proj(x) @@ -649,7 +650,7 @@ def _video_vae(pretrained_path=None, z_dim=None, device="cpu", **kwargs): return model -class WanVAE: +class WanVAE(nn.Module): def __init__( self, z_dim=16, @@ -658,6 +659,7 @@ def __init__( device="cuda", parallel=False, ): + super().__init__() self.dtype = dtype self.device = device self.parallel = parallel @@ -811,3 +813,6 @@ def decode(self, zs, generator, args): self.to_cpu() return images + + def forward(self, zs): + return self.decode(zs, None, None) diff --git a/lightx2v/text2v/models/video_encoders/trt/autoencoder_kl_causal_3d/model.py b/lightx2v/text2v/models/video_encoders/trt/autoencoder_kl_causal_3d/model.py index 774168f..2ca8045 100755 --- a/lightx2v/text2v/models/video_encoders/trt/autoencoder_kl_causal_3d/model.py +++ b/lightx2v/text2v/models/video_encoders/trt/autoencoder_kl_causal_3d/model.py @@ -6,7 +6,7 @@ class VideoEncoderKLCausal3DModel: - def __init__(self, model_path, dtype, device): + def __init__(self, model_path, dtype, device, **kwargs): self.model_path = model_path self.dtype = dtype self.device = device @@ -24,7 +24,7 @@ def load(self): trt_decoder = trt_vae_infer.HyVaeTrtModelInfer(engine_path=os.path.join(self.vae_path, "vae_decoder.engine")) self.model.decoder = trt_decoder - def decode(self, latents, generator): + def decode(self, latents, generator, **kwargs): latents = latents / self.model.config.scaling_factor latents = latents.to(dtype=self.dtype, device=self.device) self.model.enable_tiling() diff --git a/lightx2v/text2v/models/video_encoders/trt/autoencoder_kl_causal_3d/trt_vae_infer.py b/lightx2v/text2v/models/video_encoders/trt/autoencoder_kl_causal_3d/trt_vae_infer.py index 03664c2..7c79560 100644 --- a/lightx2v/text2v/models/video_encoders/trt/autoencoder_kl_causal_3d/trt_vae_infer.py +++ b/lightx2v/text2v/models/video_encoders/trt/autoencoder_kl_causal_3d/trt_vae_infer.py @@ -2,97 +2,26 @@ from pathlib import Path from subprocess import Popen -import numpy as np import torch +import numpy as np import tensorrt as trt -from cuda import cudart -import torch.nn as nn from loguru import logger from lightx2v.common.backend_infer.trt import common +from lightx2v.common.backend_infer.trt.trt_infer_base import TrtModelInferBase TRT_LOGGER = trt.Logger(trt.Logger.INFO) -class HyVaeTrtModelInfer(nn.Module): +class HyVaeTrtModelInfer(TrtModelInferBase): """ - Implements inference for the TensorRT engine. + Implements hunyuan vae inference for the TensorRT engine. """ def __init__(self, engine_path): - """ - :param engine_path: The path to the serialized engine to load from disk. - """ - # Load TRT engine - if not Path(engine_path).exists(): - # dir_name = str(Path(engine_path).parents) - # onnx_path = self.export_to_onnx(decoder, dir_name) - # self.convert_to_trt_engine(onnx_path, engine_path) - raise FileNotFoundError(f"VAE tensorrt engine `{str(engine_path)}` not exists.") - self.logger = trt.Logger(trt.Logger.ERROR) - with open(engine_path, "rb") as f, trt.Runtime(self.logger) as runtime: - assert runtime - self.engine = runtime.deserialize_cuda_engine(f.read()) - assert self.engine - self.context = self.engine.create_execution_context() - assert self.context - logger.info(f"Loaded VAE tensorrt engine from `{engine_path}`") - - def alloc(self, shape_dict): - """ - Setup I/O bindings - """ - self.inputs = [] - self.outputs = [] - self.allocations = [] - for i in range(self.engine.num_io_tensors): - name = self.engine.get_tensor_name(i) - is_input = False - if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: - is_input = True - dtype = self.engine.get_tensor_dtype(name) - # shape = self.engine.get_tensor_shape(name) - shape = shape_dict[name] - if is_input: - self.context.set_input_shape(name, shape) - self.batch_size = shape[0] - size = np.dtype(trt.nptype(dtype)).itemsize - for s in shape: - size *= s - allocation = common.cuda_call(cudart.cudaMalloc(size)) - binding = { - "index": i, - "name": name, - "dtype": np.dtype(trt.nptype(dtype)), - "shape": list(shape), - "allocation": allocation, - } - self.allocations.append(allocation) - if is_input: - self.inputs.append(binding) - else: - self.outputs.append(binding) - - assert self.batch_size > 0 - assert len(self.inputs) > 0 - assert len(self.outputs) > 0 - assert len(self.allocations) > 0 - - def input_spec(self): - """ - Get the specs for the input tensor of the network. Useful to prepare memory allocations. - :return: Two items, the shape of the input tensor and its (numpy) datatype. - """ - return self.inputs[0]["shape"], self.inputs[0]["dtype"] - - def output_spec(self): - """ - Get the specs for the output tensor of the network. Useful to prepare memory allocations. - :return: Two items, the shape of the output tensor and its (numpy) datatype. - """ - return self.outputs[0]["shape"], self.outputs[0]["dtype"] + super().__init__(engine_path) - def __call__(self, batch, top=1): + def __call__(self, batch, *args, **kwargs): """ Execute inference """ @@ -106,9 +35,10 @@ def get_output_shape(shp): out = (b, 3, 4 * (t - 1) + 1, h * 8, w * 8) return out - shp_dict = {"inp": batch.shape, "out": get_output_shape(batch.shape)} + vae_out_shape = get_output_shape(batch.shape) + shp_dict = {"inp": batch.shape, "out": vae_out_shape} self.alloc(shp_dict) - output = np.zeros(*self.output_spec()) + output = np.zeros(vae_out_shape, self.out_list[0]["dtype"]) # Process I/O and execute the network common.memcpy_host_to_device(self.inputs[0]["allocation"], np.ascontiguousarray(batch)) @@ -132,11 +62,7 @@ def export_to_onnx(decoder: torch.nn.Module, model_dir): opset_version=14, dynamic_axes={"inp": {1: "c1", 2: "c2", 3: "c3", 4: "c4"}, "out": {1: "c1", 2: "c2", 3: "c3", 4: "c4"}}, ) - # onnx_ori = onnx.load(out_path) os.system(f"onnxsim {out_path} {out_path}") - # onnx_opt, check = simplify(onnx_ori) - # assert check, f"Simplified ONNX model({out_path}) could not be validated." - # onnx.save(onnx_opt, out_path) logger.info("Finish VAE onnx exporting.") return out_path diff --git a/scripts/run_hunyuan_t2v.sh b/scripts/run_hunyuan_t2v.sh index c73fa19..008273a 100755 --- a/scripts/run_hunyuan_t2v.sh +++ b/scripts/run_hunyuan_t2v.sh @@ -1,12 +1,12 @@ #!/bin/bash # set path and first -lightx2v_path="" -model_path="" +lightx2v_path="/mtc/wq/project/sd/code/lightx2v" +model_path="/mtc/yongyang/models/x2v_models/hunyuan/lightx2v_format/t2v" # check section if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then - cuda_devices=0 + cuda_devices=2 echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable." export CUDA_VISIBLE_DEVICES=${cuda_devices} fi