Skip to content

Commit ca3e8b3

Browse files
cyyeverpytorchmergebot
authored andcommitted
[1/N] Use TYPE_CHECKING (pytorch#165852)
This PR moves typing imports into the `TYPE_CHECKING` block. Pull Request resolved: pytorch#165852 Approved by: https://github.com/Lucaskabela
1 parent b3bc797 commit ca3e8b3

File tree

6 files changed

+17
-13
lines changed

6 files changed

+17
-13
lines changed

torch/_export/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424

2525

2626
if TYPE_CHECKING:
27+
import sympy
28+
2729
from torch._export.passes.lift_constants_pass import ConstantAttrMap
2830
from torch._ops import OperatorBase
2931
from torch.export import ExportedProgram
@@ -433,8 +435,6 @@ def _check_symint(
433435
def _check_input_constraints_for_graph(
434436
input_placeholders: list[torch.fx.Node], flat_args_with_path, range_constraints
435437
) -> None:
436-
import sympy # noqa: TC002
437-
438438
if len(flat_args_with_path) != len(input_placeholders):
439439
raise RuntimeError(
440440
"Unexpected number of inputs "

torch/ao/nn/qat/dynamic/modules/linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55

66
if TYPE_CHECKING:
7-
from torch.ao.quantization.qconfig import QConfig # noqa: TC004
7+
from torch.ao.quantization.qconfig import QConfig
88

99

1010
__all__ = ["Linear"]

torch/distributed/checkpoint/state_dict_saver.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,12 @@
66
from concurrent.futures import Future
77
from dataclasses import dataclass
88
from enum import Enum
9-
from typing import cast, Optional, Union
9+
from typing import cast, Optional, TYPE_CHECKING, Union
1010
from typing_extensions import deprecated
1111

1212
import torch
1313
import torch.distributed as dist
1414
from torch.distributed._state_dict_utils import STATE_DICT_TYPE
15-
from torch.distributed.checkpoint._async_executor import ( # noqa: TC001
16-
_AsyncCheckpointExecutor,
17-
)
1815
from torch.distributed.checkpoint._async_process_executor import (
1916
_ProcessBasedAsyncCheckpointExecutor,
2017
)
@@ -38,6 +35,10 @@
3835
from .utils import _api_bc_check, _DistWrapper, _profile
3936

4037

38+
if TYPE_CHECKING:
39+
from torch.distributed.checkpoint._async_executor import _AsyncCheckpointExecutor
40+
41+
4142
__all__ = [
4243
"save_state_dict",
4344
"save",

torch/fx/passes/_tensorify_python_scalars.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import logging
44
import os
5-
from typing import Any, Union
5+
from typing import Any, TYPE_CHECKING, Union
66

77
from sympy import Integer, Number, Symbol
88
from sympy.logic.boolalg import BooleanAtom
@@ -13,16 +13,14 @@
1313
from torch._dynamo.symbolic_convert import TensorifyState
1414
from torch._dynamo.utils import get_metrics_context
1515
from torch._prims_common import get_computation_dtype
16-
from torch._subclasses import fake_tensor # noqa: TCH001
1716
from torch._subclasses.fake_tensor import FakeTensor
1817
from torch._utils_internal import justknobs_check
1918
from torch.fx._utils import lazy_format_graph_code
20-
from torch.fx.experimental.symbolic_shapes import ( # noqa: TCH001
19+
from torch.fx.experimental.symbolic_shapes import (
2120
guard_scalar,
2221
has_free_symbols,
2322
ShapeEnv,
2423
)
25-
from torch.fx.graph_module import GraphModule # noqa: TCH001
2624

2725
# TODO: refactor
2826
from torch.fx.passes.runtime_assert import _get_sym_val
@@ -32,6 +30,11 @@
3230
from torch.utils._sympy.symbol import symbol_is_type, SymT
3331

3432

33+
if TYPE_CHECKING:
34+
from torch._subclasses import fake_tensor
35+
from torch.fx.graph_module import GraphModule
36+
37+
3538
__all__: list[str] = []
3639

3740
log = logging.getLogger(__name__)

torch/onnx/_internal/exporter/_torchlib/ops/nn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""torch.ops.aten operators under the `core` module."""
22
# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value,type-var,operator,no-untyped-def,index"
33
# pyrefly: ignore-errors
4-
# ruff: noqa: TCH001,TCH002
4+
# ruff: noqa: TC001,TC002
55
# flake8: noqa: B950
66

77
from __future__ import annotations

torch/onnx/_internal/fx/type_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
if TYPE_CHECKING:
17-
import onnx.defs # noqa: TCH004
17+
import onnx.defs
1818

1919

2020
# Enable both TorchScriptTensor and torch.Tensor to be tested

0 commit comments

Comments
 (0)