Skip to content

Commit 97355ee

Browse files
authored
Allow repeat to accept shapescalar along with int for repeats argument (#230)
1 parent 956ca4c commit 97355ee

File tree

5 files changed

+46
-7
lines changed

5 files changed

+46
-7
lines changed

tripy/tests/frontend/test_shape.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ def test_shape(self, values):
8282

8383
assert isinstance(s, tp.Shape)
8484
assert len(s) == len(values)
85-
assert s.trace_tensor.producer.inputs == []
8685
assert cp.from_dlpack(s).get().tolist() == values
8786

8887
def test_empty_shape(self):

tripy/tests/integration/test_repeat.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,11 @@ def test_repeat(self, repeats, dim):
3737
expected = np.repeat(inp, repeats, dim)
3838

3939
assert np.array_equal(np.from_dlpack(tp.copy(out, device=tp.device("cpu"))), expected)
40+
41+
def test_repeat_shape_scalar(self):
42+
inp = np.arange(4, dtype=np.int32).reshape((2, 2))
43+
s = tp.ones((1, 2))
44+
out = tp.repeat(tp.Tensor(inp), s.shape[1], 0)
45+
expected = np.repeat(inp, 2, 0)
46+
47+
assert np.array_equal(np.from_dlpack(tp.copy(out, device=tp.device("cpu"))), expected)

tripy/tripy/backend/mlir/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,14 @@ def get_mlir_scalar_attr(mlir_dtype, value):
9393

9494

9595
def list_to_dense_attr(data: List, mlir_dtype):
96+
from tripy.frontend.shape import ShapeScalar
97+
9698
if isinstance(data, numbers.Number):
9799
return [get_mlir_scalar_attr(mlir_dtype, data)]
100+
101+
if isinstance(data, ShapeScalar):
102+
return [get_mlir_scalar_attr(mlir_dtype, data.tolist())]
103+
98104
attrs = []
99105
for element in data:
100106
attrs.extend(list_to_dense_attr(element, mlir_dtype))

tripy/tripy/frontend/ops/repeat.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
#
16+
from typing import Union
1617
from tripy import constraints, export
1718
from tripy.common.exception import raise_error
1819
from tripy.frontend import utils as frontend_utils
@@ -26,7 +27,7 @@
2627
dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"},
2728
)
2829
@frontend_utils.process_dim
29-
def repeat(input: "tripy.Tensor", repeats: int, dim: int) -> "tripy.Tensor":
30+
def repeat(input: "tripy.Tensor", repeats: Union[int, "tripy.ShapeScalar"], dim: int) -> "tripy.Tensor":
3031
"""
3132
Repeats each element of a tensor after itself along the specified dimension.
3233
@@ -68,9 +69,14 @@ def repeat(input: "tripy.Tensor", repeats: int, dim: int) -> "tripy.Tensor":
6869
from tripy.frontend.trace.ops.expand import expand
6970
from tripy.frontend.trace.ops.reshape import reshape
7071
from tripy.frontend.trace.ops.unsqueeze import unsqueeze
72+
from tripy.frontend.tensor import Tensor
73+
from tripy.frontend.shape import ShapeScalar, Shape
74+
from tripy.frontend.trace.ops.concatenate import concatenate
7175

72-
if repeats < 0:
73-
raise_error("`repeats` value must be non-negative.", [f"Got: repeats={repeats}."])
76+
if isinstance(repeats, int):
77+
if repeats < 0:
78+
raise_error("`repeats` value must be non-negative.", [f"Got: repeats={repeats}."])
79+
repeats = ShapeScalar(repeats)
7480

7581
# By constraining repeats to be a single integer, we can use a very
7682
# simple implementation for repeat.
@@ -84,10 +90,11 @@ def repeat(input: "tripy.Tensor", repeats: int, dim: int) -> "tripy.Tensor":
8490
# [2],] [2, 2],]
8591
#
8692
out = unsqueeze(input, dim + 1)
87-
out = expand(out, input.shape[: dim + 1] + [repeats] + input.shape[dim + 1 :])
93+
out = expand(out, input.shape[: dim + 1] + Shape([repeats]) + input.shape[dim + 1 :])
8894

89-
repeat_mask = [1] * input.rank
90-
repeat_mask[dim] = repeats
95+
repeat_mask = concatenate(
96+
[reshape(repeats, (1,)) if idx == dim else Tensor([1]) for idx in range(input.rank)], dim=0
97+
)
9198
new_shape = input.shape.multiply(repeat_mask)
9299
out = reshape(out, new_shape)
93100
return out

tripy/tripy/frontend/shape.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,25 @@ def __init__(
120120
details=[data],
121121
)
122122

123+
# the shape of data should correspond to the given rank
124+
super().__init__(data=None, dtype=int32, name=name, device=data.device)
125+
# share the underlying data
126+
self.trace_tensor = data.trace_tensor
127+
self.stack_info = data.stack_info
128+
elif (
129+
isinstance(data, Sequence)
130+
and len(data) > 0
131+
and all(map(lambda e: isinstance(e, int) or isinstance(e, ShapeScalar), data))
132+
):
133+
# Handle the case where data is a list of mixed int and ShapeScalar elements
134+
# Example: [1, a.shape[0]]
135+
# We convert this to a tensor to avoid expensive evaluation of ShapeScalar elements (like a.shape[0])
136+
from tripy.frontend.trace.ops.concatenate import concatenate
137+
from tripy.frontend.trace.ops.reshape import reshape
138+
139+
data = concatenate(
140+
[reshape(e, (1,)) if isinstance(e, ShapeScalar) else Tensor([e], dtype=int32) for e in data], dim=0
141+
)
123142
# the shape of data should correspond to the given rank
124143
super().__init__(data=None, dtype=int32, name=name, device=data.device)
125144
# share the underlying data

0 commit comments

Comments
 (0)