Skip to content

PyTorch Geometric support in MeshGraphNet #995

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Jul 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Improved documentation for diffusion models and diffusion utils.
- Safe API to override `__init__`'s arguments saved in checkpoint file with
`Module.from_checkpoint("chkpt.mdlus", models_args)`.
- PyTorch Geometric MeshGraphNet backend.

### Changed

Expand All @@ -23,6 +24,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
*after* the model instantiation.
- Updated healpix data module to use correct `DistributedSampler` target for
test data loader
- Existing DGL-based vortex shedding example has been renamed to `vortex_shedding_mgn_dgl`.
Added new `vortex_shedding_mgn` example that uses PyTorch Geometric instead.

### Deprecated

Expand Down Expand Up @@ -74,7 +77,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
amortizing regression costs
- Explicit handling of Warp device for ball query and sdf
- Merged SongUNetPosLtEmb with SongUNetPosEmb, add support for batch>1
- Add lead time embedding support for `positional_embedding_selector`. Enable
- Add lead time embedding support for `positional_embedding_selector`. Enable
arbitrary positioning of probabilistic variables
- Enable lead time aware regression without CE loss
- Bumped minimum PyTorch version from 2.0.0 to 2.4.0, to minimize
Expand Down
3 changes: 1 addition & 2 deletions examples/cfd/vortex_shedding_mgn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ simulation time span and extrapolate in time. However, the accuracy of the predi
might degrade over time and if possible, extrapolation should be avoided unless
the underlying data patterns remain stationary and consistent.

The model uses the input mesh to construct a bi-directional DGL graph for each sample.
The model uses the input mesh to construct a bi-directional graph for each sample.
The node features include (6 in total):

- Velocity components at time step $t$, i.e., $u_t$, $v_t$
Expand Down Expand Up @@ -84,7 +84,6 @@ Install the requirements using:

```bash
pip install -r requirements.txt
pip install dgl -f https://data.dgl.ai/wheels/torch-2.4/cu124/repo.html --no-deps
```

## Getting Started
Expand Down
28 changes: 14 additions & 14 deletions examples/cfd/vortex_shedding_mgn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
import hydra
from hydra.utils import to_absolute_path

from dgl.dataloading import GraphDataLoader
import matplotlib.pyplot as plt
from matplotlib import animation
from matplotlib import tri as mtri
from matplotlib.patches import Rectangle
import numpy as np
from omegaconf import DictConfig
import torch
from torch_geometric.loader import DataLoader as PyGDataLoader

from physicsnemo.models.meshgraphnet import MeshGraphNet
from physicsnemo.datapipes.gnn.vortex_shedding_dataset import VortexSheddingDataset
Expand All @@ -53,7 +53,7 @@ def __init__(self, cfg: DictConfig, logger: PythonLogger):
)

