diff --git a/docs/img/crash/crash_case4_reduced.gif b/docs/img/crash/crash_case4_reduced.gif
new file mode 100644
index 0000000000..95187bb882
Binary files /dev/null and b/docs/img/crash/crash_case4_reduced.gif differ
diff --git a/docs/img/crash/crushcan.gif b/docs/img/crash/crushcan.gif
new file mode 100644
index 0000000000..f7eeabaeff
Binary files /dev/null and b/docs/img/crash/crushcan.gif differ
diff --git a/docs/img/crash/roof_crash.gif b/docs/img/crash/roof_crash.gif
new file mode 100644
index 0000000000..bc96fa4e1e
Binary files /dev/null and b/docs/img/crash/roof_crash.gif differ
diff --git a/examples/structural_mechanics/crash/README.md b/examples/structural_mechanics/crash/README.md
index d13b32de05..fda64ab21c 100644
--- a/examples/structural_mechanics/crash/README.md
+++ b/examples/structural_mechanics/crash/README.md
@@ -1,26 +1,115 @@
-# Machine Learning Surrogates for Automotive Crash Dynamics
+# Machine Learning Surrogates for Automotive Crash Dynamics 🧱💥🚗
## Problem Overview
-Automotive crashworthiness assessment is a critical step in vehicle design.
-Traditionally, engineers rely on high-fidelity finite element (FE)
-simulations (e.g., LS-DYNA) to predict structural deformation and crash responses.
-While accurate, these simulations are computationally expensive and
-limit the speed of design iterations.
+Automotive crashworthiness assessment is a critical step in vehicle design. Traditionally, engineers rely on high-fidelity finite element (FE) simulations (e.g., LS-DYNA) to predict structural deformation and crash responses. While accurate, these simulations are computationally expensive and limit the speed of design iterations.
-Machine Learning (ML) surrogates provide a promising alternative by learning
-mappings directly from simulation data, enabling:
+Machine Learning (ML) surrogates provide a promising alternative by learning mappings directly from simulation data, enabling:
- **Rapid prediction** of deformation histories across thousands of design candidates.
- **Scalability** to large structural models without rerunning costly FE simulations.
- **Flexibility** in experimenting with different model architectures (GNNs, Transformers).
-In this example, we demonstrate a unified pipeline for crash dynamics modeling.
-The implementation supports both:
+In this example, we demonstrate a unified pipeline for crash dynamics modeling. The implementation supports Transolver and MeshGraphNet architectures with multiple rollout schemes. It supports multiple dataset formats including d3plot and VTP. The design is highly modular, enabling users to write their own readers, bring their own architectures, or implement custom rollout/transient schemes.
-- **Mesh-based Graph Neural Networks (MeshGraphNet)** – leverage connectivity from FE meshes.
-- **Point-cloud Transformers (Transolver)** – avoid explicit mesh dependency.
+For an in-depth comparison between the Transolver and MeshGraphNet models and the transient schemes for crash dynamics, see [this paper](https://arxiv.org/pdf/2510.15201).
+
+### Body-in-White Crash Modeling
+
+
+
+
+
+
+### Crushcan Modeling
+
+
+
+
+
+
+### Roof Crash Modeling
+
+
+
+
+
+
+## Quickstart
+
+1) Select your recipe (reader, datapipe, model) in `conf/config.yaml`.
+
+```yaml
+# conf/config.yaml
+defaults:
+ - reader: vtp # or d3plot, or your custom reader
+ - datapipe: point_cloud # or graph
+ - model: transolver_time_conditional # or an MGN variant
+ - training: default
+ - inference: default
+ - _self_
+```
+
+2) Point to your datasets and core training knobs.
+
+- `conf/training/default.yaml`:
+ - `raw_data_dir`: path to TRAIN runs (folder of run folders for d3plot, or folder of .vtp files for VTP)
+ - `num_time_steps`: number of frames to use per run
+ - `num_training_samples`: how many runs to load
+
+```yaml
+# conf/training/default.yaml
+raw_data_dir: "/path/to/train" # REQUIRED: change this
+num_time_steps: 14 # adjust to your data
+num_training_samples: 8 # adjust to available runs
+```
+
+- `conf/inference/default.yaml`:
+ - `raw_data_dir_test`: path to TEST runs
+ - `output_dir_pred`/`output_dir_exact`: where to write predicted/exact VTPs
+
+```yaml
+# conf/inference/default.yaml
+raw_data_dir_test: "/path/to/test" # REQUIRED: change this
+```
+
+3) Configure the datapipe features list (order matters and defines columns of `x['features']`).
+
+```yaml
+# conf/datapipe/point_cloud.yaml (same keys for graph.yaml)
+features: [thickness] # or [] for no features; preserve order if adding more
+```
+
+4) Reader‑specific options (optional).
+
+- d3plot: `conf/reader/d3plot.yaml` → `wall_node_disp_threshold`
+
+5) Model config: ensure input dimensions match your features.
+
+- Transolver (time‑conditional): set `functional_dim = len(features)` and `embedding_dim = 3`;
+
+```yaml
+# conf/model/transolver_time_conditional.yaml
+functional_dim: 1 # e.g., 1 if features: [thickness]
+embedding_dim: 3
+time_input: true
+```
+
+6) Launch training.
+
+```bash
+python train.py # single GPU
+torchrun --standalone --nproc_per_node=4 train.py # multi-GPU (DDP)
+```
+
+7) Run inference.
+
+```bash
+python inference.py
+```
+
+Outputs: predictions are saved under `output_dir_pred` (default `./predicted_vtps/`). Normalization stats are written to `./stats/` during training and reused for inference.
## Prerequisites
@@ -39,20 +128,6 @@ This will install:
- lasso-python (for LS-DYNA file parsing),
- torch_geometric and torch_scatter (for GNN operations),
-## Dataset Preprocessing
-
-Crash simulation data is parsed from LS-DYNA d3plot files using the d3plot_reader.py utility.
-
-Key steps:
-
-- Load node coordinates, displacements, element connectivity, and part IDs.
-- Parse .k keyword files to assign part thickness values.
-- Filter out rigid wall nodes using displacement thresholds.
-- Build edges (for graphs) and store per-node features (e.g., thickness).
-- Optionally export time-stepped meshes as .vtp for visualization.
-
-Run preprocessing automatically via the dataset class (CrashGraphDataset or CrashPointCloudDataset) when launching training or inference.
-
## Training
Training is managed via Hydra configurations located in conf/.
@@ -101,6 +176,131 @@ python inference.py
Predicted meshes are written as .vtp files under
./predicted_vtps/, and can be opened using ParaView.
+## Datapipe: how inputs are constructed and normalized
+
+The datapipe is responsible for turning raw LS-DYNA/Abaqus or other crash runs into model-ready tensors and statistics. It does three things in a predictable, repeatable way: it reads and filters the raw data, it constructs inputs and targets with a stable interface, and it computes the statistics required to normalize both positions and features. This section explains what the datapipe returns, how to configure it, and what models should expect to receive at training and inference time.
+
+At a high level, each sample corresponds to one crash run. The datapipe loads the full deformation trajectory for that run, and emits exactly two items: inputs x and targets y. Inputs are a dictionary with two entries. The first entry, 'coords', is a [N, 3] tensor that contains the positions at the first timestep (t0) for all retained nodes. The second entry, 'features', is a [N, F] tensor that contains the concatenation of all node-wise features configured for this experiment. The order of columns in 'features' matches the order you provide in the configuration. This means if your configuration lists features as [thickness, Y_modulus], then column 0 will always be thickness and column 1 will always be Y_modulus. Targets y are the remaining positions from t1 to tT flattened along the feature dimension, so y has shape [N, (T-1)*3].
+
+Configuration lives under `conf/datapipe/`. There are two datapipe variants: one for graph-based models and one for point-cloud models. Both accept the same core options, and both expose a `features` list. The `features` list is the single source of truth for what goes into the 'features' tensor and in which order. If you do not want any features, set `features: []` and the datapipe will return an empty [N, 0] tensor for 'features' while keeping 'coords' intact. If you add more features later, the datapipe will preserve their order and update the per-dimension statistics automatically.
+
+Under the hood the datapipe reads node positions over time from LS-DYNA (via `d3plot_reader.py` or any compatible reader you configure). For each run it constructs a fixed number of time steps, selects and reindexes the active nodes, and optionally builds graph connectivity. It also computes statistics necessary for normalization. Position statistics include per-axis means and standard deviations, as well as normalized velocity and acceleration statistics used by autoregressive rollouts. Feature statistics are computed column-wise on the concatenated 'features' tensor. During dataset creation the datapipe normalizes the position trajectory using position means and standard deviations and normalizes every column of 'features' using feature means and standard deviations. The resulting tensors are numerically stable and consistent across training and evaluation. The statistics are written under `./stats/` as `node_stats.json` and `feature_stats.json` during training, and then read back in evaluation or inference.
+
+Readers are configurable through Hydra. A reader is any callable that returns `(srcs, dsts, point_data)`, where `point_data` is a list of records—one per run. Each record must include 'coords' as a [T, N, 3] array and one array per configured feature name. Arrays for features can be [N] or [N, K]; the datapipe will promote [N] to [N, 1] and then concatenate all feature arrays in the order declared in the configuration to form 'features'. If you are using graph-based models, the `srcs` and `dsts` arrays will be used to build a PyG `Data` object with symmetric edges and self-loops, and initial edge features are computed from positions at t0 (displacements and distances). If you are using point-cloud models, graph connectivity is ignored but the remainder of the pipeline is identical.
+
+Models should consume the two-part input without guessing column indices. Positions are always available in `x['coords']` and every node-wise feature is already concatenated in `x['features']`. If you need to separate features later—for example to log per-feature metrics—you can do so deterministically because the order of columns in `x['features']` exactly matches the `features` list in the configuration. For time-conditional models, you can pass the full `x['features']` to your functional input; for autoregressive models, you can concatenate `x['features']` to the normalized velocity (and time, if used) to form the model input at each rollout step.
+
+Finally, the datapipe is designed to be resilient to the “no features” case. If you set `features: []`, the 'features' tensor simply has width zero. Statistics are computed correctly (zero-length mean and unit standard deviation) and concatenations degrade gracefully to the original position-only behavior. This makes it easy to start simple and then scale up to richer feature sets without revisiting model-side code or the data normalization logic.
+
+For completeness, the datapipe also records a lightweight name-to-column map called `_feature_slices`. It associates each configured feature name with its [start, end) slice in `x['features']`. You typically won’t need it if you just consume the full `features` tensor, but it enables reliable, reproducible slicing by name for diagnostics or logging.
+
+### Model I/O at a glance (what models receive)
+
+- Inputs `x` (dictionary):
+ - `x['coords']`: `[N, 3]` positions at `t0`
+ - `x['features']`: `[N, F]` concatenated node features in the config‑specified order (can be width 0)
+
+- Targets `y`: `[N, (T-1)*3]` positions from `t1..tT` flattened along the feature dimension.
+
+- Rollout input construction (high level):
+ - Autoregressive: per step, the model consumes normalized velocity, optionally time, and `x['features']`; positions are fed as embeddings/state.
+ - Time‑conditional one‑step: time index is provided once per call along with `x['features']` and the positional embedding.
+
+- Transolver specifics: for unstructured data, the embedding tensor is required; in this pipeline it is the current positions over the rollout. If you set `features: []`, the functional input still includes velocity (and optionally time), so the overall functional dimension remains > 0.
+
+## Reader: built-in d3plot and vtp readers and how to add your own
+
+The reader is the component that actually opens the raw simulation outputs and produces the arrays the datapipe consumes. It is intentionally thin and swappable via Hydra so you can adapt the pipeline to LS‑DYNA exports, Abaqus exports, or your own internal formats without touching the rest of the code.
+
+### Built-in d3plot reader
+
+The default reader is implemented in `d3plot_reader.py`. It searches the data directory for subfolders that contain a `d3plot` file and treats each such folder as one “run.” For each run it opens the `d3plot` with `lasso.dyna.D3plot` and extracts node coordinates, time-varying displacements, element connectivity, and part identifiers. If a LS‑DYNA keyword (`.k`) file is present, it parses the shell section definitions to obtain per-part thickness values, then converts those into per-node thickness by averaging the values of incident elements. To avoid contaminating the training with rigid content, the reader classifies nodes as structural or wall based on a displacement variation threshold and drops wall nodes. After filtering, it builds a compact node index, remaps connectivity, and—if you are training a graph model—collects undirected edges from the remapped shell elements. It can optionally save one VTP file per time step to help you visually inspect the trajectories, or write the predictions to those files in inference.
+
+The reader then assembles the per-run record expected by the datapipe. Positions are returned under the key `'coords'` as a float array of shape `[T, N, 3]`, where T is the number of time steps and N is the number of retained nodes after filtering and remapping. Feature arrays are returned one per configured feature name; for example, if your datapipe configuration lists `features: [thickness, Y_modulus]`, the reader should provide a `'thickness'` array with shape `[N]` or `[N, 1]` and a `'Y_modulus'` array with shape `[N]` or `[N, K]`. The datapipe promotes 1D arrays to 2D and concatenates all provided feature arrays in the order given by the configuration to form the final `'features'` block supplied to the model.
+
+If you use the graph datapipe, the edge list is produced by walking the filtered shell elements and collecting unique boundary pairs, then symmetrized and augmented with self-loops inside the datapipe when constructing the PyG `Data` object. If you use the point‑cloud datapipe, the edge outputs are ignored but the rest of the record shape is the same, so you can swap between model families by changing configuration only.
+
+### Built‑in VTP reader (PolyData)
+
+In addition to `d3plot`, a lightweight VTP reader is provided in `vtp_reader.py`. It treats each `.vtp` file in a directory as a separate run and expects point displacements to be stored as vector arrays in `poly.point_data` with names like `displacement_t0.000`, `displacement_t0.005`, … (a more permissive fallback of any `displacement_t*` is also supported). The reader:
+
+- loads the reference coordinates from `poly.points`
+- builds absolute positions per timestep as `[t0: coords, t>0: coords + displacement_t]`
+- extracts cell connectivity from the PolyData faces and converts it to unique edges
+- returns `(srcs, dsts, point_data)` where `point_data` contains `'coords': [T, N, 3]`
+
+By default, the VTP reader does not attach additional features; it is compatible with `features: []`. If your `.vtp` files include additional per‑point arrays you would like to model (e.g., thickness or modulus), extend the reader to add those arrays to each run’s record using keys that match your `features` list. The datapipe will then concatenate them in the configured order.
+
+Example Hydra configuration for the VTP reader:
+
+```yaml
+# conf/reader/vtp.yaml
+_target_: vtp_reader.Reader
+```
+
+Select it in `conf/config.yaml`:
+
+```yaml
+defaults:
+ - datapipe: point_cloud
+ - model: transolver_time_conditional
+ - training: default
+ - inference: default
+ - reader: vtp
+```
+
+And set `features` to empty (or to the names you add in your extended reader) in `conf/datapipe/point_cloud.yaml` or `conf/datapipe/graph.yaml`:
+
+```yaml
+features: [] # or [thickness, Y_modulus] if your reader provides them
+```
+
+### Data layout expected by readers
+
+- d3plot reader (`d3plot_reader.py`):
+ - `//d3plot` (required)
+ - `//*.k` (optional; used to parse thickness)
+
+- VTP reader (`vtp_reader.py`):
+ - `/*.vtp` (each `.vtp` is treated as one run)
+ - Displacements stored as 3‑component arrays in point_data with names like `displacement_t0.000`, `displacement_t0.005`, ... (fallback accepts any `displacement_t*`).
+
+### Write your own reader
+
+To write your own reader, implement a Hydra‑instantiable function or class whose call returns a three‑tuple `(srcs, dsts, point_data)`. The first two entries are lists of integer arrays describing edges per run (they can be empty lists if you are not producing a graph), and `point_data` is a list of Python dicts with one dict per run. Each dict must contain `'coords'` as a `[T, N, 3]` array and one array per feature name listed in `conf/datapipe/*.yaml` under `features`. Feature arrays can be `[N]` or `[N, K]` and should use the same node indexing as `'coords'`. For convenience, a simple class reader can accept the Hydra `split` argument (e.g., "train" or "test") and decide whether to save VTP frames, but this is optional.
+
+As a starting point, your YAML can point to a class by dotted path. For a class:
+
+```yaml
+# conf/reader/my_reader.yaml
+_target_: my_reader.MyReader
+# any constructor kwargs here, e.g. thresholds or unit conversions
+```
+
+Then, in `conf/config.yaml`, select the reader by adding or overriding `- reader: my_reader` (or `my_reader_fn`). The datapipe will call your reader with `data_dir`, `num_samples`, `split`, and an optional `logger`, and will expect the tuple described above. Provided you populate `'coords'` and the configured feature arrays per run, the rest of the pipeline—normalization, batching, graph construction, and model rollout—will work without code changes.
+
+A note on reader signatures and future‑proofing: the datapipe currently passes `data_dir`, `num_samples`, `split`, and `logger` when invoking the reader, and may pass additional keys in the future. To stay resilient, implement your reader with optional parameters and a catch‑all `**kwargs`.
+
+For a class reader, use this signature in `__call__`:
+
+```python
+class MyReader:
+ def __init__(self, some_option: float = 1.0):
+ self.some_option = some_option
+
+ def __call__(
+ self,
+ data_dir: str,
+ num_samples: int,
+ split: str | None = None,
+ logger=None,
+ **kwargs,
+ ):
+ ...
+```
+
+With this pattern, your reader will keep working even if the framework adds new optional arguments later.
+
## Postprocessing and Evaluation
The postprocessing/ folder provides scripts for quantitative and qualitative evaluation:
@@ -149,6 +349,28 @@ python postprocessing/plot_cross_section.py \
run_post_processing.sh can automate all evaluation tasks across runs.
+## Performance tips
+
+- AMP is enabled by default in training; it reduces memory and accelerates matmuls on modern GPUs.
+- For multi-GPU training, use `torchrun --standalone --nproc_per_node= train.py`.
+- For DDP, prefer `torchrun --standalone --nproc_per_node= train.py`.
+
+## Troubleshooting / FAQ
+
+- My `.vtp` has no displacement fields.
+ - Ensure point_data contains vector arrays named like `displacement_t0.000`, `displacement_t0.005`, ...; the reader falls back to any `displacement_t*` pattern.
+
+- I want no node features.
+ - Set `features: []`. The datapipe will return `x['features']` with shape `[N, 0]`, and the rollout will still concatenate velocity (and time if configured) for the model input.
+
+- Can functional_dim be 0 for Transolver?
+ - It can be 0 only if the total MLP input dimension remains > 0: e.g., you provide an embedding (required for unstructured) and/or time. In this pipeline, rollout always supplies an embedding (positions), so you are safe with `features: []`.
+
+- My custom reader doesn’t accept `split` or `logger`.
+ - Implement `__call__(..., split: str | None = None, logger=None, **kwargs)` to remain forward‑compatible with optional arguments.
+
## References
-- Automotive Crash Dynamics Modeling Accelerated with Machine Learning (https://arxiv.org/pdf/2510.15201)
\ No newline at end of file
+- [Automotive Crash Dynamics Modeling Accelerated with Machine Learning](https://arxiv.org/pdf/2510.15201)
+- [Transolver: A Fast Transformer Solver for PDEs on General Geometries](https://arxiv.org/pdf/2402.02366)
+- [Learning Mesh-Based Simulation with Graph Networks](https://arxiv.org/pdf/2010.03409)
diff --git a/examples/structural_mechanics/crash/conf/config.yaml b/examples/structural_mechanics/crash/conf/config.yaml
index a34192ccd6..3a3e288d66 100644
--- a/examples/structural_mechanics/crash/conf/config.yaml
+++ b/examples/structural_mechanics/crash/conf/config.yaml
@@ -25,6 +25,7 @@ experiment_desc: "unified training recipe for crash models"
run_desc: "unified training recipe for crash models"
defaults:
+ - reader: vtp #d3plot
- datapipe: point_cloud # will be overridden by model configs
- model: transolver_autoregressive_rollout_training
- training: default
diff --git a/examples/structural_mechanics/crash/conf/datapipe/graph.yaml b/examples/structural_mechanics/crash/conf/datapipe/graph.yaml
index b71dad7e38..3b14a56891 100644
--- a/examples/structural_mechanics/crash/conf/datapipe/graph.yaml
+++ b/examples/structural_mechanics/crash/conf/datapipe/graph.yaml
@@ -15,9 +15,10 @@
# limitations under the License.
_target_: datapipe.CrashGraphDataset
+_convert_: all
data_dir: ${training.raw_data_dir}
name: crash_train
split: train
num_samples: ${training.num_training_samples}
num_steps: ${training.num_time_steps}
-wall_node_disp_threshold: 1.0
\ No newline at end of file
+features: [thickness]
diff --git a/examples/structural_mechanics/crash/conf/datapipe/point_cloud.yaml b/examples/structural_mechanics/crash/conf/datapipe/point_cloud.yaml
index 78996f3235..b758a30734 100644
--- a/examples/structural_mechanics/crash/conf/datapipe/point_cloud.yaml
+++ b/examples/structural_mechanics/crash/conf/datapipe/point_cloud.yaml
@@ -15,7 +15,8 @@
# limitations under the License.
_target_: datapipe.CrashPointCloudDataset
+_convert_: all
data_dir: ${training.raw_data_dir}
num_samples: ${training.num_training_samples}
num_steps: ${training.num_time_steps}
-wall_node_disp_threshold: 1.0
\ No newline at end of file
+features: [thickness]
\ No newline at end of file
diff --git a/examples/structural_mechanics/crash/conf/model/mgn_autoregressive_rollout_training.yaml b/examples/structural_mechanics/crash/conf/model/mgn_autoregressive_rollout_training.yaml
index 9293c8e262..023a9748e5 100644
--- a/examples/structural_mechanics/crash/conf/model/mgn_autoregressive_rollout_training.yaml
+++ b/examples/structural_mechanics/crash/conf/model/mgn_autoregressive_rollout_training.yaml
@@ -15,6 +15,7 @@
# limitations under the License.
_target_: rollout.MeshGraphNetAutoregressiveRolloutTraining
+_convert_: all
input_dim_nodes: 7 # pos(3) + vel(3) + thickness(1)
input_dim_edges: 4 # dx, dy, dz, distance
diff --git a/examples/structural_mechanics/crash/conf/model/mgn_one_step_rollout.yaml b/examples/structural_mechanics/crash/conf/model/mgn_one_step_rollout.yaml
index 9293c8e262..023a9748e5 100644
--- a/examples/structural_mechanics/crash/conf/model/mgn_one_step_rollout.yaml
+++ b/examples/structural_mechanics/crash/conf/model/mgn_one_step_rollout.yaml
@@ -15,6 +15,7 @@
# limitations under the License.
_target_: rollout.MeshGraphNetAutoregressiveRolloutTraining
+_convert_: all
input_dim_nodes: 7 # pos(3) + vel(3) + thickness(1)
input_dim_edges: 4 # dx, dy, dz, distance
diff --git a/examples/structural_mechanics/crash/conf/model/mgn_time_conditional.yaml b/examples/structural_mechanics/crash/conf/model/mgn_time_conditional.yaml
index e6293cc925..030c5c0151 100644
--- a/examples/structural_mechanics/crash/conf/model/mgn_time_conditional.yaml
+++ b/examples/structural_mechanics/crash/conf/model/mgn_time_conditional.yaml
@@ -15,6 +15,7 @@
# limitations under the License.
_target_: rollout.MeshGraphNetTimeConditionalRollout
+_convert_: all
input_dim_nodes: 5 # pos(3) + thickness(1) + time(1)
input_dim_edges: 4 # dx, dy, dz, distance
diff --git a/examples/structural_mechanics/crash/conf/model/transolver_autoregressive_rollout_training.yaml b/examples/structural_mechanics/crash/conf/model/transolver_autoregressive_rollout_training.yaml
index 3635ce2ed8..8b76174fbe 100644
--- a/examples/structural_mechanics/crash/conf/model/transolver_autoregressive_rollout_training.yaml
+++ b/examples/structural_mechanics/crash/conf/model/transolver_autoregressive_rollout_training.yaml
@@ -15,6 +15,7 @@
# limitations under the License.
_target_: rollout.TransolverAutoregressiveRolloutTraining
+_convert_: all
functional_dim: 5
embedding_dim: 3
diff --git a/examples/structural_mechanics/crash/conf/model/transolver_one_step_rollout.yaml b/examples/structural_mechanics/crash/conf/model/transolver_one_step_rollout.yaml
index b0a7bdd7f3..8338067f26 100644
--- a/examples/structural_mechanics/crash/conf/model/transolver_one_step_rollout.yaml
+++ b/examples/structural_mechanics/crash/conf/model/transolver_one_step_rollout.yaml
@@ -15,6 +15,7 @@
# limitations under the License.
_target_: rollout.TransolverOneStepRollout
+_convert_: all
functional_dim: 4
embedding_dim: 3
diff --git a/examples/structural_mechanics/crash/conf/model/transolver_time_conditional.yaml b/examples/structural_mechanics/crash/conf/model/transolver_time_conditional.yaml
index 3a23fdbafd..451f0568a0 100644
--- a/examples/structural_mechanics/crash/conf/model/transolver_time_conditional.yaml
+++ b/examples/structural_mechanics/crash/conf/model/transolver_time_conditional.yaml
@@ -15,6 +15,7 @@
# limitations under the License.
_target_: rollout.TransolverTimeConditionalRollout
+_convert_: all
functional_dim: 1 # thickness
embedding_dim: 3 # position
diff --git a/examples/structural_mechanics/crash/conf/reader/d3plot.yaml b/examples/structural_mechanics/crash/conf/reader/d3plot.yaml
new file mode 100644
index 0000000000..9ee10a460e
--- /dev/null
+++ b/examples/structural_mechanics/crash/conf/reader/d3plot.yaml
@@ -0,0 +1,20 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+_target_: d3plot_reader.Reader
+_convert_: all
+wall_node_disp_threshold: 1.0
+
diff --git a/examples/structural_mechanics/crash/conf/reader/vtp.yaml b/examples/structural_mechanics/crash/conf/reader/vtp.yaml
new file mode 100644
index 0000000000..d8a2e32bb7
--- /dev/null
+++ b/examples/structural_mechanics/crash/conf/reader/vtp.yaml
@@ -0,0 +1,18 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+_target_: vtp_reader.Reader
+_convert_: all
\ No newline at end of file
diff --git a/examples/structural_mechanics/crash/d3plot_reader.py b/examples/structural_mechanics/crash/d3plot_reader.py
index ad07ff74a2..4f8b0edc9f 100644
--- a/examples/structural_mechanics/crash/d3plot_reader.py
+++ b/examples/structural_mechanics/crash/d3plot_reader.py
@@ -19,10 +19,10 @@
import pyvista as pv
from lasso.dyna import D3plot, ArrayType
-from typing import Dict, List, Optional
+from typing import Optional
-def find_run_folders(base_data_dir: str) -> List[str]:
+def find_run_folders(base_data_dir: str) -> list[str]:
"""
Find run directories containing LS-DYNA d3plot files.
@@ -43,7 +43,7 @@ def find_run_folders(base_data_dir: str) -> List[str]:
return run_dirs
-def parse_k_file(k_file_path: str) -> Dict[int, float]:
+def parse_k_file(k_file_path: str) -> dict[int, float]:
"""
Parse LS-DYNA keyword (.k) file to extract part thickness values.
@@ -53,8 +53,8 @@ def parse_k_file(k_file_path: str) -> Dict[int, float]:
Returns:
Dictionary mapping part ID -> thickness.
"""
- part_to_section: Dict[int, int] = {}
- section_thickness: Dict[int, float] = {}
+ part_to_section: dict[int, int] = {}
+ section_thickness: dict[int, float] = {}
with open(k_file_path, "r") as f:
lines = [
@@ -187,7 +187,7 @@ def build_edges_from_mesh_connectivity(mesh_connectivity) -> set:
def compute_node_thickness(
mesh_connectivity,
part_ids,
- part_thickness_map: Dict[int, float],
+ part_thickness_map: dict[int, float],
actual_part_ids: Optional[np.ndarray] = None,
) -> np.ndarray:
"""
@@ -404,8 +404,34 @@ def process_d3plot_data(
write_vtp,
logger,
)
- point_data_all.append(
- {"mesh_pos": mesh_pos_all, "thickness": filtered_thickness}
- )
+ point_data_all.append({"coords": mesh_pos_all, "thickness": filtered_thickness})
return srcs, dsts, point_data_all
+
+
+class Reader:
+ """
+ Reader for LS-DYNA d3plot files.
+
+ Args:
+ wall_node_disp_threshold: threshold for filtering wall nodes
+ """
+
+ def __init__(self, wall_node_disp_threshold: float = 1.0):
+ self.wall_node_disp_threshold = wall_node_disp_threshold
+
+ def __call__(
+ self,
+ data_dir: str,
+ num_samples: int,
+ split: str,
+ logger=None,
+ ):
+ write_vtp = False if split == "train" else True
+ return process_d3plot_data(
+ data_dir=data_dir,
+ num_samples=num_samples,
+ wall_node_disp_threshold=self.wall_node_disp_threshold,
+ write_vtp=write_vtp,
+ logger=logger,
+ )
diff --git a/examples/structural_mechanics/crash/datapipe.py b/examples/structural_mechanics/crash/datapipe.py
index fb98ac5188..a01b9dee05 100644
--- a/examples/structural_mechanics/crash/datapipe.py
+++ b/examples/structural_mechanics/crash/datapipe.py
@@ -17,9 +17,8 @@
import os
import numpy as np
import torch
-from typing import Optional, List, Dict, Any, Tuple
+from typing import Any, Callable, Optional
-from d3plot_reader import process_d3plot_data
from torch_geometric.data import Data
from torch_geometric.utils import coalesce, add_self_loops
@@ -28,7 +27,7 @@
STATS_DIRNAME = "stats"
NODE_STATS_FILE = "node_stats.json"
-THK_STATS_FILE = "thickness_stats.json"
+FEATURE_STATS_FILE = "feature_stats.json"
EDGE_STATS_FILE = "edge_stats.json"
EPS = 1e-8 # numerical stability for std
@@ -39,30 +38,31 @@ class SimSample:
Attributes
---------
- node_features : FloatTensor [N, Din]
+ node_features: dict[str, Tensor] with at least:
+ - 'coords': FloatTensor [N, 3]
+ - any other feature keys configured, e.g., 'thickness': [N, Fk]
node_target : FloatTensor [N, Dout] or [N, (T-1)*3] depending on task
graph : PyG Data or None
"""
def __init__(
self,
- node_features: torch.Tensor,
+ node_features: dict[str, torch.Tensor],
node_target: torch.Tensor,
graph: Optional[Data] = None,
):
- assert node_features.ndim == 2, (
- f"node_features must be [N, D], got {node_features.shape}"
- )
- assert node_target.ndim >= 2, (
- f"node_target must be [N, ...], got {node_target.shape}"
- )
-
+ assert isinstance(node_features, dict), "node_features must be a dict"
+ assert "coords" in node_features, "node_features must contain 'coords'"
+ assert (
+ node_features["coords"].ndim == 2 and node_features["coords"].shape[1] == 3
+ ), f"'coords' must be [N,3], got {node_features['coords'].shape}"
self.node_features = node_features
self.node_target = node_target
self.graph = graph # PyG Data or None
def to(self, device: torch.device):
- self.node_features = self.node_features.to(device)
+ for k, v in self.node_features.items():
+ self.node_features[k] = v.to(device)
self.node_target = self.node_target.to(device)
if self.graph is not None:
self.graph = self.graph.to(device)
@@ -72,15 +72,19 @@ def is_graph(self) -> bool:
return self.graph is not None
def __repr__(self) -> str:
- n = self.node_features.shape[0]
- din = self.node_features.shape[1]
+ n = self.node_features["coords"].shape[0]
+ keys = {k: tuple(v.shape) for k, v in self.node_features.items()}
+ din = 3
+ for k, v in self.node_features.items():
+ if k != "coords":
+ din += v.shape[1]
dout = (
self.node_target.shape[1]
if self.node_target.ndim == 2
else tuple(self.node_target.shape[1:])
)
e = 0 if self.graph is None else self.graph.num_edges
- return f"SimSample(N={n}, Din={din}, Dout={dout}, E={e})"
+ return f"SimSample(N={n}, keys={list(self.node_features.keys())}, Din={din}, Dout={dout}, E={e})"
class CrashBaseDataset:
@@ -97,12 +101,12 @@ class CrashBaseDataset:
def __init__(
self,
name: str = "dataset",
+ reader: Optional[Callable] = None,
data_dir: Optional[str] = None,
split: str = "train",
num_samples: int = 1000,
num_steps: int = 400,
- wall_node_disp_threshold: float = 1.0,
- write_vtp: bool = False,
+ features: Optional[list[str]] = None,
logger=None,
dt: float = 5e-3,
):
@@ -112,6 +116,7 @@ def __init__(
self.split = split
self.num_samples = num_samples
self.num_steps = num_steps
+ self.features = features
self.length = num_samples
self.logger = logger or PythonLogger()
self.dt = dt
@@ -120,102 +125,141 @@ def __init__(
f"[{self.__class__.__name__}] Preparing the {split} dataset..."
)
+ self.features = features or []
+
# Prepare stats dir
self._stats_dir = STATS_DIRNAME
os.makedirs(STATS_DIRNAME, exist_ok=True)
- # Load raw records; we keep (srcs, dsts) for graph dataset; point-cloud ignores them
- self.srcs, self.dsts, point_data = process_d3plot_data(
- self.data_dir,
- num_samples,
- wall_node_disp_threshold,
- write_vtp,
+ # Load raw records via provided reader callable (Hydra can pass a class/callable)
+ if reader is None:
+ raise ValueError("Data reader function is not specified.")
+ self.srcs, self.dsts, point_data = reader(
+ data_dir=self.data_dir,
+ num_samples=num_samples,
+ split=split,
logger=self.logger,
)
# Storage for per-sample tensors
- self.mesh_pos_seq: List[torch.Tensor] = [] # [T, N, 3], float32
- self.thickness_data: List[torch.Tensor] = [] # [N], float32
+ self.mesh_pos_seq: list[torch.Tensor] = [] # [T,N,3]
+ self.node_features_data: list[torch.Tensor] = [] # [N,F]
+ self._feature_slices: dict[
+ str, tuple[int, int]
+ ] = {} # per-sample feature slices
for rec in point_data:
- # Expect keys: "mesh_pos": [T, N, 3], "thickness": [N]
- mesh_pos_np = rec["mesh_pos"][:num_steps]
- thk_np = rec["thickness"]
-
- assert mesh_pos_np.ndim == 3 and mesh_pos_np.shape[-1] == 3, (
- f"mesh_pos must be [T,N,3], got {mesh_pos_np.shape}"
+ # Coordinates
+ if "coords" not in rec:
+ raise KeyError(f"Missing coordinates key 'coords' in reader record")
+ coords_np = rec["coords"][:num_steps]
+ assert coords_np.ndim == 3 and coords_np.shape[-1] == 3, (
+ f"coords must be [T,N,3], got {coords_np.shape}"
+ )
+ self.mesh_pos_seq.append(torch.as_tensor(coords_np, dtype=torch.float32))
+
+ # Features: concatenate requested keys if present; allow empty
+ parts = []
+ for k in self.features:
+ if k not in rec:
+ raise KeyError(f"Missing feature key '{k}' in reader record")
+ arr = rec[k]
+ if arr.ndim == 1:
+ arr = arr[:, None]
+ parts.append(arr)
+
+ feats_np = (
+ np.concatenate(parts, axis=-1)
+ if len(parts) > 0
+ else np.zeros((coords_np.shape[1], 0), dtype=np.float32)
)
- assert thk_np.ndim == 1, f"thickness must be [N], got {thk_np.shape}"
+ assert feats_np.ndim == 2 and feats_np.shape[0] == coords_np.shape[1], (
+ f"features must be [N,F], got {feats_np.shape}, N mismatch with {coords_np.shape}"
+ )
+
+ # build slice map on first record to make future slicing trivial
+ if len(self._feature_slices) == 0:
+ start = 0
+ for k in self.features:
+ width = rec[k].shape[1] if rec[k].ndim > 1 else 1
+ self._feature_slices[k] = (start, start + width)
+ start += width
- self.mesh_pos_seq.append(torch.as_tensor(mesh_pos_np, dtype=torch.float32))
- self.thickness_data.append(torch.as_tensor(thk_np, dtype=torch.float32))
+ self.node_features_data.append(
+ torch.as_tensor(feats_np, dtype=torch.float32)
+ )
- # Stats (node + thickness)
+ # Stats (node + generic features)
node_stats_path = os.path.join(self._stats_dir, NODE_STATS_FILE)
- thk_stats_path = os.path.join(self._stats_dir, THK_STATS_FILE)
+ feat_stats_path = os.path.join(self._stats_dir, FEATURE_STATS_FILE)
if self.split == "train":
self.node_stats = self._compute_autoreg_node_stats()
- self.thickness_stats = self._compute_thickness_stats()
+ self.feature_stats = self._compute_feature_stats()
save_json(self.node_stats, node_stats_path)
- save_json(self.thickness_stats, thk_stats_path)
+ save_json(self.feature_stats, feat_stats_path)
else:
- # Load if exists; otherwise compute and persist
- if os.path.exists(node_stats_path) and os.path.exists(thk_stats_path):
+ if os.path.exists(node_stats_path) and os.path.exists(feat_stats_path):
self.node_stats = load_json(node_stats_path)
- self.thickness_stats = load_json(thk_stats_path)
+ self.feature_stats = load_json(feat_stats_path)
else:
raise FileNotFoundError(
- f"Node stats file {node_stats_path} or thickness stats file {thk_stats_path} not found"
+ f"Node stats file {node_stats_path} or feature stats file {feat_stats_path} not found"
)
- # Normalize trajectories and thickness
+ # Normalize trajectories and features
for i in range(self.num_samples):
self.mesh_pos_seq[i] = self._normalize_node_tensor(
self.mesh_pos_seq[i],
self.node_stats["pos_mean"],
self.node_stats["pos_std"],
)
- self.thickness_data[i] = self._normalize_thickness_tensor(
- self.thickness_data[i],
- torch.tensor(
- self.thickness_stats["thickness_mean"], dtype=torch.float32
- ),
- torch.tensor(
- self.thickness_stats["thickness_std"], dtype=torch.float32
- ),
- )
+ if self.node_features_data[i].numel() > 0:
+ mu = torch.as_tensor(
+ self.feature_stats.get("feature_mean", []), dtype=torch.float32
+ )
+ std = torch.as_tensor(
+ self.feature_stats.get("feature_std", []), dtype=torch.float32
+ )
+ if mu.numel() == 0:
+ continue
+ self.node_features_data[i] = (
+ self.node_features_data[i] - mu.view(1, -1)
+ ) / (std.view(1, -1) + EPS)
def __len__(self):
return self.length
- def _xy_shapes(self, idx: int) -> Tuple[int, int]:
+ def _xy_shapes(self, idx: int) -> tuple[int, int]:
T, N, _ = self.mesh_pos_seq[idx].shape
- Din = 4
+ F = self.node_features_data[idx].shape[1]
+ Din = 3 + F
Dout = (T - 1) * 3
return Din, Dout
# Common x/y construction used by both datasets
def build_xy(self, idx: int):
"""
- x: [N, 4] = pos_t0(3) + thickness(1)
- y: [N, (T-1)*3] flattened all future positions
+ x: dict with two keys:
+ - 'coords': [N, 3] at t0
+ - 'features': [N, F] concatenated in the order given by self.features
+ y: [N, (T-1)*3]
"""
assert 0 <= idx < self.num_samples, f"Index {idx} out of range"
pos_seq = self.mesh_pos_seq[idx] # [T,N,3]
- thk = self.thickness_data[idx] # [N]
+ feats = self.node_features_data[idx] # [N,F]
T, N, _ = pos_seq.shape
+ F = feats.shape[1]
pos_t0 = pos_seq[0] # [N,3]
- thickness_expanded = thk.unsqueeze(1) # [N,1]
- x = torch.cat([pos_t0, thickness_expanded], dim=1) # [N,4]
+ x = {"coords": pos_t0, "features": feats}
# Flatten all future positions along feature dim
y = pos_seq[1:].transpose(0, 1).flatten(start_dim=1) # [N,(T-1)*3]
- Din, Dout = self._xy_shapes(idx)
- assert x.shape == (N, Din), (
- f"x shape mismatch: expected {(N, Din)}, got {x.shape}"
+ _, Dout = self._xy_shapes(idx)
+ assert x["coords"].shape == (N, 3) and x["features"].shape == (N, F), (
+ f"coords shape {x['coords'].shape}, features shape {x['features'].shape}, expected (N,3)/(N,{F})"
)
assert y.shape == (N, Dout), (
f"y shape mismatch: expected {(N, Dout)}, got {y.shape}"
@@ -273,11 +317,28 @@ def _compute_autoreg_node_stats(self):
"norm_acc_std": acc_std,
}
- def _compute_thickness_stats(self):
- all_thickness = torch.cat(self.thickness_data, dim=0) # [sum_N]
- thk_mean = torch.mean(all_thickness)
- thk_std = torch.std(all_thickness)
- return {"thickness_mean": thk_mean, "thickness_std": thk_std}
+ def _compute_feature_stats(self):
+ # If no features, return empty stats compatible with normalization branch
+ fdim = self.node_features_data[0].shape[1]
+ for t in self.node_features_data:
+ assert t.shape[1] == fdim, f"Feature dim mismatch: {t.shape[1]} vs {fdim}"
+
+ if fdim == 0:
+ mu = torch.zeros(0, dtype=torch.float32)
+ std = torch.ones(0, dtype=torch.float32)
+ return {"feature_mean": mu, "feature_std": std}
+
+ feat_mean = torch.zeros(fdim, dtype=torch.float32)
+ feat_meansqr = torch.zeros(fdim, dtype=torch.float32)
+ for i in range(self.num_samples):
+ x = self.node_features_data[i].to(torch.float32)
+ m = torch.mean(x, dim=0)
+ msq = torch.mean(x * x, dim=0)
+ feat_mean += m / self.num_samples
+ feat_meansqr += msq / self.num_samples
+ feat_var = torch.clamp(feat_meansqr - feat_mean * feat_mean, min=0.0)
+ feat_std = torch.sqrt(feat_var + EPS)
+ return {"feature_mean": feat_mean, "feature_std": feat_std}
@staticmethod
def _normalize_node_tensor(
@@ -316,7 +377,7 @@ def __init__(self, *args, **kwargs):
_dsts.append(np.asarray(dst)[mask])
self.srcs, self.dsts = _srcs, _dsts
- self.graphs: List[Data] = []
+ self.graphs: list[Data] = []
for i in range(self.num_samples):
g = self.create_graph(
self.srcs[i],
@@ -330,15 +391,14 @@ def __init__(self, *args, **kwargs):
# Edge stats
edge_stats_path = os.path.join(self._stats_dir, EDGE_STATS_FILE)
- if self.split == "train" and not os.path.exists(edge_stats_path):
+ if self.split == "train":
self.edge_stats = self._compute_edge_stats()
save_json(self.edge_stats, edge_stats_path)
else:
if os.path.exists(edge_stats_path):
self.edge_stats = load_json(edge_stats_path)
else:
- self.edge_stats = self._compute_edge_stats()
- save_json(self.edge_stats, edge_stats_path)
+ raise FileNotFoundError(f"Edge stats file {edge_stats_path} not found")
# Convert loaded stats to tensors
self.edge_stats["edge_mean"] = torch.as_tensor(
@@ -359,7 +419,7 @@ def __init__(self, *args, **kwargs):
def __getitem__(self, idx: int):
assert 0 <= idx < self.num_samples, f"Index {idx} out of range"
g = self.graphs[idx]
- x, y = self.build_xy(idx) # [N,4], [N,(T-1)*3]
+ x, y = self.build_xy(idx) # [N,3+F], [N,(T-1)*3]
return SimSample(
node_features=x,
@@ -426,7 +486,7 @@ class CrashPointCloudDataset(CrashBaseDataset):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- self.edge_stats: Dict[str, Any] = {}
+ self.edge_stats: dict[str, Any] = {}
def __getitem__(self, idx: int):
assert 0 <= idx < self.num_samples, f"Index {idx} out of range"
@@ -434,7 +494,7 @@ def __getitem__(self, idx: int):
return SimSample(node_features=x, node_target=y)
-def simsample_collate(batch: List[SimSample]) -> List[SimSample]:
+def simsample_collate(batch: list[SimSample]) -> list[SimSample]:
"""
Keep samples as a list (variable N per item is common here).
Models should iterate the list or implement internal padding.
diff --git a/examples/structural_mechanics/crash/inference.py b/examples/structural_mechanics/crash/inference.py
index 080e1b7958..983a61aac5 100644
--- a/examples/structural_mechanics/crash/inference.py
+++ b/examples/structural_mechanics/crash/inference.py
@@ -168,8 +168,9 @@ def run_on_single_run(self, run_path: str):
k: v.to(self.device)
for k, v in getattr(dataset, "edge_stats", {}).items()
},
- thickness={
- k: v.to(self.device) for k, v in dataset.thickness_stats.items()
+ feature={
+ k: v.to(self.device)
+ for k, v in getattr(dataset, "feature_stats", {}).items()
},
)
diff --git a/examples/structural_mechanics/crash/rollout.py b/examples/structural_mechanics/crash/rollout.py
index 86a1a12510..3e7c0725c7 100644
--- a/examples/structural_mechanics/crash/rollout.py
+++ b/examples/structural_mechanics/crash/rollout.py
@@ -17,7 +17,6 @@
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint as ckpt
-from typing import List
from physicsnemo.models.transolver import Transolver
from physicsnemo.models.meshgraphnet import MeshGraphNet
@@ -49,16 +48,17 @@ def forward(self, sample: SimSample, data_stats: dict) -> torch.Tensor:
Returns:
[T, N, 3] rollout of predicted positions
"""
- node_features = sample.node_features # [N,F_in]
- N = sample.node_features.size(0)
- device = sample.node_features.device
+ inputs = sample.node_features
+ coords = inputs["coords"] # [N,3]
+ features = inputs.get("features", coords.new_zeros((coords.size(0), 0)))
+ N = coords.size(0)
+ device = coords.device
# Initial states
- y_t1 = node_features[..., :3] # [N,3]
- thickness = node_features[..., -1:] # [N,1]
+ y_t1 = coords # [N,3]
y_t0 = y_t1 - self.initial_vel * self.dt # backstep using initial velocity
- outputs: List[torch.Tensor] = []
+ outputs: list[torch.Tensor] = []
for t in range(self.rollout_steps):
time_t = 0.0 if self.rollout_steps <= 1 else t / (self.rollout_steps - 1)
time_t = torch.tensor([time_t], device=device, dtype=torch.float32)
@@ -71,8 +71,8 @@ def forward(self, sample: SimSample, data_stats: dict) -> torch.Tensor:
# Model input
fx_t = torch.cat(
- [vel_norm, thickness, time_t.expand(N, 1)], dim=-1
- ) # [N, 3+1+1]
+ [vel_norm, features, time_t.expand(N, 1)], dim=-1
+ ) # [N, 3+F+1]
def step_fn(fx, embedding):
return super(TransolverAutoregressiveRolloutTraining, self).forward(
@@ -123,18 +123,15 @@ def forward(
Returns:
[T, N, 3] rollout of predicted positions
"""
- node_features = sample.node_features # [N,4] (pos(3) + thickness(1))
- assert node_features.ndim == 2 and node_features.shape[1] == 4, (
- f"Expected node_features [N,4], got {node_features.shape}"
- )
+ inputs = sample.node_features
+ x = inputs["coords"] # [N,3]
+ features = inputs.get("features", x.new_zeros((x.size(0), 0))) # [N,F]
- x = node_features[..., :3] # initial pos
- thickness = node_features[..., -1:]
- outputs: List[torch.Tensor] = []
+ outputs: list[torch.Tensor] = []
time_seq = torch.linspace(0.0, 1.0, self.rollout_steps, device=x.device)
for time in time_seq:
- fx_t = thickness # [N,1]
+ fx_t = features # [N,F]
def step_fn(fx, embedding, time_t):
return super(TransolverTimeConditionalRollout, self).forward(
@@ -177,22 +174,25 @@ def forward(self, sample: SimSample, data_stats: dict) -> torch.Tensor:
Returns:
[T, N, 3] rollout of predicted positions
"""
- node_features = sample.node_features
+ inputs = sample.node_features
+ coords = inputs["coords"] # [N,3]
+ features = inputs.get(
+ "features", coords.new_zeros((coords.size(0), 0))
+ ) # [N,F]
edge_features = sample.graph.edge_attr
graph = sample.graph
- N = node_features.size(0)
- y_t1 = node_features[..., :3]
- thickness = node_features[..., -1:]
+ N = coords.size(0)
+ y_t1 = coords
y_t0 = y_t1 - self.initial_vel * self.dt
- outputs: List[torch.Tensor] = []
+ outputs: list[torch.Tensor] = []
for _ in range(self.rollout_steps):
vel = (y_t1 - y_t0) / self.dt
vel_norm = (vel - data_stats["node"]["norm_vel_mean"]) / (
data_stats["node"]["norm_vel_std"] + EPS
)
- fx_t = torch.cat([y_t1, vel_norm, thickness], dim=-1)
+ fx_t = torch.cat([y_t1, vel_norm, features], dim=-1)
def step_fn(nf, ef, g):
return super(MeshGraphNetAutoregressiveRolloutTraining, self).forward(
@@ -234,17 +234,17 @@ def forward(self, sample: SimSample, data_stats: dict) -> torch.Tensor:
Returns:
[T, N, 3] rollout of predicted positions
"""
- node_features = sample.node_features
+ inputs = sample.node_features
+ x = inputs["coords"] # [N,3]
+ features = inputs.get("features", x.new_zeros((x.size(0), 0))) # [N,F]
edge_features = sample.graph.edge_attr
graph = sample.graph
- x = node_features[..., :3]
- thickness = node_features[..., -1:]
- outputs: List[torch.Tensor] = []
+ outputs: list[torch.Tensor] = []
time_seq = torch.linspace(0.0, 1.0, self.rollout_steps, device=x.device)
for time in time_seq:
- fx_t = torch.cat([x, thickness, time.expand(x.size(0), 1)], dim=-1)
+ fx_t = torch.cat([x, features, time.expand(x.size(0), 1)], dim=-1)
def step_fn(nf, ef, g):
return super(MeshGraphNetTimeConditionalRollout, self).forward(
@@ -279,19 +279,18 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, sample: SimSample, data_stats: dict) -> torch.Tensor:
- N = sample.node_features.size(0)
- thickness = sample.node_features[..., -1:] # [N,1]
+ inputs = sample.node_features
+ coords0 = inputs["coords"] # [N,3]
+ features = inputs.get("features", coords0.new_zeros((coords0.size(0), 0)))
# Ground truth sequence [T,N,3]
+ N = coords0.size(0)
gt_seq = torch.cat(
- [
- sample.node_features[..., :3].unsqueeze(0), # pos_t0
- sample.node_target.view(N, -1, 3).transpose(0, 1), # pos_t1..pos_T
- ],
+ [coords0.unsqueeze(0), sample.node_target.view(N, -1, 3).transpose(0, 1)],
dim=0,
)
- outputs: List[torch.Tensor] = []
+ outputs: list[torch.Tensor] = []
# First step: backstep to create y_-1
y_t0 = gt_seq[0] - self.initial_vel * self.dt
@@ -306,7 +305,7 @@ def forward(self, sample: SimSample, data_stats: dict) -> torch.Tensor:
vel_norm = (vel - data_stats["node"]["norm_vel_mean"]) / (
data_stats["node"]["norm_vel_std"] + EPS
)
- fx_t = torch.cat([vel_norm, thickness], dim=-1)
+ fx_t = torch.cat([vel_norm, features], dim=-1)
def step_fn(fx, embedding):
return super(TransolverOneStepRollout, self).forward(
@@ -350,23 +349,22 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, sample: SimSample, data_stats: dict) -> torch.Tensor:
- node_features = sample.node_features
+ inputs = sample.node_features
+ coords0 = inputs["coords"] # [N,3]
+ features = inputs.get(
+ "features", coords0.new_zeros((coords0.size(0), 0))
+ ) # [N,F]
edge_features = sample.graph.edge_attr
graph = sample.graph
- N = node_features.size(0)
- thickness = node_features[..., -1:]
-
# Full ground truth trajectory [T,N,3]
+ N = coords0.size(0)
gt_seq = torch.cat(
- [
- node_features[..., :3].unsqueeze(0), # pos_t0
- sample.node_target.view(N, -1, 3).transpose(0, 1), # pos_t1..T
- ],
+ [coords0.unsqueeze(0), sample.node_target.view(N, -1, 3).transpose(0, 1)],
dim=0,
)
- outputs: List[torch.Tensor] = []
+ outputs: list[torch.Tensor] = []
# First step: construct backstep
y_t0 = gt_seq[0] - self.initial_vel * self.dt
@@ -382,7 +380,7 @@ def forward(self, sample: SimSample, data_stats: dict) -> torch.Tensor:
data_stats["node"]["norm_vel_std"] + EPS
)
- fx_t = torch.cat([y_t1, vel_norm, thickness], dim=-1)
+ fx_t = torch.cat([y_t1, vel_norm, features], dim=-1)
def step_fn(nf, ef, g):
return super(MeshGraphNetOneStepRollout, self).forward(
diff --git a/examples/structural_mechanics/crash/tests/test_rollout.py b/examples/structural_mechanics/crash/tests/test_rollout.py
new file mode 100644
index 0000000000..f691ef7e6b
--- /dev/null
+++ b/examples/structural_mechanics/crash/tests/test_rollout.py
@@ -0,0 +1,183 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+from typing import Dict
+
+import torch
+import pytest
+
+
+# Ensure we can import modules from the crash example directory
+THIS_DIR = os.path.dirname(__file__)
+CRASH_DIR = os.path.abspath(os.path.join(THIS_DIR, ".."))
+if CRASH_DIR not in sys.path:
+ sys.path.insert(0, CRASH_DIR)
+
+import rollout # noqa: E402
+from datapipe import SimSample # noqa: E402
+
+
+def make_sample(N: int = 5, T: int = 4, F: int = 2) -> SimSample:
+ torch.manual_seed(0)
+ coords = torch.randn(N, 3)
+ features = torch.randn(N, F)
+ # create ground-truth future positions flattened: [N, (T-1)*3]
+ future = torch.randn(N, (T - 1) * 3)
+
+ class DummyGraph:
+ pass
+
+ graph = DummyGraph()
+ graph.edge_attr = torch.zeros(0, 1)
+
+ node_inputs: Dict[str, torch.Tensor] = {"coords": coords, "features": features}
+ return SimSample(node_features=node_inputs, node_target=future, graph=graph)
+
+
+def make_data_stats() -> Dict[str, Dict[str, torch.Tensor]]:
+ # Broadcastable stats: [1, 3]
+ zeros = torch.zeros(1, 3)
+ ones = torch.ones(1, 3)
+ return {
+ "node": {
+ "norm_vel_mean": zeros,
+ "norm_vel_std": ones,
+ "norm_acc_mean": zeros,
+ "norm_acc_std": ones,
+ }
+ }
+
+
+@pytest.fixture(autouse=True)
+def stub_parent_classes(monkeypatch):
+ # Stub Transolver.__init__ and Transolver.forward
+ def transolver_init(self, *args, **kwargs):
+ torch.nn.Module.__init__(self)
+
+ def transolver_forward(self, fx=None, embedding=None, time=None):
+ # Match shapes expected downstream: return zeros like embedding
+ assert embedding is not None
+ return torch.zeros_like(embedding)
+
+ monkeypatch.setattr(rollout.Transolver, "__init__", transolver_init, raising=True)
+ monkeypatch.setattr(rollout.Transolver, "forward", transolver_forward, raising=True)
+
+ # Stub MeshGraphNet.__init__ and MeshGraphNet.forward
+ def mgn_init(self, *args, **kwargs):
+ torch.nn.Module.__init__(self)
+
+ def mgn_forward(self, node_features=None, edge_features=None, graph=None):
+ # Return zeros acceleration with shape [N, 3]
+ assert node_features is not None
+ N = node_features.shape[0]
+ return torch.zeros(N, 3, dtype=node_features.dtype, device=node_features.device)
+
+ monkeypatch.setattr(rollout.MeshGraphNet, "__init__", mgn_init, raising=True)
+ monkeypatch.setattr(rollout.MeshGraphNet, "forward", mgn_forward, raising=True)
+
+
+def test_transolver_autoregressive_rollout_eval():
+ N, T, F = 5, 4, 2
+ sample = make_sample(N=N, T=T, F=F)
+ stats = make_data_stats()
+
+ model = rollout.TransolverAutoregressiveRolloutTraining(
+ dt=5e-3, initial_vel=torch.zeros(1, 3), num_time_steps=T
+ )
+ model.eval()
+
+ out = model.forward(sample=sample, data_stats=stats)
+ assert out.shape == (T - 1, N, 3)
+
+
+def test_transolver_time_conditional_rollout_eval():
+ N, T, F = 6, 5, 3
+ sample = make_sample(N=N, T=T, F=F)
+ stats = make_data_stats()
+
+ model = rollout.TransolverTimeConditionalRollout(num_time_steps=T)
+ model.eval()
+
+ out = model.forward(sample=sample, data_stats=stats)
+ assert out.shape == (T - 1, N, 3)
+
+
+def test_transolver_one_step_rollout_eval():
+ N, T, F = 7, 6, 1
+ sample = make_sample(N=N, T=T, F=F)
+ stats = make_data_stats()
+
+ model = rollout.TransolverOneStepRollout(
+ dt=5e-3, initial_vel=torch.zeros(1, 3), num_time_steps=T
+ )
+ model.eval()
+
+ out = model.forward(sample=sample, data_stats=stats)
+ assert out.shape == (T - 1, N, 3)
+
+
+def test_meshgraphnet_autoregressive_rollout_eval():
+ N, T, F = 4, 4, 2
+ sample = make_sample(N=N, T=T, F=F)
+ stats = make_data_stats()
+
+ model = rollout.MeshGraphNetAutoregressiveRolloutTraining(
+ dt=5e-3, initial_vel=torch.zeros(1, 3), num_time_steps=T
+ )
+ model.eval()
+
+ out = model.forward(sample=sample, data_stats=stats)
+ assert out.shape == (T - 1, N, 3)
+
+
+def test_meshgraphnet_time_conditional_rollout_eval():
+ N, T, F = 3, 5, 4
+ sample = make_sample(N=N, T=T, F=F)
+ stats = make_data_stats()
+
+ model = rollout.MeshGraphNetTimeConditionalRollout(num_time_steps=T)
+ model.eval()
+
+ out = model.forward(sample=sample, data_stats=stats)
+ assert out.shape == (T - 1, N, 3)
+
+
+def test_meshgraphnet_one_step_rollout_eval():
+ N, T, F = 8, 3, 0
+ # allow zero features
+ torch.manual_seed(0)
+ coords = torch.randn(N, 3)
+ future = torch.randn(N, (T - 1) * 3)
+
+ class DummyGraph:
+ pass
+
+ graph = DummyGraph()
+ graph.edge_attr = torch.zeros(0, 1)
+
+ node_inputs = {"coords": coords, "features": coords.new_zeros((N, 0))}
+ sample = SimSample(node_features=node_inputs, node_target=future, graph=graph)
+ stats = make_data_stats()
+
+ model = rollout.MeshGraphNetOneStepRollout(
+ dt=5e-3, initial_vel=torch.zeros(1, 3), num_time_steps=T
+ )
+ model.eval()
+
+ out = model.forward(sample=sample, data_stats=stats)
+ assert out.shape == (T - 1, N, 3)
diff --git a/examples/structural_mechanics/crash/train.py b/examples/structural_mechanics/crash/train.py
index 31caeffb33..47ef0331f8 100644
--- a/examples/structural_mechanics/crash/train.py
+++ b/examples/structural_mechanics/crash/train.py
@@ -66,12 +66,16 @@ def __init__(self, cfg: DictConfig, logger0: RankZeroLoggingWrapper):
)
# Dataset
+ reader = instantiate(cfg.reader)
+ logging.getLogger().setLevel(logging.INFO)
dataset = instantiate(
cfg.datapipe,
name="crash_train",
+ reader=reader,
split="train",
logger=logger0,
)
+ logging.getLogger().setLevel(logging.INFO)
# Move stats to device
self.data_stats = dict(
node={k: v.to(self.dist.device) for k, v in dataset.node_stats.items()},
@@ -79,8 +83,9 @@ def __init__(self, cfg: DictConfig, logger0: RankZeroLoggingWrapper):
k: v.to(self.dist.device)
for k, v in getattr(dataset, "edge_stats", {}).items()
},
- thickness={
- k: v.to(self.dist.device) for k, v in dataset.thickness_stats.items()
+ feature={
+ k: v.to(self.dist.device)
+ for k, v in getattr(dataset, "feature_stats", {}).items()
},
)
diff --git a/examples/structural_mechanics/crash/vtp_reader.py b/examples/structural_mechanics/crash/vtp_reader.py
new file mode 100644
index 0000000000..ec8a8eea96
--- /dev/null
+++ b/examples/structural_mechanics/crash/vtp_reader.py
@@ -0,0 +1,234 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import re
+import numpy as np
+import pyvista as pv
+
+
+def find_run_folders(base_data_dir):
+ """Return a list of absolute VTP file paths; each file is a separate sample."""
+ if not os.path.isdir(base_data_dir):
+ return []
+ vtps = [
+ os.path.join(base_data_dir, f)
+ for f in os.listdir(base_data_dir)
+ if f.lower().endswith(".vtp")
+ ]
+
+ def natural_key(name):
+ return [
+ int(s) if s.isdigit() else s.lower()
+ for s in re.findall(r"\d+|\D+", os.path.basename(name))
+ ]
+
+ return sorted(vtps, key=natural_key)
+
+
+def extract_mesh_connectivity_from_polydata(poly: pv.PolyData):
+ """Extract mesh connectivity (list of cells with node indices) from a PolyData."""
+ faces = poly.faces
+ connectivity = []
+ i = 0
+ n = faces.size
+ while i < n:
+ fsz = int(faces[i])
+ ids = faces[i + 1 : i + 1 + fsz].tolist()
+ if len(ids) >= 3:
+ connectivity.append(ids)
+ i += 1 + fsz
+ return connectivity
+
+
+def load_vtp_file(vtp_path):
+ """Load positions over time and connectivity from a single VTP file.
+
+ Expects displacement fields in point_data named like:
+ - displacement_t0.000, displacement_t0.005, ..., displacement_t0.100
+ Returns:
+ pos_raw: (timesteps, num_nodes, 3) absolute positions (coords + displacement_t)
+ mesh_connectivity: list[list[int]]
+ """
+ poly = pv.read(vtp_path)
+ if not isinstance(poly, pv.PolyData):
+ poly = poly.extract_surface().cast_to_polydata()
+
+ coords = np.array(poly.points, dtype=np.float64)
+
+ # Collect displacement vector arrays (3 components) and sort naturally
+ disp_names = [
+ name
+ for name in poly.point_data.keys()
+ if re.match(r"displacement_t0\.[0-9]{3}$", name)
+ ]
+ if not disp_names:
+ disp_names = [
+ name for name in poly.point_data.keys() if name.startswith("displacement_t")
+ ]
+ if not disp_names:
+ raise ValueError(f"No displacement fields found in {vtp_path}")
+
+ def natural_key(name):
+ return [
+ int(s) if s.isdigit() else s.lower() for s in re.findall(r"\d+|\D+", name)
+ ]
+
+ disp_names = sorted(disp_names, key=natural_key)
+
+ pos_list = []
+ for idx, name in enumerate(disp_names):
+ disp = np.asarray(poly.point_data[name])
+ if disp.ndim != 2 or disp.shape[1] != 3:
+ raise ValueError(
+ f"Point-data array '{name}' must be a 3-component vector (got shape {disp.shape})."
+ )
+ # Force zero displacement at t0: pos_raw[0] = coords
+ if idx == 0:
+ pos_list.append(coords)
+ else:
+ pos_list.append(coords + disp)
+
+ pos_raw = np.stack(pos_list, axis=0)
+ mesh_connectivity = extract_mesh_connectivity_from_polydata(poly)
+ return pos_raw, mesh_connectivity
+
+
+def build_edges_from_mesh_connectivity(mesh_connectivity):
+ """Build unique edges from mesh connectivity (cells of any size)."""
+ edges = set()
+ for cell in mesh_connectivity:
+ n = len(cell)
+ for idx in range(n):
+ edge = tuple(sorted((cell[idx], cell[(idx + 1) % n])))
+ edges.add(edge)
+ return edges
+
+
+def collect_mesh_pos(
+ output_dir, pos_raw, filtered_mesh_connectivity, write_vtp=False, logger=None
+):
+ """Write VTP files for each timestep and collect mesh/point data."""
+ n_timesteps = pos_raw.shape[0]
+ mesh_pos_all = []
+ pos0 = pos_raw[0] # reference for displacement
+ for t in range(n_timesteps):
+ pos = pos_raw[t, :, :]
+
+ faces = []
+ for cell in filtered_mesh_connectivity:
+ if len(cell) == 3:
+ faces.extend([3, *cell])
+ elif len(cell) == 4:
+ faces.extend([4, *cell])
+ elif len(cell) > 4:
+ continue
+
+ faces = np.array(faces)
+ mesh = pv.PolyData(pos, faces)
+
+ # Add displacement vector relative to t0
+ disp = pos - pos0
+ mesh.point_data["displacement"] = disp
+
+ if write_vtp:
+ filename = os.path.join(output_dir, f"frame_{t:03d}.vtp")
+ mesh.save(filename)
+ if write_vtp and logger:
+ logger.info(f"Saved: {filename}")
+
+ mesh_pos_all.append(pos)
+ return np.stack(mesh_pos_all)
+
+
+def process_vtp_data(data_dir, num_samples=2, write_vtp=False, logger=None):
+ """
+ Preprocesses VTP crash simulation data in a given directory.
+ Each .vtp file is treated as one sample. For each sample, computes edges from connectivity,
+ keeps all nodes, and optionally writes VTP files for each timestep.
+ Returns lists of source/destination node indices and point data for all samples.
+ """
+ processed_runs = 0
+ base_data_dir = data_dir
+ vtp_files = find_run_folders(base_data_dir)
+ srcs, dsts = [], []
+ point_data_all = []
+
+ if not vtp_files:
+ if logger:
+ logger.error("No .vtp files found in:", base_data_dir)
+ exit(1)
+
+ for vtp_path in vtp_files:
+ if logger:
+ logger.info(f"Processing {vtp_path}...")
+ output_dir = f"./output_{os.path.splitext(os.path.basename(vtp_path))[0]}"
+ os.makedirs(output_dir, exist_ok=True)
+
+ pos_raw, mesh_connectivity = load_vtp_file(vtp_path)
+
+ # Use unfiltered data
+ filtered_pos_raw = pos_raw
+ filtered_mesh_connectivity = mesh_connectivity
+
+ # Build edges and sanity-check ranges
+ edges = build_edges_from_mesh_connectivity(filtered_mesh_connectivity)
+ edge_arr = np.array(list(edges), dtype=np.int64)
+ assert edge_arr.min() >= 0 and edge_arr.max() < filtered_pos_raw.shape[1]
+
+ src, dst = np.array(list(edges)).T
+ srcs.append(src)
+ dsts.append(dst)
+
+ mesh_pos_all = collect_mesh_pos(
+ output_dir,
+ filtered_pos_raw,
+ filtered_mesh_connectivity,
+ write_vtp=write_vtp,
+ logger=logger,
+ )
+ point_data_all.append({"coords": mesh_pos_all})
+
+ processed_runs += 1
+ if processed_runs >= num_samples:
+ break
+
+ return srcs, dsts, point_data_all
+
+
+class Reader:
+ """
+ Reader for VTP files.
+ """
+
+ def __init__(self):
+ pass
+
+ def __call__(
+ self,
+ data_dir: str,
+ num_samples: int,
+ split: str | None = None,
+ logger=None,
+ **kwargs,
+ ):
+ write_vtp = False if split == "train" else True
+ return process_vtp_data(
+ data_dir=data_dir,
+ num_samples=num_samples,
+ write_vtp=write_vtp,
+ logger=logger,
+ )