Skip to content

Commit bf10d56

Browse files
committed
Fix pydantic generics and add ability to directive json schema
1 parent 0bdfea4 commit bf10d56

File tree

8 files changed

+317
-5
lines changed

8 files changed

+317
-5
lines changed

RELEASE.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Release type: minor
2+
3+
This release fixes the pydantic support for generics and allows capture of the Pydantic
4+
JSON schema attributes through a schema directive.

docs/integrations/pydantic.md

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,3 +519,55 @@ user_type = UserType(id="abc", content_name="Bob", content_description=None)
519519
print(user_type.to_pydantic())
520520
# id='abc' content={<ContentType.NAME: 'name'>: 'Bob'}
521521
```
522+
523+
## Schema directives to capture JSON schema data
524+
525+
The pydantic conversion also supports capturing the JSON schema metadata from Pydantic. Note the fields in the directive must match the json schema names.
526+
527+
```python
528+
from pydantic import BaseModel, Field
529+
from typing import Annotated, Optional
530+
from strawberry.schema_directive import Location
531+
import strawberry
532+
533+
class User(BaseModel):
534+
id: Annotated[int, Field(gt=0)]
535+
name: Annotated[str, Field(json_schema_extra={"name_type": "full"})]
536+
537+
538+
@strawberry.schema_directive(locations=[Location.FIELD_DEFINITION])
539+
class MyJsonSchema:
540+
exclusive_minimum: Optional[int] = None
541+
name_type: Optional[str] = None
542+
543+
544+
@strawberry.experimental.pydantic.type(model=User, json_schema_directive=MyJsonSchema)
545+
class UserType:
546+
id: strawberry.auto
547+
name: strawberry.auto
548+
549+
550+
@strawberry.type
551+
class Query:
552+
@strawberry.field
553+
def test() -> UserType:
554+
return UserType.from_pydantic(User(id=123, name="John Doe"))
555+
556+
557+
schema = strawberry.Schema(query=Query)
558+
```
559+
560+
Now if [the schema is exported](../guides/schema-export), the result will contain:
561+
562+
```graphql
563+
directive @myJsonSchema(exclusiveMinimum: Int = null, nameType: String = null) on FIELD_DEFINITION
564+
565+
type Query {
566+
test: UserType!
567+
}
568+
569+
type UserType {
570+
id: Int!
571+
name: String! @myJsonSchema(exclusiveMinimum: null, nameType: "full")
572+
}
573+
```

strawberry/experimental/pydantic/_compat.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,9 @@ def get_model_fields(
178178
new_fields |= self.get_model_computed_fields(model)
179179
return new_fields
180180

181+
def get_model_json_schema(self, model: type[BaseModel]) -> dict[str, Any]:
182+
return model.model_json_schema()
183+
181184
@cached_property
182185
def fields_map(self) -> dict[Any, Any]:
183186
return get_fields_map_for_v2()
@@ -273,6 +276,9 @@ def get_basic_type(self, type_: Any) -> type[Any]:
273276

274277
return type_
275278

279+
def get_model_json_schema(self, model: type[BaseModel]) -> dict[str, Any]:
280+
return model.schema()
281+
276282
def model_dump(self, model_instance: BaseModel) -> dict[Any, Any]:
277283
return model_instance.dict()
278284

strawberry/experimental/pydantic/fields.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@ def replace_types_recursively(
4646
origin = get_origin(type_)
4747

4848
if not origin or not hasattr(type_, "__args__"):
49-
return replaced_type
49+
if hasattr(basic_type, "__pydantic_generic_metadata__") and basic_type.__pydantic_generic_metadata__["args"]:
50+
return replaced_type[*basic_type.__pydantic_generic_metadata__["args"]]
51+
else:
52+
return replaced_type
5053

5154
converted = tuple(
5255
replace_types_recursively(t, is_input=is_input, compat=compat)

strawberry/experimental/pydantic/object_type.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from strawberry.types.field import StrawberryField
3434
from strawberry.types.object_type import _process_type, _wrap_dataclass
3535
from strawberry.types.type_resolver import _get_fields
36+
from strawberry.utils.str_converters import to_snake_case
3637

3738
if TYPE_CHECKING:
3839
import builtins
@@ -41,7 +42,9 @@
4142
from graphql import GraphQLResolveInfo
4243

4344

44-
def get_type_for_field(field: CompatModelField, is_input: bool, compat: PydanticCompat): # noqa: ANN201
45+
def get_type_for_field(
46+
field: CompatModelField, is_input: bool, compat: PydanticCompat
47+
): # noqa: ANN201
4548
outer_type = field.outer_type_
4649

4750
replaced_type = replace_types_recursively(outer_type, is_input, compat=compat)
@@ -62,6 +65,8 @@ def _build_dataclass_creation_fields(
6265
auto_fields_set: set[str],
6366
use_pydantic_alias: bool,
6467
compat: PydanticCompat,
68+
json_schema: dict[str, Any],
69+
json_schema_directive: Optional[builtins.type] = None,
6570
) -> DataclassCreationFields:
6671
field_type = (
6772
get_type_for_field(field, is_input, compat=compat)
@@ -84,6 +89,28 @@ def _build_dataclass_creation_fields(
8489
elif field.has_alias and use_pydantic_alias:
8590
graphql_name = field.alias
8691

92+
if json_schema_directive and json_schema:
93+
field_names = {
94+
field.name for field in dataclasses.fields(json_schema_directive)
95+
}
96+
applicable_values = {
97+
to_snake_case(key): value
98+
for key, value in json_schema.items()
99+
if to_snake_case(key) in field_names
100+
}
101+
if applicable_values:
102+
json_directive = json_schema_directive(
103+
**applicable_values
104+
)
105+
directives = (
106+
*(existing_field.directives if existing_field else ()),
107+
json_directive,
108+
)
109+
else:
110+
directives = existing_field.directives if existing_field else ()
111+
else:
112+
directives = ()
113+
87114
strawberry_field = StrawberryField(
88115
python_name=field.name,
89116
graphql_name=graphql_name,
@@ -98,7 +125,7 @@ def _build_dataclass_creation_fields(
98125
permission_classes=(
99126
existing_field.permission_classes if existing_field else []
100127
),
101-
directives=existing_field.directives if existing_field else (),
128+
directives=directives,
102129
metadata=existing_field.metadata if existing_field else {},
103130
)
104131

@@ -128,6 +155,7 @@ def type(
128155
all_fields: bool = False,
129156
include_computed: bool = False,
130157
use_pydantic_alias: bool = True,
158+
json_schema_directive: Optional[Any] = None,
131159
) -> Callable[..., builtins.type[StrawberryTypeFromPydantic[PydanticModel]]]:
132160
def wrap(cls: Any) -> builtins.type[StrawberryTypeFromPydantic[PydanticModel]]:
133161
compat = PydanticCompat.from_model(model)
@@ -184,6 +212,11 @@ def wrap(cls: Any) -> builtins.type[StrawberryTypeFromPydantic[PydanticModel]]:
184212
private_fields = get_private_fields(wrapped)
185213

186214
extra_fields_dict = {field.name: field for field in extra_strawberry_fields}
215+
fields_json_schema = (
216+
compat.get_model_json_schema(model).get("properties", {})
217+
if json_schema_directive
218+
else {}
219+
)
187220

188221
all_model_fields: list[DataclassCreationFields] = [
189222
_build_dataclass_creation_fields(
@@ -193,6 +226,8 @@ def wrap(cls: Any) -> builtins.type[StrawberryTypeFromPydantic[PydanticModel]]:
193226
auto_fields_set,
194227
use_pydantic_alias,
195228
compat=compat,
229+
json_schema_directive=json_schema_directive,
230+
json_schema=fields_json_schema.get(field.name, {}),
196231
)
197232
for field_name, field in model_fields.items()
198233
if field_name in fields_set
@@ -250,10 +285,12 @@ def is_type_of(cls: builtins.type, obj: Any, _info: GraphQLResolveInfo) -> bool:
250285
else:
251286
kwargs["init"] = False
252287

288+
bases = cls.__orig_bases__ if hasattr(cls, "__orig_bases__") else cls.__bases__
289+
253290
cls = dataclasses.make_dataclass(
254291
cls.__name__,
255292
[field.to_tuple() for field in all_model_fields],
256-
bases=cls.__bases__,
293+
bases=bases,
257294
namespace=namespace,
258295
**kwargs, # type: ignore
259296
)
@@ -317,6 +354,7 @@ def input(
317354
directives: Optional[Sequence[object]] = (),
318355
all_fields: bool = False,
319356
use_pydantic_alias: bool = True,
357+
json_schema_directive: Optional[builtins.type] = None,
320358
) -> Callable[..., builtins.type[StrawberryTypeFromPydantic[PydanticModel]]]:
321359
"""Convenience decorator for creating an input type from a Pydantic model.
322360
@@ -334,6 +372,7 @@ def input(
334372
directives=directives,
335373
all_fields=all_fields,
336374
use_pydantic_alias=use_pydantic_alias,
375+
json_schema_directive=json_schema_directive,
337376
)
338377

339378

tests/experimental/pydantic/schema/test_basic.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import textwrap
22
from enum import Enum
3-
from typing import Optional, Union
3+
from typing import Annotated, Generic, Optional, TypeAlias, Union, TypeVar
44

55
import pydantic
66

@@ -528,6 +528,71 @@ def user(self) -> User:
528528
assert result.data["user"]["age"] == 1
529529
assert result.data["user"]["password"] is None
530530

531+
def test_nested_type_with_resolved_generic():
532+
533+
A = TypeVar("A")
534+
class Hobby(pydantic.BaseModel, Generic[A]):
535+
name: A
536+
537+
@strawberry.experimental.pydantic.type(Hobby)
538+
class HobbyType(Generic[A]):
539+
name: strawberry.auto
540+
541+
class User(pydantic.BaseModel):
542+
hobby: Hobby[str]
543+
544+
@strawberry.experimental.pydantic.type(User)
545+
class UserType:
546+
hobby: strawberry.auto
547+
548+
@strawberry.type
549+
class Query:
550+
@strawberry.field
551+
def user(self) -> UserType:
552+
return UserType(hobby=HobbyType(name="Skii"))
553+
554+
schema = strawberry.Schema(query=Query)
555+
556+
query = "{ user { hobby { name } } }"
557+
558+
result = schema.execute_sync(query)
559+
560+
assert not result.errors
561+
assert result.data["user"]["hobby"]["name"] == "Skii"
562+
563+
def test_nested_type_with_resolved_field_generic():
564+
Count: TypeAlias = Annotated[float, pydantic.Field(ge = 0)]
565+
566+
A = TypeVar("A")
567+
class Hobby(pydantic.BaseModel, Generic[A]):
568+
count: A
569+
570+
@strawberry.experimental.pydantic.type(Hobby)
571+
class HobbyType(Generic[A]):
572+
count: strawberry.auto
573+
574+
class User(pydantic.BaseModel):
575+
hobby: Hobby[Count]
576+
577+
@strawberry.experimental.pydantic.type(User)
578+
class UserType:
579+
hobby: strawberry.auto
580+
581+
@strawberry.type
582+
class Query:
583+
@strawberry.field
584+
def user(self) -> UserType:
585+
return UserType(hobby=HobbyType(count=2))
586+
587+
schema = strawberry.Schema(query=Query)
588+
589+
query = "{ user { hobby { count } } }"
590+
591+
result = schema.execute_sync(query)
592+
593+
assert not result.errors
594+
assert result.data["user"]["hobby"]["count"] == 2
595+
531596

532597
@needs_pydantic_v1
533598
def test_basic_type_with_constrained_list():

tests/experimental/pydantic/schema/test_generic.py

Whitespace-only changes.

0 commit comments

Comments
 (0)