diff --git a/examples/cfd/external_aerodynamics/domino/README.md b/examples/cfd/external_aerodynamics/domino/README.md index dc0a23123d..3d2c884623 100644 --- a/examples/cfd/external_aerodynamics/domino/README.md +++ b/examples/cfd/external_aerodynamics/domino/README.md @@ -262,7 +262,7 @@ velocity, air density, etc.) that can vary across different simulations. 4. Ensure your simulation data includes global parameter values. The DoMINO datapipe expects these parameters in the pre-processed `.npy`/`.npz` files: - - Examine `openfoam_datapipe.py` and `process_data.py` for examples of how global + - Examine `vtk_cfd_dataset.py` and `process_data.py` for examples of how global parameter values are incorporated for external aerodynamics - For the automotive example, `air_density` and `inlet_velocity` remain constant across simulations @@ -327,7 +327,7 @@ The steps below outline the process. - `eval.checkpoint_name`: Checkpoint name `outputs/{project.name}/models` to evaluate model. - `eval.scaling_param_path`: Scaling parameters populated in `outputs/{project.name}`. -3. Before running `process_data.py` to process the data, be sure to modify `openfoam_datapipe.py`. +3. Before running `process_data.py` to process the data, be sure to modify `vtk_cfd_dataset.py`. This is the entry point for the user to modify the datapipe for dataprocessing. A couple of things that might need to be changed are non-dimensionalizing schemes based on the order of your variables and the `DrivAerAwsPaths` class with the diff --git a/examples/cfd/external_aerodynamics/domino/src/conf/config.yaml b/examples/cfd/external_aerodynamics/domino/src/conf/config.yaml index f0005e5e81..2f6e9bd4b1 100644 --- a/examples/cfd/external_aerodynamics/domino/src/conf/config.yaml +++ b/examples/cfd/external_aerodynamics/domino/src/conf/config.yaml @@ -81,6 +81,7 @@ data: # Input directory for training and validation data max: [4.5 , 1.2 , 1.2] gpu_preprocessing: true gpu_output: true + stl_suffix: .stl # ┌───────────────────────────────────────────┐ # │ Domain Parallelism Settings │ @@ -155,6 +156,7 @@ model: base_layer: 512 fourier_features: false num_modes: 5 + activation: ${model.activation} # ┌───────────────────────────────────────────┐ # │ Training Configs │ diff --git a/examples/cfd/external_aerodynamics/domino/src/inference_on_stl.py b/examples/cfd/external_aerodynamics/domino/src/inference_on_stl.py index 6bfda60cc8..f46e98db2b 100644 --- a/examples/cfd/external_aerodynamics/domino/src/inference_on_stl.py +++ b/examples/cfd/external_aerodynamics/domino/src/inference_on_stl.py @@ -41,18 +41,22 @@ nd_interpolator, get_filenames, write_to_vtp, + extract_global_parameters, + create_global_parameters_reference_array, ) from torch.cuda.amp import autocast from torch.nn.parallel import DistributedDataParallel from physicsnemo.distributed import DistributedManager from numpy.typing import NDArray -from typing import Any, Iterable, List, Literal, Mapping, Optional, Union, Callable +from typing import Optional import warp as wp -from pathlib import Path -import pandas as pd import matplotlib.pyplot as plt import pyvista as pv +from vtk.util.numpy_support import numpy_to_vtk + +from scipy.spatial import KDTree + try: from physicsnemo.sym.geometry.tessellation import Tessellation @@ -686,9 +690,18 @@ def __init__( else: self.device = self.dist.device + # Legacy support - will be replaced by global_params self.air_density = torch.full((1, 1), 1.205, dtype=torch.float32).to( self.device ) + + # New global parameter system + self.global_params_values = None + self.global_params_reference = None + self.global_params_reference_array = None + self.global_params_types = None + self.global_params_reference_dict = None + ( self.num_vol_vars, self.num_surf_vars, @@ -763,8 +776,12 @@ def load_volume_scaling_factors(self): scaling_param_path, "volume_scaling_factors.npy" ) - vol_factors = np.load(vol_factors_path, allow_pickle=True) - vol_factors = torch.from_numpy(vol_factors).to(self.device) + if os.path.exists(vol_factors_path): + vol_factors = np.load(vol_factors_path, allow_pickle=True) + vol_factors = torch.from_numpy(vol_factors).to(self.device) + else: + vol_factors = None + print("Volume scaling factors not found") return vol_factors @@ -774,8 +791,12 @@ def load_surface_scaling_factors(self): scaling_param_path, "surface_scaling_factors.npy" ) - surf_factors = np.load(surf_factors_path, allow_pickle=True) - surf_factors = torch.from_numpy(surf_factors).to(self.device) + if os.path.exists(surf_factors_path): + surf_factors = np.load(surf_factors_path, allow_pickle=True) + surf_factors = torch.from_numpy(surf_factors).to(self.device) + else: + surf_factors = None + print("Surface scaling factors not found") return surf_factors @@ -827,30 +848,37 @@ def read_stl_trimesh( self.length_scale = length_scale def get_num_variables(self): + model_type = self.cfg.model.model_type volume_variable_names = list(self.cfg.variables.volume.solution.keys()) num_vol_vars = 0 - for j in volume_variable_names: - if self.cfg.variables.volume.solution[j] == "vector": - num_vol_vars += 3 - else: - num_vol_vars += 1 + if model_type in ["volume", "combined"]: + for j in volume_variable_names: + if self.cfg.variables.volume.solution[j] == "vector": + num_vol_vars += 3 + else: + num_vol_vars += 1 + else: + num_vol_vars = None surface_variable_names = list(self.cfg.variables.surface.solution.keys()) num_surf_vars = 0 - for j in surface_variable_names: - if self.cfg.variables.surface.solution[j] == "vector": - num_surf_vars += 3 - else: - num_surf_vars += 1 + if model_type in ["surface", "combined"]: + for j in surface_variable_names: + if self.cfg.variables.surface.solution[j] == "vector": + num_surf_vars += 3 + else: + num_surf_vars += 1 + else: + num_surf_vars = None num_global_features = 0 - global_params_names = list(cfg.variables.global_parameters.keys()) + global_params_names = list(self.cfg.variables.global_parameters.keys()) for param in global_params_names: - if cfg.variables.global_parameters[param].type == "vector": + if self.cfg.variables.global_parameters[param].type == "vector": num_global_features += len( - cfg.variables.global_parameters[param].reference + self.cfg.variables.global_parameters[param].reference ) - elif cfg.variables.global_parameters[param].type == "scalar": + elif self.cfg.variables.global_parameters[param].type == "scalar": num_global_features += 1 else: raise ValueError(f"Unknown global parameter type") @@ -893,16 +921,103 @@ def initialize_model(self, model_path): self.vol_factors = self.load_volume_scaling_factors() self.surf_factors = self.load_surface_scaling_factors() self.load_bounding_box() + # Initialize global parameter reference values (read-only) + self.initialize_global_params_reference() + + def initialize_global_params_reference(self): + """Initialize global parameter reference values from config (read-only).""" + global_params_names = list(self.cfg.variables.global_parameters.keys()) + self.global_params_reference_dict = { + name: self.cfg.variables.global_parameters[name]["reference"] + for name in global_params_names + } + self.global_params_types = { + name: self.cfg.variables.global_parameters[name]["type"] + for name in global_params_names + } + + # Create reference array + self.global_params_reference_array = create_global_parameters_reference_array( + self.global_params_types, self.global_params_reference_dict + ) + self.global_params_reference = torch.from_numpy( + self.global_params_reference_array + ).to(self.device) + self.global_params_reference = torch.unsqueeze( + self.global_params_reference, 0 + ) # (1, N) + self.global_params_reference = torch.unsqueeze( + self.global_params_reference, -1 + ) # (1, N, 1) + + # Initialize with reference values + self.global_params_values = self.global_params_reference.clone() + + def set_global_params(self, params_dict): + """Set global parameters from a dictionary.""" + if self.global_params_types is None: + raise RuntimeError( + "Global parameters not initialized. Call initialize_model first." + ) + + # Extract parameters using the utility function + global_params_array = extract_global_parameters( + params_dict, self.global_params_types, "set_global_params" + ) + self.global_params_values = torch.from_numpy(global_params_array).to( + self.device + ) + self.global_params_values = torch.unsqueeze( + self.global_params_values, 0 + ) # (1, N) + self.global_params_values = torch.unsqueeze( + self.global_params_values, -1 + ) # (1, N, 1) + + def get_stream_velocity_and_air_density(self): + """Extract stream velocity and air density from global parameters for backward compatibility.""" + if self.global_params_values is not None: + # Extract from global parameters based on the parameter order + param_idx = 0 + stream_velocity = None + air_density = None + + for name, param_type in self.global_params_types.items(): + if name == "stream_velocity": + # For vector stream_velocity, take the first component + # This assumes no side-slip + stream_velocity = self.global_params_values[0, param_idx, 0] + elif name == "air_density": + air_density = self.global_params_values[0, param_idx, 0] + + # Update index based on parameter type + if param_type == "vector": + param_idx += len(self.global_params_reference_dict[name]) + else: + param_idx += 1 + + return stream_velocity, air_density + else: + # Fallback to legacy values + return ( + self.stream_velocity[0, 0] + if self.stream_velocity is not None + else None, + self.air_density[0, 0] if self.air_density is not None else None, + ) + def set_stencil_size(self, stencil_size): + self.stencil_size = stencil_size + + # Legacy methods for backward compatibility def set_stream_velocity(self, stream_velocity): + """Legacy method - use set_global_params instead.""" self.stream_velocity = torch.full( (1, 1), stream_velocity, dtype=torch.float32 ).to(self.device) - def set_stencil_size(self, stencil_size): - self.stencil_size = stencil_size - def set_air_density(self, air_density): + """Legacy method - use set_global_params instead.""" self.air_density = torch.full((1, 1), air_density, dtype=torch.float32).to( self.device ) @@ -1036,8 +1151,6 @@ def compute_surface_solutions(self, num_sample_points=None, plot_solutions=False pos_normals_com[:, start_idx:end_idx], self.s_grid, self.model, - inlet_velocity=self.stream_velocity, - air_density=self.air_density, ) surface_solutions[:, start_idx:end_idx] = surface_solutions_batch else: @@ -1061,8 +1174,6 @@ def compute_surface_solutions(self, num_sample_points=None, plot_solutions=False pos_normals_com[:, start_idx:end_idx], self.s_grid, self.model, - inlet_velocity=self.stream_velocity, - air_density=self.air_density, ) # print(torch.amax(surface_solutions_batch, (0, 1)), torch.amin(surface_solutions_batch, (0, 1))) surface_solutions[:, start_idx:end_idx] = surface_solutions_batch @@ -1091,18 +1202,17 @@ def compute_surface_solutions(self, num_sample_points=None, plot_solutions=False surface_solutions_all, self.surf_factors[0], self.surf_factors[1] ) + # Get stream velocity and air density for result processing + stream_velocity, air_density = self.get_stream_velocity_and_air_density() + self.out_dict["surface_coordinates"] = ( 0.5 * (surface_coordinates_all + 1.0) * (cmax - cmin) + cmin ) self.out_dict["pressure_surface"] = ( - surface_solutions_all[:, :, :1] - * self.stream_velocity**2.0 - * self.air_density + surface_solutions_all[:, :, :1] * stream_velocity**2.0 * air_density ) self.out_dict["wall-shear-stress"] = ( - surface_solutions_all[:, :, 1:4] - * self.stream_velocity**2.0 - * self.air_density + surface_solutions_all[:, :, 1:4] * stream_velocity**2.0 * air_density ) self.sampling_indices = sampling_indices @@ -1161,8 +1271,6 @@ def compute_volume_solutions(self, num_sample_points, plot_solutions=False): self.grid, self.model, use_sdf_basis=self.cfg.model.use_sdf_in_basis_func, - inlet_velocity=self.stream_velocity, - air_density=self.air_density, ) volume_solutions[:, start_idx:end_idx] = volume_solutions_batch end_event.record() @@ -1184,9 +1292,6 @@ def compute_volume_solutions(self, num_sample_points, plot_solutions=False): volume_coordinates_all = volume_coordinates volume_solutions_all = volume_solutions - cmax = scaling_factors[0] - cmin = scaling_factors[1] - volume_coordinates_all = torch.reshape( volume_coordinates_all, (1, num_sample_points, 3) ) @@ -1199,24 +1304,23 @@ def compute_volume_solutions(self, num_sample_points, plot_solutions=False): volume_solutions_all, self.vol_factors[0], self.vol_factors[1] ) + stream_velocity, air_density = self.get_stream_velocity_and_air_density() + self.out_dict["coordinates"] = ( 0.5 * (volume_coordinates_all + 1.0) * (cmax - cmin) + cmin ) - self.out_dict["velocity"] = ( - volume_solutions_all[:, :, :3] * self.stream_velocity - ) + self.out_dict["velocity"] = volume_solutions_all[:, :, :3] * stream_velocity + self.out_dict["pressure"] = ( - volume_solutions_all[:, :, 3:4] - * self.stream_velocity**2.0 - * self.air_density + volume_solutions_all[:, :, 3:4] * stream_velocity**2.0 * air_density ) # self.out_dict["turbulent-kinetic-energy"] = ( # volume_solutions_all[:, :, 4:5] - # * self.stream_velocity**2.0 - # * self.air_density + # * stream_velocity**2.0 + # * air_density # ) # self.out_dict["turbulent-viscosity"] = ( - # volume_solutions_all[:, :, 5:] * self.stream_velocity * self.length_scale + # volume_solutions_all[:, :, 5:] * stream_velocity * self.length_scale # ) self.out_dict["bounding_box_dims"] = torch.vstack(self.bounding_box_min_max) @@ -1232,20 +1336,18 @@ def compute_volume_solutions(self, num_sample_points, plot_solutions=False): 0.5 * (volume_coordinates_all + 1.0) * (cmax - cmin) + cmin ) volume_solutions_all[:, :, :3] = ( - volume_solutions_all[:, :, :3] * self.stream_velocity + volume_solutions_all[:, :, :3] * stream_velocity ) volume_solutions_all[:, :, 3:4] = ( - volume_solutions_all[:, :, 3:4] - * self.stream_velocity**2.0 - * self.air_density + volume_solutions_all[:, :, 3:4] * stream_velocity**2.0 * air_density ) # volume_solutions_all[:, :, 4:5] = ( # volume_solutions_all[:, :, 4:5] - # * self.stream_velocity**2.0 - # * self.air_density + # * stream_velocity**2.0 + # * air_density # ) # volume_solutions_all[:, :, 5] = ( - # volume_solutions_all[:, :, 5] * self.stream_velocity * self.length_scale + # volume_solutions_all[:, :, 5] * stream_velocity * self.length_scale # ) volume_coordinates_all = volume_coordinates_all.cpu().numpy() volume_solutions_all = volume_solutions_all.cpu().numpy() @@ -1263,7 +1365,7 @@ def compute_volume_solutions(self, num_sample_points, plot_solutions=False): prediction_grid[:, int(ny / 4), :, 0], prediction_grid[:, int(ny / 2), :, 0], var="x-vel", - save_path=plot_save_path + f"x-vel-midplane_{self.stream_velocity}.png", + save_path=plot_save_path + f"x-vel-midplane_{stream_velocity}.png", axes_titles=axes_titles, plot_error=False, ) @@ -1271,7 +1373,7 @@ def compute_volume_solutions(self, num_sample_points, plot_solutions=False): prediction_grid[:, int(ny / 4), :, 1], prediction_grid[:, int(ny / 2), :, 1], var="y-vel", - save_path=plot_save_path + f"y-vel-midplane_{self.stream_velocity}.png", + save_path=plot_save_path + f"y-vel-midplane_{stream_velocity}.png", axes_titles=axes_titles, plot_error=False, ) @@ -1279,7 +1381,7 @@ def compute_volume_solutions(self, num_sample_points, plot_solutions=False): prediction_grid[:, int(ny / 4), :, 2], prediction_grid[:, int(ny / 2), :, 2], var="z-vel", - save_path=plot_save_path + f"z-vel-midplane_{self.stream_velocity}.png", + save_path=plot_save_path + f"z-vel-midplane_{stream_velocity}.png", axes_titles=axes_titles, plot_error=False, ) @@ -1287,7 +1389,7 @@ def compute_volume_solutions(self, num_sample_points, plot_solutions=False): prediction_grid[:, int(ny / 4), :, 3], prediction_grid[:, int(ny / 2), :, 3], var="pres", - save_path=plot_save_path + f"pres-midplane_{self.stream_velocity}.png", + save_path=plot_save_path + f"pres-midplane_{stream_velocity}.png", axes_titles=axes_titles, plot_error=False, ) @@ -1295,7 +1397,7 @@ def compute_volume_solutions(self, num_sample_points, plot_solutions=False): # prediction_grid[:, int(ny / 4), :, 4], # prediction_grid[:, int(ny / 2), :, 4], # var="tke", - # save_path=plot_save_path + f"tke-midplane_{self.stream_velocity}.png", + # save_path=plot_save_path + f"tke-midplane_{stream_velocity}.png", # axes_titles=axes_titles, # plot_error=False, # ) @@ -1303,7 +1405,7 @@ def compute_volume_solutions(self, num_sample_points, plot_solutions=False): # prediction_grid[:, int(ny / 4), :, 5], # prediction_grid[:, int(ny / 2), :, 5], # var="nut", - # save_path=plot_save_path + f"nut-midplane_{self.stream_velocity}.png", + # save_path=plot_save_path + f"nut-midplane_{stream_velocity}.png", # axes_titles=axes_titles, # plot_error=False, # ) @@ -1369,24 +1471,13 @@ def compute_solution_on_surface( pos_normals_com, s_grid, model, - inlet_velocity, - air_density, ): """ - Global parameters: For this particular case, the model was trained on single velocity/density values - across all simulations. Hence, global_params_values and global_params_reference are the same. + Compute surface solutions using the generalized global parameter system. """ - global_params_values = torch.cat( - (inlet_velocity, air_density), axis=1 - ) # (1, 2) - global_params_values = torch.unsqueeze(global_params_values, -1) # (1, 2, 1) - - global_params_reference = torch.cat( - (inlet_velocity, air_density), axis=1 - ) # (1, 2) - global_params_reference = torch.unsqueeze( - global_params_reference, -1 - ) # (1, 2, 1) + # Use the generalized global parameter system + global_params_values = self.global_params_values + global_params_reference = self.global_params_reference if self.dist.world_size == 1: geo_encoding_local = model.geo_encoding_local( @@ -1445,22 +1536,13 @@ def compute_solution_in_volume( p_grid, model, use_sdf_basis, - inlet_velocity, - air_density, ): - - ## Global parameters - global_params_values = torch.cat( - (inlet_velocity, air_density), axis=1 - ) # (1, 2) - global_params_values = torch.unsqueeze(global_params_values, -1) # (1, 2, 1) - - global_params_reference = torch.cat( - (inlet_velocity, air_density), axis=1 - ) # (1, 2) - global_params_reference = torch.unsqueeze( - global_params_reference, -1 - ) # (1, 2, 1) + """ + Compute volume solutions using the generalized global parameter system. + """ + # Use the generalized global parameter system + global_params_values = self.global_params_values + global_params_reference = self.global_params_reference if self.dist.world_size == 1: geo_encoding_local = model.geo_encoding_local( @@ -1518,7 +1600,7 @@ def compute_solution_in_volume( input_path = cfg.eval.test_path dirnames = get_filenames(input_path) dev_id = torch.cuda.current_device() - num_files = int(len(dirnames) / 8) + num_files = int(len(dirnames) / dist.world_size) dirnames_per_gpu = dirnames[int(num_files * dev_id) : int(num_files * (dev_id + 1))] domino = dominoInference(cfg, dist, False) @@ -1550,13 +1632,15 @@ def compute_solution_in_volume( domino.compute_geo_encoding() # Calculate volume solutions - domino.compute_volume_solutions( - num_sample_points=10_256_000, plot_solutions=False - ) + if cfg.model.model_type in ["volume", "combined"]: + domino.compute_volume_solutions( + num_sample_points=10_256_000, plot_solutions=True + ) # Calculate surface solutions - domino.compute_surface_solutions() - domino.compute_forces() + if cfg.model.model_type in ["surface", "combined"]: + domino.compute_surface_solutions() + domino.compute_forces() out_dict = domino.get_out_dict() print( diff --git a/examples/cfd/external_aerodynamics/domino/src/process_data.py b/examples/cfd/external_aerodynamics/domino/src/process_data.py index 3401001d6d..081cf0ef39 100644 --- a/examples/cfd/external_aerodynamics/domino/src/process_data.py +++ b/examples/cfd/external_aerodynamics/domino/src/process_data.py @@ -15,17 +15,17 @@ # limitations under the License. """ -This code runs the data processing in parallel to load OpenFoam files, process them -and save in the npy format for faster processing in the DoMINO datapipes. Several -parameters such as number of processors, input and output paths, etc. can be +This code runs the data processing in parallel to load VTK CFD files, process them +and save in the npy format for faster processing in the DoMINO datapipes. Several +parameters such as number of processors, input and output paths, etc. can be configured in config.yaml in the data_processing tab. """ -from openfoam_datapipe import OpenFoamDataset +from vtk_cfd_dataset import VtkCfdDataset from physicsnemo.utils.domino.utils import * import multiprocessing import hydra, time -from hydra.utils import to_absolute_path +import numbers from omegaconf import DictConfig, OmegaConf @@ -54,7 +54,6 @@ def process_files(*args_list): @hydra.main(version_base="1.3", config_path="conf", config_name="config") def main(cfg: DictConfig): print(f"Config summary:\n{OmegaConf.to_yaml(cfg, sort_keys=True)}") - phase = "train" volume_variable_names = list(cfg.variables.volume.solution.keys()) num_vol_vars = 0 for j in volume_variable_names: @@ -77,12 +76,21 @@ def main(cfg: DictConfig): name: cfg.variables.global_parameters[name]["reference"] for name in global_params_names } + + epsilon = 1e-10 # Normalization will fail if reference is too close to zero + for name, value in global_params_reference.items(): + if isinstance(value, numbers.Number): + if abs(value) <= epsilon: + raise ValueError( + f"Global parameter '{name}' has a reference value of {value}, which is too close to zero." + ) + global_params_types = { name: cfg.variables.global_parameters[name]["type"] for name in global_params_names } - fm_data = OpenFoamDataset( + fm_data = VtkCfdDataset( cfg.data_processor.input_dir, kind=cfg.data_processor.kind, volume_variables=volume_variable_names, @@ -90,6 +98,7 @@ def main(cfg: DictConfig): global_params_types=global_params_types, global_params_reference=global_params_reference, model_type=cfg.model.model_type, + stl_suffix=cfg.data.stl_suffix, ) output_dir = cfg.data_processor.output_dir create_directory(output_dir) diff --git a/examples/cfd/external_aerodynamics/domino/src/test.py b/examples/cfd/external_aerodynamics/domino/src/test.py index b1914f99b1..dccac9ea76 100644 --- a/examples/cfd/external_aerodynamics/domino/src/test.py +++ b/examples/cfd/external_aerodynamics/domino/src/test.py @@ -35,29 +35,26 @@ import numpy as np -from collections import defaultdict from pathlib import Path -from typing import Any, Iterable, List, Literal, Mapping, Optional, Union, Callable -import pandas as pd +from vtk_cfd_dataset import ( + DriveSimPaths, + DrivAerAwsPaths, + SHIFTPaths, +) import pyvista as pv import torch from torch.nn.parallel import DistributedDataParallel -from torch.utils.data import DataLoader, Dataset import vtk from vtk.util import numpy_support from physicsnemo.distributed import DistributedManager -from physicsnemo.datapipes.cae.domino_datapipe import DoMINODataPipe from physicsnemo.models.domino.model import DoMINO from physicsnemo.utils.domino.utils import * from physicsnemo.utils.sdf import signed_distance_field -# AIR_DENSITY = 1.205 -# STREAM_VELOCITY = 30.00 - def loss_fn(output, target): masked_loss = torch.mean(((output - target) ** 2.0), (0, 1, 2)) @@ -396,21 +393,65 @@ def main(cfg: DictConfig): model = model.module dirnames = get_filenames(input_path) - dev_id = torch.cuda.current_device() - num_files = int(len(dirnames) / dist.world_size) - dirnames_per_gpu = dirnames[int(num_files * dev_id) : int(num_files * (dev_id + 1))] + # Calculate base number of files per GPU + base_files_per_gpu = len(dirnames) // dist.world_size + + # Calculate remainder + remainder = len(dirnames) % dist.world_size + + # Distribute the remainder evenly + # GPUs with rank < remainder get one extra file + if dist.rank < remainder: + start_idx = dist.rank * (base_files_per_gpu + 1) + end_idx = start_idx + base_files_per_gpu + 1 + else: + start_idx = dist.rank * base_files_per_gpu + remainder + end_idx = start_idx + base_files_per_gpu + + dirnames_per_gpu = dirnames[start_idx:end_idx] pred_save_path = cfg.eval.save_path if dist.rank == 0: create_directory(pred_save_path) - for count, dirname in enumerate(dirnames_per_gpu): + path_getters = { + "drivesim": DriveSimPaths, + "drivaer_aws": DrivAerAwsPaths, + "shift": SHIFTPaths, + } + kind = cfg.data_processor.kind + path_getter = path_getters[kind] + stl_suffix = cfg.data.stl_suffix + + # Get global parameters and global parameters scaling from config.yaml + global_params_names = list(cfg.variables.global_parameters.keys()) + global_params_reference_dict = { + name: cfg.variables.global_parameters[name]["reference"] + for name in global_params_names + } + global_params_types = { + name: cfg.variables.global_parameters[name]["type"] + for name in global_params_names + } + + global_params_reference = create_global_parameters_reference_array( + global_params_types, global_params_reference_dict + ) + + for i, dirname in enumerate(dirnames_per_gpu): # print(f"Processing file {dirname}") - filepath = os.path.join(input_path, dirname) - tag = int(re.findall(r"(\w+?)(\d+)", dirname)[0][1]) - stl_path = os.path.join(filepath, f"drivaer_{tag}.stl") - vtp_path = os.path.join(filepath, f"boundary_{tag}.vtp") - vtu_path = os.path.join(filepath, f"volume_{tag}.vtu") + filepath = Path(input_path) / dirname + # Load parameters from JSON file if available, otherwise use reference values + global_params_values, params_data = load_parameters_from_json( + filepath / "params.json", global_params_types, global_params_reference + ) + if kind == "drivaer_aws": + tag = int(re.findall(r"(\w+?)(\d+)", dirname)[0][1]) + else: + tag = dirname.replace("/", "") + stl_path = path_getter.geometry_path(filepath, stl_suffix) + vtp_path = path_getter.surface_path(filepath) + vtu_path = path_getter.volume_path(filepath) vtp_pred_save_path = os.path.join( pred_save_path, f"boundary_{tag}_predicted.vtp" @@ -459,50 +500,6 @@ def main(cfg: DictConfig): sdf_surf_grid = np.float32(sdf_surf_grid) surf_grid_max_min = np.float32(np.asarray([s_min, s_max])) - # Get global parameters and global parameters scaling from config.yaml - global_params_names = list(cfg.variables.global_parameters.keys()) - global_params_reference = { - name: cfg.variables.global_parameters[name]["reference"] - for name in global_params_names - } - global_params_types = { - name: cfg.variables.global_parameters[name]["type"] - for name in global_params_names - } - stream_velocity = global_params_reference["inlet_velocity"][0] - air_density = global_params_reference["air_density"] - - # Arrange global parameters reference in a list, ensuring it is flat - global_params_reference_list = [] - for name, type in global_params_types.items(): - if type == "vector": - global_params_reference_list.extend(global_params_reference[name]) - elif type == "scalar": - global_params_reference_list.append(global_params_reference[name]) - else: - raise ValueError( - f"Global parameter {name} not supported for this dataset" - ) - global_params_reference = np.array( - global_params_reference_list, dtype=np.float32 - ) - - # Define the list of global parameter values for each simulation. - # Note: The user must ensure that the values provided here correspond to the - # `global_parameters` specified in `config.yaml` and that these parameters - # exist within each simulation file. - global_params_values_list = [] - for key in global_params_types.keys(): - if key == "inlet_velocity": - global_params_values_list.append(stream_velocity) - elif key == "air_density": - global_params_values_list.append(air_density) - else: - raise ValueError( - f"Global parameter {key} not supported for this dataset" - ) - global_params_values = np.array(global_params_values_list, dtype=np.float32) - # Read VTP if model_type == "surface" or model_type == "combined": reader = vtk.vtkXMLPolyDataReader() @@ -512,8 +509,14 @@ def main(cfg: DictConfig): celldata_all = get_node_to_elem(polydata_surf) - celldata = celldata_all.GetCellData() - surface_fields = get_fields(celldata, surface_variable_names) + if cfg.data_processor.kind != "shift": + celldata_all = get_node_to_elem(polydata_surf) + else: + celldata_all = polydata_surf + + surface_fields = get_fields( + celldata_all.GetCellData(), surface_variable_names + ) surface_fields = np.concatenate(surface_fields, axis=-1) mesh = pv.PolyData(polydata_surf) @@ -828,9 +831,10 @@ def main(cfg: DictConfig): volParam_vtk.SetName(f"{volume_variable_names[1]}Pred") polydata_vol.GetPointData().AddArray(volParam_vtk) - volParam_vtk = numpy_support.numpy_to_vtk(prediction_vol[:, 4:5]) - volParam_vtk.SetName(f"{volume_variable_names[2]}Pred") - polydata_vol.GetPointData().AddArray(volParam_vtk) + if num_vol_vars > 4: + volParam_vtk = numpy_support.numpy_to_vtk(prediction_vol[:, 4:5]) + volParam_vtk.SetName(f"{volume_variable_names[2]}Pred") + polydata_vol.GetPointData().AddArray(volParam_vtk) write_to_vtu(polydata_vol, vtu_pred_save_path) diff --git a/examples/cfd/external_aerodynamics/domino/src/openfoam_datapipe.py b/examples/cfd/external_aerodynamics/domino/src/vtk_cfd_dataset.py similarity index 60% rename from examples/cfd/external_aerodynamics/domino/src/openfoam_datapipe.py rename to examples/cfd/external_aerodynamics/domino/src/vtk_cfd_dataset.py index 7432ecd8a4..d6d750f5d0 100644 --- a/examples/cfd/external_aerodynamics/domino/src/openfoam_datapipe.py +++ b/examples/cfd/external_aerodynamics/domino/src/vtk_cfd_dataset.py @@ -15,8 +15,8 @@ # limitations under the License. """ -This is the datapipe to read OpenFoam files (vtp/vtu/stl) and save them as point clouds -in npy format. +This is the datapipe to read VTK CFD files (vtp/vtu/stl) and save them as point clouds +in npy format. Supports OpenFOAM and data exported from other solvers (like SHIFT datasets). """ @@ -32,14 +32,11 @@ from physicsnemo.utils.domino.utils import * from torch.utils.data import Dataset -# AIR_DENSITY = 1.205 -# STREAM_VELOCITY = 30.00 - class DriveSimPaths: @staticmethod - def geometry_path(car_dir: Path) -> Path: - return car_dir / "body.stl" + def geometry_path(car_dir: Path, stl_suffix: str = ".stl") -> Path: + return car_dir / f"body{stl_suffix}" @staticmethod def volume_path(car_dir: Path) -> Path: @@ -52,32 +49,50 @@ def surface_path(car_dir: Path) -> Path: class DrivAerAwsPaths: @staticmethod - def _get_index(car_dir: Path) -> str: - return car_dir.name.removeprefix("run_") + def _get_index(case_dir: Path) -> str: + return case_dir.name.removeprefix("run_") @staticmethod - def geometry_path(car_dir: Path) -> Path: - return car_dir / f"drivaer_{DrivAerAwsPaths._get_index(car_dir)}.stl" + def geometry_path(case_dir: Path, stl_suffix: str = ".stl") -> Path: + return case_dir / f"drivaer_{DrivAerAwsPaths._get_index(case_dir)}{stl_suffix}" @staticmethod - def volume_path(car_dir: Path) -> Path: - return car_dir / f"volume_{DrivAerAwsPaths._get_index(car_dir)}.vtu" + def volume_path(case_dir: Path) -> Path: + return case_dir / f"volume_{DrivAerAwsPaths._get_index(case_dir)}.vtu" @staticmethod - def surface_path(car_dir: Path) -> Path: - return car_dir / f"boundary_{DrivAerAwsPaths._get_index(car_dir)}.vtp" + def surface_path(case_dir: Path) -> Path: + return case_dir / f"boundary_{DrivAerAwsPaths._get_index(case_dir)}.vtp" -class OpenFoamDataset(Dataset): +class SHIFTPaths: + """Path utilities for SHIFT dataset file locations.""" + + @staticmethod + def geometry_path(case_dir: Path, stl_suffix: str = ".stl") -> Path: + """Get path to geometry file for a SHIFT directory.""" + return case_dir / f"merged_surfaces{stl_suffix}" + + @staticmethod + def volume_path(case_dir: Path) -> Path: + return case_dir / f"merged_volumes.vtu" + + @staticmethod + def surface_path(case_dir: Path) -> Path: + return case_dir / f"merged_surfaces.vtp" + + +class VtkCfdDataset(Dataset): """ - Datapipe for converting openfoam dataset to npy + Datapipe for converting VTK CFD dataset to npy format. + Supports OpenFOAM and data exported from other solvers (like SHIFT datasets). """ def __init__( self, data_path: Union[str, Path], - kind: Literal["drivesim", "drivaer_aws"] = "drivesim", + kind: Literal["drivesim", "drivaer_aws", "shift"] = "drivesim", surface_variables: Optional[list] = [ "pMean", "wallShearStress", @@ -93,6 +108,7 @@ def __init__( }, device: int = 0, model_type=None, + stl_suffix: str = ".stl", ): if isinstance(data_path, str): data_path = Path(data_path) @@ -100,11 +116,17 @@ def __init__( self.data_path = data_path - supported_kinds = ["drivesim", "drivaer_aws"] + supported_kinds = ["drivesim", "drivaer_aws", "shift"] assert ( kind in supported_kinds ), f"kind should be one of {supported_kinds}, got {kind}" - self.path_getter = DriveSimPaths if kind == "drivesim" else DrivAerAwsPaths + path_getters = { + "drivesim": DriveSimPaths, + "drivaer_aws": DrivAerAwsPaths, + "shift": SHIFTPaths, + } + + self.path_getter = path_getters[kind] assert self.data_path.exists(), f"Path {self.data_path} does not exist" @@ -112,7 +134,6 @@ def __init__( self.filenames = get_filenames(self.data_path) random.shuffle(self.filenames) - self.indices = np.array(len(self.filenames)) self.surface_variables = surface_variables self.volume_variables = volume_variables @@ -121,24 +142,41 @@ def __init__( self.global_params_reference = global_params_reference self.stream_velocity = 0.0 - for vel_component in self.global_params_reference["inlet_velocity"]: + inlet_velocity_name = ( + "inlet_velocity" + if "inlet_velocity" in self.global_params_reference + else "stream_velocity" + ) + + for vel_component in self.global_params_reference[inlet_velocity_name]: self.stream_velocity += vel_component**2 self.stream_velocity = np.sqrt(self.stream_velocity) self.air_density = self.global_params_reference["air_density"] self.device = device self.model_type = model_type + self.surface_data_on_points = kind != "shift" + self.stl_suffix = stl_suffix def __len__(self): return len(self.filenames) def __getitem__(self, idx): cfd_filename = self.filenames[idx] - car_dir = self.data_path / cfd_filename + case_dir = self.data_path / cfd_filename + + stl_path = self.path_getter.geometry_path(case_dir, self.stl_suffix) + + # Check if STL file exists + if not Path(stl_path).exists(): + raise FileNotFoundError(f"STL file not found: {stl_path}") - stl_path = self.path_getter.geometry_path(car_dir) reader = pv.get_reader(stl_path) mesh_stl = reader.read() + # Check if STL data was successfully loaded + if mesh_stl is None or mesh_stl.n_points == 0: + raise RuntimeError(f"Failed to load STL data from: {stl_path}") + stl_vertices = mesh_stl.points stl_faces = np.array(mesh_stl.faces).reshape((-1, 4))[ :, 1: @@ -151,13 +189,21 @@ def __getitem__(self, idx): length_scale = np.amax(np.amax(stl_vertices, 0) - np.amin(stl_vertices, 0)) if self.model_type == "volume" or self.model_type == "combined": - filepath = self.path_getter.volume_path(car_dir) + filepath = self.path_getter.volume_path(case_dir) + # Check if volume file exists + if not Path(filepath).exists(): + raise FileNotFoundError(f"Volume file not found: {filepath}") + reader = vtk.vtkXMLUnstructuredGridReader() reader.SetFileName(filepath) reader.Update() # Get the unstructured grid data polydata = reader.GetOutput() + # Check if VTK reader successfully loaded the data + if polydata is None or polydata.GetNumberOfCells() == 0: + raise RuntimeError(f"Failed to load volume data from: {filepath}") + volume_coordinates, volume_fields = get_volume_data( polydata, self.volume_variables ) @@ -168,8 +214,8 @@ def __getitem__(self, idx): volume_fields[:, 3:4] = volume_fields[:, 3:4] / ( self.air_density * self.stream_velocity**2.0 ) - - volume_fields[:, 4:] = volume_fields[:, 4:] / ( + length_scale = np.amax(np.amax(stl_vertices, 0) - np.amin(stl_vertices, 0)) + volume_fields[:, 4:5] = volume_fields[:, 4:5] / ( self.stream_velocity * length_scale ) else: @@ -177,14 +223,33 @@ def __getitem__(self, idx): volume_coordinates = None if self.model_type == "surface" or self.model_type == "combined": - surface_filepath = self.path_getter.surface_path(car_dir) + surface_filepath = self.path_getter.surface_path(case_dir) + # Check if surface file exists + if not Path(surface_filepath).exists(): + raise FileNotFoundError(f"Surface file not found: {surface_filepath}") + reader = vtk.vtkXMLPolyDataReader() reader.SetFileName(surface_filepath) reader.Update() polydata = reader.GetOutput() + # Check if VTK reader successfully loaded the data + if polydata is None or polydata.GetNumberOfCells() == 0: + raise RuntimeError( + f"Failed to load surface data from: {surface_filepath}" + ) + + if self.surface_data_on_points: + celldata_all = get_node_to_elem(polydata) + celldata = celldata_all.GetCellData() + else: + celldata = polydata.GetCellData() + + # Check if celldata is valid before processing + if celldata is None: + raise RuntimeError( + f"No cell data found in surface file: {surface_filepath}" + ) - celldata_all = get_node_to_elem(polydata) - celldata = celldata_all.GetCellData() surface_fields = get_fields(celldata, self.surface_variables) surface_fields = np.concatenate(surface_fields, axis=-1) @@ -212,40 +277,16 @@ def __getitem__(self, idx): surface_normals = None surface_sizes = None - # Arrange global parameters reference in a list based on the type of the parameter - global_params_reference_list = [] - for name, type in self.global_params_types.items(): - if type == "vector": - global_params_reference_list.extend(self.global_params_reference[name]) - elif type == "scalar": - global_params_reference_list.append(self.global_params_reference[name]) - else: - raise ValueError( - f"Global parameter {name} not supported for this dataset" - ) - global_params_reference = np.array( - global_params_reference_list, dtype=np.float32 + # Create reference array from global parameters configuration + global_params_reference = create_global_parameters_reference_array( + self.global_params_types, self.global_params_reference ) - # Prepare the list of global parameter values for each simulation file - # Note: The user must ensure that the values provided here correspond to the - # `global_parameters` specified in `config.yaml` and that these parameters - # exist within each simulation file. - global_params_values_list = [] - for key in self.global_params_types.keys(): - if key == "inlet_velocity": - global_params_values_list.extend( - self.global_params_reference["inlet_velocity"] - ) - elif key == "air_density": - global_params_values_list.append( - self.global_params_reference["air_density"] - ) - else: - raise ValueError( - f"Global parameter {key} not supported for this dataset" - ) - global_params_values = np.array(global_params_values_list, dtype=np.float32) + # Load parameters from JSON file if available, otherwise use constants + params_json_path = case_dir / "params.json" + global_params_values, _ = load_parameters_from_json( + params_json_path, self.global_params_types, global_params_reference + ) # Add the parameters to the dictionary return { @@ -266,7 +307,7 @@ def __getitem__(self, idx): if __name__ == "__main__": - fm_data = OpenFoamDataset( + fm_data = VtkCfdDataset( data_path="/code/aerofoundationdata/", phase="train", volume_variables=["UMean", "pMean", "nutMean"], diff --git a/physicsnemo/utils/domino/utils.py b/physicsnemo/utils/domino/utils.py index 8bfb518e70..d721f9ae4b 100644 --- a/physicsnemo/utils/domino/utils.py +++ b/physicsnemo/utils/domino/utils.py @@ -22,6 +22,8 @@ CPU (NumPy) and GPU (CuPy) operations with automatic fallbacks. """ +import json +import numbers from pathlib import Path from typing import Any, Sequence @@ -1146,3 +1148,145 @@ def area_weighted_shuffle_array( selected_indices = xp.asarray(selected_indices) return arr[selected_indices], selected_indices + + +def extract_global_parameters( + params_data: dict[str, Any], + global_params_types: dict[str, str], + params_location: str = "param dictionary", +) -> np.ndarray: + """Extract and flatten global parameters from parameter data dictionary. + This function processes global parameters based on their types (scalar/vector) + and flattens them into a single array. Scalars are automatically promoted to + vectors when the parameter type is defined as "vector". + Args: + params_data: Dictionary containing parameter names and their values. + global_params_types: Dictionary mapping parameter names to their types ("scalar" or "vector"). + params_location: Description of parameter source for error messages. + Returns: + Flattened array containing all global parameter values as float32. + Raises: + ValueError: If a required parameter is missing or has unsupported type. + Examples: + >>> params = {"pressure": 101325.0, "velocity": [1.0, 0.0, 0.0]} + >>> types = {"pressure": "scalar", "velocity": "vector"} + >>> result = extract_global_parameters(params, types) + >>> result.shape + (4,) + >>> np.allclose(result, [101325.0, 1.0, 0.0, 0.0]) + True + >>> # Scalar automatically promoted to vector + >>> params_scalar = {"pressure": 101325.0, "stream_velocity": 2.5} + >>> types_scalar = {"pressure": "scalar", "stream_velocity": "vector"} + >>> result_scalar = extract_global_parameters(params_scalar, types_scalar) + >>> result_scalar.shape + (2,) + >>> np.allclose(result_scalar, [101325.0, 2.5]) + True + """ + global_params_values_list = [] + for name, typ in global_params_types.items(): + if name not in params_data: + raise ValueError(f"Global parameter {name} not found in {params_location}") + param_value = params_data[name] + if typ == "vector": + # Automatically promote scalars to vectors when vector type is expected + if isinstance(param_value, numbers.Number): + global_params_values_list.append(param_value) + else: + global_params_values_list.extend(param_value) + elif typ == "scalar": + global_params_values_list.append(param_value) + else: + raise ValueError(f"Global parameter {name} not supported for this dataset") + return np.array(global_params_values_list, dtype=np.float32) + + +def create_global_parameters_reference_array( + global_params_types: dict[str, str], global_params_reference: dict[str, Any] +) -> np.ndarray: + """Create flattened reference array from global parameters types and reference values. + This function arranges global parameter reference values into a flattened array + based on their types (scalar/vector). This is commonly used to create reference + arrays for parameter normalization and JSON parameter loading. + Args: + global_params_types: Dictionary mapping parameter names to their types ("scalar" or "vector"). + global_params_reference: Dictionary mapping parameter names to their reference values. + Returns: + Flattened array of reference values arranged by parameter type. + Raises: + ValueError: If a parameter has an unsupported type. + Examples: + >>> types = {"pressure": "scalar", "velocity": "vector"} + >>> refs = {"pressure": 101325.0, "velocity": [1.0, 0.0, 0.0]} + >>> ref_array = create_global_parameters_reference_array(types, refs) + >>> ref_array.shape + (4,) + >>> np.allclose(ref_array, [101325.0, 1.0, 0.0, 0.0]) + True + """ + # Arrange global parameters reference in a list based on the type of the parameter + global_params_reference_list = [] + for name, param_type in global_params_types.items(): + if param_type == "vector": + global_params_reference_list.extend(global_params_reference[name]) + elif param_type == "scalar": + global_params_reference_list.append(global_params_reference[name]) + else: + raise ValueError(f"Global parameter {name} not supported for this dataset") + + return np.array(global_params_reference_list, dtype=np.float32) + + +def load_parameters_from_json( + params_json_path: Path, + global_params_types: dict[str, str], + global_params_reference: np.ndarray, +) -> tuple[np.ndarray, dict[str, Any]]: + """Load global parameters from JSON file with fallback to reference values. + This function attempts to load global parameters from a JSON file and extract + them according to the specified types. If the file doesn't exist or parsing + fails, it falls back to using reference parameter values. + Args: + params_json_path: Path to JSON file containing parameter definitions. + global_params_types: Dictionary mapping parameter names to types ("scalar" or "vector"). + global_params_reference: Default parameter values to use as fallback. + Returns: + Tuple containing: + - Array of global parameter values loaded from JSON or fallback reference values + - Dictionary of raw parameter data from JSON file, or empty dict if file not loaded + Examples: + >>> from pathlib import Path + >>> import tempfile + >>> import json + >>> # Create temporary JSON file + >>> with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + ... json.dump({"pressure": 101325.0, "velocity": [1.0, 0.0, 0.0]}, f) + ... temp_path = Path(f.name) + >>> types = {"pressure": "scalar", "velocity": "vector"} + >>> reference = np.array([0.0, 0.0, 0.0, 0.0]) + >>> params, data = load_parameters_from_json(temp_path, types, reference) + >>> params.shape + (4,) + >>> "pressure" in data + True + >>> # Cleanup + >>> temp_path.unlink() + """ + if params_json_path.exists(): + try: + with open(params_json_path, "r") as f: + params_data = json.load(f) + return ( + extract_global_parameters( + params_data, global_params_types, params_json_path + ), + params_data, + ) + except (json.JSONDecodeError, IOError) as e: + # Fall back to constants if JSON parsing fails + print( + f"Warning: Could not parse {params_json_path}: {e}. Using default constants." + ) + return global_params_reference.copy(), {} + return global_params_reference.copy(), {}