Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions effectful/internals/unification.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,60 @@
Substitutions = collections.abc.Mapping[TypeVariable, TypeExpressions]


def _has_typing_self(typ) -> bool:
"""Check if typing.Self appears anywhere in a type expression."""
if typ is typing.Self:
return True
if isinstance(typ, inspect.Signature):
if _has_typing_self(typ.return_annotation):
return True
for p in typ.parameters.values():
if _has_typing_self(p.annotation):
return True
return False
if isinstance(typ, list):
return any(_has_typing_self(item) for item in typ)
for arg in typing.get_args(typ):
if _has_typing_self(arg):
return True
return False


def _replace_self(typ, self_tv: typing.TypeVar):
"""Replace all occurrences of typing.Self with the given TypeVar."""
if typ is typing.Self:
return self_tv
elif isinstance(typ, inspect.Signature):
new_params = []
for i, p in enumerate(typ.parameters.values()):
if _has_typing_self(p.annotation):
new_params.append(
p.replace(annotation=_replace_self(p.annotation, self_tv))
)
elif i == 0 and p.annotation is inspect.Parameter.empty:
new_params.append(p.replace(annotation=self_tv))
else:
new_params.append(p)
new_ret = (
_replace_self(typ.return_annotation, self_tv)
if _has_typing_self(typ.return_annotation)
else typ.return_annotation
)
return typ.replace(parameters=new_params, return_annotation=new_ret)
elif isinstance(typ, list):
return [_replace_self(item, self_tv) for item in typ]
args = typing.get_args(typ)
if not args:
return typ
new_args = tuple(_replace_self(a, self_tv) for a in args)
if new_args == args:
return typ
origin = typing.get_origin(typ)
if origin is not None:
return origin[new_args]
return typ


@dataclass
class Box[T]:
"""Boxed types. Prevents confusion between types computed by __type_rule__
Expand Down Expand Up @@ -347,6 +401,12 @@ def _unify_signature(
if typ != subtyp.signature:
raise TypeError(f"Cannot unify {typ} with {subtyp} given {subs}. ")

if _has_typing_self(typ):
self_tv = typing.TypeVar("Self")
typ = _replace_self(typ, self_tv)
subtyp = typ.bind(*subtyp.args, **subtyp.kwargs)
return {**unify(typ, subtyp, subs), typing.Self: self_tv} # type: ignore

for name, param in typ.parameters.items():
if param.annotation is inspect.Parameter.empty:
continue
Expand Down Expand Up @@ -913,6 +973,9 @@ def substitute(typ, subs: Substitutions) -> TypeExpressions:
>>> substitute(int, {T: str})
<class 'int'>
"""
if typing.Self in subs and _has_typing_self(typ):
return substitute(_replace_self(typ, subs[typing.Self]), subs)

if isinstance(typ, typing.TypeVar | typing.ParamSpec | typing.TypeVarTuple):
return substitute(subs[typ], subs) if typ in subs else typ
elif isinstance(typ, list | tuple):
Expand Down
3 changes: 2 additions & 1 deletion effectful/ops/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ def __type_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> type[V]:

