-
Notifications
You must be signed in to change notification settings - Fork 403
fix 4326: bound exported dynamic shapes and normalize symbolic slice … #4341
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
446b9f3
8731394
3f68272
d0e1bc0
68ac7c3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| import sys | ||
|
|
||
| import torch | ||
| from torch.fx import GraphModule, Node | ||
|
|
||
| from .pass_utils import clean_up_graph_after_modifications | ||
|
|
||
|
|
||
| _INT64_MAX = 2**63 - 1 | ||
| _SYM_MIN = getattr(torch, "sym_min", None) | ||
|
|
||
|
|
||
| def _is_int64_max(x: object) -> bool: | ||
| return isinstance(x, int) and x in (sys.maxsize, _INT64_MAX) | ||
|
|
||
|
|
||
| def eliminate_sym_min_int64_max( | ||
| gm: GraphModule, settings: object = None | ||
| ) -> GraphModule: | ||
| """Remove no-op sym_min nodes where one operand is INT64_MAX. | ||
|
|
||
| torch.export may emit sym_min(sym, INT64_MAX) for an effectively unbounded | ||
| symbolic value. That expression is equivalent to sym, and leaving it in the | ||
| graph can produce runtime calls to torch.sym_min with Tensor inputs. | ||
| """ | ||
| if _SYM_MIN is None: | ||
| return gm | ||
|
|
||
| modified = False | ||
| for node in list(gm.graph.nodes): | ||
| if ( | ||
| node.op != "call_function" | ||
| or node.target is not _SYM_MIN | ||
| or len(node.args) < 2 | ||
| ): | ||
| continue | ||
|
|
||
| lhs, rhs = node.args[:2] | ||
| if _is_int64_max(rhs) and isinstance(lhs, Node): | ||
| passthrough = lhs | ||
| elif _is_int64_max(lhs) and isinstance(rhs, Node): | ||
| passthrough = rhs | ||
| else: | ||
| continue | ||
|
|
||
| node.replace_all_uses_with(passthrough) | ||
| gm.graph.erase_node(node) | ||
| modified = True | ||
|
|
||
| return clean_up_graph_after_modifications(gm) if modified else gm |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,85 @@ | ||
| import operator | ||
| from typing import Optional | ||
|
|
||
| import torch | ||
| from torch.fx import GraphModule, Node | ||
| from torch_tensorrt.dynamo.lowering._SubgraphBuilder import SubgraphBuilder | ||
|
|
||
| from .pass_utils import clean_up_graph_after_modifications | ||
|
|
||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cant we handle this case in the converter itself?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a reason we would prefer that?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think generally simpler converters and more in lowering is better, but I think the line needs to be clear, either we normalize dimensions in the graph or we do it in the converter
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. going through this case ya lowering seems better, since the negative seems to be ITensor, so we would need extra Iselect layers, making the converter more complicated. Though we have done this before in converters, but if we want simpler converters now, doing in lowering makes more sense |
||
| def _negative_symint_operand(x: object) -> Optional[object]: | ||
| # Return n for symbolic bounds represented as -n. The caller rewrites | ||
| # that bound to dim_size - n, matching Python's negative indexing rules. | ||
| if ( | ||
| isinstance(x, Node) | ||
| and x.op == "call_function" | ||
| and x.target in (operator.neg, torch.ops.aten.neg.default) | ||
| and len(x.args) == 1 | ||
| ): | ||
| return x.args[0] | ||
| return None | ||
|
|
||
|
|
||
| def _rank(x: Node) -> Optional[int]: | ||
| val = x.meta.get("val") | ||
| if isinstance(val, torch.Tensor): | ||
| return val.dim() | ||
| if hasattr(val, "shape"): | ||
| return len(val.shape) | ||
| return None | ||
|
|
||
|
|
||
| def normalize_negative_slice_stop( | ||
| gm: GraphModule, settings: object = None | ||
| ) -> GraphModule: | ||
| """Normalize negative symbolic slice bounds to positive dim-relative bounds. | ||
|
|
||
| Python slicing accepts negative bounds such as x[-n:] or x[:-n]. TensorRT | ||
| shape expressions need the equivalent positive bound, dim_size - n. | ||
| """ | ||
| modified = False | ||
|
|
||
| for node in list(gm.graph.nodes): | ||
| if node.op != "call_function" or node.target != torch.ops.aten.slice.Tensor: | ||
| continue | ||
|
|
||
| args = list(node.args) | ||
| if len(args) < 3: | ||
| continue | ||
|
|
||
| input_node, dim = args[:2] | ||
| if not isinstance(input_node, Node) or not isinstance(dim, int): | ||
| continue | ||
|
|
||
| rank = _rank(input_node) | ||
| if rank is not None: | ||
| # Match PyTorch dim normalization for negative dims. | ||
| dim = dim % rank | ||
|
|
||
| rewritten = False | ||
| # aten.slice.Tensor can appear as (input, dim, start) or | ||
| # (input, dim, start, stop, ...). Normalize either symbolic bound. | ||
| for bound_index in (2, 3): | ||
| if len(args) <= bound_index: | ||
| continue | ||
|
|
||
| bound = args[bound_index] | ||
| positive_offset = _negative_symint_operand(bound) | ||
| if positive_offset is None: | ||
| continue | ||
|
|
||
| with SubgraphBuilder(gm.graph, node.prev) as b: | ||
| dim_size = b(torch.ops.aten.sym_size.int, input_node, dim) | ||
| # A negative symbolic bound -n becomes dim_size - n. | ||
| normalized_bound = b(operator.sub, dim_size, positive_offset) | ||
|
|
||
| args[bound_index] = normalized_bound | ||
| rewritten = True | ||
|
|
||
| if rewritten: | ||
| args[1] = dim | ||
| node.args = tuple(args) | ||
| modified = True | ||
|
|
||
| return clean_up_graph_after_modifications(gm) if modified else gm | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we have this in both the converter and a lowering pass to remove sym_min?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lowering pass should only remove a specific instance of sym_min