Skip to content

Commit 4b471cd

Browse files
Improves error messages for incorrect types
Cleans up the display of type annotations and includes module names for arguments provided from outside of Tripy or Python builtins. Before (line breaks added for clarity): ``` Not a valid overload because: For parameter: 'other', expected an instance of type: 'typing.Union[ForwardRef('tripy.Tensor'), ForwardRef('tripy.types.NestedNumberSequence')]' but got argument of type: 'Tensor'. ``` After: ``` Not a valid overload because: For parameter: 'other', expected an instance of type: 'Union[tripy.Tensor, tripy.types.NestedNumberSequence]' but got argument of type: 'torch.Tensor'. ```
1 parent 8d67c62 commit 4b471cd

File tree

2 files changed

+103
-35
lines changed

2 files changed

+103
-35
lines changed

tripy/tests/test_function_registry.py

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@
1717

1818
import inspect
1919
from textwrap import dedent
20-
from typing import Any, Dict, List, Sequence, Union
20+
from typing import Any, Dict, List, Sequence, Union, Optional
2121

2222
import pytest
23-
23+
import torch
2424
from tests import helper
2525

26+
import tripy as tp
2627
from tripy import TripyException
27-
from tripy.function_registry import AnnotationInfo, FunctionRegistry
28+
from tripy.function_registry import AnnotationInfo, FunctionRegistry, render_arg_type, sanitize_name
2829

2930

