Skip to content

Commit 16e0912

Browse files
committed
Move slicing in denoising loop back outside of engine
1 parent 0f1b4ee commit 16e0912

File tree

3 files changed

+10
-20
lines changed

3 files changed

+10
-20
lines changed

tripy/examples/diffusion/example.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -57,19 +57,18 @@ def compile_clip(model, engine_path, dtype=tp.int32, verbose=False):
5757
return compile_model(model, inputs, engine_path, verbose=verbose)
5858

5959

60-
def compile_unet(model, steps, engine_path, dtype, verbose=False):
60+
def compile_unet(model, engine_path, dtype, verbose=False):
6161
unconditional_context_shape = (1, 77, 768)
6262
conditional_context_shape = (1, 77, 768)
6363
latent_shape = (1, 4, 64, 64)
6464
inputs = (
6565
tp.InputInfo(unconditional_context_shape, dtype=dtype),
6666
tp.InputInfo(conditional_context_shape, dtype=dtype),
6767
tp.InputInfo(latent_shape, dtype=dtype),
68-
tp.InputInfo((steps,), dtype=dtype),
69-
tp.InputInfo((steps,), dtype=dtype),
70-
tp.InputInfo((steps,), dtype=dtype),
7168
tp.InputInfo((1,), dtype=dtype),
72-
tp.InputInfo((1,), dtype=tp.int32),
69+
tp.InputInfo((1,), dtype=dtype),
70+
tp.InputInfo((1,), dtype=dtype),
71+
tp.InputInfo((1,), dtype=dtype),
7372
)
7473
return compile_model(model, inputs, engine_path, verbose=verbose)
7574

@@ -91,7 +90,6 @@ def run_diffusion_loop(model, unconditional_context, context, latent, steps, gui
9190
torch_dtype = torch.float16 if dtype == tp.float16 else torch.float32
9291
idx_timesteps = list(range(1, 1000, 1000 // steps))
9392
num_timesteps = len(idx_timesteps)
94-
print(f"num_timesteps: {num_timesteps}")
9593
timesteps = torch.tensor(idx_timesteps, dtype=torch_dtype, device="cuda")
9694
guidance = torch.tensor([guidance], dtype=torch_dtype, device="cuda")
9795

@@ -106,16 +104,14 @@ def run_diffusion_loop(model, unconditional_context, context, latent, steps, gui
106104
iterator = list(range(num_timesteps))[::-1]
107105

108106
for index in iterator:
109-
idx = torch.tensor([index], dtype=torch.int32, device="cuda")
110107
latent = model(
111108
unconditional_context,
112109
context,
113110
latent,
114-
tp.Tensor(timesteps),
115-
tp.Tensor(alphas),
116-
tp.Tensor(alphas_prev),
111+
tp.Tensor(timesteps[index : index + 1]),
112+
tp.Tensor(alphas[index : index + 1]),
113+
tp.Tensor(alphas_prev[index : index + 1]),
117114
tp.Tensor(guidance),
118-
tp.Tensor(idx),
119115
)
120116

121117
return latent
@@ -165,9 +161,8 @@ def tripy_diffusion(args):
165161
os.mkdir(args.engine_dir)
166162

167163
# Load existing engines if they exist, otherwise compile and save them
168-
timesteps_size = len(list(range(1, 1000, 1000 // args.steps)))
169164
clip_compiled = compile_clip(model.text_encoder, engine_path=clip_path, verbose=args.verbose)
170-
unet_compiled = compile_unet(model, timesteps_size, engine_path=unet_path, dtype=dtype, verbose=args.verbose)
165+
unet_compiled = compile_unet(model, engine_path=unet_path, dtype=dtype, verbose=args.verbose)
171166
vae_compiled = compile_vae(model.decode, engine_path=vae_path, dtype=dtype, verbose=args.verbose)
172167

173168
# Run through CLIP to get context from prompt

tripy/examples/diffusion/models/model.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,7 @@ def decode(self, x):
8181
x = clamp(tp.permute(tp.reshape(x, (3, 512, 512)), (1, 2, 0)), 0, 1) * 255
8282
return x
8383

84-
def __call__(
85-
self, unconditional_context, context, latent, timesteps, alphas_cumprod, alphas_cumprod_prev, guidance, index
86-
):
87-
timestep = tp.reshape(timesteps[index], (1,))
88-
alphas = alphas_cumprod[index]
89-
alphas_prev = alphas_cumprod_prev[index]
84+
def __call__(self, unconditional_context, context, latent, timestep, alphas, alphas_prev, guidance):
9085
e_t = self.get_model_output(unconditional_context, context, latent, timestep, guidance)
9186
x_prev, _ = self.get_x_prev_and_pred_x0(latent, e_t, alphas, alphas_prev)
9287
return x_prev

tripy/examples/diffusion/models/unet_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def __init__(self, config: UNetConfig):
289289
config.model_channels, config.io_channels, (3, 3), padding=((1, 1), (1, 1)), dtype=config.dtype
290290
)
291291

292-
def __call__(self, x, timesteps=None, context=None, index=None):
292+
def __call__(self, x, timesteps=None, context=None):
293293
t_emb = timestep_embedding(timesteps, self.config.model_channels, self.config.dtype)
294294
emb = self.time_embedding(t_emb)
295295
x = self.conv_in(x)

0 commit comments

Comments
 (0)