Skip to content

Commit 06d6aa3

Browse files
authored
Merge branch 'NVIDIA:main' into interp-model-example
2 parents 2a6c06a + 470e6fa commit 06d6aa3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+3836
-138
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2222
models. Accessible in `examples/geophysics/diffusion_fwi`.
2323
- Domain Parallelism: Domain Parallelism is now available for kNN, radius_search,
2424
and torch.nn.functional.pad.
25+
- Unified recipe for crash modeling, supporting Transolver and MeshGraphNet,
26+
and three transient schemes.
2527
- Added a check to `stochastic_sampler` that helps handle the `EDMPrecond` model,
2628
which has a specific `.forward()` signature
2729

examples/cfd/external_aerodynamics/domino/README.md

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -188,13 +188,14 @@ GPUs and perform operations in a numerically consistent way. For more informati
188188
about the techniques of domain parallelism and `ShardTensor`, refer to PhysicsNeMo
189189
tutorials such as [`ShardTensor`](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/api/physicsnemo.distributed.shardtensor.html).
190190

191-
In DoMINO specifically, domain parallelism has been abled in two ways, which
191+
In DoMINO specifically, domain parallelism has been enabled in two ways, which
192192
can be used concurrently or separately. First, the input sampled volumetric
193193
and surface points can be sharded to accomodate higher resolution point sampling
194194
Second, the latent space of the model - typically a regularlized grid - can be
195195
sharded to reduce computational complexity of the latent processing. When training
196196
with sharded models in DoMINO, the primary objective is to enable higher
197-
resolution inputs and larger latent spaces without sacrificing substantial compute time.
197+
resolution inputs and larger latent spaces without sacrificing
198+
substantial compute time.
198199

