Skip to content

Commit 0f1b4ee

Browse files
committed
Remove alignment warnings, improve packaging
1 parent 9da2df9 commit 0f1b4ee

File tree

8 files changed

+78
-20
lines changed

8 files changed

+78
-20
lines changed

tripy/examples/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#
2+
# SPDX-FileCopyrightText: Copyright (c) 2025-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#
2+
# SPDX-FileCopyrightText: Copyright (c) 2025-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#

tripy/examples/diffusion/example.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
import nvtripy as tp
2626

2727
from transformers import CLIPTokenizer
28-
from examples.diffusion.models.clip_model import CLIPConfig
29-
from examples.diffusion.models.model import StableDiffusion, StableDiffusionConfig
30-
from examples.diffusion.weight_loader import load_from_diffusers
28+
from models.clip_model import CLIPConfig
29+
from models.model import StableDiffusion, StableDiffusionConfig
30+
from weight_loader import load_from_diffusers
3131

3232

3333
def compile_model(model, inputs, engine_path, verbose=False):
@@ -57,18 +57,19 @@ def compile_clip(model, engine_path, dtype=tp.int32, verbose=False):
5757
return compile_model(model, inputs, engine_path, verbose=verbose)
5858

5959

60-
def compile_unet(model, engine_path, dtype, verbose=False):
60+
def compile_unet(model, steps, engine_path, dtype, verbose=False):
6161
unconditional_context_shape = (1, 77, 768)
6262
conditional_context_shape = (1, 77, 768)
6363
latent_shape = (1, 4, 64, 64)
6464
inputs = (
6565
tp.InputInfo(unconditional_context_shape, dtype=dtype),
6666
tp.InputInfo(conditional_context_shape, dtype=dtype),
6767
tp.InputInfo(latent_shape, dtype=dtype),
68+
tp.InputInfo((steps,), dtype=dtype),
69+
tp.InputInfo((steps,), dtype=dtype),
70+
tp.InputInfo((steps,), dtype=dtype),
6871
tp.InputInfo((1,), dtype=dtype),
69-
tp.InputInfo((1,), dtype=dtype),
70-
tp.InputInfo((1,), dtype=dtype),
71-
tp.InputInfo((1,), dtype=dtype),
72+
tp.InputInfo((1,), dtype=tp.int32),
7273
)
7374
return compile_model(model, inputs, engine_path, verbose=verbose)
7475

