Skip to content

Commit debf018

Browse files
authored
Fix unnecessary calls to compute shape of trace tensor; fix #246 (#280)
This MR resolves the following: 1. In a trivial example like below, we were earlier computing shape of dynamic shape trace tensor which can become very expensive if reshape occurs somewhere in the middle of a big compute graph. The fix is a simple hack to force the shape of trace tensor when its statically known (in this case it is when we convert a shape scalar to 1d tensor). ```py a = tp.ones((2, 3, 4)) s1, s2, s3 = a.shape out = tp.reshape(a, (s1, s2, s3 / 2, 2)) ``` 2. Fixes #246 (thanks to @yizhuoz004 for the suggested fix)
1 parent 9c0055f commit debf018

File tree

8 files changed

+28
-53
lines changed

8 files changed

+28
-53
lines changed

tripy/pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ test = [
5757
"pytest-xdist==3.6.1",
5858
"pytest-benchmark==4.0.0",
5959
"pytest-lazy-fixture==0.6.3",
60+
"pytest-mock==3.14.0",
61+
"path.py==12.5.0",
6062
# Triton is required for torch.compile
6163
"triton==3.0.0",
6264
"snakeviz==2.2.0",

tripy/tests/frontend/trace/ops/test_unsqueeze.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@
1616
#
1717

1818
import tripy as tp
19-
from tripy.frontend.trace.ops import Unsqueeze
19+
from tripy.frontend.trace.ops import Reshape
2020

2121

2222
class TestUnsqueeze:
2323
def test_func_op(self):
2424
a = tp.ones((2, 1))
2525
a = tp.unsqueeze(a, 0)
2626
assert isinstance(a, tp.Tensor)
27-
assert isinstance(a.trace_tensor.producer, Unsqueeze)
27+
assert isinstance(a.trace_tensor.producer, Reshape)
2828

2929
def test_infer_rank(self):
3030
a = tp.ones((2, 1))

tripy/tests/integration/test_reshape.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,13 @@ def test_flatten_with_unknown_dims(self):
9797
a = tp.ones((2, 3, 4, 5))
9898
b = tp.flatten(a, start_dim=1, end_dim=-1)
9999
assert np.array_equal(cp.from_dlpack(b).get(), np.ones((2, 60), dtype=np.float32))
100+
101+
102+
def test_reshape_with_no_trace_shape_inference(mocker):
103+
# This test ensures that Tripy does not unnecessarily compute shape of a trace tensor.
104+
mock_function = mocker.patch("tripy.backend.mlir.utils.ShapeContext.get_shape_of_dynamic_trace_tensor")
105+
a = tp.ones((2, 3, 4))
106+
s1, s2, s3 = a.shape
107+
out = tp.reshape(a, (s1, s2, s3 / 2, 2))
108+
out.eval()
109+
mock_function.assert_not_called()

tripy/tests/integration/test_unsqueeze.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,9 @@ def func(a):
3535
assert tp.allclose(out, tp.Tensor(ref_out))
3636

3737
assert out.shape == ref_out.shape
38+
39+
def test_unsqueeze_compile(self):
40+
def func(a):
41+
return tp.unsqueeze(a, 3) == tp.Tensor(3, dtype=tp.float32)
42+
43+
c = tp.compile(func, args=[tp.InputInfo(((1, 2, 3), 2, 3), dtype=tp.float32)])

tripy/tripy/backend/mlir/utils.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -380,22 +380,14 @@ def traverse_backwards(tensor, visited_tensors, visited_producers):
380380
"""
381381
Recurse back from tensor to the inputs to the graph and store the visited tensors and nodes.
382382
"""
383-
from tripy.frontend.trace.ops.unsqueeze import Unsqueeze
384-
385383
if id(tensor) in visited_tensors:
386384
return
387385

388386
visited_tensors[id(tensor)] = tensor
389387
if tensor.producer is not None:
390388
visited_producers[id(tensor.producer)] = tensor.producer
391-
# Special recursion conditions op by op basis.
392-
# Only recurse inputs which are used in output shape calculations.
393-
if isinstance(tensor.producer, Unsqueeze):
394-
traverse_backwards(tensor.producer.inputs[1], visited_tensors, visited_producers)
395-
else:
396-
# Naively recurse all the inputs until a constant or user input.
397-
for input_tensor in tensor.producer.inputs:
398-
traverse_backwards(input_tensor, visited_tensors, visited_producers)
389+
for input_tensor in tensor.producer.inputs:
390+
traverse_backwards(input_tensor, visited_tensors, visited_producers)
399391

400392
def find_inputs(graph_nodes):
401393
"""

tripy/tripy/frontend/trace/ops/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,4 @@
3939
from tripy.frontend.trace.ops.split import Split
4040
from tripy.frontend.trace.ops.storage import Storage
4141
from tripy.frontend.trace.ops.unary_elementwise import UnaryElementwise
42-
from tripy.frontend.trace.ops.unsqueeze import Unsqueeze
4342
from tripy.frontend.trace.ops.where import Where

tripy/tripy/frontend/trace/ops/unsqueeze.py

Lines changed: 2 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -15,46 +15,7 @@
1515
# limitations under the License.
1616
#
1717

18-
from dataclasses import dataclass
19-
2018
from tripy import export, constraints
21-
from tripy.frontend.trace.ops.base import BaseTraceOp
22-
import tripy.frontend.trace.ops.utils as op_utils
23-
from tripy.common.datatype import DATA_TYPES
24-
25-
26-
@dataclass(repr=False)
27-
class Unsqueeze(BaseTraceOp):
28-
29-
dim: int
30-
31-
# the result will not be rank 1 and so can't be a shape but we may want to unsqueeze shapes
32-
infer_shape_output_idxs = op_utils.ShapeOutputIdxPolicies.never_return_shape
33-
34-
def infer_dtypes(self):
35-
self.outputs[0].dtype = self.inputs[0].dtype
36-
37-
def infer_rank(self):
38-
self.outputs[0].rank = self.inputs[0].rank + 1
39-
40-
def to_flat_ir(self, inputs, outputs):
41-
from tripy.flat_ir.ops import DynamicBroadcastOp
42-
43-
broadcast_dim = list(range(inputs[0].rank))
44-
for idx in range(len(broadcast_dim)):
45-
if idx >= self.dim:
46-
broadcast_dim[idx] += 1
47-
48-
DynamicBroadcastOp.build(
49-
[inputs[0], inputs[1]],
50-
[outputs[0]],
51-
broadcast_dim=broadcast_dim,
52-
)
53-
54-
55-
# Two operand unsqueeze op to ensure that Trace op is 1:1 with Python code (for error messaging).
56-
def unsqueeze_two_operand(input, result_shape, dim):
57-
return Unsqueeze.build([input, result_shape], dim)
5819

5920

6021
@export.public_api(document_under="operations/functions")
@@ -85,6 +46,7 @@ def unsqueeze(input: "tripy.Tensor", dim: int) -> "tripy.Tensor":
8546
assert np.array_equal(cp.from_dlpack(output).get(), np.expand_dims(cp.from_dlpack(input).get(), 1))
8647
"""
8748
from tripy.frontend.trace.ops.concatenate import concatenate
49+
from tripy.frontend.trace.ops.reshape import reshape
8850

8951
from tripy.frontend import Shape
9052

@@ -97,4 +59,4 @@ def unsqueeze(input: "tripy.Tensor", dim: int) -> "tripy.Tensor":
9759
else:
9860
input_shape = input.shape
9961
result_shape = concatenate([input_shape[:dim], Shape([1]), input_shape[dim:]], dim=0)
100-
return unsqueeze_two_operand(input, result_shape, dim)
62+
return reshape(input, result_shape)

tripy/tripy/frontend/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,9 @@ def convert_nontensor_arg(arg, list_index=None):
447447
if member.rank != 0:
448448
raise_error("Tensor in a shape argument must be a scalar.", [f"Got {member}"])
449449
member = Shape(unsqueeze(member, 0))
450+
# Force the trace tensor shape to be (1,) since its known that we are reshaping a scalar to a 1D tensor.
451+
# If we don't force the shape below, Tripy might require computing the shape of this trace tensor which can be expensive.
452+
member.trace_tensor.shape = (1,)
450453
shape_components.append(member)
451454
if len(acc) > 0:
452455
shape_components.append(convert_nontensor_arg(acc))
@@ -639,6 +642,7 @@ def pretty_print(data_list, shape, threshold=1000, linewidth=10, edgeitems=3):
639642
"""
640643
Returns a pretty-print string of list format data.
641644
"""
645+
642646
def _data_str(data, summarize, linewidth, edgeitems, indent=0):
643647
if isinstance(data, (float, int)):
644648
return str(data)

0 commit comments

Comments
 (0)