Skip to content

Commit 9656432

Browse files
Updates exception throwing logic to correctly exclude decorators
1 parent 13dde1b commit 9656432

File tree

5 files changed

+63
-18
lines changed

5 files changed

+63
-18
lines changed

tripy/docs/pre0_user_guides/02-compiler.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,13 @@ fast_geglu(inp).eval()
5252
In the example above, we assumed `inp` has a static shape of `(1, 2)`.
5353
Now, let's assume that the shape of `inp` can vary from `(1, 2)` to `(16, 2)`, with `(8, 2)`
5454
being the shape we'd like to optimize for. To express this constraint to the compiler,
55-
we can provide the range of shapes to `InputInfo` using `shape=((1, 8, 16), 2)`.
55+
we can provide the range of shapes to `InputInfo` using `shape=([1, 8, 16], 2)`.
5656
This indicates to the compiler that the first dimension can vary from 1 to 16,
5757
and it should optimize for a size of 8.
5858

5959
```py
6060
# doc: print-locals out out_change_shape
61-
inp_info = tp.InputInfo(shape=((1, 8, 16), 2), dtype=tp.float32)
61+
inp_info = tp.InputInfo(shape=([1, 8, 16], 2), dtype=tp.float32)
6262
fast_geglu = tp.compile(layer, args=[inp_info])
6363
out = fast_geglu(inp)
6464

tripy/tests/README.md

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,9 @@ The `tests/integration` directory captures the latter group of tests.
1010

1111
You can run all tests locally in the development container by running:
1212
```bash
13-
pytest tests/ -v -n 4 --dist worksteal --ignore tests/performance
14-
pytest tests/performance -v
13+
pytest tests/ -v
1514
```
1615

17-
Performance tests are run separately because they must run serially to ensure
18-
accurate measurements.
19-
2016
You can also provide marker arguments to only run specific test cadences
2117
(see [the test cadence section](#test-cadence) below). For example, to run only
2218
L0 tests, use:
@@ -26,6 +22,9 @@ pytest tests/ -v -m "not l1 and not manual" -n 4 --dist worksteal --ignore tests
2622
pytest tests/performance -v -m "not l1 and not manual"
2723
```
2824

25+
Note that the L0/L1 tests can be parallelized, which is not necessarily
26+
true of `manual` tests. In that case, performance tests are run separately
27+
because they must run serially to ensure accurate measurements.
2928

3029
## Profiling
3130

tripy/tests/common/test_exception.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,15 @@
1515
# limitations under the License.
1616
#
1717

18+
import re
1819
from dataclasses import dataclass
1920
from textwrap import dedent
2021

2122
from tests import helper
22-
from tripy.common.exception import TripyException, _make_stack_info_message, raise_error
23+
24+
import tripy as tp
25+
from tripy.common.exception import TripyException, _get_function_file_and_lines, _make_stack_info_message, raise_error
26+
from tripy.frontend.utils import convert_shape_inputs
2327
from tripy.utils import StackInfo, get_stack_info
2428
from tripy.utils.stack_info import SourceInfo
2529

@@ -120,3 +124,37 @@ def test_can_determine_column_range(self):
120124
).strip()
121125
in dedent(error_msg).strip()
122126
)
127+
128+
def test_convert_shape_inputs_is_excluded(self):
129+
filename, start_line, end_line = _get_function_file_and_lines(convert_shape_inputs)
130+
tensor = tp.ones((2, 3))
131+
132+
stack_info = tensor.stack_info
133+
134+
assert any(
135+
frame.file == filename and frame.line >= start_line and frame.line <= end_line for frame in stack_info
136+
)
137+
138+
# Make sure that no extraneous wrapper code is included
139+
expected = dedent(
140+
r"""
141+
--> [a-z_/\.]+:[0-9]+ in full\(\)
142+
|
143+
[0-9]+ | return full_impl\(shape, value, dtype, output_rank\)
144+
|
145+
146+
--> [a-z_/\.]+:[0-9]+ in ones\(\)
147+
|
148+
[0-9]+ | return full\(shape, 1, dtype\)
149+
| ^^^^^^^^^^^^^^^^^^^^^ --- required from here
150+
151+
--> [a-z_/\.]+:[0-9]+ in test_convert_shape_inputs_is_excluded\(\)
152+
|
153+
[0-9]+ | tensor = tp.ones\(\(2, 3\)\)
154+
| ^^^^^^^^^^^^^^^ --- required from here
155+
156+
"""
157+
).strip()
158+
159+
actual = _make_stack_info_message(stack_info, enable_color=False)
160+
assert re.search(expected, actual) is not None

tripy/tripy/common/exception.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -101,22 +101,30 @@ def apply_color(inp, color):
101101
return frame_info
102102

103103

104+
def _get_function_file_and_lines(func):
105+
filename = inspect.getsourcefile(func)
106+
lines, start_line = inspect.getsourcelines(func)
107+
return filename, start_line, start_line + len(lines)
108+
109+
104110
def _make_stack_info_message(stack_info: "utils.StackInfo", enable_color: bool = True) -> Optional[str]:
111+
105112
from tripy.frontend.utils import convert_inputs_to_tensors, convert_shape_inputs
106113

107114
EXCLUDE_FUNCTIONS = [convert_inputs_to_tensors, convert_shape_inputs]
108115

109-
def should_exclude(frame):
110-
for func in EXCLUDE_FUNCTIONS:
111-
filename = inspect.getsourcefile(func)
112-
lines, start_line = inspect.getsourcelines(func)
116+
exclude_file_lines = {} # Maps filenames to ranges of lines that should be ignored.
117+
for func in EXCLUDE_FUNCTIONS:
118+
filename, start_line, end_line = _get_function_file_and_lines(func)
119+
120+
exclude_file_lines[filename] = (start_line, end_line)
113121

114-
if frame.file != filename:
115-
return False
122+
def should_exclude(frame):
123+
if frame.file not in exclude_file_lines:
124+
return False
116125

117-
if frame.line < start_line or frame.line > (start_line + len(lines)):
118-
return False
119-
return True
126+
start_line, end_line = exclude_file_lines[frame.file]
127+
return frame.line >= start_line and frame.line <= end_line
120128

121129
frame_strs = []
122130
num_frames_printed = 0

tripy/tripy/utils/stack_info.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def get_stack_info(include_code_index: int = None) -> StackInfo:
108108
column_range=None,
109109
)
110110
if source_info.module == tripy.function_registry.__name__ and source_info.function == "wrapper":
111-
source_info._dispatch_target = frame.f_locals.get("key", "")
111+
source_info._dispatch_target = frame.f_locals["key"]
112112

113113
try:
114114
# In Python 3.11, frames contain column offset information.

0 commit comments

Comments
 (0)