Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 48 additions & 11 deletions tripy/tests/test_function_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@

import inspect
from textwrap import dedent
from typing import Any, Dict, List, Sequence, Union
from typing import Any, Dict, List, Sequence, Union, Optional

import pytest

import torch
from tests import helper

import tripy as tp
from tripy import TripyException
from tripy.function_registry import AnnotationInfo, FunctionRegistry
from tripy.function_registry import AnnotationInfo, FunctionRegistry, render_arg_type, sanitize_name


@pytest.fixture()
Expand Down Expand Up @@ -326,6 +327,14 @@ def func(n: "tripy.types.NestedNumberSequence"):
assert registry["test"]([1, 2, 3]) == [1, 2, 3]
assert registry["test"]([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]) == [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]

def test_optional_can_be_none(self, registry):
@registry("test")
def func(n: Optional[int]):
return n

assert registry["test"](None) == None
assert registry["test"](1) == 1

def test_error_sequence(self, registry):
@registry("test")
def func(n: Sequence[int]) -> int:
Expand All @@ -344,7 +353,7 @@ def func(n: Sequence[int]) -> int:
[0-9]+ \| \.\.\.
\|\s

Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Sequence\[int\]' but got argument of type: 'List\[float\]'\.
Not a valid overload because: For parameter: 'n', expected an instance of type: 'Sequence\[int\]' but got argument of type: 'List\[float\]'\.
"""
).strip(),
):
Expand All @@ -368,7 +377,7 @@ def func(n: Union[int, float]) -> int:
[0-9]+ \| \.\.\.
\|\s

Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Union\[int, float\]' but got argument of type: 'List\[str\]'\.
Not a valid overload because: For parameter: 'n', expected an instance of type: 'Union\[int, float\]' but got argument of type: 'List\[str\]'\.
"""
).strip(),
):
Expand All @@ -392,7 +401,7 @@ def func(n: Sequence[int]) -> int:
[0-9]+ \| \.\.\.
\|\s

Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Sequence\[int\]' but got argument of type: 'List\[Union\[(int, str)|(str, int)\]\]'\.
Not a valid overload because: For parameter: 'n', expected an instance of type: 'Sequence\[int\]' but got argument of type: 'List\[Union\[(int, str)|(str, int)\]\]'\.
"""
).strip(),
):
Expand All @@ -416,7 +425,7 @@ def func(n: Sequence[Sequence[int]]) -> int:
[0-9]+ \| \.\.\.
\|\s

Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Sequence\[typing\.Sequence\[int\]\]' but got argument of type: 'List\[int\]'\.
Not a valid overload because: For parameter: 'n', expected an instance of type: 'Sequence\[Sequence\[int\]\]' but got argument of type: 'List\[int\]'\.
"""
).strip(),
):
Expand All @@ -440,7 +449,7 @@ def func(n: Sequence[Sequence[int]]) -> int:
[0-9]+ \| \.\.\.
\|\s

Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Sequence\[typing\.Sequence\[int\]\]' but got argument of type: 'List\[List\[float\]\]'\.
Not a valid overload because: For parameter: 'n', expected an instance of type: 'Sequence\[Sequence\[int\]\]' but got argument of type: 'List\[List\[float\]\]'\.
"""
).strip(),
):
Expand All @@ -464,7 +473,7 @@ def func(n: Sequence[Union[int, float]]) -> int:
[0-9]+ \| \.\.\.
\|\s

Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Sequence\[typing\.Union\[int, float\]\]' but got argument of type: 'List\[str\]'\.
Not a valid overload because: For parameter: 'n', expected an instance of type: 'Sequence\[Union\[int, float\]\]' but got argument of type: 'List\[str\]'\.
"""
).strip(),
):
Expand All @@ -488,7 +497,7 @@ def func(n: "tripy.types.NestedNumberSequence"):
[0-9]+ \| \.\.\.
\|\s

Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Union\[numbers\.Number, typing\.Sequence\[ForwardRef\('tripy\.types\.NestedNumberSequence'\)\]\]' but got argument of type: 'str'\.
Not a valid overload because: For parameter: 'n', expected an instance of type: 'Union\[numbers\.Number, Sequence\[tripy\.types\.NestedNumberSequence\]\]' but got argument of type: 'str'\.
"""
).strip(),
):
Expand All @@ -512,8 +521,36 @@ def func(n: "tripy.types.NestedNumberSequence"):
[0-9]+ \| \.\.\.
\|\s

Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Union\[numbers\.Number, typing\.Sequence\[ForwardRef\('tripy\.types\.NestedNumberSequence'\)\]\]' but got argument of type: 'List\[List\[str\]\]'
Not a valid overload because: For parameter: 'n', expected an instance of type: 'Union\[numbers\.Number, Sequence\[tripy\.types\.NestedNumberSequence\]\]' but got argument of type: 'List\[List\[str\]\]'
"""
).strip(),
):
registry["test"]([["a"], ["b"], ["c"]])


