Skip to content

Commit e35db17

Browse files
authored
Support Wan2.2 t2v diffusers quantization (#556)
## What does this PR do? **Type of change:** new feature **Overview:** Support Wan2.2 t2v diffusers quantization 1. fix torch2.9 support 2. add Wan2.2 t2v diffusers pipeline quantization Main difference of the Wan2.2 pipeline comparing to exisiting pipelines is that there are 2 backbone models for denoising. For the quantization therefore we need to quantize both of them. However, it turns out our base library does not well support quantization of multiple models in the same time. Therefore, the change here just stick to quantize a single model each time, and then run the quantization multiple times. So, we need to allow users to pick which backbone to quantize, therefore adding a new argment for it 3. add a workaround for the exporting ONNX issue when we upgrade diffusers to >= 0.35.0. The issue lies is the exporting of the torch.nn.RMSNorm. Some pipelines in the diffusers > 0.35.0 use the torch version RMSNorm while before that they use the diffusers' own version of RMSNorm. It turns out they are directly replacable so the workaround is to simply replace the torch RMSNorm usages with diffusers RMSNorm. But we need to fix it properly soon by porting our ONNX export to be based on torch dynamo instead of torchscript. Issue reported from external user: #262 4. allow use of a prompts file, which is simply a text file with a list of prompts, one prompt each line 5. allow each component of a pipeline to have different dtype accuracy. added a new list stype command line arg --component-dtype for this. example: --component-dtype vae:Float 6. print the summary of the quantized model so users can capture issues from log ## Usage python quantize.py \ --model wan2.2-t2v-14b \ --format fp8 \ --batch-size 4 \ --calib-size 64 \ --n-steps 20 \ --backbone transformer \ --model-dtype BFloat16 \ --component-dtype vae:Float \ --trt-high-precision-dtype BFloat16 \ --quantized-torch-ckpt-save-path ./wan_transformer.pt \ --onnx-dir wan-transformer-onnx \ --prompts-file wan-prompts.txt ## Testing Tested SDXL_BASE, LTX_VIDEO_DEV, WAN22_T2V ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information #262 --------- Signed-off-by: Shengliang Xu <[email protected]>
1 parent bc52b6c commit e35db17

File tree

4 files changed

+349
-113
lines changed

4 files changed

+349
-113
lines changed

examples/diffusers/quantization/onnx_utils/export.py

Lines changed: 95 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,11 @@
3838
import onnx
3939
import onnx_graphsurgeon as gs
4040
import torch
41-
from diffusers.models.transformers import FluxTransformer2DModel, SD3Transformer2DModel
41+
from diffusers.models.transformers import (
42+
FluxTransformer2DModel,
43+
SD3Transformer2DModel,
44+
WanTransformer3DModel,
45+
)
4246
from diffusers.models.transformers.transformer_ltx import LTXVideoTransformer3DModel
4347
from diffusers.models.unets import UNet2DConditionModel
4448
from torch.onnx import export as onnx_export
@@ -104,6 +108,11 @@
104108
"encoder_attention_mask": {0: "batch_size"},
105109
"video_coords": {0: "batch_size", 2: "latent_dim"},
106110
},
111+
"wan2.2-t2v-14b": {
112+
"hidden_states": {0: "batch_size", 2: "frame_num", 3: "height", 4: "width"},
113+
"encoder_hidden_states": {0: "batch_size"},
114+
"timestep": {0: "batch_size"},
115+
},
107116
}
108117

109118

@@ -159,7 +168,7 @@ def _gen_dummy_inp_and_dyn_shapes_sdxl(backbone, min_bs=1, opt_bs=1):
159168
"added_cond_kwargs.time_ids": {"min": [min_bs, 6], "opt": [opt_bs, 6]},
160169
}
161170

162-
dummy_input = {
171+
dummy_kwargs = {
163172
"sample": torch.randn(*dynamic_shapes["sample"]["min"]),
164173
"timestep": torch.ones(1),
165174
"encoder_hidden_states": torch.randn(*dynamic_shapes["encoder_hidden_states"]["min"]),
@@ -169,9 +178,9 @@ def _gen_dummy_inp_and_dyn_shapes_sdxl(backbone, min_bs=1, opt_bs=1):
169178
},
170179
"return_dict": False,
171180
}
172-
dummy_input = torch_to(dummy_input, dtype=backbone.dtype)
181+
dummy_kwargs = torch_to(dummy_kwargs, dtype=backbone.dtype)
173182

174-
return dummy_input, dynamic_shapes
183+
return dummy_kwargs, dynamic_shapes
175184

176185

177186
def _gen_dummy_inp_and_dyn_shapes_sd3(backbone, min_bs=1, opt_bs=1):
@@ -196,16 +205,16 @@ def _gen_dummy_inp_and_dyn_shapes_sd3(backbone, min_bs=1, opt_bs=1):
196205
},
197206
}
198207

