@@ -57,19 +57,18 @@ def compile_clip(model, engine_path, dtype=tp.int32, verbose=False):
5757 return compile_model (model , inputs , engine_path , verbose = verbose )
5858
5959
60- def compile_unet (model , steps , engine_path , dtype , verbose = False ):
60+ def compile_unet (model , engine_path , dtype , verbose = False ):
6161 unconditional_context_shape = (1 , 77 , 768 )
6262 conditional_context_shape = (1 , 77 , 768 )
6363 latent_shape = (1 , 4 , 64 , 64 )
6464 inputs = (
6565 tp .InputInfo (unconditional_context_shape , dtype = dtype ),
6666 tp .InputInfo (conditional_context_shape , dtype = dtype ),
6767 tp .InputInfo (latent_shape , dtype = dtype ),
68- tp .InputInfo ((steps ,), dtype = dtype ),
69- tp .InputInfo ((steps ,), dtype = dtype ),
70- tp .InputInfo ((steps ,), dtype = dtype ),
7168 tp .InputInfo ((1 ,), dtype = dtype ),
72- tp .InputInfo ((1 ,), dtype = tp .int32 ),
69+ tp .InputInfo ((1 ,), dtype = dtype ),
70+ tp .InputInfo ((1 ,), dtype = dtype ),
71+ tp .InputInfo ((1 ,), dtype = dtype ),
7372 )
7473 return compile_model (model , inputs , engine_path , verbose = verbose )
7574
@@ -91,7 +90,6 @@ def run_diffusion_loop(model, unconditional_context, context, latent, steps, gui
9190 torch_dtype = torch .float16 if dtype == tp .float16 else torch .float32
9291 idx_timesteps = list (range (1 , 1000 , 1000 // steps ))
9392 num_timesteps = len (idx_timesteps )
94- print (f"num_timesteps: { num_timesteps } " )
9593 timesteps = torch .tensor (idx_timesteps , dtype = torch_dtype , device = "cuda" )
9694 guidance = torch .tensor ([guidance ], dtype = torch_dtype , device = "cuda" )
9795
@@ -106,16 +104,14 @@ def run_diffusion_loop(model, unconditional_context, context, latent, steps, gui
106104 iterator = list (range (num_timesteps ))[::- 1 ]
107105
108106 for index in iterator :
109- idx = torch .tensor ([index ], dtype = torch .int32 , device = "cuda" )
110107 latent = model (
111108 unconditional_context ,
112109 context ,
113110 latent ,
114- tp .Tensor (timesteps ),
115- tp .Tensor (alphas ),
116- tp .Tensor (alphas_prev ),
111+ tp .Tensor (timesteps [ index : index + 1 ] ),
112+ tp .Tensor (alphas [ index : index + 1 ] ),
113+ tp .Tensor (alphas_prev [ index : index + 1 ] ),
117114 tp .Tensor (guidance ),
118- tp .Tensor (idx ),
119115 )
120116
121117 return latent
@@ -165,9 +161,8 @@ def tripy_diffusion(args):
165161 os .mkdir (args .engine_dir )
166162
167163 # Load existing engines if they exist, otherwise compile and save them
168- timesteps_size = len (list (range (1 , 1000 , 1000 // args .steps )))
169164 clip_compiled = compile_clip (model .text_encoder , engine_path = clip_path , verbose = args .verbose )
170- unet_compiled = compile_unet (model , timesteps_size , engine_path = unet_path , dtype = dtype , verbose = args .verbose )
165+ unet_compiled = compile_unet (model , engine_path = unet_path , dtype = dtype , verbose = args .verbose )
171166 vae_compiled = compile_vae (model .decode , engine_path = vae_path , dtype = dtype , verbose = args .verbose )
172167
173168 # Run through CLIP to get context from prompt
0 commit comments