Skip to content

Commit 2e0c74a

Browse files
committed
Use torch to generate pos embeddings
1 parent 3ae73ea commit 2e0c74a

File tree

4 files changed

+13
-8
lines changed

4 files changed

+13
-8
lines changed

tripy/examples/segment-anything-model-v2/sam2/modeling/backbones/image_encoder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ def __init__(
120120
self.position_encoding = []
121121
position_encoding_shapes = [[256, 256], [128, 128], [64, 64], [32, 32]]
122122
for s in position_encoding_shapes:
123-
self.position_encoding.append(position_encoding.generate_static_embedding([1, 256] + s, dtype=dtype))
123+
embed = position_encoding.generate_static_embedding([1, 256] + s, dtype=dtype)
124+
self.position_encoding.append(tp.Tensor(embed))
124125

125126
def __call__(self, xs: List[tp.Tensor]):
126127

tripy/examples/segment-anything-model-v2/sam2/modeling/memory_encoder.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,4 @@ def forward(
179179
x = self.fuser(x)
180180
x = self.out_proj(x)
181181

182-
pos = tp.cast(self.position_encoding(x), x.dtype)
183-
184-
return x, pos
182+
return x

tripy/examples/segment-anything-model-v2/sam2/modeling/position_encoding.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def forward(self, x: tp.Tensor):
8080
pos = tp.permute(pos, (0, 3, 1, 2))
8181
return pos
8282

83-
def generate_static_embedding(self, inp_shape, dtype):
83+
def generate_static_embedding(self, inp_shape, dtype=None):
8484
import torch
8585

8686
B, _, H, W = inp_shape
@@ -100,7 +100,12 @@ def generate_static_embedding(self, inp_shape, dtype):
100100
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
101101
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
102102
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
103-
return tp.Tensor(pos.to(getattr(torch, dtype)).contiguous())
103+
if dtype is not None:
104+
pos = pos.to(getattr(torch, dtype))
105+
return pos.contiguous()
106+
107+
def generate_pos_embedding_torch(self, x):
108+
return self.generate_static_embedding(x.shape).to(x.dtype).contiguous()
104109

105110

106111
class PositionEmbeddingRandom(tp.Module):

tripy/examples/segment-anything-model-v2/sam2/modeling/sam2_base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def __init__(
139139

140140
# Part 3: memory encoder for the previous frame's outputs
141141
self.memory_encoder = memory_encoder
142+
self.position_encoder = self.memory_encoder.position_encoding
142143
self.mem_dim = self.hidden_dim
143144
if hasattr(self.memory_encoder, "out_proj") and hasattr(self.memory_encoder.out_proj, "weight"):
144145
# if there is compression of memories along channel dim
@@ -720,11 +721,11 @@ def _encode_new_memory(
720721
if self.sigmoid_bias_for_mem_enc != 0.0:
721722
mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
722723

723-
maskmem_features, maskmem_pos_enc = self.memory_encoder(
724+
maskmem_features = self.memory_encoder(
724725
tp.Tensor(pix_feat.float().contiguous()), tp.Tensor(mask_for_mem.contiguous())
725726
) # sigmoid already applied
726727
maskmem_features = torch.from_dlpack(maskmem_features)
727-
maskmem_pos_enc = [torch.from_dlpack(maskmem_pos_enc)]
728+
maskmem_pos_enc = [self.position_encoder.generate_pos_embedding_torch(maskmem_features)]
728729

729730
return maskmem_features, maskmem_pos_enc
730731

0 commit comments

Comments
 (0)