Skip to content

Commit c3ad24c

Browse files
Add FIGConvNet to crash example (#1207)
* Add FIGConvNet to crash example. * Add FIGConvNet to crash example * Update model config
1 parent 04d5fe9 commit c3ad24c

File tree

5 files changed

+130
-11
lines changed

5 files changed

+130
-11
lines changed
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
2+
# SPDX-FileCopyrightText: All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
_target_: rollout.FIGConvUNetTimeConditionalRollout
18+
_convert_: all
19+
20+
# Input/output channels
21+
in_channels: 2 # thickness + time
22+
out_channels: 3 # displacement offset (xyz)
23+
24+
# Architecture
25+
kernel_size: 3
26+
hidden_channels: [16, 16, 16] # channels at each level
27+
num_levels: 2 # number of down/up levels
28+
num_down_blocks: 1
29+
num_up_blocks: 1
30+
mlp_channels: [256, 256]
31+
32+
# Spatial domain
33+
aabb_max: [2.0, 2.0, 2.0]
34+
aabb_min: [-2.0, -2.0, -2.0]
35+
voxel_size: null
36+
37+
# Grid resolutions (factorized implicit grids)
38+
# Format: Uses res_mem_pair resolver (memory_format_enum, resolution_tuple)
39+
resolution_memory_format_pairs:
40+
- [b_xc_y_z, [2, 64, 64]]
41+
- [b_yc_x_z, [64, 2, 64]]
42+
- [b_zc_x_y, [64, 64, 2]]
43+
44+
# Position encoding
45+
use_rel_pos: true
46+
use_rel_pos_embed: true
47+
pos_encode_dim: 16
48+
49+
# Communication and sampling
50+
communication_types: ["sum"]
51+
to_point_sample_method: "graphconv"
52+
neighbor_search_type: "knn"
53+
knn_k: 16
54+
reductions: ["mean"]
55+
56+
# Pooling (for global features if needed)
57+
pooling_type: "max"
58+
pooling_layers: [2]
59+
60+
# Rollout parameters
61+
num_time_steps: ${training.num_time_steps}

examples/structural_mechanics/crash/datapipe.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -350,13 +350,6 @@ def _normalize_node_tensor(
350350
)
351351
return (invar - mu.view(1, 1, -1)) / (std.view(1, 1, -1) + EPS)
352352

353-
@staticmethod
354-
def _normalize_thickness_tensor(
355-
thickness: torch.Tensor, mu: torch.Tensor, std: torch.Tensor
356-
):
357-
# thickness: [N], mu/std: scalar tensors
358-
return (thickness - mu) / (std + EPS)
359-
360353

361354
class CrashGraphDataset(CrashBaseDataset):
362355
"""

examples/structural_mechanics/crash/rollout.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
# limitations under the License.
1616

1717
import torch
18-
import torch.nn as nn
1918
from torch.utils.checkpoint import checkpoint as ckpt
2019

2120
from physicsnemo.models.transolver import Transolver
2221
from physicsnemo.models.meshgraphnet import MeshGraphNet
22+
from physicsnemo.models.figconvnet.figconvunet import FIGConvUNet
2323

2424
from datapipe import SimSample
2525

@@ -406,3 +406,64 @@ def step_fn(nf, ef, g):
406406
y_t0, y_t1 = y_t1, y_t2_pred
407407

408408
return torch.stack(outputs, dim=0) # [T,N,3]
409+
410+
411+
class FIGConvUNetTimeConditionalRollout(FIGConvUNet):
412+
"""
413+
FIGConvUNet with time-conditional rollout for crash simulation.
414+
415+
Predicts each time step independently, conditioned on normalized time.
416+
"""
417+
418+
def __init__(self, *args, **kwargs):
419+
self.rollout_steps: int = kwargs.pop("num_time_steps") - 1
420+
super().__init__(*args, **kwargs)
421+
422+
def forward(
423+
self,
424+
sample: SimSample,
425+
data_stats: dict,
426+
) -> torch.Tensor:
427+
"""
428+
Args:
429+
Sample: SimSample containing node_features and node_target
430+
data_stats: dict containing normalization stats
431+
Returns:
432+
[T, N, 3] rollout of predicted positions
433+
"""
434+
inputs = sample.node_features
435+
x = inputs["coords"] # initial pos [N, 3]
436+
features = inputs.get("features", x.new_zeros((x.size(0), 0))) # [N, F]
437+
438+
outputs: list[torch.Tensor] = []
439+
time_seq = torch.linspace(0.0, 1.0, self.rollout_steps, device=x.device)
440+
441+
for time_t in time_seq:
442+
# Prepare vertices for FIGConvUNet: [1, N, 3]
443+
vertices = x.unsqueeze(0) # [1, N, 3]
444+
445+
# Prepare features: features + time [N, F+1]
446+
time_expanded = time_t.expand(x.size(0), 1) # [N, 1]
447+
features_t = torch.cat([features, time_expanded], dim=-1) # [N, F+1]
448+
features_t = features_t.unsqueeze(0) # [1, N, F+1]
449+
450+
def step_fn(verts, feats):
451+
out, _ = super(FIGConvUNetTimeConditionalRollout, self).forward(
452+
vertices=verts, features=feats
453+
)
454+
return out
455+
456+
if self.training:
457+
outf = ckpt(
458+
step_fn,
459+
vertices,
460+
features_t,
461+
use_reentrant=False,
462+
).squeeze(0) # [N, 3]
463+
else:
464+
outf = step_fn(vertices, features_t).squeeze(0) # [N, 3]
465+
466+
y_t = x + outf
467+
outputs.append(y_t)
468+
469+
return torch.stack(outputs, dim=0) # [T, N, 3]

examples/structural_mechanics/crash/train.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from torch.nn.parallel import DistributedDataParallel
3131
from torch.utils.data.distributed import DistributedSampler
3232
from torch.utils.tensorboard import SummaryWriter
33-
from tqdm import tqdm
3433

3534
from physicsnemo.distributed.manager import DistributedManager
3635
from physicsnemo.launch.logging import PythonLogger, RankZeroLoggingWrapper
@@ -64,6 +63,11 @@ def __init__(self, cfg: DictConfig, logger0: RankZeroLoggingWrapper):
6463
f"Model {model_name} requires a point-cloud datapipe, "
6564
f"but you selected {datapipe_name}."
6665
)
66+
if "FIGConvUNet" in model_name and "PointCloudDataset" not in datapipe_name:
67+
raise ValueError(
68+
f"Model {model_name} requires a point-cloud datapipe, "
69+
f"but you selected {datapipe_name}."
70+
)
6771

6872
# Dataset
6973
reader = instantiate(cfg.reader)
@@ -223,7 +227,7 @@ def main(cfg: DictConfig) -> None:
223227
for sample in trainer.dataloader:
224228
sample = sample[0].to(dist.device) # SimSample .to()
225229
loss = trainer.train(sample)
226-
total_loss += loss.item()
230+
total_loss += loss.detach().item()
227231
num_batches += 1
228232

229233
trainer.scheduler.step()

examples/structural_mechanics/crash/vtp_reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def process_vtp_data(data_dir, num_samples=2, write_vtp=False, logger=None):
169169

170170
if not vtp_files:
171171
if logger:
172-
logger.error("No .vtp files found in:", base_data_dir)
172+
logger.error(f"No .vtp files found in: {base_data_dir}")
173173
exit(1)
174174

175175
for vtp_path in vtp_files:

0 commit comments

Comments
 (0)