From 48266f26ff3e2ca84ef0c2168041ae1c45fc5f8e Mon Sep 17 00:00:00 2001
From: Matthieu Devlin <matt@zumper.com>
Date: Mon, 8 Apr 2024 10:57:18 -0700
Subject: [PATCH 1/4] Use field generic types for descriptors

---
 mypy_django_plugin/transformers/fields.py     | 20 +++++-
 tests/typecheck/fields/test_custom_fields.yml | 62 +++++++++++++++++++
 2 files changed, 81 insertions(+), 1 deletion(-)
 create mode 100644 tests/typecheck/fields/test_custom_fields.yml

diff --git a/mypy_django_plugin/transformers/fields.py b/mypy_django_plugin/transformers/fields.py
index 98654e749..b4a14ddb7 100644
--- a/mypy_django_plugin/transformers/fields.py
+++ b/mypy_django_plugin/transformers/fields.py
@@ -4,14 +4,16 @@
 from django.db.models.fields import AutoField, Field
 from django.db.models.fields.related import RelatedField
 from django.db.models.fields.reverse_related import ForeignObjectRel
+from mypy.maptype import map_instance_to_supertype
 from mypy.nodes import AssignmentStmt, NameExpr, TypeInfo
 from mypy.plugin import FunctionContext
-from mypy.types import AnyType, Instance, ProperType, TypeOfAny, UnionType
+from mypy.types import AnyType, Instance, ProperType, TypeOfAny, UninhabitedType, UnionType
 from mypy.types import Type as MypyType
 
 from mypy_django_plugin.django.context import DjangoContext
 from mypy_django_plugin.exceptions import UnregisteredModelError
 from mypy_django_plugin.lib import fullnames, helpers
+from mypy_django_plugin.lib.fullnames import FIELD_FULLNAME
 from mypy_django_plugin.lib.helpers import parse_bool
 from mypy_django_plugin.transformers import manytomany
 
@@ -150,6 +152,22 @@ def set_descriptor_types_for_field(
         is_set_nullable=is_set_nullable or is_nullable,
         is_get_nullable=is_get_nullable or is_nullable,
     )
+
+    # reconcile set and get types with the base field class
+    base_field_type = next(base for base in default_return_type.type.mro if base.fullname == FIELD_FULLNAME)
+    mapped_instance = map_instance_to_supertype(default_return_type, base_field_type)
+    mapped_set_type, mapped_get_type = mapped_instance.args
+
+    # bail if either mapped_set_type or mapped_get_type have type Never
+    if not (isinstance(mapped_set_type, UninhabitedType) or isinstance(mapped_get_type, UninhabitedType)):
+        # only replace set_type and get_type with mapped types if their original value is Any
+        set_type = helpers.convert_any_to_type(
+            set_type, helpers.make_optional(mapped_set_type) if is_set_nullable or is_nullable else mapped_set_type
+        )
+        get_type = helpers.convert_any_to_type(
+            get_type, helpers.make_optional(mapped_get_type) if is_get_nullable or is_nullable else mapped_get_type
+        )
+
     return helpers.reparametrize_instance(default_return_type, [set_type, get_type])
 
 
