Skip to content

Improve support for type inheritance from other mapped types #253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,46 @@ query {
"""
```

### Type Inheritance

You can inherit fields from other mapped types using standard Python class inheritance.

- Fields from the parent type (e.g., ApiA) are inherited by the child (e.g., ApiB).

- The `__exclude__` setting applies to inherited fields.

- If both SQLAlchemy models define the same field name, the field from the model inside `.type(...)` takes precedence.

- Declaring a field manually in the mapped type overrides everything else.

```python
class ModelA(base):
__tablename__ = "a"

id = Column(String, primary_key=True)
common_field = Column(String(50))


class ModelB(base):
__tablename__ = "b"

id = Column(String, primary_key=True)
common_field = Column(Integer) # Conflicting field
extra_field = Column(String(50))


@mapper.type(ModelA)
class ApiA:
__exclude__ = ["id"] # This field will be excluded in ApiA (and its children)


@mapper.type(ModelB)
class ApiB(ApiA):
# Inherits fields from ApiA, except "id"
# "common_field" will come from ModelB, not ModelA, so it will be a Integer
# "extra_field" will be overridden and will be a float now instead of the String type declared in ModelB:
extra_field: float = strawberry.field(name="extraField")
```
## Limitations

### Supported Types
Expand Down
93 changes: 93 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
Release type: patch

This release improves how types inherit fields from other mapped types using `@mapper.type(...)`.
You can now safely inherit from another mapped type, and the resulting GraphQL type will include all expected fields with predictable conflict resolution.

Some examples:

- Basic Inheritance:

```python
@mapper.type(ModelA)
class ApiA:
pass


@mapper.type(ModelB)
class ApiB(ApiA):
# ApiB inherits all fields declared in ApiA
pass
```


- The `__exclude__` option continues working:

```python
@mapper.type(ModelA)
class ApiA:
__exclude__ = ["relationshipB_id"]


@mapper.type(ModelB)
class ApiB(ApiA):
# ApiB will have all fields declared in ApiA, except "relationshipB_id"
pass
```

- If two SQLAlchemy models define fields with the same name, the field from the model inside `.type(...)` takes precedence:

```python
class ModelA(base):
__tablename__ = "a"

id = Column(String, primary_key=True)
example_field = Column(String(50))


class ModelB(base):
__tablename__ = "b"

id = Column(String, primary_key=True)
example_field = Column(Integer, autoincrement=True)


@mapper.type(ModelA)
class ApiA:
# example_field will be a String
pass


@mapper.type(ModelB)
class ApiB(ApiA):
# example_field will be taken from ModelB and will be an Integer
pass
```


- If a field is explicitly declared in the mapped type, it will override any inherited or model-based definition:

```python
class ModelA(base):
__tablename__ = "a"

id = Column(String, primary_key=True)
example_field = Column(String(50))


class ModelB(base):
__tablename__ = "b"

id = Column(String, primary_key=True)
example_field = Column(Integer, autoincrement=True)


@mapper.type(ModelA)
class ApiA:
pass


@mapper.type(ModelB)
class ApiB(ApiA):
# example_field will be a Float
example_field: float = strawberry.field(name="exampleField")
```
58 changes: 44 additions & 14 deletions src/strawberry_sqlalchemy_mapper/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@
Protocol,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Union,
cast,
get_type_hints,
overload,
)
from typing_extensions import Self
Expand Down Expand Up @@ -651,27 +653,47 @@ class Employee:
```
"""

def _get_generated_field_keys(type_, old_annotations) -> Tuple[List[str], Dict[str, Any]]:
old_annotations = old_annotations.copy()
generated_field_keys = set()

for key in dir(type_):
val = getattr(type_, key)
if getattr(val, _IS_GENERATED_RESOLVER_KEY, False):
setattr(type_, key, field(resolver=val))
generated_field_keys.add(key)

# Checks for an original type annotation, useful in resolving inheritance-related types
if original_type := getattr(type_, _ORIGINAL_TYPE_KEY, None):
for key in dir(original_type):
if key.startswith("__") and key.endswith("__"):
continue

val = getattr(original_type, key)
if getattr(val, _IS_GENERATED_RESOLVER_KEY, False):
setattr(type_, key, field(resolver=val))
generated_field_keys.add(key)
try:
annotations = get_type_hints(original_type)
except Exception:
annotations = original_type.__annotations__

if key in annotations:
old_annotations[key] = annotations[key]

return list(generated_field_keys), old_annotations

def convert(type_: Any) -> Any:
old_annotations = getattr(type_, "__annotations__", {})
type_.__annotations__ = {k: v for k, v in old_annotations.items() if is_private(v)}
mapper: Mapper = cast("Mapper", inspect(model))
generated_field_keys = []

excluded_keys = getattr(type_, "__exclude__", [])
list_keys = getattr(type_, "__use_list__", [])

# if the type inherits from another mapped type, then it may have
# generated resolvers. These will be treated by dataclasses as having
# a default value, which will likely cause issues because of keys
# that don't have default values. To fix this, we wrap them in
# `strawberry.field()` (like when they were originally made), so
# dataclasses will ignore them.
# TODO: Potentially raise/fix this issue upstream
for key in dir(type_):
val = getattr(type_, key)
if getattr(val, _IS_GENERATED_RESOLVER_KEY, False):
setattr(type_, key, field(resolver=val))
generated_field_keys.append(key)
generated_field_keys, old_annotations = _get_generated_field_keys(
type_, old_annotations
)

self._handle_columns(mapper, type_, excluded_keys, generated_field_keys)
relationship: RelationshipProperty
Expand Down Expand Up @@ -798,7 +820,15 @@ def convert(type_: Any) -> Any:
# because the pre-existing fields might have default values,
# which will cause the mapped fields to fail
# (because they may not have default values)
type_.__annotations__.update(old_annotations)

# For Python versions <= 3.9, only update annotations that don't already exist
# because this versions handle inherance differently
if sys.version_info[:2] <= (3, 9):
for k, v in old_annotations.items():
if k not in type_.__annotations__:
type_.__annotations__[k] = v
else:
type_.__annotations__.update(old_annotations)
Comment on lines +824 to +831
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: what is the exact difference between them? In theory they should be handling that the same

One thing that I know is that in previous versions, __annotations__ could not exist in the class if that class didn't have any annotations in it, something which I had to workaround on strawberry-django.

Is that it or something else?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @bellini666, sorry for the delay!
The problem isn’t directly related to __annotations__ existance, but to inconsistencies with inherited classes and this annotations.

In Python 3.10 and newer, the __annotations__ no longer includes the annotations from original_type (base class in this case), which caused some resolvers or type hints to be missing in the final class. So I had to add extra logic to extract those annotations manually.
I didn’t investigate this deeply, but I suspect it only happens because we’re using a decorator, which might interfere with how Python handles class inheritance and type resolution at that point.

On the other hand, in Python 3.8 and 3.9, the inherited annotations are already present in the subclass’s __annotations__. So thats why I only add missing keys, since its was raising a error when only using update().

Let me know if you see any better idea to fix this, this one was the most stable solution I found to make it work consistently across versions.


if make_interface:
mapped_type = strawberry.interface(type_)
Expand Down
Loading
Loading