Skip to content

Commit e25f408

Browse files
Use FusedLocation to hold layer metadata (#690)
Signed-off-by: yizhuoz004 <[email protected]> Co-authored-by: pranavm-nvidia <[email protected]>
1 parent 42c31f3 commit e25f408

File tree

3 files changed

+31
-25
lines changed

3 files changed

+31
-25
lines changed

tripy/nvtripy/backend/mlir/compiler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ def _get_compilation_task(self, trt_builder_opt_level):
5858
if config.enable_mlir_debug or config.enable_tensorrt_debug:
5959
opts.append("--debug=true")
6060
if config.enable_mlir_debug:
61+
# elide large constants by default, can remove this option
62+
# if we want to run the dumped IR
63+
opts.append(f"--mlir-elide-elementsattrs-if-larger=1024")
6164
opts.append(f"--debug-only={config.mlir_debug_types}")
6265
opts.append(f"--mlir-print-ir-after-all")
6366
opts.append(f"--mlir-print-ir-tree-dir={config.mlir_debug_tree_path}")

tripy/nvtripy/backend/mlir/utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,15 +124,16 @@ def make_mlir_tensor(
124124
OUTPUT_SEPARATOR = ";;<out>;;"
125125

126126

127-
def make_tensor_location(input_names: List[str], output_names: List[str]) -> ir.Location:
128-
return ir.Location.name(f"{','.join(input_names)}{OUTPUT_SEPARATOR}{','.join(output_names)}")
127+
def make_tensor_location(input_names: List[str], output_names: List[str], metadata: str) -> ir.Location:
128+
loc = ir.Location.name(f"{','.join(input_names)}{OUTPUT_SEPARATOR}{','.join(output_names)}")
129+
return ir.Location.fused([loc], ir.StringAttr.get(metadata))
129130

130131

131-
# The way locations are printed by MLIR-TRT differs from how they are printed by TRT, hence all the `?`s.
132-
TENSOR_NAME_PATTERN = re.compile(r'loc\("?(.*?)"?\):? ?')
132+
# MLIR-TRT prints the fused location in the format: loc(fused<"...">["..."]):
133+
TENSOR_NAME_PATTERN = re.compile(rf'"([^"]*{OUTPUT_SEPARATOR}[^"]*)"')
133134
# Noncapturing pattern is required so that when we `.split`, we eliminate the entire pattern and not just
134135
# the captured portions.
135-
TENSOR_NAME_PATTERN_NO_CAPTURE = re.compile(r'loc\("?.*?"?\):? ?')
136+
TENSOR_NAME_PATTERN_NO_CAPTURE = re.compile(rf'loc\((?:fused.*?">\[)?"[^"]*{OUTPUT_SEPARATOR}[^"]*"(?:\])?\):? ?')
136137

137138

138139
def parse_tensor_names_from_location(msg: str) -> Tuple[List[str], List[str], str]:

tripy/nvtripy/trace/trace.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,29 @@ def to_mlir_impl():
159159
layer_input_ops = [mlir_ops[inp.name] for inp in op.inputs]
160160
output_types = [out.to_mlir() for out in op.outputs]
161161

162+
metadata = (
163+
f"<TraceOp: {op}, Stack Info: "
164+
+ ", ".join(
165+
(
166+
str_from_stack_info(
167+
# include_code_index points to the first "useful" frame, i.e. usually the API
168+
# that the user calls, or a few frames below that:
169+
out.stack_info[out.stack_info.include_code_index or 0 :],
170+
enable_color=False,
171+
fetch_source_code=False,
172+
)
173+
if out.stack_info
174+
else ""
175+
).replace("\n", " ")
176+
for out in op.outputs
177+
)
178+
+ " >, "
179+
)
180+
162181
with make_tensor_location(
163-
[inp.name for inp in op.inputs], [out.name for out in op.outputs]
182+
[inp.name for inp in op.inputs],
183+
[out.name for out in op.outputs],
184+
metadata,
164185
):
165186
mlir_output_ops = op.to_mlir(layer_input_ops, output_types)
166187

@@ -180,25 +201,6 @@ def num_known_dims(ranked_tensor_type):
180201
if num_known_dims(output_type) >= num_known_dims(mlir_output_op.type):
181202
mlir_output_op.set_type(output_type)
182203

183-
mlir_output_op.owner.attributes["metadata"] = ir.StringAttr.get(
184-
f"<TraceOp: {op}, Stack Info: "
185-
+ ", ".join(
186-
(
187-
str_from_stack_info(
188-
# include_code_index points to the first "useful" frame, i.e. usually the API
189-
# that the user calls, or a few frames below that:
190-
out.stack_info[out.stack_info.include_code_index or 0 :],
191-
enable_color=False,
192-
fetch_source_code=False,
193-
)
194-
if out.stack_info
195-
else ""
196-
).replace("\n", " ")
197-
for out in op.outputs
198-
)
199-
+ " >, "
200-
)
201-
202204
mlir_ops.update(zip([out.name for out in op.outputs], mlir_output_ops))
203205

204206
func_dialect.ReturnOp([mlir_ops[o.name] for o in self.outputs])

0 commit comments

Comments
 (0)