diff --git a/mypy_django_plugin/lib/helpers.py b/mypy_django_plugin/lib/helpers.py index 116097ac7..e69062a44 100644 --- a/mypy_django_plugin/lib/helpers.py +++ b/mypy_django_plugin/lib/helpers.py @@ -204,6 +204,10 @@ def make_optional(typ: MypyType) -> MypyType: return UnionType.make_union([typ, NoneTyp()]) +def is_optional(typ: MypyType) -> bool: + return isinstance(typ, UnionType) and any(isinstance(item, NoneTyp) for item in typ.items) + + # Duplicating mypy.semanal_shared.parse_bool because importing it directly caused ImportError (#1784) def parse_bool(expr: Expression) -> Optional[bool]: if isinstance(expr, NameExpr): diff --git a/mypy_django_plugin/transformers/fields.py b/mypy_django_plugin/transformers/fields.py index 98654e749..731a102b5 100644 --- a/mypy_django_plugin/transformers/fields.py +++ b/mypy_django_plugin/transformers/fields.py @@ -4,9 +4,10 @@ from django.db.models.fields import AutoField, Field from django.db.models.fields.related import RelatedField from django.db.models.fields.reverse_related import ForeignObjectRel +from mypy.maptype import map_instance_to_supertype from mypy.nodes import AssignmentStmt, NameExpr, TypeInfo from mypy.plugin import FunctionContext -from mypy.types import AnyType, Instance, ProperType, TypeOfAny, UnionType +from mypy.types import AnyType, Instance, NoneType, ProperType, TypeOfAny, UninhabitedType, UnionType from mypy.types import Type as MypyType from mypy_django_plugin.django.context import DjangoContext @@ -150,6 +151,25 @@ def set_descriptor_types_for_field( is_set_nullable=is_set_nullable or is_nullable, is_get_nullable=is_get_nullable or is_nullable, ) + + # reconcile set and get types with the base field class + base_field_type = next(base for base in default_return_type.type.mro if base.fullname == fullnames.FIELD_FULLNAME) + mapped_instance = map_instance_to_supertype(default_return_type, base_field_type) + mapped_set_type, mapped_get_type = mapped_instance.args + + # bail if either mapped_set_type or mapped_get_type have type Never + if not (isinstance(mapped_set_type, UninhabitedType) or isinstance(mapped_get_type, UninhabitedType)): + # always replace set_type and get_type with (non-Any) mapped types + set_type = helpers.convert_any_to_type(mapped_set_type, set_type) + get_type = helpers.convert_any_to_type(mapped_get_type, get_type) + + # the get_type must be optional if the field is nullable + if (is_get_nullable or is_nullable) and not (isinstance(get_type, NoneType) or helpers.is_optional(get_type)): + ctx.api.fail( + f"{default_return_type.type.name} is nullable but its generic get type parameter is not optional", + ctx.context, + ) + return helpers.reparametrize_instance(default_return_type, [set_type, get_type]) diff --git a/tests/typecheck/fields/test_custom_fields.yml b/tests/typecheck/fields/test_custom_fields.yml new file mode 100644 index 000000000..14d175348 --- /dev/null +++ b/tests/typecheck/fields/test_custom_fields.yml @@ -0,0 +1,78 @@ +- case: test_custom_model_fields_with_generic_type + main: | + from myapp.models import User, CustomFieldValue + user = User() + reveal_type(user.id) # N: Revealed type is "builtins.int" + reveal_type(user.my_custom_field1) # N: Revealed type is "myapp.models.CustomFieldValue" + reveal_type(user.my_custom_field2) # N: Revealed type is "myapp.models.CustomFieldValue" + reveal_type(user.my_custom_field3) # N: Revealed type is "builtins.bool" + reveal_type(user.my_custom_field4) # N: Revealed type is "myapp.models.CustomFieldValue" + reveal_type(user.my_custom_field5) # N: Revealed type is "myapp.models.CustomFieldValue" + reveal_type(user.my_custom_field6) # N: Revealed type is "myapp.models.CustomFieldValue" + reveal_type(user.my_custom_field7) # N: Revealed type is "builtins.bool" + reveal_type(user.my_custom_field8) # N: Revealed type is "myapp.models.CustomFieldValue" + reveal_type(user.my_custom_field9) # N: Revealed type is "myapp.models.CustomFieldValue" + reveal_type(user.my_custom_field10) # N: Revealed type is "builtins.bool" + reveal_type(user.my_custom_field11) # N: Revealed type is "builtins.bool" + reveal_type(user.my_custom_field12) # N: Revealed type is "Union[myapp.models.CustomFieldValue, None]" + reveal_type(user.my_custom_field13) # N: Revealed type is "Union[myapp.models.CustomFieldValue, None]" + reveal_type(user.my_custom_field14) # N: Revealed type is "Union[builtins.bool, None]" + reveal_type(user.my_custom_field15) # N: Revealed type is "None" + monkeypatch: true + out: | + myapp/models:31: error: GenericField is nullable but its generic get type parameter is not optional [misc] + myapp/models:32: error: CustomValueField is nullable but its generic get type parameter is not optional [misc] + myapp/models:33: error: SingleTypeField is nullable but its generic get type parameter is not optional [misc] + myapp/models:34: error: AdditionalTypeVarField is nullable but its generic get type parameter is not optional [misc] + myapp/models:35: error: Field is nullable but its generic get type parameter is not optional [misc] + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + from django.db.models import fields + + from typing import Any, TypeVar, Generic, Union + + _ST = TypeVar("_ST", contravariant=True) + _GT = TypeVar("_GT", covariant=True) + + T = TypeVar("T") + + class CustomFieldValue: ... + + class GenericField(fields.Field[_ST, _GT]): ... + + class SingleTypeField(fields.Field[T, T]): ... + + class CustomValueField(fields.Field[Union[CustomFieldValue, int], CustomFieldValue]): ... + + class AdditionalTypeVarField(fields.Field[_ST, _GT], Generic[_ST, _GT, T]): ... + + class CustomSmallIntegerField(fields.SmallIntegerField[_ST, _GT]): ... + + class User(models.Model): + id = models.AutoField(primary_key=True) + my_custom_field1 = GenericField[Union[CustomFieldValue, int], CustomFieldValue]() + my_custom_field2 = CustomValueField() + my_custom_field3 = SingleTypeField[bool]() + my_custom_field4 = AdditionalTypeVarField[Union[CustomFieldValue, int], CustomFieldValue, bool]() + + # test null=True on fields with non-optional generic types throw error + my_custom_field5 = GenericField[Union[CustomFieldValue, int], CustomFieldValue](null=True) + my_custom_field6 = CustomValueField(null=True) + my_custom_field7 = SingleTypeField[bool](null=True) + my_custom_field8 = AdditionalTypeVarField[Union[CustomFieldValue, int], CustomFieldValue, bool](null=True) + my_custom_field9 = fields.Field[Union[CustomFieldValue, int], CustomFieldValue](null=True) + + # test overriding fields that set _pyi_private_set_type or _pyi_private_get_type + my_custom_field10 = fields.SmallIntegerField[bool, bool]() + my_custom_field11 = CustomSmallIntegerField[bool, bool]() + + # test null=True on fields with non-optional generic types throw no errors + my_custom_field12 = fields.Field[Union[CustomFieldValue, int], Union[CustomFieldValue, None]](null=True) + my_custom_field13 = GenericField[Union[CustomFieldValue, int], Union[CustomFieldValue, None]](null=True) + my_custom_field14 = SingleTypeField[Union[bool, None]](null=True) + my_custom_field15 = fields.Field[None, None](null=True)