diff --git a/docs/scripts/generate_notebooks.py b/docs/scripts/generate_notebooks.py index 6373f5e..4c0d91f 100644 --- a/docs/scripts/generate_notebooks.py +++ b/docs/scripts/generate_notebooks.py @@ -441,7 +441,7 @@ def generate_argyrodite_site_analysis(): "coords = np.array(\n" " [[0.5, 0.5, 0.5], # P (type 0) - PS4 tetrahedra\n" " [0.9, 0.9, 0.6], # type 1 (Li in reference)\n" - " [0.23, 0.92, 0.09], # type 2 (Mg in reference)\n" + " [0.77, 0.585, 0.585], # type 2 (Mg in reference, 48h)\n" " [0.25, 0.25, 0.25], # type 3 (Na in reference)\n" " [0.15, 0.15, 0.15], # type 4 (Be in reference)\n" " [0.0, 0.183, 0.183], # type 5 (K in reference)\n" @@ -710,7 +710,7 @@ def generate_residence_times_and_transitions(): "coords = np.array(\n" " [[0.5, 0.5, 0.5], # P (type 0) - PS4 tetrahedra\n" " [0.9, 0.9, 0.6], # type 1 (Li in reference)\n" - " [0.23, 0.92, 0.09], # type 2 (Mg in reference)\n" + " [0.77, 0.585, 0.585], # type 2 (Mg in reference, 48h)\n" " [0.25, 0.25, 0.25], # type 3 (Na in reference)\n" " [0.15, 0.15, 0.15], # type 4 (Be in reference)\n" " [0.0, 0.183, 0.183], # type 5 (K in reference)\n" diff --git a/docs/source/guides/builders.md b/docs/source/guides/builders.md index 1e2ce80..e82f79e 100644 --- a/docs/source/guides/builders.md +++ b/docs/source/guides/builders.md @@ -287,6 +287,25 @@ Sets species to use for mapping sites between reference and target structures. builder.with_site_mapping(mapping_species=["O", "Ti"]) ``` +#### `with_min_atom_distance(distance)` +Sets the minimum allowed distance between same-species atoms in the reference structure. + +When `build()` is called with a reference structure, the builder checks that no two atoms of the same species are closer than this threshold. This catches a common mistake where an atom coordinate sits on a general Wyckoff position instead of the correct special position, which produces pairs of atoms very close together and results in duplicate coordination environments. + +This check runs whenever a reference structure is set, regardless of site type. Set to 0 to disable the check. + +**Parameters:** +- `distance`: Minimum distance in the same units as the lattice parameters. Must be non-negative. Default is 0.5. + +**Examples:** +```python +# Tighten the threshold +builder.with_min_atom_distance(1.0) + +# Disable the check entirely +builder.with_min_atom_distance(0) +``` + #### `with_existing_sites(sites)` Uses pre-existing site objects instead of creating new ones. @@ -511,3 +530,33 @@ trajectory = builder.build() # Raises TypeError builder.with_polyhedral_sites(...) # Without with_reference_structure() trajectory = builder.build() # Raises ValueError ``` + +### Reference Structure Validation + +When a reference structure is set, `build()` checks for two problems that can produce incorrect site definitions: + +**Close same-species atoms**: If any pair of atoms of the same species are closer than `min_atom_distance` (default 0.5), the build fails with a `ValueError`. This typically means an atom coordinate is on a general Wyckoff position instead of the correct special position, producing near-duplicate atoms in the same coordination environment. + +```python +# This will raise ValueError because the Mg coordinate is on a +# general position (96i), producing Mg pairs ~0.14 apart +builder.with_reference_structure(bad_reference) + .with_polyhedral_sites(...) +trajectory = builder.build() +# ValueError: Reference structure has Mg atoms at indices 42 and 43 +# that are only 0.140 apart (threshold: 0.5). ... + +# To disable this check (e.g. if close atoms are intentional): +builder.with_min_atom_distance(0) +``` + +**Duplicate sites**: After generating sites, the builder checks that no two polyhedral sites share the same vertex indices and no two dynamic Voronoi sites share the same reference indices. Duplicates indicate that the reference structure has multiple atoms inside the same coordination environment. + +```python +# Even with the distance check disabled, duplicate sites are caught +builder.with_min_atom_distance(0) + .with_polyhedral_sites(...) +trajectory = builder.build() +# ValueError: Duplicate sites: site 0 and site 1 share the same +# vertex_indices [3, 7, 12, 15]. ... +``` diff --git a/docs/source/tutorials/argyrodite_site_definitions.ipynb b/docs/source/tutorials/argyrodite_site_definitions.ipynb index 75ec85d..43aa4f5 100644 --- a/docs/source/tutorials/argyrodite_site_definitions.ipynb +++ b/docs/source/tutorials/argyrodite_site_definitions.ipynb @@ -61,51 +61,10 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Reference structure contains 1664 atoms\n", - "Composition: K384 Na32 Li128 Mg768 Be128 P32 S192\n" - ] - } - ], - "source": [ - "# Create a reference structure with the argyrodite topology\n", - "# The key approach: use different atom types to differentiate each site type\n", - "# - P occupies the t0 tetrahedra (phosphorus in PS4 units)\n", - "# - Different dummy atoms (Li, Mg, Na, Be, K) occupy the t1-t5 tetrahedra\n", - "# to allow us to define each tetrahedral site type separately\n", - "# - S occupies all the anion sites\n", - "\n", - "lattice = Lattice.cubic(a=10.155) # Use the experimental lattice parameter\n", - "\n", - "coords = np.array(\n", - " [[0.5, 0.5, 0.5], # P (t0) - PS4 tetrahedra positions\n", - " [0.9, 0.9, 0.6], # t1 - first type of Li site (represented by Li atoms)\n", - " [0.23, 0.92, 0.09], # t2 - second type of Li site (represented by Mg atoms)\n", - " [0.25, 0.25, 0.25], # t3 - third type of Li site (represented by Na atoms)\n", - " [0.15, 0.15, 0.15], # t4 - fourth type of Li site (represented by Be atoms)\n", - " [0.0, 0.183, 0.183], # t5 - fifth type of Li site (represented by K atoms)\n", - " [0.0, 0.0, 0.0], # S - anion position (4a site)\n", - " [0.75, 0.25, 0.25], # S - anion position (4c site)\n", - " [0.11824, 0.11824, 0.38176]] # S - anion position (16e site)\n", - ") \n", - "\n", - "# Create the reference structure with F-43m space group symmetry\n", - "# and replicate it as a 2x2x2 supercell to match the MD simulations\n", - "reference_structure = Structure.from_spacegroup(\n", - " sg=\"F-43m\",\n", - " lattice=lattice,\n", - " species=['P', 'Li', 'Mg', 'Na', 'Be', 'K', 'S', 'S', 'S'],\n", - " coords=coords) * [2, 2, 2]\n", - "\n", - "print(f\"Reference structure contains {len(reference_structure)} atoms\")\n", - "print(f\"Composition: {reference_structure.composition.formula}\")" - ] + "outputs": [], + "source": "# Create a reference structure with the argyrodite topology\n# The key approach: use different atom types to differentiate each site type\n# - P occupies the t0 tetrahedra (phosphorus in PS4 units)\n# - Different dummy atoms (Li, Mg, Na, Be, K) occupy the t1-t5 tetrahedra\n# to allow us to define each tetrahedral site type separately\n# - S occupies all the anion sites\n\nlattice = Lattice.cubic(a=10.155) # Use the experimental lattice parameter\n\ncoords = np.array(\n [[0.5, 0.5, 0.5], # P (t0) - PS4 tetrahedra positions (4b)\n [0.9, 0.9, 0.6], # t1 - first type of Li site (16e)\n [0.77, 0.585, 0.585], # t2 - second type of Li site (48h)\n [0.25, 0.25, 0.25], # t3 - third type of Li site (4c)\n [0.15, 0.15, 0.15], # t4 - fourth type of Li site (16e)\n [0.0, 0.183, 0.183], # t5 - fifth type of Li site (48h)\n [0.0, 0.0, 0.0], # S - anion position (4a site)\n [0.75, 0.25, 0.25], # S - anion position (4c site)\n [0.11824, 0.11824, 0.38176]] # S - anion position (16e site)\n) \n\n# Create the reference structure with F-43m space group symmetry\n# and replicate it as a 2x2x2 supercell to match the MD simulations\nreference_structure = Structure.from_spacegroup(\n sg=\"F-43m\",\n lattice=lattice,\n species=['P', 'Li', 'Mg', 'Na', 'Be', 'K', 'S', 'S', 'S'],\n coords=coords) * [2, 2, 2]\n\nprint(f\"Reference structure contains {len(reference_structure)} atoms\")\nprint(f\"Composition: {reference_structure.composition.formula}\")" }, { "cell_type": "markdown", @@ -517,4 +476,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/docs/source/tutorials/residence_times_and_transitions.md b/docs/source/tutorials/residence_times_and_transitions.md index b8279c7..051b893 100644 --- a/docs/source/tutorials/residence_times_and_transitions.md +++ b/docs/source/tutorials/residence_times_and_transitions.md @@ -37,7 +37,7 @@ lattice = Lattice.cubic(a=10.155) coords = np.array( [[0.5, 0.5, 0.5], # P (type 0) - PS4 tetrahedra [0.9, 0.9, 0.6], # type 1 (Li in reference) - [0.23, 0.92, 0.09], # type 2 (Mg in reference) + [0.77, 0.585, 0.585], # type 2 (Mg in reference, 48h) [0.25, 0.25, 0.25], # type 3 (Na in reference) [0.15, 0.15, 0.15], # type 4 (Be in reference) [0.0, 0.183, 0.183], # type 5 (K in reference) diff --git a/site_analysis/builders.py b/site_analysis/builders.py index c954667..4b567b1 100644 --- a/site_analysis/builders.py +++ b/site_analysis/builders.py @@ -139,7 +139,10 @@ def reset(self) -> 'TrajectoryBuilder': # Mapping options self._mapping_species: list[str] | None = None - + + # Validation options + self._min_atom_distance: float = 0.5 + # Functions to be called during build() to create sites self._site_generators: list[Callable] = [] @@ -282,6 +285,41 @@ def with_site_mapping(self, mapping_species: str | list[str] | None) -> Trajecto self._mapping_species = mapping_species return self + def with_min_atom_distance(self, distance: float) -> 'TrajectoryBuilder': + """Set the minimum allowed distance between same-species atoms + in the reference structure. + + If any pair of atoms of the same species in the reference + structure is closer than this threshold, ``build()`` raises + ``ValueError``. This catches reference structures where atoms + sit on a general Wyckoff position instead of the correct + special position, producing duplicate coordination environments. + + This check runs whenever a reference structure is set, + regardless of site type. It is most relevant for polyhedral + and dynamic Voronoi site workflows, but also applies when a + reference structure is provided for mapping or alignment. + + Set to 0 to disable the check. + + Args: + distance: Minimum distance in the same units as the + lattice parameters. Must be non-negative. + The builder default is 0.5. + + Returns: + self: For method chaining. + + Raises: + ValueError: If distance is negative. + """ + if distance < 0: + raise ValueError( + f"min_atom_distance must be non-negative, got {distance}" + ) + self._min_atom_distance = distance + return self + def with_spherical_sites(self, centres: list[list[float]], radii: float | list[float], @@ -630,6 +668,80 @@ def with_existing_atoms(self, atoms: list) -> TrajectoryBuilder: self._atoms = atoms return self + def _validate_reference_atom_distances(self) -> None: + """Check that no same-species atom pairs in the reference + structure are closer than ``_min_atom_distance``.""" + from site_analysis.distances import all_mic_distances + from site_analysis.tools import indices_for_species + + ref = self._reference_structure + lattice_matrix = np.array(ref.lattice.matrix) # type: ignore[union-attr] + species_list = [s.species_string for s in ref] # type: ignore[union-attr] + frac_coords = np.array(ref.frac_coords) # type: ignore[union-attr] + + for sp in set(species_list): + indices = indices_for_species(species_list, sp) + if len(indices) < 2: + continue + + coords = frac_coords[indices] + dists = all_mic_distances(coords, coords, lattice_matrix) + np.fill_diagonal(dists, np.inf) + + min_dist = float(np.min(dists)) + if min_dist < self._min_atom_distance: + i_local, j_local = np.unravel_index( + int(np.argmin(dists)), dists.shape) + raise ValueError( + f"Reference structure has {sp} atoms at indices " + f"{indices[i_local]} and {indices[j_local]} that are " + f"only {min_dist:.3f} apart (threshold: " + f"{self._min_atom_distance}). This typically means " + f"the atom coordinate is on a general Wyckoff position " + f"instead of the correct special position, which " + f"produces duplicate coordination environments. " + f"To disable this check, call " + f".with_min_atom_distance(0) on the builder." + ) + + def _validate_unique_sites(self, sites: list[Site]) -> None: + """Check that no two sites share the same defining indices. + + Applies to PolyhedralSite (vertex_indices) and + DynamicVoronoiSite (reference_indices). Skipped for + VoronoiSite and SphericalSite where duplicate detection + based on coordinates is unreliable. + """ + from site_analysis.polyhedral_site import PolyhedralSite + from site_analysis.dynamic_voronoi_site import DynamicVoronoiSite + from site_analysis.voronoi_site import VoronoiSite + from site_analysis.spherical_site import SphericalSite + + if not sites: + return + + if isinstance(sites[0], PolyhedralSite): + index_attr = "vertex_indices" + elif isinstance(sites[0], DynamicVoronoiSite): + index_attr = "reference_indices" + elif isinstance(sites[0], (VoronoiSite, SphericalSite)): + return + else: + return # Unknown site type — skip gracefully + + seen: dict[frozenset[int], int] = {} + for site in sites: + key = frozenset(getattr(site, index_attr)) + if key in seen: + raise ValueError( + f"Duplicate sites: site {seen[key]} and site " + f"{site.index} share the same {index_attr} " + f"{sorted(key)}. This typically means the reference " + f"structure has multiple atoms inside the same " + f"coordination environment." + ) + seen[key] = site.index + def build(self) -> Trajectory: """Build and return the Trajectory object. @@ -637,10 +749,13 @@ def build(self) -> Trajectory: using the previously configured site generator. Returns: - Trajectory: The constructed Trajectory object - + The constructed Trajectory object. + Raises: - ValueError: If required parameters are missing + ValueError: If required parameters are missing, if the + reference structure has same-species atom pairs closer + than ``min_atom_distance``, or if duplicate sites are + detected. """ # Validate basic requirements if not self._structure: @@ -648,16 +763,20 @@ def build(self) -> Trajectory: if not self._site_generators: raise ValueError("Site type must be defined using one of the with_*_sites methods") + # Pre-build: validate reference structure + if self._reference_structure is not None and self._min_atom_distance > 0: + self._validate_reference_atom_distances() + # Reset the site index counter Site.reset_index() - + # Generate all sites sites: list[Site] = [] site_type = None - + for generator in self._site_generators: generated_sites = generator() - + # Verify site type consistency if generated_sites: current_type = type(generated_sites[0]) @@ -667,7 +786,10 @@ def build(self) -> Trajectory: raise TypeError(f"Cannot mix site types: {site_type.__name__} and {current_type.__name__}") sites.extend(generated_sites) - + + # Post-build: check for duplicate sites + self._validate_unique_sites(sites) + # Create atoms if not already set if not self._atoms: if not self._mobile_species: diff --git a/tests/benchmark_all_site_types.py b/tests/benchmark_all_site_types.py new file mode 100644 index 0000000..524fe37 --- /dev/null +++ b/tests/benchmark_all_site_types.py @@ -0,0 +1,150 @@ +"""Benchmark per-frame site assignment for all four site types. + +Isolates the assign_site_occupations hot path by pre-building the +trajectory and then timing only the per-frame analyse_structure calls +(which include coord assignment + site assignment). + +Uses the Li6PS5Cl argyrodite MD trajectory. + +Usage: + python tests/benchmark_all_site_types.py [--repeats N] +""" + +import argparse +import sys +import time +from pathlib import Path + +import numpy as np +from pymatgen.core import Lattice, Structure +from pymatgen.io.vasp import Xdatcar + +from site_analysis import TrajectoryBuilder +from site_analysis.site import Site + +DATA_DIR = Path(__file__).resolve().parent.parent / "docs" / "source" +ARGYRODITE_XDATCAR = DATA_DIR / "tutorials" / "data" / "Li6PS5Cl_0p_XDATCAR.gz" + + +def build_reference_structure(): + lattice = Lattice.cubic(a=10.155) + coords = np.array([ + [0.5, 0.5, 0.5], + [0.9, 0.9, 0.6], + [0.77, 0.585, 0.585], + [0.25, 0.25, 0.25], + [0.15, 0.15, 0.15], + [0.0, 0.183, 0.183], + [0.0, 0.0, 0.0], + [0.75, 0.25, 0.25], + [0.11824, 0.11824, 0.38176], + ]) + return Structure.from_spacegroup( + sg="F-43m", lattice=lattice, + species=["P", "Li", "Mg", "Na", "Be", "K", "S", "S", "S"], + coords=coords, + ) * [2, 2, 2] + + +def build_trajectory(site_type, structures, reference_structure): + Site._newid = 0 + builder = ( + TrajectoryBuilder() + .with_structure(structures[0]) + .with_reference_structure(reference_structure) + .with_mobile_species("Li") + .with_structure_alignment(align_species="P") + .with_site_mapping(mapping_species=["S", "Cl"]) + ) + + if site_type == "polyhedral": + for sp, label in [("Li", "1"), ("Mg", "2"), ("Na", "3"), + ("Be", "4"), ("K", "5")]: + builder.with_polyhedral_sites( + centre_species=sp, vertex_species="S", + cutoff=3.0, n_vertices=4, label=label) + + elif site_type == "voronoi": + li_idx = [i for i, s in enumerate(reference_structure) + if s.species_string == "Li"] + centres = [reference_structure[i].frac_coords.tolist() for i in li_idx] + builder.with_voronoi_sites(centres=centres) + + elif site_type == "dynamic_voronoi": + for sp, label in [("Li", "1"), ("Mg", "2"), ("Na", "3"), + ("Be", "4"), ("K", "5")]: + builder.with_dynamic_voronoi_sites( + centre_species=sp, reference_species="S", + cutoff=3.0, n_reference=4, label=label) + + elif site_type == "spherical": + centres = [] + for sp in ["Li", "Mg", "Na", "Be", "K"]: + idx = [i for i, s in enumerate(reference_structure) + if s.species_string == sp] + centres.extend([reference_structure[i].frac_coords.tolist() + for i in idx]) + builder.with_spherical_sites(centres=centres, radii=1.5) + + return builder.build() + + +def benchmark_site_type(site_type, structures, reference_structure, repeats): + trajectory = build_trajectory(site_type, structures, reference_structure) + + # Warm up: first frame populates caches + trajectory.append_timestep(structures[0], t=0) + + # Benchmark: repeated passes through the trajectory + all_structures = structures * repeats + start = time.perf_counter() + for s in all_structures: + trajectory.analyse_structure(s) + elapsed = time.perf_counter() - start + + n_frames = len(all_structures) + n_sites = len(trajectory.sites) + n_atoms = len(trajectory.atoms) + + return { + "site_type": site_type, + "n_sites": n_sites, + "n_atoms": n_atoms, + "n_frames": n_frames, + "elapsed_s": elapsed, + "ms_per_frame": elapsed / n_frames * 1000, + "frames_per_s": n_frames / elapsed, + } + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--repeats", type=int, default=5) + args = parser.parse_args() + + if not ARGYRODITE_XDATCAR.exists(): + print(f"Data file not found: {ARGYRODITE_XDATCAR}") + sys.exit(1) + + print("Loading trajectory data...") + xdatcar = Xdatcar(str(ARGYRODITE_XDATCAR)) + structures = xdatcar.structures + n_frames = len(structures) + reference_structure = build_reference_structure() + print(f" {n_frames} frames, {len(reference_structure)} atoms in reference") + print(f" {args.repeats} repeats = {n_frames * args.repeats} total frames\n") + + print(f"{'Site type':<20} {'Sites':>6} {'Atoms':>6} " + f"{'ms/frame':>10} {'frames/s':>10}") + print("-" * 60) + + for site_type in ["polyhedral", "voronoi", "dynamic_voronoi", "spherical"]: + result = benchmark_site_type( + site_type, structures, reference_structure, args.repeats) + print(f"{result['site_type']:<20} {result['n_sites']:>6} " + f"{result['n_atoms']:>6} {result['ms_per_frame']:>10.2f} " + f"{result['frames_per_s']:>10.1f}") + + +if __name__ == "__main__": + main() diff --git a/tests/benchmark_before_after.py b/tests/benchmark_before_after.py new file mode 100644 index 0000000..59b34ad --- /dev/null +++ b/tests/benchmark_before_after.py @@ -0,0 +1,112 @@ +"""Simple before/after benchmark for pymatgen decoupling. + +Build trajectory, then time trajectory_from_structures across real MD data. + +Usage: + python tests/benchmark_before_after.py +""" + +import sys +import time +from pathlib import Path + +# Ensure we import from the repo containing this script, not an installed package +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +import numpy as np +from pymatgen.core import Lattice, Structure +from pymatgen.io.vasp import Xdatcar + +from site_analysis import TrajectoryBuilder +from site_analysis.site import Site + +DATA_DIR = Path(__file__).resolve().parent.parent / "docs" / "source" +ARGYRODITE_XDATCAR = DATA_DIR / "tutorials" / "data" / "Li6PS5Cl_0p_XDATCAR.gz" + + +def build_reference(): + lattice = Lattice.cubic(a=10.155) + coords = np.array([ + [0.5, 0.5, 0.5], [0.9, 0.9, 0.6], [0.77, 0.585, 0.585], + [0.25, 0.25, 0.25], [0.15, 0.15, 0.15], [0.0, 0.183, 0.183], + [0.0, 0.0, 0.0], [0.75, 0.25, 0.25], [0.11824, 0.11824, 0.38176], + ]) + return Structure.from_spacegroup( + sg="F-43m", lattice=lattice, + species=["P", "Li", "Mg", "Na", "Be", "K", "S", "S", "S"], + coords=coords, + ) * [2, 2, 2] + + +def run_trajectory(site_type, structures, ref): + """Build and run a full trajectory. Returns (n_sites, elapsed_s, n_frames).""" + Site._newid = 0 + builder = ( + TrajectoryBuilder() + .with_structure(structures[0]) + .with_reference_structure(ref) + .with_mobile_species("Li") + .with_structure_alignment(align_species="P") + .with_site_mapping(mapping_species=["S", "Cl"]) + ) + + if site_type == "polyhedral": + for sp, l in [("Li", "1"), ("Mg", "2"), ("Na", "3"), ("Be", "4"), ("K", "5")]: + builder.with_polyhedral_sites( + centre_species=sp, vertex_species="S", cutoff=3.0, n_vertices=4, label=l) + elif site_type == "voronoi": + centres = [] + for sp in ["Li", "Mg", "Na", "Be", "K"]: + idx = [i for i, s in enumerate(ref) if s.species_string == sp] + centres.extend([ref[i].frac_coords.tolist() for i in idx]) + builder.with_voronoi_sites(centres=centres) + elif site_type == "dynamic_voronoi": + for sp, l in [("Li", "1"), ("Mg", "2"), ("Na", "3"), ("Be", "4"), ("K", "5")]: + builder.with_dynamic_voronoi_sites( + centre_species=sp, reference_species="S", cutoff=3.0, n_reference=4, label=l) + elif site_type == "spherical": + centres = [] + for sp in ["Li", "Mg", "Na", "Be", "K"]: + idx = [i for i, s in enumerate(ref) if s.species_string == sp] + centres.extend([ref[i].frac_coords.tolist() for i in idx]) + builder.with_spherical_sites(centres=centres, radii=1.5) + + trajectory = builder.build() + n_sites = len(trajectory.sites) + + # Time the actual MD trajectory analysis + start = time.perf_counter() + for i, s in enumerate(structures): + trajectory.append_timestep(s, t=i) + elapsed = time.perf_counter() - start + + return n_sites, elapsed, len(structures) + + +def main(): + # Verify which code we're running + import site_analysis + import inspect + from site_analysis.dynamic_voronoi_site_collection import DynamicVoronoiSiteCollection as D + src = inspect.getsource(D.assign_site_occupations) + dist_fn = "get_all_distances" if "get_all_distances" in src else "all_mic_distances" + print(f"Code: {site_analysis.__file__}") + print(f"Distance function: {dist_fn}") + + print("\nLoading Li6PS5Cl trajectory...") + xdatcar = Xdatcar(str(ARGYRODITE_XDATCAR)) + structures = xdatcar.structures + ref = build_reference() + print(f" {len(structures)} frames, {len(ref)} atoms in reference\n") + + print(f"{'Site type':<20} {'Sites':>6} {'Total (s)':>10} {'ms/frame':>10}") + print("-" * 50) + + for site_type in ["polyhedral", "voronoi", "dynamic_voronoi", "spherical"]: + n_sites, elapsed, n_frames = run_trajectory(site_type, structures, ref) + ms_per_frame = elapsed / n_frames * 1000 + print(f"{site_type:<20} {n_sites:>6} {elapsed:>10.3f} {ms_per_frame:>10.2f}") + + +if __name__ == "__main__": + main() diff --git a/tests/benchmark_containment.py b/tests/benchmark_containment.py index 2d2b26d..1c89bed 100644 --- a/tests/benchmark_containment.py +++ b/tests/benchmark_containment.py @@ -35,7 +35,7 @@ def build_argyrodite_trajectory(structures): coords = np.array([ [0.5, 0.5, 0.5], [0.9, 0.9, 0.6], - [0.23, 0.92, 0.09], + [0.77, 0.585, 0.585], [0.25, 0.25, 0.25], [0.15, 0.15, 0.15], [0.0, 0.183, 0.183], diff --git a/tests/benchmark_dynamic_voronoi.py b/tests/benchmark_dynamic_voronoi.py index 73aad31..ccacc8e 100644 --- a/tests/benchmark_dynamic_voronoi.py +++ b/tests/benchmark_dynamic_voronoi.py @@ -35,7 +35,7 @@ def build_dynamic_voronoi_trajectory(structures): coords = np.array([ [0.5, 0.5, 0.5], [0.9, 0.9, 0.6], - [0.23, 0.92, 0.09], + [0.77, 0.585, 0.585], [0.25, 0.25, 0.25], [0.15, 0.15, 0.15], [0.0, 0.183, 0.183], diff --git a/tests/test_builders.py b/tests/test_builders.py index 8fbd1b0..0d2a35a 100644 --- a/tests/test_builders.py +++ b/tests/test_builders.py @@ -1160,7 +1160,7 @@ def test_builder_state_reset_after_build(self): """Test that the builder resets its entire state after build() is called.""" # Create a builder builder = TrajectoryBuilder() - + # Set state directly to non-default values builder._structure = Mock(spec=Structure) builder._reference_structure = Mock(spec=Structure) @@ -1171,12 +1171,14 @@ def test_builder_state_reset_after_build(self): builder._align_metric = "max_dist" builder._mapping_species = ["Cl"] builder._site_generators = [lambda: []] - - # Mock the methods called within build + + # Mock the methods called within build (including validation) with patch('site_analysis.builders.atoms_from_structure'), \ patch('site_analysis.builders.Trajectory'), \ - patch('site_analysis.builders.Site.reset_index'): - + patch('site_analysis.builders.Site.reset_index'), \ + patch.object(builder, '_validate_reference_atom_distances'), \ + patch.object(builder, '_validate_unique_sites'): + # Call build builder.build() @@ -1526,5 +1528,168 @@ def test_use_reference_centers_passed_to_reference_based_sites_for_dynamic_voron use_reference_centers=False ) +class TestBuilderValidation(unittest.TestCase): + """Tests for builder validation of reference structures and duplicate sites.""" + + def _make_argyrodite_ref(self, mg_coord): + """Build an argyrodite reference with a given Mg coordinate.""" + lattice = Lattice.cubic(a=10.155) + coords = np.array([ + [0.5, 0.5, 0.5], + [0.9, 0.9, 0.6], + mg_coord, + [0.25, 0.25, 0.25], + [0.15, 0.15, 0.15], + [0.0, 0.183, 0.183], + [0.0, 0.0, 0.0], + [0.75, 0.25, 0.25], + [0.11824, 0.11824, 0.38176], + ]) + return Structure.from_spacegroup( + sg="F-43m", lattice=lattice, + species=["P", "Li", "Mg", "Na", "Be", "K", "S", "S", "S"], + coords=coords, + ) * [2, 2, 2] + + def _make_target(self, ref): + """Build a target structure from a reference (just use the reference).""" + return ref.copy() + + def test_close_atoms_raises_valueerror(self): + """Pre-build check rejects reference with close same-species atoms.""" + ref = self._make_argyrodite_ref([0.23, 0.92, 0.09]) # 96i — produces close pairs + target = self._make_target(ref) + + builder = (TrajectoryBuilder() + .with_structure(target) + .with_reference_structure(ref) + .with_mobile_species("Li") + .with_polyhedral_sites(centre_species="Li", vertex_species="S", + cutoff=3.0, n_vertices=4, label="test")) + + with self.assertRaises(ValueError) as ctx: + builder.build() + self.assertIn("Mg", str(ctx.exception)) + self.assertIn("with_min_atom_distance(0)", str(ctx.exception)) + + def test_min_atom_distance_zero_disables_check(self): + """Setting min_atom_distance to 0 disables the pre-build check.""" + ref = self._make_argyrodite_ref([0.23, 0.92, 0.09]) + target = self._make_target(ref) + + builder = (TrajectoryBuilder() + .with_structure(target) + .with_reference_structure(ref) + .with_mobile_species("Li") + .with_min_atom_distance(0) + .with_polyhedral_sites(centre_species="Li", vertex_species="S", + cutoff=3.0, n_vertices=4, label="test")) + + # Should not raise + traj = builder.build() + self.assertGreater(len(traj.sites), 0) + + def test_custom_threshold(self): + """Custom threshold below the close pair distance allows build.""" + ref = self._make_argyrodite_ref([0.23, 0.92, 0.09]) + target = self._make_target(ref) + + # Close pairs are ~0.14 A apart; threshold of 0.1 should pass + builder = (TrajectoryBuilder() + .with_structure(target) + .with_reference_structure(ref) + .with_mobile_species("Li") + .with_min_atom_distance(0.1) + .with_polyhedral_sites(centre_species="Li", vertex_species="S", + cutoff=3.0, n_vertices=4, label="test")) + + traj = builder.build() + self.assertGreater(len(traj.sites), 0) + + def test_valid_reference_passes(self): + """A correct reference structure passes validation.""" + ref = self._make_argyrodite_ref([0.77, 0.585, 0.585]) # 48h — no close pairs + target = self._make_target(ref) + + builder = (TrajectoryBuilder() + .with_structure(target) + .with_reference_structure(ref) + .with_mobile_species("Li") + .with_polyhedral_sites(centre_species="Li", vertex_species="S", + cutoff=3.0, n_vertices=4, label="test")) + + traj = builder.build() + self.assertGreater(len(traj.sites), 0) + + def test_only_checks_within_same_species(self): + """Different species at close distances should not trigger the check.""" + lattice = Lattice.cubic(5.0) + # Li and Na at almost the same position — different species, should be fine + structure = Structure(lattice=lattice, + species=["Li", "Na", "O", "O", "O", "O"], + coords=[[0.0, 0.0, 0.0], [0.01, 0.0, 0.0], + [0.5, 0.0, 0.0], [0.0, 0.5, 0.0], + [0.0, 0.0, 0.5], [0.5, 0.5, 0.5]]) + + builder = (TrajectoryBuilder() + .with_structure(structure) + .with_reference_structure(structure) + .with_mobile_species("Li") + .with_spherical_sites(centres=[[0.0, 0.0, 0.0]], radii=1.0)) + + # Should not raise — Li and Na are different species + traj = builder.build() + self.assertEqual(len(traj.sites), 1) + + def test_duplicate_polyhedral_sites_raises(self): + """Post-build check rejects duplicate polyhedral sites.""" + ref = self._make_argyrodite_ref([0.23, 0.92, 0.09]) + target = self._make_target(ref) + + builder = (TrajectoryBuilder() + .with_structure(target) + .with_reference_structure(ref) + .with_mobile_species("Li") + .with_min_atom_distance(0) # Disable pre-build to test post-build + .with_polyhedral_sites(centre_species="Mg", vertex_species="S", + cutoff=3.0, n_vertices=4, label="test")) + + with self.assertRaises(ValueError) as ctx: + builder.build() + self.assertIn("Duplicate sites", str(ctx.exception)) + + def test_duplicate_dynamic_voronoi_sites_raises(self): + """Post-build check rejects duplicate dynamic voronoi sites.""" + ref = self._make_argyrodite_ref([0.23, 0.92, 0.09]) + target = self._make_target(ref) + + builder = (TrajectoryBuilder() + .with_structure(target) + .with_reference_structure(ref) + .with_mobile_species("Li") + .with_min_atom_distance(0) # Disable pre-build to test post-build + .with_dynamic_voronoi_sites(centre_species="Mg", reference_species="S", + cutoff=3.0, n_reference=4, label="test")) + + with self.assertRaises(ValueError) as ctx: + builder.build() + self.assertIn("Duplicate sites", str(ctx.exception)) + + def test_non_duplicate_sites_pass(self): + """Non-duplicate polyhedral sites pass the post-build check.""" + ref = self._make_argyrodite_ref([0.77, 0.585, 0.585]) + target = self._make_target(ref) + + builder = (TrajectoryBuilder() + .with_structure(target) + .with_reference_structure(ref) + .with_mobile_species("Li") + .with_polyhedral_sites(centre_species="Mg", vertex_species="S", + cutoff=3.0, n_vertices=4, label="test")) + + traj = builder.build() + self.assertEqual(len(traj.sites), 384) + + if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/tests/test_numba_validation.py b/tests/test_numba_validation.py index efd4d7d..48ea1a5 100644 --- a/tests/test_numba_validation.py +++ b/tests/test_numba_validation.py @@ -161,7 +161,7 @@ def _run_argyrodite_analysis(self): [ [0.5, 0.5, 0.5], [0.9, 0.9, 0.6], - [0.23, 0.92, 0.09], + [0.77, 0.585, 0.585], [0.25, 0.25, 0.25], [0.15, 0.15, 0.15], [0.0, 0.183, 0.183], diff --git a/tests/validate_decoupling.py b/tests/validate_decoupling.py new file mode 100644 index 0000000..57dc9e3 --- /dev/null +++ b/tests/validate_decoupling.py @@ -0,0 +1,309 @@ +"""Validation and benchmarking for pymatgen decoupling (#59). + +Runs the full Li6PS5Cl argyrodite analysis with all four site types +(polyhedral, Voronoi, dynamic Voronoi, spherical) using the high-level +TrajectoryBuilder API. Captures per-frame site assignments and timing. + +This script uses only the public TrajectoryBuilder API, which is +unchanged between the pre-decoupling baseline and the current code. +Run on both commits and compare the JSON output files for correctness. + +Usage: + # On current code: + python tests/validate_decoupling.py --output results_after.json + + # On baseline (git checkout 7379d58): + python tests/validate_decoupling.py --output results_before.json + + # Compare: + python tests/validate_decoupling.py --compare results_before.json results_after.json +""" + +import argparse +import json +import sys +import time +from pathlib import Path + +import numpy as np +from pymatgen.core import Lattice, Structure +from pymatgen.io.vasp import Xdatcar + +from site_analysis import TrajectoryBuilder +from site_analysis.site import Site + +DATA_DIR = Path(__file__).resolve().parent.parent / "docs" / "source" +ARGYRODITE_XDATCAR = DATA_DIR / "tutorials" / "data" / "Li6PS5Cl_0p_XDATCAR.gz" + +# Reference structure for Li6PS5Cl argyrodite (2x2x2 supercell) +LATTICE = Lattice.cubic(a=10.155) +REF_COORDS = np.array([ + [0.5, 0.5, 0.5], + [0.9, 0.9, 0.6], + [0.77, 0.585, 0.585], + [0.25, 0.25, 0.25], + [0.15, 0.15, 0.15], + [0.0, 0.183, 0.183], + [0.0, 0.0, 0.0], + [0.75, 0.25, 0.25], + [0.11824, 0.11824, 0.38176], +]) +REF_SPECIES = ["P", "Li", "Mg", "Na", "Be", "K", "S", "S", "S"] + + +def build_reference_structure(): + return Structure.from_spacegroup( + sg="F-43m", lattice=LATTICE, species=REF_SPECIES, coords=REF_COORDS, + ) * [2, 2, 2] + + +def build_trajectory(site_type, structures, reference_structure): + """Build a trajectory for one site type.""" + Site._newid = 0 + + builder = ( + TrajectoryBuilder() + .with_structure(structures[0]) + .with_reference_structure(reference_structure) + .with_mobile_species("Li") + .with_structure_alignment(align_species="P") + .with_site_mapping(mapping_species=["S", "Cl"]) + ) + + if site_type == "polyhedral": + for centre_species, label in [ + ("Li", "type 1"), ("Mg", "type 2"), ("Na", "type 3"), + ("Be", "type 4"), ("K", "type 5"), + ]: + builder.with_polyhedral_sites( + centre_species=centre_species, + vertex_species="S", + cutoff=3.0, + n_vertices=4, + label=label, + ) + + elif site_type == "voronoi": + # Use all Li-related site centres (matching other site types) + centres = [] + for sp in ["Li", "Mg", "Na", "Be", "K"]: + indices = [i for i, s in enumerate(reference_structure) + if s.species_string == sp] + centres.extend([reference_structure[i].frac_coords.tolist() + for i in indices]) + builder.with_voronoi_sites(centres=centres) + + elif site_type == "dynamic_voronoi": + for centre_species, label in [ + ("Li", "type 1"), ("Mg", "type 2"), ("Na", "type 3"), + ("Be", "type 4"), ("K", "type 5"), + ]: + builder.with_dynamic_voronoi_sites( + centre_species=centre_species, + reference_species="S", + cutoff=3.0, + n_reference=4, + label=label, + ) + + elif site_type == "spherical": + # Use all Li-related site centres with a 1.5 A radius + centres = [] + for sp in ["Li", "Mg", "Na", "Be", "K"]: + indices = [i for i, s in enumerate(reference_structure) + if s.species_string == sp] + centres.extend([reference_structure[i].frac_coords.tolist() + for i in indices]) + builder.with_spherical_sites(centres=centres, radii=1.5) + + else: + raise ValueError(f"Unknown site type: {site_type}") + + return builder.build() + + +def run_analysis(site_type, structures, reference_structure, n_repeats=1): + """Run analysis and capture results.""" + trajectory = build_trajectory(site_type, structures, reference_structure) + + n_sites = len(trajectory.sites) + n_atoms = len(trajectory.atoms) + + # Run the trajectory analysis with timing + all_structures = structures * n_repeats + start = time.perf_counter() + for i, s in enumerate(all_structures): + trajectory.append_timestep(s, t=i) + elapsed = time.perf_counter() - start + + # Capture per-atom trajectories (site assignments over time) + atom_trajectories = {} + for atom in trajectory.atoms: + atom_trajectories[str(atom.index)] = [ + int(s) if s is not None else None + for s in atom.trajectory + ] + + # Capture site occupancy counts + site_data = {} + for site in trajectory.sites: + site_data[str(site.index)] = { + "label": site.label, + "average_occupation": float(site.average_occupation) + if hasattr(site, 'average_occupation') and site.average_occupation is not None + else None, + } + + return { + "site_type": site_type, + "n_sites": n_sites, + "n_atoms": n_atoms, + "n_frames": len(all_structures), + "elapsed_s": elapsed, + "ms_per_frame": elapsed / len(all_structures) * 1000, + "atom_trajectories": atom_trajectories, + "site_data": site_data, + } + + +def run_all(structures, reference_structure, n_repeats=1): + """Run analysis for all four site types.""" + results = {} + for site_type in ["polyhedral", "voronoi", "dynamic_voronoi", "spherical"]: + print(f" Running {site_type}...", end=" ", flush=True) + try: + result = run_analysis( + site_type, structures, reference_structure, n_repeats) + print(f"{result['n_sites']} sites, {result['ms_per_frame']:.2f} ms/frame") + results[site_type] = result + except Exception as e: + print(f"FAILED: {e}") + results[site_type] = {"error": str(e)} + return results + + +def compare_results(before_path, after_path): + """Compare two result files for correctness and performance.""" + with open(before_path) as f: + before = json.load(f) + with open(after_path) as f: + after = json.load(f) + + print("\n=== Correctness Comparison ===\n") + + all_match = True + for site_type in ["polyhedral", "voronoi", "dynamic_voronoi", "spherical"]: + b = before.get(site_type, {}) + a = after.get(site_type, {}) + + if "error" in b or "error" in a: + print(f"{site_type}: SKIPPED (error in one or both)") + if "error" in b: + print(f" Before: {b['error']}") + if "error" in a: + print(f" After: {a['error']}") + all_match = False + continue + + # Compare site counts + if b["n_sites"] != a["n_sites"]: + print(f"{site_type}: MISMATCH — {b['n_sites']} sites before, " + f"{a['n_sites']} after") + all_match = False + continue + + # Compare atom trajectories frame by frame + b_traj = b["atom_trajectories"] + a_traj = a["atom_trajectories"] + + mismatches = 0 + total_assignments = 0 + for atom_idx in b_traj: + if atom_idx not in a_traj: + mismatches += len(b_traj[atom_idx]) + total_assignments += len(b_traj[atom_idx]) + continue + for frame_i, (bv, av) in enumerate( + zip(b_traj[atom_idx], a_traj[atom_idx])): + total_assignments += 1 + if bv != av: + mismatches += 1 + + if mismatches == 0: + print(f"{site_type}: MATCH ({total_assignments} assignments identical)") + else: + print(f"{site_type}: MISMATCH — {mismatches}/{total_assignments} " + f"assignments differ ({mismatches/total_assignments*100:.1f}%)") + all_match = False + + print("\n=== Performance Comparison ===\n") + print(f"{'Site type':<20} {'Before (ms/f)':>15} {'After (ms/f)':>15} {'Speedup':>10}") + print("-" * 65) + for site_type in ["polyhedral", "voronoi", "dynamic_voronoi", "spherical"]: + b = before.get(site_type, {}) + a = after.get(site_type, {}) + if "error" in b or "error" in a: + print(f"{site_type:<20} {'error':>15} {'error':>15} {'—':>10}") + continue + b_ms = b["ms_per_frame"] + a_ms = a["ms_per_frame"] + speedup = b_ms / a_ms if a_ms > 0 else float("inf") + print(f"{site_type:<20} {b_ms:>15.2f} {a_ms:>15.2f} {speedup:>9.2f}x") + + if all_match: + print("\nAll site assignments match. Decoupling is correctness-preserving.") + else: + print("\nWARNING: Some site assignments differ. Investigate before closing #59.") + + return all_match + + +def main(): + parser = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("--output", type=str, default=None, + help="Save results to JSON file") + parser.add_argument("--compare", nargs=2, metavar=("BEFORE", "AFTER"), + help="Compare two result files") + parser.add_argument("--repeats", type=int, default=1, + help="Number of times to repeat the trajectory") + args = parser.parse_args() + + if args.compare: + success = compare_results(args.compare[0], args.compare[1]) + sys.exit(0 if success else 1) + + if not ARGYRODITE_XDATCAR.exists(): + print(f"Data file not found: {ARGYRODITE_XDATCAR}") + sys.exit(1) + + print("Loading trajectory data...") + xdatcar = Xdatcar(str(ARGYRODITE_XDATCAR)) + structures = xdatcar.structures + print(f" {len(structures)} frames") + + reference_structure = build_reference_structure() + print(f" Reference: {len(reference_structure)} atoms") + + print(f"\nRunning analysis (repeats={args.repeats})...") + results = run_all(structures, reference_structure, n_repeats=args.repeats) + + if args.output: + with open(args.output, "w") as f: + json.dump(results, f, indent=2) + print(f"\nResults saved to {args.output}") + else: + # Print summary + print("\n=== Summary ===\n") + print(f"{'Site type':<20} {'Sites':>8} {'Atoms':>8} {'ms/frame':>12}") + print("-" * 50) + for site_type, result in results.items(): + if "error" in result: + print(f"{site_type:<20} {'ERROR':>8}") + else: + print(f"{site_type:<20} {result['n_sites']:>8} " + f"{result['n_atoms']:>8} {result['ms_per_frame']:>12.2f}") + + +if __name__ == "__main__": + main()