Skip to content

Commit a9a9c0a

Browse files
Formatting changes
1 parent d6271b6 commit a9a9c0a

File tree

2 files changed

+106
-77
lines changed

2 files changed

+106
-77
lines changed
Lines changed: 103 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
22
# SPDX-FileCopyrightText: All rights reserved.
33
# SPDX-License-Identifier: Apache-2.0
44
#
@@ -22,6 +22,7 @@
2222

2323
# Import functions from vtp_reader
2424
import sys
25+
2526
sys.path.insert(0, str(Path(__file__).parent.parent))
2627
from vtp_reader import (
2728
load_vtp_file,
@@ -34,136 +35,164 @@
3435
def simple_vtp_file():
3536
"""Create a simple VTP file for testing."""
3637
# Create a simple quad mesh (2x2 grid)
37-
points = np.array([
38-
[0, 0, 0],
39-
[1, 0, 0],
40-
[0, 1, 0],
41-
[1, 1, 0],
42-
], dtype=np.uint8)
43-
38+
points = np.array(
39+
[
40+
[0, 0, 0],
41+
[1, 0, 0],
42+
[0, 1, 0],
43+
[1, 1, 0],
44+
],
45+
dtype=np.uint8,
46+
)
47+
4448
# Single quad cell
4549
faces = np.array([4, 0, 1, 3, 2]) # quad with 4 vertices
46-
50+
4751
mesh = pv.PolyData(points, faces, force_float=False)
48-
52+
4953
# Add displacement fields for 3 timesteps
50-
mesh.point_data['displacement_t0.000'] = np.array([
51-
[0, 0, 0],
52-
[0, 0, 0],
53-
[0, 0, 0],
54-
[0, 0, 0],
55-
], dtype=np.uint8)
56-
57-
mesh.point_data['displacement_t0.005'] = np.array([
58-
[1, 0, 0],
59-
[1, 0, 0],
60-
[1, 0, 0],
61-
[1, 0, 0],
62-
], dtype=np.uint8)
63-
64-
mesh.point_data['displacement_t0.010'] = np.array([
65-
[2, 0, 0],
66-
[2, 0, 0],
67-
[2, 0, 0],
68-
[2, 0, 0],
69-
], dtype=np.uint8)
70-
54+
mesh.point_data["displacement_t0.000"] = np.array(
55+
[
56+
[0, 0, 0],
57+
[0, 0, 0],
58+
[0, 0, 0],
59+
[0, 0, 0],
60+
],
61+
dtype=np.uint8,
62+
)
63+
64+
mesh.point_data["displacement_t0.005"] = np.array(
65+
[
66+
[1, 0, 0],
67+
[1, 0, 0],
68+
[1, 0, 0],
69+
[1, 0, 0],
70+
],
71+
dtype=np.uint8,
72+
)
73+
74+
mesh.point_data["displacement_t0.010"] = np.array(
75+
[
76+
[2, 0, 0],
77+
[2, 0, 0],
78+
[2, 0, 0],
79+
[2, 0, 0],
80+
],
81+
dtype=np.uint8,
82+
)
83+
7184
# Add thickness as additional point data
72-
mesh.point_data['thickness'] = np.array([1, 1, 1, 1], dtype=np.uint8)
73-
85+
mesh.point_data["thickness"] = np.array([1, 1, 1, 1], dtype=np.uint8)
86+
7487
# Save to temporary file
75-
with tempfile.NamedTemporaryFile(suffix='.vtp', delete=False) as f:
88+
with tempfile.NamedTemporaryFile(suffix=".vtp", delete=False) as f:
7689
temp_path = f.name
77-
90+
7891
mesh.save(temp_path)
7992
yield temp_path
80-
93+
8194
# Cleanup
8295
Path(temp_path).unlink(missing_ok=True)
8396

8497

8598
def test_load_vtp_file_basic(simple_vtp_file):
8699
"""Test basic VTP file loading."""
87100
pos_raw, mesh_connectivity, point_data_dict = load_vtp_file(simple_vtp_file)
88-
101+
89102
# Check positions shape: (timesteps, nodes, 3)
90103
assert pos_raw.shape == (3, 4, 3), f"Expected shape (3, 4, 3), got {pos_raw.shape}"
91-
104+
92105
# Check mesh connectivity
93106
assert len(mesh_connectivity) == 1, f"Expected 1 cell, got {len(mesh_connectivity)}"
94-
assert len(mesh_connectivity[0]) == 4, f"Expected quad with 4 vertices, got {len(mesh_connectivity[0])}"
95-
107+
assert len(mesh_connectivity[0]) == 4, (
108+
f"Expected quad with 4 vertices, got {len(mesh_connectivity[0])}"
109+
)
110+
96111
# Check point data dict contains thickness
97-
assert 'thickness' in point_data_dict, "Thickness not found in point_data_dict"
98-
assert point_data_dict['thickness'].shape == (4,), f"Expected thickness shape (4,), got {point_data_dict['thickness'].shape}"
112+
assert "thickness" in point_data_dict, "Thickness not found in point_data_dict"
113+
assert point_data_dict["thickness"].shape == (4,), (
114+
f"Expected thickness shape (4,), got {point_data_dict['thickness'].shape}"
115+
)
99116

100117

101118
def test_load_vtp_file_displacements(simple_vtp_file):
102119
"""Test that displacements are correctly applied."""
103120
pos_raw, _, _ = load_vtp_file(simple_vtp_file)
104-
121+
105122
# First timestep should be reference coords (displacement = 0)
106-
expected_t0 = np.array([
107-
[0, 0, 0],
108-
[1, 0, 0],
109-
[0, 1, 0],
110-
[1, 1, 0],
111-
])
123+
expected_t0 = np.array(
124+
[
125+
[0, 0, 0],
126+
[1, 0, 0],
127+
[0, 1, 0],
128+
[1, 1, 0],
129+
]
130+
)
112131
np.testing.assert_array_almost_equal(pos_raw[0], expected_t0, decimal=5)
113-
132+
114133
# Second timestep should include displacement
115134
expected_t1 = expected_t0 + np.array([[1, 0, 0]] * 4)
116135
np.testing.assert_array_almost_equal(pos_raw[1], expected_t1, decimal=5)
117-
136+
118137
# Third timestep
119138
expected_t2 = expected_t0 + np.array([[2, 0, 0]] * 4)
120139
np.testing.assert_array_almost_equal(pos_raw[2], expected_t2, decimal=5)
121140

122141

123142
def test_extract_mesh_connectivity():
124143
"""Test mesh connectivity extraction from PolyData."""
125-
points = np.array([
126-
[0, 0, 0],
127-
[1, 0, 0],
128-
[1, 1, 0],
129-
[0, 1, 0],
130-
])
131-
144+
points = np.array(
145+
[
146+
[0, 0, 0],
147+
[1, 0, 0],
148+
[1, 1, 0],
149+
[0, 1, 0],
150+
]
151+
)
152+
132153
# Create a single quad
133154
faces = np.array([4, 0, 1, 2, 3])
134155
poly = pv.PolyData(points, faces, force_float=False)
135-
156+
136157
connectivity = extract_mesh_connectivity_from_polydata(poly)
137-
158+
138159
assert len(connectivity) == 1, f"Expected 1 cell, got {len(connectivity)}"
139160
assert len(connectivity[0]) == 4, f"Expected 4 vertices, got {len(connectivity[0])}"
140-
assert connectivity[0] == [0, 1, 2, 3], f"Expected [0, 1, 2, 3], got {connectivity[0]}"
161+
assert connectivity[0] == [0, 1, 2, 3], (
162+
f"Expected [0, 1, 2, 3], got {connectivity[0]}"
163+
)
141164

142165

143166
def test_build_edges_from_mesh_connectivity():
144167
"""Test edge building from mesh connectivity."""
145168
# Single quad: should produce 4 edges
146169
mesh_connectivity = [[0, 1, 2, 3]]
147170
edges = build_edges_from_mesh_connectivity(mesh_connectivity)
148-
171+
149172
expected_edges = {(0, 1), (1, 2), (2, 3), (0, 3)}
150173
assert edges == expected_edges, f"Expected {expected_edges}, got {edges}"
151174

152175

153176
def test_point_data_extraction(simple_vtp_file):
154177
"""Test that non-displacement point data is extracted correctly."""
155178
_, _, point_data_dict = load_vtp_file(simple_vtp_file)
156-
179+
157180
# Should have thickness
158-
assert 'thickness' in point_data_dict, "Thickness not in point_data_dict"
159-
181+
assert "thickness" in point_data_dict, "Thickness not in point_data_dict"
182+
160183
# Should NOT have displacement fields
161-
assert 'displacement_t0.000' not in point_data_dict, "Displacement fields should not be in point_data_dict"
162-
assert 'displacement_t0.005' not in point_data_dict, "Displacement fields should not be in point_data_dict"
163-
184+
assert "displacement_t0.000" not in point_data_dict, (
185+
"Displacement fields should not be in point_data_dict"
186+
)
187+
assert "displacement_t0.005" not in point_data_dict, (
188+
"Displacement fields should not be in point_data_dict"
189+
)
190+
164191
# Check thickness values
165192
expected_thickness = np.array([1, 1, 1, 1], dtype=np.uint8)
166-
np.testing.assert_array_almost_equal(point_data_dict['thickness'], expected_thickness, decimal=5)
193+
np.testing.assert_array_almost_equal(
194+
point_data_dict["thickness"], expected_thickness, decimal=5
195+
)
167196

168197

169198
def test_missing_displacement_fields():
@@ -172,12 +201,12 @@ def test_missing_displacement_fields():
172201
points = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0]])
173202
faces = np.array([3, 0, 1, 2])
174203
mesh = pv.PolyData(points, faces, force_float=False)
175-
176-
with tempfile.NamedTemporaryFile(suffix='.vtp', delete=False) as f:
204+
205+
with tempfile.NamedTemporaryFile(suffix=".vtp", delete=False) as f:
177206
temp_path = f.name
178-
207+
179208
mesh.save(temp_path)
180-
209+
181210
try:
182211
with pytest.raises(ValueError, match="No displacement fields found"):
183212
load_vtp_file(temp_path)
@@ -189,5 +218,5 @@ def test_empty_mesh_connectivity():
189218
"""Test edge building with empty connectivity."""
190219
mesh_connectivity = []
191220
edges = build_edges_from_mesh_connectivity(mesh_connectivity)
192-
221+
193222
assert len(edges) == 0, f"Expected 0 edges for empty connectivity, got {len(edges)}"

examples/structural_mechanics/crash/vtp_reader.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,13 @@ def natural_key(name):
105105

106106
pos_raw = np.stack(pos_list, axis=0)
107107
mesh_connectivity = extract_mesh_connectivity_from_polydata(poly)
108-
108+
109109
# Extract all other point data fields (not displacement fields)
110110
point_data_dict = {}
111111
for name in poly.point_data.keys():
112112
if not name.startswith("displacement_"):
113113
point_data_dict[name] = np.asarray(poly.point_data[name])
114-
114+
115115
return pos_raw, mesh_connectivity, point_data_dict
116116

117117

@@ -208,7 +208,7 @@ def process_vtp_data(data_dir, num_samples=2, write_vtp=False, logger=None):
208208
write_vtp=write_vtp,
209209
logger=logger,
210210
)
211-
211+
212212
# Create record with coords and all other point data fields
213213
record = {"coords": mesh_pos_all}
214214
record.update(point_data_dict) # Add thickness and any other fields

0 commit comments

Comments
 (0)