Skip to content

Commit 360375f

Browse files
Improves error messages for incorrect types
Cleans up the display of type annotations and includes module names for arguments provided from outside of Tripy or Python builtins. Before (line breaks added for clarity): ``` Not a valid overload because: For parameter: 'other', expected an instance of type: 'typing.Union[ForwardRef('tripy.Tensor'), ForwardRef('tripy.types.NestedNumberSequence')]' but got argument of type: 'Tensor'. ``` After: ``` Not a valid overload because: For parameter: 'other', expected an instance of type: 'Union[tripy.Tensor, tripy.types.NestedNumberSequence]' but got argument of type: 'torch.Tensor'. ```
1 parent 8d67c62 commit 360375f

File tree

2 files changed

+88
-33
lines changed

2 files changed

+88
-33
lines changed

tripy/tests/test_function_registry.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@
2020
from typing import Any, Dict, List, Sequence, Union
2121

2222
import pytest
23-
23+
import torch
2424
from tests import helper
2525

26+
import tripy as tp
2627
from tripy import TripyException
27-
from tripy.function_registry import AnnotationInfo, FunctionRegistry
28+
from tripy.function_registry import AnnotationInfo, FunctionRegistry, render_arg_type, sanitize_name
2829

2930

3031
@pytest.fixture()
@@ -344,7 +345,7 @@ def func(n: Sequence[int]) -> int:
344345
[0-9]+ \| \.\.\.
345346
\|\s
346347
347-
Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Sequence\[int\]' but got argument of type: 'List\[float\]'\.
348+
Not a valid overload because: For parameter: 'n', expected an instance of type: 'Sequence\[int\]' but got argument of type: 'List\[float\]'\.
348349
"""
349350
).strip(),
350351
):
@@ -368,7 +369,7 @@ def func(n: Union[int, float]) -> int:
368369
[0-9]+ \| \.\.\.
369370
\|\s
370371
371-
Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Union\[int, float\]' but got argument of type: 'List\[str\]'\.
372+
Not a valid overload because: For parameter: 'n', expected an instance of type: 'Union\[int, float\]' but got argument of type: 'List\[str\]'\.
372373
"""
373374
).strip(),
374375
):
@@ -392,7 +393,7 @@ def func(n: Sequence[int]) -> int:
392393
[0-9]+ \| \.\.\.
393394
\|\s
394395
395-
Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Sequence\[int\]' but got argument of type: 'List\[Union\[(int, str)|(str, int)\]\]'\.
396+
Not a valid overload because: For parameter: 'n', expected an instance of type: 'Sequence\[int\]' but got argument of type: 'List\[Union\[(int, str)|(str, int)\]\]'\.
396397
"""
397398
).strip(),
398399
):
@@ -416,7 +417,7 @@ def func(n: Sequence[Sequence[int]]) -> int:
416417
[0-9]+ \| \.\.\.
417418
\|\s
418419
419-
Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Sequence\[typing\.Sequence\[int\]\]' but got argument of type: 'List\[int\]'\.
420+
Not a valid overload because: For parameter: 'n', expected an instance of type: 'Sequence\[Sequence\[int\]\]' but got argument of type: 'List\[int\]'\.
420421
"""
421422
).strip(),
422423
):
@@ -440,7 +441,7 @@ def func(n: Sequence[Sequence[int]]) -> int:
440441
[0-9]+ \| \.\.\.
441442
\|\s
442443
443-
Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Sequence\[typing\.Sequence\[int\]\]' but got argument of type: 'List\[List\[float\]\]'\.
444+
Not a valid overload because: For parameter: 'n', expected an instance of type: 'Sequence\[Sequence\[int\]\]' but got argument of type: 'List\[List\[float\]\]'\.
444445
"""
445446
).strip(),
446447
):
@@ -464,7 +465,7 @@ def func(n: Sequence[Union[int, float]]) -> int:
464465
[0-9]+ \| \.\.\.
465466
\|\s
466467
467-
Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Sequence\[typing\.Union\[int, float\]\]' but got argument of type: 'List\[str\]'\.
468+
Not a valid overload because: For parameter: 'n', expected an instance of type: 'Sequence\[Union\[int, float\]\]' but got argument of type: 'List\[str\]'\.
468469
"""
469470
).strip(),
470471
):
@@ -488,7 +489,7 @@ def func(n: "tripy.types.NestedNumberSequence"):
488489
[0-9]+ \| \.\.\.
489490
\|\s
490491
491-
Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Union\[numbers\.Number, typing\.Sequence\[ForwardRef\('tripy\.types\.NestedNumberSequence'\)\]\]' but got argument of type: 'str'\.
492+
Not a valid overload because: For parameter: 'n', expected an instance of type: 'Union\[numbers\.Number, Sequence\[tripy\.types\.NestedNumberSequence\]\]' but got argument of type: 'str'\.
492493
"""
493494
).strip(),
494495
):
@@ -512,8 +513,35 @@ def func(n: "tripy.types.NestedNumberSequence"):
512513
[0-9]+ \| \.\.\.
513514
\|\s
514515
515-
Not a valid overload because: For parameter: 'n', expected an instance of type: 'typing\.Union\[numbers\.Number, typing\.Sequence\[ForwardRef\('tripy\.types\.NestedNumberSequence'\)\]\]' but got argument of type: 'List\[List\[str\]\]'
516+
Not a valid overload because: For parameter: 'n', expected an instance of type: 'Union\[numbers\.Number, Sequence\[tripy\.types\.NestedNumberSequence\]\]' but got argument of type: 'List\[List\[str\]\]'
516517
"""
517518
).strip(),
518519
):
519520
registry["test"]([["a"], ["b"], ["c"]])
521+
522+
523+
@pytest.mark.parametrize(
524+
"typ, expected",
525+
[
526+
(tp.types.TensorLike, "Union[tripy.Tensor, tripy.types.NestedNumberSequence]"),
527+
(tp.types.ShapeLike, "Union[tripy.Shape, Sequence[Union[int, tripy.ShapeScalar]]]"),
528+
(tp.Tensor, "Tensor"),
529+
(torch.Tensor, "torch.Tensor"),
530+
(int, "int"),
531+
],
532+
)
533+
def test_sanitize_name(typ, expected):
534+
assert sanitize_name(typ) == expected
535+
536+
537+
@pytest.mark.parametrize(
538+
"typ, expected",
539+
[
540+
(tp.Tensor([1, 2, 3]), "Tensor"),
541+
(torch.tensor([1, 2, 3]), "torch.Tensor"),
542+
(0, "int"),
543+
("hi", "str"),
544+
],
545+
)
546+
def test_render_arg_type(typ, expected):
547+
assert render_arg_type(typ) == expected

