Skip to content

SAM3 does not support custom inference resolutions #42331

@Kallinteris-Andreas

Description

@Kallinteris-Andreas

System Info

Note: I am running the latest git version, sys Info should not be relevant to the issue
$ transformers env
Traceback (most recent call last):
File "/home/master-andreas/panopticon/test_env/bin/transformers", line 3, in
from transformers.cli.transformers import main
File "/home/master-andreas/panopticon/test_env/lib/python3.12/site-packages/transformers/cli/transformers.py", line 23, in
from transformers.cli.serve import Serve
File "/home/master-andreas/panopticon/test_env/lib/python3.12/site-packages/transformers/cli/serve.py", line 351, in
class Serve:
File "/home/master-andreas/panopticon/test_env/lib/python3.12/site-packages/transformers/cli/serve.py", line 658, in Serve
) -> ChatCompletionChunk:
^^^^^^^^^^^^^^^^^^^
NameError: name 'ChatCompletionChunk' is not defined

Who can help?

@yonigozlan

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

"""
Test script for SAM3 text prompting only.
This script demonstrates how to use SAM3 for text-based segmentation on images.
"""

import torch
from PIL import Image
import requests
from transformers import Sam3Processor, Sam3Model
import os


INFERENCE_RESOLUTION = (1008, 1008)  # If run with anything else other than 1008 it fails
# INFERENCE_RESOLUTION = (1400, 1400)


def test_sam3_text_prompting():
    """Test SAM3 with text prompting on a sample image."""

    # Set device
    device = "cpu"
    print(f"Using device: {device}")

    # Load model and processor
    print("Loading SAM3 model and processor...")
    model = Sam3Model.from_pretrained("facebook/sam3").to(device)
    processor = Sam3Processor.from_pretrained("facebook/sam3")

    # Load a sample image
    print("Loading sample image...")
    image_url = "http://images.cocodataset.org/val2017/000000077595.jpg"
    image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")

    # Define text prompts to test
    text_prompts = ["cat", "ear", "eye"]

    for text_prompt in text_prompts:
        print(f"\nTesting text prompt: '{text_prompt}'")

    # Prepare inputs
        inputs = processor(images=image, text=text_prompt, size=INFERENCE_RESOLUTION, return_tensors="pt").to(device)

        # Run inference
        with torch.no_grad():
            outputs = model(**inputs)

        # Post-process results
        results = processor.post_process_instance_segmentation(
            outputs,
            threshold=0.5,
            mask_threshold=0.5,
            target_sizes=inputs.get("original_sizes").tolist()
        )[0]

        # Display results
        num_objects = len(results['masks'])
        print(f"Found {num_objects} objects matching '{text_prompt}'")

        if num_objects > 0:
            # Show scores for first few objects
            scores = results['scores']
            print(f"Confidence scores: {scores[:min(3, len(scores))].tolist()}")

            # Show bounding boxes for first object
            if 'boxes' in results and len(results['boxes']) > 0:
                box = results['boxes'][0]
                print(f"First object bounding box (xyxy): {box.tolist()}")


if __name__ == "__main__":
    print("SAM3 Text Prompting Test Script")
    print("=" * 40)

    try:
        test_sam3_text_prompting()
        print("\n✓ All tests completed successfully!")

    except Exception as e:
        print(f"\n✗ Test failed with error: {e}")
        raise

Output when INFERENCE_RESOLUTION=[1400, 1400]:

$ py test_sam3_text.py
SAM3 Text Prompting Test Script
========================================
Using device: cpu
Loading SAM3 model and processor...
Loading weights: 100%|| 1468/1468 [00:00<00:00, 2709.52it/s, Materializing param=vision_encoder.neck.fpn
Loading sample image...

Testing text prompt: 'cat'

✗ Test failed with error: The size of tensor a (10000) must match the size of tensor b (5184) at non-singleton dimension 2
Traceback (most recent call last):
  File "/home/master-andreas/panopticon/test_sam3_text.py", line 124, in <module>
    test_sam3_text_prompting()
  File "/home/master-andreas/panopticon/test_sam3_text.py", line 48, in test_sam3_text_prompting
    outputs = model(**inputs)
              ^^^^^^^^^^^^^^^
  File "/home/master-andreas/panopticon/test_env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/master-andreas/panopticon/test_env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/master-andreas/panopticon/test_env/lib/python3.12/site-packages/transformers/utils/generic.py", line 938, in wrapper
    outputs = func(self, *args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/master-andreas/panopticon/test_env/lib/python3.12/site-packages/transformers/models/sam3/modeling_sam3.py", line 2264, in forward
    vision_outputs = self.vision_encoder(pixel_values, **kwargs)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/master-andreas/panopticon/test_env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/master-andreas/panopticon/test_env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/master-andreas/panopticon/test_env/lib/python3.12/site-packages/transformers/utils/generic.py", line 938, in wrapper
    outputs = func(self, *args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/master-andreas/panopticon/test_env/lib/python3.12/site-packages/transformers/models/sam3/modeling_sam3.py", line 1014, in forward
    backbone_output = self.backbone(pixel_values, **kwargs)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/master-andreas/panopticon/test_env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/master-andreas/panopticon/test_env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/master-andreas/panopticon/test_env/lib/python3.12/site-packages/transformers/utils/generic.py", line 938, in wrapper
    outputs = func(self, *args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/master-andreas/panopticon/test_env/lib/python3.12/site-packages/transformers/models/sam3/modeling_sam3.py", line 803, in forward
    hidden_states = layer(hidden_states, **kwargs)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/master-andreas/panopticon/test_env/lib/python3.12/site-packages/transformers/modeling_layers.py", line 94, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/master-andreas/panopticon/test_env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/master-andreas/panopticon/test_env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/master-andreas/panopticon/test_env/lib/python3.12/site-packages/transformers/models/sam3/modeling_sam3.py", line 734, in forward
    hidden_states, _ = self.attention(hidden_states, position_embeddings, **kwargs)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/master-andreas/panopticon/test_env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/master-andreas/panopticon/test_env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/master-andreas/panopticon/test_env/lib/python3.12/site-packages/transformers/models/sam3/modeling_sam3.py", line 500, in forward
    query, key = apply_rotary_pos_emb_2d(query, key, cos=cos, sin=sin)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/master-andreas/panopticon/test_env/lib/python3.12/site-packages/transformers/models/sam3/modeling_sam3.py", line 461, in apply_rotary_pos_emb_2d
    q_embed = (q_embed * cos) + (rotate_pairwise(q_embed) * sin)
               ~~~~~~~~^~~~~
RuntimeError: The size of tensor a (10000) must match the size of tensor b (5184) at non-singleton dimension 2

Expected behavior

Like SAM and SAM 2, I expect SAM3 to support any resolution with a multiple of the patch size (14)
Note: from test it appears that other resolutions with 5184 patch tokes work such as 1344x756 resolution

Note: it may have something to do with that mask_size is hardcoded

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions