2121import time
2222
2323import torch
24- import cupy as cp
2524import numpy as np
2625import nvtripy as tp
2726
@@ -84,18 +83,24 @@ def get_alphas_cumprod(beta_start=0.00085, beta_end=0.0120, n_training_steps=100
8483 return alphas_cumprod
8584
8685
87- def run_diffusion_loop (model , unconditional_context , context , latent , steps , guidance , dtype ):
86+ def run_diffusion_loop (model , unconditional_context , context , latent , steps , guidance , dtype , verbose = False ):
8887 torch_dtype = torch .float16 if dtype == tp .float16 else torch .float32
8988 idx_timesteps = list (range (1 , 1000 , 1000 // steps ))
9089 num_timesteps = len (idx_timesteps )
9190 timesteps = torch .tensor (idx_timesteps , dtype = torch_dtype , device = "cuda" )
9291 guidance = torch .tensor ([guidance ], dtype = torch_dtype , device = "cuda" )
9392
94- print (f"[I] Running diffusion for { steps } timesteps..." )
93+ if verbose :
94+ print (f"[I] Running diffusion for { steps } timesteps..." )
9595 alphas = get_alphas_cumprod (dtype = torch_dtype )[idx_timesteps ]
9696 alphas_prev = torch .cat ((torch .tensor ([1.0 ], dtype = torch_dtype , device = "cuda" ), alphas [:- 1 ]))
9797
98- for index in (t := tqdm (list (range (num_timesteps ))[::- 1 ])):
98+ if verbose :
99+ iterator = tqdm (list (range (num_timesteps ))[::- 1 ])
100+ else :
101+ iterator = list (range (num_timesteps ))[::- 1 ]
102+
103+ for index in iterator :
99104 latent = model (
100105 unconditional_context ,
101106 context ,
@@ -121,36 +126,39 @@ def save_image(image, args):
121126 f"seed{ args .seed if args .seed else 'rand' } -"
122127 f"{ int (time .time ())} .png"
123128 )
129+ filename = os .path .join ("output" , filename )
124130
125- target = os .path .join ("output" , filename )
126131 # Save image
127- print (f"[I] Saving image to { target } " )
128- if not os .path .isdir ("output" ):
129- print ("[I] Creating 'output ' directory." )
130- os .mkdir ( "output" )
131- image .save (target )
132+ print (f"[I] Saving image to { filename } " )
133+ if not os .path .isdir (os . path . dirname ( filename ) ):
134+ print (f "[I] Creating '{ os . path . dirname ( filename ) } ' directory." )
135+ os .makedirs ( os . path . dirname ( filename ) )
136+ image .save (filename )
132137
133138
134139def tripy_diffusion (args ):
135- run_start_time = time .perf_counter ()
140+ run_start_time = time .perf_counter () if args . verbose else None
136141
137142 dtype , torch_dtype = (tp .float16 , torch .float16 ) if args .fp16 else (tp .float32 , torch .float32 )
138143
139144 if os .path .isdir (args .engine_dir ):
140- print (f"[I] Loading cached engines from { args .engine_dir } ..." )
145+ if args .verbose :
146+ print (f"[I] Loading cached engines from { args .engine_dir } ..." )
141147 clip_compiled = tp .Executable .load (os .path .join (args .engine_dir , "clip_executable.tpymodel" ))
142148 unet_compiled = tp .Executable .load (os .path .join (args .engine_dir , "unet_executable.tpymodel" ))
143149 vae_compiled = tp .Executable .load (os .path .join (args .engine_dir , "vae_executable.tpymodel" ))
144150 else :
145151 model = StableDiffusion (StableDiffusionConfig (dtype = dtype ))
146- print ("[I] Loading model weights..." , flush = True )
152+ if args .verbose :
153+ print ("[I] Loading model weights..." , flush = True )
147154 load_from_diffusers (model , dtype , args .hf_token , debug = True )
148- clip_compiled = compile_clip (model .cond_stage_model .transformer .text_model , verbose = True )
149- unet_compiled = compile_unet (model , dtype , verbose = True )
150- vae_compiled = compile_vae (model .decode , dtype , verbose = True )
155+ clip_compiled = compile_clip (model .cond_stage_model .transformer .text_model , verbose = args . verbose )
156+ unet_compiled = compile_unet (model , dtype , verbose = args . verbose )
157+ vae_compiled = compile_vae (model .decode , dtype , verbose = args . verbose )
151158
152159 os .mkdir (args .engine_dir )
153- print (f"[I] Saving engines to ./{ args .engine_dir } ..." )
160+ if args .verbose :
161+ print (f"[I] Saving engines to ./{ args .engine_dir } ..." )
154162 clip_compiled .save (os .path .join (args .engine_dir , "clip_executable.tpymodel" ))
155163 unet_compiled .save (os .path .join (args .engine_dir , "unet_executable.tpymodel" ))
156164 vae_compiled .save (os .path .join (args .engine_dir , "vae_executable.tpymodel" ))
@@ -161,147 +169,74 @@ def tripy_diffusion(args):
161169 args .prompt , padding = "max_length" , max_length = CLIPConfig .max_seq_len , truncation = True , return_tensors = "pt"
162170 )
163171 prompt = tp .Tensor (torch_prompt .input_ids .to (torch .int32 ).to ("cuda" ))
164- print (f"[I] Got tokenized prompt." )
172+ if args .verbose :
173+ print (f"[I] Got tokenized prompt." )
165174 torch_unconditional_prompt = tokenizer (
166175 ["" ], padding = "max_length" , max_length = CLIPConfig .max_seq_len , return_tensors = "pt"
167176 )
168177 unconditional_prompt = tp .Tensor (torch_unconditional_prompt .input_ids .to (torch .int32 ).to ("cuda" ))
169- print (f"[I] Got unconditional tokenized prompt." )
178+ if args .verbose :
179+ print (f"[I] Got unconditional tokenized prompt." )
170180
171- print ("[I] Getting CLIP conditional and unconditional context..." , end = " " )
172- clip_run_start = time .perf_counter ()
181+ if args .verbose :
182+ print ("[I] Getting CLIP conditional and unconditional context..." , end = " " )
183+ clip_run_start = time .perf_counter () if args .verbose else None
173184 context = clip_compiled (prompt )
174185 unconditional_context = clip_compiled (unconditional_prompt )
175- tp .default_stream ().synchronize ()
176- clip_run_end = time .perf_counter ()
177- print (f"took { clip_run_end - clip_run_start } seconds." )
186+ if args .verbose :
187+ tp .default_stream ().synchronize ()
188+ clip_run_end = time .perf_counter ()
189+ print (f"took { clip_run_end - clip_run_start } seconds." )
190+ else :
191+ clip_run_start = None
192+ clip_run_end = None
178193
179194 # Backbone of diffusion - the UNet
180195 if args .seed is not None :
181196 torch .manual_seed (args .seed )
182197 torch_latent = torch .randn ((1 , 4 , 64 , 64 ), dtype = torch_dtype , device = "cuda" )
183198 latent = tp .Tensor (torch_latent )
184199
185- diffusion_run_start = time .perf_counter ()
186- latent = run_diffusion_loop (unet_compiled , unconditional_context , context , latent , args .steps , args .guidance , dtype )
187- tp .default_stream ().synchronize ()
188- diffusion_run_end = time .perf_counter ()
189- print (f"[I] Finished diffusion denoising. Inference took { diffusion_run_end - diffusion_run_start } seconds." )
200+ diffusion_run_start = time .perf_counter () if args .verbose else None
201+ latent = run_diffusion_loop (
202+ unet_compiled , unconditional_context , context , latent , args .steps , args .guidance , dtype , verbose = args .verbose
203+ )
204+ if args .verbose :
205+ tp .default_stream ().synchronize ()
206+ diffusion_run_end = time .perf_counter ()
207+ print (f"[I] Finished diffusion denoising. Inference took { diffusion_run_end - diffusion_run_start } seconds." )
208+ else :
209+ diffusion_run_start = None
210+ diffusion_run_end = None
190211
191212 # Upsample latent space to image with autoencoder
192- print (f"[I] Decoding latent..." , end = " " )
193- vae_run_start = time .perf_counter ()
213+ if args .verbose :
214+ print (f"[I] Decoding latent..." , end = " " )
215+ vae_run_start = time .perf_counter () if args .verbose else None
194216 x = vae_compiled (latent )
195- tp .default_stream ().synchronize ()
196- vae_run_end = time .perf_counter ()
197- print (f"took { vae_run_end - vae_run_start } seconds." )
217+ if args .verbose :
218+ tp .default_stream ().synchronize ()
219+ vae_run_end = time .perf_counter ()
220+ print (f"took { vae_run_end - vae_run_start } seconds." )
221+ else :
222+ vae_run_start = None
223+ vae_run_end = None
198224
199225 # Evaluate output
200- run_end_time = time .perf_counter ()
201- print (f"[I] Full script took { run_end_time - run_start_time } seconds." )
226+ run_end_time = time .perf_counter () if args .verbose else None
227+ if args .verbose :
228+ print (f"[I] Full script took { run_end_time - run_start_time } seconds." )
202229
203- image = Image .fromarray (cp .from_dlpack (x ).get ().astype (np .uint8 , copy = False ))
230+ image_array = np .from_dlpack (tp .copy (x , tp .device ("cpu" ))).astype (np .uint8 , copy = False )
231+ image = Image .fromarray (image_array )
204232
205233 return image , [clip_run_start , clip_run_end , diffusion_run_start , diffusion_run_end , vae_run_start , vae_run_end ]
206234
207235
208- # referenced from https://huggingface.co/blog/stable_diffusion
209- def hf_diffusion (args ):
210- from transformers import CLIPTextModel , CLIPTokenizer
211- from diffusers import AutoencoderKL , UNet2DConditionModel , LMSDiscreteScheduler
212- from tqdm .auto import tqdm
213-
214- run_start_time = time .perf_counter ()
236+ def print_summary (denoising_steps , times , verbose = False ):
237+ if not verbose or times is None or None in times :
238+ return
215239
216- dtype = torch .float16 if args .fp16 else torch .float32
217- model_opts = {"variant" : "fp16" , "torch_dtype" : torch .float16 } if args .fp16 else {}
218-
219- # Initialize models
220- model_id = "KiwiXR/stable-diffusion-v1-5"
221-
222- print ("[I] Loading models..." )
223- hf_tokenizer = CLIPTokenizer .from_pretrained (model_id , subfolder = "tokenizer" )
224- hf_encoder = CLIPTextModel .from_pretrained (model_id , subfolder = "text_encoder" ).to ("cuda" )
225- unet = UNet2DConditionModel .from_pretrained (
226- model_id , subfolder = "unet" , use_auth_token = args .hf_token , ** model_opts
227- ).to ("cuda" )
228- vae = AutoencoderKL .from_pretrained (model_id , subfolder = "vae" , use_auth_token = args .hf_token , ** model_opts ).to (
229- "cuda"
230- )
231- scheduler = LMSDiscreteScheduler (
232- beta_start = 0.00085 , beta_end = 0.012 , beta_schedule = "scaled_linear" , num_train_timesteps = 1000
233- )
234-
235- # Run through CLIP to get context from prompt
236- print ("[I] Starting tokenization and running clip..." , end = " " )
237- clip_run_start = time .perf_counter ()
238- text_input = hf_tokenizer (
239- args .prompt ,
240- padding = "max_length" ,
241- max_length = hf_tokenizer .model_max_length ,
242- truncation = True ,
243- return_tensors = "pt" ,
244- ).to ("cuda" )
245- max_length = text_input .input_ids .shape [- 1 ] # 77
246- uncond_input = hf_tokenizer (["" ], padding = "max_length" , max_length = max_length , return_tensors = "pt" ).to ("cuda" )
247- text_embeddings = hf_encoder (text_input .input_ids , output_hidden_states = True )[0 ]
248- uncond_embeddings = hf_encoder (uncond_input .input_ids )[0 ]
249- text_embeddings = torch .cat ([uncond_embeddings , text_embeddings ]).to (dtype )
250- clip_run_end = time .perf_counter ()
251- print (f"took { clip_run_end - clip_run_start } seconds." )
252-
253- # Backbone of diffusion - the UNet
254- if args .seed is not None :
255- torch .manual_seed (args .seed )
256- torch_latent = torch .randn ((1 , 4 , 64 , 64 ), dtype = dtype , device = "cuda" )
257- torch_latent *= scheduler .init_noise_sigma
258-
259- scheduler .set_timesteps (args .steps )
260-
261- diffusion_run_start = time .perf_counter ()
262- print (f"[I] Running diffusion for { args .steps } timesteps..." )
263- for t in tqdm (scheduler .timesteps ):
264- # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
265- latent_model_input = torch .cat ([torch_latent ] * 2 )
266- latent_model_input = scheduler .scale_model_input (latent_model_input , timestep = t )
267-
268- # predict the noise residual
269- with torch .no_grad ():
270- noise_pred = unet (latent_model_input , t , encoder_hidden_states = text_embeddings ).sample
271-
272- # perform guidance
273- noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
274- noise_pred = noise_pred_uncond + args .guidance * (noise_pred_text - noise_pred_uncond )
275-
276- # compute the previous noisy sample x_t -> x_t-1
277- torch_latent = scheduler .step (noise_pred , t , torch_latent ).prev_sample
278-
279- diffusion_run_end = time .perf_counter ()
280- print (f"[I] Finished diffusion denoising. Inference took { diffusion_run_end - diffusion_run_start } seconds." )
281-
282- # Upsample latent space to image with autoencoder
283- print (f"[I] Decoding latent..." , end = " " )
284- vae_run_start = time .perf_counter ()
285- torch_latent = 1 / 0.18215 * torch_latent
286- with torch .no_grad ():
287- image = vae .decode (torch_latent ).sample
288- vae_run_end = time .perf_counter ()
289- print (f"took { vae_run_end - vae_run_start } seconds." )
290-
291- # Evaluate Output
292- image = (image / 2 + 0.5 ).clamp (0 , 1 )
293- image = image .detach ().cpu ().permute (0 , 2 , 3 , 1 ).numpy ()
294- images = (image * 255 ).round ().astype ("uint8" )
295- pil_images = [Image .fromarray (image ) for image in images ]
296- image = pil_images [0 ]
297-
298- run_end_time = time .perf_counter ()
299- print (f"[I] Full script took { run_end_time - run_start_time } seconds." )
300-
301- return image , [clip_run_start , clip_run_end , diffusion_run_start , diffusion_run_end , vae_run_start , vae_run_end ]
302-
303-
304- def print_summary (denoising_steps , times ):
305240 stages_ms = [1000 * (times [i + 1 ] - times [i ]) for i in range (0 , 6 , 2 )]
306241 total_ms = sum (stages_ms )
307242 print ("|-----------------|--------------|" )
@@ -316,8 +251,6 @@ def print_summary(denoising_steps, times):
316251 print ("Throughput: {:.2f} image/s" .format (1000.0 / total_ms ))
317252
318253
319- # TODO: Add torch compilation
320- # TODO: Add Timing context (depends on how we measure perf)
321254def main ():
322255 default_prompt = "a beautiful photograph of Mt. Fuji during cherry blossom"
323256 parser = argparse .ArgumentParser (
@@ -336,6 +269,9 @@ def main():
336269 "--hf-token" , type = str , default = "" , help = "HuggingFace API access token for downloading model checkpoints"
337270 )
338271 parser .add_argument ("--engine-dir" , type = str , default = "engines" , help = "Output directory for TensorRT engines" )
272+ parser .add_argument (
273+ "--verbose" , action = "store_true" , default = False , help = "Enable verbose output with timing and progress bars"
274+ )
339275 args = parser .parse_args ()
340276
341277 if args .torch_inference :
@@ -344,7 +280,7 @@ def main():
344280 image , times = tripy_diffusion (args )
345281
346282 save_image (image , args )
347- print_summary (args .steps , times )
283+ print_summary (args .steps , times , verbose = args . verbose )
348284
349285
350286if __name__ == "__main__" :
0 commit comments