Skip to content

Commit 5c95436

Browse files
committed
Apply DimensionInputInfo to SAM sample
1 parent 43b1de6 commit 5c95436

File tree

3 files changed

+3
-10
lines changed

3 files changed

+3
-10
lines changed

tripy/examples/segment-anything-model-v2/sam2/build_sam.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,7 @@ def get_component_configs(model, cfg):
8181
(seq_len, mem_attention_batch, 64),
8282
getattr(tp, model_precision),
8383
),
84-
# TODO (#594): Remove this hack once we are able to pass in DimensionSizes directly:
85-
tp.InputInfo(((4, 16, 64),), tp.int32),
84+
tp.DimensionInputInfo(value_bounds=(4, 16, 64)),
8685
],
8786
"skip_dtype_convert": [],
8887
},

tripy/examples/segment-anything-model-v2/sam2/modeling/memory_attention.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,10 +186,8 @@ def forward(
186186
memory: tp.Tensor, # cross-attention inputs
187187
curr_pos: Optional[tp.Tensor] = None, # pos_enc for self-attention inputs
188188
memory_pos: Optional[tp.Tensor] = None, # pos_enc for cross-attention inputs
189-
num_obj_ptr_tokens: Optional[tp.Tensor] = None, # number of object pointer *tokens*
189+
num_obj_ptr_tokens: Optional[tp.DimensionSize] = None, # number of object pointer *tokens*
190190
):
191-
# TODO (#594): Remove this hack once we are able to pass in DimensionSizes directly:
192-
num_obj_ptr_tokens = num_obj_ptr_tokens.shape[0]
193191
output = curr
194192
if self.pos_enc_at_input and curr_pos is not None:
195193
output = output + 0.1 * curr_pos

tripy/examples/segment-anything-model-v2/sam2/modeling/sam2_base.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,6 @@ def _build_sam_heads(self):
242242
else:
243243
self.obj_ptr_tpos_proj = torch.nn.Identity()
244244

245-
self.fake_object_ptrs = torch.ones((1,), dtype=torch.int32, device="cuda")
246-
247245
def _forward_sam_heads(
248246
self,
249247
backbone_features,
@@ -667,14 +665,12 @@ def _prepare_memory_conditioned_features(
667665
memory = torch.cat(to_cat_memory, dim=0)
668666
memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)
669667
if isinstance(self.memory_attention, tp.Module) or isinstance(self.memory_attention, tp.Executable):
670-
if self.fake_object_ptrs.shape != (num_obj_ptr_tokens,):
671-
self.fake_object_ptrs = torch.ones((num_obj_ptr_tokens,), dtype=torch.int32, device="cuda")
672668
pix_feat_with_mem = self.memory_attention(
673669
curr=tp.Tensor(current_vision_feats[0].half().contiguous()),
674670
memory=tp.Tensor(memory.half().contiguous()),
675671
curr_pos=tp.Tensor(current_vision_pos_embeds[0].half().contiguous()),
676672
memory_pos=tp.Tensor(memory_pos_embed.half().contiguous()),
677-
num_obj_ptr_tokens=tp.Tensor(self.fake_object_ptrs),
673+
num_obj_ptr_tokens=tp.DimensionSize(num_obj_ptr_tokens),
678674
)
679675
else:
680676
pix_feat_with_mem = self.memory_attention(

0 commit comments

Comments
 (0)