Skip to content

Commit 6632d00

Browse files
committed
more review fixes
1 parent 0cbac6b commit 6632d00

File tree

6 files changed

+25
-28
lines changed

6 files changed

+25
-28
lines changed

tripy/examples/diffusion/example.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def compile_model(model, inputs, engine_path, verbose=False):
4747
if verbose:
4848
compile_end_time = time.perf_counter()
4949
print(f"saved engine to {engine_path}.")
50-
print(f"took {compile_end_time - compile_start_time} seconds.")
50+
print(f"Took {compile_end_time - compile_start_time} seconds.")
5151

5252
return compiled_model
5353

@@ -133,9 +133,7 @@ def save_image(image, args):
133133

134134
# Save image
135135
print(f"[I] Saving image to {filename}")
136-
if not os.path.isdir(os.path.dirname(filename)):
137-
print(f"[I] Creating '{os.path.dirname(filename)}' directory.")
138-
os.makedirs(os.path.dirname(filename))
136+
os.makedirs(os.path.dirname(filename), exist_ok=True)
139137
image.save(filename)
140138

141139

tripy/examples/diffusion/models/clip_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from dataclasses import dataclass
2121

22-
from examples.diffusion.helper import scaled_dot_product_attention
22+
from examples.diffusion.models.utils import scaled_dot_product_attention
2323

2424

2525
@dataclass

tripy/examples/diffusion/models/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@
2626
from examples.diffusion.models.clip_model import CLIPTextTransformer, CLIPConfig
2727
from examples.diffusion.models.unet_model import UNetModel, UNetConfig
2828
from examples.diffusion.models.vae_model import AutoencoderKL, VAEConfig
29-
from examples.diffusion.helper import clamp
29+
from examples.diffusion.models.utils import clamp
3030

3131

3232
@dataclass
3333
class StableDiffusionConfig:
34-
dtype: tp.dtype = tp.float32
34+
dtype: tp.dtype
3535
clip_config: Optional[CLIPConfig] = field(default=None, init=False)
3636
unet_config: Optional[UNetConfig] = field(default=None, init=False)
3737
vae_config: Optional[VAEConfig] = field(default=None, init=False)

tripy/examples/diffusion/models/unet_model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@
2121
import nvtripy as tp
2222
from dataclasses import dataclass
2323

24-
from examples.diffusion.helper import scaled_dot_product_attention
25-
from examples.diffusion.models.vae_model import Upsample, Downsample
24+
from examples.diffusion.models.utils import scaled_dot_product_attention, Upsample, Downsample
2625

2726

2827
@dataclass

tripy/examples/diffusion/helper.py renamed to tripy/examples/diffusion/models/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,21 @@ def scaled_dot_product_attention(
3838

3939
def clamp(tensor: tp.Tensor, min: int, max: int):
4040
return tp.minimum(tp.maximum(tensor, min), max)
41+
42+
43+
class Upsample(tp.Module):
44+
def __init__(self, config, channels):
45+
self.conv = tp.Conv(channels, channels, (3, 3), padding=((1, 1), (1, 1)), dtype=config.dtype)
46+
47+
def __call__(self, x):
48+
bs, c, py, px = x.shape
49+
x = tp.reshape(tp.expand(tp.reshape(x, (bs, c, py, 1, px, 1)), (bs, c, py, 2, px, 2)), (bs, c, py * 2, px * 2))
50+
return self.conv(x)
51+
52+
53+
class Downsample(tp.Module):
54+
def __init__(self, config, channels):
55+
self.conv = tp.Conv(channels, channels, (3, 3), stride=(2, 2), padding=((1, 1), (1, 1)), dtype=config.dtype)
56+
57+
def __call__(self, x):
58+
return self.conv(x)

tripy/examples/diffusion/models/vae_model.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import nvtripy as tp
2121
from dataclasses import dataclass
2222

23-
from examples.diffusion.helper import scaled_dot_product_attention
23+
from examples.diffusion.models.utils import scaled_dot_product_attention, Upsample, Downsample
2424

2525

2626
@dataclass
@@ -80,24 +80,6 @@ def __call__(self, x):
8080
return self.conv_shortcut(x) + h
8181

8282

83-
class Downsample(tp.Module):
84-
def __init__(self, config, channels):
85-
self.conv = tp.Conv(channels, channels, (3, 3), stride=(2, 2), padding=((1, 1), (1, 1)), dtype=config.dtype)
86-
87-
def __call__(self, x):
88-
return self.conv(x)
89-
90-
91-
class Upsample(tp.Module):
92-
def __init__(self, config, channels):
93-
self.conv = tp.Conv(channels, channels, (3, 3), padding=((1, 1), (1, 1)), dtype=config.dtype)
94-
95-
def __call__(self, x):
96-
bs, c, py, px = x.shape
97-
x = tp.reshape(tp.expand(tp.reshape(x, (bs, c, py, 1, px, 1)), (bs, c, py, 2, px, 2)), (bs, c, py * 2, px * 2))
98-
return self.conv(x)
99-
100-
10183
class UpDecoderBlock2D(tp.Module):
10284
def __init__(self, config: VAEConfig, start_channels, channels, use_upsampler=True):
10385
self.resnets = [

0 commit comments

Comments
 (0)