From fd473234425603e6347177bb99c3d7b03a8131f0 Mon Sep 17 00:00:00 2001 From: Peter Sharpe Date: Fri, 20 Jun 2025 07:47:35 -0400 Subject: [PATCH 01/19] Massive refactor on domino utils.py to improve code quality --- physicsnemo/utils/domino/utils.py | 1120 ++++++++++++++++++++++------- 1 file changed, 875 insertions(+), 245 deletions(-) diff --git a/physicsnemo/utils/domino/utils.py b/physicsnemo/utils/domino/utils.py index c2c8e0feaa..c0592481cd 100644 --- a/physicsnemo/utils/domino/utils.py +++ b/physicsnemo/utils/domino/utils.py @@ -15,219 +15,564 @@ # limitations under the License. """ -Important utilities for data processing and training, testing DoMINO. +Utilities for data processing and training with the DoMINO model architecture. + +This module provides essential utilities for computational fluid dynamics data processing, +mesh manipulation, field normalization, and geometric computations. It supports both +CPU (NumPy) and GPU (CuPy) operations with automatic fallbacks. """ -import os -from typing import Any, List, Optional, Sequence, Tuple, Union +from pathlib import Path +from typing import Any, Sequence import numpy as np +import cupy as cp +import pyvista as pv +import vtk +from vtk import vtkDataSetTriangleFilter +from vtk.util import numpy_support from scipy.spatial import KDTree from physicsnemo.utils.profiling import profile -try: - import cupy as cp - - CUPY_AVAILABLE = True -except ImportError: - CUPY_AVAILABLE = False - - -try: - import pyvista as pv - - PV_AVAILABLE = True -except ImportError: - PV_AVAILABLE = False -try: - import vtk - from vtk import vtkDataSetTriangleFilter - from vtk.util import numpy_support - - VTK_AVAILABLE = True -except ImportError: - VTK_AVAILABLE = False - -# Define a typing that works for both numpy and cupy -if CUPY_AVAILABLE: - ArrayType = Union[np.ndarray, cp.ndarray] -else: - # Or just numpy, if cupy is not available. - ArrayType = np.ndarray - - -def array_type(arr: ArrayType): - """Function to return the array type. It's just leveraging - cupy to do this if available, fallback is numpy. +# Type alias for arrays that can be either NumPy or CuPy +ArrayType = np.ndarray | cp.ndarray + + +def array_type(array: ArrayType) -> type[np] | type[cp]: + """Determine the array module (NumPy or CuPy) for the given array. + + This function enables array-agnostic code by returning the appropriate + array module that can be used for operations on the input array. + + Args: + array: Input array that can be either NumPy or CuPy array. + + Returns: + The array module (numpy or cupy) corresponding to the input array type. + + Examples: + >>> import numpy as np + >>> arr = np.array([1, 2, 3]) + >>> xp = array_type(arr) + >>> result = xp.sum(arr) # Uses numpy.sum """ - if CUPY_AVAILABLE: - return cp.get_array_module(arr) - else: - return np + return cp.get_array_module(array) -def calculate_center_of_mass(stl_centers: ArrayType, stl_sizes: ArrayType) -> ArrayType: - """Function to calculate center of mass""" - xp = array_type(stl_centers) - stl_sizes = xp.expand_dims(stl_sizes, -1) - center_of_mass = xp.sum(stl_centers * stl_sizes, axis=0) / xp.sum(stl_sizes, axis=0) - return center_of_mass - - -def normalize(field: ArrayType, mx: ArrayType, mn: ArrayType) -> ArrayType: - """Function to normalize fields""" - return 2.0 * (field - mn) / (mx - mn) - 1.0 - - -def unnormalize(field: ArrayType, mx: ArrayType, mn: ArrayType) -> ArrayType: - """Function to unnormalize fields""" - return (field + 1.0) * (mx - mn) * 0.5 + mn +def calculate_center_of_mass( + centers: ArrayType, sizes: ArrayType +) -> ArrayType: + """Calculate the center of mass for a collection of elements. + + Computes the volume-weighted centroid of mesh elements, commonly used + in computational fluid dynamics for mesh analysis and load balancing. + + Args: + centers: Array of shape (n_elements, 3) containing the centroid + coordinates of each element. + sizes: Array of shape (n_elements,) containing the volume + or area of each element used as weights. + + Returns: + Array of shape (3,) containing the x, y, z coordinates of the center of mass. + + Raises: + ValueError: If centers and sizes have incompatible shapes. + + Examples: + >>> centers = np.array([[0, 0, 0], [1, 1, 1], [2, 2, 2]]) + >>> sizes = np.array([1.0, 2.0, 3.0]) + >>> com = calculate_center_of_mass(centers, sizes) + >>> print(com) # [1.5, 1.5, 1.5] + """ + xp = array_type(centers) + + total_weighted_position = xp.einsum('i,ij->ij', sizes, centers) + total_size = xp.sum(sizes) + + return total_weighted_position / total_size + + +def normalize( + field: ArrayType, + max_val: ArrayType | None = None, + min_val: ArrayType | None = None +) -> ArrayType: + """Normalize field values to the range [-1, 1]. + + Applies min-max normalization to scale field values to a symmetric range + around zero. This is commonly used in machine learning preprocessing to + ensure numerical stability and faster convergence. + + Args: + field: Input field array to be normalized. + max_val: Maximum values for normalization, can be scalar or array. + If None, computed from the field data. + min_val: Minimum values for normalization, can be scalar or array. + If None, computed from the field data. + + Returns: + Normalized field with values in the range [-1, 1]. + + Raises: + ZeroDivisionError: If max_val equals min_val (zero range). + + Examples: + >>> field = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + >>> normalized = normalize(field, 5.0, 1.0) + >>> print(normalized) # [-1, -0.5, 0, 0.5, 1] + + >>> # Auto-compute min/max + >>> normalized = normalize(field) + """ + xp = array_type(field) + + if max_val is None: + max_val = xp.max(field, axis=0, keepdims=True) + if min_val is None: + min_val = xp.min(field, axis=0, keepdims=True) + + field_range = max_val - min_val + return 2.0 * (field - min_val) / field_range - 1.0 + + +def unnormalize( + normalized_field: ArrayType, max_val: ArrayType, min_val: ArrayType +) -> ArrayType: + """Reverse the normalization process to recover original field values. + + Transforms normalized values from the range [-1, 1] back to their original + physical range using the stored min/max values. + + Args: + normalized_field: Field values in the normalized range [-1, 1]. + max_val: Maximum values used in the original normalization. + min_val: Minimum values used in the original normalization. + + Returns: + Field values restored to their original physical range. + + Examples: + >>> normalized = np.array([-1.0, -0.5, 0.0, 0.5, 1.0]) + >>> original = unnormalize(normalized, 5.0, 1.0) + >>> print(original) # [1, 2, 3, 4, 5] + """ + field_range = max_val - min_val + return (normalized_field + 1.0) * field_range * 0.5 + min_val -def standardize(field: ArrayType, mean: ArrayType, std: ArrayType) -> ArrayType: - """Function to standardize fields""" +def standardize( + field: ArrayType, + mean: ArrayType | None = None, + std: ArrayType | None = None +) -> ArrayType: + """Standardize field values to have zero mean and unit variance. + + Applies z-score normalization to center the data around zero with + standard deviation of one. This is preferred over min-max normalization + when the data follows a normal distribution. + + Args: + field: Input field array to be standardized. + mean: Mean values for standardization. If None, computed from field data. + std: Standard deviation values for standardization. If None, computed from field data. + + Returns: + Standardized field with approximately zero mean and unit variance. + + Raises: + ZeroDivisionError: If std contains zeros. + + Examples: + >>> field = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + >>> standardized = standardize(field, 3.0, np.sqrt(2.5)) + >>> print(np.mean(standardized), np.std(standardized)) # ~0.0, ~1.0 + + >>> # Auto-compute mean/std + >>> standardized = standardize(field) + """ + xp = array_type(field) + + if mean is None: + mean = xp.mean(field, axis=0, keepdims=True) + if std is None: + std = xp.std(field, axis=0, keepdims=True) + return (field - mean) / std -def unstandardize(field: ArrayType, mean: ArrayType, std: ArrayType) -> ArrayType: - """Function to unstandardize fields""" - return field * std + mean - - -def write_to_vtp(polydata: Any, filename: str): - """Function to write polydata to vtp""" - if not VTK_AVAILABLE: - raise ImportError("VTK or is not installed. This function cannot be used.") +def unstandardize( + standardized_field: ArrayType, mean: ArrayType, std: ArrayType +) -> ArrayType: + """Reverse the standardization process to recover original field values. + + Transforms standardized values (zero mean, unit variance) back to their + original distribution using the stored mean and standard deviation. + + Args: + standardized_field: Field values with zero mean and unit variance. + mean: Mean values used in the original standardization. + std: Standard deviation values used in the original standardization. + + Returns: + Field values restored to their original distribution. + + Examples: + >>> standardized = np.array([-1.26, -0.63, 0.0, 0.63, 1.26]) + >>> original = unstandardize(standardized, 3.0, np.sqrt(2.5)) + >>> print(original) # approximately [1, 2, 3, 4, 5] + """ + return standardized_field * std + mean + + +def write_to_vtp(polydata: "vtk.vtkPolyData", filename: str) -> None: + """Write VTK polydata to a VTP (VTK PolyData) file format. + + VTP files are XML-based and store polygonal data including points, polygons, + and associated field data. This format is commonly used for surface meshes + in computational fluid dynamics visualization. + + Args: + polydata: VTK polydata object containing mesh geometry and fields. + filename: Output filename with .vtp extension. Directory will be created + if it doesn't exist. + + Raises: + RuntimeError: If writing fails due to file permissions or disk space. + + Examples: + >>> # Assuming you have a VTK polydata object + >>> write_to_vtp(surface_mesh, "output/surface.vtp") + """ + # Ensure output directory exists + output_path = Path(filename) + output_path.parent.mkdir(parents=True, exist_ok=True) + writer = vtk.vtkXMLPolyDataWriter() - writer.SetFileName(filename) + writer.SetFileName(str(output_path)) writer.SetInputData(polydata) - writer.Write() - - -def write_to_vtu(polydata: Any, filename: str): - """Function to write polydata to vtu""" - if not VTK_AVAILABLE: - raise ImportError("VTK or is not installed. This function cannot be used.") + + if not writer.Write(): + raise RuntimeError(f"Failed to write polydata to {output_path}") + + +def write_to_vtu(unstructured_grid: "vtk.vtkUnstructuredGrid", filename: str) -> None: + """Write VTK unstructured grid to a VTU (VTK Unstructured Grid) file format. + + VTU files store 3D volumetric meshes with arbitrary cell types including + tetrahedra, hexahedra, and pyramids. This format is essential for storing + finite element analysis results. + + Args: + unstructured_grid: VTK unstructured grid object containing volumetric mesh + geometry and field data. + filename: Output filename with .vtu extension. Directory will be created + if it doesn't exist. + + Raises: + RuntimeError: If writing fails due to file permissions or disk space. + + Examples: + >>> # Assuming you have a VTK unstructured grid object + >>> write_to_vtu(volume_mesh, "output/volume.vtu") + """ + # Ensure output directory exists + output_path = Path(filename) + output_path.parent.mkdir(parents=True, exist_ok=True) + writer = vtk.vtkXMLUnstructuredGridWriter() - writer.SetFileName(filename) - writer.SetInputData(polydata) - writer.Write() - - -def extract_surface_triangles(tet_mesh: Any) -> List[int]: - """Extracts the surface triangles from a triangular mesh.""" - if not VTK_AVAILABLE: - raise ImportError("VTK or is not installed. This function cannot be used.") - if not PV_AVAILABLE: - raise ImportError("PyVista is not installed. This function cannot be used.") + writer.SetFileName(str(output_path)) + writer.SetInputData(unstructured_grid) + + if not writer.Write(): + raise RuntimeError(f"Failed to write unstructured grid to {output_path}") + + +def extract_surface_triangles(tetrahedral_mesh: "vtk.vtkUnstructuredGrid") -> list[int]: + """Extract surface triangle indices from a tetrahedral mesh. + + This function identifies the boundary faces of a 3D tetrahedral mesh and + returns the vertex indices that form triangular faces on the surface. + This is essential for visualization and boundary condition application. + + Args: + tetrahedral_mesh: VTK unstructured grid containing tetrahedral elements. + + Returns: + List of vertex indices forming surface triangles. Every three consecutive + indices define one triangle. + + Raises: + NotImplementedError: If the surface contains non-triangular faces. + + Examples: + >>> # Extract surface from a tet mesh for visualization + >>> surface_indices = extract_surface_triangles(tet_mesh) + >>> # surface_indices = [v1, v2, v3, v4, v5, v6, ...] for triangles + """ + # Extract the surface using VTK filter surface_filter = vtk.vtkDataSetSurfaceFilter() - surface_filter.SetInputData(tet_mesh) + surface_filter.SetInputData(tetrahedral_mesh) surface_filter.Update() + # Wrap with PyVista for easier manipulation surface_mesh = pv.wrap(surface_filter.GetOutput()) triangle_indices = [] + + # Process faces - PyVista stores faces as [n_vertices, v1, v2, ..., vn] faces = surface_mesh.faces.reshape((-1, 4)) for face in faces: - if face[0] == 3: + if face[0] == 3: # Triangle (3 vertices) triangle_indices.extend([face[1], face[2], face[3]]) else: - raise ValueError("Face is not a triangle") + raise NotImplementedError(f"Non-triangular face found with {face[0]} vertices") return triangle_indices -def convert_to_tet_mesh(polydata: Any) -> Any: - """Function to convert tet to stl""" - if not VTK_AVAILABLE: - raise ImportError("VTK or is not installed. This function cannot be used.") - # Create a VTK DataSetTriangleFilter object - tet_filter = vtkDataSetTriangleFilter() - tet_filter.SetInputData(polydata) - tet_filter.Update() # Update to apply the filter - - # Get the output as an UnstructuredGrid - # tet_mesh = pv.wrap(tet_filter.GetOutput()) - tet_mesh = tet_filter.GetOutput() - return tet_mesh - - -def get_node_to_elem(polydata: Any) -> Any: - """Function to convert node to elem""" - if not VTK_AVAILABLE: - raise ImportError("VTK or is not installed. This function cannot be used.") - c2p = vtk.vtkPointDataToCellData() - c2p.SetInputData(polydata) - c2p.Update() - cell_data = c2p.GetOutput() +def convert_to_tet_mesh(polydata: "vtk.vtkPolyData") -> "vtk.vtkUnstructuredGrid": + """Convert surface polydata to a tetrahedral volumetric mesh. + + This function performs tetrahedralization of a surface mesh, creating + a 3D volumetric mesh suitable for finite element analysis. The process + fills the interior of the surface with tetrahedral elements. + + Args: + polydata: VTK polydata representing a closed surface mesh. + + Returns: + VTK unstructured grid containing tetrahedral elements filling the + volume enclosed by the input surface. + + Raises: + RuntimeError: If tetrahedralization fails (e.g., non-manifold surface). + + Examples: + >>> # Convert a surface mesh to volume mesh for FEM analysis + >>> tet_mesh = convert_polydata_to_tetrahedral_mesh(surface_polydata) + """ + tetrahedral_filter = vtkDataSetTriangleFilter() + tetrahedral_filter.SetInputData(polydata) + tetrahedral_filter.Update() + + tetrahedral_mesh = tetrahedral_filter.GetOutput() + return tetrahedral_mesh + + +def convert_point_data_to_cell_data(input_data: "vtk.vtkDataSet") -> "vtk.vtkDataSet": + """Convert point-based field data to cell-based field data. + + This function transforms field variables defined at mesh vertices (nodes) + to values defined at cell centers. This conversion is often needed when + switching between different numerical methods or visualization requirements. + + Args: + input_data: VTK dataset with point data to be converted. + + Returns: + VTK dataset with the same geometry but field data moved from points to cells. + Values are typically averaged from the surrounding points. + + Examples: + >>> # Convert nodal pressure values to cell-centered values + >>> cell_data = convert_point_data_to_cell_data(point_based_mesh) + """ + point_to_cell_filter = vtk.vtkPointDataToCellData() + point_to_cell_filter.SetInputData(input_data) + point_to_cell_filter.Update() + + return point_to_cell_filter.GetOutput() + + +def get_node_to_elem(polydata: "vtk.vtkDataSet") -> "vtk.vtkDataSet": + """Convert point data to cell data for VTK dataset. + + This function transforms field variables defined at mesh vertices to + values defined at cell centers using VTK's built-in conversion filter. + + Args: + polydata: VTK dataset with point data to be converted. + + Returns: + VTK dataset with field data moved from points to cells. + + Examples: + >>> cell_data = get_node_to_elem(point_based_mesh) + """ + point_to_cell_filter = vtk.vtkPointDataToCellData() + point_to_cell_filter.SetInputData(polydata) + point_to_cell_filter.Update() + cell_data = point_to_cell_filter.GetOutput() return cell_data -def get_fields_from_cell(ptdata, var_list): - """Function to get fields from elem""" - fields = [] - for var in var_list: - variable = ptdata.GetArray(var) - num_tuples = variable.GetNumberOfTuples() - cell_fields = [] - for j in range(num_tuples): - variable_value = np.array(variable.GetTuple(j)) - cell_fields.append(variable_value) - cell_fields = np.asarray(cell_fields) - fields.append(cell_fields) - fields = np.transpose(np.asarray(fields), (1, 0)) - - return fields - - -def get_fields(data, variables): - """Function to get fields from VTP/VTU""" - if not VTK_AVAILABLE: - raise ImportError("VTK or is not installed. This function cannot be used.") - fields = [] - for array_name in variables: +def get_fields_from_cell(cell_data: "vtk.vtkCellData", variable_names: list[str]) -> np.ndarray: + """Extract field variables from VTK cell data. + + This function extracts multiple field variables from VTK cell data and + organizes them into a structured NumPy array. Each variable becomes a + column in the output array. + + Args: + cell_data: VTK cell data object containing field variables. + variable_names: List of variable names to extract from the cell data. + + Returns: + NumPy array of shape (n_cells, n_variables) containing the extracted + field data. Variables are ordered according to the input list. + + Raises: + ValueError: If a requested variable name is not found in the cell data. + + Examples: + >>> # Extract pressure and velocity magnitude from cell data + >>> variable_names = ["pressure", "velocity_magnitude"] + >>> fields = get_fields_from_cell(cell_data, variable_names) + >>> print(fields.shape) # (n_cells, 2) + """ + extracted_fields = [] + for variable_name in variable_names: + variable_array = cell_data.GetArray(variable_name) + if variable_array is None: + raise ValueError(f"Variable '{variable_name}' not found in cell data") + + num_tuples = variable_array.GetNumberOfTuples() + field_values = [] + for tuple_idx in range(num_tuples): + variable_value = np.array(variable_array.GetTuple(tuple_idx)) + field_values.append(variable_value) + field_values = np.asarray(field_values) + extracted_fields.append(field_values) + + # Transpose to get shape (n_cells, n_variables) + extracted_fields = np.transpose(np.asarray(extracted_fields), (1, 0)) + return extracted_fields + + +def get_fields(data_attributes: "vtk.vtkDataSetAttributes", variable_names: list[str]) -> list[np.ndarray]: + """Extract multiple field variables from VTK data attributes. + + This function extracts field variables from VTK data attributes (either + point data or cell data) and returns them as a list of NumPy arrays. + It handles both point and cell data seamlessly. + + Args: + data_attributes: VTK data attributes object (point data or cell data). + variable_names: List of variable names to extract. + + Returns: + List of NumPy arrays, one for each requested variable. Each array + has shape (n_points/n_cells, n_components) where n_components + depends on the variable (1 for scalars, 3 for vectors, etc.). + + Raises: + ValueError: If a requested variable is not found in the data attributes. + + Examples: + >>> # Extract multiple variables from point data + >>> point_data = mesh.GetPointData() + >>> variable_names = ["pressure", "velocity", "temperature"] + >>> fields = get_fields(point_data, variable_names) + >>> print(len(fields)) # 3 arrays + """ + extracted_fields = [] + for variable_name in variable_names: try: - array = data.GetArray(array_name) - except ValueError: - raise ValueError( - f"Failed to get array {array_name} from the unstructured grid." - ) - array_data = numpy_support.vtk_to_numpy(array).reshape( - array.GetNumberOfTuples(), array.GetNumberOfComponents() + vtk_array = data_attributes.GetArray(variable_name) + except ValueError as e: + raise ValueError(f"Failed to get array '{variable_name}' from the data attributes: {e}") + + # Convert VTK array to NumPy array with proper shape + numpy_array = numpy_support.vtk_to_numpy(vtk_array).reshape( + vtk_array.GetNumberOfTuples(), vtk_array.GetNumberOfComponents() ) - fields.append(array_data) - return fields - - -def get_vertices(polydata): - """Function to get vertices""" - if not VTK_AVAILABLE: - raise ImportError("VTK or is not installed. This function cannot be used.") - points = polydata.GetPoints() - vertices = numpy_support.vtk_to_numpy(points.GetData()) + extracted_fields.append(numpy_array) + + return extracted_fields + + +def get_vertices(polydata: "vtk.vtkPolyData") -> np.ndarray: + """Extract vertex coordinates from VTK polydata object. + + This function converts VTK polydata to a NumPy array containing the 3D + coordinates of all vertices in the mesh. + + Args: + polydata: VTK polydata object containing mesh geometry. + + Returns: + NumPy array of shape (n_points, 3) containing [x, y, z] coordinates + for each vertex. + + Examples: + >>> vertices = get_vertices(mesh) + >>> print(vertices.shape) # (n_vertices, 3) + """ + vtk_points = polydata.GetPoints() + vertices = numpy_support.vtk_to_numpy(vtk_points.GetData()) return vertices -def get_volume_data(polydata, variables): - """Function to get volume data""" +def get_volume_data(polydata: "vtk.vtkPolyData", variable_names: list[str]) -> tuple[np.ndarray, list[np.ndarray]]: + """Extract vertices and field data from 3D volumetric mesh. + + This function extracts both geometric information (vertex coordinates) + and field data from a 3D volumetric mesh. It's commonly used for + processing finite element analysis results. + + Args: + polydata: VTK polydata representing a 3D volumetric mesh. + variable_names: List of field variable names to extract. + + Returns: + Tuple containing: + - Vertex coordinates as NumPy array of shape (n_vertices, 3) + - List of field arrays, one per variable + + Examples: + >>> # Extract geometry and flow fields from CFD results + >>> vertices, fields = get_volume_data(polydata, ["pressure", "velocity"]) + >>> print(vertices.shape, len(fields)) # (n_vertices, 3), 2 + """ vertices = get_vertices(polydata) point_data = polydata.GetPointData() - - fields = get_fields(point_data, variables) - + fields = get_fields(point_data, variable_names) + return vertices, fields -def get_surface_data(polydata, variables): - """Function to get surface data""" - if not VTK_AVAILABLE: - raise ImportError("VTK or is not installed. This function cannot be used.") +def get_surface_data(polydata: "vtk.vtkPolyData", variable_names: list[str]) -> tuple[np.ndarray, list[np.ndarray], list[tuple[int, int]]]: + """Extract surface mesh data including vertices, fields, and edge connectivity. + + This function extracts comprehensive surface mesh information including + vertex coordinates, field data at vertices, and edge connectivity information. + It's commonly used for processing CFD surface results and boundary conditions. + + Args: + polydata: VTK polydata representing a surface mesh. + variable_names: List of field variable names to extract from the mesh. + + Returns: + Tuple containing: + - Vertex coordinates as NumPy array of shape (n_vertices, 3) + - List of field arrays, one per variable + - List of edge tuples representing mesh connectivity + + Raises: + ValueError: If a requested variable is not found or polygon data is missing. + + Examples: + >>> # Extract surface mesh data for visualization + >>> vertices, fields, edges = get_surface_data(surface_mesh, ["pressure", "shear_stress"]) + >>> print(vertices.shape, len(fields), len(edges)) + """ points = polydata.GetPoints() vertices = np.array([points.GetPoint(i) for i in range(points.GetNumberOfPoints())]) point_data = polydata.GetPointData() fields = [] - for array_name in variables: + for array_name in variable_names: try: array = point_data.GetArray(array_name) except ValueError: @@ -259,14 +604,39 @@ def get_surface_data(polydata, variables): def calculate_normal_positional_encoding( coordinates_a: ArrayType, - coordinates_b: Optional[ArrayType] = None, - cell_length: Sequence[float] = [], + coordinates_b: ArrayType | None = None, + cell_dimensions: Sequence[float] = (1.0, 1.0, 1.0), ) -> ArrayType: - """Function to get normal positional encoding""" - dx = cell_length[0] - dy = cell_length[1] - dz = cell_length[2] + """Calculate sinusoidal positional encoding for 3D coordinates. + + This function computes transformer-style positional encodings for 3D spatial + coordinates, enabling neural networks to understand spatial relationships. + The encoding uses sinusoidal functions at different frequencies to create + unique representations for each spatial position. + + Args: + coordinates_a: Primary coordinates array of shape (n_points, 3). + coordinates_b: Optional secondary coordinates for computing relative positions. + If provided, the encoding is computed for (coordinates_a - coordinates_b). + cell_dimensions: Characteristic length scales for x, y, z dimensions used + for normalization. Defaults to unit dimensions. + + Returns: + Array of shape (n_points, 12) containing positional encodings with + 4 encoding dimensions per spatial axis (x, y, z). + + Examples: + >>> coords = np.random.rand(100, 3) * 10 # Random 3D points + >>> cell_size = [0.1, 0.1, 0.1] # Grid resolution + >>> encoding = calculate_positional_encoding_for_coordinates(coords, cell_dimensions=cell_size) + >>> print(encoding.shape) # (100, 12) + + >>> # For relative positions between two point sets + >>> encoding_rel = calculate_positional_encoding_for_coordinates(coords_a, coords_b, cell_size) + """ + dx, dy, dz = cell_dimensions[0], cell_dimensions[1], cell_dimensions[2] xp = array_type(coordinates_a) + if coordinates_b is not None: normals = coordinates_a - coordinates_b pos_x = xp.asarray(calculate_pos_encoding(normals[:, 0] / dx, d=4)) @@ -283,10 +653,31 @@ def calculate_normal_positional_encoding( return pos_normals -def nd_interpolator(coodinates, field, grid): - """Function to for nd interpolation""" +def nd_interpolator(coordinates: ArrayType, field: ArrayType, grid: ArrayType) -> ArrayType: + """Perform n-dimensional interpolation using k-nearest neighbors. + + This function interpolates field values from scattered points to a regular + grid using k-nearest neighbor averaging. It's useful for reconstructing + fields on regular grids from irregular measurement points. + + Args: + coordinates: Array of shape (n_points, n_dims) containing source point coordinates. + field: Array of shape (n_points, n_fields) containing field values at source points. + grid: Array containing target grid points for interpolation. + + Returns: + Interpolated field values at grid points using 2-nearest neighbor averaging. + + Note: + This function currently uses SciPy's KDTree which only supports CPU arrays. + A future enhancement could add CuML support for GPU acceleration. + + Examples: + >>> # Interpolate scattered CFD data to regular grid + >>> grid_values = nd_interpolator(mesh_coords, pressure_field, regular_grid) + """ # TODO - this function should get updated for cuml if using cupy. - interp_func = KDTree(coodinates[0]) + interp_func = KDTree(coordinates[0]) dd, ii = interp_func.query(grid, k=2) field_grid = field[ii] @@ -294,21 +685,65 @@ def nd_interpolator(coodinates, field, grid): return field_grid -def pad(arr: ArrayType, npoin: int, pad_value: float = 0.0) -> ArrayType: - """Function for padding""" +def pad(arr: ArrayType, n_points: int, pad_value: float = 0.0) -> ArrayType: + """Pad 2D array with constant values to reach target size. + + This function extends a 2D array by adding rows filled with a constant + value. It's commonly used to standardize array sizes in batch processing + for machine learning applications. + + Args: + arr: Input array of shape (n_points, n_features) to be padded. + n_points: Target number of points (rows) after padding. + pad_value: Constant value used for padding. Defaults to 0.0. + + Returns: + Padded array of shape (n_points, n_features). If n_points <= arr.shape[0], + returns the original array unchanged. + + Examples: + >>> arr = np.random.rand(100, 3) + >>> padded = pad(arr, 150, -1.0) # Pad to 150 points with -1.0 + >>> print(padded.shape) # (150, 3) + """ xp = array_type(arr) + if n_points <= arr.shape[0]: + return arr + arr_pad = pad_value * xp.ones( - (npoin - arr.shape[0], arr.shape[1]), dtype=xp.float32 + (n_points - arr.shape[0], arr.shape[1]), dtype=xp.float32 ) arr_padded = xp.concatenate((arr, arr_pad), axis=0) return arr_padded -def pad_inp(arr: ArrayType, npoin: int, pad_value: float = 0.0) -> ArrayType: - """Function for padding arrays""" +def pad_inp(arr: ArrayType, n_points: int, pad_value: float = 0.0) -> ArrayType: + """Pad 3D array with constant values to reach target size. + + This function extends a 3D array by adding entries along the first dimension + filled with a constant value. Used for standardizing 3D tensor sizes in + batch processing workflows. + + Args: + arr: Input array of shape (n_points, height, width) to be padded. + n_points: Target number of points along first dimension after padding. + pad_value: Constant value used for padding. Defaults to 0.0. + + Returns: + Padded array of shape (n_points, height, width). If n_points <= arr.shape[0], + returns the original array unchanged. + + Examples: + >>> arr = np.random.rand(50, 10, 5) + >>> padded = pad_inp(arr, 100, 0.0) # Pad to 100 entries + >>> print(padded.shape) # (100, 10, 5) + """ xp = array_type(arr) + if n_points <= arr.shape[0]: + return arr + arr_pad = pad_value * xp.ones( - (npoin - arr.shape[0], arr.shape[1], arr.shape[2]), dtype=xp.float32 + (n_points - arr.shape[0], arr.shape[1], arr.shape[2]), dtype=xp.float32 ) arr_padded = xp.concatenate((arr, arr_pad), axis=0) return arr_padded @@ -317,82 +752,203 @@ def pad_inp(arr: ArrayType, npoin: int, pad_value: float = 0.0) -> ArrayType: @profile def shuffle_array( arr: ArrayType, - npoin: int, -) -> Tuple[ArrayType, ArrayType]: - """Function for shuffling arrays""" + n_points: int, +) -> tuple[ArrayType, ArrayType]: + """Randomly sample points from array without replacement. + + This function performs random sampling from the input array, selecting + n_points points without replacement. It's commonly used for creating training + subsets and data augmentation in machine learning workflows. + + Args: + arr: Input array to sample from, shape (n_points, ...). + n_points: Number of points to sample. If greater than arr.shape[0], + all points are returned. + + Returns: + Tuple containing: + - Sampled array subset + - Indices of the selected points + + Examples: + >>> data = np.random.rand(1000, 3) + >>> subset, indices = shuffle_array(data, 100) + >>> print(subset.shape, indices.shape) # (100, 3), (100,) + """ xp = array_type(arr) - if npoin > arr.shape[0]: + if n_points > arr.shape[0]: # If asking too many points, truncate the ask but still shuffle. - npoin = arr.shape[0] - idx = xp.random.choice(arr.shape[0], size=npoin, replace=False) + n_points = arr.shape[0] + idx = xp.random.choice(arr.shape[0], size=n_points, replace=False) return arr[idx], idx -def shuffle_array_without_sampling(arr: ArrayType) -> Tuple[ArrayType, ArrayType]: - """Function for shuffline arrays without sampling.""" +def shuffle_array_without_sampling(arr: ArrayType) -> tuple[ArrayType, ArrayType]: + """Shuffle array order without changing the number of elements. + + This function reorders all elements in the array randomly while preserving + all data points. It's useful for randomizing data order before training + while maintaining the complete dataset. + + Args: + arr: Input array to shuffle, shape (n_points, ...). + + Returns: + Tuple containing: + - Shuffled array with same shape as input + - Permutation indices used for shuffling + + Examples: + >>> data = np.arange(100).reshape(100, 1) + >>> shuffled, indices = shuffle_array_without_sampling(data) + >>> print(shuffled.shape) # (100, 1) - same size, different order + """ xp = array_type(arr) idx = xp.arange(arr.shape[0]) xp.random.shuffle(idx) return arr[idx], idx -def create_directory(filepath: str) -> None: - """Function to create directories""" - if not os.path.exists(filepath): - os.makedirs(filepath) - - -def get_filenames(filepath: str, exclude_dirs: bool = False) -> List[str]: - """Function to get filenames from a directory""" - if os.path.exists(filepath): - filenames = [] - for item in os.listdir(filepath): - item_path = os.path.join(filepath, item) - if exclude_dirs and os.path.isdir(item_path): - # Include directories ending with .zarr even when exclude_dirs is True - if item.endswith(".zarr"): - filenames.append(item) - continue - filenames.append(item) - return filenames - else: - FileNotFoundError() - - -def calculate_pos_encoding(nx: ArrayType, d: int = 8) -> ArrayType: - """Function for calculating positional encoding""" +def create_directory(filepath: str | Path) -> None: + """Create directory and all necessary parent directories. + + This function creates a directory at the specified path, including any + necessary parent directories. It's equivalent to `mkdir -p` in Unix systems. + + Args: + filepath: Path to the directory to create. Can be string or Path object. + + Examples: + >>> create_directory("data/processed/meshes") + >>> create_directory(Path("output") / "results" / "visualization") + """ + Path(filepath).mkdir(parents=True, exist_ok=True) + + +def get_filenames(filepath: str | Path, exclude_dirs: bool = False) -> list[str]: + """Get list of filenames in a directory with optional directory filtering. + + This function returns all items in a directory, with options to exclude + subdirectories. It handles special cases like .zarr directories which + are treated as files even when exclude_dirs is True. + + Args: + filepath: Path to the directory to list. Can be string or Path object. + exclude_dirs: If True, exclude subdirectories from results. + Exception: .zarr directories are always included as they represent + data files in array storage format. + + Returns: + List of filenames/directory names found in the specified directory. + + Raises: + FileNotFoundError: If the specified directory does not exist. + + Examples: + >>> files = get_filenames("data/inputs") + >>> data_files = get_filenames("results", exclude_dirs=True) + """ + path = Path(filepath) + if not path.exists(): + raise FileNotFoundError(f"Directory {filepath} does not exist") + + filenames = [] + for item in path.iterdir(): + if exclude_dirs and item.is_dir(): + # Include directories ending with .zarr even when exclude_dirs is True + if item.name.endswith(".zarr"): + filenames.append(item.name) + continue + filenames.append(item.name) + return filenames + + +def calculate_pos_encoding(nx: ArrayType, d: int = 8) -> list[ArrayType]: + """Calculate sinusoidal positional encoding for transformer architectures. + + This function computes positional encodings using alternating sine and cosine + functions at different frequencies. These encodings help neural networks + understand positional relationships in sequences or spatial data. + + Args: + nx: Input positions/coordinates to encode. + d: Encoding dimensionality. Must be even number. Defaults to 8. + + Returns: + List of d arrays containing alternating sine and cosine encodings. + Each pair (sin, cos) uses progressively lower frequencies. + + Examples: + >>> positions = np.linspace(0, 100, 50) + >>> encodings = calculate_pos_encoding(positions, d=16) + >>> print(len(encodings)) # 16 encoding dimensions + """ vec = [] xp = array_type(nx) for k in range(int(d / 2)): - vec.append(xp.sin(nx / 10000 ** (2 * (k) / d))) - vec.append(xp.cos(nx / 10000 ** (2 * (k) / d))) + vec.append(xp.sin(nx / 10000 ** (2 * k / d))) + vec.append(xp.cos(nx / 10000 ** (2 * k / d))) return vec -def combine_dict(old_dict, new_dict): - """Function to combine dictionaries""" - for j in old_dict.keys(): - old_dict[j] += new_dict[j] +def combine_dict(old_dict: dict[Any, Any], new_dict: dict[Any, Any]) -> dict[Any, Any]: + """Combine two dictionaries by adding values for matching keys. + + This function performs element-wise addition of dictionary values for + keys that exist in both dictionaries. It's commonly used for accumulating + statistics or metrics across multiple iterations. + + Args: + old_dict: Base dictionary to update. + new_dict: Dictionary with values to add to old_dict. + + Returns: + Updated old_dict with combined values. + + Note: + This function modifies old_dict in place and returns it. + Values must support the + operator. + + Examples: + >>> stats1 = {"loss": 0.5, "accuracy": 0.8} + >>> stats2 = {"loss": 0.3, "accuracy": 0.1} + >>> combined = combine_dict(stats1, stats2) + >>> print(combined) # {"loss": 0.8, "accuracy": 0.9} + """ + for key in old_dict.keys(): + old_dict[key] += new_dict[key] return old_dict -def merge(*lists): - """Function to merge lists""" - newlist = lists[:] - for x in lists: - if x not in newlist: - newlist.extend(x) - return newlist - - -def create_grid(mx: ArrayType, mn: ArrayType, nres: ArrayType) -> ArrayType: - """Function to create grid""" - - xp = array_type(mx) +def create_grid(max_coords: ArrayType, min_coords: ArrayType, resolution: ArrayType) -> ArrayType: + """Create a 3D regular grid from coordinate bounds and resolution. + + This function generates a regular 3D grid spanning from min_coords to + max_coords with the specified resolution in each dimension. The resulting + grid is commonly used for interpolation, visualization, and regular sampling. + + Args: + max_coords: Maximum coordinates [x_max, y_max, z_max] for the grid bounds. + min_coords: Minimum coordinates [x_min, y_min, z_min] for the grid bounds. + resolution: Number of grid points [nx, ny, nz] in each dimension. + + Returns: + Grid array of shape (nx, ny, nz, 3) containing 3D coordinates for each + grid point. The last dimension contains [x, y, z] coordinates. + + Examples: + >>> # Create a 10x10x10 grid from (0,0,0) to (1,1,1) + >>> min_bounds = np.array([0, 0, 0]) + >>> max_bounds = np.array([1, 1, 1]) + >>> grid_res = np.array([10, 10, 10]) + >>> grid = create_grid(max_bounds, min_bounds, grid_res) + >>> print(grid.shape) # (10, 10, 10, 3) + """ + xp = array_type(max_coords) - dx = xp.linspace(mn[0], mx[0], nres[0], dtype=mx.dtype) - dy = xp.linspace(mn[1], mx[1], nres[1], dtype=mx.dtype) - dz = xp.linspace(mn[2], mx[2], nres[2], dtype=mx.dtype) + dx = xp.linspace(min_coords[0], max_coords[0], resolution[0], dtype=max_coords.dtype) + dy = xp.linspace(min_coords[1], max_coords[1], resolution[1], dtype=max_coords.dtype) + dz = xp.linspace(min_coords[2], max_coords[2], resolution[2], dtype=max_coords.dtype) xv, yv, zv = xp.meshgrid(dx, dy, dz) xv = xp.expand_dims(xv, -1) @@ -406,10 +962,32 @@ def create_grid(mx: ArrayType, mn: ArrayType, nres: ArrayType) -> ArrayType: def mean_std_sampling( field: ArrayType, mean: ArrayType, std: ArrayType, tolerance: float = 3.0 -) -> ArrayType: - """Function for mean/std based sampling""" +) -> list[int]: + """Identify outlier points based on statistical distance from mean. + + This function identifies data points that are statistical outliers by + checking if they fall outside mean ± tolerance*std for any field component. + It's useful for data cleaning and identifying regions of interest in CFD data. + + Args: + field: Input field array of shape (n_points, n_components). + mean: Mean values for each field component, shape (n_components,). + std: Standard deviation for each component, shape (n_components,). + tolerance: Number of standard deviations to use as outlier threshold. + Defaults to 3.0 (99.7% of normal distribution). + + Returns: + List of indices identifying outlier points that exceed the statistical threshold. + + Examples: + >>> # Find outliers in pressure field + >>> pressure_data = np.random.normal(100, 10, (1000, 1)) + >>> pressure_mean = np.array([100.0]) + >>> pressure_std = np.array([10.0]) + >>> outliers = mean_std_sampling(pressure_data, pressure_mean, pressure_std, 2.5) + >>> print(f"Found {len(outliers)} outlier points") + """ xp = array_type(field) - idx_all = [] for v in range(field.shape[-1]): fv = field[:, v] @@ -422,8 +1000,31 @@ def mean_std_sampling( return idx_all -def dict_to_device(state_dict, device, exclude_keys=["filename"]): - """Function to load dictionary to device""" +def dict_to_device(state_dict: dict[str, Any], device: Any, exclude_keys: list[str] | None = None) -> dict[str, Any]: + """Move dictionary values to specified device (GPU/CPU). + + This function transfers PyTorch tensors in a dictionary to the specified + compute device while preserving the dictionary structure. It's commonly + used for moving model parameters and data between CPU and GPU. + + Args: + state_dict: Dictionary containing tensors and other values. + device: Target device (e.g., torch.device('cuda:0')). + exclude_keys: List of keys to skip during device transfer. + Defaults to ["filename"] if None. + + Returns: + New dictionary with tensors moved to the specified device. + Non-tensor values and excluded keys are preserved as-is. + + Examples: + >>> import torch + >>> data = {"weights": torch.randn(10, 10), "filename": "model.pt"} + >>> gpu_data = dict_to_device(data, torch.device('cuda:0')) + """ + if exclude_keys is None: + exclude_keys = ["filename"] + new_state_dict = {} for k, v in state_dict.items(): if k not in exclude_keys: @@ -432,17 +1033,46 @@ def dict_to_device(state_dict, device, exclude_keys=["filename"]): def area_weighted_shuffle_array( - arr: ArrayType, npoin: int, area: ArrayType -) -> Tuple[ArrayType, ArrayType]: - """Function for area weighted shuffling""" + arr: ArrayType, n_points: int, area: ArrayType +) -> tuple[ArrayType, ArrayType]: + """Perform area-weighted random sampling from array. + + This function samples points from an array with probability proportional to + their associated area weights. This is particularly useful in CFD applications + where larger cells or surface elements should have higher sampling probability. + + Args: + arr: Input array to sample from, shape (n_points, ...). + n_points: Number of points to sample. If greater than arr.shape[0], + samples all available points. + area: Area weights for each point, shape (n_points,). Larger values + indicate higher sampling probability. + + Returns: + Tuple containing: + - Sampled array subset weighted by area + - Indices of the selected points + + Note: + For GPU arrays (CuPy), the sampling is performed on CPU due to memory + efficiency considerations. The Alias method could be implemented for + future GPU acceleration. + + Examples: + >>> # Sample mesh points with area weighting + >>> mesh_data = np.random.rand(1000, 3) + >>> cell_areas = np.random.exponential(1.0, 1000) # Area weights + >>> subset, indices = area_weighted_shuffle_array(mesh_data, 100, cell_areas) + >>> print(subset.shape) # (100, 3) - larger cells more likely selected + """ xp = array_type(arr) # Compute the total_area: factor = 1.0 total_area = xp.sum(area**factor) probs = area**factor / total_area - if npoin > arr.shape[0]: - npoin = arr.shape[0] + if n_points > arr.shape[0]: + n_points = arr.shape[0] idx = xp.arange(arr.shape[0]) @@ -455,10 +1085,10 @@ def area_weighted_shuffle_array( # In principle, we could use the Alias method to speed this up # on the GPU but it's not yet a bottleneck. - ids = np.random.choice(idx, npoin, p=probs) + ids = np.random.choice(idx, n_points, p=probs) ids = xp.asarray(ids) else: # Chug along on the CPU: - ids = xp.random.choice(idx, npoin, p=probs) + ids = xp.random.choice(idx, n_points, p=probs) return arr[ids], ids From 832ea539a8ff7385e7c8f6f71001fd663bc21c67 Mon Sep 17 00:00:00 2001 From: Peter Sharpe Date: Fri, 20 Jun 2025 07:56:33 -0400 Subject: [PATCH 02/19] Adds missing tensorboard requirement --- examples/cfd/external_aerodynamics/domino/requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/cfd/external_aerodynamics/domino/requirements.txt b/examples/cfd/external_aerodynamics/domino/requirements.txt index 1440cbf948..1011ad4a82 100644 --- a/examples/cfd/external_aerodynamics/domino/requirements.txt +++ b/examples/cfd/external_aerodynamics/domino/requirements.txt @@ -1,2 +1,3 @@ torchinfo -warp-lang \ No newline at end of file +warp-lang +tensorboard From 5f23d5c00bd7815c178e9619d12c999bd73db4a7 Mon Sep 17 00:00:00 2001 From: Peter Sharpe Date: Fri, 20 Jun 2025 08:08:14 -0400 Subject: [PATCH 03/19] Fixes missing cuml requirement --- examples/cfd/external_aerodynamics/domino/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/cfd/external_aerodynamics/domino/requirements.txt b/examples/cfd/external_aerodynamics/domino/requirements.txt index 1011ad4a82..cafc1c7a4c 100644 --- a/examples/cfd/external_aerodynamics/domino/requirements.txt +++ b/examples/cfd/external_aerodynamics/domino/requirements.txt @@ -1,3 +1,4 @@ torchinfo warp-lang tensorboard +cuml-cu12>=25.6.0 From f0dc954070e1d2e419d28d48258d7d0f0026f43c Mon Sep 17 00:00:00 2001 From: Peter Sharpe Date: Fri, 20 Jun 2025 08:22:35 -0400 Subject: [PATCH 04/19] Begins process of fixing inference_on_stl.py --- .../domino/src/inference_on_stl.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/cfd/external_aerodynamics/domino/src/inference_on_stl.py b/examples/cfd/external_aerodynamics/domino/src/inference_on_stl.py index 2c8f1afcb8..6a9701df2a 100644 --- a/examples/cfd/external_aerodynamics/domino/src/inference_on_stl.py +++ b/examples/cfd/external_aerodynamics/domino/src/inference_on_stl.py @@ -34,11 +34,11 @@ import numpy as np import torch -from modulus.models.domino.model import DoMINO -from modulus.utils.domino.utils import * +from physicsnemo.models.domino.model import DoMINO +from physicsnemo.utils.domino.utils import unnormalize, create_directory, nd_interpolator, get_filenames, write_to_vtp from torch.cuda.amp import autocast from torch.nn.parallel import DistributedDataParallel -from modulus.distributed import DistributedManager +from physicsnemo.distributed import DistributedManager from numpy.typing import NDArray from typing import Any, Iterable, List, Literal, Mapping, Optional, Union, Callable @@ -49,7 +49,7 @@ import pyvista as pv try: - from modulus.sym.geometry.tessellation import Tessellation + from physicsnemo.sym.geometry.tessellation import Tessellation SYM_AVAILABLE = True except ImportError: @@ -663,7 +663,7 @@ def __init__( self, cfg: DictConfig, dist: None, - cached_geo_encoding: False, + cached_geo_encoding: bool=False, ): self.cfg = cfg From 55c8c7de40b59a16ebfb74d9d0a2197d29e1d7de Mon Sep 17 00:00:00 2001 From: Peter Sharpe Date: Fri, 20 Jun 2025 08:22:51 -0400 Subject: [PATCH 05/19] Fixes outdated type definition --- examples/cfd/external_aerodynamics/domino/src/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/cfd/external_aerodynamics/domino/src/train.py b/examples/cfd/external_aerodynamics/domino/src/train.py index 19abb63fb0..ba6788a04a 100644 --- a/examples/cfd/external_aerodynamics/domino/src/train.py +++ b/examples/cfd/external_aerodynamics/domino/src/train.py @@ -265,7 +265,7 @@ def compute_loss_dict( integral_scaling_factor: float, surf_loss_scaling: float, vol_loss_scaling: float, -) -> Tuple[torch.Tensor, dict]: +) -> tuple[torch.Tensor, dict]: """ Compute the loss terms in a single function call. From 69b07668691b15186cce1e18f4374336463407d8 Mon Sep 17 00:00:00 2001 From: Peter Sharpe Date: Fri, 20 Jun 2025 08:29:02 -0400 Subject: [PATCH 06/19] black formatting pass --- physicsnemo/utils/domino/utils.py | 384 ++++++++++++++++-------------- 1 file changed, 201 insertions(+), 183 deletions(-) diff --git a/physicsnemo/utils/domino/utils.py b/physicsnemo/utils/domino/utils.py index c0592481cd..7fcbeb5d0e 100644 --- a/physicsnemo/utils/domino/utils.py +++ b/physicsnemo/utils/domino/utils.py @@ -41,16 +41,16 @@ def array_type(array: ArrayType) -> type[np] | type[cp]: """Determine the array module (NumPy or CuPy) for the given array. - + This function enables array-agnostic code by returning the appropriate array module that can be used for operations on the input array. - + Args: array: Input array that can be either NumPy or CuPy array. - + Returns: The array module (numpy or cupy) corresponding to the input array type. - + Examples: >>> import numpy as np >>> arr = np.array([1, 2, 3]) @@ -60,26 +60,24 @@ def array_type(array: ArrayType) -> type[np] | type[cp]: return cp.get_array_module(array) -def calculate_center_of_mass( - centers: ArrayType, sizes: ArrayType -) -> ArrayType: +def calculate_center_of_mass(centers: ArrayType, sizes: ArrayType) -> ArrayType: """Calculate the center of mass for a collection of elements. - + Computes the volume-weighted centroid of mesh elements, commonly used in computational fluid dynamics for mesh analysis and load balancing. - + Args: centers: Array of shape (n_elements, 3) containing the centroid coordinates of each element. sizes: Array of shape (n_elements,) containing the volume or area of each element used as weights. - + Returns: Array of shape (3,) containing the x, y, z coordinates of the center of mass. - + Raises: ValueError: If centers and sizes have incompatible shapes. - + Examples: >>> centers = np.array([[0, 0, 0], [1, 1, 1], [2, 2, 2]]) >>> sizes = np.array([1.0, 2.0, 3.0]) @@ -87,52 +85,50 @@ def calculate_center_of_mass( >>> print(com) # [1.5, 1.5, 1.5] """ xp = array_type(centers) - - total_weighted_position = xp.einsum('i,ij->ij', sizes, centers) + + total_weighted_position = xp.einsum("i,ij->ij", sizes, centers) total_size = xp.sum(sizes) - + return total_weighted_position / total_size def normalize( - field: ArrayType, - max_val: ArrayType | None = None, - min_val: ArrayType | None = None + field: ArrayType, max_val: ArrayType | None = None, min_val: ArrayType | None = None ) -> ArrayType: """Normalize field values to the range [-1, 1]. - + Applies min-max normalization to scale field values to a symmetric range around zero. This is commonly used in machine learning preprocessing to ensure numerical stability and faster convergence. - + Args: field: Input field array to be normalized. max_val: Maximum values for normalization, can be scalar or array. If None, computed from the field data. min_val: Minimum values for normalization, can be scalar or array. If None, computed from the field data. - + Returns: Normalized field with values in the range [-1, 1]. - + Raises: ZeroDivisionError: If max_val equals min_val (zero range). - + Examples: >>> field = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) >>> normalized = normalize(field, 5.0, 1.0) >>> print(normalized) # [-1, -0.5, 0, 0.5, 1] - + >>> # Auto-compute min/max >>> normalized = normalize(field) """ xp = array_type(field) - + if max_val is None: max_val = xp.max(field, axis=0, keepdims=True) if min_val is None: min_val = xp.min(field, axis=0, keepdims=True) - + field_range = max_val - min_val return 2.0 * (field - min_val) / field_range - 1.0 @@ -141,18 +137,18 @@ def unnormalize( normalized_field: ArrayType, max_val: ArrayType, min_val: ArrayType ) -> ArrayType: """Reverse the normalization process to recover original field values. - + Transforms normalized values from the range [-1, 1] back to their original physical range using the stored min/max values. - + Args: normalized_field: Field values in the normalized range [-1, 1]. max_val: Maximum values used in the original normalization. min_val: Minimum values used in the original normalization. - + Returns: Field values restored to their original physical range. - + Examples: >>> normalized = np.array([-1.0, -0.5, 0.0, 0.5, 1.0]) >>> original = unnormalize(normalized, 5.0, 1.0) @@ -163,42 +159,40 @@ def unnormalize( def standardize( - field: ArrayType, - mean: ArrayType | None = None, - std: ArrayType | None = None + field: ArrayType, mean: ArrayType | None = None, std: ArrayType | None = None ) -> ArrayType: """Standardize field values to have zero mean and unit variance. - + Applies z-score normalization to center the data around zero with standard deviation of one. This is preferred over min-max normalization when the data follows a normal distribution. - + Args: field: Input field array to be standardized. mean: Mean values for standardization. If None, computed from field data. std: Standard deviation values for standardization. If None, computed from field data. - + Returns: Standardized field with approximately zero mean and unit variance. - + Raises: ZeroDivisionError: If std contains zeros. - + Examples: >>> field = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) >>> standardized = standardize(field, 3.0, np.sqrt(2.5)) >>> print(np.mean(standardized), np.std(standardized)) # ~0.0, ~1.0 - + >>> # Auto-compute mean/std >>> standardized = standardize(field) """ xp = array_type(field) - + if mean is None: mean = xp.mean(field, axis=0, keepdims=True) if std is None: std = xp.std(field, axis=0, keepdims=True) - + return (field - mean) / std @@ -206,18 +200,18 @@ def unstandardize( standardized_field: ArrayType, mean: ArrayType, std: ArrayType ) -> ArrayType: """Reverse the standardization process to recover original field values. - + Transforms standardized values (zero mean, unit variance) back to their original distribution using the stored mean and standard deviation. - + Args: standardized_field: Field values with zero mean and unit variance. mean: Mean values used in the original standardization. std: Standard deviation values used in the original standardization. - + Returns: Field values restored to their original distribution. - + Examples: >>> standardized = np.array([-1.26, -0.63, 0.0, 0.63, 1.26]) >>> original = unstandardize(standardized, 3.0, np.sqrt(2.5)) @@ -228,19 +222,19 @@ def unstandardize( def write_to_vtp(polydata: "vtk.vtkPolyData", filename: str) -> None: """Write VTK polydata to a VTP (VTK PolyData) file format. - + VTP files are XML-based and store polygonal data including points, polygons, and associated field data. This format is commonly used for surface meshes in computational fluid dynamics visualization. - + Args: polydata: VTK polydata object containing mesh geometry and fields. filename: Output filename with .vtp extension. Directory will be created if it doesn't exist. - + Raises: RuntimeError: If writing fails due to file permissions or disk space. - + Examples: >>> # Assuming you have a VTK polydata object >>> write_to_vtp(surface_mesh, "output/surface.vtp") @@ -248,31 +242,31 @@ def write_to_vtp(polydata: "vtk.vtkPolyData", filename: str) -> None: # Ensure output directory exists output_path = Path(filename) output_path.parent.mkdir(parents=True, exist_ok=True) - + writer = vtk.vtkXMLPolyDataWriter() writer.SetFileName(str(output_path)) writer.SetInputData(polydata) - + if not writer.Write(): raise RuntimeError(f"Failed to write polydata to {output_path}") def write_to_vtu(unstructured_grid: "vtk.vtkUnstructuredGrid", filename: str) -> None: """Write VTK unstructured grid to a VTU (VTK Unstructured Grid) file format. - + VTU files store 3D volumetric meshes with arbitrary cell types including tetrahedra, hexahedra, and pyramids. This format is essential for storing finite element analysis results. - + Args: unstructured_grid: VTK unstructured grid object containing volumetric mesh geometry and field data. filename: Output filename with .vtu extension. Directory will be created if it doesn't exist. - + Raises: RuntimeError: If writing fails due to file permissions or disk space. - + Examples: >>> # Assuming you have a VTK unstructured grid object >>> write_to_vtu(volume_mesh, "output/volume.vtu") @@ -280,32 +274,32 @@ def write_to_vtu(unstructured_grid: "vtk.vtkUnstructuredGrid", filename: str) -> # Ensure output directory exists output_path = Path(filename) output_path.parent.mkdir(parents=True, exist_ok=True) - + writer = vtk.vtkXMLUnstructuredGridWriter() writer.SetFileName(str(output_path)) writer.SetInputData(unstructured_grid) - + if not writer.Write(): raise RuntimeError(f"Failed to write unstructured grid to {output_path}") def extract_surface_triangles(tetrahedral_mesh: "vtk.vtkUnstructuredGrid") -> list[int]: """Extract surface triangle indices from a tetrahedral mesh. - + This function identifies the boundary faces of a 3D tetrahedral mesh and returns the vertex indices that form triangular faces on the surface. This is essential for visualization and boundary condition application. - + Args: tetrahedral_mesh: VTK unstructured grid containing tetrahedral elements. - + Returns: List of vertex indices forming surface triangles. Every three consecutive indices define one triangle. - + Raises: NotImplementedError: If the surface contains non-triangular faces. - + Examples: >>> # Extract surface from a tet mesh for visualization >>> surface_indices = extract_surface_triangles(tet_mesh) @@ -319,35 +313,37 @@ def extract_surface_triangles(tetrahedral_mesh: "vtk.vtkUnstructuredGrid") -> li # Wrap with PyVista for easier manipulation surface_mesh = pv.wrap(surface_filter.GetOutput()) triangle_indices = [] - + # Process faces - PyVista stores faces as [n_vertices, v1, v2, ..., vn] faces = surface_mesh.faces.reshape((-1, 4)) for face in faces: if face[0] == 3: # Triangle (3 vertices) triangle_indices.extend([face[1], face[2], face[3]]) else: - raise NotImplementedError(f"Non-triangular face found with {face[0]} vertices") + raise NotImplementedError( + f"Non-triangular face found with {face[0]} vertices" + ) return triangle_indices def convert_to_tet_mesh(polydata: "vtk.vtkPolyData") -> "vtk.vtkUnstructuredGrid": """Convert surface polydata to a tetrahedral volumetric mesh. - + This function performs tetrahedralization of a surface mesh, creating a 3D volumetric mesh suitable for finite element analysis. The process fills the interior of the surface with tetrahedral elements. - + Args: polydata: VTK polydata representing a closed surface mesh. - + Returns: VTK unstructured grid containing tetrahedral elements filling the volume enclosed by the input surface. - + Raises: RuntimeError: If tetrahedralization fails (e.g., non-manifold surface). - + Examples: >>> # Convert a surface mesh to volume mesh for FEM analysis >>> tet_mesh = convert_polydata_to_tetrahedral_mesh(surface_polydata) @@ -362,18 +358,18 @@ def convert_to_tet_mesh(polydata: "vtk.vtkPolyData") -> "vtk.vtkUnstructuredGrid def convert_point_data_to_cell_data(input_data: "vtk.vtkDataSet") -> "vtk.vtkDataSet": """Convert point-based field data to cell-based field data. - + This function transforms field variables defined at mesh vertices (nodes) to values defined at cell centers. This conversion is often needed when switching between different numerical methods or visualization requirements. - + Args: input_data: VTK dataset with point data to be converted. - + Returns: VTK dataset with the same geometry but field data moved from points to cells. Values are typically averaged from the surrounding points. - + Examples: >>> # Convert nodal pressure values to cell-centered values >>> cell_data = convert_point_data_to_cell_data(point_based_mesh) @@ -381,22 +377,22 @@ def convert_point_data_to_cell_data(input_data: "vtk.vtkDataSet") -> "vtk.vtkDat point_to_cell_filter = vtk.vtkPointDataToCellData() point_to_cell_filter.SetInputData(input_data) point_to_cell_filter.Update() - + return point_to_cell_filter.GetOutput() def get_node_to_elem(polydata: "vtk.vtkDataSet") -> "vtk.vtkDataSet": """Convert point data to cell data for VTK dataset. - + This function transforms field variables defined at mesh vertices to values defined at cell centers using VTK's built-in conversion filter. - + Args: polydata: VTK dataset with point data to be converted. - + Returns: VTK dataset with field data moved from points to cells. - + Examples: >>> cell_data = get_node_to_elem(point_based_mesh) """ @@ -407,24 +403,26 @@ def get_node_to_elem(polydata: "vtk.vtkDataSet") -> "vtk.vtkDataSet": return cell_data -def get_fields_from_cell(cell_data: "vtk.vtkCellData", variable_names: list[str]) -> np.ndarray: +def get_fields_from_cell( + cell_data: "vtk.vtkCellData", variable_names: list[str] +) -> np.ndarray: """Extract field variables from VTK cell data. - + This function extracts multiple field variables from VTK cell data and organizes them into a structured NumPy array. Each variable becomes a column in the output array. - + Args: cell_data: VTK cell data object containing field variables. variable_names: List of variable names to extract from the cell data. - + Returns: NumPy array of shape (n_cells, n_variables) containing the extracted field data. Variables are ordered according to the input list. - + Raises: ValueError: If a requested variable name is not found in the cell data. - + Examples: >>> # Extract pressure and velocity magnitude from cell data >>> variable_names = ["pressure", "velocity_magnitude"] @@ -436,7 +434,7 @@ def get_fields_from_cell(cell_data: "vtk.vtkCellData", variable_names: list[str] variable_array = cell_data.GetArray(variable_name) if variable_array is None: raise ValueError(f"Variable '{variable_name}' not found in cell data") - + num_tuples = variable_array.GetNumberOfTuples() field_values = [] for tuple_idx in range(num_tuples): @@ -444,31 +442,33 @@ def get_fields_from_cell(cell_data: "vtk.vtkCellData", variable_names: list[str] field_values.append(variable_value) field_values = np.asarray(field_values) extracted_fields.append(field_values) - + # Transpose to get shape (n_cells, n_variables) extracted_fields = np.transpose(np.asarray(extracted_fields), (1, 0)) return extracted_fields -def get_fields(data_attributes: "vtk.vtkDataSetAttributes", variable_names: list[str]) -> list[np.ndarray]: +def get_fields( + data_attributes: "vtk.vtkDataSetAttributes", variable_names: list[str] +) -> list[np.ndarray]: """Extract multiple field variables from VTK data attributes. - + This function extracts field variables from VTK data attributes (either point data or cell data) and returns them as a list of NumPy arrays. It handles both point and cell data seamlessly. - + Args: data_attributes: VTK data attributes object (point data or cell data). variable_names: List of variable names to extract. - + Returns: List of NumPy arrays, one for each requested variable. Each array has shape (n_points/n_cells, n_components) where n_components depends on the variable (1 for scalars, 3 for vectors, etc.). - + Raises: ValueError: If a requested variable is not found in the data attributes. - + Examples: >>> # Extract multiple variables from point data >>> point_data = mesh.GetPointData() @@ -481,30 +481,32 @@ def get_fields(data_attributes: "vtk.vtkDataSetAttributes", variable_names: list try: vtk_array = data_attributes.GetArray(variable_name) except ValueError as e: - raise ValueError(f"Failed to get array '{variable_name}' from the data attributes: {e}") - + raise ValueError( + f"Failed to get array '{variable_name}' from the data attributes: {e}" + ) + # Convert VTK array to NumPy array with proper shape numpy_array = numpy_support.vtk_to_numpy(vtk_array).reshape( vtk_array.GetNumberOfTuples(), vtk_array.GetNumberOfComponents() ) extracted_fields.append(numpy_array) - + return extracted_fields def get_vertices(polydata: "vtk.vtkPolyData") -> np.ndarray: """Extract vertex coordinates from VTK polydata object. - - This function converts VTK polydata to a NumPy array containing the 3D + + This function converts VTK polydata to a NumPy array containing the 3D coordinates of all vertices in the mesh. - + Args: polydata: VTK polydata object containing mesh geometry. - + Returns: NumPy array of shape (n_points, 3) containing [x, y, z] coordinates for each vertex. - + Examples: >>> vertices = get_vertices(mesh) >>> print(vertices.shape) # (n_vertices, 3) @@ -514,22 +516,24 @@ def get_vertices(polydata: "vtk.vtkPolyData") -> np.ndarray: return vertices -def get_volume_data(polydata: "vtk.vtkPolyData", variable_names: list[str]) -> tuple[np.ndarray, list[np.ndarray]]: +def get_volume_data( + polydata: "vtk.vtkPolyData", variable_names: list[str] +) -> tuple[np.ndarray, list[np.ndarray]]: """Extract vertices and field data from 3D volumetric mesh. - + This function extracts both geometric information (vertex coordinates) and field data from a 3D volumetric mesh. It's commonly used for processing finite element analysis results. - + Args: polydata: VTK polydata representing a 3D volumetric mesh. variable_names: List of field variable names to extract. - + Returns: Tuple containing: - Vertex coordinates as NumPy array of shape (n_vertices, 3) - List of field arrays, one per variable - + Examples: >>> # Extract geometry and flow fields from CFD results >>> vertices, fields = get_volume_data(polydata, ["pressure", "velocity"]) @@ -538,30 +542,32 @@ def get_volume_data(polydata: "vtk.vtkPolyData", variable_names: list[str]) -> t vertices = get_vertices(polydata) point_data = polydata.GetPointData() fields = get_fields(point_data, variable_names) - + return vertices, fields -def get_surface_data(polydata: "vtk.vtkPolyData", variable_names: list[str]) -> tuple[np.ndarray, list[np.ndarray], list[tuple[int, int]]]: +def get_surface_data( + polydata: "vtk.vtkPolyData", variable_names: list[str] +) -> tuple[np.ndarray, list[np.ndarray], list[tuple[int, int]]]: """Extract surface mesh data including vertices, fields, and edge connectivity. - + This function extracts comprehensive surface mesh information including vertex coordinates, field data at vertices, and edge connectivity information. It's commonly used for processing CFD surface results and boundary conditions. - + Args: polydata: VTK polydata representing a surface mesh. variable_names: List of field variable names to extract from the mesh. - + Returns: Tuple containing: - Vertex coordinates as NumPy array of shape (n_vertices, 3) - List of field arrays, one per variable - List of edge tuples representing mesh connectivity - + Raises: ValueError: If a requested variable is not found or polygon data is missing. - + Examples: >>> # Extract surface mesh data for visualization >>> vertices, fields, edges = get_surface_data(surface_mesh, ["pressure", "shear_stress"]) @@ -608,35 +614,35 @@ def calculate_normal_positional_encoding( cell_dimensions: Sequence[float] = (1.0, 1.0, 1.0), ) -> ArrayType: """Calculate sinusoidal positional encoding for 3D coordinates. - + This function computes transformer-style positional encodings for 3D spatial coordinates, enabling neural networks to understand spatial relationships. The encoding uses sinusoidal functions at different frequencies to create unique representations for each spatial position. - + Args: coordinates_a: Primary coordinates array of shape (n_points, 3). coordinates_b: Optional secondary coordinates for computing relative positions. If provided, the encoding is computed for (coordinates_a - coordinates_b). cell_dimensions: Characteristic length scales for x, y, z dimensions used for normalization. Defaults to unit dimensions. - + Returns: Array of shape (n_points, 12) containing positional encodings with 4 encoding dimensions per spatial axis (x, y, z). - + Examples: >>> coords = np.random.rand(100, 3) * 10 # Random 3D points >>> cell_size = [0.1, 0.1, 0.1] # Grid resolution >>> encoding = calculate_positional_encoding_for_coordinates(coords, cell_dimensions=cell_size) >>> print(encoding.shape) # (100, 12) - + >>> # For relative positions between two point sets >>> encoding_rel = calculate_positional_encoding_for_coordinates(coords_a, coords_b, cell_size) """ dx, dy, dz = cell_dimensions[0], cell_dimensions[1], cell_dimensions[2] xp = array_type(coordinates_a) - + if coordinates_b is not None: normals = coordinates_a - coordinates_b pos_x = xp.asarray(calculate_pos_encoding(normals[:, 0] / dx, d=4)) @@ -653,25 +659,27 @@ def calculate_normal_positional_encoding( return pos_normals -def nd_interpolator(coordinates: ArrayType, field: ArrayType, grid: ArrayType) -> ArrayType: +def nd_interpolator( + coordinates: ArrayType, field: ArrayType, grid: ArrayType +) -> ArrayType: """Perform n-dimensional interpolation using k-nearest neighbors. - + This function interpolates field values from scattered points to a regular grid using k-nearest neighbor averaging. It's useful for reconstructing fields on regular grids from irregular measurement points. - + Args: coordinates: Array of shape (n_points, n_dims) containing source point coordinates. field: Array of shape (n_points, n_fields) containing field values at source points. grid: Array containing target grid points for interpolation. - + Returns: Interpolated field values at grid points using 2-nearest neighbor averaging. - + Note: This function currently uses SciPy's KDTree which only supports CPU arrays. A future enhancement could add CuML support for GPU acceleration. - + Examples: >>> # Interpolate scattered CFD data to regular grid >>> grid_values = nd_interpolator(mesh_coords, pressure_field, regular_grid) @@ -687,20 +695,20 @@ def nd_interpolator(coordinates: ArrayType, field: ArrayType, grid: ArrayType) - def pad(arr: ArrayType, n_points: int, pad_value: float = 0.0) -> ArrayType: """Pad 2D array with constant values to reach target size. - + This function extends a 2D array by adding rows filled with a constant value. It's commonly used to standardize array sizes in batch processing for machine learning applications. - + Args: arr: Input array of shape (n_points, n_features) to be padded. n_points: Target number of points (rows) after padding. pad_value: Constant value used for padding. Defaults to 0.0. - + Returns: Padded array of shape (n_points, n_features). If n_points <= arr.shape[0], returns the original array unchanged. - + Examples: >>> arr = np.random.rand(100, 3) >>> padded = pad(arr, 150, -1.0) # Pad to 150 points with -1.0 @@ -709,7 +717,7 @@ def pad(arr: ArrayType, n_points: int, pad_value: float = 0.0) -> ArrayType: xp = array_type(arr) if n_points <= arr.shape[0]: return arr - + arr_pad = pad_value * xp.ones( (n_points - arr.shape[0], arr.shape[1]), dtype=xp.float32 ) @@ -719,20 +727,20 @@ def pad(arr: ArrayType, n_points: int, pad_value: float = 0.0) -> ArrayType: def pad_inp(arr: ArrayType, n_points: int, pad_value: float = 0.0) -> ArrayType: """Pad 3D array with constant values to reach target size. - + This function extends a 3D array by adding entries along the first dimension filled with a constant value. Used for standardizing 3D tensor sizes in batch processing workflows. - + Args: arr: Input array of shape (n_points, height, width) to be padded. n_points: Target number of points along first dimension after padding. pad_value: Constant value used for padding. Defaults to 0.0. - + Returns: Padded array of shape (n_points, height, width). If n_points <= arr.shape[0], returns the original array unchanged. - + Examples: >>> arr = np.random.rand(50, 10, 5) >>> padded = pad_inp(arr, 100, 0.0) # Pad to 100 entries @@ -741,7 +749,7 @@ def pad_inp(arr: ArrayType, n_points: int, pad_value: float = 0.0) -> ArrayType: xp = array_type(arr) if n_points <= arr.shape[0]: return arr - + arr_pad = pad_value * xp.ones( (n_points - arr.shape[0], arr.shape[1], arr.shape[2]), dtype=xp.float32 ) @@ -755,21 +763,21 @@ def shuffle_array( n_points: int, ) -> tuple[ArrayType, ArrayType]: """Randomly sample points from array without replacement. - + This function performs random sampling from the input array, selecting n_points points without replacement. It's commonly used for creating training subsets and data augmentation in machine learning workflows. - + Args: arr: Input array to sample from, shape (n_points, ...). n_points: Number of points to sample. If greater than arr.shape[0], all points are returned. - + Returns: Tuple containing: - Sampled array subset - Indices of the selected points - + Examples: >>> data = np.random.rand(1000, 3) >>> subset, indices = shuffle_array(data, 100) @@ -785,19 +793,19 @@ def shuffle_array( def shuffle_array_without_sampling(arr: ArrayType) -> tuple[ArrayType, ArrayType]: """Shuffle array order without changing the number of elements. - + This function reorders all elements in the array randomly while preserving all data points. It's useful for randomizing data order before training while maintaining the complete dataset. - + Args: arr: Input array to shuffle, shape (n_points, ...). - + Returns: Tuple containing: - Shuffled array with same shape as input - Permutation indices used for shuffling - + Examples: >>> data = np.arange(100).reshape(100, 1) >>> shuffled, indices = shuffle_array_without_sampling(data) @@ -811,13 +819,13 @@ def shuffle_array_without_sampling(arr: ArrayType) -> tuple[ArrayType, ArrayType def create_directory(filepath: str | Path) -> None: """Create directory and all necessary parent directories. - + This function creates a directory at the specified path, including any necessary parent directories. It's equivalent to `mkdir -p` in Unix systems. - + Args: filepath: Path to the directory to create. Can be string or Path object. - + Examples: >>> create_directory("data/processed/meshes") >>> create_directory(Path("output") / "results" / "visualization") @@ -827,23 +835,23 @@ def create_directory(filepath: str | Path) -> None: def get_filenames(filepath: str | Path, exclude_dirs: bool = False) -> list[str]: """Get list of filenames in a directory with optional directory filtering. - + This function returns all items in a directory, with options to exclude subdirectories. It handles special cases like .zarr directories which are treated as files even when exclude_dirs is True. - + Args: filepath: Path to the directory to list. Can be string or Path object. exclude_dirs: If True, exclude subdirectories from results. Exception: .zarr directories are always included as they represent data files in array storage format. - + Returns: List of filenames/directory names found in the specified directory. - + Raises: FileNotFoundError: If the specified directory does not exist. - + Examples: >>> files = get_filenames("data/inputs") >>> data_files = get_filenames("results", exclude_dirs=True) @@ -851,7 +859,7 @@ def get_filenames(filepath: str | Path, exclude_dirs: bool = False) -> list[str] path = Path(filepath) if not path.exists(): raise FileNotFoundError(f"Directory {filepath} does not exist") - + filenames = [] for item in path.iterdir(): if exclude_dirs and item.is_dir(): @@ -865,19 +873,19 @@ def get_filenames(filepath: str | Path, exclude_dirs: bool = False) -> list[str] def calculate_pos_encoding(nx: ArrayType, d: int = 8) -> list[ArrayType]: """Calculate sinusoidal positional encoding for transformer architectures. - + This function computes positional encodings using alternating sine and cosine functions at different frequencies. These encodings help neural networks understand positional relationships in sequences or spatial data. - + Args: nx: Input positions/coordinates to encode. d: Encoding dimensionality. Must be even number. Defaults to 8. - + Returns: List of d arrays containing alternating sine and cosine encodings. Each pair (sin, cos) uses progressively lower frequencies. - + Examples: >>> positions = np.linspace(0, 100, 50) >>> encodings = calculate_pos_encoding(positions, d=16) @@ -893,22 +901,22 @@ def calculate_pos_encoding(nx: ArrayType, d: int = 8) -> list[ArrayType]: def combine_dict(old_dict: dict[Any, Any], new_dict: dict[Any, Any]) -> dict[Any, Any]: """Combine two dictionaries by adding values for matching keys. - + This function performs element-wise addition of dictionary values for keys that exist in both dictionaries. It's commonly used for accumulating statistics or metrics across multiple iterations. - + Args: old_dict: Base dictionary to update. new_dict: Dictionary with values to add to old_dict. - + Returns: Updated old_dict with combined values. - + Note: This function modifies old_dict in place and returns it. Values must support the + operator. - + Examples: >>> stats1 = {"loss": 0.5, "accuracy": 0.8} >>> stats2 = {"loss": 0.3, "accuracy": 0.1} @@ -920,22 +928,24 @@ def combine_dict(old_dict: dict[Any, Any], new_dict: dict[Any, Any]) -> dict[Any return old_dict -def create_grid(max_coords: ArrayType, min_coords: ArrayType, resolution: ArrayType) -> ArrayType: +def create_grid( + max_coords: ArrayType, min_coords: ArrayType, resolution: ArrayType +) -> ArrayType: """Create a 3D regular grid from coordinate bounds and resolution. - + This function generates a regular 3D grid spanning from min_coords to max_coords with the specified resolution in each dimension. The resulting grid is commonly used for interpolation, visualization, and regular sampling. - + Args: max_coords: Maximum coordinates [x_max, y_max, z_max] for the grid bounds. min_coords: Minimum coordinates [x_min, y_min, z_min] for the grid bounds. resolution: Number of grid points [nx, ny, nz] in each dimension. - + Returns: Grid array of shape (nx, ny, nz, 3) containing 3D coordinates for each grid point. The last dimension contains [x, y, z] coordinates. - + Examples: >>> # Create a 10x10x10 grid from (0,0,0) to (1,1,1) >>> min_bounds = np.array([0, 0, 0]) @@ -946,9 +956,15 @@ def create_grid(max_coords: ArrayType, min_coords: ArrayType, resolution: ArrayT """ xp = array_type(max_coords) - dx = xp.linspace(min_coords[0], max_coords[0], resolution[0], dtype=max_coords.dtype) - dy = xp.linspace(min_coords[1], max_coords[1], resolution[1], dtype=max_coords.dtype) - dz = xp.linspace(min_coords[2], max_coords[2], resolution[2], dtype=max_coords.dtype) + dx = xp.linspace( + min_coords[0], max_coords[0], resolution[0], dtype=max_coords.dtype + ) + dy = xp.linspace( + min_coords[1], max_coords[1], resolution[1], dtype=max_coords.dtype + ) + dz = xp.linspace( + min_coords[2], max_coords[2], resolution[2], dtype=max_coords.dtype + ) xv, yv, zv = xp.meshgrid(dx, dy, dz) xv = xp.expand_dims(xv, -1) @@ -964,21 +980,21 @@ def mean_std_sampling( field: ArrayType, mean: ArrayType, std: ArrayType, tolerance: float = 3.0 ) -> list[int]: """Identify outlier points based on statistical distance from mean. - + This function identifies data points that are statistical outliers by checking if they fall outside mean ± tolerance*std for any field component. It's useful for data cleaning and identifying regions of interest in CFD data. - + Args: field: Input field array of shape (n_points, n_components). mean: Mean values for each field component, shape (n_components,). std: Standard deviation for each component, shape (n_components,). tolerance: Number of standard deviations to use as outlier threshold. Defaults to 3.0 (99.7% of normal distribution). - + Returns: List of indices identifying outlier points that exceed the statistical threshold. - + Examples: >>> # Find outliers in pressure field >>> pressure_data = np.random.normal(100, 10, (1000, 1)) @@ -1000,23 +1016,25 @@ def mean_std_sampling( return idx_all -def dict_to_device(state_dict: dict[str, Any], device: Any, exclude_keys: list[str] | None = None) -> dict[str, Any]: +def dict_to_device( + state_dict: dict[str, Any], device: Any, exclude_keys: list[str] | None = None +) -> dict[str, Any]: """Move dictionary values to specified device (GPU/CPU). - + This function transfers PyTorch tensors in a dictionary to the specified compute device while preserving the dictionary structure. It's commonly used for moving model parameters and data between CPU and GPU. - + Args: state_dict: Dictionary containing tensors and other values. device: Target device (e.g., torch.device('cuda:0')). exclude_keys: List of keys to skip during device transfer. Defaults to ["filename"] if None. - + Returns: New dictionary with tensors moved to the specified device. Non-tensor values and excluded keys are preserved as-is. - + Examples: >>> import torch >>> data = {"weights": torch.randn(10, 10), "filename": "model.pt"} @@ -1024,7 +1042,7 @@ def dict_to_device(state_dict: dict[str, Any], device: Any, exclude_keys: list[s """ if exclude_keys is None: exclude_keys = ["filename"] - + new_state_dict = {} for k, v in state_dict.items(): if k not in exclude_keys: @@ -1036,28 +1054,28 @@ def area_weighted_shuffle_array( arr: ArrayType, n_points: int, area: ArrayType ) -> tuple[ArrayType, ArrayType]: """Perform area-weighted random sampling from array. - + This function samples points from an array with probability proportional to their associated area weights. This is particularly useful in CFD applications where larger cells or surface elements should have higher sampling probability. - + Args: arr: Input array to sample from, shape (n_points, ...). n_points: Number of points to sample. If greater than arr.shape[0], samples all available points. area: Area weights for each point, shape (n_points,). Larger values indicate higher sampling probability. - + Returns: Tuple containing: - Sampled array subset weighted by area - Indices of the selected points - + Note: For GPU arrays (CuPy), the sampling is performed on CPU due to memory efficiency considerations. The Alias method could be implemented for future GPU acceleration. - + Examples: >>> # Sample mesh points with area weighting >>> mesh_data = np.random.rand(1000, 3) From 57dba919e7e06feb6748f1034dfc13acb4fcb694 Mon Sep 17 00:00:00 2001 From: Peter Sharpe Date: Fri, 20 Jun 2025 08:33:09 -0400 Subject: [PATCH 07/19] Fixes import order --- physicsnemo/utils/domino/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/physicsnemo/utils/domino/utils.py b/physicsnemo/utils/domino/utils.py index 7fcbeb5d0e..230ef2058e 100644 --- a/physicsnemo/utils/domino/utils.py +++ b/physicsnemo/utils/domino/utils.py @@ -25,13 +25,13 @@ from pathlib import Path from typing import Any, Sequence -import numpy as np import cupy as cp +import numpy as np import pyvista as pv import vtk +from scipy.spatial import KDTree from vtk import vtkDataSetTriangleFilter from vtk.util import numpy_support -from scipy.spatial import KDTree from physicsnemo.utils.profiling import profile From 5fc6d963ad95de6ebc856cd7174eb20ad2387d0f Mon Sep 17 00:00:00 2001 From: Peter Sharpe Date: Fri, 20 Jun 2025 08:49:32 -0400 Subject: [PATCH 08/19] black formatting --- .../domino/src/inference_on_stl.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/examples/cfd/external_aerodynamics/domino/src/inference_on_stl.py b/examples/cfd/external_aerodynamics/domino/src/inference_on_stl.py index 6a9701df2a..b76666a7ff 100644 --- a/examples/cfd/external_aerodynamics/domino/src/inference_on_stl.py +++ b/examples/cfd/external_aerodynamics/domino/src/inference_on_stl.py @@ -35,7 +35,13 @@ import torch from physicsnemo.models.domino.model import DoMINO -from physicsnemo.utils.domino.utils import unnormalize, create_directory, nd_interpolator, get_filenames, write_to_vtp +from physicsnemo.utils.domino.utils import ( + unnormalize, + create_directory, + nd_interpolator, + get_filenames, + write_to_vtp, +) from torch.cuda.amp import autocast from torch.nn.parallel import DistributedDataParallel from physicsnemo.distributed import DistributedManager @@ -663,7 +669,7 @@ def __init__( self, cfg: DictConfig, dist: None, - cached_geo_encoding: bool=False, + cached_geo_encoding: bool = False, ): self.cfg = cfg From d0fd16b93a84c2120af52fe5bf24ecc3a1d5088e Mon Sep 17 00:00:00 2001 From: Peter Sharpe Date: Fri, 20 Jun 2025 10:31:16 -0400 Subject: [PATCH 09/19] Reshape accepts a shape, not a splatted iterable --- physicsnemo/datapipes/cae/domino_datapipe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/physicsnemo/datapipes/cae/domino_datapipe.py b/physicsnemo/datapipes/cae/domino_datapipe.py index 159ab0f5a2..ad83331f24 100644 --- a/physicsnemo/datapipes/cae/domino_datapipe.py +++ b/physicsnemo/datapipes/cae/domino_datapipe.py @@ -843,7 +843,7 @@ def preprocess_volume( mesh_indices_flattened, grid_reshaped, use_sign_winding_number=True, - ).reshape(nx, ny, nz) + ).reshape((nx, ny, nz)) if self.config.sampling: volume_coordinates_sampled, idx_volume = shuffle_array( From f41a1bf32f01ae5c0796d6f582e6448619ada12c Mon Sep 17 00:00:00 2001 From: Peter Sharpe Date: Fri, 20 Jun 2025 10:33:28 -0400 Subject: [PATCH 10/19] Fixes lost array axis --- physicsnemo/utils/domino/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/physicsnemo/utils/domino/utils.py b/physicsnemo/utils/domino/utils.py index 230ef2058e..4352a34593 100644 --- a/physicsnemo/utils/domino/utils.py +++ b/physicsnemo/utils/domino/utils.py @@ -73,7 +73,7 @@ def calculate_center_of_mass(centers: ArrayType, sizes: ArrayType) -> ArrayType: or area of each element used as weights. Returns: - Array of shape (3,) containing the x, y, z coordinates of the center of mass. + Array of shape (1, 3) containing the x, y, z coordinates of the center of mass. Raises: ValueError: If centers and sizes have incompatible shapes. @@ -82,14 +82,14 @@ def calculate_center_of_mass(centers: ArrayType, sizes: ArrayType) -> ArrayType: >>> centers = np.array([[0, 0, 0], [1, 1, 1], [2, 2, 2]]) >>> sizes = np.array([1.0, 2.0, 3.0]) >>> com = calculate_center_of_mass(centers, sizes) - >>> print(com) # [1.5, 1.5, 1.5] + >>> print(com) # [[1.5, 1.5, 1.5]] """ xp = array_type(centers) - total_weighted_position = xp.einsum("i,ij->ij", sizes, centers) + total_weighted_position = xp.einsum("i,ij->j", sizes, centers) total_size = xp.sum(sizes) - return total_weighted_position / total_size + return total_weighted_position[None, ...] / total_size def normalize( From d2167a92644e0b8f591c6eec36cf9883b7dc247e Mon Sep 17 00:00:00 2001 From: Peter Sharpe Date: Fri, 20 Jun 2025 11:42:27 -0400 Subject: [PATCH 11/19] Enhances docstrings in utils.py with examples and improved clarity; removes outdated examples. --- physicsnemo/utils/domino/utils.py | 226 +++++++++++++++++------------- 1 file changed, 128 insertions(+), 98 deletions(-) diff --git a/physicsnemo/utils/domino/utils.py b/physicsnemo/utils/domino/utils.py index 4352a34593..a7de84e7f7 100644 --- a/physicsnemo/utils/domino/utils.py +++ b/physicsnemo/utils/domino/utils.py @@ -79,10 +79,12 @@ def calculate_center_of_mass(centers: ArrayType, sizes: ArrayType) -> ArrayType: ValueError: If centers and sizes have incompatible shapes. Examples: - >>> centers = np.array([[0, 0, 0], [1, 1, 1], [2, 2, 2]]) + >>> import numpy as np + >>> centers = np.array([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]) >>> sizes = np.array([1.0, 2.0, 3.0]) >>> com = calculate_center_of_mass(centers, sizes) - >>> print(com) # [[1.5, 1.5, 1.5]] + >>> np.allclose(com, [[1.5, 1.5, 1.5]]) + True """ xp = array_type(centers) @@ -115,12 +117,15 @@ def normalize( ZeroDivisionError: If max_val equals min_val (zero range). Examples: + >>> import numpy as np >>> field = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) >>> normalized = normalize(field, 5.0, 1.0) - >>> print(normalized) # [-1, -0.5, 0, 0.5, 1] - + >>> np.allclose(normalized, [-1.0, -0.5, 0.0, 0.5, 1.0]) + True >>> # Auto-compute min/max - >>> normalized = normalize(field) + >>> normalized_auto = normalize(field) + >>> np.allclose(normalized_auto, [-1.0, -0.5, 0.0, 0.5, 1.0]) + True """ xp = array_type(field) @@ -150,9 +155,11 @@ def unnormalize( Field values restored to their original physical range. Examples: + >>> import numpy as np >>> normalized = np.array([-1.0, -0.5, 0.0, 0.5, 1.0]) >>> original = unnormalize(normalized, 5.0, 1.0) - >>> print(original) # [1, 2, 3, 4, 5] + >>> np.allclose(original, [1.0, 2.0, 3.0, 4.0, 5.0]) + True """ field_range = max_val - min_val return (normalized_field + 1.0) * field_range * 0.5 + min_val @@ -179,12 +186,17 @@ def standardize( ZeroDivisionError: If std contains zeros. Examples: + >>> import numpy as np >>> field = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) >>> standardized = standardize(field, 3.0, np.sqrt(2.5)) - >>> print(np.mean(standardized), np.std(standardized)) # ~0.0, ~1.0 - + >>> np.allclose(standardized, [-1.265, -0.632, 0.0, 0.632, 1.265], atol=1e-3) + True >>> # Auto-compute mean/std - >>> standardized = standardize(field) + >>> standardized_auto = standardize(field) + >>> np.abs(np.mean(standardized_auto)) < 1e-15 + True + >>> np.allclose(np.std(standardized_auto, ddof=0), 1.0) + True """ xp = array_type(field) @@ -213,9 +225,11 @@ def unstandardize( Field values restored to their original distribution. Examples: - >>> standardized = np.array([-1.26, -0.63, 0.0, 0.63, 1.26]) + >>> import numpy as np + >>> standardized = np.array([-1.265, -0.632, 0.0, 0.632, 1.265]) >>> original = unstandardize(standardized, 3.0, np.sqrt(2.5)) - >>> print(original) # approximately [1, 2, 3, 4, 5] + >>> np.allclose(original, [1.0, 2.0, 3.0, 4.0, 5.0], atol=1e-3) + True """ return standardized_field * std + mean @@ -235,9 +249,6 @@ def write_to_vtp(polydata: "vtk.vtkPolyData", filename: str) -> None: Raises: RuntimeError: If writing fails due to file permissions or disk space. - Examples: - >>> # Assuming you have a VTK polydata object - >>> write_to_vtp(surface_mesh, "output/surface.vtp") """ # Ensure output directory exists output_path = Path(filename) @@ -267,9 +278,6 @@ def write_to_vtu(unstructured_grid: "vtk.vtkUnstructuredGrid", filename: str) -> Raises: RuntimeError: If writing fails due to file permissions or disk space. - Examples: - >>> # Assuming you have a VTK unstructured grid object - >>> write_to_vtu(volume_mesh, "output/volume.vtu") """ # Ensure output directory exists output_path = Path(filename) @@ -300,10 +308,6 @@ def extract_surface_triangles(tetrahedral_mesh: "vtk.vtkUnstructuredGrid") -> li Raises: NotImplementedError: If the surface contains non-triangular faces. - Examples: - >>> # Extract surface from a tet mesh for visualization - >>> surface_indices = extract_surface_triangles(tet_mesh) - >>> # surface_indices = [v1, v2, v3, v4, v5, v6, ...] for triangles """ # Extract the surface using VTK filter surface_filter = vtk.vtkDataSetSurfaceFilter() @@ -343,10 +347,7 @@ def convert_to_tet_mesh(polydata: "vtk.vtkPolyData") -> "vtk.vtkUnstructuredGrid Raises: RuntimeError: If tetrahedralization fails (e.g., non-manifold surface). - - Examples: - >>> # Convert a surface mesh to volume mesh for FEM analysis - >>> tet_mesh = convert_polydata_to_tetrahedral_mesh(surface_polydata) + """ tetrahedral_filter = vtkDataSetTriangleFilter() tetrahedral_filter.SetInputData(polydata) @@ -370,9 +371,6 @@ def convert_point_data_to_cell_data(input_data: "vtk.vtkDataSet") -> "vtk.vtkDat VTK dataset with the same geometry but field data moved from points to cells. Values are typically averaged from the surrounding points. - Examples: - >>> # Convert nodal pressure values to cell-centered values - >>> cell_data = convert_point_data_to_cell_data(point_based_mesh) """ point_to_cell_filter = vtk.vtkPointDataToCellData() point_to_cell_filter.SetInputData(input_data) @@ -393,8 +391,6 @@ def get_node_to_elem(polydata: "vtk.vtkDataSet") -> "vtk.vtkDataSet": Returns: VTK dataset with field data moved from points to cells. - Examples: - >>> cell_data = get_node_to_elem(point_based_mesh) """ point_to_cell_filter = vtk.vtkPointDataToCellData() point_to_cell_filter.SetInputData(polydata) @@ -423,11 +419,6 @@ def get_fields_from_cell( Raises: ValueError: If a requested variable name is not found in the cell data. - Examples: - >>> # Extract pressure and velocity magnitude from cell data - >>> variable_names = ["pressure", "velocity_magnitude"] - >>> fields = get_fields_from_cell(cell_data, variable_names) - >>> print(fields.shape) # (n_cells, 2) """ extracted_fields = [] for variable_name in variable_names: @@ -469,12 +460,6 @@ def get_fields( Raises: ValueError: If a requested variable is not found in the data attributes. - Examples: - >>> # Extract multiple variables from point data - >>> point_data = mesh.GetPointData() - >>> variable_names = ["pressure", "velocity", "temperature"] - >>> fields = get_fields(point_data, variable_names) - >>> print(len(fields)) # 3 arrays """ extracted_fields = [] for variable_name in variable_names: @@ -507,9 +492,6 @@ def get_vertices(polydata: "vtk.vtkPolyData") -> np.ndarray: NumPy array of shape (n_points, 3) containing [x, y, z] coordinates for each vertex. - Examples: - >>> vertices = get_vertices(mesh) - >>> print(vertices.shape) # (n_vertices, 3) """ vtk_points = polydata.GetPoints() vertices = numpy_support.vtk_to_numpy(vtk_points.GetData()) @@ -534,10 +516,6 @@ def get_volume_data( - Vertex coordinates as NumPy array of shape (n_vertices, 3) - List of field arrays, one per variable - Examples: - >>> # Extract geometry and flow fields from CFD results - >>> vertices, fields = get_volume_data(polydata, ["pressure", "velocity"]) - >>> print(vertices.shape, len(fields)) # (n_vertices, 3), 2 """ vertices = get_vertices(polydata) point_data = polydata.GetPointData() @@ -568,10 +546,6 @@ def get_surface_data( Raises: ValueError: If a requested variable is not found or polygon data is missing. - Examples: - >>> # Extract surface mesh data for visualization - >>> vertices, fields, edges = get_surface_data(surface_mesh, ["pressure", "shear_stress"]) - >>> print(vertices.shape, len(fields), len(edges)) """ points = polydata.GetPoints() vertices = np.array([points.GetPoint(i) for i in range(points.GetNumberOfPoints())]) @@ -632,13 +606,17 @@ def calculate_normal_positional_encoding( 4 encoding dimensions per spatial axis (x, y, z). Examples: - >>> coords = np.random.rand(100, 3) * 10 # Random 3D points - >>> cell_size = [0.1, 0.1, 0.1] # Grid resolution - >>> encoding = calculate_positional_encoding_for_coordinates(coords, cell_dimensions=cell_size) - >>> print(encoding.shape) # (100, 12) - - >>> # For relative positions between two point sets - >>> encoding_rel = calculate_positional_encoding_for_coordinates(coords_a, coords_b, cell_size) + >>> import numpy as np + >>> coords = np.array([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]) + >>> cell_size = [0.1, 0.1, 0.1] + >>> encoding = calculate_normal_positional_encoding(coords, cell_dimensions=cell_size) + >>> encoding.shape + (2, 12) + >>> # Relative positioning example + >>> coords_b = np.array([[0.5, 0.5, 0.5], [0.5, 0.5, 0.5]]) + >>> encoding_rel = calculate_normal_positional_encoding(coords, coords_b, cell_size) + >>> encoding_rel.shape + (2, 12) """ dx, dy, dz = cell_dimensions[0], cell_dimensions[1], cell_dimensions[2] xp = array_type(coordinates_a) @@ -681,8 +659,14 @@ def nd_interpolator( A future enhancement could add CuML support for GPU acceleration. Examples: - >>> # Interpolate scattered CFD data to regular grid - >>> grid_values = nd_interpolator(mesh_coords, pressure_field, regular_grid) + >>> import numpy as np + >>> # Simple 2D interpolation example + >>> coords = np.array([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) + >>> field_vals = np.array([[1.0], [2.0], [3.0], [4.0]]) + >>> grid_points = np.array([[0.5, 0.5]]) + >>> result = nd_interpolator([coords], field_vals, grid_points) + >>> result.shape[0] == 1 # One grid point + True """ # TODO - this function should get updated for cuml if using cupy. interp_func = KDTree(coordinates[0]) @@ -710,9 +694,19 @@ def pad(arr: ArrayType, n_points: int, pad_value: float = 0.0) -> ArrayType: returns the original array unchanged. Examples: - >>> arr = np.random.rand(100, 3) - >>> padded = pad(arr, 150, -1.0) # Pad to 150 points with -1.0 - >>> print(padded.shape) # (150, 3) + >>> import numpy as np + >>> arr = np.array([[1.0, 2.0], [3.0, 4.0]]) + >>> padded = pad(arr, 4, -1.0) + >>> padded.shape + (4, 2) + >>> np.array_equal(padded[:2], arr) + True + >>> np.all(padded[2:] == -1.0) + True + >>> # No padding needed + >>> same = pad(arr, 2) + >>> np.array_equal(same, arr) + True """ xp = array_type(arr) if n_points <= arr.shape[0]: @@ -742,9 +736,15 @@ def pad_inp(arr: ArrayType, n_points: int, pad_value: float = 0.0) -> ArrayType: returns the original array unchanged. Examples: - >>> arr = np.random.rand(50, 10, 5) - >>> padded = pad_inp(arr, 100, 0.0) # Pad to 100 entries - >>> print(padded.shape) # (100, 10, 5) + >>> import numpy as np + >>> arr = np.array([[[1.0, 2.0]], [[3.0, 4.0]]]) + >>> padded = pad_inp(arr, 4, 0.0) + >>> padded.shape + (4, 1, 2) + >>> np.array_equal(padded[:2], arr) + True + >>> np.all(padded[2:] == 0.0) + True """ xp = array_type(arr) if n_points <= arr.shape[0]: @@ -779,9 +779,16 @@ def shuffle_array( - Indices of the selected points Examples: - >>> data = np.random.rand(1000, 3) - >>> subset, indices = shuffle_array(data, 100) - >>> print(subset.shape, indices.shape) # (100, 3), (100,) + >>> import numpy as np + >>> np.random.seed(42) # For reproducible results + >>> data = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) + >>> subset, indices = shuffle_array(data, 2) + >>> subset.shape + (2, 2) + >>> indices.shape + (2,) + >>> len(np.unique(indices)) == 2 # No duplicates + True """ xp = array_type(arr) if n_points > arr.shape[0]: @@ -807,9 +814,16 @@ def shuffle_array_without_sampling(arr: ArrayType) -> tuple[ArrayType, ArrayType - Permutation indices used for shuffling Examples: - >>> data = np.arange(100).reshape(100, 1) + >>> import numpy as np + >>> np.random.seed(42) # For reproducible results + >>> data = np.array([[1], [2], [3], [4]]) >>> shuffled, indices = shuffle_array_without_sampling(data) - >>> print(shuffled.shape) # (100, 1) - same size, different order + >>> shuffled.shape + (4, 1) + >>> indices.shape + (4,) + >>> set(indices) == set(range(4)) # All original indices present + True """ xp = array_type(arr) idx = xp.arange(arr.shape[0]) @@ -826,9 +840,6 @@ def create_directory(filepath: str | Path) -> None: Args: filepath: Path to the directory to create. Can be string or Path object. - Examples: - >>> create_directory("data/processed/meshes") - >>> create_directory(Path("output") / "results" / "visualization") """ Path(filepath).mkdir(parents=True, exist_ok=True) @@ -852,9 +863,6 @@ def get_filenames(filepath: str | Path, exclude_dirs: bool = False) -> list[str] Raises: FileNotFoundError: If the specified directory does not exist. - Examples: - >>> files = get_filenames("data/inputs") - >>> data_files = get_filenames("results", exclude_dirs=True) """ path = Path(filepath) if not path.exists(): @@ -887,9 +895,13 @@ def calculate_pos_encoding(nx: ArrayType, d: int = 8) -> list[ArrayType]: Each pair (sin, cos) uses progressively lower frequencies. Examples: - >>> positions = np.linspace(0, 100, 50) - >>> encodings = calculate_pos_encoding(positions, d=16) - >>> print(len(encodings)) # 16 encoding dimensions + >>> import numpy as np + >>> positions = np.array([0.0, 1.0, 2.0]) + >>> encodings = calculate_pos_encoding(positions, d=4) + >>> len(encodings) + 4 + >>> all(enc.shape == (3,) for enc in encodings) + True """ vec = [] xp = array_type(nx) @@ -921,7 +933,10 @@ def combine_dict(old_dict: dict[Any, Any], new_dict: dict[Any, Any]) -> dict[Any >>> stats1 = {"loss": 0.5, "accuracy": 0.8} >>> stats2 = {"loss": 0.3, "accuracy": 0.1} >>> combined = combine_dict(stats1, stats2) - >>> print(combined) # {"loss": 0.8, "accuracy": 0.9} + >>> combined["loss"] + 0.8 + >>> combined["accuracy"] + 0.8999999999999999 """ for key in old_dict.keys(): old_dict[key] += new_dict[key] @@ -947,12 +962,17 @@ def create_grid( grid point. The last dimension contains [x, y, z] coordinates. Examples: - >>> # Create a 10x10x10 grid from (0,0,0) to (1,1,1) - >>> min_bounds = np.array([0, 0, 0]) - >>> max_bounds = np.array([1, 1, 1]) - >>> grid_res = np.array([10, 10, 10]) + >>> import numpy as np + >>> min_bounds = np.array([0.0, 0.0, 0.0]) + >>> max_bounds = np.array([1.0, 1.0, 1.0]) + >>> grid_res = np.array([2, 2, 2]) >>> grid = create_grid(max_bounds, min_bounds, grid_res) - >>> print(grid.shape) # (10, 10, 10, 3) + >>> grid.shape + (2, 2, 2, 3) + >>> np.allclose(grid[0, 0, 0], [0.0, 0.0, 0.0]) + True + >>> np.allclose(grid[1, 1, 1], [1.0, 1.0, 1.0]) + True """ xp = array_type(max_coords) @@ -996,12 +1016,14 @@ def mean_std_sampling( List of indices identifying outlier points that exceed the statistical threshold. Examples: - >>> # Find outliers in pressure field - >>> pressure_data = np.random.normal(100, 10, (1000, 1)) - >>> pressure_mean = np.array([100.0]) - >>> pressure_std = np.array([10.0]) - >>> outliers = mean_std_sampling(pressure_data, pressure_mean, pressure_std, 2.5) - >>> print(f"Found {len(outliers)} outlier points") + >>> import numpy as np + >>> # Create test data with outliers + >>> field = np.array([[1.0], [2.0], [3.0], [10.0]]) # 10.0 is outlier + >>> field_mean = np.array([2.0]) + >>> field_std = np.array([1.0]) + >>> outliers = mean_std_sampling(field, field_mean, field_std, 2.0) + >>> 3 in outliers # Index 3 (value 10.0) should be detected as outlier + True """ xp = array_type(field) idx_all = [] @@ -1077,11 +1099,18 @@ def area_weighted_shuffle_array( future GPU acceleration. Examples: - >>> # Sample mesh points with area weighting - >>> mesh_data = np.random.rand(1000, 3) - >>> cell_areas = np.random.exponential(1.0, 1000) # Area weights - >>> subset, indices = area_weighted_shuffle_array(mesh_data, 100, cell_areas) - >>> print(subset.shape) # (100, 3) - larger cells more likely selected + >>> import numpy as np + >>> np.random.seed(42) # For reproducible results + >>> mesh_data = np.array([[1.0], [2.0], [3.0], [4.0]]) + >>> cell_areas = np.array([0.1, 0.1, 0.1, 10.0]) # Last point has much larger area + >>> subset, indices = area_weighted_shuffle_array(mesh_data, 2, cell_areas) + >>> subset.shape + (2, 1) + >>> indices.shape + (2,) + >>> # The point with large area (index 3) should likely be selected + >>> len(set(indices)) <= 2 # At most 2 unique indices + True """ xp = array_type(arr) # Compute the total_area: @@ -1110,3 +1139,4 @@ def area_weighted_shuffle_array( ids = xp.random.choice(idx, n_points, p=probs) return arr[ids], ids + From a7d844841c33738cdc3b372d42c52cea19d6b599 Mon Sep 17 00:00:00 2001 From: Peter Sharpe Date: Fri, 20 Jun 2025 11:46:52 -0400 Subject: [PATCH 12/19] Enhances area_weighted_shuffle_array function by adding area_factor parameter for adjustable sampling bias; updates docstring with detailed explanation and examples. --- physicsnemo/utils/domino/utils.py | 51 ++++++++++++++++++------------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/physicsnemo/utils/domino/utils.py b/physicsnemo/utils/domino/utils.py index a7de84e7f7..67ca19f7ef 100644 --- a/physicsnemo/utils/domino/utils.py +++ b/physicsnemo/utils/domino/utils.py @@ -1073,7 +1073,7 @@ def dict_to_device( def area_weighted_shuffle_array( - arr: ArrayType, n_points: int, area: ArrayType + arr: ArrayType, n_points: int, area: ArrayType, area_factor: float = 1.0 ) -> tuple[ArrayType, ArrayType]: """Perform area-weighted random sampling from array. @@ -1087,6 +1087,9 @@ def area_weighted_shuffle_array( samples all available points. area: Area weights for each point, shape (n_points,). Larger values indicate higher sampling probability. + area_factor: Exponent applied to area weights to control sampling bias. + Values > 1.0 increase bias toward larger areas, values < 1.0 reduce bias. + Defaults to 1.0 (linear weighting). Returns: Tuple containing: @@ -1111,32 +1114,36 @@ def area_weighted_shuffle_array( >>> # The point with large area (index 3) should likely be selected >>> len(set(indices)) <= 2 # At most 2 unique indices True + >>> # Use higher area_factor for stronger bias toward large areas + >>> subset_biased, _ = area_weighted_shuffle_array(mesh_data, 2, cell_areas, area_factor=2.0) """ xp = array_type(arr) - # Compute the total_area: - factor = 1.0 - total_area = xp.sum(area**factor) - probs = area**factor / total_area + # Calculate area-weighted probabilities + sampling_probabilities = area**area_factor + sampling_probabilities /= xp.sum(sampling_probabilities) # Normalize to sum to 1 - if n_points > arr.shape[0]: - n_points = arr.shape[0] + # Ensure we don't request more points than available + n_points = min(n_points, arr.shape[0]) - idx = xp.arange(arr.shape[0]) + # Create index array for all available points + point_indices = xp.arange(arr.shape[0]) - # This is too memory intensive to run on the GPU. + # Handle GPU vs CPU sampling differently due to memory constraints if xp == cp: - idx = idx.get() - probs = probs.get() - # Under the hood, this has a search over the probabilities. - # It's very expensive in memory, as far as I can tell. - # In principle, we could use the Alias method to speed this up - # on the GPU but it's not yet a bottleneck. - - ids = np.random.choice(idx, n_points, p=probs) - ids = xp.asarray(ids) + # Note: np.random.choice performs expensive probability search on CPU + # Future optimization: Consider implementing Alias method for GPU acceleration + selected_indices = np.random.choice( + point_indices.get(), + size=n_points, + p=sampling_probabilities.get() + ) + selected_indices = xp.asarray(selected_indices) else: - # Chug along on the CPU: - ids = xp.random.choice(idx, n_points, p=probs) - - return arr[ids], ids + # Direct sampling on CPU + selected_indices = np.random.choice( + point_indices, + size=n_points, + p=sampling_probabilities + ) + return arr[selected_indices], selected_indices From 87dafe01dc112fa510b84e63f3510ce1720d7ce3 Mon Sep 17 00:00:00 2001 From: Peter Sharpe Date: Fri, 20 Jun 2025 13:44:21 -0400 Subject: [PATCH 13/19] Updates docstrings in utils.py for accuracy and clarity; modifies examples in calculate_center_of_mass, standardize, nd_interpolator, pad, and pad_inp functions; adjusts k-nearest neighbors parameter in nd_interpolator for flexibility; corrects boolean checks in pad and pad_inp examples. --- physicsnemo/utils/domino/utils.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/physicsnemo/utils/domino/utils.py b/physicsnemo/utils/domino/utils.py index 67ca19f7ef..9eeb42a8f4 100644 --- a/physicsnemo/utils/domino/utils.py +++ b/physicsnemo/utils/domino/utils.py @@ -83,7 +83,7 @@ def calculate_center_of_mass(centers: ArrayType, sizes: ArrayType) -> ArrayType: >>> centers = np.array([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]) >>> sizes = np.array([1.0, 2.0, 3.0]) >>> com = calculate_center_of_mass(centers, sizes) - >>> np.allclose(com, [[1.5, 1.5, 1.5]]) + >>> np.allclose(com, [[4.0/3.0, 4.0/3.0, 4.0/3.0]]) True """ xp = array_type(centers) @@ -193,7 +193,7 @@ def standardize( True >>> # Auto-compute mean/std >>> standardized_auto = standardize(field) - >>> np.abs(np.mean(standardized_auto)) < 1e-15 + >>> np.allclose(np.mean(standardized_auto), 0.0) True >>> np.allclose(np.std(standardized_auto, ddof=0), 1.0) True @@ -638,7 +638,8 @@ def calculate_normal_positional_encoding( def nd_interpolator( - coordinates: ArrayType, field: ArrayType, grid: ArrayType + coordinates: ArrayType, field: ArrayType, grid: ArrayType, + k: int = 2 ) -> ArrayType: """Perform n-dimensional interpolation using k-nearest neighbors. @@ -649,10 +650,11 @@ def nd_interpolator( Args: coordinates: Array of shape (n_points, n_dims) containing source point coordinates. field: Array of shape (n_points, n_fields) containing field values at source points. - grid: Array containing target grid points for interpolation. - + grid: Array of shape (n_field_points, n_dims) containing target grid points for interpolation. + k: Number of nearest neighbors to use for interpolation. + Returns: - Interpolated field values at grid points using 2-nearest neighbor averaging. + Interpolated field values at grid points using k-nearest neighbor averaging. Note: This function currently uses SciPy's KDTree which only supports CPU arrays. @@ -669,11 +671,11 @@ def nd_interpolator( True """ # TODO - this function should get updated for cuml if using cupy. - interp_func = KDTree(coordinates[0]) - dd, ii = interp_func.query(grid, k=2) + kdtree = KDTree(coordinates[0]) + distances, neighbor_indices = kdtree.query(grid, k=k) - field_grid = field[ii] - field_grid = np.float32(np.mean(field_grid, (3))) + field_grid = field[neighbor_indices] + field_grid = np.mean(field_grid, axis=1) return field_grid @@ -701,7 +703,7 @@ def pad(arr: ArrayType, n_points: int, pad_value: float = 0.0) -> ArrayType: (4, 2) >>> np.array_equal(padded[:2], arr) True - >>> np.all(padded[2:] == -1.0) + >>> bool(np.all(padded[2:] == -1.0)) True >>> # No padding needed >>> same = pad(arr, 2) @@ -743,7 +745,7 @@ def pad_inp(arr: ArrayType, n_points: int, pad_value: float = 0.0) -> ArrayType: (4, 1, 2) >>> np.array_equal(padded[:2], arr) True - >>> np.all(padded[2:] == 0.0) + >>> bool(np.all(padded[2:] == 0.0)) True """ xp = array_type(arr) @@ -936,7 +938,7 @@ def combine_dict(old_dict: dict[Any, Any], new_dict: dict[Any, Any]) -> dict[Any >>> combined["loss"] 0.8 >>> combined["accuracy"] - 0.8999999999999999 + 0.9 """ for key in old_dict.keys(): old_dict[key] += new_dict[key] From 8de5fc6d1f653004065928e61cdc48e950c4826a Mon Sep 17 00:00:00 2001 From: Peter Sharpe Date: Fri, 20 Jun 2025 13:44:38 -0400 Subject: [PATCH 14/19] black format --- physicsnemo/utils/domino/utils.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/physicsnemo/utils/domino/utils.py b/physicsnemo/utils/domino/utils.py index 9eeb42a8f4..27a2ba2166 100644 --- a/physicsnemo/utils/domino/utils.py +++ b/physicsnemo/utils/domino/utils.py @@ -347,7 +347,7 @@ def convert_to_tet_mesh(polydata: "vtk.vtkPolyData") -> "vtk.vtkUnstructuredGrid Raises: RuntimeError: If tetrahedralization fails (e.g., non-manifold surface). - + """ tetrahedral_filter = vtkDataSetTriangleFilter() tetrahedral_filter.SetInputData(polydata) @@ -638,8 +638,7 @@ def calculate_normal_positional_encoding( def nd_interpolator( - coordinates: ArrayType, field: ArrayType, grid: ArrayType, - k: int = 2 + coordinates: ArrayType, field: ArrayType, grid: ArrayType, k: int = 2 ) -> ArrayType: """Perform n-dimensional interpolation using k-nearest neighbors. @@ -652,7 +651,7 @@ def nd_interpolator( field: Array of shape (n_points, n_fields) containing field values at source points. grid: Array of shape (n_field_points, n_dims) containing target grid points for interpolation. k: Number of nearest neighbors to use for interpolation. - + Returns: Interpolated field values at grid points using k-nearest neighbor averaging. @@ -1135,17 +1134,13 @@ def area_weighted_shuffle_array( # Note: np.random.choice performs expensive probability search on CPU # Future optimization: Consider implementing Alias method for GPU acceleration selected_indices = np.random.choice( - point_indices.get(), - size=n_points, - p=sampling_probabilities.get() + point_indices.get(), size=n_points, p=sampling_probabilities.get() ) selected_indices = xp.asarray(selected_indices) else: # Direct sampling on CPU selected_indices = np.random.choice( - point_indices, - size=n_points, - p=sampling_probabilities + point_indices, size=n_points, p=sampling_probabilities ) return arr[selected_indices], selected_indices From daac91a0a48ddb68d1b240466008c479f456a693 Mon Sep 17 00:00:00 2001 From: Peter Sharpe Date: Tue, 8 Jul 2025 19:41:35 +0200 Subject: [PATCH 15/19] Add test suite for domino utils module This commit introduces a new test file `test_domino_utils.py` that includes comprehensive unit tests for various functions in the domino utils module. Each test verifies the functionality of the corresponding utility function using examples from the documentation, ensuring correctness and reliability. --- test/utils/test_domino_utils.py | 230 ++++++++++++++++++++++++++++++++ 1 file changed, 230 insertions(+) create mode 100644 test/utils/test_domino_utils.py diff --git a/test/utils/test_domino_utils.py b/test/utils/test_domino_utils.py new file mode 100644 index 0000000000..24b780e47f --- /dev/null +++ b/test/utils/test_domino_utils.py @@ -0,0 +1,230 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Test suite for domino utils module. + +This test file duplicates all the docstring examples from the domino utils +module to ensure that the documented examples work correctly. +""" + +import numpy as np +import pytest + +from physicsnemo.utils.domino.utils import ( + array_type, + calculate_center_of_mass, + normalize, + unnormalize, + standardize, + unstandardize, + calculate_normal_positional_encoding, + nd_interpolator, + pad, + pad_inp, + shuffle_array, + shuffle_array_without_sampling, + calculate_pos_encoding, + combine_dict, + create_grid, + mean_std_sampling, + area_weighted_shuffle_array, +) + + +def test_array_type(): + """Test array_type function with docstring example.""" + arr = np.array([1, 2, 3]) + xp = array_type(arr) + result = xp.sum(arr) # Uses numpy.sum + # Just verify the function runs without error + assert result is not None + + +def test_calculate_center_of_mass(): + """Test calculate_center_of_mass function with docstring example.""" + centers = np.array([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]) + sizes = np.array([1.0, 2.0, 3.0]) + com = calculate_center_of_mass(centers, sizes) + expected = np.array([[4.0 / 3.0, 4.0 / 3.0, 4.0 / 3.0]]) + assert np.allclose(com, expected) + + +def test_normalize(): + """Test normalize function with docstring examples.""" + # Example 1: With explicit min/max + field = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + normalized = normalize(field, 5.0, 1.0) + expected = np.array([-1.0, -0.5, 0.0, 0.5, 1.0]) + assert np.allclose(normalized, expected) + + # Example 2: Auto-compute min/max + normalized_auto = normalize(field) + expected_auto = np.array([-1.0, -0.5, 0.0, 0.5, 1.0]) + assert np.allclose(normalized_auto, expected_auto) + + +def test_unnormalize(): + """Test unnormalize function with docstring example.""" + normalized = np.array([-1.0, -0.5, 0.0, 0.5, 1.0]) + original = unnormalize(normalized, 5.0, 1.0) + expected = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + assert np.allclose(original, expected) + + +def test_standardize(): + """Test standardize function with docstring examples.""" + # Example 1: With explicit mean/std + field = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + standardized = standardize(field, 3.0, np.sqrt(2.5)) + expected = np.array([-1.265, -0.632, 0.0, 0.632, 1.265]) + assert np.allclose(standardized, expected, atol=1e-3) + + # Example 2: Auto-compute mean/std + standardized_auto = standardize(field) + assert np.allclose(np.mean(standardized_auto), 0.0) + assert np.allclose(np.std(standardized_auto, ddof=0), 1.0) + + +def test_unstandardize(): + """Test unstandardize function with docstring example.""" + standardized = np.array([-1.265, -0.632, 0.0, 0.632, 1.265]) + original = unstandardize(standardized, 3.0, np.sqrt(2.5)) + expected = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + assert np.allclose(original, expected, atol=1e-3) + + +def test_calculate_normal_positional_encoding(): + """Test calculate_normal_positional_encoding function with docstring examples.""" + # Example 1: Basic coordinates + coords = np.array([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]) + cell_size = [0.1, 0.1, 0.1] + encoding = calculate_normal_positional_encoding(coords, cell_dimensions=cell_size) + assert encoding.shape == (2, 12) + + # Example 2: Relative positioning + coords_b = np.array([[0.5, 0.5, 0.5], [0.5, 0.5, 0.5]]) + encoding_rel = calculate_normal_positional_encoding(coords, coords_b, cell_size) + assert encoding_rel.shape == (2, 12) + + +def test_nd_interpolator(): + """Test nd_interpolator function with docstring example.""" + # Simple 2D interpolation example + coords = np.array([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) + field_vals = np.array([[1.0], [2.0], [3.0], [4.0]]) + grid_points = np.array([[0.5, 0.5]]) + result = nd_interpolator([coords], field_vals, grid_points) + assert result.shape[0] == 1 # One grid point + + +def test_pad(): + """Test pad function with docstring examples.""" + # Example 1: Padding needed + arr = np.array([[1.0, 2.0], [3.0, 4.0]]) + padded = pad(arr, 4, -1.0) + assert padded.shape == (4, 2) + assert np.array_equal(padded[:2], arr) + assert bool(np.all(padded[2:] == -1.0)) + + # Example 2: No padding needed + same = pad(arr, 2) + assert np.array_equal(same, arr) + + +def test_pad_inp(): + """Test pad_inp function with docstring example.""" + arr = np.array([[[1.0, 2.0]], [[3.0, 4.0]]]) + padded = pad_inp(arr, 4, 0.0) + assert padded.shape == (4, 1, 2) + assert np.array_equal(padded[:2], arr) + assert bool(np.all(padded[2:] == 0.0)) + + +def test_shuffle_array(): + """Test shuffle_array function with docstring example.""" + np.random.seed(42) # For reproducible results + data = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) + subset, indices = shuffle_array(data, 2) + assert subset.shape == (2, 2) + assert indices.shape == (2,) + assert len(np.unique(indices)) == 2 # No duplicates + + +def test_shuffle_array_without_sampling(): + """Test shuffle_array_without_sampling function with docstring example.""" + np.random.seed(42) # For reproducible results + data = np.array([[1], [2], [3], [4]]) + shuffled, indices = shuffle_array_without_sampling(data) + assert shuffled.shape == (4, 1) + assert indices.shape == (4,) + assert set(indices) == set(range(4)) # All original indices present + + +def test_calculate_pos_encoding(): + """Test calculate_pos_encoding function with docstring example.""" + positions = np.array([0.0, 1.0, 2.0]) + encodings = calculate_pos_encoding(positions, d=4) + assert len(encodings) == 4 + assert all(enc.shape == (3,) for enc in encodings) + + +def test_combine_dict(): + """Test combine_dict function with docstring example.""" + stats1 = {"loss": 0.5, "accuracy": 0.8} + stats2 = {"loss": 0.3, "accuracy": 0.1} + combined = combine_dict(stats1, stats2) + assert combined["loss"] == 0.8 + assert combined["accuracy"] == 0.9 + + +def test_create_grid(): + """Test create_grid function with docstring example.""" + min_bounds = np.array([0.0, 0.0, 0.0]) + max_bounds = np.array([1.0, 1.0, 1.0]) + grid_res = np.array([2, 2, 2]) + grid = create_grid(max_bounds, min_bounds, grid_res) + assert grid.shape == (2, 2, 2, 3) + assert np.allclose(grid[0, 0, 0], [0.0, 0.0, 0.0]) + assert np.allclose(grid[1, 1, 1], [1.0, 1.0, 1.0]) + + +def test_mean_std_sampling(): + """Test mean_std_sampling function with docstring example.""" + # Create test data with outliers + field = np.array([[1.0], [2.0], [3.0], [10.0]]) # 10.0 is outlier + field_mean = np.array([2.0]) + field_std = np.array([1.0]) + outliers = mean_std_sampling(field, field_mean, field_std, 2.0) + assert 3 in outliers # Index 3 (value 10.0) should be detected as outlier + + +def test_area_weighted_shuffle_array(): + """Test area_weighted_shuffle_array function with docstring example.""" + np.random.seed(42) # For reproducible results + mesh_data = np.array([[1.0], [2.0], [3.0], [4.0]]) + cell_areas = np.array([0.1, 0.1, 0.1, 10.0]) # Last point has much larger area + subset, indices = area_weighted_shuffle_array(mesh_data, 2, cell_areas) + assert subset.shape == (2, 1) + assert indices.shape == (2,) + # The point with large area (index 3) should likely be selected + assert len(set(indices)) <= 2 # At most 2 unique indices + + # Use higher area_factor for stronger bias toward large areas + subset_biased, _ = area_weighted_shuffle_array( + mesh_data, 2, cell_areas, area_factor=2.0 + ) + assert subset_biased.shape == (2, 1) From d2b11c7583a45e0b161891e3c55718b5f89022f4 Mon Sep 17 00:00:00 2001 From: Peter Sharpe Date: Tue, 8 Jul 2025 19:49:44 +0200 Subject: [PATCH 16/19] Refactor array_type function to handle CuPy import gracefully and optimize area_weighted_shuffle_array for consistent array handling. Remove redundant test for array_type. --- physicsnemo/utils/domino/utils.py | 36 +++++++++++++++---------------- test/utils/test_domino_utils.py | 11 ---------- 2 files changed, 18 insertions(+), 29 deletions(-) diff --git a/physicsnemo/utils/domino/utils.py b/physicsnemo/utils/domino/utils.py index 27a2ba2166..623c22d464 100644 --- a/physicsnemo/utils/domino/utils.py +++ b/physicsnemo/utils/domino/utils.py @@ -24,8 +24,6 @@ from pathlib import Path from typing import Any, Sequence - -import cupy as cp import numpy as np import pyvista as pv import vtk @@ -36,10 +34,15 @@ from physicsnemo.utils.profiling import profile # Type alias for arrays that can be either NumPy or CuPy -ArrayType = np.ndarray | cp.ndarray + +try: + import cupy as cp + ArrayType = np.ndarray | cp.ndarray +except ImportError: + ArrayType = np.ndarray -def array_type(array: ArrayType) -> type[np] | type[cp]: +def array_type(array: ArrayType) -> "type[np] | type[cp]": """Determine the array module (NumPy or CuPy) for the given array. This function enables array-agnostic code by returning the appropriate @@ -57,7 +60,11 @@ def array_type(array: ArrayType) -> type[np] | type[cp]: >>> xp = array_type(arr) >>> result = xp.sum(arr) # Uses numpy.sum """ - return cp.get_array_module(array) + try: + import cupy as cp + return cp.get_array_module(array) + except ImportError: + return np def calculate_center_of_mass(centers: ArrayType, sizes: ArrayType) -> ArrayType: @@ -1129,18 +1136,11 @@ def area_weighted_shuffle_array( # Create index array for all available points point_indices = xp.arange(arr.shape[0]) - # Handle GPU vs CPU sampling differently due to memory constraints - if xp == cp: - # Note: np.random.choice performs expensive probability search on CPU - # Future optimization: Consider implementing Alias method for GPU acceleration - selected_indices = np.random.choice( - point_indices.get(), size=n_points, p=sampling_probabilities.get() - ) - selected_indices = xp.asarray(selected_indices) - else: - # Direct sampling on CPU - selected_indices = np.random.choice( - point_indices, size=n_points, p=sampling_probabilities - ) + selected_indices = xp.random.choice( + xp.asarray(point_indices), + size=n_points, + p=xp.asarray(sampling_probabilities) + ) + selected_indices = xp.asarray(selected_indices) return arr[selected_indices], selected_indices diff --git a/test/utils/test_domino_utils.py b/test/utils/test_domino_utils.py index 24b780e47f..6d87f3d8ea 100644 --- a/test/utils/test_domino_utils.py +++ b/test/utils/test_domino_utils.py @@ -25,7 +25,6 @@ import pytest from physicsnemo.utils.domino.utils import ( - array_type, calculate_center_of_mass, normalize, unnormalize, @@ -44,16 +43,6 @@ area_weighted_shuffle_array, ) - -def test_array_type(): - """Test array_type function with docstring example.""" - arr = np.array([1, 2, 3]) - xp = array_type(arr) - result = xp.sum(arr) # Uses numpy.sum - # Just verify the function runs without error - assert result is not None - - def test_calculate_center_of_mass(): """Test calculate_center_of_mass function with docstring example.""" centers = np.array([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]) From 02849e0462a551e4bc637b257146c147ca3da9cb Mon Sep 17 00:00:00 2001 From: Peter Sharpe Date: Tue, 8 Jul 2025 19:51:23 +0200 Subject: [PATCH 17/19] Import PyVista conditionally in extract_surface_triangles function to avoid unnecessary dependency loading. --- physicsnemo/utils/domino/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/physicsnemo/utils/domino/utils.py b/physicsnemo/utils/domino/utils.py index 623c22d464..a15f0ddfef 100644 --- a/physicsnemo/utils/domino/utils.py +++ b/physicsnemo/utils/domino/utils.py @@ -25,7 +25,6 @@ from pathlib import Path from typing import Any, Sequence import numpy as np -import pyvista as pv import vtk from scipy.spatial import KDTree from vtk import vtkDataSetTriangleFilter @@ -322,6 +321,7 @@ def extract_surface_triangles(tetrahedral_mesh: "vtk.vtkUnstructuredGrid") -> li surface_filter.Update() # Wrap with PyVista for easier manipulation + import pyvista as pv surface_mesh = pv.wrap(surface_filter.GetOutput()) triangle_indices = [] From cdbf352f81b78380f384394a4f2ce8b29cd22843 Mon Sep 17 00:00:00 2001 From: Peter Sharpe Date: Tue, 8 Jul 2025 20:02:57 +0200 Subject: [PATCH 18/19] black formatting --- physicsnemo/utils/domino/utils.py | 8 +++++--- test/utils/test_domino_utils.py | 19 ++++++++++--------- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/physicsnemo/utils/domino/utils.py b/physicsnemo/utils/domino/utils.py index a15f0ddfef..8bfb518e70 100644 --- a/physicsnemo/utils/domino/utils.py +++ b/physicsnemo/utils/domino/utils.py @@ -24,6 +24,7 @@ from pathlib import Path from typing import Any, Sequence + import numpy as np import vtk from scipy.spatial import KDTree @@ -36,6 +37,7 @@ try: import cupy as cp + ArrayType = np.ndarray | cp.ndarray except ImportError: ArrayType = np.ndarray @@ -61,6 +63,7 @@ def array_type(array: ArrayType) -> "type[np] | type[cp]": """ try: import cupy as cp + return cp.get_array_module(array) except ImportError: return np @@ -322,6 +325,7 @@ def extract_surface_triangles(tetrahedral_mesh: "vtk.vtkUnstructuredGrid") -> li # Wrap with PyVista for easier manipulation import pyvista as pv + surface_mesh = pv.wrap(surface_filter.GetOutput()) triangle_indices = [] @@ -1137,9 +1141,7 @@ def area_weighted_shuffle_array( point_indices = xp.arange(arr.shape[0]) selected_indices = xp.random.choice( - xp.asarray(point_indices), - size=n_points, - p=xp.asarray(sampling_probabilities) + xp.asarray(point_indices), size=n_points, p=xp.asarray(sampling_probabilities) ) selected_indices = xp.asarray(selected_indices) diff --git a/test/utils/test_domino_utils.py b/test/utils/test_domino_utils.py index 6d87f3d8ea..448c05873a 100644 --- a/test/utils/test_domino_utils.py +++ b/test/utils/test_domino_utils.py @@ -25,24 +25,25 @@ import pytest from physicsnemo.utils.domino.utils import ( + area_weighted_shuffle_array, calculate_center_of_mass, - normalize, - unnormalize, - standardize, - unstandardize, calculate_normal_positional_encoding, + calculate_pos_encoding, + combine_dict, + create_grid, + mean_std_sampling, nd_interpolator, + normalize, pad, pad_inp, shuffle_array, shuffle_array_without_sampling, - calculate_pos_encoding, - combine_dict, - create_grid, - mean_std_sampling, - area_weighted_shuffle_array, + standardize, + unnormalize, + unstandardize, ) + def test_calculate_center_of_mass(): """Test calculate_center_of_mass function with docstring example.""" centers = np.array([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]) From 6251d3c8e088159a18149fc6d7a47f883bc1b4f1 Mon Sep 17 00:00:00 2001 From: Peter Sharpe Date: Tue, 8 Jul 2025 20:11:57 +0200 Subject: [PATCH 19/19] Remove unused import --- test/utils/test_domino_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/utils/test_domino_utils.py b/test/utils/test_domino_utils.py index 448c05873a..8a0e03637b 100644 --- a/test/utils/test_domino_utils.py +++ b/test/utils/test_domino_utils.py @@ -22,7 +22,6 @@ """ import numpy as np -import pytest from physicsnemo.utils.domino.utils import ( area_weighted_shuffle_array,