|
20 | 20 | import nvtripy as tp |
21 | 21 | from dataclasses import dataclass |
22 | 22 |
|
23 | | -from examples.diffusion.helper import scaled_dot_product_attention |
| 23 | +from examples.diffusion.models.utils import scaled_dot_product_attention, Upsample, Downsample |
24 | 24 |
|
25 | 25 |
|
26 | 26 | @dataclass |
@@ -80,24 +80,6 @@ def __call__(self, x): |
80 | 80 | return self.conv_shortcut(x) + h |
81 | 81 |
|
82 | 82 |
|
83 | | -class Downsample(tp.Module): |
84 | | - def __init__(self, config, channels): |
85 | | - self.conv = tp.Conv(channels, channels, (3, 3), stride=(2, 2), padding=((1, 1), (1, 1)), dtype=config.dtype) |
86 | | - |
87 | | - def __call__(self, x): |
88 | | - return self.conv(x) |
89 | | - |
90 | | - |
91 | | -class Upsample(tp.Module): |
92 | | - def __init__(self, config, channels): |
93 | | - self.conv = tp.Conv(channels, channels, (3, 3), padding=((1, 1), (1, 1)), dtype=config.dtype) |
94 | | - |
95 | | - def __call__(self, x): |
96 | | - bs, c, py, px = x.shape |
97 | | - x = tp.reshape(tp.expand(tp.reshape(x, (bs, c, py, 1, px, 1)), (bs, c, py, 2, px, 2)), (bs, c, py * 2, px * 2)) |
98 | | - return self.conv(x) |
99 | | - |
100 | | - |
101 | 83 | class UpDecoderBlock2D(tp.Module): |
102 | 84 | def __init__(self, config: VAEConfig, start_channels, channels, use_upsampler=True): |
103 | 85 | self.resnets = [ |
|
0 commit comments