Skip to content

Commit 8d67c62

Browse files
Various bug fixes to fix L1 tests
- Adds tests to ensure we only use ASCII characters in our code and markdown files. Some non-ASCII characters that look like ASCII characters were breaking tests that make assumptions about the encoding. - Updates dtype constraint tests to properly construct inputs for various APIs. - Moves dtype constraint tests to L0 since they only add about 20 seconds and were breaking very often otherwise. - Updates binary elementwise ops to only return `ShapeScalar`s if *all* inputs are also `ShapeScalar`s. - Corrects dtype constraints for several APIs. - Simplifies `mean` implementation. - Corrects type annotations for some APIs and introduces a `ShapeLike` type. - Adds support for `Optional` during type checking. - Updates `convert_inputs_to_tensors` to now potentially convert tensors to `Tensor` subclasses. See the note in the code for details on the logic.
1 parent 2503d17 commit 8d67c62

File tree

28 files changed

+234
-136
lines changed

28 files changed

+234
-136
lines changed

tripy/docs/post0_developer_guides/design-decisions.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ to MLIR if needed or return the newly created ones.
100100

101101
## Why Not Build On [JAX](https://github.com/google/jax)?
102102

103-
Tripys architecture looks very similar to JAX's, where python code is staged out and
103+
Tripy's architecture looks very similar to JAX's, where python code is staged out and
104104
lowered into a custom IR and eventually to MLIR.
105105

106106
Then why not build on top of JAX? There are a couple reasons:

tripy/docs/pre0_user_guides/02-compiler.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ fast_geglu(inp).eval()
5858
### Optimization Profiles
5959

6060
In the example above, we assumed `inp` has a static shape of `(1, 2)`.
61-
Now, lets assume that the shape of `inp` can vary from `(1, 2)` to `(16, 2)`, with `(8, 2)`
61+
Now, let's assume that the shape of `inp` can vary from `(1, 2)` to `(16, 2)`, with `(8, 2)`
6262
being the shape we'd like to optimize for. To express this constraint to the compiler,
6363
we can provide the range of shapes to `InputInfo` using `shape=((1, 8, 16), 2)`.
6464
This indicates to the compiler that the first dimension can vary from 1 to 16,

tripy/tests/flat_ir/ops/test_gather.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from tripy.flat_ir.ops import DynamicGatherOp
2222
from tripy.frontend.trace import Trace
23+
import re
2324

2425

2526
class TestGatherOp:
@@ -38,9 +39,9 @@ def test_gather_str(self, axis):
3839
reshape = flat_ir.ops[-2]
3940
print(str(reshape))
4041
assert isinstance(gather, DynamicGatherOp)
41-
assert (
42-
str(gather)
43-
== f"out: [rank=(3), dtype=(float32), loc=(gpu:0)] = DynamicGatherOp(data, indices, t_inter4, axis={axis})"
42+
assert re.match(
43+
rf"out: \[rank=\(3\), dtype=\(float32\), loc=\(gpu:0\)\] = DynamicGatherOp\(data, indices, t_inter[0-9]+, axis={axis}\)",
44+
str(gather),
4445
)
4546

4647
@pytest.mark.parametrize("axis", [0, 1])

tripy/tests/helper.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,24 @@
4343

4444
ROOT_DIR = os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.pardir))
4545

