diff --git a/README.md b/README.md index 79f1c309..9ea3704e 100755 --- a/README.md +++ b/README.md @@ -33,6 +33,7 @@ ## ๐ Task - [MLIP-Machine Learning Interatomic Potential](interatomic_potentials/README.md) - [MLES-Machine Learning Electronic Structure](electronic_structure/README.md) +- [SPEN-Spectrum Enhancement](spectrum_enhancement/README.md) - [PP-Property Prediction](property_prediction/README.md) - [SG-Structure Generation](structure_generation/README.md) - [SE-Spectrum Elucidation](spectrum_elucidation/README.md) diff --git a/ppmat/datasets/__init__.py b/ppmat/datasets/__init__.py index 98eec451..564654c0 100644 --- a/ppmat/datasets/__init__.py +++ b/ppmat/datasets/__init__.py @@ -42,10 +42,11 @@ from ppmat.datasets.msd_nmr_dataset import MSDnmrDataset from ppmat.datasets.msd_nmr_dataset import MSDnmrinfos from ppmat.datasets.density_dataset import DensityDataset -from ppmat.datasets.small_density_dataset import SmallDensityDataset +from ppmat.datasets.small_density_dataset import SmallDensityDataset +from ppmat.datasets.stem_image_dataset import STEMImageDataset from ppmat.datasets.num_atom_crystal_dataset import NumAtomsCrystalDataset from ppmat.datasets.oc20_s2ef_dataset import OC20S2EFDataset # noqa -from ppmat.datasets.qm9_dataset import QM9Dataset # noqa +from ppmat.datasets.qm9_dataset import QM9Dataset # noqa from ppmat.datasets.omol25_dataset import OMol25Dataset from ppmat.datasets.split_mptrj_data import none_to_zero from ppmat.datasets.transform import build_transforms @@ -64,8 +65,9 @@ "HighLevelWaterDataset", "MSDnmrDataset", "MatbenchDataset", - "DensityDataset", + "DensityDataset", "SmallDensityDataset", + "STEMImageDataset", "OMol25Dataset", ] diff --git a/ppmat/datasets/build_stem.py b/ppmat/datasets/build_stem.py new file mode 100644 index 00000000..3560b8a4 --- /dev/null +++ b/ppmat/datasets/build_stem.py @@ -0,0 +1,297 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import copy +import importlib +from pathlib import Path +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple + +from ppmat.utils import download as download_utils +from ppmat.utils import logger + + +def _locate_class(class_name: str): + if "." in class_name: + mod, cls = class_name.rsplit(".", 1) + return getattr(importlib.import_module(mod), cls) + return globals()[class_name] + + +def _parse_factory_cfg( + cfg: Optional[Dict[str, Any] | str], + *, + default_class_name: str, +) -> Tuple[str, Dict[str, Any]]: + """ + Parse factory config in a backward/forward-compatible way. + + Supported patterns: + 1) None + -> use default class with empty params + 2) "ClassName" / "pkg.mod.ClassName" + -> use class name string with empty params + 3) {"__class_name__": "...", "__init_params__": {...}} + -> current canonical style (same as many ppmat builders) + 4) {"class_name": "...", "init_params": {...}} + -> compatibility alias + 5) {"type": "...", "params": {...}} + -> compatibility alias + """ + if cfg is None: + return default_class_name, {} + + if isinstance(cfg, str): + return cfg, {} + + if not isinstance(cfg, dict): + raise TypeError( + "Factory cfg must be None, str, or dict, " + f"but got type={type(cfg).__name__}." + ) + + cfg = copy.deepcopy(cfg) + class_name = ( + cfg.pop("__class_name__", None) + or cfg.pop("class_name", None) + or cfg.pop("type", None) + ) + if not class_name: + raise ValueError( + "Factory cfg must include class name key, e.g. " + "{'__class_name__': 'StrictIndexSampleBuilder', '__init_params__': {...}}." + ) + + init_params = ( + cfg.pop("__init_params__", None) + if "__init_params__" in cfg + else cfg.pop("init_params", None) + if "init_params" in cfg + else cfg.pop("params", None) + if "params" in cfg + else {} + ) + if init_params is None: + init_params = {} + if not isinstance(init_params, dict): + raise TypeError( + f"Factory init params must be dict, but got type={type(init_params).__name__}." + ) + + if cfg: + raise ValueError( + f"Unsupported keys in factory cfg for '{class_name}': {list(cfg.keys())}" + ) + return class_name, init_params + + +def _build_component( + cfg: Optional[Dict[str, Any] | str], + *, + default_class_name: str, + required_methods: List[str], +): + class_name, init_params = _parse_factory_cfg( + cfg, + default_class_name=default_class_name, + ) + + cls = _locate_class(class_name) + component = cls(**init_params) + for method_name in required_methods: + if not hasattr(component, method_name): + raise TypeError( + f"Component '{class_name}' must implement method '{method_name}'." + ) + return component + + +class StrictIndexSampleBuilder: + def build( + self, + noisy_dir: Path, + target_dir: Path, + file_suffix: str, + data_count: Optional[int] = None, + ) -> List[Dict[str, str]]: + samples: List[Dict[str, str]] = [] + if data_count is None: + available = sorted(noisy_dir.glob(f"*{file_suffix}")) + data_count = len(available) + + for idx in range(int(data_count)): + name = f"{idx}{file_suffix}" + noisy_path = noisy_dir / name + target_path = target_dir / name + if not noisy_path.exists(): + raise FileNotFoundError(f"Noisy image not found: {noisy_path}") + if not target_path.exists(): + raise FileNotFoundError(f"Target image not found: {target_path}") + samples.append( + { + "name": name, + "noisy_path": str(noisy_path), + "target_path": str(target_path), + } + ) + return samples + + +class MatchedNameSampleBuilder: + def build( + self, + noisy_dir: Path, + target_dir: Path, + file_suffix: str, + data_count: Optional[int] = None, + ) -> List[Dict[str, str]]: + samples: List[Dict[str, str]] = [] + noisy_files = { + p.name: p for p in noisy_dir.glob(f"*{file_suffix}") if p.is_file() + } + target_files = { + p.name: p for p in target_dir.glob(f"*{file_suffix}") if p.is_file() + } + common_names = sorted(set(noisy_files.keys()) & set(target_files.keys())) + if data_count is not None: + common_names = common_names[: int(data_count)] + + for name in common_names: + samples.append( + { + "name": name, + "noisy_path": str(noisy_files[name]), + "target_path": str(target_files[name]), + } + ) + return samples + + +class DefaultSTEMDatasetDownloader: + def __init__(self, datasets_home: Optional[str] = None): + self.datasets_home = datasets_home or download_utils.DATASETS_HOME + + def download( + self, url: str, md5: Optional[str] = None, force_download: bool = False + ) -> Path: + if force_download: + downloaded_root = download_utils.get_path_from_url( + url, + self.datasets_home, + md5sum=md5, + check_exist=False, + decompress=True, + ) + else: + downloaded_root = download_utils.get_datasets_path_from_url(url, md5) + return Path(downloaded_root) + + +class PairDirectoryDataRootResolver: + def __init__(self, max_depth: int = 2): + if max_depth < 0: + raise ValueError(f"max_depth must be >= 0, but got {max_depth}") + self.max_depth = int(max_depth) + + @staticmethod + def _contains_pair_dirs(root: Path, noisy_subdir: str, target_subdir: str) -> bool: + return ( + root.is_dir() + and (root / noisy_subdir).exists() + and (root / target_subdir).exists() + ) + + def find_data_roots( + self, + base_root: Path, + split: Optional[str], + noisy_subdir: str, + target_subdir: str, + ) -> List[Path]: + if not base_root.exists(): + return [] + + candidate_roots: List[Path] = [base_root] + frontier: List[Path] = [base_root] + for _ in range(self.max_depth): + next_frontier: List[Path] = [] + for root in frontier: + for child in root.iterdir(): + if child.is_dir(): + candidate_roots.append(child) + next_frontier.append(child) + frontier = next_frontier + + matches: List[Path] = [] + for root in candidate_roots: + if split is not None: + split_root = root / split + if self._contains_pair_dirs(split_root, noisy_subdir, target_subdir): + matches.append(split_root) + if self._contains_pair_dirs(root, noisy_subdir, target_subdir): + matches.append(root) + + seen = set() + unique_matches = [] + for path in matches: + path_str = str(path) + if path_str in seen: + continue + seen.add(path_str) + unique_matches.append(path) + return unique_matches + + +def build_stem_sample_builder( + cfg: Optional[Dict[str, Any] | str], + *, + strict_index_naming: bool, +): + default_class_name = ( + "StrictIndexSampleBuilder" + if strict_index_naming + else "MatchedNameSampleBuilder" + ) + sample_builder = _build_component( + cfg, + default_class_name=default_class_name, + required_methods=["build"], + ) + logger.debug(f"Use sample builder: {sample_builder.__class__.__name__}") + return sample_builder + + +def build_stem_downloader(cfg: Optional[Dict[str, Any] | str]): + downloader = _build_component( + cfg, + default_class_name="DefaultSTEMDatasetDownloader", + required_methods=["download"], + ) + logger.debug(f"Use downloader: {downloader.__class__.__name__}") + return downloader + + +def build_stem_data_root_resolver(cfg: Optional[Dict[str, Any] | str]): + resolver = _build_component( + cfg, + default_class_name="PairDirectoryDataRootResolver", + required_methods=["find_data_roots"], + ) + logger.debug(f"Use data root resolver: {resolver.__class__.__name__}") + return resolver diff --git a/ppmat/datasets/stem_image_dataset.py b/ppmat/datasets/stem_image_dataset.py new file mode 100644 index 00000000..88d530de --- /dev/null +++ b/ppmat/datasets/stem_image_dataset.py @@ -0,0 +1,296 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from pathlib import Path +from typing import Dict +from typing import List +from typing import Optional + +import numpy as np +import paddle +from PIL import Image + +from ppmat.datasets.build_stem import build_stem_data_root_resolver +from ppmat.datasets.build_stem import build_stem_downloader +from ppmat.datasets.build_stem import build_stem_sample_builder +from ppmat.utils import logger + + +class STEMImageDataset(paddle.io.Dataset): + """Dataset for paired STEM image restoration/enhancement. + + Supports automatic download and extraction (zip/tar/tar.gz) through + ``ppmat.utils.download.get_datasets_path_from_url``. + + Expected directory layout after extraction: + data_path/ + train/ + noisy/ + 0.png + 1.png + ... + gt_enhance/ + 0.png + 1.png + ... + val/ + noisy/ + ... + gt_enhance/ + ... + test/ + noisy/ + ... + gt_enhance/ + ... + + Or legacy format (backward compatible): + data_path/ + noisy/ + 0.png + ... + gt_enhance/ + 0.png + ... + """ + + name = "stem_enhancement" + url = None + md5 = None + _DEFAULT_URL_MAP = { + "data": "https://paddle-org.bj.bcebos.com/paddlematerials/datasets/SFIN_datasets/haadf_data.zip", + "data_test": "https://paddle-org.bj.bcebos.com/paddlematerials/datasets/SFIN_datasets/haadf_data_test.zip", + "bf_data": "https://paddle-org.bj.bcebos.com/paddlematerials/datasets/SFIN_datasets/bf_data.zip", + "bf_data_test": "https://paddle-org.bj.bcebos.com/paddlematerials/datasets/SFIN_datasets/bf_data_test.zip", + } + _DEFAULT_MD5_MAP: Dict[str, str] = {} + + def __init__( + self, + data_path: str, + split: Optional[str] = None, + data_count: int | None = None, + noisy_subdir: str = "noisy", + target_subdir: str = "gt_enhance", + file_suffix: str = ".png", + strict_index_naming: bool = True, + sample_builder_cfg: Optional[Dict] = None, + downloader_cfg: Optional[Dict] = None, + data_root_resolver_cfg: Optional[Dict] = None, + scale_to_unit: bool = False, + url: Optional[str] = None, + md5: Optional[str] = None, + download: bool = True, + force_download: bool = False, + ): + """Initialize STEMImageDataset. + + Args: + data_path: Root directory for the dataset. + split: Dataset split, one of 'train', 'val', 'test', or None. + If None, uses legacy format without split subdirectories. + data_count: Maximum number of samples to load. None means all. + noisy_subdir: Subdirectory name for noisy images. + target_subdir: Subdirectory name for target/ground truth images. + file_suffix: File extension for images (e.g., '.png', '.tif'). + strict_index_naming: If True, expects files named as {idx}{suffix}. + sample_builder_cfg: Sample builder config in factory style: + { + "__class_name__": "StrictIndexSampleBuilder", + "__init_params__": {} + } + downloader_cfg: Downloader config in factory style. + Default class is "DefaultSTEMDatasetDownloader". + data_root_resolver_cfg: Data root resolver config in factory style. + Default class is "PairDirectoryDataRootResolver". + scale_to_unit: If True, scales pixel values to [0, 1]. + url: URL to download the dataset from. Overrides default URL. + md5: MD5 checksum for downloaded file. Optional. + download: Whether to automatically download if data not found. + force_download: If True, re-download even if data exists. + """ + super().__init__() + + self.split = split + self.data_count = data_count + self.noisy_subdir = noisy_subdir + self.target_subdir = target_subdir + self.file_suffix = file_suffix + self.strict_index_naming = strict_index_naming + self.scale_to_unit = scale_to_unit + self.sample_builder = build_stem_sample_builder( + sample_builder_cfg, + strict_index_naming=strict_index_naming, + ) + self.downloader = build_stem_downloader(downloader_cfg) + self.data_root_resolver = build_stem_data_root_resolver(data_root_resolver_cfg) + + self.url = url if url is not None else self._infer_default_url(data_path) + self.md5 = ( + md5 if md5 is not None else self._infer_default_md5(data_path) or self.md5 + ) + self.data_root = Path(data_path) + self.downloaded_root: Optional[Path] = None + + if self._locate_data_root(self.data_root) is None and ( + download or force_download + ): + self.downloaded_root = self._download_dataset(force_download) + + # Determine actual data directory based on split + self.data_dir = self._resolve_data_dir() + + # Set up noisy and target directories + self.noisy_dir = self.data_dir / noisy_subdir + self.target_dir = self.data_dir / target_subdir + + if not self.noisy_dir.exists(): + raise FileNotFoundError(f"Noisy directory not found: {self.noisy_dir}") + if not self.target_dir.exists(): + raise FileNotFoundError(f"Target directory not found: {self.target_dir}") + + self.samples = self._build_samples(data_count) + + def _resolve_data_dir(self) -> Path: + """Resolve the actual data directory based on split configuration.""" + candidate_roots = [self.data_root] + if self.downloaded_root is not None: + for root in [self.downloaded_root, self.downloaded_root.parent]: + if root != self.data_root and root not in candidate_roots: + candidate_roots.append(root) + + for candidate_root in candidate_roots: + matches = self._find_data_roots(candidate_root) + if not matches: + continue + if ( + self.downloaded_root is not None + and candidate_root == self.downloaded_root.parent + and len(matches) > 1 + ): + raise FileNotFoundError( + "Multiple candidate dataset roots were found under " + f"'{candidate_root}': {[str(m) for m in matches]}. " + "Please provide a more specific local `data_path` or explicit `url`." + ) + + data_root_candidate = matches[0] + if self.split is not None and data_root_candidate == candidate_root: + logger.warning( + f"Split '{self.split}' requested but legacy format detected. " + f"Using data directly from {candidate_root}" + ) + return data_root_candidate + + searched_roots = ", ".join([str(path) for path in candidate_roots]) + if self.split is not None: + raise FileNotFoundError( + f"Split '{self.split}' not found under: {searched_roots}" + ) + raise FileNotFoundError( + "Cannot locate dataset directories " + f"'{self.noisy_subdir}' and '{self.target_subdir}' under: {searched_roots}" + ) + + def _download_dataset(self, force_download: bool = False) -> Path: + """Download dataset with built-in ppmat factory utility.""" + if not self.url: + candidate = ", ".join(sorted(self._DEFAULT_URL_MAP.keys())) + raise FileNotFoundError( + f"Dataset not found at '{self.data_root}', and no download URL provided. " + f"Auto-url is only inferred for data_path basename in [{candidate}]." + ) + + logger.message( + f"Dataset root {self.data_root} not found. Will download it now." + ) + downloaded_root = self.downloader.download( + self.url, + self.md5, + force_download=force_download, + ) + logger.info(f"Dataset downloaded to: {downloaded_root}") + return Path(downloaded_root) + + @classmethod + def _infer_default_url(cls, data_path: str) -> Optional[str]: + key = Path(data_path).name + url = cls._DEFAULT_URL_MAP.get(key) + if url is not None: + logger.info( + f"Infer dataset download URL by data_path='{data_path}': {url}" + ) + return url + + @classmethod + def _infer_default_md5(cls, data_path: str) -> Optional[str]: + key = Path(data_path).name + md5 = cls._DEFAULT_MD5_MAP.get(key) + if md5 is not None: + logger.info( + f"Infer dataset md5 by data_path='{data_path}': {md5}" + ) + return md5 + + def _locate_data_root(self, base_root: Path) -> Optional[Path]: + matches = self._find_data_roots(base_root) + if not matches: + return None + return matches[0] + + def _find_data_roots(self, base_root: Path) -> List[Path]: + return self.data_root_resolver.find_data_roots( + base_root=base_root, + split=self.split, + noisy_subdir=self.noisy_subdir, + target_subdir=self.target_subdir, + ) + + def _build_samples(self, data_count: int | None) -> List[Dict[str, str]]: + """Build list of sample dictionaries.""" + return self.sample_builder.build( + noisy_dir=self.noisy_dir, + target_dir=self.target_dir, + file_suffix=self.file_suffix, + data_count=data_count, + ) + + def __len__(self): + return len(self.samples) + + def _load_gray_image(self, path: str) -> paddle.Tensor: + """Load image as grayscale tensor.""" + image = Image.open(path).convert("L") + image_array = np.asarray(image, dtype=np.float32) + if self.scale_to_unit: + image_array = image_array / 255.0 + return paddle.to_tensor(image_array).unsqueeze(0) + + def __getitem__(self, idx: int): + sample = self.samples[idx] + noisy = self._load_gray_image(sample["noisy_path"]) + target = self._load_gray_image(sample["target_path"]) + + output = { + "noisy": noisy, + self.target_subdir: target, + "target": target, + "name": sample["name"], + } + # Backward compatibility for legacy code paths that read `gt_enhance`. + if self.target_subdir != "gt_enhance": + output["gt_enhance"] = target + return output diff --git a/ppmat/metrics/__init__.py b/ppmat/metrics/__init__.py index a0e3fb75..a304809f 100644 --- a/ppmat/metrics/__init__.py +++ b/ppmat/metrics/__init__.py @@ -18,11 +18,19 @@ from ppmat.metrics.csp_metric import CSPMetric from ppmat.metrics.diffnmr_streaming_adapter import DiffNMRStreamingAdapter +from ppmat.metrics.sfin_metric import PSNRMetric +from ppmat.metrics.sfin_metric import SSIMMetric +from ppmat.metrics.sfin_metric import calc_psnr +from ppmat.metrics.sfin_metric import calc_ssim __all__ = [ "build_metric", "CSPMetric", "DiffNMRStreamingAdapter", + "PSNRMetric", + "SSIMMetric", + "calc_psnr", + "calc_ssim", # "DiffNMRMetric", # "NLL", "CrossEntropyMetric", "SumExceptBatchMetric", "SumExceptBatchKL", ] diff --git a/ppmat/metrics/sfin_metric.py b/ppmat/metrics/sfin_metric.py new file mode 100644 index 00000000..3a26c9f5 --- /dev/null +++ b/ppmat/metrics/sfin_metric.py @@ -0,0 +1,154 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import paddle + + +def calc_psnr( + pred: paddle.Tensor, + label: paddle.Tensor, + data_range: float = 255.0, + eps: float = 1e-12, +) -> paddle.Tensor: + """Compute batch PSNR for image tensors with shape [N, C, H, W].""" + pred = pred.astype("float64") + label = label.astype("float64") + diff = (pred - label) / data_range + mse = paddle.mean(diff * diff) + mse = paddle.maximum(mse, paddle.to_tensor(eps, dtype=mse.dtype)) + return -10.0 * paddle.log10(mse) + + +def _gaussian_window_2d( + channels: int, + win_size: int = 11, + win_sigma: float = 1.5, + dtype: str = "float32", +) -> paddle.Tensor: + coords = paddle.arange(win_size, dtype=dtype) - (win_size // 2) + gauss = paddle.exp(-(coords**2) / (2.0 * (win_sigma**2))) + gauss = gauss / paddle.sum(gauss) + window_2d = gauss.unsqueeze(1) * gauss.unsqueeze(0) + window = window_2d.reshape([1, 1, win_size, win_size]) + return paddle.tile(window, [channels, 1, 1, 1]) + + +def calc_ssim( + pred: paddle.Tensor, + label: paddle.Tensor, + data_range: float = 255.0, + win_size: int = 11, + win_sigma: float = 1.5, + k1: float = 0.01, + k2: float = 0.03, + nonnegative_ssim: bool = False, +) -> paddle.Tensor: + """Compute batch SSIM for image tensors with shape [N, C, H, W].""" + if pred.shape != label.shape: + raise ValueError( + f"Input images should have the same dimensions, got {pred.shape} and {label.shape}." + ) + if len(pred.shape) != 4: + raise ValueError( + f"Input images should be 4-d tensors [N, C, H, W], got shape {pred.shape}." + ) + if win_size % 2 != 1: + raise ValueError("win_size must be odd.") + + pred = pred.astype("float32") + label = label.astype("float32") + + channels = pred.shape[1] + window = _gaussian_window_2d( + channels=channels, + win_size=win_size, + win_sigma=win_sigma, + dtype=pred.dtype, + ) + + mu1 = paddle.nn.functional.conv2d(pred, window, stride=1, padding=0, groups=channels) + mu2 = paddle.nn.functional.conv2d(label, window, stride=1, padding=0, groups=channels) + + mu1_sq = mu1 * mu1 + mu2_sq = mu2 * mu2 + mu1_mu2 = mu1 * mu2 + + sigma1_sq = paddle.nn.functional.conv2d( + pred * pred, window, stride=1, padding=0, groups=channels + ) - mu1_sq + sigma2_sq = paddle.nn.functional.conv2d( + label * label, window, stride=1, padding=0, groups=channels + ) - mu2_sq + sigma12 = paddle.nn.functional.conv2d( + pred * label, window, stride=1, padding=0, groups=channels + ) - mu1_mu2 + + c1 = (k1 * data_range) ** 2 + c2 = (k2 * data_range) ** 2 + c1 = paddle.to_tensor(c1, dtype=pred.dtype) + c2 = paddle.to_tensor(c2, dtype=pred.dtype) + + cs_map = (2.0 * sigma12 + c2) / (sigma1_sq + sigma2_sq + c2) + ssim_map = ((2.0 * mu1_mu2 + c1) / (mu1_sq + mu2_sq + c1)) * cs_map + + if nonnegative_ssim: + ssim_map = paddle.nn.functional.relu(ssim_map) + + return paddle.mean(ssim_map) + + +class PSNRMetric: + def __init__(self, data_range: float = 255.0, eps: float = 1e-12): + self.data_range = data_range + self.eps = eps + + def __call__(self, pred: paddle.Tensor, label: paddle.Tensor): + return calc_psnr( + pred=pred, + label=label, + data_range=self.data_range, + eps=self.eps, + ) + + +class SSIMMetric: + def __init__( + self, + data_range: float = 255.0, + win_size: int = 11, + win_sigma: float = 1.5, + k1: float = 0.01, + k2: float = 0.03, + nonnegative_ssim: bool = False, + ): + self.data_range = data_range + self.win_size = win_size + self.win_sigma = win_sigma + self.k1 = k1 + self.k2 = k2 + self.nonnegative_ssim = nonnegative_ssim + + def __call__(self, pred: paddle.Tensor, label: paddle.Tensor): + return calc_ssim( + pred=pred, + label=label, + data_range=self.data_range, + win_size=self.win_size, + win_sigma=self.win_sigma, + k1=self.k1, + k2=self.k2, + nonnegative_ssim=self.nonnegative_ssim, + ) diff --git a/ppmat/models/__init__.py b/ppmat/models/__init__.py index 95d73232..130704ec 100644 --- a/ppmat/models/__init__.py +++ b/ppmat/models/__init__.py @@ -42,6 +42,7 @@ from ppmat.models.megnet.megnet import MEGNetPlus from ppmat.models.infgcn.infgcn import InfGCN from ppmat.models.mateno.mateno import MatENO +from ppmat.models.sfin.sfin import SFIN from ppmat.utils import download from ppmat.utils import logger from ppmat.utils import save_load @@ -67,6 +68,7 @@ "DiffNMR", "InfGCN", "MatENO", + "SFIN", ] # Warning: The key of the dictionary must be consistent with the file name of the value diff --git a/ppmat/models/sfin/__init__.py b/ppmat/models/sfin/__init__.py new file mode 100644 index 00000000..d890dba7 --- /dev/null +++ b/ppmat/models/sfin/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ppmat.models.sfin.sfin import SFIN, FourierUnit, SpectralTransform, FFC, SFIB, ResnetBlock + +__all__ = [ + "SFIN", + "FourierUnit", + "SpectralTransform", + "FFC", + "SFIB", + "ResnetBlock", +] diff --git a/ppmat/models/sfin/sfin.py b/ppmat/models/sfin/sfin.py new file mode 100644 index 00000000..03bf2823 --- /dev/null +++ b/ppmat/models/sfin/sfin.py @@ -0,0 +1,390 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +SFIN: Noise Calibration and Spatial-Frequency Interactive Network for STEM Image Enhancement +Paper: CVPR 2025 - https://arxiv.org/pdf/2504.02555 +""" + +from typing import Dict + +import paddle +import paddle.nn as nn + +# BatchNorm semantic alignment: +# PyTorch: running = (1 - m_torch) * running + m_torch * batch, default m_torch=0.1 +# Paddle: running = m_paddle * running + (1 - m_paddle) * batch +# so m_paddle = 1 - m_torch = 0.9 +TORCH_BN_MOMENTUM = 0.1 +PADDLE_BN_MOMENTUM = 1.0 - TORCH_BN_MOMENTUM +BN_EPSILON = 1e-5 + + +def _bn_aligned(num_features: int) -> nn.BatchNorm2D: + """Create BatchNorm2D with PyTorch-aligned momentum semantics.""" + return nn.BatchNorm2D(num_features, momentum=PADDLE_BN_MOMENTUM, epsilon=BN_EPSILON) + + +class FourierUnit(nn.Layer): + """Fourier Unit for processing frequency domain features.""" + + def __init__(self, in_channels: int, out_channels: int): + super(FourierUnit, self).__init__() + fu_bound = 1.0 / ((out_channels * 2) * 1 * 1) ** 0.5 + + self.conv_layer = nn.Conv2D( + in_channels=in_channels * 2 + 2, + out_channels=out_channels * 2, + kernel_size=1, + stride=1, + padding=0, + bias_attr=False, + weight_attr=paddle.nn.initializer.KaimingUniform( + negative_slope=5**0.5, # a=sqrt(5) + mode='fan_in', + nonlinearity='leaky_relu' + ) + ) + self.bn = _bn_aligned(out_channels * 2) + self.relu = nn.ReLU() + + def forward(self, x): + batch = x.shape[0] + fft_dim = (-2, -1) + + # Real FFT with ortho normalization + ffted = paddle.fft.rfftn(x, axes=fft_dim, norm='ortho') # (B, C, H, W/2+1) complex + + # Split into real/imaginary parts + ffted_real = paddle.real(ffted) # (B, C, H, W/2+1) + ffted_imag = paddle.imag(ffted) # (B, C, H, W/2+1) + ffted = paddle.stack([ffted_real, ffted_imag], axis=-1) # (B, C, H, W/2+1, 2) + + # Permute to (B, C, 2, H, W/2+1) + ffted = ffted.transpose([0, 1, 4, 2, 3]) # (B, C, 2, H, W/2+1) + ffted = ffted.reshape([batch, -1] + list(ffted.shape[3:])) # (B, C*2, H, W/2+1) + + # Create coordinate grids with DYNAMIC batch handling (critical fix) + height, width = ffted.shape[-2:] + coords_vert = paddle.linspace(0, 1, height).reshape([1, 1, height, 1]) + # SAFE EXPAND: Use x.shape[0] instead of captured 'batch' variable + coords_vert = coords_vert.expand([x.shape[0], 1, height, width]) + + coords_hor = paddle.linspace(0, 1, width).reshape([1, 1, 1, width]) + coords_hor = coords_hor.expand([x.shape[0], 1, height, width]) + + # Concatenate coordinates and FFT features + ffted = paddle.concat([coords_vert, coords_hor, ffted], axis=1) # (B, C*2+2, H, W/2+1) + + # Process through convolution + ffted = self.conv_layer(ffted) + ffted = self.relu(self.bn(ffted)) # (B, C*2, H, W/2+1) + + # Reshape back to complex format: (B, C, 2, H, W/2+1) โ (B, C, H, W/2+1, 2) + ffted = ffted.reshape([batch, -1, 2] + list(ffted.shape[2:])) # (B, C, 2, H, W/2+1) + ffted = ffted.transpose([0, 1, 3, 4, 2]) # (B, C, H, W/2+1, 2) + + # Convert back to complex tensor + ffted = paddle.complex(ffted[..., 0], ffted[..., 1]) # (B, C, H, W/2+1) complex + + # Inverse FFT with exact shape matching + output = paddle.fft.irfftn(ffted, s=x.shape[-2:], axes=fft_dim, norm='ortho') + return output + + +class SpectralTransform(nn.Layer): + """Spectral Transform block combining spatial and frequency domain processing.""" + + def __init__(self, in_channels: int): + super(SpectralTransform, self).__init__() + st1_fan_in = (in_channels // 2) * 3 * 3 + st1_bias_bound = 1.0 / st1_fan_in ** 0.5 + + st2_fan_in = in_channels * 3 * 3 # Input: [x (C/2), x2 (C/2)] โ C channels + st2_bias_bound = 1.0 / st2_fan_in ** 0.5 + + self.conv1 = nn.Conv2D( + in_channels // 2, in_channels // 2, 3, padding=1, + weight_attr=nn.initializer.KaimingUniform( + negative_slope=5**0.5, + mode='fan_in', + nonlinearity='leaky_relu' + ), + bias_attr=nn.initializer.Uniform(-st1_bias_bound, st1_bias_bound) + ) + self.fu = FourierUnit(in_channels // 2, in_channels // 2) + self.conv2 = nn.Conv2D( + in_channels, in_channels // 2, 3, padding=1, + weight_attr=nn.initializer.KaimingUniform( + negative_slope=5**0.5, + mode='fan_in', + nonlinearity='leaky_relu' + ), + bias_attr=nn.initializer.Uniform(-st2_bias_bound, st2_bias_bound) + ) + + def forward(self, x): + x1 = self.conv1(x) + x2 = self.fu(x1) + x = self.conv2(paddle.concat([x, x2], axis=1)) + return x + + +class FFC(nn.Layer): + """Fast Fourier Convolution block for spatial-frequency interaction.""" + + def __init__(self, in_channels: int): + super(FFC, self).__init__() + ffc_fan_in = (in_channels // 2) * 3 * 3 + ffc_bias_bound = 1.0 / ffc_fan_in ** 0.5 + + self.convl2l = nn.Conv2D( + in_channels // 2, in_channels // 2, 3, padding=1, + weight_attr=nn.initializer.KaimingUniform( + negative_slope=5**0.5, + mode='fan_in', + nonlinearity='leaky_relu' + ), + bias_attr=nn.initializer.Uniform(-ffc_bias_bound, ffc_bias_bound) + ) + self.convl2g = nn.Conv2D( + in_channels // 2, in_channels // 2, 3, padding=1, + weight_attr=nn.initializer.KaimingUniform( + negative_slope=5**0.5, + mode='fan_in', + nonlinearity='leaky_relu' + ), + bias_attr=nn.initializer.Uniform(-ffc_bias_bound, ffc_bias_bound) + ) + self.convg2l = nn.Conv2D( + in_channels // 2, in_channels // 2, 3, padding=1, + weight_attr=nn.initializer.KaimingUniform( + negative_slope=5**0.5, + mode='fan_in', + nonlinearity='leaky_relu' + ), + bias_attr=nn.initializer.Uniform(-ffc_bias_bound, ffc_bias_bound) + ) + self.convg2g = SpectralTransform(in_channels) + + def forward(self, x): + if isinstance(x, tuple): + x_l, x_g = x + else: + B, C, H, W = x.shape + x_l = x + # Must be C//2 channels to match local branch split later + x_g = paddle.zeros([B, C // 2, H, W], dtype=x.dtype) + + out_xl = self.convl2l(x_l) + self.convg2l(x_g) + out_xg = self.convl2g(x_l) + self.convg2g(x_g) + return out_xl, out_xg + + +class SFIB(nn.Layer): + """Spatial-Frequency Interactive Block.""" + + def __init__(self, in_channels: int): + super(SFIB, self).__init__() + self.ffc = FFC(in_channels) + self.bn_l = _bn_aligned(in_channels // 2) + self.bn_g = _bn_aligned(in_channels // 2) + self.act_l = nn.ReLU() + self.act_g = nn.ReLU() + + def forward(self, x): + x_l, x_g = self.ffc(x) + x_l = self.act_l(self.bn_l(x_l)) + x_g = self.act_g(self.bn_g(x_g)) + return x_l, x_g + + +class ResnetBlock(nn.Layer): + """Residual block with SFIB.""" + + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + self.conv1 = SFIB(in_channels) + self.conv2 = SFIB(in_channels) + + def forward(self, x): + x_l, x_g = paddle.split(x, [self.in_channels // 2, self.in_channels // 2], axis=1) + id_l, id_g = x_l, x_g + # Apply two SFIB blocks with residual connection + x_l, x_g = self.conv1((x_l, x_g)) + x_l, x_g = self.conv2((x_l, x_g)) + + # Residual connection + x_l = id_l + x_l + x_g = id_g + x_g + + # Recombine branches + out = paddle.concat([x_l, x_g], axis=1) + return out + + +class SFIN(nn.Layer): + """ + SFIN: Noise Calibration and Spatial-Frequency Interactive Network for STEM Image Enhancement. + + Args: + in_channels (int): Number of input channels (default: 1 for grayscale images) + base_channels (int): Base number of channels (default: 64) + num_blocks (int): Number of ResNet blocks (default: 8) + + Reference: + Li et al., "Noise Calibration and Spatial-Frequency Interactive Network for + STEM Image Enhancement", CVPR 2025. + https://arxiv.org/pdf/2504.02555 + """ + + def __init__( + self, + in_channels: int = 1, + base_channels: int = 64, + num_blocks: int = 8, + input_name: str = "noisy", + target_name: str = "gt_enhance", + loss_type: str = "l1", + loss_weight: float = 1.0, + ): + super(SFIN, self).__init__() + self.in_channels = in_channels + self.base_channels = base_channels + self.num_blocks = num_blocks + self.input_name = input_name + self.target_name = target_name + self.loss_type = loss_type.lower() + self.loss_weight = loss_weight + + if self.loss_type == "l1": + self.criterion = nn.L1Loss() + elif self.loss_type == "mse": + self.criterion = nn.MSELoss() + else: + raise ValueError(f"Unsupported loss_type '{loss_type}', expected 'l1' or 'mse'.") + + # Build ResNet blocks with proper registration + blocks = [ResnetBlock(base_channels) for _ in range(num_blocks)] + self.body = nn.Sequential(*blocks) + + # Head convolution initialization + head_fan_in = in_channels * 3 * 3 + head_bias_bound = 1.0 / head_fan_in ** 0.5 + self.head_conv = nn.Conv2D( + in_channels, base_channels, 3, padding=1, + weight_attr=nn.initializer.KaimingUniform( + negative_slope=5**0.5, + mode='fan_in', + nonlinearity='leaky_relu' + ), + bias_attr=nn.initializer.Uniform(-head_bias_bound, head_bias_bound) + ) + + # Tail convolution initialization + tail_fan_in = base_channels * 3 * 3 + tail_bias_bound = 1.0 / tail_fan_in ** 0.5 + self.tail_conv = nn.Conv2D( + base_channels, in_channels, 3, padding=1, + weight_attr=nn.initializer.KaimingUniform( + negative_slope=5**0.5, + mode='fan_in', + nonlinearity='leaky_relu' + ), + bias_attr=nn.initializer.Uniform(-tail_bias_bound, tail_bias_bound) + ) + + def _forward_tensor(self, x: paddle.Tensor) -> paddle.Tensor: + """ + Tensor-only forward pass of SFIN. + + Args: + x: Input tensor of shape (B, C, H, W) + + Returns: + Enhanced image tensor of shape (B, C, H, W) + """ + x = self.head_conv(x) + shortcut = x + x = self.body(x) + x = x + shortcut + x = self.tail_conv(x) + return x + + def _get_input_tensor(self, batch: Dict) -> paddle.Tensor: + key_candidates = [self.input_name, "image", "noisy", "input", "x"] + for key in key_candidates: + if key in batch and batch[key] is not None: + return batch[key] + raise KeyError( + f"SFIN expects one of input keys {key_candidates}, but got keys: {list(batch.keys())}" + ) + + def _get_label_tensor(self, batch: Dict): + key_candidates = [ + self.target_name, + "gt_enhance", + "target", + "label", + "gt", + "clean", + "y", + ] + for key in key_candidates: + if key in batch and batch[key] is not None: + return batch[key] + return None + + def forward(self, batch): + """ + Unified forward for both: + 1) tensor -> enhanced tensor (for direct use / legacy scripts) + 2) dict -> trainer-ready output with loss_dict and pred_dict + """ + if isinstance(batch, dict): + x = self._get_input_tensor(batch) + enhanced = self._forward_tensor(x) + + pred_dict = { + self.target_name: enhanced, + "pred": enhanced, + } + loss_dict = {} + + label = self._get_label_tensor(batch) + if label is not None: + loss = self.criterion(enhanced, label) * self.loss_weight + loss_dict["loss"] = loss + + return {"loss_dict": loss_dict, "pred_dict": pred_dict} + + return self._forward_tensor(batch) + + def predict(self, batch: Dict) -> Dict: + """ + Prediction interface for BasePredictor. + + Args: + batch: Dictionary containing 'image' key with input tensor + + Returns: + Dictionary containing 'pred' key with enhanced image + """ + if isinstance(batch, dict): + x = self._get_input_tensor(batch) + enhanced = self._forward_tensor(x) + return {self.target_name: enhanced, "pred": enhanced} + + return self._forward_tensor(batch) diff --git a/spectrum_enhancement/README.md b/spectrum_enhancement/README.md new file mode 100644 index 00000000..333be039 --- /dev/null +++ b/spectrum_enhancement/README.md @@ -0,0 +1,27 @@ +# SPEN-Spectrum Enhancement + +## 1.Introduction + +Spectrum Enhancement (SE) focuses on enhancing and denoising spectral and microscopy data. Leveraging advanced deep learning techniques, SE aims to recover high-quality signals from noisy observations, enabling more accurate analysis of material properties at the atomic scale. This task is particularly valuable for STEM (Scanning Transmission Electron Microscopy) image processing, where noise reduction can significantly improve the visualization of crystal structures and defects. + +## 2.Framework Support Matrix + +| **Supported Functions** | **Support** | +| ----------------------------------- | :---------: | +| **ML Capabilities ยท Training** | | +| Single-GPU | โ | +| Distributed training | โ | +| Mixed precision (AMP) | โ | +| Fine-tuning | โ | +| **ML Capabilities ยท Predict** | | +| Standard inference | โ | +| Distributed inference | โ | +| **Data Pipeline** | | +| Local dataset loading | โ | +| Auto dataset download | โ | +| **Task Workflow** | | +| Training / Evaluation / Prediction | โ | + +## 3.Model README + +- [SFIN](./configs/sfin/README.md) diff --git a/spectrum_enhancement/configs/sfin/README.md b/spectrum_enhancement/configs/sfin/README.md new file mode 100644 index 00000000..b56f271c --- /dev/null +++ b/spectrum_enhancement/configs/sfin/README.md @@ -0,0 +1,170 @@ +# SFIN + +[Noise Calibration and Spatial-Frequency Interactive Network for STEM Image Enhancement](https://arxiv.org/pdf/2504.02555) + +## 1.Introduction + +SFIN is a CNN-based model for STEM image restoration. It targets paired reconstruction from noisy grayscale inputs and supports two experimental modes (HAADF and BF), each with two training targets (`enhance`, `detect`). + +## 2.Model Description + +SFIN takes `noisy` as input and predicts one target image (`gt_enhance` or `gt_detect`). + +- Noise calibration module for robust denoising. +- Spatial-frequency interaction blocks for detail recovery. +- Multi-scale feature extraction. + +Training objective (L1): + +$$ +\mathcal{L} = \left\| \hat{I} - I_{gt} \right\|_1 +$$ + +Evaluation metrics: +- PSNR +- SSIM + +## 3.Configurations + +| Config | Mode | Target | Output Dir | +| --- | --- | --- | --- | +| `sfin_tem_enhance.yaml` | HAADF (TEM) | `gt_enhance` | `./output/sfin_tem_enhance` | +| `sfin_tem_detect.yaml` | HAADF (TEM) | `gt_detect` | `./output/sfin_tem_detect` | +| `sfin_bf_enhance.yaml` | BF | `gt_enhance` | `./output/sfin_bf_enhance` | +| `sfin_bf_detect.yaml` | BF | `gt_detect` | `./output/sfin_bf_detect` | + +## 4.Dataset + +### Format + +Expected paired directory format: + +```text +data/ # HAADF train + noisy/ + gt_enhance/ + gt_detect/ + +data_test/ # HAADF val/test + noisy/ + gt_enhance/ + gt_detect/ + +bf_data/ # BF train + noisy/ + gt_enhance/ + gt_detect/ + +bf_data_test/ # BF val/test + noisy/ + gt_enhance/ + gt_detect/ +``` + +### Download Links + +| Dataset | Train | Test | Link | +| :---: | :---: | :---: | :---: | +| HAADF train dataset | 1000 | - | [haadf_data.zip](https://paddle-org.bj.bcebos.com/paddlematerials/datasets/SFIN_datasets/haadf_data.zip) | +| BF train dataset | 1000 | - | [bf_data.zip](https://paddle-org.bj.bcebos.com/paddlematerials/datasets/SFIN_datasets/bf_data.zip) | +| HAADF test dataset | - | 100 | [haadf_data_test.zip](https://paddle-org.bj.bcebos.com/paddlematerials/datasets/SFIN_datasets/haadf_data_test.zip) | +| BF test dataset | - | 100 | [bf_data_test.zip](https://paddle-org.bj.bcebos.com/paddlematerials/datasets/SFIN_datasets/bf_data_test.zip) | + +### Auto Download Behavior + +`STEMImageDataset` supports local loading + auto download: + +- If `data_path` exists locally, data is loaded directly. +- If `data_path` is missing and `download=True`, URL is inferred by `data_path` basename: + - `data` -> `haadf_data.zip` + - `data_test` -> `haadf_data_test.zip` + - `bf_data` -> `bf_data.zip` + - `bf_data_test` -> `bf_data_test.zip` +- Archive formats `zip` and `tar.gz` are both supported. + +## 5.Results + +
| Model Name | +Dataset | +PSNR (dB) | +SSIM | +GPUs | +Training time | +Config | +Checkpoint | Log | +
|---|---|---|---|---|---|---|---|
| sfin_tem_enhance | +STEM Enhancement | +38.74 | +0.9622 | +1 (V100-32GB) | +~21.5 hours | +sfin_tem_enhance | +checkpoint | log | +