Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,13 @@ electronic_structure/configs/qm9_data_split.json
electronic_structure/configs/qm9.json
electronic_structure/configs/crystal_data_split.json
electronic_structure/configs/crystal.json

# ignore export test data
*.npz
# ignore cif crystal texture
*.cif
# ingore paddle weight file.
*.pdparams
.claude/

*.pkl
12 changes: 12 additions & 0 deletions ppmat/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@
from ppmat.models.megnet.megnet import MEGNetPlus
from ppmat.models.infgcn.infgcn import InfGCN
from ppmat.models.mateno.mateno import MatENO
from ppmat.models.chemeleon2 import VAEModule
from ppmat.models.chemeleon2 import LDMModule
from ppmat.models.chemeleon2 import RLModule
from ppmat.models.chemeleon2.ldm_module.dit import DiT
from ppmat.models.chemeleon2.vae_module.encoder import TransformerEncoder
from ppmat.models.chemeleon2.vae_module.decoder import TransformerDecoder
from ppmat.utils import download
from ppmat.utils import logger
from ppmat.utils import save_load
Expand All @@ -67,6 +73,12 @@
"DiffNMR",
"InfGCN",
"MatENO",
"VAEModule",
"LDMModule",
"RLModule",
"DiT",
"TransformerEncoder",
"TransformerDecoder",
]

# Warning: The key of the dictionary must be consistent with the file name of the value
Expand Down
9 changes: 9 additions & 0 deletions ppmat/models/chemeleon2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from ppmat.models.chemeleon2.vae_module.vae import VAEModule
from ppmat.models.chemeleon2.ldm_module.ldm import LDMModule
from ppmat.models.chemeleon2.rl_module.rl import RLModule

__all__ = [
"VAEModule",
"LDMModule",
"RLModule",
]
31 changes: 31 additions & 0 deletions ppmat/models/chemeleon2/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from ppmat.models.chemeleon2.common.distributions import DiagonalGaussianDistribution
from ppmat.models.chemeleon2.common.schema import CrystalBatch
from ppmat.models.chemeleon2.common.scatter import scatter_mean
from ppmat.models.chemeleon2.common.scatter import scatter_sum
from ppmat.models.chemeleon2.common.scatter import scatter_std
from ppmat.models.chemeleon2.common.lattice_utils import lattice_params_to_matrix
from ppmat.models.chemeleon2.common.lattice_utils import matrix_to_lattice_params
from ppmat.models.chemeleon2.common.lattice_utils import frac_to_cart_coords
from ppmat.models.chemeleon2.common.lattice_utils import cart_to_frac_coords
from ppmat.models.chemeleon2.common.lattice_utils import get_pbc_distances
from ppmat.models.chemeleon2.common.lattice_utils import lattice_vector_to_volume
from ppmat.models.chemeleon2.common.batch_utils import to_dense_batch
from ppmat.models.chemeleon2.common.data_augmentation import apply_augmentation
from ppmat.models.chemeleon2.common.data_augmentation import apply_noise

__all__ = [
"DiagonalGaussianDistribution",
"CrystalBatch",
"scatter_mean",
"scatter_sum",
"scatter_std",
"lattice_params_to_matrix",
"matrix_to_lattice_params",
"frac_to_cart_coords",
"cart_to_frac_coords",
"get_pbc_distances",
"lattice_vector_to_volume",
"to_dense_batch",
"apply_augmentation",
"apply_noise",
]
41 changes: 41 additions & 0 deletions ppmat/models/chemeleon2/common/batch_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import paddle


def to_dense_batch(x, batch_idx, max_num_nodes=None):
"""
将 batch 格式的数据转换为 dense batch 格式,并生成 padding mask

Args:
x: 特征张量 [N, D],N 是总原子数,D 是特征维度
batch_idx: batch 索引 [N],指示每个原子属于哪个结构
max_num_nodes: 最大节点数,如果为 None 则自动计算

Returns:
x_dense: dense 格式的特征张量 [B, max_num_nodes, D]
mask: padding mask [B, max_num_nodes],True 表示有效位置,False 表示 padding
"""
batch_size = int(batch_idx.max().item()) + 1 if batch_idx.numel() > 0 else 1

