Skip to content

Commit e5844cc

Browse files
authored
[FEA] Physics informed examples using the PhysicsInformer utility (#664)
* update to use newer apis, and some misc cleanup * update the stokes example * add pinns example * add pi informed gnn example using least squares ---------
1 parent eb01d2a commit e5844cc

File tree

17 files changed

+1033
-311
lines changed

17 files changed

+1033
-311
lines changed

examples/cfd/darcy_physics_informed/README.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,15 @@ the loss function and the use of one over the other can change from case-to-case
5050
With this example, we intend to demonstrate both such cases so that the users can compare
5151
and contrast the two approaches.
5252

53-
In this example we will also use the `PDE` class from Modulus-Sym to symbolically define
54-
the PDEs. This is very convinient and most natural way to define these PDEs and allows
55-
us to print the equations to check for correctness. This also abstracts out the
53+
In this example we will use the `PDE` class from Modulus-Sym to symbolically define
54+
the PDEs and use the `PhysicsInformer` utility to introduce the PDE
55+
constraints. Defining the PDEs sympolically is very convinient and most natural way to
56+
define these PDEs and allows us to print the equations to check for correctness.
57+
This also abstracts out the
5658
complexity of converting the equation into a pytorch representation. Modulus Sym also
5759
provides several complex, well tested PDEs like 3D Navier-Stokes, Linear elasticity,
5860
Electromagnetics, etc. pre-defined which can be used directly in physics-informing
59-
applications.
61+
applications.
6062

6163
## Getting Started
6264

examples/cfd/darcy_physics_informed/conf/config_deeponet.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ start_lr: 0.001
2424
gamma: 0.99948708
2525
max_epochs: 50
2626

27-
phy_wt: 0.1
27+
physics_weight: 0.1
2828

2929
model:
3030
fno:

examples/cfd/darcy_physics_informed/conf/config_pino.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ start_lr: 0.001
2424
gamma: 0.99948708
2525
max_epochs: 50
2626

27-
phy_wt: 0.1
27+
physics_weight: 0.1
2828

2929
model:
3030
fno:

examples/cfd/darcy_physics_informed/darcy_physics_informed_deeponet.py

Lines changed: 41 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -14,29 +14,27 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
from itertools import chain
18+
from typing import Dict
19+
1720
import hydra
18-
from omegaconf import DictConfig
19-
import torch
20-
import numpy as np
2121
import matplotlib.pyplot as plt
22-
from hydra.utils import to_absolute_path
22+
import numpy as np
23+
import torch
2324
import torch.nn.functional as F
24-
from torch.utils.data import DataLoader
25-
from itertools import chain
26-
27-
from modulus.models.mlp import FullyConnected
28-
from modulus.models.fno import FNO
25+
from hydra.utils import to_absolute_path
2926
from modulus.launch.logging import LaunchLogger
3027
from modulus.launch.utils.checkpoint import save_checkpoint
31-
32-
from utils import HDF5MapStyleDataset
28+
from modulus.models.fno import FNO
29+
from modulus.models.mlp import FullyConnected
3330
from modulus.sym.eq.pdes.diffusion import Diffusion
34-
35-
from modulus.sym.graph import Graph
31+
from modulus.sym.eq.phy_informer import PhysicsInformer
3632
from modulus.sym.key import Key
37-
from modulus.sym.node import Node
38-
from typing import Optional, Dict
3933
from modulus.sym.models.arch import Arch
34+
from omegaconf import DictConfig
35+
from torch.utils.data import DataLoader
36+
37+
from utils import HDF5MapStyleDataset
4038

4139

4240
def validation_step(graph, dataloader, epoch):
@@ -61,14 +59,14 @@ def validation_step(graph, dataloader, epoch):
6159
# plotting
6260
fig, ax = plt.subplots(1, 3, figsize=(25, 5))
6361

64-
d_min = np.min(outvar[0, 0, ...])
65-
d_max = np.max(outvar[0, 0, ...])
62+
d_min = np.min(outvar[0, 0])
63+
d_max = np.max(outvar[0, 0])
6664

67-
im = ax[0].imshow(outvar[0, 0, ...], vmin=d_min, vmax=d_max)
65+
im = ax[0].imshow(outvar[0, 0], vmin=d_min, vmax=d_max)
6866
plt.colorbar(im, ax=ax[0])
69-
im = ax[1].imshow(predvar[0, 0, ...], vmin=d_min, vmax=d_max)
67+
im = ax[1].imshow(predvar[0, 0], vmin=d_min, vmax=d_max)
7068
plt.colorbar(im, ax=ax[1])
71-
im = ax[2].imshow(np.abs(predvar[0, 0, ...] - outvar[0, 0, ...]))
69+
im = ax[2].imshow(np.abs(predvar[0, 0] - outvar[0, 0]))
7270
plt.colorbar(im, ax=ax[2])
7371

7472
ax[0].set_title("True")
@@ -113,7 +111,7 @@ def __init__(
113111
trunk_net=None,
114112
branch_net=None,
115113
):
116-
super(MdlsSymWrapper, self).__init__(
114+
super().__init__(
117115
input_keys=input_keys,
118116
output_keys=output_keys,
119117
)
@@ -162,9 +160,8 @@ def main(cfg: DictConfig):
162160
LaunchLogger.initialize()
163161

164162
# Use Diffusion equation for the Darcy PDE
165-
darcy = Diffusion(T="u", time=False, dim=2, D="k", Q=1.0 * 4.49996e00 * 3.88433e-03)
166-
167-
darcy_node = darcy.make_nodes()
163+
forcing_fn = 1.0 * 4.49996e00 * 3.88433e-03 # after scaling
164+
darcy = Diffusion(T="u", time=False, dim=2, D="k", Q=forcing_fn)
168165

169166
dataset = HDF5MapStyleDataset(
170167
to_absolute_path("./datasets/Darcy_241/train.hdf5"), device=device
@@ -204,30 +201,14 @@ def main(cfg: DictConfig):
204201
output_keys=[Key("k"), Key("u")],
205202
trunk_net=model_trunk,
206203
branch_net=model_branch,
207-
)
208-
209-
nodes = darcy_node + [model.make_node(name="network", jit=False)]
210-
211-
# note: this example uses the Graph class from Modulus Sym to construct the
212-
# computational graph. This allows you to leverage Modulus Sym's optimized
213-
# derivative backend to compute the derivatives, along with other benefits like
214-
# symbolic definition of PDEs and leveraging the PDEs from Modulus Sym's PDE
215-
# module.
216-
# For more details, refer: https://docs.nvidia.com/deeplearning/modulus/modulus-sym/api/modulus.sym.html#module-modulus.sym.graph
217-
graph = Graph(
218-
nodes,
219-
[Key("k_prime"), Key("x"), Key("y")],
220-
[Key("k"), Key("u"), Key("diffusion_u")],
221-
func_arch=False,
222204
).to(device)
223205

224-
# For pure inference (no gradients)
225-
graph_infer = Graph(
226-
[model.make_node(name="network", jit=False)],
227-
[Key("k_prime"), Key("x"), Key("y")],
228-
[Key("k"), Key("u")], # No PDE Key
229-
func_arch=False,
230-
).to(device)
206+
phy_informer = PhysicsInformer(
207+
required_outputs=["diffusion_u"],
208+
equations=darcy,
209+
grad_method="autodiff",
210+
device=device,
211+
)
231212

232213
optimizer = torch.optim.Adam(
233214
chain(model_branch.parameters(), model_trunk.parameters()),
@@ -251,16 +232,24 @@ def main(cfg: DictConfig):
251232
optimizer.zero_grad()
252233
outvar = data[1]
253234

235+
coords = torch.stack([data[2], data[3]], dim=1).requires_grad_(True)
254236
# compute forward pass
255-
out = graph.forward(
237+
out = model.forward(
256238
{
257239
"k_prime": data[0][:, 0].unsqueeze(dim=1),
258-
"x": data[2].requires_grad_(True),
259-
"y": data[3].requires_grad_(True),
240+
"x": coords[:, 0:1],
241+
"y": coords[:, 1:2],
260242
}
261243
)
262244

263-
pde_out_arr = out["diffusion_u"]
245+
residuals = phy_informer.forward(
246+
{
247+
"coordinates": coords,
248+
"u": out["u"],
249+
"k": out["k"],
250+
}
251+
)
252+
pde_out_arr = residuals["diffusion_u"]
264253

265254
# Boundary condition
266255
pde_out_arr = F.pad(
@@ -276,7 +265,7 @@ def main(cfg: DictConfig):
276265
)
277266

278267
# Compute total loss
279-
loss = loss_data + cfg.phy_wt * loss_pde
268+
loss = loss_data + cfg.physics_weight * loss_pde
280269

281270
# Backward pass and optimizer and learning rate update
282271
loss.backward()
@@ -289,7 +278,7 @@ def main(cfg: DictConfig):
289278
log.log_epoch({"Learning Rate": optimizer.param_groups[0]["lr"]})
290279

291280
with LaunchLogger("valid", epoch=epoch) as log:
292-
error = validation_step(graph_infer, validation_dataloader, epoch)
281+
error = validation_step(model, validation_dataloader, epoch)
293282
log.log_epoch({"Validation error": error})
294283

295284
save_checkpoint(

examples/cfd/darcy_physics_informed/darcy_physics_informed_fno.py

Lines changed: 28 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,20 @@
1515
# limitations under the License.
1616

1717
import hydra
18-
from omegaconf import DictConfig
19-
import torch
20-
import numpy as np
21-
2218
import matplotlib.pyplot as plt
23-
from hydra.utils import to_absolute_path
24-
import torch.nn.functional as F
25-
from torch.utils.data import DataLoader
19+
import numpy as np
20+
import torch
2621
import torch.nn.functional as F
27-
28-
from modulus.models.fno import FNO
22+
from hydra.utils import to_absolute_path
2923
from modulus.launch.logging import LaunchLogger
3024
from modulus.launch.utils.checkpoint import save_checkpoint
25+
from modulus.models.fno import FNO
3126
from modulus.sym.eq.pdes.diffusion import Diffusion
27+
from modulus.sym.eq.phy_informer import PhysicsInformer
28+
from omegaconf import DictConfig
29+
from torch.utils.data import DataLoader
3230

3331
from utils import HDF5MapStyleDataset
34-
from ops import dx, ddx
3532

3633

3734
def validation_step(model, dataloader, epoch):
@@ -53,14 +50,14 @@ def validation_step(model, dataloader, epoch):
5350
# plotting
5451
fig, ax = plt.subplots(1, 3, figsize=(25, 5))
5552

56-
d_min = np.min(outvar[0, 0, ...])
57-
d_max = np.max(outvar[0, 0, ...])
53+
d_min = np.min(outvar[0, 0])
54+
d_max = np.max(outvar[0, 0])
5855

59-
im = ax[0].imshow(outvar[0, 0, ...], vmin=d_min, vmax=d_max)
56+
im = ax[0].imshow(outvar[0, 0], vmin=d_min, vmax=d_max)
6057
plt.colorbar(im, ax=ax[0])
61-
im = ax[1].imshow(predvar[0, 0, ...], vmin=d_min, vmax=d_max)
58+
im = ax[1].imshow(predvar[0, 0], vmin=d_min, vmax=d_max)
6259
plt.colorbar(im, ax=ax[1])
63-
im = ax[2].imshow(np.abs(predvar[0, 0, ...] - outvar[0, 0, ...]))
60+
im = ax[2].imshow(np.abs(predvar[0, 0] - outvar[0, 0]))
6461
plt.colorbar(im, ax=ax[2])
6562

6663
ax[0].set_title("True")
@@ -84,8 +81,8 @@ def main(cfg: DictConfig):
8481
LaunchLogger.initialize()
8582

8683
# Use Diffusion equation for the Darcy PDE
87-
darcy = Diffusion(T="u", time=False, dim=2, D="k", Q=1.0 * 4.49996e00 * 3.88433e-03)
88-
darcy_node = darcy.make_nodes()
84+
forcing_fn = 1.0 * 4.49996e00 * 3.88433e-03 # after scaling
85+
darcy = Diffusion(T="u", time=False, dim=2, D="k", Q=forcing_fn)
8986

9087
dataset = HDF5MapStyleDataset(
9188
to_absolute_path("./datasets/Darcy_241/train.hdf5"), device=device
@@ -110,6 +107,14 @@ def main(cfg: DictConfig):
110107
padding=cfg.model.fno.padding,
111108
).to(device)
112109

110+
phy_informer = PhysicsInformer(
111+
required_outputs=["diffusion_u"],
112+
equations=darcy,
113+
grad_method="finite_difference",
114+
device=device,
115+
fd_dx=1 / 240, # Unit square with resoultion as 240
116+
)
117+
113118
optimizer = torch.optim.Adam(
114119
model.parameters(),
115120
betas=(0.9, 0.999),
@@ -135,37 +140,15 @@ def main(cfg: DictConfig):
135140
# Compute forward pass
136141
out = model(invar[:, 0].unsqueeze(dim=1))
137142

138-
dxf = 1.0 / out.shape[-2]
139-
dyf = 1.0 / out.shape[-1]
140-
141-
# Compute gradients using finite difference
142-
sol_x = dx(out, dx=dxf, channel=0, dim=1, order=1, padding="zeros")
143-
sol_y = dx(out, dx=dyf, channel=0, dim=0, order=1, padding="zeros")
144-
sol_x_x = ddx(out, dx=dxf, channel=0, dim=1, order=1, padding="zeros")
145-
sol_y_y = ddx(out, dx=dyf, channel=0, dim=0, order=1, padding="zeros")
146-
147-
k_x = dx(invar, dx=dxf, channel=0, dim=1, order=1, padding="zeros")
148-
k_y = dx(invar, dx=dxf, channel=0, dim=0, order=1, padding="zeros")
149-
150-
k, _, _ = (
151-
invar[:, 0],
152-
invar[:, 1],
153-
invar[:, 2],
154-
)
155-
156-
pde_out = darcy_node[0].evaluate(
143+
# print(out.shape, invar[:,0:1].shape)
144+
residuals = phy_informer.forward(
157145
{
158-
"u__x": sol_x,
159-
"u__y": sol_y,
160-
"u__x__x": sol_x_x,
161-
"u__y__y": sol_y_y,
162-
"k": k,
163-
"k__x": k_x,
164-
"k__y": k_y,
146+
"u": out,
147+
"k": invar[:, 0:1],
165148
}
166149
)
150+
pde_out_arr = residuals["diffusion_u"]
167151

168-
pde_out_arr = pde_out["diffusion_u"]
169152
pde_out_arr = F.pad(
170153
pde_out_arr[:, :, 2:-2, 2:-2], [2, 2, 2, 2], "constant", 0
171154
)
@@ -175,7 +158,7 @@ def main(cfg: DictConfig):
175158
loss_data = F.mse_loss(outvar, out)
176159

177160
# Compute total loss
178-
loss = loss_data + 1 / 240 * cfg.phy_wt * loss_pde
161+
loss = loss_data + 1 / 240 * cfg.physics_weight * loss_pde
179162

180163
# Backward pass and optimizer and learning rate update
181164
loss.backward()

examples/cfd/darcy_physics_informed/download_data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import h5py
1818
import numpy as np
19+
1920
from utils import download_FNO_dataset
2021

2122
download_FNO_dataset("Darcy_241", outdir="datasets/")

0 commit comments

Comments
 (0)