Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@ Changelog
0.19.0
------

- ``StructureScene.species`` is now stored as a tuple. The sequence
is fixed at construction.

- New :meth:`~hofmann.StructureScene.select_by_species` method filters
a full-length per-atom array to keep only selected species, filling
the rest with the appropriate missing sentinel (``NaN`` for numeric,
``None`` for categorical).

- :meth:`~hofmann.StructureScene.set_atom_data` gains ``by_species``
and ``by_index`` keyword arguments for sparse per-atom metadata
assignment. ``by_species`` maps species labels to values;
Expand Down
19 changes: 19 additions & 0 deletions docs/colouring.rst
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,25 @@ of length ``n_frames``. Either of these promotes the output to 2-D.
Scalar and 1-D ``by_species`` values and scalar ``by_index`` values
broadcast across frames automatically.

Filtering a full-length array
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

When you have a full-length array (e.g. from an external calculation)
but only want to colour certain species, use
:meth:`~hofmann.StructureScene.select_by_species` to replace
non-selected atoms with the appropriate missing sentinel:

.. code-block:: python

# Keep only O-atom charges; other atoms fall back to species colour.
scene.set_atom_data(
"charge",
scene.select_by_species(full_charge_array, "O"),
)

This handles integer-to-float promotion, unicode-to-object promotion,
and species-label validation automatically.

Custom colouring functions
--------------------------

Expand Down
14 changes: 7 additions & 7 deletions src/hofmann/construction/bonds.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


