77from ast import Index , literal_eval
88from contextlib import suppress
99from dataclasses import dataclass
10+ from itertools import chain
1011from pathlib import Path
1112from typing import TYPE_CHECKING , Literal , NamedTuple , cast
1213
1718 ATTRIBUTE_PROPERTY ,
1819 ATTRS_DECORATORS ,
1920 ATTRS_IMPORTS ,
21+ GLOBAL_PROPERTY ,
2022 NAME_RE ,
2123 TC001 ,
2224 TC002 ,
3032 TC101 ,
3133 TC200 ,
3234 TC201 ,
33- TOP_LEVEL_PROPERTY ,
3435 py38 ,
3536)
3637
@@ -584,8 +585,44 @@ def is_true_when_type_checking(self, node: ast.AST) -> bool | Literal['TYPE_CHEC
584585 return 'TYPE_CHECKING'
585586 return False
586587
588+ def visit_Module (self , node : ast .Module ) -> ast .Module :
589+ """
590+ Mark global statments.
591+
592+ We propagate this marking when visiting control flow nodes, that don't affect
593+ scope, such as if/else, try/except. Although for simplicity we don't handle
594+ quite all the possible cases, since we're only interested in type checking blocks
595+ and it's not realistic to encounter these for example inside a TryStar/With/Match.
596+
597+ If we're serious about handling all the cases it would probably make more sense
598+ to override generic_visit to propagate this property for a sequence of node types
599+ and attributes that contain the statements that should propagate global scope.
600+ """
601+ for stmt in node .body :
602+ setattr (stmt , GLOBAL_PROPERTY , True )
603+
604+ self .generic_visit (node )
605+ return node
606+
607+ def visit_Try (self , node : ast .Try ) -> ast .Try :
608+ """Propagate global statements."""
609+ if getattr (node , GLOBAL_PROPERTY , False ):
610+ for stmt in chain (node .body , (s for h in node .handlers for s in h .body ), node .orelse , node .finalbody ):
611+ setattr (stmt , GLOBAL_PROPERTY , True )
612+
613+ self .generic_visit (node )
614+ return node
615+
587616 def visit_If (self , node : ast .If ) -> Any :
588- """Look for a TYPE_CHECKING block."""
617+ """
618+ Look for a TYPE_CHECKING block.
619+
620+ Also recursively propagate global, since if/else does not affect scope.
621+ """
622+ if getattr (node , GLOBAL_PROPERTY , False ):
623+ for stmt in chain (node .body , getattr (node , 'orelse' , ()) or ()):
624+ setattr (stmt , GLOBAL_PROPERTY , True )
625+
589626 type_checking_condition = self .is_true_when_type_checking (node .test ) == 'TYPE_CHECKING'
590627
591628 # If it is, note down the line-number-range where the type-checking block exists
@@ -598,10 +635,6 @@ def visit_If(self, node: ast.If) -> Any:
598635 # first element in the else block - 1
599636 start_of_else_block = node .orelse [0 ].lineno - 1
600637
601- # mark all the top-level statements
602- for stmt in node .body :
603- setattr (stmt , TOP_LEVEL_PROPERTY , True )
604-
605638 # Check for TC005 errors.
606639 if ((node .end_lineno or node .lineno ) - node .lineno == 1 ) and (
607640 len (node .body ) == 1 and isinstance (node .body [0 ], ast .Pass )
@@ -834,7 +867,7 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
834867 ):
835868 self .add_annotation (node .value , 'alias' )
836869
837- if getattr (node , TOP_LEVEL_PROPERTY , False ):
870+ if getattr (node , GLOBAL_PROPERTY , False ) and self . in_type_checking_block ( node . lineno , node . col_offset ):
838871 self .type_checking_block_declarations .add (node .target .id )
839872
840873 # if it wasn't a TypeAlias we need to visit the value expression
@@ -849,9 +882,10 @@ def visit_Assign(self, node: ast.Assign) -> ast.Assign:
849882 target and it should be an `ast.Name`.
850883 """
851884 if (
852- getattr (node , TOP_LEVEL_PROPERTY , False )
885+ getattr (node , GLOBAL_PROPERTY , False )
853886 and len (node .targets ) == 1
854887 and isinstance (node .targets [0 ], ast .Name )
888+ and self .in_type_checking_block (node .lineno , node .col_offset )
855889 ):
856890 self .type_checking_block_declarations .add (node .targets [0 ].id )
857891
@@ -872,7 +906,7 @@ def visit_TypeAlias(self, node: ast.TypeAlias) -> None:
872906 """
873907 self .add_annotation (node .value , 'new-alias' )
874908
875- if getattr (node , TOP_LEVEL_PROPERTY , False ):
909+ if getattr (node , GLOBAL_PROPERTY , False ) and self . in_type_checking_block ( node . lineno , node . col_offset ):
876910 self .type_checking_block_declarations .add (node .name .id )
877911
878912 def register_function_ranges (self , node : Union [FunctionDef , AsyncFunctionDef ]) -> None :
0 commit comments