Skip to content

Commit 324dc49

Browse files
Update tripy/tests/integration/test_equal.py
Co-authored-by: pranavm-nvidia <[email protected]> Signed-off-by: Jhalak Patel <[email protected]>
1 parent 48ede16 commit 324dc49

File tree

1 file changed

+12
-33
lines changed

1 file changed

+12
-33
lines changed

tripy/tests/integration/test_equal.py

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -25,44 +25,23 @@
2525

2626
class TestEqual:
2727
@pytest.mark.parametrize(
28-
"tensor_a, tensor_b, dtype",
28+
"tensor_a, tensor_b, dtype, expected",
2929
[
30-
([1, 2, 3], [1, 2, 3], tp.int64),
31-
([1, 2, 3], [1, 2, 4], tp.int32),
32-
([1.0, 2.0, 3.0], [1.0, 2.0, 3.0], tp.float16),
33-
([1.0, 2.0, 3.0], [1.0, 2.0, 1.0], tp.bfloat16),
34-
([1.0, 2.0, 3.0], [1.0, 2.0, 3.00001], tp.float32),
35-
([True, False, True], [True, False, True], tp.bool),
30+
([1, 2, 3], [1, 2, 3], tp.int64, True),
31+
([1, 2, 3], [1, 2, 4], tp.int32, False),
32+
([1.0, 2.0, 3.0], [1.0, 2.0, 3.0], tp.float16, True),
33+
([1.0, 2.0, 3.0], [1.0, 2.0, 1.0], tp.bfloat16, False),
34+
([1.0, 2.0, 3.0], [1.0, 2.0, 3.00001], tp.float32, False),
35+
([True, False, True], [True, False, True], tp.bool, True),
3636
],
3737
)
38-
def test_equal(self, tensor_a, tensor_b, dtype):
39-
# Convert to torch tensors for comparison
40-
torch_a = torch.tensor(tensor_a, dtype=self.torch_dtype(dtype))
41-
torch_b = torch.tensor(tensor_b, dtype=self.torch_dtype(dtype))
38+
def test_equal(self, tensor_a, tensor_b, dtype, expected):
39+
a = tp.Tensor(tensor_a, dtype=dtype)
40+
b = tp.Tensor(tensor_b, dtype=dtype)
4241

43-
# Convert to tripy tensors
44-
tp_a = tp.Tensor(tensor_a, dtype=dtype)
45-
tp_b = tp.Tensor(tensor_b, dtype=dtype)
42+
out = tp.equal(a, b)
4643

47-
# Compare results
48-
torch_result = torch.equal(torch_a, torch_b)
49-
tp_result = tp.equal(tp_a, tp_b)
50-
51-
assert torch_result == tp_result
52-
53-
@staticmethod
54-
def torch_dtype(tp_dtype):
55-
# Map tripy dtypes to torch dtypes
56-
dtype_map = {
57-
tp.float32: torch.float32,
58-
tp.float16: torch.float16,
59-
tp.bfloat16: torch.bfloat16,
60-
tp.int32: torch.int32,
61-
tp.int64: torch.int64,
62-
tp.int8: torch.int8,
63-
tp.bool: torch.bool,
64-
}
65-
return dtype_map[tp_dtype]
44+
assert out == expected
6645

6746
def test_equal_different_shapes(self):
6847
a = tp.Tensor([1, 2, 3])

0 commit comments

Comments
 (0)