# instantiate dataloader
self.dataloader = GraphDataLoader(
self.dataloader = PyGDataLoader(
self.dataset,
batch_size=1, # TODO add support for batch_size > 1
shuffle=False,
Expand Down Expand Up @@ -95,30 +95,30 @@ def predict(self):
for i, (graph, cells, mask) in enumerate(self.dataloader):
graph = graph.to(self.device)
# denormalize data
graph.ndata["x"][:, 0:2] = self.dataset.denormalize(
graph.ndata["x"][:, 0:2], stats["velocity_mean"], stats["velocity_std"]
graph.x[:, 0:2] = self.dataset.denormalize(
graph.x[:, 0:2], stats["velocity_mean"], stats["velocity_std"]
)
graph.ndata["y"][:, 0:2] = self.dataset.denormalize(
graph.ndata["y"][:, 0:2],
graph.y[:, 0:2] = self.dataset.denormalize(
graph.y[:, 0:2],
stats["velocity_diff_mean"],
stats["velocity_diff_std"],
)
graph.ndata["y"][:, [2]] = self.dataset.denormalize(
graph.ndata["y"][:, [2]],
graph.y[:, [2]] = self.dataset.denormalize(
graph.y[:, [2]],
stats["pressure_mean"],
stats["pressure_std"],
)

# inference step
invar = graph.ndata["x"].clone()
invar = graph.x.clone()

if i % (self.num_test_time_steps - 1) != 0:
invar[:, 0:2] = self.pred[i - 1][:, 0:2].clone()
i += 1
invar[:, 0:2] = self.dataset.normalize_node(
invar[:, 0:2], stats["velocity_mean"], stats["velocity_std"]
)
pred_i = self.model(invar, graph.edata["x"], graph).detach() # predict
pred_i = self.model(invar, graph.edge_attr, graph).detach() # predict

# denormalize prediction
pred_i[:, 0:2] = self.dataset.denormalize(
Expand Down Expand Up @@ -146,8 +146,8 @@ def predict(self):
self.exact.append(
torch.cat(
(
(graph.ndata["y"][:, 0:2] + graph.ndata["x"][:, 0:2]),
graph.ndata["y"][:, [2]],
(graph.y[:, 0:2] + graph.x[:, 0:2]),
graph.y[:, [2]],
),
dim=-1,
).cpu()
Expand Down Expand Up @@ -185,8 +185,8 @@ def animate(self, num):
y_star = self.pred_i[num].numpy()
y_exact = self.exact_i[num].numpy()
triang = mtri.Triangulation(
graph.ndata["mesh_pos"][:, 0].numpy(),
graph.ndata["mesh_pos"][:, 1].numpy(),
graph["mesh_pos"][:, 0].numpy(),
graph["mesh_pos"][:, 1].numpy(),
self.faces[num],
)
self.ax[0].cla()
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np
import pyvista as pv
from scipy.interpolate import griddata
from typing import List, Dict, Tuple
from typing import List, Dict


def midpoint_data_interp(
Expand Down
4 changes: 3 additions & 1 deletion examples/cfd/vortex_shedding_mgn/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@ tensorflow<=2.17.1
hydra-core>=1.2.0
wandb>=0.13.7
scipy>=1.15.0
vtk>=9.2.6
vtk>=9.2.6
torch_geometric>=2.6.1
torch_scatter>=2.1.2
38 changes: 26 additions & 12 deletions examples/cfd/vortex_shedding_mgn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@
import torch
import wandb

from dgl.dataloading import GraphDataLoader

from omegaconf import DictConfig

from torch.cuda.amp import GradScaler, autocast
from torch_geometric.loader import DataLoader as PyGDataLoader

from torch.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data.distributed import DistributedSampler

from physicsnemo.datapipes.gnn.vortex_shedding_dataset import VortexSheddingDataset
from physicsnemo.distributed.manager import DistributedManager
Expand Down Expand Up @@ -62,14 +63,20 @@ def __init__(self, cfg: DictConfig, rank_zero_logger: RankZeroLoggingWrapper):
num_steps=cfg.num_training_time_steps,
)

# instantiate dataloader
self.dataloader = GraphDataLoader(
sampler = DistributedSampler(
dataset,
batch_size=cfg.batch_size,
shuffle=True,
drop_last=True,
num_replicas=self.dist.world_size,
rank=self.dist.rank,
)

# instantiate dataloader
self.dataloader = PyGDataLoader(
dataset,
batch_size=cfg.batch_size,
sampler=sampler,
pin_memory=True,
use_ddp=self.dist.world_size > 1,
num_workers=cfg.num_dataloader_workers,
)

Expand Down Expand Up @@ -150,9 +157,9 @@ def train(self, graph):

def forward(self, graph):
# forward pass
with autocast(enabled=self.amp):
pred = self.model(graph.ndata["x"], graph.edata["x"], graph)
loss = self.criterion(pred, graph.ndata["y"])
with autocast(device_type=self.dist.device.type, enabled=self.amp):
pred = self.model(graph.x, graph.edge_attr, graph)
loss = self.criterion(pred, graph.y)
return loss

def backward(self, loss):
Expand Down Expand Up @@ -188,12 +195,19 @@ def main(cfg: DictConfig) -> None:
start = time.time()
rank_zero_logger.info("Training started...")
for epoch in range(trainer.epoch_init, cfg.epochs):
trainer.dataloader.sampler.set_epoch(epoch)

epoch_loss = 0.0

for graph in trainer.dataloader:
loss = trainer.train(graph)
epoch_loss += loss.detach().cpu()

epoch_loss /= len(trainer.dataloader)
rank_zero_logger.info(
f"epoch: {epoch}, loss: {loss:10.3e}, time per epoch: {(time.time()-start):10.3e}"
f"epoch: {epoch}, loss: {epoch_loss:10.3e}, time per epoch: {(time.time()-start):10.3e}"
)
wandb.log({"loss": loss.detach().cpu()})
wandb.log({"loss": epoch_loss})

# save checkpoint
if dist.world_size > 1:
Expand Down
145 changes: 145 additions & 0 deletions examples/cfd/vortex_shedding_mgn_dgl/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# MeshGraphNet for transient vortex shedding

> [!IMPORTANT]
> Deprecation Notice
>
> Over the next 2-3 releases, DGL-based functionality will be phased out and replaced
> by equivalent or improved implementations using PyTorch Geometric (PyG).
> PyG will become the default and only supported graph backend.

This example is a re-implementation of the DeepMind's vortex shedding example
<https://github.com/deepmind/deepmind-research/tree/master/meshgraphnets> in PyTorch.
It demonstrates how to train a Graph Neural Network (GNN) for evaluation of the
transient vortex shedding on parameterized geometries.

## Problem overview

Mesh-based simulations play a central role in modeling complex physical systems across
various scientific and engineering disciplines. They offer robust numerical integration
methods and allow for adaptable resolution to strike a balance between accuracy and
efficiency. Machine learning surrogate models have emerged as powerful tools to reduce
the cost of tasks like design optimization, design space exploration, and what-if
analysis, which involve repetitive high-dimensional scientific simulations.

However, some existing machine learning surrogate models, such as CNN-type models,
are constrained by structured grids,
making them less suitable for complex geometries or shells. The homogeneous fidelity of
CNNs is a significant limitation for many complex physical systems that require an
adaptive mesh representation to resolve multi-scale physics.

Graph Neural Networks (GNNs) present a viable approach for surrogate modeling in science
and engineering. They are data-driven and capable of handling complex physics. Being
mesh-based, GNNs can handle geometry irregularities and multi-scale physics,
making them well-suited for a wide range of applications.

## Dataset

We rely on DeepMind's vortex shedding dataset for this example. The dataset includes
1000 training, 100 validation, and 100 test samples that are simulated using COMSOL
with irregular triangle 2D meshes, each for 600 time steps with a time step size of
0.01s. These samples vary in the size and the position of the cylinder. Each sample
has a unique mesh due to geometry variations across samples, and the meshes have 1885
nodes on average. Note that the model can handle different meshes with different number
of nodes and edges as the input.

## Model overview and architecture

The model is free-running and auto-regressive. It takes the initial condition as the
input and predicts the solution at the first time step. It then takes the prediction at
the first time step to predict the solution at the next time step. The model continues
to use the prediction at time step $t$ to predict the solution at time step $t+1$, until
the rollout is complete. Note that the model is also able to predict beyond the
simulation time span and extrapolate in time. However, the accuracy of the prediction
might degrade over time and if possible, extrapolation should be avoided unless
the underlying data patterns remain stationary and consistent.

The model uses the input mesh to construct a bi-directional DGL graph for each sample.
The node features include (6 in total):

- Velocity components at time step $t$, i.e., $u_t$, $v_t$
- One-hot encoded node type (interior node, no-slip node, inlet node, outlet node)

The edge features for each sample are time-independent and include (3 in total):

- Relative $x$ and $y$ distance between the two end nodes of an edge
- L2 norm of the relative distance vector

The output of the model is the velocity components at time step t+1, i.e.,
$u_{t+1}$, $v_{t+1}$, as well as the pressure $p_{t+1}$.

![Comparison between the MeshGraphNet prediction and the
ground truth for the horizontal velocity for different test samples.
](../../../docs/img/vortex_shedding.gif)

A hidden dimensionality of 128 is used in the encoder,
processor, and decoder. The encoder and decoder consist of two hidden layers, and
the processor includes 15 message passing layers. Batch size per GPU is set to 1.
Summation aggregation is used in the
processor for message aggregation. A learning rate of 0.0001 is used, decaying
exponentially with a rate of 0.9999991. Training is performed on 8 NVIDIA A100
GPUs, leveraging data parallelism for 25 epochs.

## Prerequisites

This example requires the `tensorflow` library to load the data in the `.tfrecord`
format.

Note: If installing tensorflow inside the PhysicsNeMo docker container, it's recommended
to use `pip install "tensorflow<=2.17.1"`

Install the requirements using:

```bash
pip install -r requirements.txt
pip install dgl -f https://data.dgl.ai/wheels/torch-2.4/cu124/repo.html --no-deps
```

## Getting Started

To download the data from DeepMind's repo, run

```bash
cd raw_dataset
sh download_dataset.sh cylinder_flow
```

To train the model, run

```bash
python train.py
```

Data parallelism is also supported with multi-GPU runs. To launch a multi-GPU training,
run

```bash
mpirun -np <num_GPUs> python train.py
```

If running in a docker container, you may need to include the `--allow-run-as-root` in
the multi-GPU run command.

Progress and loss logs can be monitored using Weights & Biases. To activate that,
set `wandb_mode` to `online` in the `constants.py`. This requires to have an active
Weights & Biases account. You also need to provide your API key. There are multiple ways
for providing the API key but you can simply export it as an environment variable

```bash
export WANDB_API_KEY=<your_api_key>
```

The URL to the dashboard will be displayed in the terminal after the run is launched.
Alternatively, the logging utility in `train.py` can be switched to MLFlow.

Once the model is trained, run

```bash
python inference.py
```

This will save the predictions for the test dataset in `.gif` format in the `animations`
directory.

## References

- [Learning Mesh-Based Simulation with Graph Networks](https://arxiv.org/abs/2010.03409)
Loading