Skip to content

Commit 087e0d5

Browse files
Improves docstrings for overloaded functions
Improves the docstrings for overloaded functions to be stylistically similar to non-overloaded functions. Also updates the helpers that generate strings from type annotations to be more consistent with the style the documentation uses. For example, `Union[int, float]` will now be rendered as `int | float`.
1 parent 2ad2c0a commit 087e0d5

File tree

5 files changed

+143
-84
lines changed

5 files changed

+143
-84
lines changed

tripy/docs/_static/style.css

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,14 @@ section {
2121
margin-top: 2rem;
2222
margin-bottom: 2rem;
2323
}
24+
25+
.func-overload-sig {
26+
padding-left: 3em !important;
27+
color: var(--color-api-overall);
28+
font-style: normal;
29+
}
30+
31+
.func-overload-sig p {
32+
margin-bottom: 0 !important;
33+
margin-top: 0 !important;
34+
}

tripy/tests/test_function_registry.py

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
import tripy as tp
2727
from tripy import TripyException
28-
from tripy.function_registry import AnnotationInfo, FunctionRegistry, render_arg_type, sanitize_name
28+
from tripy.function_registry import AnnotationInfo, FunctionRegistry, type_str_from_arg, str_from_type_annotation
2929

3030

3131
@pytest.fixture()
@@ -199,10 +199,10 @@ def func(a: int):
199199

200200
func_overload = registry.overloads["test"][0]
201201

202-
assert not func_overload.annotations
202+
assert not func_overload._annotations
203203
assert registry["test"](0) == 1
204-
assert func_overload.annotations
205-
assert func_overload.annotations["a"] == AnnotationInfo(int, False, inspect.Parameter.POSITIONAL_OR_KEYWORD)
204+
assert func_overload._annotations
205+
assert func_overload._annotations["a"] == AnnotationInfo(int, False, inspect.Parameter.POSITIONAL_OR_KEYWORD)
206206

