diff --git a/src/flashmd/ase/bussi.py b/src/flashmd/ase/bussi.py index e773e07..63cbb12 100644 --- a/src/flashmd/ase/bussi.py +++ b/src/flashmd/ase/bussi.py @@ -17,9 +17,10 @@ def __init__( time_constant: float = 10.0 * ase.units.fs, device: str | torch.device = "auto", rescale_energy: bool = True, + random_rotation: bool = False, **kwargs, ): - super().__init__(atoms, timestep, model, device, rescale_energy, **kwargs) + super().__init__(atoms, timestep, model, device, rescale_energy, random_rotation, **kwargs) self.temperature_K = temperature_K self.time_constant = time_constant diff --git a/src/flashmd/ase/langevin.py b/src/flashmd/ase/langevin.py index ee71890..2bcd231 100644 --- a/src/flashmd/ase/langevin.py +++ b/src/flashmd/ase/langevin.py @@ -19,9 +19,10 @@ def __init__( time_constant: float = 100.0 * ase.units.fs, device: str | torch.device = "auto", rescale_energy: bool = True, + random_rotation: bool = False, **kwargs, ): - super().__init__(atoms, timestep, model, device, rescale_energy, **kwargs) + super().__init__(atoms, timestep, model, device, rescale_energy, random_rotation, **kwargs) self.temperature_K = temperature_K self.friction = 1.0 / time_constant diff --git a/src/flashmd/ase/velocity_verlet.py b/src/flashmd/ase/velocity_verlet.py index b7f7452..c9f3433 100644 --- a/src/flashmd/ase/velocity_verlet.py +++ b/src/flashmd/ase/velocity_verlet.py @@ -8,9 +8,9 @@ from metatensor.torch.atomistic import System import ase from ..stepper import FlashMDStepper -import numpy as np - +import numpy as np +from scipy.spatial.transform import Rotation class VelocityVerlet(MolecularDynamics): def __init__( self, @@ -19,6 +19,7 @@ def __init__( model: MetatensorAtomisticModel | List[MetatensorAtomisticModel], device: str | torch.device = "auto", rescale_energy: bool = True, + random_rotation: bool = False, **kwargs, ): super().__init__(atoms, timestep, **kwargs) @@ -47,15 +48,40 @@ def __init__( self.stepper = FlashMDStepper(models, n_time_steps, self.device) self.rescale_energy = rescale_energy + self.random_rotation = random_rotation def step(self): + if self.rescale_energy: old_energy = self.atoms.get_total_energy() system = _convert_atoms_to_system( self.atoms, device=self.device, dtype=self.dtype ) + + if self.random_rotation: + # generate a random rotation matrix with SciPy + R = torch.tensor( + _random_R(), + device=system.positions.device, + dtype=system.positions.dtype, + ) + # apply the random rotation + old_cell = system.cell + system.cell = system.cell @ R.T + system.positions = system.positions @ R.T + momenta = system.get_data("momenta").block(0).values.squeeze() + momenta[:] = momenta @ R.T # does the change in place + new_system = self.stepper.step(system) + + if self.random_rotation: + # revert q, p to the original reference frame, load old cell + new_system.cell = old_cell + new_system.positions = system.positions @ R + new_momenta = new_system.get_data("momenta").block(0).values.squeeze() + new_momenta[:] = new_momenta @ R + self.atoms.set_positions(new_system.positions.detach().cpu().numpy()) self.atoms.set_momenta( new_system.get_data("momenta") @@ -126,3 +152,10 @@ def _convert_atoms_to_system( ), ) return system + + +def _random_R(): + R = Rotation.random().as_matrix() + if np.random.rand() < 0.5: + R[:, 0] *= -1 + return R