Skip to content

Commit 8c8e5c4

Browse files
committed
Rebase to TOT, small fixes
1 parent 2646d24 commit 8c8e5c4

File tree

4 files changed

+10
-16
lines changed

4 files changed

+10
-16
lines changed

tripy/examples/diffusion/example.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def get_alphas_cumprod(beta_start=0.00085, beta_end=0.0120, n_training_steps=100
8383
def run_diffusion_loop(model, unconditional_context, context, latent, steps, guidance, dtype):
8484
np_type = np.float16 if dtype == tp.float16 else np.float32
8585
idx_timesteps = list(range(1, 1000, 1000 // steps))
86+
timesteps = np.array(idx_timesteps, dtype=np_type)
8687
guidance = np.array([guidance], dtype=np_type)
8788

8889
print(f"[I] Running diffusion for {steps} timesteps...")
@@ -91,12 +92,12 @@ def run_diffusion_loop(model, unconditional_context, context, latent, steps, gui
9192
guidance = tp.Tensor(guidance)
9293

9394
model.stream = tp.Stream()
94-
for index in tqdm(range(len(idx_timesteps))):
95+
for index, timestep in (t := tqdm(list(enumerate(timesteps))[::-1])):
9596
latent = model(
9697
unconditional_context,
9798
context,
9899
latent,
99-
tp.Tensor(np.array([idx_timesteps[index]], dtype=np_type)),
100+
tp.Tensor(np.array([timestep])),
100101
tp.Tensor(alphas[index : index + 1]),
101102
tp.Tensor(alphas_prev[index : index + 1]),
102103
guidance,

tripy/examples/diffusion/helper.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,5 @@ def scaled_dot_product_attention(
3434
return tp.cast(tp.softmax((qk + attn_mask) if attn_mask is not None else qk, -1), query.dtype) @ value
3535

3636

37-
def sequential(input: tp.Tensor, ll: List[Callable[[tp.Tensor], tp.Tensor]]):
38-
"""
39-
Applies a sequence of functions to `self` chaining the output of each function to the input of the next.
40-
"""
41-
return reduce(lambda x, f: f(x), ll, input)
42-
43-
4437
def clamp(tensor: tp.Tensor, min: int, max: int):
4538
return tp.minimum(tp.maximum(tensor, tp.ones_like(tensor) * min), tp.ones_like(tensor) * max)

tripy/examples/diffusion/unet_model.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import tripy as tp
2323
from dataclasses import dataclass
2424

25-
from examples.diffusion.helper import scaled_dot_product_attention, sequential
25+
from examples.diffusion.helper import scaled_dot_product_attention
2626
from examples.diffusion.vae_model import Upsample, Downsample
2727

2828

@@ -70,7 +70,7 @@ def __init__(self, config: UNetConfig, query_dim, context_dim, n_heads, d_head):
7070
self.to_v = tp.Linear(context_dim, n_heads * d_head, bias=False, dtype=config.dtype)
7171
self.num_heads = n_heads
7272
self.head_size = d_head
73-
self.to_out = [tp.Linear(n_heads * d_head, query_dim, dtype=config.dtype)]
73+
self.to_out = tp.Sequential(tp.Linear(n_heads * d_head, query_dim, dtype=config.dtype),)
7474
self.dtype = config.dtype
7575

7676
def __call__(self, x, context=None):
@@ -83,7 +83,7 @@ def __call__(self, x, context=None):
8383
scaled_dot_product_attention(q, k, v, embedding_dim=self.head_size, dtype=self.dtype), 1, 2
8484
)
8585
h_ = tp.reshape(attention, (x.shape[0], -1, self.num_heads * self.head_size))
86-
out = sequential(h_, self.to_out)
86+
out = self.to_out(h_)
8787
return out
8888

8989

@@ -108,14 +108,14 @@ def __call__(self, x):
108108

109109
class FeedForward(tp.Module):
110110
def __init__(self, config: UNetConfig, dim, mult=4):
111-
self.net = [
111+
self.net = tp.Sequential(
112112
GEGLU(config, dim, dim * mult),
113113
Dummy(), # Accounts for Dropout layer, needed for weight loading
114114
tp.Linear(dim * mult, dim, dtype=config.dtype),
115-
]
115+
)
116116

117117
def __call__(self, x):
118-
return sequential(x, self.net)
118+
return self.net(x)
119119

120120

121121
class BasicTransformerBlock(tp.Module):

tripy/examples/diffusion/weight_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def load_weights_from_hf(model, hf_model, dtype, debug=False):
3737
weight = hf_state_dict[key]
3838
if "norm" not in key:
3939
weight = weight.to(torch_dtype)
40-
param = tp.Parameter(weight)
40+
param = tp.Tensor(weight.contiguous())
4141
tripy_state_dict[key.removeprefix("text_model.")] = param
4242

4343
model.load_state_dict(tripy_state_dict)

0 commit comments

Comments
 (0)