Skip to content

Commit 3f7a8a4

Browse files
authored
XAeroNet (#692)
* adding xaeronet-s model * add validation plots * xaeronet-v model * formatting * update changelog * remove json file * address review comments * multi-scale support, minor fixes
1 parent 297297e commit 3f7a8a4

23 files changed

+3128
-8
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1818
- Bistride Multiscale MeshGraphNet example.
1919
- FIGConvUNet model and example.
2020
- The Transolver model.
21+
- The XAeroNet model.
2122
- Incoporated CorrDiff-GEFS-HRRR model into CorrDiff, with lead-time aware SongUNet and
2223
cross entropy loss.
2324

docs/img/xaeronet_s_results.png

1.44 MB
Loading

docs/img/xaeronet_v_results.png

1.33 MB
Loading

examples/cfd/xaeronet/README.md

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# XAeroNet: Scalable Neural Models for External Aerodynamics
2+
3+
XAeroNet is a collection of scalable models for large-scale external
4+
aerodynamic evaluations. It consists of two models, XAeroNet-S and XAeroNet-V for
5+
surface and volume predictions, respectively.
6+
7+
## Problem overview
8+
9+
External aerodynamics plays a crucial role in the design and optimization of vehicles,
10+
aircraft, and other transportation systems. Accurate predictions of aerodynamic
11+
properties such as drag, pressure distribution, and airflow characteristics are
12+
essential for improving fuel efficiency, vehicle stability, and performance.
13+
Traditional approaches, such as computational fluid dynamics (CFD) simulations,
14+
are computationally expensive and time-consuming, especially when evaluating multiple
15+
design iterations or large datasets.
16+
17+
XAeroNet addresses these challenges by leveraging neural network-based surrogate
18+
models to provide fast, scalable, and accurate predictions for both surface-level
19+
and volume-level aerodynamic properties. By using the DrivAerML dataset, which
20+
contains high-fidelity CFD data for a variety of vehicle geometries, XAeroNet aims
21+
to significantly reduce the computational cost while maintaining high prediction
22+
accuracy. The two models in XAeroNet—XAeroNet-S for surface predictions and XAeroNet-V
23+
for volume predictions—enable rapid aerodynamic evaluations across different design
24+
configurations, making it easier to incorporate aerodynamic considerations early in
25+
the design process.
26+
27+
## Model Overview and Architecture
28+
29+
### XAeroNet-S
30+
31+
XAeroNet-S is a scalable MeshGraphNet model that partitions large input graphs into
32+
smaller subgraphs to reduce training memory overhead. Halo regions are added to these
33+
subgraphs to prevent message-passing truncations at the boundaries. Gradient aggregation
34+
is employed to accumulate gradients from each partition before updating the model parameters.
35+
This approach ensures that training on partitions is equivalent to training on the entire
36+
graph in terms of model updates and accuracy. Additionally, XAeroNet-S does not rely on
37+
simulation meshes for training and inference, overcoming a significant limitation of
38+
GNN models in simulation tasks.
39+
40+
The input to the training pipeline is STL files, from which the model samples a point cloud
41+
on the surface. It then constructs a connectivity graph by linking the N nearest neighbors.
42+
This method also supports multi-mesh setups, where point clouds with different resolutions
43+
are generated, their connectivity graphs are created, and all are superimposed. The Metis
44+
library is used to partition the graph for efficient training.
45+
46+
For the XAeroNet-S model, STL files are used to generate point clouds and establish graph
47+
connectivity. Additionally, the .vtp files are used to interpolate the solution fields onto
48+
the point clouds.
49+
50+
### XAeroNet-V
51+
52+
XAeroNet-V is a scalable 3D UNet model with attention gates, designed to partition large
53+
voxel grids into smaller sub-grids to reduce memory overhead during training. Halo regions
54+
are added to these partitions to avoid convolution truncations at the boundaries.
55+
Gradient aggregation is used to accumulate gradients from each partition before updating
56+
the model parameters, ensuring that training on partitions is equivalent to training on
57+
the entire voxel grid in terms of model updates and accuracy. Additionally, XAeroNet-V
58+
incorporates a continuity constraint as an additional loss term during training to
59+
enhance model interpretability.
60+
61+
For the XAeroNet-V model, the .vtu files are used to interpolate the volumetric
62+
solution fields onto a voxel grid, while the .stl files are utilized to compute
63+
the signed distance field (SDF) and its derivatives on the voxel grid.
64+
65+
## Dataset
66+
67+
We trained our models using the DrivAerML dataset from the [CAE ML Dataset collection](https://caemldatasets.org/drivaerml/).
68+
This high-fidelity, open-source (CC-BY-SA) public dataset is specifically designed
69+
for automotive aerodynamics research. It comprises 500 parametrically morphed variants
70+
of the widely utilized DrivAer notchback generic vehicle. Mesh generation and scale-resolving
71+
computational fluid dynamics (CFD) simulations were executed using consistent and validated
72+
automatic workflows that represent the industrial state-of-the-art. Geometries and comprehensive
73+
aerodynamic data are published in open-source formats. For more technical details about this
74+
dataset, please refer to their [paper](https://arxiv.org/pdf/2408.11969).
75+
76+
## Training the XAeroNet-S model
77+
78+
To train the XAeroNet-S model, follow these steps:
79+
80+
1. Download the DrivAer ML dataset using the provided `download_aws_dataset.sh` script.
81+
82+
2. Navigate to the `surface` folder.
83+
84+
3. Specify the configurations in `conf/config.yaml`. Make sure path to the dataset
85+
is specified correctly.
86+
87+
4. Run `combine_stl_solids.py`. The STL files in the DriveML dataset consist of multiple
88+
solids. Those should be combined into a single solid to properly generate a surface point
89+
cloud using the Modulus Tesselated geometry module.
90+
91+
5. Run `preprocessing.py`. This will prepare and save the partitioned graphs.
92+
93+
6. Create a `partitions_validation` folder, and move the samples you wish to use for
94+
validation to that folder.
95+
96+
7. Run `compute_stats.py` to compute the global mean and standard deviation from the
97+
training samples.
98+
99+
8. Run `train.py` to start the training.
100+
101+
9. Download the validation results (saved in form of point clouds in `.vtp` format),
102+
and visualize in Paraview.
103+
104+
![XAeroNet-S Validation results for the sample #500.](../../../docs/img/xaeronet_s_results.png)
105+
106+
## Training the XAeroNet-V model
107+
108+
To train the XAeroNet-V model, follow these steps:
109+
110+
1. Download the DrivAer ML dataset using the provided `download_aws_dataset.sh` script.
111+
112+
2. Navigate to the `volume` folder.
113+
114+
3. Specify the configurations in `conf/config.yaml`. Make sure path to the dataset
115+
is specified correctly.
116+
117+
4. Run `preprocessing.py`. This will prepare and save the voxel grids.
118+
119+
5. Create a `drivaer_aws_h5_validation` folder, and move the samples you wish to
120+
use for validation to that folder.
121+
122+
6. Run `compute_stats.py` to compute the global mean and standard deviation from
123+
the training samples.
124+
125+
7. Run `train.py` to start the training. Partitioning is performed prior to training.
126+
127+
8. Download the validation results (saved in form of voxel grids in `.vti` format),
128+
and visualize in Paraview.
129+
130+
![XAeroNet-V Validation results.](../../../docs/img/xaeronet_v_results.png)
131+
132+
## Logging
133+
134+
We mainly use TensorBoard for logging training and validation losses, as well as
135+
the learning rate during training. You can also optionally use Weight & Biases to
136+
log training metrics. To visualize TensorBoard running in a
137+
Docker container on a remote server from your local desktop, follow these steps:
138+
139+
1. **Expose the Port in Docker:**
140+
Expose port 6006 in the Docker container by including
141+
`-p 6006:6006` in your docker run command.
142+
143+
2. **Launch TensorBoard:**
144+
Start TensorBoard within the Docker container:
145+
146+
```bash
147+
tensorboard --logdir=/path/to/logdir --port=6006
148+
```
149+
150+
3. **Set Up SSH Tunneling:**
151+
Create an SSH tunnel to forward port 6006 from the remote server to your local machine:
152+
153+
```bash
154+
ssh -L 6006:localhost:6006 <user>@<remote-server-ip>
155+
```
156+
157+
Replace `<user>` with your SSH username and `<remote-server-ip>` with the IP address
158+
of your remote server. You can use a different port if necessary.
159+
160+
4. **Access TensorBoard:**
161+
Open your web browser and navigate to `http://localhost:6006` to view TensorBoard.
162+
163+
**Note:** Ensure the remote server’s firewall allows connections on port `6006`
164+
and that your local machine’s firewall allows outgoing connections.
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#!/bin/bash
2+
3+
# This is a Bash script designed to identify and remove corrupted files after downloading the AWS DrivAer dataset.
4+
# The script defines two functions: check_and_remove_corrupted_extension and check_all_runs.
5+
# The check_and_remove_corrupted_extension function checks for files in a given directory that have extra characters after their extension.
6+
# If such a file is found, it is considered corrupted, and the function removes it.
7+
# The check_all_runs function iterates over all directories in a specified local directory (LOCAL_DIR), checking for corrupted files with the extensions ".vtu", ".stl", and ".vtp".
8+
# The script begins the cleanup process by calling the check_all_runs function. The target directory for this operation is set as "./drivaer_data_full".
9+
10+
# Set the local directory to check the files
11+
LOCAL_DIR="./drivaer_data_full" # <--- This is the directory where the files are downloaded.
12+
13+
# Function to check if a file has extra characters after the extension and remove it
14+
check_and_remove_corrupted_extension() {
15+
local dir=$1
16+
local base_filename=$2
17+
local extension=$3
18+
19+
# Find any files with extra characters after the extension
20+
for file in "$dir/$base_filename"$extension*; do
21+
if [[ -f "$file" && "$file" != "$dir/$base_filename$extension" ]]; then
22+
echo "Corrupted file detected: $file (extra characters after extension), removing it."
23+
rm "$file"
24+
fi
25+
done
26+
}
27+
28+
# Function to go over all the run directories and check files
29+
check_all_runs() {
30+
for RUN_DIR in "$LOCAL_DIR"/run_*; do
31+
echo "Checking folder: $RUN_DIR"
32+
33+
# Check for corrupted .vtu files
34+
base_vtu="volume_${RUN_DIR##*_}"
35+
check_and_remove_corrupted_extension "$RUN_DIR" "$base_vtu" ".vtu"
36+
37+
# Check for corrupted .stl files
38+
base_stl="drivaer_${RUN_DIR##*_}"
39+
check_and_remove_corrupted_extension "$RUN_DIR" "$base_stl" ".stl"
40+
41+
# Check for corrupted .vtp files
42+
base_stl="drivaer_${RUN_DIR##*_}"
43+
check_and_remove_corrupted_extension "$RUN_DIR" "$base_stl" ".vtp"
44+
done
45+
}
46+
47+
# Start checking
48+
check_all_runs
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#!/bin/bash
2+
3+
# This Bash script downloads the AWS DrivAer files from the Amazon S3 bucket to a local directory.
4+
# Only the volume files (.vtu), STL files (.stl), and VTP files (.vtp) are downloaded.
5+
# It uses a function, download_run_files, to check for the existence of three specific files (".vtu", ".stl", ".vtp") in a run directory.
6+
# If a file doesn't exist, it's downloaded from the S3 bucket. If it does exist, the download is skipped.
7+
# The script runs multiple downloads in parallel, both within a single run and across multiple runs.
8+
# It also includes checks to prevent overloading the system by limiting the number of parallel downloads.
9+
10+
# Set the local directory to download the files
11+
LOCAL_DIR="./drivaer_data_full" # <--- This is the directory where the files will be downloaded.
12+
13+
# Set the S3 bucket and prefix
14+
S3_BUCKET="caemldatasets"
15+
S3_PREFIX="drivaer/dataset"
16+
17+
# Create the local directory if it doesn't exist
18+
mkdir -p "$LOCAL_DIR"
19+
20+
# Function to download files for a specific run
21+
download_run_files() {
22+
local i=$1
23+
RUN_DIR="run_$i"
24+
RUN_LOCAL_DIR="$LOCAL_DIR/$RUN_DIR"
25+
26+
# Create the run directory if it doesn't exist
27+
mkdir -p "$RUN_LOCAL_DIR"
28+
29+
# Check if the .vtu file exists before downloading
30+
if [ ! -f "$RUN_LOCAL_DIR/volume_$i.vtu" ]; then
31+
aws s3 cp --no-sign-request "s3://$S3_BUCKET/$S3_PREFIX/$RUN_DIR/volume_$i.vtu" "$RUN_LOCAL_DIR/" &
32+
else
33+
echo "File volume_$i.vtu already exists, skipping download."
34+
fi
35+
36+
# Check if the .stl file exists before downloading
37+
if [ ! -f "$RUN_LOCAL_DIR/drivaer_$i.stl" ]; then
38+
aws s3 cp --no-sign-request "s3://$S3_BUCKET/$S3_PREFIX/$RUN_DIR/drivaer_$i.stl" "$RUN_LOCAL_DIR/" &
39+
else
40+
echo "File drivaer_$i.stl already exists, skipping download."
41+
fi
42+
43+
# Check if the .vtp file exists before downloading
44+
if [ ! -f "$RUN_LOCAL_DIR/boundary_$i.vtp" ]; then
45+
aws s3 cp --no-sign-request "s3://$S3_BUCKET/$S3_PREFIX/$RUN_DIR/boundary_$i.vtp" "$RUN_LOCAL_DIR/" &
46+
else
47+
echo "File boundary_$i.vtp already exists, skipping download."
48+
fi
49+
50+
wait # Ensure that both files for this run are downloaded before moving to the next run
51+
}
52+
53+
# Loop through the run folders and download the files
54+
for i in $(seq 1 500); do
55+
download_run_files "$i" &
56+
57+
# Limit the number of parallel jobs to avoid overloading the system
58+
if (( $(jobs -r | wc -l) >= 8 )); then
59+
wait -n # Wait for the next background job to finish before starting a new one
60+
fi
61+
done
62+
63+
# Wait for all remaining background jobs to finish
64+
wait
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
trimesh==4.5.0
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
2+
# SPDX-FileCopyrightText: All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
"""
18+
This module provides functionality to convert STL files with multiple solids
19+
to another STL file with a single combined solid. It includes support for
20+
processing multiple files in parallel with progress tracking.
21+
"""
22+
23+
import os
24+
import trimesh
25+
import hydra
26+
27+
from multiprocessing import Pool
28+
from tqdm import tqdm
29+
from hydra.utils import to_absolute_path
30+
from omegaconf import DictConfig
31+
32+
33+
def process_stl_file(task):
34+
stl_path = task
35+
36+
# Load the STL file using trimesh
37+
mesh = trimesh.load_mesh(stl_path)
38+
39+
# If the STL file contains multiple solids (as a Scene object)
40+
if isinstance(mesh, trimesh.Scene):
41+
# Extract all geometries (solids) from the scene
42+
meshes = list(mesh.geometry.values())
43+
44+
# Combine all the solids into a single mesh
45+
combined_mesh = trimesh.util.concatenate(meshes)
46+
else:
47+
# If it's a single solid, no need to combine
48+
combined_mesh = mesh
49+
50+
# Prepare the output file path (next to the original file)
51+
base_name, ext = os.path.splitext(stl_path)
52+
output_file_path = to_absolute_path(f"{base_name}_single_solid{ext}")
53+
54+
# Save the new combined mesh as an STL file
55+
combined_mesh.export(output_file_path)
56+
57+
return f"Processed: {stl_path} -> {output_file_path}"
58+
59+
60+
def process_directory(data_path, num_workers=16):
61+
"""Process all STL files in the given directory using multiprocessing with progress tracking."""
62+
tasks = []
63+
for root, _, files in os.walk(data_path):
64+
stl_files = [f for f in files if f.endswith(".stl")]
65+
for stl_file in stl_files:
66+
stl_path = os.path.join(root, stl_file)
67+
68+
# Add the STL file to the tasks list (no need for output dir, saving next to the original)
69+
tasks.append(stl_path)
70+
71+
# Use multiprocessing to process the tasks with progress tracking
72+
with Pool(num_workers) as pool:
73+
for _ in tqdm(
74+
pool.imap_unordered(process_stl_file, tasks),
75+
total=len(tasks),
76+
desc="Processing STL Files",
77+
unit="file",
78+
):
79+
pass
80+
81+
82+
@hydra.main(version_base="1.3", config_path="conf", config_name="config")
83+
def main(cfg: DictConfig) -> None:
84+
# Process the directory with multiple STL files
85+
process_directory(
86+
to_absolute_path(cfg.data_path), num_workers=cfg.num_preprocess_workers
87+
)
88+
89+
90+
if __name__ == "__main__":
91+
main()

0 commit comments

Comments
 (0)