diff --git a/CHANGELOG.md b/CHANGELOG.md index 6044555ece..2a0d4cd0e6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 and three transient schemes. - Added a check to `stochastic_sampler` that helps handle the `EDMPrecond` model, which has a specific `.forward()` signature +- Examples: added a new example for reservoir simulation using X-MeshGraphNet. + Accessible in `examples/reservoir_simulation` - Added abstract interfaces for constructing active learning workflows, contained under the `physicsnemo.active_learning` namespace. A preliminary example of how to compose and define an active learning workflow is provided in `examples/active_learning`. diff --git a/docs/img/reservoir_simulation/xmgn/Norne/inference/PRES_21_PRED.png b/docs/img/reservoir_simulation/xmgn/Norne/inference/PRES_21_PRED.png new file mode 100644 index 0000000000..905bf620e6 Binary files /dev/null and b/docs/img/reservoir_simulation/xmgn/Norne/inference/PRES_21_PRED.png differ diff --git a/docs/img/reservoir_simulation/xmgn/Norne/inference/PRES_21_TRUE.png b/docs/img/reservoir_simulation/xmgn/Norne/inference/PRES_21_TRUE.png new file mode 100644 index 0000000000..be4b4ddc0b Binary files /dev/null and b/docs/img/reservoir_simulation/xmgn/Norne/inference/PRES_21_TRUE.png differ diff --git a/docs/img/reservoir_simulation/xmgn/Norne/inference/PRES_42_PRED.png b/docs/img/reservoir_simulation/xmgn/Norne/inference/PRES_42_PRED.png new file mode 100644 index 0000000000..79cb3aa86b Binary files /dev/null and b/docs/img/reservoir_simulation/xmgn/Norne/inference/PRES_42_PRED.png differ diff --git a/docs/img/reservoir_simulation/xmgn/Norne/inference/PRES_42_TRUE.png b/docs/img/reservoir_simulation/xmgn/Norne/inference/PRES_42_TRUE.png new file mode 100644 index 0000000000..a8aa3440b7 Binary files /dev/null and b/docs/img/reservoir_simulation/xmgn/Norne/inference/PRES_42_TRUE.png differ diff --git a/docs/img/reservoir_simulation/xmgn/Norne/inference/PRES_DIFF_21.png b/docs/img/reservoir_simulation/xmgn/Norne/inference/PRES_DIFF_21.png new file mode 100644 index 0000000000..78d701de08 Binary files /dev/null and b/docs/img/reservoir_simulation/xmgn/Norne/inference/PRES_DIFF_21.png differ diff --git a/docs/img/reservoir_simulation/xmgn/Norne/inference/PRES_DIFF_42.png b/docs/img/reservoir_simulation/xmgn/Norne/inference/PRES_DIFF_42.png new file mode 100644 index 0000000000..47b0f1752f Binary files /dev/null and b/docs/img/reservoir_simulation/xmgn/Norne/inference/PRES_DIFF_42.png differ diff --git a/docs/img/reservoir_simulation/xmgn/Norne/inference/SWAT_21_PRED.png b/docs/img/reservoir_simulation/xmgn/Norne/inference/SWAT_21_PRED.png new file mode 100644 index 0000000000..c347124f9b Binary files /dev/null and b/docs/img/reservoir_simulation/xmgn/Norne/inference/SWAT_21_PRED.png differ diff --git a/docs/img/reservoir_simulation/xmgn/Norne/inference/SWAT_21_TRUE.png b/docs/img/reservoir_simulation/xmgn/Norne/inference/SWAT_21_TRUE.png new file mode 100644 index 0000000000..ad6f5d9178 Binary files /dev/null and b/docs/img/reservoir_simulation/xmgn/Norne/inference/SWAT_21_TRUE.png differ diff --git a/docs/img/reservoir_simulation/xmgn/Norne/inference/SWAT_42_PRED.png b/docs/img/reservoir_simulation/xmgn/Norne/inference/SWAT_42_PRED.png new file mode 100644 index 0000000000..98d6e5b421 Binary files /dev/null and b/docs/img/reservoir_simulation/xmgn/Norne/inference/SWAT_42_PRED.png differ diff --git a/docs/img/reservoir_simulation/xmgn/Norne/inference/SWAT_42_TRUE.png b/docs/img/reservoir_simulation/xmgn/Norne/inference/SWAT_42_TRUE.png new file mode 100644 index 0000000000..b177d4d654 Binary files /dev/null and b/docs/img/reservoir_simulation/xmgn/Norne/inference/SWAT_42_TRUE.png differ diff --git a/docs/img/reservoir_simulation/xmgn/Norne/inference/SWAT_DIFF_21.png b/docs/img/reservoir_simulation/xmgn/Norne/inference/SWAT_DIFF_21.png new file mode 100644 index 0000000000..98007fe25c Binary files /dev/null and b/docs/img/reservoir_simulation/xmgn/Norne/inference/SWAT_DIFF_21.png differ diff --git a/docs/img/reservoir_simulation/xmgn/Norne/inference/SWAT_DIFF_42.png b/docs/img/reservoir_simulation/xmgn/Norne/inference/SWAT_DIFF_42.png new file mode 100644 index 0000000000..db5bbb5061 Binary files /dev/null and b/docs/img/reservoir_simulation/xmgn/Norne/inference/SWAT_DIFF_42.png differ diff --git a/docs/img/reservoir_simulation/xmgn/Norne/static/PARTITION.png b/docs/img/reservoir_simulation/xmgn/Norne/static/PARTITION.png new file mode 100644 index 0000000000..bdf20b7193 Binary files /dev/null and b/docs/img/reservoir_simulation/xmgn/Norne/static/PARTITION.png differ diff --git a/docs/img/reservoir_simulation/xmgn/Norne/static/PERMX.png b/docs/img/reservoir_simulation/xmgn/Norne/static/PERMX.png new file mode 100644 index 0000000000..3314d99778 Binary files /dev/null and b/docs/img/reservoir_simulation/xmgn/Norne/static/PERMX.png differ diff --git a/docs/img/reservoir_simulation/xmgn/Norne/static/PORO.png b/docs/img/reservoir_simulation/xmgn/Norne/static/PORO.png new file mode 100644 index 0000000000..83d6d6ac7c Binary files /dev/null and b/docs/img/reservoir_simulation/xmgn/Norne/static/PORO.png differ diff --git a/examples/reservoir_simulation/sim_utils/README.md b/examples/reservoir_simulation/sim_utils/README.md new file mode 100644 index 0000000000..6015d35b2d --- /dev/null +++ b/examples/reservoir_simulation/sim_utils/README.md @@ -0,0 +1,125 @@ +# Simulation Utilities + +## Overview + +The `sim_utils` package provides utilities for processing ECL/IX style binary +output files to prepare datasets for training. These scripts can read industry +standard simulator output formats (ECLIPSE, IX, OPM) and convert them into +various data structures suitable for different ML architectures. + +## Supported Formats + +- `.INIT` +- `.EGRID` +- `.UNRST` or `.X00xx` +- `.UNSMRY` or `.S00xx` + +## Modules + +### `ecl_reader.py` + +Main class for reading ECLIPSE-style binary output files. + +**Usage**: + +```python +from sim_utils import EclReader + +# Initialize reader with case name +reader = EclReader("path/to/CASE.DATA") + +# Read static properties +init_data = reader.read_init(["PORV", "PERMX", "PERMY", "PERMZ"]) + +# Read grid geometry +egrid_data = reader.read_egrid(["COORD", "ZCORN", "FILEHEAD", "NNC1", "NNC2"]) + +# Read dynamic properties (all timesteps) +restart_data = reader.read_restart(["PRESSURE", "SWAT", "SGAS"]) +``` + +**Common Keywords**: + +Static properties (INIT): + +- `PORV`: Pore volume +- `PERMX`, `PERMY`, `PERMZ`: Permeability in X, Y, Z directions +- `PORO`: Porosity +- `TRANX`, `TRANY`, `TRANZ`: Transmissibility in X, Y, Z directions + +Dynamic properties (UNRST): + +- `PRESSURE`: Cell pressure +- `SWAT`: Water saturation +- `SGAS`: Gas saturation +- `SOIL`: Oil saturation + +Grid geometry (EGRID): + +- `COORD`: Grid pillar coordinates +- `ZCORN`: Grid corner depths +- `FILEHEAD`: File header information +- `NNC1`, `NNC2`: Non-neighboring connections + +### `grid.py` + +**Grid** - Handles reservoir grid structure and operations. + +**Features**: + +- Grid dimensions and active cells +- Cell center coordinates computation +- Connection/edge computation for graph construction +- Aggregating directional transmissibilities for edge features +- Non-Neighboring Connections (NNC) +- Well completion arrays + +**Usage**: + +```python +from sim_utils import Grid + +# Initialize grid from simulation data +grid = Grid(init_data, egrid_data) + +# Get connections and transmissibilities for graph construction +connections, transmissibilities = grid.get_conx_tran() + +# Create completion arrays for wells +completion_inj, completion_prd = grid.create_completion_array(wells) + +# Access grid properties +print(f"Grid dimensions: {grid.nx} x {grid.ny} x {grid.nz}") +print(f"Active cells: {grid.nact}") +print(f"Cell coordinates: X={grid.X}, Y={grid.Y}, Z={grid.Z}") +``` + +### `well.py` + +**Well** and **Completion** - Well and completion data structures. Typically, +use results from `UNRST` (including well name, type, status, I, J, K) to +instantiate the object. + +**Usage**: + +```python +from sim_utils import Well, Completion + +# Create a well +well = Well(name="INJ1", type_id=3, stat=1) # Water injector + +# Add completions +well.add_completion( + I=10, # Grid I-index + J=10, # Grid J-index + K=5, # Grid K-index + dir=3, # Direction (1=X, 2=Y, 3=Z) + stat=1, # Status (1=OPEN) + conx_factor=1.0 # Connection factor +) + +# Check well properties +print(f"Well type: {well.type}") # 'INJ' or 'PRD' +print(f"Well status: {well.status}") # 'OPEN' or 'SHUT' +print(f"Number of completions: {len(well.completions)}") +``` diff --git a/examples/reservoir_simulation/sim_utils/__init__.py b/examples/reservoir_simulation/sim_utils/__init__.py new file mode 100644 index 0000000000..63075e833b --- /dev/null +++ b/examples/reservoir_simulation/sim_utils/__init__.py @@ -0,0 +1,49 @@ +# ignore_header_test +# Copyright 2025 Tsubasa Onishi +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Reservoir Utilities + +Shared utilities for processing reservoir simulation data across different +machine learning architectures (XMeshGraphNet, FNO, DeepONet, etc.). + +This package contains: +- ecl_reader: Read ECLIPSE-style simulation output files (.INIT, .EGRID, .UNRST, etc.) +- grid: Grid data structures and operations for reservoir simulations +- well: Well and completion data structures +""" + +from .ecl_reader import EclReader +from .grid import Grid +from .well import Well, Completion + +__all__ = ["EclReader", "Grid", "Well", "Completion"] + +__version__ = "1.0.0" diff --git a/examples/reservoir_simulation/sim_utils/ecl_reader.py b/examples/reservoir_simulation/sim_utils/ecl_reader.py new file mode 100644 index 0000000000..7e3384858e --- /dev/null +++ b/examples/reservoir_simulation/sim_utils/ecl_reader.py @@ -0,0 +1,864 @@ +# ignore_header_test +# Copyright 2025 Tsubasa Onishi +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import struct +import os +import logging +import glob +from datetime import datetime + +# Module-level logger +logger = logging.getLogger(__name__) + + +class EclReader: + """Reads SLB ECLIPSE style binary output files (.INIT, .EGRID, .UNRST, .X00xx). + + This class provides methods to read various ECLIPSE output files, including + initial conditions (.INIT), grid data (.EGRID), and restart files (.UNRST, .X00xx). + It handles endianness detection and data type conversion. + + Attributes: + input_file_path (str): Path to the main ECLIPSE input file (.DATA or .IXF). + input_file_path_base (str): Base path of the input file (without extension). + init_file_path (str): Path to the initial conditions file (.INIT). + egrid_file_path (str): Path to the grid data file (.EGRID). + unrst_file_path (str): Path to the unified restart file (.UNRST). Currently not used. + """ + + def __init__(self, input_file_path: str) -> None: + """Initializes the EclReader object. + + Parameters + input_file_path (str): Path to the main ECLIPSE input file (.DATA or .AFI). + + Raises: + FileNotFoundError: If the input file or any required related file is not found. + RuntimeError: If the input file has an unsupported extension. + """ + self.input_file_path = input_file_path + self._validate_input_file() + self._initialize_file_names() + + def read_init(self, keys: list = None) -> dict: + """Reads data from the initial conditions file (.INIT). + + Parameters + keys (list, optional): List of keys to read. If None, all keys are read. Defaults to None. + + Returns + dict: Dictionary containing the requested data, keyed by the provided keys. + Returns an empty dictionary if no keys are provided. + """ + return self._read_bin(self.init_file_path, keys) + + def read_egrid(self, keys: list = None) -> dict: + """Reads data from the grid data file (.EGRID). + + Parameters + keys (list, optional): List of keys to read. If None, all keys are read. Defaults to None. + + Returns + dict: Dictionary containing the requested data, keyed by the provided keys. + Returns an empty dictionary if no keys are provided. + """ + return self._read_bin(self.egrid_file_path, keys) + + def read_restart(self, keys: list = None, tstep_id: int = None) -> dict: + """Reads restart data from .UNRST or .X00xx files (automatically selected). + + Parameters + keys (list): List of variables to extract. + tstep_id (int, optional): Specific timestep. If None, reads all available. + + Returns + dict: { timestep_id: { "DATE": ..., "TIME": ..., key1: ..., ... }, ... } + + Raises: + FileNotFoundError: If no restart files found. + """ + + # Try unified first + unified_file = f"{self.input_file_path_base}.UNRST" + if os.path.exists(unified_file): + return self._read_unrst(unified_file, keys, tstep_id) + + # Else fallback to .X00xx + base_dir = os.path.dirname(self.input_file_path_base) + base_name = os.path.basename(self.input_file_path_base) + search_pattern = os.path.join(base_dir, f"{base_name}.X[0-9][0-9][0-9][0-9]") + files = sorted(glob.glob(search_pattern)) + + if not files: + raise FileNotFoundError("No restart files (.UNRST or .X00xx) were found.") + + # Single-step case (still wrap output in dict) + if tstep_id is not None: + match_file = f"{self.input_file_path_base}.X{self._int2ext(tstep_id)}" + if not os.path.exists(match_file): + raise FileNotFoundError(f"Restart file not found: {match_file}") + + data = self._read_bin(match_file, keys) + result = self._add_date_and_time(data) + return {tstep_id: result} + + # Multi-step case: read all .X00xx files + all_results = {} + previous_date = None + cumulative_days = 0 + + for fpath in files: + tstep = int(os.path.basename(fpath).split("X")[-1]) + try: + data = self._read_bin(fpath, keys) + except Exception as e: + logging.warning(f"Skipping {fpath} due to error: {e}") + continue + + result = self._add_date_and_time(data, previous_date, cumulative_days) + if "DATE" in result: + current_date = result["DATE"] + cumulative_days = result["TIME"] + previous_date = current_date + + all_results[tstep] = result + + return all_results + + def read_smry(self, keys: list, entities: list = None) -> dict: + """Reads summary data from .UNSMRY or .Sxxxx files for fields, wells, or groups. + + Parameters + keys (list): Summary variable names to extract (e.g., ["WBHP", "WOPR"]). + entities (list, optional): List of entities (e.g., wells or groups like ["INJ", "PROD", "FIELD"]). + If None, all unique entities will be used. + + Returns + dict: { + "TIME": np.ndarray, + "": { "": np.ndarray, ... }, + ... + } + + Raises: + FileNotFoundError: If the .SMSPEC or summary output files are missing. + """ + smspec_file = f"{self.input_file_path_base}.SMSPEC" + if not os.path.exists(smspec_file): + raise FileNotFoundError( + f"Summary specification file not found: {smspec_file}" + ) + + # --- Step 1: Read key/entity info from SMSPEC using pattern matching --- + with open(smspec_file, "rb") as fid: + file_data = fid.read() + + # Detect endian from the first 4 bytes + first_int = struct.unpack(" 0 and first_int < 1000: # Reasonable header size + endian = "<" + else: + endian = ">" + + # Find KEYWORDS and WGNAMES/NAMES positions by pattern matching + keywords_pos = file_data.find(b"KEYWORDS") + wgnames_pos = file_data.find(b"WGNAMES") + names_pos = file_data.find(b"NAMES") + + if keywords_pos == -1: + raise ValueError("KEYWORDS record not found in SMSPEC file") + + # Read KEYWORDS data + keywords_pos -= 4 # Go back to record start + with open(smspec_file, "rb") as fid: + fid.seek(keywords_pos) + raw_keys, n_key = self._read_smspec_record(fid, endian) + + # Read entity names (try WGNAMES first, then NAMES) + if wgnames_pos != -1: + entity_pos = wgnames_pos - 4 + elif names_pos != -1: + entity_pos = names_pos - 4 + else: + raise ValueError("Neither WGNAMES nor NAMES record found in SMSPEC file") + + with open(smspec_file, "rb") as fid: + fid.seek(entity_pos) + raw_entities, n_ent = self._read_smspec_record(fid, endian) + + if n_key != n_ent: + raise ValueError( + f"Mismatch between number of keys ({n_key}) and entities ({n_ent}) in .SMSPEC." + ) + + all_keys = ["".join(row).strip() for row in raw_keys] + all_entities = ["".join(row).strip() for row in raw_entities] + + if entities is None: + entities = sorted(set(all_entities)) + # entities = ["FIELD" if s == ':+:+:+:+' else s for s in entities] + + n_keys = len(keys) + n_ents = len(entities) + + # Build lookup table (flat index map) + index_map = { + (k, e): i + for i, (k, e) in enumerate(zip(all_keys, all_entities)) + if k in keys and e in entities + } + + # --- Step 2: Read UNSMRY or Sxxxx files --- + time_series = [] + summary_data = [] + + files = [f"{self.input_file_path_base}.UNSMRY"] + if not os.path.exists(files[0]): + base_dir = os.path.dirname(self.input_file_path_base) + base_name = os.path.basename(self.input_file_path_base) + pattern = os.path.join(base_dir, f"{base_name}.S[0-9][0-9][0-9][0-9]") + files = sorted(glob.glob(pattern)) + + for fname in files: + if not os.path.exists(fname): + logging.warning(f"Skipping missing summary file: {fname}") + continue + + with open(fname, "rb") as fid: + self._load_vector(fid, endian) # Skip SEQHDR + + while True: + _, _, label = self._load_vector(fid, endian) # MINISTEP or SEQHDR + if label == "SEQHDR": + continue + + data, _, _ = self._load_vector(fid, endian) + if data is None or len(data) == 0: + break + + time_series.append(data[0]) + + row = np.full((n_keys, n_ents), np.nan, dtype=np.float32) + for i, key in enumerate(keys): + for j, ent in enumerate(entities): + idx = index_map.get((key, ent), -1) + if idx >= 0: + row[i, j] = data[idx] + summary_data.append(row) + + # --- Step 3: Restructure output --- + time_series = np.array(time_series) + summary_data = np.array(summary_data) # [timesteps, n_keys, n_ents] + + result = {"TIME": time_series} + for j, ent in enumerate(entities): + ent_block = { + keys[i]: summary_data[:, i, j] + for i in range(n_keys) + if not np.all(np.isnan(summary_data[:, i, j])) + } + if ent_block: + if ent == ":+:+:+:+": + ent = "FIELD" + result[ent] = ent_block + + return result + + # ---- Private Methods --------------------------------------------------------------------------------------------- + + def _validate_input_file(self) -> None: + """Validates the input file and its extension. + + Raises: + FileNotFoundError: If the input file is not found. + RuntimeError: If the input file has an unsupported extension. + """ + if not os.path.exists(self.input_file_path): + raise FileNotFoundError(f"Input file not found: {self.input_file_path}") + + base, ext = os.path.splitext(self.input_file_path) + if ext.upper() not in [".DATA", ".AFI"]: + raise RuntimeError(f"Unsupported input file: {self.input_file_path}") + + self.input_file_path_base = base + + def _initialize_file_names(self) -> None: + """Initializes file paths for related binary files (.INIT, .EGRID, .UNRST).""" + self.init_file_path = f"{self.input_file_path_base}.INIT" + self.egrid_file_path = f"{self.input_file_path_base}.EGRID" + self.unrst_file_path = f"{self.input_file_path_base}.UNRST" + + def _read_bin(self, file_path: str, keys: list) -> dict: + """Reads ECLIPSE style binary data from the given file. + + Parameters + file_path (str): Path to the binary file. + keys (list): List of keys to read. + + Returns + dict: Dictionary containing the requested data. Returns an empty dictionary if keys is None. + """ + + if keys is None: + logging.warning("No keys provided.") + return {} + + logging.debug(f"Reading keys: {keys} in file: {file_path}") + + variables = {} + with open(file_path, "rb") as fid: + endian = self._detect_endian(fid) + found_keys = {key: False for key in keys} + + while keys and not all(found_keys.values()): + data, _, key = self._load_vector(fid, endian) + key = key.strip() + if key in found_keys: + # Dynamically determine dtype + if isinstance(data, np.ndarray): + variables[key] = data # Keep original dtype + elif isinstance(data, (bytes, str)): + variables[key] = data.decode( + errors="ignore" + ).strip() # Convert bytes to string + elif isinstance(data, (int, float)): + variables[key] = np.array( + [data], dtype=np.float32 + ) # Convert scalars to array + else: + logging.warning(f"Unknown data type for key: {key}") + variables[key] = data # Store as-is + + found_keys[key] = True + + if fid.tell() >= os.fstat(fid.fileno()).st_size: + break + + # Log missing keys (Debug level) + missing_keys = [k for k, v in found_keys.items() if not v] + if missing_keys: + logging.debug(f"The following keys were not found: {missing_keys}") + for key in missing_keys: + variables[key] = np.array([]) + + return variables + + def _load_vector(self, fid, endian): + """Reads a data block (vector) from the binary file. + + Parameters + fid: File object. + endian (str): Endianness ('<' for little-endian, '>' for big-endian). + + Returns + tuple: A tuple containing the data (NumPy array or string), the data count, and the key. + Returns (None, None, key) if an error occurs during reading. + """ + try: + # Read and verify the header + header_size = struct.unpack(endian + "i", fid.read(4))[0] + key = fid.read(8).decode(errors="ignore").strip() + data_count = struct.unpack(endian + "i", fid.read(4))[0] + data_type_raw = fid.read(4) + data_type = data_type_raw.decode(errors="ignore").strip().upper() + end_size = struct.unpack(endian + "i", fid.read(4))[0] + + if header_size != end_size: + logging.warning( + f"Mismatch Detected for {key}: Header={header_size}, End={end_size}" + ) + return None, None, key # Skip this entry + + # Define data type mapping + dtype_map = { + "CHAR": "S1", + "INTE": "i4", + "REAL": "f4", + "DOUB": "f8", + "LOGI": "i4", + } + dtype = dtype_map.get(data_type) + + if dtype: + raw_data = bytearray() + read_count = 0 + + while read_count < data_count: + # Read the header size of this chunk + chunk_size = struct.unpack(endian + "i", fid.read(4))[0] + chunk_data = fid.read(chunk_size) + chunk_end = struct.unpack(endian + "i", fid.read(4))[0] + + if chunk_size != chunk_end: + logging.warning( + f"Chunk mismatch in {key}: Expected {chunk_size}, got {chunk_end}" + ) + return None, None, key + + raw_data.extend(chunk_data) + read_count += chunk_size // np.dtype(dtype).itemsize + + if data_type == "CHAR": + char_array = np.frombuffer(raw_data, dtype="S1").reshape( + (-1, 8) + ) # 8-char wide strings + char_array = np.char.decode(char_array, encoding="utf-8").astype( + str + ) + return char_array, data_count, key + else: + data = np.frombuffer(raw_data, dtype=endian + dtype) + return data, data_count, key + else: + fid.seek(data_count * 4, os.SEEK_CUR) # Skip unknown type + return None, None, key + except struct.error: + return None, None, "" + + def _read_smspec_record(self, fid, endian): + """Read a single SMSPEC record using pattern matching approach. + + Parameters + fid: File object positioned at the start of a record + endian: Endianness string + + Returns + tuple: (data_array, count) + """ + try: + # Read record header + header_size = struct.unpack(endian + "i", fid.read(4))[0] + key = fid.read(8).decode("ascii", errors="ignore").strip() + data_count = struct.unpack(endian + "i", fid.read(4))[0] + data_type = fid.read(4).decode("ascii", errors="ignore").strip() + end_size = struct.unpack(endian + "i", fid.read(4))[0] + + if header_size != end_size: + raise ValueError( + f"Header size mismatch for {key}: {header_size} != {end_size}" + ) + + if data_count <= 0: + return np.array([]), 0 + + # Determine bytes per element based on data type + if data_type == "CHAR": + bytes_per_element = 8 + else: + dtype_map = {"INTE": "i4", "REAL": "f4", "DOUB": "f8", "LOGI": "i4"} + dtype = dtype_map.get(data_type, "i4") + bytes_per_element = np.dtype(dtype).itemsize + + # Read the data in chunks + raw_data = bytearray() + bytes_read = 0 + total_bytes_needed = data_count * bytes_per_element + + while bytes_read < total_bytes_needed: + chunk_size = struct.unpack(endian + "i", fid.read(4))[0] + chunk_data = fid.read(chunk_size) + chunk_end = struct.unpack(endian + "i", fid.read(4))[0] + + if chunk_size != chunk_end: + raise ValueError( + f"Chunk size mismatch: {chunk_size} != {chunk_end}" + ) + + raw_data.extend(chunk_data) + bytes_read += chunk_size + + # Parse the data based on type + if data_type == "CHAR": + # For CHAR data, reshape into 8-character strings + if len(raw_data) >= total_bytes_needed: + char_data = np.frombuffer(raw_data, dtype="S1").reshape((-1, 8)) + char_data = np.char.decode(char_data, encoding="ascii").astype(str) + return char_data, data_count + else: + raise ValueError( + f"Insufficient CHAR data: expected {total_bytes_needed}, got {len(raw_data)}" + ) + else: + # For numeric data + if len(raw_data) >= total_bytes_needed: + data = np.frombuffer(raw_data, dtype=endian + dtype) + return data, data_count + else: + raise ValueError( + f"Insufficient {data_type} data: expected {total_bytes_needed}, got {len(raw_data)}" + ) + + except Exception as e: + raise RuntimeError(f"Error reading SMSPEC record: {e}") + + def _detect_endian(self, fid): + """Detects file endianness. + + Parameters + fid: File object. + + Returns + str: Endianness ('<' for little-endian, '>' for big-endian). + """ + fid.seek(0) + test_int = fid.read(4) + little_endian = struct.unpack("i", test_int)[0] + fid.seek(0) + return "<" if abs(little_endian) < abs(big_endian) else ">" + + def _int2ext(self, i): + """Converts an integer to a formatted string with leading zeros (e.g., 1 to "0001"). + + Parameters + i (int): Integer to convert. + + Returns + str: Formatted string with leading zeros. + """ + return f"{i:04d}" + + def _read_unrst( + self, file_path: str, keys: list = None, tstep_id: int = None + ) -> dict: + """Read restart data from UNRST file with improved pattern matching.""" + + if keys is None: + keys = [] + + all_results = {} + file_size = os.path.getsize(file_path) + + # Read the entire file into memory for pattern matching + with open(file_path, "rb") as fid: + file_data = fid.read() + + # Detect endian from the first 4 bytes + first_int = struct.unpack(" 0 and first_int < 1000: # Reasonable header size + endian = "<" + else: + endian = ">" + + # Find all INTEHEAD positions by pattern matching + intehead_positions = [] + pos = 0 + while True: + pos = file_data.find(b"INTEHEAD", pos) + if pos == -1: + break + intehead_positions.append(pos - 4) # Go back 4 bytes to get to record start + pos += 8 + + # Find all requested key positions by pattern matching + key_positions = {key: [] for key in keys} if keys else {} + for key in keys: + pos = 0 + while True: + pos = file_data.find(key.encode("ascii"), pos) + if pos == -1: + break + + # Verify this is actually a record header by checking the structure + if pos >= 4: + try: + header_size = struct.unpack( + endian + "i", file_data[pos - 4 : pos] + )[0] + if 8 <= header_size <= 1000: # Reasonable header size range + key_positions[key].append(pos - 4) + except (struct.error, IndexError): + pass # Skip if we can't unpack + + pos += len(key) + + # Read data from discovered positions + with open(file_path, "rb") as fid: + dates = [] + times = [] + + # Read INTEHEAD data to get dates + for intehead_pos in intehead_positions: + fid.seek(intehead_pos) + try: + header_size = struct.unpack(endian + "i", fid.read(4))[0] + key = fid.read(8).decode("ascii", errors="ignore").strip() + data_count = struct.unpack(endian + "i", fid.read(4))[0] + data_type = fid.read(4).decode("ascii", errors="ignore").strip() + end_size = struct.unpack(endian + "i", fid.read(4))[0] + + if key == "INTEHEAD" and header_size == end_size: + # Read the INTEHEAD data to get the date + raw_data = bytearray() + read_count = 0 + while read_count < data_count: + chunk_size = struct.unpack(endian + "i", fid.read(4))[0] + chunk_data = fid.read(chunk_size) + chunk_end = struct.unpack(endian + "i", fid.read(4))[0] + if chunk_size != chunk_end: + break + raw_data.extend(chunk_data) + read_count += chunk_size // 4 + + if len(raw_data) >= data_count * 4: + data = np.frombuffer(raw_data, dtype=endian + "i4") + if len(data) > 66: + IDAY, IMON, IYEAR = data[64], data[65], data[66] + date = datetime(IYEAR, IMON, IDAY) + dates.append(date) + else: + raise ValueError( + f"INTEHEAD data too short: expected >66 elements, got {len(data)}" + ) + else: + raise ValueError( + f"Failed to read INTEHEAD data: expected {data_count * 4} bytes, got {len(raw_data)}" + ) + else: + raise ValueError( + f"Invalid INTEHEAD record: key='{key}', header_size={header_size}, end_size={end_size}" + ) + except Exception as e: + raise RuntimeError( + f"Error reading INTEHEAD at position {intehead_pos}: {e}" + ) + + # Calculate cumulative time from dates + times = [] + if len(dates) > 0: + base_date = dates[0] # Use first date as reference + for date in dates: + time_delta = date - base_date + cumulative_days = time_delta.total_seconds() / ( + 24 * 3600 + ) # Convert to days + times.append(cumulative_days) + + # Read data for each timestep + for timestep_idx, (intehead_pos, date, time) in enumerate( + zip(intehead_positions, dates, times) + ): + result = {"DATE": date, "TIME": time} + + # Read INTEHEAD data for this timestep + fid.seek(intehead_pos) + try: + header_size = struct.unpack(endian + "i", fid.read(4))[0] + key = fid.read(8).decode("ascii", errors="ignore").strip() + data_count = struct.unpack(endian + "i", fid.read(4))[0] + data_type = fid.read(4).decode("ascii", errors="ignore").strip() + end_size = struct.unpack(endian + "i", fid.read(4))[0] + + if key == "INTEHEAD" and header_size == end_size: + # Read the INTEHEAD data + raw_data = bytearray() + read_count = 0 + while read_count < data_count: + chunk_size = struct.unpack(endian + "i", fid.read(4))[0] + chunk_data = fid.read(chunk_size) + chunk_end = struct.unpack(endian + "i", fid.read(4))[0] + if chunk_size != chunk_end: + break + raw_data.extend(chunk_data) + read_count += chunk_size // 4 + + if len(raw_data) >= data_count * 4: + intehead_data = np.frombuffer(raw_data, dtype=endian + "i4") + result["INTEHEAD"] = intehead_data + else: + result["INTEHEAD"] = np.array([]) + else: + result["INTEHEAD"] = np.array([]) + except Exception as e: + logging.error( + f"Error reading INTEHEAD for timestep {timestep_idx}: {e}" + ) + result["INTEHEAD"] = np.array([]) + + # Read requested keys for this timestep + for key in keys: + if key == "INTEHEAD": + continue # Already handled above + + # Find the key position that comes after this INTEHEAD position + # but before the next INTEHEAD position (or end of file) + key_pos = None + next_intehead_pos = ( + intehead_positions[timestep_idx + 1] + if timestep_idx + 1 < len(intehead_positions) + else file_size + ) + + if key in key_positions: + # Find the first key position that comes after this INTEHEAD + for pos in key_positions[key]: + if intehead_pos < pos < next_intehead_pos: + key_pos = pos + break + + if key_pos is not None: + fid.seek(key_pos) + try: + # Read key data + header_size = struct.unpack(endian + "i", fid.read(4))[0] + key_name = ( + fid.read(8).decode("ascii", errors="ignore").strip() + ) + data_count = struct.unpack(endian + "i", fid.read(4))[0] + data_type = ( + fid.read(4).decode("ascii", errors="ignore").strip() + ) + end_size = struct.unpack(endian + "i", fid.read(4))[0] + + if key_name == key and header_size == end_size: + # Read the data + raw_data = bytearray() + bytes_read = 0 + + # Determine bytes per element based on data type + if data_type == "CHAR": + bytes_per_element = 8 + else: + bytes_per_element = 4 + + total_bytes_needed = data_count * bytes_per_element + + while bytes_read < total_bytes_needed: + chunk_size = struct.unpack( + endian + "i", fid.read(4) + )[0] + chunk_data = fid.read(chunk_size) + chunk_end = struct.unpack( + endian + "i", fid.read(4) + )[0] + if chunk_size != chunk_end: + break + raw_data.extend(chunk_data) + bytes_read += chunk_size + + if len(raw_data) >= total_bytes_needed: + if data_type == "CHAR": + # Handle string data (like well names in ZWEL) + if ( + len(raw_data) >= data_count * 8 + ): # CHAR uses 8 bytes per element + char_data = np.frombuffer( + raw_data, dtype="S1" + ).reshape((-1, 8)) + char_data = np.char.decode( + char_data, encoding="ascii" + ).astype(str) + # Join characters to form complete strings + string_data = np.array( + [ + "".join(row).strip() + for row in char_data + ] + ) + result[key] = string_data + else: + if ( + logger.getEffectiveLevel() + <= logging.DEBUG + ): + logging.debug( + f"Insufficient CHAR data for {key}: expected {data_count * 8}, got {len(raw_data)}" + ) + else: + # Handle numeric data + dtype_map = { + "REAL": "f4", + "DOUB": "f8", + "INTE": "i4", + "LOGI": "i4", + } + dtype = dtype_map.get(data_type, "f4") + key_data = np.frombuffer( + raw_data, dtype=endian + dtype + ) + + # For now, just use the data as-is without truncation + + result[key] = key_data + + except Exception as e: + logging.error( + f"Error reading {key} for timestep {timestep_idx}: {e}" + ) + # If reading fails, add empty data + if key in ["ZWEL"]: + result[key] = np.array([]) # Empty string array + else: + result[key] = np.array([]) # Empty numeric array + else: + # Key doesn't exist for this timestep, add empty data + if key in ["ZWEL"]: + result[key] = np.array([]) # Empty string array + else: + result[key] = np.array([]) # Empty numeric array + + all_results[timestep_idx] = result + + # If specific timestep requested, return only that one + if tstep_id is not None: + if tstep_id in all_results: + return {tstep_id: all_results[tstep_id]} + else: + raise ValueError(f"Timestep {tstep_id} not found in {file_path}") + + # Transform results to the expected format + def transform_results_dict(all_results: dict) -> dict: + """Convert {tstep: {key: value}} -> {key: [values]}.""" + merged = {} + for _, result in sorted(all_results.items()): + for k, v in result.items(): + if k not in merged: + merged[k] = [] + merged[k].append(v) + return merged + + return transform_results_dict(all_results) + + def _add_date_and_time(self, data: dict, prev_date=None, prev_days=0) -> dict: + """Adds DATE and TIME fields to restart data if INTEHEAD exists.""" + + result = dict(data) # copy + if "INTEHEAD" in data: + header = data["INTEHEAD"] + day, mon, year = header[64], header[65], header[66] + try: + date = datetime(year, mon, day) + result["DATE"] = date + result["TIME"] = (date - prev_date).days + prev_days if prev_date else 0 + except Exception as e: + logging.warning(f"Invalid date in INTEHEAD: {e}") + return result diff --git a/examples/reservoir_simulation/sim_utils/grid.py b/examples/reservoir_simulation/sim_utils/grid.py new file mode 100644 index 0000000000..3f5f2f6142 --- /dev/null +++ b/examples/reservoir_simulation/sim_utils/grid.py @@ -0,0 +1,380 @@ +# ignore_header_test +# Copyright 2025 Tsubasa Onishi +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import logging + +NUM_MAX_DIMENS = 3 + +# Module-level logger +logger = logging.getLogger(__name__) + + +class Grid: + """Handles reservoir grid structure and operations. + + This class manages grid dimensions, active cells, cell center coordinates, + connections/edges for graph construction, and transmissibilities. + """ + + FULL_GRID_KEYS = ["PORV"] + + def __init__(self, init_data: dict, egrid_data: dict) -> None: + self.nx, self.ny, self.nz = init_data["INTEHEAD"][8:11].astype(int).tolist() + self.nn = self.nx * self.ny * self.nz + self.num_max_dims = NUM_MAX_DIMENS + self.single_layer = self.nz == 1 + self.porv = init_data["PORV"] + self.actnum = (self.porv > 0).astype(int) + self.actnum_bool = (self.porv > 0).astype(bool) + self.nact = np.sum(self.actnum) + self.X, self.Y, self.Z = self._set_cell_center_coord(egrid_data) + self._initialize_tran_keys() + self._set_dual_poro(dp_flag=egrid_data["FILEHEAD"][5]) + self._set_NNC(dict_NNC={"NNC1": egrid_data["NNC1"], "NNC2": egrid_data["NNC2"]}) + self._compute_ijk_to_active_mapping() + self._compute_connections() + self._compute_total_tran(init_data) + + def ijk_from_I_J_K(self, I: int, J: int, K: int) -> int: + return I + (J - 1) * self.nx + (K - 1) * self.nx * self.ny + + def get_conx_tran(self) -> tuple: + """Get connections and transmissibilities for the grid. + + Returns + tuple: A tuple containing (connections, transmissibilities) where + connections is an array of grid cell connections and + transmissibilities is the corresponding edge features. + """ + # Validate edge indices are within valid range + if self._conx.size > 0: + max_idx = np.max(self._conx) + if max_idx >= self.nact: + print( + f"⚠️ Warning: Edge index {max_idx} >= nact ({self.nact}). This may cause issues." + ) + if np.min(self._conx) < 0: + print( + f"⚠️ Warning: Edge index {np.min(self._conx)} < 0. This may cause issues." + ) + return self._conx, self._Txyz_flattened + + def _set_cell_center_coord(self, egrid_data: dict): + """ + Compute cell center coordinates from COORD and ZCORN arrays (corner-point grid). + + Parameters + coord (np.ndarray): shape (6, nx+1, ny+1), coordinates of grid pillars. + zcorn (np.ndarray): shape (8 * nx * ny * nz,), raw ZCORN values (Fortran-ordered). + nx, ny, nz (int): number of cells in each direction. + + Returns + (center_x, center_y, center_z): 3D arrays of shape (nx, ny, nz) with cell centers. + """ + if "COORD" not in egrid_data or "ZCORN" not in egrid_data: + return np.array([]), np.array([]), np.array([]) + + coord, zcorn = egrid_data["COORD"], egrid_data["ZCORN"] + nx, ny, nz = self.nx, self.ny, self.nz + + # Reshape ZCORN into logical (2*nx, 2*ny, 2*nz) grid of corner depths + coord = coord.reshape((6, nx + 1, ny + 1), order="F") + zcorn = zcorn.reshape((2 * nx, 2 * ny, 2 * nz), order="F") + + center_X = np.zeros((nx, ny, nz)) + center_Y = np.zeros((nx, ny, nz)) + center_Z = np.zeros((nx, ny, nz)) + + for k in range(nz): + for j in range(ny): + for i in range(nx): + # Get the 4 pillars for this cell + pillars = [ + coord[:, i, j], + coord[:, i + 1, j], + coord[:, i, j + 1], + coord[:, i + 1, j + 1], + ] + + # Compute average X, Y from top and base of each pillar + x_vals = [ + 0.5 * (p[0] + p[3]) for p in pillars + ] # avg top and base X + y_vals = [ + 0.5 * (p[1] + p[4]) for p in pillars + ] # avg top and base Y + + center_X[i, j, k] = np.mean(x_vals) + center_Y[i, j, k] = np.mean(y_vals) + + # Collect 8 corner Z-values for this cell + z000 = zcorn[2 * i, 2 * j, 2 * k] + z100 = zcorn[2 * i + 1, 2 * j, 2 * k] + z010 = zcorn[2 * i, 2 * j + 1, 2 * k] + z110 = zcorn[2 * i + 1, 2 * j + 1, 2 * k] + z001 = zcorn[2 * i, 2 * j, 2 * k + 1] + z101 = zcorn[2 * i + 1, 2 * j, 2 * k + 1] + z011 = zcorn[2 * i, 2 * j + 1, 2 * k + 1] + z111 = zcorn[2 * i + 1, 2 * j + 1, 2 * k + 1] + + center_Z[i, j, k] = np.mean( + [z000, z100, z010, z110, z001, z101, z011, z111] + ) + + # active cell only vectors + X = center_X.reshape(-1, order="F")[self.actnum_bool] + Y = center_Y.reshape(-1, order="F")[self.actnum_bool] + Z = center_Z.reshape(-1, order="F")[self.actnum_bool] + + return X, Y, Z + + def _initialize_tran_keys(self) -> None: + """Initializes the transmissibility keys.""" + self._tran_keys = ["TRANX", "TRANY", "TRANZ", "TRANNNC"] + + def _set_dual_poro(self, dp_flag: int) -> None: + """Configures the grid for dual porosity. + + Parameters + dp_flag (int): Dual porosity flag (0 for single porosity, 1 or 2 for dual porosity). + """ + self.dual_poro = dp_flag in [1, 2] + if self.dual_poro and self.nz == 2: + self.single_layer = True + elif dp_flag not in [0, 1, 2]: + print( + f"Invalid dual porosity flag found: {dp_flag}. Proceeding with single porosity assumption." + ) + + def _set_NNC(self, dict_NNC: dict) -> None: + """Configures the grid for NNCs. + + Parameters + dict_NNC (dict): Dictionary containing egrid data (NNC1, NNC2). + """ + self.NNC = dict_NNC["NNC1"].size > 0 + if self.NNC: + self.NNC1 = dict_NNC["NNC1"] + self.NNC2 = dict_NNC["NNC2"] + self.num_NNCs = len(self.NNC1) + else: + self.NNC1, self.NNC2 = np.array([]), np.array([]) + self.num_NNCs = 0 + + def _compute_ijk_to_active_mapping(self) -> None: + """Compute mapping from IJK indices to active cell indices. + This is static and can be reused across time steps. + """ + self.ijk_to_active = {} + active_idx = 0 + for ijk in range(self.nn): + if self.actnum[ijk] > 0: + self.ijk_to_active[ijk] = active_idx + active_idx += 1 + + def _compute_connections(self) -> np.ndarray: + """ + Computes the connection matrix. + + Returns + np.ndarray: The connection matrix where each row represents a connection + between two grid cells. + + Notes + Indexing convention: + - cell_idx_cumsum: 1-based sequential indices for active cells (1, 2, 3, ...) + - NNC1/NNC2: 1-based ECLIPSE global cell indices + - Final output (_conx): 0-based Python indices for active cells + """ + + # Compute active cell indexing: creates 1-based sequential indices (1, 2, 3, ...) + # for active cells, 0 for inactive cells + cell_idx = np.ones(self.nn, dtype=int) + cell_idx[self.actnum == 0] = 0 + cell_idx_cumsum = np.cumsum(cell_idx) # 1-based: active cells get 1, 2, 3, ... + cell_idx_cumsum[self.actnum == 0] = 0 # Set inactive cells back to 0 + + # Reshape active grid indices into 3D (maintains 1-based indexing) + cell_idx_3D = cell_idx_cumsum.reshape(self.nx, self.ny, self.nz, order="F") + + # Extend face indexing by adding ghost layers (boundary cells remain 0) + face_idx = np.zeros((self.nx + 2, self.ny + 2, self.nz + 2), dtype=int) + face_idx[1 : self.nx + 1, 1 : self.ny + 1, 1 : self.nz + 1] = cell_idx_3D + + conx = [] + if self.nx > 1: # X-direction connections + idx1 = face_idx[: self.nx + 1, 1 : self.ny + 1, 1 : self.nz + 1] + idx2 = face_idx[1 : self.nx + 2, 1 : self.ny + 1, 1 : self.nz + 1] + conx.append(np.column_stack((idx1.ravel(order="F"), idx2.ravel(order="F")))) + + if self.ny > 1: # Y-direction connections + idx1 = face_idx[1 : self.nx + 1, : self.ny + 1, 1 : self.nz + 1] + idx2 = face_idx[1 : self.nx + 1, 1 : self.ny + 2, 1 : self.nz + 1] + conx.append(np.column_stack((idx1.ravel(order="F"), idx2.ravel(order="F")))) + + if not self.single_layer: # Z-direction connections. Use this flag, instead of nz because nz = 2*nz in DP systems + idx1 = face_idx[1 : self.nx + 1, 1 : self.ny + 1, : self.nz + 1] + idx2 = face_idx[1 : self.nx + 1, 1 : self.ny + 1, 1 : self.nz + 2] + conx.append(np.column_stack((idx1.ravel(order="F"), idx2.ravel(order="F")))) + + # Stack all connections into a single array + conx = np.vstack(conx) + + # Non-neighboring connections (NNC) + # NNC1/NNC2 are 1-based ECLIPSE global cell indices + # We subtract 1 to convert to 0-based Python indices for array access + # cell_idx_flattened then returns the 1-based sequential active cell index + if self.NNC and self.NNC1.size > 0 and self.NNC2.size > 0: + cell_idx_flattened = cell_idx_3D.ravel(order="F") + NNC_conx = np.column_stack( + (cell_idx_flattened[self.NNC1 - 1], cell_idx_flattened[self.NNC2 - 1]) + ) + conx = np.vstack((conx, NNC_conx)) + + # Filter out boundary connections (connections involving inactive cells) + # Inactive cells have index 0, so any connection with 0 is invalid + self._valid_conx_idx = ~np.any(conx == 0, axis=1) + + # Filter boundary connections and convert to 0-based indexing + # At this point, conx contains 1-based sequential active cell indices (1, 2, 3, ...) + # We subtract 1 to convert to 0-based Python indices (0, 1, 2, ...) for final output + self._conx = conx[self._valid_conx_idx] - 1 + + # Log detailed grid and connection information at debug level + if self._conx.size > 0 and logger.getEffectiveLevel() <= logging.DEBUG: + total_cells = self.nx * self.ny * self.nz + active_percentage = (self.nact / total_cells) * 100 + nnc_count = self.num_NNCs if hasattr(self, "num_NNCs") else 0 + + logging.debug( + f"Grid dimensions: {self.nx} × {self.ny} × {self.nz} = {total_cells:,} total cells" + ) + logging.debug( + f"Active cells: {self.nact:,} ({active_percentage:.1f}% of total)" + ) + logging.debug(f"Connections: {self._conx.shape[0]:,} edges") + logging.debug( + f" - Regular connections: {self._conx.shape[0] - nnc_count:,}" + ) + logging.debug(f" - NNC connections: {nnc_count:,}") + logging.debug( + f"Edge indices: min={np.min(self._conx)}, max={np.max(self._conx)}" + ) + + def _compute_total_tran(self, init_data: dict) -> None: + """Computes and stores total transmissibility, including TRANNNC. + + Parameters + init_data (dict): Dictionary of initialization data containing transmissibility keys. + """ + # Check if any transmissibility keys exist in init_data + available_tran_keys = [k for k in self._tran_keys if k in init_data] + if not available_tran_keys: + # Initialize transmissibility arrays as None if no tran data available + self._Txyz_flattened = None + self._T_xyz = None + return + + nx, ny, nz = self.nx, self.ny, self.nz + self._T_xyz = np.zeros((self.nn, self.num_max_dims)) # tran in xyz-dirs + self._Tx = np.zeros((nx + 1, ny, nz)) # total tran in x-dir + self._Ty = np.zeros((nx, ny + 1, nz)) # total tran in y-dir + self._Tz = np.zeros((nx, ny, nz + 1)) # total tran in z-dir + + # Store phase tran at active cells + for i, key in enumerate(self._tran_keys[:-1]): + if key in init_data and init_data[key].size: + self._T_xyz[self.actnum_bool, i] = init_data[key] + + # Compute total tran for xyz dirs + Txyz_flattened = np.array([]) + if nx > 1: + self._Tx[1 : nx + 1, :, :] = self._T_xyz[:, 0].reshape( + nx, ny, nz, order="F" + ) + Txyz_flattened = np.append(Txyz_flattened, self._Tx.ravel(order="F")) + if ny > 1: + self._Ty[:, 1 : ny + 1, :] = self._T_xyz[:, 1].reshape( + nx, ny, nz, order="F" + ) + Txyz_flattened = np.append(Txyz_flattened, self._Ty.ravel(order="F")) + if not self.single_layer: + self._Tz[:, :, 1 : nz + 1] = self._T_xyz[:, 2].reshape( + nx, ny, nz, order="F" + ) + Txyz_flattened = np.append(Txyz_flattened, self._Tz.ravel(order="F")) + + # Append total tran for NNCs if applicable + if "TRANNNC" in init_data and init_data["TRANNNC"].size: + Txyz_flattened = np.append(Txyz_flattened, init_data["TRANNNC"]) + + self._Txyz_flattened = Txyz_flattened[self._valid_conx_idx] + + def create_completion_array( + self, wells: dict, use_completion_connection_factor: bool = False + ): + """ + Create completion array from well data for a specific timestep. + + Parameters + ----------- + wells : dict + dict of Well objects for this sample + + Returns + -------- + completion : np.ndarray + Completion status array for active cells (1=injector, 0=closed, -1=producer) + """ + completion_cell_id_inj = np.zeros(self.nact, dtype=float) + completion_cell_id_prd = np.zeros(self.nact, dtype=float) + ijk_to_active = self.ijk_to_active + if not wells: + return completion_cell_id_inj, completion_cell_id_prd + for well_name in wells.keys(): + well = wells[well_name] + if well.status == "OPEN": # Only process open wells + for completion_obj in well.completions: + if completion_obj.status == "OPEN": # Only process open completions + ijk = completion_obj.IJK - 1 # Convert to 0-based indexing + if ijk in ijk_to_active: + active_idx = ijk_to_active[ijk] + cf = ( + completion_obj.connection_factor + if use_completion_connection_factor + else 1.0 + ) + if well.type == "INJ": + completion_cell_id_inj[active_idx] = cf * 1.0 + elif well.type == "PRD": + completion_cell_id_prd[active_idx] = cf * 1.0 + + return completion_cell_id_inj, completion_cell_id_prd diff --git a/examples/reservoir_simulation/sim_utils/well.py b/examples/reservoir_simulation/sim_utils/well.py new file mode 100644 index 0000000000..bb15df5fc2 --- /dev/null +++ b/examples/reservoir_simulation/sim_utils/well.py @@ -0,0 +1,174 @@ +# ignore_header_test +# Copyright 2025 Tsubasa Onishi +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +__all__ = ["Completion", "Well"] + + +def __dir__(): + return __all__ + + +class Completion: + """Represents a well completion. + + A completion defines a connection between a well and a grid cell. It stores + information about the completion's location, status (open or shut), and + flow direction. + + Attributes: + I (int): I-index of the grid cell (1-based). + J (int): J-index of the grid cell (1-based). + K (int): K-index of the grid cell (1-based). + status (str): Completion status ("OPEN" or "SHUT"). + dir (int): Penetration direction (1=X-dir, 2=Y-dir, 3=Z-dir, 4=fractured in X-dir, 5=fractured in Y-dir). + connection_factor (float): Connection transmissibility factor. + IJK (int, optional): Linear index of the grid cell (1-based). Set dynamically + via `set_ijk()` method. May not exist until explicitly set. + flow_rate (float, optional): Flow rate at the completion. Positive: injection, + negative: production. Set dynamically via `set_flow_rate()` method. + May not exist until explicitly set. + """ + + def __init__( + self, I: int, J: int, K: int, dir: int, stat: int, conx_factor: float + ) -> None: + """Initializes a Completion object. + + Parameters + I (int): I-index of the grid cell (1-based). + J (int): J-index of the grid cell (1-based). + K (int): K-index of the grid cell (1-based). + dir (int): penetration direction. 1=X-dir, 2=Y-dir, 3=Z-dir, 4=fractured in X-dir, and 5=fractured in Y-dir. + stat (int): Completion status ID (positive for open, other values for shut). + conx_factor (float): Connection transmissibility factor + """ + self.I = I + self.J = J + self.K = K + self.dir = dir + self._set_status(stat) + self.connection_factor = conx_factor + + def set_ijk(self, ijk: int) -> None: + """Sets the linear grid cell index (IJK). + + Parameters + ijk (int): Linear index of the grid cell (1-based). + """ + self.IJK = ijk + + def set_flow_rate(self, val: float) -> None: + """Sets the flow rate at the completion. + + Parameters + val (float): Flow rate. Positive: injection, negative: production. + """ + self.flow_rate = val # positive: injection, negative: production + + # ---- Private Methods --------------------------------------------------------------------------------------------- + + def _set_status(self, stat_id: int) -> None: + """Sets the completion status. + + Parameters + stat_id (int): Status ID (positive for open, other values for shut). + """ + self.status = "OPEN" if stat_id > 0 else "SHUT" + + +class Well: + """Represents a well. + + A well has a name, type (producer or injector), and a list of completions. + + Attributes: + name (str): Name of the well. + type (str): Type of well ("PRD" or "INJ"). + completions (list[Completion]): List of Completion objects associated with the well. + num_active_completions (int): Number of active (open) completions. + """ + + def __init__(self, name: str, type_id: int, stat: int) -> None: + """Initializes a Well object. + + Parameters + name (str): Name of the well. + type_id (int): Well type ID. + stat (int): Well status ID (positive for open, other values for shut). + """ + self.name = name + self._set_type(type_id) + self.completions = [] + self.num_active_completions = 0 + self.status = "OPEN" if stat > 0 else "SHUT" + + def add_completion( + self, I: int, J: int, K: int, dir: int, stat: int, conx_factor: float + ) -> None: + """Adds a completion to the well. + + Parameters + I (int): I-index of the grid cell (1-based). + J (int): J-index of the grid cell (1-based). + K (int): K-index of the grid cell (1-based). + dir (int): penetration direction. 1=X-dir, 2=Y-dir, 3=Z-dir, 4=fractured in X-dir, and 5=fractured in Y-dir + stat (int): Completion status ID (positive for open, other values for shut). + conx_factor (float): Connection transmissibility factor + """ + cmpl_stat = ( + stat if self.status == "OPEN" else 0 + ) # treatment for OPM (ICON is not updated when a well is shut) + self.completions.append(Completion(I, J, K, dir, cmpl_stat, conx_factor)) + if self.completions[-1].status == "OPEN": + self.num_active_completions += 1 + + def set_status(self) -> None: + """Set well status based on completion status""" + self.status = "SHUT" if self.num_active_completions == 0 else "OPEN" + + # ---- Private Methods --------------------------------------------------------------------------------------------- + + def _set_type(self, type_id: int) -> None: + """Sets the well type. + + Parameters + type_id (int): + - Well type ID - 1 for PRD, 2 for OILINJ, 3 for WATINJ, 4 for GASINJ (ECL) + - 5 for injector identifier for CMG (unclear how to get different injector types in CMG) + """ + if type_id == 1: + self.type = "PRD" + elif type_id in [2, 3, 4, 5]: + self.type = "INJ" + else: + self.type = "UNKNOWN" + logging.warning(f"Unknown well type: {type_id} found at well: {self.name}") diff --git a/examples/reservoir_simulation/xmgn/README.md b/examples/reservoir_simulation/xmgn/README.md new file mode 100644 index 0000000000..49f872e0fd --- /dev/null +++ b/examples/reservoir_simulation/xmgn/README.md @@ -0,0 +1,248 @@ +# XMeshGraphNet for Reservoir Simulation + +An example for surrogate modeling using +[X-MeshGraphNet](https://arxiv.org/pdf/2411.17164) on reservoir simulation +datasets. + +## Overview + +Reservoir simulation predicts reservoir performance using physical and +mathematical models. It plays critical roles in production forecasting, +reservoir management, field development planning, and optimization. Despite +advances in parallel computing and GPU acceleration, routine reservoir +simulation workflows requiring thousands of simulations remain computationally +expensive, creating a need for faster surrogate models. + +This example provides a reference implementation of XMeshGraphNet (X-MGN) for +building reservoir simulation surrogates. X-MGN is naturally compatible with +the finite volume framework commonly used in reservoir simulation. It is +particularly effective for systems with irregular connections such as faults, +pinch-outs, dual-porosity, and discrete fractures, etc. Furthermore, X-MGN +scales efficiently to industry-scale reservoir models with millions of cells. + +## Quick Start + +### Prerequisites + +**Python Version**: Python 3.10 or higher (tested with Python 3.10 and 3.11) + +**Install Dependencies**: + +```bash +pip install -r requirements.txt +``` + +### 0. Dataset Preparation + +You need to provide reservoir simulation data with ECLIPSE/IX style output +format to use this example. + +> **⚠️ Dataset License Disclaimer** +> +> Users are responsible for verifying and complying with the license terms of +> any dataset they use with this example. This includes datasets referenced in +> this documentation (such as the Norne Field dataset) or any proprietary data. +> Please ensure you have the appropriate rights and permissions before using +> any dataset for your research or commercial applications. + +#### Option 1: Use Your Own Simulation Data + +If you have your own reservoir simulation dataset, ensure all simulation cases +are stored in a single directory with ECLIPSE/IX style output files: + +```text +/ +├── CASE_1.DATA +├── CASE_1.INIT +├── CASE_1.EGRID +├── CASE_1.UNRST +├── CASE_2.DATA +├── CASE_2.INIT +└── ... (multiple cases) +``` + +#### Option 2: Sample Data + +**Note**: A downloadable sample dataset will be made available soon. + +- Example 1: Waterflood in a 2D quarter five-spot model with varying + permeability distributions generated using a geostatistical method + (1000 samples). +- Example 2: Based on the publicly available + [Norne Field](https://github.com/OPM/opm-data/tree/master/norne) dataset. + A Design of Experiment and sensitivity study identified fault + transmissibility and KVKH multipliers as key variables, which were then + varied using Latin Hypercube Sampling to generate 500 samples. This + well-known model contains numerous faults represented by Non-Neighbor + Connections (NNCs), which X-MGN naturally handles through its + graph structure. + +An open-source reservoir simulator, [OPM](https://opm-project.org/), was used +to generate both datasets. + +#### Expected Data Format + +- **Format**: ECLIPSE/IX compatible binary files +- **Required files per case**: `.INIT`, `.EGRID`, `.UNRST` (or `.X00xx`), `.UNSMRY` (or `.S00xx`) +- **Storage**: All cases in a single directory + +#### Example Visualization: Norne Field + +Static reservoir property and domain partitions: + + + + + + + + + + + +
Permeability (PERMX) distributionX-MeshGraphNet partitioning (0=halo region)
+ + +### 1. Data Preprocessing + +Configure your dataset path in `conf/.yaml` by setting +`dataset.sim_dir` to point to your simulation data directory, then run: + +```bash +python src/preprocessor.py --config-name= +``` + +**Note:** Replace `` with your configuration file name from the +`conf/` directory (without the `.yaml` extension). For example, use `config` +for `conf/config.yaml`. Use the same config name for training and inference +steps below. + +**What it does**: + +- Reads simulation binary files (`.INIT`, `.EGRID`, `.UNRST`) in the dataset directory. +- Extracts variables specified in the configuration file +- Builds graph structures with nodes (grid cells) and edges (connections) +- Creates autoregressive training sequences for next-timestep prediction +- Saves processed graphs + +### 2. Training + +Multi-GPU training is supported: + +```bash +torchrun --nproc_per_node=4 --nnodes=1 src/train.py --config-name= +``` + +### 3. Inference and Visualization + +Run autoregressive inference to predict future timesteps: + +```bash +python src/inference.py --config-name= +``` + +**Output Location:** Results are saved to +`outputs//inference/` + +**Output Files:** + +- **HDF5 files**: Contain predictions and targets for each simulation case, + organized by timestep and variable +- **GRDECL files**: Eclipse-compatible ASCII format that can be imported into + popular software such as Petrel and [ResInsight](https://resinsight.org/) + for visualization + +#### Example Results: Autoregressive Inference + +The following shows water saturation and pressure predictions for the Norne +field across 64 timesteps spanning 10 years of operation with varying well +controls. X-MGN demonstrates +good predictability, especially for near-term predictions. As expected for +autoregressive prediction, errors accumulate over time, but the model maintains +reasonable accuracy throughout: + + + +**Pressure** + + + + + + + + + + + + + + + + + + + + + + +
30 Jul 2001
(Timestep 21, Day 1362)
16 Sep 2003
(Timestep 42, Day 2140)
Ground Truth
X-MGN Prediction
Prediction Error
+ +**Water Saturation** + + + + + + + + + + + + + + + + + + + + + + +
30 Jul 2001
(Timestep 21, Day 1362)
16 Sep 2003
(Timestep 42, Day 2140)
Ground Truth
X-MGN Prediction
Prediction Error
+ + + +## Experiment Tracking + +Launch MLflow UI to monitor training progress (replace `` +with your experiment name from the config): + +```bash +cd outputs/ +mlflow ui --host 0.0.0.0 --port 5000 +``` + +Access the dashboard at: + +## References + +- [X-MeshGraphNet: Scalable Multi-Scale Graph Neural Networks for Physics + Simulation](https://arxiv.org/pdf/2411.17164) +- [Open Porous Media (OPM) Flow Simulator](https://opm-project.org/) diff --git a/examples/reservoir_simulation/xmgn/conf/config.yaml b/examples/reservoir_simulation/xmgn/conf/config.yaml new file mode 100644 index 0000000000..8de0a5de01 --- /dev/null +++ b/examples/reservoir_simulation/xmgn/conf/config.yaml @@ -0,0 +1,121 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +hydra: + job: + chdir: true + name: XMGN_2D_Q5SP_Waterflood + run: + dir: ./outputs/${hydra:job.name} + +# ┌───────────────────────────────────────────┐ +# │ Run Specification │ +# └───────────────────────────────────────────┘ + +runspec: + job_name: ${hydra:job.name} + description: "2D Q5SP waterflood reservoir simulation experiment" + +# ┌───────────────────────────────────────────┐ +# │ Dataset Configuration │ +# └───────────────────────────────────────────┘ + +dataset: + simulator: OPM # Simulator type (currently support simulators with ECLIPSE-style output files: OPM, ECLIPSE, IX, etc.) + sim_dir: ../dataset/2D_OW/CASE_2D.sim # Directory containing simulation cases and results + # num_samples: 20 # Number of samples to process (default = all, else for small size testing) + + # Graph configuration + graph: + node_features: + static: ["PERMX", "PORV", "X", "Y", "Z"] # Static node features (ECL standard + some custom keys supported) + dynamic: + variables: ["PRESSURE", "SWAT", "WCID"] # Dynamic node features (current timestep) + prev_timesteps: 2 # Number of previous timesteps to include (0=current only, 1=current+previous, etc.) + # time_series: ["WWIR", "WGIR", "WBHP"] # Time series variables. mapped onto grid cells with completion cell indices. (currently prototype implementation) + + edge_features: ["TRANX", "TRANY", "TRANZ", "TRANNNC"] # directional transmissibilities (including nnc). will be combined. + + global_features: + delta_t: true # Include time step size as global feature + time: true # Include time (normalized 0-1) as global feature + # TODO: implement advanced global features (e.g., field management) + + target_vars: + node_features: ["PRESSURE", "SWAT"] # Target variables (next timestep) + # time_series: ["WWCT", "WGOR"] # Time series variables #TODO: implement later (currently not used) + weights: [1.0, 1.0] # Per-variable loss weights (same order as node_features) + loss_functions: ["L2", "L1"] # Loss functions for each target variable (L1, L2, Huber) + + nonlinear_scaling: ["PERMX:LOG10", "TRAN:LOG10"] # Optional: specify irregular distribution, otherwise values will be directly scaled to [0, 1] + + +# ┌───────────────────────────────────────────┐ +# │ Data Preprocessing │ +# └───────────────────────────────────────────┘ + +preprocessing: + skip_graphs: false # Skip graph generation (use existing graphs to test partitioning parameters) + num_partitions: 2 # Number of partitions for each graph + halo_size: 3 # Size of halo region for partitions + num_preprocess_workers: 4 # Number of workers for data preprocessing + data_split: + train_ratio: 0.8 # Ratio of samples for training + val_ratio: 0.1 # Ratio of samples for validation + test_ratio: 0.1 # Ratio of samples for testing + +# ┌───────────────────────────────────────────┐ +# │ Model Configuration │ +# └───────────────────────────────────────────┘ + +model: + num_message_passing_layers: 3 # Number of message passing layers + hidden_dim: 64 # Hidden dimension of the model + activation: silu # Activation function + +# ┌───────────────────────────────────────────┐ +# │ Training Configuration │ +# └───────────────────────────────────────────┘ + +training: + num_epochs: 1000 # Number of epochs + batch_size: 1 # Batch size for training + start_lr: 1e-4 # Initial learning rate (cos annealing schedule is used) + end_lr: 1e-6 # Final learning rate (cos annealing schedule is used) + weight_decay: 1e-3 # Weight decay for AdamW optimizer (L2 regularization) + validation_freq: 5 # Frequency of validation and checkpoint saving + resume: false # Resume training from existing checkpoints (true) or start fresh (false) + early_stopping: + patience: 20 # Number of actual epochs (not validation checks) to wait for improvement + min_delta: 1e-6 # Minimum change to qualify as improvement + +# ┌───────────────────────────────────────────┐ +# │ Performance Optimization │ +# └───────────────────────────────────────────┘ + +performance: + use_concat_trick: true # Use the concatenation trick + checkpoint_segments: 2 # Number of segments for the activation checkpointing + enable_cudnn_benchmark: true # Enable cudnn benchmark + +# ┌───────────────────────────────────────────┐ +# │ Inference Configuration │ +# └───────────────────────────────────────────┘ + +inference: + checkpoint_path: null # Explicit path to .pt checkpoint file (e.g., "outputs/XMeshGraphNet_Reservoir/best_checkpoints/checkpoint.0.30.pt") + model_path: null # Explicit path to .mdlus model file (e.g., "outputs/XMeshGraphNet_Reservoir/best_checkpoints/MeshGraphNet.0.30.mdlus") + # Note: If both are null, automatically uses the best checkpoint from best_checkpoints directory diff --git a/examples/reservoir_simulation/xmgn/conf/config_norne.yaml b/examples/reservoir_simulation/xmgn/conf/config_norne.yaml new file mode 100644 index 0000000000..45476729f1 --- /dev/null +++ b/examples/reservoir_simulation/xmgn/conf/config_norne.yaml @@ -0,0 +1,121 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +hydra: + job: + chdir: true + name: XMGN_Norne + run: + dir: ./outputs/${hydra:job.name} + +# ┌───────────────────────────────────────────┐ +# │ Run Specification │ +# └───────────────────────────────────────────┘ + +runspec: + job_name: ${hydra:job.name} + description: "Experiment with Norne field" + +# ┌───────────────────────────────────────────┐ +# │ Dataset Configuration │ +# └───────────────────────────────────────────┘ + +dataset: + simulator: OPM # Simulator type (currently support simulators with ECLIPSE-style output files: OPM, ECLIPSE, IX, etc.) + sim_dir: ../dataset/norne/NORNE_LHS.sim # Directory containing simulation cases and results + # num_samples: 10 # Number of samples to process (default = all, else for small size testing) + + # Graph configuration + graph: + node_features: + static: ["PERMX", "PORV", "X", "Y", "Z"] # Static node features (ECL standard + some custom keys supported) + dynamic: + variables: ["PRESSURE", "SWAT", "WCID"] # Dynamic node features (current timestep) + prev_timesteps: 2 # Number of previous timesteps to include (0=current only, 1=current+previous, etc.) + # time_series: ["WWIR", "WGIR", "WBHP"] # Time series variables. mapped onto grid cells with completion cell indices. (currently prototype implementation) + + edge_features: ["TRANX", "TRANY", "TRANZ", "TRANNNC"] # directional transmissibilities (including nnc). will be combined. + + global_features: + delta_t: true # Include time step size as global feature + time: true # Include time (normalized 0-1) as global feature + # TODO: implement advanced global features (e.g., field management) + + target_vars: + node_features: ["PRESSURE", "SWAT"] # Target variables (next timestep) + # time_series: ["WWCT", "WGOR"] # Time series variables #TODO: implement later (currently not used) + weights: [1.0, 1.0] # Per-variable loss weights (same order as node_features) + loss_functions: ["L2", "L1"] # Loss functions for each target variable (L1, L2, Huber) + + nonlinear_scaling: ["PERMX:LOG10", "TRAN:LOG10"] # Optional: specify irregular distribution, otherwise values will be directly scaled to [0, 1] + + +# ┌───────────────────────────────────────────┐ +# │ Data Preprocessing │ +# └───────────────────────────────────────────┘ + +preprocessing: + skip_graphs: false # Skip graph generation (use existing graphs to test partitioning parameters) + num_partitions: 3 # Number of partitions for each graph + halo_size: 5 # Size of halo region for partitions + num_preprocess_workers: 8 # Number of workers for data preprocessing + data_split: + train_ratio: 0.8 # Ratio of samples for training + val_ratio: 0.1 # Ratio of samples for validation + test_ratio: 0.1 # Ratio of samples for testing + +# ┌───────────────────────────────────────────┐ +# │ Model Configuration │ +# └───────────────────────────────────────────┘ + +model: + num_message_passing_layers: 5 # Number of message passing layers + hidden_dim: 128 # Hidden dimension of the model + activation: silu # Activation function + +# ┌───────────────────────────────────────────┐ +# │ Training Configuration │ +# └───────────────────────────────────────────┘ + +training: + num_epochs: 1000 # Number of epochs + batch_size: 1 # Batch size for training + start_lr: 1e-3 # Initial learning rate (cos annealing schedule is used) + end_lr: 1e-6 # Final learning rate (cos annealing schedule is used) + weight_decay: 1e-3 # Weight decay for AdamW optimizer (L2 regularization) + validation_freq: 5 # Frequency of validation and checkpoint saving + resume: false # Resume training from existing checkpoints (true) or start fresh (false) + early_stopping: + patience: 20 # Number of actual epochs (not validation checks) to wait for improvement + min_delta: 1e-6 # Minimum change to qualify as improvement + +# ┌───────────────────────────────────────────┐ +# │ Performance Optimization │ +# └───────────────────────────────────────────┘ + +performance: + use_concat_trick: true # Use the concatenation trick + checkpoint_segments: 2 # Number of segments for the activation checkpointing + enable_cudnn_benchmark: true # Enable cudnn benchmark + +# ┌───────────────────────────────────────────┐ +# │ Inference Configuration │ +# └───────────────────────────────────────────┘ + +inference: + checkpoint_path: null # Explicit path to .pt checkpoint file (e.g., "outputs/XMeshGraphNet_Reservoir/best_checkpoints/checkpoint.0.30.pt") + model_path: null # Explicit path to .mdlus model file (e.g., "outputs/XMeshGraphNet_Reservoir/best_checkpoints/MeshGraphNet.0.30.mdlus") + # Note: If both are null, automatically uses the best checkpoint from best_checkpoints directory diff --git a/examples/reservoir_simulation/xmgn/requirements.txt b/examples/reservoir_simulation/xmgn/requirements.txt new file mode 100644 index 0000000000..eda6a39a5e --- /dev/null +++ b/examples/reservoir_simulation/xmgn/requirements.txt @@ -0,0 +1,32 @@ +# XMeshGraphNet Standalone Training Requirements +# Tested with Python 3.10 and CUDA 12.1 + +# PyTorch Geometric wheel repository for extension packages +--find-links https://data.pyg.org/whl/torch-2.4.0+cu121.html + +# Core dependencies +torch==2.4.0 +torchaudio==2.4.0 +torchvision==0.19.0 +numpy>=1.26.0,<2.0 +hydra-core>=1.3.0 +omegaconf>=2.3.0 + +# PyTorch Geometric and extensions +torch-geometric>=2.6.0 +torch-scatter>=2.1.2 +torch-sparse>=0.6.18 +torch-cluster>=1.6.3 +torch-spline-conv>=1.2.2 + +# Data handling +h5py>=3.14.0 + +# Experiment tracking +mlflow>=3.4.0 + +# Terminal colors (required by PhysicsNeMo logging) +termcolor>=2.0.0 + +# Notes: +# - scipy is automatically installed as a dependency of torch-sparse diff --git a/examples/reservoir_simulation/xmgn/src/data/__init__.py b/examples/reservoir_simulation/xmgn/src/data/__init__.py new file mode 100644 index 0000000000..5ebf50a710 --- /dev/null +++ b/examples/reservoir_simulation/xmgn/src/data/__init__.py @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Data processing and loading utilities + +# Import data processing utilities +from .graph_builder import ReservoirGraphBuilder +from sim_utils import EclReader, Grid, Well, Completion +from .dataloader import ( + GraphDataset, + create_dataloader, + load_stats, + find_pt_files, + custom_collate_fn, +) + +__all__ = [ + # Data processing + "ReservoirGraphBuilder", + "EclReader", + "Grid", + "Well", + "Completion", + # Data loading + "GraphDataset", + "create_dataloader", + "load_stats", + "find_pt_files", + "custom_collate_fn", +] diff --git a/examples/reservoir_simulation/xmgn/src/data/dataloader.py b/examples/reservoir_simulation/xmgn/src/data/dataloader.py new file mode 100644 index 0000000000..6c198cb8e6 --- /dev/null +++ b/examples/reservoir_simulation/xmgn/src/data/dataloader.py @@ -0,0 +1,623 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Custom dataset and dataloader utilities for reservoir simulation graph data. +Provides GraphDataset for loading and normalizing partitioned graphs from .pt files, +along with utilities for computing global statistics and creating efficient dataloaders. +""" + +import os +import sys +import json +import logging + +import torch +import torch_geometric as pyg +import numpy as np +from torch.utils.data import Dataset, DataLoader + +# Module-level logger +logger = logging.getLogger(__name__) + + +def find_pt_files(directory): + """ + Find all .pt files in a directory. + + Parameters + ---------- + directory : str + Directory to search for .pt files + + Returns + ------- + file_paths : list + List of file paths to .pt files + """ + import glob + + if not os.path.exists(directory): + return [] + + pattern = os.path.join(directory, "**", "*.pt") + file_paths = glob.glob(pattern, recursive=True) + return sorted(file_paths) + + +def save_stats(stats, output_file): + """ + Save statistics to a JSON file. + + Parameters + ---------- + stats : dict + Statistics dictionary + output_file : str + Output file path + """ + with open(output_file, "w") as f: + json.dump(stats, f, indent=2) + + logger.info(f"Statistics saved to {output_file}") + + +def load_stats(stats_file): + """ + Load statistics from a JSON file. + + Parameters + ---------- + stats_file : str + Path to the statistics file + + Returns + ------- + stats : dict + Statistics dictionary + """ + with open(stats_file, "r") as f: + stats = json.load(f) + + return stats + + +def compute_global_statistics(graph_files, stats_file=None): + """ + Compute global statistics (mean and std) across all graphs for normalization. + + Parameters + ---------- + graph_files : list + List of paths to graph files (.pt files) + stats_file : str, optional + Path to save statistics JSON file. If None, statistics are not saved. + + Returns + ------- + dict : Dictionary containing node, edge, and target statistics + """ + logger.info(f"Computing global statistics across {len(graph_files)} graphs...") + + # Collect all node, edge, and target features + all_node_features = [] + all_edge_features = [] + all_target_features = [] + + # Process all graphs to compute statistics + logger.info(f"Computing statistics from {len(graph_files)} graphs...") + for i, file_path in enumerate(graph_files, 1): + try: + graph = torch.load(file_path, weights_only=False) + + # Collect node features + if hasattr(graph, "x") and graph.x is not None: + all_node_features.append(graph.x) + + # Collect edge features + if hasattr(graph, "edge_attr") and graph.edge_attr is not None: + all_edge_features.append(graph.edge_attr) + + # Collect target features + if hasattr(graph, "y") and graph.y is not None: + all_target_features.append(graph.y) + + if i % 100 == 0: + logger.info(f" Processed {i}/{len(graph_files)} graphs...") + + except Exception as e: + logger.warning(f"Failed to load graph {file_path}: {e}") + continue + + # Compute statistics for node features + if all_node_features: + # Filter out graphs with inconsistent feature dimensions + if len(all_node_features) > 1: + # Find the most common feature dimension + feature_dims = [ + feat.shape[1] for feat in all_node_features if feat.numel() > 0 + ] + if feature_dims: + from collections import Counter + + most_common_dim = Counter(feature_dims).most_common(1)[0][0] + # Keep only graphs with the most common feature dimension + all_node_features = [ + feat + for feat in all_node_features + if feat.shape[1] == most_common_dim + ] + logger.info( + f" Filtered to {len(all_node_features)} graphs with consistent {most_common_dim} node features" + ) + + if all_node_features: + # Concatenate all node features: [total_nodes, num_node_features] + all_nodes = torch.cat(all_node_features, dim=0) + node_mean = torch.mean(all_nodes, dim=0) # [num_node_features] + node_std = torch.std(all_nodes, dim=0) # [num_node_features] + logger.info( + f" Node features: {all_nodes.shape[1]} features, {all_nodes.shape[0]} total nodes" + ) + else: + node_mean = torch.tensor([]) + node_std = torch.tensor([]) + logger.warning(" No consistent node features found") + else: + node_mean = torch.tensor([]) + node_std = torch.tensor([]) + logger.warning(" No node features found") + + # Compute statistics for edge features + if all_edge_features: + # Filter out graphs with inconsistent feature dimensions + if len(all_edge_features) > 1: + # Find the most common feature dimension + feature_dims = [ + feat.shape[1] for feat in all_edge_features if feat.numel() > 0 + ] + if feature_dims: + from collections import Counter + + most_common_dim = Counter(feature_dims).most_common(1)[0][0] + # Keep only graphs with the most common feature dimension + all_edge_features = [ + feat + for feat in all_edge_features + if feat.shape[1] == most_common_dim + ] + logger.info( + f" Filtered to {len(all_edge_features)} graphs with consistent {most_common_dim} edge features" + ) + + if all_edge_features: + # Concatenate all edge features: [total_edges, num_edge_features] + all_edges = torch.cat(all_edge_features, dim=0) + edge_mean = torch.mean(all_edges, dim=0) # [num_edge_features] + edge_std = torch.std(all_edges, dim=0) # [num_edge_features] + logger.info( + f" Edge features: {all_edges.shape[1]} features, {all_edges.shape[0]} total edges" + ) + else: + edge_mean = torch.tensor([]) + edge_std = torch.tensor([]) + logger.warning(" No consistent edge features found") + else: + edge_mean = torch.tensor([]) + edge_std = torch.tensor([]) + logger.warning(" No edge features found") + + # Compute statistics for target features + if all_target_features: + # Filter out graphs with inconsistent feature dimensions + if len(all_target_features) > 1: + # Find the most common feature dimension + feature_dims = [ + feat.shape[1] for feat in all_target_features if feat.numel() > 0 + ] + if feature_dims: + from collections import Counter + + most_common_dim = Counter(feature_dims).most_common(1)[0][0] + # Keep only graphs with the most common feature dimension + all_target_features = [ + feat + for feat in all_target_features + if feat.shape[1] == most_common_dim + ] + logger.info( + f" Filtered to {len(all_target_features)} graphs with consistent {most_common_dim} target features" + ) + + if all_target_features: + # Concatenate all target features: [total_nodes, num_target_features] + all_targets = torch.cat(all_target_features, dim=0) + target_mean = torch.mean(all_targets, dim=0) # [num_target_features] + target_std = torch.std(all_targets, dim=0) # [num_target_features] + logger.info( + f" Target features: {all_targets.shape[1]} features, {all_targets.shape[0]} total nodes" + ) + else: + target_mean = torch.tensor([]) + target_std = torch.tensor([]) + logger.warning(" No consistent target features found") + else: + target_mean = torch.tensor([]) + target_std = torch.tensor([]) + logger.warning(" No target features found") + + # Create statistics dictionary + stats = { + "node_features": {"mean": node_mean.tolist(), "std": node_std.tolist()}, + "edge_features": {"mean": edge_mean.tolist(), "std": edge_std.tolist()}, + "target_features": {"mean": target_mean.tolist(), "std": target_std.tolist()}, + } + + # Save statistics if requested + if stats_file: + with open(stats_file, "w") as f: + json.dump(stats, f, indent=2) + logger.info(f" Statistics saved to {stats_file}") + + logger.info(f" Node features - Mean: {node_mean.tolist()}") + logger.info(f" Node features - Std: {node_std.tolist()}") + logger.info(f" Edge features - Mean: {edge_mean.tolist()}") + logger.info(f" Edge features - Std: {edge_std.tolist()}") + logger.info(f" Target features - Mean: {target_mean.tolist()}") + logger.info(f" Target features - Std: {target_std.tolist()}") + + return stats + + +class PartitionedGraph: + """ + A class for partitioning a graph into multiple parts with halo regions. + + Parameters + ---------- + graph : pyg.data.Data + The graph data. + num_parts : int + The number of partitions. + halo_size : int + The size of the halo region. + """ + + def __init__(self, graph: pyg.data.Data, num_parts: int, halo_size: int): + self.num_nodes = graph.num_nodes + self.num_parts = num_parts + self.halo_size = halo_size + + # Try to partition the graph using PyG METIS, with fallback to simple partitioning + try: + # Partition the graph using PyG METIS. + # https://pytorch-geometric.readthedocs.io/en/latest/modules/loader.html#torch_geometric.loader.ClusterData + cluster_data = pyg.loader.ClusterData(graph, num_parts=self.num_parts) + part_meta = cluster_data.partition + except Exception as e: + logger.warning( + f" METIS partitioning failed ({e}), using simple partitioning..." + ) + # Fallback: simple sequential partitioning + part_meta = self._create_simple_partition(graph.num_nodes, num_parts) + + # Create partitions with halo regions using PyG `k_hop_subgraph`. + self.partitions = [] + for i in range(self.num_parts): + # Get inner nodes of the partition. + part_inner_node = part_meta.node_perm[ + part_meta.partptr[i] : part_meta.partptr[i + 1] + ] + # Partition the graph with halo regions. + # https://pytorch-geometric.readthedocs.io/en/latest/modules/utils.html?#torch_geometric.utils.k_hop_subgraph + part_node, part_edge_index, inner_node_mapping, edge_mask = ( + pyg.utils.k_hop_subgraph( + part_inner_node, + num_hops=self.halo_size, + edge_index=graph.edge_index, + num_nodes=self.num_nodes, + relabel_nodes=True, + ) + ) + + partition = pyg.data.Data( + edge_index=part_edge_index, + edge_attr=graph.edge_attr[edge_mask], + num_nodes=part_node.size(0), + part_node=part_node, + inner_node=inner_node_mapping, + ) + # Set partition node attributes. + for k, v in graph.items(): + if graph.is_node_attr(k): + setattr(partition, k, v[part_node]) + + self.partitions.append(partition) + + def __len__(self): + return self.num_parts + + def __getitem__(self, idx): + return self.partitions[idx] + + def _create_simple_partition(self, num_nodes, num_parts): + """Create a simple sequential partition as fallback when METIS is not available.""" + import torch + + # Create a simple partition object that mimics the METIS partition structure + class SimplePartition: + def __init__(self, num_nodes, num_parts): + self.node_perm = torch.arange(num_nodes) + + # Calculate partition boundaries + part_size = num_nodes // num_parts + remainder = num_nodes % num_parts + + self.partptr = [0] + for i in range(num_parts): + current_size = part_size + (1 if i < remainder else 0) + self.partptr.append(self.partptr[-1] + current_size) + + return SimplePartition(num_nodes, num_parts) + + +class GraphDataset(Dataset): + """ + A custom dataset class for loading and normalizing graph partition data. + + Parameters + ---------- + file_paths : list + List of file paths to the graph partition files. + node_mean : torch.Tensor + Global mean for node attributes (shape: [num_node_features]). + node_std : torch.Tensor + Global standard deviation for node attributes (shape: [num_node_features]). + edge_mean : torch.Tensor + Global mean for edge attributes (shape: [num_edge_features]). + edge_std : torch.Tensor + Global standard deviation for edge attributes (shape: [num_edge_features]). + target_mean : torch.Tensor + Global mean for target attributes (shape: [num_target_features]). + target_std : torch.Tensor + Global standard deviation for target attributes (shape: [num_target_features]). + """ + + def __init__( + self, + file_paths, + node_mean, + node_std, + edge_mean, + edge_std, + target_mean=None, + target_std=None, + ): + self.file_paths = file_paths + self.node_mean = node_mean + self.node_std = node_std + self.edge_mean = edge_mean + self.edge_std = edge_std + self.target_mean = target_mean + self.target_std = target_std + + def __len__(self): + return len(self.file_paths) + + def __getitem__(self, idx): + # Load the list of graph partitions (following xaeronet pattern) + partitions = torch.load(self.file_paths[idx], weights_only=False) + + # Extract label from filename (sample index) + filename = os.path.basename(self.file_paths[idx]) + # Handle different filename formats: + # - Raw graphs: CASE_2D_1_000.pt or CASE_2D_0001_000.pt -> extract sample index + # - Partitions: partitions_CASE_2D_630_009.pt or partitions_CASE_2D_0630_009.pt -> extract sample index + # - Norne format: partitions_NORNE_ATW2013_DOE_0004_002.pt -> extract 0004 (sample index) + parts = filename.replace(".pt", "").split("_") + label = 0 # Default label + + # Find all numeric parts in the filename + numeric_parts = [] + for i, part in enumerate(parts): + try: + numeric_value = int(part) + numeric_parts.append((i, numeric_value)) + except ValueError: + continue + + # The sample index is typically the second-to-last numeric part + # (last numeric part is usually the timestep) + if len(numeric_parts) >= 2: + # Get the second-to-last numeric part as the sample index + label = numeric_parts[-2][1] + elif len(numeric_parts) == 1: + # If only one numeric part, use it as the label + label = numeric_parts[0][1] + + # Normalize each partition in the list + for partition in partitions: + # Normalize node attributes (per-feature normalization) + if hasattr(partition, "x") and partition.x is not None: + # Ensure dimensions match: partition.x shape should be [num_nodes, num_features] + # node_mean and node_std should be [num_features] + if partition.x.dim() == 2 and self.node_mean.dim() == 1: + # Broadcasting: [num_nodes, num_features] - [num_features] -> [num_nodes, num_features] + partition.x = (partition.x - self.node_mean) / ( + self.node_std + 1e-8 + ) + else: + # Fallback for mismatched dimensions + logger.warning( + f"Dimension mismatch in node features. Partition shape: {partition.x.shape}, Stats shape: {self.node_mean.shape}" + ) + partition.x = (partition.x - self.node_mean.unsqueeze(0)) / ( + self.node_std.unsqueeze(0) + 1e-8 + ) + + # Normalize edge attributes (per-feature normalization) + if hasattr(partition, "edge_attr") and partition.edge_attr is not None: + # Ensure dimensions match: partition.edge_attr shape should be [num_edges, num_edge_features] + # edge_mean and edge_std should be [num_edge_features] + if partition.edge_attr.dim() == 2 and self.edge_mean.dim() == 1: + # Broadcasting: [num_edges, num_edge_features] - [num_edge_features] -> [num_edges, num_edge_features] + partition.edge_attr = (partition.edge_attr - self.edge_mean) / ( + self.edge_std + 1e-8 + ) + else: + # Fallback for mismatched dimensions + logger.warning( + f"Dimension mismatch in edge features. Partition shape: {partition.edge_attr.shape}, Stats shape: {self.edge_mean.shape}" + ) + partition.edge_attr = ( + partition.edge_attr - self.edge_mean.unsqueeze(0) + ) / (self.edge_std.unsqueeze(0) + 1e-8) + + # Normalize target attributes (per-feature normalization) + if ( + hasattr(partition, "y") + and partition.y is not None + and self.target_mean is not None + and self.target_std is not None + ): + # Ensure dimensions match: partition.y shape should be [num_nodes, num_target_features] + # target_mean and target_std should be [num_target_features] + if partition.y.dim() == 2 and self.target_mean.dim() == 1: + # Broadcasting: [num_nodes, num_target_features] - [num_target_features] -> [num_nodes, num_target_features] + partition.y = (partition.y - self.target_mean) / ( + self.target_std + 1e-8 + ) + else: + # Fallback for mismatched dimensions + logger.warning( + f"Dimension mismatch in target features. Partition shape: {partition.y.shape}, Stats shape: {self.target_mean.shape}" + ) + partition.y = (partition.y - self.target_mean.unsqueeze(0)) / ( + self.target_std.unsqueeze(0) + 1e-8 + ) + + return partitions, label + + +def custom_collate_fn(batch): + """ + Custom collate function for lists of PartitionedGraph objects (following xaeronet pattern). + + Parameters + ---------- + batch : list + List of (partitions, label) tuples from the dataset + where partitions is a list of PartitionedGraph objects + + Returns + ------- + tuple + (partitions_list, labels) where partitions_list is a list of lists of PartitionedGraph objects + and labels is a tensor of labels + """ + partitions_list, labels = zip(*batch) + return list(partitions_list), torch.tensor(labels, dtype=torch.long) + + +def create_dataloader( + partitions_path, + validation_partitions_path, + stats_file, + batch_size=1, + shuffle=True, + num_workers=0, + prefetch_factor=2, + pin_memory=True, + is_validation=False, +): + """ + Create a data loader for graph partition data. + + Parameters + ---------- + partitions_path : str + Path to the partitions directory. + validation_partitions_path : str + Path to the validation partitions directory. + stats_file : str + Path to the global statistics file. + batch_size : int + Batch size for the data loader. + shuffle : bool + Whether to shuffle the data. + num_workers : int + Number of worker processes for data loading. + prefetch_factor : int + Number of batches to prefetch. + pin_memory : bool + Whether to pin memory for faster GPU transfer. + is_validation : bool + Whether this is for validation data. + + Returns + ------- + DataLoader + The data loader. + """ + # Load global statistics + with open(stats_file, "r") as f: + stats = json.load(f) + + # Load per-feature statistics + # node_features should be a list of means/stds for each feature + node_mean = torch.tensor( + stats["node_features"]["mean"] + ) # Shape: [num_node_features] + node_std = torch.tensor(stats["node_features"]["std"]) # Shape: [num_node_features] + edge_mean = torch.tensor( + stats["edge_features"]["mean"] + ) # Shape: [num_edge_features] + edge_std = torch.tensor(stats["edge_features"]["std"]) # Shape: [num_edge_features] + + # Load target feature statistics (if available) + target_mean = None + target_std = None + if "target_features" in stats: + target_mean = torch.tensor( + stats["target_features"]["mean"] + ) # Shape: [num_target_features] + target_std = torch.tensor( + stats["target_features"]["std"] + ) # Shape: [num_target_features] + + # Find partition files + if is_validation: + file_paths = find_pt_files(validation_partitions_path) + else: + file_paths = find_pt_files(partitions_path) + + # Create dataset + dataset = GraphDataset( + file_paths, node_mean, node_std, edge_mean, edge_std, target_mean, target_std + ) + + # Create data loader + dataloader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + prefetch_factor=prefetch_factor, + pin_memory=pin_memory, + collate_fn=custom_collate_fn, # Use custom collate function for lists of PartitionedGraph objects + ) + + return dataloader diff --git a/examples/reservoir_simulation/xmgn/src/data/graph_builder.py b/examples/reservoir_simulation/xmgn/src/data/graph_builder.py new file mode 100644 index 0000000000..8543f0a72f --- /dev/null +++ b/examples/reservoir_simulation/xmgn/src/data/graph_builder.py @@ -0,0 +1,1381 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import glob +import time +import json +import logging +import numpy as np +import torch +from torch_geometric.data import Data +from hydra.utils import to_absolute_path +from sim_utils import EclReader, Well, Grid +from multiprocessing import Pool, cpu_count, Manager +from scipy.interpolate import interp1d + +# Module-level logger +logger = logging.getLogger(__name__) + + +class ReservoirGraphBuilder: + """Builds graph structures from reservoir simulation data. + + This class processes reservoir simulation output files and creates + PyTorch Geometric graph structures for machine learning tasks. + """ + + ECL_SIMULATORS = ["OPM", "ECLIPSE", "IX"] + + # CMG_SIMULATORS = ["IMEX", "GEM", "STARS"] # TODO: implement + # NEXUS_SIMULATORS = [""] # TODO: implement + def __init__(self, cfg): + self.sim_dir = to_absolute_path(cfg.dataset.sim_dir) + self.simulator = cfg.dataset.get("simulator", "").upper() + self.num_samples = cfg.dataset.get( + "num_samples", None + ) # Limit number of samples to process + self.num_preprocess_workers = cfg.preprocessing.get( + "num_preprocess_workers", 4 + ) # Number of parallel workers for sample processing + + # Get graph configuration + self.graph_config = cfg.dataset.get("graph", None) + if self.graph_config is None: + raise ValueError( + "'dataset.graph' section is required in configuration. Please provide graph configuration." + ) + + # Set prev_timestep_idx based on prev_timesteps + self.prev_timestep_idx = ( + self.graph_config.node_features.dynamic.get("prev_timesteps", 0) + 1 + ) + + # Create vars config for reading simulation data and graph creation + # Check which coordinate components are requested as static features + self.requested_coordinates = [ + coord + for coord in ["X", "Y", "Z"] + if coord in self.graph_config.node_features.static + ] + self.include_coordinates_as_features = len(self.requested_coordinates) == 3 + + # Don't filter out X, Y, Z - let them be handled as regular static variables if requested + static_vars = list(self.graph_config.node_features.static) + + self.vars = { + "grid": { + "static": static_vars, + "dynamic": list(self.graph_config.node_features.dynamic.variables), + } + } + + # Add time_series variables if specified + if hasattr(self.graph_config.node_features, "time_series"): + self.vars["time_series"] = list(self.graph_config.node_features.time_series) + else: + self.vars["time_series"] = [] + + self.output_vars = list(self.graph_config.target_vars.node_features) + + # Get nonlinear scaling configuration + self.nonlinear_scaling = getattr(self.graph_config, "nonlinear_scaling", []) + self.global_vars = self.graph_config.global_features + + # Add edge features to static vars if specified + if hasattr(self.graph_config, "edge_features"): + edge_vars = self.graph_config.edge_features + if isinstance(edge_vars, list): + self.vars["grid"]["static"].extend(list(edge_vars)) + + # Add nonlinear scaling if specified + if hasattr(self.graph_config, "nonlinear_scaling"): + self.vars["nonlinear_scaling"] = list(self.graph_config.nonlinear_scaling) + + self._validate_config() + self._parse_dist() + self._set_output_path() + + def _validate_config(self) -> None: + if self.simulator not in self.ECL_SIMULATORS: + raise NotImplementedError( + f"Unsupported simulator '{self.simulator}'. Supported simulators are: {self.ECL_SIMULATORS}" + ) + if self.vars is None: + raise ValueError("'vars' cannot be empty.") + if self.output_vars is None: + raise ValueError("'output_vars' must be specified in config.") + + # Validate output_vars + if not hasattr(self.output_vars, "__iter__"): + raise ValueError("'output_vars' must be iterable.") + + def _parse_dist(self): + self._dist_map = {} + entries = self.vars.get("nonlinear_scaling", []) + for entry in entries: + key, method = entry.split(":") + self._dist_map[key.upper()] = method.upper() + + def _find_primary_input_files(self): + """Find simulation primary input files based on simulator type. + + Returns + list: Sorted list of primary input file paths + + Note: + - ECLIPSE/OPM: *.DATA files + - IX: *.AFI files + - CMG, NEXUS: TODO - implement support + """ + # Determine file extension based on simulator type + if self.simulator == "IX": + extension = "*.AFI" + pattern_regex = r"_(\d+)\.AFI$" + else: # ECLIPSE, OPM + extension = "*.DATA" + pattern_regex = r"_(\d+)\.DATA$" + + pattern = os.path.join(self.sim_dir, "**", extension) + return sorted( + glob.glob(pattern, recursive=True), + key=lambda fp: int( + re.search(pattern_regex, os.path.basename(fp), re.IGNORECASE).group(1) + ) + if re.search(pattern_regex, os.path.basename(fp), re.IGNORECASE) + else float("inf"), + ) + + def _set_output_path(self): + out_dir = os.path.join( + os.path.dirname(self.sim_dir), f"{os.path.basename(self.sim_dir)}.dataset" + ) + self._output_path_graph = os.path.join(out_dir, "graphs") + self._output_path_well = os.path.join(out_dir, "well.json") + # Note: Directory creation is handled by preprocessor to ensure correct job-specific paths + + def get_completion_info(self, grid, well_info) -> list: + """Extract well completion information from simulation data. + + Parameters + grid: Grid object containing grid information. + well_info: Dictionary containing well data from restart files. + + Returns + list: List of Well objects with completion data for each timestep. + """ + wells_lst = [] + for i in range(len(well_info["ZWEL"])): + INTEHEAD = well_info["INTEHEAD"][i] + ZWEL = well_info["ZWEL"][i] + IWEL = well_info["IWEL"][i] + ICON = well_info["ICON"][i] + SCON = well_info["SCON"][i] + + NWELLS, NCWMAX, NICONZ, NSCONZ = ( + INTEHEAD[16], + INTEHEAD[17], + INTEHEAD[32], + INTEHEAD[33], + ) + if NWELLS == 0: + wells_lst.append([]) # no wells operating + continue + + IWEL = IWEL.reshape((-1, NWELLS), order="F") + ICON = ICON.reshape((NICONZ, NCWMAX, NWELLS), order="F") + SCON = SCON.reshape((NSCONZ, NCWMAX, NWELLS), order="F") + + well_names = ["".join(row).strip() for row in ZWEL if "".join(row).strip()] + wells = { + name: Well(name=name, type_id=IWEL[6, i], stat=IWEL[10, i]) + for i, name in enumerate(well_names) + } + + for iwell, name in enumerate(well_names): + for id, icon in enumerate(ICON[:, :, iwell].T): + scon = SCON[:, id, iwell] + if icon[0] == 0: + break + I, J, K = icon[1:4] + well = wells[name] + well.add_completion( + I=I, J=J, K=K, dir=icon[13], stat=icon[5], conx_factor=scon[0] + ) + well.completions[-1].set_ijk(grid.ijk_from_I_J_K(I, J, K)) + + wells_lst.append(wells) + + return wells_lst + + def _apply_nonlinear_scaling(self, data, var_name): + """ + Apply nonlinear scaling to data based on configuration. + + Parameters + ----------- + data : np.ndarray + Input data array + var_name : str + Variable name to check for scaling configuration + + Returns + -------- + np.ndarray + Scaled data array + """ + # Check if this variable has nonlinear scaling configured + for scaling_config in self.nonlinear_scaling: + if ":" in scaling_config: + var, scaling_type = scaling_config.split(":", 1) + if var == var_name: + if scaling_type.upper() == "LOG10": + # Apply log10 scaling: log10(max(data, 1e-10)) + # Use 1e-10 as minimum to avoid log(0) + data_scaled = np.log10(np.maximum(data, 1e-10)) + return data_scaled + elif scaling_type.upper() == "LOG": + # Apply natural log scaling: log(max(data, 1e-10)) + data_scaled = np.log(np.maximum(data, 1e-10)) + return data_scaled + elif scaling_type.upper() == "SQRT": + # Apply square root scaling: sqrt(max(data, 0)) + data_scaled = np.sqrt(np.maximum(data, 0)) + return data_scaled + else: + logger.warning( + f"Unknown scaling type '{scaling_type}' for variable '{var_name}'. Skipping scaling." + ) + + # No scaling configured for this variable + return data + + def build_graph_from_simulation_data( + self, + grid, + wells_data, + data, + sample_idx, + timestep_idx=0, + case_name=None, + time_series_data=None, + ): + """ + Build a reservoir simulation graph from processed data. + + Parameters + ----------- + grid : Grid object + Grid object with all grid information + wells_data : list + List of Well objects for this sample + data : dict + Combined data dictionary containing both static and dynamic properties + sample_idx : int + Index of the sample + timestep_idx : int + Current timestep index + case_name : str, optional + Name of the case + time_series_data : dict, optional + Interpolated time series data {well_name: {var_name: [values]}} + + Returns + -------- + graph : pyg.data.Data + Graph with node and edge features + """ + + # Get connections and transmissibility + conx, tran = grid.get_conx_tran() + edge_index = conx.T # (2, E) + tran = self._apply_nonlinear_scaling(tran, var_name="TRAN") + edge_features = tran.reshape(-1, 1) # (E, 1) + + # Create coordinates array for the graph (optional - only if coordinates are not in node features) + coordinates = ( + None + if self.include_coordinates_as_features + else np.column_stack([grid.X, grid.Y, grid.Z]) + ) # (N_active, 3) + + # Extract input variables (current timestep) + input_tensors = [] + input_var_names = [] + + # Add static variables (including X, Y, Z if requested as node features in config) + for var in self.vars["grid"]["static"]: + try: + var_data = np.asarray(data[var], dtype=np.float32) + var_data = self._apply_nonlinear_scaling(var_data, var) + + input_tensors.append( + torch.tensor(var_data, dtype=torch.float32).unsqueeze(1) + ) + input_var_names.append(var) + except Exception as e: + logger.error( + f"Failed to process static variable '{var}' - {type(e).__name__}: {e}" + ) + return None + + # Add dynamic variables (multiple previous timesteps) - including completion + dynamic_vars = list(self.vars["grid"]["dynamic"]) + + for var in dynamic_vars: + if var in data: + # Check if we have enough timesteps for the required history + if timestep_idx + 1 < self.prev_timestep_idx: + # Not enough history - skip this timestep (expected for early timesteps) + return None + + try: + # Extract data from multiple previous timesteps + var_data_list = [] + for t in range(self.prev_timestep_idx): + prev_timestep = timestep_idx - t # Go backwards in time + if prev_timestep < 0 or prev_timestep >= len(data[var]): + logger.error( + f"Variable '{var}' not available at timestep {prev_timestep} - skipping graph" + ) + return None + var_data_list.append(data[var][prev_timestep]) + + # Stack the historical data: [prev_timestep_idx, n_active] + var_data_stacked = np.stack(var_data_list, axis=0) + # Reshape to [n_active, prev_timestep_idx] to match static features + var_data_reshaped = ( + var_data_stacked.T + ) # Transpose to [n_active, prev_timestep_idx] + + # Apply nonlinear scaling if specified in config + var_data_reshaped = self._apply_nonlinear_scaling( + var_data_reshaped, var + ) + + input_tensors.append( + torch.tensor(var_data_reshaped, dtype=torch.float32) + ) + # Add individual names for each timestep: current, prev_1, prev_2, ... + for t in range(self.prev_timestep_idx): + if t == 0: + input_var_names.append(f"{var}_current") + else: + input_var_names.append(f"{var}_prev_{t}") + except Exception as e: + logger.error( + f"Failed to process variable '{var}' - {type(e).__name__}: {e}" + ) + return None + else: + logger.error(f"Variable '{var}' not available - skipping graph") + return None + + # Add time series variables if available + if time_series_data and len(self.vars.get("time_series", [])) > 0: + time_series_vars = self.vars["time_series"] + + # Get wells for current timestep (use timestep+1 for next state, as done with WCID) + current_wells = ( + wells_data[timestep_idx + 1] + if timestep_idx + 1 < len(wells_data) + else {} + ) + + if current_wells: + # Create time series arrays for this timestep + ts_arrays = self._create_time_series_arrays( + grid, + current_wells, + time_series_data, + time_series_vars, + timestep_idx + 1, + ) + + # Process each time series variable + # Note: Pressure variables (BHP, THP) will create _INJ and _PRD variants + for var_name in time_series_vars: + var_upper = var_name.upper() + is_pressure_var = ("BHP" in var_upper) or ("THP" in var_upper) + + if is_pressure_var: + # Pressure variables create two channels: _INJ and _PRD + channel_names = [f"{var_name}_INJ", f"{var_name}_PRD"] + else: + # Other variables create single channel + channel_names = [var_name] + + # Process each channel (1 for non-pressure vars, 2 for pressure vars) + for channel_name in channel_names: + if channel_name not in ts_arrays: + continue + + try: + # For time series, we can also include history + var_data_list = [] + for t in range(self.prev_timestep_idx): + prev_ts_idx = (timestep_idx + 1) - t + if prev_ts_idx < 0: + logger.error( + f"Time series '{channel_name}' not available at timestep {prev_ts_idx} - skipping graph" + ) + return None + + # Create array for this historical timestep + prev_wells = ( + wells_data[prev_ts_idx] + if prev_ts_idx < len(wells_data) + else {} + ) + if not prev_wells: + logger.error( + f"No wells data at timestep {prev_ts_idx} - skipping graph" + ) + return None + + prev_ts_arrays = self._create_time_series_arrays( + grid, + prev_wells, + time_series_data, + time_series_vars, + prev_ts_idx, + ) + + if channel_name not in prev_ts_arrays: + var_data_list.append( + np.zeros(grid.nact, dtype=np.float32) + ) + else: + var_data_list.append(prev_ts_arrays[channel_name]) + + # Stack historical data + var_data_stacked = np.stack(var_data_list, axis=0) + var_data_reshaped = ( + var_data_stacked.T + ) # [n_active, prev_timestep_idx] + + # Apply nonlinear scaling if specified (use original var_name for config lookup) + var_data_reshaped = self._apply_nonlinear_scaling( + var_data_reshaped, var_name + ) + + input_tensors.append( + torch.tensor(var_data_reshaped, dtype=torch.float32) + ) + # Add individual names for each timestep: current, prev_1, prev_2, ... + for t in range(self.prev_timestep_idx): + if t == 0: + input_var_names.append(f"{channel_name}_current") + else: + input_var_names.append(f"{channel_name}_prev_{t}") + + except Exception as e: + logger.error( + f"Failed to process time series variable '{channel_name}' - {type(e).__name__}: {e}" + ) + return None + + # Concatenate all input features (will add temporal features if enabled) + node_features = torch.cat(input_tensors, dim=1) + + # Extract target variables (next timestep) + target_timestep = timestep_idx + 1 + target_tensors = [] + target_var_names = [] + + for var in self.output_vars: + if var in data and target_timestep < len(data[var]): + try: + target_data = np.asarray( + data[var][target_timestep], dtype=np.float32 + ) + target_data = self._apply_nonlinear_scaling(target_data, var) + + target_tensors.append( + torch.tensor(target_data, dtype=torch.float32).unsqueeze(1) + ) + target_var_names.append(var) + except Exception as e: + logger.error( + f"Failed to process target variable '{var}' at timestep {target_timestep} for sample {sample_idx} - {type(e).__name__}: {e}" + ) + return None + else: + logger.error( + f"Target variable '{var}' not available at timestep {target_timestep} for sample {sample_idx}" + ) + # If target is not available, return None to skip this graph + return None + + # Concatenate target features + target = torch.cat(target_tensors, dim=1) + + # Add temporal features to node features if enabled in config + # This is the standard approach for incorporating time info in autoregressive GNNs + n_nodes = node_features.shape[0] + temporal_feature_names = [] + + if self.global_vars.get("delta_t", False) or self.global_vars.get( + "time", False + ): + # Validate that we have TIME data available + if "TIME" not in data or len(data["TIME"]) == 0: + logger.error( + f"No TIME data available in restart file for sample {sample_idx}" + ) + return None + + if target_timestep >= len(data["TIME"]): + logger.error( + f"Target timestep {target_timestep} >= available timesteps {len(data['TIME'])} for sample {sample_idx}" + ) + return None + + if timestep_idx >= len(data["TIME"]): + logger.error( + f"Current timestep {timestep_idx} >= available timesteps {len(data['TIME'])} for sample {sample_idx}" + ) + return None + + try: + # Calculate actual delta_t from TIME array + delta_t = data["TIME"][target_timestep] - data["TIME"][timestep_idx] + except Exception as e: + logger.error( + f"Failed to calculate delta_t for sample {sample_idx} - {type(e).__name__}: {e}" + ) + return None + + # Add delta_t as node feature (broadcast to all nodes) + if self.global_vars.get("delta_t", False): + delta_t_feature = torch.full((n_nodes, 1), delta_t, dtype=torch.float32) + node_features = torch.cat([node_features, delta_t_feature], dim=1) + temporal_feature_names.append("delta_t") + input_var_names.append("delta_t") + + # Add normalized time as node feature (broadcast to all nodes) + if self.global_vars.get("time", False): + total_time = data["TIME"][-1] if len(data["TIME"]) > 0 else 1.0 + time_normalized = data["TIME"][timestep_idx] / max(total_time, 1.0) + time_feature = torch.full( + (n_nodes, 1), time_normalized, dtype=torch.float32 + ) + node_features = torch.cat([node_features, time_feature], dim=1) + temporal_feature_names.append("time") + input_var_names.append("time") + + # Keep global_features for backward compatibility and metadata (not used by model) + global_features = None + + try: + # Build graph data dictionary dynamically + graph_data = { + "x": node_features, + "edge_index": torch.tensor(edge_index, dtype=torch.long), + "edge_attr": torch.tensor(edge_features, dtype=torch.float32), + "y": target, + "grid_info": { + "nx": grid.nx, + "ny": grid.ny, + "nz": grid.nz, + "total_cells": grid.nn, + "active_cells": grid.nact, + "sample_idx": sample_idx, + "timestep_idx": timestep_idx, + "target_timestep": target_timestep, + "input_vars": input_var_names, + "target_vars": target_var_names, + }, + "case_name": case_name or f"sample_{sample_idx:03d}", + "timestep_id": timestep_idx, + } + + # Add global_features only if populated (for backward compatibility) + if global_features is not None: + graph_data["global_features"] = global_features + + # Add coordinates only if they're not already in node features + if coordinates is not None: + graph_data["coordinates"] = torch.tensor( + coordinates, dtype=torch.float32 + ) + except Exception as e: + logger.error( + f"Failed to create graph data dictionary for sample {sample_idx} - {type(e).__name__}: {e}" + ) + return None + + try: + # Create graph + graph = Data(**graph_data) + return graph + except Exception as e: + logger.error( + f"Failed to create PyTorch Geometric Data object for sample {sample_idx} - {type(e).__name__}: {e}" + ) + return None + + def _prepare_data_keys(self): + """Prepare the keys needed for reading simulation data.""" + init_keys = list( + dict.fromkeys( + self.vars["grid"]["static"] + + ["INTEHEAD", "PORV", "TRANX", "TRANY", "TRANZ", "TRANNNC"] + ) + ) + + # EGRID keys (skip expensive ones if coordinates aren't needed) + egrid_keys_geometry = ["COORD", "ZCORN", "FILEHEAD"] + if not self.include_coordinates_as_features: + egrid_keys_geometry = [ + k for k in egrid_keys_geometry if k not in ("COORD", "ZCORN") + ] + egrid_keys_nnc = ["NNC1", "NNC2"] + + rst_well_keys = ["INTEHEAD", "ZWEL", "IWEL", "ICON", "SCON"] + + return init_keys, (egrid_keys_geometry, egrid_keys_nnc), rst_well_keys + + def _prepare_dynamic_variables(self): + """Prepare dynamic variables and completion requirements.""" + dyn_vars = self.vars.get("grid", {}).get("dynamic", []) or [] + include_well_completion_ids = "WCID" in dyn_vars + include_well_completion_cf = "WCCF" in dyn_vars + if include_well_completion_ids: + dyn_vars.remove("WCID") + if include_well_completion_cf: + dyn_vars.remove("WCCF") + include_well_completions = ( + include_well_completion_ids or include_well_completion_cf + ) + + # Get time series variables + time_series_vars = self.vars.get("time_series", []) or [] + include_time_series = len(time_series_vars) > 0 + + return ( + dyn_vars, + include_well_completion_ids, + include_well_completion_cf, + include_well_completions, + time_series_vars, + include_time_series, + ) + + def _process_static_data(self, reader, init_keys, egrid_data, sample_idx_1based): + """Process static grid data and validate it.""" + init_data = reader.read_init(init_keys) + + # Add coordinates if requested + if self.include_coordinates_as_features: + grid = Grid(init_data, egrid_data) + for key in self.requested_coordinates: + init_data[key] = getattr(grid, key) + else: + grid = Grid(init_data, egrid_data) + + # Filter full-grid keys to active cells only + for key in Grid.FULL_GRID_KEYS: + if key in init_data and len(init_data[key]) == grid.nn: + init_data[key] = init_data[key][grid.actnum_bool] + + # Validate static data + for key in self.vars["grid"]["static"]: + if len(init_data[key]) == 0: + raise ValueError( + f" Error: Failed to read {key} from init/egrid file for sample {sample_idx_1based}" + ) + + return init_data, grid + + def _process_dynamic_data( + self, + reader, + grid, + dynamic_variables, + rst_well_keys, + include_well_completions, + include_well_completion_cf, + sample_idx_1based, + ): + """Process dynamic data including wells and completion arrays.""" + wells_data = [] + rst_data = {} + + if not dynamic_variables: + return wells_data, rst_data + + try: + rst_well_data = reader.read_restart(rst_well_keys) + wells_data = self.get_completion_info(grid, rst_well_data) + rst_data = reader.read_restart(dynamic_variables) + + # Handle completion arrays if needed + if wells_data and include_well_completions: + try: + completion_arrays_inj, completion_arrays_prd = [], [] + for wells in wells_data: + cmpl_inj, cmpl_prd = grid.create_completion_array( + wells, include_well_completion_cf + ) + completion_arrays_inj.append(cmpl_inj) + completion_arrays_prd.append(cmpl_prd) + # use states from next time step + rst_data["WCID_INJ"] = completion_arrays_inj[1:] + rst_data["WCID_PRD"] = completion_arrays_prd[1:] + except Exception as e: + self._log_error_and_continue( + f"Failed to create completion arrays for sample {sample_idx_1based} - {type(e).__name__}: {e}" + ) + return None, None + + # Validate dynamic data + for key in dynamic_variables: + if key not in rst_data or not rst_data[key]: + logger.warning( + f"Failed to read {key} from restart file for sample {sample_idx_1based}" + ) + rst_data[key] = [] + + except Exception as e: + self._log_error_and_continue( + f"Failed to read restart data for sample {sample_idx_1based} - {type(e).__name__}: {e}" + ) + return None, None + + return wells_data, rst_data + + def _read_and_interpolate_time_series( + self, + reader, + time_series_vars, + restart_times, + sample_idx_1based, + wells_data=None, + ): + """Read summary data and interpolate to match restart timesteps. + + Parameters + reader: EclReader instance + time_series_vars: List of time series variable names (e.g., ["WWIR", "WGIR", "WBHP"]) + restart_times: Array of restart file timesteps (in days) + sample_idx_1based: Sample index for error messages + wells_data: List of well dictionaries (one per timestep) for checking well status + + Returns + dict: Interpolated time series data structured as: + {well_name: {var_name: [val_t0, val_t1, ..., val_tn]}} + """ + if not time_series_vars or len(restart_times) == 0: + return {} + + try: + # Read summary data for all entities (wells) + smry_data = reader.read_smry(keys=time_series_vars, entities=None) + + if "TIME" not in smry_data: + logger.warning( + f"No TIME data in summary file for sample {sample_idx_1based}" + ) + return {} + + smry_times = smry_data["TIME"] + + # Check if restart times are within summary time range (only warn once) + if restart_times[0] < smry_times[0] or restart_times[-1] > smry_times[-1]: + # Store the warning info but don't print yet (will be collected and summarized) + if not hasattr(self, "_time_range_warnings"): + self._time_range_warnings = [] + self._time_range_warnings.append( + { + "sample": sample_idx_1based, + "restart_range": (restart_times[0], restart_times[-1]), + "summary_range": (smry_times[0], smry_times[-1]), + } + ) + + # Interpolate time series data for each well and variable + interpolated_data = {} + + for entity, entity_data in smry_data.items(): + if entity == "TIME": + continue + + if not isinstance(entity_data, dict): + continue + + interpolated_data[entity] = {} + + for var_name, var_values in entity_data.items(): + if var_name not in time_series_vars: + continue + + # Create interpolation function + # Use linear interpolation with bounds_error=False to extrapolate if needed + interp_func = interp1d( + smry_times, + var_values, + kind="linear", + bounds_error=False, + fill_value=( + var_values[0], + var_values[-1], + ), # Use edge values for extrapolation + ) + + # Interpolate to restart timesteps + interpolated_values = interp_func(restart_times) + + # Check well status for times before summary start + # Use 0.0 for wells that are SHUT or don't exist yet (not drilled) + if wells_data is not None and restart_times[0] < smry_times[0]: + for t_idx, restart_time in enumerate(restart_times): + # Only check times before summary data starts + if restart_time >= smry_times[0]: + break + + # Check if well exists and is open at this timestep + if t_idx < len(wells_data): + if entity not in wells_data[t_idx]: + # Well doesn't exist yet (not drilled/completed), use 0.0 + interpolated_values[t_idx] = 0.0 + else: + well = wells_data[t_idx][entity] + if ( + hasattr(well, "status") + and well.status == "SHUT" + ): + # Well exists but is shut, use 0.0 instead of first summary value + interpolated_values[t_idx] = 0.0 + # else: well is OPEN, keep interpolated value (first summary value) + + interpolated_data[entity][var_name] = interpolated_values + + return interpolated_data + + except Exception as e: + logger.warning( + f"Failed to read/interpolate time series for sample {sample_idx_1based}: {type(e).__name__}: {e}" + ) + return {} + + def _create_time_series_arrays( + self, grid, wells, time_series_data, time_series_vars, timestep_idx + ): + """Create time series arrays mapped to grid cells with well completions. + + Similar to completion ID arrays, this creates one array per time series variable, + where values are assigned to grid cells containing well completions. + + For pressure variables (e.g., WBHP, WTHP - anything with BHP or THP in name), + creates separate channels for injection and production wells + (e.g., WBHP_INJ, WBHP_PRD). + + Parameters + grid: Grid object + wells: Dictionary of Well objects for current timestep + time_series_data: Interpolated time series data + {well_name: {var_name: [val_t0, val_t1, ...]}} + time_series_vars: List of time series variable names + timestep_idx: Current timestep index + + Returns + dict: {var_name: np.ndarray} where arrays have shape (n_active_cells,) + For pressure variables, returns {var_name_INJ: array, var_name_PRD: array} + """ + if not time_series_data or not wells: + return {} + + result = {} + + for var_name in time_series_vars: + # Check if this is a pressure-related variable (bottom-hole pressure or tubing head pressure) + # These should be split into injection and production channels + var_upper = var_name.upper() + is_pressure_var = ("BHP" in var_upper) or ("THP" in var_upper) + + if is_pressure_var: + # Create separate arrays for injection and production wells + var_array_inj = np.zeros(grid.nact, dtype=np.float32) + var_array_prd = np.zeros(grid.nact, dtype=np.float32) + else: + # Single array for non-pressure variables + var_array = np.zeros(grid.nact, dtype=np.float32) + + # Iterate through wells and assign values to completion cells + for well_name, well in wells.items(): + if well_name not in time_series_data: + continue + + if var_name not in time_series_data[well_name]: + continue + + # Skip shut wells - assign zeros (default array value) + if well.status == "SHUT": + continue + + # Get interpolated value at this timestep + var_values = time_series_data[well_name][var_name] + if timestep_idx >= len(var_values): + continue + + value = var_values[timestep_idx] + + # Determine well type from Well object + # well.type is "INJ" or "PRD" (set by _set_type method) + is_injector = well.type == "INJ" + + # Assign value to all completion cells for this well + for comp in well.completions: + # Check if IJK attribute exists and is valid + if hasattr(comp, "IJK") and comp.IJK is not None: + # Convert from 1-based to 0-based indexing, then map to active-only index + # (same logic as WCID in grid.create_completion_array) + ijk = comp.IJK - 1 # Convert to 0-based + if ijk in grid.ijk_to_active: + active_idx = grid.ijk_to_active[ijk] + if is_pressure_var: + # Assign to injection or production array based on well type + if is_injector: + var_array_inj[active_idx] = value + else: + var_array_prd[active_idx] = value + else: + # Single array for non-pressure variables + var_array[active_idx] = value + + # Store results + if is_pressure_var: + result[f"{var_name}_INJ"] = var_array_inj + result[f"{var_name}_PRD"] = var_array_prd + else: + result[var_name] = var_array + + return result + + def _validate_timesteps(self, combined_data, dynamic_variables, sample_idx_1based): + """Validate timesteps and return valid ones.""" + times = combined_data.get("TIME") or [] + if len(times) < 2: + self._log_error_and_continue( + f"Sample {sample_idx_1based} has only {len(times)} timestep(s), need at least 2 for current->target prediction" + ) + return [] + + # Find valid timesteps + valid_timesteps = [] + max_t = len(times) - 1 # because we use t+1 as target + + for t in range(max_t): + if self._is_timestep_valid(combined_data, dynamic_variables, t): + valid_timesteps.append(t) + + if not valid_timesteps: + self._log_error_and_continue( + f"No valid timesteps found for sample {sample_idx_1based} (simulation may have died early)" + ) + return [] + + return valid_timesteps + + def _is_timestep_valid(self, combined_data, dynamic_variables, t): + """Check if a timestep has all required data.""" + # Check inputs (t) + for var in dynamic_variables: + if var not in combined_data or len(combined_data[var]) <= t: + return False + + # Check targets (t+1) + for var in self.output_vars: + if var not in combined_data or len(combined_data[var]) <= (t + 1): + return False + + return True + + def _log_error_and_continue(self, message, context=""): + """Log error message and return indication to continue.""" + logger.error(f"{context}{message}") + logger.info("Skipping and continuing with the next one...") + return True + + def _build_graphs_for_sample( + self, + grid, + wells_data, + combined_data, + valid_timesteps, + sample_idx_1based, + case_name, + time_series_data=None, + ): + """Build graphs for all valid timesteps of a sample.""" + graphs = [] + total_possible = len(combined_data.get("TIME", [])) - 1 + + for t in valid_timesteps: + try: + graph = self.build_graph_from_simulation_data( + grid, + wells_data, + combined_data, + sample_idx=sample_idx_1based - 1, + timestep_idx=t, + case_name=case_name, + time_series_data=time_series_data, + ) + if graph is None: + # Only log as error if timestep should have had enough history + if t + 1 >= self.prev_timestep_idx: + self._log_error_and_continue( + f"Graph creation returned None for timestep {t} (missing data or simulation died)", + " ", + ) + # Otherwise skip silently (expected for early timesteps) + continue + + graphs.append(graph) + + except Exception as e: + self._log_error_and_continue( + f"Graph creation failed for timestep {t} - {type(e).__name__}: {e}", + " ", + ) + continue + + return graphs + + def _process_single_sample_worker( + self, + file_path, + sample_idx_1based, + total_samples, + egrid_keys_geometry, + egrid_keys_nnc, + init_keys, + dynamic_variables, + rst_well_keys, + include_well_completions, + include_well_completion_cf, + time_series_vars, + include_time_series, + output_path_graph, + progress_counter, + ): + """ + Process a single sample, save graphs immediately, and return filenames. + + This method is designed to be called in parallel by multiprocessing workers. + Returns (saved_filenames, case_name, error_msg) tuple. + """ + case_name = os.path.splitext(os.path.basename(file_path))[0] + + try: + reader = EclReader(file_path) + + # Read EGRID data (each worker reads independently - simpler than sharing) + egrid_data_geometry = reader.read_egrid(egrid_keys_geometry) + egrid_data_nnc = reader.read_egrid(egrid_keys_nnc) + egrid_data = {**egrid_data_geometry, **egrid_data_nnc} + + # Process static grid data + init_data, grid = self._process_static_data( + reader, init_keys, egrid_data, sample_idx_1based + ) + + # Process dynamic data + wells_data, restart_data = self._process_dynamic_data( + reader, + grid, + dynamic_variables, + rst_well_keys, + include_well_completions, + include_well_completion_cf, + sample_idx_1based, + ) + + if wells_data is None: # Error occurred + return None, case_name, "Error in processing dynamic data" + + # Combine static + dynamic data + combined_data = {**init_data, **restart_data} + + # Read and interpolate time series data if needed + time_series_data = None + if include_time_series and time_series_vars: + restart_times = np.array(combined_data.get("TIME", [])) + if len(restart_times) > 0: + time_series_data = self._read_and_interpolate_time_series( + reader, + time_series_vars, + restart_times, + sample_idx_1based, + wells_data, + ) + + # Validate timesteps + valid_timesteps = self._validate_timesteps( + combined_data, dynamic_variables, sample_idx_1based + ) + + if not valid_timesteps: + return None, case_name, "No valid timesteps found" + + # Build graphs for all valid timesteps + sample_graphs = self._build_graphs_for_sample( + grid, + wells_data, + combined_data, + valid_timesteps, + sample_idx_1based, + case_name, + time_series_data=time_series_data, + ) + + # Save graphs immediately (memory efficient) + saved_filenames = [] + for graph in sample_graphs: + timestep_id = getattr(graph, "timestep_id", 0) + filename = f"{case_name}_{timestep_id:03d}.pt" + graph_path = os.path.join(output_path_graph, filename) + torch.save(graph, graph_path) + saved_filenames.append(filename) + + # Increment progress counter (Manager.Value is automatically thread-safe) + progress_counter.value += 1 + + return saved_filenames, case_name, None + + except Exception as e: + error_msg = f"{type(e).__name__}: {e}" + # Increment progress counter even on error + progress_counter.value += 1 + return None, case_name, error_msg + + def _parse_results_from_samples(self) -> dict: + """ + Parse simulation results and create graphs from all samples. + + This is the main orchestration method that: + 1. Finds and validates input files + 2. Processes each sample's static and dynamic data + 3. Validates timesteps and builds graphs + 4. Returns all successfully created graphs + """ + # === INITIALIZATION === + sim_input_files = self._find_primary_input_files() + if not sim_input_files: + file_type = ".AFI" if self.simulator == "IX" else ".DATA" + raise RuntimeError( + f"No {file_type} files found in {self.sim_dir}. Check the path and file naming." + ) + + # Prepare data keys and variables for reading simulation files + init_keys, egrid_keys, rst_well_keys = self._prepare_data_keys() + egrid_keys_geometry, egrid_keys_nnc = egrid_keys + + ( + dynamic_variables, + include_well_completion_ids, + include_well_completion_cf, + include_well_completions, + time_series_vars, + include_time_series, + ) = self._prepare_dynamic_variables() + + all_graph_files = [] + failed_sample_count = 0 + + # Limit samples if specified + if self.num_samples is not None: + sim_input_files = sim_input_files[: self.num_samples] + + total_samples = len(sim_input_files) + + # Determine number of workers + n_workers = min(self.num_preprocess_workers, cpu_count(), total_samples) + + logger.info( + f"Processing {total_samples} simulation results using {n_workers} parallel workers..." + ) + + start_time = time.time() + + with Manager() as manager: + # Create shared progress counter + progress_counter = manager.Value("i", 0) + + # Prepare arguments for parallel processing + worker_args = [ + ( + file_path, + sample_idx_1based, + total_samples, + egrid_keys_geometry, + egrid_keys_nnc, + init_keys, + dynamic_variables, + rst_well_keys, + include_well_completions, + include_well_completion_cf, + time_series_vars, + include_time_series, + self._output_path_graph, + progress_counter, + ) + for sample_idx_1based, file_path in enumerate(sim_input_files, start=1) + ] + + with Pool(processes=n_workers) as pool: + # Start async processing + async_result = pool.starmap_async( + self._process_single_sample_worker, worker_args + ) + + # Print progress every 30 seconds + while not async_result.ready(): + async_result.wait(timeout=30.0) + if not async_result.ready(): + completed = progress_counter.value + elapsed = time.time() - start_time + logger.info( + f"... {completed}/{total_samples} samples completed (elapsed: {elapsed:.0f}s) ..." + ) + + results = async_result.get() + + elapsed = time.time() - start_time + logger.info( + f"Completed in {elapsed:.1f}s ({total_samples / elapsed:.1f} samples/s)" + ) + + # Collect results and handle errors + logger.info("Collecting results...") + for sample_idx, (saved_filenames, case_name, error_msg) in enumerate( + results, start=1 + ): + if saved_filenames is None: + failed_sample_count += 1 + self._log_error_and_continue( + f"Processing sample {sample_idx} ({case_name}): {error_msg}" + ) + if failed_sample_count > 0.2 * total_samples: + raise RuntimeError("Failed to process too many samples.") + else: + all_graph_files.extend(saved_filenames) + + # === FINAL SUMMARY === + total_samples = len(sim_input_files) + avg_graphs_per_sample = ( + (len(all_graph_files) / total_samples) if total_samples else 0.0 + ) + logger.info("Processing Summary:") + logger.info(f" Total samples processed: {total_samples}") + logger.info(f" Total graphs created: {len(all_graph_files)}") + logger.info(f" Average graphs per sample: {avg_graphs_per_sample:.1f}") + logger.info(f" Graphs saved to: {self._output_path_graph}") + + return all_graph_files + + def _completion_to_dict(self, completion): + """ + Convert a Completion object to a JSON-serializable dictionary. + + Parameters + completion: Completion object + + Returns + dict: Dictionary representation of the completion + """ + comp_dict = { + "I": completion.I, + "J": completion.J, + "K": completion.K, + "dir": completion.dir, + "status": completion.status, + "connection_factor": completion.connection_factor, + } + + # Add optional attributes if they exist + if hasattr(completion, "IJK"): + comp_dict["IJK"] = ( + int(completion.IJK) if completion.IJK is not None else None + ) + if hasattr(completion, "flow_rate"): + comp_dict["flow_rate"] = float(completion.flow_rate) + + return comp_dict + + def _well_to_dict(self, well): + """ + Convert a Well object to a JSON-serializable dictionary. + + Parameters + well: Well object + + Returns + dict: Dictionary representation of the well + """ + return { + "name": well.name, + "type": well.type, + "status": well.status, + "num_active_completions": well.num_active_completions, + "completions": [ + self._completion_to_dict(comp) for comp in well.completions + ], + } + + def _save_well_list_json(self, wells_data): + """ + Save wells_data (list of lists of dicts of Well objects) to a JSON file. + + Parameters + wells_data: List of lists of dictionaries of Well objects + """ + # Convert Wells objects to JSON-serializable format + json_data = [] + for timestep_wells in wells_data: + if isinstance(timestep_wells, dict): + # Dictionary of well_name: Well object + timestep_dict = { + well_name: self._well_to_dict(well) + for well_name, well in timestep_wells.items() + } + elif isinstance(timestep_wells, list): + # List is empty or contains Wells + if len(timestep_wells) == 0: + timestep_dict = {} + else: + timestep_dict = [ + self._well_to_dict(well) for well in timestep_wells + ] + else: + timestep_dict = {} + + json_data.append(timestep_dict) + + with open(self._output_path_well, "w") as f: + json.dump(json_data, f, indent=2) + logger.info(f"Well data saved to {self._output_path_well}") + + def execute(self): + # Process samples and save graphs (returns filenames) + generated_files = self._parse_results_from_samples() + + logger.info(f"Processed {len(generated_files)} graphs successfully!") + + return generated_files diff --git a/examples/reservoir_simulation/xmgn/src/inference.py b/examples/reservoir_simulation/xmgn/src/inference.py new file mode 100644 index 0000000000..3db3600903 --- /dev/null +++ b/examples/reservoir_simulation/xmgn/src/inference.py @@ -0,0 +1,1106 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Inference script for XMeshGraphNet on reservoir simulation data. +Loads the best checkpoint and performs autoregressive inference on test samples. +Generates GRDECL files with predictions for post-processing. +""" + +import os +import sys +import json +import glob +from datetime import datetime, timezone + +# Add repository root to Python path for sim_utils import +current_dir = os.path.dirname(os.path.abspath(__file__)) # This is src/ +repo_root = os.path.dirname(os.path.dirname(current_dir)) # Go up two levels from src/ +if repo_root not in sys.path: + sys.path.insert(0, repo_root) + +import torch +import torch.nn as nn +import numpy as np +import h5py +import hydra +from omegaconf import DictConfig + +from physicsnemo.distributed import DistributedManager +from physicsnemo.launch.logging import PythonLogger, RankZeroLoggingWrapper + +from physicsnemo.models.meshgraphnet import MeshGraphNet +from physicsnemo.launch.utils import load_checkpoint +from data.dataloader import GraphDataset, load_stats, find_pt_files +from sim_utils import EclReader, Grid +from utils import get_dataset_paths, fix_layernorm_compatibility + +# Fix LayerNorm compatibility issue +fix_layernorm_compatibility() + + +def InitializeLoggers(cfg: DictConfig): + """Initialize distributed manager and loggers for inference.""" + DistributedManager.initialize() + dist = DistributedManager() + logger = PythonLogger(name="xmgn_inference") + + logger.info("XMeshGraphNet - Autoregressive Inference for Reservoir Simulation") + + return dist, RankZeroLoggingWrapper(logger, dist) + + +class InferenceRunner: + """Inference runner for XMeshGraphNet.""" + + def __init__(self, cfg: DictConfig, dist, logger): + """Initialize the inference runner.""" + self.cfg = cfg + self.dist = dist + self.logger = logger + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Set up paths with job name + paths = get_dataset_paths(cfg) + self.dataset_dir = paths["dataset_dir"] + self.stats_file = paths["stats_file"] + self.test_partitions_path = paths["test_partitions_path"] + + # Set up inference output directory + self.inference_output_dir = "inference" + self.inference_metadata_file = os.path.join( + self.inference_output_dir, "inference_metadata.json" + ) + + # Load global statistics + self.stats = load_stats(self.stats_file) + + # Get model dimensions + input_dim_nodes = len(self.stats["node_features"]["mean"]) + input_dim_edges = len(self.stats["edge_features"]["mean"]) + output_dim = len(cfg.dataset.graph.target_vars.node_features) + + # Initialize model + self.model = MeshGraphNet( + input_dim_nodes=input_dim_nodes, + input_dim_edges=input_dim_edges, + output_dim=output_dim, + processor_size=cfg.model.num_message_passing_layers, + aggregation="sum", + hidden_dim_node_encoder=cfg.model.hidden_dim, + hidden_dim_edge_encoder=cfg.model.hidden_dim, + hidden_dim_node_decoder=cfg.model.hidden_dim, + mlp_activation_fn=cfg.model.activation, + do_concat_trick=cfg.performance.use_concat_trick, + ).to(self.device) + + # Load best checkpoint using PhysicsNeMo's load_checkpoint (same as training) + # Set up checkpoint arguments (same as training) + base_output_dir = os.getcwd() + best_checkpoint_dir = os.path.join(base_output_dir, "best_checkpoints") + + # Set up checkpoint arguments (following training pattern - same as bst_ckpt_args) + ckpt_args = { + "path": best_checkpoint_dir, + "models": self.model, + } + + # Check for explicit checkpoint paths in config + explicit_checkpoint = getattr(cfg.inference, "checkpoint_path", None) + explicit_model = getattr(cfg.inference, "model_path", None) + + if explicit_checkpoint or explicit_model: + # Use explicit checkpoint/model paths + if explicit_checkpoint: + self.logger.info(f"Using explicit checkpoint: {explicit_checkpoint}") + checkpoint = torch.load(explicit_checkpoint, map_location=self.device) + + # Load model state + if "models" in checkpoint: + self.model.load_state_dict(checkpoint["models"]) + else: + self.model.load_state_dict(checkpoint) + + # Extract epoch from filename if possible + filename = os.path.basename(explicit_checkpoint) + try: + parts = filename.split(".") + if len(parts) >= 3: + loaded_epoch = int(parts[2]) + else: + loaded_epoch = 0 + except Exception: + loaded_epoch = 0 + + self.logger.info( + f"Loaded explicit checkpoint from epoch {loaded_epoch}" + ) + + elif explicit_model: + self.logger.info(f"Using explicit model: {explicit_model}") + # For .mdlus files, we need to use PhysicsNeMo's load_checkpoint + model_ckpt_args = { + "path": os.path.dirname(explicit_model), + "models": self.model, + } + loaded_epoch = load_checkpoint(**model_ckpt_args, device=self.device) + self.logger.info(f"Loaded explicit model from epoch {loaded_epoch}") + else: + # Use automatic best checkpoint selection + self.logger.info("Using automatic best checkpoint selection") + + # Check for multiple checkpoint files and log them + if os.path.exists(best_checkpoint_dir): + checkpoint_files = [ + f for f in os.listdir(best_checkpoint_dir) if f.endswith(".pt") + ] + if len(checkpoint_files) > 1: + self.logger.info( + f"Found {len(checkpoint_files)} checkpoint files in best_checkpoints:" + ) + for file in sorted(checkpoint_files): + self.logger.info(f" - {file}") + self.logger.info( + "PhysicsNeMo will automatically select the best performing checkpoint" + ) + + # Load checkpoint using PhysicsNeMo's system + loaded_epoch = load_checkpoint(**ckpt_args, device=self.device) + self.logger.info(f"Loaded BEST checkpoint from epoch {loaded_epoch}") + + self.model.eval() + self.logger.info(f"Checkpoint directory: {best_checkpoint_dir}") + + # Create test dataset (following training pattern) + # Find partition files + file_paths = find_pt_files(self.test_partitions_path) + + # Load per-feature statistics + node_mean = torch.tensor(self.stats["node_features"]["mean"]) + node_std = torch.tensor(self.stats["node_features"]["std"]) + edge_mean = torch.tensor(self.stats["edge_features"]["mean"]) + edge_std = torch.tensor(self.stats["edge_features"]["std"]) + + # Load target feature statistics (if available) + target_mean = None + target_std = None + if "target_features" in self.stats: + target_mean = torch.tensor(self.stats["target_features"]["mean"]) + target_std = torch.tensor(self.stats["target_features"]["std"]) + + # Create dataset + self.test_dataset = GraphDataset( + file_paths, + node_mean, + node_std, + edge_mean, + edge_std, + target_mean, + target_std, + ) + + self.logger.info(f"Test dataset loaded with {len(self.test_dataset)} samples") + + def denormalize_predictions(self, pred): + """Denormalize predictions using global statistics.""" + target_mean = torch.tensor( + self.stats["target_features"]["mean"], device=self.device + ) + target_std = torch.tensor( + self.stats["target_features"]["std"], device=self.device + ) + return pred * target_std + target_mean + + def denormalize_targets(self, target): + """Denormalize targets using global statistics.""" + target_mean = torch.tensor( + self.stats["target_features"]["mean"], device=self.device + ) + target_std = torch.tensor( + self.stats["target_features"]["std"], device=self.device + ) + return target * target_std + target_mean + + def _get_target_feature_indices(self): + """ + Get the indices in node features that correspond to target variables. + These are the features we need to replace with predictions during autoregressive inference. + """ + # Get target variable names + target_vars = self.cfg.dataset.graph.target_vars.node_features + + # Get dynamic variable names from config + dynamic_vars = self.cfg.dataset.graph.node_features.dynamic.variables + + # Find indices of target variables in dynamic variables + target_indices = [] + for target_var in target_vars: + if target_var in dynamic_vars: + idx = dynamic_vars.index(target_var) + target_indices.append(idx) + + return target_indices + + def _update_node_features_with_predictions( + self, partitions_list, predictions_normalized + ): + """ + Update node features in partitions with predictions from previous timestep. + Replace only the features that correspond to target variables. + + Parameters + partitions_list: List of graph partitions + predictions_normalized: Normalized predictions from previous timestep (list of arrays per partition) + + Returns + Updated partitions_list with predictions in node features + """ + target_indices = self._get_target_feature_indices() + + # Get the number of dynamic variables to know the offset in node features + num_static_features = len(self.cfg.dataset.graph.node_features.static) + num_dynamic_features = len( + self.cfg.dataset.graph.node_features.dynamic.variables + ) + prev_timesteps = self.cfg.dataset.graph.node_features.dynamic.prev_timesteps + + # Dynamic features start after static features + # For prev_timesteps=0: dynamic features are at indices [num_static: num_static+num_dynamic] + # For prev_timesteps>0: current timestep is at the end of dynamic features + + if prev_timesteps == 0: + # Current timestep dynamic features start at num_static_features + dynamic_offset = num_static_features + else: + # Current timestep is at the last block of dynamic features + dynamic_offset = num_static_features + prev_timesteps * num_dynamic_features + + # Update each partition + updated_partitions = [] + for partition, pred_array in zip(partitions_list, predictions_normalized): + # Clone the partition to avoid modifying the original + # PyTorch Geometric Data objects need special handling + if hasattr(partition, "clone"): + partition = partition.clone() + + # Clone the node features tensor + partition.x = partition.x.clone() + + # Convert prediction array to tensor if needed + if isinstance(pred_array, np.ndarray): + pred_tensor = torch.tensor( + pred_array, dtype=torch.float32, device=partition.x.device + ) + else: + pred_tensor = ( + pred_array.clone() if hasattr(pred_array, "clone") else pred_array + ) + + # Replace target features in node features with predictions + # Note: predictions are only for inner nodes (excluding halo nodes) + for i, target_idx in enumerate(target_indices): + feature_idx = dynamic_offset + target_idx + if hasattr(partition, "inner_node"): + # Update only inner nodes (predictions don't include halo nodes) + partition.x[partition.inner_node, feature_idx] = pred_tensor[:, i] + else: + # No halo nodes, update all nodes + partition.x[:, feature_idx] = pred_tensor[:, i] + + updated_partitions.append(partition) + + return updated_partitions + + def evaluate_sample( + self, + partitions_list, + use_predictions_as_input=False, + prev_predictions_normalized=None, + ): + """ + Evaluate a single sample (list of partitions). + + Parameters + partitions_list: List of graph partitions for this timestep + use_predictions_as_input: If True, replace target features with predictions from previous timestep + prev_predictions_normalized: Normalized predictions from previous timestep (for autoregressive inference) + + Returns + avg_loss, avg_denorm_loss, predictions, targets, predictions_normalized + """ + total_loss = 0.0 + total_denorm_loss = 0.0 + num_partitions = 0 + + predictions = [] + targets = [] + predictions_normalized = [] # Store normalized predictions for next timestep + + with torch.no_grad(): + # If using autoregressive mode, update node features with previous predictions + if use_predictions_as_input and prev_predictions_normalized is not None: + partitions_list = self._update_node_features_with_predictions( + partitions_list, prev_predictions_normalized + ) + + for partition in partitions_list: + partition = partition.to(self.device) + + # Ensure data is in float32 + if hasattr(partition, "x") and partition.x is not None: + partition.x = partition.x.float() + if hasattr(partition, "edge_attr") and partition.edge_attr is not None: + partition.edge_attr = partition.edge_attr.float() + if hasattr(partition, "y") and partition.y is not None: + partition.y = partition.y.float() + + # Forward pass + pred = self.model(partition.x, partition.edge_attr, partition) + + # Get inner nodes if available + if hasattr(partition, "inner_node"): + pred_inner = pred[partition.inner_node] + target_inner = partition.y[partition.inner_node] + else: + pred_inner = pred + target_inner = partition.y + + # Calculate losses + loss = torch.nn.functional.mse_loss(pred_inner, target_inner) + + # Denormalize for physical units + pred_denorm = self.denormalize_predictions(pred_inner) + target_denorm = self.denormalize_targets(target_inner) + denorm_loss = torch.nn.functional.mse_loss(pred_denorm, target_denorm) + + total_loss += loss.item() + total_denorm_loss += denorm_loss.item() + num_partitions += 1 + + # Store predictions and targets + predictions.append(pred_denorm.cpu().numpy()) + targets.append(target_denorm.cpu().numpy()) + + # Store normalized predictions for next timestep's input + predictions_normalized.append(pred_inner.cpu().numpy()) + + avg_loss = total_loss / num_partitions if num_partitions > 0 else 0.0 + avg_denorm_loss = ( + total_denorm_loss / num_partitions if num_partitions > 0 else 0.0 + ) + + return avg_loss, avg_denorm_loss, predictions, targets, predictions_normalized + + def _extract_case_and_timestep(self, filename): + """Extract case name and time step from filename.""" + + if filename.startswith("partitions_"): + # Remove 'partitions_' prefix + filename = filename[11:] # Remove 'partitions_' + + # Remove .pt extension + filename = filename.replace(".pt", "") + + # Split by underscore and extract case name and time step + parts = filename.split("_") + if len(parts) >= 4: + # Format: CASE_2D_1_000 + case_name = "_".join(parts[:-1]) # CASE_2D_1 + timestep = parts[-1] # 000 + else: + # Fallback + case_name = filename + timestep = "000" + + return case_name, timestep + + def run_inference(self): + """Run autoregressive inference on test dataset.""" + self.logger.info("=" * 70) + self.logger.info("STARTING INFERENCE") + self.logger.info("=" * 70) + + # Get prev_timesteps config for determining initial conditions + prev_timesteps = self.cfg.dataset.graph.node_features.dynamic.prev_timesteps + num_initial_true_timesteps = ( + prev_timesteps + 1 + ) # Initial + prev_timesteps as true inputs + + self.logger.info( + f"Initial timesteps with true features: {num_initial_true_timesteps}" + ) + self.logger.info( + f"Subsequent timesteps: predictions feed into next timestep (autoregressive)" + ) + + # First, organize all samples by case and timestep + case_timestep_data = {} + for idx in range(len(self.test_dataset)): + file_path = self.test_dataset.file_paths[idx] + filename = os.path.basename(file_path) + case_name, timestep = self._extract_case_and_timestep(filename) + + if case_name not in case_timestep_data: + case_timestep_data[case_name] = {} + + case_timestep_data[case_name][timestep] = idx + + # Now process each case autoregressively + total_loss = 0.0 + total_denorm_loss = 0.0 + num_samples = 0 + case_results = {} + + all_cases = sorted(case_timestep_data.keys()) + total_cases = len(all_cases) + self.logger.info(f"Processing {total_cases} cases...") + + for case_idx, case_name in enumerate(all_cases, 1): + case_results[case_name] = { + "predictions": {}, + "targets": {}, + "losses": [], + "denorm_losses": [], + } + + # Get sorted timesteps for this case + timesteps = sorted(case_timestep_data[case_name].keys()) + + self.logger.info( + f"[{case_idx}/{total_cases}] Processing case: {case_name} ({len(timesteps)} timesteps)" + ) + + # Track predictions from previous timestep (normalized) + prev_predictions_normalized = None + + # Process each timestep in order + for timestep_idx, timestep in enumerate(timesteps): + idx = case_timestep_data[case_name][timestep] + partitions_list, label = self.test_dataset[idx] + + # Determine if we should use predictions as input + # Use true features for first num_initial_true_timesteps, then use predictions + use_predictions_as_input = timestep_idx >= num_initial_true_timesteps + + # Evaluate this timestep + loss, denorm_loss, predictions, targets, predictions_normalized = ( + self.evaluate_sample( + partitions_list, + use_predictions_as_input=use_predictions_as_input, + prev_predictions_normalized=prev_predictions_normalized, + ) + ) + + total_loss += loss + total_denorm_loss += denorm_loss + num_samples += 1 + + # Store results + case_results[case_name]["predictions"][timestep] = predictions + case_results[case_name]["targets"][timestep] = targets + case_results[case_name]["losses"].append(loss) + case_results[case_name]["denorm_losses"].append(denorm_loss) + + # Store predictions for next timestep + prev_predictions_normalized = predictions_normalized + + # Calculate final metrics + avg_loss = total_loss / num_samples + avg_denorm_loss = total_denorm_loss / num_samples + + # Save results per simulation case as HDF5 files + self._save_case_results_hdf5(case_results) + + # Calculate overall metrics for logging + all_predictions = [] + all_targets = [] + for case_data in case_results.values(): + for timestep_preds in case_data["predictions"].values(): + all_predictions.extend(timestep_preds) + for timestep_targets in case_data["targets"].values(): + all_targets.extend(timestep_targets) + + all_predictions = np.concatenate(all_predictions, axis=0) + all_targets = np.concatenate(all_targets, axis=0) + + # Calculate additional metrics + mae = np.mean(np.abs(all_predictions - all_targets)) + mse = np.mean((all_predictions - all_targets) ** 2) + rmse = np.sqrt(mse) + + # Log final results + self.logger.info("") + self.logger.info("=" * 70) + self.logger.info("AUTOREGRESSIVE INFERENCE RESULTS") + self.logger.info("=" * 70) + self.logger.info(f"Test samples processed: {num_samples}") + self.logger.info(f"Simulation cases: {len(case_results)}") + self.logger.info(f"Average normalized MSE: {avg_loss:.6e}") + self.logger.info(f"Average denormalized MSE: {avg_denorm_loss:.6e}") + self.logger.info(f"Overall MAE: {mae:.6e}") + self.logger.info(f"Overall RMSE: {rmse:.6e}") + self.logger.info("") + self.logger.info("Per-Variable Metrics:") + self.logger.info("-" * 70) + + # Per-variable metrics + target_names = self.cfg.dataset.graph.target_vars.node_features + for i, var_name in enumerate(target_names): + var_mae = np.mean(np.abs(all_predictions[:, i] - all_targets[:, i])) + var_rmse = np.sqrt( + np.mean((all_predictions[:, i] - all_targets[:, i]) ** 2) + ) + self.logger.info( + f" {var_name:>12s} | MAE: {var_mae:>12.6e} | RMSE: {var_rmse:>12.6e}" + ) + + self.logger.info("=" * 70) + + return { + "avg_loss": avg_loss, + "avg_denorm_loss": avg_denorm_loss, + "mae": mae, + "rmse": rmse, + "predictions": all_predictions, + "targets": all_targets, + "num_samples": num_samples, + "case_results": case_results, + } + + def _save_case_results_hdf5(self, case_results): + """Save inference results per simulation case as HDF5 files.""" + os.makedirs(self.inference_output_dir, exist_ok=True) + + self.logger.info("") + self.logger.info("Saving inference results to HDF5 files...") + + target_names = self.cfg.dataset.graph.target_vars.node_features + + for case_name, case_data in case_results.items(): + hdf5_file = os.path.join(self.inference_output_dir, f"{case_name}.hdf5") + + with h5py.File(hdf5_file, "w") as f: + # Create groups for predictions and targets + pred_group = f.create_group("predictions") + target_group = f.create_group("targets") + + # Save metadata + f.attrs["case_name"] = case_name + f.attrs["num_timesteps"] = len(case_data["predictions"]) + f.attrs["target_variables"] = [ + str(name) for name in target_names + ] # Convert to list of strings + f.attrs["avg_loss"] = np.mean(case_data["losses"]) + f.attrs["avg_denorm_loss"] = np.mean(case_data["denorm_losses"]) + + # Organize data by variable (PRESSURE, SWAT) with lists of vectors per timestep + for i, var_name in enumerate(target_names): + var_name_clean = var_name.upper() # Use capital case + + # Collect all timestep data for this variable + pred_vectors = [] + target_vectors = [] + timestep_numbers = [] # Track actual timestep numbers + + for timestep in sorted(case_data["predictions"].keys()): + predictions = case_data["predictions"][timestep] + targets = case_data["targets"][timestep] + + if predictions: + pred_array = np.concatenate(predictions, axis=0) + target_array = np.concatenate(targets, axis=0) + + # Extract this variable's data (column i) + pred_vectors.append(pred_array[:, i]) + target_vectors.append(target_array[:, i]) + # Store actual timestep number (predictions are FOR next timestep) + timestep_numbers.append(int(timestep)) + + # Save as variable groups with lists of vectors + if pred_vectors: + # Create variable groups + var_pred_group = pred_group.create_group(var_name_clean) + var_target_group = target_group.create_group(var_name_clean) + + # Save each timestep as a separate dataset within the variable group + for input_timestep, pred_vec, target_vec in zip( + timestep_numbers, pred_vectors, target_vectors + ): + # Predictions are FOR the next timestep after the input + predicted_timestep = input_timestep + 1 + var_pred_group.create_dataset( + f"timestep_{predicted_timestep:04d}", data=pred_vec + ) + var_target_group.create_dataset( + f"timestep_{predicted_timestep:04d}", data=target_vec + ) + + # Save metadata for this variable + var_pred_group.attrs["num_timesteps"] = len(pred_vectors) + var_pred_group.attrs["num_nodes"] = ( + len(pred_vectors[0]) if pred_vectors else 0 + ) + var_target_group.attrs["num_timesteps"] = len(target_vectors) + var_target_group.attrs["num_nodes"] = ( + len(target_vectors[0]) if target_vectors else 0 + ) + + # Save metadata file with list of HDF5 files + hdf5_files = [f"{case_name}.hdf5" for case_name in case_results.keys()] + metadata = { + "hdf5_files": hdf5_files, + "num_cases": len(case_results), + "target_variables": [str(name) for name in target_names], + "created_at": datetime.now(timezone.utc).isoformat(), + } + + with open(self.inference_metadata_file, "w") as f: + json.dump(metadata, f, indent=2) + + self.logger.info( + f"Saved {len(case_results)} HDF5 files to: {self.inference_output_dir}" + ) + self.logger.info(f"Saved metadata: {self.inference_metadata_file}") + + def _load_partition_and_halo_info(self, case_name): + """ + Load partition assignments and halo information for a given case. + First tries to read from partition .pt files (preferred), then falls back to JSON. + + Parameters + case_name: Name of the simulation case + + Returns + tuple: (partition_assignment, halo_info) + - partition_assignment: numpy array with partition IDs for active cells (1-indexed), or None + - halo_info: numpy array indicating which partition includes this cell as halo (0=none, partition_id if halo), or None + """ + # First, try to find partition .pt file (from first timestep) + partitions_dir = os.path.join(self.dataset_dir, "partitions") + + # Check in test/train/val subdirectories + for split in ["test", "train", "val"]: + split_dir = os.path.join(partitions_dir, split) + + # Find the first partition file for this case (any timestep) + partition_pt_file = None + if os.path.exists(split_dir): + pattern = os.path.join(split_dir, f"partitions_{case_name}_*.pt") + matching_files = sorted(glob.glob(pattern)) + + if matching_files: + partition_pt_file = matching_files[ + 0 + ] # Use first available timestep + + if partition_pt_file and os.path.exists(partition_pt_file): + try: + # Load partition data from .pt file + partitions = torch.load( + partition_pt_file, map_location="cpu", weights_only=False + ) + + num_partitions = len(partitions) + + # Build partition assignment and halo info + partition_assignments_dict = {} + halo_info_dict = {} # Which partition includes this cell as halo + + for part_idx, partition in enumerate(partitions): + if hasattr(partition, "part_node") and hasattr( + partition, "inner_node" + ): + # Get inner nodes (these belong to this partition) + inner_global_indices = ( + partition.part_node[partition.inner_node].cpu().numpy() + ) + for global_idx in inner_global_indices: + partition_assignments_dict[global_idx] = ( + part_idx + 1 + ) # 1-indexed + + # Get halo nodes (all nodes NOT in inner_node) + all_local_indices = torch.arange(partition.num_nodes) + halo_mask = torch.ones( + partition.num_nodes, dtype=torch.bool + ) + halo_mask[partition.inner_node] = False + halo_local_indices = all_local_indices[halo_mask] + halo_global_indices = ( + partition.part_node[halo_local_indices].cpu().numpy() + ) + + # Mark these cells as being halo in this partition + for global_idx in halo_global_indices: + halo_info_dict[global_idx] = ( + part_idx + 1 + ) # Which partition includes this as halo + + # Sort by node index and create assignment lists + # Include both inner nodes AND halo nodes + all_node_indices = set(partition_assignments_dict.keys()) | set( + halo_info_dict.keys() + ) + sorted_indices = sorted(all_node_indices) + + partition_assignment = np.array( + [ + partition_assignments_dict.get(idx, 0) + for idx in sorted_indices + ], + dtype=int, + ) + + # Create halo info array (0 = not halo, partition_id = included as halo in that partition) + halo_info = np.array( + [halo_info_dict.get(idx, 0) for idx in sorted_indices], + dtype=int, + ) + + num_halo_cells = np.count_nonzero(halo_info) + + # Debug: check how many cells are halo-only vs inner+halo + num_inner_cells = np.count_nonzero(partition_assignment) + num_halo_only = np.sum( + (halo_info > 0) & (partition_assignment == 0) + ) + num_inner_and_halo = np.sum( + (halo_info > 0) & (partition_assignment > 0) + ) + + self.logger.info( + f"Loaded partition assignments from {split}/{os.path.basename(partition_pt_file)}: " + f"{num_partitions} partitions, {len(partition_assignment)} active cells" + ) + self.logger.info( + f" Inner cells: {num_inner_cells}, Halo-only: {num_halo_only}, Inner+Halo: {num_inner_and_halo}" + ) + + return partition_assignment, halo_info + + except Exception as e: + self.logger.warning( + f"Failed to load partitions from {partition_pt_file}: {e}" + ) + continue + + # Fall back to JSON file if .pt file not found + partition_json_file = os.path.join( + self.dataset_dir, f"{case_name}_partitions.json" + ) + + if os.path.exists(partition_json_file): + try: + with open(partition_json_file, "r") as f: + partition_data = json.load(f) + + partition_assignment = np.array( + partition_data["partition_assignment"], dtype=int + ) + + self.logger.info( + f"Loaded partition assignments from JSON: " + f"{partition_data['num_partitions']} partitions, " + f"{partition_data['num_nodes']} active cells " + f"(halo info not available from JSON)" + ) + + # JSON doesn't have halo info, return None for halo + return partition_assignment, None + + except Exception as e: + self.logger.warning( + f"Failed to load partition assignments from JSON: {e}" + ) + + # Neither .pt nor JSON found + self.logger.warning( + f"No partition data found for {case_name}. PARTITION block will be skipped." + ) + return None, None + + def _extract_coordinates_from_grid(self, sample_idx): + """Extract coordinates from grid files using the general Grid approach.""" + # Load dataset metadata from preprocessing + dataset_metadata_file = os.path.join(self.dataset_dir, "dataset_metadata.json") + if not os.path.exists(dataset_metadata_file): + raise FileNotFoundError( + f"Dataset metadata not found at {dataset_metadata_file}. Please run preprocessing first." + ) + + with open(dataset_metadata_file, "r") as f: + dataset_metadata = json.load(f) + + # Get the case name from the HDF5 metadata + if not os.path.exists(self.inference_metadata_file): + raise FileNotFoundError( + f"No inference metadata found at {self.inference_metadata_file}" + ) + + with open(self.inference_metadata_file, "r") as f: + inference_metadata = json.load(f) + + hdf5_files = inference_metadata.get("hdf5_files", []) + if sample_idx >= len(hdf5_files): + raise IndexError( + f"Sample index {sample_idx} exceeds available cases ({len(hdf5_files)})" + ) + + # Extract case name from HDF5 filename (remove .hdf5 extension) + case_name = hdf5_files[sample_idx].replace(".hdf5", "") + + # Get the original sim_dir from dataset metadata + original_sim_dir = dataset_metadata.get("sim_dir") + if not original_sim_dir: + raise KeyError("sim_dir not found in dataset metadata") + + # Construct the path to the simulator data directory using the original path + data_file = os.path.join(original_sim_dir, f"{case_name}.DATA") + + if not os.path.exists(data_file): + raise FileNotFoundError(f"Simulator data file not found: {data_file}") + + # Create reader and read grid data + reader = EclReader(data_file) + + # Read grid data (COORD, ZCORN for coordinates) + egrid_keys = ["COORD", "ZCORN", "FILEHEAD", "NNC1", "NNC2"] + egrid_data = reader.read_egrid(egrid_keys) + + # Read init data for grid dimensions and porosity + init_keys = ["INTEHEAD", "PORV"] + init_data = reader.read_init(init_keys) + + # Create grid object to get coordinates (same as in reservoir_graph_builder.py) + grid = Grid(init_data, egrid_data) + X, Y, Z = grid.X, grid.Y, grid.Z + + # Get grid dimensions from the grid object + nx, ny, nz = grid.nx, grid.ny, grid.nz + nact = grid.nact # number of active cells + + self.logger.info(f"Extracted coordinates from grid for {case_name}:") + self.logger.info( + f" Grid dimensions: {nx} × {ny} × {nz} = {nx * ny * nz} total cells" + ) + self.logger.info( + f" Active cells: {nact} ({nact / (nx * ny * nz) * 100:.1f}% of total)" + ) + self.logger.info(f" Coordinates: {len(X)} active nodes") + + return X, Y, Z, (nx, ny, nz), nact, grid + + def run_post(self): + """ + Generate Eclipse-style GRDECL ASCII files from HDF5 inference results. + Each HDF5 file is converted to a GRDECL file with format: + KEY_ + + / + """ + self.logger.info("") + self.logger.info("=" * 70) + self.logger.info("POST-PROCESSING: GENERATING GRDECL FILES") + self.logger.info("=" * 70) + + # Load metadata to get HDF5 files + if not os.path.exists(self.inference_metadata_file): + self.logger.warning( + f"No inference metadata found at {self.inference_metadata_file}" + ) + return + + with open(self.inference_metadata_file, "r") as f: + metadata = json.load(f) + + hdf5_files = metadata.get("hdf5_files", []) + if not hdf5_files: + self.logger.warning("No HDF5 files found in metadata") + return + + # Output directory (directly under inference/) + grdecl_output_dir = self.inference_output_dir + os.makedirs(grdecl_output_dir, exist_ok=True) + + self.logger.info(f"Output directory: {grdecl_output_dir}") + self.logger.info(f"Processing {len(hdf5_files)} HDF5 file(s)...") + + # Process each HDF5 file + for sample_idx, hdf5_filename in enumerate(hdf5_files, 1): + hdf5_file = os.path.join(self.inference_output_dir, hdf5_filename) + case_name = os.path.basename(hdf5_filename).replace(".hdf5", "") + + self.logger.info( + f"[{sample_idx}/{len(hdf5_files)}] Generating GRDECL for {case_name}..." + ) + + # Output GRDECL filename + grdecl_filename = f"{case_name}.GRDECL" + grdecl_filepath = os.path.join(grdecl_output_dir, grdecl_filename) + + try: + # Get grid information and actnum for this case + # Note: sample_idx is 1-based for display, convert to 0-based for indexing + X, Y, Z, grid_dims, nact, grid = self._extract_coordinates_from_grid( + sample_idx - 1 + ) + nx, ny, nz = grid_dims + total_cells = nx * ny * nz + actnum = grid.actnum_bool + + # Load partition assignments and halo info for this case + partition_data_active, halo_data_active = ( + self._load_partition_and_halo_info(case_name) + ) + + with h5py.File(hdf5_file, "r") as f: + with open(grdecl_filepath, "w") as grdecl_file: + # Write combined PARTITION block (if available) + # Positive values = partition ID (inner nodes that are NOT halo anywhere) + # Negative values = -partition_id for boundary nodes (cells that serve as halo) + if partition_data_active is not None: + # Start with partition assignments for active cells + combined_data_active = partition_data_active.copy() + + # Mark ALL halo cells with negative values (even if they're also inner somewhere) + if halo_data_active is not None: + # All cells where halo_data_active > 0 get marked as halo (negative) + halo_mask = halo_data_active > 0 + num_halo = np.sum(halo_mask) + num_halo_only = np.sum( + (halo_data_active > 0) + & (partition_data_active == 0) + ) + num_boundary = np.sum( + (halo_data_active > 0) & (partition_data_active > 0) + ) + + self.logger.info( + f" Total halo cells: {num_halo} (Halo-only: {num_halo_only}, Boundary: {num_boundary})" + ) + + # Mark halo cells with negative of their halo partition ID + # Use halo_data_active (not partition_data_active) to avoid -0 for halo-only cells + combined_data_active[halo_mask] = -halo_data_active[ + halo_mask + ] + + # Initialize full array with zeros (for inactive cells) + partition_data_full = np.zeros((total_cells,), dtype=int) + # Populate only active cells with combined partition/halo info + partition_data_full[actnum] = combined_data_active + + # Write PARTITION block + grdecl_file.write("PARTITION\n") + grdecl_file.write( + "-- Positive values: partition ID (inner nodes, not serving as halo)\n" + ) + grdecl_file.write( + "-- Negative values: -partition_id for boundary/halo nodes (e.g., -2 = owned by partition 2, serves as halo)\n" + ) + grdecl_file.write("-- Zero: inactive cells\n") + for i, value in enumerate(partition_data_full): + grdecl_file.write(f"{value} ") + if (i + 1) % 10 == 0: # 10 values per line for integers + grdecl_file.write("\n") + + # Ensure newline before '/' + if len(partition_data_full) % 10 != 0: + grdecl_file.write("\n") + + # Terminator + grdecl_file.write("/\n\n") + + # Get target variables + target_variables = f.attrs.get("target_variables", []) + + # Process each target variable + for var_name in target_variables: + var_name_clean = var_name.upper() + + if var_name_clean not in f["predictions"]: + self.logger.warning( + f"Variable {var_name_clean} not found in {hdf5_filename}" + ) + continue + + # Get all timesteps for this variable + timesteps = sorted(f["predictions"][var_name_clean].keys()) + + for timestep_key in timesteps: + # Extract timestep number from key (e.g., "timestep_001" -> 1) + timestep_num = int(timestep_key.split("_")[-1]) + + # Read prediction data (only active cells) + pred_data_active = f["predictions"][var_name_clean][ + timestep_key + ][:] + + # Initialize full array with zeros for all cells + pred_data_full = np.zeros((total_cells,)) + + # Populate only active cells with predicted values + pred_data_full[actnum] = pred_data_active + + # Write in Eclipse GRDECL format + # KEY_ with 4-digit formatting (e.g., 0001, 0010, 0120) + grdecl_file.write( + f"{var_name_clean}_{timestep_num:04d}\n" + ) + + # Values (write 5 values per line for readability) + for i, value in enumerate(pred_data_full): + grdecl_file.write(f"{value:.6e} ") + if (i + 1) % 5 == 0: + grdecl_file.write("\n") + + # Ensure newline before '/' + if len(pred_data_full) % 5 != 0: + grdecl_file.write("\n") + + # Terminator + grdecl_file.write("/\n") + + except Exception as e: + self.logger.error(f"Failed to process {hdf5_filename}: {e}") + continue + + self.logger.info("") + self.logger.info("=" * 70) + self.logger.info(f"GRDECL files saved to: {grdecl_output_dir}") + self.logger.info("=" * 70) + + +@hydra.main(version_base="1.3", config_path="../conf", config_name="config") +def main(cfg: DictConfig) -> None: + """ + Main inference entry point. + Performs autoregressive inference and generates GRDECL files. + """ + + dist, logger = InitializeLoggers(cfg) + + runner = InferenceRunner(cfg, dist, logger) + + runner.run_inference() + + runner.run_post() + + logger.success("Inference and post-processing completed successfully!") + + +if __name__ == "__main__": + main() diff --git a/examples/reservoir_simulation/xmgn/src/preprocessor.py b/examples/reservoir_simulation/xmgn/src/preprocessor.py new file mode 100644 index 0000000000..aa6aca8038 --- /dev/null +++ b/examples/reservoir_simulation/xmgn/src/preprocessor.py @@ -0,0 +1,997 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Preprocessing pipeline for reservoir simulation data. +Converts simulation output to partitioned graphs for XMeshGraphNet training. +Extracts grid properties, connections, and well data, computes global statistics, +and partitions graphs for efficient distributed training. +""" + +import os +import sys +import json +import random +import re +import shutil +import contextlib +import io +import warnings +import logging + +# Add src directory to Python path for flexible imports +current_dir = os.path.dirname(os.path.abspath(__file__)) # This is src/ +if current_dir not in sys.path: + sys.path.insert(0, current_dir) + +# Add repository root to Python path for sim_utils import +repo_root = os.path.dirname(os.path.dirname(current_dir)) # Go up two levels from src/ +if repo_root not in sys.path: + sys.path.insert(0, repo_root) + +import torch +import torch_geometric as pyg +from tqdm import tqdm +import hydra +from hydra.utils import to_absolute_path +from omegaconf import DictConfig + +from data.graph_builder import ReservoirGraphBuilder +from data.dataloader import PartitionedGraph, compute_global_statistics +from utils import get_dataset_dir + +logger = logging.getLogger(__name__) + + +class SimplePartition: + """Simple sequential partition as fallback when METIS is not available. + + Mimics the METIS partition structure for compatibility. + """ + + def __init__(self, num_nodes, num_parts): + """ + Initialize a simple sequential partition. + + Parameters + ----------- + num_nodes : int + Total number of nodes to partition + num_parts : int + Number of partitions to create + """ + self.node_perm = torch.arange(num_nodes) + + # Calculate partition boundaries + part_size = num_nodes // num_parts + remainder = num_nodes % num_parts + + # Create partition pointers + self.partptr = [0] + for i in range(num_parts): + # Add extra node to first 'remainder' partitions + current_size = part_size + (1 if i < remainder else 0) + self.partptr.append(self.partptr[-1] + current_size) + + +class ReservoirPreprocessor: + """ + A class to handle the complete preprocessing pipeline for reservoir simulation data. + + This class manages the creation of raw graphs from simulation data, partitioning them + for efficient training, computing global statistics, and organizing data splits. + """ + + def __init__(self, cfg: DictConfig): + """ + Initialize the ReservoirPreprocessor with configuration. + + Parameters + ----------- + cfg : DictConfig + Hydra configuration object containing all preprocessing parameters + """ + self.cfg = cfg + + # Get dataset directory using path_utils utility for consistent job name handling + self.dataset_dir = get_dataset_dir(cfg) + + self.graphs_dir = os.path.join(self.dataset_dir, "graphs") + self.partitions_dir = os.path.join(self.dataset_dir, "partitions") + self.stats_file = os.path.join(self.dataset_dir, "global_stats.json") + + # Set default values for preprocessing + self.cfg.preprocessing.num_preprocess_workers = getattr( + cfg.preprocessing, "num_preprocess_workers", 4 + ) + self.cfg.preprocessing.num_partitions = getattr( + cfg.preprocessing, "num_partitions", 3 + ) + self.cfg.preprocessing.halo_size = getattr(cfg.preprocessing, "halo_size", 1) + + self.graph_file_list = None + self.generated_files = None + + # Extract job name from dataset directory for display + job_name = os.path.basename(self.dataset_dir) + logger.info(f"Dataset directory: {self.dataset_dir}") + logger.info(f"Job name: {job_name}") + + def _extract_case_name_from_filename(self, filename): + """ + Extract case name from a graph filename by removing the timestep suffix. + + Expected format: {case_name}_{timestep:03d}.pt + where timestep is typically 3 digits (e.g., 000, 001, 123). + + Examples: + CASE_2D_1_000.pt -> CASE_2D_1 + NORNE_ATW2013_DOE_0004_002.pt -> NORNE_ATW2013_DOE_0004 + sample_005_123.pt -> sample_005 + + Parameters + ----------- + filename : str + Graph filename (with or without .pt extension) + + Returns + -------- + str: Case name without timestep suffix + """ + # Remove .pt extension if present + name = filename.replace(".pt", "") + + # Pattern: match case_name followed by underscore and 3-digit timestep at end + # The timestep is formatted as {timestep_id:03d} in graph_builder.py + match = re.match(r"^(.+)_(\d{3})$", name) + + if match: + return match.group(1) # Return everything before the last _XXX + else: + # Fallback: if pattern doesn't match, assume entire name is the case + # (this handles edge cases or future format changes) + return name + + def save_graph_file_list(self, graph_files, list_file="generated_graphs.json"): + """ + Save list of generated graph files for tracking. + + Parameters + ----------- + graph_files : list + List of generated graph file paths + list_file : str + Path to save graph file list + """ + # Save in the graphs directory + list_path = os.path.join(self.graphs_dir, list_file) + + graph_list = { + "generated_files": [os.path.basename(f) for f in graph_files], + "graphs_dir": self.graphs_dir, + "count": len(graph_files), + "timestamp": torch.tensor(0).item(), # Simple timestamp placeholder + } + + with open(list_path, "w") as f: + json.dump(graph_list, f, indent=2) + + logger.info(f"Saved graph file list to: {list_path}") + + def load_graph_file_list(self, list_file="generated_graphs.json"): + """ + Load list of generated graph files. + + Parameters + ----------- + list_file : str + Path to graph file list + + Returns + -------- + list or None: List of graph file names, or None if not found + """ + list_path = os.path.join(self.graphs_dir, list_file) + + if not os.path.exists(list_path): + return None + + try: + with open(list_path, "r") as f: + data = json.load(f) + return data.get("generated_files", []) + except (json.JSONDecodeError, KeyError): + return None + + def save_preprocessing_metadata(self, metadata_file="preprocessing_metadata.json"): + """ + Save preprocessing paths to a metadata file for later retrieval. + + Parameters + ----------- + metadata_file : str + Path to save metadata file + """ + metadata = { + "graphs_dir": self.graphs_dir, + "partitions_dir": self.partitions_dir, + "preprocessing_completed": True, + "partition_config": { + "num_partitions": self.cfg.preprocessing.num_partitions, + "halo_size": self.cfg.preprocessing.halo_size, + }, + } + + with open(metadata_file, "w") as f: + json.dump(metadata, f, indent=2) + + logger.info(f"Saved preprocessing metadata to: {metadata_file}") + + def save_dataset_metadata(self, metadata_file="dataset_metadata.json"): + """ + Save dataset metadata for inference use. + + Parameters + ----------- + metadata_file : str + Path to save dataset metadata file + """ + # Get absolute path to sim_dir + sim_dir_abs = to_absolute_path(self.cfg.dataset.sim_dir) + + metadata = { + "sim_dir": sim_dir_abs, # Absolute path to simulator data directory + "dataset_dir": self.dataset_dir, + "graphs_dir": self.graphs_dir, + "partitions_dir": self.partitions_dir, + "stats_file": self.stats_file, + "preprocessing_completed": True, + "job_name": os.path.basename(self.dataset_dir), + "config": { + "simulator": self.cfg.dataset.simulator, + "num_samples": getattr(self.cfg.dataset, "num_samples", None), + }, + "partition_config": { + "num_partitions": self.cfg.preprocessing.num_partitions, + "halo_size": self.cfg.preprocessing.halo_size, + }, + } + + metadata_path = os.path.join(self.dataset_dir, metadata_file) + with open(metadata_path, "w") as f: + json.dump(metadata, f, indent=2) + + logger.info(f"Saved dataset metadata to: {metadata_path}") + + def validate_partition_topology(self): + """ + Validate that existing partitions match the current configuration. + + Returns + -------- + bool + True if partitions are valid and match current config, False otherwise + """ + # Check if dataset metadata exists + metadata_path = os.path.join(self.dataset_dir, "dataset_metadata.json") + if not os.path.exists(metadata_path): + logger.warning( + "No dataset metadata found, cannot validate partition topology" + ) + return False + + try: + with open(metadata_path, "r") as f: + metadata = json.load(f) + + # Check if partition config exists in metadata + if "partition_config" not in metadata: + logger.warning( + "Partition configuration not found in metadata, " + "partitions may have been created with older version" + ) + return False + + saved_config = metadata["partition_config"] + current_num_partitions = self.cfg.preprocessing.num_partitions + current_halo_size = self.cfg.preprocessing.halo_size + + saved_num_partitions = saved_config.get("num_partitions") + saved_halo_size = saved_config.get("halo_size") + + # Validate num_partitions + if saved_num_partitions != current_num_partitions: + logger.warning( + f"Partition topology mismatch: " + f"existing partitions have num_partitions={saved_num_partitions}, " + f"but current config has num_partitions={current_num_partitions}" + ) + return False + + # Validate halo_size + if saved_halo_size != current_halo_size: + logger.warning( + f"Partition topology mismatch: " + f"existing partitions have halo_size={saved_halo_size}, " + f"but current config has halo_size={current_halo_size}" + ) + return False + + logger.info( + f"Partition topology validated: " + f"num_partitions={current_num_partitions}, halo_size={current_halo_size}" + ) + return True + + except (json.JSONDecodeError, KeyError) as e: + logger.warning(f"Failed to validate partition topology: {e}") + return False + + def split_samples_by_case(self, train_ratio, val_ratio, test_ratio, random_seed=42): + """ + Split graph files by case (sample) to ensure all timesteps of a sample stay together. + + Parameters + ----------- + train_ratio : float + Ratio of samples for training + val_ratio : float + Ratio of samples for validation + test_ratio : float + Ratio of samples for testing + random_seed : int + Random seed for reproducible splits + + Returns + -------- + dict: Dictionary with 'train', 'val', 'test' keys containing lists of file names + """ + # Extract unique case names from file names + case_names = set() + for filename in self.graph_file_list: + # Extract case name using robust regex-based parsing + case_name = self._extract_case_name_from_filename(filename) + case_names.add(case_name) + + case_names = sorted(list(case_names)) + total_cases = len(case_names) + + # Validate that we have enough samples for the split + min_samples_needed = 3 # Need at least 3 samples for train/val/test split + if total_cases < min_samples_needed: + raise ValueError( + f"Insufficient samples for train/val/test split! " + f"Found {total_cases} samples, but need at least {min_samples_needed}. " + f"Please increase num_samples in config or adjust split ratios." + ) + + # Validate split ratios + total_ratio = train_ratio + val_ratio + test_ratio + if abs(total_ratio - 1.0) > 1e-6: + raise ValueError(f"Split ratios must sum to 1.0, but got {total_ratio}") + + # Set random seed for reproducible splits + random.seed(random_seed) + random.shuffle(case_names) + + # Calculate split indices + train_end = int(total_cases * train_ratio) + val_end = train_end + int(total_cases * val_ratio) + + train_cases = case_names[:train_end] + val_cases = case_names[train_end:val_end] + test_cases = case_names[val_end:] + + # Ensure at least one sample in each split + if len(train_cases) == 0: + train_cases = [case_names[0]] + if len(val_cases) > 0: + val_cases = val_cases[1:] + elif len(test_cases) > 0: + test_cases = test_cases[1:] + + if len(val_cases) == 0 and len(test_cases) > 0: + val_cases = [test_cases[0]] + test_cases = test_cases[1:] + + logger.info(f"Sample split:") + logger.info( + f" Training: {len(train_cases)} cases ({len(train_cases) / total_cases * 100:.1f}%)" + ) + logger.info( + f" Validation: {len(val_cases)} cases ({len(val_cases) / total_cases * 100:.1f}%)" + ) + logger.info( + f" Test: {len(test_cases)} cases ({len(test_cases) / total_cases * 100:.1f}%)" + ) + + # Group files by split + splits = {"train": [], "val": [], "test": []} + + for filename in self.graph_file_list: + case_name = self._extract_case_name_from_filename(filename) + if case_name in train_cases: + splits["train"].append(filename) + elif case_name in val_cases: + splits["val"].append(filename) + elif case_name in test_cases: + splits["test"].append(filename) + + logger.info(f"File split:") + logger.info(f" Training: {len(splits['train'])} files") + logger.info(f" Validation: {len(splits['val'])} files") + logger.info(f" Test: {len(splits['test'])} files") + + return splits + + def organize_partitions_by_split(self, splits): + """ + Create partitions and organize them into train/val/test subdirectories. + + Parameters + ----------- + splits : dict + Dictionary with 'train', 'val', 'test' keys containing file lists + """ + logger.info(f"\nOrganizing partitions by split...") + + # Create subdirectories + train_dir = os.path.join(self.partitions_dir, "train") + val_dir = os.path.join(self.partitions_dir, "val") + test_dir = os.path.join(self.partitions_dir, "test") + + for split_dir in [train_dir, val_dir, test_dir]: + os.makedirs(split_dir, exist_ok=True) + + # Process each split + total_moved = 0 + for split_name, file_list in splits.items(): + if not file_list: + logger.info(f" → {split_name.capitalize()}: No files to process") + continue + + split_dir = os.path.join(self.partitions_dir, split_name) + logger.info(f" → Processing {split_name} split: {len(file_list)} files") + + moved_count = 0 + logger.info(f"Organizing {split_name} split ({len(file_list)} files)...") + for filename in file_list: + # Load the graph + try: + graph_path = os.path.join( + self.partitions_dir, f"partitions_{filename}" + ) + if not os.path.exists(graph_path): + continue + + # Move the partition file to the appropriate subdirectory + dest_path = os.path.join(split_dir, f"partitions_{filename}") + shutil.move(graph_path, dest_path) + moved_count += 1 + + except Exception as e: + continue + + logger.info( + f" Moved {moved_count}/{len(file_list)} files to {split_name}/" + ) + total_moved += moved_count + + logger.info(f"Partition organization complete!") + + @contextlib.contextmanager + def suppress_all_output(self): + """Context manager to suppress all output including stdout, stderr, warnings, and logging.""" + with ( + contextlib.redirect_stdout(io.StringIO()), + contextlib.redirect_stderr(io.StringIO()), + ): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # Temporarily disable logging + logging.disable(logging.CRITICAL) + try: + yield + finally: + logging.disable(logging.NOTSET) + + def create_simple_partition(self, num_nodes, num_parts): + """Create a simple sequential partition as fallback when METIS is not available. + + Parameters + ----------- + num_nodes : int + Total number of nodes to partition + num_parts : int + Number of partitions to create + + Returns + -------- + SimplePartition + Partition object with node_perm and partptr attributes + """ + return SimplePartition(num_nodes, num_parts) + + def create_partitions_from_graphs(self, graph_file_list=None): + """ + Create partitions from raw graphs for efficient training. + + Parameters + ----------- + graph_file_list : list or None + List of specific graph files to process (if None, process all .pt files) + """ + logger.info(f"\nCreating partitions from graphs...") + + # Create partitions directory + os.makedirs(self.partitions_dir, exist_ok=True) + + # Determine which graph files to process + if graph_file_list is not None: + # Use specific list of files + graph_files = [ + os.path.join(self.graphs_dir, f) + for f in graph_file_list + if f.endswith(".pt") + ] + else: + # Find all graph files + graph_files = [] + for file in os.listdir(self.graphs_dir): + if file.endswith(".pt"): + graph_files.append(os.path.join(self.graphs_dir, file)) + + logger.info( + f" → Processing {len(graph_files)} graphs with {self.cfg.preprocessing.num_partitions} partitions each..." + ) + + # Track partition assignments by case (we'll save once per case, not per timestep) + partition_assignments_by_case = {} + + # Process each graph file + successful_partitions = 0 + for i, graph_file in tqdm( + enumerate(graph_files, 1), + total=len(graph_files), + desc="Creating partitions", + unit="graph", + ): + # Load the graph + try: + graph = torch.load(graph_file, weights_only=False) + + # Create partitions directly without using PartitionedGraph class + # to avoid module path issues + # Try to partition the graph using PyG METIS, with fallback to simple partitioning + try: + with self.suppress_all_output(): + # Partition the graph using PyG METIS + cluster_data = pyg.loader.ClusterData( + graph, num_parts=self.cfg.preprocessing.num_partitions + ) + part_meta = cluster_data.partition + except Exception as e: + logger.warning( + f" WARNING: METIS partitioning failed ({e}), using simple partitioning..." + ) + # Fallback: simple sequential partitioning + part_meta = self.create_simple_partition( + graph.num_nodes, self.cfg.preprocessing.num_partitions + ) + + # Extract partition assignments (which node belongs to which partition) + # Create an array: partition_id[node_idx] = partition_number (1-indexed) + partition_assignment = torch.zeros(graph.num_nodes, dtype=torch.int32) + for part_idx in range(self.cfg.preprocessing.num_partitions): + # Get inner nodes of this partition + part_inner_nodes = part_meta.node_perm[ + part_meta.partptr[part_idx] : part_meta.partptr[part_idx + 1] + ] + # Assign partition ID (1-indexed for visualization) + partition_assignment[part_inner_nodes] = part_idx + 1 + + # Save partition assignments per case (only once per case) + filename = os.path.basename(graph_file) + case_name = self._extract_case_name_from_filename(filename) + + # Only save if we haven't saved for this case yet + if case_name not in partition_assignments_by_case: + partition_assignments_by_case[case_name] = ( + partition_assignment.numpy().tolist() + ) + + # Create partitions with halo regions using PyG `k_hop_subgraph` + partitions = [] + for part_idx in range(self.cfg.preprocessing.num_partitions): + # Get inner nodes of the partition + part_inner_node = part_meta.node_perm[ + part_meta.partptr[part_idx] : part_meta.partptr[part_idx + 1] + ] + # Partition the graph with halo regions + part_node, part_edge_index, inner_node_mapping, edge_mask = ( + pyg.utils.k_hop_subgraph( + part_inner_node, + num_hops=self.cfg.preprocessing.halo_size, + edge_index=graph.edge_index, + num_nodes=graph.num_nodes, + relabel_nodes=True, + ) + ) + + partition = pyg.data.Data( + edge_index=part_edge_index, + edge_attr=graph.edge_attr[edge_mask], + num_nodes=part_node.size(0), + part_node=part_node, + inner_node=inner_node_mapping, + ) + # Set partition node attributes + for k, v in graph.items(): + if graph.is_node_attr(k): + setattr(partition, k, v[part_node]) + + partitions.append(partition) + + # Save partitions as a list (following xaeronet pattern) + partition_file = os.path.join( + self.partitions_dir, f"partitions_{os.path.basename(graph_file)}" + ) + torch.save(partitions, partition_file) + + successful_partitions += 1 + + except Exception as e: + logger.error(f"ERROR: processing {os.path.basename(graph_file)}: {e}") + continue + + # Save partition assignments to JSON files (one per case) + if partition_assignments_by_case: + logger.info( + f"\nSaving partition assignments for {len(partition_assignments_by_case)} cases..." + ) + for case_name, partition_array in partition_assignments_by_case.items(): + partition_json_file = os.path.join( + self.dataset_dir, f"{case_name}_partitions.json" + ) + partition_data = { + "case_name": case_name, + "num_partitions": self.cfg.preprocessing.num_partitions, + "num_nodes": len(partition_array), + "partition_assignment": partition_array, # 1-indexed partition IDs for each active cell + } + with open(partition_json_file, "w") as f: + json.dump(partition_data, f, indent=2) + logger.info( + f" → Saved partition assignments to {self.dataset_dir}/*_partitions.json" + ) + + logger.info( + f"Partitioning complete! {successful_partitions}/{len(graph_files)} graphs processed successfully" + ) + + def check_existing_data(self): + """ + Check if preprocessing data already exists and ask user for overwrite decision. + + Returns + -------- + bool: Whether to overwrite existing data + """ + graphs_exist = ( + os.path.exists(self.graphs_dir) + and len([f for f in os.listdir(self.graphs_dir) if f.endswith(".pt")]) > 0 + ) + stats_exist = os.path.exists(self.stats_file) + + if not graphs_exist and not stats_exist: + return True # No existing data, proceed normally + + logger.warning("\nWARNING: Existing preprocessing data detected:") + if graphs_exist: + graph_count = len( + [f for f in os.listdir(self.graphs_dir) if f.endswith(".pt")] + ) + logger.warning( + f" → Graphs directory exists with {graph_count} graph files" + ) + if stats_exist: + logger.warning(f" → Global statistics file exists") + + # Check if we're in a non-interactive environment + if not sys.stdin.isatty(): + logger.info( + "\nNon-interactive environment detected. Auto-selecting 'y' (overwrite)" + ) + logger.info("Will overwrite all existing data") + return True + + logger.info("\nOptions:") + logger.info("y. Overwrite all existing data and start fresh") + logger.info("n. Exit") + + while True: + try: + choice = input("\nOverwrite existing data? (y/n): ").strip().lower() + if choice in ["y", "yes"]: + logger.info("Will overwrite all existing data") + return True + elif choice in ["n", "no"]: + logger.info("Exiting preprocessing") + sys.exit(0) + else: + logger.info("Invalid choice. Please enter y or n.") + except KeyboardInterrupt: + logger.info("\nExiting preprocessing") + sys.exit(0) + + def validate_config(self) -> None: + """ + Validate configuration parameters relevant to preprocessing. + """ + logger.info("Validating configuration...") + + # Validate dataset parameters + if not hasattr(self.cfg, "dataset") or not hasattr(self.cfg.dataset, "sim_dir"): + raise ValueError("Missing required config: dataset.sim_dir") + + sim_dir_abs = to_absolute_path(self.cfg.dataset.sim_dir) + if not os.path.exists(sim_dir_abs): + raise ValueError(f"Simulation directory not found: {sim_dir_abs}") + + # Validate sample count + num_samples = self.cfg.dataset.get("num_samples", None) + if num_samples is not None and num_samples < 3: + raise ValueError( + f"Insufficient samples: {num_samples} for train/val/test split. Need at least 3." + ) + + # Validate data split ratios + if hasattr(self.cfg, "preprocessing") and hasattr( + self.cfg.preprocessing, "data_split" + ): + data_split = self.cfg.preprocessing.data_split + train_ratio = data_split.get("train_ratio", 0.7) + val_ratio = data_split.get("val_ratio", 0.2) + test_ratio = data_split.get("test_ratio", 0.1) + + total_ratio = train_ratio + val_ratio + test_ratio + if abs(total_ratio - 1.0) > 1e-6: + raise ValueError( + f"Data split ratios must sum to 1.0, but got {total_ratio:.6f} (train={train_ratio}, val={val_ratio}, test={test_ratio})" + ) + + if train_ratio <= 0 or val_ratio <= 0 or test_ratio <= 0: + raise ValueError( + f"All split ratios must be positive. Got train={train_ratio}, val={val_ratio}, test={test_ratio}" + ) + + # Validate preprocessing parameters + if hasattr(self.cfg, "preprocessing"): + num_partitions = getattr(self.cfg.preprocessing, "num_partitions", 3) + halo_size = getattr(self.cfg.preprocessing, "halo_size", 1) + + if num_partitions < 1: + raise ValueError(f"num_partitions must be >= 1, got {num_partitions}") + + if halo_size < 0: + raise ValueError(f"halo_size must be >= 0, got {halo_size}") + + logger.info("Configuration validation passed!") + + def execute(self): + """ + Execute the complete preprocessing pipeline. + + This method orchestrates the entire preprocessing workflow: + 1. Create raw graphs from simulation data + 2. Create partitions from raw graphs + 3. Split samples and organize partitions + 4. Compute global statistics + 5. Save preprocessing metadata + """ + logger.info("Reservoir Simulation XMeshGraphNet Preprocessor") + logger.info("=" * 50) + + # Validate configuration first + self.validate_config() + + # Check for existing data and get user input + overwrite_data = self.check_existing_data() + + # Get skip options + skip_graphs = ( + getattr(self.cfg.preprocessing, "skip_graphs", False) or not overwrite_data + ) + + # Step 1: Create raw graphs (unless skipped) + if not skip_graphs: + logger.info("\nStep 1: Creating graphs from simulation data...") + processor = ReservoirGraphBuilder(self.cfg) + + # Override the output path to use our job-specific dataset directory + processor._output_path_graph = self.graphs_dir + os.makedirs(self.graphs_dir, exist_ok=True) + + self.generated_files = processor.execute() + + # Save list of generated graph files + self.save_graph_file_list( + [os.path.join(self.graphs_dir, f) for f in self.generated_files] + ) + self.graph_file_list = self.generated_files + else: + logger.info( + "\nStep 1: Skipping graph generation (using existing graphs)..." + ) + if not os.path.exists(self.graphs_dir): + raise FileNotFoundError( + f"Graphs directory not found: {self.graphs_dir}" + ) + + # Load existing graph file list + self.graph_file_list = self.load_graph_file_list() + if self.graph_file_list is None: + logger.info( + " → No tracked graph files found, will process all .pt files" + ) + self.graph_file_list = None + + # Step 2: Create partitions from the raw graphs + partitions_exist = ( + os.path.exists(self.partitions_dir) + and len([f for f in os.listdir(self.partitions_dir) if f.endswith(".pt")]) + > 0 + ) + + # Validate partition topology if partitions exist + topology_valid = False + if partitions_exist and not overwrite_data: + topology_valid = self.validate_partition_topology() + if not topology_valid: + logger.warning( + "Existing partitions do not match current configuration. " + "Partitions will be recreated." + ) + + if overwrite_data or not partitions_exist or not topology_valid: + logger.info("\nStep 2: Creating partitions from graphs...") + self.create_partitions_from_graphs(graph_file_list=self.graph_file_list) + else: + logger.info( + "\nStep 2: Skipping partition creation (using existing partitions)" + ) + logger.info(f" → Using existing partitions from {self.partitions_dir}") + + # Step 2b: Split samples and organize partitions + # Check if all split directories exist (train, val, test) + train_dir = os.path.join(self.partitions_dir, "train") + val_dir = os.path.join(self.partitions_dir, "val") + test_dir = os.path.join(self.partitions_dir, "test") + splits_exist = all(os.path.exists(d) for d in [train_dir, val_dir, test_dir]) + + if overwrite_data or not splits_exist: + if not splits_exist: + logger.info("\nStep 2b: Splitting samples and organizing partitions...") + logger.info( + " → One or more split directories (train/val/test) are missing" + ) + else: + logger.info("\nStep 2b: Splitting samples and organizing partitions...") + + # Get split configuration + data_split = getattr(self.cfg.preprocessing, "data_split", {}) + train_ratio = data_split.get("train_ratio", 0.7) + val_ratio = data_split.get("val_ratio", 0.2) + test_ratio = data_split.get("test_ratio", 0.1) + random_seed = data_split.get("random_seed", 42) + + # Split samples by case + splits = self.split_samples_by_case( + train_ratio=train_ratio, + val_ratio=val_ratio, + test_ratio=test_ratio, + random_seed=random_seed, + ) + + # Organize partitions into subdirectories + self.organize_partitions_by_split(splits) + else: + logger.info( + "\nStep 2b: Skipping partition organization (using existing splits)" + ) + logger.info( + f" → Using existing train/val/test splits in {self.partitions_dir}" + ) + + # Step 3: Compute and save global statistics + if overwrite_data or not os.path.exists(self.stats_file): + logger.info("\nStep 3: Computing global statistics...") + + # Get all graph files + graph_files = [ + os.path.join(self.graphs_dir, f) + for f in os.listdir(self.graphs_dir) + if f.endswith(".pt") + ] + + logger.info( + f" → Computing statistics from {len(graph_files)} graph files..." + ) + logger.info( + f" → This includes node features, edge features, and target features" + ) + + # Suppress METIS logging during statistics computation + with self.suppress_all_output(): + stats = compute_global_statistics(graph_files, self.stats_file) + + if stats is not None: + logger.info( + f"Global statistics computed and saved to {self.stats_file}" + ) + logger.info( + f" → Node features: {len(stats['node_features']['mean'])} features" + ) + logger.info( + f" → Edge features: {len(stats['edge_features']['mean'])} features" + ) + if "target_features" in stats: + logger.info( + f" → Target features: {len(stats['target_features']['mean'])} features" + ) + else: + logger.info( + f" → Target features: Not found (graphs may not have target data)" + ) + else: + logger.error("Failed to compute global statistics") + else: + logger.info( + "\nStep 3: Skipping statistics computation (using existing file)" + ) + logger.info(f" → Using existing statistics from {self.stats_file}") + + # Step 4: Save preprocessing metadata + logger.info("\nStep 4: Saving preprocessing metadata...") + # Always save metadata in the outputs directory + # Since hydra.run.dir is not available when running preprocessor directly, + # we'll use the current directory (which should be the outputs directory when run through Hydra) + outputs_dir = os.getcwd() + metadata_file = os.path.join(outputs_dir, "preprocessing_metadata.json") + self.save_preprocessing_metadata(metadata_file) + + # Step 5: Save dataset metadata for inference + logger.info("\nStep 5: Saving dataset metadata...") + self.save_dataset_metadata() + + logger.info("\nPreprocessing complete!") + logger.info(f" → Raw graphs: {self.graphs_dir}") + logger.info(f" → Partitions: {self.partitions_dir}") + + +@hydra.main(version_base="1.3", config_path="../conf", config_name="config") +def main(cfg: DictConfig) -> None: + """ + Main function to preprocess reservoir simulation data. + """ + + preprocessor = ReservoirPreprocessor(cfg) + + preprocessor.execute() + + +if __name__ == "__main__": + main() diff --git a/examples/reservoir_simulation/xmgn/src/train.py b/examples/reservoir_simulation/xmgn/src/train.py new file mode 100644 index 0000000000..f923f3c9ff --- /dev/null +++ b/examples/reservoir_simulation/xmgn/src/train.py @@ -0,0 +1,1153 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Training pipeline for XMeshGraphNet on reservoir simulation data. +Loads partitioned graphs, normalizes features using precomputed statistics, +trains the model with early stopping, and saves checkpoints using PhysicsNeMo utilities. +""" + +import os +import sys +import json +from datetime import datetime + +# Add repository root to Python path for sim_utils import +current_dir = os.path.dirname(os.path.abspath(__file__)) # This is src/ +repo_root = os.path.dirname(os.path.dirname(current_dir)) # Go up two levels from src/ +if repo_root not in sys.path: + sys.path.insert(0, repo_root) + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.distributed as dist +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data import DataLoader + +import numpy as np +import hydra +from omegaconf import DictConfig + +from physicsnemo.distributed import DistributedManager +from physicsnemo.launch.logging import PythonLogger, RankZeroLoggingWrapper +from physicsnemo.launch.logging.mlflow import initialize_mlflow +from physicsnemo.launch.logging import LaunchLogger +from physicsnemo.launch.utils import load_checkpoint, save_checkpoint +from physicsnemo.models.meshgraphnet import MeshGraphNet + +from utils import get_dataset_paths, fix_layernorm_compatibility, EarlyStopping +from data.dataloader import load_stats, find_pt_files, GraphDataset, custom_collate_fn + +# Fix LayerNorm compatibility issue +fix_layernorm_compatibility() + + +def InitializeLoggers(cfg: DictConfig): + """Initialize distributed manager and loggers following PhysicsNeMo pattern. + + Parameters + ---------- + cfg : DictConfig + Hydra configuration object + + Returns + ------- + tuple + (DistributedManager, PythonLogger) + """ + DistributedManager.initialize() # Only call this once in the entire script! + dist = DistributedManager() + logger = PythonLogger(name="xmgn_reservoir") + + logger.info("XMeshGraphNet - Training for Reservoir Simulation") + + # Initialize MLflow (only on rank 0, following PhysicsNeMo pattern) + if dist.rank == 0: + # Clean up only .trash directory to avoid "deleted experiment" conflicts + # while preserving historical results + import shutil + + for mlflow_dir in ["mlruns", ".mlflow"]: + trash_dir = os.path.join(mlflow_dir, ".trash") + if os.path.exists(trash_dir): + shutil.rmtree(trash_dir) + logger.info( + f"Cleaned {trash_dir} directory to avoid deleted experiment conflicts" + ) + + # Get system username from environment variables + user_name = ( + os.getenv("USER") + or os.getenv("USERNAME") + or os.getenv("LOGNAME") + or "unknown" + ) + + # Initialize PhysicsNeMo's MLflow integration + initialize_mlflow( + experiment_name=cfg.runspec.job_name, + experiment_desc=cfg.runspec.description, + run_name=f"{cfg.runspec.job_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}", + run_desc=f"Training: {cfg.runspec.description}", + user_name=user_name, + mode="offline", + ) + LaunchLogger.initialize(use_mlflow=True) + + return dist, RankZeroLoggingWrapper(logger, dist) + + +class Trainer: + """ + Unified trainer class that handles both partitioned and raw graphs. + Eliminates code duplication between training and validation. + """ + + def __init__(self, cfg, dist, logger): + """ + Initialize trainer with complete setup. + + Parameters + ---------- + cfg : DictConfig + Hydra configuration object + dist : DistributedManager + Distributed manager instance + logger : PythonLogger + Logger instance + """ + self.dist = dist + self.device = self.dist.device + self.logger = logger + self.cfg = cfg + + # Enable cuDNN auto-tuner (only for GPU) + cuda_available = torch.cuda.is_available() + if cuda_available: + torch.backends.cudnn.benchmark = cfg.performance.enable_cudnn_benchmark + + # Auto-generate checkpoint filename with best practices + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + dataset_name = os.path.basename(cfg.dataset.sim_dir) + self.checkpoint_filename = f"checkpoint_{dataset_name}_{timestamp}.pth" + + # Set up dataset paths with job name + paths = get_dataset_paths(cfg) + self.dataset_dir = paths["dataset_dir"] + self.stats_file = paths["stats_file"] + self.train_partitions_path = paths["train_partitions_path"] + self.val_partitions_path = paths["val_partitions_path"] + + # Load statistics (automatically generated in dataset directory) + self.stats = load_stats(self.stats_file) + + # Initialize components + self._initialize_dataloaders(cfg) + self._initialize_model(cfg) + self._initialize_optimizer(cfg) + self._initialize_training_config(cfg) + self._initialize_early_stopping(cfg) + self._initialize_checkpoints(cfg) + + def _initialize_dataloaders(self, cfg): + """Initialize training and validation dataloaders.""" + # Create unified data loaders (automatically handles partitions vs raw graphs) + self.train_dataloader = self._create_dataloader(cfg, is_validation=False) + + # Create validation dataloader on all ranks for proper DDP validation + self.val_dataloader = self._create_dataloader(cfg, is_validation=True) + + # Log dataset information + self.logger.info( + f"Dataset: {len(self.train_dataloader)} training samples, {len(self.val_dataloader)} validation samples" + ) + + def _initialize_model(self, cfg): + """Initialize the MeshGraphNet model.""" + # Get dimensions from stats and data + input_dim_nodes = len(self.stats["node_features"]["mean"]) + input_dim_edges = len(self.stats["edge_features"]["mean"]) + output_dim = len(cfg.dataset.graph.target_vars.node_features) + + # Create model + self.model = MeshGraphNet( + input_dim_nodes=input_dim_nodes, + input_dim_edges=input_dim_edges, + output_dim=output_dim, + processor_size=cfg.model.num_message_passing_layers, + aggregation="sum", + hidden_dim_node_encoder=cfg.model.hidden_dim, + hidden_dim_edge_encoder=cfg.model.hidden_dim, + hidden_dim_node_decoder=cfg.model.hidden_dim, + mlp_activation_fn=cfg.model.activation, + do_concat_trick=cfg.performance.use_concat_trick, + num_processor_checkpoint_segments=cfg.performance.checkpoint_segments, + ).to(self.device) + + # Wrap model for multi-GPU training if available + if self.dist.world_size > 1: + # Use DistributedDataParallel for multi-node/multi-GPU training + self.model = DistributedDataParallel( + self.model, + device_ids=[self.dist.local_rank], + output_device=self.dist.device, + broadcast_buffers=self.dist.broadcast_buffers, + find_unused_parameters=self.dist.find_unused_parameters, + gradient_as_bucket_view=True, + static_graph=True, + ) + + def _initialize_optimizer(self, cfg): + """Initialize optimizer, scheduler, and gradient scaler.""" + + weight_decay = getattr(cfg.training, "weight_decay", 1.0e-3) + + # Create optimizer (AdamW with decoupled weight decay) + self.optimizer = optim.AdamW( + self.model.parameters(), + lr=cfg.training.start_lr, + weight_decay=weight_decay, + betas=(0.9, 0.99), + eps=1e-8, + ) + + # Create cosine annealing scheduler + self.scheduler = optim.lr_scheduler.CosineAnnealingLR( + self.optimizer, T_max=cfg.training.num_epochs, eta_min=cfg.training.end_lr + ) + + # Create gradient scaler for mixed precision (GPU only) + self.scaler = GradScaler() if self.device.type == "cuda" else None + + self.logger.info( + f"Optimizer: AdamW (lr={cfg.training.start_lr}, weight_decay={weight_decay})" + ) + + def _initialize_training_config(self, cfg): + """Initialize training configuration and loss functions.""" + # Store training config + self.num_epochs = cfg.training.num_epochs + self.validation_freq = cfg.training.validation_freq + + # Load target variable weights + self.target_weights = torch.tensor( + cfg.dataset.graph.target_vars.weights, device=self.device + ) + self.logger.info( + f"Target variables: {cfg.dataset.graph.target_vars.node_features}" + ) + self.logger.info(f"Target variable weights: {self.target_weights.tolist()}") + + # Initialize loss functions (handles defaults and validation) + self._initialize_loss_functions(cfg) + + def _initialize_early_stopping(self, cfg): + """Initialize early stopping if configured.""" + if hasattr(cfg.training, "early_stopping") and hasattr( + cfg.training.early_stopping, "patience" + ): + self.early_stopping = EarlyStopping( + patience=cfg.training.early_stopping.patience, + min_delta=cfg.training.early_stopping.min_delta, + ) + self.logger.info( + f"Early stopping enabled: patience={cfg.training.early_stopping.patience}, " + f"min_delta={cfg.training.early_stopping.min_delta}" + ) + else: + self.early_stopping = None + self.logger.info("Early stopping disabled") + + def _initialize_checkpoints(self, cfg): + """Initialize checkpoint directories and arguments.""" + # Set up checkpoint arguments (following PhysicsNeMo pattern) + # Use current working directory (Hydra changes to output directory) + base_output_dir = os.getcwd() + + checkpoint_dir = os.path.join(base_output_dir, "checkpoints") + best_checkpoint_dir = os.path.join(base_output_dir, "best_checkpoints") + + # Create checkpoint directories if they don't exist + os.makedirs(checkpoint_dir, exist_ok=True) + os.makedirs(best_checkpoint_dir, exist_ok=True) + + # Log checkpoint paths + self.logger.info(f"Checkpoint directory: {checkpoint_dir}") + self.logger.info(f"Best checkpoint directory: {best_checkpoint_dir}") + + self.ckpt_args = { + "path": checkpoint_dir, + "optimizer": self.optimizer, + "scheduler": self.scheduler, + "models": self.model, + } + self.bst_ckpt_args = { + "path": best_checkpoint_dir, + "optimizer": self.optimizer, + "scheduler": self.scheduler, + "models": self.model, + } + + def _create_dataloader(self, cfg, is_validation=False): + """Create dataloader using instance variables with DistributedSampler support.""" + if is_validation: + partitions_path = self.val_partitions_path + else: + partitions_path = self.train_partitions_path + + # Load global statistics + with open(self.stats_file, "r") as f: + stats = json.load(f) + + # Load per-feature statistics + node_mean = torch.tensor(stats["node_features"]["mean"]) + node_std = torch.tensor(stats["node_features"]["std"]) + edge_mean = torch.tensor(stats["edge_features"]["mean"]) + edge_std = torch.tensor(stats["edge_features"]["std"]) + + # Load target feature statistics (if available) + target_mean = None + target_std = None + if "target_features" in stats: + target_mean = torch.tensor(stats["target_features"]["mean"]) + target_std = torch.tensor(stats["target_features"]["std"]) + + # Find partition files + file_paths = find_pt_files(partitions_path) + + # Create dataset + dataset = GraphDataset( + file_paths, + node_mean, + node_std, + edge_mean, + edge_std, + target_mean, + target_std, + ) + + # Create DistributedSampler for proper distributed training + if self.dist.world_size > 1: + sampler = DistributedSampler( + dataset, + num_replicas=self.dist.world_size, + rank=self.dist.rank, + shuffle=not is_validation, + drop_last=False, + ) + shuffle = False # DistributedSampler handles shuffling + else: + sampler = None + shuffle = not is_validation + + # Create data loader + dataloader = DataLoader( + dataset, + batch_size=cfg.training.get("batch_size", 1), + shuffle=shuffle, + sampler=sampler, + num_workers=0, + pin_memory=True, + collate_fn=custom_collate_fn, # Use custom collate function for lists of PartitionedGraph objects + ) + + # Store sampler for set_epoch calls + if not is_validation: + self.train_sampler = sampler + else: + self.val_sampler = sampler + + self.logger.info( + f"Using partitioned data loader with {len(dataloader)} batches" + ) + return dataloader + + def denormalize_predictions(self, predictions): + """Denormalize predictions using global statistics.""" + if "target_features" not in self.stats: + self.logger.warning( + "No target feature statistics found for denormalization" + ) + return predictions + + target_mean = torch.tensor( + self.stats["target_features"]["mean"], device=predictions.device + ) + target_std = torch.tensor( + self.stats["target_features"]["std"], device=predictions.device + ) + + # Denormalize: pred_denorm = pred_norm * std + mean + denormalized = predictions * target_std + target_mean + return denormalized + + def denormalize_targets(self, targets): + """Denormalize targets using global statistics.""" + if "target_features" not in self.stats: + self.logger.warning( + "No target feature statistics found for denormalization" + ) + return targets + + target_mean = torch.tensor( + self.stats["target_features"]["mean"], device=targets.device + ) + target_std = torch.tensor( + self.stats["target_features"]["std"], device=targets.device + ) + + # Denormalize: target_denorm = target_norm * std + mean + denormalized = targets * target_std + target_mean + return denormalized + + def _initialize_loss_functions(self, cfg): + """Initialize PyTorch loss functions for each target variable with defaults and validation.""" + # Load loss function configuration with defaults + self.loss_functions = getattr( + cfg.dataset.graph.target_vars, "loss_functions", None + ) + self.huber_delta = getattr(cfg.dataset.graph.target_vars, "huber_delta", None) + + # Set defaults if not provided + if self.loss_functions is None: + self.loss_functions = ["L2"] * len( + cfg.dataset.graph.target_vars.node_features + ) + self.logger.warning( + f"Loss functions not specified in config. Using default: {self.loss_functions}" + ) + elif len(self.loss_functions) != len( + cfg.dataset.graph.target_vars.node_features + ): + self.logger.warning( + f"Number of loss functions ({len(self.loss_functions)}) doesn't match number of target variables ({len(cfg.dataset.graph.target_vars.node_features)}). Using L2 for all." + ) + self.loss_functions = ["L2"] * len( + cfg.dataset.graph.target_vars.node_features + ) + + # Validate loss function names and set defaults for invalid ones (case-insensitive) + valid_losses = ["L1", "L2", "Huber"] + for i, loss_func in enumerate(self.loss_functions): + # Convert to proper case for consistency + loss_func_upper = loss_func.upper() + if loss_func_upper == "L1": + self.loss_functions[i] = "L1" + elif loss_func_upper == "L2": + self.loss_functions[i] = "L2" + elif loss_func_upper == "HUBER": + self.loss_functions[i] = "Huber" + else: + self.logger.warning( + f"Invalid loss function '{loss_func}' for variable {i}. Using L2 instead." + ) + self.loss_functions[i] = "L2" + + self.logger.info(f"Loss functions: {self.loss_functions}") + if "Huber" in self.loss_functions: + if self.huber_delta is None: # Set Huber delta default + self.huber_delta = 0.5 + self.logger.info( + f"Huber delta not specified. Using default value: {self.huber_delta}" + ) + self.logger.info(f"Huber delta: {self.huber_delta}") + + # Initialize PyTorch loss functions + self.loss_fn_objects = [] + + for loss_func in self.loss_functions: + if loss_func == "L1": + self.loss_fn_objects.append(torch.nn.L1Loss()) + elif loss_func == "L2": + self.loss_fn_objects.append(torch.nn.MSELoss()) + elif loss_func == "Huber": + self.loss_fn_objects.append(torch.nn.HuberLoss(delta=self.huber_delta)) + else: + raise ValueError(f"Unknown loss function: {loss_func}") + + # Move loss functions to device + for loss_fn in self.loss_fn_objects: + loss_fn.to(self.device) + + def compute_weighted_loss(self, predictions, targets): + """ + Compute weighted loss for each target variable using configurable loss functions. + + Parameters + ---------- + predictions : torch.Tensor + Model predictions [N, num_target_vars] + targets : torch.Tensor + Target values [N, num_target_vars] + + Returns + ------- + torch.Tensor + Weighted loss + """ + losses_per_var = [] + + for i, loss_fn in enumerate(self.loss_fn_objects): + pred_var = predictions[:, i] + target_var = targets[:, i] + + # Use the initialized PyTorch loss function + loss = loss_fn(pred_var, target_var) + losses_per_var.append(loss) + + # Convert to tensor and apply weights + losses_tensor = torch.stack(losses_per_var) + weighted_loss = torch.sum(self.target_weights * losses_tensor) + + return weighted_loss + + def compute_per_variable_losses(self, predictions, targets): + """ + Compute per-variable losses for logging purposes. + + Parameters + ---------- + predictions : torch.Tensor or np.ndarray + Model predictions [N, num_target_vars] + targets : torch.Tensor or np.ndarray + Target values [N, num_target_vars] + + Returns + ------- + list + List of per-variable losses + """ + losses_per_var = [] + + for i, loss_fn in enumerate(self.loss_fn_objects): + pred_var = predictions[:, i] + target_var = targets[:, i] + + # Convert to torch tensors if needed + if not isinstance(pred_var, torch.Tensor): + pred_var = torch.tensor(pred_var, device=self.device) + if not isinstance(target_var, torch.Tensor): + target_var = torch.tensor(target_var, device=self.device) + + # Use the initialized PyTorch loss function + loss = loss_fn(pred_var, target_var) + losses_per_var.append(loss.item()) + + return losses_per_var + + def _process_partition(self, part, is_training=True): + """ + Process a single partition (for both training and validation). + + Parameters + ---------- + part : torch_geometric.data.Data + The partition to process + is_training : bool + Whether this is training (affects gradient computation) + + Returns + ------- + tuple + (loss, denorm_loss, pred, target) + """ + part = part.to(self.device) + + # Ensure data is in float32 for mixed precision training + if hasattr(part, "x") and part.x is not None: + part.x = part.x.float() + if hasattr(part, "edge_attr") and part.edge_attr is not None: + part.edge_attr = part.edge_attr.float() + if hasattr(part, "y") and part.y is not None: + part.y = part.y.float() + + # Forward pass (disable mixed precision for now to avoid dtype issues) + pred = self.model(part.x, part.edge_attr, part) + + # Get inner nodes if available (for partitioned graphs) + if hasattr(part, "inner_node"): + pred_inner = pred[part.inner_node] + target_inner = ( + part.y[part.inner_node] + if hasattr(part, "y") + else part.y[part.inner_node] + ) + else: + pred_inner = pred + target_inner = part.y + + # Compute weighted normalized loss + loss = self.compute_weighted_loss(pred_inner, target_inner) + + # Denormalize for evaluation + pred_denorm = self.denormalize_predictions(pred_inner) + target_denorm = self.denormalize_targets(target_inner) + denorm_loss = self.compute_weighted_loss(pred_denorm, target_denorm) + + return loss, denorm_loss, pred_inner, target_inner + + def _process_graph(self, graph, is_training=True): + """ + Process a single graph (for both training and validation). + + Parameters + ---------- + graph : torch_geometric.data.Data + The graph to process + is_training : bool + Whether this is training (affects gradient computation) + + Returns + ------- + tuple + (loss, denorm_loss, pred, target) + """ + graph = graph.to(self.device) + + # Forward pass + if is_training and self.device.type == "cuda": + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + pred = self.model(graph.x, graph.edge_attr, graph) + else: + pred = self.model(graph.x, graph.edge_attr, graph) + + # Compute weighted normalized loss + loss = self.compute_weighted_loss(pred, graph.y) + + # Denormalize for evaluation + pred_denorm = self.denormalize_predictions(pred) + target_denorm = self.denormalize_targets(graph.y) + denorm_loss = self.compute_weighted_loss(pred_denorm, target_denorm) + + return loss, denorm_loss, pred, graph.y + + def train_epoch(self): + """Train the model for one epoch.""" + self.model.train() + total_loss = 0.0 + num_batches = 0 + + for batch_idx, batch in enumerate(self.train_dataloader): + # Handle the new format: batch is (partitions_list, labels) + partitions_list, labels = batch + + self.optimizer.zero_grad() + + # Process each sample's partitions in the batch + total_batch_loss = 0.0 + num_samples = len(partitions_list) + + for sample_idx, partitions in enumerate(partitions_list): + # Process each partition in this sample + sample_loss = 0.0 + num_partitions = len(partitions) + + for partition in partitions: + loss, _, _, _ = self._process_partition(partition, is_training=True) + + # For logging: accumulate loss scaled only by num_partitions (consistent with validation) + sample_loss += loss.item() / num_partitions + + # For gradient computation: scale by total number of forward passes in the batch + loss = loss / (num_partitions * num_samples) + + # Backward pass + if self.device.type == "cuda": + self.scaler.scale(loss).backward() + else: + loss.backward() + + # Accumulate loss from this sample + total_batch_loss += sample_loss + + # Update optimizer after processing all samples and partitions + if self.device.type == "cuda": + self.scaler.step(self.optimizer) + self.scaler.update() + else: + self.optimizer.step() + + total_loss += total_batch_loss + num_batches += 1 + + avg_loss = total_loss / num_batches if num_batches > 0 else 0.0 + + # Synchronize loss across all GPUs for accurate reporting + if self.dist.world_size > 1: + loss_tensor = torch.tensor(avg_loss, device=self.device) + dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG) + avg_loss = loss_tensor.item() + + return avg_loss + + def validate_epoch(self): + """Validate the model for one epoch.""" + self.model.eval() + total_loss = 0.0 + total_denorm_loss = 0.0 + num_batches = 0 + + # Collect all predictions and targets for per-variable metrics + all_predictions = [] + all_targets = [] + + with torch.no_grad(): + for batch_idx, batch in enumerate(self.val_dataloader): + # Handle the new format: batch is (partitions_list, labels) + partitions_list, labels = batch + + # Process each sample's partitions in the batch + batch_loss = 0.0 + batch_denorm_loss = 0.0 + num_samples = len(partitions_list) + + for sample_idx, partitions in enumerate(partitions_list): + # Process each partition in this sample + sample_loss = 0.0 + sample_denorm_loss = 0.0 + num_partitions = len(partitions) + + for partition in partitions: + loss, denorm_loss, pred, target = self._process_partition( + partition, is_training=False + ) + loss = loss / num_partitions + denorm_loss = denorm_loss / num_partitions + sample_loss += loss.item() + sample_denorm_loss += denorm_loss.item() + + # Collect predictions and targets for per-variable metrics + all_predictions.append(pred.cpu().numpy()) + all_targets.append(target.cpu().numpy()) + + batch_loss += sample_loss + batch_denorm_loss += sample_denorm_loss + + total_loss += batch_loss + total_denorm_loss += batch_denorm_loss + + num_batches += 1 + + avg_loss = total_loss / num_batches + avg_denorm_loss = total_denorm_loss / num_batches + + # Synchronize validation losses across all GPUs + if self.dist.world_size > 1: + loss_tensor = torch.tensor([avg_loss, avg_denorm_loss], device=self.device) + dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG) + avg_loss, avg_denorm_loss = loss_tensor[0].item(), loss_tensor[1].item() + + # Calculate per-variable metrics (simplified logging without tabulate) + if all_predictions and all_targets: + all_predictions = np.concatenate(all_predictions, axis=0) + all_targets = np.concatenate(all_targets, axis=0) + + # Prepare metrics dictionary for MLflow logging + metrics = {} + + # Calculate overall MAE and RMSE + if len(all_predictions) > 0 and len(all_targets) > 0: + overall_mae = np.mean(np.abs(all_predictions - all_targets)) + overall_mse = np.mean((all_predictions - all_targets) ** 2) + overall_rmse = np.sqrt(overall_mse) + + metrics["mae"] = overall_mae + metrics["mse"] = overall_mse + metrics["rmse"] = overall_rmse + + # Calculate per-variable metrics (normalized) + target_names = self.cfg.dataset.graph.target_vars.node_features + for i, var_name in enumerate(target_names): + var_mae = np.mean(np.abs(all_predictions[:, i] - all_targets[:, i])) + var_rmse = np.sqrt( + np.mean((all_predictions[:, i] - all_targets[:, i]) ** 2) + ) + + metrics[f"mae_{var_name.lower()}"] = var_mae + metrics[f"rmse_{var_name.lower()}"] = var_rmse + + # Calculate denormalized per-variable metrics if available + if "target_features" in self.stats: + target_mean = np.array(self.stats["target_features"]["mean"]) + target_std = np.array(self.stats["target_features"]["std"]) + + # Denormalize predictions and targets + all_predictions_denorm = all_predictions * target_std + target_mean + all_targets_denorm = all_targets * target_std + target_mean + + # Overall denormalized metrics + overall_mae_denorm = np.mean( + np.abs(all_predictions_denorm - all_targets_denorm) + ) + overall_mse_denorm = np.mean( + (all_predictions_denorm - all_targets_denorm) ** 2 + ) + overall_rmse_denorm = np.sqrt(overall_mse_denorm) + + metrics["mae_denorm"] = overall_mae_denorm + metrics["mse_denorm"] = overall_mse_denorm + metrics["rmse_denorm"] = overall_rmse_denorm + + # Per-variable denormalized metrics + for i, var_name in enumerate(target_names): + var_mae_denorm = np.mean( + np.abs(all_predictions_denorm[:, i] - all_targets_denorm[:, i]) + ) + var_rmse_denorm = np.sqrt( + np.mean( + (all_predictions_denorm[:, i] - all_targets_denorm[:, i]) + ** 2 + ) + ) + + metrics[f"mae_{var_name.lower()}_denorm"] = var_mae_denorm + metrics[f"rmse_{var_name.lower()}_denorm"] = var_rmse_denorm + + # Synchronize all metrics across GPUs + if self.dist.world_size > 1 and metrics: + # Convert metrics dict to tensor for reduction + metric_keys = sorted(metrics.keys()) + metric_values = torch.tensor( + [metrics[k] for k in metric_keys], device=self.device + ) + dist.all_reduce(metric_values, op=dist.ReduceOp.AVG) + + # Update metrics dict with synchronized values + for i, key in enumerate(metric_keys): + metrics[key] = metric_values[i].item() + + return avg_loss, avg_denorm_loss, metrics + + def train(self): + """ + Complete training loop with validation and checkpointing. + Handles resume logic internally based on configuration. + + Returns + -------- + float: Best validation loss + """ + # Handle training resume based on config + loaded_epoch = self._handle_training_resume() + + # Initialize best validation loss + best_val_loss = float("inf") + + # If resuming from a checkpoint, run validation to get current best validation loss + if loaded_epoch > 0: + self.logger.info( + f"Resuming training from epoch {loaded_epoch + 1}. Running validation to get current best validation loss..." + ) + val_loss, _, _ = self.validate_epoch() + best_val_loss = val_loss + self.logger.info(f"Current best validation loss: {best_val_loss:.6f}") + + for epoch in range(max(1, loaded_epoch + 1), self.num_epochs + 1): + # Set epoch for proper distributed sampling + if self.train_sampler is not None: + self.train_sampler.set_epoch(epoch) + if self.val_sampler is not None: + self.val_sampler.set_epoch(epoch) + + # Log progress + self.logger.info(f"Starting Epoch {epoch}/{self.num_epochs}") + + # Increment early stopping epoch counter + if self.early_stopping is not None: + self.early_stopping.step() + + # Train with LaunchLogger (handles MLflow automatically) + with LaunchLogger( + name_space="train", + num_mini_batch=len(self.train_dataloader), + epoch=epoch, + epoch_alert_freq=1, + ) as log: + train_loss = self.train_epoch() + log.log_epoch( + { + "train_loss": train_loss, + "learning_rate": self.optimizer.param_groups[0]["lr"], + "best_val_loss": best_val_loss + if best_val_loss != float("inf") + else 0.0, + } + ) + + # Validation step + val_loss, val_denorm_loss, val_metrics = self._validation_step(epoch) + + # Save best model and check early stopping + should_stop = self._check_early_stopping( + epoch, val_loss, val_metrics, best_val_loss + ) + + if val_loss != float("inf") and val_loss < best_val_loss: + best_val_loss = val_loss + + # Save regular checkpoint (only on rank 0) + if self.dist.rank == 0 and ( + epoch % self.validation_freq == 0 or epoch == self.num_epochs + ): + save_checkpoint(**self.ckpt_args, epoch=epoch) + + # Update learning rate + self.scheduler.step() + + # Log training progress (ZeroRankLogger handles rank 0 automatically) + self.logger.info( + f"Epoch {epoch}/{self.num_epochs}, Train loss: {train_loss:.6f}, LR: {self.optimizer.param_groups[0]['lr']:.6f}" + ) + + # Break if early stopping triggered + if should_stop: + self.logger.info( + f"Training stopped early at epoch {epoch} due to early stopping" + ) + break + + self.logger.info( + f"Training completed! Best validation loss: {best_val_loss:.6f}" + ) + + def _validation_step(self, epoch): + """ + Perform validation step with comprehensive metrics logging. + Validation runs on all ranks for proper DDP, but only rank 0 logs to MLflow. + + Parameters + ---------- + epoch : int + Current epoch number + + Returns + ------- + tuple + (val_loss, val_denorm_loss, val_metrics) + """ + val_loss = float("inf") + val_denorm_loss = float("inf") + val_metrics = None + + # Run validation on all ranks at validation frequency + if epoch % self.validation_freq == 0 or epoch == self.num_epochs: + val_loss, val_denorm_loss, val_metrics = self.validate_epoch() + + # Only log to MLflow on rank 0 + if self.dist.rank == 0: + with LaunchLogger("valid", epoch=epoch) as log: + # Prepare comprehensive metrics for logging + metrics_to_log = self._prepare_validation_metrics( + val_loss, val_denorm_loss, val_metrics + ) + log.log_epoch(metrics_to_log) + + return val_loss, val_denorm_loss, val_metrics + + def _prepare_validation_metrics(self, val_loss, val_denorm_loss, val_metrics): + """ + Prepare comprehensive validation metrics for logging. + + Parameters + ---------- + val_loss : float + Validation loss + val_denorm_loss : float + Denormalized validation loss + val_metrics : dict + Additional validation metrics + + Returns + ------- + dict + Metrics to log + """ + metrics_to_log = { + "val_loss": val_loss, + "val_denorm_loss": val_denorm_loss, + } + + if val_metrics: + # Add overall MAE and MSE + if "mae" in val_metrics: + metrics_to_log["val_mae"] = val_metrics["mae"] + if "mse" in val_metrics: + metrics_to_log["val_mse"] = val_metrics["mse"] + if "rmse" in val_metrics: + metrics_to_log["val_rmse"] = val_metrics["rmse"] + + # Add per-variable metrics (normalized) + target_names = self.cfg.dataset.graph.target_vars.node_features + for i, var_name in enumerate(target_names): + if f"mae_{var_name.lower()}" in val_metrics: + metrics_to_log[f"val_mae_{var_name.lower()}"] = val_metrics[ + f"mae_{var_name.lower()}" + ] + if f"rmse_{var_name.lower()}" in val_metrics: + metrics_to_log[f"val_rmse_{var_name.lower()}"] = val_metrics[ + f"rmse_{var_name.lower()}" + ] + + # Add denormalized per-variable metrics if available + for i, var_name in enumerate(target_names): + if f"mae_{var_name.lower()}_denorm" in val_metrics: + metrics_to_log[f"val_mae_{var_name.lower()}_denorm"] = val_metrics[ + f"mae_{var_name.lower()}_denorm" + ] + if f"rmse_{var_name.lower()}_denorm" in val_metrics: + metrics_to_log[f"val_rmse_{var_name.lower()}_denorm"] = val_metrics[ + f"rmse_{var_name.lower()}_denorm" + ] + + return metrics_to_log + + def _check_early_stopping(self, epoch, val_loss, val_metrics, best_val_loss): + """ + Check early stopping and save best model. + + Parameters + ---------- + epoch : int + Current epoch number + val_loss : float + Validation loss + val_metrics : dict + Validation metrics + best_val_loss : float + Current best validation loss + + Returns + ------- + bool + True if early stopping should trigger, False otherwise + """ + should_stop = False + + # Save best model (only if validation was performed and only on rank 0) + if ( + self.dist.rank == 0 + and val_loss != float("inf") + and val_loss < best_val_loss + ): + save_checkpoint(**self.bst_ckpt_args, epoch=epoch) + + # Check early stopping (only on rank 0 and if validation was performed) + if ( + self.dist.rank == 0 + and self.early_stopping is not None + and val_loss != float("inf") + ): + # Check if validation improved + self.early_stopping.check_improvement(val_loss) + + # Check if we should stop based on epochs without improvement + should_stop = self.early_stopping.should_stop() + + if should_stop: + self.logger.info( + f"Early stopping triggered at epoch {epoch}. " + f"Best val_loss: {self.early_stopping.best_score:.6f}, " + f"Current: {val_loss:.6f}, " + f"Epochs without improvement: {self.early_stopping.epochs_since_improvement}/{self.early_stopping.patience}" + ) + + # Broadcast early stopping decision to all ranks + if self.dist.world_size > 1: + should_stop_tensor = torch.tensor(int(should_stop), device=self.device) + dist.broadcast(should_stop_tensor, src=0) + should_stop = bool(should_stop_tensor.item()) + + return should_stop + + def _handle_training_resume(self): + """Handle training resume based on configuration.""" + import os + import shutil + + checkpoint_dir = self.ckpt_args["path"] + best_checkpoint_dir = self.bst_ckpt_args["path"] + + # Check if any checkpoint files exist + has_checkpoints = False + if os.path.exists(checkpoint_dir): + checkpoint_files = [ + f + for f in os.listdir(checkpoint_dir) + if f.endswith(".pt") or f.endswith(".mdlus") + ] + if checkpoint_files: + has_checkpoints = True + + if os.path.exists(best_checkpoint_dir): + best_checkpoint_files = [ + f + for f in os.listdir(best_checkpoint_dir) + if f.endswith(".pt") or f.endswith(".mdlus") + ] + if best_checkpoint_files: + has_checkpoints = True + + if self.cfg.training.resume and has_checkpoints: + self.logger.info("Resuming training from existing checkpoints...") + # Load checkpoint and return the epoch + return load_checkpoint(**self.ckpt_args, device=self.dist.device) + elif self.cfg.training.resume and not has_checkpoints: + self.logger.warning( + "Resume enabled but no checkpoints found. Starting fresh training..." + ) + return 0 + elif not self.cfg.training.resume and has_checkpoints: + self.logger.info("Resume disabled: Deleting existing checkpoint files...") + try: + if os.path.exists(checkpoint_dir): + shutil.rmtree(checkpoint_dir) + os.makedirs(checkpoint_dir, exist_ok=True) + if os.path.exists(best_checkpoint_dir): + shutil.rmtree(best_checkpoint_dir) + os.makedirs(best_checkpoint_dir, exist_ok=True) + self.logger.success( + "Checkpoint files deleted. Starting fresh training..." + ) + except (OSError, PermissionError) as e: + self.logger.warning(f"Could not delete some checkpoint files: {e}") + self.logger.info("Starting fresh training anyway...") + return 0 + else: + self.logger.info("Starting fresh training...") + return 0 + + +@hydra.main(version_base="1.3", config_path="../conf", config_name="config") +def main(cfg: DictConfig) -> None: + """ + Main training entry point. + Trains XMeshGraphNet on reservoir simulation data. + """ + + dist, logger = InitializeLoggers(cfg) + + trainer = Trainer(cfg, dist, logger) + + trainer.train() + + logger.success("Training completed successfully!") + + +if __name__ == "__main__": + main() diff --git a/examples/reservoir_simulation/xmgn/src/utils.py b/examples/reservoir_simulation/xmgn/src/utils.py new file mode 100644 index 0000000000..cee23f75f8 --- /dev/null +++ b/examples/reservoir_simulation/xmgn/src/utils.py @@ -0,0 +1,190 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utility functions and classes for XMeshGraphNet training and inference. +""" + +import os +import logging +from hydra.utils import to_absolute_path +from omegaconf import DictConfig + +logger = logging.getLogger(__name__) + + +def fix_layernorm_compatibility(): + """ + Fix LayerNorm compatibility issue for PyTorch. + + This addresses a compatibility issue where LayerNorm may not have + the register_load_state_dict_pre_hook method in some PyTorch versions. + Should be called early in script initialization. + """ + import torch.nn as nn + + if not hasattr(nn.LayerNorm, "register_load_state_dict_pre_hook"): + + def _register_load_state_dict_pre_hook(self, hook): + """Dummy implementation for compatibility.""" + return None + + nn.LayerNorm.register_load_state_dict_pre_hook = ( + _register_load_state_dict_pre_hook + ) + + +class EarlyStopping: + """ + Early stopping utility to stop training when validation metric stops improving. + Counts actual epochs, not validation checks. + """ + + def __init__(self, patience=20, min_delta=1e-6): + """ + Initialize early stopping. + + Parameters + ---------- + patience : int + Number of epochs to wait for improvement + min_delta : float + Minimum change to qualify as improvement + """ + self.patience = patience + self.min_delta = min_delta + self.best_score = None + self.epochs_since_improvement = 0 + self.early_stop = False + + def step(self): + """Increment epoch counter. Call this every epoch.""" + self.epochs_since_improvement += 1 + + def check_improvement(self, current_score): + """ + Check if validation score has improved. + + Parameters + ---------- + current_score : float + Current validation loss + + Returns + ------- + bool + True if there was improvement, False otherwise + """ + if self.best_score is None: + self.best_score = current_score + self.epochs_since_improvement = 0 + return True + + # Always use "min" mode (lower is better for loss) + improved = current_score < (self.best_score - self.min_delta) + + if improved: + self.best_score = current_score + self.epochs_since_improvement = 0 + return True + + return False + + def should_stop(self): + """ + Check if training should be stopped. + + Returns + ------- + bool + True if training should stop, False otherwise + """ + if self.epochs_since_improvement >= self.patience: + self.early_stop = True + + return self.early_stop + + +def get_dataset_dir(cfg: DictConfig) -> str: + """ + Get the job-specific dataset directory path. + + Parameters + ----------- + cfg : DictConfig + Hydra configuration object + + Returns + -------- + str: Path to the job-specific dataset directory + """ + # Get job name from runspec (required) + if not hasattr(cfg, "runspec") or not hasattr(cfg.runspec, "job_name"): + raise ValueError("runspec.job_name is required in configuration") + + job_name = cfg.runspec.job_name + + # Get simulation directory from dataset (required) + if not hasattr(cfg, "dataset") or not hasattr(cfg.dataset, "sim_dir"): + raise ValueError("dataset.sim_dir is required in configuration") + + # Create base dataset directory path + base_dataset_dir = to_absolute_path(cfg.dataset.sim_dir + ".dataset") + + # Return job-specific dataset directory + return os.path.join(base_dataset_dir, job_name) + + +def get_dataset_paths(cfg: DictConfig) -> dict: + """ + Get all dataset-related paths for a given configuration. + + Parameters + ----------- + cfg : DictConfig + Hydra configuration object + + Returns + -------- + dict: Dictionary containing all dataset paths + """ + dataset_dir = get_dataset_dir(cfg) + + return { + "dataset_dir": dataset_dir, + "graphs_dir": os.path.join(dataset_dir, "graphs"), + "partitions_dir": os.path.join(dataset_dir, "partitions"), + "stats_file": os.path.join(dataset_dir, "global_stats.json"), + "train_partitions_path": os.path.join(dataset_dir, "partitions", "train"), + "val_partitions_path": os.path.join(dataset_dir, "partitions", "val"), + "test_partitions_path": os.path.join(dataset_dir, "partitions", "test"), + } + + +def print_dataset_info(cfg: DictConfig) -> None: + """ + Print dataset directory information for debugging. + + Parameters + ----------- + cfg : DictConfig + Hydra configuration object + """ + job_name = cfg.runspec.job_name + dataset_dir = get_dataset_dir(cfg) + + logger.info(f"Job name: {job_name}") + logger.info(f"Dataset directory: {dataset_dir}")