Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tripy/nvtripy/backend/mlir/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ def _get_compilation_task(self, trt_builder_opt_level):
if config.enable_mlir_debug or config.enable_tensorrt_debug:
opts.append("--debug=true")
if config.enable_mlir_debug:
# elide large constants by default, can remove this option
# if we want to run the dumped IR
opts.append(f"--mlir-elide-elementsattrs-if-larger=1024")
opts.append(f"--debug-only={config.mlir_debug_types}")
opts.append(f"--mlir-print-ir-after-all")
opts.append(f"--mlir-print-ir-tree-dir={config.mlir_debug_tree_path}")
Expand Down
11 changes: 6 additions & 5 deletions tripy/nvtripy/backend/mlir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,15 +124,16 @@ def make_mlir_tensor(
OUTPUT_SEPARATOR = ";;<out>;;"


def make_tensor_location(input_names: List[str], output_names: List[str]) -> ir.Location:
return ir.Location.name(f"{','.join(input_names)}{OUTPUT_SEPARATOR}{','.join(output_names)}")
def make_tensor_location(input_names: List[str], output_names: List[str], metadata: str) -> ir.Location:
loc = ir.Location.name(f"{','.join(input_names)}{OUTPUT_SEPARATOR}{','.join(output_names)}")
return ir.Location.fused([loc], ir.StringAttr.get(metadata))


# The way locations are printed by MLIR-TRT differs from how they are printed by TRT, hence all the `?`s.
TENSOR_NAME_PATTERN = re.compile(r'loc\("?(.*?)"?\):? ?')
# MLIR-TRT prints the fused location in the format: loc(fused<"...">["..."]):
TENSOR_NAME_PATTERN = re.compile(rf'"([^"]*{OUTPUT_SEPARATOR}[^"]*)"')
# Noncapturing pattern is required so that when we `.split`, we eliminate the entire pattern and not just
# the captured portions.
TENSOR_NAME_PATTERN_NO_CAPTURE = re.compile(r'loc\("?.*?"?\):? ?')
TENSOR_NAME_PATTERN_NO_CAPTURE = re.compile(rf'loc\((?:fused.*?">\[)?"[^"]*{OUTPUT_SEPARATOR}[^"]*"(?:\])?\):? ?')


def parse_tensor_names_from_location(msg: str) -> Tuple[List[str], List[str], str]:
Expand Down
42 changes: 22 additions & 20 deletions tripy/nvtripy/trace/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,29 @@ def to_mlir_impl():
layer_input_ops = [mlir_ops[inp.name] for inp in op.inputs]
output_types = [out.to_mlir() for out in op.outputs]

metadata = (
f"<TraceOp: {op}, Stack Info: "
+ ", ".join(
(
str_from_stack_info(
# include_code_index points to the first "useful" frame, i.e. usually the API
# that the user calls, or a few frames below that:
out.stack_info[out.stack_info.include_code_index or 0 :],
enable_color=False,
fetch_source_code=False,
)
if out.stack_info
else ""
).replace("\n", " ")
for out in op.outputs
)
+ " >, "
)

with make_tensor_location(
[inp.name for inp in op.inputs], [out.name for out in op.outputs]
[inp.name for inp in op.inputs],
[out.name for out in op.outputs],
metadata,
):
mlir_output_ops = op.to_mlir(layer_input_ops, output_types)

Expand All @@ -180,25 +201,6 @@ def num_known_dims(ranked_tensor_type):
if num_known_dims(output_type) >= num_known_dims(mlir_output_op.type):
mlir_output_op.set_type(output_type)

mlir_output_op.owner.attributes["metadata"] = ir.StringAttr.get(
f"<TraceOp: {op}, Stack Info: "
+ ", ".join(
(
str_from_stack_info(
# include_code_index points to the first "useful" frame, i.e. usually the API
# that the user calls, or a few frames below that:
out.stack_info[out.stack_info.include_code_index or 0 :],
enable_color=False,
fetch_source_code=False,
)
if out.stack_info
else ""
).replace("\n", " ")
for out in op.outputs
)
+ " >, "
)

mlir_ops.update(zip([out.name for out in op.outputs], mlir_output_ops))

func_dialect.ReturnOp([mlir_ops[o.name] for o in self.outputs])
Expand Down
Loading