|
21 | 21 | from tripy import utils |
22 | 22 | from tests import helper |
23 | 23 | from collections import defaultdict |
| 24 | +from typing import Optional, Sequence |
24 | 25 |
|
25 | 26 |
|
26 | 27 | class TestMd5: |
@@ -109,3 +110,49 @@ def test_gen_uid(self, inputs, outputs, expected_prefix): |
109 | 110 | def test_uniqueness(self): |
110 | 111 | uids = [utils.UniqueNameGen.gen_uid() for _ in range(100)] |
111 | 112 | assert len(set(uids)) == 100 |
| 113 | + |
| 114 | + |
| 115 | +class TestStride: |
| 116 | + @pytest.mark.parametrize( |
| 117 | + "shape, provided_stride, expected_stride", |
| 118 | + [ |
| 119 | + ((), None, ()), |
| 120 | + ((1,), None, (1,)), |
| 121 | + ((2, 3), None, (3, 1)), |
| 122 | + ((0, 5), None, (1, 1)), |
| 123 | + ((1, 0, 3), None, (1, 1, 1)), |
| 124 | + ((2, 1, 4), None, (4, 1, 1)), |
| 125 | + ((3, 0, 0, 5), None, (5, 1, 1, 1)), |
| 126 | + ((2, 3), (3, 1), (3, 1)), |
| 127 | + ((0, 5), (5, 1), (5, 1)), |
| 128 | + ((1, 0, 3), (3, 1, 1), (3, 1, 1)), |
| 129 | + (None, None, None), |
| 130 | + ], |
| 131 | + ) |
| 132 | + def test_get_stride( |
| 133 | + self, |
| 134 | + shape: Optional[Sequence[int]], |
| 135 | + provided_stride: Optional[Sequence[int]], |
| 136 | + expected_stride: Optional[Sequence[int]], |
| 137 | + ): |
| 138 | + """Test both get_stride and get_canonical_stride functions.""" |
| 139 | + assert utils.get_stride(shape, provided_stride) == expected_stride |
| 140 | + if provided_stride is None and shape is not None: |
| 141 | + assert utils.get_canonical_stride(shape) == expected_stride |
| 142 | + |
| 143 | + @pytest.mark.parametrize( |
| 144 | + "shape, stride, expected_stride, expected_result", |
| 145 | + [ |
| 146 | + ((2, 3), (3, 1), (3, 1), True), |
| 147 | + ((2, 3), (6, 2), (3, 1), False), |
| 148 | + ((0, 5), (1, 1), (5, 1), True), |
| 149 | + ((1, 0, 3), (0, 3, 1), (3, 3, 1), True), |
| 150 | + ((2, 1, 4), (8, 4, 1), (4, 1, 1), False), |
| 151 | + ((2, 3), (3, 1, 2), (3, 1), False), # Mismatched lengths |
| 152 | + ], |
| 153 | + ) |
| 154 | + def test_are_strides_equivalent( |
| 155 | + self, shape: Sequence[int], stride: Sequence[int], expected_stride: Sequence[int], expected_result: bool |
| 156 | + ): |
| 157 | + """Test are_strides_equivalent function.""" |
| 158 | + assert utils.are_strides_equivalent(shape, stride, expected_stride) == expected_result |
0 commit comments