diff --git a/dmff/admp/pairwise.py b/dmff/admp/pairwise.py old mode 100755 new mode 100644 index 2af81574..e263d047 --- a/dmff/admp/pairwise.py +++ b/dmff/admp/pairwise.py @@ -63,7 +63,7 @@ def generate_pairwise_interaction(pair_int_kernel, static_args): with the order in kernel ''' - def pair_int(positions, box, pairs, mScales, *atomic_params): + def pair_int(positions, box, pairs, mScales, *atomic_params): # pairs = regularize_pairs(pairs) pairs = pairs.at[:, :2].set(regularize_pairs(pairs[:, :2])) @@ -77,7 +77,7 @@ def pair_int(positions, box, pairs, mScales, *atomic_params): buffer_scales = pair_buffer_scales(pairs) mscales = mscales * buffer_scales # mscales = mScales[nbonds-1] - box_inv = jnp.linalg.inv(box + jnp.eye(3) * 1e-36) + box_inv = jnp.linalg.inv(box) dr = ri - rj dr = v_pbc_shift(dr, box, box_inv) dr = jnp.linalg.norm(dr, axis=1) @@ -89,7 +89,7 @@ def pair_int(positions, box, pairs, mScales, *atomic_params): # pair_params.append(param[pairs[:, 0]]) # pair_params.append(param[pairs[:, 1]]) - energy = jnp.sum(pair_int_kernel(dr, mscales, *pair_params) * buffer_scales) + energy = jnp.sum(pair_int_kernel(dr, mscales, *pair_params) * buffer_scales) return energy return pair_int @@ -155,7 +155,9 @@ def slater_disp_damping_kernel(dr, m, bi, bj, c6i, c6j, c8i, c8j, c10i, c10j): @vmap @jit_condition(static_argnums=()) -def slater_sr_kernel(dr, m, ai, aj, bi, bj): +# with hardcore potential +def slater_sr_hc_kernel(dr, m, ai, aj, bi, bj): + ''' Slater-ISA type short range terms see jctc 12 3851 @@ -165,5 +167,30 @@ def slater_sr_kernel(dr, m, ai, aj, bi, bj): br = b * dr br2 = br * br P = 1/3 * br2 + br + 1 - return a * P * jnp.exp(-br) * m + alpha = 0.24 + beta = 14 + x = alpha * br + x2 = x * x + x4 = x2 * x2 + x8 = x4 * x4 + x12 = x4 * x8 + x14 = x12 * x2 + HardCorePotential = a / x14 * m + return a * P * jnp.exp(-br) * m + HardCorePotential + +@vmap +@jit_condition(static_argnums=()) +def slater_sr_kernel(dr, m, ai, aj, bi, bj): + + ''' + Slater-ISA type short range terms + see jctc 12 3851 + ''' + b = jnp.sqrt(bi * bj) + a = ai * aj + br = b * dr + br2 = br * br + P = 1/3 * br2 + br + 1 + + return a * P * jnp.exp(-br) * m diff --git a/dmff/admp/pme.py b/dmff/admp/pme.py index 60b5d76e..875e7e54 100755 --- a/dmff/admp/pme.py +++ b/dmff/admp/pme.py @@ -32,6 +32,15 @@ DEFAULT_THOLE_WIDTH = 5.0 +# variables used in soft dipole truncation +MAX_DIP = 1.0 +TRUNCATION_HARDNESS = 25 # the smaller, the softer +def SOFT_TRUNCATION(x): + x2 = x * x + x_abs = jnp.sqrt(x2 + 1e-6) + val = -1/TRUNCATION_HARDNESS * jnp.log(1 + jnp.exp(-TRUNCATION_HARDNESS*(x_abs-MAX_DIP))) + MAX_DIP + return val * x/x_abs + class ADMPPmeForce: """ @@ -416,6 +425,8 @@ def update_U(i, U): dScales, ) U = U - field * pol[:, jnp.newaxis] / DIELECTRIC + # soft truncation: stop polarization catastrophe + U = SOFT_TRUNCATION(U) return U U = jax.lax.fori_loop(0, steps_pol, update_U, U) diff --git a/dmff/generators/admp.py b/dmff/generators/admp.py index 29ab2cef..456eb489 100644 --- a/dmff/generators/admp.py +++ b/dmff/generators/admp.py @@ -13,7 +13,8 @@ TT_damping_qq_c6_kernel, generate_pairwise_interaction, slater_disp_damping_kernel, - slater_sr_kernel, + slater_sr_kernel, ## no Hard Core Potential + slater_sr_hc_kernel, ## added Hard Core Potential TT_damping_qq_kernel, ) from ..admp.pme import ADMPPmeForce @@ -759,20 +760,21 @@ def createPotential( topdata._meta[self.name+"_map_atomtype"] = map_atomtype - pot_fn_sr = generate_pairwise_interaction(slater_sr_kernel, static_args={}) + pot_fn_sr = generate_pairwise_interaction(slater_sr_hc_kernel, static_args={}) + #slater_ex_sr_kernel: added Hard Core Potential has_aux = False if "has_aux" in kwargs and kwargs["has_aux"]: has_aux = True - def potential_fn(positions, box, pairs, params, aux=None): + def potential_fn(positions, box, pairs, params, aux=None): positions = positions * 10 box = box * 10 params = params[self.name] a_list = params["A"][map_atomtype] b_list = params["B"][map_atomtype] / 10 # nm^-1 to A^-1 - energy = pot_fn_sr(positions, box, pairs, self.mScales, a_list, b_list) + energy = pot_fn_sr(positions, box, pairs, self.mScales, a_list, b_list) if has_aux: return energy, aux else: @@ -790,6 +792,7 @@ def getJaxPotential(self): _DMFFGenerators["SlaterExForce"] = SlaterExGenerator + # Here are all the short range "charge penetration" terms # They all have the exchange form with minus sign class SlaterSrEsGenerator(SlaterExGenerator): @@ -798,7 +801,6 @@ def __init__(self, ffinfo: dict, paramset: ParamSet, default_name=None): super().__init__(ffinfo, paramset, default_name="SlaterSrEsForce") else: super().__init__(ffinfo, paramset, default_name=default_name) - def createPotential( self, topdata: DMFFTopology, nonbondedMethod, nonbondedCutoff, **kwargs ): @@ -812,14 +814,14 @@ def createPotential( topdata._meta[self.name+"_map_atomtype"] = map_atomtype - pot_fn_sr = generate_pairwise_interaction(slater_sr_kernel, - static_args={}) + pot_fn_sr = generate_pairwise_interaction(slater_sr_kernel, static_args={}) + ## slater_sr_others_kernel: no Hard Core Potential has_aux = False if "has_aux" in kwargs and kwargs["has_aux"]: has_aux = True - def potential_fn(positions, box, pairs, params, aux=None): + def potential_fn(positions, box, pairs, params, aux=None): positions = positions * 10 box = box * 10 params = params[self.name] @@ -934,10 +936,7 @@ def __init__(self, ffinfo: dict, paramset: ParamSet): for node in self.ffinfo["Forces"][self.name]["node"] if node["name"] in ["Multipole", "Atom"] ] - c0, dX, dY, dZ, qXX, qYY, qZZ, qXY, qXZ, qYZ, oXXX, oXXY, oXYY, oYYY, oXXZ, oXYZ, oYYZ, oXZZ, oYZZ, oZZZ = ( - [], - [], - [], + c0, dX, dY, dZ, qXX, qYY, qZZ, qXY, qXZ, qYZ = ( [], [], [], @@ -948,13 +947,6 @@ def __init__(self, ffinfo: dict, paramset: ParamSet): [], [], [], - [], - [], - [], - [], - [], - [], - [] ) kxs, kys, kzs = [], [], [] multipole_masks = [] @@ -997,29 +989,6 @@ def __init__(self, ffinfo: dict, paramset: ParamSet): qXY.append(0.0) qXZ.append(0.0) qYZ.append(0.0) - if self.lmax >= 3: - oXXX.append(float(attribs["oXXX"])) - oXXY.append(float(attribs["oXXY"])) - oXYY.append(float(attribs["oXYY"])) - oYYY.append(float(attribs["oYYY"])) - oXXZ.append(float(attribs["oXXZ"])) - oXYZ.append(float(attribs["oXYZ"])) - oYYZ.append(float(attribs["oYYZ"])) - oXZZ.append(float(attribs["oXZZ"])) - oYZZ.append(float(attribs["oYZZ"])) - oZZZ.append(float(attribs["oZZZ"])) - else: - oXXX.append(0.0) - oXXY.append(0.0) - oXYY.append(0.0) - oYYY.append(0.0) - oXXZ.append(0.0) - oXYZ.append(0.0) - oYYZ.append(0.0) - oXZZ.append(0.0) - oYZZ.append(0.0) - oZZZ.append(0.0) - mask = 1.0 if "mask" in attribs and attribs["mask"].upper() == "TRUE": mask = 0.0 @@ -1077,8 +1046,6 @@ def __init__(self, ffinfo: dict, paramset: ParamSet): n_mtps = 4 elif self.lmax == 2: n_mtps = 10 - elif self.lmax == 3: - n_mtps = 20 Q = np.zeros((n_atoms, n_mtps)) # TDDO: unit conversion @@ -1096,19 +1063,6 @@ def __init__(self, ffinfo: dict, paramset: ParamSet): Q[:, 8] = qXZ Q[:, 9] = qYZ Q[:, 4:10] *= 300 - if self.lmax >= 3: - Q[:, 10] = oXXX - Q[:, 11] = oXXY - Q[:, 12] = oXYY - Q[:, 13] = oYYY - Q[:, 14] = oXXZ - Q[:, 15] = oXYZ - Q[:, 16] = oYYZ - Q[:, 17] = oXZZ - Q[:, 18] = oYZZ - Q[:, 19] = oZZZ - # TO DO: To be decided - Q[:, 10:20] *= 15000 # add all differentiable params to self.params Q_local = convert_cart2harm(jnp.array(Q), self.lmax) @@ -1138,18 +1092,6 @@ def overwrite(self, paramset): node["attrib"]["qXY"] = Q_global[n_multipole, 7] / 300.0 node["attrib"]["qXZ"] = Q_global[n_multipole, 8] / 300.0 node["attrib"]["qYZ"] = Q_global[n_multipole, 9] / 300.0 - if self.lmax >= 3: - node["attrib"]["oXXX"] = Q_global[n_multipole, 10] / 15000.0 - node["attrib"]["oXXY"] = Q_global[n_multipole, 11] / 15000.0 - node["attrib"]["oXYY"] = Q_global[n_multipole, 12] / 15000.0 - node["attrib"]["oYYY"] = Q_global[n_multipole, 13] / 15000.0 - node["attrib"]["oXXZ"] = Q_global[n_multipole, 14] / 15000.0 - node["attrib"]["oXYZ"] = Q_global[n_multipole, 15] / 15000.0 - node["attrib"]["oYYZ"] = Q_global[n_multipole, 16] / 15000.0 - node["attrib"]["oXZZ"] = Q_global[n_multipole, 17] / 15000.0 - node["attrib"]["oYZZ"] = Q_global[n_multipole, 18] / 15000.0 - node["attrib"]["oZZZ"] = Q_global[n_multipole, 19] / 15000.0 - if q_local_masks[n_multipole] < 0.999: node["mask"] = "true" n_multipole += 1