1414# See the License for the specific language governing permissions and
1515# limitations under the License.
1616
17+ from itertools import chain
18+ from typing import Dict
19+
1720import hydra
18- from omegaconf import DictConfig
19- import torch
20- import numpy as np
2121import matplotlib .pyplot as plt
22- from hydra .utils import to_absolute_path
22+ import numpy as np
23+ import torch
2324import torch .nn .functional as F
24- from torch .utils .data import DataLoader
25- from itertools import chain
26-
27- from modulus .models .mlp import FullyConnected
28- from modulus .models .fno import FNO
25+ from hydra .utils import to_absolute_path
2926from modulus .launch .logging import LaunchLogger
3027from modulus .launch .utils .checkpoint import save_checkpoint
31-
32- from utils import HDF5MapStyleDataset
28+ from modulus . models . fno import FNO
29+ from modulus . models . mlp import FullyConnected
3330from modulus .sym .eq .pdes .diffusion import Diffusion
34-
35- from modulus .sym .graph import Graph
31+ from modulus .sym .eq .phy_informer import PhysicsInformer
3632from modulus .sym .key import Key
37- from modulus .sym .node import Node
38- from typing import Optional , Dict
3933from modulus .sym .models .arch import Arch
34+ from omegaconf import DictConfig
35+ from torch .utils .data import DataLoader
36+
37+ from utils import HDF5MapStyleDataset
4038
4139
4240def validation_step (graph , dataloader , epoch ):
@@ -61,14 +59,14 @@ def validation_step(graph, dataloader, epoch):
6159 # plotting
6260 fig , ax = plt .subplots (1 , 3 , figsize = (25 , 5 ))
6361
64- d_min = np .min (outvar [0 , 0 , ... ])
65- d_max = np .max (outvar [0 , 0 , ... ])
62+ d_min = np .min (outvar [0 , 0 ])
63+ d_max = np .max (outvar [0 , 0 ])
6664
67- im = ax [0 ].imshow (outvar [0 , 0 , ... ], vmin = d_min , vmax = d_max )
65+ im = ax [0 ].imshow (outvar [0 , 0 ], vmin = d_min , vmax = d_max )
6866 plt .colorbar (im , ax = ax [0 ])
69- im = ax [1 ].imshow (predvar [0 , 0 , ... ], vmin = d_min , vmax = d_max )
67+ im = ax [1 ].imshow (predvar [0 , 0 ], vmin = d_min , vmax = d_max )
7068 plt .colorbar (im , ax = ax [1 ])
71- im = ax [2 ].imshow (np .abs (predvar [0 , 0 , ... ] - outvar [0 , 0 , ... ]))
69+ im = ax [2 ].imshow (np .abs (predvar [0 , 0 ] - outvar [0 , 0 ]))
7270 plt .colorbar (im , ax = ax [2 ])
7371
7472 ax [0 ].set_title ("True" )
@@ -113,7 +111,7 @@ def __init__(
113111 trunk_net = None ,
114112 branch_net = None ,
115113 ):
116- super (MdlsSymWrapper , self ).__init__ (
114+ super ().__init__ (
117115 input_keys = input_keys ,
118116 output_keys = output_keys ,
119117 )
@@ -162,9 +160,8 @@ def main(cfg: DictConfig):
162160 LaunchLogger .initialize ()
163161
164162 # Use Diffusion equation for the Darcy PDE
165- darcy = Diffusion (T = "u" , time = False , dim = 2 , D = "k" , Q = 1.0 * 4.49996e00 * 3.88433e-03 )
166-
167- darcy_node = darcy .make_nodes ()
163+ forcing_fn = 1.0 * 4.49996e00 * 3.88433e-03 # after scaling
164+ darcy = Diffusion (T = "u" , time = False , dim = 2 , D = "k" , Q = forcing_fn )
168165
169166 dataset = HDF5MapStyleDataset (
170167 to_absolute_path ("./datasets/Darcy_241/train.hdf5" ), device = device
@@ -204,30 +201,14 @@ def main(cfg: DictConfig):
204201 output_keys = [Key ("k" ), Key ("u" )],
205202 trunk_net = model_trunk ,
206203 branch_net = model_branch ,
207- )
208-
209- nodes = darcy_node + [model .make_node (name = "network" , jit = False )]
210-
211- # note: this example uses the Graph class from Modulus Sym to construct the
212- # computational graph. This allows you to leverage Modulus Sym's optimized
213- # derivative backend to compute the derivatives, along with other benefits like
214- # symbolic definition of PDEs and leveraging the PDEs from Modulus Sym's PDE
215- # module.
216- # For more details, refer: https://docs.nvidia.com/deeplearning/modulus/modulus-sym/api/modulus.sym.html#module-modulus.sym.graph
217- graph = Graph (
218- nodes ,
219- [Key ("k_prime" ), Key ("x" ), Key ("y" )],
220- [Key ("k" ), Key ("u" ), Key ("diffusion_u" )],
221- func_arch = False ,
222204 ).to (device )
223205
224- # For pure inference (no gradients)
225- graph_infer = Graph (
226- [model .make_node (name = "network" , jit = False )],
227- [Key ("k_prime" ), Key ("x" ), Key ("y" )],
228- [Key ("k" ), Key ("u" )], # No PDE Key
229- func_arch = False ,
230- ).to (device )
206+ phy_informer = PhysicsInformer (
207+ required_outputs = ["diffusion_u" ],
208+ equations = darcy ,
209+ grad_method = "autodiff" ,
210+ device = device ,
211+ )
231212
232213 optimizer = torch .optim .Adam (
233214 chain (model_branch .parameters (), model_trunk .parameters ()),
@@ -251,16 +232,24 @@ def main(cfg: DictConfig):
251232 optimizer .zero_grad ()
252233 outvar = data [1 ]
253234
235+ coords = torch .stack ([data [2 ], data [3 ]], dim = 1 ).requires_grad_ (True )
254236 # compute forward pass
255- out = graph .forward (
237+ out = model .forward (
256238 {
257239 "k_prime" : data [0 ][:, 0 ].unsqueeze (dim = 1 ),
258- "x" : data [ 2 ]. requires_grad_ ( True ) ,
259- "y" : data [ 3 ]. requires_grad_ ( True ) ,
240+ "x" : coords [:, 0 : 1 ] ,
241+ "y" : coords [:, 1 : 2 ] ,
260242 }
261243 )
262244
263- pde_out_arr = out ["diffusion_u" ]
245+ residuals = phy_informer .forward (
246+ {
247+ "coordinates" : coords ,
248+ "u" : out ["u" ],
249+ "k" : out ["k" ],
250+ }
251+ )
252+ pde_out_arr = residuals ["diffusion_u" ]
264253
265254 # Boundary condition
266255 pde_out_arr = F .pad (
@@ -276,7 +265,7 @@ def main(cfg: DictConfig):
276265 )
277266
278267 # Compute total loss
279- loss = loss_data + cfg .phy_wt * loss_pde
268+ loss = loss_data + cfg .physics_weight * loss_pde
280269
281270 # Backward pass and optimizer and learning rate update
282271 loss .backward ()
@@ -289,7 +278,7 @@ def main(cfg: DictConfig):
289278 log .log_epoch ({"Learning Rate" : optimizer .param_groups [0 ]["lr" ]})
290279
291280 with LaunchLogger ("valid" , epoch = epoch ) as log :
292- error = validation_step (graph_infer , validation_dataloader , epoch )
281+ error = validation_step (model , validation_dataloader , epoch )
293282 log .log_epoch ({"Validation error" : error })
294283
295284 save_checkpoint (
0 commit comments