Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 54 additions & 19 deletions src/somd2/runner/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ def __init__(self, system, config):

# Make sure the system contains perturbable molecules.
try:
self._system.molecules("property is_perturbable")
atoms = self._system["property is_perturbable"].atoms()
pert_idxs = self._system.atoms().find(atoms)
except KeyError:
msg = "No perturbable molecules in the system"
_logger.error(msg)
Expand Down Expand Up @@ -265,7 +266,18 @@ def __init__(self, system, config):
from ghostly import modify

_logger.info("Applying modifications to ghost atom bonded terms")
self._system, self._modifications = modify(self._system)
try:
self._system, self._modifications = modify(self._system)
# Angle optimisation can sometimes fail.
except Exception as e1:
try:
self._system, self._modifications = modify(
self_system, optimise_angles=False
)
except Exception as e2:
msg = f"Unable to apply modifications to ghost atom bonded terms: {e1}; {e2}"
_logger.error(msg)
raise RuntimeError(msg)

# Check for a periodic space.
self._has_space = self._check_space()
Expand Down Expand Up @@ -363,6 +375,7 @@ def __init__(self, system, config):
from math import isclose

# Set the REST2 scale factors.
is_rest2 = False
if self._config.rest2_scale is not None:
# Single value. Interpolate between 1.0 at the end states and rest2_scale
# at lambda = 0.5.
Expand Down Expand Up @@ -396,6 +409,45 @@ def __init__(self, system, config):
raise ValueError(msg)
self._rest2_scale_factors = self._config.rest2_scale

# If there are any non-zero REST2 scale factors, then log it.
if any(
not isclose(factor, 1.0, abs_tol=1e-4)
for factor in self._rest2_scale_factors
):
is_rest2 = True
_logger.info(f"REST2 scaling factors: {self._rest2_scale_factors}")

# Make sure the REST2 selection is valid.
if self._config.rest2_selection is not None:

try:
atoms = _sr.mol.selection_to_atoms(
self._system, self._config.rest2_selection
)
except:
msg = "Invalid 'rest2_selection' value."
_logger.error(msg)
raise ValueError(msg)

# Make sure the user hasn't selected all atoms.
if len(atoms) == self._system.num_atoms():
msg = "REST2 selection cannot contain all atoms in the system."
_logger.error(msg)
raise ValueError(msg)

# Get the atom indices.
idxs = self._system.atoms().find(atoms)

# If no indices are in the perturbable region, then add them.
if not any(i in pert_idxs for i in idxs):
idxs = sorted(pert_idxs + idxs)
else:
idxs = pert_idxs

# Log the atom indices in the REST2 selection.
if is_rest2:
_logger.info(f"REST2 selection contains {len(atoms)} atoms: {idxs}")

# Apply hydrogen mass repartitioning.
if self._config.hmr:
# Work out the current hydrogen mass factor.
Expand Down Expand Up @@ -444,23 +496,6 @@ def __init__(self, system, config):
self._system, self._config.h_mass_factor
)

# Make sure the REST2 selection is valid.
if self._config.rest2_selection is not None:
from sire.mol import selection_to_atoms

try:
atoms = selection_to_atoms(self._system, self._config.rest2_selection)
except:
msg = "Invalid 'rest2_selection' value."
_logger.error(msg)
raise ValueError(msg)

# Make sure the user hasn't selected all atoms.
if len(atoms) == self._system.num_atoms():
msg = "REST2 selection cannot contain all atoms in the system."
_logger.error(msg)
raise ValueError(msg)

# Flag whether this is a GPU simulation.
self._is_gpu = self._config.platform in ["cuda", "opencl", "hip"]

Expand Down