|
25 | 25 |
|
26 | 26 | class TestEqual: |
27 | 27 | @pytest.mark.parametrize( |
28 | | - "tensor_a, tensor_b, dtype", |
| 28 | + "tensor_a, tensor_b, dtype, expected", |
29 | 29 | [ |
30 | | - ([1, 2, 3], [1, 2, 3], tp.int64), |
31 | | - ([1, 2, 3], [1, 2, 4], tp.int32), |
32 | | - ([1.0, 2.0, 3.0], [1.0, 2.0, 3.0], tp.float16), |
33 | | - ([1.0, 2.0, 3.0], [1.0, 2.0, 1.0], tp.bfloat16), |
34 | | - ([1.0, 2.0, 3.0], [1.0, 2.0, 3.00001], tp.float32), |
35 | | - ([True, False, True], [True, False, True], tp.bool), |
| 30 | + ([1, 2, 3], [1, 2, 3], tp.int64, True), |
| 31 | + ([1, 2, 3], [1, 2, 4], tp.int32, False), |
| 32 | + ([1.0, 2.0, 3.0], [1.0, 2.0, 3.0], tp.float16, True), |
| 33 | + ([1.0, 2.0, 3.0], [1.0, 2.0, 1.0], tp.bfloat16, False), |
| 34 | + ([1.0, 2.0, 3.0], [1.0, 2.0, 3.00001], tp.float32, False), |
| 35 | + ([True, False, True], [True, False, True], tp.bool, True), |
36 | 36 | ], |
37 | 37 | ) |
38 | | - def test_equal(self, tensor_a, tensor_b, dtype): |
39 | | - # Convert to torch tensors for comparison |
40 | | - torch_a = torch.tensor(tensor_a, dtype=self.torch_dtype(dtype)) |
41 | | - torch_b = torch.tensor(tensor_b, dtype=self.torch_dtype(dtype)) |
| 38 | + def test_equal(self, tensor_a, tensor_b, dtype, expected): |
| 39 | + a = tp.Tensor(tensor_a, dtype=dtype) |
| 40 | + b = tp.Tensor(tensor_b, dtype=dtype) |
42 | 41 |
|
43 | | - # Convert to tripy tensors |
44 | | - tp_a = tp.Tensor(tensor_a, dtype=dtype) |
45 | | - tp_b = tp.Tensor(tensor_b, dtype=dtype) |
| 42 | + out = tp.equal(a, b) |
46 | 43 |
|
47 | | - # Compare results |
48 | | - torch_result = torch.equal(torch_a, torch_b) |
49 | | - tp_result = tp.equal(tp_a, tp_b) |
50 | | - |
51 | | - assert torch_result == tp_result |
52 | | - |
53 | | - @staticmethod |
54 | | - def torch_dtype(tp_dtype): |
55 | | - # Map tripy dtypes to torch dtypes |
56 | | - dtype_map = { |
57 | | - tp.float32: torch.float32, |
58 | | - tp.float16: torch.float16, |
59 | | - tp.bfloat16: torch.bfloat16, |
60 | | - tp.int32: torch.int32, |
61 | | - tp.int64: torch.int64, |
62 | | - tp.int8: torch.int8, |
63 | | - tp.bool: torch.bool, |
64 | | - } |
65 | | - return dtype_map[tp_dtype] |
| 44 | + assert out == expected |
66 | 45 |
|
67 | 46 | def test_equal_different_shapes(self): |
68 | 47 | a = tp.Tensor([1, 2, 3]) |
|
0 commit comments