Skip to content

Commit 6584c93

Browse files
Updates MLIR to include layer metadata (#669)
Updates our Trace->MLIR translation to include metadata for each layer that includes stack information and details about the Trace operation.
1 parent 515b0e3 commit 6584c93

File tree

2 files changed

+26
-6
lines changed

2 files changed

+26
-6
lines changed

tripy/nvtripy/common/exception.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
# limitations under the License.
1616
#
1717

18-
import inspect
19-
from dataclasses import dataclass
2018
from textwrap import indent
2119
from typing import Any, List, Optional, Tuple
2220

@@ -84,12 +82,14 @@ def apply_color(inp, color):
8482
return frame_info
8583

8684

87-
def str_from_stack_info(stack_info: "utils.stack_info.StackInfo", enable_color: bool = True) -> Optional[str]:
85+
def str_from_stack_info(
86+
stack_info: "utils.stack_info.StackInfo", enable_color: bool = True, fetch_source_code: bool = True
87+
) -> Optional[str]:
8888
from nvtripy.frontend.module import module
8989

9090
def should_exclude(source_info):
9191
return (
92-
source_info.code is None
92+
(fetch_source_code and source_info.code is None)
9393
or source_info.module in utils.stack_info.get_module_names_to_exclude_from_stack_info()
9494
# Exclude module.__call__ since it just invokes forward and clutters the stack trace
9595
or (source_info.module == module.__name__ and source_info.function == "__call__")
@@ -98,7 +98,8 @@ def should_exclude(source_info):
9898
frame_strs = []
9999
num_frames_printed = 0
100100

101-
stack_info.fetch_source_code()
101+
if fetch_source_code:
102+
stack_info.fetch_source_code()
102103
for index, source_info in enumerate(stack_info):
103104
if should_exclude(source_info):
104105
continue

tripy/nvtripy/trace/trace.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
map_error_to_user_code_and_raise,
2929
redirect_stderr,
3030
)
31-
from nvtripy.common.exception import raise_error
31+
from nvtripy.common.exception import raise_error, str_from_stack_info
3232
from nvtripy.logging import logger
3333
from nvtripy.trace.tensor import TraceTensor
3434
from nvtripy.trace.utils import topological_sort
@@ -178,6 +178,25 @@ def num_known_dims(ranked_tensor_type):
178178
if num_known_dims(output_type) >= num_known_dims(mlir_output_op.type):
179179
mlir_output_op.set_type(output_type)
180180

181+
mlir_output_op.owner.attributes["metadata"] = ir.StringAttr.get(
182+
f"<TraceOp: {op}, Stack Info: "
183+
+ ", ".join(
184+
(
185+
str_from_stack_info(
186+
# include_code_index points to the first "useful" frame, i.e. usually the API
187+
# that the user calls, or a few frames below that:
188+
out.stack_info[out.stack_info.include_code_index or 0 :],
189+
enable_color=False,
190+
fetch_source_code=False,
191+
)
192+
if out.stack_info
193+
else ""
194+
).replace("\n", " ")
195+
for out in op.outputs
196+
)
197+
+ " >, "
198+
)
199+
181200
mlir_ops.update(zip([out.name for out in op.outputs], mlir_output_ops))
182201

183202
func_dialect.ReturnOp([mlir_ops[o.name] for o in self.outputs])

0 commit comments

Comments
 (0)