diff --git a/examples/structural_mechanics/crash/README.md b/examples/structural_mechanics/crash/README.md index 6b8a5ebb3e..f3e8d97833 100644 --- a/examples/structural_mechanics/crash/README.md +++ b/examples/structural_mechanics/crash/README.md @@ -19,14 +19,14 @@ For an in-depth comparison between the Transolver and MeshGraphNet models and th

Crash animation - +

### Crushcan Modeling

Crushcan animation - +

## Quickstart @@ -238,7 +238,10 @@ conf/ │ ├── mgn_time_conditional.yaml │ ├── transolver_autoregressive_rollout_training.yaml │ ├── transolver_one_step_rollout.yaml -│ └── transolver_time_conditional.yaml +│ ├── transolver_time_conditional.yaml +│ ├── figconvunet_autoregressive_rollout_training.yaml +│ ├── figconvunet_one_step_rollout.yaml +│ └── figconvunet_time_conditional.yaml ├── training/default.yaml # training hyperparameters └── inference/default.yaml # inference options ``` @@ -495,7 +498,7 @@ run_post_processing.sh can automate all evaluation tasks across runs. - AMP is enabled by default in training; it reduces memory and accelerates matmuls on modern GPUs. - For multi-GPU training, use `torchrun --standalone --nproc_per_node= train.py`. -- For DDP, prefer `torchrun --standalone --nproc_per_node= train.py`. +- For DDP, prefer `torchrun --standalone --nproc_per_node= train.py`. ## Troubleshooting / FAQ diff --git a/examples/structural_mechanics/crash/conf/model/figconvunet_autoregressive_rollout_training.yaml b/examples/structural_mechanics/crash/conf/model/figconvunet_autoregressive_rollout_training.yaml new file mode 100644 index 0000000000..f7a9f1f672 --- /dev/null +++ b/examples/structural_mechanics/crash/conf/model/figconvunet_autoregressive_rollout_training.yaml @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +_target_: rollout.FIGConvUNetAutoregressiveRolloutTraining +_convert_: all + +# Input/output channels +in_channels: 5 # velocity(3) + features(F) + time(1) +out_channels: 3 # acceleration (xyz) + +# Architecture +kernel_size: 3 +hidden_channels: [16, 16, 16] # channels at each level +num_levels: 2 # number of down/up levels +num_down_blocks: 1 +num_up_blocks: 1 +mlp_channels: [256, 256] + +# Spatial domain +aabb_max: [2.0, 2.0, 2.0] +aabb_min: [-2.0, -2.0, -2.0] +voxel_size: null + +# Grid resolutions (factorized implicit grids) +# Format: [memory_format, resolution_tuple] +resolution_memory_format_pairs: + - [b_xc_y_z, [2, 64, 64]] + - [b_yc_x_z, [64, 2, 64]] + - [b_zc_x_y, [64, 64, 2]] + +# Position encoding +use_rel_pos: true +use_rel_pos_embed: true +pos_encode_dim: 16 + +# Communication and sampling +communication_types: ["sum"] +to_point_sample_method: "graphconv" +neighbor_search_type: "knn" +knn_k: 16 +reductions: ["mean"] + +use_scalar_output: false +has_input_features: true + +# Pooling (for global features if needed) +pooling_type: "max" +pooling_layers: [2] + +# Rollout parameters +num_time_steps: ${training.num_time_steps} +dt: 5e-3 +initial_vel: 9.22 diff --git a/examples/structural_mechanics/crash/conf/model/figconvunet_one_step_rollout.yaml b/examples/structural_mechanics/crash/conf/model/figconvunet_one_step_rollout.yaml new file mode 100644 index 0000000000..6272057bcf --- /dev/null +++ b/examples/structural_mechanics/crash/conf/model/figconvunet_one_step_rollout.yaml @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +_target_: rollout.FIGConvUNetOneStepRollout +_convert_: all + +# Input/output channels +in_channels: 4 # velocity(3) + features (F) +out_channels: 3 # next step position (xyz) + +# Architecture +kernel_size: 3 +hidden_channels: [16, 16, 16] # channels at each level +num_levels: 2 # number of down/up levels +num_down_blocks: 1 +num_up_blocks: 1 +mlp_channels: [256, 256] + +# Spatial domain +aabb_max: [2.0, 2.0, 2.0] +aabb_min: [-2.0, -2.0, -2.0] +voxel_size: null + +# Grid resolutions (factorized implicit grids) +# Format: Uses res_mem_pair resolver (memory_format_enum, resolution_tuple) +resolution_memory_format_pairs: + - [b_xc_y_z, [2, 64, 64]] + - [b_yc_x_z, [64, 2, 64]] + - [b_zc_x_y, [64, 64, 2]] + +# Position encoding +use_rel_pos: true +use_rel_pos_embed: true +pos_encode_dim: 16 + +# Communication and sampling +communication_types: ["sum"] +to_point_sample_method: "graphconv" +neighbor_search_type: "knn" +knn_k: 16 +reductions: ["mean"] + +use_scalar_output: false +has_input_features: true + +# Pooling (for global features if needed) +pooling_type: "max" +pooling_layers: [2] + +# Rollout parameters +num_time_steps: ${training.num_time_steps} +dt: 5e-3 +initial_vel: 9.22 diff --git a/examples/structural_mechanics/crash/conf/model/figconvunet_time_conditional.yaml b/examples/structural_mechanics/crash/conf/model/figconvunet_time_conditional.yaml index 8884737ce1..56ff2e4ddd 100644 --- a/examples/structural_mechanics/crash/conf/model/figconvunet_time_conditional.yaml +++ b/examples/structural_mechanics/crash/conf/model/figconvunet_time_conditional.yaml @@ -53,6 +53,9 @@ neighbor_search_type: "knn" knn_k: 16 reductions: ["mean"] +use_scalar_output: false +has_input_features: true + # Pooling (for global features if needed) pooling_type: "max" pooling_layers: [2] diff --git a/examples/structural_mechanics/crash/rollout.py b/examples/structural_mechanics/crash/rollout.py index bab5f9dd54..9bdaaa1bb1 100644 --- a/examples/structural_mechanics/crash/rollout.py +++ b/examples/structural_mechanics/crash/rollout.py @@ -467,3 +467,171 @@ def step_fn(verts, feats): outputs.append(y_t) return torch.stack(outputs, dim=0) # [T, N, 3] + + +class FIGConvUNetOneStepRollout(FIGConvUNet): + """ + FIGConvUNet with one-step rollout for crash simulation. + + - Training: teacher forcing (uses GT positions at each step) + - Inference: autoregressive (uses predictions) + """ + + def __init__(self, *args, **kwargs): + self.dt: float = kwargs.pop("dt", 5e-3) + self.initial_vel: torch.Tensor = kwargs.pop("initial_vel") + self.rollout_steps: int = kwargs.pop("num_time_steps") - 1 + super().__init__(*args, **kwargs) + + def forward(self, sample: SimSample, data_stats: dict) -> torch.Tensor: + """ + Args: + Sample: SimSample containing node_features and node_target + data_stats: dict containing normalization stats + Returns: + [T, N, 3] rollout of predicted positions + """ + inputs = sample.node_features + x0 = inputs["coords"] # initial pos [N, 3] + features = inputs.get("features", x0.new_zeros((x0.size(0), 0))) # [N, F] + + # Ground truth sequence [T, N, 3] + N = x0.size(0) + gt_seq = torch.cat( + [x0.unsqueeze(0), sample.node_target.view(N, -1, 3).transpose(0, 1)], + dim=0, + ) + + outputs: list[torch.Tensor] = [] + # First step: backstep to create y_-1 + y_t0 = gt_seq[0] - self.initial_vel * self.dt + y_t1 = gt_seq[0] + + for t in range(self.rollout_steps): + # In training mode (except first step), use ground truth positions + if self.training and t > 0: + y_t0, y_t1 = gt_seq[t - 1], gt_seq[t] + + # Prepare vertices for FIGConvUNet: [1, N, 3] + vertices = y_t1.unsqueeze(0) # [1, N, 3] + + vel = (y_t1 - y_t0) / self.dt + vel_norm = (vel - data_stats["node"]["norm_vel_mean"]) / ( + data_stats["node"]["norm_vel_std"] + EPS + ) + + # [1, N, 3 + F] + fx_t = torch.cat([vel_norm, features], dim=-1).unsqueeze(0) + + def step_fn(verts, feats): + out, _ = super(FIGConvUNetOneStepRollout, self).forward( + vertices=verts, features=feats + ) + return out + + if self.training: + outf = ckpt( + step_fn, + vertices, + fx_t, + use_reentrant=False, + ).squeeze(0) # [N, 3] + else: + outf = step_fn(vertices, fx_t).squeeze(0) # [N, 3] + + acc = ( + outf * data_stats["node"]["norm_acc_std"] + + data_stats["node"]["norm_acc_mean"] + ) + vel_pred = self.dt * acc + vel + y_t2_pred = self.dt * vel_pred + y_t1 + + outputs.append(y_t2_pred) + + if not self.training: + # autoregressive update for inference + y_t0, y_t1 = y_t1, y_t2_pred + + return torch.stack(outputs, dim=0) # [T, N, 3] + + +class FIGConvUNetAutoregressiveRolloutTraining(FIGConvUNet): + """ + FIGConvUNet with autoregressive rollout training for crash simulation. + + Predicts sequence by autoregressively updating velocity and position + using predicted accelerations. Supports gradient checkpointing during training. + """ + + def __init__(self, *args, **kwargs): + self.dt: float = kwargs.pop("dt") + self.initial_vel: torch.Tensor = kwargs.pop("initial_vel") + self.rollout_steps: int = kwargs.pop("num_time_steps") - 1 + super().__init__(*args, **kwargs) + + def forward(self, sample: SimSample, data_stats: dict) -> torch.Tensor: + """ + Args: + sample: SimSample containing node_features and node_target + data_stats: dict containing normalization stats + Returns: + [T, N, 3] rollout of predicted positions + """ + inputs = sample.node_features + coords = inputs["coords"] # [N, 3] + features = inputs.get("features", coords.new_zeros((coords.size(0), 0))) + N = coords.size(0) + device = coords.device + + # Initial states + y_t1 = coords # [N, 3] + y_t0 = y_t1 - self.initial_vel * self.dt # backstep using initial velocity + + outputs: list[torch.Tensor] = [] + for t in range(self.rollout_steps): + time_t = 0.0 if self.rollout_steps <= 1 else t / (self.rollout_steps - 1) + time_t = torch.tensor([time_t], device=device, dtype=torch.float32) + + # Velocity normalization + vel = (y_t1 - y_t0) / self.dt + vel_norm = (vel - data_stats["node"]["norm_vel_mean"]) / ( + data_stats["node"]["norm_vel_std"] + EPS + ) + + # Prepare vertices for FIGConvUNet: [1, N, 3] + vertices = y_t1.unsqueeze(0) # [1, N, 3] + + # Prepare features: vel_norm + features + time [N, 3+F+1] + fx_t = torch.cat( + [vel_norm, features, time_t.expand(N, 1)], dim=-1 + ) # [N, 3+F+1] + fx_t = fx_t.unsqueeze(0) # [1, N, 3+F+1] + + def step_fn(verts, feats): + out, _ = super(FIGConvUNetAutoregressiveRolloutTraining, self).forward( + vertices=verts, features=feats + ) + return out + + if self.training: + outf = ckpt( + step_fn, + vertices, + fx_t, + use_reentrant=False, + ).squeeze(0) # [N, 3] + else: + outf = step_fn(vertices, fx_t).squeeze(0) # [N, 3] + + # De-normalize acceleration + acc = ( + outf * data_stats["node"]["norm_acc_std"] + + data_stats["node"]["norm_acc_mean"] + ) + vel = self.dt * acc + vel + y_t2 = self.dt * vel + y_t1 + + outputs.append(y_t2) + y_t1, y_t0 = y_t2, y_t1 + + return torch.stack(outputs, dim=0) # [T, N, 3] diff --git a/examples/structural_mechanics/crash/tests/test_rollout.py b/examples/structural_mechanics/crash/tests/test_rollout.py index f691ef7e6b..6cb6cc420a 100644 --- a/examples/structural_mechanics/crash/tests/test_rollout.py +++ b/examples/structural_mechanics/crash/tests/test_rollout.py @@ -90,6 +90,22 @@ def mgn_forward(self, node_features=None, edge_features=None, graph=None): monkeypatch.setattr(rollout.MeshGraphNet, "__init__", mgn_init, raising=True) monkeypatch.setattr(rollout.MeshGraphNet, "forward", mgn_forward, raising=True) + # Stub FIGConvUNet.__init__ and FIGConvUNet.forward + def figconvunet_init(self, *args, **kwargs): + torch.nn.Module.__init__(self) + + def figconvunet_forward(self, vertices=None, features=None): + # Return zeros with shape matching vertices + # vertices: [B, N, 3], features: [B, N, F] + # output: [B, N, 3] + assert vertices is not None + return torch.zeros_like(vertices), None + + monkeypatch.setattr(rollout.FIGConvUNet, "__init__", figconvunet_init, raising=True) + monkeypatch.setattr( + rollout.FIGConvUNet, "forward", figconvunet_forward, raising=True + ) + def test_transolver_autoregressive_rollout_eval(): N, T, F = 5, 4, 2 @@ -181,3 +197,43 @@ class DummyGraph: out = model.forward(sample=sample, data_stats=stats) assert out.shape == (T - 1, N, 3) + + +def test_figconvunet_time_conditional_rollout_eval(): + N, T, F = 6, 5, 3 + sample = make_sample(N=N, T=T, F=F) + stats = make_data_stats() + + model = rollout.FIGConvUNetTimeConditionalRollout(num_time_steps=T) + model.eval() + + out = model.forward(sample=sample, data_stats=stats) + assert out.shape == (T - 1, N, 3) + + +def test_figconvunet_one_step_rollout_eval(): + N, T, F = 7, 6, 1 + sample = make_sample(N=N, T=T, F=F) + stats = make_data_stats() + + model = rollout.FIGConvUNetOneStepRollout( + dt=5e-3, initial_vel=torch.zeros(1, 3), num_time_steps=T + ) + model.eval() + + out = model.forward(sample=sample, data_stats=stats) + assert out.shape == (T - 1, N, 3) + + +def test_figconvunet_autoregressive_rollout_eval(): + N, T, F = 5, 4, 2 + sample = make_sample(N=N, T=T, F=F) + stats = make_data_stats() + + model = rollout.FIGConvUNetAutoregressiveRolloutTraining( + dt=5e-3, initial_vel=torch.zeros(1, 3), num_time_steps=T + ) + model.eval() + + out = model.forward(sample=sample, data_stats=stats) + assert out.shape == (T - 1, N, 3) diff --git a/physicsnemo/models/figconvnet/figconvunet.py b/physicsnemo/models/figconvnet/figconvunet.py index 78c87f5699..23ab616168 100644 --- a/physicsnemo/models/figconvnet/figconvunet.py +++ b/physicsnemo/models/figconvnet/figconvunet.py @@ -149,7 +149,8 @@ def __init__( neighbor_search_type: Literal["knn", "radius"] = "radius", knn_k: int = 16, reductions: List[REDUCTION_TYPES] = ["mean"], - drag_loss_weight: Optional[float] = None, + use_scalar_output: bool = True, + has_input_features: bool = False, pooling_type: Literal["attention", "max", "mean"] = "max", pooling_layers: List[int] = None, ): @@ -163,6 +164,7 @@ def __init__( self.point_feature_to_grids = nn.ModuleList() self.aabb_length = torch.tensor(aabb_max) - torch.tensor(aabb_min) self.min_voxel_edge_length = torch.tensor([np.inf, np.inf, np.inf]) + self.use_scalar_output = use_scalar_output for mem_fmt, res in resolution_memory_format_pairs: if isinstance(mem_fmt, str): @@ -256,37 +258,40 @@ def __init__( memory_format=GridFeaturesMemoryFormat.b_x_y_z_c ) - if pooling_layers is None: - pooling_layers = [num_levels] - else: - assert isinstance(pooling_layers, list), ( - f"pooling_layers must be a list, got {type(pooling_layers)}." - ) - for layer in pooling_layers: - assert layer <= num_levels, ( - f"pooling_layer {layer} is greater than num_levels {num_levels}." + if use_scalar_output: + if pooling_layers is None: + pooling_layers = [num_levels] + else: + assert isinstance(pooling_layers, list), ( + f"pooling_layers must be a list, got {type(pooling_layers)}." + ) + for layer in pooling_layers: + assert layer <= num_levels, ( + f"pooling_layer {layer} is greater than num_levels {num_levels}." + ) + self.pooling_layers = pooling_layers + grid_pools = [ + GridFeatureGroupPool( + in_channels=hidden_channels[layer], + out_channels=mlp_channels[0], + compressed_spatial_dims=self.compressed_spatial_dims, + pooling_type=pooling_type, ) - self.pooling_layers = pooling_layers - grid_pools = [ - GridFeatureGroupPool( - in_channels=hidden_channels[layer], - out_channels=mlp_channels[0], - compressed_spatial_dims=self.compressed_spatial_dims, - pooling_type=pooling_type, + for layer in pooling_layers + ] + self.grid_pools = nn.ModuleList(grid_pools) + + self.mlp = MLP( + mlp_channels[0] + * len(self.compressed_spatial_dims) + * len(pooling_layers), + mlp_channels[-1], + mlp_channels, + use_residual=True, + activation=nn.GELU, ) - for layer in pooling_layers - ] - self.grid_pools = nn.ModuleList(grid_pools) - - self.mlp = MLP( - mlp_channels[0] * len(self.compressed_spatial_dims) * len(pooling_layers), - mlp_channels[-1], - mlp_channels, - use_residual=True, - activation=nn.GELU, - ) - self.mlp_projection = nn.Linear(mlp_channels[-1], 1) - # nn.Sigmoid(), + self.mlp_projection = nn.Linear(mlp_channels[-1], 1) + # nn.Sigmoid(), self.to_point = GridFeatureGroupToPoint( grid_in_channels=hidden_channels[0], @@ -314,16 +319,13 @@ def __init__( self.pad_to_match = GridFeatureGroupPadToMatch() - vertex_to_point_features = VerticesToPointFeatures( - embed_dim=pos_encode_dim, - out_features=hidden_channels[0], - use_mlp=True, - pos_embed_range=aabb_max[0] - aabb_min[0], - ) - - self.vertex_to_point_features = vertex_to_point_features - if drag_loss_weight is not None: - self.drag_loss_weight = drag_loss_weight + if not has_input_features: + self.vertex_to_point_features = VerticesToPointFeatures( + embed_dim=pos_encode_dim, + out_features=hidden_channels[0], + use_mlp=True, + pos_embed_range=aabb_max[0] - aabb_min[0], + ) @profile def _grid_forward(self, point_features: PointFeatures): @@ -335,15 +337,17 @@ def _grid_forward(self, point_features: PointFeatures): out_features = down_block(down_grid_feature_groups[-1]) down_grid_feature_groups.append(out_features) - # Drag prediction - pooled_feats = [] - for grid_pool, layer in zip(self.grid_pools, self.pooling_layers): - pooled_feats.append(grid_pool(down_grid_feature_groups[layer])) - if len(pooled_feats) > 1: - pooled_feats = torch.cat(pooled_feats, dim=-1) - else: - pooled_feats = pooled_feats[0] - drag_pred = self.mlp_projection(self.mlp(pooled_feats)) + drag_pred = None + if self.use_scalar_output: + # Drag prediction + pooled_feats = [] + for grid_pool, layer in zip(self.grid_pools, self.pooling_layers): + pooled_feats.append(grid_pool(down_grid_feature_groups[layer])) + if len(pooled_feats) > 1: + pooled_feats = torch.cat(pooled_feats, dim=-1) + else: + pooled_feats = pooled_feats[0] + drag_pred = self.mlp_projection(self.mlp(pooled_feats)) for level in reversed(range(self.num_levels)): up_grid_features = self.up_blocks[level](