Skip to content

Commit f72a7af

Browse files
authored
[Tripy][Bugfix] Use correct types in __str__ method for ShapeScalar (#212)
Noticed a small error: The `__str__` method for `ShapeScalar` assumed that evaluating the scalar would give a list output, resulting in a type error, since the actual result is a scalar. This PR fixes that and adds a check to the unit tests.
1 parent d3e9d2c commit f72a7af

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

tripy/tests/frontend/test_shape.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,26 @@ def other_values(request):
4040

4141

4242
class TestShapeScalar:
43-
@pytest.mark.parametrize("value", [1, tp.Tensor(1), np.array(2)])
43+
@pytest.mark.parametrize(
44+
"value",
45+
[
46+
1,
47+
tp.Tensor(1),
48+
# Note: if we don't specify the dtype, the tensor constructor will insert a cast
49+
# and the assert below about the trace_tensor's producer will fail.
50+
np.array(2, dtype=np.int32),
51+
],
52+
)
4453
def test_scalar_shape(self, value):
45-
s = tp.ShapeScalar(values)
54+
s = tp.ShapeScalar(value)
4655

4756
assert isinstance(s, tp.ShapeScalar)
4857
assert s.trace_tensor.producer.inputs == []
4958

59+
def test_scalar_shape_str_method(self):
60+
s = tp.ShapeScalar(12)
61+
assert s.__str__() == f"shape_scalar(12)"
62+
5063
def test_scalar_slice(self):
5164
a = tp.iota((3, 3))
5265
assert isinstance(a.shape[0], tp.ShapeScalar)

tripy/tripy/frontend/shape.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,9 @@ def __repr__(self) -> str:
7777
return "shape_scalar" + tensor_repr[6:]
7878

7979
def __str__(self) -> str:
80-
return "shape_scalar" + "(" + ", ".join(map(str, self.tolist())) + ")"
80+
val = self.tolist()
81+
assert isinstance(val, int)
82+
return f"shape_scalar({val})"
8183

8284

8385
@export.public_api()

0 commit comments

Comments
 (0)