Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,7 +897,6 @@ def _insert_complex_io_adapters(
partitioned_module.graph.lint()
partitioned_module.recompile()


@fn_supports_debugger # type: ignore[misc]
def compile_module(
gm: torch.fx.GraphModule,
Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,7 @@ def aten_ops_rsqrt(
)


@dynamo_tensorrt_converter(operator.neg, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.neg.default, supports_dynamic_shapes=True)
def aten_ops_neg(
ctx: ConversionContext,
Expand Down Expand Up @@ -2223,6 +2224,7 @@ def aten_ops_maximum(
)


@dynamo_tensorrt_converter(torch.sym_min, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.minimum.default, supports_dynamic_shapes=True)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we have this in both the converter and a lowering pass to remove sym_min?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lowering pass should only remove a specific instance of sym_min

def aten_ops_minimum(
ctx: ConversionContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from .complex_graph_rewrite import complex_graph_detection
from .constant_folding import constant_fold
from .eliminate_sym_min_int64_max import eliminate_sym_min_int64_max
from .force_causal_efficient_attention import force_causal_efficient_attention
from .fuse_prims_broadcast import fuse_prims_broadcast
from .pass_manager import DynamoPassManager
Expand All @@ -23,6 +24,7 @@
from .replace_fused_rms_norm import replace_fused_rms_norm
from .replace_max_pool_with_indices import replace_max_pool_with_indices
from .rule_based_autocast import rule_based_autocast
from .normalize_negative_slice_stop import normalize_negative_slice_stop

pre_lowering_pass_list = [
remove_detach,
Expand All @@ -41,6 +43,8 @@
remove_num_users_is_0_nodes,
complex_graph_detection,
force_causal_efficient_attention,
eliminate_sym_min_int64_max,
normalize_negative_slice_stop,
]

if not is_tegra_platform():
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import sys

import torch
from torch.fx import GraphModule, Node

from .pass_utils import clean_up_graph_after_modifications


_INT64_MAX = 2**63 - 1
_SYM_MIN = getattr(torch, "sym_min", None)


def _is_int64_max(x: object) -> bool:
return isinstance(x, int) and x in (sys.maxsize, _INT64_MAX)


def eliminate_sym_min_int64_max(
gm: GraphModule, settings: object = None
) -> GraphModule:
"""Remove no-op sym_min nodes where one operand is INT64_MAX.

torch.export may emit sym_min(sym, INT64_MAX) for an effectively unbounded
symbolic value. That expression is equivalent to sym, and leaving it in the
graph can produce runtime calls to torch.sym_min with Tensor inputs.
"""
if _SYM_MIN is None:
return gm

modified = False
for node in list(gm.graph.nodes):
if (
node.op != "call_function"
or node.target is not _SYM_MIN
or len(node.args) < 2
):
continue

lhs, rhs = node.args[:2]
if _is_int64_max(rhs) and isinstance(lhs, Node):
passthrough = lhs
elif _is_int64_max(lhs) and isinstance(rhs, Node):
passthrough = rhs
else:
continue

node.replace_all_uses_with(passthrough)
gm.graph.erase_node(node)
modified = True

return clean_up_graph_after_modifications(gm) if modified else gm
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import operator
from typing import Optional

import torch
from torch.fx import GraphModule, Node
from torch_tensorrt.dynamo.lowering._SubgraphBuilder import SubgraphBuilder

from .pass_utils import clean_up_graph_after_modifications


Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cant we handle this case in the converter itself?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason we would prefer that?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think generally simpler converters and more in lowering is better, but I think the line needs to be clear, either we normalize dimensions in the graph or we do it in the converter

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

going through this case ya lowering seems better, since the negative seems to be ITensor, so we would need extra Iselect layers, making the converter more complicated. Though we have done this before in converters, but if we want simpler converters now, doing in lowering makes more sense

def _negative_symint_operand(x: object) -> Optional[object]:
# Return n for symbolic bounds represented as -n. The caller rewrites
# that bound to dim_size - n, matching Python's negative indexing rules.
if (
isinstance(x, Node)
and x.op == "call_function"
and x.target in (operator.neg, torch.ops.aten.neg.default)
and len(x.args) == 1
):
return x.args[0]
return None


def _rank(x: Node) -> Optional[int]:
val = x.meta.get("val")
if isinstance(val, torch.Tensor):
return val.dim()
if hasattr(val, "shape"):
return len(val.shape)
return None


def normalize_negative_slice_stop(
gm: GraphModule, settings: object = None
) -> GraphModule:
"""Normalize negative symbolic slice bounds to positive dim-relative bounds.

Python slicing accepts negative bounds such as x[-n:] or x[:-n]. TensorRT
shape expressions need the equivalent positive bound, dim_size - n.
"""
modified = False

for node in list(gm.graph.nodes):
if node.op != "call_function" or node.target != torch.ops.aten.slice.Tensor:
continue

args = list(node.args)
if len(args) < 3:
continue

input_node, dim = args[:2]
if not isinstance(input_node, Node) or not isinstance(dim, int):
continue

rank = _rank(input_node)
if rank is not None:
# Match PyTorch dim normalization for negative dims.
dim = dim % rank

rewritten = False
# aten.slice.Tensor can appear as (input, dim, start) or
# (input, dim, start, stop, ...). Normalize either symbolic bound.
for bound_index in (2, 3):
if len(args) <= bound_index:
continue

bound = args[bound_index]
positive_offset = _negative_symint_operand(bound)
if positive_offset is None:
continue

with SubgraphBuilder(gm.graph, node.prev) as b:
dim_size = b(torch.ops.aten.sym_size.int, input_node, dim)
# A negative symbolic bound -n becomes dim_size - n.
normalized_bound = b(operator.sub, dim_size, positive_offset)

args[bound_index] = normalized_bound
rewritten = True

if rewritten:
args[1] = dim
node.args = tuple(args)
modified = True

return clean_up_graph_after_modifications(gm) if modified else gm
99 changes: 99 additions & 0 deletions tests/py/dynamo/lowering/test_aten_lowering_passes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import operator
import sys

import torch
import torch_tensorrt
from torch.testing._internal.common_utils import TestCase, run_tests
Expand Down Expand Up @@ -278,6 +281,102 @@ def forward(self, x: torch.Tensor):
self.assertTrue(True)


class TestNormalizeNegativeSliceStop(TestCase):
def test_normalizes_negative_symbolic_start_bound(self):
from torch_tensorrt.dynamo.lowering.passes.normalize_negative_slice_stop import (
normalize_negative_slice_stop,
)

graph = torch.fx.Graph()
x = graph.placeholder("x")
x.meta["val"] = torch.empty(2, 5, 3)
n = graph.placeholder("n")
neg = graph.call_function(operator.neg, args=(n,))
sliced = graph.call_function(torch.ops.aten.slice.Tensor, args=(x, -2, neg))
graph.output(sliced)

gm = torch.fx.GraphModule({}, graph)
gm = normalize_negative_slice_stop(gm)

slice_node = next(
node
for node in gm.graph.nodes
if node.op == "call_function" and node.target == torch.ops.aten.slice.Tensor
)
self.assertEqual(slice_node.args[1], 1)

normalized_start = slice_node.args[2]
self.assertEqual(normalized_start.op, "call_function")
self.assertEqual(normalized_start.target, operator.sub)

dim_size, offset = normalized_start.args
self.assertEqual(dim_size.target, torch.ops.aten.sym_size.int)
self.assertEqual(dim_size.args[0], x)
self.assertEqual(dim_size.args[1], 1)
self.assertEqual(offset, n)

def test_normalizes_negative_symbolic_stop_bound(self):
from torch_tensorrt.dynamo.lowering.passes.normalize_negative_slice_stop import (
normalize_negative_slice_stop,
)

graph = torch.fx.Graph()
x = graph.placeholder("x")
x.meta["val"] = torch.empty(2, 5, 3)
n = graph.placeholder("n")
neg = graph.call_function(torch.ops.aten.neg.default, args=(n,))
sliced = graph.call_function(torch.ops.aten.slice.Tensor, args=(x, 1, 0, neg))
graph.output(sliced)

gm = torch.fx.GraphModule({}, graph)
gm = normalize_negative_slice_stop(gm)

slice_node = next(
node
for node in gm.graph.nodes
if node.op == "call_function" and node.target == torch.ops.aten.slice.Tensor
)

normalized_stop = slice_node.args[3]
self.assertEqual(normalized_stop.op, "call_function")
self.assertEqual(normalized_stop.target, operator.sub)

dim_size, offset = normalized_stop.args
self.assertEqual(dim_size.target, torch.ops.aten.sym_size.int)
self.assertEqual(dim_size.args[0], x)
self.assertEqual(dim_size.args[1], 1)
self.assertEqual(offset, n)


class TestEliminateSymMinInt64Max(TestCase):
def test_eliminates_noop_sym_min_int64_max(self):
if not hasattr(torch, "sym_min"):
self.skipTest("torch.sym_min is not available")

from torch_tensorrt.dynamo.lowering.passes.eliminate_sym_min_int64_max import (
eliminate_sym_min_int64_max,
)

graph = torch.fx.Graph()
x = graph.placeholder("x")
rhs_int64_max = graph.call_function(torch.sym_min, args=(x, sys.maxsize))
lhs_int64_max = graph.call_function(torch.sym_min, args=(2**63 - 1, x))
graph.output((rhs_int64_max, lhs_int64_max))

gm = torch.fx.GraphModule({}, graph)
gm = eliminate_sym_min_int64_max(gm)

self.assertFalse(
any(
node.op == "call_function" and node.target is torch.sym_min
for node in gm.graph.nodes
)
)

output_node = next(node for node in gm.graph.nodes if node.op == "output")
self.assertEqual(output_node.args[0], (x, x))


class TestRewriteEfficientAttention(TestCase):
def test_force_causal_efficient_attention(self):
class RewriteEfficientAttention(torch.nn.Module):
Expand Down
Loading