Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions examples/structural_mechanics/crash/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ For an in-depth comparison between the Transolver and MeshGraphNet models and th

<p align="center">
<img src="../../../docs/img/crash/crash_case4_reduced.gif" alt="Crash animation" width="80%" />

</p>

### Crushcan Modeling

<p align="center">
<img src="../../../docs/img/crash/crushcan.gif" alt="Crushcan animation" width="80%" />

</p>

## Quickstart
Expand Down Expand Up @@ -238,7 +238,10 @@ conf/
│ ├── mgn_time_conditional.yaml
│ ├── transolver_autoregressive_rollout_training.yaml
│ ├── transolver_one_step_rollout.yaml
│ └── transolver_time_conditional.yaml
│ ├── transolver_time_conditional.yaml
│ ├── figconvunet_autoregressive_rollout_training.yaml
│ ├── figconvunet_one_step_rollout.yaml
│ └── figconvunet_time_conditional.yaml
├── training/default.yaml # training hyperparameters
└── inference/default.yaml # inference options
```
Expand Down Expand Up @@ -495,7 +498,7 @@ run_post_processing.sh can automate all evaluation tasks across runs.

- AMP is enabled by default in training; it reduces memory and accelerates matmuls on modern GPUs.
- For multi-GPU training, use `torchrun --standalone --nproc_per_node=<NUM_GPUS> train.py`.
- For DDP, prefer `torchrun --standalone --nproc_per_node=<NUM_GPUS> train.py`.
- For DDP, prefer `torchrun --standalone --nproc_per_node=<NUM_GPUS> train.py`.

## Troubleshooting / FAQ

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

_target_: rollout.FIGConvUNetAutoregressiveRolloutTraining
_convert_: all

# Input/output channels
in_channels: 5 # velocity(3) + features(F) + time(1)
out_channels: 3 # acceleration (xyz)

# Architecture
kernel_size: 3
hidden_channels: [16, 16, 16] # channels at each level
num_levels: 2 # number of down/up levels
num_down_blocks: 1
num_up_blocks: 1
mlp_channels: [256, 256]

# Spatial domain
aabb_max: [2.0, 2.0, 2.0]
aabb_min: [-2.0, -2.0, -2.0]
voxel_size: null

# Grid resolutions (factorized implicit grids)
# Format: [memory_format, resolution_tuple]
resolution_memory_format_pairs:
- [b_xc_y_z, [2, 64, 64]]
- [b_yc_x_z, [64, 2, 64]]
- [b_zc_x_y, [64, 64, 2]]

# Position encoding
use_rel_pos: true
use_rel_pos_embed: true
pos_encode_dim: 16

# Communication and sampling
communication_types: ["sum"]
to_point_sample_method: "graphconv"
neighbor_search_type: "knn"
knn_k: 16
reductions: ["mean"]

use_scalar_output: false

# Pooling (for global features if needed)
pooling_type: "max"
pooling_layers: [2]

# Rollout parameters
num_time_steps: ${training.num_time_steps}
dt: 5e-3
initial_vel: 9.22
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

_target_: rollout.FIGConvUNetOneStepRollout
_convert_: all

# Input/output channels
in_channels: 4 # velocity(3) + features (F)
out_channels: 3 # next step position (xyz)

# Architecture
kernel_size: 3
hidden_channels: [16, 16, 16] # channels at each level
num_levels: 2 # number of down/up levels
num_down_blocks: 1
num_up_blocks: 1
mlp_channels: [256, 256]

# Spatial domain
aabb_max: [2.0, 2.0, 2.0]
aabb_min: [-2.0, -2.0, -2.0]
voxel_size: null

# Grid resolutions (factorized implicit grids)
# Format: Uses res_mem_pair resolver (memory_format_enum, resolution_tuple)
resolution_memory_format_pairs:
- [b_xc_y_z, [2, 64, 64]]
- [b_yc_x_z, [64, 2, 64]]
- [b_zc_x_y, [64, 64, 2]]

# Position encoding
use_rel_pos: true
use_rel_pos_embed: true
pos_encode_dim: 16

# Communication and sampling
communication_types: ["sum"]
to_point_sample_method: "graphconv"
neighbor_search_type: "knn"
knn_k: 16
reductions: ["mean"]

use_scalar_output: false

# Pooling (for global features if needed)
pooling_type: "max"
pooling_layers: [2]

# Rollout parameters
num_time_steps: ${training.num_time_steps}
dt: 5e-3
initial_vel: 9.22
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ neighbor_search_type: "knn"
knn_k: 16
reductions: ["mean"]

use_scalar_output: false

# Pooling (for global features if needed)
pooling_type: "max"
pooling_layers: [2]
Expand Down
168 changes: 168 additions & 0 deletions examples/structural_mechanics/crash/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,3 +467,171 @@ def step_fn(verts, feats):
outputs.append(y_t)

return torch.stack(outputs, dim=0) # [T, N, 3]


class FIGConvUNetOneStepRollout(FIGConvUNet):
"""
FIGConvUNet with one-step rollout for crash simulation.

- Training: teacher forcing (uses GT positions at each step)
- Inference: autoregressive (uses predictions)
"""

def __init__(self, *args, **kwargs):
self.dt: float = kwargs.pop("dt", 5e-3)
self.initial_vel: torch.Tensor = kwargs.pop("initial_vel")
self.rollout_steps: int = kwargs.pop("num_time_steps") - 1
super().__init__(*args, **kwargs)

def forward(self, sample: SimSample, data_stats: dict) -> torch.Tensor:
"""
Args:
Sample: SimSample containing node_features and node_target
data_stats: dict containing normalization stats
Returns:
[T, N, 3] rollout of predicted positions
"""
inputs = sample.node_features
x0 = inputs["coords"] # initial pos [N, 3]
features = inputs.get("features", x0.new_zeros((x0.size(0), 0))) # [N, F]

# Ground truth sequence [T, N, 3]
N = x0.size(0)
gt_seq = torch.cat(
[x0.unsqueeze(0), sample.node_target.view(N, -1, 3).transpose(0, 1)],
dim=0,
)

outputs: list[torch.Tensor] = []
# First step: backstep to create y_-1
y_t0 = gt_seq[0] - self.initial_vel * self.dt
y_t1 = gt_seq[0]

for t in range(self.rollout_steps):
# In training mode (except first step), use ground truth positions
if self.training and t > 0:
y_t0, y_t1 = gt_seq[t - 1], gt_seq[t]

# Prepare vertices for FIGConvUNet: [1, N, 3]
vertices = y_t1.unsqueeze(0) # [1, N, 3]

vel = (y_t1 - y_t0) / self.dt
vel_norm = (vel - data_stats["node"]["norm_vel_mean"]) / (
data_stats["node"]["norm_vel_std"] + EPS
)

# [1, N, 3 + F]
fx_t = torch.cat([vel_norm, features], dim=-1).unsqueeze(0)

def step_fn(verts, feats):
out, _ = super(FIGConvUNetOneStepRollout, self).forward(
vertices=verts, features=feats
)
return out

if self.training:
outf = ckpt(
step_fn,
vertices,
fx_t,
use_reentrant=False,
).squeeze(0) # [N, 3]
else:
outf = step_fn(vertices, fx_t).squeeze(0) # [N, 3]

acc = (
outf * data_stats["node"]["norm_acc_std"]
+ data_stats["node"]["norm_acc_mean"]
)
vel_pred = self.dt * acc + vel
y_t2_pred = self.dt * vel_pred + y_t1

outputs.append(y_t2_pred)

if not self.training:
# autoregressive update for inference
y_t0, y_t1 = y_t1, y_t2_pred

return torch.stack(outputs, dim=0) # [T, N, 3]


class FIGConvUNetAutoregressiveRolloutTraining(FIGConvUNet):
"""
FIGConvUNet with autoregressive rollout training for crash simulation.

Predicts sequence by autoregressively updating velocity and position
using predicted accelerations. Supports gradient checkpointing during training.
"""

def __init__(self, *args, **kwargs):
self.dt: float = kwargs.pop("dt")
self.initial_vel: torch.Tensor = kwargs.pop("initial_vel")
self.rollout_steps: int = kwargs.pop("num_time_steps") - 1
super().__init__(*args, **kwargs)

def forward(self, sample: SimSample, data_stats: dict) -> torch.Tensor:
"""
Args:
sample: SimSample containing node_features and node_target
data_stats: dict containing normalization stats
Returns:
[T, N, 3] rollout of predicted positions
"""
inputs = sample.node_features
coords = inputs["coords"] # [N, 3]
features = inputs.get("features", coords.new_zeros((coords.size(0), 0)))
N = coords.size(0)
device = coords.device

# Initial states
y_t1 = coords # [N, 3]
y_t0 = y_t1 - self.initial_vel * self.dt # backstep using initial velocity

outputs: list[torch.Tensor] = []
for t in range(self.rollout_steps):
time_t = 0.0 if self.rollout_steps <= 1 else t / (self.rollout_steps - 1)
time_t = torch.tensor([time_t], device=device, dtype=torch.float32)

# Velocity normalization
vel = (y_t1 - y_t0) / self.dt
vel_norm = (vel - data_stats["node"]["norm_vel_mean"]) / (
data_stats["node"]["norm_vel_std"] + EPS
)

# Prepare vertices for FIGConvUNet: [1, N, 3]
vertices = y_t1.unsqueeze(0) # [1, N, 3]

# Prepare features: vel_norm + features + time [N, 3+F+1]
fx_t = torch.cat(
[vel_norm, features, time_t.expand(N, 1)], dim=-1
) # [N, 3+F+1]
fx_t = fx_t.unsqueeze(0) # [1, N, 3+F+1]

def step_fn(verts, feats):
out, _ = super(
FIGConvUNetAutoregressiveRolloutTraining, self
).forward(vertices=verts, features=feats)
return out

if self.training:
outf = ckpt(
step_fn,
vertices,
fx_t,
use_reentrant=False,
).squeeze(0) # [N, 3]
else:
outf = step_fn(vertices, fx_t).squeeze(0) # [N, 3]

# De-normalize acceleration
acc = (
outf * data_stats["node"]["norm_acc_std"]
+ data_stats["node"]["norm_acc_mean"]
)
vel = self.dt * acc + vel
y_t2 = self.dt * vel + y_t1

outputs.append(y_t2)
y_t1, y_t0 = y_t2, y_t1

return torch.stack(outputs, dim=0) # [T, N, 3]
Loading