Skip to content

Commit 0e51f4d

Browse files
Improves attention mask implementation in SAMv2
1 parent e5fe27e commit 0e51f4d

File tree

1 file changed

+1
-5
lines changed
  • tripy/examples/segment-anything-model-v2/sam2/modeling

1 file changed

+1
-5
lines changed

tripy/examples/segment-anything-model-v2/sam2/modeling/sam2_utils.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,7 @@ def scaled_dot_product_attention(
108108
target_shape.trace_tensor.shape = (2,)
109109
attn_mask = tp.cast(tp.tril(tp.ones(target_shape)), tp.bool)
110110
if attn_mask is not None and attn_mask.dtype == tp.bool:
111-
attn_mask = tp.where(
112-
(attn_mask == 0),
113-
tp.ones_like(attn_mask) * -float("inf"),
114-
tp.zeros_like(attn_mask),
115-
)
111+
attn_mask = tp.where((attn_mask == 0), tp.cast(tp.Tensor(-float("inf")), dtype=query.dtype), 0.0)
116112
if embedding_dim is None:
117113
embedding_dim = query.shape[-1]
118114
qk = query @ tp.transpose(key, -2, -1) / tp.sqrt(tp.cast(embedding_dim, query.dtype))

0 commit comments

Comments
 (0)