"""
from effectful.internals.unification import (
_has_typing_self,
freetypevars,
nested_type,
substitute,
Expand All @@ -394,7 +395,7 @@ def __type_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> type[V]:
return typing.cast(type[V], object)
elif return_anno is None:
return type(None) # type: ignore
elif not freetypevars(return_anno):
elif not freetypevars(return_anno) and not _has_typing_self(return_anno):
return return_anno

type_args = tuple(nested_type(a).value for a in args)
Expand Down
199 changes: 199 additions & 0 deletions tests/test_internals_unification.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from effectful.internals.unification import (
Box,
_has_typing_self,
canonicalize,
freetypevars,
nested_type,
Expand Down Expand Up @@ -137,6 +138,8 @@ class GenericClass[T]:
(int, {T: str}, int),
(str, {}, str),
(list[int], {T: str}, list[int]),
# typing.Self with no binding passes through unchanged
(typing.Self, {}, typing.Self),
# Single TypeVar in generic
(list[T], {T: int}, list[int]),
(set[T], {T: str}, set[str]),
Expand Down Expand Up @@ -488,6 +491,44 @@ def variadic_kwargs_func[T](**kwargs: T) -> T: # Variadic kwargs not supported
return next(iter(kwargs.values()))


class _Foo:
def return_self(self) -> typing.Self:
return self

def return_list_self(self) -> list[typing.Self]:
return [self]

def return_self_or_none(self) -> typing.Self | None:
return self

def annotated_self(self: typing.Self) -> typing.Self:
return self

def takes_other(self, other: typing.Self) -> typing.Self:
return self

def mixed_with_typevar[T](self, x: T) -> tuple[typing.Self, T]:
return (self, x)

def return_dict_self(self) -> dict[str, typing.Self]:
return {"me": self}

def return_callable_self(self) -> collections.abc.Callable[[typing.Self], int]:
return id

def return_type_self(self) -> type[typing.Self]:
return type(self)

@classmethod
def from_config(cls, config: int) -> typing.Self:
return cls()


class _Bar:
def return_self(self) -> typing.Self:
return self


@pytest.mark.parametrize(
"func,args,kwargs,expected_return_type",
[
Expand Down Expand Up @@ -565,6 +606,32 @@ def variadic_kwargs_func[T](**kwargs: T) -> T: # Variadic kwargs not supported
(variadic_args_func, (int, int), {}, int),
(variadic_kwargs_func, (), {"x": int}, int),
(variadic_kwargs_func, (), {"x": int, "y": int}, int),
# typing.Self return types (methods)
(_Foo.return_self, (int,), {}, int),
(_Foo.return_self, (str,), {}, str),
(_Foo.return_self, (_Foo,), {}, _Foo),
(_Foo.return_self, (_Bar,), {}, _Bar),
(_Bar.return_self, (_Bar,), {}, _Bar),
(_Foo.annotated_self, (_Foo,), {}, _Foo),
(_Foo.return_list_self, (int,), {}, list[int]),
(_Foo.return_list_self, (_Foo,), {}, list[_Foo]),
(_Foo.return_self_or_none, (int,), {}, int | None),
(_Foo.return_self_or_none, (_Foo,), {}, _Foo | None),
# Self as a non-self parameter
(_Foo.takes_other, (int, int), {}, int),
(_Foo.takes_other, (_Foo, _Foo), {}, _Foo),
# Self mixed with other TypeVars
(_Foo.mixed_with_typevar, (int, str), {}, tuple[int, str]),
(_Foo.mixed_with_typevar, (_Foo, list[int]), {}, tuple[_Foo, list[int]]),
# Self in dict[str, Self]
(_Foo.return_dict_self, (int,), {}, dict[str, int]),
(_Foo.return_dict_self, (_Foo,), {}, dict[str, _Foo]),
# Self inside Callable[[Self], int]
(_Foo.return_callable_self, (int,), {}, collections.abc.Callable[[int], int]),
(_Foo.return_callable_self, (_Foo,), {}, collections.abc.Callable[[_Foo], int]),
# type[Self]
(_Foo.return_type_self, (int,), {}, type[int]),
(_Foo.return_type_self, (_Foo,), {}, type[_Foo]),
],
)
def test_infer_return_type_success(
Expand Down Expand Up @@ -1698,3 +1765,135 @@ def test_binary_on_sequence_elements(f, seq, index1, index2):
),
collections.abc.Mapping,
)


# ============================================================
# typing.Self resolution tests
# ============================================================


# --- _has_typing_self ---


@pytest.mark.parametrize(
"typ,expected",
[
(typing.Self, True),
(list[typing.Self], True), # type: ignore[misc]
(typing.Self | None, True),
(dict[str, typing.Self], True), # type: ignore[misc]
(collections.abc.Callable[[typing.Self], int], True),
(type[typing.Self], True),
(int, False),
(list[int], False),
(T, False),
(list[T], False),
],
)
def test_has_typing_self(typ, expected):
assert _has_typing_self(typ) == expected


# --- chaining two signatures with Self from different "classes" ---


def test_chained_self_signatures():
"""Two unify calls sharing subs must not conflate Self."""
sig_a = inspect.signature(_Foo.return_self)
sig_b = inspect.signature(_Bar.return_self)

subs = unify(sig_a, sig_a.bind(_Foo))
assert canonicalize(substitute(sig_a.return_annotation, subs)) == _Foo

# Chaining: second unify with shared subs must not break
subs2 = unify(sig_b, sig_b.bind(_Bar), subs)
assert canonicalize(substitute(sig_b.return_annotation, subs2)) == _Bar


# --- classmethod with Self: cls is stripped, Self stays unresolved ---


def test_classmethod_self_not_resolved():
"""Classmethod Self stays unresolved when cls is stripped.