diff --git a/tests/typecheck/fields/test_custom_fields.yml b/tests/typecheck/fields/test_custom_fields.yml
new file mode 100644
index 000000000..0845dc9e8
--- /dev/null
+++ b/tests/typecheck/fields/test_custom_fields.yml
@@ -0,0 +1,62 @@
+-   case: test_custom_model_fields_with_generic_type
+    main: |
+        from myapp.models import User, CustomFieldValue
+        user = User()
+        reveal_type(user.id)  # N: Revealed type is "builtins.int"
+        reveal_type(user.my_custom_field1)  # N: Revealed type is "myapp.models.CustomFieldValue"
+        reveal_type(user.my_custom_field2)  # N: Revealed type is "myapp.models.CustomFieldValue"
+        reveal_type(user.my_custom_field3)  # N: Revealed type is "builtins.bool"
+        reveal_type(user.my_custom_field4)  # N: Revealed type is "myapp.models.CustomFieldValue"
+        reveal_type(user.my_custom_field5)  # N: Revealed type is "Union[myapp.models.CustomFieldValue, None]"
+        # user.my_custom_field6 is incorrectly typed as non-optional
+        # reveal_type(user.my_custom_field6) ## N: Revealed type is "Union[myapp.models.CustomFieldValue, None]"
+        reveal_type(user.my_custom_field7)  # N: Revealed type is "Union[builtins.bool, None]"
+        reveal_type(user.my_custom_field8)  # N: Revealed type is "Union[myapp.models.CustomFieldValue, None]"
+        reveal_type(user.my_custom_field9)  # N: Revealed type is "myapp.models.CustomFieldValue"
+        reveal_type(user.my_custom_field10)  # N: Revealed type is "Union[myapp.models.CustomFieldValue, None]"
+        # Fields that set _pyi_private_set_type or _pyi_private_get_type retain these types
+        reveal_type(user.my_custom_field11)  # N: Revealed type is "builtins.int"
+        reveal_type(user.my_custom_field12)  # N: Revealed type is "builtins.int"
+    monkeypatch: true
+    installed_apps:
+        - myapp
+    files:
+        -   path: myapp/__init__.py
+        -   path: myapp/models.py
+            content: |
+                from django.db import models
+                from django.db.models import fields
+
+                from typing import Any, TypeVar, Generic
+
+                _ST = TypeVar("_ST", contravariant=True)
+                _GT = TypeVar("_GT", covariant=True)
+
+                T = TypeVar("T")
+
+                class CustomFieldValue: ...
+
+                class GenericField(fields.Field[_ST, _GT]): ...
+
+                class SingleTypeField(fields.Field[T, T]): ...
+
+                class CustomValueField(fields.Field[CustomFieldValue | int, CustomFieldValue]): ...
+
+                class AdditionalTypeVarField(fields.Field[_ST, _GT], Generic[_ST, _GT, T]): ...
+
+                class CustomSmallIntegerField(fields.SmallIntegerField[_ST, _GT]): ...
+
+                class User(models.Model):
+                    id = models.AutoField(primary_key=True)
+                    my_custom_field1 = GenericField[CustomFieldValue | int, CustomFieldValue]()
+                    my_custom_field2 = CustomValueField()
+                    my_custom_field3 = SingleTypeField[bool]()
+                    my_custom_field4 = AdditionalTypeVarField[CustomFieldValue | int, CustomFieldValue, bool]()
+                    my_custom_field5 = GenericField[CustomFieldValue | int, CustomFieldValue](null=True)
+                    my_custom_field6 = CustomValueField(null=True)
+                    my_custom_field7 = SingleTypeField[bool](null=True)
+                    my_custom_field8 = AdditionalTypeVarField[CustomFieldValue | int, CustomFieldValue, bool](null=True)
+                    my_custom_field9 = fields.Field[CustomFieldValue | int, CustomFieldValue]()
+                    my_custom_field10 = fields.Field[CustomFieldValue | int, CustomFieldValue](null=True)
+                    my_custom_field11 = fields.SmallIntegerField[bool, bool]()
+                    my_custom_field12 = CustomSmallIntegerField[bool, bool]()

From f9529e2bc20b6e7dbd6aea109245c1735c796a28 Mon Sep 17 00:00:00 2001
From: Matthieu Devlin <matt@zumper.com>
Date: Mon, 8 Apr 2024 16:16:07 -0700
Subject: [PATCH 2/4] Fix type annotations for older python versions

---
 tests/typecheck/fields/test_custom_fields.yml | 16 ++++++++--------
 1 file changed, 8 insertions(+), 8 deletions(-)

