Skip to content

Assorted niche optimizations #19587

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,7 +915,7 @@ def deserialize(cls, data: JsonDict) -> FuncDef:
# NOTE: ret.info is set in the fixup phase.
ret.arg_names = data["arg_names"]
ret.original_first_arg = data.get("original_first_arg")
ret.arg_kinds = [ArgKind(x) for x in data["arg_kinds"]]
ret.arg_kinds = [ARG_KINDS[x] for x in data["arg_kinds"]]
ret.abstract_status = data["abstract_status"]
ret.dataclass_transform_spec = (
DataclassTransformSpec.deserialize(data["dataclass_transform_spec"])
Expand Down Expand Up @@ -2013,6 +2013,8 @@ def is_star(self) -> bool:
ARG_STAR2: Final = ArgKind.ARG_STAR2
ARG_NAMED_OPT: Final = ArgKind.ARG_NAMED_OPT

ARG_KINDS: Final = (ARG_POS, ARG_OPT, ARG_STAR, ARG_NAMED, ARG_STAR2, ARG_NAMED_OPT)


class CallExpr(Expression):
"""Call expression.
Expand Down Expand Up @@ -3488,6 +3490,8 @@ def update_tuple_type(self, typ: mypy.types.TupleType) -> None:
self.special_alias = alias
else:
self.special_alias.target = alias.target
# Invalidate recursive status cache in case it was previously set.
self.special_alias._is_recursive = None

def update_typeddict_type(self, typ: mypy.types.TypedDictType) -> None:
"""Update typeddict_type and special_alias as needed."""
Expand All @@ -3497,6 +3501,8 @@ def update_typeddict_type(self, typ: mypy.types.TypedDictType) -> None:
self.special_alias = alias
else:
self.special_alias.target = alias.target
# Invalidate recursive status cache in case it was previously set.
self.special_alias._is_recursive = None

def __str__(self) -> str:
"""Return a string representation of the type.
Expand Down
2 changes: 2 additions & 0 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5633,6 +5633,8 @@ def visit_type_alias_stmt(self, s: TypeAliasStmt) -> None:
existing.node.target = res
existing.node.alias_tvars = alias_tvars
updated = True
# Invalidate recursive status cache in case it was previously set.
existing.node._is_recursive = None
else:
# Otherwise just replace existing placeholder with type alias.
existing.node = alias_node
Expand Down
14 changes: 8 additions & 6 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
flatten_nested_unions,
get_proper_type,
get_proper_types,
remove_dups,
)
from mypy.typetraverser import TypeTraverserVisitor
from mypy.typevars import fill_typevars
Expand Down Expand Up @@ -995,7 +996,7 @@ def is_singleton_type(typ: Type) -> bool:
return typ.is_singleton_type()


