|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. |
| 2 | +# SPDX-FileCopyrightText: All rights reserved. |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | +# |
| 5 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | +# you may not use this file except in compliance with the License. |
| 7 | +# You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | + |
| 17 | +import tempfile |
| 18 | +import numpy as np |
| 19 | +import pyvista as pv |
| 20 | +import pytest |
| 21 | +from pathlib import Path |
| 22 | + |
| 23 | +# Import functions from vtp_reader |
| 24 | +import sys |
| 25 | + |
| 26 | +sys.path.insert(0, str(Path(__file__).parent.parent)) |
| 27 | +from vtp_reader import ( |
| 28 | + load_vtp_file, |
| 29 | + extract_mesh_connectivity_from_polydata, |
| 30 | + build_edges_from_mesh_connectivity, |
| 31 | +) |
| 32 | + |
| 33 | + |
| 34 | +@pytest.fixture |
| 35 | +def simple_vtp_file(): |
| 36 | + """Create a simple VTP file for testing.""" |
| 37 | + # Create a simple quad mesh (2x2 grid) |
| 38 | + points = np.array( |
| 39 | + [ |
| 40 | + [0, 0, 0], |
| 41 | + [1, 0, 0], |
| 42 | + [0, 1, 0], |
| 43 | + [1, 1, 0], |
| 44 | + ], |
| 45 | + dtype=np.uint8, |
| 46 | + ) |
| 47 | + |
| 48 | + # Single quad cell |
| 49 | + faces = np.array([4, 0, 1, 3, 2]) # quad with 4 vertices |
| 50 | + |
| 51 | + mesh = pv.PolyData(points, faces, force_float=False) |
| 52 | + |
| 53 | + # Add displacement fields for 3 timesteps |
| 54 | + mesh.point_data["displacement_t0.000"] = np.array( |
| 55 | + [ |
| 56 | + [0, 0, 0], |
| 57 | + [0, 0, 0], |
| 58 | + [0, 0, 0], |
| 59 | + [0, 0, 0], |
| 60 | + ], |
| 61 | + dtype=np.uint8, |
| 62 | + ) |
| 63 | + |
| 64 | + mesh.point_data["displacement_t0.005"] = np.array( |
| 65 | + [ |
| 66 | + [1, 0, 0], |
| 67 | + [1, 0, 0], |
| 68 | + [1, 0, 0], |
| 69 | + [1, 0, 0], |
| 70 | + ], |
| 71 | + dtype=np.uint8, |
| 72 | + ) |
| 73 | + |
| 74 | + mesh.point_data["displacement_t0.010"] = np.array( |
| 75 | + [ |
| 76 | + [2, 0, 0], |
| 77 | + [2, 0, 0], |
| 78 | + [2, 0, 0], |
| 79 | + [2, 0, 0], |
| 80 | + ], |
| 81 | + dtype=np.uint8, |
| 82 | + ) |
| 83 | + |
| 84 | + # Add thickness as additional point data |
| 85 | + mesh.point_data["thickness"] = np.array([1, 1, 1, 1], dtype=np.uint8) |
| 86 | + |
| 87 | + # Save to temporary file |
| 88 | + with tempfile.NamedTemporaryFile(suffix=".vtp", delete=False) as f: |
| 89 | + temp_path = f.name |
| 90 | + |
| 91 | + mesh.save(temp_path) |
| 92 | + yield temp_path |
| 93 | + |
| 94 | + # Cleanup |
| 95 | + Path(temp_path).unlink(missing_ok=True) |
| 96 | + |
| 97 | + |
| 98 | +def test_load_vtp_file_basic(simple_vtp_file): |
| 99 | + """Test basic VTP file loading.""" |
| 100 | + pos_raw, mesh_connectivity, point_data_dict = load_vtp_file(simple_vtp_file) |
| 101 | + |
| 102 | + # Check positions shape: (timesteps, nodes, 3) |
| 103 | + assert pos_raw.shape == (3, 4, 3), f"Expected shape (3, 4, 3), got {pos_raw.shape}" |
| 104 | + |
| 105 | + # Check mesh connectivity |
| 106 | + assert len(mesh_connectivity) == 1, f"Expected 1 cell, got {len(mesh_connectivity)}" |
| 107 | + assert len(mesh_connectivity[0]) == 4, ( |
| 108 | + f"Expected quad with 4 vertices, got {len(mesh_connectivity[0])}" |
| 109 | + ) |
| 110 | + |
| 111 | + # Check point data dict contains thickness |
| 112 | + assert "thickness" in point_data_dict, "Thickness not found in point_data_dict" |
| 113 | + assert point_data_dict["thickness"].shape == (4,), ( |
| 114 | + f"Expected thickness shape (4,), got {point_data_dict['thickness'].shape}" |
| 115 | + ) |
| 116 | + |
| 117 | + |
| 118 | +def test_load_vtp_file_displacements(simple_vtp_file): |
| 119 | + """Test that displacements are correctly applied.""" |
| 120 | + pos_raw, _, _ = load_vtp_file(simple_vtp_file) |
| 121 | + |
| 122 | + # First timestep should be reference coords (displacement = 0) |
| 123 | + expected_t0 = np.array( |
| 124 | + [ |
| 125 | + [0, 0, 0], |
| 126 | + [1, 0, 0], |
| 127 | + [0, 1, 0], |
| 128 | + [1, 1, 0], |
| 129 | + ] |
| 130 | + ) |
| 131 | + np.testing.assert_array_almost_equal(pos_raw[0], expected_t0, decimal=5) |
| 132 | + |
| 133 | + # Second timestep should include displacement |
| 134 | + expected_t1 = expected_t0 + np.array([[1, 0, 0]] * 4) |
| 135 | + np.testing.assert_array_almost_equal(pos_raw[1], expected_t1, decimal=5) |
| 136 | + |
| 137 | + # Third timestep |
| 138 | + expected_t2 = expected_t0 + np.array([[2, 0, 0]] * 4) |
| 139 | + np.testing.assert_array_almost_equal(pos_raw[2], expected_t2, decimal=5) |
| 140 | + |
| 141 | + |
| 142 | +def test_extract_mesh_connectivity(): |
| 143 | + """Test mesh connectivity extraction from PolyData.""" |
| 144 | + points = np.array( |
| 145 | + [ |
| 146 | + [0, 0, 0], |
| 147 | + [1, 0, 0], |
| 148 | + [1, 1, 0], |
| 149 | + [0, 1, 0], |
| 150 | + ] |
| 151 | + ) |
| 152 | + |
| 153 | + # Create a single quad |
| 154 | + faces = np.array([4, 0, 1, 2, 3]) |
| 155 | + poly = pv.PolyData(points, faces, force_float=False) |
| 156 | + |
| 157 | + connectivity = extract_mesh_connectivity_from_polydata(poly) |
| 158 | + |
| 159 | + assert len(connectivity) == 1, f"Expected 1 cell, got {len(connectivity)}" |
| 160 | + assert len(connectivity[0]) == 4, f"Expected 4 vertices, got {len(connectivity[0])}" |
| 161 | + assert connectivity[0] == [0, 1, 2, 3], ( |
| 162 | + f"Expected [0, 1, 2, 3], got {connectivity[0]}" |
| 163 | + ) |
| 164 | + |
| 165 | + |
| 166 | +def test_build_edges_from_mesh_connectivity(): |
| 167 | + """Test edge building from mesh connectivity.""" |
| 168 | + # Single quad: should produce 4 edges |
| 169 | + mesh_connectivity = [[0, 1, 2, 3]] |
| 170 | + edges = build_edges_from_mesh_connectivity(mesh_connectivity) |
| 171 | + |
| 172 | + expected_edges = {(0, 1), (1, 2), (2, 3), (0, 3)} |
| 173 | + assert edges == expected_edges, f"Expected {expected_edges}, got {edges}" |
| 174 | + |
| 175 | + |
| 176 | +def test_point_data_extraction(simple_vtp_file): |
| 177 | + """Test that non-displacement point data is extracted correctly.""" |
| 178 | + _, _, point_data_dict = load_vtp_file(simple_vtp_file) |
| 179 | + |
| 180 | + # Should have thickness |
| 181 | + assert "thickness" in point_data_dict, "Thickness not in point_data_dict" |
| 182 | + |
| 183 | + # Should NOT have displacement fields |
| 184 | + assert "displacement_t0.000" not in point_data_dict, ( |
| 185 | + "Displacement fields should not be in point_data_dict" |
| 186 | + ) |
| 187 | + assert "displacement_t0.005" not in point_data_dict, ( |
| 188 | + "Displacement fields should not be in point_data_dict" |
| 189 | + ) |
| 190 | + |
| 191 | + # Check thickness values |
| 192 | + expected_thickness = np.array([1, 1, 1, 1], dtype=np.uint8) |
| 193 | + np.testing.assert_array_almost_equal( |
| 194 | + point_data_dict["thickness"], expected_thickness, decimal=5 |
| 195 | + ) |
| 196 | + |
| 197 | + |
| 198 | +def test_missing_displacement_fields(): |
| 199 | + """Test that missing displacement fields raises appropriate error.""" |
| 200 | + # Create VTP without displacement fields |
| 201 | + points = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0]]) |
| 202 | + faces = np.array([3, 0, 1, 2]) |
| 203 | + mesh = pv.PolyData(points, faces, force_float=False) |
| 204 | + |
| 205 | + with tempfile.NamedTemporaryFile(suffix=".vtp", delete=False) as f: |
| 206 | + temp_path = f.name |
| 207 | + |
| 208 | + mesh.save(temp_path) |
| 209 | + |
| 210 | + try: |
| 211 | + with pytest.raises(ValueError, match="No displacement fields found"): |
| 212 | + load_vtp_file(temp_path) |
| 213 | + finally: |
| 214 | + Path(temp_path).unlink(missing_ok=True) |
| 215 | + |
| 216 | + |
| 217 | +def test_empty_mesh_connectivity(): |
| 218 | + """Test edge building with empty connectivity.""" |
| 219 | + mesh_connectivity = [] |
| 220 | + edges = build_edges_from_mesh_connectivity(mesh_connectivity) |
| 221 | + |
| 222 | + assert len(edges) == 0, f"Expected 0 edges for empty connectivity, got {len(edges)}" |
0 commit comments