199200
When configuring DoMINO for sharded training, adjust the following parameters
200201
from `src/conf/config.yaml`:
@@ -207,19 +208,13 @@ domain_parallelism:
207208
```
208209
209210
The `domain_size` represents the number of GPUs used for each batch - setting
210-
`domain_size: 1` is not advised since that is the standard training regime,
211-
but with extra overhead. `shard_grid` and `shard_points` will enable domain
211+
`domain_size: 1` is the standard training regime, and domain_parallelism
212+
will be ignored. `shard_grid` and `shard_points` will enable domain
212213
parallelism over the latent space and input/output points, respectively.
213214

214-
As one last note regarding domain-parallel training: in the phase of the DoMINO
215-
where the output solutions are calculated, the model can used two different
216-
techniques (numerically identical) to calculate the output. Due to the
217-
overhead of potential communication at each operation, it's recommended to
218-
use the `one-loop` mode with `model.solution_calculation_mode` when doing
219-
sharded training. This technique launches vectorized kernels with less
220-
launch overhead at the cost of more memory use. For non-sharded
221-
training, the `two-loop` setting is more optimal. The difference in `one-loop`
222-
or `two-loop` is purely computational, not algorithmic.
215+
Setting domain_size > 1 without specifying `shard_points=True` or `shard_grid=True`
216+
will result in a runtime error during configuration - if you do not want to use
217+
domain_parallelism, leave `domain_size=1`.
223218

224219
### Performance Optimizations
225220

examples/cfd/external_aerodynamics/domino/src/train.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@
4444

4545
import torchinfo
4646
import torch.distributed as dist
47+
from torch.distributed.fsdp import fully_shard
48+
from torch.distributed.tensor import distribute_module
49+
4750
from torch.amp import GradScaler, autocast
4851
from torch.nn.parallel import DistributedDataParallel
4952
from torch.utils.data import DataLoader
@@ -333,6 +336,13 @@ def main(cfg: DictConfig) -> None:
333336
# how to set that up, if needed.
334337
domain_mesh, data_mesh, placements = coordinate_distributed_environment(cfg)
335338

339+
if data_mesh is not None:
340+
data_replica_size = data_mesh.size()
341+
data_rank = data_mesh.get_local_rank()
342+
else:
343+
data_replica_size = dist.world_size
344+
data_rank = dist.rank
345+
336346
################################
337347
# Initialize NVML
338348
################################
@@ -438,8 +448,8 @@ def main(cfg: DictConfig) -> None:
438448
)
439449
train_sampler = DistributedSampler(
440450
train_dataloader,
441-
num_replicas=data_mesh.size(),
442-
rank=data_mesh.get_local_rank(),
451+
num_replicas=data_replica_size,
452+
rank=data_rank,
443453
**cfg.train.sampler,
444454
)
445455

@@ -458,8 +468,8 @@ def main(cfg: DictConfig) -> None:
458468
)
459469
val_sampler = DistributedSampler(
460470
val_dataloader,
461-
num_replicas=data_mesh.size(),
462-
rank=data_mesh.get_local_rank(),
471+
num_replicas=data_replica_size,
472+
rank=data_rank,
463473
**cfg.val.sampler,
464474
)
465475

@@ -478,15 +488,22 @@ def main(cfg: DictConfig) -> None:
478488
logger.info(f"Model summary:\n{torchinfo.summary(model, verbose=0, depth=2)}\n")
479489

480490
if dist.world_size > 1:
481-
model = DistributedDataParallel(
482-
model,
483-
device_ids=[dist.local_rank],
484-
output_device=dist.device,
485-
broadcast_buffers=dist.broadcast_buffers,
486-
find_unused_parameters=dist.find_unused_parameters,
487-
gradient_as_bucket_view=True,
488-
static_graph=True,
489-
)
491+
if domain_mesh is None:
492+
model = DistributedDataParallel(
493+
model,
494+
device_ids=[dist.local_rank],
495+
output_device=dist.device,
496+
broadcast_buffers=dist.broadcast_buffers,
497+
find_unused_parameters=dist.find_unused_parameters,
498+
gradient_as_bucket_view=True,
499+
static_graph=True,
500+
)
501+
else:
502+
model = distribute_module(
503+
model,
504+
device_mesh=domain_mesh,
505+
)
506+
model = fully_shard(model, mesh=data_mesh)
490507

491508
######################################################
492509
# Initialize optimzer and gradient scaler

examples/cfd/external_aerodynamics/domino/src/utils.py

Lines changed: 64 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -170,44 +170,72 @@ def coordinate_distributed_environment(cfg: DictConfig):
170170
# Default to no domain parallelism:
171171
domain_size = cfg.get("domain_parallelism", {}).get("domain_size", 1)
172172

173-
# Initialize the device mesh:
174-
mesh = dist.initialize_mesh(
175-
mesh_shape=(-1, domain_size), mesh_dim_names=("ddp", "domain")
176-
)
177-
domain_mesh = mesh["domain"]
178-
data_mesh = mesh["ddp"]
179-
180-
if domain_size > 1:
181-
# Define the default placements for each tensor that might show up in
182-
# the data. Note that we'll define placements for all keys, even if
183-
# they aren't actually used.
184-
185-
# Note that placements are defined for pre-batched data, no batch index!
186-
187-
grid_like_placement = [
188-
Shard(0),
189-
]
190-
point_like_placement = [
191-
Shard(0),
192-
]
193-
replicate_placement = [
194-
Replicate(),
195-
]
196-
placements = {
197-
"stl_coordinates": point_like_placement,
198-
"stl_centers": point_like_placement,
199-
"stl_faces": point_like_placement,
200-
"stl_areas": point_like_placement,
201-
"surface_fields": point_like_placement,
202-
"volume_mesh_centers": point_like_placement,
203-
"volume_fields": point_like_placement,
204-
"surface_mesh_centers": point_like_placement,
205-
"surface_normals": point_like_placement,
206-
"surface_areas": point_like_placement,
207-
}
208-
else:
173+
if dist.world_size == 1:
209174
domain_mesh = None
175+
data_mesh = None
210176
placements = None
177+
else:
178+
# Initialize the device mesh:
179+
mesh = dist.initialize_mesh(
180+
mesh_shape=(-1, domain_size), mesh_dim_names=("ddp", "domain")
181+
)
182+
domain_mesh = mesh["domain"]
183+
data_mesh = mesh["ddp"]
184+
185+
if domain_size > 1:
186+
# Define the default placements for each tensor that might show up in
187+
# the data. Note that we'll define placements for all keys, even if
188+
# they aren't actually used.
189+
190+
# Note that placements are defined for pre-batched data, no batch index!
191+
192+
shard_grid = cfg.get("domain_parallelism", {}).get("shard_grid", False)
193+
shard_points = cfg.get("domain_parallelism", {}).get("shard_points", False)
194+
195+
if not shard_grid and not shard_points:
196+
raise ValueError(
197+
"Either shard_grid or shard_points must be True if domain_size > 1"
198+
)
199+
200+
# Not supported with physics loss:
201+
if cfg.train.add_physics_loss:
202+
raise ValueError(
203+
"Domain parallelism is not supported with physics loss"
204+
)
205+
206+
if shard_grid:
207+
grid_like_placement = [
208+
Shard(0),
209+
]
210+
else:
211+
grid_like_placement = [
212+
Replicate(),
213+
]
214+
215+
if shard_points:
216+
point_like_placement = [
217+
Shard(0),
218+
]
219+
else:
220+
point_like_placement = [
221+
Replicate(),
222+
]
223+
224+
placements = {
225+
"stl_coordinates": point_like_placement,
226+
"stl_centers": point_like_placement,
227+
"stl_faces": point_like_placement,
228+
"stl_areas": point_like_placement,
229+
"surface_fields": point_like_placement,
230+
"volume_mesh_centers": point_like_placement,
231+
"volume_fields": point_like_placement,
232+
"surface_mesh_centers": point_like_placement,
233+
"surface_normals": point_like_placement,
234+
"surface_areas": point_like_placement,
235+
}
236+
else:
237+
domain_mesh = None
238+
placements = None
211239

212240
return domain_mesh, data_mesh, placements
213241

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
<!-- markdownlint-disable -->
2+
# Machine Learning Surrogates for Automotive Crash Dynamics
3+
4+
## Problem Overview
5+
6+
Automotive crashworthiness assessment is a critical step in vehicle design.
7+
Traditionally, engineers rely on high-fidelity finite element (FE)
8+
simulations (e.g., LS-DYNA) to predict structural deformation and crash responses.
9+
While accurate, these simulations are computationally expensive and
10+
limit the speed of design iterations.
11+
12+
Machine Learning (ML) surrogates provide a promising alternative by learning
13+
mappings directly from simulation data, enabling:
14+
15+
- **Rapid prediction** of deformation histories across thousands of design candidates.
16+
- **Scalability** to large structural models without rerunning costly FE simulations.
17+
- **Flexibility** in experimenting with different model architectures (GNNs, Transformers).
18+
19+
In this example, we demonstrate a unified pipeline for crash dynamics modeling.
20+
The implementation supports both:
21+
22+
- **Mesh-based Graph Neural Networks (MeshGraphNet)** – leverage connectivity from FE meshes.
23+
- **Point-cloud Transformers (Transolver)** – avoid explicit mesh dependency.
24+
25+
## Prerequisites
26+
27+
This example requires:
28+
- Access to LS-DYNA crash datasets (with `d3plot` and `.k` keyword files).
29+
- A GPU-enabled environment with PyTorch.
30+
31+
Install dependencies:
32+
33+
```bash
34+
pip install -r requirements.txt
35+
```
36+
37+
This will install:
38+
39+
- lasso-python (for LS-DYNA file parsing),
40+
- torch_geometric and torch_scatter (for GNN operations),
41+
42+
## Dataset Preprocessing
43+
44+
Crash simulation data is parsed from LS-DYNA d3plot files using the d3plot_reader.py utility.
45+
46+
Key steps:
47+
48+
- Load node coordinates, displacements, element connectivity, and part IDs.
49+
- Parse .k keyword files to assign part thickness values.
50+
- Filter out rigid wall nodes using displacement thresholds.
51+
- Build edges (for graphs) and store per-node features (e.g., thickness).
52+
- Optionally export time-stepped meshes as .vtp for visualization.
53+
54+
Run preprocessing automatically via the dataset class (CrashGraphDataset or CrashPointCloudDataset) when launching training or inference.
55+
56+
## Training
57+
58+
Training is managed via Hydra configurations located in conf/.
59+
The main script is train.py.
60+
61+
Config Structure
62+
63+
```bash
64+
conf/
65+
├── config.yaml # master config (sets datapipe, model, training)
66+
├── datapipe/ # dataset configs
67+
│ ├── graph.yaml
68+
│ └── point_cloud.yaml
69+
├── model/ # model configs
70+
│ ├── mgn_autoregressive_rollout_training.yaml
71+
│ ├── mgn_one_step_rollout.yaml
72+
│ ├── mgn_time_conditional.yaml
73+
│ ├── transolver_autoregressive_rollout_training.yaml
74+
│ ├── transolver_one_step_rollout.yaml
75+
│ └── transolver_time_conditional.yaml
76+
├── training/default.yaml # training hyperparameters
77+
└── inference/default.yaml # inference options
78+
```
79+
80+
Launch Training
81+
Single GPU:
82+
83+
```bash
84+
python train.py
85+
```
86+
87+
Multi-GPU (Distributed Data Parallel):
88+
89+
```bash
90+
torchrun --standalone --nproc_per_node=<NUM_GPUS> train.py
91+
```
92+
93+
## Inference
94+
95+
Use inference.py to evaluate trained models on test crash runs.
96+
97+
```bash
98+
python inference.py
99+
```
100+
101+
Predicted meshes are written as .vtp files under
102+
./predicted_vtps/, and can be opened using ParaView.
103+
104+
## Postprocessing and Evaluation
105+
106+
The postprocessing/ folder provides scripts for quantitative and qualitative evaluation:
107+
108+
- Relative $L^2$ Error (compute_l2_error.py): Computes
109+
per-timestep relative position error across runs.
110+
Produces plots and optional CSVs.
111+
112+
Example:
113+
114+
```bash
115+
python postprocessing/compute_l2_error.py \
116+
--predicted_parent ./predicted_vtps \
117+
--exact_parent ./exact_vtps \
118+
--output_plot rel_error.png \
119+
--output_csv rel_error.csv
120+
```
121+
122+
- Probe Kinematics (Driver vs Passenger Toe Pan)(compute_probe_kinematics.py):
123+
Extracts displacement/velocity/acceleration histories at selected probe nodes.
124+
Generates comparison plots (GT vs predicted).
125+
126+
Example:
127+
128+
```bash
129+
python postprocessing/compute_probe_kinematics.py \
130+
--pred_dir ./predicted_vtps/run_001 \
131+
--exact_dir ./exact_vtps/run_001 \
132+
--driver_points "70658-70659,70664" \
133+
--passenger_points "70676-70679" \
134+
--dt 0.005 \
135+
--output_plot probe_kinematics.png
136+
```
137+
138+
- Cross-Sectional Plots (plot_cross_section.py): Plots 2D slices
139+
of predicted vs ground truth deformations at specified cross-sections.
140+
141+
Example:
142+
143+
```bash
144+
python postprocessing/plot_cross_section.py \
145+
--pred_dir ./predicted_vtps/run_001 \
146+
--exact_dir ./exact_vtps/run_001 \
147+
--output_file cross_section.png
148+
```
149+
150+
run_post_processing.sh can automate all evaluation tasks across runs.
151+
152+
## References
153+
154+
- Automotive Crash Dynamics Modeling Accelerated with Machine Learning (https://arxiv.org/pdf/2510.15201)

0 commit comments

Comments
 (0)