@@ -52,7 +52,7 @@ def compile_clip(model, dtype=tp.int32, verbose=False):
5252 return compile_model (model , inputs , verbose = verbose )
5353
5454
55- def compile_unet (model , dtype = tp . float16 , verbose = False ):
55+ def compile_unet (model , dtype , verbose = False ):
5656 unconditional_context_shape = (1 , 77 , 768 )
5757 conditional_context_shape = (1 , 77 , 768 )
5858 latent_shape = (1 , 4 , 64 , 64 )
@@ -68,16 +68,16 @@ def compile_unet(model, dtype=tp.float16, verbose=False):
6868 return compile_model (model , inputs , verbose = verbose )
6969
7070
71- def compile_vae (model , dtype = tp . float16 , verbose = False ):
71+ def compile_vae (model , dtype , verbose = False ):
7272 inputs = (tp .InputInfo ((1 , 4 , 64 , 64 ), dtype = dtype ),)
7373 return compile_model (model , inputs , verbose = verbose )
7474
7575
76- def run_diffusion_loop (model , unconditional_context , context , latent , steps , guidance ):
76+ def run_diffusion_loop (model , unconditional_context , context , latent , steps , guidance , dtype ):
7777 timesteps = list (range (1 , 1000 , 1000 // steps ))
78- print (f"[I] Running diffusion for { timesteps } timesteps..." )
79- alphas = get_alphas_cumprod ()[tp .Tensor (timesteps )]
80- alphas_prev = tp .concatenate ([tp .Tensor ([1.0 ]), alphas [:- 1 ]], dim = 0 )
78+ print (f"[I] Running diffusion for { steps } timesteps..." )
79+ alphas = get_alphas_cumprod (dtype = dtype )[tp .Tensor (timesteps )]
80+ alphas_prev = tp .concatenate ([tp .Tensor ([1.0 ], dtype = dtype ), alphas [:- 1 ]], dim = 0 )
8181
8282 for index , timestep in (t := tqdm (list (enumerate (timesteps ))[::- 1 ])):
8383 t .set_description ("idx: %1d, timestep: %3d" % (index , timestep ))
@@ -86,32 +86,34 @@ def run_diffusion_loop(model, unconditional_context, context, latent, steps, gui
8686 unconditional_context ,
8787 context ,
8888 latent ,
89- tp .cast ( tp . Tensor ([timestep ]), tp . float32 ),
89+ tp .Tensor ([timestep ], dtype = dtype ),
9090 alphas [tid ],
9191 alphas_prev [tid ],
92- tp .Tensor ([guidance ]),
92+ tp .Tensor ([guidance ], dtype = dtype ),
9393 )
9494 return latent
9595
9696
9797def tripy_diffusion (args ):
9898 run_start_time = time .perf_counter ()
9999
100- if os .path .isdir ("engines" ):
100+ dtype , torch_dtype = (tp .float16 , torch .float16 ) if args .fp16 else (tp .float32 , torch .float32 )
101+
102+ if os .path .isdir (args .engine_dir ):
101103 print ("[I] Loading cached engines from disk..." )
102104 clip_compiled = tp .Executable .load (os .path .join ("engines" , "clip_executable.json" ))
103105 unet_compiled = tp .Executable .load (os .path .join ("engines" , "unet_executable.json" ))
104106 vae_compiled = tp .Executable .load (os .path .join ("engines" , "vae_executable.json" ))
105107 else :
106- model = StableDiffusion (StableDiffusionConfig (dtype = tp . float16 ))
108+ model = StableDiffusion (StableDiffusionConfig (dtype = dtype ))
107109 print ("[I] Loading model weights..." , flush = True )
108- load_from_diffusers (model , tp . float16 , debug = True )
110+ load_from_diffusers (model , dtype , args . hf_token , debug = True )
109111 clip_compiled = compile_clip (model .cond_stage_model .transformer .text_model , verbose = True )
110- unet_compiled = compile_unet (model , verbose = True )
111- vae_compiled = compile_vae (model .decode , verbose = True )
112+ unet_compiled = compile_unet (model , dtype , verbose = True )
113+ vae_compiled = compile_vae (model .decode , dtype , verbose = True )
112114
113- os .mkdir ("engines" )
114- print ("[I] Saving engines to disk ..." )
115+ os .mkdir (args . engine_dir )
116+ print (f "[I] Saving engines to { args . engine_dir } ..." )
115117 clip_compiled .save (os .path .join ("engines" , "clip_executable.json" ))
116118 unet_compiled .save (os .path .join ("engines" , "unet_executable.json" ))
117119 vae_compiled .save (os .path .join ("engines" , "vae_executable.json" ))
@@ -135,11 +137,11 @@ def tripy_diffusion(args):
135137 # Backbone of diffusion - the UNet
136138 if args .seed is not None :
137139 torch .manual_seed (args .seed )
138- torch_latent = torch .randn ((1 , 4 , 64 , 64 )).to ("cuda" )
140+ torch_latent = torch .randn ((1 , 4 , 64 , 64 ), dtype = torch_dtype ).to ("cuda" )
139141 latent = tp .Tensor (torch_latent )
140142
141143 diffusion_run_start = time .perf_counter ()
142- latent = run_diffusion_loop (unet_compiled , unconditional_context , context , latent , args .steps , args .guidance )
144+ latent = run_diffusion_loop (unet_compiled , unconditional_context , context , latent , args .steps , args .guidance , dtype )
143145 diffusion_run_end = time .perf_counter ()
144146 print (f"[I] Finished diffusion denoising. Inference took { diffusion_run_end - diffusion_run_start } seconds." )
145147
@@ -173,15 +175,17 @@ def hf_diffusion(args):
173175
174176 run_start_time = time .perf_counter ()
175177
178+ dtype = torch .float16 if args .fp16 else torch .float32
179+ model_opts = {'variant' : 'fp16' , 'torch_dtype' : torch .float16 } if args .fp16 else {}
180+
176181 # Initialize models
177- model_id = "CompVis/stable-diffusion-v1-4" #"benjamin-paine/stable-diffusion-v1-5" #"runwayml/stable-diffusion-v1-5"
178- clip_id = "openai/clip-vit-large-patch14"
182+ model_id = "KiwiXR/stable-diffusion-v1-5"
179183
180184 print ("[I] Loading models..." )
181- hf_tokenizer = CLIPTokenizer .from_pretrained (clip_id )
182- hf_encoder = CLIPTextModel .from_pretrained (clip_id ).to ("cuda" )
183- unet = UNet2DConditionModel .from_pretrained (model_id , subfolder = "unet" ).to ("cuda" )
184- vae = AutoencoderKL .from_pretrained (model_id , subfolder = "vae" ).to ("cuda" )
185+ hf_tokenizer = CLIPTokenizer .from_pretrained (model_id , subfolder = "tokenizer" )
186+ hf_encoder = CLIPTextModel .from_pretrained (model_id , subfolder = "text_encoder" ).to ("cuda" )
187+ unet = UNet2DConditionModel .from_pretrained (model_id , subfolder = "unet" , use_auth_token = args . hf_token , ** model_opts ).to ("cuda" )
188+ vae = AutoencoderKL .from_pretrained (model_id , subfolder = "vae" , use_auth_token = args . hf_token , ** model_opts ).to ("cuda" )
185189 scheduler = LMSDiscreteScheduler (beta_start = 0.00085 , beta_end = 0.012 , beta_schedule = "scaled_linear" , num_train_timesteps = 1000 )
186190
187191 # Run through CLIP to get context from prompt
@@ -192,19 +196,20 @@ def hf_diffusion(args):
192196 uncond_input = hf_tokenizer (["" ], padding = "max_length" , max_length = max_length , return_tensors = "pt" ).to ("cuda" )
193197 text_embeddings = hf_encoder (text_input .input_ids , output_hidden_states = True )[0 ]
194198 uncond_embeddings = hf_encoder (uncond_input .input_ids )[0 ]
195- text_embeddings = torch .cat ([uncond_embeddings , text_embeddings ])
199+ text_embeddings = torch .cat ([uncond_embeddings , text_embeddings ]). to ( dtype )
196200 clip_run_end = time .perf_counter ()
197201 print (f"took { clip_run_end - clip_run_start } seconds." )
198202
199203 # Backbone of diffusion - the UNet
200204 if args .seed is not None :
201205 torch .manual_seed (args .seed )
202- torch_latent = torch .randn ((1 , 4 , 64 , 64 )).to ("cuda" )
206+ torch_latent = torch .randn ((1 , 4 , 64 , 64 ), dtype = dtype ).to ("cuda" )
203207 torch_latent *= scheduler .init_noise_sigma
204208
205209 scheduler .set_timesteps (args .steps )
206210
207211 diffusion_run_start = time .perf_counter ()
212+ print (f"[I] Running diffusion for { args .steps } timesteps..." )
208213 for t in tqdm (scheduler .timesteps ):
209214 # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
210215 latent_model_input = torch .cat ([torch_latent ] * 2 )
@@ -267,7 +272,6 @@ def print_summary(denoising_steps, times):
267272
268273
269274# TODO: Add torch compilation modes
270- # TODO: Add fp16 support
271275# TODO: Add Timing context
272276def main ():
273277 default_prompt = "a horse sized cat eating a bagel"
@@ -282,6 +286,8 @@ def main():
282286 parser .add_argument ("--seed" , type = int , help = "Set the random latent seed" )
283287 parser .add_argument ("--guidance" , type = float , default = 7.5 , help = "Prompt strength" )
284288 parser .add_argument ('--torch-inference' , action = 'store_true' , help = "Run inference with PyTorch (eager mode) instead of TensorRT." )
289+ parser .add_argument ('--hf-token' , type = str , default = '' , help = "HuggingFace API access token for downloading model checkpoints" )
290+ parser .add_argument ('--engine-dir' , type = str , default = 'engines' , help = "Output directory for TensorRT engines" )
285291 args = parser .parse_args ()
286292
287293 if args .torch_inference :
0 commit comments