def try_expanding_sum_type_to_union(typ: Type, target_fullname: str) -> ProperType:
def try_expanding_sum_type_to_union(typ: Type, target_fullname: str) -> Type:
"""Attempts to recursively expand any enum Instances with the given target_fullname
into a Union of all of its component LiteralTypes.

Expand All @@ -1017,21 +1018,22 @@ class Status(Enum):
typ = get_proper_type(typ)

if isinstance(typ, UnionType):
# Non-empty enums cannot subclass each other so simply removing duplicates is enough.
items = [
try_expanding_sum_type_to_union(item, target_fullname) for item in typ.relevant_items()
try_expanding_sum_type_to_union(item, target_fullname)
for item in remove_dups(flatten_nested_unions(typ.relevant_items()))
]
return make_simplified_union(items, contract_literals=False)
return UnionType.make_union(items)

if isinstance(typ, Instance) and typ.type.fullname == target_fullname:
if typ.type.fullname == "builtins.bool":
items = [LiteralType(True, typ), LiteralType(False, typ)]
return make_simplified_union(items, contract_literals=False)
return UnionType([LiteralType(True, typ), LiteralType(False, typ)])

if typ.type.is_enum:
items = [LiteralType(name, typ) for name in typ.type.enum_members]
if not items:
return typ
return make_simplified_union(items, contract_literals=False)
return UnionType.make_union(items)

return typ

Expand Down
25 changes: 13 additions & 12 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import mypy.nodes
from mypy.bogus_type import Bogus
from mypy.nodes import ARG_POS, ARG_STAR, ARG_STAR2, INVARIANT, ArgKind, FakeInfo, SymbolNode
from mypy.nodes import ARG_KINDS, ARG_POS, ARG_STAR, ARG_STAR2, INVARIANT, ArgKind, SymbolNode
from mypy.options import Options
from mypy.state import state
from mypy.util import IdMapper
Expand Down Expand Up @@ -538,6 +538,10 @@ def __repr__(self) -> str:
return self.raw_id.__repr__()

def __eq__(self, other: object) -> bool:
# Although this call is not expensive (like UnionType or TypedDictType),
# most of the time we get the same object here, so add a fast path.
if self is other:
return True
return (
isinstance(other, TypeVarId)
and self.raw_id == other.raw_id
Expand Down Expand Up @@ -1780,7 +1784,9 @@ def deserialize(cls, data: JsonDict) -> Parameters:
assert data[".class"] == "Parameters"
return Parameters(
[deserialize_type(t) for t in data["arg_types"]],
[ArgKind(x) for x in data["arg_kinds"]],
# This is a micro-optimization until mypyc gets dedicated enum support. Otherwise,
# we would spend ~20% of types deserialization time in Enum.__call__().
[ARG_KINDS[x] for x in data["arg_kinds"]],
data["arg_names"],
variables=[cast(TypeVarLikeType, deserialize_type(v)) for v in data["variables"]],
imprecise_arg_kinds=data["imprecise_arg_kinds"],
Expand All @@ -1797,7 +1803,7 @@ def __hash__(self) -> int:
)

def __eq__(self, other: object) -> bool:
if isinstance(other, (Parameters, CallableType)):
if isinstance(other, Parameters):
return (
self.arg_types == other.arg_types
and self.arg_names == other.arg_names
Expand Down Expand Up @@ -2210,15 +2216,9 @@ def with_normalized_var_args(self) -> Self:
)

def __hash__(self) -> int:
# self.is_type_obj() will fail if self.fallback.type is a FakeInfo
if isinstance(self.fallback.type, FakeInfo):
is_type_obj = 2
else:
is_type_obj = self.is_type_obj()
return hash(
(
self.ret_type,
is_type_obj,
self.is_ellipsis_args,
self.name,
tuple(self.arg_types),
Expand All @@ -2236,7 +2236,6 @@ def __eq__(self, other: object) -> bool:
and self.arg_names == other.arg_names
and self.arg_kinds == other.arg_kinds
and self.name == other.name
and self.is_type_obj() == other.is_type_obj()
and self.is_ellipsis_args == other.is_ellipsis_args
and self.type_guard == other.type_guard
and self.type_is == other.type_is
Expand Down Expand Up @@ -2271,10 +2270,10 @@ def serialize(self) -> JsonDict:
@classmethod
def deserialize(cls, data: JsonDict) -> CallableType:
assert data[".class"] == "CallableType"
# TODO: Set definition to the containing SymbolNode?
# The .definition link is set in fixup.py.
return CallableType(
[deserialize_type(t) for t in data["arg_types"]],
[ArgKind(x) for x in data["arg_kinds"]],
[ARG_KINDS[x] for x in data["arg_kinds"]],
data["arg_names"],
deserialize_type(data["ret_type"]),
Instance.deserialize(data["fallback"]),
Expand Down Expand Up @@ -2931,6 +2930,8 @@ def __hash__(self) -> int:
def __eq__(self, other: object) -> bool:
if not isinstance(other, UnionType):
return NotImplemented
if self is other:
return True
return frozenset(self.items) == frozenset(other.items)

@overload
Expand Down
Loading