Skip to content

Commit 86fb4e3

Browse files
Adds a warning if a tensor is evaluated while compiling
In the case where a tensor is evaluated while compiling and then used in the computation graph, we throw an error. However, there are cases where the result of evaluation could make a round-trip through non-Tripy code, in which case we lose visibility. An example of this, assuming `a` and `b` are tensors, is: ```py b = b + int(a.shape[0]) ``` Here, `a.shape[0]` will be evaluated due to the `int` conversion and then used by the add operation with `b`. Because it was a Python integer in between, Tripy has no way to track that it actually came from an evaluated tensor. Hence, this change prints warnings in these ambiguous cases when a tensor is evaluated while compiling. Here's an example of the warning messages: ``` [W] Tensor was evaluated while compiling which may cause unexpected behavior in the executable. For example, this could cause values to be baked into the executable or dynamic shapes to become static. If the result of the evaluation is not being used by other operations, you can safely ignore this warning. [W] Note: Tensor was evaluated while compiling here: --> /tripy/tests/backend/api/test_compile.py:174 in func() | 174 | print(a.shape) | ^^^^^^^^^^^^^^ [2, 3] [W] Note: Tensor was evaluated while compiling here: --> /tripy/tests/backend/api/test_compile.py:176 in func() | 176 | c = a - int(a.shape[0]) | ^^^^^^^^^^^^^^^ [W] Note: Tensor was evaluated while compiling here: --> /tripy/tests/backend/api/test_compile.py:177 in func() | 177 | print(c) | ^^^^^^^^ ```
1 parent eb4956f commit 86fb4e3

File tree

8 files changed

+65
-23
lines changed

8 files changed

+65
-23
lines changed

tripy/tests/backend/api/test_compile.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,16 +168,21 @@ def func(a):
168168
with helper.raises(tp.TripyException, match="Cannot evaluate a tensor while compiling."):
169169
tp.compile(func, args=[tp.InputInfo((2, 3), dtype=tp.float32)])
170170

171-
def test_allow_eval_if_tensor_unused_in_compile(self):
171+
def test_allow_eval_if_tensor_unused_in_compile(self, capsys):
172172
# If the tensor is not actually used in the computation graph then we don't care if it's eval'd.
173173
def func(a):
174174
print(a.shape)
175175

176-
c = a - 1
176+
c = a - int(a.shape[0])
177177
print(c)
178178
return a
179179

180-
tp.compile(func, args=[tp.InputInfo((2, 3), dtype=tp.float32)])
180+
tp.compile(func, args=[tp.InputInfo((2, 3), dtype=tp.int32)])
181+
out, _ = capsys.readouterr()
182+
print(f"\n{out}")
183+
184+
# Ensure that a warning is printed for each evaluation (2 prints + int).
185+
assert out.count("Tensor was evaluated while compiling here:") == 3
181186

182187
def test_allow_eval_for_non_input_to_compile(self):
183188
# We should allow non-inputs to be evaluated.

tripy/tests/common/test_exception.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from tests import helper
2323

2424
import tripy as tp
25-
from tripy.common.exception import TripyException, _get_function_file_and_lines, _make_stack_info_message, raise_error
25+
from tripy.common.exception import TripyException, _get_function_file_and_lines, str_from_stack_info, raise_error
2626
from tripy.frontend.utils import convert_to_tensors
2727
from tripy.utils import StackInfo, get_stack_info
2828
from tripy.utils.stack_info import SourceInfo
@@ -112,7 +112,7 @@ def test_can_determine_column_range(self):
112112
]
113113
)
114114

