Skip to content

Commit b5a9c9c

Browse files
committed
Small change to scale dot attn
1 parent 8c8e5c4 commit b5a9c9c

File tree

5 files changed

+10
-23
lines changed

5 files changed

+10
-23
lines changed

tripy/examples/diffusion/clip_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __call__(self, hidden_states, causal_attention_mask):
6666
for x in (q, k, v)
6767
]
6868
attn_output = scaled_dot_product_attention(
69-
q, k, v, embedding_dim=self.head_dim, attn_mask=causal_attention_mask, dtype=self.dtype,
69+
q, k, v, embedding_dim=self.head_dim, attn_mask=causal_attention_mask,
7070
)
7171
out = self.out_proj(tp.reshape(tp.transpose(attn_output, 1, 2), (bsz, tgt_len, embed_dim)))
7272
return out

tripy/examples/diffusion/helper.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import math
2-
from functools import reduce
3-
from typing import List, Callable, Optional
2+
from typing import Optional
43

54
import tripy as tp
65

@@ -9,29 +8,17 @@ def scaled_dot_product_attention(
98
query: tp.Tensor,
109
key: tp.Tensor,
1110
value: tp.Tensor,
12-
embedding_dim: Optional[int] = None,
11+
embedding_dim: int,
1312
attn_mask: Optional[tp.Tensor] = None,
14-
is_causal: bool = False,
15-
dtype: tp.dtype = tp.float32
1613
) -> tp.Tensor:
17-
"""
18-
Computes scaled dot-product attention.
19-
`self` is the query tensor, `key` is the key tensor, and `value` is the value tensor.
20-
21-
- Described: https://paperswithcode.com/method/scaled
22-
- Paper: https://arxiv.org/abs/1706.03762v7
23-
"""
24-
if is_causal: # this path is not called in demoDiffusion
25-
target_shape = query.shape[-2:-1] + key.shape[-2:-1]
26-
# TODO: #228: WAR to prevent computing output rank in infer_rank for reshape
27-
target_shape.trace_tensor.shape = (2,)
28-
attn_mask = tp.cast(tp.tril(tp.ones(target_shape)), tp.bool)
14+
dtype = query.dtype
2915
if attn_mask is not None and attn_mask.dtype == tp.bool:
3016
attn_mask = tp.where((attn_mask == 0), tp.ones_like(attn_mask, dtype=dtype) * -float("inf"), tp.zeros_like(attn_mask, dtype=dtype))
3117
if attn_mask is not None:
3218
attn_mask = tp.cast(attn_mask, dtype)
33-
qk = query @ tp.transpose(key, -2, -1) / math.sqrt(embedding_dim)
34-
return tp.cast(tp.softmax((qk + attn_mask) if attn_mask is not None else qk, -1), query.dtype) @ value
19+
k_t = tp.transpose(key, -2, -1)
20+
qk = (query @ k_t) * (1.0 / math.sqrt(embedding_dim))
21+
return tp.softmax((qk + attn_mask) if attn_mask is not None else qk, -1) @ value
3522

3623

3724
def clamp(tensor: tp.Tensor, min: int, max: int):

tripy/examples/diffusion/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def decode(self, x):
8787
x = clamp(tp.permute(tp.reshape(x, (3, 512, 512)), (1, 2, 0)), 0, 1) * 255
8888
return x
8989

90-
def __call__(self, unconditional_context, context, latent, timestep, alphas, alphas_prev, guidance):
90+
def __call__(self, unconditional_context, context, latent, timestep, alphas, alphas_prev, guidance):
9191
e_t = self.get_model_output(unconditional_context, context, latent, timestep, guidance)
9292
x_prev, _ = self.get_x_prev_and_pred_x0(latent, e_t, alphas, alphas_prev)
9393
return x_prev

tripy/examples/diffusion/unet_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def __call__(self, x, context=None):
8080
tp.transpose(tp.reshape(y, (x.shape[0], -1, self.num_heads, self.head_size)), 1, 2) for y in (q, k, v)
8181
]
8282
attention = tp.transpose(
83-
scaled_dot_product_attention(q, k, v, embedding_dim=self.head_size, dtype=self.dtype), 1, 2
83+
scaled_dot_product_attention(q, k, v, embedding_dim=self.head_size), 1, 2
8484
)
8585
h_ = tp.reshape(attention, (x.shape[0], -1, self.num_heads * self.head_size))
8686
out = self.to_out(h_)

tripy/examples/diffusion/vae_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __call__(self, x):
5151
q, k, v = self.to_q(h_flat), self.to_k(h_flat), self.to_v(h_flat)
5252

5353
# compute attention
54-
h_ = scaled_dot_product_attention(q, k, v, embedding_dim=self.in_channels, dtype=self.dtype)
54+
h_ = scaled_dot_product_attention(q, k, v, embedding_dim=self.in_channels)
5555
out = tp.reshape(
5656
tp.transpose(self.to_out[0](h_), 1, 2),
5757
(b, c, h, w),

0 commit comments

Comments
 (0)