Skip to content

Commit da40b3f

Browse files
zongyi-limnabianMohammad Amin Nabian
authored
Fea ext lagrangian MeshGraphNet (#667)
* update lagrangian graph, add an example and a data loader * update readme and formatting * make reshape work with both 2d and 3d * fix readme * put activation in config, and raise error if recompute_activation with other act than silu * fix wandb * fix actviation in inference * add an unittest for lagrangian dataset * formatting * fix datapipe test * make the test compatible with later DGL version * fix unit test * uncomment @nfsdata_or_fail * formatting --------- Co-authored-by: Mohammad Amin Nabian <[email protected]> Co-authored-by: Mohammad Amin Nabian <[email protected]>
1 parent 3f7a8a4 commit da40b3f

File tree

10 files changed

+1520
-0
lines changed

10 files changed

+1520
-0
lines changed
1.03 MB
Loading
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# MeshGraphNet with Lagrangian mesh
2+
3+
This is an example of Meshgraphnet for particle-based simulation on the
4+
water dataset based on
5+
<https://github.com/google-deepmind/deepmind-research/tree/master/learning_to_simulate>
6+
in PyTorch.
7+
It demonstrates how to train a Graph Neural Network (GNN) for evaluation
8+
of the Lagrangian fluid.
9+
10+
## Problem overview
11+
12+
In this project, we provide an example of Lagrangian mesh simulation for fluids. The
13+
Lagrangian mesh is particle-based, where vertices represent fluid particles and
14+
edges represent their interactions. Compared to an Eulerian mesh, where the mesh
15+
grid is fixed, a Lagrangian mesh is more flexible since it does not require
16+
tessellating the domain or aligning with boundaries.
17+
18+
As a result, Lagrangian meshes are well-suited for representing complex geometries
19+
and free-boundary problems, such as water splashes and object collisions. However,
20+
a drawback of Lagrangian simulation is that it typically requires smaller time
21+
steps to maintain physically valid prediction.
22+
23+
## Dataset
24+
25+
We rely on [DeepMind's particle physics datasets](https://sites.google.com/view/learning-to-simulate)
26+
for this example. They datasets are particle-based simulation of fluid splashing
27+
and bouncing in a box or cube.
28+
29+
| Datasets | Num Particles | Num Time Steps | dt | Ground Truth Simulator |
30+
|--------------|---------------|----------------|----------|------------------------|
31+
| Water-3D | 14k | 800 | 5ms | SPH |
32+
| Water-2D | 2k | 1000 | 2.5ms | MPM |
33+
| WaterRamp | 2.5k | 600 | 2.5ms | MPM |
34+
35+
## Model overview and architecture
36+
37+
In this model, we utilize a Meshgraphnet to capture the fluid system’s dynamics.
38+
We represent the system as a graph, with vertices corresponding to fluid particles
39+
and edges representing their interactions. The model is autoregressive, using
40+
historical data to predict future states. The input features for the vertices
41+
include the current position, current velocity, node type (e.g., fluid, sand,
42+
boundary), and historical velocity. The model's output is the acceleration,
43+
defined as the difference between the current and next velocity. Both velocity
44+
and acceleration are derived from the position sequence and normalized to a
45+
standard Gaussian distribution for consistency.
46+
47+
For computational efficiency, we do not explicitly construct wall nodes for
48+
square or cubic domains. Instead, we assign a wall feature to each interior
49+
particle node, representing its distance from the domain boundaries. For a
50+
system dimensionality of \(d = 2\) or \(d = 3\), the features are structured
51+
as follows:
52+
53+
- **Node features**: position (\(d\)), historical velocity (\(t \times d\)),
54+
one-hot encoding of node type (6), wall feature (\(2 \times d\))
55+
- **Edge features**: displacement (\(d\)), distance (1)
56+
- **Node target**: acceleration (\(d\))
57+
58+
We construct edges based on a predefined radius, connecting pairs of particle
59+
nodes if their pairwise distance is within this radius. During training, we
60+
shuffle the time sequence and train in batches, with the graph constructed
61+
dynamically within the dataloader. For inference, predictions are rolled out
62+
iteratively, and a new graph is constructed based on previous predictions.
63+
Wall features are computed online during this process. To enhance robustness,
64+
a small amount of noise is added during training.
65+
66+
The model uses a hidden dimensionality of 128 for the encoder, processor, and
67+
decoder. The encoder and decoder each contain two hidden layers, while the
68+
processor consists of eight message-passing layers. We use a batch size of
69+
20 per GPU, and summation aggregation is applied for message passing in the
70+
processor. The learning rate is set to 0.0001 and decays exponentially with
71+
a rate of 0.9999991. These hyperparameters can be configured in the config file.
72+
73+
## Getting Started
74+
75+
This example requires the `tensorflow` library to load the data in the `.tfrecord`
76+
format. Install with
77+
78+
```bash
79+
pip install tensorflow
80+
```
81+
82+
To download the data from DeepMind's repo, run
83+
84+
```bash
85+
cd raw_dataset
86+
bash download_dataset.sh Water /data/
87+
```
88+
89+
Change the data path in `conf/config_2d.yaml` correspondingly
90+
91+
To train the model, run
92+
93+
```bash
94+
python train.py
95+
```
96+
97+
Progress and loss logs can be monitored using Weights & Biases. To activatethat,
98+
set `wandb_mode` to `online` in the `conf/config_2d.yaml` This requires to have an active
99+
Weights & Biases account. You also need to provide your API key in the config file.
100+
101+
```bash
102+
wandb_key: <your_api_key>
103+
```
104+
105+
The URL to the dashboard will be displayed in the terminal after the run is launched.
106+
Alternatively, the logging utility in `train.py` can be switched to MLFlow.
107+
108+
Once the model is trained, run
109+
110+
```bash
111+
python inference.py
112+
```
113+
114+
This will save the predictions for the test dataset in `.gif` format in the `animations`
115+
directory.
116+
117+
## References
118+
119+
- [Learning to simulate complex physicswith graph networks](arxiv.org/abs/2002.09405)
120+
- [Dataset](https://sites.google.com/view/learning-to-simulate)
121+
- [Learning Mesh-Based Simulation with Graph Networks](https://arxiv.org/abs/2010.03409)
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
2+
# SPDX-FileCopyrightText: All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
hydra:
18+
job:
19+
chdir: True
20+
run:
21+
dir: ./outputs/
22+
23+
# data configs
24+
data_dir: /data/Water
25+
dim: 2
26+
27+
# model config
28+
activation: "silu"
29+
30+
# training configs
31+
batch_size: 20
32+
epochs: 20
33+
num_training_samples: 1000 # 400
34+
num_training_time_steps: 990 # 600 - 5 (history)
35+
lr: 1e-4
36+
lr_min: 1e-6
37+
lr_decay_rate: 0.999 # every 10 epoch decays to 35%
38+
num_input_features: 22 # 2 (pos) + 2*5 (history of velocity) + 4 boundary features + 6 (node type)
39+
num_output_features: 2 # 2 acceleration
40+
num_edge_features: 3 # 2 displacement + 1 distance
41+
processor_size: 8
42+
radius: 0.015
43+
dt: 0.0025
44+
45+
# performance configs
46+
use_apex: True
47+
amp: False
48+
jit: False
49+
num_dataloader_workers: 10 # 4
50+
do_concat_trick: False
51+
num_processor_checkpoint_segments: 0
52+
recompute_activation: False
53+
54+
# wandb configs
55+
wandb_mode: offline
56+
watch_model: False
57+
wandb_key:
58+
wandb_project: "meshgraphnet"
59+
wandb_entity:
60+
wandb_name:
61+
ckpt_path: "./checkpoints_2d"
62+
63+
# test & visualization configs
64+
num_test_samples: 1
65+
num_test_time_steps: 200
66+
frame_skip: 1
67+
frame_interval: 1
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
2+
# SPDX-FileCopyrightText: All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
hydra:
18+
job:
19+
chdir: True
20+
run:
21+
dir: ./outputs/
22+
23+
# data configs
24+
data_dir: /data/Water-3D
25+
dim: 3
26+
27+
# model config
28+
activation: "silu"
29+
30+
# training configs
31+
batch_size: 2
32+
epochs: 20
33+
num_training_samples: 1000 # 400
34+
num_training_time_steps: 300 # 600 - 5 (history)
35+
lr: 1e-4
36+
lr_min: 1e-6
37+
lr_decay_rate: 0.999 # every 10 epoch decays to 35%
38+
num_input_features: 30 # 3 (pos) + 3*5 (history of velocity) + 6 boundary features + 6 (node type)
39+
num_output_features: 3 # 2 acceleration
40+
num_edge_features: 4 # 2 displacement + 1 distance
41+
processor_size: 8
42+
radius: 0.035
43+
dt: 0.005
44+
45+
# performance configs
46+
use_apex: True
47+
amp: False
48+
jit: False
49+
num_dataloader_workers: 4 # 4
50+
do_concat_trick: False
51+
num_processor_checkpoint_segments: 0
52+
recompute_activation: False
53+
54+
# wandb configs
55+
wandb_mode: offline
56+
watch_model: False
57+
wandb_key:
58+
wandb_project: "meshgraphnet"
59+
wandb_entity:
60+
wandb_name:
61+
ckpt_path: "./checkpoints_3d"
62+
63+
# test & visualization configs
64+
num_test_samples: 1
65+
num_test_time_steps: 400
66+
frame_skip: 1
67+
frame_interval: 1

0 commit comments

Comments
 (0)