def compute_bonds(
species: list[str],
species: tuple[str, ...],
coords: np.ndarray,
Comment thread
bjmorgan marked this conversation as resolved.
bond_specs: list[BondSpec],
lattice: np.ndarray | None = None,
Expand All @@ -28,7 +28,7 @@ def compute_bonds(
lattice parameter apart).

Args:
species: List of species labels, length ``n_atoms``.
species: Species labels, one per atom.
coords: Coordinates array of shape ``(n_atoms, 3)``.
bond_specs: List of BondSpec rules to apply.
lattice: 3x3 matrix of lattice vectors (row vectors).
Expand Down Expand Up @@ -70,7 +70,7 @@ def compute_bonds(


def _compute_bonds_direct(
species: list[str],
species: tuple[str, ...],
diff: np.ndarray,
bond_specs: list[BondSpec],
unique_species: list[str],
Expand Down Expand Up @@ -121,7 +121,7 @@ def _inscribed_sphere_radius(lattice: np.ndarray) -> float:


def _compute_bonds_periodic(
species: list[str],
species: tuple[str, ...],
diff: np.ndarray,
bond_specs: list[BondSpec],
lattice: np.ndarray,
Expand Down Expand Up @@ -160,7 +160,7 @@ def _compute_bonds_periodic(


def _compute_bonds_mic(
species: list[str],
species: tuple[str, ...],
diff_frac: np.ndarray,
bond_specs: list[BondSpec],
lattice: np.ndarray,
Expand Down Expand Up @@ -200,7 +200,7 @@ def _compute_bonds_mic(


def _compute_bonds_multi_image(
species: list[str],
species: tuple[str, ...],
diff_frac: np.ndarray,
bond_specs: list[BondSpec],
lattice: np.ndarray,
Expand Down Expand Up @@ -268,7 +268,7 @@ def _compute_bonds_multi_image(

def _species_pair_mask(
spec: BondSpec,
species: list[str],
species: tuple[str, ...],
unique_species: list[str],
) -> np.ndarray:
"""Build a boolean (n, n) mask for species pairs matching *spec*."""
Expand Down
3 changes: 2 additions & 1 deletion src/hofmann/construction/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""

import json
from collections.abc import Sequence
from importlib import resources

from hofmann.model import AtomStyle, BondSpec, Colour
Expand Down Expand Up @@ -262,7 +263,7 @@ def _load_vesta_cutoffs() -> dict[tuple[str, str], float]:


def default_bond_specs(
species: list[str],
species: Sequence[str],
*,
bond_radius: float | None = None,
bond_colour: Colour | None = None,
Expand Down
2 changes: 1 addition & 1 deletion src/hofmann/construction/polyhedra.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


def compute_polyhedra(
species: list[str],
species: tuple[str, ...],
coords: np.ndarray,
bonds: list[Bond],
polyhedra_specs: list[PolyhedronSpec],
Expand Down
6 changes: 3 additions & 3 deletions src/hofmann/construction/rendering_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class RenderingSet:
``[0, 1, 2, ...]``.
"""

species: list[str]
species: list[str] # list, not tuple: built incrementally during image expansion
coords: np.ndarray
bonds: list[Bond]
source_indices: np.ndarray
Expand Down Expand Up @@ -196,7 +196,7 @@ def _discover_bonds_for_new_atoms(


def _complete_polyhedra_vertices(
species: list[str],
species: tuple[str, ...],
coords: np.ndarray,
lattice: np.ndarray,
n_physical: int,
Expand Down Expand Up @@ -292,7 +292,7 @@ def _is_centre(sp: str) -> bool:


def build_rendering_set(
species: list[str],
species: tuple[str, ...],
coords: np.ndarray,
periodic_bonds: list[Bond],
bond_specs: list[BondSpec],
Expand Down
4 changes: 2 additions & 2 deletions src/hofmann/model/colour.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def normalise_colour(colour: Colour) -> tuple[float, float, float]:


def _species_colours(
species: list[str],
species: tuple[str, ...],
atom_styles: dict[str, AtomStyle],
) -> list[tuple[float, float, float]]:
"""Return per-atom colours from species styles (the default path).
Expand Down Expand Up @@ -167,7 +167,7 @@ def _resolve_single_layer(


def _resolve_atom_colours(
species: list[str],
species: tuple[str, ...],
atom_styles: dict[str, AtomStyle],
atom_data: dict[str, np.ndarray],
colour_by: str | list[str] | None = None,
Expand Down
95 changes: 91 additions & 4 deletions src/hofmann/model/structure_scene.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from collections.abc import Sequence
from collections.abc import Iterable, Sequence
from pathlib import Path
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -30,7 +30,8 @@ class StructureScene:
(per-atom metadata) properties are documented individually below.

Attributes:
species: One label per atom.
species: One label per atom. Stored as a tuple; the
sequence is fixed at construction.
frames: List of coordinate snapshots. Each :class:`Frame` may
carry its own ``lattice`` for variable-cell trajectories.
atom_styles: Mapping from species label to visual style.
Expand All @@ -41,7 +42,7 @@ class StructureScene:

def __init__(
self,
species: list[str],
species: Sequence[str],
frames: list[Frame],
atom_styles: dict[str, AtomStyle] | None = None,
bond_specs: list[BondSpec] | None = None,
Expand All @@ -50,7 +51,7 @@ def __init__(
title: str = "",
atom_data: dict[str, ArrayLike] | None = None,
) -> None:
self.species = species
self.species: tuple[str, ...] = tuple(species)
self.frames = frames
self.atom_styles = atom_styles if atom_styles is not None else {}
self.bond_specs = bond_specs if bond_specs is not None else []
Expand Down Expand Up @@ -699,6 +700,92 @@ def clear_2d_atom_data(self) -> None:
"""
self._atom_data._clear_2d()

def select_by_species(
self,
values: ArrayLike,
species: str | Iterable[str],
) -> np.ndarray:
"""Keep values for selected species, fill the rest with sentinels.

Returns a copy of *values* with entries for non-selected atoms
replaced by the appropriate missing sentinel: ``NaN`` for
numeric data (with integer-to-float promotion) or ``None``
for categorical data (with unicode-to-object promotion).

Intended for filtering a full-length array before passing it
to :meth:`set_atom_data`::

scene.set_atom_data(
"charge",
scene.select_by_species(full_array, "O"),
)

Args:
values: Array-like of shape ``(n_atoms,)`` or
``(n_frames, n_atoms)``.
species: A single species label or an iterable of labels
to keep.

Returns:
A new array with the same shape as *values*.

Raises:
ValueError: If *species* contains unknown labels or if
*values* has the wrong shape.
"""
arr = np.asarray(values)
n_atoms = len(self.species)

if arr.ndim == 1:
if len(arr) != n_atoms:
raise ValueError(
f"values must have length {n_atoms}, got {len(arr)}"
)
elif arr.ndim == 2:
if arr.shape[1] != n_atoms:
raise ValueError(
f"values must have {n_atoms} columns, "
f"got {arr.shape[1]}"
)
else:
raise ValueError(
f"values must be 1-D or 2-D, got {arr.ndim}-D"
)

if arr.dtype.kind not in ("b", "i", "u", "f", "U", "O"):
raise ValueError(
f"unsupported dtype {arr.dtype}; supported dtypes "
f"are bool, integer, float, string, and object"
)

if isinstance(species, str):
keep = {species}
else:
keep = set(species)

known = set(self.species)
unknown = keep - known
if unknown:
raise ValueError(
f"unknown species: {', '.join(sorted(unknown))}"
)

mask = np.array([s in keep for s in self.species])

# Build output with appropriate sentinel.
if arr.dtype.kind in ("U", "O"):
out = np.empty_like(arr, dtype=object)
out[:] = None
Comment thread
bjmorgan marked this conversation as resolved.
else:
out = np.full_like(arr, np.nan, dtype=float)

Comment thread
bjmorgan marked this conversation as resolved.
if arr.ndim == 1:
out[mask] = arr[mask]
else:
out[:, mask] = arr[:, mask]

return out

def _validate_for_render(self) -> None:
"""Raise if atom_data is incompatible with ``len(self.frames)``.

Expand Down
2 changes: 1 addition & 1 deletion src/hofmann/rendering/precompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def _precompute_scene(
)
if rs.deduplicate_molecules:
rset = deduplicate_molecules(rset, lattice)
species = rset.species
species = tuple(rset.species)
Comment thread
bjmorgan marked this conversation as resolved.
coords = rset.coords
bonds = rset.bonds
source_indices = rset.source_indices
Expand Down
Loading
Loading