@pytest.mark.parametrize(
"typ, expected",
[
(tp.types.TensorLike, "Union[tripy.Tensor, tripy.types.NestedNumberSequence]"),
(tp.types.ShapeLike, "Union[tripy.Shape, Sequence[Union[int, tripy.ShapeScalar]]]"),
(tp.Tensor, "Tensor"),
(torch.Tensor, "torch.Tensor"),
(int, "int"),
(Optional[int], "Optional[int]"),
],
)
def test_sanitize_name(typ, expected):
assert sanitize_name(typ) == expected


@pytest.mark.parametrize(
"typ, expected",
[
(tp.Tensor([1, 2, 3]), "Tensor"),
(torch.tensor([1, 2, 3]), "torch.Tensor"),
(0, "int"),
("hi", "str"),
],
)
def test_render_arg_type(typ, expected):
assert render_arg_type(typ) == expected
79 changes: 55 additions & 24 deletions tripy/tripy/function_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from typing import Any, Callable, Dict, List, Optional

from dataclasses import dataclass
from collections.abc import Sequence as ABCSequence
from typing import ForwardRef, get_args, get_origin, Sequence, Union, Optional


@dataclass
Expand All @@ -31,6 +33,57 @@ class AnnotationInfo:
kind: Any # Uses inspect.Parameter.<kind>


def get_type_name(typ):
# Attach module name if possible
module_name = ""
try:
module_name = typ.__module__ + "."
except AttributeError:
pass
else:
# Don't attach prefix for built-in types or Tripy types.
# If we include modules for Tripy, they will include all submodules, which can be confusing
# e.g. Tensor will be something like "tripy.frontend.tensor.Tensor"
if any(module_name.startswith(skip_module) for skip_module in {"builtins", "tripy"}):
module_name = ""

return module_name + typ.__qualname__


def sanitize_name(annotation):
if get_origin(annotation) is Union and annotation._name == "Optional":
types = get_args(annotation)
return f"{annotation.__name__}[{sanitize_name(types[0])}]"

if get_origin(annotation) in {Union, ABCSequence}:
types = get_args(annotation)
return f"{annotation.__name__}[{', '.join(sanitize_name(typ) for typ in types)}]"

if isinstance(annotation, ForwardRef):
return annotation.__forward_arg__

# typing module annotations are likely to be better when pretty-printed due to including subscripts
return annotation if annotation.__module__ == "typing" else get_type_name(annotation)


def render_arg_type(arg: Any) -> str:
# it is more useful to report more detailed types for sequences/tuples in error messages
from typing import List, Tuple

if isinstance(arg, List):
if len(arg) == 0:
return "List"
# catch inconsistencies this way
arg_types = {render_arg_type(member) for member in arg}
if len(arg_types) == 1:
return f"List[{list(arg_types)[0]}]"
return f"List[Union[{', '.join(arg_types)}]]"
if isinstance(arg, Tuple):
return f"Tuple[{', '.join(map(render_arg_type, arg))}]"

return get_type_name(type(arg))


class FuncOverload:
def __init__(self, func):
self.func = func
Expand Down Expand Up @@ -98,29 +151,7 @@ def _get_annotations(self):
def matches_arg_types(self, args, kwargs) -> "Result":
from tripy.utils.result import Result

def sanitize_name(annotation):
# typing module annotations are likely to be better when pretty-printed due to including subscripts
return annotation if annotation.__module__ == "typing" else annotation.__qualname__

def render_arg_type(arg: Any) -> str:
# it is more useful to report more detailed types for sequences/tuples in error messages
from typing import List, Tuple

if isinstance(arg, List):
if len(arg) == 0:
return "List"
# catch inconsistencies this way
arg_types = {render_arg_type(member) for member in arg}
if len(arg_types) == 1:
return f"List[{list(arg_types)[0]}]"
return f"List[Union[{', '.join(arg_types)}]]"
if isinstance(arg, Tuple):
return f"Tuple[{', '.join(map(render_arg_type, arg))}]"
return type(arg).__qualname__

def matches_type(name: str, annotation: type, arg: Any) -> bool:
from collections.abc import Sequence as ABCSequence
from typing import ForwardRef, get_args, get_origin, Sequence, Union, Optional

# In cases where a type is not available at the time of function definition, the type
# annotation may be provided as a string. Since we need the actual type, we just
Expand Down Expand Up @@ -149,12 +180,12 @@ def matches_type(name: str, annotation: type, arg: Any) -> bool:
return all(map(lambda member: matches_type(name, seq_arg[0], member), arg))
return True

if get_origin(annotation) is Optional:
if get_origin(annotation) is Union and annotation._name == "Optional":
return arg is None or matches_type(arg, get_args(annotation)[0])

# Forward references can be used for recursive type definitions. Warning: Has the potential for infinite looping if there is no base case!
if isinstance(annotation, ForwardRef):
# need this import in case the annotation references tripy
# NOTE: We need this import in case the annotation references tripy
import tripy

return matches_type(name, eval(annotation.__forward_arg__), arg)
Expand Down