num_nodes = paddle.zeros([batch_size], dtype='int64')
for i in range(batch_size):
num_nodes[i] = (batch_idx == i).sum()

if max_num_nodes is None:
max_num_nodes = int(num_nodes.max().item())

feat_dim = x.shape[-1]
x_dense = paddle.zeros([batch_size, max_num_nodes, feat_dim], dtype=x.dtype)
mask = paddle.zeros([batch_size, max_num_nodes], dtype='bool')

cumsum = paddle.concat([paddle.zeros([1], dtype='int64'),
paddle.cumsum(num_nodes, axis=0)[:-1]])

for i in range(batch_size):
start = int(cumsum[i].item())
end = start + int(num_nodes[i].item())
n = int(num_nodes[i].item())
x_dense[i, :n] = x[start:end]
mask[i, :n] = True

return x_dense, mask

107 changes: 107 additions & 0 deletions ppmat/models/chemeleon2/common/data_augmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import paddle


def apply_augmentation(batch, translate=False, rotate=False):
if not translate and not rotate:
return batch

batch_aug = batch.clone()

if translate:
batch_aug = _augmentation_translate(batch_aug)

if rotate:
batch_aug = _augmentation_rotate(batch_aug)

return batch_aug


def _augmentation_translate(batch):
lengths_mean = batch.lengths.mean(axis=0)
lengths_std = batch.lengths.std(axis=0, unbiased=False)

random_translate = paddle.normal(
mean=paddle.abs(lengths_mean),
std=paddle.maximum(paddle.abs(lengths_std), paddle.to_tensor([1e-8]))
) / 2

cart_coords_aug = batch.cart_coords + random_translate

cell_per_node_inv = paddle.inverse(batch.lattices[batch.batch])
frac_coords_aug = paddle.einsum('bi,bij->bj', cart_coords_aug, cell_per_node_inv)
frac_coords_aug = frac_coords_aug % 1.0

batch.cart_coords = cart_coords_aug
batch.frac_coords = frac_coords_aug

return batch


def _augmentation_rotate(batch):
rot_mat = _random_rotation_matrix(validate=True)

cart_coords_aug = paddle.matmul(batch.cart_coords, rot_mat.T)
lattices_aug = paddle.matmul(batch.lattices, rot_mat.T)

batch.cart_coords = cart_coords_aug
batch.lattices = lattices_aug

return batch


def apply_noise(batch, ratio=0.1, corruption_scale=0.1):
if ratio <= 0:
return batch

batch_noise = batch.clone()

total_num_atoms = batch_noise.num_nodes
noise_num_atoms = int(total_num_atoms * ratio)

noise_atom_types = batch_noise.atom_types.clone()
noise_indices = paddle.randperm(total_num_atoms)[:noise_num_atoms]
noise_atom_types[noise_indices] = 0

noise_cart_coords = batch_noise.cart_coords.clone()
noise_indices = paddle.randperm(total_num_atoms)[:noise_num_atoms]
noise_cart_coords[noise_indices] += paddle.randn([noise_num_atoms, 3]) * corruption_scale

cell_per_node_inv = paddle.inverse(batch.lattices[batch.batch])
noise_frac_coords = paddle.einsum('bi,bij->bj', noise_cart_coords, cell_per_node_inv)
noise_frac_coords = noise_frac_coords % 1.0

batch_noise.atom_types = noise_atom_types
batch_noise.cart_coords = noise_cart_coords
batch_noise.frac_coords = noise_frac_coords

return batch_noise


def _random_rotation_matrix(validate=False):
q = paddle.rand([4])
q = q / paddle.norm(q)