diff --git a/tests/typecheck/fields/test_custom_fields.yml b/tests/typecheck/fields/test_custom_fields.yml
index 0845dc9e8..87b075867 100644
--- a/tests/typecheck/fields/test_custom_fields.yml
+++ b/tests/typecheck/fields/test_custom_fields.yml
@@ -27,7 +27,7 @@
                 from django.db import models
                 from django.db.models import fields
 
-                from typing import Any, TypeVar, Generic
+                from typing import Any, TypeVar, Generic, Union
 
                 _ST = TypeVar("_ST", contravariant=True)
                 _GT = TypeVar("_GT", covariant=True)
@@ -40,7 +40,7 @@
 
                 class SingleTypeField(fields.Field[T, T]): ...
 
-                class CustomValueField(fields.Field[CustomFieldValue | int, CustomFieldValue]): ...
+                class CustomValueField(fields.Field[Union[CustomFieldValue, int], CustomFieldValue]): ...
 
                 class AdditionalTypeVarField(fields.Field[_ST, _GT], Generic[_ST, _GT, T]): ...
 
@@ -48,15 +48,15 @@
 
                 class User(models.Model):
                     id = models.AutoField(primary_key=True)
-                    my_custom_field1 = GenericField[CustomFieldValue | int, CustomFieldValue]()
+                    my_custom_field1 = GenericField[Union[CustomFieldValue, int], CustomFieldValue]()
                     my_custom_field2 = CustomValueField()
                     my_custom_field3 = SingleTypeField[bool]()
-                    my_custom_field4 = AdditionalTypeVarField[CustomFieldValue | int, CustomFieldValue, bool]()
-                    my_custom_field5 = GenericField[CustomFieldValue | int, CustomFieldValue](null=True)
+                    my_custom_field4 = AdditionalTypeVarField[Union[CustomFieldValue, int], CustomFieldValue, bool]()
+                    my_custom_field5 = GenericField[Union[CustomFieldValue, int], CustomFieldValue](null=True)
                     my_custom_field6 = CustomValueField(null=True)
                     my_custom_field7 = SingleTypeField[bool](null=True)
-                    my_custom_field8 = AdditionalTypeVarField[CustomFieldValue | int, CustomFieldValue, bool](null=True)
-                    my_custom_field9 = fields.Field[CustomFieldValue | int, CustomFieldValue]()
-                    my_custom_field10 = fields.Field[CustomFieldValue | int, CustomFieldValue](null=True)
+                    my_custom_field8 = AdditionalTypeVarField[Union[CustomFieldValue, int], CustomFieldValue, bool](null=True)
+                    my_custom_field9 = fields.Field[Union[CustomFieldValue, int], CustomFieldValue]()
+                    my_custom_field10 = fields.Field[Union[CustomFieldValue, int], CustomFieldValue](null=True)
                     my_custom_field11 = fields.SmallIntegerField[bool, bool]()
                     my_custom_field12 = CustomSmallIntegerField[bool, bool]()

From 8fa7a1421acfc07ac3aed17e0f37cd50a8c935ce Mon Sep 17 00:00:00 2001
From: Matthieu Devlin <matt@zumper.com>
Date: Wed, 17 Apr 2024 21:17:04 -0700
Subject: [PATCH 3/4] Enforce get type is optional when field is nullable and
 don't implicitly convert to optional

---
 mypy_django_plugin/lib/helpers.py             |  4 ++
 mypy_django_plugin/transformers/fields.py     | 16 +++----
 tests/typecheck/fields/test_custom_fields.yml | 42 +++++++++++++------
 3 files changed, 41 insertions(+), 21 deletions(-)

diff --git a/mypy_django_plugin/lib/helpers.py b/mypy_django_plugin/lib/helpers.py
index 116097ac7..e69062a44 100644
--- a/mypy_django_plugin/lib/helpers.py
+++ b/mypy_django_plugin/lib/helpers.py
@@ -204,6 +204,10 @@ def make_optional(typ: MypyType) -> MypyType:
     return UnionType.make_union([typ, NoneTyp()])
 
 
