Skip to content

Commit 61cd9a6

Browse files
committed
Use FusedLocation to hold layer metadata
1 parent f821499 commit 61cd9a6

File tree

3 files changed

+26
-22
lines changed

3 files changed

+26
-22
lines changed

tripy/nvtripy/backend/mlir/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def _get_compilation_task(self, trt_builder_opt_level):
5858
if config.enable_mlir_debug or config.enable_tensorrt_debug:
5959
opts.append("--debug=true")
6060
if config.enable_mlir_debug:
61+
opts.append(f"--mlir-elide-elementsattrs-if-larger=32")
6162
opts.append(f"--debug-only={config.mlir_debug_types}")
6263
opts.append(f"--mlir-print-ir-after-all")
6364
opts.append(f"--mlir-print-ir-tree-dir={config.mlir_debug_tree_path}")

tripy/nvtripy/backend/mlir/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,9 @@ def make_mlir_tensor(
124124
OUTPUT_SEPARATOR = ";;<out>;;"
125125

126126

127-
def make_tensor_location(input_names: List[str], output_names: List[str]) -> ir.Location:
128-
return ir.Location.name(f"{','.join(input_names)}{OUTPUT_SEPARATOR}{','.join(output_names)}")
127+
def make_tensor_location(input_names: List[str], output_names: List[str], metadata: str) -> ir.Location:
128+
loc = ir.Location.name(f"{','.join(input_names)}{OUTPUT_SEPARATOR}{','.join(output_names)}")
129+
return ir.Location.fused([loc], ir.StringAttr.get(metadata))
129130

130131

131132
# The way locations are printed by MLIR-TRT differs from how they are printed by TRT, hence all the `?`s.

tripy/nvtripy/trace/trace.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,29 @@ def to_mlir_impl():
159159
layer_input_ops = [mlir_ops[inp.name] for inp in op.inputs]
160160
output_types = [out.to_mlir() for out in op.outputs]
161161

162+
metadata = (
163+
f"<TraceOp: {op}, Stack Info: "
164+
+ ", ".join(
165+
(
166+
str_from_stack_info(
167+
# include_code_index points to the first "useful" frame, i.e. usually the API
168+
# that the user calls, or a few frames below that:
169+
out.stack_info[out.stack_info.include_code_index or 0 :],
170+
enable_color=False,
171+
fetch_source_code=False,
172+
)
173+
if out.stack_info
174+
else ""
175+
).replace("\n", " ")
176+
for out in op.outputs
177+
)
178+
+ " >, "
179+
)
180+
162181
with make_tensor_location(
163-
[inp.name for inp in op.inputs], [out.name for out in op.outputs]
182+
[inp.name for inp in op.inputs],
183+
[out.name for out in op.outputs],
184+
metadata,
164185
):
165186
mlir_output_ops = op.to_mlir(layer_input_ops, output_types)
166187

@@ -180,25 +201,6 @@ def num_known_dims(ranked_tensor_type):
180201
if num_known_dims(output_type) >= num_known_dims(mlir_output_op.type):
181202
mlir_output_op.set_type(output_type)
182203

183-
mlir_output_op.owner.attributes["metadata"] = ir.StringAttr.get(
184-
f"<TraceOp: {op}, Stack Info: "
185-
+ ", ".join(
186-
(
187-
str_from_stack_info(
188-
# include_code_index points to the first "useful" frame, i.e. usually the API
189-
# that the user calls, or a few frames below that:
190-
out.stack_info[out.stack_info.include_code_index or 0 :],
191-
enable_color=False,
192-
fetch_source_code=False,
193-
)
194-
if out.stack_info
195-
else ""
196-
).replace("\n", " ")
197-
for out in op.outputs
198-
)
199-
+ " >, "
200-
)
201-
202204
mlir_ops.update(zip([out.name for out in op.outputs], mlir_output_ops))
203205

204206
func_dialect.ReturnOp([mlir_ops[o.name] for o in self.outputs])

0 commit comments

Comments
 (0)