@@ -90,6 +91,7 @@ def run_diffusion_loop(model, unconditional_context, context, latent, steps, gui
9091
torch_dtype = torch.float16 if dtype == tp.float16 else torch.float32
9192
idx_timesteps = list(range(1, 1000, 1000 // steps))
9293
num_timesteps = len(idx_timesteps)
94+
print(f"num_timesteps: {num_timesteps}")
9395
timesteps = torch.tensor(idx_timesteps, dtype=torch_dtype, device="cuda")
9496
guidance = torch.tensor([guidance], dtype=torch_dtype, device="cuda")
9597

@@ -104,14 +106,16 @@ def run_diffusion_loop(model, unconditional_context, context, latent, steps, gui
104106
iterator = list(range(num_timesteps))[::-1]
105107

106108
for index in iterator:
109+
idx = torch.tensor([index], dtype=torch.int32, device="cuda")
107110
latent = model(
108111
unconditional_context,
109112
context,
110113
latent,
111-
tp.Tensor(timesteps[index : index + 1]),
112-
tp.Tensor(alphas[index : index + 1]),
113-
tp.Tensor(alphas_prev[index : index + 1]),
114+
tp.Tensor(timesteps),
115+
tp.Tensor(alphas),
116+
tp.Tensor(alphas_prev),
114117
tp.Tensor(guidance),
118+
tp.Tensor(idx),
115119
)
116120

117121
return latent
@@ -161,8 +165,9 @@ def tripy_diffusion(args):
161165
os.mkdir(args.engine_dir)
162166

163167
# Load existing engines if they exist, otherwise compile and save them
168+
timesteps_size = len(list(range(1, 1000, 1000 // args.steps)))
164169
clip_compiled = compile_clip(model.text_encoder, engine_path=clip_path, verbose=args.verbose)
165-
unet_compiled = compile_unet(model, engine_path=unet_path, dtype=dtype, verbose=args.verbose)
170+
unet_compiled = compile_unet(model, timesteps_size, engine_path=unet_path, dtype=dtype, verbose=args.verbose)
166171
vae_compiled = compile_vae(model.decode, engine_path=vae_path, dtype=dtype, verbose=args.verbose)
167172

168173
# Run through CLIP to get context from prompt
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#
2+
# SPDX-FileCopyrightText: Copyright (c) 2025-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#

tripy/examples/diffusion/models/clip_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from dataclasses import dataclass
2121

22-
from examples.diffusion.models.utils import scaled_dot_product_attention
22+
from models.utils import scaled_dot_product_attention
2323

2424

2525
@dataclass

tripy/examples/diffusion/models/model.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@
2323
from typing import Optional
2424
from dataclasses import dataclass, field
2525

26-
from examples.diffusion.models.clip_model import CLIPTextTransformer, CLIPConfig
27-
from examples.diffusion.models.unet_model import UNetModel, UNetConfig
28-
from examples.diffusion.models.vae_model import AutoencoderKL, VAEConfig
29-
from examples.diffusion.models.utils import clamp
26+
from models.clip_model import CLIPTextTransformer, CLIPConfig
27+
from models.unet_model import UNetModel, UNetConfig
28+
from models.vae_model import AutoencoderKL, VAEConfig
29+
from models.utils import clamp
3030

3131

3232
@dataclass
@@ -81,7 +81,12 @@ def decode(self, x):
8181
x = clamp(tp.permute(tp.reshape(x, (3, 512, 512)), (1, 2, 0)), 0, 1) * 255
8282
return x
8383

84-
def __call__(self, unconditional_context, context, latent, timestep, alphas, alphas_prev, guidance):
84+
def __call__(
85+
self, unconditional_context, context, latent, timesteps, alphas_cumprod, alphas_cumprod_prev, guidance, index
86+
):
87+
timestep = tp.reshape(timesteps[index], (1,))
88+
alphas = alphas_cumprod[index]
89+
alphas_prev = alphas_cumprod_prev[index]
8590
e_t = self.get_model_output(unconditional_context, context, latent, timestep, guidance)
8691
x_prev, _ = self.get_x_prev_and_pred_x0(latent, e_t, alphas, alphas_prev)
8792
return x_prev

tripy/examples/diffusion/models/unet_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import nvtripy as tp
2222
from dataclasses import dataclass
2323

24-
from examples.diffusion.models.utils import scaled_dot_product_attention, Upsample, Downsample
24+
from models.utils import scaled_dot_product_attention, Upsample, Downsample
2525

2626

2727
@dataclass
@@ -289,7 +289,7 @@ def __init__(self, config: UNetConfig):
289289
config.model_channels, config.io_channels, (3, 3), padding=((1, 1), (1, 1)), dtype=config.dtype
290290
)
291291

292-
def __call__(self, x, timesteps=None, context=None):
292+
def __call__(self, x, timesteps=None, context=None, index=None):
293293
t_emb = timestep_embedding(timesteps, self.config.model_channels, self.config.dtype)
294294
emb = self.time_embedding(t_emb)
295295
x = self.conv_in(x)

tripy/examples/diffusion/models/vae_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import nvtripy as tp
2121
from dataclasses import dataclass
2222

23-
from examples.diffusion.models.utils import scaled_dot_product_attention, Upsample, Downsample
23+
from models.utils import scaled_dot_product_attention, Upsample, Downsample
2424

2525

2626
@dataclass

0 commit comments

Comments
 (0)