+def is_optional(typ: MypyType) -> bool:
+    return isinstance(typ, UnionType) and any(isinstance(item, NoneTyp) for item in typ.items)
+
+
 # Duplicating mypy.semanal_shared.parse_bool because importing it directly caused ImportError (#1784)
 def parse_bool(expr: Expression) -> Optional[bool]:
     if isinstance(expr, NameExpr):
diff --git a/mypy_django_plugin/transformers/fields.py b/mypy_django_plugin/transformers/fields.py
index b4a14ddb7..74d33fcbc 100644
--- a/mypy_django_plugin/transformers/fields.py
+++ b/mypy_django_plugin/transformers/fields.py
@@ -7,7 +7,7 @@
 from mypy.maptype import map_instance_to_supertype
 from mypy.nodes import AssignmentStmt, NameExpr, TypeInfo
 from mypy.plugin import FunctionContext
-from mypy.types import AnyType, Instance, ProperType, TypeOfAny, UninhabitedType, UnionType
+from mypy.types import AnyType, Instance, NoneType, ProperType, TypeOfAny, UninhabitedType, UnionType
 from mypy.types import Type as MypyType
 
 from mypy_django_plugin.django.context import DjangoContext
@@ -160,13 +160,13 @@ def set_descriptor_types_for_field(
 
     # bail if either mapped_set_type or mapped_get_type have type Never
     if not (isinstance(mapped_set_type, UninhabitedType) or isinstance(mapped_get_type, UninhabitedType)):
-        # only replace set_type and get_type with mapped types if their original value is Any
-        set_type = helpers.convert_any_to_type(
-            set_type, helpers.make_optional(mapped_set_type) if is_set_nullable or is_nullable else mapped_set_type
-        )
-        get_type = helpers.convert_any_to_type(
-            get_type, helpers.make_optional(mapped_get_type) if is_get_nullable or is_nullable else mapped_get_type
-        )
+        # always replace set_type and get_type with (non-Any) mapped types
+        set_type = helpers.convert_any_to_type(mapped_set_type, set_type)
+        get_type = helpers.convert_any_to_type(mapped_get_type, get_type)
+
+        # the get_type must be optional if the field is nullable
+        if (is_get_nullable or is_nullable) and not (isinstance(get_type, NoneType) or helpers.is_optional(get_type)):
+            ctx.api.fail("Field is nullable but generic get type parameter is not optional", ctx.context)
 
     return helpers.reparametrize_instance(default_return_type, [set_type, get_type])
 
diff --git a/tests/typecheck/fields/test_custom_fields.yml b/tests/typecheck/fields/test_custom_fields.yml
index 87b075867..9216c0b13 100644
--- a/tests/typecheck/fields/test_custom_fields.yml
+++ b/tests/typecheck/fields/test_custom_fields.yml
@@ -7,17 +7,24 @@
         reveal_type(user.my_custom_field2)  # N: Revealed type is "myapp.models.CustomFieldValue"
         reveal_type(user.my_custom_field3)  # N: Revealed type is "builtins.bool"
         reveal_type(user.my_custom_field4)  # N: Revealed type is "myapp.models.CustomFieldValue"
-        reveal_type(user.my_custom_field5)  # N: Revealed type is "Union[myapp.models.CustomFieldValue, None]"
-        # user.my_custom_field6 is incorrectly typed as non-optional
-        # reveal_type(user.my_custom_field6) ## N: Revealed type is "Union[myapp.models.CustomFieldValue, None]"
-        reveal_type(user.my_custom_field7)  # N: Revealed type is "Union[builtins.bool, None]"
-        reveal_type(user.my_custom_field8)  # N: Revealed type is "Union[myapp.models.CustomFieldValue, None]"
+        reveal_type(user.my_custom_field5)  # N: Revealed type is "myapp.models.CustomFieldValue"
+        reveal_type(user.my_custom_field6)  # N: Revealed type is "myapp.models.CustomFieldValue"
+        reveal_type(user.my_custom_field7)  # N: Revealed type is "builtins.bool"
+        reveal_type(user.my_custom_field8)  # N: Revealed type is "myapp.models.CustomFieldValue"
         reveal_type(user.my_custom_field9)  # N: Revealed type is "myapp.models.CustomFieldValue"
-        reveal_type(user.my_custom_field10)  # N: Revealed type is "Union[myapp.models.CustomFieldValue, None]"
-        # Fields that set _pyi_private_set_type or _pyi_private_get_type retain these types
-        reveal_type(user.my_custom_field11)  # N: Revealed type is "builtins.int"
-        reveal_type(user.my_custom_field12)  # N: Revealed type is "builtins.int"
+        reveal_type(user.my_custom_field10)  # N: Revealed type is "builtins.bool"
+        reveal_type(user.my_custom_field11)  # N: Revealed type is "builtins.bool"
+        reveal_type(user.my_custom_field12)  # N: Revealed type is "Union[myapp.models.CustomFieldValue, None]"
+        reveal_type(user.my_custom_field13)  # N: Revealed type is "Union[myapp.models.CustomFieldValue, None]"
+        reveal_type(user.my_custom_field14)  # N: Revealed type is "Union[builtins.bool, None]"
+        reveal_type(user.my_custom_field15)  # N: Revealed type is "None"
     monkeypatch: true
+    out: |
+      myapp/models:31: error: Field is nullable but generic get type parameter is not optional  [misc]
+      myapp/models:32: error: Field is nullable but generic get type parameter is not optional  [misc]
+      myapp/models:33: error: Field is nullable but generic get type parameter is not optional  [misc]
+      myapp/models:34: error: Field is nullable but generic get type parameter is not optional  [misc]
+      myapp/models:35: error: Field is nullable but generic get type parameter is not optional  [misc]
     installed_apps:
         - myapp
     files:
@@ -52,11 +59,20 @@
                     my_custom_field2 = CustomValueField()
                     my_custom_field3 = SingleTypeField[bool]()
                     my_custom_field4 = AdditionalTypeVarField[Union[CustomFieldValue, int], CustomFieldValue, bool]()
+
+                    # test null=True on fields with non-optional generic types throw error
                     my_custom_field5 = GenericField[Union[CustomFieldValue, int], CustomFieldValue](null=True)
                     my_custom_field6 = CustomValueField(null=True)
                     my_custom_field7 = SingleTypeField[bool](null=True)
                     my_custom_field8 = AdditionalTypeVarField[Union[CustomFieldValue, int], CustomFieldValue, bool](null=True)
-                    my_custom_field9 = fields.Field[Union[CustomFieldValue, int], CustomFieldValue]()
-                    my_custom_field10 = fields.Field[Union[CustomFieldValue, int], CustomFieldValue](null=True)
-                    my_custom_field11 = fields.SmallIntegerField[bool, bool]()
-                    my_custom_field12 = CustomSmallIntegerField[bool, bool]()
+                    my_custom_field9 = fields.Field[Union[CustomFieldValue, int], CustomFieldValue](null=True)
+
+                    # test overriding fields that set _pyi_private_set_type or _pyi_private_get_type
+                    my_custom_field10 = fields.SmallIntegerField[bool, bool]()
+                    my_custom_field11 = CustomSmallIntegerField[bool, bool]()
+
+                    # test null=True on fields with non-optional generic types throw no errors
+                    my_custom_field12 = fields.Field[Union[CustomFieldValue, int], Union[CustomFieldValue, None]](null=True)
+                    my_custom_field13 = GenericField[Union[CustomFieldValue, int], Union[CustomFieldValue, None]](null=True)
+                    my_custom_field14 = SingleTypeField[Union[bool, None]](null=True)
+                    my_custom_field15 = fields.Field[None, None](null=True)

From babf96899175e886017b98eb4ce22f848dbdcac7 Mon Sep 17 00:00:00 2001
From: Matthieu Devlin <matt@zumper.com>
Date: Thu, 18 Apr 2024 09:58:35 -0700
Subject: [PATCH 4/4] Fix imports and add field name to error message

---
 mypy_django_plugin/transformers/fields.py     |  8 +++++---
 tests/typecheck/fields/test_custom_fields.yml | 10 +++++-----
 2 files changed, 10 insertions(+), 8 deletions(-)

diff --git a/mypy_django_plugin/transformers/fields.py b/mypy_django_plugin/transformers/fields.py
index 74d33fcbc..731a102b5 100644
--- a/mypy_django_plugin/transformers/fields.py
+++ b/mypy_django_plugin/transformers/fields.py
@@ -13,7 +13,6 @@
 from mypy_django_plugin.django.context import DjangoContext
 from mypy_django_plugin.exceptions import UnregisteredModelError
 from mypy_django_plugin.lib import fullnames, helpers
-from mypy_django_plugin.lib.fullnames import FIELD_FULLNAME
 from mypy_django_plugin.lib.helpers import parse_bool
 from mypy_django_plugin.transformers import manytomany
 
@@ -154,7 +153,7 @@ def set_descriptor_types_for_field(
     )
 
     # reconcile set and get types with the base field class
-    base_field_type = next(base for base in default_return_type.type.mro if base.fullname == FIELD_FULLNAME)
+    base_field_type = next(base for base in default_return_type.type.mro if base.fullname == fullnames.FIELD_FULLNAME)
     mapped_instance = map_instance_to_supertype(default_return_type, base_field_type)
     mapped_set_type, mapped_get_type = mapped_instance.args
 
@@ -166,7 +165,10 @@ def set_descriptor_types_for_field(
 
         # the get_type must be optional if the field is nullable
         if (is_get_nullable or is_nullable) and not (isinstance(get_type, NoneType) or helpers.is_optional(get_type)):
-            ctx.api.fail("Field is nullable but generic get type parameter is not optional", ctx.context)
+            ctx.api.fail(
+                f"{default_return_type.type.name} is nullable but its generic get type parameter is not optional",
+                ctx.context,
+            )
 
     return helpers.reparametrize_instance(default_return_type, [set_type, get_type])
 
diff --git a/tests/typecheck/fields/test_custom_fields.yml b/tests/typecheck/fields/test_custom_fields.yml
index 9216c0b13..14d175348 100644
--- a/tests/typecheck/fields/test_custom_fields.yml
+++ b/tests/typecheck/fields/test_custom_fields.yml
@@ -20,11 +20,11 @@
         reveal_type(user.my_custom_field15)  # N: Revealed type is "None"
     monkeypatch: true
     out: |
-      myapp/models:31: error: Field is nullable but generic get type parameter is not optional  [misc]
-      myapp/models:32: error: Field is nullable but generic get type parameter is not optional  [misc]
-      myapp/models:33: error: Field is nullable but generic get type parameter is not optional  [misc]
-      myapp/models:34: error: Field is nullable but generic get type parameter is not optional  [misc]
-      myapp/models:35: error: Field is nullable but generic get type parameter is not optional  [misc]
+      myapp/models:31: error: GenericField is nullable but its generic get type parameter is not optional  [misc]
+      myapp/models:32: error: CustomValueField is nullable but its generic get type parameter is not optional  [misc]
+      myapp/models:33: error: SingleTypeField is nullable but its generic get type parameter is not optional  [misc]
+      myapp/models:34: error: AdditionalTypeVarField is nullable but its generic get type parameter is not optional  [misc]
+      myapp/models:35: error: Field is nullable but its generic get type parameter is not optional  [misc]
     installed_apps:
         - myapp
     files: