Skip to content

Commit b634fd8

Browse files
Updates SAMv2 to name dimensions in order to trigger MHA fusions
1 parent 3c757f6 commit b634fd8

File tree

3 files changed

+29
-36
lines changed

3 files changed

+29
-36
lines changed

tripy/examples/segment-anything-model-v2/sam2/build_sam.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,10 @@ def get_component_configs(model, cfg):
5454
"""
5555
Get configurations for different components, including both compilation and weight loading info.
5656
"""
57-
batchsize = (1, 2, 4)
58-
num_obj = (1, 2, 4)
57+
batch = tp.NamedDimension("batch", 1, 2, 4)
58+
num_obj = tp.NamedDimension("num_obj", 1, 2, 4)
59+
seq_len = tp.NamedDimension("seq_len", 4100, 16400, 28736)
60+
mem_attention_batch = tp.NamedDimension("mem_attention_batch", 1, 2, 8)
5961
model_precision = getattr(cfg["model"], "model_precision", "float32")
6062
return {
6163
"memory_attention": {
@@ -64,19 +66,19 @@ def get_component_configs(model, cfg):
6466
"dtype": model_precision,
6567
"compile_args": [
6668
tp.InputInfo(
67-
(4096, (1, 2, 8), 256),
69+
(4096, mem_attention_batch, 256),
6870
getattr(tp, model_precision),
6971
),
7072
tp.InputInfo(
71-
((4100, 16400, 28736), (1, 2, 8), 64),
73+
(seq_len, mem_attention_batch, 64),
7274
getattr(tp, model_precision),
7375
),
7476
tp.InputInfo(
75-
(4096, (1, 2, 8), 256),
77+
(4096, mem_attention_batch, 256),
7678
getattr(tp, model_precision),
7779
),
7880
tp.InputInfo(
79-
((4100, 16400, 28736), (1, 2, 8), 64),
81+
(seq_len, mem_attention_batch, 64),
8082
getattr(tp, model_precision),
8183
),
8284
# TODO (#594): Remove this hack once we are able to pass in DimensionSizes directly:
@@ -124,29 +126,29 @@ def get_component_configs(model, cfg):
124126
"dtype": model_precision,
125127
"compile_args": [
126128
tp.InputInfo(
127-
(batchsize, 256, 64, 64),
129+
(batch, 256, 64, 64),
128130
dtype=getattr(tp, model_precision),
129131
), # image_embeddings
130132
tp.InputInfo(
131133
(1, 256, 64, 64),
132134
dtype=getattr(tp, model_precision),
133135
), # image_pe
134136
tp.InputInfo(
135-
(batchsize, (2, 4, 6), 256),
137+
(batch, (2, 4, 6), 256),
136138
dtype=getattr(tp, model_precision),
137139
), # sparse_prompt_embeddings
138140
tp.InputInfo(
139-
(batchsize, 256, 64, 64),
141+
(batch, 256, 64, 64),
140142
dtype=getattr(tp, model_precision),
141143
), # dense_prompt_embeddings
142144
True, # multimask_output
143145
False, # repeat_image
144146
tp.InputInfo(
145-
(batchsize, 32, 256, 256),
147+
(batch, 32, 256, 256),
146148
dtype=getattr(tp, model_precision),
147149
), # high_res_features_1
148150
tp.InputInfo(
149-
(batchsize, 64, 128, 128),
151+
(batch, 64, 128, 128),
150152
dtype=getattr(tp, model_precision),
151153
), # high_res_features_2
152154
],
@@ -159,7 +161,7 @@ def get_component_configs(model, cfg):
159161
"dtype": model_precision,
160162
"compile_args": [
161163
tp.InputInfo(
162-
(batchsize, 256, 256, 256),
164+
(batch, 256, 256, 256),
163165
dtype=getattr(tp, model_precision),
164166
)
165167
],
@@ -172,7 +174,7 @@ def get_component_configs(model, cfg):
172174
"dtype": model_precision,
173175
"compile_args": [
174176
tp.InputInfo(
175-
(batchsize, 256, 128, 128),
177+
(batch, 256, 128, 128),
176178
dtype=getattr(tp, model_precision),
177179
)
178180
],
@@ -184,8 +186,8 @@ def get_component_configs(model, cfg):
184186
"model": model.memory_encoder,
185187
"dtype": model_precision,
186188
"compile_args": [
187-
tp.InputInfo((batchsize, 256, 64, 64), getattr(tp, model_precision)),
188-
tp.InputInfo((batchsize, num_obj, 1024, 1024), getattr(tp, model_precision)),
189+
tp.InputInfo((batch, 256, 64, 64), getattr(tp, model_precision)),
190+
tp.InputInfo((batch, num_obj, 1024, 1024), getattr(tp, model_precision)),
189191
True,
190192
],
191193
"skip_dtype_convert": ["ln", "norm"]
@@ -196,8 +198,8 @@ def get_component_configs(model, cfg):
196198
"model": model.sam_prompt_encoder,
197199
"dtype": "float32",
198200
"compile_args": [
199-
tp.InputInfo((batchsize, num_obj, 2), dtype=tp.float32),
200-
tp.InputInfo((batchsize, num_obj), dtype=tp.int32),
201+
tp.InputInfo((batch, num_obj, 2), dtype=tp.float32),
202+
tp.InputInfo((batch, num_obj), dtype=tp.int32),
201203
None,
202204
None,
203205
],
@@ -224,7 +226,7 @@ def get_component_configs(model, cfg):
224226
"dtype": model_precision,
225227
"compile_args": [
226228
tp.InputInfo(
227-
(batchsize, 3, 1024, 1024),
229+
(batch, 3, 1024, 1024),
228230
dtype=getattr(tp, model_precision),
229231
),
230232
],

tripy/examples/segment-anything-model-v2/sam2/modeling/memory_attention.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,14 @@
2222
# See the License for the specific language governing permissions and
2323
# limitations under the License.
2424

25-
from typing import Optional, List
26-
27-
from sam2.modeling.sam.transformer import RoPEAttention
28-
from sam2.modeling.sam2_utils import get_activation_fn
25+
from typing import List, Optional
2926

3027
import nvtripy as tp
28+
from sam2.modeling.sam2_utils import get_activation_fn
29+
from sam2.modeling.sam.transformer import RoPEAttention
3130

3231

3332
class MemoryAttentionLayer(tp.Module):
34-
3533
def __init__(
3634
self,
3735
activation: str,
@@ -77,8 +75,6 @@ def _forward_sa(self, tgt, query_pos):
7775
return tgt
7876

7977
def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
80-
kwds = {}
81-
8278
# Cross-Attention
8379
tgt2 = tp.cast(self.norm2(tp.cast(tgt, self.norm2.dtype)), self.dtype)
8480

@@ -112,7 +108,6 @@ def forward(
112108

113109

114110
class MemoryAttention(tp.Module):
115-
116111
def __init__(
117112
self,
118113
d_model: int,

tripy/examples/segment-anything-model-v2/video_demo.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,19 +42,15 @@
4242

4343

4444
def compute_mask_properties(mask):
45-
# Ensure we have a boolean array
46-
test_mask = np.asarray(mask, dtype=bool)
47-
48-
# Calculate basic stats
49-
volume = np.sum(test_mask)
45+
volume = torch.sum(mask)
5046

5147
# Calculate centroid (center of mass)
5248
if volume > 0:
53-
indices = np.where(test_mask)
54-
centroid = tuple(float(np.mean(idx)) for idx in indices)
49+
indices = torch.where(mask)
50+
centroid = tuple((torch.sum(idx) / float(len(idx))).item() for idx in indices)
5551
else:
5652
centroid = None
57-
return volume, centroid
53+
return volume.item(), centroid
5854

5955

6056
def main(video_dir: str, save_path: Optional[str] = None):
@@ -161,7 +157,7 @@ def make_tensors_contiguous(d):
161157
video_segments = {} # video_segments contains the per-frame segmentation results
162158
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
163159
video_segments[out_frame_idx] = {
164-
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() for i, out_obj_id in enumerate(out_obj_ids)
160+
out_obj_id: (out_mask_logits[i] > 0.0) for i, out_obj_id in enumerate(out_obj_ids)
165161
}
166162
end = time.perf_counter()
167163
print(f"Video segmentation took {(end - start)}s")
@@ -175,7 +171,7 @@ def make_tensors_contiguous(d):
175171
plt.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
176172
for out_obj_id, out_mask in video_segments[out_frame_idx].items():
177173
vol, centre = compute_mask_properties(out_mask)
178-
show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
174+
show_mask(out_mask.cpu().numpy(), plt.gca(), obj_id=out_obj_id)
179175
plt.savefig(os.path.join(save_path, f"video_final_mask_{out_frame_idx}.png"))
180176

181177
# Print the properties of the mask generated for the final image for integration testing.

0 commit comments

Comments
 (0)