3031
@pytest.fixture()
@@ -326,6 +327,14 @@ def func(n: "tripy.types.NestedNumberSequence"):
326327
assert registry["test"]([1, 2, 3]) == [1, 2, 3]
327328
assert registry["test"]([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]) == [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
328329

330+
def test_optional_can_be_none(self, registry):
331+
@registry("test")
332+
def func(n: Optional[int]):
333+
return n
334+
335+
assert registry["test"](None) == None
336+
assert registry["test"](1) == 1
337+
329338
def test_error_sequence(self, registry):
330339
@registry("test")
331340
def func(n: Sequence[int]) -> int:
@@ -344,7 +353,7 @@ def func(n: Sequence[int]) -> int:
344353
[0-9]+ \| \.\.\.
345354
\|\s
346355
347-
Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Sequence\[int\]' but got argument of type: 'List\[float\]'\.
356+
Not a valid overload because: For parameter: 'n', expected an instance of type: 'Sequence\[int\]' but got argument of type: 'List\[float\]'\.
348357
"""
349358
).strip(),
350359
):
@@ -368,7 +377,7 @@ def func(n: Union[int, float]) -> int:
368377
[0-9]+ \| \.\.\.
369378
\|\s
370379
371-
Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Union\[int, float\]' but got argument of type: 'List\[str\]'\.
380+
Not a valid overload because: For parameter: 'n', expected an instance of type: 'Union\[int, float\]' but got argument of type: 'List\[str\]'\.
372381
"""
373382
).strip(),
374383
):
@@ -392,7 +401,7 @@ def func(n: Sequence[int]) -> int:
392401
[0-9]+ \| \.\.\.
393402
\|\s
394403
395-
Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Sequence\[int\]' but got argument of type: 'List\[Union\[(int, str)|(str, int)\]\]'\.
404+
Not a valid overload because: For parameter: 'n', expected an instance of type: 'Sequence\[int\]' but got argument of type: 'List\[Union\[(int, str)|(str, int)\]\]'\.
396405
"""
397406
).strip(),
398407
):
@@ -416,7 +425,7 @@ def func(n: Sequence[Sequence[int]]) -> int:
416425
[0-9]+ \| \.\.\.
417426
\|\s
418427
419-
Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Sequence\[typing\.Sequence\[int\]\]' but got argument of type: 'List\[int\]'\.
428+
Not a valid overload because: For parameter: 'n', expected an instance of type: 'Sequence\[Sequence\[int\]\]' but got argument of type: 'List\[int\]'\.
420429
"""
421430
).strip(),
422431
):
@@ -440,7 +449,7 @@ def func(n: Sequence[Sequence[int]]) -> int:
440449
[0-9]+ \| \.\.\.
441450
\|\s
442451
443-
Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Sequence\[typing\.Sequence\[int\]\]' but got argument of type: 'List\[List\[float\]\]'\.
452+
Not a valid overload because: For parameter: 'n', expected an instance of type: 'Sequence\[Sequence\[int\]\]' but got argument of type: 'List\[List\[float\]\]'\.
444453
"""
445454
).strip(),
446455
):
@@ -464,7 +473,7 @@ def func(n: Sequence[Union[int, float]]) -> int:
464473
[0-9]+ \| \.\.\.
465474
\|\s
466475
467-
Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Sequence\[typing\.Union\[int, float\]\]' but got argument of type: 'List\[str\]'\.
476+
Not a valid overload because: For parameter: 'n', expected an instance of type: 'Sequence\[Union\[int, float\]\]' but got argument of type: 'List\[str\]'\.
468477
"""
469478
).strip(),
470479
):
@@ -488,7 +497,7 @@ def func(n: "tripy.types.NestedNumberSequence"):
488497
[0-9]+ \| \.\.\.
489498
\|\s
490499
491-
Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Union\[numbers\.Number, typing\.Sequence\[ForwardRef\('tripy\.types\.NestedNumberSequence'\)\]\]' but got argument of type: 'str'\.
500+
Not a valid overload because: For parameter: 'n', expected an instance of type: 'Union\[numbers\.Number, Sequence\[tripy\.types\.NestedNumberSequence\]\]' but got argument of type: 'str'\.
492501
"""
493502
).strip(),
494503
):
@@ -512,8 +521,36 @@ def func(n: "tripy.types.NestedNumberSequence"):
512521
[0-9]+ \| \.\.\.
513522
\|\s
514523
515-
Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Union\[numbers\.Number, typing\.Sequence\[ForwardRef\('tripy\.types\.NestedNumberSequence'\)\]\]' but got argument of type: 'List\[List\[str\]\]'
524+
Not a valid overload because: For parameter: 'n', expected an instance of type: 'Union\[numbers\.Number, Sequence\[tripy\.types\.NestedNumberSequence\]\]' but got argument of type: 'List\[List\[str\]\]'
516525
"""
517526
).strip(),
518527
):
519528
registry["test"]([["a"], ["b"], ["c"]])
529+
530+
531+
@pytest.mark.parametrize(
532+
"typ, expected",
533+
[
534+
(tp.types.TensorLike, "Union[tripy.Tensor, tripy.types.NestedNumberSequence]"),
535+
(tp.types.ShapeLike, "Union[tripy.Shape, Sequence[Union[int, tripy.ShapeScalar]]]"),
536+
(tp.Tensor, "Tensor"),
537+
(torch.Tensor, "torch.Tensor"),
538+
(int, "int"),
539+
(Optional[int], "Optional[int]"),
540+
],
541+
)
542+
def test_sanitize_name(typ, expected):
543+
assert sanitize_name(typ) == expected
544+
545+
546+
@pytest.mark.parametrize(
547+
"typ, expected",
548+
[
549+
(tp.Tensor([1, 2, 3]), "Tensor"),
550+
(torch.tensor([1, 2, 3]), "torch.Tensor"),
551+
(0, "int"),
552+
("hi", "str"),
553+
],
554+
)
555+
def test_render_arg_type(typ, expected):
556+
assert render_arg_type(typ) == expected

tripy/tripy/function_registry.py

Lines changed: 55 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from typing import Any, Callable, Dict, List, Optional
2323

2424
from dataclasses import dataclass
25+
from collections.abc import Sequence as ABCSequence
26+
from typing import ForwardRef, get_args, get_origin, Sequence, Union, Optional
2527

2628

2729
@dataclass
@@ -31,6 +33,57 @@ class AnnotationInfo:
3133
kind: Any # Uses inspect.Parameter.<kind>
3234

3335