inspect.signature strips `cls`, so `from_config(config: int) -> Self`
has no unannotated first parameter. The Self TypeVar is created but
nothing binds it, so it remains free in the substitution result.
"""
sig = inspect.signature(_Foo.from_config)
subs = unify(sig, sig.bind(int))
result = substitute(sig.return_annotation, subs)
# Self was replaced with a TypeVar, but nothing bound it.
assert isinstance(result, typing.TypeVar)
assert result.__name__ == "Self"


# --- composition tests with Self ---


@pytest.mark.parametrize("obj_type", [_Foo, _Bar, int, str])
def test_infer_self_composition_1(obj_type):
"""Compose return_list_self -> get_first, verify matches return_self.

Step 1: return_list_self(obj) -> list[Self] (Self method)
Step 2: get_first(list[T]) -> T (generic function)
Direct: return_self(obj) -> Self
"""
sig1 = inspect.signature(_Foo.return_list_self)
sig2 = inspect.signature(get_first)
sig_direct = inspect.signature(_Foo.return_self)

# Step 1: infer list[Self] with Self bound to obj_type
inferred_type1 = substitute(
sig1.return_annotation,
unify(sig1, sig1.bind(obj_type)),
)

# Step 2: get_first(list[obj_type]) -> obj_type
inferred_type2 = substitute(
sig2.return_annotation,
unify(sig2, sig2.bind(nested_type(Box(inferred_type1)).value)),
)

# Direct: return_self(obj_type) -> obj_type
inferred_direct = substitute(
sig_direct.return_annotation,
unify(sig_direct, sig_direct.bind(obj_type)),
)

# The composed inference should match the direct inference
assert isinstance(unify(inferred_type2, inferred_direct), collections.abc.Mapping)


@pytest.mark.parametrize("obj_type", [_Foo, _Bar, int, str])
def test_infer_self_composition_2(obj_type):
"""Compose identity -> return_list_self, verify matches wrap_in_list.

Step 1: identity(x: T) -> T (generic function)
Step 2: return_list_self(self) -> list[Self] (Self method)
Direct: wrap_in_list(x: T) -> list[T]
"""
sig1 = inspect.signature(identity)
sig2 = inspect.signature(_Foo.return_list_self)
sig_direct = inspect.signature(wrap_in_list)

# Step 1: identity(obj_type) -> obj_type
inferred_type1 = substitute(
sig1.return_annotation,
unify(sig1, sig1.bind(obj_type)),
)

# Step 2: return_list_self(obj_type) -> list[obj_type]
inferred_type2 = substitute(
sig2.return_annotation,
unify(sig2, sig2.bind(nested_type(Box(inferred_type1)).value)),
)

# Direct: wrap_in_list(obj_type) -> list[obj_type]
inferred_direct = substitute(
sig_direct.return_annotation,
unify(sig_direct, sig_direct.bind(obj_type)),
)

# The composed inference should match the direct inference
assert isinstance(unify(inferred_type2, inferred_direct), collections.abc.Mapping)
Loading
Loading