199-
dummy_input = {
208+
dummy_kwargs = {
200209
"hidden_states": torch.randn(*dynamic_shapes["hidden_states"]["min"]),
201210
"timestep": torch.ones(1),
202211
"encoder_hidden_states": torch.randn(*dynamic_shapes["encoder_hidden_states"]["min"]),
203212
"pooled_projections": torch.randn(*dynamic_shapes["pooled_projections"]["min"]),
204213
"return_dict": False,
205214
}
206-
dummy_input = torch_to(dummy_input, dtype=backbone.dtype)
215+
dummy_kwargs = torch_to(dummy_kwargs, dtype=backbone.dtype)
207216

208-
return dummy_input, dynamic_shapes
217+
return dummy_kwargs, dynamic_shapes
209218

210219

211220
def _gen_dummy_inp_and_dyn_shapes_flux(backbone, min_bs=1, opt_bs=1):
@@ -237,7 +246,7 @@ def _gen_dummy_inp_and_dyn_shapes_flux(backbone, min_bs=1, opt_bs=1):
237246
dynamic_shapes["guidance"] = {"min": [1], "opt": [1]}
238247

239248
dtype = backbone.dtype
240-
dummy_input = {
249+
dummy_kwargs = {
241250
"hidden_states": torch.randn(*dynamic_shapes["hidden_states"]["min"], dtype=dtype),
242251
"encoder_hidden_states": torch.randn(
243252
*dynamic_shapes["encoder_hidden_states"]["min"], dtype=dtype
@@ -251,9 +260,9 @@ def _gen_dummy_inp_and_dyn_shapes_flux(backbone, min_bs=1, opt_bs=1):
251260
"return_dict": False,
252261
}
253262
if cfg.guidance_embeds: # flux-dev
254-
dummy_input["guidance"] = torch.full((1,), 3.5, dtype=torch.float32)
263+
dummy_kwargs["guidance"] = torch.full((1,), 3.5, dtype=torch.float32)
255264

256-
return dummy_input, dynamic_shapes
265+
return dummy_kwargs, dynamic_shapes
257266

258267

259268
def _gen_dummy_inp_and_dyn_shapes_ltx(backbone, min_bs=2, opt_bs=2):
@@ -282,7 +291,7 @@ def _gen_dummy_inp_and_dyn_shapes_ltx(backbone, min_bs=2, opt_bs=2):
282291
"opt": [opt_bs, 3, video_dim],
283292
},
284293
}
285-
dummy_input = {
294+
dummy_kwargs = {
286295
"hidden_states": torch.randn(*dynamic_shapes["hidden_states"]["min"], dtype=dtype),
287296
"encoder_hidden_states": torch.randn(
288297
*dynamic_shapes["encoder_hidden_states"]["min"], dtype=dtype
@@ -293,7 +302,57 @@ def _gen_dummy_inp_and_dyn_shapes_ltx(backbone, min_bs=2, opt_bs=2):
293302
),
294303
"video_coords": torch.randn(*dynamic_shapes["video_coords"]["min"], dtype=dtype),
295304
}
296-
return dummy_input, dynamic_shapes
305+
306+
return dummy_kwargs, dynamic_shapes
307+
308+
309+
def _gen_dummy_inp_and_dyn_shapes_wan(backbone, min_bs=1, opt_bs=2):
310+
assert isinstance(backbone, WanTransformer3DModel)
311+
dtype = backbone.dtype
312+
313+
channels = 16 # latent channels from VAE
314+
hidden_size = 4096 # text encoder hidden size (UMT5-XXL)
315+
316+
# num of frames for wan is 4*n+1, as from the official codebase:
317+
# https://github.com/Wan-Video/Wan2.2/blob/e9783574ef77be11fcab9aa5607905402538c08d/generate.py#L126
318+
# picking n == 1 as min, n = 20 as opt as 81 is the default num of frames in their code base
319+
min_num_frames = 4 * 1 + 1
320+
opt_num_frames = 4 * 20 + 1
321+
322+
# height and width configs are from their codebase:
323+
# https://github.com/Wan-Video/Wan2.2/blob/e9783574ef77be11fcab9aa5607905402538c08d/wan/configs/__init__.py#L21
324+
min_height = 480
325+
min_width = 480
326+
327+
# height max can be 1280, but opt setting is 1280x720, so use 720 here
328+
opt_height = 720
329+
opt_width = 1280
330+
331+
min_latent_height = min_height // 8
332+
min_latent_width = min_width // 8
333+
opt_latent_height = opt_height // 8
334+
opt_latent_width = opt_width // 8
335+
336+
dynamic_shapes = {
337+
"hidden_states": {
338+
"min": [min_bs, channels, min_num_frames, min_latent_height, min_latent_width],
339+
"opt": [opt_bs, channels, opt_num_frames, opt_latent_height, opt_latent_width],
340+
},
341+
"encoder_hidden_states": {
342+
"min": [min_bs, 512, hidden_size],
343+
"opt": [opt_bs, 512, hidden_size],
344+
},
345+
"timestep": {"min": [min_bs], "opt": [opt_bs]},
346+
}
347+
348+
dummy_kwargs = {
349+
"hidden_states": torch.randn(*dynamic_shapes["hidden_states"]["min"], dtype=dtype),
350+
"encoder_hidden_states": torch.randn(
351+
*dynamic_shapes["encoder_hidden_states"]["min"], dtype=dtype
352+
),
353+
"timestep": torch.ones(*dynamic_shapes["timestep"]["min"], dtype=dtype),
354+
}
355+
return dummy_kwargs, dynamic_shapes
297356

298357

299358
def update_dynamic_axes(model_id, dynamic_axes):
@@ -327,30 +386,32 @@ def _create_dynamic_shapes(dynamic_shapes):
327386
def generate_dummy_inputs_and_dynamic_axes_and_shapes(model_id, backbone):
328387
"""Generate dummy inputs, dynamic axes, and dynamic shapes for the given model."""
329388
if model_id in ["sdxl-1.0", "sdxl-turbo"]:
330-
dummy_input, dynamic_shapes = _gen_dummy_inp_and_dyn_shapes_sdxl(
331-
backbone, min_bs=2, opt_bs=16
389+
dummy_kwargs, dynamic_shapes = _gen_dummy_inp_and_dyn_shapes_sdxl(
390+
backbone, min_bs=1, opt_bs=16
332391
)
333392
elif model_id in ["sd3-medium", "sd3.5-medium"]:
334-
dummy_input, dynamic_shapes = _gen_dummy_inp_and_dyn_shapes_sd3(
335-
backbone, min_bs=2, opt_bs=16
393+
dummy_kwargs, dynamic_shapes = _gen_dummy_inp_and_dyn_shapes_sd3(
394+
backbone, min_bs=1, opt_bs=16
336395
)
337396
elif model_id in ["flux-dev", "flux-schnell"]:
338-
dummy_input, dynamic_shapes = _gen_dummy_inp_and_dyn_shapes_flux(
339-
backbone, min_bs=1, opt_bs=1
397+
dummy_kwargs, dynamic_shapes = _gen_dummy_inp_and_dyn_shapes_flux(
398+
backbone, min_bs=1, opt_bs=2
340399
)
341400
elif model_id == "ltx-video-dev":
342-
dummy_input, dynamic_shapes = _gen_dummy_inp_and_dyn_shapes_ltx(
343-
backbone, min_bs=2, opt_bs=2
401+
dummy_kwargs, dynamic_shapes = _gen_dummy_inp_and_dyn_shapes_ltx(
402+
backbone, min_bs=1, opt_bs=2
403+
)
404+
elif model_id == "wan2.2-t2v-14b":
405+
dummy_kwargs, dynamic_shapes = _gen_dummy_inp_and_dyn_shapes_wan(
406+
backbone, min_bs=1, opt_bs=2
344407
)
345408
else:
346409
raise NotImplementedError(f"Unsupported model_id: {model_id}")
347410

348-
dummy_input = torch_to(dummy_input, device=backbone.device)
349-
dummy_inputs = (dummy_input,)
411+
dummy_kwargs = torch_to(dummy_kwargs, device=backbone.device)
350412
dynamic_axes = MODEL_ID_TO_DYNAMIC_AXES[model_id]
351-
dynamic_shapes = _create_dynamic_shapes(dynamic_shapes)
352413

353-
return dummy_inputs, dynamic_axes, dynamic_shapes
414+
return dummy_kwargs, dynamic_axes, dynamic_shapes
354415

355416

356417
def get_io_shapes(model_id, onnx_load_path, dynamic_shapes):
@@ -415,7 +476,7 @@ def modelopt_export_sd(backbone, onnx_dir, model_name, precision):
415476
configure_linear_module_onnx_quantizers(backbone) if precision == "fp4" else nullcontext()
416477
)
417478

418-
dummy_inputs, dynamic_axes, _ = generate_dummy_inputs_and_dynamic_axes_and_shapes(
479+
dummy_kwargs, dynamic_axes, _ = generate_dummy_inputs_and_dynamic_axes_and_shapes(
419480
model_name, backbone
420481
)
421482

@@ -449,6 +510,13 @@ def modelopt_export_sd(backbone, onnx_dir, model_name, precision):
449510
"video_coords",
450511
]
451512
output_names = ["latent"]
513+
elif model_name in ["wan2.2-t2v-14b"]:
514+
input_names = [
515+
"hidden_states",
516+
"timestep",
517+
"encoder_hidden_states",
518+
]
519+
output_names = ["latent"]
452520
else:
453521
raise NotImplementedError(f"Unsupported model_id: {model_name}")
454522

@@ -458,8 +526,9 @@ def modelopt_export_sd(backbone, onnx_dir, model_name, precision):
458526
with quantizer_context, torch.inference_mode():
459527
onnx_export(
460528
backbone,
461-
dummy_inputs,
529+
(),
462530
f=tmp_output.as_posix(),
531+
kwargs=dummy_kwargs,
463532
input_names=input_names,
464533
output_names=output_names,
465534
dynamic_axes=dynamic_axes,

0 commit comments

Comments
 (0)