diff --git a/README.md b/README.md index 79f1c309..9ea3704e 100755 --- a/README.md +++ b/README.md @@ -33,6 +33,7 @@ ## ๐Ÿ“‘ Task - [MLIP-Machine Learning Interatomic Potential](interatomic_potentials/README.md) - [MLES-Machine Learning Electronic Structure](electronic_structure/README.md) +- [SPEN-Spectrum Enhancement](spectrum_enhancement/README.md) - [PP-Property Prediction](property_prediction/README.md) - [SG-Structure Generation](structure_generation/README.md) - [SE-Spectrum Elucidation](spectrum_elucidation/README.md) diff --git a/ppmat/datasets/__init__.py b/ppmat/datasets/__init__.py index 98eec451..564654c0 100644 --- a/ppmat/datasets/__init__.py +++ b/ppmat/datasets/__init__.py @@ -42,10 +42,11 @@ from ppmat.datasets.msd_nmr_dataset import MSDnmrDataset from ppmat.datasets.msd_nmr_dataset import MSDnmrinfos from ppmat.datasets.density_dataset import DensityDataset -from ppmat.datasets.small_density_dataset import SmallDensityDataset +from ppmat.datasets.small_density_dataset import SmallDensityDataset +from ppmat.datasets.stem_image_dataset import STEMImageDataset from ppmat.datasets.num_atom_crystal_dataset import NumAtomsCrystalDataset from ppmat.datasets.oc20_s2ef_dataset import OC20S2EFDataset # noqa -from ppmat.datasets.qm9_dataset import QM9Dataset # noqa +from ppmat.datasets.qm9_dataset import QM9Dataset # noqa from ppmat.datasets.omol25_dataset import OMol25Dataset from ppmat.datasets.split_mptrj_data import none_to_zero from ppmat.datasets.transform import build_transforms @@ -64,8 +65,9 @@ "HighLevelWaterDataset", "MSDnmrDataset", "MatbenchDataset", - "DensityDataset", + "DensityDataset", "SmallDensityDataset", + "STEMImageDataset", "OMol25Dataset", ] diff --git a/ppmat/datasets/build_stem.py b/ppmat/datasets/build_stem.py new file mode 100644 index 00000000..3560b8a4 --- /dev/null +++ b/ppmat/datasets/build_stem.py @@ -0,0 +1,297 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import copy +import importlib +from pathlib import Path +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple + +from ppmat.utils import download as download_utils +from ppmat.utils import logger + + +def _locate_class(class_name: str): + if "." in class_name: + mod, cls = class_name.rsplit(".", 1) + return getattr(importlib.import_module(mod), cls) + return globals()[class_name] + + +def _parse_factory_cfg( + cfg: Optional[Dict[str, Any] | str], + *, + default_class_name: str, +) -> Tuple[str, Dict[str, Any]]: + """ + Parse factory config in a backward/forward-compatible way. + + Supported patterns: + 1) None + -> use default class with empty params + 2) "ClassName" / "pkg.mod.ClassName" + -> use class name string with empty params + 3) {"__class_name__": "...", "__init_params__": {...}} + -> current canonical style (same as many ppmat builders) + 4) {"class_name": "...", "init_params": {...}} + -> compatibility alias + 5) {"type": "...", "params": {...}} + -> compatibility alias + """ + if cfg is None: + return default_class_name, {} + + if isinstance(cfg, str): + return cfg, {} + + if not isinstance(cfg, dict): + raise TypeError( + "Factory cfg must be None, str, or dict, " + f"but got type={type(cfg).__name__}." + ) + + cfg = copy.deepcopy(cfg) + class_name = ( + cfg.pop("__class_name__", None) + or cfg.pop("class_name", None) + or cfg.pop("type", None) + ) + if not class_name: + raise ValueError( + "Factory cfg must include class name key, e.g. " + "{'__class_name__': 'StrictIndexSampleBuilder', '__init_params__': {...}}." + ) + + init_params = ( + cfg.pop("__init_params__", None) + if "__init_params__" in cfg + else cfg.pop("init_params", None) + if "init_params" in cfg + else cfg.pop("params", None) + if "params" in cfg + else {} + ) + if init_params is None: + init_params = {} + if not isinstance(init_params, dict): + raise TypeError( + f"Factory init params must be dict, but got type={type(init_params).__name__}." + ) + + if cfg: + raise ValueError( + f"Unsupported keys in factory cfg for '{class_name}': {list(cfg.keys())}" + ) + return class_name, init_params + + +def _build_component( + cfg: Optional[Dict[str, Any] | str], + *, + default_class_name: str, + required_methods: List[str], +): + class_name, init_params = _parse_factory_cfg( + cfg, + default_class_name=default_class_name, + ) + + cls = _locate_class(class_name) + component = cls(**init_params) + for method_name in required_methods: + if not hasattr(component, method_name): + raise TypeError( + f"Component '{class_name}' must implement method '{method_name}'." + ) + return component + + +class StrictIndexSampleBuilder: + def build( + self, + noisy_dir: Path, + target_dir: Path, + file_suffix: str, + data_count: Optional[int] = None, + ) -> List[Dict[str, str]]: + samples: List[Dict[str, str]] = [] + if data_count is None: + available = sorted(noisy_dir.glob(f"*{file_suffix}")) + data_count = len(available) + + for idx in range(int(data_count)): + name = f"{idx}{file_suffix}" + noisy_path = noisy_dir / name + target_path = target_dir / name + if not noisy_path.exists(): + raise FileNotFoundError(f"Noisy image not found: {noisy_path}") + if not target_path.exists(): + raise FileNotFoundError(f"Target image not found: {target_path}") + samples.append( + { + "name": name, + "noisy_path": str(noisy_path), + "target_path": str(target_path), + } + ) + return samples + + +class MatchedNameSampleBuilder: + def build( + self, + noisy_dir: Path, + target_dir: Path, + file_suffix: str, + data_count: Optional[int] = None, + ) -> List[Dict[str, str]]: + samples: List[Dict[str, str]] = [] + noisy_files = { + p.name: p for p in noisy_dir.glob(f"*{file_suffix}") if p.is_file() + } + target_files = { + p.name: p for p in target_dir.glob(f"*{file_suffix}") if p.is_file() + } + common_names = sorted(set(noisy_files.keys()) & set(target_files.keys())) + if data_count is not None: + common_names = common_names[: int(data_count)] + + for name in common_names: + samples.append( + { + "name": name, + "noisy_path": str(noisy_files[name]), + "target_path": str(target_files[name]), + } + ) + return samples + + +class DefaultSTEMDatasetDownloader: + def __init__(self, datasets_home: Optional[str] = None): + self.datasets_home = datasets_home or download_utils.DATASETS_HOME + + def download( + self, url: str, md5: Optional[str] = None, force_download: bool = False + ) -> Path: + if force_download: + downloaded_root = download_utils.get_path_from_url( + url, + self.datasets_home, + md5sum=md5, + check_exist=False, + decompress=True, + ) + else: + downloaded_root = download_utils.get_datasets_path_from_url(url, md5) + return Path(downloaded_root) + + +class PairDirectoryDataRootResolver: + def __init__(self, max_depth: int = 2): + if max_depth < 0: + raise ValueError(f"max_depth must be >= 0, but got {max_depth}") + self.max_depth = int(max_depth) + + @staticmethod + def _contains_pair_dirs(root: Path, noisy_subdir: str, target_subdir: str) -> bool: + return ( + root.is_dir() + and (root / noisy_subdir).exists() + and (root / target_subdir).exists() + ) + + def find_data_roots( + self, + base_root: Path, + split: Optional[str], + noisy_subdir: str, + target_subdir: str, + ) -> List[Path]: + if not base_root.exists(): + return [] + + candidate_roots: List[Path] = [base_root] + frontier: List[Path] = [base_root] + for _ in range(self.max_depth): + next_frontier: List[Path] = [] + for root in frontier: + for child in root.iterdir(): + if child.is_dir(): + candidate_roots.append(child) + next_frontier.append(child) + frontier = next_frontier + + matches: List[Path] = [] + for root in candidate_roots: + if split is not None: + split_root = root / split + if self._contains_pair_dirs(split_root, noisy_subdir, target_subdir): + matches.append(split_root) + if self._contains_pair_dirs(root, noisy_subdir, target_subdir): + matches.append(root) + + seen = set() + unique_matches = [] + for path in matches: + path_str = str(path) + if path_str in seen: + continue + seen.add(path_str) + unique_matches.append(path) + return unique_matches + + +def build_stem_sample_builder( + cfg: Optional[Dict[str, Any] | str], + *, + strict_index_naming: bool, +): + default_class_name = ( + "StrictIndexSampleBuilder" + if strict_index_naming + else "MatchedNameSampleBuilder" + ) + sample_builder = _build_component( + cfg, + default_class_name=default_class_name, + required_methods=["build"], + ) + logger.debug(f"Use sample builder: {sample_builder.__class__.__name__}") + return sample_builder + + +def build_stem_downloader(cfg: Optional[Dict[str, Any] | str]): + downloader = _build_component( + cfg, + default_class_name="DefaultSTEMDatasetDownloader", + required_methods=["download"], + ) + logger.debug(f"Use downloader: {downloader.__class__.__name__}") + return downloader + + +def build_stem_data_root_resolver(cfg: Optional[Dict[str, Any] | str]): + resolver = _build_component( + cfg, + default_class_name="PairDirectoryDataRootResolver", + required_methods=["find_data_roots"], + ) + logger.debug(f"Use data root resolver: {resolver.__class__.__name__}") + return resolver diff --git a/ppmat/datasets/stem_image_dataset.py b/ppmat/datasets/stem_image_dataset.py new file mode 100644 index 00000000..88d530de --- /dev/null +++ b/ppmat/datasets/stem_image_dataset.py @@ -0,0 +1,296 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from pathlib import Path +from typing import Dict +from typing import List +from typing import Optional + +import numpy as np +import paddle +from PIL import Image + +from ppmat.datasets.build_stem import build_stem_data_root_resolver +from ppmat.datasets.build_stem import build_stem_downloader +from ppmat.datasets.build_stem import build_stem_sample_builder +from ppmat.utils import logger + + +class STEMImageDataset(paddle.io.Dataset): + """Dataset for paired STEM image restoration/enhancement. + + Supports automatic download and extraction (zip/tar/tar.gz) through + ``ppmat.utils.download.get_datasets_path_from_url``. + + Expected directory layout after extraction: + data_path/ + train/ + noisy/ + 0.png + 1.png + ... + gt_enhance/ + 0.png + 1.png + ... + val/ + noisy/ + ... + gt_enhance/ + ... + test/ + noisy/ + ... + gt_enhance/ + ... + + Or legacy format (backward compatible): + data_path/ + noisy/ + 0.png + ... + gt_enhance/ + 0.png + ... + """ + + name = "stem_enhancement" + url = None + md5 = None + _DEFAULT_URL_MAP = { + "data": "https://paddle-org.bj.bcebos.com/paddlematerials/datasets/SFIN_datasets/haadf_data.zip", + "data_test": "https://paddle-org.bj.bcebos.com/paddlematerials/datasets/SFIN_datasets/haadf_data_test.zip", + "bf_data": "https://paddle-org.bj.bcebos.com/paddlematerials/datasets/SFIN_datasets/bf_data.zip", + "bf_data_test": "https://paddle-org.bj.bcebos.com/paddlematerials/datasets/SFIN_datasets/bf_data_test.zip", + } + _DEFAULT_MD5_MAP: Dict[str, str] = {} + + def __init__( + self, + data_path: str, + split: Optional[str] = None, + data_count: int | None = None, + noisy_subdir: str = "noisy", + target_subdir: str = "gt_enhance", + file_suffix: str = ".png", + strict_index_naming: bool = True, + sample_builder_cfg: Optional[Dict] = None, + downloader_cfg: Optional[Dict] = None, + data_root_resolver_cfg: Optional[Dict] = None, + scale_to_unit: bool = False, + url: Optional[str] = None, + md5: Optional[str] = None, + download: bool = True, + force_download: bool = False, + ): + """Initialize STEMImageDataset. + + Args: + data_path: Root directory for the dataset. + split: Dataset split, one of 'train', 'val', 'test', or None. + If None, uses legacy format without split subdirectories. + data_count: Maximum number of samples to load. None means all. + noisy_subdir: Subdirectory name for noisy images. + target_subdir: Subdirectory name for target/ground truth images. + file_suffix: File extension for images (e.g., '.png', '.tif'). + strict_index_naming: If True, expects files named as {idx}{suffix}. + sample_builder_cfg: Sample builder config in factory style: + { + "__class_name__": "StrictIndexSampleBuilder", + "__init_params__": {} + } + downloader_cfg: Downloader config in factory style. + Default class is "DefaultSTEMDatasetDownloader". + data_root_resolver_cfg: Data root resolver config in factory style. + Default class is "PairDirectoryDataRootResolver". + scale_to_unit: If True, scales pixel values to [0, 1]. + url: URL to download the dataset from. Overrides default URL. + md5: MD5 checksum for downloaded file. Optional. + download: Whether to automatically download if data not found. + force_download: If True, re-download even if data exists. + """ + super().__init__() + + self.split = split + self.data_count = data_count + self.noisy_subdir = noisy_subdir + self.target_subdir = target_subdir + self.file_suffix = file_suffix + self.strict_index_naming = strict_index_naming + self.scale_to_unit = scale_to_unit + self.sample_builder = build_stem_sample_builder( + sample_builder_cfg, + strict_index_naming=strict_index_naming, + ) + self.downloader = build_stem_downloader(downloader_cfg) + self.data_root_resolver = build_stem_data_root_resolver(data_root_resolver_cfg) + + self.url = url if url is not None else self._infer_default_url(data_path) + self.md5 = ( + md5 if md5 is not None else self._infer_default_md5(data_path) or self.md5 + ) + self.data_root = Path(data_path) + self.downloaded_root: Optional[Path] = None + + if self._locate_data_root(self.data_root) is None and ( + download or force_download + ): + self.downloaded_root = self._download_dataset(force_download) + + # Determine actual data directory based on split + self.data_dir = self._resolve_data_dir() + + # Set up noisy and target directories + self.noisy_dir = self.data_dir / noisy_subdir + self.target_dir = self.data_dir / target_subdir + + if not self.noisy_dir.exists(): + raise FileNotFoundError(f"Noisy directory not found: {self.noisy_dir}") + if not self.target_dir.exists(): + raise FileNotFoundError(f"Target directory not found: {self.target_dir}") + + self.samples = self._build_samples(data_count) + + def _resolve_data_dir(self) -> Path: + """Resolve the actual data directory based on split configuration.""" + candidate_roots = [self.data_root] + if self.downloaded_root is not None: + for root in [self.downloaded_root, self.downloaded_root.parent]: + if root != self.data_root and root not in candidate_roots: + candidate_roots.append(root) + + for candidate_root in candidate_roots: + matches = self._find_data_roots(candidate_root) + if not matches: + continue + if ( + self.downloaded_root is not None + and candidate_root == self.downloaded_root.parent + and len(matches) > 1 + ): + raise FileNotFoundError( + "Multiple candidate dataset roots were found under " + f"'{candidate_root}': {[str(m) for m in matches]}. " + "Please provide a more specific local `data_path` or explicit `url`." + ) + + data_root_candidate = matches[0] + if self.split is not None and data_root_candidate == candidate_root: + logger.warning( + f"Split '{self.split}' requested but legacy format detected. " + f"Using data directly from {candidate_root}" + ) + return data_root_candidate + + searched_roots = ", ".join([str(path) for path in candidate_roots]) + if self.split is not None: + raise FileNotFoundError( + f"Split '{self.split}' not found under: {searched_roots}" + ) + raise FileNotFoundError( + "Cannot locate dataset directories " + f"'{self.noisy_subdir}' and '{self.target_subdir}' under: {searched_roots}" + ) + + def _download_dataset(self, force_download: bool = False) -> Path: + """Download dataset with built-in ppmat factory utility.""" + if not self.url: + candidate = ", ".join(sorted(self._DEFAULT_URL_MAP.keys())) + raise FileNotFoundError( + f"Dataset not found at '{self.data_root}', and no download URL provided. " + f"Auto-url is only inferred for data_path basename in [{candidate}]." + ) + + logger.message( + f"Dataset root {self.data_root} not found. Will download it now." + ) + downloaded_root = self.downloader.download( + self.url, + self.md5, + force_download=force_download, + ) + logger.info(f"Dataset downloaded to: {downloaded_root}") + return Path(downloaded_root) + + @classmethod + def _infer_default_url(cls, data_path: str) -> Optional[str]: + key = Path(data_path).name + url = cls._DEFAULT_URL_MAP.get(key) + if url is not None: + logger.info( + f"Infer dataset download URL by data_path='{data_path}': {url}" + ) + return url + + @classmethod + def _infer_default_md5(cls, data_path: str) -> Optional[str]: + key = Path(data_path).name + md5 = cls._DEFAULT_MD5_MAP.get(key) + if md5 is not None: + logger.info( + f"Infer dataset md5 by data_path='{data_path}': {md5}" + ) + return md5 + + def _locate_data_root(self, base_root: Path) -> Optional[Path]: + matches = self._find_data_roots(base_root) + if not matches: + return None + return matches[0] + + def _find_data_roots(self, base_root: Path) -> List[Path]: + return self.data_root_resolver.find_data_roots( + base_root=base_root, + split=self.split, + noisy_subdir=self.noisy_subdir, + target_subdir=self.target_subdir, + ) + + def _build_samples(self, data_count: int | None) -> List[Dict[str, str]]: + """Build list of sample dictionaries.""" + return self.sample_builder.build( + noisy_dir=self.noisy_dir, + target_dir=self.target_dir, + file_suffix=self.file_suffix, + data_count=data_count, + ) + + def __len__(self): + return len(self.samples) + + def _load_gray_image(self, path: str) -> paddle.Tensor: + """Load image as grayscale tensor.""" + image = Image.open(path).convert("L") + image_array = np.asarray(image, dtype=np.float32) + if self.scale_to_unit: + image_array = image_array / 255.0 + return paddle.to_tensor(image_array).unsqueeze(0) + + def __getitem__(self, idx: int): + sample = self.samples[idx] + noisy = self._load_gray_image(sample["noisy_path"]) + target = self._load_gray_image(sample["target_path"]) + + output = { + "noisy": noisy, + self.target_subdir: target, + "target": target, + "name": sample["name"], + } + # Backward compatibility for legacy code paths that read `gt_enhance`. + if self.target_subdir != "gt_enhance": + output["gt_enhance"] = target + return output diff --git a/ppmat/metrics/__init__.py b/ppmat/metrics/__init__.py index a0e3fb75..a304809f 100644 --- a/ppmat/metrics/__init__.py +++ b/ppmat/metrics/__init__.py @@ -18,11 +18,19 @@ from ppmat.metrics.csp_metric import CSPMetric from ppmat.metrics.diffnmr_streaming_adapter import DiffNMRStreamingAdapter +from ppmat.metrics.sfin_metric import PSNRMetric +from ppmat.metrics.sfin_metric import SSIMMetric +from ppmat.metrics.sfin_metric import calc_psnr +from ppmat.metrics.sfin_metric import calc_ssim __all__ = [ "build_metric", "CSPMetric", "DiffNMRStreamingAdapter", + "PSNRMetric", + "SSIMMetric", + "calc_psnr", + "calc_ssim", # "DiffNMRMetric", # "NLL", "CrossEntropyMetric", "SumExceptBatchMetric", "SumExceptBatchKL", ] diff --git a/ppmat/metrics/sfin_metric.py b/ppmat/metrics/sfin_metric.py new file mode 100644 index 00000000..3a26c9f5 --- /dev/null +++ b/ppmat/metrics/sfin_metric.py @@ -0,0 +1,154 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import paddle + + +def calc_psnr( + pred: paddle.Tensor, + label: paddle.Tensor, + data_range: float = 255.0, + eps: float = 1e-12, +) -> paddle.Tensor: + """Compute batch PSNR for image tensors with shape [N, C, H, W].""" + pred = pred.astype("float64") + label = label.astype("float64") + diff = (pred - label) / data_range + mse = paddle.mean(diff * diff) + mse = paddle.maximum(mse, paddle.to_tensor(eps, dtype=mse.dtype)) + return -10.0 * paddle.log10(mse) + + +def _gaussian_window_2d( + channels: int, + win_size: int = 11, + win_sigma: float = 1.5, + dtype: str = "float32", +) -> paddle.Tensor: + coords = paddle.arange(win_size, dtype=dtype) - (win_size // 2) + gauss = paddle.exp(-(coords**2) / (2.0 * (win_sigma**2))) + gauss = gauss / paddle.sum(gauss) + window_2d = gauss.unsqueeze(1) * gauss.unsqueeze(0) + window = window_2d.reshape([1, 1, win_size, win_size]) + return paddle.tile(window, [channels, 1, 1, 1]) + + +def calc_ssim( + pred: paddle.Tensor, + label: paddle.Tensor, + data_range: float = 255.0, + win_size: int = 11, + win_sigma: float = 1.5, + k1: float = 0.01, + k2: float = 0.03, + nonnegative_ssim: bool = False, +) -> paddle.Tensor: + """Compute batch SSIM for image tensors with shape [N, C, H, W].""" + if pred.shape != label.shape: + raise ValueError( + f"Input images should have the same dimensions, got {pred.shape} and {label.shape}." + ) + if len(pred.shape) != 4: + raise ValueError( + f"Input images should be 4-d tensors [N, C, H, W], got shape {pred.shape}." + ) + if win_size % 2 != 1: + raise ValueError("win_size must be odd.") + + pred = pred.astype("float32") + label = label.astype("float32") + + channels = pred.shape[1] + window = _gaussian_window_2d( + channels=channels, + win_size=win_size, + win_sigma=win_sigma, + dtype=pred.dtype, + ) + + mu1 = paddle.nn.functional.conv2d(pred, window, stride=1, padding=0, groups=channels) + mu2 = paddle.nn.functional.conv2d(label, window, stride=1, padding=0, groups=channels) + + mu1_sq = mu1 * mu1 + mu2_sq = mu2 * mu2 + mu1_mu2 = mu1 * mu2 + + sigma1_sq = paddle.nn.functional.conv2d( + pred * pred, window, stride=1, padding=0, groups=channels + ) - mu1_sq + sigma2_sq = paddle.nn.functional.conv2d( + label * label, window, stride=1, padding=0, groups=channels + ) - mu2_sq + sigma12 = paddle.nn.functional.conv2d( + pred * label, window, stride=1, padding=0, groups=channels + ) - mu1_mu2 + + c1 = (k1 * data_range) ** 2 + c2 = (k2 * data_range) ** 2 + c1 = paddle.to_tensor(c1, dtype=pred.dtype) + c2 = paddle.to_tensor(c2, dtype=pred.dtype) + + cs_map = (2.0 * sigma12 + c2) / (sigma1_sq + sigma2_sq + c2) + ssim_map = ((2.0 * mu1_mu2 + c1) / (mu1_sq + mu2_sq + c1)) * cs_map + + if nonnegative_ssim: + ssim_map = paddle.nn.functional.relu(ssim_map) + + return paddle.mean(ssim_map) + + +class PSNRMetric: + def __init__(self, data_range: float = 255.0, eps: float = 1e-12): + self.data_range = data_range + self.eps = eps + + def __call__(self, pred: paddle.Tensor, label: paddle.Tensor): + return calc_psnr( + pred=pred, + label=label, + data_range=self.data_range, + eps=self.eps, + ) + + +class SSIMMetric: + def __init__( + self, + data_range: float = 255.0, + win_size: int = 11, + win_sigma: float = 1.5, + k1: float = 0.01, + k2: float = 0.03, + nonnegative_ssim: bool = False, + ): + self.data_range = data_range + self.win_size = win_size + self.win_sigma = win_sigma + self.k1 = k1 + self.k2 = k2 + self.nonnegative_ssim = nonnegative_ssim + + def __call__(self, pred: paddle.Tensor, label: paddle.Tensor): + return calc_ssim( + pred=pred, + label=label, + data_range=self.data_range, + win_size=self.win_size, + win_sigma=self.win_sigma, + k1=self.k1, + k2=self.k2, + nonnegative_ssim=self.nonnegative_ssim, + ) diff --git a/ppmat/models/__init__.py b/ppmat/models/__init__.py index 95d73232..130704ec 100644 --- a/ppmat/models/__init__.py +++ b/ppmat/models/__init__.py @@ -42,6 +42,7 @@ from ppmat.models.megnet.megnet import MEGNetPlus from ppmat.models.infgcn.infgcn import InfGCN from ppmat.models.mateno.mateno import MatENO +from ppmat.models.sfin.sfin import SFIN from ppmat.utils import download from ppmat.utils import logger from ppmat.utils import save_load @@ -67,6 +68,7 @@ "DiffNMR", "InfGCN", "MatENO", + "SFIN", ] # Warning: The key of the dictionary must be consistent with the file name of the value diff --git a/ppmat/models/sfin/__init__.py b/ppmat/models/sfin/__init__.py new file mode 100644 index 00000000..d890dba7 --- /dev/null +++ b/ppmat/models/sfin/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ppmat.models.sfin.sfin import SFIN, FourierUnit, SpectralTransform, FFC, SFIB, ResnetBlock + +__all__ = [ + "SFIN", + "FourierUnit", + "SpectralTransform", + "FFC", + "SFIB", + "ResnetBlock", +] diff --git a/ppmat/models/sfin/sfin.py b/ppmat/models/sfin/sfin.py new file mode 100644 index 00000000..03bf2823 --- /dev/null +++ b/ppmat/models/sfin/sfin.py @@ -0,0 +1,390 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +SFIN: Noise Calibration and Spatial-Frequency Interactive Network for STEM Image Enhancement +Paper: CVPR 2025 - https://arxiv.org/pdf/2504.02555 +""" + +from typing import Dict + +import paddle +import paddle.nn as nn + +# BatchNorm semantic alignment: +# PyTorch: running = (1 - m_torch) * running + m_torch * batch, default m_torch=0.1 +# Paddle: running = m_paddle * running + (1 - m_paddle) * batch +# so m_paddle = 1 - m_torch = 0.9 +TORCH_BN_MOMENTUM = 0.1 +PADDLE_BN_MOMENTUM = 1.0 - TORCH_BN_MOMENTUM +BN_EPSILON = 1e-5 + + +def _bn_aligned(num_features: int) -> nn.BatchNorm2D: + """Create BatchNorm2D with PyTorch-aligned momentum semantics.""" + return nn.BatchNorm2D(num_features, momentum=PADDLE_BN_MOMENTUM, epsilon=BN_EPSILON) + + +class FourierUnit(nn.Layer): + """Fourier Unit for processing frequency domain features.""" + + def __init__(self, in_channels: int, out_channels: int): + super(FourierUnit, self).__init__() + fu_bound = 1.0 / ((out_channels * 2) * 1 * 1) ** 0.5 + + self.conv_layer = nn.Conv2D( + in_channels=in_channels * 2 + 2, + out_channels=out_channels * 2, + kernel_size=1, + stride=1, + padding=0, + bias_attr=False, + weight_attr=paddle.nn.initializer.KaimingUniform( + negative_slope=5**0.5, # a=sqrt(5) + mode='fan_in', + nonlinearity='leaky_relu' + ) + ) + self.bn = _bn_aligned(out_channels * 2) + self.relu = nn.ReLU() + + def forward(self, x): + batch = x.shape[0] + fft_dim = (-2, -1) + + # Real FFT with ortho normalization + ffted = paddle.fft.rfftn(x, axes=fft_dim, norm='ortho') # (B, C, H, W/2+1) complex + + # Split into real/imaginary parts + ffted_real = paddle.real(ffted) # (B, C, H, W/2+1) + ffted_imag = paddle.imag(ffted) # (B, C, H, W/2+1) + ffted = paddle.stack([ffted_real, ffted_imag], axis=-1) # (B, C, H, W/2+1, 2) + + # Permute to (B, C, 2, H, W/2+1) + ffted = ffted.transpose([0, 1, 4, 2, 3]) # (B, C, 2, H, W/2+1) + ffted = ffted.reshape([batch, -1] + list(ffted.shape[3:])) # (B, C*2, H, W/2+1) + + # Create coordinate grids with DYNAMIC batch handling (critical fix) + height, width = ffted.shape[-2:] + coords_vert = paddle.linspace(0, 1, height).reshape([1, 1, height, 1]) + # SAFE EXPAND: Use x.shape[0] instead of captured 'batch' variable + coords_vert = coords_vert.expand([x.shape[0], 1, height, width]) + + coords_hor = paddle.linspace(0, 1, width).reshape([1, 1, 1, width]) + coords_hor = coords_hor.expand([x.shape[0], 1, height, width]) + + # Concatenate coordinates and FFT features + ffted = paddle.concat([coords_vert, coords_hor, ffted], axis=1) # (B, C*2+2, H, W/2+1) + + # Process through convolution + ffted = self.conv_layer(ffted) + ffted = self.relu(self.bn(ffted)) # (B, C*2, H, W/2+1) + + # Reshape back to complex format: (B, C, 2, H, W/2+1) โ†’ (B, C, H, W/2+1, 2) + ffted = ffted.reshape([batch, -1, 2] + list(ffted.shape[2:])) # (B, C, 2, H, W/2+1) + ffted = ffted.transpose([0, 1, 3, 4, 2]) # (B, C, H, W/2+1, 2) + + # Convert back to complex tensor + ffted = paddle.complex(ffted[..., 0], ffted[..., 1]) # (B, C, H, W/2+1) complex + + # Inverse FFT with exact shape matching + output = paddle.fft.irfftn(ffted, s=x.shape[-2:], axes=fft_dim, norm='ortho') + return output + + +class SpectralTransform(nn.Layer): + """Spectral Transform block combining spatial and frequency domain processing.""" + + def __init__(self, in_channels: int): + super(SpectralTransform, self).__init__() + st1_fan_in = (in_channels // 2) * 3 * 3 + st1_bias_bound = 1.0 / st1_fan_in ** 0.5 + + st2_fan_in = in_channels * 3 * 3 # Input: [x (C/2), x2 (C/2)] โ†’ C channels + st2_bias_bound = 1.0 / st2_fan_in ** 0.5 + + self.conv1 = nn.Conv2D( + in_channels // 2, in_channels // 2, 3, padding=1, + weight_attr=nn.initializer.KaimingUniform( + negative_slope=5**0.5, + mode='fan_in', + nonlinearity='leaky_relu' + ), + bias_attr=nn.initializer.Uniform(-st1_bias_bound, st1_bias_bound) + ) + self.fu = FourierUnit(in_channels // 2, in_channels // 2) + self.conv2 = nn.Conv2D( + in_channels, in_channels // 2, 3, padding=1, + weight_attr=nn.initializer.KaimingUniform( + negative_slope=5**0.5, + mode='fan_in', + nonlinearity='leaky_relu' + ), + bias_attr=nn.initializer.Uniform(-st2_bias_bound, st2_bias_bound) + ) + + def forward(self, x): + x1 = self.conv1(x) + x2 = self.fu(x1) + x = self.conv2(paddle.concat([x, x2], axis=1)) + return x + + +class FFC(nn.Layer): + """Fast Fourier Convolution block for spatial-frequency interaction.""" + + def __init__(self, in_channels: int): + super(FFC, self).__init__() + ffc_fan_in = (in_channels // 2) * 3 * 3 + ffc_bias_bound = 1.0 / ffc_fan_in ** 0.5 + + self.convl2l = nn.Conv2D( + in_channels // 2, in_channels // 2, 3, padding=1, + weight_attr=nn.initializer.KaimingUniform( + negative_slope=5**0.5, + mode='fan_in', + nonlinearity='leaky_relu' + ), + bias_attr=nn.initializer.Uniform(-ffc_bias_bound, ffc_bias_bound) + ) + self.convl2g = nn.Conv2D( + in_channels // 2, in_channels // 2, 3, padding=1, + weight_attr=nn.initializer.KaimingUniform( + negative_slope=5**0.5, + mode='fan_in', + nonlinearity='leaky_relu' + ), + bias_attr=nn.initializer.Uniform(-ffc_bias_bound, ffc_bias_bound) + ) + self.convg2l = nn.Conv2D( + in_channels // 2, in_channels // 2, 3, padding=1, + weight_attr=nn.initializer.KaimingUniform( + negative_slope=5**0.5, + mode='fan_in', + nonlinearity='leaky_relu' + ), + bias_attr=nn.initializer.Uniform(-ffc_bias_bound, ffc_bias_bound) + ) + self.convg2g = SpectralTransform(in_channels) + + def forward(self, x): + if isinstance(x, tuple): + x_l, x_g = x + else: + B, C, H, W = x.shape + x_l = x + # Must be C//2 channels to match local branch split later + x_g = paddle.zeros([B, C // 2, H, W], dtype=x.dtype) + + out_xl = self.convl2l(x_l) + self.convg2l(x_g) + out_xg = self.convl2g(x_l) + self.convg2g(x_g) + return out_xl, out_xg + + +class SFIB(nn.Layer): + """Spatial-Frequency Interactive Block.""" + + def __init__(self, in_channels: int): + super(SFIB, self).__init__() + self.ffc = FFC(in_channels) + self.bn_l = _bn_aligned(in_channels // 2) + self.bn_g = _bn_aligned(in_channels // 2) + self.act_l = nn.ReLU() + self.act_g = nn.ReLU() + + def forward(self, x): + x_l, x_g = self.ffc(x) + x_l = self.act_l(self.bn_l(x_l)) + x_g = self.act_g(self.bn_g(x_g)) + return x_l, x_g + + +class ResnetBlock(nn.Layer): + """Residual block with SFIB.""" + + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + self.conv1 = SFIB(in_channels) + self.conv2 = SFIB(in_channels) + + def forward(self, x): + x_l, x_g = paddle.split(x, [self.in_channels // 2, self.in_channels // 2], axis=1) + id_l, id_g = x_l, x_g + # Apply two SFIB blocks with residual connection + x_l, x_g = self.conv1((x_l, x_g)) + x_l, x_g = self.conv2((x_l, x_g)) + + # Residual connection + x_l = id_l + x_l + x_g = id_g + x_g + + # Recombine branches + out = paddle.concat([x_l, x_g], axis=1) + return out + + +class SFIN(nn.Layer): + """ + SFIN: Noise Calibration and Spatial-Frequency Interactive Network for STEM Image Enhancement. + + Args: + in_channels (int): Number of input channels (default: 1 for grayscale images) + base_channels (int): Base number of channels (default: 64) + num_blocks (int): Number of ResNet blocks (default: 8) + + Reference: + Li et al., "Noise Calibration and Spatial-Frequency Interactive Network for + STEM Image Enhancement", CVPR 2025. + https://arxiv.org/pdf/2504.02555 + """ + + def __init__( + self, + in_channels: int = 1, + base_channels: int = 64, + num_blocks: int = 8, + input_name: str = "noisy", + target_name: str = "gt_enhance", + loss_type: str = "l1", + loss_weight: float = 1.0, + ): + super(SFIN, self).__init__() + self.in_channels = in_channels + self.base_channels = base_channels + self.num_blocks = num_blocks + self.input_name = input_name + self.target_name = target_name + self.loss_type = loss_type.lower() + self.loss_weight = loss_weight + + if self.loss_type == "l1": + self.criterion = nn.L1Loss() + elif self.loss_type == "mse": + self.criterion = nn.MSELoss() + else: + raise ValueError(f"Unsupported loss_type '{loss_type}', expected 'l1' or 'mse'.") + + # Build ResNet blocks with proper registration + blocks = [ResnetBlock(base_channels) for _ in range(num_blocks)] + self.body = nn.Sequential(*blocks) + + # Head convolution initialization + head_fan_in = in_channels * 3 * 3 + head_bias_bound = 1.0 / head_fan_in ** 0.5 + self.head_conv = nn.Conv2D( + in_channels, base_channels, 3, padding=1, + weight_attr=nn.initializer.KaimingUniform( + negative_slope=5**0.5, + mode='fan_in', + nonlinearity='leaky_relu' + ), + bias_attr=nn.initializer.Uniform(-head_bias_bound, head_bias_bound) + ) + + # Tail convolution initialization + tail_fan_in = base_channels * 3 * 3 + tail_bias_bound = 1.0 / tail_fan_in ** 0.5 + self.tail_conv = nn.Conv2D( + base_channels, in_channels, 3, padding=1, + weight_attr=nn.initializer.KaimingUniform( + negative_slope=5**0.5, + mode='fan_in', + nonlinearity='leaky_relu' + ), + bias_attr=nn.initializer.Uniform(-tail_bias_bound, tail_bias_bound) + ) + + def _forward_tensor(self, x: paddle.Tensor) -> paddle.Tensor: + """ + Tensor-only forward pass of SFIN. + + Args: + x: Input tensor of shape (B, C, H, W) + + Returns: + Enhanced image tensor of shape (B, C, H, W) + """ + x = self.head_conv(x) + shortcut = x + x = self.body(x) + x = x + shortcut + x = self.tail_conv(x) + return x + + def _get_input_tensor(self, batch: Dict) -> paddle.Tensor: + key_candidates = [self.input_name, "image", "noisy", "input", "x"] + for key in key_candidates: + if key in batch and batch[key] is not None: + return batch[key] + raise KeyError( + f"SFIN expects one of input keys {key_candidates}, but got keys: {list(batch.keys())}" + ) + + def _get_label_tensor(self, batch: Dict): + key_candidates = [ + self.target_name, + "gt_enhance", + "target", + "label", + "gt", + "clean", + "y", + ] + for key in key_candidates: + if key in batch and batch[key] is not None: + return batch[key] + return None + + def forward(self, batch): + """ + Unified forward for both: + 1) tensor -> enhanced tensor (for direct use / legacy scripts) + 2) dict -> trainer-ready output with loss_dict and pred_dict + """ + if isinstance(batch, dict): + x = self._get_input_tensor(batch) + enhanced = self._forward_tensor(x) + + pred_dict = { + self.target_name: enhanced, + "pred": enhanced, + } + loss_dict = {} + + label = self._get_label_tensor(batch) + if label is not None: + loss = self.criterion(enhanced, label) * self.loss_weight + loss_dict["loss"] = loss + + return {"loss_dict": loss_dict, "pred_dict": pred_dict} + + return self._forward_tensor(batch) + + def predict(self, batch: Dict) -> Dict: + """ + Prediction interface for BasePredictor. + + Args: + batch: Dictionary containing 'image' key with input tensor + + Returns: + Dictionary containing 'pred' key with enhanced image + """ + if isinstance(batch, dict): + x = self._get_input_tensor(batch) + enhanced = self._forward_tensor(x) + return {self.target_name: enhanced, "pred": enhanced} + + return self._forward_tensor(batch) diff --git a/spectrum_enhancement/README.md b/spectrum_enhancement/README.md new file mode 100644 index 00000000..333be039 --- /dev/null +++ b/spectrum_enhancement/README.md @@ -0,0 +1,27 @@ +# SPEN-Spectrum Enhancement + +## 1.Introduction + +Spectrum Enhancement (SE) focuses on enhancing and denoising spectral and microscopy data. Leveraging advanced deep learning techniques, SE aims to recover high-quality signals from noisy observations, enabling more accurate analysis of material properties at the atomic scale. This task is particularly valuable for STEM (Scanning Transmission Electron Microscopy) image processing, where noise reduction can significantly improve the visualization of crystal structures and defects. + +## 2.Framework Support Matrix + +| **Supported Functions** | **Support** | +| ----------------------------------- | :---------: | +| **ML Capabilities ยท Training** | | +| Single-GPU | โœ… | +| Distributed training | โœ… | +| Mixed precision (AMP) | โ€” | +| Fine-tuning | โœ… | +| **ML Capabilities ยท Predict** | | +| Standard inference | โœ… | +| Distributed inference | โ€” | +| **Data Pipeline** | | +| Local dataset loading | โœ… | +| Auto dataset download | โœ… | +| **Task Workflow** | | +| Training / Evaluation / Prediction | โœ… | + +## 3.Model README + +- [SFIN](./configs/sfin/README.md) diff --git a/spectrum_enhancement/configs/sfin/README.md b/spectrum_enhancement/configs/sfin/README.md new file mode 100644 index 00000000..b56f271c --- /dev/null +++ b/spectrum_enhancement/configs/sfin/README.md @@ -0,0 +1,170 @@ +# SFIN + +[Noise Calibration and Spatial-Frequency Interactive Network for STEM Image Enhancement](https://arxiv.org/pdf/2504.02555) + +## 1.Introduction + +SFIN is a CNN-based model for STEM image restoration. It targets paired reconstruction from noisy grayscale inputs and supports two experimental modes (HAADF and BF), each with two training targets (`enhance`, `detect`). + +## 2.Model Description + +SFIN takes `noisy` as input and predicts one target image (`gt_enhance` or `gt_detect`). + +- Noise calibration module for robust denoising. +- Spatial-frequency interaction blocks for detail recovery. +- Multi-scale feature extraction. + +Training objective (L1): + +$$ +\mathcal{L} = \left\| \hat{I} - I_{gt} \right\|_1 +$$ + +Evaluation metrics: +- PSNR +- SSIM + +## 3.Configurations + +| Config | Mode | Target | Output Dir | +| --- | --- | --- | --- | +| `sfin_tem_enhance.yaml` | HAADF (TEM) | `gt_enhance` | `./output/sfin_tem_enhance` | +| `sfin_tem_detect.yaml` | HAADF (TEM) | `gt_detect` | `./output/sfin_tem_detect` | +| `sfin_bf_enhance.yaml` | BF | `gt_enhance` | `./output/sfin_bf_enhance` | +| `sfin_bf_detect.yaml` | BF | `gt_detect` | `./output/sfin_bf_detect` | + +## 4.Dataset + +### Format + +Expected paired directory format: + +```text +data/ # HAADF train + noisy/ + gt_enhance/ + gt_detect/ + +data_test/ # HAADF val/test + noisy/ + gt_enhance/ + gt_detect/ + +bf_data/ # BF train + noisy/ + gt_enhance/ + gt_detect/ + +bf_data_test/ # BF val/test + noisy/ + gt_enhance/ + gt_detect/ +``` + +### Download Links + +| Dataset | Train | Test | Link | +| :---: | :---: | :---: | :---: | +| HAADF train dataset | 1000 | - | [haadf_data.zip](https://paddle-org.bj.bcebos.com/paddlematerials/datasets/SFIN_datasets/haadf_data.zip) | +| BF train dataset | 1000 | - | [bf_data.zip](https://paddle-org.bj.bcebos.com/paddlematerials/datasets/SFIN_datasets/bf_data.zip) | +| HAADF test dataset | - | 100 | [haadf_data_test.zip](https://paddle-org.bj.bcebos.com/paddlematerials/datasets/SFIN_datasets/haadf_data_test.zip) | +| BF test dataset | - | 100 | [bf_data_test.zip](https://paddle-org.bj.bcebos.com/paddlematerials/datasets/SFIN_datasets/bf_data_test.zip) | + +### Auto Download Behavior + +`STEMImageDataset` supports local loading + auto download: + +- If `data_path` exists locally, data is loaded directly. +- If `data_path` is missing and `download=True`, URL is inferred by `data_path` basename: + - `data` -> `haadf_data.zip` + - `data_test` -> `haadf_data_test.zip` + - `bf_data` -> `bf_data.zip` + - `bf_data_test` -> `bf_data_test.zip` +- Archive formats `zip` and `tar.gz` are both supported. + +## 5.Results + + + + + + + + + + + + + + + + + + + + + + + + + + +
Model NameDatasetPSNR (dB)SSIMGPUsTraining timeConfigCheckpoint | Log
sfin_tem_enhanceSTEM Enhancement38.740.96221 (V100-32GB)~21.5 hourssfin_tem_enhancecheckpoint | log
+ +## 6.Command + +Run from `PaddleMaterials` root: + +```bash +pip install -e . --no-build-isolation +``` + +### Training + +```bash +# HAADF enhance +python spectrum_enhancement/train.py \ + -c spectrum_enhancement/configs/sfin/sfin_tem_enhance.yaml + +# HAADF detect +python spectrum_enhancement/train.py \ + -c spectrum_enhancement/configs/sfin/sfin_tem_detect.yaml + +# BF enhance +python spectrum_enhancement/train.py \ + -c spectrum_enhancement/configs/sfin/sfin_bf_enhance.yaml + +# BF detect +python spectrum_enhancement/train.py \ + -c spectrum_enhancement/configs/sfin/sfin_bf_detect.yaml +``` + +### Evaluation + +```bash +python spectrum_enhancement/train.py \ + -c spectrum_enhancement/configs/sfin/sfin_tem_enhance.yaml \ + Global.do_train=False Global.do_eval=True Global.do_test=True \ + Trainer.pretrained_model_path='path/to/model.pdparams' +``` + +### Prediction + +```bash +python spectrum_enhancement/predict.py \ + --config_path spectrum_enhancement/configs/sfin/sfin_tem_enhance.yaml \ + --checkpoint_path https://paddle-org.bj.bcebos.com/paddlematerials/checkpoints/spectrum_enhancement/sfin/sfin_he_500.pdparams \ + --data_path ./data_test \ + --output_dir ./output/sfin_predictions +``` + +## 7.Citation + +```bibtex +@inproceedings{li2025sfin, + title={Noise Calibration and Spatial-Frequency Interactive Network for STEM Image Enhancement}, + author={Li, Hesong and Wu, Ziqi and Shao, Ruiwen and Zhang, Tao and Fu, Ying}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, + year={2025} +} +``` diff --git a/spectrum_enhancement/configs/sfin/sfin_bf_detect.yaml b/spectrum_enhancement/configs/sfin/sfin_bf_detect.yaml new file mode 100644 index 00000000..339dda4f --- /dev/null +++ b/spectrum_enhancement/configs/sfin/sfin_bf_detect.yaml @@ -0,0 +1,131 @@ +Global: + do_train: True + do_eval: True + do_test: True + prim_eager_enabled: False + +Trainer: + max_epochs: 500 + seed: 42 + output_dir: ./output/sfin_bf_detect + save_freq: 100 + log_freq: 100 + start_eval_epoch: 1 + eval_freq: 1 + pretrained_model_path: null + pretrained_weight_name: null + resume_from_checkpoint: null + use_amp: False + amp_level: "O1" + eval_with_no_grad: True + gradient_accumulation_steps: 1 + + best_metric_indicator: "eval_metric" + name_for_best_metric: "gt_detect" + greater_is_better: True + + compute_metric_during_train: True + metric_strategy_during_eval: "epoch" + + use_visualdl: False + use_wandb: False + use_tensorboard: False + +Model: + __class_name__: SFIN + __init_params__: + in_channels: 1 + base_channels: 64 + num_blocks: 8 + input_name: "noisy" + target_name: "gt_detect" + loss_type: "l1" + loss_weight: 1.0 + +Optimizer: + __class_name__: Adam + __init_params__: + lr: + __class_name__: MultiStepDecay + __init_params__: + learning_rate: 2.0e-4 + milestones: [250, 400, 425, 450, 475] + gamma: 0.5 + by_epoch: True + beta1: 0.9 + beta2: 0.999 + epsilon: 1.0e-8 + weight_decay: 0.0 + +Metric: + gt_detect: + __class_name__: PSNRMetric + __init_params__: + data_range: 255.0 + +Dataset: + train: + dataset: + __class_name__: STEMImageDataset + __init_params__: + data_path: "./bf_data" + noisy_subdir: "noisy" + target_subdir: "gt_detect" + file_suffix: ".png" + strict_index_naming: True + scale_to_unit: False + download: True + sampler: + __class_name__: BatchSampler + __init_params__: + shuffle: True + drop_last: True + batch_size: 8 + loader: + num_workers: 0 + use_shared_memory: False + collate_fn: DefaultCollator + + val: + dataset: + __class_name__: STEMImageDataset + __init_params__: + data_path: "./bf_data_test" + noisy_subdir: "noisy" + target_subdir: "gt_detect" + file_suffix: ".png" + strict_index_naming: True + scale_to_unit: False + download: True + sampler: + __class_name__: BatchSampler + __init_params__: + shuffle: False + drop_last: False + batch_size: 1 + loader: + num_workers: 0 + use_shared_memory: False + collate_fn: DefaultCollator + + test: + dataset: + __class_name__: STEMImageDataset + __init_params__: + data_path: "./bf_data_test" + noisy_subdir: "noisy" + target_subdir: "gt_detect" + file_suffix: ".png" + strict_index_naming: True + scale_to_unit: False + download: True + sampler: + __class_name__: BatchSampler + __init_params__: + shuffle: False + drop_last: False + batch_size: 1 + loader: + num_workers: 0 + use_shared_memory: False + collate_fn: DefaultCollator diff --git a/spectrum_enhancement/configs/sfin/sfin_bf_enhance.yaml b/spectrum_enhancement/configs/sfin/sfin_bf_enhance.yaml new file mode 100644 index 00000000..1f2ebf09 --- /dev/null +++ b/spectrum_enhancement/configs/sfin/sfin_bf_enhance.yaml @@ -0,0 +1,131 @@ +Global: + do_train: True + do_eval: True + do_test: True + prim_eager_enabled: False + +Trainer: + max_epochs: 500 + seed: 42 + output_dir: ./output/sfin_bf_enhance + save_freq: 100 + log_freq: 100 + start_eval_epoch: 1 + eval_freq: 1 + pretrained_model_path: null + pretrained_weight_name: null + resume_from_checkpoint: null + use_amp: False + amp_level: "O1" + eval_with_no_grad: True + gradient_accumulation_steps: 1 + + best_metric_indicator: "eval_metric" + name_for_best_metric: "gt_enhance" + greater_is_better: True + + compute_metric_during_train: True + metric_strategy_during_eval: "epoch" + + use_visualdl: False + use_wandb: False + use_tensorboard: False + +Model: + __class_name__: SFIN + __init_params__: + in_channels: 1 + base_channels: 64 + num_blocks: 8 + input_name: "noisy" + target_name: "gt_enhance" + loss_type: "l1" + loss_weight: 1.0 + +Optimizer: + __class_name__: Adam + __init_params__: + lr: + __class_name__: MultiStepDecay + __init_params__: + learning_rate: 2.0e-4 + milestones: [250, 400, 425, 450, 475] + gamma: 0.5 + by_epoch: True + beta1: 0.9 + beta2: 0.999 + epsilon: 1.0e-8 + weight_decay: 0.0 + +Metric: + gt_enhance: + __class_name__: PSNRMetric + __init_params__: + data_range: 255.0 + +Dataset: + train: + dataset: + __class_name__: STEMImageDataset + __init_params__: + data_path: "./bf_data" + noisy_subdir: "noisy" + target_subdir: "gt_enhance" + file_suffix: ".png" + strict_index_naming: True + scale_to_unit: False + download: True + sampler: + __class_name__: BatchSampler + __init_params__: + shuffle: True + drop_last: True + batch_size: 8 + loader: + num_workers: 0 + use_shared_memory: False + collate_fn: DefaultCollator + + val: + dataset: + __class_name__: STEMImageDataset + __init_params__: + data_path: "./bf_data_test" + noisy_subdir: "noisy" + target_subdir: "gt_enhance" + file_suffix: ".png" + strict_index_naming: True + scale_to_unit: False + download: True + sampler: + __class_name__: BatchSampler + __init_params__: + shuffle: False + drop_last: False + batch_size: 1 + loader: + num_workers: 0 + use_shared_memory: False + collate_fn: DefaultCollator + + test: + dataset: + __class_name__: STEMImageDataset + __init_params__: + data_path: "./bf_data_test" + noisy_subdir: "noisy" + target_subdir: "gt_enhance" + file_suffix: ".png" + strict_index_naming: True + scale_to_unit: False + download: True + sampler: + __class_name__: BatchSampler + __init_params__: + shuffle: False + drop_last: False + batch_size: 1 + loader: + num_workers: 0 + use_shared_memory: False + collate_fn: DefaultCollator diff --git a/spectrum_enhancement/configs/sfin/sfin_tem_detect.yaml b/spectrum_enhancement/configs/sfin/sfin_tem_detect.yaml new file mode 100644 index 00000000..43ce2bee --- /dev/null +++ b/spectrum_enhancement/configs/sfin/sfin_tem_detect.yaml @@ -0,0 +1,131 @@ +Global: + do_train: True + do_eval: True + do_test: True + prim_eager_enabled: False + +Trainer: + max_epochs: 500 + seed: 42 + output_dir: ./output/sfin_tem_detect + save_freq: 100 + log_freq: 100 + start_eval_epoch: 1 + eval_freq: 1 + pretrained_model_path: null + pretrained_weight_name: null + resume_from_checkpoint: null + use_amp: False + amp_level: "O1" + eval_with_no_grad: True + gradient_accumulation_steps: 1 + + best_metric_indicator: "eval_metric" + name_for_best_metric: "gt_detect" + greater_is_better: True + + compute_metric_during_train: True + metric_strategy_during_eval: "epoch" + + use_visualdl: False + use_wandb: False + use_tensorboard: False + +Model: + __class_name__: SFIN + __init_params__: + in_channels: 1 + base_channels: 64 + num_blocks: 8 + input_name: "noisy" + target_name: "gt_detect" + loss_type: "l1" + loss_weight: 1.0 + +Optimizer: + __class_name__: Adam + __init_params__: + lr: + __class_name__: MultiStepDecay + __init_params__: + learning_rate: 2.0e-4 + milestones: [250, 400, 425, 450, 475] + gamma: 0.5 + by_epoch: True + beta1: 0.9 + beta2: 0.999 + epsilon: 1.0e-8 + weight_decay: 0.0 + +Metric: + gt_detect: + __class_name__: PSNRMetric + __init_params__: + data_range: 255.0 + +Dataset: + train: + dataset: + __class_name__: STEMImageDataset + __init_params__: + data_path: "./data" + noisy_subdir: "noisy" + target_subdir: "gt_detect" + file_suffix: ".png" + strict_index_naming: True + scale_to_unit: False + download: True + sampler: + __class_name__: BatchSampler + __init_params__: + shuffle: True + drop_last: True + batch_size: 8 + loader: + num_workers: 0 + use_shared_memory: False + collate_fn: DefaultCollator + + val: + dataset: + __class_name__: STEMImageDataset + __init_params__: + data_path: "./data_test" + noisy_subdir: "noisy" + target_subdir: "gt_detect" + file_suffix: ".png" + strict_index_naming: True + scale_to_unit: False + download: True + sampler: + __class_name__: BatchSampler + __init_params__: + shuffle: False + drop_last: False + batch_size: 1 + loader: + num_workers: 0 + use_shared_memory: False + collate_fn: DefaultCollator + + test: + dataset: + __class_name__: STEMImageDataset + __init_params__: + data_path: "./data_test" + noisy_subdir: "noisy" + target_subdir: "gt_detect" + file_suffix: ".png" + strict_index_naming: True + scale_to_unit: False + download: True + sampler: + __class_name__: BatchSampler + __init_params__: + shuffle: False + drop_last: False + batch_size: 1 + loader: + num_workers: 0 + use_shared_memory: False + collate_fn: DefaultCollator diff --git a/spectrum_enhancement/configs/sfin/sfin_tem_enhance.yaml b/spectrum_enhancement/configs/sfin/sfin_tem_enhance.yaml new file mode 100644 index 00000000..b97ddd9f --- /dev/null +++ b/spectrum_enhancement/configs/sfin/sfin_tem_enhance.yaml @@ -0,0 +1,131 @@ +Global: + do_train: True + do_eval: True + do_test: True + prim_eager_enabled: False + +Trainer: + max_epochs: 500 + seed: 42 + output_dir: ./output/sfin_tem_enhance + save_freq: 100 + log_freq: 100 + start_eval_epoch: 1 + eval_freq: 1 + pretrained_model_path: null + pretrained_weight_name: null + resume_from_checkpoint: null + use_amp: False + amp_level: "O1" + eval_with_no_grad: True + gradient_accumulation_steps: 1 + + best_metric_indicator: "eval_metric" + name_for_best_metric: "gt_enhance" + greater_is_better: True + + compute_metric_during_train: True + metric_strategy_during_eval: "epoch" + + use_visualdl: False + use_wandb: False + use_tensorboard: False + +Model: + __class_name__: SFIN + __init_params__: + in_channels: 1 + base_channels: 64 + num_blocks: 8 + input_name: "noisy" + target_name: "gt_enhance" + loss_type: "l1" + loss_weight: 1.0 + +Optimizer: + __class_name__: Adam + __init_params__: + lr: + __class_name__: MultiStepDecay + __init_params__: + learning_rate: 2.0e-4 + milestones: [250, 400, 425, 450, 475] + gamma: 0.5 + by_epoch: True + beta1: 0.9 + beta2: 0.999 + epsilon: 1.0e-8 + weight_decay: 0.0 + +Metric: + gt_enhance: + __class_name__: PSNRMetric + __init_params__: + data_range: 255.0 + +Dataset: + train: + dataset: + __class_name__: STEMImageDataset + __init_params__: + data_path: "./data" + noisy_subdir: "noisy" + target_subdir: "gt_enhance" + file_suffix: ".png" + strict_index_naming: True + scale_to_unit: False + download: True + sampler: + __class_name__: BatchSampler + __init_params__: + shuffle: True + drop_last: True + batch_size: 8 + loader: + num_workers: 0 + use_shared_memory: False + collate_fn: DefaultCollator + + val: + dataset: + __class_name__: STEMImageDataset + __init_params__: + data_path: "./data_test" + noisy_subdir: "noisy" + target_subdir: "gt_enhance" + file_suffix: ".png" + strict_index_naming: True + scale_to_unit: False + download: True + sampler: + __class_name__: BatchSampler + __init_params__: + shuffle: False + drop_last: False + batch_size: 1 + loader: + num_workers: 0 + use_shared_memory: False + collate_fn: DefaultCollator + + test: + dataset: + __class_name__: STEMImageDataset + __init_params__: + data_path: "./data_test" + noisy_subdir: "noisy" + target_subdir: "gt_enhance" + file_suffix: ".png" + strict_index_naming: True + scale_to_unit: False + download: True + sampler: + __class_name__: BatchSampler + __init_params__: + shuffle: False + drop_last: False + batch_size: 1 + loader: + num_workers: 0 + use_shared_memory: False + collate_fn: DefaultCollator diff --git a/spectrum_enhancement/predict.py b/spectrum_enhancement/predict.py new file mode 100644 index 00000000..6ac35258 --- /dev/null +++ b/spectrum_enhancement/predict.py @@ -0,0 +1,555 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import argparse +import copy +from abc import ABC, abstractmethod +from contextlib import nullcontext +from pathlib import Path +from typing import Any +from typing import Dict +from typing import Optional +from typing import Type + +import numpy as np +import paddle +from omegaconf import OmegaConf +from PIL import Image + +from ppmat.datasets.stem_image_dataset import STEMImageDataset +from ppmat.models import build_model +from ppmat.models import build_model_from_name +from ppmat.utils import logger +from ppmat.utils import save_load + + +def _normalize_split(split: Optional[str]) -> Optional[str]: + if split is None: + return None + if split == "validation": + return "val" + return split + + +CASE_PROCESSOR_REGISTRY: Dict[str, Type["BaseCaseProcessor"]] = {} + + +def register_case_processor(cls: Type["BaseCaseProcessor"]) -> Type["BaseCaseProcessor"]: + case_name = cls.case_name.strip().lower() + if not case_name: + raise ValueError("Case processor must define a non-empty `case_name`.") + CASE_PROCESSOR_REGISTRY[case_name] = cls + return cls + + +class BaseCaseProcessor(ABC): + """ + Case-level hooks for custom data processing and output processing. + + To add a new model case: + 1. Subclass BaseCaseProcessor. + 2. Implement the abstract methods. + 3. Register with @register_case_processor. + """ + + case_name = "" + + def __init__(self, config: Dict[str, Any]): + self.config = config + + @abstractmethod + def build_dataset(self, args: argparse.Namespace) -> paddle.io.Dataset: + raise NotImplementedError + + def prepare_model_input( + self, + sample: Dict[str, Any], + index: int, + args: argparse.Namespace, + ) -> Any: + return sample + + def forward_model( + self, + model: paddle.nn.Layer, + model_input: Any, + args: argparse.Namespace, + ) -> Any: + if hasattr(model, "predict"): + return model.predict(model_input) + return model(model_input) + + @abstractmethod + def parse_model_output( + self, + model_output: Any, + sample: Dict[str, Any], + index: int, + args: argparse.Namespace, + ) -> Any: + raise NotImplementedError + + @abstractmethod + def save_prediction( + self, + parsed_output: Any, + sample: Dict[str, Any], + index: int, + output_dir: Path, + args: argparse.Namespace, + ) -> Path: + raise NotImplementedError + + +@register_case_processor +class SFINCaseProcessor(BaseCaseProcessor): + case_name = "sfin" + + def __init__(self, config: Dict[str, Any]): + super().__init__(config) + model_init = config.get("Model", {}).get("__init_params__", {}) + self.target_name = model_init.get("target_name", "gt_enhance") + + def _resolve_dataset_init_params(self, split: Optional[str]) -> Dict[str, Any]: + dataset_cfg_root = self.config.get("Dataset", {}) + if not isinstance(dataset_cfg_root, dict): + return {} + + split = _normalize_split(split) + candidate_keys = [] + if split is not None: + candidate_keys.append(split) + candidate_keys.extend(["test", "val", "train"]) + + for key in candidate_keys: + branch_cfg = dataset_cfg_root.get(key) + if not isinstance(branch_cfg, dict): + continue + dataset_cfg = branch_cfg.get("dataset", {}) + if not isinstance(dataset_cfg, dict): + continue + if dataset_cfg.get("__class_name__") == "STEMImageDataset": + return copy.deepcopy(dataset_cfg.get("__init_params__", {})) + return {} + + def build_dataset(self, args: argparse.Namespace) -> paddle.io.Dataset: + init_params = self._resolve_dataset_init_params(args.split) + + init_params["data_path"] = args.data_path or init_params.get("data_path", "./data_test") + init_params["file_suffix"] = args.file_suffix or init_params.get("file_suffix", ".png") + init_params["split"] = ( + _normalize_split(args.split) + if args.split is not None + else init_params.get("split", None) + ) + + if args.data_count > 0: + init_params["data_count"] = args.data_count + else: + init_params["data_count"] = None + + if args.noisy_subdir is not None: + init_params["noisy_subdir"] = args.noisy_subdir + if args.target_subdir is not None: + init_params["target_subdir"] = args.target_subdir + + # Preserve config defaults unless CLI explicitly overrides. + if args.download is not None: + init_params["download"] = bool(args.download) + if args.force_download is not None: + init_params["force_download"] = bool(args.force_download) + return STEMImageDataset(**init_params) + + def prepare_model_input( + self, + sample: Dict[str, Any], + index: int, + args: argparse.Namespace, + ) -> Dict[str, Any]: + if not isinstance(sample, dict): + return sample + + model_init = self.config.get("Model", {}).get("__init_params__", {}) + input_key = model_init.get("input_name", "noisy") + key_candidates = [input_key, "image", "noisy", "input", "x"] + + model_input = dict(sample) + for key in key_candidates: + x = model_input.get(key) + if isinstance(x, paddle.Tensor): + if x.ndim == 3: + model_input[key] = x.unsqueeze(0) + elif x.ndim == 2: + model_input[key] = x.unsqueeze(0).unsqueeze(0) + break + return model_input + + def _pick_prediction_tensor(self, output: Any) -> paddle.Tensor: + if isinstance(output, dict): + if "pred_dict" in output and isinstance(output["pred_dict"], dict): + output = output["pred_dict"] + + for key in [self.target_name, "pred", "output", "enhanced", "image"]: + if key in output and output[key] is not None: + pred = output[key] + break + else: + pred = None + for value in output.values(): + if isinstance(value, paddle.Tensor): + pred = value + break + if pred is None: + raise KeyError( + "Cannot find prediction tensor in model output dict. " + f"Keys: {list(output.keys())}" + ) + elif isinstance(output, (list, tuple)): + if not output: + raise ValueError("Model output list/tuple is empty.") + pred = output[0] + else: + pred = output + + if not isinstance(pred, paddle.Tensor): + pred = paddle.to_tensor(pred) + return pred + + def parse_model_output( + self, + model_output: Any, + sample: Dict[str, Any], + index: int, + args: argparse.Namespace, + ) -> np.ndarray: + pred = self._pick_prediction_tensor(model_output) + pred = paddle.clip(pred, min=0.0, max=255.0) + pred_np = pred.squeeze().detach().cpu().numpy().astype(np.uint8) + + if pred_np.ndim == 3 and pred_np.shape[0] in (1, 3): + pred_np = np.transpose(pred_np, (1, 2, 0)) + if pred_np.ndim == 3 and pred_np.shape[-1] == 1: + pred_np = pred_np[..., 0] + return pred_np + + def save_prediction( + self, + parsed_output: np.ndarray, + sample: Dict[str, Any], + index: int, + output_dir: Path, + args: argparse.Namespace, + ) -> Path: + file_name = sample.get("name", f"{index}{args.file_suffix or '.png'}") + if args.save_suffix: + suffix = args.save_suffix + if not suffix.startswith("."): + suffix = f".{suffix}" + file_name = f"{Path(file_name).stem}{suffix}" + elif Path(file_name).suffix == "": + file_name = f"{file_name}{args.file_suffix or '.png'}" + + save_path = output_dir / file_name + Image.fromarray(parsed_output).save(save_path) + return save_path + + +class SpectrumPredictor: + def __init__( + self, + case: str, + device: str, + config_path: Optional[str] = None, + checkpoint_path: Optional[str] = None, + model_name: Optional[str] = None, + weights_name: Optional[str] = None, + ): + self.case = case.strip().lower() + self.config_path = config_path + self.checkpoint_path = checkpoint_path + self.model_name = model_name + self.weights_name = weights_name + self.device = device + + paddle.set_device(device) + + if self.model_name: + logger.info( + f"Loading predefined model by name: {self.model_name} " + f"(weights_name={self.weights_name})" + ) + self.model, self.config = build_model_from_name( + self.model_name, self.weights_name + ) + self.model.eval() + else: + if not self.config_path: + raise ValueError( + "`config_path` is required when `model_name` is not provided." + ) + if not self.checkpoint_path: + raise ValueError( + "`checkpoint_path` is required when `model_name` is not provided." + ) + self.config = self._load_config(self.config_path) + self.model = self._build_model_and_load_checkpoint() + + self.case_processor = self._build_case_processor(self.case) + self.eval_with_no_grad = ( + self.config.get("Predict", {}).get("eval_with_no_grad", True) + ) + + def _load_config(self, config_path: str) -> Dict[str, Any]: + config = OmegaConf.load(config_path) + return OmegaConf.to_container(config, resolve=True) + + def _build_model_and_load_checkpoint(self): + model_cfg = self.config.get("Model") + if model_cfg is None: + raise ValueError("`Model` section is required in config.") + model = build_model(model_cfg) + self._load_checkpoint(model, self.checkpoint_path) + model.eval() + return model + + def _build_case_processor(self, case: str) -> BaseCaseProcessor: + processor_cls = CASE_PROCESSOR_REGISTRY.get(case) + if processor_cls is None: + available = ", ".join(sorted(CASE_PROCESSOR_REGISTRY.keys())) + raise ValueError(f"Unsupported case '{case}'. Available cases: [{available}]") + return processor_cls(self.config) + + @staticmethod + def _load_checkpoint(model, checkpoint_path: Optional[str]) -> None: + if not checkpoint_path: + raise ValueError("`checkpoint_path` must not be empty.") + checkpoint_loaded = False + try: + checkpoint = paddle.load(checkpoint_path) + if isinstance(checkpoint, dict): + state_dict = None + for key in ("model_state_dict", "model", "state_dict"): + if key in checkpoint and isinstance(checkpoint[key], dict): + state_dict = checkpoint[key] + break + if state_dict is None: + state_dict = checkpoint + model.set_state_dict(state_dict) + checkpoint_loaded = True + except Exception: + checkpoint_loaded = False + + if not checkpoint_loaded: + save_load.load_pretrain(model, checkpoint_path) + + def run(self, args: argparse.Namespace) -> list[Path]: + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + dataset = self.case_processor.build_dataset(args) + if len(dataset) == 0: + raise ValueError("No samples found in dataset.") + + saved_paths = [] + context = paddle.no_grad() if self.eval_with_no_grad else nullcontext() + with context: + for idx in range(len(dataset)): + sample = dataset[idx] + model_input = self.case_processor.prepare_model_input(sample, idx, args) + model_output = self.case_processor.forward_model( + self.model, + model_input, + args, + ) + parsed_output = self.case_processor.parse_model_output( + model_output, + sample, + idx, + args, + ) + save_path = self.case_processor.save_prediction( + parsed_output, + sample, + idx, + output_dir, + args, + ) + saved_paths.append(save_path) + + return saved_paths + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Generic spectrum enhancement prediction with case-level hooks." + ) + parser.add_argument( + "--case", + type=str, + default="sfin", + help="Prediction case name. Extend by registering a new case processor.", + ) + parser.add_argument( + "--model_name", + type=str, + default=None, + help=( + "Optional predefined model name from MODEL_REGISTRY. " + "If provided, `config_path` and `checkpoint_path` are optional." + ), + ) + parser.add_argument( + "--weights_name", + type=str, + default=None, + help=( + "Optional weight filename when `model_name` is used " + "(e.g., best.pdparams / latest.pdparams)." + ), + ) + parser.add_argument( + "--config_path", + type=str, + default="./spectrum_enhancement/configs/sfin/sfin_tem_enhance.yaml", + help="Path to model config yaml (used when model_name is not provided).", + ) + parser.add_argument( + "--checkpoint_path", + type=str, + default=None, + help="Path or URL to checkpoint (*.pdparams) (used when model_name is not provided).", + ) + parser.add_argument( + "--data_path", + type=str, + default=None, + help="Root directory of input data. If omitted, infer from config Dataset section.", + ) + parser.add_argument( + "--split", + type=str, + default=None, + choices=["train", "val", "validation", "test"], + help="Dataset split to use. If omitted, case processor chooses default split.", + ) + parser.add_argument( + "--output_dir", + type=str, + default=None, + help="Directory to save predictions. Defaults to /predictions.", + ) + parser.add_argument( + "--file_suffix", + type=str, + default=None, + help="Input file suffix for dataset scanning (e.g., .png).", + ) + parser.add_argument( + "--save_suffix", + type=str, + default=None, + help="Optional output file suffix override (e.g., .png).", + ) + parser.add_argument( + "--data_count", + type=int, + default=-1, + help="Max number of samples to process, <=0 means all.", + ) + parser.add_argument( + "--noisy_subdir", + type=str, + default=None, + help="Optional override for noisy image sub-directory.", + ) + parser.add_argument( + "--target_subdir", + type=str, + default=None, + help="Optional override for target image sub-directory.", + ) + parser.add_argument( + "--download", + action="store_true", + default=None, + help=( + "Enable auto-download when data_path is missing. " + "If omitted, keep dataset config default." + ), + ) + parser.add_argument( + "--force_download", + action="store_true", + default=None, + help=( + "Force re-download dataset archive. " + "If omitted, keep dataset config default." + ), + ) + parser.add_argument( + "--device", + type=str, + default="gpu" if paddle.device.cuda.device_count() > 0 else "cpu", + choices=["cpu", "gpu"], + help="Device to run inference.", + ) + return parser.parse_args() + + +def resolve_output_dir(args: argparse.Namespace, config: Dict[str, Any]) -> str: + if args.output_dir: + return args.output_dir + trainer_output_dir = config.get("Trainer", {}).get("output_dir") + if trainer_output_dir: + return str(Path(trainer_output_dir) / "predictions") + if args.config_path: + return str(Path("./output") / Path(args.config_path).stem / "predictions") + if args.model_name: + return str(Path("./output") / args.model_name / "predictions") + return str(Path("./output") / "spectrum_enhancement" / "predictions") + + +def validate_args(args: argparse.Namespace) -> None: + # Backward-compatible behavior: + # - Existing config+checkpoint workflow keeps working. + # - New model_name workflow is optional. + if args.model_name: + return + if not args.config_path or not args.checkpoint_path: + raise ValueError( + "Either provide `--model_name`, or provide both " + "`--config_path` and `--checkpoint_path`." + ) + + +def main(): + args = parse_args() + validate_args(args) + predictor = SpectrumPredictor( + case=args.case, + device=args.device, + config_path=args.config_path, + checkpoint_path=args.checkpoint_path, + model_name=args.model_name, + weights_name=args.weights_name, + ) + args.output_dir = resolve_output_dir(args, predictor.config) + saved_paths = predictor.run(args) + logger.info(f"Saved {len(saved_paths)} predictions to {args.output_dir}") + + +if __name__ == "__main__": + main() diff --git a/spectrum_enhancement/train.py b/spectrum_enhancement/train.py new file mode 100644 index 00000000..af6202a2 --- /dev/null +++ b/spectrum_enhancement/train.py @@ -0,0 +1,324 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import argparse +import copy +import datetime +import os +import os.path as osp +from abc import ABC +from typing import Any +from typing import Dict +from typing import Type + +import paddle.distributed as dist +import paddle.distributed.fleet as fleet +from omegaconf import OmegaConf + +from ppmat.datasets import build_dataloader +from ppmat.datasets import set_signal_handlers +from ppmat.datasets.transform import run_dataset_transform +from ppmat.metrics import build_metric +from ppmat.models import build_model +from ppmat.optimizer import build_optimizer +from ppmat.trainer.base_trainer import BaseTrainer +from ppmat.utils import logger +from ppmat.utils import misc +from ppmat.utils.eager_comp_setting import setting_eager_mode + + +def read_independent_dataloader_config(config: Dict[str, Any]): + if config["Global"].get("do_train", True): + train_data_cfg = config["Dataset"].get("train") + assert train_data_cfg is not None, ( + "train_data_cfg must be defined when Global.do_train is True" + ) + train_loader = build_dataloader(train_data_cfg) + else: + train_loader = None + + if config["Global"].get("do_eval", False) or config["Global"].get("do_train", True): + val_data_cfg = config["Dataset"].get("val") + if val_data_cfg is not None: + val_loader = build_dataloader(val_data_cfg) + else: + logger.info("No validation dataset defined.") + val_loader = None + else: + val_loader = None + + if config["Global"].get("do_test", False): + test_data_cfg = config["Dataset"].get("test") + assert test_data_cfg is not None, ( + "test_data_cfg must be defined when Global.do_test is True" + ) + test_loader = build_dataloader(test_data_cfg) + else: + test_loader = None + return train_loader, val_loader, test_loader + + +TRAIN_CASE_REGISTRY: Dict[str, Type["BaseTrainCase"]] = {} + + +def register_train_case(cls: Type["BaseTrainCase"]) -> Type["BaseTrainCase"]: + case_name = cls.case_name.strip().lower() + if not case_name: + raise ValueError("Train case must define a non-empty `case_name`.") + TRAIN_CASE_REGISTRY[case_name] = cls + return cls + + +class BaseTrainCase(ABC): + case_name = "" + + def __init__(self, config: Dict[str, Any]): + self.config = config + + def build_dataloaders(self): + set_signal_handlers() + dataset_cfg = self.config.get("Dataset", {}) + if dataset_cfg.get("split_dataset_ratio") is not None: + loader = build_dataloader(dataset_cfg) + train_loader = loader.get("train", None) + val_loader = loader.get("val", None) + test_loader = loader.get("test", None) + return train_loader, val_loader, test_loader + return read_independent_dataloader_config(self.config) + + def _maybe_apply_dataset_transform(self, train_loader, model_cfg: Dict[str, Any]): + if not self.config["Global"].get("do_train", True): + return + dataset_trans_cfg = self.config.get("Dataset", {}).get("transform") + if dataset_trans_cfg is None: + return + if train_loader is None: + raise ValueError( + "Dataset.transform is configured, but train_loader is None." + ) + + trans_cfg = copy.deepcopy(dataset_trans_cfg) + trans_func = trans_cfg.pop("__class_name__", None) + trans_params = trans_cfg.pop("__init_params__", {}) + if trans_func is None: + raise KeyError("Dataset.transform.__class_name__ is required.") + + label_names = self.config.get("Global", {}).get("label_names") + if label_names is None: + raise KeyError( + "Global.label_names is required when Dataset.transform is enabled." + ) + + logger.info(f"Using dataset transform function: {trans_func}") + data_mean, data_std = run_dataset_transform( + trans_func, train_loader, label_names, **trans_params + ) + logger.info( + f"Target is {label_names}, data mean is {data_mean}, data std is {data_std}" + ) + + model_cfg.setdefault("__init_params__", {}) + model_cfg["__init_params__"]["data_mean"] = data_mean + model_cfg["__init_params__"]["data_std"] = data_std + + def build_model(self, train_loader, val_loader, test_loader): + model_cfg = copy.deepcopy(self.config["Model"]) + self._maybe_apply_dataset_transform(train_loader, model_cfg) + return build_model(model_cfg) + + def build_optimizer(self, model, train_loader): + if self.config.get("Optimizer") is not None and self.config["Global"].get( + "do_train", True + ): + assert train_loader is not None, ( + "train_loader must be defined when Optimizer is provided." + ) + assert self.config["Trainer"].get("max_epochs") is not None, ( + "Trainer.max_epochs must be defined when Optimizer is provided." + ) + return build_optimizer( + self.config["Optimizer"], + model, + self.config["Trainer"]["max_epochs"], + len(train_loader), + ) + return None, None + + def build_metric(self): + metric_cfg = self.config.get("Metric") + return build_metric(metric_cfg) if metric_cfg is not None else None + + def build_trainer( + self, + model, + train_loader, + val_loader, + optimizer, + lr_scheduler, + metric_func, + ): + return BaseTrainer( + self.config["Trainer"], + model, + train_dataloader=train_loader, + val_dataloader=val_loader, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + compute_metric_func_dict=metric_func, + ) + + def post_build_trainer(self, trainer, model, train_loader, val_loader, test_loader): + return + + def run(self, trainer, train_loader, val_loader, test_loader): + if self.config["Global"].get("do_train", True): + trainer.train() + if self.config["Global"].get("do_eval", False): + logger.info("Evaluating on validation set") + trainer.eval(val_loader) + if self.config["Global"].get("do_test", False): + logger.info("Evaluating on test set") + trainer.eval(test_loader) + + +@register_train_case +class SFINTrainCase(BaseTrainCase): + case_name = "sfin" + + +class TrainRunner: + def __init__( + self, + case: str, + config_path: str, + dynamic_args: list[str], + append_timestamp: bool = False, + ): + self.case = case.strip().lower() + self.config_path = config_path + self.dynamic_args = dynamic_args + self.append_timestamp = append_timestamp + + def _build_case(self, config: Dict[str, Any]) -> BaseTrainCase: + case_cls = TRAIN_CASE_REGISTRY.get(self.case) + if case_cls is None: + available = ", ".join(sorted(TRAIN_CASE_REGISTRY.keys())) + raise ValueError( + f"Unsupported train case '{self.case}'. Available: [{available}]" + ) + return case_cls(config) + + def _load_and_merge_config(self): + cfg = OmegaConf.load(self.config_path) + cli_cfg = OmegaConf.from_dotlist(self.dynamic_args) + cfg = OmegaConf.merge(cfg, cli_cfg) + + if self.append_timestamp or cfg["Trainer"].get("append_timestamp", False): + seed = cfg["Trainer"].get("seed", 42) + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + base_output_dir = cfg["Trainer"]["output_dir"] + cfg["Trainer"]["output_dir"] = f"{base_output_dir}_t_{timestamp}_s_{seed}" + return cfg + + def _save_config(self, cfg): + if dist.get_rank() == 0: + os.makedirs(cfg["Trainer"]["output_dir"], exist_ok=True) + config_name = os.path.basename(self.config_path) + OmegaConf.save(cfg, osp.join(cfg["Trainer"]["output_dir"], config_name)) + + @staticmethod + def _setup_runtime(config: Dict[str, Any]): + logger_path = osp.join(config["Trainer"]["output_dir"], "run.log") + logger.init_logger(log_file=logger_path) + logger.info(f"Logger saved to {logger_path}") + + seed = config["Trainer"].get("seed", 42) + misc.set_random_seed(seed) + logger.info(f"Set random seed to {seed}") + + enabled = config["Global"].get("prim_eager_enabled", False) + white_list = config["Global"].get("prim_backward_white_list", None) + setting_eager_mode(enabled, white_list) + + def run(self): + cfg = self._load_and_merge_config() + self._save_config(cfg) + config = OmegaConf.to_container(cfg, resolve=True) + + self._setup_runtime(config) + + train_case = self._build_case(config) + train_loader, val_loader, test_loader = train_case.build_dataloaders() + model = train_case.build_model(train_loader, val_loader, test_loader) + optimizer, lr_scheduler = train_case.build_optimizer(model, train_loader) + metric_func = train_case.build_metric() + trainer = train_case.build_trainer( + model=model, + train_loader=train_loader, + val_loader=val_loader, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + metric_func=metric_func, + ) + train_case.post_build_trainer( + trainer, + model, + train_loader, + val_loader, + test_loader, + ) + train_case.run(trainer, train_loader, val_loader, test_loader) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--case", + type=str, + default="sfin", + help="Train case name. Extend by registering a new train case.", + ) + parser.add_argument( + "-c", + "--config", + type=str, + default="./spectrum_enhancement/configs/sfin/sfin_tem_enhance.yaml", + help="Path to config file.", + ) + parser.add_argument( + "--append_timestamp", + action="store_true", + help="Append timestamp to Trainer.output_dir.", + ) + return parser.parse_known_args() + + +def main(): + if dist.get_world_size() > 1: + fleet.init(is_collective=True) + + args, dynamic_args = parse_args() + runner = TrainRunner( + case=args.case, + config_path=args.config, + dynamic_args=dynamic_args, + append_timestamp=args.append_timestamp, + ) + runner.run() + + +if __name__ == "__main__": + main()