2525import nvtripy as tp
2626
2727from transformers import CLIPTokenizer
28- from examples . diffusion . models .clip_model import CLIPConfig
29- from examples . diffusion . models .model import StableDiffusion , StableDiffusionConfig
30- from examples . diffusion . weight_loader import load_from_diffusers
28+ from models .clip_model import CLIPConfig
29+ from models .model import StableDiffusion , StableDiffusionConfig
30+ from weight_loader import load_from_diffusers
3131
3232
3333def compile_model (model , inputs , engine_path , verbose = False ):
@@ -57,18 +57,19 @@ def compile_clip(model, engine_path, dtype=tp.int32, verbose=False):
5757 return compile_model (model , inputs , engine_path , verbose = verbose )
5858
5959
60- def compile_unet (model , engine_path , dtype , verbose = False ):
60+ def compile_unet (model , steps , engine_path , dtype , verbose = False ):
6161 unconditional_context_shape = (1 , 77 , 768 )
6262 conditional_context_shape = (1 , 77 , 768 )
6363 latent_shape = (1 , 4 , 64 , 64 )
6464 inputs = (
6565 tp .InputInfo (unconditional_context_shape , dtype = dtype ),
6666 tp .InputInfo (conditional_context_shape , dtype = dtype ),
6767 tp .InputInfo (latent_shape , dtype = dtype ),
68+ tp .InputInfo ((steps ,), dtype = dtype ),
69+ tp .InputInfo ((steps ,), dtype = dtype ),
70+ tp .InputInfo ((steps ,), dtype = dtype ),
6871 tp .InputInfo ((1 ,), dtype = dtype ),
69- tp .InputInfo ((1 ,), dtype = dtype ),
70- tp .InputInfo ((1 ,), dtype = dtype ),
71- tp .InputInfo ((1 ,), dtype = dtype ),
72+ tp .InputInfo ((1 ,), dtype = tp .int32 ),
7273 )
7374 return compile_model (model , inputs , engine_path , verbose = verbose )
7475
@@ -90,6 +91,7 @@ def run_diffusion_loop(model, unconditional_context, context, latent, steps, gui
9091 torch_dtype = torch .float16 if dtype == tp .float16 else torch .float32
9192 idx_timesteps = list (range (1 , 1000 , 1000 // steps ))
9293 num_timesteps = len (idx_timesteps )
94+ print (f"num_timesteps: { num_timesteps } " )
9395 timesteps = torch .tensor (idx_timesteps , dtype = torch_dtype , device = "cuda" )
9496 guidance = torch .tensor ([guidance ], dtype = torch_dtype , device = "cuda" )
9597
@@ -104,14 +106,16 @@ def run_diffusion_loop(model, unconditional_context, context, latent, steps, gui
104106 iterator = list (range (num_timesteps ))[::- 1 ]
105107
106108 for index in iterator :
109+ idx = torch .tensor ([index ], dtype = torch .int32 , device = "cuda" )
107110 latent = model (
108111 unconditional_context ,
109112 context ,
110113 latent ,
111- tp .Tensor (timesteps [ index : index + 1 ] ),
112- tp .Tensor (alphas [ index : index + 1 ] ),
113- tp .Tensor (alphas_prev [ index : index + 1 ] ),
114+ tp .Tensor (timesteps ),
115+ tp .Tensor (alphas ),
116+ tp .Tensor (alphas_prev ),
114117 tp .Tensor (guidance ),
118+ tp .Tensor (idx ),
115119 )
116120
117121 return latent
@@ -161,8 +165,9 @@ def tripy_diffusion(args):
161165 os .mkdir (args .engine_dir )
162166
163167 # Load existing engines if they exist, otherwise compile and save them
168+ timesteps_size = len (list (range (1 , 1000 , 1000 // args .steps )))
164169 clip_compiled = compile_clip (model .text_encoder , engine_path = clip_path , verbose = args .verbose )
165- unet_compiled = compile_unet (model , engine_path = unet_path , dtype = dtype , verbose = args .verbose )
170+ unet_compiled = compile_unet (model , timesteps_size , engine_path = unet_path , dtype = dtype , verbose = args .verbose )
166171 vae_compiled = compile_vae (model .decode , engine_path = vae_path , dtype = dtype , verbose = args .verbose )
167172
168173 # Run through CLIP to get context from prompt
0 commit comments