Skip to content

Commit ce3b396

Browse files
committed
Implement type equality for recursive types, fixes #161
1 parent cda5130 commit ce3b396

File tree

4 files changed

+119
-20
lines changed

4 files changed

+119
-20
lines changed

crates/zuban_python/src/file/inference.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1369,7 +1369,7 @@ impl<'db, 'file> Inference<'db, 'file, '_> {
13691369
&& self
13701370
.infer_name_of_definition_by_index(first_index)
13711371
.as_cow_type(i_s)
1372-
.is_equal_type(i_s.db, &value.as_cow_type(i_s))
1372+
.is_equal_type(i_s.db, None, &value.as_cow_type(i_s))
13731373
{
13741374
} else {
13751375
let mut node_ref = name_def_ref;

crates/zuban_python/src/type_helpers/typing.rs

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use crate::{
77
format_data::FormatData,
88
inference_state::InferenceState,
99
inferred::Inferred,
10-
matching::{CouldBeALiteral, Generic, Generics, ResultContext},
10+
matching::{CheckedTypeRecursion, CouldBeALiteral, Generic, Generics, ResultContext},
1111
type_::{
1212
CallableParams, ClassGenerics, GenericClass, ParamType, StarParamType, StarStarParamType,
1313
TupleArgs, Type, TypeVarKind, TypedDict, TypedDictGenerics,
@@ -69,7 +69,7 @@ pub(crate) fn execute_cast<'db>(i_s: &InferenceState<'db, '_>, args: &dyn Args<'
6969
{
7070
let t_in = actual.as_cow_type(i_s);
7171
let t_out = result.as_cow_type(i_s);
72-
if t_in.is_equal_type(i_s.db, &t_out) && !(t_in.is_any()) {
72+
if t_in.is_equal_type(i_s.db, None, &t_out) && !(t_in.is_any()) {
7373
args.add_issue(
7474
i_s,
7575
IssueKind::RedundantCast {
@@ -243,7 +243,7 @@ pub(crate) fn execute_assert_type<'db>(
243243
return Inferred::new_any_from_error();
244244
};
245245
let second_type = second.as_cow_type(i_s);
246-
if !first_type.is_equal_type(i_s.db, &second_type) {
246+
if !first_type.is_equal_type(i_s.db, None, &second_type) {
247247
let mut format_data = FormatData::new_short(i_s.db);
248248
format_data.hide_implicit_literals = false;
249249
let mut actual = first_type.format(&format_data);
@@ -259,8 +259,13 @@ pub(crate) fn execute_assert_type<'db>(
259259
}
260260

261261
impl Type {
262-
pub fn is_equal_type(&self, db: &Database, other: &Type) -> bool {
263-
let eq = |t1: &Type, t2: &Type| t1.is_equal_type(db, t2);
262+
pub fn is_equal_type(
263+
&self,
264+
db: &Database,
265+
checking_type_recursion: Option<CheckedTypeRecursion>,
266+
other: &Type,
267+
) -> bool {
268+
let eq = |t1: &Type, t2: &Type| t1.is_equal_type(db, checking_type_recursion, t2);
264269
let all_eq =
265270
|ts1: &[Type], ts2: &[Type]| ts1.iter().zip(ts2.iter()).all(|(t1, t2)| eq(t1, t2));
266271
let typed_dict_eq = |td1: &TypedDict, td2: &TypedDict| {
@@ -276,7 +281,8 @@ impl Type {
276281
&& match (&m1.extra_items, &m2.extra_items) {
277282
(None, None) => true,
278283
(Some(t1), Some(t2)) => {
279-
t1.t.is_equal_type(db, &t2.t) && t1.read_only == t2.read_only
284+
t1.t.is_equal_type(db, checking_type_recursion, &t2.t)
285+
&& t1.read_only == t2.read_only
280286
}
281287
_ => false,
282288
}
@@ -387,15 +393,16 @@ impl Type {
387393
}
388394
}
389395
(Type::Type(t1), Type::Type(t2)) => eq(t1, t2),
390-
(Type::RecursiveType(r1), Type::RecursiveType(r2)) => {
391-
r1.link == r2.link
392-
&& r1
393-
.generics
394-
.as_ref()
395-
.zip(r2.generics.as_ref())
396-
.is_none_or(|(g1, g2)| {
397-
matches_generics(Generics::List(g1, None), Generics::List(g2, None))
398-
})
396+
(t1 @ Type::RecursiveType(r1), t2) | (t2, t1 @ Type::RecursiveType(r1)) => {
397+
let checking_type_recursion = CheckedTypeRecursion {
398+
current: (t1, t2),
399+
previous: checking_type_recursion.as_ref(),
400+
};
401+
if checking_type_recursion.is_cycle() {
402+
return true;
403+
}
404+
r1.calculated_type(db)
405+
.is_equal_type(db, Some(checking_type_recursion), t2)
399406
}
400407
(Type::Literal(l1), Type::Literal(l2)) => l1.value(db) == l2.value(db),
401408
(Type::Literal(l), Type::Class(c)) | (Type::Class(c), Type::Literal(l)) => {
@@ -406,12 +413,16 @@ impl Type {
406413
(Type::Never(_), Type::Never(_)) => true,
407414
(Type::Union(u1), Type::Union(u2)) => is_equal_union_or_intersection(
408415
db,
416+
checking_type_recursion,
409417
u1.entries.iter().map(|e| &e.type_),
410418
u2.entries.iter().map(|e| &e.type_),
411419
),
412-
(Type::Intersection(i1), Type::Intersection(i2)) => {
413-
is_equal_union_or_intersection(db, i1.iter_entries(), i2.iter_entries())
414-
}
420+
(Type::Intersection(i1), Type::Intersection(i2)) => is_equal_union_or_intersection(
421+
db,
422+
checking_type_recursion,
423+
i1.iter_entries(),
424+
i2.iter_entries(),
425+
),
415426
(Type::EnumMember(m1), Type::EnumMember(m2)) => {
416427
m1.member_index == m2.member_index && m1.enum_.defined_at == m2.enum_.defined_at
417428
}
@@ -431,6 +442,7 @@ impl Type {
431442

432443
fn is_equal_union_or_intersection<'x>(
433444
db: &Database,
445+
checking_type_recursion: Option<CheckedTypeRecursion>,
434446
ts1: impl ExactSizeIterator<Item = &'x Type>,
435447
ts2: impl ExactSizeIterator<Item = &'x Type>,
436448
) -> bool {
@@ -440,7 +452,7 @@ fn is_equal_union_or_intersection<'x>(
440452
let mut all_second: Vec<_> = ts2.collect();
441453
'outer: for t1 in ts1 {
442454
for (i, t2) in all_second.iter().enumerate() {
443-
if t1.is_equal_type(db, t2) {
455+
if t1.is_equal_type(db, checking_type_recursion, t2) {
444456
all_second.remove(i);
445457
continue 'outer;
446458
}

crates/zuban_python/tests/mypylike/tests/dataclasses.test

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -989,3 +989,79 @@ class BaseArithmeticExpression(BaseExpression):
989989
return Addition()
990990

991991
class Addition[G](BaseArithmeticExpression): ...
992+
993+
[case dataclass_recursive_types_assert_type]
994+
# From GH #161
995+
import dataclasses
996+
from typing import Literal, assert_type, overload
997+
998+
type ArithmeticExpression = int | float | BaseArithmeticExpression
999+
1000+
@dataclasses.dataclass(frozen=True, kw_only=True)
1001+
class BaseExpression[G]:
1002+
pass
1003+
1004+
class BaseComparableExpression[G](BaseExpression[G]):
1005+
pass
1006+
1007+
class BaseNumericExpression[G](BaseComparableExpression[G]):
1008+
pass
1009+
1010+
class BaseArithmeticExpression[G](BaseNumericExpression[G]):
1011+
@overload
1012+
def __add__[S: BaseArithmeticExpression[Literal[False]]](
1013+
self: S,
1014+
other: int,
1015+
) -> Addition[Literal[False], S, int]: ...
1016+
1017+
@overload
1018+
def __add__[S: BaseArithmeticExpression[Literal[False]]](
1019+
self: S,
1020+
other: float,
1021+
) -> Addition[Literal[False], S, float]: ...
1022+
1023+
def __add__(self, other: ArithmeticExpression) -> Addition:
1024+
return Addition(lhs=self, rhs=other) # type: ignore
1025+
1026+
@overload
1027+
def __radd__[S: BaseArithmeticExpression[Literal[False]]](
1028+
self: S,
1029+
other: int,
1030+
) -> Addition[Literal[False], int, S]: ...
1031+
1032+
@overload
1033+
def __radd__[S: BaseArithmeticExpression[Literal[False]]](
1034+
self: S,
1035+
other: float,
1036+
) -> Addition[Literal[False], float, S]: ...
1037+
1038+
def __radd__(self, other: ArithmeticExpression) -> Addition:
1039+
return Addition(lhs=other, rhs=self) # type: ignore
1040+
1041+
class BaseBinaryArithmeticOperation[
1042+
G,
1043+
L: ArithmeticExpression,
1044+
R: ArithmeticExpression,
1045+
](BaseArithmeticExpression[G]):
1046+
lhs: L
1047+
rhs: R
1048+
1049+
class Addition[
1050+
G,
1051+
L: ArithmeticExpression,
1052+
R: ArithmeticExpression,
1053+
](BaseBinaryArithmeticOperation[G, L, R]):
1054+
pass
1055+
1056+
class BaseNamedSymbol(BaseExpression[Literal[False]]):
1057+
name: str
1058+
1059+
class IntegerNamedSymbol(BaseArithmeticExpression[Literal[False]]):
1060+
pass
1061+
1062+
age = IntegerNamedSymbol()
1063+
1064+
result1 = age + 3.7
1065+
assert_type(result1, Addition[Literal[False], IntegerNamedSymbol, float])
1066+
result2 = 3.7 + age
1067+
assert_type(result2, Addition[Literal[False], float, IntegerNamedSymbol])

crates/zuban_python/tests/mypylike/tests/recursive-type-aliases.test

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,14 @@ IncEx: TypeAlias = Union[
200200
def foo(a: IncEx | None) -> None: ...
201201

202202
foo(a={"foo"})
203+
204+
[case recursive_type_alias_assert_type]
205+
from typing import assert_type
206+
X = list['X'] | int
207+
208+
def f(x: X):
209+
assert_type(x, list[X] | int)
210+
assert_type(x, X)
211+
if isinstance(x, list):
212+
assert_type(x[0], list[X] | int)
213+
assert_type(x[0], X)

0 commit comments

Comments
 (0)