36+
def get_type_name(typ):
37+
# Attach module name if possible
38+
module_name = ""
39+
try:
40+
module_name = typ.__module__ + "."
41+
except AttributeError:
42+
pass
43+
else:
44+
# Don't attach prefix for built-in types or Tripy types.
45+
# If we include modules for Tripy, they will include all submodules, which can be confusing
46+
# e.g. Tensor will be something like "tripy.frontend.tensor.Tensor"
47+
if any(module_name.startswith(skip_module) for skip_module in {"builtins", "tripy"}):
48+
module_name = ""
49+
50+
return module_name + typ.__qualname__
51+
52+
53+
def sanitize_name(annotation):
54+
if get_origin(annotation) is Union and annotation._name == "Optional":
55+
types = get_args(annotation)
56+
return f"{annotation.__name__}[{sanitize_name(types[0])}]"
57+
58+
if get_origin(annotation) in {Union, ABCSequence}:
59+
types = get_args(annotation)
60+
return f"{annotation.__name__}[{', '.join(sanitize_name(typ) for typ in types)}]"
61+
62+
if isinstance(annotation, ForwardRef):
63+
return annotation.__forward_arg__
64+
65+
# typing module annotations are likely to be better when pretty-printed due to including subscripts
66+
return annotation if annotation.__module__ == "typing" else get_type_name(annotation)
67+
68+
69+
def render_arg_type(arg: Any) -> str:
70+
# it is more useful to report more detailed types for sequences/tuples in error messages
71+
from typing import List, Tuple
72+
73+
if isinstance(arg, List):
74+
if len(arg) == 0:
75+
return "List"
76+
# catch inconsistencies this way
77+
arg_types = {render_arg_type(member) for member in arg}
78+
if len(arg_types) == 1:
79+
return f"List[{list(arg_types)[0]}]"
80+
return f"List[Union[{', '.join(arg_types)}]]"
81+
if isinstance(arg, Tuple):
82+
return f"Tuple[{', '.join(map(render_arg_type, arg))}]"
83+
84+
return get_type_name(type(arg))
85+
86+
3487
class FuncOverload:
3588
def __init__(self, func):
3689
self.func = func
@@ -98,29 +151,7 @@ def _get_annotations(self):
98151
def matches_arg_types(self, args, kwargs) -> "Result":
99152
from tripy.utils.result import Result
100153

101-
def sanitize_name(annotation):
102-
# typing module annotations are likely to be better when pretty-printed due to including subscripts
103-
return annotation if annotation.__module__ == "typing" else annotation.__qualname__
104-
105-
def render_arg_type(arg: Any) -> str:
106-
# it is more useful to report more detailed types for sequences/tuples in error messages
107-
from typing import List, Tuple
108-
109-
if isinstance(arg, List):
110-
if len(arg) == 0:
111-
return "List"
112-
# catch inconsistencies this way
113-
arg_types = {render_arg_type(member) for member in arg}
114-
if len(arg_types) == 1:
115-
return f"List[{list(arg_types)[0]}]"
116-
return f"List[Union[{', '.join(arg_types)}]]"
117-
if isinstance(arg, Tuple):
118-
return f"Tuple[{', '.join(map(render_arg_type, arg))}]"
119-
return type(arg).__qualname__
120-
121154
def matches_type(name: str, annotation: type, arg: Any) -> bool:
122-
from collections.abc import Sequence as ABCSequence
123-
from typing import ForwardRef, get_args, get_origin, Sequence, Union, Optional
124155

125156
# In cases where a type is not available at the time of function definition, the type
126157
# annotation may be provided as a string. Since we need the actual type, we just
@@ -149,12 +180,12 @@ def matches_type(name: str, annotation: type, arg: Any) -> bool:
149180
return all(map(lambda member: matches_type(name, seq_arg[0], member), arg))
150181
return True
151182

152-
if get_origin(annotation) is Optional:
183+
if get_origin(annotation) is Union and annotation._name == "Optional":
153184
return arg is None or matches_type(arg, get_args(annotation)[0])
154185

155186
# Forward references can be used for recursive type definitions. Warning: Has the potential for infinite looping if there is no base case!
156187
if isinstance(annotation, ForwardRef):
157-
# need this import in case the annotation references tripy
188+
# NOTE: We need this import in case the annotation references tripy
158189
import tripy
159190

160191
return matches_type(name, eval(annotation.__forward_arg__), arg)

0 commit comments

Comments
 (0)