Skip to content

Commit 24a2deb

Browse files
Adds support for FP16 in memory encoder
1 parent dd14966 commit 24a2deb

File tree

5 files changed

+19
-12
lines changed

5 files changed

+19
-12
lines changed

tripy/examples/segment-anything-model-v2/configs/sam2_hiera_l.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ model:
6262
memory_encoder:
6363
_target_: sam2.modeling.memory_encoder.MemoryEncoder
6464
out_dim: 64
65+
dtype: float16
6566
position_encoding:
6667
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
6768
num_pos_feats: 64
@@ -73,6 +74,7 @@ model:
7374
kernel_size: 3
7475
stride: 2
7576
padding: 1
77+
dtype: float16
7678
fuser:
7779
_target_: sam2.modeling.memory_encoder.Fuser
7880
layer:
@@ -82,6 +84,7 @@ model:
8284
padding: 3
8385
layer_scale_init_value: 1e-6
8486
use_dwconv: True # depth-wise convs
87+
dtype: float16
8588
num_layers: 2
8689

8790
num_maskmem: 7

tripy/examples/segment-anything-model-v2/configs/sam2_hiera_s.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ model:
6363
memory_encoder:
6464
_target_: sam2.modeling.memory_encoder.MemoryEncoder
6565
out_dim: 64
66+
dtype: float16
6667
position_encoding:
6768
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
6869
num_pos_feats: 64
@@ -74,6 +75,7 @@ model:
7475
kernel_size: 3
7576
stride: 2
7677
padding: 1
78+
dtype: float16
7779
fuser:
7880
_target_: sam2.modeling.memory_encoder.Fuser
7981
layer:
@@ -83,6 +85,7 @@ model:
8385
padding: 3
8486
layer_scale_init_value: 1e-6
8587
use_dwconv: True # depth-wise convs
88+
dtype: float16
8689
num_layers: 2
8790

8891
num_maskmem: 7

tripy/examples/segment-anything-model-v2/configs/sam2_hiera_t.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ model:
6363
memory_encoder:
6464
_target_: sam2.modeling.memory_encoder.MemoryEncoder
6565
out_dim: 64
66+
dtype: float16
6667
position_encoding:
6768
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
6869
num_pos_feats: 64
@@ -74,6 +75,7 @@ model:
7475
kernel_size: 3
7576
stride: 2
7677
padding: 1
78+
dtype: float16
7779
fuser:
7880
_target_: sam2.modeling.memory_encoder.Fuser
7981
layer:
@@ -83,6 +85,7 @@ model:
8385
padding: 3
8486
layer_scale_init_value: 1e-6
8587
use_dwconv: True # depth-wise convs
88+
dtype: float16
8689
num_layers: 2
8790

8891
num_maskmem: 7

tripy/examples/segment-anything-model-v2/sam2/build_sam.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
# limitations under the License.
2424

2525

26-
import logging
27-
2826
import torch
2927
from hydra import compose
3028
from hydra.utils import instantiate
@@ -184,10 +182,10 @@ def get_component_configs(model, cfg):
184182
"memory_encoder": {
185183
"enabled": True,
186184
"model": model.memory_encoder,
187-
"dtype": "float32", # TODO add fp16 to yaml
185+
"dtype": model_precision,
188186
"compile_args": [
189-
tp.InputInfo((batchsize, 256, 64, 64), tp.float32),
190-
tp.InputInfo((batchsize, num_obj, 1024, 1024), tp.float32),
187+
tp.InputInfo((batchsize, 256, 64, 64), getattr(tp, model_precision)),
188+
tp.InputInfo((batchsize, num_obj, 1024, 1024), getattr(tp, model_precision)),
191189
True,
192190
],
193191
"skip_dtype_convert": ["ln", "norm"]
@@ -227,10 +225,7 @@ def get_component_configs(model, cfg):
227225
"compile_args": [
228226
tp.InputInfo(
229227
(batchsize, 3, 1024, 1024),
230-
dtype=getattr(
231-
tp,
232-
model_precision,
233-
),
228+
dtype=getattr(tp, model_precision),
234229
),
235230
],
236231
"skip_dtype_convert": ["norm"],
@@ -285,7 +280,7 @@ def get_or_compile_component(self, comp_name: str, comp_info: Dict[str, Any]) ->
285280
else:
286281
print(f"Compiling {comp_name}...")
287282
start = time.time()
288-
compiled_model = tp.compile(comp_info["model"], args=comp_info["compile_args"])
283+
compiled_model = tp.compile(comp_info["model"], optimization_level=5, args=comp_info["compile_args"])
289284
print(f"Compilation took {time.time() - start:.2f}s")
290285
compiled_model.save(executable_file)
291286

tripy/examples/segment-anything-model-v2/sam2/modeling/sam2_base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -710,7 +710,7 @@ def _encode_new_memory(
710710
# scale the raw mask logits with a temperature before applying sigmoid
711711
binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
712712
if binarize and not self.training:
713-
mask_for_mem = (pred_masks_high_res > 0).float()
713+
mask_for_mem = pred_masks_high_res > 0
714714
else:
715715
# apply sigmoid on the raw mask logits to turn them into range (0, 1)
716716
mask_for_mem = torch.sigmoid(pred_masks_high_res)
@@ -720,8 +720,11 @@ def _encode_new_memory(
720720
if self.sigmoid_bias_for_mem_enc != 0.0:
721721
mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
722722

723+
if self.memory_encoder.input_infos["masks"].dtype == tp.float16:
724+
mask_for_mem = mask_for_mem.half()
725+
723726
maskmem_features, maskmem_pos_enc = self.memory_encoder(
724-
tp.Tensor(pix_feat.float().contiguous()), tp.Tensor(mask_for_mem.contiguous())
727+
tp.Tensor(pix_feat.contiguous()), tp.Tensor(mask_for_mem.contiguous())
725728
) # sigmoid already applied
726729
maskmem_features = torch.from_dlpack(maskmem_features)
727730
maskmem_pos_enc = [torch.from_dlpack(maskmem_pos_enc)]

0 commit comments

Comments
 (0)