diff --git a/examples/cfd/external_aerodynamics/domino/requirements.txt b/examples/cfd/external_aerodynamics/domino/requirements.txt index 1440cbf948..cafc1c7a4c 100644 --- a/examples/cfd/external_aerodynamics/domino/requirements.txt +++ b/examples/cfd/external_aerodynamics/domino/requirements.txt @@ -1,2 +1,4 @@ torchinfo -warp-lang \ No newline at end of file +warp-lang +tensorboard +cuml-cu12>=25.6.0 diff --git a/examples/cfd/external_aerodynamics/domino/src/inference_on_stl.py b/examples/cfd/external_aerodynamics/domino/src/inference_on_stl.py index 2c8f1afcb8..b76666a7ff 100644 --- a/examples/cfd/external_aerodynamics/domino/src/inference_on_stl.py +++ b/examples/cfd/external_aerodynamics/domino/src/inference_on_stl.py @@ -34,11 +34,17 @@ import numpy as np import torch -from modulus.models.domino.model import DoMINO -from modulus.utils.domino.utils import * +from physicsnemo.models.domino.model import DoMINO +from physicsnemo.utils.domino.utils import ( + unnormalize, + create_directory, + nd_interpolator, + get_filenames, + write_to_vtp, +) from torch.cuda.amp import autocast from torch.nn.parallel import DistributedDataParallel -from modulus.distributed import DistributedManager +from physicsnemo.distributed import DistributedManager from numpy.typing import NDArray from typing import Any, Iterable, List, Literal, Mapping, Optional, Union, Callable @@ -49,7 +55,7 @@ import pyvista as pv try: - from modulus.sym.geometry.tessellation import Tessellation + from physicsnemo.sym.geometry.tessellation import Tessellation SYM_AVAILABLE = True except ImportError: @@ -663,7 +669,7 @@ def __init__( self, cfg: DictConfig, dist: None, - cached_geo_encoding: False, + cached_geo_encoding: bool = False, ): self.cfg = cfg diff --git a/examples/cfd/external_aerodynamics/domino/src/train.py b/examples/cfd/external_aerodynamics/domino/src/train.py index 19abb63fb0..ba6788a04a 100644 --- a/examples/cfd/external_aerodynamics/domino/src/train.py +++ b/examples/cfd/external_aerodynamics/domino/src/train.py @@ -265,7 +265,7 @@ def compute_loss_dict( integral_scaling_factor: float, surf_loss_scaling: float, vol_loss_scaling: float, -) -> Tuple[torch.Tensor, dict]: +) -> tuple[torch.Tensor, dict]: """ Compute the loss terms in a single function call. diff --git a/physicsnemo/datapipes/cae/domino_datapipe.py b/physicsnemo/datapipes/cae/domino_datapipe.py index 159ab0f5a2..ad83331f24 100644 --- a/physicsnemo/datapipes/cae/domino_datapipe.py +++ b/physicsnemo/datapipes/cae/domino_datapipe.py @@ -843,7 +843,7 @@ def preprocess_volume( mesh_indices_flattened, grid_reshaped, use_sign_winding_number=True, - ).reshape(nx, ny, nz) + ).reshape((nx, ny, nz)) if self.config.sampling: volume_coordinates_sampled, idx_volume = shuffle_array( diff --git a/physicsnemo/utils/domino/utils.py b/physicsnemo/utils/domino/utils.py index c2c8e0feaa..8bfb518e70 100644 --- a/physicsnemo/utils/domino/utils.py +++ b/physicsnemo/utils/domino/utils.py @@ -15,219 +15,555 @@ # limitations under the License. """ -Important utilities for data processing and training, testing DoMINO. +Utilities for data processing and training with the DoMINO model architecture. + +This module provides essential utilities for computational fluid dynamics data processing, +mesh manipulation, field normalization, and geometric computations. It supports both +CPU (NumPy) and GPU (CuPy) operations with automatic fallbacks. """ -import os -from typing import Any, List, Optional, Sequence, Tuple, Union +from pathlib import Path +from typing import Any, Sequence import numpy as np +import vtk from scipy.spatial import KDTree +from vtk import vtkDataSetTriangleFilter +from vtk.util import numpy_support from physicsnemo.utils.profiling import profile +# Type alias for arrays that can be either NumPy or CuPy + try: import cupy as cp - CUPY_AVAILABLE = True + ArrayType = np.ndarray | cp.ndarray except ImportError: - CUPY_AVAILABLE = False + ArrayType = np.ndarray -try: - import pyvista as pv +def array_type(array: ArrayType) -> "type[np] | type[cp]": + """Determine the array module (NumPy or CuPy) for the given array. - PV_AVAILABLE = True -except ImportError: - PV_AVAILABLE = False -try: - import vtk - from vtk import vtkDataSetTriangleFilter - from vtk.util import numpy_support + This function enables array-agnostic code by returning the appropriate + array module that can be used for operations on the input array. - VTK_AVAILABLE = True -except ImportError: - VTK_AVAILABLE = False + Args: + array: Input array that can be either NumPy or CuPy array. -# Define a typing that works for both numpy and cupy -if CUPY_AVAILABLE: - ArrayType = Union[np.ndarray, cp.ndarray] -else: - # Or just numpy, if cupy is not available. - ArrayType = np.ndarray + Returns: + The array module (numpy or cupy) corresponding to the input array type. - -def array_type(arr: ArrayType): - """Function to return the array type. It's just leveraging - cupy to do this if available, fallback is numpy. + Examples: + >>> import numpy as np + >>> arr = np.array([1, 2, 3]) + >>> xp = array_type(arr) + >>> result = xp.sum(arr) # Uses numpy.sum """ - if CUPY_AVAILABLE: - return cp.get_array_module(arr) - else: + try: + import cupy as cp + + return cp.get_array_module(array) + except ImportError: return np -def calculate_center_of_mass(stl_centers: ArrayType, stl_sizes: ArrayType) -> ArrayType: - """Function to calculate center of mass""" - xp = array_type(stl_centers) - stl_sizes = xp.expand_dims(stl_sizes, -1) - center_of_mass = xp.sum(stl_centers * stl_sizes, axis=0) / xp.sum(stl_sizes, axis=0) - return center_of_mass +def calculate_center_of_mass(centers: ArrayType, sizes: ArrayType) -> ArrayType: + """Calculate the center of mass for a collection of elements. + + Computes the volume-weighted centroid of mesh elements, commonly used + in computational fluid dynamics for mesh analysis and load balancing. + + Args: + centers: Array of shape (n_elements, 3) containing the centroid + coordinates of each element. + sizes: Array of shape (n_elements,) containing the volume + or area of each element used as weights. + + Returns: + Array of shape (1, 3) containing the x, y, z coordinates of the center of mass. + + Raises: + ValueError: If centers and sizes have incompatible shapes. + + Examples: + >>> import numpy as np + >>> centers = np.array([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]) + >>> sizes = np.array([1.0, 2.0, 3.0]) + >>> com = calculate_center_of_mass(centers, sizes) + >>> np.allclose(com, [[4.0/3.0, 4.0/3.0, 4.0/3.0]]) + True + """ + xp = array_type(centers) + + total_weighted_position = xp.einsum("i,ij->j", sizes, centers) + total_size = xp.sum(sizes) + + return total_weighted_position[None, ...] / total_size + +def normalize( + field: ArrayType, max_val: ArrayType | None = None, min_val: ArrayType | None = None +) -> ArrayType: + """Normalize field values to the range [-1, 1]. + + Applies min-max normalization to scale field values to a symmetric range + around zero. This is commonly used in machine learning preprocessing to + ensure numerical stability and faster convergence. + + Args: + field: Input field array to be normalized. + max_val: Maximum values for normalization, can be scalar or array. + If None, computed from the field data. + min_val: Minimum values for normalization, can be scalar or array. + If None, computed from the field data. + + Returns: + Normalized field with values in the range [-1, 1]. + + Raises: + ZeroDivisionError: If max_val equals min_val (zero range). + + Examples: + >>> import numpy as np + >>> field = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + >>> normalized = normalize(field, 5.0, 1.0) + >>> np.allclose(normalized, [-1.0, -0.5, 0.0, 0.5, 1.0]) + True + >>> # Auto-compute min/max + >>> normalized_auto = normalize(field) + >>> np.allclose(normalized_auto, [-1.0, -0.5, 0.0, 0.5, 1.0]) + True + """ + xp = array_type(field) -def normalize(field: ArrayType, mx: ArrayType, mn: ArrayType) -> ArrayType: - """Function to normalize fields""" - return 2.0 * (field - mn) / (mx - mn) - 1.0 + if max_val is None: + max_val = xp.max(field, axis=0, keepdims=True) + if min_val is None: + min_val = xp.min(field, axis=0, keepdims=True) + field_range = max_val - min_val + return 2.0 * (field - min_val) / field_range - 1.0 + + +def unnormalize( + normalized_field: ArrayType, max_val: ArrayType, min_val: ArrayType +) -> ArrayType: + """Reverse the normalization process to recover original field values. -def unnormalize(field: ArrayType, mx: ArrayType, mn: ArrayType) -> ArrayType: - """Function to unnormalize fields""" - return (field + 1.0) * (mx - mn) * 0.5 + mn + Transforms normalized values from the range [-1, 1] back to their original + physical range using the stored min/max values. + Args: + normalized_field: Field values in the normalized range [-1, 1]. + max_val: Maximum values used in the original normalization. + min_val: Minimum values used in the original normalization. + + Returns: + Field values restored to their original physical range. + + Examples: + >>> import numpy as np + >>> normalized = np.array([-1.0, -0.5, 0.0, 0.5, 1.0]) + >>> original = unnormalize(normalized, 5.0, 1.0) + >>> np.allclose(original, [1.0, 2.0, 3.0, 4.0, 5.0]) + True + """ + field_range = max_val - min_val + return (normalized_field + 1.0) * field_range * 0.5 + min_val + + +def standardize( + field: ArrayType, mean: ArrayType | None = None, std: ArrayType | None = None +) -> ArrayType: + """Standardize field values to have zero mean and unit variance. + + Applies z-score normalization to center the data around zero with + standard deviation of one. This is preferred over min-max normalization + when the data follows a normal distribution. + + Args: + field: Input field array to be standardized. + mean: Mean values for standardization. If None, computed from field data. + std: Standard deviation values for standardization. If None, computed from field data. + + Returns: + Standardized field with approximately zero mean and unit variance. + + Raises: + ZeroDivisionError: If std contains zeros. + + Examples: + >>> import numpy as np + >>> field = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + >>> standardized = standardize(field, 3.0, np.sqrt(2.5)) + >>> np.allclose(standardized, [-1.265, -0.632, 0.0, 0.632, 1.265], atol=1e-3) + True + >>> # Auto-compute mean/std + >>> standardized_auto = standardize(field) + >>> np.allclose(np.mean(standardized_auto), 0.0) + True + >>> np.allclose(np.std(standardized_auto, ddof=0), 1.0) + True + """ + xp = array_type(field) + + if mean is None: + mean = xp.mean(field, axis=0, keepdims=True) + if std is None: + std = xp.std(field, axis=0, keepdims=True) -def standardize(field: ArrayType, mean: ArrayType, std: ArrayType) -> ArrayType: - """Function to standardize fields""" return (field - mean) / std -def unstandardize(field: ArrayType, mean: ArrayType, std: ArrayType) -> ArrayType: - """Function to unstandardize fields""" - return field * std + mean +def unstandardize( + standardized_field: ArrayType, mean: ArrayType, std: ArrayType +) -> ArrayType: + """Reverse the standardization process to recover original field values. + + Transforms standardized values (zero mean, unit variance) back to their + original distribution using the stored mean and standard deviation. + + Args: + standardized_field: Field values with zero mean and unit variance. + mean: Mean values used in the original standardization. + std: Standard deviation values used in the original standardization. + + Returns: + Field values restored to their original distribution. + + Examples: + >>> import numpy as np + >>> standardized = np.array([-1.265, -0.632, 0.0, 0.632, 1.265]) + >>> original = unstandardize(standardized, 3.0, np.sqrt(2.5)) + >>> np.allclose(original, [1.0, 2.0, 3.0, 4.0, 5.0], atol=1e-3) + True + """ + return standardized_field * std + mean + + +def write_to_vtp(polydata: "vtk.vtkPolyData", filename: str) -> None: + """Write VTK polydata to a VTP (VTK PolyData) file format. + + VTP files are XML-based and store polygonal data including points, polygons, + and associated field data. This format is commonly used for surface meshes + in computational fluid dynamics visualization. + Args: + polydata: VTK polydata object containing mesh geometry and fields. + filename: Output filename with .vtp extension. Directory will be created + if it doesn't exist. + + Raises: + RuntimeError: If writing fails due to file permissions or disk space. + + """ + # Ensure output directory exists + output_path = Path(filename) + output_path.parent.mkdir(parents=True, exist_ok=True) -def write_to_vtp(polydata: Any, filename: str): - """Function to write polydata to vtp""" - if not VTK_AVAILABLE: - raise ImportError("VTK or is not installed. This function cannot be used.") writer = vtk.vtkXMLPolyDataWriter() - writer.SetFileName(filename) + writer.SetFileName(str(output_path)) writer.SetInputData(polydata) - writer.Write() + if not writer.Write(): + raise RuntimeError(f"Failed to write polydata to {output_path}") + + +def write_to_vtu(unstructured_grid: "vtk.vtkUnstructuredGrid", filename: str) -> None: + """Write VTK unstructured grid to a VTU (VTK Unstructured Grid) file format. + + VTU files store 3D volumetric meshes with arbitrary cell types including + tetrahedra, hexahedra, and pyramids. This format is essential for storing + finite element analysis results. + + Args: + unstructured_grid: VTK unstructured grid object containing volumetric mesh + geometry and field data. + filename: Output filename with .vtu extension. Directory will be created + if it doesn't exist. + + Raises: + RuntimeError: If writing fails due to file permissions or disk space. + + """ + # Ensure output directory exists + output_path = Path(filename) + output_path.parent.mkdir(parents=True, exist_ok=True) -def write_to_vtu(polydata: Any, filename: str): - """Function to write polydata to vtu""" - if not VTK_AVAILABLE: - raise ImportError("VTK or is not installed. This function cannot be used.") writer = vtk.vtkXMLUnstructuredGridWriter() - writer.SetFileName(filename) - writer.SetInputData(polydata) - writer.Write() + writer.SetFileName(str(output_path)) + writer.SetInputData(unstructured_grid) + + if not writer.Write(): + raise RuntimeError(f"Failed to write unstructured grid to {output_path}") + +def extract_surface_triangles(tetrahedral_mesh: "vtk.vtkUnstructuredGrid") -> list[int]: + """Extract surface triangle indices from a tetrahedral mesh. -def extract_surface_triangles(tet_mesh: Any) -> List[int]: - """Extracts the surface triangles from a triangular mesh.""" - if not VTK_AVAILABLE: - raise ImportError("VTK or is not installed. This function cannot be used.") - if not PV_AVAILABLE: - raise ImportError("PyVista is not installed. This function cannot be used.") + This function identifies the boundary faces of a 3D tetrahedral mesh and + returns the vertex indices that form triangular faces on the surface. + This is essential for visualization and boundary condition application. + + Args: + tetrahedral_mesh: VTK unstructured grid containing tetrahedral elements. + + Returns: + List of vertex indices forming surface triangles. Every three consecutive + indices define one triangle. + + Raises: + NotImplementedError: If the surface contains non-triangular faces. + + """ + # Extract the surface using VTK filter surface_filter = vtk.vtkDataSetSurfaceFilter() - surface_filter.SetInputData(tet_mesh) + surface_filter.SetInputData(tetrahedral_mesh) surface_filter.Update() + # Wrap with PyVista for easier manipulation + import pyvista as pv + surface_mesh = pv.wrap(surface_filter.GetOutput()) triangle_indices = [] + + # Process faces - PyVista stores faces as [n_vertices, v1, v2, ..., vn] faces = surface_mesh.faces.reshape((-1, 4)) for face in faces: - if face[0] == 3: + if face[0] == 3: # Triangle (3 vertices) triangle_indices.extend([face[1], face[2], face[3]]) else: - raise ValueError("Face is not a triangle") + raise NotImplementedError( + f"Non-triangular face found with {face[0]} vertices" + ) return triangle_indices -def convert_to_tet_mesh(polydata: Any) -> Any: - """Function to convert tet to stl""" - if not VTK_AVAILABLE: - raise ImportError("VTK or is not installed. This function cannot be used.") - # Create a VTK DataSetTriangleFilter object - tet_filter = vtkDataSetTriangleFilter() - tet_filter.SetInputData(polydata) - tet_filter.Update() # Update to apply the filter - - # Get the output as an UnstructuredGrid - # tet_mesh = pv.wrap(tet_filter.GetOutput()) - tet_mesh = tet_filter.GetOutput() - return tet_mesh - - -def get_node_to_elem(polydata: Any) -> Any: - """Function to convert node to elem""" - if not VTK_AVAILABLE: - raise ImportError("VTK or is not installed. This function cannot be used.") - c2p = vtk.vtkPointDataToCellData() - c2p.SetInputData(polydata) - c2p.Update() - cell_data = c2p.GetOutput() +def convert_to_tet_mesh(polydata: "vtk.vtkPolyData") -> "vtk.vtkUnstructuredGrid": + """Convert surface polydata to a tetrahedral volumetric mesh. + + This function performs tetrahedralization of a surface mesh, creating + a 3D volumetric mesh suitable for finite element analysis. The process + fills the interior of the surface with tetrahedral elements. + + Args: + polydata: VTK polydata representing a closed surface mesh. + + Returns: + VTK unstructured grid containing tetrahedral elements filling the + volume enclosed by the input surface. + + Raises: + RuntimeError: If tetrahedralization fails (e.g., non-manifold surface). + + """ + tetrahedral_filter = vtkDataSetTriangleFilter() + tetrahedral_filter.SetInputData(polydata) + tetrahedral_filter.Update() + + tetrahedral_mesh = tetrahedral_filter.GetOutput() + return tetrahedral_mesh + + +def convert_point_data_to_cell_data(input_data: "vtk.vtkDataSet") -> "vtk.vtkDataSet": + """Convert point-based field data to cell-based field data. + + This function transforms field variables defined at mesh vertices (nodes) + to values defined at cell centers. This conversion is often needed when + switching between different numerical methods or visualization requirements. + + Args: + input_data: VTK dataset with point data to be converted. + + Returns: + VTK dataset with the same geometry but field data moved from points to cells. + Values are typically averaged from the surrounding points. + + """ + point_to_cell_filter = vtk.vtkPointDataToCellData() + point_to_cell_filter.SetInputData(input_data) + point_to_cell_filter.Update() + + return point_to_cell_filter.GetOutput() + + +def get_node_to_elem(polydata: "vtk.vtkDataSet") -> "vtk.vtkDataSet": + """Convert point data to cell data for VTK dataset. + + This function transforms field variables defined at mesh vertices to + values defined at cell centers using VTK's built-in conversion filter. + + Args: + polydata: VTK dataset with point data to be converted. + + Returns: + VTK dataset with field data moved from points to cells. + + """ + point_to_cell_filter = vtk.vtkPointDataToCellData() + point_to_cell_filter.SetInputData(polydata) + point_to_cell_filter.Update() + cell_data = point_to_cell_filter.GetOutput() return cell_data -def get_fields_from_cell(ptdata, var_list): - """Function to get fields from elem""" - fields = [] - for var in var_list: - variable = ptdata.GetArray(var) - num_tuples = variable.GetNumberOfTuples() - cell_fields = [] - for j in range(num_tuples): - variable_value = np.array(variable.GetTuple(j)) - cell_fields.append(variable_value) - cell_fields = np.asarray(cell_fields) - fields.append(cell_fields) - fields = np.transpose(np.asarray(fields), (1, 0)) - - return fields - - -def get_fields(data, variables): - """Function to get fields from VTP/VTU""" - if not VTK_AVAILABLE: - raise ImportError("VTK or is not installed. This function cannot be used.") - fields = [] - for array_name in variables: +def get_fields_from_cell( + cell_data: "vtk.vtkCellData", variable_names: list[str] +) -> np.ndarray: + """Extract field variables from VTK cell data. + + This function extracts multiple field variables from VTK cell data and + organizes them into a structured NumPy array. Each variable becomes a + column in the output array. + + Args: + cell_data: VTK cell data object containing field variables. + variable_names: List of variable names to extract from the cell data. + + Returns: + NumPy array of shape (n_cells, n_variables) containing the extracted + field data. Variables are ordered according to the input list. + + Raises: + ValueError: If a requested variable name is not found in the cell data. + + """ + extracted_fields = [] + for variable_name in variable_names: + variable_array = cell_data.GetArray(variable_name) + if variable_array is None: + raise ValueError(f"Variable '{variable_name}' not found in cell data") + + num_tuples = variable_array.GetNumberOfTuples() + field_values = [] + for tuple_idx in range(num_tuples): + variable_value = np.array(variable_array.GetTuple(tuple_idx)) + field_values.append(variable_value) + field_values = np.asarray(field_values) + extracted_fields.append(field_values) + + # Transpose to get shape (n_cells, n_variables) + extracted_fields = np.transpose(np.asarray(extracted_fields), (1, 0)) + return extracted_fields + + +def get_fields( + data_attributes: "vtk.vtkDataSetAttributes", variable_names: list[str] +) -> list[np.ndarray]: + """Extract multiple field variables from VTK data attributes. + + This function extracts field variables from VTK data attributes (either + point data or cell data) and returns them as a list of NumPy arrays. + It handles both point and cell data seamlessly. + + Args: + data_attributes: VTK data attributes object (point data or cell data). + variable_names: List of variable names to extract. + + Returns: + List of NumPy arrays, one for each requested variable. Each array + has shape (n_points/n_cells, n_components) where n_components + depends on the variable (1 for scalars, 3 for vectors, etc.). + + Raises: + ValueError: If a requested variable is not found in the data attributes. + + """ + extracted_fields = [] + for variable_name in variable_names: try: - array = data.GetArray(array_name) - except ValueError: + vtk_array = data_attributes.GetArray(variable_name) + except ValueError as e: raise ValueError( - f"Failed to get array {array_name} from the unstructured grid." + f"Failed to get array '{variable_name}' from the data attributes: {e}" ) - array_data = numpy_support.vtk_to_numpy(array).reshape( - array.GetNumberOfTuples(), array.GetNumberOfComponents() + + # Convert VTK array to NumPy array with proper shape + numpy_array = numpy_support.vtk_to_numpy(vtk_array).reshape( + vtk_array.GetNumberOfTuples(), vtk_array.GetNumberOfComponents() ) - fields.append(array_data) - return fields + extracted_fields.append(numpy_array) + return extracted_fields -def get_vertices(polydata): - """Function to get vertices""" - if not VTK_AVAILABLE: - raise ImportError("VTK or is not installed. This function cannot be used.") - points = polydata.GetPoints() - vertices = numpy_support.vtk_to_numpy(points.GetData()) + +def get_vertices(polydata: "vtk.vtkPolyData") -> np.ndarray: + """Extract vertex coordinates from VTK polydata object. + + This function converts VTK polydata to a NumPy array containing the 3D + coordinates of all vertices in the mesh. + + Args: + polydata: VTK polydata object containing mesh geometry. + + Returns: + NumPy array of shape (n_points, 3) containing [x, y, z] coordinates + for each vertex. + + """ + vtk_points = polydata.GetPoints() + vertices = numpy_support.vtk_to_numpy(vtk_points.GetData()) return vertices -def get_volume_data(polydata, variables): - """Function to get volume data""" +def get_volume_data( + polydata: "vtk.vtkPolyData", variable_names: list[str] +) -> tuple[np.ndarray, list[np.ndarray]]: + """Extract vertices and field data from 3D volumetric mesh. + + This function extracts both geometric information (vertex coordinates) + and field data from a 3D volumetric mesh. It's commonly used for + processing finite element analysis results. + + Args: + polydata: VTK polydata representing a 3D volumetric mesh. + variable_names: List of field variable names to extract. + + Returns: + Tuple containing: + - Vertex coordinates as NumPy array of shape (n_vertices, 3) + - List of field arrays, one per variable + + """ vertices = get_vertices(polydata) point_data = polydata.GetPointData() - - fields = get_fields(point_data, variables) + fields = get_fields(point_data, variable_names) return vertices, fields -def get_surface_data(polydata, variables): - """Function to get surface data""" - if not VTK_AVAILABLE: - raise ImportError("VTK or is not installed. This function cannot be used.") +def get_surface_data( + polydata: "vtk.vtkPolyData", variable_names: list[str] +) -> tuple[np.ndarray, list[np.ndarray], list[tuple[int, int]]]: + """Extract surface mesh data including vertices, fields, and edge connectivity. + + This function extracts comprehensive surface mesh information including + vertex coordinates, field data at vertices, and edge connectivity information. + It's commonly used for processing CFD surface results and boundary conditions. + + Args: + polydata: VTK polydata representing a surface mesh. + variable_names: List of field variable names to extract from the mesh. + + Returns: + Tuple containing: + - Vertex coordinates as NumPy array of shape (n_vertices, 3) + - List of field arrays, one per variable + - List of edge tuples representing mesh connectivity + + Raises: + ValueError: If a requested variable is not found or polygon data is missing. + + """ points = polydata.GetPoints() vertices = np.array([points.GetPoint(i) for i in range(points.GetNumberOfPoints())]) point_data = polydata.GetPointData() fields = [] - for array_name in variables: + for array_name in variable_names: try: array = point_data.GetArray(array_name) except ValueError: @@ -259,14 +595,43 @@ def get_surface_data(polydata, variables): def calculate_normal_positional_encoding( coordinates_a: ArrayType, - coordinates_b: Optional[ArrayType] = None, - cell_length: Sequence[float] = [], + coordinates_b: ArrayType | None = None, + cell_dimensions: Sequence[float] = (1.0, 1.0, 1.0), ) -> ArrayType: - """Function to get normal positional encoding""" - dx = cell_length[0] - dy = cell_length[1] - dz = cell_length[2] + """Calculate sinusoidal positional encoding for 3D coordinates. + + This function computes transformer-style positional encodings for 3D spatial + coordinates, enabling neural networks to understand spatial relationships. + The encoding uses sinusoidal functions at different frequencies to create + unique representations for each spatial position. + + Args: + coordinates_a: Primary coordinates array of shape (n_points, 3). + coordinates_b: Optional secondary coordinates for computing relative positions. + If provided, the encoding is computed for (coordinates_a - coordinates_b). + cell_dimensions: Characteristic length scales for x, y, z dimensions used + for normalization. Defaults to unit dimensions. + + Returns: + Array of shape (n_points, 12) containing positional encodings with + 4 encoding dimensions per spatial axis (x, y, z). + + Examples: + >>> import numpy as np + >>> coords = np.array([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]) + >>> cell_size = [0.1, 0.1, 0.1] + >>> encoding = calculate_normal_positional_encoding(coords, cell_dimensions=cell_size) + >>> encoding.shape + (2, 12) + >>> # Relative positioning example + >>> coords_b = np.array([[0.5, 0.5, 0.5], [0.5, 0.5, 0.5]]) + >>> encoding_rel = calculate_normal_positional_encoding(coords, coords_b, cell_size) + >>> encoding_rel.shape + (2, 12) + """ + dx, dy, dz = cell_dimensions[0], cell_dimensions[1], cell_dimensions[2] xp = array_type(coordinates_a) + if coordinates_b is not None: normals = coordinates_a - coordinates_b pos_x = xp.asarray(calculate_pos_encoding(normals[:, 0] / dx, d=4)) @@ -283,32 +648,122 @@ def calculate_normal_positional_encoding( return pos_normals -def nd_interpolator(coodinates, field, grid): - """Function to for nd interpolation""" +def nd_interpolator( + coordinates: ArrayType, field: ArrayType, grid: ArrayType, k: int = 2 +) -> ArrayType: + """Perform n-dimensional interpolation using k-nearest neighbors. + + This function interpolates field values from scattered points to a regular + grid using k-nearest neighbor averaging. It's useful for reconstructing + fields on regular grids from irregular measurement points. + + Args: + coordinates: Array of shape (n_points, n_dims) containing source point coordinates. + field: Array of shape (n_points, n_fields) containing field values at source points. + grid: Array of shape (n_field_points, n_dims) containing target grid points for interpolation. + k: Number of nearest neighbors to use for interpolation. + + Returns: + Interpolated field values at grid points using k-nearest neighbor averaging. + + Note: + This function currently uses SciPy's KDTree which only supports CPU arrays. + A future enhancement could add CuML support for GPU acceleration. + + Examples: + >>> import numpy as np + >>> # Simple 2D interpolation example + >>> coords = np.array([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) + >>> field_vals = np.array([[1.0], [2.0], [3.0], [4.0]]) + >>> grid_points = np.array([[0.5, 0.5]]) + >>> result = nd_interpolator([coords], field_vals, grid_points) + >>> result.shape[0] == 1 # One grid point + True + """ # TODO - this function should get updated for cuml if using cupy. - interp_func = KDTree(coodinates[0]) - dd, ii = interp_func.query(grid, k=2) + kdtree = KDTree(coordinates[0]) + distances, neighbor_indices = kdtree.query(grid, k=k) - field_grid = field[ii] - field_grid = np.float32(np.mean(field_grid, (3))) + field_grid = field[neighbor_indices] + field_grid = np.mean(field_grid, axis=1) return field_grid -def pad(arr: ArrayType, npoin: int, pad_value: float = 0.0) -> ArrayType: - """Function for padding""" +def pad(arr: ArrayType, n_points: int, pad_value: float = 0.0) -> ArrayType: + """Pad 2D array with constant values to reach target size. + + This function extends a 2D array by adding rows filled with a constant + value. It's commonly used to standardize array sizes in batch processing + for machine learning applications. + + Args: + arr: Input array of shape (n_points, n_features) to be padded. + n_points: Target number of points (rows) after padding. + pad_value: Constant value used for padding. Defaults to 0.0. + + Returns: + Padded array of shape (n_points, n_features). If n_points <= arr.shape[0], + returns the original array unchanged. + + Examples: + >>> import numpy as np + >>> arr = np.array([[1.0, 2.0], [3.0, 4.0]]) + >>> padded = pad(arr, 4, -1.0) + >>> padded.shape + (4, 2) + >>> np.array_equal(padded[:2], arr) + True + >>> bool(np.all(padded[2:] == -1.0)) + True + >>> # No padding needed + >>> same = pad(arr, 2) + >>> np.array_equal(same, arr) + True + """ xp = array_type(arr) + if n_points <= arr.shape[0]: + return arr + arr_pad = pad_value * xp.ones( - (npoin - arr.shape[0], arr.shape[1]), dtype=xp.float32 + (n_points - arr.shape[0], arr.shape[1]), dtype=xp.float32 ) arr_padded = xp.concatenate((arr, arr_pad), axis=0) return arr_padded -def pad_inp(arr: ArrayType, npoin: int, pad_value: float = 0.0) -> ArrayType: - """Function for padding arrays""" +def pad_inp(arr: ArrayType, n_points: int, pad_value: float = 0.0) -> ArrayType: + """Pad 3D array with constant values to reach target size. + + This function extends a 3D array by adding entries along the first dimension + filled with a constant value. Used for standardizing 3D tensor sizes in + batch processing workflows. + + Args: + arr: Input array of shape (n_points, height, width) to be padded. + n_points: Target number of points along first dimension after padding. + pad_value: Constant value used for padding. Defaults to 0.0. + + Returns: + Padded array of shape (n_points, height, width). If n_points <= arr.shape[0], + returns the original array unchanged. + + Examples: + >>> import numpy as np + >>> arr = np.array([[[1.0, 2.0]], [[3.0, 4.0]]]) + >>> padded = pad_inp(arr, 4, 0.0) + >>> padded.shape + (4, 1, 2) + >>> np.array_equal(padded[:2], arr) + True + >>> bool(np.all(padded[2:] == 0.0)) + True + """ xp = array_type(arr) + if n_points <= arr.shape[0]: + return arr + arr_pad = pad_value * xp.ones( - (npoin - arr.shape[0], arr.shape[1], arr.shape[2]), dtype=xp.float32 + (n_points - arr.shape[0], arr.shape[1], arr.shape[2]), dtype=xp.float32 ) arr_padded = xp.concatenate((arr, arr_pad), axis=0) return arr_padded @@ -317,82 +772,231 @@ def pad_inp(arr: ArrayType, npoin: int, pad_value: float = 0.0) -> ArrayType: @profile def shuffle_array( arr: ArrayType, - npoin: int, -) -> Tuple[ArrayType, ArrayType]: - """Function for shuffling arrays""" + n_points: int, +) -> tuple[ArrayType, ArrayType]: + """Randomly sample points from array without replacement. + + This function performs random sampling from the input array, selecting + n_points points without replacement. It's commonly used for creating training + subsets and data augmentation in machine learning workflows. + + Args: + arr: Input array to sample from, shape (n_points, ...). + n_points: Number of points to sample. If greater than arr.shape[0], + all points are returned. + + Returns: + Tuple containing: + - Sampled array subset + - Indices of the selected points + + Examples: + >>> import numpy as np + >>> np.random.seed(42) # For reproducible results + >>> data = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) + >>> subset, indices = shuffle_array(data, 2) + >>> subset.shape + (2, 2) + >>> indices.shape + (2,) + >>> len(np.unique(indices)) == 2 # No duplicates + True + """ xp = array_type(arr) - if npoin > arr.shape[0]: + if n_points > arr.shape[0]: # If asking too many points, truncate the ask but still shuffle. - npoin = arr.shape[0] - idx = xp.random.choice(arr.shape[0], size=npoin, replace=False) + n_points = arr.shape[0] + idx = xp.random.choice(arr.shape[0], size=n_points, replace=False) return arr[idx], idx -def shuffle_array_without_sampling(arr: ArrayType) -> Tuple[ArrayType, ArrayType]: - """Function for shuffline arrays without sampling.""" +def shuffle_array_without_sampling(arr: ArrayType) -> tuple[ArrayType, ArrayType]: + """Shuffle array order without changing the number of elements. + + This function reorders all elements in the array randomly while preserving + all data points. It's useful for randomizing data order before training + while maintaining the complete dataset. + + Args: + arr: Input array to shuffle, shape (n_points, ...). + + Returns: + Tuple containing: + - Shuffled array with same shape as input + - Permutation indices used for shuffling + + Examples: + >>> import numpy as np + >>> np.random.seed(42) # For reproducible results + >>> data = np.array([[1], [2], [3], [4]]) + >>> shuffled, indices = shuffle_array_without_sampling(data) + >>> shuffled.shape + (4, 1) + >>> indices.shape + (4,) + >>> set(indices) == set(range(4)) # All original indices present + True + """ xp = array_type(arr) idx = xp.arange(arr.shape[0]) xp.random.shuffle(idx) return arr[idx], idx -def create_directory(filepath: str) -> None: - """Function to create directories""" - if not os.path.exists(filepath): - os.makedirs(filepath) - - -def get_filenames(filepath: str, exclude_dirs: bool = False) -> List[str]: - """Function to get filenames from a directory""" - if os.path.exists(filepath): - filenames = [] - for item in os.listdir(filepath): - item_path = os.path.join(filepath, item) - if exclude_dirs and os.path.isdir(item_path): - # Include directories ending with .zarr even when exclude_dirs is True - if item.endswith(".zarr"): - filenames.append(item) - continue - filenames.append(item) - return filenames - else: - FileNotFoundError() +def create_directory(filepath: str | Path) -> None: + """Create directory and all necessary parent directories. + + This function creates a directory at the specified path, including any + necessary parent directories. It's equivalent to `mkdir -p` in Unix systems. + + Args: + filepath: Path to the directory to create. Can be string or Path object. + + """ + Path(filepath).mkdir(parents=True, exist_ok=True) -def calculate_pos_encoding(nx: ArrayType, d: int = 8) -> ArrayType: - """Function for calculating positional encoding""" +def get_filenames(filepath: str | Path, exclude_dirs: bool = False) -> list[str]: + """Get list of filenames in a directory with optional directory filtering. + + This function returns all items in a directory, with options to exclude + subdirectories. It handles special cases like .zarr directories which + are treated as files even when exclude_dirs is True. + + Args: + filepath: Path to the directory to list. Can be string or Path object. + exclude_dirs: If True, exclude subdirectories from results. + Exception: .zarr directories are always included as they represent + data files in array storage format. + + Returns: + List of filenames/directory names found in the specified directory. + + Raises: + FileNotFoundError: If the specified directory does not exist. + + """ + path = Path(filepath) + if not path.exists(): + raise FileNotFoundError(f"Directory {filepath} does not exist") + + filenames = [] + for item in path.iterdir(): + if exclude_dirs and item.is_dir(): + # Include directories ending with .zarr even when exclude_dirs is True + if item.name.endswith(".zarr"): + filenames.append(item.name) + continue + filenames.append(item.name) + return filenames + + +def calculate_pos_encoding(nx: ArrayType, d: int = 8) -> list[ArrayType]: + """Calculate sinusoidal positional encoding for transformer architectures. + + This function computes positional encodings using alternating sine and cosine + functions at different frequencies. These encodings help neural networks + understand positional relationships in sequences or spatial data. + + Args: + nx: Input positions/coordinates to encode. + d: Encoding dimensionality. Must be even number. Defaults to 8. + + Returns: + List of d arrays containing alternating sine and cosine encodings. + Each pair (sin, cos) uses progressively lower frequencies. + + Examples: + >>> import numpy as np + >>> positions = np.array([0.0, 1.0, 2.0]) + >>> encodings = calculate_pos_encoding(positions, d=4) + >>> len(encodings) + 4 + >>> all(enc.shape == (3,) for enc in encodings) + True + """ vec = [] xp = array_type(nx) for k in range(int(d / 2)): - vec.append(xp.sin(nx / 10000 ** (2 * (k) / d))) - vec.append(xp.cos(nx / 10000 ** (2 * (k) / d))) + vec.append(xp.sin(nx / 10000 ** (2 * k / d))) + vec.append(xp.cos(nx / 10000 ** (2 * k / d))) return vec -def combine_dict(old_dict, new_dict): - """Function to combine dictionaries""" - for j in old_dict.keys(): - old_dict[j] += new_dict[j] - return old_dict +def combine_dict(old_dict: dict[Any, Any], new_dict: dict[Any, Any]) -> dict[Any, Any]: + """Combine two dictionaries by adding values for matching keys. + This function performs element-wise addition of dictionary values for + keys that exist in both dictionaries. It's commonly used for accumulating + statistics or metrics across multiple iterations. -def merge(*lists): - """Function to merge lists""" - newlist = lists[:] - for x in lists: - if x not in newlist: - newlist.extend(x) - return newlist + Args: + old_dict: Base dictionary to update. + new_dict: Dictionary with values to add to old_dict. + Returns: + Updated old_dict with combined values. -def create_grid(mx: ArrayType, mn: ArrayType, nres: ArrayType) -> ArrayType: - """Function to create grid""" + Note: + This function modifies old_dict in place and returns it. + Values must support the + operator. - xp = array_type(mx) + Examples: + >>> stats1 = {"loss": 0.5, "accuracy": 0.8} + >>> stats2 = {"loss": 0.3, "accuracy": 0.1} + >>> combined = combine_dict(stats1, stats2) + >>> combined["loss"] + 0.8 + >>> combined["accuracy"] + 0.9 + """ + for key in old_dict.keys(): + old_dict[key] += new_dict[key] + return old_dict - dx = xp.linspace(mn[0], mx[0], nres[0], dtype=mx.dtype) - dy = xp.linspace(mn[1], mx[1], nres[1], dtype=mx.dtype) - dz = xp.linspace(mn[2], mx[2], nres[2], dtype=mx.dtype) + +def create_grid( + max_coords: ArrayType, min_coords: ArrayType, resolution: ArrayType +) -> ArrayType: + """Create a 3D regular grid from coordinate bounds and resolution. + + This function generates a regular 3D grid spanning from min_coords to + max_coords with the specified resolution in each dimension. The resulting + grid is commonly used for interpolation, visualization, and regular sampling. + + Args: + max_coords: Maximum coordinates [x_max, y_max, z_max] for the grid bounds. + min_coords: Minimum coordinates [x_min, y_min, z_min] for the grid bounds. + resolution: Number of grid points [nx, ny, nz] in each dimension. + + Returns: + Grid array of shape (nx, ny, nz, 3) containing 3D coordinates for each + grid point. The last dimension contains [x, y, z] coordinates. + + Examples: + >>> import numpy as np + >>> min_bounds = np.array([0.0, 0.0, 0.0]) + >>> max_bounds = np.array([1.0, 1.0, 1.0]) + >>> grid_res = np.array([2, 2, 2]) + >>> grid = create_grid(max_bounds, min_bounds, grid_res) + >>> grid.shape + (2, 2, 2, 3) + >>> np.allclose(grid[0, 0, 0], [0.0, 0.0, 0.0]) + True + >>> np.allclose(grid[1, 1, 1], [1.0, 1.0, 1.0]) + True + """ + xp = array_type(max_coords) + + dx = xp.linspace( + min_coords[0], max_coords[0], resolution[0], dtype=max_coords.dtype + ) + dy = xp.linspace( + min_coords[1], max_coords[1], resolution[1], dtype=max_coords.dtype + ) + dz = xp.linspace( + min_coords[2], max_coords[2], resolution[2], dtype=max_coords.dtype + ) xv, yv, zv = xp.meshgrid(dx, dy, dz) xv = xp.expand_dims(xv, -1) @@ -406,10 +1010,34 @@ def create_grid(mx: ArrayType, mn: ArrayType, nres: ArrayType) -> ArrayType: def mean_std_sampling( field: ArrayType, mean: ArrayType, std: ArrayType, tolerance: float = 3.0 -) -> ArrayType: - """Function for mean/std based sampling""" +) -> list[int]: + """Identify outlier points based on statistical distance from mean. + + This function identifies data points that are statistical outliers by + checking if they fall outside mean ± tolerance*std for any field component. + It's useful for data cleaning and identifying regions of interest in CFD data. + + Args: + field: Input field array of shape (n_points, n_components). + mean: Mean values for each field component, shape (n_components,). + std: Standard deviation for each component, shape (n_components,). + tolerance: Number of standard deviations to use as outlier threshold. + Defaults to 3.0 (99.7% of normal distribution). + + Returns: + List of indices identifying outlier points that exceed the statistical threshold. + + Examples: + >>> import numpy as np + >>> # Create test data with outliers + >>> field = np.array([[1.0], [2.0], [3.0], [10.0]]) # 10.0 is outlier + >>> field_mean = np.array([2.0]) + >>> field_std = np.array([1.0]) + >>> outliers = mean_std_sampling(field, field_mean, field_std, 2.0) + >>> 3 in outliers # Index 3 (value 10.0) should be detected as outlier + True + """ xp = array_type(field) - idx_all = [] for v in range(field.shape[-1]): fv = field[:, v] @@ -422,8 +1050,33 @@ def mean_std_sampling( return idx_all -def dict_to_device(state_dict, device, exclude_keys=["filename"]): - """Function to load dictionary to device""" +def dict_to_device( + state_dict: dict[str, Any], device: Any, exclude_keys: list[str] | None = None +) -> dict[str, Any]: + """Move dictionary values to specified device (GPU/CPU). + + This function transfers PyTorch tensors in a dictionary to the specified + compute device while preserving the dictionary structure. It's commonly + used for moving model parameters and data between CPU and GPU. + + Args: + state_dict: Dictionary containing tensors and other values. + device: Target device (e.g., torch.device('cuda:0')). + exclude_keys: List of keys to skip during device transfer. + Defaults to ["filename"] if None. + + Returns: + New dictionary with tensors moved to the specified device. + Non-tensor values and excluded keys are preserved as-is. + + Examples: + >>> import torch + >>> data = {"weights": torch.randn(10, 10), "filename": "model.pt"} + >>> gpu_data = dict_to_device(data, torch.device('cuda:0')) + """ + if exclude_keys is None: + exclude_keys = ["filename"] + new_state_dict = {} for k, v in state_dict.items(): if k not in exclude_keys: @@ -432,33 +1085,64 @@ def dict_to_device(state_dict, device, exclude_keys=["filename"]): def area_weighted_shuffle_array( - arr: ArrayType, npoin: int, area: ArrayType -) -> Tuple[ArrayType, ArrayType]: - """Function for area weighted shuffling""" + arr: ArrayType, n_points: int, area: ArrayType, area_factor: float = 1.0 +) -> tuple[ArrayType, ArrayType]: + """Perform area-weighted random sampling from array. + + This function samples points from an array with probability proportional to + their associated area weights. This is particularly useful in CFD applications + where larger cells or surface elements should have higher sampling probability. + + Args: + arr: Input array to sample from, shape (n_points, ...). + n_points: Number of points to sample. If greater than arr.shape[0], + samples all available points. + area: Area weights for each point, shape (n_points,). Larger values + indicate higher sampling probability. + area_factor: Exponent applied to area weights to control sampling bias. + Values > 1.0 increase bias toward larger areas, values < 1.0 reduce bias. + Defaults to 1.0 (linear weighting). + + Returns: + Tuple containing: + - Sampled array subset weighted by area + - Indices of the selected points + + Note: + For GPU arrays (CuPy), the sampling is performed on CPU due to memory + efficiency considerations. The Alias method could be implemented for + future GPU acceleration. + + Examples: + >>> import numpy as np + >>> np.random.seed(42) # For reproducible results + >>> mesh_data = np.array([[1.0], [2.0], [3.0], [4.0]]) + >>> cell_areas = np.array([0.1, 0.1, 0.1, 10.0]) # Last point has much larger area + >>> subset, indices = area_weighted_shuffle_array(mesh_data, 2, cell_areas) + >>> subset.shape + (2, 1) + >>> indices.shape + (2,) + >>> # The point with large area (index 3) should likely be selected + >>> len(set(indices)) <= 2 # At most 2 unique indices + True + >>> # Use higher area_factor for stronger bias toward large areas + >>> subset_biased, _ = area_weighted_shuffle_array(mesh_data, 2, cell_areas, area_factor=2.0) + """ xp = array_type(arr) - # Compute the total_area: - factor = 1.0 - total_area = xp.sum(area**factor) - probs = area**factor / total_area + # Calculate area-weighted probabilities + sampling_probabilities = area**area_factor + sampling_probabilities /= xp.sum(sampling_probabilities) # Normalize to sum to 1 - if npoin > arr.shape[0]: - npoin = arr.shape[0] + # Ensure we don't request more points than available + n_points = min(n_points, arr.shape[0]) - idx = xp.arange(arr.shape[0]) + # Create index array for all available points + point_indices = xp.arange(arr.shape[0]) - # This is too memory intensive to run on the GPU. - if xp == cp: - idx = idx.get() - probs = probs.get() - # Under the hood, this has a search over the probabilities. - # It's very expensive in memory, as far as I can tell. - # In principle, we could use the Alias method to speed this up - # on the GPU but it's not yet a bottleneck. - - ids = np.random.choice(idx, npoin, p=probs) - ids = xp.asarray(ids) - else: - # Chug along on the CPU: - ids = xp.random.choice(idx, npoin, p=probs) + selected_indices = xp.random.choice( + xp.asarray(point_indices), size=n_points, p=xp.asarray(sampling_probabilities) + ) + selected_indices = xp.asarray(selected_indices) - return arr[ids], ids + return arr[selected_indices], selected_indices diff --git a/test/utils/test_domino_utils.py b/test/utils/test_domino_utils.py new file mode 100644 index 0000000000..8a0e03637b --- /dev/null +++ b/test/utils/test_domino_utils.py @@ -0,0 +1,219 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Test suite for domino utils module. + +This test file duplicates all the docstring examples from the domino utils +module to ensure that the documented examples work correctly. +""" + +import numpy as np + +from physicsnemo.utils.domino.utils import ( + area_weighted_shuffle_array, + calculate_center_of_mass, + calculate_normal_positional_encoding, + calculate_pos_encoding, + combine_dict, + create_grid, + mean_std_sampling, + nd_interpolator, + normalize, + pad, + pad_inp, + shuffle_array, + shuffle_array_without_sampling, + standardize, + unnormalize, + unstandardize, +) + + +def test_calculate_center_of_mass(): + """Test calculate_center_of_mass function with docstring example.""" + centers = np.array([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]) + sizes = np.array([1.0, 2.0, 3.0]) + com = calculate_center_of_mass(centers, sizes) + expected = np.array([[4.0 / 3.0, 4.0 / 3.0, 4.0 / 3.0]]) + assert np.allclose(com, expected) + + +def test_normalize(): + """Test normalize function with docstring examples.""" + # Example 1: With explicit min/max + field = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + normalized = normalize(field, 5.0, 1.0) + expected = np.array([-1.0, -0.5, 0.0, 0.5, 1.0]) + assert np.allclose(normalized, expected) + + # Example 2: Auto-compute min/max + normalized_auto = normalize(field) + expected_auto = np.array([-1.0, -0.5, 0.0, 0.5, 1.0]) + assert np.allclose(normalized_auto, expected_auto) + + +def test_unnormalize(): + """Test unnormalize function with docstring example.""" + normalized = np.array([-1.0, -0.5, 0.0, 0.5, 1.0]) + original = unnormalize(normalized, 5.0, 1.0) + expected = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + assert np.allclose(original, expected) + + +def test_standardize(): + """Test standardize function with docstring examples.""" + # Example 1: With explicit mean/std + field = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + standardized = standardize(field, 3.0, np.sqrt(2.5)) + expected = np.array([-1.265, -0.632, 0.0, 0.632, 1.265]) + assert np.allclose(standardized, expected, atol=1e-3) + + # Example 2: Auto-compute mean/std + standardized_auto = standardize(field) + assert np.allclose(np.mean(standardized_auto), 0.0) + assert np.allclose(np.std(standardized_auto, ddof=0), 1.0) + + +def test_unstandardize(): + """Test unstandardize function with docstring example.""" + standardized = np.array([-1.265, -0.632, 0.0, 0.632, 1.265]) + original = unstandardize(standardized, 3.0, np.sqrt(2.5)) + expected = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + assert np.allclose(original, expected, atol=1e-3) + + +def test_calculate_normal_positional_encoding(): + """Test calculate_normal_positional_encoding function with docstring examples.""" + # Example 1: Basic coordinates + coords = np.array([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]) + cell_size = [0.1, 0.1, 0.1] + encoding = calculate_normal_positional_encoding(coords, cell_dimensions=cell_size) + assert encoding.shape == (2, 12) + + # Example 2: Relative positioning + coords_b = np.array([[0.5, 0.5, 0.5], [0.5, 0.5, 0.5]]) + encoding_rel = calculate_normal_positional_encoding(coords, coords_b, cell_size) + assert encoding_rel.shape == (2, 12) + + +def test_nd_interpolator(): + """Test nd_interpolator function with docstring example.""" + # Simple 2D interpolation example + coords = np.array([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) + field_vals = np.array([[1.0], [2.0], [3.0], [4.0]]) + grid_points = np.array([[0.5, 0.5]]) + result = nd_interpolator([coords], field_vals, grid_points) + assert result.shape[0] == 1 # One grid point + + +def test_pad(): + """Test pad function with docstring examples.""" + # Example 1: Padding needed + arr = np.array([[1.0, 2.0], [3.0, 4.0]]) + padded = pad(arr, 4, -1.0) + assert padded.shape == (4, 2) + assert np.array_equal(padded[:2], arr) + assert bool(np.all(padded[2:] == -1.0)) + + # Example 2: No padding needed + same = pad(arr, 2) + assert np.array_equal(same, arr) + + +def test_pad_inp(): + """Test pad_inp function with docstring example.""" + arr = np.array([[[1.0, 2.0]], [[3.0, 4.0]]]) + padded = pad_inp(arr, 4, 0.0) + assert padded.shape == (4, 1, 2) + assert np.array_equal(padded[:2], arr) + assert bool(np.all(padded[2:] == 0.0)) + + +def test_shuffle_array(): + """Test shuffle_array function with docstring example.""" + np.random.seed(42) # For reproducible results + data = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) + subset, indices = shuffle_array(data, 2) + assert subset.shape == (2, 2) + assert indices.shape == (2,) + assert len(np.unique(indices)) == 2 # No duplicates + + +def test_shuffle_array_without_sampling(): + """Test shuffle_array_without_sampling function with docstring example.""" + np.random.seed(42) # For reproducible results + data = np.array([[1], [2], [3], [4]]) + shuffled, indices = shuffle_array_without_sampling(data) + assert shuffled.shape == (4, 1) + assert indices.shape == (4,) + assert set(indices) == set(range(4)) # All original indices present + + +def test_calculate_pos_encoding(): + """Test calculate_pos_encoding function with docstring example.""" + positions = np.array([0.0, 1.0, 2.0]) + encodings = calculate_pos_encoding(positions, d=4) + assert len(encodings) == 4 + assert all(enc.shape == (3,) for enc in encodings) + + +def test_combine_dict(): + """Test combine_dict function with docstring example.""" + stats1 = {"loss": 0.5, "accuracy": 0.8} + stats2 = {"loss": 0.3, "accuracy": 0.1} + combined = combine_dict(stats1, stats2) + assert combined["loss"] == 0.8 + assert combined["accuracy"] == 0.9 + + +def test_create_grid(): + """Test create_grid function with docstring example.""" + min_bounds = np.array([0.0, 0.0, 0.0]) + max_bounds = np.array([1.0, 1.0, 1.0]) + grid_res = np.array([2, 2, 2]) + grid = create_grid(max_bounds, min_bounds, grid_res) + assert grid.shape == (2, 2, 2, 3) + assert np.allclose(grid[0, 0, 0], [0.0, 0.0, 0.0]) + assert np.allclose(grid[1, 1, 1], [1.0, 1.0, 1.0]) + + +def test_mean_std_sampling(): + """Test mean_std_sampling function with docstring example.""" + # Create test data with outliers + field = np.array([[1.0], [2.0], [3.0], [10.0]]) # 10.0 is outlier + field_mean = np.array([2.0]) + field_std = np.array([1.0]) + outliers = mean_std_sampling(field, field_mean, field_std, 2.0) + assert 3 in outliers # Index 3 (value 10.0) should be detected as outlier + + +def test_area_weighted_shuffle_array(): + """Test area_weighted_shuffle_array function with docstring example.""" + np.random.seed(42) # For reproducible results + mesh_data = np.array([[1.0], [2.0], [3.0], [4.0]]) + cell_areas = np.array([0.1, 0.1, 0.1, 10.0]) # Last point has much larger area + subset, indices = area_weighted_shuffle_array(mesh_data, 2, cell_areas) + assert subset.shape == (2, 1) + assert indices.shape == (2,) + # The point with large area (index 3) should likely be selected + assert len(set(indices)) <= 2 # At most 2 unique indices + + # Use higher area_factor for stronger bias toward large areas + subset_biased, _ = area_weighted_shuffle_array( + mesh_data, 2, cell_areas, area_factor=2.0 + ) + assert subset_biased.shape == (2, 1)