rot_mat = paddle.to_tensor([
[
1 - 2 * q[2] ** 2 - 2 * q[3] ** 2,
2 * q[1] * q[2] - 2 * q[0] * q[3],
2 * q[1] * q[3] + 2 * q[0] * q[2],
],
[
2 * q[1] * q[2] + 2 * q[0] * q[3],
1 - 2 * q[1] ** 2 - 2 * q[3] ** 2,
2 * q[2] * q[3] - 2 * q[0] * q[1],
],
[
2 * q[1] * q[3] - 2 * q[0] * q[2],
2 * q[2] * q[3] + 2 * q[0] * q[1],
1 - 2 * q[1] ** 2 - 2 * q[2] ** 2,
],
], dtype='float32')

if validate:
identity = paddle.matmul(rot_mat, rot_mat.T)
eye = paddle.eye(3)
assert paddle.allclose(identity, eye, atol=1e-5, rtol=1e-5), "Not a rotation matrix."

return rot_mat
38 changes: 38 additions & 0 deletions ppmat/models/chemeleon2/common/distributions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import paddle


class DiagonalGaussianDistribution:
def __init__(self, parameters):
self.parameters = parameters
self.mean, self.logvar = paddle.chunk(parameters, 2, axis=1)
self.logvar = paddle.clip(self.logvar, -30.0, 20.0)
self.std = paddle.exp(0.5 * self.logvar)
self.var = paddle.exp(self.logvar)

def sample(self):
x = self.mean + self.std * paddle.randn(self.mean.shape)
return x

def kl(self, other=None):
# Determine which axes to sum over based on tensor dimensionality
# For 4D tensors (images): sum over [1, 2, 3]
# For 2D tensors (latent vectors): sum over [1]
if self.mean.ndim == 4:
sum_axis = [1, 2, 3]
else:
sum_axis = [1]

if other is None:
return 0.5 * paddle.sum(
paddle.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
axis=sum_axis
)
else:
return 0.5 * paddle.sum(
paddle.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
axis=sum_axis
)

def mode(self):
return self.mean
67 changes: 67 additions & 0 deletions ppmat/models/chemeleon2/common/lattice_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import paddle
from ppmat.utils.crystal import lattice_params_to_matrix_paddle
from ppmat.utils.crystal import lattices_to_params_shape


def lattice_params_to_matrix(lengths, angles):
return lattice_params_to_matrix_paddle(lengths, angles)


def matrix_to_lattice_params(lattices):
return lattices_to_params_shape(lattices)


def frac_to_cart_coords(frac_coords, lattice):
if lattice.ndim == 2:
lattice = lattice.unsqueeze(0)
return paddle.einsum('ij,jk->ik', frac_coords, lattice.squeeze(0))


def cart_to_frac_coords(cart_coords, lattice):
if lattice.ndim == 2:
lattice = lattice.unsqueeze(0)
inv_lattice = paddle.inverse(lattice)
return paddle.einsum('ij,jk->ik', cart_coords, inv_lattice.squeeze(0))


def get_pbc_distances(
coords1,
coords2,
lattice,
num_atoms=None,
return_offsets=False,
):
if lattice.ndim == 2:
lattice = lattice.unsqueeze(0)

if coords1.shape != coords2.shape:
raise ValueError("coords1 and coords2 must have the same shape")

diff = coords2 - coords1

diff_frac = cart_to_frac_coords(diff, lattice)

diff_frac = diff_frac - paddle.round(diff_frac)

diff_cart = frac_to_cart_coords(diff_frac, lattice)

distances = paddle.norm(diff_cart, axis=-1)

if return_offsets:
offsets = paddle.round(cart_to_frac_coords(coords2 - coords1, lattice))
return distances, offsets

return distances


def lattice_vector_to_volume(lattice):
if lattice.ndim == 2:
lattice = lattice.unsqueeze(0)

a = lattice[:, 0, :]
b = lattice[:, 1, :]
c = lattice[:, 2, :]

volume = paddle.abs(paddle.sum(a * paddle.cross(b, c), axis=-1))

return volume
Loading