diff --git a/ppmat/datasets/__init__.py b/ppmat/datasets/__init__.py index 98eec451..93eb0115 100644 --- a/ppmat/datasets/__init__.py +++ b/ppmat/datasets/__init__.py @@ -47,6 +47,8 @@ from ppmat.datasets.oc20_s2ef_dataset import OC20S2EFDataset # noqa from ppmat.datasets.qm9_dataset import QM9Dataset # noqa from ppmat.datasets.omol25_dataset import OMol25Dataset +from ppmat.datasets.ir_dataset import IRDataset +from ppmat.datasets.ecd_dataset import ECDDataset from ppmat.datasets.split_mptrj_data import none_to_zero from ppmat.datasets.transform import build_transforms from ppmat.utils import logger @@ -67,6 +69,8 @@ "DensityDataset", "SmallDensityDataset", "OMol25Dataset", + "IRDataset", + "ECDDataset", ] INFO_CLASS_REGISTRY: Dict[str, type] = { @@ -277,7 +281,7 @@ def set_build_sample(sampler_cfg, world_size, dataset): ) batch_sampler = getattr(io, batch_sampler_cls)( dataset, - batch_size=init_params["batch_size"], + batch_size=2, # use default batch_size=2 to avoid error when batch_sampler is not specified shuffle=False, drop_last=False, ) diff --git a/ppmat/datasets/build_ecd.py b/ppmat/datasets/build_ecd.py new file mode 100644 index 00000000..f7b44b64 --- /dev/null +++ b/ppmat/datasets/build_ecd.py @@ -0,0 +1,433 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import copy +import importlib +from pathlib import Path +from typing import Any, Dict, Optional, Tuple, List +import paddle +import pandas as pd +import numpy as np + +from ppmat.datasets.geometric_data_type.data import Data + +from ppmat.utils import download as download_utils +from ppmat.utils import logger +from ppmat.utils import ColoredTqdm as tqdm +from ppmat.utils.compound_tools import ( + atom_id_names, bond_id_names, bond_angle_float_names +) + +def _locate_class(class_name: str): + if "." in class_name: + mod, cls = class_name.rsplit(".", 1) + return getattr(importlib.import_module(mod), cls) + return globals()[class_name] + + +def _parse_factory_cfg( + cfg: Optional[Dict[str, Any] | str], + *, + default_class_name: str, +) -> Tuple[str, Dict[str, Any]]: + """Parse factory configuration, compatible with multiple formats""" + if cfg is None: + return default_class_name, {} + + if isinstance(cfg, str): + return cfg, {} + + if not isinstance(cfg, dict): + raise TypeError(f"cfg must be None, str, or dict, got {type(cfg).__name__}") + + cfg = copy.deepcopy(cfg) + class_name = cfg.pop("__class_name__", None) or cfg.pop("class_name", None) or cfg.pop("type", None) + if not class_name: + raise ValueError("Factory cfg must include class name key") + + init_params = ( + cfg.pop("__init_params__", None) or + cfg.pop("init_params", None) or + cfg.pop("params", None) or {} + ) + if not isinstance(init_params, dict): + raise TypeError(f"init_params must be dict, got {type(init_params).__name__}") + + if cfg: + raise ValueError(f"Unsupported keys in cfg: {list(cfg.keys())}") + return class_name, init_params + + +class StrictIndexSampleBuilder: + """Build samples by strict index (for ECD dataset)""" + def build(self, data_dir: Path, index_file: str, sample_path: str, data_count: Optional[int] = None): + import pandas as pd + samples = [] + df = pd.read_csv(data_dir / index_file, encoding='gbk') + ids = df['Unnamed: 0'].values[:data_count] if data_count else df['Unnamed: 0'].values + for idx in ids: + samples.append({ + 'id': int(idx), + 'smiles': df[df['Unnamed: 0'] == idx]['SMILES'].values[0], + 'spectrum_path': str(Path(sample_path) / f"{idx}.csv") + }) + return samples + + +class DefaultECDDatasetDownloader: + """ECD dataset downloader""" + def __init__(self, datasets_home: Optional[str] = None): + self.datasets_home = datasets_home or download_utils.DATASETS_HOME + + def download(self, url: str, md5: Optional[str] = None, force_download: bool = False) -> Path: + if force_download: + downloaded_root = download_utils.get_path_from_url( + url, self.datasets_home, md5sum=md5, check_exist=False, decompress=True + ) + else: + downloaded_root = download_utils.get_datasets_path_from_url(url, md5) + return Path(downloaded_root) + +def build_ecformer_downloader(cfg: Optional[Dict[str, Any] | str]): + """Build downloader""" + class_name, init_params = _parse_factory_cfg(cfg, default_class_name="DefaultECDDatasetDownloader") + cls = _locate_class(class_name) + downloader = cls(**init_params) + if not hasattr(downloader, 'download'): + raise TypeError(f"Downloader {class_name} must implement 'download' method") + logger.debug(f"Use downloader: {class_name}") + return downloader + +def get_key_padding_mask(tokens): + """Generate query padding mask""" + key_padding_mask = paddle.zeros(tokens.shape) + key_padding_mask[tokens == -1] = -paddle.inf + return key_padding_mask + +def normalize_func(src_list, norm_range=[-100, 100]): + # lihao implecation for list normalization + # input: src_list, normalization range + # output: tgt_list after normalization + + src_max, src_min = max(src_list), min(src_list) + norm_min, norm_max = norm_range[0], norm_range[1] + if src_max == 0: src_max = 1 + if src_min == 0: src_min = -1 + + tgt_list = [] + for i in range(len(src_list)): + if src_list[i] >= 0: + tgt_list.append(src_list[i] * norm_max / src_max) + else: + tgt_list.append(src_list[i] * norm_min / src_min) + + assert len(src_list) == len(tgt_list) + return tgt_list + +def get_sequence_peak(sequence): + # input- seq: List + # output- peak_list contains peak position + peak_list = [] + for i in range(1, len(sequence)-1): + if sequence[i-1]sequence[i+1]: + peak_list.append(i) + if sequence[i-1]>sequence[i] and sequence[i] 1 else 0 for i in mdegs_o] + + # Remove leading and trailing zeros + begin, end = 0, 0 + for i in range(len(mdegs)): + if mdegs[i] != 0: + begin = i + break + for i in range(len(mdegs) - 1, 0, -1): + if mdegs[i] != 0: + end = i + break + + ecd_dict[fileid] = { + 'wavelengths': wavelengths[begin: end + 1], + 'ecd': mdegs[begin: end + 1], + } + ecd_original_dict[fileid] = { + 'wavelengths': wavelengths, + 'ecd': mdegs, + } + + # Process spectrum sequences, extract peaks + ecd_final_list = [] + for key, itm in ecd_dict.items(): + # Uniform sampling + distance = int(len(itm['ecd']) / (fix_length - 1)) + sequence_org = [itm['ecd'][i] for i in range(0, len(itm['ecd']), distance)][:fix_length] + + # Normalize + sequence = normalize_func(sequence_org, norm_range=[-100, 100]) + + # Pad to fixed length + if len(sequence) < fix_length: + sequence.extend([0] * (fix_length - len(sequence))) + sequence_org.extend([0] * (fix_length - len(sequence_org))) + assert len(sequence) == fix_length + + # Generate peak mask + peak_mask = [0] * len(sequence) + for i in range(1, len(sequence) - 1): + if sequence[i - 1] < sequence[i] and sequence[i] > sequence[i + 1]: + if peak_mask[i - 1] != 2: + peak_mask[i - 1] = 1 + peak_mask[i] = 2 + if peak_mask[i + 1] != 2: + peak_mask[i + 1] = 1 + if sequence[i - 1] > sequence[i] and sequence[i] < sequence[i + 1]: + if peak_mask[i - 1] != 2: + peak_mask[i - 1] = 1 + peak_mask[i] = 2 + if peak_mask[i + 1] != 2: + peak_mask[i + 1] = 1 + + # Extract peak positions + peak_position_list = get_sequence_peak(sequence) + peak_number = len(peak_position_list) + assert peak_number < 9, f"Peak number {peak_number} >= 9" + + # Peak signs + peak_height_list = [] + for i in peak_position_list: + peak_height_list.append(1 if sequence[i] >= 0 else 0) + + # Pad to 9 peaks + peak_position_list = peak_position_list + [-1] * (9 - peak_number) + peak_height_list = peak_height_list + [-1] * (9 - peak_number) + query_padding_mask = get_key_padding_mask(paddle.to_tensor(peak_position_list)) + + tmp_dict = { + 'id': key, + 'seq': [0] + sequence, + 'seq_original': sequence_org, + 'seq_mask': peak_mask, + 'peak_num': peak_number, + 'peak_position': peak_position_list, + 'peak_height': peak_height_list, + 'query_mask': query_padding_mask.unsqueeze(0), + } + ecd_final_list.append(tmp_dict) + + ecd_final_list.sort(key=lambda x: x['id']) + return ecd_final_list, ecd_original_dict + + +def Construct_dataset(dataset, data_index, path): + """ + Construct graph data from raw features + Completely reuse the Construct_dataset logic from the prototype program + """ + graph_atom_bond = [] + graph_bond_angle = [] + + all_descriptor = np.load(os.path.join(path, 'descriptor_all_column.npy')) # (25847, 1826) + + for i in tqdm(range(len(dataset)), desc="Constructing graphs"): + data = dataset[i] + + # Collect atom features + atom_feature = [] + for name in atom_id_names: + atom_feature.append(data[name]) + + # Collect bond features + bond_feature = [] + for name in bond_id_names[0:3]: + bond_feature.append(data[name]) + + # Convert to Tensor + atom_feature = paddle.to_tensor(np.array(atom_feature).T, dtype='int64') + bond_feature = paddle.to_tensor(np.array(bond_feature).T, dtype='int64') + bond_float_feature = paddle.to_tensor(data['bond_length'].astype(paddle.get_default_dtype())) + bond_angle_feature = paddle.to_tensor(data['bond_angle'].astype(paddle.get_default_dtype())) + edge_index = paddle.to_tensor(data['edges'].T, dtype='int64') + bond_index = paddle.to_tensor(data['BondAngleGraph_edges'].T, dtype='int64') + data_index_int = paddle.to_tensor(np.array(data_index[i]), dtype='int64') + + # Add descriptor features (exactly the same as prototype program) + TPSA = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 820] / 100 + RASA = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 821] + RPSA = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 822] + MDEC = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 1568] + MATS = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 457] + + # Merge features + bond_feature = paddle.concat( + [bond_feature.astype(bond_float_feature.dtype), + bond_float_feature.reshape([-1, 1])], + axis=1 + ) + + bond_angle_feature = paddle.concat( + [bond_angle_feature.reshape([-1, 1]), TPSA.reshape([-1, 1])], + axis=1 + ) + bond_angle_feature = paddle.concat([bond_angle_feature, RASA.reshape([-1, 1])], axis=1) + bond_angle_feature = paddle.concat([bond_angle_feature, RPSA.reshape([-1, 1])], axis=1) + bond_angle_feature = paddle.concat([bond_angle_feature, MDEC.reshape([-1, 1])], axis=1) + bond_angle_feature = paddle.concat([bond_angle_feature, MATS.reshape([-1, 1])], axis=1) + + # Create Data objects + data_atom_bond = Data( + x=atom_feature, + edge_index=edge_index, + edge_attr=bond_feature, + data_index=data_index_int, + ) + data_bond_angle = Data( + edge_index=bond_index, + edge_attr=bond_angle_feature, + num_nodes=atom_feature.shape[0] + ) + + graph_atom_bond.append(data_atom_bond) + graph_bond_angle.append(data_bond_angle) + + return graph_atom_bond, graph_bond_angle + + + + +def GetAtomBondAngleDataset( + sample_path, + dataset_all, + index_all, + hand_idx_dict, + line_idx_dict +): + """ + Core function: build and return the processed graph dataset + + Args: + sample_path: ECD spectrum file path + dataset_all: info list loaded from npy + index_all: index list + hand_idx_dict: chiral pair mapping + line_idx_dict: line number mapping + + Returns: + dataset_graph_atom_bond: atom-bond graph list + dataset_graph_bond_angle: bond-angle graph list + """ + # 1. Read ECD spectrum sequences + ecd_sequences, ecd_original_sequences = read_total_ecd(sample_path) + + # 2. Construct graph data + total_graph_atom_bond, total_graph_bond_angle = Construct_dataset( + dataset_all, index_all, sample_path + ) + print("Case Before Process = ", len(total_graph_atom_bond), len(total_graph_bond_angle)) + + # 3. Attach spectrum sequence information to graph data + dataset_graph_atom_bond, dataset_graph_bond_angle = [], [] + + for itm in ecd_sequences: + line_num = itm['id'] - 1 + atom_bond = total_graph_atom_bond[line_num] + + # Attach spectrum information + atom_bond.sequence = paddle.to_tensor([itm['seq']]) + atom_bond.ecd_id = paddle.to_tensor(itm['id']) + atom_bond.seq_mask = paddle.to_tensor([itm['seq_mask']]) + atom_bond.seq_original = paddle.to_tensor([itm['seq_original']]) + atom_bond.peak_num = paddle.to_tensor([itm['peak_num']]) + atom_bond.peak_position = paddle.to_tensor([itm['peak_position']]) + atom_bond.peak_height = paddle.to_tensor([itm['peak_height']]) + atom_bond.query_mask = itm['query_mask'] + + dataset_graph_atom_bond.append(atom_bond) + dataset_graph_bond_angle.append(total_graph_bond_angle[line_num]) + + # 4. Enantiomer enhancement: add enantiomer samples + hand_id, unnamed_id = line_idx_dict[line_num]['hand_id'], line_idx_dict[line_num]['unnamed_id'] + another_line_num = -1 + + for alternative in hand_idx_dict[hand_id]: + if alternative['unnamed_id'] != unnamed_id: + another_line_num = alternative['line_number'] + break + + assert another_line_num != -1, f"cannot find the hand info of {line_num}" + + # Enantiomer: invert the spectrum + atom_bond_oppo = total_graph_atom_bond[another_line_num] + atom_bond_oppo.sequence = paddle.neg(paddle.to_tensor([itm['seq']])) + atom_bond_oppo.ecd_id = paddle.to_tensor(another_line_num + 1) + atom_bond_oppo.seq_mask = paddle.to_tensor([itm['seq_mask']]) + atom_bond_oppo.seq_original = paddle.neg(paddle.to_tensor([itm['seq_original']])) + atom_bond_oppo.peak_num = paddle.to_tensor([itm['peak_num']]) + atom_bond_oppo.peak_position = paddle.to_tensor([itm['peak_position']]) + atom_bond_oppo.peak_height = paddle.to_tensor([itm['peak_height']]) + atom_bond_oppo.query_mask = itm['query_mask'] + + dataset_graph_atom_bond.append(atom_bond_oppo) + dataset_graph_bond_angle.append(total_graph_bond_angle[another_line_num]) + + total_num = len(dataset_graph_atom_bond) + print("Case After Process = ", len(dataset_graph_atom_bond), len(dataset_graph_bond_angle)) + print('=================== Data prepared ================\n') + + return dataset_graph_atom_bond, dataset_graph_bond_angle + + +def build_ecformer_sample_builder(cfg: Optional[Dict[str, Any] | str]): + """Build sample builder""" + class_name, init_params = _parse_factory_cfg(cfg, default_class_name="StrictIndexSampleBuilder") + cls = _locate_class(class_name) + builder = cls(**init_params) + if not hasattr(builder, 'build'): + raise TypeError(f"Sample builder {class_name} must implement 'build' method") + logger.debug(f"Use sample builder: {class_name}") + return builder \ No newline at end of file diff --git a/ppmat/datasets/build_ir.py b/ppmat/datasets/build_ir.py new file mode 100644 index 00000000..f04139ad --- /dev/null +++ b/ppmat/datasets/build_ir.py @@ -0,0 +1,355 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import copy +import importlib +import json +import warnings +from pathlib import Path +from typing import Any, Dict, Optional, Tuple, List + +import paddle +import numpy as np +import pandas as pd +from scipy.signal import find_peaks +from ppmat.datasets.geometric_data_type.data import Data + +from ppmat.utils import download as download_utils +from ppmat.utils import logger +from ppmat.utils import ColoredTqdm as tqdm +from ppmat.utils.compound_tools import ( + atom_id_names, bond_id_names, bond_angle_float_names +) + + +def _locate_class(class_name: str): + if "." in class_name: + mod, cls = class_name.rsplit(".", 1) + return getattr(importlib.import_module(mod), cls) + return globals()[class_name] + + +def _parse_factory_cfg( + cfg: Optional[Dict[str, Any] | str], + *, + default_class_name: str, +) -> Tuple[str, Dict[str, Any]]: + """Parse factory configuration, compatible with multiple formats""" + if cfg is None: + return default_class_name, {} + + if isinstance(cfg, str): + return cfg, {} + + if not isinstance(cfg, dict): + raise TypeError(f"cfg must be None, str, or dict, got {type(cfg).__name__}") + + cfg = copy.deepcopy(cfg) + class_name = cfg.pop("__class_name__", None) or cfg.pop("class_name", None) or cfg.pop("type", None) + if not class_name: + raise ValueError("Factory cfg must include class name key") + + init_params = ( + cfg.pop("__init_params__", None) or + cfg.pop("init_params", None) or + cfg.pop("params", None) or {} + ) + if not isinstance(init_params, dict): + raise TypeError(f"init_params must be dict, got {type(init_params).__name__}") + + if cfg: + raise ValueError(f"Unsupported keys in cfg: {list(cfg.keys())}") + return class_name, init_params + + +class IRStrictIndexSampleBuilder: + """Build IR samples by strict index""" + def build(self, data_dir: Path, meta_file: str, spectra_dir: str, data_count: Optional[int] = None): + """Build sample list""" + samples = [] + meta_path = data_dir / meta_file + data = np.load(meta_path, allow_pickle=True).item() + + index_all = data['index_all'][:data_count] if data_count else data['index_all'] + + for idx in index_all: + samples.append({ + 'id': int(idx), + 'smiles': data['smiles_all'][data['index_all'].index(idx)] if hasattr(data['index_all'], 'index') else None, + 'spectrum_path': str(Path(spectra_dir) / f"{idx}.json") + }) + return samples + + +class DefaultIRDatasetDownloader: + """IR dataset downloader""" + def __init__(self, datasets_home: Optional[str] = None): + self.datasets_home = datasets_home or download_utils.DATASETS_HOME + + def download(self, url: str, md5: Optional[str] = None, force_download: bool = False) -> Path: + if force_download: + downloaded_root = download_utils.get_path_from_url( + url, self.datasets_home, md5sum=md5, check_exist=False, decompress=True + ) + else: + downloaded_root = download_utils.get_datasets_path_from_url(url, md5) + return Path(downloaded_root) + + +def build_ir_downloader(cfg: Optional[Dict[str, Any] | str]): + """Build downloader""" + class_name, init_params = _parse_factory_cfg(cfg, default_class_name="DefaultIRDatasetDownloader") + cls = _locate_class(class_name) + downloader = cls(**init_params) + if not hasattr(downloader, 'download'): + raise TypeError(f"Downloader {class_name} must implement 'download' method") + logger.debug(f"Use downloader: {class_name}") + return downloader + + +def build_ir_sample_builder(cfg: Optional[Dict[str, Any] | str]): + """Build sample builder""" + class_name, init_params = _parse_factory_cfg(cfg, default_class_name="IRStrictIndexSampleBuilder") + cls = _locate_class(class_name) + builder = cls(**init_params) + if not hasattr(builder, 'build'): + raise TypeError(f"Sample builder {class_name} must implement 'build' method") + logger.debug(f"Use sample builder: {class_name}") + return builder + + +# ==================== IR Specific Utility Functions ==================== + +IR_WAVELENGTH_MIN = 500 +IR_WAVELENGTH_MAX = 4000 +IR_STEP = 100 +DEFAULT_MAX_PEAKS = 15 + + +def get_key_padding_mask(tokens): + """Generate query padding mask""" + key_padding_mask = paddle.zeros(tokens.shape) + key_padding_mask[tokens == -1] = -paddle.inf + return key_padding_mask + + +def x_bin_position(real_x, distance=IR_STEP): + """Convert actual wavenumber to bin ID""" + return int((real_x - IR_WAVELENGTH_MIN) / distance) + + +def Construct_IR_Dataset(dataset, data_index, descriptor_path=None): + """ + Construct IR graph data from raw features + """ + graph_atom_bond = [] + graph_bond_angle = [] + + all_descriptor = None + if descriptor_path and os.path.exists(descriptor_path): + all_descriptor = np.load(descriptor_path) + + for i in tqdm(range(len(dataset)), desc="Constructing IR graphs"): + data = dataset[i] + + # Collect atom features + atom_feature = [] + for name in atom_id_names: + if name in data: + atom_feature.append(data[name]) + else: + if i == 0: + warnings.warn(f"Feature {name} not found in data, using zeros") + num_atoms = data.get('atomic_num', np.zeros(1)).shape[0] + atom_feature.append(np.zeros(num_atoms)) + + # Collect bond features + bond_feature = [] + for name in bond_id_names: + if name in data: + bond_feature.append(data[name]) + else: + if i == 0: + warnings.warn(f"Bond feature {name} not found, using zeros") + num_bonds = data.get('bond_dir', np.zeros(1)).shape[0] + bond_feature.append(np.zeros(num_bonds)) + + # Convert to Tensor + atom_feature = paddle.to_tensor(np.array(atom_feature).T, dtype='int64') + bond_feature = paddle.to_tensor(np.array(bond_feature).T, dtype='int64') + + bond_float_feature = paddle.to_tensor(data.get('bond_length', np.zeros(data['edges'].shape[0])).astype(paddle.get_default_dtype())) + bond_angle_feature = paddle.to_tensor(data.get('bond_angle', np.zeros(data.get('BondAngleGraph_edges', np.zeros((0,2))).shape[0])).astype(paddle.get_default_dtype())) + + edge_index = paddle.to_tensor(data['edges'].T, dtype='int64') + bond_index = paddle.to_tensor(data.get('BondAngleGraph_edges', np.zeros((0,2))).T, dtype='int64') + + data_index_int = paddle.to_tensor(np.array(int(data_index[i])), dtype='int64') + num_atoms = atom_feature.shape[0] + + # Merge bond features + bond_feature = paddle.concat( + [bond_feature.astype(bond_float_feature.dtype), + bond_float_feature.reshape([-1, 1])], + axis=1 + ) + + # Process bond angle features - ensure output is 6-dim + if bond_angle_feature.shape[0] > 0: + # Base feature: bond_angle + features = [bond_angle_feature.reshape([-1, 1])] + + # If descriptors exist,add 5 descriptor features + if all_descriptor is not None and i < all_descriptor.shape[0]: + TPSA = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 820] / 100 + RASA = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 821] + RPSA = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 822] + MDEC = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 1568] + MATS = paddle.ones([bond_angle_feature.shape[0]]) * all_descriptor[i, 457] + + features.extend([ + TPSA.reshape([-1, 1]), + RASA.reshape([-1, 1]), + RPSA.reshape([-1, 1]), + MDEC.reshape([-1, 1]), + MATS.reshape([-1, 1]) + ]) + else: + # If no descriptors, fill the remaining 5 dimensions with zeros + for _ in range(5): + features.append(paddle.zeros([bond_angle_feature.shape[0], 1])) + + # Concatenate to [E_ba, 6] + bond_angle_feature = paddle.concat(features, axis=1) + else: + # If no bond angles, create all-zero [0, 6] + bond_angle_feature = paddle.zeros([0, 6]) + + data_atom_bond = Data( + x=atom_feature, + edge_index=edge_index, + edge_attr=bond_feature, + data_index=data_index_int, + ) + + data_bond_angle = Data( + edge_index=bond_index, + edge_attr=bond_angle_feature if bond_angle_feature.shape[0] > 0 else paddle.zeros([0, 1]), + num_nodes=num_atoms, + ) + + graph_atom_bond.append(data_atom_bond) + graph_bond_angle.append(data_bond_angle) + + return graph_atom_bond, graph_bond_angle + + +def read_ir_spectra_by_ids(sample_path, index_all, max_peak=DEFAULT_MAX_PEAKS): + """ + Read IR spectrum files on demand + """ + ir_final_list = [] + + for fileid in tqdm(index_all, desc="Reading IR spectra by ID"): + filepath = os.path.join(sample_path, f"{fileid}.json") + + try: + with open(filepath, 'r') as f: + raw_ir_info = json.load(f) + + ir_x = raw_ir_info['x'] + ir_y = raw_ir_info['y_40'] + + peaks_raw, _ = find_peaks(x=ir_y, height=0.1, distance=100) + peaks_raw = peaks_raw.tolist() + + peak_num = min(len(peaks_raw), max_peak) + + if peak_num > 0: + if len(peaks_raw) > max_peak: + peaks = peaks_raw[len(peaks_raw)-max_peak:] + else: + peaks = peaks_raw + + peak_position_list = [x_bin_position(ir_x[i]) for i in peaks] + peak_height_list = [ir_y[i] for i in peaks] + else: + peak_position_list = [] + peak_height_list = [] + + peak_position_list = peak_position_list + [-1] * (max_peak - len(peak_position_list)) + peak_height_list = peak_height_list + [-1] * (max_peak - len(peak_height_list)) + + query_padding_mask = get_key_padding_mask(paddle.to_tensor(peak_position_list)) + + tmp_dict = { + 'id': fileid, + 'seq_40': ir_y, + 'peak_num': peak_num, + 'peak_position': peak_position_list, + 'peak_height': peak_height_list, + 'query_mask': query_padding_mask.unsqueeze(0), + } + ir_final_list.append(tmp_dict) + + except Exception as e: + warnings.warn(f"Error processing {fileid}.json: {e}") + continue + + ir_final_list.sort(key=lambda x: x['id']) + return ir_final_list + + +def GetIRDataset( + sample_path, + dataset_all, + index_all, +): + """ + Core function: build and return IR graph dataset + """ + # 1. Read IR spectrum sequences + ir_sequences = read_ir_spectra_by_ids(sample_path, index_all) + + # 2. Construct graph data + total_graph_atom_bond, total_graph_bond_angle = Construct_IR_Dataset( + dataset_all, index_all, sample_path + ) + print("Case Before Process = ", len(total_graph_atom_bond), len(total_graph_bond_angle)) + + # 3. Attach spectrum information to graph data + dataset_graph_atom_bond, dataset_graph_bond_angle = [], [] + + for i, itm in enumerate(ir_sequences): + atom_bond = total_graph_atom_bond[i] + + # Attach spectrum information + atom_bond.sequence = paddle.to_tensor([itm['seq_40']]) + atom_bond.ir_id = paddle.to_tensor(int(itm['id'])) + atom_bond.peak_num = paddle.to_tensor([itm['peak_num']]) + atom_bond.peak_position = paddle.to_tensor([itm['peak_position']]) + atom_bond.peak_height = paddle.to_tensor([itm['peak_height']]) + atom_bond.query_mask = itm['query_mask'] + + dataset_graph_atom_bond.append(atom_bond) + dataset_graph_bond_angle.append(total_graph_bond_angle[i]) + + total_num = len(dataset_graph_atom_bond) + print("Case After Process = ", len(dataset_graph_atom_bond), len(dataset_graph_bond_angle)) + print('=================== Data prepared ================\n') + + return dataset_graph_atom_bond, dataset_graph_bond_angle \ No newline at end of file diff --git a/ppmat/datasets/collate_fn.py b/ppmat/datasets/collate_fn.py index 9073af4c..c53ee7b3 100644 --- a/ppmat/datasets/collate_fn.py +++ b/ppmat/datasets/collate_fn.py @@ -302,3 +302,89 @@ def pad_sequence(sequences, batch_first=False, padding_value=0): out_tensor[:length, i, ...] = tensor return out_tensor + + +class ECDCollator(DefaultCollator): + def __call__(self, batch: List[Any]) -> Any: + batch = [list(x) for x in zip(*batch)] # transpose + for i in range(len(batch)): # Group into batches + batch[i] = Batch.from_data_list(batch[i]) + + batch0 = batch[0] + batch1 = batch[1] + + # Unpack Data to Tensor dictionary + batch_atom_bond, batch_bond_angle = batch0, batch1 + x, edge_index, edge_attr, query_mask = ( + batch_atom_bond.x, + batch_atom_bond.edge_index, + batch_atom_bond.edge_attr, + batch_atom_bond.query_mask, + ) + ba_edge_index, ba_edge_attr = ( + batch_bond_angle.edge_index, + batch_bond_angle.edge_attr, + ) + batch_data = batch_atom_bond.batch + pos_gt = batch_atom_bond.peak_position + height_gt = batch_atom_bond.peak_height + num_gt = batch_atom_bond.peak_num + return ( + { + "x": x, + "edge_index": edge_index, + "edge_attr": edge_attr, + "batch_data": batch_data, + "ba_edge_index": ba_edge_index, + "ba_edge_attr": ba_edge_attr, + "query_mask": query_mask, + }, + { + "peak_number": num_gt, + "peak_position": pos_gt, + "peak_height": height_gt, + }, + ) + + +class IRCollator(DefaultCollator): + """IR dataset specific collator, returns Tensor dictionary""" + + def __call__(self, batch: List[Any]) -> Any: + batch = [list(x) for x in zip(*batch)] # transpose + for i in range(len(batch)): + batch[i] = Batch.from_data_list(batch[i]) + + batch_atom_bond, batch_bond_angle = batch[0], batch[1] + + x, edge_index, edge_attr, query_mask = ( + batch_atom_bond.x, + batch_atom_bond.edge_index, + batch_atom_bond.edge_attr, + batch_atom_bond.query_mask, + ) + ba_edge_index, ba_edge_attr = ( + batch_bond_angle.edge_index, + batch_bond_angle.edge_attr, + ) + batch_data = batch_atom_bond.batch + pos_gt = batch_atom_bond.peak_position + height_gt = batch_atom_bond.peak_height + num_gt = batch_atom_bond.peak_num + + return ( + { + "x": x, + "edge_index": edge_index, + "edge_attr": edge_attr, + "batch_data": batch_data, + "ba_edge_index": ba_edge_index, + "ba_edge_attr": ba_edge_attr, + "query_mask": query_mask, + }, + { + "peak_number": num_gt, + "peak_position": pos_gt, + "peak_height": height_gt, + }, + ) \ No newline at end of file diff --git a/ppmat/datasets/ecd_dataset.py b/ppmat/datasets/ecd_dataset.py new file mode 100644 index 00000000..f3174c96 --- /dev/null +++ b/ppmat/datasets/ecd_dataset.py @@ -0,0 +1,162 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import numpy as np +import pandas as pd +import paddle +from paddle.io import Dataset +from pathlib import Path +from typing import Dict, Optional + +from ppmat.utils import PlaceEnv +from ppmat.utils.compound_tools import get_atom_feature_dims, get_bond_feature_dims +from ppmat.datasets.build_ecd import build_ecformer_sample_builder +from ppmat.datasets.build_ecd import build_ecformer_downloader +from ppmat.datasets.build_ecd import GetAtomBondAngleDataset + +_cache = () + +class ECDDataset(Dataset): + """ + ECDFormer ECD Spectrum Prediction Dataset + + Data source: https://paddle-org.bj.bcebos.com/paddlematerials/datasets/ECD/ECD.tar.gz + """ + + url = "https://paddle-org.bj.bcebos.com/paddlematerials/datasets/ECD/ECD.tar.gz" + md5 = "aa86eddee2397dbc37c4b7b9a45b1e27" + + def __init__( + self, + data_path: str, + split: Optional[str] = None, # 'train'/'val'/'test' + data_count: Optional[int] = None, + sample_builder_cfg: Optional[Dict] = None, + downloader_cfg: Optional[Dict] = None, + download: bool = True, + force_download: bool = False, + use_geometry_enhanced: bool = True, + use_column_info: bool = False, + ): + super().__init__() + + self.data_path = Path(data_path) + self.split = split + self.data_count = data_count + self.use_geometry_enhanced = use_geometry_enhanced + self.use_column_info = use_column_info + + # Build components + self.sample_builder = build_ecformer_sample_builder(sample_builder_cfg) + self.downloader = build_ecformer_downloader(downloader_cfg) + + # Handle download + if force_download or (not self._check_files() and download): + self.downloaded_root = self.downloader.download( + self.url, self.md5, force_download=force_download + ) + self.data_path = self.downloaded_root + + # Load data + self._load_data() + + def _check_files(self): + """Check if necessary files exist""" + npy_path = self.data_path / 'ecd_column_charity_new_smiles.npy' + csv_path = self.data_path / 'ecd_info.csv' + + if not npy_path.exists(): + return False + if not csv_path.exists(): + return False + return True + + def _load_data(self): + """Load all data""" + # 1. Load npy file + npy_path = self.data_path / 'ecd_column_charity_new_smiles.npy' + if not npy_path.exists(): + raise FileNotFoundError(f"npy file not found: {npy_path}") + + self.ecd_dataset = np.load(npy_path, allow_pickle=True).tolist() + + # 2. Load csv file + csv_path = self.data_path / 'ecd_info.csv' + if not csv_path.exists(): + raise FileNotFoundError(f"csv file not found: {csv_path}") + + self.ecd_info = pd.read_csv(csv_path, encoding='gbk') + + # 3. Extract data + self.dataset_all = [item['info'] for item in self.ecd_dataset] + self.smiles_all = [item['smiles'] for item in self.ecd_dataset] + self.index_all = self.ecd_info['Unnamed: 0'].values + + # 4. Build chiral pair mapping + self._build_chiral_mapping() + + # 5. Build graph dataset + self._build_graph_dataset() + + def _build_chiral_mapping(self): + """Build chiral enantiomer mapping""" + self.hand_idx_dict = {} + self.line_idx_dict = {} + + for i, itm in enumerate(self.ecd_dataset): + self.line_idx_dict[i] = { + 'hand_id': itm['hand_id'], + 'unnamed_id': itm['id'], + 'smiles': itm['smiles'] + } + + if itm['hand_id'] not in self.hand_idx_dict: + self.hand_idx_dict[itm['hand_id']] = [] + self.hand_idx_dict[itm['hand_id']].append({ + 'line_number': i, + 'unnamed_id': itm['id'], + 'smiles': itm['smiles'] + }) + + @PlaceEnv(paddle.CPUPlace()) + def _build_graph_dataset(self): + """Build graph dataset""" + global _cache + + if len(_cache) > 0: + self.graph_atom_bond, self.graph_bond_angle = _cache + return + + self.graph_atom_bond, self.graph_bond_angle = GetAtomBondAngleDataset( + sample_path=str(self.data_path), + dataset_all=self.dataset_all, + index_all=self.index_all, + hand_idx_dict=self.hand_idx_dict, + line_idx_dict=self.line_idx_dict + ) + + _cache = (self.graph_atom_bond, self.graph_bond_angle) + + assert len(self.graph_atom_bond) == len(self.graph_bond_angle) + + def __len__(self): + return len(self.graph_atom_bond) + + @PlaceEnv(paddle.CPUPlace()) + def __getitem__(self, idx): + """Returns (atom_bond_graph, bond_angle_graph)""" + return self.graph_atom_bond[idx], self.graph_bond_angle[idx] \ No newline at end of file diff --git a/ppmat/datasets/ir_dataset.py b/ppmat/datasets/ir_dataset.py new file mode 100644 index 00000000..5d73db47 --- /dev/null +++ b/ppmat/datasets/ir_dataset.py @@ -0,0 +1,168 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import numpy as np +import paddle +from paddle.io import Dataset +from pathlib import Path +from typing import Dict, Optional + +from ppmat.utils import PlaceEnv +from ppmat.datasets.build_ir import ( + build_ir_sample_builder, + build_ir_downloader, + read_ir_spectra_by_ids, + Construct_IR_Dataset, +) + +_cache = {} + + +class IRDataset(Dataset): + """ + ECFormer IR Spectrum Prediction Dataset + + Supports three preloading modes: + - '100': Small dataset with 100 samples (default, for quick testing) + - '10000': Medium dataset with 10,000 samples + - 'all': All samples (may be very large) + """ + + url = "https://paddle-org.bj.bcebos.com/paddlematerials/datasets/IR/IR.tar.gz" + md5 = "e1ea5624cf9b92b3657933245196f5dc" + + def __init__( + self, + data_path: str, + mode: str = '100', + split: Optional[str] = None, + data_count: Optional[int] = None, + sample_builder_cfg: Optional[Dict] = None, + downloader_cfg: Optional[Dict] = None, + download: bool = True, + force_download: bool = False, + use_geometry_enhanced: bool = True, + use_cache: bool = True, + ): + super().__init__() + + self.data_path = Path(data_path) + self.mode = mode + self.split = split + self.data_count = data_count + self.use_geometry_enhanced = use_geometry_enhanced + self.use_cache = use_cache + + cache_key = f"{data_path}_{mode}_{use_geometry_enhanced}" + + # If cache is enabled and hit, return directly + if use_cache and cache_key in _cache: + cached_data = _cache[cache_key] + self.graph_atom_bond = cached_data['atom_bond'] + self.graph_bond_angle = cached_data['bond_angle'] + self.smiles_list = cached_data.get('smiles', []) + return + + # Build components + self.sample_builder = build_ir_sample_builder(sample_builder_cfg) + self.downloader = build_ir_downloader(downloader_cfg) + + # Handle download + if force_download or (not self._check_files() and download): + self.downloaded_root = self.downloader.download( + self.url, self.md5, force_download=force_download + ) + self.data_path = self.downloaded_root + + # Load data + self._load_data() + + # Store in cache + if use_cache: + _cache[cache_key] = { + 'atom_bond': self.graph_atom_bond, + 'bond_angle': self.graph_bond_angle, + 'smiles': self.smiles_list + } + + def _check_files(self): + """Check if necessary files exist""" + meta_path = self.data_path / f'ir_column_charity_{self.mode}.npy' + spectra_path = self.data_path / 'qm9_ir_spec' + + if not meta_path.exists(): + return False + if not spectra_path.exists(): + return False + return True + + def _load_data(self): + """Load all data""" + # 1. Load metadata file + meta_path = self.data_path / f'ir_column_charity_{self.mode}.npy' + if not meta_path.exists(): + raise FileNotFoundError(f"IR meta file {meta_path} not found") + + data = np.load(meta_path, allow_pickle=True).item() + dataset_all = data['dataset_all'] + smiles_all = data['smiles_all'] + index_all = data['index_all'] + + print(f"Loaded meta data: {len(index_all)} samples") + + # 2. Read IR spectra on demand + spectra_path = self.data_path / 'qm9_ir_spec' + self.ir_sequences = read_ir_spectra_by_ids(str(spectra_path), index_all) + + print(f"Loaded {len(self.ir_sequences)} IR spectra") + + # 3. Construct graph data + descriptor_path = self.data_path / 'descriptor_all_column.npy' + if not descriptor_path.exists(): + descriptor_path = None + + total_graph_atom_bond, total_graph_bond_angle = Construct_IR_Dataset( + dataset_all, index_all, descriptor_path + ) + + # 4. Attach spectrum information to graph data + self.graph_atom_bond = [] + self.graph_bond_angle = [] + self.smiles_list = [] + + for i, itm in enumerate(self.ir_sequences): + atom_bond = total_graph_atom_bond[i] + + atom_bond.sequence = paddle.to_tensor([itm['seq_40']]) + atom_bond.ir_id = paddle.to_tensor(int(itm['id'])) + atom_bond.peak_num = paddle.to_tensor([itm['peak_num']]) + atom_bond.peak_position = paddle.to_tensor([itm['peak_position']]) + atom_bond.peak_height = paddle.to_tensor([itm['peak_height']]) + atom_bond.query_mask = itm['query_mask'] + + self.graph_atom_bond.append(atom_bond) + self.graph_bond_angle.append(total_graph_bond_angle[i]) + self.smiles_list.append(smiles_all[i]) + + print(f"Final dataset size: {len(self.graph_atom_bond)}") + + def __len__(self): + return len(self.graph_atom_bond) + + @PlaceEnv(paddle.CPUPlace()) + def __getitem__(self, idx): + """Returns (atom_bond_graph, bond_angle_graph)""" + return self.graph_atom_bond[idx], self.graph_bond_angle[idx] \ No newline at end of file diff --git a/ppmat/losses/__init__.py b/ppmat/losses/__init__.py index e5258105..085290c9 100644 --- a/ppmat/losses/__init__.py +++ b/ppmat/losses/__init__.py @@ -18,6 +18,8 @@ from ppmat.losses.l1_loss import L1Loss from ppmat.losses.l1_loss import MAELoss from ppmat.losses.l1_loss import SmoothL1Loss +from ppmat.losses.ecd_loss import ECDLoss +from ppmat.losses.ir_loss import IRLoss from ppmat.losses.loss_warper import LossWarper from ppmat.losses.mse_loss import MSELoss @@ -25,6 +27,8 @@ "MSELoss", "L1Loss", "SmoothL1Loss", + "ECDLoss", + "IRLoss", "MAELoss", "HuberLoss", "LossWarper", diff --git a/ppmat/losses/ecd_loss.py b/ppmat/losses/ecd_loss.py new file mode 100644 index 00000000..e659417d --- /dev/null +++ b/ppmat/losses/ecd_loss.py @@ -0,0 +1,132 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the License); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn + + +class ECDLoss(nn.Layer): + """Loss function for ECFormer ECD task. + + Combines three cross-entropy losses for peak number, position, and symbol, + with symbol loss weighted as in the original paper. + """ + + def __init__(self, loss_weight_height=2.0, num_position_classes=20, height_classes=2): + """ + Args: + loss_weight_height (float): Weight for peak symbol loss (2.0 in paper) + num_position_classes (int): Number of position classes (default: 20) + height_classes (int): Number of symbol classes (default: 2: positive/negative) + """ + super().__init__() + self.ce_loss = nn.CrossEntropyLoss() + self.loss_weight_height = loss_weight_height + self.num_position_classes = num_position_classes + self.height_classes = height_classes + + # Accumulators for epoch-level statistics + self.reset() + + def forward(self, predictions, targets): + """ + Compute ECD task losses. + + Args: + predictions (dict): Model outputs containing: + - peak_number (Tensor): [batch_size, max_peaks] logits for peak count + - peak_position (Tensor): [batch_size, max_peaks, num_position_classes] logits for positions + - peak_height (Tensor): [batch_size, max_peaks, height_classes] logits for symbols + targets (dict): Ground truth containing: + - peak_num (Tensor): [batch_size] true peak counts + - peak_position (Tensor): [batch_size, max_peaks] true position labels + - peak_height (Tensor): [batch_size, max_peaks] true symbol labels + + Returns: + dict: Loss components and total loss + """ + # Peak number loss + loss_num = self.ce_loss(predictions['peak_number'], targets['peak_number']) + + batch_size = targets['peak_number'].shape[0] + + loss_pos_total = 0.0 + loss_height_total = 0.0 + valid_samples = 0 + + for i in range(batch_size): + n_peaks = int(targets['peak_number'][i]) + if n_peaks == 0: + continue + + # Position loss (only for valid peaks) + pos_pred = predictions['peak_position'][i, :n_peaks, :].reshape([-1, self.num_position_classes]) + pos_gt = targets['peak_position'][i, :n_peaks].reshape([-1]) + loss_pos_total += self.ce_loss(pos_pred, pos_gt) + + # Symbol loss + height_pred = predictions['peak_height'][i, :n_peaks, :].reshape([-1, self.height_classes]) + height_gt = targets['peak_height'][i, :n_peaks].reshape([-1]) + loss_height_total += self.ce_loss(height_pred, height_gt) + + valid_samples += 1 + + if valid_samples > 0: + loss_pos = loss_pos_total / valid_samples + loss_height = loss_height_total / valid_samples + else: + loss_pos = paddle.to_tensor(0.0) + loss_height = paddle.to_tensor(0.0) + + # Total loss with weighted symbol term + total_loss = loss_num + self.loss_weight_height * loss_height + loss_pos + + # Update accumulators for epoch statistics + self._accumulate(loss_num, loss_pos, loss_height, valid_samples) + + return { + "loss": total_loss, + "loss_num": loss_num, + "loss_pos": loss_pos, + "loss_height": loss_height, + } + + def _accumulate(self, loss_num, loss_pos, loss_height, valid_samples): + """Accumulate losses for epoch-level statistics.""" + self.loss_num_sum += loss_num.item() if hasattr(loss_num, 'item') else loss_num + self.loss_pos_sum += loss_pos.item() if hasattr(loss_pos, 'item') else loss_pos + self.loss_height_sum += loss_height.item() if hasattr(loss_height, 'item') else loss_height + self.total_samples += valid_samples + + def reset(self): + """Reset accumulated statistics.""" + self.loss_num_sum = 0.0 + self.loss_pos_sum = 0.0 + self.loss_height_sum = 0.0 + self.total_samples = 0 + + def log_epoch_metrics(self): + """Return epoch-level loss statistics.""" + if self.total_samples == 0: + return { + "train_epoch/loss_num": -1.0, + "train_epoch/loss_pos": -1.0, + "train_epoch/loss_height": -1.0, + } + + return { + "train_epoch/loss_num": self.loss_num_sum / self.total_samples, + "train_epoch/loss_pos": self.loss_pos_sum / self.total_samples, + "train_epoch/loss_height": self.loss_height_sum / self.total_samples, + } \ No newline at end of file diff --git a/ppmat/losses/ir_loss.py b/ppmat/losses/ir_loss.py new file mode 100644 index 00000000..97cffe25 --- /dev/null +++ b/ppmat/losses/ir_loss.py @@ -0,0 +1,132 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the License); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn + + +class IRLoss(nn.Layer): + """Loss function for ECFormer IR task. + + Combines cross-entropy loss for peak position and peak number, + and MSE loss for peak intensity (height) regression. + """ + + def __init__(self, num_position_classes=36, use_height_prediction=True): + """ + Args: + num_position_classes (int): Number of position classes (default: 36 for IR) + use_height_prediction (bool): Whether to use height (intensity) regression loss + """ + super().__init__() + self.ce_loss = nn.CrossEntropyLoss() + self.mse_loss = nn.MSELoss(reduction='mean') + self.num_position_classes = num_position_classes + self.use_height_prediction = use_height_prediction + + # Accumulators for epoch-level statistics + self.reset() + + def forward(self, predictions, targets): + """ + Compute IR task losses. + + Args: + predictions (dict): Model outputs containing: + - peak_number (Tensor): [batch_size, max_peaks+1] logits for peak count + - peak_position (Tensor): [batch_size, max_peaks, num_position_classes] logits for positions + - peak_height (Tensor, optional): [batch_size, max_peaks] predicted intensity values + targets (dict): Ground truth containing: + - peak_number (Tensor): [batch_size] true peak counts + - peak_position (Tensor): [batch_size, max_peaks] true position labels + - peak_height (Tensor): [batch_size, max_peaks] true intensity values + + Returns: + dict: Loss components and total loss + """ + # Peak number loss + loss_num = self.ce_loss(predictions['peak_number'], targets['peak_number']) + + batch_size = targets['peak_number'].shape[0] + + loss_pos_total = 0.0 + loss_height_total = 0.0 + valid_samples = 0 + + for i in range(batch_size): + n_peaks = int(targets['peak_number'][i]) + if n_peaks == 0: + continue + + # Position loss (cross-entropy) + pos_pred = predictions['peak_position'][i, :n_peaks, :].reshape([-1, self.num_position_classes]) + pos_gt = targets['peak_position'][i, :n_peaks].reshape([-1]) + loss_pos_total += self.ce_loss(pos_pred, pos_gt) + + # Height loss (MSE regression) if enabled + if self.use_height_prediction and 'peak_height' in predictions: + height_pred = predictions['peak_height'][i, :n_peaks].reshape([-1]) + height_gt = targets['peak_height'][i, :n_peaks].reshape([-1]) + loss_height_total += self.mse_loss(height_pred, height_gt) + + valid_samples += 1 + + # Average losses over valid samples + loss_pos = loss_pos_total / valid_samples if valid_samples > 0 else paddle.to_tensor(0.0) + total_loss = loss_num + loss_pos + + if self.use_height_prediction: + loss_height = loss_height_total / valid_samples if valid_samples > 0 else paddle.to_tensor(0.0) + total_loss += loss_height + else: + loss_height = paddle.to_tensor(0.0) + + # Update accumulators for epoch statistics + self._accumulate(loss_num, loss_pos, loss_height, valid_samples) + + return { + "loss": total_loss, + "loss_num": loss_num, + "loss_pos": loss_pos, + "loss_height": loss_height, + } + + def _accumulate(self, loss_num, loss_pos, loss_height, valid_samples): + """Accumulate losses for epoch-level statistics.""" + self.loss_num_sum += loss_num.item() if hasattr(loss_num, 'item') else loss_num + self.loss_pos_sum += loss_pos.item() if hasattr(loss_pos, 'item') else loss_pos + self.loss_height_sum += loss_height.item() if hasattr(loss_height, 'item') else loss_height + self.total_samples += valid_samples + + def reset(self): + """Reset accumulated statistics.""" + self.loss_num_sum = 0.0 + self.loss_pos_sum = 0.0 + self.loss_height_sum = 0.0 + self.total_samples = 0 + + def log_epoch_metrics(self): + """Return epoch-level loss statistics.""" + if self.total_samples == 0: + return { + "train_epoch/loss_num": -1.0, + "train_epoch/loss_pos": -1.0, + "train_epoch/loss_height": -1.0, + } + + return { + "train_epoch/loss_num": self.loss_num_sum / self.total_samples, + "train_epoch/loss_pos": self.loss_pos_sum / self.total_samples, + "train_epoch/loss_height": self.loss_height_sum / self.total_samples, + } \ No newline at end of file diff --git a/ppmat/metrics/__init__.py b/ppmat/metrics/__init__.py index a0e3fb75..7f0b4169 100644 --- a/ppmat/metrics/__init__.py +++ b/ppmat/metrics/__init__.py @@ -17,11 +17,15 @@ import paddle # noqa from ppmat.metrics.csp_metric import CSPMetric +from ppmat.metrics.ecd_metric import ECDMetrics +from ppmat.metrics.ir_metric import IRMetrics from ppmat.metrics.diffnmr_streaming_adapter import DiffNMRStreamingAdapter __all__ = [ "build_metric", "CSPMetric", + "ECDMetrics", + "IRMetrics", "DiffNMRStreamingAdapter", # "DiffNMRMetric", # "NLL", "CrossEntropyMetric", "SumExceptBatchMetric", "SumExceptBatchKL", diff --git a/ppmat/metrics/ecd_metric.py b/ppmat/metrics/ecd_metric.py new file mode 100644 index 00000000..54b74041 --- /dev/null +++ b/ppmat/metrics/ecd_metric.py @@ -0,0 +1,316 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +from typing import Dict + + +# ========================= +# Utilities +# ========================= + +def _is_dist(): + """Check if distributed training is initialized.""" + try: + import paddle.distributed as dist + return dist.is_initialized() and dist.get_world_size() > 1 + except Exception: + return False + + +def _all_reduce_sum_(t: paddle.Tensor) -> paddle.Tensor: + """In-place SUM all_reduce if distributed; returns t.""" + if _is_dist(): + import paddle.distributed as dist + dist.all_reduce(t, op=dist.ReduceOp.SUM) + return t + + +def _to_f32(x) -> paddle.Tensor: + """Convert to float32 tensor.""" + return ( + paddle.to_tensor(x, dtype="float32") + if not isinstance(x, paddle.Tensor) + else x.astype("float32") + ) + + +# ========================= +# ECD Metrics +# ========================= + +class ECDMetrics(nn.Layer): + """Evaluation metrics for ECFormer ECD task. + + Computes: + - Number-RMSE: RMSE of predicted vs true peak count + - Position-RMSE: RMSE of predicted vs true peak positions (class indices) + - Symbol-Acc: Accuracy of predicted peak symbols (positive/negative) + - First-Symbol-Acc: Accuracy of the first peak's symbol + """ + + def __init__(self, num_position_classes=20, max_peaks=9): + """ + Args: + num_position_classes (int): Number of position classes (default: 20) + max_peaks (int): Maximum number of peaks (default: 9) + """ + super().__init__() + self.num_position_classes = num_position_classes + self.max_peaks = max_peaks + + # Accumulators for streaming metrics + self.reset() + + def reset(self): + """Reset all accumulated statistics.""" + self.num_rmse_sum = _to_f32(0.0) + self.pos_rmse_sum = _to_f32(0.0) + self.symbol_correct = _to_f32(0.0) + self.symbol_total = _to_f32(0.0) + self.first_symbol_correct = _to_f32(0.0) + self.first_symbol_total = _to_f32(0.0) + self.num_samples = _to_f32(0.0) + + def update(self, predictions: Dict[str, paddle.Tensor], targets: Dict[str, paddle.Tensor]): + """ + Update metrics with a batch of predictions and targets. + + Args: + predictions: dict from model forward + - peak_number: [batch_size, max_peaks] logits for peak count + - peak_position: [batch_size, max_peaks, num_position_classes] logits for positions + - peak_height: [batch_size, max_peaks, 2] logits for symbols + targets: dict from dataloader + - peak_num: [batch_size] true peak counts + - peak_position: [batch_size, max_peaks] true position labels + - peak_height: [batch_size, max_peaks] true symbol labels + """ + batch_size = targets['peak_num'].shape[0] + + # Peak number predictions + pred_nums = predictions['peak_number'].argmax(axis=1) + true_nums = targets['peak_num'] + + # Number RMSE accumulation + num_errors = (pred_nums - true_nums).astype('float32') + self.num_rmse_sum += paddle.sum(paddle.square(num_errors)) + + # Process each sample for position and symbol metrics + for i in range(batch_size): + n_true = int(true_nums[i]) + n_pred = int(pred_nums[i]) + + if n_true > 0 and n_pred > 0: + n_match = min(n_true, n_pred) + + # Position errors (only on matched peaks) + pos_true = targets['peak_position'][i, :n_match].astype('int64') + pos_pred = predictions['peak_position'][i, :n_match, :].argmax(axis=1) + pos_errors = (pos_pred - pos_true).astype('float32') + self.pos_rmse_sum += paddle.sum(paddle.square(pos_errors)) + + # Symbol accuracy + height_true = targets['peak_height'][i, :n_match] + height_pred = predictions['peak_height'][i, :n_match, :].argmax(axis=1) + correct = (height_true == height_pred).astype('float32') + self.symbol_correct += paddle.sum(correct) + self.symbol_total += _to_f32(n_match) + + # First peak symbol accuracy + if height_true[0] == height_pred[0]: + self.first_symbol_correct += _to_f32(1.0) + self.first_symbol_total += _to_f32(1.0) + + self.num_samples += _to_f32(batch_size) + + def accumulate(self) -> Dict[str, float]: + """ + Compute accumulated metrics. + + Returns: + dict: Dictionary containing all metrics + """ + # Distributed reduction + num_rmse_sum = _all_reduce_sum_(self.num_rmse_sum.clone()) + pos_rmse_sum = _all_reduce_sum_(self.pos_rmse_sum.clone()) + symbol_correct = _all_reduce_sum_(self.symbol_correct.clone()) + symbol_total = _all_reduce_sum_(self.symbol_total.clone()) + first_correct = _all_reduce_sum_(self.first_symbol_correct.clone()) + first_total = _all_reduce_sum_(self.first_symbol_total.clone()) + num_samples = _all_reduce_sum_(self.num_samples.clone()) + + # Compute final metrics + num_rmse = paddle.sqrt(num_rmse_sum / paddle.maximum(num_samples, _to_f32(1.0))).item() + + # For position RMSE, we need to average over matched peaks + pos_rmse = paddle.sqrt( + pos_rmse_sum / paddle.maximum(symbol_total, _to_f32(1.0)) + ).item() + + symbol_acc = (symbol_correct / paddle.maximum(symbol_total, _to_f32(1.0))).item() + first_symbol_acc = (first_correct / paddle.maximum(first_total, _to_f32(1.0))).item() + + return { + 'num_rmse': num_rmse, + 'pos_rmse': pos_rmse, + 'symbol_acc': symbol_acc, + 'first_symbol_acc': first_symbol_acc, + } + + def forward(self, predictions: Dict[str, paddle.Tensor], targets: Dict[str, paddle.Tensor]) -> Dict[str, float]: + """ + Compute metrics for a single batch (non-streaming version). + + Args: + predictions: dict from model forward + targets: dict from dataloader + + Returns: + dict: Dictionary containing all metrics for this batch + """ + self.reset() + self.update(predictions, targets) + return self.accumulate() + + +class ECFormerIRMetrics(nn.Layer): + """Evaluation metrics for ECFormer IR task. + + Computes: + - Number-RMSE: RMSE of predicted vs true peak count + - Position-RMSE: RMSE of predicted vs true peak positions (class indices) + - Height-RMSE: RMSE of predicted vs true peak intensities + """ + + def __init__(self, num_position_classes=36, max_peaks=15, use_height_prediction=True): + """ + Args: + num_position_classes (int): Number of position classes (default: 36 for IR) + max_peaks (int): Maximum number of peaks (default: 15) + use_height_prediction (bool): Whether height prediction is used + """ + super().__init__() + self.num_position_classes = num_position_classes + self.max_peaks = max_peaks + self.use_height_prediction = use_height_prediction + + # Accumulators for streaming metrics + self.reset() + + def reset(self): + """Reset all accumulated statistics.""" + self.num_rmse_sum = _to_f32(0.0) + self.pos_rmse_sum = _to_f32(0.0) + self.height_rmse_sum = _to_f32(0.0) + self.pos_count = _to_f32(0.0) + self.height_count = _to_f32(0.0) + self.num_samples = _to_f32(0.0) + + def update(self, predictions: Dict[str, paddle.Tensor], targets: Dict[str, paddle.Tensor]): + """ + Update metrics with a batch of predictions and targets. + + Args: + predictions: dict from model forward + - peak_number: [batch_size, max_peaks+1] logits for peak count + - peak_position: [batch_size, max_peaks, num_position_classes] logits for positions + - peak_height (optional): [batch_size, max_peaks] predicted intensity values + targets: dict from dataloader + - peak_num: [batch_size] true peak counts + - peak_position: [batch_size, max_peaks] true position labels + - peak_height: [batch_size, max_peaks] true intensity values + """ + batch_size = targets['peak_num'].shape[0] + + # Peak number predictions + pred_nums = predictions['peak_number'].argmax(axis=1) + true_nums = targets['peak_num'] + + # Number RMSE accumulation + num_errors = (pred_nums - true_nums).astype('float32') + self.num_rmse_sum += paddle.sum(paddle.square(num_errors)) + + # Process each sample for position and height metrics + for i in range(batch_size): + n_true = int(true_nums[i]) + n_pred = int(pred_nums[i]) + + if n_true > 0 and n_pred > 0: + n_match = min(n_true, n_pred) + + # Position errors (only on matched peaks) + pos_true = targets['peak_position'][i, :n_match].astype('int64') + pos_pred = predictions['peak_position'][i, :n_match, :].argmax(axis=1) + pos_errors = (pos_pred - pos_true).astype('float32') + self.pos_rmse_sum += paddle.sum(paddle.square(pos_errors)) + self.pos_count += _to_f32(n_match) + + # Height errors if enabled + if self.use_height_prediction and 'peak_height' in predictions: + height_true = targets['peak_height'][i, :n_match].astype('float32') + height_pred = predictions['peak_height'][i, :n_match].reshape([-1]) + height_errors = height_true - height_pred + self.height_rmse_sum += paddle.sum(paddle.square(height_errors)) + self.height_count += _to_f32(n_match) + + self.num_samples += _to_f32(batch_size) + + def accumulate(self) -> Dict[str, float]: + """ + Compute accumulated metrics. + + Returns: + dict: Dictionary containing all metrics + """ + # Distributed reduction + num_rmse_sum = _all_reduce_sum_(self.num_rmse_sum.clone()) + pos_rmse_sum = _all_reduce_sum_(self.pos_rmse_sum.clone()) + height_rmse_sum = _all_reduce_sum_(self.height_rmse_sum.clone()) + pos_count = _all_reduce_sum_(self.pos_count.clone()) + height_count = _all_reduce_sum_(self.height_count.clone()) + num_samples = _all_reduce_sum_(self.num_samples.clone()) + + # Compute final metrics + num_rmse = paddle.sqrt(num_rmse_sum / paddle.maximum(num_samples, _to_f32(1.0))).item() + pos_rmse = paddle.sqrt(pos_rmse_sum / paddle.maximum(pos_count, _to_f32(1.0))).item() + + metrics = { + 'num_rmse': num_rmse, + 'pos_rmse': pos_rmse, + } + + if self.use_height_prediction: + height_rmse = paddle.sqrt( + height_rmse_sum / paddle.maximum(height_count, _to_f32(1.0)) + ).item() + metrics['height_rmse'] = height_rmse + + return metrics + + def forward(self, predictions: Dict[str, paddle.Tensor], targets: Dict[str, paddle.Tensor]) -> Dict[str, float]: + """ + Compute metrics for a single batch (non-streaming version). + + Args: + predictions: dict from model forward + targets: dict from dataloader + + Returns: + dict: Dictionary containing all metrics for this batch + """ + self.reset() + self.update(predictions, targets) + return self.accumulate() \ No newline at end of file diff --git a/ppmat/metrics/ir_metric.py b/ppmat/metrics/ir_metric.py new file mode 100644 index 00000000..5a7c3dcc --- /dev/null +++ b/ppmat/metrics/ir_metric.py @@ -0,0 +1,177 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +from typing import Dict + + +# ========================= +# Utilities +# ========================= + +def _is_dist(): + """Check if distributed training is initialized.""" + try: + import paddle.distributed as dist + return dist.is_initialized() and dist.get_world_size() > 1 + except Exception: + return False + + +def _all_reduce_sum_(t: paddle.Tensor) -> paddle.Tensor: + """In-place SUM all_reduce if distributed; returns t.""" + if _is_dist(): + import paddle.distributed as dist + dist.all_reduce(t, op=dist.ReduceOp.SUM) + return t + + +def _to_f32(x) -> paddle.Tensor: + """Convert to float32 tensor.""" + return ( + paddle.to_tensor(x, dtype="float32") + if not isinstance(x, paddle.Tensor) + else x.astype("float32") + ) + + +# ========================= +# IR Metrics +# ========================= + +class IRMetrics(nn.Layer): + """Evaluation metrics for ECFormer IR task. + + Computes: + - Number-RMSE: RMSE of predicted vs true peak count + - Position-RMSE: RMSE of predicted vs true peak positions (class indices) + - Height-RMSE: RMSE of predicted vs true peak intensities (if enabled) + """ + + def __init__(self, use_height_prediction=True): + """ + Args: + use_height_prediction (bool): Whether height prediction is used + """ + super().__init__() + self.use_height_prediction = use_height_prediction + + # Accumulators for streaming metrics + self.reset() + + def reset(self): + """Reset all accumulated statistics.""" + self.num_rmse_sum = _to_f32(0.0) + self.pos_rmse_sum = _to_f32(0.0) + self.height_rmse_sum = _to_f32(0.0) + self.pos_count = _to_f32(0.0) + self.height_count = _to_f32(0.0) + self.num_samples = _to_f32(0.0) + + def update(self, predictions: Dict[str, paddle.Tensor], targets: Dict[str, paddle.Tensor]): + """ + Update metrics with a batch of predictions and targets. + + Args: + predictions: dict from model forward + - peak_number: [batch_size, max_peaks+1] logits for peak count + - peak_position: [batch_size, max_peaks, num_position_classes] logits for positions + - peak_height (optional): [batch_size, max_peaks] predicted intensity values + targets: dict from dataloader + - peak_num: [batch_size] true peak counts + - peak_position: [batch_size, max_peaks] true position labels + - peak_height: [batch_size, max_peaks] true intensity values + """ + batch_size = targets['peak_num'].shape[0] + + # Peak number predictions + pred_nums = predictions['peak_number'].argmax(axis=1) + true_nums = targets['peak_num'] + + # Number RMSE accumulation + num_errors = (pred_nums - true_nums).astype('float32') + self.num_rmse_sum += paddle.sum(paddle.square(num_errors)) + + # Process each sample for position and height metrics + for i in range(batch_size): + n_true = int(true_nums[i]) + n_pred = int(pred_nums[i]) + + if n_true > 0 and n_pred > 0: + n_match = min(n_true, n_pred) + + # Position errors (only on matched peaks) + pos_true = targets['peak_position'][i, :n_match].astype('int64') + pos_pred = predictions['peak_position'][i, :n_match, :].argmax(axis=1) + pos_errors = (pos_pred - pos_true).astype('float32') + self.pos_rmse_sum += paddle.sum(paddle.square(pos_errors)) + self.pos_count += _to_f32(n_match) + + # Height errors if enabled + if self.use_height_prediction and 'peak_height' in predictions: + height_true = targets['peak_height'][i, :n_match].astype('float32') + height_pred = predictions['peak_height'][i, :n_match].reshape([-1]) + height_errors = height_true - height_pred + self.height_rmse_sum += paddle.sum(paddle.square(height_errors)) + self.height_count += _to_f32(n_match) + + self.num_samples += _to_f32(batch_size) + + def accumulate(self) -> Dict[str, float]: + """ + Compute accumulated metrics. + + Returns: + dict: Dictionary containing all metrics + """ + # Distributed reduction + num_rmse_sum = _all_reduce_sum_(self.num_rmse_sum.clone()) + pos_rmse_sum = _all_reduce_sum_(self.pos_rmse_sum.clone()) + height_rmse_sum = _all_reduce_sum_(self.height_rmse_sum.clone()) + pos_count = _all_reduce_sum_(self.pos_count.clone()) + height_count = _all_reduce_sum_(self.height_count.clone()) + num_samples = _all_reduce_sum_(self.num_samples.clone()) + + # Compute final metrics + num_rmse = paddle.sqrt(num_rmse_sum / paddle.maximum(num_samples, _to_f32(1.0))).item() + pos_rmse = paddle.sqrt(pos_rmse_sum / paddle.maximum(pos_count, _to_f32(1.0))).item() + + metrics = { + 'num_rmse': num_rmse, + 'pos_rmse': pos_rmse, + } + + if self.use_height_prediction: + height_rmse = paddle.sqrt( + height_rmse_sum / paddle.maximum(height_count, _to_f32(1.0)) + ).item() + metrics['height_rmse'] = height_rmse + + return metrics + + def forward(self, predictions: Dict[str, paddle.Tensor], targets: Dict[str, paddle.Tensor]) -> Dict[str, float]: + """ + Compute metrics for a single batch (non-streaming version). + + Args: + predictions: dict from model forward + targets: dict from dataloader + + Returns: + dict: Dictionary containing all metrics for this batch + """ + self.reset() + self.update(predictions, targets) + return self.accumulate() \ No newline at end of file diff --git a/ppmat/models/__init__.py b/ppmat/models/__init__.py index 95d73232..93520a59 100644 --- a/ppmat/models/__init__.py +++ b/ppmat/models/__init__.py @@ -35,6 +35,8 @@ from ppmat.models.diffnmr.diffnmr import MolecularGraphFormer from ppmat.models.diffnmr.diffnmr import NMRNetCLIP from ppmat.models.dimenetpp.dimenetpp import DimeNetPlusPlus +from ppmat.models.ecformer import ECFormerECD +from ppmat.models.ecformer import ECFormerIR from ppmat.models.mattergen.mattergen import MatterGen from ppmat.models.mattergen.mattergen import MatterGenWithCondition from ppmat.models.mattersim.m3gnet import M3GNet diff --git a/ppmat/models/ecformer/__init__.py b/ppmat/models/ecformer/__init__.py new file mode 100644 index 00000000..0e90b295 --- /dev/null +++ b/ppmat/models/ecformer/__init__.py @@ -0,0 +1,26 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# export model calss +from .models.ECD import ECFormerECD +from .models.IR import ECFormerIR + +# export encoder(if want to use directly) +from .encoders.gin_node_embedding import GINNodeEmbedding + +__all__ = [ + 'ECFormerECD', + 'ECFormerIR', + 'GINNodeEmbedding', +] \ No newline at end of file diff --git a/ppmat/models/ecformer/encoders/__init__.py b/ppmat/models/ecformer/encoders/__init__.py new file mode 100644 index 00000000..36adf7e7 --- /dev/null +++ b/ppmat/models/ecformer/encoders/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .gin_node_embedding import GINNodeEmbedding \ No newline at end of file diff --git a/ppmat/models/ecformer/encoders/gin_node_embedding.py b/ppmat/models/ecformer/encoders/gin_node_embedding.py new file mode 100644 index 00000000..ec3f1fb3 --- /dev/null +++ b/ppmat/models/ecformer/encoders/gin_node_embedding.py @@ -0,0 +1,183 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from ..layers.atom_encoder import AtomEncoder +from ..layers.bond_encoder import BondEncoder +from ..layers.rbf import BondFloatRBF, BondAngleFloatRBF +from ..layers.gin_conv import GINConv + + +class GINNodeEmbedding(nn.Layer): + """GIN node embedding module - supports geometry-enhanced dual graph structure""" + + def __init__( + self, + full_atom_feature_dims, + full_bond_feature_dims, + bond_float_names, + bond_angle_float_names, + bond_id_names, + num_layers=5, + emb_dim=128, + drop_ratio=0.5, + JK="last", + residual=False, + use_geometry_enhanced=True + ): + super(GINNodeEmbedding, self).__init__() + + self.num_layers = num_layers + self.drop_ratio = drop_ratio + self.JK = JK + self.residual = residual + self.use_geometry_enhanced = use_geometry_enhanced + self.bond_id_names = bond_id_names + + if self.num_layers < 2: + raise ValueError("Number of GNN layers must be greater than 1.") + + # Encoders + self.atom_encoder = AtomEncoder(full_atom_feature_dims, emb_dim) + self.bond_encoder = BondEncoder(full_bond_feature_dims, emb_dim) + self.bond_float_encoder = BondFloatRBF(bond_float_names, emb_dim) + self.bond_angle_encoder = BondAngleFloatRBF(bond_angle_float_names, emb_dim) + + # GNN layer lists + self.convs = nn.LayerList() + self.convs_bond_angle = nn.LayerList() + self.convs_bond_embedding = nn.LayerList() + self.convs_bond_float = nn.LayerList() + self.convs_angle_float = nn.LayerList() + self.batch_norms = nn.LayerList() + self.batch_norms_ba = nn.LayerList() + + for _ in range(num_layers): + self.convs.append(GINConv(emb_dim)) + self.convs_bond_angle.append(GINConv(emb_dim)) + self.convs_bond_embedding.append(BondEncoder(full_bond_feature_dims, emb_dim)) + self.convs_bond_float.append(BondFloatRBF(bond_float_names, emb_dim)) + self.convs_angle_float.append(BondAngleFloatRBF(bond_angle_float_names, emb_dim)) + self.batch_norms.append(nn.BatchNorm1D(emb_dim)) + self.batch_norms_ba.append(nn.BatchNorm1D(emb_dim)) + + + def forward( + self, + x, # [N, F] atom features + edge_index, # [2, E] edge indices + edge_attr, # [E, D] edge features + # Geometry enhancement related inputs + ba_edge_index=None, # [2, E_ba] bond-angle graph edge indices + ba_edge_attr=None, # [E_ba, D_ba] bond-angle graph edge features + ): + """ + Forward pass + """ + # 1. Atom feature encoding + if x.dtype != paddle.int64: + x = x.astype(paddle.int64) + h_list = [self.atom_encoder(x)] + + if self.use_geometry_enhanced and ba_edge_index is not None: + return self._forward_enhanced( + h_list, edge_index, edge_attr, + ba_edge_index, ba_edge_attr + ) + else: + return self._forward_simple( + h_list, edge_index, edge_attr + ) + + def _forward_enhanced(self, h_list, edge_index, edge_attr, + ba_edge_index, ba_edge_attr): + """Geometry-enhanced forward pass""" + + bond_id_len = len(self.bond_id_names) + + # Initialize edge representations + h_list_ba = [self.bond_float_encoder( + edge_attr[:, bond_id_len:edge_attr.shape[1]+1].astype('float32') + ) + self.bond_encoder( + edge_attr[:, 0:bond_id_len].astype('int64') + )] + + for layer in range(self.num_layers): + # Node update + h = self.convs[layer](h_list[layer], edge_index, h_list_ba[layer]) + + # Edge update + cur_h_ba = self.convs_bond_embedding[layer]( + edge_attr[:, 0:bond_id_len].astype('int64') + ) + self.convs_bond_float[layer]( + edge_attr[:, bond_id_len:edge_attr.shape[1]+1].astype('float32') + ) + cur_angle_hidden = self.convs_angle_float[layer](ba_edge_attr) + h_ba = self.convs_bond_angle[layer](cur_h_ba, ba_edge_index, cur_angle_hidden) + + # Dropout and residual + if layer == self.num_layers - 1: + h = F.dropout(h, self.drop_ratio, training=self.training) + h_ba = F.dropout(h_ba, self.drop_ratio, training=self.training) + else: + h = F.dropout(F.relu(h), self.drop_ratio, training=self.training) + h_ba = F.dropout(F.relu(h_ba), self.drop_ratio, training=self.training) + + if self.residual: + h += h_list[layer] + h_ba += h_list_ba[layer] + + h_list.append(h) + h_list_ba.append(h_ba) + + # JK connection strategy + if self.JK == "last": + node_representation = h_list[-1] + edge_representation = h_list_ba[-1] + elif self.JK == "sum": + node_representation = sum(h_list) + edge_representation = sum(h_list_ba) + + return node_representation, edge_representation + + def _forward_simple(self, h_list, edge_index, edge_attr): + """Simplified forward pass""" + bond_id_len = len(self.bond_id_names) + + for layer in range(self.num_layers): + h = self.convs[layer]( + h_list[layer], + edge_index, + self.convs_bond_embedding[layer](edge_attr[:, 0:bond_id_len].astype('int64')) + + self.convs_bond_float[layer](edge_attr[:, bond_id_len:edge_attr.shape[1]+1].astype('float32')) + ) + h = self.batch_norms[layer](h) + + if layer == self.num_layers - 1: + h = F.dropout(h, self.drop_ratio, training=self.training) + else: + h = F.dropout(F.relu(h), self.drop_ratio, training=self.training) + + if self.residual: + h += h_list[layer] + + h_list.append(h) + + if self.JK == "last": + return h_list[-1] + elif self.JK == "sum": + return sum(h_list) \ No newline at end of file diff --git a/ppmat/models/ecformer/layers/__init__.py b/ppmat/models/ecformer/layers/__init__.py new file mode 100644 index 00000000..a763d9c7 --- /dev/null +++ b/ppmat/models/ecformer/layers/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .atom_encoder import AtomEncoder +from .bond_encoder import BondEncoder +from .rbf import RBF, BondFloatRBF +from .gin_conv import GINConv \ No newline at end of file diff --git a/ppmat/models/ecformer/layers/atom_encoder.py b/ppmat/models/ecformer/layers/atom_encoder.py new file mode 100644 index 00000000..44d5c322 --- /dev/null +++ b/ppmat/models/ecformer/layers/atom_encoder.py @@ -0,0 +1,33 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.nn as nn + +class AtomEncoder(nn.Layer): + """Atomic Feature Encoder - Maps discrete atomic features to continuous vectors""" + + def __init__(self, full_atom_feature_dims, emb_dim): + super(AtomEncoder, self).__init__() + self.atom_embedding_list = nn.LayerList() + + for dim in full_atom_feature_dims: + emb = nn.Embedding(dim + 5, emb_dim) + nn.initializer.XavierUniform()(emb.weight) + self.atom_embedding_list.append(emb) + + def forward(self, x): + x_embedding = 0 + for i in range(x.shape[1]): + x_embedding += self.atom_embedding_list[i](x[:, i]) + return x_embedding \ No newline at end of file diff --git a/ppmat/models/ecformer/layers/bond_encoder.py b/ppmat/models/ecformer/layers/bond_encoder.py new file mode 100644 index 00000000..46d56fdd --- /dev/null +++ b/ppmat/models/ecformer/layers/bond_encoder.py @@ -0,0 +1,33 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.nn as nn + +class BondEncoder(nn.Layer): + """Bond feature encoder - maps discrete bond features to continuous vectors""" + + def __init__(self, full_bond_feature_dims, emb_dim): + super(BondEncoder, self).__init__() + self.bond_embedding_list = nn.LayerList() + + for dim in full_bond_feature_dims: + emb = nn.Embedding(dim + 5, emb_dim) + nn.initializer.XavierUniform()(emb.weight) + self.bond_embedding_list.append(emb) + + def forward(self, edge_attr): + bond_embedding = 0 + for i in range(edge_attr.shape[1]): + bond_embedding += self.bond_embedding_list[i](edge_attr[:, i]) + return bond_embedding \ No newline at end of file diff --git a/ppmat/models/ecformer/layers/gin_conv.py b/ppmat/models/ecformer/layers/gin_conv.py new file mode 100644 index 00000000..77c18209 --- /dev/null +++ b/ppmat/models/ecformer/layers/gin_conv.py @@ -0,0 +1,46 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ppmat.models.common.message_passing.message_passing import MessagePassing +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +class GINConv(MessagePassing): + """Graph Isomorphism Convolution Layer""" + + def __init__(self, emb_dim): + super(GINConv, self).__init__(aggr="add") + + self.mlp = nn.Sequential( + nn.Linear(emb_dim, emb_dim), + nn.BatchNorm1D(emb_dim), + nn.ReLU(), + nn.Linear(emb_dim, emb_dim) + ) + self.eps = paddle.create_parameter( + shape=[1], + dtype=paddle.get_default_dtype(), + default_initializer=nn.initializer.Assign(paddle.to_tensor([0.])) + ) + + def forward(self, x, edge_index, edge_attr): + out = self.mlp((1 + self.eps) * x + self.propagate(edge_index, x=x, edge_attr=edge_attr)) + return out + + def message(self, x_j, edge_attr): + return F.relu(x_j + edge_attr) + + def update(self, aggr_out): + return aggr_out \ No newline at end of file diff --git a/ppmat/models/ecformer/layers/rbf.py b/ppmat/models/ecformer/layers/rbf.py new file mode 100644 index 00000000..9330c5c6 --- /dev/null +++ b/ppmat/models/ecformer/layers/rbf.py @@ -0,0 +1,119 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import numpy as np + +class RBF(nn.Layer): + """Radial Basis Function""" + + def __init__(self, + centers: paddle.nn.parameter.Parameter, + gamma: paddle.nn.parameter.Parameter): + super(RBF, self).__init__() + self.centers = centers.data.reshape([1, -1]) + self.gamma = gamma.data + + def forward(self, x): + x = x.reshape([-1, 1]) + return paddle.exp(-self.gamma * paddle.square(x - self.centers)) + + +class BondFloatRBF(nn.Layer): + """RBF encoder for continuous bond features""" + + def __init__(self, bond_float_names, embed_dim, rbf_params=None): + super(BondFloatRBF, self).__init__() + self.bond_float_names = bond_float_names + + if rbf_params is None: + self.rbf_params = self._default_rbf_params() + else: + self.rbf_params = rbf_params + + self.linear_list = nn.LayerList() + self.rbf_list = nn.LayerList() + for name in self.bond_float_names: + centers, gamma = self.rbf_params[name] + rbf = RBF(centers, gamma) + self.rbf_list.append(rbf) + linear = nn.Linear(len(centers), embed_dim) + self.linear_list.append(linear) + + def _default_rbf_params(self): + return { + 'bond_length': (paddle.create_parameter(shape=paddle.arange(0, 2, 0.1).shape, + dtype=paddle.arange(0, 2, 0.1).dtype, + default_initializer=paddle.nn.initializer.Assign(paddle.arange(0, 2, 0.1))), + paddle.create_parameter(shape=paddle.to_tensor([10.0]).shape, + dtype=paddle.to_tensor([10.0]).dtype, + default_initializer=paddle.nn.initializer.Assign(paddle.to_tensor([10.0])))), + } + + def forward(self, bond_float_features): + out_embed = 0 + for i, name in enumerate(self.bond_float_names): + x = bond_float_features[:, i].reshape([-1, 1]) + rbf_x = self.rbf_list[i](x) + out_embed += self.linear_list[i](rbf_x) + return out_embed + + +class BondAngleFloatRBF(nn.Layer): + """RBF encoder for continuous bond angle features""" + + def __init__(self, bond_angle_float_names, embed_dim, rbf_params=None): + super(BondAngleFloatRBF, self).__init__() + self.bond_angle_float_names = bond_angle_float_names + + if rbf_params is None: + self.rbf_params = { + 'bond_angle': (paddle.create_parameter(shape=paddle.arange(0, np.pi, 0.1).shape, + dtype='float32', + default_initializer=paddle.nn.initializer.Assign(paddle.arange(0, np.pi, 0.1))), + paddle.create_parameter(shape=paddle.to_tensor([10.0]).shape, + dtype=paddle.to_tensor([10.0]).dtype, + default_initializer=paddle.nn.initializer.Assign(paddle.to_tensor([10.0])))), + } + else: + self.rbf_params = rbf_params + + self.linear_list = nn.LayerList() + self.rbf_list = nn.LayerList() + + for name in self.bond_angle_float_names: + if name == 'bond_angle': + centers, gamma = self.rbf_params[name] + rbf = RBF(centers, gamma) + self.rbf_list.append(rbf) + linear = nn.Linear(len(centers), embed_dim) + self.linear_list.append(linear) + else: + linear = nn.Linear(len(self.bond_angle_float_names) - 1, embed_dim) + self.linear_list.append(linear) + break + + def forward(self, bond_angle_float_features): + out_embed = 0 + for i, name in enumerate(self.bond_angle_float_names): + if name == 'bond_angle': + x = bond_angle_float_features[:, i].reshape([-1, 1]) + rbf_x = self.rbf_list[i](x) + out_embed += self.linear_list[i](rbf_x) + else: + x = bond_angle_float_features[:, 1:] + out_embed += self.linear_list[i](x) + break + return out_embed \ No newline at end of file diff --git a/ppmat/models/ecformer/models/ECD.py b/ppmat/models/ecformer/models/ECD.py new file mode 100644 index 00000000..18f53190 --- /dev/null +++ b/ppmat/models/ecformer/models/ECD.py @@ -0,0 +1,50 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.nn as nn + +from .base_ecformer import ECFormerBase + + +class ECFormerECD(ECFormerBase): + """ECFormer for ECD spectrum prediction - peak attribute decoupling version""" + + def __init__( + self, + num_position_classes = 20, + height_classes = 2, + **kwargs + ): + super().__init__(**kwargs) + + # Peak number prediction head + self.pred_number_layer = nn.Sequential( + nn.Linear(self.emb_dim, self.emb_dim * 2), + nn.ReLU(), + nn.Linear(self.emb_dim * 2, self.max_peaks) + ) + + # Peak position prediction head + self.pred_position_layer = nn.Sequential( + nn.Linear(self.emb_dim, self.emb_dim // 4), + nn.ReLU(), + nn.Linear(self.emb_dim // 4, num_position_classes) + ) + + # Peak sign prediction head + self.pred_height_layer = nn.Sequential( + nn.Linear(self.emb_dim, self.emb_dim // 4), + nn.ReLU(), + nn.Linear(self.emb_dim // 4, height_classes) + ) \ No newline at end of file diff --git a/ppmat/models/ecformer/models/IR.py b/ppmat/models/ecformer/models/IR.py new file mode 100644 index 00000000..111283da --- /dev/null +++ b/ppmat/models/ecformer/models/IR.py @@ -0,0 +1,54 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.nn as nn + +from .base_ecformer import ECFormerBase + + +class ECFormerIR(ECFormerBase): + """ECDFormer for IR spectrum prediction - sequence regression version""" + + def __init__( + self, + num_position_classes=36, + use_height_prediction=True, + **kwargs + ): + # IR task has different maximum number of peaks + kwargs['max_peaks'] = kwargs.get('max_peaks', 15) + + super().__init__(**kwargs) + + # Peak number prediction head (IR has at most 15 peaks) + self.pred_number_layer = nn.Sequential( + nn.Linear(self.emb_dim, self.emb_dim * 2), + nn.ReLU(), + nn.Linear(self.emb_dim * 2, self.max_peaks + 1) + ) + + # Peak position prediction head (IR has more position classes) + self.pred_position_layer = nn.Sequential( + nn.Linear(self.emb_dim, self.emb_dim // 4), + nn.ReLU(), + nn.Linear(self.emb_dim // 4, num_position_classes) + ) + + # Peak intensity prediction head (IR regression) + if use_height_prediction: + self.pred_height_layer = nn.Sequential( + nn.Linear(self.emb_dim, self.emb_dim // 4), + nn.ReLU(), + nn.Linear(self.emb_dim // 4, 1) + ) \ No newline at end of file diff --git a/ppmat/models/ecformer/models/__init__.py b/ppmat/models/ecformer/models/__init__.py new file mode 100644 index 00000000..95d7e67a --- /dev/null +++ b/ppmat/models/ecformer/models/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .ECD import ECFormerECD +from .IR import ECFormerIR \ No newline at end of file diff --git a/ppmat/models/ecformer/models/base_ecformer.py b/ppmat/models/ecformer/models/base_ecformer.py new file mode 100644 index 00000000..e913d6d1 --- /dev/null +++ b/ppmat/models/ecformer/models/base_ecformer.py @@ -0,0 +1,249 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC +import paddle +import paddle.nn as nn +from paddle.nn import TransformerEncoder, TransformerEncoderLayer + +from ..encoders.gin_node_embedding import GINNodeEmbedding +from ppmat.utils.graph_utils import pad_node_features, feat_padding_mask +from paddle_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set + +def fix_mask_for_paddle(mask, n_head=None): + """ + Simple and direct mask repair function + + Args: + mask: input mask + n_head: number of attention heads (needed for attention mask) + """ + shape = mask.shape + assert len(shape) == 2 + # If [batch_size, src_len] but intended to be used as attention mask + batch_size, s_len = shape + # [32, 73] -> [32, 1, 73, 73] + if n_head: + return mask.reshape([batch_size, 1, 1, s_len]).expand([-1, n_head, s_len, -1]) + else: + return mask.unsqueeze(1).unsqueeze(2).expand([-1, -1, s_len, -1]) + + +class ECFormerBase(nn.Layer, ABC): + """ECFormer Base Class - Abstract interface for all spectrum prediction models""" + + def __init__( + self, + # GNN parameters + full_atom_feature_dims, + full_bond_feature_dims, + bond_float_names, + bond_angle_float_names, + bond_id_names, + num_layers=5, + emb_dim=128, + drop_ratio=0.0, + JK="last", + residual=False, + graph_pooling="attention", + use_geometry_enhanced=True, + max_node_num=63, + # Transformer parameters + num_heads=4, + tf_layers=2, + tf_dropout=0.1, + max_peaks=9, + ): + super().__init__() + + self.emb_dim = emb_dim + self.max_node_num = max_node_num + self.max_peaks = max_peaks + self.use_geometry_enhanced = use_geometry_enhanced + + # 1. GNN node encoder + self.gnn_node = GINNodeEmbedding( + full_atom_feature_dims=full_atom_feature_dims, + full_bond_feature_dims=full_bond_feature_dims, + bond_float_names=bond_float_names, + bond_angle_float_names=bond_angle_float_names, + bond_id_names=bond_id_names, + num_layers=num_layers, + emb_dim=emb_dim, + drop_ratio=drop_ratio, + JK=JK, + residual=residual, + use_geometry_enhanced=use_geometry_enhanced + ) + + # 2. Graph pooling layer + self.pool = self._build_pooling(graph_pooling, emb_dim) + + # 3. Query embedding (peak query vectors) + self.query_embed = nn.Embedding(max_peaks, emb_dim) + + # 4. Transformer encoder + self.tf_encoder = self._build_transformer(emb_dim, num_heads, tf_layers, tf_dropout) + + def _build_pooling(self, graph_pooling, emb_dim): + """Build graph pooling layer""" + if graph_pooling == "sum": + return global_add_pool + elif graph_pooling == "mean": + return global_mean_pool + elif graph_pooling == "max": + return global_max_pool + elif graph_pooling == "attention": + return GlobalAttention( + gate_nn=nn.Sequential( + nn.Linear(emb_dim, emb_dim), + nn.BatchNorm1D(emb_dim), + nn.ReLU(), + nn.Linear(emb_dim, 1) + ) + ) + elif graph_pooling == "set2set": + return Set2Set(emb_dim, processing_steps=2) + else: + raise ValueError(f"Invalid graph pooling type: {graph_pooling}") + + def _build_transformer(self, emb_dim, num_heads, num_layers, dropout): + """Build Transformer encoder""" + + assert emb_dim % num_heads == 0, "emb_dim must be divisible by num_heads" + + self.tf_enc_layer = TransformerEncoderLayer( + d_model=emb_dim, + nhead=num_heads, + dim_feedforward=emb_dim, + dropout=dropout, + activation='relu', + ) + return TransformerEncoder(self.tf_enc_layer, num_layers=num_layers) + + def encode_molecule( + self, + x, # [N, F] atom features + edge_index, # [2, E] edge indices + edge_attr, # [E, D] edge features + batch_data, # [N] batch information + # Geometry enhancement related + ba_edge_index=None, # [2, E_ba] bond-angle graph edge indices + ba_edge_attr=None, # [E_ba, D_ba] bond-angle graph edge features + ): + """Molecule encoder - pure Tensor input""" + + # 1. GNN encoding + if self.use_geometry_enhanced and ba_edge_index is not None: + h_node, _ = self.gnn_node( + x=x, + edge_index=edge_index, + edge_attr=edge_attr, + ba_edge_index=ba_edge_index, + ba_edge_attr=ba_edge_attr + ) + else: + h_node = self.gnn_node( + x=x, + edge_index=edge_index, + edge_attr=edge_attr, + ) + + # 2. Node feature padding (requires batch information) + batch_size = batch_data[-1] + 1 + + node_feat, node_index = pad_node_features( + h_node, batch_data, batch_size, self.max_node_num, self.emb_dim + ) + + # 3. Graph pooling + h_graph = self.pool(h_node, batch_data).unsqueeze(1) + + # 4. Concatenate graph features and node features + + total_node_feat = paddle.concat([h_graph, node_feat], axis=1) + + # 5. Generate padding mask + node_padding_mask = feat_padding_mask(node_index, self.max_node_num) + pooling_padding_mask = paddle.zeros([node_padding_mask.shape[0], 1], dtype=paddle.get_default_dtype()) + total_padding_mask = paddle.concat([pooling_padding_mask, node_padding_mask], axis=1) + + return total_node_feat, total_padding_mask, node_padding_mask + + def forward(self, + x: paddle.Tensor, + edge_index: paddle.Tensor, + edge_attr: paddle.Tensor, + batch_data: paddle.Tensor, + ba_edge_index: paddle.Tensor = None, + ba_edge_attr: paddle.Tensor = None, + query_mask: paddle.Tensor = None): + # 0. Data type check + if batch_data.dtype != paddle.int64: + batch_data = batch_data.astype(paddle.int64) + + # 1. Molecule encoding + node_feat, padding_mask, node_padding_mask = self.encode_molecule(x, edge_index, edge_attr,batch_data, ba_edge_index, ba_edge_attr) + + # 2. Peak number prediction (from graph features) + graph_feat = node_feat[:, 0, :] + pred_number = self.pred_number_layer(graph_feat) + + # 3. Query preparation + query_feat = self.query_embed.weight.unsqueeze(0).tile([node_feat.shape[0], 1, 1]) + + # Generate query mask based on predicted peak number during inference + if not self.training: + pred_peak_num = pred_number.argmax(axis=1) + peak_position = [ + [1] * int(pred_peak_num[i]) + [-1] * (self.max_peaks - int(pred_peak_num[i])) + for i in range(len(pred_peak_num)) + ] + peak_position = paddle.to_tensor(peak_position) + query_mask = get_key_padding_mask(peak_position) + + # 4. Transformer encoding + encoder_input = paddle.concat([node_feat, query_feat], axis=1) + encoder_padding_mask = paddle.concat([padding_mask, query_mask], axis=1) + + encoder_output = self.tf_encoder(encoder_input, fix_mask_for_paddle(encoder_padding_mask)) + + # 5. Peak position and sign prediction + query_output = encoder_output[:, node_feat.shape[1]:, :] + pred_position = self.pred_position_layer(query_output) + pred_height = self.pred_height_layer(query_output) + + # 6. Attention weights (for visualization) + node_feat_output = encoder_output[:, :node_feat.shape[1], :] + attn_weights = paddle.einsum("bid,bjd->bij", + node_feat_output, + query_output[:, 0, :].unsqueeze(1) + ) + attn_weights = attn_weights[:, 1:, :].squeeze() + attn_mask = node_padding_mask[:, 1:] + + return { + 'peak_number': pred_number, + 'peak_position': pred_position, + 'peak_height': pred_height, + 'attention': { + 'weights': attn_weights.cpu().tolist() if not self.training else None, + 'mask': attn_mask.cpu().tolist() if not self.training else None + } + } + +def get_key_padding_mask(tokens): + key_padding_mask = paddle.zeros(tokens.shape) + key_padding_mask[tokens == -1] = -paddle.inf + return key_padding_mask \ No newline at end of file diff --git a/ppmat/trainer/__init__.py b/ppmat/trainer/__init__.py index 41a848b7..49b7793a 100644 --- a/ppmat/trainer/__init__.py +++ b/ppmat/trainer/__init__.py @@ -1,6 +1,7 @@ from ppmat.trainer.base_trainer import BaseTrainer +from ppmat.trainer.ecformer_trainer import ECFormerTrainer -__all__ = ["BaseTrainer", "build_trainer"] +__all__ = ["BaseTrainer", "build_trainer", "ECFormerTrainer"] def build_trainer(cfg, **kwargs): diff --git a/ppmat/trainer/ecformer_trainer.py b/ppmat/trainer/ecformer_trainer.py new file mode 100644 index 00000000..142a006c --- /dev/null +++ b/ppmat/trainer/ecformer_trainer.py @@ -0,0 +1,563 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import time +from collections import OrderedDict +from typing import Dict, Optional, List, Any, Union + +import numpy as np +import paddle +from paddle import nn +from paddle import optimizer as optim +from paddle.distributed import fleet +from tqdm import tqdm + +from ppmat.trainer.base_trainer import BaseTrainer +from ppmat.utils import logger +from ppmat.utils import AverageMeter +from ppmat.utils import save_load +from ppmat.metrics.ecd_metric import ECDMetrics +from ppmat.metrics.ir_metric import IRMetrics +from ppmat.losses.ecd_loss import ECDLoss +from ppmat.losses.ir_loss import IRLoss + + +class ECFormerTrainer(BaseTrainer): + """ + ECFormer trainer supporting both ECD and IR tasks with dedicated metrics. + + Features: + - Automatic task detection from model class name + - Task-specific loss functions (ECDLoss for classification, IRLoss for regression) + - Task-specific streaming metrics (ECDMetrics, IRMetrics) + - Attention visualization during inference + - Compatible with BaseTrainer training loop + """ + + def __init__( + self, + config: Dict, + model: nn.Layer, + train_dataloader: Optional[paddle.io.DataLoader] = None, + val_dataloader: Optional[paddle.io.DataLoader] = None, + optimizer: Optional[optim.Optimizer] = None, + lr_scheduler: Optional[optim.lr.LRScheduler] = None, + compute_metric_func_dict: Optional[Dict] = None, + ): + # Initialize parent class + super().__init__( + config=config, + model=model, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + compute_metric_func_dict=compute_metric_func_dict, + ) + + # Task detection from model class name + model_class_name = model.__class__.__name__ + self.is_ir_task = "IR" in model_class_name + self.is_ecd_task = "ECD" in model_class_name + + logger.info(f"Task type detected: {'IR' if self.is_ir_task else 'ECD' if self.is_ecd_task else 'Unknown'}") + + # Get task-specific parameters from config + self.max_peaks = config.get("max_peaks", 15 if self.is_ir_task else 9) + self.num_position_classes = config.get("num_position_classes", 36 if self.is_ir_task else 20) + + # Initialize task-specific loss function + if self.is_ecd_task: + self.loss_fn = ECDLoss( + loss_weight_height=config.get("loss_weight_height", 2.0), + num_position_classes=self.num_position_classes, + height_classes=config.get("height_classes", 2) + ) + logger.info("Using ECDLoss for ECD task") + elif self.is_ir_task: + self.loss_fn = IRLoss( + num_position_classes=self.num_position_classes, + use_height_prediction=config.get("use_height_prediction", True) + ) + logger.info("Using IRLoss for IR task") + else: + # Fallback to simple cross-entropy + self.ce_loss = nn.CrossEntropyLoss() + logger.warning("Unknown task type, using fallback CrossEntropyLoss") + + # Initialize task-specific metrics (will be attached via attach_metrics) + self.train_metrics = None + self.eval_metrics = None + + # Cache for dataset building to avoid repeated decompression + self._dataset_cache = {} + + def attach_metrics(self, metric_cfg=None, **runtime_objs): + """ + Attach task-specific metrics to the trainer. + + Args: + metric_cfg: Metric configuration from config file + **runtime_objs: Additional runtime objects + """ + super().attach_metrics(metric_cfg, **runtime_objs) + + # Create task-specific metric instances if not already in metric_modules + if self.is_ecd_task and 'ECDMetrics' not in str(self.metric_modules): + self.metric_modules['ecd_metrics'] = ECDMetrics( + num_position_classes=self.num_position_classes, + max_peaks=self.max_peaks + ) + logger.info("ECDMetrics attached") + elif self.is_ir_task and 'IRMetrics' not in str(self.metric_modules): + self.metric_modules['ir_metrics'] = IRMetrics( + use_height_prediction=config.get("use_height_prediction", True) + ) + logger.info("IRMetrics attached") + + def train_epoch(self, dataloader: paddle.io.DataLoader): + """ + Train for one epoch using task-specific loss functions. + + Args: + dataloader: Training data loader + + Returns: + tuple: time_info, loss_info, metric_info + """ + self.model.train() + + # Initialize statistics + loss_info = {} + metric_info = {} + time_info = { + "reader_cost": AverageMeter(name="reader_cost", postfix="s"), + "batch_cost": AverageMeter(name="batch_cost", postfix="s"), + } + + # Update training state + self.state.max_steps_in_train_epoch = len(dataloader) + self.state.step_in_train_epoch = 0 + + # Determine if this is the main process for progress bar + is_main_process = paddle.distributed.get_rank() == 0 if paddle.distributed.is_initialized() else True + + # Create progress bar only on main process + if is_main_process: + pbar = tqdm( + total=len(dataloader), + desc=f"Epoch {self.state.epoch}/{self.max_epochs}", + unit="batch", + ncols=100, + leave=True + ) + + # Timers + reader_tic = time.perf_counter() + batch_tic = time.perf_counter() + + for iter_id, batch in enumerate(dataloader): + # Parse batch data (adapting to ECDCollator/IRCollator format) + model_inputs, targets = batch + + reader_cost = time.perf_counter() - reader_tic + time_info["reader_cost"].update(reader_cost) + + # Calculate batch size + batch_size = model_inputs['x'].shape[0] if hasattr(model_inputs['x'], 'shape') else 1 + + # Forward pass + with self.autocast_context_manager(self.use_amp, self.amp_level): + predictions = self.model( + x=model_inputs['x'], + edge_index=model_inputs['edge_index'], + edge_attr=model_inputs['edge_attr'], + batch_data=model_inputs['batch_data'], + ba_edge_index=model_inputs.get('ba_edge_index', None), + ba_edge_attr=model_inputs.get('ba_edge_attr', None), + query_mask=model_inputs.get('query_mask', None) + ) + + # Compute loss using task-specific loss function + loss_dict = self.loss_fn(predictions, targets) + loss = loss_dict["loss"] + + # Backward pass + if self.use_amp: + loss_scaled = self.scaler.scale(loss) + loss_scaled.backward() + else: + loss.backward() + + # Update parameters + if self.use_amp: + self.scaler.minimize(self.optimizer, loss_scaled) + else: + self.optimizer.step() + self.optimizer.clear_grad() + + # Update loss statistics + for key, value in loss_dict.items(): + if key not in loss_info: + loss_info[key] = AverageMeter(key) + loss_info[key].update(float(value), batch_size) + + # Update streaming metrics + self._update_streaming_metrics(result={'predictions': predictions, 'loss_dict': loss_dict}, + batch=targets, stage='train') + + batch_cost = time.perf_counter() - batch_tic + time_info["batch_cost"].update(batch_cost) + + # Update state + self.state.step_in_train_epoch += 1 + self.state.global_step += 1 + + # Update learning rate (step-based) + if self.lr_scheduler is not None and not self.lr_scheduler.by_epoch: + self.lr_scheduler.step() + + # Update progress bar on main process + if is_main_process: + # Prepare current metrics for display + current_metrics = {} + current_metrics["lr"] = f"{self.optimizer.get_lr():.2e}" + for name, meter in loss_info.items(): + # Show only the most important metrics + if "loss" == name.lower() or "acc" in name.lower(): + current_metrics[name] = f"{meter.val:.4f}" + + # Add streaming metrics if available (only show a few key metrics to avoid clutter) + stream_metrics = self._compute_streaming_metrics(stage='train') + for name, value in stream_metrics.items(): + if isinstance(value, (int, float)): + # Show only the most important metrics + if "loss" == name.lower() or "acc" in name.lower(): + short_name = name.split('/')[-1] if '/' in name else name + current_metrics[short_name] = f"{value:.4f}" + + # Update progress bar postfix + pbar.set_postfix(current_metrics, refresh=False) + pbar.update(1) + + # Write to visualization tools (every log_freq steps) + if self.state.step_in_train_epoch % self.log_freq == 0: + logs = OrderedDict() + logs["lr"] = self.optimizer.get_lr() + for name, meter in time_info.items(): + logs[name] = meter.val + for name, meter in loss_info.items(): + logs[name] = meter.val + + # Add streaming metrics + stream_metrics = self._compute_streaming_metrics(stage='train') + for name, value in stream_metrics.items(): + if isinstance(value, (int, float)): + logs[f"{name}"] = value + if name not in metric_info: + metric_info[name] = AverageMeter(name) + metric_info[name].update(float(value), 1) + + # Write to visualization tools (not to console) + logger.scalar( + tag="train(step)", + metric_dict=logs, + step=self.state.global_step, + visualdl_writer=self.visualdl_writer, + wandb_writer=self.wandb_writer, + tensorboard_writer=self.tensorboard_writer, + ) + + batch_tic = time.perf_counter() + reader_tic = time.perf_counter() + + # Close progress bar + if is_main_process: + pbar.close() + + # Compute epoch-level streaming metrics + epoch_stream_metrics = self._compute_streaming_metrics(stage='train') + for name, value in epoch_stream_metrics.items(): + if isinstance(value, (int, float)): + if name not in metric_info: + metric_info[name] = AverageMeter(name) + metric_info[name].update(float(value), 1) + + # Log epoch summary to file (not to console) + logger.info(f"Epoch {self.state.epoch} completed. Avg Loss: {loss_info.get('loss', AverageMeter('loss')).avg:.4f}") + + return time_info, loss_info, metric_info + + def eval_epoch(self, dataloader: paddle.io.DataLoader): + """ + Evaluate for one epoch using task-specific metrics. + + Args: + dataloader: Validation data loader + + Returns: + tuple: time_info, loss_info, metric_info + """ + self.model.eval() + + loss_info = {} + metric_info = {} + time_info = { + "reader_cost": AverageMeter(name="reader_cost", postfix="s"), + "batch_cost": AverageMeter(name="batch_cost", postfix="s"), + } + + self.state.max_steps_in_eval_epoch = len(dataloader) + self.state.step_in_eval_epoch = 0 + + # Reset streaming metrics for evaluation + for _, m in self.metric_modules.items(): + if hasattr(m, 'reset'): + m.reset() + + # Determine if this is the main process for progress bar + is_main_process = paddle.distributed.get_rank() == 0 if paddle.distributed.is_initialized() else True + + # Create progress bar only on main process + if is_main_process: + pbar = tqdm( + total=len(dataloader), + desc=f"Eval Epoch {self.state.epoch}/{self.max_epochs}", + unit="batch", + ncols=80, + leave=False + ) + + reader_tic = time.perf_counter() + batch_tic = time.perf_counter() + + with paddle.no_grad(): + for iter_id, batch in enumerate(dataloader): + model_inputs, targets = batch + + reader_cost = time.perf_counter() - reader_tic + time_info["reader_cost"].update(reader_cost) + + batch_size = model_inputs['x'].shape[0] if hasattr(model_inputs['x'], 'shape') else 1 + + # Forward pass + with self.autocast_context_manager(self.use_amp, self.amp_level): + predictions = self.model( + x=model_inputs['x'], + edge_index=model_inputs['edge_index'], + edge_attr=model_inputs['edge_attr'], + batch_data=model_inputs['batch_data'], + ba_edge_index=model_inputs.get('ba_edge_index', None), + ba_edge_attr=model_inputs.get('ba_edge_attr', None), + query_mask=model_inputs.get('query_mask', None) + ) + + # Compute loss + loss_dict = self.loss_fn(predictions, targets) + + # Update loss statistics + for key, value in loss_dict.items(): + if key not in loss_info: + loss_info[key] = AverageMeter(key) + loss_info[key].update(float(value), batch_size) + + # Update streaming metrics + self._update_streaming_metrics(result={'predictions': predictions, 'loss_dict': loss_dict}, + batch=targets, stage='eval') + + # Step-wise metric computation (if configured) + if self.metric_strategy_during_eval == "step": + step_metrics = self._compute_streaming_metrics(stage='eval') + for name, value in step_metrics.items(): + if isinstance(value, (int, float)): + if name not in metric_info: + metric_info[name] = AverageMeter(name) + metric_info[name].update(float(value), batch_size) + + batch_cost = time.perf_counter() - batch_tic + time_info["batch_cost"].update(batch_cost) + + self.state.step_in_eval_epoch += 1 + + # Update progress bar on main process + if is_main_process: + current_metrics = {} + for name, meter in loss_info.items(): + # Show only the most important metrics + if "loss" == name.lower() or "acc" in name.lower(): + current_metrics[name] = f"{meter.val:.4f}" + pbar.set_postfix(current_metrics, refresh=False) + pbar.update(1) + + batch_tic = time.perf_counter() + reader_tic = time.perf_counter() + + # Close progress bar + if is_main_process: + pbar.close() + + # Compute epoch-level metrics from streaming accumulators + epoch_metrics = self._compute_streaming_metrics(stage='eval') + for name, value in epoch_metrics.items(): + if isinstance(value, (int, float)): + if name not in metric_info: + metric_info[name] = AverageMeter(name) + metric_info[name].update(float(value), len(dataloader.dataset)) + + # Log evaluation summary to file (not to console) + logger.info(f"Eval Epoch {self.state.epoch} completed. Avg Loss: {loss_info.get('loss', AverageMeter('loss')).avg:.4f}") + + return time_info, loss_info, metric_info + + def predict(self, dataloader: paddle.io.DataLoader) -> Dict[str, Any]: + """ + Run inference and return predictions with attention visualization. + + Args: + dataloader: Data loader for prediction + + Returns: + dict: Predictions including peak positions, heights, and attention weights + """ + self.model.eval() + + all_pos_pred = [] + all_height_pred = [] + all_attn_weights = [] + all_peak_nums = [] + + # Determine if this is the main process for progress bar + is_main_process = paddle.distributed.get_rank() == 0 if paddle.distributed.is_initialized() else True + + # Create progress bar only on main process + if is_main_process: + pbar = tqdm( + total=len(dataloader), + desc="Predicting", + unit="batch", + ncols=80, + leave=True + ) + + with paddle.no_grad(): + for batch in pbar if is_main_process else dataloader: + model_inputs, _ = batch # No targets needed for inference + + predictions = self.model( + x=model_inputs['x'], + edge_index=model_inputs['edge_index'], + edge_attr=model_inputs['edge_attr'], + batch_data=model_inputs['batch_data'], + ba_edge_index=model_inputs.get('ba_edge_index', None), + ba_edge_attr=model_inputs.get('ba_edge_attr', None), + query_mask=model_inputs.get('query_mask', None) + ) + + # Get predicted peak numbers + prob_num = paddle.nn.functional.softmax(predictions['peak_number'], axis=1) + pred_peak_num = paddle.argmax(prob_num, axis=1) + all_peak_nums.extend(pred_peak_num.cpu().numpy().tolist()) + + for i in range(pred_peak_num.shape[0]): + n_pred = int(pred_peak_num[i]) + + # Position predictions + pos_pred = paddle.argmax( + predictions['peak_position'][i, :n_pred, :], axis=1 + ).cpu().numpy().tolist() + + # Height predictions (classification or regression) + if 'peak_height' in predictions: + if len(predictions['peak_height'].shape) == 3: # Classification (ECD) + height_pred = paddle.argmax( + predictions['peak_height'][i, :n_pred, :], axis=1 + ).cpu().numpy().tolist() + else: # Regression (IR) + height_pred = predictions['peak_height'][i, :n_pred].reshape([-1]).cpu().numpy().tolist() + else: + height_pred = [] + + all_pos_pred.append(pos_pred) + all_height_pred.append(height_pred) + + # Attention weights for visualization + if predictions.get('attention', {}).get('weights'): + all_attn_weights.append({ + 'weights': predictions['attention']['weights'][i], + 'mask': predictions['attention']['mask'][i] if predictions['attention']['mask'] else None + }) + + # Close progress bar + if is_main_process: + pbar.close() + + return { + 'peak_number': all_peak_nums, + 'peak_position': all_pos_pred, + 'peak_height': all_height_pred, + 'attention': all_attn_weights if all_attn_weights else None + } + + def _update_streaming_metrics(self, *, result, batch, stage: str): + """ + Update streaming metrics with predictions and targets. + + Args: + result: dict containing 'predictions' from model + batch: target batch + stage: 'train' or 'eval' + """ + predictions = result.get('predictions', {}) + + for name, metric in self.metric_modules.items(): + if hasattr(metric, 'update'): + try: + metric.update(predictions, batch) + except Exception as e: + logger.debug(f"Error updating metric {name}: {e}") + + def _compute_streaming_metrics(self, *, stage: str) -> Dict[str, float]: + """ + Compute and reset streaming metrics. + + Args: + stage: 'train' or 'eval' + + Returns: + dict: Computed metrics + """ + all_metrics = {} + + for name, metric in self.metric_modules.items(): + if hasattr(metric, 'accumulate'): + try: + metrics = metric.accumulate() + if isinstance(metrics, dict): + # Add prefix for clarity + for k, v in metrics.items(): + all_metrics[f"{name}/{k}"] = v + else: + all_metrics[name] = metrics + except Exception as e: + logger.debug(f"Error computing metric {name}: {e}") + + if hasattr(metric, 'reset'): + try: + metric.reset() + except Exception: + pass + + return all_metrics \ No newline at end of file diff --git a/ppmat/utils/__init__.py b/ppmat/utils/__init__.py index 8b5fb924..913a51b1 100644 --- a/ppmat/utils/__init__.py +++ b/ppmat/utils/__init__.py @@ -22,6 +22,8 @@ from ppmat.utils.save_load import load_checkpoint from ppmat.utils.save_load import load_pretrain from ppmat.utils.save_load import save_checkpoint +from ppmat.utils.place_env import PlaceEnv +from ppmat.utils.colored_tqdm import ColoredTqdm __all__ = [ logger, @@ -33,4 +35,6 @@ load_checkpoint, load_pretrain, save_checkpoint, + PlaceEnv, + ColoredTqdm ] diff --git a/ppmat/utils/colored_tqdm.py b/ppmat/utils/colored_tqdm.py new file mode 100644 index 00000000..f4762c4f --- /dev/null +++ b/ppmat/utils/colored_tqdm.py @@ -0,0 +1,80 @@ +from tqdm import tqdm +import time +import os;os.system("") # Compatible with Windows + +def hex_to_ansi(hex_color: str, background: bool = False) -> str: + """ + Convert hexadecimal color to ANSI escape sequence + + Args: + hex_color: Hexadecimal color, e.g., '#dda0a0' or 'dda0a0' + background: True for background color, False for foreground color + + Returns: + ANSI escape sequence string, e.g., '\033[38;2;221;160;160m' + + Example: + >>> print(f"{hex_to_ansi('#dda0a0')}Hello{hex_to_ansi('#000000')} World") + >>> print(f"{hex_to_ansi('dda0a0', background=True)}Background color{hex_to_ansi.reset()}") + """ + # Remove # symbol and convert to lowercase + hex_color = hex_color.lower().lstrip('#') + + # Handle shorthand form (#fff -> ffffff) + if len(hex_color) == 3: + hex_color = ''.join([c * 2 for c in hex_color]) + + # Convert to RGB values + r = int(hex_color[0:2], 16) + g = int(hex_color[2:4], 16) + b = int(hex_color[4:6], 16) + + # ANSI true color sequence + # 38;2;R;G;B for foreground, 48;2;R;G;B for background + code = 48 if background else 38 + return f'\033[{code};2;{r};{g};{b}m' + +def rgb_to_ansi(r: int, g: int, b: int, background: bool = False) -> str: + """Convert RGB values directly to ANSI""" + code = 48 if background else 38 + return f'\033[{code};2;{r};{g};{b}m' + +# ANSI code to reset color +hex_to_ansi.reset = '\033[0m' + +class ColoredTqdm(tqdm): + def __init__(self, *args, + start_color=(221, 160, 160), # RGB: #DDA0A0 + end_color=(160, 221, 160), # RGB: #A0DDA0 + **kwargs): + super().__init__(*args, **kwargs) + self.start_color = start_color + self.end_color = end_color + + def get_current_color(self): + + if self.total is None: + return "#FFFFFF" + + progress = self.n / self.total if self.total > 0 else 0 + current_rgb = tuple( + int(start + (end - start) * progress) + for start, end in zip(self.start_color, self.end_color) + ) + result = current_rgb[0] * 16 ** 4 \ + + current_rgb[1] * 16 ** 2 \ + + current_rgb[2] * 16 ** 0 + return "%06x" % result + + def update(self, n=1): + super().update(n) + style = hex_to_ansi(self.get_current_color()) + self.bar_format = f'{{l_bar}}{style}{{bar}}{hex_to_ansi.reset}{{r_bar}}' + self.refresh() + + +if __name__ == "__main__": + # Usage example + for i in ColoredTqdm(range(10), desc="🌈 Rainbow gradient", leave = False): + for j in ColoredTqdm(range(100), desc="🌈 Rainbow gradient", leave = False): + time.sleep(0.01) \ No newline at end of file diff --git a/ppmat/utils/compound_tools.py b/ppmat/utils/compound_tools.py new file mode 100644 index 00000000..03461f0b --- /dev/null +++ b/ppmat/utils/compound_tools.py @@ -0,0 +1,846 @@ +import numpy as np +from rdkit import Chem +from rdkit.Chem import AllChem +from rdkit.Chem import rdchem + +DAY_LIGHT_FG_SMARTS_LIST = [ + # C + "[CX4]", + "[$([CX2](=C)=C)]", + "[$([CX3]=[CX3])]", + "[$([CX2]#C)]", + # C & O + "[CX3]=[OX1]", + "[$([CX3]=[OX1]),$([CX3+]-[OX1-])]", + "[CX3](=[OX1])C", + "[OX1]=CN", + "[CX3](=[OX1])O", + "[CX3](=[OX1])[F,Cl,Br,I]", + "[CX3H1](=O)[#6]", + "[CX3](=[OX1])[OX2][CX3](=[OX1])", + "[NX3][CX3](=[OX1])[#6]", + "[NX3][CX3]=[NX3+]", + "[NX3,NX4+][CX3](=[OX1])[OX2,OX1-]", + "[NX3][CX3](=[OX1])[OX2H0]", + "[NX3,NX4+][CX3](=[OX1])[OX2H,OX1-]", + "[CX3](=O)[O-]", + "[CX3](=[OX1])(O)O", + "[CX3](=[OX1])([OX2])[OX2H,OX1H0-1]", + "C[OX2][CX3](=[OX1])[OX2]C", + "[CX3](=O)[OX2H1]", + "[CX3](=O)[OX1H0-,OX2H1]", + "[NX3][CX2]#[NX1]", + "[#6][CX3](=O)[OX2H0][#6]", + "[#6][CX3](=O)[#6]", + "[OD2]([#6])[#6]", + # H + "[H]", + "[!#1]", + "[H+]", + "[+H]", + "[!H]", + # N + "[NX3;H2,H1;!$(NC=O)]", + "[NX3][CX3]=[CX3]", + "[NX3;H2;!$(NC=[!#6]);!$(NC#[!#6])][#6]", + "[NX3;H2,H1;!$(NC=O)].[NX3;H2,H1;!$(NC=O)]", + "[NX3][$(C=C),$(cc)]", + "[NX3,NX4+][CX4H]([*])[CX3](=[OX1])[O,N]", + "[NX3H2,NH3X4+][CX4H]([*])[CX3](=[OX1])[NX3,NX4+][CX4H]([*])[CX3](=[OX1])[OX2H,OX1-]", + "[$([NX3H2,NX4H3+]),$([NX3H](C)(C))][CX4H]([*])[CX3](=[OX1])[OX2H,OX1-,N]", + "[CH3X4]", + "[CH2X4][CH2X4][CH2X4][NHX3][CH0X3](=[NH2X3+,NHX2+0])[NH2X3]", + "[CH2X4][CX3](=[OX1])[NX3H2]", + "[CH2X4][CX3](=[OX1])[OH0-,OH]", + "[CH2X4][SX2H,SX1H0-]", + "[CH2X4][CH2X4][CX3](=[OX1])[OH0-,OH]", + "[$([$([NX3H2,NX4H3+]),$([NX3H](C)(C))][CX4H2][CX3](=[OX1])[OX2H,OX1-,N])]", + "[CH2X4][#6X3]1:[$([#7X3H+,#7X2H0+0]:[#6X3H]:[#7X3H]),$([#7X3H])]:[#6X3H]:\ +[$([#7X3H+,#7X2H0+0]:[#6X3H]:[#7X3H]),$([#7X3H])]:[#6X3H]1", + "[CHX4]([CH3X4])[CH2X4][CH3X4]", + "[CH2X4][CHX4]([CH3X4])[CH3X4]", + "[CH2X4][CH2X4][CH2X4][CH2X4][NX4+,NX3+0]", + "[CH2X4][CH2X4][SX2][CH3X4]", + "[CH2X4][cX3]1[cX3H][cX3H][cX3H][cX3H][cX3H]1", + "[$([NX3H,NX4H2+]),$([NX3](C)(C)(C))]1[CX4H]([CH2][CH2][CH2]1)[CX3](=[OX1])[OX2H,OX1-,N]", + "[CH2X4][OX2H]", + "[NX3][CX3]=[SX1]", + "[CHX4]([CH3X4])[OX2H]", + "[CH2X4][cX3]1[cX3H][nX3H][cX3]2[cX3H][cX3H][cX3H][cX3H][cX3]12", + "[CH2X4][cX3]1[cX3H][cX3H][cX3]([OHX2,OH0X1-])[cX3H][cX3H]1", + "[CHX4]([CH3X4])[CH3X4]", + "N[CX4H2][CX3](=[OX1])[O,N]", + "N1[CX4H]([CH2][CH2][CH2]1)[CX3](=[OX1])[O,N]", + "[$(*-[NX2-]-[NX2+]#[NX1]),$(*-[NX2]=[NX2+]=[NX1-])]", + "[$([NX1-]=[NX2+]=[NX1-]),$([NX1]#[NX2+]-[NX1-2])]", + "[#7]", + "[NX2]=N", + "[NX2]=[NX2]", + "[$([NX2]=[NX3+]([O-])[#6]),$([NX2]=[NX3+0](=[O])[#6])]", + "[$([#6]=[N+]=[N-]),$([#6-]-[N+]#[N])]", + "[$([nr5]:[nr5,or5,sr5]),$([nr5]:[cr5]:[nr5,or5,sr5])]", + "[NX3][NX3]", + "[NX3][NX2]=[*]", + "[CX3;$([C]([#6])[#6]),$([CH][#6])]=[NX2][#6]", + "[$([CX3]([#6])[#6]),$([CX3H][#6])]=[$([NX2][#6]),$([NX2H])]", + "[NX3+]=[CX3]", + "[CX3](=[OX1])[NX3H][CX3](=[OX1])", + "[CX3](=[OX1])[NX3H0]([#6])[CX3](=[OX1])", + "[CX3](=[OX1])[NX3H0]([NX3H0]([CX3](=[OX1]))[CX3](=[OX1]))[CX3](=[OX1])", + "[$([NX3](=[OX1])(=[OX1])O),$([NX3+]([OX1-])(=[OX1])O)]", + "[$([OX1]=[NX3](=[OX1])[OX1-]),$([OX1]=[NX3+]([OX1-])[OX1-])]", + "[NX1]#[CX2]", + "[CX1-]#[NX2+]", + "[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8]", + "[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8].[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8]", + "[NX2]=[OX1]", + "[$([#7+][OX1-]),$([#7v5]=[OX1]);!$([#7](~[O])~[O]);!$([#7]=[#7])]", + # O + "[OX2H]", + "[#6][OX2H]", + "[OX2H][CX3]=[OX1]", + "[OX2H]P", + "[OX2H][#6X3]=[#6]", + "[OX2H][cX3]:[c]", + "[OX2H][$(C=C),$(cc)]", + "[$([OH]-*=[!#6])]", + "[OX2,OX1-][OX2,OX1-]", + # P + "[$(P(=[OX1])([$([OX2H]),$([OX1-]),$([OX2]P)])([$([OX2H]),$([OX1-]),\ +$([OX2]P)])[$([OX2H]),$([OX1-]),$([OX2]P)]),$([P+]([OX1-])([$([OX2H]),$([OX1-])\ +,$([OX2]P)])([$([OX2H]),$([OX1-]),$([OX2]P)])[$([OX2H]),$([OX1-]),$([OX2]P)])]", + "[$(P(=[OX1])([OX2][#6])([$([OX2H]),$([OX1-]),$([OX2][#6])])[$([OX2H]),\ +$([OX1-]),$([OX2][#6]),$([OX2]P)]),$([P+]([OX1-])([OX2][#6])([$([OX2H]),$([OX1-]),\ +$([OX2][#6])])[$([OX2H]),$([OX1-]),$([OX2][#6]),$([OX2]P)])]", + # S + "[S-][CX3](=S)[#6]", + "[#6X3](=[SX1])([!N])[!N]", + "[SX2]", + "[#16X2H]", + "[#16!H0]", + "[#16X2H0]", + "[#16X2H0][!#16]", + "[#16X2H0][#16X2H0]", + "[#16X2H0][!#16].[#16X2H0][!#16]", + "[$([#16X3](=[OX1])[OX2H0]),$([#16X3+]([OX1-])[OX2H0])]", + "[$([#16X3](=[OX1])[OX2H,OX1H0-]),$([#16X3+]([OX1-])[OX2H,OX1H0-])]", + "[$([#16X4](=[OX1])=[OX1]),$([#16X4+2]([OX1-])[OX1-])]", + "[$([#16X4](=[OX1])(=[OX1])([#6])[#6]),$([#16X4+2]([OX1-])([OX1-])([#6])[#6])]", + "[$([#16X4](=[OX1])(=[OX1])([#6])[OX2H,OX1H0-]),$([#16X4+2]([OX1-])([OX1-])([#6])[OX2H,OX1H0-])]", + "[$([#16X4](=[OX1])(=[OX1])([#6])[OX2H0]),$([#16X4+2]([OX1-])([OX1-])([#6])[OX2H0])]", + "[$([#16X4]([NX3])(=[OX1])(=[OX1])[#6]),$([#16X4+2]([NX3])([OX1-])([OX1-])[#6])]", + "[SX4](C)(C)(=O)=N", + "[$([SX4](=[OX1])(=[OX1])([!O])[NX3]),$([SX4+2]([OX1-])([OX1-])([!O])[NX3])]", + "[$([#16X3]=[OX1]),$([#16X3+][OX1-])]", + "[$([#16X3](=[OX1])([#6])[#6]),$([#16X3+]([OX1-])([#6])[#6])]", + "[$([#16X4](=[OX1])(=[OX1])([OX2H,OX1H0-])[OX2][#6]),$([#16X4+2]([OX1-])([OX1-])([OX2H,OX1H0-])[OX2][#6])]", + "[$([SX4](=O)(=O)(O)O),$([SX4+2]([O-])([O-])(O)O)]", + "[$([#16X4](=[OX1])(=[OX1])([OX2][#6])[OX2][#6]),$([#16X4](=[OX1])(=[OX1])([OX2][#6])[OX2][#6])]", + "[$([#16X4]([NX3])(=[OX1])(=[OX1])[OX2][#6]),$([#16X4+2]([NX3])([OX1-])([OX1-])[OX2][#6])]", + "[$([#16X4]([NX3])(=[OX1])(=[OX1])[OX2H,OX1H0-]),$([#16X4+2]([NX3])([OX1-])([OX1-])[OX2H,OX1H0-])]", + "[#16X2][OX2H,OX1H0-]", + "[#16X2][OX2H0]", + # X + "[#6][F,Cl,Br,I]", + "[F,Cl,Br,I]", + "[F,Cl,Br,I].[F,Cl,Br,I].[F,Cl,Br,I]", + ] + +def get_gasteiger_partial_charges(mol, n_iter=12): + """ + Calculates list of gasteiger partial charges for each atom in mol object. + Args: + mol: rdkit mol object. + n_iter(int): number of iterations. Default 12. + Returns: + list of computed partial charges for each atom. + """ + Chem.rdPartialCharges.ComputeGasteigerCharges(mol, nIter=n_iter, + throwOnParamFailure=True) + partial_charges = [float(a.GetProp('_GasteigerCharge')) for a in + mol.GetAtoms()] + return partial_charges + + +def create_standardized_mol_id(smiles): + """ + Args: + smiles: smiles sequence. + Returns: + inchi. + """ + if check_smiles_validity(smiles): + # remove stereochemistry + smiles = AllChem.MolToSmiles(AllChem.MolFromSmiles(smiles), + isomericSmiles=False) + mol = AllChem.MolFromSmiles(smiles) + if not mol is None: # to catch weird issue with O=C1O[al]2oc(=O)c3ccc(cn3)c3ccccc3c3cccc(c3)c3ccccc3c3cc(C(F)(F)F)c(cc3o2)-c2ccccc2-c2cccc(c2)-c2ccccc2-c2cccnc21 + if '.' in smiles: # if multiple species, pick largest molecule + mol_species_list = split_rdkit_mol_obj(mol) + largest_mol = get_largest_mol(mol_species_list) + inchi = AllChem.MolToInchi(largest_mol) + else: + inchi = AllChem.MolToInchi(mol) + return inchi + else: + return + else: + return + + +def check_smiles_validity(smiles): + """ + Check whether the smile can't be converted to rdkit mol object. + """ + try: + m = Chem.MolFromSmiles(smiles) + if m: + return True + else: + return False + except Exception as e: + return False + + +def split_rdkit_mol_obj(mol): + """ + Split rdkit mol object containing multiple species or one species into a + list of mol objects or a list containing a single object respectively. + Args: + mol: rdkit mol object. + """ + smiles = AllChem.MolToSmiles(mol, isomericSmiles=True) + smiles_list = smiles.split('.') + mol_species_list = [] + for s in smiles_list: + if check_smiles_validity(s): + mol_species_list.append(AllChem.MolFromSmiles(s)) + return mol_species_list + + +def get_largest_mol(mol_list): + """ + Given a list of rdkit mol objects, returns mol object containing the + largest num of atoms. If multiple containing largest num of atoms, + picks the first one. + Args: + mol_list(list): a list of rdkit mol object. + Returns: + the largest mol. + """ + num_atoms_list = [len(m.GetAtoms()) for m in mol_list] + largest_mol_idx = num_atoms_list.index(max(num_atoms_list)) + return mol_list[largest_mol_idx] + + +def rdchem_enum_to_list(values): + """values = {0: rdkit.Chem.rdchem.ChiralType.CHI_UNSPECIFIED, + 1: rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW, + 2: rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, + 3: rdkit.Chem.rdchem.ChiralType.CHI_OTHER} + """ + return [values[i] for i in range(len(values))] + + +def safe_index(alist, elem): + """ + Return index of element e in list l. If e is not present, return the last index + """ + try: + return alist.index(elem) + except ValueError: + return len(alist) - 1 + + +def get_atom_feature_dims(list_acquired_feature_names): + """ tbd + """ + return list(map(len, [CompoundKit.atom_vocab_dict[name] for name in list_acquired_feature_names])) + + +def get_bond_feature_dims(list_acquired_feature_names): + """ tbd + """ + list_bond_feat_dim = list(map(len, [CompoundKit.bond_vocab_dict[name] for name in list_acquired_feature_names])) + # +1 for self loop edges + return [_l + 1 for _l in list_bond_feat_dim] + + +class CompoundKit(object): + """ + CompoundKit + """ + atom_vocab_dict = { + "atomic_num": list(range(1, 119)) + ['misc'], + "chiral_tag": rdchem_enum_to_list(rdchem.ChiralType.values), + "degree": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'], + "explicit_valence": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 'misc'], + "formal_charge": [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'], + "hybridization": rdchem_enum_to_list(rdchem.HybridizationType.values), + "implicit_valence": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 'misc'], + "is_aromatic": [0, 1], + "total_numHs": [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + 'num_radical_e': [0, 1, 2, 3, 4, 'misc'], + 'atom_is_in_ring': [0, 1], + 'valence_out_shell': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + 'in_num_ring_with_size3': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + 'in_num_ring_with_size4': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + 'in_num_ring_with_size5': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + 'in_num_ring_with_size6': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + 'in_num_ring_with_size7': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + 'in_num_ring_with_size8': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + } + bond_vocab_dict = { + "bond_dir": rdchem_enum_to_list(rdchem.BondDir.values), + "bond_type": rdchem_enum_to_list(rdchem.BondType.values), + "is_in_ring": [0, 1], + + 'bond_stereo': rdchem_enum_to_list(rdchem.BondStereo.values), + 'is_conjugated': [0, 1], + } + # float features + atom_float_names = ["van_der_waals_radis", "partial_charge", 'mass'] + # bond_float_feats= ["bond_length", "bond_angle"] # optional + + ### functional groups + day_light_fg_smarts_list = DAY_LIGHT_FG_SMARTS_LIST + day_light_fg_mo_list = [Chem.MolFromSmarts(smarts) for smarts in day_light_fg_smarts_list] + + morgan_fp_N = 200 + morgan2048_fp_N = 2048 + maccs_fp_N = 167 + + period_table = Chem.GetPeriodicTable() + + ### atom + + @staticmethod + def get_atom_value(atom, name): + """get atom values""" + if name == 'atomic_num': + return atom.GetAtomicNum() + elif name == 'chiral_tag': + return atom.GetChiralTag() + elif name == 'degree': + return atom.GetDegree() + elif name == 'explicit_valence': + return atom.GetExplicitValence() + elif name == 'formal_charge': + return atom.GetFormalCharge() + elif name == 'hybridization': + return atom.GetHybridization() + elif name == 'implicit_valence': + return atom.GetImplicitValence() + elif name == 'is_aromatic': + return int(atom.GetIsAromatic()) + elif name == 'mass': + return int(atom.GetMass()) + elif name == 'total_numHs': + return atom.GetTotalNumHs() + elif name == 'num_radical_e': + return atom.GetNumRadicalElectrons() + elif name == 'atom_is_in_ring': + return int(atom.IsInRing()) + elif name == 'valence_out_shell': + return CompoundKit.period_table.GetNOuterElecs(atom.GetAtomicNum()) + else: + raise ValueError(name) + + @staticmethod + def get_atom_feature_id(atom, name): + """get atom features id""" + assert name in CompoundKit.atom_vocab_dict, "%s not found in atom_vocab_dict" % name + return safe_index(CompoundKit.atom_vocab_dict[name], CompoundKit.get_atom_value(atom, name)) + + @staticmethod + def get_atom_feature_size(name): + """get atom features size""" + assert name in CompoundKit.atom_vocab_dict, "%s not found in atom_vocab_dict" % name + return len(CompoundKit.atom_vocab_dict[name]) + + ### bond + + @staticmethod + def get_bond_value(bond, name): + """get bond values""" + if name == 'bond_dir': + return bond.GetBondDir() + elif name == 'bond_type': + return bond.GetBondType() + elif name == 'is_in_ring': + return int(bond.IsInRing()) + elif name == 'is_conjugated': + return int(bond.GetIsConjugated()) + elif name == 'bond_stereo': + return bond.GetStereo() + else: + raise ValueError(name) + + @staticmethod + def get_bond_feature_id(bond, name): + """get bond features id""" + assert name in CompoundKit.bond_vocab_dict, "%s not found in bond_vocab_dict" % name + return safe_index(CompoundKit.bond_vocab_dict[name], CompoundKit.get_bond_value(bond, name)) + + @staticmethod + def get_bond_feature_size(name): + """get bond features size""" + assert name in CompoundKit.bond_vocab_dict, "%s not found in bond_vocab_dict" % name + return len(CompoundKit.bond_vocab_dict[name]) + + ### fingerprint + + @staticmethod + def get_morgan_fingerprint(mol, radius=2): + """get morgan fingerprint""" + nBits = CompoundKit.morgan_fp_N + mfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits) + return [int(b) for b in mfp.ToBitString()] + + @staticmethod + def get_morgan2048_fingerprint(mol, radius=2): + """get morgan2048 fingerprint""" + nBits = CompoundKit.morgan2048_fp_N + mfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits) + return [int(b) for b in mfp.ToBitString()] + + @staticmethod + def get_maccs_fingerprint(mol): + """get maccs fingerprint""" + fp = AllChem.GetMACCSKeysFingerprint(mol) + return [int(b) for b in fp.ToBitString()] + + ### functional groups + + @staticmethod + def get_daylight_functional_group_counts(mol): + """get daylight functional group counts""" + fg_counts = [] + for fg_mol in CompoundKit.day_light_fg_mo_list: + sub_structs = Chem.Mol.GetSubstructMatches(mol, fg_mol, uniquify=True) + fg_counts.append(len(sub_structs)) + return fg_counts + + @staticmethod + def get_ring_size(mol): + """return (N,6) list""" + rings = mol.GetRingInfo() + rings_info = [] + for r in rings.AtomRings(): + rings_info.append(r) + ring_list = [] + for atom in mol.GetAtoms(): + atom_result = [] + for ringsize in range(3, 9): + num_of_ring_at_ringsize = 0 + for r in rings_info: + if len(r) == ringsize and atom.GetIdx() in r: + num_of_ring_at_ringsize += 1 + if num_of_ring_at_ringsize > 8: + num_of_ring_at_ringsize = 9 + atom_result.append(num_of_ring_at_ringsize) + + ring_list.append(atom_result) + return ring_list + + @staticmethod + def atom_to_feat_vector(atom): + """ tbd """ + atom_names = { + "atomic_num": safe_index(CompoundKit.atom_vocab_dict["atomic_num"], atom.GetAtomicNum()), + "chiral_tag": safe_index(CompoundKit.atom_vocab_dict["chiral_tag"], atom.GetChiralTag()), + "degree": safe_index(CompoundKit.atom_vocab_dict["degree"], atom.GetTotalDegree()), + "explicit_valence": safe_index(CompoundKit.atom_vocab_dict["explicit_valence"], atom.GetExplicitValence()), + "formal_charge": safe_index(CompoundKit.atom_vocab_dict["formal_charge"], atom.GetFormalCharge()), + "hybridization": safe_index(CompoundKit.atom_vocab_dict["hybridization"], atom.GetHybridization()), + "implicit_valence": safe_index(CompoundKit.atom_vocab_dict["implicit_valence"], atom.GetImplicitValence()), + "is_aromatic": safe_index(CompoundKit.atom_vocab_dict["is_aromatic"], int(atom.GetIsAromatic())), + "total_numHs": safe_index(CompoundKit.atom_vocab_dict["total_numHs"], atom.GetTotalNumHs()), + 'num_radical_e': safe_index(CompoundKit.atom_vocab_dict['num_radical_e'], atom.GetNumRadicalElectrons()), + 'atom_is_in_ring': safe_index(CompoundKit.atom_vocab_dict['atom_is_in_ring'], int(atom.IsInRing())), + 'valence_out_shell': safe_index(CompoundKit.atom_vocab_dict['valence_out_shell'], + CompoundKit.period_table.GetNOuterElecs(atom.GetAtomicNum())), + 'van_der_waals_radis': CompoundKit.period_table.GetRvdw(atom.GetAtomicNum()), + 'partial_charge': CompoundKit.check_partial_charge(atom), + 'mass': atom.GetMass(), + } + return atom_names + + @staticmethod + def get_atom_names(mol): + """get atom name list + TODO: to be remove in the future + """ + atom_features_dicts = [] + Chem.rdPartialCharges.ComputeGasteigerCharges(mol) + for i, atom in enumerate(mol.GetAtoms()): + atom_features_dicts.append(CompoundKit.atom_to_feat_vector(atom)) + + ring_list = CompoundKit.get_ring_size(mol) + for i, atom in enumerate(mol.GetAtoms()): + atom_features_dicts[i]['in_num_ring_with_size3'] = safe_index( + CompoundKit.atom_vocab_dict['in_num_ring_with_size3'], ring_list[i][0]) + atom_features_dicts[i]['in_num_ring_with_size4'] = safe_index( + CompoundKit.atom_vocab_dict['in_num_ring_with_size4'], ring_list[i][1]) + atom_features_dicts[i]['in_num_ring_with_size5'] = safe_index( + CompoundKit.atom_vocab_dict['in_num_ring_with_size5'], ring_list[i][2]) + atom_features_dicts[i]['in_num_ring_with_size6'] = safe_index( + CompoundKit.atom_vocab_dict['in_num_ring_with_size6'], ring_list[i][3]) + atom_features_dicts[i]['in_num_ring_with_size7'] = safe_index( + CompoundKit.atom_vocab_dict['in_num_ring_with_size7'], ring_list[i][4]) + atom_features_dicts[i]['in_num_ring_with_size8'] = safe_index( + CompoundKit.atom_vocab_dict['in_num_ring_with_size8'], ring_list[i][5]) + + return atom_features_dicts + + @staticmethod + def check_partial_charge(atom): + """tbd""" + pc = atom.GetDoubleProp('_GasteigerCharge') + if pc != pc: + # unsupported atom, replace nan with 0 + pc = 0 + if pc == float('inf'): + # max 4 for other atoms, set to 10 here if inf is get + pc = 10 + return pc + + +class Compound3DKit(object): + """the 3Dkit of Compound""" + + @staticmethod + def get_atom_poses(mol, conf): + """tbd""" + atom_poses = [] + for i, atom in enumerate(mol.GetAtoms()): + if atom.GetAtomicNum() == 0: + return [[0.0, 0.0, 0.0]] * len(mol.GetAtoms()) + pos = conf.GetAtomPosition(i) + atom_poses.append([pos.x, pos.y, pos.z]) + return atom_poses + + @staticmethod + def get_MMFF_atom_poses(mol, numConfs=None, return_energy=False): + """the atoms of mol will be changed in some cases.""" + conf = mol.GetConformer() + atom_poses = Compound3DKit.get_atom_poses(mol, conf) + return mol,atom_poses + # try: + # new_mol = Chem.AddHs(mol) + # res = AllChem.EmbedMultipleConfs(new_mol, numConfs=numConfs) + # ### MMFF generates multiple conformations + # res = AllChem.MMFFOptimizeMoleculeConfs(new_mol) + # new_mol = Chem.RemoveHs(new_mol) + # index = np.argmin([x[1] for x in res]) + # energy = res[index][1] + # conf = new_mol.GetConformer(id=int(index)) + # except: + # new_mol = mol + # AllChem.Compute2DCoords(new_mol) + # energy = 0 + # conf = new_mol.GetConformer() + # + # atom_poses = Compound3DKit.get_atom_poses(new_mol, conf) + # if return_energy: + # return new_mol, atom_poses, energy + # else: + # return new_mol, atom_poses + + @staticmethod + def get_2d_atom_poses(mol): + """get 2d atom poses""" + AllChem.Compute2DCoords(mol) + conf = mol.GetConformer() + atom_poses = Compound3DKit.get_atom_poses(mol, conf) + return atom_poses + + @staticmethod + def get_bond_lengths(edges, atom_poses): + """get bond lengths""" + bond_lengths = [] + for src_node_i, tar_node_j in edges: + bond_lengths.append(np.linalg.norm(atom_poses[tar_node_j] - atom_poses[src_node_i])) + bond_lengths = np.array(bond_lengths, 'float32') + return bond_lengths + + @staticmethod + def get_superedge_angles(edges, atom_poses, dir_type='HT'): + """get superedge angles""" + + def _get_vec(atom_poses, edge): + return atom_poses[edge[1]] - atom_poses[edge[0]] + + def _get_angle(vec1, vec2): + norm1 = np.linalg.norm(vec1) + norm2 = np.linalg.norm(vec2) + if norm1 == 0 or norm2 == 0: + return 0 + vec1 = vec1 / (norm1 + 1e-5) # 1e-5: prevent numerical errors + vec2 = vec2 / (norm2 + 1e-5) + angle = np.arccos(np.dot(vec1, vec2)) + return angle + + E = len(edges) + edge_indices = np.arange(E) + super_edges = [] + bond_angles = [] + bond_angle_dirs = [] + for tar_edge_i in range(E): + tar_edge = edges[tar_edge_i] + if dir_type == 'HT': + src_edge_indices = edge_indices[edges[:, 1] == tar_edge[0]] + elif dir_type == 'HH': + src_edge_indices = edge_indices[edges[:, 1] == tar_edge[1]] + else: + raise ValueError(dir_type) + for src_edge_i in src_edge_indices: + if src_edge_i == tar_edge_i: + continue + src_edge = edges[src_edge_i] + src_vec = _get_vec(atom_poses, src_edge) + tar_vec = _get_vec(atom_poses, tar_edge) + super_edges.append([src_edge_i, tar_edge_i]) + angle = _get_angle(src_vec, tar_vec) + bond_angles.append(angle) + bond_angle_dirs.append(src_edge[1] == tar_edge[0]) # H -> H or H -> T + + if len(super_edges) == 0: + super_edges = np.zeros([0, 2], 'int64') + bond_angles = np.zeros([0, ], 'float32') + else: + super_edges = np.array(super_edges, 'int64') + bond_angles = np.array(bond_angles, 'float32') + return super_edges, bond_angles, bond_angle_dirs + + +def new_smiles_to_graph_data(smiles, **kwargs): + """ + Convert smiles to graph data. + """ + mol = AllChem.MolFromSmiles(smiles) + if mol is None: + return None + data = new_mol_to_graph_data(mol) + return data + + +def new_mol_to_graph_data(mol): + """ + mol_to_graph_data + Args: + atom_features: Atom features. + edge_features: Edge features. + morgan_fingerprint: Morgan fingerprint. + functional_groups: Functional groups. + """ + if len(mol.GetAtoms()) == 0: + return None + + atom_id_names = list(CompoundKit.atom_vocab_dict.keys()) + CompoundKit.atom_float_names + bond_id_names = list(CompoundKit.bond_vocab_dict.keys()) + + data = {} + + ### atom features + data = {name: [] for name in atom_id_names} + + raw_atom_feat_dicts = CompoundKit.get_atom_names(mol) + for atom_feat in raw_atom_feat_dicts: + for name in atom_id_names: + data[name].append(atom_feat[name]) + + ### bond and bond features + for name in bond_id_names: + data[name] = [] + data['edges'] = [] + + for bond in mol.GetBonds(): + i = bond.GetBeginAtomIdx() + j = bond.GetEndAtomIdx() + # i->j and j->i + data['edges'] += [(i, j), (j, i)] + for name in bond_id_names: + bond_feature_id = CompoundKit.get_bond_feature_id(bond, name) + data[name] += [bond_feature_id] * 2 + + #### self loop + N = len(data[atom_id_names[0]]) + for i in range(N): + data['edges'] += [(i, i)] + for name in bond_id_names: + bond_feature_id = get_bond_feature_dims([name])[0] - 1 # self loop: value = len - 1 + data[name] += [bond_feature_id] * N + + ### make ndarray and check length + for name in list(CompoundKit.atom_vocab_dict.keys()): + data[name] = np.array(data[name], 'int64') + for name in CompoundKit.atom_float_names: + data[name] = np.array(data[name], 'float32') + for name in bond_id_names: + data[name] = np.array(data[name], 'int64') + data['edges'] = np.array(data['edges'], 'int64') + + ### morgan fingerprint + data['morgan_fp'] = np.array(CompoundKit.get_morgan_fingerprint(mol), 'int64') + # data['morgan2048_fp'] = np.array(CompoundKit.get_morgan2048_fingerprint(mol), 'int64') + data['maccs_fp'] = np.array(CompoundKit.get_maccs_fingerprint(mol), 'int64') + data['daylight_fg_counts'] = np.array(CompoundKit.get_daylight_functional_group_counts(mol), 'int64') + return data + + +def mol_to_graph_data(mol): + """ + mol_to_graph_data + Args: + atom_features: Atom features. + edge_features: Edge features. + morgan_fingerprint: Morgan fingerprint. + functional_groups: Functional groups. + """ + if len(mol.GetAtoms()) == 0: + return None + + atom_id_names = [ + "atomic_num", "chiral_tag", "degree", "explicit_valence", + "formal_charge", "hybridization", "implicit_valence", + "is_aromatic", "total_numHs", + ] + bond_id_names = [ + "bond_dir", "bond_type", "is_in_ring", + ] + + data = {} + for name in atom_id_names: + data[name] = [] + data['mass'] = [] + for name in bond_id_names: + data[name] = [] + data['edges'] = [] + + ### atom features + for i, atom in enumerate(mol.GetAtoms()): + if atom.GetAtomicNum() == 0: + return None + for name in atom_id_names: + data[name].append(CompoundKit.get_atom_feature_id(atom, name) + 1) # 0: OOV + data['mass'].append(CompoundKit.get_atom_value(atom, 'mass') * 0.01) + + ### bond features + for bond in mol.GetBonds(): + i = bond.GetBeginAtomIdx() + j = bond.GetEndAtomIdx() + # i->j and j->i + data['edges'] += [(i, j), (j, i)] + for name in bond_id_names: + bond_feature_id = CompoundKit.get_bond_feature_id(bond, name) + 1 # 0: OOV + data[name] += [bond_feature_id] * 2 + + ### self loop (+2) + N = len(data[atom_id_names[0]]) + for i in range(N): + data['edges'] += [(i, i)] + for name in bond_id_names: + bond_feature_id = CompoundKit.get_bond_feature_size(name) + 2 # N + 2: self loop + data[name] += [bond_feature_id] * N + + ### check whether edge exists + if len(data['edges']) == 0: # mol has no bonds + for name in bond_id_names: + data[name] = np.zeros((0,), dtype="int64") + data['edges'] = np.zeros((0, 2), dtype="int64") + + ### make ndarray and check length + for name in atom_id_names: + data[name] = np.array(data[name], 'int64') + data['mass'] = np.array(data['mass'], 'float32') + for name in bond_id_names: + data[name] = np.array(data[name], 'int64') + data['edges'] = np.array(data['edges'], 'int64') + + ### morgan fingerprint + data['morgan_fp'] = np.array(CompoundKit.get_morgan_fingerprint(mol), 'int64') + # data['morgan2048_fp'] = np.array(CompoundKit.get_morgan2048_fingerprint(mol), 'int64') + data['maccs_fp'] = np.array(CompoundKit.get_maccs_fingerprint(mol), 'int64') + data['daylight_fg_counts'] = np.array(CompoundKit.get_daylight_functional_group_counts(mol), 'int64') + return data + + +def mol_to_geognn_graph_data(mol, atom_poses, dir_type): + """ + mol: rdkit molecule + dir_type: direction type for bond_angle grpah + """ + if len(mol.GetAtoms()) == 0: + return None + + data = mol_to_graph_data(mol) + + data['atom_pos'] = np.array(atom_poses, 'float32') + data['bond_length'] = Compound3DKit.get_bond_lengths(data['edges'], data['atom_pos']) + BondAngleGraph_edges, bond_angles, bond_angle_dirs = \ + Compound3DKit.get_superedge_angles(data['edges'], data['atom_pos']) + data['BondAngleGraph_edges'] = BondAngleGraph_edges + data['bond_angle'] = np.array(bond_angles, 'float32') + return data + + +def mol_to_geognn_graph_data_MMFF3d(mol): + """tbd""" + if len(mol.GetAtoms()) <= 400: + mol, atom_poses = Compound3DKit.get_MMFF_atom_poses(mol, numConfs=10) + else: + atom_poses = Compound3DKit.get_2d_atom_poses(mol) + return mol_to_geognn_graph_data(mol, atom_poses, dir_type='HT') + + +def mol_to_geognn_graph_data_raw3d(mol): + """tbd""" + atom_poses = Compound3DKit.get_atom_poses(mol, mol.GetConformer()) + return mol_to_geognn_graph_data(mol, atom_poses, dir_type='HT') + +def obtain_3D_mol(smiles,name): + mol = AllChem.MolFromSmiles(smiles) + new_mol = Chem.AddHs(mol) + res = AllChem.EmbedMultipleConfs(new_mol) + ### MMFF generates multiple conformations + res = AllChem.MMFFOptimizeMoleculeConfs(new_mol) + new_mol = Chem.RemoveHs(new_mol) + Chem.MolToMolFile(new_mol, name+'.mol') + return new_mol + +def predict_SMILES_info(smiles): + # by lihao, input smiles, output dict + mol = AllChem.MolFromSmiles(smiles) + AllChem.EmbedMolecule(mol) + info_dict = mol_to_geognn_graph_data_MMFF3d(mol) + return info_dict + +# ----------------Commonly-used Parameters---------------- +atom_id_names = [ + "atomic_num", "chiral_tag", "degree", "explicit_valence", + "formal_charge", "hybridization", "implicit_valence", + "is_aromatic", "total_numHs", +] +bond_id_names = ["bond_dir", "bond_type", "is_in_ring"] +full_atom_feature_dims = get_atom_feature_dims(atom_id_names) +full_bond_feature_dims = get_bond_feature_dims(bond_id_names) +bond_angle_float_names = ['bond_angle', 'TPSA', 'RASA', 'RPSA', 'MDEC', 'MATS'] +column_specify={ + 'ADH':[1,5,0,0],'ODH':[1,5,0,1],'IC':[0,5,1,2],'IA':[0,5,1,3],'OJH':[1,5,0,4], + 'ASH':[1,5,0,5],'IC3':[0,3,1,6],'IE':[0,5,1,7],'ID':[0,5,1,8],'OD3':[1,3,0,9], + 'IB':[0,5,1,10],'AD':[1,10,0,11],'AD3':[1,3,0,12],'IF':[0,5,1,13],'OD':[1,10,0,14], + 'AS':[1,10,0,15],'OJ3':[1,3,0,16],'IG':[0,5,1,17],'AZ':[1,10,0,18],'IAH':[0,5,1,19], + 'OJ':[1,10,0,20],'ICH':[0,5,1,21],'OZ3':[1,3,0,22],'IF3':[0,3,1,23],'IAU':[0,1.6,1,24] +} +bond_float_names = [] + +if __name__ == "__main__": + # smiles = "OCc1ccccc1CN" + smiles = r"[H]/[NH+]=C(\N)C1=CC(=O)/C(=C\C=c2ccc(=C(N)[NH3+])cc2)C=C1" + # smiles = 'CC' + mol = AllChem.MolFromSmiles(smiles) + AllChem.EmbedMolecule(mol) + data = mol_to_geognn_graph_data_MMFF3d(mol) + for key, value in data.items(): + print(key, value.shape) \ No newline at end of file diff --git a/ppmat/utils/graph_utils.py b/ppmat/utils/graph_utils.py new file mode 100644 index 00000000..394328d7 --- /dev/null +++ b/ppmat/utils/graph_utils.py @@ -0,0 +1,58 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle + +def index_transform(raw_index, batch_size): + """Convert compressed batch indices to a list of node indices for each sample""" + + def get_index1(lst=None, batch_num=-1): + return [index for (index, value) in enumerate(lst) if value == batch_num] + + raw_index = raw_index.tolist() + index_list = [] + for batch_id in range(batch_size): + index_list.append(get_index1(raw_index, batch_id)) + return index_list + + +def get_key_padding_mask(tokens): + """Generate key padding mask""" + key_padding_mask = paddle.zeros(tokens.shape) + key_padding_mask[tokens == -1] = -paddle.inf + return key_padding_mask + + +def feat_padding_mask(index, max_node_num): + """Generate feature padding mask based on node indices""" + new_index = [] + for itm_list in index: + new_index.append(itm_list + [-1] * (max_node_num - len(itm_list))) + new_index = paddle.to_tensor(new_index) + return get_key_padding_mask(new_index) + + +def pad_node_features(molecule_features, batch_index, this_batch_size, max_node_num, emb_dim): + """Pad compressed node features to [batch, max_node, emb_dim] format""" + index_list = index_transform(batch_index, this_batch_size) + + new_batch_list = [] + for batch_id in range(this_batch_size): + empty_batch_tensor = paddle.zeros([max_node_num, emb_dim]) + for i in range(len(index_list[batch_id])): + empty_batch_tensor[i, :] = molecule_features[index_list[batch_id][i], :] + new_batch_list.append(empty_batch_tensor) + + node_feature = paddle.stack(new_batch_list, axis=0) + return node_feature, index_list \ No newline at end of file diff --git a/ppmat/utils/place_env.py b/ppmat/utils/place_env.py new file mode 100644 index 00000000..83274396 --- /dev/null +++ b/ppmat/utils/place_env.py @@ -0,0 +1,109 @@ +import paddle +import functools +from paddle._typing.device_like import PlaceLike + +class PlaceEnv: + """ + Class version of context manager, also supports decorator functionality + """ + + def __init__(self, place: PlaceLike): + """ + Initialize PlaceEnv + + Args: + place: device objects like paddle.CPUPlace() or paddle.CUDAPlace(0) + """ + self.place = place + self.original_device = None + + def __enter__(self): + """Called when entering the context""" + # Save current device setting + self.original_device = paddle.get_device() + paddle.set_device(self.place) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Called when exiting the context""" + # Restore original device setting + if self.original_device is not None: + paddle.set_device(self.original_device) + + def __call__(self, func): + """ + Allows instance to be used as a decorator + + Args: + func: function to be decorated + + Returns: + Decorated function + """ + @functools.wraps(func) + def wrapper(*args, **kwargs): + # Use with statement to temporarily change device setting + with self: + return func(*args, **kwargs) + return wrapper + + + +# Usage example +if __name__ == "__main__": + # Test with statement + print("=== Testing with statement ===") + print(f"Current device: {paddle.get_device()}") + + print("\n=== Testing class version with statement ===") + with PlaceEnv(paddle.CPUPlace()): + print(f"Device inside with block: {paddle.get_device()}") + y = paddle.ones([2, 3]) + print(f"Created tensor: {y}") + + print(f"Device outside with block: {paddle.get_device()}") + + # Test decorator functionality + print("\n=== Testing decorator functionality ===") + + @PlaceEnv(paddle.CPUPlace()) + def cpu_function(): + """This function will run on CPU""" + print(f"Device inside function: {paddle.get_device()}") + return paddle.rand([2, 2]) + + # Check if GPU is available + if paddle.device.cuda.device_count() > 0: + @PlaceEnv(paddle.CUDAPlace(0)) + def gpu_function(): + """This function will run on GPU""" + print(f"Device inside function: {paddle.get_device()}") + return paddle.rand([2, 2]) + + # Call decorated functions + print("Calling cpu_function:") + result_cpu = cpu_function() + print(f"Device after function execution: {paddle.get_device()}") + print(f"Result: {result_cpu}") + + if paddle.device.cuda.device_count() > 0: + print("\nCalling gpu_function:") + result_gpu = gpu_function() + print(f"Device after function execution: {paddle.get_device()}") + print(f"Result: {result_gpu}") + + print("\n=== Testing multiple nesting ===") + print(f"Initial device: {paddle.get_device()}") + + with PlaceEnv(paddle.CPUPlace()): + print(f"Device inside first with block: {paddle.get_device()}") + + if paddle.device.cuda.device_count() > 0: + with PlaceEnv(paddle.CUDAPlace(0)): + print(f"Device inside second with block: {paddle.get_device()}") + z = paddle.rand([2, 2]) + print(f"Created tensor: {z}") + + print(f"Back to first with block device: {paddle.get_device()}") + + print(f"Final device: {paddle.get_device()}") \ No newline at end of file diff --git a/spectrum_elucidation/configs/ecformer/ecd.yaml b/spectrum_elucidation/configs/ecformer/ecd.yaml new file mode 100644 index 00000000..53c9d583 --- /dev/null +++ b/spectrum_elucidation/configs/ecformer/ecd.yaml @@ -0,0 +1,113 @@ +# ECFormer ECD Task Configuration +# Task: Electronic Circular Dichroism Spectrum Prediction + +Global: + do_train: True + do_eval: True + do_test: True + label_names: ["peak_number", "peak_position", "peak_height"] + +Dataset: + train: + dataset: + __class_name__: ECDDataset + __init_params__: + data_path: ./datasets/ECD + data_count: null # null means use all data + use_geometry_enhanced: True + use_column_info: False + loader: + num_workers: 4 + collate_fn: ECDCollator + sampler: + __class_name__: BatchSampler + __init_params__: + batch_size: 128 + shuffle: True + + val: + dataset: + __class_name__: ECDDataset + __init_params__: + data_path: ./datasets/ECD + data_count: null + use_geometry_enhanced: True + use_column_info: False + loader: + num_workers: 4 + collate_fn: ECDCollator + sampler: + __class_name__: BatchSampler + __init_params__: + batch_size: 128 + shuffle: False + + test: + dataset: + __class_name__: ECDDataset + __init_params__: + data_path: ./datasets/ECD + data_count: null + use_geometry_enhanced: True + use_column_info: False + loader: + num_workers: 4 + collate_fn: ECDCollator + sampler: + __class_name__: BatchSampler + __init_params__: + batch_size: 128 + shuffle: False + +Model: + __class_name__: ECFormerECD + __init_params__: + full_atom_feature_dims: [119, 9, 12, 14, 17, 9, 14, 2, 10] + full_bond_feature_dims: [8, 23, 3] + bond_float_names: ['bond_length'] + bond_angle_float_names: ['bond_angle', 'TPSA', 'RASA', 'RPSA', 'MDEC', 'MATS'] + bond_id_names: ['bond_dir', 'bond_type', 'is_in_ring'] + num_layers: 5 + emb_dim: 256 + drop_ratio: 0.0 + graph_pooling: 'sum' + use_geometry_enhanced: True + max_node_num: 63 + num_heads: 4 + tf_layers: 2 + tf_dropout: 0.1 + max_peaks: 9 + num_position_classes: 20 + height_classes: 2 + +Optimizer: + __class_name__: AdamW + __init_params__: + lr: 0.001 + beta1: 0.9 + beta2: 0.999 + epsilon: 1e-8 + weight_decay: 0.01 + +Metric: + __class_name__: ECDMetrics + __init_params__: + num_position_classes: 20 + max_peaks: 9 + +Trainer: + output_dir: ./output/ecformer_ecd + max_epochs: 100 + log_freq: 10 + save_freq: 5 + eval_freq: 1 + seed: 42 + use_amp: False + start_eval_epoch: 1 + amp_level: 'O1' + eval_with_no_grad: True + compute_metric_during_train: False + metric_strategy_during_eval: 'step' + use_visualdl: False + use_wandb: False + use_tensorboard: False \ No newline at end of file diff --git a/spectrum_elucidation/configs/ecformer/ir.yaml b/spectrum_elucidation/configs/ecformer/ir.yaml new file mode 100644 index 00000000..383b5787 --- /dev/null +++ b/spectrum_elucidation/configs/ecformer/ir.yaml @@ -0,0 +1,123 @@ +# ECFormer IR Task Configuration +# Task: Infrared Spectrum Prediction + +Global: + do_train: True + do_eval: True + do_test: True + label_names: ["peak_number", "peak_position", "peak_height"] + +Dataset: + train: + dataset: + __class_name__: IRDataset + __init_params__: + data_path: ./datasets/IR + data_count: null + use_geometry_enhanced: True + loader: + num_workers: 4 + collate_fn: IRCollator + sampler: + __class_name__: BatchSampler + __init_params__: + batch_size: 32 + shuffle: True + + val: + dataset: + __class_name__: IRDataset + __init_params__: + data_path: ./datasets/IR + data_count: null + use_geometry_enhanced: True + loader: + num_workers: 4 + collate_fn: IRCollator + sampler: + __class_name__: BatchSampler + __init_params__: + batch_size: 32 + shuffle: False + + test: + dataset: + __class_name__: IRDataset + __init_params__: + data_path: ./datasets/IR + data_count: null + use_geometry_enhanced: True + loader: + num_workers: 4 + collate_fn: IRCollator + sampler: + __class_name__: BatchSampler + __init_params__: + batch_size: 32 + shuffle: False + +Model: + __class_name__: ECFormerIR + __init_params__: + # Atom feature dimensions (consistent with ECD) + full_atom_feature_dims: [119, 9, 12, 14, 17, 9, 14, 2, 10] + # Bond feature dimensions (consistent with ECD) + full_bond_feature_dims: [8, 23, 3] + bond_float_names: ['bond_length'] + bond_angle_float_names: ['bond_angle', 'TPSA', 'RASA', 'RPSA', 'MDEC', 'MATS'] + bond_id_names: ['bond_dir', 'bond_type', 'is_in_ring'] + + # GNN parameters + num_layers: 5 + emb_dim: 128 # IR task can use smaller embedding dimension + drop_ratio: 0.0 + graph_pooling: 'sum' + use_geometry_enhanced: True + max_node_num: 63 # Maximum number of atoms (consistent with ECD) + + # Transformer parameters + num_heads: 4 + tf_layers: 2 + tf_dropout: 0.1 + + # IR-specific parameters + max_peaks: 15 # IR spectra have up to 15 peaks + num_position_classes: 36 # IR position classes (wavenumber range) + use_height_prediction: True # IR uses intensity prediction (regression) + +Optimizer: + __class_name__: AdamW + __init_params__: + lr: 0.001 + beta1: 0.9 + beta2: 0.999 + epsilon: 1e-8 + weight_decay: 0.01 + +Metric: + __class_name__: IRMetrics + __init_params__: + use_height_prediction: True # Consistent with model + +Trainer: + output_dir: ./output/ecformer_ir + max_epochs: 100 + log_freq: 10 + save_freq: 5 + eval_freq: 1 + seed: 42 + use_amp: False + start_eval_epoch: 1 + amp_level: 'O1' + eval_with_no_grad: True + compute_metric_during_train: False + metric_strategy_during_eval: 'epoch' # 'epoch' recommended for streaming metrics + use_visualdl: False + use_wandb: False + use_tensorboard: False + + # The following parameters can be passed to trainer (override model settings if needed) + loss_weight_height: 1.0 # Optional, override in trainer layer + max_peaks: 15 # Optional + num_position_classes: 36 # Optional + height_classes: 1 # IR intensity is regression task, output dimension is 1 \ No newline at end of file diff --git a/spectrum_elucidation/sample.py b/spectrum_elucidation/diffnmr/sample.py similarity index 100% rename from spectrum_elucidation/sample.py rename to spectrum_elucidation/diffnmr/sample.py diff --git a/spectrum_elucidation/train.py b/spectrum_elucidation/diffnmr/train.py similarity index 100% rename from spectrum_elucidation/train.py rename to spectrum_elucidation/diffnmr/train.py diff --git a/spectrum_elucidation/ecformer/train.py b/spectrum_elucidation/ecformer/train.py new file mode 100644 index 00000000..58b332ea --- /dev/null +++ b/spectrum_elucidation/ecformer/train.py @@ -0,0 +1,233 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import os.path as osp + +import paddle.distributed as dist +from omegaconf import OmegaConf + +from ppmat.datasets import build_dataloader +from ppmat.datasets import set_signal_handlers +from ppmat.models import build_model +from ppmat.optimizer import build_optimizer +from ppmat.utils import logger +from ppmat.utils import misc +from ppmat.utils import save_load + +from ppmat.trainer import ECFormerTrainer + + +def main(): + # Parse arguments + parser = argparse.ArgumentParser(description="ECDFormer for ECD Spectrum Prediction") + parser.add_argument( + "-c", "--config", + type=str, + default="../configs/ecformer/ecd.yaml", + help="Path to config file", + ) + parser.add_argument( + "--resume", + type=str, + default=None, + help="Resume from checkpoint path", + ) + parser.add_argument( + "--eval-only", + action="store_true", + help="Only run evaluation on validation set", + ) + parser.add_argument( + "--test-only", + action="store_true", + help="Only run evaluation on test set", + ) + parser.add_argument( + "--predict", + type=str, + default=None, + help="Path to data for prediction (inference mode)", + ) + + args, dynamic_args = parser.parse_known_args() + + # Load configuration + config = OmegaConf.load(args.config) + cli_config = OmegaConf.from_dotlist(dynamic_args) + config = OmegaConf.merge(config, cli_config) + + # Override Global configuration based on command line arguments + if args.eval_only: + config.Global.do_train = False + config.Global.do_eval = True + config.Global.do_test = False + elif args.test_only: + config.Global.do_train = False + config.Global.do_eval = False + config.Global.do_test = True + elif args.predict is not None: + config.Global.do_train = False + config.Global.do_eval = False + config.Global.do_test = False + config.Global.do_predict = True + config.Dataset.predict.data_path = args.predict + + # Save configuration + if dist.get_rank() == 0: + os.makedirs(config.Trainer.output_dir, exist_ok=True) + config_name = os.path.basename(args.config) + OmegaConf.save(config, osp.join(config.Trainer.output_dir, config_name)) + + # Convert to dictionary + config = OmegaConf.to_container(config, resolve=True) + + # Initialize logging + logger_path = osp.join(config["Trainer"]["output_dir"], "run.log") + logger.init_logger(log_file=logger_path) + logger.info(f"Logger saved to {logger_path}") + logger.info(f"Config: {config}") + + # Set random seed + seed = config["Trainer"].get("seed", 42) + misc.set_random_seed(seed) + logger.info(f"Set random seed to {seed}") + + # Set signal handlers + set_signal_handlers() + + # Build data loaders + dataloaders = {} + + if config["Global"].get("do_train", True): + train_cfg = config["Dataset"].get("train") + assert train_cfg is not None, "train dataset must be defined when do_train is True" + dataloaders["train"] = build_dataloader(train_cfg) + logger.info(f"Train dataset loaded, size: {len(dataloaders['train'].dataset)}") + + if config["Global"].get("do_eval", False) or config["Global"].get("do_train", True): + val_cfg = config["Dataset"].get("val") + if val_cfg is not None: + dataloaders["val"] = build_dataloader(val_cfg) + logger.info(f"Validation dataset loaded, size: {len(dataloaders['val'].dataset)}") + else: + logger.info("No validation dataset defined.") + + if config["Global"].get("do_test", False): + test_cfg = config["Dataset"].get("test") + assert test_cfg is not None, "test dataset must be defined when do_test is True" + dataloaders["test"] = build_dataloader(test_cfg) + logger.info(f"Test dataset loaded, size: {len(dataloaders['test'].dataset)}") + + if config["Global"].get("do_predict", False): + predict_cfg = config["Dataset"].get("predict") + assert predict_cfg is not None, "predict dataset must be defined when do_predict is True" + dataloaders["predict"] = build_dataloader(predict_cfg) + logger.info(f"Prediction dataset loaded, size: {len(dataloaders['predict'].dataset)}") + + # Build model + model_cfg = config["Model"] + model = build_model(model_cfg) + logger.info(f"Model built: {model_cfg['__class_name__']}") + + # Print model parameter count + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if not p.stop_gradient) + logger.info(f"Total parameters: {total_params / 1e6:.2f}M") + logger.info(f"Trainable parameters: {trainable_params / 1e6:.2f}M") + + # Build optimizer and learning rate scheduler + optimizer = None + lr_scheduler = None + + if config.get("Optimizer") is not None and config["Global"].get("do_train", True): + assert dataloaders.get("train") is not None, "train_loader must be defined when optimizer is defined" + assert config["Trainer"].get("max_epochs") is not None, "max_epochs must be defined" + + optimizer, lr_scheduler = build_optimizer( + config["Optimizer"], + model, + config["Trainer"]["max_epochs"], + len(dataloaders["train"]), + ) + logger.info(f"Optimizer built: {config['Optimizer']['__class_name__']}") + + # Build trainer + trainer = ECFormerTrainer( + config=config["Trainer"], + model=model, + train_dataloader=dataloaders.get("train"), + val_dataloader=dataloaders.get("val"), + optimizer=optimizer, + lr_scheduler=lr_scheduler, + ) + + # Resume from checkpoint + if args.resume is not None: + logger.info(f"Resuming from checkpoint: {args.resume}") + save_load.load_checkpoint( + args.resume, + model, + optimizer, + trainer.scaler, + ) + + # Execute training / evaluation / prediction + if config["Global"].get("do_train", True): + logger.info("Starting training...") + trainer.train() + + if config["Global"].get("do_eval", False): + logger.info("Evaluating on validation set...") + if "val" in dataloaders: + time_info, loss_info, metric_info = trainer.eval(dataloaders["val"]) + + # Print detailed metrics + msg = "Validation Results:" + for key, meter in metric_info.items(): + msg += f" | {key}: {meter.avg:.6f}" + logger.info(msg) + else: + logger.warning("No validation dataloader found, skipping evaluation.") + + if config["Global"].get("do_test", False): + logger.info("Evaluating on test set...") + if "test" in dataloaders: + time_info, loss_info, metric_info = trainer.eval(dataloaders["test"]) + + msg = "Test Results:" + for key, meter in metric_info.items(): + msg += f" | {key}: {meter.avg:.6f}" + logger.info(msg) + else: + logger.warning("No test dataloader found, skipping test evaluation.") + + if config["Global"].get("do_predict", False): + logger.info("Running prediction...") + if "predict" in dataloaders: + results = trainer.predict(dataloaders["predict"]) + + # Save prediction results + import json + output_path = osp.join(config["Trainer"]["output_dir"], "predictions.json") + with open(output_path, "w") as f: + json.dump(results, f, indent=2) + logger.info(f"Predictions saved to {output_path}") + else: + logger.warning("No prediction dataloader found, skipping prediction.") + + +if __name__ == "__main__": + main() \ No newline at end of file