207207
def test_doc_of_non_overloaded_func(self, registry):
208208
# When there is no overload, the registry function should
@@ -224,31 +224,46 @@ def func(a: int):
224224
"""
225225
pass
226226

227+
# Tripy types should turn into class links
227228
@registry("test")
228-
def func(a: float):
229+
def func(a: Union[int, "tripy.Tensor"]):
229230
"""
230-
This func takes a float.
231+
This func takes an int or a tensor.
231232
"""
232233
pass
233234

234235
print(registry["test"].__doc__)
235236
assert (
236237
registry["test"].__doc__
237238
== dedent(
238-
"""
239+
r"""
239240
*This function has multiple overloads:*
240241
241242
----------
242243
243-
> **test** (*a*: :class:`int`) -> None
244+
.. role:: sig-prename
245+
:class: sig-prename descclassname
246+
.. role:: sig-name
247+
:class: sig-name descname
248+
249+
.. container:: func-overload-sig sig sig-object py
250+
251+
:sig-prename:`tripy`\ .\ :sig-name:`test`\ (a: int) -> None
244252
245253
This func takes an int.
246254
247255
----------
248256
249-
> **test** (*a*: :class:`float`) -> None
257+
.. role:: sig-prename
258+
:class: sig-prename descclassname
259+
.. role:: sig-name
260+
:class: sig-name descname
261+
262+
.. container:: func-overload-sig sig sig-object py
263+
264+
:sig-prename:`tripy`\ .\ :sig-name:`test`\ (a: int | :class:`tripy.Tensor`) -> None
250265
251-
This func takes a float.
266+
This func takes an int or a tensor.
252267
"""
253268
).strip()
254269
)
@@ -379,7 +394,7 @@ def func(n: Union[int, float]) -> int:
379394
[0-9]+ \| \.\.\.
380395
\|\s
381396
382-
Not a valid overload because: For parameter: 'n', expected an instance of type: 'Union\[int, float\]' but got argument of type: 'List\[str\]'\.
397+
Not a valid overload because: For parameter: 'n', expected an instance of type: 'int | float' but got argument of type: 'List\[str\]'\.
383398
"""
384399
).strip(),
385400
):
@@ -403,7 +418,7 @@ def func(n: Sequence[int]) -> int:
403418
[0-9]+ \| \.\.\.
404419
\|\s
405420
406-
Not a valid overload because: For parameter: 'n', expected an instance of type: 'Sequence\[int\]' but got argument of type: 'List\[Union\[(int, str)|(str, int)\]\]'\.
421+
Not a valid overload because: For parameter: 'n', expected an instance of type: 'Sequence\[int\]' but got argument of type: 'List\[(int \| str)|(str \| int)\]'\.
407422
"""
408423
).strip(),
409424
):
@@ -475,7 +490,7 @@ def func(n: Sequence[Union[int, float]]) -> int:
475490
[0-9]+ \| \.\.\.
476491
\|\s
477492
478-
Not a valid overload because: For parameter: 'n', expected an instance of type: 'Sequence\[Union\[int, float\]\]' but got argument of type: 'List\[str\]'\.
493+
Not a valid overload because: For parameter: 'n', expected an instance of type: 'Sequence\[int | float\]' but got argument of type: 'List\[str\]'\.
479494
"""
480495
).strip(),
481496
):
@@ -496,16 +511,16 @@ def func(a: int, *args: int) -> int:
496511
@pytest.mark.parametrize(
497512
"typ, expected",
498513
[
499-
(tp.types.TensorLike, "Union[tripy.Tensor, numbers.Number]"),
500-
(tp.types.ShapeLike, "Sequence[Union[int, tripy.DimensionSize]]"),
514+
(tp.types.TensorLike, "tripy.Tensor | numbers.Number"),
515+
(tp.types.ShapeLike, "Sequence[int | tripy.DimensionSize]"),
501516
(tp.Tensor, "Tensor"),
502517
(torch.Tensor, "torch.Tensor"),
503518
(int, "int"),
504-
(Optional[int], "Optional[int]"),
519+
(Optional[int], "int | None"),
505520
],
506521
)
507-
def test_sanitize_name(typ, expected):
508-
assert sanitize_name(typ) == expected
522+
def test_str_from_type_annotation(typ, expected):
523+
assert str_from_type_annotation(typ) == expected
509524

510525

511526
@pytest.mark.parametrize(
@@ -517,5 +532,5 @@ def test_sanitize_name(typ, expected):
517532
("hi", "str"),
518533
],
519534
)
520-
def test_render_arg_type(typ, expected):
521-
assert render_arg_type(typ) == expected
535+
def test_type_str_from_arg(typ, expected):
536+
assert type_str_from_arg(typ) == expected

tripy/tripy/backend/api/executable.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from tripy.backend.mlir import utils as mlir_utils
2424
from tripy.common.exception import raise_error
2525
from tripy.frontend import Tensor
26-
from tripy.function_registry import sanitize_name
26+
from tripy.function_registry import str_from_type_annotation
2727
from tripy.utils import json as json_utils
2828
from dataclasses import dataclass
2929

@@ -73,8 +73,11 @@ def stream(self, stream):
7373
self._executor.stream = stream
7474

7575
def __str__(self) -> str:
76-
params = [f"{name}: {sanitize_name(param.annotation)}" for name, param in self.__signature__.parameters.items()]
77-
return f"Executable({', '.join(params)}) -> {sanitize_name(self.__signature__.return_annotation)}"
76+
params = [
77+
f"{name}: {str_from_type_annotation(param.annotation)}"
78+
for name, param in self.__signature__.parameters.items()
79+
]
80+
return f"Executable({', '.join(params)}) -> {str_from_type_annotation(self.__signature__.return_annotation)}"
7881

7982
@staticmethod
8083
def load(path: str) -> "tripy.Executable":

tripy/tripy/frontend/ops/tensor_initializers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def arange(
292292
) -> "tripy.Tensor":
293293
r"""
294294
Returns a 1D tensor containing a sequence of numbers in the half-open interval
295-
:math:`[0, \text{stop})` incrementing by :math:`\text{step}`.
295+
:math:`[\text{start}, \text{stop})` incrementing by :math:`\text{step}`.
296296
297297
Args:
298298
start: The inclusive lower bound of the values to generate. If a tensor is provided, it must be a scalar tensor.

0 commit comments

Comments
 (0)