Skip to content

Commit fec98bd

Browse files
authored
[Tripy] Return the shape immediately if it is statically known instead of producing a trace operator. (#379)
Addresses issue #360.
1 parent e086494 commit fec98bd

File tree

7 files changed

+77
-18
lines changed

7 files changed

+77
-18
lines changed

tripy/tests/backend/api/test_compile.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,19 @@ def test_dynamic_shapes(self):
138138
out = compiled_add(tp.ones((3, 1), dtype=tp.float32), tp.ones((3, 1), dtype=tp.float32))
139139
assert cp.array_equal(cp.from_dlpack(out), cp.ones((3, 1), dtype=cp.float32) * 2)
140140

141+
# if we specify dynamic shapes in compilation, they should not be fixed afterwards
142+
def test_dynamic_shapes_not_fixed(self):
143+
def func(inp):
144+
s = inp.shape[0] + inp.shape[1] + inp.shape[2]
145+
return tp.ones([s], dtype=tp.float32)
146+
147+
compiled_ones = tp.compile(func, args=[tp.InputInfo(((1, 2, 5), (1, 2, 5), (1, 2, 5)), dtype=tp.float32)])
148+
149+
for shape in ((1, 1, 1), (3, 3, 3), (2, 4, 5), (5, 2, 1)):
150+
inp = tp.ones(shape, dtype=tp.float32)
151+
out = compiled_ones(inp)
152+
assert out.shape == [sum(shape)]
153+
141154
def test_error_if_evaling_input_during_compile(self):
142155
def func(a):
143156
print(a)

tripy/tests/frontend/trace/ops/test_binary_elementwise.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,12 +127,13 @@ def test_invalid_broadcast_fails(self):
127127
c.eval()
128128

129129
def test_dimension_size_inputs(self):
130-
a = tp.Tensor([1, 2])
130+
d = tp.DimensionSize(1)
131131

132132
# Operations on only DimensionSizes will yield a DimensionSize
133-
out = a.shape[0] + a.shape[0]
133+
out = d + d
134134
assert isinstance(out, tp.DimensionSize)
135135

136136
# Otherwise, a Tensor is yielded.
137-
out = a + a.shape[0]
137+
a = tp.Tensor([1, 2])
138+
out = a + d
138139
assert isinstance(out, tp.Tensor) and not isinstance(out, tp.DimensionSize)

tripy/tests/performance/test_perf.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515
import time
1616
from textwrap import dedent
17+
from typing import Callable
1718

1819
import pytest
1920
import torch
@@ -25,6 +26,22 @@
2526
import tripy as tp
2627

2728

29+
def run_timed_trials(thunk: Callable[[], None], warm_up_runs=10, iterations=1000):
30+
"""
31+
Returns the average time measured for calls to the thunk (the function intended to be timed)
32+
in microseconds. First performs the specified number of untimed warm-ups.
33+
"""
34+
35+
for _ in range(warm_up_runs):
36+
thunk()
37+
38+
start = time.perf_counter_ns()
39+
for _ in range(iterations):
40+
thunk()
41+
end = time.perf_counter_ns()
42+
return (end - start) / (iterations * 1000.0)
43+
44+
2845
@pytest.mark.parametrize("perf_case", PERF_CASES)
2946
def test_perf_regression(perf_case, benchmark):
3047
compiled_tripy_module, _, inputs, _ = perf_case
@@ -115,15 +132,10 @@ def func({arg_str}):
115132
for input in inputs:
116133
input.eval()
117134

118-
for _ in range(warm_up_runs):
119-
compiled_one_io(*inputs)
120-
121-
start = time.perf_counter_ns()
122-
for _ in range(iterations):
123-
compiled_one_io(*inputs)
124-
end = time.perf_counter_ns()
135+
def measure_thunk():
136+
return compiled_one_io(*inputs)
125137

126-
return (end - start) / (iterations * 1000.0)
138+
return run_timed_trials(measure_thunk, warm_up_runs=warm_up_runs, iterations=iterations)
127139

128140
assert measure_overhead(1) < 60.0
129141

@@ -137,3 +149,13 @@ def func({arg_str}):
137149
# Ensure all deltas are within a few microseconds of each other
138150
average_delta = sum(deltas) / float(len(deltas))
139151
assert all(abs(delta - average_delta) < 10 for delta in deltas)
152+
153+
154+
def test_tripy_param_update(benchmark):
155+
m = tp.Module()
156+
m.param = tp.Parameter([1, 2, 3, 4])
157+
158+
def measure_thunk():
159+
m.param = tp.Parameter([5, 6, 7, 8])
160+
161+
benchmark(measure_thunk)

tripy/tripy/frontend/trace/ops/reduce.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ def prod(
296296

297297

298298
def mean_impl(tensor: "tripy.Tensor", dim: Union[int, Sequence] = None, keepdim: bool = False, apply_to_divisor=None):
299+
from tripy.frontend.tensor import Tensor
299300
from tripy.frontend.trace.ops.cast import cast
300301

301302
sum_val = sum(tensor, dim=dim, keepdim=keepdim)
@@ -307,7 +308,12 @@ def mean_impl(tensor: "tripy.Tensor", dim: Union[int, Sequence] = None, keepdim:
307308
if apply_to_divisor:
308309
num_elements = apply_to_divisor(num_elements)
309310

310-
return sum_val / (cast(num_elements, sum_val.dtype))
311+
num_elements = (
312+
cast(num_elements, sum_val.dtype)
313+
if isinstance(num_elements, Tensor)
314+
else Tensor(num_elements, dtype=sum_val.dtype)
315+
)
316+
return sum_val / num_elements
311317

312318

313319
@export.public_api(document_under="operations/functions")

tripy/tripy/frontend/trace/ops/reshape.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,17 @@ def to_flat_ir(self, inputs, outputs):
4343

4444

4545
def infer_dimensions(input: "tripy.Tensor", shape: ShapeLike) -> ShapeLike:
46+
4647
num_unknown_dims = len([dim for dim in shape if op_utils.is_minus_one(dim)])
4748
if num_unknown_dims > 1:
4849
raise_error(f"The new shape can have at most one inferred dimension (denoted by -1)", [f"Got shape: {shape}."])
4950

5051
if num_unknown_dims == 1:
5152
input_volume = math.prod(input.shape)
5253
known_dims_volume = math.prod(dim for dim in shape if not op_utils.is_minus_one(dim))
53-
inferred_dim = input_volume / known_dims_volume
54+
inferred_dim = (
55+
input_volume // known_dims_volume
56+
) # If we have scalars, the floor div ensures the result is an int.
5457

5558
shape = [inferred_dim if op_utils.is_minus_one(dim) else dim for dim in shape]
5659

tripy/tripy/frontend/trace/ops/shape.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616
#
1717

1818
from dataclasses import dataclass
19-
from typing import List
2019

2120
from tripy import constraints
2221
from tripy.common.datatype import DATA_TYPES
2322
from tripy.frontend.ops.registry import TENSOR_METHOD_REGISTRY
2423
from tripy.frontend.trace.ops.base import BaseTraceOp
24+
from tripy.types import ShapeLike
2525

2626

2727
@dataclass(repr=False)
@@ -45,7 +45,7 @@ def to_flat_ir(self, inputs, outputs):
4545
@TENSOR_METHOD_REGISTRY("shape")
4646
@property
4747
@constraints.dtypes(constraints={"self": "T1"}, variables={"T1": list(DATA_TYPES.keys())})
48-
def shape(self: "tripy.Tensor") -> List["tripy.DimensionSize"]:
48+
def shape(self: "tripy.Tensor") -> ShapeLike:
4949
"""
5050
Represents the shape of the tensor.
5151
@@ -63,4 +63,9 @@ def shape(self: "tripy.Tensor") -> List["tripy.DimensionSize"]:
6363
assert shape == [8, 2]
6464
"""
6565

66+
# If the shape is statically known, we do not need to insert any operator calls.
67+
# However, if we are tracing, it might still be necessary to insert calls in the final program, so we will keep it.
68+
if all(dim >= 0 for dim in self.trace_tensor.shape) and not self.trace_tensor.is_compile_tracer:
69+
return self.trace_tensor.shape
70+
6671
return [GetDimensionSize.build([self], dim=index, always_cast_to_dimension_size=True) for index in range(self.rank)]

tripy/tripy/frontend/trace/ops/slice.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,9 @@ def __getitem__(
164164
assert np.array_equal(cp.from_dlpack(output).get(), np.arange(10)[8:2:-1])
165165
166166
"""
167+
from tripy.frontend.dimension_size import DimensionSize
167168
from tripy.frontend.tensor import Tensor
169+
from tripy.frontend.trace.ops.binary_elementwise import maximum, minimum
168170
from tripy.frontend.trace.ops.flip import flip
169171
from tripy.frontend.trace.ops.gather import gather
170172
from tripy.frontend.trace.ops.squeeze import squeeze
@@ -198,9 +200,16 @@ def convert_to_positive_idx(index: Union[int, Tensor]) -> Union[int, Tensor]:
198200
# because out of bounds indices for a *slice* mean that the dim should be empty, not an error
199201
def clamp_bound(bound: Union[int, Tensor]) -> Union[int, Tensor]:
200202
if isinstance(bound, int):
201-
return 0 if bound < 0 else where(bound > t_shape[i], t_shape[i], Tensor([bound]))
202-
else:
203-
return where(bound < 0, Tensor([0]), where(bound > t_shape[i], t_shape[i], bound))
203+
if bound < 0:
204+
return 0
205+
206+
if isinstance(t_shape[i], int):
207+
return min(bound, t_shape[i])
208+
return minimum(t_shape[i], Tensor([bound]))
209+
210+
# need the shame dimension to be a tensor to use as an argument to min and max
211+
shape_dim = t_shape[i] if isinstance(t_shape[i], Tensor) else DimensionSize(t_shape[i])
212+
return maximum(Tensor([0]), minimum(shape_dim, bound))
204213

205214
if isinstance(idx, int) or isinstance(idx, Tensor):
206215
args.append(convert_to_positive_idx(idx))

0 commit comments

Comments
 (0)