Skip to content

Commit 5a4c440

Browse files
committed
Generate memory pos embedding ahead of time
1 parent 2e0c74a commit 5a4c440

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

tripy/examples/segment-anything-model-v2/sam2/modeling/position_encoding.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ def __init__(
4848
if scale is None:
4949
scale = 2 * math.pi
5050
self.scale = scale
51+
self.static_memory_embed_shape = [2, 64, 64, 64]
52+
self.static_memory_pos_embedding = self.generate_static_embedding(self.static_memory_embed_shape, "float16")
5153

5254
def forward(self, x: tp.Tensor):
5355
# x: [B, C, H, W]
@@ -105,6 +107,8 @@ def generate_static_embedding(self, inp_shape, dtype=None):
105107
return pos.contiguous()
106108

107109
def generate_pos_embedding_torch(self, x):
110+
if list(x.shape) == self.static_memory_embed_shape:
111+
return self.static_memory_pos_embedding.to(x.dtype).contiguous()
108112
return self.generate_static_embedding(x.shape).to(x.dtype).contiguous()
109113

110114

0 commit comments

Comments
 (0)