Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions examples/cfd/vortex_shedding_mgn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
import hydra
from hydra.utils import to_absolute_path

from dgl.dataloading import GraphDataLoader
import matplotlib.pyplot as plt
from matplotlib import animation
from matplotlib import tri as mtri
from matplotlib.patches import Rectangle
import numpy as np
from omegaconf import DictConfig
import torch
from torch_geometric.loader import DataLoader

from physicsnemo.models.meshgraphnet import MeshGraphNet
from physicsnemo.datapipes.gnn.vortex_shedding_dataset import VortexSheddingDataset
Expand All @@ -53,7 +53,7 @@ def __init__(self, cfg: DictConfig, logger: PythonLogger):
)

# instantiate dataloader
self.dataloader = GraphDataLoader(
self.dataloader = DataLoader(
self.dataset,
batch_size=1, # TODO add support for batch_size > 1
shuffle=False,
Expand Down Expand Up @@ -95,30 +95,30 @@ def predict(self):
for i, (graph, cells, mask) in enumerate(self.dataloader):
graph = graph.to(self.device)
# denormalize data
graph.ndata["x"][:, 0:2] = self.dataset.denormalize(
graph.ndata["x"][:, 0:2], stats["velocity_mean"], stats["velocity_std"]
graph.x[:, 0:2] = self.dataset.denormalize(
graph.x[:, 0:2], stats["velocity_mean"], stats["velocity_std"]
)
graph.ndata["y"][:, 0:2] = self.dataset.denormalize(
graph.ndata["y"][:, 0:2],
graph.y[:, 0:2] = self.dataset.denormalize(
graph.y[:, 0:2],
stats["velocity_diff_mean"],
stats["velocity_diff_std"],
)
graph.ndata["y"][:, [2]] = self.dataset.denormalize(
graph.ndata["y"][:, [2]],
graph.y[:, [2]] = self.dataset.denormalize(
graph.y[:, [2]],
stats["pressure_mean"],
stats["pressure_std"],
)

# inference step
invar = graph.ndata["x"].clone()
invar = graph.x.clone()

if i % (self.num_test_time_steps - 1) != 0:
invar[:, 0:2] = self.pred[i - 1][:, 0:2].clone()
i += 1
invar[:, 0:2] = self.dataset.normalize_node(
invar[:, 0:2], stats["velocity_mean"], stats["velocity_std"]
)
pred_i = self.model(invar, graph.edata["x"], graph).detach() # predict
pred_i = self.model(invar, graph.edge_attr, graph).detach() # predict

# denormalize prediction
pred_i[:, 0:2] = self.dataset.denormalize(
Expand Down Expand Up @@ -146,8 +146,8 @@ def predict(self):
self.exact.append(
torch.cat(
(
(graph.ndata["y"][:, 0:2] + graph.ndata["x"][:, 0:2]),
graph.ndata["y"][:, [2]],
(graph.y[:, 0:2] + graph.x[:, 0:2]),
graph.y[:, [2]],
),
dim=-1,
).cpu()
Expand Down Expand Up @@ -185,8 +185,8 @@ def animate(self, num):
y_star = self.pred_i[num].numpy()
y_exact = self.exact_i[num].numpy()
triang = mtri.Triangulation(
graph.ndata["mesh_pos"][:, 0].numpy(),
graph.ndata["mesh_pos"][:, 1].numpy(),
graph["mesh_pos"][:, 0].numpy(),
graph["mesh_pos"][:, 1].numpy(),
self.faces[num],
)
self.ax[0].cla()
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np
import pyvista as pv
from scipy.interpolate import griddata
from typing import List, Dict, Tuple
from typing import List, Dict


def midpoint_data_interp(
Expand Down
4 changes: 3 additions & 1 deletion examples/cfd/vortex_shedding_mgn/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@ tensorflow<=2.17.1
hydra-core>=1.2.0
wandb>=0.13.7
scipy>=1.15.0
vtk>=9.2.6
vtk>=9.2.6
torch_geometric>=2.6.1
torch_scatter>=2.1.2
29 changes: 19 additions & 10 deletions examples/cfd/vortex_shedding_mgn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@
import torch
import wandb

from dgl.dataloading import GraphDataLoader

from omegaconf import DictConfig

from torch.cuda.amp import GradScaler, autocast
from torch_geometric.loader import DataLoader

from torch.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data.distributed import DistributedSampler

from physicsnemo.datapipes.gnn.vortex_shedding_dataset import VortexSheddingDataset
from physicsnemo.distributed.manager import DistributedManager
Expand Down Expand Up @@ -62,14 +63,20 @@ def __init__(self, cfg: DictConfig, rank_zero_logger: RankZeroLoggingWrapper):
num_steps=cfg.num_training_time_steps,
)

# instantiate dataloader
self.dataloader = GraphDataLoader(
sampler = DistributedSampler(
dataset,
batch_size=cfg.batch_size,
shuffle=True,
drop_last=True,
num_replicas=self.dist.world_size,
rank=self.dist.rank,
)

# instantiate dataloader
self.dataloader = DataLoader(
dataset,
batch_size=cfg.batch_size,
sampler=sampler,
pin_memory=True,
use_ddp=self.dist.world_size > 1,
num_workers=cfg.num_dataloader_workers,
)

Expand Down Expand Up @@ -150,9 +157,9 @@ def train(self, graph):

def forward(self, graph):
# forward pass
with autocast(enabled=self.amp):
pred = self.model(graph.ndata["x"], graph.edata["x"], graph)
loss = self.criterion(pred, graph.ndata["y"])
with autocast(device_type=self.dist.device.type, enabled=self.amp):
pred = self.model(graph.x, graph.edge_attr, graph)
loss = self.criterion(pred, graph.y)
return loss

def backward(self, loss):
Expand Down Expand Up @@ -188,6 +195,8 @@ def main(cfg: DictConfig) -> None:
start = time.time()
rank_zero_logger.info("Training started...")
for epoch in range(trainer.epoch_init, cfg.epochs):
trainer.dataloader.sampler.set_epoch(epoch)

for graph in trainer.dataloader:
loss = trainer.train(graph)
rank_zero_logger.info(
Expand Down
138 changes: 138 additions & 0 deletions examples/cfd/vortex_shedding_mgn_dgl/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# MeshGraphNet for transient vortex shedding

This example is a re-implementation of the DeepMind's vortex shedding example
<https://github.com/deepmind/deepmind-research/tree/master/meshgraphnets> in PyTorch.
It demonstrates how to train a Graph Neural Network (GNN) for evaluation of the
transient vortex shedding on parameterized geometries.

## Problem overview

Mesh-based simulations play a central role in modeling complex physical systems across
various scientific and engineering disciplines. They offer robust numerical integration
methods and allow for adaptable resolution to strike a balance between accuracy and
efficiency. Machine learning surrogate models have emerged as powerful tools to reduce
the cost of tasks like design optimization, design space exploration, and what-if
analysis, which involve repetitive high-dimensional scientific simulations.

However, some existing machine learning surrogate models, such as CNN-type models,
are constrained by structured grids,
making them less suitable for complex geometries or shells. The homogeneous fidelity of
CNNs is a significant limitation for many complex physical systems that require an
adaptive mesh representation to resolve multi-scale physics.

Graph Neural Networks (GNNs) present a viable approach for surrogate modeling in science
and engineering. They are data-driven and capable of handling complex physics. Being
mesh-based, GNNs can handle geometry irregularities and multi-scale physics,
making them well-suited for a wide range of applications.

## Dataset

We rely on DeepMind's vortex shedding dataset for this example. The dataset includes
1000 training, 100 validation, and 100 test samples that are simulated using COMSOL
with irregular triangle 2D meshes, each for 600 time steps with a time step size of
0.01s. These samples vary in the size and the position of the cylinder. Each sample
has a unique mesh due to geometry variations across samples, and the meshes have 1885
nodes on average. Note that the model can handle different meshes with different number
of nodes and edges as the input.

## Model overview and architecture

The model is free-running and auto-regressive. It takes the initial condition as the
input and predicts the solution at the first time step. It then takes the prediction at
the first time step to predict the solution at the next time step. The model continues
to use the prediction at time step $t$ to predict the solution at time step $t+1$, until
the rollout is complete. Note that the model is also able to predict beyond the
simulation time span and extrapolate in time. However, the accuracy of the prediction
might degrade over time and if possible, extrapolation should be avoided unless
the underlying data patterns remain stationary and consistent.

The model uses the input mesh to construct a bi-directional DGL graph for each sample.
The node features include (6 in total):

- Velocity components at time step $t$, i.e., $u_t$, $v_t$
- One-hot encoded node type (interior node, no-slip node, inlet node, outlet node)

The edge features for each sample are time-independent and include (3 in total):

- Relative $x$ and $y$ distance between the two end nodes of an edge
- L2 norm of the relative distance vector

The output of the model is the velocity components at time step t+1, i.e.,
$u_{t+1}$, $v_{t+1}$, as well as the pressure $p_{t+1}$.

![Comparison between the MeshGraphNet prediction and the
ground truth for the horizontal velocity for different test samples.
](../../../docs/img/vortex_shedding.gif)

A hidden dimensionality of 128 is used in the encoder,
processor, and decoder. The encoder and decoder consist of two hidden layers, and
the processor includes 15 message passing layers. Batch size per GPU is set to 1.
Summation aggregation is used in the
processor for message aggregation. A learning rate of 0.0001 is used, decaying
exponentially with a rate of 0.9999991. Training is performed on 8 NVIDIA A100
GPUs, leveraging data parallelism for 25 epochs.

## Prerequisites

This example requires the `tensorflow` library to load the data in the `.tfrecord`
format.

Note: If installing tensorflow inside the PhysicsNeMo docker container, it's recommended
to use `pip install "tensorflow<=2.17.1"`

Install the requirements using:

```bash
pip install -r requirements.txt
pip install dgl -f https://data.dgl.ai/wheels/torch-2.4/cu124/repo.html --no-deps
```

## Getting Started

To download the data from DeepMind's repo, run

```bash
cd raw_dataset
sh download_dataset.sh cylinder_flow
```

To train the model, run

```bash
python train.py
```

Data parallelism is also supported with multi-GPU runs. To launch a multi-GPU training,
run

```bash
mpirun -np <num_GPUs> python train.py
```

If running in a docker container, you may need to include the `--allow-run-as-root` in
the multi-GPU run command.

Progress and loss logs can be monitored using Weights & Biases. To activate that,
set `wandb_mode` to `online` in the `constants.py`. This requires to have an active
Weights & Biases account. You also need to provide your API key. There are multiple ways
for providing the API key but you can simply export it as an environment variable

```bash
export WANDB_API_KEY=<your_api_key>
```

The URL to the dashboard will be displayed in the terminal after the run is launched.
Alternatively, the logging utility in `train.py` can be switched to MLFlow.

Once the model is trained, run

```bash
python inference.py
```

This will save the predictions for the test dataset in `.gif` format in the `animations`
directory.

## References

- [Learning Mesh-Based Simulation with Graph Networks](https://arxiv.org/abs/2010.03409)
57 changes: 57 additions & 0 deletions examples/cfd/vortex_shedding_mgn_dgl/conf/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

hydra:
job:
chdir: True
run:
dir: ./outputs/

# data configs
data_dir: ./raw_dataset/cylinder_flow/cylinder_flow

# training configs
batch_size: 1
epochs: 25
num_training_samples: 400
num_training_time_steps: 300
lr: 0.0001
lr_decay_rate: 0.9999991
num_input_features: 6
num_output_features: 3
num_edge_features: 3

# performance configs
use_apex: True
amp: False
jit: False
num_dataloader_workers: 4
do_concat_trick: False
num_processor_checkpoint_segments: 0
recompute_activation: False

# wandb configs
wandb_mode: disabled
watch_model: False

ckpt_path: "./checkpoints"

# test & visualization configs
num_test_samples: 10
num_test_time_steps: 300
viz_vars: ["u", "v", "p"]
frame_skip: 10
frame_interval: 1
Loading