diff --git a/examples/structural_mechanics/crash/README.md b/examples/structural_mechanics/crash/README.md
index 6b8a5ebb3e..f3e8d97833 100644
--- a/examples/structural_mechanics/crash/README.md
+++ b/examples/structural_mechanics/crash/README.md
@@ -19,14 +19,14 @@ For an in-depth comparison between the Transolver and MeshGraphNet models and th
-
+
### Crushcan Modeling
-
+
## Quickstart
@@ -238,7 +238,10 @@ conf/
│ ├── mgn_time_conditional.yaml
│ ├── transolver_autoregressive_rollout_training.yaml
│ ├── transolver_one_step_rollout.yaml
-│ └── transolver_time_conditional.yaml
+│ ├── transolver_time_conditional.yaml
+│ ├── figconvunet_autoregressive_rollout_training.yaml
+│ ├── figconvunet_one_step_rollout.yaml
+│ └── figconvunet_time_conditional.yaml
├── training/default.yaml # training hyperparameters
└── inference/default.yaml # inference options
```
@@ -495,7 +498,7 @@ run_post_processing.sh can automate all evaluation tasks across runs.
- AMP is enabled by default in training; it reduces memory and accelerates matmuls on modern GPUs.
- For multi-GPU training, use `torchrun --standalone --nproc_per_node= train.py`.
-- For DDP, prefer `torchrun --standalone --nproc_per_node= train.py`.
+- For DDP, prefer `torchrun --standalone --nproc_per_node= train.py`.
## Troubleshooting / FAQ
diff --git a/examples/structural_mechanics/crash/conf/model/figconvunet_autoregressive_rollout_training.yaml b/examples/structural_mechanics/crash/conf/model/figconvunet_autoregressive_rollout_training.yaml
new file mode 100644
index 0000000000..f7a9f1f672
--- /dev/null
+++ b/examples/structural_mechanics/crash/conf/model/figconvunet_autoregressive_rollout_training.yaml
@@ -0,0 +1,66 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+_target_: rollout.FIGConvUNetAutoregressiveRolloutTraining
+_convert_: all
+
+# Input/output channels
+in_channels: 5 # velocity(3) + features(F) + time(1)
+out_channels: 3 # acceleration (xyz)
+
+# Architecture
+kernel_size: 3
+hidden_channels: [16, 16, 16] # channels at each level
+num_levels: 2 # number of down/up levels
+num_down_blocks: 1
+num_up_blocks: 1
+mlp_channels: [256, 256]
+
+# Spatial domain
+aabb_max: [2.0, 2.0, 2.0]
+aabb_min: [-2.0, -2.0, -2.0]
+voxel_size: null
+
+# Grid resolutions (factorized implicit grids)
+# Format: [memory_format, resolution_tuple]
+resolution_memory_format_pairs:
+ - [b_xc_y_z, [2, 64, 64]]
+ - [b_yc_x_z, [64, 2, 64]]
+ - [b_zc_x_y, [64, 64, 2]]
+
+# Position encoding
+use_rel_pos: true
+use_rel_pos_embed: true
+pos_encode_dim: 16
+
+# Communication and sampling
+communication_types: ["sum"]
+to_point_sample_method: "graphconv"
+neighbor_search_type: "knn"
+knn_k: 16
+reductions: ["mean"]
+
+use_scalar_output: false
+has_input_features: true
+
+# Pooling (for global features if needed)
+pooling_type: "max"
+pooling_layers: [2]
+
+# Rollout parameters
+num_time_steps: ${training.num_time_steps}
+dt: 5e-3
+initial_vel: 9.22
diff --git a/examples/structural_mechanics/crash/conf/model/figconvunet_one_step_rollout.yaml b/examples/structural_mechanics/crash/conf/model/figconvunet_one_step_rollout.yaml
new file mode 100644
index 0000000000..6272057bcf
--- /dev/null
+++ b/examples/structural_mechanics/crash/conf/model/figconvunet_one_step_rollout.yaml
@@ -0,0 +1,66 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+_target_: rollout.FIGConvUNetOneStepRollout
+_convert_: all
+
+# Input/output channels
+in_channels: 4 # velocity(3) + features (F)
+out_channels: 3 # next step position (xyz)
+
+# Architecture
+kernel_size: 3
+hidden_channels: [16, 16, 16] # channels at each level
+num_levels: 2 # number of down/up levels
+num_down_blocks: 1
+num_up_blocks: 1
+mlp_channels: [256, 256]
+
+# Spatial domain
+aabb_max: [2.0, 2.0, 2.0]
+aabb_min: [-2.0, -2.0, -2.0]
+voxel_size: null
+
+# Grid resolutions (factorized implicit grids)
+# Format: Uses res_mem_pair resolver (memory_format_enum, resolution_tuple)
+resolution_memory_format_pairs:
+ - [b_xc_y_z, [2, 64, 64]]
+ - [b_yc_x_z, [64, 2, 64]]
+ - [b_zc_x_y, [64, 64, 2]]
+
+# Position encoding
+use_rel_pos: true
+use_rel_pos_embed: true
+pos_encode_dim: 16
+
+# Communication and sampling
+communication_types: ["sum"]
+to_point_sample_method: "graphconv"
+neighbor_search_type: "knn"
+knn_k: 16
+reductions: ["mean"]
+
+use_scalar_output: false
+has_input_features: true
+
+# Pooling (for global features if needed)
+pooling_type: "max"
+pooling_layers: [2]
+
+# Rollout parameters
+num_time_steps: ${training.num_time_steps}
+dt: 5e-3
+initial_vel: 9.22
diff --git a/examples/structural_mechanics/crash/conf/model/figconvunet_time_conditional.yaml b/examples/structural_mechanics/crash/conf/model/figconvunet_time_conditional.yaml
index 8884737ce1..56ff2e4ddd 100644
--- a/examples/structural_mechanics/crash/conf/model/figconvunet_time_conditional.yaml
+++ b/examples/structural_mechanics/crash/conf/model/figconvunet_time_conditional.yaml
@@ -53,6 +53,9 @@ neighbor_search_type: "knn"
knn_k: 16
reductions: ["mean"]
+use_scalar_output: false
+has_input_features: true
+
# Pooling (for global features if needed)
pooling_type: "max"
pooling_layers: [2]
diff --git a/examples/structural_mechanics/crash/rollout.py b/examples/structural_mechanics/crash/rollout.py
index bab5f9dd54..9bdaaa1bb1 100644
--- a/examples/structural_mechanics/crash/rollout.py
+++ b/examples/structural_mechanics/crash/rollout.py
@@ -467,3 +467,171 @@ def step_fn(verts, feats):
outputs.append(y_t)
return torch.stack(outputs, dim=0) # [T, N, 3]
+
+
+class FIGConvUNetOneStepRollout(FIGConvUNet):
+ """
+ FIGConvUNet with one-step rollout for crash simulation.
+
+ - Training: teacher forcing (uses GT positions at each step)
+ - Inference: autoregressive (uses predictions)
+ """
+
+ def __init__(self, *args, **kwargs):
+ self.dt: float = kwargs.pop("dt", 5e-3)
+ self.initial_vel: torch.Tensor = kwargs.pop("initial_vel")
+ self.rollout_steps: int = kwargs.pop("num_time_steps") - 1
+ super().__init__(*args, **kwargs)
+
+ def forward(self, sample: SimSample, data_stats: dict) -> torch.Tensor:
+ """
+ Args:
+ Sample: SimSample containing node_features and node_target
+ data_stats: dict containing normalization stats
+ Returns:
+ [T, N, 3] rollout of predicted positions
+ """
+ inputs = sample.node_features
+ x0 = inputs["coords"] # initial pos [N, 3]
+ features = inputs.get("features", x0.new_zeros((x0.size(0), 0))) # [N, F]
+
+ # Ground truth sequence [T, N, 3]
+ N = x0.size(0)
+ gt_seq = torch.cat(
+ [x0.unsqueeze(0), sample.node_target.view(N, -1, 3).transpose(0, 1)],
+ dim=0,
+ )
+
+ outputs: list[torch.Tensor] = []
+ # First step: backstep to create y_-1
+ y_t0 = gt_seq[0] - self.initial_vel * self.dt
+ y_t1 = gt_seq[0]
+
+ for t in range(self.rollout_steps):
+ # In training mode (except first step), use ground truth positions
+ if self.training and t > 0:
+ y_t0, y_t1 = gt_seq[t - 1], gt_seq[t]
+
+ # Prepare vertices for FIGConvUNet: [1, N, 3]
+ vertices = y_t1.unsqueeze(0) # [1, N, 3]
+
+ vel = (y_t1 - y_t0) / self.dt
+ vel_norm = (vel - data_stats["node"]["norm_vel_mean"]) / (
+ data_stats["node"]["norm_vel_std"] + EPS
+ )
+
+ # [1, N, 3 + F]
+ fx_t = torch.cat([vel_norm, features], dim=-1).unsqueeze(0)
+
+ def step_fn(verts, feats):
+ out, _ = super(FIGConvUNetOneStepRollout, self).forward(
+ vertices=verts, features=feats
+ )
+ return out
+
+ if self.training:
+ outf = ckpt(
+ step_fn,
+ vertices,
+ fx_t,
+ use_reentrant=False,
+ ).squeeze(0) # [N, 3]
+ else:
+ outf = step_fn(vertices, fx_t).squeeze(0) # [N, 3]
+
+ acc = (
+ outf * data_stats["node"]["norm_acc_std"]
+ + data_stats["node"]["norm_acc_mean"]
+ )
+ vel_pred = self.dt * acc + vel
+ y_t2_pred = self.dt * vel_pred + y_t1
+
+ outputs.append(y_t2_pred)
+
+ if not self.training:
+ # autoregressive update for inference
+ y_t0, y_t1 = y_t1, y_t2_pred
+
+ return torch.stack(outputs, dim=0) # [T, N, 3]
+
+
+class FIGConvUNetAutoregressiveRolloutTraining(FIGConvUNet):
+ """
+ FIGConvUNet with autoregressive rollout training for crash simulation.
+
+ Predicts sequence by autoregressively updating velocity and position
+ using predicted accelerations. Supports gradient checkpointing during training.
+ """
+
+ def __init__(self, *args, **kwargs):
+ self.dt: float = kwargs.pop("dt")
+ self.initial_vel: torch.Tensor = kwargs.pop("initial_vel")
+ self.rollout_steps: int = kwargs.pop("num_time_steps") - 1
+ super().__init__(*args, **kwargs)
+
+ def forward(self, sample: SimSample, data_stats: dict) -> torch.Tensor:
+ """
+ Args:
+ sample: SimSample containing node_features and node_target
+ data_stats: dict containing normalization stats
+ Returns:
+ [T, N, 3] rollout of predicted positions
+ """
+ inputs = sample.node_features
+ coords = inputs["coords"] # [N, 3]
+ features = inputs.get("features", coords.new_zeros((coords.size(0), 0)))
+ N = coords.size(0)
+ device = coords.device
+
+ # Initial states
+ y_t1 = coords # [N, 3]
+ y_t0 = y_t1 - self.initial_vel * self.dt # backstep using initial velocity
+
+ outputs: list[torch.Tensor] = []
+ for t in range(self.rollout_steps):
+ time_t = 0.0 if self.rollout_steps <= 1 else t / (self.rollout_steps - 1)
+ time_t = torch.tensor([time_t], device=device, dtype=torch.float32)
+
+ # Velocity normalization
+ vel = (y_t1 - y_t0) / self.dt
+ vel_norm = (vel - data_stats["node"]["norm_vel_mean"]) / (
+ data_stats["node"]["norm_vel_std"] + EPS
+ )
+
+ # Prepare vertices for FIGConvUNet: [1, N, 3]
+ vertices = y_t1.unsqueeze(0) # [1, N, 3]
+
+ # Prepare features: vel_norm + features + time [N, 3+F+1]
+ fx_t = torch.cat(
+ [vel_norm, features, time_t.expand(N, 1)], dim=-1
+ ) # [N, 3+F+1]
+ fx_t = fx_t.unsqueeze(0) # [1, N, 3+F+1]
+
+ def step_fn(verts, feats):
+ out, _ = super(FIGConvUNetAutoregressiveRolloutTraining, self).forward(
+ vertices=verts, features=feats
+ )
+ return out
+
+ if self.training:
+ outf = ckpt(
+ step_fn,
+ vertices,
+ fx_t,
+ use_reentrant=False,
+ ).squeeze(0) # [N, 3]
+ else:
+ outf = step_fn(vertices, fx_t).squeeze(0) # [N, 3]
+
+ # De-normalize acceleration
+ acc = (
+ outf * data_stats["node"]["norm_acc_std"]
+ + data_stats["node"]["norm_acc_mean"]
+ )
+ vel = self.dt * acc + vel
+ y_t2 = self.dt * vel + y_t1
+
+ outputs.append(y_t2)
+ y_t1, y_t0 = y_t2, y_t1
+
+ return torch.stack(outputs, dim=0) # [T, N, 3]
diff --git a/examples/structural_mechanics/crash/tests/test_rollout.py b/examples/structural_mechanics/crash/tests/test_rollout.py
index f691ef7e6b..6cb6cc420a 100644
--- a/examples/structural_mechanics/crash/tests/test_rollout.py
+++ b/examples/structural_mechanics/crash/tests/test_rollout.py
@@ -90,6 +90,22 @@ def mgn_forward(self, node_features=None, edge_features=None, graph=None):
monkeypatch.setattr(rollout.MeshGraphNet, "__init__", mgn_init, raising=True)
monkeypatch.setattr(rollout.MeshGraphNet, "forward", mgn_forward, raising=True)
+ # Stub FIGConvUNet.__init__ and FIGConvUNet.forward
+ def figconvunet_init(self, *args, **kwargs):
+ torch.nn.Module.__init__(self)
+
+ def figconvunet_forward(self, vertices=None, features=None):
+ # Return zeros with shape matching vertices
+ # vertices: [B, N, 3], features: [B, N, F]
+ # output: [B, N, 3]
+ assert vertices is not None
+ return torch.zeros_like(vertices), None
+
+ monkeypatch.setattr(rollout.FIGConvUNet, "__init__", figconvunet_init, raising=True)
+ monkeypatch.setattr(
+ rollout.FIGConvUNet, "forward", figconvunet_forward, raising=True
+ )
+
def test_transolver_autoregressive_rollout_eval():
N, T, F = 5, 4, 2
@@ -181,3 +197,43 @@ class DummyGraph:
out = model.forward(sample=sample, data_stats=stats)
assert out.shape == (T - 1, N, 3)
+
+
+def test_figconvunet_time_conditional_rollout_eval():
+ N, T, F = 6, 5, 3
+ sample = make_sample(N=N, T=T, F=F)
+ stats = make_data_stats()
+
+ model = rollout.FIGConvUNetTimeConditionalRollout(num_time_steps=T)
+ model.eval()
+
+ out = model.forward(sample=sample, data_stats=stats)
+ assert out.shape == (T - 1, N, 3)
+
+
+def test_figconvunet_one_step_rollout_eval():
+ N, T, F = 7, 6, 1
+ sample = make_sample(N=N, T=T, F=F)
+ stats = make_data_stats()
+
+ model = rollout.FIGConvUNetOneStepRollout(
+ dt=5e-3, initial_vel=torch.zeros(1, 3), num_time_steps=T
+ )
+ model.eval()
+
+ out = model.forward(sample=sample, data_stats=stats)
+ assert out.shape == (T - 1, N, 3)
+
+
+def test_figconvunet_autoregressive_rollout_eval():
+ N, T, F = 5, 4, 2
+ sample = make_sample(N=N, T=T, F=F)
+ stats = make_data_stats()
+
+ model = rollout.FIGConvUNetAutoregressiveRolloutTraining(
+ dt=5e-3, initial_vel=torch.zeros(1, 3), num_time_steps=T
+ )
+ model.eval()
+
+ out = model.forward(sample=sample, data_stats=stats)
+ assert out.shape == (T - 1, N, 3)
diff --git a/physicsnemo/models/figconvnet/figconvunet.py b/physicsnemo/models/figconvnet/figconvunet.py
index 78c87f5699..23ab616168 100644
--- a/physicsnemo/models/figconvnet/figconvunet.py
+++ b/physicsnemo/models/figconvnet/figconvunet.py
@@ -149,7 +149,8 @@ def __init__(
neighbor_search_type: Literal["knn", "radius"] = "radius",
knn_k: int = 16,
reductions: List[REDUCTION_TYPES] = ["mean"],
- drag_loss_weight: Optional[float] = None,
+ use_scalar_output: bool = True,
+ has_input_features: bool = False,
pooling_type: Literal["attention", "max", "mean"] = "max",
pooling_layers: List[int] = None,
):
@@ -163,6 +164,7 @@ def __init__(
self.point_feature_to_grids = nn.ModuleList()
self.aabb_length = torch.tensor(aabb_max) - torch.tensor(aabb_min)
self.min_voxel_edge_length = torch.tensor([np.inf, np.inf, np.inf])
+ self.use_scalar_output = use_scalar_output
for mem_fmt, res in resolution_memory_format_pairs:
if isinstance(mem_fmt, str):
@@ -256,37 +258,40 @@ def __init__(
memory_format=GridFeaturesMemoryFormat.b_x_y_z_c
)
- if pooling_layers is None:
- pooling_layers = [num_levels]
- else:
- assert isinstance(pooling_layers, list), (
- f"pooling_layers must be a list, got {type(pooling_layers)}."
- )
- for layer in pooling_layers:
- assert layer <= num_levels, (
- f"pooling_layer {layer} is greater than num_levels {num_levels}."
+ if use_scalar_output:
+ if pooling_layers is None:
+ pooling_layers = [num_levels]
+ else:
+ assert isinstance(pooling_layers, list), (
+ f"pooling_layers must be a list, got {type(pooling_layers)}."
+ )
+ for layer in pooling_layers:
+ assert layer <= num_levels, (
+ f"pooling_layer {layer} is greater than num_levels {num_levels}."
+ )
+ self.pooling_layers = pooling_layers
+ grid_pools = [
+ GridFeatureGroupPool(
+ in_channels=hidden_channels[layer],
+ out_channels=mlp_channels[0],
+ compressed_spatial_dims=self.compressed_spatial_dims,
+ pooling_type=pooling_type,
)
- self.pooling_layers = pooling_layers
- grid_pools = [
- GridFeatureGroupPool(
- in_channels=hidden_channels[layer],
- out_channels=mlp_channels[0],
- compressed_spatial_dims=self.compressed_spatial_dims,
- pooling_type=pooling_type,
+ for layer in pooling_layers
+ ]
+ self.grid_pools = nn.ModuleList(grid_pools)
+
+ self.mlp = MLP(
+ mlp_channels[0]
+ * len(self.compressed_spatial_dims)
+ * len(pooling_layers),
+ mlp_channels[-1],
+ mlp_channels,
+ use_residual=True,
+ activation=nn.GELU,
)
- for layer in pooling_layers
- ]
- self.grid_pools = nn.ModuleList(grid_pools)
-
- self.mlp = MLP(
- mlp_channels[0] * len(self.compressed_spatial_dims) * len(pooling_layers),
- mlp_channels[-1],
- mlp_channels,
- use_residual=True,
- activation=nn.GELU,
- )
- self.mlp_projection = nn.Linear(mlp_channels[-1], 1)
- # nn.Sigmoid(),
+ self.mlp_projection = nn.Linear(mlp_channels[-1], 1)
+ # nn.Sigmoid(),
self.to_point = GridFeatureGroupToPoint(
grid_in_channels=hidden_channels[0],
@@ -314,16 +319,13 @@ def __init__(
self.pad_to_match = GridFeatureGroupPadToMatch()
- vertex_to_point_features = VerticesToPointFeatures(
- embed_dim=pos_encode_dim,
- out_features=hidden_channels[0],
- use_mlp=True,
- pos_embed_range=aabb_max[0] - aabb_min[0],
- )
-
- self.vertex_to_point_features = vertex_to_point_features
- if drag_loss_weight is not None:
- self.drag_loss_weight = drag_loss_weight
+ if not has_input_features:
+ self.vertex_to_point_features = VerticesToPointFeatures(
+ embed_dim=pos_encode_dim,
+ out_features=hidden_channels[0],
+ use_mlp=True,
+ pos_embed_range=aabb_max[0] - aabb_min[0],
+ )
@profile
def _grid_forward(self, point_features: PointFeatures):
@@ -335,15 +337,17 @@ def _grid_forward(self, point_features: PointFeatures):
out_features = down_block(down_grid_feature_groups[-1])
down_grid_feature_groups.append(out_features)
- # Drag prediction
- pooled_feats = []
- for grid_pool, layer in zip(self.grid_pools, self.pooling_layers):
- pooled_feats.append(grid_pool(down_grid_feature_groups[layer]))
- if len(pooled_feats) > 1:
- pooled_feats = torch.cat(pooled_feats, dim=-1)
- else:
- pooled_feats = pooled_feats[0]
- drag_pred = self.mlp_projection(self.mlp(pooled_feats))
+ drag_pred = None
+ if self.use_scalar_output:
+ # Drag prediction
+ pooled_feats = []
+ for grid_pool, layer in zip(self.grid_pools, self.pooling_layers):
+ pooled_feats.append(grid_pool(down_grid_feature_groups[layer]))
+ if len(pooled_feats) > 1:
+ pooled_feats = torch.cat(pooled_feats, dim=-1)
+ else:
+ pooled_feats = pooled_feats[0]
+ drag_pred = self.mlp_projection(self.mlp(pooled_feats))
for level in reversed(range(self.num_levels)):
up_grid_features = self.up_blocks[level](