diff --git a/.importlinter b/.importlinter index e69bf55e9f..2101743002 100644 --- a/.importlinter +++ b/.importlinter @@ -1,75 +1,100 @@ [importlinter] root_package = physicsnemo include_external_packages = True +contract_types = + forbidden_import: prevent_untracked_imports.ForbiddenImportContract [importlinter:contract:physicsnemo-modules] name = Prevent Upward Imports in the PhysicsNemo Structure type = layers containers= - physicsnemo + physicsnemo layers = - experimental - active_learning - models : registry : datapipes : metrics : domain_parallel - nn - utils - distributed - core + experimental + active_learning + models : registry : datapipes : metrics : domain_parallel + nn + utils + distributed + core [importlinter:contract:physicsnemo-core] name = Control Dependencies in PhysicsNeMo core type = layers containers= - physicsnemo.core + physicsnemo.core layers = - module : registry - meta - warnings | version_check | filesystem + module : registry + meta + warnings | version_check | filesystem [importlinter:contract:physicsnemo-distributed] name = Control Dependencies in PhysicsNeMo distributed type = layers containers= - physicsnemo.distributed + physicsnemo.distributed layers = - fft | autograd - mappings - utils - manager - config + fft | autograd + mappings + utils + manager + config [importlinter:contract:physicsnemo-utils] name = Control Dependencies in PhysicsNeMo utils type = layers containers= - physicsnemo.utils + physicsnemo.utils layers = - mesh | insolation | zenith_angle - profiling - checkpoint - capture - logging | memory + mesh | insolation | zenith_angle + profiling + checkpoint + capture + logging | memory [importlinter:contract:physicsnemo-nn] name = Control Dependencies in PhysicsNeMo nn type = layers containers= - physicsnemo.nn + physicsnemo.nn layers = - fourier_layers | transformer_layers - dgm_layers | mlp_layers | fully_connected_layers | gnn_layers + fourier_layers | transformer_layers + dgm_layers | mlp_layers | fully_connected_layers | gnn_layers activations | attention_layers | ball_query | conv_layers | drop | fft | fused_silu | interpolation | kan_layers | resample_layers | sdf | siren_layers | spectral_layers | transformer_decoder | weight_fact | weight_norm - neighbors - utils + neighbors + utils [importlinter:contract:physicsnemo-models] name = Prevent Imports between physicsnemo models type = layers containers= - physicsnemo.models + physicsnemo.models layers = - mesh_reduced - afno | dlwp | dlwp_healpix | domino | dpot | fengwu | figconvnet | fno | graphcast | meshgraphnet | pangu | pix2pix | rnn | srrn | swinvrnn | topodiff | transolver | vfgn - unet | diffusion | dlwp_healpix_layers + mesh_reduced + afno | dlwp | dlwp_healpix | domino | dpot | fengwu | figconvnet | fno | graphcast | meshgraphnet | pangu | pix2pix | rnn | srrn | swinvrnn | topodiff | transolver | vfgn + unet | diffusion | dlwp_healpix_layers +[importlinter:contract:physicsnemo-core-external-imports] +name = Prevent Non-listed external imports in physicsnemo core +type = forbidden_import +container = physicsnemo.core +dependency_group = core + +[importlinter:contract:physicsnemo-distributed-external-imports] +name = Prevent Non-listed external imports in physicsnemo distributed +type = forbidden_import +container = physicsnemo.distributed +dependency_group = distributed + +[importlinter:contract:physicsnemo-utils-external-imports] +name = Prevent Non-listed external imports in physicsnemo utils +type = forbidden_import +container = physicsnemo.utils +dependency_group = utils + +[importlinter:contract:physicsnemo-nn-external-imports] +name = Prevent Non-listed external imports in physicsnemo nn +type = forbidden_import +container = physicsnemo.nn +dependency_group = nn diff --git a/examples/structural_mechanics/crash/README.md b/examples/structural_mechanics/crash/README.md index 1475fa393f..6b8a5ebb3e 100644 --- a/examples/structural_mechanics/crash/README.md +++ b/examples/structural_mechanics/crash/README.md @@ -36,7 +36,7 @@ For an in-depth comparison between the Transolver and MeshGraphNet models and th ```yaml # conf/config.yaml defaults: - - reader: vtp # or d3plot, or your custom reader + - reader: vtp # vtp, zarr, d3plot, or your custom reader - datapipe: point_cloud # or graph - model: transolver_time_conditional # or an MGN variant - training: default @@ -47,7 +47,7 @@ defaults: 2) Point to your datasets and core training knobs. - `conf/training/default.yaml`: - - `raw_data_dir`: path to TRAIN runs (folder of run folders for d3plot, or folder of .vtp files for VTP) + - `raw_data_dir`: path to TRAIN runs (folder of run folders for d3plot, folder of .vtp files for VTP, or folder of .zarr stores for Zarr) - `num_time_steps`: number of frames to use per run - `num_training_samples`: how many runs to load @@ -77,6 +77,7 @@ features: [thickness] # or [] for no features; preserve order if adding more 4) Reader‑specific options (optional). - d3plot: `conf/reader/d3plot.yaml` → `wall_node_disp_threshold` +- VTP and Zarr readers have no additional options (they read pre-processed data) 5) Model config: ensure input dimensions match your features. @@ -127,26 +128,38 @@ This will install: [PhysicsNeMo-Curator](https://github.com/NVIDIA/physicsnemo-curator). Using `PhysicsNeMo-Curator`, crash simulation data from LS-DYNA can be processed into training-ready formats easily. -Currently, this can be used to preprocess d3plot files into VTP. +PhysicsNeMo-Curator can preprocess d3plot files into **VTP** (for visualization and smaller datasets) or **Zarr** (for large-scale ML training). ### Quick Start Install PhysicsNeMo-Curator following [these instructions](https://github.com/NVIDIA/physicsnemo-curator?tab=readme-ov-file#installation-and-usage). -Process your LS-DYNA data: +Process your LS-DYNA data to **VTP format**: ```bash export PYTHONPATH=$PYTHONPATH:examples && -physicsnemo-curator-etl \ - --config-dir=examples/config \ - --config-name=crash_etl \ - etl.source.input_dir=/data/crash_sims/ \ - etl.sink.output_dir=/data/crash_processed_vtp/ \ +physicsnemo-curator-etl \ + --config-dir=examples/structural_mechanics/crash/config \ + --config-name=crash_etl \ + serialization_format=vtp \ + etl.source.input_dir=/data/crash_sims/ \ + serialization_format.sink.output_dir=/data/crash_vtp/ \ etl.processing.num_processes=4 ``` -This will process all LS-DYNA runs in `/data/crash_sims/` and output VTP files to `/data/crash_processed_vtp/`. +Or process to **Zarr format** for large-scale training: + +```bash +export PYTHONPATH=$PYTHONPATH:examples && +physicsnemo-curator-etl \ + --config-dir=examples/structural_mechanics/crash/config \ + --config-name=crash_etl \ + serialization_format=zarr \ + etl.source.input_dir=/data/crash_sims/ \ + serialization_format.sink.output_dir=/data/crash_zarr/ \ + etl.processing.num_processes=4 +``` ### Input Data Structure @@ -165,7 +178,7 @@ crash_sims/ ### Output Formats -#### VTP Format (Recommended for this example) +#### VTP Format Produces single VTP file per run with all timesteps as displacement fields: @@ -179,10 +192,33 @@ crash_processed_vtp/ Each VTP contains: - Reference coordinates at t=0 - Displacement fields: `displacement_t0.000`, `displacement_t0.005`, etc. -- Node thickness values +- Node thickness and other point data features This format is directly compatible with the VTP reader in this example. +#### Zarr Format + +Produces one Zarr store per run with pre-computed graph structure: + +``` +crash_processed_zarr/ +├── Run100.zarr/ +│ ├── mesh_pos # (timesteps, nodes, 3) - temporal positions +│ ├── thickness # (nodes,) - node features +│ └── edges # (num_edges, 2) - pre-computed graph connectivity +├── Run101.zarr/ +└── ... +``` + +Each Zarr store contains: +- `mesh_pos`: Full temporal trajectory (no displacement reconstruction needed) +- `thickness`: Per-node features +- `edges`: Pre-computed edge connectivity (no edge rebuilding during training) + +**NOTE:** All heavy preprocessing (node filtering, edge building, thickness computation) is done once during curation using PhysicsNeMo-Curator. The reader simply loads pre-computed arrays. + +This format is directly compatible with the Zarr reader in this example. + ## Training Training is managed via Hydra configurations located in conf/. @@ -277,14 +313,15 @@ If you use the graph datapipe, the edge list is produced by walking the filtered ### Built‑in VTP reader (PolyData) -In addition to `d3plot`, a lightweight VTP reader is provided in `vtp_reader.py`. It treats each `.vtp` file in a directory as a separate run and expects point displacements to be stored as vector arrays in `poly.point_data` with names like `displacement_t0.000`, `displacement_t0.005`, … (a more permissive fallback of any `displacement_t*` is also supported). The reader: +A lightweight VTP reader is provided in `vtp_reader.py`. It treats each `.vtp` file in a directory as a separate run and expects point displacements to be stored as vector arrays in `poly.point_data` with names like `displacement_t0.000`, `displacement_t0.005`, … (a more permissive fallback of any `displacement_t*` is also supported). The reader: - loads the reference coordinates from `poly.points` - builds absolute positions per timestep as `[t0: coords, t>0: coords + displacement_t]` - extracts cell connectivity from the PolyData faces and converts it to unique edges -- returns `(srcs, dsts, point_data)` where `point_data` contains `'coords': [T, N, 3]` +- extracts all point data fields dynamically (e.g., thickness, modulus) +- returns `(srcs, dsts, point_data)` where `point_data` contains `'coords': [T, N, 3]` and all feature arrays -By default, the VTP reader does not attach additional features; it is compatible with `features: []`. If your `.vtp` files include additional per‑point arrays you would like to model (e.g., thickness or modulus), extend the reader to add those arrays to each run’s record using keys that match your `features` list. The datapipe will then concatenate them in the configured order. +The VTP reader dynamically extracts all non-displacement point data fields from the VTP file and makes them available to the datapipe. If your `.vtp` files include additional per‑point arrays (e.g., thickness or modulus), simply add their names to the `features` list in your datapipe config. Example Hydra configuration for the VTP reader: @@ -304,12 +341,58 @@ defaults: - reader: vtp ``` -And set `features` to empty (or to the names you add in your extended reader) in `conf/datapipe/point_cloud.yaml` or `conf/datapipe/graph.yaml`: +And configure features in `conf/datapipe/point_cloud.yaml` or `conf/datapipe/graph.yaml`: ```yaml -features: [] # or [thickness, Y_modulus] if your reader provides them +features: [thickness] # or [] for no features ``` +### Built‑in Zarr reader + +A Zarr reader provided in `zarr_reader.py`. It reads pre-processed Zarr stores created by PhysicsNeMo-Curator, where all heavy computation (node filtering, edge building, thickness computation) has already been done during the ETL pipeline. The reader: + +- loads pre-computed temporal positions directly from `mesh_pos` (no displacement reconstruction) +- loads pre-computed edges (no connectivity-to-edge conversion needed) +- dynamically extracts all point data fields (thickness, etc.) from the Zarr store +- returns `(srcs, dsts, point_data)` similar to VTP reader + +Data layout expected by Zarr reader: +- `/*.zarr/` (each `.zarr` directory is treated as one run) +- Each Zarr store must contain: + - `mesh_pos`: `[T, N, 3]` temporal positions + - `edges`: `[E, 2]` pre-computed edge connectivity + - Feature arrays (e.g., `thickness`): `[N]` or `[N, K]` per-node features + +Example Hydra configuration for the Zarr reader: + +```yaml +# conf/reader/zarr.yaml +_target_: zarr_reader.Reader +``` + +Select it in `conf/config.yaml`: + +```yaml +defaults: + - reader: zarr # Options are: vtp, d3plot, zarr + - datapipe: point_cloud # will be overridden by model configs + - model: transolver_autoregressive_rollout_training + - training: default + - inference: default + - _self_ +``` + +And configure features in `conf/datapipe/graph.yaml`: + +```yaml +features: [thickness] # Must match fields stored in Zarr +``` + +**Recommended workflow:** +1. Use PhysicsNeMo-Curator to preprocess d3plot → VTP or Zarr once +2. Use corresponding reader for all training/validation +3. Optionally use d3plot reader for quick prototyping on raw data + ### Data layout expected by readers - d3plot reader (`d3plot_reader.py`): @@ -320,6 +403,10 @@ features: [] # or [thickness, Y_modulus] if your reader provides them - `/*.vtp` (each `.vtp` is treated as one run) - Displacements stored as 3‑component arrays in point_data with names like `displacement_t0.000`, `displacement_t0.005`, ... (fallback accepts any `displacement_t*`). +- Zarr reader (`zarr_reader.py`): + - `/*.zarr/` (each `.zarr` directory is treated as one run) + - Contains pre-computed `mesh_pos`, `edges`, and feature arrays + ### Write your own reader To write your own reader, implement a Hydra‑instantiable function or class whose call returns a three‑tuple `(srcs, dsts, point_data)`. The first two entries are lists of integer arrays describing edges per run (they can be empty lists if you are not producing a graph), and `point_data` is a list of Python dicts with one dict per run. Each dict must contain `'coords'` as a `[T, N, 3]` array and one array per feature name listed in `conf/datapipe/*.yaml` under `features`. Feature arrays can be `[N]` or `[N, K]` and should use the same node indexing as `'coords'`. For convenience, a simple class reader can accept the Hydra `split` argument (e.g., "train" or "test") and decide whether to save VTP frames, but this is optional. diff --git a/examples/structural_mechanics/crash/conf/config.yaml b/examples/structural_mechanics/crash/conf/config.yaml index 3a3e288d66..7bba46e8c4 100644 --- a/examples/structural_mechanics/crash/conf/config.yaml +++ b/examples/structural_mechanics/crash/conf/config.yaml @@ -25,7 +25,7 @@ experiment_desc: "unified training recipe for crash models" run_desc: "unified training recipe for crash models" defaults: - - reader: vtp #d3plot + - reader: vtp # Options are: vtp, d3plot, zarr - datapipe: point_cloud # will be overridden by model configs - model: transolver_autoregressive_rollout_training - training: default diff --git a/examples/structural_mechanics/crash/conf/reader/zarr.yaml b/examples/structural_mechanics/crash/conf/reader/zarr.yaml new file mode 100644 index 0000000000..733730067f --- /dev/null +++ b/examples/structural_mechanics/crash/conf/reader/zarr.yaml @@ -0,0 +1,18 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +_target_: zarr_reader.Reader +_convert_: all \ No newline at end of file diff --git a/examples/structural_mechanics/crash/tests/test_zarr_reader.py b/examples/structural_mechanics/crash/tests/test_zarr_reader.py new file mode 100644 index 0000000000..8b5d0fe745 --- /dev/null +++ b/examples/structural_mechanics/crash/tests/test_zarr_reader.py @@ -0,0 +1,380 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +from pathlib import Path + +import numpy as np +import pytest +import zarr + +# Import functions from zarr_reader +import sys + +sys.path.insert(0, str(Path(__file__).parent.parent)) +import zarr_reader + + +def create_mock_zarr_store( + store_path: Path, + num_timesteps: int = 3, + num_nodes: int = 4, + thickness_value: float = 1.0, +): + """ + Helper function to create a mock Zarr store with crash simulation data. + + Args: + store_path: Path where the Zarr store should be created + num_timesteps: Number of timesteps + num_nodes: Number of nodes + thickness_value: Constant thickness value for all nodes + + Returns: + Tuple of (mesh_pos, node_thickness, edges) arrays that were written + """ + store_path.mkdir(exist_ok=True) + + # Create mock data + mesh_pos = np.random.randn(num_timesteps, num_nodes, 3).astype(np.float32) + node_thickness = np.ones(num_nodes, dtype=np.float32) * thickness_value + edges = np.array([[0, 1], [1, 2], [2, 3], [3, 0]], dtype=np.int64) + + # Write to Zarr store + store = zarr.open(str(store_path), mode="w") + store.create_dataset("mesh_pos", data=mesh_pos, dtype=np.float32) + store.create_dataset("thickness", data=node_thickness, dtype=np.float32) + store.create_dataset("edges", data=edges, dtype=np.int64) + + return mesh_pos, node_thickness, edges + + +@pytest.fixture +def mock_zarr_store(): + """Create a temporary Zarr store with mock crash simulation data.""" + with tempfile.TemporaryDirectory() as temp_dir: + store_path = Path(temp_dir) / "Run001.zarr" + mesh_pos, node_thickness, edges = create_mock_zarr_store( + store_path, thickness_value=2.0 + ) + yield temp_dir, mesh_pos, node_thickness, edges + + +@pytest.fixture +def mock_zarr_directory(): + """Create a directory with multiple Zarr stores.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Create multiple zarr stores + for i in range(3): + store_path = temp_path / f"Run{i:03d}.zarr" + create_mock_zarr_store(store_path) + + # Create a non-zarr directory (should be ignored) + (temp_path / "NotAZarr").mkdir() + + # Create a regular file (should be ignored) + (temp_path / "some_file.txt").touch() + + yield temp_dir + + +def test_find_zarr_stores(mock_zarr_directory): + """Test that find_zarr_stores correctly identifies Zarr directories.""" + zarr_stores = zarr_reader.find_zarr_stores(mock_zarr_directory) + + assert len(zarr_stores) == 3, f"Expected 3 zarr stores, got {len(zarr_stores)}" + assert all(path.endswith(".zarr") for path in zarr_stores) + assert all("Run" in Path(path).name for path in zarr_stores) + + +def test_find_zarr_stores_empty_directory(): + """Test find_zarr_stores with empty directory.""" + with tempfile.TemporaryDirectory() as temp_dir: + zarr_stores = zarr_reader.find_zarr_stores(temp_dir) + assert len(zarr_stores) == 0, ( + "Should return empty list for directory with no zarr stores" + ) + + +def test_find_zarr_stores_nonexistent_directory(): + """Test find_zarr_stores with nonexistent directory.""" + zarr_stores = zarr_reader.find_zarr_stores("/nonexistent/path") + assert len(zarr_stores) == 0, "Should return empty list for nonexistent directory" + + +def test_load_zarr_store(mock_zarr_store): + """Test loading data from a Zarr store.""" + temp_dir, expected_mesh_pos, expected_thickness, expected_edges = mock_zarr_store + store_path = Path(temp_dir) / "Run001.zarr" + + mesh_pos, edges, point_data_dict = zarr_reader.load_zarr_store(str(store_path)) + + # Check shapes + assert mesh_pos.shape == expected_mesh_pos.shape + assert edges.shape == expected_edges.shape + assert "thickness" in point_data_dict, "Should have thickness in point_data" + assert point_data_dict["thickness"].shape == expected_thickness.shape + + # Check data types + assert mesh_pos.dtype == np.float64, "mesh_pos should be float64" + assert point_data_dict["thickness"].dtype == np.float32, ( + "thickness should be float32" + ) + assert edges.dtype == np.int64, "edges should be int64" + + # Check values + np.testing.assert_array_almost_equal(mesh_pos, expected_mesh_pos) + np.testing.assert_array_almost_equal( + point_data_dict["thickness"], expected_thickness + ) + np.testing.assert_array_equal(edges, expected_edges) + + +def test_load_zarr_store_missing_fields(): + """Test that loading a Zarr store with missing required fields raises KeyError.""" + with tempfile.TemporaryDirectory() as temp_dir: + store_path = Path(temp_dir) / "incomplete.zarr" + store_path.mkdir() + + # Create store with only thickness (missing mesh_pos and edges) + store = zarr.open(str(store_path), mode="w") + store.create_dataset("thickness", data=np.ones(4, dtype=np.float32)) + + # Should raise KeyError for missing mesh_pos + with pytest.raises(KeyError, match="mesh_pos"): + zarr_reader.load_zarr_store(str(store_path)) + + # Test missing edges + store_path2 = Path(temp_dir) / "incomplete2.zarr" + store_path2.mkdir() + store2 = zarr.open(str(store_path2), mode="w") + store2.create_dataset( + "mesh_pos", data=np.random.randn(3, 4, 3).astype(np.float32) + ) + + # Should raise KeyError for missing edges + with pytest.raises(KeyError, match="edges"): + zarr_reader.load_zarr_store(str(store_path2)) + + +def test_load_zarr_store_multiple_point_data_fields(): + """Test that load_zarr_store dynamically reads all point data fields.""" + with tempfile.TemporaryDirectory() as temp_dir: + store_path = Path(temp_dir) / "multi_fields.zarr" + store_path.mkdir() + + # Create store with multiple point data fields + num_nodes = 10 + store = zarr.open(str(store_path), mode="w") + store.create_dataset( + "mesh_pos", data=np.random.randn(3, num_nodes, 3).astype(np.float32) + ) + store.create_dataset("edges", data=np.array([[0, 1]], dtype=np.int64)) + # Add multiple point data fields + store.create_dataset("thickness", data=np.ones(num_nodes, dtype=np.float32)) + store.create_dataset( + "stress", data=np.random.randn(num_nodes).astype(np.float32) + ) + store.create_dataset( + "temperature", data=np.random.randn(num_nodes).astype(np.float32) + ) + # This should be skipped (mesh connectivity, not point data) + store.create_dataset( + "mesh_connectivity_flat", data=np.array([0, 1, 2], dtype=np.int64) + ) + + mesh_pos, edges, point_data_dict = zarr_reader.load_zarr_store(str(store_path)) + + # Should have all three point data fields + assert "thickness" in point_data_dict + assert "stress" in point_data_dict + assert "temperature" in point_data_dict + # Should NOT include mesh connectivity + assert "mesh_connectivity_flat" not in point_data_dict + # Should NOT include mesh_pos or edges + assert "mesh_pos" not in point_data_dict + assert "edges" not in point_data_dict + + # Check that all point data fields have correct shape + for name, data in point_data_dict.items(): + assert data.shape == (num_nodes,), ( + f"{name} should have shape ({num_nodes},)" + ) + assert data.dtype == np.float32, f"{name} should be float32" + + +def test_load_zarr_store_2d_feature_arrays(): + """Test that load_zarr_store correctly handles 2D feature arrays [N, K].""" + with tempfile.TemporaryDirectory() as temp_dir: + store_path = Path(temp_dir) / "2d_features.zarr" + store_path.mkdir() + + # Create store with 2D feature array + num_nodes = 8 + feature_dim = 3 + store = zarr.open(str(store_path), mode="w") + store.create_dataset( + "mesh_pos", data=np.random.randn(3, num_nodes, 3).astype(np.float32) + ) + store.create_dataset("edges", data=np.array([[0, 1]], dtype=np.int64)) + # Add 1D feature (thickness) + store.create_dataset("thickness", data=np.ones(num_nodes, dtype=np.float32)) + # Add 2D feature array [N, K] (e.g., stress tensor components) + stress_tensor = np.random.randn(num_nodes, feature_dim).astype(np.float32) + store.create_dataset("stress_tensor", data=stress_tensor) + + mesh_pos, edges, point_data_dict = zarr_reader.load_zarr_store(str(store_path)) + + # Should have both 1D and 2D features + assert "thickness" in point_data_dict + assert "stress_tensor" in point_data_dict + + # Check 1D feature shape + assert point_data_dict["thickness"].shape == (num_nodes,) + assert point_data_dict["thickness"].ndim == 1 + + # Check 2D feature shape + assert point_data_dict["stress_tensor"].shape == (num_nodes, feature_dim) + assert point_data_dict["stress_tensor"].ndim == 2 + + # Verify values match + np.testing.assert_array_almost_equal( + point_data_dict["stress_tensor"], stress_tensor + ) + + +def test_process_zarr_data(mock_zarr_directory): + """Test processing multiple Zarr stores.""" + srcs, dsts, point_data = zarr_reader.process_zarr_data( + data_dir=mock_zarr_directory, + num_samples=2, + ) + + # Check we got 2 samples + assert len(srcs) == 2, f"Expected 2 samples, got {len(srcs)}" + assert len(dsts) == 2 + assert len(point_data) == 2 + + # Check each sample has correct structure + for i in range(2): + assert srcs[i].ndim == 1, "srcs should be 1D array" + assert dsts[i].ndim == 1, "dsts should be 1D array" + assert len(srcs[i]) == len(dsts[i]), "srcs and dsts should have same length" + + # Check point_data structure + assert "coords" in point_data[i], "point_data should have 'coords' key" + assert "thickness" in point_data[i], "point_data should have 'thickness' key" + + coords = point_data[i]["coords"] + thickness = point_data[i]["thickness"] + + assert coords.ndim == 3, "coords should be [T,N,3]" + assert coords.shape[-1] == 3, "coords last dimension should be 3" + assert thickness.ndim == 1, "thickness should be 1D" + assert len(thickness) == coords.shape[1], ( + "thickness length should match num_nodes" + ) + + +def test_process_zarr_data_no_stores(): + """Test that processing directory with no Zarr stores raises error.""" + with tempfile.TemporaryDirectory() as temp_dir: + with pytest.raises(ValueError, match="No .zarr stores found"): + zarr_reader.process_zarr_data( + data_dir=temp_dir, + num_samples=1, + ) + + +def test_process_zarr_data_validation(): + """Test that process_zarr_data validates data shapes.""" + with tempfile.TemporaryDirectory() as temp_dir: + store_path = Path(temp_dir) / "bad_store.zarr" + store_path.mkdir() + + # Create store with invalid mesh_pos shape (should be [T,N,3]) + store = zarr.open(str(store_path), mode="w") + store.create_dataset( + "mesh_pos", data=np.random.randn(3, 4, 2).astype(np.float32) + ) # Wrong last dim + store.create_dataset("thickness", data=np.ones(4, dtype=np.float32)) + store.create_dataset("edges", data=np.array([[0, 1]], dtype=np.int64)) + + with pytest.raises(ValueError, match="mesh_pos must be"): + zarr_reader.process_zarr_data( + data_dir=temp_dir, + num_samples=1, + ) + + +def test_process_zarr_data_edge_bounds(): + """Test that process_zarr_data validates edge indices are within bounds.""" + with tempfile.TemporaryDirectory() as temp_dir: + store_path = Path(temp_dir) / "bad_edges.zarr" + store_path.mkdir() + + num_nodes = 4 + store = zarr.open(str(store_path), mode="w") + store.create_dataset( + "mesh_pos", data=np.random.randn(3, num_nodes, 3).astype(np.float32) + ) + store.create_dataset("thickness", data=np.ones(num_nodes, dtype=np.float32)) + # Edge references node 10 which is out of bounds + store.create_dataset("edges", data=np.array([[0, 10]], dtype=np.int64)) + + with pytest.raises(ValueError, match="Edge indices out of bounds"): + zarr_reader.process_zarr_data( + data_dir=temp_dir, + num_samples=1, + ) + + +def test_reader_class(mock_zarr_directory): + """Test the Reader class callable interface.""" + reader = zarr_reader.Reader() + + srcs, dsts, point_data = reader( + data_dir=mock_zarr_directory, + num_samples=2, + split="train", + ) + + assert len(srcs) == 2 + assert len(dsts) == 2 + assert len(point_data) == 2 + + +def test_natural_sorting(mock_zarr_directory): + """Test that Zarr stores are sorted naturally (Run1, Run2, ..., Run10).""" + temp_path = Path(mock_zarr_directory) + + # Add more stores with different numbering + for i in [10, 5, 20]: + store_path = temp_path / f"Run{i}.zarr" + create_mock_zarr_store(store_path) + + zarr_stores = zarr_reader.find_zarr_stores(mock_zarr_directory) + store_names = [Path(p).name for p in zarr_stores] + + # Should be sorted: Run000, Run001, Run002, Run5, Run10, Run20 + assert store_names[0] == "Run000.zarr" + assert store_names[1] == "Run001.zarr" + assert store_names[2] == "Run002.zarr" + assert "Run5.zarr" in store_names + assert "Run10.zarr" in store_names + assert "Run20.zarr" in store_names diff --git a/examples/structural_mechanics/crash/zarr_reader.py b/examples/structural_mechanics/crash/zarr_reader.py new file mode 100644 index 0000000000..150ff957a1 --- /dev/null +++ b/examples/structural_mechanics/crash/zarr_reader.py @@ -0,0 +1,233 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import numpy as np +import zarr + + +def find_zarr_stores(base_data_dir: str) -> list[str]: + """ + Find all Zarr stores (directories ending with .zarr) in the base directory. + + Args: + base_data_dir: Path to directory containing Zarr stores. + + Returns: + List of Zarr store paths sorted naturally. + """ + if not os.path.isdir(base_data_dir): + return [] + + zarr_stores = [ + os.path.join(base_data_dir, f) + for f in os.listdir(base_data_dir) + if f.endswith(".zarr") and os.path.isdir(os.path.join(base_data_dir, f)) + ] + + def natural_key(name): + """Natural sort key to handle numeric sorting.""" + return [ + int(s) if s.isdigit() else s.lower() + for s in re.findall(r"\d+|\D+", os.path.basename(name)) + ] + + return sorted(zarr_stores, key=natural_key) + + +def load_zarr_store(zarr_path: str): + """ + Load mesh positions, edges, and all point data fields from a Zarr store. + + Args: + zarr_path: Path to the Zarr store directory. + + Returns: + mesh_pos: (timesteps, num_nodes, 3) temporal positions + edges: (num_edges, 2) edge connectivity + point_data_dict: Dictionary of all point data fields (e.g., thickness, etc.) + """ + store = zarr.open(zarr_path, mode="r") + + # Read mesh positions (temporal coordinates) + if "mesh_pos" not in store: + raise KeyError(f"'mesh_pos' not found in Zarr store {zarr_path}") + mesh_pos = np.array(store["mesh_pos"][:], dtype=np.float64) + + # Read edges + if "edges" not in store: + raise KeyError(f"'edges' not found in Zarr store {zarr_path}") + edges = np.array(store["edges"][:], dtype=np.int64) + + # Extract all other datasets as point data (excluding mesh-level data) + # Skip: mesh_pos, edges, mesh_connectivity_* (these are not per-node features) + point_data_dict = {} + for name in store.keys(): + if name in ("mesh_pos", "edges"): + continue + if name.startswith("mesh_connectivity_"): + continue + # Read as point data feature + point_data_dict[name] = np.array(store[name][:], dtype=np.float32) + + return mesh_pos, edges, point_data_dict + + +def process_zarr_data( + data_dir: str, + num_samples: int, + logger=None, +): + """ + Process Zarr crash simulation data from a given directory. + + Each .zarr store is treated as one sample. Reads mesh positions, edges, + and all available point data fields (e.g., thickness, etc.) from the Zarr stores. + + Args: + data_dir: Directory containing .zarr stores + num_samples: Maximum number of samples to process + logger: Optional logger for logging progress + + Returns: + srcs: List of source node indices for edges (one array per sample) + dsts: List of destination node indices for edges (one array per sample) + point_data_all: List of dicts with 'coords' and all point data fields + """ + zarr_stores = find_zarr_stores(data_dir) + + if not zarr_stores: + if logger: + logger.error(f"No .zarr stores found in: {data_dir}") + raise ValueError(f"No .zarr stores found in: {data_dir}") + + srcs, dsts = [], [] + point_data_all = [] + + processed_runs = 0 + for zarr_path in zarr_stores: + if processed_runs >= num_samples: + break + + if logger: + logger.info(f"Processing Zarr store: {os.path.basename(zarr_path)}") + + try: + mesh_pos, edges, point_data_dict = load_zarr_store(zarr_path) + + # Validate shapes + if mesh_pos.ndim != 3 or mesh_pos.shape[-1] != 3: + raise ValueError( + f"mesh_pos must be [T,N,3], got {mesh_pos.shape} in {zarr_path}" + ) + + if edges.ndim != 2 or edges.shape[-1] != 2: + raise ValueError( + f"edges must be [E,2], got {edges.shape} in {zarr_path}" + ) + + num_nodes = mesh_pos.shape[1] + + # Validate point data features + for name, data in point_data_dict.items(): + if data.ndim == 1: + if len(data) != num_nodes: + raise ValueError( + f"Point data '{name}' length {len(data)} doesn't match " + f"number of nodes {num_nodes} in {zarr_path}" + ) + elif data.ndim == 2: + if data.shape[0] != num_nodes: + raise ValueError( + f"Point data '{name}' shape {data.shape} doesn't match " + f"number of nodes {num_nodes} in {zarr_path}" + ) + else: + raise ValueError( + f"Point data '{name}' must be [N] or [N,K], got shape {data.shape} in {zarr_path}" + ) + + # Validate edge indices are within bounds + if edges.size > 0: + if edges.min() < 0 or edges.max() >= num_nodes: + raise ValueError( + f"Edge indices out of bounds [0, {num_nodes - 1}] in {zarr_path}" + ) + + # Extract source and destination node indices from edges + src, dst = edges.T + srcs.append(src) + dsts.append(dst) + + # Create record with coordinates and all point data fields + record = {"coords": mesh_pos} + record.update(point_data_dict) # Add all point data features dynamically + point_data_all.append(record) + + processed_runs += 1 + + except Exception as e: + if logger: + logger.error(f"Error processing {zarr_path}: {e}") + raise + + if logger: + logger.info(f"Successfully processed {processed_runs} Zarr stores") + + return srcs, dsts, point_data_all + + +class Reader: + """ + Reader for Zarr crash simulation stores. + + This reader loads preprocessed crash simulation data from Zarr stores + created by the PhysicsNeMo Curator ETL pipeline. + """ + + def __init__(self): + """Initialize the Zarr reader.""" + pass + + def __call__( + self, + data_dir: str, + num_samples: int, + split: str | None = None, + logger=None, + **kwargs, + ): + """ + Load Zarr crash simulation data. + + Args: + data_dir: Directory containing .zarr stores + num_samples: Number of samples to load + split: Data split ('train', 'validation', 'test') - not used for Zarr + logger: Optional logger + **kwargs: Additional arguments (ignored) + + Returns: + srcs: List of source node arrays for graph edges + dsts: List of destination node arrays for graph edges + point_data: List of dicts with 'coords' and all available point data fields + """ + return process_zarr_data( + data_dir=data_dir, + num_samples=num_samples, + logger=logger, + ) diff --git a/physicsnemo/core/registry.py b/physicsnemo/core/registry.py index 71c2c4b3af..0150e7611e 100644 --- a/physicsnemo/core/registry.py +++ b/physicsnemo/core/registry.py @@ -28,12 +28,13 @@ ENTRY_POINT_CLASSES = [ EntryPoint, ] -try: - from importlib_metadata import EntryPoint as EntryPointOld # noqa: E402 +# This is now deprecated, since EntryPoint is python 3.10 or higher. +# try: +# from importlib_metadata import EntryPoint as EntryPointOld # noqa: E402 - ENTRY_POINT_CLASSES.append(EntryPointOld) -except ImportError: - pass +# ENTRY_POINT_CLASSES.append(EntryPointOld) +# except ImportError: +# pass # This model registry follows conventions similar to fsspec, diff --git a/physicsnemo/core/version_check.py b/physicsnemo/core/version_check.py index c5b9195fde..250ef29f70 100644 --- a/physicsnemo/core/version_check.py +++ b/physicsnemo/core/version_check.py @@ -18,96 +18,96 @@ """ Utilities for version compatibility checking. -Specifically in use to prevent some newer physicsnemo modules from being used with -and older version of pytorch. - +This is used to provide a uniform and consistent way to check for missing +packages, when not all packages are required for the base physicsnemo +install. Additionally, for some packages (it's not mandatory to do this), +we have a registry of packages -> install tip that is used +to provide a helpful error message. """ -import importlib +import functools +from importlib import metadata from typing import Optional -from packaging import version +from packaging.specifiers import SpecifierSet +from packaging.version import Version + +install_cmds = { + "cupy": "pip install cupy-cuda13", + "cuml": "pip install cuml-cu13", + "scipy": "pip install scipy", +} -# Dictionary mapping module paths to their version requirements -# This can be expanded as needed for different modules -VERSION_REQUIREMENTS = { - "physicsnemo.distributed.shard_tensor": {"torch": "2.5.9"}, - "device_mesh": {"torch": "2.4.0"}, +extra_info = { + "cupy": "For more details about installing cupy, see https://docs.cupy.dev/en/stable/install.html/.", + "cuml": "For more details about installing cuml, see https://docs.rapids.ai/install/.", + "scipy": "For more details about installing scipy, see https://www.scipy.org/install/.", } -def check_min_version( - package_name: str, - min_version: str, +@functools.lru_cache(maxsize=None) +def get_installed_version(distribution_name: str) -> Optional[str]: + """ + Return the installed version for a given distribution without importing it. + Uses importlib.metadata to avoid heavy import-time side effects. + Cached for repeated lookups. + """ + try: + return metadata.version(distribution_name) + except metadata.PackageNotFoundError: + return None + + +def check_version_spec( + distribution_name: str, + spec: str, + *, error_msg: Optional[str] = None, hard_fail: bool = True, ) -> bool: """ - Check if an installed package meets the minimum version requirement. + Check whether the installed distribution satisfies a PEP 440 version specifier. Args: - package_name: Name of the package to check - min_version: Minimum required version string (e.g. '2.6.0') + distribution_name: Distribution (package) name as installed by pip + spec: PEP 440 version specifier (e.g., '>=2.4,<2.6') error_msg: Optional custom error message hard_fail: Whether to raise an ImportError if the version requirement is not met Returns: - True if version requirement is met + True if version requirement is met; False if not and hard_fail=False Raises: - ImportError: If package is not installed or version is too low + ImportError: If package is not installed or requirement not satisfied (and hard_fail=True) """ - try: - package = importlib.import_module(package_name) - package_version = getattr(package, "__version__", "0.0.0") - except ImportError: + installed = get_installed_version(distribution_name) + if installed is None: if hard_fail: - raise ImportError(f"Package {package_name} is required but not installed.") + raise ImportError( + f"Package '{distribution_name}' is required but not installed." + ) else: return False - if version.parse(package_version) < version.parse(min_version): + ok = Version(installed) in SpecifierSet(spec) + if not ok: msg = ( error_msg - or f"{package_name} version {min_version} or higher is required, but found {package_version}" + or f"{distribution_name} {spec} is required, but found {installed}" ) if hard_fail: raise ImportError(msg) - else: - return False + return False return True -def check_module_requirements(module_path: str, hard_fail: bool = True) -> bool: +def require_version_spec(package_name: str, spec: str = ">=0.0.0"): """ - Check all version requirements for a specific module. - - Args: - module_path: The import path of the module to check requirements for - - Raises: - ImportError: If any requirement is not met - """ - if module_path not in VERSION_REQUIREMENTS: - return - - requirements_pass = True - - for package, min_version in VERSION_REQUIREMENTS[module_path].items(): - result = check_min_version(package, min_version, hard_fail=hard_fail) - requirements_pass = requirements_pass and result - - return requirements_pass - - -def require_version(package_name: str, min_version: str): - """ - Decorator that prevents a function from being called unless the - specified package meets the minimum version requirement. + Decorator variant that accepts a full PEP 440 specifier instead of a single minimum version. Args: package_name: Name of the package to check - min_version: Minimum required version string (e.g. '2.3') + spec: PEP 440 version specifier (e.g., '>=2.4,<2.6') Returns: Decorator function that checks version requirement before execution @@ -120,16 +120,52 @@ def my_function(): """ def decorator(func): - import functools - @functools.wraps(func) def wrapper(*args, **kwargs): - # Verify the package meets minimum version before executing - check_min_version(package_name, min_version) - - # If we get here, version check passed + check_version_spec(package_name, spec, hard_fail=True) return func(*args, **kwargs) return wrapper return decorator + + +def ensure_available( + distribution_name: str, + spec: str = ">=0.0.0", + *, + install_hint: Optional[str] = None, + extra_message: Optional[str] = None, + hard_fail: bool = True, +) -> bool: + """ + Ensure a distribution is installed and satisfies the given specifier. + If not satisfied: + - When hard_fail=True, raises ImportError with an actionable message + - When hard_fail=False, returns False + + Args: + distribution_name: Distribution (package) name as installed by pip + spec: PEP 440 specifier (e.g., '>=24.0.0', '>=2.4,<2.6') + install_hint: Optional string suggesting how to install (e.g., "pip install cupy-cuda12x") + extra_message: Optional extra context to append to the error + hard_fail: Whether to raise on failure + """ + try: + return check_version_spec(distribution_name, spec, hard_fail=True) + except ImportError as e: + if not hard_fail: + return False + msg_parts = [str(e)] + # If not provided, and we have a hint above, use it: + if install_hint is None and distribution_name in install_cmds: + install_hint = install_cmds[distribution_name] + # If not provided, and we have extra info above, use it: + if extra_message is None and distribution_name in extra_message: + extra_message = extra_info[distribution_name] + + if install_hint: + msg_parts.append(f"Install hint: {install_hint}") + if extra_message: + msg_parts.append(extra_message) + raise ImportError(" | ".join(msg_parts)) diff --git a/physicsnemo/datapipes/climate/climate.py b/physicsnemo/datapipes/climate/climate.py index a7d041404b..264cbd19e9 100644 --- a/physicsnemo/datapipes/climate/climate.py +++ b/physicsnemo/datapipes/climate/climate.py @@ -46,7 +46,7 @@ from physicsnemo.datapipes.climate.utils.zenith_angle import cos_zenith_angle from physicsnemo.datapipes.datapipe import Datapipe from physicsnemo.datapipes.meta import DatapipeMetaData -from physicsnemo.launch.logging import PythonLogger +from physicsnemo.utils.logging import PythonLogger Tensor = torch.Tensor diff --git a/physicsnemo/distributed/manager.py b/physicsnemo/distributed/manager.py index 890222ad25..2d51d2ca4f 100644 --- a/physicsnemo/distributed/manager.py +++ b/physicsnemo/distributed/manager.py @@ -25,7 +25,7 @@ import torch import torch.distributed as dist -from physicsnemo.core.version_check import check_min_version, require_version +from physicsnemo.core.version_check import check_version_spec, require_version_spec from physicsnemo.distributed.config import ProcessGroupConfig, ProcessGroupNode # warnings.simplefilter("default", DeprecationWarning) @@ -179,7 +179,7 @@ def global_mesh(self): """ # Properties don't mesh with decorators. So in this function, I call the check manually: - check_min_version("torch", "2.4") + check_version_spec("torch", ">=2.4", hard_fail=True) if self._global_mesh is None: # Fully flat mesh (1D) by default: @@ -187,14 +187,14 @@ def global_mesh(self): return self._global_mesh - @require_version("torch", "2.4") + @require_version_spec("torch", ">=2.4") def mesh_names(self): """ Return mesh axis names """ return self._mesh_dims.keys() - @require_version("torch", "2.4") + @require_version_spec("torch", ">=2.4") def mesh_sizes(self): """ Return mesh axis sizes @@ -214,7 +214,7 @@ def group(self, name=None): else: raise PhysicsNeMoUndefinedGroupError(name) - @require_version("torch", "2.4") + @require_version_spec("torch", ">=2.4") def mesh(self, name=None): """ Return a device_mesh with the given name. @@ -434,7 +434,7 @@ def initialize(): # Set per rank numpy random seed for data sampling np.random.seed(seed=DistributedManager().rank) - @require_version("torch", "2.4") + @require_version_spec("torch", ">=2.4") def initialize_mesh( self, mesh_shape: Tuple[int, ...], mesh_dim_names: Tuple[str, ...] ) -> "torch.distributed.DeviceMesh": @@ -521,7 +521,7 @@ def initialize_mesh( return self._global_mesh # Device mesh available in torch 2.4 or higher - @require_version("torch", "2.4") + @require_version_spec("torch", ">=2.4") def get_mesh_group(self, mesh: "dist.DeviceMesh") -> dist.ProcessGroup: """ Get the process group for a given mesh. diff --git a/physicsnemo/models/diffusion/song_unet.py b/physicsnemo/models/diffusion/song_unet.py index 7160168488..8b7852139b 100644 --- a/physicsnemo/models/diffusion/song_unet.py +++ b/physicsnemo/models/diffusion/song_unet.py @@ -234,7 +234,7 @@ class SongUNet(Module): architectures. Despite the name, these embeddings encode temporal information about the diffusion process rather than spatial position information. • Limitations on input image resolution: for a model that has :math:`N` levels, - the latent state :math:`\mathbf{x}` must have resolution that is a multiple of :math:`2^N` in each dimension. + the latent state :math:`\mathbf{x}` must have resolution that is a multiple of :math:`2^{N-1}` in each dimension. This is due to a limitation in the decoder that does not support shape mismatch in the residual connections from the encoder to the decoder. For images that do not match this requirement, it is recommended to interpolate your data on a grid of the required resolution @@ -337,7 +337,7 @@ def __init__( self.img_shape_x = img_resolution[1] self._num_levels = len(channel_mult) - self._input_shape_mult = 2**self._num_levels + self._input_shape_mult = 2 ** (self._num_levels - 1) # set the threshold for checkpointing based on image resolution self.checkpoint_threshold = ( @@ -534,7 +534,7 @@ def forward(self, x, noise_labels, class_labels, augment_labels=None): f"got {x.ndim}D tensor with shape {tuple(x.shape)}" ) - # Check spatial dimensions are powers of 2 or multiples of 2^N + # Check spatial dimensions are powers of 2 or multiples of 2^{N-1} for d in x.shape[-2:]: # Check if d is a power of 2 is_power_of_2 = (d & (d - 1)) == 0 and d > 0 @@ -545,7 +545,7 @@ def forward(self, x, noise_labels, class_labels, augment_labels=None): ): raise ValueError( f"Input spatial dimensions ({x.shape[-2:]}) must be " - f"either powers of 2 or multiples of 2**N where " + f"either powers of 2 or multiples of 2**(N-1) where " f"N (={self._num_levels}) is the number of levels " f"in the U-Net." ) diff --git a/physicsnemo/models/domino/utils/__init__.py b/physicsnemo/models/domino/utils/__init__.py index 35b0d1b0f6..0acd78e250 100644 --- a/physicsnemo/models/domino/utils/__init__.py +++ b/physicsnemo/models/domino/utils/__init__.py @@ -21,6 +21,7 @@ calculate_pos_encoding, combine_dict, create_grid, + get_filenames, mean_std_sampling, nd_interpolator, normalize, diff --git a/physicsnemo/nn/neighbors/_knn/_cuml_impl.py b/physicsnemo/nn/neighbors/_knn/_cuml_impl.py index 0f40a495fc..e9fd6354ae 100644 --- a/physicsnemo/nn/neighbors/_knn/_cuml_impl.py +++ b/physicsnemo/nn/neighbors/_knn/_cuml_impl.py @@ -14,15 +14,19 @@ # See the License for the specific language governing permissions and # limitations under the License. + +import importlib + import torch from physicsnemo.core.version_check import check_min_version CUML_AVAILABLE = check_min_version("cuml", "24.0.0", hard_fail=False) +CUPY_AVAILABLE = check_min_version("cupy", "13.0.0", hard_fail=False) -if CUML_AVAILABLE: - import cuml - import cupy as cp +if CUML_AVAILABLE and CUPY_AVAILABLE: + cuml = importlib.import_module("cuml") + cp = importlib.import_module("cupy") @torch.library.custom_op("physicsnemo::knn_cuml", mutates_args=()) def knn_impl( diff --git a/physicsnemo/nn/neighbors/_knn/_scipy_impl.py b/physicsnemo/nn/neighbors/_knn/_scipy_impl.py index e28ec502e2..894d20bde6 100644 --- a/physicsnemo/nn/neighbors/_knn/_scipy_impl.py +++ b/physicsnemo/nn/neighbors/_knn/_scipy_impl.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import importlib + import torch from physicsnemo.core.version_check import check_min_version @@ -21,7 +23,7 @@ SCIPY_AVAILABLE = check_min_version("scipy", "1.7.0", hard_fail=False) if SCIPY_AVAILABLE: - from scipy.spatial import KDTree + KDTree = importlib.import_module("scipy.spatial").KDTree @torch.library.custom_op("physicsnemo::knn_scipy", mutates_args=()) def knn_impl( diff --git a/physicsnemo/utils/logging/__init__.py b/physicsnemo/utils/logging/__init__.py index 766dee4ec7..59a36a8ab6 100644 --- a/physicsnemo/utils/logging/__init__.py +++ b/physicsnemo/utils/logging/__init__.py @@ -15,4 +15,4 @@ # limitations under the License. from .console import PythonLogger, RankZeroLoggingWrapper -from .launch import LaunchLogger +# from .launch import LaunchLogger diff --git a/physicsnemo/utils/logging/launch.py b/physicsnemo/utils/logging/launch.py index 4e97cca893..c95892aeae 100644 --- a/physicsnemo/utils/logging/launch.py +++ b/physicsnemo/utils/logging/launch.py @@ -27,6 +27,9 @@ from physicsnemo.distributed import DistributedManager, reduce_loss from .console import PythonLogger +from .wandb import _WANDB_AVAILABLE +from .wandb import alert as _wandb_alert +from .wandb import wandb as _wandb class LaunchLogger(object): @@ -130,11 +133,12 @@ def __init__( self.total_iteration_index = 0 # Set x axis metric to epoch for this namespace - if self.wandb_backend: - import wandb - - wandb.define_metric(name_space + "/mini_batch_*", step_metric="iter") - wandb.define_metric(name_space + "/*", step_metric="epoch") + if self.wandb_backend and _WANDB_AVAILABLE: + _wandb.define_metric(name_space + "/mini_batch_*", step_metric="iter") + _wandb.define_metric(name_space + "/*", step_metric="epoch") + elif self.wandb_backend: + self.pyLogger.warning("WandB not installed, turning off") + self.__class__.wandb_backend = False def log_minibatch(self, losses: Dict[str, float]): """Logs metrics for a mini-batch epoch @@ -283,16 +287,15 @@ def __exit__(self, exc_type, exc_value, exc_tb): and self.root and self.epoch % self.epoch_alert_freq == 0 ): - if self.wandb_backend: - import wandb - - from .wandb import alert - + if self.wandb_backend and _WANDB_AVAILABLE: # TODO: Make this a little more informative? - alert( + _wandb_alert( title=f"{sys.argv[0]} training progress report", - text=f"Run {wandb.run.name} is at epoch {self.epoch}.", + text=f"Run {_wandb.run.name} is at epoch {self.epoch}.", ) + elif self.wandb_backend: + self.pyLogger.warning("WandB not installed, turning off") + self.__class__.wandb_backend = False def _log_backends( self, @@ -324,14 +327,15 @@ def _log_backends( ) # WandB Logging - if self.wandb_backend: - import wandb - + if self.wandb_backend and _WANDB_AVAILABLE: # For WandB send step in as a metric # Step argument in lod function does not work with multiple log calls at # different intervals metric_dict[step[0]] = step[1] - wandb.log(metric_dict) + _wandb.log(metric_dict) + elif self.wandb_backend: + self.pyLogger.warning("WandB not installed, turning off") + self.__class__.wandb_backend = False def log_figure( self, @@ -357,10 +361,11 @@ def log_figure( if dist.rank != 0: return - if self.wandb_backend: - import wandb - - wandb.log({artifact_file: figure}) + if self.wandb_backend and _WANDB_AVAILABLE: + _wandb.log({artifact_file: figure}) + elif self.wandb_backend: + self.pyLogger.warning("WandB not installed, turning off") + self.__class__.wandb_backend = False if self.mlflow_backend: self.mlflow_client.log_figure( @@ -414,16 +419,17 @@ def initialize(use_wandb: bool = False, use_mlflow: bool = False): Use MLFlow logging, by default False """ if use_wandb: - import wandb - - if wandb.run is None: + if not _WANDB_AVAILABLE: + PythonLogger().warning("WandB not installed, turning off") + use_wandb = False + elif _wandb.run is None: PythonLogger().warning("WandB not initialized, turning off") use_wandb = False if use_wandb: LaunchLogger.toggle_wandb(True) - wandb.define_metric("epoch") - wandb.define_metric("iter") + _wandb.define_metric("epoch") + _wandb.define_metric("iter") # let only root process log to mlflow if DistributedManager.is_initialized(): diff --git a/physicsnemo/utils/logging/mlflow.py b/physicsnemo/utils/logging/mlflow.py index fbf3de64a4..f0bd442752 100644 --- a/physicsnemo/utils/logging/mlflow.py +++ b/physicsnemo/utils/logging/mlflow.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import importlib import os import time from datetime import datetime @@ -22,178 +23,198 @@ import torch -try: - import mlflow # noqa: F401 for docs - from mlflow.entities.run import Run - from mlflow.tracking import MlflowClient -except ImportError: - raise ImportError( - "These utilities require the MLFlow library. Install MLFlow using `pip install mlflow`. " - + "For more info, refer: https://www.mlflow.org/docs/2.5.0/quickstart.html#install-mlflow" - ) - +from physicsnemo.core.version_check import check_min_version from physicsnemo.distributed import DistributedManager from .console import PythonLogger from .launch import LaunchLogger +MLFLOW_AVAILABLE = check_min_version("mlflow", "2.5.0", hard_fail=False) + + logger = PythonLogger("mlflow") +if MLFLOW_AVAILABLE: + mlflow = importlib.import_module("mlflow") + Run = importlib.import_module("mlflow.entities.run").Run + MlflowClient = importlib.import_module("mlflow.tracking").MlflowClient + + def initialize_mlflow( + experiment_name: str, + experiment_desc: str = None, + run_name: str = None, + run_desc: str = None, + user_name: str = None, + mode: Literal["offline", "online", "ngc"] = "offline", + tracking_location: str = None, + artifact_location: str = None, + ) -> Tuple[MlflowClient, Run]: + """Initializes MLFlow logging client and run. + + Parameters + ---------- + experiment_name : str + Experiment name + experiment_desc : str, optional + Experiment description, by default None + run_name : str, optional + Run name, by default None + run_desc : str, optional + Run description, by default None + user_name : str, optional + User name, by default None + mode : str, optional + MLFlow mode. Supports "offline", "online" and "ngc". Offline mode records logs to + local file system. Online mode is for remote tracking servers. NGC is specific + standardized setup for NGC runs, default "offline" + tracking_location : str, optional + Tracking location for MLFlow. For offline this would be an absolute folder directory. + For online mode this would be a http URI or databricks. For NGC, this option is + ignored, by default "//mlruns" + artifact_location : str, optional + Optional separate artifact location, by default None + + Note + ---- + For NGC mode, one needs to mount a NGC workspace / folder system with a metric folder + at `/mlflow/mlflow_metrics/` and a artifact folder at `/mlflow/mlflow_artifacts/`. + + Note + ---- + This will set up PhysicsNeMo Launch logger for MLFlow logging. Only one MLFlow logging + client is supported with the PhysicsNeMo Launch logger. + + Returns + ------- + Tuple[MlflowClient, Run] + Returns MLFlow logging client and active run object + """ + dist = DistributedManager() + if dist.rank != 0: # only root process should be logging to mlflow + return + + start_time = datetime.now().astimezone() + time_string = start_time.strftime("%m/%d/%y_%H-%M-%S") + group_name = f"{run_name}_{time_string}" + + # Set default value here for Hydra + if tracking_location is None: + tracking_location = str(Path("./mlruns").absolute()) + + # Set up URI (remote or local) + if mode == "online": + tracking_uri = tracking_location + elif mode == "offline": + if not tracking_location.startswith("file://"): + tracking_location = "file://" + tracking_location + tracking_uri = tracking_location + elif mode == "ngc": + if not Path("/mlflow/mlflow_metrics").is_dir(): + raise IOError( + "NGC MLFlow config select but metrics folder '/mlflow/mlflow_metrics'" + + " not found. Aborting MLFlow setup." + ) + return + + if not Path("/mlflow/mlflow_artifacts").is_dir(): + raise IOError( + "NGC MLFlow config select but artifact folder '/mlflow/mlflow_artifacts'" + + " not found. Aborting MLFlow setup." + ) + return + tracking_uri = "file:///mlflow/mlflow_metrics" + artifact_location = "file:///mlflow/mlflow_artifacts" + else: + logger.warning(f"Unsupported MLFlow mode '{mode}' provided") + tracking_uri = "file://" + str(Path("./mlruns").absolute()) + + mlflow.set_tracking_uri(tracking_uri) + client = MlflowClient() + + check_mlflow_logged_in(client) -def initialize_mlflow( - experiment_name: str, - experiment_desc: str = None, - run_name: str = None, - run_desc: str = None, - user_name: str = None, - mode: Literal["offline", "online", "ngc"] = "offline", - tracking_location: str = None, - artifact_location: str = None, -) -> Tuple[MlflowClient, Run]: - """Initializes MLFlow logging client and run. - - Parameters - ---------- - experiment_name : str - Experiment name - experiment_desc : str, optional - Experiment description, by default None - run_name : str, optional - Run name, by default None - run_desc : str, optional - Run description, by default None - user_name : str, optional - User name, by default None - mode : str, optional - MLFlow mode. Supports "offline", "online" and "ngc". Offline mode records logs to - local file system. Online mode is for remote tracking servers. NGC is specific - standardized setup for NGC runs, default "offline" - tracking_location : str, optional - Tracking location for MLFlow. For offline this would be an absolute folder directory. - For online mode this would be a http URI or databricks. For NGC, this option is - ignored, by default "//mlruns" - artifact_location : str, optional - Optional separate artifact location, by default None - - Note - ---- - For NGC mode, one needs to mount a NGC workspace / folder system with a metric folder - at `/mlflow/mlflow_metrics/` and a artifact folder at `/mlflow/mlflow_artifacts/`. - - Note - ---- - This will set up PhysicsNeMo Launch logger for MLFlow logging. Only one MLFlow logging - client is supported with the PhysicsNeMo Launch logger. - - Returns - ------- - Tuple[MlflowClient, Run] - Returns MLFlow logging client and active run object - """ - dist = DistributedManager() - if dist.rank != 0: # only root process should be logging to mlflow - return - - start_time = datetime.now().astimezone() - time_string = start_time.strftime("%m/%d/%y_%H-%M-%S") - group_name = f"{run_name}_{time_string}" - - # Set default value here for Hydra - if tracking_location is None: - tracking_location = str(Path("./mlruns").absolute()) - - # Set up URI (remote or local) - if mode == "online": - tracking_uri = tracking_location - elif mode == "offline": - if not tracking_location.startswith("file://"): - tracking_location = "file://" + tracking_location - tracking_uri = tracking_location - elif mode == "ngc": - if not Path("/mlflow/mlflow_metrics").is_dir(): - raise IOError( - "NGC MLFlow config select but metrics folder '/mlflow/mlflow_metrics'" - + " not found. Aborting MLFlow setup." + experiment = client.get_experiment_by_name(experiment_name) + # If experiment does not exist create one + if experiment is None: + logger.info(f"No {experiment_name} experiment found, creating...") + experiment_id = client.create_experiment( + experiment_name, artifact_location=artifact_location ) - return + client.set_experiment_tag( + experiment_id, "mlflow.note.content", experiment_desc + ) + else: + logger.success(f"Existing {experiment_name} experiment found") + experiment_id = experiment.experiment_id - if not Path("/mlflow/mlflow_artifacts").is_dir(): - raise IOError( - "NGC MLFlow config select but artifact folder '/mlflow/mlflow_artifacts'" - + " not found. Aborting MLFlow setup." + # Create an run and set its tags + run = client.create_run( + experiment_id, tags={"mlflow.user": user_name}, run_name=run_name + ) + client.set_tag(run.info.run_id, "mlflow.note.content", run_desc) + + start_time = datetime.now().astimezone() + time_string = start_time.strftime("%m/%d/%y %H:%M:%S") + client.set_tag(run.info.run_id, "date", time_string) + client.set_tag(run.info.run_id, "host", os.uname()[1]) + if torch.cuda.is_available(): + client.set_tag( + run.info.run_id, "gpu", torch.cuda.get_device_name(dist.device) ) - return - tracking_uri = "file:///mlflow/mlflow_metrics" - artifact_location = "file:///mlflow/mlflow_artifacts" - else: - logger.warning(f"Unsupported MLFlow mode '{mode}' provided") - tracking_uri = "file://" + str(Path("./mlruns").absolute()) - - mlflow.set_tracking_uri(tracking_uri) - client = MlflowClient() - - check_mlflow_logged_in(client) - - experiment = client.get_experiment_by_name(experiment_name) - # If experiment does not exist create one - if experiment is None: - logger.info(f"No {experiment_name} experiment found, creating...") - experiment_id = client.create_experiment( - experiment_name, artifact_location=artifact_location + client.set_tag(run.info.run_id, "group", group_name) + + run = client.get_run(run.info.run_id) + + # Set run instance in PhysicsNeMo logger + LaunchLogger.mlflow_run = run + LaunchLogger.mlflow_client = client + + return client, run + + def check_mlflow_logged_in(client: MlflowClient): + """Checks to see if MLFlow URI is functioning + + This isn't the best solution right now and overrides http timeout. Can update if MLFlow + use is increased. + """ + + logger.warning( + "Checking MLFlow logging location is working (if this hangs it's not)" + ) + t0 = os.environ.get("MLFLOW_HTTP_REQUEST_TIMEOUT", None) + try: + # Adjust http timeout to 5 seconds + os.environ["MLFLOW_HTTP_REQUEST_TIMEOUT"] = ( + str(max(int(t0), 5)) if t0 else "5" + ) + experiment = client.create_experiment(f"test-{int(time.time())}") + client.delete_experiment(experiment) + + except Exception as e: + logger.error("Failed to validate MLFlow logging location works") + raise e + finally: + # Restore http request + if t0: + os.environ["MLFLOW_HTTP_REQUEST_TIMEOUT"] = t0 + else: + del os.environ["MLFLOW_HTTP_REQUEST_TIMEOUT"] + + logger.success("MLFlow logging location is working") + +else: + + def initialize_mlflow( + *args, + **kwargs, + ): + raise ImportError( + "These utilities require the MLFlow library. Install MLFlow using `pip install mlflow`. " + + "For more info, refer: https://www.mlflow.org/docs/2.5.0/quickstart.html#install-mlflow" ) - client.set_experiment_tag(experiment_id, "mlflow.note.content", experiment_desc) - else: - logger.success(f"Existing {experiment_name} experiment found") - experiment_id = experiment.experiment_id - - # Create an run and set its tags - run = client.create_run( - experiment_id, tags={"mlflow.user": user_name}, run_name=run_name - ) - client.set_tag(run.info.run_id, "mlflow.note.content", run_desc) - - start_time = datetime.now().astimezone() - time_string = start_time.strftime("%m/%d/%y %H:%M:%S") - client.set_tag(run.info.run_id, "date", time_string) - client.set_tag(run.info.run_id, "host", os.uname()[1]) - if torch.cuda.is_available(): - client.set_tag(run.info.run_id, "gpu", torch.cuda.get_device_name(dist.device)) - client.set_tag(run.info.run_id, "group", group_name) - - run = client.get_run(run.info.run_id) - - # Set run instance in PhysicsNeMo logger - LaunchLogger.mlflow_run = run - LaunchLogger.mlflow_client = client - - return client, run - - -def check_mlflow_logged_in(client: MlflowClient): - """Checks to see if MLFlow URI is functioning - - This isn't the best solution right now and overrides http timeout. Can update if MLFlow - use is increased. - """ - - logger.warning( - "Checking MLFlow logging location is working (if this hangs it's not)" - ) - t0 = os.environ.get("MLFLOW_HTTP_REQUEST_TIMEOUT", None) - try: - # Adjust http timeout to 5 seconds - os.environ["MLFLOW_HTTP_REQUEST_TIMEOUT"] = str(max(int(t0), 5)) if t0 else "5" - experiment = client.create_experiment(f"test-{int(time.time())}") - client.delete_experiment(experiment) - - except Exception as e: - logger.error("Failed to validate MLFlow logging location works") - raise e - finally: - # Restore http request - if t0: - os.environ["MLFLOW_HTTP_REQUEST_TIMEOUT"] = t0 - else: - del os.environ["MLFLOW_HTTP_REQUEST_TIMEOUT"] - logger.success("MLFlow logging location is working") + def check_mlflow_logged_in(*args, **kwargs): + raise ImportError( + "These utilities require the MLFlow library. Install MLFlow using `pip install mlflow`. " + + "For more info, refer: https://www.mlflow.org/docs/2.5.0/quickstart.html#install-mlflow" + ) diff --git a/physicsnemo/utils/logging/wandb.py b/physicsnemo/utils/logging/wandb.py index e19042e943..a76651556d 100644 --- a/physicsnemo/utils/logging/wandb.py +++ b/physicsnemo/utils/logging/wandb.py @@ -16,121 +16,143 @@ """Weights and Biases Routines and Utilities""" +import importlib import logging import os from datetime import datetime from pathlib import Path from typing import Literal -import wandb -from wandb import AlertLevel - +from physicsnemo.core.version_check import check_min_version from physicsnemo.distributed import DistributedManager from .utils import create_ddp_group_tag -DEFAULT_WANDB_CONFIG = "~/.netrc" -logger = logging.getLogger(__name__) - -_WANDB_INITIALIZED = False - - -def initialize_wandb( - project: str, - entity: str, - name: str = "train", - group: str = None, - sync_tensorboard: bool = False, - save_code: bool = False, - resume: str = None, - wandb_id: str = None, - config=None, - mode: Literal["offline", "online", "disabled"] = "offline", - results_dir: str = None, - init_timeout: int = 90, -): - """Function to initialize wandb client with the weights and biases server. - - Parameters - ---------- - project : str - Name of the project to sync data with - entity : str, - Name of the wanbd entity - sync_tensorboard : bool, optional - sync tensorboard summary writer with wandb, by default False - save_code : bool, optional - Whether to push a copy of the code to wandb dashboard, by default False - name : str, optional - Name of the task running, by default "train" - group : str, optional - Group name of the task running. Good to set for ddp runs, by default None - resume: str, optional - Sets the resuming behavior. Options: "allow", "must", "never", "auto" or None, - by default None. - wandb_id: str, optional - A unique ID for this run, used for resuming. Used in conjunction with `resume` - parameter to enable experiment resuming. - See W&B documentation for more details: - https://docs.wandb.ai/guides/runs/resuming/ - config : optional - a dictionary-like object for saving inputs , like hyperparameters. - If dict, argparse or absl.flags, it will load the key value pairs into the - wandb.config object. If str, it will look for a yaml file by that name, - by default None. - mode: str, optional - Can be "offline", "online" or "disabled", by default "offline" - results_dir : str, optional - Output directory of the experiment, by default "//wandb" - init_timeout : int, optional - Timeout for wandb initialization, by default 90 seconds. - """ - - # Set default value here for Hydra - if results_dir is None: - results_dir = str(Path("./wandb").absolute()) - - wandb_dir = results_dir - if DistributedManager.is_initialized() and DistributedManager().distributed: - if group is None: - group = create_ddp_group_tag() - start_time = datetime.now().astimezone() - time_string = start_time.strftime("%m/%d/%y_%H:%M:%S") - wandb_name = f"{name}_Process_{DistributedManager().rank}_{time_string}" - else: - start_time = datetime.now().astimezone() - time_string = start_time.strftime("%m/%d/%y_%H:%M:%S") - wandb_name = f"{name}_{time_string}" - - if not os.path.exists(wandb_dir): - os.makedirs(wandb_dir, exist_ok=True) - - wandb.init( - project=project, - entity=entity, - sync_tensorboard=sync_tensorboard, - name=wandb_name, - resume=resume, - config=config, - mode=mode, - dir=wandb_dir, - group=group, - save_code=save_code, - id=wandb_id, - settings=wandb.Settings(init_timeout=init_timeout), - ) - - -def alert(title, text, duration=300, level=0, is_master=True): - """Send alert.""" - alert_levels = {0: AlertLevel.INFO, 1: AlertLevel.WARN, 2: AlertLevel.ERROR} - if is_wandb_initialized() and is_master: - wandb.alert( - title=title, text=text, level=alert_levels[level], wait_duration=duration +WANDB_AVAILABLE = check_min_version("wandb", "0.15.0", hard_fail=False) + +if WANDB_AVAILABLE: + wandb = importlib.import_module("wandb") + AlertLevel = importlib.import_module("wandb").AlertLevel + + DEFAULT_WANDB_CONFIG = "~/.netrc" + logger = logging.getLogger(__name__) + + _WANDB_INITIALIZED = False + + def initialize_wandb( + project: str, + entity: str, + name: str = "train", + group: str = None, + sync_tensorboard: bool = False, + save_code: bool = False, + resume: str = None, + wandb_id: str = None, + config=None, + mode: Literal["offline", "online", "disabled"] = "offline", + results_dir: str = None, + init_timeout: int = 90, + ): + """Function to initialize wandb client with the weights and biases server. + + Parameters + ---------- + project : str + Name of the project to sync data with + entity : str, + Name of the wanbd entity + sync_tensorboard : bool, optional + sync tensorboard summary writer with wandb, by default False + save_code : bool, optional + Whether to push a copy of the code to wandb dashboard, by default False + name : str, optional + Name of the task running, by default "train" + group : str, optional + Group name of the task running. Good to set for ddp runs, by default None + resume: str, optional + Sets the resuming behavior. Options: "allow", "must", "never", "auto" or None, + by default None. + wandb_id: str, optional + A unique ID for this run, used for resuming. Used in conjunction with `resume` + parameter to enable experiment resuming. + See W&B documentation for more details: + https://docs.wandb.ai/guides/runs/resuming/ + config : optional + a dictionary-like object for saving inputs , like hyperparameters. + If dict, argparse or absl.flags, it will load the key value pairs into the + wandb.config object. If str, it will look for a yaml file by that name, + by default None. + mode: str, optional + Can be "offline", "online" or "disabled", by default "offline" + results_dir : str, optional + Output directory of the experiment, by default "//wandb" + init_timeout : int, optional + Timeout for wandb initialization, by default 90 seconds. + """ + + # Set default value here for Hydra + if results_dir is None: + results_dir = str(Path("./wandb").absolute()) + + wandb_dir = results_dir + if DistributedManager.is_initialized() and DistributedManager().distributed: + if group is None: + group = create_ddp_group_tag() + start_time = datetime.now().astimezone() + time_string = start_time.strftime("%m/%d/%y_%H:%M:%S") + wandb_name = f"{name}_Process_{DistributedManager().rank}_{time_string}" + else: + start_time = datetime.now().astimezone() + time_string = start_time.strftime("%m/%d/%y_%H:%M:%S") + wandb_name = f"{name}_{time_string}" + + if not os.path.exists(wandb_dir): + os.makedirs(wandb_dir, exist_ok=True) + + wandb.init( + project=project, + entity=entity, + sync_tensorboard=sync_tensorboard, + name=wandb_name, + resume=resume, + config=config, + mode=mode, + dir=wandb_dir, + group=group, + save_code=save_code, + id=wandb_id, + settings=wandb.Settings(init_timeout=init_timeout), + ) + + def alert(title, text, duration=300, level=0, is_master=True): + """Send alert.""" + alert_levels = {0: AlertLevel.INFO, 1: AlertLevel.WARN, 2: AlertLevel.ERROR} + if is_wandb_initialized() and is_master: + wandb.alert( + title=title, + text=text, + level=alert_levels[level], + wait_duration=duration, + ) + + def is_wandb_initialized(): + """Check if wandb has been initialized.""" + global _WANDB_INITIALIZED + return _WANDB_INITIALIZED + +else: + + def _raise_wandb_not_installed(): + raise ImportError( + "These utilities require the WandB library. Install WandB using `pip install wandb`. " + + "For more info, refer: https://wandb.ai/site" ) + def initialize_wandb(*args, **kwargs): + _raise_wandb_not_installed() + + def alert(*args, **kwargs): + _raise_wandb_not_installed() -def is_wandb_initialized(): - """Check if wandb has been initialized.""" - global _WANDB_INITIALIZED - return _WANDB_INITIALIZED + def is_wandb_initialized(*args, **kwargs): + _raise_wandb_not_installed() diff --git a/prevent_untracked_imports.py b/prevent_untracked_imports.py new file mode 100644 index 0000000000..774471dc08 --- /dev/null +++ b/prevent_untracked_imports.py @@ -0,0 +1,202 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.util +import os +import sys +import sysconfig +from pathlib import Path +from typing import Dict, List, Set, Union + +import tomllib +from importlinter import Contract, ContractCheck, fields, output +from packaging.requirements import Requirement + +Dependency = Union[str, Dict[str, str]] + + +class ForbiddenImportContract(Contract): + """ + PhysicsNemo specific contract to prevent external imports + that are not included in requirements. + + This will, for each sub-package, check the external imports and ensure + via uv that the list dependencies encompass the entire import graph. + """ + + container = fields.StringField() + dependency_group = fields.StringField() + + def check(self, graph, verbose): + output.verbose_print( + verbose, + f"Getting import details from {self.container} vs uv group {self.dependency_group}...", + ) + + upstream_modules = graph.find_upstream_modules(self.container) + + # Remove any models that start with "physicsnemo": + upstream_modules = set( + module + for module in upstream_modules + if not module.startswith("physicsnemo") + ) + + upstream_external_modules = remove_standard_library(upstream_modules) + + # Now, read the tree from pyproject.toml: + dependency_tree = resolve_dependency_group_no_versions( + Path("pyproject.toml"), self.dependency_group + ) + + broken_imports = upstream_external_modules - dependency_tree + violations = {} + + for broken_import in broken_imports: + violations[broken_import] = graph.find_modules_that_directly_import( + broken_import + ) + violations[broken_import] = [ + v for v in violations[broken_import] if self.container in v + ] + + return ContractCheck( + kept=len(broken_imports) == 0, + metadata={ + "broken_imports": list(broken_imports), + "violations": violations, + }, + ) + + def render_broken_contract(self, check): + for broken_import in check.metadata["broken_imports"]: + violations = ", ".join(check.metadata["violations"][broken_import]) + output.print_error( + f"{self.container} is not allowed to import {broken_import} (from {violations})", + bold=True, + ) + output.new_line() + output.new_line() + + +def resolve_dependency_group_no_versions( + pyproject_path: str | Path, group_name: str +) -> List[str]: + """ + Open a uv-style pyproject.toml, recursively resolve a dependency group, + and strip version specifiers from all dependencies. + """ + pyproject_path = Path(pyproject_path) + with pyproject_path.open("rb") as f: + data = tomllib.load(f) + + dep_groups: Dict[str, List[Dependency]] = data.get("dependency-groups", {}) + + if group_name not in dep_groups: + raise KeyError(f"Dependency group '{group_name}' not found") + + def _resolve(group: str, seen: set[str] = None) -> List[str]: + if seen is None: + seen = set() + if group in seen: + return [] + seen.add(group) + deps: List[str] = [] + for item in dep_groups.get(group, []): + if isinstance(item, str): + # strip version using packaging + deps.append(Requirement(item).name) + elif isinstance(item, dict) and "include-group" in item: + deps.extend(_resolve(item["include-group"], seen)) + else: + raise ValueError(f"Unknown dependency format: {item}") + return deps + + # remove duplicates while preserving order + resolved = _resolve(group_name) + seen_ordered = set() + return set([d for d in resolved if not (d in seen_ordered or seen_ordered.add(d))]) + + +def flatten_deps(tree: Dict) -> Set[str]: + """Flatten nested dependency dict into a set of package names.""" + packages = set() + + def recurse(d: Dict): + for name, info in d.items(): + packages.add(name.replace("-", "_")) # normalize for imports + recurse(info["dependencies"]) + + recurse(tree) + return packages + + +def remove_standard_library(packages: Set[str]) -> Set[str]: + """Remove standard library packages from the set of packages. + + Heuristics: + - Builtins (sys.builtin_module_names) + - sys.stdlib_module_names (when available, Python 3.10+) + - importlib spec origin located within sysconfig stdlib/platstdlib + - 'built-in' or 'frozen' origins + """ + builtin_names = set(sys.builtin_module_names) + stdlib_names = set(getattr(sys, "stdlib_module_names", ())) + + stdlib_dirs = { + d + for d in { + sysconfig.get_path("stdlib"), + sysconfig.get_path("platstdlib"), + } + if d + } + stdlib_dirs = {os.path.realpath(d) for d in stdlib_dirs} + + def is_in_stdlib_path(path: str) -> bool: + if not path: + return False + real = os.path.realpath(path) + for d in stdlib_dirs: + # Match dir itself or any descendant + if real == d or real.startswith(d + os.sep): + return True + return False + + def is_stdlib(mod_name: str) -> bool: + # Fast checks + if mod_name in builtin_names or mod_name in stdlib_names: + return True + + spec = importlib.util.find_spec(mod_name) + if spec is None: + return False + + # Built-in/frozen indicators + if spec.origin in ("built-in", "frozen"): + return True + + # Package locations + if spec.submodule_search_locations: + for loc in spec.submodule_search_locations: + if is_in_stdlib_path(loc): + return True + return False + + # Modules + return is_in_stdlib_path(spec.origin) + + return {p for p in packages if not is_stdlib(p)} diff --git a/pyproject.toml b/pyproject.toml index 95b30c437a..ce75a83248 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,76 @@ -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" +[tool.uv] +no-build-isolation-package = ["torch_scatter"] +managed = true +default-groups = ["physicsnemo"] + +# The dependency-group tree is critically important for physicsnemo. +# Here, we list dependencies for each physicsnemo pacakge. Optional +# dependencies are listed with the name `package`-extras. Dependencies +# are chained together: for example, everything in core is a dep of the +# entire repo, but utils-extra only shows up for subsequent *-extra +# lists. +# +# We do this to ensure a consistent install path for targeted levels of the +# repository. If you just want the distributed manager, for example, you can +# target `distributed` instead of `physicsnemo` and you're up and running. +# +# These lists are the SINGLE SOURCE OF TRUTH. Models are also included +# below, to make single-model installation easier. +# +# In general, we do not draw a finer line than "required" and "extra". +# So, physicsnemo.nn's requirements do not include scipy and cuml, but the +# "extra" version includes BOTH. + +[dependency-groups] +core = [ + "torch>=2.4.0", + "tqdm>=4.60.0", + "requests", + "GitPython", +] +# no core-extras +distributed = [ + "numpy>=1.22.4", + {include-group = "core"} +] +# no distributed-extras +utils = [ + "termcolor", + "onnx", + "pandas", + "nvtx", + {include-group = "distributed"}, +] +utils-extras = [ + "wandb", + "mlflow", + "line_profiler", + "vtx", + "warp-lang", +] +nn = [ + "einops", + "warp-lang", + {include-group= "utils"} +] +nn-extras = [ + "cuml", + "scipy", + "cupy", +] +physicsnemo = [ + {include-group = "nn"}, + {include-group = "utils"}, +] +physicsnemo-extras = [ + {include-group = "physicsnemo"}, + {include-group = "nn-extras"}, + {include-group = "utils-extras"}, +] +dev = [ + "pytest", + "import-linter" +] [project] name = "nvidia-physicsnemo" @@ -11,22 +81,6 @@ description = "A deep learning framework for AI-driven multi-physics systems" readme = "README.md" requires-python = ">=3.10" license = "Apache-2.0" -dependencies = [ - "certifi>=2023.7.22", - "fsspec>=2023.1.0", - "numpy>=1.22.4", - "onnx>=1.14.0", - "packaging>=24.2", - "s3fs>=2023.5.0", - "setuptools>=77.0.3", - "timm>=1.0.0", - "torch>=2.4.0", - "tqdm>=4.60.0", - "treelib>=1.2.5", - "xarray>=2023.1.0", - "zarr>=2.14.2", - -] classifiers = [ "Programming Language :: Python :: 3", "Operating System :: OS Independent", @@ -39,96 +93,6 @@ Documentation = "https://docs.nvidia.com/physicsnemo/index.html#core" Issues = "https://github.com/NVIDIA/physicsnemo/issues" Changelog = "https://github.com/NVIDIA/physicsnemo/blob/main/CHANGELOG.md" -[project.optional-dependencies] -phsysicsnemo-core = [ - "packaging", - "fsspec", - "requests", - "s3fs", -] - - -launch = [ - "hydra-core>=1.2.0", - "termcolor>=2.1.1", - "wandb>=0.13.7", - "mlflow>=2.1.1", - "pydantic>=2.4.2", - "imageio>=2.28.1", - "moviepy>=1.0.3", -] - -dev = [ - "pytest>=6.0.0", - "pyyaml>=6.0", - "interrogate==1.5.0", - "coverage==6.5.0", - "ruff==0.12.5", - "moto[s3]>=5.0.28", - "pre-commit>=4.0.0", - "pytest-timeout", - "import-linter", -] - -# makani = [ -# # TODO(akamenev): PyPI does not allow direct URL deps, update once Makani is in PyPI -# # "makani @ git+https://github.com/NVIDIA/modulus-makani.git@v0.1.0", -# "torch-harmonics>=0.6.5,<0.7.1", -# "tensorly>=0.8.1", -# "tensorly-torch>=0.4.0", -# ] - -# fignet = [ -# "jaxtyping>=0.2", -# "torch_scatter>=2.1", -# "torchinfo>=1.8", -# "warp-lang>=1.0", -# "webdataset>=0.2", -# ] - -storage = [ - "multi-storage-client[boto3]>=0.33.0", -] - - - - -[tool.hatch.version] -path = "physicsnemo/__init__.py" - -all = [ - "nvidia_dali_cuda120>=1.35.0", - "h5py>=3.7.0", - "netcdf4>=1.6.3", - "ruamel.yaml>=0.17.22", - "scikit-learn>=1.0.2", - "scikit-image>=0.24.0", - "warp-lang>=1.0", - "vtk>=9.2.6", - "pyvista>=0.40.1", - "cftime>=1.6.2", - "einops>=0.7.0", - "pyspng>=0.1.0", - "shapely>=2.0.6", - "pytz>=2023.3", - "nvtx>=0.2.8", - "nvidia-physicsnemo[launch]", - "nvidia-physicsnemo[dev]", - "nvidia-physicsnemo[makani]", - "nvidia-physicsnemo[fignet]", - "nvidia-physicsnemo[storage]", -] - - - -[tool.hatch.build] -include = [ - "physicsnemo", - "physicsnemo/*", - "LICENSE" -] -exclude = ["tests", "examples"] - [tool.ruff] # Enable flake8/pycodestyle (`E`), Pyflakes (`F`), flake8-bandit (`S`), @@ -150,25 +114,3 @@ exclude = ["docs", "physicsnemo/experimental"] # Ignore `S101` (assertions) in all `test` files. "test/*.py" = ["S101"] - -# ==== UV configuration ==== -[tool.uv] -no-build-isolation-package = ["torch_scatter"] -managed = false - -[project.entry-points."physicsnemo.models"] -AFNO = "physicsnemo.models.afno:AFNO" -DLWP = "physicsnemo.models.dlwp:DLWP" -FNO = "physicsnemo.models.fno:FNO" -GraphCastNet = "physicsnemo.models.graphcast:GraphCastNet" -MeshGraphNet = "physicsnemo.models.meshgraphnet:MeshGraphNet" -FullyConnected = "physicsnemo.models.mlp:FullyConnected" -Pix2Pix = "physicsnemo.models.pix2pix:Pix2Pix" -One2ManyRNN = "physicsnemo.models.rnn:One2ManyRNN" -SRResNet = "physicsnemo.models.srrn:SRResNet" -Pangu = "physicsnemo.models.pangu:Pangu" -Fengwu = "physicsnemo.models.fengwu:Fengwu" -SwinRNN = "physicsnemo.models.swinvrnn:SwinRNN" -EDMPrecondSR = "physicsnemo.models.diffusion:EDMPrecondSR" -UNet = "physicsnemo.models.diffusion:UNet" - diff --git a/test/core/test_version_check.py b/test/core/test_version_check.py index 19f07c71c2..62d12a86dc 100644 --- a/test/core/test_version_check.py +++ b/test/core/test_version_check.py @@ -14,159 +14,132 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import MagicMock, patch +from importlib import metadata +from unittest.mock import patch import pytest from physicsnemo.core.version_check import ( - VERSION_REQUIREMENTS, - check_min_version, - check_module_requirements, + check_version_spec, + ensure_available, + get_installed_version, + require_version_spec, ) -def test_check_min_version_success(): - """Test that check_min_version succeeds when version requirement is met""" - with patch("importlib.import_module") as mock_import: - # Create a mock module with version 2.6.0 - mock_module = MagicMock() - mock_module.__version__ = "2.6.0" - mock_import.return_value = mock_module - - # Should pass with same version - assert check_min_version("torch", "2.6.0") is True - - # Should pass with lower required version - assert check_min_version("torch", "2.5.0") is True - - -def test_check_min_version_failure(): - """Test that check_min_version raises ImportError when version requirement is not met""" - with patch("importlib.import_module") as mock_import: - # Create a mock module with version 2.5.0 - mock_module = MagicMock() - mock_module.__version__ = "2.5.0" - mock_import.return_value = mock_module - - # Should fail with higher required version - with pytest.raises(ImportError) as excinfo: - check_min_version("torch", "2.6.0") - - assert "torch version 2.6.0 or higher is required" in str(excinfo.value) +def test_get_installed_version_found(): + """get_installed_version returns version string when package is installed""" + with patch( + "physicsnemo.core.version_check.metadata.version", return_value="2.6.0" + ) as mock_version: + assert get_installed_version("torch") == "2.6.0" + mock_version.assert_called_once_with("torch") -def test_check_min_version_custom_error(): - """Test that check_min_version uses custom error message if provided""" - with patch("importlib.import_module") as mock_import: - # Create a mock module with version 2.5.0 - mock_module = MagicMock() - mock_module.__version__ = "2.5.0" - mock_import.return_value = mock_module +def test_get_installed_version_not_found(): + """get_installed_version returns None when package is not installed""" + with patch( + "physicsnemo.core.version_check.metadata.version", + side_effect=metadata.PackageNotFoundError, + ): + assert get_installed_version("nonexistent_package") is None - custom_msg = "Custom error message" - with pytest.raises(ImportError) as excinfo: - check_min_version("torch", "2.6.0", error_msg=custom_msg) - assert custom_msg in str(excinfo.value) +def test_check_version_spec_success(): + """check_version_spec returns True when requirement is satisfied""" + with patch( + "physicsnemo.core.version_check.get_installed_version", return_value="2.6.0" + ): + assert check_version_spec("torch", ">=2.5,<3.0") is True -def test_check_min_version_package_not_found(): - """Test that check_min_version raises ImportError when package is not installed""" - with patch("importlib.import_module", side_effect=ImportError("Package not found")): +def test_check_version_spec_failure_hard(): + """check_version_spec raises ImportError when requirement is not met and hard_fail=True""" + with patch( + "physicsnemo.core.version_check.get_installed_version", return_value="2.5.0" + ): with pytest.raises(ImportError) as excinfo: - check_min_version("nonexistent_package", "1.0.0") + check_version_spec("torch", ">=2.6.0", hard_fail=True) + msg = str(excinfo.value) + assert "torch >=2.6.0 is required" in msg + assert "found 2.5.0" in msg - assert "Package nonexistent_package is required but not installed" in str( - excinfo.value - ) - -def test_check_module_requirements_success(): - """Test that check_module_requirements succeeds when all requirements are met""" +def test_check_version_spec_failure_soft(): + """check_version_spec returns False when requirement not met and hard_fail=False""" with patch( - "physicsnemo.core.version_check.check_min_version" - ) as mock_check_min_version: - mock_check_min_version.return_value = True - - # Should run check_min_version for known module - check_module_requirements("physicsnemo.distributed.shard_tensor") - mock_check_min_version.assert_called_once_with("torch", "2.5.9") + "physicsnemo.core.version_check.get_installed_version", return_value="2.5.0" + ): + assert check_version_spec("torch", ">=2.6.0", hard_fail=False) is False -def test_check_module_requirements_unknown_module(): - """Test that check_module_requirements does nothing for unknown modules""" +def test_check_version_spec_custom_error_message(): + """check_version_spec uses provided custom error message""" with patch( - "physicsnemo.core.version_check.check_min_version" - ) as mock_check_min_version: - # Should not call check_min_version for unknown module - check_module_requirements("unknown.module.path") - mock_check_min_version.assert_not_called() + "physicsnemo.core.version_check.get_installed_version", return_value="2.5.0" + ): + with pytest.raises(ImportError) as excinfo: + check_version_spec( + "torch", ">=2.6.0", error_msg="Custom error", hard_fail=True + ) + assert "Custom error" in str(excinfo.value) -def test_version_requirements_structure(): - """Test that VERSION_REQUIREMENTS dictionary has the expected structure""" - assert "physicsnemo.distributed.shard_tensor" in VERSION_REQUIREMENTS - assert "torch" in VERSION_REQUIREMENTS["physicsnemo.distributed.shard_tensor"] - assert ( - VERSION_REQUIREMENTS["physicsnemo.distributed.shard_tensor"]["torch"] == "2.5.9" - ) +def test_check_version_spec_package_not_found_hard(): + """Raises with clear message when package is not installed and hard_fail=True""" + with patch( + "physicsnemo.core.version_check.get_installed_version", return_value=None + ): + with pytest.raises(ImportError) as excinfo: + check_version_spec("torch", ">=2.0.0", hard_fail=True) + assert "Package 'torch' is required but not installed." in str(excinfo.value) -def test_require_version_success(): - """Test that require_version decorator allows function to run when version requirement is met""" - with patch("importlib.import_module") as mock_import: - # Create a mock module with version 2.6.0 - mock_module = MagicMock() - mock_module.__version__ = "2.6.0" - mock_import.return_value = mock_module +def test_check_version_spec_package_not_found_soft(): + """Returns False when package is not installed and hard_fail=False""" + with patch( + "physicsnemo.core.version_check.get_installed_version", return_value=None + ): + assert check_version_spec("torch", ">=2.0.0", hard_fail=False) is False - # Create a decorated function - from physicsnemo.core.version_check import require_version - @require_version("torch", "2.5.0") - def test_function(): - return "Function executed" +def test_require_version_spec_success(): + """Decorator allows execution when requirement is met""" + with patch("physicsnemo.core.version_check.check_version_spec", return_value=True): - # Function should execute normally when version requirement is met - assert test_function() == "Function executed" + @require_version_spec("torch", ">=2.5.0") + def fn(): + return "ok" + assert fn() == "ok" -def test_require_version_failure(): - """Test that require_version decorator prevents function from running when version requirement is not met""" - with patch("importlib.import_module") as mock_import: - # Create a mock module with version 2.5.0 - mock_module = MagicMock() - mock_module.__version__ = "2.5.0" - mock_import.return_value = mock_module - # Create a decorated function - from physicsnemo.core.version_check import require_version +def test_require_version_spec_failure(): + """Decorator prevents execution when requirement is not met""" + with patch( + "physicsnemo.core.version_check.check_version_spec", + side_effect=ImportError("not satisfied"), + ): - @require_version("torch", "2.6.0") - def test_function(): - return "Function executed" + @require_version_spec("torch", ">=2.6.0") + def fn(): + return "ok" - # Function should raise ImportError when version requirement is not met with pytest.raises(ImportError) as excinfo: - test_function() - - assert "torch version 2.6.0 or higher is required" in str(excinfo.value) + fn() + assert "not satisfied" in str(excinfo.value) -def test_require_version_package_not_found(): - """Test that require_version decorator raises ImportError when package is not installed""" - with patch("importlib.import_module", side_effect=ImportError("Package not found")): - # Create a decorated function - from physicsnemo.core.version_check import require_version +def test_ensure_available_success(): + """ensure_available returns True when requirement passes""" + with patch("physicsnemo.core.version_check.check_version_spec", return_value=True): + assert ensure_available("torch", ">=2.0.0") is True - @require_version("nonexistent_package", "1.0.0") - def test_function(): - return "Function executed" - # Function should raise ImportError when package is not installed - with pytest.raises(ImportError) as excinfo: - test_function() - - assert "Package nonexistent_package is required but not installed" in str( - excinfo.value - ) +def test_ensure_available_soft_failure(): + """ensure_available returns False when requirement fails and hard_fail=False""" + with patch( + "physicsnemo.core.version_check.check_version_spec", + side_effect=ImportError("bad"), + ): + assert ensure_available("torch", ">=3.0.0", hard_fail=False) is False