From 422ef32877443ef47f9dcc3cee6747f9574749c9 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 11 Mar 2026 13:46:29 -0400 Subject: [PATCH] Support typing.Self --- effectful/internals/unification.py | 63 +++++++++ effectful/ops/types.py | 3 +- tests/test_internals_unification.py | 199 ++++++++++++++++++++++++++++ tests/test_ops_semantics.py | 84 ++++++++++++ 4 files changed, 348 insertions(+), 1 deletion(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 71d6583f..0e7f27fa 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -94,6 +94,60 @@ Substitutions = collections.abc.Mapping[TypeVariable, TypeExpressions] +def _has_typing_self(typ) -> bool: + """Check if typing.Self appears anywhere in a type expression.""" + if typ is typing.Self: + return True + if isinstance(typ, inspect.Signature): + if _has_typing_self(typ.return_annotation): + return True + for p in typ.parameters.values(): + if _has_typing_self(p.annotation): + return True + return False + if isinstance(typ, list): + return any(_has_typing_self(item) for item in typ) + for arg in typing.get_args(typ): + if _has_typing_self(arg): + return True + return False + + +def _replace_self(typ, self_tv: typing.TypeVar): + """Replace all occurrences of typing.Self with the given TypeVar.""" + if typ is typing.Self: + return self_tv + elif isinstance(typ, inspect.Signature): + new_params = [] + for i, p in enumerate(typ.parameters.values()): + if _has_typing_self(p.annotation): + new_params.append( + p.replace(annotation=_replace_self(p.annotation, self_tv)) + ) + elif i == 0 and p.annotation is inspect.Parameter.empty: + new_params.append(p.replace(annotation=self_tv)) + else: + new_params.append(p) + new_ret = ( + _replace_self(typ.return_annotation, self_tv) + if _has_typing_self(typ.return_annotation) + else typ.return_annotation + ) + return typ.replace(parameters=new_params, return_annotation=new_ret) + elif isinstance(typ, list): + return [_replace_self(item, self_tv) for item in typ] + args = typing.get_args(typ) + if not args: + return typ + new_args = tuple(_replace_self(a, self_tv) for a in args) + if new_args == args: + return typ + origin = typing.get_origin(typ) + if origin is not None: + return origin[new_args] + return typ + + @dataclass class Box[T]: """Boxed types. Prevents confusion between types computed by __type_rule__ @@ -347,6 +401,12 @@ def _unify_signature( if typ != subtyp.signature: raise TypeError(f"Cannot unify {typ} with {subtyp} given {subs}. ") + if _has_typing_self(typ): + self_tv = typing.TypeVar("Self") + typ = _replace_self(typ, self_tv) + subtyp = typ.bind(*subtyp.args, **subtyp.kwargs) + return {**unify(typ, subtyp, subs), typing.Self: self_tv} # type: ignore + for name, param in typ.parameters.items(): if param.annotation is inspect.Parameter.empty: continue @@ -913,6 +973,9 @@ def substitute(typ, subs: Substitutions) -> TypeExpressions: >>> substitute(int, {T: str}) """ + if typing.Self in subs and _has_typing_self(typ): + return substitute(_replace_self(typ, subs[typing.Self]), subs) + if isinstance(typ, typing.TypeVar | typing.ParamSpec | typing.TypeVarTuple): return substitute(subs[typ], subs) if typ in subs else typ elif isinstance(typ, list | tuple): diff --git a/effectful/ops/types.py b/effectful/ops/types.py index 975c66dc..52d891f4 100644 --- a/effectful/ops/types.py +++ b/effectful/ops/types.py @@ -380,6 +380,7 @@ def __type_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> type[V]: """ from effectful.internals.unification import ( + _has_typing_self, freetypevars, nested_type, substitute, @@ -394,7 +395,7 @@ def __type_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> type[V]: return typing.cast(type[V], object) elif return_anno is None: return type(None) # type: ignore - elif not freetypevars(return_anno): + elif not freetypevars(return_anno) and not _has_typing_self(return_anno): return return_anno type_args = tuple(nested_type(a).value for a in args) diff --git a/tests/test_internals_unification.py b/tests/test_internals_unification.py index 18b158d8..61a1fcab 100644 --- a/tests/test_internals_unification.py +++ b/tests/test_internals_unification.py @@ -7,6 +7,7 @@ from effectful.internals.unification import ( Box, + _has_typing_self, canonicalize, freetypevars, nested_type, @@ -137,6 +138,8 @@ class GenericClass[T]: (int, {T: str}, int), (str, {}, str), (list[int], {T: str}, list[int]), + # typing.Self with no binding passes through unchanged + (typing.Self, {}, typing.Self), # Single TypeVar in generic (list[T], {T: int}, list[int]), (set[T], {T: str}, set[str]), @@ -488,6 +491,44 @@ def variadic_kwargs_func[T](**kwargs: T) -> T: # Variadic kwargs not supported return next(iter(kwargs.values())) +class _Foo: + def return_self(self) -> typing.Self: + return self + + def return_list_self(self) -> list[typing.Self]: + return [self] + + def return_self_or_none(self) -> typing.Self | None: + return self + + def annotated_self(self: typing.Self) -> typing.Self: + return self + + def takes_other(self, other: typing.Self) -> typing.Self: + return self + + def mixed_with_typevar[T](self, x: T) -> tuple[typing.Self, T]: + return (self, x) + + def return_dict_self(self) -> dict[str, typing.Self]: + return {"me": self} + + def return_callable_self(self) -> collections.abc.Callable[[typing.Self], int]: + return id + + def return_type_self(self) -> type[typing.Self]: + return type(self) + + @classmethod + def from_config(cls, config: int) -> typing.Self: + return cls() + + +class _Bar: + def return_self(self) -> typing.Self: + return self + + @pytest.mark.parametrize( "func,args,kwargs,expected_return_type", [ @@ -565,6 +606,32 @@ def variadic_kwargs_func[T](**kwargs: T) -> T: # Variadic kwargs not supported (variadic_args_func, (int, int), {}, int), (variadic_kwargs_func, (), {"x": int}, int), (variadic_kwargs_func, (), {"x": int, "y": int}, int), + # typing.Self return types (methods) + (_Foo.return_self, (int,), {}, int), + (_Foo.return_self, (str,), {}, str), + (_Foo.return_self, (_Foo,), {}, _Foo), + (_Foo.return_self, (_Bar,), {}, _Bar), + (_Bar.return_self, (_Bar,), {}, _Bar), + (_Foo.annotated_self, (_Foo,), {}, _Foo), + (_Foo.return_list_self, (int,), {}, list[int]), + (_Foo.return_list_self, (_Foo,), {}, list[_Foo]), + (_Foo.return_self_or_none, (int,), {}, int | None), + (_Foo.return_self_or_none, (_Foo,), {}, _Foo | None), + # Self as a non-self parameter + (_Foo.takes_other, (int, int), {}, int), + (_Foo.takes_other, (_Foo, _Foo), {}, _Foo), + # Self mixed with other TypeVars + (_Foo.mixed_with_typevar, (int, str), {}, tuple[int, str]), + (_Foo.mixed_with_typevar, (_Foo, list[int]), {}, tuple[_Foo, list[int]]), + # Self in dict[str, Self] + (_Foo.return_dict_self, (int,), {}, dict[str, int]), + (_Foo.return_dict_self, (_Foo,), {}, dict[str, _Foo]), + # Self inside Callable[[Self], int] + (_Foo.return_callable_self, (int,), {}, collections.abc.Callable[[int], int]), + (_Foo.return_callable_self, (_Foo,), {}, collections.abc.Callable[[_Foo], int]), + # type[Self] + (_Foo.return_type_self, (int,), {}, type[int]), + (_Foo.return_type_self, (_Foo,), {}, type[_Foo]), ], ) def test_infer_return_type_success( @@ -1698,3 +1765,135 @@ def test_binary_on_sequence_elements(f, seq, index1, index2): ), collections.abc.Mapping, ) + + +# ============================================================ +# typing.Self resolution tests +# ============================================================ + + +# --- _has_typing_self --- + + +@pytest.mark.parametrize( + "typ,expected", + [ + (typing.Self, True), + (list[typing.Self], True), # type: ignore[misc] + (typing.Self | None, True), + (dict[str, typing.Self], True), # type: ignore[misc] + (collections.abc.Callable[[typing.Self], int], True), + (type[typing.Self], True), + (int, False), + (list[int], False), + (T, False), + (list[T], False), + ], +) +def test_has_typing_self(typ, expected): + assert _has_typing_self(typ) == expected + + +# --- chaining two signatures with Self from different "classes" --- + + +def test_chained_self_signatures(): + """Two unify calls sharing subs must not conflate Self.""" + sig_a = inspect.signature(_Foo.return_self) + sig_b = inspect.signature(_Bar.return_self) + + subs = unify(sig_a, sig_a.bind(_Foo)) + assert canonicalize(substitute(sig_a.return_annotation, subs)) == _Foo + + # Chaining: second unify with shared subs must not break + subs2 = unify(sig_b, sig_b.bind(_Bar), subs) + assert canonicalize(substitute(sig_b.return_annotation, subs2)) == _Bar + + +# --- classmethod with Self: cls is stripped, Self stays unresolved --- + + +def test_classmethod_self_not_resolved(): + """Classmethod Self stays unresolved when cls is stripped. + + inspect.signature strips `cls`, so `from_config(config: int) -> Self` + has no unannotated first parameter. The Self TypeVar is created but + nothing binds it, so it remains free in the substitution result. + """ + sig = inspect.signature(_Foo.from_config) + subs = unify(sig, sig.bind(int)) + result = substitute(sig.return_annotation, subs) + # Self was replaced with a TypeVar, but nothing bound it. + assert isinstance(result, typing.TypeVar) + assert result.__name__ == "Self" + + +# --- composition tests with Self --- + + +@pytest.mark.parametrize("obj_type", [_Foo, _Bar, int, str]) +def test_infer_self_composition_1(obj_type): + """Compose return_list_self -> get_first, verify matches return_self. + + Step 1: return_list_self(obj) -> list[Self] (Self method) + Step 2: get_first(list[T]) -> T (generic function) + Direct: return_self(obj) -> Self + """ + sig1 = inspect.signature(_Foo.return_list_self) + sig2 = inspect.signature(get_first) + sig_direct = inspect.signature(_Foo.return_self) + + # Step 1: infer list[Self] with Self bound to obj_type + inferred_type1 = substitute( + sig1.return_annotation, + unify(sig1, sig1.bind(obj_type)), + ) + + # Step 2: get_first(list[obj_type]) -> obj_type + inferred_type2 = substitute( + sig2.return_annotation, + unify(sig2, sig2.bind(nested_type(Box(inferred_type1)).value)), + ) + + # Direct: return_self(obj_type) -> obj_type + inferred_direct = substitute( + sig_direct.return_annotation, + unify(sig_direct, sig_direct.bind(obj_type)), + ) + + # The composed inference should match the direct inference + assert isinstance(unify(inferred_type2, inferred_direct), collections.abc.Mapping) + + +@pytest.mark.parametrize("obj_type", [_Foo, _Bar, int, str]) +def test_infer_self_composition_2(obj_type): + """Compose identity -> return_list_self, verify matches wrap_in_list. + + Step 1: identity(x: T) -> T (generic function) + Step 2: return_list_self(self) -> list[Self] (Self method) + Direct: wrap_in_list(x: T) -> list[T] + """ + sig1 = inspect.signature(identity) + sig2 = inspect.signature(_Foo.return_list_self) + sig_direct = inspect.signature(wrap_in_list) + + # Step 1: identity(obj_type) -> obj_type + inferred_type1 = substitute( + sig1.return_annotation, + unify(sig1, sig1.bind(obj_type)), + ) + + # Step 2: return_list_self(obj_type) -> list[obj_type] + inferred_type2 = substitute( + sig2.return_annotation, + unify(sig2, sig2.bind(nested_type(Box(inferred_type1)).value)), + ) + + # Direct: wrap_in_list(obj_type) -> list[obj_type] + inferred_direct = substitute( + sig_direct.return_annotation, + unify(sig_direct, sig_direct.bind(obj_type)), + ) + + # The composed inference should match the direct inference + assert isinstance(unify(inferred_type2, inferred_direct), collections.abc.Mapping) diff --git a/tests/test_ops_semantics.py b/tests/test_ops_semantics.py index 79b80663..82a900a4 100644 --- a/tests/test_ops_semantics.py +++ b/tests/test_ops_semantics.py @@ -2,6 +2,7 @@ import functools import itertools import logging +import typing from collections.abc import Callable, Mapping from typing import Annotated, Any, Literal, Union @@ -863,3 +864,86 @@ def get_mixed() -> Literal[1, "a"]: with pytest.raises(TypeError, match="Union types are not supported"): typeof(get_mixed()) + + +# --- Module-level classes for typing.Self tests --- +# Must be at module level so @defop can resolve annotations via get_type_hints. + + +class _SelfA: + @defop + def ret_self(self, x: int) -> typing.Self: + raise NotHandled + + @defop + def ret_list_self(self, x: int) -> list[typing.Self]: + raise NotHandled + + @defop + def annotated_self(self: typing.Self, x: int) -> typing.Self: + raise NotHandled + + @defop + def ret_self_or_none(self, x: int) -> typing.Self | None: + raise NotHandled + + +class _SelfB: + @defop + def ret_self(self, x: int) -> typing.Self: + raise NotHandled + + +def test_typeof_self_basic(): + """typeof resolves typing.Self to the type of the first argument.""" + obj = _SelfA() + assert typeof(_SelfA.ret_self(obj, 42)) is _SelfA + + +def test_typeof_self_list(): + """typeof resolves list[Self] to list (origin type).""" + obj = _SelfA() + assert typeof(_SelfA.ret_list_self(obj, 42)) is list + + +def test_typeof_self_annotated_param(): + """Self as both the self-parameter annotation and return type.""" + obj = _SelfA() + assert typeof(_SelfA.annotated_self(obj, 42)) is _SelfA + + +def test_typeof_self_two_classes(): + """Self resolves independently per class.""" + a, b = _SelfA(), _SelfB() + assert typeof(_SelfA.ret_self(a, 42)) is _SelfA + assert typeof(_SelfB.ret_self(b, 42)) is _SelfB + + +def test_typeof_self_nested_polymorphic(): + """Self composes with a polymorphic identity operation.""" + + @defop + def identity[T](x: T) -> T: + raise NotHandled + + obj = _SelfA() + assert typeof(identity(_SelfA.ret_self(obj, 42))) is _SelfA + + +@pytest.mark.xfail(reason="Union types are not yet supported") +def test_typeof_self_union(): + """Self | None is a union return type — unsupported.""" + obj = _SelfA() + typeof(_SelfA.ret_self_or_none(obj, 42)) + + +class _SelfClassmethod: + @defop + @classmethod + def cls_ret(cls) -> typing.Self: + raise NotHandled + + +def test_typeof_self_classmethod(): + """Classmethod with Self — cls is stripped so Self is unresolved but does not crash.""" + typeof(_SelfClassmethod.cls_ret())