tripy/tripy/function_registry.py

Lines changed: 50 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from typing import Any, Callable, Dict, List, Optional
2323

2424
from dataclasses import dataclass
25+
from collections.abc import Sequence as ABCSequence
26+
from typing import ForwardRef, get_args, get_origin, Sequence, Union, Optional
2527

2628

2729
@dataclass
@@ -31,6 +33,53 @@ class AnnotationInfo:
3133
kind: Any # Uses inspect.Parameter.<kind>
3234

3335

36+
def get_type_name(typ):
37+
# Attach module name if possible
38+
module_name = ""
39+
try:
40+
module_name = typ.__module__ + "."
41+
except AttributeError:
42+
pass
43+
else:
44+
# Don't attach prefix for built-in types or Tripy types.
45+
# If we include modules for Tripy, they will include all submodules, which can be confusing
46+
# e.g. Tensor will be something like "tripy.frontend.tensor.Tensor"
47+
if any(module_name.startswith(skip_module) for skip_module in {"builtins", "tripy"}):
48+
module_name = ""
49+
50+
return module_name + typ.__qualname__
51+
52+
53+
def sanitize_name(annotation):
54+
if get_origin(annotation) in {Union, ABCSequence, Optional}:
55+
types = get_args(annotation)
56+
return f"{annotation.__name__}[{', '.join(sanitize_name(typ) for typ in types)}]"
57+
58+
if isinstance(annotation, ForwardRef):
59+
return annotation.__forward_arg__
60+
61+
# typing module annotations are likely to be better when pretty-printed due to including subscripts
62+
return annotation if annotation.__module__ == "typing" else get_type_name(annotation)
63+
64+
65+
def render_arg_type(arg: Any) -> str:
66+
# it is more useful to report more detailed types for sequences/tuples in error messages
67+
from typing import List, Tuple
68+
69+
if isinstance(arg, List):
70+
if len(arg) == 0:
71+
return "List"
72+
# catch inconsistencies this way
73+
arg_types = {render_arg_type(member) for member in arg}
74+
if len(arg_types) == 1:
75+
return f"List[{list(arg_types)[0]}]"
76+
return f"List[Union[{', '.join(arg_types)}]]"
77+
if isinstance(arg, Tuple):
78+
return f"Tuple[{', '.join(map(render_arg_type, arg))}]"
79+
80+
return get_type_name(type(arg))
81+
82+
3483
class FuncOverload:
3584
def __init__(self, func):
3685
self.func = func
@@ -98,29 +147,7 @@ def _get_annotations(self):
98147
def matches_arg_types(self, args, kwargs) -> "Result":
99148
from tripy.utils.result import Result
100149

101-
def sanitize_name(annotation):
102-
# typing module annotations are likely to be better when pretty-printed due to including subscripts
103-
return annotation if annotation.__module__ == "typing" else annotation.__qualname__
104-
105-
def render_arg_type(arg: Any) -> str:
106-
# it is more useful to report more detailed types for sequences/tuples in error messages
107-
from typing import List, Tuple
108-
109-
if isinstance(arg, List):
110-
if len(arg) == 0:
111-
return "List"
112-
# catch inconsistencies this way
113-
arg_types = {render_arg_type(member) for member in arg}
114-
if len(arg_types) == 1:
115-
return f"List[{list(arg_types)[0]}]"
116-
return f"List[Union[{', '.join(arg_types)}]]"
117-
if isinstance(arg, Tuple):
118-
return f"Tuple[{', '.join(map(render_arg_type, arg))}]"
119-
return type(arg).__qualname__
120-
121150
def matches_type(name: str, annotation: type, arg: Any) -> bool:
122-
from collections.abc import Sequence as ABCSequence
123-
from typing import ForwardRef, get_args, get_origin, Sequence, Union, Optional
124151

125152
# In cases where a type is not available at the time of function definition, the type
126153
# annotation may be provided as a string. Since we need the actual type, we just
@@ -154,7 +181,7 @@ def matches_type(name: str, annotation: type, arg: Any) -> bool:
154181

155182
# Forward references can be used for recursive type definitions. Warning: Has the potential for infinite looping if there is no base case!
156183
if isinstance(annotation, ForwardRef):
157-
# need this import in case the annotation references tripy
184+
# NOTE: We need this import in case the annotation references tripy
158185
import tripy
159186

160187
return matches_type(name, eval(annotation.__forward_arg__), arg)

0 commit comments

Comments
 (0)