Skip to content

Commit dc19495

Browse files
Daverballsondrelg
authored andcommitted
Improves robustness of detecting type checking only declarations
1 parent 41cb96c commit dc19495

File tree

3 files changed

+50
-10
lines changed

3 files changed

+50
-10
lines changed

flake8_type_checking/checker.py

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from ast import Index, literal_eval
88
from contextlib import suppress
99
from dataclasses import dataclass
10+
from itertools import chain
1011
from pathlib import Path
1112
from typing import TYPE_CHECKING, Literal, NamedTuple, cast
1213

@@ -17,6 +18,7 @@
1718
ATTRIBUTE_PROPERTY,
1819
ATTRS_DECORATORS,
1920
ATTRS_IMPORTS,
21+
GLOBAL_PROPERTY,
2022
NAME_RE,
2123
TC001,
2224
TC002,
@@ -30,7 +32,6 @@
3032
TC101,
3133
TC200,
3234
TC201,
33-
TOP_LEVEL_PROPERTY,
3435
py38,
3536
)
3637

@@ -584,8 +585,44 @@ def is_true_when_type_checking(self, node: ast.AST) -> bool | Literal['TYPE_CHEC
584585
return 'TYPE_CHECKING'
585586
return False
586587

588+
def visit_Module(self, node: ast.Module) -> ast.Module:
589+
"""
590+
Mark global statments.
591+
592+
We propagate this marking when visiting control flow nodes, that don't affect
593+
scope, such as if/else, try/except. Although for simplicity we don't handle
594+
quite all the possible cases, since we're only interested in type checking blocks
595+
and it's not realistic to encounter these for example inside a TryStar/With/Match.
596+
597+
If we're serious about handling all the cases it would probably make more sense
598+
to override generic_visit to propagate this property for a sequence of node types
599+
and attributes that contain the statements that should propagate global scope.
600+
"""
601+
for stmt in node.body:
602+
setattr(stmt, GLOBAL_PROPERTY, True)
603+
604+
self.generic_visit(node)
605+
return node
606+
607+
def visit_Try(self, node: ast.Try) -> ast.Try:
608+
"""Propagate global statements."""
609+
if getattr(node, GLOBAL_PROPERTY, False):
610+
for stmt in chain(node.body, (s for h in node.handlers for s in h.body), node.orelse, node.finalbody):
611+
setattr(stmt, GLOBAL_PROPERTY, True)
612+
613+
self.generic_visit(node)
614+
return node
615+
587616
def visit_If(self, node: ast.If) -> Any:
588-
"""Look for a TYPE_CHECKING block."""
617+
"""
618+
Look for a TYPE_CHECKING block.
619+
620+
Also recursively propagate global, since if/else does not affect scope.
621+
"""
622+
if getattr(node, GLOBAL_PROPERTY, False):
623+
for stmt in chain(node.body, getattr(node, 'orelse', ()) or ()):
624+
setattr(stmt, GLOBAL_PROPERTY, True)
625+
589626
type_checking_condition = self.is_true_when_type_checking(node.test) == 'TYPE_CHECKING'
590627

591628
# If it is, note down the line-number-range where the type-checking block exists
@@ -598,10 +635,6 @@ def visit_If(self, node: ast.If) -> Any:
598635
# first element in the else block - 1
599636
start_of_else_block = node.orelse[0].lineno - 1
600637

601-
# mark all the top-level statements
602-
for stmt in node.body:
603-
setattr(stmt, TOP_LEVEL_PROPERTY, True)
604-
605638
# Check for TC005 errors.
606639
if ((node.end_lineno or node.lineno) - node.lineno == 1) and (
607640
len(node.body) == 1 and isinstance(node.body[0], ast.Pass)
@@ -834,7 +867,7 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
834867
):
835868
self.add_annotation(node.value, 'alias')
836869

837-
if getattr(node, TOP_LEVEL_PROPERTY, False):
870+
if getattr(node, GLOBAL_PROPERTY, False) and self.in_type_checking_block(node.lineno, node.col_offset):
838871
self.type_checking_block_declarations.add(node.target.id)
839872

840873
# if it wasn't a TypeAlias we need to visit the value expression
@@ -849,9 +882,10 @@ def visit_Assign(self, node: ast.Assign) -> ast.Assign:
849882
target and it should be an `ast.Name`.
850883
"""
851884
if (
852-
getattr(node, TOP_LEVEL_PROPERTY, False)
885+
getattr(node, GLOBAL_PROPERTY, False)
853886
and len(node.targets) == 1
854887
and isinstance(node.targets[0], ast.Name)
888+
and self.in_type_checking_block(node.lineno, node.col_offset)
855889
):
856890
self.type_checking_block_declarations.add(node.targets[0].id)
857891

@@ -872,7 +906,7 @@ def visit_TypeAlias(self, node: ast.TypeAlias) -> None:
872906
"""
873907
self.add_annotation(node.value, 'new-alias')
874908

875-
if getattr(node, TOP_LEVEL_PROPERTY, False):
909+
if getattr(node, GLOBAL_PROPERTY, False) and self.in_type_checking_block(node.lineno, node.col_offset):
876910
self.type_checking_block_declarations.add(node.name.id)
877911

878912
def register_function_ranges(self, node: Union[FunctionDef, AsyncFunctionDef]) -> None:

flake8_type_checking/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
ATTRIBUTE_PROPERTY = '_flake8-type-checking__parent'
77
ANNOTATION_PROPERTY = '_flake8-type-checking__is_annotation'
8-
TOP_LEVEL_PROPERTY = '_flake8-type-checking__is_top_level'
8+
GLOBAL_PROPERTY = '_flake8-type-checking__is_global'
99

1010
NAME_RE = re.compile(r'(?<![\'"])\b[A-Za-z_]\w*(?![\'"])')
1111

tests/test_tc200.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,12 @@ class FooProtocol(Protocol):
9999
100100
class FooDict(TypedDict):
101101
seq: Sequence[int]
102+
103+
if TYPE_CHECKING:
104+
# this should not count as a type checking global
105+
Bar: int
106+
107+
x: Bar
102108
'''),
103109
{
104110
'9:5 ' + TC200.format(annotation='TypeAlias'),

0 commit comments

Comments
 (0)