Skip to content

Commit 1136e55

Browse files
authored
feat: Add support for singledispatch/singledispatchmethod
1 parent 7cf333d commit 1136e55

File tree

3 files changed

+151
-1
lines changed

3 files changed

+151
-1
lines changed

flake8_type_checking/checker.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,66 @@ def handle_fastapi_decorator(self, node: Union[AsyncFunctionDef, FunctionDef]) -
582582
self.visit(node.args.vararg.annotation)
583583

584584

585+
class FunctoolsSingledispatchMixin:
586+
"""
587+
Contains the necessary logic for `functools.singledispatch` support.
588+
589+
`functools.singledispatch` and `functools.singledispatchmethod` require
590+
runtime access to all annotations.
591+
592+
```python
593+
from functools import singledispatch
594+
595+
from mylib import Foo
596+
597+
@singledispatch
598+
def foo(arg: Foo) -> str:
599+
return arg.name
600+
```
601+
602+
Since the only use of `Foo` is within an annotation, we would usually emit
603+
a TC003 for the `mylib` import. But since `singledispatch` requires runtime
604+
access to `Foo`, that would be a false positive.
605+
"""
606+
607+
if TYPE_CHECKING:
608+
609+
def in_type_checking_block(self, lineno: int, col_offset: int) -> bool: # noqa: D102
610+
...
611+
612+
def lookup_full_name(self, node: ast.AST) -> str | None: # noqa: D102
613+
...
614+
615+
def visit(self, node: ast.AST) -> ast.AST: # noqa: D102
616+
...
617+
618+
def visit_FunctionDef(self, node: FunctionDef) -> None:
619+
"""Remove and map function arguments and returns."""
620+
super().visit_FunctionDef(node) # type: ignore[misc]
621+
if self.has_singledispatch_decorator(node):
622+
self.handle_singledispatch_decorator(node)
623+
624+
def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
625+
"""Remove and map function arguments and returns."""
626+
super().visit_AsyncFunctionDef(node) # type: ignore[misc]
627+
if self.has_singledispatch_decorator(node):
628+
self.handle_singledispatch_decorator(node)
629+
630+
def has_singledispatch_decorator(self, node: FunctionDef | AsyncFunctionDef) -> bool:
631+
"""Determine whether this function is decorated with `functools.singledispatch`."""
632+
return any(
633+
self.lookup_full_name(decorator_node) in ('functools.singledispatch', 'functools.singledispatchmethod')
634+
for decorator_node in node.decorator_list
635+
)
636+
637+
def handle_singledispatch_decorator(self, node: FunctionDef | AsyncFunctionDef) -> None:
638+
"""Walk all the annotations to register them as runtime uses."""
639+
for path in [node.args.args, node.args.kwonlyargs, node.args.posonlyargs]:
640+
for argument in path:
641+
if hasattr(argument, 'annotation') and argument.annotation:
642+
self.visit(argument.annotation)
643+
644+
585645
@dataclass
586646
class ImportName:
587647
"""DTO for representing an import in different string-formats."""
@@ -949,7 +1009,14 @@ def visit_annotated_value(self, node: ast.expr) -> None:
9491009

9501010

9511011
class ImportVisitor(
952-
DunderAllMixin, AttrsMixin, InjectorMixin, FastAPIMixin, PydanticMixin, SQLAlchemyMixin, ast.NodeVisitor
1012+
DunderAllMixin,
1013+
FunctoolsSingledispatchMixin,
1014+
AttrsMixin,
1015+
InjectorMixin,
1016+
FastAPIMixin,
1017+
PydanticMixin,
1018+
SQLAlchemyMixin,
1019+
ast.NodeVisitor,
9531020
):
9541021
"""Map all imports outside of type-checking blocks."""
9551022

tests/test_tc001_to_tc003.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,47 @@ def get_tc_001_to_003_tests(import_: str, ERROR: str) -> L:
120120
(f'import {import_}\ntype x = {import_}', {f"1:0 {ERROR.format(module=f'{import_}')}"})
121121
)
122122

123+
# Imports used for `functools.singledispatch`. None of these should generate errors.
124+
used_for_singledispatch: L = [
125+
(
126+
textwrap.dedent(f'''
127+
import functools
128+
129+
from {import_} import Dict, Any
130+
131+
@functools.singledispatch
132+
def foo(arg: Dict[str, Any]) -> Any:
133+
return 1
134+
'''),
135+
set(),
136+
),
137+
(
138+
textwrap.dedent(f'''
139+
from functools import singledispatch
140+
141+
from {import_} import Dict, Any
142+
143+
@singledispatch
144+
def foo(arg: Dict[str, Any]) -> Any:
145+
return 1
146+
'''),
147+
set(),
148+
),
149+
(
150+
textwrap.dedent(f'''
151+
from functools import singledispatchmethod
152+
153+
from {import_} import Dict, Any
154+
155+
class Foo:
156+
@singledispatchmethod
157+
def foo(self, arg: Dict[str, Any]) -> Any:
158+
return 1
159+
'''),
160+
set(),
161+
),
162+
]
163+
123164
other_useful_test_cases: L = [
124165
(
125166
textwrap.dedent(f'''
@@ -237,6 +278,7 @@ class Migration:
237278
*used_for_arg_annotations_only,
238279
*used_for_return_annotations_only,
239280
*used_for_type_alias_only,
281+
*used_for_singledispatch,
240282
*other_useful_test_cases,
241283
]
242284

tests/test_tc004.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,47 @@ class X:
227227
'3:0 ' + TC004.format(module='z'),
228228
},
229229
),
230+
# functools.singledispatch
231+
(
232+
textwrap.dedent("""
233+
import functools
234+
235+
if TYPE_CHECKING:
236+
from foo import FooType
237+
238+
@functools.singledispatch
239+
def foo(arg: FooType) -> int:
240+
return 1
241+
"""),
242+
{'5:0 ' + TC004.format(module='FooType')},
243+
),
244+
(
245+
textwrap.dedent("""
246+
from functools import singledispatch
247+
248+
if TYPE_CHECKING:
249+
from foo import FooType
250+
251+
@functools.singledispatch
252+
def foo(arg: FooType) -> int:
253+
return 1
254+
"""),
255+
{'5:0 ' + TC004.format(module='FooType')},
256+
),
257+
(
258+
textwrap.dedent("""
259+
from functools import singledispatchmethod
260+
261+
if TYPE_CHECKING:
262+
from foo import FooType
263+
264+
class Foo:
265+
@functools.singledispatch
266+
def foo(self, arg: FooType) -> int:
267+
return 1
268+
"""),
269+
{'5:0 ' + TC004.format(module='FooType')},
270+
),
230271
]
231272

232273

0 commit comments

Comments
 (0)