|
17 | 17 | import pytest |
18 | 18 | import nvtripy as tp |
19 | 19 |
|
| 20 | +test_cases = [ |
| 21 | + ((2, 3, 4), 0, -1, (24,)), # Flatten all dimensions |
| 22 | + ((2, 3, 4), 1, -1, (2, 12)), # Flatten dimensions 1 through end |
| 23 | + ((2, 3, 4), 1, 2, (2, 12)), # Flatten dimensions 1 through 2 |
| 24 | + ((2, 3, 4), 0, 1, (6, 4)), # Flatten dimensions 0 through 1 |
| 25 | + ((2, 3, 4, 5), 1, 3, (2, 60)), # Flatten dimensions 1 through 3 |
| 26 | +] |
| 27 | + |
20 | 28 |
|
21 | 29 | class TestFlatten: |
22 | 30 | @pytest.mark.parametrize( |
23 | 31 | "shape, start_dim, end_dim, expected_shape", |
24 | | - [ |
25 | | - ((2, 3, 4), 0, -1, (24,)), # Flatten all dimensions |
26 | | - ((2, 3, 4), 1, -1, (2, 12)), # Flatten dimensions 1 through end |
27 | | - ((2, 3, 4), 1, 2, (2, 12)), # Flatten dimensions 1 through 2 |
28 | | - ((2, 3, 4), 0, 1, (6, 4)), # Flatten dimensions 0 through 1 |
29 | | - ((2, 3, 4, 5), 1, 3, (2, 60)), # Flatten dimensions 1 through 3 |
30 | | - ], |
| 32 | + test_cases, |
31 | 33 | ) |
32 | 34 | def test_flatten(self, shape, start_dim, end_dim, expected_shape, eager_or_compiled): |
33 | 35 | cp_a = cp.arange(np.prod(shape)).reshape(shape).astype(np.float32) |
@@ -55,3 +57,15 @@ def test_flatten_with_unknown_dims(self, eager_or_compiled): |
55 | 57 | a = tp.ones((2, 3, 4, 5)) |
56 | 58 | b = eager_or_compiled(tp.flatten, a, start_dim=1, end_dim=-1) |
57 | 59 | assert np.array_equal(cp.from_dlpack(b).get(), np.ones((2, 60), dtype=np.float32)) |
| 60 | + |
| 61 | + @pytest.mark.parametrize( |
| 62 | + "shape, start_dim, end_dim, expected_shape", |
| 63 | + test_cases, |
| 64 | + ) |
| 65 | + def test_flatten_tensor_method(self, shape, start_dim, end_dim, expected_shape, eager_or_compiled): |
| 66 | + """Test that tensor.flatten() method works and produces same result as free function.""" |
| 67 | + cp_a = cp.arange(np.prod(shape)).reshape(shape).astype(np.float32) |
| 68 | + a = tp.Tensor(cp_a) |
| 69 | + b = eager_or_compiled(lambda t: t.flatten(start_dim=start_dim, end_dim=end_dim), a) |
| 70 | + assert b.shape == expected_shape |
| 71 | + assert np.array_equal(cp.from_dlpack(b).get(), cp_a.reshape(expected_shape).get()) |
0 commit comments