46-
MARKDOWN_FILES = [
47-
path
48-
for path in glob.glob(os.path.join(ROOT_DIR, "**", "*.md"), recursive=True)
49-
if not path.startswith(
50-
(
51-
os.path.join(ROOT_DIR, "build"),
52-
os.path.join(ROOT_DIR, "mlir-tensorrt"),
53-
os.path.join(ROOT_DIR, "stablehlo"),
46+
47+
def get_files_with_extension(ext):
48+
return [
49+
path
50+
for path in glob.glob(os.path.join(ROOT_DIR, "**", f"*{ext}"), recursive=True)
51+
if not path.startswith(
52+
(
53+
os.path.join(ROOT_DIR, "build"),
54+
os.path.join(ROOT_DIR, "mlir-tensorrt"),
55+
os.path.join(ROOT_DIR, "stablehlo"),
56+
)
5457
)
55-
)
56-
]
58+
]
59+
60+
61+
MARKDOWN_FILES = get_files_with_extension(".md")
62+
63+
PYTHON_FILES = get_files_with_extension(".py")
5764

5865

5966
@contextlib.contextmanager

tripy/tests/spec_verification/object_builders.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def default_builder(init, dtype, namespace):
6060

6161
find_func = {
6262
"tripy.Tensor": tensor_builder,
63+
"tripy.types.TensorLike": tensor_builder,
6364
"tripy.Shape": tensor_builder,
6465
"tripy.dtype": dtype_builder,
6566
datatype.dtype: dtype_builder,
@@ -83,6 +84,12 @@ def default_builder(init, dtype, namespace):
8384
default_constraints_all = {
8485
"__getitem__": {"index": 2},
8586
"__matmul__": {"self": tp.ones((2, 3))},
87+
# Force broadcasting for binary ops so the entire broadcasting code path is triggered.
88+
"__add__": {"other": 1},
89+
"__mul__": {"other": 1},
90+
"__pow__": {"other": 1},
91+
"__sub__": {"other": 1},
92+
"__truediv__": {"other": 1},
8693
"__radd__": {"self": 1},
8794
"__rmul__": {"self": 1},
8895
"__rpow__": {"self": 1},
@@ -105,29 +112,31 @@ def default_builder(init, dtype, namespace):
105112
},
106113
"cumsum": {"dim": 0},
107114
"dequantize": {"scale": tp.Tensor([1, 1, 1]), "dim": 0},
108-
"expand": {"sizes": tp.Tensor([3, 4]), "input": tp.ones((3, 1))},
115+
"expand": {"sizes": [3, 4], "input": tp.ones((3, 1))},
109116
"flip": {"dim": 1},
110117
"full_like": {"value": 1},
111-
"full": {"shape": tp.Tensor([3]), "value": 1},
118+
"full": {"shape": [3], "value": 1},
112119
"gather": {"dim": 0, "index": tp.Tensor([1])},
113-
"iota": {"shape": tp.Tensor([4])},
120+
"iota": {"shape": [4]},
114121
"masked_fill": {"value": 1},
122+
"maxpool": {"input": tp.ones((1, 3, 5, 5)), "kernel_dims": (3, 3)},
115123
"max": {"dim": 0},
116124
"mean": {"dim": 0},
117-
"ones": {"shape": tp.Tensor([3, 2])},
125+
"ones": {"shape": [3, 2]},
126+
"outer": {"vec1": tp.Tensor([2, 3, 4, 5]), "vec2": tp.Tensor([1, 2, 3, 4])},
118127
"permute": {"perm": [1, 0]},
119128
"prod": {"dim": 0},
120129
"quantize": {"scale": tp.Tensor([1, 1, 1]), "dim": 0},
121130
"repeat": {"repeats": 2, "dim": 0},
122-
"reshape": {"shape": tp.Tensor([6])},
131+
"reshape": {"shape": [6]},
123132
"softmax": {"dim": 1},
124133
"split": {"indices_or_sections": 2},
125134
"squeeze": {"input": tp.ones((3, 1)), "dims": (1)},
126135
"sum": {"dim": 0},
127136
"transpose": {"dim0": 0, "dim1": 1},
128137
"unsqueeze": {"dim": 1},
129138
"var": {"dim": 0},
130-
"zeros": {"shape": tp.Tensor([3, 2])},
139+
"zeros": {"shape": [3, 2]},
131140
}
132141

133142

tripy/tests/spec_verification/test_dtype_constraints.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,6 @@ def _run_dtype_constraints_subtest(test_data):
172172
return ret_val, namespace
173173

174174

175-
# Positive dtype testing is run during L1 testing.
176-
@pytest.mark.l1
177175
@pytest.mark.parametrize("test_data", DTYPE_CONSTRAINT_CASES, ids=lambda val: val[-1])
178176
def test_dtype_constraints(test_data):
179177
# If data type checking is enabled, negative tests will trivially pass (we will throw an

tripy/tests/test_files.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import pytest
16+
from tests import helper
17+
18+
19+
# Checks all Python and markdown files to ensure there are no non-ASCII characters
20+
@pytest.mark.parametrize("file", helper.MARKDOWN_FILES + helper.PYTHON_FILES)
21+
def test_no_non_ascii_characters(file):
22+
with open(file, "rb") as f:
23+
contents = f.read()
24+
25+
try:
26+
contents.decode("ascii")
27+
except UnicodeDecodeError as err:
28+
str_contents = contents.decode("utf-8")
29+
30+
non_ascii = str_contents[err.start : err.end]
31+
32+
line_num = [line_num for line_num, line in enumerate(str_contents.splitlines()) if non_ascii in line][0] + 1
33+
34+
assert False, f"Detected non-ASCII character(s) on line {line_num}: {non_ascii}"

tripy/tripy/common/exception.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,6 @@ def apply_color(inp, color):
9797
frame_info += " " * start + apply_color("^" * (size), Fore.red)
9898
if not is_first_frame:
9999
frame_info += " --- required from here"
100-
else:
101-
if not is_first_frame:
102-
frame_info = "Required from:\n" + frame_info
103100
frame_info += "\n\n"
104101
return frame_info
105102

tripy/tripy/export.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from dataclasses import dataclass
2020
from typing import List, Optional, Any
2121
from types import ModuleType
22-
22+
from textwrap import dedent
2323
from tripy.function_registry import FunctionRegistry
2424

2525

@@ -40,7 +40,11 @@ class PublicAPI:
4040

4141

4242
def public_api(
43-
document_under: str = "", autodoc_options: Optional[List[str]] = None, module: ModuleType = None, symbol: str = None
43+
document_under: str = "",
44+
autodoc_options: Optional[List[str]] = None,
45+
module: ModuleType = None,
46+
symbol: str = None,
47+
doc: str = None,
4448
):
4549
"""
4650
Decorator that exports a function/class to the public API under the top-level module and
@@ -71,6 +75,9 @@ def public_api(
7175
module: The module under which to export this public API. Defaults to the top-level Tripy module.
7276
7377
symbol: The name of the symbol, if different from ``__name__``.
78+
79+
doc: Optional docstring. This is useful in cases where the docstring cannot be provided as normal.
80+
For example, global variables sometimes don't register docstrings correctly.
7481
"""
7582
assert not autodoc_options or (
7683
":no-members:" not in autodoc_options or ":no-special-members:" in autodoc_options
@@ -80,6 +87,9 @@ def export_impl(obj):
8087
nonlocal module, symbol
8188
import tripy
8289

90+
if doc is not None:
91+
obj.__doc__ = dedent(doc)
92+
8393
module = module or tripy
8494

8595
symbol = symbol or obj.__name__

tripy/tripy/frontend/module/parameter.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
from typing import Any, Sequence
1919

20-
import tripy.frontend.utils as frontend_utils
2120
from tripy import export, utils
2221
from tripy.frontend.tensor import Tensor
2322
from tripy.utils import Result
@@ -30,7 +29,6 @@ class Parameter(Tensor):
3029
constant, enabling additional optimization opportunities.
3130
"""
3231

33-
@frontend_utils.convert_inputs_to_tensors()
3432
def __init__(self, tensor: Any) -> None:
3533
"""
3634
Args:

0 commit comments

Comments
 (0)