115-
error_msg = _make_stack_info_message(stack_info, enable_color=False)
115+
error_msg = str_from_stack_info(stack_info, enable_color=False)
116116
assert (
117117
dedent(
118118
"""
@@ -156,5 +156,5 @@ def test_convert_to_tensors_is_excluded(self):
156156
"""
157157
).strip()
158158

159-
actual = _make_stack_info_message(stack_info, enable_color=False)
159+
actual = str_from_stack_info(stack_info, enable_color=False)
160160
assert re.search(expected, actual) is not None

tripy/tests/helper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
import tripy as tp
3636
from tripy import utils
37-
from tripy.common.exception import _make_stack_info_message
37+
from tripy.common.exception import str_from_stack_info
3838
from tripy.frontend import Tensor
3939
from tripy.frontend.trace import Trace
4040

@@ -74,7 +74,7 @@ def raises(ExcType: type, match: Optional[str] = None, has_stack_info_for: Seque
7474
has_stack_info_for = has_stack_info_for or []
7575
for tensor in has_stack_info_for:
7676
# Stack info is indented since it's part of the `details` block in `raise_error`
77-
expected_stack_info = indent(_make_stack_info_message(tensor.stack_info).strip(), " " * 4)
77+
expected_stack_info = indent(str_from_stack_info(tensor.stack_info).strip(), " " * 4)
7878
assert expected_stack_info in error_msg, f"Missing stack information for tensor:\n{expected_stack_info}"
7979

8080

tripy/tripy/common/exception.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def _get_function_file_and_lines(func):
107107
return filename, start_line, start_line + len(lines)
108108

109109

110-
def _make_stack_info_message(stack_info: "utils.StackInfo", enable_color: bool = True) -> Optional[str]:
110+
def str_from_stack_info(stack_info: "utils.StackInfo", enable_color: bool = True) -> Optional[str]:
111111
from tripy.frontend.utils import convert_to_tensors
112112

113113
EXCLUDE_FUNCTIONS = [convert_to_tensors]
@@ -187,9 +187,9 @@ def raise_error(summary: str, details: List[Any] = []):
187187
for detail in details:
188188
stack_info_message = None
189189
if hasattr(detail, "stack_info"):
190-
stack_info_message = _make_stack_info_message(detail.stack_info)
190+
stack_info_message = str_from_stack_info(detail.stack_info)
191191
elif isinstance(detail, utils.StackInfo):
192-
stack_info_message = _make_stack_info_message(detail)
192+
stack_info_message = str_from_stack_info(detail)
193193

194194
if stack_info_message is not None:
195195
detail_msg += stack_info_message

tripy/tripy/frontend/module/module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from typing import Any, Dict, Iterator, List, Tuple, Union, Set, Sequence, TypeVar
2121

2222
from tripy import export, utils
23-
from tripy.common.exception import raise_error, _make_stack_info_message
23+
from tripy.common.exception import raise_error, str_from_stack_info
2424
from tripy.frontend.module.parameter import Parameter
2525
from tripy.logging import logger
2626

@@ -111,7 +111,7 @@ def __setattr__(self, name: str, value: Any) -> None:
111111
):
112112
stack_info = utils.get_stack_info()
113113
stack_info.fetch_source_code()
114-
stack_info_msg = _make_stack_info_message(stack_info)
114+
stack_info_msg = str_from_stack_info(stack_info)
115115

116116
logger.warning(
117117
"A container of mixed types will not be registered with this module's state_dict()."

tripy/tripy/frontend/tensor.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,11 @@
2626
from tripy import export, utils
2727
from tripy.backend.mlir import memref
2828
from tripy.common import datatype
29-
from tripy.common.exception import raise_error
29+
from tripy.common.exception import raise_error, str_from_stack_info
3030
from tripy.frontend.ops.registry import TENSOR_METHOD_REGISTRY
3131
from tripy.frontend.trace.ops import Storage
3232
from tripy.frontend.trace.tensor import TraceTensor
33+
from tripy.logging.logger import logger
3334
from tripy.utils.stack_info import StackInfo
3435

3536

@@ -198,6 +199,18 @@ def eval(self) -> runtime.MemRefValue:
198199
self.trace_tensor.device = flat_ir.outputs[0].device
199200

200201
self.trace_tensor.eval_stack_info = utils.get_stack_info()
202+
if self.trace_tensor.is_compile_tracer:
203+
logger.warning(
204+
f"Tensor was evaluated while compiling which may cause unexpected behavior in the executable.\n"
205+
f"For example, this could cause values to be baked into the executable or dynamic shapes to become static.\n"
206+
f"If the result of the evaluation is not being used by other operations, you can safely ignore this warning.",
207+
mode="once",
208+
)
209+
logger.warning(
210+
f"Note: Tensor was evaluated while compiling here: {str_from_stack_info(self.trace_tensor.eval_stack_info)}",
211+
mode="once",
212+
)
213+
201214
return data
202215

203216
def tolist(self):

tripy/tripy/logging/logger.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def __init__(self) -> None:
105105
self._indentation = 0
106106
self.verbosity: Union[str, Set[str], Dict[str, str], Dict[str, Set[str]]] = "info"
107107
self.enable_color = True
108+
self._already_logged_hashes = set()
108109

109110
@property
110111
def verbosity(self):
@@ -167,7 +168,9 @@ def indent(self, level: int = 4):
167168
finally:
168169
self._indentation = old_indentation
169170

170-
def log(self, message: Union[str, Callable[[], str]], verbosity: str, stack_depth: int = 2) -> None:
171+
def log(
172+
self, message: Union[str, Callable[[], str]], verbosity: str, mode: str = "each", stack_depth: int = 2
173+
) -> None:
171174
"""
172175
Logs a message to standard output.
173176
@@ -178,6 +181,9 @@ def log(self, message: Union[str, Callable[[], str]], verbosity: str, stack_dept
178181
message: The message to log. This can be provided as a callable in which case it will not
179182
be called unless the message actually needs to be logged.
180183
verbosity: The verbosity at which to log this message.
184+
mode: Indicates when or how to log the message. Available modes are:
185+
- "each": Log the message each time.
186+
- "once": Only log a message the first time it is seen.
181187
stack_depth: The stack depth to use when determining which file the message is being logged from.
182188
"""
183189
assert (
@@ -202,6 +208,7 @@ def get_rel_file_path():
202208
return module_path(file_path)
203209

204210
def should_log():
211+
205212
path = None
206213
# Don't actually need to get the path if there are no non-default entries in the trie.
207214
if self.verbosity.has_non_default_entries:
@@ -211,6 +218,12 @@ def should_log():
211218
if not should_log():
212219
return
213220

221+
if mode == "once":
222+
message_hash = hash(message)
223+
if message_hash in self._already_logged_hashes:
224+
return
225+
self._already_logged_hashes.add(message_hash)
226+
214227
if callable(message):
215228
message = message()
216229

tripy/tripy/utils/ast.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#
1717

1818
import ast
19-
from typing import List, Optional, Tuple
19+
from typing import List, Optional, Tuple, Set
2020

2121
from tripy.utils.result import Result
2222
from tripy.utils.stack_info import SourceInfo
@@ -40,14 +40,25 @@ def get_parsed_ast(code: str) -> Result[Tuple[str, int]]:
4040
return Result.ok((parsed_ast, indentation))
4141

4242

43-
def get_callee_func_name(callee: SourceInfo):
44-
callee_name = callee.function
43+
def get_callee_func_name_candidates(callee: SourceInfo) -> Set[str]:
4544
# Some functions (e.g. tensor methods) are routed through a function registry.
4645
# We don't actually care about the dispatch function, so we look at the `key`
4746
# to determine which underlying method we're actually calling.
4847
if callee._dispatch_target:
49-
callee_name = callee._dispatch_target
50-
return callee_name
48+
candidates = {callee._dispatch_target}
49+
else:
50+
candidates = {callee.function}
51+
52+
# Some methods are called by other builtins:
53+
SPECIAL_METHODS = {
54+
"__repr__": {"repr", "print"},
55+
"__str__": {"str"},
56+
"__int__": {"int"},
57+
"__bool__": {"bool"},
58+
}
59+
candidates.update(SPECIAL_METHODS.get(callee.function, set()))
60+
61+
return candidates
5162

5263

5364
def get_ast_node_func_name(node) -> Optional[str]:
@@ -144,7 +155,7 @@ def index_into_expr(node: ast.expr, index: int) -> ast.expr:
144155
# Grab column offsets for a given frame based on information from its callee.
145156
# This method is not perfect and is not required for Python 3.11+, where frames include column offsets.
146157
def get_candidate_column_offsets(cur_frame: SourceInfo, callee: SourceInfo) -> List[Tuple[int, int]]:
147-
callee_name = get_callee_func_name(callee)
158+
candidate_callee_names = get_callee_func_name_candidates(callee)
148159

149160
candidate_column_offsets = []
150161

@@ -165,8 +176,8 @@ def get_candidate_column_offsets(cur_frame: SourceInfo, callee: SourceInfo) -> L
165176

166177
def check_name_matches():
167178
# We need special checking for __init__ methods since the AST node will just be the class name, e.g. `Tensor`.
168-
if callee_name != "__init__":
169-
return ast_node_name == callee_name
179+
if "__init__" not in candidate_callee_names:
180+
return ast_node_name in candidate_callee_names
170181

171182
# We hardcode names of some common classes here to avoid creating an import dependency:
172183
if ast_node_name in {"Tensor"}:

0 commit comments

Comments
 (0)